AIDeveloper-PC/requirements_generator/database/db_manager.py

288 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# database/db_manager.py - 数据库 CRUD 操作封装
import os
import shutil
from typing import List, Optional, Dict, Any
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
import config
from database.models import Base, Project, RawRequirement, FunctionalRequirement, CodeFile
class DBManager:
"""SQLite 数据库管理器,封装所有 CRUD 操作"""
def __init__(self, db_path: str = None):
db_path = db_path or config.DB_PATH
os.makedirs(os.path.dirname(db_path), exist_ok=True)
self.engine = create_engine(f"sqlite:///{db_path}", echo=False)
Base.metadata.create_all(self.engine)
self._Session = sessionmaker(bind=self.engine)
def _session(self) -> Session:
return self._Session()
# ══════════════════════════════════════════════════
# Project
# ══════════════════════════════════════════════════
def create_project(self, project: Project) -> int:
with self._session() as s:
s.add(project)
s.commit()
s.refresh(project)
return project.id
def get_project_by_name(self, name: str) -> Optional[Project]:
with self._session() as s:
return s.query(Project).filter_by(name=name).first()
def get_project_by_id(self, project_id: int) -> Optional[Project]:
with self._session() as s:
return s.get(Project, project_id)
def update_project(self, project: Project) -> None:
with self._session() as s:
s.merge(project)
s.commit()
def list_projects(self) -> List[Project]:
"""返回所有项目(按创建时间倒序)"""
with self._session() as s:
return s.query(Project).order_by(Project.created_at.desc()).all()
def delete_project(self, project_id: int, delete_output: bool = False) -> bool:
"""
删除指定项目及其所有关联数据(级联删除)。
Args:
project_id: 项目 ID
delete_output: 是否同时删除磁盘上的输出目录
Returns:
True 表示删除成功False 表示项目不存在
"""
with self._session() as s:
project = s.get(Project, project_id)
if project is None:
return False
output_dir = project.output_dir
s.delete(project)
s.commit()
# 可选:删除磁盘输出目录
if delete_output and output_dir and os.path.isdir(output_dir):
shutil.rmtree(output_dir, ignore_errors=True)
return True
def get_project_stats(self, project_id: int) -> Dict[str, int]:
"""
获取项目统计信息:需求数、已生成代码数、模块数。
Returns:
{"raw_req_count": n, "func_req_count": n,
"generated_count": n, "module_count": n, "code_file_count": n}
"""
with self._session() as s:
raw_count = s.query(RawRequirement).filter_by(project_id=project_id).count()
func_reqs = (
s.query(FunctionalRequirement)
.filter_by(project_id=project_id)
.all()
)
gen_count = sum(1 for r in func_reqs if r.status == "generated")
modules = {r.module or config.DEFAULT_MODULE for r in func_reqs}
code_count = (
s.query(CodeFile)
.join(FunctionalRequirement)
.filter(FunctionalRequirement.project_id == project_id)
.count()
)
return {
"raw_req_count": raw_count,
"func_req_count": len(func_reqs),
"generated_count": gen_count,
"module_count": len(modules),
"code_file_count": code_count,
}
# ══════════════════════════════════════════════════
# Project Full Info需求-模块-代码关系)
# ══════════════════════════════════════════════════
def get_project_full_info(self, project_id: int) -> Optional[Dict[str, Any]]:
"""
获取项目完整信息,包含需求-模块-代码之间的关系树。
Returns::
{
"project": Project,
"stats": {...},
"modules": {
"<module_name>": {
"requirements": [
{
"req": FunctionalRequirement,
"code_files": [CodeFile, ...]
},
...
]
},
...
},
"raw_requirements": [RawRequirement, ...]
}
Returns None 若项目不存在。
"""
with self._session() as s:
project = s.get(Project, project_id)
if project is None:
return None
raw_reqs = (
s.query(RawRequirement)
.filter_by(project_id=project_id)
.order_by(RawRequirement.created_at)
.all()
)
func_reqs = (
s.query(FunctionalRequirement)
.filter_by(project_id=project_id)
.order_by(FunctionalRequirement.index_no)
.all()
)
code_files = (
s.query(CodeFile)
.join(FunctionalRequirement)
.filter(FunctionalRequirement.project_id == project_id)
.all()
)
# 构建 func_req_id → [CodeFile] 映射
code_map: Dict[int, List[CodeFile]] = {}
for cf in code_files:
code_map.setdefault(cf.func_req_id, []).append(cf)
# 按模块分组
modules: Dict[str, Dict] = {}
for req in func_reqs:
mod = req.module or config.DEFAULT_MODULE
modules.setdefault(mod, {"requirements": []})
modules[mod]["requirements"].append({
"req": req,
"code_files": code_map.get(req.id, []),
})
stats = self.get_project_stats(project_id)
return {
"project": project,
"stats": stats,
"modules": modules,
"raw_requirements": raw_reqs,
}
# ══════════════════════════════════════════════════
# RawRequirement
# ══════════════════════════════════════════════════
def create_raw_requirement(self, raw_req: RawRequirement) -> int:
with self._session() as s:
s.add(raw_req)
s.commit()
s.refresh(raw_req)
return raw_req.id
def get_raw_requirement(self, raw_req_id: int) -> Optional[RawRequirement]:
with self._session() as s:
return s.get(RawRequirement, raw_req_id)
# ══════════════════════════════════════════════════
# FunctionalRequirement
# ══════════════════════════════════════════════════
def create_functional_requirement(self, req: FunctionalRequirement) -> int:
with self._session() as s:
s.add(req)
s.commit()
s.refresh(req)
return req.id
def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]:
with self._session() as s:
return s.get(FunctionalRequirement, req_id)
def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]:
with self._session() as s:
return (
s.query(FunctionalRequirement)
.filter_by(project_id=project_id)
.order_by(FunctionalRequirement.index_no)
.all()
)
def update_functional_requirement(self, req: FunctionalRequirement) -> None:
with self._session() as s:
s.merge(req)
s.commit()
def delete_functional_requirement(self, req_id: int) -> None:
with self._session() as s:
obj = s.get(FunctionalRequirement, req_id)
if obj:
s.delete(obj)
s.commit()
def bulk_update_modules(self, updates: List[dict]) -> None:
"""
批量更新功能需求的 module 字段。
Args:
updates: [{"function_name": "...", "module": "..."}, ...]
"""
with self._session() as s:
name_to_module = {u["function_name"]: u["module"] for u in updates}
reqs = s.query(FunctionalRequirement).filter(
FunctionalRequirement.function_name.in_(name_to_module.keys())
).all()
for req in reqs:
req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE)
s.commit()
# ══════════════════════════════════════════════════
# CodeFile
# ══════════════════════════════════════════════════
def upsert_code_file(self, code_file: CodeFile) -> int:
with self._session() as s:
existing = (
s.query(CodeFile)
.filter_by(func_req_id=code_file.func_req_id)
.first()
)
if existing:
existing.file_name = code_file.file_name
existing.file_path = code_file.file_path
existing.module = code_file.module
existing.language = code_file.language
existing.content = code_file.content
s.commit()
return existing.id
else:
s.add(code_file)
s.commit()
s.refresh(code_file)
return code_file.id
def list_code_files(self, project_id: int) -> List[CodeFile]:
with self._session() as s:
return (
s.query(CodeFile)
.join(FunctionalRequirement)
.filter(FunctionalRequirement.project_id == project_id)
.all()
)