base_agent/tools/calculator.py

61 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ════════════════════════════════════════════════════════════════
# tools/calculator.py
# ════════════════════════════════════════════════════════════════
"""安全的数学表达式计算工具AST 解析,防注入)"""
import ast
import operator
from config.settings import settings
from tools.base_tool import BaseTool, ToolResult
class CalculatorTool(BaseTool):
name = "calculator"
description = "计算数学表达式,支持加减乘除、幂运算、括号等"
parameters = {
"expression": {"type": "string", "description": "数学表达式,例如 '(1+2)*3'"},
}
_OPERATORS = {
ast.Add: operator.add, ast.Sub: operator.sub,
ast.Mult: operator.mul, ast.Div: operator.truediv,
ast.Pow: operator.pow, ast.Mod: operator.mod,
ast.USub: operator.neg,
}
def __init__(self):
super().__init__()
# 从配置读取精度
self._precision = settings.tools.calculator.precision
self.logger.debug(f"⚙️ Calculator 精度: {self._precision}")
def execute(self, expression: str, **_) -> ToolResult:
try:
tree = ast.parse(expression, mode="eval")
result = self._eval_node(tree.body)
result = round(result, self._precision)
return ToolResult(
success=True,
output=f"{expression} = {result}",
metadata={"expression": expression, "result": result},
)
except (ValueError, TypeError, ZeroDivisionError) as exc:
return ToolResult(success=False, output=f"计算错误: {exc}")
def _eval_node(self, node: ast.AST) -> float:
match node:
case ast.Constant(value=v) if isinstance(v, (int, float)):
return v
case ast.BinOp(left=left, op=op, right=right):
fn = self._OPERATORS.get(type(op))
if fn is None:
raise ValueError(f"不支持的运算符: {type(op).__name__}")
return fn(self._eval_node(left), self._eval_node(right))
case ast.UnaryOp(op=op, operand=operand):
fn = self._OPERATORS.get(type(op))
if fn is None:
raise ValueError(f"不支持的一元运算符: {type(op).__name__}")
return fn(self._eval_node(operand))
case _:
raise ValueError(f"不支持的节点: {type(node).__name__}")