支持功能模块

This commit is contained in:
liusongtao 2026-03-05 13:38:26 +08:00
parent 3fc095652e
commit 29636b9b94
12 changed files with 1142 additions and 1608 deletions

View File

@ -1,146 +1,123 @@
# config.py - 全局配置管理 # config.py - 全局配置
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# ── LLM 配置 ────────────────────────────────────────── # ── LLM ──────────────────────────────────────────────
LLM_API_KEY = os.getenv("OPENAI_API_KEY", "") LLM_API_KEY = os.getenv("OPENAI_API_KEY", "")
LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") LLM_API_BASE = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o") LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0.3")) LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT", "60"))
LLM_MAX_RETRY = int(os.getenv("LLM_MAX_RETRY", "3"))
# ── 数据库配置 ───────────────────────────────────────── # ── 数据库 ────────────────────────────────────────────
DB_PATH = os.getenv("DB_PATH", "data/requirement_analyzer.db") DB_PATH = os.getenv("DB_PATH", "data/requirement_analyzer.db")
# ── 输出配置 ─────────────────────────────────────────── # ── 输出目录 ──────────────────────────────────────────
OUTPUT_BASE_DIR = os.getenv("OUTPUT_BASE_DIR", "output") OUTPUT_BASE_DIR = os.getenv("OUTPUT_BASE_DIR", "output")
DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "python") DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "python")
DEFAULT_MODULE = os.getenv("DEFAULT_MODULE", "default")
# ══════════════════════════════════════════════════════ # ── Prompt 模板 ───────────────────────────────────────
# Prompt 模板
# ══════════════════════════════════════════════════════
DECOMPOSE_PROMPT_TEMPLATE = """ DECOMPOSE_PROMPT_TEMPLATE = """\
你是一位资深软件架构师和产品经理请根据以下信息将原始需求分解为若干个可独立实现的功能需求 你是一名资深软件架构师请将以下原始需求分解为独立的功能需求列表
{knowledge_section} 原始需求
## 原始需求
{raw_requirement} {raw_requirement}
## 输出要求 {knowledge_section}
请严格按照以下 JSON 格式输出不要包含任何额外说明
{{ 输出要求
"functional_requirements": [ JSON 数组格式输出每个元素包含以下字段
- title: 功能标题简短10字以内
- description: 功能描述详细说明该功能的职责与边界50字以内
- function_name: 对应的函数名snake_case动词开头
- priority: 优先级high / medium / low
- module: 所属功能模块名称snake_case user_auth / order_service
示例输出
[
{{ {{
"index": 1, "title": "用户注册",
"title": "功能需求标题简洁10字以内", "description": "接收用户名、密码、邮箱校验合法性后创建用户账号并返回用户ID",
"description": "功能需求详细描述(包含输入、处理逻辑、输出)", "function_name": "register_user",
"function_name": "snake_case函数名", "priority": "high",
"priority": "high|medium|low" "module": "user_auth"
}} }}
] ]
}}
要求 只输出 JSON 数组不要有任何额外说明
1. 每个功能需求必须是独立可实现的最小单元
2. function_name 使用 snake_case 命名清晰表达函数用途
3. 分解粒度适中通常 5-15 个功能需求
4. 优先级根据业务重要性判断
""" """
# ── 函数签名 JSON 生成 Prompt ────────────────────────── FUNC_SIGNATURE_PROMPT_TEMPLATE = """\
FUNC_SIGNATURE_PROMPT_TEMPLATE = """ 你是一名资深软件工程师请根据以下功能需求生成标准函数签名信息
你是一位资深软件架构师请根据以下功能需求描述设计该函数的完整接口签名并以 JSON 格式输出
功能需求
- 需求编号: {requirement_id}
- 标题: {title}
- 描述: {description}
- 函数名: {function_name}
- 所属模块: {module}
{knowledge_section} {knowledge_section}
## 功能需求 输出要求
需求编号{requirement_id} JSON 对象格式输出包含以下字段
标题{title} - name: 函数名与上方一致
函数名{function_name} - requirement_id: 需求编号
详细描述{description} - description: 函数功能描述英文一句话
- type: 固定为 "function"
- module: 所属模块名称
- parameters: 参数字典key 为参数名value 包含
- type: 数据类型integer/string/boolean/float/list/dict/object/void/any
- inout: in / out / inout
- required: true / false
- description: 参数说明
- return:
- type: 返回类型
- on_success: {{ "value": "...", "description": "..." }} nullvoid
- on_failure: {{ "value": "...", "description": "..." }} nullvoid
## 输出格式 只输出 JSON 对象不要有任何额外说明
请严格按照以下 JSON 结构输出不要包含任何额外说明或 markdown 标记
{{
"name": "{function_name}",
"requirement_id": "{requirement_id}",
"description": "简洁的一句话功能描述(英文)",
"type": "function",
"parameters": {{
"<param_name>": {{
"type": "integer|string|boolean|float|list|dict|object",
"inout": "in|out|inout",
"description": "参数说明(英文)",
"required": true
}}
}},
"return": {{
"type": "integer|string|boolean|float|list|dict|object|void",
"description": "整体返回值说明(英文,一句话概括)",
"on_success": {{
"value": "具体成功返回值或范围,如 0、true、user object、list of items 等",
"description": "成功时的返回值含义(英文)"
}},
"on_failure": {{
"value": "具体失败返回值或范围,如 nonzero、false、null、empty list、raises Exception 等",
"description": "失败时的返回值含义,或抛出的异常类型(英文)"
}}
}}
}}
## 设计规范
1. 参数名使用 snake_case类型使用通用类型不绑定具体语言
2. inout 字段含义
- in = 仅输入参数
- out = 仅输出参数通过参数传出结果如指针/引用
- inout = 既作输入又作输出
3. 所有描述字段使用英文
4. return 字段规则
- 若函数无返回值voidtype "void"on_success/on_failure 均填 null
- 若返回值只有成功场景如纯查询on_failure 可描述为 "null or empty"
- on_success.value / on_failure.value 填写具体值或值域描述不要填写空字符串
5. 若函数无参数parameters {{}}
6. required 字段为布尔值 true false
""" """
# ── 代码生成 Prompt含签名约束───────────────────────── CODE_GEN_PROMPT_TEMPLATE = """\
CODE_GEN_PROMPT_TEMPLATE = """ 你是一名资深 {language} 工程师请根据以下函数签名和功能描述生成完整的函数实现代码
你是一位资深 {language} 工程师请根据以下功能需求和函数签名规范生成完整的 {language} 函数代码
{knowledge_section} 函数签名
## 功能需求
标题{title}
描述{description}
## 【必须严格遵守】函数签名规范
以下 JSON 定义了函数的精确接口生成的代码必须与之完全一致不得擅自增减或改名参数
```json
{signature_json} {signature_json}
```
### 签名字段说明 功能描述
- `name`函数名必须完全一致 {description}
- `parameters`每个 key 即为参数名`type` 为数据类型`inout` 含义
- `in` = 普通输入参数
- `out` = 输出参数Python 中通过返回值或可变容器传出
- `inout` = 既作输入又作输出
- `return.type`返回值类型
- `return.on_success`成功时的返回值代码实现必须与此一致
- `return.on_failure`失败时的返回值或异常代码实现必须与此一致
## 输出要求 {knowledge_section}
1. 只输出纯代码不要包含 markdown 代码块标记
2. 函数签名名称参数列表返回类型必须与上方 JSON 规范完全一致 输出要求
3. 成功/失败的返回值必须严格遵守 return.on_success / return.on_failure 的定义 1. 只输出 {language} 代码不要有任何 Markdown 标记不要 ```
4. 包含完整的类型注解Python 使用 type hints 2. 包含完整的函数实现含必要的 import
5. 包含详细的 docstring其中 Returns 段须注明成功值与失败值 3. 包含函数文档注释docstring / JSDoc
6. 包含必要的异常处理 4. 包含基本的参数校验与错误处理
7. 代码风格遵循 PEP8Python或对应语言规范 5. 代码风格遵循 {language} 最佳实践
8. 在文件顶部用注释注明需求编号功能标题函数签名摘要 """
9. 如需导入第三方库请在顶部统一导入
MODULE_CLASSIFY_PROMPT_TEMPLATE = """\
你是一名资深软件架构师请将以下功能需求列表分类到合适的功能模块中
功能需求列表
{requirements_json}
输出要求
JSON 数组格式输出每个元素包含
- function_name: 函数名与输入一致
- module: 所属模块名称snake_case user_auth / order_service / payment
模块划分原则
1. 功能相近的需求归入同一模块
2. 模块名使用英文 snake_case
3. 模块数量控制在 2~8 个之间
4. 若某需求确实无法归类使用 "default" 模块
只输出 JSON 数组不要有任何额外说明
""" """

View File

@ -1,99 +1,89 @@
# core/code_generator.py - 代码生成核心逻辑(签名约束版) # core/code_generator.py - 代码生成(按模块路由到子目录)
import os
import json import json
from typing import Optional, List, Callable from pathlib import Path
from typing import List, Optional, Callable
import config import config
from core.llm_client import LLMClient from core.llm_client import LLMClient
from database.models import FunctionalRequirement, CodeFile from database.models import FunctionalRequirement, CodeFile
from utils.output_writer import write_code_file, get_file_extension
class CodeGenerator: class CodeGenerator:
""" """根据函数签名约束,调用 LLM 生成代码文件,并按模块写入子目录"""
根据功能需求 + 函数签名约束使用 LLM 生成代码函数文件
签名由 RequirementAnalyzer.build_function_signature() 预先生成
注入 Prompt 后可确保代码参数列表与签名 JSON 完全一致
"""
def __init__(self, llm_client: Optional[LLMClient] = None): def __init__(self, llm: LLMClient):
""" self.llm = llm
初始化代码生成器
Args:
llm_client: LLM 客户端实例 None 时自动创建
"""
self.llm = llm_client or LLMClient()
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
# 单个生成 # 单个代码文件生成
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
def generate( def generate(
self, self,
func_req: FunctionalRequirement, func_req: FunctionalRequirement,
output_dir: str, output_dir: str,
language: str = config.DEFAULT_LANGUAGE, language: str = None,
knowledge: str = "", knowledge: str = "",
signature: Optional[dict] = None, signature: dict = None,
) -> CodeFile: ) -> CodeFile:
""" """
为单个功能需求生成代码文件 为单个功能需求生成代码文件写入 output_dir/<module>/ 子目录
Args: Args:
func_req: 功能需求对象必须含有效 id func_req: 功能需求对象
output_dir: 代码输出目录 output_dir: 项目根输出目录
language: 目标编程语言 language: 目标语言
knowledge: 知识库文本可选 knowledge: 知识库文本
signature: 函数签名 dict RequirementAnalyzer 生成 signature: 函数签名 dict可选有则作为约束
传入后将作为强约束注入 Prompt确保代码参数
与签名 JSON 完全一致 None 时退化为无约束模式
Returns: Returns:
CodeFile 对象含生成的代码内容和文件路径未持久化 CodeFile 对象未持久化
Raises: Raises:
ValueError: func_req.id None RuntimeError: LLM 调用失败
Exception: LLM 调用失败或文件写入失败
""" """
if func_req.id is None: language = language or config.DEFAULT_LANGUAGE
raise ValueError("FunctionalRequirement 必须先持久化id 不能为 None") module = (func_req.module or config.DEFAULT_MODULE).strip()
knowledge_section = self._build_knowledge_section(knowledge) # 按模块创建子目录
signature_json = self._build_signature_json(signature, func_req) module_dir = os.path.join(output_dir, module)
os.makedirs(module_dir, exist_ok=True)
self._ensure_init_py(module_dir)
# 构建 Prompt
sig_json = json.dumps(signature, ensure_ascii=False, indent=2) if signature else "{}"
knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else ""
prompt = config.CODE_GEN_PROMPT_TEMPLATE.format( prompt = config.CODE_GEN_PROMPT_TEMPLATE.format(
language=language, language = language,
knowledge_section=knowledge_section, signature_json = sig_json,
title=func_req.title, description = func_req.description,
description=func_req.description, knowledge_section = knowledge_section,
signature_json=signature_json,
) )
try:
code_content = self.llm.chat( code_content = self.llm.chat(
system_prompt=( prompt,
f"你是一位资深 {language} 工程师,只输出纯代码," system = f"You are an expert {language} developer. Output only code.",
"不添加任何 markdown 标记。函数签名必须与提供的 JSON 规范完全一致。" temperature = 0.2,
), max_tokens = 4096,
user_prompt=prompt,
) )
except Exception as e:
raise RuntimeError(f"代码生成失败 [{func_req.function_name}]: {e}")
file_path = write_code_file( # 写入文件
output_dir=output_dir, ext = self._get_extension(language)
function_name=func_req.function_name,
language=language,
content=code_content,
)
ext = get_file_extension(language)
file_name = f"{func_req.function_name}{ext}" file_name = f"{func_req.function_name}{ext}"
file_path = os.path.join(module_dir, file_name)
Path(file_path).write_text(code_content, encoding="utf-8")
return CodeFile( return CodeFile(
project_id=func_req.project_id, func_req_id = func_req.id,
func_req_id=func_req.id, file_name = file_name,
file_name=file_name, file_path = file_path,
file_path=file_path, module = module,
language=language, language = language,
content=code_content, content = code_content,
) )
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
@ -104,113 +94,80 @@ class CodeGenerator:
self, self,
func_reqs: List[FunctionalRequirement], func_reqs: List[FunctionalRequirement],
output_dir: str, output_dir: str,
language: str = config.DEFAULT_LANGUAGE, language: str = None,
knowledge: str = "", knowledge: str = "",
signatures: Optional[List[dict]] = None, signatures: Optional[List[dict]] = None,
on_progress: Optional[Callable] = None, on_progress: Optional[Callable] = None,
) -> List[CodeFile]: ) -> List[CodeFile]:
""" """
批量生成代码文件 批量生成代码文件
Args: Args:
func_reqs: 功能需求列表 func_reqs: 功能需求列表
output_dir: 输出目录 output_dir: 项目根输出目录
language: 目标语言 language: 目标语言
knowledge: 知识库文本 knowledge: 知识库文本
signatures: func_reqs 等长的签名列表索引对应 signatures: func_reqs 等长的签名列表索引对应
None 时所有条目均以无约束模式生成 on_progress: 进度回调 fn(index, total, req, code_file, error)
on_progress: 进度回调 fn(index, total, func_req, code_file, error)
Returns: Returns:
成功生成的 CodeFile 列表 成功生成的 CodeFile 列表
""" """
results = [] language = language or config.DEFAULT_LANGUAGE
total = len(func_reqs) total = len(func_reqs)
results = []
# 构建 func_req.id → signature 的快速查找表
sig_map = self._build_signature_map(func_reqs, signatures) sig_map = self._build_signature_map(func_reqs, signatures)
for i, req in enumerate(func_reqs): for i, req in enumerate(func_reqs, 1):
sig = sig_map.get(req.id) sig = sig_map.get(req.id)
try: try:
code_file = self.generate( code_file = self.generate(
func_req=req, func_req = req,
output_dir=output_dir, output_dir = output_dir,
language=language, language = language,
knowledge=knowledge, knowledge = knowledge,
signature=sig, signature = sig,
) )
results.append(code_file) results.append(code_file)
if on_progress: if on_progress:
on_progress(i + 1, total, req, code_file, None) on_progress(i, total, req, code_file, None)
except Exception as e: except Exception as e:
if on_progress: if on_progress:
on_progress(i + 1, total, req, None, e) on_progress(i, total, req, None, e)
return results return results
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
# 私有工具方法 # 工具方法
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
@staticmethod
def _build_knowledge_section(knowledge: str) -> str:
"""构建知识库 Prompt 段落"""
if not knowledge or not knowledge.strip():
return ""
return (
"## 参考知识库(实现时请遵循以下规范)\n"
f"{knowledge}\n\n---\n"
)
@staticmethod
def _build_signature_json(
signature: Optional[dict],
func_req: FunctionalRequirement,
) -> str:
"""
将签名 dict 序列化为格式化 JSON 字符串
若签名为 None则构造最小占位签名保持 Prompt 结构完整
Args:
signature: 签名 dict None
func_req: 对应的功能需求用于占位签名
Returns:
JSON 字符串
"""
if signature:
return json.dumps(signature, ensure_ascii=False, indent=2)
# 无签名时的最小占位,提示 LLM 自行设计但保持格式
fallback = {
"name": func_req.function_name,
"requirement_id": f"REQ.{func_req.index_no:02d}",
"description": func_req.description,
"type": "function",
"parameters": "<<请根据功能描述自行设计参数>>",
"return": "<<请根据功能描述自行设计返回值>>",
}
return json.dumps(fallback, ensure_ascii=False, indent=2)
@staticmethod @staticmethod
def _build_signature_map( def _build_signature_map(
func_reqs: List[FunctionalRequirement], func_reqs: List[FunctionalRequirement],
signatures: Optional[List[dict]], signatures: Optional[List[dict]],
) -> dict: ) -> dict:
""" """构建 func_req.id → signature 的快速查找表"""
构建 func_req.id signature 映射表
Args:
func_reqs: 功能需求列表
signatures: func_reqs 等长的签名列表 None
Returns:
{req_id: signature_dict} 字典
"""
if not signatures: if not signatures:
return {} return {}
sig_map = {} return {
for req, sig in zip(func_reqs, signatures): req.id: sig
if req.id is not None and sig: for req, sig in zip(func_reqs, signatures)
sig_map[req.id] = sig if req.id is not None
return sig_map }
@staticmethod
def _get_extension(language: str) -> str:
ext_map = {
"python": ".py", "javascript": ".js", "typescript": ".ts",
"java": ".java", "go": ".go", "rust": ".rs",
"cpp": ".cpp", "c": ".c", "csharp": ".cs",
"ruby": ".rb", "php": ".php", "swift": ".swift", "kotlin": ".kt",
}
return ext_map.get(language.lower(), ".txt")
@staticmethod
def _ensure_init_py(directory: str) -> None:
"""在目录中创建 __init__.pyPython 包标识)"""
init = os.path.join(directory, "__init__.py")
if not os.path.exists(init):
Path(init).write_text("# Auto-generated module package\n", encoding="utf-8")

View File

@ -1,90 +1,114 @@
# core/llm_client.py - LLM 客户端封装 # core/llm_client.py - LLM API 调用封装
import time
import json import json
from typing import Optional from typing import Optional
from openai import OpenAI, APIError, APITimeoutError, RateLimitError
import config import config
class LLMClient: class LLMClient:
""" """封装 OpenAI 兼容接口,提供统一的调用入口与重试机制"""
OpenAI 兼容 LLM 客户端封装
支持任何兼容 OpenAI API 格式的服务OpenAI / Azure / 本地模型等
"""
def __init__( def __init__(
self, self,
api_key: str = config.LLM_API_KEY, api_key: str = None,
base_url: str = config.LLM_BASE_URL, api_base: str = None,
model: str = config.LLM_MODEL, model: str = None,
temperature: float = config.LLM_TEMPERATURE,
): ):
""" self.model = model or config.LLM_MODEL
初始化 LLM 客户端 self.client = OpenAI(
api_key = api_key or config.LLM_API_KEY,
Args: base_url = api_base or config.LLM_API_BASE,
api_key: API 密钥 timeout = config.LLM_TIMEOUT,
base_url: API 基础 URL
model: 模型名称
temperature: 生成温度0~1越低越确定
Raises:
ImportError: 未安装 openai
ValueError: api_key 为空
"""
try:
from openai import OpenAI
except ImportError:
raise ImportError("请安装 openai: pip install openai")
if not api_key:
raise ValueError("LLM_API_KEY 未配置,请在 .env 文件中设置")
self.model = model
self.temperature = temperature
self._client = OpenAI(api_key=api_key, base_url=base_url)
def chat(self, system_prompt: str, user_prompt: str) -> str:
"""
发送对话请求返回模型回复文本
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
Returns:
模型回复的文本内容
Raises:
Exception: API 调用失败
"""
response = self._client.chat.completions.create(
model=self.model,
temperature=self.temperature,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
) )
return response.choices[0].message.content.strip()
def chat_json(self, system_prompt: str, user_prompt: str) -> dict: def chat(
self,
prompt: str,
system: str = "You are a helpful assistant.",
temperature: float = 0.2,
max_tokens: int = 4096,
) -> str:
""" """
发送对话请求解析并返回 JSON 结果 发送单轮对话请求返回模型回复文本
Args: Args:
system_prompt: 系统提示词 prompt: 用户消息
user_prompt: 用户输入 system: 系统提示词
temperature: 采样温度
max_tokens: 最大输出 token
Returns: Returns:
解析后的 dict 对象 模型回复的纯文本字符串
Raises: Raises:
json.JSONDecodeError: 模型返回非合法 JSON RuntimeError: 超过最大重试次数后仍失败
""" """
raw = self.chat(system_prompt, user_prompt) last_error = None
for attempt in range(1, config.LLM_MAX_RETRY + 1):
try:
resp = self.client.chat.completions.create(
model = self.model,
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
],
temperature = temperature,
max_tokens = max_tokens,
)
return resp.choices[0].message.content.strip()
except RateLimitError as e:
wait = 2 ** attempt
last_error = e
time.sleep(wait)
except APITimeoutError as e:
last_error = e
time.sleep(2)
except APIError as e:
last_error = e
if attempt >= config.LLM_MAX_RETRY:
break
time.sleep(1)
raise RuntimeError(
f"LLM 调用失败(已重试 {config.LLM_MAX_RETRY} 次): {last_error}"
)
def chat_json(
self,
prompt: str,
system: str = "You are a helpful assistant. Always respond with valid JSON.",
temperature: float = 0.1,
max_tokens: int = 4096,
) -> any:
"""
发送请求并将回复解析为 JSON 对象
Returns:
解析后的 Python 对象dict list
Raises:
ValueError: JSON 解析失败
RuntimeError: LLM 调用失败
"""
raw = self.chat(prompt, system=system, temperature=temperature, max_tokens=max_tokens)
# 去除可能的 markdown 代码块包裹 # 去除可能的 markdown 代码块包裹
raw = raw.strip() cleaned = self._strip_markdown_code_block(raw)
if raw.startswith("```"): try:
lines = raw.split("\n") return json.loads(cleaned)
raw = "\n".join(lines[1:-1]) except json.JSONDecodeError as e:
return json.loads(raw) raise ValueError(f"LLM 返回内容无法解析为 JSON: {e}\n原始内容:\n{raw}")
@staticmethod
def _strip_markdown_code_block(text: str) -> str:
"""去除 ```json ... ``` 或 ``` ... ``` 包裹"""
text = text.strip()
if text.startswith("```"):
lines = text.splitlines()
# 去掉首行(```json 或 ```)和末行(```
inner = lines[1:] if lines[-1].strip() == "```" else lines[1:]
if inner and inner[-1].strip() == "```":
inner = inner[:-1]
text = "\n".join(inner).strip()
return text

View File

@ -1,6 +1,6 @@
# core/requirement_analyzer.py - 需求分解 & 函数签名生成 # core/requirement_analyzer.py - 需求分解 & 函数签名生成
import re import json
from typing import List, Optional from typing import List, Optional, Callable
import config import config
from core.llm_client import LLMClient from core.llm_client import LLMClient
@ -8,19 +8,10 @@ from database.models import FunctionalRequirement
class RequirementAnalyzer: class RequirementAnalyzer:
""" """负责需求分解、模块分类、函数签名生成"""
使用 LLM 将原始需求分解为功能需求列表并生成函数接口签名
支持注入知识库上下文以提升分解质量
"""
def __init__(self, llm_client: Optional[LLMClient] = None): def __init__(self, llm: LLMClient):
""" self.llm = llm
初始化需求分析器
Args:
llm_client: LLM 客户端实例 None 时自动创建
"""
self.llm = llm_client or LLMClient()
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
# 需求分解 # 需求分解
@ -34,7 +25,7 @@ class RequirementAnalyzer:
knowledge: str = "", knowledge: str = "",
) -> List[FunctionalRequirement]: ) -> List[FunctionalRequirement]:
""" """
将原始需求分解为功能需求列表 将原始需求文本分解为功能需求列表含模块分类
Args: Args:
raw_requirement: 原始需求文本 raw_requirement: 原始需求文本
@ -44,196 +35,175 @@ class RequirementAnalyzer:
Returns: Returns:
FunctionalRequirement 对象列表未持久化id=None FunctionalRequirement 对象列表未持久化id=None
Raises:
ValueError: LLM 返回格式不合法
json.JSONDecodeError: JSON 解析失败
""" """
knowledge_section = self._build_knowledge_section(knowledge) knowledge_section = (
f"【参考知识库】\n{knowledge}\n" if knowledge else ""
)
prompt = config.DECOMPOSE_PROMPT_TEMPLATE.format( prompt = config.DECOMPOSE_PROMPT_TEMPLATE.format(
knowledge_section=knowledge_section, raw_requirement = raw_requirement,
raw_requirement=raw_requirement, knowledge_section = knowledge_section,
) )
result = self.llm.chat_json( try:
system_prompt="你是一位资深软件架构师,擅长需求分析与系统设计。", items = self.llm.chat_json(prompt)
user_prompt=prompt, if not isinstance(items, list):
) raise ValueError("LLM 返回结果不是数组")
except Exception as e:
raise RuntimeError(f"需求分解失败: {e}")
items = result.get("functional_requirements", []) reqs = []
if not items: for i, item in enumerate(items, 1):
raise ValueError("LLM 未返回任何功能需求,请检查原始需求描述")
requirements = []
for item in items:
req = FunctionalRequirement( req = FunctionalRequirement(
project_id=project_id, project_id = project_id,
raw_req_id=raw_req_id, raw_req_id = raw_req_id,
index_no=int(item.get("index", len(requirements) + 1)), index_no = i,
title=item.get("title", "未命名功能"), title = item.get("title", f"功能{i}"),
description=item.get("description", ""), description = item.get("description", ""),
function_name=self._sanitize_function_name( function_name = item.get("function_name", f"function_{i}"),
item.get("function_name", f"func_{len(requirements)+1}") priority = item.get("priority", "medium"),
), module = item.get("module", config.DEFAULT_MODULE),
priority=item.get("priority", "medium"), status = "pending",
is_custom = False,
) )
requirements.append(req) reqs.append(req)
return reqs
return requirements
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
# 函数签名生成(新增) # 模块分类(独立步骤,可对已有需求列表重新分类)
# ══════════════════════════════════════════════════
def classify_modules(
self,
func_reqs: List[FunctionalRequirement],
knowledge: str = "",
) -> List[dict]:
"""
对功能需求列表进行模块分类返回 {function_name: module} 映射列表
Args:
func_reqs: 功能需求列表
knowledge: 知识库文本可选
Returns:
[{"function_name": "...", "module": "..."}, ...]
"""
req_list = [
{
"index_no": r.index_no,
"title": r.title,
"description": r.description,
"function_name": r.function_name,
}
for r in func_reqs
]
knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else ""
prompt = config.MODULE_CLASSIFY_PROMPT_TEMPLATE.format(
requirements_json = json.dumps(req_list, ensure_ascii=False, indent=2),
knowledge_section = knowledge_section,
)
try:
result = self.llm.chat_json(prompt)
if not isinstance(result, list):
raise ValueError("LLM 返回结果不是数组")
return result
except Exception as e:
raise RuntimeError(f"模块分类失败: {e}")
# ══════════════════════════════════════════════════
# 函数签名生成
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
def build_function_signature( def build_function_signature(
self, self,
func_req: FunctionalRequirement, func_req: FunctionalRequirement,
requirement_id: str = "",
knowledge: str = "", knowledge: str = "",
) -> dict: ) -> dict:
""" """
为单个功能需求生成函数接口签名 JSON 为单个功能需求生成函数签名 dict
Args: Args:
func_req: 功能需求对象需含有效 id func_req: 功能需求对象
knowledge: 知识库文本可选 requirement_id: 需求编号字符串 "REQ.01"
knowledge: 知识库文本
Returns: Returns:
符合接口规范的 dict包含 name/requirement_id/description/ 函数签名 dict
type/parameters/return 字段
Raises: Raises:
json.JSONDecodeError: LLM 返回非合法 JSON RuntimeError: LLM 调用或解析失败
""" """
requirement_id = self._format_requirement_id(func_req.index_no) knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else ""
knowledge_section = self._build_knowledge_section(knowledge)
prompt = config.FUNC_SIGNATURE_PROMPT_TEMPLATE.format( prompt = config.FUNC_SIGNATURE_PROMPT_TEMPLATE.format(
knowledge_section=knowledge_section, requirement_id = requirement_id or f"REQ.{func_req.index_no:02d}",
requirement_id=requirement_id, title = func_req.title,
title=func_req.title, description = func_req.description,
function_name=func_req.function_name, function_name = func_req.function_name,
description=func_req.description, module = func_req.module or config.DEFAULT_MODULE,
knowledge_section = knowledge_section,
) )
try:
signature = self.llm.chat_json( sig = self.llm.chat_json(prompt)
system_prompt=( if not isinstance(sig, dict):
"你是一位资深软件架构师,专注于 API 接口设计。" raise ValueError("LLM 返回结果不是 dict")
"只输出合法 JSON不添加任何说明文字。" # 确保 module 字段存在
), if "module" not in sig:
user_prompt=prompt, sig["module"] = func_req.module or config.DEFAULT_MODULE
) return sig
except Exception as e:
# 确保关键字段存在,做兜底处理 raise RuntimeError(f"签名生成失败 [{func_req.function_name}]: {e}")
signature.setdefault("name", func_req.function_name)
signature.setdefault("requirement_id", requirement_id)
signature.setdefault("description", func_req.description)
signature.setdefault("type", "function")
signature.setdefault("parameters", {})
signature.setdefault("return", {"type": "void", "description": ""})
return signature
def build_function_signatures_batch( def build_function_signatures_batch(
self, self,
func_reqs: List[FunctionalRequirement], func_reqs: List[FunctionalRequirement],
knowledge: str = "", knowledge: str = "",
on_progress=None, on_progress: Optional[Callable] = None,
) -> List[dict]: ) -> List[dict]:
""" """
批量为功能需求列表生成函数接口签名 批量生成函数签名失败时使用降级结构
Args: Args:
func_reqs: 功能需求列表 func_reqs: 功能需求列表
knowledge: 知识库文本可选 knowledge: 知识库文本
on_progress: 进度回调 fn(index, total, func_req, signature, error) on_progress: 进度回调 fn(index, total, req, signature, error)
Returns: Returns:
签名 dict 列表顺序与 func_reqs 一致 func_reqs 等长的签名 dict 列表索引一一对应
生成失败的条目使用降级结构填充不中断整体流程
""" """
results = [] signatures = []
total = len(func_reqs) total = len(func_reqs)
for i, req in enumerate(func_reqs): for i, req in enumerate(func_reqs, 1):
req_id = f"REQ.{req.index_no:02d}"
try: try:
sig = self.build_function_signature(req, knowledge) sig = self.build_function_signature(req, req_id, knowledge)
results.append(sig) error = None
if on_progress:
on_progress(i + 1, total, req, sig, None)
except Exception as e: except Exception as e:
# 降级:用基础信息填充,保证 JSON 完整性 sig = self._fallback_signature(req, req_id)
fallback = self._build_fallback_signature(req) error = e
results.append(fallback)
signatures.append(sig)
if on_progress: if on_progress:
on_progress(i + 1, total, req, fallback, e) on_progress(i, total, req, sig, error)
return results return signatures
# ══════════════════════════════════════════════════
# 私有工具方法
# ══════════════════════════════════════════════════
@staticmethod @staticmethod
def _build_knowledge_section(knowledge: str) -> str: def _fallback_signature(
"""构建知识库 Prompt 段落""" req: FunctionalRequirement,
if not knowledge or not knowledge.strip(): requirement_id: str,
return "" ) -> dict:
return f"""## 参考知识库 """生成降级签名结构LLM 失败时使用)"""
{knowledge}
---
"""
@staticmethod
def _sanitize_function_name(name: str) -> str:
"""
清理函数名确保符合 snake_case 规范
Args:
name: 原始函数名
Returns:
合法的 snake_case 函数名
"""
name = re.sub(r"[^a-zA-Z0-9_]", "_", name).lower()
name = re.sub(r"_+", "_", name).strip("_")
if name and name[0].isdigit():
name = "func_" + name
return name or "unnamed_function"
@staticmethod
def _format_requirement_id(index_no: int) -> str:
"""
将序号格式化为需求编号字符串
Args:
index_no: 功能需求序号 1 开始
Returns:
格式化编号 'REQ.01''REQ.12'
"""
return f"REQ.{index_no:02d}"
@staticmethod
def _build_fallback_signature(func_req: FunctionalRequirement) -> dict:
"""
构建降级签名LLM 调用失败时使用
Args:
func_req: 功能需求对象
Returns:
包含基础信息的签名 dict
"""
return { return {
"name": func_req.function_name, "name": req.function_name,
"requirement_id": f"REQ.{func_req.index_no:02d}", "requirement_id": requirement_id,
"description": func_req.description, "description": req.description,
"type": "function", "type": "function",
"module": req.module or config.DEFAULT_MODULE,
"parameters": {}, "parameters": {},
"return": { "return": {
"type": "void", "type": "any",
"description": "TODO: define return value" "on_success": {"value": "...", "description": "成功时返回值"},
"on_failure": {"value": "None", "description": "失败时返回 None"},
}, },
"_note": "Auto-generated fallback due to LLM error"
} }

View File

@ -1,314 +1,152 @@
# database/db_manager.py - 数据库操作管理器 # database/db_manager.py - 数据库 CRUD 操作封装
import sqlite3
import os import os
from datetime import datetime
from typing import List, Optional from typing import List, Optional
from contextlib import contextmanager
from database.models import ( from sqlalchemy import create_engine
CREATE_TABLES_SQL, Project, RawRequirement, from sqlalchemy.orm import sessionmaker, Session
FunctionalRequirement, CodeFile
)
import config import config
from database.models import Base, Project, RawRequirement, FunctionalRequirement, CodeFile
class DBManager: class DBManager:
"""SQLite 数据库管理器,封装所有 CRUD 操作""" """SQLite 数据库管理器,封装所有 CRUD 操作"""
def __init__(self, db_path: str = config.DB_PATH): def __init__(self, db_path: str = None):
self.db_path = db_path db_path = db_path or config.DB_PATH
os.makedirs(os.path.dirname(db_path), exist_ok=True) os.makedirs(os.path.dirname(db_path), exist_ok=True)
self._init_db() self.engine = create_engine(f"sqlite:///{db_path}", echo=False)
Base.metadata.create_all(self.engine)
self._Session = sessionmaker(bind=self.engine)
# ── 连接上下文管理器 ────────────────────────────────── def _session(self) -> Session:
return self._Session()
@contextmanager
def _get_conn(self):
"""获取数据库连接(自动提交/回滚)"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _init_db(self):
"""初始化数据库,创建所有表"""
with self._get_conn() as conn:
conn.executescript(CREATE_TABLES_SQL)
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
# Project CRUD # Project
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
def create_project(self, project: Project) -> int: def create_project(self, project: Project) -> int:
"""创建项目,返回新项目 ID""" with self._session() as s:
sql = """ s.add(project)
INSERT INTO projects (name, description, language, output_dir, created_at, updated_at) s.commit()
VALUES (?, ?, ?, ?, ?, ?) s.refresh(project)
""" return project.id
with self._get_conn() as conn:
cur = conn.execute(sql, (
project.name, project.description, project.language,
project.output_dir, project.created_at, project.updated_at
))
return cur.lastrowid
def get_project_by_id(self, project_id: int) -> Optional[Project]:
"""根据 ID 查询项目"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM projects WHERE id = ?", (project_id,)
).fetchone()
if row is None:
return None
return Project(
id=row["id"], name=row["name"], description=row["description"],
language=row["language"], output_dir=row["output_dir"],
created_at=row["created_at"], updated_at=row["updated_at"]
)
def get_project_by_name(self, name: str) -> Optional[Project]: def get_project_by_name(self, name: str) -> Optional[Project]:
"""根据名称查询项目""" with self._session() as s:
with self._get_conn() as conn: return s.query(Project).filter_by(name=name).first()
row = conn.execute(
"SELECT * FROM projects WHERE name = ?", (name,)
).fetchone()
if row is None:
return None
return Project(
id=row["id"], name=row["name"], description=row["description"],
language=row["language"], output_dir=row["output_dir"],
created_at=row["created_at"], updated_at=row["updated_at"]
)
def list_projects(self) -> List[Project]: def get_project_by_id(self, project_id: int) -> Optional[Project]:
"""列出所有项目""" with self._session() as s:
with self._get_conn() as conn: return s.get(Project, project_id)
rows = conn.execute(
"SELECT * FROM projects ORDER BY created_at DESC"
).fetchall()
return [
Project(
id=r["id"], name=r["name"], description=r["description"],
language=r["language"], output_dir=r["output_dir"],
created_at=r["created_at"], updated_at=r["updated_at"]
) for r in rows
]
def update_project(self, project: Project) -> None: def update_project(self, project: Project) -> None:
"""更新项目信息""" with self._session() as s:
project.updated_at = datetime.now().isoformat() s.merge(project)
sql = """ s.commit()
UPDATE projects
SET name=?, description=?, language=?, output_dir=?, updated_at=?
WHERE id=?
"""
with self._get_conn() as conn:
conn.execute(sql, (
project.name, project.description, project.language,
project.output_dir, project.updated_at, project.id
))
def delete_project(self, project_id: int) -> None: def list_projects(self) -> List[Project]:
"""删除项目(级联删除所有关联数据)""" with self._session() as s:
with self._get_conn() as conn: return s.query(Project).order_by(Project.created_at.desc()).all()
conn.execute("DELETE FROM projects WHERE id = ?", (project_id,))
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
# RawRequirement CRUD # RawRequirement
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
def create_raw_requirement(self, req: RawRequirement) -> int: def create_raw_requirement(self, raw_req: RawRequirement) -> int:
"""创建原始需求,返回新记录 ID""" with self._session() as s:
sql = """ s.add(raw_req)
INSERT INTO raw_requirements s.commit()
(project_id, content, source_type, source_name, knowledge, created_at) s.refresh(raw_req)
VALUES (?, ?, ?, ?, ?, ?) return raw_req.id
"""
with self._get_conn() as conn:
cur = conn.execute(sql, (
req.project_id, req.content, req.source_type,
req.source_name, req.knowledge, req.created_at
))
return cur.lastrowid
def get_raw_requirement(self, req_id: int) -> Optional[RawRequirement]: def get_raw_requirement(self, raw_req_id: int) -> Optional[RawRequirement]:
"""根据 ID 查询原始需求""" with self._session() as s:
with self._get_conn() as conn: return s.get(RawRequirement, raw_req_id)
row = conn.execute(
"SELECT * FROM raw_requirements WHERE id = ?", (req_id,)
).fetchone()
if row is None:
return None
return RawRequirement(
id=row["id"], project_id=row["project_id"], content=row["content"],
source_type=row["source_type"], source_name=row["source_name"],
knowledge=row["knowledge"], created_at=row["created_at"]
)
def list_raw_requirements_by_project(self, project_id: int) -> List[RawRequirement]:
"""查询项目下所有原始需求"""
with self._get_conn() as conn:
rows = conn.execute(
"SELECT * FROM raw_requirements WHERE project_id = ? ORDER BY created_at",
(project_id,)
).fetchall()
return [
RawRequirement(
id=r["id"], project_id=r["project_id"], content=r["content"],
source_type=r["source_type"], source_name=r["source_name"],
knowledge=r["knowledge"], created_at=r["created_at"]
) for r in rows
]
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
# FunctionalRequirement CRUD # FunctionalRequirement
# ══════════════════════════════════════════════════ # ══════════════════════════════════════════════════
def create_functional_requirement(self, req: FunctionalRequirement) -> int: def create_functional_requirement(self, req: FunctionalRequirement) -> int:
"""创建功能需求,返回新记录 ID""" with self._session() as s:
sql = """ s.add(req)
INSERT INTO functional_requirements s.commit()
(project_id, raw_req_id, index_no, title, description, s.refresh(req)
function_name, priority, status, is_custom, created_at, updated_at) return req.id
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
with self._get_conn() as conn:
cur = conn.execute(sql, (
req.project_id, req.raw_req_id, req.index_no,
req.title, req.description, req.function_name,
req.priority, req.status, int(req.is_custom),
req.created_at, req.updated_at
))
return cur.lastrowid
def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]: def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]:
"""根据 ID 查询功能需求""" with self._session() as s:
with self._get_conn() as conn: return s.get(FunctionalRequirement, req_id)
row = conn.execute(
"SELECT * FROM functional_requirements WHERE id = ?", (req_id,)
).fetchone()
if row is None:
return None
return self._row_to_func_req(row)
def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]: def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]:
"""查询项目下所有功能需求(按序号排序)""" with self._session() as s:
with self._get_conn() as conn: return (
rows = conn.execute( s.query(FunctionalRequirement)
"""SELECT * FROM functional_requirements .filter_by(project_id=project_id)
WHERE project_id = ? ORDER BY index_no""", .order_by(FunctionalRequirement.index_no)
(project_id,) .all()
).fetchall() )
return [self._row_to_func_req(r) for r in rows]
def update_functional_requirement(self, req: FunctionalRequirement) -> None: def update_functional_requirement(self, req: FunctionalRequirement) -> None:
"""更新功能需求""" with self._session() as s:
req.updated_at = datetime.now().isoformat() s.merge(req)
sql = """ s.commit()
UPDATE functional_requirements
SET title=?, description=?, function_name=?, priority=?,
status=?, index_no=?, updated_at=?
WHERE id=?
"""
with self._get_conn() as conn:
conn.execute(sql, (
req.title, req.description, req.function_name,
req.priority, req.status, req.index_no,
req.updated_at, req.id
))
def delete_functional_requirement(self, req_id: int) -> None: def delete_functional_requirement(self, req_id: int) -> None:
"""删除功能需求""" with self._session() as s:
with self._get_conn() as conn: obj = s.get(FunctionalRequirement, req_id)
conn.execute( if obj:
"DELETE FROM functional_requirements WHERE id = ?", (req_id,) s.delete(obj)
) s.commit()
def _row_to_func_req(self, row) -> FunctionalRequirement: def bulk_update_modules(self, updates: List[dict]) -> None:
"""sqlite Row → FunctionalRequirement 对象"""
return FunctionalRequirement(
id=row["id"], project_id=row["project_id"],
raw_req_id=row["raw_req_id"], index_no=row["index_no"],
title=row["title"], description=row["description"],
function_name=row["function_name"], priority=row["priority"],
status=row["status"], is_custom=bool(row["is_custom"]),
created_at=row["created_at"], updated_at=row["updated_at"]
)
# ══════════════════════════════════════════════════
# CodeFile CRUD
# ══════════════════════════════════════════════════
def create_code_file(self, code_file: CodeFile) -> int:
"""创建代码文件记录,返回新记录 ID"""
sql = """
INSERT INTO code_files
(project_id, func_req_id, file_name, file_path,
language, content, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""" """
with self._get_conn() as conn: 批量更新功能需求的 module 字段
cur = conn.execute(sql, (
code_file.project_id, code_file.func_req_id, Args:
code_file.file_name, code_file.file_path, updates: [{"function_name": "...", "module": "..."}, ...]
code_file.language, code_file.content, """
code_file.created_at, code_file.updated_at with self._session() as s:
)) name_to_module = {u["function_name"]: u["module"] for u in updates}
return cur.lastrowid reqs = s.query(FunctionalRequirement).filter(
FunctionalRequirement.function_name.in_(name_to_module.keys())
).all()
for req in reqs:
req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE)
s.commit()
# ══════════════════════════════════════════════════
# CodeFile
# ══════════════════════════════════════════════════
def upsert_code_file(self, code_file: CodeFile) -> int: def upsert_code_file(self, code_file: CodeFile) -> int:
"""插入或更新代码文件(按 func_req_id 唯一键)""" with self._session() as s:
existing = self.get_code_file_by_func_req(code_file.func_req_id) existing = (
if existing: s.query(CodeFile)
code_file.id = existing.id .filter_by(func_req_id=code_file.func_req_id)
code_file.updated_at = datetime.now().isoformat() .first()
sql = """ )
UPDATE code_files if existing:
SET file_name=?, file_path=?, language=?, content=?, updated_at=? existing.file_name = code_file.file_name
WHERE id=? existing.file_path = code_file.file_path
""" existing.module = code_file.module
with self._get_conn() as conn: existing.language = code_file.language
conn.execute(sql, ( existing.content = code_file.content
code_file.file_name, code_file.file_path, s.commit()
code_file.language, code_file.content, return existing.id
code_file.updated_at, code_file.id else:
)) s.add(code_file)
return code_file.id s.commit()
else: s.refresh(code_file)
return self.create_code_file(code_file) return code_file.id
def get_code_file_by_func_req(self, func_req_id: int) -> Optional[CodeFile]: def list_code_files(self, project_id: int) -> List[CodeFile]:
"""根据功能需求 ID 查询代码文件""" with self._session() as s:
with self._get_conn() as conn: return (
row = conn.execute( s.query(CodeFile)
"SELECT * FROM code_files WHERE func_req_id = ?", (func_req_id,) .join(FunctionalRequirement)
).fetchone() .filter(FunctionalRequirement.project_id == project_id)
if row is None: .all()
return None
return self._row_to_code_file(row)
def list_code_files_by_project(self, project_id: int) -> List[CodeFile]:
"""查询项目下所有代码文件"""
with self._get_conn() as conn:
rows = conn.execute(
"SELECT * FROM code_files WHERE project_id = ? ORDER BY created_at",
(project_id,)
).fetchall()
return [self._row_to_code_file(r) for r in rows]
def _row_to_code_file(self, row) -> CodeFile:
"""sqlite Row → CodeFile 对象"""
return CodeFile(
id=row["id"], project_id=row["project_id"],
func_req_id=row["func_req_id"], file_name=row["file_name"],
file_path=row["file_path"], language=row["language"],
content=row["content"], created_at=row["created_at"],
updated_at=row["updated_at"]
) )

View File

@ -1,122 +1,95 @@
# database/models.py - 数据模型定义SQLite 建表 DDL # database/models.py - SQLAlchemy ORM 数据模型
from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from typing import Optional from sqlalchemy import (
Column, Integer, String, Text, Boolean,
DateTime, ForeignKey, Enum,
)
from sqlalchemy.orm import declarative_base, relationship
# ══════════════════════════════════════════════════════ Base = declarative_base()
# DDL 建表语句
# ══════════════════════════════════════════════════════
CREATE_TABLES_SQL = """
PRAGMA foreign_keys = ON;
-- 项目表
CREATE TABLE IF NOT EXISTS projects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE, -- 项目名称
description TEXT, -- 项目描述
language TEXT NOT NULL DEFAULT 'python', -- 目标代码语言
output_dir TEXT NOT NULL, -- 输出目录路径
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
-- 原始需求表
CREATE TABLE IF NOT EXISTS raw_requirements (
id INTEGER PRIMARY KEY AUTOINCREMENT,
project_id INTEGER NOT NULL, -- 关联项目
content TEXT NOT NULL, -- 需求原文
source_type TEXT NOT NULL DEFAULT 'text', -- text | file
source_name TEXT, -- 文件名文件输入时
knowledge TEXT, -- 合并后的知识库内容
created_at TEXT NOT NULL,
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
);
-- 功能需求表
CREATE TABLE IF NOT EXISTS functional_requirements (
id INTEGER PRIMARY KEY AUTOINCREMENT,
project_id INTEGER NOT NULL,
raw_req_id INTEGER NOT NULL, -- 关联原始需求
index_no INTEGER NOT NULL, -- 序号
title TEXT NOT NULL, -- 功能标题
description TEXT NOT NULL, -- 功能描述
function_name TEXT NOT NULL, -- 对应函数名
priority TEXT NOT NULL DEFAULT 'medium', -- high|medium|low
status TEXT NOT NULL DEFAULT 'pending', -- pending|generated|skipped
is_custom INTEGER NOT NULL DEFAULT 0, -- 是否用户自定义 (0/1)
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
FOREIGN KEY (raw_req_id) REFERENCES raw_requirements(id) ON DELETE CASCADE
);
-- 代码文件表
CREATE TABLE IF NOT EXISTS code_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
project_id INTEGER NOT NULL,
func_req_id INTEGER NOT NULL UNIQUE, -- 关联功能需求1对1
file_name TEXT NOT NULL, -- 文件名
file_path TEXT NOT NULL, -- 完整路径
language TEXT NOT NULL, -- 代码语言
content TEXT NOT NULL, -- 代码内容
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
FOREIGN KEY (func_req_id) REFERENCES functional_requirements(id) ON DELETE CASCADE
);
"""
# ══════════════════════════════════════════════════════
# 数据类Python 对象映射)
# ══════════════════════════════════════════════════════
@dataclass
class Project:
name: str
output_dir: str
language: str = "python"
description: str = ""
id: Optional[int] = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
@dataclass class Project(Base):
class RawRequirement: """项目表"""
project_id: int __tablename__ = "projects"
content: str
source_type: str = "text" # text | file id = Column(Integer, primary_key=True, autoincrement=True)
source_name: Optional[str] = None name = Column(String(200), nullable=False, unique=True)
knowledge: Optional[str] = None language = Column(String(50), nullable=False, default="python")
id: Optional[int] = None description = Column(Text, nullable=True)
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) output_dir = Column(String(500), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
raw_requirements = relationship("RawRequirement", back_populates="project", cascade="all, delete-orphan")
functional_requirements = relationship("FunctionalRequirement", back_populates="project", cascade="all, delete-orphan")
def __repr__(self):
return f"<Project(id={self.id}, name={self.name!r}, language={self.language!r})>"
@dataclass class RawRequirement(Base):
class FunctionalRequirement: """原始需求表"""
project_id: int __tablename__ = "raw_requirements"
raw_req_id: int
index_no: int id = Column(Integer, primary_key=True, autoincrement=True)
title: str project_id = Column(Integer, ForeignKey("projects.id"), nullable=False)
description: str content = Column(Text, nullable=False)
function_name: str source_type = Column(String(20), nullable=False, default="text") # text / file
priority: str = "medium" source_name = Column(String(200), nullable=True)
status: str = "pending" knowledge = Column(Text, nullable=True)
is_custom: bool = False created_at = Column(DateTime, default=datetime.utcnow)
id: Optional[int] = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) project = relationship("Project", back_populates="raw_requirements")
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) functional_requirements = relationship("FunctionalRequirement", back_populates="raw_requirement", cascade="all, delete-orphan")
def __repr__(self):
return f"<RawRequirement(id={self.id}, project_id={self.project_id})>"
@dataclass class FunctionalRequirement(Base):
class CodeFile: """功能需求表"""
project_id: int __tablename__ = "functional_requirements"
func_req_id: int
file_name: str id = Column(Integer, primary_key=True, autoincrement=True)
file_path: str project_id = Column(Integer, ForeignKey("projects.id"), nullable=False)
language: str raw_req_id = Column(Integer, ForeignKey("raw_requirements.id"), nullable=False)
content: str index_no = Column(Integer, nullable=False)
id: Optional[int] = None title = Column(String(200), nullable=False)
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) description = Column(Text, nullable=False)
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) function_name = Column(String(200), nullable=False)
priority = Column(Enum("high", "medium", "low"), default="medium")
module = Column(String(100), nullable=True, default="default") # 功能模块
status = Column(String(50), nullable=False, default="pending") # pending / generated / failed
is_custom = Column(Boolean, nullable=False, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
project = relationship("Project", back_populates="functional_requirements")
raw_requirement = relationship("RawRequirement", back_populates="functional_requirements")
code_files = relationship("CodeFile", back_populates="functional_requirement", cascade="all, delete-orphan")
def __repr__(self):
return (
f"<FunctionalRequirement(id={self.id}, title={self.title!r}, "
f"module={self.module!r}, status={self.status!r})>"
)
class CodeFile(Base):
"""生成代码文件表"""
__tablename__ = "code_files"
id = Column(Integer, primary_key=True, autoincrement=True)
func_req_id = Column(Integer, ForeignKey("functional_requirements.id"), nullable=False)
file_name = Column(String(200), nullable=False)
file_path = Column(String(500), nullable=False)
module = Column(String(100), nullable=True) # 冗余存储,方便查询
language = Column(String(50), nullable=False)
content = Column(Text, nullable=True)
generated_at = Column(DateTime, default=datetime.utcnow)
functional_requirement = relationship("FunctionalRequirement", back_populates="code_files")
def __repr__(self):
return f"<CodeFile(id={self.id}, file_name={self.file_name!r}, module={self.module!r})>"

View File

@ -1,17 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# encoding: utf-8 # main.py - 主入口:支持交互式 & 非交互式两种运行模式
# main.py - 主入口:支持交互式 & 非交互式CLI 参数)两种运行模式
#
# 交互式: python main.py
# 非交互式python main.py --non-interactive \
# --project-name "MyProject" \
# --language python \
# --requirement-text "用户管理系统,包含注册、登录、修改密码功能"
#
# 完整参数见python main.py --help
import os import os
import sys import sys
from typing import Dict from typing import Dict, List
import click import click
from rich.console import Console from rich.console import Console
@ -27,9 +18,9 @@ from core.requirement_analyzer import RequirementAnalyzer
from core.code_generator import CodeGenerator from core.code_generator import CodeGenerator
from utils.file_handler import read_file_auto, merge_knowledge_files from utils.file_handler import read_file_auto, merge_knowledge_files
from utils.output_writer import ( from utils.output_writer import (
ensure_project_dir, build_project_output_dir, write_project_readme, ensure_project_dir, build_project_output_dir,
write_function_signatures_json, validate_all_signatures, write_project_readme, write_function_signatures_json,
patch_signatures_with_url, validate_all_signatures, patch_signatures_with_url,
) )
console = Console() console = Console()
@ -44,20 +35,20 @@ def print_banner():
console.print(Panel.fit( console.print(Panel.fit(
"[bold cyan]🚀 需求分析 & 代码生成工具[/bold cyan]\n" "[bold cyan]🚀 需求分析 & 代码生成工具[/bold cyan]\n"
"[dim]Powered by LLM · SQLite · Python[/dim]", "[dim]Powered by LLM · SQLite · Python[/dim]",
border_style="cyan" border_style="cyan",
)) ))
def print_functional_requirements(reqs: list): def print_functional_requirements(reqs: List[FunctionalRequirement]):
"""以表格形式展示功能需求列表""" """以表格形式展示功能需求列表(含模块列)"""
table = Table(title="📋 功能需求列表", show_lines=True) table = Table(title="📋 功能需求列表", show_lines=True)
table.add_column("序号", style="cyan", width=6) table.add_column("序号", style="cyan", width=6)
table.add_column("ID", style="dim", width=6) table.add_column("ID", style="dim", width=6)
table.add_column("标题", style="bold", width=20) table.add_column("模块", style="magenta", width=15)
table.add_column("标题", style="bold", width=18)
table.add_column("函数名", width=25) table.add_column("函数名", width=25)
table.add_column("优先级", width=8) table.add_column("优先级", width=8)
table.add_column("类型", width=8) table.add_column("描述", width=35)
table.add_column("描述", width=40)
priority_color = {"high": "red", "medium": "yellow", "low": "green"} priority_color = {"high": "red", "medium": "yellow", "low": "green"}
for req in reqs: for req in reqs:
@ -65,53 +56,51 @@ def print_functional_requirements(reqs: list):
table.add_row( table.add_row(
str(req.index_no), str(req.index_no),
str(req.id) if req.id else "-", str(req.id) if req.id else "-",
req.module or config.DEFAULT_MODULE,
req.title, req.title,
f"[code]{req.function_name}[/code]", f"[code]{req.function_name}[/code]",
f"[{color}]{req.priority}[/{color}]", f"[{color}]{req.priority}[/{color}]",
"[magenta]自定义[/magenta]" if req.is_custom else "LLM生成", req.description[:50] + "..." if len(req.description) > 50 else req.description,
req.description[:60] + "..." if len(req.description) > 60 else req.description,
) )
console.print(table) console.print(table)
def print_signatures_preview(signatures: list): def print_module_summary(reqs: List[FunctionalRequirement]):
""" """打印模块分组摘要"""
以表格形式预览函数签名列表 url 字段 module_map: Dict[str, List[str]] = {}
for req in reqs:
m = req.module or config.DEFAULT_MODULE
module_map.setdefault(m, []).append(req.function_name)
Args: table = Table(title="📦 功能模块分组", show_lines=True)
signatures: 纯签名列表顶层文档的 "functions" 字段 table.add_column("模块", style="magenta bold", width=20)
""" table.add_column("函数数量", style="cyan", width=8)
table.add_column("函数列表", width=50)
for module, funcs in sorted(module_map.items()):
table.add_row(module, str(len(funcs)), ", ".join(funcs))
console.print(table)
def print_signatures_preview(signatures: List[dict]):
"""以表格形式预览函数签名列表(含 module / url 列)"""
table = Table(title="📄 函数签名预览", show_lines=True) table = Table(title="📄 函数签名预览", show_lines=True)
table.add_column("需求编号", style="cyan", width=10) table.add_column("需求编号", style="cyan", width=8)
table.add_column("函数名", style="bold", width=25) table.add_column("模块", style="magenta", width=15)
table.add_column("参数数量", width=8) table.add_column("函数名", style="bold", width=22)
table.add_column("参数数", width=6)
table.add_column("返回类型", width=10) table.add_column("返回类型", width=10)
table.add_column("成功返回值", width=18) table.add_column("URL", style="dim", width=28)
table.add_column("失败返回值", width=18)
table.add_column("URL", style="dim", width=30)
def _fmt_value(v) -> str:
if v is None:
return "-"
if isinstance(v, dict):
return "{" + ", ".join(v.keys()) + "}"
return str(v)[:16]
for sig in signatures: for sig in signatures:
ret = sig.get("return") or {} ret = sig.get("return") or {}
on_success = ret.get("on_success") or {}
on_failure = ret.get("on_failure") or {}
url = sig.get("url", "") url = sig.get("url", "")
# 只显示文件名部分,避免路径过长
url_display = os.path.basename(url) if url else "[dim]待生成[/dim]" url_display = os.path.basename(url) if url else "[dim]待生成[/dim]"
table.add_row( table.add_row(
sig.get("requirement_id", "-"), sig.get("requirement_id", "-"),
sig.get("module", "-"),
sig.get("name", "-"), sig.get("name", "-"),
str(len(sig.get("parameters", {}))), str(len(sig.get("parameters", {}))),
ret.get("type", "void"), ret.get("type", "void"),
_fmt_value(on_success.get("value")),
_fmt_value(on_failure.get("value")),
url_display, url_display,
) )
console.print(table) console.print(table)
@ -127,39 +116,39 @@ def step_init_project(
description: str = "", description: str = "",
non_interactive: bool = False, non_interactive: bool = False,
) -> Project: ) -> Project:
console.print(
"\n[bold]Step 1 · 项目配置[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
if not non_interactive: if not non_interactive:
console.print("\n[bold]Step 1 · 项目配置[/bold]", style="blue") project_name = project_name or Prompt.ask("📁 项目名称")
project_name = project_name or Prompt.ask("📁 请输入项目名称")
language = language or Prompt.ask( language = language or Prompt.ask(
"💻 目标代码语言", "💻 目标语言", default=config.DEFAULT_LANGUAGE,
default=config.DEFAULT_LANGUAGE, choices=["python","javascript","typescript","java","go","rust"],
choices=["python", "javascript", "typescript", "java", "go", "rust"],
) )
description = description or Prompt.ask("📝 项目描述(可选)", default="") description = description or Prompt.ask("📝 项目描述(可选)", default="")
else: else:
if not project_name: if not project_name:
raise ValueError("非交互模式下 --project-name 为必填项") raise ValueError("非交互模式下 --project-name 为必填项")
language = language or config.DEFAULT_LANGUAGE language = language or config.DEFAULT_LANGUAGE
console.print("\n[bold]Step 1 · 项目配置[/bold] [dim](非交互)[/dim]", style="blue") console.print(f" 项目: {project_name} 语言: {language}")
console.print(f" 项目名称: {project_name} 语言: {language}")
existing = db.get_project_by_name(project_name) existing = db.get_project_by_name(project_name)
if existing: if existing:
if non_interactive: if non_interactive:
console.print(f"[green]✓ 已加载已有项目: {project_name} (ID={existing.id})[/green]") console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]")
return existing return existing
use_existing = Confirm.ask(f"⚠️ 项目 '{project_name}' 已存在,是否继续使用?") if Confirm.ask(f"⚠️ 项目 '{project_name}' 已存在,继续使用?"):
if use_existing:
console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]") console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]")
return existing return existing
project_name = Prompt.ask("请输入新的项目名称") project_name = Prompt.ask("请输入新的项目名称")
output_dir = build_project_output_dir(project_name)
project = Project( project = Project(
name=project_name, name = project_name,
language=language, language = language,
output_dir=output_dir, output_dir = build_project_output_dir(project_name),
description=description, description = description,
) )
project.id = db.create_project(project) project.id = db.create_project(project)
console.print(f"[green]✓ 项目已创建: {project_name} (ID={project.id})[/green]") console.print(f"[green]✓ 项目已创建: {project_name} (ID={project.id})[/green]")
@ -178,11 +167,10 @@ def step_input_requirement(
non_interactive: bool = False, non_interactive: bool = False,
) -> tuple: ) -> tuple:
console.print( console.print(
f"\n[bold]Step 2 · 输入原始需求[/bold]" "\n[bold]Step 2 · 输入原始需求[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""), + (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue", style="blue",
) )
raw_text = "" raw_text = ""
source_name = None source_name = None
source_type = "text" source_type = "text"
@ -195,12 +183,11 @@ def step_input_requirement(
console.print(f" 需求文件: {source_name} ({len(raw_text)} 字符)") console.print(f" 需求文件: {source_name} ({len(raw_text)} 字符)")
elif requirement_text: elif requirement_text:
raw_text = requirement_text raw_text = requirement_text
source_type = "text" console.print(f" 需求文本: {raw_text[:80]}{'...' if len(raw_text)>80 else ''}")
console.print(f" 需求文本: {raw_text[:80]}{'...' if len(raw_text) > 80 else ''}")
else: else:
raise ValueError("非交互模式下必须提供 --requirement-text 或 --requirement-file") raise ValueError("非交互模式下必须提供 --requirement-text 或 --requirement-file")
else: else:
input_type = Prompt.ask("📥 需求输入方式", choices=["text", "file"], default="text") input_type = Prompt.ask("📥 需求输入方式", choices=["text","file"], default="text")
if input_type == "text": if input_type == "text":
console.print("[dim]请输入原始需求(输入空行结束):[/dim]") console.print("[dim]请输入原始需求(输入空行结束):[/dim]")
lines = [] lines = []
@ -210,13 +197,12 @@ def step_input_requirement(
break break
lines.append(line) lines.append(line)
raw_text = "\n".join(lines) raw_text = "\n".join(lines)
source_type = "text"
else: else:
file_path = Prompt.ask("📂 需求文件路径") fp = Prompt.ask("📂 需求文件路径")
raw_text = read_file_auto(file_path) raw_text = read_file_auto(fp)
source_name = os.path.basename(file_path) source_name = os.path.basename(fp)
source_type = "file" source_type = "file"
console.print(f"[green]✓ 已读取文件: {source_name} ({len(raw_text)} 字符)[/green]") console.print(f"[green]✓ 已读取: {source_name} ({len(raw_text)} 字符)[/green]")
knowledge_text = "" knowledge_text = ""
if non_interactive: if non_interactive:
@ -224,18 +210,17 @@ def step_input_requirement(
knowledge_text = merge_knowledge_files(list(knowledge_files)) knowledge_text = merge_knowledge_files(list(knowledge_files))
console.print(f" 知识库: {len(knowledge_files)} 个文件,{len(knowledge_text)} 字符") console.print(f" 知识库: {len(knowledge_files)} 个文件,{len(knowledge_text)} 字符")
else: else:
use_kb = Confirm.ask("📚 是否输入知识库文件?", default=False) if Confirm.ask("📚 是否输入知识库文件?", default=False):
if use_kb:
kb_paths = [] kb_paths = []
while True: while True:
kb_path = Prompt.ask("知识库文件路径(留空结束)", default="") p = Prompt.ask("知识库文件路径(留空结束)", default="")
if not kb_path: if not p:
break break
if os.path.exists(kb_path): if os.path.exists(p):
kb_paths.append(kb_path) kb_paths.append(p)
console.print(f" [green]+ {kb_path}[/green]") console.print(f" [green]+ {p}[/green]")
else: else:
console.print(f" [red]文件不存在: {kb_path}[/red]") console.print(f" [red]文件不存在: {p}[/red]")
if kb_paths: if kb_paths:
knowledge_text = merge_knowledge_files(kb_paths) knowledge_text = merge_knowledge_files(kb_paths)
console.print(f"[green]✓ 知识库已合并 ({len(knowledge_text)} 字符)[/green]") console.print(f"[green]✓ 知识库已合并 ({len(knowledge_text)} 字符)[/green]")
@ -256,29 +241,28 @@ def step_decompose_requirements(
non_interactive: bool = False, non_interactive: bool = False,
) -> tuple: ) -> tuple:
console.print( console.print(
f"\n[bold]Step 3 · LLM 需求分解[/bold]" "\n[bold]Step 3 · LLM 需求分解[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""), + (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue", style="blue",
) )
raw_req = RawRequirement( raw_req = RawRequirement(
project_id=project.id, project_id = project.id,
content=raw_text, content = raw_text,
source_type=source_type, source_type = source_type,
source_name=source_name, source_name = source_name,
knowledge=knowledge_text or None, knowledge = knowledge_text or None,
) )
raw_req_id = db.create_raw_requirement(raw_req) raw_req_id = db.create_raw_requirement(raw_req)
console.print(f"[dim]原始需求已存储 (ID={raw_req_id})[/dim]") console.print(f"[dim]原始需求已存储 (ID={raw_req_id})[/dim]")
with console.status("[bold yellow]🤖 LLM 正在分解需求,请稍候...[/bold yellow]"): with console.status("[bold yellow]🤖 LLM 正在分解需求...[/bold yellow]"):
llm = LLMClient() llm = LLMClient()
analyzer = RequirementAnalyzer(llm) analyzer = RequirementAnalyzer(llm)
func_reqs = analyzer.decompose( func_reqs = analyzer.decompose(
raw_requirement=raw_text, raw_requirement = raw_text,
project_id=project.id, project_id = project.id,
raw_req_id=raw_req_id, raw_req_id = raw_req_id,
knowledge=knowledge_text, knowledge = knowledge_text,
) )
for req in func_reqs: for req in func_reqs:
@ -289,18 +273,107 @@ def step_decompose_requirements(
# ══════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════
# Step 4用户编辑功能需求 # Step 4模块分类可选重新分类
# ══════════════════════════════════════════════════════
def step_classify_modules(
project: Project,
func_reqs: List[FunctionalRequirement],
knowledge_text: str = "",
non_interactive: bool = False,
) -> List[FunctionalRequirement]:
"""
Step 4对功能需求进行模块分类
- 非交互模式直接使用 LLM 分类结果
- 交互模式展示 LLM 分类结果允许用户手动调整
"""
console.print(
"\n[bold]Step 4 · 功能模块分类[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
# LLM 自动分类
with console.status("[bold yellow]🤖 LLM 正在进行模块分类...[/bold yellow]"):
llm = LLMClient()
analyzer = RequirementAnalyzer(llm)
try:
updates = analyzer.classify_modules(func_reqs, knowledge_text)
# 回写 module 到 func_reqs 对象
name_to_module = {u["function_name"]: u["module"] for u in updates}
for req in func_reqs:
req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE)
db.update_functional_requirement(req)
console.print(f"[green]✓ LLM 模块分类完成[/green]")
except Exception as e:
console.print(f"[yellow]⚠ 模块分类失败,保留原有模块: {e}[/yellow]")
print_module_summary(func_reqs)
if non_interactive:
return func_reqs
# 交互式调整
while True:
console.print(
"\n模块操作: [cyan]r[/cyan]=重新分类 "
"[cyan]e[/cyan]=手动编辑某需求的模块 [cyan]ok[/cyan]=确认继续"
)
action = Prompt.ask("请选择操作", default="ok").strip().lower()
if action == "ok":
break
elif action == "r":
# 重新触发 LLM 分类
with console.status("[bold yellow]🤖 重新分类中...[/bold yellow]"):
try:
updates = analyzer.classify_modules(func_reqs, knowledge_text)
name_to_module = {u["function_name"]: u["module"] for u in updates}
for req in func_reqs:
req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE)
db.update_functional_requirement(req)
console.print("[green]✓ 重新分类完成[/green]")
except Exception as e:
console.print(f"[red]重新分类失败: {e}[/red]")
print_module_summary(func_reqs)
elif action == "e":
print_functional_requirements(func_reqs)
idx_str = Prompt.ask("输入要修改模块的需求序号")
if not idx_str.isdigit():
continue
idx = int(idx_str)
target = next((r for r in func_reqs if r.index_no == idx), None)
if target is None:
console.print("[red]序号不存在[/red]")
continue
new_module = Prompt.ask(
f"新模块名(当前: {target.module}",
default=target.module or config.DEFAULT_MODULE,
)
target.module = new_module.strip() or config.DEFAULT_MODULE
db.update_functional_requirement(target)
console.print(f"[green]✓ 已更新 '{target.function_name}' → 模块: {target.module}[/green]")
print_module_summary(func_reqs)
return func_reqs
# ══════════════════════════════════════════════════════
# Step 5编辑功能需求
# ══════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════
def step_edit_requirements( def step_edit_requirements(
project: Project, project: Project,
func_reqs: list, func_reqs: List[FunctionalRequirement],
raw_req_id: int, raw_req_id: int,
non_interactive: bool = False, non_interactive: bool = False,
skip_indices: list = None, skip_indices: list = None,
) -> list: ) -> List[FunctionalRequirement]:
console.print( console.print(
f"\n[bold]Step 4 · 编辑功能需求[/bold]" "\n[bold]Step 5 · 编辑功能需求[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""), + (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue", style="blue",
) )
@ -334,8 +407,9 @@ def step_edit_requirements(
if action == "ok": if action == "ok":
break break
elif action == "d": elif action == "d":
idx_str = Prompt.ask("输入要删除的功能需求序号(多个用逗号分隔)") idx_str = Prompt.ask("输入要删除的序号(多个用逗号分隔)")
to_delete = {int(x.strip()) for x in idx_str.split(",") if x.strip().isdigit()} to_delete = {int(x.strip()) for x in idx_str.split(",") if x.strip().isdigit()}
removed, kept = [], [] removed, kept = [], []
for req in func_reqs: for req in func_reqs:
@ -349,28 +423,35 @@ def step_edit_requirements(
req.index_no = i req.index_no = i
db.update_functional_requirement(req) db.update_functional_requirement(req)
console.print(f"[red]✗ 已删除: {', '.join(removed)}[/red]") console.print(f"[red]✗ 已删除: {', '.join(removed)}[/red]")
elif action == "a": elif action == "a":
title = Prompt.ask("功能标题") title = Prompt.ask("功能标题")
description = Prompt.ask("功能描述") description = Prompt.ask("功能描述")
func_name = Prompt.ask("函数名 (snake_case)") func_name = Prompt.ask("函数名 (snake_case)")
priority = Prompt.ask( priority = Prompt.ask(
"优先级", choices=["high", "medium", "low"], default="medium" "优先级", choices=["high","medium","low"], default="medium"
)
module = Prompt.ask(
"所属模块snake_case留空使用默认",
default=config.DEFAULT_MODULE,
) )
new_req = FunctionalRequirement( new_req = FunctionalRequirement(
project_id=project.id, project_id = project.id,
raw_req_id=raw_req_id, raw_req_id = raw_req_id,
index_no=len(func_reqs) + 1, index_no = len(func_reqs) + 1,
title=title, title = title,
description=description, description = description,
function_name=func_name, function_name = func_name,
priority=priority, priority = priority,
is_custom=True, module = module.strip() or config.DEFAULT_MODULE,
is_custom = True,
) )
new_req.id = db.create_functional_requirement(new_req) new_req.id = db.create_functional_requirement(new_req)
func_reqs.append(new_req) func_reqs.append(new_req)
console.print(f"[green]✓ 已添加自定义需求: {title}[/green]") console.print(f"[green]✓ 已添加: {title} → 模块: {new_req.module}[/green]")
elif action == "e": elif action == "e":
idx_str = Prompt.ask("输入要编辑的功能需求序号") idx_str = Prompt.ask("输入要编辑的序号")
if not idx_str.isdigit(): if not idx_str.isdigit():
continue continue
idx = int(idx_str) idx = int(idx_str)
@ -382,8 +463,11 @@ def step_edit_requirements(
target.description = Prompt.ask("新描述", default=target.description) target.description = Prompt.ask("新描述", default=target.description)
target.function_name = Prompt.ask("新函数名", default=target.function_name) target.function_name = Prompt.ask("新函数名", default=target.function_name)
target.priority = Prompt.ask( target.priority = Prompt.ask(
"新优先级", choices=["high", "medium", "low"], default=target.priority "新优先级", choices=["high","medium","low"], default=target.priority
) )
target.module = Prompt.ask(
"新模块", default=target.module or config.DEFAULT_MODULE
).strip() or config.DEFAULT_MODULE
db.update_functional_requirement(target) db.update_functional_requirement(target)
console.print(f"[green]✓ 已更新: {target.title}[/green]") console.print(f"[green]✓ 已更新: {target.title}[/green]")
@ -391,33 +475,24 @@ def step_edit_requirements(
# ══════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════
# Step 5A生成函数签名 JSON不含 url 字段,待 5C 回写 # Step 6A生成函数签名 JSON初版不含 url
# ══════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════
def step_generate_signatures( def step_generate_signatures(
project: Project, project: Project,
func_reqs: list, func_reqs: List[FunctionalRequirement],
output_dir: str, output_dir: str,
knowledge_text: str, knowledge_text: str,
json_file_name: str = "function_signatures.json", json_file_name: str = "function_signatures.json",
non_interactive: bool = False, non_interactive: bool = False,
) -> tuple: ) -> tuple:
"""
为所有功能需求生成函数签名写入初版 JSON不含 url 字段
url 字段将在 Step 5C 代码生成完成后回写并刷新 JSON 文件
Returns:
(signatures: List[dict], json_path: str)
"""
console.print( console.print(
f"\n[bold]Step 5A · 生成函数签名 JSON[/bold]" "\n[bold]Step 6A · 生成函数签名 JSON[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""), + (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue", style="blue",
) )
llm = LLMClient() llm = LLMClient()
analyzer = RequirementAnalyzer(llm) analyzer = RequirementAnalyzer(llm)
success_count = 0 success_count = 0
fail_count = 0 fail_count = 0
@ -425,42 +500,40 @@ def step_generate_signatures(
nonlocal success_count, fail_count nonlocal success_count, fail_count
if error: if error:
console.print( console.print(
f" [{index}/{total}] [yellow]⚠ {req.title} 签名生成失败" f" [{index}/{total}] [yellow]⚠ {req.title} 签名生成失败"
f"使用降级结构: {error}[/yellow]" f"(降级): {error}[/yellow]"
) )
fail_count += 1 fail_count += 1
else: else:
console.print( console.print(
f" [{index}/{total}] [green]✓ {req.title}[/green] " f" [{index}/{total}] [green]✓ {req.title}[/green] "
f"→ [dim]{signature.get('name')}()[/dim] " f"[dim]{req.module}[/dim] → {signature.get('name')}()"
f"params={len(signature.get('parameters', {}))}"
) )
success_count += 1 success_count += 1
console.print(f"[yellow]正在为 {len(func_reqs)} 个功能需求生成函数签名...[/yellow]") console.print(f"[yellow]正在为 {len(func_reqs)} 个功能需求生成函数签名...[/yellow]")
signatures = analyzer.build_function_signatures_batch( signatures = analyzer.build_function_signatures_batch(
func_reqs=func_reqs, func_reqs = func_reqs,
knowledge=knowledge_text, knowledge = knowledge_text,
on_progress=on_progress, on_progress = on_progress,
) )
# 校验 # 校验
validation_report = validate_all_signatures(signatures) report = validate_all_signatures(signatures)
if validation_report: if report:
console.print(f"[yellow]⚠ 发现 {len(validation_report)} 个签名存在结构问题:[/yellow]") console.print(f"[yellow]⚠ {len(report)} 个签名存在结构问题:[/yellow]")
for fname, errors in validation_report.items(): for fname, errs in report.items():
for err in errors: for err in errs:
console.print(f" [yellow]· {fname}: {err}[/yellow]") console.print(f" [yellow]· {fname}: {err}[/yellow]")
else: else:
console.print("[green]✓ 所有签名结构校验通过[/green]") console.print("[green]✓ 所有签名结构校验通过[/green]")
# 写入初版 JSONurl 字段尚未填入)
json_path = write_function_signatures_json( json_path = write_function_signatures_json(
output_dir=output_dir, output_dir = output_dir,
signatures=signatures, signatures = signatures,
project_name=project.name, project_name = project.name,
project_description=project.description or "", # ← 传入项目描述 project_description = project.description or "",
file_name=json_file_name, file_name = json_file_name,
) )
console.print( console.print(
f"[green]✓ 签名 JSON 初版已写入: [cyan]{os.path.abspath(json_path)}[/cyan][/green]\n" f"[green]✓ 签名 JSON 初版已写入: [cyan]{os.path.abspath(json_path)}[/cyan][/green]\n"
@ -470,35 +543,32 @@ def step_generate_signatures(
# ══════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════
# Step 5B生成代码文件收集 {函数名: 文件路径} 映射 # Step 6B生成代码文件按模块写入子目录
# ══════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════
def step_generate_code( def step_generate_code(
project: Project, project: Project,
func_reqs: list, func_reqs: List[FunctionalRequirement],
output_dir: str, output_dir: str,
knowledge_text: str, knowledge_text: str,
signatures: list, signatures: List[dict],
non_interactive: bool = False, non_interactive: bool = False,
) -> Dict[str, str]: ) -> Dict[str, str]:
""" """
依据签名约束批量生成代码文件 批量生成代码文件 req.module 路由到 output_dir/<module>/ 子目录
Returns: Returns:
func_name_to_url: {函数名: 代码文件绝对路径} 映射表 func_name_to_url: {函数名: 代码文件绝对路径}
Step 5C 回写 url 字段使用
生成失败的函数不会出现在映射表中
""" """
console.print( console.print(
f"\n[bold]Step 5B · 生成代码文件[/bold]" "\n[bold]Step 6B · 生成代码文件[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""), + (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue", style="blue",
) )
generator = CodeGenerator(LLMClient()) generator = CodeGenerator(LLMClient())
success_count = 0 success_count = 0
fail_count = 0 fail_count = 0
func_name_to_url: Dict[str, str] = {} # ← 收集 函数名 → 文件绝对路径 func_name_to_url: Dict[str, str] = {}
def on_progress(index, total, req, code_file, error): def on_progress(index, total, req, code_file, error):
nonlocal success_count, fail_count nonlocal success_count, fail_count
@ -509,29 +579,38 @@ def step_generate_code(
db.upsert_code_file(code_file) db.upsert_code_file(code_file)
req.status = "generated" req.status = "generated"
db.update_functional_requirement(req) db.update_functional_requirement(req)
# 收集 函数名 → 绝对文件路径(作为 url 回写)
func_name_to_url[req.function_name] = os.path.abspath(code_file.file_path) func_name_to_url[req.function_name] = os.path.abspath(code_file.file_path)
console.print( console.print(
f" [{index}/{total}] [green]✓ {req.title}[/green] " f" [{index}/{total}] [green]✓ {req.title}[/green] "
f"[dim]{code_file.file_name}[/dim]" f"[dim]{req.module}/{code_file.file_name}[/dim]"
) )
success_count += 1 success_count += 1
console.print(f"[yellow]开始生成 {len(func_reqs)} 个代码文件(签名约束模式)...[/yellow]") console.print(
f"[yellow]开始生成 {len(func_reqs)} 个代码文件(按模块分目录)...[/yellow]"
)
generator.generate_batch( generator.generate_batch(
func_reqs=func_reqs, func_reqs = func_reqs,
output_dir=output_dir, output_dir = output_dir,
language=project.language, language = project.language,
knowledge=knowledge_text, knowledge = knowledge_text,
signatures=signatures, signatures = signatures,
on_progress=on_progress, on_progress = on_progress,
) )
# 生成 README含模块列表
modules = list({req.module or config.DEFAULT_MODULE for req in func_reqs})
req_summary = "\n".join( req_summary = "\n".join(
f"{i+1}. **{r.title}** (`{r.function_name}`) - {r.description[:80]}" f"{i+1}. **{r.title}** (`{r.module}/{r.function_name}`) - {r.description[:80]}"
for i, r in enumerate(func_reqs) for i, r in enumerate(func_reqs)
) )
write_project_readme(output_dir, project.name, req_summary) write_project_readme(
output_dir = output_dir,
project_name = project.name,
project_description = project.description or "",
requirements_summary = req_summary,
modules = modules,
)
console.print(Panel( console.print(Panel(
f"[bold green]✅ 代码生成完成![/bold green]\n" f"[bold green]✅ 代码生成完成![/bold green]\n"
@ -543,65 +622,38 @@ def step_generate_code(
# ══════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════
# Step 5C回写 url 字段并刷新 JSON # Step 6C回写 url 字段并刷新 JSON
# ══════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════
def step_patch_signatures_url( def step_patch_signatures_url(
project: Project, project: Project,
signatures: list, signatures: List[dict],
func_name_to_url: Dict[str, str], func_name_to_url: Dict[str, str],
output_dir: str, output_dir: str,
json_file_name: str, json_file_name: str,
non_interactive: bool = False, non_interactive: bool = False,
) -> str: ) -> str:
"""
将代码文件路径回写到签名的 "url" 字段并重新写入 JSON 文件
执行流程
1. 调用 patch_signatures_with_url() 原地修改签名列表
2. 打印最终签名预览 url
3. 重新调用 write_function_signatures_json() 覆盖写入 JSON
Args:
project: 项目对象提供 name description
signatures: Step 5A 产出的签名列表将被原地修改
func_name_to_url: Step 5B 收集的 {函数名: 文件绝对路径} 映射
output_dir: JSON 文件所在目录
json_file_name: JSON 文件名
non_interactive: 是否非交互模式
Returns:
刷新后的 JSON 文件绝对路径
"""
console.print( console.print(
f"\n[bold]Step 5C · 回写代码文件路径url到签名 JSON[/bold]" "\n[bold]Step 6C · 回写代码路径url到签名 JSON[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""), + (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue", style="blue",
) )
# 原地回写 url 字段
patch_signatures_with_url(signatures, func_name_to_url) patch_signatures_with_url(signatures, func_name_to_url)
patched = sum(1 for s in signatures if s.get("url")) patched = sum(1 for s in signatures if s.get("url"))
unpatched = len(signatures) - patched unpatched = len(signatures) - patched
if unpatched: if unpatched:
console.print( console.print(f"[yellow]⚠ {unpatched} 个函数 url 未回写(代码生成失败)[/yellow]")
f"[yellow]⚠ {unpatched} 个函数未能写入 url"
f"(对应代码文件生成失败)[/yellow]"
)
# 打印最终预览(含 url 列)
print_signatures_preview(signatures) print_signatures_preview(signatures)
# 覆盖写入 JSON含 project.description
json_path = write_function_signatures_json( json_path = write_function_signatures_json(
output_dir=output_dir, output_dir = output_dir,
signatures=signatures, signatures = signatures,
project_name=project.name, project_name = project.name,
project_description=project.description or "", project_description = project.description or "",
file_name=json_file_name, file_name = json_file_name,
) )
console.print( console.print(
f"[green]✓ 签名 JSON 已更新(含 url: " f"[green]✓ 签名 JSON 已更新(含 url: "
f"[cyan]{os.path.abspath(json_path)}[/cyan][/green]\n" f"[cyan]{os.path.abspath(json_path)}[/cyan][/green]\n"
@ -625,43 +677,51 @@ def run_workflow(
json_file_name: str = "function_signatures.json", json_file_name: str = "function_signatures.json",
non_interactive: bool = False, non_interactive: bool = False,
): ):
"""完整工作流Step 1 → 5C""" """完整工作流 Step 1 → 6C"""
print_banner() print_banner()
# Step 1 # Step 1:项目初始化
project = step_init_project( project = step_init_project(
project_name=project_name, project_name = project_name,
language=language, language = language,
description=description, description = description,
non_interactive=non_interactive, non_interactive = non_interactive,
) )
# Step 2 # Step 2:输入原始需求
raw_text, knowledge_text, source_name, source_type = step_input_requirement( raw_text, knowledge_text, source_name, source_type = step_input_requirement(
project=project, project = project,
requirement_text=requirement_text, requirement_text = requirement_text,
requirement_file=requirement_file, requirement_file = requirement_file,
knowledge_files=list(knowledge_files) if knowledge_files else [], knowledge_files = list(knowledge_files) if knowledge_files else [],
non_interactive=non_interactive, non_interactive = non_interactive,
) )
# Step 3 # Step 3LLM 需求分解
raw_req_id, func_reqs = step_decompose_requirements( raw_req_id, func_reqs = step_decompose_requirements(
project=project, project = project,
raw_text=raw_text, raw_text = raw_text,
knowledge_text=knowledge_text, knowledge_text = knowledge_text,
source_name=source_name, source_name = source_name,
source_type=source_type, source_type = source_type,
non_interactive=non_interactive, non_interactive = non_interactive,
) )
# Step 4 # Step 4模块分类
func_reqs = step_classify_modules(
project = project,
func_reqs = func_reqs,
knowledge_text = knowledge_text,
non_interactive = non_interactive,
)
# Step 5编辑功能需求
func_reqs = step_edit_requirements( func_reqs = step_edit_requirements(
project=project, project = project,
func_reqs=func_reqs, func_reqs = func_reqs,
raw_req_id=raw_req_id, raw_req_id = raw_req_id,
non_interactive=non_interactive, non_interactive = non_interactive,
skip_indices=skip_indices or [], skip_indices = skip_indices or [],
) )
if not func_reqs: if not func_reqs:
@ -670,40 +730,43 @@ def run_workflow(
output_dir = ensure_project_dir(project.name) output_dir = ensure_project_dir(project.name)
# Step 5A生成签名初版不含 url # Step 6A生成函数签名
signatures, json_path = step_generate_signatures( signatures, json_path = step_generate_signatures(
project=project, project = project,
func_reqs=func_reqs, func_reqs = func_reqs,
output_dir=output_dir, output_dir = output_dir,
knowledge_text=knowledge_text, knowledge_text = knowledge_text,
json_file_name=json_file_name, json_file_name = json_file_name,
non_interactive=non_interactive, non_interactive = non_interactive,
) )
# Step 5B生成代码收集 {函数名: 文件路径} # Step 6B生成代码文件
func_name_to_url = step_generate_code( func_name_to_url = step_generate_code(
project=project, project = project,
func_reqs=func_reqs, func_reqs = func_reqs,
output_dir=output_dir, output_dir = output_dir,
knowledge_text=knowledge_text, knowledge_text = knowledge_text,
signatures=signatures, signatures = signatures,
non_interactive=non_interactive, non_interactive = non_interactive,
) )
# Step 5C回写 url 字段,刷新 JSON # Step 6C回写 url,刷新 JSON
json_path = step_patch_signatures_url( json_path = step_patch_signatures_url(
project=project, project = project,
signatures=signatures, signatures = signatures,
func_name_to_url=func_name_to_url, func_name_to_url = func_name_to_url,
output_dir=output_dir, output_dir = output_dir,
json_file_name=json_file_name, json_file_name = json_file_name,
non_interactive=non_interactive, non_interactive = non_interactive,
) )
# 最终汇总
modules = sorted({req.module or config.DEFAULT_MODULE for req in func_reqs})
console.print(Panel( console.print(Panel(
f"[bold cyan]🎉 全部流程完成![/bold cyan]\n" f"[bold cyan]🎉 全部流程完成![/bold cyan]\n"
f"项目: [bold]{project.name}[/bold]\n" f"项目: [bold]{project.name}[/bold]\n"
f"描述: {project.description or '(无)'}\n" f"描述: {project.description or '(无)'}\n"
f"模块: {', '.join(modules)}\n"
f"代码目录: [cyan]{os.path.abspath(output_dir)}[/cyan]\n" f"代码目录: [cyan]{os.path.abspath(output_dir)}[/cyan]\n"
f"签名文件: [cyan]{json_path}[/cyan]", f"签名文件: [cyan]{json_path}[/cyan]",
border_style="cyan", border_style="cyan",
@ -716,24 +779,21 @@ def run_workflow(
@click.command() @click.command()
@click.option("--non-interactive", is_flag=True, default=False, @click.option("--non-interactive", is_flag=True, default=False,
help="以非交互模式运行(所有参数通过命令行传入)") help="以非交互模式运行")
@click.option("--project-name", "-p", default=None, help="项目名称") @click.option("--project-name", "-p", default=None, help="项目名称")
@click.option("--language", "-l", default=None, @click.option("--language", "-l", default=None,
type=click.Choice(["python","javascript","typescript","java","go","rust"]), type=click.Choice(["python","javascript","typescript","java","go","rust"]),
help=f"目标代码语言(默认: {config.DEFAULT_LANGUAGE}") help=f"目标代码语言(默认: {config.DEFAULT_LANGUAGE}")
@click.option("--description", "-d", default="", help="项目描述") @click.option("--description", "-d", default="", help="项目描述")
@click.option("--requirement-text","-r", default=None, @click.option("--requirement-text","-r", default=None, help="原始需求文本")
help="原始需求文本(与 --requirement-file 二选一)")
@click.option("--requirement-file","-f", default=None, @click.option("--requirement-file","-f", default=None,
type=click.Path(exists=True), type=click.Path(exists=True), help="原始需求文件路径")
help="原始需求文件路径(支持 .txt/.md/.pdf/.docx")
@click.option("--knowledge-file", "-k", default=None, multiple=True, @click.option("--knowledge-file", "-k", default=None, multiple=True,
type=click.Path(exists=True), type=click.Path(exists=True), help="知识库文件(可多次指定)")
help="知识库文件路径(可多次指定,如 -k a.md -k b.pdf")
@click.option("--skip-index", "-s", default=None, multiple=True, type=int, @click.option("--skip-index", "-s", default=None, multiple=True, type=int,
help="跳过的功能需求序号(可多次指定,如 -s 2 -s 5") help="跳过的功能需求序号(可多次指定)")
@click.option("--json-file-name", "-j", default="function_signatures.json", @click.option("--json-file-name", "-j", default="function_signatures.json",
help="函数签名 JSON 文件名(默认: function_signatures.json") help="签名 JSON 文件名")
def cli( def cli(
non_interactive, project_name, language, description, non_interactive, project_name, language, description,
requirement_text, requirement_file, knowledge_file, requirement_text, requirement_file, knowledge_file,
@ -743,7 +803,7 @@ def cli(
需求分析 & 代码生成工具 需求分析 & 代码生成工具
\b \b
交互式运行推荐初次使用 交互式运行
python main.py python main.py
\b \b
@ -752,28 +812,19 @@ def cli(
--project-name "UserSystem" \\ --project-name "UserSystem" \\
--description "用户管理系统后端服务" \\ --description "用户管理系统后端服务" \\
--language python \\ --language python \\
--requirement-text "用户管理系统,包含注册、登录、修改密码功能" \\ --requirement-text "用户管理系统,包含注册、登录、修改密码功能"
--knowledge-file docs/api_spec.md \\
--json-file-name api_signatures.json
\b
从文件读取需求 + 跳过部分功能需求
python main.py --non-interactive \\
--project-name "MyProject" \\
--requirement-file requirements.md \\
--skip-index 3 --skip-index 7
""" """
try: try:
run_workflow( run_workflow(
project_name=project_name, project_name = project_name,
language=language, language = language,
description=description, description = description,
requirement_text=requirement_text, requirement_text = requirement_text,
requirement_file=requirement_file, requirement_file = requirement_file,
knowledge_files=knowledge_file, knowledge_files = knowledge_file,
skip_indices=list(skip_index) if skip_index else [], skip_indices = list(skip_index) if skip_index else [],
json_file_name=json_file_name, json_file_name = json_file_name,
non_interactive=non_interactive, non_interactive = non_interactive,
) )
except KeyboardInterrupt: except KeyboardInterrupt:
console.print("\n[yellow]用户中断,退出[/yellow]") console.print("\n[yellow]用户中断,退出[/yellow]")

View File

@ -1,6 +1,7 @@
openai>=1.0.0 openai>=1.30.0
python-dotenv>=1.0.0
rich>=13.0.0
python-docx>=0.8.11
PyPDF2>=3.0.0
click>=8.1.0 click>=8.1.0
rich>=13.0.0
sqlalchemy>=2.0.0
python-dotenv>=1.0.0
pypdf>=4.0.0
python-docx>=1.1.0

View File

@ -0,0 +1,11 @@
#!/usr/bin/env bash
# 安装依赖
pip install -r requirements.txt
# 设置环境变量
set OPENAI_API_KEY="sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR"
set OPENAI_BASE_URL="https://openapi.monica.im/v1" # 或其他兼容接口
set LLM_MODEL="gpt-4o"
python main.py

0
requirements_generator/run.sh Executable file → Normal file
View File

View File

@ -1,95 +1,13 @@
# utils/file_handler.py - 文件读取工具(支持 txt/md/pdf/docx # utils/file_handler.py - 文件读取工具(支持 txt / md / pdf / docx
import os import os
from typing import List, Optional from typing import List
from pathlib import Path
def read_text_file(file_path: str) -> str:
"""
读取纯文本文件内容.txt / .md / .py
Args:
file_path: 文件路径
Returns:
文件文本内容
Raises:
FileNotFoundError: 文件不存在
UnicodeDecodeError: 编码错误时尝试 latin-1 兜底
"""
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
try:
return path.read_text(encoding="utf-8")
except UnicodeDecodeError:
return path.read_text(encoding="latin-1")
def read_pdf_file(file_path: str) -> str:
"""
读取 PDF 文件内容
Args:
file_path: PDF 文件路径
Returns:
提取的文本内容
Raises:
ImportError: 未安装 PyPDF2
FileNotFoundError: 文件不存在
"""
try:
import PyPDF2
except ImportError:
raise ImportError("请安装 PyPDF2: pip install PyPDF2")
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
texts = []
with open(file_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text = page.extract_text()
if text:
texts.append(text)
return "\n".join(texts)
def read_docx_file(file_path: str) -> str:
"""
读取 Word (.docx) 文件内容
Args:
file_path: docx 文件路径
Returns:
提取的文本内容段落合并
Raises:
ImportError: 未安装 python-docx
FileNotFoundError: 文件不存在
"""
try:
from docx import Document
except ImportError:
raise ImportError("请安装 python-docx: pip install python-docx")
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
doc = Document(file_path)
return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
def read_file_auto(file_path: str) -> str: def read_file_auto(file_path: str) -> str:
""" """
根据文件扩展名自动选择读取方式 自动识别文件类型并读取文本内容
支持格式.txt / .md / .pdf / .docx / 其他 UTF-8 读取
Args: Args:
file_path: 文件路径 file_path: 文件路径
@ -98,45 +16,66 @@ def read_file_auto(file_path: str) -> str:
文件文本内容 文件文本内容
Raises: Raises:
ValueError: 不支持的文件类型 FileNotFoundError: 文件不存在
RuntimeError: 读取失败
""" """
ext = Path(file_path).suffix.lower() if not os.path.exists(file_path):
readers = { raise FileNotFoundError(f"文件不存在: {file_path}")
".txt": read_text_file,
".md": read_text_file, ext = os.path.splitext(file_path)[1].lower()
".py": read_text_file,
".json": read_text_file, try:
".yaml": read_text_file, if ext == ".pdf":
".yml": read_text_file, return _read_pdf(file_path)
".pdf": read_pdf_file, elif ext in (".docx", ".doc"):
".docx": read_docx_file, return _read_docx(file_path)
} else:
reader = readers.get(ext) # txt / md / 其他文本格式
if reader is None: with open(file_path, "r", encoding="utf-8", errors="replace") as f:
raise ValueError(f"不支持的文件类型: {ext},支持: {list(readers.keys())}") return f.read()
return reader(file_path) except Exception as e:
raise RuntimeError(f"读取文件失败 [{file_path}]: {e}")
def _read_pdf(file_path: str) -> str:
"""读取 PDF 文件文本"""
try:
from pypdf import PdfReader
except ImportError:
raise RuntimeError("读取 PDF 需要安装 pypdfpip install pypdf")
reader = PdfReader(file_path)
pages = [page.extract_text() or "" for page in reader.pages]
return "\n".join(pages)
def _read_docx(file_path: str) -> str:
"""读取 Word 文档文本"""
try:
from docx import Document
except ImportError:
raise RuntimeError("读取 docx 需要安装 python-docxpip install python-docx")
doc = Document(file_path)
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
return "\n".join(paragraphs)
def merge_knowledge_files(file_paths: List[str]) -> str: def merge_knowledge_files(file_paths: List[str]) -> str:
""" """
合并多个知识库文件为单一文本 合并多个知识库文件为单一文本
Args: Args:
file_paths: 知识库文件路径列表 file_paths: 文件路径列表
Returns: Returns:
合并后的知识库文本包含文件名分隔符 合并后的文本各文件以分隔线隔开
""" """
if not file_paths: parts = []
return "" for path in file_paths:
sections = []
for fp in file_paths:
try: try:
content = read_file_auto(fp) content = read_file_auto(path)
file_name = Path(fp).name parts.append(f"--- {os.path.basename(path)} ---\n{content}")
sections.append(f"### 知识库文件: {file_name}\n{content}")
except Exception as e: except Exception as e:
sections.append(f"### 知识库文件: {fp}\n[读取失败: {e}]") parts.append(f"--- {os.path.basename(path)} [读取失败: {e}] ---")
return "\n\n".join(parts)
return "\n\n".join(sections)

View File

@ -6,133 +6,74 @@ from typing import Dict, List
import config import config
# 各语言文件扩展名映射
LANGUAGE_EXT_MAP: Dict[str, str] = {
"python": ".py",
"javascript": ".js",
"typescript": ".ts",
"java": ".java",
"go": ".go",
"rust": ".rs",
"cpp": ".cpp",
"c": ".c",
"csharp": ".cs",
"ruby": ".rb",
"php": ".php",
"swift": ".swift",
"kotlin": ".kt",
}
# 合法的通用类型集合
VALID_TYPES = { VALID_TYPES = {
"integer", "string", "boolean", "float", "integer", "string", "boolean", "float",
"list", "dict", "object", "void", "any", "list", "dict", "object", "void", "any",
} }
# 合法的 inout 值
VALID_INOUT = {"in", "out", "inout"} VALID_INOUT = {"in", "out", "inout"}
def get_file_extension(language: str) -> str: # ══════════════════════════════════════════════════════
""" # 目录管理
获取指定语言的文件扩展名 # ══════════════════════════════════════════════════════
Args:
language: 编程语言名称小写
Returns:
文件扩展名含点号 '.py'
"""
return LANGUAGE_EXT_MAP.get(language.lower(), ".txt")
def build_project_output_dir(project_name: str) -> str: def build_project_output_dir(project_name: str) -> str:
""" safe = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name)
构建项目输出目录路径 return os.path.join(config.OUTPUT_BASE_DIR, safe)
Args:
project_name: 项目名称
Returns:
输出目录路径
"""
safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name)
return os.path.join(config.OUTPUT_BASE_DIR, safe_name)
def ensure_project_dir(project_name: str) -> str: def ensure_project_dir(project_name: str) -> str:
""" """确保项目根输出目录存在,并创建 __init__.py"""
确保项目输出目录存在不存在则创建
Args:
project_name: 项目名称
Returns:
创建好的目录路径
"""
output_dir = build_project_output_dir(project_name) output_dir = build_project_output_dir(project_name)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
init_file = os.path.join(output_dir, "__init__.py") init_file = os.path.join(output_dir, "__init__.py")
if not os.path.exists(init_file): if not os.path.exists(init_file):
Path(init_file).write_text( Path(init_file).write_text("# Auto-generated project package\n", encoding="utf-8")
"# Auto-generated project package\n", encoding="utf-8"
)
return output_dir return output_dir
def write_code_file( def ensure_module_dir(output_dir: str, module: str) -> str:
output_dir: str, """确保模块子目录存在,并创建 __init__.py"""
function_name: str, module_dir = os.path.join(output_dir, module)
language: str, os.makedirs(module_dir, exist_ok=True)
content: str, init_file = os.path.join(module_dir, "__init__.py")
) -> str: if not os.path.exists(init_file):
""" Path(init_file).write_text(
将代码内容写入指定目录的文件 f"# Auto-generated module package: {module}\n", encoding="utf-8"
)
return module_dir
Args:
output_dir: 输出目录路径
function_name: 函数名用于生成文件名
language: 编程语言
content: 代码内容
Returns:
写入的文件完整路径
"""
ext = get_file_extension(language)
file_name = f"{function_name}{ext}"
file_path = os.path.join(output_dir, file_name)
Path(file_path).write_text(content, encoding="utf-8")
return file_path
# ══════════════════════════════════════════════════════
# README
# ══════════════════════════════════════════════════════
def write_project_readme( def write_project_readme(
output_dir: str, output_dir: str,
project_name: str, project_name: str,
project_description: str,
requirements_summary: str, requirements_summary: str,
modules: List[str] = None,
) -> str: ) -> str:
""" """生成项目 README.md"""
在项目目录生成 README.md 文件 module_section = ""
if modules:
module_list = "\n".join(f"- `{m}/`" for m in sorted(set(modules)))
module_section = f"\n## 功能模块\n\n{module_list}\n"
Args: content = f"""# {project_name}
output_dir: 项目输出目录
project_name: 项目名称
requirements_summary: 功能需求摘要文本
Returns:
README.md 文件路径
"""
readme_content = f"""# {project_name}
> Auto-generated by Requirement Analyzer > Auto-generated by Requirement Analyzer
{project_description or ""}
{module_section}
## 功能需求列表 ## 功能需求列表
{requirements_summary} {requirements_summary}
""" """
readme_path = os.path.join(output_dir, "README.md") path = os.path.join(output_dir, "README.md")
Path(readme_path).write_text(readme_content, encoding="utf-8") Path(path).write_text(content, encoding="utf-8")
return readme_path return path
# ══════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════
@ -145,19 +86,11 @@ def build_signatures_document(
signatures: List[dict], signatures: List[dict],
) -> dict: ) -> dict:
""" """
将函数签名列表包装为带项目信息的顶层文档结构 构建顶层签名文档结构::
Args:
project_name: 项目名称写入 "project" 字段
project_description: 项目描述写入 "description" 字段
signatures: 函数签名 dict 列表写入 "functions" 字段
Returns:
顶层文档 dict结构为::
{ {
"project": "<project_name>", "project": "<name>",
"description": "<project_description>", "description": "<description>",
"functions": [ ... ] "functions": [ ... ]
} }
""" """
@ -173,66 +106,28 @@ def patch_signatures_with_url(
func_name_to_url: Dict[str, str], func_name_to_url: Dict[str, str],
) -> List[dict]: ) -> List[dict]:
""" """
将代码文件的路径URL回写到对应函数签名的 "url" 字段 将代码文件路径回写到签名的 "url" 字段紧跟 "type" 之后
遍历签名列表根据 signature["name"] func_name_to_url 中查找
对应路径找到则写入 "url" 字段未找到则写入空字符串不抛出异常
"url" 字段插入位置紧跟在 "type" 字段之后以保持字段顺序的可读性::
{
"name": "create_user",
"requirement_id": "REQ.01",
"description": "...",
"type": "function",
"url": "/abs/path/to/create_user.py", 新增
"parameters": { ... },
"return": { ... }
}
Args: Args:
signatures: 原始签名列表in-place 修改 signatures: 签名列表in-place 修改
func_name_to_url: {函数名: 代码文件绝对路径} 映射表 func_name_to_url: {函数名: 文件绝对路径}
CodeGenerator.generate_batch() 的进度回调收集
Returns: Returns:
修改后的签名列表与传入的同一对象方便链式调用 修改后的签名列表
""" """
for sig in signatures: for sig in signatures:
func_name = sig.get("name", "") url = func_name_to_url.get(sig.get("name", ""), "")
url = func_name_to_url.get(func_name, "")
_insert_field_after(sig, after_key="type", new_key="url", new_value=url) _insert_field_after(sig, after_key="type", new_key="url", new_value=url)
return signatures return signatures
def _insert_field_after( def _insert_field_after(d: dict, after_key: str, new_key: str, new_value) -> None:
d: dict, """在有序 dict 中将 new_key 插入到 after_key 之后"""
after_key: str,
new_key: str,
new_value,
) -> None:
"""
在有序 dict 中将 new_key 插入到 after_key 之后
after_key 不存在则追加到末尾
new_key 已存在则直接更新其值不改变位置
Args:
d: 目标 dictPython 3.7+ 保证插入顺序
after_key: 参考键名
new_key: 要插入的键名
new_value: 要插入的值
"""
if new_key in d: if new_key in d:
d[new_key] = new_value d[new_key] = new_value
return return
items = list(d.items()) items = list(d.items())
insert_pos = len(items) insert_pos = next((i + 1 for i, (k, _) in enumerate(items) if k == after_key), len(items))
for i, (k, _) in enumerate(items):
if k == after_key:
insert_pos = i + 1
break
items.insert(insert_pos, (new_key, new_value)) items.insert(insert_pos, (new_key, new_value))
d.clear() d.clear()
d.update(items) d.update(items)
@ -245,42 +140,7 @@ def write_function_signatures_json(
project_description: str, project_description: str,
file_name: str = "function_signatures.json", file_name: str = "function_signatures.json",
) -> str: ) -> str:
""" """将签名列表导出为 JSON 文件"""
将函数签名列表连同项目信息一起导出为 JSON 文件
输出的 JSON 顶层结构为::
{
"project": "<project_name>",
"description": "<project_description>",
"functions": [
{
"name": "...",
"requirement_id": "...",
"description": "...",
"type": "function",
"url": "/abs/path/to/xxx.py",
"parameters": { ... },
"return": { ... }
},
...
]
}
Args:
output_dir: JSON 文件写入目录
signatures: 函数签名 dict 列表应已通过
patch_signatures_with_url() 写入 "url" 字段
project_name: 项目名称
project_description: 项目描述
file_name: 输出文件名默认 function_signatures.json
Returns:
写入的 JSON 文件完整路径
Raises:
OSError: 目录不可写
"""
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
document = build_signatures_document(project_name, project_description, signatures) document = build_signatures_document(project_name, project_description, signatures)
file_path = os.path.join(output_dir, file_name) file_path = os.path.join(output_dir, file_name)
@ -294,137 +154,70 @@ def write_function_signatures_json(
# ══════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════
def validate_signature_schema(signature: dict) -> List[str]: def validate_signature_schema(signature: dict) -> List[str]:
""" """校验单个函数签名结构,返回错误列表(空列表表示通过)"""
校验单个函数签名 dict 是否符合规范
校验范围
- 顶层必填字段name / requirement_id / description / type / parameters
- 可选字段 "url"若存在则必须为非空字符串
- parameters每个参数的 type / inout / required 字段
- returntype 字段 + on_success / on_failure 子结构
- void 函数on_success / on_failure 应为 null
- void 函数on_success / on_failure 必须存在
value非空 description非空均需填写
Args:
signature: 单个函数签名 dict
Returns:
错误信息字符串列表列表为空表示校验通过
"""
errors: List[str] = [] errors: List[str] = []
# ── 顶层必填字段 ──────────────────────────────────
for key in ("name", "requirement_id", "description", "type", "parameters"): for key in ("name", "requirement_id", "description", "type", "parameters"):
if key not in signature: if key not in signature:
errors.append(f"缺少顶层字段: '{key}'") errors.append(f"缺少顶层字段: '{key}'")
# ── url 字段(可选,存在时校验非空)─────────────────
if "url" in signature: if "url" in signature:
if not isinstance(signature["url"], str): if not isinstance(signature["url"], str):
errors.append("'url' 字段必须是字符串类型") errors.append("'url' 字段必须是字符串类型")
elif signature["url"] == "": elif signature["url"] == "":
errors.append("'url' 字段不能为空字符串(代码文件路径未成功回写)") errors.append("'url' 字段不能为空字符串")
# ── parameters ────────────────────────────────────
params = signature.get("parameters", {}) params = signature.get("parameters", {})
if not isinstance(params, dict): if isinstance(params, dict):
errors.append("'parameters' 必须是 dict 类型")
else:
for pname, pdef in params.items(): for pname, pdef in params.items():
if not isinstance(pdef, dict): if not isinstance(pdef, dict):
errors.append(f"参数 '{pname}' 定义必须是 dict") errors.append(f"参数 '{pname}' 定义必须是 dict")
continue continue
# type支持联合类型如 "string|integer"
if "type" not in pdef: if "type" not in pdef:
errors.append(f"参数 '{pname}' 缺少 'type' 字段") errors.append(f"参数 '{pname}' 缺少 'type'")
else: else:
parts = [p.strip() for p in pdef["type"].split("|")] parts = [p.strip() for p in pdef["type"].split("|")]
if not all(p in VALID_TYPES for p in parts): if not all(p in VALID_TYPES for p in parts):
errors.append( errors.append(f"参数 '{pname}' type='{pdef['type']}' 含不合法类型")
f"参数 '{pname}' 的 type='{pdef['type']}' 含有不合法的类型"
)
# inout
if "inout" not in pdef: if "inout" not in pdef:
errors.append(f"参数 '{pname}' 缺少 'inout' 字段") errors.append(f"参数 '{pname}' 缺少 'inout'")
elif pdef["inout"] not in VALID_INOUT: elif pdef["inout"] not in VALID_INOUT:
errors.append( errors.append(f"参数 '{pname}' inout='{pdef['inout']}' 应为 in/out/inout")
f"参数 '{pname}' 的 inout='{pdef['inout']}' 应为 in/out/inout"
)
# required
if "required" not in pdef: if "required" not in pdef:
errors.append(f"参数 '{pname}' 缺少 'required' 字段") errors.append(f"参数 '{pname}' 缺少 'required'")
elif not isinstance(pdef["required"], bool): elif not isinstance(pdef["required"], bool):
errors.append( errors.append(f"参数 '{pname}' 'required' 应为布尔值")
f"参数 '{pname}''required' 应为布尔值 true/false"
f"当前为: {pdef['required']!r}"
)
# ── return ────────────────────────────────────────
ret = signature.get("return") ret = signature.get("return")
if ret is None: if ret is None:
errors.append( errors.append("缺少 'return' 字段")
"缺少 'return' 字段void 函数请填 " elif isinstance(ret, dict):
"{\"type\": \"void\", \"on_success\": null, \"on_failure\": null}"
)
elif not isinstance(ret, dict):
errors.append("'return' 必须是 dict 类型")
else:
ret_type = ret.get("type") ret_type = ret.get("type")
if not ret_type: if not ret_type:
errors.append("'return' 缺少 'type' 字段") errors.append("'return' 缺少 'type'")
elif ret_type not in VALID_TYPES: elif ret_type not in VALID_TYPES:
errors.append(f"'return.type'='{ret_type}' 不在合法类型列表中") errors.append(f"'return.type'='{ret_type}' 不合法")
is_void = (ret_type == "void") is_void = (ret_type == "void")
for sub_key in ("on_success", "on_failure"): for sub_key in ("on_success", "on_failure"):
sub = ret.get(sub_key) sub = ret.get(sub_key)
if is_void: if is_void:
if sub is not None: if sub is not None:
errors.append( errors.append(f"void 函数 'return.{sub_key}' 应为 null")
f"void 函数的 'return.{sub_key}' 应为 null"
f"当前为: {sub!r}"
)
else: else:
if sub is None: if sub is None:
errors.append( errors.append(f"非 void 函数缺少 'return.{sub_key}'")
f"非 void 函数缺少 'return.{sub_key}'" elif isinstance(sub, dict):
f"请描述{'成功' if sub_key == 'on_success' else '失败'}时的返回值" if not sub.get("value"):
) errors.append(f"'return.{sub_key}.value' 不能为空")
elif not isinstance(sub, dict): if not sub.get("description"):
errors.append(f"'return.{sub_key}' 必须是 dict 类型")
else:
if "value" not in sub:
errors.append(f"'return.{sub_key}' 缺少 'value' 字段")
elif sub["value"] == "":
errors.append(
f"'return.{sub_key}.value' 不能为空字符串,"
f"请填写具体返回值、值域描述或结构示例"
)
if "description" not in sub or sub.get("description") in (None, ""):
errors.append(f"'return.{sub_key}.description' 不能为空") errors.append(f"'return.{sub_key}.description' 不能为空")
return errors return errors
def validate_all_signatures(signatures: List[dict]) -> Dict[str, List[str]]: def validate_all_signatures(signatures: List[dict]) -> Dict[str, List[str]]:
""" """批量校验,返回 {函数名: [错误]} 字典(仅含有错误的条目)"""
批量校验函数签名列表 return {
sig.get("name", f"unknown_{i}"): errs
注意此函数接受的是纯签名列表即顶层文档的 "functions" 字段 for i, sig in enumerate(signatures)
而非包含 project/description 的顶层文档 if (errs := validate_signature_schema(sig))
}
Args:
signatures: 函数签名 dict 列表
Returns:
{函数名: [错误信息, ...]} 字典仅包含有错误的条目
"""
report: Dict[str, List[str]] = {}
for sig in signatures:
name = sig.get("name", f"unknown_{id(sig)}")
errs = validate_signature_schema(sig)
if errs:
report[name] = errs
return report