base_agent/memory/memory_store.py

128 lines
4.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""记忆模块:对话历史管理"""
"""
memory/memory_store.py
Agent 记忆模块:管理对话历史(短期记忆)与关键信息摘要(长期记忆)
"""
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from typing import Literal
from utils.logger import get_logger
# ── 消息数据结构 ───────────────────────────────────────────────
@dataclass
class Message:
"""单条对话消息"""
role: Literal["user", "assistant", "tool"]
content: str
timestamp: str = field(default_factory=lambda: datetime.now().strftime("%H:%M:%S"))
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"role": self.role,
"content": self.content,
"timestamp": self.timestamp,
}
# ── 记忆存储 ───────────────────────────────────────────────────
class MemoryStore:
"""
对话记忆存储
短期记忆: 使用 deque 保存最近 N 轮对话,自动滚动淘汰旧消息
长期记忆: 保存关键事实摘要(生产环境可替换为向量数据库)
使用示例:
memory = MemoryStore(max_history=10)
memory.add_user_message("你好")
memory.add_assistant_message("你好!有什么可以帮你?")
history = memory.get_history()
"""
def __init__(self, max_history: int = 20):
"""
Args:
max_history: 短期记忆保留的最大消息条数
"""
self.logger = get_logger("MEMORY")
self.max_history = max_history
self._history: deque[Message] = deque(maxlen=max_history)
self._facts: list[str] = [] # 长期记忆:关键事实
self.logger.info(f"💾 Memory 初始化,最大历史: {max_history}")
# ── 写入接口 ────────────────────────────────────────────────
def add_user_message(self, content: str) -> None:
"""记录用户消息"""
self._add(Message(role="user", content=content))
def add_assistant_message(self, content: str) -> None:
"""记录 Agent 回复"""
self._add(Message(role="assistant", content=content))
def add_tool_result(self, tool_name: str, result: str) -> None:
"""记录工具调用结果"""
self._add(Message(
role="tool",
content=result,
metadata={"tool": tool_name},
))
def add_fact(self, fact: str) -> None:
"""向长期记忆中添加关键事实"""
self._facts.append(fact)
self.logger.debug(f"📌 长期记忆新增: {fact}")
# ── 读取接口 ────────────────────────────────────────────────
def get_history(self, last_n: int | None = None) -> list[dict]:
"""
获取对话历史LLM 上下文格式)
Args:
last_n: 仅返回最近 N 条None 表示全部
Returns:
消息字典列表,格式: [{"role": ..., "content": ...}, ...]
"""
messages = list(self._history)
if last_n:
messages = messages[-last_n:]
return [m.to_dict() for m in messages]
def get_facts(self) -> list[str]:
"""获取所有长期记忆事实"""
return list(self._facts)
def get_context_summary(self) -> str:
"""生成上下文摘要字符串,供 LLM Prompt 使用"""
history = self.get_history(last_n=6)
lines = [f"[{m['role'].upper()}] {m['content'][:80]}" for m in history]
return "\n".join(lines) if lines else "(暂无对话历史)"
# ── 管理接口 ────────────────────────────────────────────────
def clear_history(self) -> None:
"""清空短期对话历史"""
self._history.clear()
self.logger.info("🗑 对话历史已清空")
def stats(self) -> dict:
"""返回记忆统计信息"""
return {
"history_count": len(self._history),
"facts_count": len(self._facts),
"max_history": self.max_history,
}
# ── 私有方法 ────────────────────────────────────────────────
def _add(self, message: Message) -> None:
self._history.append(message)
self.logger.debug(f"💬 [{message.role.upper()}] {message.content[:60]}...")