diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index eb325ad156..640a3932dd 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -7,6 +7,7 @@ from uuid import UUID import gpuhunt +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.backends.base.compute import ( @@ -90,6 +91,7 @@ BackendModel, ComputeGroupModel, DecryptedString, + EventModel, FileArchiveModel, FleetModel, GatewayComputeModel, @@ -1111,6 +1113,11 @@ async def create_secret( return secret_model +async def list_events(session: AsyncSession) -> list[EventModel]: + res = await session.execute(select(EventModel).order_by(EventModel.recorded_at, EventModel.id)) + return list(res.scalars().all()) + + def get_private_key_string() -> str: return """ -----BEGIN RSA PRIVATE KEY----- diff --git a/src/tests/_internal/server/background/tasks/test_process_events.py b/src/tests/_internal/server/background/tasks/test_process_events.py index 899f2946e8..21043e0bae 100644 --- a/src/tests/_internal/server/background/tasks/test_process_events.py +++ b/src/tests/_internal/server/background/tasks/test_process_events.py @@ -3,14 +3,12 @@ import pytest from freezegun import freeze_time -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.server import settings from dstack._internal.server.background.tasks.process_events import delete_events -from dstack._internal.server.models import EventModel from dstack._internal.server.services import events -from dstack._internal.server.testing.common import create_user +from dstack._internal.server.testing.common import create_user, list_events @pytest.mark.asyncio @@ -27,8 +25,7 @@ async def test_deletes_old_events(test_db, session: AsyncSession) -> None: ) await session.commit() - res = await session.execute(select(EventModel)) - all_events = res.scalars().all() + all_events = await list_events(session) assert len(all_events) == 10 with ( @@ -37,8 +34,7 @@ async def test_deletes_old_events(test_db, session: AsyncSession) -> None: ): await delete_events() - res = await session.execute(select(EventModel).order_by(EventModel.recorded_at)) - remaining_events = res.scalars().all() + remaining_events = await list_events(session) assert len(remaining_events) == 5 assert [e.message for e in remaining_events] == [ "Event 5", diff --git a/src/tests/_internal/server/services/test_instances.py b/src/tests/_internal/server/services/test_instances.py index 9e4cb02e3a..ca6432d61e 100644 --- a/src/tests/_internal/server/services/test_instances.py +++ b/src/tests/_internal/server/services/test_instances.py @@ -1,7 +1,6 @@ import uuid import pytest -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession import dstack._internal.server.services.instances as instances_services @@ -15,13 +14,14 @@ Resources, ) from dstack._internal.core.models.profiles import Profile -from dstack._internal.server.models import EventModel, InstanceModel +from dstack._internal.server.models import InstanceModel from dstack._internal.server.testing.common import ( create_instance, create_project, create_user, get_volume, get_volume_configuration, + list_events, ) from dstack._internal.utils.common import get_current_datetime @@ -41,8 +41,7 @@ async def test_includes_termination_reason_in_event_messages_only_once( instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATING) instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATED) - res = await session.execute(select(EventModel)) - events = res.scalars().all() + events = await list_events(session) assert len(events) == 2 assert {e.message for e in events} == { "Instance status changed PENDING -> TERMINATING. Termination reason: ERROR (Some err)", @@ -63,8 +62,7 @@ async def test_includes_termination_reason_in_event_message_when_switching_direc instance.termination_reason_message = "Some err" instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATED) - res = await session.execute(select(EventModel)) - events = res.scalars().all() + events = await list_events(session) assert len(events) == 1 assert events[0].message == ( "Instance status changed PENDING -> TERMINATED. Termination reason: ERROR (Some err)"