base_agent/memory/memory_store.py

156 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
memory/memory_store.py
Agent 记忆模块:管理对话历史(短期)与关键事实(长期)
新增: add_chain_result() 记录完整多步骤调用链
"""
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Literal
from utils.logger import get_logger
if TYPE_CHECKING:
from mcp.mcp_protocol import ChainResult
# ── 消息数据结构 ───────────────────────────────────────────────
@dataclass
class Message:
"""单条对话消息"""
role: Literal["user", "assistant", "tool", "chain"]
content: str
timestamp: str = field(default_factory=lambda: datetime.now().strftime("%H:%M:%S"))
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"role": self.role,
"content": self.content,
"timestamp": self.timestamp,
}
# ── 记忆存储 ───────────────────────────────────────────────────
class MemoryStore:
"""
对话记忆存储
短期记忆: deque 保存最近 N 轮对话,自动滚动淘汰
长期记忆: 关键事实列表(生产环境可替换为向量数据库)
链路记录: 完整的多步骤调用链历史
使用示例:
memory = MemoryStore(max_history=20)
memory.add_user_message("你好")
memory.add_chain_result(chain_result)
"""
def __init__(self, max_history: int = 20):
self.logger = get_logger("MEMORY")
self.max_history = max_history
self._history: deque[Message] = deque(maxlen=max_history)
self._facts: list[str] = []
self._chains: list[dict] = [] # 调用链历史记录
self.logger.info(f"💾 Memory 初始化,最大历史: {max_history}")
# ── 写入接口 ────────────────────────────────────────────────
def add_user_message(self, content: str) -> None:
self._add(Message(role="user", content=content))
def add_assistant_message(self, content: str) -> None:
self._add(Message(role="assistant", content=content))
def add_tool_result(self, tool_name: str, result: str) -> None:
self._add(Message(
role="tool",
content=result,
metadata={"tool": tool_name},
))
def add_chain_result(self, chain_result: "ChainResult") -> None:
"""
记录完整的多步骤调用链结果
Args:
chain_result: ChainResult 实例
"""
# 写入对话历史assistant 角色)
self._add(Message(
role="chain",
content=chain_result.final_reply,
metadata={
"goal": chain_result.goal,
"total_steps": chain_result.total_steps,
"completed_steps": chain_result.completed_steps,
"success": chain_result.success,
"tools_used": [r.tool_name for r in chain_result.step_results],
},
))
# 写入链路追踪记录
chain_record = {
"timestamp": datetime.now().isoformat(),
"goal": chain_result.goal,
"steps": [
{
"step_id": r.step_id,
"tool_name": r.tool_name,
"success": r.success,
"output": r.output[:200],
"error": r.error,
}
for r in chain_result.step_results
],
"success": chain_result.success,
}
self._chains.append(chain_record)
self.logger.info(
f"🔗 调用链已记录: {chain_result.completed_steps}/{chain_result.total_steps} 步成功"
)
def add_fact(self, fact: str) -> None:
self._facts.append(fact)
self.logger.debug(f"📌 长期记忆新增: {fact}")
# ── 读取接口 ────────────────────────────────────────────────
def get_history(self, last_n: int | None = None) -> list[dict]:
messages = list(self._history)
if last_n:
messages = messages[-last_n:]
return [m.to_dict() for m in messages]
def get_facts(self) -> list[str]:
return list(self._facts)
def get_chain_history(self) -> list[dict]:
"""获取所有调用链历史记录"""
return list(self._chains)
def get_context_summary(self) -> str:
"""生成上下文摘要,供 LLM Prompt 使用"""
history = self.get_history(last_n=6)
lines = [f"[{m['role'].upper()}] {m['content'][:80]}" for m in history]
return "\n".join(lines) if lines else "(暂无对话历史)"
# ── 管理接口 ────────────────────────────────────────────────
def clear_history(self) -> None:
self._history.clear()
self.logger.info("🗑 对话历史已清空")
def stats(self) -> dict:
return {
"history_count": len(self._history),
"facts_count": len(self._facts),
"chain_count": len(self._chains),
"max_history": self.max_history,
}
def _add(self, message: Message) -> None:
self._history.append(message)
self.logger.debug(f"💬 [{message.role.upper()}] {message.content[:60]}...")