""" memory/memory_store.py Agent 记忆模块:管理对话历史(短期)与关键事实(长期) 新增: add_chain_result() 记录完整多步骤调用链 """ from collections import deque from dataclasses import dataclass, field from datetime import datetime from typing import TYPE_CHECKING, Literal from utils.logger import get_logger if TYPE_CHECKING: from mcp.mcp_protocol import ChainResult # ── 消息数据结构 ─────────────────────────────────────────────── @dataclass class Message: """单条对话消息""" role: Literal["user", "assistant", "tool", "chain"] content: str timestamp: str = field(default_factory=lambda: datetime.now().strftime("%H:%M:%S")) metadata: dict = field(default_factory=dict) def to_dict(self) -> dict: return { "role": self.role, "content": self.content, "timestamp": self.timestamp, } # ── 记忆存储 ─────────────────────────────────────────────────── class MemoryStore: """ 对话记忆存储 短期记忆: deque 保存最近 N 轮对话,自动滚动淘汰 长期记忆: 关键事实列表(生产环境可替换为向量数据库) 链路记录: 完整的多步骤调用链历史 使用示例: memory = MemoryStore(max_history=20) memory.add_user_message("你好") memory.add_chain_result(chain_result) """ def __init__(self, max_history: int = 20): self.logger = get_logger("MEMORY") self.max_history = max_history self._history: deque[Message] = deque(maxlen=max_history) self._facts: list[str] = [] self._chains: list[dict] = [] # 调用链历史记录 self.logger.info(f"💾 Memory 初始化,最大历史: {max_history} 条") # ── 写入接口 ──────────────────────────────────────────────── def add_user_message(self, content: str) -> None: self._add(Message(role="user", content=content)) def add_assistant_message(self, content: str) -> None: self._add(Message(role="assistant", content=content)) def add_tool_result(self, tool_name: str, result: str) -> None: self._add(Message( role="tool", content=result, metadata={"tool": tool_name}, )) def add_chain_result(self, chain_result: "ChainResult") -> None: """ 记录完整的多步骤调用链结果 Args: chain_result: ChainResult 实例 """ # 写入对话历史(assistant 角色) self._add(Message( role="chain", content=chain_result.final_reply, metadata={ "goal": chain_result.goal, "total_steps": chain_result.total_steps, "completed_steps": chain_result.completed_steps, "success": chain_result.success, "tools_used": [r.tool_name for r in chain_result.step_results], }, )) # 写入链路追踪记录 chain_record = { "timestamp": datetime.now().isoformat(), "goal": chain_result.goal, "steps": [ { "step_id": r.step_id, "tool_name": r.tool_name, "success": r.success, "output": r.output[:200], "error": r.error, } for r in chain_result.step_results ], "success": chain_result.success, } self._chains.append(chain_record) self.logger.info( f"🔗 调用链已记录: {chain_result.completed_steps}/{chain_result.total_steps} 步成功" ) def add_fact(self, fact: str) -> None: self._facts.append(fact) self.logger.debug(f"📌 长期记忆新增: {fact}") # ── 读取接口 ──────────────────────────────────────────────── def get_history(self, last_n: int | None = None) -> list[dict]: messages = list(self._history) if last_n: messages = messages[-last_n:] return [m.to_dict() for m in messages] def get_facts(self) -> list[str]: return list(self._facts) def get_chain_history(self) -> list[dict]: """获取所有调用链历史记录""" return list(self._chains) def get_context_summary(self) -> str: """生成上下文摘要,供 LLM Prompt 使用""" history = self.get_history(last_n=6) lines = [f"[{m['role'].upper()}] {m['content'][:80]}" for m in history] return "\n".join(lines) if lines else "(暂无对话历史)" # ── 管理接口 ──────────────────────────────────────────────── def clear_history(self) -> None: self._history.clear() self.logger.info("🗑 对话历史已清空") def stats(self) -> dict: return { "history_count": len(self._history), "facts_count": len(self._facts), "chain_count": len(self._chains), "max_history": self.max_history, } def _add(self, message: Message) -> None: self._history.append(message) self.logger.debug(f"💬 [{message.role.upper()}] {message.content[:60]}...")