base_agent/tools/file_reader.py

178 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
tools/file_reader.py
文件读取工具 —— 读取本地文件内容,支持文本/JSON/CSV
配置通过 settings.tools['file_reader'] 读取
"""
import csv
import io
import json
from pathlib import Path
from config.settings import settings
from utils.logger import get_logger
logger = get_logger("TOOL.FileReader")
def _cfg(key: str, fallback=None):
return settings.tools['file_reader'].get(key, fallback)
class FileReaderTool:
name = "file_reader"
description = (
"读取本地文件内容,支持 .txt / .md / .py / .json / .csv / .yaml / .log 等文本文件。"
"文件必须位于 config.yaml file_reader.allowed_root 目录下。"
)
parameters = {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "文件路径(相对于 allowed_root 或绝对路径)",
},
"encoding": {
"type": "string",
"description": "文件编码,默认 utf-8",
},
"max_lines": {
"type": "integer",
"description": "最多读取行数0 表示全部读取",
},
},
"required": ["file_path"],
}
_TEXT_EXTENSIONS = {
".txt", ".md", ".py", ".js", ".ts", ".java", ".c", ".cpp",
".h", ".hpp", ".go", ".rs", ".rb", ".php", ".sh", ".bash",
".yaml", ".yml", ".toml", ".ini", ".cfg", ".conf",
".json", ".csv", ".log", ".xml", ".html", ".css", ".sql",
".env", ".gitignore", ".dockerfile",
}
def execute(
self,
file_path: str = "",
encoding: str = "utf-8",
max_lines: int = 0,
**_,
) -> str:
if not file_path or not file_path.strip():
return "❌ 参数错误: file_path 不能为空"
allowed_root = Path(_cfg('allowed_root', './workspace')).resolve()
max_size_kb = _cfg('max_file_size_kb', 512)
# 路径解析
path = Path(file_path)
if not path.is_absolute():
path = allowed_root / path
path = path.resolve()
logger.info(
f"📄 读取文件: {path}\n"
f" allowed_root={allowed_root} "
f"max_size={max_size_kb}KB [config]"
)
# 安全检查:必须在 allowed_root 内
try:
path.relative_to(allowed_root)
except ValueError:
return (
f"❌ 安全限制: 文件路径超出允许范围\n"
f" 路径: {path}\n"
f" 允许范围: {allowed_root}\n"
f" 请在 config.yaml → tools.file_reader.allowed_root 中调整"
)
if not path.exists():
return f"❌ 文件不存在: {path}"
if not path.is_file():
return f"❌ 路径不是文件: {path}"
# 文件大小检查
size_kb = path.stat().st_size / 1024
if size_kb > max_size_kb:
return (
f"❌ 文件过大: {size_kb:.1f} KB > 限制 {max_size_kb} KB\n"
f" 请在 config.yaml → tools.file_reader.max_file_size_kb 中调整"
)
# 扩展名检查
suffix = path.suffix.lower()
if suffix not in self._TEXT_EXTENSIONS:
return (
f"❌ 不支持的文件类型: {suffix}\n"
f" 支持类型: {', '.join(sorted(self._TEXT_EXTENSIONS))}"
)
# 读取文件
try:
if suffix == ".json":
return self._read_json(path, encoding)
if suffix == ".csv":
return self._read_csv(path, encoding, max_lines)
return self._read_text(path, encoding, max_lines)
except UnicodeDecodeError:
return (
f"❌ 编码错误: 无法以 {encoding} 解码文件\n"
f" 请尝试指定 encoding 参数,例如 'gbk''latin-1'"
)
except Exception as e:
return f"❌ 读取失败: {e}"
@staticmethod
def _read_text(path: Path, encoding: str, max_lines: int) -> str:
content = path.read_text(encoding=encoding)
lines = content.splitlines()
total = len(lines)
if max_lines and max_lines < total:
shown = lines[:max_lines]
omitted = total - max_lines
text = "\n".join(shown)
return (
f"📄 {path.name} ({total} 行,显示前 {max_lines} 行)\n"
f"{'' * 50}\n{text}\n"
f"{'' * 50}\n... 还有 {omitted} 行未显示"
)
return f"📄 {path.name} ({total} 行)\n{'' * 50}\n{content}"
@staticmethod
def _read_json(path: Path, encoding: str) -> str:
content = path.read_text(encoding=encoding)
try:
data = json.loads(content)
formatted = json.dumps(data, ensure_ascii=False, indent=2)
return f"📄 {path.name} (JSON)\n{'' * 50}\n{formatted}"
except json.JSONDecodeError as e:
return f"⚠️ JSON 解析失败: {e}\n原始内容:\n{content[:500]}"
@staticmethod
def _read_csv(path: Path, encoding: str, max_lines: int) -> str:
content = path.read_text(encoding=encoding)
reader = csv.reader(io.StringIO(content))
rows = list(reader)
total = len(rows)
limit = max_lines if max_lines else min(total, 50)
shown = rows[:limit]
# 计算列宽
if not shown:
return f"📄 {path.name} (CSV空文件)"
col_widths = [
max(len(str(row[i])) if i < len(row) else 0 for row in shown)
for i in range(len(shown[0]))
]
lines = [f"📄 {path.name} (CSV{total} 行)", "" * 50]
for row in shown:
cells = [
str(row[i]).ljust(col_widths[i]) if i < len(row) else ""
for i in range(len(shown[0]))
]
lines.append(" | ".join(cells))
if total > limit:
lines.append(f"... 还有 {total - limit} 行未显示")
return "\n".join(lines)