AIDeveloper-PC/requirements_generator/database/db_manager.py

158 lines
6.4 KiB
Python

# database/db_manager.py - 数据库 CRUD 操作封装
import os
from typing import List, Optional
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
import config
from database.models import Base, Project, RawRequirement, FunctionalRequirement, CodeFile, ChangeHistory
class DBManager:
"""SQLite 数据库管理器,封装所有 CRUD 操作"""
def __init__(self, db_path: str = None):
db_path = db_path or config.DB_PATH
os.makedirs(os.path.dirname(db_path), exist_ok=True)
self.engine = create_engine(f"sqlite:///{db_path}", echo=False)
Base.metadata.create_all(self.engine)
self._Session = sessionmaker(bind=self.engine)
def _session(self) -> Session:
return self._Session()
# ══════════════════════════════════════════════════
# Project
# ══════════════════════════════════════════════════
def create_project(self, project: Project) -> int:
with self._session() as s:
s.add(project)
s.commit()
s.refresh(project)
return project.id
def get_project_by_name(self, name: str) -> Optional[Project]:
with self._session() as s:
return s.query(Project).filter_by(name=name).first()
def get_project_by_id(self, project_id: int) -> Optional[Project]:
with self._session() as s:
return s.get(Project, project_id)
def update_project(self, project: Project) -> None:
with self._session() as s:
s.merge(project)
s.commit()
def list_projects(self) -> List[Project]:
with self._session() as s:
return s.query(Project).order_by(Project.created_at.desc()).all()
def delete_project(self, project_id: int) -> None:
with self._session() as s:
project = s.get(Project, project_id)
if project:
s.delete(project)
s.commit()
# ══════════════════════════════════════════════════
# RawRequirement
# ══════════════════════════════════════════════════
def create_raw_requirement(self, raw_req: RawRequirement) -> int:
with self._session() as s:
s.add(raw_req)
s.commit()
s.refresh(raw_req)
return raw_req.id
def get_raw_requirement(self, raw_req_id: int) -> Optional[RawRequirement]:
with self._session() as s:
return s.get(RawRequirement, raw_req_id)
# ══════════════════════════════════════════════════
# FunctionalRequirement
# ══════════════════════════════════════════════════
def create_functional_requirement(self, req: FunctionalRequirement) -> int:
with self._session() as s:
s.add(req)
s.commit()
s.refresh(req)
return req.id
def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]:
with self._session() as s:
return s.get(FunctionalRequirement, req_id)
def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]:
with self._session() as s:
return (
s.query(FunctionalRequirement)
.filter_by(project_id=project_id)
.order_by(FunctionalRequirement.index_no)
.all()
)
def update_functional_requirement(self, req: FunctionalRequirement) -> None:
with self._session() as s:
s.merge(req)
s.commit()
def delete_functional_requirement(self, req_id: int) -> None:
with self._session() as s:
obj = s.get(FunctionalRequirement, req_id)
if obj:
s.delete(obj)
s.commit()
# ══════════════════════════════════════════════════
# CodeFile
# ══════════════════════════════════════════════════
def upsert_code_file(self, code_file: CodeFile) -> int:
with self._session() as s:
existing = (
s.query(CodeFile)
.filter_by(func_req_id=code_file.func_req_id)
.first()
)
if existing:
existing.file_name = code_file.file_name
existing.file_path = code_file.file_path
existing.module = code_file.module
existing.language = code_file.language
existing.content = code_file.content
s.commit()
return existing.id
else:
s.add(code_file)
s.commit()
s.refresh(code_file)
return code_file.id
def list_code_files(self, project_id: int) -> List[CodeFile]:
with self._session() as s:
return (
s.query(CodeFile)
.join(FunctionalRequirement)
.filter(FunctionalRequirement.project_id == project_id)
.all()
)
# ══════════════════════════════════════════════════
# ChangeHistory
# ══════════════════════════════════════════════════
def create_change_history(self, change: ChangeHistory) -> int:
with self._session() as s:
s.add(change)
s.commit()
s.refresh(change)
return change.id
def list_change_history(self, project_id: int) -> List[ChangeHistory]:
with self._session() as s:
return s.query(ChangeHistory).filter_by(project_id=project_id).order_by(ChangeHistory.change_time.desc()).all()