AIDeveloper-PC/ai_test_generator/main.py

289 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
AI-Powered API Test Generator, Runner & Coverage Analyzer
──────────────────────────────────────────────────────────
Usage:
# 基础用法
python main.py --api-desc examples/api_desc.json \\
--requirements "创建用户,删除用户"
# 在需求中自然语言描述参数生成要求
python main.py --api-desc examples/api_desc.json \\
--requirements "创建用户测试参数不少于10组覆盖边界值和异常值,删除用户"
# 从文件读取需求(每行一条,支持参数生成指令)
python main.py --api-desc examples/api_desc.json \\
--req-file examples/requirements.txt \\
--batch-size 10
"""
import argparse
import logging
import sys
from pathlib import Path
from config import config
from core.parser import InterfaceParser, ApiDescriptor
from core.requirement_parser import RequirementParser, ParamConstraint
from core.prompt_builder import PromptBuilder
from core.llm_client import LLMClient
from core.test_generator import TestGenerator
from core.test_runner import TestRunner
from core.analyzer import CoverageAnalyzer, CoverageReport
from core.report_generator import ReportGenerator
# ══════════════════════════════════════════════════════════════
# 日志初始化
# ══════════════════════════════════════════════════════════════
def setup_logging(debug: bool = False):
"""
配置日志级别。
⚠️ Bug 规避openai-python v2.24.0 + pydantic-core
第三方 SDK logger 强制锁定在 WARNING避免触发
model_dump(by_alias=None) → pydantic-core Rust 层 TypeError。
Ref: https://github.com/openai/openai-python/issues/2921
"""
root_level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(
level=root_level,
format="%(asctime)s [%(levelname)s] %(name)s%(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
for name in ["openai", "openai._base_client", "anthropic", "httpx", "httpcore"]:
logging.getLogger(name).setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# ══════════════════════════════════════════════════════════════
# CLI
# ══════════════════════════════════════════════════════════════
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="AI-Powered API Test Generator & Coverage Analyzer",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
参数生成指令示例(可直接写在需求文本中):
"创建用户测试参数不少于10组覆盖边界值和异常值"
"查询接口至少8组边界值"
"所有接口参数不少于5组覆盖等价类"
""",
)
p.add_argument("--api-desc", required=True, help="接口描述 JSON 文件")
p.add_argument("--requirements", default="", help="测试需求(逗号分隔,支持参数生成指令)")
p.add_argument("--req-file", default="", help="测试需求文件(每行一条)")
p.add_argument("--batch-size", type=int, default=0, help="每批接口数量0=使用 config.LLM_BATCH_SIZE")
p.add_argument("--workers", type=int, default=0, help="并行执行线程数0=使用 config.TEST_MAX_WORKERS")
p.add_argument("--skip-run", action="store_true", help="只生成,不执行")
p.add_argument("--skip-analyze", action="store_true", help="跳过覆盖率分析")
p.add_argument("--output-dir", default="", help="测试文件根输出目录")
p.add_argument("--debug", action="store_true", help="开启 DEBUG 日志(第三方 SDK 仍保持 WARNING")
return p
def load_requirements(args) -> list[str]:
reqs: list[str] = []
if args.requirements:
reqs += [r.strip() for r in args.requirements.split(",") if r.strip()]
if args.req_file:
with open(args.req_file, "r", encoding="utf-8") as f:
reqs += [line.strip() for line in f if line.strip()]
if not reqs:
logger.error("未提供任何测试需求,请使用 --requirements 或 --req-file 指定。")
sys.exit(1)
return reqs
# ══════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════
def main():
args = build_arg_parser().parse_args()
setup_logging(debug=args.debug)
if args.output_dir: config.GENERATED_TESTS_DIR = args.output_dir
if args.batch_size > 0: config.LLM_BATCH_SIZE = args.batch_size
if args.workers > 0: config.TEST_MAX_WORKERS = args.workers
# ── Step 1: 解析接口描述 ──────────────────────────────────
logger.info("▶ Step 1: 解析接口描述 …")
parser = InterfaceParser()
descriptor = parser.parse_file(args.api_desc)
project = descriptor.project
project_desc = descriptor.description
interfaces = descriptor.interfaces
logger.info(f" 项目名称:{project}")
logger.info(f" 项目描述:{project_desc}")
logger.info(f" 接口数量:{len(interfaces)}")
for iface in interfaces:
if iface.protocol == "function":
logger.info(
f" [function] {iface.name} "
f"{iface.source_file} (module: {iface.module_path})"
)
else:
logger.info(f" [http] {iface.name}{iface.http_full_url}")
# ── Step 2: 加载并解析测试需求(含参数生成指令)──────────
logger.info("▶ Step 2: 加载并解析测试需求 …")
raw_requirements = load_requirements(args)
req_parser = RequirementParser()
clean_reqs, per_req_constraints, global_constraint = req_parser.parse(
raw_requirements
)
# 打印原始需求 & 解析结果
for i, (raw, clean, c) in enumerate(
zip(raw_requirements, clean_reqs, per_req_constraints), 1
):
if c.has_param_directive:
logger.info(
f" {i}. {raw}\n"
f" └─ 参数约束:{c} | 业务需求:{clean}"
)
else:
logger.info(f" {i}. {raw}")
if global_constraint.has_param_directive:
logger.info(f" 全局参数约束:{global_constraint}")
# 使用"清洗后的需求"传给 LLM去掉参数指令保留纯业务描述
# 同时保留原始需求用于覆盖率分析(保证与用户输入一致)
effective_requirements = clean_reqs
# ── Step 3: 构建 Prompt 并调用 LLM 生成测试用例 ──────────
logger.info(
f"▶ Step 3: 调用 LLM 生成测试用例 "
f"(batch_size={getattr(config, 'LLM_BATCH_SIZE', 10)}) …"
)
builder = PromptBuilder()
iface_summary = parser.to_summary_dict(interfaces)
project_header = builder.build_project_header(project, project_desc)
# System Prompt注入全局参数约束规则
system_prompt = builder.get_system_prompt(
global_constraint=global_constraint
)
test_cases = LLMClient().generate_test_cases(
system_prompt=system_prompt,
user_prompt="",
iface_summaries=iface_summary,
requirements=effective_requirements,
project_header=project_header,
param_constraint=global_constraint, # ← 注入参数约束
)
logger.info(f" 共生成测试用例:{len(test_cases)}")
# ── Step 4: 生成测试文件 ──────────────────────────────────
logger.info("▶ Step 4: 生成测试文件 …")
generator = TestGenerator(project=project, project_desc=project_desc)
test_files = generator.generate(test_cases)
out = generator.output_dir
logger.info(f" 输出目录:{out.resolve()}")
run_results = []
if not args.skip_run:
# ── Step 5: 并行执行测试 ──────────────────────────────
logger.info(
f"▶ Step 5: 执行测试 "
f"(workers={getattr(config, 'TEST_MAX_WORKERS', 8)}) …"
)
runner = TestRunner()
run_results = runner.run_all(test_files)
runner.print_summary(run_results)
runner.save_results(run_results, str(out / "run_results.json"))
if not args.skip_analyze:
# ── Step 6: 覆盖率分析 ────────────────────────────────
logger.info("▶ Step 6: 覆盖率分析 …")
report = CoverageAnalyzer(
interfaces=interfaces,
requirements=raw_requirements, # 用原始需求做覆盖率分析
test_cases=test_cases,
run_results=run_results,
).analyze()
# ── Step 7: 生成报告 ──────────────────────────────────
logger.info("▶ Step 7: 生成报告 …")
rg = ReportGenerator()
rg.save_json(report, str(out / "coverage_report.json"))
rg.save_html(report, str(out / "coverage_report.html"))
_print_terminal_summary(report, out, project, global_constraint)
logger.info(f"\n✅ 完成。输出目录:{out.resolve()}")
# ══════════════════════════════════════════════════════════════
# 终端摘要
# ══════════════════════════════════════════════════════════════
def _print_terminal_summary(
report: "CoverageReport",
out: Path,
project: str,
global_constraint: ParamConstraint,
):
W = 68
def bar(rate: float, w: int = 20) -> str:
filled = int(rate * w)
empty = w - filled
icon = "" if rate >= 0.8 else ("⚠️ " if rate >= 0.5 else "")
return f"{'' * filled}{'' * empty} {rate * 100:.1f}% {icon}"
print(f"\n{'' * W}")
print(f" 项目:{project}")
print(f" 覆盖率摘要")
print(f"{'' * W}")
print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
print(f" 用例通过率 {bar(report.pass_rate)}")
print(f"{'' * W}")
print(f" 测试用例总数 {report.total_test_cases}")
print(f" 覆盖缺口总数 {len(report.gaps)}")
print(f" 🔴 严重缺口 {report.critical_gap_count}")
print(f" 🟠 高优先级缺口 {report.high_gap_count}")
# 参数约束达成情况
if global_constraint.has_param_directive:
print(f"{'' * W}")
print(f" 参数生成约束 {global_constraint}")
actual = report.total_test_cases
needed = global_constraint.min_groups
status = "✅ 已满足" if actual >= needed else f"❌ 未满足(需 {needed} 组,实际 {actual} 组)"
print(f" 参数组数要求 {status}")
if report.gaps:
print(f"{'' * W}")
print(" Top 缺口最多显示8条")
icons = {"critical": "🔴", "high": "🟠", "medium": "🟡", "low": "🔵"}
for g in report.gaps[:8]:
print(f" {icons.get(g.severity, '')} [{g.gap_type}] {g.target}")
print(f"{g.suggestion}")
print(f"{'' * W}")
print(f" 输出目录:{out.resolve()}")
print(f" • coverage_report.html")
print(f" • coverage_report.json")
print(f" • run_results.json")
print(f" • test_cases_summary.json")
print(f"{'' * W}\n")
if __name__ == "__main__":
main()