ruwiki-test/src/adapters/llm.py

197 lines
6.4 KiB
Python

import asyncio
import time
import openai
import structlog
import tiktoken
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from ..models import AppConfig
from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry
logger = structlog.get_logger()
class LLMError(Exception):
pass
class LLMTokenLimitError(LLMError):
pass
class LLMRateLimitError(LLMError):
pass
class LLMProviderAdapter(BaseAdapter):
def __init__(self, config: AppConfig) -> None:
super().__init__("llm_adapter")
self.config = config
self.client = AsyncOpenAI(api_key=config.openai_api_key)
try:
self.tokenizer = tiktoken.encoding_for_model(config.openai_model)
except KeyError:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
self.rate_limiter = RateLimiter(
max_concurrent=config.max_concurrent_llm,
name="llm_limiter",
)
self.circuit_breaker = CircuitBreaker(
failure_threshold=config.circuit_failure_threshold,
recovery_timeout=config.circuit_recovery_timeout,
name="llm_circuit",
)
self.request_times: list[float] = []
self.rpm_lock = asyncio.Lock()
def count_tokens(self, text: str) -> int:
try:
return len(self.tokenizer.encode(text))
except Exception as e:
self.logger.warning("Ошибка подсчёта токенов", error=str(e))
return len(text) // 4
async def _check_rpm_limit(self) -> None:
async with self.rpm_lock:
current_time = time.time()
self.request_times = [
req_time for req_time in self.request_times if current_time - req_time < 60
]
if len(self.request_times) >= self.config.openai_rpm:
oldest_request = min(self.request_times)
wait_time = 60 - (current_time - oldest_request)
if wait_time > 0:
self.logger.info(
"Ожидание из-за RPM лимита",
wait_seconds=wait_time,
current_rpm=len(self.request_times),
)
await asyncio.sleep(wait_time)
self.request_times.append(current_time)
async def _make_completion_request(
self,
messages: list[dict[str, str]],
) -> ChatCompletion:
try:
response = await self.client.chat.completions.create(
model=self.config.openai_model,
messages=messages,
temperature=self.config.openai_temperature,
max_tokens=1500,
)
return response
except openai.RateLimitError as e:
raise LLMRateLimitError(f"Rate limit exceeded: {e}") from e
except openai.APIError as e:
raise LLMError(f"OpenAI API error: {e}") from e
async def simplify_text(
self,
title: str,
wiki_text: str,
prompt_template: str,
) -> tuple[str, int, int]:
input_tokens = self.count_tokens(wiki_text)
if input_tokens > 6000:
raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов (лимит 6000)")
try:
prompt_text = prompt_template.format(
title=title,
wiki_source_text=wiki_text,
)
except KeyError as e:
raise LLMError(f"Ошибка в шаблоне промпта: отсутствует ключ {e}") from e
messages = self._parse_prompt_template(prompt_text)
total_input_tokens = sum(self.count_tokens(msg["content"]) for msg in messages)
async with self.rate_limiter:
await self._check_rpm_limit()
response = await self.circuit_breaker.call(
lambda: with_retry(
lambda: self._make_completion_request(messages),
max_attempts=self.config.max_retries,
min_wait=self.config.retry_delay,
max_wait=self.config.retry_delay * 4,
retry_exceptions=(LLMRateLimitError, ConnectionError, TimeoutError),
name=f"simplify_{title}",
)
)
if not response.choices:
raise LLMError("Пустой ответ от OpenAI")
simplified_text = response.choices[0].message.content
if not simplified_text:
raise LLMError("OpenAI вернул пустой текст")
simplified_text = simplified_text.replace("###END###", "").strip()
output_tokens = self.count_tokens(simplified_text)
if output_tokens > 1200:
self.logger.warning(
"Упрощённый текст превышает лимит",
output_tokens=output_tokens,
title=title,
)
self.logger.info(
"Текст успешно упрощён",
title=title,
input_tokens=total_input_tokens,
output_tokens=output_tokens,
)
return simplified_text, total_input_tokens, output_tokens
def _parse_prompt_template(self, prompt_text: str) -> list[dict[str, str]]:
messages: list[dict[str, str]] = []
parts = prompt_text.split("### role:")
for part in parts[1:]:
lines = part.strip().split("\n", 1)
if len(lines) < 2:
continue
role = lines[0].strip()
content = lines[1].strip()
if role in ("system", "user", "assistant"):
messages.append({"role": role, "content": content})
if not messages:
messages = [{"role": "user", "content": prompt_text}]
return messages
async def health_check(self) -> bool:
try:
test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}]
response = await self.client.chat.completions.create(
model=self.config.openai_model,
messages=test_messages,
temperature=0,
max_tokens=10,
)
return bool(response.choices and response.choices[0].message.content)
except Exception as e:
self.logger.error("LLM health check failed", error=str(e))
return False