AIDeveloper-PC/ai_test_generator/main.py

232 lines
8.7 KiB
Python
Raw Normal View History

2026-03-04 18:09:45 +00:00
#!/usr/bin/env python3
"""
AI-Powered API Test Generator, Runner & Coverage Analyzer
Usage:
python main.py --api-desc examples/api_desc.json \\
--requirements "创建用户,删除用户"
python main.py --api-desc examples/api_desc.json \\
--req-file examples/requirements.txt
"""
import argparse
import logging
import sys
from pathlib import Path
from config import config
from core.parser import InterfaceParser, ApiDescriptor
from core.prompt_builder import PromptBuilder
from core.llm_client import LLMClient
from core.test_generator import TestGenerator
from core.test_runner import TestRunner
from core.analyzer import CoverageAnalyzer, CoverageReport
from core.report_generator import ReportGenerator
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s%(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
# ── CLI ───────────────────────────────────────────────────────
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="AI-Powered API Test Generator & Coverage Analyzer"
)
p.add_argument(
"--api-desc", required=True,
help="接口描述 JSON 文件(含 project / description / units 字段)",
)
p.add_argument(
"--requirements", default="",
help="测试需求(逗号分隔)",
)
p.add_argument(
"--req-file", default="",
help="测试需求文件(每行一条)",
)
p.add_argument(
"--skip-run", action="store_true",
help="只生成测试文件,不执行",
)
p.add_argument(
"--skip-analyze", action="store_true",
help="跳过覆盖率分析",
)
p.add_argument(
"--output-dir", default="",
help="测试文件根输出目录(默认 config.GENERATED_TESTS_DIR",
)
p.add_argument(
"--debug", action="store_true",
help="开启 DEBUG 日志",
)
return p
def load_requirements(args) -> list[str]:
reqs: list[str] = []
if args.requirements:
reqs += [r.strip() for r in args.requirements.split(",") if r.strip()]
if args.req_file:
with open(args.req_file, "r", encoding="utf-8") as f:
reqs += [line.strip() for line in f if line.strip()]
if not reqs:
logger.error(
"No requirements provided. Use --requirements or --req-file."
)
sys.exit(1)
return reqs
# ── Main ──────────────────────────────────────────────────────
def main():
args = build_arg_parser().parse_args()
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
if args.output_dir:
config.GENERATED_TESTS_DIR = args.output_dir
# ── Step 1: 解析接口描述 ──────────────────────────────────
logger.info("▶ Step 1: Parsing interface description …")
parser: InterfaceParser = InterfaceParser()
descriptor: ApiDescriptor = parser.parse_file(args.api_desc)
project = descriptor.project
project_desc = descriptor.description
interfaces = descriptor.interfaces
logger.info(f" Project : {project}")
logger.info(f" Description: {project_desc}")
logger.info(
f" Interfaces : {len(interfaces)}"
f"{[i.name for i in interfaces]}"
)
# 打印每个接口的 url 解析结果
for iface in interfaces:
if iface.protocol == "function":
logger.info(
f" [{iface.name}] source={iface.source_file} "
f"module={iface.module_path}"
)
else:
logger.info(
f" [{iface.name}] full_url={iface.http_full_url}"
)
# ── Step 2: 加载测试需求 ──────────────────────────────────
logger.info("▶ Step 2: Loading test requirements …")
requirements = load_requirements(args)
for i, req in enumerate(requirements, 1):
logger.info(f" {i}. {req}")
# ── Step 3: 调用 LLM 生成测试用例 ────────────────────────
logger.info("▶ Step 3: Calling LLM …")
builder = PromptBuilder()
test_cases = LLMClient().generate_test_cases(
builder.get_system_prompt(),
builder.build_user_prompt(
requirements, interfaces, parser,
project=project,
project_desc=project_desc,
),
)
logger.info(f" LLM returned {len(test_cases)} test case(s)")
# ── Step 4: 生成测试文件 ──────────────────────────────────
logger.info(
f"▶ Step 4: Generating test files (project='{project}') …"
)
generator = TestGenerator(project=project, project_desc=project_desc)
test_files = generator.generate(test_cases)
out = generator.output_dir
run_results = []
if not args.skip_run:
# ── Step 5: 执行测试 ──────────────────────────────────
logger.info("▶ Step 5: Running tests …")
runner = TestRunner()
run_results = runner.run_all(test_files)
runner.print_summary(run_results)
runner.save_results(run_results, str(out / "run_results.json"))
if not args.skip_analyze:
# ── Step 6: 覆盖率分析 ────────────────────────────────
logger.info("▶ Step 6: Analyzing coverage …")
report = CoverageAnalyzer(
interfaces=interfaces,
requirements=requirements,
test_cases=test_cases,
run_results=run_results,
).analyze()
# ── Step 7: 生成报告 ──────────────────────────────────
logger.info("▶ Step 7: Generating reports …")
rg = ReportGenerator()
rg.save_json(report, str(out / "coverage_report.json"))
rg.save_html(report, str(out / "coverage_report.html"))
_print_terminal_summary(report, out, project)
logger.info(f"\n✅ Done. Output: {out.resolve()}")
# ── 终端摘要 ──────────────────────────────────────────────────
def _print_terminal_summary(
report: CoverageReport, out: Path, project: str
):
W = 66
def bar(rate: float, w: int = 20) -> str:
filled = int(rate * w)
empty = w - filled
icon = "" if rate >= 0.8 else ("⚠️ " if rate >= 0.5 else "")
return f"{'' * filled}{'' * empty} {rate * 100:.1f}% {icon}"
print(f"\n{'' * W}")
print(f" PROJECT : {project}")
print(f" COVERAGE SUMMARY")
print(f"{'' * W}")
print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
print(f" 用例通过率 {bar(report.pass_rate)}")
print(f"{'' * W}")
print(f" Total Gaps : {len(report.gaps)}")
print(f" 🔴 Critical: {report.critical_gap_count}")
print(f" 🟠 High : {report.high_gap_count}")
if report.gaps:
print(f"{'' * W}")
print(" Top Gaps (up to 8):")
icons = {
"critical": "🔴", "high": "🟠",
"medium": "🟡", "low": "🔵",
}
for g in report.gaps[:8]:
icon = icons.get(g.severity, "")
print(f" {icon} [{g.gap_type}] {g.target}")
print(f"{g.suggestion}")
print(f"{'' * W}")
print(f" Output : {out.resolve()}")
print(f" • coverage_report.html ← open in browser")
print(f" • coverage_report.json")
print(f" • run_results.json")
print(f" • test_cases_summary.json")
print(f"{'' * W}\n")
if __name__ == "__main__":
main()