90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
# core/llm_client.py - LLM 客户端封装
|
||
import json
|
||
from typing import Optional
|
||
|
||
import config
|
||
|
||
|
||
class LLMClient:
|
||
"""
|
||
OpenAI 兼容 LLM 客户端封装。
|
||
支持任何兼容 OpenAI API 格式的服务(OpenAI / Azure / 本地模型等)。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str = config.LLM_API_KEY,
|
||
base_url: str = config.LLM_BASE_URL,
|
||
model: str = config.LLM_MODEL,
|
||
temperature: float = config.LLM_TEMPERATURE,
|
||
):
|
||
"""
|
||
初始化 LLM 客户端
|
||
|
||
Args:
|
||
api_key: API 密钥
|
||
base_url: API 基础 URL
|
||
model: 模型名称
|
||
temperature: 生成温度(0~1,越低越确定)
|
||
|
||
Raises:
|
||
ImportError: 未安装 openai 库
|
||
ValueError: api_key 为空
|
||
"""
|
||
try:
|
||
from openai import OpenAI
|
||
except ImportError:
|
||
raise ImportError("请安装 openai: pip install openai")
|
||
|
||
if not api_key:
|
||
raise ValueError("LLM_API_KEY 未配置,请在 .env 文件中设置")
|
||
|
||
self.model = model
|
||
self.temperature = temperature
|
||
self._client = OpenAI(api_key=api_key, base_url=base_url)
|
||
|
||
def chat(self, system_prompt: str, user_prompt: str) -> str:
|
||
"""
|
||
发送对话请求,返回模型回复文本
|
||
|
||
Args:
|
||
system_prompt: 系统提示词
|
||
user_prompt: 用户输入
|
||
|
||
Returns:
|
||
模型回复的文本内容
|
||
|
||
Raises:
|
||
Exception: API 调用失败
|
||
"""
|
||
response = self._client.chat.completions.create(
|
||
model=self.model,
|
||
temperature=self.temperature,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
)
|
||
return response.choices[0].message.content.strip()
|
||
|
||
def chat_json(self, system_prompt: str, user_prompt: str) -> dict:
|
||
"""
|
||
发送对话请求,解析并返回 JSON 结果
|
||
|
||
Args:
|
||
system_prompt: 系统提示词
|
||
user_prompt: 用户输入
|
||
|
||
Returns:
|
||
解析后的 dict 对象
|
||
|
||
Raises:
|
||
json.JSONDecodeError: 模型返回非合法 JSON
|
||
"""
|
||
raw = self.chat(system_prompt, user_prompt)
|
||
# 去除可能的 markdown 代码块包裹
|
||
raw = raw.strip()
|
||
if raw.startswith("```"):
|
||
lines = raw.split("\n")
|
||
raw = "\n".join(lines[1:-1])
|
||
return json.loads(raw) |