232 lines
8.7 KiB
Python
232 lines
8.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
AI-Powered API Test Generator, Runner & Coverage Analyzer
|
||
──────────────────────────────────────────────────────────
|
||
Usage:
|
||
python main.py --api-desc examples/api_desc.json \\
|
||
--requirements "创建用户,删除用户"
|
||
|
||
python main.py --api-desc examples/api_desc.json \\
|
||
--req-file examples/requirements.txt
|
||
"""
|
||
|
||
import argparse
|
||
import logging
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
from config import config
|
||
from core.parser import InterfaceParser, ApiDescriptor
|
||
from core.prompt_builder import PromptBuilder
|
||
from core.llm_client import LLMClient
|
||
from core.test_generator import TestGenerator
|
||
from core.test_runner import TestRunner
|
||
from core.analyzer import CoverageAnalyzer, CoverageReport
|
||
from core.report_generator import ReportGenerator
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
|
||
handlers=[logging.StreamHandler(sys.stdout)],
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ── CLI ───────────────────────────────────────────────────────
|
||
|
||
def build_arg_parser() -> argparse.ArgumentParser:
|
||
p = argparse.ArgumentParser(
|
||
description="AI-Powered API Test Generator & Coverage Analyzer"
|
||
)
|
||
p.add_argument(
|
||
"--api-desc", required=True,
|
||
help="接口描述 JSON 文件(含 project / description / units 字段)",
|
||
)
|
||
p.add_argument(
|
||
"--requirements", default="",
|
||
help="测试需求(逗号分隔)",
|
||
)
|
||
p.add_argument(
|
||
"--req-file", default="",
|
||
help="测试需求文件(每行一条)",
|
||
)
|
||
p.add_argument(
|
||
"--skip-run", action="store_true",
|
||
help="只生成测试文件,不执行",
|
||
)
|
||
p.add_argument(
|
||
"--skip-analyze", action="store_true",
|
||
help="跳过覆盖率分析",
|
||
)
|
||
p.add_argument(
|
||
"--output-dir", default="",
|
||
help="测试文件根输出目录(默认 config.GENERATED_TESTS_DIR)",
|
||
)
|
||
p.add_argument(
|
||
"--debug", action="store_true",
|
||
help="开启 DEBUG 日志",
|
||
)
|
||
return p
|
||
|
||
|
||
def load_requirements(args) -> list[str]:
|
||
reqs: list[str] = []
|
||
if args.requirements:
|
||
reqs += [r.strip() for r in args.requirements.split(",") if r.strip()]
|
||
if args.req_file:
|
||
with open(args.req_file, "r", encoding="utf-8") as f:
|
||
reqs += [line.strip() for line in f if line.strip()]
|
||
if not reqs:
|
||
logger.error(
|
||
"No requirements provided. Use --requirements or --req-file."
|
||
)
|
||
sys.exit(1)
|
||
return reqs
|
||
|
||
|
||
# ── Main ──────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
args = build_arg_parser().parse_args()
|
||
|
||
if args.debug:
|
||
logging.getLogger().setLevel(logging.DEBUG)
|
||
if args.output_dir:
|
||
config.GENERATED_TESTS_DIR = args.output_dir
|
||
|
||
# ── Step 1: 解析接口描述 ──────────────────────────────────
|
||
logger.info("▶ Step 1: Parsing interface description …")
|
||
parser: InterfaceParser = InterfaceParser()
|
||
descriptor: ApiDescriptor = parser.parse_file(args.api_desc)
|
||
|
||
project = descriptor.project
|
||
project_desc = descriptor.description
|
||
interfaces = descriptor.interfaces
|
||
|
||
logger.info(f" Project : {project}")
|
||
logger.info(f" Description: {project_desc}")
|
||
logger.info(
|
||
f" Interfaces : {len(interfaces)} — "
|
||
f"{[i.name for i in interfaces]}"
|
||
)
|
||
# 打印每个接口的 url 解析结果
|
||
for iface in interfaces:
|
||
if iface.protocol == "function":
|
||
logger.info(
|
||
f" [{iface.name}] source={iface.source_file} "
|
||
f"module={iface.module_path}"
|
||
)
|
||
else:
|
||
logger.info(
|
||
f" [{iface.name}] full_url={iface.http_full_url}"
|
||
)
|
||
|
||
# ── Step 2: 加载测试需求 ──────────────────────────────────
|
||
logger.info("▶ Step 2: Loading test requirements …")
|
||
requirements = load_requirements(args)
|
||
for i, req in enumerate(requirements, 1):
|
||
logger.info(f" {i}. {req}")
|
||
|
||
# ── Step 3: 调用 LLM 生成测试用例 ────────────────────────
|
||
logger.info("▶ Step 3: Calling LLM …")
|
||
builder = PromptBuilder()
|
||
test_cases = LLMClient().generate_test_cases(
|
||
builder.get_system_prompt(),
|
||
builder.build_user_prompt(
|
||
requirements, interfaces, parser,
|
||
project=project,
|
||
project_desc=project_desc,
|
||
),
|
||
)
|
||
logger.info(f" LLM returned {len(test_cases)} test case(s)")
|
||
|
||
# ── Step 4: 生成测试文件 ──────────────────────────────────
|
||
logger.info(
|
||
f"▶ Step 4: Generating test files (project='{project}') …"
|
||
)
|
||
generator = TestGenerator(project=project, project_desc=project_desc)
|
||
test_files = generator.generate(test_cases)
|
||
out = generator.output_dir
|
||
|
||
run_results = []
|
||
|
||
if not args.skip_run:
|
||
# ── Step 5: 执行测试 ──────────────────────────────────
|
||
logger.info("▶ Step 5: Running tests …")
|
||
runner = TestRunner()
|
||
run_results = runner.run_all(test_files)
|
||
runner.print_summary(run_results)
|
||
runner.save_results(run_results, str(out / "run_results.json"))
|
||
|
||
if not args.skip_analyze:
|
||
# ── Step 6: 覆盖率分析 ────────────────────────────────
|
||
logger.info("▶ Step 6: Analyzing coverage …")
|
||
report = CoverageAnalyzer(
|
||
interfaces=interfaces,
|
||
requirements=requirements,
|
||
test_cases=test_cases,
|
||
run_results=run_results,
|
||
).analyze()
|
||
|
||
# ── Step 7: 生成报告 ──────────────────────────────────
|
||
logger.info("▶ Step 7: Generating reports …")
|
||
rg = ReportGenerator()
|
||
rg.save_json(report, str(out / "coverage_report.json"))
|
||
rg.save_html(report, str(out / "coverage_report.html"))
|
||
|
||
_print_terminal_summary(report, out, project)
|
||
|
||
logger.info(f"\n✅ Done. Output: {out.resolve()}")
|
||
|
||
|
||
# ── 终端摘要 ──────────────────────────────────────────────────
|
||
|
||
def _print_terminal_summary(
|
||
report: CoverageReport, out: Path, project: str
|
||
):
|
||
W = 66
|
||
|
||
def bar(rate: float, w: int = 20) -> str:
|
||
filled = int(rate * w)
|
||
empty = w - filled
|
||
icon = "✅" if rate >= 0.8 else ("⚠️ " if rate >= 0.5 else "❌")
|
||
return f"{'█' * filled}{'░' * empty} {rate * 100:.1f}% {icon}"
|
||
|
||
print(f"\n{'═' * W}")
|
||
print(f" PROJECT : {project}")
|
||
print(f" COVERAGE SUMMARY")
|
||
print(f"{'═' * W}")
|
||
print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
|
||
print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
|
||
print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
|
||
print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
|
||
print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
|
||
print(f" 用例通过率 {bar(report.pass_rate)}")
|
||
print(f"{'─' * W}")
|
||
print(f" Total Gaps : {len(report.gaps)}")
|
||
print(f" 🔴 Critical: {report.critical_gap_count}")
|
||
print(f" 🟠 High : {report.high_gap_count}")
|
||
|
||
if report.gaps:
|
||
print(f"{'─' * W}")
|
||
print(" Top Gaps (up to 8):")
|
||
icons = {
|
||
"critical": "🔴", "high": "🟠",
|
||
"medium": "🟡", "low": "🔵",
|
||
}
|
||
for g in report.gaps[:8]:
|
||
icon = icons.get(g.severity, "⚪")
|
||
print(f" {icon} [{g.gap_type}] {g.target}")
|
||
print(f" → {g.suggestion}")
|
||
|
||
print(f"{'─' * W}")
|
||
print(f" Output : {out.resolve()}")
|
||
print(f" • coverage_report.html ← open in browser")
|
||
print(f" • coverage_report.json")
|
||
print(f" • run_results.json")
|
||
print(f" • test_cases_summary.json")
|
||
print(f"{'═' * W}\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |