Default Changelist

This commit is contained in:
sontolau 2026-03-05 02:09:45 +08:00
parent b16661ac91
commit 9d42954d44
35 changed files with 5188 additions and 0 deletions

16
MCPServers/requirement.py Normal file
View File

@ -0,0 +1,16 @@
from mcp.server.fastmcp import FastMCP
from mcp.server.transport_security import TransportSecuritySettings
mcp = FastMCP("requirements",
transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False))
@mcp.tool()
def generate_requirement_document(project: str, requirements: list, standard: str = "GJB438C"):
print(project, requirements, standard)
if __name__ == "__main__":
mcp.settings.host = "0.0.0.0"
mcp.settings.port = 7999
mcp.run(transport="streamable-http")

View File

View File

@ -0,0 +1,21 @@
import os
from dataclasses import dataclass
@dataclass
class Config:
# LLM配置以OpenAI兼容接口为例可替换为其他LLM
LLM_API_KEY: str = os.getenv("LLM_API_KEY", "your-api-key-here")
LLM_BASE_URL: str = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1")
LLM_MODEL: str = os.getenv("LLM_MODEL", "gpt-4o")
LLM_TEMPERATURE: float = 0.2
LLM_MAX_TOKENS: int = 4096
# 测试配置
GENERATED_TESTS_DIR: str = "generated_tests"
TEST_TIMEOUT: int = 30 # 单个测试超时时间(秒)
# HTTP测试配置
HTTP_BASE_URL: str = os.getenv("HTTP_BASE_URL", "http://localhost:8080")
HTTP_TIMEOUT: int = 10
config = Config()

View File

View File

@ -0,0 +1,609 @@
"""
覆盖率分析与缺口识别
修复
- 移除对已删除属性 request_parameters / url_parameters / response 的引用
- HTTP function 接口统一使用 iface.in_parameters / iface.out_parameters
- HTTP 接口的返回字段覆盖统一从 return_info.on_success / on_failure 读取
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from core.parser import InterfaceInfo
logger = logging.getLogger(__name__)
# ══════════════════════════════════════════════════════════════
# 数据结构
# ══════════════════════════════════════════════════════════════
@dataclass
class InterfaceCoverage:
interface_name: str
protocol: str
requirement_id: str = ""
is_covered: bool = False
# 入参覆盖in / inout
all_in_params: set[str] = field(default_factory=set)
covered_in_params: set[str] = field(default_factory=set)
# 出参断言覆盖out / inout
all_out_params: set[str] = field(default_factory=set)
asserted_out_params: set[str] = field(default_factory=set)
# on_success 返回字段断言覆盖
all_success_fields: set[str] = field(default_factory=set)
covered_success_fields: set[str] = field(default_factory=set)
# on_failure 返回字段断言覆盖
all_failure_fields: set[str] = field(default_factory=set)
covered_failure_fields: set[str] = field(default_factory=set)
# 异常断言on_failure = raises Exception
expects_exception: bool = False
exception_case_covered: bool = False
has_positive_case: bool = False
has_negative_case: bool = False
covering_test_ids: list[str] = field(default_factory=list)
# ── 覆盖率计算 ────────────────────────────────────────────
@property
def in_param_coverage_rate(self) -> float:
if not self.all_in_params:
return 1.0
return len(self.covered_in_params) / len(self.all_in_params)
@property
def out_param_coverage_rate(self) -> float:
if not self.all_out_params:
return 1.0
return len(self.asserted_out_params) / len(self.all_out_params)
@property
def success_field_coverage_rate(self) -> float:
if not self.all_success_fields:
return 1.0
return len(self.covered_success_fields) / len(self.all_success_fields)
@property
def failure_field_coverage_rate(self) -> float:
if not self.all_failure_fields:
return 1.0
return len(self.covered_failure_fields) / len(self.all_failure_fields)
@property
def uncovered_in_params(self) -> set[str]:
return self.all_in_params - self.covered_in_params
@property
def uncovered_out_params(self) -> set[str]:
return self.all_out_params - self.asserted_out_params
@property
def uncovered_success_fields(self) -> set[str]:
return self.all_success_fields - self.covered_success_fields
@property
def uncovered_failure_fields(self) -> set[str]:
return self.all_failure_fields - self.covered_failure_fields
@dataclass
class RequirementCoverage:
requirement: str
requirement_id: str = ""
covering_test_ids: list[str] = field(default_factory=list)
@property
def is_covered(self) -> bool:
return bool(self.covering_test_ids)
@dataclass
class Gap:
gap_type: str # 见下方常量
severity: str # critical | high | medium | low
target: str
detail: str
suggestion: str
# gap_type 常量
GAP_INTERFACE_NOT_COVERED = "interface_not_covered"
GAP_IN_PARAM_NOT_COVERED = "in_param_not_covered"
GAP_OUT_PARAM_NOT_ASSERTED = "out_param_not_asserted"
GAP_SUCCESS_FIELD_NOT_ASSERTED = "success_field_not_asserted"
GAP_FAILURE_FIELD_NOT_ASSERTED = "failure_field_not_asserted"
GAP_EXCEPTION_NOT_TESTED = "exception_case_not_tested"
GAP_MISSING_POSITIVE = "missing_positive_case"
GAP_MISSING_NEGATIVE = "missing_negative_case"
GAP_REQUIREMENT_NOT_COVERED = "requirement_not_covered"
GAP_TEST_FAILED = "test_failed"
GAP_TEST_ERROR = "test_error"
@dataclass
class CoverageReport:
total_interfaces: int = 0
covered_interfaces: int = 0
total_requirements: int = 0
covered_requirements: int = 0
total_test_cases: int = 0
passed_test_cases: int = 0
failed_test_cases: int = 0
error_test_cases: int = 0
interface_coverages: list[InterfaceCoverage] = field(default_factory=list)
requirement_coverages: list[RequirementCoverage] = field(default_factory=list)
gaps: list[Gap] = field(default_factory=list)
@property
def interface_coverage_rate(self) -> float:
return self.covered_interfaces / self.total_interfaces \
if self.total_interfaces else 0.0
@property
def requirement_coverage_rate(self) -> float:
return self.covered_requirements / self.total_requirements \
if self.total_requirements else 0.0
@property
def pass_rate(self) -> float:
return self.passed_test_cases / self.total_test_cases \
if self.total_test_cases else 0.0
@property
def avg_in_param_coverage_rate(self) -> float:
rates = [
ic.in_param_coverage_rate
for ic in self.interface_coverages
if ic.all_in_params
]
return sum(rates) / len(rates) if rates else 1.0
@property
def avg_success_field_coverage_rate(self) -> float:
rates = [
ic.success_field_coverage_rate
for ic in self.interface_coverages
if ic.all_success_fields
]
return sum(rates) / len(rates) if rates else 1.0
@property
def avg_failure_field_coverage_rate(self) -> float:
rates = [
ic.failure_field_coverage_rate
for ic in self.interface_coverages
if ic.all_failure_fields
]
return sum(rates) / len(rates) if rates else 1.0
@property
def critical_gap_count(self) -> int:
return sum(1 for g in self.gaps if g.severity == "critical")
@property
def high_gap_count(self) -> int:
return sum(1 for g in self.gaps if g.severity == "high")
# ══════════════════════════════════════════════════════════════
# 分析器
# ══════════════════════════════════════════════════════════════
class CoverageAnalyzer:
def __init__(
self,
interfaces: list[InterfaceInfo],
requirements: list[str],
test_cases: list[dict],
run_results: list,
):
self.interfaces = interfaces
self.requirements = requirements
self.test_cases = test_cases
self.run_results = run_results
self._iface_map = {i.name: i for i in interfaces}
self._result_map = {
getattr(r, "test_id", ""): r for r in run_results
}
# ── 入口 ──────────────────────────────────────────────────
def analyze(self) -> CoverageReport:
logger.info("Starting coverage analysis …")
report = CoverageReport()
iface_cov_map = self._init_interface_coverages()
req_cov_map = self._init_requirement_coverages()
self._scan_test_cases(iface_cov_map, req_cov_map)
self._scan_run_results(report)
report.interface_coverages = list(iface_cov_map.values())
report.requirement_coverages = list(req_cov_map.values())
report.total_interfaces = len(self.interfaces)
report.covered_interfaces = sum(
1 for ic in iface_cov_map.values() if ic.is_covered
)
report.total_requirements = len(self.requirements)
report.covered_requirements = sum(
1 for rc in req_cov_map.values() if rc.is_covered
)
report.total_test_cases = len(self.test_cases)
report.gaps = self._identify_gaps(iface_cov_map, req_cov_map)
logger.info(
f"Analysis done — "
f"interface={report.interface_coverage_rate:.0%}, "
f"requirement={report.requirement_coverage_rate:.0%}, "
f"pass={report.pass_rate:.0%}, "
f"gaps={len(report.gaps)}"
)
return report
# ══════════════════════════════════════════════════════════
# 初始化接口覆盖对象
# ══════════════════════════════════════════════════════════
def _init_interface_coverages(self) -> dict[str, InterfaceCoverage]:
result: dict[str, InterfaceCoverage] = {}
for iface in self.interfaces:
ic = InterfaceCoverage(
interface_name=iface.name,
protocol=iface.protocol,
requirement_id=iface.requirement_id,
)
# ── 入参in / inout────────────────────────────
# HTTP 与 function 统一使用 in_parameters
ic.all_in_params = {p.name for p in iface.in_parameters}
# ── 出参out / inout───────────────────────────
ic.all_out_params = {p.name for p in iface.out_parameters}
# ── 返回值字段on_success──────────────────────
ri = iface.return_info
success_fields = set(ri.on_success.value_fields.keys())
# boolean / 非 dict 类型:用虚拟字段 "return" 代表返回值本身
if ri.type in ("boolean", "integer", "string", "float") or (
ri.on_success.value is not None
and not isinstance(ri.on_success.value, dict)
):
success_fields.add("return")
ic.all_success_fields = success_fields
# ── 返回值字段on_failure──────────────────────
if ri.on_failure.is_exception:
ic.expects_exception = True
else:
failure_fields = set(ri.on_failure.value_fields.keys())
if ri.type in ("boolean", "integer", "string", "float") or (
ri.on_failure.value is not None
and not isinstance(ri.on_failure.value, dict)
and not ri.on_failure.is_exception
):
failure_fields.add("return")
ic.all_failure_fields = failure_fields
result[iface.name] = ic
return result
# ══════════════════════════════════════════════════════════
# 初始化需求覆盖对象
# ══════════════════════════════════════════════════════════
def _init_requirement_coverages(self) -> dict[str, RequirementCoverage]:
result: dict[str, RequirementCoverage] = {}
for req in self.requirements:
result[req] = RequirementCoverage(requirement=req)
# requirement_id → requirement text 的辅助映射
self._req_id_map: dict[str, str] = {}
for iface in self.interfaces:
if iface.requirement_id:
matched = self._fuzzy_match_req(iface.name, result)
if matched:
self._req_id_map[iface.requirement_id] = matched
return result
# ══════════════════════════════════════════════════════════
# 扫描测试用例
# ══════════════════════════════════════════════════════════
def _scan_test_cases(
self,
iface_cov_map: dict[str, InterfaceCoverage],
req_cov_map: dict[str, RequirementCoverage],
):
for tc in self.test_cases:
test_id = tc.get("test_id", "")
requirement = tc.get("requirement", "")
req_id = tc.get("requirement_id", "")
description = (
tc.get("description", "") + " " + tc.get("test_name", "")
).lower()
is_negative = any(
kw in description
for kw in (
"negative", "invalid", "missing", "fail", "error",
"wrong", "bad", "boundary", "负向", "exception",
)
)
# 需求覆盖匹配
matched_req = (
self._match_by_req_id(req_id, req_cov_map)
or self._match_by_text(requirement, req_cov_map)
)
if matched_req:
req_cov_map[matched_req].covering_test_ids.append(test_id)
# 遍历步骤
for step in tc.get("steps", []):
iface_name = step.get("interface_name", "")
ic = iface_cov_map.get(iface_name)
if ic is None:
continue
ic.is_covered = True
if test_id not in ic.covering_test_ids:
ic.covering_test_ids.append(test_id)
if is_negative:
ic.has_negative_case = True
else:
ic.has_positive_case = True
# 入参覆盖
for param_name in step.get("input", {}).keys():
if param_name in ic.all_in_params:
ic.covered_in_params.add(param_name)
# 断言覆盖
for assertion in step.get("assertions", []):
f = assertion.get("field", "")
operator = assertion.get("operator", "")
# 异常断言
if f == "exception" and operator == "raised":
ic.exception_case_covered = True
continue
# 出参断言
if f in ic.all_out_params:
ic.asserted_out_params.add(f)
# on_success 字段(正向用例)
if not is_negative and f in ic.all_success_fields:
ic.covered_success_fields.add(f)
# on_failure 字段(负向用例)
if is_negative and f in ic.all_failure_fields:
ic.covered_failure_fields.add(f)
# boolean / scalar return 断言
if f == "return":
if not is_negative:
ic.covered_success_fields.add("return")
else:
ic.covered_failure_fields.add("return")
# ══════════════════════════════════════════════════════════
# 扫描执行结果
# ══════════════════════════════════════════════════════════
def _scan_run_results(self, report: CoverageReport):
for r in self.run_results:
s = getattr(r, "status", "")
if s == "PASS":
report.passed_test_cases += 1
elif s == "FAIL":
report.failed_test_cases += 1
else:
report.error_test_cases += 1
# ══════════════════════════════════════════════════════════
# 缺口识别
# ══════════════════════════════════════════════════════════
def _identify_gaps(
self,
iface_cov_map: dict[str, InterfaceCoverage],
req_cov_map: dict[str, RequirementCoverage],
) -> list[Gap]:
gaps: list[Gap] = []
for ic in iface_cov_map.values():
n = ic.interface_name
rid = f"[{ic.requirement_id}] " if ic.requirement_id else ""
# ① 接口完全未覆盖
if not ic.is_covered:
gaps.append(Gap(
gap_type=GAP_INTERFACE_NOT_COVERED,
severity="critical", target=n,
detail=f"{rid}'{n}' has NO test case.",
suggestion=f"Add positive + negative test cases for '{n}'.",
))
continue # 后续缺口依赖覆盖,跳过
# ② 缺少正向用例
if not ic.has_positive_case:
gaps.append(Gap(
gap_type=GAP_MISSING_POSITIVE,
severity="critical", target=n,
detail=f"{rid}'{n}' has no positive (happy-path) test.",
suggestion=f"Add a positive test case with valid inputs for '{n}'.",
))
# ③ 缺少负向用例
if not ic.has_negative_case:
gaps.append(Gap(
gap_type=GAP_MISSING_NEGATIVE,
severity="high", target=n,
detail=f"{rid}'{n}' has no negative test case.",
suggestion=(
f"Add negative tests for '{n}': "
f"missing required params, invalid types, boundary values."
),
))
# ④ 入参未覆盖
for param in sorted(ic.uncovered_in_params):
gaps.append(Gap(
gap_type=GAP_IN_PARAM_NOT_COVERED,
severity="high", target=n,
detail=f"{rid}Input param '{param}' of '{n}' never used in tests.",
suggestion=f"Add a test that explicitly passes '{param}' to '{n}'.",
))
# ⑤ 出参未断言
for param in sorted(ic.uncovered_out_params):
gaps.append(Gap(
gap_type=GAP_OUT_PARAM_NOT_ASSERTED,
severity="high", target=n,
detail=f"{rid}Output param '{param}' of '{n}' never asserted.",
suggestion=(
f"Add an assertion on output param '{param}' "
f"after calling '{n}'."
),
))
# ⑥ on_success 返回字段未断言
for f in sorted(ic.uncovered_success_fields):
gaps.append(Gap(
gap_type=GAP_SUCCESS_FIELD_NOT_ASSERTED,
severity="high", target=n,
detail=(
f"{rid}on_success field '{f}' of '{n}' "
f"never asserted in positive tests."
),
suggestion=(
f"Add assertion on '{f}' in the positive "
f"test step for '{n}'."
),
))
# ⑦ on_failure 返回字段未断言
for f in sorted(ic.uncovered_failure_fields):
gaps.append(Gap(
gap_type=GAP_FAILURE_FIELD_NOT_ASSERTED,
severity="medium", target=n,
detail=(
f"{rid}on_failure field '{f}' of '{n}' "
f"never asserted in negative tests."
),
suggestion=(
f"Add assertion on '{f}' in the negative "
f"test step for '{n}'."
),
))
# ⑧ 异常场景未测试
if ic.expects_exception and not ic.exception_case_covered:
gaps.append(Gap(
gap_type=GAP_EXCEPTION_NOT_TESTED,
severity="high", target=n,
detail=(
f"{rid}'{n}' declares on_failure = raises Exception, "
f"but no test asserts that the exception is raised."
),
suggestion=(
f"Add a negative test for '{n}' that wraps the call "
f"in try/except and asserts the exception is raised."
),
))
# ⑨ 需求未覆盖
for rc in req_cov_map.values():
if not rc.is_covered:
gaps.append(Gap(
gap_type=GAP_REQUIREMENT_NOT_COVERED,
severity="critical",
target=rc.requirement,
detail=f"Requirement '{rc.requirement}' has no test case.",
suggestion=f"Generate test cases for: '{rc.requirement}'.",
))
# ⑩ 执行失败 / 错误
for r in self.run_results:
s = getattr(r, "status", "")
if s == "FAIL":
gaps.append(Gap(
gap_type=GAP_TEST_FAILED,
severity="high",
target=getattr(r, "test_id", ""),
detail=f"Test '{r.test_id}' FAILED: {r.message}",
suggestion=(
"Investigate failure; fix implementation or test data."
),
))
elif s in ("ERROR", "TIMEOUT"):
gaps.append(Gap(
gap_type=GAP_TEST_ERROR,
severity="medium",
target=getattr(r, "test_id", ""),
detail=f"Test '{r.test_id}' {s}: {r.message}",
suggestion=(
"Check script for runtime errors or increase TEST_TIMEOUT."
),
))
# 按严重程度排序
_order = {"critical": 0, "high": 1, "medium": 2, "low": 3}
gaps.sort(key=lambda g: _order.get(g.severity, 9))
return gaps
# ══════════════════════════════════════════════════════════
# 工具方法
# ══════════════════════════════════════════════════════════
def _match_by_req_id(
self,
req_id: str,
req_cov_map: dict[str, RequirementCoverage],
) -> str | None:
if not req_id:
return None
mapped = self._req_id_map.get(req_id)
if mapped and mapped in req_cov_map:
return mapped
return None
def _match_by_text(
self,
requirement: str,
req_cov_map: dict[str, RequirementCoverage],
) -> str | None:
if requirement in req_cov_map:
return requirement
req_lower = requirement.lower()
for key in req_cov_map:
if key.lower() in req_lower or req_lower in key.lower():
return key
return None
def _fuzzy_match_req(
self,
iface_name: str,
req_cov_map: dict[str, RequirementCoverage],
) -> str | None:
name_lower = iface_name.lower().replace("_", " ")
for key in req_cov_map:
if name_lower in key.lower() or key.lower() in name_lower:
return key
return None

View File

@ -0,0 +1,58 @@
import json
import re
import logging
from openai import OpenAI
from config import config
logger = logging.getLogger(__name__)
class LLMClient:
"""封装 LLM API 调用OpenAI 兼容接口)"""
def __init__(self):
self.client = OpenAI(
api_key=config.LLM_API_KEY,
base_url=config.LLM_BASE_URL,
)
def generate_test_cases(self, system_prompt: str, user_prompt: str) -> list[dict]:
logger.info(f"Calling LLM: model={config.LLM_MODEL}")
logger.debug(f"--- USER PROMPT ---\n{user_prompt}\n---")
response = self.client.chat.completions.create(
model=config.LLM_MODEL,
temperature=config.LLM_TEMPERATURE,
max_tokens=config.LLM_MAX_TOKENS,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
raw = response.choices[0].message.content
logger.debug(f"--- LLM RAW RESPONSE ---\n{raw[:800]}\n---")
return self._parse_json(raw)
# ── 解析 ──────────────────────────────────────────────────
def _parse_json(self, content: str) -> list[dict]:
content = content.strip()
# 去除可能的 markdown 代码块
content = re.sub(r"^```(?:json)?\s*", "", content, flags=re.MULTILINE)
content = re.sub(r"\s*```\s*$", "", content, flags=re.MULTILINE)
content = content.strip()
try:
data = json.loads(content)
except json.JSONDecodeError:
# 尝试提取第一个 JSON 数组
match = re.search(r"\[.*\]", content, re.DOTALL)
if match:
data = json.loads(match.group())
else:
logger.error(f"Cannot parse LLM response as JSON:\n{content[:600]}")
raise
if not isinstance(data, list):
raise ValueError(f"Expected JSON array, got {type(data)}")
return data

View File

@ -0,0 +1,388 @@
"""
接口描述解析器
支持格式当前版
顶层字段
- project : 项目名称用于命名输出目录
- description : 项目描述
- units : 接口/函数描述列表
接口公共字段
- name : 接口/函数名称
- requirement_id : 需求编号
- description : 接口描述
- type : "function" | "http"
- url : function Python 源文件路径 "create_user.py"
http base URL "http://127.0.0.1/api"
- parameters : 统一参数字段HTTP function 两种协议共用
HTTP 接口不再使用 url_parameters / request_parameters
- return : 返回值描述 on_success / on_failure
HTTP 专用字段
- method : "get" | "post" | "put" | "delete"
- 完整请求 URL = url.rstrip("/") + "/" + name.lstrip("/")
Function 专用
- url .py 文件路径转换为 Python 模块路径供 import 使用
- 函数名即 name 字段
"""
from __future__ import annotations
import json
import re
from typing import Any
from dataclasses import dataclass, field
# ══════════════════════════════════════════════════════════════
# 基础数据结构
# ══════════════════════════════════════════════════════════════
@dataclass
class ParameterInfo:
name: str
type: str # 支持复合类型 "string|integer"
description: str
required: bool = True
default: Any = None
inout: str = "in" # "in" | "out" | "inout"
@dataclass
class ReturnCase:
"""on_success 或 on_failure 的返回描述"""
value: Any = None
description: str = ""
@property
def is_exception(self) -> bool:
"""判断是否为抛出异常的场景"""
return (
isinstance(self.value, str)
and "exception" in self.value.lower()
)
@property
def value_fields(self) -> dict[str, Any]:
"""若 value 是 dict返回其字段否则返回空 dict"""
return self.value if isinstance(self.value, dict) else {}
@property
def value_summary(self) -> str:
"""返回值的简短文字描述,用于 LLM 提示词"""
if isinstance(self.value, dict):
return json.dumps(self.value, ensure_ascii=False)
if isinstance(self.value, bool):
return str(self.value).lower()
return str(self.value) if self.value is not None else ""
@dataclass
class ReturnInfo:
type: str = ""
description: str = ""
on_success: ReturnCase = field(default_factory=ReturnCase)
on_failure: ReturnCase = field(default_factory=ReturnCase)
@property
def all_fields(self) -> set[str]:
"""汇总 on_success + on_failure 中出现的所有字段名"""
fields: set[str] = set()
for case in (self.on_success, self.on_failure):
fields.update(case.value_fields.keys())
return fields
# ══════════════════════════════════════════════════════════════
# 接口描述
# ══════════════════════════════════════════════════════════════
@dataclass
class InterfaceInfo:
name: str
description: str
protocol: str # "http" | "function"
url: str = "" # 原始 url 字段
requirement_id: str = ""
# HTTP 专用
method: str = "get"
# 统一参数列表HTTP 与 function 两种协议共用,均来自 "parameters" 字段)
parameters: list[ParameterInfo] = field(default_factory=list)
# 返回值描述
return_info: ReturnInfo = field(default_factory=ReturnInfo)
# ── 便捷属性:参数分组 ────────────────────────────────────
@property
def in_parameters(self) -> list[ParameterInfo]:
"""inout = "in""inout" 的参数(作为输入)"""
return [p for p in self.parameters if p.inout in ("in", "inout")]
@property
def out_parameters(self) -> list[ParameterInfo]:
"""inout = "out""inout" 的参数(作为输出)"""
return [p for p in self.parameters if p.inout in ("out", "inout")]
# ── 便捷属性HTTP ────────────────────────────────────────
@property
def http_full_url(self) -> str:
"""
HTTP 协议的完整请求 URL
url (base) + name (path)
url="http://127.0.0.1/api", name="/delete_user"
"http://127.0.0.1/api/delete_user"
"""
if self.protocol != "http":
return ""
base = self.url.rstrip("/")
path = self.name.lstrip("/")
return f"{base}/{path}" if base else f"/{path}"
# ── 便捷属性Function ────────────────────────────────────
@property
def module_path(self) -> str:
"""
url.py 文件路径转换为 Python 模块导入路径
规则
1. 去除 .py 后缀
2. 去除前导 ./ /
3. 将路径分隔符替换为 .
示例
"create_user.py" "create_user"
"myapp/services/user_service.py" "myapp.services.user_service"
"./services/user_service.py" "services.user_service"
"myapp.services.user_service" "myapp.services.user_service"
"""
if self.protocol != "function" or not self.url:
return ""
u = self.url.strip()
# 已经是点分模块路径(无路径分隔符且无 .py 后缀)
if re.match(r'^[\w]+(\.[\w]+)*$', u) and not u.endswith(".py"):
return u
# 文件路径 → 模块路径
u = u.replace("\\", "/")
u = re.sub(r'\.py$', '', u)
u = re.sub(r'^\.?/', '', u)
return u.replace("/", ".")
@property
def source_file(self) -> str:
"""function 协议的源文件路径(原始 url 字段)"""
return self.url if self.protocol == "function" else ""
# ══════════════════════════════════════════════════════════════
# 描述文件根结构
# ══════════════════════════════════════════════════════════════
@dataclass
class ApiDescriptor:
project: str
description: str
interfaces: list[InterfaceInfo]
# ══════════════════════════════════════════════════════════════
# 解析器
# ══════════════════════════════════════════════════════════════
class InterfaceParser:
# ── 公开接口 ──────────────────────────────────────────────
def parse_file(self, file_path: str) -> ApiDescriptor:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
return self.parse(data)
def parse(self, data: dict | list) -> ApiDescriptor:
"""
兼容三种格式
新格式 {"project": "...", "description": "...", "units": [...]}
旧格式1{"project": "...", "units": [...]}
旧格式2[...] 直接是接口数组
"""
if isinstance(data, list):
return ApiDescriptor(
project="default",
description="",
interfaces=self._parse_units(data),
)
return ApiDescriptor(
project=data.get("project", "default"),
description=data.get("description", ""),
interfaces=self._parse_units(data.get("units", [])),
)
# ── 单元解析 ──────────────────────────────────────────────
def _parse_units(self, units: list[dict]) -> list[InterfaceInfo]:
result = []
for item in units:
protocol = (
item.get("protocol") or item.get("type") or "function"
).lower()
if protocol == "http":
result.append(self._parse_http(item))
else:
result.append(self._parse_function(item))
return result
def _parse_http(self, item: dict) -> InterfaceInfo:
"""
HTTP 接口解析
参数统一从 "parameters" 字段读取
同时兼容旧格式的 "request_parameters" / "url_parameters"
优先级parameters > request_parameters + url_parameters
"""
# 优先使用新统一字段 "parameters"
raw_params: dict = item.get("parameters", {})
# 旧格式兼容:合并 url_parameters + request_parameters
if not raw_params:
raw_params = {
**item.get("url_parameters", {}),
**item.get("request_parameters", {}),
}
return InterfaceInfo(
name=item["name"],
description=item.get("description", ""),
protocol="http",
url=item.get("url", ""),
requirement_id=item.get("requirement_id", ""),
method=item.get("method", "get").lower(),
parameters=self._parse_param_dict(raw_params, with_inout=True),
return_info=self._parse_return(item.get("return", {})),
)
def _parse_function(self, item: dict) -> InterfaceInfo:
return InterfaceInfo(
name=item["name"],
description=item.get("description", ""),
protocol="function",
url=item.get("url", ""),
requirement_id=item.get("requirement_id", ""),
parameters=self._parse_param_dict(
item.get("parameters", {}), with_inout=True,
),
return_info=self._parse_return(item.get("return", {})),
)
# ── 参数解析 ──────────────────────────────────────────────
def _parse_param_dict(
self,
params: dict,
with_inout: bool = False,
) -> list[ParameterInfo]:
result = []
for name, info in params.items():
result.append(ParameterInfo(
name=name,
type=info.get("type", "string"),
description=info.get("description", ""),
required=info.get("required", True),
default=info.get("default", None),
inout=info.get("inout", "in") if with_inout else "in",
))
return result
# ── 返回值解析 ────────────────────────────────────────────
def _parse_return(self, ret: dict) -> ReturnInfo:
if not ret:
return ReturnInfo()
return ReturnInfo(
type=ret.get("type", ""),
description=ret.get("description", ""),
on_success=self._parse_return_case(ret.get("on_success", {})),
on_failure=self._parse_return_case(ret.get("on_failure", {})),
)
def _parse_return_case(self, case: dict) -> ReturnCase:
if not case:
return ReturnCase()
return ReturnCase(
value=case.get("value"),
description=case.get("description", ""),
)
# ══════════════════════════════════════════════════════════
# 转换为 LLM 可读摘要
# ══════════════════════════════════════════════════════════
def to_summary_dict(self, interfaces: list[InterfaceInfo]) -> list[dict]:
return [
self._http_summary(i) if i.protocol == "http"
else self._function_summary(i)
for i in interfaces
]
def _http_summary(self, iface: InterfaceInfo) -> dict:
ri = iface.return_info
return {
"name": iface.name,
"requirement_id": iface.requirement_id,
"description": iface.description,
"protocol": "http",
"full_url": iface.http_full_url,
"method": iface.method,
# HTTP 接口统一用 "parameters" 输出给 LLM
"parameters": self._params_to_dict(iface.parameters),
"return": {
"type": ri.type,
"description": ri.description,
"on_success": {
"value": ri.on_success.value,
"description": ri.on_success.description,
},
"on_failure": {
"value": ri.on_failure.value,
"description": ri.on_failure.description,
},
},
}
def _function_summary(self, iface: InterfaceInfo) -> dict:
ri = iface.return_info
return {
"name": iface.name,
"requirement_id": iface.requirement_id,
"description": iface.description,
"protocol": "function",
"source_file": iface.source_file,
"module_path": iface.module_path,
"parameters": self._params_to_dict(iface.parameters),
"return": {
"type": ri.type,
"description": ri.description,
"on_success": {
"value": ri.on_success.value,
"description": ri.on_success.description,
},
"on_failure": {
"value": ri.on_failure.value,
"description": ri.on_failure.description,
},
},
}
def _params_to_dict(self, params: list[ParameterInfo]) -> dict:
result = {}
for p in params:
entry: dict = {
"type": p.type,
"inout": p.inout,
"description": p.description,
"required": p.required,
}
if p.default is not None:
entry["default"] = p.default
result[p.name] = entry
return result

View File

@ -0,0 +1,167 @@
"""
提示词构建器
升级点
- 传递项目 description
- function 协议source_file(.py) module_path from <module> import <func>
- HTTP 协议full_url = url(base) + name(path)
- parameters 统一字段inout 区分输入/输出
"""
import json
from core.parser import InterfaceInfo, InterfaceParser
SYSTEM_PROMPT = """
You are a senior software test engineer. Generate test cases based on the provided
interface descriptions and test requirements.
OUTPUT FORMAT
Return ONLY a valid JSON array. No markdown fences. No extra text.
Each element = ONE test case:
{
"test_id" : "unique id, e.g. TC_001",
"test_name" : "short descriptive name",
"description" : "what this test verifies",
"requirement" : "the original requirement text this case maps to",
"requirement_id" : "e.g. REQ.01",
"steps": [
{
"step_no" : 1,
"interface_name" : "function or endpoint name",
"protocol" : "http or function",
"url" : "source_file (function) or full_url (http)",
"purpose" : "why this step is needed",
"input": {
// Only parameters with inout = "in" or "inout"
"param_name": <value>
},
"use_output_of": { // optional
"step_no" : 1,
"field" : "user_id",
"as_param" : "user_id"
},
"assertions": [
{
"field" : "field_name or 'return' or 'exception'",
"operator" : "eq|ne|gt|lt|gte|lte|in|not_null|contains|raised|not_raised",
"expected" : <value>,
"message" : "human readable description"
}
]
}
],
"test_data_notes" : "explanation of auto-generated test data",
"test_code" : "<complete Python script — see rules below>"
}
TEST CODE RULES
1. Complete, runnable Python script. No external test framework.
2. Allowed imports: standard library + `requests` + the actual module under test.
FUNCTION INTERFACES
3. Each function interface has:
"source_file" : e.g. "create_user.py"
"module_path" : e.g. "create_user" (derived from source_file)
4. Import the REAL function using module_path:
try:
from <module_path> import <function_name>
except ImportError:
# Stub fallback — simulate on_success for positive tests,
# on_failure / raise Exception for negative tests
def <function_name>(<in_params>):
return <on_success_value> # or raise Exception(...)
5. Call the function with ONLY "in"/"inout" parameters.
6. Capture the return value; assert against on_success or on_failure structure.
7. If on_failure.value contains "exception" or "raises":
- In negative tests: wrap call in try/except, assert exception IS raised.
- Use field="exception", operator="raised", expected="Exception".
HTTP INTERFACES
8. Each HTTP interface has:
"full_url" : complete URL, e.g. "http://127.0.0.1/api/delete_user"
"method" : "get" | "post" | "put" | "delete" etc.
9. Send request:
resp = requests.<method>("<full_url>", json=<input_dict>)
10. Assert on resp.status_code and resp.json() fields.
MULTI-STEP
11. Execute steps in order; extract fields from previous step's return value
and pass to next step via use_output_of mapping.
STRUCTURED OUTPUT (REQUIRED)
12. After EACH step, print:
##STEP_RESULT## {"step_no":<n>,"interface_name":"...","status":"PASS|FAIL",
"assertions":[{"field":"...","operator":"...","expected":...,
"actual":...,"passed":true|false,"message":"..."}]}
13. Final line must be:
PASS: <summary> or FAIL: <summary>
14. Wrap entire test body in try/except; print FAIL on any unhandled exception.
15. Do NOT use unittest or pytest.
ASSERTION RULES BY RETURN TYPE
return.type = "dict" assert each key in on_success.value (positive)
or on_failure.value (negative)
return.type = "boolean" field="return", operator="eq", expected=true/false
return.type = "integer" field="return", operator="eq"/"gt"/"lt"
on_failure = exception field="exception", operator="raised"
COVERAGE GUIDELINES
- Per requirement: at least 1 positive + 1 negative test case.
- Negative test_name/description MUST contain "negative" or "invalid".
- Multi-interface requirements: single multi-step test case.
- Cover ALL "in"/"inout" parameters (including optional ones).
- Assert ALL fields described in on_success (positive) and on_failure (negative).
- For "out" parameters: verify their values in assertions after the call.
"""
class PromptBuilder:
def get_system_prompt(self) -> str:
return SYSTEM_PROMPT.strip()
def build_user_prompt(
self,
requirements: list[str],
interfaces: list[InterfaceInfo],
parser: InterfaceParser,
project: str = "",
project_desc: str = "",
) -> str:
iface_summary = parser.to_summary_dict(interfaces)
req_lines = "\n".join(
f"{i + 1}. {r}" for i, r in enumerate(requirements)
)
# 项目信息头
project_section = ""
if project or project_desc:
project_section = "## Project\n"
if project:
project_section += f"Name : {project}\n"
if project_desc:
project_section += f"Description: {project_desc}\n"
project_section += "\n"
return (
f"{project_section}"
f"## Available Interfaces\n"
f"{json.dumps(iface_summary, ensure_ascii=False, indent=2)}\n\n"
f"## Test Requirements\n"
f"{req_lines}\n\n"
f"Generate comprehensive test cases (positive + negative) for every requirement.\n"
f"- For function interfaces: use 'module_path' for import, "
f"'source_file' for reference.\n"
f"- For HTTP interfaces: use 'full_url' as the request URL.\n"
f"- Use requirement_id from the interface to populate the test case's "
f"requirement_id field.\n"
f"- Chain multiple interface calls in one test case when a requirement "
f"involves more than one interface.\n"
).strip()

View File

@ -0,0 +1,492 @@
"""
报告生成器JSON 摘要 + HTML 可视化报告
升级点
- 新增 requirement_id
- 新增 out 参数覆盖率on_success/on_failure 字段覆盖率
- 缺口类型标签中文化扩充
"""
from __future__ import annotations
import json
import logging
from datetime import datetime
from core.analyzer import CoverageReport, Gap
logger = logging.getLogger(__name__)
class ReportGenerator:
def save_json(self, report: CoverageReport, path: str):
with open(path, "w", encoding="utf-8") as f:
json.dump(self._to_dict(report), f, ensure_ascii=False, indent=2)
logger.info(f"Coverage JSON → {path}")
def save_html(self, report: CoverageReport, path: str):
with open(path, "w", encoding="utf-8") as f:
f.write(self._build_html(report))
logger.info(f"Coverage HTML → {path}")
# ── JSON 序列化 ───────────────────────────────────────────
def _to_dict(self, r: CoverageReport) -> dict:
return {
"generated_at": datetime.now().isoformat(timespec="seconds"),
"summary": {
"interface_coverage": f"{r.interface_coverage_rate:.1%}",
"requirement_coverage": f"{r.requirement_coverage_rate:.1%}",
"avg_in_param_coverage": f"{r.avg_in_param_coverage_rate:.1%}",
"avg_success_field_coverage": f"{r.avg_success_field_coverage_rate:.1%}",
"avg_failure_field_coverage": f"{r.avg_failure_field_coverage_rate:.1%}",
"pass_rate": f"{r.pass_rate:.1%}",
"total_interfaces": r.total_interfaces,
"covered_interfaces": r.covered_interfaces,
"total_requirements": r.total_requirements,
"covered_requirements": r.covered_requirements,
"total_test_cases": r.total_test_cases,
"passed": r.passed_test_cases,
"failed": r.failed_test_cases,
"errors": r.error_test_cases,
"critical_gaps": r.critical_gap_count,
"high_gaps": r.high_gap_count,
},
"interface_details": [
{
"name": ic.interface_name,
"requirement_id": ic.requirement_id,
"protocol": ic.protocol,
"is_covered": ic.is_covered,
"has_positive_case": ic.has_positive_case,
"has_negative_case": ic.has_negative_case,
"in_param_coverage": f"{ic.in_param_coverage_rate:.1%}",
"covered_in_params": sorted(ic.covered_in_params),
"uncovered_in_params": sorted(ic.uncovered_in_params),
"out_param_coverage": f"{ic.out_param_coverage_rate:.1%}",
"covered_out_params": sorted(ic.asserted_out_params),
"uncovered_out_params": sorted(ic.uncovered_out_params),
"success_field_coverage": f"{ic.success_field_coverage_rate:.1%}",
"covered_success_fields": sorted(ic.covered_success_fields),
"uncovered_success_fields": sorted(ic.uncovered_success_fields),
"failure_field_coverage": f"{ic.failure_field_coverage_rate:.1%}",
"covered_failure_fields": sorted(ic.covered_failure_fields),
"uncovered_failure_fields": sorted(ic.uncovered_failure_fields),
"expects_exception": ic.expects_exception,
"exception_case_covered": (
ic.exception_case_covered if ic.expects_exception else None
),
"covering_test_ids": ic.covering_test_ids,
}
for ic in r.interface_coverages
],
"requirement_details": [
{
"requirement": rc.requirement,
"requirement_id": rc.requirement_id,
"is_covered": rc.is_covered,
"covering_test_ids": rc.covering_test_ids,
}
for rc in r.requirement_coverages
],
"gaps": [
{
"gap_type": g.gap_type,
"severity": g.severity,
"target": g.target,
"detail": g.detail,
"suggestion": g.suggestion,
}
for g in r.gaps
],
}
# ══════════════════════════════════════════════════════════
# HTML 报告
# ══════════════════════════════════════════════════════════
_GAP_LABELS: dict[str, str] = {
"interface_not_covered": "接口未覆盖",
"in_param_not_covered": "入参未覆盖",
"out_param_not_asserted": "出参未断言",
"success_field_not_asserted": "成功返回字段未断言",
"failure_field_not_asserted": "失败返回字段未断言",
"exception_case_not_tested": "异常场景未测试",
"missing_positive_case": "缺少正向用例",
"missing_negative_case": "缺少负向用例",
"requirement_not_covered": "需求未覆盖",
"test_failed": "用例执行失败",
"test_error": "用例执行异常",
}
# ── 样式辅助 ──────────────────────────────────────────────
@staticmethod
def _rate_color(rate: float) -> str:
if rate >= 0.8: return "#27ae60"
if rate >= 0.5: return "#f39c12"
return "#e74c3c"
@staticmethod
def _sev_badge(sev: str) -> str:
colors = {
"critical": "#e74c3c", "high": "#e67e22",
"medium": "#f1c40f", "low": "#3498db",
}
c = colors.get(sev, "#95a5a6")
return (
f'<span style="background:{c};color:#fff;padding:2px 8px;'
f'border-radius:4px;font-size:11px;font-weight:600">'
f'{sev.upper()}</span>'
)
@staticmethod
def _bool_badge(val: bool) -> str:
if val:
return '<span style="color:#27ae60;font-size:15px;font-weight:700">✔</span>'
return '<span style="color:#e74c3c;font-size:15px;font-weight:700">✘</span>'
def _pct_bar(self, rate: float, w: int = 120) -> str:
color = self._rate_color(rate)
filled = int(w * rate)
return (
f'<div style="display:inline-flex;align-items:center;gap:5px">'
f'<div style="width:{w}px;background:#ecf0f1;border-radius:3px;height:11px">'
f'<div style="width:{filled}px;background:{color};'
f'height:11px;border-radius:3px"></div></div>'
f'<span style="font-size:12px;font-weight:600;color:{color}">'
f'{rate * 100:.0f}%</span></div>'
)
def _metric_card(
self, label: str, value: str, sub: str = "", color: str = "#2c3e50"
) -> str:
return (
f'<div style="background:#fff;border-radius:10px;padding:16px 22px;'
f'box-shadow:0 2px 10px rgba(0,0,0,.08);min-width:148px;text-align:center">'
f'<div style="font-size:26px;font-weight:700;color:{color}">{value}</div>'
f'<div style="font-size:12px;color:#7f8c8d;margin-top:4px">{label}</div>'
f'<div style="font-size:11px;color:#bdc3c7;margin-top:2px">{sub}</div>'
f'</div>'
)
# ── 汇总卡片 ──────────────────────────────────────────────
def _build_cards(self, r: CoverageReport) -> str:
rc = self._rate_color
return "".join([
self._metric_card(
"接口覆盖率", f"{r.interface_coverage_rate:.0%}",
f"{r.covered_interfaces}/{r.total_interfaces}",
rc(r.interface_coverage_rate),
),
self._metric_card(
"需求覆盖率", f"{r.requirement_coverage_rate:.0%}",
f"{r.covered_requirements}/{r.total_requirements}",
rc(r.requirement_coverage_rate),
),
self._metric_card(
"入参覆盖率", f"{r.avg_in_param_coverage_rate:.0%}",
"avg in-params",
rc(r.avg_in_param_coverage_rate),
),
self._metric_card(
"成功返回覆盖", f"{r.avg_success_field_coverage_rate:.0%}",
"on_success fields",
rc(r.avg_success_field_coverage_rate),
),
self._metric_card(
"失败返回覆盖", f"{r.avg_failure_field_coverage_rate:.0%}",
"on_failure fields",
rc(r.avg_failure_field_coverage_rate),
),
self._metric_card(
"用例通过率", f"{r.pass_rate:.0%}",
f"{r.passed_test_cases}/{r.total_test_cases}",
rc(r.pass_rate),
),
self._metric_card(
"Critical 缺口", str(r.critical_gap_count),
f"High: {r.high_gap_count}",
"#e74c3c" if r.critical_gap_count > 0 else "#27ae60",
),
])
# ── 接口覆盖表 ────────────────────────────────────────────
def _build_interface_table(self, r: CoverageReport) -> str:
rows = ""
for ic in r.interface_coverages:
proto_color = "#3498db" if ic.protocol == "http" else "#9b59b6"
# 异常测试单元格
if ic.expects_exception:
exc_cell = self._bool_badge(ic.exception_case_covered)
else:
exc_cell = '<span style="color:#bdc3c7;font-size:13px">N/A</span>'
# 未覆盖项 tooltip
def missing_tip(items: set[str], label: str) -> str:
if not items:
return ""
joined = ", ".join(sorted(items))
return (
f'<div style="font-size:10px;color:#e74c3c;margin-top:2px">'
f'缺: {joined}</div>'
)
rows += f"""
<tr>
<td>
<code style="font-size:12px">{ic.interface_name}</code>
</td>
<td style="font-size:11px;color:#7f8c8d;white-space:nowrap">
{ic.requirement_id or ""}
</td>
<td>
<span style="background:{proto_color};color:#fff;padding:1px 7px;
border-radius:3px;font-size:11px">{ic.protocol.upper()}</span>
</td>
<td style="text-align:center">{self._bool_badge(ic.is_covered)}</td>
<td style="text-align:center">{self._bool_badge(ic.has_positive_case)}</td>
<td style="text-align:center">{self._bool_badge(ic.has_negative_case)}</td>
<td>
{self._pct_bar(ic.in_param_coverage_rate)}
{missing_tip(ic.uncovered_in_params, "")}
</td>
<td>
{self._pct_bar(ic.out_param_coverage_rate)}
{missing_tip(ic.uncovered_out_params, "")}
</td>
<td>
{self._pct_bar(ic.success_field_coverage_rate)}
{missing_tip(ic.uncovered_success_fields, "")}
</td>
<td>
{self._pct_bar(ic.failure_field_coverage_rate)}
{missing_tip(ic.uncovered_failure_fields, "")}
</td>
<td style="text-align:center">{exc_cell}</td>
<td style="font-size:11px;color:#7f8c8d">
{", ".join(ic.covering_test_ids) or ""}
</td>
</tr>"""
return rows
# ── 需求覆盖表 ────────────────────────────────────────────
def _build_requirement_table(self, r: CoverageReport) -> str:
rows = ""
for rc in r.requirement_coverages:
rows += f"""
<tr>
<td style="font-size:11px;color:#7f8c8d;white-space:nowrap">
{rc.requirement_id or ""}
</td>
<td style="max-width:320px;font-size:13px">{rc.requirement}</td>
<td style="text-align:center">{self._bool_badge(rc.is_covered)}</td>
<td style="font-size:11px;color:#7f8c8d">
{", ".join(rc.covering_test_ids) or ""}
</td>
</tr>"""
return rows
# ── 缺口清单表 ────────────────────────────────────────────
def _build_gap_table(self, r: CoverageReport) -> str:
if not r.gaps:
return (
'<tr><td colspan="6" style="text-align:center;color:#27ae60;'
'padding:24px;font-size:14px">🎉 No gaps found! Full coverage achieved.</td></tr>'
)
rows = ""
for i, g in enumerate(r.gaps, 1):
rows += f"""
<tr>
<td style="text-align:center;color:#7f8c8d;font-size:12px">{i}</td>
<td style="white-space:nowrap">{self._sev_badge(g.severity)}</td>
<td>
<span style="font-size:11px;background:#ecf0f1;padding:2px 6px;
border-radius:3px;white-space:nowrap">
{self._GAP_LABELS.get(g.gap_type, g.gap_type)}
</span>
</td>
<td><code style="font-size:11px">{g.target}</code></td>
<td style="font-size:12px;max-width:280px">{g.detail}</td>
<td style="font-size:12px;color:#2980b9;max-width:280px">{g.suggestion}</td>
</tr>"""
return rows
# ── 缺口统计图(纯 CSS 横向柱状图)────────────────────────
def _build_gap_chart(self, r: CoverageReport) -> str:
from collections import Counter
type_counts = Counter(g.gap_type for g in r.gaps)
if not type_counts:
return ""
max_count = max(type_counts.values(), default=1)
bars = ""
sev_color_map: dict[str, str] = {}
for g in r.gaps:
sev_color_map[g.gap_type] = {
"critical": "#e74c3c", "high": "#e67e22",
"medium": "#f1c40f", "low": "#3498db",
}.get(g.severity, "#95a5a6")
for gap_type, count in sorted(type_counts.items(), key=lambda x: -x[1]):
label = self._GAP_LABELS.get(gap_type, gap_type)
color = sev_color_map.get(gap_type, "#95a5a6")
width = int(count / max_count * 320)
bars += f"""
<div style="display:flex;align-items:center;gap:10px;margin-bottom:8px">
<div style="width:160px;font-size:12px;color:#555;text-align:right;
white-space:nowrap;overflow:hidden;text-overflow:ellipsis"
title="{label}">{label}</div>
<div style="width:{width}px;height:18px;background:{color};
border-radius:3px;transition:width .3s"></div>
<span style="font-size:12px;font-weight:600;color:{color}">{count}</span>
</div>"""
return f"""
<div style="background:#fff;border-radius:10px;padding:20px 24px;
box-shadow:0 2px 8px rgba(0,0,0,.07);margin-bottom:8px">
<div style="font-size:14px;font-weight:600;margin-bottom:16px;color:#2c3e50">
缺口类型分布
</div>
{bars}
</div>"""
# ── 组装完整 HTML ─────────────────────────────────────────
def _build_html(self, r: CoverageReport) -> str:
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
css = """
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
background: #f4f6f9; color: #2c3e50;
}
.wrap { max-width: 1440px; margin: 0 auto; padding: 32px 20px; }
h1 { font-size: 24px; font-weight: 700; margin-bottom: 4px; }
.sub { color: #7f8c8d; font-size: 12px; margin-bottom: 26px; }
.cards { display: flex; flex-wrap: wrap; gap: 14px; margin-bottom: 32px; }
h2 {
font-size: 16px; font-weight: 600; margin: 28px 0 12px;
padding-left: 10px; border-left: 4px solid #3498db;
}
table {
width: 100%; border-collapse: collapse; background: #fff;
border-radius: 10px; overflow: hidden;
box-shadow: 0 2px 8px rgba(0,0,0,.07);
}
th {
background: #2c3e50; color: #fff; padding: 10px 12px;
font-size: 12px; text-align: left; white-space: nowrap;
}
td {
padding: 9px 12px; font-size: 12px;
border-bottom: 1px solid #ecf0f1; vertical-align: middle;
}
tr:last-child td { border-bottom: none; }
tr:hover td { background: #f8f9fa; }
code {
background: #ecf0f1; padding: 1px 4px; border-radius: 3px;
font-family: "SFMono-Regular", Consolas, monospace;
}
.footer {
text-align: center; color: #bdc3c7;
font-size: 11px; margin-top: 36px; padding-bottom: 20px;
}
"""
return f"""<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>AI Test Coverage Report</title>
<style>{css}</style>
</head>
<body>
<div class="wrap">
<!-- 标题 -->
<h1>🧪 AI Test Coverage Report</h1>
<div class="sub">Generated at {ts} &nbsp;|&nbsp;
Interfaces: {r.total_interfaces} &nbsp;|&nbsp;
Test Cases: {r.total_test_cases} &nbsp;|&nbsp;
Gaps: {len(r.gaps)}
</div>
<!-- 汇总卡片 -->
<div class="cards">{self._build_cards(r)}</div>
<!-- 缺口分布图 -->
{self._build_gap_chart(r)}
<!-- 接口覆盖详情 -->
<h2>📡 接口覆盖详情</h2>
<div style="overflow-x:auto">
<table>
<thead>
<tr>
<th>接口名称</th>
<th>REQ ID</th>
<th>协议</th>
<th>已覆盖</th>
<th>正向</th>
<th>负向</th>
<th>入参覆盖率</th>
<th>出参断言率</th>
<th>成功返回覆盖</th>
<th>失败返回覆盖</th>
<th>异常测试</th>
<th>覆盖用例</th>
</tr>
</thead>
<tbody>{self._build_interface_table(r)}</tbody>
</table>
</div>
<!-- 需求覆盖详情 -->
<h2>📋 需求覆盖详情</h2>
<table>
<thead>
<tr>
<th>REQ ID</th>
<th>需求描述</th>
<th>已覆盖</th>
<th>覆盖用例</th>
</tr>
</thead>
<tbody>{self._build_requirement_table(r)}</tbody>
</table>
<!-- 缺口清单 -->
<h2>🔍 缺口清单 {len(r.gaps)}
<span style="color:#e74c3c">Critical: {r.critical_gap_count}</span>
<span style="color:#e67e22">High: {r.high_gap_count}</span>
</h2>
<div style="overflow-x:auto">
<table>
<thead>
<tr>
<th>#</th>
<th>严重程度</th>
<th>缺口类型</th>
<th>目标</th>
<th>详情</th>
<th>补充建议</th>
</tr>
</thead>
<tbody>{self._build_gap_table(r)}</tbody>
</table>
</div>
<div class="footer">
AI-Powered Test Generator &nbsp;·&nbsp; Coverage Analyzer &nbsp;·&nbsp;
Report generated at {ts}
</div>
</div>
</body>
</html>"""

View File

@ -0,0 +1,138 @@
"""
测试文件生成器
升级点
- 接收 project / project_desc写入头部注释
- 头部注释增加 url / source_file / module_path 信息
- 输出到 <GENERATED_TESTS_DIR>/<project>/ 子目录
"""
from __future__ import annotations
import json
import re
import logging
from pathlib import Path
from config import config
logger = logging.getLogger(__name__)
class TestGenerator:
def __init__(self, project: str = "", project_desc: str = ""):
self.project = project
self.project_desc = project_desc
safe = self._safe_name(project) if project else ""
base = Path(config.GENERATED_TESTS_DIR)
self.output_dir = base / safe if safe else base
self.output_dir.mkdir(parents=True, exist_ok=True)
# ── 公开接口 ──────────────────────────────────────────────
def generate(self, test_cases: list[dict]) -> list[Path]:
self._save_summary(test_cases)
files: list[Path] = []
for tc in test_cases:
path = self._write_file(tc)
if path:
files.append(path)
logger.info(
f"Generated {len(files)} test file(s) → '{self.output_dir}'"
)
return files
# ── 写入单个测试文件 ──────────────────────────────────────
def _write_file(self, tc: dict) -> Path | None:
test_id = tc.get("test_id", "TC_UNKNOWN")
test_code = tc.get("test_code", "").strip()
if not test_code:
logger.warning(f"[{test_id}] No test_code, skipping.")
return None
file_path = self.output_dir / f"{self._safe_name(test_id)}.py"
with open(file_path, "w", encoding="utf-8") as f:
f.write(self._build_header(tc))
f.write("\n")
f.write(test_code)
f.write("\n")
logger.info(f" Written: {file_path}")
return file_path
# ── 文件头注释 ────────────────────────────────────────────
def _build_header(self, tc: dict) -> str:
req_id = tc.get("requirement_id", "")
req_text = tc.get("requirement", "")
data_notes = tc.get("test_data_notes", "")
steps = tc.get("steps", [])
# 项目信息行
proj_lines = ""
if self.project:
proj_lines += f"# Project : {self.project}\n"
if self.project_desc:
proj_lines += f"# Proj Desc : {self.project_desc}\n"
# 步骤注释
step_lines = ""
for s in steps:
protocol = s.get("protocol", "").upper()
iface = s.get("interface_name", "")
url = s.get("url", "")
purpose = s.get("purpose", "")
step_lines += (
f"# Step {s.get('step_no')}: [{protocol}] {iface}\n"
f"# URL : {url}\n"
f"# Purpose: {purpose}\n"
)
inputs = s.get("input", {})
if inputs:
step_lines += (
f"# Input : "
f"{json.dumps(inputs, ensure_ascii=False)}\n"
)
for a in s.get("assertions", []):
msg = a.get("message", "")
step_lines += (
f"# Assert : {a.get('field')} "
f"{a.get('operator')} {a.get('expected')}"
f"{'' + msg if msg else ''}\n"
)
return (
f"# {'=' * 60}\n"
f"{proj_lines}"
f"# Test ID : {tc.get('test_id', '')}\n"
f"# Test Name : {tc.get('test_name', '')}\n"
f"# Requirement: [{req_id}] {req_text}\n"
f"# Description: {tc.get('description', '')}\n"
f"# Test Data : {data_notes}\n"
f"# Steps:\n"
f"{step_lines.rstrip()}\n"
f"# {'=' * 60}\n"
f"import os\n"
f"import json\n"
)
# ── 保存摘要 JSON ─────────────────────────────────────────
def _save_summary(self, test_cases: list[dict]):
summary = [
{k: v for k, v in tc.items() if k != "test_code"}
for tc in test_cases
]
path = self.output_dir / "test_cases_summary.json"
with open(path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
logger.info(f"Summary → {path}")
# ── 工具 ──────────────────────────────────────────────────
@staticmethod
def _safe_name(name: str) -> str:
return re.sub(r'[^\w\-]', '_', name).strip("_")

View File

@ -0,0 +1,168 @@
import os
import sys
import json
import subprocess
import time
import logging
from pathlib import Path
from dataclasses import dataclass, field
from config import config
logger = logging.getLogger(__name__)
@dataclass
class StepResult:
step_no: int
interface_name: str
status: str # PASS | FAIL | ERROR
assertions: list[dict] = field(default_factory=list)
# 每条 assertion: {"field","operator","expected","actual","passed","message"}
@dataclass
class TestResult:
test_id: str
file_path: str
status: str # PASS | FAIL | ERROR | TIMEOUT
message: str = ""
duration: float = 0.0
stdout: str = ""
stderr: str = ""
step_results: list[StepResult] = field(default_factory=list)
class TestRunner:
def run_all(self, test_files: list[Path]) -> list[TestResult]:
results = []
total = len(test_files)
print(f"\n{''*62}")
print(f" Running {total} test(s) …")
print(f"{''*62}")
for idx, fp in enumerate(test_files, 1):
print(f"\n[{idx}/{total}] {fp.name}")
result = self._run_one(fp)
results.append(result)
self._print_result(result)
return results
# ── 执行单个脚本 ──────────────────────────────────────────
def _run_one(self, file_path: Path) -> TestResult:
test_id = file_path.stem
t0 = time.time()
try:
proc = subprocess.run(
[sys.executable, str(file_path)],
capture_output=True, text=True,
timeout=config.TEST_TIMEOUT,
env=self._env(),
)
duration = time.time() - t0
stdout = proc.stdout.strip()
stderr = proc.stderr.strip()
status, message = self._parse_output(stdout, proc.returncode)
step_results = self._parse_step_results(stdout)
return TestResult(
test_id=test_id, file_path=str(file_path),
status=status, message=message,
duration=duration, stdout=stdout, stderr=stderr,
step_results=step_results,
)
except subprocess.TimeoutExpired:
return TestResult(
test_id=test_id, file_path=str(file_path),
status="TIMEOUT", message=f"Exceeded {config.TEST_TIMEOUT}s",
duration=time.time() - t0,
)
except Exception as e:
return TestResult(
test_id=test_id, file_path=str(file_path),
status="ERROR", message=str(e),
duration=time.time() - t0,
)
# ── 解析输出 ──────────────────────────────────────────────
def _parse_output(self, stdout: str, returncode: int) -> tuple[str, str]:
if not stdout:
return "FAIL", f"No output (exit={returncode})"
last = stdout.strip().splitlines()[-1].strip()
upper = last.upper()
if upper.startswith("PASS"):
return "PASS", last[5:].strip()
if upper.startswith("FAIL"):
return "FAIL", last[5:].strip()
return ("PASS" if returncode == 0 else "FAIL"), last
def _parse_step_results(self, stdout: str) -> list[StepResult]:
"""
解析脚本中以 ##STEP_RESULT## 开头的结构化输出行
格式##STEP_RESULT## <json>
"""
results = []
for line in stdout.splitlines():
line = line.strip()
if line.startswith("##STEP_RESULT##"):
try:
data = json.loads(line[len("##STEP_RESULT##"):].strip())
results.append(StepResult(
step_no=data.get("step_no", 0),
interface_name=data.get("interface_name", ""),
status=data.get("status", ""),
assertions=data.get("assertions", []),
))
except Exception:
pass
return results
def _env(self) -> dict:
env = os.environ.copy()
env["HTTP_BASE_URL"] = config.HTTP_BASE_URL
return env
# ── 打印 ──────────────────────────────────────────────────
def _print_result(self, r: TestResult):
icon = {"PASS": "", "FAIL": "", "TIMEOUT": "⏱️", "ERROR": "⚠️"}.get(r.status, "")
print(f" {icon} [{r.status}] {r.test_id} ({r.duration:.2f}s)")
print(f" {r.message}")
for sr in r.step_results:
s_icon = "" if sr.status == "PASS" else ""
print(f" {s_icon} Step {sr.step_no}: {sr.interface_name}")
for a in sr.assertions:
a_icon = "" if a.get("passed") else ""
print(f" {a_icon} {a.get('message','')} "
f"(expected={a.get('expected')}, actual={a.get('actual')})")
if r.stderr:
print(f" stderr: {r.stderr[:300]}")
def print_summary(self, results: list[TestResult]):
total = len(results)
passed = sum(1 for r in results if r.status == "PASS")
failed = sum(1 for r in results if r.status == "FAIL")
errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT"))
print(f"\n{''*62}")
print(f" TEST SUMMARY")
print(f"{''*62}")
print(f" Total : {total}")
print(f" ✅ PASS : {passed}")
print(f" ❌ FAIL : {failed}")
print(f" ⚠️ ERROR : {errors}")
print(f" Pass Rate: {passed/total*100:.1f}%" if total else " Pass Rate: N/A")
print(f"{''*62}\n")
def save_results(self, results: list[TestResult], path: str):
with open(path, "w", encoding="utf-8") as f:
json.dump([{
"test_id": r.test_id, "status": r.status,
"message": r.message, "duration": r.duration,
"stdout": r.stdout, "stderr": r.stderr,
"step_results": [
{"step_no": sr.step_no, "interface_name": sr.interface_name,
"status": sr.status, "assertions": sr.assertions}
for sr in r.step_results
],
} for r in results], f, ensure_ascii=False, indent=2)
logger.info(f"Run results → {path}")

View File

@ -0,0 +1,91 @@
{
"project": "project name",
"description": "this is my first project",
"units": [
{
"name": "create_user",
"requirement_id": "REQ.01",
"description": "Creates a new user account with provided credentials and information.",
"type": "function",
"url": "create_user.py",
"parameters": {
"username": {
"type": "string",
"inout": "in",
"description": "The desired username for the new account.",
"required": true
},
"password": {
"type": "string",
"inout": "in",
"description": "The password for the new account, meeting security requirements.",
"required": true
},
"email": {
"type": "string",
"inout": "in",
"description": "The email address associated with the new account.",
"required": true
},
"phone_number": {
"type": "string",
"inout": "in",
"description": "The phone number associated with the new account.",
"required": false
},
"additional_info": {
"type": "dict",
"inout": "in",
"description": "Optional additional information about the user.",
"required": false
}
},
"return": {
"type": "dict",
"description": "Returns the created user object on success or an error message on failure.",
"on_success": {
"value": "User object containing user details such as id, username, email, etc.",
"description": "The user object representing the newly created account."
},
"on_failure": {
"value": "Error message or raises an exception.",
"description": "Details about why the user creation failed, such as validation errors or duplicate username."
}
}
},
{
"name": "/delete_user",
"requirement_id": "REQ.02",
"description": "Deletes a user by user ID or username from the system.",
"type": "http",
"url": "http://127.0.0.1/api",
"method": "post",
"parameters": {
"user_id": {
"type": "integer",
"inout": "in",
"description": "Unique identifier of the user to be deleted.",
"required": false
},
"username": {
"type": "string",
"inout": "in",
"description": "Username of the user to be deleted.",
"required": false
}
},
"return": {
"type": "boolean",
"description": "Indicates whether the user was successfully deleted.",
"on_success": {
"value": true,
"description": "The user was successfully deleted."
},
"on_failure": {
"value": false,
"description": "The user could not be deleted, or the user does not exist."
}
}
}
]
}

View File

@ -0,0 +1,91 @@
{
"project": "project name",
"description": "this is my first project",
"units": [
{
"name": "create_user",
"requirement_id": "REQ.01",
"description": "Creates a new user account with provided credentials and information.",
"type": "function",
"url": "/Users/sontolau/Workspace/AIDeveloper/requirements_generator/output/my_first_project/create_user.py",
"parameters": {
"username": {
"type": "string",
"inout": "in",
"description": "The desired username for the new account.",
"required": true
},
"password": {
"type": "string",
"inout": "in",
"description": "The password for the new account, meeting security requirements.",
"required": true
},
"email": {
"type": "string",
"inout": "in",
"description": "The email address associated with the new account.",
"required": true
},
"phone_number": {
"type": "string",
"inout": "in",
"description": "The phone number associated with the new account.",
"required": false
},
"additional_info": {
"type": "dict",
"inout": "in",
"description": "Optional additional information about the user.",
"required": false
}
},
"return": {
"type": "dict",
"description": "Returns the created user object on success or an error message on failure.",
"on_success": {
"value": "User object containing user details such as id, username, email, etc.",
"description": "The user object representing the newly created account."
},
"on_failure": {
"value": "Error message or raises an exception.",
"description": "Details about why the user creation failed, such as validation errors or duplicate username."
}
}
},
{
"name": "/delete_user",
"requirement_id": "REQ.02",
"description": "Deletes a user by user ID or username from the system.",
"type": "http",
"url": "http://127.0.0.1/api",
"method": "post",
"parameters": {
"user_id": {
"type": "integer",
"inout": "in",
"description": "Unique identifier of the user to be deleted.",
"required": false
},
"username": {
"type": "string",
"inout": "in",
"description": "Username of the user to be deleted.",
"required": false
}
},
"return": {
"type": "boolean",
"description": "Indicates whether the user was successfully deleted.",
"on_success": {
"value": true,
"description": "The user was successfully deleted."
},
"on_failure": {
"value": false,
"description": "The user could not be deleted, or the user does not exist."
}
}
}
]
}

232
ai_test_generator/main.py Normal file
View File

@ -0,0 +1,232 @@
#!/usr/bin/env python3
"""
AI-Powered API Test Generator, Runner & Coverage Analyzer
Usage:
python main.py --api-desc examples/api_desc.json \\
--requirements "创建用户,删除用户"
python main.py --api-desc examples/api_desc.json \\
--req-file examples/requirements.txt
"""
import argparse
import logging
import sys
from pathlib import Path
from config import config
from core.parser import InterfaceParser, ApiDescriptor
from core.prompt_builder import PromptBuilder
from core.llm_client import LLMClient
from core.test_generator import TestGenerator
from core.test_runner import TestRunner
from core.analyzer import CoverageAnalyzer, CoverageReport
from core.report_generator import ReportGenerator
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s%(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
# ── CLI ───────────────────────────────────────────────────────
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="AI-Powered API Test Generator & Coverage Analyzer"
)
p.add_argument(
"--api-desc", required=True,
help="接口描述 JSON 文件(含 project / description / units 字段)",
)
p.add_argument(
"--requirements", default="",
help="测试需求(逗号分隔)",
)
p.add_argument(
"--req-file", default="",
help="测试需求文件(每行一条)",
)
p.add_argument(
"--skip-run", action="store_true",
help="只生成测试文件,不执行",
)
p.add_argument(
"--skip-analyze", action="store_true",
help="跳过覆盖率分析",
)
p.add_argument(
"--output-dir", default="",
help="测试文件根输出目录(默认 config.GENERATED_TESTS_DIR",
)
p.add_argument(
"--debug", action="store_true",
help="开启 DEBUG 日志",
)
return p
def load_requirements(args) -> list[str]:
reqs: list[str] = []
if args.requirements:
reqs += [r.strip() for r in args.requirements.split(",") if r.strip()]
if args.req_file:
with open(args.req_file, "r", encoding="utf-8") as f:
reqs += [line.strip() for line in f if line.strip()]
if not reqs:
logger.error(
"No requirements provided. Use --requirements or --req-file."
)
sys.exit(1)
return reqs
# ── Main ──────────────────────────────────────────────────────
def main():
args = build_arg_parser().parse_args()
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
if args.output_dir:
config.GENERATED_TESTS_DIR = args.output_dir
# ── Step 1: 解析接口描述 ──────────────────────────────────
logger.info("▶ Step 1: Parsing interface description …")
parser: InterfaceParser = InterfaceParser()
descriptor: ApiDescriptor = parser.parse_file(args.api_desc)
project = descriptor.project
project_desc = descriptor.description
interfaces = descriptor.interfaces
logger.info(f" Project : {project}")
logger.info(f" Description: {project_desc}")
logger.info(
f" Interfaces : {len(interfaces)}"
f"{[i.name for i in interfaces]}"
)
# 打印每个接口的 url 解析结果
for iface in interfaces:
if iface.protocol == "function":
logger.info(
f" [{iface.name}] source={iface.source_file} "
f"module={iface.module_path}"
)
else:
logger.info(
f" [{iface.name}] full_url={iface.http_full_url}"
)
# ── Step 2: 加载测试需求 ──────────────────────────────────
logger.info("▶ Step 2: Loading test requirements …")
requirements = load_requirements(args)
for i, req in enumerate(requirements, 1):
logger.info(f" {i}. {req}")
# ── Step 3: 调用 LLM 生成测试用例 ────────────────────────
logger.info("▶ Step 3: Calling LLM …")
builder = PromptBuilder()
test_cases = LLMClient().generate_test_cases(
builder.get_system_prompt(),
builder.build_user_prompt(
requirements, interfaces, parser,
project=project,
project_desc=project_desc,
),
)
logger.info(f" LLM returned {len(test_cases)} test case(s)")
# ── Step 4: 生成测试文件 ──────────────────────────────────
logger.info(
f"▶ Step 4: Generating test files (project='{project}') …"
)
generator = TestGenerator(project=project, project_desc=project_desc)
test_files = generator.generate(test_cases)
out = generator.output_dir
run_results = []
if not args.skip_run:
# ── Step 5: 执行测试 ──────────────────────────────────
logger.info("▶ Step 5: Running tests …")
runner = TestRunner()
run_results = runner.run_all(test_files)
runner.print_summary(run_results)
runner.save_results(run_results, str(out / "run_results.json"))
if not args.skip_analyze:
# ── Step 6: 覆盖率分析 ────────────────────────────────
logger.info("▶ Step 6: Analyzing coverage …")
report = CoverageAnalyzer(
interfaces=interfaces,
requirements=requirements,
test_cases=test_cases,
run_results=run_results,
).analyze()
# ── Step 7: 生成报告 ──────────────────────────────────
logger.info("▶ Step 7: Generating reports …")
rg = ReportGenerator()
rg.save_json(report, str(out / "coverage_report.json"))
rg.save_html(report, str(out / "coverage_report.html"))
_print_terminal_summary(report, out, project)
logger.info(f"\n✅ Done. Output: {out.resolve()}")
# ── 终端摘要 ──────────────────────────────────────────────────
def _print_terminal_summary(
report: CoverageReport, out: Path, project: str
):
W = 66
def bar(rate: float, w: int = 20) -> str:
filled = int(rate * w)
empty = w - filled
icon = "" if rate >= 0.8 else ("⚠️ " if rate >= 0.5 else "")
return f"{'' * filled}{'' * empty} {rate * 100:.1f}% {icon}"
print(f"\n{'' * W}")
print(f" PROJECT : {project}")
print(f" COVERAGE SUMMARY")
print(f"{'' * W}")
print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
print(f" 用例通过率 {bar(report.pass_rate)}")
print(f"{'' * W}")
print(f" Total Gaps : {len(report.gaps)}")
print(f" 🔴 Critical: {report.critical_gap_count}")
print(f" 🟠 High : {report.high_gap_count}")
if report.gaps:
print(f"{'' * W}")
print(" Top Gaps (up to 8):")
icons = {
"critical": "🔴", "high": "🟠",
"medium": "🟡", "low": "🔵",
}
for g in report.gaps[:8]:
icon = icons.get(g.severity, "")
print(f" {icon} [{g.gap_type}] {g.target}")
print(f"{g.suggestion}")
print(f"{'' * W}")
print(f" Output : {out.resolve()}")
print(f" • coverage_report.html ← open in browser")
print(f" • coverage_report.json")
print(f" • run_results.json")
print(f" • test_cases_summary.json")
print(f"{'' * W}\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,2 @@
openai>=1.30.0
requests>=2.31.0

19
ai_test_generator/run.sh Executable file
View File

@ -0,0 +1,19 @@
#!/usr/bin/env bash
# 安装依赖
pip install -r requirements.txt
# 设置环境变量
export LLM_API_KEY="sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR"
export LLM_BASE_URL="https://openapi.monica.im/v1" # 或其他兼容接口
export LLM_MODEL="gpt-4o"
export HTTP_BASE_URL="http://localhost:8080"
# 运行(生成 + 执行)
# python main.py \
# --api-desc examples/api_desc.json \
# --requirements "创建用户,自动生成测试数据,更改指定用户的密码,自动生成测试数据"
# 只生成不执行
python main.py --api-desc examples/function_signatures.json \
--requirements "1.对每个接口进行测试,支持特殊值、边界值测试"

59
auto-test.txt Normal file
View File

@ -0,0 +1,59 @@
你现在是一名高级ai软件工程师非常擅长通过llm模型解决软件应用需求现在需要构建python工程利用ai实现基于api的自动化测试代码生成与执行支持将用户每个测试需求分解成对1个或多个接口的调用并基于接口返回描述对接口返回值进行测试设计如下
# 工作流程
1. 用户输入测试需求、接口描述文件文件采用json格式包含接口名称、请求参数描述、返回参数描述接口描述等信息。例如
## 接口描述示例
[
{
"name": "/api/create_user",
"description": "to create new user account in the system",
"protocol": "http",
"base_url": "http://123.0.0.1:8000"
"method": "post",
"url_parameters": {
"version": {
"type": "integer",
"description": "the api's version",
"required": "false"
}
},
"request_parameters": {
"username": { "type": "string", "description": "user name", "required": true },
"email": { "type": "string", "description": "the user's email", "required": false },
"password": { "type": "string", "description": "the user's password", "required": true }
},
"response": {
"code": {
"type": "integer",
"description": "0 is successful, nonzero is failure"
},
"user_id": {
"type": "integer",
"description": "the created user's id "
}
}
},
{
"name": "change_password",
"description": "replace old password with new password",
"protocol": "function",
"parameters": {
"user_id": { "type": "integer", "inout","in", "description": "the user's id", "required": true },
"old_password": { "type": "string", "inout","in", "description": "the old password", "required": true },
"new_password": { "type": "string", "inout","in", "description": "the user's new password", "required": true }
},
"return": {
"type": "integer",
"description": "0 is successful, nonzero is failure"
}
}
...
]
## 测试需求示例
1创建用户自动生成测试数据
2更改指定用户的密码自动生成测试数据支持先创建用户再改变该用户的密码
2. 软件将接口信息、指令、输出要求等通过提示词传递给llm模型。
3. llm模型结合用户测试需求并根据接口列表生成针对每个测试需求的测试用例包含生成的测试数据、代码等测试用例支持基于测试逻辑对多个接口的调用与返回值测试
4. 输出的测试用例以json格式返回。
5. 软件解析json数据并生成测试用例文件
6. 软件运行测试用例文件并输出每个测试用例的结果;

0
init_request.json Normal file
View File

81
mcp-client.py Normal file
View File

@ -0,0 +1,81 @@
#!/usr/bin/env python3
import requests
import json
session_id="123456789"
class MCPClient:
def __init__(self, url):
self.url = url
self.session = requests.Session()
self.session_id = None
def initialize(self):
"""初始化 MCP 会话"""
payload = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "python-client",
"version": "1.0.0"
}
}
}
response = self.session.post(
self.url,
json=payload,
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
}
)
self.session_id = response.headers['Mcp-Session-Id']
def call_method(self, method, params=None, request_id=None):
"""调用 MCP 方法"""
payload = {
"jsonrpc": "2.0",
"id": request_id or method,
"method": method
}
if params:
payload["params"] = params
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
}
# 如果有 session ID添加到头部
if self.session_id:
headers["Mcp-Session-Id"] = self.session_id
response = self.session.post(
self.url,
json=payload,
headers=headers
)
print(response.text)
return response.json()
# 使用示例
if __name__ == "__main__":
client = MCPClient("http://iot.gesukj.com:7999/mcp")
# 初始化
print("Initializing...")
init_result = client.initialize()
# print(json.dumps(init_result, indent=2))
# 列出工具
print("\nListing tools...")
tools_result = client.call_method("tools/list", request_id=4)
print(json.dumps(tools_result, indent=2))

View File

View File

@ -0,0 +1,146 @@
# config.py - 全局配置管理
import os
from dotenv import load_dotenv
load_dotenv()
# ── LLM 配置 ──────────────────────────────────────────
LLM_API_KEY = os.getenv("OPENAI_API_KEY", "")
LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0.3"))
# ── 数据库配置 ─────────────────────────────────────────
DB_PATH = os.getenv("DB_PATH", "data/requirement_analyzer.db")
# ── 输出配置 ───────────────────────────────────────────
OUTPUT_BASE_DIR = os.getenv("OUTPUT_BASE_DIR", "output")
DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "python")
# ══════════════════════════════════════════════════════
# Prompt 模板
# ══════════════════════════════════════════════════════
DECOMPOSE_PROMPT_TEMPLATE = """
你是一位资深软件架构师和产品经理请根据以下信息将原始需求分解为若干个可独立实现的功能需求
{knowledge_section}
## 原始需求
{raw_requirement}
## 输出要求
请严格按照以下 JSON 格式输出不要包含任何额外说明
{{
"functional_requirements": [
{{
"index": 1,
"title": "功能需求标题简洁10字以内",
"description": "功能需求详细描述(包含输入、处理逻辑、输出)",
"function_name": "snake_case函数名",
"priority": "high|medium|low"
}}
]
}}
要求
1. 每个功能需求必须是独立可实现的最小单元
2. function_name 使用 snake_case 命名清晰表达函数用途
3. 分解粒度适中通常 5-15 个功能需求
4. 优先级根据业务重要性判断
"""
# ── 函数签名 JSON 生成 Prompt ──────────────────────────
FUNC_SIGNATURE_PROMPT_TEMPLATE = """
你是一位资深软件架构师请根据以下功能需求描述设计该函数的完整接口签名并以 JSON 格式输出
{knowledge_section}
## 功能需求
需求编号{requirement_id}
标题{title}
函数名{function_name}
详细描述{description}
## 输出格式
请严格按照以下 JSON 结构输出不要包含任何额外说明或 markdown 标记
{{
"name": "{function_name}",
"requirement_id": "{requirement_id}",
"description": "简洁的一句话功能描述(英文)",
"type": "function",
"parameters": {{
"<param_name>": {{
"type": "integer|string|boolean|float|list|dict|object",
"inout": "in|out|inout",
"description": "参数说明(英文)",
"required": true
}}
}},
"return": {{
"type": "integer|string|boolean|float|list|dict|object|void",
"description": "整体返回值说明(英文,一句话概括)",
"on_success": {{
"value": "具体成功返回值或范围,如 0、true、user object、list of items 等",
"description": "成功时的返回值含义(英文)"
}},
"on_failure": {{
"value": "具体失败返回值或范围,如 nonzero、false、null、empty list、raises Exception 等",
"description": "失败时的返回值含义,或抛出的异常类型(英文)"
}}
}}
}}
## 设计规范
1. 参数名使用 snake_case类型使用通用类型不绑定具体语言
2. inout 字段含义
- in = 仅输入参数
- out = 仅输出参数通过参数传出结果如指针/引用
- inout = 既作输入又作输出
3. 所有描述字段使用英文
4. return 字段规则
- 若函数无返回值voidtype "void"on_success/on_failure 均填 null
- 若返回值只有成功场景如纯查询on_failure 可描述为 "null or empty"
- on_success.value / on_failure.value 填写具体值或值域描述不要填写空字符串
5. 若函数无参数parameters {{}}
6. required 字段为布尔值 true false
"""
# ── 代码生成 Prompt含签名约束─────────────────────────
CODE_GEN_PROMPT_TEMPLATE = """
你是一位资深 {language} 工程师请根据以下功能需求和函数签名规范生成完整的 {language} 函数代码
{knowledge_section}
## 功能需求
标题{title}
描述{description}
## 【必须严格遵守】函数签名规范
以下 JSON 定义了函数的精确接口生成的代码必须与之完全一致不得擅自增减或改名参数
```json
{signature_json}
```
### 签名字段说明
- `name`函数名必须完全一致
- `parameters`每个 key 即为参数名`type` 为数据类型`inout` 含义
- `in` = 普通输入参数
- `out` = 输出参数Python 中通过返回值或可变容器传出
- `inout` = 既作输入又作输出
- `return.type`返回值类型
- `return.on_success`成功时的返回值代码实现必须与此一致
- `return.on_failure`失败时的返回值或异常代码实现必须与此一致
## 输出要求
1. 只输出纯代码不要包含 markdown 代码块标记
2. 函数签名名称参数列表返回类型必须与上方 JSON 规范完全一致
3. 成功/失败的返回值必须严格遵守 return.on_success / return.on_failure 的定义
4. 包含完整的类型注解Python 使用 type hints
5. 包含详细的 docstring其中 Returns 段须注明成功值与失败值
6. 包含必要的异常处理
7. 代码风格遵循 PEP8Python或对应语言规范
8. 在文件顶部用注释注明需求编号功能标题函数签名摘要
9. 如需导入第三方库请在顶部统一导入
"""

View File

View File

@ -0,0 +1,216 @@
# core/code_generator.py - 代码生成核心逻辑(签名约束版)
import json
from typing import Optional, List, Callable
import config
from core.llm_client import LLMClient
from database.models import FunctionalRequirement, CodeFile
from utils.output_writer import write_code_file, get_file_extension
class CodeGenerator:
"""
根据功能需求 + 函数签名约束使用 LLM 生成代码函数文件
签名由 RequirementAnalyzer.build_function_signature() 预先生成
注入 Prompt 后可确保代码参数列表与签名 JSON 完全一致
"""
def __init__(self, llm_client: Optional[LLMClient] = None):
"""
初始化代码生成器
Args:
llm_client: LLM 客户端实例 None 时自动创建
"""
self.llm = llm_client or LLMClient()
# ══════════════════════════════════════════════════
# 单个生成
# ══════════════════════════════════════════════════
def generate(
self,
func_req: FunctionalRequirement,
output_dir: str,
language: str = config.DEFAULT_LANGUAGE,
knowledge: str = "",
signature: Optional[dict] = None,
) -> CodeFile:
"""
为单个功能需求生成代码文件
Args:
func_req: 功能需求对象必须含有效 id
output_dir: 代码输出目录
language: 目标编程语言
knowledge: 知识库文本可选
signature: 函数签名 dict RequirementAnalyzer 生成
传入后将作为强约束注入 Prompt确保代码参数
与签名 JSON 完全一致 None 时退化为无约束模式
Returns:
CodeFile 对象含生成的代码内容和文件路径未持久化
Raises:
ValueError: func_req.id None
Exception: LLM 调用失败或文件写入失败
"""
if func_req.id is None:
raise ValueError("FunctionalRequirement 必须先持久化id 不能为 None")
knowledge_section = self._build_knowledge_section(knowledge)
signature_json = self._build_signature_json(signature, func_req)
prompt = config.CODE_GEN_PROMPT_TEMPLATE.format(
language=language,
knowledge_section=knowledge_section,
title=func_req.title,
description=func_req.description,
signature_json=signature_json,
)
code_content = self.llm.chat(
system_prompt=(
f"你是一位资深 {language} 工程师,只输出纯代码,"
"不添加任何 markdown 标记。函数签名必须与提供的 JSON 规范完全一致。"
),
user_prompt=prompt,
)
file_path = write_code_file(
output_dir=output_dir,
function_name=func_req.function_name,
language=language,
content=code_content,
)
ext = get_file_extension(language)
file_name = f"{func_req.function_name}{ext}"
return CodeFile(
project_id=func_req.project_id,
func_req_id=func_req.id,
file_name=file_name,
file_path=file_path,
language=language,
content=code_content,
)
# ══════════════════════════════════════════════════
# 批量生成
# ══════════════════════════════════════════════════
def generate_batch(
self,
func_reqs: List[FunctionalRequirement],
output_dir: str,
language: str = config.DEFAULT_LANGUAGE,
knowledge: str = "",
signatures: Optional[List[dict]] = None,
on_progress: Optional[Callable] = None,
) -> List[CodeFile]:
"""
批量生成代码文件
Args:
func_reqs: 功能需求列表
output_dir: 输出目录
language: 目标语言
knowledge: 知识库文本
signatures: func_reqs 等长的签名列表索引对应
None 时所有条目均以无约束模式生成
on_progress: 进度回调 fn(index, total, func_req, code_file, error)
Returns:
成功生成的 CodeFile 列表
"""
results = []
total = len(func_reqs)
# 构建 func_req.id → signature 的快速查找表
sig_map = self._build_signature_map(func_reqs, signatures)
for i, req in enumerate(func_reqs):
sig = sig_map.get(req.id)
try:
code_file = self.generate(
func_req=req,
output_dir=output_dir,
language=language,
knowledge=knowledge,
signature=sig,
)
results.append(code_file)
if on_progress:
on_progress(i + 1, total, req, code_file, None)
except Exception as e:
if on_progress:
on_progress(i + 1, total, req, None, e)
return results
# ══════════════════════════════════════════════════
# 私有工具方法
# ══════════════════════════════════════════════════
@staticmethod
def _build_knowledge_section(knowledge: str) -> str:
"""构建知识库 Prompt 段落"""
if not knowledge or not knowledge.strip():
return ""
return (
"## 参考知识库(实现时请遵循以下规范)\n"
f"{knowledge}\n\n---\n"
)
@staticmethod
def _build_signature_json(
signature: Optional[dict],
func_req: FunctionalRequirement,
) -> str:
"""
将签名 dict 序列化为格式化 JSON 字符串
若签名为 None则构造最小占位签名保持 Prompt 结构完整
Args:
signature: 签名 dict None
func_req: 对应的功能需求用于占位签名
Returns:
JSON 字符串
"""
if signature:
return json.dumps(signature, ensure_ascii=False, indent=2)
# 无签名时的最小占位,提示 LLM 自行设计但保持格式
fallback = {
"name": func_req.function_name,
"requirement_id": f"REQ.{func_req.index_no:02d}",
"description": func_req.description,
"type": "function",
"parameters": "<<请根据功能描述自行设计参数>>",
"return": "<<请根据功能描述自行设计返回值>>",
}
return json.dumps(fallback, ensure_ascii=False, indent=2)
@staticmethod
def _build_signature_map(
func_reqs: List[FunctionalRequirement],
signatures: Optional[List[dict]],
) -> dict:
"""
构建 func_req.id signature 映射表
Args:
func_reqs: 功能需求列表
signatures: func_reqs 等长的签名列表 None
Returns:
{req_id: signature_dict} 字典
"""
if not signatures:
return {}
sig_map = {}
for req, sig in zip(func_reqs, signatures):
if req.id is not None and sig:
sig_map[req.id] = sig
return sig_map

View File

@ -0,0 +1,90 @@
# core/llm_client.py - LLM 客户端封装
import json
from typing import Optional
import config
class LLMClient:
"""
OpenAI 兼容 LLM 客户端封装
支持任何兼容 OpenAI API 格式的服务OpenAI / Azure / 本地模型等
"""
def __init__(
self,
api_key: str = config.LLM_API_KEY,
base_url: str = config.LLM_BASE_URL,
model: str = config.LLM_MODEL,
temperature: float = config.LLM_TEMPERATURE,
):
"""
初始化 LLM 客户端
Args:
api_key: API 密钥
base_url: API 基础 URL
model: 模型名称
temperature: 生成温度0~1越低越确定
Raises:
ImportError: 未安装 openai
ValueError: api_key 为空
"""
try:
from openai import OpenAI
except ImportError:
raise ImportError("请安装 openai: pip install openai")
if not api_key:
raise ValueError("LLM_API_KEY 未配置,请在 .env 文件中设置")
self.model = model
self.temperature = temperature
self._client = OpenAI(api_key=api_key, base_url=base_url)
def chat(self, system_prompt: str, user_prompt: str) -> str:
"""
发送对话请求返回模型回复文本
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
Returns:
模型回复的文本内容
Raises:
Exception: API 调用失败
"""
response = self._client.chat.completions.create(
model=self.model,
temperature=self.temperature,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
return response.choices[0].message.content.strip()
def chat_json(self, system_prompt: str, user_prompt: str) -> dict:
"""
发送对话请求解析并返回 JSON 结果
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
Returns:
解析后的 dict 对象
Raises:
json.JSONDecodeError: 模型返回非合法 JSON
"""
raw = self.chat(system_prompt, user_prompt)
# 去除可能的 markdown 代码块包裹
raw = raw.strip()
if raw.startswith("```"):
lines = raw.split("\n")
raw = "\n".join(lines[1:-1])
return json.loads(raw)

View File

@ -0,0 +1,239 @@
# core/requirement_analyzer.py - 需求分解 & 函数签名生成
import re
from typing import List, Optional
import config
from core.llm_client import LLMClient
from database.models import FunctionalRequirement
class RequirementAnalyzer:
"""
使用 LLM 将原始需求分解为功能需求列表并生成函数接口签名
支持注入知识库上下文以提升分解质量
"""
def __init__(self, llm_client: Optional[LLMClient] = None):
"""
初始化需求分析器
Args:
llm_client: LLM 客户端实例 None 时自动创建
"""
self.llm = llm_client or LLMClient()
# ══════════════════════════════════════════════════
# 需求分解
# ══════════════════════════════════════════════════
def decompose(
self,
raw_requirement: str,
project_id: int,
raw_req_id: int,
knowledge: str = "",
) -> List[FunctionalRequirement]:
"""
将原始需求分解为功能需求列表
Args:
raw_requirement: 原始需求文本
project_id: 所属项目 ID
raw_req_id: 原始需求记录 ID
knowledge: 知识库文本可选
Returns:
FunctionalRequirement 对象列表未持久化id=None
Raises:
ValueError: LLM 返回格式不合法
json.JSONDecodeError: JSON 解析失败
"""
knowledge_section = self._build_knowledge_section(knowledge)
prompt = config.DECOMPOSE_PROMPT_TEMPLATE.format(
knowledge_section=knowledge_section,
raw_requirement=raw_requirement,
)
result = self.llm.chat_json(
system_prompt="你是一位资深软件架构师,擅长需求分析与系统设计。",
user_prompt=prompt,
)
items = result.get("functional_requirements", [])
if not items:
raise ValueError("LLM 未返回任何功能需求,请检查原始需求描述")
requirements = []
for item in items:
req = FunctionalRequirement(
project_id=project_id,
raw_req_id=raw_req_id,
index_no=int(item.get("index", len(requirements) + 1)),
title=item.get("title", "未命名功能"),
description=item.get("description", ""),
function_name=self._sanitize_function_name(
item.get("function_name", f"func_{len(requirements)+1}")
),
priority=item.get("priority", "medium"),
)
requirements.append(req)
return requirements
# ══════════════════════════════════════════════════
# 函数签名生成(新增)
# ══════════════════════════════════════════════════
def build_function_signature(
self,
func_req: FunctionalRequirement,
knowledge: str = "",
) -> dict:
"""
为单个功能需求生成函数接口签名 JSON
Args:
func_req: 功能需求对象需含有效 id
knowledge: 知识库文本可选
Returns:
符合接口规范的 dict包含 name/requirement_id/description/
type/parameters/return 字段
Raises:
json.JSONDecodeError: LLM 返回非合法 JSON
"""
requirement_id = self._format_requirement_id(func_req.index_no)
knowledge_section = self._build_knowledge_section(knowledge)
prompt = config.FUNC_SIGNATURE_PROMPT_TEMPLATE.format(
knowledge_section=knowledge_section,
requirement_id=requirement_id,
title=func_req.title,
function_name=func_req.function_name,
description=func_req.description,
)
signature = self.llm.chat_json(
system_prompt=(
"你是一位资深软件架构师,专注于 API 接口设计。"
"只输出合法 JSON不添加任何说明文字。"
),
user_prompt=prompt,
)
# 确保关键字段存在,做兜底处理
signature.setdefault("name", func_req.function_name)
signature.setdefault("requirement_id", requirement_id)
signature.setdefault("description", func_req.description)
signature.setdefault("type", "function")
signature.setdefault("parameters", {})
signature.setdefault("return", {"type": "void", "description": ""})
return signature
def build_function_signatures_batch(
self,
func_reqs: List[FunctionalRequirement],
knowledge: str = "",
on_progress=None,
) -> List[dict]:
"""
批量为功能需求列表生成函数接口签名
Args:
func_reqs: 功能需求列表
knowledge: 知识库文本可选
on_progress: 进度回调 fn(index, total, func_req, signature, error)
Returns:
签名 dict 列表顺序与 func_reqs 一致
生成失败的条目使用降级结构填充不中断整体流程
"""
results = []
total = len(func_reqs)
for i, req in enumerate(func_reqs):
try:
sig = self.build_function_signature(req, knowledge)
results.append(sig)
if on_progress:
on_progress(i + 1, total, req, sig, None)
except Exception as e:
# 降级:用基础信息填充,保证 JSON 完整性
fallback = self._build_fallback_signature(req)
results.append(fallback)
if on_progress:
on_progress(i + 1, total, req, fallback, e)
return results
# ══════════════════════════════════════════════════
# 私有工具方法
# ══════════════════════════════════════════════════
@staticmethod
def _build_knowledge_section(knowledge: str) -> str:
"""构建知识库 Prompt 段落"""
if not knowledge or not knowledge.strip():
return ""
return f"""## 参考知识库
{knowledge}
---
"""
@staticmethod
def _sanitize_function_name(name: str) -> str:
"""
清理函数名确保符合 snake_case 规范
Args:
name: 原始函数名
Returns:
合法的 snake_case 函数名
"""
name = re.sub(r"[^a-zA-Z0-9_]", "_", name).lower()
name = re.sub(r"_+", "_", name).strip("_")
if name and name[0].isdigit():
name = "func_" + name
return name or "unnamed_function"
@staticmethod
def _format_requirement_id(index_no: int) -> str:
"""
将序号格式化为需求编号字符串
Args:
index_no: 功能需求序号 1 开始
Returns:
格式化编号 'REQ.01''REQ.12'
"""
return f"REQ.{index_no:02d}"
@staticmethod
def _build_fallback_signature(func_req: FunctionalRequirement) -> dict:
"""
构建降级签名LLM 调用失败时使用
Args:
func_req: 功能需求对象
Returns:
包含基础信息的签名 dict
"""
return {
"name": func_req.function_name,
"requirement_id": f"REQ.{func_req.index_no:02d}",
"description": func_req.description,
"type": "function",
"parameters": {},
"return": {
"type": "void",
"description": "TODO: define return value"
},
"_note": "Auto-generated fallback due to LLM error"
}

View File

@ -0,0 +1,314 @@
# database/db_manager.py - 数据库操作管理器
import sqlite3
import os
from datetime import datetime
from typing import List, Optional
from contextlib import contextmanager
from database.models import (
CREATE_TABLES_SQL, Project, RawRequirement,
FunctionalRequirement, CodeFile
)
import config
class DBManager:
"""SQLite 数据库管理器,封装所有 CRUD 操作"""
def __init__(self, db_path: str = config.DB_PATH):
self.db_path = db_path
os.makedirs(os.path.dirname(db_path), exist_ok=True)
self._init_db()
# ── 连接上下文管理器 ──────────────────────────────────
@contextmanager
def _get_conn(self):
"""获取数据库连接(自动提交/回滚)"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _init_db(self):
"""初始化数据库,创建所有表"""
with self._get_conn() as conn:
conn.executescript(CREATE_TABLES_SQL)
# ══════════════════════════════════════════════════
# Project CRUD
# ══════════════════════════════════════════════════
def create_project(self, project: Project) -> int:
"""创建项目,返回新项目 ID"""
sql = """
INSERT INTO projects (name, description, language, output_dir, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
"""
with self._get_conn() as conn:
cur = conn.execute(sql, (
project.name, project.description, project.language,
project.output_dir, project.created_at, project.updated_at
))
return cur.lastrowid
def get_project_by_id(self, project_id: int) -> Optional[Project]:
"""根据 ID 查询项目"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM projects WHERE id = ?", (project_id,)
).fetchone()
if row is None:
return None
return Project(
id=row["id"], name=row["name"], description=row["description"],
language=row["language"], output_dir=row["output_dir"],
created_at=row["created_at"], updated_at=row["updated_at"]
)
def get_project_by_name(self, name: str) -> Optional[Project]:
"""根据名称查询项目"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM projects WHERE name = ?", (name,)
).fetchone()
if row is None:
return None
return Project(
id=row["id"], name=row["name"], description=row["description"],
language=row["language"], output_dir=row["output_dir"],
created_at=row["created_at"], updated_at=row["updated_at"]
)
def list_projects(self) -> List[Project]:
"""列出所有项目"""
with self._get_conn() as conn:
rows = conn.execute(
"SELECT * FROM projects ORDER BY created_at DESC"
).fetchall()
return [
Project(
id=r["id"], name=r["name"], description=r["description"],
language=r["language"], output_dir=r["output_dir"],
created_at=r["created_at"], updated_at=r["updated_at"]
) for r in rows
]
def update_project(self, project: Project) -> None:
"""更新项目信息"""
project.updated_at = datetime.now().isoformat()
sql = """
UPDATE projects
SET name=?, description=?, language=?, output_dir=?, updated_at=?
WHERE id=?
"""
with self._get_conn() as conn:
conn.execute(sql, (
project.name, project.description, project.language,
project.output_dir, project.updated_at, project.id
))
def delete_project(self, project_id: int) -> None:
"""删除项目(级联删除所有关联数据)"""
with self._get_conn() as conn:
conn.execute("DELETE FROM projects WHERE id = ?", (project_id,))
# ══════════════════════════════════════════════════
# RawRequirement CRUD
# ══════════════════════════════════════════════════
def create_raw_requirement(self, req: RawRequirement) -> int:
"""创建原始需求,返回新记录 ID"""
sql = """
INSERT INTO raw_requirements
(project_id, content, source_type, source_name, knowledge, created_at)
VALUES (?, ?, ?, ?, ?, ?)
"""
with self._get_conn() as conn:
cur = conn.execute(sql, (
req.project_id, req.content, req.source_type,
req.source_name, req.knowledge, req.created_at
))
return cur.lastrowid
def get_raw_requirement(self, req_id: int) -> Optional[RawRequirement]:
"""根据 ID 查询原始需求"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM raw_requirements WHERE id = ?", (req_id,)
).fetchone()
if row is None:
return None
return RawRequirement(
id=row["id"], project_id=row["project_id"], content=row["content"],
source_type=row["source_type"], source_name=row["source_name"],
knowledge=row["knowledge"], created_at=row["created_at"]
)
def list_raw_requirements_by_project(self, project_id: int) -> List[RawRequirement]:
"""查询项目下所有原始需求"""
with self._get_conn() as conn:
rows = conn.execute(
"SELECT * FROM raw_requirements WHERE project_id = ? ORDER BY created_at",
(project_id,)
).fetchall()
return [
RawRequirement(
id=r["id"], project_id=r["project_id"], content=r["content"],
source_type=r["source_type"], source_name=r["source_name"],
knowledge=r["knowledge"], created_at=r["created_at"]
) for r in rows
]
# ══════════════════════════════════════════════════
# FunctionalRequirement CRUD
# ══════════════════════════════════════════════════
def create_functional_requirement(self, req: FunctionalRequirement) -> int:
"""创建功能需求,返回新记录 ID"""
sql = """
INSERT INTO functional_requirements
(project_id, raw_req_id, index_no, title, description,
function_name, priority, status, is_custom, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
with self._get_conn() as conn:
cur = conn.execute(sql, (
req.project_id, req.raw_req_id, req.index_no,
req.title, req.description, req.function_name,
req.priority, req.status, int(req.is_custom),
req.created_at, req.updated_at
))
return cur.lastrowid
def get_functional_requirement(self, req_id: int) -> Optional[FunctionalRequirement]:
"""根据 ID 查询功能需求"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM functional_requirements WHERE id = ?", (req_id,)
).fetchone()
if row is None:
return None
return self._row_to_func_req(row)
def list_functional_requirements(self, project_id: int) -> List[FunctionalRequirement]:
"""查询项目下所有功能需求(按序号排序)"""
with self._get_conn() as conn:
rows = conn.execute(
"""SELECT * FROM functional_requirements
WHERE project_id = ? ORDER BY index_no""",
(project_id,)
).fetchall()
return [self._row_to_func_req(r) for r in rows]
def update_functional_requirement(self, req: FunctionalRequirement) -> None:
"""更新功能需求"""
req.updated_at = datetime.now().isoformat()
sql = """
UPDATE functional_requirements
SET title=?, description=?, function_name=?, priority=?,
status=?, index_no=?, updated_at=?
WHERE id=?
"""
with self._get_conn() as conn:
conn.execute(sql, (
req.title, req.description, req.function_name,
req.priority, req.status, req.index_no,
req.updated_at, req.id
))
def delete_functional_requirement(self, req_id: int) -> None:
"""删除功能需求"""
with self._get_conn() as conn:
conn.execute(
"DELETE FROM functional_requirements WHERE id = ?", (req_id,)
)
def _row_to_func_req(self, row) -> FunctionalRequirement:
"""sqlite Row → FunctionalRequirement 对象"""
return FunctionalRequirement(
id=row["id"], project_id=row["project_id"],
raw_req_id=row["raw_req_id"], index_no=row["index_no"],
title=row["title"], description=row["description"],
function_name=row["function_name"], priority=row["priority"],
status=row["status"], is_custom=bool(row["is_custom"]),
created_at=row["created_at"], updated_at=row["updated_at"]
)
# ══════════════════════════════════════════════════
# CodeFile CRUD
# ══════════════════════════════════════════════════
def create_code_file(self, code_file: CodeFile) -> int:
"""创建代码文件记录,返回新记录 ID"""
sql = """
INSERT INTO code_files
(project_id, func_req_id, file_name, file_path,
language, content, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"""
with self._get_conn() as conn:
cur = conn.execute(sql, (
code_file.project_id, code_file.func_req_id,
code_file.file_name, code_file.file_path,
code_file.language, code_file.content,
code_file.created_at, code_file.updated_at
))
return cur.lastrowid
def upsert_code_file(self, code_file: CodeFile) -> int:
"""插入或更新代码文件(按 func_req_id 唯一键)"""
existing = self.get_code_file_by_func_req(code_file.func_req_id)
if existing:
code_file.id = existing.id
code_file.updated_at = datetime.now().isoformat()
sql = """
UPDATE code_files
SET file_name=?, file_path=?, language=?, content=?, updated_at=?
WHERE id=?
"""
with self._get_conn() as conn:
conn.execute(sql, (
code_file.file_name, code_file.file_path,
code_file.language, code_file.content,
code_file.updated_at, code_file.id
))
return code_file.id
else:
return self.create_code_file(code_file)
def get_code_file_by_func_req(self, func_req_id: int) -> Optional[CodeFile]:
"""根据功能需求 ID 查询代码文件"""
with self._get_conn() as conn:
row = conn.execute(
"SELECT * FROM code_files WHERE func_req_id = ?", (func_req_id,)
).fetchone()
if row is None:
return None
return self._row_to_code_file(row)
def list_code_files_by_project(self, project_id: int) -> List[CodeFile]:
"""查询项目下所有代码文件"""
with self._get_conn() as conn:
rows = conn.execute(
"SELECT * FROM code_files WHERE project_id = ? ORDER BY created_at",
(project_id,)
).fetchall()
return [self._row_to_code_file(r) for r in rows]
def _row_to_code_file(self, row) -> CodeFile:
"""sqlite Row → CodeFile 对象"""
return CodeFile(
id=row["id"], project_id=row["project_id"],
func_req_id=row["func_req_id"], file_name=row["file_name"],
file_path=row["file_path"], language=row["language"],
content=row["content"], created_at=row["created_at"],
updated_at=row["updated_at"]
)

View File

@ -0,0 +1,122 @@
# database/models.py - 数据模型定义SQLite 建表 DDL
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional
# ══════════════════════════════════════════════════════
# DDL 建表语句
# ══════════════════════════════════════════════════════
CREATE_TABLES_SQL = """
PRAGMA foreign_keys = ON;
-- 项目表
CREATE TABLE IF NOT EXISTS projects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE, -- 项目名称
description TEXT, -- 项目描述
language TEXT NOT NULL DEFAULT 'python', -- 目标代码语言
output_dir TEXT NOT NULL, -- 输出目录路径
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
-- 原始需求表
CREATE TABLE IF NOT EXISTS raw_requirements (
id INTEGER PRIMARY KEY AUTOINCREMENT,
project_id INTEGER NOT NULL, -- 关联项目
content TEXT NOT NULL, -- 需求原文
source_type TEXT NOT NULL DEFAULT 'text', -- text | file
source_name TEXT, -- 文件名文件输入时
knowledge TEXT, -- 合并后的知识库内容
created_at TEXT NOT NULL,
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
);
-- 功能需求表
CREATE TABLE IF NOT EXISTS functional_requirements (
id INTEGER PRIMARY KEY AUTOINCREMENT,
project_id INTEGER NOT NULL,
raw_req_id INTEGER NOT NULL, -- 关联原始需求
index_no INTEGER NOT NULL, -- 序号
title TEXT NOT NULL, -- 功能标题
description TEXT NOT NULL, -- 功能描述
function_name TEXT NOT NULL, -- 对应函数名
priority TEXT NOT NULL DEFAULT 'medium', -- high|medium|low
status TEXT NOT NULL DEFAULT 'pending', -- pending|generated|skipped
is_custom INTEGER NOT NULL DEFAULT 0, -- 是否用户自定义 (0/1)
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
FOREIGN KEY (raw_req_id) REFERENCES raw_requirements(id) ON DELETE CASCADE
);
-- 代码文件表
CREATE TABLE IF NOT EXISTS code_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
project_id INTEGER NOT NULL,
func_req_id INTEGER NOT NULL UNIQUE, -- 关联功能需求1对1
file_name TEXT NOT NULL, -- 文件名
file_path TEXT NOT NULL, -- 完整路径
language TEXT NOT NULL, -- 代码语言
content TEXT NOT NULL, -- 代码内容
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
FOREIGN KEY (func_req_id) REFERENCES functional_requirements(id) ON DELETE CASCADE
);
"""
# ══════════════════════════════════════════════════════
# 数据类Python 对象映射)
# ══════════════════════════════════════════════════════
@dataclass
class Project:
name: str
output_dir: str
language: str = "python"
description: str = ""
id: Optional[int] = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
@dataclass
class RawRequirement:
project_id: int
content: str
source_type: str = "text" # text | file
source_name: Optional[str] = None
knowledge: Optional[str] = None
id: Optional[int] = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
@dataclass
class FunctionalRequirement:
project_id: int
raw_req_id: int
index_no: int
title: str
description: str
function_name: str
priority: str = "medium"
status: str = "pending"
is_custom: bool = False
id: Optional[int] = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
@dataclass
class CodeFile:
project_id: int
func_req_id: int
file_name: str
file_path: str
language: str
content: str
id: Optional[int] = None
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())

View File

@ -0,0 +1,789 @@
#!/usr/bin/env python3
# encoding: utf-8
# main.py - 主入口:支持交互式 & 非交互式CLI 参数)两种运行模式
#
# 交互式: python main.py
# 非交互式python main.py --non-interactive \
# --project-name "MyProject" \
# --language python \
# --requirement-text "用户管理系统,包含注册、登录、修改密码功能"
#
# 完整参数见python main.py --help
import os
import sys
from typing import Dict
import click
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.prompt import Prompt, Confirm
import config
from database.db_manager import DBManager
from database.models import Project, RawRequirement, FunctionalRequirement
from core.llm_client import LLMClient
from core.requirement_analyzer import RequirementAnalyzer
from core.code_generator import CodeGenerator
from utils.file_handler import read_file_auto, merge_knowledge_files
from utils.output_writer import (
ensure_project_dir, build_project_output_dir, write_project_readme,
write_function_signatures_json, validate_all_signatures,
patch_signatures_with_url,
)
console = Console()
db = DBManager()
# ══════════════════════════════════════════════════════
# 显示工具函数
# ══════════════════════════════════════════════════════
def print_banner():
console.print(Panel.fit(
"[bold cyan]🚀 需求分析 & 代码生成工具[/bold cyan]\n"
"[dim]Powered by LLM · SQLite · Python[/dim]",
border_style="cyan"
))
def print_functional_requirements(reqs: list):
"""以表格形式展示功能需求列表"""
table = Table(title="📋 功能需求列表", show_lines=True)
table.add_column("序号", style="cyan", width=6)
table.add_column("ID", style="dim", width=6)
table.add_column("标题", style="bold", width=20)
table.add_column("函数名", width=25)
table.add_column("优先级", width=8)
table.add_column("类型", width=8)
table.add_column("描述", width=40)
priority_color = {"high": "red", "medium": "yellow", "low": "green"}
for req in reqs:
color = priority_color.get(req.priority, "white")
table.add_row(
str(req.index_no),
str(req.id) if req.id else "-",
req.title,
f"[code]{req.function_name}[/code]",
f"[{color}]{req.priority}[/{color}]",
"[magenta]自定义[/magenta]" if req.is_custom else "LLM生成",
req.description[:60] + "..." if len(req.description) > 60 else req.description,
)
console.print(table)
def print_signatures_preview(signatures: list):
"""
以表格形式预览函数签名列表 url 字段
Args:
signatures: 纯签名列表顶层文档的 "functions" 字段
"""
table = Table(title="📄 函数签名预览", show_lines=True)
table.add_column("需求编号", style="cyan", width=10)
table.add_column("函数名", style="bold", width=25)
table.add_column("参数数量", width=8)
table.add_column("返回类型", width=10)
table.add_column("成功返回值", width=18)
table.add_column("失败返回值", width=18)
table.add_column("URL", style="dim", width=30)
def _fmt_value(v) -> str:
if v is None:
return "-"
if isinstance(v, dict):
return "{" + ", ".join(v.keys()) + "}"
return str(v)[:16]
for sig in signatures:
ret = sig.get("return") or {}
on_success = ret.get("on_success") or {}
on_failure = ret.get("on_failure") or {}
url = sig.get("url", "")
# 只显示文件名部分,避免路径过长
url_display = os.path.basename(url) if url else "[dim]待生成[/dim]"
table.add_row(
sig.get("requirement_id", "-"),
sig.get("name", "-"),
str(len(sig.get("parameters", {}))),
ret.get("type", "void"),
_fmt_value(on_success.get("value")),
_fmt_value(on_failure.get("value")),
url_display,
)
console.print(table)
# ══════════════════════════════════════════════════════
# Step 1项目初始化
# ══════════════════════════════════════════════════════
def step_init_project(
project_name: str = None,
language: str = None,
description: str = "",
non_interactive: bool = False,
) -> Project:
if not non_interactive:
console.print("\n[bold]Step 1 · 项目配置[/bold]", style="blue")
project_name = project_name or Prompt.ask("📁 请输入项目名称")
language = language or Prompt.ask(
"💻 目标代码语言",
default=config.DEFAULT_LANGUAGE,
choices=["python", "javascript", "typescript", "java", "go", "rust"],
)
description = description or Prompt.ask("📝 项目描述(可选)", default="")
else:
if not project_name:
raise ValueError("非交互模式下 --project-name 为必填项")
language = language or config.DEFAULT_LANGUAGE
console.print("\n[bold]Step 1 · 项目配置[/bold] [dim](非交互)[/dim]", style="blue")
console.print(f" 项目名称: {project_name} 语言: {language}")
existing = db.get_project_by_name(project_name)
if existing:
if non_interactive:
console.print(f"[green]✓ 已加载已有项目: {project_name} (ID={existing.id})[/green]")
return existing
use_existing = Confirm.ask(f"⚠️ 项目 '{project_name}' 已存在,是否继续使用?")
if use_existing:
console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]")
return existing
project_name = Prompt.ask("请输入新的项目名称")
output_dir = build_project_output_dir(project_name)
project = Project(
name=project_name,
language=language,
output_dir=output_dir,
description=description,
)
project.id = db.create_project(project)
console.print(f"[green]✓ 项目已创建: {project_name} (ID={project.id})[/green]")
return project
# ══════════════════════════════════════════════════════
# Step 2输入原始需求 & 知识库
# ══════════════════════════════════════════════════════
def step_input_requirement(
project: Project,
requirement_text: str = None,
requirement_file: str = None,
knowledge_files: list = None,
non_interactive: bool = False,
) -> tuple:
console.print(
f"\n[bold]Step 2 · 输入原始需求[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
raw_text = ""
source_name = None
source_type = "text"
if non_interactive:
if requirement_file:
raw_text = read_file_auto(requirement_file)
source_name = os.path.basename(requirement_file)
source_type = "file"
console.print(f" 需求文件: {source_name} ({len(raw_text)} 字符)")
elif requirement_text:
raw_text = requirement_text
source_type = "text"
console.print(f" 需求文本: {raw_text[:80]}{'...' if len(raw_text) > 80 else ''}")
else:
raise ValueError("非交互模式下必须提供 --requirement-text 或 --requirement-file")
else:
input_type = Prompt.ask("📥 需求输入方式", choices=["text", "file"], default="text")
if input_type == "text":
console.print("[dim]请输入原始需求(输入空行结束):[/dim]")
lines = []
while True:
line = input()
if line == "" and lines:
break
lines.append(line)
raw_text = "\n".join(lines)
source_type = "text"
else:
file_path = Prompt.ask("📂 需求文件路径")
raw_text = read_file_auto(file_path)
source_name = os.path.basename(file_path)
source_type = "file"
console.print(f"[green]✓ 已读取文件: {source_name} ({len(raw_text)} 字符)[/green]")
knowledge_text = ""
if non_interactive:
if knowledge_files:
knowledge_text = merge_knowledge_files(list(knowledge_files))
console.print(f" 知识库: {len(knowledge_files)} 个文件,{len(knowledge_text)} 字符")
else:
use_kb = Confirm.ask("📚 是否输入知识库文件?", default=False)
if use_kb:
kb_paths = []
while True:
kb_path = Prompt.ask("知识库文件路径(留空结束)", default="")
if not kb_path:
break
if os.path.exists(kb_path):
kb_paths.append(kb_path)
console.print(f" [green]+ {kb_path}[/green]")
else:
console.print(f" [red]文件不存在: {kb_path}[/red]")
if kb_paths:
knowledge_text = merge_knowledge_files(kb_paths)
console.print(f"[green]✓ 知识库已合并 ({len(knowledge_text)} 字符)[/green]")
return raw_text, knowledge_text, source_name, source_type
# ══════════════════════════════════════════════════════
# Step 3LLM 分解需求
# ══════════════════════════════════════════════════════
def step_decompose_requirements(
project: Project,
raw_text: str,
knowledge_text: str,
source_name: str,
source_type: str,
non_interactive: bool = False,
) -> tuple:
console.print(
f"\n[bold]Step 3 · LLM 需求分解[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
raw_req = RawRequirement(
project_id=project.id,
content=raw_text,
source_type=source_type,
source_name=source_name,
knowledge=knowledge_text or None,
)
raw_req_id = db.create_raw_requirement(raw_req)
console.print(f"[dim]原始需求已存储 (ID={raw_req_id})[/dim]")
with console.status("[bold yellow]🤖 LLM 正在分解需求,请稍候...[/bold yellow]"):
llm = LLMClient()
analyzer = RequirementAnalyzer(llm)
func_reqs = analyzer.decompose(
raw_requirement=raw_text,
project_id=project.id,
raw_req_id=raw_req_id,
knowledge=knowledge_text,
)
for req in func_reqs:
req.id = db.create_functional_requirement(req)
console.print(f"[green]✓ 已生成 {len(func_reqs)} 个功能需求[/green]")
return raw_req_id, func_reqs
# ══════════════════════════════════════════════════════
# Step 4用户编辑功能需求
# ══════════════════════════════════════════════════════
def step_edit_requirements(
project: Project,
func_reqs: list,
raw_req_id: int,
non_interactive: bool = False,
skip_indices: list = None,
) -> list:
console.print(
f"\n[bold]Step 4 · 编辑功能需求[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
if non_interactive:
if skip_indices:
to_skip = set(skip_indices)
removed, kept = [], []
for req in func_reqs:
if req.index_no in to_skip:
db.delete_functional_requirement(req.id)
removed.append(req.title)
else:
kept.append(req)
func_reqs = kept
for i, req in enumerate(func_reqs, 1):
req.index_no = i
db.update_functional_requirement(req)
if removed:
console.print(f" [red]已跳过: {', '.join(removed)}[/red]")
print_functional_requirements(func_reqs)
return func_reqs
while True:
print_functional_requirements(func_reqs)
console.print(
"\n操作: [cyan]d[/cyan]=删除 [cyan]a[/cyan]=添加 "
"[cyan]e[/cyan]=编辑 [cyan]ok[/cyan]=确认继续"
)
action = Prompt.ask("请选择操作", default="ok").strip().lower()
if action == "ok":
break
elif action == "d":
idx_str = Prompt.ask("输入要删除的功能需求序号(多个用逗号分隔)")
to_delete = {int(x.strip()) for x in idx_str.split(",") if x.strip().isdigit()}
removed, kept = [], []
for req in func_reqs:
if req.index_no in to_delete:
db.delete_functional_requirement(req.id)
removed.append(req.title)
else:
kept.append(req)
func_reqs = kept
for i, req in enumerate(func_reqs, 1):
req.index_no = i
db.update_functional_requirement(req)
console.print(f"[red]✗ 已删除: {', '.join(removed)}[/red]")
elif action == "a":
title = Prompt.ask("功能标题")
description = Prompt.ask("功能描述")
func_name = Prompt.ask("函数名 (snake_case)")
priority = Prompt.ask(
"优先级", choices=["high", "medium", "low"], default="medium"
)
new_req = FunctionalRequirement(
project_id=project.id,
raw_req_id=raw_req_id,
index_no=len(func_reqs) + 1,
title=title,
description=description,
function_name=func_name,
priority=priority,
is_custom=True,
)
new_req.id = db.create_functional_requirement(new_req)
func_reqs.append(new_req)
console.print(f"[green]✓ 已添加自定义需求: {title}[/green]")
elif action == "e":
idx_str = Prompt.ask("输入要编辑的功能需求序号")
if not idx_str.isdigit():
continue
idx = int(idx_str)
target = next((r for r in func_reqs if r.index_no == idx), None)
if target is None:
console.print("[red]序号不存在[/red]")
continue
target.title = Prompt.ask("新标题", default=target.title)
target.description = Prompt.ask("新描述", default=target.description)
target.function_name = Prompt.ask("新函数名", default=target.function_name)
target.priority = Prompt.ask(
"新优先级", choices=["high", "medium", "low"], default=target.priority
)
db.update_functional_requirement(target)
console.print(f"[green]✓ 已更新: {target.title}[/green]")
return func_reqs
# ══════════════════════════════════════════════════════
# Step 5A生成函数签名 JSON不含 url 字段,待 5C 回写)
# ══════════════════════════════════════════════════════
def step_generate_signatures(
project: Project,
func_reqs: list,
output_dir: str,
knowledge_text: str,
json_file_name: str = "function_signatures.json",
non_interactive: bool = False,
) -> tuple:
"""
为所有功能需求生成函数签名写入初版 JSON不含 url 字段
url 字段将在 Step 5C 代码生成完成后回写并刷新 JSON 文件
Returns:
(signatures: List[dict], json_path: str)
"""
console.print(
f"\n[bold]Step 5A · 生成函数签名 JSON[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
llm = LLMClient()
analyzer = RequirementAnalyzer(llm)
success_count = 0
fail_count = 0
def on_progress(index, total, req, signature, error):
nonlocal success_count, fail_count
if error:
console.print(
f" [{index}/{total}] [yellow]⚠ {req.title} 签名生成失败,"
f"使用降级结构: {error}[/yellow]"
)
fail_count += 1
else:
console.print(
f" [{index}/{total}] [green]✓ {req.title}[/green] "
f"→ [dim]{signature.get('name')}()[/dim] "
f"params={len(signature.get('parameters', {}))}"
)
success_count += 1
console.print(f"[yellow]正在为 {len(func_reqs)} 个功能需求生成函数签名...[/yellow]")
signatures = analyzer.build_function_signatures_batch(
func_reqs=func_reqs,
knowledge=knowledge_text,
on_progress=on_progress,
)
# 校验
validation_report = validate_all_signatures(signatures)
if validation_report:
console.print(f"[yellow]⚠ 发现 {len(validation_report)} 个签名存在结构问题:[/yellow]")
for fname, errors in validation_report.items():
for err in errors:
console.print(f" [yellow]· {fname}: {err}[/yellow]")
else:
console.print("[green]✓ 所有签名结构校验通过[/green]")
# 写入初版 JSONurl 字段尚未填入)
json_path = write_function_signatures_json(
output_dir=output_dir,
signatures=signatures,
project_name=project.name,
project_description=project.description or "", # ← 传入项目描述
file_name=json_file_name,
)
console.print(
f"[green]✓ 签名 JSON 初版已写入: [cyan]{os.path.abspath(json_path)}[/cyan][/green]\n"
f" 成功: {success_count} 降级: {fail_count}"
)
return signatures, json_path
# ══════════════════════════════════════════════════════
# Step 5B生成代码文件收集 {函数名: 文件路径} 映射
# ══════════════════════════════════════════════════════
def step_generate_code(
project: Project,
func_reqs: list,
output_dir: str,
knowledge_text: str,
signatures: list,
non_interactive: bool = False,
) -> Dict[str, str]:
"""
依据签名约束批量生成代码文件
Returns:
func_name_to_url: {函数名: 代码文件绝对路径} 映射表
Step 5C 回写 url 字段使用
生成失败的函数不会出现在映射表中
"""
console.print(
f"\n[bold]Step 5B · 生成代码文件[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
generator = CodeGenerator(LLMClient())
success_count = 0
fail_count = 0
func_name_to_url: Dict[str, str] = {} # ← 收集 函数名 → 文件绝对路径
def on_progress(index, total, req, code_file, error):
nonlocal success_count, fail_count
if error:
console.print(f" [{index}/{total}] [red]✗ {req.title}: {error}[/red]")
fail_count += 1
else:
db.upsert_code_file(code_file)
req.status = "generated"
db.update_functional_requirement(req)
# 收集 函数名 → 绝对文件路径(作为 url 回写)
func_name_to_url[req.function_name] = os.path.abspath(code_file.file_path)
console.print(
f" [{index}/{total}] [green]✓ {req.title}[/green] "
f"→ [dim]{code_file.file_name}[/dim]"
)
success_count += 1
console.print(f"[yellow]开始生成 {len(func_reqs)} 个代码文件(签名约束模式)...[/yellow]")
generator.generate_batch(
func_reqs=func_reqs,
output_dir=output_dir,
language=project.language,
knowledge=knowledge_text,
signatures=signatures,
on_progress=on_progress,
)
req_summary = "\n".join(
f"{i+1}. **{r.title}** (`{r.function_name}`) - {r.description[:80]}"
for i, r in enumerate(func_reqs)
)
write_project_readme(output_dir, project.name, req_summary)
console.print(Panel(
f"[bold green]✅ 代码生成完成![/bold green]\n"
f"成功: {success_count} 失败: {fail_count}\n"
f"输出目录: [cyan]{os.path.abspath(output_dir)}[/cyan]",
border_style="green",
))
return func_name_to_url
# ══════════════════════════════════════════════════════
# Step 5C回写 url 字段并刷新 JSON
# ══════════════════════════════════════════════════════
def step_patch_signatures_url(
project: Project,
signatures: list,
func_name_to_url: Dict[str, str],
output_dir: str,
json_file_name: str,
non_interactive: bool = False,
) -> str:
"""
将代码文件路径回写到签名的 "url" 字段并重新写入 JSON 文件
执行流程
1. 调用 patch_signatures_with_url() 原地修改签名列表
2. 打印最终签名预览 url
3. 重新调用 write_function_signatures_json() 覆盖写入 JSON
Args:
project: 项目对象提供 name description
signatures: Step 5A 产出的签名列表将被原地修改
func_name_to_url: Step 5B 收集的 {函数名: 文件绝对路径} 映射
output_dir: JSON 文件所在目录
json_file_name: JSON 文件名
non_interactive: 是否非交互模式
Returns:
刷新后的 JSON 文件绝对路径
"""
console.print(
f"\n[bold]Step 5C · 回写代码文件路径url到签名 JSON[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
# 原地回写 url 字段
patch_signatures_with_url(signatures, func_name_to_url)
patched = sum(1 for s in signatures if s.get("url"))
unpatched = len(signatures) - patched
if unpatched:
console.print(
f"[yellow]⚠ {unpatched} 个函数未能写入 url"
f"(对应代码文件生成失败)[/yellow]"
)
# 打印最终预览(含 url 列)
print_signatures_preview(signatures)
# 覆盖写入 JSON含 project.description
json_path = write_function_signatures_json(
output_dir=output_dir,
signatures=signatures,
project_name=project.name,
project_description=project.description or "",
file_name=json_file_name,
)
console.print(
f"[green]✓ 签名 JSON 已更新(含 url: "
f"[cyan]{os.path.abspath(json_path)}[/cyan][/green]\n"
f" 已回写: {patched} 未回写: {unpatched}"
)
return os.path.abspath(json_path)
# ══════════════════════════════════════════════════════
# 核心工作流
# ══════════════════════════════════════════════════════
def run_workflow(
project_name: str = None,
language: str = None,
description: str = "",
requirement_text: str = None,
requirement_file: str = None,
knowledge_files: tuple = (),
skip_indices: list = None,
json_file_name: str = "function_signatures.json",
non_interactive: bool = False,
):
"""完整工作流Step 1 → 5C"""
print_banner()
# Step 1
project = step_init_project(
project_name=project_name,
language=language,
description=description,
non_interactive=non_interactive,
)
# Step 2
raw_text, knowledge_text, source_name, source_type = step_input_requirement(
project=project,
requirement_text=requirement_text,
requirement_file=requirement_file,
knowledge_files=list(knowledge_files) if knowledge_files else [],
non_interactive=non_interactive,
)
# Step 3
raw_req_id, func_reqs = step_decompose_requirements(
project=project,
raw_text=raw_text,
knowledge_text=knowledge_text,
source_name=source_name,
source_type=source_type,
non_interactive=non_interactive,
)
# Step 4
func_reqs = step_edit_requirements(
project=project,
func_reqs=func_reqs,
raw_req_id=raw_req_id,
non_interactive=non_interactive,
skip_indices=skip_indices or [],
)
if not func_reqs:
console.print("[red]⚠ 功能需求列表为空,流程终止[/red]")
return
output_dir = ensure_project_dir(project.name)
# Step 5A生成签名初版不含 url
signatures, json_path = step_generate_signatures(
project=project,
func_reqs=func_reqs,
output_dir=output_dir,
knowledge_text=knowledge_text,
json_file_name=json_file_name,
non_interactive=non_interactive,
)
# Step 5B生成代码收集 {函数名: 文件路径}
func_name_to_url = step_generate_code(
project=project,
func_reqs=func_reqs,
output_dir=output_dir,
knowledge_text=knowledge_text,
signatures=signatures,
non_interactive=non_interactive,
)
# Step 5C回写 url 字段,刷新 JSON
json_path = step_patch_signatures_url(
project=project,
signatures=signatures,
func_name_to_url=func_name_to_url,
output_dir=output_dir,
json_file_name=json_file_name,
non_interactive=non_interactive,
)
console.print(Panel(
f"[bold cyan]🎉 全部流程完成![/bold cyan]\n"
f"项目: [bold]{project.name}[/bold]\n"
f"描述: {project.description or '(无)'}\n"
f"代码目录: [cyan]{os.path.abspath(output_dir)}[/cyan]\n"
f"签名文件: [cyan]{json_path}[/cyan]",
border_style="cyan",
))
# ══════════════════════════════════════════════════════
# CLI 入口click
# ══════════════════════════════════════════════════════
@click.command()
@click.option("--non-interactive", is_flag=True, default=False,
help="以非交互模式运行(所有参数通过命令行传入)")
@click.option("--project-name", "-p", default=None, help="项目名称")
@click.option("--language", "-l", default=None,
type=click.Choice(["python","javascript","typescript","java","go","rust"]),
help=f"目标代码语言(默认: {config.DEFAULT_LANGUAGE}")
@click.option("--description", "-d", default="", help="项目描述")
@click.option("--requirement-text","-r", default=None,
help="原始需求文本(与 --requirement-file 二选一)")
@click.option("--requirement-file","-f", default=None,
type=click.Path(exists=True),
help="原始需求文件路径(支持 .txt/.md/.pdf/.docx")
@click.option("--knowledge-file", "-k", default=None, multiple=True,
type=click.Path(exists=True),
help="知识库文件路径(可多次指定,如 -k a.md -k b.pdf")
@click.option("--skip-index", "-s", default=None, multiple=True, type=int,
help="要跳过的功能需求序号(可多次指定,如 -s 2 -s 5")
@click.option("--json-file-name", "-j", default="function_signatures.json",
help="函数签名 JSON 文件名(默认: function_signatures.json")
def cli(
non_interactive, project_name, language, description,
requirement_text, requirement_file, knowledge_file,
skip_index, json_file_name,
):
"""
需求分析 & 代码生成工具
\b
交互式运行推荐初次使用
python main.py
\b
非交互式运行示例
python main.py --non-interactive \\
--project-name "UserSystem" \\
--description "用户管理系统后端服务" \\
--language python \\
--requirement-text "用户管理系统,包含注册、登录、修改密码功能" \\
--knowledge-file docs/api_spec.md \\
--json-file-name api_signatures.json
\b
从文件读取需求 + 跳过部分功能需求
python main.py --non-interactive \\
--project-name "MyProject" \\
--requirement-file requirements.md \\
--skip-index 3 --skip-index 7
"""
try:
run_workflow(
project_name=project_name,
language=language,
description=description,
requirement_text=requirement_text,
requirement_file=requirement_file,
knowledge_files=knowledge_file,
skip_indices=list(skip_index) if skip_index else [],
json_file_name=json_file_name,
non_interactive=non_interactive,
)
except KeyboardInterrupt:
console.print("\n[yellow]用户中断,退出[/yellow]")
sys.exit(0)
except Exception as e:
console.print(f"\n[bold red]❌ 错误: {e}[/bold red]")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
cli()

View File

@ -0,0 +1,51 @@
你是一个产品经理与高级软件工程师非常善于将用户原始需求进行理解、分析并生成条目化的功能需求再根据每个功能需求生成指定的功能函数现在生成python工程实现以上功能设计如下
# 工作流程设计
1. 用户输入:知识库,原始需求描述、项目名称。其中原始需求支持文本与文件格式,知识库可选,通过一个或多个文件输入;
2. 软件通过llm模型结合输入的知识库将原始需求分解成多个可实现的功能需求并显示给用户
3. 软件支持用户删除其中不需要的功能需求,同时支持添加自定义功能需求;
4. 软件根据配置后的功能需求生成指定代码默认使用python的功能函数文件输出到指定项目名称的目录中一个功能需求对应一个函数、一个代码文件
# 数据库设计
1. 使用sqlite数据库存储相关信息包含项目表、原始需求表功能需求表、代码文件表
2. 数据库支持通过ID实现项目、原始需求、功能需求、代码文件之间的关联关系
更新以上代码支持以json格式将所有的功能函数描述输出到json文件格式如下
[
{
"name": "change_password",
"requirement_id": "the id related to the requirement. e.g.REQ.01",
"description": "replace old password with new password",
"type": "function",
"parameters": {
"user_id": { "type": "integer", "inout": "in", "description": "the user's id", "required": true },
"old_password": { "type": "string", "inout": "in", "description": "the old password", "required": true },
"new_password": { "type": "string", "inout": "in", "description": "the user's new password", "required": true }
},
"return": {
"type": "integer",
"description": "0 is successful, nonzero is failure"
}
},
{
"name": "set_and_get_system_status",
"requirement_id": "the id related to the requirement. e.g. REQ.02",
"description": "set and get the system status",
"type": "function",
"parameters": {
"new_status": {
"type": "integer",
"intou": "in",
"required": "true",
"description": "the new system status"
},
"current_status": {
"type": "integer",
"inout": "out",
"required": "true",
"description": "get current system status"
}
}
},
...
]

View File

@ -0,0 +1,6 @@
openai>=1.0.0
python-dotenv>=1.0.0
rich>=13.0.0
python-docx>=0.8.11
PyPDF2>=3.0.0
click>=8.1.0

11
requirements_generator/run.sh Executable file
View File

@ -0,0 +1,11 @@
#!/usr/bin/env bash
# 安装依赖
pip install -r requirements.txt
# 设置环境变量
export OPENAI_API_KEY="sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR"
export OPENAI_BASE_URL="https://openapi.monica.im/v1" # 或其他兼容接口
export LLM_MODEL="gpt-4o"
python main.py

View File

View File

@ -0,0 +1,142 @@
# utils/file_handler.py - 文件读取工具(支持 txt/md/pdf/docx
import os
from typing import List, Optional
from pathlib import Path
def read_text_file(file_path: str) -> str:
"""
读取纯文本文件内容.txt / .md / .py
Args:
file_path: 文件路径
Returns:
文件文本内容
Raises:
FileNotFoundError: 文件不存在
UnicodeDecodeError: 编码错误时尝试 latin-1 兜底
"""
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
try:
return path.read_text(encoding="utf-8")
except UnicodeDecodeError:
return path.read_text(encoding="latin-1")
def read_pdf_file(file_path: str) -> str:
"""
读取 PDF 文件内容
Args:
file_path: PDF 文件路径
Returns:
提取的文本内容
Raises:
ImportError: 未安装 PyPDF2
FileNotFoundError: 文件不存在
"""
try:
import PyPDF2
except ImportError:
raise ImportError("请安装 PyPDF2: pip install PyPDF2")
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
texts = []
with open(file_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text = page.extract_text()
if text:
texts.append(text)
return "\n".join(texts)
def read_docx_file(file_path: str) -> str:
"""
读取 Word (.docx) 文件内容
Args:
file_path: docx 文件路径
Returns:
提取的文本内容段落合并
Raises:
ImportError: 未安装 python-docx
FileNotFoundError: 文件不存在
"""
try:
from docx import Document
except ImportError:
raise ImportError("请安装 python-docx: pip install python-docx")
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
doc = Document(file_path)
return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
def read_file_auto(file_path: str) -> str:
"""
根据文件扩展名自动选择读取方式
Args:
file_path: 文件路径
Returns:
文件文本内容
Raises:
ValueError: 不支持的文件类型
"""
ext = Path(file_path).suffix.lower()
readers = {
".txt": read_text_file,
".md": read_text_file,
".py": read_text_file,
".json": read_text_file,
".yaml": read_text_file,
".yml": read_text_file,
".pdf": read_pdf_file,
".docx": read_docx_file,
}
reader = readers.get(ext)
if reader is None:
raise ValueError(f"不支持的文件类型: {ext},支持: {list(readers.keys())}")
return reader(file_path)
def merge_knowledge_files(file_paths: List[str]) -> str:
"""
合并多个知识库文件为单一文本
Args:
file_paths: 知识库文件路径列表
Returns:
合并后的知识库文本包含文件名分隔符
"""
if not file_paths:
return ""
sections = []
for fp in file_paths:
try:
content = read_file_auto(fp)
file_name = Path(fp).name
sections.append(f"### 知识库文件: {file_name}\n{content}")
except Exception as e:
sections.append(f"### 知识库文件: {fp}\n[读取失败: {e}]")
return "\n\n".join(sections)

View File

@ -0,0 +1,430 @@
# utils/output_writer.py - 代码文件 & JSON 输出工具
import os
import json
from pathlib import Path
from typing import Dict, List
import config
# 各语言文件扩展名映射
LANGUAGE_EXT_MAP: Dict[str, str] = {
"python": ".py",
"javascript": ".js",
"typescript": ".ts",
"java": ".java",
"go": ".go",
"rust": ".rs",
"cpp": ".cpp",
"c": ".c",
"csharp": ".cs",
"ruby": ".rb",
"php": ".php",
"swift": ".swift",
"kotlin": ".kt",
}
# 合法的通用类型集合
VALID_TYPES = {
"integer", "string", "boolean", "float",
"list", "dict", "object", "void", "any",
}
# 合法的 inout 值
VALID_INOUT = {"in", "out", "inout"}
def get_file_extension(language: str) -> str:
"""
获取指定语言的文件扩展名
Args:
language: 编程语言名称小写
Returns:
文件扩展名含点号 '.py'
"""
return LANGUAGE_EXT_MAP.get(language.lower(), ".txt")
def build_project_output_dir(project_name: str) -> str:
"""
构建项目输出目录路径
Args:
project_name: 项目名称
Returns:
输出目录路径
"""
safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_name)
return os.path.join(config.OUTPUT_BASE_DIR, safe_name)
def ensure_project_dir(project_name: str) -> str:
"""
确保项目输出目录存在不存在则创建
Args:
project_name: 项目名称
Returns:
创建好的目录路径
"""
output_dir = build_project_output_dir(project_name)
os.makedirs(output_dir, exist_ok=True)
init_file = os.path.join(output_dir, "__init__.py")
if not os.path.exists(init_file):
Path(init_file).write_text(
"# Auto-generated project package\n", encoding="utf-8"
)
return output_dir
def write_code_file(
output_dir: str,
function_name: str,
language: str,
content: str,
) -> str:
"""
将代码内容写入指定目录的文件
Args:
output_dir: 输出目录路径
function_name: 函数名用于生成文件名
language: 编程语言
content: 代码内容
Returns:
写入的文件完整路径
"""
ext = get_file_extension(language)
file_name = f"{function_name}{ext}"
file_path = os.path.join(output_dir, file_name)
Path(file_path).write_text(content, encoding="utf-8")
return file_path
def write_project_readme(
output_dir: str,
project_name: str,
requirements_summary: str,
) -> str:
"""
在项目目录生成 README.md 文件
Args:
output_dir: 项目输出目录
project_name: 项目名称
requirements_summary: 功能需求摘要文本
Returns:
README.md 文件路径
"""
readme_content = f"""# {project_name}
> Auto-generated by Requirement Analyzer
## 功能需求列表
{requirements_summary}
"""
readme_path = os.path.join(output_dir, "README.md")
Path(readme_path).write_text(readme_content, encoding="utf-8")
return readme_path
# ══════════════════════════════════════════════════════
# 函数签名 JSON 导出
# ══════════════════════════════════════════════════════
def build_signatures_document(
project_name: str,
project_description: str,
signatures: List[dict],
) -> dict:
"""
将函数签名列表包装为带项目信息的顶层文档结构
Args:
project_name: 项目名称写入 "project" 字段
project_description: 项目描述写入 "description" 字段
signatures: 函数签名 dict 列表写入 "functions" 字段
Returns:
顶层文档 dict结构为::
{
"project": "<project_name>",
"description": "<project_description>",
"functions": [ ... ]
}
"""
return {
"project": project_name,
"description": project_description or "",
"functions": signatures,
}
def patch_signatures_with_url(
signatures: List[dict],
func_name_to_url: Dict[str, str],
) -> List[dict]:
"""
将代码文件的路径URL回写到对应函数签名的 "url" 字段
遍历签名列表根据 signature["name"] func_name_to_url 中查找
对应路径找到则写入 "url" 字段未找到则写入空字符串不抛出异常
"url" 字段插入位置紧跟在 "type" 字段之后以保持字段顺序的可读性::
{
"name": "create_user",
"requirement_id": "REQ.01",
"description": "...",
"type": "function",
"url": "/abs/path/to/create_user.py", 新增
"parameters": { ... },
"return": { ... }
}
Args:
signatures: 原始签名列表in-place 修改
func_name_to_url: {函数名: 代码文件绝对路径} 映射表
CodeGenerator.generate_batch() 的进度回调收集
Returns:
修改后的签名列表与传入的同一对象方便链式调用
"""
for sig in signatures:
func_name = sig.get("name", "")
url = func_name_to_url.get(func_name, "")
_insert_field_after(sig, after_key="type", new_key="url", new_value=url)
return signatures
def _insert_field_after(
d: dict,
after_key: str,
new_key: str,
new_value,
) -> None:
"""
在有序 dict 中将 new_key 插入到 after_key 之后
after_key 不存在则追加到末尾
new_key 已存在则直接更新其值不改变位置
Args:
d: 目标 dictPython 3.7+ 保证插入顺序
after_key: 参考键名
new_key: 要插入的键名
new_value: 要插入的值
"""
if new_key in d:
d[new_key] = new_value
return
items = list(d.items())
insert_pos = len(items)
for i, (k, _) in enumerate(items):
if k == after_key:
insert_pos = i + 1
break
items.insert(insert_pos, (new_key, new_value))
d.clear()
d.update(items)
def write_function_signatures_json(
output_dir: str,
signatures: List[dict],
project_name: str,
project_description: str,
file_name: str = "function_signatures.json",
) -> str:
"""
将函数签名列表连同项目信息一起导出为 JSON 文件
输出的 JSON 顶层结构为::
{
"project": "<project_name>",
"description": "<project_description>",
"functions": [
{
"name": "...",
"requirement_id": "...",
"description": "...",
"type": "function",
"url": "/abs/path/to/xxx.py",
"parameters": { ... },
"return": { ... }
},
...
]
}
Args:
output_dir: JSON 文件写入目录
signatures: 函数签名 dict 列表应已通过
patch_signatures_with_url() 写入 "url" 字段
project_name: 项目名称
project_description: 项目描述
file_name: 输出文件名默认 function_signatures.json
Returns:
写入的 JSON 文件完整路径
Raises:
OSError: 目录不可写
"""
os.makedirs(output_dir, exist_ok=True)
document = build_signatures_document(project_name, project_description, signatures)
file_path = os.path.join(output_dir, file_name)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(document, f, ensure_ascii=False, indent=2)
return file_path
# ══════════════════════════════════════════════════════
# 签名结构校验
# ══════════════════════════════════════════════════════
def validate_signature_schema(signature: dict) -> List[str]:
"""
校验单个函数签名 dict 是否符合规范
校验范围
- 顶层必填字段name / requirement_id / description / type / parameters
- 可选字段 "url"若存在则必须为非空字符串
- parameters每个参数的 type / inout / required 字段
- returntype 字段 + on_success / on_failure 子结构
- void 函数on_success / on_failure 应为 null
- void 函数on_success / on_failure 必须存在
value非空 description非空均需填写
Args:
signature: 单个函数签名 dict
Returns:
错误信息字符串列表列表为空表示校验通过
"""
errors: List[str] = []
# ── 顶层必填字段 ──────────────────────────────────
for key in ("name", "requirement_id", "description", "type", "parameters"):
if key not in signature:
errors.append(f"缺少顶层字段: '{key}'")
# ── url 字段(可选,存在时校验非空)─────────────────
if "url" in signature:
if not isinstance(signature["url"], str):
errors.append("'url' 字段必须是字符串类型")
elif signature["url"] == "":
errors.append("'url' 字段不能为空字符串(代码文件路径未成功回写)")
# ── parameters ────────────────────────────────────
params = signature.get("parameters", {})
if not isinstance(params, dict):
errors.append("'parameters' 必须是 dict 类型")
else:
for pname, pdef in params.items():
if not isinstance(pdef, dict):
errors.append(f"参数 '{pname}' 定义必须是 dict")
continue
# type支持联合类型如 "string|integer"
if "type" not in pdef:
errors.append(f"参数 '{pname}' 缺少 'type' 字段")
else:
parts = [p.strip() for p in pdef["type"].split("|")]
if not all(p in VALID_TYPES for p in parts):
errors.append(
f"参数 '{pname}' 的 type='{pdef['type']}' 含有不合法的类型"
)
# inout
if "inout" not in pdef:
errors.append(f"参数 '{pname}' 缺少 'inout' 字段")
elif pdef["inout"] not in VALID_INOUT:
errors.append(
f"参数 '{pname}' 的 inout='{pdef['inout']}' 应为 in/out/inout"
)
# required
if "required" not in pdef:
errors.append(f"参数 '{pname}' 缺少 'required' 字段")
elif not isinstance(pdef["required"], bool):
errors.append(
f"参数 '{pname}''required' 应为布尔值 true/false"
f"当前为: {pdef['required']!r}"
)
# ── return ────────────────────────────────────────
ret = signature.get("return")
if ret is None:
errors.append(
"缺少 'return' 字段void 函数请填 "
"{\"type\": \"void\", \"on_success\": null, \"on_failure\": null}"
)
elif not isinstance(ret, dict):
errors.append("'return' 必须是 dict 类型")
else:
ret_type = ret.get("type")
if not ret_type:
errors.append("'return' 缺少 'type' 字段")
elif ret_type not in VALID_TYPES:
errors.append(f"'return.type'='{ret_type}' 不在合法类型列表中")
is_void = (ret_type == "void")
for sub_key in ("on_success", "on_failure"):
sub = ret.get(sub_key)
if is_void:
if sub is not None:
errors.append(
f"void 函数的 'return.{sub_key}' 应为 null"
f"当前为: {sub!r}"
)
else:
if sub is None:
errors.append(
f"非 void 函数缺少 'return.{sub_key}'"
f"请描述{'成功' if sub_key == 'on_success' else '失败'}时的返回值"
)
elif not isinstance(sub, dict):
errors.append(f"'return.{sub_key}' 必须是 dict 类型")
else:
if "value" not in sub:
errors.append(f"'return.{sub_key}' 缺少 'value' 字段")
elif sub["value"] == "":
errors.append(
f"'return.{sub_key}.value' 不能为空字符串,"
f"请填写具体返回值、值域描述或结构示例"
)
if "description" not in sub or sub.get("description") in (None, ""):
errors.append(f"'return.{sub_key}.description' 不能为空")
return errors
def validate_all_signatures(signatures: List[dict]) -> Dict[str, List[str]]:
"""
批量校验函数签名列表
注意此函数接受的是纯签名列表即顶层文档的 "functions" 字段
而非包含 project/description 的顶层文档
Args:
signatures: 函数签名 dict 列表
Returns:
{函数名: [错误信息, ...]} 字典仅包含有错误的条目
"""
report: Dict[str, List[str]] = {}
for sig in signatures:
name = sig.get("name", f"unknown_{id(sig)}")
errs = validate_signature_schema(sig)
if errs:
report[name] = errs
return report