253 lines
9.8 KiB
Python
253 lines
9.8 KiB
Python
"""
|
||
main.py
|
||
智能体 Demo 程序入口(OpenAI Function Calling 驱动)
|
||
|
||
运行模式:
|
||
python main.py → 交互模式
|
||
python main.py demo → 演示模式
|
||
python main.py config → 打印当前配置
|
||
python main.py health → 检测 OpenAI API 连通性
|
||
LLM_API_KEY=sk-xxx python main.py → 指定 API Key
|
||
LLM_MODEL_NAME=gpt-4-turbo python main.py→ 指定模型
|
||
AGENT_CONFIG_PATH=my.yaml python main.py → 指定配置文件
|
||
"""
|
||
|
||
import sys
|
||
|
||
from client.agent_client import AgentClient, AgentResponse
|
||
from config.settings import settings
|
||
from llm.llm_engine import LLMEngine
|
||
from mcp.mcp_server import MCPServer
|
||
from memory.memory_store import MemoryStore
|
||
from tools.calculator import CalculatorTool
|
||
from tools.code_executor import CodeExecutorTool
|
||
from tools.file_reader import FileReaderTool
|
||
from tools.web_search import WebSearchTool
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("SYSTEM")
|
||
|
||
_ALL_TOOLS = {
|
||
"calculator": CalculatorTool,
|
||
"web_search": WebSearchTool,
|
||
"file_reader": FileReaderTool,
|
||
"code_executor": CodeExecutorTool,
|
||
}
|
||
|
||
|
||
# ── 系统组装 ───────────────────────────────────────────────────
|
||
def build_agent() -> AgentClient:
|
||
"""工厂函数:由 settings 驱动的 Agent 组装"""
|
||
logger.info("🔧 开始组装 Agent 系统(OpenAI Function Calling 模式)...")
|
||
logger.info(settings.display())
|
||
|
||
mcp_server = MCPServer()
|
||
for tool_cls in _ALL_TOOLS.values():
|
||
mcp_server.register_tool(tool_cls)
|
||
|
||
llm = LLMEngine()
|
||
memory = MemoryStore(max_history=settings.memory.max_history)
|
||
client = AgentClient(llm=llm, mcp_server=mcp_server, memory=memory)
|
||
|
||
logger.info(f"✅ Agent 组装完成,已注册工具: {mcp_server.list_tools()}")
|
||
return client
|
||
|
||
|
||
# ── 结果打印 ───────────────────────────────────────────────────
|
||
def print_response(response: AgentResponse) -> None:
|
||
"""格式化打印 AgentResponse"""
|
||
print(f"\n{'═' * 62}")
|
||
print(f"👤 用户: {response.user_input}")
|
||
print(f"{'─' * 62}")
|
||
|
||
if response.chain_result:
|
||
cr = response.chain_result
|
||
tag = "🔗 多步串行" if response.is_multi_step else "🔧 单步调用"
|
||
status = "✅ 全部成功" if cr.success else f"⚠️ 步骤 {cr.failed_step} 失败"
|
||
print(f"{tag} | {cr.completed_steps}/{cr.total_steps} 步 | {status}")
|
||
print()
|
||
for r in cr.step_results:
|
||
icon = "✅" if r.success else "❌"
|
||
preview = r.output.replace("\n", " ")[:90]
|
||
print(f" {icon} Step {r.step_id} [{r.tool_name}]")
|
||
if r.success:
|
||
print(f" └─ {preview}...")
|
||
else:
|
||
print(f" └─ 错误: {r.error}")
|
||
print()
|
||
|
||
print(f"🤖 Agent 回复:\n{response.final_reply}")
|
||
print(f"{'═' * 62}\n")
|
||
|
||
|
||
# ── API 健康检测 ───────────────────────────────────────────────
|
||
def run_health_check() -> None:
|
||
"""检测 OpenAI API 连通性"""
|
||
print(f"\n{'─' * 50}")
|
||
print(f" 🏥 OpenAI API 健康检测")
|
||
print(f"{'─' * 50}")
|
||
print(f" Provider : {settings.llm.provider}")
|
||
print(f" Model : {settings.llm.model_name}")
|
||
print(f" API Key : {'***' + settings.llm.api_key[-4:] if len(settings.llm.api_key) > 4 else '(未设置)'}")
|
||
print(f" Base URL : {settings.llm.api_base_url or 'https://api.openai.com/v1'}")
|
||
print(f"{'─' * 50}")
|
||
|
||
if not settings.llm.api_key:
|
||
print(" ❌ API Key 未设置")
|
||
print(" 💡 请设置环境变量: export LLM_API_KEY=sk-...")
|
||
print(f"{'─' * 50}\n")
|
||
return
|
||
|
||
print(" ⏳ 正在检测连通性...")
|
||
llm = LLMEngine()
|
||
ok = llm.provider.health_check()
|
||
|
||
if ok:
|
||
print(f" ✅ API 连通正常,模型 [{settings.llm.model_name}] 可用")
|
||
else:
|
||
print(f" ❌ API 连接失败,请检查网络或 API Key")
|
||
print(f" 💡 可尝试设置代理: export LLM_API_BASE_URL=https://your-proxy/v1")
|
||
print(f"{'─' * 50}\n")
|
||
|
||
|
||
# ── 演示场景 ───────────────────────────────────────────────────
|
||
def run_demo(client: AgentClient) -> None:
|
||
"""运行预设演示场景"""
|
||
demo_cases = [
|
||
("🔢 单步: 数学计算",
|
||
"计算 (100 + 200) × 3 等于多少?"),
|
||
|
||
("🌐 单步: 网络搜索",
|
||
"搜索 Python 3.12 的主要新特性"),
|
||
|
||
("🔗 两步: 搜索 + 计算",
|
||
"搜索 Python 最新版本号,然后计算 3.12 × 100 的结果"),
|
||
|
||
("🔗 两步: 读取文件 + 执行代码",
|
||
"读取 script.py 文件然后执行里面的代码"),
|
||
|
||
("💬 无工具: 直接问答",
|
||
"你好,请介绍一下你自己"),
|
||
]
|
||
|
||
logger.info("\n" + "═" * 62)
|
||
logger.info(f"🎬 演示模式 | 模型: {settings.llm.model_name} | "
|
||
f"Provider: {settings.llm.provider}")
|
||
logger.info("═" * 62)
|
||
|
||
for title, question in demo_cases:
|
||
logger.info(f"\n📌 场景: {title}")
|
||
response = client.chat(question)
|
||
print_response(response)
|
||
|
||
stats = client.get_memory_stats()
|
||
print(f"📊 Memory 统计: {stats}\n")
|
||
|
||
|
||
# ── 交互模式 ───────────────────────────────────────────────────
|
||
def run_interactive(client: AgentClient) -> None:
|
||
"""启动交互式命令行对话"""
|
||
print("\n" + "═" * 62)
|
||
print(f" 🤖 Agent | {settings.llm.model_name} | {settings.llm.provider}")
|
||
print(f" Function Calling: {'✅ 开启' if settings.llm.function_calling else '❌ 关闭(规则引擎)'}")
|
||
print(f" Fallback Rules : {'✅ 开启' if settings.agent.fallback_to_rules else '❌ 关闭'}")
|
||
print("─" * 62)
|
||
print(" 💡 示例:")
|
||
print(" 计算 (100+200) × 3")
|
||
print(" 搜索 Python 新特性,然后计算 3.12 × 100")
|
||
print(" 读取 config.json 文件然后执行代码")
|
||
print("─" * 62)
|
||
print(" 🛠 命令: config / health / tools / chains / stats / clear / quit")
|
||
print("═" * 62 + "\n")
|
||
|
||
while True:
|
||
try:
|
||
user_input = input("👤 你: ").strip()
|
||
except (KeyboardInterrupt, EOFError):
|
||
print("\n👋 再见!")
|
||
break
|
||
|
||
if not user_input:
|
||
continue
|
||
|
||
match user_input.lower():
|
||
case "quit" | "exit":
|
||
print("👋 再见!")
|
||
break
|
||
case "config":
|
||
print(settings.display())
|
||
case "health":
|
||
run_health_check()
|
||
case "clear":
|
||
client.clear_session()
|
||
print("✅ 会话已清空\n")
|
||
case "stats":
|
||
print(f"📊 {client.get_memory_stats()}\n")
|
||
case "tools":
|
||
schemas = client.mcp_server.get_tool_schemas()
|
||
print(f"🔧 已注册工具 ({len(schemas)} 个):")
|
||
for s in schemas:
|
||
print(f" • [{s.name}] {s.description}")
|
||
print()
|
||
case "chains":
|
||
chains = client.memory.get_chain_history()
|
||
if not chains:
|
||
print("🔗 暂无调用链历史\n")
|
||
else:
|
||
print(f"🔗 调用链历史 ({len(chains)} 条):")
|
||
for i, c in enumerate(chains, 1):
|
||
steps = " → ".join(s["tool_name"] for s in c["steps"])
|
||
ok_cnt = sum(1 for s in c["steps"] if s["success"])
|
||
total = len(c["steps"])
|
||
print(f" {i}. [{c['timestamp'][11:19]}] {c['goal'][:38]}...")
|
||
print(f" 链路: {steps} ({ok_cnt}/{total} 步成功)")
|
||
print()
|
||
case _:
|
||
response = client.chat(user_input)
|
||
print_response(response)
|
||
|
||
|
||
# ── 配置打印 ───────────────────────────────────────────────────
|
||
def run_show_config() -> None:
|
||
print(settings.display())
|
||
print("\n📁 配置文件查找路径(按优先级):")
|
||
print(" 1. 环境变量 AGENT_CONFIG_PATH")
|
||
print(" 2. ./config/config.yaml")
|
||
print(" 3. ./config.yaml")
|
||
print("\n🌍 支持的环境变量覆盖:")
|
||
env_vars = [
|
||
("LLM_API_KEY", "OpenAI API 密钥(sk-...)"),
|
||
("LLM_MODEL_NAME", "模型名称,如 gpt-4o / gpt-4-turbo"),
|
||
("LLM_API_BASE_URL", "自定义 API 地址(兼容代理)"),
|
||
("LLM_MODEL_PATH", "本地模型路径"),
|
||
("SEARCH_API_KEY", "搜索 API 密钥"),
|
||
("LOG_LEVEL", "日志级别 DEBUG/INFO/WARNING/ERROR"),
|
||
("AGENT_CONFIG_PATH","配置文件路径"),
|
||
]
|
||
for var, desc in env_vars:
|
||
print(f" {var:<22} → {desc}")
|
||
print()
|
||
|
||
|
||
# ── 主函数 ─────────────────────────────────────────────────────
|
||
def main() -> None:
|
||
mode = sys.argv[1] if len(sys.argv) > 1 else "interactive"
|
||
|
||
if mode == "config":
|
||
run_show_config()
|
||
return
|
||
|
||
if mode == "health":
|
||
run_health_check()
|
||
return
|
||
|
||
client = build_agent()
|
||
|
||
if mode == "demo":
|
||
run_demo(client)
|
||
else:
|
||
run_interactive(client)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |