dataloader/tests/unit/test_workers_base.py

447 lines
16 KiB
Python
Raw Normal View History

2025-11-05 13:00:41 +01:00
# tests/unit/test_workers_base.py
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from dataloader.workers.base import PGWorker, WorkerConfig
@pytest.mark.unit
class TestPGWorker:
"""
Unit тесты для PGWorker.
"""
def test_init_creates_worker_with_config(self):
"""
Тест создания воркера с конфигурацией.
"""
cfg = WorkerConfig(queue="test_queue", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx:
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
worker = PGWorker(cfg, stop_event)
assert worker._cfg == cfg
assert worker._stop == stop_event
assert worker._listener is None
assert not worker._notify_wakeup.is_set()
@pytest.mark.asyncio
async def test_run_starts_listener_and_processes_jobs(self):
"""
Тест запуска воркера с listener'ом.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls:
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
mock_cfg.pg.url = "postgresql+asyncpg://test"
mock_listener = Mock()
mock_listener.start = AsyncMock()
mock_listener.stop = AsyncMock()
mock_listener_cls.return_value = mock_listener
worker = PGWorker(cfg, stop_event)
call_count = [0]
async def mock_claim():
call_count[0] += 1
if call_count[0] >= 2:
stop_event.set()
return False
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim):
await worker.run()
assert mock_listener.start.call_count == 1
assert mock_listener.stop.call_count == 1
@pytest.mark.asyncio
async def test_run_falls_back_to_polling_if_listener_fails(self):
"""
Тест fallback на polling, если LISTEN/NOTIFY не запустился.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls:
mock_logger = Mock()
mock_ctx.get_logger.return_value = mock_logger
mock_ctx.sessionmaker = Mock()
mock_cfg.pg.url = "postgresql+asyncpg://test"
mock_listener = Mock()
mock_listener.start = AsyncMock(side_effect=Exception("Connection failed"))
mock_listener_cls.return_value = mock_listener
worker = PGWorker(cfg, stop_event)
call_count = [0]
async def mock_claim():
call_count[0] += 1
if call_count[0] >= 2:
stop_event.set()
return False
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim):
await worker.run()
assert worker._listener is None
assert mock_logger.warning.call_count == 1
@pytest.mark.asyncio
async def test_listen_or_sleep_with_listener_waits_for_notify(self):
"""
Тест ожидания через LISTEN/NOTIFY.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx:
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
worker = PGWorker(cfg, stop_event)
worker._listener = Mock()
worker._notify_wakeup.set()
await worker._listen_or_sleep(1)
assert not worker._notify_wakeup.is_set()
@pytest.mark.asyncio
async def test_listen_or_sleep_without_listener_uses_timeout(self):
"""
Тест fallback на таймаут без listener'а.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx:
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
worker = PGWorker(cfg, stop_event)
start_time = asyncio.get_event_loop().time()
await worker._listen_or_sleep(1)
elapsed = asyncio.get_event_loop().time() - start_time
assert elapsed >= 1.0
@pytest.mark.asyncio
async def test_claim_and_execute_once_returns_false_when_no_job(self):
"""
Тест, что claim_and_execute_once возвращает False, если задач нет.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
mock_session = AsyncMock()
mock_session.commit = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value=None)
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
result = await worker._claim_and_execute_once()
assert result is False
assert mock_session.commit.call_count == 1
@pytest.mark.asyncio
async def test_claim_and_execute_once_executes_job_successfully(self):
"""
Тест успешного выполнения задачи.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {"key": "value"}
})
mock_repo.finish_ok = AsyncMock()
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
async def mock_pipeline(task, args):
yield
with patch.object(worker, "_pipeline", side_effect=mock_pipeline), \
patch.object(worker, "_execute_with_heartbeat", return_value=False):
result = await worker._claim_and_execute_once()
assert result is True
assert mock_repo.finish_ok.call_count == 1
@pytest.mark.asyncio
async def test_claim_and_execute_once_handles_cancellation(self):
"""
Тест обработки отмены задачи пользователем.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {}
})
mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
with patch.object(worker, "_execute_with_heartbeat", return_value=True):
result = await worker._claim_and_execute_once()
assert result is True
mock_repo.finish_fail_or_retry.assert_called_once()
args = mock_repo.finish_fail_or_retry.call_args
assert "canceled by user" in args[0]
@pytest.mark.asyncio
async def test_claim_and_execute_once_handles_exceptions(self):
"""
Тест обработки исключений при выполнении задачи.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {}
})
mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
with patch.object(worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")):
result = await worker._claim_and_execute_once()
assert result is True
mock_repo.finish_fail_or_retry.assert_called_once()
args = mock_repo.finish_fail_or_retry.call_args
assert "Test error" in args[0]
@pytest.mark.asyncio
async def test_execute_with_heartbeat_sends_heartbeats(self):
"""
Тест отправки heartbeat'ов во время выполнения задачи.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.heartbeat = AsyncMock(return_value=(True, False))
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
async def slow_pipeline():
await asyncio.sleep(0.5)
yield
await asyncio.sleep(0.6)
yield
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline())
assert canceled is False
assert mock_repo.heartbeat.call_count >= 1
@pytest.mark.asyncio
async def test_execute_with_heartbeat_detects_cancellation(self):
"""
Тест обнаружения отмены через heartbeat.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.heartbeat = AsyncMock(return_value=(True, True))
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
async def slow_pipeline():
await asyncio.sleep(0.5)
yield
await asyncio.sleep(0.6)
yield
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline())
assert canceled is True
@pytest.mark.asyncio
async def test_pipeline_handles_sync_function(self):
"""
Тест выполнения синхронной функции-пайплайна.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
def sync_pipeline(args):
return "result"
mock_resolve.return_value = sync_pipeline
worker = PGWorker(cfg, stop_event)
results = []
async for _ in worker._pipeline("test.task", {}):
results.append(_)
assert len(results) == 1
@pytest.mark.asyncio
async def test_pipeline_handles_async_function(self):
"""
Тест выполнения асинхронной функции-пайплайна.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
async def async_pipeline(args):
return "result"
mock_resolve.return_value = async_pipeline
worker = PGWorker(cfg, stop_event)
results = []
async for _ in worker._pipeline("test.task", {}):
results.append(_)
assert len(results) == 1
@pytest.mark.asyncio
async def test_pipeline_handles_async_generator(self):
"""
Тест выполнения асинхронного генератора-пайплайна.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
async def async_gen_pipeline(args):
yield
yield
yield
mock_resolve.return_value = async_gen_pipeline
worker = PGWorker(cfg, stop_event)
results = []
async for _ in worker._pipeline("test.task", {}):
results.append(_)
assert len(results) == 3