Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 143 additions & 22 deletions src/dstack/_internal/cli/commands/login.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@
import argparse
import queue
import sys
import threading
import urllib.parse
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Optional
from typing import Any, Optional

import questionary
from rich.prompt import Prompt as RichPrompt
from rich.text import Text

from dstack._internal.cli.commands import BaseCommand
from dstack._internal.cli.commands.project import select_default_project
from dstack._internal.cli.utils.common import console, resolve_url
from dstack._internal.core.errors import ClientError, CLIError
from dstack._internal.core.models.users import UserWithCreds
from dstack._internal.utils.logging import get_logger
from dstack.api._public.runs import ConfigManager
from dstack.api.server import APIClient

logger = get_logger(__name__)

is_project_menu_supported = sys.stdin.isatty()


class UrlPrompt(RichPrompt):
def render_default(self, default: Any) -> Text:
return Text(f"({default})", style="bold orange1")


class LoginCommand(BaseCommand):
NAME = "login"
Expand All @@ -23,7 +39,7 @@ def _register(self):
self._parser.add_argument(
"--url",
help="The server URL, e.g. https://sky.dstack.ai",
required=True,
required=not is_project_menu_supported,
)
self._parser.add_argument(
"-p",
Expand All @@ -33,10 +49,25 @@ def _register(self):
" Selected automatically if the server supports only one provider."
),
)
self._parser.add_argument(
"-y",
"--yes",
help="Don't ask for confirmation (e.g. set first project as default)",
action="store_true",
)
self._parser.add_argument(
"-n",
"--no",
help="Don't ask for confirmation (e.g. do not change default project)",
action="store_true",
)

def _command(self, args: argparse.Namespace):
super()._command(args)
base_url = _normalize_url_or_error(args.url)
url = args.url
if url is None:
url = self._prompt_url()
base_url = _normalize_url_or_error(url)
api_client = APIClient(base_url=base_url)
provider = self._select_provider_or_error(api_client=api_client, provider=args.provider)
server = _LoginServer(api_client=api_client, provider=provider)
Expand All @@ -56,9 +87,9 @@ def _command(self, args: argparse.Namespace):
server.shutdown()
if user is None:
raise CLIError("CLI authentication failed")
console.print(f"Logged in as [code]{user.username}[/].")
console.print(f"Logged in as [code]{user.username}[/]")
api_client = APIClient(base_url=base_url, token=user.creds.token)
self._configure_projects(api_client=api_client, user=user)
self._configure_projects(api_client=api_client, user=user, args=args)

def _select_provider_or_error(self, api_client: APIClient, provider: Optional[str]) -> str:
providers = api_client.auth.list_providers()
Expand All @@ -67,6 +98,8 @@ def _select_provider_or_error(self, api_client: APIClient, provider: Optional[st
raise CLIError("No SSO providers configured on the server.")
if provider is None:
if len(available_providers) > 1:
if is_project_menu_supported:
return self._prompt_provider(available_providers)
raise CLIError(
"Specify -p/--provider to choose SSO provider"
f" Available providers: {', '.join(available_providers)}"
Expand All @@ -79,7 +112,37 @@ def _select_provider_or_error(self, api_client: APIClient, provider: Optional[st
)
return provider

def _configure_projects(self, api_client: APIClient, user: UserWithCreds):
def _prompt_url(self) -> str:
try:
url = UrlPrompt.ask(
"Enter the server URL",
default="https://sky.dstack.ai",
console=console,
)
except KeyboardInterrupt:
console.print("\nCancelled by user")
raise SystemExit(1)
if url is None:
raise CLIError("URL is required")
return url

def _prompt_provider(self, available_providers: list[str]) -> str:
choices = [
questionary.Choice(title=provider, value=provider) for provider in available_providers
]
selected_provider = questionary.select(
message="Select SSO provider:",
choices=choices,
qmark="",
instruction="(↑↓ Enter)",
).ask()
if selected_provider is None:
raise SystemExit(1)
return selected_provider

def _configure_projects(
self, api_client: APIClient, user: UserWithCreds, args: argparse.Namespace
):
projects = api_client.projects.list(include_not_joined=False)
if len(projects) == 0:
console.print(
Expand All @@ -89,30 +152,88 @@ def _configure_projects(self, api_client: APIClient, user: UserWithCreds):
return
config_manager = ConfigManager()
default_project = config_manager.get_project_config()
new_default_project = None
for i, project in enumerate(projects):
set_as_default = (
default_project is None
and i == 0
or default_project is not None
and default_project.name == project.project_name
)
if set_as_default:
new_default_project = project
for project in projects:
config_manager.configure_project(
name=project.project_name,
url=api_client.base_url,
token=user.creds.token,
default=set_as_default,
default=False,
)
config_manager.save()
project_names = ", ".join(f"[code]{p.project_name}[/]" for p in projects)
console.print(
f"Configured projects: {', '.join(f'[code]{p.project_name}[/]' for p in projects)}."
f"Added {project_names} project{'' if len(projects) == 1 else 's'} at {config_manager.config_filepath}"
)
if new_default_project:
console.print(
f"Set project [code]{new_default_project.project_name}[/] as default project."
)

project_configs = config_manager.list_project_configs()

if args.no:
return

if args.yes:
if len(projects) > 0:
first_project_from_server = projects[0]
first_project_config = next(
(
pc
for pc in project_configs
if pc.name == first_project_from_server.project_name
),
None,
)
if first_project_config is not None:
config_manager.configure_project(
name=first_project_config.name,
url=first_project_config.url,
token=first_project_config.token,
default=True,
)
config_manager.save()
console.print(
f"Set [code]{first_project_config.name}[/] project as default at {config_manager.config_filepath}"
)
return

if len(project_configs) == 1 or not is_project_menu_supported:
selected_project = None
if len(project_configs) == 1:
selected_project = project_configs[0]
else:
for i, project in enumerate(projects):
set_as_default = (
default_project is None
and i == 0
or default_project is not None
and default_project.name == project.project_name
)
if set_as_default:
selected_project = next(
(pc for pc in project_configs if pc.name == project.project_name),
None,
)
break
if selected_project is not None:
config_manager.configure_project(
name=selected_project.name,
url=selected_project.url,
token=selected_project.token,
default=True,
)
config_manager.save()
console.print(
f"Set [code]{selected_project.name}[/] project as default at {config_manager.config_filepath}"
)
else:
console.print()
selected_project = select_default_project(project_configs, default_project)
if selected_project is not None:
config_manager.configure_project(
name=selected_project.name,
url=selected_project.url,
token=selected_project.token,
default=True,
)
config_manager.save()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing confirmation message after interactive project selection

Low Severity

The interactive project selection path (the else branch) successfully configures and saves the selected default project but does not print a confirmation message. All other code paths that set a default project (the --yes flag at lines 193-195 and the auto-default case at lines 224-226) print a "Set [project] project as default at [path]" message. This inconsistency means users who interactively select a project receive no feedback that their selection was applied.

Fix in Cursor Fix in Web



class _BadRequestError(Exception):
Expand Down
14 changes: 5 additions & 9 deletions src/dstack/_internal/cli/commands/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@
import sys
from typing import Any, Optional, Union

import questionary
from requests import HTTPError
from rich.table import Table

try:
import questionary

is_project_menu_supported = sys.stdin.isatty()
except (ImportError, NotImplementedError, AttributeError):
is_project_menu_supported = False

import dstack.api.server
from dstack._internal.cli.commands import BaseCommand
from dstack._internal.cli.utils.common import add_row_from_dict, confirm_ask, console
Expand All @@ -22,6 +16,8 @@

logger = get_logger(__name__)

is_project_menu_supported = sys.stdin.isatty()


def select_default_project(
project_configs: list[ProjectConfig], default_project: Optional[ProjectConfig]
Expand Down Expand Up @@ -57,9 +53,9 @@ def select_default_project(
default_index = i
menu_entries.append((entry, i))

choices = [questionary.Choice(title=entry, value=index) for entry, index in menu_entries] # pyright: ignore[reportPossiblyUnboundVariable]
choices = [questionary.Choice(title=entry, value=index) for entry, index in menu_entries]
default_value = default_index
selected_index = questionary.select( # pyright: ignore[reportPossiblyUnboundVariable]
selected_index = questionary.select(
message="Select the default project:",
choices=choices,
default=default_value, # pyright: ignore[reportArgumentType]
Expand Down
6 changes: 5 additions & 1 deletion src/dstack/_internal/cli/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def configure_logging():

def confirm_ask(prompt, **kwargs) -> bool:
kwargs["console"] = console
return Confirm.ask(prompt=prompt, **kwargs)
try:
return Confirm.ask(prompt=prompt, **kwargs)
except KeyboardInterrupt:
console.print("\nCancelled by user")
raise SystemExit(1)


def add_row_from_dict(table: Table, data: Dict[Union[str, int], Any], **kwargs):
Expand Down
Loading