277 lines
8.7 KiB
Python
277 lines
8.7 KiB
Python
|
import asyncio
|
|||
|
import time
|
|||
|
from unittest.mock import AsyncMock, patch
|
|||
|
|
|||
|
import pytest
|
|||
|
from openai import APIError, RateLimitError
|
|||
|
|
|||
|
from src.adapters import (
|
|||
|
CircuitBreaker,
|
|||
|
CircuitBreakerError,
|
|||
|
LLMProviderAdapter,
|
|||
|
LLMRateLimitError,
|
|||
|
LLMTokenLimitError,
|
|||
|
RateLimiter,
|
|||
|
RuWikiAdapter,
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
class TestCircuitBreaker:
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_successful_call(self):
|
|||
|
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
|||
|
|
|||
|
async def test_func():
|
|||
|
return "success"
|
|||
|
|
|||
|
result = await cb.call(test_func)
|
|||
|
assert result == "success"
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_failure_accumulation(self):
|
|||
|
cb = CircuitBreaker(failure_threshold=2, recovery_timeout=1)
|
|||
|
|
|||
|
async def failing_func():
|
|||
|
raise ValueError("Test error")
|
|||
|
|
|||
|
with pytest.raises(ValueError):
|
|||
|
await cb.call(failing_func)
|
|||
|
|
|||
|
with pytest.raises(ValueError):
|
|||
|
await cb.call(failing_func)
|
|||
|
|
|||
|
with pytest.raises(CircuitBreakerError):
|
|||
|
await cb.call(failing_func)
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_recovery(self):
|
|||
|
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
|
|||
|
|
|||
|
async def failing_func():
|
|||
|
raise ValueError("Test error")
|
|||
|
|
|||
|
async def success_func():
|
|||
|
return "recovered"
|
|||
|
|
|||
|
with pytest.raises(ValueError):
|
|||
|
await cb.call(failing_func)
|
|||
|
|
|||
|
with pytest.raises(CircuitBreakerError):
|
|||
|
await cb.call(failing_func)
|
|||
|
|
|||
|
await asyncio.sleep(0.2)
|
|||
|
|
|||
|
result = await cb.call(success_func)
|
|||
|
assert result == "recovered"
|
|||
|
|
|||
|
|
|||
|
class TestRateLimiter:
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_concurrency_limit(self):
|
|||
|
limiter = RateLimiter(max_concurrent=2)
|
|||
|
results = []
|
|||
|
|
|||
|
async def test_task(task_id: int):
|
|||
|
async with limiter:
|
|||
|
results.append(f"start_{task_id}")
|
|||
|
await asyncio.sleep(0.1)
|
|||
|
results.append(f"end_{task_id}")
|
|||
|
|
|||
|
tasks = [test_task(i) for i in range(3)]
|
|||
|
await asyncio.gather(*tasks)
|
|||
|
|
|||
|
start_count = 0
|
|||
|
max_concurrent = 0
|
|||
|
|
|||
|
for result in results:
|
|||
|
if result.startswith("start_"):
|
|||
|
start_count += 1
|
|||
|
max_concurrent = max(max_concurrent, start_count)
|
|||
|
elif result.startswith("end_"):
|
|||
|
start_count -= 1
|
|||
|
|
|||
|
assert max_concurrent <= 2
|
|||
|
|
|||
|
|
|||
|
class TestRuWikiAdapter:
|
|||
|
|
|||
|
def test_extract_title_from_url(self):
|
|||
|
adapter = RuWikiAdapter
|
|||
|
|
|||
|
title = adapter.extract_title_from_url("https://ru.wikipedia.org/wiki/Тест")
|
|||
|
assert title == "Тест"
|
|||
|
|
|||
|
title = adapter.extract_title_from_url("https://ru.wikipedia.org/wiki/Тест_статья")
|
|||
|
assert title == "Тест статья"
|
|||
|
|
|||
|
title = adapter.extract_title_from_url(
|
|||
|
"https://ru.wikipedia.org/wiki/%D0%A2%D0%B5%D1%81%D1%82"
|
|||
|
)
|
|||
|
assert title == "Тест"
|
|||
|
|
|||
|
def test_extract_title_invalid_url(self):
|
|||
|
adapter = RuWikiAdapter
|
|||
|
|
|||
|
with pytest.raises(ValueError):
|
|||
|
adapter.extract_title_from_url("https://example.com/invalid")
|
|||
|
|
|||
|
with pytest.raises(ValueError):
|
|||
|
adapter.extract_title_from_url("https://ru.wikipedia.org/invalid")
|
|||
|
|
|||
|
def test_clean_wikitext(self, test_config, sample_wikitext):
|
|||
|
adapter = RuWikiAdapter(test_config)
|
|||
|
|
|||
|
cleaned = adapter._clean_wikitext(sample_wikitext)
|
|||
|
|
|||
|
assert "{{навигация" not in cleaned
|
|||
|
assert "[[Категория:" not in cleaned
|
|||
|
|
|||
|
assert "'''Тест'''" in cleaned
|
|||
|
assert "== Определение ==" in cleaned
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_health_check_success(self, test_config):
|
|||
|
adapter = RuWikiAdapter(test_config)
|
|||
|
|
|||
|
with patch.object(adapter, "_get_client") as mock_get_client:
|
|||
|
mock_client = AsyncMock()
|
|||
|
mock_get_client.return_value = mock_client
|
|||
|
|
|||
|
with patch("asyncio.to_thread") as mock_to_thread:
|
|||
|
mock_to_thread.return_value = {"query": {"general": {}}}
|
|||
|
|
|||
|
result = await adapter.health_check()
|
|||
|
assert result is True
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_health_check_failure(self, test_config):
|
|||
|
adapter = RuWikiAdapter(test_config)
|
|||
|
|
|||
|
with patch.object(adapter, "_get_client") as mock_get_client:
|
|||
|
mock_get_client.side_effect = ConnectionError("Network error")
|
|||
|
|
|||
|
result = await adapter.health_check()
|
|||
|
assert result is False
|
|||
|
|
|||
|
|
|||
|
class TestLLMProviderAdapter:
|
|||
|
|
|||
|
def test_count_tokens(self, test_config):
|
|||
|
adapter = LLMProviderAdapter(test_config)
|
|||
|
|
|||
|
count = adapter.count_tokens("Hello world")
|
|||
|
assert count > 0
|
|||
|
|
|||
|
count = adapter.count_tokens("")
|
|||
|
assert count == 0
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_rpm_limiting(self, test_config):
|
|||
|
test_config.openai_rpm = 2
|
|||
|
adapter = LLMProviderAdapter(test_config)
|
|||
|
|
|||
|
current_time = time.time()
|
|||
|
adapter.request_times = [current_time - 10, current_time - 5]
|
|||
|
|
|||
|
start_time = time.time()
|
|||
|
await adapter._check_rpm_limit()
|
|||
|
elapsed = time.time() - start_time
|
|||
|
|
|||
|
assert elapsed > 0.01
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_simplify_text_token_limit_error(self, test_config):
|
|||
|
adapter = LLMProviderAdapter(test_config)
|
|||
|
|
|||
|
long_text = "word " * 2000
|
|||
|
|
|||
|
with pytest.raises(LLMTokenLimitError):
|
|||
|
await adapter.simplify_text("Test", long_text, "template")
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_simplify_text_success(self, test_config, mock_openai_response):
|
|||
|
adapter = LLMProviderAdapter(test_config)
|
|||
|
|
|||
|
with patch.object(adapter.client.chat.completions, "create") as mock_create:
|
|||
|
mock_create.return_value = mock_openai_response
|
|||
|
|
|||
|
with patch.object(adapter, "_check_rpm_limit"):
|
|||
|
result = await adapter.simplify_text(
|
|||
|
title="Тест",
|
|||
|
wiki_text="Тестовый текст",
|
|||
|
prompt_template="### role: user\n{wiki_source_text}",
|
|||
|
)
|
|||
|
|
|||
|
simplified_text, input_tokens, output_tokens = result
|
|||
|
|
|||
|
assert "Упрощённый текст для школьников" in simplified_text
|
|||
|
assert "###END###" not in simplified_text
|
|||
|
assert input_tokens > 0
|
|||
|
assert output_tokens > 0
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_simplify_text_openai_error(self, test_config):
|
|||
|
adapter = LLMProviderAdapter(test_config)
|
|||
|
|
|||
|
with patch.object(adapter.client.chat.completions, "create") as mock_create:
|
|||
|
mock_create.side_effect = RateLimitError(
|
|||
|
"Rate limit exceeded", response=None, body=None
|
|||
|
)
|
|||
|
|
|||
|
with patch.object(adapter, "_check_rpm_limit"):
|
|||
|
with pytest.raises(LLMRateLimitError):
|
|||
|
await adapter.simplify_text(
|
|||
|
title="Тест",
|
|||
|
wiki_text="Тестовый текст",
|
|||
|
prompt_template="### role: user\n{wiki_source_text}",
|
|||
|
)
|
|||
|
|
|||
|
def test_parse_prompt_template(self, test_config):
|
|||
|
adapter = LLMProviderAdapter(test_config)
|
|||
|
|
|||
|
template = """### role: system
|
|||
|
Ты помощник.
|
|||
|
|
|||
|
### role: user
|
|||
|
Задание: {task}"""
|
|||
|
|
|||
|
messages = adapter._parse_prompt_template(template)
|
|||
|
|
|||
|
assert len(messages) == 2
|
|||
|
assert messages[0]["role"] == "system"
|
|||
|
assert messages[0]["content"] == "Ты помощник."
|
|||
|
assert messages[1]["role"] == "user"
|
|||
|
assert messages[1]["content"] == "Задание: {task}"
|
|||
|
|
|||
|
def test_parse_prompt_template_fallback(self, test_config):
|
|||
|
adapter = LLMProviderAdapter(test_config)
|
|||
|
|
|||
|
template = "Обычный текст без ролей"
|
|||
|
messages = adapter._parse_prompt_template(template)
|
|||
|
|
|||
|
assert len(messages) == 1
|
|||
|
assert messages[0]["role"] == "user"
|
|||
|
assert messages[0]["content"] == template
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_health_check_success(self, test_config, mock_openai_response):
|
|||
|
adapter = LLMProviderAdapter(test_config)
|
|||
|
|
|||
|
with patch.object(adapter.client.chat.completions, "create") as mock_create:
|
|||
|
mock_create.return_value = mock_openai_response
|
|||
|
|
|||
|
result = await adapter.health_check()
|
|||
|
assert result is True
|
|||
|
|
|||
|
@pytest.mark.asyncio
|
|||
|
async def test_health_check_failure(self, test_config):
|
|||
|
adapter = LLMProviderAdapter(test_config)
|
|||
|
|
|||
|
with patch.object(adapter.client.chat.completions, "create") as mock_create:
|
|||
|
mock_create.side_effect = APIError("API Error", response=None, body=None)
|
|||
|
|
|||
|
result = await adapter.health_check()
|
|||
|
assert result is False
|