From 528c54fcfa479b6a060d5bd678638e009fa3fe58 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Mon, 26 Jan 2026 16:39:33 +0100 Subject: [PATCH 1/4] Support gateway events in API, CLI, and UI --- .../List/hooks/useColumnDefinitions.tsx | 13 +++++++++++++ .../src/pages/Events/List/hooks/useFilters.ts | 9 +++++++++ frontend/src/types/event.d.ts | 5 +++-- src/dstack/_internal/cli/commands/event.py | 19 +++++++++++++++++++ src/dstack/_internal/cli/services/events.py | 1 + src/dstack/_internal/core/models/events.py | 1 + src/dstack/_internal/core/models/gateways.py | 4 ++++ src/dstack/_internal/server/routers/events.py | 1 + src/dstack/_internal/server/schemas/events.py | 11 +++++++++++ .../_internal/server/services/events.py | 17 +++++++++++++++++ .../server/services/gateways/__init__.py | 1 + src/dstack/api/server/_events.py | 2 ++ .../_internal/server/routers/test_gateways.py | 8 ++++++++ 13 files changed, 90 insertions(+), 2 deletions(-) diff --git a/frontend/src/pages/Events/List/hooks/useColumnDefinitions.tsx b/frontend/src/pages/Events/List/hooks/useColumnDefinitions.tsx index ad337cf5f1..d6e5b846ea 100644 --- a/frontend/src/pages/Events/List/hooks/useColumnDefinitions.tsx +++ b/frontend/src/pages/Events/List/hooks/useColumnDefinitions.tsx @@ -125,6 +125,19 @@ export const useColumnsDefinitions = () => { ); + case 'gateway': + return ( +
+ Gateway{' '} + {target.project_name && ( + + {target.project_name} + + )} + /{target.name} +
+ ); + default: return '---'; } diff --git a/frontend/src/pages/Events/List/hooks/useFilters.ts b/frontend/src/pages/Events/List/hooks/useFilters.ts index a3d510718f..d463770b30 100644 --- a/frontend/src/pages/Events/List/hooks/useFilters.ts +++ b/frontend/src/pages/Events/List/hooks/useFilters.ts @@ -18,6 +18,7 @@ type RequestParamsKeys = keyof Pick< | 'target_runs' | 'target_jobs' | 'target_volumes' + | 'target_gateways' | 'within_projects' | 'within_fleets' | 'within_runs' @@ -33,6 +34,7 @@ const filterKeys: Record = { TARGET_RUNS: 'target_runs', TARGET_JOBS: 'target_jobs', TARGET_VOLUMES: 'target_volumes', + TARGET_GATEWAYS: 'target_gateways', WITHIN_PROJECTS: 'within_projects', WITHIN_FLEETS: 'within_fleets', WITHIN_RUNS: 'within_runs', @@ -50,6 +52,7 @@ const multipleChoiseKeys: RequestParamsKeys[] = [ 'target_runs', 'target_jobs', 'target_volumes', + 'target_gateways', 'within_projects', 'within_fleets', 'within_runs', @@ -65,6 +68,7 @@ const targetTypes = [ { label: 'Run', value: 'run' }, { label: 'Job', value: 'job' }, { label: 'Volume', value: 'volume' }, + { label: 'Gateway', value: 'gateway' }, ]; export const useFilters = () => { @@ -162,6 +166,11 @@ export const useFilters = () => { operators: ['='], propertyLabel: 'Target volumes', }, + { + key: filterKeys.TARGET_GATEWAYS, + operators: ['='], + propertyLabel: 'Target gateways', + }, { key: filterKeys.WITHIN_PROJECTS, diff --git a/frontend/src/types/event.d.ts b/frontend/src/types/event.d.ts index 3aadfa1f31..618ea6673f 100644 --- a/frontend/src/types/event.d.ts +++ b/frontend/src/types/event.d.ts @@ -1,4 +1,4 @@ -declare type TEventTargetType = 'project' | 'user' | 'fleet' | 'instance' | 'run' | 'job' | 'volume'; +declare type TEventTargetType = 'project' | 'user' | 'fleet' | 'instance' | 'run' | 'job' | 'volume' | 'gateway'; declare type TEventListRequestParams = Omit & { prev_recorded_at?: string; @@ -9,6 +9,7 @@ declare type TEventListRequestParams = Omit EventListFilters: api.client.volumes.get(project_name=api.project, name=name).id for name in args.target_volumes ] + elif args.target_gateways: + filters.target_gateways = [] + for name in args.target_gateways: + id = api.client.gateways.get(api.project, name).id + if id is None: + # TODO(0.21): Remove this check once `Gateway.id` is required. + raise CLIError( + "Cannot determine gateway ID, most likely due to an outdated dstack server." + " Update the server to 0.20.7 or higher or remove --target-gateway." + ) + filters.target_gateways.append(id) if args.within_fleets: filters.within_fleets = [ diff --git a/src/dstack/_internal/cli/services/events.py b/src/dstack/_internal/cli/services/events.py index c2903065c9..11f764bd15 100644 --- a/src/dstack/_internal/cli/services/events.py +++ b/src/dstack/_internal/cli/services/events.py @@ -17,6 +17,7 @@ class EventListFilters: target_fleets: Optional[list[uuid.UUID]] = None target_runs: Optional[list[uuid.UUID]] = None target_volumes: Optional[list[uuid.UUID]] = None + target_gateways: Optional[list[uuid.UUID]] = None within_projects: Optional[list[uuid.UUID]] = None within_fleets: Optional[list[uuid.UUID]] = None within_runs: Optional[list[uuid.UUID]] = None diff --git a/src/dstack/_internal/core/models/events.py b/src/dstack/_internal/core/models/events.py index 6dae2dc178..289c4fc674 100644 --- a/src/dstack/_internal/core/models/events.py +++ b/src/dstack/_internal/core/models/events.py @@ -17,6 +17,7 @@ class EventTargetType(str, Enum): RUN = "run" JOB = "job" VOLUME = "volume" + GATEWAY = "gateway" class EventTarget(CoreModel): diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 2dfeb5b181..b342c0a73b 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -1,4 +1,5 @@ import datetime +import uuid from enum import Enum from typing import Dict, Optional, Union @@ -93,6 +94,9 @@ class GatewaySpec(CoreModel): class Gateway(CoreModel): + # ID is only optional on the client side for compatibility with pre-0.20.7 servers. + # TODO(0.21): Make required. + id: Optional[uuid.UUID] = None name: str configuration: GatewayConfiguration created_at: datetime.datetime diff --git a/src/dstack/_internal/server/routers/events.py b/src/dstack/_internal/server/routers/events.py index be75cccbb4..4250eb4d7a 100644 --- a/src/dstack/_internal/server/routers/events.py +++ b/src/dstack/_internal/server/routers/events.py @@ -45,6 +45,7 @@ async def list_events( target_runs=body.target_runs, target_jobs=body.target_jobs, target_volumes=body.target_volumes, + target_gateways=body.target_gateways, within_projects=body.within_projects, within_fleets=body.within_fleets, within_runs=body.within_runs, diff --git a/src/dstack/_internal/server/schemas/events.py b/src/dstack/_internal/server/schemas/events.py index 66ea2e3404..30f7fe3244 100644 --- a/src/dstack/_internal/server/schemas/events.py +++ b/src/dstack/_internal/server/schemas/events.py @@ -91,6 +91,17 @@ class ListEventsRequest(CoreModel): max_items=MAX_FILTER_ITEMS, ), ] = None + target_gateways: Annotated[ + Optional[list[uuid.UUID]], + Field( + description=( + "List of gateway IDs." + " The response will only include events that target the specified gateways" + ), + min_items=MIN_FILTER_ITEMS, + max_items=MAX_FILTER_ITEMS, + ), + ] = None within_projects: Annotated[ Optional[list[uuid.UUID]], Field( diff --git a/src/dstack/_internal/server/services/events.py b/src/dstack/_internal/server/services/events.py index 80d81734b5..c6d35a4577 100644 --- a/src/dstack/_internal/server/services/events.py +++ b/src/dstack/_internal/server/services/events.py @@ -14,6 +14,7 @@ EventModel, EventTargetModel, FleetModel, + GatewayModel, InstanceModel, JobModel, MemberModel, @@ -87,6 +88,7 @@ def __post_init__(self): def from_model( model: Union[ FleetModel, + GatewayModel, InstanceModel, JobModel, ProjectModel, @@ -102,6 +104,13 @@ def from_model( id=model.id, name=model.name, ) + if isinstance(model, GatewayModel): + return Target( + type=EventTargetType.GATEWAY, + project_id=model.project_id or model.project.id, + id=model.id, + name=model.name, + ) if isinstance(model, InstanceModel): return Target( type=EventTargetType.INSTANCE, @@ -222,6 +231,7 @@ async def list_events( target_runs: Optional[list[uuid.UUID]], target_jobs: Optional[list[uuid.UUID]], target_volumes: Optional[list[uuid.UUID]], + target_gateways: Optional[list[uuid.UUID]], within_projects: Optional[list[uuid.UUID]], within_fleets: Optional[list[uuid.UUID]], within_runs: Optional[list[uuid.UUID]], @@ -298,6 +308,13 @@ async def list_events( EventTargetModel.entity_id.in_(target_volumes), ) ) + if target_gateways is not None: + target_filters.append( + and_( + EventTargetModel.entity_type == EventTargetType.GATEWAY, + EventTargetModel.entity_id.in_(target_gateways), + ) + ) if within_projects is not None: target_filters.append(EventTargetModel.entity_project_id.in_(within_projects)) if within_fleets is not None: diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 4ab80a8331..cf41b53973 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -558,6 +558,7 @@ def gateway_model_to_gateway(gateway_model: GatewayModel) -> Gateway: configuration = get_gateway_configuration(gateway_model) configuration.default = gateway_model.project.default_gateway_id == gateway_model.id return Gateway( + id=gateway_model.id, name=gateway_model.name, ip_address=ip_address, instance_id=instance_id, diff --git a/src/dstack/api/server/_events.py b/src/dstack/api/server/_events.py index d9bf828394..d403fb2427 100644 --- a/src/dstack/api/server/_events.py +++ b/src/dstack/api/server/_events.py @@ -30,6 +30,7 @@ def list( *, # NOTE: New parameters go here. Avoid positional parameters, they can break compatibility. target_volumes: Optional[list[UUID]] = None, + target_gateways: Optional[list[UUID]] = None, ) -> list[Event]: if prev_recorded_at is not None: # Time zones other than UTC are misinterpreted by the server: @@ -43,6 +44,7 @@ def list( target_runs=target_runs, target_jobs=target_jobs, target_volumes=target_volumes, + target_gateways=target_gateways, within_projects=within_projects, within_fleets=within_fleets, within_runs=within_runs, diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index b909c7d729..70f6b22b7e 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -17,6 +17,7 @@ create_user, get_auth_headers, ) +from dstack._internal.server.testing.matchers import SomeUUID4Str class TestListAndGetGateways: @@ -54,6 +55,7 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): assert response.status_code == 200 assert response.json() == [ { + "id": SomeUUID4Str(), "backend": backend.type.value, "created_at": response.json()[0]["created_at"], "default": False, @@ -107,6 +109,7 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): ) assert response.status_code == 200 assert response.json() == { + "id": SomeUUID4Str(), "backend": backend.type.value, "created_at": response.json()["created_at"], "default": False, @@ -189,6 +192,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn ) assert response.status_code == 200 assert response.json() == { + "id": SomeUUID4Str(), "name": "test", "backend": "aws", "region": "us", @@ -243,6 +247,7 @@ async def test_create_gateway_without_name( g.assert_called_once() assert response.status_code == 200 assert response.json() == { + "id": SomeUUID4Str(), "name": "random-name", "backend": "aws", "region": "us", @@ -347,6 +352,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: ) assert response.status_code == 200 assert response.json() == { + "id": SomeUUID4Str(), "backend": backend.type.value, "created_at": response.json()["created_at"], "default": True, @@ -471,6 +477,7 @@ def get_backend(project, backend_type): assert response.status_code == 200 assert response.json() == [ { + "id": str(gateway_gcp.id), "backend": backend_gcp.type.value, "created_at": response.json()[0]["created_at"], "default": False, @@ -542,6 +549,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: ) assert response.status_code == 200 assert response.json() == { + "id": SomeUUID4Str(), "backend": backend.type.value, "created_at": response.json()["created_at"], "status": "submitted", From bbcce8fd287a69981d38d894afb84ac52c1d2600 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Mon, 26 Jan 2026 20:58:23 +0100 Subject: [PATCH 2/4] Add gateway lifecycle events - Gateway created - Gateway status changed - Gateway deleted - Gateway set as default - Gateway unset as default - Gateway wildcard domain changed --- .../background/tasks/process_gateways.py | 24 +-- .../_internal/server/routers/gateways.py | 15 +- .../server/services/gateways/__init__.py | 141 ++++++++++++++---- src/dstack/_internal/server/testing/common.py | 15 +- .../background/tasks/test_process_gateways.py | 16 ++ .../_internal/server/routers/test_gateways.py | 57 ++++++- 6 files changed, 211 insertions(+), 57 deletions(-) diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/tasks/process_gateways.py index a54cb9e319..2566a4f4d8 100644 --- a/src/dstack/_internal/server/background/tasks/process_gateways.py +++ b/src/dstack/_internal/server/background/tasks/process_gateways.py @@ -14,6 +14,7 @@ GatewayConnection, create_gateway_compute, gateway_connections_pool, + switch_gateway_status, ) from dstack._internal.server.services.locking import advisory_lock_ctx, get_locker from dstack._internal.server.services.logging import fmt @@ -60,14 +61,6 @@ async def process_gateways(): logger.error( "%s: unexpected gateway status %r", fmt(gateway_model), initial_status.upper() ) - if gateway_model.status != initial_status: - logger.info( - "%s: gateway status has changed %s -> %s%s", - fmt(gateway_model), - initial_status.upper(), - gateway_model.status.upper(), - f": {gateway_model.status_message}" if gateway_model.status_message else "", - ) gateway_model.last_processed_at = get_current_datetime() await session.commit() finally: @@ -128,8 +121,8 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew project=gateway_model.project, backend_type=configuration.backend ) except BackendNotAvailable: - gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Backend not available" + switch_gateway_status(session, gateway_model, GatewayStatus.FAILED) return try: @@ -140,18 +133,17 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew backend_id=backend_model.id, ) session.add(gateway_model) - gateway_model.status = GatewayStatus.PROVISIONING + switch_gateway_status(session, gateway_model, GatewayStatus.PROVISIONING) except BackendError as e: - logger.info("%s: failed to create gateway compute: %r", fmt(gateway_model), e) - gateway_model.status = GatewayStatus.FAILED status_message = f"Backend error: {repr(e)}" if len(e.args) > 0: status_message = str(e.args[0]) gateway_model.status_message = status_message + switch_gateway_status(session, gateway_model, GatewayStatus.FAILED) except Exception as e: logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model)) - gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = f"Unexpected error: {repr(e)}" + switch_gateway_status(session, gateway_model, GatewayStatus.FAILED) async def _process_provisioning_gateway( @@ -179,18 +171,18 @@ async def _process_provisioning_gateway( gateway_model.gateway_compute ) if connection is None: - gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Failed to connect to gateway" + switch_gateway_status(session, gateway_model, GatewayStatus.FAILED) gateway_model.gateway_compute.deleted = True return try: await gateways_services.configure_gateway(connection) except Exception: logger.exception("%s: failed to configure gateway", fmt(gateway_model)) - gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Failed to configure gateway" + switch_gateway_status(session, gateway_model, GatewayStatus.FAILED) await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) gateway_model.gateway_compute.active = False return - gateway_model.status = GatewayStatus.RUNNING + switch_gateway_status(session, gateway_model, GatewayStatus.RUNNING) diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index fb03a3d69c..0f89e5db45 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -72,11 +72,12 @@ async def delete_gateways( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): - _, project = user_project + user, project = user_project await gateways.delete_gateways( session=session, project=project, gateways_names=body.names, + user=user, ) @@ -86,8 +87,8 @@ async def set_default_gateway( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): - _, project = user_project - await gateways.set_default_gateway(session=session, project=project, name=body.name) + user, project = user_project + await gateways.set_default_gateway(session=session, project=project, name=body.name, user=user) @router.post("/set_wildcard_domain", response_model=models.Gateway) @@ -96,9 +97,13 @@ async def set_gateway_wildcard_domain( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): - _, project = user_project + user, project = user_project return CustomORJSONResponse( await gateways.set_gateway_wildcard_domain( - session=session, project=project, name=body.name, wildcard_domain=body.wildcard_domain + session=session, + project=project, + name=body.name, + wildcard_domain=body.wildcard_domain, + user=user, ) ) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index cf41b53973..2f3447ca79 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -1,6 +1,8 @@ import asyncio import datetime import uuid +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from datetime import timedelta from functools import partial from typing import List, Optional, Sequence @@ -45,6 +47,7 @@ ProjectModel, UserModel, ) +from dstack._internal.server.services import events from dstack._internal.server.services.backends import ( check_backend_type_available, get_project_backend_by_type_or_error, @@ -66,6 +69,24 @@ logger = get_logger(__name__) +def switch_gateway_status( + session: AsyncSession, + gateway_model: GatewayModel, + new_status: GatewayStatus, + actor: events.AnyActor = events.SystemActor(), +): + old_status = gateway_model.status + if old_status == new_status: + return + + gateway_model.status = new_status + + msg = f"Gateway status changed {old_status.upper()} -> {new_status.upper()}" + if gateway_model.status_message is not None: + msg += f" ({gateway_model.status_message})" + events.emit(session, msg, actor=actor, targets=[events.Target.from_model(gateway_model)]) + + GATEWAY_CONNECT_ATTEMPTS = 30 GATEWAY_CONNECT_DELAY = 10 GATEWAY_CONFIGURE_ATTEMPTS = 50 @@ -163,6 +184,7 @@ async def create_gateway( configuration.name = await generate_gateway_name(session=session, project=project) gateway = GatewayModel( + id=uuid.uuid4(), name=configuration.name, region=configuration.region, project_id=project.id, @@ -173,11 +195,19 @@ async def create_gateway( last_processed_at=get_current_datetime(), ) session.add(gateway) + events.emit( + session, + f"Gateway created. Status: {gateway.status.upper()}", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(gateway)], + ) await session.commit() default_gateway = await get_project_default_gateway_model(session=session, project=project) if default_gateway is None or configuration.default: - await set_default_gateway(session=session, project=project, name=configuration.name) + await set_default_gateway( + session=session, project=project, name=configuration.name, user=user + ) return gateway_model_to_gateway(gateway) @@ -214,6 +244,7 @@ async def delete_gateways( session: AsyncSession, project: ProjectModel, gateways_names: List[str], + user: UserModel, ): res = await session.execute( select(GatewayModel).where( @@ -273,46 +304,51 @@ async def delete_gateways( gateway_model.gateway_compute.deleted = True session.add(gateway_model.gateway_compute) await session.delete(gateway_model) + events.emit( + session, + "Gateway deleted", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(gateway_model)], + ) await session.commit() async def set_gateway_wildcard_domain( - session: AsyncSession, project: ProjectModel, name: str, wildcard_domain: Optional[str] + session: AsyncSession, + project: ProjectModel, + name: str, + wildcard_domain: Optional[str], + user: UserModel, ) -> Gateway: - gateway = await get_project_gateway_model_by_name( - session=session, - project=project, - name=name, - ) - if gateway is None: - raise ResourceNotExistsError() - if gateway.backend.type == BackendType.DSTACK: - raise ServerClientError("Custom domains for dstack Sky gateway are not supported") - await session.execute( - update(GatewayModel) - .where( - GatewayModel.project_id == project.id, - GatewayModel.name == name, - ) - .values( - wildcard_domain=wildcard_domain, - ) - ) - await session.commit() - gateway = await get_project_gateway_model_by_name( - session=session, - project=project, - name=name, - ) - if gateway is None: - raise ResourceNotExistsError() + async with get_project_gateway_model_by_name_for_update( + session=session, project=project, name=name + ) as gateway: + if gateway is None: + raise ResourceNotExistsError() + if gateway.backend.type == BackendType.DSTACK: + raise ServerClientError("Custom domains for dstack Sky gateway are not supported") + old_domain = gateway.wildcard_domain + if old_domain != wildcard_domain: + gateway.wildcard_domain = wildcard_domain + events.emit( + session, + f"Gateway wildcard domain changed {old_domain!r} -> {gateway.wildcard_domain!r}", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(gateway)], + ) + await session.commit() return gateway_model_to_gateway(gateway) -async def set_default_gateway(session: AsyncSession, project: ProjectModel, name: str): +async def set_default_gateway( + session: AsyncSession, project: ProjectModel, name: str, user: Optional[UserModel] +): gateway = await get_project_gateway_model_by_name(session=session, project=project, name=name) if gateway is None: raise ResourceNotExistsError() + if project.default_gateway_id == gateway.id: + return + previous_gateway = await get_project_default_gateway_model(session, project) await session.execute( update(ProjectModel) .where( @@ -322,6 +358,19 @@ async def set_default_gateway(session: AsyncSession, project: ProjectModel, name default_gateway_id=gateway.id, ) ) + if previous_gateway is not None: + events.emit( + session, + "Gateway unset as default", + actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(), + targets=[events.Target.from_model(previous_gateway)], + ) + events.emit( + session, + "Gateway set as default", + actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(), + targets=[events.Target.from_model(gateway)], + ) await session.commit() @@ -343,6 +392,38 @@ async def get_project_gateway_model_by_name( return res.scalar() +@asynccontextmanager +async def get_project_gateway_model_by_name_for_update( + session: AsyncSession, project: ProjectModel, name: str +) -> AsyncGenerator[Optional[GatewayModel], None]: + """ + Fetch the gateway from the database and lock it for update. + + **NOTE**: commit changes to the database before exiting from this context manager, + so that in-memory locks are only released after commit. + """ + + filters = [ + GatewayModel.project_id == project.id, + GatewayModel.name == name, + ] + res = await session.execute(select(GatewayModel.id).where(*filters)) + gateway_id = res.scalar_one_or_none() + if gateway_id is None: + yield None + else: + async with get_locker(get_db().dialect_name).lock_ctx( + GatewayModel.__tablename__, [gateway_id] + ): + # Refetch after lock + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id.in_([gateway_id]), *filters) + .with_for_update(key_share=True) + ) + yield res.scalar_one_or_none() + + async def get_project_default_gateway_model( session: AsyncSession, project: ProjectModel ) -> Optional[GatewayModel]: diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 640a3932dd..cca5212576 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -7,8 +7,9 @@ from uuid import UUID import gpuhunt -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload from dstack._internal.core.backends.base.compute import ( Compute, @@ -1114,8 +1115,16 @@ async def create_secret( async def list_events(session: AsyncSession) -> list[EventModel]: - res = await session.execute(select(EventModel).order_by(EventModel.recorded_at, EventModel.id)) - return list(res.scalars().all()) + res = await session.execute( + select(EventModel) + .order_by(EventModel.recorded_at, EventModel.id) + .options(joinedload(EventModel.targets)) + ) + return list(res.scalars().unique().all()) + + +async def clear_events(session: AsyncSession) -> None: + await session.execute(delete(EventModel)) def get_private_key_string() -> str: diff --git a/src/tests/_internal/server/background/tasks/test_process_gateways.py b/src/tests/_internal/server/background/tasks/test_process_gateways.py index 3460f18cb9..b280b8948d 100644 --- a/src/tests/_internal/server/background/tasks/test_process_gateways.py +++ b/src/tests/_internal/server/background/tasks/test_process_gateways.py @@ -13,6 +13,7 @@ create_gateway, create_gateway_compute, create_project, + list_events, ) @@ -46,6 +47,9 @@ async def test_submitted_to_provisioning(self, test_db, session: AsyncSession): assert gateway.status == GatewayStatus.PROVISIONING assert gateway.gateway_compute is not None assert gateway.gateway_compute.ip_address == "2.2.2.2" + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed SUBMITTED -> PROVISIONING" async def test_marks_gateway_as_failed_if_gateway_creation_errors( self, test_db, session: AsyncSession @@ -71,6 +75,9 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors( await session.refresh(gateway) assert gateway.status == GatewayStatus.FAILED assert gateway.status_message == "Some error" + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)" @pytest.mark.asyncio @@ -96,6 +103,9 @@ async def test_provisioning_to_running(self, test_db, session: AsyncSession): pool_add.assert_called_once() await session.refresh(gateway) assert gateway.status == GatewayStatus.RUNNING + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed PROVISIONING -> RUNNING" async def test_marks_gateway_as_failed_if_fails_to_connect( self, test_db, session: AsyncSession @@ -119,3 +129,9 @@ async def test_marks_gateway_as_failed_if_fails_to_connect( await session.refresh(gateway) assert gateway.status == GatewayStatus.FAILED assert gateway.status_message == "Failed to connect to gateway" + events = await list_events(session) + assert len(events) == 1 + assert ( + events[0].message + == "Gateway status changed PROVISIONING -> FAILED (Failed to connect to gateway)" + ) diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 70f6b22b7e..f80537a1b1 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -10,12 +10,14 @@ from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( ComputeMockSpec, + clear_events, create_backend, create_gateway, create_gateway_compute, create_project, create_user, get_auth_headers, + list_events, ) from dstack._internal.server.testing.matchers import SomeUUID4Str @@ -218,6 +220,8 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn "tags": None, }, } + events = await list_events(session) + assert events[0].message == "Gateway created. Status: SUBMITTED" @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -273,6 +277,8 @@ async def test_create_gateway_without_name( "tags": None, }, } + events = await list_events(session) + assert events[0].message == "Gateway created. Status: SUBMITTED" @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -337,6 +343,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: project_id=project.id, backend_id=backend.id, gateway_compute_id=gateway_compute.id, + name="first_gateway", ) response = await client.post( f"/api/project/{project.name}/gateways/set_default", @@ -378,6 +385,40 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "tags": None, }, } + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway set as default" + + second_gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + ) + second_gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=second_gateway_compute.id, + name="second_gateway", + ) + await clear_events(session) + response = await client.post( + f"/api/project/{project.name}/gateways/set_default", + json={"name": second_gateway.name}, + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + events = await list_events(session) + assert len(events) == 2 + actual_events = [(e.targets[0].entity_name, e.message) for e in events] + expected_events = [ + ("first_gateway", "Gateway unset as default"), + ("second_gateway", "Gateway set as default"), + ] + assert ( + actual_events == expected_events + # in case events are emitted exactly at the same time + or actual_events == expected_events[::-1] + ) @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -504,6 +545,10 @@ def get_backend(project, backend_type): }, } ] + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway deleted" + assert events[0].targets[0].entity_name == "gateway-aws" class TestUpdateGateway: @@ -541,10 +586,11 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: project_id=project.id, backend_id=backend.id, gateway_compute_id=gateway_compute.id, + wildcard_domain="old.example", ) response = await client.post( f"/api/project/{project.name}/gateways/set_wildcard_domain", - json={"name": gateway.name, "wildcard_domain": "test.com"}, + json={"name": gateway.name, "wildcard_domain": "new.example"}, headers=get_auth_headers(user.token), ) assert response.status_code == 200 @@ -560,7 +606,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "hostname": gateway_compute.ip_address, "name": gateway.name, "region": gateway.region, - "wildcard_domain": "test.com", + "wildcard_domain": "new.example", "configuration": { "type": "gateway", "name": gateway.name, @@ -568,13 +614,18 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "region": gateway.region, "instance_type": None, "router": None, - "domain": "test.com", + "domain": "new.example", "default": False, "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, }, } + events = await list_events(session) + assert len(events) == 1 + assert ( + events[0].message == "Gateway wildcard domain changed 'old.example' -> 'new.example'" + ) @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) From 54f389ed29d84230e6595438ccfa95c4f5845262 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Mon, 26 Jan 2026 22:17:09 +0100 Subject: [PATCH 3/4] Fix `test_set_wildcard_domain[postgres]` --- src/dstack/_internal/server/services/gateways/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 2f3447ca79..89a88cbae5 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -419,6 +419,7 @@ async def get_project_gateway_model_by_name_for_update( res = await session.execute( select(GatewayModel) .where(GatewayModel.id.in_([gateway_id]), *filters) + .options(selectinload(GatewayModel.gateway_compute)) .with_for_update(key_share=True) ) yield res.scalar_one_or_none() From 5c66bdcbc5c108e9067baf2299c497e6d23b2fab Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 27 Jan 2026 09:32:19 +0100 Subject: [PATCH 4/4] Avoid `selectinload` --- src/dstack/_internal/server/services/gateways/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 89a88cbae5..bff20466a8 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -419,8 +419,7 @@ async def get_project_gateway_model_by_name_for_update( res = await session.execute( select(GatewayModel) .where(GatewayModel.id.in_([gateway_id]), *filters) - .options(selectinload(GatewayModel.gateway_compute)) - .with_for_update(key_share=True) + .with_for_update(key_share=True, of=GatewayModel) ) yield res.scalar_one_or_none()