base_agent/tools/calculator.py

64 lines
2.6 KiB
Python
Raw Normal View History

2026-02-28 08:21:35 +00:00
"""计算器工具"""
# ════════════════════════════════════════════════════════════════
# tools/calculator.py — 数学计算工具
# ════════════════════════════════════════════════════════════════
"""
tools/calculator.py
安全的数学表达式计算工具使用 ast 模块避免 eval 注入风险
"""
import ast
import operator
from tools.base_tool import BaseTool, ToolResult
class CalculatorTool(BaseTool):
name = "calculator"
description = "计算数学表达式,支持加减乘除、幂运算、括号等"
parameters = {
"expression": {
"type": "string",
"description": "数学表达式,例如 '(1+2)*3''2**10'",
}
}
# 允许的运算符白名单(防止注入)
_OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.Mod: operator.mod,
ast.USub: operator.neg,
}
def execute(self, expression: str, **_) -> ToolResult:
try:
tree = ast.parse(expression, mode="eval")
result = self._eval_node(tree.body)
return ToolResult(
success=True,
output=f"{expression} = {result}",
metadata={"expression": expression, "result": result},
)
except (ValueError, TypeError, ZeroDivisionError) as exc:
return ToolResult(success=False, output=f"计算错误: {exc}")
def _eval_node(self, node: ast.AST) -> float:
"""递归解析 AST 节点"""
match node:
case ast.Constant(value=v) if isinstance(v, (int, float)):
return v
case ast.BinOp(left=left, op=op, right=right):
fn = self._OPERATORS.get(type(op))
if fn is None:
raise ValueError(f"不支持的运算符: {type(op).__name__}")
return fn(self._eval_node(left), self._eval_node(right))
case ast.UnaryOp(op=op, operand=operand):
fn = self._OPERATORS.get(type(op))
if fn is None:
raise ValueError(f"不支持的一元运算符: {type(op).__name__}")
return fn(self._eval_node(operand))
case _:
raise ValueError(f"不支持的表达式节点: {type(node).__name__}")