source optimization

This commit is contained in:
sontolau 2026-03-09 11:48:19 +08:00
parent c247a8f1dc
commit ce898d81ee
7 changed files with 1457 additions and 375 deletions

View File

@ -1,58 +1,204 @@
"""
LLM 客户端
变更
- generate_test_cases() 新增 param_constraint 参数
- 分批调用时将约束透传给 PromptBuilder.build_batch_prompt()
"""
from __future__ import annotations
import json import json
import re
import logging import logging
import re
import time
from typing import Any
from openai import OpenAI from openai import OpenAI
from config import config from config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ══════════════════════════════════════════════════════════════
# JSON 健壮解析工具
# ══════════════════════════════════════════════════════════════
class RobustJSONParser:
def parse(self, text: str) -> list[dict]:
text = text.strip()
text = re.sub(r'^```(?:json)?\s*', '', text, flags=re.MULTILINE)
text = re.sub(r'\s*```$', '', text, flags=re.MULTILINE)
text = text.strip()
start = text.find('[')
if start == -1:
logger.warning("LLM 响应中未找到 JSON 数组。")
return []
text = text[start:]
try:
result = json.loads(text)
if isinstance(result, list):
return result
except json.JSONDecodeError:
pass
text = self._fix_truncated(text)
try:
result = json.loads(text)
if isinstance(result, list):
logger.warning("JSON 截断修复成功,部分用例可能丢失。")
return result
except json.JSONDecodeError as e:
logger.error(f"JSON 解析失败(修复后仍无效):{e}")
return []
def _fix_truncated(self, text: str) -> str:
last = text.rfind('},')
if last == -1:
last = text.rfind('}')
if last == -1:
return text
return text[:last + 1] + ']'
# ══════════════════════════════════════════════════════════════
# LLM 客户端
# ══════════════════════════════════════════════════════════════
class LLMClient: class LLMClient:
"""封装 LLM API 调用OpenAI 兼容接口)"""
DEFAULT_BATCH_SIZE = 10
def __init__(self): def __init__(self):
self.client = OpenAI( self.client = OpenAI(
api_key=config.LLM_API_KEY, api_key=config.LLM_API_KEY,
base_url=config.LLM_BASE_URL, base_url=getattr(config, "LLM_BASE_URL", None),
)
self.model = config.LLM_MODEL
self.max_tokens = getattr(config, "LLM_MAX_TOKENS", 8192)
self.batch_size = getattr(config, "LLM_BATCH_SIZE", self.DEFAULT_BATCH_SIZE)
self.retry = getattr(config, "LLM_RETRY", 3)
self.retry_delay = getattr(config, "LLM_RETRY_DELAY", 5)
self.batch_interval = getattr(config, "LLM_BATCH_INTERVAL", 1)
self._parser = RobustJSONParser()
from core.prompt_builder import PromptBuilder
self._prompt_builder = PromptBuilder()
# ── 主入口 ────────────────────────────────────────────────
def generate_test_cases(
self,
system_prompt: str,
user_prompt: str,
iface_summaries: list[dict] | None = None,
requirements: list[str] | None = None,
project_header: str = "",
param_constraint: "ParamConstraint | None" = None, # ← 新增
) -> list[dict]:
"""
生成测试用例
方式 A单次 user_prompt
方式 B分批 iface_summaries + requirements推荐大规模场景
param_constraint RequirementParser 解析出的全局参数约束
有值时自动注入参数数据集到每批 prompt
"""
if iface_summaries is not None and requirements is not None:
return self._generate_batched(
system_prompt, iface_summaries, requirements,
project_header, param_constraint,
)
return self._call_with_retry(system_prompt, user_prompt)
# ── 分批调用 ──────────────────────────────────────────────
def _generate_batched(
self,
system_prompt: str,
iface_summaries: list[dict],
requirements: list[str],
project_header: str,
param_constraint: Any,
) -> list[dict]:
total = len(iface_summaries)
batches = self._make_batches(iface_summaries, self.batch_size)
all_cases: list[dict] = []
logger.info(
f"分批模式:共 {total} 个接口 → "
f"{len(batches)}× 每批最多 {self.batch_size}"
)
if param_constraint and param_constraint.has_param_directive:
logger.info(
f"参数约束已启用:{param_constraint}"
) )
def generate_test_cases(self, system_prompt: str, user_prompt: str) -> list[dict]: for idx, batch in enumerate(batches, 1):
logger.info(f"Calling LLM: model={config.LLM_MODEL}") names = [i.get("name", "?") for i in batch]
logger.debug(f"--- USER PROMPT ---\n{user_prompt}\n---") logger.info(f"{idx}/{len(batches)} 批:{names}")
user_prompt = self._prompt_builder.build_batch_prompt(
batch=batch,
requirements=requirements,
project_header=project_header,
param_constraint=param_constraint, # ← 透传
)
cases = self._call_with_retry(system_prompt, user_prompt)
logger.info(f"{idx} 批 → 生成 {len(cases)} 个测试用例")
all_cases.extend(cases)
if idx < len(batches):
time.sleep(self.batch_interval)
logger.info(f"全部批次完成,共生成 {len(all_cases)} 个测试用例")
return all_cases
# ── 单次调用(含重试)────────────────────────────────────
def _call_with_retry(
self,
system_prompt: str,
user_prompt: str,
) -> list[dict]:
last_error: Exception | None = None
for attempt in range(1, self.retry + 1):
try:
logger.debug(f"LLM 调用第 {attempt}/{self.retry} 次 …")
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=config.LLM_MODEL, model=self.model,
temperature=config.LLM_TEMPERATURE, max_tokens=self.max_tokens,
max_tokens=config.LLM_MAX_TOKENS,
messages=[ messages=[
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}, {"role": "user", "content": user_prompt},
], ],
) )
raw = response.choices[0].message.content raw = response.choices[0].message.content or ""
logger.debug(f"--- LLM RAW RESPONSE ---\n{raw[:800]}\n---") logger.debug(f"LLM 响应长度:{len(raw)} 字符")
return self._parse_json(raw)
# ── 解析 ────────────────────────────────────────────────── cases = self._parser.parse(raw)
if cases:
return cases
def _parse_json(self, content: str) -> list[dict]: logger.warning(f"{attempt} 次调用返回空结果,准备重试 …")
content = content.strip()
# 去除可能的 markdown 代码块
content = re.sub(r"^```(?:json)?\s*", "", content, flags=re.MULTILINE)
content = re.sub(r"\s*```\s*$", "", content, flags=re.MULTILINE)
content = content.strip()
try: except Exception as e:
data = json.loads(content) last_error = e
except json.JSONDecodeError: logger.warning(f"{attempt} 次调用失败:{e}")
# 尝试提取第一个 JSON 数组
match = re.search(r"\[.*\]", content, re.DOTALL)
if match:
data = json.loads(match.group())
else:
logger.error(f"Cannot parse LLM response as JSON:\n{content[:600]}")
raise
if not isinstance(data, list): if attempt < self.retry:
raise ValueError(f"Expected JSON array, got {type(data)}") time.sleep(self.retry_delay)
return data
logger.error(
f"已重试 {self.retry} 次,全部失败。最后一次错误:{last_error}"
)
return []
@staticmethod
def _make_batches(items: list[Any], size: int) -> list[list[Any]]:
return [items[i:i + size] for i in range(0, len(items), size)]

View File

@ -0,0 +1,393 @@
"""
测试参数生成器
职责
根据接口参数的数据类型和参数生成策略
自动生成覆盖正常值边界值异常值的测试数据集
支持的数据类型
string / str
integer / int / number
float / double / decimal
boolean / bool
array / list
object / dict / map
输出格式供注入 prompt
{
"param_name": {
"type": "string",
"groups": [
{"label": "正常值", "value": "hello", "category": "normal"},
{"label": "空字符串", "value": "", "category": "boundary"},
{"label": "超长字符串","value": "a"*256, "category": "exception"},
...
]
}
}
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from core.requirement_parser import ParamStrategy
# ══════════════════════════════════════════════════════════════
# 数据结构
# ══════════════════════════════════════════════════════════════
@dataclass
class ParamGroup:
"""单组测试参数。"""
label: str # 人类可读的描述,如"空字符串"
value: Any # 实际参数值
category: str # "normal" | "boundary" | "exception" | "equivalence"
def to_dict(self) -> dict:
return {
"label": self.label,
"value": self.value,
"category": self.category,
}
@dataclass
class ParamDataSet:
"""单个参数的完整测试数据集。"""
param_name: str
param_type: str
groups: list[ParamGroup] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"type": self.param_type,
"groups": [g.to_dict() for g in self.groups],
}
@property
def normal_values(self) -> list[Any]:
return [g.value for g in self.groups if g.category == "normal"]
@property
def boundary_values(self) -> list[Any]:
return [g.value for g in self.groups if g.category == "boundary"]
@property
def exception_values(self) -> list[Any]:
return [g.value for g in self.groups if g.category == "exception"]
# ══════════════════════════════════════════════════════════════
# 各类型测试数据库
# ══════════════════════════════════════════════════════════════
# ── string ────────────────────────────────────────────────────
_STRING_NORMAL: list[ParamGroup] = [
ParamGroup("普通字符串", "hello", "normal"),
ParamGroup("中文字符串", "测试数据", "normal"),
ParamGroup("字母数字混合", "user123", "normal"),
]
_STRING_BOUNDARY: list[ParamGroup] = [
ParamGroup("空字符串", "", "boundary"),
ParamGroup("单字符", "a", "boundary"),
ParamGroup("最大长度(255)", "a" * 255, "boundary"),
ParamGroup("恰好256字符", "a" * 256, "boundary"),
ParamGroup("含空格", "hello world", "boundary"),
ParamGroup("首尾空格", " hello ", "boundary"),
]
_STRING_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("超长字符串", "a" * 1000, "exception"),
ParamGroup("纯空格", " ", "exception"),
ParamGroup("特殊字符", "!@#$%^&*()", "exception"),
ParamGroup("换行符", "line1\nline2", "exception"),
ParamGroup("SQL注入", "' OR '1'='1", "exception"),
ParamGroup("XSS注入", "<script>alert(1)</script>", "exception"),
ParamGroup("Unicode特殊字符","你好\u0000世界", "exception"),
ParamGroup("数字类型误传", 12345, "exception"),
]
# ── integer ───────────────────────────────────────────────────
_INTEGER_NORMAL: list[ParamGroup] = [
ParamGroup("典型正整数", 1, "normal"),
ParamGroup("较大正整数", 100, "normal"),
]
_INTEGER_BOUNDARY: list[ParamGroup] = [
ParamGroup("", 0, "boundary"),
ParamGroup("正边界值1", 1, "boundary"),
ParamGroup("负边界值-1", -1, "boundary"),
ParamGroup("最大值(int32)", 2_147_483_647, "boundary"),
ParamGroup("最小值(int32)", -2_147_483_648, "boundary"),
ParamGroup("最大值+1", 2_147_483_648, "boundary"),
]
_INTEGER_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("字符串类型", "abc", "exception"),
ParamGroup("浮点数", 3.14, "exception"),
ParamGroup("布尔值true", True, "exception"),
ParamGroup("空字符串", "", "exception"),
ParamGroup("超大整数", 10 ** 20, "exception"),
]
# ── float ─────────────────────────────────────────────────────
_FLOAT_NORMAL: list[ParamGroup] = [
ParamGroup("典型浮点数", 3.14, "normal"),
ParamGroup("整数形式", 1.0, "normal"),
]
_FLOAT_BOUNDARY: list[ParamGroup] = [
ParamGroup("", 0.0, "boundary"),
ParamGroup("极小正数", 1e-10, "boundary"),
ParamGroup("极大正数", 1e10, "boundary"),
ParamGroup("负数", -3.14, "boundary"),
ParamGroup("负无穷", float('-inf'), "boundary"),
ParamGroup("正无穷", float('inf'), "boundary"),
]
_FLOAT_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("NaN", float('nan'), "exception"),
ParamGroup("字符串类型", "abc", "exception"),
ParamGroup("空字符串", "", "exception"),
]
# ── boolean ───────────────────────────────────────────────────
_BOOLEAN_NORMAL: list[ParamGroup] = [
ParamGroup("true", True, "normal"),
ParamGroup("false", False, "normal"),
]
_BOOLEAN_BOUNDARY: list[ParamGroup] = [
ParamGroup("整数1", 1, "boundary"),
ParamGroup("整数0", 0, "boundary"),
]
_BOOLEAN_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("字符串'true'", "true", "exception"),
ParamGroup("字符串'false'", "false", "exception"),
ParamGroup("字符串'yes'", "yes", "exception"),
ParamGroup("随机字符串", "abc", "exception"),
]
# ── array ─────────────────────────────────────────────────────
_ARRAY_NORMAL: list[ParamGroup] = [
ParamGroup("普通数组", [1, 2, 3], "normal"),
ParamGroup("字符串数组", ["a", "b", "c"], "normal"),
]
_ARRAY_BOUNDARY: list[ParamGroup] = [
ParamGroup("空数组", [], "boundary"),
ParamGroup("单元素数组", [1], "boundary"),
ParamGroup("大数组(100元素)", list(range(100)), "boundary"),
ParamGroup("含None元素", [1, None, 3], "boundary"),
]
_ARRAY_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("字符串类型", "not_an_array", "exception"),
ParamGroup("整数类型", 123, "exception"),
ParamGroup("嵌套过深", [[[[[1]]]]], "exception"),
]
# ── object ────────────────────────────────────────────────────
_OBJECT_NORMAL: list[ParamGroup] = [
ParamGroup("完整字段对象", {"id": 1, "name": "test"}, "normal"),
]
_OBJECT_BOUNDARY: list[ParamGroup] = [
ParamGroup("空对象", {}, "boundary"),
ParamGroup("仅必填字段", {"id": 1}, "boundary"),
ParamGroup("含额外字段", {"id": 1, "extra": "x"}, "boundary"),
]
_OBJECT_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("字符串类型", "not_an_object", "exception"),
ParamGroup("数组类型", [1, 2, 3], "exception"),
ParamGroup("字段值类型错误", {"id": "not_int"}, "exception"),
]
# 类型 → 数据映射表
_TYPE_DATA: dict[str, dict[str, list[ParamGroup]]] = {
"string": {"normal": _STRING_NORMAL, "boundary": _STRING_BOUNDARY, "exception": _STRING_EXCEPTION},
"integer": {"normal": _INTEGER_NORMAL, "boundary": _INTEGER_BOUNDARY, "exception": _INTEGER_EXCEPTION},
"float": {"normal": _FLOAT_NORMAL, "boundary": _FLOAT_BOUNDARY, "exception": _FLOAT_EXCEPTION},
"boolean": {"normal": _BOOLEAN_NORMAL, "boundary": _BOOLEAN_BOUNDARY, "exception": _BOOLEAN_EXCEPTION},
"array": {"normal": _ARRAY_NORMAL, "boundary": _ARRAY_BOUNDARY, "exception": _ARRAY_EXCEPTION},
"object": {"normal": _OBJECT_NORMAL, "boundary": _OBJECT_BOUNDARY, "exception": _OBJECT_EXCEPTION},
}
# 类型别名归一化
_TYPE_ALIAS: dict[str, str] = {
"str": "string", "text": "string", "varchar": "string",
"int": "integer", "long": "integer", "number": "integer",
"float": "float", "double": "float", "decimal": "float",
"bool": "boolean",
"list": "array",
"dict": "object", "map": "object", "json": "object",
}
# ══════════════════════════════════════════════════════════════
# 参数生成器
# ══════════════════════════════════════════════════════════════
class ParamGenerator:
"""
根据参数类型和策略生成测试数据集
"""
def generate(
self,
param_name: str,
param_type: str,
strategies: list[ParamStrategy],
min_groups: int = 2,
) -> ParamDataSet:
"""
生成单个参数的测试数据集
Args:
param_name : 参数名
param_type : 数据类型支持别名
strategies : 需要覆盖的策略列表
min_groups : 最少生成的参数组数
Returns:
ParamDataSet
"""
norm_type = self._normalize_type(param_type)
type_data = _TYPE_DATA.get(norm_type, _TYPE_DATA["string"])
groups: list[ParamGroup] = []
# 始终包含正常值
groups.extend(type_data["normal"])
# 按策略追加
for strategy in strategies:
if strategy == ParamStrategy.BOUNDARY:
groups.extend(type_data["boundary"])
elif strategy == ParamStrategy.EXCEPTION:
groups.extend(type_data["exception"])
elif strategy == ParamStrategy.EQUIVALENCE:
# 等价类:正常值 + 部分边界值
groups.extend(type_data["boundary"][:3])
elif strategy == ParamStrategy.RANDOM:
groups.extend(self._random_groups(norm_type))
# 去重(按 label
seen_labels: set[str] = set()
unique_groups: list[ParamGroup] = []
for g in groups:
if g.label not in seen_labels:
seen_labels.add(g.label)
unique_groups.append(g)
# 补齐到 min_groups
if len(unique_groups) < min_groups:
extra = type_data["boundary"] + type_data["exception"]
for g in extra:
if g.label not in seen_labels and len(unique_groups) < min_groups:
seen_labels.add(g.label)
unique_groups.append(g)
return ParamDataSet(
param_name=param_name,
param_type=norm_type,
groups=unique_groups,
)
def generate_for_interface(
self,
interface_summary: dict,
strategies: list[ParamStrategy],
min_groups: int = 2,
) -> dict[str, ParamDataSet]:
"""
为一个接口的所有入参生成测试数据集
Args:
interface_summary : parser.to_summary_dict() 中的单个接口 dict
strategies : 策略列表
min_groups : 最少组数
Returns:
{ param_name: ParamDataSet }
"""
result: dict[str, ParamDataSet] = {}
params = interface_summary.get("params", {})
for param_name, param_info in params.items():
if isinstance(param_info, dict):
param_type = param_info.get("type", "string")
inout = param_info.get("inout", "in")
else:
param_type = str(param_info)
inout = "in"
# 只对入参in / inout生成测试数据
if inout not in ("in", "inout"):
continue
result[param_name] = self.generate(
param_name=param_name,
param_type=param_type,
strategies=strategies,
min_groups=min_groups,
)
return result
# ── 工具 ──────────────────────────────────────────────────
@staticmethod
def _normalize_type(t: str) -> str:
t = t.lower().strip()
return _TYPE_ALIAS.get(t, t if t in _TYPE_DATA else "string")
@staticmethod
def _random_groups(norm_type: str) -> list[ParamGroup]:
"""生成少量随机值(补充策略)。"""
import random, string
if norm_type == "string":
val = ''.join(random.choices(string.ascii_letters, k=8))
return [ParamGroup(f"随机字符串({val})", val, "equivalence")]
if norm_type == "integer":
val = random.randint(-1000, 1000)
return [ParamGroup(f"随机整数({val})", val, "equivalence")]
if norm_type == "float":
val = round(random.uniform(-100, 100), 4)
return [ParamGroup(f"随机浮点数({val})", val, "equivalence")]
return []
def to_prompt_text(
self,
datasets: dict[str, ParamDataSet],
min_groups: int,
) -> str:
"""
将参数数据集转换为注入 prompt 的中文说明文本
"""
if not datasets:
return ""
lines = [
f"### 参数测试数据集(每个参数至少覆盖 {min_groups} 组)",
"",
]
for param_name, ds in datasets.items():
lines.append(f"**参数:{param_name}**(类型:{ds.param_type}")
for g in ds.groups:
category_label = {
"normal": "正常值",
"boundary": "边界值",
"exception": "异常值",
"equivalence": "等价类",
}.get(g.category, g.category)
lines.append(
f" - [{category_label}] {g.label}`{repr(g.value)}`"
)
lines.append("")
lines.append(
f"> 请从以上数据集中选取参数值组合,"
f"确保生成的测试用例总数不少于 **{min_groups} 组**"
f"并覆盖正常值、边界值和异常值场景。"
)
return "\n".join(lines)

View File

@ -1,131 +1,290 @@
""" """
提示词构建器 提示词构建器中文版
升级点
- 传递项目 description 变更说明
- function 协议source_file(.py) module_path from <module> import <func> - 新增 build_param_directive_section()将参数约束注入 System Prompt
- HTTP 协议full_url = url(base) + name(path) - 新增 build_param_dataset_section()将参数数据集注入 User Prompt
- parameters 统一字段inout 区分输入/输出 - build_batch_prompt / build_user_prompt 支持传入参数数据集
""" """
import json import json
from core.parser import InterfaceInfo, InterfaceParser from core.parser import InterfaceInfo, InterfaceParser
from core.requirement_parser import ParamConstraint, ParamStrategy
from core.param_generator import ParamGenerator, ParamDataSet
SYSTEM_PROMPT = """
You are a senior software test engineer. Generate test cases based on the provided # ══════════════════════════════════════════════════════════════
interface descriptions and test requirements. # System Prompt中文
# ══════════════════════════════════════════════════════════════
_SYSTEM_PROMPT_BASE = """
你是一名资深软件测试工程师擅长根据接口描述和测试需求生成高质量的测试用例
OUTPUT FORMAT 输出格式要求
Return ONLY a valid JSON array. No markdown fences. No extra text.
Each element = ONE test case: - 只输出一个合法的 JSON 数组不要包含任何 markdown 代码块```
- 不要在 JSON 前后添加任何解释性文字
- 每个数组元素代表一个测试用例结构如下
{ {
"test_id" : "unique id, e.g. TC_001", "test_id" : "唯一编号,如 TC_001",
"test_name" : "short descriptive name", "test_name" : "简短的测试名称",
"description" : "what this test verifies", "description" : "本用例验证的内容",
"requirement" : "the original requirement text this case maps to", "requirement" : "对应的原始需求文本",
"requirement_id" : "e.g. REQ.01", "requirement_id" : "需求编号,如 REQ.01(从接口描述中的 requirement_id 字段获取)",
"steps": [ "steps": [
{ {
"step_no" : 1, "step_no" : 1,
"interface_name" : "function or endpoint name", "interface_name" : "接口名称或函数名",
"protocol" : "http or function", "protocol" : "http 或 function",
"url" : "source_file (function) or full_url (http)", "url" : "function 协议填 source_filehttp 协议填 full_url",
"purpose" : "why this step is needed", "purpose" : "本步骤的目的",
"input": { "input": {
// Only parameters with inout = "in" or "inout" "参数名": "参数值"
"param_name": <value>
}, },
"use_output_of": { // optional "use_output_of": {
"step_no" : 1, "step_no" : 1,
"field" : "user_id", "field" : "上一步返回值中的字段名",
"as_param" : "user_id" "as_param" : "作为本步骤的参数名"
}, },
"assertions": [ "assertions": [
{ {
"field" : "field_name or 'return' or 'exception'", "field" : "断言的字段名,或 'return'(整体返回值)或 'exception'(异常)",
"operator" : "eq | ne | gt | lt | gte | lte | in | not_null | contains | raised | not_raised", "operator" : "eq | ne | gt | lt | gte | lte | in | not_null | contains | raised | not_raised",
"expected" : <value>, "expected" : "期望值",
"message" : "human readable description" "message" : "断言说明"
} }
] ]
} }
], ],
"test_data_notes" : "explanation of auto-generated test data", "test_data_notes" : "测试数据说明(说明本用例使用了哪类参数:正常值/边界值/异常值)",
"test_code" : "<complete Python script — see rules below>" "test_code" : "完整可运行的 Python 测试脚本(见下方规则)"
} }
TEST CODE RULES 测试代码编写规则
1. Complete, runnable Python script. No external test framework. 1. test_code 必须是完整可直接运行的 Python 脚本
2. Allowed imports: standard library + `requests` + the actual module under test. 2. 不得使用 unittest pytest 框架
3. 只允许导入Python 标准库requests被测模块
FUNCTION INTERFACES function 协议接口
3. Each function interface has: 4. 使用 module_path 字段导入真实函数若导入失败则使用桩函数
"source_file" : e.g. "create_user.py"
"module_path" : e.g. "create_user" (derived from source_file)
4. Import the REAL function using module_path:
try: try:
from <module_path> import <function_name> from <module_path> import <function_name>
except ImportError: except ImportError:
# Stub fallback — simulate on_success for positive tests, def <function_name>(<入参列表>):
# on_failure / raise Exception for negative tests return <on_success_value> # 桩函数,仅供结构验证
def <function_name>(<in_params>): 5. 调用函数时只传 inout "in" "inout" 的参数
return <on_success_value> # or raise Exception(...) 6. on_failure 包含 "exception"负向测试需用 try/except 捕获异常
5. Call the function with ONLY "in"/"inout" parameters.
6. Capture the return value; assert against on_success or on_failure structure.
7. If on_failure.value contains "exception" or "raises":
- In negative tests: wrap call in try/except, assert exception IS raised.
- Use field="exception", operator="raised", expected="Exception".
HTTP INTERFACES http 协议接口
8. Each HTTP interface has: 7. 使用 full_url 字段作为请求地址
"full_url" : complete URL, e.g. "http://127.0.0.1/api/delete_user"
"method" : "get" | "post" | "put" | "delete" etc.
9. Send request:
resp = requests.<method>("<full_url>", json=<input_dict>) resp = requests.<method>("<full_url>", json=<input_dict>)
10. Assert on resp.status_code and resp.json() fields. 8. 断言 resp.status_code resp.json() 中的字段
MULTI-STEP 结构化输出必须遵守
11. Execute steps in order; extract fields from previous step's return value 9. 每个步骤执行后必须在标准输出打印如下格式的一行
and pass to next step via use_output_of mapping. ##STEP_RESULT## {"step_no":<n>,"interface_name":"...","status":"PASS 或 FAIL",
STRUCTURED OUTPUT (REQUIRED)
12. After EACH step, print:
##STEP_RESULT## {"step_no":<n>,"interface_name":"...","status":"PASS|FAIL",
"assertions":[{"field":"...","operator":"...","expected":..., "assertions":[{"field":"...","operator":"...","expected":...,
"actual":...,"passed":true|false,"message":"..."}]} "actual":...,"passed":true false,"message":"..."}]}
13. Final line must be: 10. 脚本最后一行必须是
PASS: <summary> or FAIL: <summary> PASS: <摘要> FAIL: <摘要>
14. Wrap entire test body in try/except; print FAIL on any unhandled exception. 11. 整个脚本主体用 try/except 包裹捕获未处理异常时打印 FAIL
15. Do NOT use unittest or pytest.
ASSERTION RULES BY RETURN TYPE 断言编写规则
return.type = "dict" assert each key in on_success.value (positive) - 返回值为 dict on_success / on_failure 中每个 key 单独断言
or on_failure.value (negative) - 返回值为 bool field="return"operator="eq"expected=true false
return.type = "boolean" field="return", operator="eq", expected=true/false - 返回值为 int field="return"operator="eq" / "gt" / "lt"
return.type = "integer" field="return", operator="eq"/"gt"/"lt" - 预期抛出异常 field="exception"operator="raised"
on_failure = exception field="exception", operator="raised" - 预期不抛出异常 field="exception"operator="not_raised"
COVERAGE GUIDELINES 覆盖度要求
- Per requirement: at least 1 positive + 1 negative test case. - 每条需求至少生成 1 个正向用例 + 1 个负向用例
- Negative test_name/description MUST contain "negative" or "invalid". - 负向用例的 test_name 必须包含"负向""无效"
- Multi-interface requirements: single multi-step test case. - 覆盖接口的所有 inout "in" "inout" 的参数
- Cover ALL "in"/"inout" parameters (including optional ones). - 正向用例断言 on_success 的所有字段
- Assert ALL fields described in on_success (positive) and on_failure (negative). - 负向用例断言 on_failure 的所有字段
- For "out" parameters: verify their values in assertions after the call. """
# 参数生成规则段落(有参数约束时追加到 System Prompt
_PARAM_DIRECTIVE_TEMPLATE = """
参数生成规则本次任务特殊要求
{directives}
- test_data_notes 字段必须说明本用例使用的参数类别正常值/边界值/异常值及具体值
- 每个测试用例的 input 字段中的参数值必须从下方"参数测试数据集"中选取
- 禁止在 input 中使用未在数据集中出现的随机值
""" """
# ══════════════════════════════════════════════════════════════
# PromptBuilder
# ══════════════════════════════════════════════════════════════
class PromptBuilder: class PromptBuilder:
def get_system_prompt(self) -> str: def __init__(self):
return SYSTEM_PROMPT.strip() self._param_gen = ParamGenerator()
# ── System Prompt ────────────────────────────────────────
def get_system_prompt(
self,
global_constraint: ParamConstraint = None,
) -> str:
"""
构建 System Prompt
若存在全局参数约束追加参数生成规则段落
"""
base = _SYSTEM_PROMPT_BASE.strip()
if global_constraint and global_constraint.has_param_directive:
directive_section = self._build_param_directive_section(
global_constraint
)
return base + "\n" + directive_section.strip()
return base
def _build_param_directive_section(
self,
constraint: ParamConstraint,
) -> str:
"""构建参数生成规则段落,注入 System Prompt。"""
lines: list[str] = []
if constraint.min_groups > 2:
lines.append(
f"- 每个接口的测试用例总数不少于 **{constraint.min_groups} 组**"
f"(正向 + 负向合计)"
)
strategy_labels = {
ParamStrategy.BOUNDARY: "边界值(空值、最大值、最小值、临界值)",
ParamStrategy.EXCEPTION: "异常值None、类型错误、超长、特殊字符等",
ParamStrategy.EQUIVALENCE: "等价类(有效等价类 + 无效等价类)",
ParamStrategy.RANDOM: "随机值(覆盖典型随机场景)",
}
if constraint.strategies:
strategy_str = "".join(
strategy_labels[s]
for s in constraint.strategies
if s in strategy_labels
)
lines.append(f"- 参数值必须覆盖以下类型:{strategy_str}")
if not lines:
return ""
directives = "\n ".join(lines)
return _PARAM_DIRECTIVE_TEMPLATE.format(directives=directives)
# ── 项目信息头 ───────────────────────────────────────────
def build_project_header(
self,
project: str = "",
project_desc: str = "",
) -> str:
if not project and not project_desc:
return ""
lines = ["## 项目信息"]
if project:
lines.append(f"项目名称:{project}")
if project_desc:
lines.append(f"项目描述:{project_desc}")
return "\n".join(lines)
# ── 参数数据集段落 ───────────────────────────────────────
def build_param_dataset_section(
self,
iface_summaries: list[dict],
constraint: ParamConstraint,
) -> str:
"""
为批次内所有接口生成参数数据集返回注入 user_prompt 的文本段落
"""
if not constraint.has_param_directive:
return ""
all_lines: list[str] = [
"## 参数测试数据集",
f"> 以下数据集由系统自动生成,请从中选取参数值组合,"
f"确保每个接口的测试用例总数不少于 **{constraint.min_groups} 组**。",
"",
]
for iface in iface_summaries:
name = iface.get("name", "未知接口")
datasets = self._param_gen.generate_for_interface(
interface_summary=iface,
strategies=constraint.strategies,
min_groups=constraint.min_groups,
)
if not datasets:
continue
all_lines.append(f"### 接口:{name}")
text = self._param_gen.to_prompt_text(datasets, constraint.min_groups)
all_lines.append(text)
all_lines.append("")
return "\n".join(all_lines)
# ── 分批 user_prompt ────────────────────────────────────
def build_batch_prompt(
self,
batch: list[dict],
requirements: list[str],
project_header: str = "",
param_constraint: ParamConstraint = None,
) -> str:
"""
构建单批次的 user_prompt
若存在参数约束自动注入参数数据集
"""
req_lines = "\n".join(
f"{i + 1}. {r}" for i, r in enumerate(requirements)
)
parts: list[str] = []
if project_header:
parts.append(project_header)
parts.append(
"## 本批次接口描述\n"
+ json.dumps(batch, ensure_ascii=False, indent=2)
)
parts.append("## 测试需求\n" + req_lines)
# 注入参数数据集
if param_constraint and param_constraint.has_param_directive:
dataset_section = self.build_param_dataset_section(
iface_summaries=batch,
constraint=param_constraint,
)
if dataset_section:
parts.append(dataset_section)
parts.append(
"请根据以上接口描述、测试需求和参数数据集,生成完整的测试用例(包含正向和负向)。\n"
"- function 协议:使用 module_path 字段导入被测函数\n"
"- http 协议:使用 full_url 字段作为请求地址\n"
"- 每条需求的 requirement_id 须填入对应用例的 requirement_id 字段\n"
"- test_data_notes 须说明本用例使用的参数类别和具体值"
)
return "\n\n".join(parts)
# ── 单次 user_prompt ────────────────────────────────────
def build_user_prompt( def build_user_prompt(
self, self,
@ -134,34 +293,40 @@ class PromptBuilder:
parser: InterfaceParser, parser: InterfaceParser,
project: str = "", project: str = "",
project_desc: str = "", project_desc: str = "",
param_constraint: ParamConstraint = None,
) -> str: ) -> str:
"""
单次调用模式接口数 batch_size 时使用
"""
iface_summary = parser.to_summary_dict(interfaces) iface_summary = parser.to_summary_dict(interfaces)
req_lines = "\n".join( req_lines = "\n".join(
f"{i + 1}. {r}" for i, r in enumerate(requirements) f"{i + 1}. {r}" for i, r in enumerate(requirements)
) )
header = self.build_project_header(project, project_desc)
# 项目信息头 parts: list[str] = []
project_section = "" if header:
if project or project_desc: parts.append(header)
project_section = "## Project\n" parts.append(
if project: "## 接口描述\n"
project_section += f"Name : {project}\n" + json.dumps(iface_summary, ensure_ascii=False, indent=2)
if project_desc: )
project_section += f"Description: {project_desc}\n" parts.append("## 测试需求\n" + req_lines)
project_section += "\n"
return ( if param_constraint and param_constraint.has_param_directive:
f"{project_section}" dataset_section = self.build_param_dataset_section(
f"## Available Interfaces\n" iface_summaries=iface_summary
f"{json.dumps(iface_summary, ensure_ascii=False, indent=2)}\n\n" if isinstance(iface_summary, list) else list(iface_summary.values()),
f"## Test Requirements\n" constraint=param_constraint,
f"Generate comprehensive test cases (positive + negative) for every interface.\n" )
f"- The generated test cases for each interface must meet the following requirements: {req_lines}\n" if dataset_section:
f"- For function interfaces: use 'module_path' for import, " parts.append(dataset_section)
f"'source_file' for reference.\n"
f"- For HTTP interfaces: use 'full_url' as the request URL.\n" parts.append(
f"- Use requirement_id from the interface to populate the test case's " "请根据以上接口描述、测试需求和参数数据集,生成完整的测试用例(包含正向和负向)。\n"
f"requirement_id field.\n" "- function 协议:使用 module_path 字段导入被测函数\n"
f"- Chain multiple interface calls in one test case when a requirement " "- http 协议:使用 full_url 字段作为请求地址\n"
f"involves more than one interface.\n" "- 每条需求的 requirement_id 须填入对应用例的 requirement_id 字段\n"
).strip() "- test_data_notes 须说明本用例使用的参数类别和具体值"
)
return "\n\n".join(parts)

View File

@ -0,0 +1,211 @@
"""
需求解析器
职责
从用户填写的测试需求文本中识别并提取参数生成约束指令
- 参数组数约束"不少于10组""至少5组""生成8组"
- 参数策略约束"边界值""异常值""等价类""随机值"
- 数据类型约束"字符串边界值""整数异常值"
解析结果供 ParamGenerator PromptBuilder 使用
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from enum import Enum
# ══════════════════════════════════════════════════════════════
# 枚举:参数生成策略
# ══════════════════════════════════════════════════════════════
class ParamStrategy(str, Enum):
BOUNDARY = "boundary" # 边界值
EXCEPTION = "exception" # 异常值 / 错误值
EQUIVALENCE = "equivalence" # 等价类
RANDOM = "random" # 随机值
CUSTOM = "custom" # 自定义(用户在需求中直接指定)
# ══════════════════════════════════════════════════════════════
# 数据结构:解析结果
# ══════════════════════════════════════════════════════════════
@dataclass
class ParamConstraint:
"""
从单条需求文本中解析出的参数生成约束
"""
# 最少生成的参数组数(正向 + 负向合计)
min_groups: int = 2
# 需要覆盖的策略集合
strategies: list[ParamStrategy] = field(default_factory=list)
# 原始需求文本(去除参数指令后的纯业务需求部分)
clean_requirement: str = ""
# 是否显式指定了参数要求(用于区分"用户有要求"和"默认值"
has_param_directive: bool = False
def __str__(self) -> str:
parts = []
if self.min_groups > 2:
parts.append(f"至少 {self.min_groups} 组参数")
for s in self.strategies:
parts.append(_STRATEGY_LABEL[s])
return "".join(parts) if parts else "默认1正1负"
_STRATEGY_LABEL: dict[ParamStrategy, str] = {
ParamStrategy.BOUNDARY: "边界值",
ParamStrategy.EXCEPTION: "异常值",
ParamStrategy.EQUIVALENCE: "等价类",
ParamStrategy.RANDOM: "随机值",
ParamStrategy.CUSTOM: "自定义",
}
# ══════════════════════════════════════════════════════════════
# 解析规则
# ══════════════════════════════════════════════════════════════
# 数量约束:匹配"不少于N组"、"至少N组"、"最少N组"、"生成N组"、"N组以上"
_RE_MIN_GROUPS = re.compile(
r'(?:不少于|至少|最少|生成|共|需要|要求)\s*(\d+)\s*组'
r'|(\d+)\s*组(?:以上|及以上|或以上)',
re.UNICODE,
)
# 策略关键词映射(关键词 → 策略枚举)
_STRATEGY_KEYWORDS: list[tuple[list[str], ParamStrategy]] = [
(["边界值", "边界", "boundary", "边界测试"], ParamStrategy.BOUNDARY),
(["异常值", "异常", "错误值", "非法值",
"exception", "invalid", "error"], ParamStrategy.EXCEPTION),
(["等价类", "等价", "equivalence"], ParamStrategy.EQUIVALENCE),
(["随机", "随机值", "random"], ParamStrategy.RANDOM),
]
# 参数指令整体识别:包含上述任意关键词则认为是参数指令
_RE_PARAM_DIRECTIVE = re.compile(
r'(?:参数|测试参数|测试数据|数据|入参)'
r'.*?(?:不少于|至少|最少|生成|边界|异常|等价|随机|\d+\s*组)',
re.UNICODE,
)
# ══════════════════════════════════════════════════════════════
# 解析器
# ══════════════════════════════════════════════════════════════
class RequirementParser:
"""
解析测试需求列表提取参数生成约束返回
- clean_requirements : 去除参数指令后的纯业务需求列表
- constraints : 每条需求对应的 ParamConstraint
- global_constraint : 全局约束"全局"/"所有接口"等描述中提取
"""
def parse(
self,
requirements: list[str],
) -> tuple[list[str], list[ParamConstraint], ParamConstraint]:
"""
Returns:
(clean_requirements, per_req_constraints, global_constraint)
"""
clean_reqs: list[str] = []
constraints: list[ParamConstraint] = []
global_c = ParamConstraint()
for req in requirements:
c = self._parse_one(req)
constraints.append(c)
clean_reqs.append(c.clean_requirement)
# 若某条需求是全局性描述(如"所有接口参数不少于10组"
# 则将其约束提升为全局约束
if self._is_global(req) and c.has_param_directive:
global_c = self._merge(global_c, c)
return clean_reqs, constraints, global_c
# ── 解析单条需求 ──────────────────────────────────────────
def _parse_one(self, req: str) -> ParamConstraint:
c = ParamConstraint(clean_requirement=req)
# 1. 检测是否包含参数指令
if not self._has_directive(req):
return c
c.has_param_directive = True
# 2. 提取数量约束
m = _RE_MIN_GROUPS.search(req)
if m:
n = int(m.group(1) or m.group(2))
c.min_groups = max(n, 2) # 至少保留 1正1负
# 3. 提取策略约束
seen: set[ParamStrategy] = set()
for keywords, strategy in _STRATEGY_KEYWORDS:
if any(kw in req for kw in keywords):
seen.add(strategy)
c.strategies = list(seen)
# 4. 若未指定策略,默认同时覆盖边界值和异常值
if not c.strategies:
c.strategies = [ParamStrategy.BOUNDARY, ParamStrategy.EXCEPTION]
# 5. 清理需求文本:去除参数指令部分,保留业务描述
c.clean_requirement = self._strip_directive(req)
return c
# ── 工具方法 ──────────────────────────────────────────────
@staticmethod
def _has_directive(req: str) -> bool:
"""判断需求文本是否包含参数生成指令。"""
return bool(_RE_PARAM_DIRECTIVE.search(req)) or bool(
_RE_MIN_GROUPS.search(req)
)
@staticmethod
def _is_global(req: str) -> bool:
"""判断是否为全局性参数要求。"""
return any(
kw in req
for kw in ["所有接口", "全部接口", "全局", "所有测试", "全部测试"]
)
@staticmethod
def _strip_directive(req: str) -> str:
"""
去除需求文本中的参数指令部分保留纯业务描述
策略以中文逗号分号括号等分隔去除含参数指令的子句
"""
# 按常见分隔符拆分
parts = re.split(r'[,;()【】\n]', req)
clean_parts = []
for part in parts:
part = part.strip()
if not part:
continue
# 若该子句包含参数指令关键词,则跳过
if _RE_PARAM_DIRECTIVE.search(part) or _RE_MIN_GROUPS.search(part):
continue
clean_parts.append(part)
result = "".join(clean_parts).strip()
return result if result else req # 若全部被去除则保留原文
@staticmethod
def _merge(base: ParamConstraint, other: ParamConstraint) -> ParamConstraint:
"""合并两个约束,取较大值。"""
merged = ParamConstraint(
min_groups=max(base.min_groups, other.min_groups),
strategies=list(set(base.strategies) | set(other.strategies)),
has_param_directive=True,
)
return merged

View File

@ -1,168 +1,278 @@
import os """
import sys 测试执行器
大规模支持改造
1. 并行执行ThreadPoolExecutor默认 8 个并发
2. 超时控制单个用例超时不阻塞整体
3. 步骤级结果解析##STEP_RESULT## 行
4. 进度显示实时打印完成数
5. 结果持久化保存 run_results.json
"""
from __future__ import annotations
import json import json
import subprocess
import time
import logging import logging
from pathlib import Path import subprocess
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
from config import config from config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
STEP_RESULT_PREFIX = "##STEP_RESULT##"
# ══════════════════════════════════════════════════════════════
# 数据结构
# ══════════════════════════════════════════════════════════════
@dataclass
class AssertionResult:
field: str
operator: str
expected: object
actual: object
passed: bool
message: str = ""
@dataclass @dataclass
class StepResult: class StepResult:
step_no: int step_no: int
interface_name: str interface_name: str
status: str # PASS | FAIL | ERROR status: str # "PASS" | "FAIL"
assertions: list[dict] = field(default_factory=list) assertions: list[AssertionResult] = field(default_factory=list)
# 每条 assertion: {"field","operator","expected","actual","passed","message"}
@dataclass @dataclass
class TestResult: class TestResult:
test_id: str test_id: str
file_path: str file_path: str
status: str # PASS | FAIL | ERROR | TIMEOUT status: str # "PASS" | "FAIL" | "ERROR" | "TIMEOUT"
message: str = "" message: str = ""
duration: float = 0.0 duration: float = 0.0
step_results: list[StepResult] = field(default_factory=list)
stdout: str = "" stdout: str = ""
stderr: str = "" stderr: str = ""
step_results: list[StepResult] = field(default_factory=list)
# ══════════════════════════════════════════════════════════════
# 执行器
# ══════════════════════════════════════════════════════════════
class TestRunner: class TestRunner:
def __init__(self):
self.timeout = getattr(config, "TEST_TIMEOUT", 60)
self.max_workers = getattr(config, "TEST_MAX_WORKERS", 8)
self.python_bin = getattr(config, "TEST_PYTHON_BIN", sys.executable)
# ── 主入口 ────────────────────────────────────────────────
def run_all(self, test_files: list[Path]) -> list[TestResult]: def run_all(self, test_files: list[Path]) -> list[TestResult]:
results = [] """
并行执行所有测试文件返回结果列表
"""
if not test_files:
logger.warning("No test files to run.")
return []
total = len(test_files) total = len(test_files)
print(f"\n{''*62}") results: list[TestResult] = []
print(f" Running {total} test(s) …") done = 0
print(f"{''*62}")
for idx, fp in enumerate(test_files, 1): logger.info(
print(f"\n[{idx}/{total}] {fp.name}") f"Running {total} test(s) with "
result = self._run_one(fp) f"max_workers={self.max_workers}, timeout={self.timeout}s"
)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_map = {
executor.submit(self._run_one, f): f
for f in test_files
}
for future in as_completed(future_map):
result = future.result()
results.append(result) results.append(result)
self._print_result(result) done += 1
icon = "" if result.status == "PASS" else ""
logger.info(
f" [{done}/{total}] {icon} {result.test_id} "
f"({result.status}) {result.duration:.2f}s"
)
# 按文件名排序,保持输出稳定
results.sort(key=lambda r: r.file_path)
return results return results
# ── 执行单个脚本 ────────────────────────────────────────── # ── 执行单个测试文件 ──────────────────────────────────────
def _run_one(self, file_path: Path) -> TestResult: def _run_one(self, file_path: Path) -> TestResult:
test_id = file_path.stem test_id = file_path.stem
t0 = time.time() start = time.monotonic()
try: try:
proc = subprocess.run( proc = subprocess.run(
[sys.executable, str(file_path)], [self.python_bin, str(file_path)],
capture_output=True, text=True, capture_output=True,
timeout=config.TEST_TIMEOUT, text=True,
env=self._env(), timeout=self.timeout,
) )
duration = time.time() - t0 duration = time.monotonic() - start
stdout = proc.stdout.strip() return self._parse_output(
stderr = proc.stderr.strip() test_id, str(file_path), proc, duration
status, message = self._parse_output(stdout, proc.returncode)
step_results = self._parse_step_results(stdout)
return TestResult(
test_id=test_id, file_path=str(file_path),
status=status, message=message,
duration=duration, stdout=stdout, stderr=stderr,
step_results=step_results,
) )
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
duration = time.monotonic() - start
logger.warning(f" TIMEOUT: {test_id} ({duration:.1f}s)")
return TestResult( return TestResult(
test_id=test_id, file_path=str(file_path), test_id=test_id,
status="TIMEOUT", message=f"Exceeded {config.TEST_TIMEOUT}s", file_path=str(file_path),
duration=time.time() - t0, status="TIMEOUT",
message=f"Exceeded {self.timeout}s timeout",
duration=duration,
) )
except Exception as e: except Exception as e:
duration = time.monotonic() - start
logger.error(f" ERROR running {test_id}: {e}")
return TestResult( return TestResult(
test_id=test_id, file_path=str(file_path), test_id=test_id,
status="ERROR", message=str(e), file_path=str(file_path),
duration=time.time() - t0, status="ERROR",
message=str(e),
duration=duration,
) )
# ── 解析输出 ────────────────────────────────────────────── # ── 解析输出 ──────────────────────────────────────────────
def _parse_output(self, stdout: str, returncode: int) -> tuple[str, str]: def _parse_output(
if not stdout: self,
return "FAIL", f"No output (exit={returncode})" test_id: str,
last = stdout.strip().splitlines()[-1].strip() file_path: str,
upper = last.upper() proc: subprocess.CompletedProcess,
if upper.startswith("PASS"): duration: float,
return "PASS", last[5:].strip() ) -> TestResult:
if upper.startswith("FAIL"): stdout = proc.stdout or ""
return "FAIL", last[5:].strip() stderr = proc.stderr or ""
return ("PASS" if returncode == 0 else "FAIL"), last step_results = self._parse_step_results(stdout)
# 最后一行决定整体状态
last_line = stdout.strip().splitlines()[-1] if stdout.strip() else ""
if last_line.startswith("PASS"):
status = "PASS"
message = last_line
elif last_line.startswith("FAIL"):
status = "FAIL"
message = last_line
elif proc.returncode != 0:
status = "FAIL"
message = f"Exit code {proc.returncode}: {stderr[:200]}"
else:
status = "ERROR"
message = "No PASS/FAIL line found in output"
return TestResult(
test_id=test_id,
file_path=file_path,
status=status,
message=message,
duration=duration,
step_results=step_results,
stdout=stdout,
stderr=stderr,
)
def _parse_step_results(self, stdout: str) -> list[StepResult]: def _parse_step_results(self, stdout: str) -> list[StepResult]:
"""
解析脚本中以 ##STEP_RESULT## 开头的结构化输出行
格式##STEP_RESULT## <json>
"""
results = [] results = []
for line in stdout.splitlines(): for line in stdout.splitlines():
line = line.strip() line = line.strip()
if line.startswith("##STEP_RESULT##"): if not line.startswith(STEP_RESULT_PREFIX):
continue
try: try:
data = json.loads(line[len("##STEP_RESULT##"):].strip()) raw = json.loads(line[len(STEP_RESULT_PREFIX):].strip())
assertions = [
AssertionResult(
field=a.get("field", ""),
operator=a.get("operator", ""),
expected=a.get("expected"),
actual=a.get("actual"),
passed=bool(a.get("passed", False)),
message=a.get("message", ""),
)
for a in raw.get("assertions", [])
]
results.append(StepResult( results.append(StepResult(
step_no=data.get("step_no", 0), step_no=raw.get("step_no", 0),
interface_name=data.get("interface_name", ""), interface_name=raw.get("interface_name", ""),
status=data.get("status", ""), status=raw.get("status", "FAIL"),
assertions=data.get("assertions", []), assertions=assertions,
)) ))
except Exception: except (json.JSONDecodeError, KeyError) as e:
pass logger.debug(f"Step result parse error: {e}")
return results return results
def _env(self) -> dict: # ── 摘要打印 ──────────────────────────────────────────────
env = os.environ.copy()
env["HTTP_BASE_URL"] = config.HTTP_BASE_URL
return env
# ── 打印 ──────────────────────────────────────────────────
def _print_result(self, r: TestResult):
icon = {"PASS": "", "FAIL": "", "TIMEOUT": "⏱️", "ERROR": "⚠️"}.get(r.status, "")
print(f" {icon} [{r.status}] {r.test_id} ({r.duration:.2f}s)")
print(f" {r.message}")
for sr in r.step_results:
s_icon = "" if sr.status == "PASS" else ""
print(f" {s_icon} Step {sr.step_no}: {sr.interface_name}")
for a in sr.assertions:
a_icon = "" if a.get("passed") else ""
print(f" {a_icon} {a.get('message','')} "
f"(expected={a.get('expected')}, actual={a.get('actual')})")
if r.stderr:
print(f" stderr: {r.stderr[:300]}")
def print_summary(self, results: list[TestResult]): def print_summary(self, results: list[TestResult]):
total = len(results) total = len(results)
passed = sum(1 for r in results if r.status == "PASS") passed = sum(1 for r in results if r.status == "PASS")
failed = sum(1 for r in results if r.status == "FAIL") failed = sum(1 for r in results if r.status == "FAIL")
errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT")) errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT"))
print(f"\n{''*62}") avg_dur = sum(r.duration for r in results) / total if total else 0
print(f" TEST SUMMARY")
print(f"{''*62}") print(f"\n{'' * 56}")
print(f" Total : {total}") print(f" Test Run Summary ({total} cases)")
print(f"{'' * 56}")
print(f" ✅ PASS : {passed}") print(f" ✅ PASS : {passed}")
print(f" ❌ FAIL : {failed}") print(f" ❌ FAIL : {failed}")
print(f" ⚠️ ERROR : {errors}") print(f" ⚠️ ERROR : {errors}")
print(f" Pass Rate: {passed/total*100:.1f}%" if total else " Pass Rate: N/A") print(f" ⏱ Avg : {avg_dur:.2f}s / case")
print(f"{''*62}\n") print(f"{'' * 56}")
if failed or errors:
print(" Failed / Error cases:")
for r in results:
if r.status not in ("PASS",):
print(f" [{r.status}] {r.test_id}: {r.message}")
print()
# ── 持久化 ────────────────────────────────────────────────
def save_results(self, results: list[TestResult], path: str): def save_results(self, results: list[TestResult], path: str):
with open(path, "w", encoding="utf-8") as f: data = [
json.dump([{ {
"test_id": r.test_id, "status": r.status, "test_id": r.test_id,
"message": r.message, "duration": r.duration, "file_path": r.file_path,
"stdout": r.stdout, "stderr": r.stderr, "status": r.status,
"message": r.message,
"duration": round(r.duration, 3),
"step_results": [ "step_results": [
{"step_no": sr.step_no, "interface_name": sr.interface_name, {
"status": sr.status, "assertions": sr.assertions} "step_no": s.step_no,
for sr in r.step_results "interface_name": s.interface_name,
"status": s.status,
"assertions": [
{
"field": a.field,
"operator": a.operator,
"expected": a.expected,
"actual": a.actual,
"passed": a.passed,
"message": a.message,
}
for a in s.assertions
], ],
} for r in results], f, ensure_ascii=False, indent=2) }
for s in r.step_results
],
}
for r in results
]
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"Run results → {path}") logger.info(f"Run results → {path}")

View File

@ -3,11 +3,18 @@
AI-Powered API Test Generator, Runner & Coverage Analyzer AI-Powered API Test Generator, Runner & Coverage Analyzer
Usage: Usage:
# 基础用法
python main.py --api-desc examples/api_desc.json \\ python main.py --api-desc examples/api_desc.json \\
--requirements "创建用户,删除用户" --requirements "创建用户,删除用户"
# 在需求中自然语言描述参数生成要求
python main.py --api-desc examples/api_desc.json \\ python main.py --api-desc examples/api_desc.json \\
--req-file examples/requirements.txt --requirements "创建用户测试参数不少于10组覆盖边界值和异常值,删除用户"
# 从文件读取需求(每行一条,支持参数生成指令)
python main.py --api-desc examples/api_desc.json \\
--req-file examples/requirements.txt \\
--batch-size 10
""" """
import argparse import argparse
@ -17,6 +24,7 @@ from pathlib import Path
from config import config from config import config
from core.parser import InterfaceParser, ApiDescriptor from core.parser import InterfaceParser, ApiDescriptor
from core.requirement_parser import RequirementParser, ParamConstraint
from core.prompt_builder import PromptBuilder from core.prompt_builder import PromptBuilder
from core.llm_client import LLMClient from core.llm_client import LLMClient
from core.test_generator import TestGenerator from core.test_generator import TestGenerator
@ -24,48 +32,57 @@ from core.test_runner import TestRunner
from core.analyzer import CoverageAnalyzer, CoverageReport from core.analyzer import CoverageAnalyzer, CoverageReport
from core.report_generator import ReportGenerator from core.report_generator import ReportGenerator
# ══════════════════════════════════════════════════════════════
# 日志初始化
# ══════════════════════════════════════════════════════════════
def setup_logging(debug: bool = False):
"""
配置日志级别
Bug 规避openai-python v2.24.0 + pydantic-core
第三方 SDK logger 强制锁定在 WARNING避免触发
model_dump(by_alias=None) pydantic-core Rust TypeError
Ref: https://github.com/openai/openai-python/issues/2921
"""
root_level = logging.DEBUG if debug else logging.INFO
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=root_level,
format="%(asctime)s [%(levelname)s] %(name)s%(message)s", format="%(asctime)s [%(levelname)s] %(name)s%(message)s",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
for name in ["openai", "openai._base_client", "anthropic", "httpx", "httpcore"]:
logging.getLogger(name).setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ── CLI ─────────────────────────────────────────────────────── # ══════════════════════════════════════════════════════════════
# CLI
# ══════════════════════════════════════════════════════════════
def build_arg_parser() -> argparse.ArgumentParser: def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser( p = argparse.ArgumentParser(
description="AI-Powered API Test Generator & Coverage Analyzer" description="AI-Powered API Test Generator & Coverage Analyzer",
) formatter_class=argparse.RawDescriptionHelpFormatter,
p.add_argument( epilog="""
"--api-desc", required=True, 参数生成指令示例可直接写在需求文本中
help="接口描述 JSON 文件(含 project / description / units 字段)", "创建用户测试参数不少于10组覆盖边界值和异常值"
) "查询接口至少8组边界值"
p.add_argument( "所有接口参数不少于5组覆盖等价类"
"--requirements", default="", """,
help="测试需求(逗号分隔)",
)
p.add_argument(
"--req-file", default="",
help="测试需求文件(每行一条)",
)
p.add_argument(
"--skip-run", action="store_true",
help="只生成测试文件,不执行",
)
p.add_argument(
"--skip-analyze", action="store_true",
help="跳过覆盖率分析",
)
p.add_argument(
"--output-dir", default="",
help="测试文件根输出目录(默认 config.GENERATED_TESTS_DIR",
)
p.add_argument(
"--debug", action="store_true",
help="开启 DEBUG 日志",
) )
p.add_argument("--api-desc", required=True, help="接口描述 JSON 文件")
p.add_argument("--requirements", default="", help="测试需求(逗号分隔,支持参数生成指令)")
p.add_argument("--req-file", default="", help="测试需求文件(每行一条)")
p.add_argument("--batch-size", type=int, default=0, help="每批接口数量0=使用 config.LLM_BATCH_SIZE")
p.add_argument("--workers", type=int, default=0, help="并行执行线程数0=使用 config.TEST_MAX_WORKERS")
p.add_argument("--skip-run", action="store_true", help="只生成,不执行")
p.add_argument("--skip-analyze", action="store_true", help="跳过覆盖率分析")
p.add_argument("--output-dir", default="", help="测试文件根输出目录")
p.add_argument("--debug", action="store_true", help="开启 DEBUG 日志(第三方 SDK 仍保持 WARNING")
return p return p
@ -77,82 +94,111 @@ def load_requirements(args) -> list[str]:
with open(args.req_file, "r", encoding="utf-8") as f: with open(args.req_file, "r", encoding="utf-8") as f:
reqs += [line.strip() for line in f if line.strip()] reqs += [line.strip() for line in f if line.strip()]
if not reqs: if not reqs:
logger.error( logger.error("未提供任何测试需求,请使用 --requirements 或 --req-file 指定。")
"No requirements provided. Use --requirements or --req-file."
)
sys.exit(1) sys.exit(1)
return reqs return reqs
# ── Main ────────────────────────────────────────────────────── # ══════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════
def main(): def main():
args = build_arg_parser().parse_args() args = build_arg_parser().parse_args()
setup_logging(debug=args.debug)
if args.debug: if args.output_dir: config.GENERATED_TESTS_DIR = args.output_dir
logging.getLogger().setLevel(logging.DEBUG) if args.batch_size > 0: config.LLM_BATCH_SIZE = args.batch_size
if args.output_dir: if args.workers > 0: config.TEST_MAX_WORKERS = args.workers
config.GENERATED_TESTS_DIR = args.output_dir
# ── Step 1: 解析接口描述 ────────────────────────────────── # ── Step 1: 解析接口描述 ──────────────────────────────────
logger.info("▶ Step 1: Parsing interface description") logger.info("▶ Step 1: 解析接口描述")
parser: InterfaceParser = InterfaceParser() parser = InterfaceParser()
descriptor: ApiDescriptor = parser.parse_file(args.api_desc) descriptor = parser.parse_file(args.api_desc)
project = descriptor.project project = descriptor.project
project_desc = descriptor.description project_desc = descriptor.description
interfaces = descriptor.interfaces interfaces = descriptor.interfaces
logger.info(f" Project : {project}") logger.info(f" 项目名称:{project}")
logger.info(f" Description: {project_desc}") logger.info(f" 项目描述:{project_desc}")
logger.info( logger.info(f" 接口数量:{len(interfaces)}")
f" Interfaces : {len(interfaces)}"
f"{[i.name for i in interfaces]}"
)
# 打印每个接口的 url 解析结果
for iface in interfaces: for iface in interfaces:
if iface.protocol == "function": if iface.protocol == "function":
logger.info( logger.info(
f" [{iface.name}] source={iface.source_file} " f" [function] {iface.name} "
f"module={iface.module_path}" f"{iface.source_file} (module: {iface.module_path})"
) )
else: else:
logger.info(f" [http] {iface.name}{iface.http_full_url}")
# ── Step 2: 加载并解析测试需求(含参数生成指令)──────────
logger.info("▶ Step 2: 加载并解析测试需求 …")
raw_requirements = load_requirements(args)
req_parser = RequirementParser()
clean_reqs, per_req_constraints, global_constraint = req_parser.parse(
raw_requirements
)
# 打印原始需求 & 解析结果
for i, (raw, clean, c) in enumerate(
zip(raw_requirements, clean_reqs, per_req_constraints), 1
):
if c.has_param_directive:
logger.info( logger.info(
f" [{iface.name}] full_url={iface.http_full_url}" f" {i}. {raw}\n"
f" └─ 参数约束:{c} | 业务需求:{clean}"
) )
else:
logger.info(f" {i}. {raw}")
# ── Step 2: 加载测试需求 ────────────────────────────────── if global_constraint.has_param_directive:
logger.info("▶ Step 2: Loading test requirements …") logger.info(f" 全局参数约束:{global_constraint}")
requirements = load_requirements(args)
for i, req in enumerate(requirements, 1):
logger.info(f" {i}. {req}")
# ── Step 3: 调用 LLM 生成测试用例 ──────────────────────── # 使用"清洗后的需求"传给 LLM去掉参数指令保留纯业务描述
logger.info("▶ Step 3: Calling LLM …") # 同时保留原始需求用于覆盖率分析(保证与用户输入一致)
effective_requirements = clean_reqs
# ── Step 3: 构建 Prompt 并调用 LLM 生成测试用例 ──────────
logger.info(
f"▶ Step 3: 调用 LLM 生成测试用例 "
f"(batch_size={getattr(config, 'LLM_BATCH_SIZE', 10)}) …"
)
builder = PromptBuilder() builder = PromptBuilder()
test_cases = LLMClient().generate_test_cases( iface_summary = parser.to_summary_dict(interfaces)
builder.get_system_prompt(), project_header = builder.build_project_header(project, project_desc)
builder.build_user_prompt(
requirements, interfaces, parser, # System Prompt注入全局参数约束规则
project=project, system_prompt = builder.get_system_prompt(
project_desc=project_desc, global_constraint=global_constraint
),
) )
logger.info(f" LLM returned {len(test_cases)} test case(s)")
test_cases = LLMClient().generate_test_cases(
system_prompt=system_prompt,
user_prompt="",
iface_summaries=iface_summary,
requirements=effective_requirements,
project_header=project_header,
param_constraint=global_constraint, # ← 注入参数约束
)
logger.info(f" 共生成测试用例:{len(test_cases)}")
# ── Step 4: 生成测试文件 ────────────────────────────────── # ── Step 4: 生成测试文件 ──────────────────────────────────
logger.info( logger.info("▶ Step 4: 生成测试文件 …")
f"▶ Step 4: Generating test files (project='{project}') …"
)
generator = TestGenerator(project=project, project_desc=project_desc) generator = TestGenerator(project=project, project_desc=project_desc)
test_files = generator.generate(test_cases) test_files = generator.generate(test_cases)
out = generator.output_dir out = generator.output_dir
logger.info(f" 输出目录:{out.resolve()}")
run_results = [] run_results = []
if not args.skip_run: if not args.skip_run:
# ── Step 5: 执行测试 ────────────────────────────────── # ── Step 5: 并行执行测试 ──────────────────────────────
logger.info("▶ Step 5: Running tests …") logger.info(
f"▶ Step 5: 执行测试 "
f"(workers={getattr(config, 'TEST_MAX_WORKERS', 8)}) …"
)
runner = TestRunner() runner = TestRunner()
run_results = runner.run_all(test_files) run_results = runner.run_all(test_files)
runner.print_summary(run_results) runner.print_summary(run_results)
@ -160,31 +206,36 @@ def main():
if not args.skip_analyze: if not args.skip_analyze:
# ── Step 6: 覆盖率分析 ──────────────────────────────── # ── Step 6: 覆盖率分析 ────────────────────────────────
logger.info("▶ Step 6: Analyzing coverage") logger.info("▶ Step 6: 覆盖率分析")
report = CoverageAnalyzer( report = CoverageAnalyzer(
interfaces=interfaces, interfaces=interfaces,
requirements=requirements, requirements=raw_requirements, # 用原始需求做覆盖率分析
test_cases=test_cases, test_cases=test_cases,
run_results=run_results, run_results=run_results,
).analyze() ).analyze()
# ── Step 7: 生成报告 ────────────────────────────────── # ── Step 7: 生成报告 ──────────────────────────────────
logger.info("▶ Step 7: Generating reports") logger.info("▶ Step 7: 生成报告")
rg = ReportGenerator() rg = ReportGenerator()
rg.save_json(report, str(out / "coverage_report.json")) rg.save_json(report, str(out / "coverage_report.json"))
rg.save_html(report, str(out / "coverage_report.html")) rg.save_html(report, str(out / "coverage_report.html"))
_print_terminal_summary(report, out, project) _print_terminal_summary(report, out, project, global_constraint)
logger.info(f"\nDone. Output: {out.resolve()}") logger.info(f"\n完成。输出目录:{out.resolve()}")
# ── 终端摘要 ────────────────────────────────────────────────── # ══════════════════════════════════════════════════════════════
# 终端摘要
# ══════════════════════════════════════════════════════════════
def _print_terminal_summary( def _print_terminal_summary(
report: CoverageReport, out: Path, project: str report: "CoverageReport",
out: Path,
project: str,
global_constraint: ParamConstraint,
): ):
W = 66 W = 68
def bar(rate: float, w: int = 20) -> str: def bar(rate: float, w: int = 20) -> str:
filled = int(rate * w) filled = int(rate * w)
@ -193,8 +244,8 @@ def _print_terminal_summary(
return f"{'' * filled}{'' * empty} {rate * 100:.1f}% {icon}" return f"{'' * filled}{'' * empty} {rate * 100:.1f}% {icon}"
print(f"\n{'' * W}") print(f"\n{'' * W}")
print(f" PROJECT : {project}") print(f" 项目:{project}")
print(f" COVERAGE SUMMARY") print(f" 覆盖率摘要")
print(f"{'' * W}") print(f"{'' * W}")
print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}") print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}") print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
@ -203,25 +254,31 @@ def _print_terminal_summary(
print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}") print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
print(f" 用例通过率 {bar(report.pass_rate)}") print(f" 用例通过率 {bar(report.pass_rate)}")
print(f"{'' * W}") print(f"{'' * W}")
print(f" Total Gaps : {len(report.gaps)}") print(f" 测试用例总数 {report.total_test_cases}")
print(f" 🔴 Critical: {report.critical_gap_count}") print(f" 覆盖缺口总数 {len(report.gaps)}")
print(f" 🟠 High : {report.high_gap_count}") print(f" 🔴 严重缺口 {report.critical_gap_count}")
print(f" 🟠 高优先级缺口 {report.high_gap_count}")
# 参数约束达成情况
if global_constraint.has_param_directive:
print(f"{'' * W}")
print(f" 参数生成约束 {global_constraint}")
actual = report.total_test_cases
needed = global_constraint.min_groups
status = "✅ 已满足" if actual >= needed else f"❌ 未满足(需 {needed} 组,实际 {actual} 组)"
print(f" 参数组数要求 {status}")
if report.gaps: if report.gaps:
print(f"{'' * W}") print(f"{'' * W}")
print(" Top Gaps (up to 8):") print(" Top 缺口最多显示8条")
icons = { icons = {"critical": "🔴", "high": "🟠", "medium": "🟡", "low": "🔵"}
"critical": "🔴", "high": "🟠",
"medium": "🟡", "low": "🔵",
}
for g in report.gaps[:8]: for g in report.gaps[:8]:
icon = icons.get(g.severity, "") print(f" {icons.get(g.severity, '')} [{g.gap_type}] {g.target}")
print(f" {icon} [{g.gap_type}] {g.target}")
print(f"{g.suggestion}") print(f"{g.suggestion}")
print(f"{'' * W}") print(f"{'' * W}")
print(f" Output : {out.resolve()}") print(f" 输出目录:{out.resolve()}")
print(f" • coverage_report.html ← open in browser") print(f" • coverage_report.html")
print(f" • coverage_report.json") print(f" • coverage_report.json")
print(f" • run_results.json") print(f" • run_results.json")
print(f" • test_cases_summary.json") print(f" • test_cases_summary.json")

View File

@ -16,4 +16,4 @@ export HTTP_BASE_URL="http://localhost:8080"
# 只生成不执行 # 只生成不执行
python main.py --api-desc examples/api_desc.json \ python main.py --api-desc examples/api_desc.json \
--requirements "1.对每个接口进行测试,支持特殊值、边界值测试" --requirements "每个测试用例的参数不少于10组"