31 lines
869 B
Python
31 lines
869 B
Python
|
|
import inspect
|
||
|
|
from typing import Dict, Type, List
|
||
|
|
|
||
|
|
from app.tools.base import BaseTool
|
||
|
|
|
||
|
|
|
||
|
|
class ToolRegistry:
|
||
|
|
"""工具注册器"""
|
||
|
|
_tools: Dict[str, Type['BaseTool']] = {}
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def register(cls, tool_class: Type['BaseTool']):
|
||
|
|
"""注册工具类"""
|
||
|
|
if not inspect.isabstract(tool_class):
|
||
|
|
cls._tools[tool_class.__name__] = tool_class
|
||
|
|
return tool_class
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_all(cls) -> Dict[str, Type['BaseTool']]:
|
||
|
|
"""获取所有注册的工具类"""
|
||
|
|
return cls._tools.copy()
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_instances(cls, **kwargs) -> List['BaseTool']:
|
||
|
|
"""获取所有工具的实例"""
|
||
|
|
return [tool_cls(**kwargs) for tool_cls in cls._tools.values()]
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def clear(cls):
|
||
|
|
"""清空注册表(主要用于测试)"""
|
||
|
|
cls._tools.clear()
|