AIDeveloper-PC/requirements_generator/main.py

789 lines
30 KiB
Python
Raw Normal View History

2026-03-04 18:09:45 +00:00
#!/usr/bin/env python3
# encoding: utf-8
# main.py - 主入口:支持交互式 & 非交互式CLI 参数)两种运行模式
#
# 交互式: python main.py
# 非交互式python main.py --non-interactive \
# --project-name "MyProject" \
# --language python \
# --requirement-text "用户管理系统,包含注册、登录、修改密码功能"
#
# 完整参数见python main.py --help
import os
import sys
from typing import Dict
import click
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.prompt import Prompt, Confirm
import config
from database.db_manager import DBManager
from database.models import Project, RawRequirement, FunctionalRequirement
from core.llm_client import LLMClient
from core.requirement_analyzer import RequirementAnalyzer
from core.code_generator import CodeGenerator
from utils.file_handler import read_file_auto, merge_knowledge_files
from utils.output_writer import (
ensure_project_dir, build_project_output_dir, write_project_readme,
write_function_signatures_json, validate_all_signatures,
patch_signatures_with_url,
)
console = Console()
db = DBManager()
# ══════════════════════════════════════════════════════
# 显示工具函数
# ══════════════════════════════════════════════════════
def print_banner():
console.print(Panel.fit(
"[bold cyan]🚀 需求分析 & 代码生成工具[/bold cyan]\n"
"[dim]Powered by LLM · SQLite · Python[/dim]",
border_style="cyan"
))
def print_functional_requirements(reqs: list):
"""以表格形式展示功能需求列表"""
table = Table(title="📋 功能需求列表", show_lines=True)
table.add_column("序号", style="cyan", width=6)
table.add_column("ID", style="dim", width=6)
table.add_column("标题", style="bold", width=20)
table.add_column("函数名", width=25)
table.add_column("优先级", width=8)
table.add_column("类型", width=8)
table.add_column("描述", width=40)
priority_color = {"high": "red", "medium": "yellow", "low": "green"}
for req in reqs:
color = priority_color.get(req.priority, "white")
table.add_row(
str(req.index_no),
str(req.id) if req.id else "-",
req.title,
f"[code]{req.function_name}[/code]",
f"[{color}]{req.priority}[/{color}]",
"[magenta]自定义[/magenta]" if req.is_custom else "LLM生成",
req.description[:60] + "..." if len(req.description) > 60 else req.description,
)
console.print(table)
def print_signatures_preview(signatures: list):
"""
以表格形式预览函数签名列表 url 字段
Args:
signatures: 纯签名列表顶层文档的 "functions" 字段
"""
table = Table(title="📄 函数签名预览", show_lines=True)
table.add_column("需求编号", style="cyan", width=10)
table.add_column("函数名", style="bold", width=25)
table.add_column("参数数量", width=8)
table.add_column("返回类型", width=10)
table.add_column("成功返回值", width=18)
table.add_column("失败返回值", width=18)
table.add_column("URL", style="dim", width=30)
def _fmt_value(v) -> str:
if v is None:
return "-"
if isinstance(v, dict):
return "{" + ", ".join(v.keys()) + "}"
return str(v)[:16]
for sig in signatures:
ret = sig.get("return") or {}
on_success = ret.get("on_success") or {}
on_failure = ret.get("on_failure") or {}
url = sig.get("url", "")
# 只显示文件名部分,避免路径过长
url_display = os.path.basename(url) if url else "[dim]待生成[/dim]"
table.add_row(
sig.get("requirement_id", "-"),
sig.get("name", "-"),
str(len(sig.get("parameters", {}))),
ret.get("type", "void"),
_fmt_value(on_success.get("value")),
_fmt_value(on_failure.get("value")),
url_display,
)
console.print(table)
# ══════════════════════════════════════════════════════
# Step 1项目初始化
# ══════════════════════════════════════════════════════
def step_init_project(
project_name: str = None,
language: str = None,
description: str = "",
non_interactive: bool = False,
) -> Project:
if not non_interactive:
console.print("\n[bold]Step 1 · 项目配置[/bold]", style="blue")
project_name = project_name or Prompt.ask("📁 请输入项目名称")
language = language or Prompt.ask(
"💻 目标代码语言",
default=config.DEFAULT_LANGUAGE,
choices=["python", "javascript", "typescript", "java", "go", "rust"],
)
description = description or Prompt.ask("📝 项目描述(可选)", default="")
else:
if not project_name:
raise ValueError("非交互模式下 --project-name 为必填项")
language = language or config.DEFAULT_LANGUAGE
console.print("\n[bold]Step 1 · 项目配置[/bold] [dim](非交互)[/dim]", style="blue")
console.print(f" 项目名称: {project_name} 语言: {language}")
existing = db.get_project_by_name(project_name)
if existing:
if non_interactive:
console.print(f"[green]✓ 已加载已有项目: {project_name} (ID={existing.id})[/green]")
return existing
use_existing = Confirm.ask(f"⚠️ 项目 '{project_name}' 已存在,是否继续使用?")
if use_existing:
console.print(f"[green]✓ 已加载项目: {project_name} (ID={existing.id})[/green]")
return existing
project_name = Prompt.ask("请输入新的项目名称")
output_dir = build_project_output_dir(project_name)
project = Project(
name=project_name,
language=language,
output_dir=output_dir,
description=description,
)
project.id = db.create_project(project)
console.print(f"[green]✓ 项目已创建: {project_name} (ID={project.id})[/green]")
return project
# ══════════════════════════════════════════════════════
# Step 2输入原始需求 & 知识库
# ══════════════════════════════════════════════════════
def step_input_requirement(
project: Project,
requirement_text: str = None,
requirement_file: str = None,
knowledge_files: list = None,
non_interactive: bool = False,
) -> tuple:
console.print(
f"\n[bold]Step 2 · 输入原始需求[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
raw_text = ""
source_name = None
source_type = "text"
if non_interactive:
if requirement_file:
raw_text = read_file_auto(requirement_file)
source_name = os.path.basename(requirement_file)
source_type = "file"
console.print(f" 需求文件: {source_name} ({len(raw_text)} 字符)")
elif requirement_text:
raw_text = requirement_text
source_type = "text"
console.print(f" 需求文本: {raw_text[:80]}{'...' if len(raw_text) > 80 else ''}")
else:
raise ValueError("非交互模式下必须提供 --requirement-text 或 --requirement-file")
else:
input_type = Prompt.ask("📥 需求输入方式", choices=["text", "file"], default="text")
if input_type == "text":
console.print("[dim]请输入原始需求(输入空行结束):[/dim]")
lines = []
while True:
line = input()
if line == "" and lines:
break
lines.append(line)
raw_text = "\n".join(lines)
source_type = "text"
else:
file_path = Prompt.ask("📂 需求文件路径")
raw_text = read_file_auto(file_path)
source_name = os.path.basename(file_path)
source_type = "file"
console.print(f"[green]✓ 已读取文件: {source_name} ({len(raw_text)} 字符)[/green]")
knowledge_text = ""
if non_interactive:
if knowledge_files:
knowledge_text = merge_knowledge_files(list(knowledge_files))
console.print(f" 知识库: {len(knowledge_files)} 个文件,{len(knowledge_text)} 字符")
else:
use_kb = Confirm.ask("📚 是否输入知识库文件?", default=False)
if use_kb:
kb_paths = []
while True:
kb_path = Prompt.ask("知识库文件路径(留空结束)", default="")
if not kb_path:
break
if os.path.exists(kb_path):
kb_paths.append(kb_path)
console.print(f" [green]+ {kb_path}[/green]")
else:
console.print(f" [red]文件不存在: {kb_path}[/red]")
if kb_paths:
knowledge_text = merge_knowledge_files(kb_paths)
console.print(f"[green]✓ 知识库已合并 ({len(knowledge_text)} 字符)[/green]")
return raw_text, knowledge_text, source_name, source_type
# ══════════════════════════════════════════════════════
# Step 3LLM 分解需求
# ══════════════════════════════════════════════════════
def step_decompose_requirements(
project: Project,
raw_text: str,
knowledge_text: str,
source_name: str,
source_type: str,
non_interactive: bool = False,
) -> tuple:
console.print(
f"\n[bold]Step 3 · LLM 需求分解[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
raw_req = RawRequirement(
project_id=project.id,
content=raw_text,
source_type=source_type,
source_name=source_name,
knowledge=knowledge_text or None,
)
raw_req_id = db.create_raw_requirement(raw_req)
console.print(f"[dim]原始需求已存储 (ID={raw_req_id})[/dim]")
with console.status("[bold yellow]🤖 LLM 正在分解需求,请稍候...[/bold yellow]"):
llm = LLMClient()
analyzer = RequirementAnalyzer(llm)
func_reqs = analyzer.decompose(
raw_requirement=raw_text,
project_id=project.id,
raw_req_id=raw_req_id,
knowledge=knowledge_text,
)
for req in func_reqs:
req.id = db.create_functional_requirement(req)
console.print(f"[green]✓ 已生成 {len(func_reqs)} 个功能需求[/green]")
return raw_req_id, func_reqs
# ══════════════════════════════════════════════════════
# Step 4用户编辑功能需求
# ══════════════════════════════════════════════════════
def step_edit_requirements(
project: Project,
func_reqs: list,
raw_req_id: int,
non_interactive: bool = False,
skip_indices: list = None,
) -> list:
console.print(
f"\n[bold]Step 4 · 编辑功能需求[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
if non_interactive:
if skip_indices:
to_skip = set(skip_indices)
removed, kept = [], []
for req in func_reqs:
if req.index_no in to_skip:
db.delete_functional_requirement(req.id)
removed.append(req.title)
else:
kept.append(req)
func_reqs = kept
for i, req in enumerate(func_reqs, 1):
req.index_no = i
db.update_functional_requirement(req)
if removed:
console.print(f" [red]已跳过: {', '.join(removed)}[/red]")
print_functional_requirements(func_reqs)
return func_reqs
while True:
print_functional_requirements(func_reqs)
console.print(
"\n操作: [cyan]d[/cyan]=删除 [cyan]a[/cyan]=添加 "
"[cyan]e[/cyan]=编辑 [cyan]ok[/cyan]=确认继续"
)
action = Prompt.ask("请选择操作", default="ok").strip().lower()
if action == "ok":
break
elif action == "d":
idx_str = Prompt.ask("输入要删除的功能需求序号(多个用逗号分隔)")
to_delete = {int(x.strip()) for x in idx_str.split(",") if x.strip().isdigit()}
removed, kept = [], []
for req in func_reqs:
if req.index_no in to_delete:
db.delete_functional_requirement(req.id)
removed.append(req.title)
else:
kept.append(req)
func_reqs = kept
for i, req in enumerate(func_reqs, 1):
req.index_no = i
db.update_functional_requirement(req)
console.print(f"[red]✗ 已删除: {', '.join(removed)}[/red]")
elif action == "a":
title = Prompt.ask("功能标题")
description = Prompt.ask("功能描述")
func_name = Prompt.ask("函数名 (snake_case)")
priority = Prompt.ask(
"优先级", choices=["high", "medium", "low"], default="medium"
)
new_req = FunctionalRequirement(
project_id=project.id,
raw_req_id=raw_req_id,
index_no=len(func_reqs) + 1,
title=title,
description=description,
function_name=func_name,
priority=priority,
is_custom=True,
)
new_req.id = db.create_functional_requirement(new_req)
func_reqs.append(new_req)
console.print(f"[green]✓ 已添加自定义需求: {title}[/green]")
elif action == "e":
idx_str = Prompt.ask("输入要编辑的功能需求序号")
if not idx_str.isdigit():
continue
idx = int(idx_str)
target = next((r for r in func_reqs if r.index_no == idx), None)
if target is None:
console.print("[red]序号不存在[/red]")
continue
target.title = Prompt.ask("新标题", default=target.title)
target.description = Prompt.ask("新描述", default=target.description)
target.function_name = Prompt.ask("新函数名", default=target.function_name)
target.priority = Prompt.ask(
"新优先级", choices=["high", "medium", "low"], default=target.priority
)
db.update_functional_requirement(target)
console.print(f"[green]✓ 已更新: {target.title}[/green]")
return func_reqs
# ══════════════════════════════════════════════════════
# Step 5A生成函数签名 JSON不含 url 字段,待 5C 回写)
# ══════════════════════════════════════════════════════
def step_generate_signatures(
project: Project,
func_reqs: list,
output_dir: str,
knowledge_text: str,
json_file_name: str = "function_signatures.json",
non_interactive: bool = False,
) -> tuple:
"""
为所有功能需求生成函数签名写入初版 JSON不含 url 字段
url 字段将在 Step 5C 代码生成完成后回写并刷新 JSON 文件
Returns:
(signatures: List[dict], json_path: str)
"""
console.print(
f"\n[bold]Step 5A · 生成函数签名 JSON[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
llm = LLMClient()
analyzer = RequirementAnalyzer(llm)
success_count = 0
fail_count = 0
def on_progress(index, total, req, signature, error):
nonlocal success_count, fail_count
if error:
console.print(
f" [{index}/{total}] [yellow]⚠ {req.title} 签名生成失败,"
f"使用降级结构: {error}[/yellow]"
)
fail_count += 1
else:
console.print(
f" [{index}/{total}] [green]✓ {req.title}[/green] "
f"→ [dim]{signature.get('name')}()[/dim] "
f"params={len(signature.get('parameters', {}))}"
)
success_count += 1
console.print(f"[yellow]正在为 {len(func_reqs)} 个功能需求生成函数签名...[/yellow]")
signatures = analyzer.build_function_signatures_batch(
func_reqs=func_reqs,
knowledge=knowledge_text,
on_progress=on_progress,
)
# 校验
validation_report = validate_all_signatures(signatures)
if validation_report:
console.print(f"[yellow]⚠ 发现 {len(validation_report)} 个签名存在结构问题:[/yellow]")
for fname, errors in validation_report.items():
for err in errors:
console.print(f" [yellow]· {fname}: {err}[/yellow]")
else:
console.print("[green]✓ 所有签名结构校验通过[/green]")
# 写入初版 JSONurl 字段尚未填入)
json_path = write_function_signatures_json(
output_dir=output_dir,
signatures=signatures,
project_name=project.name,
project_description=project.description or "", # ← 传入项目描述
file_name=json_file_name,
)
console.print(
f"[green]✓ 签名 JSON 初版已写入: [cyan]{os.path.abspath(json_path)}[/cyan][/green]\n"
f" 成功: {success_count} 降级: {fail_count}"
)
return signatures, json_path
# ══════════════════════════════════════════════════════
# Step 5B生成代码文件收集 {函数名: 文件路径} 映射
# ══════════════════════════════════════════════════════
def step_generate_code(
project: Project,
func_reqs: list,
output_dir: str,
knowledge_text: str,
signatures: list,
non_interactive: bool = False,
) -> Dict[str, str]:
"""
依据签名约束批量生成代码文件
Returns:
func_name_to_url: {函数名: 代码文件绝对路径} 映射表
Step 5C 回写 url 字段使用
生成失败的函数不会出现在映射表中
"""
console.print(
f"\n[bold]Step 5B · 生成代码文件[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
generator = CodeGenerator(LLMClient())
success_count = 0
fail_count = 0
func_name_to_url: Dict[str, str] = {} # ← 收集 函数名 → 文件绝对路径
def on_progress(index, total, req, code_file, error):
nonlocal success_count, fail_count
if error:
console.print(f" [{index}/{total}] [red]✗ {req.title}: {error}[/red]")
fail_count += 1
else:
db.upsert_code_file(code_file)
req.status = "generated"
db.update_functional_requirement(req)
# 收集 函数名 → 绝对文件路径(作为 url 回写)
func_name_to_url[req.function_name] = os.path.abspath(code_file.file_path)
console.print(
f" [{index}/{total}] [green]✓ {req.title}[/green] "
f"→ [dim]{code_file.file_name}[/dim]"
)
success_count += 1
console.print(f"[yellow]开始生成 {len(func_reqs)} 个代码文件(签名约束模式)...[/yellow]")
generator.generate_batch(
func_reqs=func_reqs,
output_dir=output_dir,
language=project.language,
knowledge=knowledge_text,
signatures=signatures,
on_progress=on_progress,
)
req_summary = "\n".join(
f"{i+1}. **{r.title}** (`{r.function_name}`) - {r.description[:80]}"
for i, r in enumerate(func_reqs)
)
write_project_readme(output_dir, project.name, req_summary)
console.print(Panel(
f"[bold green]✅ 代码生成完成![/bold green]\n"
f"成功: {success_count} 失败: {fail_count}\n"
f"输出目录: [cyan]{os.path.abspath(output_dir)}[/cyan]",
border_style="green",
))
return func_name_to_url
# ══════════════════════════════════════════════════════
# Step 5C回写 url 字段并刷新 JSON
# ══════════════════════════════════════════════════════
def step_patch_signatures_url(
project: Project,
signatures: list,
func_name_to_url: Dict[str, str],
output_dir: str,
json_file_name: str,
non_interactive: bool = False,
) -> str:
"""
将代码文件路径回写到签名的 "url" 字段并重新写入 JSON 文件
执行流程
1. 调用 patch_signatures_with_url() 原地修改签名列表
2. 打印最终签名预览 url
3. 重新调用 write_function_signatures_json() 覆盖写入 JSON
Args:
project: 项目对象提供 name description
signatures: Step 5A 产出的签名列表将被原地修改
func_name_to_url: Step 5B 收集的 {函数名: 文件绝对路径} 映射
output_dir: JSON 文件所在目录
json_file_name: JSON 文件名
non_interactive: 是否非交互模式
Returns:
刷新后的 JSON 文件绝对路径
"""
console.print(
f"\n[bold]Step 5C · 回写代码文件路径url到签名 JSON[/bold]"
+ (" [dim](非交互)[/dim]" if non_interactive else ""),
style="blue",
)
# 原地回写 url 字段
patch_signatures_with_url(signatures, func_name_to_url)
patched = sum(1 for s in signatures if s.get("url"))
unpatched = len(signatures) - patched
if unpatched:
console.print(
f"[yellow]⚠ {unpatched} 个函数未能写入 url"
f"(对应代码文件生成失败)[/yellow]"
)
# 打印最终预览(含 url 列)
print_signatures_preview(signatures)
# 覆盖写入 JSON含 project.description
json_path = write_function_signatures_json(
output_dir=output_dir,
signatures=signatures,
project_name=project.name,
project_description=project.description or "",
file_name=json_file_name,
)
console.print(
f"[green]✓ 签名 JSON 已更新(含 url: "
f"[cyan]{os.path.abspath(json_path)}[/cyan][/green]\n"
f" 已回写: {patched} 未回写: {unpatched}"
)
return os.path.abspath(json_path)
# ══════════════════════════════════════════════════════
# 核心工作流
# ══════════════════════════════════════════════════════
def run_workflow(
project_name: str = None,
language: str = None,
description: str = "",
requirement_text: str = None,
requirement_file: str = None,
knowledge_files: tuple = (),
skip_indices: list = None,
json_file_name: str = "function_signatures.json",
non_interactive: bool = False,
):
"""完整工作流Step 1 → 5C"""
print_banner()
# Step 1
project = step_init_project(
project_name=project_name,
language=language,
description=description,
non_interactive=non_interactive,
)
# Step 2
raw_text, knowledge_text, source_name, source_type = step_input_requirement(
project=project,
requirement_text=requirement_text,
requirement_file=requirement_file,
knowledge_files=list(knowledge_files) if knowledge_files else [],
non_interactive=non_interactive,
)
# Step 3
raw_req_id, func_reqs = step_decompose_requirements(
project=project,
raw_text=raw_text,
knowledge_text=knowledge_text,
source_name=source_name,
source_type=source_type,
non_interactive=non_interactive,
)
# Step 4
func_reqs = step_edit_requirements(
project=project,
func_reqs=func_reqs,
raw_req_id=raw_req_id,
non_interactive=non_interactive,
skip_indices=skip_indices or [],
)
if not func_reqs:
console.print("[red]⚠ 功能需求列表为空,流程终止[/red]")
return
output_dir = ensure_project_dir(project.name)
# Step 5A生成签名初版不含 url
signatures, json_path = step_generate_signatures(
project=project,
func_reqs=func_reqs,
output_dir=output_dir,
knowledge_text=knowledge_text,
json_file_name=json_file_name,
non_interactive=non_interactive,
)
# Step 5B生成代码收集 {函数名: 文件路径}
func_name_to_url = step_generate_code(
project=project,
func_reqs=func_reqs,
output_dir=output_dir,
knowledge_text=knowledge_text,
signatures=signatures,
non_interactive=non_interactive,
)
# Step 5C回写 url 字段,刷新 JSON
json_path = step_patch_signatures_url(
project=project,
signatures=signatures,
func_name_to_url=func_name_to_url,
output_dir=output_dir,
json_file_name=json_file_name,
non_interactive=non_interactive,
)
console.print(Panel(
f"[bold cyan]🎉 全部流程完成![/bold cyan]\n"
f"项目: [bold]{project.name}[/bold]\n"
f"描述: {project.description or '(无)'}\n"
f"代码目录: [cyan]{os.path.abspath(output_dir)}[/cyan]\n"
f"签名文件: [cyan]{json_path}[/cyan]",
border_style="cyan",
))
# ══════════════════════════════════════════════════════
# CLI 入口click
# ══════════════════════════════════════════════════════
@click.command()
@click.option("--non-interactive", is_flag=True, default=False,
help="以非交互模式运行(所有参数通过命令行传入)")
@click.option("--project-name", "-p", default=None, help="项目名称")
@click.option("--language", "-l", default=None,
type=click.Choice(["python","javascript","typescript","java","go","rust"]),
help=f"目标代码语言(默认: {config.DEFAULT_LANGUAGE}")
@click.option("--description", "-d", default="", help="项目描述")
@click.option("--requirement-text","-r", default=None,
help="原始需求文本(与 --requirement-file 二选一)")
@click.option("--requirement-file","-f", default=None,
type=click.Path(exists=True),
help="原始需求文件路径(支持 .txt/.md/.pdf/.docx")
@click.option("--knowledge-file", "-k", default=None, multiple=True,
type=click.Path(exists=True),
help="知识库文件路径(可多次指定,如 -k a.md -k b.pdf")
@click.option("--skip-index", "-s", default=None, multiple=True, type=int,
help="要跳过的功能需求序号(可多次指定,如 -s 2 -s 5")
@click.option("--json-file-name", "-j", default="function_signatures.json",
help="函数签名 JSON 文件名(默认: function_signatures.json")
def cli(
non_interactive, project_name, language, description,
requirement_text, requirement_file, knowledge_file,
skip_index, json_file_name,
):
"""
需求分析 & 代码生成工具
\b
交互式运行推荐初次使用
python main.py
\b
非交互式运行示例
python main.py --non-interactive \\
--project-name "UserSystem" \\
--description "用户管理系统后端服务" \\
--language python \\
--requirement-text "用户管理系统,包含注册、登录、修改密码功能" \\
--knowledge-file docs/api_spec.md \\
--json-file-name api_signatures.json
\b
从文件读取需求 + 跳过部分功能需求
python main.py --non-interactive \\
--project-name "MyProject" \\
--requirement-file requirements.md \\
--skip-index 3 --skip-index 7
"""
try:
run_workflow(
project_name=project_name,
language=language,
description=description,
requirement_text=requirement_text,
requirement_file=requirement_file,
knowledge_files=knowledge_file,
skip_indices=list(skip_index) if skip_index else [],
json_file_name=json_file_name,
non_interactive=non_interactive,
)
except KeyboardInterrupt:
console.print("\n[yellow]用户中断,退出[/yellow]")
sys.exit(0)
except Exception as e:
console.print(f"\n[bold red]❌ 错误: {e}[/bold red]")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
cli()