64 lines
2.6 KiB
Python
64 lines
2.6 KiB
Python
"""计算器工具"""
|
|
# ════════════════════════════════════════════════════════════════
|
|
# tools/calculator.py — 数学计算工具
|
|
# ════════════════════════════════════════════════════════════════
|
|
"""
|
|
tools/calculator.py
|
|
安全的数学表达式计算工具(使用 ast 模块避免 eval 注入风险)
|
|
"""
|
|
|
|
import ast
|
|
import operator
|
|
from tools.base_tool import BaseTool, ToolResult
|
|
|
|
|
|
class CalculatorTool(BaseTool):
|
|
name = "calculator"
|
|
description = "计算数学表达式,支持加减乘除、幂运算、括号等"
|
|
parameters = {
|
|
"expression": {
|
|
"type": "string",
|
|
"description": "数学表达式,例如 '(1+2)*3' 或 '2**10'",
|
|
}
|
|
}
|
|
|
|
# 允许的运算符白名单(防止注入)
|
|
_OPERATORS = {
|
|
ast.Add: operator.add,
|
|
ast.Sub: operator.sub,
|
|
ast.Mult: operator.mul,
|
|
ast.Div: operator.truediv,
|
|
ast.Pow: operator.pow,
|
|
ast.Mod: operator.mod,
|
|
ast.USub: operator.neg,
|
|
}
|
|
|
|
def execute(self, expression: str, **_) -> ToolResult:
|
|
try:
|
|
tree = ast.parse(expression, mode="eval")
|
|
result = self._eval_node(tree.body)
|
|
return ToolResult(
|
|
success=True,
|
|
output=f"{expression} = {result}",
|
|
metadata={"expression": expression, "result": result},
|
|
)
|
|
except (ValueError, TypeError, ZeroDivisionError) as exc:
|
|
return ToolResult(success=False, output=f"计算错误: {exc}")
|
|
|
|
def _eval_node(self, node: ast.AST) -> float:
|
|
"""递归解析 AST 节点"""
|
|
match node:
|
|
case ast.Constant(value=v) if isinstance(v, (int, float)):
|
|
return v
|
|
case ast.BinOp(left=left, op=op, right=right):
|
|
fn = self._OPERATORS.get(type(op))
|
|
if fn is None:
|
|
raise ValueError(f"不支持的运算符: {type(op).__name__}")
|
|
return fn(self._eval_node(left), self._eval_node(right))
|
|
case ast.UnaryOp(op=op, operand=operand):
|
|
fn = self._OPERATORS.get(type(op))
|
|
if fn is None:
|
|
raise ValueError(f"不支持的一元运算符: {type(op).__name__}")
|
|
return fn(self._eval_node(operand))
|
|
case _:
|
|
raise ValueError(f"不支持的表达式节点: {type(node).__name__}") |