AIDeveloper-PC/requirements_generator/core/llm_client.py

113 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# core/llm_client.py - LLM API 调用封装
import time
import json
from openai import OpenAI, APIError, APITimeoutError, RateLimitError
import config
class LLMClient:
"""封装 OpenAI 兼容接口,提供统一的调用入口与重试机制"""
def __init__(
self,
api_key: str = None,
api_base: str = None,
model: str = None,
):
self.model = model or config.LLM_MODEL
self.client = OpenAI(
api_key =api_key or config.LLM_API_KEY,
base_url =api_base or config.LLM_API_BASE,
timeout = config.LLM_TIMEOUT,
)
def chat(
self,
prompt: str,
system: str = "You are a helpful assistant.",
temperature: float = 0.2,
max_tokens: int = 4096,
) -> str:
"""
发送单轮对话请求,返回模型回复文本。
Args:
prompt: 用户消息
system: 系统提示词
temperature: 采样温度
max_tokens: 最大输出 token 数
Returns:
模型回复的纯文本字符串
Raises:
RuntimeError: 超过最大重试次数后仍失败
"""
last_error = None
for attempt in range(1, config.LLM_MAX_RETRY + 1):
try:
resp = self.client.chat.completions.create(
model = self.model,
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
],
temperature = temperature,
max_tokens = max_tokens,
)
return resp.choices[0].message.content.strip()
except RateLimitError as e:
wait = 2 ** attempt
last_error = e
time.sleep(wait)
except APITimeoutError as e:
last_error = e
time.sleep(2)
except APIError as e:
last_error = e
if attempt >= config.LLM_MAX_RETRY:
break
time.sleep(1)
raise RuntimeError(
f"LLM 调用失败(已重试 {config.LLM_MAX_RETRY} 次): {last_error}"
)
def chat_json(
self,
prompt: str,
system: str = "You are a helpful assistant. Always respond with valid JSON.",
temperature: float = 0.1,
max_tokens: int = 4096,
) -> any:
"""
发送请求并将回复解析为 JSON 对象。
Returns:
解析后的 Python 对象dict 或 list
Raises:
ValueError: JSON 解析失败
RuntimeError: LLM 调用失败
"""
raw = self.chat(prompt, system=system, temperature=temperature, max_tokens=max_tokens)
# 去除可能的 markdown 代码块包裹
cleaned = self._strip_markdown_code_block(raw)
try:
return json.loads(cleaned)
except json.JSONDecodeError as e:
raise ValueError(f"LLM 返回内容无法解析为 JSON: {e}\n原始内容:\n{raw}")
@staticmethod
def _strip_markdown_code_block(text: str) -> str:
"""去除 ```json ... ``` 或 ``` ... ``` 包裹"""
text = text.strip()
if text.startswith("```"):
lines = text.splitlines()
# 去掉首行(```json 或 ```)和末行(```
inner = lines[1:] if lines[-1].strip() == "```" else lines[1:]
if inner and inner[-1].strip() == "```":
inner = inner[:-1]
text = "\n".join(inner).strip()
return text