61 lines
2.7 KiB
Python
61 lines
2.7 KiB
Python
# ════════════════════════════════════════════════════════════════
|
||
# tools/calculator.py
|
||
# ════════════════════════════════════════════════════════════════
|
||
"""安全的数学表达式计算工具(AST 解析,防注入)"""
|
||
|
||
import ast
|
||
import operator
|
||
|
||
from config.settings import settings
|
||
from tools.base_tool import BaseTool, ToolResult
|
||
|
||
|
||
class CalculatorTool(BaseTool):
|
||
name = "calculator"
|
||
description = "计算数学表达式,支持加减乘除、幂运算、括号等"
|
||
parameters = {
|
||
"expression": {"type": "string", "description": "数学表达式,例如 '(1+2)*3'"},
|
||
}
|
||
|
||
_OPERATORS = {
|
||
ast.Add: operator.add, ast.Sub: operator.sub,
|
||
ast.Mult: operator.mul, ast.Div: operator.truediv,
|
||
ast.Pow: operator.pow, ast.Mod: operator.mod,
|
||
ast.USub: operator.neg,
|
||
}
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
# 从配置读取精度
|
||
self._precision = settings.tools['calculator']['precision']
|
||
self.logger.debug(f"⚙️ Calculator 精度: {self._precision}")
|
||
|
||
def execute(self, expression: str, **_) -> ToolResult:
|
||
try:
|
||
tree = ast.parse(expression, mode="eval")
|
||
result = self._eval_node(tree.body)
|
||
result = round(result, self._precision)
|
||
return ToolResult(
|
||
success=True,
|
||
output=f"{expression} = {result}",
|
||
metadata={"expression": expression, "result": result},
|
||
)
|
||
except (ValueError, TypeError, ZeroDivisionError) as exc:
|
||
return ToolResult(success=False, output=f"计算错误: {exc}")
|
||
|
||
def _eval_node(self, node: ast.AST) -> float:
|
||
match node:
|
||
case ast.Constant(value=v) if isinstance(v, (int, float)):
|
||
return v
|
||
case ast.BinOp(left=left, op=op, right=right):
|
||
fn = self._OPERATORS.get(type(op))
|
||
if fn is None:
|
||
raise ValueError(f"不支持的运算符: {type(op).__name__}")
|
||
return fn(self._eval_node(left), self._eval_node(right))
|
||
case ast.UnaryOp(op=op, operand=operand):
|
||
fn = self._OPERATORS.get(type(op))
|
||
if fn is None:
|
||
raise ValueError(f"不支持的一元运算符: {type(op).__name__}")
|
||
return fn(self._eval_node(operand))
|
||
case _:
|
||
raise ValueError(f"不支持的节点: {type(node).__name__}") |