base_agent/tools/calculator.py

61 lines
2.7 KiB
Python
Raw Normal View History

2026-02-28 08:21:35 +00:00
# ════════════════════════════════════════════════════════════════
2026-03-09 05:37:29 +00:00
# tools/calculator.py
2026-02-28 08:21:35 +00:00
# ════════════════════════════════════════════════════════════════
2026-03-09 05:37:29 +00:00
"""安全的数学表达式计算工具AST 解析,防注入)"""
2026-02-28 08:21:35 +00:00
import ast
import operator
2026-03-09 05:37:29 +00:00
from config.settings import settings
2026-02-28 08:21:35 +00:00
from tools.base_tool import BaseTool, ToolResult
class CalculatorTool(BaseTool):
name = "calculator"
description = "计算数学表达式,支持加减乘除、幂运算、括号等"
parameters = {
2026-03-09 05:37:29 +00:00
"expression": {"type": "string", "description": "数学表达式,例如 '(1+2)*3'"},
2026-02-28 08:21:35 +00:00
}
_OPERATORS = {
2026-03-09 05:37:29 +00:00
ast.Add: operator.add, ast.Sub: operator.sub,
ast.Mult: operator.mul, ast.Div: operator.truediv,
ast.Pow: operator.pow, ast.Mod: operator.mod,
2026-02-28 08:21:35 +00:00
ast.USub: operator.neg,
}
2026-03-09 05:37:29 +00:00
def __init__(self):
super().__init__()
# 从配置读取精度
self._precision = settings.tools.calculator.precision
self.logger.debug(f"⚙️ Calculator 精度: {self._precision}")
2026-02-28 08:21:35 +00:00
def execute(self, expression: str, **_) -> ToolResult:
try:
tree = ast.parse(expression, mode="eval")
result = self._eval_node(tree.body)
2026-03-09 05:37:29 +00:00
result = round(result, self._precision)
2026-02-28 08:21:35 +00:00
return ToolResult(
success=True,
output=f"{expression} = {result}",
metadata={"expression": expression, "result": result},
)
except (ValueError, TypeError, ZeroDivisionError) as exc:
return ToolResult(success=False, output=f"计算错误: {exc}")
def _eval_node(self, node: ast.AST) -> float:
match node:
case ast.Constant(value=v) if isinstance(v, (int, float)):
return v
case ast.BinOp(left=left, op=op, right=right):
fn = self._OPERATORS.get(type(op))
if fn is None:
raise ValueError(f"不支持的运算符: {type(op).__name__}")
return fn(self._eval_node(left), self._eval_node(right))
case ast.UnaryOp(op=op, operand=operand):
fn = self._OPERATORS.get(type(op))
if fn is None:
raise ValueError(f"不支持的一元运算符: {type(op).__name__}")
return fn(self._eval_node(operand))
case _:
2026-03-09 05:37:29 +00:00
raise ValueError(f"不支持的节点: {type(node).__name__}")