ruwiki-test/tests/test_adapters.py

279 lines
8.8 KiB
Python
Raw Normal View History

2025-07-11 21:28:58 +02:00
import asyncio
import time
from unittest.mock import AsyncMock, patch
import pytest
from openai import APIError, RateLimitError
from src.adapters import (
CircuitBreaker,
CircuitBreakerError,
LLMProviderAdapter,
LLMRateLimitError,
LLMTokenLimitError,
RateLimiter,
RuWikiAdapter,
)
class TestCircuitBreaker:
@pytest.mark.asyncio
async def test_successful_call(self):
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
async def test_func():
return "success"
result = await cb.call(test_func)
assert result == "success"
@pytest.mark.asyncio
async def test_failure_accumulation(self):
cb = CircuitBreaker(failure_threshold=2, recovery_timeout=1)
async def failing_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
await cb.call(failing_func)
with pytest.raises(ValueError):
await cb.call(failing_func)
with pytest.raises(CircuitBreakerError):
await cb.call(failing_func)
@pytest.mark.asyncio
async def test_recovery(self):
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
async def failing_func():
raise ValueError("Test error")
async def success_func():
return "recovered"
with pytest.raises(ValueError):
await cb.call(failing_func)
with pytest.raises(CircuitBreakerError):
await cb.call(failing_func)
await asyncio.sleep(0.2)
result = await cb.call(success_func)
assert result == "recovered"
class TestRateLimiter:
@pytest.mark.asyncio
async def test_concurrency_limit(self):
limiter = RateLimiter(max_concurrent=2)
results = []
async def test_task(task_id: int):
async with limiter:
results.append(f"start_{task_id}")
await asyncio.sleep(0.1)
results.append(f"end_{task_id}")
tasks = [test_task(i) for i in range(3)]
await asyncio.gather(*tasks)
start_count = 0
max_concurrent = 0
for result in results:
if result.startswith("start_"):
start_count += 1
max_concurrent = max(max_concurrent, start_count)
elif result.startswith("end_"):
start_count -= 1
assert max_concurrent <= 2
class TestRuWikiAdapter:
def test_extract_title_from_url(self):
adapter = RuWikiAdapter
title = adapter.extract_title_from_url("https://ru.wikipedia.org/wiki/Тест")
assert title == "Тест"
title = adapter.extract_title_from_url("https://ru.wikipedia.org/wiki/Тест_статья")
assert title == "Тест статья"
title = adapter.extract_title_from_url(
"https://ru.wikipedia.org/wiki/%D0%A2%D0%B5%D1%81%D1%82"
)
assert title == "Тест"
def test_extract_title_invalid_url(self):
adapter = RuWikiAdapter
with pytest.raises(ValueError):
adapter.extract_title_from_url("https://example.com/invalid")
with pytest.raises(ValueError):
adapter.extract_title_from_url("https://ru.wikipedia.org/invalid")
def test_clean_wikitext(self, test_config, sample_wikitext):
"""Тест очистки wiki-текста."""
adapter = RuWikiAdapter(test_config)
cleaned = adapter._clean_wikitext(sample_wikitext)
assert "{{навигация" not in cleaned
assert "[[Категория:" not in cleaned
assert "'''Тест'''" in cleaned
assert "== Определение ==" in cleaned
@pytest.mark.asyncio
async def test_health_check_success(self, test_config):
adapter = RuWikiAdapter(test_config)
with patch.object(adapter, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_get_client.return_value = mock_client
with patch("asyncio.to_thread") as mock_to_thread:
mock_to_thread.return_value = {"query": {"general": {}}}
result = await adapter.health_check()
assert result is True
@pytest.mark.asyncio
async def test_health_check_failure(self, test_config):
adapter = RuWikiAdapter(test_config)
with patch.object(adapter, "_get_client") as mock_get_client:
mock_get_client.side_effect = ConnectionError("Network error")
result = await adapter.health_check()
assert result is False
class TestLLMProviderAdapter:
def test_count_tokens(self, test_config):
"""Тест подсчёта токенов."""
adapter = LLMProviderAdapter(test_config)
count = adapter.count_tokens("Hello world")
assert count > 0
count = adapter.count_tokens("")
assert count == 0
@pytest.mark.asyncio
async def test_rpm_limiting(self, test_config):
test_config.openai_rpm = 2
adapter = LLMProviderAdapter(test_config)
current_time = time.time()
adapter.request_times = [current_time - 10, current_time - 5]
start_time = time.time()
await adapter._check_rpm_limit()
elapsed = time.time() - start_time
assert elapsed > 0.01
@pytest.mark.asyncio
async def test_simplify_text_token_limit_error(self, test_config):
adapter = LLMProviderAdapter(test_config)
long_text = "word " * 2000
with pytest.raises(LLMTokenLimitError):
await adapter.simplify_text("Test", long_text, "template")
@pytest.mark.asyncio
async def test_simplify_text_success(self, test_config, mock_openai_response):
adapter = LLMProviderAdapter(test_config)
with patch.object(adapter.client.chat.completions, "create") as mock_create:
mock_create.return_value = mock_openai_response
with patch.object(adapter, "_check_rpm_limit"):
result = await adapter.simplify_text(
title="Тест",
wiki_text="Тестовый текст",
prompt_template="### role: user\n{wiki_source_text}",
)
simplified_text, input_tokens, output_tokens = result
assert "Упрощённый текст для школьников" in simplified_text
assert "###END###" not in simplified_text
assert input_tokens > 0
assert output_tokens > 0
@pytest.mark.asyncio
async def test_simplify_text_openai_error(self, test_config):
adapter = LLMProviderAdapter(test_config)
with patch.object(adapter.client.chat.completions, "create") as mock_create:
mock_create.side_effect = RateLimitError(
"Rate limit exceeded", response=None, body=None
)
with patch.object(adapter, "_check_rpm_limit"):
with pytest.raises(LLMRateLimitError):
await adapter.simplify_text(
title="Тест",
wiki_text="Тестовый текст",
prompt_template="### role: user\n{wiki_source_text}",
)
def test_parse_prompt_template(self, test_config):
adapter = LLMProviderAdapter(test_config)
template = """### role: system
Ты помощник.
### role: user
Задание: {task}"""
messages = adapter._parse_prompt_template(template)
assert len(messages) == 2
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "Ты помощник."
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "Задание: {task}"
def test_parse_prompt_template_fallback(self, test_config):
adapter = LLMProviderAdapter(test_config)
template = "Обычный текст без ролей"
messages = adapter._parse_prompt_template(template)
assert len(messages) == 1
assert messages[0]["role"] == "user"
assert messages[0]["content"] == template
@pytest.mark.asyncio
async def test_health_check_success(self, test_config, mock_openai_response):
adapter = LLMProviderAdapter(test_config)
with patch.object(adapter.client.chat.completions, "create") as mock_create:
mock_create.return_value = mock_openai_response
result = await adapter.health_check()
assert result is True
@pytest.mark.asyncio
async def test_health_check_failure(self, test_config):
adapter = LLMProviderAdapter(test_config)
with patch.object(adapter.client.chat.completions, "create") as mock_create:
mock_create.side_effect = APIError("API Error", response=None, body=None)
result = await adapter.health_check()
assert result is False