diff --git a/ai_test_generator/core/llm_client.py b/ai_test_generator/core/llm_client.py
index a74f090..d2030a2 100644
--- a/ai_test_generator/core/llm_client.py
+++ b/ai_test_generator/core/llm_client.py
@@ -1,58 +1,204 @@
+"""
+LLM 客户端
+══════════════════════════════════════════════════════════════
+变更:
+ - generate_test_cases() 新增 param_constraint 参数
+ - 分批调用时将约束透传给 PromptBuilder.build_batch_prompt()
+"""
+
+from __future__ import annotations
import json
-import re
import logging
+import re
+import time
+from typing import Any
+
from openai import OpenAI
from config import config
logger = logging.getLogger(__name__)
-class LLMClient:
- """封装 LLM API 调用(OpenAI 兼容接口)"""
+# ══════════════════════════════════════════════════════════════
+# JSON 健壮解析工具
+# ══════════════════════════════════════════════════════════════
- def __init__(self):
- self.client = OpenAI(
- api_key=config.LLM_API_KEY,
- base_url=config.LLM_BASE_URL,
- )
+class RobustJSONParser:
- def generate_test_cases(self, system_prompt: str, user_prompt: str) -> list[dict]:
- logger.info(f"Calling LLM: model={config.LLM_MODEL}")
- logger.debug(f"--- USER PROMPT ---\n{user_prompt}\n---")
+ def parse(self, text: str) -> list[dict]:
+ text = text.strip()
+ text = re.sub(r'^```(?:json)?\s*', '', text, flags=re.MULTILINE)
+ text = re.sub(r'\s*```$', '', text, flags=re.MULTILINE)
+ text = text.strip()
- response = self.client.chat.completions.create(
- model=config.LLM_MODEL,
- temperature=config.LLM_TEMPERATURE,
- max_tokens=config.LLM_MAX_TOKENS,
- messages=[
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ],
- )
- raw = response.choices[0].message.content
- logger.debug(f"--- LLM RAW RESPONSE ---\n{raw[:800]}\n---")
- return self._parse_json(raw)
-
- # ── 解析 ──────────────────────────────────────────────────
-
- def _parse_json(self, content: str) -> list[dict]:
- content = content.strip()
- # 去除可能的 markdown 代码块
- content = re.sub(r"^```(?:json)?\s*", "", content, flags=re.MULTILINE)
- content = re.sub(r"\s*```\s*$", "", content, flags=re.MULTILINE)
- content = content.strip()
+ start = text.find('[')
+ if start == -1:
+ logger.warning("LLM 响应中未找到 JSON 数组。")
+ return []
+ text = text[start:]
try:
- data = json.loads(content)
+ result = json.loads(text)
+ if isinstance(result, list):
+ return result
except json.JSONDecodeError:
- # 尝试提取第一个 JSON 数组
- match = re.search(r"\[.*\]", content, re.DOTALL)
- if match:
- data = json.loads(match.group())
- else:
- logger.error(f"Cannot parse LLM response as JSON:\n{content[:600]}")
- raise
+ pass
- if not isinstance(data, list):
- raise ValueError(f"Expected JSON array, got {type(data)}")
- return data
\ No newline at end of file
+ text = self._fix_truncated(text)
+ try:
+ result = json.loads(text)
+ if isinstance(result, list):
+ logger.warning("JSON 截断修复成功,部分用例可能丢失。")
+ return result
+ except json.JSONDecodeError as e:
+ logger.error(f"JSON 解析失败(修复后仍无效):{e}")
+ return []
+
+ def _fix_truncated(self, text: str) -> str:
+ last = text.rfind('},')
+ if last == -1:
+ last = text.rfind('}')
+ if last == -1:
+ return text
+ return text[:last + 1] + ']'
+
+
+# ══════════════════════════════════════════════════════════════
+# LLM 客户端
+# ══════════════════════════════════════════════════════════════
+
+class LLMClient:
+
+ DEFAULT_BATCH_SIZE = 10
+
+ def __init__(self):
+ self.client = OpenAI(
+ api_key=config.LLM_API_KEY,
+ base_url=getattr(config, "LLM_BASE_URL", None),
+ )
+ self.model = config.LLM_MODEL
+ self.max_tokens = getattr(config, "LLM_MAX_TOKENS", 8192)
+ self.batch_size = getattr(config, "LLM_BATCH_SIZE", self.DEFAULT_BATCH_SIZE)
+ self.retry = getattr(config, "LLM_RETRY", 3)
+ self.retry_delay = getattr(config, "LLM_RETRY_DELAY", 5)
+ self.batch_interval = getattr(config, "LLM_BATCH_INTERVAL", 1)
+ self._parser = RobustJSONParser()
+
+ from core.prompt_builder import PromptBuilder
+ self._prompt_builder = PromptBuilder()
+
+ # ── 主入口 ────────────────────────────────────────────────
+
+ def generate_test_cases(
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ iface_summaries: list[dict] | None = None,
+ requirements: list[str] | None = None,
+ project_header: str = "",
+ param_constraint: "ParamConstraint | None" = None, # ← 新增
+ ) -> list[dict]:
+ """
+ 生成测试用例。
+
+ 方式 A(单次):传 user_prompt。
+ 方式 B(分批):传 iface_summaries + requirements,推荐大规模场景。
+
+ param_constraint:由 RequirementParser 解析出的全局参数约束,
+ 有值时自动注入参数数据集到每批 prompt。
+ """
+ if iface_summaries is not None and requirements is not None:
+ return self._generate_batched(
+ system_prompt, iface_summaries, requirements,
+ project_header, param_constraint,
+ )
+ return self._call_with_retry(system_prompt, user_prompt)
+
+ # ── 分批调用 ──────────────────────────────────────────────
+
+ def _generate_batched(
+ self,
+ system_prompt: str,
+ iface_summaries: list[dict],
+ requirements: list[str],
+ project_header: str,
+ param_constraint: Any,
+ ) -> list[dict]:
+ total = len(iface_summaries)
+ batches = self._make_batches(iface_summaries, self.batch_size)
+ all_cases: list[dict] = []
+
+ logger.info(
+ f"分批模式:共 {total} 个接口 → "
+ f"{len(batches)} 批 × 每批最多 {self.batch_size} 个"
+ )
+ if param_constraint and param_constraint.has_param_directive:
+ logger.info(
+ f"参数约束已启用:{param_constraint}"
+ )
+
+ for idx, batch in enumerate(batches, 1):
+ names = [i.get("name", "?") for i in batch]
+ logger.info(f" 第 {idx}/{len(batches)} 批:{names}")
+
+ user_prompt = self._prompt_builder.build_batch_prompt(
+ batch=batch,
+ requirements=requirements,
+ project_header=project_header,
+ param_constraint=param_constraint, # ← 透传
+ )
+
+ cases = self._call_with_retry(system_prompt, user_prompt)
+ logger.info(f" 第 {idx} 批 → 生成 {len(cases)} 个测试用例")
+ all_cases.extend(cases)
+
+ if idx < len(batches):
+ time.sleep(self.batch_interval)
+
+ logger.info(f"全部批次完成,共生成 {len(all_cases)} 个测试用例")
+ return all_cases
+
+ # ── 单次调用(含重试)────────────────────────────────────
+
+ def _call_with_retry(
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ ) -> list[dict]:
+ last_error: Exception | None = None
+
+ for attempt in range(1, self.retry + 1):
+ try:
+ logger.debug(f"LLM 调用第 {attempt}/{self.retry} 次 …")
+ response = self.client.chat.completions.create(
+ model=self.model,
+ max_tokens=self.max_tokens,
+ messages=[
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ],
+ )
+ raw = response.choices[0].message.content or ""
+ logger.debug(f"LLM 响应长度:{len(raw)} 字符")
+
+ cases = self._parser.parse(raw)
+ if cases:
+ return cases
+
+ logger.warning(f"第 {attempt} 次调用返回空结果,准备重试 …")
+
+ except Exception as e:
+ last_error = e
+ logger.warning(f"第 {attempt} 次调用失败:{e}")
+
+ if attempt < self.retry:
+ time.sleep(self.retry_delay)
+
+ logger.error(
+ f"已重试 {self.retry} 次,全部失败。最后一次错误:{last_error}"
+ )
+ return []
+
+ @staticmethod
+ def _make_batches(items: list[Any], size: int) -> list[list[Any]]:
+ return [items[i:i + size] for i in range(0, len(items), size)]
\ No newline at end of file
diff --git a/ai_test_generator/core/param_generator.py b/ai_test_generator/core/param_generator.py
new file mode 100644
index 0000000..da61601
--- /dev/null
+++ b/ai_test_generator/core/param_generator.py
@@ -0,0 +1,393 @@
+"""
+测试参数生成器
+══════════════════════════════════════════════════════════════
+职责:
+ 根据接口参数的数据类型和参数生成策略,
+ 自动生成覆盖正常值、边界值、异常值的测试数据集。
+
+支持的数据类型:
+ string / str
+ integer / int / number
+ float / double / decimal
+ boolean / bool
+ array / list
+ object / dict / map
+
+输出格式(供注入 prompt):
+ {
+ "param_name": {
+ "type": "string",
+ "groups": [
+ {"label": "正常值", "value": "hello", "category": "normal"},
+ {"label": "空字符串", "value": "", "category": "boundary"},
+ {"label": "超长字符串","value": "a"*256, "category": "exception"},
+ ...
+ ]
+ }
+ }
+"""
+
+from __future__ import annotations
+from dataclasses import dataclass, field
+from typing import Any
+
+from core.requirement_parser import ParamStrategy
+
+
+# ══════════════════════════════════════════════════════════════
+# 数据结构
+# ══════════════════════════════════════════════════════════════
+
+@dataclass
+class ParamGroup:
+ """单组测试参数。"""
+ label: str # 人类可读的描述,如"空字符串"
+ value: Any # 实际参数值
+ category: str # "normal" | "boundary" | "exception" | "equivalence"
+
+ def to_dict(self) -> dict:
+ return {
+ "label": self.label,
+ "value": self.value,
+ "category": self.category,
+ }
+
+
+@dataclass
+class ParamDataSet:
+ """单个参数的完整测试数据集。"""
+ param_name: str
+ param_type: str
+ groups: list[ParamGroup] = field(default_factory=list)
+
+ def to_dict(self) -> dict:
+ return {
+ "type": self.param_type,
+ "groups": [g.to_dict() for g in self.groups],
+ }
+
+ @property
+ def normal_values(self) -> list[Any]:
+ return [g.value for g in self.groups if g.category == "normal"]
+
+ @property
+ def boundary_values(self) -> list[Any]:
+ return [g.value for g in self.groups if g.category == "boundary"]
+
+ @property
+ def exception_values(self) -> list[Any]:
+ return [g.value for g in self.groups if g.category == "exception"]
+
+
+# ══════════════════════════════════════════════════════════════
+# 各类型测试数据库
+# ══════════════════════════════════════════════════════════════
+
+# ── string ────────────────────────────────────────────────────
+_STRING_NORMAL: list[ParamGroup] = [
+ ParamGroup("普通字符串", "hello", "normal"),
+ ParamGroup("中文字符串", "测试数据", "normal"),
+ ParamGroup("字母数字混合", "user123", "normal"),
+]
+_STRING_BOUNDARY: list[ParamGroup] = [
+ ParamGroup("空字符串", "", "boundary"),
+ ParamGroup("单字符", "a", "boundary"),
+ ParamGroup("最大长度(255)", "a" * 255, "boundary"),
+ ParamGroup("恰好256字符", "a" * 256, "boundary"),
+ ParamGroup("含空格", "hello world", "boundary"),
+ ParamGroup("首尾空格", " hello ", "boundary"),
+]
+_STRING_EXCEPTION: list[ParamGroup] = [
+ ParamGroup("None/null", None, "exception"),
+ ParamGroup("超长字符串", "a" * 1000, "exception"),
+ ParamGroup("纯空格", " ", "exception"),
+ ParamGroup("特殊字符", "!@#$%^&*()", "exception"),
+ ParamGroup("换行符", "line1\nline2", "exception"),
+ ParamGroup("SQL注入", "' OR '1'='1", "exception"),
+ ParamGroup("XSS注入", "", "exception"),
+ ParamGroup("Unicode特殊字符","你好\u0000世界", "exception"),
+ ParamGroup("数字类型误传", 12345, "exception"),
+]
+
+# ── integer ───────────────────────────────────────────────────
+_INTEGER_NORMAL: list[ParamGroup] = [
+ ParamGroup("典型正整数", 1, "normal"),
+ ParamGroup("较大正整数", 100, "normal"),
+]
+_INTEGER_BOUNDARY: list[ParamGroup] = [
+ ParamGroup("零", 0, "boundary"),
+ ParamGroup("正边界值1", 1, "boundary"),
+ ParamGroup("负边界值-1", -1, "boundary"),
+ ParamGroup("最大值(int32)", 2_147_483_647, "boundary"),
+ ParamGroup("最小值(int32)", -2_147_483_648, "boundary"),
+ ParamGroup("最大值+1", 2_147_483_648, "boundary"),
+]
+_INTEGER_EXCEPTION: list[ParamGroup] = [
+ ParamGroup("None/null", None, "exception"),
+ ParamGroup("字符串类型", "abc", "exception"),
+ ParamGroup("浮点数", 3.14, "exception"),
+ ParamGroup("布尔值true", True, "exception"),
+ ParamGroup("空字符串", "", "exception"),
+ ParamGroup("超大整数", 10 ** 20, "exception"),
+]
+
+# ── float ─────────────────────────────────────────────────────
+_FLOAT_NORMAL: list[ParamGroup] = [
+ ParamGroup("典型浮点数", 3.14, "normal"),
+ ParamGroup("整数形式", 1.0, "normal"),
+]
+_FLOAT_BOUNDARY: list[ParamGroup] = [
+ ParamGroup("零", 0.0, "boundary"),
+ ParamGroup("极小正数", 1e-10, "boundary"),
+ ParamGroup("极大正数", 1e10, "boundary"),
+ ParamGroup("负数", -3.14, "boundary"),
+ ParamGroup("负无穷", float('-inf'), "boundary"),
+ ParamGroup("正无穷", float('inf'), "boundary"),
+]
+_FLOAT_EXCEPTION: list[ParamGroup] = [
+ ParamGroup("None/null", None, "exception"),
+ ParamGroup("NaN", float('nan'), "exception"),
+ ParamGroup("字符串类型", "abc", "exception"),
+ ParamGroup("空字符串", "", "exception"),
+]
+
+# ── boolean ───────────────────────────────────────────────────
+_BOOLEAN_NORMAL: list[ParamGroup] = [
+ ParamGroup("true", True, "normal"),
+ ParamGroup("false", False, "normal"),
+]
+_BOOLEAN_BOUNDARY: list[ParamGroup] = [
+ ParamGroup("整数1", 1, "boundary"),
+ ParamGroup("整数0", 0, "boundary"),
+]
+_BOOLEAN_EXCEPTION: list[ParamGroup] = [
+ ParamGroup("None/null", None, "exception"),
+ ParamGroup("字符串'true'", "true", "exception"),
+ ParamGroup("字符串'false'", "false", "exception"),
+ ParamGroup("字符串'yes'", "yes", "exception"),
+ ParamGroup("随机字符串", "abc", "exception"),
+]
+
+# ── array ─────────────────────────────────────────────────────
+_ARRAY_NORMAL: list[ParamGroup] = [
+ ParamGroup("普通数组", [1, 2, 3], "normal"),
+ ParamGroup("字符串数组", ["a", "b", "c"], "normal"),
+]
+_ARRAY_BOUNDARY: list[ParamGroup] = [
+ ParamGroup("空数组", [], "boundary"),
+ ParamGroup("单元素数组", [1], "boundary"),
+ ParamGroup("大数组(100元素)", list(range(100)), "boundary"),
+ ParamGroup("含None元素", [1, None, 3], "boundary"),
+]
+_ARRAY_EXCEPTION: list[ParamGroup] = [
+ ParamGroup("None/null", None, "exception"),
+ ParamGroup("字符串类型", "not_an_array", "exception"),
+ ParamGroup("整数类型", 123, "exception"),
+ ParamGroup("嵌套过深", [[[[[1]]]]], "exception"),
+]
+
+# ── object ────────────────────────────────────────────────────
+_OBJECT_NORMAL: list[ParamGroup] = [
+ ParamGroup("完整字段对象", {"id": 1, "name": "test"}, "normal"),
+]
+_OBJECT_BOUNDARY: list[ParamGroup] = [
+ ParamGroup("空对象", {}, "boundary"),
+ ParamGroup("仅必填字段", {"id": 1}, "boundary"),
+ ParamGroup("含额外字段", {"id": 1, "extra": "x"}, "boundary"),
+]
+_OBJECT_EXCEPTION: list[ParamGroup] = [
+ ParamGroup("None/null", None, "exception"),
+ ParamGroup("字符串类型", "not_an_object", "exception"),
+ ParamGroup("数组类型", [1, 2, 3], "exception"),
+ ParamGroup("字段值类型错误", {"id": "not_int"}, "exception"),
+]
+
+# 类型 → 数据映射表
+_TYPE_DATA: dict[str, dict[str, list[ParamGroup]]] = {
+ "string": {"normal": _STRING_NORMAL, "boundary": _STRING_BOUNDARY, "exception": _STRING_EXCEPTION},
+ "integer": {"normal": _INTEGER_NORMAL, "boundary": _INTEGER_BOUNDARY, "exception": _INTEGER_EXCEPTION},
+ "float": {"normal": _FLOAT_NORMAL, "boundary": _FLOAT_BOUNDARY, "exception": _FLOAT_EXCEPTION},
+ "boolean": {"normal": _BOOLEAN_NORMAL, "boundary": _BOOLEAN_BOUNDARY, "exception": _BOOLEAN_EXCEPTION},
+ "array": {"normal": _ARRAY_NORMAL, "boundary": _ARRAY_BOUNDARY, "exception": _ARRAY_EXCEPTION},
+ "object": {"normal": _OBJECT_NORMAL, "boundary": _OBJECT_BOUNDARY, "exception": _OBJECT_EXCEPTION},
+}
+
+# 类型别名归一化
+_TYPE_ALIAS: dict[str, str] = {
+ "str": "string", "text": "string", "varchar": "string",
+ "int": "integer", "long": "integer", "number": "integer",
+ "float": "float", "double": "float", "decimal": "float",
+ "bool": "boolean",
+ "list": "array",
+ "dict": "object", "map": "object", "json": "object",
+}
+
+
+# ══════════════════════════════════════════════════════════════
+# 参数生成器
+# ══════════════════════════════════════════════════════════════
+
+class ParamGenerator:
+ """
+ 根据参数类型和策略生成测试数据集。
+ """
+
+ def generate(
+ self,
+ param_name: str,
+ param_type: str,
+ strategies: list[ParamStrategy],
+ min_groups: int = 2,
+ ) -> ParamDataSet:
+ """
+ 生成单个参数的测试数据集。
+
+ Args:
+ param_name : 参数名
+ param_type : 数据类型(支持别名)
+ strategies : 需要覆盖的策略列表
+ min_groups : 最少生成的参数组数
+
+ Returns:
+ ParamDataSet
+ """
+ norm_type = self._normalize_type(param_type)
+ type_data = _TYPE_DATA.get(norm_type, _TYPE_DATA["string"])
+
+ groups: list[ParamGroup] = []
+
+ # 始终包含正常值
+ groups.extend(type_data["normal"])
+
+ # 按策略追加
+ for strategy in strategies:
+ if strategy == ParamStrategy.BOUNDARY:
+ groups.extend(type_data["boundary"])
+ elif strategy == ParamStrategy.EXCEPTION:
+ groups.extend(type_data["exception"])
+ elif strategy == ParamStrategy.EQUIVALENCE:
+ # 等价类:正常值 + 部分边界值
+ groups.extend(type_data["boundary"][:3])
+ elif strategy == ParamStrategy.RANDOM:
+ groups.extend(self._random_groups(norm_type))
+
+ # 去重(按 label)
+ seen_labels: set[str] = set()
+ unique_groups: list[ParamGroup] = []
+ for g in groups:
+ if g.label not in seen_labels:
+ seen_labels.add(g.label)
+ unique_groups.append(g)
+
+ # 补齐到 min_groups
+ if len(unique_groups) < min_groups:
+ extra = type_data["boundary"] + type_data["exception"]
+ for g in extra:
+ if g.label not in seen_labels and len(unique_groups) < min_groups:
+ seen_labels.add(g.label)
+ unique_groups.append(g)
+
+ return ParamDataSet(
+ param_name=param_name,
+ param_type=norm_type,
+ groups=unique_groups,
+ )
+
+ def generate_for_interface(
+ self,
+ interface_summary: dict,
+ strategies: list[ParamStrategy],
+ min_groups: int = 2,
+ ) -> dict[str, ParamDataSet]:
+ """
+ 为一个接口的所有入参生成测试数据集。
+
+ Args:
+ interface_summary : parser.to_summary_dict() 中的单个接口 dict
+ strategies : 策略列表
+ min_groups : 最少组数
+
+ Returns:
+ { param_name: ParamDataSet }
+ """
+ result: dict[str, ParamDataSet] = {}
+ params = interface_summary.get("params", {})
+
+ for param_name, param_info in params.items():
+ if isinstance(param_info, dict):
+ param_type = param_info.get("type", "string")
+ inout = param_info.get("inout", "in")
+ else:
+ param_type = str(param_info)
+ inout = "in"
+
+ # 只对入参(in / inout)生成测试数据
+ if inout not in ("in", "inout"):
+ continue
+
+ result[param_name] = self.generate(
+ param_name=param_name,
+ param_type=param_type,
+ strategies=strategies,
+ min_groups=min_groups,
+ )
+
+ return result
+
+ # ── 工具 ──────────────────────────────────────────────────
+
+ @staticmethod
+ def _normalize_type(t: str) -> str:
+ t = t.lower().strip()
+ return _TYPE_ALIAS.get(t, t if t in _TYPE_DATA else "string")
+
+ @staticmethod
+ def _random_groups(norm_type: str) -> list[ParamGroup]:
+ """生成少量随机值(补充策略)。"""
+ import random, string
+ if norm_type == "string":
+ val = ''.join(random.choices(string.ascii_letters, k=8))
+ return [ParamGroup(f"随机字符串({val})", val, "equivalence")]
+ if norm_type == "integer":
+ val = random.randint(-1000, 1000)
+ return [ParamGroup(f"随机整数({val})", val, "equivalence")]
+ if norm_type == "float":
+ val = round(random.uniform(-100, 100), 4)
+ return [ParamGroup(f"随机浮点数({val})", val, "equivalence")]
+ return []
+
+ def to_prompt_text(
+ self,
+ datasets: dict[str, ParamDataSet],
+ min_groups: int,
+ ) -> str:
+ """
+ 将参数数据集转换为注入 prompt 的中文说明文本。
+ """
+ if not datasets:
+ return ""
+
+ lines = [
+ f"### 参数测试数据集(每个参数至少覆盖 {min_groups} 组)",
+ "",
+ ]
+ for param_name, ds in datasets.items():
+ lines.append(f"**参数:{param_name}**(类型:{ds.param_type})")
+ for g in ds.groups:
+ category_label = {
+ "normal": "正常值",
+ "boundary": "边界值",
+ "exception": "异常值",
+ "equivalence": "等价类",
+ }.get(g.category, g.category)
+ lines.append(
+ f" - [{category_label}] {g.label}:`{repr(g.value)}`"
+ )
+ lines.append("")
+
+ lines.append(
+ f"> 请从以上数据集中选取参数值组合,"
+ f"确保生成的测试用例总数不少于 **{min_groups} 组**,"
+ f"并覆盖正常值、边界值和异常值场景。"
+ )
+ return "\n".join(lines)
\ No newline at end of file
diff --git a/ai_test_generator/core/prompt_builder.py b/ai_test_generator/core/prompt_builder.py
index 1612df2..2ab07e8 100644
--- a/ai_test_generator/core/prompt_builder.py
+++ b/ai_test_generator/core/prompt_builder.py
@@ -1,167 +1,332 @@
"""
-提示词构建器
-升级点:
- - 传递项目 description
- - function 协议:source_file(.py) → module_path → from import
- - HTTP 协议:full_url = url(base) + name(path)
- - parameters 统一字段,inout 区分输入/输出
+提示词构建器(中文版)
+══════════════════════════════════════════════════════════════
+变更说明:
+ - 新增 build_param_directive_section():将参数约束注入 System Prompt
+ - 新增 build_param_dataset_section():将参数数据集注入 User Prompt
+ - build_batch_prompt / build_user_prompt 支持传入参数数据集
"""
import json
from core.parser import InterfaceInfo, InterfaceParser
+from core.requirement_parser import ParamConstraint, ParamStrategy
+from core.param_generator import ParamGenerator, ParamDataSet
-SYSTEM_PROMPT = """
-You are a senior software test engineer. Generate test cases based on the provided
-interface descriptions and test requirements.
+
+# ══════════════════════════════════════════════════════════════
+# System Prompt(中文)
+# ══════════════════════════════════════════════════════════════
+
+_SYSTEM_PROMPT_BASE = """
+你是一名资深软件测试工程师,擅长根据接口描述和测试需求生成高质量的测试用例。
═══════════════════════════════════════════════════════
-OUTPUT FORMAT
-Return ONLY a valid JSON array. No markdown fences. No extra text.
+【输出格式要求】
═══════════════════════════════════════════════════════
-Each element = ONE test case:
+- 只输出一个合法的 JSON 数组,不要包含任何 markdown 代码块(```)
+- 不要在 JSON 前后添加任何解释性文字
+- 每个数组元素代表一个测试用例,结构如下:
+
{
- "test_id" : "unique id, e.g. TC_001",
- "test_name" : "short descriptive name",
- "description" : "what this test verifies",
- "requirement" : "the original requirement text this case maps to",
- "requirement_id" : "e.g. REQ.01",
+ "test_id" : "唯一编号,如 TC_001",
+ "test_name" : "简短的测试名称",
+ "description" : "本用例验证的内容",
+ "requirement" : "对应的原始需求文本",
+ "requirement_id" : "需求编号,如 REQ.01(从接口描述中的 requirement_id 字段获取)",
"steps": [
{
- "step_no" : 1,
- "interface_name" : "function or endpoint name",
- "protocol" : "http or function",
- "url" : "source_file (function) or full_url (http)",
- "purpose" : "why this step is needed",
+ "step_no" : 1,
+ "interface_name" : "接口名称或函数名",
+ "protocol" : "http 或 function",
+ "url" : "function 协议填 source_file,http 协议填 full_url",
+ "purpose" : "本步骤的目的",
"input": {
- // Only parameters with inout = "in" or "inout"
- "param_name":
+ "参数名": "参数值"
},
- "use_output_of": { // optional
+ "use_output_of": {
"step_no" : 1,
- "field" : "user_id",
- "as_param" : "user_id"
+ "field" : "上一步返回值中的字段名",
+ "as_param" : "作为本步骤的参数名"
},
"assertions": [
{
- "field" : "field_name or 'return' or 'exception'",
- "operator" : "eq|ne|gt|lt|gte|lte|in|not_null|contains|raised|not_raised",
- "expected" : ,
- "message" : "human readable description"
+ "field" : "断言的字段名,或 'return'(整体返回值)或 'exception'(异常)",
+ "operator" : "eq | ne | gt | lt | gte | lte | in | not_null | contains | raised | not_raised",
+ "expected" : "期望值",
+ "message" : "断言说明"
}
]
}
],
- "test_data_notes" : "explanation of auto-generated test data",
- "test_code" : ""
+ "test_data_notes" : "测试数据说明(说明本用例使用了哪类参数:正常值/边界值/异常值)",
+ "test_code" : "完整可运行的 Python 测试脚本(见下方规则)"
}
═══════════════════════════════════════════════════════
-TEST CODE RULES
+【测试代码编写规则】
═══════════════════════════════════════════════════════
-1. Complete, runnable Python script. No external test framework.
-2. Allowed imports: standard library + `requests` + the actual module under test.
+1. test_code 必须是完整、可直接运行的 Python 脚本
+2. 不得使用 unittest 或 pytest 框架
+3. 只允许导入:Python 标准库、requests、被测模块
-── FUNCTION INTERFACES ────────────────────────────────
-3. Each function interface has:
- "source_file" : e.g. "create_user.py"
- "module_path" : e.g. "create_user" (derived from source_file)
-4. Import the REAL function using module_path:
+── function 协议接口 ──────────────────────────────────
+4. 使用 module_path 字段导入真实函数,若导入失败则使用桩函数:
try:
from import
except ImportError:
- # Stub fallback — simulate on_success for positive tests,
- # on_failure / raise Exception for negative tests
- def ():
- return # or raise Exception(...)
-5. Call the function with ONLY "in"/"inout" parameters.
-6. Capture the return value; assert against on_success or on_failure structure.
-7. If on_failure.value contains "exception" or "raises":
- - In negative tests: wrap call in try/except, assert exception IS raised.
- - Use field="exception", operator="raised", expected="Exception".
+ def (<入参列表>):
+ return # 桩函数,仅供结构验证
+5. 调用函数时只传 inout 为 "in" 或 "inout" 的参数
+6. 若 on_failure 包含 "exception",负向测试需用 try/except 捕获异常
-── HTTP INTERFACES ────────────────────────────────────
-8. Each HTTP interface has:
- "full_url" : complete URL, e.g. "http://127.0.0.1/api/delete_user"
- "method" : "get" | "post" | "put" | "delete" etc.
-9. Send request:
+── http 协议接口 ──────────────────────────────────────
+7. 使用 full_url 字段作为请求地址:
resp = requests.("", json=)
-10. Assert on resp.status_code and resp.json() fields.
+8. 断言 resp.status_code 和 resp.json() 中的字段
-── MULTI-STEP ─────────────────────────────────────────
-11. Execute steps in order; extract fields from previous step's return value
- and pass to next step via use_output_of mapping.
-
-── STRUCTURED OUTPUT (REQUIRED) ───────────────────────
-12. After EACH step, print:
- ##STEP_RESULT## {"step_no":,"interface_name":"...","status":"PASS|FAIL",
- "assertions":[{"field":"...","operator":"...","expected":...,
- "actual":...,"passed":true|false,"message":"..."}]}
-13. Final line must be:
- PASS: or FAIL:
-14. Wrap entire test body in try/except; print FAIL on any unhandled exception.
-15. Do NOT use unittest or pytest.
+── 结构化输出(必须遵守)─────────────────────────────
+9. 每个步骤执行后,必须在标准输出打印如下格式的一行:
+ ##STEP_RESULT## {"step_no":,"interface_name":"...","status":"PASS 或 FAIL",
+ "assertions":[{"field":"...","operator":"...","expected":...,
+ "actual":...,"passed":true 或 false,"message":"..."}]}
+10. 脚本最后一行必须是:
+ PASS: <摘要> 或 FAIL: <摘要>
+11. 整个脚本主体用 try/except 包裹,捕获未处理异常时打印 FAIL
═══════════════════════════════════════════════════════
-ASSERTION RULES BY RETURN TYPE
+【断言编写规则】
═══════════════════════════════════════════════════════
-return.type = "dict" → assert each key in on_success.value (positive)
- or on_failure.value (negative)
-return.type = "boolean" → field="return", operator="eq", expected=true/false
-return.type = "integer" → field="return", operator="eq"/"gt"/"lt"
-on_failure = exception → field="exception", operator="raised"
+- 返回值为 dict → 对 on_success / on_failure 中每个 key 单独断言
+- 返回值为 bool → field="return",operator="eq",expected=true 或 false
+- 返回值为 int → field="return",operator="eq" / "gt" / "lt"
+- 预期抛出异常 → field="exception",operator="raised"
+- 预期不抛出异常 → field="exception",operator="not_raised"
═══════════════════════════════════════════════════════
-COVERAGE GUIDELINES
+【覆盖度要求】
═══════════════════════════════════════════════════════
-- Per requirement: at least 1 positive + 1 negative test case.
-- Negative test_name/description MUST contain "negative" or "invalid".
-- Multi-interface requirements: single multi-step test case.
-- Cover ALL "in"/"inout" parameters (including optional ones).
-- Assert ALL fields described in on_success (positive) and on_failure (negative).
-- For "out" parameters: verify their values in assertions after the call.
+- 每条需求至少生成 1 个正向用例 + 1 个负向用例
+- 负向用例的 test_name 必须包含"负向"或"无效"
+- 覆盖接口的所有 inout 为 "in" 或 "inout" 的参数
+- 正向用例断言 on_success 的所有字段
+- 负向用例断言 on_failure 的所有字段
+"""
+
+# 参数生成规则段落(有参数约束时追加到 System Prompt)
+_PARAM_DIRECTIVE_TEMPLATE = """
+═══════════════════════════════════════════════════════
+【参数生成规则(本次任务特殊要求)】
+═══════════════════════════════════════════════════════
+{directives}
+- test_data_notes 字段必须说明本用例使用的参数类别(正常值/边界值/异常值)及具体值
+- 每个测试用例的 input 字段中的参数值,必须从下方"参数测试数据集"中选取
+- 禁止在 input 中使用未在数据集中出现的随机值
"""
+# ══════════════════════════════════════════════════════════════
+# PromptBuilder
+# ══════════════════════════════════════════════════════════════
+
class PromptBuilder:
- def get_system_prompt(self) -> str:
- return SYSTEM_PROMPT.strip()
+ def __init__(self):
+ self._param_gen = ParamGenerator()
+
+ # ── System Prompt ────────────────────────────────────────
+
+ def get_system_prompt(
+ self,
+ global_constraint: ParamConstraint = None,
+ ) -> str:
+ """
+ 构建 System Prompt。
+ 若存在全局参数约束,追加参数生成规则段落。
+ """
+ base = _SYSTEM_PROMPT_BASE.strip()
+
+ if global_constraint and global_constraint.has_param_directive:
+ directive_section = self._build_param_directive_section(
+ global_constraint
+ )
+ return base + "\n" + directive_section.strip()
+
+ return base
+
+ def _build_param_directive_section(
+ self,
+ constraint: ParamConstraint,
+ ) -> str:
+ """构建参数生成规则段落,注入 System Prompt。"""
+ lines: list[str] = []
+
+ if constraint.min_groups > 2:
+ lines.append(
+ f"- 每个接口的测试用例总数不少于 **{constraint.min_groups} 组**"
+ f"(正向 + 负向合计)"
+ )
+
+ strategy_labels = {
+ ParamStrategy.BOUNDARY: "边界值(空值、最大值、最小值、临界值)",
+ ParamStrategy.EXCEPTION: "异常值(None、类型错误、超长、特殊字符等)",
+ ParamStrategy.EQUIVALENCE: "等价类(有效等价类 + 无效等价类)",
+ ParamStrategy.RANDOM: "随机值(覆盖典型随机场景)",
+ }
+ if constraint.strategies:
+ strategy_str = "、".join(
+ strategy_labels[s]
+ for s in constraint.strategies
+ if s in strategy_labels
+ )
+ lines.append(f"- 参数值必须覆盖以下类型:{strategy_str}")
+
+ if not lines:
+ return ""
+
+ directives = "\n ".join(lines)
+ return _PARAM_DIRECTIVE_TEMPLATE.format(directives=directives)
+
+ # ── 项目信息头 ───────────────────────────────────────────
+
+ def build_project_header(
+ self,
+ project: str = "",
+ project_desc: str = "",
+ ) -> str:
+ if not project and not project_desc:
+ return ""
+ lines = ["## 项目信息"]
+ if project:
+ lines.append(f"项目名称:{project}")
+ if project_desc:
+ lines.append(f"项目描述:{project_desc}")
+ return "\n".join(lines)
+
+ # ── 参数数据集段落 ───────────────────────────────────────
+
+ def build_param_dataset_section(
+ self,
+ iface_summaries: list[dict],
+ constraint: ParamConstraint,
+ ) -> str:
+ """
+ 为批次内所有接口生成参数数据集,返回注入 user_prompt 的文本段落。
+ """
+ if not constraint.has_param_directive:
+ return ""
+
+ all_lines: list[str] = [
+ "## 参数测试数据集",
+ f"> 以下数据集由系统自动生成,请从中选取参数值组合,"
+ f"确保每个接口的测试用例总数不少于 **{constraint.min_groups} 组**。",
+ "",
+ ]
+
+ for iface in iface_summaries:
+ name = iface.get("name", "未知接口")
+ datasets = self._param_gen.generate_for_interface(
+ interface_summary=iface,
+ strategies=constraint.strategies,
+ min_groups=constraint.min_groups,
+ )
+ if not datasets:
+ continue
+
+ all_lines.append(f"### 接口:{name}")
+ text = self._param_gen.to_prompt_text(datasets, constraint.min_groups)
+ all_lines.append(text)
+ all_lines.append("")
+
+ return "\n".join(all_lines)
+
+ # ── 分批 user_prompt ────────────────────────────────────
+
+ def build_batch_prompt(
+ self,
+ batch: list[dict],
+ requirements: list[str],
+ project_header: str = "",
+ param_constraint: ParamConstraint = None,
+ ) -> str:
+ """
+ 构建单批次的 user_prompt。
+ 若存在参数约束,自动注入参数数据集。
+ """
+ req_lines = "\n".join(
+ f"{i + 1}. {r}" for i, r in enumerate(requirements)
+ )
+ parts: list[str] = []
+
+ if project_header:
+ parts.append(project_header)
+
+ parts.append(
+ "## 本批次接口描述\n"
+ + json.dumps(batch, ensure_ascii=False, indent=2)
+ )
+ parts.append("## 测试需求\n" + req_lines)
+
+ # 注入参数数据集
+ if param_constraint and param_constraint.has_param_directive:
+ dataset_section = self.build_param_dataset_section(
+ iface_summaries=batch,
+ constraint=param_constraint,
+ )
+ if dataset_section:
+ parts.append(dataset_section)
+
+ parts.append(
+ "请根据以上接口描述、测试需求和参数数据集,生成完整的测试用例(包含正向和负向)。\n"
+ "- function 协议:使用 module_path 字段导入被测函数\n"
+ "- http 协议:使用 full_url 字段作为请求地址\n"
+ "- 每条需求的 requirement_id 须填入对应用例的 requirement_id 字段\n"
+ "- test_data_notes 须说明本用例使用的参数类别和具体值"
+ )
+ return "\n\n".join(parts)
+
+ # ── 单次 user_prompt ────────────────────────────────────
def build_user_prompt(
self,
- requirements: list[str],
- interfaces: list[InterfaceInfo],
- parser: InterfaceParser,
- project: str = "",
- project_desc: str = "",
+ requirements: list[str],
+ interfaces: list[InterfaceInfo],
+ parser: InterfaceParser,
+ project: str = "",
+ project_desc: str = "",
+ param_constraint: ParamConstraint = None,
) -> str:
+ """
+ 单次调用模式(接口数 ≤ batch_size 时使用)。
+ """
iface_summary = parser.to_summary_dict(interfaces)
req_lines = "\n".join(
f"{i + 1}. {r}" for i, r in enumerate(requirements)
)
+ header = self.build_project_header(project, project_desc)
- # 项目信息头
- project_section = ""
- if project or project_desc:
- project_section = "## Project\n"
- if project:
- project_section += f"Name : {project}\n"
- if project_desc:
- project_section += f"Description: {project_desc}\n"
- project_section += "\n"
+ parts: list[str] = []
+ if header:
+ parts.append(header)
+ parts.append(
+ "## 接口描述\n"
+ + json.dumps(iface_summary, ensure_ascii=False, indent=2)
+ )
+ parts.append("## 测试需求\n" + req_lines)
- return (
- f"{project_section}"
- f"## Available Interfaces\n"
- f"{json.dumps(iface_summary, ensure_ascii=False, indent=2)}\n\n"
- f"## Test Requirements\n"
- f"Generate comprehensive test cases (positive + negative) for every interface.\n"
- f"- The generated test cases for each interface must meet the following requirements: {req_lines}\n"
- f"- For function interfaces: use 'module_path' for import, "
- f"'source_file' for reference.\n"
- f"- For HTTP interfaces: use 'full_url' as the request URL.\n"
- f"- Use requirement_id from the interface to populate the test case's "
- f"requirement_id field.\n"
- f"- Chain multiple interface calls in one test case when a requirement "
- f"involves more than one interface.\n"
- ).strip()
\ No newline at end of file
+ if param_constraint and param_constraint.has_param_directive:
+ dataset_section = self.build_param_dataset_section(
+ iface_summaries=iface_summary
+ if isinstance(iface_summary, list) else list(iface_summary.values()),
+ constraint=param_constraint,
+ )
+ if dataset_section:
+ parts.append(dataset_section)
+
+ parts.append(
+ "请根据以上接口描述、测试需求和参数数据集,生成完整的测试用例(包含正向和负向)。\n"
+ "- function 协议:使用 module_path 字段导入被测函数\n"
+ "- http 协议:使用 full_url 字段作为请求地址\n"
+ "- 每条需求的 requirement_id 须填入对应用例的 requirement_id 字段\n"
+ "- test_data_notes 须说明本用例使用的参数类别和具体值"
+ )
+ return "\n\n".join(parts)
\ No newline at end of file
diff --git a/ai_test_generator/core/requirement_parser.py b/ai_test_generator/core/requirement_parser.py
new file mode 100644
index 0000000..7cce453
--- /dev/null
+++ b/ai_test_generator/core/requirement_parser.py
@@ -0,0 +1,211 @@
+"""
+需求解析器
+══════════════════════════════════════════════════════════════
+职责:
+ 从用户填写的测试需求文本中,识别并提取参数生成约束指令:
+ - 参数组数约束:如"不少于10组"、"至少5组"、"生成8组"
+ - 参数策略约束:如"边界值"、"异常值"、"等价类"、"随机值"
+ - 数据类型约束:如"字符串边界值"、"整数异常值"
+ 解析结果供 ParamGenerator 和 PromptBuilder 使用。
+"""
+
+from __future__ import annotations
+import re
+from dataclasses import dataclass, field
+from enum import Enum
+
+
+# ══════════════════════════════════════════════════════════════
+# 枚举:参数生成策略
+# ══════════════════════════════════════════════════════════════
+
+class ParamStrategy(str, Enum):
+ BOUNDARY = "boundary" # 边界值
+ EXCEPTION = "exception" # 异常值 / 错误值
+ EQUIVALENCE = "equivalence" # 等价类
+ RANDOM = "random" # 随机值
+ CUSTOM = "custom" # 自定义(用户在需求中直接指定)
+
+
+# ══════════════════════════════════════════════════════════════
+# 数据结构:解析结果
+# ══════════════════════════════════════════════════════════════
+
+@dataclass
+class ParamConstraint:
+ """
+ 从单条需求文本中解析出的参数生成约束。
+ """
+ # 最少生成的参数组数(正向 + 负向合计)
+ min_groups: int = 2
+
+ # 需要覆盖的策略集合
+ strategies: list[ParamStrategy] = field(default_factory=list)
+
+ # 原始需求文本(去除参数指令后的纯业务需求部分)
+ clean_requirement: str = ""
+
+ # 是否显式指定了参数要求(用于区分"用户有要求"和"默认值")
+ has_param_directive: bool = False
+
+ def __str__(self) -> str:
+ parts = []
+ if self.min_groups > 2:
+ parts.append(f"至少 {self.min_groups} 组参数")
+ for s in self.strategies:
+ parts.append(_STRATEGY_LABEL[s])
+ return "、".join(parts) if parts else "默认(1正1负)"
+
+
+_STRATEGY_LABEL: dict[ParamStrategy, str] = {
+ ParamStrategy.BOUNDARY: "边界值",
+ ParamStrategy.EXCEPTION: "异常值",
+ ParamStrategy.EQUIVALENCE: "等价类",
+ ParamStrategy.RANDOM: "随机值",
+ ParamStrategy.CUSTOM: "自定义",
+}
+
+
+# ══════════════════════════════════════════════════════════════
+# 解析规则
+# ══════════════════════════════════════════════════════════════
+
+# 数量约束:匹配"不少于N组"、"至少N组"、"最少N组"、"生成N组"、"N组以上"
+_RE_MIN_GROUPS = re.compile(
+ r'(?:不少于|至少|最少|生成|共|需要|要求)\s*(\d+)\s*组'
+ r'|(\d+)\s*组(?:以上|及以上|或以上)',
+ re.UNICODE,
+)
+
+# 策略关键词映射(关键词 → 策略枚举)
+_STRATEGY_KEYWORDS: list[tuple[list[str], ParamStrategy]] = [
+ (["边界值", "边界", "boundary", "边界测试"], ParamStrategy.BOUNDARY),
+ (["异常值", "异常", "错误值", "非法值",
+ "exception", "invalid", "error"], ParamStrategy.EXCEPTION),
+ (["等价类", "等价", "equivalence"], ParamStrategy.EQUIVALENCE),
+ (["随机", "随机值", "random"], ParamStrategy.RANDOM),
+]
+
+# 参数指令整体识别:包含上述任意关键词则认为是参数指令
+_RE_PARAM_DIRECTIVE = re.compile(
+ r'(?:参数|测试参数|测试数据|数据|入参)'
+ r'.*?(?:不少于|至少|最少|生成|边界|异常|等价|随机|\d+\s*组)',
+ re.UNICODE,
+)
+
+
+# ══════════════════════════════════════════════════════════════
+# 解析器
+# ══════════════════════════════════════════════════════════════
+
+class RequirementParser:
+ """
+ 解析测试需求列表,提取参数生成约束,返回:
+ - clean_requirements : 去除参数指令后的纯业务需求列表
+ - constraints : 每条需求对应的 ParamConstraint
+ - global_constraint : 全局约束(从"全局"/"所有接口"等描述中提取)
+ """
+
+ def parse(
+ self,
+ requirements: list[str],
+ ) -> tuple[list[str], list[ParamConstraint], ParamConstraint]:
+ """
+ Returns:
+ (clean_requirements, per_req_constraints, global_constraint)
+ """
+ clean_reqs: list[str] = []
+ constraints: list[ParamConstraint] = []
+ global_c = ParamConstraint()
+
+ for req in requirements:
+ c = self._parse_one(req)
+ constraints.append(c)
+ clean_reqs.append(c.clean_requirement)
+
+ # 若某条需求是全局性描述(如"所有接口参数不少于10组"),
+ # 则将其约束提升为全局约束
+ if self._is_global(req) and c.has_param_directive:
+ global_c = self._merge(global_c, c)
+
+ return clean_reqs, constraints, global_c
+
+ # ── 解析单条需求 ──────────────────────────────────────────
+
+ def _parse_one(self, req: str) -> ParamConstraint:
+ c = ParamConstraint(clean_requirement=req)
+
+ # 1. 检测是否包含参数指令
+ if not self._has_directive(req):
+ return c
+
+ c.has_param_directive = True
+
+ # 2. 提取数量约束
+ m = _RE_MIN_GROUPS.search(req)
+ if m:
+ n = int(m.group(1) or m.group(2))
+ c.min_groups = max(n, 2) # 至少保留 1正1负
+
+ # 3. 提取策略约束
+ seen: set[ParamStrategy] = set()
+ for keywords, strategy in _STRATEGY_KEYWORDS:
+ if any(kw in req for kw in keywords):
+ seen.add(strategy)
+ c.strategies = list(seen)
+
+ # 4. 若未指定策略,默认同时覆盖边界值和异常值
+ if not c.strategies:
+ c.strategies = [ParamStrategy.BOUNDARY, ParamStrategy.EXCEPTION]
+
+ # 5. 清理需求文本:去除参数指令部分,保留业务描述
+ c.clean_requirement = self._strip_directive(req)
+
+ return c
+
+ # ── 工具方法 ──────────────────────────────────────────────
+
+ @staticmethod
+ def _has_directive(req: str) -> bool:
+ """判断需求文本是否包含参数生成指令。"""
+ return bool(_RE_PARAM_DIRECTIVE.search(req)) or bool(
+ _RE_MIN_GROUPS.search(req)
+ )
+
+ @staticmethod
+ def _is_global(req: str) -> bool:
+ """判断是否为全局性参数要求。"""
+ return any(
+ kw in req
+ for kw in ["所有接口", "全部接口", "全局", "所有测试", "全部测试"]
+ )
+
+ @staticmethod
+ def _strip_directive(req: str) -> str:
+ """
+ 去除需求文本中的参数指令部分,保留纯业务描述。
+ 策略:以中文逗号、分号、括号等分隔,去除含参数指令的子句。
+ """
+ # 按常见分隔符拆分
+ parts = re.split(r'[,,;;(())【】\n]', req)
+ clean_parts = []
+ for part in parts:
+ part = part.strip()
+ if not part:
+ continue
+ # 若该子句包含参数指令关键词,则跳过
+ if _RE_PARAM_DIRECTIVE.search(part) or _RE_MIN_GROUPS.search(part):
+ continue
+ clean_parts.append(part)
+ result = ",".join(clean_parts).strip()
+ return result if result else req # 若全部被去除则保留原文
+
+ @staticmethod
+ def _merge(base: ParamConstraint, other: ParamConstraint) -> ParamConstraint:
+ """合并两个约束,取较大值。"""
+ merged = ParamConstraint(
+ min_groups=max(base.min_groups, other.min_groups),
+ strategies=list(set(base.strategies) | set(other.strategies)),
+ has_param_directive=True,
+ )
+ return merged
\ No newline at end of file
diff --git a/ai_test_generator/core/test_runner.py b/ai_test_generator/core/test_runner.py
index 90a3e7f..0f91940 100644
--- a/ai_test_generator/core/test_runner.py
+++ b/ai_test_generator/core/test_runner.py
@@ -1,168 +1,278 @@
-import os
-import sys
+"""
+测试执行器
+══════════════════════════════════════════════════════════════
+大规模支持改造:
+ 1. 并行执行:ThreadPoolExecutor,默认 8 个并发
+ 2. 超时控制:单个用例超时不阻塞整体
+ 3. 步骤级结果解析:##STEP_RESULT## 行
+ 4. 进度显示:实时打印完成数
+ 5. 结果持久化:保存 run_results.json
+"""
+
+from __future__ import annotations
import json
-import subprocess
-import time
import logging
-from pathlib import Path
+import subprocess
+import sys
+import time
+from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
+from pathlib import Path
+
from config import config
logger = logging.getLogger(__name__)
+STEP_RESULT_PREFIX = "##STEP_RESULT##"
+
+
+# ══════════════════════════════════════════════════════════════
+# 数据结构
+# ══════════════════════════════════════════════════════════════
+
+@dataclass
+class AssertionResult:
+ field: str
+ operator: str
+ expected: object
+ actual: object
+ passed: bool
+ message: str = ""
+
@dataclass
class StepResult:
step_no: int
interface_name: str
- status: str # PASS | FAIL | ERROR
- assertions: list[dict] = field(default_factory=list)
- # 每条 assertion: {"field","operator","expected","actual","passed","message"}
+ status: str # "PASS" | "FAIL"
+ assertions: list[AssertionResult] = field(default_factory=list)
@dataclass
class TestResult:
test_id: str
file_path: str
- status: str # PASS | FAIL | ERROR | TIMEOUT
- message: str = ""
- duration: float = 0.0
- stdout: str = ""
- stderr: str = ""
+ status: str # "PASS" | "FAIL" | "ERROR" | "TIMEOUT"
+ message: str = ""
+ duration: float = 0.0
step_results: list[StepResult] = field(default_factory=list)
+ stdout: str = ""
+ stderr: str = ""
+# ══════════════════════════════════════════════════════════════
+# 执行器
+# ══════════════════════════════════════════════════════════════
+
class TestRunner:
+ def __init__(self):
+ self.timeout = getattr(config, "TEST_TIMEOUT", 60)
+ self.max_workers = getattr(config, "TEST_MAX_WORKERS", 8)
+ self.python_bin = getattr(config, "TEST_PYTHON_BIN", sys.executable)
+
+ # ── 主入口 ────────────────────────────────────────────────
+
def run_all(self, test_files: list[Path]) -> list[TestResult]:
- results = []
- total = len(test_files)
- print(f"\n{'═'*62}")
- print(f" Running {total} test(s) …")
- print(f"{'═'*62}")
- for idx, fp in enumerate(test_files, 1):
- print(f"\n[{idx}/{total}] {fp.name}")
- result = self._run_one(fp)
- results.append(result)
- self._print_result(result)
+ """
+ 并行执行所有测试文件,返回结果列表。
+ """
+ if not test_files:
+ logger.warning("No test files to run.")
+ return []
+
+ total = len(test_files)
+ results: list[TestResult] = []
+ done = 0
+
+ logger.info(
+ f"Running {total} test(s) with "
+ f"max_workers={self.max_workers}, timeout={self.timeout}s"
+ )
+
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
+ future_map = {
+ executor.submit(self._run_one, f): f
+ for f in test_files
+ }
+ for future in as_completed(future_map):
+ result = future.result()
+ results.append(result)
+ done += 1
+ icon = "✅" if result.status == "PASS" else "❌"
+ logger.info(
+ f" [{done}/{total}] {icon} {result.test_id} "
+ f"({result.status}) {result.duration:.2f}s"
+ )
+
+ # 按文件名排序,保持输出稳定
+ results.sort(key=lambda r: r.file_path)
return results
- # ── 执行单个脚本 ──────────────────────────────────────────
+ # ── 执行单个测试文件 ──────────────────────────────────────
def _run_one(self, file_path: Path) -> TestResult:
- test_id = file_path.stem
- t0 = time.time()
+ test_id = file_path.stem
+ start = time.monotonic()
+
try:
proc = subprocess.run(
- [sys.executable, str(file_path)],
- capture_output=True, text=True,
- timeout=config.TEST_TIMEOUT,
- env=self._env(),
+ [self.python_bin, str(file_path)],
+ capture_output=True,
+ text=True,
+ timeout=self.timeout,
)
- duration = time.time() - t0
- stdout = proc.stdout.strip()
- stderr = proc.stderr.strip()
- status, message = self._parse_output(stdout, proc.returncode)
- step_results = self._parse_step_results(stdout)
- return TestResult(
- test_id=test_id, file_path=str(file_path),
- status=status, message=message,
- duration=duration, stdout=stdout, stderr=stderr,
- step_results=step_results,
+ duration = time.monotonic() - start
+ return self._parse_output(
+ test_id, str(file_path), proc, duration
)
+
except subprocess.TimeoutExpired:
+ duration = time.monotonic() - start
+ logger.warning(f" TIMEOUT: {test_id} ({duration:.1f}s)")
return TestResult(
- test_id=test_id, file_path=str(file_path),
- status="TIMEOUT", message=f"Exceeded {config.TEST_TIMEOUT}s",
- duration=time.time() - t0,
+ test_id=test_id,
+ file_path=str(file_path),
+ status="TIMEOUT",
+ message=f"Exceeded {self.timeout}s timeout",
+ duration=duration,
)
except Exception as e:
+ duration = time.monotonic() - start
+ logger.error(f" ERROR running {test_id}: {e}")
return TestResult(
- test_id=test_id, file_path=str(file_path),
- status="ERROR", message=str(e),
- duration=time.time() - t0,
+ test_id=test_id,
+ file_path=str(file_path),
+ status="ERROR",
+ message=str(e),
+ duration=duration,
)
# ── 解析输出 ──────────────────────────────────────────────
- def _parse_output(self, stdout: str, returncode: int) -> tuple[str, str]:
- if not stdout:
- return "FAIL", f"No output (exit={returncode})"
- last = stdout.strip().splitlines()[-1].strip()
- upper = last.upper()
- if upper.startswith("PASS"):
- return "PASS", last[5:].strip()
- if upper.startswith("FAIL"):
- return "FAIL", last[5:].strip()
- return ("PASS" if returncode == 0 else "FAIL"), last
+ def _parse_output(
+ self,
+ test_id: str,
+ file_path: str,
+ proc: subprocess.CompletedProcess,
+ duration: float,
+ ) -> TestResult:
+ stdout = proc.stdout or ""
+ stderr = proc.stderr or ""
+ step_results = self._parse_step_results(stdout)
+
+ # 最后一行决定整体状态
+ last_line = stdout.strip().splitlines()[-1] if stdout.strip() else ""
+ if last_line.startswith("PASS"):
+ status = "PASS"
+ message = last_line
+ elif last_line.startswith("FAIL"):
+ status = "FAIL"
+ message = last_line
+ elif proc.returncode != 0:
+ status = "FAIL"
+ message = f"Exit code {proc.returncode}: {stderr[:200]}"
+ else:
+ status = "ERROR"
+ message = "No PASS/FAIL line found in output"
+
+ return TestResult(
+ test_id=test_id,
+ file_path=file_path,
+ status=status,
+ message=message,
+ duration=duration,
+ step_results=step_results,
+ stdout=stdout,
+ stderr=stderr,
+ )
def _parse_step_results(self, stdout: str) -> list[StepResult]:
- """
- 解析脚本中以 ##STEP_RESULT## 开头的结构化输出行
- 格式:##STEP_RESULT##
- """
results = []
for line in stdout.splitlines():
line = line.strip()
- if line.startswith("##STEP_RESULT##"):
- try:
- data = json.loads(line[len("##STEP_RESULT##"):].strip())
- results.append(StepResult(
- step_no=data.get("step_no", 0),
- interface_name=data.get("interface_name", ""),
- status=data.get("status", ""),
- assertions=data.get("assertions", []),
- ))
- except Exception:
- pass
+ if not line.startswith(STEP_RESULT_PREFIX):
+ continue
+ try:
+ raw = json.loads(line[len(STEP_RESULT_PREFIX):].strip())
+ assertions = [
+ AssertionResult(
+ field=a.get("field", ""),
+ operator=a.get("operator", ""),
+ expected=a.get("expected"),
+ actual=a.get("actual"),
+ passed=bool(a.get("passed", False)),
+ message=a.get("message", ""),
+ )
+ for a in raw.get("assertions", [])
+ ]
+ results.append(StepResult(
+ step_no=raw.get("step_no", 0),
+ interface_name=raw.get("interface_name", ""),
+ status=raw.get("status", "FAIL"),
+ assertions=assertions,
+ ))
+ except (json.JSONDecodeError, KeyError) as e:
+ logger.debug(f"Step result parse error: {e}")
return results
- def _env(self) -> dict:
- env = os.environ.copy()
- env["HTTP_BASE_URL"] = config.HTTP_BASE_URL
- return env
-
- # ── 打印 ──────────────────────────────────────────────────
-
- def _print_result(self, r: TestResult):
- icon = {"PASS": "✅", "FAIL": "❌", "TIMEOUT": "⏱️", "ERROR": "⚠️"}.get(r.status, "❓")
- print(f" {icon} [{r.status}] {r.test_id} ({r.duration:.2f}s)")
- print(f" {r.message}")
- for sr in r.step_results:
- s_icon = "✅" if sr.status == "PASS" else "❌"
- print(f" {s_icon} Step {sr.step_no}: {sr.interface_name}")
- for a in sr.assertions:
- a_icon = "✔" if a.get("passed") else "✘"
- print(f" {a_icon} {a.get('message','')} "
- f"(expected={a.get('expected')}, actual={a.get('actual')})")
- if r.stderr:
- print(f" stderr: {r.stderr[:300]}")
+ # ── 摘要打印 ──────────────────────────────────────────────
def print_summary(self, results: list[TestResult]):
- total = len(results)
- passed = sum(1 for r in results if r.status == "PASS")
- failed = sum(1 for r in results if r.status == "FAIL")
- errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT"))
- print(f"\n{'═'*62}")
- print(f" TEST SUMMARY")
- print(f"{'═'*62}")
- print(f" Total : {total}")
+ total = len(results)
+ passed = sum(1 for r in results if r.status == "PASS")
+ failed = sum(1 for r in results if r.status == "FAIL")
+ errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT"))
+ avg_dur = sum(r.duration for r in results) / total if total else 0
+
+ print(f"\n{'─' * 56}")
+ print(f" Test Run Summary ({total} cases)")
+ print(f"{'─' * 56}")
print(f" ✅ PASS : {passed}")
print(f" ❌ FAIL : {failed}")
print(f" ⚠️ ERROR : {errors}")
- print(f" Pass Rate: {passed/total*100:.1f}%" if total else " Pass Rate: N/A")
- print(f"{'═'*62}\n")
+ print(f" ⏱ Avg : {avg_dur:.2f}s / case")
+ print(f"{'─' * 56}")
+
+ if failed or errors:
+ print(" Failed / Error cases:")
+ for r in results:
+ if r.status not in ("PASS",):
+ print(f" [{r.status}] {r.test_id}: {r.message}")
+ print()
+
+ # ── 持久化 ────────────────────────────────────────────────
def save_results(self, results: list[TestResult], path: str):
- with open(path, "w", encoding="utf-8") as f:
- json.dump([{
- "test_id": r.test_id, "status": r.status,
- "message": r.message, "duration": r.duration,
- "stdout": r.stdout, "stderr": r.stderr,
+ data = [
+ {
+ "test_id": r.test_id,
+ "file_path": r.file_path,
+ "status": r.status,
+ "message": r.message,
+ "duration": round(r.duration, 3),
"step_results": [
- {"step_no": sr.step_no, "interface_name": sr.interface_name,
- "status": sr.status, "assertions": sr.assertions}
- for sr in r.step_results
+ {
+ "step_no": s.step_no,
+ "interface_name": s.interface_name,
+ "status": s.status,
+ "assertions": [
+ {
+ "field": a.field,
+ "operator": a.operator,
+ "expected": a.expected,
+ "actual": a.actual,
+ "passed": a.passed,
+ "message": a.message,
+ }
+ for a in s.assertions
+ ],
+ }
+ for s in r.step_results
],
- } for r in results], f, ensure_ascii=False, indent=2)
+ }
+ for r in results
+ ]
+ with open(path, "w", encoding="utf-8") as f:
+ json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"Run results → {path}")
\ No newline at end of file
diff --git a/ai_test_generator/main.py b/ai_test_generator/main.py
index 5714c79..7ff44e7 100644
--- a/ai_test_generator/main.py
+++ b/ai_test_generator/main.py
@@ -3,11 +3,18 @@
AI-Powered API Test Generator, Runner & Coverage Analyzer
──────────────────────────────────────────────────────────
Usage:
+ # 基础用法
python main.py --api-desc examples/api_desc.json \\
--requirements "创建用户,删除用户"
+ # 在需求中自然语言描述参数生成要求
python main.py --api-desc examples/api_desc.json \\
- --req-file examples/requirements.txt
+ --requirements "创建用户(测试参数不少于10组,覆盖边界值和异常值),删除用户"
+
+ # 从文件读取需求(每行一条,支持参数生成指令)
+ python main.py --api-desc examples/api_desc.json \\
+ --req-file examples/requirements.txt \\
+ --batch-size 10
"""
import argparse
@@ -17,6 +24,7 @@ from pathlib import Path
from config import config
from core.parser import InterfaceParser, ApiDescriptor
+from core.requirement_parser import RequirementParser, ParamConstraint
from core.prompt_builder import PromptBuilder
from core.llm_client import LLMClient
from core.test_generator import TestGenerator
@@ -24,48 +32,57 @@ from core.test_runner import TestRunner
from core.analyzer import CoverageAnalyzer, CoverageReport
from core.report_generator import ReportGenerator
-logging.basicConfig(
- level=logging.INFO,
- format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
- handlers=[logging.StreamHandler(sys.stdout)],
-)
+
+# ══════════════════════════════════════════════════════════════
+# 日志初始化
+# ══════════════════════════════════════════════════════════════
+
+def setup_logging(debug: bool = False):
+ """
+ 配置日志级别。
+
+ ⚠️ Bug 规避(openai-python v2.24.0 + pydantic-core):
+ 第三方 SDK logger 强制锁定在 WARNING,避免触发
+ model_dump(by_alias=None) → pydantic-core Rust 层 TypeError。
+ Ref: https://github.com/openai/openai-python/issues/2921
+ """
+ root_level = logging.DEBUG if debug else logging.INFO
+ logging.basicConfig(
+ level=root_level,
+ format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+ for name in ["openai", "openai._base_client", "anthropic", "httpx", "httpcore"]:
+ logging.getLogger(name).setLevel(logging.WARNING)
+
+
logger = logging.getLogger(__name__)
-# ── CLI ───────────────────────────────────────────────────────
+# ══════════════════════════════════════════════════════════════
+# CLI
+# ══════════════════════════════════════════════════════════════
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
- description="AI-Powered API Test Generator & Coverage Analyzer"
- )
- p.add_argument(
- "--api-desc", required=True,
- help="接口描述 JSON 文件(含 project / description / units 字段)",
- )
- p.add_argument(
- "--requirements", default="",
- help="测试需求(逗号分隔)",
- )
- p.add_argument(
- "--req-file", default="",
- help="测试需求文件(每行一条)",
- )
- p.add_argument(
- "--skip-run", action="store_true",
- help="只生成测试文件,不执行",
- )
- p.add_argument(
- "--skip-analyze", action="store_true",
- help="跳过覆盖率分析",
- )
- p.add_argument(
- "--output-dir", default="",
- help="测试文件根输出目录(默认 config.GENERATED_TESTS_DIR)",
- )
- p.add_argument(
- "--debug", action="store_true",
- help="开启 DEBUG 日志",
+ description="AI-Powered API Test Generator & Coverage Analyzer",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+参数生成指令示例(可直接写在需求文本中):
+ "创建用户,测试参数不少于10组,覆盖边界值和异常值"
+ "查询接口(至少8组,边界值)"
+ "所有接口参数不少于5组,覆盖等价类"
+ """,
)
+ p.add_argument("--api-desc", required=True, help="接口描述 JSON 文件")
+ p.add_argument("--requirements", default="", help="测试需求(逗号分隔,支持参数生成指令)")
+ p.add_argument("--req-file", default="", help="测试需求文件(每行一条)")
+ p.add_argument("--batch-size", type=int, default=0, help="每批接口数量(0=使用 config.LLM_BATCH_SIZE)")
+ p.add_argument("--workers", type=int, default=0, help="并行执行线程数(0=使用 config.TEST_MAX_WORKERS)")
+ p.add_argument("--skip-run", action="store_true", help="只生成,不执行")
+ p.add_argument("--skip-analyze", action="store_true", help="跳过覆盖率分析")
+ p.add_argument("--output-dir", default="", help="测试文件根输出目录")
+ p.add_argument("--debug", action="store_true", help="开启 DEBUG 日志(第三方 SDK 仍保持 WARNING)")
return p
@@ -77,82 +94,111 @@ def load_requirements(args) -> list[str]:
with open(args.req_file, "r", encoding="utf-8") as f:
reqs += [line.strip() for line in f if line.strip()]
if not reqs:
- logger.error(
- "No requirements provided. Use --requirements or --req-file."
- )
+ logger.error("未提供任何测试需求,请使用 --requirements 或 --req-file 指定。")
sys.exit(1)
return reqs
-# ── Main ──────────────────────────────────────────────────────
+# ══════════════════════════════════════════════════════════════
+# Main
+# ══════════════════════════════════════════════════════════════
def main():
args = build_arg_parser().parse_args()
+ setup_logging(debug=args.debug)
- if args.debug:
- logging.getLogger().setLevel(logging.DEBUG)
- if args.output_dir:
- config.GENERATED_TESTS_DIR = args.output_dir
+ if args.output_dir: config.GENERATED_TESTS_DIR = args.output_dir
+ if args.batch_size > 0: config.LLM_BATCH_SIZE = args.batch_size
+ if args.workers > 0: config.TEST_MAX_WORKERS = args.workers
# ── Step 1: 解析接口描述 ──────────────────────────────────
- logger.info("▶ Step 1: Parsing interface description …")
- parser: InterfaceParser = InterfaceParser()
- descriptor: ApiDescriptor = parser.parse_file(args.api_desc)
+ logger.info("▶ Step 1: 解析接口描述 …")
+ parser = InterfaceParser()
+ descriptor = parser.parse_file(args.api_desc)
project = descriptor.project
project_desc = descriptor.description
interfaces = descriptor.interfaces
- logger.info(f" Project : {project}")
- logger.info(f" Description: {project_desc}")
- logger.info(
- f" Interfaces : {len(interfaces)} — "
- f"{[i.name for i in interfaces]}"
- )
- # 打印每个接口的 url 解析结果
+ logger.info(f" 项目名称:{project}")
+ logger.info(f" 项目描述:{project_desc}")
+ logger.info(f" 接口数量:{len(interfaces)}")
for iface in interfaces:
if iface.protocol == "function":
logger.info(
- f" [{iface.name}] source={iface.source_file} "
- f"module={iface.module_path}"
+ f" [function] {iface.name} "
+ f"← {iface.source_file} (module: {iface.module_path})"
)
else:
- logger.info(
- f" [{iface.name}] full_url={iface.http_full_url}"
- )
+ logger.info(f" [http] {iface.name} ← {iface.http_full_url}")
- # ── Step 2: 加载测试需求 ──────────────────────────────────
- logger.info("▶ Step 2: Loading test requirements …")
- requirements = load_requirements(args)
- for i, req in enumerate(requirements, 1):
- logger.info(f" {i}. {req}")
+ # ── Step 2: 加载并解析测试需求(含参数生成指令)──────────
+ logger.info("▶ Step 2: 加载并解析测试需求 …")
+ raw_requirements = load_requirements(args)
- # ── Step 3: 调用 LLM 生成测试用例 ────────────────────────
- logger.info("▶ Step 3: Calling LLM …")
- builder = PromptBuilder()
- test_cases = LLMClient().generate_test_cases(
- builder.get_system_prompt(),
- builder.build_user_prompt(
- requirements, interfaces, parser,
- project=project,
- project_desc=project_desc,
- ),
+ req_parser = RequirementParser()
+ clean_reqs, per_req_constraints, global_constraint = req_parser.parse(
+ raw_requirements
)
- logger.info(f" LLM returned {len(test_cases)} test case(s)")
+
+ # 打印原始需求 & 解析结果
+ for i, (raw, clean, c) in enumerate(
+ zip(raw_requirements, clean_reqs, per_req_constraints), 1
+ ):
+ if c.has_param_directive:
+ logger.info(
+ f" {i}. {raw}\n"
+ f" └─ 参数约束:{c} | 业务需求:{clean}"
+ )
+ else:
+ logger.info(f" {i}. {raw}")
+
+ if global_constraint.has_param_directive:
+ logger.info(f" 全局参数约束:{global_constraint}")
+
+ # 使用"清洗后的需求"传给 LLM(去掉参数指令,保留纯业务描述)
+ # 同时保留原始需求用于覆盖率分析(保证与用户输入一致)
+ effective_requirements = clean_reqs
+
+ # ── Step 3: 构建 Prompt 并调用 LLM 生成测试用例 ──────────
+ logger.info(
+ f"▶ Step 3: 调用 LLM 生成测试用例 "
+ f"(batch_size={getattr(config, 'LLM_BATCH_SIZE', 10)}) …"
+ )
+ builder = PromptBuilder()
+ iface_summary = parser.to_summary_dict(interfaces)
+ project_header = builder.build_project_header(project, project_desc)
+
+ # System Prompt:注入全局参数约束规则
+ system_prompt = builder.get_system_prompt(
+ global_constraint=global_constraint
+ )
+
+ test_cases = LLMClient().generate_test_cases(
+ system_prompt=system_prompt,
+ user_prompt="",
+ iface_summaries=iface_summary,
+ requirements=effective_requirements,
+ project_header=project_header,
+ param_constraint=global_constraint, # ← 注入参数约束
+ )
+ logger.info(f" 共生成测试用例:{len(test_cases)} 个")
# ── Step 4: 生成测试文件 ──────────────────────────────────
- logger.info(
- f"▶ Step 4: Generating test files (project='{project}') …"
- )
+ logger.info("▶ Step 4: 生成测试文件 …")
generator = TestGenerator(project=project, project_desc=project_desc)
test_files = generator.generate(test_cases)
out = generator.output_dir
+ logger.info(f" 输出目录:{out.resolve()}")
run_results = []
if not args.skip_run:
- # ── Step 5: 执行测试 ──────────────────────────────────
- logger.info("▶ Step 5: Running tests …")
+ # ── Step 5: 并行执行测试 ──────────────────────────────
+ logger.info(
+ f"▶ Step 5: 执行测试 "
+ f"(workers={getattr(config, 'TEST_MAX_WORKERS', 8)}) …"
+ )
runner = TestRunner()
run_results = runner.run_all(test_files)
runner.print_summary(run_results)
@@ -160,31 +206,36 @@ def main():
if not args.skip_analyze:
# ── Step 6: 覆盖率分析 ────────────────────────────────
- logger.info("▶ Step 6: Analyzing coverage …")
+ logger.info("▶ Step 6: 覆盖率分析 …")
report = CoverageAnalyzer(
interfaces=interfaces,
- requirements=requirements,
+ requirements=raw_requirements, # 用原始需求做覆盖率分析
test_cases=test_cases,
run_results=run_results,
).analyze()
# ── Step 7: 生成报告 ──────────────────────────────────
- logger.info("▶ Step 7: Generating reports …")
+ logger.info("▶ Step 7: 生成报告 …")
rg = ReportGenerator()
rg.save_json(report, str(out / "coverage_report.json"))
rg.save_html(report, str(out / "coverage_report.html"))
- _print_terminal_summary(report, out, project)
+ _print_terminal_summary(report, out, project, global_constraint)
- logger.info(f"\n✅ Done. Output: {out.resolve()}")
+ logger.info(f"\n✅ 完成。输出目录:{out.resolve()}")
-# ── 终端摘要 ──────────────────────────────────────────────────
+# ══════════════════════════════════════════════════════════════
+# 终端摘要
+# ══════════════════════════════════════════════════════════════
def _print_terminal_summary(
- report: CoverageReport, out: Path, project: str
+ report: "CoverageReport",
+ out: Path,
+ project: str,
+ global_constraint: ParamConstraint,
):
- W = 66
+ W = 68
def bar(rate: float, w: int = 20) -> str:
filled = int(rate * w)
@@ -193,35 +244,41 @@ def _print_terminal_summary(
return f"{'█' * filled}{'░' * empty} {rate * 100:.1f}% {icon}"
print(f"\n{'═' * W}")
- print(f" PROJECT : {project}")
- print(f" COVERAGE SUMMARY")
+ print(f" 项目:{project}")
+ print(f" 覆盖率摘要")
print(f"{'═' * W}")
- print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
- print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
- print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
- print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
- print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
- print(f" 用例通过率 {bar(report.pass_rate)}")
+ print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
+ print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
+ print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
+ print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
+ print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
+ print(f" 用例通过率 {bar(report.pass_rate)}")
print(f"{'─' * W}")
- print(f" Total Gaps : {len(report.gaps)}")
- print(f" 🔴 Critical: {report.critical_gap_count}")
- print(f" 🟠 High : {report.high_gap_count}")
+ print(f" 测试用例总数 {report.total_test_cases}")
+ print(f" 覆盖缺口总数 {len(report.gaps)}")
+ print(f" 🔴 严重缺口 {report.critical_gap_count}")
+ print(f" 🟠 高优先级缺口 {report.high_gap_count}")
+
+ # 参数约束达成情况
+ if global_constraint.has_param_directive:
+ print(f"{'─' * W}")
+ print(f" 参数生成约束 {global_constraint}")
+ actual = report.total_test_cases
+ needed = global_constraint.min_groups
+ status = "✅ 已满足" if actual >= needed else f"❌ 未满足(需 {needed} 组,实际 {actual} 组)"
+ print(f" 参数组数要求 {status}")
if report.gaps:
print(f"{'─' * W}")
- print(" Top Gaps (up to 8):")
- icons = {
- "critical": "🔴", "high": "🟠",
- "medium": "🟡", "low": "🔵",
- }
+ print(" Top 缺口(最多显示8条):")
+ icons = {"critical": "🔴", "high": "🟠", "medium": "🟡", "low": "🔵"}
for g in report.gaps[:8]:
- icon = icons.get(g.severity, "⚪")
- print(f" {icon} [{g.gap_type}] {g.target}")
+ print(f" {icons.get(g.severity, '⚪')} [{g.gap_type}] {g.target}")
print(f" → {g.suggestion}")
print(f"{'─' * W}")
- print(f" Output : {out.resolve()}")
- print(f" • coverage_report.html ← open in browser")
+ print(f" 输出目录:{out.resolve()}")
+ print(f" • coverage_report.html")
print(f" • coverage_report.json")
print(f" • run_results.json")
print(f" • test_cases_summary.json")
diff --git a/ai_test_generator/run.sh b/ai_test_generator/run.sh
index 6da9f8c..c4d8f5e 100755
--- a/ai_test_generator/run.sh
+++ b/ai_test_generator/run.sh
@@ -16,4 +16,4 @@ export HTTP_BASE_URL="http://localhost:8080"
# 只生成不执行
python main.py --api-desc examples/api_desc.json \
- --requirements "1.对每个接口进行测试,支持特殊值、边界值测试"
+ --requirements "每个测试用例的参数不少于10组"