# database/db_manager.py - 数据库 CRUD 操作封装 import os import shutil from typing import List, Optional, Dict, Any from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session import config from database.models import Base, Project, RawRequirement, FunctionalRequirement, CodeFile class DBManager: """SQLite 数据库管理器,封装所有 CRUD 操作""" def __init__(self, db_path: str = None): db_path = db_path or config.DB_PATH os.makedirs(os.path.dirname(db_path), exist_ok=True) self.engine = create_engine(f"sqlite:///{db_path}", echo=False) Base.metadata.create_all(self.engine) self._Session = sessionmaker(bind=self.engine) def _session(self) -> Session: return self._Session() # ══════════════════════════════════════════════════ # Project # ══════════════════════════════════════════════════ def create_project(self, project: Project) -> int: with self._session() as s: s.add(project) s.commit() s.refresh(project) return project.id def get_project_by_name(self, name: str) -> Optional[Project]: with self._session() as s: return s.query(Project).filter_by(name=name).first() def get_project_by_id(self, project_id: int) -> Optional[Project]: with self._session() as s: return s.get(Project, project_id) def update_project(self, project: Project) -> None: with self._session() as s: s.merge(project) s.commit() def list_projects(self) -> List[Project]: """返回所有项目(按创建时间倒序)""" with self._session() as s: return s.query(Project).order_by(Project.created_at.desc()).all() def delete_project(self, project_id: int, delete_output: bool = False) -> bool: """ 删除指定项目及其所有关联数据(级联删除)。 Args: project_id: 项目 ID delete_output: 是否同时删除磁盘上的输出目录 Returns: True 表示删除成功,False 表示项目不存在 """ with self._session() as s: project = s.get(Project, project_id) if project is None: return False output_dir = project.output_dir s.delete(project) s.commit() # 可选:删除磁盘输出目录 if delete_output and output_dir and os.path.isdir(output_dir): shutil.rmtree(output_dir, ignore_errors=True) return True def get_project_stats(self, project_id: int) -> Dict[str, int]: """ 获取项目统计信息:需求数、已生成代码数、模块数。 Returns: {"raw_req_count": n, "func_req_count": n, "generated_count": n, "module_count": n, "code_file_count": n} """ with self._session() as s: raw_count = s.query(RawRequirement).filter_by(project_id=project_id).count() func_reqs = ( s.query(FunctionalRequirement) .filter_by(project_id=project_id) .all() ) gen_count = sum(1 for r in func_reqs if r.status == "generated") modules = {r.module or config.DEFAULT_MODULE for r in func_reqs} code_count = ( s.query(CodeFile) .join(FunctionalRequirement) .filter(FunctionalRequirement.project_id == project_id) .count() ) return { "raw_req_count": raw_count, "func_req_count": len(func_reqs), "generated_count": gen_count, "module_count": len(modules), "code_file_count": code_count, } # ══════════════════════════════════════════════════ # Project Full Info(需求-模块-代码关系) # ══════════════════════════════════════════════════ def get_project_full_info(self, project_id: int) -> Optional[Dict[str, Any]]: """ 获取项目完整信息,包含需求-模块-代码之间的关系树。 Returns:: { "project": Project, "stats": {...}, "modules": { "": { "requirements": [ { "req": FunctionalRequirement, "code_files": [CodeFile, ...] }, ... ] }, ... }, "raw_requirements": [RawRequirement, ...] } Returns None 若项目不存在。 """ with self._session() as s: project = s.get(Project, project_id) if project is None: return None raw_reqs = ( s.query(RawRequirement) .filter_by(project_id=project_id) .order_by(RawRequirement.created_at) .all() ) func_reqs = ( s.query(FunctionalRequirement) .filter_by(project_id=project_id) .order_by(FunctionalRequirement.index_no) .all() ) code_files = ( s.query(CodeFile) .join(FunctionalRequirement) .filter(FunctionalRequirement.project_id == project_id) .all() ) # 构建 func_req_id → [CodeFile] 映射 code_map: Dict[int, List[CodeFile]] = {} for cf in code_files: code_map.setdefault(cf.func_req_id, []).append(cf) # 按模块分组 modules: Dict[str, Dict] = {} for req in func_reqs: mod = req.module or config.DEFAULT_MODULE modules.setdefault(mod, {"requirements": []}) modules[mod]["requirements"].append({ "req": req, "code_files": code_map.get(req.id, []), }) stats = self.get_project_stats(project_id) return { "project": project, "stats": stats, "modules": modules, "raw_requirements": raw_reqs, } # ══════════════════════════════════════════════════ # RawRequirement # ══════════════════════════════════════════════════ def create_raw_requirement(self, raw_req: RawRequirement) -> int: with self._session() as s: s.add(raw_req) s.commit() s.refresh(raw_req) return raw_req.id def get_raw_requirement(self, raw_req_id: int) -> Optional[RawRequirement]: with self._session() as s: return s.get(RawRequirement, raw_req_id) # ══════════════════════════════════════════════════ # FunctionalRequirement # ══════════════════════════════════════════════════ def create_functional_requirement(self, req: FunctionalRequirement) -> int: with self._session() as s: s.add(req) s.commit() s.refresh(req) return req.id def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]: with self._session() as s: return s.get(FunctionalRequirement, req_id) def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]: with self._session() as s: return ( s.query(FunctionalRequirement) .filter_by(project_id=project_id) .order_by(FunctionalRequirement.index_no) .all() ) def update_functional_requirement(self, req: FunctionalRequirement) -> None: with self._session() as s: s.merge(req) s.commit() def delete_functional_requirement(self, req_id: int) -> None: with self._session() as s: obj = s.get(FunctionalRequirement, req_id) if obj: s.delete(obj) s.commit() def bulk_update_modules(self, updates: List[dict]) -> None: """ 批量更新功能需求的 module 字段。 Args: updates: [{"function_name": "...", "module": "..."}, ...] """ with self._session() as s: name_to_module = {u["function_name"]: u["module"] for u in updates} reqs = s.query(FunctionalRequirement).filter( FunctionalRequirement.function_name.in_(name_to_module.keys()) ).all() for req in reqs: req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE) s.commit() # ══════════════════════════════════════════════════ # CodeFile # ══════════════════════════════════════════════════ def upsert_code_file(self, code_file: CodeFile) -> int: with self._session() as s: existing = ( s.query(CodeFile) .filter_by(func_req_id=code_file.func_req_id) .first() ) if existing: existing.file_name = code_file.file_name existing.file_path = code_file.file_path existing.module = code_file.module existing.language = code_file.language existing.content = code_file.content s.commit() return existing.id else: s.add(code_file) s.commit() s.refresh(code_file) return code_file.id def list_code_files(self, project_id: int) -> List[CodeFile]: with self._session() as s: return ( s.query(CodeFile) .join(FunctionalRequirement) .filter(FunctionalRequirement.project_id == project_id) .all() )