base_agent/main.py

253 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
main.py
智能体 Demo 程序入口OpenAI Function Calling 驱动)
运行模式:
python main.py → 交互模式
python main.py demo → 演示模式
python main.py config → 打印当前配置
python main.py health → 检测 OpenAI API 连通性
LLM_API_KEY=sk-xxx python main.py → 指定 API Key
LLM_MODEL_NAME=gpt-4-turbo python main.py→ 指定模型
AGENT_CONFIG_PATH=my.yaml python main.py → 指定配置文件
"""
import sys
from client.agent_client import AgentClient, AgentResponse
from config.settings import settings
from llm.llm_engine import LLMEngine
from mcp.mcp_server import MCPServer
from memory.memory_store import MemoryStore
from tools.calculator import CalculatorTool
from tools.code_executor import CodeExecutorTool
from tools.file_reader import FileReaderTool
from tools.web_search import WebSearchTool
from utils.logger import get_logger
logger = get_logger("SYSTEM")
_ALL_TOOLS = {
"calculator": CalculatorTool,
"web_search": WebSearchTool,
"file_reader": FileReaderTool,
"code_executor": CodeExecutorTool,
}
# ── 系统组装 ───────────────────────────────────────────────────
def build_agent() -> AgentClient:
"""工厂函数:由 settings 驱动的 Agent 组装"""
logger.info("🔧 开始组装 Agent 系统OpenAI Function Calling 模式)...")
logger.info(settings.display())
mcp_server = MCPServer()
for tool_cls in _ALL_TOOLS.values():
mcp_server.register_tool(tool_cls)
llm = LLMEngine()
memory = MemoryStore(max_history=settings.memory.max_history)
client = AgentClient(llm=llm, mcp_server=mcp_server, memory=memory)
logger.info(f"✅ Agent 组装完成,已注册工具: {mcp_server.list_tools()}")
return client
# ── 结果打印 ───────────────────────────────────────────────────
def print_response(response: AgentResponse) -> None:
"""格式化打印 AgentResponse"""
print(f"\n{'' * 62}")
print(f"👤 用户: {response.user_input}")
print(f"{'' * 62}")
if response.chain_result:
cr = response.chain_result
tag = "🔗 多步串行" if response.is_multi_step else "🔧 单步调用"
status = "✅ 全部成功" if cr.success else f"⚠️ 步骤 {cr.failed_step} 失败"
print(f"{tag} | {cr.completed_steps}/{cr.total_steps} 步 | {status}")
print()
for r in cr.step_results:
icon = "" if r.success else ""
preview = r.output.replace("\n", " ")[:90]
print(f" {icon} Step {r.step_id} [{r.tool_name}]")
if r.success:
print(f" └─ {preview}...")
else:
print(f" └─ 错误: {r.error}")
print()
print(f"🤖 Agent 回复:\n{response.final_reply}")
print(f"{'' * 62}\n")
# ── API 健康检测 ───────────────────────────────────────────────
def run_health_check() -> None:
"""检测 OpenAI API 连通性"""
print(f"\n{'' * 50}")
print(f" 🏥 OpenAI API 健康检测")
print(f"{'' * 50}")
print(f" Provider : {settings.llm.provider}")
print(f" Model : {settings.llm.model_name}")
print(f" API Key : {'***' + settings.llm.api_key[-4:] if len(settings.llm.api_key) > 4 else '(未设置)'}")
print(f" Base URL : {settings.llm.api_base_url or 'https://api.openai.com/v1'}")
print(f"{'' * 50}")
if not settings.llm.api_key:
print(" ❌ API Key 未设置")
print(" 💡 请设置环境变量: export LLM_API_KEY=sk-...")
print(f"{'' * 50}\n")
return
print(" ⏳ 正在检测连通性...")
llm = LLMEngine()
ok = llm.provider.health_check()
if ok:
print(f" ✅ API 连通正常,模型 [{settings.llm.model_name}] 可用")
else:
print(f" ❌ API 连接失败,请检查网络或 API Key")
print(f" 💡 可尝试设置代理: export LLM_API_BASE_URL=https://your-proxy/v1")
print(f"{'' * 50}\n")
# ── 演示场景 ───────────────────────────────────────────────────
def run_demo(client: AgentClient) -> None:
"""运行预设演示场景"""
demo_cases = [
("🔢 单步: 数学计算",
"计算 (100 + 200) × 3 等于多少?"),
("🌐 单步: 网络搜索",
"搜索 Python 3.12 的主要新特性"),
("🔗 两步: 搜索 + 计算",
"搜索 Python 最新版本号,然后计算 3.12 × 100 的结果"),
("🔗 两步: 读取文件 + 执行代码",
"读取 script.py 文件然后执行里面的代码"),
("💬 无工具: 直接问答",
"你好,请介绍一下你自己"),
]
logger.info("\n" + "" * 62)
logger.info(f"🎬 演示模式 | 模型: {settings.llm.model_name} | "
f"Provider: {settings.llm.provider}")
logger.info("" * 62)
for title, question in demo_cases:
logger.info(f"\n📌 场景: {title}")
response = client.chat(question)
print_response(response)
stats = client.get_memory_stats()
print(f"📊 Memory 统计: {stats}\n")
# ── 交互模式 ───────────────────────────────────────────────────
def run_interactive(client: AgentClient) -> None:
"""启动交互式命令行对话"""
print("\n" + "" * 62)
print(f" 🤖 Agent | {settings.llm.model_name} | {settings.llm.provider}")
print(f" Function Calling: {'✅ 开启' if settings.llm.function_calling else '❌ 关闭(规则引擎)'}")
print(f" Fallback Rules : {'✅ 开启' if settings.agent.fallback_to_rules else '❌ 关闭'}")
print("" * 62)
print(" 💡 示例:")
print(" 计算 (100+200) × 3")
print(" 搜索 Python 新特性,然后计算 3.12 × 100")
print(" 读取 config.json 文件然后执行代码")
print("" * 62)
print(" 🛠 命令: config / health / tools / chains / stats / clear / quit")
print("" * 62 + "\n")
while True:
try:
user_input = input("👤 你: ").strip()
except (KeyboardInterrupt, EOFError):
print("\n👋 再见!")
break
if not user_input:
continue
match user_input.lower():
case "quit" | "exit":
print("👋 再见!")
break
case "config":
print(settings.display())
case "health":
run_health_check()
case "clear":
client.clear_session()
print("✅ 会话已清空\n")
case "stats":
print(f"📊 {client.get_memory_stats()}\n")
case "tools":
schemas = client.mcp_server.get_tool_schemas()
print(f"🔧 已注册工具 ({len(schemas)} 个):")
for s in schemas:
print(f" • [{s.name}] {s.description}")
print()
case "chains":
chains = client.memory.get_chain_history()
if not chains:
print("🔗 暂无调用链历史\n")
else:
print(f"🔗 调用链历史 ({len(chains)} 条):")
for i, c in enumerate(chains, 1):
steps = "".join(s["tool_name"] for s in c["steps"])
ok_cnt = sum(1 for s in c["steps"] if s["success"])
total = len(c["steps"])
print(f" {i}. [{c['timestamp'][11:19]}] {c['goal'][:38]}...")
print(f" 链路: {steps} ({ok_cnt}/{total} 步成功)")
print()
case _:
response = client.chat(user_input)
print_response(response)
# ── 配置打印 ───────────────────────────────────────────────────
def run_show_config() -> None:
print(settings.display())
print("\n📁 配置文件查找路径(按优先级):")
print(" 1. 环境变量 AGENT_CONFIG_PATH")
print(" 2. ./config/config.yaml")
print(" 3. ./config.yaml")
print("\n🌍 支持的环境变量覆盖:")
env_vars = [
("LLM_API_KEY", "OpenAI API 密钥sk-..."),
("LLM_MODEL_NAME", "模型名称,如 gpt-4o / gpt-4-turbo"),
("LLM_API_BASE_URL", "自定义 API 地址(兼容代理)"),
("LLM_MODEL_PATH", "本地模型路径"),
("SEARCH_API_KEY", "搜索 API 密钥"),
("LOG_LEVEL", "日志级别 DEBUG/INFO/WARNING/ERROR"),
("AGENT_CONFIG_PATH","配置文件路径"),
]
for var, desc in env_vars:
print(f" {var:<22}{desc}")
print()
# ── 主函数 ─────────────────────────────────────────────────────
def main() -> None:
mode = sys.argv[1] if len(sys.argv) > 1 else "interactive"
if mode == "config":
run_show_config()
return
if mode == "health":
run_health_check()
return
client = build_agent()
if mode == "demo":
run_demo(client)
else:
run_interactive(client)
if __name__ == "__main__":
main()