466 lines
17 KiB
Python
466 lines
17 KiB
Python
|
|
"""
|
|||
|
|
agent/agent.py
|
|||
|
|
Agent 核心 —— 通过 SkillRegistry 统一调用本地工具和在线 MCP Skill
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from config.settings import settings
|
|||
|
|
from mcp.skill_registry import DispatchResult, SkillRegistry
|
|||
|
|
from utils.logger import get_logger
|
|||
|
|
|
|||
|
|
logger = get_logger("Agent")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
# 消息 / 历史
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class Message:
|
|||
|
|
role: str # system | user | assistant | tool
|
|||
|
|
content: str = ""
|
|||
|
|
tool_call_id: str = ""
|
|||
|
|
tool_name: str = ""
|
|||
|
|
tool_calls: list[dict] = field(default_factory=list)
|
|||
|
|
|
|||
|
|
def to_api_dict(self) -> dict:
|
|||
|
|
d: dict[str, Any] = {"role": self.role}
|
|||
|
|
if self.role == "tool":
|
|||
|
|
d["content"] = self.content
|
|||
|
|
d["tool_call_id"] = self.tool_call_id
|
|||
|
|
elif self.tool_calls:
|
|||
|
|
d["content"] = self.content or None
|
|||
|
|
d["tool_calls"] = self.tool_calls
|
|||
|
|
else:
|
|||
|
|
d["content"] = self.content
|
|||
|
|
return d
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
# LLM 客户端(OpenAI-compatible)
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
|
|||
|
|
class LLMClient:
|
|||
|
|
"""
|
|||
|
|
OpenAI-compatible LLM 客户端
|
|||
|
|
支持 function calling / tool_calls
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
try:
|
|||
|
|
from openai import OpenAI
|
|||
|
|
self._client = OpenAI(
|
|||
|
|
api_key=settings.llm.api_key or "sk-placeholder",
|
|||
|
|
base_url=settings.llm.api_base_url or None,
|
|||
|
|
)
|
|||
|
|
self._available = True
|
|||
|
|
except ImportError:
|
|||
|
|
logger.warning("⚠️ openai 未安装,LLM 调用将使用 mock 模式")
|
|||
|
|
self._available = False
|
|||
|
|
|
|||
|
|
def chat(
|
|||
|
|
self,
|
|||
|
|
messages: list[dict],
|
|||
|
|
tools: list[dict] | None = None,
|
|||
|
|
stream: bool = False,
|
|||
|
|
) -> dict:
|
|||
|
|
"""
|
|||
|
|
发送对话请求
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
{
|
|||
|
|
"content": str, # 文本回复(可能为空)
|
|||
|
|
"tool_calls": list[dict], # function calling 调用列表
|
|||
|
|
"finish_reason": str,
|
|||
|
|
}
|
|||
|
|
"""
|
|||
|
|
if not self._available:
|
|||
|
|
return self._mock_response(messages, tools)
|
|||
|
|
|
|||
|
|
kwargs: dict[str, Any] = {
|
|||
|
|
"model": settings.llm.model_name,
|
|||
|
|
"messages": messages,
|
|||
|
|
"max_tokens": settings.llm.max_tokens,
|
|||
|
|
"temperature": settings.llm.temperature,
|
|||
|
|
"stream": stream,
|
|||
|
|
}
|
|||
|
|
if tools and settings.llm.function_calling:
|
|||
|
|
kwargs["tools"] = [{"type": "function", "function": t} for t in tools]
|
|||
|
|
kwargs["tool_choice"] = "auto"
|
|||
|
|
|
|||
|
|
resp = self._client.chat.completions.create(**kwargs)
|
|||
|
|
|
|||
|
|
if stream:
|
|||
|
|
return self._collect_stream(resp)
|
|||
|
|
|
|||
|
|
msg = resp.choices[0].message
|
|||
|
|
finish = resp.choices[0].finish_reason
|
|||
|
|
return {
|
|||
|
|
"content": msg.content or "",
|
|||
|
|
"tool_calls": self._parse_tool_calls(msg.tool_calls),
|
|||
|
|
"finish_reason": finish,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _parse_tool_calls(raw) -> list[dict]:
|
|||
|
|
if not raw:
|
|||
|
|
return []
|
|||
|
|
result = []
|
|||
|
|
for tc in raw:
|
|||
|
|
try:
|
|||
|
|
args = json.loads(tc.function.arguments)
|
|||
|
|
except (json.JSONDecodeError, AttributeError):
|
|||
|
|
args = {}
|
|||
|
|
result.append({
|
|||
|
|
"id": tc.id,
|
|||
|
|
"name": tc.function.name,
|
|||
|
|
"arguments": args,
|
|||
|
|
})
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _collect_stream(stream) -> dict:
|
|||
|
|
content = ""
|
|||
|
|
tool_calls = []
|
|||
|
|
finish = ""
|
|||
|
|
tc_buffers: dict[int, dict] = {}
|
|||
|
|
|
|||
|
|
for chunk in stream:
|
|||
|
|
delta = chunk.choices[0].delta
|
|||
|
|
finish = chunk.choices[0].finish_reason or finish
|
|||
|
|
|
|||
|
|
if delta.content:
|
|||
|
|
content += delta.content
|
|||
|
|
|
|||
|
|
if delta.tool_calls:
|
|||
|
|
for tc in delta.tool_calls:
|
|||
|
|
idx = tc.index
|
|||
|
|
if idx not in tc_buffers:
|
|||
|
|
tc_buffers[idx] = {
|
|||
|
|
"id": tc.id or "",
|
|||
|
|
"name": tc.function.name if tc.function else "",
|
|||
|
|
"args": "",
|
|||
|
|
}
|
|||
|
|
if tc.function and tc.function.arguments:
|
|||
|
|
tc_buffers[idx]["args"] += tc.function.arguments
|
|||
|
|
|
|||
|
|
for buf in tc_buffers.values():
|
|||
|
|
try:
|
|||
|
|
args = json.loads(buf["args"])
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
args = {}
|
|||
|
|
tool_calls.append({
|
|||
|
|
"id": buf["id"],
|
|||
|
|
"name": buf["name"],
|
|||
|
|
"arguments": args,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return {"content": content, "tool_calls": tool_calls, "finish_reason": finish}
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _mock_response(messages: list[dict], tools: list[dict] | None) -> dict:
|
|||
|
|
"""无 LLM 时的 mock 响应(用于测试)"""
|
|||
|
|
last = messages[-1].get("content", "") if messages else ""
|
|||
|
|
if tools and ("搜索" in last or "search" in last.lower()):
|
|||
|
|
return {
|
|||
|
|
"content": None,
|
|||
|
|
"tool_calls": [{
|
|||
|
|
"id": "mock_001",
|
|||
|
|
"name": tools[0]["name"],
|
|||
|
|
"arguments": {"query": last},
|
|||
|
|
}],
|
|||
|
|
"finish_reason": "tool_calls",
|
|||
|
|
}
|
|||
|
|
return {
|
|||
|
|
"content": f"[Mock LLM] 收到: {last[:100]}",
|
|||
|
|
"tool_calls": [],
|
|||
|
|
"finish_reason": "stop",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
# Agent 核心
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
|
|||
|
|
class Agent:
|
|||
|
|
"""
|
|||
|
|
Agent 核心
|
|||
|
|
|
|||
|
|
通过 SkillRegistry 统一调用本地工具和在线 MCP Skill,
|
|||
|
|
Agent 无需感知工具来源。
|
|||
|
|
|
|||
|
|
用法:
|
|||
|
|
registry = SkillRegistry()
|
|||
|
|
registry.register_local_many(CalculatorTool(), WebSearchTool())
|
|||
|
|
registry.connect_skills() # 连接在线 MCP Skill
|
|||
|
|
|
|||
|
|
agent = Agent(registry)
|
|||
|
|
reply = agent.chat("帮我搜索 Python 最新版本")
|
|||
|
|
print(reply)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
SYSTEM_PROMPT = (
|
|||
|
|
"你是一个智能助手,可以调用工具完成用户的任务。\n"
|
|||
|
|
"调用工具时请确保参数完整准确。\n"
|
|||
|
|
"工具调用结果会自动返回给你,请根据结果给出最终回答。"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
registry: SkillRegistry,
|
|||
|
|
system_prompt: str | None = None,
|
|||
|
|
):
|
|||
|
|
self.registry = registry
|
|||
|
|
self.llm = LLMClient()
|
|||
|
|
self.history: list[Message] = []
|
|||
|
|
self.system_prompt = system_prompt or self.SYSTEM_PROMPT
|
|||
|
|
self._max_steps = settings.agent.max_chain_steps
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
f"🤖 Agent 初始化完成\n"
|
|||
|
|
f" LLM : {settings.llm.provider} / {settings.llm.model_name}\n"
|
|||
|
|
f" 工具总数 : {len(registry.get_all_schemas())} 个\n"
|
|||
|
|
f" 最大步数 : {self._max_steps}\n"
|
|||
|
|
f" 工具列表 :\n" +
|
|||
|
|
"\n".join(
|
|||
|
|
f" {'🔵' if t['source'] == 'local' else '🟢'} "
|
|||
|
|
f"[{t['source']:20s}] {t['name']}"
|
|||
|
|
for t in registry.list_all_tools()
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ── 对话入口 ──────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def chat(self, user_input: str) -> str:
|
|||
|
|
"""
|
|||
|
|
单轮对话入口(支持多步工具调用链)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_input: 用户输入文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
最终回复文本
|
|||
|
|
"""
|
|||
|
|
logger.info(f"💬 用户输入: {user_input}")
|
|||
|
|
self.history.append(Message(role="user", content=user_input))
|
|||
|
|
|
|||
|
|
reply = self._run_loop()
|
|||
|
|
|
|||
|
|
self.history.append(Message(role="assistant", content=reply))
|
|||
|
|
# 历史记录截断(来自 config.yaml memory.max_history)
|
|||
|
|
max_h = settings.memory.max_history
|
|||
|
|
if len(self.history) > max_h:
|
|||
|
|
self.history = self.history[-max_h:]
|
|||
|
|
|
|||
|
|
return reply
|
|||
|
|
|
|||
|
|
def reset(self) -> None:
|
|||
|
|
"""清空对话历史"""
|
|||
|
|
self.history.clear()
|
|||
|
|
logger.info("🔄 对话历史已清空")
|
|||
|
|
|
|||
|
|
# ── 多步推理循环 ──────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def _run_loop(self) -> str:
|
|||
|
|
"""
|
|||
|
|
多步工具调用循环
|
|||
|
|
|
|||
|
|
流程:
|
|||
|
|
1. 构造消息列表 → 调用 LLM
|
|||
|
|
2. 若 LLM 返回 tool_calls → 执行工具 → 追加结果 → 回到 1
|
|||
|
|
3. 若 LLM 返回文本 → 结束,返回文本
|
|||
|
|
"""
|
|||
|
|
tools = self.registry.get_all_schemas()
|
|||
|
|
step = 0
|
|||
|
|
loop_history: list[Message] = list(self.history)
|
|||
|
|
|
|||
|
|
while step < self._max_steps:
|
|||
|
|
step += 1
|
|||
|
|
logger.info(f"🔁 推理步骤 {step}/{self._max_steps}")
|
|||
|
|
|
|||
|
|
# 构造 API 消息列表
|
|||
|
|
messages = self._build_messages(loop_history)
|
|||
|
|
|
|||
|
|
# 调用 LLM
|
|||
|
|
llm_resp = self.llm.chat(
|
|||
|
|
messages=messages,
|
|||
|
|
tools=tools if settings.llm.function_calling else None,
|
|||
|
|
stream=settings.llm.stream,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
content = llm_resp.get("content", "") or ""
|
|||
|
|
tool_calls = llm_resp.get("tool_calls", [])
|
|||
|
|
finish = llm_resp.get("finish_reason", "stop")
|
|||
|
|
|
|||
|
|
logger.debug(
|
|||
|
|
f" LLM 响应: finish={finish} "
|
|||
|
|
f"tool_calls={len(tool_calls)} "
|
|||
|
|
f"content={content[:80]}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ── 无工具调用:直接返回文本 ──────────────────────
|
|||
|
|
if not tool_calls:
|
|||
|
|
return content or "(无回复)"
|
|||
|
|
|
|||
|
|
# ── 有工具调用:执行并追加结果 ────────────────────
|
|||
|
|
# 追加 assistant 消息(含 tool_calls)
|
|||
|
|
loop_history.append(Message(
|
|||
|
|
role="assistant",
|
|||
|
|
content=content,
|
|||
|
|
tool_calls=[{
|
|||
|
|
"id": tc["id"],
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": tc["name"],
|
|||
|
|
"arguments": json.dumps(tc["arguments"], ensure_ascii=False),
|
|||
|
|
},
|
|||
|
|
} for tc in tool_calls],
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
# 执行每个工具调用
|
|||
|
|
for tc in tool_calls:
|
|||
|
|
result = self._execute_tool(tc["name"], tc["arguments"])
|
|||
|
|
loop_history.append(Message(
|
|||
|
|
role="tool",
|
|||
|
|
content=str(result),
|
|||
|
|
tool_call_id=tc["id"],
|
|||
|
|
tool_name=tc["name"],
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
# finish_reason == stop 且有工具结果时继续循环让 LLM 总结
|
|||
|
|
if finish == "stop":
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 超过最大步数,强制返回最后一条 assistant 内容
|
|||
|
|
for msg in reversed(loop_history):
|
|||
|
|
if msg.role == "assistant" and msg.content:
|
|||
|
|
return msg.content
|
|||
|
|
return "(已达最大推理步数,无法给出最终回答)"
|
|||
|
|
|
|||
|
|
# ── 工具执行 ──────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def _execute_tool(
|
|||
|
|
self,
|
|||
|
|
tool_name: str,
|
|||
|
|
arguments: dict[str, Any],
|
|||
|
|
) -> str:
|
|||
|
|
"""执行工具调用,返回结果字符串"""
|
|||
|
|
logger.info(
|
|||
|
|
f"🔧 执行工具: {tool_name}\n"
|
|||
|
|
f" 来源: {self.registry.get_tool_info(tool_name)}\n"
|
|||
|
|
f" 参数: {json.dumps(arguments, ensure_ascii=False)[:200]}"
|
|||
|
|
)
|
|||
|
|
start = time.time()
|
|||
|
|
result = self.registry.dispatch(tool_name, arguments)
|
|||
|
|
elapsed = time.time() - start
|
|||
|
|
|
|||
|
|
icon = "✅" if result.success else "❌"
|
|||
|
|
logger.info(
|
|||
|
|
f"{icon} 工具结果: {tool_name} "
|
|||
|
|
f"source={result.source} 耗时={elapsed:.2f}s\n"
|
|||
|
|
f" {str(result)[:200]}"
|
|||
|
|
)
|
|||
|
|
return str(result)
|
|||
|
|
|
|||
|
|
# ── 消息构造 ──────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def _build_messages(self, history: list[Message]) -> list[dict]:
|
|||
|
|
messages = [{"role": "system", "content": self.system_prompt}]
|
|||
|
|
messages += [m.to_api_dict() for m in history]
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
# ── 调试工具 ──────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def show_tools(self) -> str:
|
|||
|
|
"""打印所有可用工具(含来源)"""
|
|||
|
|
tools = self.registry.list_all_tools()
|
|||
|
|
lines = [f"📦 可用工具(共 {len(tools)} 个):", "─" * 50]
|
|||
|
|
for t in tools:
|
|||
|
|
icon = "🔵" if t["source"] == "local" else "🟢"
|
|||
|
|
lines.append(
|
|||
|
|
f" {icon} [{t['source']:25s}] {t['name']}\n"
|
|||
|
|
f" {t['description']}"
|
|||
|
|
)
|
|||
|
|
return "\n".join(lines)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
# Demo 入口
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
|
|||
|
|
def create_agent() -> tuple[Agent, SkillRegistry]:
|
|||
|
|
"""
|
|||
|
|
工厂函数:创建并初始化 Agent + SkillRegistry
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(agent, registry) —— registry 需在程序退出时调用 .close()
|
|||
|
|
"""
|
|||
|
|
from tools.calculator import CalculatorTool
|
|||
|
|
from tools.code_executor import CodeExecutorTool
|
|||
|
|
from tools.file_reader import FileReaderTool
|
|||
|
|
from tools.ssh_docker import SSHDockerTool
|
|||
|
|
from tools.static_analyzer import StaticAnalyzerTool
|
|||
|
|
from tools.web_search import WebSearchTool
|
|||
|
|
|
|||
|
|
registry = SkillRegistry()
|
|||
|
|
|
|||
|
|
# 注册本地工具(根据 config.yaml mcp.enabled_tools 过滤)
|
|||
|
|
enabled = settings.mcp.enabled_tools
|
|||
|
|
tool_map = {
|
|||
|
|
"calculator": CalculatorTool,
|
|||
|
|
"web_search": WebSearchTool,
|
|||
|
|
"file_reader": FileReaderTool,
|
|||
|
|
"code_executor": CodeExecutorTool,
|
|||
|
|
"static_analyzer": StaticAnalyzerTool,
|
|||
|
|
"ssh_docker": SSHDockerTool,
|
|||
|
|
}
|
|||
|
|
for name in enabled:
|
|||
|
|
if name in tool_map:
|
|||
|
|
registry.register_local(tool_map[name]())
|
|||
|
|
|
|||
|
|
# 连接在线 MCP Skill(来自 config.yaml mcp_skills)
|
|||
|
|
registry.connect_skills()
|
|||
|
|
|
|||
|
|
agent = Agent(registry)
|
|||
|
|
return agent, registry
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
import atexit
|
|||
|
|
|
|||
|
|
print(settings.display())
|
|||
|
|
|
|||
|
|
agent, registry = create_agent()
|
|||
|
|
atexit.register(registry.close) # 程序退出时自动关闭连接
|
|||
|
|
|
|||
|
|
print(agent.show_tools())
|
|||
|
|
print("─" * 60)
|
|||
|
|
print("💡 输入 'exit' 退出,'reset' 清空历史,'tools' 查看工具列表")
|
|||
|
|
print("─" * 60)
|
|||
|
|
|
|||
|
|
while True:
|
|||
|
|
try:
|
|||
|
|
user_input = input("\n🧑 You: ").strip()
|
|||
|
|
except (EOFError, KeyboardInterrupt):
|
|||
|
|
print("\n👋 再见!")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not user_input:
|
|||
|
|
continue
|
|||
|
|
if user_input.lower() == "exit":
|
|||
|
|
print("👋 再见!")
|
|||
|
|
break
|
|||
|
|
if user_input.lower() == "reset":
|
|||
|
|
agent.reset()
|
|||
|
|
print("🔄 对话历史已清空")
|
|||
|
|
continue
|
|||
|
|
if user_input.lower() == "tools":
|
|||
|
|
print(agent.show_tools())
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
reply = agent.chat(user_input)
|
|||
|
|
print(f"\n🤖 Agent: {reply}")
|