2026-03-05 05:38:26 +00:00
|
|
|
|
# core/llm_client.py - LLM API 调用封装
|
|
|
|
|
|
import time
|
2026-03-04 18:09:45 +00:00
|
|
|
|
import json
|
|
|
|
|
|
|
2026-03-05 05:38:26 +00:00
|
|
|
|
from openai import OpenAI, APIError, APITimeoutError, RateLimitError
|
|
|
|
|
|
|
2026-03-06 14:25:18 +00:00
|
|
|
|
import config
|
2026-03-04 18:09:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLMClient:
|
2026-03-05 05:38:26 +00:00
|
|
|
|
"""封装 OpenAI 兼容接口,提供统一的调用入口与重试机制"""
|
2026-03-04 18:09:45 +00:00
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
2026-03-05 05:38:26 +00:00
|
|
|
|
api_key: str = None,
|
|
|
|
|
|
api_base: str = None,
|
|
|
|
|
|
model: str = None,
|
2026-03-04 18:09:45 +00:00
|
|
|
|
):
|
2026-03-06 13:50:01 +00:00
|
|
|
|
self.model = model or config.LLM_MODEL
|
2026-03-05 05:38:26 +00:00
|
|
|
|
self.client = OpenAI(
|
2026-03-06 13:50:01 +00:00
|
|
|
|
api_key =api_key or config.LLM_API_KEY,
|
|
|
|
|
|
base_url =api_base or config.LLM_API_BASE,
|
2026-03-05 05:38:26 +00:00
|
|
|
|
timeout = config.LLM_TIMEOUT,
|
|
|
|
|
|
)
|
2026-03-04 18:09:45 +00:00
|
|
|
|
|
2026-03-05 05:38:26 +00:00
|
|
|
|
def chat(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompt: str,
|
|
|
|
|
|
system: str = "You are a helpful assistant.",
|
|
|
|
|
|
temperature: float = 0.2,
|
|
|
|
|
|
max_tokens: int = 4096,
|
|
|
|
|
|
) -> str:
|
2026-03-04 18:09:45 +00:00
|
|
|
|
"""
|
2026-03-05 05:38:26 +00:00
|
|
|
|
发送单轮对话请求,返回模型回复文本。
|
2026-03-04 18:09:45 +00:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-03-05 05:38:26 +00:00
|
|
|
|
prompt: 用户消息
|
|
|
|
|
|
system: 系统提示词
|
|
|
|
|
|
temperature: 采样温度
|
|
|
|
|
|
max_tokens: 最大输出 token 数
|
2026-03-04 18:09:45 +00:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
2026-03-05 05:38:26 +00:00
|
|
|
|
模型回复的纯文本字符串
|
2026-03-04 18:09:45 +00:00
|
|
|
|
|
|
|
|
|
|
Raises:
|
2026-03-05 05:38:26 +00:00
|
|
|
|
RuntimeError: 超过最大重试次数后仍失败
|
2026-03-04 18:09:45 +00:00
|
|
|
|
"""
|
2026-03-05 05:38:26 +00:00
|
|
|
|
last_error = None
|
|
|
|
|
|
for attempt in range(1, config.LLM_MAX_RETRY + 1):
|
|
|
|
|
|
try:
|
|
|
|
|
|
resp = self.client.chat.completions.create(
|
|
|
|
|
|
model = self.model,
|
|
|
|
|
|
messages = [
|
|
|
|
|
|
{"role": "system", "content": system},
|
|
|
|
|
|
{"role": "user", "content": prompt},
|
|
|
|
|
|
],
|
|
|
|
|
|
temperature = temperature,
|
|
|
|
|
|
max_tokens = max_tokens,
|
|
|
|
|
|
)
|
|
|
|
|
|
return resp.choices[0].message.content.strip()
|
|
|
|
|
|
except RateLimitError as e:
|
|
|
|
|
|
wait = 2 ** attempt
|
|
|
|
|
|
last_error = e
|
|
|
|
|
|
time.sleep(wait)
|
|
|
|
|
|
except APITimeoutError as e:
|
|
|
|
|
|
last_error = e
|
|
|
|
|
|
time.sleep(2)
|
|
|
|
|
|
except APIError as e:
|
|
|
|
|
|
last_error = e
|
|
|
|
|
|
if attempt >= config.LLM_MAX_RETRY:
|
|
|
|
|
|
break
|
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
|
f"LLM 调用失败(已重试 {config.LLM_MAX_RETRY} 次): {last_error}"
|
2026-03-04 18:09:45 +00:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-03-05 05:38:26 +00:00
|
|
|
|
def chat_json(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompt: str,
|
|
|
|
|
|
system: str = "You are a helpful assistant. Always respond with valid JSON.",
|
|
|
|
|
|
temperature: float = 0.1,
|
|
|
|
|
|
max_tokens: int = 4096,
|
|
|
|
|
|
) -> any:
|
2026-03-04 18:09:45 +00:00
|
|
|
|
"""
|
2026-03-05 05:38:26 +00:00
|
|
|
|
发送请求并将回复解析为 JSON 对象。
|
2026-03-04 18:09:45 +00:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
2026-03-05 05:38:26 +00:00
|
|
|
|
解析后的 Python 对象(dict 或 list)
|
2026-03-04 18:09:45 +00:00
|
|
|
|
|
|
|
|
|
|
Raises:
|
2026-03-05 05:38:26 +00:00
|
|
|
|
ValueError: JSON 解析失败
|
|
|
|
|
|
RuntimeError: LLM 调用失败
|
2026-03-04 18:09:45 +00:00
|
|
|
|
"""
|
2026-03-05 05:38:26 +00:00
|
|
|
|
raw = self.chat(prompt, system=system, temperature=temperature, max_tokens=max_tokens)
|
2026-03-04 18:09:45 +00:00
|
|
|
|
# 去除可能的 markdown 代码块包裹
|
2026-03-05 05:38:26 +00:00
|
|
|
|
cleaned = self._strip_markdown_code_block(raw)
|
|
|
|
|
|
try:
|
|
|
|
|
|
return json.loads(cleaned)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise ValueError(f"LLM 返回内容无法解析为 JSON: {e}\n原始内容:\n{raw}")
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _strip_markdown_code_block(text: str) -> str:
|
|
|
|
|
|
"""去除 ```json ... ``` 或 ``` ... ``` 包裹"""
|
|
|
|
|
|
text = text.strip()
|
|
|
|
|
|
if text.startswith("```"):
|
|
|
|
|
|
lines = text.splitlines()
|
|
|
|
|
|
# 去掉首行(```json 或 ```)和末行(```)
|
|
|
|
|
|
inner = lines[1:] if lines[-1].strip() == "```" else lines[1:]
|
|
|
|
|
|
if inner and inner[-1].strip() == "```":
|
|
|
|
|
|
inner = inner[:-1]
|
|
|
|
|
|
text = "\n".join(inner).strip()
|
|
|
|
|
|
return text
|