113 lines
3.5 KiB
Python
113 lines
3.5 KiB
Python
"""
|
|
tools/calculator.py
|
|
数学计算工具 —— 支持基本四则运算及常用数学函数
|
|
配置通过 settings.tools['calculator'] 读取
|
|
"""
|
|
|
|
import math
|
|
import operator
|
|
from typing import Any
|
|
|
|
from config.settings import settings
|
|
from tools.base_tool import BaseTool
|
|
from utils.logger import get_logger
|
|
|
|
logger = get_logger("TOOL.Calculator")
|
|
|
|
|
|
def _cfg(key: str, fallback=None):
|
|
return settings.tools['calculator'].get(key, fallback)
|
|
|
|
|
|
class Tool(BaseTool):
|
|
name = "calculator"
|
|
description = (
|
|
"执行数学计算,支持四则运算、幂运算、开方、三角函数、对数等。"
|
|
"输入数学表达式字符串,返回计算结果。"
|
|
)
|
|
parameters = {
|
|
"type": "object",
|
|
"properties": {
|
|
"expression": {
|
|
"type": "string",
|
|
"description": (
|
|
"数学表达式,例如: '2 + 3 * 4', 'sqrt(16)', "
|
|
"'sin(3.14159/2)', 'log(100, 10)', '2 ** 10'"
|
|
),
|
|
},
|
|
},
|
|
"required": ["expression"],
|
|
}
|
|
|
|
# 安全的内置函数白名单
|
|
_SAFE_GLOBALS: dict[str, Any] = {
|
|
"__builtins__": {},
|
|
# 基本数学
|
|
"abs": abs,
|
|
"round": round,
|
|
"pow": pow,
|
|
"min": min,
|
|
"max": max,
|
|
# math 模块常用函数
|
|
"sqrt": math.sqrt,
|
|
"ceil": math.ceil,
|
|
"floor": math.floor,
|
|
"log": math.log,
|
|
"log2": math.log2,
|
|
"log10": math.log10,
|
|
"exp": math.exp,
|
|
"sin": math.sin,
|
|
"cos": math.cos,
|
|
"tan": math.tan,
|
|
"asin": math.asin,
|
|
"acos": math.acos,
|
|
"atan": math.atan,
|
|
"atan2": math.atan2,
|
|
"pi": math.pi,
|
|
"e": math.e,
|
|
"inf": math.inf,
|
|
"factorial": math.factorial,
|
|
"gcd": math.gcd,
|
|
"hypot": math.hypot,
|
|
}
|
|
|
|
def execute(self, expression: str = "", **_) -> str:
|
|
if not expression or not expression.strip():
|
|
return "❌ 参数错误: expression 不能为空"
|
|
|
|
expr = expression.strip()
|
|
logger.info(f"🔢 计算表达式: {expr}")
|
|
|
|
# 安全检查:禁止危险关键字
|
|
forbidden = ["import", "exec", "eval", "open", "os", "sys",
|
|
"__", "compile", "globals", "locals", "getattr"]
|
|
for kw in forbidden:
|
|
if kw in expr:
|
|
return f"❌ 安全限制: 表达式包含禁止关键字 '{kw}'"
|
|
|
|
try:
|
|
precision = _cfg('precision', 10)
|
|
result = eval(expr, self._SAFE_GLOBALS, {}) # noqa: S307
|
|
|
|
# 格式化输出
|
|
if isinstance(result, float):
|
|
# 去除多余尾零
|
|
formatted = f"{result:.{precision}f}".rstrip("0").rstrip(".")
|
|
elif isinstance(result, complex):
|
|
formatted = str(result)
|
|
else:
|
|
formatted = str(result)
|
|
|
|
logger.info(f"✅ 计算结果: {expr} = {formatted}")
|
|
return f"{expr} = {formatted}"
|
|
|
|
except ZeroDivisionError:
|
|
return f"❌ 计算错误: 除零错误 表达式: {expr}"
|
|
except OverflowError:
|
|
return f"❌ 计算错误: 数值溢出 表达式: {expr}"
|
|
except ValueError as e:
|
|
return f"❌ 计算错误: {e} 表达式: {expr}"
|
|
except SyntaxError:
|
|
return f"❌ 语法错误: 无法解析表达式 '{expr}'"
|
|
except Exception as e:
|
|
return f"❌ 计算失败: {e} 表达式: {expr}" |