支持功能模块
This commit is contained in:
parent
3fc095652e
commit
29636b9b94
|
|
@ -1,146 +1,123 @@
|
|||
# config.py - 全局配置管理
|
||||
# config.py - 全局配置
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# ── LLM 配置 ──────────────────────────────────────────
|
||||
LLM_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
|
||||
LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0.3"))
|
||||
# ── LLM ──────────────────────────────────────────────
|
||||
LLM_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
LLM_API_BASE = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
|
||||
LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT", "60"))
|
||||
LLM_MAX_RETRY = int(os.getenv("LLM_MAX_RETRY", "3"))
|
||||
|
||||
# ── 数据库配置 ─────────────────────────────────────────
|
||||
DB_PATH = os.getenv("DB_PATH", "data/requirement_analyzer.db")
|
||||
# ── 数据库 ────────────────────────────────────────────
|
||||
DB_PATH = os.getenv("DB_PATH", "data/requirement_analyzer.db")
|
||||
|
||||
# ── 输出配置 ───────────────────────────────────────────
|
||||
# ── 输出目录 ──────────────────────────────────────────
|
||||
OUTPUT_BASE_DIR = os.getenv("OUTPUT_BASE_DIR", "output")
|
||||
DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "python")
|
||||
DEFAULT_MODULE = os.getenv("DEFAULT_MODULE", "default")
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Prompt 模板
|
||||
# ══════════════════════════════════════════════════════
|
||||
# ── Prompt 模板 ───────────────────────────────────────
|
||||
|
||||
DECOMPOSE_PROMPT_TEMPLATE = """
|
||||
你是一位资深软件架构师和产品经理。请根据以下信息,将原始需求分解为若干个可独立实现的功能需求。
|
||||
DECOMPOSE_PROMPT_TEMPLATE = """\
|
||||
你是一名资深软件架构师,请将以下原始需求分解为独立的功能需求列表。
|
||||
|
||||
{knowledge_section}
|
||||
|
||||
## 原始需求
|
||||
【原始需求】
|
||||
{raw_requirement}
|
||||
|
||||
## 输出要求
|
||||
请严格按照以下 JSON 格式输出,不要包含任何额外说明:
|
||||
{{
|
||||
"functional_requirements": [
|
||||
{{
|
||||
"index": 1,
|
||||
"title": "功能需求标题(简洁,10字以内)",
|
||||
"description": "功能需求详细描述(包含输入、处理逻辑、输出)",
|
||||
"function_name": "snake_case函数名",
|
||||
"priority": "high|medium|low"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
要求:
|
||||
1. 每个功能需求必须是独立可实现的最小单元
|
||||
2. function_name 使用 snake_case 命名,清晰表达函数用途
|
||||
3. 分解粒度适中,通常 5-15 个功能需求
|
||||
4. 优先级根据业务重要性判断
|
||||
"""
|
||||
|
||||
# ── 函数签名 JSON 生成 Prompt ──────────────────────────
|
||||
FUNC_SIGNATURE_PROMPT_TEMPLATE = """
|
||||
你是一位资深软件架构师。请根据以下功能需求描述,设计该函数的完整接口签名,并以 JSON 格式输出。
|
||||
|
||||
{knowledge_section}
|
||||
|
||||
## 功能需求
|
||||
需求编号:{requirement_id}
|
||||
标题:{title}
|
||||
函数名:{function_name}
|
||||
详细描述:{description}
|
||||
【输出要求】
|
||||
以 JSON 数组格式输出,每个元素包含以下字段:
|
||||
- title: 功能标题(简短,10字以内)
|
||||
- description: 功能描述(详细说明该功能的职责与边界,50字以内)
|
||||
- function_name: 对应的函数名(snake_case,动词开头)
|
||||
- priority: 优先级(high / medium / low)
|
||||
- module: 所属功能模块名称(snake_case,如 user_auth / order_service)
|
||||
|
||||
## 输出格式
|
||||
请严格按照以下 JSON 结构输出,不要包含任何额外说明或 markdown 标记:
|
||||
{{
|
||||
"name": "{function_name}",
|
||||
"requirement_id": "{requirement_id}",
|
||||
"description": "简洁的一句话功能描述(英文)",
|
||||
"type": "function",
|
||||
"parameters": {{
|
||||
"<param_name>": {{
|
||||
"type": "integer|string|boolean|float|list|dict|object",
|
||||
"inout": "in|out|inout",
|
||||
"description": "参数说明(英文)",
|
||||
"required": true
|
||||
}}
|
||||
}},
|
||||
"return": {{
|
||||
"type": "integer|string|boolean|float|list|dict|object|void",
|
||||
"description": "整体返回值说明(英文,一句话概括)",
|
||||
"on_success": {{
|
||||
"value": "具体成功返回值或范围,如 0、true、user object、list of items 等",
|
||||
"description": "成功时的返回值含义(英文)"
|
||||
}},
|
||||
"on_failure": {{
|
||||
"value": "具体失败返回值或范围,如 nonzero、false、null、empty list、raises Exception 等",
|
||||
"description": "失败时的返回值含义,或抛出的异常类型(英文)"
|
||||
}}
|
||||
【示例输出】
|
||||
[
|
||||
{{
|
||||
"title": "用户注册",
|
||||
"description": "接收用户名、密码、邮箱,校验合法性后创建用户账号并返回用户ID",
|
||||
"function_name": "register_user",
|
||||
"priority": "high",
|
||||
"module": "user_auth"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
|
||||
## 设计规范
|
||||
1. 参数名使用 snake_case,类型使用通用类型(不绑定具体语言)
|
||||
2. inout 字段含义:
|
||||
- in = 仅输入参数
|
||||
- out = 仅输出参数(通过参数传出结果,如指针/引用)
|
||||
- inout = 既作输入又作输出
|
||||
3. 所有描述字段使用英文
|
||||
4. return 字段规则:
|
||||
- 若函数无返回值(void),type 填 "void",on_success/on_failure 均填 null
|
||||
- 若返回值只有成功场景(如纯查询),on_failure 可描述为 "null or empty"
|
||||
- on_success.value / on_failure.value 填写具体值或值域描述,不要填写空字符串
|
||||
5. 若函数无参数,parameters 填 {{}}
|
||||
6. required 字段为布尔值 true 或 false
|
||||
只输出 JSON 数组,不要有任何额外说明。
|
||||
"""
|
||||
|
||||
# ── 代码生成 Prompt(含签名约束)─────────────────────────
|
||||
CODE_GEN_PROMPT_TEMPLATE = """
|
||||
你是一位资深 {language} 工程师。请根据以下功能需求和【函数签名规范】,生成完整的 {language} 函数代码。
|
||||
FUNC_SIGNATURE_PROMPT_TEMPLATE = """\
|
||||
你是一名资深软件工程师,请根据以下功能需求生成标准函数签名信息。
|
||||
|
||||
【功能需求】
|
||||
- 需求编号: {requirement_id}
|
||||
- 标题: {title}
|
||||
- 描述: {description}
|
||||
- 函数名: {function_name}
|
||||
- 所属模块: {module}
|
||||
|
||||
{knowledge_section}
|
||||
|
||||
## 功能需求
|
||||
标题:{title}
|
||||
描述:{description}
|
||||
【输出要求】
|
||||
以 JSON 对象格式输出,包含以下字段:
|
||||
- name: 函数名(与上方一致)
|
||||
- requirement_id: 需求编号
|
||||
- description: 函数功能描述(英文,一句话)
|
||||
- type: 固定为 "function"
|
||||
- module: 所属模块名称
|
||||
- parameters: 参数字典,key 为参数名,value 包含:
|
||||
- type: 数据类型(integer/string/boolean/float/list/dict/object/void/any)
|
||||
- inout: in / out / inout
|
||||
- required: true / false
|
||||
- description: 参数说明
|
||||
- return:
|
||||
- type: 返回类型
|
||||
- on_success: {{ "value": "...", "description": "..." }} 或 null(void)
|
||||
- on_failure: {{ "value": "...", "description": "..." }} 或 null(void)
|
||||
|
||||
## 【必须严格遵守】函数签名规范
|
||||
以下 JSON 定义了函数的精确接口,生成的代码必须与之完全一致,不得擅自增减或改名参数:
|
||||
只输出 JSON 对象,不要有任何额外说明。
|
||||
"""
|
||||
|
||||
```json
|
||||
CODE_GEN_PROMPT_TEMPLATE = """\
|
||||
你是一名资深 {language} 工程师,请根据以下函数签名和功能描述生成完整的函数实现代码。
|
||||
|
||||
【函数签名】
|
||||
{signature_json}
|
||||
```
|
||||
|
||||
### 签名字段说明
|
||||
- `name`:函数名,必须完全一致
|
||||
- `parameters`:每个 key 即为参数名,`type` 为数据类型,`inout` 含义:
|
||||
- `in` = 普通输入参数
|
||||
- `out` = 输出参数(Python 中通过返回值或可变容器传出)
|
||||
- `inout` = 既作输入又作输出
|
||||
- `return.type`:返回值类型
|
||||
- `return.on_success`:成功时的返回值,代码实现必须与此一致
|
||||
- `return.on_failure`:失败时的返回值或异常,代码实现必须与此一致
|
||||
【功能描述】
|
||||
{description}
|
||||
|
||||
## 输出要求
|
||||
1. 只输出纯代码,不要包含 markdown 代码块标记
|
||||
2. 函数签名(名称、参数列表、返回类型)必须与上方 JSON 规范完全一致
|
||||
3. 成功/失败的返回值必须严格遵守 return.on_success / return.on_failure 的定义
|
||||
4. 包含完整的类型注解(Python 使用 type hints)
|
||||
5. 包含详细的 docstring,其中 Returns 段须注明成功值与失败值
|
||||
6. 包含必要的异常处理
|
||||
7. 代码风格遵循 PEP8(Python)或对应语言规范
|
||||
8. 在文件顶部用注释注明:需求编号、功能标题、函数签名摘要
|
||||
9. 如需导入第三方库,请在顶部统一导入
|
||||
{knowledge_section}
|
||||
|
||||
【输出要求】
|
||||
1. 只输出 {language} 代码,不要有任何 Markdown 标记(不要 ```)
|
||||
2. 包含完整的函数实现(含必要的 import)
|
||||
3. 包含函数文档注释(docstring / JSDoc 等)
|
||||
4. 包含基本的参数校验与错误处理
|
||||
5. 代码风格遵循 {language} 最佳实践
|
||||
"""
|
||||
|
||||
MODULE_CLASSIFY_PROMPT_TEMPLATE = """\
|
||||
你是一名资深软件架构师,请将以下功能需求列表分类到合适的功能模块中。
|
||||
|
||||
【功能需求列表】
|
||||
{requirements_json}
|
||||
|
||||
【输出要求】
|
||||
以 JSON 数组格式输出,每个元素包含:
|
||||
- function_name: 函数名(与输入一致)
|
||||
- module: 所属模块名称(snake_case,如 user_auth / order_service / payment)
|
||||
|
||||
模块划分原则:
|
||||
1. 功能相近的需求归入同一模块
|
||||
2. 模块名使用英文 snake_case
|
||||
3. 模块数量控制在 2~8 个之间
|
||||
4. 若某需求确实无法归类,使用 "default" 模块
|
||||
|
||||
只输出 JSON 数组,不要有任何额外说明。
|
||||
"""
|
||||
|
|
@ -1,99 +1,89 @@
|
|||
# core/code_generator.py - 代码生成核心逻辑(签名约束版)
|
||||
# core/code_generator.py - 代码生成(按模块路由到子目录)
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, List, Callable
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Callable
|
||||
|
||||
import config
|
||||
from core.llm_client import LLMClient
|
||||
from database.models import FunctionalRequirement, CodeFile
|
||||
from utils.output_writer import write_code_file, get_file_extension
|
||||
|
||||
|
||||
class CodeGenerator:
|
||||
"""
|
||||
根据功能需求 + 函数签名约束,使用 LLM 生成代码函数文件。
|
||||
签名由 RequirementAnalyzer.build_function_signature() 预先生成,
|
||||
注入 Prompt 后可确保代码参数列表与签名 JSON 完全一致。
|
||||
"""
|
||||
"""根据函数签名约束,调用 LLM 生成代码文件,并按模块写入子目录"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
"""
|
||||
初始化代码生成器
|
||||
|
||||
Args:
|
||||
llm_client: LLM 客户端实例,为 None 时自动创建
|
||||
"""
|
||||
self.llm = llm_client or LLMClient()
|
||||
def __init__(self, llm: LLMClient):
|
||||
self.llm = llm
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 单个生成
|
||||
# 单个代码文件生成
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def generate(
|
||||
self,
|
||||
func_req: FunctionalRequirement,
|
||||
func_req: FunctionalRequirement,
|
||||
output_dir: str,
|
||||
language: str = config.DEFAULT_LANGUAGE,
|
||||
knowledge: str = "",
|
||||
signature: Optional[dict] = None,
|
||||
language: str = None,
|
||||
knowledge: str = "",
|
||||
signature: dict = None,
|
||||
) -> CodeFile:
|
||||
"""
|
||||
为单个功能需求生成代码文件
|
||||
为单个功能需求生成代码文件,写入 output_dir/<module>/ 子目录。
|
||||
|
||||
Args:
|
||||
func_req: 功能需求对象(必须含有效 id)
|
||||
output_dir: 代码输出目录
|
||||
language: 目标编程语言
|
||||
knowledge: 知识库文本(可选)
|
||||
signature: 函数签名 dict(由 RequirementAnalyzer 生成)。
|
||||
传入后将作为强约束注入 Prompt,确保代码参数
|
||||
与签名 JSON 完全一致;为 None 时退化为无约束模式。
|
||||
func_req: 功能需求对象
|
||||
output_dir: 项目根输出目录
|
||||
language: 目标语言
|
||||
knowledge: 知识库文本
|
||||
signature: 函数签名 dict(可选,有则作为约束)
|
||||
|
||||
Returns:
|
||||
CodeFile 对象(含生成的代码内容和文件路径,未持久化)
|
||||
CodeFile 对象(未持久化)
|
||||
|
||||
Raises:
|
||||
ValueError: func_req.id 为 None
|
||||
Exception: LLM 调用失败或文件写入失败
|
||||
RuntimeError: LLM 调用失败
|
||||
"""
|
||||
if func_req.id is None:
|
||||
raise ValueError("FunctionalRequirement 必须先持久化(id 不能为 None)")
|
||||
language = language or config.DEFAULT_LANGUAGE
|
||||
module = (func_req.module or config.DEFAULT_MODULE).strip()
|
||||
|
||||
knowledge_section = self._build_knowledge_section(knowledge)
|
||||
signature_json = self._build_signature_json(signature, func_req)
|
||||
# 按模块创建子目录
|
||||
module_dir = os.path.join(output_dir, module)
|
||||
os.makedirs(module_dir, exist_ok=True)
|
||||
self._ensure_init_py(module_dir)
|
||||
|
||||
# 构建 Prompt
|
||||
sig_json = json.dumps(signature, ensure_ascii=False, indent=2) if signature else "{}"
|
||||
knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else ""
|
||||
prompt = config.CODE_GEN_PROMPT_TEMPLATE.format(
|
||||
language=language,
|
||||
knowledge_section=knowledge_section,
|
||||
title=func_req.title,
|
||||
description=func_req.description,
|
||||
signature_json=signature_json,
|
||||
language = language,
|
||||
signature_json = sig_json,
|
||||
description = func_req.description,
|
||||
knowledge_section = knowledge_section,
|
||||
)
|
||||
|
||||
code_content = self.llm.chat(
|
||||
system_prompt=(
|
||||
f"你是一位资深 {language} 工程师,只输出纯代码,"
|
||||
"不添加任何 markdown 标记。函数签名必须与提供的 JSON 规范完全一致。"
|
||||
),
|
||||
user_prompt=prompt,
|
||||
)
|
||||
try:
|
||||
code_content = self.llm.chat(
|
||||
prompt,
|
||||
system = f"You are an expert {language} developer. Output only code.",
|
||||
temperature = 0.2,
|
||||
max_tokens = 4096,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"代码生成失败 [{func_req.function_name}]: {e}")
|
||||
|
||||
file_path = write_code_file(
|
||||
output_dir=output_dir,
|
||||
function_name=func_req.function_name,
|
||||
language=language,
|
||||
content=code_content,
|
||||
)
|
||||
|
||||
ext = get_file_extension(language)
|
||||
# 写入文件
|
||||
ext = self._get_extension(language)
|
||||
file_name = f"{func_req.function_name}{ext}"
|
||||
file_path = os.path.join(module_dir, file_name)
|
||||
Path(file_path).write_text(code_content, encoding="utf-8")
|
||||
|
||||
return CodeFile(
|
||||
project_id=func_req.project_id,
|
||||
func_req_id=func_req.id,
|
||||
file_name=file_name,
|
||||
file_path=file_path,
|
||||
language=language,
|
||||
content=code_content,
|
||||
func_req_id = func_req.id,
|
||||
file_name = file_name,
|
||||
file_path = file_path,
|
||||
module = module,
|
||||
language = language,
|
||||
content = code_content,
|
||||
)
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
|
|
@ -102,115 +92,82 @@ class CodeGenerator:
|
|||
|
||||
def generate_batch(
|
||||
self,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
output_dir: str,
|
||||
language: str = config.DEFAULT_LANGUAGE,
|
||||
knowledge: str = "",
|
||||
signatures: Optional[List[dict]] = None,
|
||||
on_progress: Optional[Callable] = None,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
output_dir: str,
|
||||
language: str = None,
|
||||
knowledge: str = "",
|
||||
signatures: Optional[List[dict]] = None,
|
||||
on_progress: Optional[Callable] = None,
|
||||
) -> List[CodeFile]:
|
||||
"""
|
||||
批量生成代码文件
|
||||
批量生成代码文件。
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
output_dir: 输出目录
|
||||
output_dir: 项目根输出目录
|
||||
language: 目标语言
|
||||
knowledge: 知识库文本
|
||||
signatures: 与 func_reqs 等长的签名列表(索引对应)。
|
||||
为 None 时所有条目均以无约束模式生成。
|
||||
on_progress: 进度回调 fn(index, total, func_req, code_file, error)
|
||||
signatures: 与 func_reqs 等长的签名列表(索引对应)
|
||||
on_progress: 进度回调 fn(index, total, req, code_file, error)
|
||||
|
||||
Returns:
|
||||
成功生成的 CodeFile 列表
|
||||
"""
|
||||
results = []
|
||||
total = len(func_reqs)
|
||||
language = language or config.DEFAULT_LANGUAGE
|
||||
total = len(func_reqs)
|
||||
results = []
|
||||
sig_map = self._build_signature_map(func_reqs, signatures)
|
||||
|
||||
# 构建 func_req.id → signature 的快速查找表
|
||||
sig_map = self._build_signature_map(func_reqs, signatures)
|
||||
|
||||
for i, req in enumerate(func_reqs):
|
||||
for i, req in enumerate(func_reqs, 1):
|
||||
sig = sig_map.get(req.id)
|
||||
try:
|
||||
code_file = self.generate(
|
||||
func_req=req,
|
||||
output_dir=output_dir,
|
||||
language=language,
|
||||
knowledge=knowledge,
|
||||
signature=sig,
|
||||
func_req = req,
|
||||
output_dir = output_dir,
|
||||
language = language,
|
||||
knowledge = knowledge,
|
||||
signature = sig,
|
||||
)
|
||||
results.append(code_file)
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, code_file, None)
|
||||
on_progress(i, total, req, code_file, None)
|
||||
except Exception as e:
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, None, e)
|
||||
on_progress(i, total, req, None, e)
|
||||
|
||||
return results
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 私有工具方法
|
||||
# 工具方法
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
def _build_knowledge_section(knowledge: str) -> str:
|
||||
"""构建知识库 Prompt 段落"""
|
||||
if not knowledge or not knowledge.strip():
|
||||
return ""
|
||||
return (
|
||||
"## 参考知识库(实现时请遵循以下规范)\n"
|
||||
f"{knowledge}\n\n---\n"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_signature_json(
|
||||
signature: Optional[dict],
|
||||
func_req: FunctionalRequirement,
|
||||
) -> str:
|
||||
"""
|
||||
将签名 dict 序列化为格式化 JSON 字符串;
|
||||
若签名为 None,则构造最小占位签名,保持 Prompt 结构完整。
|
||||
|
||||
Args:
|
||||
signature: 签名 dict 或 None
|
||||
func_req: 对应的功能需求(用于占位签名)
|
||||
|
||||
Returns:
|
||||
JSON 字符串
|
||||
"""
|
||||
if signature:
|
||||
return json.dumps(signature, ensure_ascii=False, indent=2)
|
||||
# 无签名时的最小占位,提示 LLM 自行设计但保持格式
|
||||
fallback = {
|
||||
"name": func_req.function_name,
|
||||
"requirement_id": f"REQ.{func_req.index_no:02d}",
|
||||
"description": func_req.description,
|
||||
"type": "function",
|
||||
"parameters": "<<请根据功能描述自行设计参数>>",
|
||||
"return": "<<请根据功能描述自行设计返回值>>",
|
||||
}
|
||||
return json.dumps(fallback, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _build_signature_map(
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
signatures: Optional[List[dict]],
|
||||
) -> dict:
|
||||
"""
|
||||
构建 func_req.id → signature 映射表
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
signatures: 与 func_reqs 等长的签名列表,或 None
|
||||
|
||||
Returns:
|
||||
{req_id: signature_dict} 字典
|
||||
"""
|
||||
"""构建 func_req.id → signature 的快速查找表"""
|
||||
if not signatures:
|
||||
return {}
|
||||
sig_map = {}
|
||||
for req, sig in zip(func_reqs, signatures):
|
||||
if req.id is not None and sig:
|
||||
sig_map[req.id] = sig
|
||||
return sig_map
|
||||
return {
|
||||
req.id: sig
|
||||
for req, sig in zip(func_reqs, signatures)
|
||||
if req.id is not None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_extension(language: str) -> str:
|
||||
ext_map = {
|
||||
"python": ".py", "javascript": ".js", "typescript": ".ts",
|
||||
"java": ".java", "go": ".go", "rust": ".rs",
|
||||
"cpp": ".cpp", "c": ".c", "csharp": ".cs",
|
||||
"ruby": ".rb", "php": ".php", "swift": ".swift", "kotlin": ".kt",
|
||||
}
|
||||
return ext_map.get(language.lower(), ".txt")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_init_py(directory: str) -> None:
|
||||
"""在目录中创建 __init__.py(Python 包标识)"""
|
||||
init = os.path.join(directory, "__init__.py")
|
||||
if not os.path.exists(init):
|
||||
Path(init).write_text("# Auto-generated module package\n", encoding="utf-8")
|
||||
|
|
@ -1,90 +1,114 @@
|
|||
# core/llm_client.py - LLM 客户端封装
|
||||
# core/llm_client.py - LLM API 调用封装
|
||||
import time
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI, APIError, APITimeoutError, RateLimitError
|
||||
|
||||
import config
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""
|
||||
OpenAI 兼容 LLM 客户端封装。
|
||||
支持任何兼容 OpenAI API 格式的服务(OpenAI / Azure / 本地模型等)。
|
||||
"""
|
||||
"""封装 OpenAI 兼容接口,提供统一的调用入口与重试机制"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = config.LLM_API_KEY,
|
||||
base_url: str = config.LLM_BASE_URL,
|
||||
model: str = config.LLM_MODEL,
|
||||
temperature: float = config.LLM_TEMPERATURE,
|
||||
api_key: str = None,
|
||||
api_base: str = None,
|
||||
model: str = None,
|
||||
):
|
||||
"""
|
||||
初始化 LLM 客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
base_url: API 基础 URL
|
||||
model: 模型名称
|
||||
temperature: 生成温度(0~1,越低越确定)
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 openai 库
|
||||
ValueError: api_key 为空
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
raise ImportError("请安装 openai: pip install openai")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置,请在 .env 文件中设置")
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self._client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def chat(self, system_prompt: str, user_prompt: str) -> str:
|
||||
"""
|
||||
发送对话请求,返回模型回复文本
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
|
||||
Returns:
|
||||
模型回复的文本内容
|
||||
|
||||
Raises:
|
||||
Exception: API 调用失败
|
||||
"""
|
||||
response = self._client.chat.completions.create(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
self.model = model or config.LLM_MODEL
|
||||
self.client = OpenAI(
|
||||
api_key = api_key or config.LLM_API_KEY,
|
||||
base_url = api_base or config.LLM_API_BASE,
|
||||
timeout = config.LLM_TIMEOUT,
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def chat_json(self, system_prompt: str, user_prompt: str) -> dict:
|
||||
def chat(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "You are a helpful assistant.",
|
||||
temperature: float = 0.2,
|
||||
max_tokens: int = 4096,
|
||||
) -> str:
|
||||
"""
|
||||
发送对话请求,解析并返回 JSON 结果
|
||||
发送单轮对话请求,返回模型回复文本。
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
prompt: 用户消息
|
||||
system: 系统提示词
|
||||
temperature: 采样温度
|
||||
max_tokens: 最大输出 token 数
|
||||
|
||||
Returns:
|
||||
解析后的 dict 对象
|
||||
模型回复的纯文本字符串
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 模型返回非合法 JSON
|
||||
RuntimeError: 超过最大重试次数后仍失败
|
||||
"""
|
||||
raw = self.chat(system_prompt, user_prompt)
|
||||
last_error = None
|
||||
for attempt in range(1, config.LLM_MAX_RETRY + 1):
|
||||
try:
|
||||
resp = self.client.chat.completions.create(
|
||||
model = self.model,
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature = temperature,
|
||||
max_tokens = max_tokens,
|
||||
)
|
||||
return resp.choices[0].message.content.strip()
|
||||
except RateLimitError as e:
|
||||
wait = 2 ** attempt
|
||||
last_error = e
|
||||
time.sleep(wait)
|
||||
except APITimeoutError as e:
|
||||
last_error = e
|
||||
time.sleep(2)
|
||||
except APIError as e:
|
||||
last_error = e
|
||||
if attempt >= config.LLM_MAX_RETRY:
|
||||
break
|
||||
time.sleep(1)
|
||||
raise RuntimeError(
|
||||
f"LLM 调用失败(已重试 {config.LLM_MAX_RETRY} 次): {last_error}"
|
||||
)
|
||||
|
||||
def chat_json(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "You are a helpful assistant. Always respond with valid JSON.",
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
) -> any:
|
||||
"""
|
||||
发送请求并将回复解析为 JSON 对象。
|
||||
|
||||
Returns:
|
||||
解析后的 Python 对象(dict 或 list)
|
||||
|
||||
Raises:
|
||||
ValueError: JSON 解析失败
|
||||
RuntimeError: LLM 调用失败
|
||||
"""
|
||||
raw = self.chat(prompt, system=system, temperature=temperature, max_tokens=max_tokens)
|
||||
# 去除可能的 markdown 代码块包裹
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
lines = raw.split("\n")
|
||||
raw = "\n".join(lines[1:-1])
|
||||
return json.loads(raw)
|
||||
cleaned = self._strip_markdown_code_block(raw)
|
||||
try:
|
||||
return json.loads(cleaned)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"LLM 返回内容无法解析为 JSON: {e}\n原始内容:\n{raw}")
|
||||
|
||||
@staticmethod
|
||||
def _strip_markdown_code_block(text: str) -> str:
|
||||
"""去除 ```json ... ``` 或 ``` ... ``` 包裹"""
|
||||
text = text.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.splitlines()
|
||||
# 去掉首行(```json 或 ```)和末行(```)
|
||||
inner = lines[1:] if lines[-1].strip() == "```" else lines[1:]
|
||||
if inner and inner[-1].strip() == "```":
|
||||
inner = inner[:-1]
|
||||
text = "\n".join(inner).strip()
|
||||
return text
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
# core/requirement_analyzer.py - 需求分解 & 函数签名生成
|
||||
import re
|
||||
from typing import List, Optional
|
||||
import json
|
||||
from typing import List, Optional, Callable
|
||||
|
||||
import config
|
||||
from core.llm_client import LLMClient
|
||||
|
|
@ -8,19 +8,10 @@ from database.models import FunctionalRequirement
|
|||
|
||||
|
||||
class RequirementAnalyzer:
|
||||
"""
|
||||
使用 LLM 将原始需求分解为功能需求列表,并生成函数接口签名。
|
||||
支持注入知识库上下文以提升分解质量。
|
||||
"""
|
||||
"""负责需求分解、模块分类、函数签名生成"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
"""
|
||||
初始化需求分析器
|
||||
|
||||
Args:
|
||||
llm_client: LLM 客户端实例,为 None 时自动创建
|
||||
"""
|
||||
self.llm = llm_client or LLMClient()
|
||||
def __init__(self, llm: LLMClient):
|
||||
self.llm = llm
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 需求分解
|
||||
|
|
@ -29,12 +20,12 @@ class RequirementAnalyzer:
|
|||
def decompose(
|
||||
self,
|
||||
raw_requirement: str,
|
||||
project_id: int,
|
||||
raw_req_id: int,
|
||||
knowledge: str = "",
|
||||
project_id: int,
|
||||
raw_req_id: int,
|
||||
knowledge: str = "",
|
||||
) -> List[FunctionalRequirement]:
|
||||
"""
|
||||
将原始需求分解为功能需求列表
|
||||
将原始需求文本分解为功能需求列表(含模块分类)。
|
||||
|
||||
Args:
|
||||
raw_requirement: 原始需求文本
|
||||
|
|
@ -44,196 +35,175 @@ class RequirementAnalyzer:
|
|||
|
||||
Returns:
|
||||
FunctionalRequirement 对象列表(未持久化,id=None)
|
||||
|
||||
Raises:
|
||||
ValueError: LLM 返回格式不合法
|
||||
json.JSONDecodeError: JSON 解析失败
|
||||
"""
|
||||
knowledge_section = self._build_knowledge_section(knowledge)
|
||||
knowledge_section = (
|
||||
f"【参考知识库】\n{knowledge}\n" if knowledge else ""
|
||||
)
|
||||
prompt = config.DECOMPOSE_PROMPT_TEMPLATE.format(
|
||||
knowledge_section=knowledge_section,
|
||||
raw_requirement=raw_requirement,
|
||||
raw_requirement = raw_requirement,
|
||||
knowledge_section = knowledge_section,
|
||||
)
|
||||
|
||||
result = self.llm.chat_json(
|
||||
system_prompt="你是一位资深软件架构师,擅长需求分析与系统设计。",
|
||||
user_prompt=prompt,
|
||||
)
|
||||
try:
|
||||
items = self.llm.chat_json(prompt)
|
||||
if not isinstance(items, list):
|
||||
raise ValueError("LLM 返回结果不是数组")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"需求分解失败: {e}")
|
||||
|
||||
items = result.get("functional_requirements", [])
|
||||
if not items:
|
||||
raise ValueError("LLM 未返回任何功能需求,请检查原始需求描述")
|
||||
|
||||
requirements = []
|
||||
for item in items:
|
||||
reqs = []
|
||||
for i, item in enumerate(items, 1):
|
||||
req = FunctionalRequirement(
|
||||
project_id=project_id,
|
||||
raw_req_id=raw_req_id,
|
||||
index_no=int(item.get("index", len(requirements) + 1)),
|
||||
title=item.get("title", "未命名功能"),
|
||||
description=item.get("description", ""),
|
||||
function_name=self._sanitize_function_name(
|
||||
item.get("function_name", f"func_{len(requirements)+1}")
|
||||
),
|
||||
priority=item.get("priority", "medium"),
|
||||
project_id = project_id,
|
||||
raw_req_id = raw_req_id,
|
||||
index_no = i,
|
||||
title = item.get("title", f"功能{i}"),
|
||||
description = item.get("description", ""),
|
||||
function_name = item.get("function_name", f"function_{i}"),
|
||||
priority = item.get("priority", "medium"),
|
||||
module = item.get("module", config.DEFAULT_MODULE),
|
||||
status = "pending",
|
||||
is_custom = False,
|
||||
)
|
||||
requirements.append(req)
|
||||
|
||||
return requirements
|
||||
reqs.append(req)
|
||||
return reqs
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 函数签名生成(新增)
|
||||
# 模块分类(独立步骤,可对已有需求列表重新分类)
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def classify_modules(
|
||||
self,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
knowledge: str = "",
|
||||
) -> List[dict]:
|
||||
"""
|
||||
对功能需求列表进行模块分类,返回 {function_name: module} 映射列表。
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
knowledge: 知识库文本(可选)
|
||||
|
||||
Returns:
|
||||
[{"function_name": "...", "module": "..."}, ...]
|
||||
"""
|
||||
req_list = [
|
||||
{
|
||||
"index_no": r.index_no,
|
||||
"title": r.title,
|
||||
"description": r.description,
|
||||
"function_name": r.function_name,
|
||||
}
|
||||
for r in func_reqs
|
||||
]
|
||||
knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else ""
|
||||
prompt = config.MODULE_CLASSIFY_PROMPT_TEMPLATE.format(
|
||||
requirements_json = json.dumps(req_list, ensure_ascii=False, indent=2),
|
||||
knowledge_section = knowledge_section,
|
||||
)
|
||||
try:
|
||||
result = self.llm.chat_json(prompt)
|
||||
if not isinstance(result, list):
|
||||
raise ValueError("LLM 返回结果不是数组")
|
||||
return result
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"模块分类失败: {e}")
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 函数签名生成
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def build_function_signature(
|
||||
self,
|
||||
func_req: FunctionalRequirement,
|
||||
knowledge: str = "",
|
||||
func_req: FunctionalRequirement,
|
||||
requirement_id: str = "",
|
||||
knowledge: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
为单个功能需求生成函数接口签名 JSON
|
||||
为单个功能需求生成函数签名 dict。
|
||||
|
||||
Args:
|
||||
func_req: 功能需求对象(需含有效 id)
|
||||
knowledge: 知识库文本(可选)
|
||||
func_req: 功能需求对象
|
||||
requirement_id: 需求编号字符串(如 "REQ.01")
|
||||
knowledge: 知识库文本
|
||||
|
||||
Returns:
|
||||
符合接口规范的 dict,包含 name/requirement_id/description/
|
||||
type/parameters/return 字段
|
||||
函数签名 dict
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: LLM 返回非合法 JSON
|
||||
RuntimeError: LLM 调用或解析失败
|
||||
"""
|
||||
requirement_id = self._format_requirement_id(func_req.index_no)
|
||||
knowledge_section = self._build_knowledge_section(knowledge)
|
||||
|
||||
knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else ""
|
||||
prompt = config.FUNC_SIGNATURE_PROMPT_TEMPLATE.format(
|
||||
knowledge_section=knowledge_section,
|
||||
requirement_id=requirement_id,
|
||||
title=func_req.title,
|
||||
function_name=func_req.function_name,
|
||||
description=func_req.description,
|
||||
requirement_id = requirement_id or f"REQ.{func_req.index_no:02d}",
|
||||
title = func_req.title,
|
||||
description = func_req.description,
|
||||
function_name = func_req.function_name,
|
||||
module = func_req.module or config.DEFAULT_MODULE,
|
||||
knowledge_section = knowledge_section,
|
||||
)
|
||||
|
||||
signature = self.llm.chat_json(
|
||||
system_prompt=(
|
||||
"你是一位资深软件架构师,专注于 API 接口设计。"
|
||||
"只输出合法 JSON,不添加任何说明文字。"
|
||||
),
|
||||
user_prompt=prompt,
|
||||
)
|
||||
|
||||
# 确保关键字段存在,做兜底处理
|
||||
signature.setdefault("name", func_req.function_name)
|
||||
signature.setdefault("requirement_id", requirement_id)
|
||||
signature.setdefault("description", func_req.description)
|
||||
signature.setdefault("type", "function")
|
||||
signature.setdefault("parameters", {})
|
||||
signature.setdefault("return", {"type": "void", "description": ""})
|
||||
|
||||
return signature
|
||||
try:
|
||||
sig = self.llm.chat_json(prompt)
|
||||
if not isinstance(sig, dict):
|
||||
raise ValueError("LLM 返回结果不是 dict")
|
||||
# 确保 module 字段存在
|
||||
if "module" not in sig:
|
||||
sig["module"] = func_req.module or config.DEFAULT_MODULE
|
||||
return sig
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"签名生成失败 [{func_req.function_name}]: {e}")
|
||||
|
||||
def build_function_signatures_batch(
|
||||
self,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
knowledge: str = "",
|
||||
on_progress=None,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
knowledge: str = "",
|
||||
on_progress: Optional[Callable] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
批量为功能需求列表生成函数接口签名
|
||||
批量生成函数签名,失败时使用降级结构。
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
knowledge: 知识库文本(可选)
|
||||
on_progress: 进度回调 fn(index, total, func_req, signature, error)
|
||||
knowledge: 知识库文本
|
||||
on_progress: 进度回调 fn(index, total, req, signature, error)
|
||||
|
||||
Returns:
|
||||
签名 dict 列表,顺序与 func_reqs 一致;
|
||||
生成失败的条目使用降级结构填充,不中断整体流程
|
||||
与 func_reqs 等长的签名 dict 列表(索引一一对应)
|
||||
"""
|
||||
results = []
|
||||
total = len(func_reqs)
|
||||
signatures = []
|
||||
total = len(func_reqs)
|
||||
|
||||
for i, req in enumerate(func_reqs):
|
||||
for i, req in enumerate(func_reqs, 1):
|
||||
req_id = f"REQ.{req.index_no:02d}"
|
||||
try:
|
||||
sig = self.build_function_signature(req, knowledge)
|
||||
results.append(sig)
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, sig, None)
|
||||
sig = self.build_function_signature(req, req_id, knowledge)
|
||||
error = None
|
||||
except Exception as e:
|
||||
# 降级:用基础信息填充,保证 JSON 完整性
|
||||
fallback = self._build_fallback_signature(req)
|
||||
results.append(fallback)
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, fallback, e)
|
||||
sig = self._fallback_signature(req, req_id)
|
||||
error = e
|
||||
|
||||
return results
|
||||
signatures.append(sig)
|
||||
if on_progress:
|
||||
on_progress(i, total, req, sig, error)
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 私有工具方法
|
||||
# ══════════════════════════════════════════════════
|
||||
return signatures
|
||||
|
||||
@staticmethod
|
||||
def _build_knowledge_section(knowledge: str) -> str:
|
||||
"""构建知识库 Prompt 段落"""
|
||||
if not knowledge or not knowledge.strip():
|
||||
return ""
|
||||
return f"""## 参考知识库
|
||||
{knowledge}
|
||||
|
||||
---
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_function_name(name: str) -> str:
|
||||
"""
|
||||
清理函数名,确保符合 snake_case 规范
|
||||
|
||||
Args:
|
||||
name: 原始函数名
|
||||
|
||||
Returns:
|
||||
合法的 snake_case 函数名
|
||||
"""
|
||||
name = re.sub(r"[^a-zA-Z0-9_]", "_", name).lower()
|
||||
name = re.sub(r"_+", "_", name).strip("_")
|
||||
if name and name[0].isdigit():
|
||||
name = "func_" + name
|
||||
return name or "unnamed_function"
|
||||
|
||||
@staticmethod
|
||||
def _format_requirement_id(index_no: int) -> str:
|
||||
"""
|
||||
将序号格式化为需求编号字符串
|
||||
|
||||
Args:
|
||||
index_no: 功能需求序号(从 1 开始)
|
||||
|
||||
Returns:
|
||||
格式化编号,如 'REQ.01'、'REQ.12'
|
||||
"""
|
||||
return f"REQ.{index_no:02d}"
|
||||
|
||||
@staticmethod
|
||||
def _build_fallback_signature(func_req: FunctionalRequirement) -> dict:
|
||||
"""
|
||||
构建降级签名(LLM 调用失败时使用)
|
||||
|
||||
Args:
|
||||
func_req: 功能需求对象
|
||||
|
||||
Returns:
|
||||
包含基础信息的签名 dict
|
||||
"""
|
||||
def _fallback_signature(
|
||||
req: FunctionalRequirement,
|
||||
requirement_id: str,
|
||||
) -> dict:
|
||||
"""生成降级签名结构(LLM 失败时使用)"""
|
||||
return {
|
||||
"name": func_req.function_name,
|
||||
"requirement_id": f"REQ.{func_req.index_no:02d}",
|
||||
"description": func_req.description,
|
||||
"name": req.function_name,
|
||||
"requirement_id": requirement_id,
|
||||
"description": req.description,
|
||||
"type": "function",
|
||||
"module": req.module or config.DEFAULT_MODULE,
|
||||
"parameters": {},
|
||||
"return": {
|
||||
"type": "void",
|
||||
"description": "TODO: define return value"
|
||||
"return": {
|
||||
"type": "any",
|
||||
"on_success": {"value": "...", "description": "成功时返回值"},
|
||||
"on_failure": {"value": "None", "description": "失败时返回 None"},
|
||||
},
|
||||
"_note": "Auto-generated fallback due to LLM error"
|
||||
}
|
||||
|
|
@ -1,314 +1,152 @@
|
|||
# database/db_manager.py - 数据库操作管理器
|
||||
import sqlite3
|
||||
# database/db_manager.py - 数据库 CRUD 操作封装
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
from database.models import (
|
||||
CREATE_TABLES_SQL, Project, RawRequirement,
|
||||
FunctionalRequirement, CodeFile
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
import config
|
||||
from database.models import Base, Project, RawRequirement, FunctionalRequirement, CodeFile
|
||||
|
||||
|
||||
class DBManager:
|
||||
"""SQLite 数据库管理器,封装所有 CRUD 操作"""
|
||||
|
||||
def __init__(self, db_path: str = config.DB_PATH):
|
||||
self.db_path = db_path
|
||||
def __init__(self, db_path: str = None):
|
||||
db_path = db_path or config.DB_PATH
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
self._init_db()
|
||||
self.engine = create_engine(f"sqlite:///{db_path}", echo=False)
|
||||
Base.metadata.create_all(self.engine)
|
||||
self._Session = sessionmaker(bind=self.engine)
|
||||
|
||||
# ── 连接上下文管理器 ──────────────────────────────────
|
||||
|
||||
@contextmanager
|
||||
def _get_conn(self):
|
||||
"""获取数据库连接(自动提交/回滚)"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库,创建所有表"""
|
||||
with self._get_conn() as conn:
|
||||
conn.executescript(CREATE_TABLES_SQL)
|
||||
def _session(self) -> Session:
|
||||
return self._Session()
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# Project CRUD
|
||||
# Project
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_project(self, project: Project) -> int:
|
||||
"""创建项目,返回新项目 ID"""
|
||||
sql = """
|
||||
INSERT INTO projects (name, description, language, output_dir, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
project.name, project.description, project.language,
|
||||
project.output_dir, project.created_at, project.updated_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
|
||||
def get_project_by_id(self, project_id: int) -> Optional[Project]:
|
||||
"""根据 ID 查询项目"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM projects WHERE id = ?", (project_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return Project(
|
||||
id=row["id"], name=row["name"], description=row["description"],
|
||||
language=row["language"], output_dir=row["output_dir"],
|
||||
created_at=row["created_at"], updated_at=row["updated_at"]
|
||||
)
|
||||
with self._session() as s:
|
||||
s.add(project)
|
||||
s.commit()
|
||||
s.refresh(project)
|
||||
return project.id
|
||||
|
||||
def get_project_by_name(self, name: str) -> Optional[Project]:
|
||||
"""根据名称查询项目"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM projects WHERE name = ?", (name,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return Project(
|
||||
id=row["id"], name=row["name"], description=row["description"],
|
||||
language=row["language"], output_dir=row["output_dir"],
|
||||
created_at=row["created_at"], updated_at=row["updated_at"]
|
||||
)
|
||||
with self._session() as s:
|
||||
return s.query(Project).filter_by(name=name).first()
|
||||
|
||||
def list_projects(self) -> List[Project]:
|
||||
"""列出所有项目"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM projects ORDER BY created_at DESC"
|
||||
).fetchall()
|
||||
return [
|
||||
Project(
|
||||
id=r["id"], name=r["name"], description=r["description"],
|
||||
language=r["language"], output_dir=r["output_dir"],
|
||||
created_at=r["created_at"], updated_at=r["updated_at"]
|
||||
) for r in rows
|
||||
]
|
||||
def get_project_by_id(self, project_id: int) -> Optional[Project]:
|
||||
with self._session() as s:
|
||||
return s.get(Project, project_id)
|
||||
|
||||
def update_project(self, project: Project) -> None:
|
||||
"""更新项目信息"""
|
||||
project.updated_at = datetime.now().isoformat()
|
||||
sql = """
|
||||
UPDATE projects
|
||||
SET name=?, description=?, language=?, output_dir=?, updated_at=?
|
||||
WHERE id=?
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(sql, (
|
||||
project.name, project.description, project.language,
|
||||
project.output_dir, project.updated_at, project.id
|
||||
))
|
||||
with self._session() as s:
|
||||
s.merge(project)
|
||||
s.commit()
|
||||
|
||||
def delete_project(self, project_id: int) -> None:
|
||||
"""删除项目(级联删除所有关联数据)"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute("DELETE FROM projects WHERE id = ?", (project_id,))
|
||||
def list_projects(self) -> List[Project]:
|
||||
with self._session() as s:
|
||||
return s.query(Project).order_by(Project.created_at.desc()).all()
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# RawRequirement CRUD
|
||||
# RawRequirement
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_raw_requirement(self, req: RawRequirement) -> int:
|
||||
"""创建原始需求,返回新记录 ID"""
|
||||
sql = """
|
||||
INSERT INTO raw_requirements
|
||||
(project_id, content, source_type, source_name, knowledge, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
req.project_id, req.content, req.source_type,
|
||||
req.source_name, req.knowledge, req.created_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
def create_raw_requirement(self, raw_req: RawRequirement) -> int:
|
||||
with self._session() as s:
|
||||
s.add(raw_req)
|
||||
s.commit()
|
||||
s.refresh(raw_req)
|
||||
return raw_req.id
|
||||
|
||||
def get_raw_requirement(self, req_id: int) -> Optional[RawRequirement]:
|
||||
"""根据 ID 查询原始需求"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM raw_requirements WHERE id = ?", (req_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return RawRequirement(
|
||||
id=row["id"], project_id=row["project_id"], content=row["content"],
|
||||
source_type=row["source_type"], source_name=row["source_name"],
|
||||
knowledge=row["knowledge"], created_at=row["created_at"]
|
||||
)
|
||||
|
||||
def list_raw_requirements_by_project(self, project_id: int) -> List[RawRequirement]:
|
||||
"""查询项目下所有原始需求"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM raw_requirements WHERE project_id = ? ORDER BY created_at",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
return [
|
||||
RawRequirement(
|
||||
id=r["id"], project_id=r["project_id"], content=r["content"],
|
||||
source_type=r["source_type"], source_name=r["source_name"],
|
||||
knowledge=r["knowledge"], created_at=r["created_at"]
|
||||
) for r in rows
|
||||
]
|
||||
def get_raw_requirement(self, raw_req_id: int) -> Optional[RawRequirement]:
|
||||
with self._session() as s:
|
||||
return s.get(RawRequirement, raw_req_id)
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# FunctionalRequirement CRUD
|
||||
# FunctionalRequirement
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_functional_requirement(self, req: FunctionalRequirement) -> int:
|
||||
"""创建功能需求,返回新记录 ID"""
|
||||
sql = """
|
||||
INSERT INTO functional_requirements
|
||||
(project_id, raw_req_id, index_no, title, description,
|
||||
function_name, priority, status, is_custom, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
req.project_id, req.raw_req_id, req.index_no,
|
||||
req.title, req.description, req.function_name,
|
||||
req.priority, req.status, int(req.is_custom),
|
||||
req.created_at, req.updated_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
with self._session() as s:
|
||||
s.add(req)
|
||||
s.commit()
|
||||
s.refresh(req)
|
||||
return req.id
|
||||
|
||||
def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]:
|
||||
"""根据 ID 查询功能需求"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM functional_requirements WHERE id = ?", (req_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_func_req(row)
|
||||
with self._session() as s:
|
||||
return s.get(FunctionalRequirement, req_id)
|
||||
|
||||
def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]:
|
||||
"""查询项目下所有功能需求(按序号排序)"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"""SELECT * FROM functional_requirements
|
||||
WHERE project_id = ? ORDER BY index_no""",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
return [self._row_to_func_req(r) for r in rows]
|
||||
|
||||
def update_functional_requirement(self, req: FunctionalRequirement) -> None:
|
||||
"""更新功能需求"""
|
||||
req.updated_at = datetime.now().isoformat()
|
||||
sql = """
|
||||
UPDATE functional_requirements
|
||||
SET title=?, description=?, function_name=?, priority=?,
|
||||
status=?, index_no=?, updated_at=?
|
||||
WHERE id=?
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(sql, (
|
||||
req.title, req.description, req.function_name,
|
||||
req.priority, req.status, req.index_no,
|
||||
req.updated_at, req.id
|
||||
))
|
||||
|
||||
def delete_functional_requirement(self, req_id: int) -> None:
|
||||
"""删除功能需求"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM functional_requirements WHERE id = ?", (req_id,)
|
||||
with self._session() as s:
|
||||
return (
|
||||
s.query(FunctionalRequirement)
|
||||
.filter_by(project_id=project_id)
|
||||
.order_by(FunctionalRequirement.index_no)
|
||||
.all()
|
||||
)
|
||||
|
||||
def _row_to_func_req(self, row) -> FunctionalRequirement:
|
||||
"""sqlite Row → FunctionalRequirement 对象"""
|
||||
return FunctionalRequirement(
|
||||
id=row["id"], project_id=row["project_id"],
|
||||
raw_req_id=row["raw_req_id"], index_no=row["index_no"],
|
||||
title=row["title"], description=row["description"],
|
||||
function_name=row["function_name"], priority=row["priority"],
|
||||
status=row["status"], is_custom=bool(row["is_custom"]),
|
||||
created_at=row["created_at"], updated_at=row["updated_at"]
|
||||
)
|
||||
def update_functional_requirement(self, req: FunctionalRequirement) -> None:
|
||||
with self._session() as s:
|
||||
s.merge(req)
|
||||
s.commit()
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# CodeFile CRUD
|
||||
# ══════════════════════════════════════════════════
|
||||
def delete_functional_requirement(self, req_id: int) -> None:
|
||||
with self._session() as s:
|
||||
obj = s.get(FunctionalRequirement, req_id)
|
||||
if obj:
|
||||
s.delete(obj)
|
||||
s.commit()
|
||||
|
||||
def create_code_file(self, code_file: CodeFile) -> int:
|
||||
"""创建代码文件记录,返回新记录 ID"""
|
||||
sql = """
|
||||
INSERT INTO code_files
|
||||
(project_id, func_req_id, file_name, file_path,
|
||||
language, content, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
def bulk_update_modules(self, updates: List[dict]) -> None:
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
code_file.project_id, code_file.func_req_id,
|
||||
code_file.file_name, code_file.file_path,
|
||||
code_file.language, code_file.content,
|
||||
code_file.created_at, code_file.updated_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
批量更新功能需求的 module 字段。
|
||||
|
||||
Args:
|
||||
updates: [{"function_name": "...", "module": "..."}, ...]
|
||||
"""
|
||||
with self._session() as s:
|
||||
name_to_module = {u["function_name"]: u["module"] for u in updates}
|
||||
reqs = s.query(FunctionalRequirement).filter(
|
||||
FunctionalRequirement.function_name.in_(name_to_module.keys())
|
||||
).all()
|
||||
for req in reqs:
|
||||
req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE)
|
||||
s.commit()
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# CodeFile
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def upsert_code_file(self, code_file: CodeFile) -> int:
|
||||
"""插入或更新代码文件(按 func_req_id 唯一键)"""
|
||||
existing = self.get_code_file_by_func_req(code_file.func_req_id)
|
||||
if existing:
|
||||
code_file.id = existing.id
|
||||
code_file.updated_at = datetime.now().isoformat()
|
||||
sql = """
|
||||
UPDATE code_files
|
||||
SET file_name=?, file_path=?, language=?, content=?, updated_at=?
|
||||
WHERE id=?
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(sql, (
|
||||
code_file.file_name, code_file.file_path,
|
||||
code_file.language, code_file.content,
|
||||
code_file.updated_at, code_file.id
|
||||
))
|
||||
return code_file.id
|
||||
else:
|
||||
return self.create_code_file(code_file)
|
||||
with self._session() as s:
|
||||
existing = (
|
||||
s.query(CodeFile)
|
||||
.filter_by(func_req_id=code_file.func_req_id)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.file_name = code_file.file_name
|
||||
existing.file_path = code_file.file_path
|
||||
existing.module = code_file.module
|
||||
existing.language = code_file.language
|
||||
existing.content = code_file.content
|
||||
s.commit()
|
||||
return existing.id
|
||||
else:
|
||||
s.add(code_file)
|
||||
s.commit()
|
||||
s.refresh(code_file)
|
||||
return code_file.id
|
||||
|
||||
def get_code_file_by_func_req(self, func_req_id: int) -> Optional[CodeFile]:
|
||||
"""根据功能需求 ID 查询代码文件"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM code_files WHERE func_req_id = ?", (func_req_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_code_file(row)
|
||||
|
||||
def list_code_files_by_project(self, project_id: int) -> List[CodeFile]:
|
||||
"""查询项目下所有代码文件"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM code_files WHERE project_id = ? ORDER BY created_at",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
return [self._row_to_code_file(r) for r in rows]
|
||||
|
||||
def _row_to_code_file(self, row) -> CodeFile:
|
||||
"""sqlite Row → CodeFile 对象"""
|
||||
return CodeFile(
|
||||
id=row["id"], project_id=row["project_id"],
|
||||
func_req_id=row["func_req_id"], file_name=row["file_name"],
|
||||
file_path=row["file_path"], language=row["language"],
|
||||
content=row["content"], created_at=row["created_at"],
|
||||
updated_at=row["updated_at"]
|
||||
)
|
||||
def list_code_files(self, project_id: int) -> List[CodeFile]:
|
||||
with self._session() as s:
|
||||
return (
|
||||
s.query(CodeFile)
|
||||
.join(FunctionalRequirement)
|
||||
.filter(FunctionalRequirement.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
|
|
@ -1,122 +1,95 @@
|
|||
# database/models.py - 数据模型定义(SQLite 建表 DDL)
|
||||
from dataclasses import dataclass, field
|
||||
# database/models.py - SQLAlchemy ORM 数据模型
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, Text, Boolean,
|
||||
DateTime, ForeignKey, Enum,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# DDL 建表语句
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
CREATE_TABLES_SQL = """
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
-- 项目表
|
||||
CREATE TABLE IF NOT EXISTS projects (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE, -- 项目名称
|
||||
description TEXT, -- 项目描述
|
||||
language TEXT NOT NULL DEFAULT 'python', -- 目标代码语言
|
||||
output_dir TEXT NOT NULL, -- 输出目录路径
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- 原始需求表
|
||||
CREATE TABLE IF NOT EXISTS raw_requirements (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
project_id INTEGER NOT NULL, -- 关联项目
|
||||
content TEXT NOT NULL, -- 需求原文
|
||||
source_type TEXT NOT NULL DEFAULT 'text', -- text | file
|
||||
source_name TEXT, -- 文件名(文件输入时)
|
||||
knowledge TEXT, -- 合并后的知识库内容
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- 功能需求表
|
||||
CREATE TABLE IF NOT EXISTS functional_requirements (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
project_id INTEGER NOT NULL,
|
||||
raw_req_id INTEGER NOT NULL, -- 关联原始需求
|
||||
index_no INTEGER NOT NULL, -- 序号
|
||||
title TEXT NOT NULL, -- 功能标题
|
||||
description TEXT NOT NULL, -- 功能描述
|
||||
function_name TEXT NOT NULL, -- 对应函数名
|
||||
priority TEXT NOT NULL DEFAULT 'medium', -- high|medium|low
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- pending|generated|skipped
|
||||
is_custom INTEGER NOT NULL DEFAULT 0, -- 是否用户自定义 (0/1)
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (raw_req_id) REFERENCES raw_requirements(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- 代码文件表
|
||||
CREATE TABLE IF NOT EXISTS code_files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
project_id INTEGER NOT NULL,
|
||||
func_req_id INTEGER NOT NULL UNIQUE, -- 关联功能需求(1对1)
|
||||
file_name TEXT NOT NULL, -- 文件名
|
||||
file_path TEXT NOT NULL, -- 完整路径
|
||||
language TEXT NOT NULL, -- 代码语言
|
||||
content TEXT NOT NULL, -- 代码内容
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (func_req_id) REFERENCES functional_requirements(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# 数据类(Python 对象映射)
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
@dataclass
|
||||
class Project:
|
||||
name: str
|
||||
output_dir: str
|
||||
language: str = "python"
|
||||
description: str = ""
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RawRequirement:
|
||||
project_id: int
|
||||
content: str
|
||||
source_type: str = "text" # text | file
|
||||
source_name: Optional[str] = None
|
||||
knowledge: Optional[str] = None
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
class Project(Base):
|
||||
"""项目表"""
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(200), nullable=False, unique=True)
|
||||
language = Column(String(50), nullable=False, default="python")
|
||||
description = Column(Text, nullable=True)
|
||||
output_dir = Column(String(500), nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
raw_requirements = relationship("RawRequirement", back_populates="project", cascade="all, delete-orphan")
|
||||
functional_requirements = relationship("FunctionalRequirement", back_populates="project", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Project(id={self.id}, name={self.name!r}, language={self.language!r})>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionalRequirement:
|
||||
project_id: int
|
||||
raw_req_id: int
|
||||
index_no: int
|
||||
title: str
|
||||
description: str
|
||||
function_name: str
|
||||
priority: str = "medium"
|
||||
status: str = "pending"
|
||||
is_custom: bool = False
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
class RawRequirement(Base):
|
||||
"""原始需求表"""
|
||||
__tablename__ = "raw_requirements"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
source_type = Column(String(20), nullable=False, default="text") # text / file
|
||||
source_name = Column(String(200), nullable=True)
|
||||
knowledge = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
project = relationship("Project", back_populates="raw_requirements")
|
||||
functional_requirements = relationship("FunctionalRequirement", back_populates="raw_requirement", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RawRequirement(id={self.id}, project_id={self.project_id})>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeFile:
|
||||
project_id: int
|
||||
func_req_id: int
|
||||
file_name: str
|
||||
file_path: str
|
||||
language: str
|
||||
content: str
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
class FunctionalRequirement(Base):
|
||||
"""功能需求表"""
|
||||
__tablename__ = "functional_requirements"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False)
|
||||
raw_req_id = Column(Integer, ForeignKey("raw_requirements.id"), nullable=False)
|
||||
index_no = Column(Integer, nullable=False)
|
||||
title = Column(String(200), nullable=False)
|
||||
description = Column(Text, nullable=False)
|
||||
function_name = Column(String(200), nullable=False)
|
||||
priority = Column(Enum("high", "medium", "low"), default="medium")
|
||||
module = Column(String(100), nullable=True, default="default") # 功能模块
|
||||
status = Column(String(50), nullable=False, default="pending") # pending / generated / failed
|
||||
is_custom = Column(Boolean, nullable=False, default=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
project = relationship("Project", back_populates="functional_requirements")
|
||||
raw_requirement = relationship("RawRequirement", back_populates="functional_requirements")
|
||||
code_files = relationship("CodeFile", back_populates="functional_requirement", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<FunctionalRequirement(id={self.id}, title={self.title!r}, "
|
||||
f"module={self.module!r}, status={self.status!r})>"
|
||||
)
|
||||
|
||||
|
||||
class CodeFile(Base):
|
||||
"""生成代码文件表"""
|
||||
__tablename__ = "code_files"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
func_req_id = Column(Integer, ForeignKey("functional_requirements.id"), nullable=False)
|
||||
file_name = Column(String(200), nullable=False)
|
||||
file_path = Column(String(500), nullable=False)
|
||||
module = Column(String(100), nullable=True) # 冗余存储,方便查询
|
||||
language = Column(String(50), nullable=False)
|
||||
content = Column(Text, nullable=True)
|
||||
generated_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
functional_requirement = relationship("FunctionalRequirement", back_populates="code_files")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CodeFile(id={self.id}, file_name={self.file_name!r}, module={self.module!r})>"
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,7 @@
|
|||
openai>=1.0.0
|
||||
python-dotenv>=1.0.0
|
||||
openai>=1.30.0
|
||||
click>=8.1.0
|
||||
rich>=13.0.0
|
||||
python-docx>=0.8.11
|
||||
PyPDF2>=3.0.0
|
||||
click>=8.1.0
|
||||
sqlalchemy>=2.0.0
|
||||
python-dotenv>=1.0.0
|
||||
pypdf>=4.0.0
|
||||
python-docx>=1.1.0
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 设置环境变量
|
||||
set OPENAI_API_KEY="sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR"
|
||||
set OPENAI_BASE_URL="https://openapi.monica.im/v1" # 或其他兼容接口
|
||||
set LLM_MODEL="gpt-4o"
|
||||
|
||||
python main.py
|
||||
|
|
@ -1,95 +1,13 @@
|
|||
# utils/file_handler.py - 文件读取工具(支持 txt/md/pdf/docx)
|
||||
# utils/file_handler.py - 文件读取工具(支持 txt / md / pdf / docx)
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def read_text_file(file_path: str) -> str:
|
||||
"""
|
||||
读取纯文本文件内容(.txt / .md / .py 等)
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
文件文本内容
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
UnicodeDecodeError: 编码错误时尝试 latin-1 兜底
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return path.read_text(encoding="latin-1")
|
||||
|
||||
|
||||
def read_pdf_file(file_path: str) -> str:
|
||||
"""
|
||||
读取 PDF 文件内容
|
||||
|
||||
Args:
|
||||
file_path: PDF 文件路径
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 PyPDF2
|
||||
FileNotFoundError: 文件不存在
|
||||
"""
|
||||
try:
|
||||
import PyPDF2
|
||||
except ImportError:
|
||||
raise ImportError("请安装 PyPDF2: pip install PyPDF2")
|
||||
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
texts = []
|
||||
with open(file_path, "rb") as f:
|
||||
reader = PyPDF2.PdfReader(f)
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
texts.append(text)
|
||||
return "\n".join(texts)
|
||||
|
||||
|
||||
def read_docx_file(file_path: str) -> str:
|
||||
"""
|
||||
读取 Word (.docx) 文件内容
|
||||
|
||||
Args:
|
||||
file_path: docx 文件路径
|
||||
|
||||
Returns:
|
||||
提取的文本内容(段落合并)
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 python-docx
|
||||
FileNotFoundError: 文件不存在
|
||||
"""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError("请安装 python-docx: pip install python-docx")
|
||||
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
doc = Document(file_path)
|
||||
return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
|
||||
from typing import List
|
||||
|
||||
|
||||
def read_file_auto(file_path: str) -> str:
|
||||
"""
|
||||
根据文件扩展名自动选择读取方式
|
||||
自动识别文件类型并读取文本内容。
|
||||
|
||||
支持格式:.txt / .md / .pdf / .docx / 其他(按 UTF-8 读取)
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
|
@ -98,45 +16,66 @@ def read_file_auto(file_path: str) -> str:
|
|||
文件文本内容
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的文件类型
|
||||
FileNotFoundError: 文件不存在
|
||||
RuntimeError: 读取失败
|
||||
"""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
readers = {
|
||||
".txt": read_text_file,
|
||||
".md": read_text_file,
|
||||
".py": read_text_file,
|
||||
".json": read_text_file,
|
||||
".yaml": read_text_file,
|
||||
".yml": read_text_file,
|
||||
".pdf": read_pdf_file,
|
||||
".docx": read_docx_file,
|
||||
}
|
||||
reader = readers.get(ext)
|
||||
if reader is None:
|
||||
raise ValueError(f"不支持的文件类型: {ext},支持: {list(readers.keys())}")
|
||||
return reader(file_path)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
try:
|
||||
if ext == ".pdf":
|
||||
return _read_pdf(file_path)
|
||||
elif ext in (".docx", ".doc"):
|
||||
return _read_docx(file_path)
|
||||
else:
|
||||
# txt / md / 其他文本格式
|
||||
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"读取文件失败 [{file_path}]: {e}")
|
||||
|
||||
|
||||
def _read_pdf(file_path: str) -> str:
|
||||
"""读取 PDF 文件文本"""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
raise RuntimeError("读取 PDF 需要安装 pypdf:pip install pypdf")
|
||||
|
||||
reader = PdfReader(file_path)
|
||||
pages = [page.extract_text() or "" for page in reader.pages]
|
||||
return "\n".join(pages)
|
||||
|
||||
|
||||
def _read_docx(file_path: str) -> str:
|
||||
"""读取 Word 文档文本"""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise RuntimeError("读取 docx 需要安装 python-docx:pip install python-docx")
|
||||
|
||||
doc = Document(file_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return "\n".join(paragraphs)
|
||||
|
||||
|
||||
def merge_knowledge_files(file_paths: List[str]) -> str:
|
||||
"""
|
||||
合并多个知识库文件为单一文本
|
||||
合并多个知识库文件为单一文本。
|
||||
|
||||
Args:
|
||||
file_paths: 知识库文件路径列表
|
||||
file_paths: 文件路径列表
|
||||
|
||||
Returns:
|
||||
合并后的知识库文本(包含文件名分隔符)
|
||||
合并后的文本(各文件以分隔线隔开)
|
||||
"""
|
||||
if not file_paths:
|
||||
return ""
|
||||
|
||||
sections = []
|
||||
for fp in file_paths:
|
||||
parts = []
|
||||
for path in file_paths:
|
||||
try:
|
||||
content = read_file_auto(fp)
|
||||
file_name = Path(fp).name
|
||||
sections.append(f"### 知识库文件: {file_name}\n{content}")
|
||||
content = read_file_auto(path)
|
||||
parts.append(f"--- {os.path.basename(path)} ---\n{content}")
|
||||
except Exception as e:
|
||||
sections.append(f"### 知识库文件: {fp}\n[读取失败: {e}]")
|
||||
|
||||
return "\n\n".join(sections)
|
||||
parts.append(f"--- {os.path.basename(path)} [读取失败: {e}] ---")
|
||||
return "\n\n".join(parts)
|
||||
|
|
@ -6,133 +6,74 @@ from typing import Dict, List
|
|||
|
||||
import config
|
||||
|
||||
|
||||
# 各语言文件扩展名映射
|
||||
LANGUAGE_EXT_MAP: Dict[str, str] = {
|
||||
"python": ".py",
|
||||
"javascript": ".js",
|
||||
"typescript": ".ts",
|
||||
"java": ".java",
|
||||
"go": ".go",
|
||||
"rust": ".rs",
|
||||
"cpp": ".cpp",
|
||||
"c": ".c",
|
||||
"csharp": ".cs",
|
||||
"ruby": ".rb",
|
||||
"php": ".php",
|
||||
"swift": ".swift",
|
||||
"kotlin": ".kt",
|
||||
}
|
||||
|
||||
# 合法的通用类型集合
|
||||
VALID_TYPES = {
|
||||
"integer", "string", "boolean", "float",
|
||||
"list", "dict", "object", "void", "any",
|
||||
}
|
||||
|
||||
# 合法的 inout 值
|
||||
VALID_INOUT = {"in", "out", "inout"}
|
||||
|
||||
|
||||
def get_file_extension(language: str) -> str:
|
||||
"""
|
||||
获取指定语言的文件扩展名
|
||||
|
||||
Args:
|
||||
language: 编程语言名称(小写)
|
||||
|
||||
Returns:
|
||||
文件扩展名(含点号,如 '.py')
|
||||
"""
|
||||
return LANGUAGE_EXT_MAP.get(language.lower(), ".txt")
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# 目录管理
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def build_project_output_dir(project_name: str) -> str:
|
||||
"""
|
||||
构建项目输出目录路径
|
||||
|
||||
Args:
|
||||
project_name: 项目名称
|
||||
|
||||
Returns:
|
||||
输出目录路径
|
||||
"""
|
||||
safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name)
|
||||
return os.path.join(config.OUTPUT_BASE_DIR, safe_name)
|
||||
safe = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name)
|
||||
return os.path.join(config.OUTPUT_BASE_DIR, safe)
|
||||
|
||||
|
||||
def ensure_project_dir(project_name: str) -> str:
|
||||
"""
|
||||
确保项目输出目录存在,不存在则创建
|
||||
|
||||
Args:
|
||||
project_name: 项目名称
|
||||
|
||||
Returns:
|
||||
创建好的目录路径
|
||||
"""
|
||||
"""确保项目根输出目录存在,并创建 __init__.py"""
|
||||
output_dir = build_project_output_dir(project_name)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
init_file = os.path.join(output_dir, "__init__.py")
|
||||
if not os.path.exists(init_file):
|
||||
Path(init_file).write_text(
|
||||
"# Auto-generated project package\n", encoding="utf-8"
|
||||
)
|
||||
Path(init_file).write_text("# Auto-generated project package\n", encoding="utf-8")
|
||||
return output_dir
|
||||
|
||||
|
||||
def write_code_file(
|
||||
output_dir: str,
|
||||
function_name: str,
|
||||
language: str,
|
||||
content: str,
|
||||
) -> str:
|
||||
"""
|
||||
将代码内容写入指定目录的文件
|
||||
def ensure_module_dir(output_dir: str, module: str) -> str:
|
||||
"""确保模块子目录存在,并创建 __init__.py"""
|
||||
module_dir = os.path.join(output_dir, module)
|
||||
os.makedirs(module_dir, exist_ok=True)
|
||||
init_file = os.path.join(module_dir, "__init__.py")
|
||||
if not os.path.exists(init_file):
|
||||
Path(init_file).write_text(
|
||||
f"# Auto-generated module package: {module}\n", encoding="utf-8"
|
||||
)
|
||||
return module_dir
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录路径
|
||||
function_name: 函数名(用于生成文件名)
|
||||
language: 编程语言
|
||||
content: 代码内容
|
||||
|
||||
Returns:
|
||||
写入的文件完整路径
|
||||
"""
|
||||
ext = get_file_extension(language)
|
||||
file_name = f"{function_name}{ext}"
|
||||
file_path = os.path.join(output_dir, file_name)
|
||||
Path(file_path).write_text(content, encoding="utf-8")
|
||||
return file_path
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# README
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def write_project_readme(
|
||||
output_dir: str,
|
||||
project_name: str,
|
||||
output_dir: str,
|
||||
project_name: str,
|
||||
project_description: str,
|
||||
requirements_summary: str,
|
||||
modules: List[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
在项目目录生成 README.md 文件
|
||||
"""生成项目 README.md"""
|
||||
module_section = ""
|
||||
if modules:
|
||||
module_list = "\n".join(f"- `{m}/`" for m in sorted(set(modules)))
|
||||
module_section = f"\n## 功能模块\n\n{module_list}\n"
|
||||
|
||||
Args:
|
||||
output_dir: 项目输出目录
|
||||
project_name: 项目名称
|
||||
requirements_summary: 功能需求摘要文本
|
||||
|
||||
Returns:
|
||||
README.md 文件路径
|
||||
"""
|
||||
readme_content = f"""# {project_name}
|
||||
content = f"""# {project_name}
|
||||
|
||||
> Auto-generated by Requirement Analyzer
|
||||
|
||||
{project_description or ""}
|
||||
{module_section}
|
||||
## 功能需求列表
|
||||
|
||||
{requirements_summary}
|
||||
"""
|
||||
readme_path = os.path.join(output_dir, "README.md")
|
||||
Path(readme_path).write_text(readme_content, encoding="utf-8")
|
||||
return readme_path
|
||||
path = os.path.join(output_dir, "README.md")
|
||||
Path(path).write_text(content, encoding="utf-8")
|
||||
return path
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
|
@ -140,26 +81,18 @@ def write_project_readme(
|
|||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def build_signatures_document(
|
||||
project_name: str,
|
||||
project_name: str,
|
||||
project_description: str,
|
||||
signatures: List[dict],
|
||||
signatures: List[dict],
|
||||
) -> dict:
|
||||
"""
|
||||
将函数签名列表包装为带项目信息的顶层文档结构。
|
||||
构建顶层签名文档结构::
|
||||
|
||||
Args:
|
||||
project_name: 项目名称,写入 "project" 字段
|
||||
project_description: 项目描述,写入 "description" 字段
|
||||
signatures: 函数签名 dict 列表,写入 "functions" 字段
|
||||
|
||||
Returns:
|
||||
顶层文档 dict,结构为::
|
||||
|
||||
{
|
||||
"project": "<project_name>",
|
||||
"description": "<project_description>",
|
||||
"functions": [ ... ]
|
||||
}
|
||||
{
|
||||
"project": "<name>",
|
||||
"description": "<description>",
|
||||
"functions": [ ... ]
|
||||
}
|
||||
"""
|
||||
return {
|
||||
"project": project_name,
|
||||
|
|
@ -169,118 +102,45 @@ def build_signatures_document(
|
|||
|
||||
|
||||
def patch_signatures_with_url(
|
||||
signatures: List[dict],
|
||||
signatures: List[dict],
|
||||
func_name_to_url: Dict[str, str],
|
||||
) -> List[dict]:
|
||||
"""
|
||||
将代码文件的路径(URL)回写到对应函数签名的 "url" 字段。
|
||||
|
||||
遍历签名列表,根据 signature["name"] 在 func_name_to_url 中查找
|
||||
对应路径,找到则写入 "url" 字段;未找到则写入空字符串,不抛出异常。
|
||||
|
||||
"url" 字段插入位置紧跟在 "type" 字段之后,以保持字段顺序的可读性::
|
||||
|
||||
{
|
||||
"name": "create_user",
|
||||
"requirement_id": "REQ.01",
|
||||
"description": "...",
|
||||
"type": "function",
|
||||
"url": "/abs/path/to/create_user.py", ← 新增
|
||||
"parameters": { ... },
|
||||
"return": { ... }
|
||||
}
|
||||
将代码文件路径回写到签名的 "url" 字段(紧跟 "type" 之后)。
|
||||
|
||||
Args:
|
||||
signatures: 原始签名列表(in-place 修改)
|
||||
func_name_to_url: {函数名: 代码文件绝对路径} 映射表,
|
||||
由 CodeGenerator.generate_batch() 的进度回调收集
|
||||
signatures: 签名列表(in-place 修改)
|
||||
func_name_to_url: {函数名: 文件绝对路径}
|
||||
|
||||
Returns:
|
||||
修改后的签名列表(与传入的同一对象,方便链式调用)
|
||||
修改后的签名列表
|
||||
"""
|
||||
for sig in signatures:
|
||||
func_name = sig.get("name", "")
|
||||
url = func_name_to_url.get(func_name, "")
|
||||
url = func_name_to_url.get(sig.get("name", ""), "")
|
||||
_insert_field_after(sig, after_key="type", new_key="url", new_value=url)
|
||||
return signatures
|
||||
|
||||
|
||||
def _insert_field_after(
|
||||
d: dict,
|
||||
after_key: str,
|
||||
new_key: str,
|
||||
new_value,
|
||||
) -> None:
|
||||
"""
|
||||
在有序 dict 中将 new_key 插入到 after_key 之后。
|
||||
若 after_key 不存在,则追加到末尾。
|
||||
若 new_key 已存在,则直接更新其值(不改变位置)。
|
||||
|
||||
Args:
|
||||
d: 目标 dict(Python 3.7+ 保证插入顺序)
|
||||
after_key: 参考键名
|
||||
new_key: 要插入的键名
|
||||
new_value: 要插入的值
|
||||
"""
|
||||
def _insert_field_after(d: dict, after_key: str, new_key: str, new_value) -> None:
|
||||
"""在有序 dict 中将 new_key 插入到 after_key 之后"""
|
||||
if new_key in d:
|
||||
d[new_key] = new_value
|
||||
return
|
||||
|
||||
items = list(d.items())
|
||||
insert_pos = len(items)
|
||||
for i, (k, _) in enumerate(items):
|
||||
if k == after_key:
|
||||
insert_pos = i + 1
|
||||
break
|
||||
|
||||
insert_pos = next((i + 1 for i, (k, _) in enumerate(items) if k == after_key), len(items))
|
||||
items.insert(insert_pos, (new_key, new_value))
|
||||
d.clear()
|
||||
d.update(items)
|
||||
|
||||
|
||||
def write_function_signatures_json(
|
||||
output_dir: str,
|
||||
signatures: List[dict],
|
||||
project_name: str,
|
||||
output_dir: str,
|
||||
signatures: List[dict],
|
||||
project_name: str,
|
||||
project_description: str,
|
||||
file_name: str = "function_signatures.json",
|
||||
file_name: str = "function_signatures.json",
|
||||
) -> str:
|
||||
"""
|
||||
将函数签名列表连同项目信息一起导出为 JSON 文件。
|
||||
|
||||
输出的 JSON 顶层结构为::
|
||||
|
||||
{
|
||||
"project": "<project_name>",
|
||||
"description": "<project_description>",
|
||||
"functions": [
|
||||
{
|
||||
"name": "...",
|
||||
"requirement_id": "...",
|
||||
"description": "...",
|
||||
"type": "function",
|
||||
"url": "/abs/path/to/xxx.py",
|
||||
"parameters": { ... },
|
||||
"return": { ... }
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
Args:
|
||||
output_dir: JSON 文件写入目录
|
||||
signatures: 函数签名 dict 列表(应已通过
|
||||
patch_signatures_with_url() 写入 "url" 字段)
|
||||
project_name: 项目名称
|
||||
project_description: 项目描述
|
||||
file_name: 输出文件名,默认 function_signatures.json
|
||||
|
||||
Returns:
|
||||
写入的 JSON 文件完整路径
|
||||
|
||||
Raises:
|
||||
OSError: 目录不可写
|
||||
"""
|
||||
"""将签名列表导出为 JSON 文件"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
document = build_signatures_document(project_name, project_description, signatures)
|
||||
file_path = os.path.join(output_dir, file_name)
|
||||
|
|
@ -294,137 +154,70 @@ def write_function_signatures_json(
|
|||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def validate_signature_schema(signature: dict) -> List[str]:
|
||||
"""
|
||||
校验单个函数签名 dict 是否符合规范。
|
||||
|
||||
校验范围:
|
||||
- 顶层必填字段:name / requirement_id / description / type / parameters
|
||||
- 可选字段 "url":若存在则必须为非空字符串
|
||||
- parameters:每个参数的 type / inout / required 字段
|
||||
- return:type 字段 + on_success / on_failure 子结构
|
||||
- void 函数:on_success / on_failure 应为 null
|
||||
- 非 void 函数:on_success / on_failure 必须存在,
|
||||
且 value(非空)与 description(非空)均需填写
|
||||
|
||||
Args:
|
||||
signature: 单个函数签名 dict
|
||||
|
||||
Returns:
|
||||
错误信息字符串列表,列表为空表示校验通过
|
||||
"""
|
||||
"""校验单个函数签名结构,返回错误列表(空列表表示通过)"""
|
||||
errors: List[str] = []
|
||||
|
||||
# ── 顶层必填字段 ──────────────────────────────────
|
||||
for key in ("name", "requirement_id", "description", "type", "parameters"):
|
||||
if key not in signature:
|
||||
errors.append(f"缺少顶层字段: '{key}'")
|
||||
|
||||
# ── url 字段(可选,存在时校验非空)─────────────────
|
||||
if "url" in signature:
|
||||
if not isinstance(signature["url"], str):
|
||||
errors.append("'url' 字段必须是字符串类型")
|
||||
elif signature["url"] == "":
|
||||
errors.append("'url' 字段不能为空字符串(代码文件路径未成功回写)")
|
||||
errors.append("'url' 字段不能为空字符串")
|
||||
|
||||
# ── parameters ────────────────────────────────────
|
||||
params = signature.get("parameters", {})
|
||||
if not isinstance(params, dict):
|
||||
errors.append("'parameters' 必须是 dict 类型")
|
||||
else:
|
||||
if isinstance(params, dict):
|
||||
for pname, pdef in params.items():
|
||||
if not isinstance(pdef, dict):
|
||||
errors.append(f"参数 '{pname}' 定义必须是 dict")
|
||||
continue
|
||||
# type(支持联合类型,如 "string|integer")
|
||||
if "type" not in pdef:
|
||||
errors.append(f"参数 '{pname}' 缺少 'type' 字段")
|
||||
errors.append(f"参数 '{pname}' 缺少 'type'")
|
||||
else:
|
||||
parts = [p.strip() for p in pdef["type"].split("|")]
|
||||
if not all(p in VALID_TYPES for p in parts):
|
||||
errors.append(
|
||||
f"参数 '{pname}' 的 type='{pdef['type']}' 含有不合法的类型"
|
||||
)
|
||||
# inout
|
||||
errors.append(f"参数 '{pname}' type='{pdef['type']}' 含不合法类型")
|
||||
if "inout" not in pdef:
|
||||
errors.append(f"参数 '{pname}' 缺少 'inout' 字段")
|
||||
errors.append(f"参数 '{pname}' 缺少 'inout'")
|
||||
elif pdef["inout"] not in VALID_INOUT:
|
||||
errors.append(
|
||||
f"参数 '{pname}' 的 inout='{pdef['inout']}' 应为 in/out/inout"
|
||||
)
|
||||
# required
|
||||
errors.append(f"参数 '{pname}' inout='{pdef['inout']}' 应为 in/out/inout")
|
||||
if "required" not in pdef:
|
||||
errors.append(f"参数 '{pname}' 缺少 'required' 字段")
|
||||
errors.append(f"参数 '{pname}' 缺少 'required'")
|
||||
elif not isinstance(pdef["required"], bool):
|
||||
errors.append(
|
||||
f"参数 '{pname}' 的 'required' 应为布尔值 true/false,"
|
||||
f"当前为: {pdef['required']!r}"
|
||||
)
|
||||
errors.append(f"参数 '{pname}' 'required' 应为布尔值")
|
||||
|
||||
# ── return ────────────────────────────────────────
|
||||
ret = signature.get("return")
|
||||
if ret is None:
|
||||
errors.append(
|
||||
"缺少 'return' 字段(void 函数请填 "
|
||||
"{\"type\": \"void\", \"on_success\": null, \"on_failure\": null})"
|
||||
)
|
||||
elif not isinstance(ret, dict):
|
||||
errors.append("'return' 必须是 dict 类型")
|
||||
else:
|
||||
errors.append("缺少 'return' 字段")
|
||||
elif isinstance(ret, dict):
|
||||
ret_type = ret.get("type")
|
||||
if not ret_type:
|
||||
errors.append("'return' 缺少 'type' 字段")
|
||||
errors.append("'return' 缺少 'type'")
|
||||
elif ret_type not in VALID_TYPES:
|
||||
errors.append(f"'return.type'='{ret_type}' 不在合法类型列表中")
|
||||
|
||||
errors.append(f"'return.type'='{ret_type}' 不合法")
|
||||
is_void = (ret_type == "void")
|
||||
|
||||
for sub_key in ("on_success", "on_failure"):
|
||||
sub = ret.get(sub_key)
|
||||
if is_void:
|
||||
if sub is not None:
|
||||
errors.append(
|
||||
f"void 函数的 'return.{sub_key}' 应为 null,"
|
||||
f"当前为: {sub!r}"
|
||||
)
|
||||
errors.append(f"void 函数 'return.{sub_key}' 应为 null")
|
||||
else:
|
||||
if sub is None:
|
||||
errors.append(
|
||||
f"非 void 函数缺少 'return.{sub_key}',"
|
||||
f"请描述{'成功' if sub_key == 'on_success' else '失败'}时的返回值"
|
||||
)
|
||||
elif not isinstance(sub, dict):
|
||||
errors.append(f"'return.{sub_key}' 必须是 dict 类型")
|
||||
else:
|
||||
if "value" not in sub:
|
||||
errors.append(f"'return.{sub_key}' 缺少 'value' 字段")
|
||||
elif sub["value"] == "":
|
||||
errors.append(
|
||||
f"'return.{sub_key}.value' 不能为空字符串,"
|
||||
f"请填写具体返回值、值域描述或结构示例"
|
||||
)
|
||||
if "description" not in sub or sub.get("description") in (None, ""):
|
||||
errors.append(f"非 void 函数缺少 'return.{sub_key}'")
|
||||
elif isinstance(sub, dict):
|
||||
if not sub.get("value"):
|
||||
errors.append(f"'return.{sub_key}.value' 不能为空")
|
||||
if not sub.get("description"):
|
||||
errors.append(f"'return.{sub_key}.description' 不能为空")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_all_signatures(signatures: List[dict]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
批量校验函数签名列表。
|
||||
|
||||
注意:此函数接受的是纯签名列表(即顶层文档的 "functions" 字段),
|
||||
而非包含 project/description 的顶层文档。
|
||||
|
||||
Args:
|
||||
signatures: 函数签名 dict 列表
|
||||
|
||||
Returns:
|
||||
{函数名: [错误信息, ...]} 字典,仅包含有错误的条目
|
||||
"""
|
||||
report: Dict[str, List[str]] = {}
|
||||
for sig in signatures:
|
||||
name = sig.get("name", f"unknown_{id(sig)}")
|
||||
errs = validate_signature_schema(sig)
|
||||
if errs:
|
||||
report[name] = errs
|
||||
return report
|
||||
"""批量校验,返回 {函数名: [错误]} 字典(仅含有错误的条目)"""
|
||||
return {
|
||||
sig.get("name", f"unknown_{i}"): errs
|
||||
for i, sig in enumerate(signatures)
|
||||
if (errs := validate_signature_schema(sig))
|
||||
}
|
||||
Loading…
Reference in New Issue