Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions src/dstack/_internal/server/background/tasks/process_gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
GatewayConnection,
create_gateway_compute,
gateway_connections_pool,
switch_gateway_status,
)
from dstack._internal.server.services.locking import advisory_lock_ctx, get_locker
from dstack._internal.server.services.logging import fmt
Expand Down Expand Up @@ -60,14 +61,6 @@ async def process_gateways():
logger.error(
"%s: unexpected gateway status %r", fmt(gateway_model), initial_status.upper()
)
if gateway_model.status != initial_status:
logger.info(
"%s: gateway status has changed %s -> %s%s",
fmt(gateway_model),
initial_status.upper(),
gateway_model.status.upper(),
f": {gateway_model.status_message}" if gateway_model.status_message else "",
)
gateway_model.last_processed_at = get_current_datetime()
await session.commit()
finally:
Expand Down Expand Up @@ -128,8 +121,8 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
project=gateway_model.project, backend_type=configuration.backend
)
except BackendNotAvailable:
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Backend not available"
switch_gateway_status(session, gateway_model, GatewayStatus.FAILED)
return

try:
Expand All @@ -140,18 +133,17 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
backend_id=backend_model.id,
)
session.add(gateway_model)
gateway_model.status = GatewayStatus.PROVISIONING
switch_gateway_status(session, gateway_model, GatewayStatus.PROVISIONING)
except BackendError as e:
logger.info("%s: failed to create gateway compute: %r", fmt(gateway_model), e)
gateway_model.status = GatewayStatus.FAILED
status_message = f"Backend error: {repr(e)}"
if len(e.args) > 0:
status_message = str(e.args[0])
gateway_model.status_message = status_message
switch_gateway_status(session, gateway_model, GatewayStatus.FAILED)
except Exception as e:
logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model))
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = f"Unexpected error: {repr(e)}"
switch_gateway_status(session, gateway_model, GatewayStatus.FAILED)


async def _process_provisioning_gateway(
Expand Down Expand Up @@ -179,18 +171,18 @@ async def _process_provisioning_gateway(
gateway_model.gateway_compute
)
if connection is None:
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Failed to connect to gateway"
switch_gateway_status(session, gateway_model, GatewayStatus.FAILED)
gateway_model.gateway_compute.deleted = True
return
try:
await gateways_services.configure_gateway(connection)
except Exception:
logger.exception("%s: failed to configure gateway", fmt(gateway_model))
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Failed to configure gateway"
switch_gateway_status(session, gateway_model, GatewayStatus.FAILED)
await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
gateway_model.gateway_compute.active = False
return

gateway_model.status = GatewayStatus.RUNNING
switch_gateway_status(session, gateway_model, GatewayStatus.RUNNING)
15 changes: 10 additions & 5 deletions src/dstack/_internal/server/routers/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ async def delete_gateways(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
):
_, project = user_project
user, project = user_project
await gateways.delete_gateways(
session=session,
project=project,
gateways_names=body.names,
user=user,
)


Expand All @@ -86,8 +87,8 @@ async def set_default_gateway(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
):
_, project = user_project
await gateways.set_default_gateway(session=session, project=project, name=body.name)
user, project = user_project
await gateways.set_default_gateway(session=session, project=project, name=body.name, user=user)


@router.post("/set_wildcard_domain", response_model=models.Gateway)
Expand All @@ -96,9 +97,13 @@ async def set_gateway_wildcard_domain(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
):
_, project = user_project
user, project = user_project
return CustomORJSONResponse(
await gateways.set_gateway_wildcard_domain(
session=session, project=project, name=body.name, wildcard_domain=body.wildcard_domain
session=session,
project=project,
name=body.name,
wildcard_domain=body.wildcard_domain,
user=user,
)
)
141 changes: 111 additions & 30 deletions src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import datetime
import uuid
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import timedelta
from functools import partial
from typing import List, Optional, Sequence
Expand Down Expand Up @@ -45,6 +47,7 @@
ProjectModel,
UserModel,
)
from dstack._internal.server.services import events
from dstack._internal.server.services.backends import (
check_backend_type_available,
get_project_backend_by_type_or_error,
Expand All @@ -66,6 +69,24 @@
logger = get_logger(__name__)


def switch_gateway_status(
session: AsyncSession,
gateway_model: GatewayModel,
new_status: GatewayStatus,
actor: events.AnyActor = events.SystemActor(),
):
old_status = gateway_model.status
if old_status == new_status:
return

gateway_model.status = new_status

msg = f"Gateway status changed {old_status.upper()} -> {new_status.upper()}"
if gateway_model.status_message is not None:
msg += f" ({gateway_model.status_message})"
events.emit(session, msg, actor=actor, targets=[events.Target.from_model(gateway_model)])


GATEWAY_CONNECT_ATTEMPTS = 30
GATEWAY_CONNECT_DELAY = 10
GATEWAY_CONFIGURE_ATTEMPTS = 50
Expand Down Expand Up @@ -163,6 +184,7 @@ async def create_gateway(
configuration.name = await generate_gateway_name(session=session, project=project)

gateway = GatewayModel(
id=uuid.uuid4(),
name=configuration.name,
region=configuration.region,
project_id=project.id,
Expand All @@ -173,11 +195,19 @@ async def create_gateway(
last_processed_at=get_current_datetime(),
)
session.add(gateway)
events.emit(
session,
f"Gateway created. Status: {gateway.status.upper()}",
actor=events.UserActor.from_user(user),
targets=[events.Target.from_model(gateway)],
)
await session.commit()

default_gateway = await get_project_default_gateway_model(session=session, project=project)
if default_gateway is None or configuration.default:
await set_default_gateway(session=session, project=project, name=configuration.name)
await set_default_gateway(
session=session, project=project, name=configuration.name, user=user
)
return gateway_model_to_gateway(gateway)


Expand Down Expand Up @@ -214,6 +244,7 @@ async def delete_gateways(
session: AsyncSession,
project: ProjectModel,
gateways_names: List[str],
user: UserModel,
):
res = await session.execute(
select(GatewayModel).where(
Expand Down Expand Up @@ -273,46 +304,51 @@ async def delete_gateways(
gateway_model.gateway_compute.deleted = True
session.add(gateway_model.gateway_compute)
await session.delete(gateway_model)
events.emit(
session,
"Gateway deleted",
actor=events.UserActor.from_user(user),
targets=[events.Target.from_model(gateway_model)],
)
await session.commit()


async def set_gateway_wildcard_domain(
session: AsyncSession, project: ProjectModel, name: str, wildcard_domain: Optional[str]
session: AsyncSession,
project: ProjectModel,
name: str,
wildcard_domain: Optional[str],
user: UserModel,
) -> Gateway:
gateway = await get_project_gateway_model_by_name(
session=session,
project=project,
name=name,
)
if gateway is None:
raise ResourceNotExistsError()
if gateway.backend.type == BackendType.DSTACK:
raise ServerClientError("Custom domains for dstack Sky gateway are not supported")
await session.execute(
update(GatewayModel)
.where(
GatewayModel.project_id == project.id,
GatewayModel.name == name,
)
.values(
wildcard_domain=wildcard_domain,
)
)
await session.commit()
gateway = await get_project_gateway_model_by_name(
session=session,
project=project,
name=name,
)
if gateway is None:
raise ResourceNotExistsError()
async with get_project_gateway_model_by_name_for_update(
session=session, project=project, name=name
) as gateway:
if gateway is None:
raise ResourceNotExistsError()
if gateway.backend.type == BackendType.DSTACK:
raise ServerClientError("Custom domains for dstack Sky gateway are not supported")
old_domain = gateway.wildcard_domain
if old_domain != wildcard_domain:
gateway.wildcard_domain = wildcard_domain
events.emit(
session,
f"Gateway wildcard domain changed {old_domain!r} -> {gateway.wildcard_domain!r}",
actor=events.UserActor.from_user(user),
targets=[events.Target.from_model(gateway)],
)
await session.commit()
return gateway_model_to_gateway(gateway)


async def set_default_gateway(session: AsyncSession, project: ProjectModel, name: str):
async def set_default_gateway(
session: AsyncSession, project: ProjectModel, name: str, user: Optional[UserModel]
):
gateway = await get_project_gateway_model_by_name(session=session, project=project, name=name)
if gateway is None:
raise ResourceNotExistsError()
if project.default_gateway_id == gateway.id:
return
previous_gateway = await get_project_default_gateway_model(session, project)
await session.execute(
update(ProjectModel)
.where(
Expand All @@ -322,6 +358,19 @@ async def set_default_gateway(session: AsyncSession, project: ProjectModel, name
default_gateway_id=gateway.id,
)
)
if previous_gateway is not None:
events.emit(
session,
"Gateway unset as default",
actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(),
targets=[events.Target.from_model(previous_gateway)],
)
events.emit(
session,
"Gateway set as default",
actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(),
targets=[events.Target.from_model(gateway)],
)
await session.commit()


Expand All @@ -343,6 +392,38 @@ async def get_project_gateway_model_by_name(
return res.scalar()


@asynccontextmanager
async def get_project_gateway_model_by_name_for_update(
session: AsyncSession, project: ProjectModel, name: str
) -> AsyncGenerator[Optional[GatewayModel], None]:
"""
Fetch the gateway from the database and lock it for update.

**NOTE**: commit changes to the database before exiting from this context manager,
so that in-memory locks are only released after commit.
"""

filters = [
GatewayModel.project_id == project.id,
GatewayModel.name == name,
]
res = await session.execute(select(GatewayModel.id).where(*filters))
gateway_id = res.scalar_one_or_none()
if gateway_id is None:
yield None
else:
async with get_locker(get_db().dialect_name).lock_ctx(
GatewayModel.__tablename__, [gateway_id]
):
# Refetch after lock
res = await session.execute(
select(GatewayModel)
.where(GatewayModel.id.in_([gateway_id]), *filters)
.with_for_update(key_share=True, of=GatewayModel)
)
yield res.scalar_one_or_none()


async def get_project_default_gateway_model(
session: AsyncSession, project: ProjectModel
) -> Optional[GatewayModel]:
Expand Down
15 changes: 12 additions & 3 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from uuid import UUID

import gpuhunt
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload

from dstack._internal.core.backends.base.compute import (
Compute,
Expand Down Expand Up @@ -1114,8 +1115,16 @@ async def create_secret(


async def list_events(session: AsyncSession) -> list[EventModel]:
res = await session.execute(select(EventModel).order_by(EventModel.recorded_at, EventModel.id))
return list(res.scalars().all())
res = await session.execute(
select(EventModel)
.order_by(EventModel.recorded_at, EventModel.id)
.options(joinedload(EventModel.targets))
)
return list(res.scalars().unique().all())


async def clear_events(session: AsyncSession) -> None:
await session.execute(delete(EventModel))


def get_private_key_string() -> str:
Expand Down
Loading