source optimization

This commit is contained in:
sontolau 2026-03-09 11:48:19 +08:00
parent c247a8f1dc
commit ce898d81ee
7 changed files with 1457 additions and 375 deletions

View File

@ -1,58 +1,204 @@
"""
LLM 客户端
变更
- generate_test_cases() 新增 param_constraint 参数
- 分批调用时将约束透传给 PromptBuilder.build_batch_prompt()
"""
from __future__ import annotations
import json
import re
import logging
import re
import time
from typing import Any
from openai import OpenAI
from config import config
logger = logging.getLogger(__name__)
class LLMClient:
"""封装 LLM API 调用OpenAI 兼容接口)"""
# ══════════════════════════════════════════════════════════════
# JSON 健壮解析工具
# ══════════════════════════════════════════════════════════════
def __init__(self):
self.client = OpenAI(
api_key=config.LLM_API_KEY,
base_url=config.LLM_BASE_URL,
)
class RobustJSONParser:
def generate_test_cases(self, system_prompt: str, user_prompt: str) -> list[dict]:
logger.info(f"Calling LLM: model={config.LLM_MODEL}")
logger.debug(f"--- USER PROMPT ---\n{user_prompt}\n---")
def parse(self, text: str) -> list[dict]:
text = text.strip()
text = re.sub(r'^```(?:json)?\s*', '', text, flags=re.MULTILINE)
text = re.sub(r'\s*```$', '', text, flags=re.MULTILINE)
text = text.strip()
response = self.client.chat.completions.create(
model=config.LLM_MODEL,
temperature=config.LLM_TEMPERATURE,
max_tokens=config.LLM_MAX_TOKENS,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
raw = response.choices[0].message.content
logger.debug(f"--- LLM RAW RESPONSE ---\n{raw[:800]}\n---")
return self._parse_json(raw)
# ── 解析 ──────────────────────────────────────────────────
def _parse_json(self, content: str) -> list[dict]:
content = content.strip()
# 去除可能的 markdown 代码块
content = re.sub(r"^```(?:json)?\s*", "", content, flags=re.MULTILINE)
content = re.sub(r"\s*```\s*$", "", content, flags=re.MULTILINE)
content = content.strip()
start = text.find('[')
if start == -1:
logger.warning("LLM 响应中未找到 JSON 数组。")
return []
text = text[start:]
try:
data = json.loads(content)
result = json.loads(text)
if isinstance(result, list):
return result
except json.JSONDecodeError:
# 尝试提取第一个 JSON 数组
match = re.search(r"\[.*\]", content, re.DOTALL)
if match:
data = json.loads(match.group())
else:
logger.error(f"Cannot parse LLM response as JSON:\n{content[:600]}")
raise
pass
if not isinstance(data, list):
raise ValueError(f"Expected JSON array, got {type(data)}")
return data
text = self._fix_truncated(text)
try:
result = json.loads(text)
if isinstance(result, list):
logger.warning("JSON 截断修复成功,部分用例可能丢失。")
return result
except json.JSONDecodeError as e:
logger.error(f"JSON 解析失败(修复后仍无效):{e}")
return []
def _fix_truncated(self, text: str) -> str:
last = text.rfind('},')
if last == -1:
last = text.rfind('}')
if last == -1:
return text
return text[:last + 1] + ']'
# ══════════════════════════════════════════════════════════════
# LLM 客户端
# ══════════════════════════════════════════════════════════════
class LLMClient:
DEFAULT_BATCH_SIZE = 10
def __init__(self):
self.client = OpenAI(
api_key=config.LLM_API_KEY,
base_url=getattr(config, "LLM_BASE_URL", None),
)
self.model = config.LLM_MODEL
self.max_tokens = getattr(config, "LLM_MAX_TOKENS", 8192)
self.batch_size = getattr(config, "LLM_BATCH_SIZE", self.DEFAULT_BATCH_SIZE)
self.retry = getattr(config, "LLM_RETRY", 3)
self.retry_delay = getattr(config, "LLM_RETRY_DELAY", 5)
self.batch_interval = getattr(config, "LLM_BATCH_INTERVAL", 1)
self._parser = RobustJSONParser()
from core.prompt_builder import PromptBuilder
self._prompt_builder = PromptBuilder()
# ── 主入口 ────────────────────────────────────────────────
def generate_test_cases(
self,
system_prompt: str,
user_prompt: str,
iface_summaries: list[dict] | None = None,
requirements: list[str] | None = None,
project_header: str = "",
param_constraint: "ParamConstraint | None" = None, # ← 新增
) -> list[dict]:
"""
生成测试用例
方式 A单次 user_prompt
方式 B分批 iface_summaries + requirements推荐大规模场景
param_constraint RequirementParser 解析出的全局参数约束
有值时自动注入参数数据集到每批 prompt
"""
if iface_summaries is not None and requirements is not None:
return self._generate_batched(
system_prompt, iface_summaries, requirements,
project_header, param_constraint,
)
return self._call_with_retry(system_prompt, user_prompt)
# ── 分批调用 ──────────────────────────────────────────────
def _generate_batched(
self,
system_prompt: str,
iface_summaries: list[dict],
requirements: list[str],
project_header: str,
param_constraint: Any,
) -> list[dict]:
total = len(iface_summaries)
batches = self._make_batches(iface_summaries, self.batch_size)
all_cases: list[dict] = []
logger.info(
f"分批模式:共 {total} 个接口 → "
f"{len(batches)}× 每批最多 {self.batch_size}"
)
if param_constraint and param_constraint.has_param_directive:
logger.info(
f"参数约束已启用:{param_constraint}"
)
for idx, batch in enumerate(batches, 1):
names = [i.get("name", "?") for i in batch]
logger.info(f"{idx}/{len(batches)} 批:{names}")
user_prompt = self._prompt_builder.build_batch_prompt(
batch=batch,
requirements=requirements,
project_header=project_header,
param_constraint=param_constraint, # ← 透传
)
cases = self._call_with_retry(system_prompt, user_prompt)
logger.info(f"{idx} 批 → 生成 {len(cases)} 个测试用例")
all_cases.extend(cases)
if idx < len(batches):
time.sleep(self.batch_interval)
logger.info(f"全部批次完成,共生成 {len(all_cases)} 个测试用例")
return all_cases
# ── 单次调用(含重试)────────────────────────────────────
def _call_with_retry(
self,
system_prompt: str,
user_prompt: str,
) -> list[dict]:
last_error: Exception | None = None
for attempt in range(1, self.retry + 1):
try:
logger.debug(f"LLM 调用第 {attempt}/{self.retry} 次 …")
response = self.client.chat.completions.create(
model=self.model,
max_tokens=self.max_tokens,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
raw = response.choices[0].message.content or ""
logger.debug(f"LLM 响应长度:{len(raw)} 字符")
cases = self._parser.parse(raw)
if cases:
return cases
logger.warning(f"{attempt} 次调用返回空结果,准备重试 …")
except Exception as e:
last_error = e
logger.warning(f"{attempt} 次调用失败:{e}")
if attempt < self.retry:
time.sleep(self.retry_delay)
logger.error(
f"已重试 {self.retry} 次,全部失败。最后一次错误:{last_error}"
)
return []
@staticmethod
def _make_batches(items: list[Any], size: int) -> list[list[Any]]:
return [items[i:i + size] for i in range(0, len(items), size)]

View File

@ -0,0 +1,393 @@
"""
测试参数生成器
职责
根据接口参数的数据类型和参数生成策略
自动生成覆盖正常值边界值异常值的测试数据集
支持的数据类型
string / str
integer / int / number
float / double / decimal
boolean / bool
array / list
object / dict / map
输出格式供注入 prompt
{
"param_name": {
"type": "string",
"groups": [
{"label": "正常值", "value": "hello", "category": "normal"},
{"label": "空字符串", "value": "", "category": "boundary"},
{"label": "超长字符串","value": "a"*256, "category": "exception"},
...
]
}
}
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from core.requirement_parser import ParamStrategy
# ══════════════════════════════════════════════════════════════
# 数据结构
# ══════════════════════════════════════════════════════════════
@dataclass
class ParamGroup:
"""单组测试参数。"""
label: str # 人类可读的描述,如"空字符串"
value: Any # 实际参数值
category: str # "normal" | "boundary" | "exception" | "equivalence"
def to_dict(self) -> dict:
return {
"label": self.label,
"value": self.value,
"category": self.category,
}
@dataclass
class ParamDataSet:
"""单个参数的完整测试数据集。"""
param_name: str
param_type: str
groups: list[ParamGroup] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"type": self.param_type,
"groups": [g.to_dict() for g in self.groups],
}
@property
def normal_values(self) -> list[Any]:
return [g.value for g in self.groups if g.category == "normal"]
@property
def boundary_values(self) -> list[Any]:
return [g.value for g in self.groups if g.category == "boundary"]
@property
def exception_values(self) -> list[Any]:
return [g.value for g in self.groups if g.category == "exception"]
# ══════════════════════════════════════════════════════════════
# 各类型测试数据库
# ══════════════════════════════════════════════════════════════
# ── string ────────────────────────────────────────────────────
_STRING_NORMAL: list[ParamGroup] = [
ParamGroup("普通字符串", "hello", "normal"),
ParamGroup("中文字符串", "测试数据", "normal"),
ParamGroup("字母数字混合", "user123", "normal"),
]
_STRING_BOUNDARY: list[ParamGroup] = [
ParamGroup("空字符串", "", "boundary"),
ParamGroup("单字符", "a", "boundary"),
ParamGroup("最大长度(255)", "a" * 255, "boundary"),
ParamGroup("恰好256字符", "a" * 256, "boundary"),
ParamGroup("含空格", "hello world", "boundary"),
ParamGroup("首尾空格", " hello ", "boundary"),
]
_STRING_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("超长字符串", "a" * 1000, "exception"),
ParamGroup("纯空格", " ", "exception"),
ParamGroup("特殊字符", "!@#$%^&*()", "exception"),
ParamGroup("换行符", "line1\nline2", "exception"),
ParamGroup("SQL注入", "' OR '1'='1", "exception"),
ParamGroup("XSS注入", "<script>alert(1)</script>", "exception"),
ParamGroup("Unicode特殊字符","你好\u0000世界", "exception"),
ParamGroup("数字类型误传", 12345, "exception"),
]
# ── integer ───────────────────────────────────────────────────
_INTEGER_NORMAL: list[ParamGroup] = [
ParamGroup("典型正整数", 1, "normal"),
ParamGroup("较大正整数", 100, "normal"),
]
_INTEGER_BOUNDARY: list[ParamGroup] = [
ParamGroup("", 0, "boundary"),
ParamGroup("正边界值1", 1, "boundary"),
ParamGroup("负边界值-1", -1, "boundary"),
ParamGroup("最大值(int32)", 2_147_483_647, "boundary"),
ParamGroup("最小值(int32)", -2_147_483_648, "boundary"),
ParamGroup("最大值+1", 2_147_483_648, "boundary"),
]
_INTEGER_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("字符串类型", "abc", "exception"),
ParamGroup("浮点数", 3.14, "exception"),
ParamGroup("布尔值true", True, "exception"),
ParamGroup("空字符串", "", "exception"),
ParamGroup("超大整数", 10 ** 20, "exception"),
]
# ── float ─────────────────────────────────────────────────────
_FLOAT_NORMAL: list[ParamGroup] = [
ParamGroup("典型浮点数", 3.14, "normal"),
ParamGroup("整数形式", 1.0, "normal"),
]
_FLOAT_BOUNDARY: list[ParamGroup] = [
ParamGroup("", 0.0, "boundary"),
ParamGroup("极小正数", 1e-10, "boundary"),
ParamGroup("极大正数", 1e10, "boundary"),
ParamGroup("负数", -3.14, "boundary"),
ParamGroup("负无穷", float('-inf'), "boundary"),
ParamGroup("正无穷", float('inf'), "boundary"),
]
_FLOAT_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("NaN", float('nan'), "exception"),
ParamGroup("字符串类型", "abc", "exception"),
ParamGroup("空字符串", "", "exception"),
]
# ── boolean ───────────────────────────────────────────────────
_BOOLEAN_NORMAL: list[ParamGroup] = [
ParamGroup("true", True, "normal"),
ParamGroup("false", False, "normal"),
]
_BOOLEAN_BOUNDARY: list[ParamGroup] = [
ParamGroup("整数1", 1, "boundary"),
ParamGroup("整数0", 0, "boundary"),
]
_BOOLEAN_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("字符串'true'", "true", "exception"),
ParamGroup("字符串'false'", "false", "exception"),
ParamGroup("字符串'yes'", "yes", "exception"),
ParamGroup("随机字符串", "abc", "exception"),
]
# ── array ─────────────────────────────────────────────────────
_ARRAY_NORMAL: list[ParamGroup] = [
ParamGroup("普通数组", [1, 2, 3], "normal"),
ParamGroup("字符串数组", ["a", "b", "c"], "normal"),
]
_ARRAY_BOUNDARY: list[ParamGroup] = [
ParamGroup("空数组", [], "boundary"),
ParamGroup("单元素数组", [1], "boundary"),
ParamGroup("大数组(100元素)", list(range(100)), "boundary"),
ParamGroup("含None元素", [1, None, 3], "boundary"),
]
_ARRAY_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("字符串类型", "not_an_array", "exception"),
ParamGroup("整数类型", 123, "exception"),
ParamGroup("嵌套过深", [[[[[1]]]]], "exception"),
]
# ── object ────────────────────────────────────────────────────
_OBJECT_NORMAL: list[ParamGroup] = [
ParamGroup("完整字段对象", {"id": 1, "name": "test"}, "normal"),
]
_OBJECT_BOUNDARY: list[ParamGroup] = [
ParamGroup("空对象", {}, "boundary"),
ParamGroup("仅必填字段", {"id": 1}, "boundary"),
ParamGroup("含额外字段", {"id": 1, "extra": "x"}, "boundary"),
]
_OBJECT_EXCEPTION: list[ParamGroup] = [
ParamGroup("None/null", None, "exception"),
ParamGroup("字符串类型", "not_an_object", "exception"),
ParamGroup("数组类型", [1, 2, 3], "exception"),
ParamGroup("字段值类型错误", {"id": "not_int"}, "exception"),
]
# 类型 → 数据映射表
_TYPE_DATA: dict[str, dict[str, list[ParamGroup]]] = {
"string": {"normal": _STRING_NORMAL, "boundary": _STRING_BOUNDARY, "exception": _STRING_EXCEPTION},
"integer": {"normal": _INTEGER_NORMAL, "boundary": _INTEGER_BOUNDARY, "exception": _INTEGER_EXCEPTION},
"float": {"normal": _FLOAT_NORMAL, "boundary": _FLOAT_BOUNDARY, "exception": _FLOAT_EXCEPTION},
"boolean": {"normal": _BOOLEAN_NORMAL, "boundary": _BOOLEAN_BOUNDARY, "exception": _BOOLEAN_EXCEPTION},
"array": {"normal": _ARRAY_NORMAL, "boundary": _ARRAY_BOUNDARY, "exception": _ARRAY_EXCEPTION},
"object": {"normal": _OBJECT_NORMAL, "boundary": _OBJECT_BOUNDARY, "exception": _OBJECT_EXCEPTION},
}
# 类型别名归一化
_TYPE_ALIAS: dict[str, str] = {
"str": "string", "text": "string", "varchar": "string",
"int": "integer", "long": "integer", "number": "integer",
"float": "float", "double": "float", "decimal": "float",
"bool": "boolean",
"list": "array",
"dict": "object", "map": "object", "json": "object",
}
# ══════════════════════════════════════════════════════════════
# 参数生成器
# ══════════════════════════════════════════════════════════════
class ParamGenerator:
"""
根据参数类型和策略生成测试数据集
"""
def generate(
self,
param_name: str,
param_type: str,
strategies: list[ParamStrategy],
min_groups: int = 2,
) -> ParamDataSet:
"""
生成单个参数的测试数据集
Args:
param_name : 参数名
param_type : 数据类型支持别名
strategies : 需要覆盖的策略列表
min_groups : 最少生成的参数组数
Returns:
ParamDataSet
"""
norm_type = self._normalize_type(param_type)
type_data = _TYPE_DATA.get(norm_type, _TYPE_DATA["string"])
groups: list[ParamGroup] = []
# 始终包含正常值
groups.extend(type_data["normal"])
# 按策略追加
for strategy in strategies:
if strategy == ParamStrategy.BOUNDARY:
groups.extend(type_data["boundary"])
elif strategy == ParamStrategy.EXCEPTION:
groups.extend(type_data["exception"])
elif strategy == ParamStrategy.EQUIVALENCE:
# 等价类:正常值 + 部分边界值
groups.extend(type_data["boundary"][:3])
elif strategy == ParamStrategy.RANDOM:
groups.extend(self._random_groups(norm_type))
# 去重(按 label
seen_labels: set[str] = set()
unique_groups: list[ParamGroup] = []
for g in groups:
if g.label not in seen_labels:
seen_labels.add(g.label)
unique_groups.append(g)
# 补齐到 min_groups
if len(unique_groups) < min_groups:
extra = type_data["boundary"] + type_data["exception"]
for g in extra:
if g.label not in seen_labels and len(unique_groups) < min_groups:
seen_labels.add(g.label)
unique_groups.append(g)
return ParamDataSet(
param_name=param_name,
param_type=norm_type,
groups=unique_groups,
)
def generate_for_interface(
self,
interface_summary: dict,
strategies: list[ParamStrategy],
min_groups: int = 2,
) -> dict[str, ParamDataSet]:
"""
为一个接口的所有入参生成测试数据集
Args:
interface_summary : parser.to_summary_dict() 中的单个接口 dict
strategies : 策略列表
min_groups : 最少组数
Returns:
{ param_name: ParamDataSet }
"""
result: dict[str, ParamDataSet] = {}
params = interface_summary.get("params", {})
for param_name, param_info in params.items():
if isinstance(param_info, dict):
param_type = param_info.get("type", "string")
inout = param_info.get("inout", "in")
else:
param_type = str(param_info)
inout = "in"
# 只对入参in / inout生成测试数据
if inout not in ("in", "inout"):
continue
result[param_name] = self.generate(
param_name=param_name,
param_type=param_type,
strategies=strategies,
min_groups=min_groups,
)
return result
# ── 工具 ──────────────────────────────────────────────────
@staticmethod
def _normalize_type(t: str) -> str:
t = t.lower().strip()
return _TYPE_ALIAS.get(t, t if t in _TYPE_DATA else "string")
@staticmethod
def _random_groups(norm_type: str) -> list[ParamGroup]:
"""生成少量随机值(补充策略)。"""
import random, string
if norm_type == "string":
val = ''.join(random.choices(string.ascii_letters, k=8))
return [ParamGroup(f"随机字符串({val})", val, "equivalence")]
if norm_type == "integer":
val = random.randint(-1000, 1000)
return [ParamGroup(f"随机整数({val})", val, "equivalence")]
if norm_type == "float":
val = round(random.uniform(-100, 100), 4)
return [ParamGroup(f"随机浮点数({val})", val, "equivalence")]
return []
def to_prompt_text(
self,
datasets: dict[str, ParamDataSet],
min_groups: int,
) -> str:
"""
将参数数据集转换为注入 prompt 的中文说明文本
"""
if not datasets:
return ""
lines = [
f"### 参数测试数据集(每个参数至少覆盖 {min_groups} 组)",
"",
]
for param_name, ds in datasets.items():
lines.append(f"**参数:{param_name}**(类型:{ds.param_type}")
for g in ds.groups:
category_label = {
"normal": "正常值",
"boundary": "边界值",
"exception": "异常值",
"equivalence": "等价类",
}.get(g.category, g.category)
lines.append(
f" - [{category_label}] {g.label}`{repr(g.value)}`"
)
lines.append("")
lines.append(
f"> 请从以上数据集中选取参数值组合,"
f"确保生成的测试用例总数不少于 **{min_groups} 组**"
f"并覆盖正常值、边界值和异常值场景。"
)
return "\n".join(lines)

View File

@ -1,167 +1,332 @@
"""
提示词构建器
升级点
- 传递项目 description
- function 协议source_file(.py) module_path from <module> import <func>
- HTTP 协议full_url = url(base) + name(path)
- parameters 统一字段inout 区分输入/输出
提示词构建器中文版
变更说明
- 新增 build_param_directive_section()将参数约束注入 System Prompt
- 新增 build_param_dataset_section()将参数数据集注入 User Prompt
- build_batch_prompt / build_user_prompt 支持传入参数数据集
"""
import json
from core.parser import InterfaceInfo, InterfaceParser
from core.requirement_parser import ParamConstraint, ParamStrategy
from core.param_generator import ParamGenerator, ParamDataSet
SYSTEM_PROMPT = """
You are a senior software test engineer. Generate test cases based on the provided
interface descriptions and test requirements.
# ══════════════════════════════════════════════════════════════
# System Prompt中文
# ══════════════════════════════════════════════════════════════
_SYSTEM_PROMPT_BASE = """
你是一名资深软件测试工程师擅长根据接口描述和测试需求生成高质量的测试用例
OUTPUT FORMAT
Return ONLY a valid JSON array. No markdown fences. No extra text.
输出格式要求
Each element = ONE test case:
- 只输出一个合法的 JSON 数组不要包含任何 markdown 代码块```
- 不要在 JSON 前后添加任何解释性文字
- 每个数组元素代表一个测试用例结构如下
{
"test_id" : "unique id, e.g. TC_001",
"test_name" : "short descriptive name",
"description" : "what this test verifies",
"requirement" : "the original requirement text this case maps to",
"requirement_id" : "e.g. REQ.01",
"test_id" : "唯一编号,如 TC_001",
"test_name" : "简短的测试名称",
"description" : "本用例验证的内容",
"requirement" : "对应的原始需求文本",
"requirement_id" : "需求编号,如 REQ.01(从接口描述中的 requirement_id 字段获取)",
"steps": [
{
"step_no" : 1,
"interface_name" : "function or endpoint name",
"protocol" : "http or function",
"url" : "source_file (function) or full_url (http)",
"purpose" : "why this step is needed",
"step_no" : 1,
"interface_name" : "接口名称或函数名",
"protocol" : "http 或 function",
"url" : "function 协议填 source_filehttp 协议填 full_url",
"purpose" : "本步骤的目的",
"input": {
// Only parameters with inout = "in" or "inout"
"param_name": <value>
"参数名": "参数值"
},
"use_output_of": { // optional
"use_output_of": {
"step_no" : 1,
"field" : "user_id",
"as_param" : "user_id"
"field" : "上一步返回值中的字段名",
"as_param" : "作为本步骤的参数名"
},
"assertions": [
{
"field" : "field_name or 'return' or 'exception'",
"operator" : "eq|ne|gt|lt|gte|lte|in|not_null|contains|raised|not_raised",
"expected" : <value>,
"message" : "human readable description"
"field" : "断言的字段名,或 'return'(整体返回值)或 'exception'(异常)",
"operator" : "eq | ne | gt | lt | gte | lte | in | not_null | contains | raised | not_raised",
"expected" : "期望值",
"message" : "断言说明"
}
]
}
],
"test_data_notes" : "explanation of auto-generated test data",
"test_code" : "<complete Python script — see rules below>"
"test_data_notes" : "测试数据说明(说明本用例使用了哪类参数:正常值/边界值/异常值)",
"test_code" : "完整可运行的 Python 测试脚本(见下方规则)"
}
TEST CODE RULES
测试代码编写规则
1. Complete, runnable Python script. No external test framework.
2. Allowed imports: standard library + `requests` + the actual module under test.
1. test_code 必须是完整可直接运行的 Python 脚本
2. 不得使用 unittest pytest 框架
3. 只允许导入Python 标准库requests被测模块
FUNCTION INTERFACES
3. Each function interface has:
"source_file" : e.g. "create_user.py"
"module_path" : e.g. "create_user" (derived from source_file)
4. Import the REAL function using module_path:
function 协议接口
4. 使用 module_path 字段导入真实函数若导入失败则使用桩函数
try:
from <module_path> import <function_name>
except ImportError:
# Stub fallback — simulate on_success for positive tests,
# on_failure / raise Exception for negative tests
def <function_name>(<in_params>):
return <on_success_value> # or raise Exception(...)
5. Call the function with ONLY "in"/"inout" parameters.
6. Capture the return value; assert against on_success or on_failure structure.
7. If on_failure.value contains "exception" or "raises":
- In negative tests: wrap call in try/except, assert exception IS raised.
- Use field="exception", operator="raised", expected="Exception".
def <function_name>(<入参列表>):
return <on_success_value> # 桩函数,仅供结构验证
5. 调用函数时只传 inout "in" "inout" 的参数
6. on_failure 包含 "exception"负向测试需用 try/except 捕获异常
HTTP INTERFACES
8. Each HTTP interface has:
"full_url" : complete URL, e.g. "http://127.0.0.1/api/delete_user"
"method" : "get" | "post" | "put" | "delete" etc.
9. Send request:
http 协议接口
7. 使用 full_url 字段作为请求地址
resp = requests.<method>("<full_url>", json=<input_dict>)
10. Assert on resp.status_code and resp.json() fields.
8. 断言 resp.status_code resp.json() 中的字段
MULTI-STEP
11. Execute steps in order; extract fields from previous step's return value
and pass to next step via use_output_of mapping.
STRUCTURED OUTPUT (REQUIRED)
12. After EACH step, print:
##STEP_RESULT## {"step_no":<n>,"interface_name":"...","status":"PASS|FAIL",
"assertions":[{"field":"...","operator":"...","expected":...,
"actual":...,"passed":true|false,"message":"..."}]}
13. Final line must be:
PASS: <summary> or FAIL: <summary>
14. Wrap entire test body in try/except; print FAIL on any unhandled exception.
15. Do NOT use unittest or pytest.
结构化输出必须遵守
9. 每个步骤执行后必须在标准输出打印如下格式的一行
##STEP_RESULT## {"step_no":<n>,"interface_name":"...","status":"PASS 或 FAIL",
"assertions":[{"field":"...","operator":"...","expected":...,
"actual":...,"passed":true false,"message":"..."}]}
10. 脚本最后一行必须是
PASS: <摘要> FAIL: <摘要>
11. 整个脚本主体用 try/except 包裹捕获未处理异常时打印 FAIL
ASSERTION RULES BY RETURN TYPE
断言编写规则
return.type = "dict" assert each key in on_success.value (positive)
or on_failure.value (negative)
return.type = "boolean" field="return", operator="eq", expected=true/false
return.type = "integer" field="return", operator="eq"/"gt"/"lt"
on_failure = exception field="exception", operator="raised"
- 返回值为 dict on_success / on_failure 中每个 key 单独断言
- 返回值为 bool field="return"operator="eq"expected=true false
- 返回值为 int field="return"operator="eq" / "gt" / "lt"
- 预期抛出异常 field="exception"operator="raised"
- 预期不抛出异常 field="exception"operator="not_raised"
COVERAGE GUIDELINES
覆盖度要求
- Per requirement: at least 1 positive + 1 negative test case.
- Negative test_name/description MUST contain "negative" or "invalid".
- Multi-interface requirements: single multi-step test case.
- Cover ALL "in"/"inout" parameters (including optional ones).
- Assert ALL fields described in on_success (positive) and on_failure (negative).
- For "out" parameters: verify their values in assertions after the call.
- 每条需求至少生成 1 个正向用例 + 1 个负向用例
- 负向用例的 test_name 必须包含"负向""无效"
- 覆盖接口的所有 inout "in" "inout" 的参数
- 正向用例断言 on_success 的所有字段
- 负向用例断言 on_failure 的所有字段
"""
# 参数生成规则段落(有参数约束时追加到 System Prompt
_PARAM_DIRECTIVE_TEMPLATE = """
参数生成规则本次任务特殊要求
{directives}
- test_data_notes 字段必须说明本用例使用的参数类别正常值/边界值/异常值及具体值
- 每个测试用例的 input 字段中的参数值必须从下方"参数测试数据集"中选取
- 禁止在 input 中使用未在数据集中出现的随机值
"""
# ══════════════════════════════════════════════════════════════
# PromptBuilder
# ══════════════════════════════════════════════════════════════
class PromptBuilder:
def get_system_prompt(self) -> str:
return SYSTEM_PROMPT.strip()
def __init__(self):
self._param_gen = ParamGenerator()
# ── System Prompt ────────────────────────────────────────
def get_system_prompt(
self,
global_constraint: ParamConstraint = None,
) -> str:
"""
构建 System Prompt
若存在全局参数约束追加参数生成规则段落
"""
base = _SYSTEM_PROMPT_BASE.strip()
if global_constraint and global_constraint.has_param_directive:
directive_section = self._build_param_directive_section(
global_constraint
)
return base + "\n" + directive_section.strip()
return base
def _build_param_directive_section(
self,
constraint: ParamConstraint,
) -> str:
"""构建参数生成规则段落,注入 System Prompt。"""
lines: list[str] = []
if constraint.min_groups > 2:
lines.append(
f"- 每个接口的测试用例总数不少于 **{constraint.min_groups} 组**"
f"(正向 + 负向合计)"
)
strategy_labels = {
ParamStrategy.BOUNDARY: "边界值(空值、最大值、最小值、临界值)",
ParamStrategy.EXCEPTION: "异常值None、类型错误、超长、特殊字符等",
ParamStrategy.EQUIVALENCE: "等价类(有效等价类 + 无效等价类)",
ParamStrategy.RANDOM: "随机值(覆盖典型随机场景)",
}
if constraint.strategies:
strategy_str = "".join(
strategy_labels[s]
for s in constraint.strategies
if s in strategy_labels
)
lines.append(f"- 参数值必须覆盖以下类型:{strategy_str}")
if not lines:
return ""
directives = "\n ".join(lines)
return _PARAM_DIRECTIVE_TEMPLATE.format(directives=directives)
# ── 项目信息头 ───────────────────────────────────────────
def build_project_header(
self,
project: str = "",
project_desc: str = "",
) -> str:
if not project and not project_desc:
return ""
lines = ["## 项目信息"]
if project:
lines.append(f"项目名称:{project}")
if project_desc:
lines.append(f"项目描述:{project_desc}")
return "\n".join(lines)
# ── 参数数据集段落 ───────────────────────────────────────
def build_param_dataset_section(
self,
iface_summaries: list[dict],
constraint: ParamConstraint,
) -> str:
"""
为批次内所有接口生成参数数据集返回注入 user_prompt 的文本段落
"""
if not constraint.has_param_directive:
return ""
all_lines: list[str] = [
"## 参数测试数据集",
f"> 以下数据集由系统自动生成,请从中选取参数值组合,"
f"确保每个接口的测试用例总数不少于 **{constraint.min_groups} 组**。",
"",
]
for iface in iface_summaries:
name = iface.get("name", "未知接口")
datasets = self._param_gen.generate_for_interface(
interface_summary=iface,
strategies=constraint.strategies,
min_groups=constraint.min_groups,
)
if not datasets:
continue
all_lines.append(f"### 接口:{name}")
text = self._param_gen.to_prompt_text(datasets, constraint.min_groups)
all_lines.append(text)
all_lines.append("")
return "\n".join(all_lines)
# ── 分批 user_prompt ────────────────────────────────────
def build_batch_prompt(
self,
batch: list[dict],
requirements: list[str],
project_header: str = "",
param_constraint: ParamConstraint = None,
) -> str:
"""
构建单批次的 user_prompt
若存在参数约束自动注入参数数据集
"""
req_lines = "\n".join(
f"{i + 1}. {r}" for i, r in enumerate(requirements)
)
parts: list[str] = []
if project_header:
parts.append(project_header)
parts.append(
"## 本批次接口描述\n"
+ json.dumps(batch, ensure_ascii=False, indent=2)
)
parts.append("## 测试需求\n" + req_lines)
# 注入参数数据集
if param_constraint and param_constraint.has_param_directive:
dataset_section = self.build_param_dataset_section(
iface_summaries=batch,
constraint=param_constraint,
)
if dataset_section:
parts.append(dataset_section)
parts.append(
"请根据以上接口描述、测试需求和参数数据集,生成完整的测试用例(包含正向和负向)。\n"
"- function 协议:使用 module_path 字段导入被测函数\n"
"- http 协议:使用 full_url 字段作为请求地址\n"
"- 每条需求的 requirement_id 须填入对应用例的 requirement_id 字段\n"
"- test_data_notes 须说明本用例使用的参数类别和具体值"
)
return "\n\n".join(parts)
# ── 单次 user_prompt ────────────────────────────────────
def build_user_prompt(
self,
requirements: list[str],
interfaces: list[InterfaceInfo],
parser: InterfaceParser,
project: str = "",
project_desc: str = "",
requirements: list[str],
interfaces: list[InterfaceInfo],
parser: InterfaceParser,
project: str = "",
project_desc: str = "",
param_constraint: ParamConstraint = None,
) -> str:
"""
单次调用模式接口数 batch_size 时使用
"""
iface_summary = parser.to_summary_dict(interfaces)
req_lines = "\n".join(
f"{i + 1}. {r}" for i, r in enumerate(requirements)
)
header = self.build_project_header(project, project_desc)
# 项目信息头
project_section = ""
if project or project_desc:
project_section = "## Project\n"
if project:
project_section += f"Name : {project}\n"
if project_desc:
project_section += f"Description: {project_desc}\n"
project_section += "\n"
parts: list[str] = []
if header:
parts.append(header)
parts.append(
"## 接口描述\n"
+ json.dumps(iface_summary, ensure_ascii=False, indent=2)
)
parts.append("## 测试需求\n" + req_lines)
return (
f"{project_section}"
f"## Available Interfaces\n"
f"{json.dumps(iface_summary, ensure_ascii=False, indent=2)}\n\n"
f"## Test Requirements\n"
f"Generate comprehensive test cases (positive + negative) for every interface.\n"
f"- The generated test cases for each interface must meet the following requirements: {req_lines}\n"
f"- For function interfaces: use 'module_path' for import, "
f"'source_file' for reference.\n"
f"- For HTTP interfaces: use 'full_url' as the request URL.\n"
f"- Use requirement_id from the interface to populate the test case's "
f"requirement_id field.\n"
f"- Chain multiple interface calls in one test case when a requirement "
f"involves more than one interface.\n"
).strip()
if param_constraint and param_constraint.has_param_directive:
dataset_section = self.build_param_dataset_section(
iface_summaries=iface_summary
if isinstance(iface_summary, list) else list(iface_summary.values()),
constraint=param_constraint,
)
if dataset_section:
parts.append(dataset_section)
parts.append(
"请根据以上接口描述、测试需求和参数数据集,生成完整的测试用例(包含正向和负向)。\n"
"- function 协议:使用 module_path 字段导入被测函数\n"
"- http 协议:使用 full_url 字段作为请求地址\n"
"- 每条需求的 requirement_id 须填入对应用例的 requirement_id 字段\n"
"- test_data_notes 须说明本用例使用的参数类别和具体值"
)
return "\n\n".join(parts)

View File

@ -0,0 +1,211 @@
"""
需求解析器
职责
从用户填写的测试需求文本中识别并提取参数生成约束指令
- 参数组数约束"不少于10组""至少5组""生成8组"
- 参数策略约束"边界值""异常值""等价类""随机值"
- 数据类型约束"字符串边界值""整数异常值"
解析结果供 ParamGenerator PromptBuilder 使用
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from enum import Enum
# ══════════════════════════════════════════════════════════════
# 枚举:参数生成策略
# ══════════════════════════════════════════════════════════════
class ParamStrategy(str, Enum):
BOUNDARY = "boundary" # 边界值
EXCEPTION = "exception" # 异常值 / 错误值
EQUIVALENCE = "equivalence" # 等价类
RANDOM = "random" # 随机值
CUSTOM = "custom" # 自定义(用户在需求中直接指定)
# ══════════════════════════════════════════════════════════════
# 数据结构:解析结果
# ══════════════════════════════════════════════════════════════
@dataclass
class ParamConstraint:
"""
从单条需求文本中解析出的参数生成约束
"""
# 最少生成的参数组数(正向 + 负向合计)
min_groups: int = 2
# 需要覆盖的策略集合
strategies: list[ParamStrategy] = field(default_factory=list)
# 原始需求文本(去除参数指令后的纯业务需求部分)
clean_requirement: str = ""
# 是否显式指定了参数要求(用于区分"用户有要求"和"默认值"
has_param_directive: bool = False
def __str__(self) -> str:
parts = []
if self.min_groups > 2:
parts.append(f"至少 {self.min_groups} 组参数")
for s in self.strategies:
parts.append(_STRATEGY_LABEL[s])
return "".join(parts) if parts else "默认1正1负"
_STRATEGY_LABEL: dict[ParamStrategy, str] = {
ParamStrategy.BOUNDARY: "边界值",
ParamStrategy.EXCEPTION: "异常值",
ParamStrategy.EQUIVALENCE: "等价类",
ParamStrategy.RANDOM: "随机值",
ParamStrategy.CUSTOM: "自定义",
}
# ══════════════════════════════════════════════════════════════
# 解析规则
# ══════════════════════════════════════════════════════════════
# 数量约束:匹配"不少于N组"、"至少N组"、"最少N组"、"生成N组"、"N组以上"
_RE_MIN_GROUPS = re.compile(
r'(?:不少于|至少|最少|生成|共|需要|要求)\s*(\d+)\s*组'
r'|(\d+)\s*组(?:以上|及以上|或以上)',
re.UNICODE,
)
# 策略关键词映射(关键词 → 策略枚举)
_STRATEGY_KEYWORDS: list[tuple[list[str], ParamStrategy]] = [
(["边界值", "边界", "boundary", "边界测试"], ParamStrategy.BOUNDARY),
(["异常值", "异常", "错误值", "非法值",
"exception", "invalid", "error"], ParamStrategy.EXCEPTION),
(["等价类", "等价", "equivalence"], ParamStrategy.EQUIVALENCE),
(["随机", "随机值", "random"], ParamStrategy.RANDOM),
]
# 参数指令整体识别:包含上述任意关键词则认为是参数指令
_RE_PARAM_DIRECTIVE = re.compile(
r'(?:参数|测试参数|测试数据|数据|入参)'
r'.*?(?:不少于|至少|最少|生成|边界|异常|等价|随机|\d+\s*组)',
re.UNICODE,
)
# ══════════════════════════════════════════════════════════════
# 解析器
# ══════════════════════════════════════════════════════════════
class RequirementParser:
"""
解析测试需求列表提取参数生成约束返回
- clean_requirements : 去除参数指令后的纯业务需求列表
- constraints : 每条需求对应的 ParamConstraint
- global_constraint : 全局约束"全局"/"所有接口"等描述中提取
"""
def parse(
self,
requirements: list[str],
) -> tuple[list[str], list[ParamConstraint], ParamConstraint]:
"""
Returns:
(clean_requirements, per_req_constraints, global_constraint)
"""
clean_reqs: list[str] = []
constraints: list[ParamConstraint] = []
global_c = ParamConstraint()
for req in requirements:
c = self._parse_one(req)
constraints.append(c)
clean_reqs.append(c.clean_requirement)
# 若某条需求是全局性描述(如"所有接口参数不少于10组"
# 则将其约束提升为全局约束
if self._is_global(req) and c.has_param_directive:
global_c = self._merge(global_c, c)
return clean_reqs, constraints, global_c
# ── 解析单条需求 ──────────────────────────────────────────
def _parse_one(self, req: str) -> ParamConstraint:
c = ParamConstraint(clean_requirement=req)
# 1. 检测是否包含参数指令
if not self._has_directive(req):
return c
c.has_param_directive = True
# 2. 提取数量约束
m = _RE_MIN_GROUPS.search(req)
if m:
n = int(m.group(1) or m.group(2))
c.min_groups = max(n, 2) # 至少保留 1正1负
# 3. 提取策略约束
seen: set[ParamStrategy] = set()
for keywords, strategy in _STRATEGY_KEYWORDS:
if any(kw in req for kw in keywords):
seen.add(strategy)
c.strategies = list(seen)
# 4. 若未指定策略,默认同时覆盖边界值和异常值
if not c.strategies:
c.strategies = [ParamStrategy.BOUNDARY, ParamStrategy.EXCEPTION]
# 5. 清理需求文本:去除参数指令部分,保留业务描述
c.clean_requirement = self._strip_directive(req)
return c
# ── 工具方法 ──────────────────────────────────────────────
@staticmethod
def _has_directive(req: str) -> bool:
"""判断需求文本是否包含参数生成指令。"""
return bool(_RE_PARAM_DIRECTIVE.search(req)) or bool(
_RE_MIN_GROUPS.search(req)
)
@staticmethod
def _is_global(req: str) -> bool:
"""判断是否为全局性参数要求。"""
return any(
kw in req
for kw in ["所有接口", "全部接口", "全局", "所有测试", "全部测试"]
)
@staticmethod
def _strip_directive(req: str) -> str:
"""
去除需求文本中的参数指令部分保留纯业务描述
策略以中文逗号分号括号等分隔去除含参数指令的子句
"""
# 按常见分隔符拆分
parts = re.split(r'[,;()【】\n]', req)
clean_parts = []
for part in parts:
part = part.strip()
if not part:
continue
# 若该子句包含参数指令关键词,则跳过
if _RE_PARAM_DIRECTIVE.search(part) or _RE_MIN_GROUPS.search(part):
continue
clean_parts.append(part)
result = "".join(clean_parts).strip()
return result if result else req # 若全部被去除则保留原文
@staticmethod
def _merge(base: ParamConstraint, other: ParamConstraint) -> ParamConstraint:
"""合并两个约束,取较大值。"""
merged = ParamConstraint(
min_groups=max(base.min_groups, other.min_groups),
strategies=list(set(base.strategies) | set(other.strategies)),
has_param_directive=True,
)
return merged

View File

@ -1,168 +1,278 @@
import os
import sys
"""
测试执行器
大规模支持改造
1. 并行执行ThreadPoolExecutor默认 8 个并发
2. 超时控制单个用例超时不阻塞整体
3. 步骤级结果解析##STEP_RESULT## 行
4. 进度显示实时打印完成数
5. 结果持久化保存 run_results.json
"""
from __future__ import annotations
import json
import subprocess
import time
import logging
from pathlib import Path
import subprocess
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path
from config import config
logger = logging.getLogger(__name__)
STEP_RESULT_PREFIX = "##STEP_RESULT##"
# ══════════════════════════════════════════════════════════════
# 数据结构
# ══════════════════════════════════════════════════════════════
@dataclass
class AssertionResult:
field: str
operator: str
expected: object
actual: object
passed: bool
message: str = ""
@dataclass
class StepResult:
step_no: int
interface_name: str
status: str # PASS | FAIL | ERROR
assertions: list[dict] = field(default_factory=list)
# 每条 assertion: {"field","operator","expected","actual","passed","message"}
status: str # "PASS" | "FAIL"
assertions: list[AssertionResult] = field(default_factory=list)
@dataclass
class TestResult:
test_id: str
file_path: str
status: str # PASS | FAIL | ERROR | TIMEOUT
message: str = ""
duration: float = 0.0
stdout: str = ""
stderr: str = ""
status: str # "PASS" | "FAIL" | "ERROR" | "TIMEOUT"
message: str = ""
duration: float = 0.0
step_results: list[StepResult] = field(default_factory=list)
stdout: str = ""
stderr: str = ""
# ══════════════════════════════════════════════════════════════
# 执行器
# ══════════════════════════════════════════════════════════════
class TestRunner:
def __init__(self):
self.timeout = getattr(config, "TEST_TIMEOUT", 60)
self.max_workers = getattr(config, "TEST_MAX_WORKERS", 8)
self.python_bin = getattr(config, "TEST_PYTHON_BIN", sys.executable)
# ── 主入口 ────────────────────────────────────────────────
def run_all(self, test_files: list[Path]) -> list[TestResult]:
results = []
total = len(test_files)
print(f"\n{''*62}")
print(f" Running {total} test(s) …")
print(f"{''*62}")
for idx, fp in enumerate(test_files, 1):
print(f"\n[{idx}/{total}] {fp.name}")
result = self._run_one(fp)
results.append(result)
self._print_result(result)
"""
并行执行所有测试文件返回结果列表
"""
if not test_files:
logger.warning("No test files to run.")
return []
total = len(test_files)
results: list[TestResult] = []
done = 0
logger.info(
f"Running {total} test(s) with "
f"max_workers={self.max_workers}, timeout={self.timeout}s"
)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_map = {
executor.submit(self._run_one, f): f
for f in test_files
}
for future in as_completed(future_map):
result = future.result()
results.append(result)
done += 1
icon = "" if result.status == "PASS" else ""
logger.info(
f" [{done}/{total}] {icon} {result.test_id} "
f"({result.status}) {result.duration:.2f}s"
)
# 按文件名排序,保持输出稳定
results.sort(key=lambda r: r.file_path)
return results
# ── 执行单个脚本 ──────────────────────────────────────────
# ── 执行单个测试文件 ──────────────────────────────────────
def _run_one(self, file_path: Path) -> TestResult:
test_id = file_path.stem
t0 = time.time()
test_id = file_path.stem
start = time.monotonic()
try:
proc = subprocess.run(
[sys.executable, str(file_path)],
capture_output=True, text=True,
timeout=config.TEST_TIMEOUT,
env=self._env(),
[self.python_bin, str(file_path)],
capture_output=True,
text=True,
timeout=self.timeout,
)
duration = time.time() - t0
stdout = proc.stdout.strip()
stderr = proc.stderr.strip()
status, message = self._parse_output(stdout, proc.returncode)
step_results = self._parse_step_results(stdout)
return TestResult(
test_id=test_id, file_path=str(file_path),
status=status, message=message,
duration=duration, stdout=stdout, stderr=stderr,
step_results=step_results,
duration = time.monotonic() - start
return self._parse_output(
test_id, str(file_path), proc, duration
)
except subprocess.TimeoutExpired:
duration = time.monotonic() - start
logger.warning(f" TIMEOUT: {test_id} ({duration:.1f}s)")
return TestResult(
test_id=test_id, file_path=str(file_path),
status="TIMEOUT", message=f"Exceeded {config.TEST_TIMEOUT}s",
duration=time.time() - t0,
test_id=test_id,
file_path=str(file_path),
status="TIMEOUT",
message=f"Exceeded {self.timeout}s timeout",
duration=duration,
)
except Exception as e:
duration = time.monotonic() - start
logger.error(f" ERROR running {test_id}: {e}")
return TestResult(
test_id=test_id, file_path=str(file_path),
status="ERROR", message=str(e),
duration=time.time() - t0,
test_id=test_id,
file_path=str(file_path),
status="ERROR",
message=str(e),
duration=duration,
)
# ── 解析输出 ──────────────────────────────────────────────
def _parse_output(self, stdout: str, returncode: int) -> tuple[str, str]:
if not stdout:
return "FAIL", f"No output (exit={returncode})"
last = stdout.strip().splitlines()[-1].strip()
upper = last.upper()
if upper.startswith("PASS"):
return "PASS", last[5:].strip()
if upper.startswith("FAIL"):
return "FAIL", last[5:].strip()
return ("PASS" if returncode == 0 else "FAIL"), last
def _parse_output(
self,
test_id: str,
file_path: str,
proc: subprocess.CompletedProcess,
duration: float,
) -> TestResult:
stdout = proc.stdout or ""
stderr = proc.stderr or ""
step_results = self._parse_step_results(stdout)
# 最后一行决定整体状态
last_line = stdout.strip().splitlines()[-1] if stdout.strip() else ""
if last_line.startswith("PASS"):
status = "PASS"
message = last_line
elif last_line.startswith("FAIL"):
status = "FAIL"
message = last_line
elif proc.returncode != 0:
status = "FAIL"
message = f"Exit code {proc.returncode}: {stderr[:200]}"
else:
status = "ERROR"
message = "No PASS/FAIL line found in output"
return TestResult(
test_id=test_id,
file_path=file_path,
status=status,
message=message,
duration=duration,
step_results=step_results,
stdout=stdout,
stderr=stderr,
)
def _parse_step_results(self, stdout: str) -> list[StepResult]:
"""
解析脚本中以 ##STEP_RESULT## 开头的结构化输出行
格式##STEP_RESULT## <json>
"""
results = []
for line in stdout.splitlines():
line = line.strip()
if line.startswith("##STEP_RESULT##"):
try:
data = json.loads(line[len("##STEP_RESULT##"):].strip())
results.append(StepResult(
step_no=data.get("step_no", 0),
interface_name=data.get("interface_name", ""),
status=data.get("status", ""),
assertions=data.get("assertions", []),
))
except Exception:
pass
if not line.startswith(STEP_RESULT_PREFIX):
continue
try:
raw = json.loads(line[len(STEP_RESULT_PREFIX):].strip())
assertions = [
AssertionResult(
field=a.get("field", ""),
operator=a.get("operator", ""),
expected=a.get("expected"),
actual=a.get("actual"),
passed=bool(a.get("passed", False)),
message=a.get("message", ""),
)
for a in raw.get("assertions", [])
]
results.append(StepResult(
step_no=raw.get("step_no", 0),
interface_name=raw.get("interface_name", ""),
status=raw.get("status", "FAIL"),
assertions=assertions,
))
except (json.JSONDecodeError, KeyError) as e:
logger.debug(f"Step result parse error: {e}")
return results
def _env(self) -> dict:
env = os.environ.copy()
env["HTTP_BASE_URL"] = config.HTTP_BASE_URL
return env
# ── 打印 ──────────────────────────────────────────────────
def _print_result(self, r: TestResult):
icon = {"PASS": "", "FAIL": "", "TIMEOUT": "⏱️", "ERROR": "⚠️"}.get(r.status, "")
print(f" {icon} [{r.status}] {r.test_id} ({r.duration:.2f}s)")
print(f" {r.message}")
for sr in r.step_results:
s_icon = "" if sr.status == "PASS" else ""
print(f" {s_icon} Step {sr.step_no}: {sr.interface_name}")
for a in sr.assertions:
a_icon = "" if a.get("passed") else ""
print(f" {a_icon} {a.get('message','')} "
f"(expected={a.get('expected')}, actual={a.get('actual')})")
if r.stderr:
print(f" stderr: {r.stderr[:300]}")
# ── 摘要打印 ──────────────────────────────────────────────
def print_summary(self, results: list[TestResult]):
total = len(results)
passed = sum(1 for r in results if r.status == "PASS")
failed = sum(1 for r in results if r.status == "FAIL")
errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT"))
print(f"\n{''*62}")
print(f" TEST SUMMARY")
print(f"{''*62}")
print(f" Total : {total}")
total = len(results)
passed = sum(1 for r in results if r.status == "PASS")
failed = sum(1 for r in results if r.status == "FAIL")
errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT"))
avg_dur = sum(r.duration for r in results) / total if total else 0
print(f"\n{'' * 56}")
print(f" Test Run Summary ({total} cases)")
print(f"{'' * 56}")
print(f" ✅ PASS : {passed}")
print(f" ❌ FAIL : {failed}")
print(f" ⚠️ ERROR : {errors}")
print(f" Pass Rate: {passed/total*100:.1f}%" if total else " Pass Rate: N/A")
print(f"{''*62}\n")
print(f" ⏱ Avg : {avg_dur:.2f}s / case")
print(f"{'' * 56}")
if failed or errors:
print(" Failed / Error cases:")
for r in results:
if r.status not in ("PASS",):
print(f" [{r.status}] {r.test_id}: {r.message}")
print()
# ── 持久化 ────────────────────────────────────────────────
def save_results(self, results: list[TestResult], path: str):
with open(path, "w", encoding="utf-8") as f:
json.dump([{
"test_id": r.test_id, "status": r.status,
"message": r.message, "duration": r.duration,
"stdout": r.stdout, "stderr": r.stderr,
data = [
{
"test_id": r.test_id,
"file_path": r.file_path,
"status": r.status,
"message": r.message,
"duration": round(r.duration, 3),
"step_results": [
{"step_no": sr.step_no, "interface_name": sr.interface_name,
"status": sr.status, "assertions": sr.assertions}
for sr in r.step_results
{
"step_no": s.step_no,
"interface_name": s.interface_name,
"status": s.status,
"assertions": [
{
"field": a.field,
"operator": a.operator,
"expected": a.expected,
"actual": a.actual,
"passed": a.passed,
"message": a.message,
}
for a in s.assertions
],
}
for s in r.step_results
],
} for r in results], f, ensure_ascii=False, indent=2)
}
for r in results
]
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"Run results → {path}")

View File

@ -3,11 +3,18 @@
AI-Powered API Test Generator, Runner & Coverage Analyzer
Usage:
# 基础用法
python main.py --api-desc examples/api_desc.json \\
--requirements "创建用户,删除用户"
# 在需求中自然语言描述参数生成要求
python main.py --api-desc examples/api_desc.json \\
--req-file examples/requirements.txt
--requirements "创建用户测试参数不少于10组覆盖边界值和异常值,删除用户"
# 从文件读取需求(每行一条,支持参数生成指令)
python main.py --api-desc examples/api_desc.json \\
--req-file examples/requirements.txt \\
--batch-size 10
"""
import argparse
@ -17,6 +24,7 @@ from pathlib import Path
from config import config
from core.parser import InterfaceParser, ApiDescriptor
from core.requirement_parser import RequirementParser, ParamConstraint
from core.prompt_builder import PromptBuilder
from core.llm_client import LLMClient
from core.test_generator import TestGenerator
@ -24,48 +32,57 @@ from core.test_runner import TestRunner
from core.analyzer import CoverageAnalyzer, CoverageReport
from core.report_generator import ReportGenerator
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s%(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
# ══════════════════════════════════════════════════════════════
# 日志初始化
# ══════════════════════════════════════════════════════════════
def setup_logging(debug: bool = False):
"""
配置日志级别
Bug 规避openai-python v2.24.0 + pydantic-core
第三方 SDK logger 强制锁定在 WARNING避免触发
model_dump(by_alias=None) pydantic-core Rust TypeError
Ref: https://github.com/openai/openai-python/issues/2921
"""
root_level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(
level=root_level,
format="%(asctime)s [%(levelname)s] %(name)s%(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
for name in ["openai", "openai._base_client", "anthropic", "httpx", "httpcore"]:
logging.getLogger(name).setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# ── CLI ───────────────────────────────────────────────────────
# ══════════════════════════════════════════════════════════════
# CLI
# ══════════════════════════════════════════════════════════════
def build_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="AI-Powered API Test Generator & Coverage Analyzer"
)
p.add_argument(
"--api-desc", required=True,
help="接口描述 JSON 文件(含 project / description / units 字段)",
)
p.add_argument(
"--requirements", default="",
help="测试需求(逗号分隔)",
)
p.add_argument(
"--req-file", default="",
help="测试需求文件(每行一条)",
)
p.add_argument(
"--skip-run", action="store_true",
help="只生成测试文件,不执行",
)
p.add_argument(
"--skip-analyze", action="store_true",
help="跳过覆盖率分析",
)
p.add_argument(
"--output-dir", default="",
help="测试文件根输出目录(默认 config.GENERATED_TESTS_DIR",
)
p.add_argument(
"--debug", action="store_true",
help="开启 DEBUG 日志",
description="AI-Powered API Test Generator & Coverage Analyzer",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
参数生成指令示例可直接写在需求文本中
"创建用户测试参数不少于10组覆盖边界值和异常值"
"查询接口至少8组边界值"
"所有接口参数不少于5组覆盖等价类"
""",
)
p.add_argument("--api-desc", required=True, help="接口描述 JSON 文件")
p.add_argument("--requirements", default="", help="测试需求(逗号分隔,支持参数生成指令)")
p.add_argument("--req-file", default="", help="测试需求文件(每行一条)")
p.add_argument("--batch-size", type=int, default=0, help="每批接口数量0=使用 config.LLM_BATCH_SIZE")
p.add_argument("--workers", type=int, default=0, help="并行执行线程数0=使用 config.TEST_MAX_WORKERS")
p.add_argument("--skip-run", action="store_true", help="只生成,不执行")
p.add_argument("--skip-analyze", action="store_true", help="跳过覆盖率分析")
p.add_argument("--output-dir", default="", help="测试文件根输出目录")
p.add_argument("--debug", action="store_true", help="开启 DEBUG 日志(第三方 SDK 仍保持 WARNING")
return p
@ -77,82 +94,111 @@ def load_requirements(args) -> list[str]:
with open(args.req_file, "r", encoding="utf-8") as f:
reqs += [line.strip() for line in f if line.strip()]
if not reqs:
logger.error(
"No requirements provided. Use --requirements or --req-file."
)
logger.error("未提供任何测试需求,请使用 --requirements 或 --req-file 指定。")
sys.exit(1)
return reqs
# ── Main ──────────────────────────────────────────────────────
# ══════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════
def main():
args = build_arg_parser().parse_args()
setup_logging(debug=args.debug)
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
if args.output_dir:
config.GENERATED_TESTS_DIR = args.output_dir
if args.output_dir: config.GENERATED_TESTS_DIR = args.output_dir
if args.batch_size > 0: config.LLM_BATCH_SIZE = args.batch_size
if args.workers > 0: config.TEST_MAX_WORKERS = args.workers
# ── Step 1: 解析接口描述 ──────────────────────────────────
logger.info("▶ Step 1: Parsing interface description")
parser: InterfaceParser = InterfaceParser()
descriptor: ApiDescriptor = parser.parse_file(args.api_desc)
logger.info("▶ Step 1: 解析接口描述")
parser = InterfaceParser()
descriptor = parser.parse_file(args.api_desc)
project = descriptor.project
project_desc = descriptor.description
interfaces = descriptor.interfaces
logger.info(f" Project : {project}")
logger.info(f" Description: {project_desc}")
logger.info(
f" Interfaces : {len(interfaces)}"
f"{[i.name for i in interfaces]}"
)
# 打印每个接口的 url 解析结果
logger.info(f" 项目名称:{project}")
logger.info(f" 项目描述:{project_desc}")
logger.info(f" 接口数量:{len(interfaces)}")
for iface in interfaces:
if iface.protocol == "function":
logger.info(
f" [{iface.name}] source={iface.source_file} "
f"module={iface.module_path}"
f" [function] {iface.name} "
f"{iface.source_file} (module: {iface.module_path})"
)
else:
logger.info(
f" [{iface.name}] full_url={iface.http_full_url}"
)
logger.info(f" [http] {iface.name}{iface.http_full_url}")
# ── Step 2: 加载测试需求 ──────────────────────────────────
logger.info("▶ Step 2: Loading test requirements …")
requirements = load_requirements(args)
for i, req in enumerate(requirements, 1):
logger.info(f" {i}. {req}")
# ── Step 2: 加载并解析测试需求(含参数生成指令)──────────
logger.info("▶ Step 2: 加载并解析测试需求 …")
raw_requirements = load_requirements(args)
# ── Step 3: 调用 LLM 生成测试用例 ────────────────────────
logger.info("▶ Step 3: Calling LLM …")
builder = PromptBuilder()
test_cases = LLMClient().generate_test_cases(
builder.get_system_prompt(),
builder.build_user_prompt(
requirements, interfaces, parser,
project=project,
project_desc=project_desc,
),
req_parser = RequirementParser()
clean_reqs, per_req_constraints, global_constraint = req_parser.parse(
raw_requirements
)
logger.info(f" LLM returned {len(test_cases)} test case(s)")
# 打印原始需求 & 解析结果
for i, (raw, clean, c) in enumerate(
zip(raw_requirements, clean_reqs, per_req_constraints), 1
):
if c.has_param_directive:
logger.info(
f" {i}. {raw}\n"
f" └─ 参数约束:{c} | 业务需求:{clean}"
)
else:
logger.info(f" {i}. {raw}")
if global_constraint.has_param_directive:
logger.info(f" 全局参数约束:{global_constraint}")
# 使用"清洗后的需求"传给 LLM去掉参数指令保留纯业务描述
# 同时保留原始需求用于覆盖率分析(保证与用户输入一致)
effective_requirements = clean_reqs
# ── Step 3: 构建 Prompt 并调用 LLM 生成测试用例 ──────────
logger.info(
f"▶ Step 3: 调用 LLM 生成测试用例 "
f"(batch_size={getattr(config, 'LLM_BATCH_SIZE', 10)}) …"
)
builder = PromptBuilder()
iface_summary = parser.to_summary_dict(interfaces)
project_header = builder.build_project_header(project, project_desc)
# System Prompt注入全局参数约束规则
system_prompt = builder.get_system_prompt(
global_constraint=global_constraint
)
test_cases = LLMClient().generate_test_cases(
system_prompt=system_prompt,
user_prompt="",
iface_summaries=iface_summary,
requirements=effective_requirements,
project_header=project_header,
param_constraint=global_constraint, # ← 注入参数约束
)
logger.info(f" 共生成测试用例:{len(test_cases)}")
# ── Step 4: 生成测试文件 ──────────────────────────────────
logger.info(
f"▶ Step 4: Generating test files (project='{project}') …"
)
logger.info("▶ Step 4: 生成测试文件 …")
generator = TestGenerator(project=project, project_desc=project_desc)
test_files = generator.generate(test_cases)
out = generator.output_dir
logger.info(f" 输出目录:{out.resolve()}")
run_results = []
if not args.skip_run:
# ── Step 5: 执行测试 ──────────────────────────────────
logger.info("▶ Step 5: Running tests …")
# ── Step 5: 并行执行测试 ──────────────────────────────
logger.info(
f"▶ Step 5: 执行测试 "
f"(workers={getattr(config, 'TEST_MAX_WORKERS', 8)}) …"
)
runner = TestRunner()
run_results = runner.run_all(test_files)
runner.print_summary(run_results)
@ -160,31 +206,36 @@ def main():
if not args.skip_analyze:
# ── Step 6: 覆盖率分析 ────────────────────────────────
logger.info("▶ Step 6: Analyzing coverage")
logger.info("▶ Step 6: 覆盖率分析")
report = CoverageAnalyzer(
interfaces=interfaces,
requirements=requirements,
requirements=raw_requirements, # 用原始需求做覆盖率分析
test_cases=test_cases,
run_results=run_results,
).analyze()
# ── Step 7: 生成报告 ──────────────────────────────────
logger.info("▶ Step 7: Generating reports")
logger.info("▶ Step 7: 生成报告")
rg = ReportGenerator()
rg.save_json(report, str(out / "coverage_report.json"))
rg.save_html(report, str(out / "coverage_report.html"))
_print_terminal_summary(report, out, project)
_print_terminal_summary(report, out, project, global_constraint)
logger.info(f"\nDone. Output: {out.resolve()}")
logger.info(f"\n完成。输出目录:{out.resolve()}")
# ── 终端摘要 ──────────────────────────────────────────────────
# ══════════════════════════════════════════════════════════════
# 终端摘要
# ══════════════════════════════════════════════════════════════
def _print_terminal_summary(
report: CoverageReport, out: Path, project: str
report: "CoverageReport",
out: Path,
project: str,
global_constraint: ParamConstraint,
):
W = 66
W = 68
def bar(rate: float, w: int = 20) -> str:
filled = int(rate * w)
@ -193,35 +244,41 @@ def _print_terminal_summary(
return f"{'' * filled}{'' * empty} {rate * 100:.1f}% {icon}"
print(f"\n{'' * W}")
print(f" PROJECT : {project}")
print(f" COVERAGE SUMMARY")
print(f" 项目:{project}")
print(f" 覆盖率摘要")
print(f"{'' * W}")
print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
print(f" 用例通过率 {bar(report.pass_rate)}")
print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}")
print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}")
print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}")
print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}")
print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}")
print(f" 用例通过率 {bar(report.pass_rate)}")
print(f"{'' * W}")
print(f" Total Gaps : {len(report.gaps)}")
print(f" 🔴 Critical: {report.critical_gap_count}")
print(f" 🟠 High : {report.high_gap_count}")
print(f" 测试用例总数 {report.total_test_cases}")
print(f" 覆盖缺口总数 {len(report.gaps)}")
print(f" 🔴 严重缺口 {report.critical_gap_count}")
print(f" 🟠 高优先级缺口 {report.high_gap_count}")
# 参数约束达成情况
if global_constraint.has_param_directive:
print(f"{'' * W}")
print(f" 参数生成约束 {global_constraint}")
actual = report.total_test_cases
needed = global_constraint.min_groups
status = "✅ 已满足" if actual >= needed else f"❌ 未满足(需 {needed} 组,实际 {actual} 组)"
print(f" 参数组数要求 {status}")
if report.gaps:
print(f"{'' * W}")
print(" Top Gaps (up to 8):")
icons = {
"critical": "🔴", "high": "🟠",
"medium": "🟡", "low": "🔵",
}
print(" Top 缺口最多显示8条")
icons = {"critical": "🔴", "high": "🟠", "medium": "🟡", "low": "🔵"}
for g in report.gaps[:8]:
icon = icons.get(g.severity, "")
print(f" {icon} [{g.gap_type}] {g.target}")
print(f" {icons.get(g.severity, '')} [{g.gap_type}] {g.target}")
print(f"{g.suggestion}")
print(f"{'' * W}")
print(f" Output : {out.resolve()}")
print(f" • coverage_report.html ← open in browser")
print(f" 输出目录:{out.resolve()}")
print(f" • coverage_report.html")
print(f" • coverage_report.json")
print(f" • run_results.json")
print(f" • test_cases_summary.json")

View File

@ -16,4 +16,4 @@ export HTTP_BASE_URL="http://localhost:8080"
# 只生成不执行
python main.py --api-desc examples/api_desc.json \
--requirements "1.对每个接口进行测试,支持特殊值、边界值测试"
--requirements "每个测试用例的参数不少于10组"