113 lines
3.7 KiB
Python
113 lines
3.7 KiB
Python
# core/llm_client.py - LLM API 调用封装
|
||
import time
|
||
import json
|
||
|
||
from openai import OpenAI, APIError, APITimeoutError, RateLimitError
|
||
|
||
from gui_ai_developer import config
|
||
|
||
|
||
class LLMClient:
|
||
"""封装 OpenAI 兼容接口,提供统一的调用入口与重试机制"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str = None,
|
||
api_base: str = None,
|
||
model: str = None,
|
||
):
|
||
self.model = model or config.LLM_MODEL
|
||
self.client = OpenAI(
|
||
api_key =api_key or config.LLM_API_KEY,
|
||
base_url =api_base or config.LLM_API_BASE,
|
||
timeout = config.LLM_TIMEOUT,
|
||
)
|
||
|
||
def chat(
|
||
self,
|
||
prompt: str,
|
||
system: str = "You are a helpful assistant.",
|
||
temperature: float = 0.2,
|
||
max_tokens: int = 4096,
|
||
) -> str:
|
||
"""
|
||
发送单轮对话请求,返回模型回复文本。
|
||
|
||
Args:
|
||
prompt: 用户消息
|
||
system: 系统提示词
|
||
temperature: 采样温度
|
||
max_tokens: 最大输出 token 数
|
||
|
||
Returns:
|
||
模型回复的纯文本字符串
|
||
|
||
Raises:
|
||
RuntimeError: 超过最大重试次数后仍失败
|
||
"""
|
||
last_error = None
|
||
for attempt in range(1, config.LLM_MAX_RETRY + 1):
|
||
try:
|
||
resp = self.client.chat.completions.create(
|
||
model = self.model,
|
||
messages = [
|
||
{"role": "system", "content": system},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
temperature = temperature,
|
||
max_tokens = max_tokens,
|
||
)
|
||
return resp.choices[0].message.content.strip()
|
||
except RateLimitError as e:
|
||
wait = 2 ** attempt
|
||
last_error = e
|
||
time.sleep(wait)
|
||
except APITimeoutError as e:
|
||
last_error = e
|
||
time.sleep(2)
|
||
except APIError as e:
|
||
last_error = e
|
||
if attempt >= config.LLM_MAX_RETRY:
|
||
break
|
||
time.sleep(1)
|
||
raise RuntimeError(
|
||
f"LLM 调用失败(已重试 {config.LLM_MAX_RETRY} 次): {last_error}"
|
||
)
|
||
|
||
def chat_json(
|
||
self,
|
||
prompt: str,
|
||
system: str = "You are a helpful assistant. Always respond with valid JSON.",
|
||
temperature: float = 0.1,
|
||
max_tokens: int = 4096,
|
||
) -> any:
|
||
"""
|
||
发送请求并将回复解析为 JSON 对象。
|
||
|
||
Returns:
|
||
解析后的 Python 对象(dict 或 list)
|
||
|
||
Raises:
|
||
ValueError: JSON 解析失败
|
||
RuntimeError: LLM 调用失败
|
||
"""
|
||
raw = self.chat(prompt, system=system, temperature=temperature, max_tokens=max_tokens)
|
||
# 去除可能的 markdown 代码块包裹
|
||
cleaned = self._strip_markdown_code_block(raw)
|
||
try:
|
||
return json.loads(cleaned)
|
||
except json.JSONDecodeError as e:
|
||
raise ValueError(f"LLM 返回内容无法解析为 JSON: {e}\n原始内容:\n{raw}")
|
||
|
||
@staticmethod
|
||
def _strip_markdown_code_block(text: str) -> str:
|
||
"""去除 ```json ... ``` 或 ``` ... ``` 包裹"""
|
||
text = text.strip()
|
||
if text.startswith("```"):
|
||
lines = text.splitlines()
|
||
# 去掉首行(```json 或 ```)和末行(```)
|
||
inner = lines[1:] if lines[-1].strip() == "```" else lines[1:]
|
||
if inner and inner[-1].strip() == "```":
|
||
inner = inner[:-1]
|
||
text = "\n".join(inner).strip()
|
||
return text |