diff --git a/ai_test_generator/core/llm_client.py b/ai_test_generator/core/llm_client.py index a74f090..d2030a2 100644 --- a/ai_test_generator/core/llm_client.py +++ b/ai_test_generator/core/llm_client.py @@ -1,58 +1,204 @@ +""" +LLM 客户端 +══════════════════════════════════════════════════════════════ +变更: + - generate_test_cases() 新增 param_constraint 参数 + - 分批调用时将约束透传给 PromptBuilder.build_batch_prompt() +""" + +from __future__ import annotations import json -import re import logging +import re +import time +from typing import Any + from openai import OpenAI from config import config logger = logging.getLogger(__name__) -class LLMClient: - """封装 LLM API 调用(OpenAI 兼容接口)""" +# ══════════════════════════════════════════════════════════════ +# JSON 健壮解析工具 +# ══════════════════════════════════════════════════════════════ - def __init__(self): - self.client = OpenAI( - api_key=config.LLM_API_KEY, - base_url=config.LLM_BASE_URL, - ) +class RobustJSONParser: - def generate_test_cases(self, system_prompt: str, user_prompt: str) -> list[dict]: - logger.info(f"Calling LLM: model={config.LLM_MODEL}") - logger.debug(f"--- USER PROMPT ---\n{user_prompt}\n---") + def parse(self, text: str) -> list[dict]: + text = text.strip() + text = re.sub(r'^```(?:json)?\s*', '', text, flags=re.MULTILINE) + text = re.sub(r'\s*```$', '', text, flags=re.MULTILINE) + text = text.strip() - response = self.client.chat.completions.create( - model=config.LLM_MODEL, - temperature=config.LLM_TEMPERATURE, - max_tokens=config.LLM_MAX_TOKENS, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - ) - raw = response.choices[0].message.content - logger.debug(f"--- LLM RAW RESPONSE ---\n{raw[:800]}\n---") - return self._parse_json(raw) - - # ── 解析 ────────────────────────────────────────────────── - - def _parse_json(self, content: str) -> list[dict]: - content = content.strip() - # 去除可能的 markdown 代码块 - content = re.sub(r"^```(?:json)?\s*", "", content, flags=re.MULTILINE) - content = re.sub(r"\s*```\s*$", "", content, flags=re.MULTILINE) - content = content.strip() + start = text.find('[') + if start == -1: + logger.warning("LLM 响应中未找到 JSON 数组。") + return [] + text = text[start:] try: - data = json.loads(content) + result = json.loads(text) + if isinstance(result, list): + return result except json.JSONDecodeError: - # 尝试提取第一个 JSON 数组 - match = re.search(r"\[.*\]", content, re.DOTALL) - if match: - data = json.loads(match.group()) - else: - logger.error(f"Cannot parse LLM response as JSON:\n{content[:600]}") - raise + pass - if not isinstance(data, list): - raise ValueError(f"Expected JSON array, got {type(data)}") - return data \ No newline at end of file + text = self._fix_truncated(text) + try: + result = json.loads(text) + if isinstance(result, list): + logger.warning("JSON 截断修复成功,部分用例可能丢失。") + return result + except json.JSONDecodeError as e: + logger.error(f"JSON 解析失败(修复后仍无效):{e}") + return [] + + def _fix_truncated(self, text: str) -> str: + last = text.rfind('},') + if last == -1: + last = text.rfind('}') + if last == -1: + return text + return text[:last + 1] + ']' + + +# ══════════════════════════════════════════════════════════════ +# LLM 客户端 +# ══════════════════════════════════════════════════════════════ + +class LLMClient: + + DEFAULT_BATCH_SIZE = 10 + + def __init__(self): + self.client = OpenAI( + api_key=config.LLM_API_KEY, + base_url=getattr(config, "LLM_BASE_URL", None), + ) + self.model = config.LLM_MODEL + self.max_tokens = getattr(config, "LLM_MAX_TOKENS", 8192) + self.batch_size = getattr(config, "LLM_BATCH_SIZE", self.DEFAULT_BATCH_SIZE) + self.retry = getattr(config, "LLM_RETRY", 3) + self.retry_delay = getattr(config, "LLM_RETRY_DELAY", 5) + self.batch_interval = getattr(config, "LLM_BATCH_INTERVAL", 1) + self._parser = RobustJSONParser() + + from core.prompt_builder import PromptBuilder + self._prompt_builder = PromptBuilder() + + # ── 主入口 ──────────────────────────────────────────────── + + def generate_test_cases( + self, + system_prompt: str, + user_prompt: str, + iface_summaries: list[dict] | None = None, + requirements: list[str] | None = None, + project_header: str = "", + param_constraint: "ParamConstraint | None" = None, # ← 新增 + ) -> list[dict]: + """ + 生成测试用例。 + + 方式 A(单次):传 user_prompt。 + 方式 B(分批):传 iface_summaries + requirements,推荐大规模场景。 + + param_constraint:由 RequirementParser 解析出的全局参数约束, + 有值时自动注入参数数据集到每批 prompt。 + """ + if iface_summaries is not None and requirements is not None: + return self._generate_batched( + system_prompt, iface_summaries, requirements, + project_header, param_constraint, + ) + return self._call_with_retry(system_prompt, user_prompt) + + # ── 分批调用 ────────────────────────────────────────────── + + def _generate_batched( + self, + system_prompt: str, + iface_summaries: list[dict], + requirements: list[str], + project_header: str, + param_constraint: Any, + ) -> list[dict]: + total = len(iface_summaries) + batches = self._make_batches(iface_summaries, self.batch_size) + all_cases: list[dict] = [] + + logger.info( + f"分批模式:共 {total} 个接口 → " + f"{len(batches)} 批 × 每批最多 {self.batch_size} 个" + ) + if param_constraint and param_constraint.has_param_directive: + logger.info( + f"参数约束已启用:{param_constraint}" + ) + + for idx, batch in enumerate(batches, 1): + names = [i.get("name", "?") for i in batch] + logger.info(f" 第 {idx}/{len(batches)} 批:{names}") + + user_prompt = self._prompt_builder.build_batch_prompt( + batch=batch, + requirements=requirements, + project_header=project_header, + param_constraint=param_constraint, # ← 透传 + ) + + cases = self._call_with_retry(system_prompt, user_prompt) + logger.info(f" 第 {idx} 批 → 生成 {len(cases)} 个测试用例") + all_cases.extend(cases) + + if idx < len(batches): + time.sleep(self.batch_interval) + + logger.info(f"全部批次完成,共生成 {len(all_cases)} 个测试用例") + return all_cases + + # ── 单次调用(含重试)──────────────────────────────────── + + def _call_with_retry( + self, + system_prompt: str, + user_prompt: str, + ) -> list[dict]: + last_error: Exception | None = None + + for attempt in range(1, self.retry + 1): + try: + logger.debug(f"LLM 调用第 {attempt}/{self.retry} 次 …") + response = self.client.chat.completions.create( + model=self.model, + max_tokens=self.max_tokens, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + ) + raw = response.choices[0].message.content or "" + logger.debug(f"LLM 响应长度:{len(raw)} 字符") + + cases = self._parser.parse(raw) + if cases: + return cases + + logger.warning(f"第 {attempt} 次调用返回空结果,准备重试 …") + + except Exception as e: + last_error = e + logger.warning(f"第 {attempt} 次调用失败:{e}") + + if attempt < self.retry: + time.sleep(self.retry_delay) + + logger.error( + f"已重试 {self.retry} 次,全部失败。最后一次错误:{last_error}" + ) + return [] + + @staticmethod + def _make_batches(items: list[Any], size: int) -> list[list[Any]]: + return [items[i:i + size] for i in range(0, len(items), size)] \ No newline at end of file diff --git a/ai_test_generator/core/param_generator.py b/ai_test_generator/core/param_generator.py new file mode 100644 index 0000000..da61601 --- /dev/null +++ b/ai_test_generator/core/param_generator.py @@ -0,0 +1,393 @@ +""" +测试参数生成器 +══════════════════════════════════════════════════════════════ +职责: + 根据接口参数的数据类型和参数生成策略, + 自动生成覆盖正常值、边界值、异常值的测试数据集。 + +支持的数据类型: + string / str + integer / int / number + float / double / decimal + boolean / bool + array / list + object / dict / map + +输出格式(供注入 prompt): + { + "param_name": { + "type": "string", + "groups": [ + {"label": "正常值", "value": "hello", "category": "normal"}, + {"label": "空字符串", "value": "", "category": "boundary"}, + {"label": "超长字符串","value": "a"*256, "category": "exception"}, + ... + ] + } + } +""" + +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any + +from core.requirement_parser import ParamStrategy + + +# ══════════════════════════════════════════════════════════════ +# 数据结构 +# ══════════════════════════════════════════════════════════════ + +@dataclass +class ParamGroup: + """单组测试参数。""" + label: str # 人类可读的描述,如"空字符串" + value: Any # 实际参数值 + category: str # "normal" | "boundary" | "exception" | "equivalence" + + def to_dict(self) -> dict: + return { + "label": self.label, + "value": self.value, + "category": self.category, + } + + +@dataclass +class ParamDataSet: + """单个参数的完整测试数据集。""" + param_name: str + param_type: str + groups: list[ParamGroup] = field(default_factory=list) + + def to_dict(self) -> dict: + return { + "type": self.param_type, + "groups": [g.to_dict() for g in self.groups], + } + + @property + def normal_values(self) -> list[Any]: + return [g.value for g in self.groups if g.category == "normal"] + + @property + def boundary_values(self) -> list[Any]: + return [g.value for g in self.groups if g.category == "boundary"] + + @property + def exception_values(self) -> list[Any]: + return [g.value for g in self.groups if g.category == "exception"] + + +# ══════════════════════════════════════════════════════════════ +# 各类型测试数据库 +# ══════════════════════════════════════════════════════════════ + +# ── string ──────────────────────────────────────────────────── +_STRING_NORMAL: list[ParamGroup] = [ + ParamGroup("普通字符串", "hello", "normal"), + ParamGroup("中文字符串", "测试数据", "normal"), + ParamGroup("字母数字混合", "user123", "normal"), +] +_STRING_BOUNDARY: list[ParamGroup] = [ + ParamGroup("空字符串", "", "boundary"), + ParamGroup("单字符", "a", "boundary"), + ParamGroup("最大长度(255)", "a" * 255, "boundary"), + ParamGroup("恰好256字符", "a" * 256, "boundary"), + ParamGroup("含空格", "hello world", "boundary"), + ParamGroup("首尾空格", " hello ", "boundary"), +] +_STRING_EXCEPTION: list[ParamGroup] = [ + ParamGroup("None/null", None, "exception"), + ParamGroup("超长字符串", "a" * 1000, "exception"), + ParamGroup("纯空格", " ", "exception"), + ParamGroup("特殊字符", "!@#$%^&*()", "exception"), + ParamGroup("换行符", "line1\nline2", "exception"), + ParamGroup("SQL注入", "' OR '1'='1", "exception"), + ParamGroup("XSS注入", "", "exception"), + ParamGroup("Unicode特殊字符","你好\u0000世界", "exception"), + ParamGroup("数字类型误传", 12345, "exception"), +] + +# ── integer ─────────────────────────────────────────────────── +_INTEGER_NORMAL: list[ParamGroup] = [ + ParamGroup("典型正整数", 1, "normal"), + ParamGroup("较大正整数", 100, "normal"), +] +_INTEGER_BOUNDARY: list[ParamGroup] = [ + ParamGroup("零", 0, "boundary"), + ParamGroup("正边界值1", 1, "boundary"), + ParamGroup("负边界值-1", -1, "boundary"), + ParamGroup("最大值(int32)", 2_147_483_647, "boundary"), + ParamGroup("最小值(int32)", -2_147_483_648, "boundary"), + ParamGroup("最大值+1", 2_147_483_648, "boundary"), +] +_INTEGER_EXCEPTION: list[ParamGroup] = [ + ParamGroup("None/null", None, "exception"), + ParamGroup("字符串类型", "abc", "exception"), + ParamGroup("浮点数", 3.14, "exception"), + ParamGroup("布尔值true", True, "exception"), + ParamGroup("空字符串", "", "exception"), + ParamGroup("超大整数", 10 ** 20, "exception"), +] + +# ── float ───────────────────────────────────────────────────── +_FLOAT_NORMAL: list[ParamGroup] = [ + ParamGroup("典型浮点数", 3.14, "normal"), + ParamGroup("整数形式", 1.0, "normal"), +] +_FLOAT_BOUNDARY: list[ParamGroup] = [ + ParamGroup("零", 0.0, "boundary"), + ParamGroup("极小正数", 1e-10, "boundary"), + ParamGroup("极大正数", 1e10, "boundary"), + ParamGroup("负数", -3.14, "boundary"), + ParamGroup("负无穷", float('-inf'), "boundary"), + ParamGroup("正无穷", float('inf'), "boundary"), +] +_FLOAT_EXCEPTION: list[ParamGroup] = [ + ParamGroup("None/null", None, "exception"), + ParamGroup("NaN", float('nan'), "exception"), + ParamGroup("字符串类型", "abc", "exception"), + ParamGroup("空字符串", "", "exception"), +] + +# ── boolean ─────────────────────────────────────────────────── +_BOOLEAN_NORMAL: list[ParamGroup] = [ + ParamGroup("true", True, "normal"), + ParamGroup("false", False, "normal"), +] +_BOOLEAN_BOUNDARY: list[ParamGroup] = [ + ParamGroup("整数1", 1, "boundary"), + ParamGroup("整数0", 0, "boundary"), +] +_BOOLEAN_EXCEPTION: list[ParamGroup] = [ + ParamGroup("None/null", None, "exception"), + ParamGroup("字符串'true'", "true", "exception"), + ParamGroup("字符串'false'", "false", "exception"), + ParamGroup("字符串'yes'", "yes", "exception"), + ParamGroup("随机字符串", "abc", "exception"), +] + +# ── array ───────────────────────────────────────────────────── +_ARRAY_NORMAL: list[ParamGroup] = [ + ParamGroup("普通数组", [1, 2, 3], "normal"), + ParamGroup("字符串数组", ["a", "b", "c"], "normal"), +] +_ARRAY_BOUNDARY: list[ParamGroup] = [ + ParamGroup("空数组", [], "boundary"), + ParamGroup("单元素数组", [1], "boundary"), + ParamGroup("大数组(100元素)", list(range(100)), "boundary"), + ParamGroup("含None元素", [1, None, 3], "boundary"), +] +_ARRAY_EXCEPTION: list[ParamGroup] = [ + ParamGroup("None/null", None, "exception"), + ParamGroup("字符串类型", "not_an_array", "exception"), + ParamGroup("整数类型", 123, "exception"), + ParamGroup("嵌套过深", [[[[[1]]]]], "exception"), +] + +# ── object ──────────────────────────────────────────────────── +_OBJECT_NORMAL: list[ParamGroup] = [ + ParamGroup("完整字段对象", {"id": 1, "name": "test"}, "normal"), +] +_OBJECT_BOUNDARY: list[ParamGroup] = [ + ParamGroup("空对象", {}, "boundary"), + ParamGroup("仅必填字段", {"id": 1}, "boundary"), + ParamGroup("含额外字段", {"id": 1, "extra": "x"}, "boundary"), +] +_OBJECT_EXCEPTION: list[ParamGroup] = [ + ParamGroup("None/null", None, "exception"), + ParamGroup("字符串类型", "not_an_object", "exception"), + ParamGroup("数组类型", [1, 2, 3], "exception"), + ParamGroup("字段值类型错误", {"id": "not_int"}, "exception"), +] + +# 类型 → 数据映射表 +_TYPE_DATA: dict[str, dict[str, list[ParamGroup]]] = { + "string": {"normal": _STRING_NORMAL, "boundary": _STRING_BOUNDARY, "exception": _STRING_EXCEPTION}, + "integer": {"normal": _INTEGER_NORMAL, "boundary": _INTEGER_BOUNDARY, "exception": _INTEGER_EXCEPTION}, + "float": {"normal": _FLOAT_NORMAL, "boundary": _FLOAT_BOUNDARY, "exception": _FLOAT_EXCEPTION}, + "boolean": {"normal": _BOOLEAN_NORMAL, "boundary": _BOOLEAN_BOUNDARY, "exception": _BOOLEAN_EXCEPTION}, + "array": {"normal": _ARRAY_NORMAL, "boundary": _ARRAY_BOUNDARY, "exception": _ARRAY_EXCEPTION}, + "object": {"normal": _OBJECT_NORMAL, "boundary": _OBJECT_BOUNDARY, "exception": _OBJECT_EXCEPTION}, +} + +# 类型别名归一化 +_TYPE_ALIAS: dict[str, str] = { + "str": "string", "text": "string", "varchar": "string", + "int": "integer", "long": "integer", "number": "integer", + "float": "float", "double": "float", "decimal": "float", + "bool": "boolean", + "list": "array", + "dict": "object", "map": "object", "json": "object", +} + + +# ══════════════════════════════════════════════════════════════ +# 参数生成器 +# ══════════════════════════════════════════════════════════════ + +class ParamGenerator: + """ + 根据参数类型和策略生成测试数据集。 + """ + + def generate( + self, + param_name: str, + param_type: str, + strategies: list[ParamStrategy], + min_groups: int = 2, + ) -> ParamDataSet: + """ + 生成单个参数的测试数据集。 + + Args: + param_name : 参数名 + param_type : 数据类型(支持别名) + strategies : 需要覆盖的策略列表 + min_groups : 最少生成的参数组数 + + Returns: + ParamDataSet + """ + norm_type = self._normalize_type(param_type) + type_data = _TYPE_DATA.get(norm_type, _TYPE_DATA["string"]) + + groups: list[ParamGroup] = [] + + # 始终包含正常值 + groups.extend(type_data["normal"]) + + # 按策略追加 + for strategy in strategies: + if strategy == ParamStrategy.BOUNDARY: + groups.extend(type_data["boundary"]) + elif strategy == ParamStrategy.EXCEPTION: + groups.extend(type_data["exception"]) + elif strategy == ParamStrategy.EQUIVALENCE: + # 等价类:正常值 + 部分边界值 + groups.extend(type_data["boundary"][:3]) + elif strategy == ParamStrategy.RANDOM: + groups.extend(self._random_groups(norm_type)) + + # 去重(按 label) + seen_labels: set[str] = set() + unique_groups: list[ParamGroup] = [] + for g in groups: + if g.label not in seen_labels: + seen_labels.add(g.label) + unique_groups.append(g) + + # 补齐到 min_groups + if len(unique_groups) < min_groups: + extra = type_data["boundary"] + type_data["exception"] + for g in extra: + if g.label not in seen_labels and len(unique_groups) < min_groups: + seen_labels.add(g.label) + unique_groups.append(g) + + return ParamDataSet( + param_name=param_name, + param_type=norm_type, + groups=unique_groups, + ) + + def generate_for_interface( + self, + interface_summary: dict, + strategies: list[ParamStrategy], + min_groups: int = 2, + ) -> dict[str, ParamDataSet]: + """ + 为一个接口的所有入参生成测试数据集。 + + Args: + interface_summary : parser.to_summary_dict() 中的单个接口 dict + strategies : 策略列表 + min_groups : 最少组数 + + Returns: + { param_name: ParamDataSet } + """ + result: dict[str, ParamDataSet] = {} + params = interface_summary.get("params", {}) + + for param_name, param_info in params.items(): + if isinstance(param_info, dict): + param_type = param_info.get("type", "string") + inout = param_info.get("inout", "in") + else: + param_type = str(param_info) + inout = "in" + + # 只对入参(in / inout)生成测试数据 + if inout not in ("in", "inout"): + continue + + result[param_name] = self.generate( + param_name=param_name, + param_type=param_type, + strategies=strategies, + min_groups=min_groups, + ) + + return result + + # ── 工具 ────────────────────────────────────────────────── + + @staticmethod + def _normalize_type(t: str) -> str: + t = t.lower().strip() + return _TYPE_ALIAS.get(t, t if t in _TYPE_DATA else "string") + + @staticmethod + def _random_groups(norm_type: str) -> list[ParamGroup]: + """生成少量随机值(补充策略)。""" + import random, string + if norm_type == "string": + val = ''.join(random.choices(string.ascii_letters, k=8)) + return [ParamGroup(f"随机字符串({val})", val, "equivalence")] + if norm_type == "integer": + val = random.randint(-1000, 1000) + return [ParamGroup(f"随机整数({val})", val, "equivalence")] + if norm_type == "float": + val = round(random.uniform(-100, 100), 4) + return [ParamGroup(f"随机浮点数({val})", val, "equivalence")] + return [] + + def to_prompt_text( + self, + datasets: dict[str, ParamDataSet], + min_groups: int, + ) -> str: + """ + 将参数数据集转换为注入 prompt 的中文说明文本。 + """ + if not datasets: + return "" + + lines = [ + f"### 参数测试数据集(每个参数至少覆盖 {min_groups} 组)", + "", + ] + for param_name, ds in datasets.items(): + lines.append(f"**参数:{param_name}**(类型:{ds.param_type})") + for g in ds.groups: + category_label = { + "normal": "正常值", + "boundary": "边界值", + "exception": "异常值", + "equivalence": "等价类", + }.get(g.category, g.category) + lines.append( + f" - [{category_label}] {g.label}:`{repr(g.value)}`" + ) + lines.append("") + + lines.append( + f"> 请从以上数据集中选取参数值组合," + f"确保生成的测试用例总数不少于 **{min_groups} 组**," + f"并覆盖正常值、边界值和异常值场景。" + ) + return "\n".join(lines) \ No newline at end of file diff --git a/ai_test_generator/core/prompt_builder.py b/ai_test_generator/core/prompt_builder.py index 1612df2..2ab07e8 100644 --- a/ai_test_generator/core/prompt_builder.py +++ b/ai_test_generator/core/prompt_builder.py @@ -1,167 +1,332 @@ """ -提示词构建器 -升级点: - - 传递项目 description - - function 协议:source_file(.py) → module_path → from import - - HTTP 协议:full_url = url(base) + name(path) - - parameters 统一字段,inout 区分输入/输出 +提示词构建器(中文版) +══════════════════════════════════════════════════════════════ +变更说明: + - 新增 build_param_directive_section():将参数约束注入 System Prompt + - 新增 build_param_dataset_section():将参数数据集注入 User Prompt + - build_batch_prompt / build_user_prompt 支持传入参数数据集 """ import json from core.parser import InterfaceInfo, InterfaceParser +from core.requirement_parser import ParamConstraint, ParamStrategy +from core.param_generator import ParamGenerator, ParamDataSet -SYSTEM_PROMPT = """ -You are a senior software test engineer. Generate test cases based on the provided -interface descriptions and test requirements. + +# ══════════════════════════════════════════════════════════════ +# System Prompt(中文) +# ══════════════════════════════════════════════════════════════ + +_SYSTEM_PROMPT_BASE = """ +你是一名资深软件测试工程师,擅长根据接口描述和测试需求生成高质量的测试用例。 ═══════════════════════════════════════════════════════ -OUTPUT FORMAT -Return ONLY a valid JSON array. No markdown fences. No extra text. +【输出格式要求】 ═══════════════════════════════════════════════════════ -Each element = ONE test case: +- 只输出一个合法的 JSON 数组,不要包含任何 markdown 代码块(```) +- 不要在 JSON 前后添加任何解释性文字 +- 每个数组元素代表一个测试用例,结构如下: + { - "test_id" : "unique id, e.g. TC_001", - "test_name" : "short descriptive name", - "description" : "what this test verifies", - "requirement" : "the original requirement text this case maps to", - "requirement_id" : "e.g. REQ.01", + "test_id" : "唯一编号,如 TC_001", + "test_name" : "简短的测试名称", + "description" : "本用例验证的内容", + "requirement" : "对应的原始需求文本", + "requirement_id" : "需求编号,如 REQ.01(从接口描述中的 requirement_id 字段获取)", "steps": [ { - "step_no" : 1, - "interface_name" : "function or endpoint name", - "protocol" : "http or function", - "url" : "source_file (function) or full_url (http)", - "purpose" : "why this step is needed", + "step_no" : 1, + "interface_name" : "接口名称或函数名", + "protocol" : "http 或 function", + "url" : "function 协议填 source_file,http 协议填 full_url", + "purpose" : "本步骤的目的", "input": { - // Only parameters with inout = "in" or "inout" - "param_name": + "参数名": "参数值" }, - "use_output_of": { // optional + "use_output_of": { "step_no" : 1, - "field" : "user_id", - "as_param" : "user_id" + "field" : "上一步返回值中的字段名", + "as_param" : "作为本步骤的参数名" }, "assertions": [ { - "field" : "field_name or 'return' or 'exception'", - "operator" : "eq|ne|gt|lt|gte|lte|in|not_null|contains|raised|not_raised", - "expected" : , - "message" : "human readable description" + "field" : "断言的字段名,或 'return'(整体返回值)或 'exception'(异常)", + "operator" : "eq | ne | gt | lt | gte | lte | in | not_null | contains | raised | not_raised", + "expected" : "期望值", + "message" : "断言说明" } ] } ], - "test_data_notes" : "explanation of auto-generated test data", - "test_code" : "" + "test_data_notes" : "测试数据说明(说明本用例使用了哪类参数:正常值/边界值/异常值)", + "test_code" : "完整可运行的 Python 测试脚本(见下方规则)" } ═══════════════════════════════════════════════════════ -TEST CODE RULES +【测试代码编写规则】 ═══════════════════════════════════════════════════════ -1. Complete, runnable Python script. No external test framework. -2. Allowed imports: standard library + `requests` + the actual module under test. +1. test_code 必须是完整、可直接运行的 Python 脚本 +2. 不得使用 unittest 或 pytest 框架 +3. 只允许导入:Python 标准库、requests、被测模块 -── FUNCTION INTERFACES ──────────────────────────────── -3. Each function interface has: - "source_file" : e.g. "create_user.py" - "module_path" : e.g. "create_user" (derived from source_file) -4. Import the REAL function using module_path: +── function 协议接口 ────────────────────────────────── +4. 使用 module_path 字段导入真实函数,若导入失败则使用桩函数: try: from import except ImportError: - # Stub fallback — simulate on_success for positive tests, - # on_failure / raise Exception for negative tests - def (): - return # or raise Exception(...) -5. Call the function with ONLY "in"/"inout" parameters. -6. Capture the return value; assert against on_success or on_failure structure. -7. If on_failure.value contains "exception" or "raises": - - In negative tests: wrap call in try/except, assert exception IS raised. - - Use field="exception", operator="raised", expected="Exception". + def (<入参列表>): + return # 桩函数,仅供结构验证 +5. 调用函数时只传 inout 为 "in" 或 "inout" 的参数 +6. 若 on_failure 包含 "exception",负向测试需用 try/except 捕获异常 -── HTTP INTERFACES ──────────────────────────────────── -8. Each HTTP interface has: - "full_url" : complete URL, e.g. "http://127.0.0.1/api/delete_user" - "method" : "get" | "post" | "put" | "delete" etc. -9. Send request: +── http 协议接口 ────────────────────────────────────── +7. 使用 full_url 字段作为请求地址: resp = requests.("", json=) -10. Assert on resp.status_code and resp.json() fields. +8. 断言 resp.status_code 和 resp.json() 中的字段 -── MULTI-STEP ───────────────────────────────────────── -11. Execute steps in order; extract fields from previous step's return value - and pass to next step via use_output_of mapping. - -── STRUCTURED OUTPUT (REQUIRED) ─────────────────────── -12. After EACH step, print: - ##STEP_RESULT## {"step_no":,"interface_name":"...","status":"PASS|FAIL", - "assertions":[{"field":"...","operator":"...","expected":..., - "actual":...,"passed":true|false,"message":"..."}]} -13. Final line must be: - PASS: or FAIL: -14. Wrap entire test body in try/except; print FAIL on any unhandled exception. -15. Do NOT use unittest or pytest. +── 结构化输出(必须遵守)───────────────────────────── +9. 每个步骤执行后,必须在标准输出打印如下格式的一行: + ##STEP_RESULT## {"step_no":,"interface_name":"...","status":"PASS 或 FAIL", + "assertions":[{"field":"...","operator":"...","expected":..., + "actual":...,"passed":true 或 false,"message":"..."}]} +10. 脚本最后一行必须是: + PASS: <摘要> 或 FAIL: <摘要> +11. 整个脚本主体用 try/except 包裹,捕获未处理异常时打印 FAIL ═══════════════════════════════════════════════════════ -ASSERTION RULES BY RETURN TYPE +【断言编写规则】 ═══════════════════════════════════════════════════════ -return.type = "dict" → assert each key in on_success.value (positive) - or on_failure.value (negative) -return.type = "boolean" → field="return", operator="eq", expected=true/false -return.type = "integer" → field="return", operator="eq"/"gt"/"lt" -on_failure = exception → field="exception", operator="raised" +- 返回值为 dict → 对 on_success / on_failure 中每个 key 单独断言 +- 返回值为 bool → field="return",operator="eq",expected=true 或 false +- 返回值为 int → field="return",operator="eq" / "gt" / "lt" +- 预期抛出异常 → field="exception",operator="raised" +- 预期不抛出异常 → field="exception",operator="not_raised" ═══════════════════════════════════════════════════════ -COVERAGE GUIDELINES +【覆盖度要求】 ═══════════════════════════════════════════════════════ -- Per requirement: at least 1 positive + 1 negative test case. -- Negative test_name/description MUST contain "negative" or "invalid". -- Multi-interface requirements: single multi-step test case. -- Cover ALL "in"/"inout" parameters (including optional ones). -- Assert ALL fields described in on_success (positive) and on_failure (negative). -- For "out" parameters: verify their values in assertions after the call. +- 每条需求至少生成 1 个正向用例 + 1 个负向用例 +- 负向用例的 test_name 必须包含"负向"或"无效" +- 覆盖接口的所有 inout 为 "in" 或 "inout" 的参数 +- 正向用例断言 on_success 的所有字段 +- 负向用例断言 on_failure 的所有字段 +""" + +# 参数生成规则段落(有参数约束时追加到 System Prompt) +_PARAM_DIRECTIVE_TEMPLATE = """ +═══════════════════════════════════════════════════════ +【参数生成规则(本次任务特殊要求)】 +═══════════════════════════════════════════════════════ +{directives} +- test_data_notes 字段必须说明本用例使用的参数类别(正常值/边界值/异常值)及具体值 +- 每个测试用例的 input 字段中的参数值,必须从下方"参数测试数据集"中选取 +- 禁止在 input 中使用未在数据集中出现的随机值 """ +# ══════════════════════════════════════════════════════════════ +# PromptBuilder +# ══════════════════════════════════════════════════════════════ + class PromptBuilder: - def get_system_prompt(self) -> str: - return SYSTEM_PROMPT.strip() + def __init__(self): + self._param_gen = ParamGenerator() + + # ── System Prompt ──────────────────────────────────────── + + def get_system_prompt( + self, + global_constraint: ParamConstraint = None, + ) -> str: + """ + 构建 System Prompt。 + 若存在全局参数约束,追加参数生成规则段落。 + """ + base = _SYSTEM_PROMPT_BASE.strip() + + if global_constraint and global_constraint.has_param_directive: + directive_section = self._build_param_directive_section( + global_constraint + ) + return base + "\n" + directive_section.strip() + + return base + + def _build_param_directive_section( + self, + constraint: ParamConstraint, + ) -> str: + """构建参数生成规则段落,注入 System Prompt。""" + lines: list[str] = [] + + if constraint.min_groups > 2: + lines.append( + f"- 每个接口的测试用例总数不少于 **{constraint.min_groups} 组**" + f"(正向 + 负向合计)" + ) + + strategy_labels = { + ParamStrategy.BOUNDARY: "边界值(空值、最大值、最小值、临界值)", + ParamStrategy.EXCEPTION: "异常值(None、类型错误、超长、特殊字符等)", + ParamStrategy.EQUIVALENCE: "等价类(有效等价类 + 无效等价类)", + ParamStrategy.RANDOM: "随机值(覆盖典型随机场景)", + } + if constraint.strategies: + strategy_str = "、".join( + strategy_labels[s] + for s in constraint.strategies + if s in strategy_labels + ) + lines.append(f"- 参数值必须覆盖以下类型:{strategy_str}") + + if not lines: + return "" + + directives = "\n ".join(lines) + return _PARAM_DIRECTIVE_TEMPLATE.format(directives=directives) + + # ── 项目信息头 ─────────────────────────────────────────── + + def build_project_header( + self, + project: str = "", + project_desc: str = "", + ) -> str: + if not project and not project_desc: + return "" + lines = ["## 项目信息"] + if project: + lines.append(f"项目名称:{project}") + if project_desc: + lines.append(f"项目描述:{project_desc}") + return "\n".join(lines) + + # ── 参数数据集段落 ─────────────────────────────────────── + + def build_param_dataset_section( + self, + iface_summaries: list[dict], + constraint: ParamConstraint, + ) -> str: + """ + 为批次内所有接口生成参数数据集,返回注入 user_prompt 的文本段落。 + """ + if not constraint.has_param_directive: + return "" + + all_lines: list[str] = [ + "## 参数测试数据集", + f"> 以下数据集由系统自动生成,请从中选取参数值组合," + f"确保每个接口的测试用例总数不少于 **{constraint.min_groups} 组**。", + "", + ] + + for iface in iface_summaries: + name = iface.get("name", "未知接口") + datasets = self._param_gen.generate_for_interface( + interface_summary=iface, + strategies=constraint.strategies, + min_groups=constraint.min_groups, + ) + if not datasets: + continue + + all_lines.append(f"### 接口:{name}") + text = self._param_gen.to_prompt_text(datasets, constraint.min_groups) + all_lines.append(text) + all_lines.append("") + + return "\n".join(all_lines) + + # ── 分批 user_prompt ──────────────────────────────────── + + def build_batch_prompt( + self, + batch: list[dict], + requirements: list[str], + project_header: str = "", + param_constraint: ParamConstraint = None, + ) -> str: + """ + 构建单批次的 user_prompt。 + 若存在参数约束,自动注入参数数据集。 + """ + req_lines = "\n".join( + f"{i + 1}. {r}" for i, r in enumerate(requirements) + ) + parts: list[str] = [] + + if project_header: + parts.append(project_header) + + parts.append( + "## 本批次接口描述\n" + + json.dumps(batch, ensure_ascii=False, indent=2) + ) + parts.append("## 测试需求\n" + req_lines) + + # 注入参数数据集 + if param_constraint and param_constraint.has_param_directive: + dataset_section = self.build_param_dataset_section( + iface_summaries=batch, + constraint=param_constraint, + ) + if dataset_section: + parts.append(dataset_section) + + parts.append( + "请根据以上接口描述、测试需求和参数数据集,生成完整的测试用例(包含正向和负向)。\n" + "- function 协议:使用 module_path 字段导入被测函数\n" + "- http 协议:使用 full_url 字段作为请求地址\n" + "- 每条需求的 requirement_id 须填入对应用例的 requirement_id 字段\n" + "- test_data_notes 须说明本用例使用的参数类别和具体值" + ) + return "\n\n".join(parts) + + # ── 单次 user_prompt ──────────────────────────────────── def build_user_prompt( self, - requirements: list[str], - interfaces: list[InterfaceInfo], - parser: InterfaceParser, - project: str = "", - project_desc: str = "", + requirements: list[str], + interfaces: list[InterfaceInfo], + parser: InterfaceParser, + project: str = "", + project_desc: str = "", + param_constraint: ParamConstraint = None, ) -> str: + """ + 单次调用模式(接口数 ≤ batch_size 时使用)。 + """ iface_summary = parser.to_summary_dict(interfaces) req_lines = "\n".join( f"{i + 1}. {r}" for i, r in enumerate(requirements) ) + header = self.build_project_header(project, project_desc) - # 项目信息头 - project_section = "" - if project or project_desc: - project_section = "## Project\n" - if project: - project_section += f"Name : {project}\n" - if project_desc: - project_section += f"Description: {project_desc}\n" - project_section += "\n" + parts: list[str] = [] + if header: + parts.append(header) + parts.append( + "## 接口描述\n" + + json.dumps(iface_summary, ensure_ascii=False, indent=2) + ) + parts.append("## 测试需求\n" + req_lines) - return ( - f"{project_section}" - f"## Available Interfaces\n" - f"{json.dumps(iface_summary, ensure_ascii=False, indent=2)}\n\n" - f"## Test Requirements\n" - f"Generate comprehensive test cases (positive + negative) for every interface.\n" - f"- The generated test cases for each interface must meet the following requirements: {req_lines}\n" - f"- For function interfaces: use 'module_path' for import, " - f"'source_file' for reference.\n" - f"- For HTTP interfaces: use 'full_url' as the request URL.\n" - f"- Use requirement_id from the interface to populate the test case's " - f"requirement_id field.\n" - f"- Chain multiple interface calls in one test case when a requirement " - f"involves more than one interface.\n" - ).strip() \ No newline at end of file + if param_constraint and param_constraint.has_param_directive: + dataset_section = self.build_param_dataset_section( + iface_summaries=iface_summary + if isinstance(iface_summary, list) else list(iface_summary.values()), + constraint=param_constraint, + ) + if dataset_section: + parts.append(dataset_section) + + parts.append( + "请根据以上接口描述、测试需求和参数数据集,生成完整的测试用例(包含正向和负向)。\n" + "- function 协议:使用 module_path 字段导入被测函数\n" + "- http 协议:使用 full_url 字段作为请求地址\n" + "- 每条需求的 requirement_id 须填入对应用例的 requirement_id 字段\n" + "- test_data_notes 须说明本用例使用的参数类别和具体值" + ) + return "\n\n".join(parts) \ No newline at end of file diff --git a/ai_test_generator/core/requirement_parser.py b/ai_test_generator/core/requirement_parser.py new file mode 100644 index 0000000..7cce453 --- /dev/null +++ b/ai_test_generator/core/requirement_parser.py @@ -0,0 +1,211 @@ +""" +需求解析器 +══════════════════════════════════════════════════════════════ +职责: + 从用户填写的测试需求文本中,识别并提取参数生成约束指令: + - 参数组数约束:如"不少于10组"、"至少5组"、"生成8组" + - 参数策略约束:如"边界值"、"异常值"、"等价类"、"随机值" + - 数据类型约束:如"字符串边界值"、"整数异常值" + 解析结果供 ParamGenerator 和 PromptBuilder 使用。 +""" + +from __future__ import annotations +import re +from dataclasses import dataclass, field +from enum import Enum + + +# ══════════════════════════════════════════════════════════════ +# 枚举:参数生成策略 +# ══════════════════════════════════════════════════════════════ + +class ParamStrategy(str, Enum): + BOUNDARY = "boundary" # 边界值 + EXCEPTION = "exception" # 异常值 / 错误值 + EQUIVALENCE = "equivalence" # 等价类 + RANDOM = "random" # 随机值 + CUSTOM = "custom" # 自定义(用户在需求中直接指定) + + +# ══════════════════════════════════════════════════════════════ +# 数据结构:解析结果 +# ══════════════════════════════════════════════════════════════ + +@dataclass +class ParamConstraint: + """ + 从单条需求文本中解析出的参数生成约束。 + """ + # 最少生成的参数组数(正向 + 负向合计) + min_groups: int = 2 + + # 需要覆盖的策略集合 + strategies: list[ParamStrategy] = field(default_factory=list) + + # 原始需求文本(去除参数指令后的纯业务需求部分) + clean_requirement: str = "" + + # 是否显式指定了参数要求(用于区分"用户有要求"和"默认值") + has_param_directive: bool = False + + def __str__(self) -> str: + parts = [] + if self.min_groups > 2: + parts.append(f"至少 {self.min_groups} 组参数") + for s in self.strategies: + parts.append(_STRATEGY_LABEL[s]) + return "、".join(parts) if parts else "默认(1正1负)" + + +_STRATEGY_LABEL: dict[ParamStrategy, str] = { + ParamStrategy.BOUNDARY: "边界值", + ParamStrategy.EXCEPTION: "异常值", + ParamStrategy.EQUIVALENCE: "等价类", + ParamStrategy.RANDOM: "随机值", + ParamStrategy.CUSTOM: "自定义", +} + + +# ══════════════════════════════════════════════════════════════ +# 解析规则 +# ══════════════════════════════════════════════════════════════ + +# 数量约束:匹配"不少于N组"、"至少N组"、"最少N组"、"生成N组"、"N组以上" +_RE_MIN_GROUPS = re.compile( + r'(?:不少于|至少|最少|生成|共|需要|要求)\s*(\d+)\s*组' + r'|(\d+)\s*组(?:以上|及以上|或以上)', + re.UNICODE, +) + +# 策略关键词映射(关键词 → 策略枚举) +_STRATEGY_KEYWORDS: list[tuple[list[str], ParamStrategy]] = [ + (["边界值", "边界", "boundary", "边界测试"], ParamStrategy.BOUNDARY), + (["异常值", "异常", "错误值", "非法值", + "exception", "invalid", "error"], ParamStrategy.EXCEPTION), + (["等价类", "等价", "equivalence"], ParamStrategy.EQUIVALENCE), + (["随机", "随机值", "random"], ParamStrategy.RANDOM), +] + +# 参数指令整体识别:包含上述任意关键词则认为是参数指令 +_RE_PARAM_DIRECTIVE = re.compile( + r'(?:参数|测试参数|测试数据|数据|入参)' + r'.*?(?:不少于|至少|最少|生成|边界|异常|等价|随机|\d+\s*组)', + re.UNICODE, +) + + +# ══════════════════════════════════════════════════════════════ +# 解析器 +# ══════════════════════════════════════════════════════════════ + +class RequirementParser: + """ + 解析测试需求列表,提取参数生成约束,返回: + - clean_requirements : 去除参数指令后的纯业务需求列表 + - constraints : 每条需求对应的 ParamConstraint + - global_constraint : 全局约束(从"全局"/"所有接口"等描述中提取) + """ + + def parse( + self, + requirements: list[str], + ) -> tuple[list[str], list[ParamConstraint], ParamConstraint]: + """ + Returns: + (clean_requirements, per_req_constraints, global_constraint) + """ + clean_reqs: list[str] = [] + constraints: list[ParamConstraint] = [] + global_c = ParamConstraint() + + for req in requirements: + c = self._parse_one(req) + constraints.append(c) + clean_reqs.append(c.clean_requirement) + + # 若某条需求是全局性描述(如"所有接口参数不少于10组"), + # 则将其约束提升为全局约束 + if self._is_global(req) and c.has_param_directive: + global_c = self._merge(global_c, c) + + return clean_reqs, constraints, global_c + + # ── 解析单条需求 ────────────────────────────────────────── + + def _parse_one(self, req: str) -> ParamConstraint: + c = ParamConstraint(clean_requirement=req) + + # 1. 检测是否包含参数指令 + if not self._has_directive(req): + return c + + c.has_param_directive = True + + # 2. 提取数量约束 + m = _RE_MIN_GROUPS.search(req) + if m: + n = int(m.group(1) or m.group(2)) + c.min_groups = max(n, 2) # 至少保留 1正1负 + + # 3. 提取策略约束 + seen: set[ParamStrategy] = set() + for keywords, strategy in _STRATEGY_KEYWORDS: + if any(kw in req for kw in keywords): + seen.add(strategy) + c.strategies = list(seen) + + # 4. 若未指定策略,默认同时覆盖边界值和异常值 + if not c.strategies: + c.strategies = [ParamStrategy.BOUNDARY, ParamStrategy.EXCEPTION] + + # 5. 清理需求文本:去除参数指令部分,保留业务描述 + c.clean_requirement = self._strip_directive(req) + + return c + + # ── 工具方法 ────────────────────────────────────────────── + + @staticmethod + def _has_directive(req: str) -> bool: + """判断需求文本是否包含参数生成指令。""" + return bool(_RE_PARAM_DIRECTIVE.search(req)) or bool( + _RE_MIN_GROUPS.search(req) + ) + + @staticmethod + def _is_global(req: str) -> bool: + """判断是否为全局性参数要求。""" + return any( + kw in req + for kw in ["所有接口", "全部接口", "全局", "所有测试", "全部测试"] + ) + + @staticmethod + def _strip_directive(req: str) -> str: + """ + 去除需求文本中的参数指令部分,保留纯业务描述。 + 策略:以中文逗号、分号、括号等分隔,去除含参数指令的子句。 + """ + # 按常见分隔符拆分 + parts = re.split(r'[,,;;(())【】\n]', req) + clean_parts = [] + for part in parts: + part = part.strip() + if not part: + continue + # 若该子句包含参数指令关键词,则跳过 + if _RE_PARAM_DIRECTIVE.search(part) or _RE_MIN_GROUPS.search(part): + continue + clean_parts.append(part) + result = ",".join(clean_parts).strip() + return result if result else req # 若全部被去除则保留原文 + + @staticmethod + def _merge(base: ParamConstraint, other: ParamConstraint) -> ParamConstraint: + """合并两个约束,取较大值。""" + merged = ParamConstraint( + min_groups=max(base.min_groups, other.min_groups), + strategies=list(set(base.strategies) | set(other.strategies)), + has_param_directive=True, + ) + return merged \ No newline at end of file diff --git a/ai_test_generator/core/test_runner.py b/ai_test_generator/core/test_runner.py index 90a3e7f..0f91940 100644 --- a/ai_test_generator/core/test_runner.py +++ b/ai_test_generator/core/test_runner.py @@ -1,168 +1,278 @@ -import os -import sys +""" +测试执行器 +══════════════════════════════════════════════════════════════ +大规模支持改造: + 1. 并行执行:ThreadPoolExecutor,默认 8 个并发 + 2. 超时控制:单个用例超时不阻塞整体 + 3. 步骤级结果解析:##STEP_RESULT## 行 + 4. 进度显示:实时打印完成数 + 5. 结果持久化:保存 run_results.json +""" + +from __future__ import annotations import json -import subprocess -import time import logging -from pathlib import Path +import subprocess +import sys +import time +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field +from pathlib import Path + from config import config logger = logging.getLogger(__name__) +STEP_RESULT_PREFIX = "##STEP_RESULT##" + + +# ══════════════════════════════════════════════════════════════ +# 数据结构 +# ══════════════════════════════════════════════════════════════ + +@dataclass +class AssertionResult: + field: str + operator: str + expected: object + actual: object + passed: bool + message: str = "" + @dataclass class StepResult: step_no: int interface_name: str - status: str # PASS | FAIL | ERROR - assertions: list[dict] = field(default_factory=list) - # 每条 assertion: {"field","operator","expected","actual","passed","message"} + status: str # "PASS" | "FAIL" + assertions: list[AssertionResult] = field(default_factory=list) @dataclass class TestResult: test_id: str file_path: str - status: str # PASS | FAIL | ERROR | TIMEOUT - message: str = "" - duration: float = 0.0 - stdout: str = "" - stderr: str = "" + status: str # "PASS" | "FAIL" | "ERROR" | "TIMEOUT" + message: str = "" + duration: float = 0.0 step_results: list[StepResult] = field(default_factory=list) + stdout: str = "" + stderr: str = "" +# ══════════════════════════════════════════════════════════════ +# 执行器 +# ══════════════════════════════════════════════════════════════ + class TestRunner: + def __init__(self): + self.timeout = getattr(config, "TEST_TIMEOUT", 60) + self.max_workers = getattr(config, "TEST_MAX_WORKERS", 8) + self.python_bin = getattr(config, "TEST_PYTHON_BIN", sys.executable) + + # ── 主入口 ──────────────────────────────────────────────── + def run_all(self, test_files: list[Path]) -> list[TestResult]: - results = [] - total = len(test_files) - print(f"\n{'═'*62}") - print(f" Running {total} test(s) …") - print(f"{'═'*62}") - for idx, fp in enumerate(test_files, 1): - print(f"\n[{idx}/{total}] {fp.name}") - result = self._run_one(fp) - results.append(result) - self._print_result(result) + """ + 并行执行所有测试文件,返回结果列表。 + """ + if not test_files: + logger.warning("No test files to run.") + return [] + + total = len(test_files) + results: list[TestResult] = [] + done = 0 + + logger.info( + f"Running {total} test(s) with " + f"max_workers={self.max_workers}, timeout={self.timeout}s" + ) + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_map = { + executor.submit(self._run_one, f): f + for f in test_files + } + for future in as_completed(future_map): + result = future.result() + results.append(result) + done += 1 + icon = "✅" if result.status == "PASS" else "❌" + logger.info( + f" [{done}/{total}] {icon} {result.test_id} " + f"({result.status}) {result.duration:.2f}s" + ) + + # 按文件名排序,保持输出稳定 + results.sort(key=lambda r: r.file_path) return results - # ── 执行单个脚本 ────────────────────────────────────────── + # ── 执行单个测试文件 ────────────────────────────────────── def _run_one(self, file_path: Path) -> TestResult: - test_id = file_path.stem - t0 = time.time() + test_id = file_path.stem + start = time.monotonic() + try: proc = subprocess.run( - [sys.executable, str(file_path)], - capture_output=True, text=True, - timeout=config.TEST_TIMEOUT, - env=self._env(), + [self.python_bin, str(file_path)], + capture_output=True, + text=True, + timeout=self.timeout, ) - duration = time.time() - t0 - stdout = proc.stdout.strip() - stderr = proc.stderr.strip() - status, message = self._parse_output(stdout, proc.returncode) - step_results = self._parse_step_results(stdout) - return TestResult( - test_id=test_id, file_path=str(file_path), - status=status, message=message, - duration=duration, stdout=stdout, stderr=stderr, - step_results=step_results, + duration = time.monotonic() - start + return self._parse_output( + test_id, str(file_path), proc, duration ) + except subprocess.TimeoutExpired: + duration = time.monotonic() - start + logger.warning(f" TIMEOUT: {test_id} ({duration:.1f}s)") return TestResult( - test_id=test_id, file_path=str(file_path), - status="TIMEOUT", message=f"Exceeded {config.TEST_TIMEOUT}s", - duration=time.time() - t0, + test_id=test_id, + file_path=str(file_path), + status="TIMEOUT", + message=f"Exceeded {self.timeout}s timeout", + duration=duration, ) except Exception as e: + duration = time.monotonic() - start + logger.error(f" ERROR running {test_id}: {e}") return TestResult( - test_id=test_id, file_path=str(file_path), - status="ERROR", message=str(e), - duration=time.time() - t0, + test_id=test_id, + file_path=str(file_path), + status="ERROR", + message=str(e), + duration=duration, ) # ── 解析输出 ────────────────────────────────────────────── - def _parse_output(self, stdout: str, returncode: int) -> tuple[str, str]: - if not stdout: - return "FAIL", f"No output (exit={returncode})" - last = stdout.strip().splitlines()[-1].strip() - upper = last.upper() - if upper.startswith("PASS"): - return "PASS", last[5:].strip() - if upper.startswith("FAIL"): - return "FAIL", last[5:].strip() - return ("PASS" if returncode == 0 else "FAIL"), last + def _parse_output( + self, + test_id: str, + file_path: str, + proc: subprocess.CompletedProcess, + duration: float, + ) -> TestResult: + stdout = proc.stdout or "" + stderr = proc.stderr or "" + step_results = self._parse_step_results(stdout) + + # 最后一行决定整体状态 + last_line = stdout.strip().splitlines()[-1] if stdout.strip() else "" + if last_line.startswith("PASS"): + status = "PASS" + message = last_line + elif last_line.startswith("FAIL"): + status = "FAIL" + message = last_line + elif proc.returncode != 0: + status = "FAIL" + message = f"Exit code {proc.returncode}: {stderr[:200]}" + else: + status = "ERROR" + message = "No PASS/FAIL line found in output" + + return TestResult( + test_id=test_id, + file_path=file_path, + status=status, + message=message, + duration=duration, + step_results=step_results, + stdout=stdout, + stderr=stderr, + ) def _parse_step_results(self, stdout: str) -> list[StepResult]: - """ - 解析脚本中以 ##STEP_RESULT## 开头的结构化输出行 - 格式:##STEP_RESULT## - """ results = [] for line in stdout.splitlines(): line = line.strip() - if line.startswith("##STEP_RESULT##"): - try: - data = json.loads(line[len("##STEP_RESULT##"):].strip()) - results.append(StepResult( - step_no=data.get("step_no", 0), - interface_name=data.get("interface_name", ""), - status=data.get("status", ""), - assertions=data.get("assertions", []), - )) - except Exception: - pass + if not line.startswith(STEP_RESULT_PREFIX): + continue + try: + raw = json.loads(line[len(STEP_RESULT_PREFIX):].strip()) + assertions = [ + AssertionResult( + field=a.get("field", ""), + operator=a.get("operator", ""), + expected=a.get("expected"), + actual=a.get("actual"), + passed=bool(a.get("passed", False)), + message=a.get("message", ""), + ) + for a in raw.get("assertions", []) + ] + results.append(StepResult( + step_no=raw.get("step_no", 0), + interface_name=raw.get("interface_name", ""), + status=raw.get("status", "FAIL"), + assertions=assertions, + )) + except (json.JSONDecodeError, KeyError) as e: + logger.debug(f"Step result parse error: {e}") return results - def _env(self) -> dict: - env = os.environ.copy() - env["HTTP_BASE_URL"] = config.HTTP_BASE_URL - return env - - # ── 打印 ────────────────────────────────────────────────── - - def _print_result(self, r: TestResult): - icon = {"PASS": "✅", "FAIL": "❌", "TIMEOUT": "⏱️", "ERROR": "⚠️"}.get(r.status, "❓") - print(f" {icon} [{r.status}] {r.test_id} ({r.duration:.2f}s)") - print(f" {r.message}") - for sr in r.step_results: - s_icon = "✅" if sr.status == "PASS" else "❌" - print(f" {s_icon} Step {sr.step_no}: {sr.interface_name}") - for a in sr.assertions: - a_icon = "✔" if a.get("passed") else "✘" - print(f" {a_icon} {a.get('message','')} " - f"(expected={a.get('expected')}, actual={a.get('actual')})") - if r.stderr: - print(f" stderr: {r.stderr[:300]}") + # ── 摘要打印 ────────────────────────────────────────────── def print_summary(self, results: list[TestResult]): - total = len(results) - passed = sum(1 for r in results if r.status == "PASS") - failed = sum(1 for r in results if r.status == "FAIL") - errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT")) - print(f"\n{'═'*62}") - print(f" TEST SUMMARY") - print(f"{'═'*62}") - print(f" Total : {total}") + total = len(results) + passed = sum(1 for r in results if r.status == "PASS") + failed = sum(1 for r in results if r.status == "FAIL") + errors = sum(1 for r in results if r.status in ("ERROR", "TIMEOUT")) + avg_dur = sum(r.duration for r in results) / total if total else 0 + + print(f"\n{'─' * 56}") + print(f" Test Run Summary ({total} cases)") + print(f"{'─' * 56}") print(f" ✅ PASS : {passed}") print(f" ❌ FAIL : {failed}") print(f" ⚠️ ERROR : {errors}") - print(f" Pass Rate: {passed/total*100:.1f}%" if total else " Pass Rate: N/A") - print(f"{'═'*62}\n") + print(f" ⏱ Avg : {avg_dur:.2f}s / case") + print(f"{'─' * 56}") + + if failed or errors: + print(" Failed / Error cases:") + for r in results: + if r.status not in ("PASS",): + print(f" [{r.status}] {r.test_id}: {r.message}") + print() + + # ── 持久化 ──────────────────────────────────────────────── def save_results(self, results: list[TestResult], path: str): - with open(path, "w", encoding="utf-8") as f: - json.dump([{ - "test_id": r.test_id, "status": r.status, - "message": r.message, "duration": r.duration, - "stdout": r.stdout, "stderr": r.stderr, + data = [ + { + "test_id": r.test_id, + "file_path": r.file_path, + "status": r.status, + "message": r.message, + "duration": round(r.duration, 3), "step_results": [ - {"step_no": sr.step_no, "interface_name": sr.interface_name, - "status": sr.status, "assertions": sr.assertions} - for sr in r.step_results + { + "step_no": s.step_no, + "interface_name": s.interface_name, + "status": s.status, + "assertions": [ + { + "field": a.field, + "operator": a.operator, + "expected": a.expected, + "actual": a.actual, + "passed": a.passed, + "message": a.message, + } + for a in s.assertions + ], + } + for s in r.step_results ], - } for r in results], f, ensure_ascii=False, indent=2) + } + for r in results + ] + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) logger.info(f"Run results → {path}") \ No newline at end of file diff --git a/ai_test_generator/main.py b/ai_test_generator/main.py index 5714c79..7ff44e7 100644 --- a/ai_test_generator/main.py +++ b/ai_test_generator/main.py @@ -3,11 +3,18 @@ AI-Powered API Test Generator, Runner & Coverage Analyzer ────────────────────────────────────────────────────────── Usage: + # 基础用法 python main.py --api-desc examples/api_desc.json \\ --requirements "创建用户,删除用户" + # 在需求中自然语言描述参数生成要求 python main.py --api-desc examples/api_desc.json \\ - --req-file examples/requirements.txt + --requirements "创建用户(测试参数不少于10组,覆盖边界值和异常值),删除用户" + + # 从文件读取需求(每行一条,支持参数生成指令) + python main.py --api-desc examples/api_desc.json \\ + --req-file examples/requirements.txt \\ + --batch-size 10 """ import argparse @@ -17,6 +24,7 @@ from pathlib import Path from config import config from core.parser import InterfaceParser, ApiDescriptor +from core.requirement_parser import RequirementParser, ParamConstraint from core.prompt_builder import PromptBuilder from core.llm_client import LLMClient from core.test_generator import TestGenerator @@ -24,48 +32,57 @@ from core.test_runner import TestRunner from core.analyzer import CoverageAnalyzer, CoverageReport from core.report_generator import ReportGenerator -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", - handlers=[logging.StreamHandler(sys.stdout)], -) + +# ══════════════════════════════════════════════════════════════ +# 日志初始化 +# ══════════════════════════════════════════════════════════════ + +def setup_logging(debug: bool = False): + """ + 配置日志级别。 + + ⚠️ Bug 规避(openai-python v2.24.0 + pydantic-core): + 第三方 SDK logger 强制锁定在 WARNING,避免触发 + model_dump(by_alias=None) → pydantic-core Rust 层 TypeError。 + Ref: https://github.com/openai/openai-python/issues/2921 + """ + root_level = logging.DEBUG if debug else logging.INFO + logging.basicConfig( + level=root_level, + format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + ) + for name in ["openai", "openai._base_client", "anthropic", "httpx", "httpcore"]: + logging.getLogger(name).setLevel(logging.WARNING) + + logger = logging.getLogger(__name__) -# ── CLI ─────────────────────────────────────────────────────── +# ══════════════════════════════════════════════════════════════ +# CLI +# ══════════════════════════════════════════════════════════════ def build_arg_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( - description="AI-Powered API Test Generator & Coverage Analyzer" - ) - p.add_argument( - "--api-desc", required=True, - help="接口描述 JSON 文件(含 project / description / units 字段)", - ) - p.add_argument( - "--requirements", default="", - help="测试需求(逗号分隔)", - ) - p.add_argument( - "--req-file", default="", - help="测试需求文件(每行一条)", - ) - p.add_argument( - "--skip-run", action="store_true", - help="只生成测试文件,不执行", - ) - p.add_argument( - "--skip-analyze", action="store_true", - help="跳过覆盖率分析", - ) - p.add_argument( - "--output-dir", default="", - help="测试文件根输出目录(默认 config.GENERATED_TESTS_DIR)", - ) - p.add_argument( - "--debug", action="store_true", - help="开启 DEBUG 日志", + description="AI-Powered API Test Generator & Coverage Analyzer", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +参数生成指令示例(可直接写在需求文本中): + "创建用户,测试参数不少于10组,覆盖边界值和异常值" + "查询接口(至少8组,边界值)" + "所有接口参数不少于5组,覆盖等价类" + """, ) + p.add_argument("--api-desc", required=True, help="接口描述 JSON 文件") + p.add_argument("--requirements", default="", help="测试需求(逗号分隔,支持参数生成指令)") + p.add_argument("--req-file", default="", help="测试需求文件(每行一条)") + p.add_argument("--batch-size", type=int, default=0, help="每批接口数量(0=使用 config.LLM_BATCH_SIZE)") + p.add_argument("--workers", type=int, default=0, help="并行执行线程数(0=使用 config.TEST_MAX_WORKERS)") + p.add_argument("--skip-run", action="store_true", help="只生成,不执行") + p.add_argument("--skip-analyze", action="store_true", help="跳过覆盖率分析") + p.add_argument("--output-dir", default="", help="测试文件根输出目录") + p.add_argument("--debug", action="store_true", help="开启 DEBUG 日志(第三方 SDK 仍保持 WARNING)") return p @@ -77,82 +94,111 @@ def load_requirements(args) -> list[str]: with open(args.req_file, "r", encoding="utf-8") as f: reqs += [line.strip() for line in f if line.strip()] if not reqs: - logger.error( - "No requirements provided. Use --requirements or --req-file." - ) + logger.error("未提供任何测试需求,请使用 --requirements 或 --req-file 指定。") sys.exit(1) return reqs -# ── Main ────────────────────────────────────────────────────── +# ══════════════════════════════════════════════════════════════ +# Main +# ══════════════════════════════════════════════════════════════ def main(): args = build_arg_parser().parse_args() + setup_logging(debug=args.debug) - if args.debug: - logging.getLogger().setLevel(logging.DEBUG) - if args.output_dir: - config.GENERATED_TESTS_DIR = args.output_dir + if args.output_dir: config.GENERATED_TESTS_DIR = args.output_dir + if args.batch_size > 0: config.LLM_BATCH_SIZE = args.batch_size + if args.workers > 0: config.TEST_MAX_WORKERS = args.workers # ── Step 1: 解析接口描述 ────────────────────────────────── - logger.info("▶ Step 1: Parsing interface description …") - parser: InterfaceParser = InterfaceParser() - descriptor: ApiDescriptor = parser.parse_file(args.api_desc) + logger.info("▶ Step 1: 解析接口描述 …") + parser = InterfaceParser() + descriptor = parser.parse_file(args.api_desc) project = descriptor.project project_desc = descriptor.description interfaces = descriptor.interfaces - logger.info(f" Project : {project}") - logger.info(f" Description: {project_desc}") - logger.info( - f" Interfaces : {len(interfaces)} — " - f"{[i.name for i in interfaces]}" - ) - # 打印每个接口的 url 解析结果 + logger.info(f" 项目名称:{project}") + logger.info(f" 项目描述:{project_desc}") + logger.info(f" 接口数量:{len(interfaces)}") for iface in interfaces: if iface.protocol == "function": logger.info( - f" [{iface.name}] source={iface.source_file} " - f"module={iface.module_path}" + f" [function] {iface.name} " + f"← {iface.source_file} (module: {iface.module_path})" ) else: - logger.info( - f" [{iface.name}] full_url={iface.http_full_url}" - ) + logger.info(f" [http] {iface.name} ← {iface.http_full_url}") - # ── Step 2: 加载测试需求 ────────────────────────────────── - logger.info("▶ Step 2: Loading test requirements …") - requirements = load_requirements(args) - for i, req in enumerate(requirements, 1): - logger.info(f" {i}. {req}") + # ── Step 2: 加载并解析测试需求(含参数生成指令)────────── + logger.info("▶ Step 2: 加载并解析测试需求 …") + raw_requirements = load_requirements(args) - # ── Step 3: 调用 LLM 生成测试用例 ──────────────────────── - logger.info("▶ Step 3: Calling LLM …") - builder = PromptBuilder() - test_cases = LLMClient().generate_test_cases( - builder.get_system_prompt(), - builder.build_user_prompt( - requirements, interfaces, parser, - project=project, - project_desc=project_desc, - ), + req_parser = RequirementParser() + clean_reqs, per_req_constraints, global_constraint = req_parser.parse( + raw_requirements ) - logger.info(f" LLM returned {len(test_cases)} test case(s)") + + # 打印原始需求 & 解析结果 + for i, (raw, clean, c) in enumerate( + zip(raw_requirements, clean_reqs, per_req_constraints), 1 + ): + if c.has_param_directive: + logger.info( + f" {i}. {raw}\n" + f" └─ 参数约束:{c} | 业务需求:{clean}" + ) + else: + logger.info(f" {i}. {raw}") + + if global_constraint.has_param_directive: + logger.info(f" 全局参数约束:{global_constraint}") + + # 使用"清洗后的需求"传给 LLM(去掉参数指令,保留纯业务描述) + # 同时保留原始需求用于覆盖率分析(保证与用户输入一致) + effective_requirements = clean_reqs + + # ── Step 3: 构建 Prompt 并调用 LLM 生成测试用例 ────────── + logger.info( + f"▶ Step 3: 调用 LLM 生成测试用例 " + f"(batch_size={getattr(config, 'LLM_BATCH_SIZE', 10)}) …" + ) + builder = PromptBuilder() + iface_summary = parser.to_summary_dict(interfaces) + project_header = builder.build_project_header(project, project_desc) + + # System Prompt:注入全局参数约束规则 + system_prompt = builder.get_system_prompt( + global_constraint=global_constraint + ) + + test_cases = LLMClient().generate_test_cases( + system_prompt=system_prompt, + user_prompt="", + iface_summaries=iface_summary, + requirements=effective_requirements, + project_header=project_header, + param_constraint=global_constraint, # ← 注入参数约束 + ) + logger.info(f" 共生成测试用例:{len(test_cases)} 个") # ── Step 4: 生成测试文件 ────────────────────────────────── - logger.info( - f"▶ Step 4: Generating test files (project='{project}') …" - ) + logger.info("▶ Step 4: 生成测试文件 …") generator = TestGenerator(project=project, project_desc=project_desc) test_files = generator.generate(test_cases) out = generator.output_dir + logger.info(f" 输出目录:{out.resolve()}") run_results = [] if not args.skip_run: - # ── Step 5: 执行测试 ────────────────────────────────── - logger.info("▶ Step 5: Running tests …") + # ── Step 5: 并行执行测试 ────────────────────────────── + logger.info( + f"▶ Step 5: 执行测试 " + f"(workers={getattr(config, 'TEST_MAX_WORKERS', 8)}) …" + ) runner = TestRunner() run_results = runner.run_all(test_files) runner.print_summary(run_results) @@ -160,31 +206,36 @@ def main(): if not args.skip_analyze: # ── Step 6: 覆盖率分析 ──────────────────────────────── - logger.info("▶ Step 6: Analyzing coverage …") + logger.info("▶ Step 6: 覆盖率分析 …") report = CoverageAnalyzer( interfaces=interfaces, - requirements=requirements, + requirements=raw_requirements, # 用原始需求做覆盖率分析 test_cases=test_cases, run_results=run_results, ).analyze() # ── Step 7: 生成报告 ────────────────────────────────── - logger.info("▶ Step 7: Generating reports …") + logger.info("▶ Step 7: 生成报告 …") rg = ReportGenerator() rg.save_json(report, str(out / "coverage_report.json")) rg.save_html(report, str(out / "coverage_report.html")) - _print_terminal_summary(report, out, project) + _print_terminal_summary(report, out, project, global_constraint) - logger.info(f"\n✅ Done. Output: {out.resolve()}") + logger.info(f"\n✅ 完成。输出目录:{out.resolve()}") -# ── 终端摘要 ────────────────────────────────────────────────── +# ══════════════════════════════════════════════════════════════ +# 终端摘要 +# ══════════════════════════════════════════════════════════════ def _print_terminal_summary( - report: CoverageReport, out: Path, project: str + report: "CoverageReport", + out: Path, + project: str, + global_constraint: ParamConstraint, ): - W = 66 + W = 68 def bar(rate: float, w: int = 20) -> str: filled = int(rate * w) @@ -193,35 +244,41 @@ def _print_terminal_summary( return f"{'█' * filled}{'░' * empty} {rate * 100:.1f}% {icon}" print(f"\n{'═' * W}") - print(f" PROJECT : {project}") - print(f" COVERAGE SUMMARY") + print(f" 项目:{project}") + print(f" 覆盖率摘要") print(f"{'═' * W}") - print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}") - print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}") - print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}") - print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}") - print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}") - print(f" 用例通过率 {bar(report.pass_rate)}") + print(f" 接口覆盖率 {bar(report.interface_coverage_rate)}") + print(f" 需求覆盖率 {bar(report.requirement_coverage_rate)}") + print(f" 入参覆盖率 {bar(report.avg_in_param_coverage_rate)}") + print(f" 成功返回字段覆盖 {bar(report.avg_success_field_coverage_rate)}") + print(f" 失败返回字段覆盖 {bar(report.avg_failure_field_coverage_rate)}") + print(f" 用例通过率 {bar(report.pass_rate)}") print(f"{'─' * W}") - print(f" Total Gaps : {len(report.gaps)}") - print(f" 🔴 Critical: {report.critical_gap_count}") - print(f" 🟠 High : {report.high_gap_count}") + print(f" 测试用例总数 {report.total_test_cases}") + print(f" 覆盖缺口总数 {len(report.gaps)}") + print(f" 🔴 严重缺口 {report.critical_gap_count}") + print(f" 🟠 高优先级缺口 {report.high_gap_count}") + + # 参数约束达成情况 + if global_constraint.has_param_directive: + print(f"{'─' * W}") + print(f" 参数生成约束 {global_constraint}") + actual = report.total_test_cases + needed = global_constraint.min_groups + status = "✅ 已满足" if actual >= needed else f"❌ 未满足(需 {needed} 组,实际 {actual} 组)" + print(f" 参数组数要求 {status}") if report.gaps: print(f"{'─' * W}") - print(" Top Gaps (up to 8):") - icons = { - "critical": "🔴", "high": "🟠", - "medium": "🟡", "low": "🔵", - } + print(" Top 缺口(最多显示8条):") + icons = {"critical": "🔴", "high": "🟠", "medium": "🟡", "low": "🔵"} for g in report.gaps[:8]: - icon = icons.get(g.severity, "⚪") - print(f" {icon} [{g.gap_type}] {g.target}") + print(f" {icons.get(g.severity, '⚪')} [{g.gap_type}] {g.target}") print(f" → {g.suggestion}") print(f"{'─' * W}") - print(f" Output : {out.resolve()}") - print(f" • coverage_report.html ← open in browser") + print(f" 输出目录:{out.resolve()}") + print(f" • coverage_report.html") print(f" • coverage_report.json") print(f" • run_results.json") print(f" • test_cases_summary.json") diff --git a/ai_test_generator/run.sh b/ai_test_generator/run.sh index 6da9f8c..c4d8f5e 100755 --- a/ai_test_generator/run.sh +++ b/ai_test_generator/run.sh @@ -16,4 +16,4 @@ export HTTP_BASE_URL="http://localhost:8080" # 只生成不执行 python main.py --api-desc examples/api_desc.json \ - --requirements "1.对每个接口进行测试,支持特殊值、边界值测试" + --requirements "每个测试用例的参数不少于10组"