支持功能模块
This commit is contained in:
parent
3fc095652e
commit
29636b9b94
|
|
@ -1,146 +1,123 @@
|
|||
# config.py - 全局配置管理
|
||||
# config.py - 全局配置
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# ── LLM 配置 ──────────────────────────────────────────
|
||||
# ── LLM ──────────────────────────────────────────────
|
||||
LLM_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
LLM_API_BASE = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
|
||||
LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0.3"))
|
||||
LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT", "60"))
|
||||
LLM_MAX_RETRY = int(os.getenv("LLM_MAX_RETRY", "3"))
|
||||
|
||||
# ── 数据库配置 ─────────────────────────────────────────
|
||||
# ── 数据库 ────────────────────────────────────────────
|
||||
DB_PATH = os.getenv("DB_PATH", "data/requirement_analyzer.db")
|
||||
|
||||
# ── 输出配置 ───────────────────────────────────────────
|
||||
# ── 输出目录 ──────────────────────────────────────────
|
||||
OUTPUT_BASE_DIR = os.getenv("OUTPUT_BASE_DIR", "output")
|
||||
DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "python")
|
||||
DEFAULT_MODULE = os.getenv("DEFAULT_MODULE", "default")
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Prompt 模板
|
||||
# ══════════════════════════════════════════════════════
|
||||
# ── Prompt 模板 ───────────────────────────────────────
|
||||
|
||||
DECOMPOSE_PROMPT_TEMPLATE = """
|
||||
你是一位资深软件架构师和产品经理。请根据以下信息,将原始需求分解为若干个可独立实现的功能需求。
|
||||
DECOMPOSE_PROMPT_TEMPLATE = """\
|
||||
你是一名资深软件架构师,请将以下原始需求分解为独立的功能需求列表。
|
||||
|
||||
{knowledge_section}
|
||||
|
||||
## 原始需求
|
||||
【原始需求】
|
||||
{raw_requirement}
|
||||
|
||||
## 输出要求
|
||||
请严格按照以下 JSON 格式输出,不要包含任何额外说明:
|
||||
{knowledge_section}
|
||||
|
||||
【输出要求】
|
||||
以 JSON 数组格式输出,每个元素包含以下字段:
|
||||
- title: 功能标题(简短,10字以内)
|
||||
- description: 功能描述(详细说明该功能的职责与边界,50字以内)
|
||||
- function_name: 对应的函数名(snake_case,动词开头)
|
||||
- priority: 优先级(high / medium / low)
|
||||
- module: 所属功能模块名称(snake_case,如 user_auth / order_service)
|
||||
|
||||
【示例输出】
|
||||
[
|
||||
{{
|
||||
"functional_requirements": [
|
||||
{{
|
||||
"index": 1,
|
||||
"title": "功能需求标题(简洁,10字以内)",
|
||||
"description": "功能需求详细描述(包含输入、处理逻辑、输出)",
|
||||
"function_name": "snake_case函数名",
|
||||
"priority": "high|medium|low"
|
||||
"title": "用户注册",
|
||||
"description": "接收用户名、密码、邮箱,校验合法性后创建用户账号并返回用户ID",
|
||||
"function_name": "register_user",
|
||||
"priority": "high",
|
||||
"module": "user_auth"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
要求:
|
||||
1. 每个功能需求必须是独立可实现的最小单元
|
||||
2. function_name 使用 snake_case 命名,清晰表达函数用途
|
||||
3. 分解粒度适中,通常 5-15 个功能需求
|
||||
4. 优先级根据业务重要性判断
|
||||
只输出 JSON 数组,不要有任何额外说明。
|
||||
"""
|
||||
|
||||
# ── 函数签名 JSON 生成 Prompt ──────────────────────────
|
||||
FUNC_SIGNATURE_PROMPT_TEMPLATE = """
|
||||
你是一位资深软件架构师。请根据以下功能需求描述,设计该函数的完整接口签名,并以 JSON 格式输出。
|
||||
FUNC_SIGNATURE_PROMPT_TEMPLATE = """\
|
||||
你是一名资深软件工程师,请根据以下功能需求生成标准函数签名信息。
|
||||
|
||||
【功能需求】
|
||||
- 需求编号: {requirement_id}
|
||||
- 标题: {title}
|
||||
- 描述: {description}
|
||||
- 函数名: {function_name}
|
||||
- 所属模块: {module}
|
||||
|
||||
{knowledge_section}
|
||||
|
||||
## 功能需求
|
||||
需求编号:{requirement_id}
|
||||
标题:{title}
|
||||
函数名:{function_name}
|
||||
详细描述:{description}
|
||||
【输出要求】
|
||||
以 JSON 对象格式输出,包含以下字段:
|
||||
- name: 函数名(与上方一致)
|
||||
- requirement_id: 需求编号
|
||||
- description: 函数功能描述(英文,一句话)
|
||||
- type: 固定为 "function"
|
||||
- module: 所属模块名称
|
||||
- parameters: 参数字典,key 为参数名,value 包含:
|
||||
- type: 数据类型(integer/string/boolean/float/list/dict/object/void/any)
|
||||
- inout: in / out / inout
|
||||
- required: true / false
|
||||
- description: 参数说明
|
||||
- return:
|
||||
- type: 返回类型
|
||||
- on_success: {{ "value": "...", "description": "..." }} 或 null(void)
|
||||
- on_failure: {{ "value": "...", "description": "..." }} 或 null(void)
|
||||
|
||||
## 输出格式
|
||||
请严格按照以下 JSON 结构输出,不要包含任何额外说明或 markdown 标记:
|
||||
{{
|
||||
"name": "{function_name}",
|
||||
"requirement_id": "{requirement_id}",
|
||||
"description": "简洁的一句话功能描述(英文)",
|
||||
"type": "function",
|
||||
"parameters": {{
|
||||
"<param_name>": {{
|
||||
"type": "integer|string|boolean|float|list|dict|object",
|
||||
"inout": "in|out|inout",
|
||||
"description": "参数说明(英文)",
|
||||
"required": true
|
||||
}}
|
||||
}},
|
||||
"return": {{
|
||||
"type": "integer|string|boolean|float|list|dict|object|void",
|
||||
"description": "整体返回值说明(英文,一句话概括)",
|
||||
"on_success": {{
|
||||
"value": "具体成功返回值或范围,如 0、true、user object、list of items 等",
|
||||
"description": "成功时的返回值含义(英文)"
|
||||
}},
|
||||
"on_failure": {{
|
||||
"value": "具体失败返回值或范围,如 nonzero、false、null、empty list、raises Exception 等",
|
||||
"description": "失败时的返回值含义,或抛出的异常类型(英文)"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
|
||||
## 设计规范
|
||||
1. 参数名使用 snake_case,类型使用通用类型(不绑定具体语言)
|
||||
2. inout 字段含义:
|
||||
- in = 仅输入参数
|
||||
- out = 仅输出参数(通过参数传出结果,如指针/引用)
|
||||
- inout = 既作输入又作输出
|
||||
3. 所有描述字段使用英文
|
||||
4. return 字段规则:
|
||||
- 若函数无返回值(void),type 填 "void",on_success/on_failure 均填 null
|
||||
- 若返回值只有成功场景(如纯查询),on_failure 可描述为 "null or empty"
|
||||
- on_success.value / on_failure.value 填写具体值或值域描述,不要填写空字符串
|
||||
5. 若函数无参数,parameters 填 {{}}
|
||||
6. required 字段为布尔值 true 或 false
|
||||
只输出 JSON 对象,不要有任何额外说明。
|
||||
"""
|
||||
|
||||
# ── 代码生成 Prompt(含签名约束)─────────────────────────
|
||||
CODE_GEN_PROMPT_TEMPLATE = """
|
||||
你是一位资深 {language} 工程师。请根据以下功能需求和【函数签名规范】,生成完整的 {language} 函数代码。
|
||||
CODE_GEN_PROMPT_TEMPLATE = """\
|
||||
你是一名资深 {language} 工程师,请根据以下函数签名和功能描述生成完整的函数实现代码。
|
||||
|
||||
{knowledge_section}
|
||||
|
||||
## 功能需求
|
||||
标题:{title}
|
||||
描述:{description}
|
||||
|
||||
## 【必须严格遵守】函数签名规范
|
||||
以下 JSON 定义了函数的精确接口,生成的代码必须与之完全一致,不得擅自增减或改名参数:
|
||||
|
||||
```json
|
||||
【函数签名】
|
||||
{signature_json}
|
||||
```
|
||||
|
||||
### 签名字段说明
|
||||
- `name`:函数名,必须完全一致
|
||||
- `parameters`:每个 key 即为参数名,`type` 为数据类型,`inout` 含义:
|
||||
- `in` = 普通输入参数
|
||||
- `out` = 输出参数(Python 中通过返回值或可变容器传出)
|
||||
- `inout` = 既作输入又作输出
|
||||
- `return.type`:返回值类型
|
||||
- `return.on_success`:成功时的返回值,代码实现必须与此一致
|
||||
- `return.on_failure`:失败时的返回值或异常,代码实现必须与此一致
|
||||
【功能描述】
|
||||
{description}
|
||||
|
||||
## 输出要求
|
||||
1. 只输出纯代码,不要包含 markdown 代码块标记
|
||||
2. 函数签名(名称、参数列表、返回类型)必须与上方 JSON 规范完全一致
|
||||
3. 成功/失败的返回值必须严格遵守 return.on_success / return.on_failure 的定义
|
||||
4. 包含完整的类型注解(Python 使用 type hints)
|
||||
5. 包含详细的 docstring,其中 Returns 段须注明成功值与失败值
|
||||
6. 包含必要的异常处理
|
||||
7. 代码风格遵循 PEP8(Python)或对应语言规范
|
||||
8. 在文件顶部用注释注明:需求编号、功能标题、函数签名摘要
|
||||
9. 如需导入第三方库,请在顶部统一导入
|
||||
{knowledge_section}
|
||||
|
||||
【输出要求】
|
||||
1. 只输出 {language} 代码,不要有任何 Markdown 标记(不要 ```)
|
||||
2. 包含完整的函数实现(含必要的 import)
|
||||
3. 包含函数文档注释(docstring / JSDoc 等)
|
||||
4. 包含基本的参数校验与错误处理
|
||||
5. 代码风格遵循 {language} 最佳实践
|
||||
"""
|
||||
|
||||
MODULE_CLASSIFY_PROMPT_TEMPLATE = """\
|
||||
你是一名资深软件架构师,请将以下功能需求列表分类到合适的功能模块中。
|
||||
|
||||
【功能需求列表】
|
||||
{requirements_json}
|
||||
|
||||
【输出要求】
|
||||
以 JSON 数组格式输出,每个元素包含:
|
||||
- function_name: 函数名(与输入一致)
|
||||
- module: 所属模块名称(snake_case,如 user_auth / order_service / payment)
|
||||
|
||||
模块划分原则:
|
||||
1. 功能相近的需求归入同一模块
|
||||
2. 模块名使用英文 snake_case
|
||||
3. 模块数量控制在 2~8 个之间
|
||||
4. 若某需求确实无法归类,使用 "default" 模块
|
||||
|
||||
只输出 JSON 数组,不要有任何额外说明。
|
||||
"""
|
||||
|
|
@ -1,97 +1,87 @@
|
|||
# core/code_generator.py - 代码生成核心逻辑(签名约束版)
|
||||
# core/code_generator.py - 代码生成(按模块路由到子目录)
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, List, Callable
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Callable
|
||||
|
||||
import config
|
||||
from core.llm_client import LLMClient
|
||||
from database.models import FunctionalRequirement, CodeFile
|
||||
from utils.output_writer import write_code_file, get_file_extension
|
||||
|
||||
|
||||
class CodeGenerator:
|
||||
"""
|
||||
根据功能需求 + 函数签名约束,使用 LLM 生成代码函数文件。
|
||||
签名由 RequirementAnalyzer.build_function_signature() 预先生成,
|
||||
注入 Prompt 后可确保代码参数列表与签名 JSON 完全一致。
|
||||
"""
|
||||
"""根据函数签名约束,调用 LLM 生成代码文件,并按模块写入子目录"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
"""
|
||||
初始化代码生成器
|
||||
|
||||
Args:
|
||||
llm_client: LLM 客户端实例,为 None 时自动创建
|
||||
"""
|
||||
self.llm = llm_client or LLMClient()
|
||||
def __init__(self, llm: LLMClient):
|
||||
self.llm = llm
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 单个生成
|
||||
# 单个代码文件生成
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def generate(
|
||||
self,
|
||||
func_req: FunctionalRequirement,
|
||||
output_dir: str,
|
||||
language: str = config.DEFAULT_LANGUAGE,
|
||||
language: str = None,
|
||||
knowledge: str = "",
|
||||
signature: Optional[dict] = None,
|
||||
signature: dict = None,
|
||||
) -> CodeFile:
|
||||
"""
|
||||
为单个功能需求生成代码文件
|
||||
为单个功能需求生成代码文件,写入 output_dir/<module>/ 子目录。
|
||||
|
||||
Args:
|
||||
func_req: 功能需求对象(必须含有效 id)
|
||||
output_dir: 代码输出目录
|
||||
language: 目标编程语言
|
||||
knowledge: 知识库文本(可选)
|
||||
signature: 函数签名 dict(由 RequirementAnalyzer 生成)。
|
||||
传入后将作为强约束注入 Prompt,确保代码参数
|
||||
与签名 JSON 完全一致;为 None 时退化为无约束模式。
|
||||
func_req: 功能需求对象
|
||||
output_dir: 项目根输出目录
|
||||
language: 目标语言
|
||||
knowledge: 知识库文本
|
||||
signature: 函数签名 dict(可选,有则作为约束)
|
||||
|
||||
Returns:
|
||||
CodeFile 对象(含生成的代码内容和文件路径,未持久化)
|
||||
CodeFile 对象(未持久化)
|
||||
|
||||
Raises:
|
||||
ValueError: func_req.id 为 None
|
||||
Exception: LLM 调用失败或文件写入失败
|
||||
RuntimeError: LLM 调用失败
|
||||
"""
|
||||
if func_req.id is None:
|
||||
raise ValueError("FunctionalRequirement 必须先持久化(id 不能为 None)")
|
||||
language = language or config.DEFAULT_LANGUAGE
|
||||
module = (func_req.module or config.DEFAULT_MODULE).strip()
|
||||
|
||||
knowledge_section = self._build_knowledge_section(knowledge)
|
||||
signature_json = self._build_signature_json(signature, func_req)
|
||||
# 按模块创建子目录
|
||||
module_dir = os.path.join(output_dir, module)
|
||||
os.makedirs(module_dir, exist_ok=True)
|
||||
self._ensure_init_py(module_dir)
|
||||
|
||||
# 构建 Prompt
|
||||
sig_json = json.dumps(signature, ensure_ascii=False, indent=2) if signature else "{}"
|
||||
knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else ""
|
||||
prompt = config.CODE_GEN_PROMPT_TEMPLATE.format(
|
||||
language = language,
|
||||
knowledge_section=knowledge_section,
|
||||
title=func_req.title,
|
||||
signature_json = sig_json,
|
||||
description = func_req.description,
|
||||
signature_json=signature_json,
|
||||
knowledge_section = knowledge_section,
|
||||
)
|
||||
|
||||
try:
|
||||
code_content = self.llm.chat(
|
||||
system_prompt=(
|
||||
f"你是一位资深 {language} 工程师,只输出纯代码,"
|
||||
"不添加任何 markdown 标记。函数签名必须与提供的 JSON 规范完全一致。"
|
||||
),
|
||||
user_prompt=prompt,
|
||||
prompt,
|
||||
system = f"You are an expert {language} developer. Output only code.",
|
||||
temperature = 0.2,
|
||||
max_tokens = 4096,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"代码生成失败 [{func_req.function_name}]: {e}")
|
||||
|
||||
file_path = write_code_file(
|
||||
output_dir=output_dir,
|
||||
function_name=func_req.function_name,
|
||||
language=language,
|
||||
content=code_content,
|
||||
)
|
||||
|
||||
ext = get_file_extension(language)
|
||||
# 写入文件
|
||||
ext = self._get_extension(language)
|
||||
file_name = f"{func_req.function_name}{ext}"
|
||||
file_path = os.path.join(module_dir, file_name)
|
||||
Path(file_path).write_text(code_content, encoding="utf-8")
|
||||
|
||||
return CodeFile(
|
||||
project_id=func_req.project_id,
|
||||
func_req_id = func_req.id,
|
||||
file_name = file_name,
|
||||
file_path = file_path,
|
||||
module = module,
|
||||
language = language,
|
||||
content = code_content,
|
||||
)
|
||||
|
|
@ -104,33 +94,31 @@ class CodeGenerator:
|
|||
self,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
output_dir: str,
|
||||
language: str = config.DEFAULT_LANGUAGE,
|
||||
language: str = None,
|
||||
knowledge: str = "",
|
||||
signatures: Optional[List[dict]] = None,
|
||||
on_progress: Optional[Callable] = None,
|
||||
) -> List[CodeFile]:
|
||||
"""
|
||||
批量生成代码文件
|
||||
批量生成代码文件。
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
output_dir: 输出目录
|
||||
output_dir: 项目根输出目录
|
||||
language: 目标语言
|
||||
knowledge: 知识库文本
|
||||
signatures: 与 func_reqs 等长的签名列表(索引对应)。
|
||||
为 None 时所有条目均以无约束模式生成。
|
||||
on_progress: 进度回调 fn(index, total, func_req, code_file, error)
|
||||
signatures: 与 func_reqs 等长的签名列表(索引对应)
|
||||
on_progress: 进度回调 fn(index, total, req, code_file, error)
|
||||
|
||||
Returns:
|
||||
成功生成的 CodeFile 列表
|
||||
"""
|
||||
results = []
|
||||
language = language or config.DEFAULT_LANGUAGE
|
||||
total = len(func_reqs)
|
||||
|
||||
# 构建 func_req.id → signature 的快速查找表
|
||||
results = []
|
||||
sig_map = self._build_signature_map(func_reqs, signatures)
|
||||
|
||||
for i, req in enumerate(func_reqs):
|
||||
for i, req in enumerate(func_reqs, 1):
|
||||
sig = sig_map.get(req.id)
|
||||
try:
|
||||
code_file = self.generate(
|
||||
|
|
@ -142,75 +130,44 @@ class CodeGenerator:
|
|||
)
|
||||
results.append(code_file)
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, code_file, None)
|
||||
on_progress(i, total, req, code_file, None)
|
||||
except Exception as e:
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, None, e)
|
||||
on_progress(i, total, req, None, e)
|
||||
|
||||
return results
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 私有工具方法
|
||||
# 工具方法
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
def _build_knowledge_section(knowledge: str) -> str:
|
||||
"""构建知识库 Prompt 段落"""
|
||||
if not knowledge or not knowledge.strip():
|
||||
return ""
|
||||
return (
|
||||
"## 参考知识库(实现时请遵循以下规范)\n"
|
||||
f"{knowledge}\n\n---\n"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_signature_json(
|
||||
signature: Optional[dict],
|
||||
func_req: FunctionalRequirement,
|
||||
) -> str:
|
||||
"""
|
||||
将签名 dict 序列化为格式化 JSON 字符串;
|
||||
若签名为 None,则构造最小占位签名,保持 Prompt 结构完整。
|
||||
|
||||
Args:
|
||||
signature: 签名 dict 或 None
|
||||
func_req: 对应的功能需求(用于占位签名)
|
||||
|
||||
Returns:
|
||||
JSON 字符串
|
||||
"""
|
||||
if signature:
|
||||
return json.dumps(signature, ensure_ascii=False, indent=2)
|
||||
# 无签名时的最小占位,提示 LLM 自行设计但保持格式
|
||||
fallback = {
|
||||
"name": func_req.function_name,
|
||||
"requirement_id": f"REQ.{func_req.index_no:02d}",
|
||||
"description": func_req.description,
|
||||
"type": "function",
|
||||
"parameters": "<<请根据功能描述自行设计参数>>",
|
||||
"return": "<<请根据功能描述自行设计返回值>>",
|
||||
}
|
||||
return json.dumps(fallback, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _build_signature_map(
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
signatures: Optional[List[dict]],
|
||||
) -> dict:
|
||||
"""
|
||||
构建 func_req.id → signature 映射表
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
signatures: 与 func_reqs 等长的签名列表,或 None
|
||||
|
||||
Returns:
|
||||
{req_id: signature_dict} 字典
|
||||
"""
|
||||
"""构建 func_req.id → signature 的快速查找表"""
|
||||
if not signatures:
|
||||
return {}
|
||||
sig_map = {}
|
||||
for req, sig in zip(func_reqs, signatures):
|
||||
if req.id is not None and sig:
|
||||
sig_map[req.id] = sig
|
||||
return sig_map
|
||||
return {
|
||||
req.id: sig
|
||||
for req, sig in zip(func_reqs, signatures)
|
||||
if req.id is not None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_extension(language: str) -> str:
|
||||
ext_map = {
|
||||
"python": ".py", "javascript": ".js", "typescript": ".ts",
|
||||
"java": ".java", "go": ".go", "rust": ".rs",
|
||||
"cpp": ".cpp", "c": ".c", "csharp": ".cs",
|
||||
"ruby": ".rb", "php": ".php", "swift": ".swift", "kotlin": ".kt",
|
||||
}
|
||||
return ext_map.get(language.lower(), ".txt")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_init_py(directory: str) -> None:
|
||||
"""在目录中创建 __init__.py(Python 包标识)"""
|
||||
init = os.path.join(directory, "__init__.py")
|
||||
if not os.path.exists(init):
|
||||
Path(init).write_text("# Auto-generated module package\n", encoding="utf-8")
|
||||
|
|
@ -1,90 +1,114 @@
|
|||
# core/llm_client.py - LLM 客户端封装
|
||||
# core/llm_client.py - LLM API 调用封装
|
||||
import time
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI, APIError, APITimeoutError, RateLimitError
|
||||
|
||||
import config
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""
|
||||
OpenAI 兼容 LLM 客户端封装。
|
||||
支持任何兼容 OpenAI API 格式的服务(OpenAI / Azure / 本地模型等)。
|
||||
"""
|
||||
"""封装 OpenAI 兼容接口,提供统一的调用入口与重试机制"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = config.LLM_API_KEY,
|
||||
base_url: str = config.LLM_BASE_URL,
|
||||
model: str = config.LLM_MODEL,
|
||||
temperature: float = config.LLM_TEMPERATURE,
|
||||
api_key: str = None,
|
||||
api_base: str = None,
|
||||
model: str = None,
|
||||
):
|
||||
"""
|
||||
初始化 LLM 客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
base_url: API 基础 URL
|
||||
model: 模型名称
|
||||
temperature: 生成温度(0~1,越低越确定)
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 openai 库
|
||||
ValueError: api_key 为空
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
raise ImportError("请安装 openai: pip install openai")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置,请在 .env 文件中设置")
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self._client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def chat(self, system_prompt: str, user_prompt: str) -> str:
|
||||
"""
|
||||
发送对话请求,返回模型回复文本
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
|
||||
Returns:
|
||||
模型回复的文本内容
|
||||
|
||||
Raises:
|
||||
Exception: API 调用失败
|
||||
"""
|
||||
response = self._client.chat.completions.create(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
self.model = model or config.LLM_MODEL
|
||||
self.client = OpenAI(
|
||||
api_key = api_key or config.LLM_API_KEY,
|
||||
base_url = api_base or config.LLM_API_BASE,
|
||||
timeout = config.LLM_TIMEOUT,
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def chat_json(self, system_prompt: str, user_prompt: str) -> dict:
|
||||
def chat(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "You are a helpful assistant.",
|
||||
temperature: float = 0.2,
|
||||
max_tokens: int = 4096,
|
||||
) -> str:
|
||||
"""
|
||||
发送对话请求,解析并返回 JSON 结果
|
||||
发送单轮对话请求,返回模型回复文本。
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
prompt: 用户消息
|
||||
system: 系统提示词
|
||||
temperature: 采样温度
|
||||
max_tokens: 最大输出 token 数
|
||||
|
||||
Returns:
|
||||
解析后的 dict 对象
|
||||
模型回复的纯文本字符串
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 模型返回非合法 JSON
|
||||
RuntimeError: 超过最大重试次数后仍失败
|
||||
"""
|
||||
raw = self.chat(system_prompt, user_prompt)
|
||||
last_error = None
|
||||
for attempt in range(1, config.LLM_MAX_RETRY + 1):
|
||||
try:
|
||||
resp = self.client.chat.completions.create(
|
||||
model = self.model,
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature = temperature,
|
||||
max_tokens = max_tokens,
|
||||
)
|
||||
return resp.choices[0].message.content.strip()
|
||||
except RateLimitError as e:
|
||||
wait = 2 ** attempt
|
||||
last_error = e
|
||||
time.sleep(wait)
|
||||
except APITimeoutError as e:
|
||||
last_error = e
|
||||
time.sleep(2)
|
||||
except APIError as e:
|
||||
last_error = e
|
||||
if attempt >= config.LLM_MAX_RETRY:
|
||||
break
|
||||
time.sleep(1)
|
||||
raise RuntimeError(
|
||||
f"LLM 调用失败(已重试 {config.LLM_MAX_RETRY} 次): {last_error}"
|
||||
)
|
||||
|
||||
def chat_json(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "You are a helpful assistant. Always respond with valid JSON.",
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
) -> any:
|
||||
"""
|
||||
发送请求并将回复解析为 JSON 对象。
|
||||
|
||||
Returns:
|
||||
解析后的 Python 对象(dict 或 list)
|
||||
|
||||
Raises:
|
||||
ValueError: JSON 解析失败
|
||||
RuntimeError: LLM 调用失败
|
||||
"""
|
||||
raw = self.chat(prompt, system=system, temperature=temperature, max_tokens=max_tokens)
|
||||
# 去除可能的 markdown 代码块包裹
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
lines = raw.split("\n")
|
||||
raw = "\n".join(lines[1:-1])
|
||||
return json.loads(raw)
|
||||
cleaned = self._strip_markdown_code_block(raw)
|
||||
try:
|
||||
return json.loads(cleaned)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"LLM 返回内容无法解析为 JSON: {e}\n原始内容:\n{raw}")
|
||||
|
||||
@staticmethod
|
||||
def _strip_markdown_code_block(text: str) -> str:
|
||||
"""去除 ```json ... ``` 或 ``` ... ``` 包裹"""
|
||||
text = text.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.splitlines()
|
||||
# 去掉首行(```json 或 ```)和末行(```)
|
||||
inner = lines[1:] if lines[-1].strip() == "```" else lines[1:]
|
||||
if inner and inner[-1].strip() == "```":
|
||||
inner = inner[:-1]
|
||||
text = "\n".join(inner).strip()
|
||||
return text
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
# core/requirement_analyzer.py - 需求分解 & 函数签名生成
|
||||
import re
|
||||
from typing import List, Optional
|
||||
import json
|
||||
from typing import List, Optional, Callable
|
||||
|
||||
import config
|
||||
from core.llm_client import LLMClient
|
||||
|
|
@ -8,19 +8,10 @@ from database.models import FunctionalRequirement
|
|||
|
||||
|
||||
class RequirementAnalyzer:
|
||||
"""
|
||||
使用 LLM 将原始需求分解为功能需求列表,并生成函数接口签名。
|
||||
支持注入知识库上下文以提升分解质量。
|
||||
"""
|
||||
"""负责需求分解、模块分类、函数签名生成"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
"""
|
||||
初始化需求分析器
|
||||
|
||||
Args:
|
||||
llm_client: LLM 客户端实例,为 None 时自动创建
|
||||
"""
|
||||
self.llm = llm_client or LLMClient()
|
||||
def __init__(self, llm: LLMClient):
|
||||
self.llm = llm
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 需求分解
|
||||
|
|
@ -34,7 +25,7 @@ class RequirementAnalyzer:
|
|||
knowledge: str = "",
|
||||
) -> List[FunctionalRequirement]:
|
||||
"""
|
||||
将原始需求分解为功能需求列表
|
||||
将原始需求文本分解为功能需求列表(含模块分类)。
|
||||
|
||||
Args:
|
||||
raw_requirement: 原始需求文本
|
||||
|
|
@ -44,196 +35,175 @@ class RequirementAnalyzer:
|
|||
|
||||
Returns:
|
||||
FunctionalRequirement 对象列表(未持久化,id=None)
|
||||
|
||||
Raises:
|
||||
ValueError: LLM 返回格式不合法
|
||||
json.JSONDecodeError: JSON 解析失败
|
||||
"""
|
||||
knowledge_section = self._build_knowledge_section(knowledge)
|
||||
knowledge_section = (
|
||||
f"【参考知识库】\n{knowledge}\n" if knowledge else ""
|
||||
)
|
||||
prompt = config.DECOMPOSE_PROMPT_TEMPLATE.format(
|
||||
knowledge_section=knowledge_section,
|
||||
raw_requirement = raw_requirement,
|
||||
knowledge_section = knowledge_section,
|
||||
)
|
||||
|
||||
result = self.llm.chat_json(
|
||||
system_prompt="你是一位资深软件架构师,擅长需求分析与系统设计。",
|
||||
user_prompt=prompt,
|
||||
)
|
||||
try:
|
||||
items = self.llm.chat_json(prompt)
|
||||
if not isinstance(items, list):
|
||||
raise ValueError("LLM 返回结果不是数组")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"需求分解失败: {e}")
|
||||
|
||||
items = result.get("functional_requirements", [])
|
||||
if not items:
|
||||
raise ValueError("LLM 未返回任何功能需求,请检查原始需求描述")
|
||||
|
||||
requirements = []
|
||||
for item in items:
|
||||
reqs = []
|
||||
for i, item in enumerate(items, 1):
|
||||
req = FunctionalRequirement(
|
||||
project_id = project_id,
|
||||
raw_req_id = raw_req_id,
|
||||
index_no=int(item.get("index", len(requirements) + 1)),
|
||||
title=item.get("title", "未命名功能"),
|
||||
index_no = i,
|
||||
title = item.get("title", f"功能{i}"),
|
||||
description = item.get("description", ""),
|
||||
function_name=self._sanitize_function_name(
|
||||
item.get("function_name", f"func_{len(requirements)+1}")
|
||||
),
|
||||
function_name = item.get("function_name", f"function_{i}"),
|
||||
priority = item.get("priority", "medium"),
|
||||
module = item.get("module", config.DEFAULT_MODULE),
|
||||
status = "pending",
|
||||
is_custom = False,
|
||||
)
|
||||
requirements.append(req)
|
||||
|
||||
return requirements
|
||||
reqs.append(req)
|
||||
return reqs
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 函数签名生成(新增)
|
||||
# 模块分类(独立步骤,可对已有需求列表重新分类)
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def classify_modules(
|
||||
self,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
knowledge: str = "",
|
||||
) -> List[dict]:
|
||||
"""
|
||||
对功能需求列表进行模块分类,返回 {function_name: module} 映射列表。
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
knowledge: 知识库文本(可选)
|
||||
|
||||
Returns:
|
||||
[{"function_name": "...", "module": "..."}, ...]
|
||||
"""
|
||||
req_list = [
|
||||
{
|
||||
"index_no": r.index_no,
|
||||
"title": r.title,
|
||||
"description": r.description,
|
||||
"function_name": r.function_name,
|
||||
}
|
||||
for r in func_reqs
|
||||
]
|
||||
knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else ""
|
||||
prompt = config.MODULE_CLASSIFY_PROMPT_TEMPLATE.format(
|
||||
requirements_json = json.dumps(req_list, ensure_ascii=False, indent=2),
|
||||
knowledge_section = knowledge_section,
|
||||
)
|
||||
try:
|
||||
result = self.llm.chat_json(prompt)
|
||||
if not isinstance(result, list):
|
||||
raise ValueError("LLM 返回结果不是数组")
|
||||
return result
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"模块分类失败: {e}")
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 函数签名生成
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def build_function_signature(
|
||||
self,
|
||||
func_req: FunctionalRequirement,
|
||||
requirement_id: str = "",
|
||||
knowledge: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
为单个功能需求生成函数接口签名 JSON
|
||||
为单个功能需求生成函数签名 dict。
|
||||
|
||||
Args:
|
||||
func_req: 功能需求对象(需含有效 id)
|
||||
knowledge: 知识库文本(可选)
|
||||
func_req: 功能需求对象
|
||||
requirement_id: 需求编号字符串(如 "REQ.01")
|
||||
knowledge: 知识库文本
|
||||
|
||||
Returns:
|
||||
符合接口规范的 dict,包含 name/requirement_id/description/
|
||||
type/parameters/return 字段
|
||||
函数签名 dict
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: LLM 返回非合法 JSON
|
||||
RuntimeError: LLM 调用或解析失败
|
||||
"""
|
||||
requirement_id = self._format_requirement_id(func_req.index_no)
|
||||
knowledge_section = self._build_knowledge_section(knowledge)
|
||||
|
||||
knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else ""
|
||||
prompt = config.FUNC_SIGNATURE_PROMPT_TEMPLATE.format(
|
||||
knowledge_section=knowledge_section,
|
||||
requirement_id=requirement_id,
|
||||
requirement_id = requirement_id or f"REQ.{func_req.index_no:02d}",
|
||||
title = func_req.title,
|
||||
function_name=func_req.function_name,
|
||||
description = func_req.description,
|
||||
function_name = func_req.function_name,
|
||||
module = func_req.module or config.DEFAULT_MODULE,
|
||||
knowledge_section = knowledge_section,
|
||||
)
|
||||
|
||||
signature = self.llm.chat_json(
|
||||
system_prompt=(
|
||||
"你是一位资深软件架构师,专注于 API 接口设计。"
|
||||
"只输出合法 JSON,不添加任何说明文字。"
|
||||
),
|
||||
user_prompt=prompt,
|
||||
)
|
||||
|
||||
# 确保关键字段存在,做兜底处理
|
||||
signature.setdefault("name", func_req.function_name)
|
||||
signature.setdefault("requirement_id", requirement_id)
|
||||
signature.setdefault("description", func_req.description)
|
||||
signature.setdefault("type", "function")
|
||||
signature.setdefault("parameters", {})
|
||||
signature.setdefault("return", {"type": "void", "description": ""})
|
||||
|
||||
return signature
|
||||
try:
|
||||
sig = self.llm.chat_json(prompt)
|
||||
if not isinstance(sig, dict):
|
||||
raise ValueError("LLM 返回结果不是 dict")
|
||||
# 确保 module 字段存在
|
||||
if "module" not in sig:
|
||||
sig["module"] = func_req.module or config.DEFAULT_MODULE
|
||||
return sig
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"签名生成失败 [{func_req.function_name}]: {e}")
|
||||
|
||||
def build_function_signatures_batch(
|
||||
self,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
knowledge: str = "",
|
||||
on_progress=None,
|
||||
on_progress: Optional[Callable] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
批量为功能需求列表生成函数接口签名
|
||||
批量生成函数签名,失败时使用降级结构。
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
knowledge: 知识库文本(可选)
|
||||
on_progress: 进度回调 fn(index, total, func_req, signature, error)
|
||||
knowledge: 知识库文本
|
||||
on_progress: 进度回调 fn(index, total, req, signature, error)
|
||||
|
||||
Returns:
|
||||
签名 dict 列表,顺序与 func_reqs 一致;
|
||||
生成失败的条目使用降级结构填充,不中断整体流程
|
||||
与 func_reqs 等长的签名 dict 列表(索引一一对应)
|
||||
"""
|
||||
results = []
|
||||
signatures = []
|
||||
total = len(func_reqs)
|
||||
|
||||
for i, req in enumerate(func_reqs):
|
||||
for i, req in enumerate(func_reqs, 1):
|
||||
req_id = f"REQ.{req.index_no:02d}"
|
||||
try:
|
||||
sig = self.build_function_signature(req, knowledge)
|
||||
results.append(sig)
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, sig, None)
|
||||
sig = self.build_function_signature(req, req_id, knowledge)
|
||||
error = None
|
||||
except Exception as e:
|
||||
# 降级:用基础信息填充,保证 JSON 完整性
|
||||
fallback = self._build_fallback_signature(req)
|
||||
results.append(fallback)
|
||||
sig = self._fallback_signature(req, req_id)
|
||||
error = e
|
||||
|
||||
signatures.append(sig)
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, fallback, e)
|
||||
on_progress(i, total, req, sig, error)
|
||||
|
||||
return results
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 私有工具方法
|
||||
# ══════════════════════════════════════════════════
|
||||
return signatures
|
||||
|
||||
@staticmethod
|
||||
def _build_knowledge_section(knowledge: str) -> str:
|
||||
"""构建知识库 Prompt 段落"""
|
||||
if not knowledge or not knowledge.strip():
|
||||
return ""
|
||||
return f"""## 参考知识库
|
||||
{knowledge}
|
||||
|
||||
---
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_function_name(name: str) -> str:
|
||||
"""
|
||||
清理函数名,确保符合 snake_case 规范
|
||||
|
||||
Args:
|
||||
name: 原始函数名
|
||||
|
||||
Returns:
|
||||
合法的 snake_case 函数名
|
||||
"""
|
||||
name = re.sub(r"[^a-zA-Z0-9_]", "_", name).lower()
|
||||
name = re.sub(r"_+", "_", name).strip("_")
|
||||
if name and name[0].isdigit():
|
||||
name = "func_" + name
|
||||
return name or "unnamed_function"
|
||||
|
||||
@staticmethod
|
||||
def _format_requirement_id(index_no: int) -> str:
|
||||
"""
|
||||
将序号格式化为需求编号字符串
|
||||
|
||||
Args:
|
||||
index_no: 功能需求序号(从 1 开始)
|
||||
|
||||
Returns:
|
||||
格式化编号,如 'REQ.01'、'REQ.12'
|
||||
"""
|
||||
return f"REQ.{index_no:02d}"
|
||||
|
||||
@staticmethod
|
||||
def _build_fallback_signature(func_req: FunctionalRequirement) -> dict:
|
||||
"""
|
||||
构建降级签名(LLM 调用失败时使用)
|
||||
|
||||
Args:
|
||||
func_req: 功能需求对象
|
||||
|
||||
Returns:
|
||||
包含基础信息的签名 dict
|
||||
"""
|
||||
def _fallback_signature(
|
||||
req: FunctionalRequirement,
|
||||
requirement_id: str,
|
||||
) -> dict:
|
||||
"""生成降级签名结构(LLM 失败时使用)"""
|
||||
return {
|
||||
"name": func_req.function_name,
|
||||
"requirement_id": f"REQ.{func_req.index_no:02d}",
|
||||
"description": func_req.description,
|
||||
"name": req.function_name,
|
||||
"requirement_id": requirement_id,
|
||||
"description": req.description,
|
||||
"type": "function",
|
||||
"module": req.module or config.DEFAULT_MODULE,
|
||||
"parameters": {},
|
||||
"return": {
|
||||
"type": "void",
|
||||
"description": "TODO: define return value"
|
||||
"type": "any",
|
||||
"on_success": {"value": "...", "description": "成功时返回值"},
|
||||
"on_failure": {"value": "None", "description": "失败时返回 None"},
|
||||
},
|
||||
"_note": "Auto-generated fallback due to LLM error"
|
||||
}
|
||||
|
|
@ -1,314 +1,152 @@
|
|||
# database/db_manager.py - 数据库操作管理器
|
||||
import sqlite3
|
||||
# database/db_manager.py - 数据库 CRUD 操作封装
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
from database.models import (
|
||||
CREATE_TABLES_SQL, Project, RawRequirement,
|
||||
FunctionalRequirement, CodeFile
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
import config
|
||||
from database.models import Base, Project, RawRequirement, FunctionalRequirement, CodeFile
|
||||
|
||||
|
||||
class DBManager:
|
||||
"""SQLite 数据库管理器,封装所有 CRUD 操作"""
|
||||
|
||||
def __init__(self, db_path: str = config.DB_PATH):
|
||||
self.db_path = db_path
|
||||
def __init__(self, db_path: str = None):
|
||||
db_path = db_path or config.DB_PATH
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
self._init_db()
|
||||
self.engine = create_engine(f"sqlite:///{db_path}", echo=False)
|
||||
Base.metadata.create_all(self.engine)
|
||||
self._Session = sessionmaker(bind=self.engine)
|
||||
|
||||
# ── 连接上下文管理器 ──────────────────────────────────
|
||||
|
||||
@contextmanager
|
||||
def _get_conn(self):
|
||||
"""获取数据库连接(自动提交/回滚)"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库,创建所有表"""
|
||||
with self._get_conn() as conn:
|
||||
conn.executescript(CREATE_TABLES_SQL)
|
||||
def _session(self) -> Session:
|
||||
return self._Session()
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# Project CRUD
|
||||
# Project
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_project(self, project: Project) -> int:
|
||||
"""创建项目,返回新项目 ID"""
|
||||
sql = """
|
||||
INSERT INTO projects (name, description, language, output_dir, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
project.name, project.description, project.language,
|
||||
project.output_dir, project.created_at, project.updated_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
|
||||
def get_project_by_id(self, project_id: int) -> Optional[Project]:
|
||||
"""根据 ID 查询项目"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM projects WHERE id = ?", (project_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return Project(
|
||||
id=row["id"], name=row["name"], description=row["description"],
|
||||
language=row["language"], output_dir=row["output_dir"],
|
||||
created_at=row["created_at"], updated_at=row["updated_at"]
|
||||
)
|
||||
with self._session() as s:
|
||||
s.add(project)
|
||||
s.commit()
|
||||
s.refresh(project)
|
||||
return project.id
|
||||
|
||||
def get_project_by_name(self, name: str) -> Optional[Project]:
|
||||
"""根据名称查询项目"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM projects WHERE name = ?", (name,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return Project(
|
||||
id=row["id"], name=row["name"], description=row["description"],
|
||||
language=row["language"], output_dir=row["output_dir"],
|
||||
created_at=row["created_at"], updated_at=row["updated_at"]
|
||||
)
|
||||
with self._session() as s:
|
||||
return s.query(Project).filter_by(name=name).first()
|
||||
|
||||
def list_projects(self) -> List[Project]:
|
||||
"""列出所有项目"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM projects ORDER BY created_at DESC"
|
||||
).fetchall()
|
||||
return [
|
||||
Project(
|
||||
id=r["id"], name=r["name"], description=r["description"],
|
||||
language=r["language"], output_dir=r["output_dir"],
|
||||
created_at=r["created_at"], updated_at=r["updated_at"]
|
||||
) for r in rows
|
||||
]
|
||||
def get_project_by_id(self, project_id: int) -> Optional[Project]:
|
||||
with self._session() as s:
|
||||
return s.get(Project, project_id)
|
||||
|
||||
def update_project(self, project: Project) -> None:
|
||||
"""更新项目信息"""
|
||||
project.updated_at = datetime.now().isoformat()
|
||||
sql = """
|
||||
UPDATE projects
|
||||
SET name=?, description=?, language=?, output_dir=?, updated_at=?
|
||||
WHERE id=?
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(sql, (
|
||||
project.name, project.description, project.language,
|
||||
project.output_dir, project.updated_at, project.id
|
||||
))
|
||||
with self._session() as s:
|
||||
s.merge(project)
|
||||
s.commit()
|
||||
|
||||
def delete_project(self, project_id: int) -> None:
|
||||
"""删除项目(级联删除所有关联数据)"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute("DELETE FROM projects WHERE id = ?", (project_id,))
|
||||
def list_projects(self) -> List[Project]:
|
||||
with self._session() as s:
|
||||
return s.query(Project).order_by(Project.created_at.desc()).all()
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# RawRequirement CRUD
|
||||
# RawRequirement
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_raw_requirement(self, req: RawRequirement) -> int:
|
||||
"""创建原始需求,返回新记录 ID"""
|
||||
sql = """
|
||||
INSERT INTO raw_requirements
|
||||
(project_id, content, source_type, source_name, knowledge, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
req.project_id, req.content, req.source_type,
|
||||
req.source_name, req.knowledge, req.created_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
def create_raw_requirement(self, raw_req: RawRequirement) -> int:
|
||||
with self._session() as s:
|
||||
s.add(raw_req)
|
||||
s.commit()
|
||||
s.refresh(raw_req)
|
||||
return raw_req.id
|
||||
|
||||
def get_raw_requirement(self, req_id: int) -> Optional[RawRequirement]:
|
||||
"""根据 ID 查询原始需求"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM raw_requirements WHERE id = ?", (req_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return RawRequirement(
|
||||
id=row["id"], project_id=row["project_id"], content=row["content"],
|
||||
source_type=row["source_type"], source_name=row["source_name"],
|
||||
knowledge=row["knowledge"], created_at=row["created_at"]
|
||||
)
|
||||
|
||||
def list_raw_requirements_by_project(self, project_id: int) -> List[RawRequirement]:
|
||||
"""查询项目下所有原始需求"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM raw_requirements WHERE project_id = ? ORDER BY created_at",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
return [
|
||||
RawRequirement(
|
||||
id=r["id"], project_id=r["project_id"], content=r["content"],
|
||||
source_type=r["source_type"], source_name=r["source_name"],
|
||||
knowledge=r["knowledge"], created_at=r["created_at"]
|
||||
) for r in rows
|
||||
]
|
||||
def get_raw_requirement(self, raw_req_id: int) -> Optional[RawRequirement]:
|
||||
with self._session() as s:
|
||||
return s.get(RawRequirement, raw_req_id)
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# FunctionalRequirement CRUD
|
||||
# FunctionalRequirement
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_functional_requirement(self, req: FunctionalRequirement) -> int:
|
||||
"""创建功能需求,返回新记录 ID"""
|
||||
sql = """
|
||||
INSERT INTO functional_requirements
|
||||
(project_id, raw_req_id, index_no, title, description,
|
||||
function_name, priority, status, is_custom, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
req.project_id, req.raw_req_id, req.index_no,
|
||||
req.title, req.description, req.function_name,
|
||||
req.priority, req.status, int(req.is_custom),
|
||||
req.created_at, req.updated_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
with self._session() as s:
|
||||
s.add(req)
|
||||
s.commit()
|
||||
s.refresh(req)
|
||||
return req.id
|
||||
|
||||
def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]:
|
||||
"""根据 ID 查询功能需求"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM functional_requirements WHERE id = ?", (req_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_func_req(row)
|
||||
with self._session() as s:
|
||||
return s.get(FunctionalRequirement, req_id)
|
||||
|
||||
def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]:
|
||||
"""查询项目下所有功能需求(按序号排序)"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"""SELECT * FROM functional_requirements
|
||||
WHERE project_id = ? ORDER BY index_no""",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
return [self._row_to_func_req(r) for r in rows]
|
||||
with self._session() as s:
|
||||
return (
|
||||
s.query(FunctionalRequirement)
|
||||
.filter_by(project_id=project_id)
|
||||
.order_by(FunctionalRequirement.index_no)
|
||||
.all()
|
||||
)
|
||||
|
||||
def update_functional_requirement(self, req: FunctionalRequirement) -> None:
|
||||
"""更新功能需求"""
|
||||
req.updated_at = datetime.now().isoformat()
|
||||
sql = """
|
||||
UPDATE functional_requirements
|
||||
SET title=?, description=?, function_name=?, priority=?,
|
||||
status=?, index_no=?, updated_at=?
|
||||
WHERE id=?
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(sql, (
|
||||
req.title, req.description, req.function_name,
|
||||
req.priority, req.status, req.index_no,
|
||||
req.updated_at, req.id
|
||||
))
|
||||
with self._session() as s:
|
||||
s.merge(req)
|
||||
s.commit()
|
||||
|
||||
def delete_functional_requirement(self, req_id: int) -> None:
|
||||
"""删除功能需求"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM functional_requirements WHERE id = ?", (req_id,)
|
||||
)
|
||||
with self._session() as s:
|
||||
obj = s.get(FunctionalRequirement, req_id)
|
||||
if obj:
|
||||
s.delete(obj)
|
||||
s.commit()
|
||||
|
||||
def _row_to_func_req(self, row) -> FunctionalRequirement:
|
||||
"""sqlite Row → FunctionalRequirement 对象"""
|
||||
return FunctionalRequirement(
|
||||
id=row["id"], project_id=row["project_id"],
|
||||
raw_req_id=row["raw_req_id"], index_no=row["index_no"],
|
||||
title=row["title"], description=row["description"],
|
||||
function_name=row["function_name"], priority=row["priority"],
|
||||
status=row["status"], is_custom=bool(row["is_custom"]),
|
||||
created_at=row["created_at"], updated_at=row["updated_at"]
|
||||
)
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# CodeFile CRUD
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_code_file(self, code_file: CodeFile) -> int:
|
||||
"""创建代码文件记录,返回新记录 ID"""
|
||||
sql = """
|
||||
INSERT INTO code_files
|
||||
(project_id, func_req_id, file_name, file_path,
|
||||
language, content, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
def bulk_update_modules(self, updates: List[dict]) -> None:
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
code_file.project_id, code_file.func_req_id,
|
||||
code_file.file_name, code_file.file_path,
|
||||
code_file.language, code_file.content,
|
||||
code_file.created_at, code_file.updated_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
批量更新功能需求的 module 字段。
|
||||
|
||||
Args:
|
||||
updates: [{"function_name": "...", "module": "..."}, ...]
|
||||
"""
|
||||
with self._session() as s:
|
||||
name_to_module = {u["function_name"]: u["module"] for u in updates}
|
||||
reqs = s.query(FunctionalRequirement).filter(
|
||||
FunctionalRequirement.function_name.in_(name_to_module.keys())
|
||||
).all()
|
||||
for req in reqs:
|
||||
req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE)
|
||||
s.commit()
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# CodeFile
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def upsert_code_file(self, code_file: CodeFile) -> int:
|
||||
"""插入或更新代码文件(按 func_req_id 唯一键)"""
|
||||
existing = self.get_code_file_by_func_req(code_file.func_req_id)
|
||||
if existing:
|
||||
code_file.id = existing.id
|
||||
code_file.updated_at = datetime.now().isoformat()
|
||||
sql = """
|
||||
UPDATE code_files
|
||||
SET file_name=?, file_path=?, language=?, content=?, updated_at=?
|
||||
WHERE id=?
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(sql, (
|
||||
code_file.file_name, code_file.file_path,
|
||||
code_file.language, code_file.content,
|
||||
code_file.updated_at, code_file.id
|
||||
))
|
||||
return code_file.id
|
||||
else:
|
||||
return self.create_code_file(code_file)
|
||||
|
||||
def get_code_file_by_func_req(self, func_req_id: int) -> Optional[CodeFile]:
|
||||
"""根据功能需求 ID 查询代码文件"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM code_files WHERE func_req_id = ?", (func_req_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_code_file(row)
|
||||
|
||||
def list_code_files_by_project(self, project_id: int) -> List[CodeFile]:
|
||||
"""查询项目下所有代码文件"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM code_files WHERE project_id = ? ORDER BY created_at",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
return [self._row_to_code_file(r) for r in rows]
|
||||
|
||||
def _row_to_code_file(self, row) -> CodeFile:
|
||||
"""sqlite Row → CodeFile 对象"""
|
||||
return CodeFile(
|
||||
id=row["id"], project_id=row["project_id"],
|
||||
func_req_id=row["func_req_id"], file_name=row["file_name"],
|
||||
file_path=row["file_path"], language=row["language"],
|
||||
content=row["content"], created_at=row["created_at"],
|
||||
updated_at=row["updated_at"]
|
||||
with self._session() as s:
|
||||
existing = (
|
||||
s.query(CodeFile)
|
||||
.filter_by(func_req_id=code_file.func_req_id)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.file_name = code_file.file_name
|
||||
existing.file_path = code_file.file_path
|
||||
existing.module = code_file.module
|
||||
existing.language = code_file.language
|
||||
existing.content = code_file.content
|
||||
s.commit()
|
||||
return existing.id
|
||||
else:
|
||||
s.add(code_file)
|
||||
s.commit()
|
||||
s.refresh(code_file)
|
||||
return code_file.id
|
||||
|
||||
def list_code_files(self, project_id: int) -> List[CodeFile]:
|
||||
with self._session() as s:
|
||||
return (
|
||||
s.query(CodeFile)
|
||||
.join(FunctionalRequirement)
|
||||
.filter(FunctionalRequirement.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
|
|
@ -1,122 +1,95 @@
|
|||
# database/models.py - 数据模型定义(SQLite 建表 DDL)
|
||||
from dataclasses import dataclass, field
|
||||
# database/models.py - SQLAlchemy ORM 数据模型
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, Text, Boolean,
|
||||
DateTime, ForeignKey, Enum,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# DDL 建表语句
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
CREATE_TABLES_SQL = """
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
-- 项目表
|
||||
CREATE TABLE IF NOT EXISTS projects (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE, -- 项目名称
|
||||
description TEXT, -- 项目描述
|
||||
language TEXT NOT NULL DEFAULT 'python', -- 目标代码语言
|
||||
output_dir TEXT NOT NULL, -- 输出目录路径
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- 原始需求表
|
||||
CREATE TABLE IF NOT EXISTS raw_requirements (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
project_id INTEGER NOT NULL, -- 关联项目
|
||||
content TEXT NOT NULL, -- 需求原文
|
||||
source_type TEXT NOT NULL DEFAULT 'text', -- text | file
|
||||
source_name TEXT, -- 文件名(文件输入时)
|
||||
knowledge TEXT, -- 合并后的知识库内容
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- 功能需求表
|
||||
CREATE TABLE IF NOT EXISTS functional_requirements (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
project_id INTEGER NOT NULL,
|
||||
raw_req_id INTEGER NOT NULL, -- 关联原始需求
|
||||
index_no INTEGER NOT NULL, -- 序号
|
||||
title TEXT NOT NULL, -- 功能标题
|
||||
description TEXT NOT NULL, -- 功能描述
|
||||
function_name TEXT NOT NULL, -- 对应函数名
|
||||
priority TEXT NOT NULL DEFAULT 'medium', -- high|medium|low
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- pending|generated|skipped
|
||||
is_custom INTEGER NOT NULL DEFAULT 0, -- 是否用户自定义 (0/1)
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (raw_req_id) REFERENCES raw_requirements(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- 代码文件表
|
||||
CREATE TABLE IF NOT EXISTS code_files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
project_id INTEGER NOT NULL,
|
||||
func_req_id INTEGER NOT NULL UNIQUE, -- 关联功能需求(1对1)
|
||||
file_name TEXT NOT NULL, -- 文件名
|
||||
file_path TEXT NOT NULL, -- 完整路径
|
||||
language TEXT NOT NULL, -- 代码语言
|
||||
content TEXT NOT NULL, -- 代码内容
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (func_req_id) REFERENCES functional_requirements(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# 数据类(Python 对象映射)
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
@dataclass
|
||||
class Project:
|
||||
name: str
|
||||
output_dir: str
|
||||
language: str = "python"
|
||||
description: str = ""
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RawRequirement:
|
||||
project_id: int
|
||||
content: str
|
||||
source_type: str = "text" # text | file
|
||||
source_name: Optional[str] = None
|
||||
knowledge: Optional[str] = None
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
class Project(Base):
|
||||
"""项目表"""
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(200), nullable=False, unique=True)
|
||||
language = Column(String(50), nullable=False, default="python")
|
||||
description = Column(Text, nullable=True)
|
||||
output_dir = Column(String(500), nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
raw_requirements = relationship("RawRequirement", back_populates="project", cascade="all, delete-orphan")
|
||||
functional_requirements = relationship("FunctionalRequirement", back_populates="project", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Project(id={self.id}, name={self.name!r}, language={self.language!r})>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionalRequirement:
|
||||
project_id: int
|
||||
raw_req_id: int
|
||||
index_no: int
|
||||
title: str
|
||||
description: str
|
||||
function_name: str
|
||||
priority: str = "medium"
|
||||
status: str = "pending"
|
||||
is_custom: bool = False
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
class RawRequirement(Base):
|
||||
"""原始需求表"""
|
||||
__tablename__ = "raw_requirements"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
source_type = Column(String(20), nullable=False, default="text") # text / file
|
||||
source_name = Column(String(200), nullable=True)
|
||||
knowledge = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
project = relationship("Project", back_populates="raw_requirements")
|
||||
functional_requirements = relationship("FunctionalRequirement", back_populates="raw_requirement", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RawRequirement(id={self.id}, project_id={self.project_id})>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeFile:
|
||||
project_id: int
|
||||
func_req_id: int
|
||||
file_name: str
|
||||
file_path: str
|
||||
language: str
|
||||
content: str
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
class FunctionalRequirement(Base):
|
||||
"""功能需求表"""
|
||||
__tablename__ = "functional_requirements"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False)
|
||||
raw_req_id = Column(Integer, ForeignKey("raw_requirements.id"), nullable=False)
|
||||
index_no = Column(Integer, nullable=False)
|
||||
title = Column(String(200), nullable=False)
|
||||
description = Column(Text, nullable=False)
|
||||
function_name = Column(String(200), nullable=False)
|
||||
priority = Column(Enum("high", "medium", "low"), default="medium")
|
||||
module = Column(String(100), nullable=True, default="default") # 功能模块
|
||||
status = Column(String(50), nullable=False, default="pending") # pending / generated / failed
|
||||
is_custom = Column(Boolean, nullable=False, default=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
project = relationship("Project", back_populates="functional_requirements")
|
||||
raw_requirement = relationship("RawRequirement", back_populates="functional_requirements")
|
||||
code_files = relationship("CodeFile", back_populates="functional_requirement", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<FunctionalRequirement(id={self.id}, title={self.title!r}, "
|
||||
f"module={self.module!r}, status={self.status!r})>"
|
||||
)
|
||||
|
||||
|
||||
class CodeFile(Base):
|
||||
"""生成代码文件表"""
|
||||
__tablename__ = "code_files"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
func_req_id = Column(Integer, ForeignKey("functional_requirements.id"), nullable=False)
|
||||
file_name = Column(String(200), nullable=False)
|
||||
file_path = Column(String(500), nullable=False)
|
||||
module = Column(String(100), nullable=True) # 冗余存储,方便查询
|
||||
language = Column(String(50), nullable=False)
|
||||
content = Column(Text, nullable=True)
|
||||
generated_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
functional_requirement = relationship("FunctionalRequirement", back_populates="code_files")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CodeFile(id={self.id}, file_name={self.file_name!r}, module={self.module!r})>"
|
||||
|
|
@ -1,17 +1,8 @@
|
|||
#!/usr/bin/env python3
|
||||
# encoding: utf-8
|
||||
# main.py - 主入口:支持交互式 & 非交互式(CLI 参数)两种运行模式
|
||||
#
|
||||
# 交互式: python main.py
|
||||
# 非交互式:python main.py --non-interactive \
|
||||
# --project-name "MyProject" \
|
||||
# --language python \
|
||||
# --requirement-text "用户管理系统,包含注册、登录、修改密码功能"
|
||||
#
|
||||
# 完整参数见:python main.py --help
|
||||
# main.py - 主入口:支持交互式 & 非交互式两种运行模式
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
|
|
@ -27,9 +18,9 @@ from core.requirement_analyzer import RequirementAnalyzer
|
|||
from core.code_generator import CodeGenerator
|
||||
from utils.file_handler import read_file_auto, merge_knowledge_files
|
||||
from utils.output_writer import (
|
||||
ensure_project_dir, build_project_output_dir, write_project_readme,
|
||||
write_function_signatures_json, validate_all_signatures,
|
||||
patch_signatures_with_url,
|
||||
ensure_project_dir, build_project_output_dir,
|
||||
write_project_readme, write_function_signatures_json,
|
||||
validate_all_signatures, patch_signatures_with_url,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
|
@ -44,20 +35,20 @@ def print_banner():
|
|||
console.print(Panel.fit(
|
||||
"[bold cyan]🚀 需求分析 & 代码生成工具[/bold cyan]\n"
|
||||
"[dim]Powered by LLM · SQLite · Python[/dim]",
|
||||
border_style="cyan"
|
||||
border_style="cyan",
|
||||
))
|
||||
|
||||
|
||||
def print_functional_requirements(reqs: list):
|
||||
"""以表格形式展示功能需求列表"""
|
||||
def print_functional_requirements(reqs: List[FunctionalRequirement]):
|
||||
"""以表格形式展示功能需求列表(含模块列)"""
|
||||
table = Table(title="📋 功能需求列表", show_lines=True)
|
||||
table.add_column("序号", style="cyan", width=6)
|
||||
table.add_column("ID", style="dim", width=6)
|
||||
table.add_column("标题", style="bold", width=20)
|
||||
table.add_column("模块", style="magenta", width=15)
|
||||
table.add_column("标题", style="bold", width=18)
|
||||
table.add_column("函数名", width=25)
|
||||
table.add_column("优先级", width=8)
|
||||
table.add_column("类型", width=8)
|
||||
table.add_column("描述", width=40)
|
||||
table.add_column("描述", width=35)
|
||||
|
||||
priority_color = {"high": "red", "medium": "yellow", "low": "green"}
|
||||
for req in reqs:
|
||||
|
|
@ -65,53 +56,51 @@ def print_functional_requirements(reqs: list):
|
|||
table.add_row(
|
||||
str(req.index_no),
|
||||
str(req.id) if req.id else "-",
|
||||
req.module or config.DEFAULT_MODULE,
|
||||
req.title,
|
||||
f"[code]{req.function_name}[/code]",
|
||||
f"[{color}]{req.priority}[/{color}]",
|
||||
"[magenta]自定义[/magenta]" if req.is_custom else "LLM生成",
|
||||
req.description[:60] + "..." if len(req.description) > 60 else req.description,
|
||||
req.description[:50] + "..." if len(req.description) > 50 else req.description,
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
|
||||
def print_signatures_preview(signatures: list):
|
||||
"""
|
||||
以表格形式预览函数签名列表(含 url 字段)
|
||||
def print_module_summary(reqs: List[FunctionalRequirement]):
|
||||
"""打印模块分组摘要"""
|
||||
module_map: Dict[str, List[str]] = {}
|
||||
for req in reqs:
|
||||
m = req.module or config.DEFAULT_MODULE
|
||||
module_map.setdefault(m, []).append(req.function_name)
|
||||
|
||||
Args:
|
||||
signatures: 纯签名列表(顶层文档的 "functions" 字段)
|
||||
"""
|
||||
table = Table(title="📦 功能模块分组", show_lines=True)
|
||||
table.add_column("模块", style="magenta bold", width=20)
|
||||
table.add_column("函数数量", style="cyan", width=8)
|
||||
table.add_column("函数列表", width=50)
|
||||
for module, funcs in sorted(module_map.items()):
|
||||
table.add_row(module, str(len(funcs)), ", ".join(funcs))
|
||||
console.print(table)
|
||||
|
||||
|
||||
def print_signatures_preview(signatures: List[dict]):
|
||||
"""以表格形式预览函数签名列表(含 module / url 列)"""
|
||||
table = Table(title="📄 函数签名预览", show_lines=True)
|
||||
table.add_column("需求编号", style="cyan", width=10)
|
||||
table.add_column("函数名", style="bold", width=25)
|
||||
table.add_column("参数数量", width=8)
|
||||
table.add_column("需求编号", style="cyan", width=8)
|
||||
table.add_column("模块", style="magenta", width=15)
|
||||
table.add_column("函数名", style="bold", width=22)
|
||||
table.add_column("参数数", width=6)
|
||||
table.add_column("返回类型", width=10)
|
||||
table.add_column("成功返回值", width=18)
|
||||
table.add_column("失败返回值", width=18)
|
||||
table.add_column("URL", style="dim", width=30)
|
||||
|
||||
def _fmt_value(v) -> str:
|
||||
if v is None:
|
||||
return "-"
|
||||
if isinstance(v, dict):
|
||||
return "{" + ", ".join(v.keys()) + "}"
|
||||
return str(v)[:16]
|
||||
table.add_column("URL", style="dim", width=28)
|
||||
|
||||
for sig in signatures:
|
||||
ret = sig.get("return") or {}
|
||||
on_success = ret.get("on_success") or {}
|
||||
on_failure = ret.get("on_failure") or {}
|
||||
url = sig.get("url", "")
|
||||
# 只显示文件名部分,避免路径过长
|
||||
url_display = os.path.basename(url) if url else "[dim]待生成[/dim]"
|
||||
|
||||
table.add_row(
|
||||
sig.get("requirement_id", "-"),
|
||||
sig.get("module", "-"),
|
||||
sig.get("name", "-"),
|
||||
str(len(sig.get("parameters", {}))),
|
||||
ret.get("type", "void"),
|
||||
_fmt_value(on_success.get("value")),
|
||||
_fmt_value(on_failure.get("value")),
|
||||
url_display,
|
||||
)
|
||||
console.print(table)
|
||||
|
|
@ -127,12 +116,15 @@ def step_init_project(
|
|||
description: str = "",
|
||||
non_interactive: bool = False,
|
||||
) -> Project:
|
||||
console.print(
|
||||
"\n[bold]Step 1 · 项目配置[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
if not non_interactive:
|
||||
console.print("\n[bold]Step 1 · 项目配置[/bold]", style="blue")
|
||||
project_name = project_name or Prompt.ask("📁 请输入项目名称")
|
||||
project_name = project_name or Prompt.ask("📁 项目名称")
|
||||
language = language or Prompt.ask(
|
||||
"💻 目标代码语言",
|
||||
default=config.DEFAULT_LANGUAGE,
|
||||
"💻 目标语言", default=config.DEFAULT_LANGUAGE,
|
||||
choices=["python","javascript","typescript","java","go","rust"],
|
||||
)
|
||||
description = description or Prompt.ask("📝 项目描述(可选)", default="")
|
||||
|
|
@ -140,25 +132,22 @@ def step_init_project(
|
|||
if not project_name:
|
||||
raise ValueError("非交互模式下 --project-name 为必填项")
|
||||
language = language or config.DEFAULT_LANGUAGE
|
||||
console.print("\n[bold]Step 1 · 项目配置[/bold] [dim](非交互)[/dim]", style="blue")
|
||||
console.print(f" 项目名称: {project_name} 语言: {language}")
|
||||
console.print(f" 项目: {project_name} 语言: {language}")
|
||||
|
||||
existing = db.get_project_by_name(project_name)
|
||||
if existing:
|
||||
if non_interactive:
|
||||
console.print(f"[green]✓ 已加载已有项目: {project_name} (ID={existing.id})[/green]")
|
||||
console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]")
|
||||
return existing
|
||||
use_existing = Confirm.ask(f"⚠️ 项目 '{project_name}' 已存在,是否继续使用?")
|
||||
if use_existing:
|
||||
if Confirm.ask(f"⚠️ 项目 '{project_name}' 已存在,继续使用?"):
|
||||
console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]")
|
||||
return existing
|
||||
project_name = Prompt.ask("请输入新的项目名称")
|
||||
|
||||
output_dir = build_project_output_dir(project_name)
|
||||
project = Project(
|
||||
name = project_name,
|
||||
language = language,
|
||||
output_dir=output_dir,
|
||||
output_dir = build_project_output_dir(project_name),
|
||||
description = description,
|
||||
)
|
||||
project.id = db.create_project(project)
|
||||
|
|
@ -178,11 +167,10 @@ def step_input_requirement(
|
|||
non_interactive: bool = False,
|
||||
) -> tuple:
|
||||
console.print(
|
||||
f"\n[bold]Step 2 · 输入原始需求[/bold]"
|
||||
"\n[bold]Step 2 · 输入原始需求[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
raw_text = ""
|
||||
source_name = None
|
||||
source_type = "text"
|
||||
|
|
@ -195,7 +183,6 @@ def step_input_requirement(
|
|||
console.print(f" 需求文件: {source_name} ({len(raw_text)} 字符)")
|
||||
elif requirement_text:
|
||||
raw_text = requirement_text
|
||||
source_type = "text"
|
||||
console.print(f" 需求文本: {raw_text[:80]}{'...' if len(raw_text)>80 else ''}")
|
||||
else:
|
||||
raise ValueError("非交互模式下必须提供 --requirement-text 或 --requirement-file")
|
||||
|
|
@ -210,13 +197,12 @@ def step_input_requirement(
|
|||
break
|
||||
lines.append(line)
|
||||
raw_text = "\n".join(lines)
|
||||
source_type = "text"
|
||||
else:
|
||||
file_path = Prompt.ask("📂 需求文件路径")
|
||||
raw_text = read_file_auto(file_path)
|
||||
source_name = os.path.basename(file_path)
|
||||
fp = Prompt.ask("📂 需求文件路径")
|
||||
raw_text = read_file_auto(fp)
|
||||
source_name = os.path.basename(fp)
|
||||
source_type = "file"
|
||||
console.print(f"[green]✓ 已读取文件: {source_name} ({len(raw_text)} 字符)[/green]")
|
||||
console.print(f"[green]✓ 已读取: {source_name} ({len(raw_text)} 字符)[/green]")
|
||||
|
||||
knowledge_text = ""
|
||||
if non_interactive:
|
||||
|
|
@ -224,18 +210,17 @@ def step_input_requirement(
|
|||
knowledge_text = merge_knowledge_files(list(knowledge_files))
|
||||
console.print(f" 知识库: {len(knowledge_files)} 个文件,{len(knowledge_text)} 字符")
|
||||
else:
|
||||
use_kb = Confirm.ask("📚 是否输入知识库文件?", default=False)
|
||||
if use_kb:
|
||||
if Confirm.ask("📚 是否输入知识库文件?", default=False):
|
||||
kb_paths = []
|
||||
while True:
|
||||
kb_path = Prompt.ask("知识库文件路径(留空结束)", default="")
|
||||
if not kb_path:
|
||||
p = Prompt.ask("知识库文件路径(留空结束)", default="")
|
||||
if not p:
|
||||
break
|
||||
if os.path.exists(kb_path):
|
||||
kb_paths.append(kb_path)
|
||||
console.print(f" [green]+ {kb_path}[/green]")
|
||||
if os.path.exists(p):
|
||||
kb_paths.append(p)
|
||||
console.print(f" [green]+ {p}[/green]")
|
||||
else:
|
||||
console.print(f" [red]文件不存在: {kb_path}[/red]")
|
||||
console.print(f" [red]文件不存在: {p}[/red]")
|
||||
if kb_paths:
|
||||
knowledge_text = merge_knowledge_files(kb_paths)
|
||||
console.print(f"[green]✓ 知识库已合并 ({len(knowledge_text)} 字符)[/green]")
|
||||
|
|
@ -256,11 +241,10 @@ def step_decompose_requirements(
|
|||
non_interactive: bool = False,
|
||||
) -> tuple:
|
||||
console.print(
|
||||
f"\n[bold]Step 3 · LLM 需求分解[/bold]"
|
||||
"\n[bold]Step 3 · LLM 需求分解[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
raw_req = RawRequirement(
|
||||
project_id = project.id,
|
||||
content = raw_text,
|
||||
|
|
@ -271,7 +255,7 @@ def step_decompose_requirements(
|
|||
raw_req_id = db.create_raw_requirement(raw_req)
|
||||
console.print(f"[dim]原始需求已存储 (ID={raw_req_id})[/dim]")
|
||||
|
||||
with console.status("[bold yellow]🤖 LLM 正在分解需求,请稍候...[/bold yellow]"):
|
||||
with console.status("[bold yellow]🤖 LLM 正在分解需求...[/bold yellow]"):
|
||||
llm = LLMClient()
|
||||
analyzer = RequirementAnalyzer(llm)
|
||||
func_reqs = analyzer.decompose(
|
||||
|
|
@ -289,18 +273,107 @@ def step_decompose_requirements(
|
|||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 4:用户编辑功能需求
|
||||
# Step 4:模块分类(可选重新分类)
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_classify_modules(
|
||||
project: Project,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
knowledge_text: str = "",
|
||||
non_interactive: bool = False,
|
||||
) -> List[FunctionalRequirement]:
|
||||
"""
|
||||
Step 4:对功能需求进行模块分类。
|
||||
|
||||
- 非交互模式:直接使用 LLM 分类结果
|
||||
- 交互模式:展示 LLM 分类结果,允许用户手动调整
|
||||
"""
|
||||
console.print(
|
||||
"\n[bold]Step 4 · 功能模块分类[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
# LLM 自动分类
|
||||
with console.status("[bold yellow]🤖 LLM 正在进行模块分类...[/bold yellow]"):
|
||||
llm = LLMClient()
|
||||
analyzer = RequirementAnalyzer(llm)
|
||||
try:
|
||||
updates = analyzer.classify_modules(func_reqs, knowledge_text)
|
||||
# 回写 module 到 func_reqs 对象
|
||||
name_to_module = {u["function_name"]: u["module"] for u in updates}
|
||||
for req in func_reqs:
|
||||
req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE)
|
||||
db.update_functional_requirement(req)
|
||||
console.print(f"[green]✓ LLM 模块分类完成[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]⚠ 模块分类失败,保留原有模块: {e}[/yellow]")
|
||||
|
||||
print_module_summary(func_reqs)
|
||||
|
||||
if non_interactive:
|
||||
return func_reqs
|
||||
|
||||
# 交互式调整
|
||||
while True:
|
||||
console.print(
|
||||
"\n模块操作: [cyan]r[/cyan]=重新分类 "
|
||||
"[cyan]e[/cyan]=手动编辑某需求的模块 [cyan]ok[/cyan]=确认继续"
|
||||
)
|
||||
action = Prompt.ask("请选择操作", default="ok").strip().lower()
|
||||
|
||||
if action == "ok":
|
||||
break
|
||||
|
||||
elif action == "r":
|
||||
# 重新触发 LLM 分类
|
||||
with console.status("[bold yellow]🤖 重新分类中...[/bold yellow]"):
|
||||
try:
|
||||
updates = analyzer.classify_modules(func_reqs, knowledge_text)
|
||||
name_to_module = {u["function_name"]: u["module"] for u in updates}
|
||||
for req in func_reqs:
|
||||
req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE)
|
||||
db.update_functional_requirement(req)
|
||||
console.print("[green]✓ 重新分类完成[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]重新分类失败: {e}[/red]")
|
||||
print_module_summary(func_reqs)
|
||||
|
||||
elif action == "e":
|
||||
print_functional_requirements(func_reqs)
|
||||
idx_str = Prompt.ask("输入要修改模块的需求序号")
|
||||
if not idx_str.isdigit():
|
||||
continue
|
||||
idx = int(idx_str)
|
||||
target = next((r for r in func_reqs if r.index_no == idx), None)
|
||||
if target is None:
|
||||
console.print("[red]序号不存在[/red]")
|
||||
continue
|
||||
new_module = Prompt.ask(
|
||||
f"新模块名(当前: {target.module})",
|
||||
default=target.module or config.DEFAULT_MODULE,
|
||||
)
|
||||
target.module = new_module.strip() or config.DEFAULT_MODULE
|
||||
db.update_functional_requirement(target)
|
||||
console.print(f"[green]✓ 已更新 '{target.function_name}' → 模块: {target.module}[/green]")
|
||||
print_module_summary(func_reqs)
|
||||
|
||||
return func_reqs
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 5:编辑功能需求
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_edit_requirements(
|
||||
project: Project,
|
||||
func_reqs: list,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
raw_req_id: int,
|
||||
non_interactive: bool = False,
|
||||
skip_indices: list = None,
|
||||
) -> list:
|
||||
) -> List[FunctionalRequirement]:
|
||||
console.print(
|
||||
f"\n[bold]Step 4 · 编辑功能需求[/bold]"
|
||||
"\n[bold]Step 5 · 编辑功能需求[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
|
@ -334,8 +407,9 @@ def step_edit_requirements(
|
|||
|
||||
if action == "ok":
|
||||
break
|
||||
|
||||
elif action == "d":
|
||||
idx_str = Prompt.ask("输入要删除的功能需求序号(多个用逗号分隔)")
|
||||
idx_str = Prompt.ask("输入要删除的序号(多个用逗号分隔)")
|
||||
to_delete = {int(x.strip()) for x in idx_str.split(",") if x.strip().isdigit()}
|
||||
removed, kept = [], []
|
||||
for req in func_reqs:
|
||||
|
|
@ -349,6 +423,7 @@ def step_edit_requirements(
|
|||
req.index_no = i
|
||||
db.update_functional_requirement(req)
|
||||
console.print(f"[red]✗ 已删除: {', '.join(removed)}[/red]")
|
||||
|
||||
elif action == "a":
|
||||
title = Prompt.ask("功能标题")
|
||||
description = Prompt.ask("功能描述")
|
||||
|
|
@ -356,6 +431,10 @@ def step_edit_requirements(
|
|||
priority = Prompt.ask(
|
||||
"优先级", choices=["high","medium","low"], default="medium"
|
||||
)
|
||||
module = Prompt.ask(
|
||||
"所属模块(snake_case,留空使用默认)",
|
||||
default=config.DEFAULT_MODULE,
|
||||
)
|
||||
new_req = FunctionalRequirement(
|
||||
project_id = project.id,
|
||||
raw_req_id = raw_req_id,
|
||||
|
|
@ -364,13 +443,15 @@ def step_edit_requirements(
|
|||
description = description,
|
||||
function_name = func_name,
|
||||
priority = priority,
|
||||
module = module.strip() or config.DEFAULT_MODULE,
|
||||
is_custom = True,
|
||||
)
|
||||
new_req.id = db.create_functional_requirement(new_req)
|
||||
func_reqs.append(new_req)
|
||||
console.print(f"[green]✓ 已添加自定义需求: {title}[/green]")
|
||||
console.print(f"[green]✓ 已添加: {title} → 模块: {new_req.module}[/green]")
|
||||
|
||||
elif action == "e":
|
||||
idx_str = Prompt.ask("输入要编辑的功能需求序号")
|
||||
idx_str = Prompt.ask("输入要编辑的序号")
|
||||
if not idx_str.isdigit():
|
||||
continue
|
||||
idx = int(idx_str)
|
||||
|
|
@ -384,6 +465,9 @@ def step_edit_requirements(
|
|||
target.priority = Prompt.ask(
|
||||
"新优先级", choices=["high","medium","low"], default=target.priority
|
||||
)
|
||||
target.module = Prompt.ask(
|
||||
"新模块", default=target.module or config.DEFAULT_MODULE
|
||||
).strip() or config.DEFAULT_MODULE
|
||||
db.update_functional_requirement(target)
|
||||
console.print(f"[green]✓ 已更新: {target.title}[/green]")
|
||||
|
||||
|
|
@ -391,33 +475,24 @@ def step_edit_requirements(
|
|||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 5A:生成函数签名 JSON(不含 url 字段,待 5C 回写)
|
||||
# Step 6A:生成函数签名 JSON(初版,不含 url)
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_generate_signatures(
|
||||
project: Project,
|
||||
func_reqs: list,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
output_dir: str,
|
||||
knowledge_text: str,
|
||||
json_file_name: str = "function_signatures.json",
|
||||
non_interactive: bool = False,
|
||||
) -> tuple:
|
||||
"""
|
||||
为所有功能需求生成函数签名,写入初版 JSON(不含 url 字段)。
|
||||
url 字段将在 Step 5C 代码生成完成后回写并刷新 JSON 文件。
|
||||
|
||||
Returns:
|
||||
(signatures: List[dict], json_path: str)
|
||||
"""
|
||||
console.print(
|
||||
f"\n[bold]Step 5A · 生成函数签名 JSON[/bold]"
|
||||
"\n[bold]Step 6A · 生成函数签名 JSON[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
llm = LLMClient()
|
||||
analyzer = RequirementAnalyzer(llm)
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
|
|
@ -425,15 +500,14 @@ def step_generate_signatures(
|
|||
nonlocal success_count, fail_count
|
||||
if error:
|
||||
console.print(
|
||||
f" [{index}/{total}] [yellow]⚠ {req.title} 签名生成失败,"
|
||||
f"使用降级结构: {error}[/yellow]"
|
||||
f" [{index}/{total}] [yellow]⚠ {req.title} 签名生成失败"
|
||||
f"(降级): {error}[/yellow]"
|
||||
)
|
||||
fail_count += 1
|
||||
else:
|
||||
console.print(
|
||||
f" [{index}/{total}] [green]✓ {req.title}[/green] "
|
||||
f"→ [dim]{signature.get('name')}()[/dim] "
|
||||
f"params={len(signature.get('parameters', {}))}"
|
||||
f"[dim]{req.module}[/dim] → {signature.get('name')}()"
|
||||
)
|
||||
success_count += 1
|
||||
|
||||
|
|
@ -445,21 +519,20 @@ def step_generate_signatures(
|
|||
)
|
||||
|
||||
# 校验
|
||||
validation_report = validate_all_signatures(signatures)
|
||||
if validation_report:
|
||||
console.print(f"[yellow]⚠ 发现 {len(validation_report)} 个签名存在结构问题:[/yellow]")
|
||||
for fname, errors in validation_report.items():
|
||||
for err in errors:
|
||||
report = validate_all_signatures(signatures)
|
||||
if report:
|
||||
console.print(f"[yellow]⚠ {len(report)} 个签名存在结构问题:[/yellow]")
|
||||
for fname, errs in report.items():
|
||||
for err in errs:
|
||||
console.print(f" [yellow]· {fname}: {err}[/yellow]")
|
||||
else:
|
||||
console.print("[green]✓ 所有签名结构校验通过[/green]")
|
||||
|
||||
# 写入初版 JSON(url 字段尚未填入)
|
||||
json_path = write_function_signatures_json(
|
||||
output_dir = output_dir,
|
||||
signatures = signatures,
|
||||
project_name = project.name,
|
||||
project_description=project.description or "", # ← 传入项目描述
|
||||
project_description = project.description or "",
|
||||
file_name = json_file_name,
|
||||
)
|
||||
console.print(
|
||||
|
|
@ -470,35 +543,32 @@ def step_generate_signatures(
|
|||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 5B:生成代码文件,收集 {函数名: 文件路径} 映射
|
||||
# Step 6B:生成代码文件(按模块写入子目录)
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_generate_code(
|
||||
project: Project,
|
||||
func_reqs: list,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
output_dir: str,
|
||||
knowledge_text: str,
|
||||
signatures: list,
|
||||
signatures: List[dict],
|
||||
non_interactive: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
依据签名约束批量生成代码文件。
|
||||
批量生成代码文件,按 req.module 路由到 output_dir/<module>/ 子目录。
|
||||
|
||||
Returns:
|
||||
func_name_to_url: {函数名: 代码文件绝对路径} 映射表,
|
||||
供 Step 5C 回写 url 字段使用。
|
||||
生成失败的函数不会出现在映射表中。
|
||||
func_name_to_url: {函数名: 代码文件绝对路径}
|
||||
"""
|
||||
console.print(
|
||||
f"\n[bold]Step 5B · 生成代码文件[/bold]"
|
||||
"\n[bold]Step 6B · 生成代码文件[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
generator = CodeGenerator(LLMClient())
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
func_name_to_url: Dict[str, str] = {} # ← 收集 函数名 → 文件绝对路径
|
||||
func_name_to_url: Dict[str, str] = {}
|
||||
|
||||
def on_progress(index, total, req, code_file, error):
|
||||
nonlocal success_count, fail_count
|
||||
|
|
@ -509,15 +579,16 @@ def step_generate_code(
|
|||
db.upsert_code_file(code_file)
|
||||
req.status = "generated"
|
||||
db.update_functional_requirement(req)
|
||||
# 收集 函数名 → 绝对文件路径(作为 url 回写)
|
||||
func_name_to_url[req.function_name] = os.path.abspath(code_file.file_path)
|
||||
console.print(
|
||||
f" [{index}/{total}] [green]✓ {req.title}[/green] "
|
||||
f"→ [dim]{code_file.file_name}[/dim]"
|
||||
f"[dim]{req.module}/{code_file.file_name}[/dim]"
|
||||
)
|
||||
success_count += 1
|
||||
|
||||
console.print(f"[yellow]开始生成 {len(func_reqs)} 个代码文件(签名约束模式)...[/yellow]")
|
||||
console.print(
|
||||
f"[yellow]开始生成 {len(func_reqs)} 个代码文件(按模块分目录)...[/yellow]"
|
||||
)
|
||||
generator.generate_batch(
|
||||
func_reqs = func_reqs,
|
||||
output_dir = output_dir,
|
||||
|
|
@ -527,11 +598,19 @@ def step_generate_code(
|
|||
on_progress = on_progress,
|
||||
)
|
||||
|
||||
# 生成 README(含模块列表)
|
||||
modules = list({req.module or config.DEFAULT_MODULE for req in func_reqs})
|
||||
req_summary = "\n".join(
|
||||
f"{i+1}. **{r.title}** (`{r.function_name}`) - {r.description[:80]}"
|
||||
f"{i+1}. **{r.title}** (`{r.module}/{r.function_name}`) - {r.description[:80]}"
|
||||
for i, r in enumerate(func_reqs)
|
||||
)
|
||||
write_project_readme(output_dir, project.name, req_summary)
|
||||
write_project_readme(
|
||||
output_dir = output_dir,
|
||||
project_name = project.name,
|
||||
project_description = project.description or "",
|
||||
requirements_summary = req_summary,
|
||||
modules = modules,
|
||||
)
|
||||
|
||||
console.print(Panel(
|
||||
f"[bold green]✅ 代码生成完成![/bold green]\n"
|
||||
|
|
@ -543,57 +622,31 @@ def step_generate_code(
|
|||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 5C:回写 url 字段并刷新 JSON
|
||||
# Step 6C:回写 url 字段并刷新 JSON
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_patch_signatures_url(
|
||||
project: Project,
|
||||
signatures: list,
|
||||
signatures: List[dict],
|
||||
func_name_to_url: Dict[str, str],
|
||||
output_dir: str,
|
||||
json_file_name: str,
|
||||
non_interactive: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
将代码文件路径回写到签名的 "url" 字段,并重新写入 JSON 文件。
|
||||
|
||||
执行流程:
|
||||
1. 调用 patch_signatures_with_url() 原地修改签名列表
|
||||
2. 打印最终签名预览(含 url 列)
|
||||
3. 重新调用 write_function_signatures_json() 覆盖写入 JSON
|
||||
|
||||
Args:
|
||||
project: 项目对象(提供 name 与 description)
|
||||
signatures: Step 5A 产出的签名列表(将被原地修改)
|
||||
func_name_to_url: Step 5B 收集的 {函数名: 文件绝对路径} 映射
|
||||
output_dir: JSON 文件所在目录
|
||||
json_file_name: JSON 文件名
|
||||
non_interactive: 是否非交互模式
|
||||
|
||||
Returns:
|
||||
刷新后的 JSON 文件绝对路径
|
||||
"""
|
||||
console.print(
|
||||
f"\n[bold]Step 5C · 回写代码文件路径(url)到签名 JSON[/bold]"
|
||||
"\n[bold]Step 6C · 回写代码路径(url)到签名 JSON[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
# 原地回写 url 字段
|
||||
patch_signatures_with_url(signatures, func_name_to_url)
|
||||
|
||||
patched = sum(1 for s in signatures if s.get("url"))
|
||||
unpatched = len(signatures) - patched
|
||||
if unpatched:
|
||||
console.print(
|
||||
f"[yellow]⚠ {unpatched} 个函数未能写入 url"
|
||||
f"(对应代码文件生成失败)[/yellow]"
|
||||
)
|
||||
console.print(f"[yellow]⚠ {unpatched} 个函数 url 未回写(代码生成失败)[/yellow]")
|
||||
|
||||
# 打印最终预览(含 url 列)
|
||||
print_signatures_preview(signatures)
|
||||
|
||||
# 覆盖写入 JSON(含 project.description)
|
||||
json_path = write_function_signatures_json(
|
||||
output_dir = output_dir,
|
||||
signatures = signatures,
|
||||
|
|
@ -601,7 +654,6 @@ def step_patch_signatures_url(
|
|||
project_description = project.description or "",
|
||||
file_name = json_file_name,
|
||||
)
|
||||
|
||||
console.print(
|
||||
f"[green]✓ 签名 JSON 已更新(含 url): "
|
||||
f"[cyan]{os.path.abspath(json_path)}[/cyan][/green]\n"
|
||||
|
|
@ -625,10 +677,10 @@ def run_workflow(
|
|||
json_file_name: str = "function_signatures.json",
|
||||
non_interactive: bool = False,
|
||||
):
|
||||
"""完整工作流(Step 1 → 5C)"""
|
||||
"""完整工作流 Step 1 → 6C"""
|
||||
print_banner()
|
||||
|
||||
# Step 1
|
||||
# Step 1:项目初始化
|
||||
project = step_init_project(
|
||||
project_name = project_name,
|
||||
language = language,
|
||||
|
|
@ -636,7 +688,7 @@ def run_workflow(
|
|||
non_interactive = non_interactive,
|
||||
)
|
||||
|
||||
# Step 2
|
||||
# Step 2:输入原始需求
|
||||
raw_text, knowledge_text, source_name, source_type = step_input_requirement(
|
||||
project = project,
|
||||
requirement_text = requirement_text,
|
||||
|
|
@ -645,7 +697,7 @@ def run_workflow(
|
|||
non_interactive = non_interactive,
|
||||
)
|
||||
|
||||
# Step 3
|
||||
# Step 3:LLM 需求分解
|
||||
raw_req_id, func_reqs = step_decompose_requirements(
|
||||
project = project,
|
||||
raw_text = raw_text,
|
||||
|
|
@ -655,7 +707,15 @@ def run_workflow(
|
|||
non_interactive = non_interactive,
|
||||
)
|
||||
|
||||
# Step 4
|
||||
# Step 4:模块分类
|
||||
func_reqs = step_classify_modules(
|
||||
project = project,
|
||||
func_reqs = func_reqs,
|
||||
knowledge_text = knowledge_text,
|
||||
non_interactive = non_interactive,
|
||||
)
|
||||
|
||||
# Step 5:编辑功能需求
|
||||
func_reqs = step_edit_requirements(
|
||||
project = project,
|
||||
func_reqs = func_reqs,
|
||||
|
|
@ -670,7 +730,7 @@ def run_workflow(
|
|||
|
||||
output_dir = ensure_project_dir(project.name)
|
||||
|
||||
# Step 5A:生成签名(初版,不含 url)
|
||||
# Step 6A:生成函数签名
|
||||
signatures, json_path = step_generate_signatures(
|
||||
project = project,
|
||||
func_reqs = func_reqs,
|
||||
|
|
@ -680,7 +740,7 @@ def run_workflow(
|
|||
non_interactive = non_interactive,
|
||||
)
|
||||
|
||||
# Step 5B:生成代码,收集 {函数名: 文件路径}
|
||||
# Step 6B:生成代码文件
|
||||
func_name_to_url = step_generate_code(
|
||||
project = project,
|
||||
func_reqs = func_reqs,
|
||||
|
|
@ -690,7 +750,7 @@ def run_workflow(
|
|||
non_interactive = non_interactive,
|
||||
)
|
||||
|
||||
# Step 5C:回写 url 字段,刷新 JSON
|
||||
# Step 6C:回写 url,刷新 JSON
|
||||
json_path = step_patch_signatures_url(
|
||||
project = project,
|
||||
signatures = signatures,
|
||||
|
|
@ -700,10 +760,13 @@ def run_workflow(
|
|||
non_interactive = non_interactive,
|
||||
)
|
||||
|
||||
# 最终汇总
|
||||
modules = sorted({req.module or config.DEFAULT_MODULE for req in func_reqs})
|
||||
console.print(Panel(
|
||||
f"[bold cyan]🎉 全部流程完成![/bold cyan]\n"
|
||||
f"项目: [bold]{project.name}[/bold]\n"
|
||||
f"描述: {project.description or '(无)'}\n"
|
||||
f"模块: {', '.join(modules)}\n"
|
||||
f"代码目录: [cyan]{os.path.abspath(output_dir)}[/cyan]\n"
|
||||
f"签名文件: [cyan]{json_path}[/cyan]",
|
||||
border_style="cyan",
|
||||
|
|
@ -716,24 +779,21 @@ def run_workflow(
|
|||
|
||||
@click.command()
|
||||
@click.option("--non-interactive", is_flag=True, default=False,
|
||||
help="以非交互模式运行(所有参数通过命令行传入)")
|
||||
help="以非交互模式运行")
|
||||
@click.option("--project-name", "-p", default=None, help="项目名称")
|
||||
@click.option("--language", "-l", default=None,
|
||||
type=click.Choice(["python","javascript","typescript","java","go","rust"]),
|
||||
help=f"目标代码语言(默认: {config.DEFAULT_LANGUAGE})")
|
||||
@click.option("--description", "-d", default="", help="项目描述")
|
||||
@click.option("--requirement-text","-r", default=None,
|
||||
help="原始需求文本(与 --requirement-file 二选一)")
|
||||
@click.option("--requirement-text","-r", default=None, help="原始需求文本")
|
||||
@click.option("--requirement-file","-f", default=None,
|
||||
type=click.Path(exists=True),
|
||||
help="原始需求文件路径(支持 .txt/.md/.pdf/.docx)")
|
||||
type=click.Path(exists=True), help="原始需求文件路径")
|
||||
@click.option("--knowledge-file", "-k", default=None, multiple=True,
|
||||
type=click.Path(exists=True),
|
||||
help="知识库文件路径(可多次指定,如 -k a.md -k b.pdf)")
|
||||
type=click.Path(exists=True), help="知识库文件(可多次指定)")
|
||||
@click.option("--skip-index", "-s", default=None, multiple=True, type=int,
|
||||
help="要跳过的功能需求序号(可多次指定,如 -s 2 -s 5)")
|
||||
help="跳过的功能需求序号(可多次指定)")
|
||||
@click.option("--json-file-name", "-j", default="function_signatures.json",
|
||||
help="函数签名 JSON 文件名(默认: function_signatures.json)")
|
||||
help="签名 JSON 文件名")
|
||||
def cli(
|
||||
non_interactive, project_name, language, description,
|
||||
requirement_text, requirement_file, knowledge_file,
|
||||
|
|
@ -743,7 +803,7 @@ def cli(
|
|||
需求分析 & 代码生成工具
|
||||
|
||||
\b
|
||||
交互式运行(推荐初次使用):
|
||||
交互式运行:
|
||||
python main.py
|
||||
|
||||
\b
|
||||
|
|
@ -752,16 +812,7 @@ def cli(
|
|||
--project-name "UserSystem" \\
|
||||
--description "用户管理系统后端服务" \\
|
||||
--language python \\
|
||||
--requirement-text "用户管理系统,包含注册、登录、修改密码功能" \\
|
||||
--knowledge-file docs/api_spec.md \\
|
||||
--json-file-name api_signatures.json
|
||||
|
||||
\b
|
||||
从文件读取需求 + 跳过部分功能需求:
|
||||
python main.py --non-interactive \\
|
||||
--project-name "MyProject" \\
|
||||
--requirement-file requirements.md \\
|
||||
--skip-index 3 --skip-index 7
|
||||
--requirement-text "用户管理系统,包含注册、登录、修改密码功能"
|
||||
"""
|
||||
try:
|
||||
run_workflow(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
openai>=1.0.0
|
||||
python-dotenv>=1.0.0
|
||||
rich>=13.0.0
|
||||
python-docx>=0.8.11
|
||||
PyPDF2>=3.0.0
|
||||
openai>=1.30.0
|
||||
click>=8.1.0
|
||||
rich>=13.0.0
|
||||
sqlalchemy>=2.0.0
|
||||
python-dotenv>=1.0.0
|
||||
pypdf>=4.0.0
|
||||
python-docx>=1.1.0
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 设置环境变量
|
||||
set OPENAI_API_KEY="sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR"
|
||||
set OPENAI_BASE_URL="https://openapi.monica.im/v1" # 或其他兼容接口
|
||||
set LLM_MODEL="gpt-4o"
|
||||
|
||||
python main.py
|
||||
|
|
@ -1,95 +1,13 @@
|
|||
# utils/file_handler.py - 文件读取工具(支持 txt / md / pdf / docx)
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def read_text_file(file_path: str) -> str:
|
||||
"""
|
||||
读取纯文本文件内容(.txt / .md / .py 等)
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
文件文本内容
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
UnicodeDecodeError: 编码错误时尝试 latin-1 兜底
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return path.read_text(encoding="latin-1")
|
||||
|
||||
|
||||
def read_pdf_file(file_path: str) -> str:
|
||||
"""
|
||||
读取 PDF 文件内容
|
||||
|
||||
Args:
|
||||
file_path: PDF 文件路径
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 PyPDF2
|
||||
FileNotFoundError: 文件不存在
|
||||
"""
|
||||
try:
|
||||
import PyPDF2
|
||||
except ImportError:
|
||||
raise ImportError("请安装 PyPDF2: pip install PyPDF2")
|
||||
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
texts = []
|
||||
with open(file_path, "rb") as f:
|
||||
reader = PyPDF2.PdfReader(f)
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
texts.append(text)
|
||||
return "\n".join(texts)
|
||||
|
||||
|
||||
def read_docx_file(file_path: str) -> str:
|
||||
"""
|
||||
读取 Word (.docx) 文件内容
|
||||
|
||||
Args:
|
||||
file_path: docx 文件路径
|
||||
|
||||
Returns:
|
||||
提取的文本内容(段落合并)
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 python-docx
|
||||
FileNotFoundError: 文件不存在
|
||||
"""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError("请安装 python-docx: pip install python-docx")
|
||||
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
doc = Document(file_path)
|
||||
return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
|
||||
from typing import List
|
||||
|
||||
|
||||
def read_file_auto(file_path: str) -> str:
|
||||
"""
|
||||
根据文件扩展名自动选择读取方式
|
||||
自动识别文件类型并读取文本内容。
|
||||
|
||||
支持格式:.txt / .md / .pdf / .docx / 其他(按 UTF-8 读取)
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
|
@ -98,45 +16,66 @@ def read_file_auto(file_path: str) -> str:
|
|||
文件文本内容
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的文件类型
|
||||
FileNotFoundError: 文件不存在
|
||||
RuntimeError: 读取失败
|
||||
"""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
readers = {
|
||||
".txt": read_text_file,
|
||||
".md": read_text_file,
|
||||
".py": read_text_file,
|
||||
".json": read_text_file,
|
||||
".yaml": read_text_file,
|
||||
".yml": read_text_file,
|
||||
".pdf": read_pdf_file,
|
||||
".docx": read_docx_file,
|
||||
}
|
||||
reader = readers.get(ext)
|
||||
if reader is None:
|
||||
raise ValueError(f"不支持的文件类型: {ext},支持: {list(readers.keys())}")
|
||||
return reader(file_path)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
try:
|
||||
if ext == ".pdf":
|
||||
return _read_pdf(file_path)
|
||||
elif ext in (".docx", ".doc"):
|
||||
return _read_docx(file_path)
|
||||
else:
|
||||
# txt / md / 其他文本格式
|
||||
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"读取文件失败 [{file_path}]: {e}")
|
||||
|
||||
|
||||
def _read_pdf(file_path: str) -> str:
|
||||
"""读取 PDF 文件文本"""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
raise RuntimeError("读取 PDF 需要安装 pypdf:pip install pypdf")
|
||||
|
||||
reader = PdfReader(file_path)
|
||||
pages = [page.extract_text() or "" for page in reader.pages]
|
||||
return "\n".join(pages)
|
||||
|
||||
|
||||
def _read_docx(file_path: str) -> str:
|
||||
"""读取 Word 文档文本"""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise RuntimeError("读取 docx 需要安装 python-docx:pip install python-docx")
|
||||
|
||||
doc = Document(file_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return "\n".join(paragraphs)
|
||||
|
||||
|
||||
def merge_knowledge_files(file_paths: List[str]) -> str:
|
||||
"""
|
||||
合并多个知识库文件为单一文本
|
||||
合并多个知识库文件为单一文本。
|
||||
|
||||
Args:
|
||||
file_paths: 知识库文件路径列表
|
||||
file_paths: 文件路径列表
|
||||
|
||||
Returns:
|
||||
合并后的知识库文本(包含文件名分隔符)
|
||||
合并后的文本(各文件以分隔线隔开)
|
||||
"""
|
||||
if not file_paths:
|
||||
return ""
|
||||
|
||||
sections = []
|
||||
for fp in file_paths:
|
||||
parts = []
|
||||
for path in file_paths:
|
||||
try:
|
||||
content = read_file_auto(fp)
|
||||
file_name = Path(fp).name
|
||||
sections.append(f"### 知识库文件: {file_name}\n{content}")
|
||||
content = read_file_auto(path)
|
||||
parts.append(f"--- {os.path.basename(path)} ---\n{content}")
|
||||
except Exception as e:
|
||||
sections.append(f"### 知识库文件: {fp}\n[读取失败: {e}]")
|
||||
|
||||
return "\n\n".join(sections)
|
||||
parts.append(f"--- {os.path.basename(path)} [读取失败: {e}] ---")
|
||||
return "\n\n".join(parts)
|
||||
|
|
@ -6,133 +6,74 @@ from typing import Dict, List
|
|||
|
||||
import config
|
||||
|
||||
|
||||
# 各语言文件扩展名映射
|
||||
LANGUAGE_EXT_MAP: Dict[str, str] = {
|
||||
"python": ".py",
|
||||
"javascript": ".js",
|
||||
"typescript": ".ts",
|
||||
"java": ".java",
|
||||
"go": ".go",
|
||||
"rust": ".rs",
|
||||
"cpp": ".cpp",
|
||||
"c": ".c",
|
||||
"csharp": ".cs",
|
||||
"ruby": ".rb",
|
||||
"php": ".php",
|
||||
"swift": ".swift",
|
||||
"kotlin": ".kt",
|
||||
}
|
||||
|
||||
# 合法的通用类型集合
|
||||
VALID_TYPES = {
|
||||
"integer", "string", "boolean", "float",
|
||||
"list", "dict", "object", "void", "any",
|
||||
}
|
||||
|
||||
# 合法的 inout 值
|
||||
VALID_INOUT = {"in", "out", "inout"}
|
||||
|
||||
|
||||
def get_file_extension(language: str) -> str:
|
||||
"""
|
||||
获取指定语言的文件扩展名
|
||||
|
||||
Args:
|
||||
language: 编程语言名称(小写)
|
||||
|
||||
Returns:
|
||||
文件扩展名(含点号,如 '.py')
|
||||
"""
|
||||
return LANGUAGE_EXT_MAP.get(language.lower(), ".txt")
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# 目录管理
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def build_project_output_dir(project_name: str) -> str:
|
||||
"""
|
||||
构建项目输出目录路径
|
||||
|
||||
Args:
|
||||
project_name: 项目名称
|
||||
|
||||
Returns:
|
||||
输出目录路径
|
||||
"""
|
||||
safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name)
|
||||
return os.path.join(config.OUTPUT_BASE_DIR, safe_name)
|
||||
safe = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name)
|
||||
return os.path.join(config.OUTPUT_BASE_DIR, safe)
|
||||
|
||||
|
||||
def ensure_project_dir(project_name: str) -> str:
|
||||
"""
|
||||
确保项目输出目录存在,不存在则创建
|
||||
|
||||
Args:
|
||||
project_name: 项目名称
|
||||
|
||||
Returns:
|
||||
创建好的目录路径
|
||||
"""
|
||||
"""确保项目根输出目录存在,并创建 __init__.py"""
|
||||
output_dir = build_project_output_dir(project_name)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
init_file = os.path.join(output_dir, "__init__.py")
|
||||
if not os.path.exists(init_file):
|
||||
Path(init_file).write_text(
|
||||
"# Auto-generated project package\n", encoding="utf-8"
|
||||
)
|
||||
Path(init_file).write_text("# Auto-generated project package\n", encoding="utf-8")
|
||||
return output_dir
|
||||
|
||||
|
||||
def write_code_file(
|
||||
output_dir: str,
|
||||
function_name: str,
|
||||
language: str,
|
||||
content: str,
|
||||
) -> str:
|
||||
"""
|
||||
将代码内容写入指定目录的文件
|
||||
def ensure_module_dir(output_dir: str, module: str) -> str:
|
||||
"""确保模块子目录存在,并创建 __init__.py"""
|
||||
module_dir = os.path.join(output_dir, module)
|
||||
os.makedirs(module_dir, exist_ok=True)
|
||||
init_file = os.path.join(module_dir, "__init__.py")
|
||||
if not os.path.exists(init_file):
|
||||
Path(init_file).write_text(
|
||||
f"# Auto-generated module package: {module}\n", encoding="utf-8"
|
||||
)
|
||||
return module_dir
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录路径
|
||||
function_name: 函数名(用于生成文件名)
|
||||
language: 编程语言
|
||||
content: 代码内容
|
||||
|
||||
Returns:
|
||||
写入的文件完整路径
|
||||
"""
|
||||
ext = get_file_extension(language)
|
||||
file_name = f"{function_name}{ext}"
|
||||
file_path = os.path.join(output_dir, file_name)
|
||||
Path(file_path).write_text(content, encoding="utf-8")
|
||||
return file_path
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# README
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def write_project_readme(
|
||||
output_dir: str,
|
||||
project_name: str,
|
||||
project_description: str,
|
||||
requirements_summary: str,
|
||||
modules: List[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
在项目目录生成 README.md 文件
|
||||
"""生成项目 README.md"""
|
||||
module_section = ""
|
||||
if modules:
|
||||
module_list = "\n".join(f"- `{m}/`" for m in sorted(set(modules)))
|
||||
module_section = f"\n## 功能模块\n\n{module_list}\n"
|
||||
|
||||
Args:
|
||||
output_dir: 项目输出目录
|
||||
project_name: 项目名称
|
||||
requirements_summary: 功能需求摘要文本
|
||||
|
||||
Returns:
|
||||
README.md 文件路径
|
||||
"""
|
||||
readme_content = f"""# {project_name}
|
||||
content = f"""# {project_name}
|
||||
|
||||
> Auto-generated by Requirement Analyzer
|
||||
|
||||
{project_description or ""}
|
||||
{module_section}
|
||||
## 功能需求列表
|
||||
|
||||
{requirements_summary}
|
||||
"""
|
||||
readme_path = os.path.join(output_dir, "README.md")
|
||||
Path(readme_path).write_text(readme_content, encoding="utf-8")
|
||||
return readme_path
|
||||
path = os.path.join(output_dir, "README.md")
|
||||
Path(path).write_text(content, encoding="utf-8")
|
||||
return path
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
|
@ -145,19 +86,11 @@ def build_signatures_document(
|
|||
signatures: List[dict],
|
||||
) -> dict:
|
||||
"""
|
||||
将函数签名列表包装为带项目信息的顶层文档结构。
|
||||
|
||||
Args:
|
||||
project_name: 项目名称,写入 "project" 字段
|
||||
project_description: 项目描述,写入 "description" 字段
|
||||
signatures: 函数签名 dict 列表,写入 "functions" 字段
|
||||
|
||||
Returns:
|
||||
顶层文档 dict,结构为::
|
||||
构建顶层签名文档结构::
|
||||
|
||||
{
|
||||
"project": "<project_name>",
|
||||
"description": "<project_description>",
|
||||
"project": "<name>",
|
||||
"description": "<description>",
|
||||
"functions": [ ... ]
|
||||
}
|
||||
"""
|
||||
|
|
@ -173,66 +106,28 @@ def patch_signatures_with_url(
|
|||
func_name_to_url: Dict[str, str],
|
||||
) -> List[dict]:
|
||||
"""
|
||||
将代码文件的路径(URL)回写到对应函数签名的 "url" 字段。
|
||||
|
||||
遍历签名列表,根据 signature["name"] 在 func_name_to_url 中查找
|
||||
对应路径,找到则写入 "url" 字段;未找到则写入空字符串,不抛出异常。
|
||||
|
||||
"url" 字段插入位置紧跟在 "type" 字段之后,以保持字段顺序的可读性::
|
||||
|
||||
{
|
||||
"name": "create_user",
|
||||
"requirement_id": "REQ.01",
|
||||
"description": "...",
|
||||
"type": "function",
|
||||
"url": "/abs/path/to/create_user.py", ← 新增
|
||||
"parameters": { ... },
|
||||
"return": { ... }
|
||||
}
|
||||
将代码文件路径回写到签名的 "url" 字段(紧跟 "type" 之后)。
|
||||
|
||||
Args:
|
||||
signatures: 原始签名列表(in-place 修改)
|
||||
func_name_to_url: {函数名: 代码文件绝对路径} 映射表,
|
||||
由 CodeGenerator.generate_batch() 的进度回调收集
|
||||
signatures: 签名列表(in-place 修改)
|
||||
func_name_to_url: {函数名: 文件绝对路径}
|
||||
|
||||
Returns:
|
||||
修改后的签名列表(与传入的同一对象,方便链式调用)
|
||||
修改后的签名列表
|
||||
"""
|
||||
for sig in signatures:
|
||||
func_name = sig.get("name", "")
|
||||
url = func_name_to_url.get(func_name, "")
|
||||
url = func_name_to_url.get(sig.get("name", ""), "")
|
||||
_insert_field_after(sig, after_key="type", new_key="url", new_value=url)
|
||||
return signatures
|
||||
|
||||
|
||||
def _insert_field_after(
|
||||
d: dict,
|
||||
after_key: str,
|
||||
new_key: str,
|
||||
new_value,
|
||||
) -> None:
|
||||
"""
|
||||
在有序 dict 中将 new_key 插入到 after_key 之后。
|
||||
若 after_key 不存在,则追加到末尾。
|
||||
若 new_key 已存在,则直接更新其值(不改变位置)。
|
||||
|
||||
Args:
|
||||
d: 目标 dict(Python 3.7+ 保证插入顺序)
|
||||
after_key: 参考键名
|
||||
new_key: 要插入的键名
|
||||
new_value: 要插入的值
|
||||
"""
|
||||
def _insert_field_after(d: dict, after_key: str, new_key: str, new_value) -> None:
|
||||
"""在有序 dict 中将 new_key 插入到 after_key 之后"""
|
||||
if new_key in d:
|
||||
d[new_key] = new_value
|
||||
return
|
||||
|
||||
items = list(d.items())
|
||||
insert_pos = len(items)
|
||||
for i, (k, _) in enumerate(items):
|
||||
if k == after_key:
|
||||
insert_pos = i + 1
|
||||
break
|
||||
|
||||
insert_pos = next((i + 1 for i, (k, _) in enumerate(items) if k == after_key), len(items))
|
||||
items.insert(insert_pos, (new_key, new_value))
|
||||
d.clear()
|
||||
d.update(items)
|
||||
|
|
@ -245,42 +140,7 @@ def write_function_signatures_json(
|
|||
project_description: str,
|
||||
file_name: str = "function_signatures.json",
|
||||
) -> str:
|
||||
"""
|
||||
将函数签名列表连同项目信息一起导出为 JSON 文件。
|
||||
|
||||
输出的 JSON 顶层结构为::
|
||||
|
||||
{
|
||||
"project": "<project_name>",
|
||||
"description": "<project_description>",
|
||||
"functions": [
|
||||
{
|
||||
"name": "...",
|
||||
"requirement_id": "...",
|
||||
"description": "...",
|
||||
"type": "function",
|
||||
"url": "/abs/path/to/xxx.py",
|
||||
"parameters": { ... },
|
||||
"return": { ... }
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
Args:
|
||||
output_dir: JSON 文件写入目录
|
||||
signatures: 函数签名 dict 列表(应已通过
|
||||
patch_signatures_with_url() 写入 "url" 字段)
|
||||
project_name: 项目名称
|
||||
project_description: 项目描述
|
||||
file_name: 输出文件名,默认 function_signatures.json
|
||||
|
||||
Returns:
|
||||
写入的 JSON 文件完整路径
|
||||
|
||||
Raises:
|
||||
OSError: 目录不可写
|
||||
"""
|
||||
"""将签名列表导出为 JSON 文件"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
document = build_signatures_document(project_name, project_description, signatures)
|
||||
file_path = os.path.join(output_dir, file_name)
|
||||
|
|
@ -294,137 +154,70 @@ def write_function_signatures_json(
|
|||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def validate_signature_schema(signature: dict) -> List[str]:
|
||||
"""
|
||||
校验单个函数签名 dict 是否符合规范。
|
||||
|
||||
校验范围:
|
||||
- 顶层必填字段:name / requirement_id / description / type / parameters
|
||||
- 可选字段 "url":若存在则必须为非空字符串
|
||||
- parameters:每个参数的 type / inout / required 字段
|
||||
- return:type 字段 + on_success / on_failure 子结构
|
||||
- void 函数:on_success / on_failure 应为 null
|
||||
- 非 void 函数:on_success / on_failure 必须存在,
|
||||
且 value(非空)与 description(非空)均需填写
|
||||
|
||||
Args:
|
||||
signature: 单个函数签名 dict
|
||||
|
||||
Returns:
|
||||
错误信息字符串列表,列表为空表示校验通过
|
||||
"""
|
||||
"""校验单个函数签名结构,返回错误列表(空列表表示通过)"""
|
||||
errors: List[str] = []
|
||||
|
||||
# ── 顶层必填字段 ──────────────────────────────────
|
||||
for key in ("name", "requirement_id", "description", "type", "parameters"):
|
||||
if key not in signature:
|
||||
errors.append(f"缺少顶层字段: '{key}'")
|
||||
|
||||
# ── url 字段(可选,存在时校验非空)─────────────────
|
||||
if "url" in signature:
|
||||
if not isinstance(signature["url"], str):
|
||||
errors.append("'url' 字段必须是字符串类型")
|
||||
elif signature["url"] == "":
|
||||
errors.append("'url' 字段不能为空字符串(代码文件路径未成功回写)")
|
||||
errors.append("'url' 字段不能为空字符串")
|
||||
|
||||
# ── parameters ────────────────────────────────────
|
||||
params = signature.get("parameters", {})
|
||||
if not isinstance(params, dict):
|
||||
errors.append("'parameters' 必须是 dict 类型")
|
||||
else:
|
||||
if isinstance(params, dict):
|
||||
for pname, pdef in params.items():
|
||||
if not isinstance(pdef, dict):
|
||||
errors.append(f"参数 '{pname}' 定义必须是 dict")
|
||||
continue
|
||||
# type(支持联合类型,如 "string|integer")
|
||||
if "type" not in pdef:
|
||||
errors.append(f"参数 '{pname}' 缺少 'type' 字段")
|
||||
errors.append(f"参数 '{pname}' 缺少 'type'")
|
||||
else:
|
||||
parts = [p.strip() for p in pdef["type"].split("|")]
|
||||
if not all(p in VALID_TYPES for p in parts):
|
||||
errors.append(
|
||||
f"参数 '{pname}' 的 type='{pdef['type']}' 含有不合法的类型"
|
||||
)
|
||||
# inout
|
||||
errors.append(f"参数 '{pname}' type='{pdef['type']}' 含不合法类型")
|
||||
if "inout" not in pdef:
|
||||
errors.append(f"参数 '{pname}' 缺少 'inout' 字段")
|
||||
errors.append(f"参数 '{pname}' 缺少 'inout'")
|
||||
elif pdef["inout"] not in VALID_INOUT:
|
||||
errors.append(
|
||||
f"参数 '{pname}' 的 inout='{pdef['inout']}' 应为 in/out/inout"
|
||||
)
|
||||
# required
|
||||
errors.append(f"参数 '{pname}' inout='{pdef['inout']}' 应为 in/out/inout")
|
||||
if "required" not in pdef:
|
||||
errors.append(f"参数 '{pname}' 缺少 'required' 字段")
|
||||
errors.append(f"参数 '{pname}' 缺少 'required'")
|
||||
elif not isinstance(pdef["required"], bool):
|
||||
errors.append(
|
||||
f"参数 '{pname}' 的 'required' 应为布尔值 true/false,"
|
||||
f"当前为: {pdef['required']!r}"
|
||||
)
|
||||
errors.append(f"参数 '{pname}' 'required' 应为布尔值")
|
||||
|
||||
# ── return ────────────────────────────────────────
|
||||
ret = signature.get("return")
|
||||
if ret is None:
|
||||
errors.append(
|
||||
"缺少 'return' 字段(void 函数请填 "
|
||||
"{\"type\": \"void\", \"on_success\": null, \"on_failure\": null})"
|
||||
)
|
||||
elif not isinstance(ret, dict):
|
||||
errors.append("'return' 必须是 dict 类型")
|
||||
else:
|
||||
errors.append("缺少 'return' 字段")
|
||||
elif isinstance(ret, dict):
|
||||
ret_type = ret.get("type")
|
||||
if not ret_type:
|
||||
errors.append("'return' 缺少 'type' 字段")
|
||||
errors.append("'return' 缺少 'type'")
|
||||
elif ret_type not in VALID_TYPES:
|
||||
errors.append(f"'return.type'='{ret_type}' 不在合法类型列表中")
|
||||
|
||||
errors.append(f"'return.type'='{ret_type}' 不合法")
|
||||
is_void = (ret_type == "void")
|
||||
|
||||
for sub_key in ("on_success", "on_failure"):
|
||||
sub = ret.get(sub_key)
|
||||
if is_void:
|
||||
if sub is not None:
|
||||
errors.append(
|
||||
f"void 函数的 'return.{sub_key}' 应为 null,"
|
||||
f"当前为: {sub!r}"
|
||||
)
|
||||
errors.append(f"void 函数 'return.{sub_key}' 应为 null")
|
||||
else:
|
||||
if sub is None:
|
||||
errors.append(
|
||||
f"非 void 函数缺少 'return.{sub_key}',"
|
||||
f"请描述{'成功' if sub_key == 'on_success' else '失败'}时的返回值"
|
||||
)
|
||||
elif not isinstance(sub, dict):
|
||||
errors.append(f"'return.{sub_key}' 必须是 dict 类型")
|
||||
else:
|
||||
if "value" not in sub:
|
||||
errors.append(f"'return.{sub_key}' 缺少 'value' 字段")
|
||||
elif sub["value"] == "":
|
||||
errors.append(
|
||||
f"'return.{sub_key}.value' 不能为空字符串,"
|
||||
f"请填写具体返回值、值域描述或结构示例"
|
||||
)
|
||||
if "description" not in sub or sub.get("description") in (None, ""):
|
||||
errors.append(f"非 void 函数缺少 'return.{sub_key}'")
|
||||
elif isinstance(sub, dict):
|
||||
if not sub.get("value"):
|
||||
errors.append(f"'return.{sub_key}.value' 不能为空")
|
||||
if not sub.get("description"):
|
||||
errors.append(f"'return.{sub_key}.description' 不能为空")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_all_signatures(signatures: List[dict]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
批量校验函数签名列表。
|
||||
|
||||
注意:此函数接受的是纯签名列表(即顶层文档的 "functions" 字段),
|
||||
而非包含 project/description 的顶层文档。
|
||||
|
||||
Args:
|
||||
signatures: 函数签名 dict 列表
|
||||
|
||||
Returns:
|
||||
{函数名: [错误信息, ...]} 字典,仅包含有错误的条目
|
||||
"""
|
||||
report: Dict[str, List[str]] = {}
|
||||
for sig in signatures:
|
||||
name = sig.get("name", f"unknown_{id(sig)}")
|
||||
errs = validate_signature_schema(sig)
|
||||
if errs:
|
||||
report[name] = errs
|
||||
return report
|
||||
"""批量校验,返回 {函数名: [错误]} 字典(仅含有错误的条目)"""
|
||||
return {
|
||||
sig.get("name", f"unknown_{i}"): errs
|
||||
for i, sig in enumerate(signatures)
|
||||
if (errs := validate_signature_schema(sig))
|
||||
}
|
||||
Loading…
Reference in New Issue