AIDeveloper-PC/ai_test_generator/main.py

232 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
AI-Powered API Test Generator, Runner & Coverage Analyzer
──────────────────────────────────────────────────────────
Usage:
python main.py --api-desc examples/api_desc.json \\
--requirements "创建用户,删除用户"
python main.py --api-desc examples/api_desc.json \\
--req-file examples/requirements.txt
"""
import argparse
import logging
import sys
from pathlib import Path
from config import config
from core.parser import InterfaceParser, ApiDescriptor
from core.prompt_builder import PromptBuilder
from core.llm_client import LLMClient
from core.test_generator import TestGenerator
from core.test_runner import TestRunner
from core.analyzer import CoverageAnalyzer, CoverageReport
from core.report_generator import ReportGenerator
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s%(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
# ── CLI ───────────────────────────────────────────────────────
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="AI-Powered API Test Generator & Coverage Analyzer"
)
p.add_argument(
"--api-desc", required=True,
help="接口描述 JSON 文件(含 project / description / units 字段)",
)
p.add_argument(
"--requirements", default="",
help="测试需求(逗号分隔)",
)
p.add_argument(
"--req-file", default="",
help="测试需求文件(每行一条)",
)
p.add_argument(
"--skip-run", action="store_true",
help="只生成测试文件,不执行",
)
p.add_argument(
"--skip-analyze", action="store_true",
help="跳过覆盖率分析",
)
p.add_argument(
"--output-dir", default="",
help="测试文件根输出目录(默认 config.GENERATED_TESTS_DIR",
)
p.add_argument(
"--debug", action="store_true",
help="开启 DEBUG 日志",
)
return p
def load_requirements(args) -> list[str]:
reqs: list[str] = []
if args.requirements:
reqs += [r.strip() for r in args.requirements.split(",") if r.strip()]
if args.req_file:
with open(args.req_file, "r", encoding="utf-8") as f:
reqs += [line.strip() for line in f if line.strip()]
if not reqs:
logger.error(
"No requirements provided. Use --requirements or --req-file."
)
sys.exit(1)
return reqs
# ── Main ──────────────────────────────────────────────────────
def main():
args = build_arg_parser().parse_args()
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
if args.output_dir:
config.GENERATED_TESTS_DIR = args.output_dir
# ── Step 1: 解析接口描述 ──────────────────────────────────
logger.info("▶ Step 1: Parsing interface description …")
parser: InterfaceParser = InterfaceParser()
descriptor: ApiDescriptor = parser.parse_file(args.api_desc)
project = descriptor.project
project_desc = descriptor.description
interfaces = descriptor.interfaces
logger.info(f" Project : {project}")
logger.info(f" Description: {project_desc}")
logger.info(
f" Interfaces : {len(interfaces)}"
f"{[i.name for i in interfaces]}"
)
# 打印每个接口的 url 解析结果
for iface in interfaces:
if iface.protocol == "function":
logger.info(
f" [{iface.name}] source={iface.source_file} "
f"module={iface.module_path}"
)
else:
logger.info(
f" [{iface.name}] full_url={iface.http_full_url}"
)
# ── Step 2: 加载测试需求 ──────────────────────────────────
logger.info("▶ Step 2: Loading test requirements …")
requirements = load_requirements(args)
for i, req in enumerate(requirements, 1):
logger.info(f" {i}. {req}")
# ── Step 3: 调用 LLM 生成测试用例 ────────────────────────
logger.info("▶ Step 3: Calling LLM …")
builder = PromptBuilder()
test_cases = LLMClient().generate_test_cases(
builder.get_system_prompt(),
builder.build_user_prompt(
requirements, interfaces, parser,
project=project,
project_desc=project_desc,
),
)
logger.info(f" LLM returned {len(test_cases)} test case(s)")
# ── Step 4: 生成测试文件 ──────────────────────────────────
logger.info(
f"▶ Step 4: Generating test files (project='{project}') …"
)
generator = TestGenerator(project=project, project_desc=project_desc)
test_files = generator.generate(test_cases)
out = generator.output_dir
run_results = []
if not args.skip_run:
# ── Step 5: 执行测试 ──────────────────────────────────
logger.info("▶ Step 5: Running tests …")
runner = TestRunner()
run_results = runner.run_all(test_files)
runner.print_summary(run_results)
runner.save_results(run_results, str(out / "run_results.json"))
if not args.skip_analyze:
# ── Step 6: 覆盖率分析 ────────────────────────────────
logger.info("▶ Step 6: Analyzing coverage …")
report = CoverageAnalyzer(
interfaces=interfaces,
requirements=requirements,
test_cases=test_cases,
run_results=run_results,
).analyze()
# ── Step 7: 生成报告 ──────────────────────────────────
logger.info("▶ Step 7: Generating reports …")
rg = ReportGenerator()
rg.save_json(report, str(out / "coverage_report.json"))
rg.save_html(report, str(out / "coverage_report.html"))
_print_terminal_summary(report, out, project)
logger.info(f"\n✅ Done. Output: {out.resolve()}")
# ── 终端摘要 ──────────────────────────────────────────────────
def _print_terminal_summary(
report: CoverageReport, out: Path, project: str
):
W = 66
def bar(rate: float, w: int = 20) -> str:
filled = int(rate * w)
empty = w - filled
icon = "" if rate >= 0.8 else ("⚠️ " if rate >= 0.5 else "")
return f"{'' * filled}{'' * empty} {rate * 100:.1f}% {icon}"
print(f"\n{'' * W}")
print(f" PROJECT : {project}")
print(f" COVERAGE SUMMARY")
print(f"{'' * W}")
print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
print(f" 用例通过率 {bar(report.pass_rate)}")
print(f"{'' * W}")
print(f" Total Gaps : {len(report.gaps)}")
print(f" 🔴 Critical: {report.critical_gap_count}")
print(f" 🟠 High : {report.high_gap_count}")
if report.gaps:
print(f"{'' * W}")
print(" Top Gaps (up to 8):")
icons = {
"critical": "🔴", "high": "🟠",
"medium": "🟡", "low": "🔵",
}
for g in report.gaps[:8]:
icon = icons.get(g.severity, "")
print(f" {icon} [{g.gap_type}] {g.target}")
print(f"{g.suggestion}")
print(f"{'' * W}")
print(f" Output : {out.resolve()}")
print(f" • coverage_report.html ← open in browser")
print(f" • coverage_report.json")
print(f" • run_results.json")
print(f" • test_cases_summary.json")
print(f"{'' * W}\n")
if __name__ == "__main__":
main()