diff --git a/requirements_generator/config.py b/requirements_generator/config.py index a9702aa..906a2a6 100644 --- a/requirements_generator/config.py +++ b/requirements_generator/config.py @@ -1,146 +1,123 @@ -# config.py - 全局配置管理 +# config.py - 全局配置 import os from dotenv import load_dotenv load_dotenv() -# ── LLM 配置 ────────────────────────────────────────── -LLM_API_KEY = os.getenv("OPENAI_API_KEY", "") -LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") -LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o") -LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0.3")) +# ── LLM ────────────────────────────────────────────── +LLM_API_KEY = os.getenv("OPENAI_API_KEY", "") +LLM_API_BASE = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") +LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o") +LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT", "60")) +LLM_MAX_RETRY = int(os.getenv("LLM_MAX_RETRY", "3")) -# ── 数据库配置 ───────────────────────────────────────── -DB_PATH = os.getenv("DB_PATH", "data/requirement_analyzer.db") +# ── 数据库 ──────────────────────────────────────────── +DB_PATH = os.getenv("DB_PATH", "data/requirement_analyzer.db") -# ── 输出配置 ─────────────────────────────────────────── +# ── 输出目录 ────────────────────────────────────────── OUTPUT_BASE_DIR = os.getenv("OUTPUT_BASE_DIR", "output") DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "python") +DEFAULT_MODULE = os.getenv("DEFAULT_MODULE", "default") -# ══════════════════════════════════════════════════════ -# Prompt 模板 -# ══════════════════════════════════════════════════════ +# ── Prompt 模板 ─────────────────────────────────────── -DECOMPOSE_PROMPT_TEMPLATE = """ -你是一位资深软件架构师和产品经理。请根据以下信息,将原始需求分解为若干个可独立实现的功能需求。 +DECOMPOSE_PROMPT_TEMPLATE = """\ +你是一名资深软件架构师,请将以下原始需求分解为独立的功能需求列表。 -{knowledge_section} - -## 原始需求 +【原始需求】 {raw_requirement} -## 输出要求 -请严格按照以下 JSON 格式输出,不要包含任何额外说明: -{{ - "functional_requirements": [ - {{ - "index": 1, - "title": "功能需求标题(简洁,10字以内)", - "description": "功能需求详细描述(包含输入、处理逻辑、输出)", - "function_name": "snake_case函数名", - "priority": "high|medium|low" - }} - ] -}} - -要求: -1. 每个功能需求必须是独立可实现的最小单元 -2. function_name 使用 snake_case 命名,清晰表达函数用途 -3. 分解粒度适中,通常 5-15 个功能需求 -4. 优先级根据业务重要性判断 -""" - -# ── 函数签名 JSON 生成 Prompt ────────────────────────── -FUNC_SIGNATURE_PROMPT_TEMPLATE = """ -你是一位资深软件架构师。请根据以下功能需求描述,设计该函数的完整接口签名,并以 JSON 格式输出。 - {knowledge_section} -## 功能需求 -需求编号:{requirement_id} -标题:{title} -函数名:{function_name} -详细描述:{description} +【输出要求】 +以 JSON 数组格式输出,每个元素包含以下字段: +- title: 功能标题(简短,10字以内) +- description: 功能描述(详细说明该功能的职责与边界,50字以内) +- function_name: 对应的函数名(snake_case,动词开头) +- priority: 优先级(high / medium / low) +- module: 所属功能模块名称(snake_case,如 user_auth / order_service) -## 输出格式 -请严格按照以下 JSON 结构输出,不要包含任何额外说明或 markdown 标记: -{{ - "name": "{function_name}", - "requirement_id": "{requirement_id}", - "description": "简洁的一句话功能描述(英文)", - "type": "function", - "parameters": {{ - "": {{ - "type": "integer|string|boolean|float|list|dict|object", - "inout": "in|out|inout", - "description": "参数说明(英文)", - "required": true - }} - }}, - "return": {{ - "type": "integer|string|boolean|float|list|dict|object|void", - "description": "整体返回值说明(英文,一句话概括)", - "on_success": {{ - "value": "具体成功返回值或范围,如 0、true、user object、list of items 等", - "description": "成功时的返回值含义(英文)" - }}, - "on_failure": {{ - "value": "具体失败返回值或范围,如 nonzero、false、null、empty list、raises Exception 等", - "description": "失败时的返回值含义,或抛出的异常类型(英文)" - }} +【示例输出】 +[ + {{ + "title": "用户注册", + "description": "接收用户名、密码、邮箱,校验合法性后创建用户账号并返回用户ID", + "function_name": "register_user", + "priority": "high", + "module": "user_auth" }} -}} +] -## 设计规范 -1. 参数名使用 snake_case,类型使用通用类型(不绑定具体语言) -2. inout 字段含义: - - in = 仅输入参数 - - out = 仅输出参数(通过参数传出结果,如指针/引用) - - inout = 既作输入又作输出 -3. 所有描述字段使用英文 -4. return 字段规则: - - 若函数无返回值(void),type 填 "void",on_success/on_failure 均填 null - - 若返回值只有成功场景(如纯查询),on_failure 可描述为 "null or empty" - - on_success.value / on_failure.value 填写具体值或值域描述,不要填写空字符串 -5. 若函数无参数,parameters 填 {{}} -6. required 字段为布尔值 true 或 false +只输出 JSON 数组,不要有任何额外说明。 """ -# ── 代码生成 Prompt(含签名约束)───────────────────────── -CODE_GEN_PROMPT_TEMPLATE = """ -你是一位资深 {language} 工程师。请根据以下功能需求和【函数签名规范】,生成完整的 {language} 函数代码。 +FUNC_SIGNATURE_PROMPT_TEMPLATE = """\ +你是一名资深软件工程师,请根据以下功能需求生成标准函数签名信息。 + +【功能需求】 +- 需求编号: {requirement_id} +- 标题: {title} +- 描述: {description} +- 函数名: {function_name} +- 所属模块: {module} {knowledge_section} -## 功能需求 -标题:{title} -描述:{description} +【输出要求】 +以 JSON 对象格式输出,包含以下字段: +- name: 函数名(与上方一致) +- requirement_id: 需求编号 +- description: 函数功能描述(英文,一句话) +- type: 固定为 "function" +- module: 所属模块名称 +- parameters: 参数字典,key 为参数名,value 包含: + - type: 数据类型(integer/string/boolean/float/list/dict/object/void/any) + - inout: in / out / inout + - required: true / false + - description: 参数说明 +- return: + - type: 返回类型 + - on_success: {{ "value": "...", "description": "..." }} 或 null(void) + - on_failure: {{ "value": "...", "description": "..." }} 或 null(void) -## 【必须严格遵守】函数签名规范 -以下 JSON 定义了函数的精确接口,生成的代码必须与之完全一致,不得擅自增减或改名参数: +只输出 JSON 对象,不要有任何额外说明。 +""" -```json +CODE_GEN_PROMPT_TEMPLATE = """\ +你是一名资深 {language} 工程师,请根据以下函数签名和功能描述生成完整的函数实现代码。 + +【函数签名】 {signature_json} -``` -### 签名字段说明 -- `name`:函数名,必须完全一致 -- `parameters`:每个 key 即为参数名,`type` 为数据类型,`inout` 含义: - - `in` = 普通输入参数 - - `out` = 输出参数(Python 中通过返回值或可变容器传出) - - `inout` = 既作输入又作输出 -- `return.type`:返回值类型 -- `return.on_success`:成功时的返回值,代码实现必须与此一致 -- `return.on_failure`:失败时的返回值或异常,代码实现必须与此一致 +【功能描述】 +{description} -## 输出要求 -1. 只输出纯代码,不要包含 markdown 代码块标记 -2. 函数签名(名称、参数列表、返回类型)必须与上方 JSON 规范完全一致 -3. 成功/失败的返回值必须严格遵守 return.on_success / return.on_failure 的定义 -4. 包含完整的类型注解(Python 使用 type hints) -5. 包含详细的 docstring,其中 Returns 段须注明成功值与失败值 -6. 包含必要的异常处理 -7. 代码风格遵循 PEP8(Python)或对应语言规范 -8. 在文件顶部用注释注明:需求编号、功能标题、函数签名摘要 -9. 如需导入第三方库,请在顶部统一导入 +{knowledge_section} + +【输出要求】 +1. 只输出 {language} 代码,不要有任何 Markdown 标记(不要 ```) +2. 包含完整的函数实现(含必要的 import) +3. 包含函数文档注释(docstring / JSDoc 等) +4. 包含基本的参数校验与错误处理 +5. 代码风格遵循 {language} 最佳实践 +""" + +MODULE_CLASSIFY_PROMPT_TEMPLATE = """\ +你是一名资深软件架构师,请将以下功能需求列表分类到合适的功能模块中。 + +【功能需求列表】 +{requirements_json} + +【输出要求】 +以 JSON 数组格式输出,每个元素包含: +- function_name: 函数名(与输入一致) +- module: 所属模块名称(snake_case,如 user_auth / order_service / payment) + +模块划分原则: +1. 功能相近的需求归入同一模块 +2. 模块名使用英文 snake_case +3. 模块数量控制在 2~8 个之间 +4. 若某需求确实无法归类,使用 "default" 模块 + +只输出 JSON 数组,不要有任何额外说明。 """ \ No newline at end of file diff --git a/requirements_generator/core/code_generator.py b/requirements_generator/core/code_generator.py index 2418726..4199bce 100644 --- a/requirements_generator/core/code_generator.py +++ b/requirements_generator/core/code_generator.py @@ -1,99 +1,89 @@ -# core/code_generator.py - 代码生成核心逻辑(签名约束版) +# core/code_generator.py - 代码生成(按模块路由到子目录) +import os import json -from typing import Optional, List, Callable +from pathlib import Path +from typing import List, Optional, Callable import config from core.llm_client import LLMClient from database.models import FunctionalRequirement, CodeFile -from utils.output_writer import write_code_file, get_file_extension class CodeGenerator: - """ - 根据功能需求 + 函数签名约束,使用 LLM 生成代码函数文件。 - 签名由 RequirementAnalyzer.build_function_signature() 预先生成, - 注入 Prompt 后可确保代码参数列表与签名 JSON 完全一致。 - """ + """根据函数签名约束,调用 LLM 生成代码文件,并按模块写入子目录""" - def __init__(self, llm_client: Optional[LLMClient] = None): - """ - 初始化代码生成器 - - Args: - llm_client: LLM 客户端实例,为 None 时自动创建 - """ - self.llm = llm_client or LLMClient() + def __init__(self, llm: LLMClient): + self.llm = llm # ══════════════════════════════════════════════════ - # 单个生成 + # 单个代码文件生成 # ══════════════════════════════════════════════════ def generate( self, - func_req: FunctionalRequirement, + func_req: FunctionalRequirement, output_dir: str, - language: str = config.DEFAULT_LANGUAGE, - knowledge: str = "", - signature: Optional[dict] = None, + language: str = None, + knowledge: str = "", + signature: dict = None, ) -> CodeFile: """ - 为单个功能需求生成代码文件 + 为单个功能需求生成代码文件,写入 output_dir// 子目录。 Args: - func_req: 功能需求对象(必须含有效 id) - output_dir: 代码输出目录 - language: 目标编程语言 - knowledge: 知识库文本(可选) - signature: 函数签名 dict(由 RequirementAnalyzer 生成)。 - 传入后将作为强约束注入 Prompt,确保代码参数 - 与签名 JSON 完全一致;为 None 时退化为无约束模式。 + func_req: 功能需求对象 + output_dir: 项目根输出目录 + language: 目标语言 + knowledge: 知识库文本 + signature: 函数签名 dict(可选,有则作为约束) Returns: - CodeFile 对象(含生成的代码内容和文件路径,未持久化) + CodeFile 对象(未持久化) Raises: - ValueError: func_req.id 为 None - Exception: LLM 调用失败或文件写入失败 + RuntimeError: LLM 调用失败 """ - if func_req.id is None: - raise ValueError("FunctionalRequirement 必须先持久化(id 不能为 None)") + language = language or config.DEFAULT_LANGUAGE + module = (func_req.module or config.DEFAULT_MODULE).strip() - knowledge_section = self._build_knowledge_section(knowledge) - signature_json = self._build_signature_json(signature, func_req) + # 按模块创建子目录 + module_dir = os.path.join(output_dir, module) + os.makedirs(module_dir, exist_ok=True) + self._ensure_init_py(module_dir) + # 构建 Prompt + sig_json = json.dumps(signature, ensure_ascii=False, indent=2) if signature else "{}" + knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else "" prompt = config.CODE_GEN_PROMPT_TEMPLATE.format( - language=language, - knowledge_section=knowledge_section, - title=func_req.title, - description=func_req.description, - signature_json=signature_json, + language = language, + signature_json = sig_json, + description = func_req.description, + knowledge_section = knowledge_section, ) - code_content = self.llm.chat( - system_prompt=( - f"你是一位资深 {language} 工程师,只输出纯代码," - "不添加任何 markdown 标记。函数签名必须与提供的 JSON 规范完全一致。" - ), - user_prompt=prompt, - ) + try: + code_content = self.llm.chat( + prompt, + system = f"You are an expert {language} developer. Output only code.", + temperature = 0.2, + max_tokens = 4096, + ) + except Exception as e: + raise RuntimeError(f"代码生成失败 [{func_req.function_name}]: {e}") - file_path = write_code_file( - output_dir=output_dir, - function_name=func_req.function_name, - language=language, - content=code_content, - ) - - ext = get_file_extension(language) + # 写入文件 + ext = self._get_extension(language) file_name = f"{func_req.function_name}{ext}" + file_path = os.path.join(module_dir, file_name) + Path(file_path).write_text(code_content, encoding="utf-8") return CodeFile( - project_id=func_req.project_id, - func_req_id=func_req.id, - file_name=file_name, - file_path=file_path, - language=language, - content=code_content, + func_req_id = func_req.id, + file_name = file_name, + file_path = file_path, + module = module, + language = language, + content = code_content, ) # ══════════════════════════════════════════════════ @@ -102,115 +92,82 @@ class CodeGenerator: def generate_batch( self, - func_reqs: List[FunctionalRequirement], - output_dir: str, - language: str = config.DEFAULT_LANGUAGE, - knowledge: str = "", - signatures: Optional[List[dict]] = None, - on_progress: Optional[Callable] = None, + func_reqs: List[FunctionalRequirement], + output_dir: str, + language: str = None, + knowledge: str = "", + signatures: Optional[List[dict]] = None, + on_progress: Optional[Callable] = None, ) -> List[CodeFile]: """ - 批量生成代码文件 + 批量生成代码文件。 Args: func_reqs: 功能需求列表 - output_dir: 输出目录 + output_dir: 项目根输出目录 language: 目标语言 knowledge: 知识库文本 - signatures: 与 func_reqs 等长的签名列表(索引对应)。 - 为 None 时所有条目均以无约束模式生成。 - on_progress: 进度回调 fn(index, total, func_req, code_file, error) + signatures: 与 func_reqs 等长的签名列表(索引对应) + on_progress: 进度回调 fn(index, total, req, code_file, error) Returns: 成功生成的 CodeFile 列表 """ - results = [] - total = len(func_reqs) + language = language or config.DEFAULT_LANGUAGE + total = len(func_reqs) + results = [] + sig_map = self._build_signature_map(func_reqs, signatures) - # 构建 func_req.id → signature 的快速查找表 - sig_map = self._build_signature_map(func_reqs, signatures) - - for i, req in enumerate(func_reqs): + for i, req in enumerate(func_reqs, 1): sig = sig_map.get(req.id) try: code_file = self.generate( - func_req=req, - output_dir=output_dir, - language=language, - knowledge=knowledge, - signature=sig, + func_req = req, + output_dir = output_dir, + language = language, + knowledge = knowledge, + signature = sig, ) results.append(code_file) if on_progress: - on_progress(i + 1, total, req, code_file, None) + on_progress(i, total, req, code_file, None) except Exception as e: if on_progress: - on_progress(i + 1, total, req, None, e) + on_progress(i, total, req, None, e) return results # ══════════════════════════════════════════════════ - # 私有工具方法 + # 工具方法 # ══════════════════════════════════════════════════ - @staticmethod - def _build_knowledge_section(knowledge: str) -> str: - """构建知识库 Prompt 段落""" - if not knowledge or not knowledge.strip(): - return "" - return ( - "## 参考知识库(实现时请遵循以下规范)\n" - f"{knowledge}\n\n---\n" - ) - - @staticmethod - def _build_signature_json( - signature: Optional[dict], - func_req: FunctionalRequirement, - ) -> str: - """ - 将签名 dict 序列化为格式化 JSON 字符串; - 若签名为 None,则构造最小占位签名,保持 Prompt 结构完整。 - - Args: - signature: 签名 dict 或 None - func_req: 对应的功能需求(用于占位签名) - - Returns: - JSON 字符串 - """ - if signature: - return json.dumps(signature, ensure_ascii=False, indent=2) - # 无签名时的最小占位,提示 LLM 自行设计但保持格式 - fallback = { - "name": func_req.function_name, - "requirement_id": f"REQ.{func_req.index_no:02d}", - "description": func_req.description, - "type": "function", - "parameters": "<<请根据功能描述自行设计参数>>", - "return": "<<请根据功能描述自行设计返回值>>", - } - return json.dumps(fallback, ensure_ascii=False, indent=2) - @staticmethod def _build_signature_map( - func_reqs: List[FunctionalRequirement], + func_reqs: List[FunctionalRequirement], signatures: Optional[List[dict]], ) -> dict: - """ - 构建 func_req.id → signature 映射表 - - Args: - func_reqs: 功能需求列表 - signatures: 与 func_reqs 等长的签名列表,或 None - - Returns: - {req_id: signature_dict} 字典 - """ + """构建 func_req.id → signature 的快速查找表""" if not signatures: return {} - sig_map = {} - for req, sig in zip(func_reqs, signatures): - if req.id is not None and sig: - sig_map[req.id] = sig - return sig_map \ No newline at end of file + return { + req.id: sig + for req, sig in zip(func_reqs, signatures) + if req.id is not None + } + + @staticmethod + def _get_extension(language: str) -> str: + ext_map = { + "python": ".py", "javascript": ".js", "typescript": ".ts", + "java": ".java", "go": ".go", "rust": ".rs", + "cpp": ".cpp", "c": ".c", "csharp": ".cs", + "ruby": ".rb", "php": ".php", "swift": ".swift", "kotlin": ".kt", + } + return ext_map.get(language.lower(), ".txt") + + @staticmethod + def _ensure_init_py(directory: str) -> None: + """在目录中创建 __init__.py(Python 包标识)""" + init = os.path.join(directory, "__init__.py") + if not os.path.exists(init): + Path(init).write_text("# Auto-generated module package\n", encoding="utf-8") \ No newline at end of file diff --git a/requirements_generator/core/llm_client.py b/requirements_generator/core/llm_client.py index 9a2701e..ce2aa10 100644 --- a/requirements_generator/core/llm_client.py +++ b/requirements_generator/core/llm_client.py @@ -1,90 +1,114 @@ -# core/llm_client.py - LLM 客户端封装 +# core/llm_client.py - LLM API 调用封装 +import time import json from typing import Optional +from openai import OpenAI, APIError, APITimeoutError, RateLimitError + import config class LLMClient: - """ - OpenAI 兼容 LLM 客户端封装。 - 支持任何兼容 OpenAI API 格式的服务(OpenAI / Azure / 本地模型等)。 - """ + """封装 OpenAI 兼容接口,提供统一的调用入口与重试机制""" def __init__( self, - api_key: str = config.LLM_API_KEY, - base_url: str = config.LLM_BASE_URL, - model: str = config.LLM_MODEL, - temperature: float = config.LLM_TEMPERATURE, + api_key: str = None, + api_base: str = None, + model: str = None, ): - """ - 初始化 LLM 客户端 - - Args: - api_key: API 密钥 - base_url: API 基础 URL - model: 模型名称 - temperature: 生成温度(0~1,越低越确定) - - Raises: - ImportError: 未安装 openai 库 - ValueError: api_key 为空 - """ - try: - from openai import OpenAI - except ImportError: - raise ImportError("请安装 openai: pip install openai") - - if not api_key: - raise ValueError("LLM_API_KEY 未配置,请在 .env 文件中设置") - - self.model = model - self.temperature = temperature - self._client = OpenAI(api_key=api_key, base_url=base_url) - - def chat(self, system_prompt: str, user_prompt: str) -> str: - """ - 发送对话请求,返回模型回复文本 - - Args: - system_prompt: 系统提示词 - user_prompt: 用户输入 - - Returns: - 模型回复的文本内容 - - Raises: - Exception: API 调用失败 - """ - response = self._client.chat.completions.create( - model=self.model, - temperature=self.temperature, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], + self.model = model or config.LLM_MODEL + self.client = OpenAI( + api_key = api_key or config.LLM_API_KEY, + base_url = api_base or config.LLM_API_BASE, + timeout = config.LLM_TIMEOUT, ) - return response.choices[0].message.content.strip() - def chat_json(self, system_prompt: str, user_prompt: str) -> dict: + def chat( + self, + prompt: str, + system: str = "You are a helpful assistant.", + temperature: float = 0.2, + max_tokens: int = 4096, + ) -> str: """ - 发送对话请求,解析并返回 JSON 结果 + 发送单轮对话请求,返回模型回复文本。 Args: - system_prompt: 系统提示词 - user_prompt: 用户输入 + prompt: 用户消息 + system: 系统提示词 + temperature: 采样温度 + max_tokens: 最大输出 token 数 Returns: - 解析后的 dict 对象 + 模型回复的纯文本字符串 Raises: - json.JSONDecodeError: 模型返回非合法 JSON + RuntimeError: 超过最大重试次数后仍失败 """ - raw = self.chat(system_prompt, user_prompt) + last_error = None + for attempt in range(1, config.LLM_MAX_RETRY + 1): + try: + resp = self.client.chat.completions.create( + model = self.model, + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": prompt}, + ], + temperature = temperature, + max_tokens = max_tokens, + ) + return resp.choices[0].message.content.strip() + except RateLimitError as e: + wait = 2 ** attempt + last_error = e + time.sleep(wait) + except APITimeoutError as e: + last_error = e + time.sleep(2) + except APIError as e: + last_error = e + if attempt >= config.LLM_MAX_RETRY: + break + time.sleep(1) + raise RuntimeError( + f"LLM 调用失败(已重试 {config.LLM_MAX_RETRY} 次): {last_error}" + ) + + def chat_json( + self, + prompt: str, + system: str = "You are a helpful assistant. Always respond with valid JSON.", + temperature: float = 0.1, + max_tokens: int = 4096, + ) -> any: + """ + 发送请求并将回复解析为 JSON 对象。 + + Returns: + 解析后的 Python 对象(dict 或 list) + + Raises: + ValueError: JSON 解析失败 + RuntimeError: LLM 调用失败 + """ + raw = self.chat(prompt, system=system, temperature=temperature, max_tokens=max_tokens) # 去除可能的 markdown 代码块包裹 - raw = raw.strip() - if raw.startswith("```"): - lines = raw.split("\n") - raw = "\n".join(lines[1:-1]) - return json.loads(raw) \ No newline at end of file + cleaned = self._strip_markdown_code_block(raw) + try: + return json.loads(cleaned) + except json.JSONDecodeError as e: + raise ValueError(f"LLM 返回内容无法解析为 JSON: {e}\n原始内容:\n{raw}") + + @staticmethod + def _strip_markdown_code_block(text: str) -> str: + """去除 ```json ... ``` 或 ``` ... ``` 包裹""" + text = text.strip() + if text.startswith("```"): + lines = text.splitlines() + # 去掉首行(```json 或 ```)和末行(```) + inner = lines[1:] if lines[-1].strip() == "```" else lines[1:] + if inner and inner[-1].strip() == "```": + inner = inner[:-1] + text = "\n".join(inner).strip() + return text \ No newline at end of file diff --git a/requirements_generator/core/requirement_analyzer.py b/requirements_generator/core/requirement_analyzer.py index 15adda7..dcf7551 100644 --- a/requirements_generator/core/requirement_analyzer.py +++ b/requirements_generator/core/requirement_analyzer.py @@ -1,6 +1,6 @@ # core/requirement_analyzer.py - 需求分解 & 函数签名生成 -import re -from typing import List, Optional +import json +from typing import List, Optional, Callable import config from core.llm_client import LLMClient @@ -8,19 +8,10 @@ from database.models import FunctionalRequirement class RequirementAnalyzer: - """ - 使用 LLM 将原始需求分解为功能需求列表,并生成函数接口签名。 - 支持注入知识库上下文以提升分解质量。 - """ + """负责需求分解、模块分类、函数签名生成""" - def __init__(self, llm_client: Optional[LLMClient] = None): - """ - 初始化需求分析器 - - Args: - llm_client: LLM 客户端实例,为 None 时自动创建 - """ - self.llm = llm_client or LLMClient() + def __init__(self, llm: LLMClient): + self.llm = llm # ══════════════════════════════════════════════════ # 需求分解 @@ -29,12 +20,12 @@ class RequirementAnalyzer: def decompose( self, raw_requirement: str, - project_id: int, - raw_req_id: int, - knowledge: str = "", + project_id: int, + raw_req_id: int, + knowledge: str = "", ) -> List[FunctionalRequirement]: """ - 将原始需求分解为功能需求列表 + 将原始需求文本分解为功能需求列表(含模块分类)。 Args: raw_requirement: 原始需求文本 @@ -44,196 +35,175 @@ class RequirementAnalyzer: Returns: FunctionalRequirement 对象列表(未持久化,id=None) - - Raises: - ValueError: LLM 返回格式不合法 - json.JSONDecodeError: JSON 解析失败 """ - knowledge_section = self._build_knowledge_section(knowledge) + knowledge_section = ( + f"【参考知识库】\n{knowledge}\n" if knowledge else "" + ) prompt = config.DECOMPOSE_PROMPT_TEMPLATE.format( - knowledge_section=knowledge_section, - raw_requirement=raw_requirement, + raw_requirement = raw_requirement, + knowledge_section = knowledge_section, ) - result = self.llm.chat_json( - system_prompt="你是一位资深软件架构师,擅长需求分析与系统设计。", - user_prompt=prompt, - ) + try: + items = self.llm.chat_json(prompt) + if not isinstance(items, list): + raise ValueError("LLM 返回结果不是数组") + except Exception as e: + raise RuntimeError(f"需求分解失败: {e}") - items = result.get("functional_requirements", []) - if not items: - raise ValueError("LLM 未返回任何功能需求,请检查原始需求描述") - - requirements = [] - for item in items: + reqs = [] + for i, item in enumerate(items, 1): req = FunctionalRequirement( - project_id=project_id, - raw_req_id=raw_req_id, - index_no=int(item.get("index", len(requirements) + 1)), - title=item.get("title", "未命名功能"), - description=item.get("description", ""), - function_name=self._sanitize_function_name( - item.get("function_name", f"func_{len(requirements)+1}") - ), - priority=item.get("priority", "medium"), + project_id = project_id, + raw_req_id = raw_req_id, + index_no = i, + title = item.get("title", f"功能{i}"), + description = item.get("description", ""), + function_name = item.get("function_name", f"function_{i}"), + priority = item.get("priority", "medium"), + module = item.get("module", config.DEFAULT_MODULE), + status = "pending", + is_custom = False, ) - requirements.append(req) - - return requirements + reqs.append(req) + return reqs # ══════════════════════════════════════════════════ - # 函数签名生成(新增) + # 模块分类(独立步骤,可对已有需求列表重新分类) + # ══════════════════════════════════════════════════ + + def classify_modules( + self, + func_reqs: List[FunctionalRequirement], + knowledge: str = "", + ) -> List[dict]: + """ + 对功能需求列表进行模块分类,返回 {function_name: module} 映射列表。 + + Args: + func_reqs: 功能需求列表 + knowledge: 知识库文本(可选) + + Returns: + [{"function_name": "...", "module": "..."}, ...] + """ + req_list = [ + { + "index_no": r.index_no, + "title": r.title, + "description": r.description, + "function_name": r.function_name, + } + for r in func_reqs + ] + knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else "" + prompt = config.MODULE_CLASSIFY_PROMPT_TEMPLATE.format( + requirements_json = json.dumps(req_list, ensure_ascii=False, indent=2), + knowledge_section = knowledge_section, + ) + try: + result = self.llm.chat_json(prompt) + if not isinstance(result, list): + raise ValueError("LLM 返回结果不是数组") + return result + except Exception as e: + raise RuntimeError(f"模块分类失败: {e}") + + # ══════════════════════════════════════════════════ + # 函数签名生成 # ══════════════════════════════════════════════════ def build_function_signature( self, - func_req: FunctionalRequirement, - knowledge: str = "", + func_req: FunctionalRequirement, + requirement_id: str = "", + knowledge: str = "", ) -> dict: """ - 为单个功能需求生成函数接口签名 JSON + 为单个功能需求生成函数签名 dict。 Args: - func_req: 功能需求对象(需含有效 id) - knowledge: 知识库文本(可选) + func_req: 功能需求对象 + requirement_id: 需求编号字符串(如 "REQ.01") + knowledge: 知识库文本 Returns: - 符合接口规范的 dict,包含 name/requirement_id/description/ - type/parameters/return 字段 + 函数签名 dict Raises: - json.JSONDecodeError: LLM 返回非合法 JSON + RuntimeError: LLM 调用或解析失败 """ - requirement_id = self._format_requirement_id(func_req.index_no) - knowledge_section = self._build_knowledge_section(knowledge) - + knowledge_section = f"【参考知识库】\n{knowledge}\n" if knowledge else "" prompt = config.FUNC_SIGNATURE_PROMPT_TEMPLATE.format( - knowledge_section=knowledge_section, - requirement_id=requirement_id, - title=func_req.title, - function_name=func_req.function_name, - description=func_req.description, + requirement_id = requirement_id or f"REQ.{func_req.index_no:02d}", + title = func_req.title, + description = func_req.description, + function_name = func_req.function_name, + module = func_req.module or config.DEFAULT_MODULE, + knowledge_section = knowledge_section, ) - - signature = self.llm.chat_json( - system_prompt=( - "你是一位资深软件架构师,专注于 API 接口设计。" - "只输出合法 JSON,不添加任何说明文字。" - ), - user_prompt=prompt, - ) - - # 确保关键字段存在,做兜底处理 - signature.setdefault("name", func_req.function_name) - signature.setdefault("requirement_id", requirement_id) - signature.setdefault("description", func_req.description) - signature.setdefault("type", "function") - signature.setdefault("parameters", {}) - signature.setdefault("return", {"type": "void", "description": ""}) - - return signature + try: + sig = self.llm.chat_json(prompt) + if not isinstance(sig, dict): + raise ValueError("LLM 返回结果不是 dict") + # 确保 module 字段存在 + if "module" not in sig: + sig["module"] = func_req.module or config.DEFAULT_MODULE + return sig + except Exception as e: + raise RuntimeError(f"签名生成失败 [{func_req.function_name}]: {e}") def build_function_signatures_batch( self, - func_reqs: List[FunctionalRequirement], - knowledge: str = "", - on_progress=None, + func_reqs: List[FunctionalRequirement], + knowledge: str = "", + on_progress: Optional[Callable] = None, ) -> List[dict]: """ - 批量为功能需求列表生成函数接口签名 + 批量生成函数签名,失败时使用降级结构。 Args: func_reqs: 功能需求列表 - knowledge: 知识库文本(可选) - on_progress: 进度回调 fn(index, total, func_req, signature, error) + knowledge: 知识库文本 + on_progress: 进度回调 fn(index, total, req, signature, error) Returns: - 签名 dict 列表,顺序与 func_reqs 一致; - 生成失败的条目使用降级结构填充,不中断整体流程 + 与 func_reqs 等长的签名 dict 列表(索引一一对应) """ - results = [] - total = len(func_reqs) + signatures = [] + total = len(func_reqs) - for i, req in enumerate(func_reqs): + for i, req in enumerate(func_reqs, 1): + req_id = f"REQ.{req.index_no:02d}" try: - sig = self.build_function_signature(req, knowledge) - results.append(sig) - if on_progress: - on_progress(i + 1, total, req, sig, None) + sig = self.build_function_signature(req, req_id, knowledge) + error = None except Exception as e: - # 降级:用基础信息填充,保证 JSON 完整性 - fallback = self._build_fallback_signature(req) - results.append(fallback) - if on_progress: - on_progress(i + 1, total, req, fallback, e) + sig = self._fallback_signature(req, req_id) + error = e - return results + signatures.append(sig) + if on_progress: + on_progress(i, total, req, sig, error) - # ══════════════════════════════════════════════════ - # 私有工具方法 - # ══════════════════════════════════════════════════ + return signatures @staticmethod - def _build_knowledge_section(knowledge: str) -> str: - """构建知识库 Prompt 段落""" - if not knowledge or not knowledge.strip(): - return "" - return f"""## 参考知识库 -{knowledge} - ---- -""" - - @staticmethod - def _sanitize_function_name(name: str) -> str: - """ - 清理函数名,确保符合 snake_case 规范 - - Args: - name: 原始函数名 - - Returns: - 合法的 snake_case 函数名 - """ - name = re.sub(r"[^a-zA-Z0-9_]", "_", name).lower() - name = re.sub(r"_+", "_", name).strip("_") - if name and name[0].isdigit(): - name = "func_" + name - return name or "unnamed_function" - - @staticmethod - def _format_requirement_id(index_no: int) -> str: - """ - 将序号格式化为需求编号字符串 - - Args: - index_no: 功能需求序号(从 1 开始) - - Returns: - 格式化编号,如 'REQ.01'、'REQ.12' - """ - return f"REQ.{index_no:02d}" - - @staticmethod - def _build_fallback_signature(func_req: FunctionalRequirement) -> dict: - """ - 构建降级签名(LLM 调用失败时使用) - - Args: - func_req: 功能需求对象 - - Returns: - 包含基础信息的签名 dict - """ + def _fallback_signature( + req: FunctionalRequirement, + requirement_id: str, + ) -> dict: + """生成降级签名结构(LLM 失败时使用)""" return { - "name": func_req.function_name, - "requirement_id": f"REQ.{func_req.index_no:02d}", - "description": func_req.description, + "name": req.function_name, + "requirement_id": requirement_id, + "description": req.description, "type": "function", + "module": req.module or config.DEFAULT_MODULE, "parameters": {}, - "return": { - "type": "void", - "description": "TODO: define return value" + "return": { + "type": "any", + "on_success": {"value": "...", "description": "成功时返回值"}, + "on_failure": {"value": "None", "description": "失败时返回 None"}, }, - "_note": "Auto-generated fallback due to LLM error" } \ No newline at end of file diff --git a/requirements_generator/database/db_manager.py b/requirements_generator/database/db_manager.py index 95ad904..e5e56a7 100644 --- a/requirements_generator/database/db_manager.py +++ b/requirements_generator/database/db_manager.py @@ -1,314 +1,152 @@ -# database/db_manager.py - 数据库操作管理器 -import sqlite3 +# database/db_manager.py - 数据库 CRUD 操作封装 import os -from datetime import datetime from typing import List, Optional -from contextlib import contextmanager -from database.models import ( - CREATE_TABLES_SQL, Project, RawRequirement, - FunctionalRequirement, CodeFile -) +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session + import config +from database.models import Base, Project, RawRequirement, FunctionalRequirement, CodeFile class DBManager: """SQLite 数据库管理器,封装所有 CRUD 操作""" - def __init__(self, db_path: str = config.DB_PATH): - self.db_path = db_path + def __init__(self, db_path: str = None): + db_path = db_path or config.DB_PATH os.makedirs(os.path.dirname(db_path), exist_ok=True) - self._init_db() + self.engine = create_engine(f"sqlite:///{db_path}", echo=False) + Base.metadata.create_all(self.engine) + self._Session = sessionmaker(bind=self.engine) - # ── 连接上下文管理器 ────────────────────────────────── - - @contextmanager - def _get_conn(self): - """获取数据库连接(自动提交/回滚)""" - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA foreign_keys = ON") - try: - yield conn - conn.commit() - except Exception: - conn.rollback() - raise - finally: - conn.close() - - def _init_db(self): - """初始化数据库,创建所有表""" - with self._get_conn() as conn: - conn.executescript(CREATE_TABLES_SQL) + def _session(self) -> Session: + return self._Session() # ══════════════════════════════════════════════════ - # Project CRUD + # Project # ══════════════════════════════════════════════════ def create_project(self, project: Project) -> int: - """创建项目,返回新项目 ID""" - sql = """ - INSERT INTO projects (name, description, language, output_dir, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?) - """ - with self._get_conn() as conn: - cur = conn.execute(sql, ( - project.name, project.description, project.language, - project.output_dir, project.created_at, project.updated_at - )) - return cur.lastrowid - - def get_project_by_id(self, project_id: int) -> Optional[Project]: - """根据 ID 查询项目""" - with self._get_conn() as conn: - row = conn.execute( - "SELECT * FROM projects WHERE id = ?", (project_id,) - ).fetchone() - if row is None: - return None - return Project( - id=row["id"], name=row["name"], description=row["description"], - language=row["language"], output_dir=row["output_dir"], - created_at=row["created_at"], updated_at=row["updated_at"] - ) + with self._session() as s: + s.add(project) + s.commit() + s.refresh(project) + return project.id def get_project_by_name(self, name: str) -> Optional[Project]: - """根据名称查询项目""" - with self._get_conn() as conn: - row = conn.execute( - "SELECT * FROM projects WHERE name = ?", (name,) - ).fetchone() - if row is None: - return None - return Project( - id=row["id"], name=row["name"], description=row["description"], - language=row["language"], output_dir=row["output_dir"], - created_at=row["created_at"], updated_at=row["updated_at"] - ) + with self._session() as s: + return s.query(Project).filter_by(name=name).first() - def list_projects(self) -> List[Project]: - """列出所有项目""" - with self._get_conn() as conn: - rows = conn.execute( - "SELECT * FROM projects ORDER BY created_at DESC" - ).fetchall() - return [ - Project( - id=r["id"], name=r["name"], description=r["description"], - language=r["language"], output_dir=r["output_dir"], - created_at=r["created_at"], updated_at=r["updated_at"] - ) for r in rows - ] + def get_project_by_id(self, project_id: int) -> Optional[Project]: + with self._session() as s: + return s.get(Project, project_id) def update_project(self, project: Project) -> None: - """更新项目信息""" - project.updated_at = datetime.now().isoformat() - sql = """ - UPDATE projects - SET name=?, description=?, language=?, output_dir=?, updated_at=? - WHERE id=? - """ - with self._get_conn() as conn: - conn.execute(sql, ( - project.name, project.description, project.language, - project.output_dir, project.updated_at, project.id - )) + with self._session() as s: + s.merge(project) + s.commit() - def delete_project(self, project_id: int) -> None: - """删除项目(级联删除所有关联数据)""" - with self._get_conn() as conn: - conn.execute("DELETE FROM projects WHERE id = ?", (project_id,)) + def list_projects(self) -> List[Project]: + with self._session() as s: + return s.query(Project).order_by(Project.created_at.desc()).all() # ══════════════════════════════════════════════════ - # RawRequirement CRUD + # RawRequirement # ══════════════════════════════════════════════════ - def create_raw_requirement(self, req: RawRequirement) -> int: - """创建原始需求,返回新记录 ID""" - sql = """ - INSERT INTO raw_requirements - (project_id, content, source_type, source_name, knowledge, created_at) - VALUES (?, ?, ?, ?, ?, ?) - """ - with self._get_conn() as conn: - cur = conn.execute(sql, ( - req.project_id, req.content, req.source_type, - req.source_name, req.knowledge, req.created_at - )) - return cur.lastrowid + def create_raw_requirement(self, raw_req: RawRequirement) -> int: + with self._session() as s: + s.add(raw_req) + s.commit() + s.refresh(raw_req) + return raw_req.id - def get_raw_requirement(self, req_id: int) -> Optional[RawRequirement]: - """根据 ID 查询原始需求""" - with self._get_conn() as conn: - row = conn.execute( - "SELECT * FROM raw_requirements WHERE id = ?", (req_id,) - ).fetchone() - if row is None: - return None - return RawRequirement( - id=row["id"], project_id=row["project_id"], content=row["content"], - source_type=row["source_type"], source_name=row["source_name"], - knowledge=row["knowledge"], created_at=row["created_at"] - ) - - def list_raw_requirements_by_project(self, project_id: int) -> List[RawRequirement]: - """查询项目下所有原始需求""" - with self._get_conn() as conn: - rows = conn.execute( - "SELECT * FROM raw_requirements WHERE project_id = ? ORDER BY created_at", - (project_id,) - ).fetchall() - return [ - RawRequirement( - id=r["id"], project_id=r["project_id"], content=r["content"], - source_type=r["source_type"], source_name=r["source_name"], - knowledge=r["knowledge"], created_at=r["created_at"] - ) for r in rows - ] + def get_raw_requirement(self, raw_req_id: int) -> Optional[RawRequirement]: + with self._session() as s: + return s.get(RawRequirement, raw_req_id) # ══════════════════════════════════════════════════ - # FunctionalRequirement CRUD + # FunctionalRequirement # ══════════════════════════════════════════════════ def create_functional_requirement(self, req: FunctionalRequirement) -> int: - """创建功能需求,返回新记录 ID""" - sql = """ - INSERT INTO functional_requirements - (project_id, raw_req_id, index_no, title, description, - function_name, priority, status, is_custom, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """ - with self._get_conn() as conn: - cur = conn.execute(sql, ( - req.project_id, req.raw_req_id, req.index_no, - req.title, req.description, req.function_name, - req.priority, req.status, int(req.is_custom), - req.created_at, req.updated_at - )) - return cur.lastrowid + with self._session() as s: + s.add(req) + s.commit() + s.refresh(req) + return req.id def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]: - """根据 ID 查询功能需求""" - with self._get_conn() as conn: - row = conn.execute( - "SELECT * FROM functional_requirements WHERE id = ?", (req_id,) - ).fetchone() - if row is None: - return None - return self._row_to_func_req(row) + with self._session() as s: + return s.get(FunctionalRequirement, req_id) def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]: - """查询项目下所有功能需求(按序号排序)""" - with self._get_conn() as conn: - rows = conn.execute( - """SELECT * FROM functional_requirements - WHERE project_id = ? ORDER BY index_no""", - (project_id,) - ).fetchall() - return [self._row_to_func_req(r) for r in rows] - - def update_functional_requirement(self, req: FunctionalRequirement) -> None: - """更新功能需求""" - req.updated_at = datetime.now().isoformat() - sql = """ - UPDATE functional_requirements - SET title=?, description=?, function_name=?, priority=?, - status=?, index_no=?, updated_at=? - WHERE id=? - """ - with self._get_conn() as conn: - conn.execute(sql, ( - req.title, req.description, req.function_name, - req.priority, req.status, req.index_no, - req.updated_at, req.id - )) - - def delete_functional_requirement(self, req_id: int) -> None: - """删除功能需求""" - with self._get_conn() as conn: - conn.execute( - "DELETE FROM functional_requirements WHERE id = ?", (req_id,) + with self._session() as s: + return ( + s.query(FunctionalRequirement) + .filter_by(project_id=project_id) + .order_by(FunctionalRequirement.index_no) + .all() ) - def _row_to_func_req(self, row) -> FunctionalRequirement: - """sqlite Row → FunctionalRequirement 对象""" - return FunctionalRequirement( - id=row["id"], project_id=row["project_id"], - raw_req_id=row["raw_req_id"], index_no=row["index_no"], - title=row["title"], description=row["description"], - function_name=row["function_name"], priority=row["priority"], - status=row["status"], is_custom=bool(row["is_custom"]), - created_at=row["created_at"], updated_at=row["updated_at"] - ) + def update_functional_requirement(self, req: FunctionalRequirement) -> None: + with self._session() as s: + s.merge(req) + s.commit() - # ══════════════════════════════════════════════════ - # CodeFile CRUD - # ══════════════════════════════════════════════════ + def delete_functional_requirement(self, req_id: int) -> None: + with self._session() as s: + obj = s.get(FunctionalRequirement, req_id) + if obj: + s.delete(obj) + s.commit() - def create_code_file(self, code_file: CodeFile) -> int: - """创建代码文件记录,返回新记录 ID""" - sql = """ - INSERT INTO code_files - (project_id, func_req_id, file_name, file_path, - language, content, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + def bulk_update_modules(self, updates: List[dict]) -> None: """ - with self._get_conn() as conn: - cur = conn.execute(sql, ( - code_file.project_id, code_file.func_req_id, - code_file.file_name, code_file.file_path, - code_file.language, code_file.content, - code_file.created_at, code_file.updated_at - )) - return cur.lastrowid + 批量更新功能需求的 module 字段。 + + Args: + updates: [{"function_name": "...", "module": "..."}, ...] + """ + with self._session() as s: + name_to_module = {u["function_name"]: u["module"] for u in updates} + reqs = s.query(FunctionalRequirement).filter( + FunctionalRequirement.function_name.in_(name_to_module.keys()) + ).all() + for req in reqs: + req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE) + s.commit() + + # ══════════════════════════════════════════════════ + # CodeFile + # ══════════════════════════════════════════════════ def upsert_code_file(self, code_file: CodeFile) -> int: - """插入或更新代码文件(按 func_req_id 唯一键)""" - existing = self.get_code_file_by_func_req(code_file.func_req_id) - if existing: - code_file.id = existing.id - code_file.updated_at = datetime.now().isoformat() - sql = """ - UPDATE code_files - SET file_name=?, file_path=?, language=?, content=?, updated_at=? - WHERE id=? - """ - with self._get_conn() as conn: - conn.execute(sql, ( - code_file.file_name, code_file.file_path, - code_file.language, code_file.content, - code_file.updated_at, code_file.id - )) - return code_file.id - else: - return self.create_code_file(code_file) + with self._session() as s: + existing = ( + s.query(CodeFile) + .filter_by(func_req_id=code_file.func_req_id) + .first() + ) + if existing: + existing.file_name = code_file.file_name + existing.file_path = code_file.file_path + existing.module = code_file.module + existing.language = code_file.language + existing.content = code_file.content + s.commit() + return existing.id + else: + s.add(code_file) + s.commit() + s.refresh(code_file) + return code_file.id - def get_code_file_by_func_req(self, func_req_id: int) -> Optional[CodeFile]: - """根据功能需求 ID 查询代码文件""" - with self._get_conn() as conn: - row = conn.execute( - "SELECT * FROM code_files WHERE func_req_id = ?", (func_req_id,) - ).fetchone() - if row is None: - return None - return self._row_to_code_file(row) - - def list_code_files_by_project(self, project_id: int) -> List[CodeFile]: - """查询项目下所有代码文件""" - with self._get_conn() as conn: - rows = conn.execute( - "SELECT * FROM code_files WHERE project_id = ? ORDER BY created_at", - (project_id,) - ).fetchall() - return [self._row_to_code_file(r) for r in rows] - - def _row_to_code_file(self, row) -> CodeFile: - """sqlite Row → CodeFile 对象""" - return CodeFile( - id=row["id"], project_id=row["project_id"], - func_req_id=row["func_req_id"], file_name=row["file_name"], - file_path=row["file_path"], language=row["language"], - content=row["content"], created_at=row["created_at"], - updated_at=row["updated_at"] - ) \ No newline at end of file + def list_code_files(self, project_id: int) -> List[CodeFile]: + with self._session() as s: + return ( + s.query(CodeFile) + .join(FunctionalRequirement) + .filter(FunctionalRequirement.project_id == project_id) + .all() + ) \ No newline at end of file diff --git a/requirements_generator/database/models.py b/requirements_generator/database/models.py index b077013..d4147cd 100644 --- a/requirements_generator/database/models.py +++ b/requirements_generator/database/models.py @@ -1,122 +1,95 @@ -# database/models.py - 数据模型定义(SQLite 建表 DDL) -from dataclasses import dataclass, field +# database/models.py - SQLAlchemy ORM 数据模型 from datetime import datetime -from typing import Optional +from sqlalchemy import ( + Column, Integer, String, Text, Boolean, + DateTime, ForeignKey, Enum, +) +from sqlalchemy.orm import declarative_base, relationship -# ══════════════════════════════════════════════════════ -# DDL 建表语句 -# ══════════════════════════════════════════════════════ - -CREATE_TABLES_SQL = """ -PRAGMA foreign_keys = ON; - --- 项目表 -CREATE TABLE IF NOT EXISTS projects ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, -- 项目名称 - description TEXT, -- 项目描述 - language TEXT NOT NULL DEFAULT 'python', -- 目标代码语言 - output_dir TEXT NOT NULL, -- 输出目录路径 - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL -); - --- 原始需求表 -CREATE TABLE IF NOT EXISTS raw_requirements ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - project_id INTEGER NOT NULL, -- 关联项目 - content TEXT NOT NULL, -- 需求原文 - source_type TEXT NOT NULL DEFAULT 'text', -- text | file - source_name TEXT, -- 文件名(文件输入时) - knowledge TEXT, -- 合并后的知识库内容 - created_at TEXT NOT NULL, - FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE -); - --- 功能需求表 -CREATE TABLE IF NOT EXISTS functional_requirements ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - project_id INTEGER NOT NULL, - raw_req_id INTEGER NOT NULL, -- 关联原始需求 - index_no INTEGER NOT NULL, -- 序号 - title TEXT NOT NULL, -- 功能标题 - description TEXT NOT NULL, -- 功能描述 - function_name TEXT NOT NULL, -- 对应函数名 - priority TEXT NOT NULL DEFAULT 'medium', -- high|medium|low - status TEXT NOT NULL DEFAULT 'pending', -- pending|generated|skipped - is_custom INTEGER NOT NULL DEFAULT 0, -- 是否用户自定义 (0/1) - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL, - FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, - FOREIGN KEY (raw_req_id) REFERENCES raw_requirements(id) ON DELETE CASCADE -); - --- 代码文件表 -CREATE TABLE IF NOT EXISTS code_files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - project_id INTEGER NOT NULL, - func_req_id INTEGER NOT NULL UNIQUE, -- 关联功能需求(1对1) - file_name TEXT NOT NULL, -- 文件名 - file_path TEXT NOT NULL, -- 完整路径 - language TEXT NOT NULL, -- 代码语言 - content TEXT NOT NULL, -- 代码内容 - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL, - FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, - FOREIGN KEY (func_req_id) REFERENCES functional_requirements(id) ON DELETE CASCADE -); -""" - -# ══════════════════════════════════════════════════════ -# 数据类(Python 对象映射) -# ══════════════════════════════════════════════════════ - -@dataclass -class Project: - name: str - output_dir: str - language: str = "python" - description: str = "" - id: Optional[int] = None - created_at: str = field(default_factory=lambda: datetime.now().isoformat()) - updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) +Base = declarative_base() -@dataclass -class RawRequirement: - project_id: int - content: str - source_type: str = "text" # text | file - source_name: Optional[str] = None - knowledge: Optional[str] = None - id: Optional[int] = None - created_at: str = field(default_factory=lambda: datetime.now().isoformat()) +class Project(Base): + """项目表""" + __tablename__ = "projects" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(200), nullable=False, unique=True) + language = Column(String(50), nullable=False, default="python") + description = Column(Text, nullable=True) + output_dir = Column(String(500), nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + raw_requirements = relationship("RawRequirement", back_populates="project", cascade="all, delete-orphan") + functional_requirements = relationship("FunctionalRequirement", back_populates="project", cascade="all, delete-orphan") + + def __repr__(self): + return f"" -@dataclass -class FunctionalRequirement: - project_id: int - raw_req_id: int - index_no: int - title: str - description: str - function_name: str - priority: str = "medium" - status: str = "pending" - is_custom: bool = False - id: Optional[int] = None - created_at: str = field(default_factory=lambda: datetime.now().isoformat()) - updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) +class RawRequirement(Base): + """原始需求表""" + __tablename__ = "raw_requirements" + + id = Column(Integer, primary_key=True, autoincrement=True) + project_id = Column(Integer, ForeignKey("projects.id"), nullable=False) + content = Column(Text, nullable=False) + source_type = Column(String(20), nullable=False, default="text") # text / file + source_name = Column(String(200), nullable=True) + knowledge = Column(Text, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + + project = relationship("Project", back_populates="raw_requirements") + functional_requirements = relationship("FunctionalRequirement", back_populates="raw_requirement", cascade="all, delete-orphan") + + def __repr__(self): + return f"" -@dataclass -class CodeFile: - project_id: int - func_req_id: int - file_name: str - file_path: str - language: str - content: str - id: Optional[int] = None - created_at: str = field(default_factory=lambda: datetime.now().isoformat()) - updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) \ No newline at end of file +class FunctionalRequirement(Base): + """功能需求表""" + __tablename__ = "functional_requirements" + + id = Column(Integer, primary_key=True, autoincrement=True) + project_id = Column(Integer, ForeignKey("projects.id"), nullable=False) + raw_req_id = Column(Integer, ForeignKey("raw_requirements.id"), nullable=False) + index_no = Column(Integer, nullable=False) + title = Column(String(200), nullable=False) + description = Column(Text, nullable=False) + function_name = Column(String(200), nullable=False) + priority = Column(Enum("high", "medium", "low"), default="medium") + module = Column(String(100), nullable=True, default="default") # 功能模块 + status = Column(String(50), nullable=False, default="pending") # pending / generated / failed + is_custom = Column(Boolean, nullable=False, default=False) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + project = relationship("Project", back_populates="functional_requirements") + raw_requirement = relationship("RawRequirement", back_populates="functional_requirements") + code_files = relationship("CodeFile", back_populates="functional_requirement", cascade="all, delete-orphan") + + def __repr__(self): + return ( + f"" + ) + + +class CodeFile(Base): + """生成代码文件表""" + __tablename__ = "code_files" + + id = Column(Integer, primary_key=True, autoincrement=True) + func_req_id = Column(Integer, ForeignKey("functional_requirements.id"), nullable=False) + file_name = Column(String(200), nullable=False) + file_path = Column(String(500), nullable=False) + module = Column(String(100), nullable=True) # 冗余存储,方便查询 + language = Column(String(50), nullable=False) + content = Column(Text, nullable=True) + generated_at = Column(DateTime, default=datetime.utcnow) + + functional_requirement = relationship("FunctionalRequirement", back_populates="code_files") + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/requirements_generator/main.py b/requirements_generator/main.py index fae0d51..36c434c 100644 --- a/requirements_generator/main.py +++ b/requirements_generator/main.py @@ -1,17 +1,8 @@ #!/usr/bin/env python3 -# encoding: utf-8 -# main.py - 主入口:支持交互式 & 非交互式(CLI 参数)两种运行模式 -# -# 交互式: python main.py -# 非交互式:python main.py --non-interactive \ -# --project-name "MyProject" \ -# --language python \ -# --requirement-text "用户管理系统,包含注册、登录、修改密码功能" -# -# 完整参数见:python main.py --help +# main.py - 主入口:支持交互式 & 非交互式两种运行模式 import os import sys -from typing import Dict +from typing import Dict, List import click from rich.console import Console @@ -27,9 +18,9 @@ from core.requirement_analyzer import RequirementAnalyzer from core.code_generator import CodeGenerator from utils.file_handler import read_file_auto, merge_knowledge_files from utils.output_writer import ( - ensure_project_dir, build_project_output_dir, write_project_readme, - write_function_signatures_json, validate_all_signatures, - patch_signatures_with_url, + ensure_project_dir, build_project_output_dir, + write_project_readme, write_function_signatures_json, + validate_all_signatures, patch_signatures_with_url, ) console = Console() @@ -44,20 +35,20 @@ def print_banner(): console.print(Panel.fit( "[bold cyan]🚀 需求分析 & 代码生成工具[/bold cyan]\n" "[dim]Powered by LLM · SQLite · Python[/dim]", - border_style="cyan" + border_style="cyan", )) -def print_functional_requirements(reqs: list): - """以表格形式展示功能需求列表""" +def print_functional_requirements(reqs: List[FunctionalRequirement]): + """以表格形式展示功能需求列表(含模块列)""" table = Table(title="📋 功能需求列表", show_lines=True) - table.add_column("序号", style="cyan", width=6) - table.add_column("ID", style="dim", width=6) - table.add_column("标题", style="bold", width=20) - table.add_column("函数名", width=25) - table.add_column("优先级", width=8) - table.add_column("类型", width=8) - table.add_column("描述", width=40) + table.add_column("序号", style="cyan", width=6) + table.add_column("ID", style="dim", width=6) + table.add_column("模块", style="magenta", width=15) + table.add_column("标题", style="bold", width=18) + table.add_column("函数名", width=25) + table.add_column("优先级", width=8) + table.add_column("描述", width=35) priority_color = {"high": "red", "medium": "yellow", "low": "green"} for req in reqs: @@ -65,53 +56,51 @@ def print_functional_requirements(reqs: list): table.add_row( str(req.index_no), str(req.id) if req.id else "-", + req.module or config.DEFAULT_MODULE, req.title, f"[code]{req.function_name}[/code]", f"[{color}]{req.priority}[/{color}]", - "[magenta]自定义[/magenta]" if req.is_custom else "LLM生成", - req.description[:60] + "..." if len(req.description) > 60 else req.description, + req.description[:50] + "..." if len(req.description) > 50 else req.description, ) console.print(table) -def print_signatures_preview(signatures: list): - """ - 以表格形式预览函数签名列表(含 url 字段) +def print_module_summary(reqs: List[FunctionalRequirement]): + """打印模块分组摘要""" + module_map: Dict[str, List[str]] = {} + for req in reqs: + m = req.module or config.DEFAULT_MODULE + module_map.setdefault(m, []).append(req.function_name) - Args: - signatures: 纯签名列表(顶层文档的 "functions" 字段) - """ + table = Table(title="📦 功能模块分组", show_lines=True) + table.add_column("模块", style="magenta bold", width=20) + table.add_column("函数数量", style="cyan", width=8) + table.add_column("函数列表", width=50) + for module, funcs in sorted(module_map.items()): + table.add_row(module, str(len(funcs)), ", ".join(funcs)) + console.print(table) + + +def print_signatures_preview(signatures: List[dict]): + """以表格形式预览函数签名列表(含 module / url 列)""" table = Table(title="📄 函数签名预览", show_lines=True) - table.add_column("需求编号", style="cyan", width=10) - table.add_column("函数名", style="bold", width=25) - table.add_column("参数数量", width=8) - table.add_column("返回类型", width=10) - table.add_column("成功返回值", width=18) - table.add_column("失败返回值", width=18) - table.add_column("URL", style="dim", width=30) - - def _fmt_value(v) -> str: - if v is None: - return "-" - if isinstance(v, dict): - return "{" + ", ".join(v.keys()) + "}" - return str(v)[:16] + table.add_column("需求编号", style="cyan", width=8) + table.add_column("模块", style="magenta", width=15) + table.add_column("函数名", style="bold", width=22) + table.add_column("参数数", width=6) + table.add_column("返回类型", width=10) + table.add_column("URL", style="dim", width=28) for sig in signatures: - ret = sig.get("return") or {} - on_success = ret.get("on_success") or {} - on_failure = ret.get("on_failure") or {} - url = sig.get("url", "") - # 只显示文件名部分,避免路径过长 + ret = sig.get("return") or {} + url = sig.get("url", "") url_display = os.path.basename(url) if url else "[dim]待生成[/dim]" - table.add_row( sig.get("requirement_id", "-"), + sig.get("module", "-"), sig.get("name", "-"), str(len(sig.get("parameters", {}))), ret.get("type", "void"), - _fmt_value(on_success.get("value")), - _fmt_value(on_failure.get("value")), url_display, ) console.print(table) @@ -122,44 +111,44 @@ def print_signatures_preview(signatures: list): # ══════════════════════════════════════════════════════ def step_init_project( - project_name: str = None, - language: str = None, - description: str = "", + project_name: str = None, + language: str = None, + description: str = "", non_interactive: bool = False, ) -> Project: + console.print( + "\n[bold]Step 1 · 项目配置[/bold]" + + (" [dim](非交互)[/dim]" if non_interactive else ""), + style="blue", + ) if not non_interactive: - console.print("\n[bold]Step 1 · 项目配置[/bold]", style="blue") - project_name = project_name or Prompt.ask("📁 请输入项目名称") + project_name = project_name or Prompt.ask("📁 项目名称") language = language or Prompt.ask( - "💻 目标代码语言", - default=config.DEFAULT_LANGUAGE, - choices=["python", "javascript", "typescript", "java", "go", "rust"], + "💻 目标语言", default=config.DEFAULT_LANGUAGE, + choices=["python","javascript","typescript","java","go","rust"], ) description = description or Prompt.ask("📝 项目描述(可选)", default="") else: if not project_name: raise ValueError("非交互模式下 --project-name 为必填项") language = language or config.DEFAULT_LANGUAGE - console.print("\n[bold]Step 1 · 项目配置[/bold] [dim](非交互)[/dim]", style="blue") - console.print(f" 项目名称: {project_name} 语言: {language}") + console.print(f" 项目: {project_name} 语言: {language}") existing = db.get_project_by_name(project_name) if existing: if non_interactive: - console.print(f"[green]✓ 已加载已有项目: {project_name} (ID={existing.id})[/green]") + console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]") return existing - use_existing = Confirm.ask(f"⚠️ 项目 '{project_name}' 已存在,是否继续使用?") - if use_existing: + if Confirm.ask(f"⚠️ 项目 '{project_name}' 已存在,继续使用?"): console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]") return existing project_name = Prompt.ask("请输入新的项目名称") - output_dir = build_project_output_dir(project_name) project = Project( - name=project_name, - language=language, - output_dir=output_dir, - description=description, + name = project_name, + language = language, + output_dir = build_project_output_dir(project_name), + description = description, ) project.id = db.create_project(project) console.print(f"[green]✓ 项目已创建: {project_name} (ID={project.id})[/green]") @@ -171,18 +160,17 @@ def step_init_project( # ══════════════════════════════════════════════════════ def step_input_requirement( - project: Project, - requirement_text: str = None, - requirement_file: str = None, - knowledge_files: list = None, - non_interactive: bool = False, + project: Project, + requirement_text: str = None, + requirement_file: str = None, + knowledge_files: list = None, + non_interactive: bool = False, ) -> tuple: console.print( - f"\n[bold]Step 2 · 输入原始需求[/bold]" + "\n[bold]Step 2 · 输入原始需求[/bold]" + (" [dim](非交互)[/dim]" if non_interactive else ""), style="blue", ) - raw_text = "" source_name = None source_type = "text" @@ -195,12 +183,11 @@ def step_input_requirement( console.print(f" 需求文件: {source_name} ({len(raw_text)} 字符)") elif requirement_text: raw_text = requirement_text - source_type = "text" - console.print(f" 需求文本: {raw_text[:80]}{'...' if len(raw_text) > 80 else ''}") + console.print(f" 需求文本: {raw_text[:80]}{'...' if len(raw_text)>80 else ''}") else: raise ValueError("非交互模式下必须提供 --requirement-text 或 --requirement-file") else: - input_type = Prompt.ask("📥 需求输入方式", choices=["text", "file"], default="text") + input_type = Prompt.ask("📥 需求输入方式", choices=["text","file"], default="text") if input_type == "text": console.print("[dim]请输入原始需求(输入空行结束):[/dim]") lines = [] @@ -209,14 +196,13 @@ def step_input_requirement( if line == "" and lines: break lines.append(line) - raw_text = "\n".join(lines) - source_type = "text" + raw_text = "\n".join(lines) else: - file_path = Prompt.ask("📂 需求文件路径") - raw_text = read_file_auto(file_path) - source_name = os.path.basename(file_path) + fp = Prompt.ask("📂 需求文件路径") + raw_text = read_file_auto(fp) + source_name = os.path.basename(fp) source_type = "file" - console.print(f"[green]✓ 已读取文件: {source_name} ({len(raw_text)} 字符)[/green]") + console.print(f"[green]✓ 已读取: {source_name} ({len(raw_text)} 字符)[/green]") knowledge_text = "" if non_interactive: @@ -224,18 +210,17 @@ def step_input_requirement( knowledge_text = merge_knowledge_files(list(knowledge_files)) console.print(f" 知识库: {len(knowledge_files)} 个文件,{len(knowledge_text)} 字符") else: - use_kb = Confirm.ask("📚 是否输入知识库文件?", default=False) - if use_kb: + if Confirm.ask("📚 是否输入知识库文件?", default=False): kb_paths = [] while True: - kb_path = Prompt.ask("知识库文件路径(留空结束)", default="") - if not kb_path: + p = Prompt.ask("知识库文件路径(留空结束)", default="") + if not p: break - if os.path.exists(kb_path): - kb_paths.append(kb_path) - console.print(f" [green]+ {kb_path}[/green]") + if os.path.exists(p): + kb_paths.append(p) + console.print(f" [green]+ {p}[/green]") else: - console.print(f" [red]文件不存在: {kb_path}[/red]") + console.print(f" [red]文件不存在: {p}[/red]") if kb_paths: knowledge_text = merge_knowledge_files(kb_paths) console.print(f"[green]✓ 知识库已合并 ({len(knowledge_text)} 字符)[/green]") @@ -248,37 +233,36 @@ def step_input_requirement( # ══════════════════════════════════════════════════════ def step_decompose_requirements( - project: Project, - raw_text: str, - knowledge_text: str, - source_name: str, - source_type: str, + project: Project, + raw_text: str, + knowledge_text: str, + source_name: str, + source_type: str, non_interactive: bool = False, ) -> tuple: console.print( - f"\n[bold]Step 3 · LLM 需求分解[/bold]" + "\n[bold]Step 3 · LLM 需求分解[/bold]" + (" [dim](非交互)[/dim]" if non_interactive else ""), style="blue", ) - raw_req = RawRequirement( - project_id=project.id, - content=raw_text, - source_type=source_type, - source_name=source_name, - knowledge=knowledge_text or None, + project_id = project.id, + content = raw_text, + source_type = source_type, + source_name = source_name, + knowledge = knowledge_text or None, ) raw_req_id = db.create_raw_requirement(raw_req) console.print(f"[dim]原始需求已存储 (ID={raw_req_id})[/dim]") - with console.status("[bold yellow]🤖 LLM 正在分解需求,请稍候...[/bold yellow]"): + with console.status("[bold yellow]🤖 LLM 正在分解需求...[/bold yellow]"): llm = LLMClient() analyzer = RequirementAnalyzer(llm) func_reqs = analyzer.decompose( - raw_requirement=raw_text, - project_id=project.id, - raw_req_id=raw_req_id, - knowledge=knowledge_text, + raw_requirement = raw_text, + project_id = project.id, + raw_req_id = raw_req_id, + knowledge = knowledge_text, ) for req in func_reqs: @@ -289,18 +273,107 @@ def step_decompose_requirements( # ══════════════════════════════════════════════════════ -# Step 4:用户编辑功能需求 +# Step 4:模块分类(可选重新分类) +# ══════════════════════════════════════════════════════ + +def step_classify_modules( + project: Project, + func_reqs: List[FunctionalRequirement], + knowledge_text: str = "", + non_interactive: bool = False, +) -> List[FunctionalRequirement]: + """ + Step 4:对功能需求进行模块分类。 + + - 非交互模式:直接使用 LLM 分类结果 + - 交互模式:展示 LLM 分类结果,允许用户手动调整 + """ + console.print( + "\n[bold]Step 4 · 功能模块分类[/bold]" + + (" [dim](非交互)[/dim]" if non_interactive else ""), + style="blue", + ) + + # LLM 自动分类 + with console.status("[bold yellow]🤖 LLM 正在进行模块分类...[/bold yellow]"): + llm = LLMClient() + analyzer = RequirementAnalyzer(llm) + try: + updates = analyzer.classify_modules(func_reqs, knowledge_text) + # 回写 module 到 func_reqs 对象 + name_to_module = {u["function_name"]: u["module"] for u in updates} + for req in func_reqs: + req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE) + db.update_functional_requirement(req) + console.print(f"[green]✓ LLM 模块分类完成[/green]") + except Exception as e: + console.print(f"[yellow]⚠ 模块分类失败,保留原有模块: {e}[/yellow]") + + print_module_summary(func_reqs) + + if non_interactive: + return func_reqs + + # 交互式调整 + while True: + console.print( + "\n模块操作: [cyan]r[/cyan]=重新分类 " + "[cyan]e[/cyan]=手动编辑某需求的模块 [cyan]ok[/cyan]=确认继续" + ) + action = Prompt.ask("请选择操作", default="ok").strip().lower() + + if action == "ok": + break + + elif action == "r": + # 重新触发 LLM 分类 + with console.status("[bold yellow]🤖 重新分类中...[/bold yellow]"): + try: + updates = analyzer.classify_modules(func_reqs, knowledge_text) + name_to_module = {u["function_name"]: u["module"] for u in updates} + for req in func_reqs: + req.module = name_to_module.get(req.function_name, config.DEFAULT_MODULE) + db.update_functional_requirement(req) + console.print("[green]✓ 重新分类完成[/green]") + except Exception as e: + console.print(f"[red]重新分类失败: {e}[/red]") + print_module_summary(func_reqs) + + elif action == "e": + print_functional_requirements(func_reqs) + idx_str = Prompt.ask("输入要修改模块的需求序号") + if not idx_str.isdigit(): + continue + idx = int(idx_str) + target = next((r for r in func_reqs if r.index_no == idx), None) + if target is None: + console.print("[red]序号不存在[/red]") + continue + new_module = Prompt.ask( + f"新模块名(当前: {target.module})", + default=target.module or config.DEFAULT_MODULE, + ) + target.module = new_module.strip() or config.DEFAULT_MODULE + db.update_functional_requirement(target) + console.print(f"[green]✓ 已更新 '{target.function_name}' → 模块: {target.module}[/green]") + print_module_summary(func_reqs) + + return func_reqs + + +# ══════════════════════════════════════════════════════ +# Step 5:编辑功能需求 # ══════════════════════════════════════════════════════ def step_edit_requirements( - project: Project, - func_reqs: list, - raw_req_id: int, + project: Project, + func_reqs: List[FunctionalRequirement], + raw_req_id: int, non_interactive: bool = False, - skip_indices: list = None, -) -> list: + skip_indices: list = None, +) -> List[FunctionalRequirement]: console.print( - f"\n[bold]Step 4 · 编辑功能需求[/bold]" + "\n[bold]Step 5 · 编辑功能需求[/bold]" + (" [dim](非交互)[/dim]" if non_interactive else ""), style="blue", ) @@ -334,8 +407,9 @@ def step_edit_requirements( if action == "ok": break + elif action == "d": - idx_str = Prompt.ask("输入要删除的功能需求序号(多个用逗号分隔)") + idx_str = Prompt.ask("输入要删除的序号(多个用逗号分隔)") to_delete = {int(x.strip()) for x in idx_str.split(",") if x.strip().isdigit()} removed, kept = [], [] for req in func_reqs: @@ -349,28 +423,35 @@ def step_edit_requirements( req.index_no = i db.update_functional_requirement(req) console.print(f"[red]✗ 已删除: {', '.join(removed)}[/red]") + elif action == "a": - title = Prompt.ask("功能标题") - description = Prompt.ask("功能描述") - func_name = Prompt.ask("函数名 (snake_case)") - priority = Prompt.ask( - "优先级", choices=["high", "medium", "low"], default="medium" + title = Prompt.ask("功能标题") + description = Prompt.ask("功能描述") + func_name = Prompt.ask("函数名 (snake_case)") + priority = Prompt.ask( + "优先级", choices=["high","medium","low"], default="medium" + ) + module = Prompt.ask( + "所属模块(snake_case,留空使用默认)", + default=config.DEFAULT_MODULE, ) new_req = FunctionalRequirement( - project_id=project.id, - raw_req_id=raw_req_id, - index_no=len(func_reqs) + 1, - title=title, - description=description, - function_name=func_name, - priority=priority, - is_custom=True, + project_id = project.id, + raw_req_id = raw_req_id, + index_no = len(func_reqs) + 1, + title = title, + description = description, + function_name = func_name, + priority = priority, + module = module.strip() or config.DEFAULT_MODULE, + is_custom = True, ) new_req.id = db.create_functional_requirement(new_req) func_reqs.append(new_req) - console.print(f"[green]✓ 已添加自定义需求: {title}[/green]") + console.print(f"[green]✓ 已添加: {title} → 模块: {new_req.module}[/green]") + elif action == "e": - idx_str = Prompt.ask("输入要编辑的功能需求序号") + idx_str = Prompt.ask("输入要编辑的序号") if not idx_str.isdigit(): continue idx = int(idx_str) @@ -382,8 +463,11 @@ def step_edit_requirements( target.description = Prompt.ask("新描述", default=target.description) target.function_name = Prompt.ask("新函数名", default=target.function_name) target.priority = Prompt.ask( - "新优先级", choices=["high", "medium", "low"], default=target.priority + "新优先级", choices=["high","medium","low"], default=target.priority ) + target.module = Prompt.ask( + "新模块", default=target.module or config.DEFAULT_MODULE + ).strip() or config.DEFAULT_MODULE db.update_functional_requirement(target) console.print(f"[green]✓ 已更新: {target.title}[/green]") @@ -391,33 +475,24 @@ def step_edit_requirements( # ══════════════════════════════════════════════════════ -# Step 5A:生成函数签名 JSON(不含 url 字段,待 5C 回写) +# Step 6A:生成函数签名 JSON(初版,不含 url) # ══════════════════════════════════════════════════════ def step_generate_signatures( - project: Project, - func_reqs: list, - output_dir: str, - knowledge_text: str, - json_file_name: str = "function_signatures.json", + project: Project, + func_reqs: List[FunctionalRequirement], + output_dir: str, + knowledge_text: str, + json_file_name: str = "function_signatures.json", non_interactive: bool = False, ) -> tuple: - """ - 为所有功能需求生成函数签名,写入初版 JSON(不含 url 字段)。 - url 字段将在 Step 5C 代码生成完成后回写并刷新 JSON 文件。 - - Returns: - (signatures: List[dict], json_path: str) - """ console.print( - f"\n[bold]Step 5A · 生成函数签名 JSON[/bold]" + "\n[bold]Step 6A · 生成函数签名 JSON[/bold]" + (" [dim](非交互)[/dim]" if non_interactive else ""), style="blue", ) - llm = LLMClient() analyzer = RequirementAnalyzer(llm) - success_count = 0 fail_count = 0 @@ -425,42 +500,40 @@ def step_generate_signatures( nonlocal success_count, fail_count if error: console.print( - f" [{index}/{total}] [yellow]⚠ {req.title} 签名生成失败," - f"使用降级结构: {error}[/yellow]" + f" [{index}/{total}] [yellow]⚠ {req.title} 签名生成失败" + f"(降级): {error}[/yellow]" ) fail_count += 1 else: console.print( f" [{index}/{total}] [green]✓ {req.title}[/green] " - f"→ [dim]{signature.get('name')}()[/dim] " - f"params={len(signature.get('parameters', {}))}" + f"[dim]{req.module}[/dim] → {signature.get('name')}()" ) success_count += 1 console.print(f"[yellow]正在为 {len(func_reqs)} 个功能需求生成函数签名...[/yellow]") signatures = analyzer.build_function_signatures_batch( - func_reqs=func_reqs, - knowledge=knowledge_text, - on_progress=on_progress, + func_reqs = func_reqs, + knowledge = knowledge_text, + on_progress = on_progress, ) # 校验 - validation_report = validate_all_signatures(signatures) - if validation_report: - console.print(f"[yellow]⚠ 发现 {len(validation_report)} 个签名存在结构问题:[/yellow]") - for fname, errors in validation_report.items(): - for err in errors: + report = validate_all_signatures(signatures) + if report: + console.print(f"[yellow]⚠ {len(report)} 个签名存在结构问题:[/yellow]") + for fname, errs in report.items(): + for err in errs: console.print(f" [yellow]· {fname}: {err}[/yellow]") else: console.print("[green]✓ 所有签名结构校验通过[/green]") - # 写入初版 JSON(url 字段尚未填入) json_path = write_function_signatures_json( - output_dir=output_dir, - signatures=signatures, - project_name=project.name, - project_description=project.description or "", # ← 传入项目描述 - file_name=json_file_name, + output_dir = output_dir, + signatures = signatures, + project_name = project.name, + project_description = project.description or "", + file_name = json_file_name, ) console.print( f"[green]✓ 签名 JSON 初版已写入: [cyan]{os.path.abspath(json_path)}[/cyan][/green]\n" @@ -470,35 +543,32 @@ def step_generate_signatures( # ══════════════════════════════════════════════════════ -# Step 5B:生成代码文件,收集 {函数名: 文件路径} 映射 +# Step 6B:生成代码文件(按模块写入子目录) # ══════════════════════════════════════════════════════ def step_generate_code( - project: Project, - func_reqs: list, - output_dir: str, - knowledge_text: str, - signatures: list, + project: Project, + func_reqs: List[FunctionalRequirement], + output_dir: str, + knowledge_text: str, + signatures: List[dict], non_interactive: bool = False, ) -> Dict[str, str]: """ - 依据签名约束批量生成代码文件。 + 批量生成代码文件,按 req.module 路由到 output_dir// 子目录。 Returns: - func_name_to_url: {函数名: 代码文件绝对路径} 映射表, - 供 Step 5C 回写 url 字段使用。 - 生成失败的函数不会出现在映射表中。 + func_name_to_url: {函数名: 代码文件绝对路径} """ console.print( - f"\n[bold]Step 5B · 生成代码文件[/bold]" + "\n[bold]Step 6B · 生成代码文件[/bold]" + (" [dim](非交互)[/dim]" if non_interactive else ""), style="blue", ) - - generator = CodeGenerator(LLMClient()) - success_count = 0 - fail_count = 0 - func_name_to_url: Dict[str, str] = {} # ← 收集 函数名 → 文件绝对路径 + generator = CodeGenerator(LLMClient()) + success_count = 0 + fail_count = 0 + func_name_to_url: Dict[str, str] = {} def on_progress(index, total, req, code_file, error): nonlocal success_count, fail_count @@ -509,29 +579,38 @@ def step_generate_code( db.upsert_code_file(code_file) req.status = "generated" db.update_functional_requirement(req) - # 收集 函数名 → 绝对文件路径(作为 url 回写) func_name_to_url[req.function_name] = os.path.abspath(code_file.file_path) console.print( f" [{index}/{total}] [green]✓ {req.title}[/green] " - f"→ [dim]{code_file.file_name}[/dim]" + f"[dim]{req.module}/{code_file.file_name}[/dim]" ) success_count += 1 - console.print(f"[yellow]开始生成 {len(func_reqs)} 个代码文件(签名约束模式)...[/yellow]") + console.print( + f"[yellow]开始生成 {len(func_reqs)} 个代码文件(按模块分目录)...[/yellow]" + ) generator.generate_batch( - func_reqs=func_reqs, - output_dir=output_dir, - language=project.language, - knowledge=knowledge_text, - signatures=signatures, - on_progress=on_progress, + func_reqs = func_reqs, + output_dir = output_dir, + language = project.language, + knowledge = knowledge_text, + signatures = signatures, + on_progress = on_progress, ) + # 生成 README(含模块列表) + modules = list({req.module or config.DEFAULT_MODULE for req in func_reqs}) req_summary = "\n".join( - f"{i+1}. **{r.title}** (`{r.function_name}`) - {r.description[:80]}" + f"{i+1}. **{r.title}** (`{r.module}/{r.function_name}`) - {r.description[:80]}" for i, r in enumerate(func_reqs) ) - write_project_readme(output_dir, project.name, req_summary) + write_project_readme( + output_dir = output_dir, + project_name = project.name, + project_description = project.description or "", + requirements_summary = req_summary, + modules = modules, + ) console.print(Panel( f"[bold green]✅ 代码生成完成![/bold green]\n" @@ -543,65 +622,38 @@ def step_generate_code( # ══════════════════════════════════════════════════════ -# Step 5C:回写 url 字段并刷新 JSON +# Step 6C:回写 url 字段并刷新 JSON # ══════════════════════════════════════════════════════ def step_patch_signatures_url( - project: Project, - signatures: list, + project: Project, + signatures: List[dict], func_name_to_url: Dict[str, str], - output_dir: str, - json_file_name: str, - non_interactive: bool = False, + output_dir: str, + json_file_name: str, + non_interactive: bool = False, ) -> str: - """ - 将代码文件路径回写到签名的 "url" 字段,并重新写入 JSON 文件。 - - 执行流程: - 1. 调用 patch_signatures_with_url() 原地修改签名列表 - 2. 打印最终签名预览(含 url 列) - 3. 重新调用 write_function_signatures_json() 覆盖写入 JSON - - Args: - project: 项目对象(提供 name 与 description) - signatures: Step 5A 产出的签名列表(将被原地修改) - func_name_to_url: Step 5B 收集的 {函数名: 文件绝对路径} 映射 - output_dir: JSON 文件所在目录 - json_file_name: JSON 文件名 - non_interactive: 是否非交互模式 - - Returns: - 刷新后的 JSON 文件绝对路径 - """ console.print( - f"\n[bold]Step 5C · 回写代码文件路径(url)到签名 JSON[/bold]" + "\n[bold]Step 6C · 回写代码路径(url)到签名 JSON[/bold]" + (" [dim](非交互)[/dim]" if non_interactive else ""), style="blue", ) - - # 原地回写 url 字段 patch_signatures_with_url(signatures, func_name_to_url) patched = sum(1 for s in signatures if s.get("url")) unpatched = len(signatures) - patched if unpatched: - console.print( - f"[yellow]⚠ {unpatched} 个函数未能写入 url" - f"(对应代码文件生成失败)[/yellow]" - ) + console.print(f"[yellow]⚠ {unpatched} 个函数 url 未回写(代码生成失败)[/yellow]") - # 打印最终预览(含 url 列) print_signatures_preview(signatures) - # 覆盖写入 JSON(含 project.description) json_path = write_function_signatures_json( - output_dir=output_dir, - signatures=signatures, - project_name=project.name, - project_description=project.description or "", - file_name=json_file_name, + output_dir = output_dir, + signatures = signatures, + project_name = project.name, + project_description = project.description or "", + file_name = json_file_name, ) - console.print( f"[green]✓ 签名 JSON 已更新(含 url): " f"[cyan]{os.path.abspath(json_path)}[/cyan][/green]\n" @@ -615,53 +667,61 @@ def step_patch_signatures_url( # ══════════════════════════════════════════════════════ def run_workflow( - project_name: str = None, - language: str = None, - description: str = "", - requirement_text: str = None, - requirement_file: str = None, - knowledge_files: tuple = (), - skip_indices: list = None, - json_file_name: str = "function_signatures.json", - non_interactive: bool = False, + project_name: str = None, + language: str = None, + description: str = "", + requirement_text: str = None, + requirement_file: str = None, + knowledge_files: tuple = (), + skip_indices: list = None, + json_file_name: str = "function_signatures.json", + non_interactive: bool = False, ): - """完整工作流(Step 1 → 5C)""" + """完整工作流 Step 1 → 6C""" print_banner() - # Step 1 + # Step 1:项目初始化 project = step_init_project( - project_name=project_name, - language=language, - description=description, - non_interactive=non_interactive, + project_name = project_name, + language = language, + description = description, + non_interactive = non_interactive, ) - # Step 2 + # Step 2:输入原始需求 raw_text, knowledge_text, source_name, source_type = step_input_requirement( - project=project, - requirement_text=requirement_text, - requirement_file=requirement_file, - knowledge_files=list(knowledge_files) if knowledge_files else [], - non_interactive=non_interactive, + project = project, + requirement_text = requirement_text, + requirement_file = requirement_file, + knowledge_files = list(knowledge_files) if knowledge_files else [], + non_interactive = non_interactive, ) - # Step 3 + # Step 3:LLM 需求分解 raw_req_id, func_reqs = step_decompose_requirements( - project=project, - raw_text=raw_text, - knowledge_text=knowledge_text, - source_name=source_name, - source_type=source_type, - non_interactive=non_interactive, + project = project, + raw_text = raw_text, + knowledge_text = knowledge_text, + source_name = source_name, + source_type = source_type, + non_interactive = non_interactive, ) - # Step 4 + # Step 4:模块分类 + func_reqs = step_classify_modules( + project = project, + func_reqs = func_reqs, + knowledge_text = knowledge_text, + non_interactive = non_interactive, + ) + + # Step 5:编辑功能需求 func_reqs = step_edit_requirements( - project=project, - func_reqs=func_reqs, - raw_req_id=raw_req_id, - non_interactive=non_interactive, - skip_indices=skip_indices or [], + project = project, + func_reqs = func_reqs, + raw_req_id = raw_req_id, + non_interactive = non_interactive, + skip_indices = skip_indices or [], ) if not func_reqs: @@ -670,40 +730,43 @@ def run_workflow( output_dir = ensure_project_dir(project.name) - # Step 5A:生成签名(初版,不含 url) + # Step 6A:生成函数签名 signatures, json_path = step_generate_signatures( - project=project, - func_reqs=func_reqs, - output_dir=output_dir, - knowledge_text=knowledge_text, - json_file_name=json_file_name, - non_interactive=non_interactive, + project = project, + func_reqs = func_reqs, + output_dir = output_dir, + knowledge_text = knowledge_text, + json_file_name = json_file_name, + non_interactive = non_interactive, ) - # Step 5B:生成代码,收集 {函数名: 文件路径} + # Step 6B:生成代码文件 func_name_to_url = step_generate_code( - project=project, - func_reqs=func_reqs, - output_dir=output_dir, - knowledge_text=knowledge_text, - signatures=signatures, - non_interactive=non_interactive, + project = project, + func_reqs = func_reqs, + output_dir = output_dir, + knowledge_text = knowledge_text, + signatures = signatures, + non_interactive = non_interactive, ) - # Step 5C:回写 url 字段,刷新 JSON + # Step 6C:回写 url,刷新 JSON json_path = step_patch_signatures_url( - project=project, - signatures=signatures, - func_name_to_url=func_name_to_url, - output_dir=output_dir, - json_file_name=json_file_name, - non_interactive=non_interactive, + project = project, + signatures = signatures, + func_name_to_url = func_name_to_url, + output_dir = output_dir, + json_file_name = json_file_name, + non_interactive = non_interactive, ) + # 最终汇总 + modules = sorted({req.module or config.DEFAULT_MODULE for req in func_reqs}) console.print(Panel( f"[bold cyan]🎉 全部流程完成![/bold cyan]\n" f"项目: [bold]{project.name}[/bold]\n" f"描述: {project.description or '(无)'}\n" + f"模块: {', '.join(modules)}\n" f"代码目录: [cyan]{os.path.abspath(output_dir)}[/cyan]\n" f"签名文件: [cyan]{json_path}[/cyan]", border_style="cyan", @@ -715,25 +778,22 @@ def run_workflow( # ══════════════════════════════════════════════════════ @click.command() -@click.option("--non-interactive", is_flag=True, default=False, - help="以非交互模式运行(所有参数通过命令行传入)") +@click.option("--non-interactive", is_flag=True, default=False, + help="以非交互模式运行") @click.option("--project-name", "-p", default=None, help="项目名称") @click.option("--language", "-l", default=None, type=click.Choice(["python","javascript","typescript","java","go","rust"]), help=f"目标代码语言(默认: {config.DEFAULT_LANGUAGE})") @click.option("--description", "-d", default="", help="项目描述") -@click.option("--requirement-text","-r", default=None, - help="原始需求文本(与 --requirement-file 二选一)") +@click.option("--requirement-text","-r", default=None, help="原始需求文本") @click.option("--requirement-file","-f", default=None, - type=click.Path(exists=True), - help="原始需求文件路径(支持 .txt/.md/.pdf/.docx)") + type=click.Path(exists=True), help="原始需求文件路径") @click.option("--knowledge-file", "-k", default=None, multiple=True, - type=click.Path(exists=True), - help="知识库文件路径(可多次指定,如 -k a.md -k b.pdf)") + type=click.Path(exists=True), help="知识库文件(可多次指定)") @click.option("--skip-index", "-s", default=None, multiple=True, type=int, - help="要跳过的功能需求序号(可多次指定,如 -s 2 -s 5)") + help="跳过的功能需求序号(可多次指定)") @click.option("--json-file-name", "-j", default="function_signatures.json", - help="函数签名 JSON 文件名(默认: function_signatures.json)") + help="签名 JSON 文件名") def cli( non_interactive, project_name, language, description, requirement_text, requirement_file, knowledge_file, @@ -743,7 +803,7 @@ def cli( 需求分析 & 代码生成工具 \b - 交互式运行(推荐初次使用): + 交互式运行: python main.py \b @@ -752,28 +812,19 @@ def cli( --project-name "UserSystem" \\ --description "用户管理系统后端服务" \\ --language python \\ - --requirement-text "用户管理系统,包含注册、登录、修改密码功能" \\ - --knowledge-file docs/api_spec.md \\ - --json-file-name api_signatures.json - - \b - 从文件读取需求 + 跳过部分功能需求: - python main.py --non-interactive \\ - --project-name "MyProject" \\ - --requirement-file requirements.md \\ - --skip-index 3 --skip-index 7 + --requirement-text "用户管理系统,包含注册、登录、修改密码功能" """ try: run_workflow( - project_name=project_name, - language=language, - description=description, - requirement_text=requirement_text, - requirement_file=requirement_file, - knowledge_files=knowledge_file, - skip_indices=list(skip_index) if skip_index else [], - json_file_name=json_file_name, - non_interactive=non_interactive, + project_name = project_name, + language = language, + description = description, + requirement_text = requirement_text, + requirement_file = requirement_file, + knowledge_files = knowledge_file, + skip_indices = list(skip_index) if skip_index else [], + json_file_name = json_file_name, + non_interactive = non_interactive, ) except KeyboardInterrupt: console.print("\n[yellow]用户中断,退出[/yellow]") diff --git a/requirements_generator/requirements.txt b/requirements_generator/requirements.txt index e265d13..1674436 100644 --- a/requirements_generator/requirements.txt +++ b/requirements_generator/requirements.txt @@ -1,6 +1,7 @@ -openai>=1.0.0 -python-dotenv>=1.0.0 +openai>=1.30.0 +click>=8.1.0 rich>=13.0.0 -python-docx>=0.8.11 -PyPDF2>=3.0.0 -click>=8.1.0 \ No newline at end of file +sqlalchemy>=2.0.0 +python-dotenv>=1.0.0 +pypdf>=4.0.0 +python-docx>=1.1.0 \ No newline at end of file diff --git a/requirements_generator/run.bat b/requirements_generator/run.bat new file mode 100644 index 0000000..5c0e2ce --- /dev/null +++ b/requirements_generator/run.bat @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +# 安装依赖 +pip install -r requirements.txt + +# 设置环境变量 +set OPENAI_API_KEY="sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR" +set OPENAI_BASE_URL="https://openapi.monica.im/v1" # 或其他兼容接口 +set LLM_MODEL="gpt-4o" + +python main.py diff --git a/requirements_generator/run.sh b/requirements_generator/run.sh old mode 100755 new mode 100644 diff --git a/requirements_generator/utils/file_handler.py b/requirements_generator/utils/file_handler.py index f32738d..298aaf0 100644 --- a/requirements_generator/utils/file_handler.py +++ b/requirements_generator/utils/file_handler.py @@ -1,95 +1,13 @@ -# utils/file_handler.py - 文件读取工具(支持 txt/md/pdf/docx) +# utils/file_handler.py - 文件读取工具(支持 txt / md / pdf / docx) import os -from typing import List, Optional -from pathlib import Path - - -def read_text_file(file_path: str) -> str: - """ - 读取纯文本文件内容(.txt / .md / .py 等) - - Args: - file_path: 文件路径 - - Returns: - 文件文本内容 - - Raises: - FileNotFoundError: 文件不存在 - UnicodeDecodeError: 编码错误时尝试 latin-1 兜底 - """ - path = Path(file_path) - if not path.exists(): - raise FileNotFoundError(f"文件不存在: {file_path}") - try: - return path.read_text(encoding="utf-8") - except UnicodeDecodeError: - return path.read_text(encoding="latin-1") - - -def read_pdf_file(file_path: str) -> str: - """ - 读取 PDF 文件内容 - - Args: - file_path: PDF 文件路径 - - Returns: - 提取的文本内容 - - Raises: - ImportError: 未安装 PyPDF2 - FileNotFoundError: 文件不存在 - """ - try: - import PyPDF2 - except ImportError: - raise ImportError("请安装 PyPDF2: pip install PyPDF2") - - path = Path(file_path) - if not path.exists(): - raise FileNotFoundError(f"文件不存在: {file_path}") - - texts = [] - with open(file_path, "rb") as f: - reader = PyPDF2.PdfReader(f) - for page in reader.pages: - text = page.extract_text() - if text: - texts.append(text) - return "\n".join(texts) - - -def read_docx_file(file_path: str) -> str: - """ - 读取 Word (.docx) 文件内容 - - Args: - file_path: docx 文件路径 - - Returns: - 提取的文本内容(段落合并) - - Raises: - ImportError: 未安装 python-docx - FileNotFoundError: 文件不存在 - """ - try: - from docx import Document - except ImportError: - raise ImportError("请安装 python-docx: pip install python-docx") - - path = Path(file_path) - if not path.exists(): - raise FileNotFoundError(f"文件不存在: {file_path}") - - doc = Document(file_path) - return "\n".join(para.text for para in doc.paragraphs if para.text.strip()) +from typing import List def read_file_auto(file_path: str) -> str: """ - 根据文件扩展名自动选择读取方式 + 自动识别文件类型并读取文本内容。 + + 支持格式:.txt / .md / .pdf / .docx / 其他(按 UTF-8 读取) Args: file_path: 文件路径 @@ -98,45 +16,66 @@ def read_file_auto(file_path: str) -> str: 文件文本内容 Raises: - ValueError: 不支持的文件类型 + FileNotFoundError: 文件不存在 + RuntimeError: 读取失败 """ - ext = Path(file_path).suffix.lower() - readers = { - ".txt": read_text_file, - ".md": read_text_file, - ".py": read_text_file, - ".json": read_text_file, - ".yaml": read_text_file, - ".yml": read_text_file, - ".pdf": read_pdf_file, - ".docx": read_docx_file, - } - reader = readers.get(ext) - if reader is None: - raise ValueError(f"不支持的文件类型: {ext},支持: {list(readers.keys())}") - return reader(file_path) + if not os.path.exists(file_path): + raise FileNotFoundError(f"文件不存在: {file_path}") + + ext = os.path.splitext(file_path)[1].lower() + + try: + if ext == ".pdf": + return _read_pdf(file_path) + elif ext in (".docx", ".doc"): + return _read_docx(file_path) + else: + # txt / md / 其他文本格式 + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + return f.read() + except Exception as e: + raise RuntimeError(f"读取文件失败 [{file_path}]: {e}") + + +def _read_pdf(file_path: str) -> str: + """读取 PDF 文件文本""" + try: + from pypdf import PdfReader + except ImportError: + raise RuntimeError("读取 PDF 需要安装 pypdf:pip install pypdf") + + reader = PdfReader(file_path) + pages = [page.extract_text() or "" for page in reader.pages] + return "\n".join(pages) + + +def _read_docx(file_path: str) -> str: + """读取 Word 文档文本""" + try: + from docx import Document + except ImportError: + raise RuntimeError("读取 docx 需要安装 python-docx:pip install python-docx") + + doc = Document(file_path) + paragraphs = [p.text for p in doc.paragraphs if p.text.strip()] + return "\n".join(paragraphs) def merge_knowledge_files(file_paths: List[str]) -> str: """ - 合并多个知识库文件为单一文本 + 合并多个知识库文件为单一文本。 Args: - file_paths: 知识库文件路径列表 + file_paths: 文件路径列表 Returns: - 合并后的知识库文本(包含文件名分隔符) + 合并后的文本(各文件以分隔线隔开) """ - if not file_paths: - return "" - - sections = [] - for fp in file_paths: + parts = [] + for path in file_paths: try: - content = read_file_auto(fp) - file_name = Path(fp).name - sections.append(f"### 知识库文件: {file_name}\n{content}") + content = read_file_auto(path) + parts.append(f"--- {os.path.basename(path)} ---\n{content}") except Exception as e: - sections.append(f"### 知识库文件: {fp}\n[读取失败: {e}]") - - return "\n\n".join(sections) \ No newline at end of file + parts.append(f"--- {os.path.basename(path)} [读取失败: {e}] ---") + return "\n\n".join(parts) \ No newline at end of file diff --git a/requirements_generator/utils/output_writer.py b/requirements_generator/utils/output_writer.py index 6fb9e58..ecca0dc 100644 --- a/requirements_generator/utils/output_writer.py +++ b/requirements_generator/utils/output_writer.py @@ -6,133 +6,74 @@ from typing import Dict, List import config - -# 各语言文件扩展名映射 -LANGUAGE_EXT_MAP: Dict[str, str] = { - "python": ".py", - "javascript": ".js", - "typescript": ".ts", - "java": ".java", - "go": ".go", - "rust": ".rs", - "cpp": ".cpp", - "c": ".c", - "csharp": ".cs", - "ruby": ".rb", - "php": ".php", - "swift": ".swift", - "kotlin": ".kt", -} - -# 合法的通用类型集合 VALID_TYPES = { "integer", "string", "boolean", "float", "list", "dict", "object", "void", "any", } - -# 合法的 inout 值 VALID_INOUT = {"in", "out", "inout"} -def get_file_extension(language: str) -> str: - """ - 获取指定语言的文件扩展名 - - Args: - language: 编程语言名称(小写) - - Returns: - 文件扩展名(含点号,如 '.py') - """ - return LANGUAGE_EXT_MAP.get(language.lower(), ".txt") - +# ══════════════════════════════════════════════════════ +# 目录管理 +# ══════════════════════════════════════════════════════ def build_project_output_dir(project_name: str) -> str: - """ - 构建项目输出目录路径 - - Args: - project_name: 项目名称 - - Returns: - 输出目录路径 - """ - safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name) - return os.path.join(config.OUTPUT_BASE_DIR, safe_name) + safe = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name) + return os.path.join(config.OUTPUT_BASE_DIR, safe) def ensure_project_dir(project_name: str) -> str: - """ - 确保项目输出目录存在,不存在则创建 - - Args: - project_name: 项目名称 - - Returns: - 创建好的目录路径 - """ + """确保项目根输出目录存在,并创建 __init__.py""" output_dir = build_project_output_dir(project_name) os.makedirs(output_dir, exist_ok=True) init_file = os.path.join(output_dir, "__init__.py") if not os.path.exists(init_file): - Path(init_file).write_text( - "# Auto-generated project package\n", encoding="utf-8" - ) + Path(init_file).write_text("# Auto-generated project package\n", encoding="utf-8") return output_dir -def write_code_file( - output_dir: str, - function_name: str, - language: str, - content: str, -) -> str: - """ - 将代码内容写入指定目录的文件 +def ensure_module_dir(output_dir: str, module: str) -> str: + """确保模块子目录存在,并创建 __init__.py""" + module_dir = os.path.join(output_dir, module) + os.makedirs(module_dir, exist_ok=True) + init_file = os.path.join(module_dir, "__init__.py") + if not os.path.exists(init_file): + Path(init_file).write_text( + f"# Auto-generated module package: {module}\n", encoding="utf-8" + ) + return module_dir - Args: - output_dir: 输出目录路径 - function_name: 函数名(用于生成文件名) - language: 编程语言 - content: 代码内容 - - Returns: - 写入的文件完整路径 - """ - ext = get_file_extension(language) - file_name = f"{function_name}{ext}" - file_path = os.path.join(output_dir, file_name) - Path(file_path).write_text(content, encoding="utf-8") - return file_path +# ══════════════════════════════════════════════════════ +# README +# ══════════════════════════════════════════════════════ def write_project_readme( - output_dir: str, - project_name: str, + output_dir: str, + project_name: str, + project_description: str, requirements_summary: str, + modules: List[str] = None, ) -> str: - """ - 在项目目录生成 README.md 文件 + """生成项目 README.md""" + module_section = "" + if modules: + module_list = "\n".join(f"- `{m}/`" for m in sorted(set(modules))) + module_section = f"\n## 功能模块\n\n{module_list}\n" - Args: - output_dir: 项目输出目录 - project_name: 项目名称 - requirements_summary: 功能需求摘要文本 - - Returns: - README.md 文件路径 - """ - readme_content = f"""# {project_name} + content = f"""# {project_name} > Auto-generated by Requirement Analyzer +{project_description or ""} +{module_section} ## 功能需求列表 {requirements_summary} """ - readme_path = os.path.join(output_dir, "README.md") - Path(readme_path).write_text(readme_content, encoding="utf-8") - return readme_path + path = os.path.join(output_dir, "README.md") + Path(path).write_text(content, encoding="utf-8") + return path # ══════════════════════════════════════════════════════ @@ -140,26 +81,18 @@ def write_project_readme( # ══════════════════════════════════════════════════════ def build_signatures_document( - project_name: str, + project_name: str, project_description: str, - signatures: List[dict], + signatures: List[dict], ) -> dict: """ - 将函数签名列表包装为带项目信息的顶层文档结构。 + 构建顶层签名文档结构:: - Args: - project_name: 项目名称,写入 "project" 字段 - project_description: 项目描述,写入 "description" 字段 - signatures: 函数签名 dict 列表,写入 "functions" 字段 - - Returns: - 顶层文档 dict,结构为:: - - { - "project": "", - "description": "", - "functions": [ ... ] - } + { + "project": "", + "description": "", + "functions": [ ... ] + } """ return { "project": project_name, @@ -169,118 +102,45 @@ def build_signatures_document( def patch_signatures_with_url( - signatures: List[dict], + signatures: List[dict], func_name_to_url: Dict[str, str], ) -> List[dict]: """ - 将代码文件的路径(URL)回写到对应函数签名的 "url" 字段。 - - 遍历签名列表,根据 signature["name"] 在 func_name_to_url 中查找 - 对应路径,找到则写入 "url" 字段;未找到则写入空字符串,不抛出异常。 - - "url" 字段插入位置紧跟在 "type" 字段之后,以保持字段顺序的可读性:: - - { - "name": "create_user", - "requirement_id": "REQ.01", - "description": "...", - "type": "function", - "url": "/abs/path/to/create_user.py", ← 新增 - "parameters": { ... }, - "return": { ... } - } + 将代码文件路径回写到签名的 "url" 字段(紧跟 "type" 之后)。 Args: - signatures: 原始签名列表(in-place 修改) - func_name_to_url: {函数名: 代码文件绝对路径} 映射表, - 由 CodeGenerator.generate_batch() 的进度回调收集 + signatures: 签名列表(in-place 修改) + func_name_to_url: {函数名: 文件绝对路径} Returns: - 修改后的签名列表(与传入的同一对象,方便链式调用) + 修改后的签名列表 """ for sig in signatures: - func_name = sig.get("name", "") - url = func_name_to_url.get(func_name, "") + url = func_name_to_url.get(sig.get("name", ""), "") _insert_field_after(sig, after_key="type", new_key="url", new_value=url) return signatures -def _insert_field_after( - d: dict, - after_key: str, - new_key: str, - new_value, -) -> None: - """ - 在有序 dict 中将 new_key 插入到 after_key 之后。 - 若 after_key 不存在,则追加到末尾。 - 若 new_key 已存在,则直接更新其值(不改变位置)。 - - Args: - d: 目标 dict(Python 3.7+ 保证插入顺序) - after_key: 参考键名 - new_key: 要插入的键名 - new_value: 要插入的值 - """ +def _insert_field_after(d: dict, after_key: str, new_key: str, new_value) -> None: + """在有序 dict 中将 new_key 插入到 after_key 之后""" if new_key in d: d[new_key] = new_value return - items = list(d.items()) - insert_pos = len(items) - for i, (k, _) in enumerate(items): - if k == after_key: - insert_pos = i + 1 - break - + insert_pos = next((i + 1 for i, (k, _) in enumerate(items) if k == after_key), len(items)) items.insert(insert_pos, (new_key, new_value)) d.clear() d.update(items) def write_function_signatures_json( - output_dir: str, - signatures: List[dict], - project_name: str, + output_dir: str, + signatures: List[dict], + project_name: str, project_description: str, - file_name: str = "function_signatures.json", + file_name: str = "function_signatures.json", ) -> str: - """ - 将函数签名列表连同项目信息一起导出为 JSON 文件。 - - 输出的 JSON 顶层结构为:: - - { - "project": "", - "description": "", - "functions": [ - { - "name": "...", - "requirement_id": "...", - "description": "...", - "type": "function", - "url": "/abs/path/to/xxx.py", - "parameters": { ... }, - "return": { ... } - }, - ... - ] - } - - Args: - output_dir: JSON 文件写入目录 - signatures: 函数签名 dict 列表(应已通过 - patch_signatures_with_url() 写入 "url" 字段) - project_name: 项目名称 - project_description: 项目描述 - file_name: 输出文件名,默认 function_signatures.json - - Returns: - 写入的 JSON 文件完整路径 - - Raises: - OSError: 目录不可写 - """ + """将签名列表导出为 JSON 文件""" os.makedirs(output_dir, exist_ok=True) document = build_signatures_document(project_name, project_description, signatures) file_path = os.path.join(output_dir, file_name) @@ -294,137 +154,70 @@ def write_function_signatures_json( # ══════════════════════════════════════════════════════ def validate_signature_schema(signature: dict) -> List[str]: - """ - 校验单个函数签名 dict 是否符合规范。 - - 校验范围: - - 顶层必填字段:name / requirement_id / description / type / parameters - - 可选字段 "url":若存在则必须为非空字符串 - - parameters:每个参数的 type / inout / required 字段 - - return:type 字段 + on_success / on_failure 子结构 - - void 函数:on_success / on_failure 应为 null - - 非 void 函数:on_success / on_failure 必须存在, - 且 value(非空)与 description(非空)均需填写 - - Args: - signature: 单个函数签名 dict - - Returns: - 错误信息字符串列表,列表为空表示校验通过 - """ + """校验单个函数签名结构,返回错误列表(空列表表示通过)""" errors: List[str] = [] - # ── 顶层必填字段 ────────────────────────────────── for key in ("name", "requirement_id", "description", "type", "parameters"): if key not in signature: errors.append(f"缺少顶层字段: '{key}'") - # ── url 字段(可选,存在时校验非空)───────────────── if "url" in signature: if not isinstance(signature["url"], str): errors.append("'url' 字段必须是字符串类型") elif signature["url"] == "": - errors.append("'url' 字段不能为空字符串(代码文件路径未成功回写)") + errors.append("'url' 字段不能为空字符串") - # ── parameters ──────────────────────────────────── params = signature.get("parameters", {}) - if not isinstance(params, dict): - errors.append("'parameters' 必须是 dict 类型") - else: + if isinstance(params, dict): for pname, pdef in params.items(): if not isinstance(pdef, dict): errors.append(f"参数 '{pname}' 定义必须是 dict") continue - # type(支持联合类型,如 "string|integer") if "type" not in pdef: - errors.append(f"参数 '{pname}' 缺少 'type' 字段") + errors.append(f"参数 '{pname}' 缺少 'type'") else: parts = [p.strip() for p in pdef["type"].split("|")] if not all(p in VALID_TYPES for p in parts): - errors.append( - f"参数 '{pname}' 的 type='{pdef['type']}' 含有不合法的类型" - ) - # inout + errors.append(f"参数 '{pname}' type='{pdef['type']}' 含不合法类型") if "inout" not in pdef: - errors.append(f"参数 '{pname}' 缺少 'inout' 字段") + errors.append(f"参数 '{pname}' 缺少 'inout'") elif pdef["inout"] not in VALID_INOUT: - errors.append( - f"参数 '{pname}' 的 inout='{pdef['inout']}' 应为 in/out/inout" - ) - # required + errors.append(f"参数 '{pname}' inout='{pdef['inout']}' 应为 in/out/inout") if "required" not in pdef: - errors.append(f"参数 '{pname}' 缺少 'required' 字段") + errors.append(f"参数 '{pname}' 缺少 'required'") elif not isinstance(pdef["required"], bool): - errors.append( - f"参数 '{pname}' 的 'required' 应为布尔值 true/false," - f"当前为: {pdef['required']!r}" - ) + errors.append(f"参数 '{pname}' 'required' 应为布尔值") - # ── return ──────────────────────────────────────── ret = signature.get("return") if ret is None: - errors.append( - "缺少 'return' 字段(void 函数请填 " - "{\"type\": \"void\", \"on_success\": null, \"on_failure\": null})" - ) - elif not isinstance(ret, dict): - errors.append("'return' 必须是 dict 类型") - else: + errors.append("缺少 'return' 字段") + elif isinstance(ret, dict): ret_type = ret.get("type") if not ret_type: - errors.append("'return' 缺少 'type' 字段") + errors.append("'return' 缺少 'type'") elif ret_type not in VALID_TYPES: - errors.append(f"'return.type'='{ret_type}' 不在合法类型列表中") - + errors.append(f"'return.type'='{ret_type}' 不合法") is_void = (ret_type == "void") - for sub_key in ("on_success", "on_failure"): sub = ret.get(sub_key) if is_void: if sub is not None: - errors.append( - f"void 函数的 'return.{sub_key}' 应为 null," - f"当前为: {sub!r}" - ) + errors.append(f"void 函数 'return.{sub_key}' 应为 null") else: if sub is None: - errors.append( - f"非 void 函数缺少 'return.{sub_key}'," - f"请描述{'成功' if sub_key == 'on_success' else '失败'}时的返回值" - ) - elif not isinstance(sub, dict): - errors.append(f"'return.{sub_key}' 必须是 dict 类型") - else: - if "value" not in sub: - errors.append(f"'return.{sub_key}' 缺少 'value' 字段") - elif sub["value"] == "": - errors.append( - f"'return.{sub_key}.value' 不能为空字符串," - f"请填写具体返回值、值域描述或结构示例" - ) - if "description" not in sub or sub.get("description") in (None, ""): + errors.append(f"非 void 函数缺少 'return.{sub_key}'") + elif isinstance(sub, dict): + if not sub.get("value"): + errors.append(f"'return.{sub_key}.value' 不能为空") + if not sub.get("description"): errors.append(f"'return.{sub_key}.description' 不能为空") - return errors def validate_all_signatures(signatures: List[dict]) -> Dict[str, List[str]]: - """ - 批量校验函数签名列表。 - - 注意:此函数接受的是纯签名列表(即顶层文档的 "functions" 字段), - 而非包含 project/description 的顶层文档。 - - Args: - signatures: 函数签名 dict 列表 - - Returns: - {函数名: [错误信息, ...]} 字典,仅包含有错误的条目 - """ - report: Dict[str, List[str]] = {} - for sig in signatures: - name = sig.get("name", f"unknown_{id(sig)}") - errs = validate_signature_schema(sig) - if errs: - report[name] = errs - return report \ No newline at end of file + """批量校验,返回 {函数名: [错误]} 字典(仅含有错误的条目)""" + return { + sig.get("name", f"unknown_{i}"): errs + for i, sig in enumerate(signatures) + if (errs := validate_signature_schema(sig)) + } \ No newline at end of file