base_agent/agent/agent.py

466 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
agent/agent.py
Agent 核心 —— 通过 SkillRegistry 统一调用本地工具和在线 MCP Skill
"""
import json
import time
from dataclasses import dataclass, field
from typing import Any
from config.settings import settings
from mcp.skill_registry import DispatchResult, SkillRegistry
from utils.logger import get_logger
logger = get_logger("Agent")
# ════════════════════════════════════════════════════════════════
# 消息 / 历史
# ════════════════════════════════════════════════════════════════
@dataclass
class Message:
role: str # system | user | assistant | tool
content: str = ""
tool_call_id: str = ""
tool_name: str = ""
tool_calls: list[dict] = field(default_factory=list)
def to_api_dict(self) -> dict:
d: dict[str, Any] = {"role": self.role}
if self.role == "tool":
d["content"] = self.content
d["tool_call_id"] = self.tool_call_id
elif self.tool_calls:
d["content"] = self.content or None
d["tool_calls"] = self.tool_calls
else:
d["content"] = self.content
return d
# ════════════════════════════════════════════════════════════════
# LLM 客户端OpenAI-compatible
# ════════════════════════════════════════════════════════════════
class LLMClient:
"""
OpenAI-compatible LLM 客户端
支持 function calling / tool_calls
"""
def __init__(self):
try:
from openai import OpenAI
self._client = OpenAI(
api_key=settings.llm.api_key or "sk-placeholder",
base_url=settings.llm.api_base_url or None,
)
self._available = True
except ImportError:
logger.warning("⚠️ openai 未安装LLM 调用将使用 mock 模式")
self._available = False
def chat(
self,
messages: list[dict],
tools: list[dict] | None = None,
stream: bool = False,
) -> dict:
"""
发送对话请求
Returns:
{
"content": str, # 文本回复(可能为空)
"tool_calls": list[dict], # function calling 调用列表
"finish_reason": str,
}
"""
if not self._available:
return self._mock_response(messages, tools)
kwargs: dict[str, Any] = {
"model": settings.llm.model_name,
"messages": messages,
"max_tokens": settings.llm.max_tokens,
"temperature": settings.llm.temperature,
"stream": stream,
}
if tools and settings.llm.function_calling:
kwargs["tools"] = [{"type": "function", "function": t} for t in tools[:128]]
kwargs["tool_choice"] = "auto"
resp = self._client.chat.completions.create(**kwargs)
if stream:
return self._collect_stream(resp)
msg = resp.choices[0].message
finish = resp.choices[0].finish_reason
return {
"content": msg.content or "",
"tool_calls": self._parse_tool_calls(msg.tool_calls),
"finish_reason": finish,
}
@staticmethod
def _parse_tool_calls(raw) -> list[dict]:
if not raw:
return []
result = []
for tc in raw:
try:
args = json.loads(tc.function.arguments)
except (json.JSONDecodeError, AttributeError):
args = {}
result.append({
"id": tc.id,
"name": tc.function.name,
"arguments": args,
})
return result
@staticmethod
def _collect_stream(stream) -> dict:
content = ""
tool_calls = []
finish = ""
tc_buffers: dict[int, dict] = {}
for chunk in stream:
delta = chunk.choices[0].delta
finish = chunk.choices[0].finish_reason or finish
if delta.content:
content += delta.content
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tc_buffers:
tc_buffers[idx] = {
"id": tc.id or "",
"name": tc.function.name if tc.function else "",
"args": "",
}
if tc.function and tc.function.arguments:
tc_buffers[idx]["args"] += tc.function.arguments
for buf in tc_buffers.values():
try:
args = json.loads(buf["args"])
except json.JSONDecodeError:
args = {}
tool_calls.append({
"id": buf["id"],
"name": buf["name"],
"arguments": args,
})
return {"content": content, "tool_calls": tool_calls, "finish_reason": finish}
@staticmethod
def _mock_response(messages: list[dict], tools: list[dict] | None) -> dict:
"""无 LLM 时的 mock 响应(用于测试)"""
last = messages[-1].get("content", "") if messages else ""
if tools and ("搜索" in last or "search" in last.lower()):
return {
"content": None,
"tool_calls": [{
"id": "mock_001",
"name": tools[0]["name"],
"arguments": {"query": last},
}],
"finish_reason": "tool_calls",
}
return {
"content": f"[Mock LLM] 收到: {last[:100]}",
"tool_calls": [],
"finish_reason": "stop",
}
# ════════════════════════════════════════════════════════════════
# Agent 核心
# ════════════════════════════════════════════════════════════════
class Agent:
"""
Agent 核心
通过 SkillRegistry 统一调用本地工具和在线 MCP Skill
Agent 无需感知工具来源。
用法:
registry = SkillRegistry()
registry.register_local_many(CalculatorTool(), WebSearchTool())
registry.connect_skills() # 连接在线 MCP Skill
agent = Agent(registry)
reply = agent.chat("帮我搜索 Python 最新版本")
print(reply)
"""
SYSTEM_PROMPT = (
"你是一个智能助手,可以调用工具完成用户的任务。\n"
"调用工具时请确保参数完整准确。\n"
"工具调用结果会自动返回给你,请根据结果给出最终回答。"
)
def __init__(
self,
registry: SkillRegistry,
system_prompt: str | None = None,
):
self.registry = registry
self.llm = LLMClient()
self.history: list[Message] = []
self.system_prompt = system_prompt or self.SYSTEM_PROMPT
self._max_steps = settings.agent.max_chain_steps
logger.info(
f"🤖 Agent 初始化完成\n"
f" LLM : {settings.llm.provider} / {settings.llm.model_name}\n"
f" 工具总数 : {len(registry.get_all_schemas())}\n"
f" 最大步数 : {self._max_steps}\n"
f" 工具列表 :\n" +
"\n".join(
f" {'🔵' if t['source'] == 'local' else '🟢'} "
f"[{t['source']:20s}] {t['name']}"
for t in registry.list_all_tools()
)
)
# ── 对话入口 ──────────────────────────────────────────────
def chat(self, user_input: str) -> str:
"""
单轮对话入口(支持多步工具调用链)
Args:
user_input: 用户输入文本
Returns:
最终回复文本
"""
logger.info(f"💬 用户输入: {user_input}")
self.history.append(Message(role="user", content=user_input))
reply = self._run_loop()
self.history.append(Message(role="assistant", content=reply))
# 历史记录截断(来自 config.yaml memory.max_history
max_h = settings.memory.max_history
if len(self.history) > max_h:
self.history = self.history[-max_h:]
return reply
def reset(self) -> None:
"""清空对话历史"""
self.history.clear()
logger.info("🔄 对话历史已清空")
# ── 多步推理循环 ──────────────────────────────────────────
def _run_loop(self) -> str:
"""
多步工具调用循环
流程:
1. 构造消息列表 → 调用 LLM
2. 若 LLM 返回 tool_calls → 执行工具 → 追加结果 → 回到 1
3. 若 LLM 返回文本 → 结束,返回文本
"""
tools = self.registry.get_all_schemas()
step = 0
loop_history: list[Message] = list(self.history)
while step < self._max_steps:
step += 1
logger.info(f"🔁 推理步骤 {step}/{self._max_steps}")
# 构造 API 消息列表
messages = self._build_messages(loop_history)
# 调用 LLM
llm_resp = self.llm.chat(
messages=messages,
tools=tools if settings.llm.function_calling else None,
stream=settings.llm.stream,
)
content = llm_resp.get("content", "") or ""
tool_calls = llm_resp.get("tool_calls", [])
finish = llm_resp.get("finish_reason", "stop")
logger.debug(
f" LLM 响应: finish={finish} "
f"tool_calls={len(tool_calls)} "
f"content={content[:80]}"
)
# ── 无工具调用:直接返回文本 ──────────────────────
if not tool_calls:
return content or "(无回复)"
# ── 有工具调用:执行并追加结果 ────────────────────
# 追加 assistant 消息(含 tool_calls
loop_history.append(Message(
role="assistant",
content=content,
tool_calls=[{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": json.dumps(tc["arguments"], ensure_ascii=False),
},
} for tc in tool_calls],
))
# 执行每个工具调用
for tc in tool_calls:
result = self._execute_tool(tc["name"], tc["arguments"])
loop_history.append(Message(
role="tool",
content=str(result),
tool_call_id=tc["id"],
tool_name=tc["name"],
))
# finish_reason == stop 且有工具结果时继续循环让 LLM 总结
if finish == "stop":
break
# 超过最大步数,强制返回最后一条 assistant 内容
for msg in reversed(loop_history):
if msg.role == "assistant" and msg.content:
return msg.content
return "(已达最大推理步数,无法给出最终回答)"
# ── 工具执行 ──────────────────────────────────────────────
def _execute_tool(
self,
tool_name: str,
arguments: dict[str, Any],
) -> str:
"""执行工具调用,返回结果字符串"""
logger.info(
f"🔧 执行工具: {tool_name}\n"
f" 来源: {self.registry.get_tool_info(tool_name)}\n"
f" 参数: {json.dumps(arguments, ensure_ascii=False)[:200]}"
)
start = time.time()
result = self.registry.dispatch(tool_name, arguments)
elapsed = time.time() - start
icon = "" if result.success else ""
logger.info(
f"{icon} 工具结果: {tool_name} "
f"source={result.source} 耗时={elapsed:.2f}s\n"
f" {str(result)[:200]}"
)
return str(result)
# ── 消息构造 ──────────────────────────────────────────────
def _build_messages(self, history: list[Message]) -> list[dict]:
messages = [{"role": "system", "content": self.system_prompt}]
messages += [m.to_api_dict() for m in history]
return messages
# ── 调试工具 ──────────────────────────────────────────────
def show_tools(self) -> str:
"""打印所有可用工具(含来源)"""
tools = self.registry.list_all_tools()
lines = [f"📦 可用工具(共 {len(tools)} 个):", "" * 50]
for t in tools:
icon = "🔵" if t["source"] == "local" else "🟢"
lines.append(
f" {icon} [{t['source']:25s}] {t['name']}\n"
f" {t['description']}"
)
return "\n".join(lines)
# ════════════════════════════════════════════════════════════════
# Demo 入口
# ════════════════════════════════════════════════════════════════
def create_agent() -> tuple[Agent, SkillRegistry]:
"""
工厂函数:创建并初始化 Agent + SkillRegistry
Returns:
(agent, registry) —— registry 需在程序退出时调用 .close()
"""
from tools.calculator import CalculatorTool
from tools.code_executor import CodeExecutorTool
from tools.file_reader import FileReaderTool
from tools.ssh_docker import SSHDockerTool
from tools.static_analyzer import StaticAnalyzerTool
from tools.web_search import WebSearchTool
registry = SkillRegistry()
# 注册本地工具(根据 config.yaml mcp.enabled_tools 过滤)
enabled = settings.mcp.enabled_tools
tool_map = {
"calculator": CalculatorTool,
"web_search": WebSearchTool,
"file_reader": FileReaderTool,
"code_executor": CodeExecutorTool,
"static_analyzer": StaticAnalyzerTool,
"ssh_docker": SSHDockerTool,
}
for name in enabled:
if name in tool_map:
registry.register_local(tool_map[name]())
# 连接在线 MCP Skill来自 config.yaml mcp_skills
registry.connect_skills()
agent = Agent(registry)
return agent, registry
if __name__ == "__main__":
import atexit
print(settings.display())
agent, registry = create_agent()
atexit.register(registry.close) # 程序退出时自动关闭连接
print(agent.show_tools())
print("" * 60)
print("💡 输入 'exit' 退出,'reset' 清空历史,'tools' 查看工具列表")
print("" * 60)
while True:
try:
user_input = input("\n🧑 You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\n👋 再见!")
break
if not user_input:
continue
if user_input.lower() == "exit":
print("👋 再见!")
break
if user_input.lower() == "reset":
agent.reset()
print("🔄 对话历史已清空")
continue
if user_input.lower() == "tools":
print(agent.show_tools())
continue
reply = agent.chat(user_input)
print(f"\n🤖 Agent: {reply}")