diff --git a/MCPServers/requirement.py b/MCPServers/requirement.py new file mode 100644 index 0000000..afd778f --- /dev/null +++ b/MCPServers/requirement.py @@ -0,0 +1,16 @@ +from mcp.server.fastmcp import FastMCP +from mcp.server.transport_security import TransportSecuritySettings + +mcp = FastMCP("requirements", + transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False)) + + +@mcp.tool() +def generate_requirement_document(project: str, requirements: list, standard: str = "GJB438C"): + print(project, requirements, standard) + +if __name__ == "__main__": + mcp.settings.host = "0.0.0.0" + mcp.settings.port = 7999 + + mcp.run(transport="streamable-http") \ No newline at end of file diff --git a/ai_test_generator/__init__.py b/ai_test_generator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ai_test_generator/config.py b/ai_test_generator/config.py new file mode 100644 index 0000000..6d416ac --- /dev/null +++ b/ai_test_generator/config.py @@ -0,0 +1,21 @@ +import os +from dataclasses import dataclass + +@dataclass +class Config: + # LLM配置(以OpenAI兼容接口为例,可替换为其他LLM) + LLM_API_KEY: str = os.getenv("LLM_API_KEY", "your-api-key-here") + LLM_BASE_URL: str = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1") + LLM_MODEL: str = os.getenv("LLM_MODEL", "gpt-4o") + LLM_TEMPERATURE: float = 0.2 + LLM_MAX_TOKENS: int = 4096 + + # 测试配置 + GENERATED_TESTS_DIR: str = "generated_tests" + TEST_TIMEOUT: int = 30 # 单个测试超时时间(秒) + + # HTTP测试配置 + HTTP_BASE_URL: str = os.getenv("HTTP_BASE_URL", "http://localhost:8080") + HTTP_TIMEOUT: int = 10 + +config = Config() \ No newline at end of file diff --git a/ai_test_generator/core/__init__.py b/ai_test_generator/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ai_test_generator/core/analyzer.py b/ai_test_generator/core/analyzer.py new file mode 100644 index 0000000..6c99a27 --- /dev/null +++ b/ai_test_generator/core/analyzer.py @@ -0,0 +1,609 @@ +""" +覆盖率分析与缺口识别 +══════════════════════════════════════════════════════════════ +修复: + - 移除对已删除属性 request_parameters / url_parameters / response 的引用 + - HTTP 与 function 接口统一使用 iface.in_parameters / iface.out_parameters + - HTTP 接口的返回字段覆盖统一从 return_info.on_success / on_failure 读取 +""" + +from __future__ import annotations +import logging +from dataclasses import dataclass, field +from core.parser import InterfaceInfo + +logger = logging.getLogger(__name__) + + +# ══════════════════════════════════════════════════════════════ +# 数据结构 +# ══════════════════════════════════════════════════════════════ + +@dataclass +class InterfaceCoverage: + interface_name: str + protocol: str + requirement_id: str = "" + is_covered: bool = False + + # 入参覆盖(in / inout) + all_in_params: set[str] = field(default_factory=set) + covered_in_params: set[str] = field(default_factory=set) + + # 出参断言覆盖(out / inout) + all_out_params: set[str] = field(default_factory=set) + asserted_out_params: set[str] = field(default_factory=set) + + # on_success 返回字段断言覆盖 + all_success_fields: set[str] = field(default_factory=set) + covered_success_fields: set[str] = field(default_factory=set) + + # on_failure 返回字段断言覆盖 + all_failure_fields: set[str] = field(default_factory=set) + covered_failure_fields: set[str] = field(default_factory=set) + + # 异常断言(on_failure = raises Exception) + expects_exception: bool = False + exception_case_covered: bool = False + + has_positive_case: bool = False + has_negative_case: bool = False + covering_test_ids: list[str] = field(default_factory=list) + + # ── 覆盖率计算 ──────────────────────────────────────────── + + @property + def in_param_coverage_rate(self) -> float: + if not self.all_in_params: + return 1.0 + return len(self.covered_in_params) / len(self.all_in_params) + + @property + def out_param_coverage_rate(self) -> float: + if not self.all_out_params: + return 1.0 + return len(self.asserted_out_params) / len(self.all_out_params) + + @property + def success_field_coverage_rate(self) -> float: + if not self.all_success_fields: + return 1.0 + return len(self.covered_success_fields) / len(self.all_success_fields) + + @property + def failure_field_coverage_rate(self) -> float: + if not self.all_failure_fields: + return 1.0 + return len(self.covered_failure_fields) / len(self.all_failure_fields) + + @property + def uncovered_in_params(self) -> set[str]: + return self.all_in_params - self.covered_in_params + + @property + def uncovered_out_params(self) -> set[str]: + return self.all_out_params - self.asserted_out_params + + @property + def uncovered_success_fields(self) -> set[str]: + return self.all_success_fields - self.covered_success_fields + + @property + def uncovered_failure_fields(self) -> set[str]: + return self.all_failure_fields - self.covered_failure_fields + + +@dataclass +class RequirementCoverage: + requirement: str + requirement_id: str = "" + covering_test_ids: list[str] = field(default_factory=list) + + @property + def is_covered(self) -> bool: + return bool(self.covering_test_ids) + + +@dataclass +class Gap: + gap_type: str # 见下方常量 + severity: str # critical | high | medium | low + target: str + detail: str + suggestion: str + + +# gap_type 常量 +GAP_INTERFACE_NOT_COVERED = "interface_not_covered" +GAP_IN_PARAM_NOT_COVERED = "in_param_not_covered" +GAP_OUT_PARAM_NOT_ASSERTED = "out_param_not_asserted" +GAP_SUCCESS_FIELD_NOT_ASSERTED = "success_field_not_asserted" +GAP_FAILURE_FIELD_NOT_ASSERTED = "failure_field_not_asserted" +GAP_EXCEPTION_NOT_TESTED = "exception_case_not_tested" +GAP_MISSING_POSITIVE = "missing_positive_case" +GAP_MISSING_NEGATIVE = "missing_negative_case" +GAP_REQUIREMENT_NOT_COVERED = "requirement_not_covered" +GAP_TEST_FAILED = "test_failed" +GAP_TEST_ERROR = "test_error" + + +@dataclass +class CoverageReport: + total_interfaces: int = 0 + covered_interfaces: int = 0 + total_requirements: int = 0 + covered_requirements: int = 0 + total_test_cases: int = 0 + passed_test_cases: int = 0 + failed_test_cases: int = 0 + error_test_cases: int = 0 + + interface_coverages: list[InterfaceCoverage] = field(default_factory=list) + requirement_coverages: list[RequirementCoverage] = field(default_factory=list) + gaps: list[Gap] = field(default_factory=list) + + @property + def interface_coverage_rate(self) -> float: + return self.covered_interfaces / self.total_interfaces \ + if self.total_interfaces else 0.0 + + @property + def requirement_coverage_rate(self) -> float: + return self.covered_requirements / self.total_requirements \ + if self.total_requirements else 0.0 + + @property + def pass_rate(self) -> float: + return self.passed_test_cases / self.total_test_cases \ + if self.total_test_cases else 0.0 + + @property + def avg_in_param_coverage_rate(self) -> float: + rates = [ + ic.in_param_coverage_rate + for ic in self.interface_coverages + if ic.all_in_params + ] + return sum(rates) / len(rates) if rates else 1.0 + + @property + def avg_success_field_coverage_rate(self) -> float: + rates = [ + ic.success_field_coverage_rate + for ic in self.interface_coverages + if ic.all_success_fields + ] + return sum(rates) / len(rates) if rates else 1.0 + + @property + def avg_failure_field_coverage_rate(self) -> float: + rates = [ + ic.failure_field_coverage_rate + for ic in self.interface_coverages + if ic.all_failure_fields + ] + return sum(rates) / len(rates) if rates else 1.0 + + @property + def critical_gap_count(self) -> int: + return sum(1 for g in self.gaps if g.severity == "critical") + + @property + def high_gap_count(self) -> int: + return sum(1 for g in self.gaps if g.severity == "high") + + +# ══════════════════════════════════════════════════════════════ +# 分析器 +# ══════════════════════════════════════════════════════════════ + +class CoverageAnalyzer: + + def __init__( + self, + interfaces: list[InterfaceInfo], + requirements: list[str], + test_cases: list[dict], + run_results: list, + ): + self.interfaces = interfaces + self.requirements = requirements + self.test_cases = test_cases + self.run_results = run_results + self._iface_map = {i.name: i for i in interfaces} + self._result_map = { + getattr(r, "test_id", ""): r for r in run_results + } + + # ── 入口 ────────────────────────────────────────────────── + + def analyze(self) -> CoverageReport: + logger.info("Starting coverage analysis …") + report = CoverageReport() + + iface_cov_map = self._init_interface_coverages() + req_cov_map = self._init_requirement_coverages() + + self._scan_test_cases(iface_cov_map, req_cov_map) + self._scan_run_results(report) + + report.interface_coverages = list(iface_cov_map.values()) + report.requirement_coverages = list(req_cov_map.values()) + report.total_interfaces = len(self.interfaces) + report.covered_interfaces = sum( + 1 for ic in iface_cov_map.values() if ic.is_covered + ) + report.total_requirements = len(self.requirements) + report.covered_requirements = sum( + 1 for rc in req_cov_map.values() if rc.is_covered + ) + report.total_test_cases = len(self.test_cases) + report.gaps = self._identify_gaps(iface_cov_map, req_cov_map) + + logger.info( + f"Analysis done — " + f"interface={report.interface_coverage_rate:.0%}, " + f"requirement={report.requirement_coverage_rate:.0%}, " + f"pass={report.pass_rate:.0%}, " + f"gaps={len(report.gaps)}" + ) + return report + + # ══════════════════════════════════════════════════════════ + # 初始化接口覆盖对象 + # ══════════════════════════════════════════════════════════ + + def _init_interface_coverages(self) -> dict[str, InterfaceCoverage]: + result: dict[str, InterfaceCoverage] = {} + + for iface in self.interfaces: + ic = InterfaceCoverage( + interface_name=iface.name, + protocol=iface.protocol, + requirement_id=iface.requirement_id, + ) + + # ── 入参(in / inout)──────────────────────────── + # HTTP 与 function 统一使用 in_parameters + ic.all_in_params = {p.name for p in iface.in_parameters} + + # ── 出参(out / inout)─────────────────────────── + ic.all_out_params = {p.name for p in iface.out_parameters} + + # ── 返回值字段(on_success)────────────────────── + ri = iface.return_info + success_fields = set(ri.on_success.value_fields.keys()) + + # boolean / 非 dict 类型:用虚拟字段 "return" 代表返回值本身 + if ri.type in ("boolean", "integer", "string", "float") or ( + ri.on_success.value is not None + and not isinstance(ri.on_success.value, dict) + ): + success_fields.add("return") + + ic.all_success_fields = success_fields + + # ── 返回值字段(on_failure)────────────────────── + if ri.on_failure.is_exception: + ic.expects_exception = True + else: + failure_fields = set(ri.on_failure.value_fields.keys()) + if ri.type in ("boolean", "integer", "string", "float") or ( + ri.on_failure.value is not None + and not isinstance(ri.on_failure.value, dict) + and not ri.on_failure.is_exception + ): + failure_fields.add("return") + ic.all_failure_fields = failure_fields + + result[iface.name] = ic + + return result + + # ══════════════════════════════════════════════════════════ + # 初始化需求覆盖对象 + # ══════════════════════════════════════════════════════════ + + def _init_requirement_coverages(self) -> dict[str, RequirementCoverage]: + result: dict[str, RequirementCoverage] = {} + for req in self.requirements: + result[req] = RequirementCoverage(requirement=req) + + # requirement_id → requirement text 的辅助映射 + self._req_id_map: dict[str, str] = {} + for iface in self.interfaces: + if iface.requirement_id: + matched = self._fuzzy_match_req(iface.name, result) + if matched: + self._req_id_map[iface.requirement_id] = matched + + return result + + # ══════════════════════════════════════════════════════════ + # 扫描测试用例 + # ══════════════════════════════════════════════════════════ + + def _scan_test_cases( + self, + iface_cov_map: dict[str, InterfaceCoverage], + req_cov_map: dict[str, RequirementCoverage], + ): + for tc in self.test_cases: + test_id = tc.get("test_id", "") + requirement = tc.get("requirement", "") + req_id = tc.get("requirement_id", "") + description = ( + tc.get("description", "") + " " + tc.get("test_name", "") + ).lower() + + is_negative = any( + kw in description + for kw in ( + "negative", "invalid", "missing", "fail", "error", + "wrong", "bad", "boundary", "负向", "exception", + ) + ) + + # 需求覆盖匹配 + matched_req = ( + self._match_by_req_id(req_id, req_cov_map) + or self._match_by_text(requirement, req_cov_map) + ) + if matched_req: + req_cov_map[matched_req].covering_test_ids.append(test_id) + + # 遍历步骤 + for step in tc.get("steps", []): + iface_name = step.get("interface_name", "") + ic = iface_cov_map.get(iface_name) + if ic is None: + continue + + ic.is_covered = True + if test_id not in ic.covering_test_ids: + ic.covering_test_ids.append(test_id) + + if is_negative: + ic.has_negative_case = True + else: + ic.has_positive_case = True + + # 入参覆盖 + for param_name in step.get("input", {}).keys(): + if param_name in ic.all_in_params: + ic.covered_in_params.add(param_name) + + # 断言覆盖 + for assertion in step.get("assertions", []): + f = assertion.get("field", "") + operator = assertion.get("operator", "") + + # 异常断言 + if f == "exception" and operator == "raised": + ic.exception_case_covered = True + continue + + # 出参断言 + if f in ic.all_out_params: + ic.asserted_out_params.add(f) + + # on_success 字段(正向用例) + if not is_negative and f in ic.all_success_fields: + ic.covered_success_fields.add(f) + + # on_failure 字段(负向用例) + if is_negative and f in ic.all_failure_fields: + ic.covered_failure_fields.add(f) + + # boolean / scalar return 断言 + if f == "return": + if not is_negative: + ic.covered_success_fields.add("return") + else: + ic.covered_failure_fields.add("return") + + # ══════════════════════════════════════════════════════════ + # 扫描执行结果 + # ══════════════════════════════════════════════════════════ + + def _scan_run_results(self, report: CoverageReport): + for r in self.run_results: + s = getattr(r, "status", "") + if s == "PASS": + report.passed_test_cases += 1 + elif s == "FAIL": + report.failed_test_cases += 1 + else: + report.error_test_cases += 1 + + # ══════════════════════════════════════════════════════════ + # 缺口识别 + # ══════════════════════════════════════════════════════════ + + def _identify_gaps( + self, + iface_cov_map: dict[str, InterfaceCoverage], + req_cov_map: dict[str, RequirementCoverage], + ) -> list[Gap]: + gaps: list[Gap] = [] + + for ic in iface_cov_map.values(): + n = ic.interface_name + rid = f"[{ic.requirement_id}] " if ic.requirement_id else "" + + # ① 接口完全未覆盖 + if not ic.is_covered: + gaps.append(Gap( + gap_type=GAP_INTERFACE_NOT_COVERED, + severity="critical", target=n, + detail=f"{rid}'{n}' has NO test case.", + suggestion=f"Add positive + negative test cases for '{n}'.", + )) + continue # 后续缺口依赖覆盖,跳过 + + # ② 缺少正向用例 + if not ic.has_positive_case: + gaps.append(Gap( + gap_type=GAP_MISSING_POSITIVE, + severity="critical", target=n, + detail=f"{rid}'{n}' has no positive (happy-path) test.", + suggestion=f"Add a positive test case with valid inputs for '{n}'.", + )) + + # ③ 缺少负向用例 + if not ic.has_negative_case: + gaps.append(Gap( + gap_type=GAP_MISSING_NEGATIVE, + severity="high", target=n, + detail=f"{rid}'{n}' has no negative test case.", + suggestion=( + f"Add negative tests for '{n}': " + f"missing required params, invalid types, boundary values." + ), + )) + + # ④ 入参未覆盖 + for param in sorted(ic.uncovered_in_params): + gaps.append(Gap( + gap_type=GAP_IN_PARAM_NOT_COVERED, + severity="high", target=n, + detail=f"{rid}Input param '{param}' of '{n}' never used in tests.", + suggestion=f"Add a test that explicitly passes '{param}' to '{n}'.", + )) + + # ⑤ 出参未断言 + for param in sorted(ic.uncovered_out_params): + gaps.append(Gap( + gap_type=GAP_OUT_PARAM_NOT_ASSERTED, + severity="high", target=n, + detail=f"{rid}Output param '{param}' of '{n}' never asserted.", + suggestion=( + f"Add an assertion on output param '{param}' " + f"after calling '{n}'." + ), + )) + + # ⑥ on_success 返回字段未断言 + for f in sorted(ic.uncovered_success_fields): + gaps.append(Gap( + gap_type=GAP_SUCCESS_FIELD_NOT_ASSERTED, + severity="high", target=n, + detail=( + f"{rid}on_success field '{f}' of '{n}' " + f"never asserted in positive tests." + ), + suggestion=( + f"Add assertion on '{f}' in the positive " + f"test step for '{n}'." + ), + )) + + # ⑦ on_failure 返回字段未断言 + for f in sorted(ic.uncovered_failure_fields): + gaps.append(Gap( + gap_type=GAP_FAILURE_FIELD_NOT_ASSERTED, + severity="medium", target=n, + detail=( + f"{rid}on_failure field '{f}' of '{n}' " + f"never asserted in negative tests." + ), + suggestion=( + f"Add assertion on '{f}' in the negative " + f"test step for '{n}'." + ), + )) + + # ⑧ 异常场景未测试 + if ic.expects_exception and not ic.exception_case_covered: + gaps.append(Gap( + gap_type=GAP_EXCEPTION_NOT_TESTED, + severity="high", target=n, + detail=( + f"{rid}'{n}' declares on_failure = raises Exception, " + f"but no test asserts that the exception is raised." + ), + suggestion=( + f"Add a negative test for '{n}' that wraps the call " + f"in try/except and asserts the exception is raised." + ), + )) + + # ⑨ 需求未覆盖 + for rc in req_cov_map.values(): + if not rc.is_covered: + gaps.append(Gap( + gap_type=GAP_REQUIREMENT_NOT_COVERED, + severity="critical", + target=rc.requirement, + detail=f"Requirement '{rc.requirement}' has no test case.", + suggestion=f"Generate test cases for: '{rc.requirement}'.", + )) + + # ⑩ 执行失败 / 错误 + for r in self.run_results: + s = getattr(r, "status", "") + if s == "FAIL": + gaps.append(Gap( + gap_type=GAP_TEST_FAILED, + severity="high", + target=getattr(r, "test_id", ""), + detail=f"Test '{r.test_id}' FAILED: {r.message}", + suggestion=( + "Investigate failure; fix implementation or test data." + ), + )) + elif s in ("ERROR", "TIMEOUT"): + gaps.append(Gap( + gap_type=GAP_TEST_ERROR, + severity="medium", + target=getattr(r, "test_id", ""), + detail=f"Test '{r.test_id}' {s}: {r.message}", + suggestion=( + "Check script for runtime errors or increase TEST_TIMEOUT." + ), + )) + + # 按严重程度排序 + _order = {"critical": 0, "high": 1, "medium": 2, "low": 3} + gaps.sort(key=lambda g: _order.get(g.severity, 9)) + return gaps + + # ══════════════════════════════════════════════════════════ + # 工具方法 + # ══════════════════════════════════════════════════════════ + + def _match_by_req_id( + self, + req_id: str, + req_cov_map: dict[str, RequirementCoverage], + ) -> str | None: + if not req_id: + return None + mapped = self._req_id_map.get(req_id) + if mapped and mapped in req_cov_map: + return mapped + return None + + def _match_by_text( + self, + requirement: str, + req_cov_map: dict[str, RequirementCoverage], + ) -> str | None: + if requirement in req_cov_map: + return requirement + req_lower = requirement.lower() + for key in req_cov_map: + if key.lower() in req_lower or req_lower in key.lower(): + return key + return None + + def _fuzzy_match_req( + self, + iface_name: str, + req_cov_map: dict[str, RequirementCoverage], + ) -> str | None: + name_lower = iface_name.lower().replace("_", " ") + for key in req_cov_map: + if name_lower in key.lower() or key.lower() in name_lower: + return key + return None \ No newline at end of file diff --git a/ai_test_generator/core/llm_client.py b/ai_test_generator/core/llm_client.py new file mode 100644 index 0000000..a74f090 --- /dev/null +++ b/ai_test_generator/core/llm_client.py @@ -0,0 +1,58 @@ +import json +import re +import logging +from openai import OpenAI +from config import config + +logger = logging.getLogger(__name__) + + +class LLMClient: + """封装 LLM API 调用(OpenAI 兼容接口)""" + + def __init__(self): + self.client = OpenAI( + api_key=config.LLM_API_KEY, + base_url=config.LLM_BASE_URL, + ) + + def generate_test_cases(self, system_prompt: str, user_prompt: str) -> list[dict]: + logger.info(f"Calling LLM: model={config.LLM_MODEL}") + logger.debug(f"--- USER PROMPT ---\n{user_prompt}\n---") + + response = self.client.chat.completions.create( + model=config.LLM_MODEL, + temperature=config.LLM_TEMPERATURE, + max_tokens=config.LLM_MAX_TOKENS, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + ) + raw = response.choices[0].message.content + logger.debug(f"--- LLM RAW RESPONSE ---\n{raw[:800]}\n---") + return self._parse_json(raw) + + # ── 解析 ────────────────────────────────────────────────── + + def _parse_json(self, content: str) -> list[dict]: + content = content.strip() + # 去除可能的 markdown 代码块 + content = re.sub(r"^```(?:json)?\s*", "", content, flags=re.MULTILINE) + content = re.sub(r"\s*```\s*$", "", content, flags=re.MULTILINE) + content = content.strip() + + try: + data = json.loads(content) + except json.JSONDecodeError: + # 尝试提取第一个 JSON 数组 + match = re.search(r"\[.*\]", content, re.DOTALL) + if match: + data = json.loads(match.group()) + else: + logger.error(f"Cannot parse LLM response as JSON:\n{content[:600]}") + raise + + if not isinstance(data, list): + raise ValueError(f"Expected JSON array, got {type(data)}") + return data \ No newline at end of file diff --git a/ai_test_generator/core/parser.py b/ai_test_generator/core/parser.py new file mode 100644 index 0000000..b90acaa --- /dev/null +++ b/ai_test_generator/core/parser.py @@ -0,0 +1,388 @@ +""" +接口描述解析器 +══════════════════════════════════════════════════════════════ +支持格式(当前版): + 顶层字段: + - project : 项目名称,用于命名输出目录 + - description : 项目描述 + - units : 接口/函数描述列表 + + 接口公共字段: + - name : 接口/函数名称 + - requirement_id : 需求编号 + - description : 接口描述 + - type : "function" | "http" + - url : function → Python 源文件路径,如 "create_user.py" + http → base URL,如 "http://127.0.0.1/api" + - parameters : 统一参数字段(HTTP 与 function 两种协议共用) + ⚠️ HTTP 接口不再使用 url_parameters / request_parameters + - return : 返回值描述(含 on_success / on_failure) + + HTTP 专用字段: + - method : "get" | "post" | "put" | "delete" 等 + - 完整请求 URL = url.rstrip("/") + "/" + name.lstrip("/") + + Function 专用: + - url 为 .py 文件路径,转换为 Python 模块路径供 import 使用 + - 函数名即 name 字段 +""" + +from __future__ import annotations +import json +import re +from typing import Any +from dataclasses import dataclass, field + + +# ══════════════════════════════════════════════════════════════ +# 基础数据结构 +# ══════════════════════════════════════════════════════════════ + +@dataclass +class ParameterInfo: + name: str + type: str # 支持复合类型 "string|integer" + description: str + required: bool = True + default: Any = None + inout: str = "in" # "in" | "out" | "inout" + + +@dataclass +class ReturnCase: + """on_success 或 on_failure 的返回描述""" + value: Any = None + description: str = "" + + @property + def is_exception(self) -> bool: + """判断是否为抛出异常的场景""" + return ( + isinstance(self.value, str) + and "exception" in self.value.lower() + ) + + @property + def value_fields(self) -> dict[str, Any]: + """若 value 是 dict,返回其字段;否则返回空 dict""" + return self.value if isinstance(self.value, dict) else {} + + @property + def value_summary(self) -> str: + """返回值的简短文字描述,用于 LLM 提示词""" + if isinstance(self.value, dict): + return json.dumps(self.value, ensure_ascii=False) + if isinstance(self.value, bool): + return str(self.value).lower() + return str(self.value) if self.value is not None else "" + + +@dataclass +class ReturnInfo: + type: str = "" + description: str = "" + on_success: ReturnCase = field(default_factory=ReturnCase) + on_failure: ReturnCase = field(default_factory=ReturnCase) + + @property + def all_fields(self) -> set[str]: + """汇总 on_success + on_failure 中出现的所有字段名""" + fields: set[str] = set() + for case in (self.on_success, self.on_failure): + fields.update(case.value_fields.keys()) + return fields + + +# ══════════════════════════════════════════════════════════════ +# 接口描述 +# ══════════════════════════════════════════════════════════════ + +@dataclass +class InterfaceInfo: + name: str + description: str + protocol: str # "http" | "function" + url: str = "" # 原始 url 字段 + requirement_id: str = "" + + # HTTP 专用 + method: str = "get" + + # 统一参数列表(HTTP 与 function 两种协议共用,均来自 "parameters" 字段) + parameters: list[ParameterInfo] = field(default_factory=list) + + # 返回值描述 + return_info: ReturnInfo = field(default_factory=ReturnInfo) + + # ── 便捷属性:参数分组 ──────────────────────────────────── + + @property + def in_parameters(self) -> list[ParameterInfo]: + """inout = "in" 或 "inout" 的参数(作为输入)""" + return [p for p in self.parameters if p.inout in ("in", "inout")] + + @property + def out_parameters(self) -> list[ParameterInfo]: + """inout = "out" 或 "inout" 的参数(作为输出)""" + return [p for p in self.parameters if p.inout in ("out", "inout")] + + # ── 便捷属性:HTTP ──────────────────────────────────────── + + @property + def http_full_url(self) -> str: + """ + HTTP 协议的完整请求 URL: + url (base) + name (path) + 例:url="http://127.0.0.1/api", name="/delete_user" + → "http://127.0.0.1/api/delete_user" + """ + if self.protocol != "http": + return "" + base = self.url.rstrip("/") + path = self.name.lstrip("/") + return f"{base}/{path}" if base else f"/{path}" + + # ── 便捷属性:Function ──────────────────────────────────── + + @property + def module_path(self) -> str: + """ + 将 url(.py 文件路径)转换为 Python 模块导入路径。 + 规则: + 1. 去除 .py 后缀 + 2. 去除前导 ./ 或 / + 3. 将路径分隔符替换为 . + 示例: + "create_user.py" → "create_user" + "myapp/services/user_service.py" → "myapp.services.user_service" + "./services/user_service.py" → "services.user_service" + "myapp.services.user_service" → "myapp.services.user_service" + """ + if self.protocol != "function" or not self.url: + return "" + u = self.url.strip() + # 已经是点分模块路径(无路径分隔符且无 .py 后缀) + if re.match(r'^[\w]+(\.[\w]+)*$', u) and not u.endswith(".py"): + return u + # 文件路径 → 模块路径 + u = u.replace("\\", "/") + u = re.sub(r'\.py$', '', u) + u = re.sub(r'^\.?/', '', u) + return u.replace("/", ".") + + @property + def source_file(self) -> str: + """function 协议的源文件路径(原始 url 字段)""" + return self.url if self.protocol == "function" else "" + + +# ══════════════════════════════════════════════════════════════ +# 描述文件根结构 +# ══════════════════════════════════════════════════════════════ + +@dataclass +class ApiDescriptor: + project: str + description: str + interfaces: list[InterfaceInfo] + + +# ══════════════════════════════════════════════════════════════ +# 解析器 +# ══════════════════════════════════════════════════════════════ + +class InterfaceParser: + + # ── 公开接口 ────────────────────────────────────────────── + + def parse_file(self, file_path: str) -> ApiDescriptor: + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + return self.parse(data) + + def parse(self, data: dict | list) -> ApiDescriptor: + """ + 兼容三种格式: + 新格式 :{"project": "...", "description": "...", "units": [...]} + 旧格式1:{"project": "...", "units": [...]} + 旧格式2:[...] 直接是接口数组 + """ + if isinstance(data, list): + return ApiDescriptor( + project="default", + description="", + interfaces=self._parse_units(data), + ) + return ApiDescriptor( + project=data.get("project", "default"), + description=data.get("description", ""), + interfaces=self._parse_units(data.get("units", [])), + ) + + # ── 单元解析 ────────────────────────────────────────────── + + def _parse_units(self, units: list[dict]) -> list[InterfaceInfo]: + result = [] + for item in units: + protocol = ( + item.get("protocol") or item.get("type") or "function" + ).lower() + if protocol == "http": + result.append(self._parse_http(item)) + else: + result.append(self._parse_function(item)) + return result + + def _parse_http(self, item: dict) -> InterfaceInfo: + """ + HTTP 接口解析。 + 参数统一从 "parameters" 字段读取, + 同时兼容旧格式的 "request_parameters" / "url_parameters"。 + 优先级:parameters > request_parameters + url_parameters + """ + # 优先使用新统一字段 "parameters" + raw_params: dict = item.get("parameters", {}) + + # 旧格式兼容:合并 url_parameters + request_parameters + if not raw_params: + raw_params = { + **item.get("url_parameters", {}), + **item.get("request_parameters", {}), + } + + return InterfaceInfo( + name=item["name"], + description=item.get("description", ""), + protocol="http", + url=item.get("url", ""), + requirement_id=item.get("requirement_id", ""), + method=item.get("method", "get").lower(), + parameters=self._parse_param_dict(raw_params, with_inout=True), + return_info=self._parse_return(item.get("return", {})), + ) + + def _parse_function(self, item: dict) -> InterfaceInfo: + return InterfaceInfo( + name=item["name"], + description=item.get("description", ""), + protocol="function", + url=item.get("url", ""), + requirement_id=item.get("requirement_id", ""), + parameters=self._parse_param_dict( + item.get("parameters", {}), with_inout=True, + ), + return_info=self._parse_return(item.get("return", {})), + ) + + # ── 参数解析 ────────────────────────────────────────────── + + def _parse_param_dict( + self, + params: dict, + with_inout: bool = False, + ) -> list[ParameterInfo]: + result = [] + for name, info in params.items(): + result.append(ParameterInfo( + name=name, + type=info.get("type", "string"), + description=info.get("description", ""), + required=info.get("required", True), + default=info.get("default", None), + inout=info.get("inout", "in") if with_inout else "in", + )) + return result + + # ── 返回值解析 ──────────────────────────────────────────── + + def _parse_return(self, ret: dict) -> ReturnInfo: + if not ret: + return ReturnInfo() + return ReturnInfo( + type=ret.get("type", ""), + description=ret.get("description", ""), + on_success=self._parse_return_case(ret.get("on_success", {})), + on_failure=self._parse_return_case(ret.get("on_failure", {})), + ) + + def _parse_return_case(self, case: dict) -> ReturnCase: + if not case: + return ReturnCase() + return ReturnCase( + value=case.get("value"), + description=case.get("description", ""), + ) + + # ══════════════════════════════════════════════════════════ + # 转换为 LLM 可读摘要 + # ══════════════════════════════════════════════════════════ + + def to_summary_dict(self, interfaces: list[InterfaceInfo]) -> list[dict]: + return [ + self._http_summary(i) if i.protocol == "http" + else self._function_summary(i) + for i in interfaces + ] + + def _http_summary(self, iface: InterfaceInfo) -> dict: + ri = iface.return_info + return { + "name": iface.name, + "requirement_id": iface.requirement_id, + "description": iface.description, + "protocol": "http", + "full_url": iface.http_full_url, + "method": iface.method, + # HTTP 接口统一用 "parameters" 输出给 LLM + "parameters": self._params_to_dict(iface.parameters), + "return": { + "type": ri.type, + "description": ri.description, + "on_success": { + "value": ri.on_success.value, + "description": ri.on_success.description, + }, + "on_failure": { + "value": ri.on_failure.value, + "description": ri.on_failure.description, + }, + }, + } + + def _function_summary(self, iface: InterfaceInfo) -> dict: + ri = iface.return_info + return { + "name": iface.name, + "requirement_id": iface.requirement_id, + "description": iface.description, + "protocol": "function", + "source_file": iface.source_file, + "module_path": iface.module_path, + "parameters": self._params_to_dict(iface.parameters), + "return": { + "type": ri.type, + "description": ri.description, + "on_success": { + "value": ri.on_success.value, + "description": ri.on_success.description, + }, + "on_failure": { + "value": ri.on_failure.value, + "description": ri.on_failure.description, + }, + }, + } + + def _params_to_dict(self, params: list[ParameterInfo]) -> dict: + result = {} + for p in params: + entry: dict = { + "type": p.type, + "inout": p.inout, + "description": p.description, + "required": p.required, + } + if p.default is not None: + entry["default"] = p.default + result[p.name] = entry + return result \ No newline at end of file diff --git a/ai_test_generator/core/prompt_builder.py b/ai_test_generator/core/prompt_builder.py new file mode 100644 index 0000000..14e82d4 --- /dev/null +++ b/ai_test_generator/core/prompt_builder.py @@ -0,0 +1,167 @@ +""" +提示词构建器 +升级点: + - 传递项目 description + - function 协议:source_file(.py) → module_path → from import + - HTTP 协议:full_url = url(base) + name(path) + - parameters 统一字段,inout 区分输入/输出 +""" + +import json +from core.parser import InterfaceInfo, InterfaceParser + +SYSTEM_PROMPT = """ +You are a senior software test engineer. Generate test cases based on the provided +interface descriptions and test requirements. + +═══════════════════════════════════════════════════════ +OUTPUT FORMAT +Return ONLY a valid JSON array. No markdown fences. No extra text. +═══════════════════════════════════════════════════════ +Each element = ONE test case: +{ + "test_id" : "unique id, e.g. TC_001", + "test_name" : "short descriptive name", + "description" : "what this test verifies", + "requirement" : "the original requirement text this case maps to", + "requirement_id" : "e.g. REQ.01", + "steps": [ + { + "step_no" : 1, + "interface_name" : "function or endpoint name", + "protocol" : "http or function", + "url" : "source_file (function) or full_url (http)", + "purpose" : "why this step is needed", + "input": { + // Only parameters with inout = "in" or "inout" + "param_name": + }, + "use_output_of": { // optional + "step_no" : 1, + "field" : "user_id", + "as_param" : "user_id" + }, + "assertions": [ + { + "field" : "field_name or 'return' or 'exception'", + "operator" : "eq|ne|gt|lt|gte|lte|in|not_null|contains|raised|not_raised", + "expected" : , + "message" : "human readable description" + } + ] + } + ], + "test_data_notes" : "explanation of auto-generated test data", + "test_code" : "" +} + +═══════════════════════════════════════════════════════ +TEST CODE RULES +═══════════════════════════════════════════════════════ +1. Complete, runnable Python script. No external test framework. +2. Allowed imports: standard library + `requests` + the actual module under test. + +── FUNCTION INTERFACES ──────────────────────────────── +3. Each function interface has: + "source_file" : e.g. "create_user.py" + "module_path" : e.g. "create_user" (derived from source_file) +4. Import the REAL function using module_path: + try: + from import + except ImportError: + # Stub fallback — simulate on_success for positive tests, + # on_failure / raise Exception for negative tests + def (): + return # or raise Exception(...) +5. Call the function with ONLY "in"/"inout" parameters. +6. Capture the return value; assert against on_success or on_failure structure. +7. If on_failure.value contains "exception" or "raises": + - In negative tests: wrap call in try/except, assert exception IS raised. + - Use field="exception", operator="raised", expected="Exception". + +── HTTP INTERFACES ──────────────────────────────────── +8. Each HTTP interface has: + "full_url" : complete URL, e.g. "http://127.0.0.1/api/delete_user" + "method" : "get" | "post" | "put" | "delete" etc. +9. Send request: + resp = requests.("", json=) +10. Assert on resp.status_code and resp.json() fields. + +── MULTI-STEP ───────────────────────────────────────── +11. Execute steps in order; extract fields from previous step's return value + and pass to next step via use_output_of mapping. + +── STRUCTURED OUTPUT (REQUIRED) ─────────────────────── +12. After EACH step, print: + ##STEP_RESULT## {"step_no":,"interface_name":"...","status":"PASS|FAIL", + "assertions":[{"field":"...","operator":"...","expected":..., + "actual":...,"passed":true|false,"message":"..."}]} +13. Final line must be: + PASS: or FAIL: +14. Wrap entire test body in try/except; print FAIL on any unhandled exception. +15. Do NOT use unittest or pytest. + +═══════════════════════════════════════════════════════ +ASSERTION RULES BY RETURN TYPE +═══════════════════════════════════════════════════════ +return.type = "dict" → assert each key in on_success.value (positive) + or on_failure.value (negative) +return.type = "boolean" → field="return", operator="eq", expected=true/false +return.type = "integer" → field="return", operator="eq"/"gt"/"lt" +on_failure = exception → field="exception", operator="raised" + +═══════════════════════════════════════════════════════ +COVERAGE GUIDELINES +═══════════════════════════════════════════════════════ +- Per requirement: at least 1 positive + 1 negative test case. +- Negative test_name/description MUST contain "negative" or "invalid". +- Multi-interface requirements: single multi-step test case. +- Cover ALL "in"/"inout" parameters (including optional ones). +- Assert ALL fields described in on_success (positive) and on_failure (negative). +- For "out" parameters: verify their values in assertions after the call. +""" + + +class PromptBuilder: + + def get_system_prompt(self) -> str: + return SYSTEM_PROMPT.strip() + + def build_user_prompt( + self, + requirements: list[str], + interfaces: list[InterfaceInfo], + parser: InterfaceParser, + project: str = "", + project_desc: str = "", + ) -> str: + iface_summary = parser.to_summary_dict(interfaces) + req_lines = "\n".join( + f"{i + 1}. {r}" for i, r in enumerate(requirements) + ) + + # 项目信息头 + project_section = "" + if project or project_desc: + project_section = "## Project\n" + if project: + project_section += f"Name : {project}\n" + if project_desc: + project_section += f"Description: {project_desc}\n" + project_section += "\n" + + return ( + f"{project_section}" + f"## Available Interfaces\n" + f"{json.dumps(iface_summary, ensure_ascii=False, indent=2)}\n\n" + f"## Test Requirements\n" + f"{req_lines}\n\n" + f"Generate comprehensive test cases (positive + negative) for every requirement.\n" + f"- For function interfaces: use 'module_path' for import, " + f"'source_file' for reference.\n" + f"- For HTTP interfaces: use 'full_url' as the request URL.\n" + f"- Use requirement_id from the interface to populate the test case's " + f"requirement_id field.\n" + f"- Chain multiple interface calls in one test case when a requirement " + f"involves more than one interface.\n" + ).strip() \ No newline at end of file diff --git a/ai_test_generator/core/report_generator.py b/ai_test_generator/core/report_generator.py new file mode 100644 index 0000000..f1ac64e --- /dev/null +++ b/ai_test_generator/core/report_generator.py @@ -0,0 +1,492 @@ +""" +报告生成器:JSON 摘要 + HTML 可视化报告 +升级点: + - 新增 requirement_id 列 + - 新增 out 参数覆盖率、on_success/on_failure 字段覆盖率 + - 缺口类型标签中文化扩充 +""" + +from __future__ import annotations +import json +import logging +from datetime import datetime +from core.analyzer import CoverageReport, Gap + +logger = logging.getLogger(__name__) + + +class ReportGenerator: + + def save_json(self, report: CoverageReport, path: str): + with open(path, "w", encoding="utf-8") as f: + json.dump(self._to_dict(report), f, ensure_ascii=False, indent=2) + logger.info(f"Coverage JSON → {path}") + + def save_html(self, report: CoverageReport, path: str): + with open(path, "w", encoding="utf-8") as f: + f.write(self._build_html(report)) + logger.info(f"Coverage HTML → {path}") + + # ── JSON 序列化 ─────────────────────────────────────────── + + def _to_dict(self, r: CoverageReport) -> dict: + return { + "generated_at": datetime.now().isoformat(timespec="seconds"), + "summary": { + "interface_coverage": f"{r.interface_coverage_rate:.1%}", + "requirement_coverage": f"{r.requirement_coverage_rate:.1%}", + "avg_in_param_coverage": f"{r.avg_in_param_coverage_rate:.1%}", + "avg_success_field_coverage": f"{r.avg_success_field_coverage_rate:.1%}", + "avg_failure_field_coverage": f"{r.avg_failure_field_coverage_rate:.1%}", + "pass_rate": f"{r.pass_rate:.1%}", + "total_interfaces": r.total_interfaces, + "covered_interfaces": r.covered_interfaces, + "total_requirements": r.total_requirements, + "covered_requirements": r.covered_requirements, + "total_test_cases": r.total_test_cases, + "passed": r.passed_test_cases, + "failed": r.failed_test_cases, + "errors": r.error_test_cases, + "critical_gaps": r.critical_gap_count, + "high_gaps": r.high_gap_count, + }, + "interface_details": [ + { + "name": ic.interface_name, + "requirement_id": ic.requirement_id, + "protocol": ic.protocol, + "is_covered": ic.is_covered, + "has_positive_case": ic.has_positive_case, + "has_negative_case": ic.has_negative_case, + "in_param_coverage": f"{ic.in_param_coverage_rate:.1%}", + "covered_in_params": sorted(ic.covered_in_params), + "uncovered_in_params": sorted(ic.uncovered_in_params), + "out_param_coverage": f"{ic.out_param_coverage_rate:.1%}", + "covered_out_params": sorted(ic.asserted_out_params), + "uncovered_out_params": sorted(ic.uncovered_out_params), + "success_field_coverage": f"{ic.success_field_coverage_rate:.1%}", + "covered_success_fields": sorted(ic.covered_success_fields), + "uncovered_success_fields": sorted(ic.uncovered_success_fields), + "failure_field_coverage": f"{ic.failure_field_coverage_rate:.1%}", + "covered_failure_fields": sorted(ic.covered_failure_fields), + "uncovered_failure_fields": sorted(ic.uncovered_failure_fields), + "expects_exception": ic.expects_exception, + "exception_case_covered": ( + ic.exception_case_covered if ic.expects_exception else None + ), + "covering_test_ids": ic.covering_test_ids, + } + for ic in r.interface_coverages + ], + "requirement_details": [ + { + "requirement": rc.requirement, + "requirement_id": rc.requirement_id, + "is_covered": rc.is_covered, + "covering_test_ids": rc.covering_test_ids, + } + for rc in r.requirement_coverages + ], + "gaps": [ + { + "gap_type": g.gap_type, + "severity": g.severity, + "target": g.target, + "detail": g.detail, + "suggestion": g.suggestion, + } + for g in r.gaps + ], + } + + # ══════════════════════════════════════════════════════════ + # HTML 报告 + # ══════════════════════════════════════════════════════════ + + _GAP_LABELS: dict[str, str] = { + "interface_not_covered": "接口未覆盖", + "in_param_not_covered": "入参未覆盖", + "out_param_not_asserted": "出参未断言", + "success_field_not_asserted": "成功返回字段未断言", + "failure_field_not_asserted": "失败返回字段未断言", + "exception_case_not_tested": "异常场景未测试", + "missing_positive_case": "缺少正向用例", + "missing_negative_case": "缺少负向用例", + "requirement_not_covered": "需求未覆盖", + "test_failed": "用例执行失败", + "test_error": "用例执行异常", + } + + # ── 样式辅助 ────────────────────────────────────────────── + + @staticmethod + def _rate_color(rate: float) -> str: + if rate >= 0.8: return "#27ae60" + if rate >= 0.5: return "#f39c12" + return "#e74c3c" + + @staticmethod + def _sev_badge(sev: str) -> str: + colors = { + "critical": "#e74c3c", "high": "#e67e22", + "medium": "#f1c40f", "low": "#3498db", + } + c = colors.get(sev, "#95a5a6") + return ( + f'' + f'{sev.upper()}' + ) + + @staticmethod + def _bool_badge(val: bool) -> str: + if val: + return '' + return '' + + def _pct_bar(self, rate: float, w: int = 120) -> str: + color = self._rate_color(rate) + filled = int(w * rate) + return ( + f'
' + f'
' + f'
' + f'' + f'{rate * 100:.0f}%
' + ) + + def _metric_card( + self, label: str, value: str, sub: str = "", color: str = "#2c3e50" + ) -> str: + return ( + f'
' + f'
{value}
' + f'
{label}
' + f'
{sub}
' + f'
' + ) + + # ── 汇总卡片 ────────────────────────────────────────────── + + def _build_cards(self, r: CoverageReport) -> str: + rc = self._rate_color + return "".join([ + self._metric_card( + "接口覆盖率", f"{r.interface_coverage_rate:.0%}", + f"{r.covered_interfaces}/{r.total_interfaces}", + rc(r.interface_coverage_rate), + ), + self._metric_card( + "需求覆盖率", f"{r.requirement_coverage_rate:.0%}", + f"{r.covered_requirements}/{r.total_requirements}", + rc(r.requirement_coverage_rate), + ), + self._metric_card( + "入参覆盖率", f"{r.avg_in_param_coverage_rate:.0%}", + "avg in-params", + rc(r.avg_in_param_coverage_rate), + ), + self._metric_card( + "成功返回覆盖", f"{r.avg_success_field_coverage_rate:.0%}", + "on_success fields", + rc(r.avg_success_field_coverage_rate), + ), + self._metric_card( + "失败返回覆盖", f"{r.avg_failure_field_coverage_rate:.0%}", + "on_failure fields", + rc(r.avg_failure_field_coverage_rate), + ), + self._metric_card( + "用例通过率", f"{r.pass_rate:.0%}", + f"{r.passed_test_cases}/{r.total_test_cases}", + rc(r.pass_rate), + ), + self._metric_card( + "Critical 缺口", str(r.critical_gap_count), + f"High: {r.high_gap_count}", + "#e74c3c" if r.critical_gap_count > 0 else "#27ae60", + ), + ]) + + # ── 接口覆盖表 ──────────────────────────────────────────── + + def _build_interface_table(self, r: CoverageReport) -> str: + rows = "" + for ic in r.interface_coverages: + proto_color = "#3498db" if ic.protocol == "http" else "#9b59b6" + + # 异常测试单元格 + if ic.expects_exception: + exc_cell = self._bool_badge(ic.exception_case_covered) + else: + exc_cell = 'N/A' + + # 未覆盖项 tooltip + def missing_tip(items: set[str], label: str) -> str: + if not items: + return "" + joined = ", ".join(sorted(items)) + return ( + f'
' + f'缺: {joined}
' + ) + + rows += f""" + + + {ic.interface_name} + + + {ic.requirement_id or "—"} + + + {ic.protocol.upper()} + + {self._bool_badge(ic.is_covered)} + {self._bool_badge(ic.has_positive_case)} + {self._bool_badge(ic.has_negative_case)} + + {self._pct_bar(ic.in_param_coverage_rate)} + {missing_tip(ic.uncovered_in_params, "缺")} + + + {self._pct_bar(ic.out_param_coverage_rate)} + {missing_tip(ic.uncovered_out_params, "缺")} + + + {self._pct_bar(ic.success_field_coverage_rate)} + {missing_tip(ic.uncovered_success_fields, "缺")} + + + {self._pct_bar(ic.failure_field_coverage_rate)} + {missing_tip(ic.uncovered_failure_fields, "缺")} + + {exc_cell} + + {", ".join(ic.covering_test_ids) or "—"} + + """ + return rows + + # ── 需求覆盖表 ──────────────────────────────────────────── + + def _build_requirement_table(self, r: CoverageReport) -> str: + rows = "" + for rc in r.requirement_coverages: + rows += f""" + + + {rc.requirement_id or "—"} + + {rc.requirement} + {self._bool_badge(rc.is_covered)} + + {", ".join(rc.covering_test_ids) or "—"} + + """ + return rows + + # ── 缺口清单表 ──────────────────────────────────────────── + + def _build_gap_table(self, r: CoverageReport) -> str: + if not r.gaps: + return ( + '🎉 No gaps found! Full coverage achieved.' + ) + rows = "" + for i, g in enumerate(r.gaps, 1): + rows += f""" + + {i} + {self._sev_badge(g.severity)} + + + {self._GAP_LABELS.get(g.gap_type, g.gap_type)} + + + {g.target} + {g.detail} + {g.suggestion} + """ + return rows + + # ── 缺口统计图(纯 CSS 横向柱状图)──────────────────────── + + def _build_gap_chart(self, r: CoverageReport) -> str: + from collections import Counter + type_counts = Counter(g.gap_type for g in r.gaps) + if not type_counts: + return "" + + max_count = max(type_counts.values(), default=1) + bars = "" + sev_color_map: dict[str, str] = {} + for g in r.gaps: + sev_color_map[g.gap_type] = { + "critical": "#e74c3c", "high": "#e67e22", + "medium": "#f1c40f", "low": "#3498db", + }.get(g.severity, "#95a5a6") + + for gap_type, count in sorted(type_counts.items(), key=lambda x: -x[1]): + label = self._GAP_LABELS.get(gap_type, gap_type) + color = sev_color_map.get(gap_type, "#95a5a6") + width = int(count / max_count * 320) + bars += f""" +
+
{label}
+
+ {count} +
""" + + return f""" +
+
+ 缺口类型分布 +
+ {bars} +
""" + + # ── 组装完整 HTML ───────────────────────────────────────── + + def _build_html(self, r: CoverageReport) -> str: + ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + css = """ + * { box-sizing: border-box; margin: 0; padding: 0; } + body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; + background: #f4f6f9; color: #2c3e50; + } + .wrap { max-width: 1440px; margin: 0 auto; padding: 32px 20px; } + h1 { font-size: 24px; font-weight: 700; margin-bottom: 4px; } + .sub { color: #7f8c8d; font-size: 12px; margin-bottom: 26px; } + .cards { display: flex; flex-wrap: wrap; gap: 14px; margin-bottom: 32px; } + h2 { + font-size: 16px; font-weight: 600; margin: 28px 0 12px; + padding-left: 10px; border-left: 4px solid #3498db; + } + table { + width: 100%; border-collapse: collapse; background: #fff; + border-radius: 10px; overflow: hidden; + box-shadow: 0 2px 8px rgba(0,0,0,.07); + } + th { + background: #2c3e50; color: #fff; padding: 10px 12px; + font-size: 12px; text-align: left; white-space: nowrap; + } + td { + padding: 9px 12px; font-size: 12px; + border-bottom: 1px solid #ecf0f1; vertical-align: middle; + } + tr:last-child td { border-bottom: none; } + tr:hover td { background: #f8f9fa; } + code { + background: #ecf0f1; padding: 1px 4px; border-radius: 3px; + font-family: "SFMono-Regular", Consolas, monospace; + } + .footer { + text-align: center; color: #bdc3c7; + font-size: 11px; margin-top: 36px; padding-bottom: 20px; + } + """ + + return f""" + + + + +AI Test Coverage Report + + + +
+ + +

🧪 AI Test Coverage Report

+
Generated at {ts}  |  + Interfaces: {r.total_interfaces}  |  + Test Cases: {r.total_test_cases}  |  + Gaps: {len(r.gaps)} +
+ + +
{self._build_cards(r)}
+ + + {self._build_gap_chart(r)} + + +

📡 接口覆盖详情

+
+ + + + + + + + + + + + + + + + + + {self._build_interface_table(r)} +
接口名称REQ ID协议已覆盖正向负向入参覆盖率出参断言率成功返回覆盖失败返回覆盖异常测试覆盖用例
+
+ + +

📋 需求覆盖详情

+ + + + + + + + + + {self._build_requirement_table(r)} +
REQ ID需求描述已覆盖覆盖用例
+ + +

🔍 缺口清单(共 {len(r.gaps)} 项, + Critical: {r.critical_gap_count}, + High: {r.high_gap_count}) +

+
+ + + + + + + + + + + + {self._build_gap_table(r)} +
#严重程度缺口类型目标详情补充建议
+
+ + +
+ +""" \ No newline at end of file diff --git a/ai_test_generator/core/test_generator.py b/ai_test_generator/core/test_generator.py new file mode 100644 index 0000000..ebe4365 --- /dev/null +++ b/ai_test_generator/core/test_generator.py @@ -0,0 +1,138 @@ +""" +测试文件生成器 +升级点: + - 接收 project / project_desc,写入头部注释 + - 头部注释增加 url / source_file / module_path 信息 + - 输出到 // 子目录 +""" + +from __future__ import annotations +import json +import re +import logging +from pathlib import Path +from config import config + +logger = logging.getLogger(__name__) + + +class TestGenerator: + + def __init__(self, project: str = "", project_desc: str = ""): + self.project = project + self.project_desc = project_desc + + safe = self._safe_name(project) if project else "" + base = Path(config.GENERATED_TESTS_DIR) + self.output_dir = base / safe if safe else base + self.output_dir.mkdir(parents=True, exist_ok=True) + + # ── 公开接口 ────────────────────────────────────────────── + + def generate(self, test_cases: list[dict]) -> list[Path]: + self._save_summary(test_cases) + files: list[Path] = [] + for tc in test_cases: + path = self._write_file(tc) + if path: + files.append(path) + logger.info( + f"Generated {len(files)} test file(s) → '{self.output_dir}'" + ) + return files + + # ── 写入单个测试文件 ────────────────────────────────────── + + def _write_file(self, tc: dict) -> Path | None: + test_id = tc.get("test_id", "TC_UNKNOWN") + test_code = tc.get("test_code", "").strip() + if not test_code: + logger.warning(f"[{test_id}] No test_code, skipping.") + return None + + file_path = self.output_dir / f"{self._safe_name(test_id)}.py" + with open(file_path, "w", encoding="utf-8") as f: + f.write(self._build_header(tc)) + f.write("\n") + f.write(test_code) + f.write("\n") + + logger.info(f" Written: {file_path}") + return file_path + + # ── 文件头注释 ──────────────────────────────────────────── + + def _build_header(self, tc: dict) -> str: + req_id = tc.get("requirement_id", "") + req_text = tc.get("requirement", "") + data_notes = tc.get("test_data_notes", "") + steps = tc.get("steps", []) + + # 项目信息行 + proj_lines = "" + if self.project: + proj_lines += f"# Project : {self.project}\n" + if self.project_desc: + proj_lines += f"# Proj Desc : {self.project_desc}\n" + + # 步骤注释 + step_lines = "" + for s in steps: + protocol = s.get("protocol", "").upper() + iface = s.get("interface_name", "") + url = s.get("url", "") + purpose = s.get("purpose", "") + + step_lines += ( + f"# Step {s.get('step_no')}: [{protocol}] {iface}\n" + f"# URL : {url}\n" + f"# Purpose: {purpose}\n" + ) + + inputs = s.get("input", {}) + if inputs: + step_lines += ( + f"# Input : " + f"{json.dumps(inputs, ensure_ascii=False)}\n" + ) + + for a in s.get("assertions", []): + msg = a.get("message", "") + step_lines += ( + f"# Assert : {a.get('field')} " + f"{a.get('operator')} {a.get('expected')}" + f"{' — ' + msg if msg else ''}\n" + ) + + return ( + f"# {'=' * 60}\n" + f"{proj_lines}" + f"# Test ID : {tc.get('test_id', '')}\n" + f"# Test Name : {tc.get('test_name', '')}\n" + f"# Requirement: [{req_id}] {req_text}\n" + f"# Description: {tc.get('description', '')}\n" + f"# Test Data : {data_notes}\n" + f"# Steps:\n" + f"{step_lines.rstrip()}\n" + f"# {'=' * 60}\n" + f"import os\n" + f"import json\n" + ) + + # ── 保存摘要 JSON ───────────────────────────────────────── + + def _save_summary(self, test_cases: list[dict]): + summary = [ + {k: v for k, v in tc.items() if k != "test_code"} + for tc in test_cases + ] + path = self.output_dir / "test_cases_summary.json" + with open(path, "w", encoding="utf-8") as f: + json.dump(summary, f, ensure_ascii=False, indent=2) + logger.info(f"Summary → {path}") + + # ── 工具 ────────────────────────────────────────────────── + + @staticmethod + def _safe_name(name: str) -> str: + return re.sub(r'[^\w\-]', '_', name).strip("_") \ No newline at end of file diff --git a/ai_test_generator/core/test_runner.py b/ai_test_generator/core/test_runner.py new file mode 100644 index 0000000..90a3e7f --- /dev/null +++ b/ai_test_generator/core/test_runner.py @@ -0,0 +1,168 @@ +import os +import sys +import json +import subprocess +import time +import logging +from pathlib import Path +from dataclasses import dataclass, field +from config import config + +logger = logging.getLogger(__name__) + + +@dataclass +class StepResult: + step_no: int + interface_name: str + status: str # PASS | FAIL | ERROR + assertions: list[dict] = field(default_factory=list) + # 每条 assertion: {"field","operator","expected","actual","passed","message"} + + +@dataclass +class TestResult: + test_id: str + file_path: str + status: str # PASS | FAIL | ERROR | TIMEOUT + message: str = "" + duration: float = 0.0 + stdout: str = "" + stderr: str = "" + step_results: list[StepResult] = field(default_factory=list) + + +class TestRunner: + + def run_all(self, test_files: list[Path]) -> list[TestResult]: + results = [] + total = len(test_files) + print(f"\n{'═'*62}") + print(f" Running {total} test(s) …") + print(f"{'═'*62}") + for idx, fp in enumerate(test_files, 1): + print(f"\n[{idx}/{total}] {fp.name}") + result = self._run_one(fp) + results.append(result) + self._print_result(result) + return results + + # ── 执行单个脚本 ────────────────────────────────────────── + + def _run_one(self, file_path: Path) -> TestResult: + test_id = file_path.stem + t0 = time.time() + try: + proc = subprocess.run( + [sys.executable, str(file_path)], + capture_output=True, text=True, + timeout=config.TEST_TIMEOUT, + env=self._env(), + ) + duration = time.time() - t0 + stdout = proc.stdout.strip() + stderr = proc.stderr.strip() + status, message = self._parse_output(stdout, proc.returncode) + step_results = self._parse_step_results(stdout) + return TestResult( + test_id=test_id, file_path=str(file_path), + status=status, message=message, + duration=duration, stdout=stdout, stderr=stderr, + step_results=step_results, + ) + except subprocess.TimeoutExpired: + return TestResult( + test_id=test_id, file_path=str(file_path), + status="TIMEOUT", message=f"Exceeded {config.TEST_TIMEOUT}s", + duration=time.time() - t0, + ) + except Exception as e: + return TestResult( + test_id=test_id, file_path=str(file_path), + status="ERROR", message=str(e), + duration=time.time() - t0, + ) + + # ── 解析输出 ────────────────────────────────────────────── + + def _parse_output(self, stdout: str, returncode: int) -> tuple[str, str]: + if not stdout: + return "FAIL", f"No output (exit={returncode})" + last = stdout.strip().splitlines()[-1].strip() + upper = last.upper() + if upper.startswith("PASS"): + return "PASS", last[5:].strip() + if upper.startswith("FAIL"): + return "FAIL", last[5:].strip() + return ("PASS" if returncode == 0 else "FAIL"), last + + def _parse_step_results(self, stdout: str) -> list[StepResult]: + """ + 解析脚本中以 ##STEP_RESULT## 开头的结构化输出行 + 格式:##STEP_RESULT## + """ + results = [] + for line in stdout.splitlines(): + line = line.strip() + if line.startswith("##STEP_RESULT##"): + try: + data = json.loads(line[len("##STEP_RESULT##"):].strip()) + results.append(StepResult( + step_no=data.get("step_no", 0), + interface_name=data.get("interface_name", ""), + status=data.get("status", ""), + assertions=data.get("assertions", []), + )) + except Exception: + pass + return results + + def _env(self) -> dict: + env = os.environ.copy() + env["HTTP_BASE_URL"] = config.HTTP_BASE_URL + return env + + # ── 打印 ────────────────────────────────────────────────── + + def _print_result(self, r: TestResult): + icon = {"PASS": "✅", "FAIL": "❌", "TIMEOUT": "⏱️", "ERROR": "⚠️"}.get(r.status, "❓") + print(f" {icon} [{r.status}] {r.test_id} ({r.duration:.2f}s)") + print(f" {r.message}") + for sr in r.step_results: + s_icon = "✅" if sr.status == "PASS" else "❌" + print(f" {s_icon} Step {sr.step_no}: {sr.interface_name}") + for a in sr.assertions: + a_icon = "✔" if a.get("passed") else "✘" + print(f" {a_icon} {a.get('message','')} " + f"(expected={a.get('expected')}, actual={a.get('actual')})") + if r.stderr: + print(f" stderr: {r.stderr[:300]}") + + def print_summary(self, results: list[TestResult]): + total = len(results) + passed = sum(1 for r in results if r.status == "PASS") + failed = sum(1 for r in results if r.status == "FAIL") + errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT")) + print(f"\n{'═'*62}") + print(f" TEST SUMMARY") + print(f"{'═'*62}") + print(f" Total : {total}") + print(f" ✅ PASS : {passed}") + print(f" ❌ FAIL : {failed}") + print(f" ⚠️ ERROR : {errors}") + print(f" Pass Rate: {passed/total*100:.1f}%" if total else " Pass Rate: N/A") + print(f"{'═'*62}\n") + + def save_results(self, results: list[TestResult], path: str): + with open(path, "w", encoding="utf-8") as f: + json.dump([{ + "test_id": r.test_id, "status": r.status, + "message": r.message, "duration": r.duration, + "stdout": r.stdout, "stderr": r.stderr, + "step_results": [ + {"step_no": sr.step_no, "interface_name": sr.interface_name, + "status": sr.status, "assertions": sr.assertions} + for sr in r.step_results + ], + } for r in results], f, ensure_ascii=False, indent=2) + logger.info(f"Run results → {path}") \ No newline at end of file diff --git a/ai_test_generator/examples/api_desc.json b/ai_test_generator/examples/api_desc.json new file mode 100644 index 0000000..9add8f7 --- /dev/null +++ b/ai_test_generator/examples/api_desc.json @@ -0,0 +1,91 @@ +{ + "project": "project name", + "description": "this is my first project", + "units": [ + { + "name": "create_user", + "requirement_id": "REQ.01", + "description": "Creates a new user account with provided credentials and information.", + "type": "function", + "url": "create_user.py", + "parameters": { + "username": { + "type": "string", + "inout": "in", + "description": "The desired username for the new account.", + "required": true + }, + "password": { + "type": "string", + "inout": "in", + "description": "The password for the new account, meeting security requirements.", + "required": true + }, + "email": { + "type": "string", + "inout": "in", + "description": "The email address associated with the new account.", + "required": true + }, + "phone_number": { + "type": "string", + "inout": "in", + "description": "The phone number associated with the new account.", + "required": false + }, + "additional_info": { + "type": "dict", + "inout": "in", + "description": "Optional additional information about the user.", + "required": false + } + }, + "return": { + "type": "dict", + "description": "Returns the created user object on success or an error message on failure.", + "on_success": { + "value": "User object containing user details such as id, username, email, etc.", + "description": "The user object representing the newly created account." + }, + "on_failure": { + "value": "Error message or raises an exception.", + "description": "Details about why the user creation failed, such as validation errors or duplicate username." + } + } + }, + { + "name": "/delete_user", + "requirement_id": "REQ.02", + "description": "Deletes a user by user ID or username from the system.", + "type": "http", + "url": "http://127.0.0.1/api", + "method": "post", + "parameters": { + "user_id": { + "type": "integer", + "inout": "in", + "description": "Unique identifier of the user to be deleted.", + "required": false + }, + "username": { + "type": "string", + "inout": "in", + "description": "Username of the user to be deleted.", + "required": false + } + }, + "return": { + "type": "boolean", + "description": "Indicates whether the user was successfully deleted.", + "on_success": { + "value": true, + "description": "The user was successfully deleted." + }, + "on_failure": { + "value": false, + "description": "The user could not be deleted, or the user does not exist." + } + } + } + ] +} \ No newline at end of file diff --git a/ai_test_generator/examples/function_signatures.json b/ai_test_generator/examples/function_signatures.json new file mode 100644 index 0000000..2ac803f --- /dev/null +++ b/ai_test_generator/examples/function_signatures.json @@ -0,0 +1,91 @@ +{ + "project": "project name", + "description": "this is my first project", + "units": [ + { + "name": "create_user", + "requirement_id": "REQ.01", + "description": "Creates a new user account with provided credentials and information.", + "type": "function", + "url": "/Users/sontolau/Workspace/AIDeveloper/requirements_generator/output/my_first_project/create_user.py", + "parameters": { + "username": { + "type": "string", + "inout": "in", + "description": "The desired username for the new account.", + "required": true + }, + "password": { + "type": "string", + "inout": "in", + "description": "The password for the new account, meeting security requirements.", + "required": true + }, + "email": { + "type": "string", + "inout": "in", + "description": "The email address associated with the new account.", + "required": true + }, + "phone_number": { + "type": "string", + "inout": "in", + "description": "The phone number associated with the new account.", + "required": false + }, + "additional_info": { + "type": "dict", + "inout": "in", + "description": "Optional additional information about the user.", + "required": false + } + }, + "return": { + "type": "dict", + "description": "Returns the created user object on success or an error message on failure.", + "on_success": { + "value": "User object containing user details such as id, username, email, etc.", + "description": "The user object representing the newly created account." + }, + "on_failure": { + "value": "Error message or raises an exception.", + "description": "Details about why the user creation failed, such as validation errors or duplicate username." + } + } + }, + { + "name": "/delete_user", + "requirement_id": "REQ.02", + "description": "Deletes a user by user ID or username from the system.", + "type": "http", + "url": "http://127.0.0.1/api", + "method": "post", + "parameters": { + "user_id": { + "type": "integer", + "inout": "in", + "description": "Unique identifier of the user to be deleted.", + "required": false + }, + "username": { + "type": "string", + "inout": "in", + "description": "Username of the user to be deleted.", + "required": false + } + }, + "return": { + "type": "boolean", + "description": "Indicates whether the user was successfully deleted.", + "on_success": { + "value": true, + "description": "The user was successfully deleted." + }, + "on_failure": { + "value": false, + "description": "The user could not be deleted, or the user does not exist." + } + } + } + ] +} \ No newline at end of file diff --git a/ai_test_generator/main.py b/ai_test_generator/main.py new file mode 100644 index 0000000..5714c79 --- /dev/null +++ b/ai_test_generator/main.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +""" +AI-Powered API Test Generator, Runner & Coverage Analyzer +────────────────────────────────────────────────────────── +Usage: + python main.py --api-desc examples/api_desc.json \\ + --requirements "创建用户,删除用户" + + python main.py --api-desc examples/api_desc.json \\ + --req-file examples/requirements.txt +""" + +import argparse +import logging +import sys +from pathlib import Path + +from config import config +from core.parser import InterfaceParser, ApiDescriptor +from core.prompt_builder import PromptBuilder +from core.llm_client import LLMClient +from core.test_generator import TestGenerator +from core.test_runner import TestRunner +from core.analyzer import CoverageAnalyzer, CoverageReport +from core.report_generator import ReportGenerator + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + + +# ── CLI ─────────────────────────────────────────────────────── + +def build_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + description="AI-Powered API Test Generator & Coverage Analyzer" + ) + p.add_argument( + "--api-desc", required=True, + help="接口描述 JSON 文件(含 project / description / units 字段)", + ) + p.add_argument( + "--requirements", default="", + help="测试需求(逗号分隔)", + ) + p.add_argument( + "--req-file", default="", + help="测试需求文件(每行一条)", + ) + p.add_argument( + "--skip-run", action="store_true", + help="只生成测试文件,不执行", + ) + p.add_argument( + "--skip-analyze", action="store_true", + help="跳过覆盖率分析", + ) + p.add_argument( + "--output-dir", default="", + help="测试文件根输出目录(默认 config.GENERATED_TESTS_DIR)", + ) + p.add_argument( + "--debug", action="store_true", + help="开启 DEBUG 日志", + ) + return p + + +def load_requirements(args) -> list[str]: + reqs: list[str] = [] + if args.requirements: + reqs += [r.strip() for r in args.requirements.split(",") if r.strip()] + if args.req_file: + with open(args.req_file, "r", encoding="utf-8") as f: + reqs += [line.strip() for line in f if line.strip()] + if not reqs: + logger.error( + "No requirements provided. Use --requirements or --req-file." + ) + sys.exit(1) + return reqs + + +# ── Main ────────────────────────────────────────────────────── + +def main(): + args = build_arg_parser().parse_args() + + if args.debug: + logging.getLogger().setLevel(logging.DEBUG) + if args.output_dir: + config.GENERATED_TESTS_DIR = args.output_dir + + # ── Step 1: 解析接口描述 ────────────────────────────────── + logger.info("▶ Step 1: Parsing interface description …") + parser: InterfaceParser = InterfaceParser() + descriptor: ApiDescriptor = parser.parse_file(args.api_desc) + + project = descriptor.project + project_desc = descriptor.description + interfaces = descriptor.interfaces + + logger.info(f" Project : {project}") + logger.info(f" Description: {project_desc}") + logger.info( + f" Interfaces : {len(interfaces)} — " + f"{[i.name for i in interfaces]}" + ) + # 打印每个接口的 url 解析结果 + for iface in interfaces: + if iface.protocol == "function": + logger.info( + f" [{iface.name}] source={iface.source_file} " + f"module={iface.module_path}" + ) + else: + logger.info( + f" [{iface.name}] full_url={iface.http_full_url}" + ) + + # ── Step 2: 加载测试需求 ────────────────────────────────── + logger.info("▶ Step 2: Loading test requirements …") + requirements = load_requirements(args) + for i, req in enumerate(requirements, 1): + logger.info(f" {i}. {req}") + + # ── Step 3: 调用 LLM 生成测试用例 ──────────────────────── + logger.info("▶ Step 3: Calling LLM …") + builder = PromptBuilder() + test_cases = LLMClient().generate_test_cases( + builder.get_system_prompt(), + builder.build_user_prompt( + requirements, interfaces, parser, + project=project, + project_desc=project_desc, + ), + ) + logger.info(f" LLM returned {len(test_cases)} test case(s)") + + # ── Step 4: 生成测试文件 ────────────────────────────────── + logger.info( + f"▶ Step 4: Generating test files (project='{project}') …" + ) + generator = TestGenerator(project=project, project_desc=project_desc) + test_files = generator.generate(test_cases) + out = generator.output_dir + + run_results = [] + + if not args.skip_run: + # ── Step 5: 执行测试 ────────────────────────────────── + logger.info("▶ Step 5: Running tests …") + runner = TestRunner() + run_results = runner.run_all(test_files) + runner.print_summary(run_results) + runner.save_results(run_results, str(out / "run_results.json")) + + if not args.skip_analyze: + # ── Step 6: 覆盖率分析 ──────────────────────────────── + logger.info("▶ Step 6: Analyzing coverage …") + report = CoverageAnalyzer( + interfaces=interfaces, + requirements=requirements, + test_cases=test_cases, + run_results=run_results, + ).analyze() + + # ── Step 7: 生成报告 ────────────────────────────────── + logger.info("▶ Step 7: Generating reports …") + rg = ReportGenerator() + rg.save_json(report, str(out / "coverage_report.json")) + rg.save_html(report, str(out / "coverage_report.html")) + + _print_terminal_summary(report, out, project) + + logger.info(f"\n✅ Done. Output: {out.resolve()}") + + +# ── 终端摘要 ────────────────────────────────────────────────── + +def _print_terminal_summary( + report: CoverageReport, out: Path, project: str +): + W = 66 + + def bar(rate: float, w: int = 20) -> str: + filled = int(rate * w) + empty = w - filled + icon = "✅" if rate >= 0.8 else ("⚠️ " if rate >= 0.5 else "❌") + return f"{'█' * filled}{'░' * empty} {rate * 100:.1f}% {icon}" + + print(f"\n{'═' * W}") + print(f" PROJECT : {project}") + print(f" COVERAGE SUMMARY") + print(f"{'═' * W}") + print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}") + print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}") + print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}") + print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}") + print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}") + print(f" 用例通过率 {bar(report.pass_rate)}") + print(f"{'─' * W}") + print(f" Total Gaps : {len(report.gaps)}") + print(f" 🔴 Critical: {report.critical_gap_count}") + print(f" 🟠 High : {report.high_gap_count}") + + if report.gaps: + print(f"{'─' * W}") + print(" Top Gaps (up to 8):") + icons = { + "critical": "🔴", "high": "🟠", + "medium": "🟡", "low": "🔵", + } + for g in report.gaps[:8]: + icon = icons.get(g.severity, "⚪") + print(f" {icon} [{g.gap_type}] {g.target}") + print(f" → {g.suggestion}") + + print(f"{'─' * W}") + print(f" Output : {out.resolve()}") + print(f" • coverage_report.html ← open in browser") + print(f" • coverage_report.json") + print(f" • run_results.json") + print(f" • test_cases_summary.json") + print(f"{'═' * W}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ai_test_generator/requirements.txt b/ai_test_generator/requirements.txt new file mode 100644 index 0000000..16e06ee --- /dev/null +++ b/ai_test_generator/requirements.txt @@ -0,0 +1,2 @@ +openai>=1.30.0 +requests>=2.31.0 \ No newline at end of file diff --git a/ai_test_generator/run.sh b/ai_test_generator/run.sh new file mode 100755 index 0000000..4cc0dc3 --- /dev/null +++ b/ai_test_generator/run.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +# 安装依赖 +pip install -r requirements.txt + +# 设置环境变量 +export LLM_API_KEY="sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR" +export LLM_BASE_URL="https://openapi.monica.im/v1" # 或其他兼容接口 +export LLM_MODEL="gpt-4o" +export HTTP_BASE_URL="http://localhost:8080" + +# 运行(生成 + 执行) +# python main.py \ +# --api-desc examples/api_desc.json \ +# --requirements "创建用户,自动生成测试数据,更改指定用户的密码,自动生成测试数据" + +# 只生成不执行 +python main.py --api-desc examples/function_signatures.json \ + --requirements "1.对每个接口进行测试,支持特殊值、边界值测试" diff --git a/auto-test.txt b/auto-test.txt new file mode 100644 index 0000000..ed0e1fa --- /dev/null +++ b/auto-test.txt @@ -0,0 +1,59 @@ +你现在是一名高级ai软件工程师,非常擅长通过llm模型解决软件应用需求,现在需要构建python工程利用ai实现基于api的自动化测试代码生成与执行,支持将用户每个测试需求分解成对1个或多个接口的调用,并基于接口返回描述对接口返回值进行测试,设计如下: +# 工作流程 +1. 用户输入测试需求、接口描述文件(文件采用json格式,包含接口名称、请求参数描述、返回参数描述,接口描述等信息)。例如: +## 接口描述示例 +[ + { + "name": "/api/create_user", + "description": "to create new user account in the system", + "protocol": "http", + "base_url": "http://123.0.0.1:8000" + "method": "post", + "url_parameters": { + "version": { + "type": "integer", + "description": "the api's version", + "required": "false" + } + }, + "request_parameters": { + "username": { "type": "string", "description": "user name", "required": true }, + "email": { "type": "string", "description": "the user's email", "required": false }, + "password": { "type": "string", "description": "the user's password", "required": true } + }, + "response": { + "code": { + "type": "integer", + "description": "0 is successful, nonzero is failure" + }, + "user_id": { + "type": "integer", + "description": "the created user's id " + } + } + }, + { + "name": "change_password", + "description": "replace old password with new password", + "protocol": "function", + "parameters": { + "user_id": { "type": "integer", "inout","in", "description": "the user's id", "required": true }, + "old_password": { "type": "string", "inout","in", "description": "the old password", "required": true }, + "new_password": { "type": "string", "inout","in", "description": "the user's new password", "required": true } + }, + "return": { + "type": "integer", + "description": "0 is successful, nonzero is failure" + } + }, + ... +] +## 测试需求示例 +(1)创建用户,自动生成测试数据 +(2)更改指定用户的密码,自动生成测试数据;支持先创建用户,再改变该用户的密码; + +2. 软件将接口信息、指令、输出要求等通过提示词传递给llm模型。 +3. llm模型结合用户测试需求,并根据接口列表生成针对每个测试需求的测试用例(包含生成的测试数据、代码等),测试用例支持基于测试逻辑对多个接口的调用与返回值测试 +4. 输出的测试用例以json格式返回。 +5. 软件解析json数据并生成测试用例文件 +6. 软件运行测试用例文件并输出每个测试用例的结果; diff --git a/init_request.json b/init_request.json new file mode 100644 index 0000000..e69de29 diff --git a/mcp-client.py b/mcp-client.py new file mode 100644 index 0000000..6f38b6a --- /dev/null +++ b/mcp-client.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +import requests +import json + +session_id="123456789" + +class MCPClient: + def __init__(self, url): + self.url = url + self.session = requests.Session() + self.session_id = None + + def initialize(self): + """初始化 MCP 会话""" + payload = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "python-client", + "version": "1.0.0" + } + } + } + + response = self.session.post( + self.url, + json=payload, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + ) + self.session_id = response.headers['Mcp-Session-Id'] + + + def call_method(self, method, params=None, request_id=None): + """调用 MCP 方法""" + payload = { + "jsonrpc": "2.0", + "id": request_id or method, + "method": method + } + + if params: + payload["params"] = params + + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream" + } + + # 如果有 session ID,添加到头部 + if self.session_id: + headers["Mcp-Session-Id"] = self.session_id + + response = self.session.post( + self.url, + json=payload, + headers=headers + ) + print(response.text) + return response.json() + + +# 使用示例 +if __name__ == "__main__": + client = MCPClient("http://iot.gesukj.com:7999/mcp") + + # 初始化 + print("Initializing...") + init_result = client.initialize() + # print(json.dumps(init_result, indent=2)) + + # 列出工具 + print("\nListing tools...") + tools_result = client.call_method("tools/list", request_id=4) + print(json.dumps(tools_result, indent=2)) diff --git a/requirements_generator/__init__.py b/requirements_generator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/requirements_generator/config.py b/requirements_generator/config.py new file mode 100644 index 0000000..a9702aa --- /dev/null +++ b/requirements_generator/config.py @@ -0,0 +1,146 @@ +# config.py - 全局配置管理 +import os +from dotenv import load_dotenv + +load_dotenv() + +# ── LLM 配置 ────────────────────────────────────────── +LLM_API_KEY = os.getenv("OPENAI_API_KEY", "") +LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") +LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o") +LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0.3")) + +# ── 数据库配置 ───────────────────────────────────────── +DB_PATH = os.getenv("DB_PATH", "data/requirement_analyzer.db") + +# ── 输出配置 ─────────────────────────────────────────── +OUTPUT_BASE_DIR = os.getenv("OUTPUT_BASE_DIR", "output") +DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "python") + +# ══════════════════════════════════════════════════════ +# Prompt 模板 +# ══════════════════════════════════════════════════════ + +DECOMPOSE_PROMPT_TEMPLATE = """ +你是一位资深软件架构师和产品经理。请根据以下信息,将原始需求分解为若干个可独立实现的功能需求。 + +{knowledge_section} + +## 原始需求 +{raw_requirement} + +## 输出要求 +请严格按照以下 JSON 格式输出,不要包含任何额外说明: +{{ + "functional_requirements": [ + {{ + "index": 1, + "title": "功能需求标题(简洁,10字以内)", + "description": "功能需求详细描述(包含输入、处理逻辑、输出)", + "function_name": "snake_case函数名", + "priority": "high|medium|low" + }} + ] +}} + +要求: +1. 每个功能需求必须是独立可实现的最小单元 +2. function_name 使用 snake_case 命名,清晰表达函数用途 +3. 分解粒度适中,通常 5-15 个功能需求 +4. 优先级根据业务重要性判断 +""" + +# ── 函数签名 JSON 生成 Prompt ────────────────────────── +FUNC_SIGNATURE_PROMPT_TEMPLATE = """ +你是一位资深软件架构师。请根据以下功能需求描述,设计该函数的完整接口签名,并以 JSON 格式输出。 + +{knowledge_section} + +## 功能需求 +需求编号:{requirement_id} +标题:{title} +函数名:{function_name} +详细描述:{description} + +## 输出格式 +请严格按照以下 JSON 结构输出,不要包含任何额外说明或 markdown 标记: +{{ + "name": "{function_name}", + "requirement_id": "{requirement_id}", + "description": "简洁的一句话功能描述(英文)", + "type": "function", + "parameters": {{ + "": {{ + "type": "integer|string|boolean|float|list|dict|object", + "inout": "in|out|inout", + "description": "参数说明(英文)", + "required": true + }} + }}, + "return": {{ + "type": "integer|string|boolean|float|list|dict|object|void", + "description": "整体返回值说明(英文,一句话概括)", + "on_success": {{ + "value": "具体成功返回值或范围,如 0、true、user object、list of items 等", + "description": "成功时的返回值含义(英文)" + }}, + "on_failure": {{ + "value": "具体失败返回值或范围,如 nonzero、false、null、empty list、raises Exception 等", + "description": "失败时的返回值含义,或抛出的异常类型(英文)" + }} + }} +}} + +## 设计规范 +1. 参数名使用 snake_case,类型使用通用类型(不绑定具体语言) +2. inout 字段含义: + - in = 仅输入参数 + - out = 仅输出参数(通过参数传出结果,如指针/引用) + - inout = 既作输入又作输出 +3. 所有描述字段使用英文 +4. return 字段规则: + - 若函数无返回值(void),type 填 "void",on_success/on_failure 均填 null + - 若返回值只有成功场景(如纯查询),on_failure 可描述为 "null or empty" + - on_success.value / on_failure.value 填写具体值或值域描述,不要填写空字符串 +5. 若函数无参数,parameters 填 {{}} +6. required 字段为布尔值 true 或 false +""" + +# ── 代码生成 Prompt(含签名约束)───────────────────────── +CODE_GEN_PROMPT_TEMPLATE = """ +你是一位资深 {language} 工程师。请根据以下功能需求和【函数签名规范】,生成完整的 {language} 函数代码。 + +{knowledge_section} + +## 功能需求 +标题:{title} +描述:{description} + +## 【必须严格遵守】函数签名规范 +以下 JSON 定义了函数的精确接口,生成的代码必须与之完全一致,不得擅自增减或改名参数: + +```json +{signature_json} +``` + +### 签名字段说明 +- `name`:函数名,必须完全一致 +- `parameters`:每个 key 即为参数名,`type` 为数据类型,`inout` 含义: + - `in` = 普通输入参数 + - `out` = 输出参数(Python 中通过返回值或可变容器传出) + - `inout` = 既作输入又作输出 +- `return.type`:返回值类型 +- `return.on_success`:成功时的返回值,代码实现必须与此一致 +- `return.on_failure`:失败时的返回值或异常,代码实现必须与此一致 + +## 输出要求 +1. 只输出纯代码,不要包含 markdown 代码块标记 +2. 函数签名(名称、参数列表、返回类型)必须与上方 JSON 规范完全一致 +3. 成功/失败的返回值必须严格遵守 return.on_success / return.on_failure 的定义 +4. 包含完整的类型注解(Python 使用 type hints) +5. 包含详细的 docstring,其中 Returns 段须注明成功值与失败值 +6. 包含必要的异常处理 +7. 代码风格遵循 PEP8(Python)或对应语言规范 +8. 在文件顶部用注释注明:需求编号、功能标题、函数签名摘要 +9. 如需导入第三方库,请在顶部统一导入 +""" \ No newline at end of file diff --git a/requirements_generator/core/__init__.py b/requirements_generator/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/requirements_generator/core/code_generator.py b/requirements_generator/core/code_generator.py new file mode 100644 index 0000000..2418726 --- /dev/null +++ b/requirements_generator/core/code_generator.py @@ -0,0 +1,216 @@ +# core/code_generator.py - 代码生成核心逻辑(签名约束版) +import json +from typing import Optional, List, Callable + +import config +from core.llm_client import LLMClient +from database.models import FunctionalRequirement, CodeFile +from utils.output_writer import write_code_file, get_file_extension + + +class CodeGenerator: + """ + 根据功能需求 + 函数签名约束,使用 LLM 生成代码函数文件。 + 签名由 RequirementAnalyzer.build_function_signature() 预先生成, + 注入 Prompt 后可确保代码参数列表与签名 JSON 完全一致。 + """ + + def __init__(self, llm_client: Optional[LLMClient] = None): + """ + 初始化代码生成器 + + Args: + llm_client: LLM 客户端实例,为 None 时自动创建 + """ + self.llm = llm_client or LLMClient() + + # ══════════════════════════════════════════════════ + # 单个生成 + # ══════════════════════════════════════════════════ + + def generate( + self, + func_req: FunctionalRequirement, + output_dir: str, + language: str = config.DEFAULT_LANGUAGE, + knowledge: str = "", + signature: Optional[dict] = None, + ) -> CodeFile: + """ + 为单个功能需求生成代码文件 + + Args: + func_req: 功能需求对象(必须含有效 id) + output_dir: 代码输出目录 + language: 目标编程语言 + knowledge: 知识库文本(可选) + signature: 函数签名 dict(由 RequirementAnalyzer 生成)。 + 传入后将作为强约束注入 Prompt,确保代码参数 + 与签名 JSON 完全一致;为 None 时退化为无约束模式。 + + Returns: + CodeFile 对象(含生成的代码内容和文件路径,未持久化) + + Raises: + ValueError: func_req.id 为 None + Exception: LLM 调用失败或文件写入失败 + """ + if func_req.id is None: + raise ValueError("FunctionalRequirement 必须先持久化(id 不能为 None)") + + knowledge_section = self._build_knowledge_section(knowledge) + signature_json = self._build_signature_json(signature, func_req) + + prompt = config.CODE_GEN_PROMPT_TEMPLATE.format( + language=language, + knowledge_section=knowledge_section, + title=func_req.title, + description=func_req.description, + signature_json=signature_json, + ) + + code_content = self.llm.chat( + system_prompt=( + f"你是一位资深 {language} 工程师,只输出纯代码," + "不添加任何 markdown 标记。函数签名必须与提供的 JSON 规范完全一致。" + ), + user_prompt=prompt, + ) + + file_path = write_code_file( + output_dir=output_dir, + function_name=func_req.function_name, + language=language, + content=code_content, + ) + + ext = get_file_extension(language) + file_name = f"{func_req.function_name}{ext}" + + return CodeFile( + project_id=func_req.project_id, + func_req_id=func_req.id, + file_name=file_name, + file_path=file_path, + language=language, + content=code_content, + ) + + # ══════════════════════════════════════════════════ + # 批量生成 + # ══════════════════════════════════════════════════ + + def generate_batch( + self, + func_reqs: List[FunctionalRequirement], + output_dir: str, + language: str = config.DEFAULT_LANGUAGE, + knowledge: str = "", + signatures: Optional[List[dict]] = None, + on_progress: Optional[Callable] = None, + ) -> List[CodeFile]: + """ + 批量生成代码文件 + + Args: + func_reqs: 功能需求列表 + output_dir: 输出目录 + language: 目标语言 + knowledge: 知识库文本 + signatures: 与 func_reqs 等长的签名列表(索引对应)。 + 为 None 时所有条目均以无约束模式生成。 + on_progress: 进度回调 fn(index, total, func_req, code_file, error) + + Returns: + 成功生成的 CodeFile 列表 + """ + results = [] + total = len(func_reqs) + + # 构建 func_req.id → signature 的快速查找表 + sig_map = self._build_signature_map(func_reqs, signatures) + + for i, req in enumerate(func_reqs): + sig = sig_map.get(req.id) + try: + code_file = self.generate( + func_req=req, + output_dir=output_dir, + language=language, + knowledge=knowledge, + signature=sig, + ) + results.append(code_file) + if on_progress: + on_progress(i + 1, total, req, code_file, None) + except Exception as e: + if on_progress: + on_progress(i + 1, total, req, None, e) + + return results + + # ══════════════════════════════════════════════════ + # 私有工具方法 + # ══════════════════════════════════════════════════ + + @staticmethod + def _build_knowledge_section(knowledge: str) -> str: + """构建知识库 Prompt 段落""" + if not knowledge or not knowledge.strip(): + return "" + return ( + "## 参考知识库(实现时请遵循以下规范)\n" + f"{knowledge}\n\n---\n" + ) + + @staticmethod + def _build_signature_json( + signature: Optional[dict], + func_req: FunctionalRequirement, + ) -> str: + """ + 将签名 dict 序列化为格式化 JSON 字符串; + 若签名为 None,则构造最小占位签名,保持 Prompt 结构完整。 + + Args: + signature: 签名 dict 或 None + func_req: 对应的功能需求(用于占位签名) + + Returns: + JSON 字符串 + """ + if signature: + return json.dumps(signature, ensure_ascii=False, indent=2) + # 无签名时的最小占位,提示 LLM 自行设计但保持格式 + fallback = { + "name": func_req.function_name, + "requirement_id": f"REQ.{func_req.index_no:02d}", + "description": func_req.description, + "type": "function", + "parameters": "<<请根据功能描述自行设计参数>>", + "return": "<<请根据功能描述自行设计返回值>>", + } + return json.dumps(fallback, ensure_ascii=False, indent=2) + + @staticmethod + def _build_signature_map( + func_reqs: List[FunctionalRequirement], + signatures: Optional[List[dict]], + ) -> dict: + """ + 构建 func_req.id → signature 映射表 + + Args: + func_reqs: 功能需求列表 + signatures: 与 func_reqs 等长的签名列表,或 None + + Returns: + {req_id: signature_dict} 字典 + """ + if not signatures: + return {} + sig_map = {} + for req, sig in zip(func_reqs, signatures): + if req.id is not None and sig: + sig_map[req.id] = sig + return sig_map \ No newline at end of file diff --git a/requirements_generator/core/llm_client.py b/requirements_generator/core/llm_client.py new file mode 100644 index 0000000..9a2701e --- /dev/null +++ b/requirements_generator/core/llm_client.py @@ -0,0 +1,90 @@ +# core/llm_client.py - LLM 客户端封装 +import json +from typing import Optional + +import config + + +class LLMClient: + """ + OpenAI 兼容 LLM 客户端封装。 + 支持任何兼容 OpenAI API 格式的服务(OpenAI / Azure / 本地模型等)。 + """ + + def __init__( + self, + api_key: str = config.LLM_API_KEY, + base_url: str = config.LLM_BASE_URL, + model: str = config.LLM_MODEL, + temperature: float = config.LLM_TEMPERATURE, + ): + """ + 初始化 LLM 客户端 + + Args: + api_key: API 密钥 + base_url: API 基础 URL + model: 模型名称 + temperature: 生成温度(0~1,越低越确定) + + Raises: + ImportError: 未安装 openai 库 + ValueError: api_key 为空 + """ + try: + from openai import OpenAI + except ImportError: + raise ImportError("请安装 openai: pip install openai") + + if not api_key: + raise ValueError("LLM_API_KEY 未配置,请在 .env 文件中设置") + + self.model = model + self.temperature = temperature + self._client = OpenAI(api_key=api_key, base_url=base_url) + + def chat(self, system_prompt: str, user_prompt: str) -> str: + """ + 发送对话请求,返回模型回复文本 + + Args: + system_prompt: 系统提示词 + user_prompt: 用户输入 + + Returns: + 模型回复的文本内容 + + Raises: + Exception: API 调用失败 + """ + response = self._client.chat.completions.create( + model=self.model, + temperature=self.temperature, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + ) + return response.choices[0].message.content.strip() + + def chat_json(self, system_prompt: str, user_prompt: str) -> dict: + """ + 发送对话请求,解析并返回 JSON 结果 + + Args: + system_prompt: 系统提示词 + user_prompt: 用户输入 + + Returns: + 解析后的 dict 对象 + + Raises: + json.JSONDecodeError: 模型返回非合法 JSON + """ + raw = self.chat(system_prompt, user_prompt) + # 去除可能的 markdown 代码块包裹 + raw = raw.strip() + if raw.startswith("```"): + lines = raw.split("\n") + raw = "\n".join(lines[1:-1]) + return json.loads(raw) \ No newline at end of file diff --git a/requirements_generator/core/requirement_analyzer.py b/requirements_generator/core/requirement_analyzer.py new file mode 100644 index 0000000..15adda7 --- /dev/null +++ b/requirements_generator/core/requirement_analyzer.py @@ -0,0 +1,239 @@ +# core/requirement_analyzer.py - 需求分解 & 函数签名生成 +import re +from typing import List, Optional + +import config +from core.llm_client import LLMClient +from database.models import FunctionalRequirement + + +class RequirementAnalyzer: + """ + 使用 LLM 将原始需求分解为功能需求列表,并生成函数接口签名。 + 支持注入知识库上下文以提升分解质量。 + """ + + def __init__(self, llm_client: Optional[LLMClient] = None): + """ + 初始化需求分析器 + + Args: + llm_client: LLM 客户端实例,为 None 时自动创建 + """ + self.llm = llm_client or LLMClient() + + # ══════════════════════════════════════════════════ + # 需求分解 + # ══════════════════════════════════════════════════ + + def decompose( + self, + raw_requirement: str, + project_id: int, + raw_req_id: int, + knowledge: str = "", + ) -> List[FunctionalRequirement]: + """ + 将原始需求分解为功能需求列表 + + Args: + raw_requirement: 原始需求文本 + project_id: 所属项目 ID + raw_req_id: 原始需求记录 ID + knowledge: 知识库文本(可选) + + Returns: + FunctionalRequirement 对象列表(未持久化,id=None) + + Raises: + ValueError: LLM 返回格式不合法 + json.JSONDecodeError: JSON 解析失败 + """ + knowledge_section = self._build_knowledge_section(knowledge) + prompt = config.DECOMPOSE_PROMPT_TEMPLATE.format( + knowledge_section=knowledge_section, + raw_requirement=raw_requirement, + ) + + result = self.llm.chat_json( + system_prompt="你是一位资深软件架构师,擅长需求分析与系统设计。", + user_prompt=prompt, + ) + + items = result.get("functional_requirements", []) + if not items: + raise ValueError("LLM 未返回任何功能需求,请检查原始需求描述") + + requirements = [] + for item in items: + req = FunctionalRequirement( + project_id=project_id, + raw_req_id=raw_req_id, + index_no=int(item.get("index", len(requirements) + 1)), + title=item.get("title", "未命名功能"), + description=item.get("description", ""), + function_name=self._sanitize_function_name( + item.get("function_name", f"func_{len(requirements)+1}") + ), + priority=item.get("priority", "medium"), + ) + requirements.append(req) + + return requirements + + # ══════════════════════════════════════════════════ + # 函数签名生成(新增) + # ══════════════════════════════════════════════════ + + def build_function_signature( + self, + func_req: FunctionalRequirement, + knowledge: str = "", + ) -> dict: + """ + 为单个功能需求生成函数接口签名 JSON + + Args: + func_req: 功能需求对象(需含有效 id) + knowledge: 知识库文本(可选) + + Returns: + 符合接口规范的 dict,包含 name/requirement_id/description/ + type/parameters/return 字段 + + Raises: + json.JSONDecodeError: LLM 返回非合法 JSON + """ + requirement_id = self._format_requirement_id(func_req.index_no) + knowledge_section = self._build_knowledge_section(knowledge) + + prompt = config.FUNC_SIGNATURE_PROMPT_TEMPLATE.format( + knowledge_section=knowledge_section, + requirement_id=requirement_id, + title=func_req.title, + function_name=func_req.function_name, + description=func_req.description, + ) + + signature = self.llm.chat_json( + system_prompt=( + "你是一位资深软件架构师,专注于 API 接口设计。" + "只输出合法 JSON,不添加任何说明文字。" + ), + user_prompt=prompt, + ) + + # 确保关键字段存在,做兜底处理 + signature.setdefault("name", func_req.function_name) + signature.setdefault("requirement_id", requirement_id) + signature.setdefault("description", func_req.description) + signature.setdefault("type", "function") + signature.setdefault("parameters", {}) + signature.setdefault("return", {"type": "void", "description": ""}) + + return signature + + def build_function_signatures_batch( + self, + func_reqs: List[FunctionalRequirement], + knowledge: str = "", + on_progress=None, + ) -> List[dict]: + """ + 批量为功能需求列表生成函数接口签名 + + Args: + func_reqs: 功能需求列表 + knowledge: 知识库文本(可选) + on_progress: 进度回调 fn(index, total, func_req, signature, error) + + Returns: + 签名 dict 列表,顺序与 func_reqs 一致; + 生成失败的条目使用降级结构填充,不中断整体流程 + """ + results = [] + total = len(func_reqs) + + for i, req in enumerate(func_reqs): + try: + sig = self.build_function_signature(req, knowledge) + results.append(sig) + if on_progress: + on_progress(i + 1, total, req, sig, None) + except Exception as e: + # 降级:用基础信息填充,保证 JSON 完整性 + fallback = self._build_fallback_signature(req) + results.append(fallback) + if on_progress: + on_progress(i + 1, total, req, fallback, e) + + return results + + # ══════════════════════════════════════════════════ + # 私有工具方法 + # ══════════════════════════════════════════════════ + + @staticmethod + def _build_knowledge_section(knowledge: str) -> str: + """构建知识库 Prompt 段落""" + if not knowledge or not knowledge.strip(): + return "" + return f"""## 参考知识库 +{knowledge} + +--- +""" + + @staticmethod + def _sanitize_function_name(name: str) -> str: + """ + 清理函数名,确保符合 snake_case 规范 + + Args: + name: 原始函数名 + + Returns: + 合法的 snake_case 函数名 + """ + name = re.sub(r"[^a-zA-Z0-9_]", "_", name).lower() + name = re.sub(r"_+", "_", name).strip("_") + if name and name[0].isdigit(): + name = "func_" + name + return name or "unnamed_function" + + @staticmethod + def _format_requirement_id(index_no: int) -> str: + """ + 将序号格式化为需求编号字符串 + + Args: + index_no: 功能需求序号(从 1 开始) + + Returns: + 格式化编号,如 'REQ.01'、'REQ.12' + """ + return f"REQ.{index_no:02d}" + + @staticmethod + def _build_fallback_signature(func_req: FunctionalRequirement) -> dict: + """ + 构建降级签名(LLM 调用失败时使用) + + Args: + func_req: 功能需求对象 + + Returns: + 包含基础信息的签名 dict + """ + return { + "name": func_req.function_name, + "requirement_id": f"REQ.{func_req.index_no:02d}", + "description": func_req.description, + "type": "function", + "parameters": {}, + "return": { + "type": "void", + "description": "TODO: define return value" + }, + "_note": "Auto-generated fallback due to LLM error" + } \ No newline at end of file diff --git a/requirements_generator/database/__init__.py b/requirements_generator/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/requirements_generator/database/db_manager.py b/requirements_generator/database/db_manager.py new file mode 100644 index 0000000..95ad904 --- /dev/null +++ b/requirements_generator/database/db_manager.py @@ -0,0 +1,314 @@ +# database/db_manager.py - 数据库操作管理器 +import sqlite3 +import os +from datetime import datetime +from typing import List, Optional +from contextlib import contextmanager + +from database.models import ( + CREATE_TABLES_SQL, Project, RawRequirement, + FunctionalRequirement, CodeFile +) +import config + + +class DBManager: + """SQLite 数据库管理器,封装所有 CRUD 操作""" + + def __init__(self, db_path: str = config.DB_PATH): + self.db_path = db_path + os.makedirs(os.path.dirname(db_path), exist_ok=True) + self._init_db() + + # ── 连接上下文管理器 ────────────────────────────────── + + @contextmanager + def _get_conn(self): + """获取数据库连接(自动提交/回滚)""" + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") + try: + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + def _init_db(self): + """初始化数据库,创建所有表""" + with self._get_conn() as conn: + conn.executescript(CREATE_TABLES_SQL) + + # ══════════════════════════════════════════════════ + # Project CRUD + # ══════════════════════════════════════════════════ + + def create_project(self, project: Project) -> int: + """创建项目,返回新项目 ID""" + sql = """ + INSERT INTO projects (name, description, language, output_dir, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + """ + with self._get_conn() as conn: + cur = conn.execute(sql, ( + project.name, project.description, project.language, + project.output_dir, project.created_at, project.updated_at + )) + return cur.lastrowid + + def get_project_by_id(self, project_id: int) -> Optional[Project]: + """根据 ID 查询项目""" + with self._get_conn() as conn: + row = conn.execute( + "SELECT * FROM projects WHERE id = ?", (project_id,) + ).fetchone() + if row is None: + return None + return Project( + id=row["id"], name=row["name"], description=row["description"], + language=row["language"], output_dir=row["output_dir"], + created_at=row["created_at"], updated_at=row["updated_at"] + ) + + def get_project_by_name(self, name: str) -> Optional[Project]: + """根据名称查询项目""" + with self._get_conn() as conn: + row = conn.execute( + "SELECT * FROM projects WHERE name = ?", (name,) + ).fetchone() + if row is None: + return None + return Project( + id=row["id"], name=row["name"], description=row["description"], + language=row["language"], output_dir=row["output_dir"], + created_at=row["created_at"], updated_at=row["updated_at"] + ) + + def list_projects(self) -> List[Project]: + """列出所有项目""" + with self._get_conn() as conn: + rows = conn.execute( + "SELECT * FROM projects ORDER BY created_at DESC" + ).fetchall() + return [ + Project( + id=r["id"], name=r["name"], description=r["description"], + language=r["language"], output_dir=r["output_dir"], + created_at=r["created_at"], updated_at=r["updated_at"] + ) for r in rows + ] + + def update_project(self, project: Project) -> None: + """更新项目信息""" + project.updated_at = datetime.now().isoformat() + sql = """ + UPDATE projects + SET name=?, description=?, language=?, output_dir=?, updated_at=? + WHERE id=? + """ + with self._get_conn() as conn: + conn.execute(sql, ( + project.name, project.description, project.language, + project.output_dir, project.updated_at, project.id + )) + + def delete_project(self, project_id: int) -> None: + """删除项目(级联删除所有关联数据)""" + with self._get_conn() as conn: + conn.execute("DELETE FROM projects WHERE id = ?", (project_id,)) + + # ══════════════════════════════════════════════════ + # RawRequirement CRUD + # ══════════════════════════════════════════════════ + + def create_raw_requirement(self, req: RawRequirement) -> int: + """创建原始需求,返回新记录 ID""" + sql = """ + INSERT INTO raw_requirements + (project_id, content, source_type, source_name, knowledge, created_at) + VALUES (?, ?, ?, ?, ?, ?) + """ + with self._get_conn() as conn: + cur = conn.execute(sql, ( + req.project_id, req.content, req.source_type, + req.source_name, req.knowledge, req.created_at + )) + return cur.lastrowid + + def get_raw_requirement(self, req_id: int) -> Optional[RawRequirement]: + """根据 ID 查询原始需求""" + with self._get_conn() as conn: + row = conn.execute( + "SELECT * FROM raw_requirements WHERE id = ?", (req_id,) + ).fetchone() + if row is None: + return None + return RawRequirement( + id=row["id"], project_id=row["project_id"], content=row["content"], + source_type=row["source_type"], source_name=row["source_name"], + knowledge=row["knowledge"], created_at=row["created_at"] + ) + + def list_raw_requirements_by_project(self, project_id: int) -> List[RawRequirement]: + """查询项目下所有原始需求""" + with self._get_conn() as conn: + rows = conn.execute( + "SELECT * FROM raw_requirements WHERE project_id = ? ORDER BY created_at", + (project_id,) + ).fetchall() + return [ + RawRequirement( + id=r["id"], project_id=r["project_id"], content=r["content"], + source_type=r["source_type"], source_name=r["source_name"], + knowledge=r["knowledge"], created_at=r["created_at"] + ) for r in rows + ] + + # ══════════════════════════════════════════════════ + # FunctionalRequirement CRUD + # ══════════════════════════════════════════════════ + + def create_functional_requirement(self, req: FunctionalRequirement) -> int: + """创建功能需求,返回新记录 ID""" + sql = """ + INSERT INTO functional_requirements + (project_id, raw_req_id, index_no, title, description, + function_name, priority, status, is_custom, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """ + with self._get_conn() as conn: + cur = conn.execute(sql, ( + req.project_id, req.raw_req_id, req.index_no, + req.title, req.description, req.function_name, + req.priority, req.status, int(req.is_custom), + req.created_at, req.updated_at + )) + return cur.lastrowid + + def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]: + """根据 ID 查询功能需求""" + with self._get_conn() as conn: + row = conn.execute( + "SELECT * FROM functional_requirements WHERE id = ?", (req_id,) + ).fetchone() + if row is None: + return None + return self._row_to_func_req(row) + + def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]: + """查询项目下所有功能需求(按序号排序)""" + with self._get_conn() as conn: + rows = conn.execute( + """SELECT * FROM functional_requirements + WHERE project_id = ? ORDER BY index_no""", + (project_id,) + ).fetchall() + return [self._row_to_func_req(r) for r in rows] + + def update_functional_requirement(self, req: FunctionalRequirement) -> None: + """更新功能需求""" + req.updated_at = datetime.now().isoformat() + sql = """ + UPDATE functional_requirements + SET title=?, description=?, function_name=?, priority=?, + status=?, index_no=?, updated_at=? + WHERE id=? + """ + with self._get_conn() as conn: + conn.execute(sql, ( + req.title, req.description, req.function_name, + req.priority, req.status, req.index_no, + req.updated_at, req.id + )) + + def delete_functional_requirement(self, req_id: int) -> None: + """删除功能需求""" + with self._get_conn() as conn: + conn.execute( + "DELETE FROM functional_requirements WHERE id = ?", (req_id,) + ) + + def _row_to_func_req(self, row) -> FunctionalRequirement: + """sqlite Row → FunctionalRequirement 对象""" + return FunctionalRequirement( + id=row["id"], project_id=row["project_id"], + raw_req_id=row["raw_req_id"], index_no=row["index_no"], + title=row["title"], description=row["description"], + function_name=row["function_name"], priority=row["priority"], + status=row["status"], is_custom=bool(row["is_custom"]), + created_at=row["created_at"], updated_at=row["updated_at"] + ) + + # ══════════════════════════════════════════════════ + # CodeFile CRUD + # ══════════════════════════════════════════════════ + + def create_code_file(self, code_file: CodeFile) -> int: + """创建代码文件记录,返回新记录 ID""" + sql = """ + INSERT INTO code_files + (project_id, func_req_id, file_name, file_path, + language, content, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """ + with self._get_conn() as conn: + cur = conn.execute(sql, ( + code_file.project_id, code_file.func_req_id, + code_file.file_name, code_file.file_path, + code_file.language, code_file.content, + code_file.created_at, code_file.updated_at + )) + return cur.lastrowid + + def upsert_code_file(self, code_file: CodeFile) -> int: + """插入或更新代码文件(按 func_req_id 唯一键)""" + existing = self.get_code_file_by_func_req(code_file.func_req_id) + if existing: + code_file.id = existing.id + code_file.updated_at = datetime.now().isoformat() + sql = """ + UPDATE code_files + SET file_name=?, file_path=?, language=?, content=?, updated_at=? + WHERE id=? + """ + with self._get_conn() as conn: + conn.execute(sql, ( + code_file.file_name, code_file.file_path, + code_file.language, code_file.content, + code_file.updated_at, code_file.id + )) + return code_file.id + else: + return self.create_code_file(code_file) + + def get_code_file_by_func_req(self, func_req_id: int) -> Optional[CodeFile]: + """根据功能需求 ID 查询代码文件""" + with self._get_conn() as conn: + row = conn.execute( + "SELECT * FROM code_files WHERE func_req_id = ?", (func_req_id,) + ).fetchone() + if row is None: + return None + return self._row_to_code_file(row) + + def list_code_files_by_project(self, project_id: int) -> List[CodeFile]: + """查询项目下所有代码文件""" + with self._get_conn() as conn: + rows = conn.execute( + "SELECT * FROM code_files WHERE project_id = ? ORDER BY created_at", + (project_id,) + ).fetchall() + return [self._row_to_code_file(r) for r in rows] + + def _row_to_code_file(self, row) -> CodeFile: + """sqlite Row → CodeFile 对象""" + return CodeFile( + id=row["id"], project_id=row["project_id"], + func_req_id=row["func_req_id"], file_name=row["file_name"], + file_path=row["file_path"], language=row["language"], + content=row["content"], created_at=row["created_at"], + updated_at=row["updated_at"] + ) \ No newline at end of file diff --git a/requirements_generator/database/models.py b/requirements_generator/database/models.py new file mode 100644 index 0000000..b077013 --- /dev/null +++ b/requirements_generator/database/models.py @@ -0,0 +1,122 @@ +# database/models.py - 数据模型定义(SQLite 建表 DDL) +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional + +# ══════════════════════════════════════════════════════ +# DDL 建表语句 +# ══════════════════════════════════════════════════════ + +CREATE_TABLES_SQL = """ +PRAGMA foreign_keys = ON; + +-- 项目表 +CREATE TABLE IF NOT EXISTS projects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, -- 项目名称 + description TEXT, -- 项目描述 + language TEXT NOT NULL DEFAULT 'python', -- 目标代码语言 + output_dir TEXT NOT NULL, -- 输出目录路径 + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +-- 原始需求表 +CREATE TABLE IF NOT EXISTS raw_requirements ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER NOT NULL, -- 关联项目 + content TEXT NOT NULL, -- 需求原文 + source_type TEXT NOT NULL DEFAULT 'text', -- text | file + source_name TEXT, -- 文件名(文件输入时) + knowledge TEXT, -- 合并后的知识库内容 + created_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE +); + +-- 功能需求表 +CREATE TABLE IF NOT EXISTS functional_requirements ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER NOT NULL, + raw_req_id INTEGER NOT NULL, -- 关联原始需求 + index_no INTEGER NOT NULL, -- 序号 + title TEXT NOT NULL, -- 功能标题 + description TEXT NOT NULL, -- 功能描述 + function_name TEXT NOT NULL, -- 对应函数名 + priority TEXT NOT NULL DEFAULT 'medium', -- high|medium|low + status TEXT NOT NULL DEFAULT 'pending', -- pending|generated|skipped + is_custom INTEGER NOT NULL DEFAULT 0, -- 是否用户自定义 (0/1) + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (raw_req_id) REFERENCES raw_requirements(id) ON DELETE CASCADE +); + +-- 代码文件表 +CREATE TABLE IF NOT EXISTS code_files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER NOT NULL, + func_req_id INTEGER NOT NULL UNIQUE, -- 关联功能需求(1对1) + file_name TEXT NOT NULL, -- 文件名 + file_path TEXT NOT NULL, -- 完整路径 + language TEXT NOT NULL, -- 代码语言 + content TEXT NOT NULL, -- 代码内容 + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (func_req_id) REFERENCES functional_requirements(id) ON DELETE CASCADE +); +""" + +# ══════════════════════════════════════════════════════ +# 数据类(Python 对象映射) +# ══════════════════════════════════════════════════════ + +@dataclass +class Project: + name: str + output_dir: str + language: str = "python" + description: str = "" + id: Optional[int] = None + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) + + +@dataclass +class RawRequirement: + project_id: int + content: str + source_type: str = "text" # text | file + source_name: Optional[str] = None + knowledge: Optional[str] = None + id: Optional[int] = None + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + + +@dataclass +class FunctionalRequirement: + project_id: int + raw_req_id: int + index_no: int + title: str + description: str + function_name: str + priority: str = "medium" + status: str = "pending" + is_custom: bool = False + id: Optional[int] = None + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) + + +@dataclass +class CodeFile: + project_id: int + func_req_id: int + file_name: str + file_path: str + language: str + content: str + id: Optional[int] = None + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) \ No newline at end of file diff --git a/requirements_generator/main.py b/requirements_generator/main.py new file mode 100644 index 0000000..fae0d51 --- /dev/null +++ b/requirements_generator/main.py @@ -0,0 +1,789 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# main.py - 主入口:支持交互式 & 非交互式(CLI 参数)两种运行模式 +# +# 交互式: python main.py +# 非交互式:python main.py --non-interactive \ +# --project-name "MyProject" \ +# --language python \ +# --requirement-text "用户管理系统,包含注册、登录、修改密码功能" +# +# 完整参数见:python main.py --help +import os +import sys +from typing import Dict + +import click +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.prompt import Prompt, Confirm + +import config +from database.db_manager import DBManager +from database.models import Project, RawRequirement, FunctionalRequirement +from core.llm_client import LLMClient +from core.requirement_analyzer import RequirementAnalyzer +from core.code_generator import CodeGenerator +from utils.file_handler import read_file_auto, merge_knowledge_files +from utils.output_writer import ( + ensure_project_dir, build_project_output_dir, write_project_readme, + write_function_signatures_json, validate_all_signatures, + patch_signatures_with_url, +) + +console = Console() +db = DBManager() + + +# ══════════════════════════════════════════════════════ +# 显示工具函数 +# ══════════════════════════════════════════════════════ + +def print_banner(): + console.print(Panel.fit( + "[bold cyan]🚀 需求分析 & 代码生成工具[/bold cyan]\n" + "[dim]Powered by LLM · SQLite · Python[/dim]", + border_style="cyan" + )) + + +def print_functional_requirements(reqs: list): + """以表格形式展示功能需求列表""" + table = Table(title="📋 功能需求列表", show_lines=True) + table.add_column("序号", style="cyan", width=6) + table.add_column("ID", style="dim", width=6) + table.add_column("标题", style="bold", width=20) + table.add_column("函数名", width=25) + table.add_column("优先级", width=8) + table.add_column("类型", width=8) + table.add_column("描述", width=40) + + priority_color = {"high": "red", "medium": "yellow", "low": "green"} + for req in reqs: + color = priority_color.get(req.priority, "white") + table.add_row( + str(req.index_no), + str(req.id) if req.id else "-", + req.title, + f"[code]{req.function_name}[/code]", + f"[{color}]{req.priority}[/{color}]", + "[magenta]自定义[/magenta]" if req.is_custom else "LLM生成", + req.description[:60] + "..." if len(req.description) > 60 else req.description, + ) + console.print(table) + + +def print_signatures_preview(signatures: list): + """ + 以表格形式预览函数签名列表(含 url 字段) + + Args: + signatures: 纯签名列表(顶层文档的 "functions" 字段) + """ + table = Table(title="📄 函数签名预览", show_lines=True) + table.add_column("需求编号", style="cyan", width=10) + table.add_column("函数名", style="bold", width=25) + table.add_column("参数数量", width=8) + table.add_column("返回类型", width=10) + table.add_column("成功返回值", width=18) + table.add_column("失败返回值", width=18) + table.add_column("URL", style="dim", width=30) + + def _fmt_value(v) -> str: + if v is None: + return "-" + if isinstance(v, dict): + return "{" + ", ".join(v.keys()) + "}" + return str(v)[:16] + + for sig in signatures: + ret = sig.get("return") or {} + on_success = ret.get("on_success") or {} + on_failure = ret.get("on_failure") or {} + url = sig.get("url", "") + # 只显示文件名部分,避免路径过长 + url_display = os.path.basename(url) if url else "[dim]待生成[/dim]" + + table.add_row( + sig.get("requirement_id", "-"), + sig.get("name", "-"), + str(len(sig.get("parameters", {}))), + ret.get("type", "void"), + _fmt_value(on_success.get("value")), + _fmt_value(on_failure.get("value")), + url_display, + ) + console.print(table) + + +# ══════════════════════════════════════════════════════ +# Step 1:项目初始化 +# ══════════════════════════════════════════════════════ + +def step_init_project( + project_name: str = None, + language: str = None, + description: str = "", + non_interactive: bool = False, +) -> Project: + if not non_interactive: + console.print("\n[bold]Step 1 · 项目配置[/bold]", style="blue") + project_name = project_name or Prompt.ask("📁 请输入项目名称") + language = language or Prompt.ask( + "💻 目标代码语言", + default=config.DEFAULT_LANGUAGE, + choices=["python", "javascript", "typescript", "java", "go", "rust"], + ) + description = description or Prompt.ask("📝 项目描述(可选)", default="") + else: + if not project_name: + raise ValueError("非交互模式下 --project-name 为必填项") + language = language or config.DEFAULT_LANGUAGE + console.print("\n[bold]Step 1 · 项目配置[/bold] [dim](非交互)[/dim]", style="blue") + console.print(f" 项目名称: {project_name} 语言: {language}") + + existing = db.get_project_by_name(project_name) + if existing: + if non_interactive: + console.print(f"[green]✓ 已加载已有项目: {project_name} (ID={existing.id})[/green]") + return existing + use_existing = Confirm.ask(f"⚠️ 项目 '{project_name}' 已存在,是否继续使用?") + if use_existing: + console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]") + return existing + project_name = Prompt.ask("请输入新的项目名称") + + output_dir = build_project_output_dir(project_name) + project = Project( + name=project_name, + language=language, + output_dir=output_dir, + description=description, + ) + project.id = db.create_project(project) + console.print(f"[green]✓ 项目已创建: {project_name} (ID={project.id})[/green]") + return project + + +# ══════════════════════════════════════════════════════ +# Step 2:输入原始需求 & 知识库 +# ══════════════════════════════════════════════════════ + +def step_input_requirement( + project: Project, + requirement_text: str = None, + requirement_file: str = None, + knowledge_files: list = None, + non_interactive: bool = False, +) -> tuple: + console.print( + f"\n[bold]Step 2 · 输入原始需求[/bold]" + + (" [dim](非交互)[/dim]" if non_interactive else ""), + style="blue", + ) + + raw_text = "" + source_name = None + source_type = "text" + + if non_interactive: + if requirement_file: + raw_text = read_file_auto(requirement_file) + source_name = os.path.basename(requirement_file) + source_type = "file" + console.print(f" 需求文件: {source_name} ({len(raw_text)} 字符)") + elif requirement_text: + raw_text = requirement_text + source_type = "text" + console.print(f" 需求文本: {raw_text[:80]}{'...' if len(raw_text) > 80 else ''}") + else: + raise ValueError("非交互模式下必须提供 --requirement-text 或 --requirement-file") + else: + input_type = Prompt.ask("📥 需求输入方式", choices=["text", "file"], default="text") + if input_type == "text": + console.print("[dim]请输入原始需求(输入空行结束):[/dim]") + lines = [] + while True: + line = input() + if line == "" and lines: + break + lines.append(line) + raw_text = "\n".join(lines) + source_type = "text" + else: + file_path = Prompt.ask("📂 需求文件路径") + raw_text = read_file_auto(file_path) + source_name = os.path.basename(file_path) + source_type = "file" + console.print(f"[green]✓ 已读取文件: {source_name} ({len(raw_text)} 字符)[/green]") + + knowledge_text = "" + if non_interactive: + if knowledge_files: + knowledge_text = merge_knowledge_files(list(knowledge_files)) + console.print(f" 知识库: {len(knowledge_files)} 个文件,{len(knowledge_text)} 字符") + else: + use_kb = Confirm.ask("📚 是否输入知识库文件?", default=False) + if use_kb: + kb_paths = [] + while True: + kb_path = Prompt.ask("知识库文件路径(留空结束)", default="") + if not kb_path: + break + if os.path.exists(kb_path): + kb_paths.append(kb_path) + console.print(f" [green]+ {kb_path}[/green]") + else: + console.print(f" [red]文件不存在: {kb_path}[/red]") + if kb_paths: + knowledge_text = merge_knowledge_files(kb_paths) + console.print(f"[green]✓ 知识库已合并 ({len(knowledge_text)} 字符)[/green]") + + return raw_text, knowledge_text, source_name, source_type + + +# ══════════════════════════════════════════════════════ +# Step 3:LLM 分解需求 +# ══════════════════════════════════════════════════════ + +def step_decompose_requirements( + project: Project, + raw_text: str, + knowledge_text: str, + source_name: str, + source_type: str, + non_interactive: bool = False, +) -> tuple: + console.print( + f"\n[bold]Step 3 · LLM 需求分解[/bold]" + + (" [dim](非交互)[/dim]" if non_interactive else ""), + style="blue", + ) + + raw_req = RawRequirement( + project_id=project.id, + content=raw_text, + source_type=source_type, + source_name=source_name, + knowledge=knowledge_text or None, + ) + raw_req_id = db.create_raw_requirement(raw_req) + console.print(f"[dim]原始需求已存储 (ID={raw_req_id})[/dim]") + + with console.status("[bold yellow]🤖 LLM 正在分解需求,请稍候...[/bold yellow]"): + llm = LLMClient() + analyzer = RequirementAnalyzer(llm) + func_reqs = analyzer.decompose( + raw_requirement=raw_text, + project_id=project.id, + raw_req_id=raw_req_id, + knowledge=knowledge_text, + ) + + for req in func_reqs: + req.id = db.create_functional_requirement(req) + + console.print(f"[green]✓ 已生成 {len(func_reqs)} 个功能需求[/green]") + return raw_req_id, func_reqs + + +# ══════════════════════════════════════════════════════ +# Step 4:用户编辑功能需求 +# ══════════════════════════════════════════════════════ + +def step_edit_requirements( + project: Project, + func_reqs: list, + raw_req_id: int, + non_interactive: bool = False, + skip_indices: list = None, +) -> list: + console.print( + f"\n[bold]Step 4 · 编辑功能需求[/bold]" + + (" [dim](非交互)[/dim]" if non_interactive else ""), + style="blue", + ) + + if non_interactive: + if skip_indices: + to_skip = set(skip_indices) + removed, kept = [], [] + for req in func_reqs: + if req.index_no in to_skip: + db.delete_functional_requirement(req.id) + removed.append(req.title) + else: + kept.append(req) + func_reqs = kept + for i, req in enumerate(func_reqs, 1): + req.index_no = i + db.update_functional_requirement(req) + if removed: + console.print(f" [red]已跳过: {', '.join(removed)}[/red]") + print_functional_requirements(func_reqs) + return func_reqs + + while True: + print_functional_requirements(func_reqs) + console.print( + "\n操作: [cyan]d[/cyan]=删除 [cyan]a[/cyan]=添加 " + "[cyan]e[/cyan]=编辑 [cyan]ok[/cyan]=确认继续" + ) + action = Prompt.ask("请选择操作", default="ok").strip().lower() + + if action == "ok": + break + elif action == "d": + idx_str = Prompt.ask("输入要删除的功能需求序号(多个用逗号分隔)") + to_delete = {int(x.strip()) for x in idx_str.split(",") if x.strip().isdigit()} + removed, kept = [], [] + for req in func_reqs: + if req.index_no in to_delete: + db.delete_functional_requirement(req.id) + removed.append(req.title) + else: + kept.append(req) + func_reqs = kept + for i, req in enumerate(func_reqs, 1): + req.index_no = i + db.update_functional_requirement(req) + console.print(f"[red]✗ 已删除: {', '.join(removed)}[/red]") + elif action == "a": + title = Prompt.ask("功能标题") + description = Prompt.ask("功能描述") + func_name = Prompt.ask("函数名 (snake_case)") + priority = Prompt.ask( + "优先级", choices=["high", "medium", "low"], default="medium" + ) + new_req = FunctionalRequirement( + project_id=project.id, + raw_req_id=raw_req_id, + index_no=len(func_reqs) + 1, + title=title, + description=description, + function_name=func_name, + priority=priority, + is_custom=True, + ) + new_req.id = db.create_functional_requirement(new_req) + func_reqs.append(new_req) + console.print(f"[green]✓ 已添加自定义需求: {title}[/green]") + elif action == "e": + idx_str = Prompt.ask("输入要编辑的功能需求序号") + if not idx_str.isdigit(): + continue + idx = int(idx_str) + target = next((r for r in func_reqs if r.index_no == idx), None) + if target is None: + console.print("[red]序号不存在[/red]") + continue + target.title = Prompt.ask("新标题", default=target.title) + target.description = Prompt.ask("新描述", default=target.description) + target.function_name = Prompt.ask("新函数名", default=target.function_name) + target.priority = Prompt.ask( + "新优先级", choices=["high", "medium", "low"], default=target.priority + ) + db.update_functional_requirement(target) + console.print(f"[green]✓ 已更新: {target.title}[/green]") + + return func_reqs + + +# ══════════════════════════════════════════════════════ +# Step 5A:生成函数签名 JSON(不含 url 字段,待 5C 回写) +# ══════════════════════════════════════════════════════ + +def step_generate_signatures( + project: Project, + func_reqs: list, + output_dir: str, + knowledge_text: str, + json_file_name: str = "function_signatures.json", + non_interactive: bool = False, +) -> tuple: + """ + 为所有功能需求生成函数签名,写入初版 JSON(不含 url 字段)。 + url 字段将在 Step 5C 代码生成完成后回写并刷新 JSON 文件。 + + Returns: + (signatures: List[dict], json_path: str) + """ + console.print( + f"\n[bold]Step 5A · 生成函数签名 JSON[/bold]" + + (" [dim](非交互)[/dim]" if non_interactive else ""), + style="blue", + ) + + llm = LLMClient() + analyzer = RequirementAnalyzer(llm) + + success_count = 0 + fail_count = 0 + + def on_progress(index, total, req, signature, error): + nonlocal success_count, fail_count + if error: + console.print( + f" [{index}/{total}] [yellow]⚠ {req.title} 签名生成失败," + f"使用降级结构: {error}[/yellow]" + ) + fail_count += 1 + else: + console.print( + f" [{index}/{total}] [green]✓ {req.title}[/green] " + f"→ [dim]{signature.get('name')}()[/dim] " + f"params={len(signature.get('parameters', {}))}" + ) + success_count += 1 + + console.print(f"[yellow]正在为 {len(func_reqs)} 个功能需求生成函数签名...[/yellow]") + signatures = analyzer.build_function_signatures_batch( + func_reqs=func_reqs, + knowledge=knowledge_text, + on_progress=on_progress, + ) + + # 校验 + validation_report = validate_all_signatures(signatures) + if validation_report: + console.print(f"[yellow]⚠ 发现 {len(validation_report)} 个签名存在结构问题:[/yellow]") + for fname, errors in validation_report.items(): + for err in errors: + console.print(f" [yellow]· {fname}: {err}[/yellow]") + else: + console.print("[green]✓ 所有签名结构校验通过[/green]") + + # 写入初版 JSON(url 字段尚未填入) + json_path = write_function_signatures_json( + output_dir=output_dir, + signatures=signatures, + project_name=project.name, + project_description=project.description or "", # ← 传入项目描述 + file_name=json_file_name, + ) + console.print( + f"[green]✓ 签名 JSON 初版已写入: [cyan]{os.path.abspath(json_path)}[/cyan][/green]\n" + f" 成功: {success_count} 降级: {fail_count}" + ) + return signatures, json_path + + +# ══════════════════════════════════════════════════════ +# Step 5B:生成代码文件,收集 {函数名: 文件路径} 映射 +# ══════════════════════════════════════════════════════ + +def step_generate_code( + project: Project, + func_reqs: list, + output_dir: str, + knowledge_text: str, + signatures: list, + non_interactive: bool = False, +) -> Dict[str, str]: + """ + 依据签名约束批量生成代码文件。 + + Returns: + func_name_to_url: {函数名: 代码文件绝对路径} 映射表, + 供 Step 5C 回写 url 字段使用。 + 生成失败的函数不会出现在映射表中。 + """ + console.print( + f"\n[bold]Step 5B · 生成代码文件[/bold]" + + (" [dim](非交互)[/dim]" if non_interactive else ""), + style="blue", + ) + + generator = CodeGenerator(LLMClient()) + success_count = 0 + fail_count = 0 + func_name_to_url: Dict[str, str] = {} # ← 收集 函数名 → 文件绝对路径 + + def on_progress(index, total, req, code_file, error): + nonlocal success_count, fail_count + if error: + console.print(f" [{index}/{total}] [red]✗ {req.title}: {error}[/red]") + fail_count += 1 + else: + db.upsert_code_file(code_file) + req.status = "generated" + db.update_functional_requirement(req) + # 收集 函数名 → 绝对文件路径(作为 url 回写) + func_name_to_url[req.function_name] = os.path.abspath(code_file.file_path) + console.print( + f" [{index}/{total}] [green]✓ {req.title}[/green] " + f"→ [dim]{code_file.file_name}[/dim]" + ) + success_count += 1 + + console.print(f"[yellow]开始生成 {len(func_reqs)} 个代码文件(签名约束模式)...[/yellow]") + generator.generate_batch( + func_reqs=func_reqs, + output_dir=output_dir, + language=project.language, + knowledge=knowledge_text, + signatures=signatures, + on_progress=on_progress, + ) + + req_summary = "\n".join( + f"{i+1}. **{r.title}** (`{r.function_name}`) - {r.description[:80]}" + for i, r in enumerate(func_reqs) + ) + write_project_readme(output_dir, project.name, req_summary) + + console.print(Panel( + f"[bold green]✅ 代码生成完成![/bold green]\n" + f"成功: {success_count} 失败: {fail_count}\n" + f"输出目录: [cyan]{os.path.abspath(output_dir)}[/cyan]", + border_style="green", + )) + return func_name_to_url + + +# ══════════════════════════════════════════════════════ +# Step 5C:回写 url 字段并刷新 JSON +# ══════════════════════════════════════════════════════ + +def step_patch_signatures_url( + project: Project, + signatures: list, + func_name_to_url: Dict[str, str], + output_dir: str, + json_file_name: str, + non_interactive: bool = False, +) -> str: + """ + 将代码文件路径回写到签名的 "url" 字段,并重新写入 JSON 文件。 + + 执行流程: + 1. 调用 patch_signatures_with_url() 原地修改签名列表 + 2. 打印最终签名预览(含 url 列) + 3. 重新调用 write_function_signatures_json() 覆盖写入 JSON + + Args: + project: 项目对象(提供 name 与 description) + signatures: Step 5A 产出的签名列表(将被原地修改) + func_name_to_url: Step 5B 收集的 {函数名: 文件绝对路径} 映射 + output_dir: JSON 文件所在目录 + json_file_name: JSON 文件名 + non_interactive: 是否非交互模式 + + Returns: + 刷新后的 JSON 文件绝对路径 + """ + console.print( + f"\n[bold]Step 5C · 回写代码文件路径(url)到签名 JSON[/bold]" + + (" [dim](非交互)[/dim]" if non_interactive else ""), + style="blue", + ) + + # 原地回写 url 字段 + patch_signatures_with_url(signatures, func_name_to_url) + + patched = sum(1 for s in signatures if s.get("url")) + unpatched = len(signatures) - patched + if unpatched: + console.print( + f"[yellow]⚠ {unpatched} 个函数未能写入 url" + f"(对应代码文件生成失败)[/yellow]" + ) + + # 打印最终预览(含 url 列) + print_signatures_preview(signatures) + + # 覆盖写入 JSON(含 project.description) + json_path = write_function_signatures_json( + output_dir=output_dir, + signatures=signatures, + project_name=project.name, + project_description=project.description or "", + file_name=json_file_name, + ) + + console.print( + f"[green]✓ 签名 JSON 已更新(含 url): " + f"[cyan]{os.path.abspath(json_path)}[/cyan][/green]\n" + f" 已回写: {patched} 未回写: {unpatched}" + ) + return os.path.abspath(json_path) + + +# ══════════════════════════════════════════════════════ +# 核心工作流 +# ══════════════════════════════════════════════════════ + +def run_workflow( + project_name: str = None, + language: str = None, + description: str = "", + requirement_text: str = None, + requirement_file: str = None, + knowledge_files: tuple = (), + skip_indices: list = None, + json_file_name: str = "function_signatures.json", + non_interactive: bool = False, +): + """完整工作流(Step 1 → 5C)""" + print_banner() + + # Step 1 + project = step_init_project( + project_name=project_name, + language=language, + description=description, + non_interactive=non_interactive, + ) + + # Step 2 + raw_text, knowledge_text, source_name, source_type = step_input_requirement( + project=project, + requirement_text=requirement_text, + requirement_file=requirement_file, + knowledge_files=list(knowledge_files) if knowledge_files else [], + non_interactive=non_interactive, + ) + + # Step 3 + raw_req_id, func_reqs = step_decompose_requirements( + project=project, + raw_text=raw_text, + knowledge_text=knowledge_text, + source_name=source_name, + source_type=source_type, + non_interactive=non_interactive, + ) + + # Step 4 + func_reqs = step_edit_requirements( + project=project, + func_reqs=func_reqs, + raw_req_id=raw_req_id, + non_interactive=non_interactive, + skip_indices=skip_indices or [], + ) + + if not func_reqs: + console.print("[red]⚠ 功能需求列表为空,流程终止[/red]") + return + + output_dir = ensure_project_dir(project.name) + + # Step 5A:生成签名(初版,不含 url) + signatures, json_path = step_generate_signatures( + project=project, + func_reqs=func_reqs, + output_dir=output_dir, + knowledge_text=knowledge_text, + json_file_name=json_file_name, + non_interactive=non_interactive, + ) + + # Step 5B:生成代码,收集 {函数名: 文件路径} + func_name_to_url = step_generate_code( + project=project, + func_reqs=func_reqs, + output_dir=output_dir, + knowledge_text=knowledge_text, + signatures=signatures, + non_interactive=non_interactive, + ) + + # Step 5C:回写 url 字段,刷新 JSON + json_path = step_patch_signatures_url( + project=project, + signatures=signatures, + func_name_to_url=func_name_to_url, + output_dir=output_dir, + json_file_name=json_file_name, + non_interactive=non_interactive, + ) + + console.print(Panel( + f"[bold cyan]🎉 全部流程完成![/bold cyan]\n" + f"项目: [bold]{project.name}[/bold]\n" + f"描述: {project.description or '(无)'}\n" + f"代码目录: [cyan]{os.path.abspath(output_dir)}[/cyan]\n" + f"签名文件: [cyan]{json_path}[/cyan]", + border_style="cyan", + )) + + +# ══════════════════════════════════════════════════════ +# CLI 入口(click) +# ══════════════════════════════════════════════════════ + +@click.command() +@click.option("--non-interactive", is_flag=True, default=False, + help="以非交互模式运行(所有参数通过命令行传入)") +@click.option("--project-name", "-p", default=None, help="项目名称") +@click.option("--language", "-l", default=None, + type=click.Choice(["python","javascript","typescript","java","go","rust"]), + help=f"目标代码语言(默认: {config.DEFAULT_LANGUAGE})") +@click.option("--description", "-d", default="", help="项目描述") +@click.option("--requirement-text","-r", default=None, + help="原始需求文本(与 --requirement-file 二选一)") +@click.option("--requirement-file","-f", default=None, + type=click.Path(exists=True), + help="原始需求文件路径(支持 .txt/.md/.pdf/.docx)") +@click.option("--knowledge-file", "-k", default=None, multiple=True, + type=click.Path(exists=True), + help="知识库文件路径(可多次指定,如 -k a.md -k b.pdf)") +@click.option("--skip-index", "-s", default=None, multiple=True, type=int, + help="要跳过的功能需求序号(可多次指定,如 -s 2 -s 5)") +@click.option("--json-file-name", "-j", default="function_signatures.json", + help="函数签名 JSON 文件名(默认: function_signatures.json)") +def cli( + non_interactive, project_name, language, description, + requirement_text, requirement_file, knowledge_file, + skip_index, json_file_name, +): + """ + 需求分析 & 代码生成工具 + + \b + 交互式运行(推荐初次使用): + python main.py + + \b + 非交互式运行示例: + python main.py --non-interactive \\ + --project-name "UserSystem" \\ + --description "用户管理系统后端服务" \\ + --language python \\ + --requirement-text "用户管理系统,包含注册、登录、修改密码功能" \\ + --knowledge-file docs/api_spec.md \\ + --json-file-name api_signatures.json + + \b + 从文件读取需求 + 跳过部分功能需求: + python main.py --non-interactive \\ + --project-name "MyProject" \\ + --requirement-file requirements.md \\ + --skip-index 3 --skip-index 7 + """ + try: + run_workflow( + project_name=project_name, + language=language, + description=description, + requirement_text=requirement_text, + requirement_file=requirement_file, + knowledge_files=knowledge_file, + skip_indices=list(skip_index) if skip_index else [], + json_file_name=json_file_name, + non_interactive=non_interactive, + ) + except KeyboardInterrupt: + console.print("\n[yellow]用户中断,退出[/yellow]") + sys.exit(0) + except Exception as e: + console.print(f"\n[bold red]❌ 错误: {e}[/bold red]") + import traceback + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + cli() \ No newline at end of file diff --git a/requirements_generator/prompt.txt b/requirements_generator/prompt.txt new file mode 100644 index 0000000..f18d056 --- /dev/null +++ b/requirements_generator/prompt.txt @@ -0,0 +1,51 @@ +你是一个产品经理与高级软件工程师,非常善于将用户原始需求进行理解、分析并生成条目化的功能需求,再根据每个功能需求生成指定的功能函数,现在生成python工程实现以上功能,设计如下: + +# 工作流程设计 +1. 用户输入:知识库,原始需求描述、项目名称。其中原始需求支持文本与文件格式,知识库可选,通过一个或多个文件输入; +2. 软件通过llm模型,结合输入的知识库,将原始需求分解成多个可实现的功能需求,并显示给用户; +3. 软件支持用户删除其中不需要的功能需求,同时支持添加自定义功能需求; +4. 软件根据配置后的功能需求,生成指定代码(默认使用python)的功能函数文件输出到指定项目名称的目录中,一个功能需求对应一个函数、一个代码文件; + +# 数据库设计 +1. 使用sqlite数据库存储相关信息,包含项目表、原始需求表,功能需求表、代码文件表; +2. 数据库支持通过ID实现项目、原始需求、功能需求、代码文件之间的关联关系; + +更新以上代码,支持以json格式将所有的功能函数描述输出到json文件,格式如下: +[ + { + "name": "change_password", + "requirement_id": "the id related to the requirement. e.g.REQ.01", + "description": "replace old password with new password", + "type": "function", + "parameters": { + "user_id": { "type": "integer", "inout": "in", "description": "the user's id", "required": true }, + "old_password": { "type": "string", "inout": "in", "description": "the old password", "required": true }, + "new_password": { "type": "string", "inout": "in", "description": "the user's new password", "required": true } + }, + "return": { + "type": "integer", + "description": "0 is successful, nonzero is failure" + } + }, + { + "name": "set_and_get_system_status", + "requirement_id": "the id related to the requirement. e.g. REQ.02", + "description": "set and get the system status", + "type": "function", + "parameters": { + "new_status": { + "type": "integer", + "intou": "in", + "required": "true", + "description": "the new system status" + }, + "current_status": { + "type": "integer", + "inout": "out", + "required": "true", + "description": "get current system status" + } + } + }, + ... +] \ No newline at end of file diff --git a/requirements_generator/requirements.txt b/requirements_generator/requirements.txt new file mode 100644 index 0000000..e265d13 --- /dev/null +++ b/requirements_generator/requirements.txt @@ -0,0 +1,6 @@ +openai>=1.0.0 +python-dotenv>=1.0.0 +rich>=13.0.0 +python-docx>=0.8.11 +PyPDF2>=3.0.0 +click>=8.1.0 \ No newline at end of file diff --git a/requirements_generator/run.sh b/requirements_generator/run.sh new file mode 100755 index 0000000..90f28b0 --- /dev/null +++ b/requirements_generator/run.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +# 安装依赖 +pip install -r requirements.txt + +# 设置环境变量 +export OPENAI_API_KEY="sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR" +export OPENAI_BASE_URL="https://openapi.monica.im/v1" # 或其他兼容接口 +export LLM_MODEL="gpt-4o" + +python main.py diff --git a/requirements_generator/utils/__init__.py b/requirements_generator/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/requirements_generator/utils/file_handler.py b/requirements_generator/utils/file_handler.py new file mode 100644 index 0000000..f32738d --- /dev/null +++ b/requirements_generator/utils/file_handler.py @@ -0,0 +1,142 @@ +# utils/file_handler.py - 文件读取工具(支持 txt/md/pdf/docx) +import os +from typing import List, Optional +from pathlib import Path + + +def read_text_file(file_path: str) -> str: + """ + 读取纯文本文件内容(.txt / .md / .py 等) + + Args: + file_path: 文件路径 + + Returns: + 文件文本内容 + + Raises: + FileNotFoundError: 文件不存在 + UnicodeDecodeError: 编码错误时尝试 latin-1 兜底 + """ + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"文件不存在: {file_path}") + try: + return path.read_text(encoding="utf-8") + except UnicodeDecodeError: + return path.read_text(encoding="latin-1") + + +def read_pdf_file(file_path: str) -> str: + """ + 读取 PDF 文件内容 + + Args: + file_path: PDF 文件路径 + + Returns: + 提取的文本内容 + + Raises: + ImportError: 未安装 PyPDF2 + FileNotFoundError: 文件不存在 + """ + try: + import PyPDF2 + except ImportError: + raise ImportError("请安装 PyPDF2: pip install PyPDF2") + + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"文件不存在: {file_path}") + + texts = [] + with open(file_path, "rb") as f: + reader = PyPDF2.PdfReader(f) + for page in reader.pages: + text = page.extract_text() + if text: + texts.append(text) + return "\n".join(texts) + + +def read_docx_file(file_path: str) -> str: + """ + 读取 Word (.docx) 文件内容 + + Args: + file_path: docx 文件路径 + + Returns: + 提取的文本内容(段落合并) + + Raises: + ImportError: 未安装 python-docx + FileNotFoundError: 文件不存在 + """ + try: + from docx import Document + except ImportError: + raise ImportError("请安装 python-docx: pip install python-docx") + + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"文件不存在: {file_path}") + + doc = Document(file_path) + return "\n".join(para.text for para in doc.paragraphs if para.text.strip()) + + +def read_file_auto(file_path: str) -> str: + """ + 根据文件扩展名自动选择读取方式 + + Args: + file_path: 文件路径 + + Returns: + 文件文本内容 + + Raises: + ValueError: 不支持的文件类型 + """ + ext = Path(file_path).suffix.lower() + readers = { + ".txt": read_text_file, + ".md": read_text_file, + ".py": read_text_file, + ".json": read_text_file, + ".yaml": read_text_file, + ".yml": read_text_file, + ".pdf": read_pdf_file, + ".docx": read_docx_file, + } + reader = readers.get(ext) + if reader is None: + raise ValueError(f"不支持的文件类型: {ext},支持: {list(readers.keys())}") + return reader(file_path) + + +def merge_knowledge_files(file_paths: List[str]) -> str: + """ + 合并多个知识库文件为单一文本 + + Args: + file_paths: 知识库文件路径列表 + + Returns: + 合并后的知识库文本(包含文件名分隔符) + """ + if not file_paths: + return "" + + sections = [] + for fp in file_paths: + try: + content = read_file_auto(fp) + file_name = Path(fp).name + sections.append(f"### 知识库文件: {file_name}\n{content}") + except Exception as e: + sections.append(f"### 知识库文件: {fp}\n[读取失败: {e}]") + + return "\n\n".join(sections) \ No newline at end of file diff --git a/requirements_generator/utils/output_writer.py b/requirements_generator/utils/output_writer.py new file mode 100644 index 0000000..6fb9e58 --- /dev/null +++ b/requirements_generator/utils/output_writer.py @@ -0,0 +1,430 @@ +# utils/output_writer.py - 代码文件 & JSON 输出工具 +import os +import json +from pathlib import Path +from typing import Dict, List + +import config + + +# 各语言文件扩展名映射 +LANGUAGE_EXT_MAP: Dict[str, str] = { + "python": ".py", + "javascript": ".js", + "typescript": ".ts", + "java": ".java", + "go": ".go", + "rust": ".rs", + "cpp": ".cpp", + "c": ".c", + "csharp": ".cs", + "ruby": ".rb", + "php": ".php", + "swift": ".swift", + "kotlin": ".kt", +} + +# 合法的通用类型集合 +VALID_TYPES = { + "integer", "string", "boolean", "float", + "list", "dict", "object", "void", "any", +} + +# 合法的 inout 值 +VALID_INOUT = {"in", "out", "inout"} + + +def get_file_extension(language: str) -> str: + """ + 获取指定语言的文件扩展名 + + Args: + language: 编程语言名称(小写) + + Returns: + 文件扩展名(含点号,如 '.py') + """ + return LANGUAGE_EXT_MAP.get(language.lower(), ".txt") + + +def build_project_output_dir(project_name: str) -> str: + """ + 构建项目输出目录路径 + + Args: + project_name: 项目名称 + + Returns: + 输出目录路径 + """ + safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name) + return os.path.join(config.OUTPUT_BASE_DIR, safe_name) + + +def ensure_project_dir(project_name: str) -> str: + """ + 确保项目输出目录存在,不存在则创建 + + Args: + project_name: 项目名称 + + Returns: + 创建好的目录路径 + """ + output_dir = build_project_output_dir(project_name) + os.makedirs(output_dir, exist_ok=True) + init_file = os.path.join(output_dir, "__init__.py") + if not os.path.exists(init_file): + Path(init_file).write_text( + "# Auto-generated project package\n", encoding="utf-8" + ) + return output_dir + + +def write_code_file( + output_dir: str, + function_name: str, + language: str, + content: str, +) -> str: + """ + 将代码内容写入指定目录的文件 + + Args: + output_dir: 输出目录路径 + function_name: 函数名(用于生成文件名) + language: 编程语言 + content: 代码内容 + + Returns: + 写入的文件完整路径 + """ + ext = get_file_extension(language) + file_name = f"{function_name}{ext}" + file_path = os.path.join(output_dir, file_name) + Path(file_path).write_text(content, encoding="utf-8") + return file_path + + +def write_project_readme( + output_dir: str, + project_name: str, + requirements_summary: str, +) -> str: + """ + 在项目目录生成 README.md 文件 + + Args: + output_dir: 项目输出目录 + project_name: 项目名称 + requirements_summary: 功能需求摘要文本 + + Returns: + README.md 文件路径 + """ + readme_content = f"""# {project_name} + +> Auto-generated by Requirement Analyzer + +## 功能需求列表 + +{requirements_summary} +""" + readme_path = os.path.join(output_dir, "README.md") + Path(readme_path).write_text(readme_content, encoding="utf-8") + return readme_path + + +# ══════════════════════════════════════════════════════ +# 函数签名 JSON 导出 +# ══════════════════════════════════════════════════════ + +def build_signatures_document( + project_name: str, + project_description: str, + signatures: List[dict], +) -> dict: + """ + 将函数签名列表包装为带项目信息的顶层文档结构。 + + Args: + project_name: 项目名称,写入 "project" 字段 + project_description: 项目描述,写入 "description" 字段 + signatures: 函数签名 dict 列表,写入 "functions" 字段 + + Returns: + 顶层文档 dict,结构为:: + + { + "project": "", + "description": "", + "functions": [ ... ] + } + """ + return { + "project": project_name, + "description": project_description or "", + "functions": signatures, + } + + +def patch_signatures_with_url( + signatures: List[dict], + func_name_to_url: Dict[str, str], +) -> List[dict]: + """ + 将代码文件的路径(URL)回写到对应函数签名的 "url" 字段。 + + 遍历签名列表,根据 signature["name"] 在 func_name_to_url 中查找 + 对应路径,找到则写入 "url" 字段;未找到则写入空字符串,不抛出异常。 + + "url" 字段插入位置紧跟在 "type" 字段之后,以保持字段顺序的可读性:: + + { + "name": "create_user", + "requirement_id": "REQ.01", + "description": "...", + "type": "function", + "url": "/abs/path/to/create_user.py", ← 新增 + "parameters": { ... }, + "return": { ... } + } + + Args: + signatures: 原始签名列表(in-place 修改) + func_name_to_url: {函数名: 代码文件绝对路径} 映射表, + 由 CodeGenerator.generate_batch() 的进度回调收集 + + Returns: + 修改后的签名列表(与传入的同一对象,方便链式调用) + """ + for sig in signatures: + func_name = sig.get("name", "") + url = func_name_to_url.get(func_name, "") + _insert_field_after(sig, after_key="type", new_key="url", new_value=url) + return signatures + + +def _insert_field_after( + d: dict, + after_key: str, + new_key: str, + new_value, +) -> None: + """ + 在有序 dict 中将 new_key 插入到 after_key 之后。 + 若 after_key 不存在,则追加到末尾。 + 若 new_key 已存在,则直接更新其值(不改变位置)。 + + Args: + d: 目标 dict(Python 3.7+ 保证插入顺序) + after_key: 参考键名 + new_key: 要插入的键名 + new_value: 要插入的值 + """ + if new_key in d: + d[new_key] = new_value + return + + items = list(d.items()) + insert_pos = len(items) + for i, (k, _) in enumerate(items): + if k == after_key: + insert_pos = i + 1 + break + + items.insert(insert_pos, (new_key, new_value)) + d.clear() + d.update(items) + + +def write_function_signatures_json( + output_dir: str, + signatures: List[dict], + project_name: str, + project_description: str, + file_name: str = "function_signatures.json", +) -> str: + """ + 将函数签名列表连同项目信息一起导出为 JSON 文件。 + + 输出的 JSON 顶层结构为:: + + { + "project": "", + "description": "", + "functions": [ + { + "name": "...", + "requirement_id": "...", + "description": "...", + "type": "function", + "url": "/abs/path/to/xxx.py", + "parameters": { ... }, + "return": { ... } + }, + ... + ] + } + + Args: + output_dir: JSON 文件写入目录 + signatures: 函数签名 dict 列表(应已通过 + patch_signatures_with_url() 写入 "url" 字段) + project_name: 项目名称 + project_description: 项目描述 + file_name: 输出文件名,默认 function_signatures.json + + Returns: + 写入的 JSON 文件完整路径 + + Raises: + OSError: 目录不可写 + """ + os.makedirs(output_dir, exist_ok=True) + document = build_signatures_document(project_name, project_description, signatures) + file_path = os.path.join(output_dir, file_name) + with open(file_path, "w", encoding="utf-8") as f: + json.dump(document, f, ensure_ascii=False, indent=2) + return file_path + + +# ══════════════════════════════════════════════════════ +# 签名结构校验 +# ══════════════════════════════════════════════════════ + +def validate_signature_schema(signature: dict) -> List[str]: + """ + 校验单个函数签名 dict 是否符合规范。 + + 校验范围: + - 顶层必填字段:name / requirement_id / description / type / parameters + - 可选字段 "url":若存在则必须为非空字符串 + - parameters:每个参数的 type / inout / required 字段 + - return:type 字段 + on_success / on_failure 子结构 + - void 函数:on_success / on_failure 应为 null + - 非 void 函数:on_success / on_failure 必须存在, + 且 value(非空)与 description(非空)均需填写 + + Args: + signature: 单个函数签名 dict + + Returns: + 错误信息字符串列表,列表为空表示校验通过 + """ + errors: List[str] = [] + + # ── 顶层必填字段 ────────────────────────────────── + for key in ("name", "requirement_id", "description", "type", "parameters"): + if key not in signature: + errors.append(f"缺少顶层字段: '{key}'") + + # ── url 字段(可选,存在时校验非空)───────────────── + if "url" in signature: + if not isinstance(signature["url"], str): + errors.append("'url' 字段必须是字符串类型") + elif signature["url"] == "": + errors.append("'url' 字段不能为空字符串(代码文件路径未成功回写)") + + # ── parameters ──────────────────────────────────── + params = signature.get("parameters", {}) + if not isinstance(params, dict): + errors.append("'parameters' 必须是 dict 类型") + else: + for pname, pdef in params.items(): + if not isinstance(pdef, dict): + errors.append(f"参数 '{pname}' 定义必须是 dict") + continue + # type(支持联合类型,如 "string|integer") + if "type" not in pdef: + errors.append(f"参数 '{pname}' 缺少 'type' 字段") + else: + parts = [p.strip() for p in pdef["type"].split("|")] + if not all(p in VALID_TYPES for p in parts): + errors.append( + f"参数 '{pname}' 的 type='{pdef['type']}' 含有不合法的类型" + ) + # inout + if "inout" not in pdef: + errors.append(f"参数 '{pname}' 缺少 'inout' 字段") + elif pdef["inout"] not in VALID_INOUT: + errors.append( + f"参数 '{pname}' 的 inout='{pdef['inout']}' 应为 in/out/inout" + ) + # required + if "required" not in pdef: + errors.append(f"参数 '{pname}' 缺少 'required' 字段") + elif not isinstance(pdef["required"], bool): + errors.append( + f"参数 '{pname}' 的 'required' 应为布尔值 true/false," + f"当前为: {pdef['required']!r}" + ) + + # ── return ──────────────────────────────────────── + ret = signature.get("return") + if ret is None: + errors.append( + "缺少 'return' 字段(void 函数请填 " + "{\"type\": \"void\", \"on_success\": null, \"on_failure\": null})" + ) + elif not isinstance(ret, dict): + errors.append("'return' 必须是 dict 类型") + else: + ret_type = ret.get("type") + if not ret_type: + errors.append("'return' 缺少 'type' 字段") + elif ret_type not in VALID_TYPES: + errors.append(f"'return.type'='{ret_type}' 不在合法类型列表中") + + is_void = (ret_type == "void") + + for sub_key in ("on_success", "on_failure"): + sub = ret.get(sub_key) + if is_void: + if sub is not None: + errors.append( + f"void 函数的 'return.{sub_key}' 应为 null," + f"当前为: {sub!r}" + ) + else: + if sub is None: + errors.append( + f"非 void 函数缺少 'return.{sub_key}'," + f"请描述{'成功' if sub_key == 'on_success' else '失败'}时的返回值" + ) + elif not isinstance(sub, dict): + errors.append(f"'return.{sub_key}' 必须是 dict 类型") + else: + if "value" not in sub: + errors.append(f"'return.{sub_key}' 缺少 'value' 字段") + elif sub["value"] == "": + errors.append( + f"'return.{sub_key}.value' 不能为空字符串," + f"请填写具体返回值、值域描述或结构示例" + ) + if "description" not in sub or sub.get("description") in (None, ""): + errors.append(f"'return.{sub_key}.description' 不能为空") + + return errors + + +def validate_all_signatures(signatures: List[dict]) -> Dict[str, List[str]]: + """ + 批量校验函数签名列表。 + + 注意:此函数接受的是纯签名列表(即顶层文档的 "functions" 字段), + 而非包含 project/description 的顶层文档。 + + Args: + signatures: 函数签名 dict 列表 + + Returns: + {函数名: [错误信息, ...]} 字典,仅包含有错误的条目 + """ + report: Dict[str, List[str]] = {} + for sig in signatures: + name = sig.get("name", f"unknown_{id(sig)}") + errs = validate_signature_schema(sig) + if errs: + report[name] = errs + return report \ No newline at end of file