base_agent/database/database.py

64 lines
2.2 KiB
Python

# database.py
from sqlalchemy import create_engine, Column, Integer, String, DateTime, true
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from datetime import datetime
from config.settings import settings
Base = declarative_base()
class SkillUpload(Base):
__tablename__ = 'skill_uploads'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, nullable=False, unique=True)
description = Column(String, nullable=True)
skill_id = Column(String, nullable=False, unique=True)
upload_time = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<SkillUpload(name='{self.name}', upload_time='{self.upload_time}')>"
class DatabaseManager:
"""数据库管理器,提供文件上传和查询功能"""
def __init__(self):
self.engine = create_engine(settings.database.url)
Base.metadata.create_all(self.engine) # 创建表
self.Session = sessionmaker(bind=self.engine)
def add_skill_upload(self, skill_name: str, skill_id: str, description: str) -> None:
"""添加技能上传记录到数据库"""
session = self.Session()
new_skill = SkillUpload(name=skill_name,
skill_id=skill_id,
description=description)
session.add(new_skill)
session.commit()
session.close()
def get_skill_by_name(self, skill_name: str) -> SkillUpload | None:
"""根据名称查询单个技能上传记录"""
session = self.Session()
skill = session.query(SkillUpload).filter(SkillUpload.name == skill_name).first()
session.close()
return skill
def get_all_skills(self) -> list[SkillUpload]:
"""获取所有技能上传记录"""
session = self.Session()
skills = session.query(SkillUpload).all()
session.close()
return skills
def skill_exists(self, skill_name: str) -> bool:
"""检查技能是否已上传"""
session = self.Session()
exists = session.query(SkillUpload).filter(SkillUpload.name == skill_name).count() > 0
session.close()
return exists
db_manager: DatabaseManager = DatabaseManager()