314 lines
13 KiB
Python
314 lines
13 KiB
Python
|
|
# database/db_manager.py - 数据库操作管理器
|
||
|
|
import sqlite3
|
||
|
|
import os
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import List, Optional
|
||
|
|
from contextlib import contextmanager
|
||
|
|
|
||
|
|
from database.models import (
|
||
|
|
CREATE_TABLES_SQL, Project, RawRequirement,
|
||
|
|
FunctionalRequirement, CodeFile
|
||
|
|
)
|
||
|
|
import config
|
||
|
|
|
||
|
|
|
||
|
|
class DBManager:
|
||
|
|
"""SQLite 数据库管理器,封装所有 CRUD 操作"""
|
||
|
|
|
||
|
|
def __init__(self, db_path: str = config.DB_PATH):
|
||
|
|
self.db_path = db_path
|
||
|
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||
|
|
self._init_db()
|
||
|
|
|
||
|
|
# ── 连接上下文管理器 ──────────────────────────────────
|
||
|
|
|
||
|
|
@contextmanager
|
||
|
|
def _get_conn(self):
|
||
|
|
"""获取数据库连接(自动提交/回滚)"""
|
||
|
|
conn = sqlite3.connect(self.db_path)
|
||
|
|
conn.row_factory = sqlite3.Row
|
||
|
|
conn.execute("PRAGMA foreign_keys = ON")
|
||
|
|
try:
|
||
|
|
yield conn
|
||
|
|
conn.commit()
|
||
|
|
except Exception:
|
||
|
|
conn.rollback()
|
||
|
|
raise
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def _init_db(self):
|
||
|
|
"""初始化数据库,创建所有表"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
conn.executescript(CREATE_TABLES_SQL)
|
||
|
|
|
||
|
|
# ══════════════════════════════════════════════════
|
||
|
|
# Project CRUD
|
||
|
|
# ══════════════════════════════════════════════════
|
||
|
|
|
||
|
|
def create_project(self, project: Project) -> int:
|
||
|
|
"""创建项目,返回新项目 ID"""
|
||
|
|
sql = """
|
||
|
|
INSERT INTO projects (name, description, language, output_dir, created_at, updated_at)
|
||
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
||
|
|
"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
cur = conn.execute(sql, (
|
||
|
|
project.name, project.description, project.language,
|
||
|
|
project.output_dir, project.created_at, project.updated_at
|
||
|
|
))
|
||
|
|
return cur.lastrowid
|
||
|
|
|
||
|
|
def get_project_by_id(self, project_id: int) -> Optional[Project]:
|
||
|
|
"""根据 ID 查询项目"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
row = conn.execute(
|
||
|
|
"SELECT * FROM projects WHERE id = ?", (project_id,)
|
||
|
|
).fetchone()
|
||
|
|
if row is None:
|
||
|
|
return None
|
||
|
|
return Project(
|
||
|
|
id=row["id"], name=row["name"], description=row["description"],
|
||
|
|
language=row["language"], output_dir=row["output_dir"],
|
||
|
|
created_at=row["created_at"], updated_at=row["updated_at"]
|
||
|
|
)
|
||
|
|
|
||
|
|
def get_project_by_name(self, name: str) -> Optional[Project]:
|
||
|
|
"""根据名称查询项目"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
row = conn.execute(
|
||
|
|
"SELECT * FROM projects WHERE name = ?", (name,)
|
||
|
|
).fetchone()
|
||
|
|
if row is None:
|
||
|
|
return None
|
||
|
|
return Project(
|
||
|
|
id=row["id"], name=row["name"], description=row["description"],
|
||
|
|
language=row["language"], output_dir=row["output_dir"],
|
||
|
|
created_at=row["created_at"], updated_at=row["updated_at"]
|
||
|
|
)
|
||
|
|
|
||
|
|
def list_projects(self) -> List[Project]:
|
||
|
|
"""列出所有项目"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
rows = conn.execute(
|
||
|
|
"SELECT * FROM projects ORDER BY created_at DESC"
|
||
|
|
).fetchall()
|
||
|
|
return [
|
||
|
|
Project(
|
||
|
|
id=r["id"], name=r["name"], description=r["description"],
|
||
|
|
language=r["language"], output_dir=r["output_dir"],
|
||
|
|
created_at=r["created_at"], updated_at=r["updated_at"]
|
||
|
|
) for r in rows
|
||
|
|
]
|
||
|
|
|
||
|
|
def update_project(self, project: Project) -> None:
|
||
|
|
"""更新项目信息"""
|
||
|
|
project.updated_at = datetime.now().isoformat()
|
||
|
|
sql = """
|
||
|
|
UPDATE projects
|
||
|
|
SET name=?, description=?, language=?, output_dir=?, updated_at=?
|
||
|
|
WHERE id=?
|
||
|
|
"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
conn.execute(sql, (
|
||
|
|
project.name, project.description, project.language,
|
||
|
|
project.output_dir, project.updated_at, project.id
|
||
|
|
))
|
||
|
|
|
||
|
|
def delete_project(self, project_id: int) -> None:
|
||
|
|
"""删除项目(级联删除所有关联数据)"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
conn.execute("DELETE FROM projects WHERE id = ?", (project_id,))
|
||
|
|
|
||
|
|
# ══════════════════════════════════════════════════
|
||
|
|
# RawRequirement CRUD
|
||
|
|
# ══════════════════════════════════════════════════
|
||
|
|
|
||
|
|
def create_raw_requirement(self, req: RawRequirement) -> int:
|
||
|
|
"""创建原始需求,返回新记录 ID"""
|
||
|
|
sql = """
|
||
|
|
INSERT INTO raw_requirements
|
||
|
|
(project_id, content, source_type, source_name, knowledge, created_at)
|
||
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
||
|
|
"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
cur = conn.execute(sql, (
|
||
|
|
req.project_id, req.content, req.source_type,
|
||
|
|
req.source_name, req.knowledge, req.created_at
|
||
|
|
))
|
||
|
|
return cur.lastrowid
|
||
|
|
|
||
|
|
def get_raw_requirement(self, req_id: int) -> Optional[RawRequirement]:
|
||
|
|
"""根据 ID 查询原始需求"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
row = conn.execute(
|
||
|
|
"SELECT * FROM raw_requirements WHERE id = ?", (req_id,)
|
||
|
|
).fetchone()
|
||
|
|
if row is None:
|
||
|
|
return None
|
||
|
|
return RawRequirement(
|
||
|
|
id=row["id"], project_id=row["project_id"], content=row["content"],
|
||
|
|
source_type=row["source_type"], source_name=row["source_name"],
|
||
|
|
knowledge=row["knowledge"], created_at=row["created_at"]
|
||
|
|
)
|
||
|
|
|
||
|
|
def list_raw_requirements_by_project(self, project_id: int) -> List[RawRequirement]:
|
||
|
|
"""查询项目下所有原始需求"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
rows = conn.execute(
|
||
|
|
"SELECT * FROM raw_requirements WHERE project_id = ? ORDER BY created_at",
|
||
|
|
(project_id,)
|
||
|
|
).fetchall()
|
||
|
|
return [
|
||
|
|
RawRequirement(
|
||
|
|
id=r["id"], project_id=r["project_id"], content=r["content"],
|
||
|
|
source_type=r["source_type"], source_name=r["source_name"],
|
||
|
|
knowledge=r["knowledge"], created_at=r["created_at"]
|
||
|
|
) for r in rows
|
||
|
|
]
|
||
|
|
|
||
|
|
# ══════════════════════════════════════════════════
|
||
|
|
# FunctionalRequirement CRUD
|
||
|
|
# ══════════════════════════════════════════════════
|
||
|
|
|
||
|
|
def create_functional_requirement(self, req: FunctionalRequirement) -> int:
|
||
|
|
"""创建功能需求,返回新记录 ID"""
|
||
|
|
sql = """
|
||
|
|
INSERT INTO functional_requirements
|
||
|
|
(project_id, raw_req_id, index_no, title, description,
|
||
|
|
function_name, priority, status, is_custom, created_at, updated_at)
|
||
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
|
|
"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
cur = conn.execute(sql, (
|
||
|
|
req.project_id, req.raw_req_id, req.index_no,
|
||
|
|
req.title, req.description, req.function_name,
|
||
|
|
req.priority, req.status, int(req.is_custom),
|
||
|
|
req.created_at, req.updated_at
|
||
|
|
))
|
||
|
|
return cur.lastrowid
|
||
|
|
|
||
|
|
def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]:
|
||
|
|
"""根据 ID 查询功能需求"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
row = conn.execute(
|
||
|
|
"SELECT * FROM functional_requirements WHERE id = ?", (req_id,)
|
||
|
|
).fetchone()
|
||
|
|
if row is None:
|
||
|
|
return None
|
||
|
|
return self._row_to_func_req(row)
|
||
|
|
|
||
|
|
def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]:
|
||
|
|
"""查询项目下所有功能需求(按序号排序)"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
rows = conn.execute(
|
||
|
|
"""SELECT * FROM functional_requirements
|
||
|
|
WHERE project_id = ? ORDER BY index_no""",
|
||
|
|
(project_id,)
|
||
|
|
).fetchall()
|
||
|
|
return [self._row_to_func_req(r) for r in rows]
|
||
|
|
|
||
|
|
def update_functional_requirement(self, req: FunctionalRequirement) -> None:
|
||
|
|
"""更新功能需求"""
|
||
|
|
req.updated_at = datetime.now().isoformat()
|
||
|
|
sql = """
|
||
|
|
UPDATE functional_requirements
|
||
|
|
SET title=?, description=?, function_name=?, priority=?,
|
||
|
|
status=?, index_no=?, updated_at=?
|
||
|
|
WHERE id=?
|
||
|
|
"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
conn.execute(sql, (
|
||
|
|
req.title, req.description, req.function_name,
|
||
|
|
req.priority, req.status, req.index_no,
|
||
|
|
req.updated_at, req.id
|
||
|
|
))
|
||
|
|
|
||
|
|
def delete_functional_requirement(self, req_id: int) -> None:
|
||
|
|
"""删除功能需求"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
conn.execute(
|
||
|
|
"DELETE FROM functional_requirements WHERE id = ?", (req_id,)
|
||
|
|
)
|
||
|
|
|
||
|
|
def _row_to_func_req(self, row) -> FunctionalRequirement:
|
||
|
|
"""sqlite Row → FunctionalRequirement 对象"""
|
||
|
|
return FunctionalRequirement(
|
||
|
|
id=row["id"], project_id=row["project_id"],
|
||
|
|
raw_req_id=row["raw_req_id"], index_no=row["index_no"],
|
||
|
|
title=row["title"], description=row["description"],
|
||
|
|
function_name=row["function_name"], priority=row["priority"],
|
||
|
|
status=row["status"], is_custom=bool(row["is_custom"]),
|
||
|
|
created_at=row["created_at"], updated_at=row["updated_at"]
|
||
|
|
)
|
||
|
|
|
||
|
|
# ══════════════════════════════════════════════════
|
||
|
|
# CodeFile CRUD
|
||
|
|
# ══════════════════════════════════════════════════
|
||
|
|
|
||
|
|
def create_code_file(self, code_file: CodeFile) -> int:
|
||
|
|
"""创建代码文件记录,返回新记录 ID"""
|
||
|
|
sql = """
|
||
|
|
INSERT INTO code_files
|
||
|
|
(project_id, func_req_id, file_name, file_path,
|
||
|
|
language, content, created_at, updated_at)
|
||
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||
|
|
"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
cur = conn.execute(sql, (
|
||
|
|
code_file.project_id, code_file.func_req_id,
|
||
|
|
code_file.file_name, code_file.file_path,
|
||
|
|
code_file.language, code_file.content,
|
||
|
|
code_file.created_at, code_file.updated_at
|
||
|
|
))
|
||
|
|
return cur.lastrowid
|
||
|
|
|
||
|
|
def upsert_code_file(self, code_file: CodeFile) -> int:
|
||
|
|
"""插入或更新代码文件(按 func_req_id 唯一键)"""
|
||
|
|
existing = self.get_code_file_by_func_req(code_file.func_req_id)
|
||
|
|
if existing:
|
||
|
|
code_file.id = existing.id
|
||
|
|
code_file.updated_at = datetime.now().isoformat()
|
||
|
|
sql = """
|
||
|
|
UPDATE code_files
|
||
|
|
SET file_name=?, file_path=?, language=?, content=?, updated_at=?
|
||
|
|
WHERE id=?
|
||
|
|
"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
conn.execute(sql, (
|
||
|
|
code_file.file_name, code_file.file_path,
|
||
|
|
code_file.language, code_file.content,
|
||
|
|
code_file.updated_at, code_file.id
|
||
|
|
))
|
||
|
|
return code_file.id
|
||
|
|
else:
|
||
|
|
return self.create_code_file(code_file)
|
||
|
|
|
||
|
|
def get_code_file_by_func_req(self, func_req_id: int) -> Optional[CodeFile]:
|
||
|
|
"""根据功能需求 ID 查询代码文件"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
row = conn.execute(
|
||
|
|
"SELECT * FROM code_files WHERE func_req_id = ?", (func_req_id,)
|
||
|
|
).fetchone()
|
||
|
|
if row is None:
|
||
|
|
return None
|
||
|
|
return self._row_to_code_file(row)
|
||
|
|
|
||
|
|
def list_code_files_by_project(self, project_id: int) -> List[CodeFile]:
|
||
|
|
"""查询项目下所有代码文件"""
|
||
|
|
with self._get_conn() as conn:
|
||
|
|
rows = conn.execute(
|
||
|
|
"SELECT * FROM code_files WHERE project_id = ? ORDER BY created_at",
|
||
|
|
(project_id,)
|
||
|
|
).fetchall()
|
||
|
|
return [self._row_to_code_file(r) for r in rows]
|
||
|
|
|
||
|
|
def _row_to_code_file(self, row) -> CodeFile:
|
||
|
|
"""sqlite Row → CodeFile 对象"""
|
||
|
|
return CodeFile(
|
||
|
|
id=row["id"], project_id=row["project_id"],
|
||
|
|
func_req_id=row["func_req_id"], file_name=row["file_name"],
|
||
|
|
file_path=row["file_path"], language=row["language"],
|
||
|
|
content=row["content"], created_at=row["created_at"],
|
||
|
|
updated_at=row["updated_at"]
|
||
|
|
)
|