AIDeveloper-PC/ai_test_generator/main.py

289 lines
13 KiB
Python
Raw Normal View History

2026-03-04 18:09:45 +00:00
#!/usr/bin/env python3
"""
AI-Powered API Test Generator, Runner & Coverage Analyzer
Usage:
2026-03-09 03:48:19 +00:00
# 基础用法
2026-03-04 18:09:45 +00:00
python main.py --api-desc examples/api_desc.json \\
--requirements "创建用户,删除用户"
2026-03-09 03:48:19 +00:00
# 在需求中自然语言描述参数生成要求
2026-03-04 18:09:45 +00:00
python main.py --api-desc examples/api_desc.json \\
2026-03-09 03:48:19 +00:00
--requirements "创建用户测试参数不少于10组覆盖边界值和异常值,删除用户"
# 从文件读取需求(每行一条,支持参数生成指令)
python main.py --api-desc examples/api_desc.json \\
--req-file examples/requirements.txt \\
--batch-size 10
2026-03-04 18:09:45 +00:00
"""
import argparse
import logging
import sys
from pathlib import Path
from config import config
from core.parser import InterfaceParser, ApiDescriptor
2026-03-09 03:48:19 +00:00
from core.requirement_parser import RequirementParser, ParamConstraint
2026-03-04 18:09:45 +00:00
from core.prompt_builder import PromptBuilder
from core.llm_client import LLMClient
from core.test_generator import TestGenerator
from core.test_runner import TestRunner
from core.analyzer import CoverageAnalyzer, CoverageReport
from core.report_generator import ReportGenerator
2026-03-09 03:48:19 +00:00
# ══════════════════════════════════════════════════════════════
# 日志初始化
# ══════════════════════════════════════════════════════════════
def setup_logging(debug: bool = False):
"""
配置日志级别
Bug 规避openai-python v2.24.0 + pydantic-core
第三方 SDK logger 强制锁定在 WARNING避免触发
model_dump(by_alias=None) pydantic-core Rust TypeError
Ref: https://github.com/openai/openai-python/issues/2921
"""
root_level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(
level=root_level,
format="%(asctime)s [%(levelname)s] %(name)s%(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
for name in ["openai", "openai._base_client", "anthropic", "httpx", "httpcore"]:
logging.getLogger(name).setLevel(logging.WARNING)
2026-03-04 18:09:45 +00:00
logger = logging.getLogger(__name__)
2026-03-09 03:48:19 +00:00
# ══════════════════════════════════════════════════════════════
# CLI
# ══════════════════════════════════════════════════════════════
2026-03-04 18:09:45 +00:00
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
2026-03-09 03:48:19 +00:00
description="AI-Powered API Test Generator & Coverage Analyzer",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
参数生成指令示例可直接写在需求文本中
"创建用户测试参数不少于10组覆盖边界值和异常值"
"查询接口至少8组边界值"
"所有接口参数不少于5组覆盖等价类"
""",
2026-03-04 18:09:45 +00:00
)
2026-03-09 03:48:19 +00:00
p.add_argument("--api-desc", required=True, help="接口描述 JSON 文件")
p.add_argument("--requirements", default="", help="测试需求(逗号分隔,支持参数生成指令)")
p.add_argument("--req-file", default="", help="测试需求文件(每行一条)")
p.add_argument("--batch-size", type=int, default=0, help="每批接口数量0=使用 config.LLM_BATCH_SIZE")
p.add_argument("--workers", type=int, default=0, help="并行执行线程数0=使用 config.TEST_MAX_WORKERS")
p.add_argument("--skip-run", action="store_true", help="只生成,不执行")
p.add_argument("--skip-analyze", action="store_true", help="跳过覆盖率分析")
p.add_argument("--output-dir", default="", help="测试文件根输出目录")
p.add_argument("--debug", action="store_true", help="开启 DEBUG 日志(第三方 SDK 仍保持 WARNING")
2026-03-04 18:09:45 +00:00
return p
def load_requirements(args) -> list[str]:
reqs: list[str] = []
if args.requirements:
reqs += [r.strip() for r in args.requirements.split(",") if r.strip()]
if args.req_file:
with open(args.req_file, "r", encoding="utf-8") as f:
reqs += [line.strip() for line in f if line.strip()]
if not reqs:
2026-03-09 03:48:19 +00:00
logger.error("未提供任何测试需求,请使用 --requirements 或 --req-file 指定。")
2026-03-04 18:09:45 +00:00
sys.exit(1)
return reqs
2026-03-09 03:48:19 +00:00
# ══════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════
2026-03-04 18:09:45 +00:00
def main():
args = build_arg_parser().parse_args()
2026-03-09 03:48:19 +00:00
setup_logging(debug=args.debug)
2026-03-04 18:09:45 +00:00
2026-03-09 03:48:19 +00:00
if args.output_dir: config.GENERATED_TESTS_DIR = args.output_dir
if args.batch_size > 0: config.LLM_BATCH_SIZE = args.batch_size
if args.workers > 0: config.TEST_MAX_WORKERS = args.workers
2026-03-04 18:09:45 +00:00
# ── Step 1: 解析接口描述 ──────────────────────────────────
2026-03-09 03:48:19 +00:00
logger.info("▶ Step 1: 解析接口描述 …")
parser = InterfaceParser()
descriptor = parser.parse_file(args.api_desc)
2026-03-04 18:09:45 +00:00
project = descriptor.project
project_desc = descriptor.description
interfaces = descriptor.interfaces
2026-03-09 03:48:19 +00:00
logger.info(f" 项目名称:{project}")
logger.info(f" 项目描述:{project_desc}")
logger.info(f" 接口数量:{len(interfaces)}")
2026-03-04 18:09:45 +00:00
for iface in interfaces:
if iface.protocol == "function":
logger.info(
2026-03-09 03:48:19 +00:00
f" [function] {iface.name} "
f"{iface.source_file} (module: {iface.module_path})"
2026-03-04 18:09:45 +00:00
)
else:
2026-03-09 03:48:19 +00:00
logger.info(f" [http] {iface.name}{iface.http_full_url}")
# ── Step 2: 加载并解析测试需求(含参数生成指令)──────────
logger.info("▶ Step 2: 加载并解析测试需求 …")
raw_requirements = load_requirements(args)
req_parser = RequirementParser()
clean_reqs, per_req_constraints, global_constraint = req_parser.parse(
raw_requirements
)
# 打印原始需求 & 解析结果
for i, (raw, clean, c) in enumerate(
zip(raw_requirements, clean_reqs, per_req_constraints), 1
):
if c.has_param_directive:
2026-03-04 18:09:45 +00:00
logger.info(
2026-03-09 03:48:19 +00:00
f" {i}. {raw}\n"
f" └─ 参数约束:{c} | 业务需求:{clean}"
2026-03-04 18:09:45 +00:00
)
2026-03-09 03:48:19 +00:00
else:
logger.info(f" {i}. {raw}")
if global_constraint.has_param_directive:
logger.info(f" 全局参数约束:{global_constraint}")
# 使用"清洗后的需求"传给 LLM去掉参数指令保留纯业务描述
# 同时保留原始需求用于覆盖率分析(保证与用户输入一致)
effective_requirements = clean_reqs
# ── Step 3: 构建 Prompt 并调用 LLM 生成测试用例 ──────────
logger.info(
f"▶ Step 3: 调用 LLM 生成测试用例 "
f"(batch_size={getattr(config, 'LLM_BATCH_SIZE', 10)}) …"
)
builder = PromptBuilder()
iface_summary = parser.to_summary_dict(interfaces)
project_header = builder.build_project_header(project, project_desc)
2026-03-04 18:09:45 +00:00
2026-03-09 03:48:19 +00:00
# System Prompt注入全局参数约束规则
system_prompt = builder.get_system_prompt(
global_constraint=global_constraint
)
2026-03-04 18:09:45 +00:00
test_cases = LLMClient().generate_test_cases(
2026-03-09 03:48:19 +00:00
system_prompt=system_prompt,
user_prompt="",
iface_summaries=iface_summary,
requirements=effective_requirements,
project_header=project_header,
param_constraint=global_constraint, # ← 注入参数约束
2026-03-04 18:09:45 +00:00
)
2026-03-09 03:48:19 +00:00
logger.info(f" 共生成测试用例:{len(test_cases)}")
2026-03-04 18:09:45 +00:00
# ── Step 4: 生成测试文件 ──────────────────────────────────
2026-03-09 03:48:19 +00:00
logger.info("▶ Step 4: 生成测试文件 …")
2026-03-04 18:09:45 +00:00
generator = TestGenerator(project=project, project_desc=project_desc)
test_files = generator.generate(test_cases)
out = generator.output_dir
2026-03-09 03:48:19 +00:00
logger.info(f" 输出目录:{out.resolve()}")
2026-03-04 18:09:45 +00:00
run_results = []
if not args.skip_run:
2026-03-09 03:48:19 +00:00
# ── Step 5: 并行执行测试 ──────────────────────────────
logger.info(
f"▶ Step 5: 执行测试 "
f"(workers={getattr(config, 'TEST_MAX_WORKERS', 8)}) …"
)
2026-03-04 18:09:45 +00:00
runner = TestRunner()
run_results = runner.run_all(test_files)
runner.print_summary(run_results)
runner.save_results(run_results, str(out / "run_results.json"))
if not args.skip_analyze:
# ── Step 6: 覆盖率分析 ────────────────────────────────
2026-03-09 03:48:19 +00:00
logger.info("▶ Step 6: 覆盖率分析 …")
2026-03-04 18:09:45 +00:00
report = CoverageAnalyzer(
interfaces=interfaces,
2026-03-09 03:48:19 +00:00
requirements=raw_requirements, # 用原始需求做覆盖率分析
2026-03-04 18:09:45 +00:00
test_cases=test_cases,
run_results=run_results,
).analyze()
# ── Step 7: 生成报告 ──────────────────────────────────
2026-03-09 03:48:19 +00:00
logger.info("▶ Step 7: 生成报告 …")
2026-03-04 18:09:45 +00:00
rg = ReportGenerator()
rg.save_json(report, str(out / "coverage_report.json"))
rg.save_html(report, str(out / "coverage_report.html"))
2026-03-09 03:48:19 +00:00
_print_terminal_summary(report, out, project, global_constraint)
2026-03-04 18:09:45 +00:00
2026-03-09 03:48:19 +00:00
logger.info(f"\n✅ 完成。输出目录:{out.resolve()}")
2026-03-04 18:09:45 +00:00
2026-03-09 03:48:19 +00:00
# ══════════════════════════════════════════════════════════════
# 终端摘要
# ══════════════════════════════════════════════════════════════
2026-03-04 18:09:45 +00:00
def _print_terminal_summary(
2026-03-09 03:48:19 +00:00
report: "CoverageReport",
out: Path,
project: str,
global_constraint: ParamConstraint,
2026-03-04 18:09:45 +00:00
):
2026-03-09 03:48:19 +00:00
W = 68
2026-03-04 18:09:45 +00:00
def bar(rate: float, w: int = 20) -> str:
filled = int(rate * w)
empty = w - filled
icon = "" if rate >= 0.8 else ("⚠️ " if rate >= 0.5 else "")
return f"{'' * filled}{'' * empty} {rate * 100:.1f}% {icon}"
print(f"\n{'' * W}")
2026-03-09 03:48:19 +00:00
print(f" 项目:{project}")
print(f" 覆盖率摘要")
2026-03-04 18:09:45 +00:00
print(f"{'' * W}")
2026-03-09 03:48:19 +00:00
print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
print(f" 用例通过率 {bar(report.pass_rate)}")
2026-03-04 18:09:45 +00:00
print(f"{'' * W}")
2026-03-09 03:48:19 +00:00
print(f" 测试用例总数 {report.total_test_cases}")
print(f" 覆盖缺口总数 {len(report.gaps)}")
print(f" 🔴 严重缺口 {report.critical_gap_count}")
print(f" 🟠 高优先级缺口 {report.high_gap_count}")
# 参数约束达成情况
if global_constraint.has_param_directive:
print(f"{'' * W}")
print(f" 参数生成约束 {global_constraint}")
actual = report.total_test_cases
needed = global_constraint.min_groups
status = "✅ 已满足" if actual >= needed else f"❌ 未满足(需 {needed} 组,实际 {actual} 组)"
print(f" 参数组数要求 {status}")
2026-03-04 18:09:45 +00:00
if report.gaps:
print(f"{'' * W}")
2026-03-09 03:48:19 +00:00
print(" Top 缺口最多显示8条")
icons = {"critical": "🔴", "high": "🟠", "medium": "🟡", "low": "🔵"}
2026-03-04 18:09:45 +00:00
for g in report.gaps[:8]:
2026-03-09 03:48:19 +00:00
print(f" {icons.get(g.severity, '')} [{g.gap_type}] {g.target}")
2026-03-04 18:09:45 +00:00
print(f"{g.suggestion}")
print(f"{'' * W}")
2026-03-09 03:48:19 +00:00
print(f" 输出目录:{out.resolve()}")
print(f" • coverage_report.html")
2026-03-04 18:09:45 +00:00
print(f" • coverage_report.json")
print(f" • run_results.json")
print(f" • test_cases_summary.json")
print(f"{'' * W}\n")
if __name__ == "__main__":
main()