base_agent/memory/memory_store.py

128 lines
4.8 KiB
Python
Raw Normal View History

2026-02-28 08:21:35 +00:00
"""记忆模块:对话历史管理"""
"""
memory/memory_store.py
Agent 记忆模块管理对话历史短期记忆与关键信息摘要长期记忆
"""
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from typing import Literal
from utils.logger import get_logger
# ── 消息数据结构 ───────────────────────────────────────────────
@dataclass
class Message:
"""单条对话消息"""
role: Literal["user", "assistant", "tool"]
content: str
timestamp: str = field(default_factory=lambda: datetime.now().strftime("%H:%M:%S"))
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"role": self.role,
"content": self.content,
"timestamp": self.timestamp,
}
# ── 记忆存储 ───────────────────────────────────────────────────
class MemoryStore:
"""
对话记忆存储
短期记忆: 使用 deque 保存最近 N 轮对话自动滚动淘汰旧消息
长期记忆: 保存关键事实摘要生产环境可替换为向量数据库
使用示例:
memory = MemoryStore(max_history=10)
memory.add_user_message("你好")
memory.add_assistant_message("你好!有什么可以帮你?")
history = memory.get_history()
"""
def __init__(self, max_history: int = 20):
"""
Args:
max_history: 短期记忆保留的最大消息条数
"""
self.logger = get_logger("MEMORY")
self.max_history = max_history
self._history: deque[Message] = deque(maxlen=max_history)
self._facts: list[str] = [] # 长期记忆:关键事实
self.logger.info(f"💾 Memory 初始化,最大历史: {max_history}")
# ── 写入接口 ────────────────────────────────────────────────
def add_user_message(self, content: str) -> None:
"""记录用户消息"""
self._add(Message(role="user", content=content))
def add_assistant_message(self, content: str) -> None:
"""记录 Agent 回复"""
self._add(Message(role="assistant", content=content))
def add_tool_result(self, tool_name: str, result: str) -> None:
"""记录工具调用结果"""
self._add(Message(
role="tool",
content=result,
metadata={"tool": tool_name},
))
def add_fact(self, fact: str) -> None:
"""向长期记忆中添加关键事实"""
self._facts.append(fact)
self.logger.debug(f"📌 长期记忆新增: {fact}")
# ── 读取接口 ────────────────────────────────────────────────
def get_history(self, last_n: int | None = None) -> list[dict]:
"""
获取对话历史LLM 上下文格式
Args:
last_n: 仅返回最近 N None 表示全部
Returns:
消息字典列表格式: [{"role": ..., "content": ...}, ...]
"""
messages = list(self._history)
if last_n:
messages = messages[-last_n:]
return [m.to_dict() for m in messages]
def get_facts(self) -> list[str]:
"""获取所有长期记忆事实"""
return list(self._facts)
def get_context_summary(self) -> str:
"""生成上下文摘要字符串,供 LLM Prompt 使用"""
history = self.get_history(last_n=6)
lines = [f"[{m['role'].upper()}] {m['content'][:80]}" for m in history]
return "\n".join(lines) if lines else "(暂无对话历史)"
# ── 管理接口 ────────────────────────────────────────────────
def clear_history(self) -> None:
"""清空短期对话历史"""
self._history.clear()
self.logger.info("🗑 对话历史已清空")
def stats(self) -> dict:
"""返回记忆统计信息"""
return {
"history_count": len(self._history),
"facts_count": len(self._facts),
"max_history": self.max_history,
}
# ── 私有方法 ────────────────────────────────────────────────
def _add(self, message: Message) -> None:
self._history.append(message)
self.logger.debug(f"💬 [{message.role.upper()}] {message.content[:60]}...")