90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
|
|
# core/llm_client.py - LLM 客户端封装
|
|||
|
|
import json
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
import config
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMClient:
|
|||
|
|
"""
|
|||
|
|
OpenAI 兼容 LLM 客户端封装。
|
|||
|
|
支持任何兼容 OpenAI API 格式的服务(OpenAI / Azure / 本地模型等)。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
api_key: str = config.LLM_API_KEY,
|
|||
|
|
base_url: str = config.LLM_BASE_URL,
|
|||
|
|
model: str = config.LLM_MODEL,
|
|||
|
|
temperature: float = config.LLM_TEMPERATURE,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化 LLM 客户端
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
api_key: API 密钥
|
|||
|
|
base_url: API 基础 URL
|
|||
|
|
model: 模型名称
|
|||
|
|
temperature: 生成温度(0~1,越低越确定)
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
ImportError: 未安装 openai 库
|
|||
|
|
ValueError: api_key 为空
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
from openai import OpenAI
|
|||
|
|
except ImportError:
|
|||
|
|
raise ImportError("请安装 openai: pip install openai")
|
|||
|
|
|
|||
|
|
if not api_key:
|
|||
|
|
raise ValueError("LLM_API_KEY 未配置,请在 .env 文件中设置")
|
|||
|
|
|
|||
|
|
self.model = model
|
|||
|
|
self.temperature = temperature
|
|||
|
|
self._client = OpenAI(api_key=api_key, base_url=base_url)
|
|||
|
|
|
|||
|
|
def chat(self, system_prompt: str, user_prompt: str) -> str:
|
|||
|
|
"""
|
|||
|
|
发送对话请求,返回模型回复文本
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_prompt: 系统提示词
|
|||
|
|
user_prompt: 用户输入
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
模型回复的文本内容
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
Exception: API 调用失败
|
|||
|
|
"""
|
|||
|
|
response = self._client.chat.completions.create(
|
|||
|
|
model=self.model,
|
|||
|
|
temperature=self.temperature,
|
|||
|
|
messages=[
|
|||
|
|
{"role": "system", "content": system_prompt},
|
|||
|
|
{"role": "user", "content": user_prompt},
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
return response.choices[0].message.content.strip()
|
|||
|
|
|
|||
|
|
def chat_json(self, system_prompt: str, user_prompt: str) -> dict:
|
|||
|
|
"""
|
|||
|
|
发送对话请求,解析并返回 JSON 结果
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_prompt: 系统提示词
|
|||
|
|
user_prompt: 用户输入
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
解析后的 dict 对象
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
json.JSONDecodeError: 模型返回非合法 JSON
|
|||
|
|
"""
|
|||
|
|
raw = self.chat(system_prompt, user_prompt)
|
|||
|
|
# 去除可能的 markdown 代码块包裹
|
|||
|
|
raw = raw.strip()
|
|||
|
|
if raw.startswith("```"):
|
|||
|
|
lines = raw.split("\n")
|
|||
|
|
raw = "\n".join(lines[1:-1])
|
|||
|
|
return json.loads(raw)
|