Default Changelist
This commit is contained in:
parent
b16661ac91
commit
9d42954d44
|
|
@ -0,0 +1,16 @@
|
|||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.transport_security import TransportSecuritySettings
|
||||
|
||||
mcp = FastMCP("requirements",
|
||||
transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False))
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def generate_requirement_document(project: str, requirements: list, standard: str = "GJB438C"):
|
||||
print(project, requirements, standard)
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.settings.host = "0.0.0.0"
|
||||
mcp.settings.port = 7999
|
||||
|
||||
mcp.run(transport="streamable-http")
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
# LLM配置(以OpenAI兼容接口为例,可替换为其他LLM)
|
||||
LLM_API_KEY: str = os.getenv("LLM_API_KEY", "your-api-key-here")
|
||||
LLM_BASE_URL: str = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1")
|
||||
LLM_MODEL: str = os.getenv("LLM_MODEL", "gpt-4o")
|
||||
LLM_TEMPERATURE: float = 0.2
|
||||
LLM_MAX_TOKENS: int = 4096
|
||||
|
||||
# 测试配置
|
||||
GENERATED_TESTS_DIR: str = "generated_tests"
|
||||
TEST_TIMEOUT: int = 30 # 单个测试超时时间(秒)
|
||||
|
||||
# HTTP测试配置
|
||||
HTTP_BASE_URL: str = os.getenv("HTTP_BASE_URL", "http://localhost:8080")
|
||||
HTTP_TIMEOUT: int = 10
|
||||
|
||||
config = Config()
|
||||
|
|
@ -0,0 +1,609 @@
|
|||
"""
|
||||
覆盖率分析与缺口识别
|
||||
══════════════════════════════════════════════════════════════
|
||||
修复:
|
||||
- 移除对已删除属性 request_parameters / url_parameters / response 的引用
|
||||
- HTTP 与 function 接口统一使用 iface.in_parameters / iface.out_parameters
|
||||
- HTTP 接口的返回字段覆盖统一从 return_info.on_success / on_failure 读取
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from core.parser import InterfaceInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# 数据结构
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
|
||||
@dataclass
|
||||
class InterfaceCoverage:
|
||||
interface_name: str
|
||||
protocol: str
|
||||
requirement_id: str = ""
|
||||
is_covered: bool = False
|
||||
|
||||
# 入参覆盖(in / inout)
|
||||
all_in_params: set[str] = field(default_factory=set)
|
||||
covered_in_params: set[str] = field(default_factory=set)
|
||||
|
||||
# 出参断言覆盖(out / inout)
|
||||
all_out_params: set[str] = field(default_factory=set)
|
||||
asserted_out_params: set[str] = field(default_factory=set)
|
||||
|
||||
# on_success 返回字段断言覆盖
|
||||
all_success_fields: set[str] = field(default_factory=set)
|
||||
covered_success_fields: set[str] = field(default_factory=set)
|
||||
|
||||
# on_failure 返回字段断言覆盖
|
||||
all_failure_fields: set[str] = field(default_factory=set)
|
||||
covered_failure_fields: set[str] = field(default_factory=set)
|
||||
|
||||
# 异常断言(on_failure = raises Exception)
|
||||
expects_exception: bool = False
|
||||
exception_case_covered: bool = False
|
||||
|
||||
has_positive_case: bool = False
|
||||
has_negative_case: bool = False
|
||||
covering_test_ids: list[str] = field(default_factory=list)
|
||||
|
||||
# ── 覆盖率计算 ────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def in_param_coverage_rate(self) -> float:
|
||||
if not self.all_in_params:
|
||||
return 1.0
|
||||
return len(self.covered_in_params) / len(self.all_in_params)
|
||||
|
||||
@property
|
||||
def out_param_coverage_rate(self) -> float:
|
||||
if not self.all_out_params:
|
||||
return 1.0
|
||||
return len(self.asserted_out_params) / len(self.all_out_params)
|
||||
|
||||
@property
|
||||
def success_field_coverage_rate(self) -> float:
|
||||
if not self.all_success_fields:
|
||||
return 1.0
|
||||
return len(self.covered_success_fields) / len(self.all_success_fields)
|
||||
|
||||
@property
|
||||
def failure_field_coverage_rate(self) -> float:
|
||||
if not self.all_failure_fields:
|
||||
return 1.0
|
||||
return len(self.covered_failure_fields) / len(self.all_failure_fields)
|
||||
|
||||
@property
|
||||
def uncovered_in_params(self) -> set[str]:
|
||||
return self.all_in_params - self.covered_in_params
|
||||
|
||||
@property
|
||||
def uncovered_out_params(self) -> set[str]:
|
||||
return self.all_out_params - self.asserted_out_params
|
||||
|
||||
@property
|
||||
def uncovered_success_fields(self) -> set[str]:
|
||||
return self.all_success_fields - self.covered_success_fields
|
||||
|
||||
@property
|
||||
def uncovered_failure_fields(self) -> set[str]:
|
||||
return self.all_failure_fields - self.covered_failure_fields
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequirementCoverage:
|
||||
requirement: str
|
||||
requirement_id: str = ""
|
||||
covering_test_ids: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def is_covered(self) -> bool:
|
||||
return bool(self.covering_test_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gap:
|
||||
gap_type: str # 见下方常量
|
||||
severity: str # critical | high | medium | low
|
||||
target: str
|
||||
detail: str
|
||||
suggestion: str
|
||||
|
||||
|
||||
# gap_type 常量
|
||||
GAP_INTERFACE_NOT_COVERED = "interface_not_covered"
|
||||
GAP_IN_PARAM_NOT_COVERED = "in_param_not_covered"
|
||||
GAP_OUT_PARAM_NOT_ASSERTED = "out_param_not_asserted"
|
||||
GAP_SUCCESS_FIELD_NOT_ASSERTED = "success_field_not_asserted"
|
||||
GAP_FAILURE_FIELD_NOT_ASSERTED = "failure_field_not_asserted"
|
||||
GAP_EXCEPTION_NOT_TESTED = "exception_case_not_tested"
|
||||
GAP_MISSING_POSITIVE = "missing_positive_case"
|
||||
GAP_MISSING_NEGATIVE = "missing_negative_case"
|
||||
GAP_REQUIREMENT_NOT_COVERED = "requirement_not_covered"
|
||||
GAP_TEST_FAILED = "test_failed"
|
||||
GAP_TEST_ERROR = "test_error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CoverageReport:
|
||||
total_interfaces: int = 0
|
||||
covered_interfaces: int = 0
|
||||
total_requirements: int = 0
|
||||
covered_requirements: int = 0
|
||||
total_test_cases: int = 0
|
||||
passed_test_cases: int = 0
|
||||
failed_test_cases: int = 0
|
||||
error_test_cases: int = 0
|
||||
|
||||
interface_coverages: list[InterfaceCoverage] = field(default_factory=list)
|
||||
requirement_coverages: list[RequirementCoverage] = field(default_factory=list)
|
||||
gaps: list[Gap] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def interface_coverage_rate(self) -> float:
|
||||
return self.covered_interfaces / self.total_interfaces \
|
||||
if self.total_interfaces else 0.0
|
||||
|
||||
@property
|
||||
def requirement_coverage_rate(self) -> float:
|
||||
return self.covered_requirements / self.total_requirements \
|
||||
if self.total_requirements else 0.0
|
||||
|
||||
@property
|
||||
def pass_rate(self) -> float:
|
||||
return self.passed_test_cases / self.total_test_cases \
|
||||
if self.total_test_cases else 0.0
|
||||
|
||||
@property
|
||||
def avg_in_param_coverage_rate(self) -> float:
|
||||
rates = [
|
||||
ic.in_param_coverage_rate
|
||||
for ic in self.interface_coverages
|
||||
if ic.all_in_params
|
||||
]
|
||||
return sum(rates) / len(rates) if rates else 1.0
|
||||
|
||||
@property
|
||||
def avg_success_field_coverage_rate(self) -> float:
|
||||
rates = [
|
||||
ic.success_field_coverage_rate
|
||||
for ic in self.interface_coverages
|
||||
if ic.all_success_fields
|
||||
]
|
||||
return sum(rates) / len(rates) if rates else 1.0
|
||||
|
||||
@property
|
||||
def avg_failure_field_coverage_rate(self) -> float:
|
||||
rates = [
|
||||
ic.failure_field_coverage_rate
|
||||
for ic in self.interface_coverages
|
||||
if ic.all_failure_fields
|
||||
]
|
||||
return sum(rates) / len(rates) if rates else 1.0
|
||||
|
||||
@property
|
||||
def critical_gap_count(self) -> int:
|
||||
return sum(1 for g in self.gaps if g.severity == "critical")
|
||||
|
||||
@property
|
||||
def high_gap_count(self) -> int:
|
||||
return sum(1 for g in self.gaps if g.severity == "high")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# 分析器
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
|
||||
class CoverageAnalyzer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
interfaces: list[InterfaceInfo],
|
||||
requirements: list[str],
|
||||
test_cases: list[dict],
|
||||
run_results: list,
|
||||
):
|
||||
self.interfaces = interfaces
|
||||
self.requirements = requirements
|
||||
self.test_cases = test_cases
|
||||
self.run_results = run_results
|
||||
self._iface_map = {i.name: i for i in interfaces}
|
||||
self._result_map = {
|
||||
getattr(r, "test_id", ""): r for r in run_results
|
||||
}
|
||||
|
||||
# ── 入口 ──────────────────────────────────────────────────
|
||||
|
||||
def analyze(self) -> CoverageReport:
|
||||
logger.info("Starting coverage analysis …")
|
||||
report = CoverageReport()
|
||||
|
||||
iface_cov_map = self._init_interface_coverages()
|
||||
req_cov_map = self._init_requirement_coverages()
|
||||
|
||||
self._scan_test_cases(iface_cov_map, req_cov_map)
|
||||
self._scan_run_results(report)
|
||||
|
||||
report.interface_coverages = list(iface_cov_map.values())
|
||||
report.requirement_coverages = list(req_cov_map.values())
|
||||
report.total_interfaces = len(self.interfaces)
|
||||
report.covered_interfaces = sum(
|
||||
1 for ic in iface_cov_map.values() if ic.is_covered
|
||||
)
|
||||
report.total_requirements = len(self.requirements)
|
||||
report.covered_requirements = sum(
|
||||
1 for rc in req_cov_map.values() if rc.is_covered
|
||||
)
|
||||
report.total_test_cases = len(self.test_cases)
|
||||
report.gaps = self._identify_gaps(iface_cov_map, req_cov_map)
|
||||
|
||||
logger.info(
|
||||
f"Analysis done — "
|
||||
f"interface={report.interface_coverage_rate:.0%}, "
|
||||
f"requirement={report.requirement_coverage_rate:.0%}, "
|
||||
f"pass={report.pass_rate:.0%}, "
|
||||
f"gaps={len(report.gaps)}"
|
||||
)
|
||||
return report
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# 初始化接口覆盖对象
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
def _init_interface_coverages(self) -> dict[str, InterfaceCoverage]:
|
||||
result: dict[str, InterfaceCoverage] = {}
|
||||
|
||||
for iface in self.interfaces:
|
||||
ic = InterfaceCoverage(
|
||||
interface_name=iface.name,
|
||||
protocol=iface.protocol,
|
||||
requirement_id=iface.requirement_id,
|
||||
)
|
||||
|
||||
# ── 入参(in / inout)────────────────────────────
|
||||
# HTTP 与 function 统一使用 in_parameters
|
||||
ic.all_in_params = {p.name for p in iface.in_parameters}
|
||||
|
||||
# ── 出参(out / inout)───────────────────────────
|
||||
ic.all_out_params = {p.name for p in iface.out_parameters}
|
||||
|
||||
# ── 返回值字段(on_success)──────────────────────
|
||||
ri = iface.return_info
|
||||
success_fields = set(ri.on_success.value_fields.keys())
|
||||
|
||||
# boolean / 非 dict 类型:用虚拟字段 "return" 代表返回值本身
|
||||
if ri.type in ("boolean", "integer", "string", "float") or (
|
||||
ri.on_success.value is not None
|
||||
and not isinstance(ri.on_success.value, dict)
|
||||
):
|
||||
success_fields.add("return")
|
||||
|
||||
ic.all_success_fields = success_fields
|
||||
|
||||
# ── 返回值字段(on_failure)──────────────────────
|
||||
if ri.on_failure.is_exception:
|
||||
ic.expects_exception = True
|
||||
else:
|
||||
failure_fields = set(ri.on_failure.value_fields.keys())
|
||||
if ri.type in ("boolean", "integer", "string", "float") or (
|
||||
ri.on_failure.value is not None
|
||||
and not isinstance(ri.on_failure.value, dict)
|
||||
and not ri.on_failure.is_exception
|
||||
):
|
||||
failure_fields.add("return")
|
||||
ic.all_failure_fields = failure_fields
|
||||
|
||||
result[iface.name] = ic
|
||||
|
||||
return result
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# 初始化需求覆盖对象
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
def _init_requirement_coverages(self) -> dict[str, RequirementCoverage]:
|
||||
result: dict[str, RequirementCoverage] = {}
|
||||
for req in self.requirements:
|
||||
result[req] = RequirementCoverage(requirement=req)
|
||||
|
||||
# requirement_id → requirement text 的辅助映射
|
||||
self._req_id_map: dict[str, str] = {}
|
||||
for iface in self.interfaces:
|
||||
if iface.requirement_id:
|
||||
matched = self._fuzzy_match_req(iface.name, result)
|
||||
if matched:
|
||||
self._req_id_map[iface.requirement_id] = matched
|
||||
|
||||
return result
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# 扫描测试用例
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
def _scan_test_cases(
|
||||
self,
|
||||
iface_cov_map: dict[str, InterfaceCoverage],
|
||||
req_cov_map: dict[str, RequirementCoverage],
|
||||
):
|
||||
for tc in self.test_cases:
|
||||
test_id = tc.get("test_id", "")
|
||||
requirement = tc.get("requirement", "")
|
||||
req_id = tc.get("requirement_id", "")
|
||||
description = (
|
||||
tc.get("description", "") + " " + tc.get("test_name", "")
|
||||
).lower()
|
||||
|
||||
is_negative = any(
|
||||
kw in description
|
||||
for kw in (
|
||||
"negative", "invalid", "missing", "fail", "error",
|
||||
"wrong", "bad", "boundary", "负向", "exception",
|
||||
)
|
||||
)
|
||||
|
||||
# 需求覆盖匹配
|
||||
matched_req = (
|
||||
self._match_by_req_id(req_id, req_cov_map)
|
||||
or self._match_by_text(requirement, req_cov_map)
|
||||
)
|
||||
if matched_req:
|
||||
req_cov_map[matched_req].covering_test_ids.append(test_id)
|
||||
|
||||
# 遍历步骤
|
||||
for step in tc.get("steps", []):
|
||||
iface_name = step.get("interface_name", "")
|
||||
ic = iface_cov_map.get(iface_name)
|
||||
if ic is None:
|
||||
continue
|
||||
|
||||
ic.is_covered = True
|
||||
if test_id not in ic.covering_test_ids:
|
||||
ic.covering_test_ids.append(test_id)
|
||||
|
||||
if is_negative:
|
||||
ic.has_negative_case = True
|
||||
else:
|
||||
ic.has_positive_case = True
|
||||
|
||||
# 入参覆盖
|
||||
for param_name in step.get("input", {}).keys():
|
||||
if param_name in ic.all_in_params:
|
||||
ic.covered_in_params.add(param_name)
|
||||
|
||||
# 断言覆盖
|
||||
for assertion in step.get("assertions", []):
|
||||
f = assertion.get("field", "")
|
||||
operator = assertion.get("operator", "")
|
||||
|
||||
# 异常断言
|
||||
if f == "exception" and operator == "raised":
|
||||
ic.exception_case_covered = True
|
||||
continue
|
||||
|
||||
# 出参断言
|
||||
if f in ic.all_out_params:
|
||||
ic.asserted_out_params.add(f)
|
||||
|
||||
# on_success 字段(正向用例)
|
||||
if not is_negative and f in ic.all_success_fields:
|
||||
ic.covered_success_fields.add(f)
|
||||
|
||||
# on_failure 字段(负向用例)
|
||||
if is_negative and f in ic.all_failure_fields:
|
||||
ic.covered_failure_fields.add(f)
|
||||
|
||||
# boolean / scalar return 断言
|
||||
if f == "return":
|
||||
if not is_negative:
|
||||
ic.covered_success_fields.add("return")
|
||||
else:
|
||||
ic.covered_failure_fields.add("return")
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# 扫描执行结果
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
def _scan_run_results(self, report: CoverageReport):
|
||||
for r in self.run_results:
|
||||
s = getattr(r, "status", "")
|
||||
if s == "PASS":
|
||||
report.passed_test_cases += 1
|
||||
elif s == "FAIL":
|
||||
report.failed_test_cases += 1
|
||||
else:
|
||||
report.error_test_cases += 1
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# 缺口识别
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
def _identify_gaps(
|
||||
self,
|
||||
iface_cov_map: dict[str, InterfaceCoverage],
|
||||
req_cov_map: dict[str, RequirementCoverage],
|
||||
) -> list[Gap]:
|
||||
gaps: list[Gap] = []
|
||||
|
||||
for ic in iface_cov_map.values():
|
||||
n = ic.interface_name
|
||||
rid = f"[{ic.requirement_id}] " if ic.requirement_id else ""
|
||||
|
||||
# ① 接口完全未覆盖
|
||||
if not ic.is_covered:
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_INTERFACE_NOT_COVERED,
|
||||
severity="critical", target=n,
|
||||
detail=f"{rid}'{n}' has NO test case.",
|
||||
suggestion=f"Add positive + negative test cases for '{n}'.",
|
||||
))
|
||||
continue # 后续缺口依赖覆盖,跳过
|
||||
|
||||
# ② 缺少正向用例
|
||||
if not ic.has_positive_case:
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_MISSING_POSITIVE,
|
||||
severity="critical", target=n,
|
||||
detail=f"{rid}'{n}' has no positive (happy-path) test.",
|
||||
suggestion=f"Add a positive test case with valid inputs for '{n}'.",
|
||||
))
|
||||
|
||||
# ③ 缺少负向用例
|
||||
if not ic.has_negative_case:
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_MISSING_NEGATIVE,
|
||||
severity="high", target=n,
|
||||
detail=f"{rid}'{n}' has no negative test case.",
|
||||
suggestion=(
|
||||
f"Add negative tests for '{n}': "
|
||||
f"missing required params, invalid types, boundary values."
|
||||
),
|
||||
))
|
||||
|
||||
# ④ 入参未覆盖
|
||||
for param in sorted(ic.uncovered_in_params):
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_IN_PARAM_NOT_COVERED,
|
||||
severity="high", target=n,
|
||||
detail=f"{rid}Input param '{param}' of '{n}' never used in tests.",
|
||||
suggestion=f"Add a test that explicitly passes '{param}' to '{n}'.",
|
||||
))
|
||||
|
||||
# ⑤ 出参未断言
|
||||
for param in sorted(ic.uncovered_out_params):
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_OUT_PARAM_NOT_ASSERTED,
|
||||
severity="high", target=n,
|
||||
detail=f"{rid}Output param '{param}' of '{n}' never asserted.",
|
||||
suggestion=(
|
||||
f"Add an assertion on output param '{param}' "
|
||||
f"after calling '{n}'."
|
||||
),
|
||||
))
|
||||
|
||||
# ⑥ on_success 返回字段未断言
|
||||
for f in sorted(ic.uncovered_success_fields):
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_SUCCESS_FIELD_NOT_ASSERTED,
|
||||
severity="high", target=n,
|
||||
detail=(
|
||||
f"{rid}on_success field '{f}' of '{n}' "
|
||||
f"never asserted in positive tests."
|
||||
),
|
||||
suggestion=(
|
||||
f"Add assertion on '{f}' in the positive "
|
||||
f"test step for '{n}'."
|
||||
),
|
||||
))
|
||||
|
||||
# ⑦ on_failure 返回字段未断言
|
||||
for f in sorted(ic.uncovered_failure_fields):
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_FAILURE_FIELD_NOT_ASSERTED,
|
||||
severity="medium", target=n,
|
||||
detail=(
|
||||
f"{rid}on_failure field '{f}' of '{n}' "
|
||||
f"never asserted in negative tests."
|
||||
),
|
||||
suggestion=(
|
||||
f"Add assertion on '{f}' in the negative "
|
||||
f"test step for '{n}'."
|
||||
),
|
||||
))
|
||||
|
||||
# ⑧ 异常场景未测试
|
||||
if ic.expects_exception and not ic.exception_case_covered:
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_EXCEPTION_NOT_TESTED,
|
||||
severity="high", target=n,
|
||||
detail=(
|
||||
f"{rid}'{n}' declares on_failure = raises Exception, "
|
||||
f"but no test asserts that the exception is raised."
|
||||
),
|
||||
suggestion=(
|
||||
f"Add a negative test for '{n}' that wraps the call "
|
||||
f"in try/except and asserts the exception is raised."
|
||||
),
|
||||
))
|
||||
|
||||
# ⑨ 需求未覆盖
|
||||
for rc in req_cov_map.values():
|
||||
if not rc.is_covered:
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_REQUIREMENT_NOT_COVERED,
|
||||
severity="critical",
|
||||
target=rc.requirement,
|
||||
detail=f"Requirement '{rc.requirement}' has no test case.",
|
||||
suggestion=f"Generate test cases for: '{rc.requirement}'.",
|
||||
))
|
||||
|
||||
# ⑩ 执行失败 / 错误
|
||||
for r in self.run_results:
|
||||
s = getattr(r, "status", "")
|
||||
if s == "FAIL":
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_TEST_FAILED,
|
||||
severity="high",
|
||||
target=getattr(r, "test_id", ""),
|
||||
detail=f"Test '{r.test_id}' FAILED: {r.message}",
|
||||
suggestion=(
|
||||
"Investigate failure; fix implementation or test data."
|
||||
),
|
||||
))
|
||||
elif s in ("ERROR", "TIMEOUT"):
|
||||
gaps.append(Gap(
|
||||
gap_type=GAP_TEST_ERROR,
|
||||
severity="medium",
|
||||
target=getattr(r, "test_id", ""),
|
||||
detail=f"Test '{r.test_id}' {s}: {r.message}",
|
||||
suggestion=(
|
||||
"Check script for runtime errors or increase TEST_TIMEOUT."
|
||||
),
|
||||
))
|
||||
|
||||
# 按严重程度排序
|
||||
_order = {"critical": 0, "high": 1, "medium": 2, "low": 3}
|
||||
gaps.sort(key=lambda g: _order.get(g.severity, 9))
|
||||
return gaps
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# 工具方法
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
def _match_by_req_id(
|
||||
self,
|
||||
req_id: str,
|
||||
req_cov_map: dict[str, RequirementCoverage],
|
||||
) -> str | None:
|
||||
if not req_id:
|
||||
return None
|
||||
mapped = self._req_id_map.get(req_id)
|
||||
if mapped and mapped in req_cov_map:
|
||||
return mapped
|
||||
return None
|
||||
|
||||
def _match_by_text(
|
||||
self,
|
||||
requirement: str,
|
||||
req_cov_map: dict[str, RequirementCoverage],
|
||||
) -> str | None:
|
||||
if requirement in req_cov_map:
|
||||
return requirement
|
||||
req_lower = requirement.lower()
|
||||
for key in req_cov_map:
|
||||
if key.lower() in req_lower or req_lower in key.lower():
|
||||
return key
|
||||
return None
|
||||
|
||||
def _fuzzy_match_req(
|
||||
self,
|
||||
iface_name: str,
|
||||
req_cov_map: dict[str, RequirementCoverage],
|
||||
) -> str | None:
|
||||
name_lower = iface_name.lower().replace("_", " ")
|
||||
for key in req_cov_map:
|
||||
if name_lower in key.lower() or key.lower() in name_lower:
|
||||
return key
|
||||
return None
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
import json
|
||||
import re
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
from config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""封装 LLM API 调用(OpenAI 兼容接口)"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = OpenAI(
|
||||
api_key=config.LLM_API_KEY,
|
||||
base_url=config.LLM_BASE_URL,
|
||||
)
|
||||
|
||||
def generate_test_cases(self, system_prompt: str, user_prompt: str) -> list[dict]:
|
||||
logger.info(f"Calling LLM: model={config.LLM_MODEL}")
|
||||
logger.debug(f"--- USER PROMPT ---\n{user_prompt}\n---")
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=config.LLM_MODEL,
|
||||
temperature=config.LLM_TEMPERATURE,
|
||||
max_tokens=config.LLM_MAX_TOKENS,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
)
|
||||
raw = response.choices[0].message.content
|
||||
logger.debug(f"--- LLM RAW RESPONSE ---\n{raw[:800]}\n---")
|
||||
return self._parse_json(raw)
|
||||
|
||||
# ── 解析 ──────────────────────────────────────────────────
|
||||
|
||||
def _parse_json(self, content: str) -> list[dict]:
|
||||
content = content.strip()
|
||||
# 去除可能的 markdown 代码块
|
||||
content = re.sub(r"^```(?:json)?\s*", "", content, flags=re.MULTILINE)
|
||||
content = re.sub(r"\s*```\s*$", "", content, flags=re.MULTILINE)
|
||||
content = content.strip()
|
||||
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 尝试提取第一个 JSON 数组
|
||||
match = re.search(r"\[.*\]", content, re.DOTALL)
|
||||
if match:
|
||||
data = json.loads(match.group())
|
||||
else:
|
||||
logger.error(f"Cannot parse LLM response as JSON:\n{content[:600]}")
|
||||
raise
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValueError(f"Expected JSON array, got {type(data)}")
|
||||
return data
|
||||
|
|
@ -0,0 +1,388 @@
|
|||
"""
|
||||
接口描述解析器
|
||||
══════════════════════════════════════════════════════════════
|
||||
支持格式(当前版):
|
||||
顶层字段:
|
||||
- project : 项目名称,用于命名输出目录
|
||||
- description : 项目描述
|
||||
- units : 接口/函数描述列表
|
||||
|
||||
接口公共字段:
|
||||
- name : 接口/函数名称
|
||||
- requirement_id : 需求编号
|
||||
- description : 接口描述
|
||||
- type : "function" | "http"
|
||||
- url : function → Python 源文件路径,如 "create_user.py"
|
||||
http → base URL,如 "http://127.0.0.1/api"
|
||||
- parameters : 统一参数字段(HTTP 与 function 两种协议共用)
|
||||
⚠️ HTTP 接口不再使用 url_parameters / request_parameters
|
||||
- return : 返回值描述(含 on_success / on_failure)
|
||||
|
||||
HTTP 专用字段:
|
||||
- method : "get" | "post" | "put" | "delete" 等
|
||||
- 完整请求 URL = url.rstrip("/") + "/" + name.lstrip("/")
|
||||
|
||||
Function 专用:
|
||||
- url 为 .py 文件路径,转换为 Python 模块路径供 import 使用
|
||||
- 函数名即 name 字段
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# 基础数据结构
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
|
||||
@dataclass
|
||||
class ParameterInfo:
|
||||
name: str
|
||||
type: str # 支持复合类型 "string|integer"
|
||||
description: str
|
||||
required: bool = True
|
||||
default: Any = None
|
||||
inout: str = "in" # "in" | "out" | "inout"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReturnCase:
|
||||
"""on_success 或 on_failure 的返回描述"""
|
||||
value: Any = None
|
||||
description: str = ""
|
||||
|
||||
@property
|
||||
def is_exception(self) -> bool:
|
||||
"""判断是否为抛出异常的场景"""
|
||||
return (
|
||||
isinstance(self.value, str)
|
||||
and "exception" in self.value.lower()
|
||||
)
|
||||
|
||||
@property
|
||||
def value_fields(self) -> dict[str, Any]:
|
||||
"""若 value 是 dict,返回其字段;否则返回空 dict"""
|
||||
return self.value if isinstance(self.value, dict) else {}
|
||||
|
||||
@property
|
||||
def value_summary(self) -> str:
|
||||
"""返回值的简短文字描述,用于 LLM 提示词"""
|
||||
if isinstance(self.value, dict):
|
||||
return json.dumps(self.value, ensure_ascii=False)
|
||||
if isinstance(self.value, bool):
|
||||
return str(self.value).lower()
|
||||
return str(self.value) if self.value is not None else ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReturnInfo:
|
||||
type: str = ""
|
||||
description: str = ""
|
||||
on_success: ReturnCase = field(default_factory=ReturnCase)
|
||||
on_failure: ReturnCase = field(default_factory=ReturnCase)
|
||||
|
||||
@property
|
||||
def all_fields(self) -> set[str]:
|
||||
"""汇总 on_success + on_failure 中出现的所有字段名"""
|
||||
fields: set[str] = set()
|
||||
for case in (self.on_success, self.on_failure):
|
||||
fields.update(case.value_fields.keys())
|
||||
return fields
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# 接口描述
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
|
||||
@dataclass
|
||||
class InterfaceInfo:
|
||||
name: str
|
||||
description: str
|
||||
protocol: str # "http" | "function"
|
||||
url: str = "" # 原始 url 字段
|
||||
requirement_id: str = ""
|
||||
|
||||
# HTTP 专用
|
||||
method: str = "get"
|
||||
|
||||
# 统一参数列表(HTTP 与 function 两种协议共用,均来自 "parameters" 字段)
|
||||
parameters: list[ParameterInfo] = field(default_factory=list)
|
||||
|
||||
# 返回值描述
|
||||
return_info: ReturnInfo = field(default_factory=ReturnInfo)
|
||||
|
||||
# ── 便捷属性:参数分组 ────────────────────────────────────
|
||||
|
||||
@property
|
||||
def in_parameters(self) -> list[ParameterInfo]:
|
||||
"""inout = "in" 或 "inout" 的参数(作为输入)"""
|
||||
return [p for p in self.parameters if p.inout in ("in", "inout")]
|
||||
|
||||
@property
|
||||
def out_parameters(self) -> list[ParameterInfo]:
|
||||
"""inout = "out" 或 "inout" 的参数(作为输出)"""
|
||||
return [p for p in self.parameters if p.inout in ("out", "inout")]
|
||||
|
||||
# ── 便捷属性:HTTP ────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def http_full_url(self) -> str:
|
||||
"""
|
||||
HTTP 协议的完整请求 URL:
|
||||
url (base) + name (path)
|
||||
例:url="http://127.0.0.1/api", name="/delete_user"
|
||||
→ "http://127.0.0.1/api/delete_user"
|
||||
"""
|
||||
if self.protocol != "http":
|
||||
return ""
|
||||
base = self.url.rstrip("/")
|
||||
path = self.name.lstrip("/")
|
||||
return f"{base}/{path}" if base else f"/{path}"
|
||||
|
||||
# ── 便捷属性:Function ────────────────────────────────────
|
||||
|
||||
@property
|
||||
def module_path(self) -> str:
|
||||
"""
|
||||
将 url(.py 文件路径)转换为 Python 模块导入路径。
|
||||
规则:
|
||||
1. 去除 .py 后缀
|
||||
2. 去除前导 ./ 或 /
|
||||
3. 将路径分隔符替换为 .
|
||||
示例:
|
||||
"create_user.py" → "create_user"
|
||||
"myapp/services/user_service.py" → "myapp.services.user_service"
|
||||
"./services/user_service.py" → "services.user_service"
|
||||
"myapp.services.user_service" → "myapp.services.user_service"
|
||||
"""
|
||||
if self.protocol != "function" or not self.url:
|
||||
return ""
|
||||
u = self.url.strip()
|
||||
# 已经是点分模块路径(无路径分隔符且无 .py 后缀)
|
||||
if re.match(r'^[\w]+(\.[\w]+)*$', u) and not u.endswith(".py"):
|
||||
return u
|
||||
# 文件路径 → 模块路径
|
||||
u = u.replace("\\", "/")
|
||||
u = re.sub(r'\.py$', '', u)
|
||||
u = re.sub(r'^\.?/', '', u)
|
||||
return u.replace("/", ".")
|
||||
|
||||
@property
|
||||
def source_file(self) -> str:
|
||||
"""function 协议的源文件路径(原始 url 字段)"""
|
||||
return self.url if self.protocol == "function" else ""
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# 描述文件根结构
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
|
||||
@dataclass
|
||||
class ApiDescriptor:
|
||||
project: str
|
||||
description: str
|
||||
interfaces: list[InterfaceInfo]
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
# 解析器
|
||||
# ══════════════════════════════════════════════════════════════
|
||||
|
||||
class InterfaceParser:
|
||||
|
||||
# ── 公开接口 ──────────────────────────────────────────────
|
||||
|
||||
def parse_file(self, file_path: str) -> ApiDescriptor:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return self.parse(data)
|
||||
|
||||
def parse(self, data: dict | list) -> ApiDescriptor:
|
||||
"""
|
||||
兼容三种格式:
|
||||
新格式 :{"project": "...", "description": "...", "units": [...]}
|
||||
旧格式1:{"project": "...", "units": [...]}
|
||||
旧格式2:[...] 直接是接口数组
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
return ApiDescriptor(
|
||||
project="default",
|
||||
description="",
|
||||
interfaces=self._parse_units(data),
|
||||
)
|
||||
return ApiDescriptor(
|
||||
project=data.get("project", "default"),
|
||||
description=data.get("description", ""),
|
||||
interfaces=self._parse_units(data.get("units", [])),
|
||||
)
|
||||
|
||||
# ── 单元解析 ──────────────────────────────────────────────
|
||||
|
||||
def _parse_units(self, units: list[dict]) -> list[InterfaceInfo]:
|
||||
result = []
|
||||
for item in units:
|
||||
protocol = (
|
||||
item.get("protocol") or item.get("type") or "function"
|
||||
).lower()
|
||||
if protocol == "http":
|
||||
result.append(self._parse_http(item))
|
||||
else:
|
||||
result.append(self._parse_function(item))
|
||||
return result
|
||||
|
||||
def _parse_http(self, item: dict) -> InterfaceInfo:
|
||||
"""
|
||||
HTTP 接口解析。
|
||||
参数统一从 "parameters" 字段读取,
|
||||
同时兼容旧格式的 "request_parameters" / "url_parameters"。
|
||||
优先级:parameters > request_parameters + url_parameters
|
||||
"""
|
||||
# 优先使用新统一字段 "parameters"
|
||||
raw_params: dict = item.get("parameters", {})
|
||||
|
||||
# 旧格式兼容:合并 url_parameters + request_parameters
|
||||
if not raw_params:
|
||||
raw_params = {
|
||||
**item.get("url_parameters", {}),
|
||||
**item.get("request_parameters", {}),
|
||||
}
|
||||
|
||||
return InterfaceInfo(
|
||||
name=item["name"],
|
||||
description=item.get("description", ""),
|
||||
protocol="http",
|
||||
url=item.get("url", ""),
|
||||
requirement_id=item.get("requirement_id", ""),
|
||||
method=item.get("method", "get").lower(),
|
||||
parameters=self._parse_param_dict(raw_params, with_inout=True),
|
||||
return_info=self._parse_return(item.get("return", {})),
|
||||
)
|
||||
|
||||
def _parse_function(self, item: dict) -> InterfaceInfo:
|
||||
return InterfaceInfo(
|
||||
name=item["name"],
|
||||
description=item.get("description", ""),
|
||||
protocol="function",
|
||||
url=item.get("url", ""),
|
||||
requirement_id=item.get("requirement_id", ""),
|
||||
parameters=self._parse_param_dict(
|
||||
item.get("parameters", {}), with_inout=True,
|
||||
),
|
||||
return_info=self._parse_return(item.get("return", {})),
|
||||
)
|
||||
|
||||
# ── 参数解析 ──────────────────────────────────────────────
|
||||
|
||||
def _parse_param_dict(
|
||||
self,
|
||||
params: dict,
|
||||
with_inout: bool = False,
|
||||
) -> list[ParameterInfo]:
|
||||
result = []
|
||||
for name, info in params.items():
|
||||
result.append(ParameterInfo(
|
||||
name=name,
|
||||
type=info.get("type", "string"),
|
||||
description=info.get("description", ""),
|
||||
required=info.get("required", True),
|
||||
default=info.get("default", None),
|
||||
inout=info.get("inout", "in") if with_inout else "in",
|
||||
))
|
||||
return result
|
||||
|
||||
# ── 返回值解析 ────────────────────────────────────────────
|
||||
|
||||
def _parse_return(self, ret: dict) -> ReturnInfo:
|
||||
if not ret:
|
||||
return ReturnInfo()
|
||||
return ReturnInfo(
|
||||
type=ret.get("type", ""),
|
||||
description=ret.get("description", ""),
|
||||
on_success=self._parse_return_case(ret.get("on_success", {})),
|
||||
on_failure=self._parse_return_case(ret.get("on_failure", {})),
|
||||
)
|
||||
|
||||
def _parse_return_case(self, case: dict) -> ReturnCase:
|
||||
if not case:
|
||||
return ReturnCase()
|
||||
return ReturnCase(
|
||||
value=case.get("value"),
|
||||
description=case.get("description", ""),
|
||||
)
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# 转换为 LLM 可读摘要
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
def to_summary_dict(self, interfaces: list[InterfaceInfo]) -> list[dict]:
|
||||
return [
|
||||
self._http_summary(i) if i.protocol == "http"
|
||||
else self._function_summary(i)
|
||||
for i in interfaces
|
||||
]
|
||||
|
||||
def _http_summary(self, iface: InterfaceInfo) -> dict:
|
||||
ri = iface.return_info
|
||||
return {
|
||||
"name": iface.name,
|
||||
"requirement_id": iface.requirement_id,
|
||||
"description": iface.description,
|
||||
"protocol": "http",
|
||||
"full_url": iface.http_full_url,
|
||||
"method": iface.method,
|
||||
# HTTP 接口统一用 "parameters" 输出给 LLM
|
||||
"parameters": self._params_to_dict(iface.parameters),
|
||||
"return": {
|
||||
"type": ri.type,
|
||||
"description": ri.description,
|
||||
"on_success": {
|
||||
"value": ri.on_success.value,
|
||||
"description": ri.on_success.description,
|
||||
},
|
||||
"on_failure": {
|
||||
"value": ri.on_failure.value,
|
||||
"description": ri.on_failure.description,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _function_summary(self, iface: InterfaceInfo) -> dict:
|
||||
ri = iface.return_info
|
||||
return {
|
||||
"name": iface.name,
|
||||
"requirement_id": iface.requirement_id,
|
||||
"description": iface.description,
|
||||
"protocol": "function",
|
||||
"source_file": iface.source_file,
|
||||
"module_path": iface.module_path,
|
||||
"parameters": self._params_to_dict(iface.parameters),
|
||||
"return": {
|
||||
"type": ri.type,
|
||||
"description": ri.description,
|
||||
"on_success": {
|
||||
"value": ri.on_success.value,
|
||||
"description": ri.on_success.description,
|
||||
},
|
||||
"on_failure": {
|
||||
"value": ri.on_failure.value,
|
||||
"description": ri.on_failure.description,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _params_to_dict(self, params: list[ParameterInfo]) -> dict:
|
||||
result = {}
|
||||
for p in params:
|
||||
entry: dict = {
|
||||
"type": p.type,
|
||||
"inout": p.inout,
|
||||
"description": p.description,
|
||||
"required": p.required,
|
||||
}
|
||||
if p.default is not None:
|
||||
entry["default"] = p.default
|
||||
result[p.name] = entry
|
||||
return result
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
"""
|
||||
提示词构建器
|
||||
升级点:
|
||||
- 传递项目 description
|
||||
- function 协议:source_file(.py) → module_path → from <module> import <func>
|
||||
- HTTP 协议:full_url = url(base) + name(path)
|
||||
- parameters 统一字段,inout 区分输入/输出
|
||||
"""
|
||||
|
||||
import json
|
||||
from core.parser import InterfaceInfo, InterfaceParser
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are a senior software test engineer. Generate test cases based on the provided
|
||||
interface descriptions and test requirements.
|
||||
|
||||
═══════════════════════════════════════════════════════
|
||||
OUTPUT FORMAT
|
||||
Return ONLY a valid JSON array. No markdown fences. No extra text.
|
||||
═══════════════════════════════════════════════════════
|
||||
Each element = ONE test case:
|
||||
{
|
||||
"test_id" : "unique id, e.g. TC_001",
|
||||
"test_name" : "short descriptive name",
|
||||
"description" : "what this test verifies",
|
||||
"requirement" : "the original requirement text this case maps to",
|
||||
"requirement_id" : "e.g. REQ.01",
|
||||
"steps": [
|
||||
{
|
||||
"step_no" : 1,
|
||||
"interface_name" : "function or endpoint name",
|
||||
"protocol" : "http or function",
|
||||
"url" : "source_file (function) or full_url (http)",
|
||||
"purpose" : "why this step is needed",
|
||||
"input": {
|
||||
// Only parameters with inout = "in" or "inout"
|
||||
"param_name": <value>
|
||||
},
|
||||
"use_output_of": { // optional
|
||||
"step_no" : 1,
|
||||
"field" : "user_id",
|
||||
"as_param" : "user_id"
|
||||
},
|
||||
"assertions": [
|
||||
{
|
||||
"field" : "field_name or 'return' or 'exception'",
|
||||
"operator" : "eq|ne|gt|lt|gte|lte|in|not_null|contains|raised|not_raised",
|
||||
"expected" : <value>,
|
||||
"message" : "human readable description"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"test_data_notes" : "explanation of auto-generated test data",
|
||||
"test_code" : "<complete Python script — see rules below>"
|
||||
}
|
||||
|
||||
═══════════════════════════════════════════════════════
|
||||
TEST CODE RULES
|
||||
═══════════════════════════════════════════════════════
|
||||
1. Complete, runnable Python script. No external test framework.
|
||||
2. Allowed imports: standard library + `requests` + the actual module under test.
|
||||
|
||||
── FUNCTION INTERFACES ────────────────────────────────
|
||||
3. Each function interface has:
|
||||
"source_file" : e.g. "create_user.py"
|
||||
"module_path" : e.g. "create_user" (derived from source_file)
|
||||
4. Import the REAL function using module_path:
|
||||
try:
|
||||
from <module_path> import <function_name>
|
||||
except ImportError:
|
||||
# Stub fallback — simulate on_success for positive tests,
|
||||
# on_failure / raise Exception for negative tests
|
||||
def <function_name>(<in_params>):
|
||||
return <on_success_value> # or raise Exception(...)
|
||||
5. Call the function with ONLY "in"/"inout" parameters.
|
||||
6. Capture the return value; assert against on_success or on_failure structure.
|
||||
7. If on_failure.value contains "exception" or "raises":
|
||||
- In negative tests: wrap call in try/except, assert exception IS raised.
|
||||
- Use field="exception", operator="raised", expected="Exception".
|
||||
|
||||
── HTTP INTERFACES ────────────────────────────────────
|
||||
8. Each HTTP interface has:
|
||||
"full_url" : complete URL, e.g. "http://127.0.0.1/api/delete_user"
|
||||
"method" : "get" | "post" | "put" | "delete" etc.
|
||||
9. Send request:
|
||||
resp = requests.<method>("<full_url>", json=<input_dict>)
|
||||
10. Assert on resp.status_code and resp.json() fields.
|
||||
|
||||
── MULTI-STEP ─────────────────────────────────────────
|
||||
11. Execute steps in order; extract fields from previous step's return value
|
||||
and pass to next step via use_output_of mapping.
|
||||
|
||||
── STRUCTURED OUTPUT (REQUIRED) ───────────────────────
|
||||
12. After EACH step, print:
|
||||
##STEP_RESULT## {"step_no":<n>,"interface_name":"...","status":"PASS|FAIL",
|
||||
"assertions":[{"field":"...","operator":"...","expected":...,
|
||||
"actual":...,"passed":true|false,"message":"..."}]}
|
||||
13. Final line must be:
|
||||
PASS: <summary> or FAIL: <summary>
|
||||
14. Wrap entire test body in try/except; print FAIL on any unhandled exception.
|
||||
15. Do NOT use unittest or pytest.
|
||||
|
||||
═══════════════════════════════════════════════════════
|
||||
ASSERTION RULES BY RETURN TYPE
|
||||
═══════════════════════════════════════════════════════
|
||||
return.type = "dict" → assert each key in on_success.value (positive)
|
||||
or on_failure.value (negative)
|
||||
return.type = "boolean" → field="return", operator="eq", expected=true/false
|
||||
return.type = "integer" → field="return", operator="eq"/"gt"/"lt"
|
||||
on_failure = exception → field="exception", operator="raised"
|
||||
|
||||
═══════════════════════════════════════════════════════
|
||||
COVERAGE GUIDELINES
|
||||
═══════════════════════════════════════════════════════
|
||||
- Per requirement: at least 1 positive + 1 negative test case.
|
||||
- Negative test_name/description MUST contain "negative" or "invalid".
|
||||
- Multi-interface requirements: single multi-step test case.
|
||||
- Cover ALL "in"/"inout" parameters (including optional ones).
|
||||
- Assert ALL fields described in on_success (positive) and on_failure (negative).
|
||||
- For "out" parameters: verify their values in assertions after the call.
|
||||
"""
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return SYSTEM_PROMPT.strip()
|
||||
|
||||
def build_user_prompt(
|
||||
self,
|
||||
requirements: list[str],
|
||||
interfaces: list[InterfaceInfo],
|
||||
parser: InterfaceParser,
|
||||
project: str = "",
|
||||
project_desc: str = "",
|
||||
) -> str:
|
||||
iface_summary = parser.to_summary_dict(interfaces)
|
||||
req_lines = "\n".join(
|
||||
f"{i + 1}. {r}" for i, r in enumerate(requirements)
|
||||
)
|
||||
|
||||
# 项目信息头
|
||||
project_section = ""
|
||||
if project or project_desc:
|
||||
project_section = "## Project\n"
|
||||
if project:
|
||||
project_section += f"Name : {project}\n"
|
||||
if project_desc:
|
||||
project_section += f"Description: {project_desc}\n"
|
||||
project_section += "\n"
|
||||
|
||||
return (
|
||||
f"{project_section}"
|
||||
f"## Available Interfaces\n"
|
||||
f"{json.dumps(iface_summary, ensure_ascii=False, indent=2)}\n\n"
|
||||
f"## Test Requirements\n"
|
||||
f"{req_lines}\n\n"
|
||||
f"Generate comprehensive test cases (positive + negative) for every requirement.\n"
|
||||
f"- For function interfaces: use 'module_path' for import, "
|
||||
f"'source_file' for reference.\n"
|
||||
f"- For HTTP interfaces: use 'full_url' as the request URL.\n"
|
||||
f"- Use requirement_id from the interface to populate the test case's "
|
||||
f"requirement_id field.\n"
|
||||
f"- Chain multiple interface calls in one test case when a requirement "
|
||||
f"involves more than one interface.\n"
|
||||
).strip()
|
||||
|
|
@ -0,0 +1,492 @@
|
|||
"""
|
||||
报告生成器:JSON 摘要 + HTML 可视化报告
|
||||
升级点:
|
||||
- 新增 requirement_id 列
|
||||
- 新增 out 参数覆盖率、on_success/on_failure 字段覆盖率
|
||||
- 缺口类型标签中文化扩充
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from core.analyzer import CoverageReport, Gap
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReportGenerator:
|
||||
|
||||
def save_json(self, report: CoverageReport, path: str):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self._to_dict(report), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"Coverage JSON → {path}")
|
||||
|
||||
def save_html(self, report: CoverageReport, path: str):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(self._build_html(report))
|
||||
logger.info(f"Coverage HTML → {path}")
|
||||
|
||||
# ── JSON 序列化 ───────────────────────────────────────────
|
||||
|
||||
def _to_dict(self, r: CoverageReport) -> dict:
|
||||
return {
|
||||
"generated_at": datetime.now().isoformat(timespec="seconds"),
|
||||
"summary": {
|
||||
"interface_coverage": f"{r.interface_coverage_rate:.1%}",
|
||||
"requirement_coverage": f"{r.requirement_coverage_rate:.1%}",
|
||||
"avg_in_param_coverage": f"{r.avg_in_param_coverage_rate:.1%}",
|
||||
"avg_success_field_coverage": f"{r.avg_success_field_coverage_rate:.1%}",
|
||||
"avg_failure_field_coverage": f"{r.avg_failure_field_coverage_rate:.1%}",
|
||||
"pass_rate": f"{r.pass_rate:.1%}",
|
||||
"total_interfaces": r.total_interfaces,
|
||||
"covered_interfaces": r.covered_interfaces,
|
||||
"total_requirements": r.total_requirements,
|
||||
"covered_requirements": r.covered_requirements,
|
||||
"total_test_cases": r.total_test_cases,
|
||||
"passed": r.passed_test_cases,
|
||||
"failed": r.failed_test_cases,
|
||||
"errors": r.error_test_cases,
|
||||
"critical_gaps": r.critical_gap_count,
|
||||
"high_gaps": r.high_gap_count,
|
||||
},
|
||||
"interface_details": [
|
||||
{
|
||||
"name": ic.interface_name,
|
||||
"requirement_id": ic.requirement_id,
|
||||
"protocol": ic.protocol,
|
||||
"is_covered": ic.is_covered,
|
||||
"has_positive_case": ic.has_positive_case,
|
||||
"has_negative_case": ic.has_negative_case,
|
||||
"in_param_coverage": f"{ic.in_param_coverage_rate:.1%}",
|
||||
"covered_in_params": sorted(ic.covered_in_params),
|
||||
"uncovered_in_params": sorted(ic.uncovered_in_params),
|
||||
"out_param_coverage": f"{ic.out_param_coverage_rate:.1%}",
|
||||
"covered_out_params": sorted(ic.asserted_out_params),
|
||||
"uncovered_out_params": sorted(ic.uncovered_out_params),
|
||||
"success_field_coverage": f"{ic.success_field_coverage_rate:.1%}",
|
||||
"covered_success_fields": sorted(ic.covered_success_fields),
|
||||
"uncovered_success_fields": sorted(ic.uncovered_success_fields),
|
||||
"failure_field_coverage": f"{ic.failure_field_coverage_rate:.1%}",
|
||||
"covered_failure_fields": sorted(ic.covered_failure_fields),
|
||||
"uncovered_failure_fields": sorted(ic.uncovered_failure_fields),
|
||||
"expects_exception": ic.expects_exception,
|
||||
"exception_case_covered": (
|
||||
ic.exception_case_covered if ic.expects_exception else None
|
||||
),
|
||||
"covering_test_ids": ic.covering_test_ids,
|
||||
}
|
||||
for ic in r.interface_coverages
|
||||
],
|
||||
"requirement_details": [
|
||||
{
|
||||
"requirement": rc.requirement,
|
||||
"requirement_id": rc.requirement_id,
|
||||
"is_covered": rc.is_covered,
|
||||
"covering_test_ids": rc.covering_test_ids,
|
||||
}
|
||||
for rc in r.requirement_coverages
|
||||
],
|
||||
"gaps": [
|
||||
{
|
||||
"gap_type": g.gap_type,
|
||||
"severity": g.severity,
|
||||
"target": g.target,
|
||||
"detail": g.detail,
|
||||
"suggestion": g.suggestion,
|
||||
}
|
||||
for g in r.gaps
|
||||
],
|
||||
}
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# HTML 报告
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
_GAP_LABELS: dict[str, str] = {
|
||||
"interface_not_covered": "接口未覆盖",
|
||||
"in_param_not_covered": "入参未覆盖",
|
||||
"out_param_not_asserted": "出参未断言",
|
||||
"success_field_not_asserted": "成功返回字段未断言",
|
||||
"failure_field_not_asserted": "失败返回字段未断言",
|
||||
"exception_case_not_tested": "异常场景未测试",
|
||||
"missing_positive_case": "缺少正向用例",
|
||||
"missing_negative_case": "缺少负向用例",
|
||||
"requirement_not_covered": "需求未覆盖",
|
||||
"test_failed": "用例执行失败",
|
||||
"test_error": "用例执行异常",
|
||||
}
|
||||
|
||||
# ── 样式辅助 ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _rate_color(rate: float) -> str:
|
||||
if rate >= 0.8: return "#27ae60"
|
||||
if rate >= 0.5: return "#f39c12"
|
||||
return "#e74c3c"
|
||||
|
||||
@staticmethod
|
||||
def _sev_badge(sev: str) -> str:
|
||||
colors = {
|
||||
"critical": "#e74c3c", "high": "#e67e22",
|
||||
"medium": "#f1c40f", "low": "#3498db",
|
||||
}
|
||||
c = colors.get(sev, "#95a5a6")
|
||||
return (
|
||||
f'<span style="background:{c};color:#fff;padding:2px 8px;'
|
||||
f'border-radius:4px;font-size:11px;font-weight:600">'
|
||||
f'{sev.upper()}</span>'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _bool_badge(val: bool) -> str:
|
||||
if val:
|
||||
return '<span style="color:#27ae60;font-size:15px;font-weight:700">✔</span>'
|
||||
return '<span style="color:#e74c3c;font-size:15px;font-weight:700">✘</span>'
|
||||
|
||||
def _pct_bar(self, rate: float, w: int = 120) -> str:
|
||||
color = self._rate_color(rate)
|
||||
filled = int(w * rate)
|
||||
return (
|
||||
f'<div style="display:inline-flex;align-items:center;gap:5px">'
|
||||
f'<div style="width:{w}px;background:#ecf0f1;border-radius:3px;height:11px">'
|
||||
f'<div style="width:{filled}px;background:{color};'
|
||||
f'height:11px;border-radius:3px"></div></div>'
|
||||
f'<span style="font-size:12px;font-weight:600;color:{color}">'
|
||||
f'{rate * 100:.0f}%</span></div>'
|
||||
)
|
||||
|
||||
def _metric_card(
|
||||
self, label: str, value: str, sub: str = "", color: str = "#2c3e50"
|
||||
) -> str:
|
||||
return (
|
||||
f'<div style="background:#fff;border-radius:10px;padding:16px 22px;'
|
||||
f'box-shadow:0 2px 10px rgba(0,0,0,.08);min-width:148px;text-align:center">'
|
||||
f'<div style="font-size:26px;font-weight:700;color:{color}">{value}</div>'
|
||||
f'<div style="font-size:12px;color:#7f8c8d;margin-top:4px">{label}</div>'
|
||||
f'<div style="font-size:11px;color:#bdc3c7;margin-top:2px">{sub}</div>'
|
||||
f'</div>'
|
||||
)
|
||||
|
||||
# ── 汇总卡片 ──────────────────────────────────────────────
|
||||
|
||||
def _build_cards(self, r: CoverageReport) -> str:
|
||||
rc = self._rate_color
|
||||
return "".join([
|
||||
self._metric_card(
|
||||
"接口覆盖率", f"{r.interface_coverage_rate:.0%}",
|
||||
f"{r.covered_interfaces}/{r.total_interfaces}",
|
||||
rc(r.interface_coverage_rate),
|
||||
),
|
||||
self._metric_card(
|
||||
"需求覆盖率", f"{r.requirement_coverage_rate:.0%}",
|
||||
f"{r.covered_requirements}/{r.total_requirements}",
|
||||
rc(r.requirement_coverage_rate),
|
||||
),
|
||||
self._metric_card(
|
||||
"入参覆盖率", f"{r.avg_in_param_coverage_rate:.0%}",
|
||||
"avg in-params",
|
||||
rc(r.avg_in_param_coverage_rate),
|
||||
),
|
||||
self._metric_card(
|
||||
"成功返回覆盖", f"{r.avg_success_field_coverage_rate:.0%}",
|
||||
"on_success fields",
|
||||
rc(r.avg_success_field_coverage_rate),
|
||||
),
|
||||
self._metric_card(
|
||||
"失败返回覆盖", f"{r.avg_failure_field_coverage_rate:.0%}",
|
||||
"on_failure fields",
|
||||
rc(r.avg_failure_field_coverage_rate),
|
||||
),
|
||||
self._metric_card(
|
||||
"用例通过率", f"{r.pass_rate:.0%}",
|
||||
f"{r.passed_test_cases}/{r.total_test_cases}",
|
||||
rc(r.pass_rate),
|
||||
),
|
||||
self._metric_card(
|
||||
"Critical 缺口", str(r.critical_gap_count),
|
||||
f"High: {r.high_gap_count}",
|
||||
"#e74c3c" if r.critical_gap_count > 0 else "#27ae60",
|
||||
),
|
||||
])
|
||||
|
||||
# ── 接口覆盖表 ────────────────────────────────────────────
|
||||
|
||||
def _build_interface_table(self, r: CoverageReport) -> str:
|
||||
rows = ""
|
||||
for ic in r.interface_coverages:
|
||||
proto_color = "#3498db" if ic.protocol == "http" else "#9b59b6"
|
||||
|
||||
# 异常测试单元格
|
||||
if ic.expects_exception:
|
||||
exc_cell = self._bool_badge(ic.exception_case_covered)
|
||||
else:
|
||||
exc_cell = '<span style="color:#bdc3c7;font-size:13px">N/A</span>'
|
||||
|
||||
# 未覆盖项 tooltip
|
||||
def missing_tip(items: set[str], label: str) -> str:
|
||||
if not items:
|
||||
return ""
|
||||
joined = ", ".join(sorted(items))
|
||||
return (
|
||||
f'<div style="font-size:10px;color:#e74c3c;margin-top:2px">'
|
||||
f'缺: {joined}</div>'
|
||||
)
|
||||
|
||||
rows += f"""
|
||||
<tr>
|
||||
<td>
|
||||
<code style="font-size:12px">{ic.interface_name}</code>
|
||||
</td>
|
||||
<td style="font-size:11px;color:#7f8c8d;white-space:nowrap">
|
||||
{ic.requirement_id or "—"}
|
||||
</td>
|
||||
<td>
|
||||
<span style="background:{proto_color};color:#fff;padding:1px 7px;
|
||||
border-radius:3px;font-size:11px">{ic.protocol.upper()}</span>
|
||||
</td>
|
||||
<td style="text-align:center">{self._bool_badge(ic.is_covered)}</td>
|
||||
<td style="text-align:center">{self._bool_badge(ic.has_positive_case)}</td>
|
||||
<td style="text-align:center">{self._bool_badge(ic.has_negative_case)}</td>
|
||||
<td>
|
||||
{self._pct_bar(ic.in_param_coverage_rate)}
|
||||
{missing_tip(ic.uncovered_in_params, "缺")}
|
||||
</td>
|
||||
<td>
|
||||
{self._pct_bar(ic.out_param_coverage_rate)}
|
||||
{missing_tip(ic.uncovered_out_params, "缺")}
|
||||
</td>
|
||||
<td>
|
||||
{self._pct_bar(ic.success_field_coverage_rate)}
|
||||
{missing_tip(ic.uncovered_success_fields, "缺")}
|
||||
</td>
|
||||
<td>
|
||||
{self._pct_bar(ic.failure_field_coverage_rate)}
|
||||
{missing_tip(ic.uncovered_failure_fields, "缺")}
|
||||
</td>
|
||||
<td style="text-align:center">{exc_cell}</td>
|
||||
<td style="font-size:11px;color:#7f8c8d">
|
||||
{", ".join(ic.covering_test_ids) or "—"}
|
||||
</td>
|
||||
</tr>"""
|
||||
return rows
|
||||
|
||||
# ── 需求覆盖表 ────────────────────────────────────────────
|
||||
|
||||
def _build_requirement_table(self, r: CoverageReport) -> str:
|
||||
rows = ""
|
||||
for rc in r.requirement_coverages:
|
||||
rows += f"""
|
||||
<tr>
|
||||
<td style="font-size:11px;color:#7f8c8d;white-space:nowrap">
|
||||
{rc.requirement_id or "—"}
|
||||
</td>
|
||||
<td style="max-width:320px;font-size:13px">{rc.requirement}</td>
|
||||
<td style="text-align:center">{self._bool_badge(rc.is_covered)}</td>
|
||||
<td style="font-size:11px;color:#7f8c8d">
|
||||
{", ".join(rc.covering_test_ids) or "—"}
|
||||
</td>
|
||||
</tr>"""
|
||||
return rows
|
||||
|
||||
# ── 缺口清单表 ────────────────────────────────────────────
|
||||
|
||||
def _build_gap_table(self, r: CoverageReport) -> str:
|
||||
if not r.gaps:
|
||||
return (
|
||||
'<tr><td colspan="6" style="text-align:center;color:#27ae60;'
|
||||
'padding:24px;font-size:14px">🎉 No gaps found! Full coverage achieved.</td></tr>'
|
||||
)
|
||||
rows = ""
|
||||
for i, g in enumerate(r.gaps, 1):
|
||||
rows += f"""
|
||||
<tr>
|
||||
<td style="text-align:center;color:#7f8c8d;font-size:12px">{i}</td>
|
||||
<td style="white-space:nowrap">{self._sev_badge(g.severity)}</td>
|
||||
<td>
|
||||
<span style="font-size:11px;background:#ecf0f1;padding:2px 6px;
|
||||
border-radius:3px;white-space:nowrap">
|
||||
{self._GAP_LABELS.get(g.gap_type, g.gap_type)}
|
||||
</span>
|
||||
</td>
|
||||
<td><code style="font-size:11px">{g.target}</code></td>
|
||||
<td style="font-size:12px;max-width:280px">{g.detail}</td>
|
||||
<td style="font-size:12px;color:#2980b9;max-width:280px">{g.suggestion}</td>
|
||||
</tr>"""
|
||||
return rows
|
||||
|
||||
# ── 缺口统计图(纯 CSS 横向柱状图)────────────────────────
|
||||
|
||||
def _build_gap_chart(self, r: CoverageReport) -> str:
|
||||
from collections import Counter
|
||||
type_counts = Counter(g.gap_type for g in r.gaps)
|
||||
if not type_counts:
|
||||
return ""
|
||||
|
||||
max_count = max(type_counts.values(), default=1)
|
||||
bars = ""
|
||||
sev_color_map: dict[str, str] = {}
|
||||
for g in r.gaps:
|
||||
sev_color_map[g.gap_type] = {
|
||||
"critical": "#e74c3c", "high": "#e67e22",
|
||||
"medium": "#f1c40f", "low": "#3498db",
|
||||
}.get(g.severity, "#95a5a6")
|
||||
|
||||
for gap_type, count in sorted(type_counts.items(), key=lambda x: -x[1]):
|
||||
label = self._GAP_LABELS.get(gap_type, gap_type)
|
||||
color = sev_color_map.get(gap_type, "#95a5a6")
|
||||
width = int(count / max_count * 320)
|
||||
bars += f"""
|
||||
<div style="display:flex;align-items:center;gap:10px;margin-bottom:8px">
|
||||
<div style="width:160px;font-size:12px;color:#555;text-align:right;
|
||||
white-space:nowrap;overflow:hidden;text-overflow:ellipsis"
|
||||
title="{label}">{label}</div>
|
||||
<div style="width:{width}px;height:18px;background:{color};
|
||||
border-radius:3px;transition:width .3s"></div>
|
||||
<span style="font-size:12px;font-weight:600;color:{color}">{count}</span>
|
||||
</div>"""
|
||||
|
||||
return f"""
|
||||
<div style="background:#fff;border-radius:10px;padding:20px 24px;
|
||||
box-shadow:0 2px 8px rgba(0,0,0,.07);margin-bottom:8px">
|
||||
<div style="font-size:14px;font-weight:600;margin-bottom:16px;color:#2c3e50">
|
||||
缺口类型分布
|
||||
</div>
|
||||
{bars}
|
||||
</div>"""
|
||||
|
||||
# ── 组装完整 HTML ─────────────────────────────────────────
|
||||
|
||||
def _build_html(self, r: CoverageReport) -> str:
|
||||
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
css = """
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
|
||||
background: #f4f6f9; color: #2c3e50;
|
||||
}
|
||||
.wrap { max-width: 1440px; margin: 0 auto; padding: 32px 20px; }
|
||||
h1 { font-size: 24px; font-weight: 700; margin-bottom: 4px; }
|
||||
.sub { color: #7f8c8d; font-size: 12px; margin-bottom: 26px; }
|
||||
.cards { display: flex; flex-wrap: wrap; gap: 14px; margin-bottom: 32px; }
|
||||
h2 {
|
||||
font-size: 16px; font-weight: 600; margin: 28px 0 12px;
|
||||
padding-left: 10px; border-left: 4px solid #3498db;
|
||||
}
|
||||
table {
|
||||
width: 100%; border-collapse: collapse; background: #fff;
|
||||
border-radius: 10px; overflow: hidden;
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,.07);
|
||||
}
|
||||
th {
|
||||
background: #2c3e50; color: #fff; padding: 10px 12px;
|
||||
font-size: 12px; text-align: left; white-space: nowrap;
|
||||
}
|
||||
td {
|
||||
padding: 9px 12px; font-size: 12px;
|
||||
border-bottom: 1px solid #ecf0f1; vertical-align: middle;
|
||||
}
|
||||
tr:last-child td { border-bottom: none; }
|
||||
tr:hover td { background: #f8f9fa; }
|
||||
code {
|
||||
background: #ecf0f1; padding: 1px 4px; border-radius: 3px;
|
||||
font-family: "SFMono-Regular", Consolas, monospace;
|
||||
}
|
||||
.footer {
|
||||
text-align: center; color: #bdc3c7;
|
||||
font-size: 11px; margin-top: 36px; padding-bottom: 20px;
|
||||
}
|
||||
"""
|
||||
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>AI Test Coverage Report</title>
|
||||
<style>{css}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrap">
|
||||
|
||||
<!-- 标题 -->
|
||||
<h1>🧪 AI Test Coverage Report</h1>
|
||||
<div class="sub">Generated at {ts} |
|
||||
Interfaces: {r.total_interfaces} |
|
||||
Test Cases: {r.total_test_cases} |
|
||||
Gaps: {len(r.gaps)}
|
||||
</div>
|
||||
|
||||
<!-- 汇总卡片 -->
|
||||
<div class="cards">{self._build_cards(r)}</div>
|
||||
|
||||
<!-- 缺口分布图 -->
|
||||
{self._build_gap_chart(r)}
|
||||
|
||||
<!-- 接口覆盖详情 -->
|
||||
<h2>📡 接口覆盖详情</h2>
|
||||
<div style="overflow-x:auto">
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>接口名称</th>
|
||||
<th>REQ ID</th>
|
||||
<th>协议</th>
|
||||
<th>已覆盖</th>
|
||||
<th>正向</th>
|
||||
<th>负向</th>
|
||||
<th>入参覆盖率</th>
|
||||
<th>出参断言率</th>
|
||||
<th>成功返回覆盖</th>
|
||||
<th>失败返回覆盖</th>
|
||||
<th>异常测试</th>
|
||||
<th>覆盖用例</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>{self._build_interface_table(r)}</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- 需求覆盖详情 -->
|
||||
<h2>📋 需求覆盖详情</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>REQ ID</th>
|
||||
<th>需求描述</th>
|
||||
<th>已覆盖</th>
|
||||
<th>覆盖用例</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>{self._build_requirement_table(r)}</tbody>
|
||||
</table>
|
||||
|
||||
<!-- 缺口清单 -->
|
||||
<h2>🔍 缺口清单(共 {len(r.gaps)} 项,
|
||||
<span style="color:#e74c3c">Critical: {r.critical_gap_count}</span>,
|
||||
<span style="color:#e67e22">High: {r.high_gap_count}</span>)
|
||||
</h2>
|
||||
<div style="overflow-x:auto">
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>#</th>
|
||||
<th>严重程度</th>
|
||||
<th>缺口类型</th>
|
||||
<th>目标</th>
|
||||
<th>详情</th>
|
||||
<th>补充建议</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>{self._build_gap_table(r)}</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<div class="footer">
|
||||
AI-Powered Test Generator · Coverage Analyzer ·
|
||||
Report generated at {ts}
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>"""
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
"""
|
||||
测试文件生成器
|
||||
升级点:
|
||||
- 接收 project / project_desc,写入头部注释
|
||||
- 头部注释增加 url / source_file / module_path 信息
|
||||
- 输出到 <GENERATED_TESTS_DIR>/<project>/ 子目录
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestGenerator:
|
||||
|
||||
def __init__(self, project: str = "", project_desc: str = ""):
|
||||
self.project = project
|
||||
self.project_desc = project_desc
|
||||
|
||||
safe = self._safe_name(project) if project else ""
|
||||
base = Path(config.GENERATED_TESTS_DIR)
|
||||
self.output_dir = base / safe if safe else base
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── 公开接口 ──────────────────────────────────────────────
|
||||
|
||||
def generate(self, test_cases: list[dict]) -> list[Path]:
|
||||
self._save_summary(test_cases)
|
||||
files: list[Path] = []
|
||||
for tc in test_cases:
|
||||
path = self._write_file(tc)
|
||||
if path:
|
||||
files.append(path)
|
||||
logger.info(
|
||||
f"Generated {len(files)} test file(s) → '{self.output_dir}'"
|
||||
)
|
||||
return files
|
||||
|
||||
# ── 写入单个测试文件 ──────────────────────────────────────
|
||||
|
||||
def _write_file(self, tc: dict) -> Path | None:
|
||||
test_id = tc.get("test_id", "TC_UNKNOWN")
|
||||
test_code = tc.get("test_code", "").strip()
|
||||
if not test_code:
|
||||
logger.warning(f"[{test_id}] No test_code, skipping.")
|
||||
return None
|
||||
|
||||
file_path = self.output_dir / f"{self._safe_name(test_id)}.py"
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(self._build_header(tc))
|
||||
f.write("\n")
|
||||
f.write(test_code)
|
||||
f.write("\n")
|
||||
|
||||
logger.info(f" Written: {file_path}")
|
||||
return file_path
|
||||
|
||||
# ── 文件头注释 ────────────────────────────────────────────
|
||||
|
||||
def _build_header(self, tc: dict) -> str:
|
||||
req_id = tc.get("requirement_id", "")
|
||||
req_text = tc.get("requirement", "")
|
||||
data_notes = tc.get("test_data_notes", "")
|
||||
steps = tc.get("steps", [])
|
||||
|
||||
# 项目信息行
|
||||
proj_lines = ""
|
||||
if self.project:
|
||||
proj_lines += f"# Project : {self.project}\n"
|
||||
if self.project_desc:
|
||||
proj_lines += f"# Proj Desc : {self.project_desc}\n"
|
||||
|
||||
# 步骤注释
|
||||
step_lines = ""
|
||||
for s in steps:
|
||||
protocol = s.get("protocol", "").upper()
|
||||
iface = s.get("interface_name", "")
|
||||
url = s.get("url", "")
|
||||
purpose = s.get("purpose", "")
|
||||
|
||||
step_lines += (
|
||||
f"# Step {s.get('step_no')}: [{protocol}] {iface}\n"
|
||||
f"# URL : {url}\n"
|
||||
f"# Purpose: {purpose}\n"
|
||||
)
|
||||
|
||||
inputs = s.get("input", {})
|
||||
if inputs:
|
||||
step_lines += (
|
||||
f"# Input : "
|
||||
f"{json.dumps(inputs, ensure_ascii=False)}\n"
|
||||
)
|
||||
|
||||
for a in s.get("assertions", []):
|
||||
msg = a.get("message", "")
|
||||
step_lines += (
|
||||
f"# Assert : {a.get('field')} "
|
||||
f"{a.get('operator')} {a.get('expected')}"
|
||||
f"{' — ' + msg if msg else ''}\n"
|
||||
)
|
||||
|
||||
return (
|
||||
f"# {'=' * 60}\n"
|
||||
f"{proj_lines}"
|
||||
f"# Test ID : {tc.get('test_id', '')}\n"
|
||||
f"# Test Name : {tc.get('test_name', '')}\n"
|
||||
f"# Requirement: [{req_id}] {req_text}\n"
|
||||
f"# Description: {tc.get('description', '')}\n"
|
||||
f"# Test Data : {data_notes}\n"
|
||||
f"# Steps:\n"
|
||||
f"{step_lines.rstrip()}\n"
|
||||
f"# {'=' * 60}\n"
|
||||
f"import os\n"
|
||||
f"import json\n"
|
||||
)
|
||||
|
||||
# ── 保存摘要 JSON ─────────────────────────────────────────
|
||||
|
||||
def _save_summary(self, test_cases: list[dict]):
|
||||
summary = [
|
||||
{k: v for k, v in tc.items() if k != "test_code"}
|
||||
for tc in test_cases
|
||||
]
|
||||
path = self.output_dir / "test_cases_summary.json"
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"Summary → {path}")
|
||||
|
||||
# ── 工具 ──────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _safe_name(name: str) -> str:
|
||||
return re.sub(r'[^\w\-]', '_', name).strip("_")
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
step_no: int
|
||||
interface_name: str
|
||||
status: str # PASS | FAIL | ERROR
|
||||
assertions: list[dict] = field(default_factory=list)
|
||||
# 每条 assertion: {"field","operator","expected","actual","passed","message"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
test_id: str
|
||||
file_path: str
|
||||
status: str # PASS | FAIL | ERROR | TIMEOUT
|
||||
message: str = ""
|
||||
duration: float = 0.0
|
||||
stdout: str = ""
|
||||
stderr: str = ""
|
||||
step_results: list[StepResult] = field(default_factory=list)
|
||||
|
||||
|
||||
class TestRunner:
|
||||
|
||||
def run_all(self, test_files: list[Path]) -> list[TestResult]:
|
||||
results = []
|
||||
total = len(test_files)
|
||||
print(f"\n{'═'*62}")
|
||||
print(f" Running {total} test(s) …")
|
||||
print(f"{'═'*62}")
|
||||
for idx, fp in enumerate(test_files, 1):
|
||||
print(f"\n[{idx}/{total}] {fp.name}")
|
||||
result = self._run_one(fp)
|
||||
results.append(result)
|
||||
self._print_result(result)
|
||||
return results
|
||||
|
||||
# ── 执行单个脚本 ──────────────────────────────────────────
|
||||
|
||||
def _run_one(self, file_path: Path) -> TestResult:
|
||||
test_id = file_path.stem
|
||||
t0 = time.time()
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
[sys.executable, str(file_path)],
|
||||
capture_output=True, text=True,
|
||||
timeout=config.TEST_TIMEOUT,
|
||||
env=self._env(),
|
||||
)
|
||||
duration = time.time() - t0
|
||||
stdout = proc.stdout.strip()
|
||||
stderr = proc.stderr.strip()
|
||||
status, message = self._parse_output(stdout, proc.returncode)
|
||||
step_results = self._parse_step_results(stdout)
|
||||
return TestResult(
|
||||
test_id=test_id, file_path=str(file_path),
|
||||
status=status, message=message,
|
||||
duration=duration, stdout=stdout, stderr=stderr,
|
||||
step_results=step_results,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
return TestResult(
|
||||
test_id=test_id, file_path=str(file_path),
|
||||
status="TIMEOUT", message=f"Exceeded {config.TEST_TIMEOUT}s",
|
||||
duration=time.time() - t0,
|
||||
)
|
||||
except Exception as e:
|
||||
return TestResult(
|
||||
test_id=test_id, file_path=str(file_path),
|
||||
status="ERROR", message=str(e),
|
||||
duration=time.time() - t0,
|
||||
)
|
||||
|
||||
# ── 解析输出 ──────────────────────────────────────────────
|
||||
|
||||
def _parse_output(self, stdout: str, returncode: int) -> tuple[str, str]:
|
||||
if not stdout:
|
||||
return "FAIL", f"No output (exit={returncode})"
|
||||
last = stdout.strip().splitlines()[-1].strip()
|
||||
upper = last.upper()
|
||||
if upper.startswith("PASS"):
|
||||
return "PASS", last[5:].strip()
|
||||
if upper.startswith("FAIL"):
|
||||
return "FAIL", last[5:].strip()
|
||||
return ("PASS" if returncode == 0 else "FAIL"), last
|
||||
|
||||
def _parse_step_results(self, stdout: str) -> list[StepResult]:
|
||||
"""
|
||||
解析脚本中以 ##STEP_RESULT## 开头的结构化输出行
|
||||
格式:##STEP_RESULT## <json>
|
||||
"""
|
||||
results = []
|
||||
for line in stdout.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("##STEP_RESULT##"):
|
||||
try:
|
||||
data = json.loads(line[len("##STEP_RESULT##"):].strip())
|
||||
results.append(StepResult(
|
||||
step_no=data.get("step_no", 0),
|
||||
interface_name=data.get("interface_name", ""),
|
||||
status=data.get("status", ""),
|
||||
assertions=data.get("assertions", []),
|
||||
))
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
|
||||
def _env(self) -> dict:
|
||||
env = os.environ.copy()
|
||||
env["HTTP_BASE_URL"] = config.HTTP_BASE_URL
|
||||
return env
|
||||
|
||||
# ── 打印 ──────────────────────────────────────────────────
|
||||
|
||||
def _print_result(self, r: TestResult):
|
||||
icon = {"PASS": "✅", "FAIL": "❌", "TIMEOUT": "⏱️", "ERROR": "⚠️"}.get(r.status, "❓")
|
||||
print(f" {icon} [{r.status}] {r.test_id} ({r.duration:.2f}s)")
|
||||
print(f" {r.message}")
|
||||
for sr in r.step_results:
|
||||
s_icon = "✅" if sr.status == "PASS" else "❌"
|
||||
print(f" {s_icon} Step {sr.step_no}: {sr.interface_name}")
|
||||
for a in sr.assertions:
|
||||
a_icon = "✔" if a.get("passed") else "✘"
|
||||
print(f" {a_icon} {a.get('message','')} "
|
||||
f"(expected={a.get('expected')}, actual={a.get('actual')})")
|
||||
if r.stderr:
|
||||
print(f" stderr: {r.stderr[:300]}")
|
||||
|
||||
def print_summary(self, results: list[TestResult]):
|
||||
total = len(results)
|
||||
passed = sum(1 for r in results if r.status == "PASS")
|
||||
failed = sum(1 for r in results if r.status == "FAIL")
|
||||
errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT"))
|
||||
print(f"\n{'═'*62}")
|
||||
print(f" TEST SUMMARY")
|
||||
print(f"{'═'*62}")
|
||||
print(f" Total : {total}")
|
||||
print(f" ✅ PASS : {passed}")
|
||||
print(f" ❌ FAIL : {failed}")
|
||||
print(f" ⚠️ ERROR : {errors}")
|
||||
print(f" Pass Rate: {passed/total*100:.1f}%" if total else " Pass Rate: N/A")
|
||||
print(f"{'═'*62}\n")
|
||||
|
||||
def save_results(self, results: list[TestResult], path: str):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump([{
|
||||
"test_id": r.test_id, "status": r.status,
|
||||
"message": r.message, "duration": r.duration,
|
||||
"stdout": r.stdout, "stderr": r.stderr,
|
||||
"step_results": [
|
||||
{"step_no": sr.step_no, "interface_name": sr.interface_name,
|
||||
"status": sr.status, "assertions": sr.assertions}
|
||||
for sr in r.step_results
|
||||
],
|
||||
} for r in results], f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"Run results → {path}")
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
{
|
||||
"project": "project name",
|
||||
"description": "this is my first project",
|
||||
"units": [
|
||||
{
|
||||
"name": "create_user",
|
||||
"requirement_id": "REQ.01",
|
||||
"description": "Creates a new user account with provided credentials and information.",
|
||||
"type": "function",
|
||||
"url": "create_user.py",
|
||||
"parameters": {
|
||||
"username": {
|
||||
"type": "string",
|
||||
"inout": "in",
|
||||
"description": "The desired username for the new account.",
|
||||
"required": true
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"inout": "in",
|
||||
"description": "The password for the new account, meeting security requirements.",
|
||||
"required": true
|
||||
},
|
||||
"email": {
|
||||
"type": "string",
|
||||
"inout": "in",
|
||||
"description": "The email address associated with the new account.",
|
||||
"required": true
|
||||
},
|
||||
"phone_number": {
|
||||
"type": "string",
|
||||
"inout": "in",
|
||||
"description": "The phone number associated with the new account.",
|
||||
"required": false
|
||||
},
|
||||
"additional_info": {
|
||||
"type": "dict",
|
||||
"inout": "in",
|
||||
"description": "Optional additional information about the user.",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"return": {
|
||||
"type": "dict",
|
||||
"description": "Returns the created user object on success or an error message on failure.",
|
||||
"on_success": {
|
||||
"value": "User object containing user details such as id, username, email, etc.",
|
||||
"description": "The user object representing the newly created account."
|
||||
},
|
||||
"on_failure": {
|
||||
"value": "Error message or raises an exception.",
|
||||
"description": "Details about why the user creation failed, such as validation errors or duplicate username."
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "/delete_user",
|
||||
"requirement_id": "REQ.02",
|
||||
"description": "Deletes a user by user ID or username from the system.",
|
||||
"type": "http",
|
||||
"url": "http://127.0.0.1/api",
|
||||
"method": "post",
|
||||
"parameters": {
|
||||
"user_id": {
|
||||
"type": "integer",
|
||||
"inout": "in",
|
||||
"description": "Unique identifier of the user to be deleted.",
|
||||
"required": false
|
||||
},
|
||||
"username": {
|
||||
"type": "string",
|
||||
"inout": "in",
|
||||
"description": "Username of the user to be deleted.",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"return": {
|
||||
"type": "boolean",
|
||||
"description": "Indicates whether the user was successfully deleted.",
|
||||
"on_success": {
|
||||
"value": true,
|
||||
"description": "The user was successfully deleted."
|
||||
},
|
||||
"on_failure": {
|
||||
"value": false,
|
||||
"description": "The user could not be deleted, or the user does not exist."
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
{
|
||||
"project": "project name",
|
||||
"description": "this is my first project",
|
||||
"units": [
|
||||
{
|
||||
"name": "create_user",
|
||||
"requirement_id": "REQ.01",
|
||||
"description": "Creates a new user account with provided credentials and information.",
|
||||
"type": "function",
|
||||
"url": "/Users/sontolau/Workspace/AIDeveloper/requirements_generator/output/my_first_project/create_user.py",
|
||||
"parameters": {
|
||||
"username": {
|
||||
"type": "string",
|
||||
"inout": "in",
|
||||
"description": "The desired username for the new account.",
|
||||
"required": true
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"inout": "in",
|
||||
"description": "The password for the new account, meeting security requirements.",
|
||||
"required": true
|
||||
},
|
||||
"email": {
|
||||
"type": "string",
|
||||
"inout": "in",
|
||||
"description": "The email address associated with the new account.",
|
||||
"required": true
|
||||
},
|
||||
"phone_number": {
|
||||
"type": "string",
|
||||
"inout": "in",
|
||||
"description": "The phone number associated with the new account.",
|
||||
"required": false
|
||||
},
|
||||
"additional_info": {
|
||||
"type": "dict",
|
||||
"inout": "in",
|
||||
"description": "Optional additional information about the user.",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"return": {
|
||||
"type": "dict",
|
||||
"description": "Returns the created user object on success or an error message on failure.",
|
||||
"on_success": {
|
||||
"value": "User object containing user details such as id, username, email, etc.",
|
||||
"description": "The user object representing the newly created account."
|
||||
},
|
||||
"on_failure": {
|
||||
"value": "Error message or raises an exception.",
|
||||
"description": "Details about why the user creation failed, such as validation errors or duplicate username."
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "/delete_user",
|
||||
"requirement_id": "REQ.02",
|
||||
"description": "Deletes a user by user ID or username from the system.",
|
||||
"type": "http",
|
||||
"url": "http://127.0.0.1/api",
|
||||
"method": "post",
|
||||
"parameters": {
|
||||
"user_id": {
|
||||
"type": "integer",
|
||||
"inout": "in",
|
||||
"description": "Unique identifier of the user to be deleted.",
|
||||
"required": false
|
||||
},
|
||||
"username": {
|
||||
"type": "string",
|
||||
"inout": "in",
|
||||
"description": "Username of the user to be deleted.",
|
||||
"required": false
|
||||
}
|
||||
},
|
||||
"return": {
|
||||
"type": "boolean",
|
||||
"description": "Indicates whether the user was successfully deleted.",
|
||||
"on_success": {
|
||||
"value": true,
|
||||
"description": "The user was successfully deleted."
|
||||
},
|
||||
"on_failure": {
|
||||
"value": false,
|
||||
"description": "The user could not be deleted, or the user does not exist."
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,232 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
AI-Powered API Test Generator, Runner & Coverage Analyzer
|
||||
──────────────────────────────────────────────────────────
|
||||
Usage:
|
||||
python main.py --api-desc examples/api_desc.json \\
|
||||
--requirements "创建用户,删除用户"
|
||||
|
||||
python main.py --api-desc examples/api_desc.json \\
|
||||
--req-file examples/requirements.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from config import config
|
||||
from core.parser import InterfaceParser, ApiDescriptor
|
||||
from core.prompt_builder import PromptBuilder
|
||||
from core.llm_client import LLMClient
|
||||
from core.test_generator import TestGenerator
|
||||
from core.test_runner import TestRunner
|
||||
from core.analyzer import CoverageAnalyzer, CoverageReport
|
||||
from core.report_generator import ReportGenerator
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── CLI ───────────────────────────────────────────────────────
|
||||
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(
|
||||
description="AI-Powered API Test Generator & Coverage Analyzer"
|
||||
)
|
||||
p.add_argument(
|
||||
"--api-desc", required=True,
|
||||
help="接口描述 JSON 文件(含 project / description / units 字段)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--requirements", default="",
|
||||
help="测试需求(逗号分隔)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--req-file", default="",
|
||||
help="测试需求文件(每行一条)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--skip-run", action="store_true",
|
||||
help="只生成测试文件,不执行",
|
||||
)
|
||||
p.add_argument(
|
||||
"--skip-analyze", action="store_true",
|
||||
help="跳过覆盖率分析",
|
||||
)
|
||||
p.add_argument(
|
||||
"--output-dir", default="",
|
||||
help="测试文件根输出目录(默认 config.GENERATED_TESTS_DIR)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--debug", action="store_true",
|
||||
help="开启 DEBUG 日志",
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
def load_requirements(args) -> list[str]:
|
||||
reqs: list[str] = []
|
||||
if args.requirements:
|
||||
reqs += [r.strip() for r in args.requirements.split(",") if r.strip()]
|
||||
if args.req_file:
|
||||
with open(args.req_file, "r", encoding="utf-8") as f:
|
||||
reqs += [line.strip() for line in f if line.strip()]
|
||||
if not reqs:
|
||||
logger.error(
|
||||
"No requirements provided. Use --requirements or --req-file."
|
||||
)
|
||||
sys.exit(1)
|
||||
return reqs
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
args = build_arg_parser().parse_args()
|
||||
|
||||
if args.debug:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
if args.output_dir:
|
||||
config.GENERATED_TESTS_DIR = args.output_dir
|
||||
|
||||
# ── Step 1: 解析接口描述 ──────────────────────────────────
|
||||
logger.info("▶ Step 1: Parsing interface description …")
|
||||
parser: InterfaceParser = InterfaceParser()
|
||||
descriptor: ApiDescriptor = parser.parse_file(args.api_desc)
|
||||
|
||||
project = descriptor.project
|
||||
project_desc = descriptor.description
|
||||
interfaces = descriptor.interfaces
|
||||
|
||||
logger.info(f" Project : {project}")
|
||||
logger.info(f" Description: {project_desc}")
|
||||
logger.info(
|
||||
f" Interfaces : {len(interfaces)} — "
|
||||
f"{[i.name for i in interfaces]}"
|
||||
)
|
||||
# 打印每个接口的 url 解析结果
|
||||
for iface in interfaces:
|
||||
if iface.protocol == "function":
|
||||
logger.info(
|
||||
f" [{iface.name}] source={iface.source_file} "
|
||||
f"module={iface.module_path}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f" [{iface.name}] full_url={iface.http_full_url}"
|
||||
)
|
||||
|
||||
# ── Step 2: 加载测试需求 ──────────────────────────────────
|
||||
logger.info("▶ Step 2: Loading test requirements …")
|
||||
requirements = load_requirements(args)
|
||||
for i, req in enumerate(requirements, 1):
|
||||
logger.info(f" {i}. {req}")
|
||||
|
||||
# ── Step 3: 调用 LLM 生成测试用例 ────────────────────────
|
||||
logger.info("▶ Step 3: Calling LLM …")
|
||||
builder = PromptBuilder()
|
||||
test_cases = LLMClient().generate_test_cases(
|
||||
builder.get_system_prompt(),
|
||||
builder.build_user_prompt(
|
||||
requirements, interfaces, parser,
|
||||
project=project,
|
||||
project_desc=project_desc,
|
||||
),
|
||||
)
|
||||
logger.info(f" LLM returned {len(test_cases)} test case(s)")
|
||||
|
||||
# ── Step 4: 生成测试文件 ──────────────────────────────────
|
||||
logger.info(
|
||||
f"▶ Step 4: Generating test files (project='{project}') …"
|
||||
)
|
||||
generator = TestGenerator(project=project, project_desc=project_desc)
|
||||
test_files = generator.generate(test_cases)
|
||||
out = generator.output_dir
|
||||
|
||||
run_results = []
|
||||
|
||||
if not args.skip_run:
|
||||
# ── Step 5: 执行测试 ──────────────────────────────────
|
||||
logger.info("▶ Step 5: Running tests …")
|
||||
runner = TestRunner()
|
||||
run_results = runner.run_all(test_files)
|
||||
runner.print_summary(run_results)
|
||||
runner.save_results(run_results, str(out / "run_results.json"))
|
||||
|
||||
if not args.skip_analyze:
|
||||
# ── Step 6: 覆盖率分析 ────────────────────────────────
|
||||
logger.info("▶ Step 6: Analyzing coverage …")
|
||||
report = CoverageAnalyzer(
|
||||
interfaces=interfaces,
|
||||
requirements=requirements,
|
||||
test_cases=test_cases,
|
||||
run_results=run_results,
|
||||
).analyze()
|
||||
|
||||
# ── Step 7: 生成报告 ──────────────────────────────────
|
||||
logger.info("▶ Step 7: Generating reports …")
|
||||
rg = ReportGenerator()
|
||||
rg.save_json(report, str(out / "coverage_report.json"))
|
||||
rg.save_html(report, str(out / "coverage_report.html"))
|
||||
|
||||
_print_terminal_summary(report, out, project)
|
||||
|
||||
logger.info(f"\n✅ Done. Output: {out.resolve()}")
|
||||
|
||||
|
||||
# ── 终端摘要 ──────────────────────────────────────────────────
|
||||
|
||||
def _print_terminal_summary(
|
||||
report: CoverageReport, out: Path, project: str
|
||||
):
|
||||
W = 66
|
||||
|
||||
def bar(rate: float, w: int = 20) -> str:
|
||||
filled = int(rate * w)
|
||||
empty = w - filled
|
||||
icon = "✅" if rate >= 0.8 else ("⚠️ " if rate >= 0.5 else "❌")
|
||||
return f"{'█' * filled}{'░' * empty} {rate * 100:.1f}% {icon}"
|
||||
|
||||
print(f"\n{'═' * W}")
|
||||
print(f" PROJECT : {project}")
|
||||
print(f" COVERAGE SUMMARY")
|
||||
print(f"{'═' * W}")
|
||||
print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
|
||||
print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
|
||||
print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
|
||||
print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
|
||||
print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
|
||||
print(f" 用例通过率 {bar(report.pass_rate)}")
|
||||
print(f"{'─' * W}")
|
||||
print(f" Total Gaps : {len(report.gaps)}")
|
||||
print(f" 🔴 Critical: {report.critical_gap_count}")
|
||||
print(f" 🟠 High : {report.high_gap_count}")
|
||||
|
||||
if report.gaps:
|
||||
print(f"{'─' * W}")
|
||||
print(" Top Gaps (up to 8):")
|
||||
icons = {
|
||||
"critical": "🔴", "high": "🟠",
|
||||
"medium": "🟡", "low": "🔵",
|
||||
}
|
||||
for g in report.gaps[:8]:
|
||||
icon = icons.get(g.severity, "⚪")
|
||||
print(f" {icon} [{g.gap_type}] {g.target}")
|
||||
print(f" → {g.suggestion}")
|
||||
|
||||
print(f"{'─' * W}")
|
||||
print(f" Output : {out.resolve()}")
|
||||
print(f" • coverage_report.html ← open in browser")
|
||||
print(f" • coverage_report.json")
|
||||
print(f" • run_results.json")
|
||||
print(f" • test_cases_summary.json")
|
||||
print(f"{'═' * W}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
openai>=1.30.0
|
||||
requests>=2.31.0
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 设置环境变量
|
||||
export LLM_API_KEY="sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR"
|
||||
export LLM_BASE_URL="https://openapi.monica.im/v1" # 或其他兼容接口
|
||||
export LLM_MODEL="gpt-4o"
|
||||
export HTTP_BASE_URL="http://localhost:8080"
|
||||
|
||||
# 运行(生成 + 执行)
|
||||
# python main.py \
|
||||
# --api-desc examples/api_desc.json \
|
||||
# --requirements "创建用户,自动生成测试数据,更改指定用户的密码,自动生成测试数据"
|
||||
|
||||
# 只生成不执行
|
||||
python main.py --api-desc examples/function_signatures.json \
|
||||
--requirements "1.对每个接口进行测试,支持特殊值、边界值测试"
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
你现在是一名高级ai软件工程师,非常擅长通过llm模型解决软件应用需求,现在需要构建python工程利用ai实现基于api的自动化测试代码生成与执行,支持将用户每个测试需求分解成对1个或多个接口的调用,并基于接口返回描述对接口返回值进行测试,设计如下:
|
||||
# 工作流程
|
||||
1. 用户输入测试需求、接口描述文件(文件采用json格式,包含接口名称、请求参数描述、返回参数描述,接口描述等信息)。例如:
|
||||
## 接口描述示例
|
||||
[
|
||||
{
|
||||
"name": "/api/create_user",
|
||||
"description": "to create new user account in the system",
|
||||
"protocol": "http",
|
||||
"base_url": "http://123.0.0.1:8000"
|
||||
"method": "post",
|
||||
"url_parameters": {
|
||||
"version": {
|
||||
"type": "integer",
|
||||
"description": "the api's version",
|
||||
"required": "false"
|
||||
}
|
||||
},
|
||||
"request_parameters": {
|
||||
"username": { "type": "string", "description": "user name", "required": true },
|
||||
"email": { "type": "string", "description": "the user's email", "required": false },
|
||||
"password": { "type": "string", "description": "the user's password", "required": true }
|
||||
},
|
||||
"response": {
|
||||
"code": {
|
||||
"type": "integer",
|
||||
"description": "0 is successful, nonzero is failure"
|
||||
},
|
||||
"user_id": {
|
||||
"type": "integer",
|
||||
"description": "the created user's id "
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "change_password",
|
||||
"description": "replace old password with new password",
|
||||
"protocol": "function",
|
||||
"parameters": {
|
||||
"user_id": { "type": "integer", "inout","in", "description": "the user's id", "required": true },
|
||||
"old_password": { "type": "string", "inout","in", "description": "the old password", "required": true },
|
||||
"new_password": { "type": "string", "inout","in", "description": "the user's new password", "required": true }
|
||||
},
|
||||
"return": {
|
||||
"type": "integer",
|
||||
"description": "0 is successful, nonzero is failure"
|
||||
}
|
||||
},
|
||||
...
|
||||
]
|
||||
## 测试需求示例
|
||||
(1)创建用户,自动生成测试数据
|
||||
(2)更改指定用户的密码,自动生成测试数据;支持先创建用户,再改变该用户的密码;
|
||||
|
||||
2. 软件将接口信息、指令、输出要求等通过提示词传递给llm模型。
|
||||
3. llm模型结合用户测试需求,并根据接口列表生成针对每个测试需求的测试用例(包含生成的测试数据、代码等),测试用例支持基于测试逻辑对多个接口的调用与返回值测试
|
||||
4. 输出的测试用例以json格式返回。
|
||||
5. 软件解析json数据并生成测试用例文件
|
||||
6. 软件运行测试用例文件并输出每个测试用例的结果;
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
#!/usr/bin/env python3
|
||||
import requests
|
||||
import json
|
||||
|
||||
session_id="123456789"
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self, url):
|
||||
self.url = url
|
||||
self.session = requests.Session()
|
||||
self.session_id = None
|
||||
|
||||
def initialize(self):
|
||||
"""初始化 MCP 会话"""
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "python-client",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response = self.session.post(
|
||||
self.url,
|
||||
json=payload,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
)
|
||||
self.session_id = response.headers['Mcp-Session-Id']
|
||||
|
||||
|
||||
def call_method(self, method, params=None, request_id=None):
|
||||
"""调用 MCP 方法"""
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id or method,
|
||||
"method": method
|
||||
}
|
||||
|
||||
if params:
|
||||
payload["params"] = params
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream"
|
||||
}
|
||||
|
||||
# 如果有 session ID,添加到头部
|
||||
if self.session_id:
|
||||
headers["Mcp-Session-Id"] = self.session_id
|
||||
|
||||
response = self.session.post(
|
||||
self.url,
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
print(response.text)
|
||||
return response.json()
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
client = MCPClient("http://iot.gesukj.com:7999/mcp")
|
||||
|
||||
# 初始化
|
||||
print("Initializing...")
|
||||
init_result = client.initialize()
|
||||
# print(json.dumps(init_result, indent=2))
|
||||
|
||||
# 列出工具
|
||||
print("\nListing tools...")
|
||||
tools_result = client.call_method("tools/list", request_id=4)
|
||||
print(json.dumps(tools_result, indent=2))
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
# config.py - 全局配置管理
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# ── LLM 配置 ──────────────────────────────────────────
|
||||
LLM_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
|
||||
LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0.3"))
|
||||
|
||||
# ── 数据库配置 ─────────────────────────────────────────
|
||||
DB_PATH = os.getenv("DB_PATH", "data/requirement_analyzer.db")
|
||||
|
||||
# ── 输出配置 ───────────────────────────────────────────
|
||||
OUTPUT_BASE_DIR = os.getenv("OUTPUT_BASE_DIR", "output")
|
||||
DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "python")
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Prompt 模板
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
DECOMPOSE_PROMPT_TEMPLATE = """
|
||||
你是一位资深软件架构师和产品经理。请根据以下信息,将原始需求分解为若干个可独立实现的功能需求。
|
||||
|
||||
{knowledge_section}
|
||||
|
||||
## 原始需求
|
||||
{raw_requirement}
|
||||
|
||||
## 输出要求
|
||||
请严格按照以下 JSON 格式输出,不要包含任何额外说明:
|
||||
{{
|
||||
"functional_requirements": [
|
||||
{{
|
||||
"index": 1,
|
||||
"title": "功能需求标题(简洁,10字以内)",
|
||||
"description": "功能需求详细描述(包含输入、处理逻辑、输出)",
|
||||
"function_name": "snake_case函数名",
|
||||
"priority": "high|medium|low"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
要求:
|
||||
1. 每个功能需求必须是独立可实现的最小单元
|
||||
2. function_name 使用 snake_case 命名,清晰表达函数用途
|
||||
3. 分解粒度适中,通常 5-15 个功能需求
|
||||
4. 优先级根据业务重要性判断
|
||||
"""
|
||||
|
||||
# ── 函数签名 JSON 生成 Prompt ──────────────────────────
|
||||
FUNC_SIGNATURE_PROMPT_TEMPLATE = """
|
||||
你是一位资深软件架构师。请根据以下功能需求描述,设计该函数的完整接口签名,并以 JSON 格式输出。
|
||||
|
||||
{knowledge_section}
|
||||
|
||||
## 功能需求
|
||||
需求编号:{requirement_id}
|
||||
标题:{title}
|
||||
函数名:{function_name}
|
||||
详细描述:{description}
|
||||
|
||||
## 输出格式
|
||||
请严格按照以下 JSON 结构输出,不要包含任何额外说明或 markdown 标记:
|
||||
{{
|
||||
"name": "{function_name}",
|
||||
"requirement_id": "{requirement_id}",
|
||||
"description": "简洁的一句话功能描述(英文)",
|
||||
"type": "function",
|
||||
"parameters": {{
|
||||
"<param_name>": {{
|
||||
"type": "integer|string|boolean|float|list|dict|object",
|
||||
"inout": "in|out|inout",
|
||||
"description": "参数说明(英文)",
|
||||
"required": true
|
||||
}}
|
||||
}},
|
||||
"return": {{
|
||||
"type": "integer|string|boolean|float|list|dict|object|void",
|
||||
"description": "整体返回值说明(英文,一句话概括)",
|
||||
"on_success": {{
|
||||
"value": "具体成功返回值或范围,如 0、true、user object、list of items 等",
|
||||
"description": "成功时的返回值含义(英文)"
|
||||
}},
|
||||
"on_failure": {{
|
||||
"value": "具体失败返回值或范围,如 nonzero、false、null、empty list、raises Exception 等",
|
||||
"description": "失败时的返回值含义,或抛出的异常类型(英文)"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
|
||||
## 设计规范
|
||||
1. 参数名使用 snake_case,类型使用通用类型(不绑定具体语言)
|
||||
2. inout 字段含义:
|
||||
- in = 仅输入参数
|
||||
- out = 仅输出参数(通过参数传出结果,如指针/引用)
|
||||
- inout = 既作输入又作输出
|
||||
3. 所有描述字段使用英文
|
||||
4. return 字段规则:
|
||||
- 若函数无返回值(void),type 填 "void",on_success/on_failure 均填 null
|
||||
- 若返回值只有成功场景(如纯查询),on_failure 可描述为 "null or empty"
|
||||
- on_success.value / on_failure.value 填写具体值或值域描述,不要填写空字符串
|
||||
5. 若函数无参数,parameters 填 {{}}
|
||||
6. required 字段为布尔值 true 或 false
|
||||
"""
|
||||
|
||||
# ── 代码生成 Prompt(含签名约束)─────────────────────────
|
||||
CODE_GEN_PROMPT_TEMPLATE = """
|
||||
你是一位资深 {language} 工程师。请根据以下功能需求和【函数签名规范】,生成完整的 {language} 函数代码。
|
||||
|
||||
{knowledge_section}
|
||||
|
||||
## 功能需求
|
||||
标题:{title}
|
||||
描述:{description}
|
||||
|
||||
## 【必须严格遵守】函数签名规范
|
||||
以下 JSON 定义了函数的精确接口,生成的代码必须与之完全一致,不得擅自增减或改名参数:
|
||||
|
||||
```json
|
||||
{signature_json}
|
||||
```
|
||||
|
||||
### 签名字段说明
|
||||
- `name`:函数名,必须完全一致
|
||||
- `parameters`:每个 key 即为参数名,`type` 为数据类型,`inout` 含义:
|
||||
- `in` = 普通输入参数
|
||||
- `out` = 输出参数(Python 中通过返回值或可变容器传出)
|
||||
- `inout` = 既作输入又作输出
|
||||
- `return.type`:返回值类型
|
||||
- `return.on_success`:成功时的返回值,代码实现必须与此一致
|
||||
- `return.on_failure`:失败时的返回值或异常,代码实现必须与此一致
|
||||
|
||||
## 输出要求
|
||||
1. 只输出纯代码,不要包含 markdown 代码块标记
|
||||
2. 函数签名(名称、参数列表、返回类型)必须与上方 JSON 规范完全一致
|
||||
3. 成功/失败的返回值必须严格遵守 return.on_success / return.on_failure 的定义
|
||||
4. 包含完整的类型注解(Python 使用 type hints)
|
||||
5. 包含详细的 docstring,其中 Returns 段须注明成功值与失败值
|
||||
6. 包含必要的异常处理
|
||||
7. 代码风格遵循 PEP8(Python)或对应语言规范
|
||||
8. 在文件顶部用注释注明:需求编号、功能标题、函数签名摘要
|
||||
9. 如需导入第三方库,请在顶部统一导入
|
||||
"""
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
# core/code_generator.py - 代码生成核心逻辑(签名约束版)
|
||||
import json
|
||||
from typing import Optional, List, Callable
|
||||
|
||||
import config
|
||||
from core.llm_client import LLMClient
|
||||
from database.models import FunctionalRequirement, CodeFile
|
||||
from utils.output_writer import write_code_file, get_file_extension
|
||||
|
||||
|
||||
class CodeGenerator:
|
||||
"""
|
||||
根据功能需求 + 函数签名约束,使用 LLM 生成代码函数文件。
|
||||
签名由 RequirementAnalyzer.build_function_signature() 预先生成,
|
||||
注入 Prompt 后可确保代码参数列表与签名 JSON 完全一致。
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
"""
|
||||
初始化代码生成器
|
||||
|
||||
Args:
|
||||
llm_client: LLM 客户端实例,为 None 时自动创建
|
||||
"""
|
||||
self.llm = llm_client or LLMClient()
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 单个生成
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def generate(
|
||||
self,
|
||||
func_req: FunctionalRequirement,
|
||||
output_dir: str,
|
||||
language: str = config.DEFAULT_LANGUAGE,
|
||||
knowledge: str = "",
|
||||
signature: Optional[dict] = None,
|
||||
) -> CodeFile:
|
||||
"""
|
||||
为单个功能需求生成代码文件
|
||||
|
||||
Args:
|
||||
func_req: 功能需求对象(必须含有效 id)
|
||||
output_dir: 代码输出目录
|
||||
language: 目标编程语言
|
||||
knowledge: 知识库文本(可选)
|
||||
signature: 函数签名 dict(由 RequirementAnalyzer 生成)。
|
||||
传入后将作为强约束注入 Prompt,确保代码参数
|
||||
与签名 JSON 完全一致;为 None 时退化为无约束模式。
|
||||
|
||||
Returns:
|
||||
CodeFile 对象(含生成的代码内容和文件路径,未持久化)
|
||||
|
||||
Raises:
|
||||
ValueError: func_req.id 为 None
|
||||
Exception: LLM 调用失败或文件写入失败
|
||||
"""
|
||||
if func_req.id is None:
|
||||
raise ValueError("FunctionalRequirement 必须先持久化(id 不能为 None)")
|
||||
|
||||
knowledge_section = self._build_knowledge_section(knowledge)
|
||||
signature_json = self._build_signature_json(signature, func_req)
|
||||
|
||||
prompt = config.CODE_GEN_PROMPT_TEMPLATE.format(
|
||||
language=language,
|
||||
knowledge_section=knowledge_section,
|
||||
title=func_req.title,
|
||||
description=func_req.description,
|
||||
signature_json=signature_json,
|
||||
)
|
||||
|
||||
code_content = self.llm.chat(
|
||||
system_prompt=(
|
||||
f"你是一位资深 {language} 工程师,只输出纯代码,"
|
||||
"不添加任何 markdown 标记。函数签名必须与提供的 JSON 规范完全一致。"
|
||||
),
|
||||
user_prompt=prompt,
|
||||
)
|
||||
|
||||
file_path = write_code_file(
|
||||
output_dir=output_dir,
|
||||
function_name=func_req.function_name,
|
||||
language=language,
|
||||
content=code_content,
|
||||
)
|
||||
|
||||
ext = get_file_extension(language)
|
||||
file_name = f"{func_req.function_name}{ext}"
|
||||
|
||||
return CodeFile(
|
||||
project_id=func_req.project_id,
|
||||
func_req_id=func_req.id,
|
||||
file_name=file_name,
|
||||
file_path=file_path,
|
||||
language=language,
|
||||
content=code_content,
|
||||
)
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 批量生成
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def generate_batch(
|
||||
self,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
output_dir: str,
|
||||
language: str = config.DEFAULT_LANGUAGE,
|
||||
knowledge: str = "",
|
||||
signatures: Optional[List[dict]] = None,
|
||||
on_progress: Optional[Callable] = None,
|
||||
) -> List[CodeFile]:
|
||||
"""
|
||||
批量生成代码文件
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
output_dir: 输出目录
|
||||
language: 目标语言
|
||||
knowledge: 知识库文本
|
||||
signatures: 与 func_reqs 等长的签名列表(索引对应)。
|
||||
为 None 时所有条目均以无约束模式生成。
|
||||
on_progress: 进度回调 fn(index, total, func_req, code_file, error)
|
||||
|
||||
Returns:
|
||||
成功生成的 CodeFile 列表
|
||||
"""
|
||||
results = []
|
||||
total = len(func_reqs)
|
||||
|
||||
# 构建 func_req.id → signature 的快速查找表
|
||||
sig_map = self._build_signature_map(func_reqs, signatures)
|
||||
|
||||
for i, req in enumerate(func_reqs):
|
||||
sig = sig_map.get(req.id)
|
||||
try:
|
||||
code_file = self.generate(
|
||||
func_req=req,
|
||||
output_dir=output_dir,
|
||||
language=language,
|
||||
knowledge=knowledge,
|
||||
signature=sig,
|
||||
)
|
||||
results.append(code_file)
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, code_file, None)
|
||||
except Exception as e:
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, None, e)
|
||||
|
||||
return results
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 私有工具方法
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
def _build_knowledge_section(knowledge: str) -> str:
|
||||
"""构建知识库 Prompt 段落"""
|
||||
if not knowledge or not knowledge.strip():
|
||||
return ""
|
||||
return (
|
||||
"## 参考知识库(实现时请遵循以下规范)\n"
|
||||
f"{knowledge}\n\n---\n"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_signature_json(
|
||||
signature: Optional[dict],
|
||||
func_req: FunctionalRequirement,
|
||||
) -> str:
|
||||
"""
|
||||
将签名 dict 序列化为格式化 JSON 字符串;
|
||||
若签名为 None,则构造最小占位签名,保持 Prompt 结构完整。
|
||||
|
||||
Args:
|
||||
signature: 签名 dict 或 None
|
||||
func_req: 对应的功能需求(用于占位签名)
|
||||
|
||||
Returns:
|
||||
JSON 字符串
|
||||
"""
|
||||
if signature:
|
||||
return json.dumps(signature, ensure_ascii=False, indent=2)
|
||||
# 无签名时的最小占位,提示 LLM 自行设计但保持格式
|
||||
fallback = {
|
||||
"name": func_req.function_name,
|
||||
"requirement_id": f"REQ.{func_req.index_no:02d}",
|
||||
"description": func_req.description,
|
||||
"type": "function",
|
||||
"parameters": "<<请根据功能描述自行设计参数>>",
|
||||
"return": "<<请根据功能描述自行设计返回值>>",
|
||||
}
|
||||
return json.dumps(fallback, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _build_signature_map(
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
signatures: Optional[List[dict]],
|
||||
) -> dict:
|
||||
"""
|
||||
构建 func_req.id → signature 映射表
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
signatures: 与 func_reqs 等长的签名列表,或 None
|
||||
|
||||
Returns:
|
||||
{req_id: signature_dict} 字典
|
||||
"""
|
||||
if not signatures:
|
||||
return {}
|
||||
sig_map = {}
|
||||
for req, sig in zip(func_reqs, signatures):
|
||||
if req.id is not None and sig:
|
||||
sig_map[req.id] = sig
|
||||
return sig_map
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
# core/llm_client.py - LLM 客户端封装
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import config
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""
|
||||
OpenAI 兼容 LLM 客户端封装。
|
||||
支持任何兼容 OpenAI API 格式的服务(OpenAI / Azure / 本地模型等)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = config.LLM_API_KEY,
|
||||
base_url: str = config.LLM_BASE_URL,
|
||||
model: str = config.LLM_MODEL,
|
||||
temperature: float = config.LLM_TEMPERATURE,
|
||||
):
|
||||
"""
|
||||
初始化 LLM 客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
base_url: API 基础 URL
|
||||
model: 模型名称
|
||||
temperature: 生成温度(0~1,越低越确定)
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 openai 库
|
||||
ValueError: api_key 为空
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
raise ImportError("请安装 openai: pip install openai")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置,请在 .env 文件中设置")
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self._client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def chat(self, system_prompt: str, user_prompt: str) -> str:
|
||||
"""
|
||||
发送对话请求,返回模型回复文本
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
|
||||
Returns:
|
||||
模型回复的文本内容
|
||||
|
||||
Raises:
|
||||
Exception: API 调用失败
|
||||
"""
|
||||
response = self._client.chat.completions.create(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def chat_json(self, system_prompt: str, user_prompt: str) -> dict:
|
||||
"""
|
||||
发送对话请求,解析并返回 JSON 结果
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
|
||||
Returns:
|
||||
解析后的 dict 对象
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 模型返回非合法 JSON
|
||||
"""
|
||||
raw = self.chat(system_prompt, user_prompt)
|
||||
# 去除可能的 markdown 代码块包裹
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
lines = raw.split("\n")
|
||||
raw = "\n".join(lines[1:-1])
|
||||
return json.loads(raw)
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
# core/requirement_analyzer.py - 需求分解 & 函数签名生成
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
import config
|
||||
from core.llm_client import LLMClient
|
||||
from database.models import FunctionalRequirement
|
||||
|
||||
|
||||
class RequirementAnalyzer:
|
||||
"""
|
||||
使用 LLM 将原始需求分解为功能需求列表,并生成函数接口签名。
|
||||
支持注入知识库上下文以提升分解质量。
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
"""
|
||||
初始化需求分析器
|
||||
|
||||
Args:
|
||||
llm_client: LLM 客户端实例,为 None 时自动创建
|
||||
"""
|
||||
self.llm = llm_client or LLMClient()
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 需求分解
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def decompose(
|
||||
self,
|
||||
raw_requirement: str,
|
||||
project_id: int,
|
||||
raw_req_id: int,
|
||||
knowledge: str = "",
|
||||
) -> List[FunctionalRequirement]:
|
||||
"""
|
||||
将原始需求分解为功能需求列表
|
||||
|
||||
Args:
|
||||
raw_requirement: 原始需求文本
|
||||
project_id: 所属项目 ID
|
||||
raw_req_id: 原始需求记录 ID
|
||||
knowledge: 知识库文本(可选)
|
||||
|
||||
Returns:
|
||||
FunctionalRequirement 对象列表(未持久化,id=None)
|
||||
|
||||
Raises:
|
||||
ValueError: LLM 返回格式不合法
|
||||
json.JSONDecodeError: JSON 解析失败
|
||||
"""
|
||||
knowledge_section = self._build_knowledge_section(knowledge)
|
||||
prompt = config.DECOMPOSE_PROMPT_TEMPLATE.format(
|
||||
knowledge_section=knowledge_section,
|
||||
raw_requirement=raw_requirement,
|
||||
)
|
||||
|
||||
result = self.llm.chat_json(
|
||||
system_prompt="你是一位资深软件架构师,擅长需求分析与系统设计。",
|
||||
user_prompt=prompt,
|
||||
)
|
||||
|
||||
items = result.get("functional_requirements", [])
|
||||
if not items:
|
||||
raise ValueError("LLM 未返回任何功能需求,请检查原始需求描述")
|
||||
|
||||
requirements = []
|
||||
for item in items:
|
||||
req = FunctionalRequirement(
|
||||
project_id=project_id,
|
||||
raw_req_id=raw_req_id,
|
||||
index_no=int(item.get("index", len(requirements) + 1)),
|
||||
title=item.get("title", "未命名功能"),
|
||||
description=item.get("description", ""),
|
||||
function_name=self._sanitize_function_name(
|
||||
item.get("function_name", f"func_{len(requirements)+1}")
|
||||
),
|
||||
priority=item.get("priority", "medium"),
|
||||
)
|
||||
requirements.append(req)
|
||||
|
||||
return requirements
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 函数签名生成(新增)
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def build_function_signature(
|
||||
self,
|
||||
func_req: FunctionalRequirement,
|
||||
knowledge: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
为单个功能需求生成函数接口签名 JSON
|
||||
|
||||
Args:
|
||||
func_req: 功能需求对象(需含有效 id)
|
||||
knowledge: 知识库文本(可选)
|
||||
|
||||
Returns:
|
||||
符合接口规范的 dict,包含 name/requirement_id/description/
|
||||
type/parameters/return 字段
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: LLM 返回非合法 JSON
|
||||
"""
|
||||
requirement_id = self._format_requirement_id(func_req.index_no)
|
||||
knowledge_section = self._build_knowledge_section(knowledge)
|
||||
|
||||
prompt = config.FUNC_SIGNATURE_PROMPT_TEMPLATE.format(
|
||||
knowledge_section=knowledge_section,
|
||||
requirement_id=requirement_id,
|
||||
title=func_req.title,
|
||||
function_name=func_req.function_name,
|
||||
description=func_req.description,
|
||||
)
|
||||
|
||||
signature = self.llm.chat_json(
|
||||
system_prompt=(
|
||||
"你是一位资深软件架构师,专注于 API 接口设计。"
|
||||
"只输出合法 JSON,不添加任何说明文字。"
|
||||
),
|
||||
user_prompt=prompt,
|
||||
)
|
||||
|
||||
# 确保关键字段存在,做兜底处理
|
||||
signature.setdefault("name", func_req.function_name)
|
||||
signature.setdefault("requirement_id", requirement_id)
|
||||
signature.setdefault("description", func_req.description)
|
||||
signature.setdefault("type", "function")
|
||||
signature.setdefault("parameters", {})
|
||||
signature.setdefault("return", {"type": "void", "description": ""})
|
||||
|
||||
return signature
|
||||
|
||||
def build_function_signatures_batch(
|
||||
self,
|
||||
func_reqs: List[FunctionalRequirement],
|
||||
knowledge: str = "",
|
||||
on_progress=None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
批量为功能需求列表生成函数接口签名
|
||||
|
||||
Args:
|
||||
func_reqs: 功能需求列表
|
||||
knowledge: 知识库文本(可选)
|
||||
on_progress: 进度回调 fn(index, total, func_req, signature, error)
|
||||
|
||||
Returns:
|
||||
签名 dict 列表,顺序与 func_reqs 一致;
|
||||
生成失败的条目使用降级结构填充,不中断整体流程
|
||||
"""
|
||||
results = []
|
||||
total = len(func_reqs)
|
||||
|
||||
for i, req in enumerate(func_reqs):
|
||||
try:
|
||||
sig = self.build_function_signature(req, knowledge)
|
||||
results.append(sig)
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, sig, None)
|
||||
except Exception as e:
|
||||
# 降级:用基础信息填充,保证 JSON 完整性
|
||||
fallback = self._build_fallback_signature(req)
|
||||
results.append(fallback)
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, req, fallback, e)
|
||||
|
||||
return results
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# 私有工具方法
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
def _build_knowledge_section(knowledge: str) -> str:
|
||||
"""构建知识库 Prompt 段落"""
|
||||
if not knowledge or not knowledge.strip():
|
||||
return ""
|
||||
return f"""## 参考知识库
|
||||
{knowledge}
|
||||
|
||||
---
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_function_name(name: str) -> str:
|
||||
"""
|
||||
清理函数名,确保符合 snake_case 规范
|
||||
|
||||
Args:
|
||||
name: 原始函数名
|
||||
|
||||
Returns:
|
||||
合法的 snake_case 函数名
|
||||
"""
|
||||
name = re.sub(r"[^a-zA-Z0-9_]", "_", name).lower()
|
||||
name = re.sub(r"_+", "_", name).strip("_")
|
||||
if name and name[0].isdigit():
|
||||
name = "func_" + name
|
||||
return name or "unnamed_function"
|
||||
|
||||
@staticmethod
|
||||
def _format_requirement_id(index_no: int) -> str:
|
||||
"""
|
||||
将序号格式化为需求编号字符串
|
||||
|
||||
Args:
|
||||
index_no: 功能需求序号(从 1 开始)
|
||||
|
||||
Returns:
|
||||
格式化编号,如 'REQ.01'、'REQ.12'
|
||||
"""
|
||||
return f"REQ.{index_no:02d}"
|
||||
|
||||
@staticmethod
|
||||
def _build_fallback_signature(func_req: FunctionalRequirement) -> dict:
|
||||
"""
|
||||
构建降级签名(LLM 调用失败时使用)
|
||||
|
||||
Args:
|
||||
func_req: 功能需求对象
|
||||
|
||||
Returns:
|
||||
包含基础信息的签名 dict
|
||||
"""
|
||||
return {
|
||||
"name": func_req.function_name,
|
||||
"requirement_id": f"REQ.{func_req.index_no:02d}",
|
||||
"description": func_req.description,
|
||||
"type": "function",
|
||||
"parameters": {},
|
||||
"return": {
|
||||
"type": "void",
|
||||
"description": "TODO: define return value"
|
||||
},
|
||||
"_note": "Auto-generated fallback due to LLM error"
|
||||
}
|
||||
|
|
@ -0,0 +1,314 @@
|
|||
# database/db_manager.py - 数据库操作管理器
|
||||
import sqlite3
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
from database.models import (
|
||||
CREATE_TABLES_SQL, Project, RawRequirement,
|
||||
FunctionalRequirement, CodeFile
|
||||
)
|
||||
import config
|
||||
|
||||
|
||||
class DBManager:
|
||||
"""SQLite 数据库管理器,封装所有 CRUD 操作"""
|
||||
|
||||
def __init__(self, db_path: str = config.DB_PATH):
|
||||
self.db_path = db_path
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
# ── 连接上下文管理器 ──────────────────────────────────
|
||||
|
||||
@contextmanager
|
||||
def _get_conn(self):
|
||||
"""获取数据库连接(自动提交/回滚)"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库,创建所有表"""
|
||||
with self._get_conn() as conn:
|
||||
conn.executescript(CREATE_TABLES_SQL)
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# Project CRUD
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_project(self, project: Project) -> int:
|
||||
"""创建项目,返回新项目 ID"""
|
||||
sql = """
|
||||
INSERT INTO projects (name, description, language, output_dir, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
project.name, project.description, project.language,
|
||||
project.output_dir, project.created_at, project.updated_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
|
||||
def get_project_by_id(self, project_id: int) -> Optional[Project]:
|
||||
"""根据 ID 查询项目"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM projects WHERE id = ?", (project_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return Project(
|
||||
id=row["id"], name=row["name"], description=row["description"],
|
||||
language=row["language"], output_dir=row["output_dir"],
|
||||
created_at=row["created_at"], updated_at=row["updated_at"]
|
||||
)
|
||||
|
||||
def get_project_by_name(self, name: str) -> Optional[Project]:
|
||||
"""根据名称查询项目"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM projects WHERE name = ?", (name,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return Project(
|
||||
id=row["id"], name=row["name"], description=row["description"],
|
||||
language=row["language"], output_dir=row["output_dir"],
|
||||
created_at=row["created_at"], updated_at=row["updated_at"]
|
||||
)
|
||||
|
||||
def list_projects(self) -> List[Project]:
|
||||
"""列出所有项目"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM projects ORDER BY created_at DESC"
|
||||
).fetchall()
|
||||
return [
|
||||
Project(
|
||||
id=r["id"], name=r["name"], description=r["description"],
|
||||
language=r["language"], output_dir=r["output_dir"],
|
||||
created_at=r["created_at"], updated_at=r["updated_at"]
|
||||
) for r in rows
|
||||
]
|
||||
|
||||
def update_project(self, project: Project) -> None:
|
||||
"""更新项目信息"""
|
||||
project.updated_at = datetime.now().isoformat()
|
||||
sql = """
|
||||
UPDATE projects
|
||||
SET name=?, description=?, language=?, output_dir=?, updated_at=?
|
||||
WHERE id=?
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(sql, (
|
||||
project.name, project.description, project.language,
|
||||
project.output_dir, project.updated_at, project.id
|
||||
))
|
||||
|
||||
def delete_project(self, project_id: int) -> None:
|
||||
"""删除项目(级联删除所有关联数据)"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute("DELETE FROM projects WHERE id = ?", (project_id,))
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# RawRequirement CRUD
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_raw_requirement(self, req: RawRequirement) -> int:
|
||||
"""创建原始需求,返回新记录 ID"""
|
||||
sql = """
|
||||
INSERT INTO raw_requirements
|
||||
(project_id, content, source_type, source_name, knowledge, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
req.project_id, req.content, req.source_type,
|
||||
req.source_name, req.knowledge, req.created_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
|
||||
def get_raw_requirement(self, req_id: int) -> Optional[RawRequirement]:
|
||||
"""根据 ID 查询原始需求"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM raw_requirements WHERE id = ?", (req_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return RawRequirement(
|
||||
id=row["id"], project_id=row["project_id"], content=row["content"],
|
||||
source_type=row["source_type"], source_name=row["source_name"],
|
||||
knowledge=row["knowledge"], created_at=row["created_at"]
|
||||
)
|
||||
|
||||
def list_raw_requirements_by_project(self, project_id: int) -> List[RawRequirement]:
|
||||
"""查询项目下所有原始需求"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM raw_requirements WHERE project_id = ? ORDER BY created_at",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
return [
|
||||
RawRequirement(
|
||||
id=r["id"], project_id=r["project_id"], content=r["content"],
|
||||
source_type=r["source_type"], source_name=r["source_name"],
|
||||
knowledge=r["knowledge"], created_at=r["created_at"]
|
||||
) for r in rows
|
||||
]
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# FunctionalRequirement CRUD
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_functional_requirement(self, req: FunctionalRequirement) -> int:
|
||||
"""创建功能需求,返回新记录 ID"""
|
||||
sql = """
|
||||
INSERT INTO functional_requirements
|
||||
(project_id, raw_req_id, index_no, title, description,
|
||||
function_name, priority, status, is_custom, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
req.project_id, req.raw_req_id, req.index_no,
|
||||
req.title, req.description, req.function_name,
|
||||
req.priority, req.status, int(req.is_custom),
|
||||
req.created_at, req.updated_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
|
||||
def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]:
|
||||
"""根据 ID 查询功能需求"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM functional_requirements WHERE id = ?", (req_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_func_req(row)
|
||||
|
||||
def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]:
|
||||
"""查询项目下所有功能需求(按序号排序)"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"""SELECT * FROM functional_requirements
|
||||
WHERE project_id = ? ORDER BY index_no""",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
return [self._row_to_func_req(r) for r in rows]
|
||||
|
||||
def update_functional_requirement(self, req: FunctionalRequirement) -> None:
|
||||
"""更新功能需求"""
|
||||
req.updated_at = datetime.now().isoformat()
|
||||
sql = """
|
||||
UPDATE functional_requirements
|
||||
SET title=?, description=?, function_name=?, priority=?,
|
||||
status=?, index_no=?, updated_at=?
|
||||
WHERE id=?
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(sql, (
|
||||
req.title, req.description, req.function_name,
|
||||
req.priority, req.status, req.index_no,
|
||||
req.updated_at, req.id
|
||||
))
|
||||
|
||||
def delete_functional_requirement(self, req_id: int) -> None:
|
||||
"""删除功能需求"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM functional_requirements WHERE id = ?", (req_id,)
|
||||
)
|
||||
|
||||
def _row_to_func_req(self, row) -> FunctionalRequirement:
|
||||
"""sqlite Row → FunctionalRequirement 对象"""
|
||||
return FunctionalRequirement(
|
||||
id=row["id"], project_id=row["project_id"],
|
||||
raw_req_id=row["raw_req_id"], index_no=row["index_no"],
|
||||
title=row["title"], description=row["description"],
|
||||
function_name=row["function_name"], priority=row["priority"],
|
||||
status=row["status"], is_custom=bool(row["is_custom"]),
|
||||
created_at=row["created_at"], updated_at=row["updated_at"]
|
||||
)
|
||||
|
||||
# ══════════════════════════════════════════════════
|
||||
# CodeFile CRUD
|
||||
# ══════════════════════════════════════════════════
|
||||
|
||||
def create_code_file(self, code_file: CodeFile) -> int:
|
||||
"""创建代码文件记录,返回新记录 ID"""
|
||||
sql = """
|
||||
INSERT INTO code_files
|
||||
(project_id, func_req_id, file_name, file_path,
|
||||
language, content, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
cur = conn.execute(sql, (
|
||||
code_file.project_id, code_file.func_req_id,
|
||||
code_file.file_name, code_file.file_path,
|
||||
code_file.language, code_file.content,
|
||||
code_file.created_at, code_file.updated_at
|
||||
))
|
||||
return cur.lastrowid
|
||||
|
||||
def upsert_code_file(self, code_file: CodeFile) -> int:
|
||||
"""插入或更新代码文件(按 func_req_id 唯一键)"""
|
||||
existing = self.get_code_file_by_func_req(code_file.func_req_id)
|
||||
if existing:
|
||||
code_file.id = existing.id
|
||||
code_file.updated_at = datetime.now().isoformat()
|
||||
sql = """
|
||||
UPDATE code_files
|
||||
SET file_name=?, file_path=?, language=?, content=?, updated_at=?
|
||||
WHERE id=?
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(sql, (
|
||||
code_file.file_name, code_file.file_path,
|
||||
code_file.language, code_file.content,
|
||||
code_file.updated_at, code_file.id
|
||||
))
|
||||
return code_file.id
|
||||
else:
|
||||
return self.create_code_file(code_file)
|
||||
|
||||
def get_code_file_by_func_req(self, func_req_id: int) -> Optional[CodeFile]:
|
||||
"""根据功能需求 ID 查询代码文件"""
|
||||
with self._get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM code_files WHERE func_req_id = ?", (func_req_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_code_file(row)
|
||||
|
||||
def list_code_files_by_project(self, project_id: int) -> List[CodeFile]:
|
||||
"""查询项目下所有代码文件"""
|
||||
with self._get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM code_files WHERE project_id = ? ORDER BY created_at",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
return [self._row_to_code_file(r) for r in rows]
|
||||
|
||||
def _row_to_code_file(self, row) -> CodeFile:
|
||||
"""sqlite Row → CodeFile 对象"""
|
||||
return CodeFile(
|
||||
id=row["id"], project_id=row["project_id"],
|
||||
func_req_id=row["func_req_id"], file_name=row["file_name"],
|
||||
file_path=row["file_path"], language=row["language"],
|
||||
content=row["content"], created_at=row["created_at"],
|
||||
updated_at=row["updated_at"]
|
||||
)
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
# database/models.py - 数据模型定义(SQLite 建表 DDL)
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# DDL 建表语句
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
CREATE_TABLES_SQL = """
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
-- 项目表
|
||||
CREATE TABLE IF NOT EXISTS projects (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE, -- 项目名称
|
||||
description TEXT, -- 项目描述
|
||||
language TEXT NOT NULL DEFAULT 'python', -- 目标代码语言
|
||||
output_dir TEXT NOT NULL, -- 输出目录路径
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- 原始需求表
|
||||
CREATE TABLE IF NOT EXISTS raw_requirements (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
project_id INTEGER NOT NULL, -- 关联项目
|
||||
content TEXT NOT NULL, -- 需求原文
|
||||
source_type TEXT NOT NULL DEFAULT 'text', -- text | file
|
||||
source_name TEXT, -- 文件名(文件输入时)
|
||||
knowledge TEXT, -- 合并后的知识库内容
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- 功能需求表
|
||||
CREATE TABLE IF NOT EXISTS functional_requirements (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
project_id INTEGER NOT NULL,
|
||||
raw_req_id INTEGER NOT NULL, -- 关联原始需求
|
||||
index_no INTEGER NOT NULL, -- 序号
|
||||
title TEXT NOT NULL, -- 功能标题
|
||||
description TEXT NOT NULL, -- 功能描述
|
||||
function_name TEXT NOT NULL, -- 对应函数名
|
||||
priority TEXT NOT NULL DEFAULT 'medium', -- high|medium|low
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- pending|generated|skipped
|
||||
is_custom INTEGER NOT NULL DEFAULT 0, -- 是否用户自定义 (0/1)
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (raw_req_id) REFERENCES raw_requirements(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- 代码文件表
|
||||
CREATE TABLE IF NOT EXISTS code_files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
project_id INTEGER NOT NULL,
|
||||
func_req_id INTEGER NOT NULL UNIQUE, -- 关联功能需求(1对1)
|
||||
file_name TEXT NOT NULL, -- 文件名
|
||||
file_path TEXT NOT NULL, -- 完整路径
|
||||
language TEXT NOT NULL, -- 代码语言
|
||||
content TEXT NOT NULL, -- 代码内容
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (func_req_id) REFERENCES functional_requirements(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# 数据类(Python 对象映射)
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
@dataclass
|
||||
class Project:
|
||||
name: str
|
||||
output_dir: str
|
||||
language: str = "python"
|
||||
description: str = ""
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
|
||||
@dataclass
|
||||
class RawRequirement:
|
||||
project_id: int
|
||||
content: str
|
||||
source_type: str = "text" # text | file
|
||||
source_name: Optional[str] = None
|
||||
knowledge: Optional[str] = None
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionalRequirement:
|
||||
project_id: int
|
||||
raw_req_id: int
|
||||
index_no: int
|
||||
title: str
|
||||
description: str
|
||||
function_name: str
|
||||
priority: str = "medium"
|
||||
status: str = "pending"
|
||||
is_custom: bool = False
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeFile:
|
||||
project_id: int
|
||||
func_req_id: int
|
||||
file_name: str
|
||||
file_path: str
|
||||
language: str
|
||||
content: str
|
||||
id: Optional[int] = None
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
|
@ -0,0 +1,789 @@
|
|||
#!/usr/bin/env python3
|
||||
# encoding: utf-8
|
||||
# main.py - 主入口:支持交互式 & 非交互式(CLI 参数)两种运行模式
|
||||
#
|
||||
# 交互式: python main.py
|
||||
# 非交互式:python main.py --non-interactive \
|
||||
# --project-name "MyProject" \
|
||||
# --language python \
|
||||
# --requirement-text "用户管理系统,包含注册、登录、修改密码功能"
|
||||
#
|
||||
# 完整参数见:python main.py --help
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt, Confirm
|
||||
|
||||
import config
|
||||
from database.db_manager import DBManager
|
||||
from database.models import Project, RawRequirement, FunctionalRequirement
|
||||
from core.llm_client import LLMClient
|
||||
from core.requirement_analyzer import RequirementAnalyzer
|
||||
from core.code_generator import CodeGenerator
|
||||
from utils.file_handler import read_file_auto, merge_knowledge_files
|
||||
from utils.output_writer import (
|
||||
ensure_project_dir, build_project_output_dir, write_project_readme,
|
||||
write_function_signatures_json, validate_all_signatures,
|
||||
patch_signatures_with_url,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
db = DBManager()
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# 显示工具函数
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def print_banner():
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]🚀 需求分析 & 代码生成工具[/bold cyan]\n"
|
||||
"[dim]Powered by LLM · SQLite · Python[/dim]",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
|
||||
def print_functional_requirements(reqs: list):
|
||||
"""以表格形式展示功能需求列表"""
|
||||
table = Table(title="📋 功能需求列表", show_lines=True)
|
||||
table.add_column("序号", style="cyan", width=6)
|
||||
table.add_column("ID", style="dim", width=6)
|
||||
table.add_column("标题", style="bold", width=20)
|
||||
table.add_column("函数名", width=25)
|
||||
table.add_column("优先级", width=8)
|
||||
table.add_column("类型", width=8)
|
||||
table.add_column("描述", width=40)
|
||||
|
||||
priority_color = {"high": "red", "medium": "yellow", "low": "green"}
|
||||
for req in reqs:
|
||||
color = priority_color.get(req.priority, "white")
|
||||
table.add_row(
|
||||
str(req.index_no),
|
||||
str(req.id) if req.id else "-",
|
||||
req.title,
|
||||
f"[code]{req.function_name}[/code]",
|
||||
f"[{color}]{req.priority}[/{color}]",
|
||||
"[magenta]自定义[/magenta]" if req.is_custom else "LLM生成",
|
||||
req.description[:60] + "..." if len(req.description) > 60 else req.description,
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
|
||||
def print_signatures_preview(signatures: list):
|
||||
"""
|
||||
以表格形式预览函数签名列表(含 url 字段)
|
||||
|
||||
Args:
|
||||
signatures: 纯签名列表(顶层文档的 "functions" 字段)
|
||||
"""
|
||||
table = Table(title="📄 函数签名预览", show_lines=True)
|
||||
table.add_column("需求编号", style="cyan", width=10)
|
||||
table.add_column("函数名", style="bold", width=25)
|
||||
table.add_column("参数数量", width=8)
|
||||
table.add_column("返回类型", width=10)
|
||||
table.add_column("成功返回值", width=18)
|
||||
table.add_column("失败返回值", width=18)
|
||||
table.add_column("URL", style="dim", width=30)
|
||||
|
||||
def _fmt_value(v) -> str:
|
||||
if v is None:
|
||||
return "-"
|
||||
if isinstance(v, dict):
|
||||
return "{" + ", ".join(v.keys()) + "}"
|
||||
return str(v)[:16]
|
||||
|
||||
for sig in signatures:
|
||||
ret = sig.get("return") or {}
|
||||
on_success = ret.get("on_success") or {}
|
||||
on_failure = ret.get("on_failure") or {}
|
||||
url = sig.get("url", "")
|
||||
# 只显示文件名部分,避免路径过长
|
||||
url_display = os.path.basename(url) if url else "[dim]待生成[/dim]"
|
||||
|
||||
table.add_row(
|
||||
sig.get("requirement_id", "-"),
|
||||
sig.get("name", "-"),
|
||||
str(len(sig.get("parameters", {}))),
|
||||
ret.get("type", "void"),
|
||||
_fmt_value(on_success.get("value")),
|
||||
_fmt_value(on_failure.get("value")),
|
||||
url_display,
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 1:项目初始化
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_init_project(
|
||||
project_name: str = None,
|
||||
language: str = None,
|
||||
description: str = "",
|
||||
non_interactive: bool = False,
|
||||
) -> Project:
|
||||
if not non_interactive:
|
||||
console.print("\n[bold]Step 1 · 项目配置[/bold]", style="blue")
|
||||
project_name = project_name or Prompt.ask("📁 请输入项目名称")
|
||||
language = language or Prompt.ask(
|
||||
"💻 目标代码语言",
|
||||
default=config.DEFAULT_LANGUAGE,
|
||||
choices=["python", "javascript", "typescript", "java", "go", "rust"],
|
||||
)
|
||||
description = description or Prompt.ask("📝 项目描述(可选)", default="")
|
||||
else:
|
||||
if not project_name:
|
||||
raise ValueError("非交互模式下 --project-name 为必填项")
|
||||
language = language or config.DEFAULT_LANGUAGE
|
||||
console.print("\n[bold]Step 1 · 项目配置[/bold] [dim](非交互)[/dim]", style="blue")
|
||||
console.print(f" 项目名称: {project_name} 语言: {language}")
|
||||
|
||||
existing = db.get_project_by_name(project_name)
|
||||
if existing:
|
||||
if non_interactive:
|
||||
console.print(f"[green]✓ 已加载已有项目: {project_name} (ID={existing.id})[/green]")
|
||||
return existing
|
||||
use_existing = Confirm.ask(f"⚠️ 项目 '{project_name}' 已存在,是否继续使用?")
|
||||
if use_existing:
|
||||
console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]")
|
||||
return existing
|
||||
project_name = Prompt.ask("请输入新的项目名称")
|
||||
|
||||
output_dir = build_project_output_dir(project_name)
|
||||
project = Project(
|
||||
name=project_name,
|
||||
language=language,
|
||||
output_dir=output_dir,
|
||||
description=description,
|
||||
)
|
||||
project.id = db.create_project(project)
|
||||
console.print(f"[green]✓ 项目已创建: {project_name} (ID={project.id})[/green]")
|
||||
return project
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 2:输入原始需求 & 知识库
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_input_requirement(
|
||||
project: Project,
|
||||
requirement_text: str = None,
|
||||
requirement_file: str = None,
|
||||
knowledge_files: list = None,
|
||||
non_interactive: bool = False,
|
||||
) -> tuple:
|
||||
console.print(
|
||||
f"\n[bold]Step 2 · 输入原始需求[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
raw_text = ""
|
||||
source_name = None
|
||||
source_type = "text"
|
||||
|
||||
if non_interactive:
|
||||
if requirement_file:
|
||||
raw_text = read_file_auto(requirement_file)
|
||||
source_name = os.path.basename(requirement_file)
|
||||
source_type = "file"
|
||||
console.print(f" 需求文件: {source_name} ({len(raw_text)} 字符)")
|
||||
elif requirement_text:
|
||||
raw_text = requirement_text
|
||||
source_type = "text"
|
||||
console.print(f" 需求文本: {raw_text[:80]}{'...' if len(raw_text) > 80 else ''}")
|
||||
else:
|
||||
raise ValueError("非交互模式下必须提供 --requirement-text 或 --requirement-file")
|
||||
else:
|
||||
input_type = Prompt.ask("📥 需求输入方式", choices=["text", "file"], default="text")
|
||||
if input_type == "text":
|
||||
console.print("[dim]请输入原始需求(输入空行结束):[/dim]")
|
||||
lines = []
|
||||
while True:
|
||||
line = input()
|
||||
if line == "" and lines:
|
||||
break
|
||||
lines.append(line)
|
||||
raw_text = "\n".join(lines)
|
||||
source_type = "text"
|
||||
else:
|
||||
file_path = Prompt.ask("📂 需求文件路径")
|
||||
raw_text = read_file_auto(file_path)
|
||||
source_name = os.path.basename(file_path)
|
||||
source_type = "file"
|
||||
console.print(f"[green]✓ 已读取文件: {source_name} ({len(raw_text)} 字符)[/green]")
|
||||
|
||||
knowledge_text = ""
|
||||
if non_interactive:
|
||||
if knowledge_files:
|
||||
knowledge_text = merge_knowledge_files(list(knowledge_files))
|
||||
console.print(f" 知识库: {len(knowledge_files)} 个文件,{len(knowledge_text)} 字符")
|
||||
else:
|
||||
use_kb = Confirm.ask("📚 是否输入知识库文件?", default=False)
|
||||
if use_kb:
|
||||
kb_paths = []
|
||||
while True:
|
||||
kb_path = Prompt.ask("知识库文件路径(留空结束)", default="")
|
||||
if not kb_path:
|
||||
break
|
||||
if os.path.exists(kb_path):
|
||||
kb_paths.append(kb_path)
|
||||
console.print(f" [green]+ {kb_path}[/green]")
|
||||
else:
|
||||
console.print(f" [red]文件不存在: {kb_path}[/red]")
|
||||
if kb_paths:
|
||||
knowledge_text = merge_knowledge_files(kb_paths)
|
||||
console.print(f"[green]✓ 知识库已合并 ({len(knowledge_text)} 字符)[/green]")
|
||||
|
||||
return raw_text, knowledge_text, source_name, source_type
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 3:LLM 分解需求
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_decompose_requirements(
|
||||
project: Project,
|
||||
raw_text: str,
|
||||
knowledge_text: str,
|
||||
source_name: str,
|
||||
source_type: str,
|
||||
non_interactive: bool = False,
|
||||
) -> tuple:
|
||||
console.print(
|
||||
f"\n[bold]Step 3 · LLM 需求分解[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
raw_req = RawRequirement(
|
||||
project_id=project.id,
|
||||
content=raw_text,
|
||||
source_type=source_type,
|
||||
source_name=source_name,
|
||||
knowledge=knowledge_text or None,
|
||||
)
|
||||
raw_req_id = db.create_raw_requirement(raw_req)
|
||||
console.print(f"[dim]原始需求已存储 (ID={raw_req_id})[/dim]")
|
||||
|
||||
with console.status("[bold yellow]🤖 LLM 正在分解需求,请稍候...[/bold yellow]"):
|
||||
llm = LLMClient()
|
||||
analyzer = RequirementAnalyzer(llm)
|
||||
func_reqs = analyzer.decompose(
|
||||
raw_requirement=raw_text,
|
||||
project_id=project.id,
|
||||
raw_req_id=raw_req_id,
|
||||
knowledge=knowledge_text,
|
||||
)
|
||||
|
||||
for req in func_reqs:
|
||||
req.id = db.create_functional_requirement(req)
|
||||
|
||||
console.print(f"[green]✓ 已生成 {len(func_reqs)} 个功能需求[/green]")
|
||||
return raw_req_id, func_reqs
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 4:用户编辑功能需求
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_edit_requirements(
|
||||
project: Project,
|
||||
func_reqs: list,
|
||||
raw_req_id: int,
|
||||
non_interactive: bool = False,
|
||||
skip_indices: list = None,
|
||||
) -> list:
|
||||
console.print(
|
||||
f"\n[bold]Step 4 · 编辑功能需求[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
if non_interactive:
|
||||
if skip_indices:
|
||||
to_skip = set(skip_indices)
|
||||
removed, kept = [], []
|
||||
for req in func_reqs:
|
||||
if req.index_no in to_skip:
|
||||
db.delete_functional_requirement(req.id)
|
||||
removed.append(req.title)
|
||||
else:
|
||||
kept.append(req)
|
||||
func_reqs = kept
|
||||
for i, req in enumerate(func_reqs, 1):
|
||||
req.index_no = i
|
||||
db.update_functional_requirement(req)
|
||||
if removed:
|
||||
console.print(f" [red]已跳过: {', '.join(removed)}[/red]")
|
||||
print_functional_requirements(func_reqs)
|
||||
return func_reqs
|
||||
|
||||
while True:
|
||||
print_functional_requirements(func_reqs)
|
||||
console.print(
|
||||
"\n操作: [cyan]d[/cyan]=删除 [cyan]a[/cyan]=添加 "
|
||||
"[cyan]e[/cyan]=编辑 [cyan]ok[/cyan]=确认继续"
|
||||
)
|
||||
action = Prompt.ask("请选择操作", default="ok").strip().lower()
|
||||
|
||||
if action == "ok":
|
||||
break
|
||||
elif action == "d":
|
||||
idx_str = Prompt.ask("输入要删除的功能需求序号(多个用逗号分隔)")
|
||||
to_delete = {int(x.strip()) for x in idx_str.split(",") if x.strip().isdigit()}
|
||||
removed, kept = [], []
|
||||
for req in func_reqs:
|
||||
if req.index_no in to_delete:
|
||||
db.delete_functional_requirement(req.id)
|
||||
removed.append(req.title)
|
||||
else:
|
||||
kept.append(req)
|
||||
func_reqs = kept
|
||||
for i, req in enumerate(func_reqs, 1):
|
||||
req.index_no = i
|
||||
db.update_functional_requirement(req)
|
||||
console.print(f"[red]✗ 已删除: {', '.join(removed)}[/red]")
|
||||
elif action == "a":
|
||||
title = Prompt.ask("功能标题")
|
||||
description = Prompt.ask("功能描述")
|
||||
func_name = Prompt.ask("函数名 (snake_case)")
|
||||
priority = Prompt.ask(
|
||||
"优先级", choices=["high", "medium", "low"], default="medium"
|
||||
)
|
||||
new_req = FunctionalRequirement(
|
||||
project_id=project.id,
|
||||
raw_req_id=raw_req_id,
|
||||
index_no=len(func_reqs) + 1,
|
||||
title=title,
|
||||
description=description,
|
||||
function_name=func_name,
|
||||
priority=priority,
|
||||
is_custom=True,
|
||||
)
|
||||
new_req.id = db.create_functional_requirement(new_req)
|
||||
func_reqs.append(new_req)
|
||||
console.print(f"[green]✓ 已添加自定义需求: {title}[/green]")
|
||||
elif action == "e":
|
||||
idx_str = Prompt.ask("输入要编辑的功能需求序号")
|
||||
if not idx_str.isdigit():
|
||||
continue
|
||||
idx = int(idx_str)
|
||||
target = next((r for r in func_reqs if r.index_no == idx), None)
|
||||
if target is None:
|
||||
console.print("[red]序号不存在[/red]")
|
||||
continue
|
||||
target.title = Prompt.ask("新标题", default=target.title)
|
||||
target.description = Prompt.ask("新描述", default=target.description)
|
||||
target.function_name = Prompt.ask("新函数名", default=target.function_name)
|
||||
target.priority = Prompt.ask(
|
||||
"新优先级", choices=["high", "medium", "low"], default=target.priority
|
||||
)
|
||||
db.update_functional_requirement(target)
|
||||
console.print(f"[green]✓ 已更新: {target.title}[/green]")
|
||||
|
||||
return func_reqs
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 5A:生成函数签名 JSON(不含 url 字段,待 5C 回写)
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_generate_signatures(
|
||||
project: Project,
|
||||
func_reqs: list,
|
||||
output_dir: str,
|
||||
knowledge_text: str,
|
||||
json_file_name: str = "function_signatures.json",
|
||||
non_interactive: bool = False,
|
||||
) -> tuple:
|
||||
"""
|
||||
为所有功能需求生成函数签名,写入初版 JSON(不含 url 字段)。
|
||||
url 字段将在 Step 5C 代码生成完成后回写并刷新 JSON 文件。
|
||||
|
||||
Returns:
|
||||
(signatures: List[dict], json_path: str)
|
||||
"""
|
||||
console.print(
|
||||
f"\n[bold]Step 5A · 生成函数签名 JSON[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
llm = LLMClient()
|
||||
analyzer = RequirementAnalyzer(llm)
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
def on_progress(index, total, req, signature, error):
|
||||
nonlocal success_count, fail_count
|
||||
if error:
|
||||
console.print(
|
||||
f" [{index}/{total}] [yellow]⚠ {req.title} 签名生成失败,"
|
||||
f"使用降级结构: {error}[/yellow]"
|
||||
)
|
||||
fail_count += 1
|
||||
else:
|
||||
console.print(
|
||||
f" [{index}/{total}] [green]✓ {req.title}[/green] "
|
||||
f"→ [dim]{signature.get('name')}()[/dim] "
|
||||
f"params={len(signature.get('parameters', {}))}"
|
||||
)
|
||||
success_count += 1
|
||||
|
||||
console.print(f"[yellow]正在为 {len(func_reqs)} 个功能需求生成函数签名...[/yellow]")
|
||||
signatures = analyzer.build_function_signatures_batch(
|
||||
func_reqs=func_reqs,
|
||||
knowledge=knowledge_text,
|
||||
on_progress=on_progress,
|
||||
)
|
||||
|
||||
# 校验
|
||||
validation_report = validate_all_signatures(signatures)
|
||||
if validation_report:
|
||||
console.print(f"[yellow]⚠ 发现 {len(validation_report)} 个签名存在结构问题:[/yellow]")
|
||||
for fname, errors in validation_report.items():
|
||||
for err in errors:
|
||||
console.print(f" [yellow]· {fname}: {err}[/yellow]")
|
||||
else:
|
||||
console.print("[green]✓ 所有签名结构校验通过[/green]")
|
||||
|
||||
# 写入初版 JSON(url 字段尚未填入)
|
||||
json_path = write_function_signatures_json(
|
||||
output_dir=output_dir,
|
||||
signatures=signatures,
|
||||
project_name=project.name,
|
||||
project_description=project.description or "", # ← 传入项目描述
|
||||
file_name=json_file_name,
|
||||
)
|
||||
console.print(
|
||||
f"[green]✓ 签名 JSON 初版已写入: [cyan]{os.path.abspath(json_path)}[/cyan][/green]\n"
|
||||
f" 成功: {success_count} 降级: {fail_count}"
|
||||
)
|
||||
return signatures, json_path
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 5B:生成代码文件,收集 {函数名: 文件路径} 映射
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_generate_code(
|
||||
project: Project,
|
||||
func_reqs: list,
|
||||
output_dir: str,
|
||||
knowledge_text: str,
|
||||
signatures: list,
|
||||
non_interactive: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
依据签名约束批量生成代码文件。
|
||||
|
||||
Returns:
|
||||
func_name_to_url: {函数名: 代码文件绝对路径} 映射表,
|
||||
供 Step 5C 回写 url 字段使用。
|
||||
生成失败的函数不会出现在映射表中。
|
||||
"""
|
||||
console.print(
|
||||
f"\n[bold]Step 5B · 生成代码文件[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
generator = CodeGenerator(LLMClient())
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
func_name_to_url: Dict[str, str] = {} # ← 收集 函数名 → 文件绝对路径
|
||||
|
||||
def on_progress(index, total, req, code_file, error):
|
||||
nonlocal success_count, fail_count
|
||||
if error:
|
||||
console.print(f" [{index}/{total}] [red]✗ {req.title}: {error}[/red]")
|
||||
fail_count += 1
|
||||
else:
|
||||
db.upsert_code_file(code_file)
|
||||
req.status = "generated"
|
||||
db.update_functional_requirement(req)
|
||||
# 收集 函数名 → 绝对文件路径(作为 url 回写)
|
||||
func_name_to_url[req.function_name] = os.path.abspath(code_file.file_path)
|
||||
console.print(
|
||||
f" [{index}/{total}] [green]✓ {req.title}[/green] "
|
||||
f"→ [dim]{code_file.file_name}[/dim]"
|
||||
)
|
||||
success_count += 1
|
||||
|
||||
console.print(f"[yellow]开始生成 {len(func_reqs)} 个代码文件(签名约束模式)...[/yellow]")
|
||||
generator.generate_batch(
|
||||
func_reqs=func_reqs,
|
||||
output_dir=output_dir,
|
||||
language=project.language,
|
||||
knowledge=knowledge_text,
|
||||
signatures=signatures,
|
||||
on_progress=on_progress,
|
||||
)
|
||||
|
||||
req_summary = "\n".join(
|
||||
f"{i+1}. **{r.title}** (`{r.function_name}`) - {r.description[:80]}"
|
||||
for i, r in enumerate(func_reqs)
|
||||
)
|
||||
write_project_readme(output_dir, project.name, req_summary)
|
||||
|
||||
console.print(Panel(
|
||||
f"[bold green]✅ 代码生成完成![/bold green]\n"
|
||||
f"成功: {success_count} 失败: {fail_count}\n"
|
||||
f"输出目录: [cyan]{os.path.abspath(output_dir)}[/cyan]",
|
||||
border_style="green",
|
||||
))
|
||||
return func_name_to_url
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# Step 5C:回写 url 字段并刷新 JSON
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def step_patch_signatures_url(
|
||||
project: Project,
|
||||
signatures: list,
|
||||
func_name_to_url: Dict[str, str],
|
||||
output_dir: str,
|
||||
json_file_name: str,
|
||||
non_interactive: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
将代码文件路径回写到签名的 "url" 字段,并重新写入 JSON 文件。
|
||||
|
||||
执行流程:
|
||||
1. 调用 patch_signatures_with_url() 原地修改签名列表
|
||||
2. 打印最终签名预览(含 url 列)
|
||||
3. 重新调用 write_function_signatures_json() 覆盖写入 JSON
|
||||
|
||||
Args:
|
||||
project: 项目对象(提供 name 与 description)
|
||||
signatures: Step 5A 产出的签名列表(将被原地修改)
|
||||
func_name_to_url: Step 5B 收集的 {函数名: 文件绝对路径} 映射
|
||||
output_dir: JSON 文件所在目录
|
||||
json_file_name: JSON 文件名
|
||||
non_interactive: 是否非交互模式
|
||||
|
||||
Returns:
|
||||
刷新后的 JSON 文件绝对路径
|
||||
"""
|
||||
console.print(
|
||||
f"\n[bold]Step 5C · 回写代码文件路径(url)到签名 JSON[/bold]"
|
||||
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
|
||||
style="blue",
|
||||
)
|
||||
|
||||
# 原地回写 url 字段
|
||||
patch_signatures_with_url(signatures, func_name_to_url)
|
||||
|
||||
patched = sum(1 for s in signatures if s.get("url"))
|
||||
unpatched = len(signatures) - patched
|
||||
if unpatched:
|
||||
console.print(
|
||||
f"[yellow]⚠ {unpatched} 个函数未能写入 url"
|
||||
f"(对应代码文件生成失败)[/yellow]"
|
||||
)
|
||||
|
||||
# 打印最终预览(含 url 列)
|
||||
print_signatures_preview(signatures)
|
||||
|
||||
# 覆盖写入 JSON(含 project.description)
|
||||
json_path = write_function_signatures_json(
|
||||
output_dir=output_dir,
|
||||
signatures=signatures,
|
||||
project_name=project.name,
|
||||
project_description=project.description or "",
|
||||
file_name=json_file_name,
|
||||
)
|
||||
|
||||
console.print(
|
||||
f"[green]✓ 签名 JSON 已更新(含 url): "
|
||||
f"[cyan]{os.path.abspath(json_path)}[/cyan][/green]\n"
|
||||
f" 已回写: {patched} 未回写: {unpatched}"
|
||||
)
|
||||
return os.path.abspath(json_path)
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# 核心工作流
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def run_workflow(
|
||||
project_name: str = None,
|
||||
language: str = None,
|
||||
description: str = "",
|
||||
requirement_text: str = None,
|
||||
requirement_file: str = None,
|
||||
knowledge_files: tuple = (),
|
||||
skip_indices: list = None,
|
||||
json_file_name: str = "function_signatures.json",
|
||||
non_interactive: bool = False,
|
||||
):
|
||||
"""完整工作流(Step 1 → 5C)"""
|
||||
print_banner()
|
||||
|
||||
# Step 1
|
||||
project = step_init_project(
|
||||
project_name=project_name,
|
||||
language=language,
|
||||
description=description,
|
||||
non_interactive=non_interactive,
|
||||
)
|
||||
|
||||
# Step 2
|
||||
raw_text, knowledge_text, source_name, source_type = step_input_requirement(
|
||||
project=project,
|
||||
requirement_text=requirement_text,
|
||||
requirement_file=requirement_file,
|
||||
knowledge_files=list(knowledge_files) if knowledge_files else [],
|
||||
non_interactive=non_interactive,
|
||||
)
|
||||
|
||||
# Step 3
|
||||
raw_req_id, func_reqs = step_decompose_requirements(
|
||||
project=project,
|
||||
raw_text=raw_text,
|
||||
knowledge_text=knowledge_text,
|
||||
source_name=source_name,
|
||||
source_type=source_type,
|
||||
non_interactive=non_interactive,
|
||||
)
|
||||
|
||||
# Step 4
|
||||
func_reqs = step_edit_requirements(
|
||||
project=project,
|
||||
func_reqs=func_reqs,
|
||||
raw_req_id=raw_req_id,
|
||||
non_interactive=non_interactive,
|
||||
skip_indices=skip_indices or [],
|
||||
)
|
||||
|
||||
if not func_reqs:
|
||||
console.print("[red]⚠ 功能需求列表为空,流程终止[/red]")
|
||||
return
|
||||
|
||||
output_dir = ensure_project_dir(project.name)
|
||||
|
||||
# Step 5A:生成签名(初版,不含 url)
|
||||
signatures, json_path = step_generate_signatures(
|
||||
project=project,
|
||||
func_reqs=func_reqs,
|
||||
output_dir=output_dir,
|
||||
knowledge_text=knowledge_text,
|
||||
json_file_name=json_file_name,
|
||||
non_interactive=non_interactive,
|
||||
)
|
||||
|
||||
# Step 5B:生成代码,收集 {函数名: 文件路径}
|
||||
func_name_to_url = step_generate_code(
|
||||
project=project,
|
||||
func_reqs=func_reqs,
|
||||
output_dir=output_dir,
|
||||
knowledge_text=knowledge_text,
|
||||
signatures=signatures,
|
||||
non_interactive=non_interactive,
|
||||
)
|
||||
|
||||
# Step 5C:回写 url 字段,刷新 JSON
|
||||
json_path = step_patch_signatures_url(
|
||||
project=project,
|
||||
signatures=signatures,
|
||||
func_name_to_url=func_name_to_url,
|
||||
output_dir=output_dir,
|
||||
json_file_name=json_file_name,
|
||||
non_interactive=non_interactive,
|
||||
)
|
||||
|
||||
console.print(Panel(
|
||||
f"[bold cyan]🎉 全部流程完成![/bold cyan]\n"
|
||||
f"项目: [bold]{project.name}[/bold]\n"
|
||||
f"描述: {project.description or '(无)'}\n"
|
||||
f"代码目录: [cyan]{os.path.abspath(output_dir)}[/cyan]\n"
|
||||
f"签名文件: [cyan]{json_path}[/cyan]",
|
||||
border_style="cyan",
|
||||
))
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# CLI 入口(click)
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
@click.command()
|
||||
@click.option("--non-interactive", is_flag=True, default=False,
|
||||
help="以非交互模式运行(所有参数通过命令行传入)")
|
||||
@click.option("--project-name", "-p", default=None, help="项目名称")
|
||||
@click.option("--language", "-l", default=None,
|
||||
type=click.Choice(["python","javascript","typescript","java","go","rust"]),
|
||||
help=f"目标代码语言(默认: {config.DEFAULT_LANGUAGE})")
|
||||
@click.option("--description", "-d", default="", help="项目描述")
|
||||
@click.option("--requirement-text","-r", default=None,
|
||||
help="原始需求文本(与 --requirement-file 二选一)")
|
||||
@click.option("--requirement-file","-f", default=None,
|
||||
type=click.Path(exists=True),
|
||||
help="原始需求文件路径(支持 .txt/.md/.pdf/.docx)")
|
||||
@click.option("--knowledge-file", "-k", default=None, multiple=True,
|
||||
type=click.Path(exists=True),
|
||||
help="知识库文件路径(可多次指定,如 -k a.md -k b.pdf)")
|
||||
@click.option("--skip-index", "-s", default=None, multiple=True, type=int,
|
||||
help="要跳过的功能需求序号(可多次指定,如 -s 2 -s 5)")
|
||||
@click.option("--json-file-name", "-j", default="function_signatures.json",
|
||||
help="函数签名 JSON 文件名(默认: function_signatures.json)")
|
||||
def cli(
|
||||
non_interactive, project_name, language, description,
|
||||
requirement_text, requirement_file, knowledge_file,
|
||||
skip_index, json_file_name,
|
||||
):
|
||||
"""
|
||||
需求分析 & 代码生成工具
|
||||
|
||||
\b
|
||||
交互式运行(推荐初次使用):
|
||||
python main.py
|
||||
|
||||
\b
|
||||
非交互式运行示例:
|
||||
python main.py --non-interactive \\
|
||||
--project-name "UserSystem" \\
|
||||
--description "用户管理系统后端服务" \\
|
||||
--language python \\
|
||||
--requirement-text "用户管理系统,包含注册、登录、修改密码功能" \\
|
||||
--knowledge-file docs/api_spec.md \\
|
||||
--json-file-name api_signatures.json
|
||||
|
||||
\b
|
||||
从文件读取需求 + 跳过部分功能需求:
|
||||
python main.py --non-interactive \\
|
||||
--project-name "MyProject" \\
|
||||
--requirement-file requirements.md \\
|
||||
--skip-index 3 --skip-index 7
|
||||
"""
|
||||
try:
|
||||
run_workflow(
|
||||
project_name=project_name,
|
||||
language=language,
|
||||
description=description,
|
||||
requirement_text=requirement_text,
|
||||
requirement_file=requirement_file,
|
||||
knowledge_files=knowledge_file,
|
||||
skip_indices=list(skip_index) if skip_index else [],
|
||||
json_file_name=json_file_name,
|
||||
non_interactive=non_interactive,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]用户中断,退出[/yellow]")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
console.print(f"\n[bold red]❌ 错误: {e}[/bold red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
你是一个产品经理与高级软件工程师,非常善于将用户原始需求进行理解、分析并生成条目化的功能需求,再根据每个功能需求生成指定的功能函数,现在生成python工程实现以上功能,设计如下:
|
||||
|
||||
# 工作流程设计
|
||||
1. 用户输入:知识库,原始需求描述、项目名称。其中原始需求支持文本与文件格式,知识库可选,通过一个或多个文件输入;
|
||||
2. 软件通过llm模型,结合输入的知识库,将原始需求分解成多个可实现的功能需求,并显示给用户;
|
||||
3. 软件支持用户删除其中不需要的功能需求,同时支持添加自定义功能需求;
|
||||
4. 软件根据配置后的功能需求,生成指定代码(默认使用python)的功能函数文件输出到指定项目名称的目录中,一个功能需求对应一个函数、一个代码文件;
|
||||
|
||||
# 数据库设计
|
||||
1. 使用sqlite数据库存储相关信息,包含项目表、原始需求表,功能需求表、代码文件表;
|
||||
2. 数据库支持通过ID实现项目、原始需求、功能需求、代码文件之间的关联关系;
|
||||
|
||||
更新以上代码,支持以json格式将所有的功能函数描述输出到json文件,格式如下:
|
||||
[
|
||||
{
|
||||
"name": "change_password",
|
||||
"requirement_id": "the id related to the requirement. e.g.REQ.01",
|
||||
"description": "replace old password with new password",
|
||||
"type": "function",
|
||||
"parameters": {
|
||||
"user_id": { "type": "integer", "inout": "in", "description": "the user's id", "required": true },
|
||||
"old_password": { "type": "string", "inout": "in", "description": "the old password", "required": true },
|
||||
"new_password": { "type": "string", "inout": "in", "description": "the user's new password", "required": true }
|
||||
},
|
||||
"return": {
|
||||
"type": "integer",
|
||||
"description": "0 is successful, nonzero is failure"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "set_and_get_system_status",
|
||||
"requirement_id": "the id related to the requirement. e.g. REQ.02",
|
||||
"description": "set and get the system status",
|
||||
"type": "function",
|
||||
"parameters": {
|
||||
"new_status": {
|
||||
"type": "integer",
|
||||
"intou": "in",
|
||||
"required": "true",
|
||||
"description": "the new system status"
|
||||
},
|
||||
"current_status": {
|
||||
"type": "integer",
|
||||
"inout": "out",
|
||||
"required": "true",
|
||||
"description": "get current system status"
|
||||
}
|
||||
}
|
||||
},
|
||||
...
|
||||
]
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
openai>=1.0.0
|
||||
python-dotenv>=1.0.0
|
||||
rich>=13.0.0
|
||||
python-docx>=0.8.11
|
||||
PyPDF2>=3.0.0
|
||||
click>=8.1.0
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 设置环境变量
|
||||
export OPENAI_API_KEY="sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR"
|
||||
export OPENAI_BASE_URL="https://openapi.monica.im/v1" # 或其他兼容接口
|
||||
export LLM_MODEL="gpt-4o"
|
||||
|
||||
python main.py
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
# utils/file_handler.py - 文件读取工具(支持 txt/md/pdf/docx)
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def read_text_file(file_path: str) -> str:
|
||||
"""
|
||||
读取纯文本文件内容(.txt / .md / .py 等)
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
文件文本内容
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
UnicodeDecodeError: 编码错误时尝试 latin-1 兜底
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return path.read_text(encoding="latin-1")
|
||||
|
||||
|
||||
def read_pdf_file(file_path: str) -> str:
|
||||
"""
|
||||
读取 PDF 文件内容
|
||||
|
||||
Args:
|
||||
file_path: PDF 文件路径
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 PyPDF2
|
||||
FileNotFoundError: 文件不存在
|
||||
"""
|
||||
try:
|
||||
import PyPDF2
|
||||
except ImportError:
|
||||
raise ImportError("请安装 PyPDF2: pip install PyPDF2")
|
||||
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
texts = []
|
||||
with open(file_path, "rb") as f:
|
||||
reader = PyPDF2.PdfReader(f)
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
texts.append(text)
|
||||
return "\n".join(texts)
|
||||
|
||||
|
||||
def read_docx_file(file_path: str) -> str:
|
||||
"""
|
||||
读取 Word (.docx) 文件内容
|
||||
|
||||
Args:
|
||||
file_path: docx 文件路径
|
||||
|
||||
Returns:
|
||||
提取的文本内容(段落合并)
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 python-docx
|
||||
FileNotFoundError: 文件不存在
|
||||
"""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError("请安装 python-docx: pip install python-docx")
|
||||
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
doc = Document(file_path)
|
||||
return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
|
||||
|
||||
|
||||
def read_file_auto(file_path: str) -> str:
|
||||
"""
|
||||
根据文件扩展名自动选择读取方式
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
文件文本内容
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的文件类型
|
||||
"""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
readers = {
|
||||
".txt": read_text_file,
|
||||
".md": read_text_file,
|
||||
".py": read_text_file,
|
||||
".json": read_text_file,
|
||||
".yaml": read_text_file,
|
||||
".yml": read_text_file,
|
||||
".pdf": read_pdf_file,
|
||||
".docx": read_docx_file,
|
||||
}
|
||||
reader = readers.get(ext)
|
||||
if reader is None:
|
||||
raise ValueError(f"不支持的文件类型: {ext},支持: {list(readers.keys())}")
|
||||
return reader(file_path)
|
||||
|
||||
|
||||
def merge_knowledge_files(file_paths: List[str]) -> str:
|
||||
"""
|
||||
合并多个知识库文件为单一文本
|
||||
|
||||
Args:
|
||||
file_paths: 知识库文件路径列表
|
||||
|
||||
Returns:
|
||||
合并后的知识库文本(包含文件名分隔符)
|
||||
"""
|
||||
if not file_paths:
|
||||
return ""
|
||||
|
||||
sections = []
|
||||
for fp in file_paths:
|
||||
try:
|
||||
content = read_file_auto(fp)
|
||||
file_name = Path(fp).name
|
||||
sections.append(f"### 知识库文件: {file_name}\n{content}")
|
||||
except Exception as e:
|
||||
sections.append(f"### 知识库文件: {fp}\n[读取失败: {e}]")
|
||||
|
||||
return "\n\n".join(sections)
|
||||
|
|
@ -0,0 +1,430 @@
|
|||
# utils/output_writer.py - 代码文件 & JSON 输出工具
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import config
|
||||
|
||||
|
||||
# 各语言文件扩展名映射
|
||||
LANGUAGE_EXT_MAP: Dict[str, str] = {
|
||||
"python": ".py",
|
||||
"javascript": ".js",
|
||||
"typescript": ".ts",
|
||||
"java": ".java",
|
||||
"go": ".go",
|
||||
"rust": ".rs",
|
||||
"cpp": ".cpp",
|
||||
"c": ".c",
|
||||
"csharp": ".cs",
|
||||
"ruby": ".rb",
|
||||
"php": ".php",
|
||||
"swift": ".swift",
|
||||
"kotlin": ".kt",
|
||||
}
|
||||
|
||||
# 合法的通用类型集合
|
||||
VALID_TYPES = {
|
||||
"integer", "string", "boolean", "float",
|
||||
"list", "dict", "object", "void", "any",
|
||||
}
|
||||
|
||||
# 合法的 inout 值
|
||||
VALID_INOUT = {"in", "out", "inout"}
|
||||
|
||||
|
||||
def get_file_extension(language: str) -> str:
|
||||
"""
|
||||
获取指定语言的文件扩展名
|
||||
|
||||
Args:
|
||||
language: 编程语言名称(小写)
|
||||
|
||||
Returns:
|
||||
文件扩展名(含点号,如 '.py')
|
||||
"""
|
||||
return LANGUAGE_EXT_MAP.get(language.lower(), ".txt")
|
||||
|
||||
|
||||
def build_project_output_dir(project_name: str) -> str:
|
||||
"""
|
||||
构建项目输出目录路径
|
||||
|
||||
Args:
|
||||
project_name: 项目名称
|
||||
|
||||
Returns:
|
||||
输出目录路径
|
||||
"""
|
||||
safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name)
|
||||
return os.path.join(config.OUTPUT_BASE_DIR, safe_name)
|
||||
|
||||
|
||||
def ensure_project_dir(project_name: str) -> str:
|
||||
"""
|
||||
确保项目输出目录存在,不存在则创建
|
||||
|
||||
Args:
|
||||
project_name: 项目名称
|
||||
|
||||
Returns:
|
||||
创建好的目录路径
|
||||
"""
|
||||
output_dir = build_project_output_dir(project_name)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
init_file = os.path.join(output_dir, "__init__.py")
|
||||
if not os.path.exists(init_file):
|
||||
Path(init_file).write_text(
|
||||
"# Auto-generated project package\n", encoding="utf-8"
|
||||
)
|
||||
return output_dir
|
||||
|
||||
|
||||
def write_code_file(
|
||||
output_dir: str,
|
||||
function_name: str,
|
||||
language: str,
|
||||
content: str,
|
||||
) -> str:
|
||||
"""
|
||||
将代码内容写入指定目录的文件
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录路径
|
||||
function_name: 函数名(用于生成文件名)
|
||||
language: 编程语言
|
||||
content: 代码内容
|
||||
|
||||
Returns:
|
||||
写入的文件完整路径
|
||||
"""
|
||||
ext = get_file_extension(language)
|
||||
file_name = f"{function_name}{ext}"
|
||||
file_path = os.path.join(output_dir, file_name)
|
||||
Path(file_path).write_text(content, encoding="utf-8")
|
||||
return file_path
|
||||
|
||||
|
||||
def write_project_readme(
|
||||
output_dir: str,
|
||||
project_name: str,
|
||||
requirements_summary: str,
|
||||
) -> str:
|
||||
"""
|
||||
在项目目录生成 README.md 文件
|
||||
|
||||
Args:
|
||||
output_dir: 项目输出目录
|
||||
project_name: 项目名称
|
||||
requirements_summary: 功能需求摘要文本
|
||||
|
||||
Returns:
|
||||
README.md 文件路径
|
||||
"""
|
||||
readme_content = f"""# {project_name}
|
||||
|
||||
> Auto-generated by Requirement Analyzer
|
||||
|
||||
## 功能需求列表
|
||||
|
||||
{requirements_summary}
|
||||
"""
|
||||
readme_path = os.path.join(output_dir, "README.md")
|
||||
Path(readme_path).write_text(readme_content, encoding="utf-8")
|
||||
return readme_path
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# 函数签名 JSON 导出
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def build_signatures_document(
|
||||
project_name: str,
|
||||
project_description: str,
|
||||
signatures: List[dict],
|
||||
) -> dict:
|
||||
"""
|
||||
将函数签名列表包装为带项目信息的顶层文档结构。
|
||||
|
||||
Args:
|
||||
project_name: 项目名称,写入 "project" 字段
|
||||
project_description: 项目描述,写入 "description" 字段
|
||||
signatures: 函数签名 dict 列表,写入 "functions" 字段
|
||||
|
||||
Returns:
|
||||
顶层文档 dict,结构为::
|
||||
|
||||
{
|
||||
"project": "<project_name>",
|
||||
"description": "<project_description>",
|
||||
"functions": [ ... ]
|
||||
}
|
||||
"""
|
||||
return {
|
||||
"project": project_name,
|
||||
"description": project_description or "",
|
||||
"functions": signatures,
|
||||
}
|
||||
|
||||
|
||||
def patch_signatures_with_url(
|
||||
signatures: List[dict],
|
||||
func_name_to_url: Dict[str, str],
|
||||
) -> List[dict]:
|
||||
"""
|
||||
将代码文件的路径(URL)回写到对应函数签名的 "url" 字段。
|
||||
|
||||
遍历签名列表,根据 signature["name"] 在 func_name_to_url 中查找
|
||||
对应路径,找到则写入 "url" 字段;未找到则写入空字符串,不抛出异常。
|
||||
|
||||
"url" 字段插入位置紧跟在 "type" 字段之后,以保持字段顺序的可读性::
|
||||
|
||||
{
|
||||
"name": "create_user",
|
||||
"requirement_id": "REQ.01",
|
||||
"description": "...",
|
||||
"type": "function",
|
||||
"url": "/abs/path/to/create_user.py", ← 新增
|
||||
"parameters": { ... },
|
||||
"return": { ... }
|
||||
}
|
||||
|
||||
Args:
|
||||
signatures: 原始签名列表(in-place 修改)
|
||||
func_name_to_url: {函数名: 代码文件绝对路径} 映射表,
|
||||
由 CodeGenerator.generate_batch() 的进度回调收集
|
||||
|
||||
Returns:
|
||||
修改后的签名列表(与传入的同一对象,方便链式调用)
|
||||
"""
|
||||
for sig in signatures:
|
||||
func_name = sig.get("name", "")
|
||||
url = func_name_to_url.get(func_name, "")
|
||||
_insert_field_after(sig, after_key="type", new_key="url", new_value=url)
|
||||
return signatures
|
||||
|
||||
|
||||
def _insert_field_after(
|
||||
d: dict,
|
||||
after_key: str,
|
||||
new_key: str,
|
||||
new_value,
|
||||
) -> None:
|
||||
"""
|
||||
在有序 dict 中将 new_key 插入到 after_key 之后。
|
||||
若 after_key 不存在,则追加到末尾。
|
||||
若 new_key 已存在,则直接更新其值(不改变位置)。
|
||||
|
||||
Args:
|
||||
d: 目标 dict(Python 3.7+ 保证插入顺序)
|
||||
after_key: 参考键名
|
||||
new_key: 要插入的键名
|
||||
new_value: 要插入的值
|
||||
"""
|
||||
if new_key in d:
|
||||
d[new_key] = new_value
|
||||
return
|
||||
|
||||
items = list(d.items())
|
||||
insert_pos = len(items)
|
||||
for i, (k, _) in enumerate(items):
|
||||
if k == after_key:
|
||||
insert_pos = i + 1
|
||||
break
|
||||
|
||||
items.insert(insert_pos, (new_key, new_value))
|
||||
d.clear()
|
||||
d.update(items)
|
||||
|
||||
|
||||
def write_function_signatures_json(
|
||||
output_dir: str,
|
||||
signatures: List[dict],
|
||||
project_name: str,
|
||||
project_description: str,
|
||||
file_name: str = "function_signatures.json",
|
||||
) -> str:
|
||||
"""
|
||||
将函数签名列表连同项目信息一起导出为 JSON 文件。
|
||||
|
||||
输出的 JSON 顶层结构为::
|
||||
|
||||
{
|
||||
"project": "<project_name>",
|
||||
"description": "<project_description>",
|
||||
"functions": [
|
||||
{
|
||||
"name": "...",
|
||||
"requirement_id": "...",
|
||||
"description": "...",
|
||||
"type": "function",
|
||||
"url": "/abs/path/to/xxx.py",
|
||||
"parameters": { ... },
|
||||
"return": { ... }
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
Args:
|
||||
output_dir: JSON 文件写入目录
|
||||
signatures: 函数签名 dict 列表(应已通过
|
||||
patch_signatures_with_url() 写入 "url" 字段)
|
||||
project_name: 项目名称
|
||||
project_description: 项目描述
|
||||
file_name: 输出文件名,默认 function_signatures.json
|
||||
|
||||
Returns:
|
||||
写入的 JSON 文件完整路径
|
||||
|
||||
Raises:
|
||||
OSError: 目录不可写
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
document = build_signatures_document(project_name, project_description, signatures)
|
||||
file_path = os.path.join(output_dir, file_name)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(document, f, ensure_ascii=False, indent=2)
|
||||
return file_path
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════
|
||||
# 签名结构校验
|
||||
# ══════════════════════════════════════════════════════
|
||||
|
||||
def validate_signature_schema(signature: dict) -> List[str]:
|
||||
"""
|
||||
校验单个函数签名 dict 是否符合规范。
|
||||
|
||||
校验范围:
|
||||
- 顶层必填字段:name / requirement_id / description / type / parameters
|
||||
- 可选字段 "url":若存在则必须为非空字符串
|
||||
- parameters:每个参数的 type / inout / required 字段
|
||||
- return:type 字段 + on_success / on_failure 子结构
|
||||
- void 函数:on_success / on_failure 应为 null
|
||||
- 非 void 函数:on_success / on_failure 必须存在,
|
||||
且 value(非空)与 description(非空)均需填写
|
||||
|
||||
Args:
|
||||
signature: 单个函数签名 dict
|
||||
|
||||
Returns:
|
||||
错误信息字符串列表,列表为空表示校验通过
|
||||
"""
|
||||
errors: List[str] = []
|
||||
|
||||
# ── 顶层必填字段 ──────────────────────────────────
|
||||
for key in ("name", "requirement_id", "description", "type", "parameters"):
|
||||
if key not in signature:
|
||||
errors.append(f"缺少顶层字段: '{key}'")
|
||||
|
||||
# ── url 字段(可选,存在时校验非空)─────────────────
|
||||
if "url" in signature:
|
||||
if not isinstance(signature["url"], str):
|
||||
errors.append("'url' 字段必须是字符串类型")
|
||||
elif signature["url"] == "":
|
||||
errors.append("'url' 字段不能为空字符串(代码文件路径未成功回写)")
|
||||
|
||||
# ── parameters ────────────────────────────────────
|
||||
params = signature.get("parameters", {})
|
||||
if not isinstance(params, dict):
|
||||
errors.append("'parameters' 必须是 dict 类型")
|
||||
else:
|
||||
for pname, pdef in params.items():
|
||||
if not isinstance(pdef, dict):
|
||||
errors.append(f"参数 '{pname}' 定义必须是 dict")
|
||||
continue
|
||||
# type(支持联合类型,如 "string|integer")
|
||||
if "type" not in pdef:
|
||||
errors.append(f"参数 '{pname}' 缺少 'type' 字段")
|
||||
else:
|
||||
parts = [p.strip() for p in pdef["type"].split("|")]
|
||||
if not all(p in VALID_TYPES for p in parts):
|
||||
errors.append(
|
||||
f"参数 '{pname}' 的 type='{pdef['type']}' 含有不合法的类型"
|
||||
)
|
||||
# inout
|
||||
if "inout" not in pdef:
|
||||
errors.append(f"参数 '{pname}' 缺少 'inout' 字段")
|
||||
elif pdef["inout"] not in VALID_INOUT:
|
||||
errors.append(
|
||||
f"参数 '{pname}' 的 inout='{pdef['inout']}' 应为 in/out/inout"
|
||||
)
|
||||
# required
|
||||
if "required" not in pdef:
|
||||
errors.append(f"参数 '{pname}' 缺少 'required' 字段")
|
||||
elif not isinstance(pdef["required"], bool):
|
||||
errors.append(
|
||||
f"参数 '{pname}' 的 'required' 应为布尔值 true/false,"
|
||||
f"当前为: {pdef['required']!r}"
|
||||
)
|
||||
|
||||
# ── return ────────────────────────────────────────
|
||||
ret = signature.get("return")
|
||||
if ret is None:
|
||||
errors.append(
|
||||
"缺少 'return' 字段(void 函数请填 "
|
||||
"{\"type\": \"void\", \"on_success\": null, \"on_failure\": null})"
|
||||
)
|
||||
elif not isinstance(ret, dict):
|
||||
errors.append("'return' 必须是 dict 类型")
|
||||
else:
|
||||
ret_type = ret.get("type")
|
||||
if not ret_type:
|
||||
errors.append("'return' 缺少 'type' 字段")
|
||||
elif ret_type not in VALID_TYPES:
|
||||
errors.append(f"'return.type'='{ret_type}' 不在合法类型列表中")
|
||||
|
||||
is_void = (ret_type == "void")
|
||||
|
||||
for sub_key in ("on_success", "on_failure"):
|
||||
sub = ret.get(sub_key)
|
||||
if is_void:
|
||||
if sub is not None:
|
||||
errors.append(
|
||||
f"void 函数的 'return.{sub_key}' 应为 null,"
|
||||
f"当前为: {sub!r}"
|
||||
)
|
||||
else:
|
||||
if sub is None:
|
||||
errors.append(
|
||||
f"非 void 函数缺少 'return.{sub_key}',"
|
||||
f"请描述{'成功' if sub_key == 'on_success' else '失败'}时的返回值"
|
||||
)
|
||||
elif not isinstance(sub, dict):
|
||||
errors.append(f"'return.{sub_key}' 必须是 dict 类型")
|
||||
else:
|
||||
if "value" not in sub:
|
||||
errors.append(f"'return.{sub_key}' 缺少 'value' 字段")
|
||||
elif sub["value"] == "":
|
||||
errors.append(
|
||||
f"'return.{sub_key}.value' 不能为空字符串,"
|
||||
f"请填写具体返回值、值域描述或结构示例"
|
||||
)
|
||||
if "description" not in sub or sub.get("description") in (None, ""):
|
||||
errors.append(f"'return.{sub_key}.description' 不能为空")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_all_signatures(signatures: List[dict]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
批量校验函数签名列表。
|
||||
|
||||
注意:此函数接受的是纯签名列表(即顶层文档的 "functions" 字段),
|
||||
而非包含 project/description 的顶层文档。
|
||||
|
||||
Args:
|
||||
signatures: 函数签名 dict 列表
|
||||
|
||||
Returns:
|
||||
{函数名: [错误信息, ...]} 字典,仅包含有错误的条目
|
||||
"""
|
||||
report: Dict[str, List[str]] = {}
|
||||
for sig in signatures:
|
||||
name = sig.get("name", f"unknown_{id(sig)}")
|
||||
errs = validate_signature_schema(sig)
|
||||
if errs:
|
||||
report[name] = errs
|
||||
return report
|
||||
Loading…
Reference in New Issue