diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/tasks/process_gateways.py index a54cb9e31..2566a4f4d 100644 --- a/src/dstack/_internal/server/background/tasks/process_gateways.py +++ b/src/dstack/_internal/server/background/tasks/process_gateways.py @@ -14,6 +14,7 @@ GatewayConnection, create_gateway_compute, gateway_connections_pool, + switch_gateway_status, ) from dstack._internal.server.services.locking import advisory_lock_ctx, get_locker from dstack._internal.server.services.logging import fmt @@ -60,14 +61,6 @@ async def process_gateways(): logger.error( "%s: unexpected gateway status %r", fmt(gateway_model), initial_status.upper() ) - if gateway_model.status != initial_status: - logger.info( - "%s: gateway status has changed %s -> %s%s", - fmt(gateway_model), - initial_status.upper(), - gateway_model.status.upper(), - f": {gateway_model.status_message}" if gateway_model.status_message else "", - ) gateway_model.last_processed_at = get_current_datetime() await session.commit() finally: @@ -128,8 +121,8 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew project=gateway_model.project, backend_type=configuration.backend ) except BackendNotAvailable: - gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Backend not available" + switch_gateway_status(session, gateway_model, GatewayStatus.FAILED) return try: @@ -140,18 +133,17 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew backend_id=backend_model.id, ) session.add(gateway_model) - gateway_model.status = GatewayStatus.PROVISIONING + switch_gateway_status(session, gateway_model, GatewayStatus.PROVISIONING) except BackendError as e: - logger.info("%s: failed to create gateway compute: %r", fmt(gateway_model), e) - gateway_model.status = GatewayStatus.FAILED status_message = f"Backend error: {repr(e)}" if len(e.args) > 0: status_message = str(e.args[0]) gateway_model.status_message = status_message + switch_gateway_status(session, gateway_model, GatewayStatus.FAILED) except Exception as e: logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model)) - gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = f"Unexpected error: {repr(e)}" + switch_gateway_status(session, gateway_model, GatewayStatus.FAILED) async def _process_provisioning_gateway( @@ -179,18 +171,18 @@ async def _process_provisioning_gateway( gateway_model.gateway_compute ) if connection is None: - gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Failed to connect to gateway" + switch_gateway_status(session, gateway_model, GatewayStatus.FAILED) gateway_model.gateway_compute.deleted = True return try: await gateways_services.configure_gateway(connection) except Exception: logger.exception("%s: failed to configure gateway", fmt(gateway_model)) - gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Failed to configure gateway" + switch_gateway_status(session, gateway_model, GatewayStatus.FAILED) await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) gateway_model.gateway_compute.active = False return - gateway_model.status = GatewayStatus.RUNNING + switch_gateway_status(session, gateway_model, GatewayStatus.RUNNING) diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index fb03a3d69..0f89e5db4 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -72,11 +72,12 @@ async def delete_gateways( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): - _, project = user_project + user, project = user_project await gateways.delete_gateways( session=session, project=project, gateways_names=body.names, + user=user, ) @@ -86,8 +87,8 @@ async def set_default_gateway( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): - _, project = user_project - await gateways.set_default_gateway(session=session, project=project, name=body.name) + user, project = user_project + await gateways.set_default_gateway(session=session, project=project, name=body.name, user=user) @router.post("/set_wildcard_domain", response_model=models.Gateway) @@ -96,9 +97,13 @@ async def set_gateway_wildcard_domain( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): - _, project = user_project + user, project = user_project return CustomORJSONResponse( await gateways.set_gateway_wildcard_domain( - session=session, project=project, name=body.name, wildcard_domain=body.wildcard_domain + session=session, + project=project, + name=body.name, + wildcard_domain=body.wildcard_domain, + user=user, ) ) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index cf41b5397..bff20466a 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -1,6 +1,8 @@ import asyncio import datetime import uuid +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from datetime import timedelta from functools import partial from typing import List, Optional, Sequence @@ -45,6 +47,7 @@ ProjectModel, UserModel, ) +from dstack._internal.server.services import events from dstack._internal.server.services.backends import ( check_backend_type_available, get_project_backend_by_type_or_error, @@ -66,6 +69,24 @@ logger = get_logger(__name__) +def switch_gateway_status( + session: AsyncSession, + gateway_model: GatewayModel, + new_status: GatewayStatus, + actor: events.AnyActor = events.SystemActor(), +): + old_status = gateway_model.status + if old_status == new_status: + return + + gateway_model.status = new_status + + msg = f"Gateway status changed {old_status.upper()} -> {new_status.upper()}" + if gateway_model.status_message is not None: + msg += f" ({gateway_model.status_message})" + events.emit(session, msg, actor=actor, targets=[events.Target.from_model(gateway_model)]) + + GATEWAY_CONNECT_ATTEMPTS = 30 GATEWAY_CONNECT_DELAY = 10 GATEWAY_CONFIGURE_ATTEMPTS = 50 @@ -163,6 +184,7 @@ async def create_gateway( configuration.name = await generate_gateway_name(session=session, project=project) gateway = GatewayModel( + id=uuid.uuid4(), name=configuration.name, region=configuration.region, project_id=project.id, @@ -173,11 +195,19 @@ async def create_gateway( last_processed_at=get_current_datetime(), ) session.add(gateway) + events.emit( + session, + f"Gateway created. Status: {gateway.status.upper()}", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(gateway)], + ) await session.commit() default_gateway = await get_project_default_gateway_model(session=session, project=project) if default_gateway is None or configuration.default: - await set_default_gateway(session=session, project=project, name=configuration.name) + await set_default_gateway( + session=session, project=project, name=configuration.name, user=user + ) return gateway_model_to_gateway(gateway) @@ -214,6 +244,7 @@ async def delete_gateways( session: AsyncSession, project: ProjectModel, gateways_names: List[str], + user: UserModel, ): res = await session.execute( select(GatewayModel).where( @@ -273,46 +304,51 @@ async def delete_gateways( gateway_model.gateway_compute.deleted = True session.add(gateway_model.gateway_compute) await session.delete(gateway_model) + events.emit( + session, + "Gateway deleted", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(gateway_model)], + ) await session.commit() async def set_gateway_wildcard_domain( - session: AsyncSession, project: ProjectModel, name: str, wildcard_domain: Optional[str] + session: AsyncSession, + project: ProjectModel, + name: str, + wildcard_domain: Optional[str], + user: UserModel, ) -> Gateway: - gateway = await get_project_gateway_model_by_name( - session=session, - project=project, - name=name, - ) - if gateway is None: - raise ResourceNotExistsError() - if gateway.backend.type == BackendType.DSTACK: - raise ServerClientError("Custom domains for dstack Sky gateway are not supported") - await session.execute( - update(GatewayModel) - .where( - GatewayModel.project_id == project.id, - GatewayModel.name == name, - ) - .values( - wildcard_domain=wildcard_domain, - ) - ) - await session.commit() - gateway = await get_project_gateway_model_by_name( - session=session, - project=project, - name=name, - ) - if gateway is None: - raise ResourceNotExistsError() + async with get_project_gateway_model_by_name_for_update( + session=session, project=project, name=name + ) as gateway: + if gateway is None: + raise ResourceNotExistsError() + if gateway.backend.type == BackendType.DSTACK: + raise ServerClientError("Custom domains for dstack Sky gateway are not supported") + old_domain = gateway.wildcard_domain + if old_domain != wildcard_domain: + gateway.wildcard_domain = wildcard_domain + events.emit( + session, + f"Gateway wildcard domain changed {old_domain!r} -> {gateway.wildcard_domain!r}", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(gateway)], + ) + await session.commit() return gateway_model_to_gateway(gateway) -async def set_default_gateway(session: AsyncSession, project: ProjectModel, name: str): +async def set_default_gateway( + session: AsyncSession, project: ProjectModel, name: str, user: Optional[UserModel] +): gateway = await get_project_gateway_model_by_name(session=session, project=project, name=name) if gateway is None: raise ResourceNotExistsError() + if project.default_gateway_id == gateway.id: + return + previous_gateway = await get_project_default_gateway_model(session, project) await session.execute( update(ProjectModel) .where( @@ -322,6 +358,19 @@ async def set_default_gateway(session: AsyncSession, project: ProjectModel, name default_gateway_id=gateway.id, ) ) + if previous_gateway is not None: + events.emit( + session, + "Gateway unset as default", + actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(), + targets=[events.Target.from_model(previous_gateway)], + ) + events.emit( + session, + "Gateway set as default", + actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(), + targets=[events.Target.from_model(gateway)], + ) await session.commit() @@ -343,6 +392,38 @@ async def get_project_gateway_model_by_name( return res.scalar() +@asynccontextmanager +async def get_project_gateway_model_by_name_for_update( + session: AsyncSession, project: ProjectModel, name: str +) -> AsyncGenerator[Optional[GatewayModel], None]: + """ + Fetch the gateway from the database and lock it for update. + + **NOTE**: commit changes to the database before exiting from this context manager, + so that in-memory locks are only released after commit. + """ + + filters = [ + GatewayModel.project_id == project.id, + GatewayModel.name == name, + ] + res = await session.execute(select(GatewayModel.id).where(*filters)) + gateway_id = res.scalar_one_or_none() + if gateway_id is None: + yield None + else: + async with get_locker(get_db().dialect_name).lock_ctx( + GatewayModel.__tablename__, [gateway_id] + ): + # Refetch after lock + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id.in_([gateway_id]), *filters) + .with_for_update(key_share=True, of=GatewayModel) + ) + yield res.scalar_one_or_none() + + async def get_project_default_gateway_model( session: AsyncSession, project: ProjectModel ) -> Optional[GatewayModel]: diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 640a3932d..cca521257 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -7,8 +7,9 @@ from uuid import UUID import gpuhunt -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload from dstack._internal.core.backends.base.compute import ( Compute, @@ -1114,8 +1115,16 @@ async def create_secret( async def list_events(session: AsyncSession) -> list[EventModel]: - res = await session.execute(select(EventModel).order_by(EventModel.recorded_at, EventModel.id)) - return list(res.scalars().all()) + res = await session.execute( + select(EventModel) + .order_by(EventModel.recorded_at, EventModel.id) + .options(joinedload(EventModel.targets)) + ) + return list(res.scalars().unique().all()) + + +async def clear_events(session: AsyncSession) -> None: + await session.execute(delete(EventModel)) def get_private_key_string() -> str: diff --git a/src/tests/_internal/server/background/tasks/test_process_gateways.py b/src/tests/_internal/server/background/tasks/test_process_gateways.py index 3460f18cb..b280b8948 100644 --- a/src/tests/_internal/server/background/tasks/test_process_gateways.py +++ b/src/tests/_internal/server/background/tasks/test_process_gateways.py @@ -13,6 +13,7 @@ create_gateway, create_gateway_compute, create_project, + list_events, ) @@ -46,6 +47,9 @@ async def test_submitted_to_provisioning(self, test_db, session: AsyncSession): assert gateway.status == GatewayStatus.PROVISIONING assert gateway.gateway_compute is not None assert gateway.gateway_compute.ip_address == "2.2.2.2" + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed SUBMITTED -> PROVISIONING" async def test_marks_gateway_as_failed_if_gateway_creation_errors( self, test_db, session: AsyncSession @@ -71,6 +75,9 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors( await session.refresh(gateway) assert gateway.status == GatewayStatus.FAILED assert gateway.status_message == "Some error" + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)" @pytest.mark.asyncio @@ -96,6 +103,9 @@ async def test_provisioning_to_running(self, test_db, session: AsyncSession): pool_add.assert_called_once() await session.refresh(gateway) assert gateway.status == GatewayStatus.RUNNING + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed PROVISIONING -> RUNNING" async def test_marks_gateway_as_failed_if_fails_to_connect( self, test_db, session: AsyncSession @@ -119,3 +129,9 @@ async def test_marks_gateway_as_failed_if_fails_to_connect( await session.refresh(gateway) assert gateway.status == GatewayStatus.FAILED assert gateway.status_message == "Failed to connect to gateway" + events = await list_events(session) + assert len(events) == 1 + assert ( + events[0].message + == "Gateway status changed PROVISIONING -> FAILED (Failed to connect to gateway)" + ) diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 70f6b22b7..f80537a1b 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -10,12 +10,14 @@ from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( ComputeMockSpec, + clear_events, create_backend, create_gateway, create_gateway_compute, create_project, create_user, get_auth_headers, + list_events, ) from dstack._internal.server.testing.matchers import SomeUUID4Str @@ -218,6 +220,8 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn "tags": None, }, } + events = await list_events(session) + assert events[0].message == "Gateway created. Status: SUBMITTED" @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -273,6 +277,8 @@ async def test_create_gateway_without_name( "tags": None, }, } + events = await list_events(session) + assert events[0].message == "Gateway created. Status: SUBMITTED" @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -337,6 +343,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: project_id=project.id, backend_id=backend.id, gateway_compute_id=gateway_compute.id, + name="first_gateway", ) response = await client.post( f"/api/project/{project.name}/gateways/set_default", @@ -378,6 +385,40 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "tags": None, }, } + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway set as default" + + second_gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + ) + second_gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=second_gateway_compute.id, + name="second_gateway", + ) + await clear_events(session) + response = await client.post( + f"/api/project/{project.name}/gateways/set_default", + json={"name": second_gateway.name}, + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + events = await list_events(session) + assert len(events) == 2 + actual_events = [(e.targets[0].entity_name, e.message) for e in events] + expected_events = [ + ("first_gateway", "Gateway unset as default"), + ("second_gateway", "Gateway set as default"), + ] + assert ( + actual_events == expected_events + # in case events are emitted exactly at the same time + or actual_events == expected_events[::-1] + ) @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -504,6 +545,10 @@ def get_backend(project, backend_type): }, } ] + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway deleted" + assert events[0].targets[0].entity_name == "gateway-aws" class TestUpdateGateway: @@ -541,10 +586,11 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: project_id=project.id, backend_id=backend.id, gateway_compute_id=gateway_compute.id, + wildcard_domain="old.example", ) response = await client.post( f"/api/project/{project.name}/gateways/set_wildcard_domain", - json={"name": gateway.name, "wildcard_domain": "test.com"}, + json={"name": gateway.name, "wildcard_domain": "new.example"}, headers=get_auth_headers(user.token), ) assert response.status_code == 200 @@ -560,7 +606,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "hostname": gateway_compute.ip_address, "name": gateway.name, "region": gateway.region, - "wildcard_domain": "test.com", + "wildcard_domain": "new.example", "configuration": { "type": "gateway", "name": gateway.name, @@ -568,13 +614,18 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "region": gateway.region, "instance_type": None, "router": None, - "domain": "test.com", + "domain": "new.example", "default": False, "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, }, } + events = await list_events(session) + assert len(events) == 1 + assert ( + events[0].message == "Gateway wildcard domain changed 'old.example' -> 'new.example'" + ) @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)