AIDeveloper-PC/requirements_generator/database/db_manager.py

314 lines
13 KiB
Python
Raw Normal View History

2026-03-04 18:09:45 +00:00
# database/db_manager.py - 数据库操作管理器
import sqlite3
import os
from datetime import datetime
from typing import List, Optional
from contextlib import contextmanager
from database.models import (
CREATE_TABLES_SQL, Project, RawRequirement,
FunctionalRequirement, CodeFile
)
import config
class DBManager:
"""SQLite 数据库管理器,封装所有 CRUD 操作"""
def __init__(self, db_path: str = config.DB_PATH):
self.db_path = db_path
os.makedirs(os.path.dirname(db_path), exist_ok=True)
self._init_db()
# ── 连接上下文管理器 ──────────────────────────────────
@contextmanager
def _get_conn(self):
"""获取数据库连接(自动提交/回滚)"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _init_db(self):
"""初始化数据库,创建所有表"""
with self._get_conn() as conn:
conn.executescript(CREATE_TABLES_SQL)
# ══════════════════════════════════════════════════
# Project CRUD
# ══════════════════════════════════════════════════
def create_project(self, project: Project) -> int:
"""创建项目,返回新项目 ID"""
sql = """
INSERT INTO projects (name, description, language, output_dir, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
"""
with self._get_conn() as conn:
cur = conn.execute(sql, (
project.name, project.description, project.language,
project.output_dir, project.created_at, project.updated_at
))
return cur.lastrowid
def get_project_by_id(self, project_id: int) -> Optional[Project]:
"""根据 ID 查询项目"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM projects WHERE id = ?", (project_id,)
).fetchone()
if row is None:
return None
return Project(
id=row["id"], name=row["name"], description=row["description"],
language=row["language"], output_dir=row["output_dir"],
created_at=row["created_at"], updated_at=row["updated_at"]
)
def get_project_by_name(self, name: str) -> Optional[Project]:
"""根据名称查询项目"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM projects WHERE name = ?", (name,)
).fetchone()
if row is None:
return None
return Project(
id=row["id"], name=row["name"], description=row["description"],
language=row["language"], output_dir=row["output_dir"],
created_at=row["created_at"], updated_at=row["updated_at"]
)
def list_projects(self) -> List[Project]:
"""列出所有项目"""
with self._get_conn() as conn:
rows = conn.execute(
"SELECT * FROM projects ORDER BY created_at DESC"
).fetchall()
return [
Project(
id=r["id"], name=r["name"], description=r["description"],
language=r["language"], output_dir=r["output_dir"],
created_at=r["created_at"], updated_at=r["updated_at"]
) for r in rows
]
def update_project(self, project: Project) -> None:
"""更新项目信息"""
project.updated_at = datetime.now().isoformat()
sql = """
UPDATE projects
SET name=?, description=?, language=?, output_dir=?, updated_at=?
WHERE id=?
"""
with self._get_conn() as conn:
conn.execute(sql, (
project.name, project.description, project.language,
project.output_dir, project.updated_at, project.id
))
def delete_project(self, project_id: int) -> None:
"""删除项目(级联删除所有关联数据)"""
with self._get_conn() as conn:
conn.execute("DELETE FROM projects WHERE id = ?", (project_id,))
# ══════════════════════════════════════════════════
# RawRequirement CRUD
# ══════════════════════════════════════════════════
def create_raw_requirement(self, req: RawRequirement) -> int:
"""创建原始需求,返回新记录 ID"""
sql = """
INSERT INTO raw_requirements
(project_id, content, source_type, source_name, knowledge, created_at)
VALUES (?, ?, ?, ?, ?, ?)
"""
with self._get_conn() as conn:
cur = conn.execute(sql, (
req.project_id, req.content, req.source_type,
req.source_name, req.knowledge, req.created_at
))
return cur.lastrowid
def get_raw_requirement(self, req_id: int) -> Optional[RawRequirement]:
"""根据 ID 查询原始需求"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM raw_requirements WHERE id = ?", (req_id,)
).fetchone()
if row is None:
return None
return RawRequirement(
id=row["id"], project_id=row["project_id"], content=row["content"],
source_type=row["source_type"], source_name=row["source_name"],
knowledge=row["knowledge"], created_at=row["created_at"]
)
def list_raw_requirements_by_project(self, project_id: int) -> List[RawRequirement]:
"""查询项目下所有原始需求"""
with self._get_conn() as conn:
rows = conn.execute(
"SELECT * FROM raw_requirements WHERE project_id = ? ORDER BY created_at",
(project_id,)
).fetchall()
return [
RawRequirement(
id=r["id"], project_id=r["project_id"], content=r["content"],
source_type=r["source_type"], source_name=r["source_name"],
knowledge=r["knowledge"], created_at=r["created_at"]
) for r in rows
]
# ══════════════════════════════════════════════════
# FunctionalRequirement CRUD
# ══════════════════════════════════════════════════
def create_functional_requirement(self, req: FunctionalRequirement) -> int:
"""创建功能需求,返回新记录 ID"""
sql = """
INSERT INTO functional_requirements
(project_id, raw_req_id, index_no, title, description,
function_name, priority, status, is_custom, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
with self._get_conn() as conn:
cur = conn.execute(sql, (
req.project_id, req.raw_req_id, req.index_no,
req.title, req.description, req.function_name,
req.priority, req.status, int(req.is_custom),
req.created_at, req.updated_at
))
return cur.lastrowid
def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]:
"""根据 ID 查询功能需求"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM functional_requirements WHERE id = ?", (req_id,)
).fetchone()
if row is None:
return None
return self._row_to_func_req(row)
def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]:
"""查询项目下所有功能需求(按序号排序)"""
with self._get_conn() as conn:
rows = conn.execute(
"""SELECT * FROM functional_requirements
WHERE project_id = ? ORDER BY index_no""",
(project_id,)
).fetchall()
return [self._row_to_func_req(r) for r in rows]
def update_functional_requirement(self, req: FunctionalRequirement) -> None:
"""更新功能需求"""
req.updated_at = datetime.now().isoformat()
sql = """
UPDATE functional_requirements
SET title=?, description=?, function_name=?, priority=?,
status=?, index_no=?, updated_at=?
WHERE id=?
"""
with self._get_conn() as conn:
conn.execute(sql, (
req.title, req.description, req.function_name,
req.priority, req.status, req.index_no,
req.updated_at, req.id
))
def delete_functional_requirement(self, req_id: int) -> None:
"""删除功能需求"""
with self._get_conn() as conn:
conn.execute(
"DELETE FROM functional_requirements WHERE id = ?", (req_id,)
)
def _row_to_func_req(self, row) -> FunctionalRequirement:
"""sqlite Row → FunctionalRequirement 对象"""
return FunctionalRequirement(
id=row["id"], project_id=row["project_id"],
raw_req_id=row["raw_req_id"], index_no=row["index_no"],
title=row["title"], description=row["description"],
function_name=row["function_name"], priority=row["priority"],
status=row["status"], is_custom=bool(row["is_custom"]),
created_at=row["created_at"], updated_at=row["updated_at"]
)
# ══════════════════════════════════════════════════
# CodeFile CRUD
# ══════════════════════════════════════════════════
def create_code_file(self, code_file: CodeFile) -> int:
"""创建代码文件记录,返回新记录 ID"""
sql = """
INSERT INTO code_files
(project_id, func_req_id, file_name, file_path,
language, content, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"""
with self._get_conn() as conn:
cur = conn.execute(sql, (
code_file.project_id, code_file.func_req_id,
code_file.file_name, code_file.file_path,
code_file.language, code_file.content,
code_file.created_at, code_file.updated_at
))
return cur.lastrowid
def upsert_code_file(self, code_file: CodeFile) -> int:
"""插入或更新代码文件(按 func_req_id 唯一键)"""
existing = self.get_code_file_by_func_req(code_file.func_req_id)
if existing:
code_file.id = existing.id
code_file.updated_at = datetime.now().isoformat()
sql = """
UPDATE code_files
SET file_name=?, file_path=?, language=?, content=?, updated_at=?
WHERE id=?
"""
with self._get_conn() as conn:
conn.execute(sql, (
code_file.file_name, code_file.file_path,
code_file.language, code_file.content,
code_file.updated_at, code_file.id
))
return code_file.id
else:
return self.create_code_file(code_file)
def get_code_file_by_func_req(self, func_req_id: int) -> Optional[CodeFile]:
"""根据功能需求 ID 查询代码文件"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM code_files WHERE func_req_id = ?", (func_req_id,)
).fetchone()
if row is None:
return None
return self._row_to_code_file(row)
def list_code_files_by_project(self, project_id: int) -> List[CodeFile]:
"""查询项目下所有代码文件"""
with self._get_conn() as conn:
rows = conn.execute(
"SELECT * FROM code_files WHERE project_id = ? ORDER BY created_at",
(project_id,)
).fetchall()
return [self._row_to_code_file(r) for r in rows]
def _row_to_code_file(self, row) -> CodeFile:
"""sqlite Row → CodeFile 对象"""
return CodeFile(
id=row["id"], project_id=row["project_id"],
func_req_id=row["func_req_id"], file_name=row["file_name"],
file_path=row["file_path"], language=row["language"],
content=row["content"], created_at=row["created_at"],
updated_at=row["updated_at"]
)