From 9271f12b6a18b97f3474224fb914eec7c1288870 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Tue, 27 Jan 2026 01:39:36 +0100 Subject: [PATCH 1/3] moved init stuff to DI --- backend/app/api/routes/sse.py | 29 +- backend/app/core/dishka_lifespan.py | 33 +- backend/app/core/lifecycle.py | 62 -- backend/app/core/providers.py | 452 ++++---- backend/app/db/repositories/__init__.py | 12 + .../execution_queue_repository.py | 234 +++++ .../execution_state_repository.py | 65 ++ .../db/repositories/pod_state_repository.py | 180 ++++ .../db/repositories/resource_repository.py | 300 ++++++ backend/app/dlq/manager.py | 370 +++---- backend/app/events/core/__init__.py | 4 - backend/app/events/core/consumer.py | 277 +---- backend/app/events/core/producer.py | 120 +-- backend/app/events/event_store_consumer.py | 190 ---- backend/app/services/coordinator/__init__.py | 6 - .../app/services/coordinator/coordinator.py | 554 ++++------ .../app/services/coordinator/queue_manager.py | 271 ----- .../services/coordinator/resource_manager.py | 324 ------ backend/app/services/event_bus.py | 379 ++----- .../app/services/idempotency/middleware.py | 68 +- backend/app/services/k8s_worker/worker.py | 496 +++------ backend/app/services/kafka_event_service.py | 60 +- backend/app/services/notification_service.py | 976 +++++++----------- backend/app/services/pod_monitor/monitor.py | 441 ++------ .../services/result_processor/processor.py | 173 +--- backend/app/services/saga/__init__.py | 3 +- .../app/services/saga/saga_orchestrator.py | 435 +++----- .../app/services/sse/kafka_redis_bridge.py | 207 ++-- backend/app/services/sse/sse_service.py | 28 +- .../app/services/sse/sse_shutdown_manager.py | 40 - backend/app/services/user_settings_service.py | 40 +- backend/di_lifecycle_refactor_plan.md | 64 ++ .../tests/e2e/core/test_dishka_lifespan.py | 10 +- backend/tests/e2e/dlq/test_dlq_manager.py | 43 +- .../e2e/events/test_consume_roundtrip.py | 28 +- .../e2e/events/test_consumer_lifecycle.py | 44 +- .../tests/e2e/events/test_event_dispatcher.py | 33 +- .../e2e/events/test_producer_roundtrip.py | 44 +- .../idempotency/test_consumer_idempotent.py | 36 +- .../result_processor/test_result_processor.py | 134 --- .../coordinator/test_execution_coordinator.py | 148 +-- .../e2e/services/events/test_event_bus.py | 31 +- .../sse/test_partitioned_event_router.py | 81 -- .../tests/e2e/test_k8s_worker_create_pod.py | 41 +- .../coordinator/test_queue_manager.py | 41 - .../coordinator/test_resource_manager.py | 61 -- .../unit/services/pod_monitor/test_monitor.py | 784 +++----------- .../result_processor/test_processor.py | 30 +- .../saga/test_saga_orchestrator_unit.py | 37 +- .../services/sse/test_kafka_redis_bridge.py | 49 +- .../services/sse/test_shutdown_manager.py | 41 +- .../unit/services/sse/test_sse_service.py | 14 +- .../services/sse/test_sse_shutdown_manager.py | 32 +- backend/workers/dlq_processor.py | 4 +- backend/workers/run_coordinator.py | 44 +- backend/workers/run_event_replay.py | 46 +- backend/workers/run_k8s_worker.py | 43 +- backend/workers/run_pod_monitor.py | 45 +- backend/workers/run_result_processor.py | 86 +- backend/workers/run_saga_orchestrator.py | 37 +- 60 files changed, 3027 insertions(+), 5933 deletions(-) delete mode 100644 backend/app/core/lifecycle.py create mode 100644 backend/app/db/repositories/execution_queue_repository.py create mode 100644 backend/app/db/repositories/execution_state_repository.py create mode 100644 backend/app/db/repositories/pod_state_repository.py create mode 100644 backend/app/db/repositories/resource_repository.py delete mode 100644 backend/app/events/event_store_consumer.py delete mode 100644 backend/app/services/coordinator/queue_manager.py delete mode 100644 backend/app/services/coordinator/resource_manager.py create mode 100644 backend/di_lifecycle_refactor_plan.md delete mode 100644 backend/tests/e2e/result_processor/test_result_processor.py delete mode 100644 backend/tests/e2e/services/sse/test_partitioned_event_router.py delete mode 100644 backend/tests/unit/services/coordinator/test_queue_manager.py delete mode 100644 backend/tests/unit/services/coordinator/test_resource_manager.py diff --git a/backend/app/api/routes/sse.py b/backend/app/api/routes/sse.py index 6b1b406f..ae8d1367 100644 --- a/backend/app/api/routes/sse.py +++ b/backend/app/api/routes/sse.py @@ -3,13 +3,7 @@ from fastapi import APIRouter, Request from sse_starlette.sse import EventSourceResponse -from app.domain.sse import SSEHealthDomain -from app.schemas_pydantic.sse import ( - ShutdownStatusResponse, - SSEExecutionEventData, - SSEHealthResponse, - SSENotificationEventData, -) +from app.schemas_pydantic.sse import SSEExecutionEventData, SSENotificationEventData from app.services.auth_service import AuthService from app.services.sse.sse_service import SSEService @@ -38,24 +32,3 @@ async def execution_events( return EventSourceResponse( sse_service.create_execution_stream(execution_id=execution_id, user_id=current_user.user_id) ) - - -@router.get("/health", response_model=SSEHealthResponse) -async def sse_health( - request: Request, - sse_service: FromDishka[SSEService], - auth_service: FromDishka[AuthService], -) -> SSEHealthResponse: - """Get SSE service health status.""" - _ = await auth_service.get_current_user(request) - domain: SSEHealthDomain = await sse_service.get_health_status() - return SSEHealthResponse( - status=domain.status, - kafka_enabled=domain.kafka_enabled, - active_connections=domain.active_connections, - active_executions=domain.active_executions, - active_consumers=domain.active_consumers, - max_connections_per_user=domain.max_connections_per_user, - shutdown=ShutdownStatusResponse(**vars(domain.shutdown)), - timestamp=domain.timestamp, - ) diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py index 857222be..f0177a94 100644 --- a/backend/app/core/dishka_lifespan.py +++ b/backend/app/core/dishka_lifespan.py @@ -2,7 +2,7 @@ import asyncio import logging -from contextlib import AsyncExitStack, asynccontextmanager +from contextlib import asynccontextmanager from typing import AsyncGenerator import redis.asyncio as redis @@ -15,7 +15,6 @@ from app.core.startup import initialize_rate_limits from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS -from app.events.event_store_consumer import EventStoreConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.notification_service import NotificationService from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge @@ -76,26 +75,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: extra={"testing": settings.TESTING, "enable_tracing": settings.ENABLE_TRACING}, ) - # Phase 1: Resolve all DI dependencies in parallel - ( - schema_registry, - database, - redis_client, - rate_limit_metrics, - sse_bridge, - event_store_consumer, - notification_service, - ) = await asyncio.gather( + # Resolve DI dependencies in parallel (fail fast on config issues) + schema_registry, database, redis_client, rate_limit_metrics, _, _ = await asyncio.gather( container.get(SchemaRegistryManager), container.get(Database), container.get(redis.Redis), container.get(RateLimitMetrics), container.get(SSEKafkaRedisBridge), - container.get(EventStoreConsumer), container.get(NotificationService), ) - # Phase 2: Initialize infrastructure in parallel (independent subsystems) + # Initialize infrastructure in parallel await asyncio.gather( initialize_event_schemas(schema_registry), init_beanie(database=database, document_models=ALL_DOCUMENTS), @@ -103,16 +93,5 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: ) logger.info("Infrastructure initialized (schemas, beanie, rate limits)") - # Phase 3: Start Kafka consumers in parallel (providers already started them via async with, - # but __aenter__ is idempotent so this is safe and explicit) - async with AsyncExitStack() as stack: - stack.push_async_callback(sse_bridge.aclose) - stack.push_async_callback(event_store_consumer.aclose) - stack.push_async_callback(notification_service.aclose) - await asyncio.gather( - sse_bridge.__aenter__(), - event_store_consumer.__aenter__(), - notification_service.__aenter__(), - ) - logger.info("SSE bridge, EventStoreConsumer, and NotificationService started") - yield + yield + # Container close handles all cleanup automatically diff --git a/backend/app/core/lifecycle.py b/backend/app/core/lifecycle.py deleted file mode 100644 index 2e0d8f85..00000000 --- a/backend/app/core/lifecycle.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from types import TracebackType -from typing import Self - - -class LifecycleEnabled: - """Base class for services with async lifecycle management. - - Usage: - async with MyService() as service: - # service is running - # service is stopped - - Subclasses override _on_start() and _on_stop() for their logic. - Base class handles idempotency and context manager protocol. - - For internal component cleanup, use aclose() which follows Python's - standard async cleanup pattern (like aiofiles, aiohttp). - """ - - def __init__(self) -> None: - self._lifecycle_started: bool = False - - async def _on_start(self) -> None: - """Override with startup logic. Called once on enter.""" - pass - - async def _on_stop(self) -> None: - """Override with cleanup logic. Called once on exit.""" - pass - - async def aclose(self) -> None: - """Close the service. For internal component cleanup. - - Mirrors Python's standard aclose() pattern (like aiofiles, aiohttp). - Idempotent - safe to call multiple times. - """ - if not self._lifecycle_started: - return - self._lifecycle_started = False - await self._on_stop() - - @property - def is_running(self) -> bool: - """Check if service is currently running.""" - return self._lifecycle_started - - async def __aenter__(self) -> Self: - if self._lifecycle_started: - return self # Already started, idempotent - await self._on_start() - self._lifecycle_started = True - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc: BaseException | None, - tb: TracebackType | None, - ) -> None: - await self.aclose() diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 7dc457ec..d2465123 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -1,9 +1,8 @@ -from __future__ import annotations - import logging from typing import AsyncIterator import redis.asyncio as redis +from aiokafka import AIOKafkaProducer from dishka import Provider, Scope, from_context, provide from pymongo.asynchronous.mongo_client import AsyncMongoClient @@ -39,20 +38,22 @@ from app.db.repositories.admin.admin_settings_repository import AdminSettingsRepository from app.db.repositories.admin.admin_user_repository import AdminUserRepository from app.db.repositories.dlq_repository import DLQRepository +from app.db.repositories.execution_queue_repository import ExecutionQueueRepository +from app.db.repositories.execution_state_repository import ExecutionStateRepository +from app.db.repositories.pod_state_repository import PodStateRepository from app.db.repositories.replay_repository import ReplayRepository from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository +from app.db.repositories.resource_repository import ResourceRepository from app.db.repositories.user_settings_repository import UserSettingsRepository -from app.dlq.manager import DLQManager, create_dlq_manager +from app.dlq.manager import DLQManager from app.domain.saga.models import SagaConfig -from app.events.core import UnifiedProducer +from app.events.core import ProducerMetrics, UnifiedProducer from app.events.event_store import EventStore, create_event_store -from app.events.event_store_consumer import EventStoreConsumer, create_event_store_consumer from app.events.schema.schema_registry import SchemaRegistryManager -from app.infrastructure.kafka.topics import get_all_topics from app.services.admin import AdminEventsService, AdminSettingsService, AdminUserService from app.services.auth_service import AuthService from app.services.coordinator.coordinator import ExecutionCoordinator -from app.services.event_bus import EventBusManager +from app.services.event_bus import EventBus from app.services.event_replay.replay_service import EventReplayService from app.services.event_service import EventService from app.services.execution_service import ExecutionService @@ -70,13 +71,13 @@ from app.services.rate_limit_service import RateLimitService from app.services.replay_service import ReplayService from app.services.result_processor.resource_cleaner import ResourceCleaner -from app.services.saga import SagaOrchestrator, create_saga_orchestrator +from app.services.saga import SagaOrchestrator from app.services.saga.saga_service import SagaService from app.services.saved_script_service import SavedScriptService -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge, create_sse_kafka_redis_bridge +from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus from app.services.sse.sse_service import SSEService -from app.services.sse.sse_shutdown_manager import SSEShutdownManager, create_sse_shutdown_manager +from app.services.sse.sse_shutdown_manager import SSEShutdownManager from app.services.user_settings_service import UserSettingsService from app.settings import Settings @@ -113,12 +114,12 @@ async def get_redis_client(self, settings: Settings, logger: logging.Logger) -> socket_timeout=5, ) # Test connection - await client.ping() # type: ignore[misc] # redis-py returns Awaitable[bool] | bool + await client.ping() # type: ignore[misc] # redis-py dual sync/async return type logger.info(f"Redis connected: {settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}") try: yield client finally: - await client.close() + await client.aclose() @provide def get_rate_limit_service( @@ -127,6 +128,48 @@ def get_rate_limit_service( return RateLimitService(redis_client, settings, rate_limit_metrics) +class RedisRepositoryProvider(Provider): + """Provides Redis-backed state repositories for stateless services.""" + + scope = Scope.APP + + @provide + def get_execution_state_repository( + self, redis_client: redis.Redis, logger: logging.Logger + ) -> ExecutionStateRepository: + return ExecutionStateRepository(redis_client, logger) + + @provide + def get_execution_queue_repository( + self, redis_client: redis.Redis, logger: logging.Logger, settings: Settings + ) -> ExecutionQueueRepository: + return ExecutionQueueRepository( + redis_client, + logger, + max_queue_size=10000, + max_executions_per_user=100, + ) + + @provide + async def get_resource_repository( + self, redis_client: redis.Redis, logger: logging.Logger, settings: Settings + ) -> ResourceRepository: + repo = ResourceRepository( + redis_client, + logger, + total_cpu_cores=32.0, + total_memory_mb=65536, + ) + await repo.initialize() + return repo + + @provide + def get_pod_state_repository( + self, redis_client: redis.Redis, logger: logging.Logger + ) -> PodStateRepository: + return PodStateRepository(redis_client, logger) + + class DatabaseProvider(Provider): scope = Scope.APP @@ -155,27 +198,69 @@ def get_tracer_manager(self, settings: Settings) -> TracerManager: return TracerManager(tracer_name=settings.TRACING_SERVICE_NAME) -class MessagingProvider(Provider): +class KafkaProvider(Provider): + """Provides Kafka producer - low-level AIOKafkaProducer for DI.""" + scope = Scope.APP @provide - async def get_kafka_producer( - self, settings: Settings, schema_registry: SchemaRegistryManager, logger: logging.Logger, - event_metrics: EventMetrics - ) -> AsyncIterator[UnifiedProducer]: - async with UnifiedProducer(schema_registry, logger, settings, event_metrics) as producer: + async def get_aiokafka_producer( + self, settings: Settings, logger: logging.Logger + ) -> AsyncIterator[AIOKafkaProducer]: + producer = AIOKafkaProducer( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + acks="all", + enable_idempotence=True, + max_request_size=10 * 1024 * 1024, # 10MB + ) + await producer.start() + logger.info(f"Kafka producer started: {settings.KAFKA_BOOTSTRAP_SERVERS}") + try: yield producer + finally: + await producer.stop() + + +class MessagingProvider(Provider): + scope = Scope.APP @provide - async def get_dlq_manager( + def get_producer_metrics(self) -> ProducerMetrics: + return ProducerMetrics() + + @provide + def get_unified_producer( + self, + aiokafka_producer: AIOKafkaProducer, + schema_registry: SchemaRegistryManager, + logger: logging.Logger, + event_metrics: EventMetrics, + producer_metrics: ProducerMetrics, + ) -> UnifiedProducer: + return UnifiedProducer( + producer=aiokafka_producer, + schema_registry_manager=schema_registry, + logger=logger, + event_metrics=event_metrics, + producer_metrics=producer_metrics, + ) + + @provide + def get_dlq_manager( self, settings: Settings, + aiokafka_producer: AIOKafkaProducer, schema_registry: SchemaRegistryManager, logger: logging.Logger, dlq_metrics: DLQMetrics, - ) -> AsyncIterator[DLQManager]: - async with create_dlq_manager(settings, schema_registry, logger, dlq_metrics) as manager: - yield manager + ) -> DLQManager: + return DLQManager( + settings=settings, + producer=aiokafka_producer, + schema_registry=schema_registry, + logger=logger, + dlq_metrics=dlq_metrics, + ) @provide def get_idempotency_repository(self, redis_client: redis.Redis) -> RedisIdempotencyRepository: @@ -203,7 +288,7 @@ def get_schema_registry(self, settings: Settings, logger: logging.Logger) -> Sch return SchemaRegistryManager(settings, logger) @provide - async def get_event_store( + def get_event_store( self, schema_registry: SchemaRegistryManager, logger: logging.Logger, event_metrics: EventMetrics ) -> EventStore: return create_event_store( @@ -211,36 +296,19 @@ async def get_event_store( ) @provide - async def get_event_store_consumer( + def get_event_bus( self, - event_store: EventStore, - schema_registry: SchemaRegistryManager, + aiokafka_producer: AIOKafkaProducer, settings: Settings, - kafka_producer: UnifiedProducer, logger: logging.Logger, - event_metrics: EventMetrics, - ) -> AsyncIterator[EventStoreConsumer]: - topics = get_all_topics() - async with create_event_store_consumer( - event_store=event_store, - topics=list(topics), - schema_registry_manager=schema_registry, - settings=settings, - producer=kafka_producer, - logger=logger, - event_metrics=event_metrics, - ) as consumer: - yield consumer - - @provide - async def get_event_bus_manager( - self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics - ) -> AsyncIterator[EventBusManager]: - manager = EventBusManager(settings, logger, connection_metrics) - try: - yield manager - finally: - await manager.close() + connection_metrics: ConnectionMetrics, + ) -> EventBus: + return EventBus( + producer=aiokafka_producer, + settings=settings, + logger=logger, + connection_metrics=connection_metrics, + ) class KubernetesProvider(Provider): @@ -385,35 +453,25 @@ class SSEProvider(Provider): scope = Scope.APP @provide - async def get_sse_redis_bus( - self, redis_client: redis.Redis, logger: logging.Logger - ) -> AsyncIterator[SSERedisBus]: - bus = SSERedisBus(redis_client, logger) - yield bus + def get_sse_redis_bus(self, redis_client: redis.Redis, logger: logging.Logger) -> SSERedisBus: + return SSERedisBus(redis_client, logger) @provide - async def get_sse_kafka_redis_bridge( + def get_sse_kafka_redis_bridge( self, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_metrics: EventMetrics, sse_redis_bus: SSERedisBus, logger: logging.Logger, - ) -> AsyncIterator[SSEKafkaRedisBridge]: - async with create_sse_kafka_redis_bridge( - schema_registry=schema_registry, - settings=settings, - event_metrics=event_metrics, - sse_bus=sse_redis_bus, - logger=logger, - ) as bridge: - yield bridge + ) -> SSEKafkaRedisBridge: + return SSEKafkaRedisBridge( + sse_bus=sse_redis_bus, + logger=logger, + ) @provide(scope=Scope.REQUEST) def get_sse_shutdown_manager( self, logger: logging.Logger, connection_metrics: ConnectionMetrics ) -> SSEShutdownManager: - return create_sse_shutdown_manager(logger=logger, connection_metrics=connection_metrics) + return SSEShutdownManager(logger=logger, connection_metrics=connection_metrics) @provide(scope=Scope.REQUEST) def get_sse_service( @@ -426,7 +484,6 @@ def get_sse_service( logger: logging.Logger, connection_metrics: ConnectionMetrics, ) -> SSEService: - shutdown_manager.set_router(router) return SSEService( repository=sse_repository, router=router, @@ -483,12 +540,11 @@ async def get_user_settings_service( self, repository: UserSettingsRepository, kafka_event_service: KafkaEventService, - event_bus_manager: EventBusManager, - settings: Settings, + event_bus: EventBus, logger: logging.Logger, ) -> UserSettingsService: - service = UserSettingsService(repository, kafka_event_service, settings, logger) - await service.initialize(event_bus_manager) + service = UserSettingsService(repository, kafka_event_service, event_bus, logger) + await service.setup_event_subscription() return service @@ -513,31 +569,23 @@ def get_admin_settings_service( return AdminSettingsService(admin_settings_repository, logger) @provide - async def get_notification_service( + def get_notification_service( self, notification_repository: NotificationRepository, - kafka_event_service: KafkaEventService, - event_bus_manager: EventBusManager, - schema_registry: SchemaRegistryManager, + event_bus: EventBus, sse_redis_bus: SSERedisBus, settings: Settings, logger: logging.Logger, notification_metrics: NotificationMetrics, - event_metrics: EventMetrics, - ) -> AsyncIterator[NotificationService]: - service = NotificationService( + ) -> NotificationService: + return NotificationService( notification_repository=notification_repository, - event_service=kafka_event_service, - event_bus_manager=event_bus_manager, - schema_registry_manager=schema_registry, + event_bus=event_bus, sse_bus=sse_redis_bus, settings=settings, logger=logger, notification_metrics=notification_metrics, - event_metrics=event_metrics, ) - async with service: - yield service @provide def get_grafana_alert_processor( @@ -566,68 +614,120 @@ def _create_default_saga_config() -> SagaConfig: ) -# Standalone factory functions for lifecycle-managed services (eliminates duplication) -async def _provide_saga_orchestrator( - saga_repository: SagaRepository, - kafka_producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - resource_allocation_repository: ResourceAllocationRepository, - logger: logging.Logger, - event_metrics: EventMetrics, -) -> AsyncIterator[SagaOrchestrator]: - """Shared factory for SagaOrchestrator with lifecycle management.""" - async with create_saga_orchestrator( - saga_repository=saga_repository, +class CoordinatorProvider(Provider): + scope = Scope.APP + + @provide + def get_execution_coordinator( + self, + kafka_producer: UnifiedProducer, + execution_repository: ExecutionRepository, + state_repo: ExecutionStateRepository, + queue_repo: ExecutionQueueRepository, + resource_repo: ResourceRepository, + logger: logging.Logger, + coordinator_metrics: CoordinatorMetrics, + event_metrics: EventMetrics, + ) -> ExecutionCoordinator: + return ExecutionCoordinator( producer=kafka_producer, - schema_registry_manager=schema_registry, - settings=settings, - event_store=event_store, - idempotency_manager=idempotency_manager, - resource_allocation_repository=resource_allocation_repository, - config=_create_default_saga_config(), + execution_repository=execution_repository, + state_repo=state_repo, + queue_repo=queue_repo, + resource_repo=resource_repo, logger=logger, + coordinator_metrics=coordinator_metrics, event_metrics=event_metrics, - ) as orchestrator: - yield orchestrator - - -async def _provide_execution_coordinator( - kafka_producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - execution_repository: ExecutionRepository, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - event_metrics: EventMetrics, -) -> AsyncIterator[ExecutionCoordinator]: - """Shared factory for ExecutionCoordinator with lifecycle management.""" - async with ExecutionCoordinator( + ) + + +class K8sWorkerProvider(Provider): + scope = Scope.APP + + @provide + def get_kubernetes_worker( + self, + kafka_producer: UnifiedProducer, + pod_state_repo: PodStateRepository, + k8s_clients: K8sClients, + logger: logging.Logger, + kubernetes_metrics: KubernetesMetrics, + execution_metrics: ExecutionMetrics, + event_metrics: EventMetrics, + ) -> KubernetesWorker: + config = K8sWorkerConfig() + return KubernetesWorker( + config=config, producer=kafka_producer, - schema_registry_manager=schema_registry, - settings=settings, - event_store=event_store, - execution_repository=execution_repository, - idempotency_manager=idempotency_manager, + pod_state_repo=pod_state_repo, + v1_client=k8s_clients.v1, + networking_v1_client=k8s_clients.networking_v1, + apps_v1_client=k8s_clients.apps_v1, + logger=logger, + kubernetes_metrics=kubernetes_metrics, + execution_metrics=execution_metrics, + event_metrics=event_metrics, + ) + + +class PodMonitorProvider(Provider): + scope = Scope.APP + + @provide + def get_event_mapper( + self, + logger: logging.Logger, + k8s_clients: K8sClients, + ) -> PodEventMapper: + return PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) + + @provide + def get_pod_monitor( + self, + kafka_producer: UnifiedProducer, + pod_state_repo: PodStateRepository, + k8s_clients: K8sClients, + logger: logging.Logger, + event_mapper: PodEventMapper, + kubernetes_metrics: KubernetesMetrics, + ) -> PodMonitor: + config = PodMonitorConfig() + return PodMonitor( + config=config, + producer=kafka_producer, + pod_state_repo=pod_state_repo, + v1_client=k8s_clients.v1, + event_mapper=event_mapper, + logger=logger, + kubernetes_metrics=kubernetes_metrics, + ) + + +class SagaOrchestratorProvider(Provider): + scope = Scope.APP + + @provide + def get_saga_orchestrator( + self, + saga_repository: SagaRepository, + kafka_producer: UnifiedProducer, + resource_allocation_repository: ResourceAllocationRepository, + logger: logging.Logger, + event_metrics: EventMetrics, + ) -> SagaOrchestrator: + return SagaOrchestrator( + config=_create_default_saga_config(), + saga_repository=saga_repository, + producer=kafka_producer, + resource_allocation_repository=resource_allocation_repository, logger=logger, - coordinator_metrics=coordinator_metrics, event_metrics=event_metrics, - ) as coordinator: - yield coordinator + ) class BusinessServicesProvider(Provider): scope = Scope.REQUEST - def __init__(self) -> None: - super().__init__() - # Register shared factory functions on instance (avoids warning about missing self) - self.provide(_provide_execution_coordinator) - @provide def get_saga_service( self, @@ -697,82 +797,6 @@ def get_admin_user_service( ) -class CoordinatorProvider(Provider): - scope = Scope.APP - - def __init__(self) -> None: - super().__init__() - self.provide(_provide_execution_coordinator) - - -class K8sWorkerProvider(Provider): - scope = Scope.APP - - @provide - async def get_kubernetes_worker( - self, - kafka_producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - event_metrics: EventMetrics, - ) -> AsyncIterator[KubernetesWorker]: - config = K8sWorkerConfig() - async with KubernetesWorker( - config=config, - producer=kafka_producer, - schema_registry_manager=schema_registry, - settings=settings, - event_store=event_store, - idempotency_manager=idempotency_manager, - logger=logger, - event_metrics=event_metrics, - ) as worker: - yield worker - - -class PodMonitorProvider(Provider): - scope = Scope.APP - - @provide - def get_event_mapper( - self, - logger: logging.Logger, - k8s_clients: K8sClients, - ) -> PodEventMapper: - return PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) - - @provide - async def get_pod_monitor( - self, - kafka_event_service: KafkaEventService, - k8s_clients: K8sClients, - logger: logging.Logger, - event_mapper: PodEventMapper, - kubernetes_metrics: KubernetesMetrics, - ) -> AsyncIterator[PodMonitor]: - config = PodMonitorConfig() - async with PodMonitor( - config=config, - kafka_event_service=kafka_event_service, - logger=logger, - k8s_clients=k8s_clients, - event_mapper=event_mapper, - kubernetes_metrics=kubernetes_metrics, - ) as monitor: - yield monitor - - -class SagaOrchestratorProvider(Provider): - scope = Scope.APP - - def __init__(self) -> None: - super().__init__() - self.provide(_provide_saga_orchestrator) - - class EventReplayProvider(Provider): scope = Scope.APP diff --git a/backend/app/db/repositories/__init__.py b/backend/app/db/repositories/__init__.py index 1e985797..c5e0199c 100644 --- a/backend/app/db/repositories/__init__.py +++ b/backend/app/db/repositories/__init__.py @@ -1,9 +1,13 @@ from app.db.repositories.admin.admin_settings_repository import AdminSettingsRepository from app.db.repositories.admin.admin_user_repository import AdminUserRepository from app.db.repositories.event_repository import EventRepository +from app.db.repositories.execution_queue_repository import ExecutionQueueRepository, QueuePriority, QueueStats from app.db.repositories.execution_repository import ExecutionRepository +from app.db.repositories.execution_state_repository import ExecutionStateRepository from app.db.repositories.notification_repository import NotificationRepository +from app.db.repositories.pod_state_repository import PodStateRepository from app.db.repositories.replay_repository import ReplayRepository +from app.db.repositories.resource_repository import ResourceAllocation, ResourceRepository, ResourceStats from app.db.repositories.saga_repository import SagaRepository from app.db.repositories.saved_script_repository import SavedScriptRepository from app.db.repositories.sse_repository import SSERepository @@ -15,8 +19,16 @@ "AdminUserRepository", "EventRepository", "ExecutionRepository", + "ExecutionQueueRepository", + "ExecutionStateRepository", "NotificationRepository", + "PodStateRepository", + "QueuePriority", + "QueueStats", "ReplayRepository", + "ResourceAllocation", + "ResourceRepository", + "ResourceStats", "SagaRepository", "SavedScriptRepository", "SSERepository", diff --git a/backend/app/db/repositories/execution_queue_repository.py b/backend/app/db/repositories/execution_queue_repository.py new file mode 100644 index 00000000..d24af7bf --- /dev/null +++ b/backend/app/db/repositories/execution_queue_repository.py @@ -0,0 +1,234 @@ +"""Redis-backed execution queue repository. + +Replaces in-memory priority queue (QueueManager) with Redis sorted sets +for stateless, horizontally-scalable services. +""" + +from __future__ import annotations + +import json +import logging +import time +from dataclasses import dataclass +from enum import IntEnum + +import redis.asyncio as redis + + +class QueuePriority(IntEnum): + """Execution queue priorities. Lower value = higher priority.""" + + CRITICAL = 0 + HIGH = 1 + NORMAL = 5 + LOW = 8 + BACKGROUND = 10 + + +@dataclass +class QueueStats: + """Queue statistics.""" + + total_size: int + priority_distribution: dict[str, int] + max_queue_size: int + utilization_percent: float + + +class ExecutionQueueRepository: + """Redis-backed priority queue for executions. + + Uses Redis sorted sets for O(log N) priority queue operations. + Stores event data in hash maps for retrieval. + """ + + QUEUE_KEY = "exec:queue" + DATA_KEY_PREFIX = "exec:queue:data" + USER_COUNT_KEY = "exec:queue:user_count" + + def __init__( + self, + redis_client: redis.Redis, + logger: logging.Logger, + max_queue_size: int = 10000, + max_executions_per_user: int = 100, + stale_timeout_seconds: int = 3600, + ) -> None: + self._redis = redis_client + self._logger = logger + self.max_queue_size = max_queue_size + self.max_executions_per_user = max_executions_per_user + self.stale_timeout_seconds = stale_timeout_seconds + + async def enqueue( + self, + execution_id: str, + event_data: dict[str, object], + priority: QueuePriority, + user_id: str, + ) -> tuple[bool, int | None, str | None]: + """Add execution to queue. Returns (success, position, error).""" + # Check queue size + queue_size = await self._redis.zcard(self.QUEUE_KEY) + if queue_size >= self.max_queue_size: + return False, None, "Queue is full" + + # Check user limit + user_count = await self._redis.hincrby(self.USER_COUNT_KEY, user_id, 0) # type: ignore[misc] + if user_count >= self.max_executions_per_user: + return False, None, f"User execution limit exceeded ({self.max_executions_per_user})" + + # Score: priority * 1e12 + timestamp (lower = higher priority, earlier = higher priority) + timestamp = time.time() + score = priority.value * 1e12 + timestamp + + # Use pipeline for atomicity + pipe = self._redis.pipeline() + + # Add to sorted set + pipe.zadd(self.QUEUE_KEY, {execution_id: score}) + + # Store event data + data_key = f"{self.DATA_KEY_PREFIX}:{execution_id}" + event_data["_enqueue_timestamp"] = timestamp + event_data["_priority"] = priority.value + event_data["_user_id"] = user_id + pipe.hset(data_key, mapping={k: json.dumps(v) if not isinstance(v, str) else v for k, v in event_data.items()}) + pipe.expire(data_key, self.stale_timeout_seconds + 60) + + # Increment user count + pipe.hincrby(self.USER_COUNT_KEY, user_id, 1) + + await pipe.execute() + + # Get position + position = await self._redis.zrank(self.QUEUE_KEY, execution_id) + + self._logger.info( + f"Enqueued execution {execution_id}. Priority: {priority.name}, " + f"Position: {position}, Queue size: {queue_size + 1}" + ) + + return True, position, None + + async def dequeue(self) -> tuple[str, dict[str, object | float | str]] | None: + """Remove and return highest priority execution. Returns (execution_id, event_data) or None.""" + while True: + # Pop the lowest score (highest priority) + result = await self._redis.zpopmin(self.QUEUE_KEY, count=1) + if not result: + return None + + execution_id = result[0][0] + if isinstance(execution_id, bytes): + execution_id = execution_id.decode() + + # Get event data + data_key = f"{self.DATA_KEY_PREFIX}:{execution_id}" + raw_data = await self._redis.hgetall(data_key) # type: ignore[misc] + + if not raw_data: + # Data expired or missing, skip this entry + self._logger.warning(f"Queue entry {execution_id} has no data, skipping") + continue + + # Parse data + event_data: dict[str, object | float | str] = {} + for k, v in raw_data.items(): + key = k.decode() if isinstance(k, bytes) else k + val = v.decode() if isinstance(v, bytes) else v + try: + event_data[key] = json.loads(val) + except (json.JSONDecodeError, TypeError): + event_data[key] = val + + # Check if stale + enqueue_time_val = event_data.pop("_enqueue_timestamp", 0) + enqueue_time = float(enqueue_time_val) if isinstance(enqueue_time_val, (int, float, str)) else 0.0 + event_data.pop("_priority", None) + user_id_val = event_data.pop("_user_id", "anonymous") + user_id = str(user_id_val) + + age = time.time() - enqueue_time + if age > self.stale_timeout_seconds: + # Stale, clean up and continue + await self._redis.delete(data_key) + await self._redis.hincrby(self.USER_COUNT_KEY, user_id, -1) # type: ignore[misc] + self._logger.info(f"Skipped stale execution {execution_id} (age: {age:.2f}s)") + continue + + # Clean up + await self._redis.delete(data_key) + await self._redis.hincrby(self.USER_COUNT_KEY, user_id, -1) # type: ignore[misc] + + self._logger.info(f"Dequeued execution {execution_id}. Wait time: {age:.2f}s") + return execution_id, event_data + + async def remove(self, execution_id: str) -> bool: + """Remove specific execution from queue. Returns True if removed.""" + # Get user_id before removing + data_key = f"{self.DATA_KEY_PREFIX}:{execution_id}" + raw_data = await self._redis.hgetall(data_key) # type: ignore[misc] + + removed = await self._redis.zrem(self.QUEUE_KEY, execution_id) + if removed: + # Decrement user count + if raw_data: + user_id_raw = raw_data.get(b"_user_id") or raw_data.get("_user_id") + if user_id_raw: + user_id = user_id_raw.decode() if isinstance(user_id_raw, bytes) else user_id_raw + try: + user_id = json.loads(user_id) + except (json.JSONDecodeError, TypeError): + pass + await self._redis.hincrby(self.USER_COUNT_KEY, str(user_id), -1) # type: ignore[misc] + + await self._redis.delete(data_key) + self._logger.info(f"Removed execution {execution_id} from queue") + return True + return False + + async def get_position(self, execution_id: str) -> int | None: + """Get queue position of execution (0-indexed).""" + result = await self._redis.zrank(self.QUEUE_KEY, execution_id) + return int(result) if result is not None else None + + async def get_stats(self) -> QueueStats: + """Get queue statistics.""" + total_size = await self._redis.zcard(self.QUEUE_KEY) + + # Count by priority (sample first 1000) + priority_counts: dict[str, int] = {} + entries = await self._redis.zrange(self.QUEUE_KEY, 0, 999, withscores=True) + for _, score in entries: + priority_value = int(score // 1e12) + try: + priority_name = QueuePriority(priority_value).name + except ValueError: + priority_name = "UNKNOWN" + priority_counts[priority_name] = priority_counts.get(priority_name, 0) + 1 + + return QueueStats( + total_size=total_size, + priority_distribution=priority_counts, + max_queue_size=self.max_queue_size, + utilization_percent=(total_size / self.max_queue_size) * 100 if self.max_queue_size > 0 else 0, + ) + + async def cleanup_stale(self) -> int: + """Remove stale entries. Returns count removed. Call periodically.""" + removed = 0 + threshold_score = QueuePriority.BACKGROUND.value * 1e12 + (time.time() - self.stale_timeout_seconds) + + # Get entries older than threshold + stale_entries = await self._redis.zrangebyscore(self.QUEUE_KEY, "-inf", threshold_score, start=0, num=100) + + for entry in stale_entries: + execution_id = entry.decode() if isinstance(entry, bytes) else entry + if await self.remove(execution_id): + removed += 1 + + if removed: + self._logger.info(f"Cleaned {removed} stale executions from queue") + + return removed diff --git a/backend/app/db/repositories/execution_state_repository.py b/backend/app/db/repositories/execution_state_repository.py new file mode 100644 index 00000000..e343ff02 --- /dev/null +++ b/backend/app/db/repositories/execution_state_repository.py @@ -0,0 +1,65 @@ +"""Redis-backed execution state tracking repository. + +Replaces in-memory state tracking (_active_executions sets) with Redis +for stateless, horizontally-scalable services. +""" + +from __future__ import annotations + +import logging + +import redis.asyncio as redis + + +class ExecutionStateRepository: + """Redis-backed execution state tracking. + + Provides atomic claim/release operations for executions, + replacing in-memory sets like `_active_executions`. + """ + + KEY_PREFIX = "exec:active" + + def __init__(self, redis_client: redis.Redis, logger: logging.Logger) -> None: + self._redis = redis_client + self._logger = logger + + async def try_claim(self, execution_id: str, ttl_seconds: int = 3600) -> bool: + """Atomically claim an execution. Returns True if claimed, False if already claimed. + + Uses Redis SETNX for atomic check-and-set. + TTL ensures cleanup if service crashes without releasing. + """ + key = f"{self.KEY_PREFIX}:{execution_id}" + result = await self._redis.set(key, "1", nx=True, ex=ttl_seconds) + if result: + self._logger.debug(f"Claimed execution {execution_id}") + return result is not None + + async def is_active(self, execution_id: str) -> bool: + """Check if an execution is currently active/claimed.""" + key = f"{self.KEY_PREFIX}:{execution_id}" + result = await self._redis.exists(key) + return bool(result) + + async def remove(self, execution_id: str) -> bool: + """Release/remove an execution claim. Returns True if was claimed.""" + key = f"{self.KEY_PREFIX}:{execution_id}" + deleted = await self._redis.delete(key) + if deleted: + self._logger.debug(f"Released execution {execution_id}") + return bool(deleted > 0) + + async def get_active_count(self) -> int: + """Get count of active executions. For metrics only.""" + pattern = f"{self.KEY_PREFIX}:*" + count = 0 + async for _ in self._redis.scan_iter(match=pattern, count=100): + count += 1 + return count + + async def extend_ttl(self, execution_id: str, ttl_seconds: int = 3600) -> bool: + """Extend the TTL of an active execution. Returns True if extended.""" + key = f"{self.KEY_PREFIX}:{execution_id}" + result = await self._redis.expire(key, ttl_seconds) + return bool(result) diff --git a/backend/app/db/repositories/pod_state_repository.py b/backend/app/db/repositories/pod_state_repository.py new file mode 100644 index 00000000..0e652720 --- /dev/null +++ b/backend/app/db/repositories/pod_state_repository.py @@ -0,0 +1,180 @@ +"""Redis-backed pod state tracking repository. + +Replaces in-memory pod state tracking (_tracked_pods, _active_creations) +for stateless, horizontally-scalable services. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from datetime import datetime, timezone + +import redis.asyncio as redis + + +@dataclass +class PodState: + """State of a tracked pod.""" + + pod_name: str + execution_id: str + status: str + created_at: datetime + updated_at: datetime + metadata: dict[str, object] | None = None + + +class PodStateRepository: + """Redis-backed pod state tracking. + + Provides atomic operations for pod creation tracking, + replacing in-memory sets like `_active_creations` and `_tracked_pods`. + """ + + CREATION_KEY_PREFIX = "pod:creating" + TRACKED_KEY_PREFIX = "pod:tracked" + RESOURCE_VERSION_KEY = "pod:resource_version" + + def __init__(self, redis_client: redis.Redis, logger: logging.Logger) -> None: + self._redis = redis_client + self._logger = logger + + # --- Active Creations (for KubernetesWorker) --- + + async def try_claim_creation(self, execution_id: str, ttl_seconds: int = 300) -> bool: + """Atomically claim a pod creation slot. Returns True if claimed.""" + key = f"{self.CREATION_KEY_PREFIX}:{execution_id}" + result = await self._redis.set(key, "1", nx=True, ex=ttl_seconds) + if result: + self._logger.debug(f"Claimed pod creation for {execution_id}") + return result is not None + + async def release_creation(self, execution_id: str) -> bool: + """Release a pod creation claim.""" + key = f"{self.CREATION_KEY_PREFIX}:{execution_id}" + deleted = await self._redis.delete(key) + if deleted: + self._logger.debug(f"Released pod creation for {execution_id}") + return bool(deleted) + + async def get_active_creations_count(self) -> int: + """Get count of active pod creations.""" + count = 0 + async for _ in self._redis.scan_iter(match=f"{self.CREATION_KEY_PREFIX}:*", count=100): + count += 1 + return count + + async def is_creation_active(self, execution_id: str) -> bool: + """Check if a pod creation is active.""" + key = f"{self.CREATION_KEY_PREFIX}:{execution_id}" + result = await self._redis.exists(key) + return bool(result) + + # --- Tracked Pods (for PodMonitor) --- + + async def track_pod( + self, + pod_name: str, + execution_id: str, + status: str, + metadata: dict[str, object] | None = None, + ttl_seconds: int = 7200, + ) -> None: + """Track a pod's state.""" + key = f"{self.TRACKED_KEY_PREFIX}:{pod_name}" + now = datetime.now(timezone.utc).isoformat() + + data = { + "pod_name": pod_name, + "execution_id": execution_id, + "status": status, + "created_at": now, + "updated_at": now, + "metadata": json.dumps(metadata) if metadata else "{}", + } + + await self._redis.hset(key, mapping=data) # type: ignore[misc] + await self._redis.expire(key, ttl_seconds) + self._logger.debug(f"Tracking pod {pod_name} for execution {execution_id}") + + async def update_pod_status(self, pod_name: str, status: str) -> bool: + """Update a tracked pod's status. Returns True if updated.""" + key = f"{self.TRACKED_KEY_PREFIX}:{pod_name}" + exists = await self._redis.exists(key) + if not exists: + return False + + now = datetime.now(timezone.utc).isoformat() + await self._redis.hset(key, mapping={"status": status, "updated_at": now}) # type: ignore[misc] + return True + + async def untrack_pod(self, pod_name: str) -> bool: + """Remove a pod from tracking. Returns True if removed.""" + key = f"{self.TRACKED_KEY_PREFIX}:{pod_name}" + deleted = await self._redis.delete(key) + if deleted: + self._logger.debug(f"Untracked pod {pod_name}") + return bool(deleted) + + async def get_pod_state(self, pod_name: str) -> PodState | None: + """Get state of a tracked pod.""" + key = f"{self.TRACKED_KEY_PREFIX}:{pod_name}" + data: dict[bytes | str, bytes | str] = await self._redis.hgetall(key) # type: ignore[misc] + if not data: + return None + + def get_str(k: str) -> str: + val = data.get(k.encode(), data.get(k, "")) + return val.decode() if isinstance(val, bytes) else str(val) + + metadata_str = get_str("metadata") + try: + metadata = json.loads(metadata_str) if metadata_str else None + except json.JSONDecodeError: + metadata = None + + return PodState( + pod_name=get_str("pod_name"), + execution_id=get_str("execution_id"), + status=get_str("status"), + created_at=datetime.fromisoformat(get_str("created_at")), + updated_at=datetime.fromisoformat(get_str("updated_at")), + metadata=metadata, + ) + + async def is_pod_tracked(self, pod_name: str) -> bool: + """Check if a pod is being tracked.""" + key = f"{self.TRACKED_KEY_PREFIX}:{pod_name}" + result = await self._redis.exists(key) + return bool(result) + + async def get_tracked_pods_count(self) -> int: + """Get count of tracked pods.""" + count = 0 + async for _ in self._redis.scan_iter(match=f"{self.TRACKED_KEY_PREFIX}:*", count=100): + count += 1 + return count + + async def get_tracked_pod_names(self) -> set[str]: + """Get set of all tracked pod names.""" + names: set[str] = set() + prefix_len = len(self.TRACKED_KEY_PREFIX) + 1 + async for key in self._redis.scan_iter(match=f"{self.TRACKED_KEY_PREFIX}:*", count=100): + key_str = key.decode() if isinstance(key, bytes) else key + names.add(key_str[prefix_len:]) + return names + + # --- Resource Version (for PodMonitor watch) --- + + async def get_resource_version(self) -> str | None: + """Get the last known resource version for watch resumption.""" + result = await self._redis.get(self.RESOURCE_VERSION_KEY) + if result: + return result.decode() if isinstance(result, bytes) else result + return None + + async def set_resource_version(self, version: str) -> None: + """Store the resource version for watch resumption.""" + await self._redis.set(self.RESOURCE_VERSION_KEY, version) diff --git a/backend/app/db/repositories/resource_repository.py b/backend/app/db/repositories/resource_repository.py new file mode 100644 index 00000000..1f6b54b0 --- /dev/null +++ b/backend/app/db/repositories/resource_repository.py @@ -0,0 +1,300 @@ +"""Redis-backed resource allocation repository. + +Replaces in-memory resource tracking (ResourceManager) with Redis +for stateless, horizontally-scalable services. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import redis.asyncio as redis + + +@dataclass +class ResourceAllocation: + """Resource allocation for an execution.""" + + execution_id: str + cpu_cores: float + memory_mb: int + gpu_count: int = 0 + + @property + def cpu_millicores(self) -> int: + """Get CPU in millicores for Kubernetes.""" + return int(self.cpu_cores * 1000) + + @property + def memory_bytes(self) -> int: + """Get memory in bytes.""" + return self.memory_mb * 1024 * 1024 + + +@dataclass +class ResourceStats: + """Resource statistics.""" + + total_cpu: float + total_memory_mb: int + total_gpu: int + available_cpu: float + available_memory_mb: int + available_gpu: int + allocation_count: int + + +class ResourceRepository: + """Redis-backed resource allocation tracking. + + Uses Redis for atomic resource allocation with Lua scripts. + Replaces in-memory ResourceManager._allocations dict. + """ + + POOL_KEY = "resource:pool" + ALLOC_KEY_PREFIX = "resource:alloc" + + # Default allocations by language + DEFAULT_ALLOCATIONS = { + "python": (0.5, 512), + "javascript": (0.5, 512), + "go": (0.25, 256), + "rust": (0.5, 512), + "java": (1.0, 1024), + "cpp": (0.5, 512), + "r": (1.0, 2048), + } + + def __init__( + self, + redis_client: redis.Redis, + logger: logging.Logger, + total_cpu_cores: float = 32.0, + total_memory_mb: int = 65536, + total_gpu_count: int = 0, + overcommit_factor: float = 1.2, + max_cpu_per_execution: float = 4.0, + max_memory_per_execution_mb: int = 8192, + min_reserve_cpu: float = 2.0, + min_reserve_memory_mb: int = 4096, + ) -> None: + self._redis = redis_client + self._logger = logger + + # Apply overcommit + self._total_cpu = total_cpu_cores * overcommit_factor + self._total_memory = int(total_memory_mb * overcommit_factor) + self._total_gpu = total_gpu_count + + self._max_cpu_per_exec = max_cpu_per_execution + self._max_memory_per_exec = max_memory_per_execution_mb + + # Adjust reserves for small pools (max 10% of total) + self._min_reserve_cpu = min(min_reserve_cpu, 0.1 * self._total_cpu) + self._min_reserve_memory = min(min_reserve_memory_mb, int(0.1 * self._total_memory)) + + async def initialize(self) -> None: + """Initialize the resource pool if not exists.""" + exists = await self._redis.exists(self.POOL_KEY) + if not exists: + await self._redis.hset( # type: ignore[misc] + self.POOL_KEY, + mapping={ + "total_cpu": str(self._total_cpu), + "total_memory": str(self._total_memory), + "total_gpu": str(self._total_gpu), + "available_cpu": str(self._total_cpu), + "available_memory": str(self._total_memory), + "available_gpu": str(self._total_gpu), + }, + ) + self._logger.info( + f"Initialized resource pool: {self._total_cpu} CPU, " + f"{self._total_memory}MB RAM, {self._total_gpu} GPU" + ) + + async def allocate( + self, + execution_id: str, + language: str, + requested_cpu: float | None = None, + requested_memory_mb: int | None = None, + requested_gpu: int = 0, + ) -> ResourceAllocation | None: + """Allocate resources for execution. Returns allocation or None if insufficient.""" + # Check if already allocated + alloc_key = f"{self.ALLOC_KEY_PREFIX}:{execution_id}" + existing = await self._redis.hgetall(alloc_key) # type: ignore[misc] + if existing: + self._logger.warning(f"Execution {execution_id} already has allocation") + return ResourceAllocation( + execution_id=execution_id, + cpu_cores=float(existing.get(b"cpu", existing.get("cpu", 0))), + memory_mb=int(existing.get(b"memory", existing.get("memory", 0))), + gpu_count=int(existing.get(b"gpu", existing.get("gpu", 0))), + ) + + # Determine requested resources + if requested_cpu is None or requested_memory_mb is None: + default_cpu, default_memory = self.DEFAULT_ALLOCATIONS.get(language, (0.5, 512)) + requested_cpu = requested_cpu or default_cpu + requested_memory_mb = requested_memory_mb or default_memory + + # Apply limits + requested_cpu = min(requested_cpu, self._max_cpu_per_exec) + requested_memory_mb = min(requested_memory_mb, self._max_memory_per_exec) + + # Atomic allocation using Lua script + lua_script = """ + local pool_key = KEYS[1] + local alloc_key = KEYS[2] + local req_cpu = tonumber(ARGV[1]) + local req_memory = tonumber(ARGV[2]) + local req_gpu = tonumber(ARGV[3]) + local min_cpu = tonumber(ARGV[4]) + local min_memory = tonumber(ARGV[5]) + + local avail_cpu = tonumber(redis.call('HGET', pool_key, 'available_cpu') or '0') + local avail_memory = tonumber(redis.call('HGET', pool_key, 'available_memory') or '0') + local avail_gpu = tonumber(redis.call('HGET', pool_key, 'available_gpu') or '0') + + local cpu_after = avail_cpu - req_cpu + local memory_after = avail_memory - req_memory + local gpu_after = avail_gpu - req_gpu + + if cpu_after < min_cpu or memory_after < min_memory or gpu_after < 0 then + return 0 + end + + redis.call('HSET', pool_key, 'available_cpu', tostring(cpu_after)) + redis.call('HSET', pool_key, 'available_memory', tostring(memory_after)) + redis.call('HSET', pool_key, 'available_gpu', tostring(gpu_after)) + + redis.call('HSET', alloc_key, 'cpu', tostring(req_cpu), 'memory', tostring(req_memory), + 'gpu', tostring(req_gpu)) + redis.call('EXPIRE', alloc_key, 7200) + + return 1 + """ + + result = await self._redis.eval( # type: ignore[misc] + lua_script, + 2, + self.POOL_KEY, + alloc_key, + str(requested_cpu), + str(requested_memory_mb), + str(requested_gpu), + str(self._min_reserve_cpu), + str(self._min_reserve_memory), + ) + + if not result: + pool = await self._redis.hgetall(self.POOL_KEY) # type: ignore[misc] + avail_cpu = float(pool.get(b"available_cpu", pool.get("available_cpu", 0))) + avail_memory = int(float(pool.get(b"available_memory", pool.get("available_memory", 0)))) + self._logger.warning( + f"Insufficient resources for {execution_id}. " + f"Requested: {requested_cpu} CPU, {requested_memory_mb}MB. " + f"Available: {avail_cpu} CPU, {avail_memory}MB" + ) + return None + + self._logger.info( + f"Allocated resources for {execution_id}: " + f"{requested_cpu} CPU, {requested_memory_mb}MB RAM, {requested_gpu} GPU" + ) + + return ResourceAllocation( + execution_id=execution_id, + cpu_cores=requested_cpu, + memory_mb=requested_memory_mb, + gpu_count=requested_gpu, + ) + + async def release(self, execution_id: str) -> bool: + """Release resource allocation. Returns True if released.""" + alloc_key = f"{self.ALLOC_KEY_PREFIX}:{execution_id}" + + # Get current allocation + alloc = await self._redis.hgetall(alloc_key) # type: ignore[misc] + if not alloc: + self._logger.warning(f"No allocation found for {execution_id}") + return False + + cpu = float(alloc.get(b"cpu", alloc.get("cpu", 0))) + memory = int(float(alloc.get(b"memory", alloc.get("memory", 0)))) + gpu = int(alloc.get(b"gpu", alloc.get("gpu", 0))) + + # Release atomically + pipe = self._redis.pipeline() + pipe.hincrbyfloat(self.POOL_KEY, "available_cpu", cpu) + pipe.hincrbyfloat(self.POOL_KEY, "available_memory", memory) + pipe.hincrby(self.POOL_KEY, "available_gpu", gpu) + pipe.delete(alloc_key) + await pipe.execute() + + self._logger.info(f"Released resources for {execution_id}: {cpu} CPU, {memory}MB RAM, {gpu} GPU") + return True + + async def get_allocation(self, execution_id: str) -> ResourceAllocation | None: + """Get current allocation for execution.""" + alloc_key = f"{self.ALLOC_KEY_PREFIX}:{execution_id}" + alloc = await self._redis.hgetall(alloc_key) # type: ignore[misc] + if not alloc: + return None + + return ResourceAllocation( + execution_id=execution_id, + cpu_cores=float(alloc.get(b"cpu", alloc.get("cpu", 0))), + memory_mb=int(float(alloc.get(b"memory", alloc.get("memory", 0)))), + gpu_count=int(alloc.get(b"gpu", alloc.get("gpu", 0))), + ) + + async def get_stats(self) -> ResourceStats: + """Get resource statistics.""" + pool = await self._redis.hgetall(self.POOL_KEY) # type: ignore[misc] + + # Decode bytes if needed + def get_val(key: str, default: str = "0") -> str: + return str(pool.get(key.encode(), pool.get(key, default))) + + total_cpu = float(get_val("total_cpu")) + total_memory = int(float(get_val("total_memory"))) + total_gpu = int(get_val("total_gpu")) + available_cpu = float(get_val("available_cpu")) + available_memory = int(float(get_val("available_memory"))) + available_gpu = int(get_val("available_gpu")) + + # Count allocations + count = 0 + async for _ in self._redis.scan_iter(match=f"{self.ALLOC_KEY_PREFIX}:*", count=100): + count += 1 + + return ResourceStats( + total_cpu=total_cpu, + total_memory_mb=total_memory, + total_gpu=total_gpu, + available_cpu=available_cpu, + available_memory_mb=available_memory, + available_gpu=available_gpu, + allocation_count=count, + ) + + async def can_allocate(self, cpu_cores: float, memory_mb: int, gpu_count: int = 0) -> bool: + """Check if resources can be allocated.""" + pool = await self._redis.hgetall(self.POOL_KEY) # type: ignore[misc] + + def get_val(key: str) -> float: + return float(pool.get(key.encode(), pool.get(key, 0))) + + available_cpu = get_val("available_cpu") + available_memory = get_val("available_memory") + available_gpu = get_val("available_gpu") + + return ( + (available_cpu - cpu_cores) >= self._min_reserve_cpu + and (available_memory - memory_mb) >= self._min_reserve_memory + and (available_gpu - gpu_count) >= 0 + ) diff --git a/backend/app/dlq/manager.py b/backend/app/dlq/manager.py index 1e20dc23..c1f5472b 100644 --- a/backend/app/dlq/manager.py +++ b/backend/app/dlq/manager.py @@ -1,13 +1,19 @@ +"""DLQ Manager - stateless event handler. + +Manages Dead Letter Queue messages. Receives events, +processes them, and handles retries. No lifecycle management. +""" + +from __future__ import annotations + import asyncio import json import logging from datetime import datetime, timezone -from typing import Any, Callable -from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from aiokafka import AIOKafkaProducer from opentelemetry.trace import SpanKind -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import DLQMetrics from app.core.tracing import EventAttributes from app.core.tracing.utils import extract_trace_context, get_tracer, inject_trace_context @@ -21,7 +27,7 @@ RetryPolicy, RetryStrategy, ) -from app.domain.enums.kafka import GroupId, KafkaTopic +from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import ( DLQMessageDiscardedEvent, DLQMessageReceivedEvent, @@ -32,149 +38,118 @@ from app.settings import Settings -class DLQManager(LifecycleEnabled): +class DLQManager: + """Stateless DLQ manager - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + Worker entrypoint handles the consume loop. + """ + def __init__( self, settings: Settings, - consumer: AIOKafkaConsumer, producer: AIOKafkaProducer, schema_registry: SchemaRegistryManager, logger: logging.Logger, dlq_metrics: DLQMetrics, dlq_topic: KafkaTopic = KafkaTopic.DEAD_LETTER_QUEUE, retry_topic_suffix: str = "-retry", - default_retry_policy: RetryPolicy | None = None, - ): - super().__init__() - self.settings = settings - self.metrics = dlq_metrics - self.schema_registry = schema_registry - self.logger = logger - self.dlq_topic = dlq_topic - self.retry_topic_suffix = retry_topic_suffix - self.default_retry_policy = default_retry_policy or RetryPolicy( + ) -> None: + self._settings = settings + self._producer = producer + self._schema_registry = schema_registry + self._logger = logger + self._metrics = dlq_metrics + self._dlq_topic = dlq_topic + self._retry_topic_suffix = retry_topic_suffix + self._default_retry_policy = RetryPolicy( topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF ) - self.consumer: AIOKafkaConsumer = consumer - self.producer: AIOKafkaProducer = producer + self._retry_policies: dict[str, RetryPolicy] = {} + self._filters: list[object] = [] + self._dlq_events_topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.DLQ_EVENTS}" + self._event_metadata = EventMetadata(service_name="dlq-manager", service_version="1.0.0") - self._process_task: asyncio.Task[None] | None = None - self._monitor_task: asyncio.Task[None] | None = None + def set_retry_policy(self, topic: str, policy: RetryPolicy) -> None: + """Set retry policy for a specific topic.""" + self._retry_policies[topic] = policy - # Topic-specific retry policies - self._retry_policies: dict[str, RetryPolicy] = {} + def set_default_retry_policy(self, policy: RetryPolicy) -> None: + """Set the default retry policy.""" + self._default_retry_policy = policy - # Message filters - self._filters: list[Callable[[DLQMessage], bool]] = [] + def add_filter(self, filter_func: object) -> None: + """Add a message filter.""" + self._filters.append(filter_func) - self._dlq_events_topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.DLQ_EVENTS}" - self._event_metadata = EventMetadata(service_name="dlq-manager", service_version="1.0.0") + async def handle_dlq_message(self, raw_message: bytes, headers: dict[str, str]) -> None: + """Handle a DLQ message from Kafka. - def _kafka_msg_to_message(self, msg: Any) -> DLQMessage: - """Parse Kafka ConsumerRecord into DLQMessage.""" - data = json.loads(msg.value) - headers = {k: v.decode() for k, v in (msg.headers or [])} - return DLQMessage(**data, dlq_offset=msg.offset, dlq_partition=msg.partition, headers=headers) - - async def _on_start(self) -> None: - """Start DLQ manager.""" - # Start producer and consumer in parallel for faster startup - await asyncio.gather(self.producer.start(), self.consumer.start()) - - # Start processing tasks - self._process_task = asyncio.create_task(self._process_messages()) - self._monitor_task = asyncio.create_task(self._monitor_dlq()) - - self.logger.info("DLQ Manager started") - - async def _on_stop(self) -> None: - """Stop DLQ manager.""" - # Cancel tasks - for task in [self._process_task, self._monitor_task]: - if task: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # Stop Kafka clients - await self.consumer.stop() - await self.producer.stop() - - self.logger.info("DLQ Manager stopped") - - async def _process_messages(self) -> None: - """Process DLQ messages using async iteration.""" - async for msg in self.consumer: - try: - start = asyncio.get_running_loop().time() - dlq_msg = self._kafka_msg_to_message(msg) - - # Record metrics - self.metrics.record_dlq_message_received(dlq_msg.original_topic, dlq_msg.event.event_type) - self.metrics.record_dlq_message_age((datetime.now(timezone.utc) - dlq_msg.failed_at).total_seconds()) - - # Process with tracing - ctx = extract_trace_context(dlq_msg.headers) - with get_tracer().start_as_current_span( - name="dlq.consume", - context=ctx, - kind=SpanKind.CONSUMER, - attributes={ - EventAttributes.KAFKA_TOPIC: self.dlq_topic, - EventAttributes.EVENT_TYPE: dlq_msg.event.event_type, - EventAttributes.EVENT_ID: dlq_msg.event.event_id, - }, - ): - await self._process_dlq_message(dlq_msg) - - # Commit and record duration - await self.consumer.commit() - self.metrics.record_dlq_processing_duration(asyncio.get_running_loop().time() - start, "process") + Called by worker entrypoint for each message from consume loop. + """ + start = asyncio.get_running_loop().time() - except Exception as e: - self.logger.error(f"Error processing DLQ message: {e}") + try: + data = json.loads(raw_message) + dlq_msg = DLQMessage(**data, headers=headers) + + self._metrics.record_dlq_message_received(dlq_msg.original_topic, dlq_msg.event.event_type) + self._metrics.record_dlq_message_age( + (datetime.now(timezone.utc) - dlq_msg.failed_at).total_seconds() + ) + + ctx = extract_trace_context(dlq_msg.headers) + with get_tracer().start_as_current_span( + name="dlq.consume", + context=ctx, + kind=SpanKind.CONSUMER, + attributes={ + EventAttributes.KAFKA_TOPIC: str(self._dlq_topic), + EventAttributes.EVENT_TYPE: dlq_msg.event.event_type, + EventAttributes.EVENT_ID: dlq_msg.event.event_id, + }, + ): + await self._process_dlq_message(dlq_msg) + + self._metrics.record_dlq_processing_duration( + asyncio.get_running_loop().time() - start, "process" + ) + + except Exception as e: + self._logger.error(f"Error processing DLQ message: {e}") async def _process_dlq_message(self, message: DLQMessage) -> None: - # Apply filters + """Process a DLQ message.""" for filter_func in self._filters: - if not filter_func(message): - self.logger.info("Message filtered out", extra={"event_id": message.event.event_id}) + if not filter_func(message): # type: ignore[operator] + self._logger.info("Message filtered out", extra={"event_id": message.event.event_id}) return - # Store in MongoDB via Beanie await self._store_message(message) - # Get retry policy for topic - retry_policy = self._retry_policies.get(message.original_topic, self.default_retry_policy) + retry_policy = self._retry_policies.get(message.original_topic, self._default_retry_policy) - # Check if should retry if not retry_policy.should_retry(message): await self._discard_message(message, "max_retries_exceeded") return - # Calculate next retry time next_retry = retry_policy.get_next_retry_time(message) - # Update message status await self._update_message_status( message.event.event_id, DLQMessageUpdate(status=DLQMessageStatus.SCHEDULED, next_retry_at=next_retry), ) - # If immediate retry, process now if retry_policy.strategy == RetryStrategy.IMMEDIATE: await self._retry_message(message) async def _store_message(self, message: DLQMessage) -> None: - # Ensure message has proper status and timestamps + """Store DLQ message in MongoDB.""" message.status = DLQMessageStatus.PENDING message.last_updated = datetime.now(timezone.utc) doc = DLQMessageDocument(**message.model_dump()) - # Upsert using Beanie existing = await DLQMessageDocument.find_one({"event.event_id": message.event.event_id}) if existing: doc.id = existing.id @@ -183,11 +158,12 @@ async def _store_message(self, message: DLQMessage) -> None: await self._emit_message_received_event(message) async def _update_message_status(self, event_id: str, update: DLQMessageUpdate) -> None: + """Update DLQ message status.""" doc = await DLQMessageDocument.find_one({"event.event_id": event_id}) if not doc: return - update_dict: dict[str, Any] = {"status": update.status, "last_updated": datetime.now(timezone.utc)} + update_dict: dict[str, object] = {"status": update.status, "last_updated": datetime.now(timezone.utc)} if update.next_retry_at is not None: update_dict["next_retry_at"] = update.next_retry_at if update.retried_at is not None: @@ -204,8 +180,8 @@ async def _update_message_status(self, event_id: str, update: DLQMessageUpdate) await doc.set(update_dict) async def _retry_message(self, message: DLQMessage) -> None: - # Send to retry topic first (for monitoring) - retry_topic = f"{message.original_topic}{self.retry_topic_suffix}" + """Retry a DLQ message.""" + retry_topic = f"{message.original_topic}{self._retry_topic_suffix}" hdrs: dict[str, str] = { "dlq_retry_count": str(message.retry_count + 1), @@ -215,31 +191,26 @@ async def _retry_message(self, message: DLQMessage) -> None: hdrs = inject_trace_context(hdrs) kafka_headers: list[tuple[str, bytes]] = [(k, v.encode()) for k, v in hdrs.items()] - # Get the original event event = message.event - # Send to retry topic - await self.producer.send_and_wait( + await self._producer.send_and_wait( topic=retry_topic, value=json.dumps(event.model_dump(mode="json")).encode(), key=message.event.event_id.encode(), headers=kafka_headers, ) - # Send to original topic - await self.producer.send_and_wait( + await self._producer.send_and_wait( topic=message.original_topic, value=json.dumps(event.model_dump(mode="json")).encode(), key=message.event.event_id.encode(), headers=kafka_headers, ) - # Update metrics - self.metrics.record_dlq_message_retried(message.original_topic, message.event.event_type, "success") + self._metrics.record_dlq_message_retried(message.original_topic, message.event.event_type, "success") new_retry_count = message.retry_count + 1 - # Update status await self._update_message_status( message.event.event_id, DLQMessageUpdate( @@ -249,16 +220,14 @@ async def _retry_message(self, message: DLQMessage) -> None: ), ) - # Emit DLQ message retried event await self._emit_message_retried_event(message, retry_topic, new_retry_count) - self.logger.info("Successfully retried message", extra={"event_id": message.event.event_id}) + self._logger.info("Successfully retried message", extra={"event_id": message.event.event_id}) async def _discard_message(self, message: DLQMessage, reason: str) -> None: - # Update metrics - self.metrics.record_dlq_message_discarded(message.original_topic, message.event.event_type, reason) + """Discard a DLQ message.""" + self._metrics.record_dlq_message_discarded(message.original_topic, message.event.event_type, reason) - # Update status await self._update_message_status( message.event.event_id, DLQMessageUpdate( @@ -270,57 +239,49 @@ async def _discard_message(self, message: DLQMessage, reason: str) -> None: await self._emit_message_discarded_event(message, reason) - self.logger.warning("Discarded message", extra={"event_id": message.event.event_id, "reason": reason}) + self._logger.warning("Discarded message", extra={"event_id": message.event.event_id, "reason": reason}) - async def _monitor_dlq(self) -> None: - while self.is_running: - try: - # Find messages ready for retry using Beanie - now = datetime.now(timezone.utc) - - docs = ( - await DLQMessageDocument.find( - { - "status": DLQMessageStatus.SCHEDULED, - "next_retry_at": {"$lte": now}, - } - ) - .limit(100) - .to_list() - ) - - for doc in docs: - message = DLQMessage.model_validate(doc, from_attributes=True) - await self._retry_message(message) - - # Update queue size metrics - await self._update_queue_metrics() - - # Sleep before next check - await asyncio.sleep(10) + async def check_scheduled_retries(self, batch_size: int = 100) -> int: + """Check for scheduled messages ready for retry. - except Exception as e: - self.logger.error(f"Error in DLQ monitor: {e}") - await asyncio.sleep(60) + Should be called periodically from worker entrypoint. + Returns number of messages retried. + """ + now = datetime.now(timezone.utc) + + docs = ( + await DLQMessageDocument.find( + { + "status": DLQMessageStatus.SCHEDULED, + "next_retry_at": {"$lte": now}, + } + ) + .limit(batch_size) + .to_list() + ) + + count = 0 + for doc in docs: + message = DLQMessage.model_validate(doc, from_attributes=True) + await self._retry_message(message) + count += 1 + + await self._update_queue_metrics() + + return count async def _update_queue_metrics(self) -> None: - # Get counts by topic using Beanie aggregation - pipeline: list[dict[str, Any]] = [ + """Update queue size metrics.""" + pipeline: list[dict[str, object]] = [ {"$match": {"status": {"$in": [DLQMessageStatus.PENDING, DLQMessageStatus.SCHEDULED]}}}, {"$group": {"_id": "$original_topic", "count": {"$sum": 1}}}, ] async for result in DLQMessageDocument.aggregate(pipeline): - self.metrics.update_dlq_queue_size(result["_id"], result["count"]) - - def set_retry_policy(self, topic: str, policy: RetryPolicy) -> None: - self._retry_policies[topic] = policy - - def add_filter(self, filter_func: Callable[[DLQMessage], bool]) -> None: - self._filters.append(filter_func) + self._metrics.update_dlq_queue_size(result["_id"], result["count"]) async def _emit_message_received_event(self, message: DLQMessage) -> None: - """Emit a DLQMessageReceivedEvent to the DLQ events topic.""" + """Emit a DLQMessageReceivedEvent.""" event = DLQMessageReceivedEvent( dlq_event_id=message.event.event_id, original_topic=message.original_topic, @@ -333,8 +294,10 @@ async def _emit_message_received_event(self, message: DLQMessage) -> None: ) await self._produce_dlq_event(event) - async def _emit_message_retried_event(self, message: DLQMessage, retry_topic: str, new_retry_count: int) -> None: - """Emit a DLQMessageRetriedEvent to the DLQ events topic.""" + async def _emit_message_retried_event( + self, message: DLQMessage, retry_topic: str, new_retry_count: int + ) -> None: + """Emit a DLQMessageRetriedEvent.""" event = DLQMessageRetriedEvent( dlq_event_id=message.event.event_id, original_topic=message.original_topic, @@ -346,7 +309,7 @@ async def _emit_message_retried_event(self, message: DLQMessage, retry_topic: st await self._produce_dlq_event(event) async def _emit_message_discarded_event(self, message: DLQMessage, reason: str) -> None: - """Emit a DLQMessageDiscardedEvent to the DLQ events topic.""" + """Emit a DLQMessageDiscardedEvent.""" event = DLQMessageDiscardedEvent( dlq_event_id=message.event.event_id, original_topic=message.original_topic, @@ -360,26 +323,26 @@ async def _emit_message_discarded_event(self, message: DLQMessage, reason: str) async def _produce_dlq_event( self, event: DLQMessageReceivedEvent | DLQMessageRetriedEvent | DLQMessageDiscardedEvent ) -> None: - """Produce a DLQ lifecycle event to the DLQ events topic.""" + """Produce a DLQ lifecycle event.""" try: - serialized = await self.schema_registry.serialize_event(event) - await self.producer.send_and_wait( + serialized = await self._schema_registry.serialize_event(event) + await self._producer.send_and_wait( topic=self._dlq_events_topic, value=serialized, key=event.event_id.encode(), ) except Exception as e: - self.logger.error(f"Failed to emit DLQ event {event.event_type}: {e}") + self._logger.error(f"Failed to emit DLQ event {event.event_type}: {e}") async def retry_message_manually(self, event_id: str) -> bool: + """Manually retry a DLQ message.""" doc = await DLQMessageDocument.find_one({"event.event_id": event_id}) if not doc: - self.logger.error("Message not found in DLQ", extra={"event_id": event_id}) + self._logger.error("Message not found in DLQ", extra={"event_id": event_id}) return False - # Guard against invalid states if doc.status in {DLQMessageStatus.DISCARDED, DLQMessageStatus.RETRIED}: - self.logger.info("Skipping manual retry", extra={"event_id": event_id, "status": doc.status}) + self._logger.info("Skipping manual retry", extra={"event_id": event_id, "status": doc.status}) return False message = DLQMessage.model_validate(doc, from_attributes=True) @@ -387,14 +350,7 @@ async def retry_message_manually(self, event_id: str) -> bool: return True async def retry_messages_batch(self, event_ids: list[str]) -> DLQBatchRetryResult: - """Retry multiple DLQ messages in batch. - - Args: - event_ids: List of event IDs to retry - - Returns: - Batch result with success/failure counts and details - """ + """Retry multiple DLQ messages in batch.""" details: list[DLQRetryResult] = [] successful = 0 failed = 0 @@ -409,78 +365,24 @@ async def retry_messages_batch(self, event_ids: list[str]) -> DLQBatchRetryResul failed += 1 details.append(DLQRetryResult(event_id=event_id, status="failed", error="Retry failed")) except Exception as e: - self.logger.error(f"Error retrying message {event_id}: {e}") + self._logger.error(f"Error retrying message {event_id}: {e}") failed += 1 details.append(DLQRetryResult(event_id=event_id, status="failed", error=str(e))) return DLQBatchRetryResult(total=len(event_ids), successful=successful, failed=failed, details=details) async def discard_message_manually(self, event_id: str, reason: str) -> bool: - """Manually discard a DLQ message with state validation. - - Args: - event_id: The event ID to discard - reason: Reason for discarding - - Returns: - True if discarded, False if not found or in terminal state - """ + """Manually discard a DLQ message.""" doc = await DLQMessageDocument.find_one({"event.event_id": event_id}) if not doc: - self.logger.error("Message not found in DLQ", extra={"event_id": event_id}) + self._logger.error("Message not found in DLQ", extra={"event_id": event_id}) return False - # Guard against invalid states (terminal states) if doc.status in {DLQMessageStatus.DISCARDED, DLQMessageStatus.RETRIED}: - self.logger.info("Skipping manual discard", extra={"event_id": event_id, "status": doc.status}) + self._logger.info("Skipping manual discard", extra={"event_id": event_id, "status": doc.status}) return False message = DLQMessage.model_validate(doc, from_attributes=True) await self._discard_message(message, reason) return True - -def create_dlq_manager( - settings: Settings, - schema_registry: SchemaRegistryManager, - logger: logging.Logger, - dlq_metrics: DLQMetrics, - dlq_topic: KafkaTopic = KafkaTopic.DEAD_LETTER_QUEUE, - retry_topic_suffix: str = "-retry", - default_retry_policy: RetryPolicy | None = None, -) -> DLQManager: - topic_name = f"{settings.KAFKA_TOPIC_PREFIX}{dlq_topic}" - consumer = AIOKafkaConsumer( - topic_name, - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=GroupId.DLQ_MANAGER, - enable_auto_commit=False, - auto_offset_reset="earliest", - client_id="dlq-manager-consumer", - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - producer = AIOKafkaProducer( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - client_id="dlq-manager-producer", - acks="all", - compression_type="gzip", - max_batch_size=16384, - linger_ms=10, - enable_idempotence=True, - ) - if default_retry_policy is None: - default_retry_policy = RetryPolicy(topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF) - return DLQManager( - settings=settings, - consumer=consumer, - producer=producer, - schema_registry=schema_registry, - logger=logger, - dlq_metrics=dlq_metrics, - dlq_topic=dlq_topic, - retry_topic_suffix=retry_topic_suffix, - default_retry_policy=default_retry_policy, - ) diff --git a/backend/app/events/core/__init__.py b/backend/app/events/core/__init__.py index 3b12df76..1723502a 100644 --- a/backend/app/events/core/__init__.py +++ b/backend/app/events/core/__init__.py @@ -8,15 +8,11 @@ from .types import ( ConsumerConfig, ConsumerMetrics, - ConsumerState, ProducerMetrics, - ProducerState, ) __all__ = [ # Types - "ProducerState", - "ConsumerState", "ConsumerConfig", "ProducerMetrics", "ConsumerMetrics", diff --git a/backend/app/events/core/consumer.py b/backend/app/events/core/consumer.py index d0532f37..3365af98 100644 --- a/backend/app/events/core/consumer.py +++ b/backend/app/events/core/consumer.py @@ -1,258 +1,69 @@ -import asyncio +"""Unified Kafka consumer - pure message handler. + +Handles deserialization, dispatch, and metrics for Kafka messages. +No lifecycle, no properties, no state - just handle(). +Worker gets AIOKafkaConsumer directly from DI. +""" + +from __future__ import annotations + import logging -from collections.abc import Awaitable, Callable -from datetime import datetime, timezone -from typing import Any -from aiokafka import AIOKafkaConsumer, TopicPartition -from aiokafka.errors import KafkaError +from aiokafka import ConsumerRecord from opentelemetry.trace import SpanKind from app.core.metrics import EventMetrics from app.core.tracing import EventAttributes from app.core.tracing.utils import extract_trace_context, get_tracer -from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent from app.events.schema.schema_registry import SchemaRegistryManager -from app.settings import Settings from .dispatcher import EventDispatcher -from .types import ConsumerConfig, ConsumerMetrics, ConsumerMetricsSnapshot, ConsumerState, ConsumerStatus class UnifiedConsumer: + """Pure message handler - deserialize, dispatch, record metrics.""" + def __init__( self, - config: ConsumerConfig, event_dispatcher: EventDispatcher, schema_registry: SchemaRegistryManager, - settings: Settings, logger: logging.Logger, event_metrics: EventMetrics, - ): - self._config = config - self.logger = logger - self._schema_registry = schema_registry + group_id: str, + ) -> None: self._dispatcher = event_dispatcher - self._consumer: AIOKafkaConsumer | None = None - self._state = ConsumerState.STOPPED - self._running = False - self._metrics = ConsumerMetrics() + self._schema_registry = schema_registry + self._logger = logger self._event_metrics = event_metrics - self._error_callback: "Callable[[Exception, DomainEvent], Awaitable[None]] | None" = None - self._consume_task: asyncio.Task[None] | None = None - self._topic_prefix = settings.KAFKA_TOPIC_PREFIX - - async def start(self, topics: list[KafkaTopic]) -> None: - self._state = self._state if self._state != ConsumerState.STOPPED else ConsumerState.STARTING - - topic_strings = [f"{self._topic_prefix}{str(topic)}" for topic in topics] - - self._consumer = AIOKafkaConsumer( - *topic_strings, - bootstrap_servers=self._config.bootstrap_servers, - group_id=self._config.group_id, - client_id=self._config.client_id, - auto_offset_reset=self._config.auto_offset_reset, - enable_auto_commit=self._config.enable_auto_commit, - session_timeout_ms=self._config.session_timeout_ms, - heartbeat_interval_ms=self._config.heartbeat_interval_ms, - max_poll_interval_ms=self._config.max_poll_interval_ms, - request_timeout_ms=self._config.request_timeout_ms, - fetch_min_bytes=self._config.fetch_min_bytes, - fetch_max_wait_ms=self._config.fetch_max_wait_ms, - ) - - await self._consumer.start() - self._running = True - self._consume_task = asyncio.create_task(self._consume_loop()) - - self._state = ConsumerState.RUNNING - - self.logger.info(f"Consumer started for topics: {topic_strings}") - - async def stop(self) -> None: - self._state = ( - ConsumerState.STOPPING - if self._state not in (ConsumerState.STOPPED, ConsumerState.STOPPING) - else self._state - ) - - self._running = False - - if self._consume_task: - self._consume_task.cancel() - await asyncio.gather(self._consume_task, return_exceptions=True) - self._consume_task = None - - await self._cleanup() - self._state = ConsumerState.STOPPED - - async def _cleanup(self) -> None: - if self._consumer: - await self._consumer.stop() - self._consumer = None - - async def _consume_loop(self) -> None: - self.logger.info(f"Consumer loop started for group {self._config.group_id}") - poll_count = 0 - message_count = 0 - - while self._running and self._consumer: - poll_count += 1 - if poll_count % 100 == 0: # Log every 100 polls - self.logger.debug(f"Consumer loop active: polls={poll_count}, messages={message_count}") - + self._group_id = group_id + + async def handle(self, msg: ConsumerRecord) -> DomainEvent | None: + """Handle a Kafka message - deserialize, dispatch, record metrics.""" + if msg.value is None: + return None + + event = await self._schema_registry.deserialize_event(msg.value, msg.topic) + headers = {k: v.decode() for k, v in msg.headers} + + with get_tracer().start_as_current_span( + name="kafka.consume", + context=extract_trace_context(headers), + kind=SpanKind.CONSUMER, + attributes={ + EventAttributes.KAFKA_TOPIC: msg.topic, + EventAttributes.KAFKA_PARTITION: msg.partition, + EventAttributes.KAFKA_OFFSET: msg.offset, + EventAttributes.EVENT_TYPE: event.event_type, + EventAttributes.EVENT_ID: event.event_id, + }, + ): try: - # Use getone() with timeout for single message consumption - msg = await asyncio.wait_for( - self._consumer.getone(), - timeout=0.1 - ) - - message_count += 1 - self.logger.debug( - f"Message received from topic {msg.topic}, partition {msg.partition}, offset {msg.offset}" - ) - await self._process_message(msg) - if not self._config.enable_auto_commit: - await self._consumer.commit() - - except asyncio.TimeoutError: - # No message available within timeout, continue polling - await asyncio.sleep(0.01) - except KafkaError as e: - self.logger.error(f"Consumer error: {e}") - self._metrics.processing_errors += 1 - - self.logger.warning( - f"Consumer loop ended for group {self._config.group_id}: " - f"running={self._running}, consumer={self._consumer is not None}" - ) - - async def _process_message(self, message: Any) -> None: - """Process a ConsumerRecord from aiokafka.""" - topic = message.topic - if not topic: - self.logger.warning("Message with no topic received") - return - - raw_value = message.value - if not raw_value: - self.logger.warning(f"Empty message from topic {topic}") - return - - self.logger.debug(f"Deserializing message from topic {topic}, size={len(raw_value)} bytes") - event = await self._schema_registry.deserialize_event(raw_value, topic) - self.logger.info(f"Deserialized event: type={event.event_type}, id={event.event_id}") - - # Extract trace context from Kafka headers and start a consumer span - # aiokafka headers are list of tuples: [(key, value), ...] - header_list = message.headers or [] - headers: dict[str, str] = {} - for k, v in header_list: - headers[str(k)] = v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else (v or "") - ctx = extract_trace_context(headers) - tracer = get_tracer() - - # Dispatch event through EventDispatcher - try: - self.logger.debug(f"Dispatching {event.event_type} to handlers") - partition_val = message.partition - offset_val = message.offset - part_attr = partition_val if partition_val is not None else -1 - off_attr = offset_val if offset_val is not None else -1 - with tracer.start_as_current_span( - name="kafka.consume", - context=ctx, - kind=SpanKind.CONSUMER, - attributes={ - EventAttributes.KAFKA_TOPIC: topic, - EventAttributes.KAFKA_PARTITION: part_attr, - EventAttributes.KAFKA_OFFSET: off_attr, - EventAttributes.EVENT_TYPE: event.event_type, - EventAttributes.EVENT_ID: event.event_id, - }, - ): await self._dispatcher.dispatch(event) - self.logger.debug(f"Successfully dispatched {event.event_type}") - # Update metrics on successful dispatch - self._metrics.messages_consumed += 1 - self._metrics.bytes_consumed += len(raw_value) - self._metrics.last_message_time = datetime.now(timezone.utc) - # Record Kafka consumption metrics - self._event_metrics.record_kafka_message_consumed(topic=topic, consumer_group=self._config.group_id) - except Exception as e: - self.logger.error(f"Dispatcher error for event {event.event_type}: {e}") - self._metrics.processing_errors += 1 - # Record Kafka consumption error - self._event_metrics.record_kafka_consumption_error( - topic=topic, consumer_group=self._config.group_id, error_type=type(e).__name__ - ) - if self._error_callback: - await self._error_callback(e, event) - - def register_error_callback(self, callback: Callable[[Exception, DomainEvent], Awaitable[None]]) -> None: - self._error_callback = callback - - @property - def state(self) -> ConsumerState: - return self._state - - @property - def metrics(self) -> ConsumerMetrics: - return self._metrics - - @property - def is_running(self) -> bool: - return self._state == ConsumerState.RUNNING - - @property - def consumer(self) -> AIOKafkaConsumer | None: - return self._consumer - - def get_status(self) -> ConsumerStatus: - return ConsumerStatus( - state=self._state, - is_running=self.is_running, - group_id=self._config.group_id, - client_id=self._config.client_id, - metrics=ConsumerMetricsSnapshot( - messages_consumed=self._metrics.messages_consumed, - bytes_consumed=self._metrics.bytes_consumed, - consumer_lag=self._metrics.consumer_lag, - commit_failures=self._metrics.commit_failures, - processing_errors=self._metrics.processing_errors, - last_message_time=self._metrics.last_message_time, - last_updated=self._metrics.last_updated, - ), - ) - - async def seek_to_beginning(self) -> None: - """Seek all assigned partitions to the beginning.""" - if not self._consumer: - self.logger.warning("Cannot seek: consumer not initialized") - return - - assignment = self._consumer.assignment() - if assignment: - await self._consumer.seek_to_beginning(*assignment) - - async def seek_to_end(self) -> None: - """Seek all assigned partitions to the end.""" - if not self._consumer: - self.logger.warning("Cannot seek: consumer not initialized") - return - - assignment = self._consumer.assignment() - if assignment: - await self._consumer.seek_to_end(*assignment) - - async def seek_to_offset(self, topic: str, partition: int, offset: int) -> None: - """Seek a specific partition to a specific offset.""" - if not self._consumer: - self.logger.warning("Cannot seek to offset: consumer not initialized") - return + self._event_metrics.record_kafka_message_consumed(msg.topic, self._group_id) + except Exception as e: + self._logger.error(f"Dispatch error: {event.event_type}: {e}") + self._event_metrics.record_kafka_consumption_error(msg.topic, self._group_id, type(e).__name__) + raise - tp = TopicPartition(topic, partition) - self._consumer.seek(tp, offset) + return event diff --git a/backend/app/events/core/producer.py b/backend/app/events/core/producer.py index a41188c7..98ab0b74 100644 --- a/backend/app/events/core/producer.py +++ b/backend/app/events/core/producer.py @@ -1,122 +1,60 @@ +"""Unified Kafka producer - thin wrapper over AIOKafkaProducer. + +The producer receives a ready-to-use AIOKafkaProducer from DI. +No lifecycle management - DI provider handles start/stop. +""" + +from __future__ import annotations + import asyncio import json import logging import socket from datetime import datetime, timezone -from typing import Any from aiokafka import AIOKafkaProducer from aiokafka.errors import KafkaError -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics from app.dlq.models import DLQMessage, DLQMessageStatus from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent from app.events.schema.schema_registry import SchemaRegistryManager from app.infrastructure.kafka.mappings import EVENT_TYPE_TO_TOPIC -from app.settings import Settings -from .types import ProducerMetrics, ProducerState +from .types import ProducerMetrics -class UnifiedProducer(LifecycleEnabled): - """Fully async Kafka producer using aiokafka.""" +class UnifiedProducer: + """Kafka producer wrapper - receives ready-to-use producer from DI. + + No lifecycle methods (start/stop) - DI provider manages AIOKafkaProducer lifecycle. + """ def __init__( self, + producer: AIOKafkaProducer, schema_registry_manager: SchemaRegistryManager, logger: logging.Logger, - settings: Settings, event_metrics: EventMetrics, - ): - super().__init__() - self._settings = settings + producer_metrics: ProducerMetrics, + topic_prefix: str = "", + ) -> None: + self._producer = producer self._schema_registry = schema_registry_manager - self.logger = logger - self._producer: AIOKafkaProducer | None = None - self._state = ProducerState.STOPPED - self._metrics = ProducerMetrics() + self._logger = logger self._event_metrics = event_metrics - self._topic_prefix = settings.KAFKA_TOPIC_PREFIX - - @property - def is_running(self) -> bool: - return self._state == ProducerState.RUNNING - - @property - def state(self) -> ProducerState: - return self._state - - @property - def metrics(self) -> ProducerMetrics: - return self._metrics - - @property - def producer(self) -> AIOKafkaProducer | None: - return self._producer - - async def _on_start(self) -> None: - """Start the Kafka producer.""" - self._state = ProducerState.STARTING - self.logger.info("Starting producer...") - - self._producer = AIOKafkaProducer( - bootstrap_servers=self._settings.KAFKA_BOOTSTRAP_SERVERS, - client_id=f"{self._settings.SERVICE_NAME}-producer", - acks="all", - compression_type="gzip", - max_batch_size=16384, - linger_ms=10, - enable_idempotence=True, - ) - - await self._producer.start() - self._state = ProducerState.RUNNING - self.logger.info(f"Producer started: {self._settings.KAFKA_BOOTSTRAP_SERVERS}") - - def get_status(self) -> dict[str, Any]: - return { - "state": self._state, - "running": self.is_running, - "config": { - "bootstrap_servers": self._settings.KAFKA_BOOTSTRAP_SERVERS, - "client_id": f"{self._settings.SERVICE_NAME}-producer", - }, - "metrics": { - "messages_sent": self._metrics.messages_sent, - "messages_failed": self._metrics.messages_failed, - "bytes_sent": self._metrics.bytes_sent, - "queue_size": self._metrics.queue_size, - "avg_latency_ms": self._metrics.avg_latency_ms, - "last_error": self._metrics.last_error, - "last_error_time": self._metrics.last_error_time.isoformat() if self._metrics.last_error_time else None, - }, - } - - async def _on_stop(self) -> None: - """Stop the Kafka producer.""" - self._state = ProducerState.STOPPING - self.logger.info("Stopping producer...") - - if self._producer: - await self._producer.stop() - self._producer = None - - self._state = ProducerState.STOPPED - self.logger.info("Producer stopped") + self._topic_prefix = topic_prefix + self._metrics = producer_metrics async def produce( self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None ) -> None: """Produce a message to Kafka.""" - if not self._producer: - self.logger.error("Producer not running") - return + topic = f"{self._topic_prefix}{EVENT_TYPE_TO_TOPIC[event_to_produce.event_type]}" try: serialized_value = await self._schema_registry.serialize_event(event_to_produce) - topic = f"{self._topic_prefix}{EVENT_TYPE_TO_TOPIC[event_to_produce.event_type]}" # Convert headers to list of tuples format header_list = [(k, v.encode()) for k, v in headers.items()] if headers else None @@ -135,24 +73,20 @@ async def produce( # Record Kafka metrics self._event_metrics.record_kafka_message_produced(topic) - self.logger.debug(f"Message [{event_to_produce}] sent to topic: {topic}") + self._logger.debug(f"Message [{event_to_produce}] sent to topic: {topic}") except KafkaError as e: self._metrics.messages_failed += 1 self._metrics.last_error = str(e) self._metrics.last_error_time = datetime.now(timezone.utc) self._event_metrics.record_kafka_production_error(topic=topic, error_type=type(e).__name__) - self.logger.error(f"Failed to produce message: {e}") + self._logger.error(f"Failed to produce message: {e}") raise async def send_to_dlq( self, original_event: DomainEvent, original_topic: str, error: Exception, retry_count: int = 0 ) -> None: """Send a failed event to the Dead Letter Queue.""" - if not self._producer: - self.logger.error("Producer not running, cannot send to DLQ") - return - try: # Get producer ID (hostname + task name) current_task = asyncio.current_task() @@ -202,7 +136,7 @@ async def send_to_dlq( self._event_metrics.record_kafka_message_produced(dlq_topic) self._metrics.messages_sent += 1 - self.logger.warning( + self._logger.warning( f"Event {original_event.event_id} sent to DLQ. " f"Original topic: {original_topic}, Error: {error}, " f"Retry count: {retry_count}" @@ -210,7 +144,7 @@ async def send_to_dlq( except Exception as e: # If we can't send to DLQ, log critically but don't crash - self.logger.critical( + self._logger.critical( f"Failed to send event {original_event.event_id} to DLQ: {e}. Original error: {error}", exc_info=True ) self._metrics.messages_failed += 1 diff --git a/backend/app/events/event_store_consumer.py b/backend/app/events/event_store_consumer.py deleted file mode 100644 index 1dbdb83c..00000000 --- a/backend/app/events/event_store_consumer.py +++ /dev/null @@ -1,190 +0,0 @@ -import asyncio -import logging - -from opentelemetry.trace import SpanKind - -from app.core.lifecycle import LifecycleEnabled -from app.core.metrics import EventMetrics -from app.core.tracing.utils import trace_span -from app.domain.enums.events import EventType -from app.domain.enums.kafka import GroupId, KafkaTopic -from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer, create_dlq_error_handler -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -from app.settings import Settings - - -class EventStoreConsumer(LifecycleEnabled): - """Consumes events from Kafka and stores them in MongoDB.""" - - def __init__( - self, - event_store: EventStore, - topics: list[KafkaTopic], - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - logger: logging.Logger, - event_metrics: EventMetrics, - producer: UnifiedProducer | None = None, - group_id: GroupId = GroupId.EVENT_STORE_CONSUMER, - batch_size: int = 100, - batch_timeout_seconds: float = 5.0, - ): - super().__init__() - self.event_store = event_store - self.topics = topics - self.settings = settings - self.group_id = group_id - self.batch_size = batch_size - self.batch_timeout = batch_timeout_seconds - self.logger = logger - self.event_metrics = event_metrics - self.consumer: UnifiedConsumer | None = None - self.schema_registry_manager = schema_registry_manager - self.dispatcher = EventDispatcher(logger) - self.producer = producer # For DLQ handling - self._batch_buffer: list[DomainEvent] = [] - self._batch_lock = asyncio.Lock() - self._last_batch_time: float = 0.0 - self._batch_task: asyncio.Task[None] | None = None - - async def _on_start(self) -> None: - """Start consuming and storing events.""" - self._last_batch_time = asyncio.get_running_loop().time() - config = ConsumerConfig( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=self.group_id, - enable_auto_commit=False, - max_poll_records=self.batch_size, - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - self.consumer = UnifiedConsumer( - config, - event_dispatcher=self.dispatcher, - schema_registry=self.schema_registry_manager, - settings=self.settings, - logger=self.logger, - event_metrics=self.event_metrics, - ) - - # Register handler for all event types - store everything - for event_type in EventType: - self.dispatcher.register(event_type)(self._handle_event) - - # Register error callback - use DLQ if producer is available - if self.producer: - # Use DLQ handler with retry logic - dlq_handler = create_dlq_error_handler( - producer=self.producer, - original_topic="event-store", # Generic topic name for event store - logger=self.logger, - max_retries=3, - ) - self.consumer.register_error_callback(dlq_handler) - else: - # Fallback to simple logging - self.consumer.register_error_callback(self._handle_error_with_event) - - await self.consumer.start(self.topics) - - self._batch_task = asyncio.create_task(self._batch_processor()) - - self.logger.info(f"Event store consumer started for topics: {self.topics}") - - async def _on_stop(self) -> None: - """Stop consumer.""" - await self._flush_batch() - - if self._batch_task: - self._batch_task.cancel() - try: - await self._batch_task - except asyncio.CancelledError: - pass - - if self.consumer: - await self.consumer.stop() - - self.logger.info("Event store consumer stopped") - - async def _handle_event(self, event: DomainEvent) -> None: - """Handle incoming event from dispatcher.""" - self.logger.info(f"Event store received event: {event.event_type} - {event.event_id}") - - async with self._batch_lock: - self._batch_buffer.append(event) - - if len(self._batch_buffer) >= self.batch_size: - await self._flush_batch() - - async def _handle_error_with_event(self, error: Exception, event: DomainEvent) -> None: - """Handle processing errors with event context.""" - self.logger.error(f"Error processing event {event.event_id} ({event.event_type}): {error}", exc_info=True) - - async def _batch_processor(self) -> None: - """Periodically flush batches based on timeout.""" - while self.is_running: - try: - await asyncio.sleep(1) - - async with self._batch_lock: - time_since_last_batch = asyncio.get_running_loop().time() - self._last_batch_time - - if self._batch_buffer and time_since_last_batch >= self.batch_timeout: - await self._flush_batch() - - except Exception as e: - self.logger.error(f"Error in batch processor: {e}") - - async def _flush_batch(self) -> None: - if not self._batch_buffer: - return - - batch = self._batch_buffer.copy() - self._batch_buffer.clear() - self._last_batch_time = asyncio.get_running_loop().time() - - self.logger.info(f"Event store flushing batch of {len(batch)} events") - with trace_span( - name="event_store.flush_batch", - kind=SpanKind.CONSUMER, - attributes={"events.batch.count": len(batch)}, - ): - results = await self.event_store.store_batch(batch) - - self.logger.info( - f"Stored event batch: total={results['total']}, " - f"stored={results['stored']}, duplicates={results['duplicates']}, " - f"failed={results['failed']}" - ) - - -def create_event_store_consumer( - event_store: EventStore, - topics: list[KafkaTopic], - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - logger: logging.Logger, - event_metrics: EventMetrics, - producer: UnifiedProducer | None = None, - group_id: GroupId = GroupId.EVENT_STORE_CONSUMER, - batch_size: int = 100, - batch_timeout_seconds: float = 5.0, -) -> EventStoreConsumer: - return EventStoreConsumer( - event_store=event_store, - topics=topics, - group_id=group_id, - batch_size=batch_size, - batch_timeout_seconds=batch_timeout_seconds, - schema_registry_manager=schema_registry_manager, - settings=settings, - logger=logger, - event_metrics=event_metrics, - producer=producer, - ) diff --git a/backend/app/services/coordinator/__init__.py b/backend/app/services/coordinator/__init__.py index b3890c9d..2c79d4a3 100644 --- a/backend/app/services/coordinator/__init__.py +++ b/backend/app/services/coordinator/__init__.py @@ -1,11 +1,5 @@ from app.services.coordinator.coordinator import ExecutionCoordinator -from app.services.coordinator.queue_manager import QueueManager, QueuePriority -from app.services.coordinator.resource_manager import ResourceAllocation, ResourceManager __all__ = [ "ExecutionCoordinator", - "QueueManager", - "QueuePriority", - "ResourceManager", - "ResourceAllocation", ] diff --git a/backend/app/services/coordinator/coordinator.py b/backend/app/services/coordinator/coordinator.py index e9e0591b..b8c9e6e5 100644 --- a/backend/app/services/coordinator/coordinator.py +++ b/backend/app/services/coordinator/coordinator.py @@ -1,15 +1,24 @@ -import asyncio +"""Execution Coordinator - stateless event handler. + +Coordinates execution scheduling across the system. Receives events, +processes them, and publishes results. No lifecycle management. +All state is stored in Redis repositories. +""" + +from __future__ import annotations + import logging import time -from collections.abc import Coroutine -from typing import Any, TypeAlias from uuid import uuid4 -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import CoordinatorMetrics, EventMetrics +from app.db.repositories import ( + ExecutionQueueRepository, + ExecutionStateRepository, + QueuePriority, + ResourceRepository, +) from app.db.repositories.execution_repository import ExecutionRepository -from app.domain.enums.events import EventType -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId from app.domain.enums.storage import ExecutionErrorType from app.domain.events.typed import ( CreatePodCommandEvent, @@ -20,392 +29,219 @@ ExecutionFailedEvent, ExecutionRequestedEvent, ) -from app.domain.idempotency import KeyStrategy -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import ( - SchemaRegistryManager, -) -from app.services.coordinator.queue_manager import QueueManager, QueuePriority -from app.services.coordinator.resource_manager import ResourceAllocation, ResourceManager -from app.services.idempotency import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper -from app.settings import Settings +from app.events.core import UnifiedProducer -EventHandler: TypeAlias = Coroutine[Any, Any, None] -ExecutionMap: TypeAlias = dict[str, ResourceAllocation] +class ExecutionCoordinator: + """Stateless execution coordinator - pure event handler. -class ExecutionCoordinator(LifecycleEnabled): - """ - Coordinates execution scheduling across the system. - - This service: - 1. Consumes ExecutionRequested events - 2. Manages execution queue with priority - 3. Enforces rate limits - 4. Allocates resources - 5. Publishes ExecutionStarted events for workers + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + All state (active executions, queue, resources) stored in Redis. """ def __init__( self, producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, execution_repository: ExecutionRepository, - idempotency_manager: IdempotencyManager, + state_repo: ExecutionStateRepository, + queue_repo: ExecutionQueueRepository, + resource_repo: ResourceRepository, logger: logging.Logger, coordinator_metrics: CoordinatorMetrics, event_metrics: EventMetrics, - consumer_group: str = GroupId.EXECUTION_COORDINATOR, - max_concurrent_scheduling: int = 10, - scheduling_interval_seconds: float = 0.5, - ): - super().__init__() - self.logger = logger - self.metrics = coordinator_metrics + ) -> None: + self._producer = producer + self._execution_repository = execution_repository + self._state_repo = state_repo + self._queue_repo = queue_repo + self._resource_repo = resource_repo + self._logger = logger + self._metrics = coordinator_metrics self._event_metrics = event_metrics - self._settings = settings - - # Kafka configuration - self.kafka_servers = self._settings.KAFKA_BOOTSTRAP_SERVERS - self.consumer_group = consumer_group - - # Components - self.queue_manager = QueueManager( - logger=self.logger, - coordinator_metrics=coordinator_metrics, - max_queue_size=10000, - max_executions_per_user=100, - stale_timeout_seconds=3600, - ) - - self.resource_manager = ResourceManager( - logger=self.logger, - coordinator_metrics=coordinator_metrics, - total_cpu_cores=32.0, - total_memory_mb=65536, - total_gpu_count=0, - ) - - # Kafka components - self.consumer: UnifiedConsumer | None = None - self.idempotent_consumer: IdempotentConsumerWrapper | None = None - self.producer: UnifiedProducer = producer - - # Persistence via repositories - self.execution_repository = execution_repository - self.idempotency_manager = idempotency_manager - self._event_store = event_store - - # Scheduling - self.max_concurrent_scheduling = max_concurrent_scheduling - self.scheduling_interval = scheduling_interval_seconds - self._scheduling_semaphore = asyncio.Semaphore(max_concurrent_scheduling) - - # State tracking - self._scheduling_task: asyncio.Task[None] | None = None - self._active_executions: set[str] = set() - self._execution_resources: ExecutionMap = {} - self._schema_registry_manager = schema_registry_manager - self.dispatcher = EventDispatcher(logger=self.logger) - - async def _on_start(self) -> None: - """Start the coordinator service.""" - self.logger.info("Starting ExecutionCoordinator service...") - - await self.queue_manager.start() - - await self.idempotency_manager.initialize() - - consumer_config = ConsumerConfig( - bootstrap_servers=self.kafka_servers, - group_id=self.consumer_group, - enable_auto_commit=False, - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - max_poll_records=100, # Process max 100 messages at a time for flow control - fetch_max_wait_ms=500, # Wait max 500ms for data (reduces latency) - fetch_min_bytes=1, # Return immediately if any data available - ) - - self.consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self.dispatcher, - schema_registry=self._schema_registry_manager, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - - # Register handlers with EventDispatcher BEFORE wrapping with idempotency - @self.dispatcher.register(EventType.EXECUTION_REQUESTED) - async def handle_requested(event: ExecutionRequestedEvent) -> None: - await self._route_execution_event(event) - - @self.dispatcher.register(EventType.EXECUTION_COMPLETED) - async def handle_completed(event: ExecutionCompletedEvent) -> None: - await self._route_execution_result(event) - - @self.dispatcher.register(EventType.EXECUTION_FAILED) - async def handle_failed(event: ExecutionFailedEvent) -> None: - await self._route_execution_result(event) - - @self.dispatcher.register(EventType.EXECUTION_CANCELLED) - async def handle_cancelled(event: ExecutionCancelledEvent) -> None: - await self._route_execution_event(event) - - self.idempotent_consumer = IdempotentConsumerWrapper( - consumer=self.consumer, - idempotency_manager=self.idempotency_manager, - dispatcher=self.dispatcher, - logger=self.logger, - default_key_strategy=KeyStrategy.EVENT_BASED, # Use event ID for deduplication - default_ttl_seconds=7200, # 2 hours TTL for coordinator events - enable_for_all_handlers=True, # Enable idempotency for ALL handlers - ) - self.logger.info("COORDINATOR: Event handlers registered with idempotency protection") - - await self.idempotent_consumer.start(list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.EXECUTION_COORDINATOR])) - - # Start scheduling task - self._scheduling_task = asyncio.create_task(self._scheduling_loop()) - - self.logger.info("ExecutionCoordinator service started successfully") - - async def _on_stop(self) -> None: - """Stop the coordinator service.""" - self.logger.info("Stopping ExecutionCoordinator service...") - - # Stop scheduling task - if self._scheduling_task: - self._scheduling_task.cancel() - try: - await self._scheduling_task - except asyncio.CancelledError: - pass - - # Stop consumer (idempotent wrapper only) - if self.idempotent_consumer: - await self.idempotent_consumer.stop() - - await self.queue_manager.stop() - - # Close idempotency manager - if hasattr(self, "idempotency_manager") and self.idempotency_manager: - await self.idempotency_manager.close() - - self.logger.info(f"ExecutionCoordinator service stopped. Active executions: {len(self._active_executions)}") - - async def _route_execution_event(self, event: ExecutionRequestedEvent | ExecutionCancelledEvent) -> None: - """Route execution events to appropriate handlers based on event type""" - self.logger.info( - f"COORDINATOR: Routing execution event - type: {event.event_type}, " - f"id: {event.event_id}, " - f"actual class: {type(event).__name__}" - ) - - if event.event_type == EventType.EXECUTION_REQUESTED: - await self._handle_execution_requested(event) - elif event.event_type == EventType.EXECUTION_CANCELLED: - await self._handle_execution_cancelled(event) - else: - self.logger.debug(f"Ignoring execution event type: {event.event_type}") - - async def _route_execution_result(self, event: ExecutionCompletedEvent | ExecutionFailedEvent) -> None: - """Route execution result events to appropriate handlers based on event type""" - if event.event_type == EventType.EXECUTION_COMPLETED: - await self._handle_execution_completed(event) - elif event.event_type == EventType.EXECUTION_FAILED: - await self._handle_execution_failed(event) - else: - self.logger.debug(f"Ignoring execution result event type: {event.event_type}") - - async def _handle_execution_requested(self, event: ExecutionRequestedEvent) -> None: - """Handle execution requested event - add to queue for processing""" - self.logger.info(f"HANDLER CALLED: _handle_execution_requested for event {event.event_id}") + async def handle_execution_requested(self, event: ExecutionRequestedEvent) -> None: + """Handle execution requested event - add to queue and try to schedule.""" + self._logger.info(f"Handling ExecutionRequestedEvent: {event.execution_id}") start_time = time.time() try: - # Add to queue with priority - success, position, error = await self.queue_manager.add_execution( - event, - priority=QueuePriority(event.priority), + priority = QueuePriority(event.priority) + user_id = event.metadata.user_id or "anonymous" + + # Add to Redis queue + success, position, error = await self._queue_repo.enqueue( + execution_id=event.execution_id, + event_data=event.model_dump(mode="json"), + priority=priority, + user_id=user_id, ) if not success: - # Publish queue full event await self._publish_queue_full(event, error or "Queue is full") - self.metrics.record_coordinator_execution_scheduled("queue_full") + self._metrics.record_coordinator_execution_scheduled("queue_full") return # Publish ExecutionAcceptedEvent - if position is None: - position = 0 - await self._publish_execution_accepted(event, position, event.priority) + await self._publish_execution_accepted(event, position or 0, event.priority) # Track metrics duration = time.time() - start_time - self.metrics.record_coordinator_scheduling_duration(duration) - self.metrics.record_coordinator_execution_scheduled("queued") + self._metrics.record_coordinator_scheduling_duration(duration) + self._metrics.record_coordinator_execution_scheduled("queued") - self.logger.info(f"Execution {event.execution_id} added to queue at position {position}") + self._logger.info(f"Execution {event.execution_id} added to queue at position {position}") - # Schedule immediately if at front of queue (position 0) + # If at front of queue (position 0), try to schedule immediately if position == 0: - await self._schedule_execution(event) + await self._try_schedule_next() except Exception as e: - self.logger.error(f"Failed to handle execution request {event.execution_id}: {e}", exc_info=True) - self.metrics.record_coordinator_execution_scheduled("error") + self._logger.error(f"Failed to handle execution request {event.execution_id}: {e}", exc_info=True) + self._metrics.record_coordinator_execution_scheduled("error") - async def _handle_execution_cancelled(self, event: ExecutionCancelledEvent) -> None: - """Handle execution cancelled event""" + async def handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: + """Handle execution completed - release resources and try to schedule next.""" execution_id = event.execution_id + self._logger.info(f"Handling ExecutionCompletedEvent: {execution_id}") + + # Release resources + await self._resource_repo.release(execution_id) - removed = await self.queue_manager.remove_execution(execution_id) + # Remove from active state + await self._state_repo.remove(execution_id) - if execution_id in self._execution_resources: - await self.resource_manager.release_allocation(execution_id) - del self._execution_resources[execution_id] + # Update metrics + count = await self._state_repo.get_active_count() + self._metrics.update_coordinator_active_executions(count) - self._active_executions.discard(execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) + self._logger.info(f"Execution {execution_id} completed, resources released") - if removed: - self.logger.info(f"Execution {execution_id} cancelled and removed from queue") + # Try to schedule next execution from queue + await self._try_schedule_next() - async def _handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: - """Handle execution completed event""" + async def handle_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle execution failed - release resources and try to schedule next.""" execution_id = event.execution_id + self._logger.info(f"Handling ExecutionFailedEvent: {execution_id}") - if execution_id in self._execution_resources: - await self.resource_manager.release_allocation(execution_id) - del self._execution_resources[execution_id] + # Release resources + await self._resource_repo.release(execution_id) + + # Remove from active state + await self._state_repo.remove(execution_id) - # Remove from active set - self._active_executions.discard(execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) + # Update metrics + count = await self._state_repo.get_active_count() + self._metrics.update_coordinator_active_executions(count) - self.logger.info(f"Execution {execution_id} completed, resources released") + # Try to schedule next execution from queue + await self._try_schedule_next() - async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: - """Handle execution failed event""" + async def handle_execution_cancelled(self, event: ExecutionCancelledEvent) -> None: + """Handle execution cancelled - remove from queue and release resources.""" execution_id = event.execution_id + self._logger.info(f"Handling ExecutionCancelledEvent: {execution_id}") - # Release resources - if execution_id in self._execution_resources: - await self.resource_manager.release_allocation(execution_id) - del self._execution_resources[execution_id] - - # Remove from active set - self._active_executions.discard(execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - - async def _scheduling_loop(self) -> None: - """Main scheduling loop""" - while self.is_running: - try: - # Get next execution from queue - execution = await self.queue_manager.get_next_execution() - - if execution: - # Schedule execution - asyncio.create_task(self._schedule_execution(execution)) - else: - # No executions in queue, wait - await asyncio.sleep(self.scheduling_interval) - - except Exception as e: - self.logger.error(f"Error in scheduling loop: {e}", exc_info=True) - await asyncio.sleep(5) # Wait before retrying + # Remove from queue if present + await self._queue_repo.remove(execution_id) + + # Release resources if allocated + await self._resource_repo.release(execution_id) + + # Remove from active state + await self._state_repo.remove(execution_id) + + # Update metrics + count = await self._state_repo.get_active_count() + self._metrics.update_coordinator_active_executions(count) + + async def _try_schedule_next(self) -> None: + """Try to schedule the next execution from the queue.""" + result = await self._queue_repo.dequeue() + if not result: + return + + execution_id, event_data = result + + # Reconstruct event from stored data + try: + event = ExecutionRequestedEvent.model_validate(event_data) + await self._schedule_execution(event) + except Exception as e: + self._logger.error(f"Failed to schedule execution {execution_id}: {e}", exc_info=True) async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None: - """Schedule a single execution""" - async with self._scheduling_semaphore: - start_time = time.time() - execution_id = event.execution_id - - # Atomic check-and-claim: no await between check and add prevents TOCTOU race - # when both eager scheduling (position=0) and _scheduling_loop try to schedule - if execution_id in self._active_executions: - self.logger.debug(f"Execution {execution_id} already claimed, skipping") - return - self._active_executions.add(execution_id) - - try: - # Request resource allocation - allocation = await self.resource_manager.request_allocation( - execution_id, - event.language, - requested_cpu=None, # Use defaults for now - requested_memory_mb=None, - requested_gpu=0, - ) + """Schedule a single execution - allocate resources and publish command.""" + start_time = time.time() + execution_id = event.execution_id + + # Try to claim this execution atomically + claimed = await self._state_repo.try_claim(execution_id) + if not claimed: + self._logger.debug(f"Execution {execution_id} already claimed, skipping") + return - if not allocation: - # No resources available, release claim and requeue - self._active_executions.discard(execution_id) - await self.queue_manager.requeue_execution(event, increment_retry=False) - self.logger.info(f"No resources available for {execution_id}, requeued") - return - - # Track allocation (already in _active_executions from claim above) - self._execution_resources[execution_id] = allocation - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - - # Publish execution started event for workers - self.logger.info(f"About to publish ExecutionStartedEvent for {event.execution_id}") - try: - await self._publish_execution_started(event) - self.logger.info(f"Successfully published ExecutionStartedEvent for {event.execution_id}") - except Exception as publish_error: - self.logger.error( - f"Failed to publish ExecutionStartedEvent for {event.execution_id}: {publish_error}", - exc_info=True, - ) - raise - - # Track metrics - queue_time = start_time - event.timestamp.timestamp() - priority = getattr(event, "priority", QueuePriority.NORMAL.value) - self.metrics.record_coordinator_queue_time(queue_time, QueuePriority(priority).name) - - scheduling_duration = time.time() - start_time - self.metrics.record_coordinator_scheduling_duration(scheduling_duration) - self.metrics.record_coordinator_execution_scheduled("scheduled") - - self.logger.info( - f"Scheduled execution {event.execution_id}. " - f"Queue time: {queue_time:.2f}s, " - f"Resources: {allocation.cpu_cores} CPU, " - f"{allocation.memory_mb}MB RAM" + try: + # Allocate resources + allocation = await self._resource_repo.allocate( + execution_id=execution_id, + language=event.language, + requested_cpu=None, + requested_memory_mb=None, + requested_gpu=0, + ) + + if not allocation: + # No resources available, release claim and requeue + await self._state_repo.remove(execution_id) + await self._queue_repo.enqueue( + execution_id=event.execution_id, + event_data=event.model_dump(mode="json"), + priority=QueuePriority(event.priority), + user_id=event.metadata.user_id or "anonymous", ) + self._logger.info(f"No resources available for {execution_id}, requeued") + return + + # Update metrics + count = await self._state_repo.get_active_count() + self._metrics.update_coordinator_active_executions(count) - except Exception as e: - self.logger.error(f"Failed to schedule execution {event.execution_id}: {e}", exc_info=True) + # Publish CreatePodCommand + await self._publish_execution_started(event) - # Release any allocated resources - if event.execution_id in self._execution_resources: - await self.resource_manager.release_allocation(event.execution_id) - del self._execution_resources[event.execution_id] + # Track metrics + queue_time = start_time - event.timestamp.timestamp() + priority = QueuePriority(event.priority) + self._metrics.record_coordinator_queue_time(queue_time, priority.name) + + scheduling_duration = time.time() - start_time + self._metrics.record_coordinator_scheduling_duration(scheduling_duration) + self._metrics.record_coordinator_execution_scheduled("scheduled") + + self._logger.info( + f"Scheduled execution {event.execution_id}. " + f"Queue time: {queue_time:.2f}s, " + f"Resources: {allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM" + ) + + except Exception as e: + self._logger.error(f"Failed to schedule execution {event.execution_id}: {e}", exc_info=True) - self._active_executions.discard(event.execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - self.metrics.record_coordinator_execution_scheduled("error") + # Release resources and claim + await self._resource_repo.release(execution_id) + await self._state_repo.remove(execution_id) - # Publish failure event - await self._publish_scheduling_failed(event, str(e)) + count = await self._state_repo.get_active_count() + self._metrics.update_coordinator_active_executions(count) + self._metrics.record_coordinator_execution_scheduled("error") + + # Publish failure event + await self._publish_scheduling_failed(event, str(e)) async def _build_command_metadata(self, request: ExecutionRequestedEvent) -> EventMetadata: """Build metadata for CreatePodCommandEvent with guaranteed user_id.""" - # Prefer execution record user_id to avoid missing attribution - exec_rec = await self.execution_repository.get_execution(request.execution_id) + exec_rec = await self._execution_repository.get_execution(request.execution_id) user_id: str = exec_rec.user_id if exec_rec and exec_rec.user_id else "system" return EventMetadata( @@ -416,7 +252,7 @@ async def _build_command_metadata(self, request: ExecutionRequestedEvent) -> Eve ) async def _publish_execution_started(self, request: ExecutionRequestedEvent) -> None: - """Send CreatePodCommandEvent to k8s-worker via SAGA_COMMANDS topic""" + """Send CreatePodCommandEvent to k8s-worker via SAGA_COMMANDS topic.""" metadata = await self._build_command_metadata(request) create_pod_cmd = CreatePodCommandEvent( @@ -437,64 +273,54 @@ async def _publish_execution_started(self, request: ExecutionRequestedEvent) -> metadata=metadata, ) - await self.producer.produce(event_to_produce=create_pod_cmd, key=request.execution_id) - - async def _publish_execution_accepted(self, request: ExecutionRequestedEvent, position: int, priority: int) -> None: - """Publish execution accepted event to notify that request was valid and queued""" - self.logger.info(f"Publishing ExecutionAcceptedEvent for execution {request.execution_id}") + await self._producer.produce(event_to_produce=create_pod_cmd, key=request.execution_id) + self._logger.info(f"Published CreatePodCommandEvent for {request.execution_id}") + async def _publish_execution_accepted( + self, request: ExecutionRequestedEvent, position: int, priority: int + ) -> None: + """Publish execution accepted event.""" event = ExecutionAcceptedEvent( execution_id=request.execution_id, queue_position=position, - estimated_wait_seconds=None, # Could calculate based on queue analysis + estimated_wait_seconds=None, priority=priority, metadata=request.metadata, ) - await self.producer.produce(event_to_produce=event) - self.logger.info(f"ExecutionAcceptedEvent published for {request.execution_id}") + await self._producer.produce(event_to_produce=event) + self._logger.info(f"ExecutionAcceptedEvent published for {request.execution_id}") async def _publish_queue_full(self, request: ExecutionRequestedEvent, error: str) -> None: - """Publish queue full event""" - # Get queue stats for context - queue_stats = await self.queue_manager.get_queue_stats() + """Publish queue full event.""" + queue_stats = await self._queue_repo.get_stats() event = ExecutionFailedEvent( execution_id=request.execution_id, error_type=ExecutionErrorType.RESOURCE_LIMIT, exit_code=-1, - stderr=f"Queue full: {error}. Queue size: {queue_stats.get('total_size', 'unknown')}", + stderr=f"Queue full: {error}. Queue size: {queue_stats.total_size}", resource_usage=None, metadata=request.metadata, error_message=error, ) - await self.producer.produce(event_to_produce=event, key=request.execution_id) + await self._producer.produce(event_to_produce=event, key=request.execution_id) async def _publish_scheduling_failed(self, request: ExecutionRequestedEvent, error: str) -> None: - """Publish scheduling failed event""" - # Get resource stats for context - resource_stats = await self.resource_manager.get_resource_stats() + """Publish scheduling failed event.""" + resource_stats = await self._resource_repo.get_stats() event = ExecutionFailedEvent( execution_id=request.execution_id, error_type=ExecutionErrorType.SYSTEM_ERROR, exit_code=-1, stderr=f"Failed to schedule execution: {error}. " - f"Available resources: CPU={resource_stats.available.cpu_cores}, " - f"Memory={resource_stats.available.memory_mb}MB", + f"Available resources: CPU={resource_stats.available_cpu}, " + f"Memory={resource_stats.available_memory_mb}MB", resource_usage=None, metadata=request.metadata, error_message=error, ) - await self.producer.produce(event_to_produce=event, key=request.execution_id) - - async def get_status(self) -> dict[str, Any]: - """Get coordinator status""" - return { - "running": self.is_running, - "active_executions": len(self._active_executions), - "queue_stats": await self.queue_manager.get_queue_stats(), - "resource_stats": await self.resource_manager.get_resource_stats(), - } + await self._producer.produce(event_to_produce=event, key=request.execution_id) diff --git a/backend/app/services/coordinator/queue_manager.py b/backend/app/services/coordinator/queue_manager.py deleted file mode 100644 index 8dab2643..00000000 --- a/backend/app/services/coordinator/queue_manager.py +++ /dev/null @@ -1,271 +0,0 @@ -import asyncio -import heapq -import logging -import time -from collections import defaultdict -from dataclasses import dataclass, field -from enum import IntEnum -from typing import Any - -from app.core.metrics import CoordinatorMetrics -from app.domain.events.typed import ExecutionRequestedEvent - - -class QueuePriority(IntEnum): - CRITICAL = 0 - HIGH = 1 - NORMAL = 5 - LOW = 8 - BACKGROUND = 10 - - -@dataclass(order=True) -class QueuedExecution: - priority: int - timestamp: float = field(compare=False) - event: ExecutionRequestedEvent = field(compare=False) - retry_count: int = field(default=0, compare=False) - - @property - def execution_id(self) -> str: - return self.event.execution_id - - @property - def user_id(self) -> str: - return self.event.metadata.user_id or "anonymous" - - @property - def age_seconds(self) -> float: - return time.time() - self.timestamp - - -class QueueManager: - def __init__( - self, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - max_queue_size: int = 10000, - max_executions_per_user: int = 100, - stale_timeout_seconds: int = 3600, - ) -> None: - self.logger = logger - self.metrics = coordinator_metrics - self.max_queue_size = max_queue_size - self.max_executions_per_user = max_executions_per_user - self.stale_timeout_seconds = stale_timeout_seconds - - self._queue: list[QueuedExecution] = [] - self._queue_lock = asyncio.Lock() - self._user_execution_count: dict[str, int] = defaultdict(int) - self._execution_users: dict[str, str] = {} - self._cleanup_task: asyncio.Task[None] | None = None - self._running = False - - async def start(self) -> None: - if self._running: - return - - self._running = True - self._cleanup_task = asyncio.create_task(self._cleanup_stale_executions()) - self.logger.info("Queue manager started") - - async def stop(self) -> None: - if not self._running: - return - - self._running = False - - if self._cleanup_task: - self._cleanup_task.cancel() - try: - await self._cleanup_task - except asyncio.CancelledError: - pass - - self.logger.info(f"Queue manager stopped. Final queue size: {len(self._queue)}") - - async def add_execution( - self, event: ExecutionRequestedEvent, priority: QueuePriority | None = None - ) -> tuple[bool, int | None, str | None]: - async with self._queue_lock: - if len(self._queue) >= self.max_queue_size: - return False, None, "Queue is full" - - user_id = event.metadata.user_id or "anonymous" - - if self._user_execution_count[user_id] >= self.max_executions_per_user: - return False, None, f"User execution limit exceeded ({self.max_executions_per_user})" - - if priority is None: - priority = QueuePriority(event.priority) - - queued = QueuedExecution(priority=priority.value, timestamp=time.time(), event=event) - - heapq.heappush(self._queue, queued) - self._track_execution(event.execution_id, user_id) - position = self._get_queue_position(event.execution_id) - - # Update single authoritative metric for execution request queue depth - self.metrics.update_execution_request_queue_size(len(self._queue)) - - self.logger.info( - f"Added execution {event.execution_id} to queue. " - f"Priority: {priority.name}, Position: {position}, " - f"Queue size: {len(self._queue)}" - ) - - return True, position, None - - async def get_next_execution(self) -> ExecutionRequestedEvent | None: - async with self._queue_lock: - while self._queue: - queued = heapq.heappop(self._queue) - - if self._is_stale(queued): - self._untrack_execution(queued.execution_id) - self._record_removal("stale") - continue - - self._untrack_execution(queued.execution_id) - self._record_wait_time(queued) - # Update metric after removal from the queue - self.metrics.update_execution_request_queue_size(len(self._queue)) - - self.logger.info( - f"Retrieved execution {queued.execution_id} from queue. " - f"Wait time: {queued.age_seconds:.2f}s, Queue size: {len(self._queue)}" - ) - - return queued.event - - return None - - async def remove_execution(self, execution_id: str) -> bool: - async with self._queue_lock: - initial_size = len(self._queue) - self._queue = [q for q in self._queue if q.execution_id != execution_id] - - if len(self._queue) < initial_size: - heapq.heapify(self._queue) - self._untrack_execution(execution_id) - # Update metric after explicit removal - self.metrics.update_execution_request_queue_size(len(self._queue)) - self.logger.info(f"Removed execution {execution_id} from queue") - return True - - return False - - async def get_queue_position(self, execution_id: str) -> int | None: - async with self._queue_lock: - return self._get_queue_position(execution_id) - - async def get_queue_stats(self) -> dict[str, Any]: - async with self._queue_lock: - priority_counts: dict[str, int] = defaultdict(int) - user_counts: dict[str, int] = defaultdict(int) - - for queued in self._queue: - priority_name = QueuePriority(queued.priority).name - priority_counts[priority_name] += 1 - user_counts[queued.user_id] += 1 - - top_users = dict(sorted(user_counts.items(), key=lambda x: x[1], reverse=True)[:10]) - - return { - "total_size": len(self._queue), - "priority_distribution": dict(priority_counts), - "top_users": top_users, - "max_queue_size": self.max_queue_size, - "utilization_percent": (len(self._queue) / self.max_queue_size) * 100, - } - - async def requeue_execution( - self, event: ExecutionRequestedEvent, increment_retry: bool = True - ) -> tuple[bool, int | None, str | None]: - def _next_lower(p: QueuePriority) -> QueuePriority: - order = [ - QueuePriority.CRITICAL, - QueuePriority.HIGH, - QueuePriority.NORMAL, - QueuePriority.LOW, - QueuePriority.BACKGROUND, - ] - try: - idx = order.index(p) - except ValueError: - # Fallback: treat unknown numeric as NORMAL - idx = order.index(QueuePriority.NORMAL) - return order[min(idx + 1, len(order) - 1)] - - if increment_retry: - original_priority = QueuePriority(event.priority) - new_priority = _next_lower(original_priority) - else: - new_priority = QueuePriority(event.priority) - - return await self.add_execution(event, priority=new_priority) - - def _get_queue_position(self, execution_id: str) -> int | None: - for position, queued in enumerate(self._queue): - if queued.execution_id == execution_id: - return position - return None - - def _is_stale(self, queued: QueuedExecution) -> bool: - return queued.age_seconds > self.stale_timeout_seconds - - def _track_execution(self, execution_id: str, user_id: str) -> None: - self._user_execution_count[user_id] += 1 - self._execution_users[execution_id] = user_id - - def _untrack_execution(self, execution_id: str) -> None: - if execution_id in self._execution_users: - user_id = self._execution_users.pop(execution_id) - self._user_execution_count[user_id] -= 1 - if self._user_execution_count[user_id] <= 0: - del self._user_execution_count[user_id] - - def _record_removal(self, reason: str) -> None: - # No-op: we keep a single queue depth metric and avoid operation counters - return - - def _record_wait_time(self, queued: QueuedExecution) -> None: - self.metrics.record_queue_wait_time_by_priority( - queued.age_seconds, QueuePriority(queued.priority).name, "default" - ) - - def _update_add_metrics(self, priority: QueuePriority) -> None: - # Deprecated in favor of single execution queue depth metric - self.metrics.update_execution_request_queue_size(len(self._queue)) - - def _update_queue_size(self) -> None: - self.metrics.update_execution_request_queue_size(len(self._queue)) - - async def _cleanup_stale_executions(self) -> None: - while self._running: - try: - await asyncio.sleep(300) - - async with self._queue_lock: - stale_executions = [] - active_executions = [] - - for queued in self._queue: - if self._is_stale(queued): - stale_executions.append(queued) - else: - active_executions.append(queued) - - if stale_executions: - self._queue = active_executions - heapq.heapify(self._queue) - - for queued in stale_executions: - self._untrack_execution(queued.execution_id) - - # Update metric after stale cleanup - self.metrics.update_execution_request_queue_size(len(self._queue)) - self.logger.info(f"Cleaned {len(stale_executions)} stale executions from queue") - - except Exception as e: - self.logger.error(f"Error in queue cleanup: {e}") diff --git a/backend/app/services/coordinator/resource_manager.py b/backend/app/services/coordinator/resource_manager.py deleted file mode 100644 index bd0c2fbf..00000000 --- a/backend/app/services/coordinator/resource_manager.py +++ /dev/null @@ -1,324 +0,0 @@ -import asyncio -import logging -from dataclasses import dataclass - -from app.core.metrics import CoordinatorMetrics - - -@dataclass -class ResourceAllocation: - """Resource allocation for an execution""" - - cpu_cores: float - memory_mb: int - gpu_count: int = 0 - - @property - def cpu_millicores(self) -> int: - """Get CPU in millicores for Kubernetes""" - return int(self.cpu_cores * 1000) - - @property - def memory_bytes(self) -> int: - """Get memory in bytes""" - return self.memory_mb * 1024 * 1024 - - -@dataclass -class ResourcePool: - """Available resource pool""" - - total_cpu_cores: float - total_memory_mb: int - total_gpu_count: int - - available_cpu_cores: float - available_memory_mb: int - available_gpu_count: int - - # Resource limits per execution - max_cpu_per_execution: float = 4.0 - max_memory_per_execution_mb: int = 8192 - max_gpu_per_execution: int = 1 - - # Minimum resources to keep available - min_available_cpu_cores: float = 2.0 - min_available_memory_mb: int = 4096 - - -@dataclass -class ResourceGroup: - """Resource group with usage information""" - - cpu_cores: float - memory_mb: int - gpu_count: int - - -@dataclass -class ResourceStats: - """Resource statistics""" - - total: ResourceGroup - available: ResourceGroup - allocated: ResourceGroup - utilization: dict[str, float] - allocation_count: int - limits: dict[str, int | float] - - -@dataclass -class ResourceAllocationInfo: - """Information about a resource allocation""" - - execution_id: str - cpu_cores: float - memory_mb: int - gpu_count: int - cpu_percentage: float - memory_percentage: float - - -class ResourceManager: - """Manages resource allocation for executions""" - - def __init__( - self, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - total_cpu_cores: float = 32.0, - total_memory_mb: int = 65536, # 64GB - total_gpu_count: int = 0, - overcommit_factor: float = 1.2, # Allow 20% overcommit - ): - self.logger = logger - self.metrics = coordinator_metrics - self.pool = ResourcePool( - total_cpu_cores=total_cpu_cores * overcommit_factor, - total_memory_mb=int(total_memory_mb * overcommit_factor), - total_gpu_count=total_gpu_count, - available_cpu_cores=total_cpu_cores * overcommit_factor, - available_memory_mb=int(total_memory_mb * overcommit_factor), - available_gpu_count=total_gpu_count, - ) - - # Adjust minimum reserve thresholds proportionally for small pools. - # Keep at most 10% of total as reserve (but not higher than defaults). - # This avoids refusing small, reasonable allocations on modest clusters. - self.pool.min_available_cpu_cores = min( - self.pool.min_available_cpu_cores, - max(0.1 * self.pool.total_cpu_cores, 0.0), - ) - self.pool.min_available_memory_mb = min( - self.pool.min_available_memory_mb, - max(int(0.1 * self.pool.total_memory_mb), 0), - ) - - # Track allocations - self._allocations: dict[str, ResourceAllocation] = {} - self._allocation_lock = asyncio.Lock() - - # Default allocations by language - self.default_allocations = { - "python": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "javascript": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "go": ResourceAllocation(cpu_cores=0.25, memory_mb=256), - "rust": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "java": ResourceAllocation(cpu_cores=1.0, memory_mb=1024), - "cpp": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "r": ResourceAllocation(cpu_cores=1.0, memory_mb=2048), - } - - # Update initial metrics - self._update_metrics() - - async def request_allocation( - self, - execution_id: str, - language: str, - requested_cpu: float | None = None, - requested_memory_mb: int | None = None, - requested_gpu: int = 0, - ) -> ResourceAllocation | None: - """ - Request resource allocation for execution - - Returns: - ResourceAllocation if successful, None if resources unavailable - """ - async with self._allocation_lock: - # Check if already allocated - if execution_id in self._allocations: - self.logger.warning(f"Execution {execution_id} already has allocation") - return self._allocations[execution_id] - - # Determine requested resources - if requested_cpu is None or requested_memory_mb is None: - # Use defaults based on language - default = self.default_allocations.get(language, ResourceAllocation(cpu_cores=0.5, memory_mb=512)) - requested_cpu = requested_cpu or default.cpu_cores - requested_memory_mb = requested_memory_mb or default.memory_mb - - # Apply limits - requested_cpu = min(requested_cpu, self.pool.max_cpu_per_execution) - requested_memory_mb = min(requested_memory_mb, self.pool.max_memory_per_execution_mb) - requested_gpu = min(requested_gpu, self.pool.max_gpu_per_execution) - - # Check availability (considering minimum reserves) - cpu_after = self.pool.available_cpu_cores - requested_cpu - memory_after = self.pool.available_memory_mb - requested_memory_mb - gpu_after = self.pool.available_gpu_count - requested_gpu - - if ( - cpu_after < self.pool.min_available_cpu_cores - or memory_after < self.pool.min_available_memory_mb - or gpu_after < 0 - ): - self.logger.warning( - f"Insufficient resources for execution {execution_id}. " - f"Requested: {requested_cpu} CPU, {requested_memory_mb}MB RAM, " - f"{requested_gpu} GPU. Available: {self.pool.available_cpu_cores} CPU, " - f"{self.pool.available_memory_mb}MB RAM, {self.pool.available_gpu_count} GPU" - ) - return None - - # Create allocation - allocation = ResourceAllocation( - cpu_cores=requested_cpu, memory_mb=requested_memory_mb, gpu_count=requested_gpu - ) - - # Update pool - self.pool.available_cpu_cores = cpu_after - self.pool.available_memory_mb = memory_after - self.pool.available_gpu_count = gpu_after - - # Track allocation - self._allocations[execution_id] = allocation - - # Update metrics - self._update_metrics() - - self.logger.info( - f"Allocated resources for execution {execution_id}: " - f"{allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM, " - f"{allocation.gpu_count} GPU" - ) - - return allocation - - async def release_allocation(self, execution_id: str) -> bool: - """Release resource allocation""" - async with self._allocation_lock: - if execution_id not in self._allocations: - self.logger.warning(f"No allocation found for execution {execution_id}") - return False - - allocation = self._allocations[execution_id] - - # Return resources to pool - self.pool.available_cpu_cores += allocation.cpu_cores - self.pool.available_memory_mb += allocation.memory_mb - self.pool.available_gpu_count += allocation.gpu_count - - # Remove allocation - del self._allocations[execution_id] - - # Update metrics - self._update_metrics() - - self.logger.info( - f"Released resources for execution {execution_id}: " - f"{allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM, " - f"{allocation.gpu_count} GPU" - ) - - return True - - async def get_allocation(self, execution_id: str) -> ResourceAllocation | None: - """Get current allocation for execution""" - async with self._allocation_lock: - return self._allocations.get(execution_id) - - async def can_allocate(self, cpu_cores: float, memory_mb: int, gpu_count: int = 0) -> bool: - """Check if resources can be allocated""" - async with self._allocation_lock: - cpu_after = self.pool.available_cpu_cores - cpu_cores - memory_after = self.pool.available_memory_mb - memory_mb - gpu_after = self.pool.available_gpu_count - gpu_count - - return ( - cpu_after >= self.pool.min_available_cpu_cores - and memory_after >= self.pool.min_available_memory_mb - and gpu_after >= 0 - ) - - async def get_resource_stats(self) -> ResourceStats: - """Get resource statistics""" - async with self._allocation_lock: - allocated_cpu = self.pool.total_cpu_cores - self.pool.available_cpu_cores - allocated_memory = self.pool.total_memory_mb - self.pool.available_memory_mb - allocated_gpu = self.pool.total_gpu_count - self.pool.available_gpu_count - - gpu_percent = (allocated_gpu / self.pool.total_gpu_count * 100) if self.pool.total_gpu_count > 0 else 0 - - return ResourceStats( - total=ResourceGroup( - cpu_cores=self.pool.total_cpu_cores, - memory_mb=self.pool.total_memory_mb, - gpu_count=self.pool.total_gpu_count, - ), - available=ResourceGroup( - cpu_cores=self.pool.available_cpu_cores, - memory_mb=self.pool.available_memory_mb, - gpu_count=self.pool.available_gpu_count, - ), - allocated=ResourceGroup(cpu_cores=allocated_cpu, memory_mb=allocated_memory, gpu_count=allocated_gpu), - utilization={ - "cpu_percent": (allocated_cpu / self.pool.total_cpu_cores * 100), - "memory_percent": (allocated_memory / self.pool.total_memory_mb * 100), - "gpu_percent": gpu_percent, - }, - allocation_count=len(self._allocations), - limits={ - "max_cpu_per_execution": self.pool.max_cpu_per_execution, - "max_memory_per_execution_mb": self.pool.max_memory_per_execution_mb, - "max_gpu_per_execution": self.pool.max_gpu_per_execution, - }, - ) - - async def get_allocations_by_resource_usage(self) -> list[ResourceAllocationInfo]: - """Get allocations sorted by resource usage""" - async with self._allocation_lock: - allocations = [] - for exec_id, allocation in self._allocations.items(): - allocations.append( - ResourceAllocationInfo( - execution_id=str(exec_id), - cpu_cores=allocation.cpu_cores, - memory_mb=allocation.memory_mb, - gpu_count=allocation.gpu_count, - cpu_percentage=(allocation.cpu_cores / self.pool.total_cpu_cores * 100), - memory_percentage=(allocation.memory_mb / self.pool.total_memory_mb * 100), - ) - ) - - # Sort by total resource usage - allocations.sort(key=lambda x: x.cpu_percentage + x.memory_percentage, reverse=True) - - return allocations - - def _update_metrics(self) -> None: - """Update metrics""" - cpu_usage = self.pool.total_cpu_cores - self.pool.available_cpu_cores - cpu_percent = cpu_usage / self.pool.total_cpu_cores * 100 - self.metrics.update_resource_usage("cpu", cpu_percent) - - memory_usage = self.pool.total_memory_mb - self.pool.available_memory_mb - memory_percent = memory_usage / self.pool.total_memory_mb * 100 - self.metrics.update_resource_usage("memory", memory_percent) - - gpu_usage = self.pool.total_gpu_count - self.pool.available_gpu_count - gpu_percent = gpu_usage / max(1, self.pool.total_gpu_count) * 100 - self.metrics.update_resource_usage("gpu", gpu_percent) - - self.metrics.update_coordinator_active_executions(len(self._allocations)) diff --git a/backend/app/services/event_bus.py b/backend/app/services/event_bus.py index 613b2ef6..6f9d62a0 100644 --- a/backend/app/services/event_bus.py +++ b/backend/app/services/event_bus.py @@ -1,323 +1,140 @@ +"""Event Bus - stateless pub/sub service. + +Distributed event bus for cross-instance communication via Kafka. +No lifecycle management - receives ready-to-use producer from DI. +""" + +from __future__ import annotations + import asyncio import fnmatch import json import logging from dataclasses import dataclass, field -from typing import Any, Callable +from datetime import datetime, timezone from uuid import uuid4 -from aiokafka import AIOKafkaConsumer, AIOKafkaProducer -from aiokafka.errors import KafkaError -from fastapi import Request +from aiokafka import AIOKafkaProducer +from pydantic import BaseModel, ConfigDict -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import ConnectionMetrics from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import BaseEvent, DomainEvent, domain_event_adapter from app.settings import Settings +class EventBusEvent(BaseModel): + """Represents an event on the event bus.""" + + model_config = ConfigDict(from_attributes=True) + + id: str + event_type: str + timestamp: datetime + payload: dict[str, object] + + @dataclass class Subscription: """Represents a single event subscription.""" id: str = field(default_factory=lambda: str(uuid4())) pattern: str = "" - handler: Callable[[DomainEvent], Any] = field(default=lambda _: None) - - -class EventBus(LifecycleEnabled): - """ - Distributed event bus for cross-instance communication via Kafka. - - Publishers send events to Kafka. Subscribers receive events from OTHER instances - only - self-published messages are filtered out. This design means: - - Publishers should update their own state directly before calling publish() - - Handlers only run for events from other instances (cache invalidation, etc.) - - Supports pattern-based subscriptions using wildcards: - - execution.* - matches all execution events - - execution.123.* - matches all events for execution 123 - - *.completed - matches all completed events - """ - - def __init__(self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics) -> None: - super().__init__() - self.logger = logger - self.settings = settings - self.metrics = connection_metrics - self.producer: AIOKafkaProducer | None = None - self.consumer: AIOKafkaConsumer | None = None - self._subscriptions: dict[str, Subscription] = {} # id -> Subscription - self._pattern_index: dict[str, set[str]] = {} # pattern -> set of subscription ids - self._consumer_task: asyncio.Task[None] | None = None + handler: object = field(default=None) + + +class EventBus: + """Stateless event bus - pure pub/sub service.""" + + def __init__( + self, + producer: AIOKafkaProducer, + settings: Settings, + logger: logging.Logger, + connection_metrics: ConnectionMetrics, + ) -> None: + self._producer = producer + self._settings = settings + self._logger = logger + self._metrics = connection_metrics + self._subscriptions: dict[str, Subscription] = {} + self._pattern_index: dict[str, set[str]] = {} self._lock = asyncio.Lock() - self._topic = f"{self.settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" - self._instance_id = str(uuid4()) # Unique ID for filtering self-published messages - - async def _on_start(self) -> None: - """Start the event bus with Kafka backing.""" - await self._initialize_kafka() - self._consumer_task = asyncio.create_task(self._kafka_listener()) - self.logger.info("Event bus started with Kafka backing") - - async def _initialize_kafka(self) -> None: - """Initialize Kafka producer and consumer.""" - # Producer setup - self.producer = AIOKafkaProducer( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - client_id=f"event-bus-producer-{uuid4()}", - linger_ms=10, - max_batch_size=16384, - enable_idempotence=True, + self._topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" + self._instance_id = str(uuid4()) + + async def publish(self, event_type: str, data: dict[str, object]) -> None: + """Publish an event to Kafka for cross-instance distribution.""" + event = EventBusEvent( + id=str(uuid4()), + event_type=event_type, + timestamp=datetime.now(timezone.utc), + payload=data, ) - # Consumer setup - self.consumer = AIOKafkaConsumer( - self._topic, - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"event-bus-{uuid4()}", - auto_offset_reset="latest", - enable_auto_commit=True, - client_id=f"event-bus-consumer-{uuid4()}", - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - # Start both in parallel for faster startup - await asyncio.gather(self.producer.start(), self.consumer.start()) - - async def _on_stop(self) -> None: - """Stop the event bus and clean up resources.""" - # Cancel consumer task - if self._consumer_task and not self._consumer_task.done(): - self._consumer_task.cancel() - try: - await self._consumer_task - except asyncio.CancelledError: - pass - - # Stop Kafka components - if self.consumer: - await self.consumer.stop() - self.consumer = None - - if self.producer: - await self.producer.stop() - self.producer = None - - # Clear subscriptions - async with self._lock: - self._subscriptions.clear() - self._pattern_index.clear() - - self.logger.info("Event bus stopped") - - async def publish(self, event: BaseEvent) -> None: - """ - Publish a typed event to Kafka for cross-instance distribution. - - Local handlers receive events only from OTHER instances via the Kafka listener. - Publishers should update their own state directly before calling publish(). - - Args: - event: Typed domain event to publish - """ - if self.producer: - try: - value = event.model_dump_json().encode("utf-8") - key = event.event_type.encode("utf-8") - headers = [("source_instance", self._instance_id.encode("utf-8"))] - - await self.producer.send_and_wait( - topic=self._topic, - value=value, - key=key, - headers=headers, - ) - except Exception as e: - self.logger.error(f"Failed to publish to Kafka: {e}") - - async def subscribe(self, pattern: str, handler: Callable[[DomainEvent], Any]) -> str: - """ - Subscribe to events matching a pattern. - - Args: - pattern: Event pattern with wildcards (e.g., "execution.*") - handler: Async function to handle matching events + try: + await self._producer.send_and_wait( + topic=self._topic, + value=event.model_dump_json().encode(), + key=event_type.encode(), + headers=[("source_instance", self._instance_id.encode())], + ) + except Exception as e: + self._logger.error(f"Failed to publish to Kafka: {e}") - Returns: - Subscription ID for later unsubscribe - """ + async def subscribe(self, pattern: str, handler: object) -> str: + """Subscribe to events matching a pattern. Returns subscription ID.""" subscription = Subscription(pattern=pattern, handler=handler) async with self._lock: - # Store subscription self._subscriptions[subscription.id] = subscription - - # Update pattern index if pattern not in self._pattern_index: self._pattern_index[pattern] = set() self._pattern_index[pattern].add(subscription.id) + self._metrics.update_event_bus_subscribers(len(self._pattern_index[pattern]), pattern) - # Update metrics - self._update_metrics(pattern) - - self.logger.debug(f"Created subscription {subscription.id} for pattern: {pattern}") return subscription.id - async def unsubscribe(self, pattern: str, handler: Callable[[DomainEvent], Any]) -> None: - """Unsubscribe a specific handler from a pattern.""" + async def unsubscribe(self, pattern: str, handler: object) -> None: + """Unsubscribe a handler from a pattern.""" async with self._lock: - # Find subscription with matching pattern and handler - for sub_id, subscription in list(self._subscriptions.items()): - if subscription.pattern == pattern and subscription.handler == handler: - await self._remove_subscription(sub_id) + for sub_id, sub in list(self._subscriptions.items()): + if sub.pattern == pattern and sub.handler == handler: + del self._subscriptions[sub_id] + self._pattern_index[pattern].discard(sub_id) + if not self._pattern_index[pattern]: + del self._pattern_index[pattern] + self._metrics.update_event_bus_subscribers(0, pattern) + else: + self._metrics.update_event_bus_subscribers(len(self._pattern_index[pattern]), pattern) return - self.logger.warning(f"No subscription found for pattern {pattern} with given handler") - - async def _remove_subscription(self, subscription_id: str) -> None: - """Remove a subscription by ID (must be called within lock).""" - if subscription_id not in self._subscriptions: - self.logger.warning(f"Subscription {subscription_id} not found") - return - - subscription = self._subscriptions[subscription_id] - pattern = subscription.pattern - - # Remove from subscriptions - del self._subscriptions[subscription_id] - - # Update pattern index - if pattern in self._pattern_index: - self._pattern_index[pattern].discard(subscription_id) - if not self._pattern_index[pattern]: - del self._pattern_index[pattern] - - # Update metrics - self._update_metrics(pattern) - - self.logger.debug(f"Removed subscription {subscription_id} for pattern: {pattern}") - - async def _distribute_event(self, event: DomainEvent) -> None: - """Distribute event to all matching local subscribers.""" - # Find matching subscriptions - matching_handlers = await self._find_matching_handlers(event.event_type) - - if not matching_handlers: + async def handle_kafka_message(self, raw_message: bytes, headers: dict[str, str]) -> None: + """Handle a Kafka message. Skips messages from this instance.""" + if headers.get("source_instance") == self._instance_id: return - # Execute all handlers concurrently - results = await asyncio.gather( - *(self._invoke_handler(handler, event) for handler in matching_handlers), return_exceptions=True - ) - - # Log any errors - for _i, result in enumerate(results): - if isinstance(result, Exception): - self.logger.error(f"Handler failed for event {event.event_type}: {result}") - - async def _find_matching_handlers(self, event_type: str) -> list[Callable[[DomainEvent], Any]]: - """Find all handlers matching the event type.""" - async with self._lock: - handlers: list[Callable[[DomainEvent], Any]] = [] - for pattern, sub_ids in self._pattern_index.items(): - if fnmatch.fnmatch(event_type, pattern): - handlers.extend( - self._subscriptions[sub_id].handler for sub_id in sub_ids if sub_id in self._subscriptions - ) - return handlers - - async def _invoke_handler(self, handler: Callable[[DomainEvent], Any], event: DomainEvent) -> None: - """Invoke a single handler, handling both sync and async.""" - if asyncio.iscoroutinefunction(handler): - await handler(event) - else: - await asyncio.to_thread(handler, event) - - async def _kafka_listener(self) -> None: - """Listen for Kafka messages from OTHER instances and distribute to local subscribers.""" - if not self.consumer: - return - - self.logger.info("Kafka listener started") - try: - while self.is_running: - try: - msg = await asyncio.wait_for(self.consumer.getone(), timeout=0.1) - - # Skip messages from this instance - publisher handles its own state - headers = dict(msg.headers) if msg.headers else {} - source = headers.get("source_instance", b"").decode("utf-8") - if source == self._instance_id: - continue - - try: - event_dict = json.loads(msg.value.decode("utf-8")) - event = domain_event_adapter.validate_python(event_dict) - await self._distribute_event(event) - except Exception as e: - self.logger.error(f"Error processing Kafka message: {e}") - - except asyncio.TimeoutError: - continue - except KafkaError as e: - self.logger.error(f"Consumer error: {e}") - continue - - except asyncio.CancelledError: - self.logger.info("Kafka listener cancelled") + event = EventBusEvent.model_validate(json.loads(raw_message)) + await self._distribute_event(event) except Exception as e: - self.logger.error(f"Fatal error in Kafka listener: {e}") - - def _update_metrics(self, pattern: str) -> None: - """Update metrics for a pattern (must be called within lock).""" - if self.metrics: - count = len(self._pattern_index.get(pattern, set())) - self.metrics.update_event_bus_subscribers(count, pattern) + self._logger.error(f"Error processing Kafka message: {e}") - async def get_statistics(self) -> dict[str, Any]: - """Get event bus statistics.""" + async def _distribute_event(self, event: EventBusEvent) -> None: + """Distribute event to matching local subscribers.""" async with self._lock: - return { - "patterns": list(self._pattern_index.keys()), - "total_patterns": len(self._pattern_index), - "total_subscriptions": len(self._subscriptions), - "kafka_enabled": self.producer is not None, - "running": self.is_running, - } - - -class EventBusManager: - """Manages EventBus lifecycle as a singleton.""" - - def __init__(self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics) -> None: - self.settings = settings - self.logger = logger - self._connection_metrics = connection_metrics - self._event_bus: EventBus | None = None - self._lock = asyncio.Lock() - - async def get_event_bus(self) -> EventBus: - """Get or create the event bus instance.""" - async with self._lock: - if self._event_bus is None: - self._event_bus = EventBus(self.settings, self.logger, self._connection_metrics) - await self._event_bus.__aenter__() - return self._event_bus - - async def close(self) -> None: - """Stop and clean up the event bus.""" - async with self._lock: - if self._event_bus: - await self._event_bus.aclose() - self._event_bus = None - - -async def get_event_bus(request: Request) -> EventBus: - manager: EventBusManager = request.app.state.event_bus_manager - return await manager.get_event_bus() + handlers = [ + self._subscriptions[sub_id].handler + for pattern, sub_ids in self._pattern_index.items() + if fnmatch.fnmatch(event.event_type, pattern) + for sub_id in sub_ids + if sub_id in self._subscriptions + ] + + for handler in handlers: + try: + if asyncio.iscoroutinefunction(handler): + await handler(event) + else: + handler(event) # type: ignore[operator] + except Exception as e: + self._logger.error(f"Handler failed for {event.event_type}: {e}") diff --git a/backend/app/services/idempotency/middleware.py b/backend/app/services/idempotency/middleware.py index 04a4f931..4dac5287 100644 --- a/backend/app/services/idempotency/middleware.py +++ b/backend/app/services/idempotency/middleware.py @@ -1,12 +1,11 @@ +"""Idempotent event processing middleware""" + import asyncio import logging -from collections.abc import Awaitable -from typing import Any, Callable +from typing import Any, Awaitable, Callable, Dict, Set from app.domain.enums.events import EventType -from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent -from app.domain.idempotency import KeyStrategy from app.events.core import EventDispatcher, UnifiedConsumer from app.services.idempotency.idempotency_manager import IdempotencyManager @@ -19,9 +18,9 @@ def __init__( handler: Callable[[DomainEvent], Awaitable[None]], idempotency_manager: IdempotencyManager, logger: logging.Logger, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, + key_strategy: str = "event_based", custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: set[str] | None = None, + fields: Set[str] | None = None, ttl_seconds: int | None = None, cache_result: bool = True, on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, @@ -44,7 +43,7 @@ async def __call__(self, event: DomainEvent) -> None: ) # Generate custom key if function provided custom_key = None - if self.key_strategy == KeyStrategy.CUSTOM and self.custom_key_func: + if self.key_strategy == "custom" and self.custom_key_func: custom_key = self.custom_key_func(event) # Check idempotency @@ -93,9 +92,9 @@ async def __call__(self, event: DomainEvent) -> None: def idempotent_handler( idempotency_manager: IdempotencyManager, logger: logging.Logger, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, + key_strategy: str = "event_based", custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: set[str] | None = None, + fields: Set[str] | None = None, ttl_seconds: int | None = None, cache_result: bool = True, on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, @@ -128,7 +127,7 @@ def __init__( idempotency_manager: IdempotencyManager, dispatcher: EventDispatcher, logger: logging.Logger, - default_key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, + default_key_strategy: str = "event_based", default_ttl_seconds: int = 3600, enable_for_all_handlers: bool = True, ): @@ -138,22 +137,19 @@ def __init__( self.logger = logger self.default_key_strategy = default_key_strategy self.default_ttl_seconds = default_ttl_seconds - self.enable_for_all_handlers = enable_for_all_handlers - self._original_handlers: dict[EventType, list[Callable[[DomainEvent], Awaitable[None]]]] = {} + self._original_handlers: Dict[EventType, list[Callable[[DomainEvent], Awaitable[None]]]] = {} - def make_handlers_idempotent(self) -> None: - """Wrap all registered handlers with idempotency""" - self.logger.info( - f"make_handlers_idempotent called: enable_for_all={self.enable_for_all_handlers}, " - f"dispatcher={self.dispatcher is not None}" - ) - if not self.enable_for_all_handlers or not self.dispatcher: - self.logger.warning("Skipping handler wrapping - conditions not met") + if enable_for_all_handlers: + self._wrap_handlers() + + def _wrap_handlers(self) -> None: + """Wrap all registered handlers with idempotency.""" + if not self.dispatcher: + self.logger.warning("No dispatcher available for handler wrapping") return - # Store original handlers using public API self._original_handlers = self.dispatcher.get_all_handlers() - self.logger.info(f"Got {len(self._original_handlers)} event types with handlers to wrap") + self.logger.debug(f"Wrapping {len(self._original_handlers)} event types with idempotency") # Wrap each handler for event_type, handlers in self._original_handlers.items(): @@ -169,21 +165,15 @@ def make_handlers_idempotent(self) -> None: ) wrapped_handlers.append(wrapped) - # Replace handlers using public API - self.logger.info( - f"Replacing {len(handlers)} handlers for {event_type} with {len(wrapped_handlers)} wrapped handlers" - ) self.dispatcher.replace_handlers(event_type, wrapped_handlers) - self.logger.info("Handler wrapping complete") - def subscribe_idempotent_handler( self, event_type: str, handler: Callable[[DomainEvent], Awaitable[None]], - key_strategy: KeyStrategy | None = None, + key_strategy: str | None = None, custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: set[str] | None = None, + fields: Set[str] | None = None, ttl_seconds: int | None = None, cache_result: bool = True, on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, @@ -259,21 +249,3 @@ async def dispatch_handler(event: DomainEvent) -> None: else: # Fallback to direct consumer registration if no dispatcher self.logger.error(f"No EventDispatcher available for registering idempotent handler for {event_type}") - - async def start(self, topics: list[KafkaTopic]) -> None: - """Start the consumer with idempotency""" - self.logger.info(f"IdempotentConsumerWrapper.start called with topics: {topics}") - # Make handlers idempotent before starting - self.make_handlers_idempotent() - - # Start the consumer with required topics parameter - await self.consumer.start(topics) - self.logger.info("IdempotentConsumerWrapper started successfully") - - async def stop(self) -> None: - """Stop the consumer""" - await self.consumer.stop() - - # Delegate other methods to the wrapped consumer - def __getattr__(self, name: str) -> Any: - return getattr(self.consumer, name) diff --git a/backend/app/services/k8s_worker/worker.py b/backend/app/services/k8s_worker/worker.py index eafceeca..c49ec98f 100644 --- a/backend/app/services/k8s_worker/worker.py +++ b/backend/app/services/k8s_worker/worker.py @@ -1,336 +1,169 @@ +"""Kubernetes Worker - stateless event handler. + +Creates Kubernetes pods from execution events. Receives events, +processes them, and publishes results. No lifecycle management. +All state is stored in Redis repositories. +""" + +from __future__ import annotations + import asyncio import logging -import os import time from pathlib import Path -from typing import Any from kubernetes import client as k8s_client -from kubernetes import config as k8s_config from kubernetes.client.rest import ApiException -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics, ExecutionMetrics, KubernetesMetrics -from app.domain.enums.events import EventType -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId +from app.db.repositories.pod_state_repository import PodStateRepository from app.domain.enums.storage import ExecutionErrorType from app.domain.events.typed import ( CreatePodCommandEvent, DeletePodCommandEvent, - DomainEvent, ExecutionFailedEvent, ExecutionStartedEvent, PodCreatedEvent, ) -from app.domain.idempotency import KeyStrategy -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import ( - SchemaRegistryManager, -) +from app.events.core import UnifiedProducer from app.runtime_registry import RUNTIME_REGISTRY -from app.services.idempotency import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.pod_builder import PodBuilder -from app.settings import Settings -class KubernetesWorker(LifecycleEnabled): - """ - Worker service that creates Kubernetes pods from execution events. - - This service: - 1. Consumes ExecutionStarted events from Kafka - 2. Creates ConfigMaps with script content - 3. Creates Pods to execute the scripts - 4. Creates NetworkPolicies for security - 5. Publishes PodCreated events +class KubernetesWorker: + """Stateless Kubernetes worker - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + All state (active creations) stored in Redis via PodStateRepository. """ def __init__( - self, - config: K8sWorkerConfig, - producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - event_metrics: EventMetrics, - ): - super().__init__() + self, + config: K8sWorkerConfig, + producer: UnifiedProducer, + pod_state_repo: PodStateRepository, + v1_client: k8s_client.CoreV1Api, + networking_v1_client: k8s_client.NetworkingV1Api, + apps_v1_client: k8s_client.AppsV1Api, + logger: logging.Logger, + kubernetes_metrics: KubernetesMetrics, + execution_metrics: ExecutionMetrics, + event_metrics: EventMetrics, + ) -> None: + self._config = config + self._producer = producer + self._pod_state_repo = pod_state_repo + self._v1 = v1_client + self._networking_v1 = networking_v1_client + self._apps_v1 = apps_v1_client + self._logger = logger + self._metrics = kubernetes_metrics + self._execution_metrics = execution_metrics self._event_metrics = event_metrics - self.logger = logger - self.metrics = KubernetesMetrics(settings) - self.execution_metrics = ExecutionMetrics(settings) - self.config = config or K8sWorkerConfig() - self._settings = settings - - self.kafka_servers = self._settings.KAFKA_BOOTSTRAP_SERVERS - self._event_store = event_store - - # Kubernetes clients - self.v1: k8s_client.CoreV1Api | None = None - self.networking_v1: k8s_client.NetworkingV1Api | None = None - self.apps_v1: k8s_client.AppsV1Api | None = None - - # Components - self.pod_builder = PodBuilder(namespace=self.config.namespace, config=self.config) - self.consumer: UnifiedConsumer | None = None - self.idempotent_consumer: IdempotentConsumerWrapper | None = None - self.idempotency_manager: IdempotencyManager = idempotency_manager - self.dispatcher: EventDispatcher | None = None - self.producer: UnifiedProducer = producer - - # State tracking - self._active_creations: set[str] = set() - self._creation_semaphore = asyncio.Semaphore(self.config.max_concurrent_pods) - self._schema_registry_manager = schema_registry_manager - - async def _on_start(self) -> None: - """Start the Kubernetes worker.""" - self.logger.info("Starting KubernetesWorker service...") - self.logger.info("DEBUG: About to initialize Kubernetes client") - - if self.config.namespace == "default": - raise RuntimeError( - "KubernetesWorker namespace 'default' is forbidden. Set K8S_NAMESPACE to a dedicated namespace." - ) - - # Initialize Kubernetes client - self._initialize_kubernetes_client() - self.logger.info("DEBUG: Kubernetes client initialized") - - self.logger.info("Using provided producer") - - self.logger.info("Idempotency manager provided") - - # Create consumer configuration - consumer_config = ConsumerConfig( - bootstrap_servers=self.kafka_servers, - group_id=self.config.consumer_group, - enable_auto_commit=False, - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - # Create dispatcher and register handlers for saga commands - self.dispatcher = EventDispatcher(logger=self.logger) - self.dispatcher.register_handler(EventType.CREATE_POD_COMMAND, self._handle_create_pod_command_wrapper) - self.dispatcher.register_handler(EventType.DELETE_POD_COMMAND, self._handle_delete_pod_command_wrapper) - - # Create consumer with dispatcher - self.consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self.dispatcher, - schema_registry=self._schema_registry_manager, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - - # Wrap consumer with idempotency - use content hash for pod commands - self.idempotent_consumer = IdempotentConsumerWrapper( - consumer=self.consumer, - idempotency_manager=self.idempotency_manager, - dispatcher=self.dispatcher, - logger=self.logger, - default_key_strategy=KeyStrategy.CONTENT_HASH, # Hash execution_id + script for deduplication - default_ttl_seconds=3600, # 1 hour TTL for pod creation events - enable_for_all_handlers=True, # Enable idempotency for all handlers - ) - - # Start the consumer with idempotency - topics from centralized config - await self.idempotent_consumer.start(list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.K8S_WORKER])) - - # Create daemonset for image pre-pulling - asyncio.create_task(self.ensure_image_pre_puller_daemonset()) - self.logger.info("Image pre-puller daemonset task scheduled") - - self.logger.info("KubernetesWorker service started successfully") - - async def _on_stop(self) -> None: - """Stop the Kubernetes worker.""" - self.logger.info("Stopping KubernetesWorker service...") - - # Wait for active creations to complete - if self._active_creations: - self.logger.info(f"Waiting for {len(self._active_creations)} active pod creations to complete...") - timeout = 30 - start_time = time.time() - - while self._active_creations and (time.time() - start_time) < timeout: - await asyncio.sleep(1) - - if self._active_creations: - self.logger.warning(f"Timeout waiting for pod creations, {len(self._active_creations)} still active") - - # Stop the consumer (idempotent wrapper only) - if self.idempotent_consumer: - await self.idempotent_consumer.stop() - - # Close idempotency manager - await self.idempotency_manager.close() - - # Note: producer is managed by DI container, not stopped here - - self.logger.info("KubernetesWorker service stopped") - - def _initialize_kubernetes_client(self) -> None: - """Initialize Kubernetes API clients""" - try: - # Load config - if self.config.in_cluster: - self.logger.info("Using in-cluster Kubernetes configuration") - k8s_config.load_incluster_config() - elif self.config.kubeconfig_path and os.path.exists(self.config.kubeconfig_path): - self.logger.info(f"Using kubeconfig from {self.config.kubeconfig_path}") - k8s_config.load_kube_config(config_file=self.config.kubeconfig_path) - else: - # Try default locations - if os.path.exists("/var/run/secrets/kubernetes.io/serviceaccount"): - self.logger.info("Detected in-cluster environment") - k8s_config.load_incluster_config() - else: - self.logger.info("Using default kubeconfig") - k8s_config.load_kube_config() - - # Get the default configuration that was set by load_kube_config - configuration = k8s_client.Configuration.get_default_copy() - - # The certificate data should already be configured by load_kube_config - # Log the configuration for debugging - self.logger.info(f"Kubernetes API host: {configuration.host}") - self.logger.info(f"SSL CA cert configured: {configuration.ssl_ca_cert is not None}") + self._pod_builder = PodBuilder(namespace=config.namespace, config=config) - # Create API clients with the configuration - api_client = k8s_client.ApiClient(configuration) - self.v1 = k8s_client.CoreV1Api(api_client) - self.networking_v1 = k8s_client.NetworkingV1Api(api_client) - self.apps_v1 = k8s_client.AppsV1Api(api_client) - - # Test connection with namespace-scoped operation - _ = self.v1.list_namespaced_pod(namespace=self.config.namespace, limit=1) - self.logger.info(f"Successfully connected to Kubernetes API, namespace {self.config.namespace} accessible") - - except Exception as e: - self.logger.error(f"Failed to initialize Kubernetes client: {e}") - raise - - async def _handle_create_pod_command_wrapper(self, event: DomainEvent) -> None: - """Wrapper for handling CreatePodCommandEvent with type safety.""" - assert isinstance(event, CreatePodCommandEvent) - self.logger.info(f"Processing create_pod_command for execution {event.execution_id} from saga {event.saga_id}") - await self._handle_create_pod_command(event) - - async def _handle_delete_pod_command_wrapper(self, event: DomainEvent) -> None: - """Wrapper for handling DeletePodCommandEvent.""" - assert isinstance(event, DeletePodCommandEvent) - self.logger.info(f"Processing delete_pod_command for execution {event.execution_id} from saga {event.saga_id}") - await self._handle_delete_pod_command(event) - - async def _handle_create_pod_command(self, command: CreatePodCommandEvent) -> None: - """Handle create pod command from saga orchestrator""" + async def handle_create_pod_command(self, command: CreatePodCommandEvent) -> None: + """Handle create pod command from saga orchestrator.""" execution_id = command.execution_id + self._logger.info(f"Processing create_pod_command for execution {execution_id} from saga {command.saga_id}") - # Check if already processing - if execution_id in self._active_creations: - self.logger.warning(f"Already creating pod for execution {execution_id}") + # Try to claim this creation atomically in Redis + claimed = await self._pod_state_repo.try_claim_creation(execution_id, ttl_seconds=300) + if not claimed: + self._logger.warning(f"Already creating pod for execution {execution_id}, skipping") return - # Create pod asynchronously - asyncio.create_task(self._create_pod_for_execution(command)) + await self._create_pod_for_execution(command) - async def _handle_delete_pod_command(self, command: DeletePodCommandEvent) -> None: - """Handle delete pod command from saga orchestrator (compensation)""" + async def handle_delete_pod_command(self, command: DeletePodCommandEvent) -> None: + """Handle delete pod command from saga orchestrator (compensation).""" execution_id = command.execution_id - self.logger.info(f"Deleting pod for execution {execution_id} due to: {command.reason}") + self._logger.info(f"Deleting pod for execution {execution_id} due to: {command.reason}") try: # Delete the pod pod_name = f"executor-{execution_id}" - if self.v1: - await asyncio.to_thread( - self.v1.delete_namespaced_pod, - name=pod_name, - namespace=self.config.namespace, - grace_period_seconds=30, - ) - self.logger.info(f"Successfully deleted pod {pod_name}") + await asyncio.to_thread( + self._v1.delete_namespaced_pod, + name=pod_name, + namespace=self._config.namespace, + grace_period_seconds=30, + ) + self._logger.info(f"Successfully deleted pod {pod_name}") # Delete associated ConfigMap configmap_name = f"script-{execution_id}" - if self.v1: - await asyncio.to_thread( - self.v1.delete_namespaced_config_map, name=configmap_name, namespace=self.config.namespace - ) - self.logger.info(f"Successfully deleted ConfigMap {configmap_name}") - - # NetworkPolicy cleanup is managed via a static cluster policy; no per-execution NP deletion + await asyncio.to_thread( + self._v1.delete_namespaced_config_map, + name=configmap_name, + namespace=self._config.namespace, + ) + self._logger.info(f"Successfully deleted ConfigMap {configmap_name}") except ApiException as e: if e.status == 404: - self.logger.warning(f"Resources for execution {execution_id} not found (may have already been deleted)") + self._logger.warning( + f"Resources for execution {execution_id} not found (may have already been deleted)" + ) else: - self.logger.error(f"Failed to delete resources for execution {execution_id}: {e}") + self._logger.error(f"Failed to delete resources for execution {execution_id}: {e}") async def _create_pod_for_execution(self, command: CreatePodCommandEvent) -> None: - """Create pod for execution""" - async with self._creation_semaphore: - execution_id = command.execution_id - self._active_creations.add(execution_id) - self.metrics.update_k8s_active_creations(len(self._active_creations)) - - # Queue depth is owned by the coordinator; do not modify here - - start_time = time.time() - - try: - # We now have the CreatePodCommandEvent directly from saga - script_content = command.script - entrypoint_content = await self._get_entrypoint_script() - - # Create ConfigMap - config_map = self.pod_builder.build_config_map( - command=command, script_content=script_content, entrypoint_content=entrypoint_content - ) + """Create pod for execution.""" + execution_id = command.execution_id + start_time = time.time() - await self._create_config_map(config_map) + try: + # Update metrics for active creations + active_count = await self._pod_state_repo.get_active_creations_count() + self._metrics.update_k8s_active_creations(active_count) + + # Build and create ConfigMap + script_content = command.script + entrypoint_content = await self._get_entrypoint_script() + + config_map = self._pod_builder.build_config_map( + command=command, + script_content=script_content, + entrypoint_content=entrypoint_content, + ) + await self._create_config_map(config_map) - pod = self.pod_builder.build_pod_manifest(command=command) - await self._create_pod(pod) + # Build and create Pod + pod = self._pod_builder.build_pod_manifest(command=command) + await self._create_pod(pod) - # Publish PodCreated event - await self._publish_pod_created(command, pod) + # Publish PodCreated event + await self._publish_pod_created(command, pod) - # Update metrics - duration = time.time() - start_time - self.metrics.record_k8s_pod_creation_duration(duration, command.language) - self.metrics.record_k8s_pod_created("success", command.language) + # Update metrics + duration = time.time() - start_time + self._metrics.record_k8s_pod_creation_duration(duration, command.language) + self._metrics.record_k8s_pod_created("success", command.language) - self.logger.info( - f"Successfully created pod {pod.metadata.name} for execution {execution_id}. " - f"Duration: {duration:.2f}s" - ) + self._logger.info( + f"Successfully created pod {pod.metadata.name} for execution {execution_id}. " + f"Duration: {duration:.2f}s" + ) - except Exception as e: - self.logger.error(f"Failed to create pod for execution {execution_id}: {e}", exc_info=True) + except Exception as e: + self._logger.error(f"Failed to create pod for execution {execution_id}: {e}", exc_info=True) + self._metrics.record_k8s_pod_created("failed", "unknown") - # Update metrics - self.metrics.record_k8s_pod_created("failed", "unknown") + # Publish failure event + await self._publish_pod_creation_failed(command, str(e)) - # Publish failure event - await self._publish_pod_creation_failed(command, str(e)) + finally: + # Release the creation claim + await self._pod_state_repo.release_creation(execution_id) - finally: - self._active_creations.discard(execution_id) - self.metrics.update_k8s_active_creations(len(self._active_creations)) + # Update metrics + active_count = await self._pod_state_repo.get_active_creations_count() + self._metrics.update_k8s_active_creations(active_count) async def _get_entrypoint_script(self) -> str: - """Get entrypoint script content""" + """Get entrypoint script content.""" entrypoint_path = Path("app/scripts/entrypoint.sh") if entrypoint_path.exists(): return await asyncio.to_thread(entrypoint_path.read_text) @@ -353,67 +186,62 @@ async def _get_entrypoint_script(self) -> str: """ async def _create_config_map(self, config_map: k8s_client.V1ConfigMap) -> None: - """Create ConfigMap in Kubernetes""" - if not self.v1: - raise RuntimeError("Kubernetes client not initialized") + """Create ConfigMap in Kubernetes.""" try: await asyncio.to_thread( - self.v1.create_namespaced_config_map, namespace=self.config.namespace, body=config_map + self._v1.create_namespaced_config_map, + namespace=self._config.namespace, + body=config_map, ) - self.metrics.record_k8s_config_map_created("success") - self.logger.debug(f"Created ConfigMap {config_map.metadata.name}") + self._metrics.record_k8s_config_map_created("success") + self._logger.debug(f"Created ConfigMap {config_map.metadata.name}") except ApiException as e: if e.status == 409: # Already exists - self.logger.warning(f"ConfigMap {config_map.metadata.name} already exists") - self.metrics.record_k8s_config_map_created("already_exists") + self._logger.warning(f"ConfigMap {config_map.metadata.name} already exists") + self._metrics.record_k8s_config_map_created("already_exists") else: - self.metrics.record_k8s_config_map_created("failed") + self._metrics.record_k8s_config_map_created("failed") raise async def _create_pod(self, pod: k8s_client.V1Pod) -> None: - """Create Pod in Kubernetes""" - if not self.v1: - raise RuntimeError("Kubernetes client not initialized") + """Create Pod in Kubernetes.""" try: - await asyncio.to_thread(self.v1.create_namespaced_pod, namespace=self.config.namespace, body=pod) - self.logger.debug(f"Created Pod {pod.metadata.name}") + await asyncio.to_thread( + self._v1.create_namespaced_pod, + namespace=self._config.namespace, + body=pod, + ) + self._logger.debug(f"Created Pod {pod.metadata.name}") except ApiException as e: if e.status == 409: # Already exists - self.logger.warning(f"Pod {pod.metadata.name} already exists") + self._logger.warning(f"Pod {pod.metadata.name} already exists") else: raise async def _publish_execution_started(self, command: CreatePodCommandEvent, pod: k8s_client.V1Pod) -> None: - """Publish execution started event""" + """Publish execution started event.""" event = ExecutionStartedEvent( execution_id=command.execution_id, - aggregate_id=command.execution_id, # Set aggregate_id to execution_id + aggregate_id=command.execution_id, pod_name=pod.metadata.name, node_name=pod.spec.node_name, - container_id=None, # Will be set when container actually starts + container_id=None, metadata=command.metadata, ) - if not self.producer: - self.logger.error("Producer not initialized") - return - await self.producer.produce(event_to_produce=event) + await self._producer.produce(event_to_produce=event) async def _publish_pod_created(self, command: CreatePodCommandEvent, pod: k8s_client.V1Pod) -> None: - """Publish pod created event""" + """Publish pod created event.""" event = PodCreatedEvent( execution_id=command.execution_id, pod_name=pod.metadata.name, namespace=pod.metadata.namespace, metadata=command.metadata, ) - - if not self.producer: - self.logger.error("Producer not initialized") - return - await self.producer.produce(event_to_produce=event) + await self._producer.produce(event_to_produce=event) async def _publish_pod_creation_failed(self, command: CreatePodCommandEvent, error: str) -> None: - """Publish pod creation failed event""" + """Publish pod creation failed event.""" event = ExecutionFailedEvent( execution_id=command.execution_id, error_type=ExecutionErrorType.SYSTEM_ERROR, @@ -423,33 +251,28 @@ async def _publish_pod_creation_failed(self, command: CreatePodCommandEvent, err metadata=command.metadata, error_message=str(error), ) + await self._producer.produce(event_to_produce=event, key=command.execution_id) - if not self.producer: - self.logger.error("Producer not initialized") - return - await self.producer.produce(event_to_produce=event) - - async def get_status(self) -> dict[str, Any]: - """Get worker status""" + async def get_status(self) -> dict[str, object]: + """Get worker status.""" + active_count = await self._pod_state_repo.get_active_creations_count() return { - "running": self.is_running, - "active_creations": len(self._active_creations), + "active_creations": active_count, "config": { - "namespace": self.config.namespace, - "max_concurrent_pods": self.config.max_concurrent_pods, + "namespace": self._config.namespace, + "max_concurrent_pods": self._config.max_concurrent_pods, "enable_network_policies": True, }, } async def ensure_image_pre_puller_daemonset(self) -> None: - """Ensure the runtime image pre-puller DaemonSet exists""" - if not self.apps_v1: - self.logger.warning("Kubernetes AppsV1Api client not initialized. Skipping DaemonSet creation.") - return + """Ensure the runtime image pre-puller DaemonSet exists. + This should be called once at startup from the worker entrypoint, + not as a background task. + """ daemonset_name = "runtime-image-pre-puller" - namespace = self.config.namespace - await asyncio.sleep(5) + namespace = self._config.namespace try: init_containers = [] @@ -457,7 +280,7 @@ async def ensure_image_pre_puller_daemonset(self) -> None: for i, image_ref in enumerate(sorted(list(all_images))): sanitized_image_ref = image_ref.split("/")[-1].replace(":", "-").replace(".", "-").replace("_", "-") - self.logger.info(f"DAEMONSET: before: {image_ref} -> {sanitized_image_ref}") + self._logger.info(f"DAEMONSET: before: {image_ref} -> {sanitized_image_ref}") container_name = f"pull-{i}-{sanitized_image_ref}" init_containers.append( { @@ -468,7 +291,7 @@ async def ensure_image_pre_puller_daemonset(self) -> None: } ) - manifest: dict[str, Any] = { + manifest: dict[str, object] = { "apiVersion": "apps/v1", "kind": "DaemonSet", "metadata": {"name": daemonset_name, "namespace": namespace}, @@ -488,24 +311,31 @@ async def ensure_image_pre_puller_daemonset(self) -> None: try: await asyncio.to_thread( - self.apps_v1.read_namespaced_daemon_set, name=daemonset_name, namespace=namespace + self._apps_v1.read_namespaced_daemon_set, + name=daemonset_name, + namespace=namespace, ) - self.logger.info(f"DaemonSet '{daemonset_name}' exists. Replacing to ensure it is up-to-date.") + self._logger.info(f"DaemonSet '{daemonset_name}' exists. Replacing to ensure it is up-to-date.") await asyncio.to_thread( - self.apps_v1.replace_namespaced_daemon_set, name=daemonset_name, namespace=namespace, body=manifest + self._apps_v1.replace_namespaced_daemon_set, + name=daemonset_name, + namespace=namespace, + body=manifest, ) - self.logger.info(f"DaemonSet '{daemonset_name}' replaced successfully.") + self._logger.info(f"DaemonSet '{daemonset_name}' replaced successfully.") except ApiException as e: if e.status == 404: - self.logger.info(f"DaemonSet '{daemonset_name}' not found. Creating...") + self._logger.info(f"DaemonSet '{daemonset_name}' not found. Creating...") await asyncio.to_thread( - self.apps_v1.create_namespaced_daemon_set, namespace=namespace, body=manifest + self._apps_v1.create_namespaced_daemon_set, + namespace=namespace, + body=manifest, ) - self.logger.info(f"DaemonSet '{daemonset_name}' created successfully.") + self._logger.info(f"DaemonSet '{daemonset_name}' created successfully.") else: raise except ApiException as e: - self.logger.error(f"K8s API error applying DaemonSet '{daemonset_name}': {e.reason}", exc_info=True) + self._logger.error(f"K8s API error applying DaemonSet '{daemonset_name}': {e.reason}", exc_info=True) except Exception as e: - self.logger.error(f"Unexpected error applying image-puller DaemonSet: {e}", exc_info=True) + self._logger.error(f"Unexpected error applying image-puller DaemonSet: {e}", exc_info=True) diff --git a/backend/app/services/kafka_event_service.py b/backend/app/services/kafka_event_service.py index ac2207ca..9c152b97 100644 --- a/backend/app/services/kafka_event_service.py +++ b/backend/app/services/kafka_event_service.py @@ -1,7 +1,7 @@ import logging import time from datetime import datetime, timezone -from typing import Any +from typing import Any, Dict from uuid import uuid4 from opentelemetry import trace @@ -21,12 +21,12 @@ class KafkaEventService: def __init__( - self, - event_repository: EventRepository, - kafka_producer: UnifiedProducer, - settings: Settings, - logger: logging.Logger, - event_metrics: EventMetrics, + self, + event_repository: EventRepository, + kafka_producer: UnifiedProducer, + settings: Settings, + logger: logging.Logger, + event_metrics: EventMetrics, ): self.event_repository = event_repository self.kafka_producer = kafka_producer @@ -35,12 +35,12 @@ def __init__( self.settings = settings async def publish_event( - self, - event_type: EventType, - payload: dict[str, Any], - aggregate_id: str | None, - correlation_id: str | None = None, - metadata: EventMetadata | None = None, + self, + event_type: EventType, + payload: Dict[str, Any], + aggregate_id: str | None, + correlation_id: str | None = None, + metadata: EventMetadata | None = None, ) -> str: """ Publish an event to Kafka and store an audit copy via the repository @@ -90,7 +90,7 @@ async def publish_event( await self.event_repository.store_event(domain_event) # Prepare headers - headers: dict[str, str] = { + headers: Dict[str, str] = { "event_type": event_type, "correlation_id": event_metadata.correlation_id or "", "service": event_metadata.service_name, @@ -113,12 +113,12 @@ async def publish_event( return domain_event.event_id async def publish_execution_event( - self, - event_type: EventType, - execution_id: str, - status: str, - metadata: EventMetadata | None = None, - error_message: str | None = None, + self, + event_type: EventType, + execution_id: str, + status: str, + metadata: EventMetadata | None = None, + error_message: str | None = None, ) -> str: """Publish execution-related event using provided metadata (no framework coupling).""" self.logger.info( @@ -154,13 +154,13 @@ async def publish_execution_event( return event_id async def publish_pod_event( - self, - event_type: EventType, - pod_name: str, - execution_id: str, - namespace: str = "integr8scode", - status: str | None = None, - metadata: EventMetadata | None = None, + self, + event_type: EventType, + pod_name: str, + execution_id: str, + namespace: str = "integr8scode", + status: str | None = None, + metadata: EventMetadata | None = None, ) -> str: """Publish pod-related event""" payload = {"pod_name": pod_name, "execution_id": execution_id, "namespace": namespace} @@ -185,7 +185,7 @@ async def publish_domain_event(self, event: DomainEvent, key: str | None = None) start_time = time.time() await self.event_repository.store_event(event) - headers: dict[str, str] = { + headers: Dict[str, str] = { "event_type": event.event_type, "correlation_id": event.metadata.correlation_id or "", "service": event.metadata.service_name, @@ -201,7 +201,3 @@ async def publish_domain_event(self, event: DomainEvent, key: str | None = None) self.metrics.record_event_processing_duration(time.time() - start_time, event.event_type) self.logger.info("Domain event published", extra={"event_id": event.event_id}) return event.event_id - - async def close(self) -> None: - """Close event service resources""" - await self.kafka_producer.aclose() diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 2d005fbc..1e37d987 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -1,17 +1,21 @@ +"""Notification Service - stateless event handler. + +Handles notification creation and delivery. Receives events, +processes them, and delivers notifications. No lifecycle management. +""" + +from __future__ import annotations + import asyncio import logging from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta -from typing import Awaitable, Callable import httpx -from app.core.lifecycle import LifecycleEnabled -from app.core.metrics import EventMetrics, NotificationMetrics +from app.core.metrics import NotificationMetrics from app.core.tracing.utils import add_span_attributes from app.db.repositories.notification_repository import NotificationRepository -from app.domain.enums.events import EventType -from app.domain.enums.kafka import GroupId from app.domain.enums.notification import ( NotificationChannel, NotificationSeverity, @@ -20,13 +24,9 @@ from app.domain.enums.user import UserRole from app.domain.events.typed import ( DomainEvent, - EventMetadata, ExecutionCompletedEvent, ExecutionFailedEvent, ExecutionTimeoutEvent, - NotificationAllReadEvent, - NotificationCreatedEvent, - NotificationReadEvent, ) from app.domain.notification import ( DomainNotification, @@ -39,25 +39,13 @@ NotificationThrottledError, NotificationValidationError, ) -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.infrastructure.kafka.mappings import get_topic_for_event from app.schemas_pydantic.sse import RedisNotificationMessage -from app.services.event_bus import EventBusManager -from app.services.kafka_event_service import KafkaEventService +from app.services.event_bus import EventBus from app.services.sse.redis_bus import SSERedisBus from app.settings import Settings -# Constants ENTITY_EXECUTION_TAG = "entity:execution" -# Type aliases -type EventPayload = dict[str, object] -type NotificationContext = dict[str, object] -type ChannelHandler = Callable[[DomainNotification, DomainNotificationSubscription], Awaitable[None]] -type SystemNotificationStats = dict[str, int] -type SlackMessage = dict[str, object] - @dataclass class ThrottleCache: @@ -67,11 +55,11 @@ class ThrottleCache: _lock: asyncio.Lock = field(default_factory=asyncio.Lock) async def check_throttle( - self, - user_id: str, - severity: NotificationSeverity, - window_hours: int, - max_per_hour: int, + self, + user_id: str, + severity: NotificationSeverity, + window_hours: int, + max_per_hour: int, ) -> bool: """Check if notification should be throttled.""" key = f"{user_id}:{severity}" @@ -82,14 +70,11 @@ async def check_throttle( if key not in self._entries: self._entries[key] = [] - # Clean old entries self._entries[key] = [ts for ts in self._entries[key] if ts > window_start] - # Check limit if len(self._entries[key]) >= max_per_hour: return True - # Add new entry self._entries[key].append(now) return False @@ -105,157 +90,137 @@ class SystemConfig: throttle_exempt: bool -class NotificationService(LifecycleEnabled): +class NotificationService: + """Stateless notification service - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + Worker entrypoint handles the consume loop. + """ + def __init__( - self, - notification_repository: NotificationRepository, - event_service: KafkaEventService, - event_bus_manager: EventBusManager, - schema_registry_manager: SchemaRegistryManager, - sse_bus: SSERedisBus, - settings: Settings, - logger: logging.Logger, - notification_metrics: NotificationMetrics, - event_metrics: EventMetrics, + self, + notification_repository: NotificationRepository, + event_bus: EventBus, + sse_bus: SSERedisBus, + settings: Settings, + logger: logging.Logger, + notification_metrics: NotificationMetrics, ) -> None: - super().__init__() - self.repository = notification_repository - self.event_service = event_service - self.event_bus_manager = event_bus_manager - self.metrics = notification_metrics - self._event_metrics = event_metrics - self.settings = settings - self.schema_registry_manager = schema_registry_manager - self.sse_bus = sse_bus - self.logger = logger - - # State + self._repository = notification_repository + self._event_bus = event_bus + self._sse_bus = sse_bus + self._settings = settings + self._logger = logger + self._metrics = notification_metrics self._throttle_cache = ThrottleCache() - # Tasks - self._tasks: set[asyncio.Task[None]] = set() - - self._consumer: UnifiedConsumer | None = None - self._dispatcher: EventDispatcher | None = None - self._consumer_task: asyncio.Task[None] | None = None - - self.logger.info( - "NotificationService initialized", - extra={ - "repository": type(notification_repository).__name__, - "event_service": type(event_service).__name__, - "schema_registry": type(schema_registry_manager).__name__, - }, - ) - - # Channel handlers mapping - self._channel_handlers: dict[NotificationChannel, ChannelHandler] = { + self._channel_handlers: dict[NotificationChannel, object] = { NotificationChannel.IN_APP: self._send_in_app, NotificationChannel.WEBHOOK: self._send_webhook, NotificationChannel.SLACK: self._send_slack, } - async def _on_start(self) -> None: - """Start the notification service with Kafka consumer.""" - self.logger.info("Starting notification service...") - self._start_background_tasks() - await self._subscribe_to_events() - self.logger.info("Notification service started with Kafka consumer") - - async def _on_stop(self) -> None: - """Stop the notification service.""" - self.logger.info("Stopping notification service...") - - # Cancel all tasks - for task in self._tasks: - task.cancel() - - # Wait for cancellation - if self._tasks: - await asyncio.gather(*self._tasks, return_exceptions=True) - - # Stop consumer - if self._consumer: - await self._consumer.stop() - - # Clear cache - await self._throttle_cache.clear() - - self.logger.info("Notification service stopped") - - def _start_background_tasks(self) -> None: - """Start background processing tasks.""" - tasks = [ - asyncio.create_task(self._process_pending_notifications()), - asyncio.create_task(self._cleanup_old_notifications()), - ] - - for task in tasks: - self._tasks.add(task) - task.add_done_callback(self._tasks.discard) - - async def _subscribe_to_events(self) -> None: - """Subscribe to relevant events for notifications.""" - # Configure consumer for notification-relevant events - consumer_config = ConsumerConfig( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=GroupId.NOTIFICATION_SERVICE, - max_poll_records=10, - enable_auto_commit=True, - auto_offset_reset="latest", - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, + self._logger.info("NotificationService initialized") + + async def handle_execution_event(self, event: DomainEvent) -> None: + """Handle execution result events. + + Called by worker entrypoint for each event. + """ + try: + if isinstance(event, ExecutionCompletedEvent): + await self._handle_execution_completed(event) + elif isinstance(event, ExecutionFailedEvent): + await self._handle_execution_failed(event) + elif isinstance(event, ExecutionTimeoutEvent): + await self._handle_execution_timeout(event) + else: + self._logger.warning(f"Unhandled execution event type: {event.event_type}") + except Exception as e: + self._logger.error(f"Error handling execution event: {e}", exc_info=True) + + async def _handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: + """Handle execution completed event.""" + user_id = event.metadata.user_id + if not user_id: + self._logger.error("No user_id in event metadata") + return + + title = f"Execution Completed: {event.execution_id}" + duration = event.resource_usage.execution_time_wall_seconds if event.resource_usage else 0.0 + body = f"Your execution completed successfully. Duration: {duration:.2f}s." + await self.create_notification( + user_id=user_id, + subject=title, + body=body, + severity=NotificationSeverity.MEDIUM, + tags=["execution", "completed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], + metadata=event.model_dump( + exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} + ), ) - execution_results_topic = get_topic_for_event(EventType.EXECUTION_COMPLETED) - - # Log topics for debugging - self.logger.info(f"Notification service will subscribe to topics: {execution_results_topic}") - - # Create dispatcher and register handlers for specific event types - self._dispatcher = EventDispatcher(logger=self.logger) - # Use a single handler for execution result events (simpler and less brittle) - self._dispatcher.register_handler(EventType.EXECUTION_COMPLETED, self._handle_execution_event) - self._dispatcher.register_handler(EventType.EXECUTION_FAILED, self._handle_execution_event) - self._dispatcher.register_handler(EventType.EXECUTION_TIMEOUT, self._handle_execution_event) - - # Create consumer with dispatcher - self._consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self._dispatcher, - schema_registry=self.schema_registry_manager, - settings=self.settings, - logger=self.logger, - event_metrics=self._event_metrics, + async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle execution failed event.""" + user_id = event.metadata.user_id + if not user_id: + self._logger.error("No user_id in event metadata") + return + + event_data = event.model_dump( + exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} ) + event_data["stdout"] = event_data["stdout"][:200] + event_data["stderr"] = event_data["stderr"][:200] - # Start consumer - await self._consumer.start([execution_results_topic]) + title = f"Execution Failed: {event.execution_id}" + body = f"Your execution failed: {event.error_message}" + await self.create_notification( + user_id=user_id, + subject=title, + body=body, + severity=NotificationSeverity.HIGH, + tags=["execution", "failed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], + metadata=event_data, + ) - # Start consumer task - self._consumer_task = asyncio.create_task(self._run_consumer()) - self._tasks.add(self._consumer_task) - self._consumer_task.add_done_callback(self._tasks.discard) + async def _handle_execution_timeout(self, event: ExecutionTimeoutEvent) -> None: + """Handle execution timeout event.""" + user_id = event.metadata.user_id + if not user_id: + self._logger.error("No user_id in event metadata") + return - self.logger.info("Notification service subscribed to execution events") + title = f"Execution Timeout: {event.execution_id}" + body = f"Your execution timed out after {event.timeout_seconds}s." + await self.create_notification( + user_id=user_id, + subject=title, + body=body, + severity=NotificationSeverity.HIGH, + tags=["execution", "timeout", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], + metadata=event.model_dump( + exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} + ), + ) async def create_notification( - self, - user_id: str, - subject: str, - body: str, - tags: list[str], - severity: NotificationSeverity = NotificationSeverity.MEDIUM, - channel: NotificationChannel = NotificationChannel.IN_APP, - scheduled_for: datetime | None = None, - action_url: str | None = None, - metadata: NotificationContext | None = None, + self, + user_id: str, + subject: str, + body: str, + tags: list[str], + severity: NotificationSeverity = NotificationSeverity.MEDIUM, + channel: NotificationChannel = NotificationChannel.IN_APP, + scheduled_for: datetime | None = None, + action_url: str | None = None, + metadata: dict[str, object] | None = None, ) -> DomainNotification: + """Create a new notification.""" if not tags: raise NotificationValidationError("tags must be a non-empty list") - self.logger.info( + + self._logger.info( f"Creating notification for user {user_id}", extra={ "user_id": user_id, @@ -266,26 +231,24 @@ async def create_notification( }, ) - # Check throttling if await self._throttle_cache.check_throttle( - user_id, - severity, - window_hours=self.settings.NOTIF_THROTTLE_WINDOW_HOURS, - max_per_hour=self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, + user_id, + severity, + window_hours=self._settings.NOTIF_THROTTLE_WINDOW_HOURS, + max_per_hour=self._settings.NOTIF_THROTTLE_MAX_PER_HOUR, ): error_msg = ( f"Notification rate limit exceeded for user {user_id}. " - f"Max {self.settings.NOTIF_THROTTLE_MAX_PER_HOUR} " - f"per {self.settings.NOTIF_THROTTLE_WINDOW_HOURS} hour(s)" + f"Max {self._settings.NOTIF_THROTTLE_MAX_PER_HOUR} " + f"per {self._settings.NOTIF_THROTTLE_WINDOW_HOURS} hour(s)" ) - self.logger.warning(error_msg) + self._logger.warning(error_msg) raise NotificationThrottledError( user_id, - self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, - self.settings.NOTIF_THROTTLE_WINDOW_HOURS, + self._settings.NOTIF_THROTTLE_MAX_PER_HOUR, + self._settings.NOTIF_THROTTLE_WINDOW_HOURS, ) - # Create notification create_data = DomainNotificationCreate( user_id=user_id, channel=channel, @@ -298,26 +261,16 @@ async def create_notification( metadata=metadata or {}, ) - # Save to database - notification = await self.repository.create_notification(create_data) + notification = await self._repository.create_notification(create_data) - # Publish event - event_bus = await self.event_bus_manager.get_event_bus() - await event_bus.publish( - NotificationCreatedEvent( - notification_id=str(notification.notification_id), - user_id=user_id, - subject=subject, - body=body, - severity=severity, - tags=notification.tags, - channels=[channel], - metadata=EventMetadata( - service_name=self.settings.SERVICE_NAME, - service_version=self.settings.SERVICE_VERSION, - user_id=user_id, - ), - ) + await self._event_bus.publish( + "notifications.created", + { + "notification_id": str(notification.notification_id), + "user_id": user_id, + "severity": str(severity), + "tags": notification.tags, + }, ) await self._deliver_notification(notification) @@ -325,23 +278,21 @@ async def create_notification( return notification async def create_system_notification( - self, - title: str, - message: str, - severity: NotificationSeverity = NotificationSeverity.MEDIUM, - tags: list[str] | None = None, - metadata: dict[str, object] | None = None, - target_users: list[str] | None = None, - target_roles: list[UserRole] | None = None, - ) -> SystemNotificationStats: - """Create system notifications with streamlined control flow. - - Returns stats with totals and created/failed/throttled counts. - """ + self, + title: str, + message: str, + severity: NotificationSeverity = NotificationSeverity.MEDIUM, + tags: list[str] | None = None, + metadata: dict[str, object] | None = None, + target_users: list[str] | None = None, + target_roles: list[UserRole] | None = None, + ) -> dict[str, int]: + """Create system notifications with streamlined control flow.""" cfg = SystemConfig( - severity=severity, throttle_exempt=(severity in (NotificationSeverity.HIGH, NotificationSeverity.URGENT)) + severity=severity, + throttle_exempt=(severity in (NotificationSeverity.HIGH, NotificationSeverity.URGENT)), ) - base_context: NotificationContext = {"message": message, **(metadata or {})} + base_context: dict[str, object] = {"message": message, **(metadata or {})} users = await self._resolve_targets(target_users, target_roles) if not users: @@ -354,14 +305,16 @@ async def worker(uid: str) -> str: return await self._create_system_for_user(uid, cfg, title, base_context, tags or ["system"]) results = ( - [await worker(u) for u in users] if len(users) <= 20 else await asyncio.gather(*(worker(u) for u in users)) + [await worker(u) for u in users] + if len(users) <= 20 + else await asyncio.gather(*(worker(u) for u in users)) ) created = sum(1 for r in results if r == "created") throttled = sum(1 for r in results if r == "throttled") failed = sum(1 for r in results if r == "failed") - self.logger.info( + self._logger.info( "System notification completed", extra={ "severity": cfg.severity, @@ -376,31 +329,31 @@ async def worker(uid: str) -> str: return {"total_users": len(users), "created": created, "failed": failed, "throttled": throttled} async def _resolve_targets( - self, - target_users: list[str] | None, - target_roles: list[UserRole] | None, + self, + target_users: list[str] | None, + target_roles: list[UserRole] | None, ) -> list[str]: if target_users is not None: return target_users if target_roles: - return await self.repository.get_users_by_roles(target_roles) - return await self.repository.get_active_users(days=30) + return await self._repository.get_users_by_roles(target_roles) + return await self._repository.get_active_users(days=30) async def _create_system_for_user( - self, - user_id: str, - cfg: SystemConfig, - title: str, - base_context: NotificationContext, - tags: list[str], + self, + user_id: str, + cfg: SystemConfig, + title: str, + base_context: dict[str, object], + tags: list[str], ) -> str: try: if not cfg.throttle_exempt: throttled = await self._throttle_cache.check_throttle( user_id, cfg.severity, - window_hours=self.settings.NOTIF_THROTTLE_WINDOW_HOURS, - max_per_hour=self.settings.NOTIF_THROTTLE_MAX_PER_HOUR, + window_hours=self._settings.NOTIF_THROTTLE_WINDOW_HOURS, + max_per_hour=self._settings.NOTIF_THROTTLE_MAX_PER_HOUR, ) if throttled: return "throttled" @@ -416,27 +369,29 @@ async def _create_system_for_user( ) return "created" except Exception as e: - self.logger.error( - "Failed to create system notification for user", extra={"user_id": user_id, "error": str(e)} + self._logger.error( + "Failed to create system notification for user", + extra={"user_id": user_id, "error": str(e)}, ) return "failed" async def _send_in_app( - self, notification: DomainNotification, subscription: DomainNotificationSubscription + self, + notification: DomainNotification, + subscription: DomainNotificationSubscription, ) -> None: - """Send in-app notification via SSE bus (fan-out to connected clients).""" + """Send in-app notification via SSE bus.""" await self._publish_notification_sse(notification) async def _send_webhook( - self, notification: DomainNotification, subscription: DomainNotificationSubscription + self, + notification: DomainNotification, + subscription: DomainNotificationSubscription, ) -> None: """Send webhook notification.""" webhook_url = notification.webhook_url or subscription.webhook_url if not webhook_url: - raise ValueError( - f"No webhook URL configured for user {notification.user_id} on channel {notification.channel}. " - f"Configure in notification settings." - ) + raise ValueError(f"No webhook URL configured for user {notification.user_id}") payload = { "notification_id": str(notification.notification_id), @@ -453,15 +408,6 @@ async def _send_webhook( headers = notification.webhook_headers or {} headers["Content-Type"] = "application/json" - self.logger.debug( - f"Sending webhook notification to {webhook_url}", - extra={ - "notification_id": str(notification.notification_id), - "payload_size": len(str(payload)), - "webhook_url": webhook_url, - }, - ) - add_span_attributes( **{ "notification.id": str(notification.notification_id), @@ -472,25 +418,17 @@ async def _send_webhook( async with httpx.AsyncClient() as client: response = await client.post(webhook_url, json=payload, headers=headers, timeout=30.0) response.raise_for_status() - self.logger.debug( - "Webhook delivered successfully", - extra={ - "notification_id": str(notification.notification_id), - "status_code": response.status_code, - "response_time_ms": int(response.elapsed.total_seconds() * 1000), - }, - ) - async def _send_slack(self, notification: DomainNotification, subscription: DomainNotificationSubscription) -> None: + async def _send_slack( + self, + notification: DomainNotification, + subscription: DomainNotificationSubscription, + ) -> None: """Send Slack notification.""" if not subscription.slack_webhook: - raise ValueError( - f"No Slack webhook URL configured for user {notification.user_id}. " - f"Please configure Slack integration in notification settings." - ) + raise ValueError(f"No Slack webhook URL configured for user {notification.user_id}") - # Format message for Slack - slack_message: SlackMessage = { + slack_message: dict[str, object] = { "text": notification.subject, "attachments": [ { @@ -502,20 +440,12 @@ async def _send_slack(self, notification: DomainNotification, subscription: Doma ], } - # Add action button if URL provided if notification.action_url: attachments = slack_message.get("attachments", []) if attachments and isinstance(attachments, list): - attachments[0]["actions"] = [{"type": "button", "text": "View Details", "url": notification.action_url}] - - self.logger.debug( - "Sending Slack notification", - extra={ - "notification_id": str(notification.notification_id), - "has_action": notification.action_url is not None, - "priority_color": self._get_slack_color(notification.severity), - }, - ) + attachments[0]["actions"] = [ + {"type": "button", "text": "View Details", "url": notification.action_url} + ] add_span_attributes( **{ @@ -526,172 +456,170 @@ async def _send_slack(self, notification: DomainNotification, subscription: Doma async with httpx.AsyncClient() as client: response = await client.post(subscription.slack_webhook, json=slack_message, timeout=30.0) response.raise_for_status() - self.logger.debug( - "Slack notification delivered successfully", - extra={"notification_id": str(notification.notification_id), "status_code": response.status_code}, - ) def _get_slack_color(self, priority: NotificationSeverity) -> str: """Get Slack color based on severity.""" return { - NotificationSeverity.LOW: "#36a64f", # Green - NotificationSeverity.MEDIUM: "#ff9900", # Orange - NotificationSeverity.HIGH: "#ff0000", # Red - NotificationSeverity.URGENT: "#990000", # Dark Red - }.get(priority, "#808080") # Default gray - - async def _process_pending_notifications(self) -> None: - """Process pending notifications in background.""" - while self.is_running: - try: - # Find pending notifications - notifications = await self.repository.find_pending_notifications( - batch_size=self.settings.NOTIF_PENDING_BATCH_SIZE - ) + NotificationSeverity.LOW: "#36a64f", + NotificationSeverity.MEDIUM: "#ff9900", + NotificationSeverity.HIGH: "#ff0000", + NotificationSeverity.URGENT: "#990000", + }.get(priority, "#808080") - # Process each notification - for notification in notifications: - if not self.is_running: - break - await self._deliver_notification(notification) - - # Sleep between batches - await asyncio.sleep(5) - - except Exception as e: - self.logger.error(f"Error processing pending notifications: {e}") - await asyncio.sleep(10) - - async def _cleanup_old_notifications(self) -> None: - """Cleanup old notifications periodically.""" - while self.is_running: - try: - # Run cleanup once per day - await asyncio.sleep(86400) # 24 hours - - if not self.is_running: - break - - # Delete old notifications - deleted_count = await self.repository.cleanup_old_notifications(self.settings.NOTIF_OLD_DAYS) - - self.logger.info(f"Cleaned up {deleted_count} old notifications") - - except Exception as e: - self.logger.error(f"Error cleaning up old notifications: {e}") - - async def _run_consumer(self) -> None: - """Run the event consumer loop.""" - while self.is_running: - try: - # Consumer handles polling internally - await asyncio.sleep(1) - except asyncio.CancelledError: - self.logger.info("Notification consumer task cancelled") - break - except Exception as e: - self.logger.error(f"Error in notification consumer loop: {e}") - await asyncio.sleep(5) - - async def _handle_execution_timeout_typed(self, event: ExecutionTimeoutEvent) -> None: - """Handle typed execution timeout event.""" - user_id = event.metadata.user_id - if not user_id: - self.logger.error("No user_id in event metadata") - return + async def process_pending_notifications(self, batch_size: int = 10) -> int: + """Process pending notifications. - title = f"Execution Timeout: {event.execution_id}" - body = f"Your execution timed out after {event.timeout_seconds}s." - await self.create_notification( - user_id=user_id, - subject=title, - body=body, - severity=NotificationSeverity.HIGH, - tags=["execution", "timeout", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ), - ) + Should be called periodically from worker entrypoint. + Returns number of notifications processed. + """ + notifications = await self._repository.find_pending_notifications(batch_size=batch_size) + count = 0 - async def _handle_execution_completed_typed(self, event: ExecutionCompletedEvent) -> None: - """Handle typed execution completed event.""" - user_id = event.metadata.user_id - if not user_id: - self.logger.error("No user_id in event metadata") + for notification in notifications: + await self._deliver_notification(notification) + count += 1 + + return count + + async def cleanup_old_notifications(self, days: int = 30) -> int: + """Cleanup old notifications. + + Should be called periodically from worker entrypoint. + Returns number of notifications deleted. + """ + return await self._repository.cleanup_old_notifications(days) + + async def _should_skip_notification( + self, + notification: DomainNotification, + subscription: DomainNotificationSubscription, + ) -> str | None: + """Check if notification should be skipped based on subscription filters.""" + if not subscription.enabled: + return f"User {notification.user_id} has {notification.channel} disabled" + + if subscription.severities and notification.severity not in subscription.severities: + return f"Notification severity '{notification.severity}' filtered by user preferences" + + if subscription.include_tags and not any( + tag in subscription.include_tags for tag in (notification.tags or []) + ): + return f"Notification tags {notification.tags} not in include list" + + if subscription.exclude_tags and any( + tag in subscription.exclude_tags for tag in (notification.tags or []) + ): + return f"Notification tags {notification.tags} excluded by preferences" + + return None + + async def _deliver_notification(self, notification: DomainNotification) -> None: + """Deliver notification through configured channel.""" + claimed = await self._repository.try_claim_pending(notification.notification_id) + if not claimed: return - title = f"Execution Completed: {event.execution_id}" - duration = event.resource_usage.execution_time_wall_seconds if event.resource_usage else 0.0 - body = f"Your execution completed successfully. Duration: {duration:.2f}s." - await self.create_notification( - user_id=user_id, - subject=title, - body=body, - severity=NotificationSeverity.MEDIUM, - tags=["execution", "completed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ), + self._logger.info( + f"Delivering notification {notification.notification_id}", + extra={ + "notification_id": str(notification.notification_id), + "user_id": notification.user_id, + "channel": notification.channel, + "severity": notification.severity, + "tags": list(notification.tags or []), + }, ) - async def _handle_execution_event(self, event: DomainEvent) -> None: - """Unified handler for execution result events.""" - try: - if isinstance(event, ExecutionCompletedEvent): - await self._handle_execution_completed_typed(event) - elif isinstance(event, ExecutionFailedEvent): - await self._handle_execution_failed_typed(event) - elif isinstance(event, ExecutionTimeoutEvent): - await self._handle_execution_timeout_typed(event) - else: - self.logger.warning(f"Unhandled execution event type: {event.event_type}") - except Exception as e: - self.logger.error(f"Error handling execution event: {e}", exc_info=True) + subscription = await self._repository.get_subscription(notification.user_id, notification.channel) - async def _handle_execution_failed_typed(self, event: ExecutionFailedEvent) -> None: - """Handle typed execution failed event.""" - user_id = event.metadata.user_id - if not user_id: - self.logger.error("No user_id in event metadata") + skip_reason = await self._should_skip_notification(notification, subscription) + if skip_reason: + self._logger.info(skip_reason) + await self._repository.update_notification( + notification.notification_id, + notification.user_id, + DomainNotificationUpdate(status=NotificationStatus.SKIPPED, error_message=skip_reason), + ) return - # Use model_dump to get all event data - event_data = event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ) + start_time = asyncio.get_running_loop().time() + try: + handler = self._channel_handlers.get(notification.channel) + if handler is None: + raise ValueError(f"No handler configured for channel: {notification.channel}") - # Truncate stdout/stderr for notification context - event_data["stdout"] = event_data["stdout"][:200] - event_data["stderr"] = event_data["stderr"][:200] + await handler(notification, subscription) # type: ignore + delivery_time = asyncio.get_running_loop().time() - start_time - title = f"Execution Failed: {event.execution_id}" - body = f"Your execution failed: {event.error_message}" - await self.create_notification( - user_id=user_id, - subject=title, - body=body, - severity=NotificationSeverity.HIGH, - tags=["execution", "failed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event_data, + await self._repository.update_notification( + notification.notification_id, + notification.user_id, + DomainNotificationUpdate(status=NotificationStatus.DELIVERED, delivered_at=datetime.now(UTC)), + ) + + self._logger.info( + f"Successfully delivered notification {notification.notification_id}", + extra={ + "notification_id": str(notification.notification_id), + "channel": notification.channel, + "delivery_time_ms": int(delivery_time * 1000), + }, + ) + + self._metrics.record_notification_sent( + notification.severity, channel=notification.channel, severity=notification.severity + ) + self._metrics.record_notification_delivery_time(delivery_time, notification.severity) + + except Exception as e: + self._logger.error( + f"Failed to deliver notification {notification.notification_id}: {str(e)}", + exc_info=True, + ) + + new_retry_count = notification.retry_count + 1 + error_message = f"Delivery failed via {notification.channel}: {str(e)}" + failed_at = datetime.now(UTC) + + notif_status = NotificationStatus.PENDING \ + if new_retry_count < notification.max_retries else NotificationStatus.FAILED + await self._repository.update_notification( + notification.notification_id, + notification.user_id, + DomainNotificationUpdate( + status=notif_status, + failed_at=failed_at, + error_message=error_message, + retry_count=new_retry_count, + ), + ) + + async def _publish_notification_sse(self, notification: DomainNotification) -> None: + """Publish an in-app notification to the SSE bus.""" + message = RedisNotificationMessage( + notification_id=notification.notification_id, + severity=notification.severity, + status=notification.status, + tags=list(notification.tags or []), + subject=notification.subject, + body=notification.body, + action_url=notification.action_url or "", + created_at=notification.created_at, ) + await self._sse_bus.publish_notification(notification.user_id, message) async def mark_as_read(self, user_id: str, notification_id: str) -> bool: """Mark notification as read.""" - success = await self.repository.mark_as_read(notification_id, user_id) + success = await self._repository.mark_as_read(notification_id, user_id) - event_bus = await self.event_bus_manager.get_event_bus() if success: - await event_bus.publish( - NotificationReadEvent( - notification_id=str(notification_id), - user_id=user_id, - read_at=datetime.now(UTC), - metadata=EventMetadata( - service_name=self.settings.SERVICE_NAME, - service_version=self.settings.SERVICE_VERSION, - user_id=user_id, - ), - ) + await self._event_bus.publish( + "notifications.read", + { + "notification_id": str(notification_id), + "user_id": user_id, + "read_at": datetime.now(UTC).isoformat(), + }, ) else: raise NotificationNotFoundError(notification_id) @@ -700,21 +628,20 @@ async def mark_as_read(self, user_id: str, notification_id: str) -> bool: async def get_unread_count(self, user_id: str) -> int: """Get count of unread notifications.""" - return await self.repository.get_unread_count(user_id) + return await self._repository.get_unread_count(user_id) async def list_notifications( - self, - user_id: str, - status: NotificationStatus | None = None, - limit: int = 20, - offset: int = 0, - include_tags: list[str] | None = None, - exclude_tags: list[str] | None = None, - tag_prefix: str | None = None, + self, + user_id: str, + status: NotificationStatus | None = None, + limit: int = 20, + offset: int = 0, + include_tags: list[str] | None = None, + exclude_tags: list[str] | None = None, + tag_prefix: str | None = None, ) -> DomainNotificationListResult: """List notifications with pagination.""" - # Get notifications - notifications = await self.repository.list_notifications( + notifications = await self._repository.list_notifications( user_id=user_id, status=status, skip=offset, @@ -724,9 +651,8 @@ async def list_notifications( tag_prefix=tag_prefix, ) - # Get counts total, unread_count = await asyncio.gather( - self.repository.count_notifications( + self._repository.count_notifications( user_id=user_id, status=status, include_tags=include_tags, @@ -736,21 +662,24 @@ async def list_notifications( self.get_unread_count(user_id), ) - return DomainNotificationListResult(notifications=notifications, total=total, unread_count=unread_count) + return DomainNotificationListResult( + notifications=notifications, + total=total, + unread_count=unread_count, + ) async def update_subscription( - self, - user_id: str, - channel: NotificationChannel, - enabled: bool, - webhook_url: str | None = None, - slack_webhook: str | None = None, - severities: list[NotificationSeverity] | None = None, - include_tags: list[str] | None = None, - exclude_tags: list[str] | None = None, + self, + user_id: str, + channel: NotificationChannel, + enabled: bool, + webhook_url: str | None = None, + slack_webhook: str | None = None, + severities: list[NotificationSeverity] | None = None, + include_tags: list[str] | None = None, + exclude_tags: list[str] | None = None, ) -> DomainNotificationSubscription: """Update notification subscription preferences.""" - # Validate channel-specific requirements if channel == NotificationChannel.WEBHOOK and enabled: if not webhook_url: raise NotificationValidationError("webhook_url is required when enabling WEBHOOK") @@ -762,7 +691,6 @@ async def update_subscription( if not slack_webhook.startswith("https://hooks.slack.com/"): raise NotificationValidationError("slack_webhook must be a valid Slack webhook URL") - # Build update data update_data = DomainSubscriptionUpdate( enabled=enabled, webhook_url=webhook_url, @@ -772,193 +700,27 @@ async def update_subscription( exclude_tags=exclude_tags, ) - return await self.repository.upsert_subscription(user_id, channel, update_data) + return await self._repository.upsert_subscription(user_id, channel, update_data) async def mark_all_as_read(self, user_id: str) -> int: """Mark all notifications as read for a user.""" - count = await self.repository.mark_all_as_read(user_id) + count = await self._repository.mark_all_as_read(user_id) - event_bus = await self.event_bus_manager.get_event_bus() if count > 0: - await event_bus.publish( - NotificationAllReadEvent( - user_id=user_id, - count=count, - read_at=datetime.now(UTC), - metadata=EventMetadata( - service_name=self.settings.SERVICE_NAME, - service_version=self.settings.SERVICE_VERSION, - user_id=user_id, - ), - ) + await self._event_bus.publish( + "notifications.all_read", + {"user_id": user_id, "count": count, "read_at": datetime.now(UTC).isoformat()}, ) return count async def get_subscriptions(self, user_id: str) -> dict[NotificationChannel, DomainNotificationSubscription]: """Get all notification subscriptions for a user.""" - return await self.repository.get_all_subscriptions(user_id) + return await self._repository.get_all_subscriptions(user_id) async def delete_notification(self, user_id: str, notification_id: str) -> bool: """Delete a notification.""" - deleted = await self.repository.delete_notification(str(notification_id), user_id) + deleted = await self._repository.delete_notification(str(notification_id), user_id) if not deleted: raise NotificationNotFoundError(notification_id) return deleted - - async def _publish_notification_sse(self, notification: DomainNotification) -> None: - """Publish an in-app notification to the SSE bus for realtime delivery.""" - message = RedisNotificationMessage( - notification_id=notification.notification_id, - severity=notification.severity, - status=notification.status, - tags=list(notification.tags or []), - subject=notification.subject, - body=notification.body, - action_url=notification.action_url or "", - created_at=notification.created_at, - ) - await self.sse_bus.publish_notification(notification.user_id, message) - - async def _should_skip_notification( - self, notification: DomainNotification, subscription: DomainNotificationSubscription - ) -> str | None: - """Check if notification should be skipped based on subscription filters. - - Returns skip reason if should skip, None otherwise. - """ - if not subscription.enabled: - return f"User {notification.user_id} has {notification.channel} disabled; skipping delivery." - - if subscription.severities and notification.severity not in subscription.severities: - return ( - f"Notification severity '{notification.severity}' filtered by user preferences " - f"for {notification.channel}" - ) - - if subscription.include_tags and not any(tag in subscription.include_tags for tag in (notification.tags or [])): - return f"Notification tags {notification.tags} not in include list for {notification.channel}" - - if subscription.exclude_tags and any(tag in subscription.exclude_tags for tag in (notification.tags or [])): - return f"Notification tags {notification.tags} excluded by preferences for {notification.channel}" - - return None - - async def _deliver_notification(self, notification: DomainNotification) -> None: - """Deliver notification through configured channel using safe state transitions.""" - # Attempt to claim this notification for sending - claimed = await self.repository.try_claim_pending(notification.notification_id) - if not claimed: - return - - self.logger.info( - f"Delivering notification {notification.notification_id}", - extra={ - "notification_id": str(notification.notification_id), - "user_id": notification.user_id, - "channel": notification.channel, - "severity": notification.severity, - "tags": list(notification.tags or []), - }, - ) - - # Check user subscription for the channel - subscription = await self.repository.get_subscription(notification.user_id, notification.channel) - - # Check if notification should be skipped - skip_reason = await self._should_skip_notification(notification, subscription) - if skip_reason: - self.logger.info(skip_reason) - await self.repository.update_notification( - notification.notification_id, - notification.user_id, - DomainNotificationUpdate(status=NotificationStatus.SKIPPED, error_message=skip_reason), - ) - return - - # Send through channel - start_time = asyncio.get_running_loop().time() - try: - handler = self._channel_handlers.get(notification.channel) - if handler is None: - raise ValueError( - f"No handler configured for notification channel: {notification.channel}. " - f"Available channels: {list(self._channel_handlers.keys())}" - ) - - self.logger.debug(f"Using handler {handler.__name__} for channel {notification.channel}") - await handler(notification, subscription) - delivery_time = asyncio.get_running_loop().time() - start_time - - # Mark delivered - await self.repository.update_notification( - notification.notification_id, - notification.user_id, - DomainNotificationUpdate(status=NotificationStatus.DELIVERED, delivered_at=datetime.now(UTC)), - ) - - self.logger.info( - f"Successfully delivered notification {notification.notification_id}", - extra={ - "notification_id": str(notification.notification_id), - "channel": notification.channel, - "delivery_time_ms": int(delivery_time * 1000), - }, - ) - - # Metrics (use tag string or severity) - self.metrics.record_notification_sent( - notification.severity, channel=notification.channel, severity=notification.severity - ) - self.metrics.record_notification_delivery_time(delivery_time, notification.severity) - - except Exception as e: - error_details = { - "notification_id": str(notification.notification_id), - "channel": notification.channel, - "error_type": type(e).__name__, - "error_message": str(e), - "retry_count": notification.retry_count, - "max_retries": notification.max_retries, - } - - self.logger.error( - f"Failed to deliver notification {notification.notification_id}: {str(e)}", - extra=error_details, - exc_info=True, - ) - - new_retry_count = notification.retry_count + 1 - error_message = f"Delivery failed via {notification.channel}: {str(e)}" - failed_at = datetime.now(UTC) - - # Schedule retry if under limit - if new_retry_count < notification.max_retries: - retry_time = datetime.now(UTC) + timedelta(minutes=self.settings.NOTIF_RETRY_DELAY_MINUTES) - self.logger.info( - f"Scheduled retry {new_retry_count}/{notification.max_retries} for {notification.notification_id}", - extra={"retry_at": retry_time.isoformat()}, - ) - # Will be retried - keep as PENDING but with scheduled_for - # Note: scheduled_for not in DomainNotificationUpdate, so we update status fields only - await self.repository.update_notification( - notification.notification_id, - notification.user_id, - DomainNotificationUpdate( - status=NotificationStatus.PENDING, - failed_at=failed_at, - error_message=error_message, - retry_count=new_retry_count, - ), - ) - else: - await self.repository.update_notification( - notification.notification_id, - notification.user_id, - DomainNotificationUpdate( - status=NotificationStatus.FAILED, - failed_at=failed_at, - error_message=error_message, - retry_count=new_retry_count, - ), - ) diff --git a/backend/app/services/pod_monitor/monitor.py b/backend/app/services/pod_monitor/monitor.py index ecbb4556..046cbee9 100644 --- a/backend/app/services/pod_monitor/monitor.py +++ b/backend/app/services/pod_monitor/monitor.py @@ -1,35 +1,28 @@ +"""Pod Monitor - stateless event handler. + +Monitors Kubernetes pods and publishes lifecycle events. Receives events, +processes them, and publishes results. No lifecycle management. +All state is stored in Redis via PodStateRepository. +""" + +from __future__ import annotations + import asyncio import logging import time -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from dataclasses import dataclass -from enum import auto from typing import Any from kubernetes import client as k8s_client -from kubernetes.client.rest import ApiException -from app.core.k8s_clients import K8sClients, close_k8s_clients, create_k8s_clients -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import KubernetesMetrics from app.core.utils import StringEnum +from app.db.repositories.pod_state_repository import PodStateRepository from app.domain.events.typed import DomainEvent -from app.services.kafka_event_service import KafkaEventService +from app.events.core import UnifiedProducer from app.services.pod_monitor.config import PodMonitorConfig from app.services.pod_monitor.event_mapper import PodEventMapper -# Type aliases -type PodName = str -type ResourceVersion = str -type EventType = str -type KubeEvent = dict[str, Any] -type StatusDict = dict[str, Any] - -# Constants -MAX_BACKOFF_SECONDS: int = 300 # 5 minutes -RECONCILIATION_LOG_INTERVAL: int = 60 # 1 minute - class WatchEventType(StringEnum): """Kubernetes watch event types.""" @@ -39,33 +32,13 @@ class WatchEventType(StringEnum): DELETED = "DELETED" -class MonitorState(StringEnum): - """Pod monitor states.""" - - IDLE = auto() - RUNNING = auto() - STOPPING = auto() - STOPPED = auto() - - class ErrorType(StringEnum): """Error types for metrics.""" - RESOURCE_VERSION_EXPIRED = auto() - API_ERROR = auto() - UNEXPECTED = auto() - PROCESSING_ERROR = auto() - - -@dataclass(frozen=True, slots=True) -class WatchContext: - """Immutable context for watch operations.""" - - namespace: str - label_selector: str - field_selector: str | None - timeout_seconds: int - resource_version: ResourceVersion | None + RESOURCE_VERSION_EXPIRED = "resource_version_expired" + API_ERROR = "api_error" + UNEXPECTED = "unexpected" + PROCESSING_ERROR = "processing_error" @dataclass(frozen=True, slots=True) @@ -74,206 +47,70 @@ class PodEvent: event_type: WatchEventType pod: k8s_client.V1Pod - resource_version: ResourceVersion | None + resource_version: str | None @dataclass(frozen=True, slots=True) class ReconciliationResult: """Result of state reconciliation.""" - missing_pods: set[PodName] - extra_pods: set[PodName] + missing_pods: set[str] + extra_pods: set[str] duration_seconds: float success: bool error: str | None = None -class PodMonitor(LifecycleEnabled): - """ - Monitors Kubernetes pods and publishes lifecycle events. +class PodMonitor: + """Stateless pod monitor - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + All state (tracked pods, resource version) stored in Redis via PodStateRepository. - This service watches pods with specific labels using the K8s watch API, - maps Kubernetes events to application events, and publishes them to Kafka. - Events are stored in the events collection AND published to Kafka via KafkaEventService. + Worker entrypoint handles the watch loop: + watch = Watch() + for event in watch.stream(...): + await monitor.handle_raw_event(event) """ def __init__( self, config: PodMonitorConfig, - kafka_event_service: KafkaEventService, - logger: logging.Logger, - k8s_clients: K8sClients, + producer: UnifiedProducer, + pod_state_repo: PodStateRepository, + v1_client: k8s_client.CoreV1Api, event_mapper: PodEventMapper, + logger: logging.Logger, kubernetes_metrics: KubernetesMetrics, ) -> None: - """Initialize the pod monitor with all required dependencies. - - All dependencies must be provided - use create_pod_monitor() factory - for automatic dependency creation in production. - """ - super().__init__() - self.logger = logger - self.config = config - - # Kubernetes clients (required, no nullability) - self._clients = k8s_clients - self._v1 = k8s_clients.v1 - self._watch = k8s_clients.watch - - # Components (required, no nullability) + self._config = config + self._producer = producer + self._pod_state_repo = pod_state_repo + self._v1 = v1_client self._event_mapper = event_mapper - self._kafka_event_service = kafka_event_service - - # State - self._state = MonitorState.IDLE - self._tracked_pods: set[PodName] = set() - self._reconnect_attempts: int = 0 - self._last_resource_version: ResourceVersion | None = None - - # Tasks - self._watch_task: asyncio.Task[None] | None = None - self._reconcile_task: asyncio.Task[None] | None = None - - # Metrics + self._logger = logger self._metrics = kubernetes_metrics - @property - def state(self) -> MonitorState: - """Get current monitor state.""" - return self._state - - async def _on_start(self) -> None: - """Start the pod monitor.""" - self.logger.info("Starting PodMonitor service...") - - # Verify K8s connectivity (all clients already injected via __init__) - await asyncio.to_thread(self._v1.get_api_resources) - self.logger.info("Successfully connected to Kubernetes API") - - # Start monitoring - self._state = MonitorState.RUNNING - self._watch_task = asyncio.create_task(self._watch_pods()) - - # Start reconciliation if enabled - if self.config.enable_state_reconciliation: - self._reconcile_task = asyncio.create_task(self._reconciliation_loop()) - - self.logger.info("PodMonitor service started successfully") - - async def _on_stop(self) -> None: - """Stop the pod monitor.""" - self.logger.info("Stopping PodMonitor service...") - self._state = MonitorState.STOPPING - - # Cancel tasks - tasks = [t for t in [self._watch_task, self._reconcile_task] if t] - for task in tasks: - task.cancel() - - # Wait for cancellation - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - # Close watch - if self._watch: - self._watch.stop() - - # Clear state - self._tracked_pods.clear() - self._event_mapper.clear_cache() - - self._state = MonitorState.STOPPED - self.logger.info("PodMonitor service stopped") - - async def _watch_pods(self) -> None: - """Main watch loop for pods.""" - while self._state == MonitorState.RUNNING: - try: - self._reconnect_attempts = 0 - await self._watch_pod_events() - - except ApiException as e: - match e.status: - case 410: # Gone - resource version too old - self.logger.warning("Resource version expired, resetting watch") - self._last_resource_version = None - self._metrics.record_pod_monitor_watch_error(ErrorType.RESOURCE_VERSION_EXPIRED) - case _: - self.logger.error(f"API error in watch: {e}") - self._metrics.record_pod_monitor_watch_error(ErrorType.API_ERROR) - - await self._handle_watch_error() - - except Exception as e: - self.logger.error(f"Unexpected error in watch: {e}", exc_info=True) - self._metrics.record_pod_monitor_watch_error(ErrorType.UNEXPECTED) - await self._handle_watch_error() - - async def _watch_pod_events(self) -> None: - """Watch for pod events.""" - # self._v1 and self._watch are guaranteed initialized by start() - - context = WatchContext( - namespace=self.config.namespace, - label_selector=self.config.label_selector, - field_selector=self.config.field_selector, - timeout_seconds=self.config.watch_timeout_seconds, - resource_version=self._last_resource_version, - ) - - self.logger.info(f"Starting pod watch with selector: {context.label_selector}, namespace: {context.namespace}") - - # Create watch stream - kwargs = { - "namespace": context.namespace, - "label_selector": context.label_selector, - "timeout_seconds": context.timeout_seconds, - } - - if context.field_selector: - kwargs["field_selector"] = context.field_selector - - if context.resource_version: - kwargs["resource_version"] = context.resource_version - - # Watch stream (clients guaranteed by __init__) - stream = self._watch.stream(self._v1.list_namespaced_pod, **kwargs) + async def handle_raw_event(self, raw_event: dict[str, Any]) -> None: + """Process a raw Kubernetes watch event. + Called by worker entrypoint for each event from watch stream. + """ try: - for event in stream: - if self._state != MonitorState.RUNNING: - break - - await self._process_raw_event(event) - - finally: - # Store resource version for next watch - self._update_resource_version(stream) - - def _update_resource_version(self, stream: Any) -> None: - """Update last resource version from stream.""" - try: - if stream._stop_event and stream._stop_event.resource_version: - self._last_resource_version = stream._stop_event.resource_version - except AttributeError: - pass - - async def _process_raw_event(self, raw_event: KubeEvent) -> None: - """Process a raw Kubernetes watch event.""" - try: - # Parse event event = PodEvent( event_type=WatchEventType(raw_event["type"].upper()), pod=raw_event["object"], resource_version=( - raw_event["object"].metadata.resource_version if raw_event["object"].metadata else None + raw_event["object"].metadata.resource_version + if raw_event["object"].metadata + else None ), ) await self._process_pod_event(event) except (KeyError, ValueError) as e: - self.logger.error(f"Invalid event format: {e}") + self._logger.error(f"Invalid event format: {e}") self._metrics.record_pod_monitor_watch_error(ErrorType.PROCESSING_ERROR) async def _process_pod_event(self, event: PodEvent) -> None: @@ -281,25 +118,38 @@ async def _process_pod_event(self, event: PodEvent) -> None: start_time = time.time() try: - # Update resource version + # Update resource version in Redis if event.resource_version: - self._last_resource_version = event.resource_version + await self._pod_state_repo.set_resource_version(event.resource_version) # Skip ignored phases pod_phase = event.pod.status.phase if event.pod.status else None - if pod_phase in self.config.ignored_pod_phases: + if pod_phase in self._config.ignored_pod_phases: return - # Update tracked pods + # Get pod info pod_name = event.pod.metadata.name + execution_id = ( + event.pod.metadata.labels.get("execution-id") + if event.pod.metadata and event.pod.metadata.labels + else None + ) + + # Update tracked pods in Redis match event.event_type: case WatchEventType.ADDED | WatchEventType.MODIFIED: - self._tracked_pods.add(pod_name) + if execution_id: + await self._pod_state_repo.track_pod( + pod_name=pod_name, + execution_id=execution_id, + status=pod_phase or "Unknown", + ) case WatchEventType.DELETED: - self._tracked_pods.discard(pod_name) + await self._pod_state_repo.untrack_pod(pod_name) # Update metrics - self._metrics.update_pod_monitor_pods_watched(len(self._tracked_pods)) + tracked_count = await self._pod_state_repo.get_tracked_pods_count() + self._metrics.update_pod_monitor_pods_watched(tracked_count) # Map to application events app_events = self._event_mapper.map_pod_event(event.pod, event.event_type) @@ -310,7 +160,7 @@ async def _process_pod_event(self, event: PodEvent) -> None: # Log event if app_events: - self.logger.info( + self._logger.info( f"Processed {event.event_type} event for pod {pod_name} " f"(phase: {pod_phase or 'Unknown'}), " f"published {len(app_events)} events" @@ -321,11 +171,11 @@ async def _process_pod_event(self, event: PodEvent) -> None: self._metrics.record_pod_monitor_event_processing_duration(duration, event.event_type) except Exception as e: - self.logger.error(f"Error processing pod event: {e}", exc_info=True) + self._logger.error(f"Error processing pod event: {e}", exc_info=True) self._metrics.record_pod_monitor_watch_error(ErrorType.PROCESSING_ERROR) async def _publish_event(self, event: DomainEvent, pod: k8s_client.V1Pod) -> None: - """Publish event to Kafka and store in events collection.""" + """Publish event to Kafka.""" try: # Add correlation ID from pod labels if pod.metadata and pod.metadata.labels: @@ -334,94 +184,74 @@ async def _publish_event(self, event: DomainEvent, pod: k8s_client.V1Pod) -> Non execution_id = getattr(event, "execution_id", None) or event.aggregate_id key = str(execution_id or (pod.metadata.name if pod.metadata else "unknown")) - await self._kafka_event_service.publish_domain_event(event=event, key=key) + await self._producer.produce(event_to_produce=event, key=key) phase = pod.status.phase if pod.status else "Unknown" self._metrics.record_pod_monitor_event_published(event.event_type, phase) except Exception as e: - self.logger.error(f"Error publishing event: {e}", exc_info=True) - - async def _handle_watch_error(self) -> None: - """Handle watch errors with exponential backoff.""" - self._reconnect_attempts += 1 - - if self._reconnect_attempts > self.config.max_reconnect_attempts: - self.logger.error( - f"Max reconnect attempts ({self.config.max_reconnect_attempts}) exceeded, stopping pod monitor" - ) - self._state = MonitorState.STOPPING - return - - # Calculate exponential backoff - backoff = min(self.config.watch_reconnect_delay * (2 ** (self._reconnect_attempts - 1)), MAX_BACKOFF_SECONDS) - - self.logger.info( - f"Reconnecting watch in {backoff}s " - f"(attempt {self._reconnect_attempts}/{self.config.max_reconnect_attempts})" - ) - - self._metrics.increment_pod_monitor_watch_reconnects() - await asyncio.sleep(backoff) - - async def _reconciliation_loop(self) -> None: - """Periodically reconcile state with Kubernetes.""" - while self._state == MonitorState.RUNNING: - try: - await asyncio.sleep(self.config.reconcile_interval_seconds) - - if self._state == MonitorState.RUNNING: - result = await self._reconcile_state() - self._log_reconciliation_result(result) + self._logger.error(f"Error publishing event: {e}", exc_info=True) - except Exception as e: - self.logger.error(f"Error in reconciliation loop: {e}", exc_info=True) + async def reconcile_state(self) -> ReconciliationResult: + """Reconcile tracked pods with actual Kubernetes state. - async def _reconcile_state(self) -> ReconciliationResult: - """Reconcile tracked pods with actual state.""" + Should be called periodically from worker entrypoint if reconciliation + is enabled in config. + """ start_time = time.time() try: - self.logger.info("Starting pod state reconciliation") + self._logger.info("Starting pod state reconciliation") - # List all pods matching selector (clients guaranteed by __init__) + # List all pods matching selector pods = await asyncio.to_thread( - self._v1.list_namespaced_pod, namespace=self.config.namespace, label_selector=self.config.label_selector + self._v1.list_namespaced_pod, + namespace=self._config.namespace, + label_selector=self._config.label_selector, ) - # Get current pod names + # Get current pod names from K8s current_pods = {pod.metadata.name for pod in pods.items} + # Get tracked pods from Redis + tracked_pods = await self._pod_state_repo.get_tracked_pod_names() + # Find differences - missing_pods = current_pods - self._tracked_pods - extra_pods = self._tracked_pods - current_pods + missing_pods = current_pods - tracked_pods + extra_pods = tracked_pods - current_pods - # Process missing pods + # Process missing pods (add them to tracking) for pod in pods.items: if pod.metadata.name in missing_pods: - self.logger.info(f"Reconciling missing pod: {pod.metadata.name}") + self._logger.info(f"Reconciling missing pod: {pod.metadata.name}") event = PodEvent( - event_type=WatchEventType.ADDED, pod=pod, resource_version=pod.metadata.resource_version + event_type=WatchEventType.ADDED, + pod=pod, + resource_version=pod.metadata.resource_version, ) await self._process_pod_event(event) - # Remove extra pods + # Remove stale pods from Redis for pod_name in extra_pods: - self.logger.info(f"Removing stale pod from tracking: {pod_name}") - self._tracked_pods.discard(pod_name) + self._logger.info(f"Removing stale pod from tracking: {pod_name}") + await self._pod_state_repo.untrack_pod(pod_name) # Update metrics - self._metrics.update_pod_monitor_pods_watched(len(self._tracked_pods)) + tracked_count = await self._pod_state_repo.get_tracked_pods_count() + self._metrics.update_pod_monitor_pods_watched(tracked_count) self._metrics.record_pod_monitor_reconciliation_run("success") duration = time.time() - start_time return ReconciliationResult( - missing_pods=missing_pods, extra_pods=extra_pods, duration_seconds=duration, success=True + missing_pods=missing_pods, + extra_pods=extra_pods, + duration_seconds=duration, + success=True, ) except Exception as e: - self.logger.error(f"Failed to reconcile state: {e}", exc_info=True) + self._logger.error(f"Failed to reconcile state: {e}", exc_info=True) self._metrics.record_pod_monitor_reconciliation_run("failed") return ReconciliationResult( @@ -431,74 +261,3 @@ async def _reconcile_state(self) -> ReconciliationResult: success=False, error=str(e), ) - - def _log_reconciliation_result(self, result: ReconciliationResult) -> None: - """Log reconciliation result.""" - if result.success: - self.logger.info( - f"Reconciliation completed in {result.duration_seconds:.2f}s. " - f"Found {len(result.missing_pods)} missing, " - f"{len(result.extra_pods)} extra pods" - ) - else: - self.logger.error(f"Reconciliation failed after {result.duration_seconds:.2f}s: {result.error}") - - async def get_status(self) -> StatusDict: - """Get monitor status.""" - return { - "state": self._state, - "tracked_pods": len(self._tracked_pods), - "reconnect_attempts": self._reconnect_attempts, - "last_resource_version": self._last_resource_version, - "config": { - "namespace": self.config.namespace, - "label_selector": self.config.label_selector, - "enable_reconciliation": self.config.enable_state_reconciliation, - }, - } - - -@asynccontextmanager -async def create_pod_monitor( - config: PodMonitorConfig, - kafka_event_service: KafkaEventService, - logger: logging.Logger, - kubernetes_metrics: KubernetesMetrics, - k8s_clients: K8sClients | None = None, - event_mapper: PodEventMapper | None = None, -) -> AsyncIterator[PodMonitor]: - """Create and manage a pod monitor instance. - - This factory handles production dependency creation: - - Creates K8sClients if not provided (using config settings) - - Creates PodEventMapper if not provided - - Cleans up created K8sClients on exit - """ - # Track whether we created clients (so we know to close them) - owns_clients = k8s_clients is None - - if k8s_clients is None: - k8s_clients = create_k8s_clients( - logger=logger, - kubeconfig_path=config.kubeconfig_path, - in_cluster=config.in_cluster, - ) - - if event_mapper is None: - event_mapper = PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) - - monitor = PodMonitor( - config=config, - kafka_event_service=kafka_event_service, - logger=logger, - k8s_clients=k8s_clients, - event_mapper=event_mapper, - kubernetes_metrics=kubernetes_metrics, - ) - - try: - async with monitor: - yield monitor - finally: - if owns_clients: - close_k8s_clients(k8s_clients) diff --git a/backend/app/services/result_processor/processor.py b/backend/app/services/result_processor/processor.py index 464584c7..c7d5f1c7 100644 --- a/backend/app/services/result_processor/processor.py +++ b/backend/app/services/result_processor/processor.py @@ -1,19 +1,21 @@ +"""Result Processor - stateless event handler. + +Processes execution completion events and stores results. +Receives events, processes them, and publishes results. No lifecycle management. +""" + +from __future__ import annotations + import logging -from enum import auto -from typing import Any from pydantic import BaseModel, ConfigDict, Field -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics, ExecutionMetrics -from app.core.utils import StringEnum from app.db.repositories.execution_repository import ExecutionRepository -from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId, KafkaTopic from app.domain.enums.storage import ExecutionErrorType, StorageType from app.domain.events.typed import ( - DomainEvent, EventMetadata, ExecutionCompletedEvent, ExecutionFailedEvent, @@ -22,22 +24,10 @@ ResultStoredEvent, ) from app.domain.execution import ExecutionNotFoundError, ExecutionResultDomain -from app.domain.idempotency import KeyStrategy -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper +from app.events.core import UnifiedProducer from app.settings import Settings -class ProcessingState(StringEnum): - """Processing state enumeration.""" - - IDLE = auto() - PROCESSING = auto() - STOPPED = auto() - - class ResultProcessorConfig(BaseModel): """Configuration for result processor.""" @@ -52,126 +42,33 @@ class ResultProcessorConfig(BaseModel): processing_timeout: int = Field(default=300) -class ResultProcessor(LifecycleEnabled): - """Service for processing execution completion events and storing results.""" +class ResultProcessor: + """Stateless result processor - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + Worker entrypoint handles the consume loop. + """ def __init__( self, execution_repo: ExecutionRepository, producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, settings: Settings, - idempotency_manager: IdempotencyManager, logger: logging.Logger, execution_metrics: ExecutionMetrics, event_metrics: EventMetrics, + config: ResultProcessorConfig | None = None, ) -> None: - """Initialize the result processor.""" - super().__init__() - self.config = ResultProcessorConfig() self._execution_repo = execution_repo self._producer = producer - self._schema_registry = schema_registry self._settings = settings + self._logger = logger self._metrics = execution_metrics self._event_metrics = event_metrics - self._idempotency_manager: IdempotencyManager = idempotency_manager - self._state = ProcessingState.IDLE - self._consumer: IdempotentConsumerWrapper | None = None - self._dispatcher: EventDispatcher | None = None - self.logger = logger - - async def _on_start(self) -> None: - """Start the result processor.""" - self.logger.info("Starting ResultProcessor...") - - # Initialize idempotency manager (safe to call multiple times) - await self._idempotency_manager.initialize() - self.logger.info("Idempotency manager initialized for ResultProcessor") - - self._dispatcher = self._create_dispatcher() - self._consumer = await self._create_consumer() - self._state = ProcessingState.PROCESSING - self.logger.info("ResultProcessor started successfully with idempotency protection") - - async def _on_stop(self) -> None: - """Stop the result processor.""" - self.logger.info("Stopping ResultProcessor...") - self._state = ProcessingState.STOPPED - - if self._consumer: - await self._consumer.stop() - - await self._idempotency_manager.close() - # Note: producer is managed by DI container, not stopped here - self.logger.info("ResultProcessor stopped") - - def _create_dispatcher(self) -> EventDispatcher: - """Create and configure event dispatcher with handlers.""" - dispatcher = EventDispatcher(logger=self.logger) - - # Register handlers for specific event types - dispatcher.register_handler(EventType.EXECUTION_COMPLETED, self._handle_completed_wrapper) - dispatcher.register_handler(EventType.EXECUTION_FAILED, self._handle_failed_wrapper) - dispatcher.register_handler(EventType.EXECUTION_TIMEOUT, self._handle_timeout_wrapper) - - return dispatcher - - async def _create_consumer(self) -> IdempotentConsumerWrapper: - """Create and configure idempotent Kafka consumer.""" - consumer_config = ConsumerConfig( - bootstrap_servers=self._settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=self.config.consumer_group, - max_poll_records=1, - enable_auto_commit=True, - auto_offset_reset="earliest", - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - # Create consumer with schema registry and dispatcher - if not self._dispatcher: - raise RuntimeError("Event dispatcher not initialized") - - base_consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self._dispatcher, - schema_registry=self._schema_registry, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - wrapper = IdempotentConsumerWrapper( - consumer=base_consumer, - idempotency_manager=self._idempotency_manager, - dispatcher=self._dispatcher, - logger=self.logger, - default_key_strategy=KeyStrategy.CONTENT_HASH, - default_ttl_seconds=7200, - enable_for_all_handlers=True, - ) - await wrapper.start(self.config.topics) - return wrapper - - # Wrappers accepting DomainEvent to satisfy dispatcher typing + self._config = config or ResultProcessorConfig() - async def _handle_completed_wrapper(self, event: DomainEvent) -> None: - assert isinstance(event, ExecutionCompletedEvent) - await self._handle_completed(event) - - async def _handle_failed_wrapper(self, event: DomainEvent) -> None: - assert isinstance(event, ExecutionFailedEvent) - await self._handle_failed(event) - - async def _handle_timeout_wrapper(self, event: DomainEvent) -> None: - assert isinstance(event, ExecutionTimeoutEvent) - await self._handle_timeout(event) - - async def _handle_completed(self, event: ExecutionCompletedEvent) -> None: + async def handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: """Handle execution completed event.""" - exec_obj = await self._execution_repo.get_execution(event.execution_id) if exec_obj is None: raise ExecutionNotFoundError(event.execution_id) @@ -190,7 +87,7 @@ async def _handle_completed(self, event: ExecutionCompletedEvent) -> None: # Calculate and record memory utilization percentage settings_limit = self._settings.K8S_POD_MEMORY_LIMIT - memory_limit_mib = int(settings_limit.rstrip("Mi")) # TODO: Less brittle acquisition of limit + memory_limit_mib = int(settings_limit.rstrip("Mi")) memory_percent = (memory_mib / memory_limit_mib) * 100 self._metrics.memory_utilization_percent.record( memory_percent, attributes={"lang_and_version": lang_and_version} @@ -203,20 +100,18 @@ async def _handle_completed(self, event: ExecutionCompletedEvent) -> None: stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata, + metadata=event.metadata.model_dump(), ) try: await self._execution_repo.write_terminal_result(result) await self._publish_result_stored(result) except Exception as e: - self.logger.error(f"Failed to handle ExecutionCompletedEvent: {e}", exc_info=True) + self._logger.error(f"Failed to handle ExecutionCompletedEvent: {e}", exc_info=True) await self._publish_result_failed(event.execution_id, str(e)) - async def _handle_failed(self, event: ExecutionFailedEvent) -> None: + async def handle_execution_failed(self, event: ExecutionFailedEvent) -> None: """Handle execution failed event.""" - - # Fetch execution to get language and version for metrics exec_obj = await self._execution_repo.get_execution(event.execution_id) if exec_obj is None: raise ExecutionNotFoundError(event.execution_id) @@ -232,19 +127,19 @@ async def _handle_failed(self, event: ExecutionFailedEvent) -> None: stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata, + metadata=event.metadata.model_dump(), error_type=event.error_type, ) + try: await self._execution_repo.write_terminal_result(result) await self._publish_result_stored(result) except Exception as e: - self.logger.error(f"Failed to handle ExecutionFailedEvent: {e}", exc_info=True) + self._logger.error(f"Failed to handle ExecutionFailedEvent: {e}", exc_info=True) await self._publish_result_failed(event.execution_id, str(e)) - async def _handle_timeout(self, event: ExecutionTimeoutEvent) -> None: + async def handle_execution_timeout(self, event: ExecutionTimeoutEvent) -> None: """Handle execution timeout event.""" - exec_obj = await self._execution_repo.get_execution(event.execution_id) if exec_obj is None: raise ExecutionNotFoundError(event.execution_id) @@ -263,19 +158,19 @@ async def _handle_timeout(self, event: ExecutionTimeoutEvent) -> None: stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata, + metadata=event.metadata.model_dump(), error_type=ExecutionErrorType.TIMEOUT, ) + try: await self._execution_repo.write_terminal_result(result) await self._publish_result_stored(result) except Exception as e: - self.logger.error(f"Failed to handle ExecutionTimeoutEvent: {e}", exc_info=True) + self._logger.error(f"Failed to handle ExecutionTimeoutEvent: {e}", exc_info=True) await self._publish_result_failed(event.execution_id, str(e)) async def _publish_result_stored(self, result: ExecutionResultDomain) -> None: """Publish result stored event.""" - size_bytes = len(result.stdout) + len(result.stderr) event = ResultStoredEvent( execution_id=result.execution_id, @@ -292,7 +187,6 @@ async def _publish_result_stored(self, result: ExecutionResultDomain) -> None: async def _publish_result_failed(self, execution_id: str, error_message: str) -> None: """Publish result processing failed event.""" - event = ResultFailedEvent( execution_id=execution_id, error=error_message, @@ -303,10 +197,3 @@ async def _publish_result_failed(self, execution_id: str, error_message: str) -> ) await self._producer.produce(event_to_produce=event, key=execution_id) - - async def get_status(self) -> dict[str, Any]: - """Get processor status.""" - return { - "state": self._state, - "consumer_active": self._consumer is not None, - } diff --git a/backend/app/services/saga/__init__.py b/backend/app/services/saga/__init__.py index e89535ae..ec47a201 100644 --- a/backend/app/services/saga/__init__.py +++ b/backend/app/services/saga/__init__.py @@ -12,7 +12,7 @@ RemoveFromQueueCompensation, ValidateExecutionStep, ) -from app.services.saga.saga_orchestrator import SagaOrchestrator, create_saga_orchestrator +from app.services.saga.saga_orchestrator import SagaOrchestrator from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep __all__ = [ @@ -34,5 +34,4 @@ "ReleaseResourcesCompensation", "RemoveFromQueueCompensation", "DeletePodCompensation", - "create_saga_orchestrator", ] diff --git a/backend/app/services/saga/saga_orchestrator.py b/backend/app/services/saga/saga_orchestrator.py index 7d607bb6..e032d6a7 100644 --- a/backend/app/services/saga/saga_orchestrator.py +++ b/backend/app/services/saga/saga_orchestrator.py @@ -1,11 +1,18 @@ -import asyncio +"""Saga Orchestrator - stateless event handler. + +Orchestrates saga execution and compensation. Receives events, +processes them, and publishes results. No lifecycle management. +All state is stored in SagaRepository (MongoDB). +""" + +from __future__ import annotations + import logging from datetime import UTC, datetime, timedelta from uuid import uuid4 from opentelemetry.trace import SpanKind -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics from app.core.tracing import EventAttributes from app.core.tracing.utils import get_tracer @@ -14,165 +21,79 @@ from app.domain.enums.events import EventType from app.domain.enums.saga import SagaState from app.domain.events.typed import DomainEvent, EventMetadata, SagaCancelledEvent -from app.domain.idempotency import KeyStrategy from app.domain.saga.models import Saga, SagaConfig -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -from app.infrastructure.kafka.mappings import get_topic_for_event -from app.services.idempotency import IdempotentConsumerWrapper -from app.services.idempotency.idempotency_manager import IdempotencyManager -from app.settings import Settings +from app.events.core import UnifiedProducer from .base_saga import BaseSaga from .execution_saga import ExecutionSaga from .saga_step import SagaContext -class SagaOrchestrator(LifecycleEnabled): - """Orchestrates saga execution and compensation""" +class SagaOrchestrator: + """Stateless saga orchestrator - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + All state stored in SagaRepository. Worker entrypoint handles the consume loop. + """ def __init__( self, config: SagaConfig, saga_repository: SagaRepository, producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, resource_allocation_repository: ResourceAllocationRepository, logger: logging.Logger, event_metrics: EventMetrics, - ): - super().__init__() - self.config = config - self._sagas: dict[str, type[BaseSaga]] = {} - self._running_instances: dict[str, Saga] = {} - self._consumer: IdempotentConsumerWrapper | None = None - self._idempotency_manager: IdempotencyManager = idempotency_manager + ) -> None: + self._config = config + self._repo = saga_repository self._producer = producer - self._schema_registry_manager = schema_registry_manager - self._settings = settings - self._event_store = event_store - self._repo: SagaRepository = saga_repository - self._alloc_repo: ResourceAllocationRepository = resource_allocation_repository - self._tasks: list[asyncio.Task[None]] = [] - self.logger = logger + self._alloc_repo = resource_allocation_repository + self._logger = logger self._event_metrics = event_metrics + self._sagas: dict[str, type[BaseSaga]] = {} + + # Register default sagas + self._register_default_sagas() def register_saga(self, saga_class: type[BaseSaga]) -> None: + """Register a saga class.""" self._sagas[saga_class.get_name()] = saga_class - self.logger.info(f"Registered saga: {saga_class.get_name()}") + self._logger.info(f"Registered saga: {saga_class.get_name()}") def _register_default_sagas(self) -> None: + """Register default sagas.""" self.register_saga(ExecutionSaga) - self.logger.info("Registered default sagas") - - async def _on_start(self) -> None: - """Start the saga orchestrator.""" - self.logger.info(f"Starting saga orchestrator: {self.config.name}") - - self._register_default_sagas() - - await self._start_consumer() - - timeout_task = asyncio.create_task(self._check_timeouts()) - self._tasks.append(timeout_task) + self._logger.info("Registered default sagas") - self.logger.info("Saga orchestrator started") + def get_trigger_event_types(self) -> set[EventType]: + """Get all event types that trigger sagas. - async def _on_stop(self) -> None: - """Stop the saga orchestrator.""" - self.logger.info("Stopping saga orchestrator...") - - if self._consumer: - await self._consumer.stop() - - await self._idempotency_manager.close() - - for task in self._tasks: - if not task.done(): - task.cancel() - - if self._tasks: - await asyncio.gather(*self._tasks, return_exceptions=True) - - self.logger.info("Saga orchestrator stopped") - - async def _start_consumer(self) -> None: - self.logger.info(f"Registered sagas: {list(self._sagas.keys())}") - topics = set() - event_types_to_register = set() + Helper for worker entrypoint to know which topics to subscribe to. + """ + event_types: set[EventType] = set() for saga_class in self._sagas.values(): - trigger_event_types = saga_class.get_trigger_events() - self.logger.info(f"Saga {saga_class.get_name()} triggers on event types: {trigger_event_types}") - - # Convert event types to topics for subscription - for event_type in trigger_event_types: - topic = get_topic_for_event(event_type) - topics.add(topic) - event_types_to_register.add(event_type) - self.logger.debug(f"Event type {event_type} maps to topic {topic}") - - # Also register handlers for completion events so execution sagas can complete - completion_event_types = { + trigger_events = saga_class.get_trigger_events() + event_types.update(trigger_events) + + # Also include completion events + completion_events = { EventType.EXECUTION_COMPLETED, EventType.EXECUTION_FAILED, EventType.EXECUTION_TIMEOUT, } - for event_type in completion_event_types: - topic = get_topic_for_event(event_type) - topics.add(topic) - event_types_to_register.add(event_type) - self.logger.debug(f"Completion event type {event_type} maps to topic {topic}") - - if not topics: - self.logger.warning("No trigger events found in registered sagas") - return + event_types.update(completion_events) - consumer_config = ConsumerConfig( - bootstrap_servers=self._settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"saga-{self.config.name}", - enable_auto_commit=False, - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - ) + return event_types - dispatcher = EventDispatcher(logger=self.logger) - for event_type in event_types_to_register: - dispatcher.register_handler(event_type, self._handle_event) - self.logger.info(f"Registered handler for event type: {event_type}") - - base_consumer = UnifiedConsumer( - config=consumer_config, - event_dispatcher=dispatcher, - schema_registry=self._schema_registry_manager, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - self._consumer = IdempotentConsumerWrapper( - consumer=base_consumer, - idempotency_manager=self._idempotency_manager, - dispatcher=dispatcher, - logger=self.logger, - default_key_strategy=KeyStrategy.EVENT_BASED, - default_ttl_seconds=7200, - enable_for_all_handlers=False, - ) + async def handle_event(self, event: DomainEvent) -> None: + """Handle incoming event. - assert self._consumer is not None - await self._consumer.start(list(topics)) - - self.logger.info(f"Saga consumer started for topics: {topics}") + Called by worker entrypoint for each event. + """ + self._logger.info(f"Saga orchestrator handling event: type={event.event_type}, id={event.event_id}") - async def _handle_event(self, event: DomainEvent) -> None: - """Handle incoming event""" - self.logger.info(f"Saga orchestrator handling event: type={event.event_type}, id={event.event_id}") try: # Check if this is a completion event that should update an existing saga completion_events = { @@ -187,88 +108,87 @@ async def _handle_event(self, event: DomainEvent) -> None: # Check if this event should trigger a new saga saga_triggered = False for saga_name, saga_class in self._sagas.items(): - self.logger.debug(f"Checking if {saga_name} should be triggered by {event.event_type}") + self._logger.debug(f"Checking if {saga_name} should be triggered by {event.event_type}") if self._should_trigger_saga(saga_class, event): - self.logger.info(f"Event {event.event_type} triggers saga {saga_name}") + self._logger.info(f"Event {event.event_type} triggers saga {saga_name}") saga_triggered = True saga_id = await self._start_saga(saga_name, event) if not saga_id: raise RuntimeError(f"Failed to create saga {saga_name} for event {event.event_id}") if not saga_triggered: - self.logger.debug(f"Event {event.event_type} did not trigger any saga") + self._logger.debug(f"Event {event.event_type} did not trigger any saga") except Exception as e: - self.logger.error(f"Error handling event {event.event_id}: {e}", exc_info=True) + self._logger.error(f"Error handling event {event.event_id}: {e}", exc_info=True) raise async def _handle_completion_event(self, event: DomainEvent) -> None: """Handle execution completion events to update saga state.""" execution_id = getattr(event, "execution_id", None) if not execution_id: - self.logger.warning(f"Completion event {event.event_type} has no execution_id") + self._logger.warning(f"Completion event {event.event_type} has no execution_id") return - # Find the execution saga specifically (not other saga types) + # Find the execution saga specifically saga = await self._repo.get_saga_by_execution_and_name(execution_id, ExecutionSaga.get_name()) if not saga: - self.logger.debug(f"No execution_saga found for execution {execution_id}") + self._logger.debug(f"No execution_saga found for execution {execution_id}") return # Only update if saga is still in a running state if saga.state not in (SagaState.RUNNING, SagaState.CREATED): - self.logger.debug(f"Saga {saga.saga_id} already in terminal state {saga.state}") + self._logger.debug(f"Saga {saga.saga_id} already in terminal state {saga.state}") return # Update saga state based on completion event type if event.event_type == EventType.EXECUTION_COMPLETED: - self.logger.info(f"Marking saga {saga.saga_id} as COMPLETED due to execution completion") + self._logger.info(f"Marking saga {saga.saga_id} as COMPLETED due to execution completion") saga.state = SagaState.COMPLETED saga.completed_at = datetime.now(UTC) elif event.event_type == EventType.EXECUTION_TIMEOUT: timeout_seconds = getattr(event, "timeout_seconds", None) - self.logger.info(f"Marking saga {saga.saga_id} as TIMEOUT after {timeout_seconds}s") + self._logger.info(f"Marking saga {saga.saga_id} as TIMEOUT after {timeout_seconds}s") saga.state = SagaState.TIMEOUT saga.error_message = f"Execution timed out after {timeout_seconds} seconds" saga.completed_at = datetime.now(UTC) else: # EXECUTION_FAILED error_msg = getattr(event, "error_message", None) or f"Execution {event.event_type}" - self.logger.info(f"Marking saga {saga.saga_id} as FAILED: {error_msg}") + self._logger.info(f"Marking saga {saga.saga_id} as FAILED: {error_msg}") saga.state = SagaState.FAILED saga.error_message = error_msg saga.completed_at = datetime.now(UTC) await self._save_saga(saga) - self._running_instances.pop(saga.saga_id, None) def _should_trigger_saga(self, saga_class: type[BaseSaga], event: DomainEvent) -> bool: + """Check if event should trigger a saga.""" trigger_event_types = saga_class.get_trigger_events() should_trigger = event.event_type in trigger_event_types - self.logger.debug( + self._logger.debug( f"Saga {saga_class.get_name()} triggers on {trigger_event_types}, " f"event is {event.event_type}, should trigger: {should_trigger}" ) return should_trigger async def _start_saga(self, saga_name: str, trigger_event: DomainEvent) -> str | None: - """Start a new saga instance""" - self.logger.info(f"Starting saga {saga_name} for event {trigger_event.event_type}") + """Start a new saga instance.""" + self._logger.info(f"Starting saga {saga_name} for event {trigger_event.event_type}") saga_class = self._sagas.get(saga_name) if not saga_class: raise ValueError(f"Unknown saga: {saga_name}") execution_id = getattr(trigger_event, "execution_id", None) - self.logger.debug(f"Extracted execution_id={execution_id} from event") + self._logger.debug(f"Extracted execution_id={execution_id} from event") if not execution_id: - self.logger.warning(f"Could not extract execution ID from event: {trigger_event}") + self._logger.warning(f"Could not extract execution ID from event: {trigger_event}") return None existing = await self._repo.get_saga_by_execution_and_name(execution_id, saga_name) if existing: - self.logger.info(f"Saga {saga_name} already exists for execution {execution_id}") - saga_id: str = existing.saga_id - return saga_id + self._logger.info(f"Saga {saga_name} already exists for execution {execution_id}") + return existing.saga_id instance = Saga( saga_id=str(uuid4()), @@ -278,25 +198,21 @@ async def _start_saga(self, saga_name: str, trigger_event: DomainEvent) -> str | ) await self._save_saga(instance) - self._running_instances[instance.saga_id] = instance - - self.logger.info(f"Started saga {saga_name} (ID: {instance.saga_id}) for execution {execution_id}") + self._logger.info(f"Started saga {saga_name} (ID: {instance.saga_id}) for execution {execution_id}") + # Execute saga steps synchronously saga = saga_class() - # Inject runtime dependencies explicitly (no DI via context) try: saga.bind_dependencies( producer=self._producer, alloc_repo=self._alloc_repo, - publish_commands=bool(getattr(self.config, "publish_commands", False)), + publish_commands=bool(getattr(self._config, "publish_commands", False)), ) except Exception: - # Back-compat: if saga doesn't support binding, it will fallback to context where needed pass context = SagaContext(instance.saga_id, execution_id) - - asyncio.create_task(self._execute_saga(saga, instance, context, trigger_event)) + await self._execute_saga(saga, instance, context, trigger_event) return instance.saga_id @@ -307,24 +223,17 @@ async def _execute_saga( context: SagaContext, trigger_event: DomainEvent, ) -> None: - """Execute saga steps""" + """Execute saga steps synchronously.""" tracer = get_tracer() try: - # Get saga steps steps = saga.get_steps() - # Execute each step for step in steps: - if not self.is_running: - break - - # Update current step instance.current_step = step.name await self._save_saga(instance) - self.logger.info(f"Executing saga step: {step.name} for saga {instance.saga_id}") + self._logger.info(f"Executing saga step: {step.name} for saga {instance.saga_id}") - # Execute step within a span with tracer.start_as_current_span( name="saga.step", kind=SpanKind.INTERNAL, @@ -339,8 +248,6 @@ async def _execute_saga( if success: instance.completed_steps.append(step.name) - - # Persist only safe, public context (no ephemeral objects) instance.context_data = context.to_public_dict() await self._save_saga(instance) @@ -348,162 +255,126 @@ async def _execute_saga( if compensation: context.add_compensation(compensation) else: - # Step failed, start compensation - self.logger.error(f"Saga step {step.name} failed for saga {instance.saga_id}") + self._logger.error(f"Saga step {step.name} failed for saga {instance.saga_id}") - if self.config.enable_compensation: + if self._config.enable_compensation: await self._compensate_saga(instance, context) else: await self._fail_saga(instance, "Step failed without compensation") - return - # All steps completed successfully - # Execution saga waits for external completion events (EXECUTION_COMPLETED/FAILED) + # All steps completed if instance.saga_name == ExecutionSaga.get_name(): - self.logger.info(f"Saga {instance.saga_id} steps done, waiting for execution completion event") + self._logger.info(f"Saga {instance.saga_id} steps done, waiting for execution completion event") else: await self._complete_saga(instance) except Exception as e: - self.logger.error(f"Error executing saga {instance.saga_id}: {e}", exc_info=True) + self._logger.error(f"Error executing saga {instance.saga_id}: {e}", exc_info=True) - if self.config.enable_compensation: + if self._config.enable_compensation: await self._compensate_saga(instance, context) else: await self._fail_saga(instance, str(e)) async def _compensate_saga(self, instance: Saga, context: SagaContext) -> None: - """Execute compensation steps""" - self.logger.info(f"Starting compensation for saga {instance.saga_id}") + """Execute compensation steps.""" + self._logger.info(f"Starting compensation for saga {instance.saga_id}") - # Only update state if not already cancelled if instance.state != SagaState.CANCELLED: instance.state = SagaState.COMPENSATING await self._save_saga(instance) - # Execute compensations in reverse order for compensation in reversed(context.compensations): try: - self.logger.info(f"Executing compensation: {compensation.name} for saga {instance.saga_id}") - + self._logger.info(f"Executing compensation: {compensation.name} for saga {instance.saga_id}") success = await compensation.compensate(context) if success: instance.compensated_steps.append(compensation.name) else: - self.logger.error(f"Compensation {compensation.name} failed for saga {instance.saga_id}") + self._logger.error(f"Compensation {compensation.name} failed for saga {instance.saga_id}") except Exception as e: - self.logger.error(f"Error in compensation {compensation.name}: {e}", exc_info=True) + self._logger.error(f"Error in compensation {compensation.name}: {e}", exc_info=True) - # Mark saga as failed or keep as cancelled if instance.state == SagaState.CANCELLED: - # Keep cancelled state but update compensated steps instance.updated_at = datetime.now(UTC) await self._save_saga(instance) - self.logger.info(f"Saga {instance.saga_id} compensation completed after cancellation") + self._logger.info(f"Saga {instance.saga_id} compensation completed after cancellation") else: - # Mark as failed for non-cancelled compensations await self._fail_saga(instance, "Saga compensated due to failure") async def _complete_saga(self, instance: Saga) -> None: - """Mark saga as completed""" + """Mark saga as completed.""" instance.state = SagaState.COMPLETED instance.completed_at = datetime.now(UTC) await self._save_saga(instance) - - # Remove from running instances - self._running_instances.pop(instance.saga_id, None) - - self.logger.info(f"Saga {instance.saga_id} completed successfully") + self._logger.info(f"Saga {instance.saga_id} completed successfully") async def _fail_saga(self, instance: Saga, error_message: str) -> None: - """Mark saga as failed""" + """Mark saga as failed.""" instance.state = SagaState.FAILED instance.error_message = error_message instance.completed_at = datetime.now(UTC) await self._save_saga(instance) + self._logger.error(f"Saga {instance.saga_id} failed: {error_message}") - # Remove from running instances - self._running_instances.pop(instance.saga_id, None) - - self.logger.error(f"Saga {instance.saga_id} failed: {error_message}") + async def check_timeouts(self) -> int: + """Check for saga timeouts. - async def _check_timeouts(self) -> None: - """Check for saga timeouts""" - while self.is_running: - try: - # Check every 30 seconds - await asyncio.sleep(30) - - cutoff_time = datetime.now(UTC) - timedelta(seconds=self.config.timeout_seconds) - - timed_out = await self._repo.find_timed_out_sagas(cutoff_time) - - for instance in timed_out: - self.logger.warning(f"Saga {instance.saga_id} timed out") - - instance.state = SagaState.TIMEOUT - instance.error_message = f"Saga timed out after {self.config.timeout_seconds} seconds" - instance.completed_at = datetime.now(UTC) - - await self._save_saga(instance) - self._running_instances.pop(instance.saga_id, None) + Should be called periodically from worker entrypoint. + Returns number of timed out sagas. + """ + cutoff_time = datetime.now(UTC) - timedelta(seconds=self._config.timeout_seconds) + timed_out = await self._repo.find_timed_out_sagas(cutoff_time) + count = 0 + + for instance in timed_out: + self._logger.warning(f"Saga {instance.saga_id} timed out") + instance.state = SagaState.TIMEOUT + instance.error_message = f"Saga timed out after {self._config.timeout_seconds} seconds" + instance.completed_at = datetime.now(UTC) + await self._save_saga(instance) + count += 1 - except Exception as e: - self.logger.error(f"Error checking timeouts: {e}") + return count async def _save_saga(self, instance: Saga) -> None: - """Persist saga through repository""" + """Persist saga through repository.""" instance.updated_at = datetime.now(UTC) await self._repo.upsert_saga(instance) async def get_saga_status(self, saga_id: str) -> Saga | None: - """Get saga instance status""" - # Check memory first - if saga_id in self._running_instances: - return self._running_instances[saga_id] - + """Get saga instance status.""" return await self._repo.get_saga(saga_id) async def get_execution_sagas(self, execution_id: str) -> list[Saga]: - """Get all sagas for an execution, sorted by created_at descending (newest first)""" + """Get all sagas for an execution.""" result = await self._repo.get_sagas_by_execution(execution_id) return result.sagas async def cancel_saga(self, saga_id: str) -> bool: - """Cancel a running saga and trigger compensation. - - Args: - saga_id: The ID of the saga to cancel - - Returns: - True if cancelled successfully, False otherwise - """ + """Cancel a running saga and trigger compensation.""" try: - # Get saga instance saga_instance = await self.get_saga_status(saga_id) if not saga_instance: - self.logger.error("Saga not found", extra={"saga_id": saga_id}) + self._logger.error("Saga not found", extra={"saga_id": saga_id}) return False - # Check if saga can be cancelled if saga_instance.state not in [SagaState.RUNNING, SagaState.CREATED]: - self.logger.warning( - "Cannot cancel saga in current state. Only RUNNING or CREATED sagas can be cancelled.", + self._logger.warning( + "Cannot cancel saga in current state", extra={"saga_id": saga_id, "state": saga_instance.state}, ) return False - # Update state to CANCELLED saga_instance.state = SagaState.CANCELLED saga_instance.error_message = "Saga cancelled by user request" saga_instance.completed_at = datetime.now(UTC) - # Log cancellation with user context if available user_id = saga_instance.context_data.get("user_id") - self.logger.info( + self._logger.info( "Saga cancellation initiated", extra={ "saga_id": saga_id, @@ -512,38 +383,28 @@ async def cancel_saga(self, saga_id: str) -> bool: }, ) - # Save state await self._save_saga(saga_instance) - # Remove from running instances - self._running_instances.pop(saga_id, None) - - # Publish cancellation event - if self._producer and self.config.store_events: + if self._config.store_events: await self._publish_saga_cancelled_event(saga_instance) - # Trigger compensation if saga was running and has completed steps - if saga_instance.completed_steps and self.config.enable_compensation: - # Get saga class + if saga_instance.completed_steps and self._config.enable_compensation: saga_class = self._sagas.get(saga_instance.saga_name) if saga_class: - # Create saga instance and context saga = saga_class() try: saga.bind_dependencies( producer=self._producer, alloc_repo=self._alloc_repo, - publish_commands=bool(getattr(self.config, "publish_commands", False)), + publish_commands=bool(getattr(self._config, "publish_commands", False)), ) except Exception: pass - context = SagaContext(saga_instance.saga_id, saga_instance.execution_id) - # Restore context data + context = SagaContext(saga_instance.saga_id, saga_instance.execution_id) for key, value in saga_instance.context_data.items(): context.set(key, value) - # Get steps and build compensation list steps = saga.get_steps() for step in steps: if step.name in saga_instance.completed_steps: @@ -551,19 +412,18 @@ async def cancel_saga(self, saga_id: str) -> bool: if compensation: context.add_compensation(compensation) - # Execute compensation await self._compensate_saga(saga_instance, context) else: - self.logger.error( + self._logger.error( "Saga class not found for compensation", extra={"saga_name": saga_instance.saga_name, "saga_id": saga_id}, ) - self.logger.info("Saga cancelled successfully", extra={"saga_id": saga_id}) + self._logger.info("Saga cancelled successfully", extra={"saga_id": saga_id}) return True except Exception as e: - self.logger.error( + self._logger.error( "Error cancelling saga", extra={"saga_id": saga_id, "error": str(e)}, exc_info=True, @@ -571,11 +431,7 @@ async def cancel_saga(self, saga_id: str) -> bool: return False async def _publish_saga_cancelled_event(self, saga_instance: Saga) -> None: - """Publish saga cancelled event. - - Args: - saga_instance: The cancelled saga instance - """ + """Publish saga cancelled event.""" try: cancelled_by = saga_instance.context_data.get("user_id") if saga_instance.context_data else None metadata = EventMetadata( @@ -596,53 +452,8 @@ async def _publish_saga_cancelled_event(self, saga_instance: Saga) -> None: metadata=metadata, ) - if self._producer: - await self._producer.produce(event_to_produce=event, key=saga_instance.execution_id) - - self.logger.info(f"Published cancellation event for saga {saga_instance.saga_id}") + await self._producer.produce(event_to_produce=event, key=saga_instance.execution_id) + self._logger.info(f"Published cancellation event for saga {saga_instance.saga_id}") except Exception as e: - self.logger.error(f"Failed to publish saga cancellation event: {e}") - - -def create_saga_orchestrator( - saga_repository: SagaRepository, - producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - resource_allocation_repository: ResourceAllocationRepository, - config: SagaConfig, - logger: logging.Logger, - event_metrics: EventMetrics, -) -> SagaOrchestrator: - """Factory function to create a saga orchestrator. - - Args: - saga_repository: Repository for saga persistence - producer: Kafka producer instance - schema_registry_manager: Schema registry manager for event serialization - settings: Application settings - event_store: Event store instance for event sourcing - idempotency_manager: Manager for idempotent event processing - resource_allocation_repository: Repository for resource allocations - config: Saga configuration - logger: Logger instance - event_metrics: Event metrics for tracking Kafka consumption - - Returns: - A new saga orchestrator instance - """ - return SagaOrchestrator( - config, - saga_repository=saga_repository, - producer=producer, - schema_registry_manager=schema_registry_manager, - settings=settings, - event_store=event_store, - idempotency_manager=idempotency_manager, - resource_allocation_repository=resource_allocation_repository, - logger=logger, - event_metrics=event_metrics, - ) + self._logger.error(f"Failed to publish saga cancellation event: {e}") diff --git a/backend/app/services/sse/kafka_redis_bridge.py b/backend/app/services/sse/kafka_redis_bridge.py index 0a4eb780..83d7604e 100644 --- a/backend/app/services/sse/kafka_redis_bridge.py +++ b/backend/app/services/sse/kafka_redis_bridge.py @@ -1,151 +1,84 @@ +"""SSE Kafka Redis Bridge - stateless event handler. + +Bridges Kafka events to Redis channels for SSE delivery. +No lifecycle management - worker entrypoint handles the consume loop. +""" + from __future__ import annotations -import asyncio import logging -from app.core.lifecycle import LifecycleEnabled -from app.core.metrics import EventMetrics from app.domain.enums.events import EventType -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer -from app.events.schema.schema_registry import SchemaRegistryManager from app.services.sse.redis_bus import SSERedisBus -from app.settings import Settings - -class SSEKafkaRedisBridge(LifecycleEnabled): - """ - Bridges Kafka events to Redis channels for SSE delivery. - - - Consumes relevant Kafka topics using a small consumer pool - - Deserializes events and publishes them to Redis via SSERedisBus - - Keeps no in-process buffers; delivery to clients is via Redis only +# Event types relevant for SSE streaming +RELEVANT_EVENT_TYPES: set[EventType] = { + EventType.EXECUTION_REQUESTED, + EventType.EXECUTION_QUEUED, + EventType.EXECUTION_STARTED, + EventType.EXECUTION_RUNNING, + EventType.EXECUTION_COMPLETED, + EventType.EXECUTION_FAILED, + EventType.EXECUTION_TIMEOUT, + EventType.EXECUTION_CANCELLED, + EventType.RESULT_STORED, + EventType.POD_CREATED, + EventType.POD_SCHEDULED, + EventType.POD_RUNNING, + EventType.POD_SUCCEEDED, + EventType.POD_FAILED, + EventType.POD_TERMINATED, + EventType.POD_DELETED, +} + + +class SSEKafkaRedisBridge: + """Stateless SSE bridge - pure event handler. + + No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. + Worker entrypoint handles the consume loop. """ def __init__( - self, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_metrics: EventMetrics, - sse_bus: SSERedisBus, - logger: logging.Logger, + self, + sse_bus: SSERedisBus, + logger: logging.Logger, ) -> None: - super().__init__() - self.schema_registry = schema_registry - self.settings = settings - self.event_metrics = event_metrics - self.sse_bus = sse_bus - self.logger = logger - - self.num_consumers = settings.SSE_CONSUMER_POOL_SIZE - self.consumers: list[UnifiedConsumer] = [] - - async def _on_start(self) -> None: - """Start the SSE Kafka→Redis bridge.""" - self.logger.info(f"Starting SSE Kafka→Redis bridge with {self.num_consumers} consumers") - - # Phase 1: Build all consumers and track them immediately (no I/O) - self.consumers = [self._build_consumer(i) for i in range(self.num_consumers)] - - # Phase 2: Start all in parallel - already tracked in self.consumers for cleanup - topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.WEBSOCKET_GATEWAY]) - await asyncio.gather(*[c.start(topics) for c in self.consumers]) - - self.logger.info("SSE Kafka→Redis bridge started successfully") - - async def _on_stop(self) -> None: - """Stop the SSE Kafka→Redis bridge.""" - self.logger.info("Stopping SSE Kafka→Redis bridge") - await asyncio.gather(*[c.stop() for c in self.consumers], return_exceptions=True) - self.consumers.clear() - self.logger.info("SSE Kafka→Redis bridge stopped") - - def _build_consumer(self, consumer_index: int) -> UnifiedConsumer: - """Build a consumer instance without starting it.""" - config = ConsumerConfig( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id="sse-bridge-pool", - client_id=f"sse-bridge-{consumer_index}", - enable_auto_commit=True, - auto_offset_reset="latest", - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - dispatcher = EventDispatcher(logger=self.logger) - self._register_routing_handlers(dispatcher) - - return UnifiedConsumer( - config=config, - event_dispatcher=dispatcher, - schema_registry=self.schema_registry, - settings=self.settings, - logger=self.logger, - event_metrics=self.event_metrics, - ) - - def _register_routing_handlers(self, dispatcher: EventDispatcher) -> None: - """Publish relevant events to Redis channels keyed by execution_id.""" - relevant_events = [ - EventType.EXECUTION_REQUESTED, - EventType.EXECUTION_QUEUED, - EventType.EXECUTION_STARTED, - EventType.EXECUTION_RUNNING, - EventType.EXECUTION_COMPLETED, - EventType.EXECUTION_FAILED, - EventType.EXECUTION_TIMEOUT, - EventType.EXECUTION_CANCELLED, - EventType.RESULT_STORED, - EventType.POD_CREATED, - EventType.POD_SCHEDULED, - EventType.POD_RUNNING, - EventType.POD_SUCCEEDED, - EventType.POD_FAILED, - EventType.POD_TERMINATED, - EventType.POD_DELETED, - ] - - async def route_event(event: DomainEvent) -> None: - data = event.model_dump() - execution_id = data.get("execution_id") - if not execution_id: - self.logger.debug(f"Event {event.event_type} has no execution_id") - return - try: - await self.sse_bus.publish_event(execution_id, event) - self.logger.info(f"Published {event.event_type} to Redis for {execution_id}") - except Exception as e: - self.logger.error( - f"Failed to publish {event.event_type} to Redis for {execution_id}: {e}", - exc_info=True, - ) - - for et in relevant_events: - dispatcher.register_handler(et, route_event) - - def get_stats(self) -> dict[str, int | bool]: + self._sse_bus = sse_bus + self._logger = logger + + @staticmethod + def get_relevant_event_types() -> set[EventType]: + """Get event types that should be routed to SSE. + + Helper for worker entrypoint to know which topics to subscribe to. + """ + return RELEVANT_EVENT_TYPES + + async def handle_event(self, event: DomainEvent) -> None: + """Handle an event and route to SSE bus. + + Called by worker entrypoint for each event from consume loop. + """ + data = event.model_dump() + execution_id = data.get("execution_id") + + if not execution_id: + self._logger.debug(f"Event {event.event_type} has no execution_id") + return + + try: + await self._sse_bus.publish_event(execution_id, event) + self._logger.debug(f"Published {event.event_type} to Redis for {execution_id}") + except Exception as e: + self._logger.error( + f"Failed to publish {event.event_type} to Redis for {execution_id}: {e}", + exc_info=True, + ) + + async def get_status(self) -> dict[str, list[str]]: + """Get bridge status.""" return { - "num_consumers": len(self.consumers), - "active_executions": 0, - "total_buffers": 0, - "is_running": self.is_running, + "relevant_event_types": [str(et) for et in RELEVANT_EVENT_TYPES], } - - -def create_sse_kafka_redis_bridge( - schema_registry: SchemaRegistryManager, - settings: Settings, - event_metrics: EventMetrics, - sse_bus: SSERedisBus, - logger: logging.Logger, -) -> SSEKafkaRedisBridge: - return SSEKafkaRedisBridge( - schema_registry=schema_registry, - settings=settings, - event_metrics=event_metrics, - sse_bus=sse_bus, - logger=logger, - ) diff --git a/backend/app/services/sse/sse_service.py b/backend/app/services/sse/sse_service.py index 9cdb13ee..6afb1f90 100644 --- a/backend/app/services/sse/sse_service.py +++ b/backend/app/services/sse/sse_service.py @@ -2,13 +2,12 @@ import logging from collections.abc import AsyncGenerator from datetime import datetime, timezone -from typing import Any +from typing import Any, Dict from app.core.metrics import ConnectionMetrics from app.db.repositories.sse_repository import SSERepository from app.domain.enums.events import EventType -from app.domain.enums.sse import SSEControlEvent, SSEHealthStatus, SSENotificationEvent -from app.domain.sse import SSEHealthDomain +from app.domain.enums.sse import SSEControlEvent, SSENotificationEvent from app.schemas_pydantic.execution import ExecutionResult from app.schemas_pydantic.sse import ( RedisNotificationMessage, @@ -50,7 +49,7 @@ def __init__( self.metrics = connection_metrics self.heartbeat_interval = getattr(settings, "SSE_HEARTBEAT_INTERVAL", 30) - async def create_execution_stream(self, execution_id: str, user_id: str) -> AsyncGenerator[dict[str, Any], None]: + async def create_execution_stream(self, execution_id: str, user_id: str) -> AsyncGenerator[Dict[str, Any], None]: connection_id = f"sse_{execution_id}_{datetime.now(timezone.utc).timestamp()}" shutdown_event = await self.shutdown_manager.register_connection(execution_id, connection_id) @@ -125,7 +124,7 @@ async def _stream_events_redis( subscription: Any, shutdown_event: asyncio.Event, include_heartbeat: bool = True, - ) -> AsyncGenerator[dict[str, Any], None]: + ) -> AsyncGenerator[Dict[str, Any], None]: last_heartbeat = datetime.now(timezone.utc) while True: if shutdown_event.is_set(): @@ -195,7 +194,7 @@ async def _build_sse_event_from_redis(self, execution_id: str, msg: RedisSSEMess } ) - async def create_notification_stream(self, user_id: str) -> AsyncGenerator[dict[str, Any], None]: + async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[str, Any], None]: subscription = None try: @@ -258,23 +257,10 @@ async def create_notification_stream(self, user_id: str) -> AsyncGenerator[dict[ if subscription is not None: await asyncio.shield(subscription.close()) - async def get_health_status(self) -> SSEHealthDomain: - router_stats = self.router.get_stats() - return SSEHealthDomain( - status=SSEHealthStatus.DRAINING if self.shutdown_manager.is_shutting_down() else SSEHealthStatus.HEALTHY, - kafka_enabled=True, - active_connections=router_stats["active_executions"], - active_executions=router_stats["active_executions"], - active_consumers=router_stats["num_consumers"], - max_connections_per_user=5, - shutdown=self.shutdown_manager.get_shutdown_status(), - timestamp=datetime.now(timezone.utc), - ) - - def _format_sse_event(self, event: SSEExecutionEventData) -> dict[str, Any]: + def _format_sse_event(self, event: SSEExecutionEventData) -> Dict[str, Any]: """Format typed SSE event for sse-starlette.""" return {"data": event.model_dump_json(exclude_none=True)} - def _format_notification_event(self, event: SSENotificationEventData) -> dict[str, Any]: + def _format_notification_event(self, event: SSENotificationEventData) -> Dict[str, Any]: """Format typed notification SSE event for sse-starlette.""" return {"data": event.model_dump_json(exclude_none=True)} diff --git a/backend/app/services/sse/sse_shutdown_manager.py b/backend/app/services/sse/sse_shutdown_manager.py index c30ee855..dc799d10 100644 --- a/backend/app/services/sse/sse_shutdown_manager.py +++ b/backend/app/services/sse/sse_shutdown_manager.py @@ -3,7 +3,6 @@ import time from enum import Enum -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import ConnectionMetrics from app.domain.sse import ShutdownStatus @@ -56,9 +55,6 @@ def __init__( self._connection_callbacks: dict[str, asyncio.Event] = {} # connection_id -> shutdown event self._draining_connections: set[str] = set() - # Router reference (set during initialization) - self._router: LifecycleEnabled | None = None - # Synchronization self._lock = asyncio.Lock() self._shutdown_event = asyncio.Event() @@ -73,10 +69,6 @@ def __init__( extra={"drain_timeout": drain_timeout, "notification_timeout": notification_timeout}, ) - def set_router(self, router: LifecycleEnabled) -> None: - """Set the router reference for shutdown coordination.""" - self._router = router - async def register_connection(self, execution_id: str, connection_id: str) -> asyncio.Event | None: """ Register a new SSE connection. @@ -259,10 +251,6 @@ async def _force_close_connections(self) -> None: self._connection_callbacks.clear() self._draining_connections.clear() - # If we have a router, tell it to stop accepting new subscriptions - if self._router: - await self._router.aclose() - self.metrics.update_sse_draining_connections(0) self.logger.info("Force close phase complete") @@ -305,31 +293,3 @@ async def _wait_for_complete(self) -> None: """Wait for shutdown to complete""" while not self._shutdown_complete: await asyncio.sleep(0.1) - - -def create_sse_shutdown_manager( - logger: logging.Logger, - connection_metrics: ConnectionMetrics, - drain_timeout: float = 30.0, - notification_timeout: float = 5.0, - force_close_timeout: float = 10.0, -) -> SSEShutdownManager: - """Factory function to create an SSE shutdown manager. - - Args: - logger: Logger instance - connection_metrics: Connection metrics for tracking SSE connections - drain_timeout: Time to wait for connections to close gracefully - notification_timeout: Time to wait for shutdown notifications to be sent - force_close_timeout: Time before force closing connections - - Returns: - A new SSE shutdown manager instance - """ - return SSEShutdownManager( - logger=logger, - connection_metrics=connection_metrics, - drain_timeout=drain_timeout, - notification_timeout=notification_timeout, - force_close_timeout=force_close_timeout, - ) diff --git a/backend/app/services/user_settings_service.py b/backend/app/services/user_settings_service.py index fc9964d0..981f0164 100644 --- a/backend/app/services/user_settings_service.py +++ b/backend/app/services/user_settings_service.py @@ -8,7 +8,6 @@ from app.db.repositories.user_settings_repository import UserSettingsRepository from app.domain.enums import Theme from app.domain.enums.events import EventType -from app.domain.events.typed import DomainEvent, EventMetadata, UserSettingsUpdatedEvent from app.domain.user import ( DomainEditorSettings, DomainNotificationSettings, @@ -17,9 +16,8 @@ DomainUserSettingsChangedEvent, DomainUserSettingsUpdate, ) -from app.services.event_bus import EventBusManager +from app.services.event_bus import EventBus, EventBusEvent from app.services.kafka_event_service import KafkaEventService -from app.settings import Settings _settings_adapter = TypeAdapter(DomainUserSettings) _update_adapter = TypeAdapter(DomainUserSettingsUpdate) @@ -30,12 +28,12 @@ def __init__( self, repository: UserSettingsRepository, event_service: KafkaEventService, - settings: Settings, + event_bus: EventBus, logger: logging.Logger, ) -> None: self.repository = repository self.event_service = event_service - self.settings = settings + self._event_bus = event_bus self.logger = logger self._cache_ttl = timedelta(minutes=5) self._max_cache_size = 1000 @@ -43,8 +41,6 @@ def __init__( maxsize=self._max_cache_size, ttl=self._cache_ttl.total_seconds(), ) - self._event_bus_manager: EventBusManager | None = None - self._subscription_id: str | None = None self.logger.info( "UserSettingsService initialized", @@ -60,20 +56,19 @@ async def get_user_settings(self, user_id: str) -> DomainUserSettings: return await self.get_user_settings_fresh(user_id) - async def initialize(self, event_bus_manager: EventBusManager) -> None: + async def setup_event_subscription(self) -> None: """Subscribe to settings update events for cross-instance cache invalidation. Note: EventBus filters out self-published messages, so this handler only - runs for events from OTHER instances. + runs for events from OTHER instances. Called by DI provider after construction. """ - self._event_bus_manager = event_bus_manager - bus = await event_bus_manager.get_event_bus() - async def _handle(evt: DomainEvent) -> None: - if isinstance(evt, UserSettingsUpdatedEvent): - await self.invalidate_cache(evt.user_id) + async def _handle(evt: EventBusEvent) -> None: + uid = evt.payload.get("user_id") + if uid: + await self.invalidate_cache(str(uid)) - self._subscription_id = await bus.subscribe(f"{EventType.USER_SETTINGS_UPDATED}*", _handle) + await self._event_bus.subscribe("user.settings.updated*", _handle) async def get_user_settings_fresh(self, user_id: str) -> DomainUserSettings: """Bypass cache and rebuild settings from snapshot + events.""" @@ -114,20 +109,7 @@ async def update_user_settings( changes_json = _update_adapter.dump_python(updates, exclude_none=True, mode="json") await self._publish_settings_event(user_id, changes_json, reason) - if self._event_bus_manager is not None: - bus = await self._event_bus_manager.get_event_bus() - await bus.publish( - UserSettingsUpdatedEvent( - user_id=user_id, - changed_fields=list(changes_json.keys()), - reason=reason, - metadata=EventMetadata( - service_name=self.settings.SERVICE_NAME, - service_version=self.settings.SERVICE_VERSION, - user_id=user_id, - ), - ) - ) + await self._event_bus.publish("user.settings.updated", {"user_id": user_id}) self._add_to_cache(user_id, new_settings) if (await self.repository.count_events_since_snapshot(user_id)) >= 10: diff --git a/backend/di_lifecycle_refactor_plan.md b/backend/di_lifecycle_refactor_plan.md new file mode 100644 index 00000000..e19e65f6 --- /dev/null +++ b/backend/di_lifecycle_refactor_plan.md @@ -0,0 +1,64 @@ +# Backend DI Lifecycle Refactor Plan + +## Goals +- Push all service lifecycle management into Dishka providers; no ad-hoc threads, `is_running` flags, or manual `__aenter__/__aexit__`, `_on_start` / `_on_stop` hooks. +- Keep **zero lifecycle helper files or task-group utilities** in app code; Dishka provider primitives alone manage ownership, and container close is the only shutdown signal. +- Simplify worker entrypoints and FastAPI startup so shutting down the DI container is the only teardown required. + +## Current Pain Points (code refs) +- `app/core/lifecycle.py` mixes lifecycle concerns into every long-running service and leaks `is_running` flags across the codebase. +- Kafka-facing services (`events/core/consumer.py`, `events/core/producer.py`, `events/event_store_consumer.py`, `services/result_processor/processor.py`, `services/k8s_worker/worker.py`, `services/coordinator/coordinator.py`, `services/sse/kafka_redis_bridge.py`, `services/notification_service.py`, `dlq/manager.py`) start background loops via `asyncio.create_task` and manage stop signals manually. +- FastAPI lifespan (`app/core/dishka_lifespan.py`) manually enters/starts multiple services and stacks callbacks; workers use `while service.is_running` loops (e.g., `workers/run_saga_orchestrator.py`). +- `app/core/adaptive_sampling.py` spins a raw thread for periodic work. +- `EventBusManager` caches an `EventBus` started via `__aenter__`, duplicating lifecycle logic. + +## Target Architecture (Dishka-centric) +- Use Dishka `@provide` async-generator providers directly—no extra lifecycle helper modules. Providers must **not** call `start/stop/__aenter__/__aexit__`; objects are usable immediately after construction and simply released when the container closes. +- Keep services as pure orchestrators/handlers that assume dependencies are already constructed; no lifecycle methods (`start`, `stop`, `aclose`, `__aenter__`, `__aexit__`, `is_running`). +- FastAPI lifecycle only needs to create/close the container; bootstrap work (schema registry init, Beanie init, rate-limit seeding) runs inside an `APP`-scoped provider without explicit start/stop calls. +- Worker entrypoints resolve services (already constructed by providers), then block on a shutdown event; container.close() just releases references—no service teardown calls. + +## Step-by-Step Refactor +1) **Inline provider construction (no helper files, no start/stop)** + - Keep everything inside existing provider classes; do not add `core/di/*` helper modules. + - Providers construct dependencies once and yield them; **no start/stop or context-manager calls** inside providers. Objects must be drop-safe when the container releases them. + +2) **Retire `LifecycleEnabled`** + - Remove the base class and delete `is_running` state from all services. + - Convert services to plain classes whose constructors accept already-started collaborators (producer, dispatcher, repositories, etc.). + - Where a class only wrapped start/stop (e.g., `UnifiedProducer`, `UnifiedConsumer`), replace with lightweight functions or data classes plus provider-managed runners. + +3) **Kafka-facing services → passive, no lifecycle** + - Event ingestors (`EventStoreConsumer`, `ResultProcessor`, `NotificationService`, `SSEKafkaRedisBridge`, `DLQManager`, `ExecutionCoordinator`, `KubernetesWorker`, `SagaOrchestrator`, `PodMonitor`) become passive components: construction wires handlers/clients, but there is **no start/stop**. Message handling is invoked explicitly by callers (per-request or explicit trigger), not via background loops. + - Delete `_batch_task`, `_scheduling_task`, `_process_task`, and any `asyncio.create_task` usage. No runners, no background scheduling, no threads/processes. + +4) **FastAPI bootstrap simplification** + - Remove custom lifespan entirely; rely on FastAPI default lifecycle. + - Perform one-time bootstrap (schemas, Beanie, rate limits, tracing/metrics wiring) directly in `main.py` before constructing the app object and DI container. No dedicated provider or lifespan hook. + - Wiring stays declarative in `main.py`; providers stay free of bootstrap side-effects. + +5) **Worker entrypoints overhaul** + - Use signal-driven shutdown only: install handlers, wait on shutdown event, then rely on Dishka to close the container (no explicit teardown). Avoid polling loops, task groups, runners, lifecycle files. + - Providers should be sync where possible; prefer simple constructors over async generators so container cleanup is automatic and implicit. + - Each worker script builds settings → container, resolves the needed service (already constructed by its provider), logs readiness, then waits on the shutdown event. + +6) **Adaptive sampling & other threads** + - Replace `AdaptiveSampler` thread with on-demand, stateless computation (pure function or cached calculator). No background loop, no thread, no task group. + - Audit and remove any `threading.Thread` or `multiprocessing` usage; prefer synchronous or explicitly awaited calls executed by the caller. + +7) **Testing & migration** + - Update unit tests to drop assertions around `is_running` and context-manager behavior; add tests that closing the DI container cancels consumer loops and flushes Kafka commits. + - Add a narrow integration test that spins an APP-scoped container with a fake consumer to verify provider-managed shutdown. + - Keep `uv run pytest` as the execution path; prefer `PYTEST_ADDOPTS=` override to disable xdist when debugging lifecycle issues locally. + +## Risks / Open Questions +- Kafka libraries typically expect explicit `start/stop`; shifting to construct-and-drop may require swapping implementations (e.g., per-call producer/consumer) to avoid leaks. +- Some startup routines (Kubernetes config load) are blocking; may still need threadpool execution even without explicit lifecycle. +- Need to ensure Dishka container close is called in all entrypoints so provider objects are released. + +## Definition of Done +- No class in `app/` inherits `LifecycleEnabled`; the file is removed. +- No service exposes `is_running` or `__aenter__/__aexit__`; lifecycle lives exclusively in providers. +- FastAPI and worker entrypoints use container close as the sole shutdown hook. +- No background loops/tasks/threads/processes; services do work only when explicitly invoked. +- Unit and targeted integration tests pass under `uv run pytest` with minimal/no external dependencies. diff --git a/backend/tests/e2e/core/test_dishka_lifespan.py b/backend/tests/e2e/core/test_dishka_lifespan.py index 39aada74..25d4de31 100644 --- a/backend/tests/e2e/core/test_dishka_lifespan.py +++ b/backend/tests/e2e/core/test_dishka_lifespan.py @@ -89,11 +89,11 @@ async def test_sse_bridge_available(self, scope: AsyncContainer) -> None: assert bridge is not None @pytest.mark.asyncio - async def test_event_store_consumer_available( + async def test_event_store_available( self, scope: AsyncContainer ) -> None: - """Event store consumer is available after lifespan.""" - from app.events.event_store_consumer import EventStoreConsumer + """Event store is available after lifespan.""" + from app.events.event_store import EventStore - consumer = await scope.get(EventStoreConsumer) - assert consumer is not None + store = await scope.get(EventStore) + assert store is not None diff --git a/backend/tests/e2e/dlq/test_dlq_manager.py b/backend/tests/e2e/dlq/test_dlq_manager.py index 381f90e2..d8888138 100644 --- a/backend/tests/e2e/dlq/test_dlq_manager.py +++ b/backend/tests/e2e/dlq/test_dlq_manager.py @@ -6,8 +6,7 @@ import pytest from aiokafka import AIOKafkaConsumer, AIOKafkaProducer -from app.core.metrics import DLQMetrics -from app.dlq.manager import create_dlq_manager +from app.dlq.manager import DLQManager from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DLQMessageReceivedEvent @@ -28,9 +27,8 @@ @pytest.mark.asyncio async def test_dlq_manager_persists_and_emits_event(scope: AsyncContainer, test_settings: Settings) -> None: """Test that DLQ manager persists messages and emits DLQMessageReceivedEvent.""" - schema_registry = SchemaRegistryManager(test_settings, _test_logger) - dlq_metrics: DLQMetrics = await scope.get(DLQMetrics) - manager = create_dlq_manager(settings=test_settings, schema_registry=schema_registry, logger=_test_logger, dlq_metrics=dlq_metrics) + schema_registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) + manager: DLQManager = await scope.get(DLQManager) prefix = test_settings.KAFKA_TOPIC_PREFIX ev = make_execution_requested_event(execution_id=f"exec-dlq-persist-{uuid.uuid4().hex[:8]}") @@ -72,31 +70,24 @@ async def consume_dlq_events() -> None: "producer_id": "tests", } - # Produce to DLQ topic BEFORE starting consumers (auto_offset_reset="earliest") - producer = AIOKafkaProducer(bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS) - await producer.start() - try: - await producer.send_and_wait( - topic=f"{prefix}{str(KafkaTopic.DEAD_LETTER_QUEUE)}", - key=ev.event_id.encode(), - value=json.dumps(payload).encode(), - ) - finally: - await producer.stop() - - # Start consumer for DLQ events + # Start consumer for DLQ events before producing await consumer.start() consume_task = asyncio.create_task(consume_dlq_events()) try: - # Start manager - it will consume from DLQ, persist, and emit DLQMessageReceivedEvent - async with manager: - # Await the DLQMessageReceivedEvent - true async, no polling - received = await asyncio.wait_for(received_future, timeout=15.0) - assert received.dlq_event_id == ev.event_id - assert received.event_type == EventType.DLQ_MESSAGE_RECEIVED - assert received.original_event_type == str(EventType.EXECUTION_REQUESTED) - assert received.error == "handler failed" + # Now produce to DLQ topic and call manager.handle_dlq_message directly + raw_message = json.dumps(payload).encode() + headers: dict[str, str] = {} + + # Manager handles the message (stateless handler) + await manager.handle_dlq_message(raw_message, headers) + + # Await the DLQMessageReceivedEvent - true async, no polling + received = await asyncio.wait_for(received_future, timeout=15.0) + assert received.dlq_event_id == ev.event_id + assert received.event_type == EventType.DLQ_MESSAGE_RECEIVED + assert received.original_event_type == str(EventType.EXECUTION_REQUESTED) + assert received.error == "handler failed" finally: consume_task.cancel() try: diff --git a/backend/tests/e2e/events/test_consume_roundtrip.py b/backend/tests/e2e/events/test_consume_roundtrip.py index 3b7d969b..3c64f706 100644 --- a/backend/tests/e2e/events/test_consume_roundtrip.py +++ b/backend/tests/e2e/events/test_consume_roundtrip.py @@ -3,13 +3,13 @@ import uuid import pytest +from aiokafka import AIOKafkaConsumer from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent from app.events.core import UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher -from app.events.core.types import ConsumerConfig from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.settings import Settings from dishka import AsyncContainer @@ -43,22 +43,25 @@ async def _handle(_event: DomainEvent) -> None: received.set() group_id = f"test-consumer.{uuid.uuid4().hex[:6]}" - config = ConsumerConfig( + + # Create AIOKafkaConsumer directly for test + topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EXECUTION_EVENTS}" + kafka_consumer = AIOKafkaConsumer( + topic, bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, group_id=group_id, enable_auto_commit=True, auto_offset_reset="earliest", ) + await kafka_consumer.start() - consumer = UnifiedConsumer( - config, - dispatcher, + handler = UnifiedConsumer( + event_dispatcher=dispatcher, schema_registry=registry, - settings=settings, logger=_test_logger, event_metrics=event_metrics, + group_id=group_id, ) - await consumer.start([KafkaTopic.EXECUTION_EVENTS]) try: # Produce a request event @@ -67,6 +70,13 @@ async def _handle(_event: DomainEvent) -> None: await producer.produce(evt, key=execution_id) # Wait for the handler to be called - await asyncio.wait_for(received.wait(), timeout=10.0) + async def consume_until_received() -> None: + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + if received.is_set(): + break + + await asyncio.wait_for(consume_until_received(), timeout=10.0) finally: - await consumer.stop() + await kafka_consumer.stop() diff --git a/backend/tests/e2e/events/test_consumer_lifecycle.py b/backend/tests/e2e/events/test_consumer_lifecycle.py index 98c53a08..a7e102d9 100644 --- a/backend/tests/e2e/events/test_consumer_lifecycle.py +++ b/backend/tests/e2e/events/test_consumer_lifecycle.py @@ -2,9 +2,10 @@ from uuid import uuid4 import pytest +from aiokafka import AIOKafkaConsumer from app.core.metrics import EventMetrics from app.domain.enums.kafka import KafkaTopic -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer +from app.events.core import EventDispatcher, UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager from app.settings import Settings from dishka import AsyncContainer @@ -17,30 +18,29 @@ @pytest.mark.asyncio -async def test_consumer_start_status_seek_and_stop(scope: AsyncContainer) -> None: +async def test_consumer_seek_operations(scope: AsyncContainer) -> None: + """Test AIOKafkaConsumer seek operations work correctly.""" registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) settings: Settings = await scope.get(Settings) - event_metrics: EventMetrics = await scope.get(EventMetrics) - cfg = ConsumerConfig( + + group_id = f"test-consumer-{uuid4().hex[:6]}" + + # Create AIOKafkaConsumer directly + topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EXECUTION_EVENTS}" + kafka_consumer = AIOKafkaConsumer( + topic, bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"test-consumer-{uuid4().hex[:6]}", + group_id=group_id, + enable_auto_commit=True, + auto_offset_reset="earliest", ) - disp = EventDispatcher(logger=_test_logger) - c = UnifiedConsumer( - cfg, - event_dispatcher=disp, - schema_registry=registry, - settings=settings, - logger=_test_logger, - event_metrics=event_metrics, - ) - await c.start([KafkaTopic.EXECUTION_EVENTS]) + await kafka_consumer.start() + try: - st = c.get_status() - assert st.state == "running" and st.is_running is True - # Exercise seek functions; don't force specific partition offsets - await c.seek_to_beginning() - await c.seek_to_end() - # No need to sleep; just ensure we can call seek APIs while running + # Exercise seek functions on AIOKafkaConsumer directly + assignment = kafka_consumer.assignment() + if assignment: + await kafka_consumer.seek_to_beginning(*assignment) + await kafka_consumer.seek_to_end(*assignment) finally: - await c.stop() + await kafka_consumer.stop() diff --git a/backend/tests/e2e/events/test_event_dispatcher.py b/backend/tests/e2e/events/test_event_dispatcher.py index 2ead3aa3..126bdf04 100644 --- a/backend/tests/e2e/events/test_event_dispatcher.py +++ b/backend/tests/e2e/events/test_event_dispatcher.py @@ -3,13 +3,13 @@ import uuid import pytest +from aiokafka import AIOKafkaConsumer from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent from app.events.core import UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher -from app.events.core.types import ConsumerConfig from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.settings import Settings from dishka import AsyncContainer @@ -44,22 +44,26 @@ async def h1(_e: DomainEvent) -> None: async def h2(_e: DomainEvent) -> None: h2_called.set() - # Real consumer against execution-events - cfg = ConsumerConfig( + group_id = f"dispatcher-it.{uuid.uuid4().hex[:6]}" + + # Create AIOKafkaConsumer directly for test + topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EXECUTION_EVENTS}" + kafka_consumer = AIOKafkaConsumer( + topic, bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"dispatcher-it.{uuid.uuid4().hex[:6]}", + group_id=group_id, enable_auto_commit=True, auto_offset_reset="earliest", ) - consumer = UnifiedConsumer( - cfg, - dispatcher, + await kafka_consumer.start() + + handler = UnifiedConsumer( + event_dispatcher=dispatcher, schema_registry=registry, - settings=settings, logger=_test_logger, event_metrics=event_metrics, + group_id=group_id, ) - await consumer.start([KafkaTopic.EXECUTION_EVENTS]) # Produce a request event via DI producer: UnifiedProducer = await scope.get(UnifiedProducer) @@ -67,6 +71,13 @@ async def h2(_e: DomainEvent) -> None: await producer.produce(evt, key="k") try: - await asyncio.wait_for(asyncio.gather(h1_called.wait(), h2_called.wait()), timeout=10.0) + async def consume_until_handled() -> None: + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + if h1_called.is_set() and h2_called.is_set(): + break + + await asyncio.wait_for(consume_until_handled(), timeout=10.0) finally: - await consumer.stop() + await kafka_consumer.stop() diff --git a/backend/tests/e2e/events/test_producer_roundtrip.py b/backend/tests/e2e/events/test_producer_roundtrip.py index 8340610b..ed1a4cdb 100644 --- a/backend/tests/e2e/events/test_producer_roundtrip.py +++ b/backend/tests/e2e/events/test_producer_roundtrip.py @@ -2,11 +2,8 @@ from uuid import uuid4 import pytest -from app.core.metrics import EventMetrics -from app.events.core import UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager +from app.events.core import ProducerMetrics, UnifiedProducer from app.infrastructure.kafka.mappings import get_topic_for_event -from app.settings import Settings from dishka import AsyncContainer from tests.conftest import make_execution_requested_event @@ -17,25 +14,20 @@ @pytest.mark.asyncio -async def test_unified_producer_start_produce_send_to_dlq_stop( - scope: AsyncContainer, test_settings: Settings -) -> None: - schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - event_metrics: EventMetrics = await scope.get(EventMetrics) - prod = UnifiedProducer( - schema, - logger=_test_logger, - settings=test_settings, - event_metrics=event_metrics, - ) - - async with prod: - ev = make_execution_requested_event(execution_id=f"exec-{uuid4().hex[:8]}") - await prod.produce(ev) - - # Exercise send_to_dlq path - topic = str(get_topic_for_event(ev.event_type)) - await prod.send_to_dlq(ev, original_topic=topic, error=RuntimeError("forced"), retry_count=1) - - st = prod.get_status() - assert st["running"] is True and st["state"] == "running" +async def test_unified_producer_produce_and_send_to_dlq(scope: AsyncContainer) -> None: + # Get producer and metrics from DI + prod: UnifiedProducer = await scope.get(UnifiedProducer) + metrics: ProducerMetrics = await scope.get(ProducerMetrics) + + initial_sent = metrics.messages_sent + + # Produce an event + ev = make_execution_requested_event(execution_id=f"exec-{uuid4().hex[:8]}") + await prod.produce(ev) + + # Exercise send_to_dlq path + topic = str(get_topic_for_event(ev.event_type)) + await prod.send_to_dlq(ev, original_topic=topic, error=RuntimeError("forced"), retry_count=1) + + # Verify metrics are being tracked + assert metrics.messages_sent >= initial_sent + 2 diff --git a/backend/tests/e2e/idempotency/test_consumer_idempotent.py b/backend/tests/e2e/idempotency/test_consumer_idempotent.py index 2ffae6ae..749a0ea3 100644 --- a/backend/tests/e2e/idempotency/test_consumer_idempotent.py +++ b/backend/tests/e2e/idempotency/test_consumer_idempotent.py @@ -3,11 +3,12 @@ import uuid import pytest +from aiokafka import AIOKafkaConsumer from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer +from app.events.core import EventDispatcher, UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher as Disp from app.events.schema.schema_registry import SchemaRegistryManager from app.domain.idempotency import KeyStrategy @@ -57,23 +58,28 @@ async def handle(_ev: DomainEvent) -> None: await producer.produce(ev, key=execution_id) await producer.produce(ev, key=execution_id) - # Real consumer with idempotent wrapper - cfg = ConsumerConfig( + group_id = f"test-idem-consumer.{uuid.uuid4().hex[:6]}" + + # Create AIOKafkaConsumer directly + topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EXECUTION_EVENTS}" + kafka_consumer = AIOKafkaConsumer( + topic, bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"test-idem-consumer.{uuid.uuid4().hex[:6]}", + group_id=group_id, enable_auto_commit=True, auto_offset_reset="earliest", ) - base = UnifiedConsumer( - cfg, + await kafka_consumer.start() + + handler = UnifiedConsumer( event_dispatcher=disp, schema_registry=registry, - settings=settings, logger=_test_logger, event_metrics=event_metrics, + group_id=group_id, ) wrapper = IdempotentConsumerWrapper( - consumer=base, + consumer=handler, idempotency_manager=idm, dispatcher=disp, default_key_strategy=KeyStrategy.EVENT_BASED, @@ -81,10 +87,16 @@ async def handle(_ev: DomainEvent) -> None: logger=_test_logger, ) - await wrapper.start([KafkaTopic.EXECUTION_EVENTS]) try: - # Await the future directly - true async, no polling - await asyncio.wait_for(handled_future, timeout=10.0) + # Consume until handler is called + async def consume_until_handled() -> None: + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + if handled_future.done(): + break + + await asyncio.wait_for(consume_until_handled(), timeout=10.0) assert seen["n"] >= 1 finally: - await wrapper.stop() + await kafka_consumer.stop() diff --git a/backend/tests/e2e/result_processor/test_result_processor.py b/backend/tests/e2e/result_processor/test_result_processor.py deleted file mode 100644 index 4f5b11f3..00000000 --- a/backend/tests/e2e/result_processor/test_result_processor.py +++ /dev/null @@ -1,134 +0,0 @@ -import asyncio -import logging -import uuid - -import pytest -from app.core.database_context import Database -from app.core.metrics import EventMetrics, ExecutionMetrics -from app.db.repositories.execution_repository import ExecutionRepository -from app.domain.enums.events import EventType -from app.domain.enums.execution import ExecutionStatus -from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import ( - EventMetadata, - ExecutionCompletedEvent, - ResourceUsageDomain, - ResultStoredEvent, -) -from app.domain.execution import DomainExecutionCreate -from app.events.core import UnifiedConsumer, UnifiedProducer -from app.events.core.dispatcher import EventDispatcher -from app.events.core.types import ConsumerConfig -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.idempotency import IdempotencyManager -from app.services.result_processor.processor import ResultProcessor -from app.settings import Settings -from dishka import AsyncContainer - -# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers -# instantiate Consumer() objects simultaneously. Serial execution prevents this. -pytestmark = [ - pytest.mark.e2e, - pytest.mark.kafka, - pytest.mark.mongodb, - pytest.mark.xdist_group("kafka_consumers"), -] - -_test_logger = logging.getLogger("test.result_processor.processor") - - -@pytest.mark.asyncio -async def test_result_processor_persists_and_emits(scope: AsyncContainer) -> None: - # Ensure schemas - registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - settings: Settings = await scope.get(Settings) - event_metrics: EventMetrics = await scope.get(EventMetrics) - execution_metrics: ExecutionMetrics = await scope.get(ExecutionMetrics) - await initialize_event_schemas(registry) - - # Dependencies - db: Database = await scope.get(Database) - repo: ExecutionRepository = await scope.get(ExecutionRepository) - producer: UnifiedProducer = await scope.get(UnifiedProducer) - idem: IdempotencyManager = await scope.get(IdempotencyManager) - - # Create a base execution to satisfy ResultProcessor lookup - created = await repo.create_execution(DomainExecutionCreate( - script="print('x')", - user_id="u1", - lang="python", - lang_version="3.11", - status=ExecutionStatus.RUNNING, - )) - execution_id = created.execution_id - - # Build and start the processor - processor = ResultProcessor( - execution_repo=repo, - producer=producer, - schema_registry=registry, - settings=settings, - idempotency_manager=idem, - logger=_test_logger, - execution_metrics=execution_metrics, - event_metrics=event_metrics, - ) - - # Setup a small consumer to capture ResultStoredEvent - dispatcher = EventDispatcher(logger=_test_logger) - stored_received = asyncio.Event() - - @dispatcher.register(EventType.RESULT_STORED) - async def _stored(event: ResultStoredEvent) -> None: - if event.execution_id == execution_id: - stored_received.set() - - group_id = f"rp-test.{uuid.uuid4().hex[:6]}" - cconf = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=group_id, - enable_auto_commit=True, - auto_offset_reset="earliest", - ) - stored_consumer = UnifiedConsumer( - cconf, - dispatcher, - schema_registry=registry, - settings=settings, - logger=_test_logger, - event_metrics=event_metrics, - ) - - # Produce the event BEFORE starting consumers (auto_offset_reset="earliest" will read it) - usage = ResourceUsageDomain( - execution_time_wall_seconds=0.5, - cpu_time_jiffies=100, - clk_tck_hertz=100, - peak_memory_kb=1024, - ) - evt = ExecutionCompletedEvent( - execution_id=execution_id, - exit_code=0, - stdout="hello", - stderr="", - resource_usage=usage, - metadata=EventMetadata(service_name="tests", service_version="1.0.0"), - ) - await producer.produce(evt, key=execution_id) - - # Start consumers after producing - await stored_consumer.start([KafkaTopic.EXECUTION_RESULTS]) - - try: - async with processor: - # Await the ResultStoredEvent - signals that processing is complete - await asyncio.wait_for(stored_received.wait(), timeout=12.0) - - # Now verify DB persistence - should be done since event was emitted - doc = await db.get_collection("executions").find_one({"execution_id": execution_id}) - assert doc is not None, f"Execution {execution_id} not found in DB after ResultStoredEvent" - assert doc.get("status") == ExecutionStatus.COMPLETED, ( - f"Expected COMPLETED status, got {doc.get('status')}" - ) - finally: - await stored_consumer.stop() diff --git a/backend/tests/e2e/services/coordinator/test_execution_coordinator.py b/backend/tests/e2e/services/coordinator/test_execution_coordinator.py index 5406c7b4..472ebc0e 100644 --- a/backend/tests/e2e/services/coordinator/test_execution_coordinator.py +++ b/backend/tests/e2e/services/coordinator/test_execution_coordinator.py @@ -3,148 +3,36 @@ from dishka import AsyncContainer from tests.conftest import make_execution_requested_event -pytestmark = [pytest.mark.e2e, pytest.mark.kafka] +pytestmark = [pytest.mark.e2e, pytest.mark.kafka, pytest.mark.redis] -class TestHandleExecutionRequested: - """Tests for _handle_execution_requested method.""" +class TestExecutionCoordinator: + """Tests for ExecutionCoordinator handler methods.""" @pytest.mark.asyncio - async def test_handle_requested_schedules_execution( - self, scope: AsyncContainer - ) -> None: - """Handler schedules execution immediately.""" + async def test_handle_requested_does_not_raise(self, scope: AsyncContainer) -> None: + """Handler processes execution request without error.""" coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - ev = make_execution_requested_event(execution_id="e-sched-1") + ev = make_execution_requested_event(execution_id="e-test-1") - await coord._handle_execution_requested(ev) # noqa: SLF001 - - assert "e-sched-1" in coord._active_executions # noqa: SLF001 + # Should not raise + await coord.handle_execution_requested(ev) @pytest.mark.asyncio - async def test_handle_requested_with_priority( - self, scope: AsyncContainer - ) -> None: + async def test_handle_requested_with_priority(self, scope: AsyncContainer) -> None: """Handler respects execution priority.""" coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - ev = make_execution_requested_event( - execution_id="e-priority-1", - priority=10, # High priority - ) - - await coord._handle_execution_requested(ev) # noqa: SLF001 - - assert "e-priority-1" in coord._active_executions # noqa: SLF001 - - @pytest.mark.asyncio - async def test_handle_requested_unique_executions( - self, scope: AsyncContainer - ) -> None: - """Each execution gets unique tracking.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - ev1 = make_execution_requested_event(execution_id="e-unique-1") - ev2 = make_execution_requested_event(execution_id="e-unique-2") - - await coord._handle_execution_requested(ev1) # noqa: SLF001 - await coord._handle_execution_requested(ev2) # noqa: SLF001 - - assert "e-unique-1" in coord._active_executions # noqa: SLF001 - assert "e-unique-2" in coord._active_executions # noqa: SLF001 - - -class TestGetStatus: - """Tests for get_status method.""" - - @pytest.mark.asyncio - async def test_get_status_returns_dict(self, scope: AsyncContainer) -> None: - """Get status returns dictionary with coordinator info.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - status = await coord.get_status() - - assert isinstance(status, dict) - assert "running" in status - assert "active_executions" in status - assert "queue_stats" in status - assert "resource_stats" in status - - @pytest.mark.asyncio - async def test_get_status_tracks_active_executions( - self, scope: AsyncContainer - ) -> None: - """Status tracks number of active executions.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - initial_status = await coord.get_status() - initial_active = initial_status.get("active_executions", 0) - - # Add execution - ev = make_execution_requested_event(execution_id="e-status-track-1") - await coord._handle_execution_requested(ev) # noqa: SLF001 - - new_status = await coord.get_status() - new_active = new_status.get("active_executions", 0) - - assert new_active == initial_active + 1, ( - f"Expected exactly one more active execution: {initial_active} -> {new_active}" - ) - - -class TestQueueManager: - """Tests for queue manager integration.""" - - @pytest.mark.asyncio - async def test_queue_manager_initialized(self, scope: AsyncContainer) -> None: - """Queue manager is properly initialized.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - assert coord.queue_manager is not None - assert hasattr(coord.queue_manager, "add_execution") - assert hasattr(coord.queue_manager, "get_next_execution") - - -class TestResourceManager: - """Tests for resource manager integration.""" - - @pytest.mark.asyncio - async def test_resource_manager_initialized( - self, scope: AsyncContainer - ) -> None: - """Resource manager is properly initialized.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - assert coord.resource_manager is not None - assert hasattr(coord.resource_manager, "request_allocation") - assert hasattr(coord.resource_manager, "release_allocation") - - @pytest.mark.asyncio - async def test_resource_manager_has_pool( - self, scope: AsyncContainer - ) -> None: - """Resource manager has resource pool configured.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - - # Check resource manager has pool with capacity - assert coord.resource_manager.pool is not None - assert coord.resource_manager.pool.total_cpu_cores > 0 - assert coord.resource_manager.pool.total_memory_mb > 0 - - -class TestCoordinatorLifecycle: - """Tests for coordinator lifecycle.""" - - @pytest.mark.asyncio - async def test_coordinator_has_consumer(self, scope: AsyncContainer) -> None: - """Coordinator has Kafka consumer configured.""" - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) + ev = make_execution_requested_event(execution_id="e-priority-1", priority=10) - # Consumer is set up during start, may be None before - assert hasattr(coord, "consumer") + await coord.handle_execution_requested(ev) @pytest.mark.asyncio - async def test_coordinator_has_producer(self, scope: AsyncContainer) -> None: - """Coordinator has Kafka producer configured.""" + async def test_coordinator_resolves_from_di(self, scope: AsyncContainer) -> None: + """Coordinator can be resolved from DI container.""" coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - assert coord.producer is not None + assert coord is not None + assert hasattr(coord, "handle_execution_requested") + assert hasattr(coord, "handle_execution_completed") + assert hasattr(coord, "handle_execution_failed") + assert hasattr(coord, "handle_execution_cancelled") diff --git a/backend/tests/e2e/services/events/test_event_bus.py b/backend/tests/e2e/services/events/test_event_bus.py index 5d87b290..f5f083eb 100644 --- a/backend/tests/e2e/services/events/test_event_bus.py +++ b/backend/tests/e2e/services/events/test_event_bus.py @@ -1,11 +1,11 @@ import asyncio +from datetime import datetime, timezone +from uuid import uuid4 import pytest from aiokafka import AIOKafkaProducer -from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import DomainEvent, EventMetadata, UserSettingsUpdatedEvent -from app.services.event_bus import EventBusManager +from app.services.event_bus import EventBus, EventBusEvent from app.settings import Settings from dishka import AsyncContainer @@ -15,24 +15,23 @@ @pytest.mark.asyncio async def test_event_bus_publish_subscribe(scope: AsyncContainer, test_settings: Settings) -> None: """Test EventBus receives events from other instances (cross-instance communication).""" - manager: EventBusManager = await scope.get(EventBusManager) - bus = await manager.get_event_bus() + bus: EventBus = await scope.get(EventBus) # Future resolves when handler receives the event - no polling needed - received_future: asyncio.Future[DomainEvent] = asyncio.get_running_loop().create_future() + received_future: asyncio.Future[EventBusEvent] = asyncio.get_running_loop().create_future() - async def handler(event: DomainEvent) -> None: + async def handler(event: EventBusEvent) -> None: if not received_future.done(): received_future.set_result(event) - await bus.subscribe(f"{EventType.USER_SETTINGS_UPDATED}*", handler) + await bus.subscribe("test.*", handler) # Simulate message from another instance by producing directly to Kafka - event = UserSettingsUpdatedEvent( - user_id="test-user", - changed_fields=["theme"], - reason="test", - metadata=EventMetadata(service_name="test", service_version="1.0"), + event = EventBusEvent( + id=str(uuid4()), + event_type="test.created", + timestamp=datetime.now(timezone.utc), + payload={"x": 1}, ) topic = f"{test_settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" @@ -42,7 +41,7 @@ async def handler(event: DomainEvent) -> None: await producer.send_and_wait( topic=topic, value=event.model_dump_json().encode("utf-8"), - key=EventType.USER_SETTINGS_UPDATED.encode("utf-8"), + key=b"test.created", headers=[("source_instance", b"other-instance")], ) finally: @@ -50,6 +49,4 @@ async def handler(event: DomainEvent) -> None: # Await the future directly - true async, no polling received = await asyncio.wait_for(received_future, timeout=10.0) - assert received.event_type == EventType.USER_SETTINGS_UPDATED - assert isinstance(received, UserSettingsUpdatedEvent) - assert received.user_id == "test-user" + assert received.event_type == "test.created" diff --git a/backend/tests/e2e/services/sse/test_partitioned_event_router.py b/backend/tests/e2e/services/sse/test_partitioned_event_router.py deleted file mode 100644 index 6bb6b71f..00000000 --- a/backend/tests/e2e/services/sse/test_partitioned_event_router.py +++ /dev/null @@ -1,81 +0,0 @@ -import asyncio -import logging -from uuid import uuid4 - -import pytest -import redis.asyncio as redis -from app.core.metrics import EventMetrics -from app.events.core import EventDispatcher -from app.events.schema.schema_registry import SchemaRegistryManager -from app.schemas_pydantic.sse import RedisSSEMessage -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge -from app.services.sse.redis_bus import SSERedisBus -from app.settings import Settings - -from tests.conftest import make_execution_requested_event - -pytestmark = [pytest.mark.e2e, pytest.mark.redis] - -_test_logger = logging.getLogger("test.services.sse.partitioned_event_router_integration") - - -@pytest.mark.asyncio -async def test_router_bridges_to_redis(redis_client: redis.Redis, test_settings: Settings) -> None: - suffix = uuid4().hex[:6] - bus = SSERedisBus( - redis_client, - exec_prefix=f"sse:exec:{suffix}:", - notif_prefix=f"sse:notif:{suffix}:", - logger=_test_logger, - ) - router = SSEKafkaRedisBridge( - schema_registry=SchemaRegistryManager(settings=test_settings, logger=_test_logger), - settings=test_settings, - event_metrics=EventMetrics(test_settings), - sse_bus=bus, - logger=_test_logger, - ) - disp = EventDispatcher(logger=_test_logger) - router._register_routing_handlers(disp) - - # Open Redis subscription for our execution id - execution_id = f"e-{uuid4().hex[:8]}" - subscription = await bus.open_subscription(execution_id) - - ev = make_execution_requested_event(execution_id=execution_id) - handler = disp.get_handlers(ev.event_type)[0] - await handler(ev) - - # Await the subscription directly - true async, no polling - msg = await asyncio.wait_for(subscription.get(RedisSSEMessage), timeout=2.0) - assert msg is not None - assert str(msg.event_type) == str(ev.event_type) - - -@pytest.mark.asyncio -async def test_router_start_and_stop(redis_client: redis.Redis, test_settings: Settings) -> None: - test_settings.SSE_CONSUMER_POOL_SIZE = 1 - suffix = uuid4().hex[:6] - router = SSEKafkaRedisBridge( - schema_registry=SchemaRegistryManager(settings=test_settings, logger=_test_logger), - settings=test_settings, - event_metrics=EventMetrics(test_settings), - sse_bus=SSERedisBus( - redis_client, - exec_prefix=f"sse:exec:{suffix}:", - notif_prefix=f"sse:notif:{suffix}:", - logger=_test_logger, - ), - logger=_test_logger, - ) - - await router.__aenter__() - stats = router.get_stats() - assert stats["num_consumers"] == 1 - await router.aclose() - assert router.get_stats()["num_consumers"] == 0 - # idempotent start/stop - await router.__aenter__() - await router.__aenter__() - await router.aclose() - await router.aclose() diff --git a/backend/tests/e2e/test_k8s_worker_create_pod.py b/backend/tests/e2e/test_k8s_worker_create_pod.py index c43bb2e5..d1efcf80 100644 --- a/backend/tests/e2e/test_k8s_worker_create_pod.py +++ b/backend/tests/e2e/test_k8s_worker_create_pod.py @@ -2,13 +2,7 @@ import uuid import pytest -from app.core.metrics import EventMetrics from app.domain.events.typed import CreatePodCommandEvent, EventMetadata -from app.events.core import UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency import IdempotencyManager -from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.worker import KubernetesWorker from app.settings import Settings from dishka import AsyncContainer @@ -25,27 +19,10 @@ async def test_worker_creates_configmap_and_pod( ) -> None: ns = test_settings.K8S_NAMESPACE - schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - store: EventStore = await scope.get(EventStore) - producer: UnifiedProducer = await scope.get(UnifiedProducer) - idem: IdempotencyManager = await scope.get(IdempotencyManager) - event_metrics: EventMetrics = await scope.get(EventMetrics) + # Get worker from DI (already configured with dependencies) + worker: KubernetesWorker = await scope.get(KubernetesWorker) - cfg = K8sWorkerConfig(namespace=ns, max_concurrent_pods=1) - worker = KubernetesWorker( - config=cfg, - producer=producer, - schema_registry_manager=schema, - settings=test_settings, - event_store=store, - idempotency_manager=idem, - logger=_test_logger, - event_metrics=event_metrics, - ) - - # Initialize k8s clients using worker's own method - worker._initialize_kubernetes_client() # noqa: SLF001 - if worker.v1 is None: + if worker._v1 is None: # noqa: SLF001 pytest.skip("Kubernetes cluster not available") exec_id = uuid.uuid4().hex[:8] @@ -68,7 +45,7 @@ async def test_worker_creates_configmap_and_pod( ) # Build and create ConfigMap + Pod - cm = worker.pod_builder.build_config_map( + cm = worker._pod_builder.build_config_map( # noqa: SLF001 command=cmd, script_content=cmd.script, entrypoint_content=await worker._get_entrypoint_script(), # noqa: SLF001 @@ -80,15 +57,15 @@ async def test_worker_creates_configmap_and_pod( pytest.skip(f"Insufficient permissions or namespace not found: {e}") raise - pod = worker.pod_builder.build_pod_manifest(cmd) + pod = worker._pod_builder.build_pod_manifest(cmd) # noqa: SLF001 await worker._create_pod(pod) # noqa: SLF001 # Verify resources exist - got_cm = worker.v1.read_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) + got_cm = worker._v1.read_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) # noqa: SLF001 assert got_cm is not None - got_pod = worker.v1.read_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) + got_pod = worker._v1.read_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) # noqa: SLF001 assert got_pod is not None # Cleanup - worker.v1.delete_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) - worker.v1.delete_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) + worker._v1.delete_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) # noqa: SLF001 + worker._v1.delete_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) # noqa: SLF001 diff --git a/backend/tests/unit/services/coordinator/test_queue_manager.py b/backend/tests/unit/services/coordinator/test_queue_manager.py deleted file mode 100644 index 671b19a7..00000000 --- a/backend/tests/unit/services/coordinator/test_queue_manager.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging - -import pytest -from app.core.metrics import CoordinatorMetrics -from app.domain.events.typed import ExecutionRequestedEvent -from app.services.coordinator.queue_manager import QueueManager, QueuePriority - -from tests.conftest import make_execution_requested_event - -_test_logger = logging.getLogger("test.services.coordinator.queue_manager") - -pytestmark = pytest.mark.unit - - -def ev(execution_id: str, priority: int = QueuePriority.NORMAL.value) -> ExecutionRequestedEvent: - return make_execution_requested_event(execution_id=execution_id, priority=priority) - - -@pytest.mark.asyncio -async def test_requeue_execution_increments_priority(coordinator_metrics: CoordinatorMetrics) -> None: - qm = QueueManager(max_queue_size=10, logger=_test_logger, coordinator_metrics=coordinator_metrics) - await qm.start() - # Use NORMAL priority which can be incremented to LOW - e = ev("x", priority=QueuePriority.NORMAL.value) - await qm.add_execution(e) - await qm.requeue_execution(e, increment_retry=True) - nxt = await qm.get_next_execution() - assert nxt is not None - await qm.stop() - - -@pytest.mark.asyncio -async def test_queue_stats_empty_and_after_add(coordinator_metrics: CoordinatorMetrics) -> None: - qm = QueueManager(max_queue_size=5, logger=_test_logger, coordinator_metrics=coordinator_metrics) - await qm.start() - stats0 = await qm.get_queue_stats() - assert stats0["total_size"] == 0 - await qm.add_execution(ev("a")) - st = await qm.get_queue_stats() - assert st["total_size"] == 1 - await qm.stop() diff --git a/backend/tests/unit/services/coordinator/test_resource_manager.py b/backend/tests/unit/services/coordinator/test_resource_manager.py deleted file mode 100644 index 3624dae6..00000000 --- a/backend/tests/unit/services/coordinator/test_resource_manager.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging - -import pytest -from app.core.metrics import CoordinatorMetrics -from app.services.coordinator.resource_manager import ResourceManager - -_test_logger = logging.getLogger("test.services.coordinator.resource_manager") - - -@pytest.mark.asyncio -async def test_request_allocation_defaults_and_limits(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager(total_cpu_cores=8.0, total_memory_mb=16384, total_gpu_count=0, logger=_test_logger, coordinator_metrics=coordinator_metrics) - - # Default for python - alloc = await rm.request_allocation("e1", "python") - assert alloc is not None - - assert alloc.cpu_cores > 0 - assert alloc.memory_mb > 0 - - # Respect per-exec max cap - alloc2 = await rm.request_allocation("e2", "python", requested_cpu=100.0, requested_memory_mb=999999) - assert alloc2 is not None - assert alloc2.cpu_cores <= rm.pool.max_cpu_per_execution - assert alloc2.memory_mb <= rm.pool.max_memory_per_execution_mb - - -@pytest.mark.asyncio -async def test_release_and_can_allocate(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager(total_cpu_cores=4.0, total_memory_mb=8192, total_gpu_count=0, logger=_test_logger, coordinator_metrics=coordinator_metrics) - - a = await rm.request_allocation("e1", "python", requested_cpu=1.0, requested_memory_mb=512) - assert a is not None - - ok = await rm.release_allocation("e1") - assert ok is True - - # After release, can allocate near limits while preserving headroom. - # Use a tiny epsilon to avoid edge rounding issues in >= comparisons. - epsilon_cpu = 1e-6 - epsilon_mem = 1 - can = await rm.can_allocate(cpu_cores=rm.pool.total_cpu_cores - rm.pool.min_available_cpu_cores - epsilon_cpu, - memory_mb=rm.pool.total_memory_mb - rm.pool.min_available_memory_mb - epsilon_mem, - gpu_count=0) - assert can is True - - -@pytest.mark.asyncio -async def test_resource_stats(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager(total_cpu_cores=2.0, total_memory_mb=4096, total_gpu_count=0, logger=_test_logger, coordinator_metrics=coordinator_metrics) - # Make sure the allocation succeeds - alloc = await rm.request_allocation("e1", "python", requested_cpu=0.5, requested_memory_mb=256) - assert alloc is not None, "Allocation should have succeeded" - - stats = await rm.get_resource_stats() - - assert stats.total.cpu_cores > 0 - assert stats.available.cpu_cores >= 0 - assert stats.allocated.cpu_cores > 0 # Should be > 0 since we allocated - assert stats.utilization["cpu_percent"] >= 0 - assert stats.allocation_count >= 1 # Should be at least 1 (may have system allocations) diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index 283d428e..8916af97 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -1,3 +1,5 @@ +"""Tests for stateless PodMonitor handler.""" + import asyncio import logging import types @@ -5,37 +7,24 @@ from unittest.mock import MagicMock import pytest -from app.core import k8s_clients as k8s_clients_module -from app.core.k8s_clients import K8sClients +from kubernetes import client as k8s_client + from app.core.metrics import EventMetrics, KubernetesMetrics -from app.db.repositories.event_repository import EventRepository -from app.domain.events.typed import ( - DomainEvent, - EventMetadata, - ExecutionCompletedEvent, - ExecutionStartedEvent, - ResourceUsageDomain, -) +from app.db.repositories.pod_state_repository import PodStateRepository +from app.domain.events.typed import DomainEvent, EventMetadata, ExecutionCompletedEvent +from app.domain.execution.models import ResourceUsageDomain from app.events.core import UnifiedProducer -from app.services.kafka_event_service import KafkaEventService from app.services.pod_monitor.config import PodMonitorConfig from app.services.pod_monitor.event_mapper import PodEventMapper from app.services.pod_monitor.monitor import ( - MonitorState, PodEvent, PodMonitor, ReconciliationResult, WatchEventType, - create_pod_monitor, ) -from app.settings import Settings -from kubernetes.client.rest import ApiException from tests.unit.services.pod_monitor.conftest import ( - MockWatchStream, - Pod, make_mock_v1_api, - make_mock_watch, make_pod, ) @@ -44,55 +33,52 @@ _test_logger = logging.getLogger("test.pod_monitor") -# ===== Test doubles for KafkaEventService dependencies ===== - - -class FakeEventRepository(EventRepository): - """In-memory event repository for testing.""" - - def __init__(self) -> None: - super().__init__(_test_logger) - self.stored_events: list[DomainEvent] = [] - - async def store_event(self, event: DomainEvent) -> str: - self.stored_events.append(event) - return event.event_id - - class FakeUnifiedProducer(UnifiedProducer): """Fake producer that captures events without Kafka.""" def __init__(self) -> None: - # Don't call super().__init__ - we don't need real Kafka self.produced_events: list[tuple[DomainEvent, str | None]] = [] - self.logger = _test_logger + self._logger = _test_logger async def produce( self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None ) -> None: self.produced_events.append((event_to_produce, key)) - async def aclose(self) -> None: - pass +class FakePodStateRepository: + """Fake pod state repository for testing.""" -def create_test_kafka_event_service(event_metrics: EventMetrics) -> tuple[KafkaEventService, FakeUnifiedProducer]: - """Create real KafkaEventService with fake dependencies for testing.""" - fake_producer = FakeUnifiedProducer() - fake_repo = FakeEventRepository() - settings = Settings() # Uses defaults/env vars + def __init__(self) -> None: + self._tracked: set[str] = set() + self._resource_version: str | None = None - service = KafkaEventService( - event_repository=fake_repo, - kafka_producer=fake_producer, - settings=settings, - logger=_test_logger, - event_metrics=event_metrics, - ) - return service, fake_producer + async def track_pod( + self, pod_name: str, execution_id: str, status: str, + metadata: dict[str, object] | None = None, ttl_seconds: int = 7200, + ) -> None: + self._tracked.add(pod_name) + async def untrack_pod(self, pod_name: str) -> bool: + if pod_name in self._tracked: + self._tracked.discard(pod_name) + return True + return False -# ===== Helpers to create test instances with pure DI ===== + async def is_pod_tracked(self, pod_name: str) -> bool: + return pod_name in self._tracked + + async def get_tracked_pod_names(self) -> set[str]: + return self._tracked.copy() + + async def get_tracked_pods_count(self) -> int: + return len(self._tracked) + + async def get_resource_version(self) -> str | None: + return self._resource_version + + async def set_resource_version(self, version: str) -> None: + self._resource_version = version class SpyMapper: @@ -108,43 +94,28 @@ def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # n return [] -def make_k8s_clients_di( - events: list[dict[str, Any]] | None = None, - resource_version: str = "rv1", - pods: list[Pod] | None = None, - logs: str = "{}", -) -> K8sClients: - """Create K8sClients for DI with mocks.""" - v1 = make_mock_v1_api(logs=logs, pods=pods) - watch = make_mock_watch(events or [], resource_version) - return K8sClients( - api_client=MagicMock(), - v1=v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=watch, - ) - - def make_pod_monitor( event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, config: PodMonitorConfig | None = None, - kafka_service: KafkaEventService | None = None, - k8s_clients: K8sClients | None = None, + producer: UnifiedProducer | None = None, + pod_state_repo: PodStateRepository | None = None, + v1_client: k8s_client.CoreV1Api | None = None, event_mapper: PodEventMapper | None = None, ) -> PodMonitor: """Create PodMonitor with sensible test defaults.""" cfg = config or PodMonitorConfig() - clients = k8s_clients or make_k8s_clients_di() + prod = producer or FakeUnifiedProducer() + repo = pod_state_repo or FakePodStateRepository() + v1 = v1_client or make_mock_v1_api("{}") mapper = event_mapper or PodEventMapper(logger=_test_logger, k8s_api=make_mock_v1_api("{}")) - service = kafka_service or create_test_kafka_event_service(event_metrics)[0] return PodMonitor( config=cfg, - kafka_event_service=service, - logger=_test_logger, - k8s_clients=clients, + producer=prod, + pod_state_repo=repo, # type: ignore[arg-type] + v1_client=v1, event_mapper=mapper, + logger=_test_logger, kubernetes_metrics=kubernetes_metrics, ) @@ -153,251 +124,127 @@ def make_pod_monitor( @pytest.mark.asyncio -async def test_start_and_stop_lifecycle(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - - spy = SpyMapper() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, event_mapper=spy) # type: ignore[arg-type] +async def test_handle_raw_event_tracks_pod(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that handle_raw_event tracks new pods.""" + fake_repo = FakePodStateRepository() + pm = make_pod_monitor(event_metrics, kubernetes_metrics, pod_state_repo=fake_repo) # type: ignore[arg-type] - # Replace _watch_pods to avoid real watch loop - async def _quick_watch() -> None: - return None + pod = make_pod(name="test-pod", phase="Running", labels={"execution-id": "e1"}, resource_version="v1") + raw_event = {"type": "ADDED", "object": pod} - pm._watch_pods = _quick_watch # type: ignore[method-assign] + await pm.handle_raw_event(raw_event) - await pm.__aenter__() - assert pm.state == MonitorState.RUNNING - - await pm.aclose() - final_state: MonitorState = pm.state - assert final_state == MonitorState.STOPPED - assert spy.cleared is True + assert "test-pod" in fake_repo._tracked @pytest.mark.asyncio -async def test_watch_pod_events_flow_and_publish(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False +async def test_handle_raw_event_untracks_deleted_pod(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that handle_raw_event untracks deleted pods.""" + fake_repo = FakePodStateRepository() + fake_repo._tracked.add("test-pod") + pm = make_pod_monitor(event_metrics, kubernetes_metrics, pod_state_repo=fake_repo) # type: ignore[arg-type] - pod = make_pod(name="p", phase="Succeeded", labels={"execution-id": "e1"}, term_exit=0, resource_version="rv1") - k8s_clients = make_k8s_clients_di(events=[{"type": "MODIFIED", "object": pod}], resource_version="rv2") + pod = make_pod(name="test-pod", phase="Succeeded", labels={"execution-id": "e1"}, resource_version="v2") + raw_event = {"type": "DELETED", "object": pod} - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) - pm._state = MonitorState.RUNNING + await pm.handle_raw_event(raw_event) - await pm._watch_pod_events() - assert pm._last_resource_version == "rv2" + assert "test-pod" not in fake_repo._tracked @pytest.mark.asyncio -async def test_process_raw_event_invalid_and_handle_watch_error(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) +async def test_handle_raw_event_updates_resource_version(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that handle_raw_event updates resource version.""" + fake_repo = FakePodStateRepository() + pm = make_pod_monitor(event_metrics, kubernetes_metrics, pod_state_repo=fake_repo) # type: ignore[arg-type] - await pm._process_raw_event({}) + pod = make_pod(name="test-pod", phase="Running", labels={"execution-id": "e1"}, resource_version="v123") + raw_event = {"type": "ADDED", "object": pod} - pm.config.watch_reconnect_delay = 0 - pm._reconnect_attempts = 0 - await pm._handle_watch_error() - await pm._handle_watch_error() - assert pm._reconnect_attempts >= 2 + await pm.handle_raw_event(raw_event) + + assert fake_repo._resource_version == "v123" @pytest.mark.asyncio -async def test_get_status(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.namespace = "test-ns" - cfg.label_selector = "app=test" - cfg.enable_state_reconciliation = True +async def test_handle_raw_event_invalid_event(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that handle_raw_event handles invalid events gracefully.""" + pm = make_pod_monitor(event_metrics, kubernetes_metrics) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._tracked_pods = {"pod1", "pod2"} - pm._reconnect_attempts = 3 - pm._last_resource_version = "v123" + # Should not raise for empty event + await pm.handle_raw_event({}) - status = await pm.get_status() - assert "idle" in status["state"].lower() - assert status["tracked_pods"] == 2 - assert status["reconnect_attempts"] == 3 - assert status["last_resource_version"] == "v123" - assert status["config"]["namespace"] == "test-ns" - assert status["config"]["label_selector"] == "app=test" - assert status["config"]["enable_reconciliation"] is True + # Should not raise for event without object + await pm.handle_raw_event({"type": "ADDED"}) @pytest.mark.asyncio -async def test_reconciliation_loop_and_state(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_handle_raw_event_ignored_phase(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that handle_raw_event ignores configured phases.""" cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = True - cfg.reconcile_interval_seconds = 0 # sleep(0) yields control immediately - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - reconcile_called: list[bool] = [] - - async def mock_reconcile() -> ReconciliationResult: - reconcile_called.append(True) - return ReconciliationResult(missing_pods={"p1"}, extra_pods={"p2"}, duration_seconds=0.1, success=True) - - evt = asyncio.Event() - - async def wrapped_reconcile() -> ReconciliationResult: - res = await mock_reconcile() - evt.set() - return res + cfg.ignored_pod_phases = ["Unknown"] + fake_repo = FakePodStateRepository() + pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, pod_state_repo=fake_repo) # type: ignore[arg-type] - pm._reconcile_state = wrapped_reconcile # type: ignore[method-assign] + pod = make_pod(name="ignored-pod", phase="Unknown", labels={"execution-id": "e1"}, resource_version="v1") + raw_event = {"type": "ADDED", "object": pod} - task = asyncio.create_task(pm._reconciliation_loop()) - await asyncio.wait_for(evt.wait(), timeout=1.0) - pm._state = MonitorState.STOPPED - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task + await pm.handle_raw_event(raw_event) - assert len(reconcile_called) > 0 + # Pod should not be tracked due to ignored phase + assert "ignored-pod" not in fake_repo._tracked @pytest.mark.asyncio -async def test_reconcile_state_success(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_reconcile_state_finds_missing_pods(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that reconcile_state identifies missing pods.""" cfg = PodMonitorConfig() cfg.namespace = "test" cfg.label_selector = "app=test" pod1 = make_pod(name="pod1", phase="Running", resource_version="v1") pod2 = make_pod(name="pod2", phase="Running", resource_version="v1") - k8s_clients = make_k8s_clients_di(pods=[pod1, pod2]) - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) - pm._tracked_pods = {"pod2", "pod3"} - processed: list[str] = [] + mock_v1 = MagicMock() + mock_v1.list_namespaced_pod.return_value = MagicMock(items=[pod1, pod2]) - async def mock_process(event: PodEvent) -> None: - processed.append(event.pod.metadata.name) + fake_repo = FakePodStateRepository() + fake_repo._tracked.add("pod2") + fake_repo._tracked.add("pod3") # Extra pod not in K8s - pm._process_pod_event = mock_process # type: ignore[method-assign] + pm = make_pod_monitor( + event_metrics, kubernetes_metrics, config=cfg, pod_state_repo=fake_repo, v1_client=mock_v1 # type: ignore[arg-type] + ) - result = await pm._reconcile_state() + result = await pm.reconcile_state() assert result.success is True assert result.missing_pods == {"pod1"} assert result.extra_pods == {"pod3"} - assert "pod1" in processed - assert "pod3" not in pm._tracked_pods @pytest.mark.asyncio -async def test_reconcile_state_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_reconcile_state_handles_api_error(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that reconcile_state handles API errors gracefully.""" cfg = PodMonitorConfig() - fail_v1 = MagicMock() - fail_v1.list_namespaced_pod.side_effect = RuntimeError("API error") + mock_v1 = MagicMock() + mock_v1.list_namespaced_pod.side_effect = RuntimeError("API error") - k8s_clients = K8sClients( - api_client=MagicMock(), - v1=fail_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=make_mock_watch([]), - ) + pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, v1_client=mock_v1) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) + result = await pm.reconcile_state() - result = await pm._reconcile_state() assert result.success is False assert result.error is not None assert "API error" in result.error @pytest.mark.asyncio -async def test_process_pod_event_full_flow(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.ignored_pod_phases = ["Unknown"] - - class MockMapper: - def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # noqa: ARG002 - class Event: - event_type = types.SimpleNamespace(value="test_event") - metadata = types.SimpleNamespace(correlation_id=None) - aggregate_id = "agg1" - - return [Event()] - - def clear_cache(self) -> None: - pass - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, event_mapper=MockMapper()) # type: ignore[arg-type] - - published: list[Any] = [] - - async def mock_publish(event: Any, pod: Any) -> None: # noqa: ARG001 - published.append(event) - - pm._publish_event = mock_publish # type: ignore[method-assign] - - event = PodEvent( - event_type=WatchEventType.ADDED, - pod=make_pod(name="test-pod", phase="Running"), - resource_version="v1", - ) - - await pm._process_pod_event(event) - assert "test-pod" in pm._tracked_pods - assert pm._last_resource_version == "v1" - assert len(published) == 1 - - event_del = PodEvent( - event_type=WatchEventType.DELETED, - pod=make_pod(name="test-pod", phase="Succeeded"), - resource_version="v2", - ) - - await pm._process_pod_event(event_del) - assert "test-pod" not in pm._tracked_pods - assert pm._last_resource_version == "v2" - - event_ignored = PodEvent( - event_type=WatchEventType.ADDED, - pod=make_pod(name="ignored-pod", phase="Unknown"), - resource_version="v3", - ) - - published.clear() - await pm._process_pod_event(event_ignored) - assert len(published) == 0 - - -@pytest.mark.asyncio -async def test_process_pod_event_exception_handling(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - - class FailMapper: - def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: - raise RuntimeError("Mapping failed") - - def clear_cache(self) -> None: - pass - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, event_mapper=FailMapper()) # type: ignore[arg-type] - - event = PodEvent( - event_type=WatchEventType.ADDED, - pod=make_pod(name="fail-pod", phase="Pending"), - resource_version=None, - ) - - # Should not raise - errors are caught and logged - await pm._process_pod_event(event) - - -@pytest.mark.asyncio -async def test_publish_event_full_flow(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - service, fake_producer = create_test_kafka_event_service(event_metrics) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, kafka_service=service) +async def test_publish_event(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that events are published correctly.""" + fake_producer = FakeUnifiedProducer() + pm = make_pod_monitor(event_metrics, kubernetes_metrics, producer=fake_producer) event = ExecutionCompletedEvent( execution_id="exec1", @@ -415,387 +262,58 @@ async def test_publish_event_full_flow(event_metrics: EventMetrics, kubernetes_m @pytest.mark.asyncio -async def test_publish_event_exception_handling(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - - class FailingProducer(FakeUnifiedProducer): - async def produce( - self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None - ) -> None: - raise RuntimeError("Publish failed") - - # Create service with failing producer - failing_producer = FailingProducer() - fake_repo = FakeEventRepository() - failing_service = KafkaEventService( - event_repository=fake_repo, - kafka_producer=failing_producer, - settings=Settings(), - logger=_test_logger, - event_metrics=event_metrics, - ) - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, kafka_service=failing_service) - - event = ExecutionStartedEvent( - execution_id="exec1", - pod_name="test-pod", - metadata=EventMetadata(service_name="test", service_version="1.0"), - ) - - # Use pod with no metadata to exercise edge case - pod = make_pod(name="no-meta-pod", phase="Pending") - pod.metadata = None # type: ignore[assignment] - - # Should not raise - errors are caught and logged - await pm._publish_event(event, pod) - - -@pytest.mark.asyncio -async def test_handle_watch_error_max_attempts(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.max_reconnect_attempts = 2 - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - pm._reconnect_attempts = 2 - - await pm._handle_watch_error() - - assert pm._state == MonitorState.STOPPING - - -@pytest.mark.asyncio -async def test_watch_pods_main_loop(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - watch_count: list[int] = [] - - async def mock_watch() -> None: - watch_count.append(1) - if len(watch_count) > 2: - pm._state = MonitorState.STOPPED - - async def mock_handle_error() -> None: - pass - - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle_error # type: ignore[method-assign] - - await pm._watch_pods() - assert len(watch_count) > 2 - - -@pytest.mark.asyncio -async def test_watch_pods_api_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - async def mock_watch() -> None: - raise ApiException(status=410) - - error_handled: list[bool] = [] - - async def mock_handle() -> None: - error_handled.append(True) - pm._state = MonitorState.STOPPED - - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle # type: ignore[method-assign] - - await pm._watch_pods() - - assert pm._last_resource_version is None - assert len(error_handled) > 0 - - -@pytest.mark.asyncio -async def test_watch_pods_generic_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - async def mock_watch() -> None: - raise RuntimeError("Unexpected error") - - error_handled: list[bool] = [] - - async def mock_handle() -> None: - error_handled.append(True) - pm._state = MonitorState.STOPPED - - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle # type: ignore[method-assign] - - await pm._watch_pods() - assert len(error_handled) > 0 - - -@pytest.mark.asyncio -async def test_create_pod_monitor_context_manager(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, monkeypatch: pytest.MonkeyPatch) -> None: - """Test create_pod_monitor factory with auto-created dependencies.""" - # Mock create_k8s_clients to avoid real K8s connection - mock_v1 = make_mock_v1_api() - mock_watch = make_mock_watch([]) - mock_clients = K8sClients( - api_client=MagicMock(), - v1=mock_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=mock_watch, - ) - - def mock_create_clients( - logger: logging.Logger, # noqa: ARG001 - kubeconfig_path: str | None = None, # noqa: ARG001 - in_cluster: bool | None = None, # noqa: ARG001 - ) -> K8sClients: - return mock_clients - - monkeypatch.setattr(k8s_clients_module, "create_k8s_clients", mock_create_clients) - monkeypatch.setattr("app.services.pod_monitor.monitor.create_k8s_clients", mock_create_clients) - - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - - service, _ = create_test_kafka_event_service(event_metrics) - - # Use the actual create_pod_monitor which will use our mocked create_k8s_clients - async with create_pod_monitor(cfg, service, _test_logger, kubernetes_metrics=kubernetes_metrics) as monitor: - assert monitor.state == MonitorState.RUNNING - - final_state: MonitorState = monitor.state - assert final_state == MonitorState.STOPPED - - -@pytest.mark.asyncio -async def test_create_pod_monitor_with_injected_k8s_clients(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test create_pod_monitor with injected K8sClients (DI path).""" - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - - service, _ = create_test_kafka_event_service(event_metrics) - - mock_v1 = make_mock_v1_api() - mock_watch = make_mock_watch([]) - mock_k8s_clients = K8sClients( - api_client=MagicMock(), - v1=mock_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=mock_watch, - ) - - async with create_pod_monitor( - cfg, service, _test_logger, k8s_clients=mock_k8s_clients, kubernetes_metrics=kubernetes_metrics - ) as monitor: - assert monitor.state == MonitorState.RUNNING - assert monitor._clients is mock_k8s_clients - assert monitor._v1 is mock_v1 - - final_state: MonitorState = monitor.state - assert final_state == MonitorState.STOPPED - - -@pytest.mark.asyncio -async def test_start_already_running(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test idempotent start via __aenter__.""" - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - - # Simulate already started state - pm._lifecycle_started = True - pm._state = MonitorState.RUNNING - - # Should be idempotent - just return self - await pm.__aenter__() - - -@pytest.mark.asyncio -async def test_stop_already_stopped(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test idempotent stop via aclose().""" - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.STOPPED - # Not started, so aclose should be a no-op - - await pm.aclose() - - -@pytest.mark.asyncio -async def test_stop_with_tasks(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test cleanup of tasks on aclose().""" - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - pm._lifecycle_started = True - - async def dummy_task() -> None: - await asyncio.Event().wait() - - pm._watch_task = asyncio.create_task(dummy_task()) - pm._reconcile_task = asyncio.create_task(dummy_task()) - pm._tracked_pods = {"pod1"} - - await pm.aclose() - - assert pm._state == MonitorState.STOPPED - assert len(pm._tracked_pods) == 0 - - -def test_update_resource_version(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - - class Stream: - _stop_event = types.SimpleNamespace(resource_version="v123") - - pm._update_resource_version(Stream()) - assert pm._last_resource_version == "v123" - - class BadStream: - pass - - pm._update_resource_version(BadStream()) - - -@pytest.mark.asyncio -async def test_process_raw_event_with_metadata(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - - processed: list[PodEvent] = [] - - async def mock_process(event: PodEvent) -> None: - processed.append(event) - - pm._process_pod_event = mock_process # type: ignore[method-assign] - - raw_event = { - "type": "ADDED", - "object": types.SimpleNamespace(metadata=types.SimpleNamespace(resource_version="v1")), - } - - await pm._process_raw_event(raw_event) - assert len(processed) == 1 - assert processed[0].resource_version == "v1" - - raw_event_no_meta = {"type": "MODIFIED", "object": types.SimpleNamespace(metadata=None)} - - await pm._process_raw_event(raw_event_no_meta) - assert len(processed) == 2 - assert processed[1].resource_version is None - - -@pytest.mark.asyncio -async def test_watch_pods_api_exception_other_status(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - async def mock_watch() -> None: - raise ApiException(status=500) - - error_handled: list[bool] = [] - - async def mock_handle() -> None: - error_handled.append(True) - pm._state = MonitorState.STOPPED - - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle # type: ignore[method-assign] - - await pm._watch_pods() - assert len(error_handled) > 0 - - -@pytest.mark.asyncio -async def test_watch_pod_events_with_field_selector(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.field_selector = "status.phase=Running" - cfg.enable_state_reconciliation = False - - watch_kwargs: list[dict[str, Any]] = [] - - tracking_v1 = MagicMock() - - def track_list(namespace: str, label_selector: str) -> None: - watch_kwargs.append({"namespace": namespace, "label_selector": label_selector}) - return None - - tracking_v1.list_namespaced_pod.side_effect = track_list - - tracking_watch = MagicMock() +async def test_process_pod_event_publishes_mapped_events(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that _process_pod_event publishes events from mapper.""" + fake_producer = FakeUnifiedProducer() + fake_repo = FakePodStateRepository() - def track_stream(func: Any, **kwargs: Any) -> MockWatchStream: # noqa: ARG001 - watch_kwargs.append(kwargs) - return MockWatchStream([], "rv1") + class MockMapper: + def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # noqa: ARG002 + return [ + ExecutionCompletedEvent( + execution_id="e1", + aggregate_id="e1", + exit_code=0, + resource_usage=ResourceUsageDomain(), + metadata=EventMetadata(service_name="test", service_version="1.0"), + ) + ] - tracking_watch.stream.side_effect = track_stream - tracking_watch.stop.return_value = None + def clear_cache(self) -> None: + pass - k8s_clients = K8sClients( - api_client=MagicMock(), - v1=tracking_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=tracking_watch, + pm = make_pod_monitor( + event_metrics, + kubernetes_metrics, + producer=fake_producer, + pod_state_repo=fake_repo, # type: ignore[arg-type] + event_mapper=MockMapper(), # type: ignore[arg-type] ) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) - pm._state = MonitorState.RUNNING - - await pm._watch_pod_events() - - assert any("field_selector" in kw for kw in watch_kwargs) - - -@pytest.mark.asyncio -async def test_reconciliation_loop_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = True - cfg.reconcile_interval_seconds = 0 # sleep(0) yields control immediately - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - hit = asyncio.Event() + pod = make_pod(name="test-pod", phase="Running", labels={"execution-id": "e1"}) + event = PodEvent(event_type=WatchEventType.ADDED, pod=pod, resource_version="v1") - async def raising() -> ReconciliationResult: - hit.set() - raise RuntimeError("Reconcile error") - - pm._reconcile_state = raising # type: ignore[method-assign] + await pm._process_pod_event(event) - task = asyncio.create_task(pm._reconciliation_loop()) - await asyncio.wait_for(hit.wait(), timeout=1.0) - pm._state = MonitorState.STOPPED - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task + assert len(fake_producer.produced_events) == 1 + assert "test-pod" in fake_repo._tracked @pytest.mark.asyncio -async def test_start_with_reconciliation(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = True - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) +async def test_process_pod_event_handles_mapper_error(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: + """Test that _process_pod_event handles mapper errors gracefully.""" - async def mock_watch() -> None: - return None + class FailMapper: + def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: + raise RuntimeError("Mapping failed") - async def mock_reconcile() -> None: - return None + def clear_cache(self) -> None: + pass - pm._watch_pods = mock_watch # type: ignore[method-assign] - pm._reconciliation_loop = mock_reconcile # type: ignore[method-assign] + pm = make_pod_monitor(event_metrics, kubernetes_metrics, event_mapper=FailMapper()) # type: ignore[arg-type] - await pm.__aenter__() - assert pm._watch_task is not None - assert pm._reconcile_task is not None + pod = make_pod(name="fail-pod", phase="Pending") + event = PodEvent(event_type=WatchEventType.ADDED, pod=pod, resource_version=None) - await pm.aclose() + # Should not raise - errors are caught and logged + await pm._process_pod_event(event) diff --git a/backend/tests/unit/services/result_processor/test_processor.py b/backend/tests/unit/services/result_processor/test_processor.py index c13fe0ab..90f12556 100644 --- a/backend/tests/unit/services/result_processor/test_processor.py +++ b/backend/tests/unit/services/result_processor/test_processor.py @@ -1,16 +1,9 @@ -import logging -from unittest.mock import MagicMock - import pytest -from app.core.metrics import EventMetrics, ExecutionMetrics -from app.domain.enums.events import EventType from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId, KafkaTopic -from app.services.result_processor.processor import ResultProcessor, ResultProcessorConfig +from app.services.result_processor.processor import ResultProcessorConfig pytestmark = pytest.mark.unit -_test_logger = logging.getLogger("test.services.result_processor.processor") - class TestResultProcessorConfig: def test_default_values(self) -> None: @@ -27,24 +20,3 @@ def test_custom_values(self) -> None: config = ResultProcessorConfig(batch_size=20, processing_timeout=600) assert config.batch_size == 20 assert config.processing_timeout == 600 - - -def test_create_dispatcher_registers_handlers( - execution_metrics: ExecutionMetrics, event_metrics: EventMetrics -) -> None: - rp = ResultProcessor( - execution_repo=MagicMock(), - producer=MagicMock(), - schema_registry=MagicMock(), - settings=MagicMock(), - idempotency_manager=MagicMock(), - logger=_test_logger, - execution_metrics=execution_metrics, - event_metrics=event_metrics, - ) - dispatcher = rp._create_dispatcher() - assert dispatcher is not None - assert EventType.EXECUTION_COMPLETED in dispatcher._handlers - assert EventType.EXECUTION_FAILED in dispatcher._handlers - assert EventType.EXECUTION_TIMEOUT in dispatcher._handlers - diff --git a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py index 8f2b35f9..c0bc6628 100644 --- a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py +++ b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py @@ -10,13 +10,9 @@ from app.domain.events.typed import DomainEvent, ExecutionRequestedEvent from app.domain.saga.models import Saga, SagaConfig from app.events.core import UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency.idempotency_manager import IdempotencyManager from app.services.saga.base_saga import BaseSaga from app.services.saga.saga_orchestrator import SagaOrchestrator from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep -from app.settings import Settings from tests.conftest import make_execution_requested_event @@ -52,23 +48,6 @@ async def produce( return None -class _FakeIdem(IdempotencyManager): - """Fake IdempotencyManager for testing.""" - - def __init__(self) -> None: - pass # Skip parent __init__ - - async def close(self) -> None: - return None - - -class _FakeStore(EventStore): - """Fake EventStore for testing.""" - - def __init__(self) -> None: - pass # Skip parent __init__ - - class _FakeAlloc(ResourceAllocationRepository): """Fake ResourceAllocationRepository for testing.""" @@ -105,10 +84,6 @@ def _orch(event_metrics: EventMetrics) -> SagaOrchestrator: config=SagaConfig(name="t", enable_compensation=True, store_events=True, publish_commands=False), saga_repository=_FakeRepo(), producer=_FakeProd(), - schema_registry_manager=MagicMock(spec=SchemaRegistryManager), - settings=MagicMock(spec=Settings), - event_store=_FakeStore(), - idempotency_manager=_FakeIdem(), resource_allocation_repository=_FakeAlloc(), logger=_test_logger, event_metrics=event_metrics, @@ -119,11 +94,9 @@ def _orch(event_metrics: EventMetrics) -> SagaOrchestrator: async def test_min_success_flow(event_metrics: EventMetrics) -> None: orch = _orch(event_metrics) orch.register_saga(_Saga) - # Set orchestrator running state via lifecycle property - orch._lifecycle_started = True - await orch._handle_event(make_execution_requested_event(execution_id="e")) - # basic sanity; deep behavior covered by integration - assert orch.is_running is True + # Stateless orchestrator - just call handle_event directly + await orch.handle_event(make_execution_requested_event(execution_id="e")) + # Basic sanity - no exception means success; deep behavior covered by integration @pytest.mark.asyncio @@ -133,10 +106,6 @@ async def test_should_trigger_and_existing_short_circuit(event_metrics: EventMet config=SagaConfig(name="t", enable_compensation=True, store_events=True, publish_commands=False), saga_repository=fake_repo, producer=_FakeProd(), - schema_registry_manager=MagicMock(spec=SchemaRegistryManager), - settings=MagicMock(spec=Settings), - event_store=_FakeStore(), - idempotency_manager=_FakeIdem(), resource_allocation_repository=_FakeAlloc(), logger=_test_logger, event_metrics=event_metrics, diff --git a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py index 6fa5d1ef..d2fd5ebd 100644 --- a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py +++ b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py @@ -1,15 +1,10 @@ import logging -from unittest.mock import MagicMock import pytest -from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.events.typed import DomainEvent, EventMetadata, ExecutionStartedEvent -from app.events.core import EventDispatcher -from app.events.schema.schema_registry import SchemaRegistryManager from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus -from app.settings import Settings pytestmark = pytest.mark.unit @@ -31,34 +26,42 @@ def _make_metadata() -> EventMetadata: @pytest.mark.asyncio -async def test_register_and_route_events_without_kafka() -> None: - # Build the bridge but don't call start(); directly test routing handlers +async def test_handle_event_routes_to_redis_bus() -> None: + """Test that handle_event routes events to Redis bus.""" fake_bus = _FakeBus() - mock_settings = MagicMock(spec=Settings) - mock_settings.KAFKA_BOOTSTRAP_SERVERS = "kafka:9092" - mock_settings.SSE_CONSUMER_POOL_SIZE = 1 bridge = SSEKafkaRedisBridge( - schema_registry=MagicMock(spec=SchemaRegistryManager), - settings=mock_settings, - event_metrics=MagicMock(spec=EventMetrics), sse_bus=fake_bus, logger=_test_logger, ) - disp = EventDispatcher(_test_logger) - bridge._register_routing_handlers(disp) - handlers = disp.get_handlers(EventType.EXECUTION_STARTED) - assert len(handlers) > 0 - # Event with empty execution_id is ignored - h = handlers[0] - await h(ExecutionStartedEvent(execution_id="", pod_name="p", metadata=_make_metadata())) + await bridge.handle_event( + ExecutionStartedEvent(execution_id="", pod_name="p", metadata=_make_metadata()) + ) assert fake_bus.published == [] # Proper event is published - await h(ExecutionStartedEvent(execution_id="exec-123", pod_name="p", metadata=_make_metadata())) + await bridge.handle_event( + ExecutionStartedEvent(execution_id="exec-123", pod_name="p", metadata=_make_metadata()) + ) assert fake_bus.published and fake_bus.published[-1][0] == "exec-123" - s = bridge.get_stats() - assert s["num_consumers"] == 0 and s["is_running"] is False + +@pytest.mark.asyncio +async def test_get_status_returns_relevant_event_types() -> None: + """Test that get_status returns relevant event types.""" + fake_bus = _FakeBus() + bridge = SSEKafkaRedisBridge(sse_bus=fake_bus, logger=_test_logger) + + status = await bridge.get_status() + assert "relevant_event_types" in status + assert len(status["relevant_event_types"]) > 0 + + +def test_get_relevant_event_types() -> None: + """Test static method returns relevant event types.""" + event_types = SSEKafkaRedisBridge.get_relevant_event_types() + assert EventType.EXECUTION_STARTED in event_types + assert EventType.EXECUTION_COMPLETED in event_types + assert EventType.RESULT_STORED in event_types diff --git a/backend/tests/unit/services/sse/test_shutdown_manager.py b/backend/tests/unit/services/sse/test_shutdown_manager.py index 05f6e023..1dfc07cc 100644 --- a/backend/tests/unit/services/sse/test_shutdown_manager.py +++ b/backend/tests/unit/services/sse/test_shutdown_manager.py @@ -2,28 +2,22 @@ import logging import pytest -from app.core.lifecycle import LifecycleEnabled + from app.core.metrics import ConnectionMetrics from app.services.sse.sse_shutdown_manager import SSEShutdownManager _test_logger = logging.getLogger("test.services.sse.shutdown_manager") -class _FakeRouter(LifecycleEnabled): - """Fake router that tracks whether aclose was called.""" - - def __init__(self) -> None: - super().__init__() - self.stopped = False - self._lifecycle_started = True # Simulate already-started router - - async def _on_stop(self) -> None: - self.stopped = True - - @pytest.mark.asyncio async def test_shutdown_graceful_notify_and_drain(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager(drain_timeout=1.0, notification_timeout=0.01, force_close_timeout=0.1, logger=_test_logger, connection_metrics=connection_metrics) + mgr = SSEShutdownManager( + drain_timeout=1.0, + notification_timeout=0.01, + force_close_timeout=0.1, + logger=_test_logger, + connection_metrics=connection_metrics, + ) # Register two connections and arrange that they unregister when notified ev1 = await mgr.register_connection("e1", "c1") @@ -46,12 +40,14 @@ async def on_shutdown(event: asyncio.Event, cid: str) -> None: @pytest.mark.asyncio -async def test_shutdown_force_close_calls_router_stop_and_rejects_new(connection_metrics: ConnectionMetrics) -> None: +async def test_shutdown_force_close_and_rejects_new(connection_metrics: ConnectionMetrics) -> None: mgr = SSEShutdownManager( - drain_timeout=0.01, notification_timeout=0.01, force_close_timeout=0.01, logger=_test_logger, connection_metrics=connection_metrics + drain_timeout=0.01, + notification_timeout=0.01, + force_close_timeout=0.01, + logger=_test_logger, + connection_metrics=connection_metrics, ) - router = _FakeRouter() - mgr.set_router(router) # Register a connection but never unregister -> force close path ev = await mgr.register_connection("e1", "c1") @@ -59,7 +55,6 @@ async def test_shutdown_force_close_calls_router_stop_and_rejects_new(connection # Initiate shutdown await mgr.initiate_shutdown() - assert router.stopped is True assert mgr.is_shutting_down() is True status = mgr.get_shutdown_status() assert status.draining_connections == 0 @@ -71,7 +66,13 @@ async def test_shutdown_force_close_calls_router_stop_and_rejects_new(connection @pytest.mark.asyncio async def test_get_shutdown_status_transitions(connection_metrics: ConnectionMetrics) -> None: - m = SSEShutdownManager(drain_timeout=0.01, notification_timeout=0.0, force_close_timeout=0.0, logger=_test_logger, connection_metrics=connection_metrics) + m = SSEShutdownManager( + drain_timeout=0.01, + notification_timeout=0.0, + force_close_timeout=0.0, + logger=_test_logger, + connection_metrics=connection_metrics, + ) st0 = m.get_shutdown_status() assert st0.phase == "ready" await m.initiate_shutdown() diff --git a/backend/tests/unit/services/sse/test_sse_service.py b/backend/tests/unit/services/sse/test_sse_service.py index 3c86a15a..c33298ce 100644 --- a/backend/tests/unit/services/sse/test_sse_service.py +++ b/backend/tests/unit/services/sse/test_sse_service.py @@ -10,9 +10,8 @@ from app.db.repositories.sse_repository import SSERepository from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus -from app.domain.events import ResourceUsageDomain -from app.domain.execution import DomainExecution -from app.domain.sse import ShutdownStatus, SSEExecutionStatusDomain, SSEHealthDomain +from app.domain.execution import DomainExecution, ResourceUsageDomain +from app.domain.sse import ShutdownStatus, SSEExecutionStatusDomain from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus, SSERedisSubscription from app.services.sse.sse_service import SSEService @@ -240,12 +239,3 @@ async def test_notification_stream_connected_and_heartbeat_and_message(connectio # Give the generator a chance to observe the flag and finish with pytest.raises(StopAsyncIteration): await asyncio.wait_for(agen.__anext__(), timeout=0.2) - - -@pytest.mark.asyncio -async def test_health_status_shape(connection_metrics: ConnectionMetrics) -> None: - svc = SSEService(repository=_FakeRepo(), router=_FakeRouter(), sse_bus=_FakeBus(), shutdown_manager=_FakeShutdown(), - settings=_make_fake_settings(), logger=_test_logger, connection_metrics=connection_metrics) - h = await svc.get_health_status() - assert isinstance(h, SSEHealthDomain) - assert h.active_consumers == 3 and h.active_executions == 2 diff --git a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py index fc7ffb3b..e15c427a 100644 --- a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py +++ b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py @@ -2,7 +2,7 @@ import logging import pytest -from app.core.lifecycle import LifecycleEnabled + from app.core.metrics import ConnectionMetrics from app.services.sse.sse_shutdown_manager import SSEShutdownManager @@ -11,22 +11,15 @@ _test_logger = logging.getLogger("test.services.sse.sse_shutdown_manager") -class _FakeRouter(LifecycleEnabled): - """Fake router that tracks whether aclose was called.""" - - def __init__(self) -> None: - super().__init__() - self.stopped = False - self._lifecycle_started = True # Simulate already-started router - - async def _on_stop(self) -> None: - self.stopped = True - - @pytest.mark.asyncio async def test_register_unregister_and_shutdown_flow(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager(drain_timeout=0.5, notification_timeout=0.1, force_close_timeout=0.1, logger=_test_logger, connection_metrics=connection_metrics) - mgr.set_router(_FakeRouter()) + mgr = SSEShutdownManager( + drain_timeout=0.5, + notification_timeout=0.1, + force_close_timeout=0.1, + logger=_test_logger, + connection_metrics=connection_metrics, + ) # Register two connections e1 = await mgr.register_connection("exec-1", "c1") @@ -52,8 +45,13 @@ async def test_register_unregister_and_shutdown_flow(connection_metrics: Connect @pytest.mark.asyncio async def test_reject_new_connection_during_shutdown(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager(drain_timeout=0.5, notification_timeout=0.01, force_close_timeout=0.01, - logger=_test_logger, connection_metrics=connection_metrics) + mgr = SSEShutdownManager( + drain_timeout=0.5, + notification_timeout=0.01, + force_close_timeout=0.01, + logger=_test_logger, + connection_metrics=connection_metrics, + ) # Pre-register one active connection - shutdown will block waiting for it e = await mgr.register_connection("e", "c0") assert e is not None diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index 97598539..47b3ec53 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -46,14 +46,14 @@ def _configure_retry_policies(manager: DLQManager, logger: logging.Logger) -> No topic="websocket-events", strategy=RetryStrategy.FIXED_INTERVAL, max_retries=10, base_delay_seconds=10 ), ) - manager.default_retry_policy = RetryPolicy( + manager.set_default_retry_policy(RetryPolicy( topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF, max_retries=4, base_delay_seconds=60, max_delay_seconds=1800, retry_multiplier=2.5, - ) + )) def _configure_filters(manager: DLQManager, testing: bool, logger: logging.Logger) -> None: diff --git a/backend/workers/run_coordinator.py b/backend/workers/run_coordinator.py index 12004bf1..969b240d 100644 --- a/backend/workers/run_coordinator.py +++ b/backend/workers/run_coordinator.py @@ -1,15 +1,21 @@ +"""Coordinator worker entrypoint - stateless event processing. + +Consumes execution events from Kafka and dispatches to ExecutionCoordinator handlers. +DI container manages all lifecycle - worker just iterates over consumer. +""" + import asyncio import logging -import signal +from aiokafka import AIOKafkaConsumer from app.core.container import create_coordinator_container from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId +from app.events.core import UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.coordinator.coordinator import ExecutionCoordinator from app.settings import Settings from beanie import init_beanie @@ -18,6 +24,7 @@ async def run_coordinator(settings: Settings) -> None: """Run the execution coordinator service.""" container = create_coordinator_container(settings) + logger = await container.get(logging.Logger) logger.info("Starting ExecutionCoordinator with DI container...") @@ -27,27 +34,18 @@ async def run_coordinator(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - coordinator = await container.get(ExecutionCoordinator) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - logger.info("ExecutionCoordinator started and running") - - try: - # Wait for shutdown signal or service to stop - while coordinator.is_running and not shutdown_event.is_set(): - await asyncio.sleep(60) - status = await coordinator.get_status() - logger.info(f"Coordinator status: {status}") - finally: - # Container cleanup stops everything - logger.info("Initiating graceful shutdown...") - await container.close() + kafka_consumer = await container.get(AIOKafkaConsumer) + handler = await container.get(UnifiedConsumer) + + logger.info("ExecutionCoordinator started, consuming events...") + + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + + logger.info("ExecutionCoordinator shutdown complete") + + await container.close() def main() -> None: diff --git a/backend/workers/run_event_replay.py b/backend/workers/run_event_replay.py index 95c38dad..757cb369 100644 --- a/backend/workers/run_event_replay.py +++ b/backend/workers/run_event_replay.py @@ -1,60 +1,38 @@ +"""Event replay worker entrypoint - stateless replay service. + +Provides event replay capability. DI container manages all lifecycle. +This service doesn't consume from Kafka - it's an HTTP-driven replay service. +""" + import asyncio import logging -from contextlib import AsyncExitStack from app.core.container import create_event_replay_container from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS -from app.events.core import UnifiedProducer -from app.services.event_replay.replay_service import EventReplayService from app.settings import Settings from beanie import init_beanie -async def cleanup_task(replay_service: EventReplayService, logger: logging.Logger, interval_hours: int = 6) -> None: - """Periodically clean up old replay sessions""" - while True: - try: - await asyncio.sleep(interval_hours * 3600) - removed = await replay_service.cleanup_old_sessions(older_than_hours=48) - logger.info(f"Cleaned up {removed} old replay sessions") - except Exception as e: - logger.error(f"Error during cleanup: {e}") - - async def run_replay_service(settings: Settings) -> None: - """Run the event replay service with cleanup task.""" + """Run the event replay service.""" container = create_event_replay_container(settings) + logger = await container.get(logging.Logger) logger.info("Starting EventReplayService with DI container...") db = await container.get(Database) await init_beanie(database=db, document_models=ALL_DOCUMENTS) - producer = await container.get(UnifiedProducer) - replay_service = await container.get(EventReplayService) - - logger.info("Event replay service initialized") - - async with AsyncExitStack() as stack: - stack.push_async_callback(container.close) - await stack.enter_async_context(producer) - - task = asyncio.create_task(cleanup_task(replay_service, logger)) - - async def _cancel_task() -> None: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + logger.info("Event replay service initialized and ready") - stack.push_async_callback(_cancel_task) + # Service is HTTP-driven, wait for external shutdown + await asyncio.Event().wait() - await asyncio.Event().wait() + await container.close() def main() -> None: diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index d3b857ad..9e11a25b 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -1,15 +1,21 @@ +"""Kubernetes worker entrypoint - stateless event processing. + +Consumes pod creation events from Kafka and dispatches to KubernetesWorker handlers. +DI container manages all lifecycle - worker just iterates over consumer. +""" + import asyncio import logging -import signal +from aiokafka import AIOKafkaConsumer from app.core.container import create_k8s_worker_container from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId +from app.events.core import UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.k8s_worker.worker import KubernetesWorker from app.settings import Settings from beanie import init_beanie @@ -27,27 +33,18 @@ async def run_kubernetes_worker(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - worker = await container.get(KubernetesWorker) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - logger.info("KubernetesWorker started and running") - - try: - # Wait for shutdown signal or service to stop - while worker.is_running and not shutdown_event.is_set(): - await asyncio.sleep(60) - status = await worker.get_status() - logger.info(f"Kubernetes worker status: {status}") - finally: - # Container cleanup stops everything - logger.info("Initiating graceful shutdown...") - await container.close() + kafka_consumer = await container.get(AIOKafkaConsumer) + handler = await container.get(UnifiedConsumer) + + logger.info("KubernetesWorker started, consuming events...") + + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + + logger.info("KubernetesWorker shutdown complete") + + await container.close() def main() -> None: diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index 4b4dd325..23997ed7 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -1,20 +1,24 @@ +"""Pod monitor worker entrypoint - consumes pod events from Kafka. + +Same pattern as other workers - pure Kafka consumer. +K8s watch is externalized to a separate component that publishes to Kafka. +""" + import asyncio import logging -import signal +from aiokafka import AIOKafkaConsumer from app.core.container import create_pod_monitor_container from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId +from app.events.core import UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.pod_monitor.monitor import MonitorState, PodMonitor from app.settings import Settings from beanie import init_beanie -RECONCILIATION_LOG_INTERVAL: int = 60 - async def run_pod_monitor(settings: Settings) -> None: """Run the pod monitor service.""" @@ -29,27 +33,18 @@ async def run_pod_monitor(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - monitor = await container.get(PodMonitor) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - logger.info("PodMonitor started and running") - - try: - # Wait for shutdown signal or service to stop - while monitor.state == MonitorState.RUNNING and not shutdown_event.is_set(): - await asyncio.sleep(RECONCILIATION_LOG_INTERVAL) - status = await monitor.get_status() - logger.info(f"Pod monitor status: {status}") - finally: - # Container cleanup stops everything - logger.info("Initiating graceful shutdown...") - await container.close() + kafka_consumer = await container.get(AIOKafkaConsumer) + handler = await container.get(UnifiedConsumer) + + logger.info("PodMonitor started, consuming events...") + + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + + logger.info("PodMonitor shutdown complete") + + await container.close() def main() -> None: diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index 5431b011..c7b557db 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -1,74 +1,50 @@ +"""Result processor worker entrypoint - stateless event processing. + +Consumes execution completion events from Kafka and dispatches to ResultProcessor handlers. +DI container manages all lifecycle - worker just iterates over consumer. +""" + import asyncio import logging -import signal -from contextlib import AsyncExitStack +from aiokafka import AIOKafkaConsumer from app.core.container import create_result_processor_container +from app.core.database_context import Database from app.core.logging import setup_logger -from app.core.metrics import EventMetrics, ExecutionMetrics from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS -from app.db.repositories.execution_repository import ExecutionRepository from app.domain.enums.kafka import GroupId -from app.events.core import UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency import IdempotencyManager -from app.services.result_processor.processor import ProcessingState, ResultProcessor +from app.events.core import UnifiedConsumer +from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.settings import Settings from beanie import init_beanie -from pymongo.asynchronous.mongo_client import AsyncMongoClient async def run_result_processor(settings: Settings) -> None: - - db_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( - settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 - ) - await init_beanie(database=db_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) + """Run the result processor service.""" container = create_result_processor_container(settings) - producer = await container.get(UnifiedProducer) - schema_registry = await container.get(SchemaRegistryManager) - idempotency_manager = await container.get(IdempotencyManager) - execution_repo = await container.get(ExecutionRepository) - execution_metrics = await container.get(ExecutionMetrics) - event_metrics = await container.get(EventMetrics) logger = await container.get(logging.Logger) - logger.info(f"Beanie ODM initialized with {len(ALL_DOCUMENTS)} document models") - - # ResultProcessor is manually created (not from DI), so we own its lifecycle - processor = ResultProcessor( - execution_repo=execution_repo, - producer=producer, - schema_registry=schema_registry, - settings=settings, - idempotency_manager=idempotency_manager, - logger=logger, - execution_metrics=execution_metrics, - event_metrics=event_metrics, - ) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - # We own the processor, so we use async with to manage its lifecycle - async with AsyncExitStack() as stack: - stack.callback(db_client.close) - stack.push_async_callback(container.close) - await stack.enter_async_context(processor) - - logger.info("ResultProcessor started and running") - - # Wait for shutdown signal or service to stop - while processor._state == ProcessingState.PROCESSING and not shutdown_event.is_set(): - await asyncio.sleep(60) - status = await processor.get_status() - logger.info(f"ResultProcessor status: {status}") - - logger.info("Initiating graceful shutdown...") + logger.info("Starting ResultProcessor with DI container...") + + db = await container.get(Database) + await init_beanie(database=db, document_models=ALL_DOCUMENTS) + + schema_registry = await container.get(SchemaRegistryManager) + await initialize_event_schemas(schema_registry) + + kafka_consumer = await container.get(AIOKafkaConsumer) + handler = await container.get(UnifiedConsumer) + + logger.info("ResultProcessor started, consuming events...") + + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() + + logger.info("ResultProcessor shutdown complete") + + await container.close() def main() -> None: diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 7fd0c359..3a230be8 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -1,21 +1,27 @@ +"""Saga orchestrator worker entrypoint - stateless event processing. + +Consumes execution events from Kafka and dispatches to SagaOrchestrator handlers. +DI container manages all lifecycle - worker just iterates over consumer. +""" + import asyncio import logging -import signal +from aiokafka import AIOKafkaConsumer from app.core.container import create_saga_orchestrator_container from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId +from app.events.core import UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.saga import SagaOrchestrator from app.settings import Settings from beanie import init_beanie async def run_saga_orchestrator(settings: Settings) -> None: - """Run the saga orchestrator.""" + """Run the saga orchestrator service.""" container = create_saga_orchestrator_container(settings) logger = await container.get(logging.Logger) @@ -27,27 +33,18 @@ async def run_saga_orchestrator(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - orchestrator = await container.get(SagaOrchestrator) + kafka_consumer = await container.get(AIOKafkaConsumer) + handler = await container.get(UnifiedConsumer) - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) + logger.info("SagaOrchestrator started, consuming events...") - logger.info("Saga orchestrator started and running") + async for msg in kafka_consumer: + await handler.handle(msg) + await kafka_consumer.commit() - try: - # Wait for shutdown signal or service to stop - while orchestrator.is_running and not shutdown_event.is_set(): - await asyncio.sleep(1) - finally: - # Container cleanup stops everything - logger.info("Initiating graceful shutdown...") - await container.close() + logger.info("SagaOrchestrator shutdown complete") - logger.warning("Saga orchestrator stopped") + await container.close() def main() -> None: From 528aaa5a3b7fb1f10cfa61749663567ce523df17 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Tue, 27 Jan 2026 16:52:08 +0100 Subject: [PATCH 2/3] mypy and failed test fix --- backend/app/services/idempotency/middleware.py | 11 ++++++----- backend/app/services/result_processor/processor.py | 6 +++--- .../tests/unit/services/pod_monitor/test_monitor.py | 2 +- backend/tests/unit/services/sse/test_sse_service.py | 3 ++- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/backend/app/services/idempotency/middleware.py b/backend/app/services/idempotency/middleware.py index 4dac5287..7fd3d1e3 100644 --- a/backend/app/services/idempotency/middleware.py +++ b/backend/app/services/idempotency/middleware.py @@ -6,6 +6,7 @@ from app.domain.enums.events import EventType from app.domain.events.typed import DomainEvent +from app.domain.idempotency import KeyStrategy from app.events.core import EventDispatcher, UnifiedConsumer from app.services.idempotency.idempotency_manager import IdempotencyManager @@ -18,7 +19,7 @@ def __init__( handler: Callable[[DomainEvent], Awaitable[None]], idempotency_manager: IdempotencyManager, logger: logging.Logger, - key_strategy: str = "event_based", + key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, custom_key_func: Callable[[DomainEvent], str] | None = None, fields: Set[str] | None = None, ttl_seconds: int | None = None, @@ -43,7 +44,7 @@ async def __call__(self, event: DomainEvent) -> None: ) # Generate custom key if function provided custom_key = None - if self.key_strategy == "custom" and self.custom_key_func: + if self.key_strategy == KeyStrategy.CUSTOM and self.custom_key_func: custom_key = self.custom_key_func(event) # Check idempotency @@ -92,7 +93,7 @@ async def __call__(self, event: DomainEvent) -> None: def idempotent_handler( idempotency_manager: IdempotencyManager, logger: logging.Logger, - key_strategy: str = "event_based", + key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, custom_key_func: Callable[[DomainEvent], str] | None = None, fields: Set[str] | None = None, ttl_seconds: int | None = None, @@ -127,7 +128,7 @@ def __init__( idempotency_manager: IdempotencyManager, dispatcher: EventDispatcher, logger: logging.Logger, - default_key_strategy: str = "event_based", + default_key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, default_ttl_seconds: int = 3600, enable_for_all_handlers: bool = True, ): @@ -171,7 +172,7 @@ def subscribe_idempotent_handler( self, event_type: str, handler: Callable[[DomainEvent], Awaitable[None]], - key_strategy: str | None = None, + key_strategy: KeyStrategy | None = None, custom_key_func: Callable[[DomainEvent], str] | None = None, fields: Set[str] | None = None, ttl_seconds: int | None = None, diff --git a/backend/app/services/result_processor/processor.py b/backend/app/services/result_processor/processor.py index c7d5f1c7..a71a4e4b 100644 --- a/backend/app/services/result_processor/processor.py +++ b/backend/app/services/result_processor/processor.py @@ -100,7 +100,7 @@ async def handle_execution_completed(self, event: ExecutionCompletedEvent) -> No stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata.model_dump(), + metadata=event.metadata, ) try: @@ -127,7 +127,7 @@ async def handle_execution_failed(self, event: ExecutionFailedEvent) -> None: stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata.model_dump(), + metadata=event.metadata, error_type=event.error_type, ) @@ -158,7 +158,7 @@ async def handle_execution_timeout(self, event: ExecutionTimeoutEvent) -> None: stdout=event.stdout, stderr=event.stderr, resource_usage=event.resource_usage, - metadata=event.metadata.model_dump(), + metadata=event.metadata, error_type=ExecutionErrorType.TIMEOUT, ) diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index 8916af97..5a233fbd 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -12,7 +12,7 @@ from app.core.metrics import EventMetrics, KubernetesMetrics from app.db.repositories.pod_state_repository import PodStateRepository from app.domain.events.typed import DomainEvent, EventMetadata, ExecutionCompletedEvent -from app.domain.execution.models import ResourceUsageDomain +from app.domain.events.typed import ResourceUsageDomain from app.events.core import UnifiedProducer from app.services.pod_monitor.config import PodMonitorConfig from app.services.pod_monitor.event_mapper import PodEventMapper diff --git a/backend/tests/unit/services/sse/test_sse_service.py b/backend/tests/unit/services/sse/test_sse_service.py index c33298ce..310907c6 100644 --- a/backend/tests/unit/services/sse/test_sse_service.py +++ b/backend/tests/unit/services/sse/test_sse_service.py @@ -10,7 +10,8 @@ from app.db.repositories.sse_repository import SSERepository from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus -from app.domain.execution import DomainExecution, ResourceUsageDomain +from app.domain.events.typed import ResourceUsageDomain +from app.domain.execution import DomainExecution from app.domain.sse import ShutdownStatus, SSEExecutionStatusDomain from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus, SSERedisSubscription From e12a9735984e13a32f31861a5e65d9ac153a197d Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Tue, 27 Jan 2026 21:24:50 +0100 Subject: [PATCH 3/3] other fixes --- backend/app/api/routes/dlq.py | 3 +- backend/app/api/routes/execution.py | 80 +-- backend/app/core/providers.py | 66 +- backend/app/db/docs/dlq.py | 4 +- backend/app/db/repositories/__init__.py | 27 +- backend/app/db/repositories/dlq_repository.py | 3 +- .../execution_queue_repository.py | 234 -------- .../execution_state_repository.py | 65 -- backend/app/db/repositories/redis/__init__.py | 10 + .../redis/idempotency_repository.py | 36 ++ .../{ => redis}/pod_state_repository.py | 17 - .../redis/user_limit_repository.py | 25 + .../db/repositories/resource_repository.py | 300 ---------- backend/app/dlq/manager.py | 80 +-- backend/app/dlq/models.py | 15 +- backend/app/domain/enums/__init__.py | 3 +- backend/app/domain/enums/execution.py | 12 + backend/app/domain/enums/kafka.py | 1 + backend/app/domain/events/typed.py | 16 +- backend/app/domain/idempotency/__init__.py | 13 - backend/app/domain/idempotency/models.py | 43 -- backend/app/events/core/consumer.py | 48 +- backend/app/events/core/dlq_handler.py | 5 +- backend/app/events/core/producer.py | 10 +- backend/app/schemas_pydantic/dlq.py | 9 +- .../app/services/coordinator/coordinator.py | 269 ++------- backend/app/services/execution_service.py | 8 +- backend/app/services/idempotency/__init__.py | 21 - .../idempotency/idempotency_manager.py | 331 ---------- .../app/services/idempotency/middleware.py | 252 -------- .../services/idempotency/redis_repository.py | 141 ----- backend/app/services/k8s_worker/worker.py | 9 +- backend/app/services/kafka_event_service.py | 1 + backend/app/services/notification_service.py | 13 - .../app/services/pod_monitor/event_mapper.py | 2 +- backend/app/services/pod_monitor/monitor.py | 2 +- .../services/result_processor/processor.py | 2 + backend/app/settings.py | 3 + backend/tests/conftest.py | 5 +- .../db/repositories/test_dlq_repository.py | 9 +- backend/tests/e2e/dlq/test_dlq_discard.py | 2 +- backend/tests/e2e/dlq/test_dlq_manager.py | 2 +- backend/tests/e2e/dlq/test_dlq_retry.py | 2 +- backend/tests/e2e/events/test_dlq_handler.py | 19 +- .../e2e/events/test_producer_roundtrip.py | 2 +- .../e2e/events/test_schema_registry_real.py | 2 +- backend/tests/e2e/idempotency/__init__.py | 0 .../idempotency/test_consumer_idempotent.py | 102 ---- .../idempotency/test_decorator_idempotent.py | 53 -- .../tests/e2e/idempotency/test_idempotency.py | 564 ------------------ .../idempotency/test_idempotent_handler.py | 63 -- .../coordinator/test_execution_coordinator.py | 3 +- .../idempotency/test_redis_repository.py | 170 ------ .../tests/e2e/services/sse/test_redis_bus.py | 2 +- backend/tests/e2e/test_dlq_routes.py | 15 +- .../tests/e2e/test_k8s_worker_create_pod.py | 3 +- .../tests/unit/events/test_metadata_model.py | 5 +- .../events/test_schema_registry_manager.py | 2 +- .../unit/services/idempotency/__init__.py | 0 .../idempotency/test_idempotency_manager.py | 104 ---- .../services/idempotency/test_middleware.py | 122 ---- .../services/pod_monitor/test_event_mapper.py | 2 +- .../unit/services/pod_monitor/test_monitor.py | 8 +- .../services/saga/test_saga_step_and_base.py | 2 +- .../services/sse/test_kafka_redis_bridge.py | 2 +- .../tests/unit/services/test_pod_builder.py | 11 +- backend/workers/dlq_processor.py | 22 +- docs/architecture/idempotency.md | 142 ++--- docs/architecture/services-overview.md | 8 +- 69 files changed, 384 insertions(+), 3243 deletions(-) delete mode 100644 backend/app/db/repositories/execution_queue_repository.py delete mode 100644 backend/app/db/repositories/execution_state_repository.py create mode 100644 backend/app/db/repositories/redis/__init__.py create mode 100644 backend/app/db/repositories/redis/idempotency_repository.py rename backend/app/db/repositories/{ => redis}/pod_state_repository.py (91%) create mode 100644 backend/app/db/repositories/redis/user_limit_repository.py delete mode 100644 backend/app/db/repositories/resource_repository.py delete mode 100644 backend/app/domain/idempotency/__init__.py delete mode 100644 backend/app/domain/idempotency/models.py delete mode 100644 backend/app/services/idempotency/__init__.py delete mode 100644 backend/app/services/idempotency/idempotency_manager.py delete mode 100644 backend/app/services/idempotency/middleware.py delete mode 100644 backend/app/services/idempotency/redis_repository.py delete mode 100644 backend/tests/e2e/idempotency/__init__.py delete mode 100644 backend/tests/e2e/idempotency/test_consumer_idempotent.py delete mode 100644 backend/tests/e2e/idempotency/test_decorator_idempotent.py delete mode 100644 backend/tests/e2e/idempotency/test_idempotency.py delete mode 100644 backend/tests/e2e/idempotency/test_idempotent_handler.py delete mode 100644 backend/tests/e2e/services/idempotency/test_redis_repository.py delete mode 100644 backend/tests/unit/services/idempotency/__init__.py delete mode 100644 backend/tests/unit/services/idempotency/test_idempotency_manager.py delete mode 100644 backend/tests/unit/services/idempotency/test_middleware.py diff --git a/backend/app/api/routes/dlq.py b/backend/app/api/routes/dlq.py index bac0c1fa..061092d1 100644 --- a/backend/app/api/routes/dlq.py +++ b/backend/app/api/routes/dlq.py @@ -8,6 +8,7 @@ from app.dlq.manager import DLQManager from app.dlq.models import DLQMessageStatus from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic from app.schemas_pydantic.dlq import ( DLQBatchRetryResponse, DLQMessageDetail, @@ -35,7 +36,7 @@ async def get_dlq_statistics(repository: FromDishka[DLQRepository]) -> DLQStats: async def get_dlq_messages( repository: FromDishka[DLQRepository], status: DLQMessageStatus | None = Query(None), - topic: str | None = None, + topic: KafkaTopic | None = None, event_type: EventType | None = Query(None), limit: int = Query(50, ge=1, le=1000), offset: int = Query(0, ge=0), diff --git a/backend/app/api/routes/execution.py b/backend/app/api/routes/execution.py index 9dd2b6f6..cc05b106 100644 --- a/backend/app/api/routes/execution.py +++ b/backend/app/api/routes/execution.py @@ -1,6 +1,5 @@ -from datetime import datetime, timezone +from datetime import datetime from typing import Annotated -from uuid import uuid4 from dishka import FromDishka from dishka.integrations.fastapi import DishkaRoute, inject @@ -9,12 +8,12 @@ from app.api.dependencies import admin_user, current_user from app.core.tracing import EventAttributes, add_span_attributes from app.core.utils import get_client_ip +from app.db.repositories.redis.idempotency_repository import IdempotencyRepository from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus from app.domain.enums.user import UserRole -from app.domain.events.typed import BaseEvent, DomainEvent, EventMetadata +from app.domain.events.typed import DomainEvent, EventMetadata from app.domain.exceptions import DomainError -from app.domain.idempotency import KeyStrategy from app.schemas_pydantic.execution import ( CancelExecutionRequest, CancelResponse, @@ -31,7 +30,6 @@ from app.schemas_pydantic.user import UserResponse from app.services.event_service import EventService from app.services.execution_service import ExecutionService -from app.services.idempotency import IdempotencyManager from app.services.kafka_event_service import KafkaEventService from app.settings import Settings @@ -58,7 +56,7 @@ async def create_execution( current_user: Annotated[UserResponse, Depends(current_user)], execution: ExecutionRequest, execution_service: FromDishka[ExecutionService], - idempotency_manager: FromDishka[IdempotencyManager], + idempotency_repo: FromDishka[IdempotencyRepository], idempotency_key: Annotated[str | None, Header(alias="Idempotency-Key")] = None, ) -> ExecutionResponse: add_span_attributes( @@ -74,33 +72,14 @@ async def create_execution( ) # Handle idempotency if key provided - pseudo_event = None - if idempotency_key: - # Create a pseudo-event for idempotency tracking - pseudo_event = BaseEvent( - event_id=str(uuid4()), - event_type=EventType.EXECUTION_REQUESTED, - timestamp=datetime.now(timezone.utc), - metadata=EventMetadata( - user_id=current_user.user_id, correlation_id=str(uuid4()), service_name="api", service_version="1.0.0" - ), - ) - - # Check for duplicate request using custom key - idempotency_result = await idempotency_manager.check_and_reserve( - event=pseudo_event, - key_strategy=KeyStrategy.CUSTOM, - custom_key=f"http:{current_user.user_id}:{idempotency_key}", - ttl_seconds=86400, # 24 hours TTL for HTTP idempotency - ) + idem_key = f"exec:{current_user.user_id}:{idempotency_key}" if idempotency_key else None - if idempotency_result.is_duplicate: - cached_json = await idempotency_manager.get_cached_json( - event=pseudo_event, - key_strategy=KeyStrategy.CUSTOM, - custom_key=f"http:{current_user.user_id}:{idempotency_key}", - ) - return ExecutionResponse.model_validate_json(cached_json) + if idem_key: + is_new = await idempotency_repo.try_reserve(idem_key, ttl=86400) + if not is_new: + cached = await idempotency_repo.get_result(idem_key) + if cached: + return ExecutionResponse.model_validate_json(cached) try: client_ip = get_client_ip(request) @@ -114,37 +93,16 @@ async def create_execution( user_agent=user_agent, ) - # Store result for idempotency if key was provided - if idempotency_key and pseudo_event: - response_model = ExecutionResponse.model_validate(exec_result) - await idempotency_manager.mark_completed_with_json( - event=pseudo_event, - cached_json=response_model.model_dump_json(), - key_strategy=KeyStrategy.CUSTOM, - custom_key=f"http:{current_user.user_id}:{idempotency_key}", - ) - - return ExecutionResponse.model_validate(exec_result) - - except DomainError as e: - # Mark as failed for idempotency - if idempotency_key and pseudo_event: - await idempotency_manager.mark_failed( - event=pseudo_event, - error=str(e), - key_strategy=KeyStrategy.CUSTOM, - custom_key=f"http:{current_user.user_id}:{idempotency_key}", - ) + response = ExecutionResponse.model_validate(exec_result) + + if idem_key: + await idempotency_repo.store_result(idem_key, response.model_dump_json(), ttl=86400) + + return response + + except DomainError: raise except Exception as e: - # Mark as failed for idempotency - if idempotency_key and pseudo_event: - await idempotency_manager.mark_failed( - event=pseudo_event, - error=str(e), - key_strategy=KeyStrategy.CUSTOM, - custom_key=f"http:{current_user.user_id}:{idempotency_key}", - ) raise HTTPException(status_code=500, detail="Internal server error during script execution") from e diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index d2465123..2ba8c264 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -38,12 +38,11 @@ from app.db.repositories.admin.admin_settings_repository import AdminSettingsRepository from app.db.repositories.admin.admin_user_repository import AdminUserRepository from app.db.repositories.dlq_repository import DLQRepository -from app.db.repositories.execution_queue_repository import ExecutionQueueRepository -from app.db.repositories.execution_state_repository import ExecutionStateRepository -from app.db.repositories.pod_state_repository import PodStateRepository +from app.db.repositories.redis.idempotency_repository import IdempotencyRepository +from app.db.repositories.redis.pod_state_repository import PodStateRepository +from app.db.repositories.redis.user_limit_repository import UserLimitRepository from app.db.repositories.replay_repository import ReplayRepository from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository -from app.db.repositories.resource_repository import ResourceRepository from app.db.repositories.user_settings_repository import UserSettingsRepository from app.dlq.manager import DLQManager from app.domain.saga.models import SagaConfig @@ -58,9 +57,6 @@ from app.services.event_service import EventService from app.services.execution_service import ExecutionService from app.services.grafana_alert_processor import GrafanaAlertProcessor -from app.services.idempotency import IdempotencyConfig, IdempotencyManager -from app.services.idempotency.idempotency_manager import create_idempotency_manager -from app.services.idempotency.redis_repository import RedisIdempotencyRepository from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.worker import KubernetesWorker from app.services.kafka_event_service import KafkaEventService @@ -134,34 +130,14 @@ class RedisRepositoryProvider(Provider): scope = Scope.APP @provide - def get_execution_state_repository( - self, redis_client: redis.Redis, logger: logging.Logger - ) -> ExecutionStateRepository: - return ExecutionStateRepository(redis_client, logger) - - @provide - def get_execution_queue_repository( - self, redis_client: redis.Redis, logger: logging.Logger, settings: Settings - ) -> ExecutionQueueRepository: - return ExecutionQueueRepository( - redis_client, - logger, - max_queue_size=10000, - max_executions_per_user=100, - ) + def get_idempotency_repository(self, redis_client: redis.Redis) -> IdempotencyRepository: + return IdempotencyRepository(redis_client) @provide - async def get_resource_repository( - self, redis_client: redis.Redis, logger: logging.Logger, settings: Settings - ) -> ResourceRepository: - repo = ResourceRepository( - redis_client, - logger, - total_cpu_cores=32.0, - total_memory_mb=65536, - ) - await repo.initialize() - return repo + def get_user_limit_repository( + self, redis_client: redis.Redis, settings: Settings + ) -> UserLimitRepository: + return UserLimitRepository(redis_client, max_per_user=settings.MAX_EXECUTIONS_PER_USER) @provide def get_pod_state_repository( @@ -262,22 +238,6 @@ def get_dlq_manager( dlq_metrics=dlq_metrics, ) - @provide - def get_idempotency_repository(self, redis_client: redis.Redis) -> RedisIdempotencyRepository: - return RedisIdempotencyRepository(redis_client, key_prefix="idempotency") - - @provide - async def get_idempotency_manager( - self, repo: RedisIdempotencyRepository, logger: logging.Logger, database_metrics: DatabaseMetrics - ) -> AsyncIterator[IdempotencyManager]: - manager = create_idempotency_manager( - repository=repo, config=IdempotencyConfig(), logger=logger, database_metrics=database_metrics - ) - await manager.initialize() - try: - yield manager - finally: - await manager.close() class EventProvider(Provider): @@ -622,9 +582,7 @@ def get_execution_coordinator( self, kafka_producer: UnifiedProducer, execution_repository: ExecutionRepository, - state_repo: ExecutionStateRepository, - queue_repo: ExecutionQueueRepository, - resource_repo: ResourceRepository, + user_limit_repo: UserLimitRepository, logger: logging.Logger, coordinator_metrics: CoordinatorMetrics, event_metrics: EventMetrics, @@ -632,9 +590,7 @@ def get_execution_coordinator( return ExecutionCoordinator( producer=kafka_producer, execution_repository=execution_repository, - state_repo=state_repo, - queue_repo=queue_repo, - resource_repo=resource_repo, + user_limit_repo=user_limit_repo, logger=logger, coordinator_metrics=coordinator_metrics, event_metrics=event_metrics, diff --git a/backend/app/db/docs/dlq.py b/backend/app/db/docs/dlq.py index 71e2f7a3..5aedcc73 100644 --- a/backend/app/db/docs/dlq.py +++ b/backend/app/db/docs/dlq.py @@ -5,6 +5,7 @@ from pymongo import ASCENDING, DESCENDING, IndexModel from app.dlq.models import DLQMessageStatus +from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent @@ -14,7 +15,7 @@ class DLQMessageDocument(Document): model_config = ConfigDict(from_attributes=True) event: DomainEvent # Discriminated union - contains event_id, event_type - original_topic: Indexed(str) = "" # type: ignore[valid-type] + original_topic: KafkaTopic error: str = "Unknown error" retry_count: Indexed(int) = 0 # type: ignore[valid-type] failed_at: Indexed(datetime) = Field(default_factory=lambda: datetime.now(timezone.utc)) # type: ignore[valid-type] @@ -37,6 +38,7 @@ class Settings: indexes = [ IndexModel([("event.event_id", ASCENDING)], unique=True, name="idx_dlq_event_id"), IndexModel([("event.event_type", ASCENDING)], name="idx_dlq_event_type"), + IndexModel([("original_topic", ASCENDING)], name="idx_dlq_original_topic"), IndexModel([("status", ASCENDING)], name="idx_dlq_status"), IndexModel([("failed_at", DESCENDING)], name="idx_dlq_failed_desc"), IndexModel([("created_at", ASCENDING)], name="idx_dlq_created_ttl", expireAfterSeconds=7 * 24 * 3600), diff --git a/backend/app/db/repositories/__init__.py b/backend/app/db/repositories/__init__.py index c5e0199c..07b45a7e 100644 --- a/backend/app/db/repositories/__init__.py +++ b/backend/app/db/repositories/__init__.py @@ -1,13 +1,14 @@ from app.db.repositories.admin.admin_settings_repository import AdminSettingsRepository from app.db.repositories.admin.admin_user_repository import AdminUserRepository +from app.db.repositories.dlq_repository import DLQRepository from app.db.repositories.event_repository import EventRepository -from app.db.repositories.execution_queue_repository import ExecutionQueueRepository, QueuePriority, QueueStats from app.db.repositories.execution_repository import ExecutionRepository -from app.db.repositories.execution_state_repository import ExecutionStateRepository from app.db.repositories.notification_repository import NotificationRepository -from app.db.repositories.pod_state_repository import PodStateRepository +from app.db.repositories.redis.idempotency_repository import IdempotencyRepository +from app.db.repositories.redis.pod_state_repository import PodState, PodStateRepository +from app.db.repositories.redis.user_limit_repository import UserLimitRepository from app.db.repositories.replay_repository import ReplayRepository -from app.db.repositories.resource_repository import ResourceAllocation, ResourceRepository, ResourceStats +from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository from app.db.repositories.saga_repository import SagaRepository from app.db.repositories.saved_script_repository import SavedScriptRepository from app.db.repositories.sse_repository import SSERepository @@ -15,23 +16,23 @@ from app.db.repositories.user_settings_repository import UserSettingsRepository __all__ = [ + # MongoDB repositories "AdminSettingsRepository", "AdminUserRepository", + "DLQRepository", "EventRepository", "ExecutionRepository", - "ExecutionQueueRepository", - "ExecutionStateRepository", "NotificationRepository", - "PodStateRepository", - "QueuePriority", - "QueueStats", "ReplayRepository", - "ResourceAllocation", - "ResourceRepository", - "ResourceStats", + "ResourceAllocationRepository", "SagaRepository", "SavedScriptRepository", "SSERepository", - "UserSettingsRepository", "UserRepository", + "UserSettingsRepository", + # Redis repositories + "IdempotencyRepository", + "PodState", + "PodStateRepository", + "UserLimitRepository", ] diff --git a/backend/app/db/repositories/dlq_repository.py b/backend/app/db/repositories/dlq_repository.py index 6390a7b2..52bf8e94 100644 --- a/backend/app/db/repositories/dlq_repository.py +++ b/backend/app/db/repositories/dlq_repository.py @@ -17,6 +17,7 @@ TopicStatistic, ) from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic class DLQRepository: @@ -88,7 +89,7 @@ async def get_dlq_stats(self) -> DLQStatistics: async def get_messages( self, status: DLQMessageStatus | None = None, - topic: str | None = None, + topic: KafkaTopic | None = None, event_type: EventType | None = None, limit: int = 50, offset: int = 0, diff --git a/backend/app/db/repositories/execution_queue_repository.py b/backend/app/db/repositories/execution_queue_repository.py deleted file mode 100644 index d24af7bf..00000000 --- a/backend/app/db/repositories/execution_queue_repository.py +++ /dev/null @@ -1,234 +0,0 @@ -"""Redis-backed execution queue repository. - -Replaces in-memory priority queue (QueueManager) with Redis sorted sets -for stateless, horizontally-scalable services. -""" - -from __future__ import annotations - -import json -import logging -import time -from dataclasses import dataclass -from enum import IntEnum - -import redis.asyncio as redis - - -class QueuePriority(IntEnum): - """Execution queue priorities. Lower value = higher priority.""" - - CRITICAL = 0 - HIGH = 1 - NORMAL = 5 - LOW = 8 - BACKGROUND = 10 - - -@dataclass -class QueueStats: - """Queue statistics.""" - - total_size: int - priority_distribution: dict[str, int] - max_queue_size: int - utilization_percent: float - - -class ExecutionQueueRepository: - """Redis-backed priority queue for executions. - - Uses Redis sorted sets for O(log N) priority queue operations. - Stores event data in hash maps for retrieval. - """ - - QUEUE_KEY = "exec:queue" - DATA_KEY_PREFIX = "exec:queue:data" - USER_COUNT_KEY = "exec:queue:user_count" - - def __init__( - self, - redis_client: redis.Redis, - logger: logging.Logger, - max_queue_size: int = 10000, - max_executions_per_user: int = 100, - stale_timeout_seconds: int = 3600, - ) -> None: - self._redis = redis_client - self._logger = logger - self.max_queue_size = max_queue_size - self.max_executions_per_user = max_executions_per_user - self.stale_timeout_seconds = stale_timeout_seconds - - async def enqueue( - self, - execution_id: str, - event_data: dict[str, object], - priority: QueuePriority, - user_id: str, - ) -> tuple[bool, int | None, str | None]: - """Add execution to queue. Returns (success, position, error).""" - # Check queue size - queue_size = await self._redis.zcard(self.QUEUE_KEY) - if queue_size >= self.max_queue_size: - return False, None, "Queue is full" - - # Check user limit - user_count = await self._redis.hincrby(self.USER_COUNT_KEY, user_id, 0) # type: ignore[misc] - if user_count >= self.max_executions_per_user: - return False, None, f"User execution limit exceeded ({self.max_executions_per_user})" - - # Score: priority * 1e12 + timestamp (lower = higher priority, earlier = higher priority) - timestamp = time.time() - score = priority.value * 1e12 + timestamp - - # Use pipeline for atomicity - pipe = self._redis.pipeline() - - # Add to sorted set - pipe.zadd(self.QUEUE_KEY, {execution_id: score}) - - # Store event data - data_key = f"{self.DATA_KEY_PREFIX}:{execution_id}" - event_data["_enqueue_timestamp"] = timestamp - event_data["_priority"] = priority.value - event_data["_user_id"] = user_id - pipe.hset(data_key, mapping={k: json.dumps(v) if not isinstance(v, str) else v for k, v in event_data.items()}) - pipe.expire(data_key, self.stale_timeout_seconds + 60) - - # Increment user count - pipe.hincrby(self.USER_COUNT_KEY, user_id, 1) - - await pipe.execute() - - # Get position - position = await self._redis.zrank(self.QUEUE_KEY, execution_id) - - self._logger.info( - f"Enqueued execution {execution_id}. Priority: {priority.name}, " - f"Position: {position}, Queue size: {queue_size + 1}" - ) - - return True, position, None - - async def dequeue(self) -> tuple[str, dict[str, object | float | str]] | None: - """Remove and return highest priority execution. Returns (execution_id, event_data) or None.""" - while True: - # Pop the lowest score (highest priority) - result = await self._redis.zpopmin(self.QUEUE_KEY, count=1) - if not result: - return None - - execution_id = result[0][0] - if isinstance(execution_id, bytes): - execution_id = execution_id.decode() - - # Get event data - data_key = f"{self.DATA_KEY_PREFIX}:{execution_id}" - raw_data = await self._redis.hgetall(data_key) # type: ignore[misc] - - if not raw_data: - # Data expired or missing, skip this entry - self._logger.warning(f"Queue entry {execution_id} has no data, skipping") - continue - - # Parse data - event_data: dict[str, object | float | str] = {} - for k, v in raw_data.items(): - key = k.decode() if isinstance(k, bytes) else k - val = v.decode() if isinstance(v, bytes) else v - try: - event_data[key] = json.loads(val) - except (json.JSONDecodeError, TypeError): - event_data[key] = val - - # Check if stale - enqueue_time_val = event_data.pop("_enqueue_timestamp", 0) - enqueue_time = float(enqueue_time_val) if isinstance(enqueue_time_val, (int, float, str)) else 0.0 - event_data.pop("_priority", None) - user_id_val = event_data.pop("_user_id", "anonymous") - user_id = str(user_id_val) - - age = time.time() - enqueue_time - if age > self.stale_timeout_seconds: - # Stale, clean up and continue - await self._redis.delete(data_key) - await self._redis.hincrby(self.USER_COUNT_KEY, user_id, -1) # type: ignore[misc] - self._logger.info(f"Skipped stale execution {execution_id} (age: {age:.2f}s)") - continue - - # Clean up - await self._redis.delete(data_key) - await self._redis.hincrby(self.USER_COUNT_KEY, user_id, -1) # type: ignore[misc] - - self._logger.info(f"Dequeued execution {execution_id}. Wait time: {age:.2f}s") - return execution_id, event_data - - async def remove(self, execution_id: str) -> bool: - """Remove specific execution from queue. Returns True if removed.""" - # Get user_id before removing - data_key = f"{self.DATA_KEY_PREFIX}:{execution_id}" - raw_data = await self._redis.hgetall(data_key) # type: ignore[misc] - - removed = await self._redis.zrem(self.QUEUE_KEY, execution_id) - if removed: - # Decrement user count - if raw_data: - user_id_raw = raw_data.get(b"_user_id") or raw_data.get("_user_id") - if user_id_raw: - user_id = user_id_raw.decode() if isinstance(user_id_raw, bytes) else user_id_raw - try: - user_id = json.loads(user_id) - except (json.JSONDecodeError, TypeError): - pass - await self._redis.hincrby(self.USER_COUNT_KEY, str(user_id), -1) # type: ignore[misc] - - await self._redis.delete(data_key) - self._logger.info(f"Removed execution {execution_id} from queue") - return True - return False - - async def get_position(self, execution_id: str) -> int | None: - """Get queue position of execution (0-indexed).""" - result = await self._redis.zrank(self.QUEUE_KEY, execution_id) - return int(result) if result is not None else None - - async def get_stats(self) -> QueueStats: - """Get queue statistics.""" - total_size = await self._redis.zcard(self.QUEUE_KEY) - - # Count by priority (sample first 1000) - priority_counts: dict[str, int] = {} - entries = await self._redis.zrange(self.QUEUE_KEY, 0, 999, withscores=True) - for _, score in entries: - priority_value = int(score // 1e12) - try: - priority_name = QueuePriority(priority_value).name - except ValueError: - priority_name = "UNKNOWN" - priority_counts[priority_name] = priority_counts.get(priority_name, 0) + 1 - - return QueueStats( - total_size=total_size, - priority_distribution=priority_counts, - max_queue_size=self.max_queue_size, - utilization_percent=(total_size / self.max_queue_size) * 100 if self.max_queue_size > 0 else 0, - ) - - async def cleanup_stale(self) -> int: - """Remove stale entries. Returns count removed. Call periodically.""" - removed = 0 - threshold_score = QueuePriority.BACKGROUND.value * 1e12 + (time.time() - self.stale_timeout_seconds) - - # Get entries older than threshold - stale_entries = await self._redis.zrangebyscore(self.QUEUE_KEY, "-inf", threshold_score, start=0, num=100) - - for entry in stale_entries: - execution_id = entry.decode() if isinstance(entry, bytes) else entry - if await self.remove(execution_id): - removed += 1 - - if removed: - self._logger.info(f"Cleaned {removed} stale executions from queue") - - return removed diff --git a/backend/app/db/repositories/execution_state_repository.py b/backend/app/db/repositories/execution_state_repository.py deleted file mode 100644 index e343ff02..00000000 --- a/backend/app/db/repositories/execution_state_repository.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Redis-backed execution state tracking repository. - -Replaces in-memory state tracking (_active_executions sets) with Redis -for stateless, horizontally-scalable services. -""" - -from __future__ import annotations - -import logging - -import redis.asyncio as redis - - -class ExecutionStateRepository: - """Redis-backed execution state tracking. - - Provides atomic claim/release operations for executions, - replacing in-memory sets like `_active_executions`. - """ - - KEY_PREFIX = "exec:active" - - def __init__(self, redis_client: redis.Redis, logger: logging.Logger) -> None: - self._redis = redis_client - self._logger = logger - - async def try_claim(self, execution_id: str, ttl_seconds: int = 3600) -> bool: - """Atomically claim an execution. Returns True if claimed, False if already claimed. - - Uses Redis SETNX for atomic check-and-set. - TTL ensures cleanup if service crashes without releasing. - """ - key = f"{self.KEY_PREFIX}:{execution_id}" - result = await self._redis.set(key, "1", nx=True, ex=ttl_seconds) - if result: - self._logger.debug(f"Claimed execution {execution_id}") - return result is not None - - async def is_active(self, execution_id: str) -> bool: - """Check if an execution is currently active/claimed.""" - key = f"{self.KEY_PREFIX}:{execution_id}" - result = await self._redis.exists(key) - return bool(result) - - async def remove(self, execution_id: str) -> bool: - """Release/remove an execution claim. Returns True if was claimed.""" - key = f"{self.KEY_PREFIX}:{execution_id}" - deleted = await self._redis.delete(key) - if deleted: - self._logger.debug(f"Released execution {execution_id}") - return bool(deleted > 0) - - async def get_active_count(self) -> int: - """Get count of active executions. For metrics only.""" - pattern = f"{self.KEY_PREFIX}:*" - count = 0 - async for _ in self._redis.scan_iter(match=pattern, count=100): - count += 1 - return count - - async def extend_ttl(self, execution_id: str, ttl_seconds: int = 3600) -> bool: - """Extend the TTL of an active execution. Returns True if extended.""" - key = f"{self.KEY_PREFIX}:{execution_id}" - result = await self._redis.expire(key, ttl_seconds) - return bool(result) diff --git a/backend/app/db/repositories/redis/__init__.py b/backend/app/db/repositories/redis/__init__.py new file mode 100644 index 00000000..2038e67a --- /dev/null +++ b/backend/app/db/repositories/redis/__init__.py @@ -0,0 +1,10 @@ +from app.db.repositories.redis.idempotency_repository import IdempotencyRepository +from app.db.repositories.redis.pod_state_repository import PodState, PodStateRepository +from app.db.repositories.redis.user_limit_repository import UserLimitRepository + +__all__ = [ + "IdempotencyRepository", + "PodState", + "PodStateRepository", + "UserLimitRepository", +] diff --git a/backend/app/db/repositories/redis/idempotency_repository.py b/backend/app/db/repositories/redis/idempotency_repository.py new file mode 100644 index 00000000..64700175 --- /dev/null +++ b/backend/app/db/repositories/redis/idempotency_repository.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import redis.asyncio as redis + + +class IdempotencyRepository: + """Simple idempotency using Redis SET NX. + + Pattern: + 1. try_reserve(key) - returns True if new, False if duplicate + 2. If duplicate and need cached result - get_result(key) + 3. After processing - store_result(key, json) + """ + + KEY_PREFIX = "idempotent" + + def __init__(self, redis_client: redis.Redis, default_ttl: int = 86400) -> None: + self._redis = redis_client + self._default_ttl = default_ttl + + async def try_reserve(self, key: str, ttl: int | None = None) -> bool: + """Reserve key atomically. Returns True if new (should process), False if duplicate.""" + full_key = f"{self.KEY_PREFIX}:{key}" + result = await self._redis.set(full_key, "1", nx=True, ex=ttl or self._default_ttl) + return result is not None + + async def store_result(self, key: str, result_json: str, ttl: int | None = None) -> None: + """Store result for duplicate requests to retrieve.""" + full_key = f"{self.KEY_PREFIX}:{key}" + await self._redis.set(full_key, result_json, ex=ttl or self._default_ttl) + + async def get_result(self, key: str) -> str | None: + """Get cached result if exists.""" + full_key = f"{self.KEY_PREFIX}:{key}" + result = await self._redis.get(full_key) + return str(result) if result is not None else None diff --git a/backend/app/db/repositories/pod_state_repository.py b/backend/app/db/repositories/redis/pod_state_repository.py similarity index 91% rename from backend/app/db/repositories/pod_state_repository.py rename to backend/app/db/repositories/redis/pod_state_repository.py index 0e652720..5f80359b 100644 --- a/backend/app/db/repositories/pod_state_repository.py +++ b/backend/app/db/repositories/redis/pod_state_repository.py @@ -1,12 +1,5 @@ -"""Redis-backed pod state tracking repository. - -Replaces in-memory pod state tracking (_tracked_pods, _active_creations) -for stateless, horizontally-scalable services. -""" - from __future__ import annotations -import json import logging from dataclasses import dataclass from datetime import datetime, timezone @@ -23,7 +16,6 @@ class PodState: status: str created_at: datetime updated_at: datetime - metadata: dict[str, object] | None = None class PodStateRepository: @@ -79,7 +71,6 @@ async def track_pod( pod_name: str, execution_id: str, status: str, - metadata: dict[str, object] | None = None, ttl_seconds: int = 7200, ) -> None: """Track a pod's state.""" @@ -92,7 +83,6 @@ async def track_pod( "status": status, "created_at": now, "updated_at": now, - "metadata": json.dumps(metadata) if metadata else "{}", } await self._redis.hset(key, mapping=data) # type: ignore[misc] @@ -129,19 +119,12 @@ def get_str(k: str) -> str: val = data.get(k.encode(), data.get(k, "")) return val.decode() if isinstance(val, bytes) else str(val) - metadata_str = get_str("metadata") - try: - metadata = json.loads(metadata_str) if metadata_str else None - except json.JSONDecodeError: - metadata = None - return PodState( pod_name=get_str("pod_name"), execution_id=get_str("execution_id"), status=get_str("status"), created_at=datetime.fromisoformat(get_str("created_at")), updated_at=datetime.fromisoformat(get_str("updated_at")), - metadata=metadata, ) async def is_pod_tracked(self, pod_name: str) -> bool: diff --git a/backend/app/db/repositories/redis/user_limit_repository.py b/backend/app/db/repositories/redis/user_limit_repository.py new file mode 100644 index 00000000..b6ce09d9 --- /dev/null +++ b/backend/app/db/repositories/redis/user_limit_repository.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import redis.asyncio as redis + + +class UserLimitRepository: + """Simple per-user execution counter.""" + + KEY = "exec:user_count" + + def __init__(self, redis_client: redis.Redis, max_per_user: int = 100) -> None: + self._redis = redis_client + self._max_per_user = max_per_user + + async def try_increment(self, user_id: str) -> bool: + """Increment user count. Returns True if under limit, False if limit exceeded.""" + count = await self._redis.hincrby(self.KEY, user_id, 1) # type: ignore[misc] + if count > self._max_per_user: + await self._redis.hincrby(self.KEY, user_id, -1) # type: ignore[misc] + return False + return True + + async def decrement(self, user_id: str) -> None: + """Decrement user count.""" + await self._redis.hincrby(self.KEY, user_id, -1) # type: ignore[misc] diff --git a/backend/app/db/repositories/resource_repository.py b/backend/app/db/repositories/resource_repository.py deleted file mode 100644 index 1f6b54b0..00000000 --- a/backend/app/db/repositories/resource_repository.py +++ /dev/null @@ -1,300 +0,0 @@ -"""Redis-backed resource allocation repository. - -Replaces in-memory resource tracking (ResourceManager) with Redis -for stateless, horizontally-scalable services. -""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass - -import redis.asyncio as redis - - -@dataclass -class ResourceAllocation: - """Resource allocation for an execution.""" - - execution_id: str - cpu_cores: float - memory_mb: int - gpu_count: int = 0 - - @property - def cpu_millicores(self) -> int: - """Get CPU in millicores for Kubernetes.""" - return int(self.cpu_cores * 1000) - - @property - def memory_bytes(self) -> int: - """Get memory in bytes.""" - return self.memory_mb * 1024 * 1024 - - -@dataclass -class ResourceStats: - """Resource statistics.""" - - total_cpu: float - total_memory_mb: int - total_gpu: int - available_cpu: float - available_memory_mb: int - available_gpu: int - allocation_count: int - - -class ResourceRepository: - """Redis-backed resource allocation tracking. - - Uses Redis for atomic resource allocation with Lua scripts. - Replaces in-memory ResourceManager._allocations dict. - """ - - POOL_KEY = "resource:pool" - ALLOC_KEY_PREFIX = "resource:alloc" - - # Default allocations by language - DEFAULT_ALLOCATIONS = { - "python": (0.5, 512), - "javascript": (0.5, 512), - "go": (0.25, 256), - "rust": (0.5, 512), - "java": (1.0, 1024), - "cpp": (0.5, 512), - "r": (1.0, 2048), - } - - def __init__( - self, - redis_client: redis.Redis, - logger: logging.Logger, - total_cpu_cores: float = 32.0, - total_memory_mb: int = 65536, - total_gpu_count: int = 0, - overcommit_factor: float = 1.2, - max_cpu_per_execution: float = 4.0, - max_memory_per_execution_mb: int = 8192, - min_reserve_cpu: float = 2.0, - min_reserve_memory_mb: int = 4096, - ) -> None: - self._redis = redis_client - self._logger = logger - - # Apply overcommit - self._total_cpu = total_cpu_cores * overcommit_factor - self._total_memory = int(total_memory_mb * overcommit_factor) - self._total_gpu = total_gpu_count - - self._max_cpu_per_exec = max_cpu_per_execution - self._max_memory_per_exec = max_memory_per_execution_mb - - # Adjust reserves for small pools (max 10% of total) - self._min_reserve_cpu = min(min_reserve_cpu, 0.1 * self._total_cpu) - self._min_reserve_memory = min(min_reserve_memory_mb, int(0.1 * self._total_memory)) - - async def initialize(self) -> None: - """Initialize the resource pool if not exists.""" - exists = await self._redis.exists(self.POOL_KEY) - if not exists: - await self._redis.hset( # type: ignore[misc] - self.POOL_KEY, - mapping={ - "total_cpu": str(self._total_cpu), - "total_memory": str(self._total_memory), - "total_gpu": str(self._total_gpu), - "available_cpu": str(self._total_cpu), - "available_memory": str(self._total_memory), - "available_gpu": str(self._total_gpu), - }, - ) - self._logger.info( - f"Initialized resource pool: {self._total_cpu} CPU, " - f"{self._total_memory}MB RAM, {self._total_gpu} GPU" - ) - - async def allocate( - self, - execution_id: str, - language: str, - requested_cpu: float | None = None, - requested_memory_mb: int | None = None, - requested_gpu: int = 0, - ) -> ResourceAllocation | None: - """Allocate resources for execution. Returns allocation or None if insufficient.""" - # Check if already allocated - alloc_key = f"{self.ALLOC_KEY_PREFIX}:{execution_id}" - existing = await self._redis.hgetall(alloc_key) # type: ignore[misc] - if existing: - self._logger.warning(f"Execution {execution_id} already has allocation") - return ResourceAllocation( - execution_id=execution_id, - cpu_cores=float(existing.get(b"cpu", existing.get("cpu", 0))), - memory_mb=int(existing.get(b"memory", existing.get("memory", 0))), - gpu_count=int(existing.get(b"gpu", existing.get("gpu", 0))), - ) - - # Determine requested resources - if requested_cpu is None or requested_memory_mb is None: - default_cpu, default_memory = self.DEFAULT_ALLOCATIONS.get(language, (0.5, 512)) - requested_cpu = requested_cpu or default_cpu - requested_memory_mb = requested_memory_mb or default_memory - - # Apply limits - requested_cpu = min(requested_cpu, self._max_cpu_per_exec) - requested_memory_mb = min(requested_memory_mb, self._max_memory_per_exec) - - # Atomic allocation using Lua script - lua_script = """ - local pool_key = KEYS[1] - local alloc_key = KEYS[2] - local req_cpu = tonumber(ARGV[1]) - local req_memory = tonumber(ARGV[2]) - local req_gpu = tonumber(ARGV[3]) - local min_cpu = tonumber(ARGV[4]) - local min_memory = tonumber(ARGV[5]) - - local avail_cpu = tonumber(redis.call('HGET', pool_key, 'available_cpu') or '0') - local avail_memory = tonumber(redis.call('HGET', pool_key, 'available_memory') or '0') - local avail_gpu = tonumber(redis.call('HGET', pool_key, 'available_gpu') or '0') - - local cpu_after = avail_cpu - req_cpu - local memory_after = avail_memory - req_memory - local gpu_after = avail_gpu - req_gpu - - if cpu_after < min_cpu or memory_after < min_memory or gpu_after < 0 then - return 0 - end - - redis.call('HSET', pool_key, 'available_cpu', tostring(cpu_after)) - redis.call('HSET', pool_key, 'available_memory', tostring(memory_after)) - redis.call('HSET', pool_key, 'available_gpu', tostring(gpu_after)) - - redis.call('HSET', alloc_key, 'cpu', tostring(req_cpu), 'memory', tostring(req_memory), - 'gpu', tostring(req_gpu)) - redis.call('EXPIRE', alloc_key, 7200) - - return 1 - """ - - result = await self._redis.eval( # type: ignore[misc] - lua_script, - 2, - self.POOL_KEY, - alloc_key, - str(requested_cpu), - str(requested_memory_mb), - str(requested_gpu), - str(self._min_reserve_cpu), - str(self._min_reserve_memory), - ) - - if not result: - pool = await self._redis.hgetall(self.POOL_KEY) # type: ignore[misc] - avail_cpu = float(pool.get(b"available_cpu", pool.get("available_cpu", 0))) - avail_memory = int(float(pool.get(b"available_memory", pool.get("available_memory", 0)))) - self._logger.warning( - f"Insufficient resources for {execution_id}. " - f"Requested: {requested_cpu} CPU, {requested_memory_mb}MB. " - f"Available: {avail_cpu} CPU, {avail_memory}MB" - ) - return None - - self._logger.info( - f"Allocated resources for {execution_id}: " - f"{requested_cpu} CPU, {requested_memory_mb}MB RAM, {requested_gpu} GPU" - ) - - return ResourceAllocation( - execution_id=execution_id, - cpu_cores=requested_cpu, - memory_mb=requested_memory_mb, - gpu_count=requested_gpu, - ) - - async def release(self, execution_id: str) -> bool: - """Release resource allocation. Returns True if released.""" - alloc_key = f"{self.ALLOC_KEY_PREFIX}:{execution_id}" - - # Get current allocation - alloc = await self._redis.hgetall(alloc_key) # type: ignore[misc] - if not alloc: - self._logger.warning(f"No allocation found for {execution_id}") - return False - - cpu = float(alloc.get(b"cpu", alloc.get("cpu", 0))) - memory = int(float(alloc.get(b"memory", alloc.get("memory", 0)))) - gpu = int(alloc.get(b"gpu", alloc.get("gpu", 0))) - - # Release atomically - pipe = self._redis.pipeline() - pipe.hincrbyfloat(self.POOL_KEY, "available_cpu", cpu) - pipe.hincrbyfloat(self.POOL_KEY, "available_memory", memory) - pipe.hincrby(self.POOL_KEY, "available_gpu", gpu) - pipe.delete(alloc_key) - await pipe.execute() - - self._logger.info(f"Released resources for {execution_id}: {cpu} CPU, {memory}MB RAM, {gpu} GPU") - return True - - async def get_allocation(self, execution_id: str) -> ResourceAllocation | None: - """Get current allocation for execution.""" - alloc_key = f"{self.ALLOC_KEY_PREFIX}:{execution_id}" - alloc = await self._redis.hgetall(alloc_key) # type: ignore[misc] - if not alloc: - return None - - return ResourceAllocation( - execution_id=execution_id, - cpu_cores=float(alloc.get(b"cpu", alloc.get("cpu", 0))), - memory_mb=int(float(alloc.get(b"memory", alloc.get("memory", 0)))), - gpu_count=int(alloc.get(b"gpu", alloc.get("gpu", 0))), - ) - - async def get_stats(self) -> ResourceStats: - """Get resource statistics.""" - pool = await self._redis.hgetall(self.POOL_KEY) # type: ignore[misc] - - # Decode bytes if needed - def get_val(key: str, default: str = "0") -> str: - return str(pool.get(key.encode(), pool.get(key, default))) - - total_cpu = float(get_val("total_cpu")) - total_memory = int(float(get_val("total_memory"))) - total_gpu = int(get_val("total_gpu")) - available_cpu = float(get_val("available_cpu")) - available_memory = int(float(get_val("available_memory"))) - available_gpu = int(get_val("available_gpu")) - - # Count allocations - count = 0 - async for _ in self._redis.scan_iter(match=f"{self.ALLOC_KEY_PREFIX}:*", count=100): - count += 1 - - return ResourceStats( - total_cpu=total_cpu, - total_memory_mb=total_memory, - total_gpu=total_gpu, - available_cpu=available_cpu, - available_memory_mb=available_memory, - available_gpu=available_gpu, - allocation_count=count, - ) - - async def can_allocate(self, cpu_cores: float, memory_mb: int, gpu_count: int = 0) -> bool: - """Check if resources can be allocated.""" - pool = await self._redis.hgetall(self.POOL_KEY) # type: ignore[misc] - - def get_val(key: str) -> float: - return float(pool.get(key.encode(), pool.get(key, 0))) - - available_cpu = get_val("available_cpu") - available_memory = get_val("available_memory") - available_gpu = get_val("available_gpu") - - return ( - (available_cpu - cpu_cores) >= self._min_reserve_cpu - and (available_memory - memory_mb) >= self._min_reserve_memory - and (available_gpu - gpu_count) >= 0 - ) diff --git a/backend/app/dlq/manager.py b/backend/app/dlq/manager.py index c1f5472b..4125851e 100644 --- a/backend/app/dlq/manager.py +++ b/backend/app/dlq/manager.py @@ -1,9 +1,3 @@ -"""DLQ Manager - stateless event handler. - -Manages Dead Letter Queue messages. Receives events, -processes them, and handles retries. No lifecycle management. -""" - from __future__ import annotations import asyncio @@ -63,14 +57,14 @@ def __init__( self._dlq_topic = dlq_topic self._retry_topic_suffix = retry_topic_suffix self._default_retry_policy = RetryPolicy( - topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF + topic=KafkaTopic.DEAD_LETTER_QUEUE, strategy=RetryStrategy.EXPONENTIAL_BACKOFF ) - self._retry_policies: dict[str, RetryPolicy] = {} + self._retry_policies: dict[KafkaTopic, RetryPolicy] = {} self._filters: list[object] = [] self._dlq_events_topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.DLQ_EVENTS}" - self._event_metadata = EventMetadata(service_name="dlq-manager", service_version="1.0.0") + self._event_metadata = EventMetadata(service_name="dlq-manager", service_version="1.0.0", user_id="system") - def set_retry_policy(self, topic: str, policy: RetryPolicy) -> None: + def set_retry_policy(self, topic: KafkaTopic, policy: RetryPolicy) -> None: """Set retry policy for a specific topic.""" self._retry_policies[topic] = policy @@ -89,34 +83,30 @@ async def handle_dlq_message(self, raw_message: bytes, headers: dict[str, str]) """ start = asyncio.get_running_loop().time() - try: - data = json.loads(raw_message) - dlq_msg = DLQMessage(**data, headers=headers) - - self._metrics.record_dlq_message_received(dlq_msg.original_topic, dlq_msg.event.event_type) - self._metrics.record_dlq_message_age( - (datetime.now(timezone.utc) - dlq_msg.failed_at).total_seconds() - ) + data = json.loads(raw_message) + dlq_msg = DLQMessage(**data, headers=headers) - ctx = extract_trace_context(dlq_msg.headers) - with get_tracer().start_as_current_span( - name="dlq.consume", - context=ctx, - kind=SpanKind.CONSUMER, - attributes={ - EventAttributes.KAFKA_TOPIC: str(self._dlq_topic), - EventAttributes.EVENT_TYPE: dlq_msg.event.event_type, - EventAttributes.EVENT_ID: dlq_msg.event.event_id, - }, - ): - await self._process_dlq_message(dlq_msg) - - self._metrics.record_dlq_processing_duration( - asyncio.get_running_loop().time() - start, "process" - ) + self._metrics.record_dlq_message_received(dlq_msg.original_topic, dlq_msg.event.event_type) + self._metrics.record_dlq_message_age( + (datetime.now(timezone.utc) - dlq_msg.failed_at).total_seconds() + ) - except Exception as e: - self._logger.error(f"Error processing DLQ message: {e}") + ctx = extract_trace_context(dlq_msg.headers) + with get_tracer().start_as_current_span( + name="dlq.consume", + context=ctx, + kind=SpanKind.CONSUMER, + attributes={ + EventAttributes.KAFKA_TOPIC: str(self._dlq_topic), + EventAttributes.EVENT_TYPE: dlq_msg.event.event_type, + EventAttributes.EVENT_ID: dlq_msg.event.event_id, + }, + ): + await self._process_dlq_message(dlq_msg) + + self._metrics.record_dlq_processing_duration( + asyncio.get_running_loop().time() - start, "process" + ) async def _process_dlq_message(self, message: DLQMessage) -> None: """Process a DLQ message.""" @@ -163,21 +153,9 @@ async def _update_message_status(self, event_id: str, update: DLQMessageUpdate) if not doc: return - update_dict: dict[str, object] = {"status": update.status, "last_updated": datetime.now(timezone.utc)} - if update.next_retry_at is not None: - update_dict["next_retry_at"] = update.next_retry_at - if update.retried_at is not None: - update_dict["retried_at"] = update.retried_at - if update.discarded_at is not None: - update_dict["discarded_at"] = update.discarded_at - if update.retry_count is not None: - update_dict["retry_count"] = update.retry_count - if update.discard_reason is not None: - update_dict["discard_reason"] = update.discard_reason - if update.last_error is not None: - update_dict["last_error"] = update.last_error - - await doc.set(update_dict) + updates = {k: v for k, v in vars(update).items() if v is not None} + updates["last_updated"] = datetime.now(timezone.utc) + await doc.set(updates) async def _retry_message(self, message: DLQMessage) -> None: """Retry a DLQ message.""" diff --git a/backend/app/dlq/models.py b/backend/app/dlq/models.py index 66961243..e98be1f8 100644 --- a/backend/app/dlq/models.py +++ b/backend/app/dlq/models.py @@ -1,11 +1,11 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import Any from pydantic import BaseModel, ConfigDict, Field from app.core.utils import StringEnum from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent @@ -34,7 +34,7 @@ class DLQMessage(BaseModel): model_config = ConfigDict(from_attributes=True) event: DomainEvent # Discriminated union - auto-validates from dict - original_topic: str = "" + original_topic: KafkaTopic error: str = "Unknown error" retry_count: int = 0 failed_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @@ -63,7 +63,6 @@ class DLQMessageUpdate: retry_count: int | None = None discard_reason: str | None = None last_error: str | None = None - extra: dict[str, Any] = field(default_factory=dict) @dataclass @@ -71,7 +70,7 @@ class DLQMessageFilter: """Filter criteria for querying DLQ messages.""" status: DLQMessageStatus | None = None - topic: str | None = None + topic: KafkaTopic | None = None event_type: EventType | None = None @@ -79,7 +78,7 @@ class DLQMessageFilter: class RetryPolicy: """Retry policy configuration for DLQ messages.""" - topic: str + topic: KafkaTopic strategy: RetryStrategy max_retries: int = 5 base_delay_seconds: float = 60.0 @@ -121,7 +120,7 @@ class TopicStatistic(BaseModel): model_config = ConfigDict(from_attributes=True) - topic: str + topic: KafkaTopic count: int avg_retry_count: float @@ -191,7 +190,7 @@ class DLQTopicSummary(BaseModel): model_config = ConfigDict(from_attributes=True) - topic: str + topic: KafkaTopic total_messages: int status_breakdown: dict[str, int] oldest_message: datetime diff --git a/backend/app/domain/enums/__init__.py b/backend/app/domain/enums/__init__.py index f37aac67..31458a8e 100644 --- a/backend/app/domain/enums/__init__.py +++ b/backend/app/domain/enums/__init__.py @@ -1,5 +1,5 @@ from app.domain.enums.common import ErrorType, SortOrder, Theme -from app.domain.enums.execution import ExecutionStatus +from app.domain.enums.execution import ExecutionStatus, QueuePriority from app.domain.enums.health import AlertSeverity, AlertStatus, ComponentStatus from app.domain.enums.notification import ( NotificationChannel, @@ -17,6 +17,7 @@ "Theme", # Execution "ExecutionStatus", + "QueuePriority", # Health "AlertSeverity", "AlertStatus", diff --git a/backend/app/domain/enums/execution.py b/backend/app/domain/enums/execution.py index abb4809d..f04db8d0 100644 --- a/backend/app/domain/enums/execution.py +++ b/backend/app/domain/enums/execution.py @@ -1,3 +1,5 @@ +from enum import IntEnum + from app.core.utils import StringEnum @@ -12,3 +14,13 @@ class ExecutionStatus(StringEnum): TIMEOUT = "timeout" CANCELLED = "cancelled" ERROR = "error" + + +class QueuePriority(IntEnum): + """Execution queue priorities. Lower value = higher priority.""" + + CRITICAL = 0 + HIGH = 1 + NORMAL = 5 + LOW = 8 + BACKGROUND = 10 diff --git a/backend/app/domain/enums/kafka.py b/backend/app/domain/enums/kafka.py index ca48de6f..25e81d63 100644 --- a/backend/app/domain/enums/kafka.py +++ b/backend/app/domain/enums/kafka.py @@ -34,6 +34,7 @@ class KafkaTopic(StringEnum): # Resource topics RESOURCE_EVENTS = "resource_events" + RESOURCE_ALLOCATION = "resource_allocation" # Notification topics NOTIFICATION_EVENTS = "notification_events" diff --git a/backend/app/domain/events/typed.py b/backend/app/domain/events/typed.py index fca2a11d..b2da7781 100644 --- a/backend/app/domain/events/typed.py +++ b/backend/app/domain/events/typed.py @@ -8,6 +8,8 @@ from app.domain.enums.auth import LoginMethod from app.domain.enums.common import Environment from app.domain.enums.events import EventType +from app.domain.enums.execution import QueuePriority +from app.domain.enums.kafka import KafkaTopic from app.domain.enums.notification import NotificationChannel, NotificationSeverity from app.domain.enums.storage import ExecutionErrorType, StorageType @@ -30,8 +32,8 @@ class EventMetadata(AvroBase): service_name: str service_version: str + user_id: str correlation_id: str = Field(default_factory=lambda: str(uuid4())) - user_id: str | None = None ip_address: str | None = None user_agent: str | None = None environment: Environment = Environment.PRODUCTION @@ -75,7 +77,7 @@ class ExecutionRequestedEvent(BaseEvent): memory_limit: str cpu_request: str memory_request: str - priority: int = 5 + priority: QueuePriority = QueuePriority.NORMAL class ExecutionAcceptedEvent(BaseEvent): @@ -83,7 +85,7 @@ class ExecutionAcceptedEvent(BaseEvent): execution_id: str queue_position: int estimated_wait_seconds: float | None = None - priority: int = 5 + priority: QueuePriority = QueuePriority.NORMAL class ExecutionQueuedEvent(BaseEvent): @@ -428,7 +430,7 @@ class CreatePodCommandEvent(BaseEvent): memory_limit: str cpu_request: str memory_request: str - priority: int = 5 + priority: QueuePriority = QueuePriority.NORMAL class DeletePodCommandEvent(BaseEvent): @@ -558,7 +560,7 @@ class DLQMessageReceivedEvent(BaseEvent): event_type: Literal[EventType.DLQ_MESSAGE_RECEIVED] = EventType.DLQ_MESSAGE_RECEIVED dlq_event_id: str # The event_id of the failed message - original_topic: str + original_topic: KafkaTopic original_event_type: str error: str retry_count: int @@ -571,7 +573,7 @@ class DLQMessageRetriedEvent(BaseEvent): event_type: Literal[EventType.DLQ_MESSAGE_RETRIED] = EventType.DLQ_MESSAGE_RETRIED dlq_event_id: str # The event_id of the retried message - original_topic: str + original_topic: KafkaTopic original_event_type: str retry_count: int # New retry count after this retry retry_topic: str # Topic the message was retried to @@ -582,7 +584,7 @@ class DLQMessageDiscardedEvent(BaseEvent): event_type: Literal[EventType.DLQ_MESSAGE_DISCARDED] = EventType.DLQ_MESSAGE_DISCARDED dlq_event_id: str # The event_id of the discarded message - original_topic: str + original_topic: KafkaTopic original_event_type: str reason: str retry_count: int # Final retry count when discarded diff --git a/backend/app/domain/idempotency/__init__.py b/backend/app/domain/idempotency/__init__.py deleted file mode 100644 index 8529b2de..00000000 --- a/backend/app/domain/idempotency/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .models import ( - IdempotencyRecord, - IdempotencyStats, - IdempotencyStatus, - KeyStrategy, -) - -__all__ = [ - "IdempotencyStatus", - "IdempotencyRecord", - "IdempotencyStats", - "KeyStrategy", -] diff --git a/backend/app/domain/idempotency/models.py b/backend/app/domain/idempotency/models.py deleted file mode 100644 index 6d4eca45..00000000 --- a/backend/app/domain/idempotency/models.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from datetime import datetime - -from pydantic.dataclasses import dataclass - -from app.core.utils import StringEnum - - -class IdempotencyStatus(StringEnum): - PROCESSING = "processing" - COMPLETED = "completed" - FAILED = "failed" - EXPIRED = "expired" - - -class KeyStrategy(StringEnum): - """Strategy for generating idempotency keys.""" - - EVENT_BASED = "event_based" - CONTENT_HASH = "content_hash" - CUSTOM = "custom" - - -@dataclass -class IdempotencyRecord: - key: str - status: IdempotencyStatus - event_type: str - event_id: str - created_at: datetime - ttl_seconds: int - completed_at: datetime | None = None - processing_duration_ms: int | None = None - error: str | None = None - result_json: str | None = None - - -@dataclass -class IdempotencyStats: - total_keys: int - status_counts: dict[IdempotencyStatus, int] - prefix: str diff --git a/backend/app/events/core/consumer.py b/backend/app/events/core/consumer.py index 3365af98..62079874 100644 --- a/backend/app/events/core/consumer.py +++ b/backend/app/events/core/consumer.py @@ -1,10 +1,3 @@ -"""Unified Kafka consumer - pure message handler. - -Handles deserialization, dispatch, and metrics for Kafka messages. -No lifecycle, no properties, no state - just handle(). -Worker gets AIOKafkaConsumer directly from DI. -""" - from __future__ import annotations import logging @@ -25,12 +18,12 @@ class UnifiedConsumer: """Pure message handler - deserialize, dispatch, record metrics.""" def __init__( - self, - event_dispatcher: EventDispatcher, - schema_registry: SchemaRegistryManager, - logger: logging.Logger, - event_metrics: EventMetrics, - group_id: str, + self, + event_dispatcher: EventDispatcher, + schema_registry: SchemaRegistryManager, + logger: logging.Logger, + event_metrics: EventMetrics, + group_id: str, ) -> None: self._dispatcher = event_dispatcher self._schema_registry = schema_registry @@ -47,23 +40,18 @@ async def handle(self, msg: ConsumerRecord) -> DomainEvent | None: headers = {k: v.decode() for k, v in msg.headers} with get_tracer().start_as_current_span( - name="kafka.consume", - context=extract_trace_context(headers), - kind=SpanKind.CONSUMER, - attributes={ - EventAttributes.KAFKA_TOPIC: msg.topic, - EventAttributes.KAFKA_PARTITION: msg.partition, - EventAttributes.KAFKA_OFFSET: msg.offset, - EventAttributes.EVENT_TYPE: event.event_type, - EventAttributes.EVENT_ID: event.event_id, - }, + name="kafka.consume", + context=extract_trace_context(headers), + kind=SpanKind.CONSUMER, + attributes={ + EventAttributes.KAFKA_TOPIC: msg.topic, + EventAttributes.KAFKA_PARTITION: msg.partition, + EventAttributes.KAFKA_OFFSET: msg.offset, + EventAttributes.EVENT_TYPE: event.event_type, + EventAttributes.EVENT_ID: event.event_id, + }, ): - try: - await self._dispatcher.dispatch(event) - self._event_metrics.record_kafka_message_consumed(msg.topic, self._group_id) - except Exception as e: - self._logger.error(f"Dispatch error: {event.event_type}: {e}") - self._event_metrics.record_kafka_consumption_error(msg.topic, self._group_id, type(e).__name__) - raise + await self._dispatcher.dispatch(event) + self._event_metrics.record_kafka_message_consumed(msg.topic, self._group_id) return event diff --git a/backend/app/events/core/dlq_handler.py b/backend/app/events/core/dlq_handler.py index 7de433a7..f92fae68 100644 --- a/backend/app/events/core/dlq_handler.py +++ b/backend/app/events/core/dlq_handler.py @@ -1,13 +1,14 @@ import logging from typing import Awaitable, Callable +from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent from .producer import UnifiedProducer def create_dlq_error_handler( - producer: UnifiedProducer, original_topic: str, logger: logging.Logger, max_retries: int = 3 + producer: UnifiedProducer, original_topic: KafkaTopic, logger: logging.Logger, max_retries: int = 3 ) -> Callable[[Exception, DomainEvent], Awaitable[None]]: """Create an error handler that sends failed events to DLQ after max retries.""" retry_counts: dict[str, int] = {} @@ -26,7 +27,7 @@ async def handle_error_with_dlq(error: Exception, event: DomainEvent) -> None: def create_immediate_dlq_handler( - producer: UnifiedProducer, original_topic: str, logger: logging.Logger + producer: UnifiedProducer, original_topic: KafkaTopic, logger: logging.Logger ) -> Callable[[Exception, DomainEvent], Awaitable[None]]: """Create an error handler that immediately sends failed events to DLQ.""" diff --git a/backend/app/events/core/producer.py b/backend/app/events/core/producer.py index 98ab0b74..30fffa08 100644 --- a/backend/app/events/core/producer.py +++ b/backend/app/events/core/producer.py @@ -1,9 +1,3 @@ -"""Unified Kafka producer - thin wrapper over AIOKafkaProducer. - -The producer receives a ready-to-use AIOKafkaProducer from DI. -No lifecycle management - DI provider handles start/stop. -""" - from __future__ import annotations import asyncio @@ -69,8 +63,6 @@ async def produce( # Update metrics on success self._metrics.messages_sent += 1 self._metrics.bytes_sent += len(serialized_value) - - # Record Kafka metrics self._event_metrics.record_kafka_message_produced(topic) self._logger.debug(f"Message [{event_to_produce}] sent to topic: {topic}") @@ -84,7 +76,7 @@ async def produce( raise async def send_to_dlq( - self, original_event: DomainEvent, original_topic: str, error: Exception, retry_count: int = 0 + self, original_event: DomainEvent, original_topic: KafkaTopic, error: Exception, retry_count: int = 0 ) -> None: """Send a failed event to the Dead Letter Queue.""" try: diff --git a/backend/app/schemas_pydantic/dlq.py b/backend/app/schemas_pydantic/dlq.py index 4093d03f..6b9cbc9b 100644 --- a/backend/app/schemas_pydantic/dlq.py +++ b/backend/app/schemas_pydantic/dlq.py @@ -10,6 +10,7 @@ RetryStrategy, TopicStatistic, ) +from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent @@ -31,7 +32,7 @@ class DLQMessageResponse(BaseModel): model_config = ConfigDict(from_attributes=True) event: DomainEvent - original_topic: str + original_topic: KafkaTopic error: str retry_count: int failed_at: datetime @@ -46,7 +47,7 @@ class DLQMessageResponse(BaseModel): class RetryPolicyRequest(BaseModel): """Request model for setting a retry policy.""" - topic: str + topic: KafkaTopic strategy: RetryStrategy max_retries: int = 5 base_delay_seconds: float = 60.0 @@ -87,7 +88,7 @@ class DLQTopicSummaryResponse(BaseModel): model_config = ConfigDict(from_attributes=True) - topic: str + topic: KafkaTopic total_messages: int status_breakdown: dict[str, int] oldest_message: datetime @@ -102,7 +103,7 @@ class DLQMessageDetail(BaseModel): model_config = ConfigDict(from_attributes=True) event: DomainEvent - original_topic: str + original_topic: KafkaTopic error: str retry_count: int failed_at: datetime diff --git a/backend/app/services/coordinator/coordinator.py b/backend/app/services/coordinator/coordinator.py index b8c9e6e5..c4feb8bc 100644 --- a/backend/app/services/coordinator/coordinator.py +++ b/backend/app/services/coordinator/coordinator.py @@ -1,10 +1,3 @@ -"""Execution Coordinator - stateless event handler. - -Coordinates execution scheduling across the system. Receives events, -processes them, and publishes results. No lifecycle management. -All state is stored in Redis repositories. -""" - from __future__ import annotations import logging @@ -12,17 +5,11 @@ from uuid import uuid4 from app.core.metrics import CoordinatorMetrics, EventMetrics -from app.db.repositories import ( - ExecutionQueueRepository, - ExecutionStateRepository, - QueuePriority, - ResourceRepository, -) from app.db.repositories.execution_repository import ExecutionRepository +from app.db.repositories.redis.user_limit_repository import UserLimitRepository from app.domain.enums.storage import ExecutionErrorType from app.domain.events.typed import ( CreatePodCommandEvent, - EventMetadata, ExecutionAcceptedEvent, ExecutionCancelledEvent, ExecutionCompletedEvent, @@ -33,228 +20,73 @@ class ExecutionCoordinator: - """Stateless execution coordinator - pure event handler. + """Stateless execution coordinator - fire and forget to k8s. - No lifecycle methods (start/stop) - receives ready-to-use dependencies from DI. - All state (active executions, queue, resources) stored in Redis. + Only tracks per-user execution count for rate limiting. + K8s handles resource allocation and scheduling. """ def __init__( self, producer: UnifiedProducer, execution_repository: ExecutionRepository, - state_repo: ExecutionStateRepository, - queue_repo: ExecutionQueueRepository, - resource_repo: ResourceRepository, + user_limit_repo: UserLimitRepository, logger: logging.Logger, coordinator_metrics: CoordinatorMetrics, event_metrics: EventMetrics, ) -> None: self._producer = producer self._execution_repository = execution_repository - self._state_repo = state_repo - self._queue_repo = queue_repo - self._resource_repo = resource_repo + self._user_limit = user_limit_repo self._logger = logger self._metrics = coordinator_metrics self._event_metrics = event_metrics async def handle_execution_requested(self, event: ExecutionRequestedEvent) -> None: - """Handle execution requested event - add to queue and try to schedule.""" + """Handle execution request - check limit, fire to k8s.""" self._logger.info(f"Handling ExecutionRequestedEvent: {event.execution_id}") start_time = time.time() + user_id = event.metadata.user_id try: - priority = QueuePriority(event.priority) - user_id = event.metadata.user_id or "anonymous" - - # Add to Redis queue - success, position, error = await self._queue_repo.enqueue( - execution_id=event.execution_id, - event_data=event.model_dump(mode="json"), - priority=priority, - user_id=user_id, - ) - - if not success: - await self._publish_queue_full(event, error or "Queue is full") - self._metrics.record_coordinator_execution_scheduled("queue_full") + if not await self._user_limit.try_increment(user_id): + await self._publish_limit_exceeded(event) + self._metrics.record_coordinator_execution_scheduled("limit_exceeded") return - # Publish ExecutionAcceptedEvent - await self._publish_execution_accepted(event, position or 0, event.priority) + await self._publish_execution_accepted(event) + await self._publish_create_pod_command(event) - # Track metrics duration = time.time() - start_time self._metrics.record_coordinator_scheduling_duration(duration) - self._metrics.record_coordinator_execution_scheduled("queued") - - self._logger.info(f"Execution {event.execution_id} added to queue at position {position}") - - # If at front of queue (position 0), try to schedule immediately - if position == 0: - await self._try_schedule_next() + self._metrics.record_coordinator_execution_scheduled("scheduled") except Exception as e: + await self._user_limit.decrement(user_id) self._logger.error(f"Failed to handle execution request {event.execution_id}: {e}", exc_info=True) self._metrics.record_coordinator_execution_scheduled("error") async def handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: - """Handle execution completed - release resources and try to schedule next.""" - execution_id = event.execution_id - self._logger.info(f"Handling ExecutionCompletedEvent: {execution_id}") - - # Release resources - await self._resource_repo.release(execution_id) - - # Remove from active state - await self._state_repo.remove(execution_id) - - # Update metrics - count = await self._state_repo.get_active_count() - self._metrics.update_coordinator_active_executions(count) - - self._logger.info(f"Execution {execution_id} completed, resources released") - - # Try to schedule next execution from queue - await self._try_schedule_next() + """Handle execution completed - decrement user counter.""" + self._logger.info(f"Handling ExecutionCompletedEvent: {event.execution_id}") + if event.metadata.user_id: + await self._user_limit.decrement(event.metadata.user_id) async def handle_execution_failed(self, event: ExecutionFailedEvent) -> None: - """Handle execution failed - release resources and try to schedule next.""" - execution_id = event.execution_id - self._logger.info(f"Handling ExecutionFailedEvent: {execution_id}") - - # Release resources - await self._resource_repo.release(execution_id) - - # Remove from active state - await self._state_repo.remove(execution_id) - - # Update metrics - count = await self._state_repo.get_active_count() - self._metrics.update_coordinator_active_executions(count) - - # Try to schedule next execution from queue - await self._try_schedule_next() + """Handle execution failed - decrement user counter.""" + self._logger.info(f"Handling ExecutionFailedEvent: {event.execution_id}") + if event.metadata.user_id: + await self._user_limit.decrement(event.metadata.user_id) async def handle_execution_cancelled(self, event: ExecutionCancelledEvent) -> None: - """Handle execution cancelled - remove from queue and release resources.""" - execution_id = event.execution_id - self._logger.info(f"Handling ExecutionCancelledEvent: {execution_id}") - - # Remove from queue if present - await self._queue_repo.remove(execution_id) - - # Release resources if allocated - await self._resource_repo.release(execution_id) - - # Remove from active state - await self._state_repo.remove(execution_id) - - # Update metrics - count = await self._state_repo.get_active_count() - self._metrics.update_coordinator_active_executions(count) - - async def _try_schedule_next(self) -> None: - """Try to schedule the next execution from the queue.""" - result = await self._queue_repo.dequeue() - if not result: - return - - execution_id, event_data = result - - # Reconstruct event from stored data - try: - event = ExecutionRequestedEvent.model_validate(event_data) - await self._schedule_execution(event) - except Exception as e: - self._logger.error(f"Failed to schedule execution {execution_id}: {e}", exc_info=True) - - async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None: - """Schedule a single execution - allocate resources and publish command.""" - start_time = time.time() - execution_id = event.execution_id - - # Try to claim this execution atomically - claimed = await self._state_repo.try_claim(execution_id) - if not claimed: - self._logger.debug(f"Execution {execution_id} already claimed, skipping") - return - - try: - # Allocate resources - allocation = await self._resource_repo.allocate( - execution_id=execution_id, - language=event.language, - requested_cpu=None, - requested_memory_mb=None, - requested_gpu=0, - ) - - if not allocation: - # No resources available, release claim and requeue - await self._state_repo.remove(execution_id) - await self._queue_repo.enqueue( - execution_id=event.execution_id, - event_data=event.model_dump(mode="json"), - priority=QueuePriority(event.priority), - user_id=event.metadata.user_id or "anonymous", - ) - self._logger.info(f"No resources available for {execution_id}, requeued") - return - - # Update metrics - count = await self._state_repo.get_active_count() - self._metrics.update_coordinator_active_executions(count) - - # Publish CreatePodCommand - await self._publish_execution_started(event) - - # Track metrics - queue_time = start_time - event.timestamp.timestamp() - priority = QueuePriority(event.priority) - self._metrics.record_coordinator_queue_time(queue_time, priority.name) - - scheduling_duration = time.time() - start_time - self._metrics.record_coordinator_scheduling_duration(scheduling_duration) - self._metrics.record_coordinator_execution_scheduled("scheduled") - - self._logger.info( - f"Scheduled execution {event.execution_id}. " - f"Queue time: {queue_time:.2f}s, " - f"Resources: {allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM" - ) - - except Exception as e: - self._logger.error(f"Failed to schedule execution {event.execution_id}: {e}", exc_info=True) - - # Release resources and claim - await self._resource_repo.release(execution_id) - await self._state_repo.remove(execution_id) - - count = await self._state_repo.get_active_count() - self._metrics.update_coordinator_active_executions(count) - self._metrics.record_coordinator_execution_scheduled("error") - - # Publish failure event - await self._publish_scheduling_failed(event, str(e)) - - async def _build_command_metadata(self, request: ExecutionRequestedEvent) -> EventMetadata: - """Build metadata for CreatePodCommandEvent with guaranteed user_id.""" - exec_rec = await self._execution_repository.get_execution(request.execution_id) - user_id: str = exec_rec.user_id if exec_rec and exec_rec.user_id else "system" - - return EventMetadata( - service_name="execution-coordinator", - service_version="1.0.0", - user_id=user_id, - correlation_id=request.metadata.correlation_id, - ) - - async def _publish_execution_started(self, request: ExecutionRequestedEvent) -> None: - """Send CreatePodCommandEvent to k8s-worker via SAGA_COMMANDS topic.""" - metadata = await self._build_command_metadata(request) - + """Handle execution cancelled - decrement user counter.""" + self._logger.info(f"Handling ExecutionCancelledEvent: {event.execution_id}") + exec_rec = await self._execution_repository.get_execution(event.execution_id) + if exec_rec and exec_rec.user_id: + await self._user_limit.decrement(exec_rec.user_id) + + async def _publish_create_pod_command(self, request: ExecutionRequestedEvent) -> None: + """Send CreatePodCommandEvent to k8s-worker.""" create_pod_cmd = CreatePodCommandEvent( saga_id=str(uuid4()), execution_id=request.execution_id, @@ -270,57 +102,34 @@ async def _publish_execution_started(self, request: ExecutionRequestedEvent) -> cpu_request=request.cpu_request, memory_request=request.memory_request, priority=request.priority, - metadata=metadata, + metadata=request.metadata, ) await self._producer.produce(event_to_produce=create_pod_cmd, key=request.execution_id) self._logger.info(f"Published CreatePodCommandEvent for {request.execution_id}") - async def _publish_execution_accepted( - self, request: ExecutionRequestedEvent, position: int, priority: int - ) -> None: + async def _publish_execution_accepted(self, request: ExecutionRequestedEvent) -> None: """Publish execution accepted event.""" event = ExecutionAcceptedEvent( execution_id=request.execution_id, - queue_position=position, + queue_position=0, estimated_wait_seconds=None, - priority=priority, + priority=request.priority, metadata=request.metadata, ) await self._producer.produce(event_to_produce=event) - self._logger.info(f"ExecutionAcceptedEvent published for {request.execution_id}") - - async def _publish_queue_full(self, request: ExecutionRequestedEvent, error: str) -> None: - """Publish queue full event.""" - queue_stats = await self._queue_repo.get_stats() + async def _publish_limit_exceeded(self, request: ExecutionRequestedEvent) -> None: + """Publish limit exceeded failure event.""" event = ExecutionFailedEvent( execution_id=request.execution_id, error_type=ExecutionErrorType.RESOURCE_LIMIT, exit_code=-1, - stderr=f"Queue full: {error}. Queue size: {queue_stats.total_size}", - resource_usage=None, - metadata=request.metadata, - error_message=error, - ) - - await self._producer.produce(event_to_produce=event, key=request.execution_id) - - async def _publish_scheduling_failed(self, request: ExecutionRequestedEvent, error: str) -> None: - """Publish scheduling failed event.""" - resource_stats = await self._resource_repo.get_stats() - - event = ExecutionFailedEvent( - execution_id=request.execution_id, - error_type=ExecutionErrorType.SYSTEM_ERROR, - exit_code=-1, - stderr=f"Failed to schedule execution: {error}. " - f"Available resources: CPU={resource_stats.available_cpu}, " - f"Memory={resource_stats.available_memory_mb}MB", + stderr="User execution limit exceeded", resource_usage=None, metadata=request.metadata, - error_message=error, + error_message="User execution limit exceeded", ) await self._producer.produce(event_to_produce=event, key=request.execution_id) diff --git a/backend/app/services/execution_service.py b/backend/app/services/execution_service.py index 8798e9c3..22503cbd 100644 --- a/backend/app/services/execution_service.py +++ b/backend/app/services/execution_service.py @@ -8,7 +8,7 @@ from app.core.metrics import ExecutionMetrics from app.db.repositories.execution_repository import ExecutionRepository from app.domain.enums.events import EventType -from app.domain.enums.execution import ExecutionStatus +from app.domain.enums.execution import ExecutionStatus, QueuePriority from app.domain.events.typed import ( DomainEvent, EventMetadata, @@ -96,7 +96,7 @@ async def get_example_scripts(self) -> dict[str, str]: def _create_event_metadata( self, - user_id: str | None = None, + user_id: str, client_ip: str | None = None, user_agent: str | None = None, ) -> EventMetadata: @@ -131,7 +131,7 @@ async def execute_script( user_agent: str | None, lang: str = "python", lang_version: str = "3.11", - priority: int = 5, + priority: QueuePriority = QueuePriority.NORMAL, timeout_override: int | None = None, ) -> DomainExecution: """ @@ -457,7 +457,7 @@ async def _publish_deletion_event(self, execution_id: str) -> None: Args: execution_id: UUID of deleted execution. """ - metadata = self._create_event_metadata() + metadata = self._create_event_metadata("system") event = ExecutionCancelledEvent( execution_id=execution_id, reason="user_requested", cancelled_by=metadata.user_id, metadata=metadata diff --git a/backend/app/services/idempotency/__init__.py b/backend/app/services/idempotency/__init__.py deleted file mode 100644 index 82af12f0..00000000 --- a/backend/app/services/idempotency/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from app.domain.idempotency import IdempotencyStatus -from app.services.idempotency.idempotency_manager import ( - IdempotencyConfig, - IdempotencyKeyStrategy, - IdempotencyManager, - IdempotencyResult, - create_idempotency_manager, -) -from app.services.idempotency.middleware import IdempotentConsumerWrapper, IdempotentEventHandler, idempotent_handler - -__all__ = [ - "IdempotencyManager", - "IdempotencyConfig", - "IdempotencyResult", - "IdempotencyStatus", - "IdempotencyKeyStrategy", - "create_idempotency_manager", - "IdempotentEventHandler", - "idempotent_handler", - "IdempotentConsumerWrapper", -] diff --git a/backend/app/services/idempotency/idempotency_manager.py b/backend/app/services/idempotency/idempotency_manager.py deleted file mode 100644 index af952289..00000000 --- a/backend/app/services/idempotency/idempotency_manager.py +++ /dev/null @@ -1,331 +0,0 @@ -import asyncio -import hashlib -import json -import logging -from datetime import datetime, timedelta, timezone -from typing import Protocol - -from pydantic import BaseModel -from pymongo.errors import DuplicateKeyError - -from app.core.metrics import DatabaseMetrics -from app.domain.events.typed import BaseEvent -from app.domain.idempotency import IdempotencyRecord, IdempotencyStats, IdempotencyStatus, KeyStrategy - - -class IdempotencyResult(BaseModel): - is_duplicate: bool - status: IdempotencyStatus - created_at: datetime - completed_at: datetime | None = None - processing_duration_ms: int | None = None - error: str | None = None - has_cached_result: bool = False - key: str - - -class IdempotencyConfig(BaseModel): - key_prefix: str = "idempotency" - default_ttl_seconds: int = 3600 - processing_timeout_seconds: int = 300 - enable_result_caching: bool = True - max_result_size_bytes: int = 1048576 - enable_metrics: bool = True - collection_name: str = "idempotency_keys" - - -class IdempotencyKeyStrategy: - @staticmethod - def event_based(event: BaseEvent) -> str: - return f"{event.event_type}:{event.event_id}" - - @staticmethod - def content_hash(event: BaseEvent, fields: set[str] | None = None) -> str: - event_dict = event.model_dump(mode="json") - event_dict.pop("event_id", None) - event_dict.pop("timestamp", None) - event_dict.pop("metadata", None) - - if fields: - event_dict = {k: v for k, v in event_dict.items() if k in fields} - - content = json.dumps(event_dict, sort_keys=True) - return hashlib.sha256(content.encode()).hexdigest() - - @staticmethod - def custom(event: BaseEvent, custom_key: str) -> str: - return f"{event.event_type}:{custom_key}" - - -class IdempotencyRepoProtocol(Protocol): - async def find_by_key(self, key: str) -> IdempotencyRecord | None: ... - async def insert_processing(self, record: IdempotencyRecord) -> None: ... - async def update_record(self, record: IdempotencyRecord) -> int: ... - async def delete_key(self, key: str) -> int: ... - async def aggregate_status_counts(self, key_prefix: str) -> dict[str, int]: ... - async def health_check(self) -> None: ... - - -class IdempotencyManager: - def __init__( - self, - config: IdempotencyConfig, - repository: IdempotencyRepoProtocol, - logger: logging.Logger, - database_metrics: DatabaseMetrics, - ) -> None: - self.config = config - self.metrics = database_metrics - self._repo: IdempotencyRepoProtocol = repository - self._stats_update_task: asyncio.Task[None] | None = None - self.logger = logger - - async def initialize(self) -> None: - if self.config.enable_metrics and self._stats_update_task is None: - self._stats_update_task = asyncio.create_task(self._update_stats_loop()) - self.logger.info("Idempotency manager ready") - - async def close(self) -> None: - if self._stats_update_task: - self._stats_update_task.cancel() - try: - await self._stats_update_task - except asyncio.CancelledError: - pass - self.logger.info("Closed idempotency manager") - - def _generate_key( - self, event: BaseEvent, key_strategy: KeyStrategy, custom_key: str | None = None, fields: set[str] | None = None - ) -> str: - if key_strategy == KeyStrategy.EVENT_BASED: - key = IdempotencyKeyStrategy.event_based(event) - elif key_strategy == KeyStrategy.CONTENT_HASH: - key = IdempotencyKeyStrategy.content_hash(event, fields) - elif key_strategy == KeyStrategy.CUSTOM and custom_key: - key = IdempotencyKeyStrategy.custom(event, custom_key) - else: - raise ValueError(f"Invalid key strategy: {key_strategy}") - return f"{self.config.key_prefix}:{key}" - - async def check_and_reserve( - self, - event: BaseEvent, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, - custom_key: str | None = None, - ttl_seconds: int | None = None, - fields: set[str] | None = None, - ) -> IdempotencyResult: - full_key = self._generate_key(event, key_strategy, custom_key, fields) - ttl = ttl_seconds or self.config.default_ttl_seconds - - existing = await self._repo.find_by_key(full_key) - if existing: - self.metrics.record_idempotency_cache_hit(event.event_type, "check_and_reserve") - return await self._handle_existing_key(existing, full_key, event.event_type) - - self.metrics.record_idempotency_cache_miss(event.event_type, "check_and_reserve") - return await self._create_new_key(full_key, event, ttl) - - async def _handle_existing_key( - self, - existing: IdempotencyRecord, - full_key: str, - event_type: str, - ) -> IdempotencyResult: - status = existing.status - if status == IdempotencyStatus.PROCESSING: - return await self._handle_processing_key(existing, full_key, event_type) - - self.metrics.record_idempotency_duplicate_blocked(event_type) - created_at = existing.created_at or datetime.now(timezone.utc) - return IdempotencyResult( - is_duplicate=True, - status=status, - created_at=created_at, - completed_at=existing.completed_at, - processing_duration_ms=existing.processing_duration_ms, - error=existing.error, - has_cached_result=existing.result_json is not None, - key=full_key, - ) - - async def _handle_processing_key( - self, - existing: IdempotencyRecord, - full_key: str, - event_type: str, - ) -> IdempotencyResult: - created_at = existing.created_at - now = datetime.now(timezone.utc) - - if now - created_at > timedelta(seconds=self.config.processing_timeout_seconds): - self.logger.warning(f"Idempotency key {full_key} processing timeout, allowing retry") - existing.created_at = now - existing.status = IdempotencyStatus.PROCESSING - await self._repo.update_record(existing) - return IdempotencyResult( - is_duplicate=False, status=IdempotencyStatus.PROCESSING, created_at=now, key=full_key - ) - - self.metrics.record_idempotency_duplicate_blocked(event_type) - return IdempotencyResult( - is_duplicate=True, - status=IdempotencyStatus.PROCESSING, - created_at=created_at, - has_cached_result=existing.result_json is not None, - key=full_key, - ) - - async def _create_new_key(self, full_key: str, event: BaseEvent, ttl: int) -> IdempotencyResult: - created_at = datetime.now(timezone.utc) - try: - record = IdempotencyRecord( - key=full_key, - status=IdempotencyStatus.PROCESSING, - event_type=event.event_type, - event_id=str(event.event_id), - created_at=created_at, - ttl_seconds=ttl, - ) - await self._repo.insert_processing(record) - return IdempotencyResult( - is_duplicate=False, status=IdempotencyStatus.PROCESSING, created_at=created_at, key=full_key - ) - except DuplicateKeyError: - # Race: someone inserted the same key concurrently — treat as existing - existing = await self._repo.find_by_key(full_key) - if existing: - return await self._handle_existing_key(existing, full_key, event.event_type) - # If for some reason it's still not found, allow processing - return IdempotencyResult( - is_duplicate=False, status=IdempotencyStatus.PROCESSING, created_at=created_at, key=full_key - ) - - async def _update_key_status( - self, - full_key: str, - existing: IdempotencyRecord, - status: IdempotencyStatus, - cached_json: str | None = None, - error: str | None = None, - ) -> bool: - created_at = existing.created_at - completed_at = datetime.now(timezone.utc) - duration_ms = int((completed_at - created_at).total_seconds() * 1000) - existing.status = status - existing.completed_at = completed_at - existing.processing_duration_ms = duration_ms - if error: - existing.error = error - if cached_json is not None and self.config.enable_result_caching: - if len(cached_json.encode()) <= self.config.max_result_size_bytes: - existing.result_json = cached_json - else: - self.logger.warning(f"Result too large to cache for key {full_key}") - return (await self._repo.update_record(existing)) > 0 - - async def mark_completed( - self, - event: BaseEvent, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, - custom_key: str | None = None, - fields: set[str] | None = None, - ) -> bool: - full_key = self._generate_key(event, key_strategy, custom_key, fields) - try: - existing = await self._repo.find_by_key(full_key) - except Exception as e: # Narrow DB op - self.logger.error(f"Failed to load idempotency key for completion: {e}") - return False - if not existing: - self.logger.warning(f"Idempotency key {full_key} not found when marking completed") - return False - # mark_completed does not accept arbitrary result today; use mark_completed_with_cache for cached payloads - return await self._update_key_status(full_key, existing, IdempotencyStatus.COMPLETED, cached_json=None) - - async def mark_failed( - self, - event: BaseEvent, - error: str, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, - custom_key: str | None = None, - fields: set[str] | None = None, - ) -> bool: - full_key = self._generate_key(event, key_strategy, custom_key, fields) - existing = await self._repo.find_by_key(full_key) - if not existing: - self.logger.warning(f"Idempotency key {full_key} not found when marking failed") - return False - return await self._update_key_status( - full_key, existing, IdempotencyStatus.FAILED, cached_json=None, error=error - ) - - async def mark_completed_with_json( - self, - event: BaseEvent, - cached_json: str, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, - custom_key: str | None = None, - fields: set[str] | None = None, - ) -> bool: - full_key = self._generate_key(event, key_strategy, custom_key, fields) - existing = await self._repo.find_by_key(full_key) - if not existing: - self.logger.warning(f"Idempotency key {full_key} not found when marking completed with cache") - return False - return await self._update_key_status(full_key, existing, IdempotencyStatus.COMPLETED, cached_json=cached_json) - - async def get_cached_json( - self, event: BaseEvent, key_strategy: KeyStrategy, custom_key: str | None, fields: set[str] | None = None - ) -> str: - full_key = self._generate_key(event, key_strategy, custom_key, fields) - existing = await self._repo.find_by_key(full_key) - assert existing and existing.result_json is not None, "Invariant: cached result must exist when requested" - return existing.result_json - - async def remove( - self, - event: BaseEvent, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, - custom_key: str | None = None, - fields: set[str] | None = None, - ) -> bool: - full_key = self._generate_key(event, key_strategy, custom_key, fields) - try: - deleted = await self._repo.delete_key(full_key) - return deleted > 0 - except Exception as e: - self.logger.error(f"Failed to remove idempotency key: {e}") - return False - - async def get_stats(self) -> IdempotencyStats: - counts_raw = await self._repo.aggregate_status_counts(self.config.key_prefix) - status_counts: dict[IdempotencyStatus, int] = { - IdempotencyStatus.PROCESSING: counts_raw.get(IdempotencyStatus.PROCESSING, 0), - IdempotencyStatus.COMPLETED: counts_raw.get(IdempotencyStatus.COMPLETED, 0), - IdempotencyStatus.FAILED: counts_raw.get(IdempotencyStatus.FAILED, 0), - } - total = sum(status_counts.values()) - return IdempotencyStats(total_keys=total, status_counts=status_counts, prefix=self.config.key_prefix) - - async def _update_stats_loop(self) -> None: - while True: - try: - stats = await self.get_stats() - self.metrics.update_idempotency_keys_active(stats.total_keys, self.config.key_prefix) - await asyncio.sleep(60) - except asyncio.CancelledError: - break - except Exception as e: - self.logger.error(f"Failed to update idempotency stats: {e}") - await asyncio.sleep(300) - - -def create_idempotency_manager( - *, - repository: IdempotencyRepoProtocol, - config: IdempotencyConfig | None = None, - logger: logging.Logger, - database_metrics: DatabaseMetrics, -) -> IdempotencyManager: - return IdempotencyManager(config or IdempotencyConfig(), repository, logger, database_metrics) diff --git a/backend/app/services/idempotency/middleware.py b/backend/app/services/idempotency/middleware.py deleted file mode 100644 index 7fd3d1e3..00000000 --- a/backend/app/services/idempotency/middleware.py +++ /dev/null @@ -1,252 +0,0 @@ -"""Idempotent event processing middleware""" - -import asyncio -import logging -from typing import Any, Awaitable, Callable, Dict, Set - -from app.domain.enums.events import EventType -from app.domain.events.typed import DomainEvent -from app.domain.idempotency import KeyStrategy -from app.events.core import EventDispatcher, UnifiedConsumer -from app.services.idempotency.idempotency_manager import IdempotencyManager - - -class IdempotentEventHandler: - """Wrapper for event handlers with idempotency support""" - - def __init__( - self, - handler: Callable[[DomainEvent], Awaitable[None]], - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, - custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: Set[str] | None = None, - ttl_seconds: int | None = None, - cache_result: bool = True, - on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, - ): - self.handler = handler - self.idempotency_manager = idempotency_manager - self.logger = logger - self.key_strategy = key_strategy - self.custom_key_func = custom_key_func - self.fields = fields - self.ttl_seconds = ttl_seconds - self.cache_result = cache_result - self.on_duplicate = on_duplicate - - async def __call__(self, event: DomainEvent) -> None: - """Process event with idempotency check""" - self.logger.info( - f"IdempotentEventHandler called for event {event.event_type}, " - f"id={event.event_id}, handler={self.handler.__name__}" - ) - # Generate custom key if function provided - custom_key = None - if self.key_strategy == KeyStrategy.CUSTOM and self.custom_key_func: - custom_key = self.custom_key_func(event) - - # Check idempotency - idempotency_result = await self.idempotency_manager.check_and_reserve( - event=event, - key_strategy=self.key_strategy, - custom_key=custom_key, - ttl_seconds=self.ttl_seconds, - fields=self.fields, - ) - - if idempotency_result.is_duplicate: - # Handle duplicate - self.logger.info( - f"Duplicate event detected: {event.event_type} ({event.event_id}), status: {idempotency_result.status}" - ) - - # Call duplicate handler if provided - if self.on_duplicate: - if asyncio.iscoroutinefunction(self.on_duplicate): - await self.on_duplicate(event, idempotency_result) - else: - await asyncio.to_thread(self.on_duplicate, event, idempotency_result) - - # For duplicate, just return without error - return - - # Not a duplicate, process the event - try: - # Call the actual handler - it returns None - await self.handler(event) - - # Mark as completed - await self.idempotency_manager.mark_completed( - event=event, key_strategy=self.key_strategy, custom_key=custom_key, fields=self.fields - ) - - except Exception as e: - # Mark as failed - await self.idempotency_manager.mark_failed( - event=event, error=str(e), key_strategy=self.key_strategy, custom_key=custom_key, fields=self.fields - ) - raise - - -def idempotent_handler( - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, - custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: Set[str] | None = None, - ttl_seconds: int | None = None, - cache_result: bool = True, - on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, -) -> Callable[[Callable[[DomainEvent], Awaitable[None]]], Callable[[DomainEvent], Awaitable[None]]]: - """Decorator for making event handlers idempotent""" - - def decorator(func: Callable[[DomainEvent], Awaitable[None]]) -> Callable[[DomainEvent], Awaitable[None]]: - handler = IdempotentEventHandler( - handler=func, - idempotency_manager=idempotency_manager, - logger=logger, - key_strategy=key_strategy, - custom_key_func=custom_key_func, - fields=fields, - ttl_seconds=ttl_seconds, - cache_result=cache_result, - on_duplicate=on_duplicate, - ) - return handler # IdempotentEventHandler is already callable with the right signature - - return decorator - - -class IdempotentConsumerWrapper: - """Wrapper for Kafka consumer with automatic idempotency""" - - def __init__( - self, - consumer: UnifiedConsumer, - idempotency_manager: IdempotencyManager, - dispatcher: EventDispatcher, - logger: logging.Logger, - default_key_strategy: KeyStrategy = KeyStrategy.EVENT_BASED, - default_ttl_seconds: int = 3600, - enable_for_all_handlers: bool = True, - ): - self.consumer = consumer - self.idempotency_manager = idempotency_manager - self.dispatcher = dispatcher - self.logger = logger - self.default_key_strategy = default_key_strategy - self.default_ttl_seconds = default_ttl_seconds - self._original_handlers: Dict[EventType, list[Callable[[DomainEvent], Awaitable[None]]]] = {} - - if enable_for_all_handlers: - self._wrap_handlers() - - def _wrap_handlers(self) -> None: - """Wrap all registered handlers with idempotency.""" - if not self.dispatcher: - self.logger.warning("No dispatcher available for handler wrapping") - return - - self._original_handlers = self.dispatcher.get_all_handlers() - self.logger.debug(f"Wrapping {len(self._original_handlers)} event types with idempotency") - - # Wrap each handler - for event_type, handlers in self._original_handlers.items(): - wrapped_handlers: list[Callable[[DomainEvent], Awaitable[None]]] = [] - for handler in handlers: - # Wrap with idempotency - IdempotentEventHandler is callable with the right signature - wrapped = IdempotentEventHandler( - handler=handler, - idempotency_manager=self.idempotency_manager, - logger=self.logger, - key_strategy=self.default_key_strategy, - ttl_seconds=self.default_ttl_seconds, - ) - wrapped_handlers.append(wrapped) - - self.dispatcher.replace_handlers(event_type, wrapped_handlers) - - def subscribe_idempotent_handler( - self, - event_type: str, - handler: Callable[[DomainEvent], Awaitable[None]], - key_strategy: KeyStrategy | None = None, - custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: Set[str] | None = None, - ttl_seconds: int | None = None, - cache_result: bool = True, - on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, - ) -> None: - """Subscribe an idempotent handler for specific event type""" - # Create the idempotent handler wrapper - idempotent_wrapper = IdempotentEventHandler( - handler=handler, - idempotency_manager=self.idempotency_manager, - logger=self.logger, - key_strategy=key_strategy or self.default_key_strategy, - custom_key_func=custom_key_func, - fields=fields, - ttl_seconds=ttl_seconds or self.default_ttl_seconds, - cache_result=cache_result, - on_duplicate=on_duplicate, - ) - - # Create an async handler that processes the message - async def async_handler(message: Any) -> Any: - self.logger.info(f"IDEMPOTENT HANDLER CALLED for {event_type}") - - # Extract event from confluent-kafka Message - if not hasattr(message, "value"): - self.logger.error(f"Received non-Message object for {event_type}: {type(message)}") - return None - - # Debug log to check message details - self.logger.info( - f"Handler for {event_type} - Message type: {type(message)}, " - f"has key: {hasattr(message, 'key')}, " - f"has topic: {hasattr(message, 'topic')}" - ) - - raw_value = message.value - - # Debug the raw value - self.logger.info(f"Raw value extracted: {raw_value[:100] if raw_value else 'None or empty'}") - - # Handle tombstone messages (null value for log compaction) - if raw_value is None: - self.logger.warning(f"Received empty message for {event_type} - tombstone or consumed value") - return None - - # Handle empty messages - if not raw_value: - self.logger.warning(f"Received empty message for {event_type} - empty bytes") - return None - - try: - # Deserialize using schema registry if available - event = await self.consumer._schema_registry.deserialize_event(raw_value, message.topic) - if not event: - self.logger.error(f"Failed to deserialize event for {event_type}") - return None - - # Call the idempotent wrapper directly in async context - await idempotent_wrapper(event) - - self.logger.debug(f"Successfully processed {event_type} event: {event.event_id}") - return None - except Exception as e: - self.logger.error(f"Failed to process message for {event_type}: {e}", exc_info=True) - raise - - # Register with the dispatcher if available - if self.dispatcher: - # Create wrapper for EventDispatcher - async def dispatch_handler(event: DomainEvent) -> None: - await idempotent_wrapper(event) - - self.dispatcher.register(EventType(event_type))(dispatch_handler) - else: - # Fallback to direct consumer registration if no dispatcher - self.logger.error(f"No EventDispatcher available for registering idempotent handler for {event_type}") diff --git a/backend/app/services/idempotency/redis_repository.py b/backend/app/services/idempotency/redis_repository.py deleted file mode 100644 index 54b9eb1a..00000000 --- a/backend/app/services/idempotency/redis_repository.py +++ /dev/null @@ -1,141 +0,0 @@ -from __future__ import annotations - -import json -from datetime import datetime, timezone -from typing import Any - -import redis.asyncio as redis -from pymongo.errors import DuplicateKeyError - -from app.domain.idempotency import IdempotencyRecord, IdempotencyStatus - - -def _iso(dt: datetime) -> str: - return dt.astimezone(timezone.utc).isoformat() - - -def _json_default(obj: Any) -> str: - if isinstance(obj, datetime): - return _iso(obj) - return str(obj) - - -def _parse_iso_datetime(v: str | None) -> datetime | None: - if not v: - return None - try: - return datetime.fromisoformat(v.replace("Z", "+00:00")) - except Exception: - return None - - -class RedisIdempotencyRepository: - """Redis-backed repository compatible with IdempotencyManager expectations. - - Key shape: : - Value: JSON document with fields similar to Mongo version. - Expiration: handled by Redis key expiry; initial EX set on insert. - """ - - def __init__(self, client: redis.Redis, key_prefix: str = "idempotency") -> None: - self._r = client - self._prefix = key_prefix.rstrip(":") - - def _full_key(self, key: str) -> str: - # If caller already namespaces, respect it; otherwise prefix. - return key if key.startswith(f"{self._prefix}:") else f"{self._prefix}:{key}" - - def _doc_to_record(self, doc: dict[str, Any]) -> IdempotencyRecord: - created_at = doc.get("created_at") - if isinstance(created_at, str): - created_at = _parse_iso_datetime(created_at) - completed_at = doc.get("completed_at") - if isinstance(completed_at, str): - completed_at = _parse_iso_datetime(completed_at) - return IdempotencyRecord( - key=str(doc.get("key", "")), - status=IdempotencyStatus(doc.get("status", IdempotencyStatus.PROCESSING)), - event_type=str(doc.get("event_type", "")), - event_id=str(doc.get("event_id", "")), - created_at=created_at, # type: ignore[arg-type] - ttl_seconds=int(doc.get("ttl_seconds", 0) or 0), - completed_at=completed_at, - processing_duration_ms=doc.get("processing_duration_ms"), - error=doc.get("error"), - result_json=doc.get("result"), - ) - - def _record_to_doc(self, rec: IdempotencyRecord) -> dict[str, Any]: - return { - "key": rec.key, - "status": rec.status, - "event_type": rec.event_type, - "event_id": rec.event_id, - "created_at": _iso(rec.created_at), - "ttl_seconds": rec.ttl_seconds, - "completed_at": _iso(rec.completed_at) if rec.completed_at else None, - "processing_duration_ms": rec.processing_duration_ms, - "error": rec.error, - "result": rec.result_json, - } - - async def find_by_key(self, key: str) -> IdempotencyRecord | None: - k = self._full_key(key) - raw = await self._r.get(k) - if not raw: - return None - try: - doc: dict[str, Any] = json.loads(raw) - except Exception: - return None - return self._doc_to_record(doc) - - async def insert_processing(self, record: IdempotencyRecord) -> None: - k = self._full_key(record.key) - doc = self._record_to_doc(record) - # SET NX with EX for atomic reservation - ok = await self._r.set(k, json.dumps(doc, default=_json_default), ex=record.ttl_seconds, nx=True) - if not ok: - # Mirror Mongo behavior so manager's DuplicateKeyError path is reused - raise DuplicateKeyError("Key already exists") - - async def update_record(self, record: IdempotencyRecord) -> int: - k = self._full_key(record.key) - # Read-modify-write while preserving TTL - pipe = self._r.pipeline() - pipe.ttl(k) - pipe.get(k) - ttl_val, raw = await pipe.execute() - if not raw: - return 0 - doc = self._record_to_doc(record) - # Write back, keep TTL if positive - payload = json.dumps(doc, default=_json_default) - if isinstance(ttl_val, int) and ttl_val > 0: - await self._r.set(k, payload, ex=ttl_val) - else: - await self._r.set(k, payload) - return 1 - - async def delete_key(self, key: str) -> int: - k = self._full_key(key) - return int(await self._r.delete(k) or 0) - - async def aggregate_status_counts(self, key_prefix: str) -> dict[str, int]: - pattern = f"{key_prefix.rstrip(':')}:*" - counts: dict[str, int] = {} - # SCAN to avoid blocking Redis - async for k in self._r.scan_iter(match=pattern, count=200): - try: - raw = await self._r.get(k) - if not raw: - continue - doc = json.loads(raw) - status = str(doc.get("status", "")) - counts[status] = counts.get(status, 0) + 1 - except Exception: - continue - return counts - - async def health_check(self) -> None: - await self._r.ping() # type: ignore[misc] # redis-py returns Awaitable[bool] | bool diff --git a/backend/app/services/k8s_worker/worker.py b/backend/app/services/k8s_worker/worker.py index c49ec98f..9966241d 100644 --- a/backend/app/services/k8s_worker/worker.py +++ b/backend/app/services/k8s_worker/worker.py @@ -1,10 +1,3 @@ -"""Kubernetes Worker - stateless event handler. - -Creates Kubernetes pods from execution events. Receives events, -processes them, and publishes results. No lifecycle management. -All state is stored in Redis repositories. -""" - from __future__ import annotations import asyncio @@ -16,7 +9,7 @@ from kubernetes.client.rest import ApiException from app.core.metrics import EventMetrics, ExecutionMetrics, KubernetesMetrics -from app.db.repositories.pod_state_repository import PodStateRepository +from app.db.repositories.redis.pod_state_repository import PodStateRepository from app.domain.enums.storage import ExecutionErrorType from app.domain.events.typed import ( CreatePodCommandEvent, diff --git a/backend/app/services/kafka_event_service.py b/backend/app/services/kafka_event_service.py index 9c152b97..1c9d15ae 100644 --- a/backend/app/services/kafka_event_service.py +++ b/backend/app/services/kafka_event_service.py @@ -68,6 +68,7 @@ async def publish_event( event_metadata = metadata or EventMetadata( service_name=self.settings.SERVICE_NAME, service_version=self.settings.SERVICE_VERSION, + user_id="system", correlation_id=correlation_id or str(uuid4()), ) if correlation_id and event_metadata.correlation_id != correlation_id: diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 1e37d987..f3b581ae 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -155,9 +155,6 @@ async def _handle_execution_completed(self, event: ExecutionCompletedEvent) -> N body=body, severity=NotificationSeverity.MEDIUM, tags=["execution", "completed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ), ) async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: @@ -167,12 +164,6 @@ async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: self._logger.error("No user_id in event metadata") return - event_data = event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ) - event_data["stdout"] = event_data["stdout"][:200] - event_data["stderr"] = event_data["stderr"][:200] - title = f"Execution Failed: {event.execution_id}" body = f"Your execution failed: {event.error_message}" await self.create_notification( @@ -181,7 +172,6 @@ async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: body=body, severity=NotificationSeverity.HIGH, tags=["execution", "failed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event_data, ) async def _handle_execution_timeout(self, event: ExecutionTimeoutEvent) -> None: @@ -199,9 +189,6 @@ async def _handle_execution_timeout(self, event: ExecutionTimeoutEvent) -> None: body=body, severity=NotificationSeverity.HIGH, tags=["execution", "timeout", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ), ) async def create_notification( diff --git a/backend/app/services/pod_monitor/event_mapper.py b/backend/app/services/pod_monitor/event_mapper.py index bbdefc28..91305ada 100644 --- a/backend/app/services/pod_monitor/event_mapper.py +++ b/backend/app/services/pod_monitor/event_mapper.py @@ -183,7 +183,7 @@ def _create_metadata(self, pod: k8s_client.V1Pod) -> EventMetadata: correlation_id = annotations.get("integr8s.io/correlation-id") or labels.get("correlation-id") or "" md = EventMetadata( - user_id=labels.get("user-id"), + user_id=labels.get("user-id", "system"), service_name=GroupId.POD_MONITOR, service_version="1.0.0", correlation_id=correlation_id, diff --git a/backend/app/services/pod_monitor/monitor.py b/backend/app/services/pod_monitor/monitor.py index 046cbee9..b1604d7b 100644 --- a/backend/app/services/pod_monitor/monitor.py +++ b/backend/app/services/pod_monitor/monitor.py @@ -17,7 +17,7 @@ from app.core.metrics import KubernetesMetrics from app.core.utils import StringEnum -from app.db.repositories.pod_state_repository import PodStateRepository +from app.db.repositories.redis.pod_state_repository import PodStateRepository from app.domain.events.typed import DomainEvent from app.events.core import UnifiedProducer from app.services.pod_monitor.config import PodMonitorConfig diff --git a/backend/app/services/result_processor/processor.py b/backend/app/services/result_processor/processor.py index a71a4e4b..73c40244 100644 --- a/backend/app/services/result_processor/processor.py +++ b/backend/app/services/result_processor/processor.py @@ -180,6 +180,7 @@ async def _publish_result_stored(self, result: ExecutionResultDomain) -> None: metadata=EventMetadata( service_name=GroupId.RESULT_PROCESSOR, service_version="1.0.0", + user_id="system", ), ) @@ -193,6 +194,7 @@ async def _publish_result_failed(self, execution_id: str, error_message: str) -> metadata=EventMetadata( service_name=GroupId.RESULT_PROCESSOR, service_version="1.0.0", + user_id="system", ), ) diff --git a/backend/app/settings.py b/backend/app/settings.py index 57e7c752..8fb3d34f 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -130,6 +130,9 @@ class Settings(BaseSettings): RATE_LIMIT_REDIS_PREFIX: str = "rate_limit:" RATE_LIMIT_ALGORITHM: str = "sliding_window" # sliding_window or token_bucket + # Per-user execution limit + MAX_EXECUTIONS_PER_USER: int = 100 + # Service metadata SERVICE_NAME: str = "integr8scode-backend" SERVICE_VERSION: str = "1.0.0" diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index bb96bf9d..7103b4ea 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -8,6 +8,7 @@ import pytest_asyncio import redis.asyncio as redis from app.core.database_context import Database +from app.domain.enums.execution import QueuePriority from app.domain.events.typed import EventMetadata, ExecutionRequestedEvent from app.main import create_app from app.settings import Settings @@ -196,10 +197,10 @@ def make_execution_requested_event( memory_limit: str = "128Mi", cpu_request: str = "50m", memory_request: str = "64Mi", - priority: int = 5, + priority: QueuePriority = QueuePriority.NORMAL, service_name: str = "tests", service_version: str = "1.0.0", - user_id: str | None = None, + user_id: str = "test-user", ) -> ExecutionRequestedEvent: """Factory for ExecutionRequestedEvent with sensible defaults. diff --git a/backend/tests/e2e/db/repositories/test_dlq_repository.py b/backend/tests/e2e/db/repositories/test_dlq_repository.py index 9464d087..04b4ae87 100644 --- a/backend/tests/e2e/db/repositories/test_dlq_repository.py +++ b/backend/tests/e2e/db/repositories/test_dlq_repository.py @@ -6,6 +6,7 @@ from app.db.repositories.dlq_repository import DLQRepository from app.dlq import DLQMessageStatus from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic pytestmark = pytest.mark.e2e @@ -30,7 +31,7 @@ async def insert_test_dlq_docs() -> None: "user_id": "u1", "login_method": "password", }, - original_topic="t1", + original_topic=KafkaTopic.USER_EVENTS, error="err", retry_count=0, failed_at=now, @@ -45,7 +46,7 @@ async def insert_test_dlq_docs() -> None: "user_id": "u1", "login_method": "password", }, - original_topic="t1", + original_topic=KafkaTopic.USER_EVENTS, error="err", retry_count=0, failed_at=now, @@ -60,7 +61,7 @@ async def insert_test_dlq_docs() -> None: "execution_id": "x1", "pod_name": "p1", }, - original_topic="t2", + original_topic=KafkaTopic.EXECUTION_EVENTS, error="err", retry_count=0, failed_at=now, @@ -88,4 +89,4 @@ async def test_stats_list_get_and_updates(repo: DLQRepository) -> None: assert await repo.mark_message_discarded("id1", "r") in (True, False) topics = await repo.get_topics_summary() - assert any(t.topic == "t1" for t in topics) + assert any(t.topic == KafkaTopic.USER_EVENTS for t in topics) diff --git a/backend/tests/e2e/dlq/test_dlq_discard.py b/backend/tests/e2e/dlq/test_dlq_discard.py index 2c4650f4..270950b6 100644 --- a/backend/tests/e2e/dlq/test_dlq_discard.py +++ b/backend/tests/e2e/dlq/test_dlq_discard.py @@ -32,7 +32,7 @@ async def _create_dlq_document( doc = DLQMessageDocument( event=event_dict, - original_topic=str(KafkaTopic.EXECUTION_EVENTS), + original_topic=KafkaTopic.EXECUTION_EVENTS, error="Test error", retry_count=0, failed_at=now, diff --git a/backend/tests/e2e/dlq/test_dlq_manager.py b/backend/tests/e2e/dlq/test_dlq_manager.py index d8888138..612bbbbb 100644 --- a/backend/tests/e2e/dlq/test_dlq_manager.py +++ b/backend/tests/e2e/dlq/test_dlq_manager.py @@ -5,7 +5,7 @@ from datetime import datetime, timezone import pytest -from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from aiokafka import AIOKafkaConsumer from app.dlq.manager import DLQManager from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic diff --git a/backend/tests/e2e/dlq/test_dlq_retry.py b/backend/tests/e2e/dlq/test_dlq_retry.py index d01fefe7..01b08ff4 100644 --- a/backend/tests/e2e/dlq/test_dlq_retry.py +++ b/backend/tests/e2e/dlq/test_dlq_retry.py @@ -32,7 +32,7 @@ async def _create_dlq_document( doc = DLQMessageDocument( event=event_dict, - original_topic=str(KafkaTopic.EXECUTION_EVENTS), + original_topic=KafkaTopic.EXECUTION_EVENTS, error="Test error", retry_count=0, failed_at=now, diff --git a/backend/tests/e2e/events/test_dlq_handler.py b/backend/tests/e2e/events/test_dlq_handler.py index d96dde5e..d93dd423 100644 --- a/backend/tests/e2e/events/test_dlq_handler.py +++ b/backend/tests/e2e/events/test_dlq_handler.py @@ -1,6 +1,7 @@ import logging import pytest +from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent, EventMetadata, SagaStartedEvent from app.events.core import UnifiedProducer, create_dlq_error_handler, create_immediate_dlq_handler from dishka import AsyncContainer @@ -13,21 +14,21 @@ @pytest.mark.asyncio async def test_dlq_handler_with_retries(scope: AsyncContainer, monkeypatch: pytest.MonkeyPatch) -> None: p: UnifiedProducer = await scope.get(UnifiedProducer) - calls: list[tuple[str | None, str, str, int]] = [] + calls: list[tuple[str | None, KafkaTopic, str, int]] = [] async def _record_send_to_dlq( - original_event: DomainEvent, original_topic: str, error: Exception, retry_count: int + original_event: DomainEvent, original_topic: KafkaTopic, error: Exception, retry_count: int ) -> None: calls.append((original_event.event_id, original_topic, str(error), retry_count)) monkeypatch.setattr(p, "send_to_dlq", _record_send_to_dlq) - h = create_dlq_error_handler(p, original_topic="t", max_retries=2, logger=_test_logger) + h = create_dlq_error_handler(p, original_topic=KafkaTopic.SAGA_EVENTS, max_retries=2, logger=_test_logger) e = SagaStartedEvent( saga_id="s", saga_name="n", execution_id="x", initial_event_id="i", - metadata=EventMetadata(service_name="a", service_version="1"), + metadata=EventMetadata(service_name="a", service_version="1", user_id="test"), ) # Call 1 and 2 should not send to DLQ await h(RuntimeError("boom"), e) @@ -36,27 +37,27 @@ async def _record_send_to_dlq( # 3rd call triggers DLQ await h(RuntimeError("boom"), e) assert len(calls) == 1 - assert calls[0][1] == "t" + assert calls[0][1] == KafkaTopic.SAGA_EVENTS @pytest.mark.asyncio async def test_immediate_dlq_handler(scope: AsyncContainer, monkeypatch: pytest.MonkeyPatch) -> None: p: UnifiedProducer = await scope.get(UnifiedProducer) - calls: list[tuple[str | None, str, str, int]] = [] + calls: list[tuple[str | None, KafkaTopic, str, int]] = [] async def _record_send_to_dlq( - original_event: DomainEvent, original_topic: str, error: Exception, retry_count: int + original_event: DomainEvent, original_topic: KafkaTopic, error: Exception, retry_count: int ) -> None: calls.append((original_event.event_id, original_topic, str(error), retry_count)) monkeypatch.setattr(p, "send_to_dlq", _record_send_to_dlq) - h = create_immediate_dlq_handler(p, original_topic="t", logger=_test_logger) + h = create_immediate_dlq_handler(p, original_topic=KafkaTopic.SAGA_EVENTS, logger=_test_logger) e = SagaStartedEvent( saga_id="s2", saga_name="n", execution_id="x", initial_event_id="i", - metadata=EventMetadata(service_name="a", service_version="1"), + metadata=EventMetadata(service_name="a", service_version="1", user_id="test"), ) await h(RuntimeError("x"), e) assert calls and calls[0][3] == 0 diff --git a/backend/tests/e2e/events/test_producer_roundtrip.py b/backend/tests/e2e/events/test_producer_roundtrip.py index ed1a4cdb..e844a3b0 100644 --- a/backend/tests/e2e/events/test_producer_roundtrip.py +++ b/backend/tests/e2e/events/test_producer_roundtrip.py @@ -26,7 +26,7 @@ async def test_unified_producer_produce_and_send_to_dlq(scope: AsyncContainer) - await prod.produce(ev) # Exercise send_to_dlq path - topic = str(get_topic_for_event(ev.event_type)) + topic = get_topic_for_event(ev.event_type) await prod.send_to_dlq(ev, original_topic=topic, error=RuntimeError("forced"), retry_count=1) # Verify metrics are being tracked diff --git a/backend/tests/e2e/events/test_schema_registry_real.py b/backend/tests/e2e/events/test_schema_registry_real.py index d6c182de..877c2ab7 100644 --- a/backend/tests/e2e/events/test_schema_registry_real.py +++ b/backend/tests/e2e/events/test_schema_registry_real.py @@ -19,7 +19,7 @@ async def test_serialize_and_deserialize_event_real_registry(test_settings: Sett execution_id="e1", pod_name="p", namespace="n", - metadata=EventMetadata(service_name="s", service_version="1"), + metadata=EventMetadata(service_name="s", service_version="1", user_id="test"), ) data = await m.serialize_event(ev) topic = str(get_topic_for_event(ev.event_type)) diff --git a/backend/tests/e2e/idempotency/__init__.py b/backend/tests/e2e/idempotency/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/backend/tests/e2e/idempotency/test_consumer_idempotent.py b/backend/tests/e2e/idempotency/test_consumer_idempotent.py deleted file mode 100644 index 749a0ea3..00000000 --- a/backend/tests/e2e/idempotency/test_consumer_idempotent.py +++ /dev/null @@ -1,102 +0,0 @@ -import asyncio -import logging -import uuid - -import pytest -from aiokafka import AIOKafkaConsumer -from app.core.metrics import EventMetrics -from app.domain.enums.events import EventType -from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import DomainEvent -from app.events.core import EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.core.dispatcher import EventDispatcher as Disp -from app.events.schema.schema_registry import SchemaRegistryManager -from app.domain.idempotency import KeyStrategy -from app.services.idempotency.idempotency_manager import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper -from app.settings import Settings -from dishka import AsyncContainer - -from tests.conftest import make_execution_requested_event - -# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers -# instantiate Consumer() objects simultaneously. Serial execution prevents this. -pytestmark = [ - pytest.mark.e2e, - pytest.mark.kafka, - pytest.mark.redis, - pytest.mark.xdist_group("kafka_consumers"), -] - -_test_logger = logging.getLogger("test.idempotency.consumer_idempotent") - - -@pytest.mark.asyncio -async def test_consumer_idempotent_wrapper_blocks_duplicates(scope: AsyncContainer) -> None: - producer: UnifiedProducer = await scope.get(UnifiedProducer) - idm: IdempotencyManager = await scope.get(IdempotencyManager) - registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - settings: Settings = await scope.get(Settings) - event_metrics: EventMetrics = await scope.get(EventMetrics) - - # Future resolves when handler processes an event - no polling needed - handled_future: asyncio.Future[None] = asyncio.get_running_loop().create_future() - seen = {"n": 0} - - # Build a dispatcher that signals completion via future - disp: Disp = EventDispatcher(logger=_test_logger) - - @disp.register(EventType.EXECUTION_REQUESTED) - async def handle(_ev: DomainEvent) -> None: - seen["n"] += 1 - if not handled_future.done(): - handled_future.set_result(None) - - # Produce messages BEFORE starting consumer (auto_offset_reset="earliest" will read them) - execution_id = f"e-{uuid.uuid4().hex[:8]}" - ev = make_execution_requested_event(execution_id=execution_id) - await producer.produce(ev, key=execution_id) - await producer.produce(ev, key=execution_id) - - group_id = f"test-idem-consumer.{uuid.uuid4().hex[:6]}" - - # Create AIOKafkaConsumer directly - topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EXECUTION_EVENTS}" - kafka_consumer = AIOKafkaConsumer( - topic, - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=group_id, - enable_auto_commit=True, - auto_offset_reset="earliest", - ) - await kafka_consumer.start() - - handler = UnifiedConsumer( - event_dispatcher=disp, - schema_registry=registry, - logger=_test_logger, - event_metrics=event_metrics, - group_id=group_id, - ) - wrapper = IdempotentConsumerWrapper( - consumer=handler, - idempotency_manager=idm, - dispatcher=disp, - default_key_strategy=KeyStrategy.EVENT_BASED, - enable_for_all_handlers=True, - logger=_test_logger, - ) - - try: - # Consume until handler is called - async def consume_until_handled() -> None: - async for msg in kafka_consumer: - await handler.handle(msg) - await kafka_consumer.commit() - if handled_future.done(): - break - - await asyncio.wait_for(consume_until_handled(), timeout=10.0) - assert seen["n"] >= 1 - finally: - await kafka_consumer.stop() diff --git a/backend/tests/e2e/idempotency/test_decorator_idempotent.py b/backend/tests/e2e/idempotency/test_decorator_idempotent.py deleted file mode 100644 index 9638c4a3..00000000 --- a/backend/tests/e2e/idempotency/test_decorator_idempotent.py +++ /dev/null @@ -1,53 +0,0 @@ -import logging - -import pytest -from app.domain.events.typed import DomainEvent -from app.domain.idempotency import KeyStrategy -from app.services.idempotency.idempotency_manager import IdempotencyManager -from app.services.idempotency.middleware import idempotent_handler -from dishka import AsyncContainer - -from tests.conftest import make_execution_requested_event - -_test_logger = logging.getLogger("test.idempotency.decorator_idempotent") - - -pytestmark = [pytest.mark.e2e] - - -@pytest.mark.asyncio -async def test_decorator_blocks_duplicate_event(scope: AsyncContainer) -> None: - idm: IdempotencyManager = await scope.get(IdempotencyManager) - - calls = {"n": 0} - - @idempotent_handler(idempotency_manager=idm, key_strategy=KeyStrategy.EVENT_BASED, logger=_test_logger) - async def h(ev: DomainEvent) -> None: - calls["n"] += 1 - - ev = make_execution_requested_event(execution_id="exec-deco-1") - - await h(ev) - await h(ev) # duplicate - assert calls["n"] == 1 - - -@pytest.mark.asyncio -async def test_decorator_custom_key_blocks(scope: AsyncContainer) -> None: - idm: IdempotencyManager = await scope.get(IdempotencyManager) - - calls = {"n": 0} - - def fixed_key(_ev: DomainEvent) -> str: - return "fixed-key" - - @idempotent_handler(idempotency_manager=idm, key_strategy=KeyStrategy.CUSTOM, custom_key_func=fixed_key, logger=_test_logger) - async def h(ev: DomainEvent) -> None: - calls["n"] += 1 - - e1 = make_execution_requested_event(execution_id="exec-deco-2a") - e2 = make_execution_requested_event(execution_id="exec-deco-2b") - - await h(e1) - await h(e2) # different event ids but same custom key - assert calls["n"] == 1 diff --git a/backend/tests/e2e/idempotency/test_idempotency.py b/backend/tests/e2e/idempotency/test_idempotency.py deleted file mode 100644 index 9dc444ec..00000000 --- a/backend/tests/e2e/idempotency/test_idempotency.py +++ /dev/null @@ -1,564 +0,0 @@ -import asyncio -import json -import logging -import uuid -from collections.abc import AsyncGenerator -from datetime import datetime, timedelta, timezone -from typing import Any - -import pytest -import redis.asyncio as redis -from app.core.metrics import DatabaseMetrics -from app.domain.events.typed import DomainEvent -from app.domain.idempotency import IdempotencyRecord, IdempotencyStatus, KeyStrategy -from app.services.idempotency.idempotency_manager import IdempotencyConfig, IdempotencyManager -from app.services.idempotency.middleware import IdempotentEventHandler, idempotent_handler -from app.services.idempotency.redis_repository import RedisIdempotencyRepository -from app.settings import Settings - -from tests.conftest import make_execution_requested_event - -pytestmark = [pytest.mark.e2e, pytest.mark.redis] - -# Test logger for all tests -_test_logger = logging.getLogger("test.idempotency") - - -class TestIdempotencyManager: - """IdempotencyManager backed by real Redis repository (DI-provided client).""" - - @pytest.fixture - async def manager(self, redis_client: redis.Redis, test_settings: Settings) -> AsyncGenerator[IdempotencyManager, None]: - prefix = f"idemp_ut:{uuid.uuid4().hex[:6]}" - cfg = IdempotencyConfig( - key_prefix=prefix, - default_ttl_seconds=3600, - processing_timeout_seconds=5, - enable_result_caching=True, - max_result_size_bytes=1024, - enable_metrics=False, - ) - repo = RedisIdempotencyRepository(redis_client, key_prefix=prefix) - database_metrics = DatabaseMetrics(test_settings) - m = IdempotencyManager(cfg, repo, _test_logger, database_metrics=database_metrics) - await m.initialize() - try: - yield m - finally: - await m.close() - - @pytest.mark.asyncio - async def test_complete_flow_new_event(self, manager: IdempotencyManager) -> None: - """Test the complete flow for a new event""" - real_event = make_execution_requested_event(execution_id="exec-123") - # Check and reserve - result = await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - - assert result.is_duplicate is False - assert result.status == IdempotencyStatus.PROCESSING - assert result.key.endswith(f"{real_event.event_type}:{real_event.event_id}") - assert result.key.startswith(f"{manager.config.key_prefix}:") - - # Verify it's in the repository - record = await manager._repo.find_by_key(result.key) - assert record is not None - assert record.status == IdempotencyStatus.PROCESSING - - # Mark as completed - success = await manager.mark_completed(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert success is True - - # Verify status updated - record = await manager._repo.find_by_key(result.key) - assert record is not None - assert record.status == IdempotencyStatus.COMPLETED - assert record.completed_at is not None - assert record.processing_duration_ms is not None - - @pytest.mark.asyncio - async def test_duplicate_detection(self, manager: IdempotencyManager) -> None: - """Test that duplicates are properly detected""" - real_event = make_execution_requested_event(execution_id="exec-dupe-1") - # First request - result1 = await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert result1.is_duplicate is False - - # Mark as completed - await manager.mark_completed(real_event, key_strategy=KeyStrategy.EVENT_BASED) - - # Second request with same event - result2 = await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert result2.is_duplicate is True - assert result2.status == IdempotencyStatus.COMPLETED - - @pytest.mark.asyncio - async def test_concurrent_requests_race_condition(self, manager: IdempotencyManager) -> None: - """Test handling of concurrent requests for the same event""" - real_event = make_execution_requested_event(execution_id="exec-race-1") - # Simulate concurrent requests - tasks = [ - manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - for _ in range(5) - ] - - results = await asyncio.gather(*tasks) - - # Only one should succeed - non_duplicate_count = sum(1 for r in results if not r.is_duplicate) - assert non_duplicate_count == 1 - - # Others should be marked as duplicates - duplicate_count = sum(1 for r in results if r.is_duplicate) - assert duplicate_count == 4 - - @pytest.mark.asyncio - async def test_processing_timeout_allows_retry(self, manager: IdempotencyManager) -> None: - """Test that stuck processing allows retry after timeout""" - real_event = make_execution_requested_event(execution_id="exec-timeout-1") - # First request - result1 = await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert result1.is_duplicate is False - - # Manually update the created_at to simulate old processing - record = await manager._repo.find_by_key(result1.key) - assert record is not None - record.created_at = datetime.now(timezone.utc) - timedelta(seconds=10) - await manager._repo.update_record(record) - - # Second request should be allowed due to timeout - result2 = await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert result2.is_duplicate is False # Allowed to retry - assert result2.status == IdempotencyStatus.PROCESSING - - @pytest.mark.asyncio - async def test_content_hash_strategy(self, manager: IdempotencyManager) -> None: - """Test content-based deduplication""" - # Two events with same content and same execution_id - event1 = make_execution_requested_event( - execution_id="exec-1", - service_name="test-service", - ) - - event2 = make_execution_requested_event( - execution_id="exec-1", - service_name="test-service", - ) - - # Use content hash strategy - result1 = await manager.check_and_reserve(event1, key_strategy=KeyStrategy.CONTENT_HASH) - assert result1.is_duplicate is False - - await manager.mark_completed(event1, key_strategy=KeyStrategy.CONTENT_HASH) - - # Second event with same content should be duplicate - result2 = await manager.check_and_reserve(event2, key_strategy=KeyStrategy.CONTENT_HASH) - assert result2.is_duplicate is True - - @pytest.mark.asyncio - async def test_failed_event_handling(self, manager: IdempotencyManager) -> None: - """Test marking events as failed""" - real_event = make_execution_requested_event(execution_id="exec-failed-1") - # Reserve - result = await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert result.is_duplicate is False - - # Mark as failed - error_msg = "Execution failed: out of memory" - success = await manager.mark_failed(real_event, error=error_msg, key_strategy=KeyStrategy.EVENT_BASED) - assert success is True - - # Verify status and error - record = await manager._repo.find_by_key(result.key) - assert record is not None - assert record.status == IdempotencyStatus.FAILED - assert record.error == error_msg - assert record.completed_at is not None - - @pytest.mark.asyncio - async def test_result_caching(self, manager: IdempotencyManager) -> None: - """Test caching of results""" - real_event = make_execution_requested_event(execution_id="exec-cache-1") - # Reserve - result = await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert result.is_duplicate is False - - # Complete with cached result - cached_result = json.dumps({"output": "Hello, World!", "exit_code": 0}) - success = await manager.mark_completed_with_json( - real_event, - cached_json=cached_result, - key_strategy=KeyStrategy.EVENT_BASED - ) - assert success is True - - # Retrieve cached result - retrieved = await manager.get_cached_json(real_event, KeyStrategy.EVENT_BASED, None) - assert retrieved == cached_result - - # Check duplicate with cached result - duplicate_result = await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert duplicate_result.is_duplicate is True - assert duplicate_result.has_cached_result is True - - @pytest.mark.asyncio - async def test_stats_aggregation(self, manager: IdempotencyManager) -> None: - """Test statistics aggregation""" - # Create various events with different statuses - events = [] - for i in range(10): - event = make_execution_requested_event( - execution_id=f"exec-{i}", - script=f"print({i})", - service_name="test-service", - ) - events.append(event) - - # Process events with different outcomes - for i, event in enumerate(events): - await manager.check_and_reserve(event, key_strategy=KeyStrategy.EVENT_BASED) - - if i < 6: - await manager.mark_completed(event, key_strategy=KeyStrategy.EVENT_BASED) - elif i < 8: - await manager.mark_failed(event, "Test error", key_strategy=KeyStrategy.EVENT_BASED) - # Leave rest in processing - - # Get stats - stats = await manager.get_stats() - - assert stats.total_keys == 10 - assert stats.status_counts[IdempotencyStatus.COMPLETED] == 6 - assert stats.status_counts[IdempotencyStatus.FAILED] == 2 - assert stats.status_counts[IdempotencyStatus.PROCESSING] == 2 - assert stats.prefix == manager.config.key_prefix - - @pytest.mark.asyncio - async def test_remove_key(self, manager: IdempotencyManager) -> None: - """Test removing idempotency keys""" - real_event = make_execution_requested_event(execution_id="exec-remove-1") - # Add a key - result = await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert result.is_duplicate is False - - # Remove it - removed = await manager.remove(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert removed is True - - # Verify it's gone - record = await manager._repo.find_by_key(result.key) - assert record is None - - # Can process again - result2 = await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - assert result2.is_duplicate is False - - -class TestIdempotentEventHandlerIntegration: - """Test IdempotentEventHandler with real components""" - - @pytest.fixture - async def manager(self, redis_client: redis.Redis, test_settings: Settings) -> AsyncGenerator[IdempotencyManager, None]: - prefix = f"handler_test:{uuid.uuid4().hex[:6]}" - config = IdempotencyConfig(key_prefix=prefix, enable_metrics=False) - repo = RedisIdempotencyRepository(redis_client, key_prefix=prefix) - database_metrics = DatabaseMetrics(test_settings) - m = IdempotencyManager(config, repo, _test_logger, database_metrics=database_metrics) - await m.initialize() - try: - yield m - finally: - await m.close() - - @pytest.mark.asyncio - async def test_handler_processes_new_event(self, manager: IdempotencyManager) -> None: - """Test that handler processes new events""" - processed_events: list[DomainEvent] = [] - - async def actual_handler(event: DomainEvent) -> None: - processed_events.append(event) - - # Create idempotent handler - handler = IdempotentEventHandler( - handler=actual_handler, - idempotency_manager=manager, - key_strategy=KeyStrategy.EVENT_BASED, - logger=_test_logger, - ) - - # Process event - real_event = make_execution_requested_event(execution_id="handler-test-123") - await handler(real_event) - - # Verify event was processed - assert len(processed_events) == 1 - assert processed_events[0] == real_event - - @pytest.mark.asyncio - async def test_handler_blocks_duplicate(self, manager: IdempotencyManager) -> None: - """Test that handler blocks duplicate events""" - processed_events: list[DomainEvent] = [] - - async def actual_handler(event: DomainEvent) -> None: - processed_events.append(event) - - # Create idempotent handler - handler = IdempotentEventHandler( - handler=actual_handler, - idempotency_manager=manager, - key_strategy=KeyStrategy.EVENT_BASED, - logger=_test_logger, - ) - - # Process event twice - real_event = make_execution_requested_event(execution_id="handler-dup-123") - await handler(real_event) - await handler(real_event) - - # Verify event was processed only once - assert len(processed_events) == 1 - - @pytest.mark.asyncio - async def test_handler_with_failure(self, manager: IdempotencyManager) -> None: - """Test handler marks failure on exception""" - - async def failing_handler(event: DomainEvent) -> None: # noqa: ARG001 - raise ValueError("Processing failed") - - handler = IdempotentEventHandler( - handler=failing_handler, - idempotency_manager=manager, - key_strategy=KeyStrategy.EVENT_BASED, - logger=_test_logger, - ) - - # Process event (should raise) - real_event = make_execution_requested_event(execution_id="handler-fail-1") - with pytest.raises(ValueError, match="Processing failed"): - await handler(real_event) - - # Verify marked as failed - key = f"{manager.config.key_prefix}:{real_event.event_type}:{real_event.event_id}" - record = await manager._repo.find_by_key(key) - assert record is not None - assert record.status == IdempotencyStatus.FAILED - assert record.error is not None - assert "Processing failed" in record.error - - @pytest.mark.asyncio - async def test_handler_duplicate_callback(self, manager: IdempotencyManager) -> None: - """Test duplicate callback is invoked""" - duplicate_events: list[tuple[DomainEvent, Any]] = [] - - async def actual_handler(event: DomainEvent) -> None: # noqa: ARG001 - pass # Do nothing - - async def on_duplicate(event: DomainEvent, result: Any) -> None: - duplicate_events.append((event, result)) - - handler = IdempotentEventHandler( - handler=actual_handler, - idempotency_manager=manager, - key_strategy=KeyStrategy.EVENT_BASED, - on_duplicate=on_duplicate, - logger=_test_logger, - ) - - # Process twice - real_event = make_execution_requested_event(execution_id="handler-dup-cb-1") - await handler(real_event) - await handler(real_event) - - # Verify duplicate callback was called - assert len(duplicate_events) == 1 - assert duplicate_events[0][0] == real_event - assert duplicate_events[0][1].is_duplicate is True - - @pytest.mark.asyncio - async def test_decorator_integration(self, manager: IdempotencyManager) -> None: - """Test the @idempotent_handler decorator""" - processed_events: list[DomainEvent] = [] - - @idempotent_handler( - idempotency_manager=manager, - key_strategy=KeyStrategy.CONTENT_HASH, - ttl_seconds=300, - logger=_test_logger, - ) - async def my_handler(event: DomainEvent) -> None: - processed_events.append(event) - - # Process same event twice - real_event = make_execution_requested_event(execution_id="decor-1") - await my_handler(real_event) - await my_handler(real_event) - - # Should only process once - assert len(processed_events) == 1 - - # Create event with same ID and same content for content hash match - similar_event = make_execution_requested_event( - execution_id=real_event.execution_id, - script=real_event.script, - ) - - # Should still be blocked (content hash) - await my_handler(similar_event) - assert len(processed_events) == 1 # Still only one - - @pytest.mark.asyncio - async def test_custom_key_function(self, manager: IdempotencyManager) -> None: - """Test handler with custom key function""" - processed_scripts: list[str] = [] - - async def process_script(event: DomainEvent) -> None: - script: str = getattr(event, "script", "") - processed_scripts.append(script) - - def extract_script_key(event: DomainEvent) -> str: - # Custom key based on script content only - script: str = getattr(event, "script", "") - return f"script:{hash(script)}" - - handler = IdempotentEventHandler( - handler=process_script, - idempotency_manager=manager, - key_strategy=KeyStrategy.CUSTOM, - custom_key_func=extract_script_key, - logger=_test_logger, - ) - - # Events with same script - event1 = make_execution_requested_event( - execution_id="id1", - script="print('hello')", - service_name="test-service", - ) - - event2 = make_execution_requested_event( - execution_id="id2", - language="python", - language_version="3.9", # Different version - runtime_image="python:3.9-slim", - runtime_command=("python",), - runtime_filename="main.py", - timeout_seconds=60, # Different timeout - cpu_limit="200m", - memory_limit="256Mi", - cpu_request="100m", - memory_request="128Mi", - service_name="test-service", - ) - - await handler(event1) - await handler(event2) - - # Should only process once (same script) - assert len(processed_scripts) == 1 - assert processed_scripts[0] == "print('hello')" - - @pytest.mark.asyncio - async def test_invalid_key_strategy(self, manager: IdempotencyManager) -> None: - """Test that invalid key strategy raises error""" - real_event = make_execution_requested_event(execution_id="invalid-strategy-1") - with pytest.raises(ValueError, match="Invalid key strategy"): - await manager.check_and_reserve(real_event, key_strategy="invalid_strategy") # type: ignore[arg-type] # testing invalid use - - @pytest.mark.asyncio - async def test_custom_key_without_custom_key_param(self, manager: IdempotencyManager) -> None: - """Test that custom strategy without custom_key raises error""" - real_event = make_execution_requested_event(execution_id="custom-key-missing-1") - with pytest.raises(ValueError, match="Invalid key strategy"): - await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.CUSTOM) - - @pytest.mark.asyncio - async def test_get_cached_json_existing(self, manager: IdempotencyManager) -> None: - """Test retrieving cached JSON result""" - # First complete with cached result - real_event = make_execution_requested_event(execution_id="cache-exist-1") - await manager.check_and_reserve(real_event, key_strategy=KeyStrategy.EVENT_BASED) - cached_data = json.dumps({"output": "test", "code": 0}) - await manager.mark_completed_with_json(real_event, cached_data, KeyStrategy.EVENT_BASED) - - # Retrieve cached result - retrieved = await manager.get_cached_json(real_event, KeyStrategy.EVENT_BASED, None) - assert retrieved == cached_data - - @pytest.mark.asyncio - async def test_get_cached_json_non_existing(self, manager: IdempotencyManager) -> None: - """Test retrieving non-existing cached result raises assertion""" - real_event = make_execution_requested_event(execution_id="cache-miss-1") - # Trying to get cached result for non-existent key should raise - with pytest.raises(AssertionError, match="cached result must exist"): - await manager.get_cached_json(real_event, KeyStrategy.EVENT_BASED, None) - - @pytest.mark.asyncio - async def test_cleanup_expired_keys(self, manager: IdempotencyManager) -> None: - """Test cleanup of expired keys""" - # Create expired record - expired_key = f"{manager.config.key_prefix}:expired" - expired_record = IdempotencyRecord( - key=expired_key, - status=IdempotencyStatus.COMPLETED, - event_type="test", - event_id="expired-1", - created_at=datetime.now(timezone.utc) - timedelta(hours=2), - ttl_seconds=3600, # 1 hour TTL - completed_at=datetime.now(timezone.utc) - timedelta(hours=2) - ) - await manager._repo.insert_processing(expired_record) - - # Cleanup should detect it as expired - # Note: actual cleanup implementation depends on repository - record = await manager._repo.find_by_key(expired_key) - assert record is not None # Still exists until explicit cleanup - - @pytest.mark.asyncio - async def test_metrics_enabled(self, redis_client: redis.Redis, test_settings: Settings) -> None: - """Test manager with metrics enabled""" - config = IdempotencyConfig(key_prefix=f"metrics:{uuid.uuid4().hex[:6]}", enable_metrics=True) - repository = RedisIdempotencyRepository(redis_client, key_prefix=config.key_prefix) - database_metrics = DatabaseMetrics(test_settings) - manager = IdempotencyManager(config, repository, _test_logger, database_metrics=database_metrics) - - # Initialize with metrics - await manager.initialize() - assert manager._stats_update_task is not None - - # Cleanup - await manager.close() - - @pytest.mark.asyncio - async def test_content_hash_with_fields(self, manager: IdempotencyManager) -> None: - """Test content hash with specific fields""" - event1 = make_execution_requested_event( - execution_id="exec-1", - service_name="test-service", - ) - - # Use content hash with only script field - fields = {"script", "language"} - result1 = await manager.check_and_reserve( - event1, - key_strategy=KeyStrategy.CONTENT_HASH, - fields=fields - ) - assert result1.is_duplicate is False - await manager.mark_completed(event1, key_strategy=KeyStrategy.CONTENT_HASH, fields=fields) - - # Event with same script and language but different other fields - event2 = make_execution_requested_event( - execution_id="exec-2", - timeout_seconds=60, - cpu_limit="200m", - memory_limit="256Mi", - cpu_request="100m", - memory_request="128Mi", - service_name="test-service", - ) - - result2 = await manager.check_and_reserve( - event2, - key_strategy=KeyStrategy.CONTENT_HASH, - fields=fields - ) - assert result2.is_duplicate is True # Same script and language diff --git a/backend/tests/e2e/idempotency/test_idempotent_handler.py b/backend/tests/e2e/idempotency/test_idempotent_handler.py deleted file mode 100644 index 63f3fca9..00000000 --- a/backend/tests/e2e/idempotency/test_idempotent_handler.py +++ /dev/null @@ -1,63 +0,0 @@ -import logging - -import pytest -from app.domain.events.typed import DomainEvent -from app.domain.idempotency import KeyStrategy -from app.services.idempotency.idempotency_manager import IdempotencyManager -from app.services.idempotency.middleware import IdempotentEventHandler -from dishka import AsyncContainer - -from tests.conftest import make_execution_requested_event - -pytestmark = [pytest.mark.e2e] - -_test_logger = logging.getLogger("test.idempotency.idempotent_handler") - - -@pytest.mark.asyncio -async def test_idempotent_handler_blocks_duplicates(scope: AsyncContainer) -> None: - manager: IdempotencyManager = await scope.get(IdempotencyManager) - - processed: list[str | None] = [] - - async def _handler(ev: DomainEvent) -> None: - processed.append(ev.event_id) - - handler = IdempotentEventHandler( - handler=_handler, - idempotency_manager=manager, - key_strategy=KeyStrategy.EVENT_BASED, - logger=_test_logger, - ) - - ev = make_execution_requested_event(execution_id="exec-dup-1") - - await handler(ev) - await handler(ev) # duplicate - - assert processed == [ev.event_id] - - -@pytest.mark.asyncio -async def test_idempotent_handler_content_hash_blocks_same_content(scope: AsyncContainer) -> None: - manager: IdempotencyManager = await scope.get(IdempotencyManager) - - processed: list[str] = [] - - async def _handler(ev: DomainEvent) -> None: - processed.append(getattr(ev, "execution_id", "")) - - handler = IdempotentEventHandler( - handler=_handler, - idempotency_manager=manager, - key_strategy=KeyStrategy.CONTENT_HASH, - logger=_test_logger, - ) - - e1 = make_execution_requested_event(execution_id="exec-dup-2") - e2 = make_execution_requested_event(execution_id="exec-dup-2") - - await handler(e1) - await handler(e2) - - assert processed == [e1.execution_id] diff --git a/backend/tests/e2e/services/coordinator/test_execution_coordinator.py b/backend/tests/e2e/services/coordinator/test_execution_coordinator.py index 472ebc0e..67660c15 100644 --- a/backend/tests/e2e/services/coordinator/test_execution_coordinator.py +++ b/backend/tests/e2e/services/coordinator/test_execution_coordinator.py @@ -1,4 +1,5 @@ import pytest +from app.domain.enums.execution import QueuePriority from app.services.coordinator.coordinator import ExecutionCoordinator from dishka import AsyncContainer from tests.conftest import make_execution_requested_event @@ -22,7 +23,7 @@ async def test_handle_requested_does_not_raise(self, scope: AsyncContainer) -> N async def test_handle_requested_with_priority(self, scope: AsyncContainer) -> None: """Handler respects execution priority.""" coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) - ev = make_execution_requested_event(execution_id="e-priority-1", priority=10) + ev = make_execution_requested_event(execution_id="e-priority-1", priority=QueuePriority.BACKGROUND) await coord.handle_execution_requested(ev) diff --git a/backend/tests/e2e/services/idempotency/test_redis_repository.py b/backend/tests/e2e/services/idempotency/test_redis_repository.py deleted file mode 100644 index c346c8f6..00000000 --- a/backend/tests/e2e/services/idempotency/test_redis_repository.py +++ /dev/null @@ -1,170 +0,0 @@ -import json -import uuid -from datetime import datetime, timedelta, timezone - -import pytest -import redis.asyncio as redis -from app.domain.idempotency import IdempotencyRecord, IdempotencyStatus -from app.services.idempotency.redis_repository import ( - RedisIdempotencyRepository, - _iso, - _json_default, - _parse_iso_datetime, -) -from pymongo.errors import DuplicateKeyError - -pytestmark = [pytest.mark.e2e, pytest.mark.redis] - - -class TestHelperFunctions: - def test_iso_datetime(self) -> None: - dt = datetime(2025, 1, 15, 10, 30, 45, tzinfo=timezone.utc) - result = _iso(dt) - assert result == "2025-01-15T10:30:45+00:00" - - def test_iso_datetime_with_timezone(self) -> None: - dt = datetime(2025, 1, 15, 10, 30, 45, tzinfo=timezone(timedelta(hours=5))) - result = _iso(dt) - assert result == "2025-01-15T05:30:45+00:00" - - def test_json_default_datetime(self) -> None: - dt = datetime(2025, 1, 15, 10, 30, 45, tzinfo=timezone.utc) - result = _json_default(dt) - assert result == "2025-01-15T10:30:45+00:00" - - def test_json_default_other(self) -> None: - obj = {"key": "value"} - result = _json_default(obj) - assert result == "{'key': 'value'}" - - def test_parse_iso_datetime_variants(self) -> None: - result1 = _parse_iso_datetime("2025-01-15T10:30:45+00:00") - assert result1 is not None and result1.year == 2025 - result2 = _parse_iso_datetime("2025-01-15T10:30:45Z") - assert result2 is not None and result2.tzinfo == timezone.utc - assert _parse_iso_datetime(None) is None - assert _parse_iso_datetime("") is None - assert _parse_iso_datetime("not-a-date") is None - - -@pytest.fixture -def repository(redis_client: redis.Redis) -> RedisIdempotencyRepository: - return RedisIdempotencyRepository(redis_client, key_prefix="idempotency") - - -@pytest.fixture -def sample_record() -> IdempotencyRecord: - return IdempotencyRecord( - key="test-key", - status=IdempotencyStatus.PROCESSING, - event_type="test.event", - event_id="event-123", - created_at=datetime(2025, 1, 15, 10, 30, 45, tzinfo=timezone.utc), - ttl_seconds=5, - completed_at=None, - processing_duration_ms=None, - error=None, - result_json=None, - ) - - -def test_full_key_helpers(repository: RedisIdempotencyRepository) -> None: - assert repository._full_key("my") == "idempotency:my" - assert repository._full_key("idempotency:my") == "idempotency:my" - - -def test_doc_record_roundtrip(repository: RedisIdempotencyRepository) -> None: - rec = IdempotencyRecord( - key="k", - status=IdempotencyStatus.COMPLETED, - event_type="e.t", - event_id="e-1", - created_at=datetime(2025, 1, 15, tzinfo=timezone.utc), - ttl_seconds=60, - completed_at=datetime(2025, 1, 15, 0, 1, tzinfo=timezone.utc), - processing_duration_ms=123, - error="err", - result_json='{"ok":true}', - ) - doc = repository._record_to_doc(rec) - back = repository._doc_to_record(doc) - assert back.key == rec.key and back.status == rec.status - - -@pytest.mark.asyncio -async def test_insert_find_update_delete_flow( - repository: RedisIdempotencyRepository, - redis_client: redis.Redis, - sample_record: IdempotencyRecord, -) -> None: - # Insert processing (NX) - await repository.insert_processing(sample_record) - key = repository._full_key(sample_record.key) - ttl = await redis_client.ttl(key) - assert ttl == sample_record.ttl_seconds or ttl > 0 - - # Duplicate insert should raise DuplicateKeyError - with pytest.raises(DuplicateKeyError): - await repository.insert_processing(sample_record) - - # Find returns the record - found = await repository.find_by_key(sample_record.key) - assert found is not None and found.key == sample_record.key - - # Update preserves TTL when present - sample_record.status = IdempotencyStatus.COMPLETED - sample_record.completed_at = datetime.now(timezone.utc) - sample_record.processing_duration_ms = 10 - sample_record.result_json = json.dumps({"result": True}) - updated = await repository.update_record(sample_record) - assert updated == 1 - ttl_after = await redis_client.ttl(key) - assert ttl_after == ttl or ttl_after <= ttl # ttl should not increase - - # Delete - deleted = await repository.delete_key(sample_record.key) - assert deleted == 1 - assert await repository.find_by_key(sample_record.key) is None - - -@pytest.mark.asyncio -async def test_update_record_when_missing( - repository: RedisIdempotencyRepository, sample_record: IdempotencyRecord -) -> None: - # If key missing, update returns 0 - res = await repository.update_record(sample_record) - assert res == 0 - - -@pytest.mark.asyncio -async def test_aggregate_status_counts( - repository: RedisIdempotencyRepository, redis_client: redis.Redis -) -> None: - # Use unique prefix to avoid collision with other tests - prefix = uuid.uuid4().hex[:8] - # Seed few keys directly using repository - statuses = (IdempotencyStatus.PROCESSING, IdempotencyStatus.PROCESSING, IdempotencyStatus.COMPLETED) - for i, status in enumerate(statuses): - rec = IdempotencyRecord( - key=f"{prefix}_k{i}", - status=status, - event_type="t", - event_id=f"{prefix}_e{i}", - created_at=datetime.now(timezone.utc), - ttl_seconds=60, - ) - await repository.insert_processing(rec) - if status != IdempotencyStatus.PROCESSING: - rec.status = status - rec.completed_at = datetime.now(timezone.utc) - await repository.update_record(rec) - - counts = await repository.aggregate_status_counts("idempotency") - # Counts include all records in namespace, check we have at least our seeded counts - assert counts[IdempotencyStatus.PROCESSING] >= 2 - assert counts[IdempotencyStatus.COMPLETED] >= 1 - - -@pytest.mark.asyncio -async def test_health_check(repository: RedisIdempotencyRepository) -> None: - await repository.health_check() # should not raise diff --git a/backend/tests/e2e/services/sse/test_redis_bus.py b/backend/tests/e2e/services/sse/test_redis_bus.py index 8d0ac726..1985395b 100644 --- a/backend/tests/e2e/services/sse/test_redis_bus.py +++ b/backend/tests/e2e/services/sse/test_redis_bus.py @@ -61,7 +61,7 @@ def pubsub(self) -> _FakePubSub: def _make_metadata() -> EventMetadata: - return EventMetadata(service_name="test", service_version="1.0") + return EventMetadata(service_name="test", service_version="1.0", user_id="test") @pytest.mark.asyncio diff --git a/backend/tests/e2e/test_dlq_routes.py b/backend/tests/e2e/test_dlq_routes.py index c48857d5..974b5164 100644 --- a/backend/tests/e2e/test_dlq_routes.py +++ b/backend/tests/e2e/test_dlq_routes.py @@ -1,6 +1,7 @@ import pytest from app.dlq.models import DLQMessageStatus, RetryStrategy from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic from app.schemas_pydantic.dlq import ( DLQBatchRetryResponse, DLQMessageDetail, @@ -94,7 +95,7 @@ async def test_get_dlq_messages_by_topic( """Filter DLQ messages by topic.""" response = await test_user.get( "/api/v1/dlq/messages", - params={"topic": "execution-events"}, + params={"topic": KafkaTopic.EXECUTION_EVENTS}, ) assert response.status_code == 200 @@ -230,7 +231,7 @@ async def test_set_retry_policy(self, test_user: AsyncClient) -> None: response = await test_user.post( "/api/v1/dlq/retry-policy", json={ - "topic": "execution-events", + "topic": KafkaTopic.EXECUTION_EVENTS, "strategy": RetryStrategy.EXPONENTIAL_BACKOFF, "max_retries": 5, "base_delay_seconds": 60.0, @@ -241,7 +242,7 @@ async def test_set_retry_policy(self, test_user: AsyncClient) -> None: assert response.status_code == 200 result = MessageResponse.model_validate(response.json()) - assert "execution-events" in result.message + assert KafkaTopic.EXECUTION_EVENTS in result.message @pytest.mark.asyncio async def test_set_retry_policy_fixed_strategy( @@ -251,7 +252,7 @@ async def test_set_retry_policy_fixed_strategy( response = await test_user.post( "/api/v1/dlq/retry-policy", json={ - "topic": "test-topic", + "topic": KafkaTopic.POD_EVENTS, "strategy": RetryStrategy.FIXED_INTERVAL, "max_retries": 3, "base_delay_seconds": 30.0, @@ -262,7 +263,7 @@ async def test_set_retry_policy_fixed_strategy( assert response.status_code == 200 result = MessageResponse.model_validate(response.json()) - assert "test-topic" in result.message + assert KafkaTopic.POD_EVENTS in result.message @pytest.mark.asyncio async def test_set_retry_policy_scheduled_strategy( @@ -272,7 +273,7 @@ async def test_set_retry_policy_scheduled_strategy( response = await test_user.post( "/api/v1/dlq/retry-policy", json={ - "topic": "notifications-topic", + "topic": KafkaTopic.USER_EVENTS, "strategy": RetryStrategy.SCHEDULED, "max_retries": 10, "base_delay_seconds": 120.0, @@ -283,7 +284,7 @@ async def test_set_retry_policy_scheduled_strategy( assert response.status_code == 200 result = MessageResponse.model_validate(response.json()) - assert "notifications-topic" in result.message + assert KafkaTopic.USER_EVENTS in result.message class TestDiscardDLQMessage: diff --git a/backend/tests/e2e/test_k8s_worker_create_pod.py b/backend/tests/e2e/test_k8s_worker_create_pod.py index d1efcf80..81b3c301 100644 --- a/backend/tests/e2e/test_k8s_worker_create_pod.py +++ b/backend/tests/e2e/test_k8s_worker_create_pod.py @@ -2,6 +2,7 @@ import uuid import pytest +from app.domain.enums.execution import QueuePriority from app.domain.events.typed import CreatePodCommandEvent, EventMetadata from app.services.k8s_worker.worker import KubernetesWorker from app.settings import Settings @@ -40,7 +41,7 @@ async def test_worker_creates_configmap_and_pod( memory_limit="128Mi", cpu_request="50m", memory_request="64Mi", - priority=5, + priority=QueuePriority.NORMAL, metadata=EventMetadata(service_name="tests", service_version="1", user_id="u1"), ) diff --git a/backend/tests/unit/events/test_metadata_model.py b/backend/tests/unit/events/test_metadata_model.py index f237a263..7feaa41b 100644 --- a/backend/tests/unit/events/test_metadata_model.py +++ b/backend/tests/unit/events/test_metadata_model.py @@ -2,9 +2,10 @@ def test_metadata_creation() -> None: - m = EventMetadata(service_name="svc", service_version="1") + m = EventMetadata(service_name="svc", service_version="1", user_id="test-user") assert m.service_name == "svc" assert m.service_version == "1" + assert m.user_id == "test-user" assert m.correlation_id # auto-generated @@ -14,7 +15,7 @@ def test_metadata_with_user() -> None: def test_metadata_copy_with_correlation() -> None: - m = EventMetadata(service_name="svc", service_version="1") + m = EventMetadata(service_name="svc", service_version="1", user_id="test-user") m2 = m.model_copy(update={"correlation_id": "cid"}) assert m2.correlation_id == "cid" assert m2.service_name == m.service_name diff --git a/backend/tests/unit/events/test_schema_registry_manager.py b/backend/tests/unit/events/test_schema_registry_manager.py index 6819237a..2a733e6c 100644 --- a/backend/tests/unit/events/test_schema_registry_manager.py +++ b/backend/tests/unit/events/test_schema_registry_manager.py @@ -25,7 +25,7 @@ def test_deserialize_json_execution_requested(test_settings: Settings) -> None: "cpu_request": "50m", "memory_request": "64Mi", "priority": 5, - "metadata": {"service_name": "t", "service_version": "1.0"}, + "metadata": {"service_name": "t", "service_version": "1.0", "user_id": "test"}, } ev = m.deserialize_json(data) assert isinstance(ev, ExecutionRequestedEvent) diff --git a/backend/tests/unit/services/idempotency/__init__.py b/backend/tests/unit/services/idempotency/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/backend/tests/unit/services/idempotency/test_idempotency_manager.py b/backend/tests/unit/services/idempotency/test_idempotency_manager.py deleted file mode 100644 index aa7fb056..00000000 --- a/backend/tests/unit/services/idempotency/test_idempotency_manager.py +++ /dev/null @@ -1,104 +0,0 @@ -import logging -from unittest.mock import MagicMock - -import pytest -from app.core.metrics import DatabaseMetrics -from app.domain.events.typed import BaseEvent -from app.domain.idempotency import KeyStrategy -from app.services.idempotency.idempotency_manager import ( - IdempotencyConfig, - IdempotencyKeyStrategy, - IdempotencyManager, -) - -pytestmark = pytest.mark.unit - -# Test logger -_test_logger = logging.getLogger("test.idempotency_manager") - - -class TestIdempotencyKeyStrategy: - def test_event_based(self) -> None: - event = MagicMock(spec=BaseEvent) - event.event_type = "test.event" - event.event_id = "event-123" - key = IdempotencyKeyStrategy.event_based(event) - assert key == "test.event:event-123" - - def test_content_hash_all_fields(self) -> None: - event = MagicMock(spec=BaseEvent) - event.model_dump.return_value = { - "event_id": "123", - "event_type": "test", - "timestamp": "2025-01-01", - "metadata": {}, - "field1": "value1", - "field2": "value2", - } - key = IdempotencyKeyStrategy.content_hash(event) - assert isinstance(key, str) and len(key) == 64 - - def test_content_hash_specific_fields(self) -> None: - event = MagicMock(spec=BaseEvent) - event.model_dump.return_value = { - "event_id": "123", - "event_type": "test", - "field1": "value1", - "field2": "value2", - "field3": "value3", - } - key = IdempotencyKeyStrategy.content_hash(event, fields={"field1", "field3"}) - assert isinstance(key, str) and len(key) == 64 - - def test_custom(self) -> None: - event = MagicMock(spec=BaseEvent) - event.event_type = "test.event" - key = IdempotencyKeyStrategy.custom(event, "custom-key-123") - assert key == "test.event:custom-key-123" - - -class TestIdempotencyConfig: - def test_default_config(self) -> None: - config = IdempotencyConfig() - assert config.key_prefix == "idempotency" - assert config.default_ttl_seconds == 3600 - assert config.processing_timeout_seconds == 300 - assert config.enable_result_caching is True - assert config.max_result_size_bytes == 1048576 - assert config.enable_metrics is True - assert config.collection_name == "idempotency_keys" - - def test_custom_config(self) -> None: - config = IdempotencyConfig( - key_prefix="custom", - default_ttl_seconds=7200, - processing_timeout_seconds=600, - enable_result_caching=False, - max_result_size_bytes=2048, - enable_metrics=False, - collection_name="custom_keys", - ) - assert config.key_prefix == "custom" - assert config.default_ttl_seconds == 7200 - assert config.processing_timeout_seconds == 600 - assert config.enable_result_caching is False - assert config.max_result_size_bytes == 2048 - assert config.enable_metrics is False - assert config.collection_name == "custom_keys" - - -def test_manager_generate_key_variants(database_metrics: DatabaseMetrics) -> None: - repo = MagicMock() - mgr = IdempotencyManager(IdempotencyConfig(), repo, _test_logger, database_metrics=database_metrics) - ev = MagicMock(spec=BaseEvent) - ev.event_type = "t" - ev.event_id = "e" - ev.model_dump.return_value = {"event_id": "e", "event_type": "t"} - - assert mgr._generate_key(ev, KeyStrategy.EVENT_BASED) == "idempotency:t:e" - ch = mgr._generate_key(ev, KeyStrategy.CONTENT_HASH) - assert ch.startswith("idempotency:") and len(ch.split(":")[-1]) == 64 - assert mgr._generate_key(ev, KeyStrategy.CUSTOM, custom_key="k") == "idempotency:t:k" - with pytest.raises(ValueError): - mgr._generate_key(ev, KeyStrategy.CUSTOM) # CUSTOM requires custom_key - diff --git a/backend/tests/unit/services/idempotency/test_middleware.py b/backend/tests/unit/services/idempotency/test_middleware.py deleted file mode 100644 index 58f92ec3..00000000 --- a/backend/tests/unit/services/idempotency/test_middleware.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging -from unittest.mock import AsyncMock, MagicMock - -import pytest -from app.domain.events.typed import DomainEvent -from app.domain.idempotency import IdempotencyStatus, KeyStrategy -from app.services.idempotency.idempotency_manager import IdempotencyManager, IdempotencyResult -from app.services.idempotency.middleware import ( - IdempotentEventHandler, -) - -_test_logger = logging.getLogger("test.services.idempotency.middleware") - - -pytestmark = pytest.mark.unit - - -class TestIdempotentEventHandler: - @pytest.fixture - def mock_idempotency_manager(self) -> AsyncMock: - return AsyncMock(spec=IdempotencyManager) - - @pytest.fixture - def mock_handler(self) -> AsyncMock: - handler = AsyncMock() - handler.__name__ = "test_handler" - return handler - - @pytest.fixture - def event(self) -> MagicMock: - event = MagicMock(spec=DomainEvent) - event.event_type = "test.event" - event.event_id = "event-123" - return event - - @pytest.fixture - def idempotent_event_handler( - self, mock_handler: AsyncMock, mock_idempotency_manager: AsyncMock - ) -> IdempotentEventHandler: - return IdempotentEventHandler( - handler=mock_handler, - idempotency_manager=mock_idempotency_manager, - key_strategy=KeyStrategy.EVENT_BASED, - ttl_seconds=3600, - cache_result=True, - logger=_test_logger - ) - - @pytest.mark.asyncio - async def test_call_with_fields( - self, mock_handler: AsyncMock, mock_idempotency_manager: AsyncMock, event: MagicMock - ) -> None: - # Setup with specific fields - fields = {"field1", "field2"} - - handler = IdempotentEventHandler( - handler=mock_handler, - idempotency_manager=mock_idempotency_manager, - key_strategy=KeyStrategy.CONTENT_HASH, - fields=fields, - logger=_test_logger - ) - - idempotency_result = IdempotencyResult( - is_duplicate=False, - status=IdempotencyStatus.PROCESSING, - created_at=MagicMock(), - key="test-key" - ) - mock_idempotency_manager.check_and_reserve.return_value = idempotency_result - - # Execute - await handler(event) - - # Verify - mock_idempotency_manager.check_and_reserve.assert_called_once_with( - event=event, - key_strategy=KeyStrategy.CONTENT_HASH, - custom_key=None, - ttl_seconds=None, - fields=fields - ) - - @pytest.mark.asyncio - async def test_call_handler_exception( - self, - idempotent_event_handler: IdempotentEventHandler, - mock_idempotency_manager: AsyncMock, - mock_handler: AsyncMock, - event: MagicMock, - ) -> None: - # Setup: Handler raises exception - idempotency_result = IdempotencyResult( - is_duplicate=False, - status=IdempotencyStatus.PROCESSING, - created_at=MagicMock(), - key="test-key" - ) - mock_idempotency_manager.check_and_reserve.return_value = idempotency_result - mock_handler.side_effect = Exception("Handler error") - - # Execute and verify exception is raised - with pytest.raises(Exception, match="Handler error"): - await idempotent_event_handler(event) - - # Verify failure is marked - mock_idempotency_manager.mark_failed.assert_called_once_with( - event=event, - error="Handler error", - key_strategy=KeyStrategy.EVENT_BASED, - custom_key=None, - fields=None - ) - - # Duplicate handler and custom key behavior covered by integration tests - - -class TestIdempotentHandlerDecorator: - pass - -class TestIdempotentConsumerWrapper: - pass diff --git a/backend/tests/unit/services/pod_monitor/test_event_mapper.py b/backend/tests/unit/services/pod_monitor/test_event_mapper.py index 2a42bbdd..0afc1154 100644 --- a/backend/tests/unit/services/pod_monitor/test_event_mapper.py +++ b/backend/tests/unit/services/pod_monitor/test_event_mapper.py @@ -31,7 +31,7 @@ def _ctx(pod: Pod, event_type: str = "ADDED") -> PodContext: return PodContext( pod=pod, execution_id="e1", - metadata=EventMetadata(service_name="t", service_version="1"), + metadata=EventMetadata(service_name="t", service_version="1", user_id="test"), phase=pod.status.phase or "", event_type=event_type, ) diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index 5a233fbd..e9ab717b 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -10,7 +10,7 @@ from kubernetes import client as k8s_client from app.core.metrics import EventMetrics, KubernetesMetrics -from app.db.repositories.pod_state_repository import PodStateRepository +from app.db.repositories.redis.pod_state_repository import PodStateRepository from app.domain.events.typed import DomainEvent, EventMetadata, ExecutionCompletedEvent from app.domain.events.typed import ResourceUsageDomain from app.events.core import UnifiedProducer @@ -55,7 +55,7 @@ def __init__(self) -> None: async def track_pod( self, pod_name: str, execution_id: str, status: str, - metadata: dict[str, object] | None = None, ttl_seconds: int = 7200, + ttl_seconds: int = 7200, ) -> None: self._tracked.add(pod_name) @@ -251,7 +251,7 @@ async def test_publish_event(event_metrics: EventMetrics, kubernetes_metrics: Ku aggregate_id="exec1", exit_code=0, resource_usage=ResourceUsageDomain(), - metadata=EventMetadata(service_name="test", service_version="1.0"), + metadata=EventMetadata(service_name="test", service_version="1.0", user_id="test"), ) pod = make_pod(name="test-pod", phase="Succeeded", labels={"execution-id": "exec1"}) @@ -275,7 +275,7 @@ def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # n aggregate_id="e1", exit_code=0, resource_usage=ResourceUsageDomain(), - metadata=EventMetadata(service_name="test", service_version="1.0"), + metadata=EventMetadata(service_name="test", service_version="1.0", user_id="test"), ) ] diff --git a/backend/tests/unit/services/saga/test_saga_step_and_base.py b/backend/tests/unit/services/saga/test_saga_step_and_base.py index d56acab6..dc8e08c6 100644 --- a/backend/tests/unit/services/saga/test_saga_step_and_base.py +++ b/backend/tests/unit/services/saga/test_saga_step_and_base.py @@ -47,7 +47,7 @@ async def test_context_adders() -> None: error_type="test_error", message="test", service_name="test_service", - metadata=EventMetadata(service_name="t", service_version="1"), + metadata=EventMetadata(service_name="t", service_version="1", user_id="test"), ) ctx.add_event(evt) assert len(ctx.events) == 1 diff --git a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py index d2fd5ebd..726f405b 100644 --- a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py +++ b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py @@ -22,7 +22,7 @@ async def publish_event(self, execution_id: str, event: DomainEvent) -> None: def _make_metadata() -> EventMetadata: - return EventMetadata(service_name="test", service_version="1.0") + return EventMetadata(service_name="test", service_version="1.0", user_id="test") @pytest.mark.asyncio diff --git a/backend/tests/unit/services/test_pod_builder.py b/backend/tests/unit/services/test_pod_builder.py index 6742073b..d06bbc41 100644 --- a/backend/tests/unit/services/test_pod_builder.py +++ b/backend/tests/unit/services/test_pod_builder.py @@ -1,6 +1,7 @@ from uuid import uuid4 import pytest +from app.domain.enums.execution import QueuePriority from app.domain.events.typed import CreatePodCommandEvent, EventMetadata from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.pod_builder import PodBuilder @@ -38,7 +39,7 @@ def create_pod_command(self) -> CreatePodCommandEvent: memory_request="256Mi", cpu_limit="1000m", memory_limit="1Gi", - priority=5, + priority=QueuePriority.NORMAL, metadata=EventMetadata( user_id=str(uuid4()), correlation_id=str(uuid4()), @@ -152,7 +153,7 @@ def test_container_resources_defaults( memory_request="", cpu_limit="", memory_limit="", - priority=5, + priority=QueuePriority.NORMAL, metadata=EventMetadata( service_name="svc", service_version="1", @@ -285,7 +286,7 @@ def test_pod_timeout_default( memory_request="128Mi", cpu_limit="500m", memory_limit="512Mi", - priority=5, + priority=QueuePriority.NORMAL, metadata=EventMetadata(user_id=str(uuid4()), service_name="t", service_version="1") ) @@ -342,7 +343,7 @@ def test_pod_labels_truncation( memory_limit="128Mi", cpu_request="50m", memory_request="64Mi", - priority=5, + priority=QueuePriority.NORMAL, metadata=EventMetadata( service_name="svc", service_version="1", @@ -399,7 +400,7 @@ def test_different_languages( memory_request="128Mi", cpu_limit="200m", memory_limit="256Mi", - priority=5, + priority=QueuePriority.NORMAL, metadata=EventMetadata(user_id=str(uuid4()), service_name="t", service_version="1") ) diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index 47b3ec53..9f72f518 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -9,15 +9,16 @@ from app.db.docs import ALL_DOCUMENTS from app.dlq import DLQMessage, RetryPolicy, RetryStrategy from app.dlq.manager import DLQManager +from app.domain.enums.kafka import KafkaTopic from app.settings import Settings from beanie import init_beanie def _configure_retry_policies(manager: DLQManager, logger: logging.Logger) -> None: manager.set_retry_policy( - "execution-requests", + KafkaTopic.EXECUTION_REQUESTS, RetryPolicy( - topic="execution-requests", + topic=KafkaTopic.EXECUTION_REQUESTS, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, max_retries=5, base_delay_seconds=30, @@ -26,9 +27,9 @@ def _configure_retry_policies(manager: DLQManager, logger: logging.Logger) -> No ), ) manager.set_retry_policy( - "pod-events", + KafkaTopic.POD_EVENTS, RetryPolicy( - topic="pod-events", + topic=KafkaTopic.POD_EVENTS, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, max_retries=3, base_delay_seconds=60, @@ -37,17 +38,20 @@ def _configure_retry_policies(manager: DLQManager, logger: logging.Logger) -> No ), ) manager.set_retry_policy( - "resource-allocation", - RetryPolicy(topic="resource-allocation", strategy=RetryStrategy.IMMEDIATE, max_retries=3), + KafkaTopic.RESOURCE_ALLOCATION, + RetryPolicy(topic=KafkaTopic.RESOURCE_ALLOCATION, strategy=RetryStrategy.IMMEDIATE, max_retries=3), ) manager.set_retry_policy( - "websocket-events", + KafkaTopic.WEBSOCKET_EVENTS, RetryPolicy( - topic="websocket-events", strategy=RetryStrategy.FIXED_INTERVAL, max_retries=10, base_delay_seconds=10 + topic=KafkaTopic.WEBSOCKET_EVENTS, + strategy=RetryStrategy.FIXED_INTERVAL, + max_retries=10, + base_delay_seconds=10, ), ) manager.set_default_retry_policy(RetryPolicy( - topic="default", + topic=KafkaTopic.DEAD_LETTER_QUEUE, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, max_retries=4, base_delay_seconds=60, diff --git a/docs/architecture/idempotency.md b/docs/architecture/idempotency.md index 6b036d68..a2e40c84 100644 --- a/docs/architecture/idempotency.md +++ b/docs/architecture/idempotency.md @@ -1,129 +1,93 @@ # Idempotency -The platform implements at-least-once event delivery with idempotency protection to prevent duplicate processing. When a -Kafka message is delivered multiple times (due to retries, rebalances, or failures), the idempotency layer ensures the -event handler executes only once. Results can be cached for fast duplicate responses. +The platform implements idempotency protection for HTTP API requests using a simple Redis-based pattern. When a client provides an `Idempotency-Key` header, duplicate requests return the cached result instead of re-executing. ## Architecture ```mermaid flowchart TB - subgraph Kafka Consumer - MSG[Incoming Event] --> CHECK[Check & Reserve Key] + subgraph API Request + REQ[POST /execute] --> KEY{Idempotency-Key?} end - subgraph Idempotency Manager - CHECK --> REDIS[(Redis)] - REDIS --> FOUND{Key Exists?} - FOUND -->|Yes| STATUS{Status?} - STATUS -->|Processing| TIMEOUT{Timed Out?} - STATUS -->|Completed/Failed| DUP[Return Duplicate] - TIMEOUT -->|Yes| RETRY[Allow Retry] - TIMEOUT -->|No| WAIT[Block Duplicate] - FOUND -->|No| RESERVE[Reserve Key] + subgraph Redis + KEY -->|Yes| CHECK[SET NX] + CHECK --> EXISTS{Key Exists?} + EXISTS -->|Yes| GET[Get Cached Result] + EXISTS -->|No| RESERVE[Key Reserved] end - subgraph Handler Execution - RESERVE --> HANDLER[Execute Handler] - RETRY --> HANDLER - HANDLER -->|Success| COMPLETE[Mark Completed] - HANDLER -->|Error| FAIL[Mark Failed] - COMPLETE --> CACHE[Cache Result] + subgraph Handler + RESERVE --> EXEC[Execute Request] + EXEC --> STORE[Store Result] + GET --> RETURN[Return Cached] + STORE --> RETURN end ``` -## Key Strategies +## Usage -The idempotency manager supports three strategies for generating keys from events: +Clients include the `Idempotency-Key` header to enable duplicate protection: -**Event-based** uses the event's unique ID and type. This is the default and works for events where the ID is guaranteed -unique (like UUIDs generated at publish time). - -**Content hash** generates a SHA-256 hash of the event's payload, excluding metadata like timestamps and event IDs. Use -this when the same logical operation might produce different event IDs but identical content. - -**Custom** allows the caller to provide an arbitrary key. Useful when idempotency depends on business logic (e.g., "one -execution per user per minute"). - -```python ---8<-- "backend/app/services/idempotency/idempotency_manager.py:37:57" +```bash +curl -X POST https://api.example.com/api/v1/execute \ + -H "Authorization: Bearer $TOKEN" \ + -H "Idempotency-Key: my-unique-request-id" \ + -d '{"script": "print(1)", "lang": "python", "lang_version": "3.11"}' ``` -## Status Lifecycle +If the same key is sent again within the TTL window (24 hours), the API returns the cached response without re-executing. + +## Implementation -Each idempotency record transitions through defined states: +The idempotency layer uses Redis `SET NX EX` for atomic reservation: ```python ---8<-- "backend/app/domain/idempotency/models.py:11:15" +--8<-- "backend/app/db/repositories/redis/idempotency_repository.py" ``` -When an event arrives, the manager checks for an existing key. If none exists, it creates a record in `PROCESSING` state -and returns control to the handler. On success, the record moves to `COMPLETED`; on error, to `FAILED`. Both terminal -states block duplicate processing for the TTL duration. +### Key Format -If a key is found in `PROCESSING` state but has exceeded the processing timeout (default 5 minutes), the manager assumes -the previous processor crashed and allows a retry. +Keys are namespaced by user to prevent cross-user collisions: -## Middleware Integration - -The `IdempotentEventHandler` wraps Kafka event handlers with automatic duplicate detection: - -```python ---8<-- "backend/app/services/idempotency/middleware.py:39:73" ``` - -For bulk registration, the `IdempotentConsumerWrapper` wraps all handlers in a dispatcher: - -```python ---8<-- "backend/app/services/idempotency/middleware.py:122:141" +idempotent:exec:{user_id}:{idempotency_key} ``` -## Redis Storage - -Idempotency records are stored in Redis with automatic TTL expiration. The `SET NX EX` command provides atomic -reservation—if two processes race to claim the same key, only one succeeds: +### Flow -```python ---8<-- "backend/app/services/idempotency/redis_repository.py:93:100" -``` +1. **Reserve**: `SET NX` attempts to claim the key atomically +2. **Check**: If key exists, fetch cached result +3. **Execute**: If new, process the request +4. **Store**: Save result JSON for future duplicates ## Configuration -| Parameter | Default | Description | -|------------------------------|---------------|--------------------------------------| -| `key_prefix` | `idempotency` | Redis key namespace | -| `default_ttl_seconds` | `3600` | How long completed keys are retained | -| `processing_timeout_seconds` | `300` | When to assume a processor crashed | -| `enable_result_caching` | `true` | Store handler results for duplicates | -| `max_result_size_bytes` | `1048576` | Maximum cached result size (1MB) | - -```python ---8<-- "backend/app/services/idempotency/idempotency_manager.py:27:34" -``` - -## Result Caching +| Parameter | Default | Description | +|---------------|----------|------------------------------------| +| `KEY_PREFIX` | `idempotent` | Redis key namespace | +| `default_ttl` | `86400` | TTL in seconds (24 hours) | -When `enable_result_caching` is true, the manager stores the handler's result JSON alongside the completion status. -Subsequent duplicates can return the cached result without re-executing the handler. This is useful for idempotent -queries where the response should be consistent. +## Why This Design -Results exceeding `max_result_size_bytes` are silently dropped from the cache but the idempotency protection still -applies. +The previous implementation (~600 lines) included: -## Metrics +- Multiple key strategies (event-based, content-hash, custom) +- Processing state tracking with timeouts +- Background stats collection +- Middleware wrappers for Kafka consumers +- Result caching with size limits -The idempotency system exposes several metrics for monitoring: +Analysis showed only HTTP API idempotency was actually used, and Kafka consumer idempotency was handled elsewhere. The simplified design: -- `idempotency_cache_hits` - Key lookups that found an existing record -- `idempotency_cache_misses` - Key lookups that created new records -- `idempotency_duplicates_blocked` - Events rejected as duplicates -- `idempotency_keys_active` - Current number of active keys (updated periodically) +- **35 lines** vs ~600 lines +- **3 methods** vs complex state machine +- **No background tasks** +- **No unused abstractions** ## Key Files -| File | Purpose | -|--------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------| -| [`services/idempotency/idempotency_manager.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/services/idempotency/idempotency_manager.py) | Core idempotency logic | -| [`services/idempotency/redis_repository.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/services/idempotency/redis_repository.py) | Redis storage adapter | -| [`services/idempotency/middleware.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/services/idempotency/middleware.py) | Handler wrappers and consumer integration | -| [`domain/idempotency/`](https://github.com/HardMax71/Integr8sCode/tree/main/backend/app/domain/idempotency) | Domain models | +| File | Purpose | +|------|---------| +| [`db/repositories/redis/idempotency_repository.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/db/repositories/redis/idempotency_repository.py) | Redis-based idempotency | +| [`api/routes/execution.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/api/routes/execution.py) | HTTP API usage | diff --git a/docs/architecture/services-overview.md b/docs/architecture/services-overview.md index bce59980..e756f869 100644 --- a/docs/architecture/services-overview.md +++ b/docs/architecture/services-overview.md @@ -4,7 +4,7 @@ This document explains what lives under `backend/app/services/`, what each servi ## High-level architecture -The API (FastAPI) receives user requests for auth, execute, events, scripts, and settings. The Coordinator accepts validated execution requests and enqueues them to Kafka with metadata and idempotency guards. The Saga Orchestrator drives stateful execution via events and publishes commands to the K8s Worker. The K8s Worker builds and creates per-execution pods and supporting ConfigMaps with network isolation enforced at cluster level via Cilium policy. Pod Monitor watches K8s and translates pod phases and logs into domain events. Result Processor consumes completion/failure/timeout events, updates DB, and cleans resources. SSE Router fans execution events out to connected clients. DLQ Processor and Event Replay support reliability and investigations. +The API (FastAPI) receives user requests for auth, execute, events, scripts, and settings. The Coordinator accepts validated execution requests, enforces per-user limits, and fires CreatePodCommandEvent to Kafka for K8s to handle. The Saga Orchestrator drives stateful execution via events and publishes commands to the K8s Worker. The K8s Worker builds and creates per-execution pods and supporting ConfigMaps with network isolation enforced at cluster level via Cilium policy. Pod Monitor watches K8s and translates pod phases and logs into domain events. Result Processor consumes completion/failure/timeout events, updates DB, and cleans resources. SSE Router fans execution events out to connected clients. DLQ Processor and Event Replay support reliability and investigations. ## Event streams @@ -12,7 +12,7 @@ EXECUTION_EVENTS carries lifecycle updates like queued, started, running, and ca ## Execution pipeline services -The coordinator/ module contains QueueManager which maintains an in-memory view of pending executions with priorities, aging, and backpressure. It doesn't own metrics for queue depth (that's centralized in coordinator metrics) and doesn't publish commands directly, instead emitting events for the Saga Orchestrator to process. This provides fairness, limits, and stale-job cleanup in one place while preventing double publications. +The coordinator/ module is a stateless service that enforces per-user execution limits and fires CreatePodCommandEvent directly to Kafka. K8s handles all resource allocation and scheduling via ResourceQuotas, LimitRanges, and PriorityClasses. The coordinator tracks active executions per user via Redis and decrements on completion/failure/cancellation. The saga/ module has ExecutionSaga which encodes the multi-step execution flow from receiving a request through creating a pod command, observing pod outcomes, and committing the result. The Saga Orchestrator subscribes to EXECUTION events, reconstructs sagas, and issues SAGA_COMMANDS to the worker with goals of idempotency across restarts, clean compensation on failure, and avoiding duplicate side-effects. @@ -44,7 +44,7 @@ The saved_script_service.py handles CRUD for saved scripts with ownership checks The rate_limit_service.py is a Redis-backed sliding window / token bucket implementation with dynamic configuration per endpoint group, user overrides, and IP fallback. It has a safe failure mode (fail open) with explicit metrics when Redis is unavailable. -The idempotency/ module provides middleware and wrappers to make Kafka consumption idempotent using content-hash or custom keys, used for SAGA_COMMANDS to avoid duplicate pod creation. +The `db/repositories/redis/idempotency_repository.py` provides simple Redis-based idempotency for HTTP API requests using the `Idempotency-Key` header pattern. The saga_service.py provides read-model access for saga state and guardrails like enforcing access control on saga inspection routes. @@ -70,7 +70,7 @@ The DLQ Processor drains and retries dead-lettered messages with backoff and vis The worker refuses to run in the default namespace. Use the setup script to apply the Cilium policy in a dedicated namespace and run the worker there. Apply `backend/k8s/policies/executor-deny-all-cnp.yaml` or use `scripts/setup_k8s.sh `. All executor pods are labeled `app=integr8s, component=executor` and are covered by the static deny-all policy. See [Security Policies](../security/policies.md) for details on network isolation. -Sagas and consumers use content-hash keys by default to avoid duplicates on restarts. Coordinator centralizes queue depth metrics, Result Processor normalizes error types, and Rate Limit service emits rich diagnostics even when disabled. +Coordinator handles per-user execution limits, Result Processor normalizes error types, and Rate Limit service emits rich diagnostics even when disabled. ## Common flows