# database/db_manager.py - 数据库操作管理器 import sqlite3 import os from datetime import datetime from typing import List, Optional from contextlib import contextmanager from database.models import ( CREATE_TABLES_SQL, Project, RawRequirement, FunctionalRequirement, CodeFile ) import config class DBManager: """SQLite 数据库管理器,封装所有 CRUD 操作""" def __init__(self, db_path: str = config.DB_PATH): self.db_path = db_path os.makedirs(os.path.dirname(db_path), exist_ok=True) self._init_db() # ── 连接上下文管理器 ────────────────────────────────── @contextmanager def _get_conn(self): """获取数据库连接(自动提交/回滚)""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row conn.execute("PRAGMA foreign_keys = ON") try: yield conn conn.commit() except Exception: conn.rollback() raise finally: conn.close() def _init_db(self): """初始化数据库,创建所有表""" with self._get_conn() as conn: conn.executescript(CREATE_TABLES_SQL) # ══════════════════════════════════════════════════ # Project CRUD # ══════════════════════════════════════════════════ def create_project(self, project: Project) -> int: """创建项目,返回新项目 ID""" sql = """ INSERT INTO projects (name, description, language, output_dir, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?) """ with self._get_conn() as conn: cur = conn.execute(sql, ( project.name, project.description, project.language, project.output_dir, project.created_at, project.updated_at )) return cur.lastrowid def get_project_by_id(self, project_id: int) -> Optional[Project]: """根据 ID 查询项目""" with self._get_conn() as conn: row = conn.execute( "SELECT * FROM projects WHERE id = ?", (project_id,) ).fetchone() if row is None: return None return Project( id=row["id"], name=row["name"], description=row["description"], language=row["language"], output_dir=row["output_dir"], created_at=row["created_at"], updated_at=row["updated_at"] ) def get_project_by_name(self, name: str) -> Optional[Project]: """根据名称查询项目""" with self._get_conn() as conn: row = conn.execute( "SELECT * FROM projects WHERE name = ?", (name,) ).fetchone() if row is None: return None return Project( id=row["id"], name=row["name"], description=row["description"], language=row["language"], output_dir=row["output_dir"], created_at=row["created_at"], updated_at=row["updated_at"] ) def list_projects(self) -> List[Project]: """列出所有项目""" with self._get_conn() as conn: rows = conn.execute( "SELECT * FROM projects ORDER BY created_at DESC" ).fetchall() return [ Project( id=r["id"], name=r["name"], description=r["description"], language=r["language"], output_dir=r["output_dir"], created_at=r["created_at"], updated_at=r["updated_at"] ) for r in rows ] def update_project(self, project: Project) -> None: """更新项目信息""" project.updated_at = datetime.now().isoformat() sql = """ UPDATE projects SET name=?, description=?, language=?, output_dir=?, updated_at=? WHERE id=? """ with self._get_conn() as conn: conn.execute(sql, ( project.name, project.description, project.language, project.output_dir, project.updated_at, project.id )) def delete_project(self, project_id: int) -> None: """删除项目(级联删除所有关联数据)""" with self._get_conn() as conn: conn.execute("DELETE FROM projects WHERE id = ?", (project_id,)) # ══════════════════════════════════════════════════ # RawRequirement CRUD # ══════════════════════════════════════════════════ def create_raw_requirement(self, req: RawRequirement) -> int: """创建原始需求,返回新记录 ID""" sql = """ INSERT INTO raw_requirements (project_id, content, source_type, source_name, knowledge, created_at) VALUES (?, ?, ?, ?, ?, ?) """ with self._get_conn() as conn: cur = conn.execute(sql, ( req.project_id, req.content, req.source_type, req.source_name, req.knowledge, req.created_at )) return cur.lastrowid def get_raw_requirement(self, req_id: int) -> Optional[RawRequirement]: """根据 ID 查询原始需求""" with self._get_conn() as conn: row = conn.execute( "SELECT * FROM raw_requirements WHERE id = ?", (req_id,) ).fetchone() if row is None: return None return RawRequirement( id=row["id"], project_id=row["project_id"], content=row["content"], source_type=row["source_type"], source_name=row["source_name"], knowledge=row["knowledge"], created_at=row["created_at"] ) def list_raw_requirements_by_project(self, project_id: int) -> List[RawRequirement]: """查询项目下所有原始需求""" with self._get_conn() as conn: rows = conn.execute( "SELECT * FROM raw_requirements WHERE project_id = ? ORDER BY created_at", (project_id,) ).fetchall() return [ RawRequirement( id=r["id"], project_id=r["project_id"], content=r["content"], source_type=r["source_type"], source_name=r["source_name"], knowledge=r["knowledge"], created_at=r["created_at"] ) for r in rows ] # ══════════════════════════════════════════════════ # FunctionalRequirement CRUD # ══════════════════════════════════════════════════ def create_functional_requirement(self, req: FunctionalRequirement) -> int: """创建功能需求,返回新记录 ID""" sql = """ INSERT INTO functional_requirements (project_id, raw_req_id, index_no, title, description, function_name, priority, status, is_custom, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ with self._get_conn() as conn: cur = conn.execute(sql, ( req.project_id, req.raw_req_id, req.index_no, req.title, req.description, req.function_name, req.priority, req.status, int(req.is_custom), req.created_at, req.updated_at )) return cur.lastrowid def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]: """根据 ID 查询功能需求""" with self._get_conn() as conn: row = conn.execute( "SELECT * FROM functional_requirements WHERE id = ?", (req_id,) ).fetchone() if row is None: return None return self._row_to_func_req(row) def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]: """查询项目下所有功能需求(按序号排序)""" with self._get_conn() as conn: rows = conn.execute( """SELECT * FROM functional_requirements WHERE project_id = ? ORDER BY index_no""", (project_id,) ).fetchall() return [self._row_to_func_req(r) for r in rows] def update_functional_requirement(self, req: FunctionalRequirement) -> None: """更新功能需求""" req.updated_at = datetime.now().isoformat() sql = """ UPDATE functional_requirements SET title=?, description=?, function_name=?, priority=?, status=?, index_no=?, updated_at=? WHERE id=? """ with self._get_conn() as conn: conn.execute(sql, ( req.title, req.description, req.function_name, req.priority, req.status, req.index_no, req.updated_at, req.id )) def delete_functional_requirement(self, req_id: int) -> None: """删除功能需求""" with self._get_conn() as conn: conn.execute( "DELETE FROM functional_requirements WHERE id = ?", (req_id,) ) def _row_to_func_req(self, row) -> FunctionalRequirement: """sqlite Row → FunctionalRequirement 对象""" return FunctionalRequirement( id=row["id"], project_id=row["project_id"], raw_req_id=row["raw_req_id"], index_no=row["index_no"], title=row["title"], description=row["description"], function_name=row["function_name"], priority=row["priority"], status=row["status"], is_custom=bool(row["is_custom"]), created_at=row["created_at"], updated_at=row["updated_at"] ) # ══════════════════════════════════════════════════ # CodeFile CRUD # ══════════════════════════════════════════════════ def create_code_file(self, code_file: CodeFile) -> int: """创建代码文件记录,返回新记录 ID""" sql = """ INSERT INTO code_files (project_id, func_req_id, file_name, file_path, language, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """ with self._get_conn() as conn: cur = conn.execute(sql, ( code_file.project_id, code_file.func_req_id, code_file.file_name, code_file.file_path, code_file.language, code_file.content, code_file.created_at, code_file.updated_at )) return cur.lastrowid def upsert_code_file(self, code_file: CodeFile) -> int: """插入或更新代码文件(按 func_req_id 唯一键)""" existing = self.get_code_file_by_func_req(code_file.func_req_id) if existing: code_file.id = existing.id code_file.updated_at = datetime.now().isoformat() sql = """ UPDATE code_files SET file_name=?, file_path=?, language=?, content=?, updated_at=? WHERE id=? """ with self._get_conn() as conn: conn.execute(sql, ( code_file.file_name, code_file.file_path, code_file.language, code_file.content, code_file.updated_at, code_file.id )) return code_file.id else: return self.create_code_file(code_file) def get_code_file_by_func_req(self, func_req_id: int) -> Optional[CodeFile]: """根据功能需求 ID 查询代码文件""" with self._get_conn() as conn: row = conn.execute( "SELECT * FROM code_files WHERE func_req_id = ?", (func_req_id,) ).fetchone() if row is None: return None return self._row_to_code_file(row) def list_code_files_by_project(self, project_id: int) -> List[CodeFile]: """查询项目下所有代码文件""" with self._get_conn() as conn: rows = conn.execute( "SELECT * FROM code_files WHERE project_id = ? ORDER BY created_at", (project_id,) ).fetchall() return [self._row_to_code_file(r) for r in rows] def _row_to_code_file(self, row) -> CodeFile: """sqlite Row → CodeFile 对象""" return CodeFile( id=row["id"], project_id=row["project_id"], func_req_id=row["func_req_id"], file_name=row["file_name"], file_path=row["file_path"], language=row["language"], content=row["content"], created_at=row["created_at"], updated_at=row["updated_at"] )