diff --git a/cmd/server/main-server.go b/cmd/server/main-server.go index 5eb247c75c..e91e556092 100644 --- a/cmd/server/main-server.go +++ b/cmd/server/main-server.go @@ -28,6 +28,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/service" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata" + "github.com/wavetermdev/waveterm/pkg/util/envutil" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/sigutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" @@ -79,7 +80,7 @@ func doShutdown(reason string) { log.Printf("shutting down: %s\n", reason) ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) defer cancelFn() - go blockcontroller.StopAllBlockControllers() + go blockcontroller.StopAllBlockControllersForShutdown() shutdownActivityUpdate() sendTelemetryWrapper() // TODO deal with flush in progress @@ -162,15 +163,9 @@ func sendDiagnosticPing() bool { if err != nil || !isOnline { return false } - clientData, err := wstore.DBGetSingleton[*waveobj.Client](ctx) - if err != nil { - return false - } - if clientData == nil { - return false - } + clientId := wstore.GetClientId() usageTelemetry := telemetry.IsTelemetryEnabled() - wcloud.SendDiagnosticPing(ctx, clientData.OID, usageTelemetry) + wcloud.SendDiagnosticPing(ctx, clientId, usageTelemetry) return true } @@ -226,12 +221,8 @@ func sendTelemetryWrapper() { ctx, cancelFn := context.WithTimeout(context.Background(), 15*time.Second) defer cancelFn() beforeSendActivityUpdate(ctx) - client, err := wstore.DBGetSingleton[*waveobj.Client](ctx) - if err != nil { - log.Printf("[error] getting client data for telemetry: %v\n", err) - return - } - err = wcloud.SendAllTelemetry(client.OID) + clientId := wstore.GetClientId() + err := wcloud.SendAllTelemetry(clientId) if err != nil { log.Printf("[error] sending telemetry: %v\n", err) } @@ -392,7 +383,10 @@ func createMainWshClient() { wshfs.RpcClient = rpc wshutil.DefaultRouter.RegisterTrustedLeaf(rpc, wshutil.DefaultRoute) wps.Broker.SetClient(wshutil.DefaultRouter) - localConnWsh := wshutil.MakeWshRpc(wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, wshremote.MakeRemoteRpcServerImpl(nil, wshutil.DefaultRouter, wshclient.GetBareRpcClient(), true), "conn:local") + localInitialEnv := envutil.PruneInitialEnv(envutil.SliceToMap(os.Environ())) + sockName := wavebase.GetDomainSocketName() + remoteImpl := wshremote.MakeRemoteRpcServerImpl(nil, wshutil.DefaultRouter, wshclient.GetBareRpcClient(), true, localInitialEnv, sockName) + localConnWsh := wshutil.MakeWshRpc(wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, remoteImpl, "conn:local") go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName) wshutil.DefaultRouter.RegisterTrustedLeaf(localConnWsh, wshutil.MakeConnectionRouteId(wshrpc.LocalConnName)) } diff --git a/cmd/wsh/cmd/wshcmd-connserver.go b/cmd/wsh/cmd/wshcmd-connserver.go index 6ec0d5e4d7..ea46a8cc6e 100644 --- a/cmd/wsh/cmd/wshcmd-connserver.go +++ b/cmd/wsh/cmd/wshcmd-connserver.go @@ -19,6 +19,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" + "github.com/wavetermdev/waveterm/pkg/util/envutil" "github.com/wavetermdev/waveterm/pkg/util/packetparser" "github.com/wavetermdev/waveterm/pkg/util/sigutil" "github.com/wavetermdev/waveterm/pkg/wavebase" @@ -42,6 +43,7 @@ var connServerRouterDomainSocket bool var connServerConnName string var connServerDev bool var ConnServerWshRouter *wshutil.WshRouter +var connServerInitialEnv map[string]string func init() { serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode (stdio upstream)") @@ -120,18 +122,18 @@ func runListener(listener net.Listener, router *wshutil.WshRouter) { } } -func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.WshRpc, error) { +func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter, sockName string) (*wshutil.WshRpc, error) { routeId := wshutil.MakeConnectionRouteId(connServerConnName) rpcCtx := wshrpc.RpcContext{ RouteId: routeId, Conn: connServerConnName, } - + bareRouteId := wshutil.MakeRandomProcRouteId() bareClient := wshutil.MakeWshRpc(wshrpc.RpcContext{}, &wshclient.WshServer{}, bareRouteId) router.RegisterTrustedLeaf(bareClient, bareRouteId) - - connServerClient := wshutil.MakeWshRpc(rpcCtx, wshremote.MakeRemoteRpcServerImpl(os.Stdout, router, bareClient, false), routeId) + + connServerClient := wshutil.MakeWshRpc(rpcCtx, wshremote.MakeRemoteRpcServerImpl(os.Stdout, router, bareClient, false, connServerInitialEnv, sockName), routeId) router.RegisterTrustedLeaf(connServerClient, routeId) return connServerClient, nil } @@ -170,8 +172,10 @@ func serverRunRouter() error { }() router.RegisterUpstream(termProxy) + sockName := getRemoteDomainSocketName() + // setup the connserver rpc client first - client, err := setupConnServerRpcClientWithRouter(router) + client, err := setupConnServerRpcClientWithRouter(router, sockName) if err != nil { return fmt.Errorf("error setting up connserver rpc client: %v", err) } @@ -267,15 +271,11 @@ func serverRunRouterDomainSocket(jwtToken string) error { // register the domain socket connection as upstream router.RegisterUpstream(upstreamProxy) - // setup the connserver rpc client (leaf) - client, err := setupConnServerRpcClientWithRouter(router) - if err != nil { - return fmt.Errorf("error setting up connserver rpc client: %v", err) - } - wshfs.RpcClient = client + // use the router's control RPC to authenticate with upstream + controlRpc := router.GetControlRpc() // authenticate with the upstream router using the JWT - _, err = wshclient.AuthenticateCommand(client, jwtToken, &wshrpc.RpcOpts{Route: wshutil.ControlRoute}) + _, err = wshclient.AuthenticateCommand(controlRpc, jwtToken, &wshrpc.RpcOpts{Route: wshutil.ControlRootRoute}) if err != nil { return fmt.Errorf("error authenticating with upstream: %v", err) } @@ -283,7 +283,7 @@ func serverRunRouterDomainSocket(jwtToken string) error { // fetch and set JWT public key log.Printf("trying to get JWT public key") - jwtPublicKeyB64, err := wshclient.GetJwtPublicKeyCommand(client, nil) + jwtPublicKeyB64, err := wshclient.GetJwtPublicKeyCommand(controlRpc, nil) if err != nil { return fmt.Errorf("error getting jwt public key: %v", err) } @@ -297,6 +297,13 @@ func serverRunRouterDomainSocket(jwtToken string) error { } log.Printf("got JWT public key") + // now setup the connserver rpc client + client, err := setupConnServerRpcClientWithRouter(router, sockName) + if err != nil { + return fmt.Errorf("error setting up connserver rpc client: %v", err) + } + wshfs.RpcClient = client + // set up the local domain socket listener for local wsh commands unixListener, err := MakeRemoteUnixListener() if err != nil { @@ -323,7 +330,11 @@ func serverRunRouterDomainSocket(jwtToken string) error { } func serverRunNormal(jwtToken string) error { - err := setupRpcClient(wshremote.MakeRemoteRpcServerImpl(os.Stdout, nil, nil, false), jwtToken) + sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken) + if err != nil { + return fmt.Errorf("error extracting socket name from JWT: %v", err) + } + err = setupRpcClient(wshremote.MakeRemoteRpcServerImpl(os.Stdout, nil, nil, false, connServerInitialEnv, sockName), jwtToken) if err != nil { return err } @@ -359,6 +370,8 @@ func askForJwtToken() (string, error) { } func serverRun(cmd *cobra.Command, args []string) error { + connServerInitialEnv = envutil.PruneInitialEnv(envutil.SliceToMap(os.Environ())) + var logFile *os.File if connServerDev { var err error diff --git a/cmd/wsh/cmd/wshcmd-jobdebug.go b/cmd/wsh/cmd/wshcmd-jobdebug.go index 5ae68b7051..2be81b8a96 100644 --- a/cmd/wsh/cmd/wshcmd-jobdebug.go +++ b/cmd/wsh/cmd/wshcmd-jobdebug.go @@ -44,10 +44,10 @@ var jobDebugPruneCmd = &cobra.Command{ RunE: jobDebugPruneRun, } -var jobDebugExitCmd = &cobra.Command{ - Use: "exit", - Short: "exit a job manager", - RunE: jobDebugExitRun, +var jobDebugTerminateCmd = &cobra.Command{ + Use: "terminate", + Short: "terminate a job manager", + RunE: jobDebugTerminateRun, } var jobDebugDisconnectCmd = &cobra.Command{ @@ -96,7 +96,7 @@ var jobDebugDetachJobCmd = &cobra.Command{ var jobIdFlag string var jobDebugJsonFlag bool var jobConnFlag string -var exitJobIdFlag string +var terminateJobIdFlag string var disconnectJobIdFlag string var reconnectJobIdFlag string var reconnectConnNameFlag string @@ -110,7 +110,7 @@ func init() { jobDebugCmd.AddCommand(jobDebugDeleteCmd) jobDebugCmd.AddCommand(jobDebugDeleteAllCmd) jobDebugCmd.AddCommand(jobDebugPruneCmd) - jobDebugCmd.AddCommand(jobDebugExitCmd) + jobDebugCmd.AddCommand(jobDebugTerminateCmd) jobDebugCmd.AddCommand(jobDebugDisconnectCmd) jobDebugCmd.AddCommand(jobDebugReconnectCmd) jobDebugCmd.AddCommand(jobDebugReconnectConnCmd) @@ -124,8 +124,8 @@ func init() { jobDebugDeleteCmd.Flags().StringVar(&jobIdFlag, "jobid", "", "job id to delete (required)") jobDebugDeleteCmd.MarkFlagRequired("jobid") - jobDebugExitCmd.Flags().StringVar(&exitJobIdFlag, "jobid", "", "job id to exit (required)") - jobDebugExitCmd.MarkFlagRequired("jobid") + jobDebugTerminateCmd.Flags().StringVar(&terminateJobIdFlag, "jobid", "", "job id to terminate (required)") + jobDebugTerminateCmd.MarkFlagRequired("jobid") jobDebugDisconnectCmd.Flags().StringVar(&disconnectJobIdFlag, "jobid", "", "job id to disconnect (required)") jobDebugDisconnectCmd.MarkFlagRequired("jobid") @@ -176,12 +176,15 @@ func jobDebugListRun(cmd *cobra.Command, args []string) error { return nil } - fmt.Printf("%-36s %-20s %-9s %-10s %-30s %-8s %-10s\n", "OID", "Connection", "Connected", "Manager", "Cmd", "ExitCode", "Stream") + fmt.Printf("%-36s %-20s %-9s %-10s %-6s %-30s %-8s %-10s %-8s\n", "OID", "Connection", "Connected", "Manager", "Reason", "Cmd", "ExitCode", "Stream", "Attached") for _, job := range rtnData { connectedStatus := "no" if connectedMap[job.OID] { connectedStatus = "yes" } + if job.TerminateOnReconnect { + connectedStatus += "*" + } streamStatus := "-" if job.StreamDone { @@ -203,8 +206,26 @@ func jobDebugListRun(cmd *cobra.Command, args []string) error { } } - fmt.Printf("%-36s %-20s %-9s %-10s %-30s %-8s %-10s\n", - job.OID, job.Connection, connectedStatus, job.JobManagerStatus, job.Cmd, exitCode, streamStatus) + doneReason := "-" + if job.JobManagerDoneReason == "startuperror" { + doneReason = "serr" + } else if job.JobManagerDoneReason == "gone" { + doneReason = "gone" + } else if job.JobManagerDoneReason == "terminated" { + doneReason = "term" + } + + attachedBlock := "-" + if job.AttachedBlockId != "" { + if len(job.AttachedBlockId) >= 8 { + attachedBlock = job.AttachedBlockId[:8] + } else { + attachedBlock = job.AttachedBlockId + } + } + + fmt.Printf("%-36s %-20s %-9s %-10s %-6s %-30s %-8s %-10s %-8s\n", + job.OID, job.Connection, connectedStatus, job.JobManagerStatus, doneReason, job.Cmd, exitCode, streamStatus, attachedBlock) } return nil } @@ -275,13 +296,13 @@ func jobDebugPruneRun(cmd *cobra.Command, args []string) error { return nil } -func jobDebugExitRun(cmd *cobra.Command, args []string) error { - err := wshclient.JobControllerExitJobCommand(RpcClient, exitJobIdFlag, nil) +func jobDebugTerminateRun(cmd *cobra.Command, args []string) error { + err := wshclient.JobControllerExitJobCommand(RpcClient, terminateJobIdFlag, nil) if err != nil { - return fmt.Errorf("exiting job manager: %w", err) + return fmt.Errorf("terminating job manager: %w", err) } - fmt.Printf("Job manager for %s exited successfully\n", exitJobIdFlag) + fmt.Printf("Job manager for %s terminated successfully\n", terminateJobIdFlag) return nil } diff --git a/frontend/app/store/wshclientapi.ts b/frontend/app/store/wshclientapi.ts index 3caeb0f201..26f97e062c 100644 --- a/frontend/app/store/wshclientapi.ts +++ b/frontend/app/store/wshclientapi.ts @@ -97,6 +97,11 @@ class RpcApiType { return client.wshRpcCall("connreinstallwsh", data, opts); } + // command "connserverinit" [call] + ConnServerInitCommand(client: WshClient, data: CommandConnServerInitData, opts?: RpcOpts): Promise { + return client.wshRpcCall("connserverinit", data, opts); + } + // command "connstatus" [call] ConnStatusCommand(client: WshClient, opts?: RpcOpts): Promise { return client.wshRpcCall("connstatus", null, opts); @@ -112,6 +117,11 @@ class RpcApiType { return client.wshRpcCall("controllerappendoutput", data, opts); } + // command "controllerdestroy" [call] + ControllerDestroyCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { + return client.wshRpcCall("controllerdestroy", data, opts); + } + // command "controllerinput" [call] ControllerInputCommand(client: WshClient, data: CommandBlockInputData, opts?: RpcOpts): Promise { return client.wshRpcCall("controllerinput", data, opts); @@ -122,11 +132,6 @@ class RpcApiType { return client.wshRpcCall("controllerresync", data, opts); } - // command "controllerstop" [call] - ControllerStopCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { - return client.wshRpcCall("controllerstop", data, opts); - } - // command "createblock" [call] CreateBlockCommand(client: WshClient, data: CommandCreateBlockData, opts?: RpcOpts): Promise { return client.wshRpcCall("createblock", data, opts); diff --git a/frontend/app/view/term/term.tsx b/frontend/app/view/term/term.tsx index d1ca981c97..10fd0fb112 100644 --- a/frontend/app/view/term/term.tsx +++ b/frontend/app/view/term/term.tsx @@ -298,7 +298,6 @@ const TerminalView = ({ blockId, model }: ViewComponentProps) => useWebGl: !termSettings?.["term:disablewebgl"], sendDataHandler: model.sendDataToController.bind(model), nodeModel: model.nodeModel, - jobId: blockData?.jobid, } ); (window as any).term = termWrap; diff --git a/frontend/app/view/term/termwrap.ts b/frontend/app/view/term/termwrap.ts index 60743db584..37aef29839 100644 --- a/frontend/app/view/term/termwrap.ts +++ b/frontend/app/view/term/termwrap.ts @@ -49,7 +49,6 @@ type TermWrapOptions = { useWebGl?: boolean; sendDataHandler?: (data: string) => void; nodeModel?: BlockNodeModel; - jobId?: string; }; // for xterm OSC handlers, we return true always because we "own" the OSC number. @@ -375,7 +374,6 @@ function handleOsc16162Command(data: string, blockId: string, loaded: boolean, t export class TermWrap { tabId: string; blockId: string; - jobId: string; ptyOffset: number; dataBytesProcessed: number; terminal: Terminal; @@ -423,7 +421,6 @@ export class TermWrap { this.loaded = false; this.tabId = tabId; this.blockId = blockId; - this.jobId = waveOptions.jobId; this.sendDataHandler = waveOptions.sendDataHandler; this.nodeModel = waveOptions.nodeModel; this.ptyOffset = 0; @@ -498,7 +495,7 @@ export class TermWrap { } getZoneId(): string { - return this.jobId ?? this.blockId; + return this.blockId; } resetCompositionState() { diff --git a/frontend/app/view/tsunami/tsunami.tsx b/frontend/app/view/tsunami/tsunami.tsx index dbebb824b3..615d2573fd 100644 --- a/frontend/app/view/tsunami/tsunami.tsx +++ b/frontend/app/view/tsunami/tsunami.tsx @@ -118,9 +118,9 @@ class TsunamiViewModel extends WebViewModel { this.doControllerResync(false, "resync", false); } - stopController() { - const prtn = RpcApi.ControllerStopCommand(TabRpcClient, this.blockId); - prtn.catch((e) => console.log("error stopping controller", e)); + destroyController() { + const prtn = RpcApi.ControllerDestroyCommand(TabRpcClient, this.blockId); + prtn.catch((e) => console.log("error destroying controller", e)); } async restartController() { @@ -130,7 +130,7 @@ class TsunamiViewModel extends WebViewModel { this.triggerRestartAtom(); try { // Stop the controller first - await RpcApi.ControllerStopCommand(TabRpcClient, this.blockId); + await RpcApi.ControllerDestroyCommand(TabRpcClient, this.blockId); // Wait a bit for the controller to fully stop await new Promise((resolve) => setTimeout(resolve, 300)); // Then resync to restart it @@ -202,7 +202,7 @@ class TsunamiViewModel extends WebViewModel { const tsunamiItems: ContextMenuItem[] = [ { label: "Stop WaveApp", - click: () => this.stopController(), + click: () => this.destroyController(), }, { label: "Restart WaveApp", diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index a4ec175c1f..d84fee4d3a 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -197,6 +197,7 @@ declare global { // wshrpc.CommandAuthenticateRtnData type CommandAuthenticateRtnData = { + routeid: string; env?: {[key: string]: string}; initscripttext?: string; rpccontext?: RpcContext; @@ -233,6 +234,11 @@ declare global { errorstring?: string; }; + // wshrpc.CommandConnServerInitData + type CommandConnServerInitData = { + clientid: string; + }; + // wshrpc.CommandControllerAppendOutputData type CommandControllerAppendOutputData = { blockid: string; @@ -387,6 +393,8 @@ declare global { // wshrpc.CommandJobInputData type CommandJobInputData = { jobid: string; + inputsessionid?: string; + seqnum?: number; inputdata64?: string; signame?: string; termsize?: TermSize; @@ -396,6 +404,7 @@ declare global { type CommandJobPrepareConnectData = { streammeta: StreamMeta; seq: number; + termsize: TermSize; }; // wshrpc.CommandJobStartStreamData @@ -998,6 +1007,7 @@ declare global { cmd?: string; "cmd:interactive"?: boolean; "cmd:login"?: boolean; + "cmd:persistent"?: boolean; "cmd:runonstart"?: boolean; "cmd:clearonstart"?: boolean; "cmd:runonce"?: boolean; @@ -1160,6 +1170,7 @@ declare global { clientos: string; clientversion: string; shell: string; + homedir: string; }; // wshrpc.RestartBuilderAndWaitResult @@ -1173,6 +1184,7 @@ declare global { type RpcContext = { sockname?: string; routeid: string; + procroute?: boolean; blockid?: string; conn?: string; isrouter?: boolean; diff --git a/package-lock.json b/package-lock.json index a016966011..27fec467a2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -102,7 +102,6 @@ "@types/sprintf-js": "^1", "@types/throttle-debounce": "^5", "@types/tinycolor2": "^1", - "@types/uuid": "^11.0.0", "@types/ws": "^8", "@vitejs/plugin-react-swc": "4.2.2", "@vitest/coverage-istanbul": "^3.0.9", @@ -10049,17 +10048,6 @@ "integrity": "sha512-zFDAD+tlpf2r4asuHEj0XH6pY6i0g5NeAHPn+15wk3BV6JA69eERFXC1gyGThDkVa1zCyKr5jox1+2LbV/AMLg==", "license": "MIT" }, - "node_modules/@types/uuid": { - "version": "11.0.0", - "resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-11.0.0.tgz", - "integrity": "sha512-HVyk8nj2m+jcFRNazzqyVKiZezyhDKrGUA3jlEcg/nZ6Ms+qHwocba1Y/AaVaznJTAM9xpdFSh+ptbNrhOGvZA==", - "deprecated": "This is a stub types definition. uuid provides its own type definitions, so you do not need this installed.", - "dev": true, - "license": "MIT", - "dependencies": { - "uuid": "*" - } - }, "node_modules/@types/verror": { "version": "1.10.11", "resolved": "https://registry.npmjs.org/@types/verror/-/verror-1.10.11.tgz", diff --git a/package.json b/package.json index 38846952cd..7fe1990cf2 100644 --- a/package.json +++ b/package.json @@ -45,7 +45,6 @@ "@types/sprintf-js": "^1", "@types/throttle-debounce": "^5", "@types/tinycolor2": "^1", - "@types/uuid": "^11.0.0", "@types/ws": "^8", "@vitejs/plugin-react-swc": "4.2.2", "@vitest/coverage-istanbul": "^3.0.9", diff --git a/pkg/aiusechat/usechat.go b/pkg/aiusechat/usechat.go index 8d8fcf6446..7e8badf258 100644 --- a/pkg/aiusechat/usechat.go +++ b/pkg/aiusechat/usechat.go @@ -676,17 +676,10 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) { return } - // Get client ID from database - client, err := wstore.DBGetSingleton[*waveobj.Client](r.Context()) - if err != nil { - http.Error(w, fmt.Sprintf("Failed to get client: %v", err), http.StatusInternalServerError) - return - } - // Call the core WaveAIPostMessage function chatOpts := uctypes.WaveChatOpts{ ChatId: req.ChatID, - ClientId: client.OID, + ClientId: wstore.GetClientId(), Config: *aiOpts, WidgetAccess: req.WidgetAccess, AllowNativeWebSearch: true, diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index a255c680a4..903081976f 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -13,10 +13,12 @@ import ( "sync" "time" + "github.com/google/uuid" "github.com/wavetermdev/waveterm/pkg/blocklogger" "github.com/wavetermdev/waveterm/pkg/filestore" "github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" + "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wps" @@ -53,7 +55,7 @@ type BlockInputUnion struct { type BlockControllerRuntimeStatus struct { BlockId string `json:"blockid"` - Version int `json:"version"` + Version int64 `json:"version"` ShellProcStatus string `json:"shellprocstatus,omitempty"` ShellProcConnName string `json:"shellprocconnname,omitempty"` ShellProcExitCode int `json:"shellprocexitcode"` @@ -63,8 +65,8 @@ type BlockControllerRuntimeStatus struct { // Controller interface that all block controllers must implement type Controller interface { Start(ctx context.Context, blockMeta waveobj.MetaMapType, rtOpts *waveobj.RuntimeOpts, force bool) error - Stop(graceful bool, newStatus string) error - GetRuntimeStatus() *BlockControllerRuntimeStatus + Stop(graceful bool, newStatus string, destroy bool) error + GetRuntimeStatus() *BlockControllerRuntimeStatus // does not return nil SendInput(input *BlockInputUnion) error } @@ -93,7 +95,7 @@ func registerController(blockId string, controller Controller) { registryLock.Unlock() if existingController != nil { - existingController.Stop(false, Status_Done) + existingController.Stop(false, Status_Done, true) wstore.DeleteRTInfo(waveobj.MakeORef(waveobj.OType_Block, blockId)) } } @@ -135,22 +137,32 @@ func ResyncController(ctx context.Context, tabId string, blockId string, rtOpts // If no controller needed, stop existing if present if controllerName == "" { if existing != nil { - StopBlockController(blockId) - deleteController(blockId) + DestroyBlockController(blockId) } return nil } + // Determine if we should use ShellJobController vs ShellController + isPersistent := blockData.Meta.GetBool(waveobj.MetaKey_CmdPersistent, false) + connName := blockData.Meta.GetString(waveobj.MetaKey_Connection, "") + isRemote := !conncontroller.IsLocalConnName(connName) + shouldUseShellJobController := isPersistent && isRemote && (controllerName == BlockController_Shell || controllerName == BlockController_Cmd) + // Check if we need to morph controller type if existing != nil { existingStatus := existing.GetRuntimeStatus() needsReplace := false - // Determine if existing controller type matches what we need switch existing.(type) { case *ShellController: if controllerName != BlockController_Shell && controllerName != BlockController_Cmd { needsReplace = true + } else if shouldUseShellJobController { + needsReplace = true + } + case *ShellJobController: + if !shouldUseShellJobController { + needsReplace = true } case *TsunamiController: if controllerName != BlockController_Tsunami { @@ -160,31 +172,41 @@ func ResyncController(ctx context.Context, tabId string, blockId string, rtOpts if needsReplace { log.Printf("stopping blockcontroller %s due to controller type change\n", blockId) - StopBlockController(blockId) + DestroyBlockController(blockId) time.Sleep(100 * time.Millisecond) - deleteController(blockId) existing = nil } - // For shell/cmd, check if connection changed + // For shell/cmd, check if connection changed (but not for job controller) if !needsReplace && (controllerName == BlockController_Shell || controllerName == BlockController_Cmd) { - connName := blockData.Meta.GetString(waveobj.MetaKey_Connection, "") - // Check if connection changed, including between different local connections - if existingStatus.ShellProcStatus == Status_Running && existingStatus.ShellProcConnName != connName { - log.Printf("stopping blockcontroller %s due to conn change (from %q to %q)\n", blockId, existingStatus.ShellProcConnName, connName) - StopBlockControllerAndSetStatus(blockId, Status_Init) - time.Sleep(100 * time.Millisecond) - // Don't delete, will reuse same controller type - existing = getController(blockId) + if _, isShellController := existing.(*ShellController); isShellController { + // Check if connection changed, including between different local connections + if existingStatus.ShellProcStatus == Status_Running && existingStatus.ShellProcConnName != connName { + log.Printf("stopping blockcontroller %s due to conn change (from %q to %q)\n", blockId, existingStatus.ShellProcConnName, connName) + DestroyBlockController(blockId) + time.Sleep(100 * time.Millisecond) + existing = nil + } } } } // Force restart if requested if force && existing != nil { - StopBlockController(blockId) + DestroyBlockController(blockId) time.Sleep(100 * time.Millisecond) - existing = getController(blockId) + existing = nil + } + + // Destroy done controllers before restarting + if existing != nil { + status := existing.GetRuntimeStatus() + if status.ShellProcStatus == Status_Done { + log.Printf("destroying blockcontroller %s with done status before restart\n", blockId) + DestroyBlockController(blockId) + time.Sleep(100 * time.Millisecond) + existing = nil + } } // Create or restart controller @@ -195,7 +217,11 @@ func ResyncController(ctx context.Context, tabId string, blockId string, rtOpts // Create new controller based on type switch controllerName { case BlockController_Shell, BlockController_Cmd: - controller = MakeShellController(tabId, blockId, controllerName) + if shouldUseShellJobController { + controller = MakeShellJobController(tabId, blockId, controllerName) + } else { + controller = MakeShellController(tabId, blockId, controllerName) + } registerController(blockId, controller) case BlockController_Tsunami: @@ -209,7 +235,7 @@ func ResyncController(ctx context.Context, tabId string, blockId string, rtOpts // Check if we need to start/restart status := controller.GetRuntimeStatus() - if status.ShellProcStatus == Status_Init || status.ShellProcStatus == Status_Done { + if status.ShellProcStatus == Status_Init { // For shell/cmd, check connection status first (for non-local connections) if controllerName == BlockController_Shell || controllerName == BlockController_Cmd { connName := blockData.Meta.GetString(waveobj.MetaKey_Connection, "") @@ -239,22 +265,14 @@ func GetBlockControllerRuntimeStatus(blockId string) *BlockControllerRuntimeStat return controller.GetRuntimeStatus() } -func StopBlockController(blockId string) { - controller := getController(blockId) - if controller == nil { - return - } - controller.Stop(true, Status_Done) - wstore.DeleteRTInfo(waveobj.MakeORef(waveobj.OType_Block, blockId)) -} - -func StopBlockControllerAndSetStatus(blockId string, newStatus string) { +func DestroyBlockController(blockId string) { controller := getController(blockId) if controller == nil { return } - controller.Stop(true, newStatus) + controller.Stop(true, Status_Done, true) wstore.DeleteRTInfo(waveobj.MakeORef(waveobj.OType_Block, blockId)) + deleteController(blockId) } func SendInput(blockId string, inputUnion *BlockInputUnion) error { @@ -265,13 +283,14 @@ func SendInput(blockId string, inputUnion *BlockInputUnion) error { return controller.SendInput(inputUnion) } -func StopAllBlockControllers() { +// only call this on shutdown +func StopAllBlockControllersForShutdown() { controllers := getAllControllers() for blockId, controller := range controllers { status := controller.GetRuntimeStatus() if status != nil && status.ShellProcStatus == Status_Running { go func(id string, c Controller) { - c.Stop(true, Status_Done) + c.Stop(true, Status_Done, false) wstore.DeleteRTInfo(waveobj.MakeORef(waveobj.OType_Block, id)) }(blockId, controller) } @@ -386,3 +405,40 @@ func CheckConnStatus(blockId string) error { } return nil } + +func makeSwapToken(ctx context.Context, logCtx context.Context, blockId string, blockMeta waveobj.MetaMapType, remoteName string, shellType string) *shellutil.TokenSwapEntry { + token := &shellutil.TokenSwapEntry{ + Token: uuid.New().String(), + Env: make(map[string]string), + Exp: time.Now().Add(5 * time.Minute), + } + token.Env["TERM_PROGRAM"] = "waveterm" + token.Env["WAVETERM_BLOCKID"] = blockId + token.Env["WAVETERM_VERSION"] = wavebase.WaveVersion + token.Env["WAVETERM"] = "1" + tabId, err := wstore.DBFindTabForBlockId(ctx, blockId) + if err != nil { + log.Printf("error finding tab for block: %v\n", err) + } else { + token.Env["WAVETERM_TABID"] = tabId + } + if tabId != "" { + wsId, err := wstore.DBFindWorkspaceForTabId(ctx, tabId) + if err != nil { + log.Printf("error finding workspace for tab: %v\n", err) + } else { + token.Env["WAVETERM_WORKSPACEID"] = wsId + } + } + token.Env["WAVETERM_CLIENTID"] = wstore.GetClientId() + token.Env["WAVETERM_CONN"] = remoteName + envMap, err := resolveEnvMap(blockId, blockMeta, remoteName) + if err != nil { + log.Printf("error resolving env map: %v\n", err) + } + for k, v := range envMap { + token.Env[k] = v + } + token.ScriptText = getCustomInitScript(logCtx, blockMeta, remoteName, shellType) + return token +} diff --git a/pkg/blockcontroller/shellcontroller.go b/pkg/blockcontroller/shellcontroller.go index 040a245745..d6f128963b 100644 --- a/pkg/blockcontroller/shellcontroller.go +++ b/pkg/blockcontroller/shellcontroller.go @@ -17,7 +17,6 @@ import ( "sync/atomic" "time" - "github.com/google/uuid" "github.com/wavetermdev/waveterm/pkg/blocklogger" "github.com/wavetermdev/waveterm/pkg/filestore" "github.com/wavetermdev/waveterm/pkg/panichandler" @@ -28,6 +27,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/util/fileutil" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/utilds" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wconfig" @@ -60,7 +60,7 @@ type ShellController struct { RunLock *atomic.Bool ProcStatus string ProcExitCode int - StatusVersion int + VersionTs utilds.VersionTs // for shell/cmd ShellProc *shellexec.ShellProc @@ -93,7 +93,7 @@ func (sc *ShellController) Start(ctx context.Context, blockMeta waveobj.MetaMapT return nil } -func (sc *ShellController) Stop(graceful bool, newStatus string) error { +func (sc *ShellController) Stop(graceful bool, newStatus string, destroy bool) error { sc.Lock.Lock() defer sc.Lock.Unlock() @@ -121,8 +121,7 @@ func (sc *ShellController) Stop(graceful bool, newStatus string) error { func (sc *ShellController) getRuntimeStatus_nolock() BlockControllerRuntimeStatus { var rtn BlockControllerRuntimeStatus - sc.StatusVersion++ - rtn.Version = sc.StatusVersion + rtn.Version = sc.VersionTs.GetVersionTs() rtn.BlockId = sc.BlockId rtn.ShellProcStatus = sc.ProcStatus if sc.ShellProc != nil { @@ -316,6 +315,7 @@ type ConnUnion struct { ShellPath string ShellOpts []string ShellType string + HomeDir string } func (bc *ShellController) getConnUnion(logCtx context.Context, remoteName string, blockMeta waveobj.MetaMapType) (ConnUnion, error) { @@ -408,7 +408,7 @@ func (bc *ShellController) setupAndStartShellProcess(logCtx context.Context, rc return nil, fmt.Errorf("unknown controller type %q", bc.ControllerType) } var shellProc *shellexec.ShellProc - swapToken := bc.makeSwapToken(ctx, logCtx, blockMeta, remoteName, connUnion.ShellType) + swapToken := makeSwapToken(ctx, logCtx, bc.BlockId, blockMeta, remoteName, connUnion.ShellType) cmdOpts.SwapToken = swapToken blocklogger.Debugf(logCtx, "[conndebug] created swaptoken: %s\n", swapToken.Token) if connUnion.ConnType == ConnType_Wsl { @@ -421,10 +421,10 @@ func (bc *ShellController) setupAndStartShellProcess(logCtx context.Context, rc } else { sockName := wslConn.GetDomainSocketName() rpcContext := wshrpc.RpcContext{ - RouteId: wshutil.MakeRandomProcRouteId(), - SockName: sockName, - BlockId: bc.BlockId, - Conn: wslConn.GetName(), + ProcRoute: true, + SockName: sockName, + BlockId: bc.BlockId, + Conn: wslConn.GetName(), } jwtStr, err := wshutil.MakeClientJWTToken(rpcContext) if err != nil { @@ -454,10 +454,10 @@ func (bc *ShellController) setupAndStartShellProcess(logCtx context.Context, rc } else { sockName := conn.GetDomainSocketName() rpcContext := wshrpc.RpcContext{ - RouteId: wshutil.MakeRandomProcRouteId(), - SockName: sockName, - BlockId: bc.BlockId, - Conn: conn.Opts.String(), + ProcRoute: true, + SockName: sockName, + BlockId: bc.BlockId, + Conn: conn.Opts.String(), } jwtStr, err := wshutil.MakeClientJWTToken(rpcContext) if err != nil { @@ -481,9 +481,9 @@ func (bc *ShellController) setupAndStartShellProcess(logCtx context.Context, rc if connUnion.WshEnabled { sockName := wavebase.GetDomainSocketName() rpcContext := wshrpc.RpcContext{ - RouteId: wshutil.MakeRandomProcRouteId(), - SockName: sockName, - BlockId: bc.BlockId, + ProcRoute: true, + SockName: sockName, + BlockId: bc.BlockId, } jwtStr, err := wshutil.MakeClientJWTToken(rpcContext) if err != nil { @@ -606,12 +606,14 @@ func (union *ConnUnion) getRemoteInfoAndShellType(blockMeta waveobj.MetaMapType) } // TODO allow overriding remote shell path union.ShellPath = remoteInfo.Shell + union.HomeDir = remoteInfo.HomeDir } else { shellPath, err := getLocalShellPath(blockMeta) if err != nil { return err } union.ShellPath = shellPath + union.HomeDir = wavebase.GetHomeDir() } union.ShellType = shellutil.GetShellTypeFromShellPath(union.ShellPath) return nil @@ -715,48 +717,6 @@ func createCmdStrAndOpts(blockId string, blockMeta waveobj.MetaMapType, connName return cmdStr, &cmdOpts, nil } -func (bc *ShellController) makeSwapToken(ctx context.Context, logCtx context.Context, blockMeta waveobj.MetaMapType, remoteName string, shellType string) *shellutil.TokenSwapEntry { - token := &shellutil.TokenSwapEntry{ - Token: uuid.New().String(), - Env: make(map[string]string), - Exp: time.Now().Add(5 * time.Minute), - } - token.Env["TERM_PROGRAM"] = "waveterm" - token.Env["WAVETERM_BLOCKID"] = bc.BlockId - token.Env["WAVETERM_VERSION"] = wavebase.WaveVersion - token.Env["WAVETERM"] = "1" - tabId, err := wstore.DBFindTabForBlockId(ctx, bc.BlockId) - if err != nil { - log.Printf("error finding tab for block: %v\n", err) - } else { - token.Env["WAVETERM_TABID"] = tabId - } - if tabId != "" { - wsId, err := wstore.DBFindWorkspaceForTabId(ctx, tabId) - if err != nil { - log.Printf("error finding workspace for tab: %v\n", err) - } else { - token.Env["WAVETERM_WORKSPACEID"] = wsId - } - } - clientData, err := wstore.DBGetSingleton[*waveobj.Client](ctx) - if err != nil { - log.Printf("error getting client data: %v\n", err) - } else { - token.Env["WAVETERM_CLIENTID"] = clientData.OID - } - token.Env["WAVETERM_CONN"] = remoteName - envMap, err := resolveEnvMap(bc.BlockId, blockMeta, remoteName) - if err != nil { - log.Printf("error resolving env map: %v\n", err) - } - for k, v := range envMap { - token.Env[k] = v - } - token.ScriptText = getCustomInitScript(logCtx, blockMeta, remoteName, shellType) - return token -} - func (bc *ShellController) getBlockData_noErr() *waveobj.Block { ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) defer cancelFn() diff --git a/pkg/blockcontroller/shelljobcontroller.go b/pkg/blockcontroller/shelljobcontroller.go new file mode 100644 index 0000000000..82366f94f6 --- /dev/null +++ b/pkg/blockcontroller/shelljobcontroller.go @@ -0,0 +1,306 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package blockcontroller + +import ( + "context" + "encoding/base64" + "fmt" + "io/fs" + "log" + "sync" + + "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/blocklogger" + "github.com/wavetermdev/waveterm/pkg/filestore" + "github.com/wavetermdev/waveterm/pkg/jobcontroller" + "github.com/wavetermdev/waveterm/pkg/remote" + "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" + "github.com/wavetermdev/waveterm/pkg/shellexec" + "github.com/wavetermdev/waveterm/pkg/util/shellutil" + "github.com/wavetermdev/waveterm/pkg/utilds" + "github.com/wavetermdev/waveterm/pkg/wavebase" + "github.com/wavetermdev/waveterm/pkg/waveobj" + "github.com/wavetermdev/waveterm/pkg/wps" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" + "github.com/wavetermdev/waveterm/pkg/wshutil" + "github.com/wavetermdev/waveterm/pkg/wstore" +) + +type ShellJobController struct { + Lock *sync.Mutex + + ControllerType string + TabId string + BlockId string + BlockDef *waveobj.BlockDef + VersionTs utilds.VersionTs + + InputSessionId string // random uuid + inputSeqNum int // monotonic sequence number for inputs, starts at 1 + + JobId string + LastKnownStatus string +} + +func MakeShellJobController(tabId string, blockId string, controllerType string) Controller { + return &ShellJobController{ + Lock: &sync.Mutex{}, + ControllerType: controllerType, + TabId: tabId, + BlockId: blockId, + LastKnownStatus: Status_Init, + InputSessionId: uuid.New().String(), + } +} + +func (sjc *ShellJobController) WithLock(f func()) { + sjc.Lock.Lock() + defer sjc.Lock.Unlock() + f() +} + +func (sjc *ShellJobController) getJobId() string { + sjc.Lock.Lock() + defer sjc.Lock.Unlock() + return sjc.JobId +} + +func (sjc *ShellJobController) getNextInputSeq() (string, int) { + sjc.Lock.Lock() + defer sjc.Lock.Unlock() + sjc.inputSeqNum++ + return sjc.InputSessionId, sjc.inputSeqNum +} + +func (sjc *ShellJobController) getJobStatus_withlock() string { + if sjc.JobId == "" { + sjc.LastKnownStatus = Status_Init + return Status_Init + } + status, err := jobcontroller.GetJobManagerStatus(context.Background(), sjc.JobId) + if err != nil { + log.Printf("error getting job status for %s: %v, using last known status: %s", sjc.JobId, err, sjc.LastKnownStatus) + return sjc.LastKnownStatus + } + sjc.LastKnownStatus = status + return status +} + +func (sjc *ShellJobController) getRuntimeStatus_withlock() BlockControllerRuntimeStatus { + var rtn BlockControllerRuntimeStatus + rtn.Version = sjc.VersionTs.GetVersionTs() + rtn.BlockId = sjc.BlockId + rtn.ShellProcStatus = sjc.getJobStatus_withlock() + return rtn +} + +func (sjc *ShellJobController) GetRuntimeStatus() *BlockControllerRuntimeStatus { + var rtn BlockControllerRuntimeStatus + sjc.WithLock(func() { + rtn = sjc.getRuntimeStatus_withlock() + }) + return &rtn +} + +func (sjc *ShellJobController) sendUpdate_withlock() { + rtStatus := sjc.getRuntimeStatus_withlock() + log.Printf("sending blockcontroller update %#v\n", rtStatus) + wps.Broker.Publish(wps.WaveEvent{ + Event: wps.Event_ControllerStatus, + Scopes: []string{ + waveobj.MakeORef(waveobj.OType_Tab, sjc.TabId).String(), + waveobj.MakeORef(waveobj.OType_Block, sjc.BlockId).String(), + }, + Data: rtStatus, + }) +} + +// Start initializes or reconnects to a shell job for the block. +// Logic: +// - If block has no existing jobId: starts a new job and attaches it +// - If block has existing jobId with running job manager: reconnects to existing job +// - If block has existing jobId with non-running job manager: +// - force=true: detaches old job and starts new one +// - force=false: returns without starting (leaves block unstarted) +// +// After establishing jobId, ensures job connection is active (reconnects if needed) +func (sjc *ShellJobController) Start(ctx context.Context, blockMeta waveobj.MetaMapType, rtOpts *waveobj.RuntimeOpts, force bool) error { + blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, sjc.BlockId) + if err != nil { + return fmt.Errorf("error getting block: %w", err) + } + + connName := blockMeta.GetString(waveobj.MetaKey_Connection, "") + if conncontroller.IsLocalConnName(connName) { + return fmt.Errorf("shell job controller requires a remote connection") + } + + var jobId string + if blockData.JobId != "" { + status, err := jobcontroller.GetJobManagerStatus(ctx, blockData.JobId) + if err != nil { + return fmt.Errorf("error getting job manager status: %w", err) + } + if status != jobcontroller.JobStatus_Running { + if force { + log.Printf("block %q has jobId %s but manager is not running (status: %s), detaching (force=true)\n", sjc.BlockId, blockData.JobId, status) + jobcontroller.DetachJobFromBlock(ctx, blockData.JobId, false) + } else { + log.Printf("block %q has jobId %s but manager is not running (status: %s), not starting (force=false)\n", sjc.BlockId, blockData.JobId, status) + return nil + } + } else { + jobId = blockData.JobId + } + } + + if jobId == "" { + log.Printf("block %q starting new shell job\n", sjc.BlockId) + newJobId, err := sjc.startNewJob(ctx, blockMeta, connName) + if err != nil { + return fmt.Errorf("failed to start new job: %w", err) + } + jobId = newJobId + + err = jobcontroller.AttachJobToBlock(ctx, jobId, sjc.BlockId) + if err != nil { + log.Printf("error attaching job to block: %v\n", err) + } + } + + sjc.WithLock(func() { + sjc.JobId = jobId + sjc.sendUpdate_withlock() + }) + + _, err = jobcontroller.CheckJobConnected(ctx, jobId) + if err != nil { + log.Printf("job %s is not connected, attempting reconnect: %v\n", jobId, err) + err = jobcontroller.ReconnectJob(ctx, jobId, rtOpts) + if err != nil { + return fmt.Errorf("failed to reconnect to job: %w", err) + } + } + + return nil +} + +func (sjc *ShellJobController) Stop(graceful bool, newStatus string, destroy bool) error { + if !destroy { + return nil + } + jobId := sjc.getJobId() + if jobId == "" { + return nil + } + ctx := context.Background() + jobcontroller.DetachJobFromBlock(ctx, jobId, false) + return jobcontroller.TerminateJobManager(ctx, jobId) +} + +func (sjc *ShellJobController) SendInput(inputUnion *BlockInputUnion) error { + if inputUnion == nil { + return nil + } + jobId := sjc.getJobId() + if jobId == "" { + return fmt.Errorf("no job attached to controller") + } + inputSessionId, seqNum := sjc.getNextInputSeq() + data := wshrpc.CommandJobInputData{ + JobId: jobId, + InputSessionId: inputSessionId, + SeqNum: seqNum, + TermSize: inputUnion.TermSize, + SigName: inputUnion.SigName, + } + if len(inputUnion.InputData) > 0 { + data.InputData64 = base64.StdEncoding.EncodeToString(inputUnion.InputData) + } + return jobcontroller.SendInput(context.Background(), data) +} + +func (sjc *ShellJobController) startNewJob(ctx context.Context, blockMeta waveobj.MetaMapType, connName string) (string, error) { + termSize := waveobj.TermSize{ + Rows: shellutil.DefaultTermRows, + Cols: shellutil.DefaultTermCols, + } + cmdStr := blockMeta.GetString(waveobj.MetaKey_Cmd, "") + cwd := blockMeta.GetString(waveobj.MetaKey_CmdCwd, "") + opts, err := remote.ParseOpts(connName) + if err != nil { + return "", fmt.Errorf("invalid ssh remote name (%s): %w", connName, err) + } + conn := conncontroller.GetConn(opts) + if conn == nil { + return "", fmt.Errorf("connection %q not found", connName) + } + connRoute := wshutil.MakeConnectionRouteId(connName) + remoteInfo, err := wshclient.RemoteGetInfoCommand(wshclient.GetBareRpcClient(), &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000}) + if err != nil { + return "", fmt.Errorf("unable to obtain remote info from connserver: %w", err) + } + shellType := shellutil.GetShellTypeFromShellPath(remoteInfo.Shell) + swapToken := makeSwapToken(ctx, ctx, sjc.BlockId, blockMeta, connName, shellType) + sockName := wavebase.GetPersistentRemoteSockName(wstore.GetClientId()) + rpcContext := wshrpc.RpcContext{ + ProcRoute: true, + SockName: sockName, + BlockId: sjc.BlockId, + Conn: connName, + } + jwtStr, err := wshutil.MakeClientJWTToken(rpcContext) + if err != nil { + return "", fmt.Errorf("error making jwt token: %w", err) + } + swapToken.RpcContext = &rpcContext + swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr + cmdOpts := shellexec.CommandOptsType{ + Interactive: true, + Login: true, + Cwd: cwd, + SwapToken: swapToken, + ForceJwt: false, + } + jobId, err := shellexec.StartRemoteShellJob(ctx, ctx, termSize, cmdStr, cmdOpts, conn) + if err != nil { + return "", fmt.Errorf("failed to start remote shell job: %w", err) + } + return jobId, nil +} + +func (sjc *ShellJobController) resetTerminalState(logCtx context.Context) { + ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout) + defer cancelFn() + + jobId := "" + sjc.WithLock(func() { + jobId = sjc.JobId + }) + if jobId == "" { + return + } + + wfile, statErr := filestore.WFS.Stat(ctx, jobId, jobcontroller.JobOutputFileName) + if statErr == fs.ErrNotExist || wfile.Size == 0 { + return + } + + blocklogger.Debugf(logCtx, "[conndebug] resetTerminalState: resetting terminal state for job\n") + + resetSeq := "\x1b[0m" // reset attributes + resetSeq += "\x1b[?25h" // show cursor + resetSeq += "\x1b[?1000l" // disable mouse tracking + resetSeq += "\x1b[?1007l" // disable alternate scroll mode + resetSeq += "\x1b[?2004l" // disable bracketed paste mode + resetSeq += shellutil.FormatOSC(16162, "R") // disable alternate screen mode + resetSeq += "\r\n\r\n" + + err := filestore.WFS.AppendData(ctx, jobId, jobcontroller.JobOutputFileName, []byte(resetSeq)) + if err != nil { + log.Printf("error appending terminal reset to job file: %v\n", err) + } +} diff --git a/pkg/blockcontroller/tsunamicontroller.go b/pkg/blockcontroller/tsunamicontroller.go index 0452ae4661..483eddf808 100644 --- a/pkg/blockcontroller/tsunamicontroller.go +++ b/pkg/blockcontroller/tsunamicontroller.go @@ -37,15 +37,15 @@ type TsunamiAppProc struct { } type TsunamiController struct { - blockId string - tabId string - runLock sync.Mutex - tsunamiProc *TsunamiAppProc - statusLock sync.Mutex - status string - statusVersion int - exitCode int - port int + blockId string + tabId string + runLock sync.Mutex + tsunamiProc *TsunamiAppProc + statusLock sync.Mutex + status string + versionTs utilds.VersionTs + exitCode int + port int } func (c *TsunamiController) setManifestMetadata(appId string) { @@ -235,7 +235,7 @@ func (c *TsunamiController) Start(ctx context.Context, blockMeta waveobj.MetaMap return nil } -func (c *TsunamiController) Stop(graceful bool, newStatus string) error { +func (c *TsunamiController) Stop(graceful bool, newStatus string, destroy bool) error { log.Printf("TsunamiController.Stop called for block %s (graceful: %t, newStatus: %s)", c.blockId, graceful, newStatus) c.runLock.Lock() defer c.runLock.Unlock() @@ -268,10 +268,9 @@ func (c *TsunamiController) Stop(graceful bool, newStatus string) error { func (c *TsunamiController) GetRuntimeStatus() *BlockControllerRuntimeStatus { var rtn *BlockControllerRuntimeStatus c.WithStatusLock(func() { - c.statusVersion++ rtn = &BlockControllerRuntimeStatus{ BlockId: c.blockId, - Version: c.statusVersion, + Version: c.versionTs.GetVersionTs(), ShellProcStatus: c.status, ShellProcExitCode: c.exitCode, } diff --git a/pkg/jobcontroller/jobcontroller.go b/pkg/jobcontroller/jobcontroller.go index 66f928d920..ff7d7c2421 100644 --- a/pkg/jobcontroller/jobcontroller.go +++ b/pkg/jobcontroller/jobcontroller.go @@ -18,6 +18,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" "github.com/wavetermdev/waveterm/pkg/streamclient" + "github.com/wavetermdev/waveterm/pkg/util/envutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavejwt" "github.com/wavetermdev/waveterm/pkg/waveobj" @@ -54,6 +55,17 @@ func isJobManagerRunning(job *waveobj.Job) bool { return job.JobManagerStatus == JobStatus_Running } +func GetJobManagerStatus(ctx context.Context, jobId string) (string, error) { + job, err := wstore.DBGet[*waveobj.Job](ctx, jobId) + if err != nil { + return "", fmt.Errorf("failed to get job: %w", err) + } + if job == nil { + return JobStatus_Done, nil + } + return job.JobManagerStatus, nil +} + var ( jobConnStates = make(map[string]string) jobConnStatesLock sync.Mutex @@ -77,6 +89,7 @@ func InitJobController() { rpcClient := wshclient.GetBareRpcClient() rpcClient.EventListener.On(wps.Event_RouteUp, handleRouteUpEvent) rpcClient.EventListener.On(wps.Event_RouteDown, handleRouteDownEvent) + rpcClient.EventListener.On(wps.Event_ConnChange, handleConnChangeEvent) wshclient.EventSubCommand(rpcClient, wps.SubscriptionRequest{ Event: wps.Event_RouteUp, AllScopes: true, @@ -85,6 +98,10 @@ func InitJobController() { Event: wps.Event_RouteDown, AllScopes: true, }, nil) + wshclient.EventSubCommand(rpcClient, wps.SubscriptionRequest{ + Event: wps.Event_ConnChange, + AllScopes: true, + }, nil) } func handleRouteUpEvent(event *wps.WaveEvent) { @@ -105,6 +122,34 @@ func handleRouteEvent(event *wps.WaveEvent, newStatus string) { } } +func handleConnChangeEvent(event *wps.WaveEvent) { + var connStatus wshrpc.ConnStatus + err := utilfn.ReUnmarshal(&connStatus, event.Data) + if err != nil { + log.Printf("[connchange] error unmarshaling ConnStatus: %v", err) + return + } + + if !connStatus.Connected { + return + } + + for _, scope := range event.Scopes { + if strings.HasPrefix(scope, "connection:") { + connName := strings.TrimPrefix(scope, "connection:") + log.Printf("[conn:%s] connection became connected, terminating jobs with TerminateOnReconnect", connName) + go func() { + defer func() { + panichandler.PanicHandler("jobcontroller:handleConnChangeEvent", recover()) + }() + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + TerminateJobsOnConn(ctx, connName) + }() + } + } +} + func GetJobConnStatus(jobId string) string { jobConnStatesLock.Lock() defer jobConnStatesLock.Unlock() @@ -137,7 +182,7 @@ func GetConnectedJobIds() []string { return connectedJobIds } -func ensureJobConnected(ctx context.Context, jobId string) (*waveobj.Job, error) { +func CheckJobConnected(ctx context.Context, jobId string) (*waveobj.Job, error) { job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId) if err != nil { return nil, fmt.Errorf("failed to get job: %w", err) @@ -233,24 +278,20 @@ func StartJob(ctx context.Context, params StartJobParams) (string, error) { return "", fmt.Errorf("failed to create WaveFS file: %w", err) } - clientId, err := wstore.DBGetSingleton[*waveobj.Client](ctx) - if err != nil || clientId == nil { - return "", fmt.Errorf("failed to get client: %w", err) - } - + clientId := wstore.GetClientId() publicKey := wavejwt.GetPublicKey() publicKeyBase64 := base64.StdEncoding.EncodeToString(publicKey) - + jobEnv := envutil.CopyAndAddToEnvMap(params.Env, "WAVETERM_JOBID", jobId) startJobData := wshrpc.CommandRemoteStartJobData{ Cmd: params.Cmd, Args: params.Args, - Env: params.Env, + Env: jobEnv, TermSize: *params.TermSize, StreamMeta: streamMeta, JobAuthToken: jobAuthToken, JobId: jobId, MainServerJwtToken: jobAccessToken, - ClientId: clientId.OID, + ClientId: clientId, PublicKeyBase64: publicKeyBase64, } @@ -259,7 +300,8 @@ func StartJob(ctx context.Context, params StartJobParams) (string, error) { Timeout: 30000, } - log.Printf("[job:%s] sending RemoteStartJobCommand to connection %s", jobId, params.ConnName) + log.Printf("[job:%s] sending RemoteStartJobCommand to connection %s, cmd=%q, args=%v", jobId, params.ConnName, params.Cmd, params.Args) + log.Printf("[job:%s] env=%v", jobId, params.Env) rtnData, err := wshclient.RemoteStartJobCommand(bareRpc, startJobData, rpcOpts) if err != nil { log.Printf("[job:%s] RemoteStartJobCommand failed: %v", jobId, err) @@ -296,18 +338,18 @@ func StartJob(ctx context.Context, params StartJobParams) (string, error) { return jobId, nil } -func handleAppendJobFile(ctx context.Context, jobId string, fileName string, data []byte) error { - err := filestore.WFS.AppendData(ctx, jobId, fileName, data) +func doWFSAppend(ctx context.Context, oref waveobj.ORef, fileName string, data []byte) error { + err := filestore.WFS.AppendData(ctx, oref.OID, fileName, data) if err != nil { - return fmt.Errorf("error appending to job file: %w", err) + return err } wps.Broker.Publish(wps.WaveEvent{ Event: wps.Event_BlockFile, Scopes: []string{ - waveobj.MakeORef(waveobj.OType_Job, jobId).String(), + oref.String(), }, Data: &wps.WSFileEventData{ - ZoneId: jobId, + ZoneId: oref.OID, FileName: fileName, FileOp: wps.FileOp_Append, Data64: base64.StdEncoding.EncodeToString(data), @@ -316,6 +358,26 @@ func handleAppendJobFile(ctx context.Context, jobId string, fileName string, dat return nil } +func handleAppendJobFile(ctx context.Context, jobId string, fileName string, data []byte) error { + err := doWFSAppend(ctx, waveobj.MakeORef(waveobj.OType_Job, jobId), fileName, data) + if err != nil { + return fmt.Errorf("error appending to job file: %w", err) + } + + job, err := wstore.DBGet[*waveobj.Job](ctx, jobId) + if err != nil { + return fmt.Errorf("error getting job: %w", err) + } + if job != nil && job.AttachedBlockId != "" { + err = doWFSAppend(ctx, waveobj.MakeORef(waveobj.OType_Block, job.AttachedBlockId), fileName, data) + if err != nil { + return fmt.Errorf("error appending to block file: %w", err) + } + } + + return nil +} + func runOutputLoop(ctx context.Context, jobId string, reader *streamclient.Reader) { defer func() { log.Printf("[job:%s] output loop finished", jobId) @@ -404,6 +466,13 @@ func tryTerminateJobManager(ctx context.Context, jobId string) { } func TerminateJobManager(ctx context.Context, jobId string) error { + err := wstore.DBUpdateFn[*waveobj.Job](ctx, jobId, func(job *waveobj.Job) { + job.TerminateOnReconnect = true + }) + if err != nil { + return fmt.Errorf("failed to set TerminateOnReconnect: %w", err) + } + job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId) if err != nil { return fmt.Errorf("failed to get job: %w", err) @@ -475,7 +544,7 @@ func remoteTerminateJobManager(ctx context.Context, job *waveobj.Job) error { return nil } -func ReconnectJob(ctx context.Context, jobId string) error { +func ReconnectJob(ctx context.Context, jobId string, rtOpts *waveobj.RuntimeOpts) error { job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId) if err != nil { return fmt.Errorf("failed to get job: %w", err) @@ -549,7 +618,36 @@ func ReconnectJob(ctx context.Context, jobId string) error { } log.Printf("[job:%s] route established, restarting streaming", jobId) - return RestartStreaming(ctx, jobId, true) + return RestartStreaming(ctx, jobId, true, rtOpts) +} + +func TerminateJobsOnConn(ctx context.Context, connName string) { + allJobs, err := wstore.DBGetAllObjsByType[*waveobj.Job](ctx, waveobj.OType_Job) + if err != nil { + log.Printf("[conn:%s] failed to get jobs for termination: %v", connName, err) + return + } + + var jobsToTerminate []*waveobj.Job + for _, job := range allJobs { + if job.Connection == connName && job.TerminateOnReconnect { + jobsToTerminate = append(jobsToTerminate, job) + } + } + + log.Printf("[conn:%s] found %d jobs to terminate", connName, len(jobsToTerminate)) + + successCount := 0 + for _, job := range jobsToTerminate { + err = remoteTerminateJobManager(ctx, job) + if err != nil { + log.Printf("[job:%s] error terminating: %v", job.OID, err) + } else { + successCount++ + } + } + + log.Printf("[conn:%s] finished terminating jobs: %d/%d successful", connName, successCount, len(jobsToTerminate)) } func ReconnectJobsForConn(ctx context.Context, connName string) error { @@ -576,7 +674,7 @@ func ReconnectJobsForConn(ctx context.Context, connName string) error { log.Printf("[conn:%s] found %d jobs to reconnect", connName, len(jobsToReconnect)) for _, job := range jobsToReconnect { - err = ReconnectJob(ctx, job.OID) + err = ReconnectJob(ctx, job.OID, nil) if err != nil { log.Printf("[job:%s] error reconnecting: %v", job.OID, err) } @@ -585,12 +683,23 @@ func ReconnectJobsForConn(ctx context.Context, connName string) error { return nil } -func RestartStreaming(ctx context.Context, jobId string, knownConnected bool) error { +func RestartStreaming(ctx context.Context, jobId string, knownConnected bool, rtOpts *waveobj.RuntimeOpts) error { job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId) if err != nil { return fmt.Errorf("failed to get job: %w", err) } + termSize := job.CmdTermSize + if rtOpts != nil && rtOpts.TermSize.Rows > 0 && rtOpts.TermSize.Cols > 0 { + termSize = rtOpts.TermSize + err = wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + job.CmdTermSize = termSize + }) + if err != nil { + log.Printf("[job:%s] warning: failed to update termsize in DB: %v", jobId, err) + } + } + if !knownConnected { isConnected, err := conncontroller.IsConnected(job.Connection) if err != nil { @@ -625,6 +734,7 @@ func RestartStreaming(ctx context.Context, jobId string, knownConnected bool) er prepareData := wshrpc.CommandJobPrepareConnectData{ StreamMeta: *streamMeta, Seq: currentSeq, + TermSize: termSize, } rpcOpts := &wshrpc.RpcOpts{ @@ -823,7 +933,17 @@ func DetachJobFromBlock(ctx context.Context, jobId string, updateBlock bool) err func SendInput(ctx context.Context, data wshrpc.CommandJobInputData) error { jobId := data.JobId - _, err := ensureJobConnected(ctx, jobId) + + if data.TermSize != nil { + err := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { + job.CmdTermSize = *data.TermSize + }) + if err != nil { + log.Printf("[job:%s] warning: failed to update termsize in DB: %v", jobId, err) + } + } + + _, err := CheckJobConnected(ctx, jobId) if err != nil { return err } @@ -840,14 +960,5 @@ func SendInput(ctx context.Context, data wshrpc.CommandJobInputData) error { return fmt.Errorf("failed to send input to job: %w", err) } - if data.TermSize != nil { - err = wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) { - job.CmdTermSize = *data.TermSize - }) - if err != nil { - log.Printf("[job:%s] warning: failed to update termsize in DB: %v", jobId, err) - } - } - return nil } diff --git a/pkg/jobmanager/jobcmd.go b/pkg/jobmanager/jobcmd.go index 2349e69b35..7102fd77fd 100644 --- a/pkg/jobmanager/jobcmd.go +++ b/pkg/jobmanager/jobcmd.go @@ -7,7 +7,6 @@ import ( "encoding/base64" "fmt" "log" - "os" "os/exec" "sync" "syscall" @@ -31,6 +30,7 @@ type JobCmd struct { cmd *exec.Cmd cmdPty pty.Pty ptsName string + termSize waveobj.TermSize cleanedUp bool ptyClosed bool processExited bool @@ -53,7 +53,7 @@ func MakeJobCmd(jobId string, cmdDef CmdDef) (*JobCmd, error) { } ecmd := exec.Command(cmdDef.Cmd, cmdDef.Args...) if len(cmdDef.Env) > 0 { - ecmd.Env = os.Environ() + ecmd.Env = make([]string, 0, len(cmdDef.Env)) for key, val := range cmdDef.Env { ecmd.Env = append(ecmd.Env, fmt.Sprintf("%s=%s", key, val)) } @@ -66,6 +66,7 @@ func MakeJobCmd(jobId string, cmdDef CmdDef) (*JobCmd, error) { jm.cmd = ecmd jm.cmdPty = cmdPty jm.ptsName = jm.cmdPty.Name() + jm.termSize = cmdDef.TermSize go jm.waitForProcess() return jm, nil } @@ -150,6 +151,30 @@ func (jm *JobCmd) GetExitInfo() (bool, *wshrpc.CommandJobCmdExitedData) { return true, exitData } +func (jm *JobCmd) setTermSize_withlock(termSize waveobj.TermSize) error { + if jm.cmdPty == nil { + return fmt.Errorf("no active pty") + } + if jm.termSize.Rows == termSize.Rows && jm.termSize.Cols == termSize.Cols { + return nil + } + err := pty.Setsize(jm.cmdPty, &pty.Winsize{ + Rows: uint16(termSize.Rows), + Cols: uint16(termSize.Cols), + }) + if err != nil { + return fmt.Errorf("error setting terminal size: %w", err) + } + jm.termSize = termSize + return nil +} + +func (jm *JobCmd) SetTermSize(termSize waveobj.TermSize) error { + jm.lock.Lock() + defer jm.lock.Unlock() + return jm.setTermSize_withlock(termSize) +} + // TODO set up a single input handler loop + queue so we dont need to hold the lock but still get synchronized in-order execution func (jm *JobCmd) HandleInput(data wshrpc.CommandJobInputData) error { jm.lock.Lock() @@ -182,12 +207,9 @@ func (jm *JobCmd) HandleInput(data wshrpc.CommandJobInputData) error { } if data.TermSize != nil { - err := pty.Setsize(jm.cmdPty, &pty.Winsize{ - Rows: uint16(data.TermSize.Rows), - Cols: uint16(data.TermSize.Cols), - }) + err := jm.setTermSize_withlock(*data.TermSize) if err != nil { - return fmt.Errorf("error setting terminal size: %w", err) + return err } } diff --git a/pkg/jobmanager/jobmanager.go b/pkg/jobmanager/jobmanager.go index afa015304f..3783baebc3 100644 --- a/pkg/jobmanager/jobmanager.go +++ b/pkg/jobmanager/jobmanager.go @@ -11,9 +11,12 @@ import ( "path/filepath" "runtime" "sync" + "time" + "github.com/shirou/gopsutil/v4/process" "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/panichandler" + "github.com/wavetermdev/waveterm/pkg/utilds" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wavejwt" "github.com/wavetermdev/waveterm/pkg/wshrpc" @@ -23,6 +26,8 @@ import ( const JobAccessTokenLabel = "Wave-JobAccessToken" const JobManagerStartLabel = "Wave-JobManagerStart" +const JobInputQueueTimeout = 100 * time.Millisecond +const JobInputQueueSize = 1000 var WshCmdJobManager JobManager @@ -33,6 +38,7 @@ type JobManager struct { JwtPublicKey []byte JobAuthToken string StreamManager *StreamManager + InputQueue *utilds.QuickReorderQueue[wshrpc.CommandJobInputData] lock sync.Mutex attachedClient *MainServerConn connectedStreamClient *MainServerConn @@ -48,6 +54,7 @@ func SetupJobManager(clientId string, jobId string, publicKeyBytes []byte, jobAu WshCmdJobManager.JwtPublicKey = publicKeyBytes WshCmdJobManager.JobAuthToken = jobAuthToken WshCmdJobManager.StreamManager = MakeStreamManager() + WshCmdJobManager.InputQueue = utilds.MakeQuickReorderQueue[wshrpc.CommandJobInputData](JobInputQueueSize, JobInputQueueTimeout) err := wavejwt.SetPublicKey(publicKeyBytes) if err != nil { return fmt.Errorf("failed to set public key: %w", err) @@ -56,6 +63,14 @@ func SetupJobManager(clientId string, jobId string, publicKeyBytes []byte, jobAu if err != nil { return err } + + go func() { + defer func() { + panichandler.PanicHandler("JobManager:processInputQueue", recover()) + }() + WshCmdJobManager.processInputQueue() + }() + fmt.Fprintf(readyFile, JobManagerStartLabel+"\n") readyFile.Close() @@ -67,6 +82,24 @@ func SetupJobManager(clientId string, jobId string, publicKeyBytes []byte, jobAu return nil } +func (jm *JobManager) processInputQueue() { + for data := range jm.InputQueue.C() { + jm.lock.Lock() + cmd := jm.Cmd + jm.lock.Unlock() + + if cmd == nil { + log.Printf("processInputQueue: skipping input, job not started\n") + continue + } + + err := cmd.HandleInput(data) + if err != nil { + log.Printf("processInputQueue: error handling input: %v\n", err) + } + } +} + func (jm *JobManager) GetCmd() *JobCmd { jm.lock.Lock() defer jm.lock.Unlock() @@ -164,6 +197,175 @@ func (jm *JobManager) disconnectFromStreamHelper(mainServerConn *MainServerConn) jm.connectedStreamClient = nil } +func (jm *JobManager) SetAttachedClient(msc *MainServerConn) { + jm.lock.Lock() + defer jm.lock.Unlock() + + if jm.attachedClient != nil { + log.Printf("SetAttachedClient: kicking out existing client\n") + jm.attachedClient.Close() + } + jm.attachedClient = msc +} + +func (jm *JobManager) StartJob(msc *MainServerConn, data wshrpc.CommandStartJobData) (*wshrpc.CommandStartJobRtnData, error) { + jm.lock.Lock() + defer jm.lock.Unlock() + + if jm.Cmd != nil { + log.Printf("StartJob: job already started") + return nil, fmt.Errorf("job already started") + } + + cmdDef := CmdDef{ + Cmd: data.Cmd, + Args: data.Args, + Env: data.Env, + TermSize: data.TermSize, + } + log.Printf("StartJob: creating job cmd for jobid=%s", jm.JobId) + jobCmd, err := MakeJobCmd(jm.JobId, cmdDef) + if err != nil { + log.Printf("StartJob: failed to make job cmd: %v", err) + return nil, fmt.Errorf("failed to start job: %w", err) + } + jm.Cmd = jobCmd + log.Printf("StartJob: job cmd created successfully") + + if data.StreamMeta != nil { + serverSeq, err := jm.connectToStreamHelper_withlock(msc, *data.StreamMeta, 0) + if err != nil { + return nil, fmt.Errorf("failed to connect stream: %w", err) + } + err = msc.WshRpc.StreamBroker.AttachStreamWriter(data.StreamMeta, jm.StreamManager) + if err != nil { + return nil, fmt.Errorf("failed to attach stream writer: %w", err) + } + log.Printf("StartJob: connected stream streamid=%s serverSeq=%d\n", data.StreamMeta.Id, serverSeq) + } + + _, cmdPty := jobCmd.GetCmd() + if cmdPty != nil { + log.Printf("StartJob: attaching pty reader to stream manager") + err = jm.StreamManager.AttachReader(cmdPty) + if err != nil { + log.Printf("StartJob: failed to attach reader: %v", err) + return nil, fmt.Errorf("failed to attach reader to stream manager: %w", err) + } + log.Printf("StartJob: pty reader attached successfully") + } else { + log.Printf("StartJob: no pty to attach") + } + + cmd, _ := jobCmd.GetCmd() + if cmd == nil || cmd.Process == nil { + log.Printf("StartJob: cmd or process is nil") + return nil, fmt.Errorf("cmd or process is nil") + } + cmdPid := cmd.Process.Pid + cmdProc, err := process.NewProcess(int32(cmdPid)) + if err != nil { + log.Printf("StartJob: failed to get cmd process: %v", err) + return nil, fmt.Errorf("failed to get cmd process: %w", err) + } + cmdStartTs, err := cmdProc.CreateTime() + if err != nil { + log.Printf("StartJob: failed to get cmd start time: %v", err) + return nil, fmt.Errorf("failed to get cmd start time: %w", err) + } + + jobManagerPid := os.Getpid() + jobManagerProc, err := process.NewProcess(int32(jobManagerPid)) + if err != nil { + log.Printf("StartJob: failed to get job manager process: %v", err) + return nil, fmt.Errorf("failed to get job manager process: %w", err) + } + jobManagerStartTs, err := jobManagerProc.CreateTime() + if err != nil { + log.Printf("StartJob: failed to get job manager start time: %v", err) + return nil, fmt.Errorf("failed to get job manager start time: %w", err) + } + + log.Printf("StartJob: job started successfully cmdPid=%d cmdStartTs=%d jobManagerPid=%d jobManagerStartTs=%d", cmdPid, cmdStartTs, jobManagerPid, jobManagerStartTs) + return &wshrpc.CommandStartJobRtnData{ + CmdPid: cmdPid, + CmdStartTs: cmdStartTs, + JobManagerPid: jobManagerPid, + JobManagerStartTs: jobManagerStartTs, + }, nil +} + +func (jm *JobManager) PrepareConnect(msc *MainServerConn, data wshrpc.CommandJobPrepareConnectData) (*wshrpc.CommandJobConnectRtnData, error) { + jm.lock.Lock() + defer jm.lock.Unlock() + + if jm.Cmd == nil { + return nil, fmt.Errorf("job not started") + } + + err := jm.Cmd.SetTermSize(data.TermSize) + if err != nil { + log.Printf("PrepareConnect: failed to set term size: %v\n", err) + } + + rtnData := &wshrpc.CommandJobConnectRtnData{} + streamDone, streamError := jm.StreamManager.GetStreamDoneInfo() + + if streamDone { + log.Printf("PrepareConnect: stream already done, skipping connection streamError=%q\n", streamError) + rtnData.Seq = data.Seq + rtnData.StreamDone = true + rtnData.StreamError = streamError + } else { + corkedStreamMeta := data.StreamMeta + corkedStreamMeta.RWnd = 0 + serverSeq, err := jm.connectToStreamHelper_withlock(msc, corkedStreamMeta, data.Seq) + if err != nil { + return nil, err + } + jm.pendingStreamMeta = &data.StreamMeta + rtnData.Seq = serverSeq + rtnData.StreamDone = false + } + + hasExited, exitData := jm.Cmd.GetExitInfo() + if hasExited && exitData != nil { + rtnData.HasExited = true + rtnData.ExitCode = exitData.ExitCode + rtnData.ExitSignal = exitData.ExitSignal + rtnData.ExitErr = exitData.ExitErr + } + + log.Printf("PrepareConnect: streamid=%s clientSeq=%d serverSeq=%d streamDone=%v streamError=%q hasExited=%v\n", data.StreamMeta.Id, data.Seq, rtnData.Seq, rtnData.StreamDone, rtnData.StreamError, hasExited) + return rtnData, nil +} + +func (jm *JobManager) StartStream(msc *MainServerConn) error { + jm.lock.Lock() + defer jm.lock.Unlock() + + if jm.Cmd == nil { + return fmt.Errorf("job not started") + } + if jm.pendingStreamMeta == nil { + return fmt.Errorf("no pending stream (call PrepareConnect first)") + } + + err := msc.WshRpc.StreamBroker.AttachStreamWriter(jm.pendingStreamMeta, jm.StreamManager) + if err != nil { + return fmt.Errorf("failed to attach stream writer: %w", err) + } + + err = jm.StreamManager.SetRwndSize(int(jm.pendingStreamMeta.RWnd)) + if err != nil { + return fmt.Errorf("failed to set rwnd size: %w", err) + } + + log.Printf("StartStream: streamid=%s rwnd=%d streaming started\n", jm.pendingStreamMeta.Id, jm.pendingStreamMeta.RWnd) + jm.pendingStreamMeta = nil + return nil +} + func GetJobSocketPath(jobId string) string { socketDir := filepath.Join("/tmp", fmt.Sprintf("waveterm-%d", os.Getuid())) return filepath.Join(socketDir, fmt.Sprintf("%s.sock", jobId)) diff --git a/pkg/jobmanager/jobmanager_unix.go b/pkg/jobmanager/jobmanager_unix.go index bddbea3987..46cd0deec5 100644 --- a/pkg/jobmanager/jobmanager_unix.go +++ b/pkg/jobmanager/jobmanager_unix.go @@ -10,6 +10,7 @@ import ( "log" "os" "os/signal" + "path/filepath" "strings" "syscall" @@ -69,6 +70,12 @@ func daemonize(clientId string, jobId string) error { devNull.Close() logPath := GetJobFilePath(clientId, jobId, "log") + logDir := filepath.Dir(logPath) + err = os.MkdirAll(logDir, 0700) + if err != nil { + return fmt.Errorf("failed to create log directory: %w", err) + } + logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) if err != nil { return fmt.Errorf("failed to open log file: %w", err) diff --git a/pkg/jobmanager/mainserverconn.go b/pkg/jobmanager/mainserverconn.go index 8f10eed20c..de5965128d 100644 --- a/pkg/jobmanager/mainserverconn.go +++ b/pkg/jobmanager/mainserverconn.go @@ -8,11 +8,9 @@ import ( "fmt" "log" "net" - "os" "sync" "sync/atomic" - "github.com/shirou/gopsutil/v4/process" "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/wavejwt" "github.com/wavetermdev/waveterm/pkg/wshrpc" @@ -93,14 +91,7 @@ func (msc *MainServerConn) AuthenticateToJobManagerCommand(ctx context.Context, return err } - WshCmdJobManager.lock.Lock() - defer WshCmdJobManager.lock.Unlock() - - if WshCmdJobManager.attachedClient != nil { - log.Printf("AuthenticateToJobManager: kicking out existing client\n") - WshCmdJobManager.attachedClient.Close() - } - WshCmdJobManager.attachedClient = msc + WshCmdJobManager.SetAttachedClient(msc) return nil } @@ -110,183 +101,35 @@ func (msc *MainServerConn) StartJobCommand(ctx context.Context, data wshrpc.Comm log.Printf("StartJobCommand: not authenticated") return nil, fmt.Errorf("not authenticated") } - if WshCmdJobManager.IsJobStarted() { - log.Printf("StartJobCommand: job already started") - return nil, fmt.Errorf("job already started") - } - - WshCmdJobManager.lock.Lock() - defer WshCmdJobManager.lock.Unlock() - - if WshCmdJobManager.Cmd != nil { - log.Printf("StartJobCommand: job already started (double check)") - return nil, fmt.Errorf("job already started") - } - - cmdDef := CmdDef{ - Cmd: data.Cmd, - Args: data.Args, - Env: data.Env, - TermSize: data.TermSize, - } - log.Printf("StartJobCommand: creating job cmd for jobid=%s", WshCmdJobManager.JobId) - jobCmd, err := MakeJobCmd(WshCmdJobManager.JobId, cmdDef) - if err != nil { - log.Printf("StartJobCommand: failed to make job cmd: %v", err) - return nil, fmt.Errorf("failed to start job: %w", err) - } - WshCmdJobManager.Cmd = jobCmd - log.Printf("StartJobCommand: job cmd created successfully") - - if data.StreamMeta != nil { - serverSeq, err := WshCmdJobManager.connectToStreamHelper_withlock(msc, *data.StreamMeta, 0) - if err != nil { - return nil, fmt.Errorf("failed to connect stream: %w", err) - } - err = msc.WshRpc.StreamBroker.AttachStreamWriter(data.StreamMeta, WshCmdJobManager.StreamManager) - if err != nil { - return nil, fmt.Errorf("failed to attach stream writer: %w", err) - } - log.Printf("StartJob: connected stream streamid=%s serverSeq=%d\n", data.StreamMeta.Id, serverSeq) - } - - _, cmdPty := jobCmd.GetCmd() - if cmdPty != nil { - log.Printf("StartJobCommand: attaching pty reader to stream manager") - err = WshCmdJobManager.StreamManager.AttachReader(cmdPty) - if err != nil { - log.Printf("StartJobCommand: failed to attach reader: %v", err) - return nil, fmt.Errorf("failed to attach reader to stream manager: %w", err) - } - log.Printf("StartJobCommand: pty reader attached successfully") - } else { - log.Printf("StartJobCommand: no pty to attach") - } - - cmd, _ := jobCmd.GetCmd() - if cmd == nil || cmd.Process == nil { - log.Printf("StartJobCommand: cmd or process is nil") - return nil, fmt.Errorf("cmd or process is nil") - } - cmdPid := cmd.Process.Pid - cmdProc, err := process.NewProcess(int32(cmdPid)) - if err != nil { - log.Printf("StartJobCommand: failed to get cmd process: %v", err) - return nil, fmt.Errorf("failed to get cmd process: %w", err) - } - cmdStartTs, err := cmdProc.CreateTime() - if err != nil { - log.Printf("StartJobCommand: failed to get cmd start time: %v", err) - return nil, fmt.Errorf("failed to get cmd start time: %w", err) - } - - jobManagerPid := os.Getpid() - jobManagerProc, err := process.NewProcess(int32(jobManagerPid)) - if err != nil { - log.Printf("StartJobCommand: failed to get job manager process: %v", err) - return nil, fmt.Errorf("failed to get job manager process: %w", err) - } - jobManagerStartTs, err := jobManagerProc.CreateTime() - if err != nil { - log.Printf("StartJobCommand: failed to get job manager start time: %v", err) - return nil, fmt.Errorf("failed to get job manager start time: %w", err) - } - - log.Printf("StartJobCommand: job started successfully cmdPid=%d cmdStartTs=%d jobManagerPid=%d jobManagerStartTs=%d", cmdPid, cmdStartTs, jobManagerPid, jobManagerStartTs) - return &wshrpc.CommandStartJobRtnData{ - CmdPid: cmdPid, - CmdStartTs: cmdStartTs, - JobManagerPid: jobManagerPid, - JobManagerStartTs: jobManagerStartTs, - }, nil + return WshCmdJobManager.StartJob(msc, data) } func (msc *MainServerConn) JobPrepareConnectCommand(ctx context.Context, data wshrpc.CommandJobPrepareConnectData) (*wshrpc.CommandJobConnectRtnData, error) { - WshCmdJobManager.lock.Lock() - defer WshCmdJobManager.lock.Unlock() - if !msc.PeerAuthenticated.Load() { return nil, fmt.Errorf("peer not authenticated") } if !msc.SelfAuthenticated.Load() { return nil, fmt.Errorf("not authenticated to server") } - if WshCmdJobManager.Cmd == nil { - return nil, fmt.Errorf("job not started") - } - - rtnData := &wshrpc.CommandJobConnectRtnData{} - streamDone, streamError := WshCmdJobManager.StreamManager.GetStreamDoneInfo() - - if streamDone { - log.Printf("JobPrepareConnect: stream already done, skipping connection streamError=%q\n", streamError) - rtnData.Seq = data.Seq - rtnData.StreamDone = true - rtnData.StreamError = streamError - } else { - corkedStreamMeta := data.StreamMeta - corkedStreamMeta.RWnd = 0 - serverSeq, err := WshCmdJobManager.connectToStreamHelper_withlock(msc, corkedStreamMeta, data.Seq) - if err != nil { - return nil, err - } - WshCmdJobManager.pendingStreamMeta = &data.StreamMeta - rtnData.Seq = serverSeq - rtnData.StreamDone = false - } - - hasExited, exitData := WshCmdJobManager.Cmd.GetExitInfo() - if hasExited && exitData != nil { - rtnData.HasExited = true - rtnData.ExitCode = exitData.ExitCode - rtnData.ExitSignal = exitData.ExitSignal - rtnData.ExitErr = exitData.ExitErr - } - - log.Printf("JobPrepareConnect: streamid=%s clientSeq=%d serverSeq=%d streamDone=%v streamError=%q hasExited=%v\n", data.StreamMeta.Id, data.Seq, rtnData.Seq, rtnData.StreamDone, rtnData.StreamError, hasExited) - return rtnData, nil + return WshCmdJobManager.PrepareConnect(msc, data) } func (msc *MainServerConn) JobStartStreamCommand(ctx context.Context, data wshrpc.CommandJobStartStreamData) error { - WshCmdJobManager.lock.Lock() - defer WshCmdJobManager.lock.Unlock() - if !msc.PeerAuthenticated.Load() { return fmt.Errorf("not authenticated") } - if WshCmdJobManager.Cmd == nil { - return fmt.Errorf("job not started") - } - if WshCmdJobManager.pendingStreamMeta == nil { - return fmt.Errorf("no pending stream (call JobPrepareConnect first)") - } - - err := msc.WshRpc.StreamBroker.AttachStreamWriter(WshCmdJobManager.pendingStreamMeta, WshCmdJobManager.StreamManager) - if err != nil { - return fmt.Errorf("failed to attach stream writer: %w", err) - } - - err = WshCmdJobManager.StreamManager.SetRwndSize(int(WshCmdJobManager.pendingStreamMeta.RWnd)) - if err != nil { - return fmt.Errorf("failed to set rwnd size: %w", err) - } - - log.Printf("JobStartStream: streamid=%s rwnd=%d streaming started\n", WshCmdJobManager.pendingStreamMeta.Id, WshCmdJobManager.pendingStreamMeta.RWnd) - WshCmdJobManager.pendingStreamMeta = nil - return nil + return WshCmdJobManager.StartStream(msc) } func (msc *MainServerConn) JobInputCommand(ctx context.Context, data wshrpc.CommandJobInputData) error { - WshCmdJobManager.lock.Lock() - defer WshCmdJobManager.lock.Unlock() - if !msc.PeerAuthenticated.Load() { return fmt.Errorf("not authenticated") } - if WshCmdJobManager.Cmd == nil { + if !WshCmdJobManager.IsJobStarted() { return fmt.Errorf("job not started") } - return WshCmdJobManager.Cmd.HandleInput(data) + WshCmdJobManager.InputQueue.QueueItem(data.InputSessionId, data.SeqNum, data) + return nil } diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index b042eb9693..ae31f2e555 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -27,6 +27,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata" "github.com/wavetermdev/waveterm/pkg/userinput" + "github.com/wavetermdev/waveterm/pkg/util/envutil" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" @@ -34,7 +35,9 @@ import ( "github.com/wavetermdev/waveterm/pkg/wconfig" "github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" "github.com/wavetermdev/waveterm/pkg/wshutil" + "github.com/wavetermdev/waveterm/pkg/wstore" "golang.org/x/crypto/ssh" "golang.org/x/mod/semver" ) @@ -65,7 +68,9 @@ var clientControllerMap = make(map[remote.SSHOpts]*SSHConn) var activeConnCounter = &atomic.Int32{} type SSHConn struct { - Lock *sync.Mutex + lock *sync.Mutex // this lock protects the fields in the struct from concurrent access + lifecycleLock *sync.Mutex // this protects the lifecycle from concurrent calls + Status string WshEnabled *atomic.Bool Opts *remote.SSHOpts @@ -77,7 +82,6 @@ type SSHConn struct { WshError string NoWshReason string WshVersion string - HasWaiter *atomic.Bool LastConnectTime int64 ActiveConnNum int } @@ -117,8 +121,8 @@ func GetNumSSHHasConnected() int { } func (conn *SSHConn) DeriveConnStatus() wshrpc.ConnStatus { - conn.Lock.Lock() - defer conn.Lock.Unlock() + conn.lock.Lock() + defer conn.lock.Unlock() return wshrpc.ConnStatus{ Status: conn.Status, Connected: conn.Status == Status_Connected, @@ -156,51 +160,78 @@ func (conn *SSHConn) FireConnChangeEvent() { } func (conn *SSHConn) Close() error { + conn.lifecycleLock.Lock() + defer conn.lifecycleLock.Unlock() + defer conn.FireConnChangeEvent() conn.WithLock(func() { if conn.Status == Status_Connected || conn.Status == Status_Connecting { // if status is init, disconnected, or error don't change it conn.Status = Status_Disconnected } - conn.close_nolock() }) - // we must wait for the waiter to complete - startTime := time.Now() - for conn.HasWaiter.Load() { - time.Sleep(10 * time.Millisecond) - if time.Since(startTime) > 2*time.Second { - return fmt.Errorf("timeout waiting for waiter to complete") - } - } + conn.closeInternal_withlifecyclelock() return nil } -func (conn *SSHConn) close_nolock() { +func (conn *SSHConn) closeInternal_withlifecyclelock() { // does not set status (that should happen at another level) - if conn.DomainSockListener != nil { - conn.DomainSockListener.Close() - conn.DomainSockListener = nil - conn.DomainSockName = "" + client := WithLockRtn(conn, func() *ssh.Client { + return conn.Client + }) + if client != nil { + // this MUST go first to force close the connection. + // the DomainSockListener.Close() sends SSH protocol packets which can block on a dead network conn + startTime := time.Now() + client.Close() + duration := time.Since(startTime).Milliseconds() + if duration > 100 { + log.Printf("[conncontroller] conn:%s Client.Close() took %d ms", conn.GetName(), duration) + } + conn.WithLock(func() { + conn.Client = nil + }) } - if conn.ConnController != nil { - conn.ConnController.Close() - conn.ConnController = nil + listener := WithLockRtn(conn, func() net.Listener { + return conn.DomainSockListener + }) + if listener != nil { + startTime := time.Now() + listener.Close() + duration := time.Since(startTime).Milliseconds() + if duration > 100 { + log.Printf("[conncontroller] conn:%s DomainSockListener.Close() took %d ms", conn.GetName(), duration) + } + conn.WithLock(func() { + conn.DomainSockListener = nil + conn.DomainSockName = "" + }) } - if conn.Client != nil { - conn.Client.Close() - conn.Client = nil + controller := WithLockRtn(conn, func() *ssh.Session { + return conn.ConnController + }) + if controller != nil { + startTime := time.Now() + controller.Close() + duration := time.Since(startTime).Milliseconds() + if duration > 100 { + log.Printf("[conncontroller] conn:%s ConnController.Close() took %d ms", conn.GetName(), duration) + } + conn.WithLock(func() { + conn.ConnController = nil + }) } } func (conn *SSHConn) GetDomainSocketName() string { - conn.Lock.Lock() - defer conn.Lock.Unlock() + conn.lock.Lock() + defer conn.lock.Unlock() return conn.DomainSockName } func (conn *SSHConn) GetStatus() string { - conn.Lock.Lock() - defer conn.Lock.Unlock() + conn.lock.Lock() + defer conn.lock.Unlock() return conn.Status } @@ -266,6 +297,70 @@ func IsWshVersionUpToDate(logCtx context.Context, wshVersionLine string) (bool, return true, clientVersion, "", nil } +// for testing only -- trying to determine the env difference when attaching or not attaching a pty to an ssh session +func (conn *SSHConn) GetEnvironmentMaps(ctx context.Context) (map[string]string, map[string]string, error) { + client := conn.GetClient() + if client == nil { + return nil, nil, fmt.Errorf("ssh client is not connected") + } + + noPtyEnv, err := conn.getEnvironmentNoPty(ctx, client) + if err != nil { + return nil, nil, fmt.Errorf("error getting environment without PTY: %w", err) + } + + ptyEnv, err := conn.getEnvironmentWithPty(ctx, client) + if err != nil { + return nil, nil, fmt.Errorf("error getting environment with PTY: %w", err) + } + + return noPtyEnv, ptyEnv, nil +} + +func (conn *SSHConn) getEnvironmentNoPty(ctx context.Context, client *ssh.Client) (map[string]string, error) { + session, err := client.NewSession() + if err != nil { + return nil, fmt.Errorf("unable to create ssh session: %w", err) + } + defer session.Close() + + outputBuf := &strings.Builder{} + session.Stdout = outputBuf + session.Stderr = outputBuf + + err = session.Run("env -0") + if err != nil { + return nil, fmt.Errorf("error running env command: %w", err) + } + + return envutil.EnvToMap(outputBuf.String()), nil +} + +func (conn *SSHConn) getEnvironmentWithPty(ctx context.Context, client *ssh.Client) (map[string]string, error) { + session, err := client.NewSession() + if err != nil { + return nil, fmt.Errorf("unable to create ssh session: %w", err) + } + defer session.Close() + + termSize := waveobj.TermSize{Rows: 24, Cols: 80} + err = session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil) + if err != nil { + return nil, fmt.Errorf("unable to request PTY: %w", err) + } + + outputBuf := &strings.Builder{} + session.Stdout = outputBuf + session.Stderr = outputBuf + + err = session.Run("env -0") + if err != nil { + return nil, fmt.Errorf("error running env command: %w", err) + } + + return envutil.EnvToMap(outputBuf.String()), nil +} + func (conn *SSHConn) getWshPath() string { config, ok := conn.getConnectionConfig() if ok && config.ConnWshPath != "" { @@ -422,11 +517,17 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool, useR conn.Infof(ctx, "connserver started, waiting for route to be registered\n") regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) defer cancelFn() - err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn)) + connRoute := wshutil.MakeConnectionRouteId(rpcCtx.Conn) + err = wshutil.DefaultRouter.WaitForRegister(regCtx, connRoute) if err != nil { return false, clientVersion, "", fmt.Errorf("timeout waiting for connserver to register") } time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready") + wshclient.ConnServerInitCommand( + wshclient.GetBareRpcClient(), + wshrpc.CommandConnServerInitData{ClientId: wstore.GetClientId()}, + &wshrpc.RpcOpts{Route: connRoute}, + ) conn.Infof(ctx, "connserver is registered and ready\n") return false, clientVersion, "", nil } @@ -534,8 +635,8 @@ func (conn *SSHConn) InstallWsh(ctx context.Context, osArchStr string) error { } func (conn *SSHConn) GetClient() *ssh.Client { - conn.Lock.Lock() - defer conn.Lock.Unlock() + conn.lock.Lock() + defer conn.lock.Unlock() return conn.Client } @@ -565,6 +666,9 @@ func (conn *SSHConn) WaitForConnect(ctx context.Context) error { // does not return an error since that error is stored inside of SSHConn func (conn *SSHConn) Connect(ctx context.Context, connFlags *wconfig.ConnKeywords) error { + conn.lifecycleLock.Lock() + defer conn.lifecycleLock.Unlock() + blocklogger.Infof(ctx, "\n") var connectAllowed bool conn.WithLock(func() { @@ -583,39 +687,41 @@ func (conn *SSHConn) Connect(ctx context.Context, connFlags *wconfig.ConnKeyword conn.Infof(ctx, "trying to connect to %q...\n", conn.GetName()) conn.FireConnChangeEvent() err := conn.connectInternal(ctx, connFlags) - conn.WithLock(func() { - if err != nil { - conn.Infof(ctx, "ERROR %v\n\n", err) + if err != nil { + conn.Infof(ctx, "ERROR %v\n\n", err) + conn.WithLock(func() { conn.Status = Status_Error conn.Error = err.Error() - conn.close_nolock() - telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{ - Conn: map[string]int{"ssh:connecterror": 1}, - }, "ssh-connconnect") - telemetry.GoRecordTEventWrap(&telemetrydata.TEvent{ - Event: "conn:connecterror", - Props: telemetrydata.TEventProps{ - ConnType: "ssh", - }, - }) - } else { - conn.Infof(ctx, "successfully connected (wsh:%v)\n\n", conn.WshEnabled.Load()) + }) + conn.closeInternal_withlifecyclelock() + telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{ + Conn: map[string]int{"ssh:connecterror": 1}, + }, "ssh-connconnect") + telemetry.GoRecordTEventWrap(&telemetrydata.TEvent{ + Event: "conn:connecterror", + Props: telemetrydata.TEventProps{ + ConnType: "ssh", + }, + }) + } else { + conn.Infof(ctx, "successfully connected (wsh:%v)\n\n", conn.WshEnabled.Load()) + conn.WithLock(func() { conn.Status = Status_Connected conn.LastConnectTime = time.Now().UnixMilli() if conn.ActiveConnNum == 0 { conn.ActiveConnNum = int(activeConnCounter.Add(1)) } - telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{ - Conn: map[string]int{"ssh:connect": 1}, - }, "ssh-connconnect") - telemetry.GoRecordTEventWrap(&telemetrydata.TEvent{ - Event: "conn:connect", - Props: telemetrydata.TEventProps{ - ConnType: "ssh", - }, - }) - } - }) + }) + telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{ + Conn: map[string]int{"ssh:connect": 1}, + }, "ssh-connconnect") + telemetry.GoRecordTEventWrap(&telemetrydata.TEvent{ + Event: "conn:connect", + Props: telemetrydata.TEventProps{ + ConnType: "ssh", + }, + }) + } conn.FireConnChangeEvent() if err != nil { return err @@ -652,14 +758,14 @@ func (conn *SSHConn) Connect(ctx context.Context, connFlags *wconfig.ConnKeyword } func (conn *SSHConn) WithLock(fn func()) { - conn.Lock.Lock() - defer conn.Lock.Unlock() + conn.lock.Lock() + defer conn.lock.Unlock() fn() } func WithLockRtn[T any](conn *SSHConn, fn func() T) T { - conn.Lock.Lock() - defer conn.Lock.Unlock() + conn.lock.Lock() + defer conn.lock.Unlock() return fn() } @@ -716,7 +822,7 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string) err = fmt.Errorf("error opening domain socket listener: %w", err) return WshCheckResult{NoWshReason: "error opening domain socket", NoWshCode: NoWshCode_DomainSocketError, WshError: err} } - needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false, false) + needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false, true) if err != nil { conn.Infof(ctx, "ERROR starting conn server: %v\n", err) err = fmt.Errorf("error starting conn server: %w", err) @@ -730,7 +836,7 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string) err = fmt.Errorf("error installing wsh: %w", err) return WshCheckResult{NoWshReason: "error installing wsh/connserver", NoWshCode: NoWshCode_InstallError, WshError: err} } - needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true, false) + needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true, true) if err != nil { conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err) err = fmt.Errorf("error starting conn server (after install): %w", err) @@ -820,12 +926,13 @@ func (conn *SSHConn) connectInternal(ctx context.Context, connFlags *wconfig.Con func (conn *SSHConn) waitForDisconnect() { defer conn.FireConnChangeEvent() - defer conn.HasWaiter.Store(false) client := conn.GetClient() if client == nil { return } err := client.Wait() + conn.lifecycleLock.Lock() + defer conn.lifecycleLock.Unlock() conn.WithLock(func() { // disconnects happen for a variety of reasons (like network, etc. and are typically transient) // so we just set the status to "disconnected" here (not error) @@ -836,8 +943,8 @@ func (conn *SSHConn) waitForDisconnect() { if conn.Status != Status_Error { conn.Status = Status_Disconnected } - conn.close_nolock() }) + conn.closeInternal_withlifecyclelock() } func (conn *SSHConn) SetWshError(err error) { @@ -861,7 +968,13 @@ func getConnInternal(opts *remote.SSHOpts, createIfNotExists bool) *SSHConn { defer globalLock.Unlock() rtn := clientControllerMap[*opts] if rtn == nil && createIfNotExists { - rtn = &SSHConn{Lock: &sync.Mutex{}, Status: Status_Init, WshEnabled: &atomic.Bool{}, Opts: opts, HasWaiter: &atomic.Bool{}} + rtn = &SSHConn{ + lock: &sync.Mutex{}, + lifecycleLock: &sync.Mutex{}, + Status: Status_Init, + WshEnabled: &atomic.Bool{}, + Opts: opts, + } clientControllerMap[*opts] = rtn } return rtn diff --git a/pkg/service/workspaceservice/workspaceservice.go b/pkg/service/workspaceservice/workspaceservice.go index 152aa09eb6..c0d5072a48 100644 --- a/pkg/service/workspaceservice/workspaceservice.go +++ b/pkg/service/workspaceservice/workspaceservice.go @@ -240,7 +240,7 @@ func (svc *WorkspaceService) CloseTab(ctx context.Context, workspaceId string, t if err == nil && tab != nil { go func() { for _, blockId := range tab.BlockIds { - blockcontroller.StopBlockController(blockId) + blockcontroller.DestroyBlockController(blockId) } }() } diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index d08250a4d8..7f9d6498c1 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -11,6 +11,7 @@ import ( "log" "os" "os/exec" + "path/filepath" "runtime" "strings" "sync" @@ -21,6 +22,7 @@ import ( "github.com/creack/pty" "github.com/wavetermdev/waveterm/pkg/blocklogger" + "github.com/wavetermdev/waveterm/pkg/jobcontroller" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" "github.com/wavetermdev/waveterm/pkg/util/pamparse" @@ -459,6 +461,113 @@ func StartRemoteShellProc(ctx context.Context, logCtx context.Context, termSize return &ShellProc{Cmd: sessionWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil } +func StartRemoteShellJob(ctx context.Context, logCtx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (string, error) { + connRoute := wshutil.MakeConnectionRouteId(conn.GetName()) + rpcClient := wshclient.GetBareRpcClient() + remoteInfo, err := wshclient.RemoteGetInfoCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000}) + if err != nil { + return "", fmt.Errorf("unable to obtain client info: %w", err) + } + log.Printf("client info collected: %+#v", remoteInfo) + var shellPath string + if cmdOpts.ShellPath != "" { + conn.Infof(logCtx, "using shell path from command opts: %s\n", cmdOpts.ShellPath) + shellPath = cmdOpts.ShellPath + } + configShellPath := conn.GetConfigShellPath() + if shellPath == "" && configShellPath != "" { + conn.Infof(logCtx, "using shell path from config (conn:shellpath): %s\n", configShellPath) + shellPath = configShellPath + } + if shellPath == "" && remoteInfo.Shell != "" { + conn.Infof(logCtx, "using shell path detected on remote machine: %s\n", remoteInfo.Shell) + shellPath = remoteInfo.Shell + } + if shellPath == "" { + conn.Infof(logCtx, "no shell path detected, using default (/bin/bash)\n") + shellPath = "/bin/bash" + } + var shellOpts []string + log.Printf("detected shell %q for conn %q\n", shellPath, conn.GetName()) + shellOpts = append(shellOpts, cmdOpts.ShellOpts...) + shellType := shellutil.GetShellTypeFromShellPath(shellPath) + conn.Infof(logCtx, "detected shell type: %s\n", shellType) + conn.Debugf(logCtx, "cmdStr: %q\n", cmdStr) + + if cmdStr == "" { + if shellType == shellutil.ShellType_bash { + bashPath := fmt.Sprintf("~/.waveterm/%s/.bashrc", shellutil.BashIntegrationDir) + shellOpts = append(shellOpts, "--rcfile", bashPath) + } else if shellType == shellutil.ShellType_fish { + if cmdOpts.Login { + shellOpts = append(shellOpts, "-l") + } + waveFishPath := fmt.Sprintf("~/.waveterm/%s/wave.fish", shellutil.FishIntegrationDir) + carg := fmt.Sprintf(`"source %s"`, waveFishPath) + shellOpts = append(shellOpts, "-C", carg) + } else if shellType == shellutil.ShellType_pwsh { + pwshPath := fmt.Sprintf("~/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir) + shellPath = "& " + shellPath + shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", pwshPath) + } else { + if cmdOpts.Login { + shellOpts = append(shellOpts, "-l") + } + if cmdOpts.Interactive { + shellOpts = append(shellOpts, "-i") + } + } + } else { + shellOpts = append(shellOpts, "-c", cmdStr) + } + conn.Infof(logCtx, "starting shell job, using command: %s %s\n", shellPath, strings.Join(shellOpts, " ")) + + if termSize.Rows == 0 || termSize.Cols == 0 { + termSize.Rows = shellutil.DefaultTermRows + termSize.Cols = shellutil.DefaultTermCols + } + if termSize.Rows <= 0 || termSize.Cols <= 0 { + return "", fmt.Errorf("invalid term size: %v", termSize) + } + + env := make(map[string]string) + env["TERM"] = shellutil.DefaultTermType + if shellType == shellutil.ShellType_zsh { + zshDir := filepath.Join(remoteInfo.HomeDir, ".waveterm", shellutil.ZshIntegrationDir) + conn.Infof(logCtx, "setting ZDOTDIR to %s\n", zshDir) + env["ZDOTDIR"] = zshDir + } + if cmdOpts.SwapToken != nil { + packedToken, err := cmdOpts.SwapToken.PackForClient() + if err != nil { + conn.Infof(logCtx, "error packing swap token: %v", err) + } else { + conn.Debugf(logCtx, "packed swaptoken %s\n", packedToken) + env[wavebase.WaveSwapTokenVarName] = packedToken + } + jwtToken := cmdOpts.SwapToken.Env[wavebase.WaveJwtTokenVarName] + if jwtToken != "" && cmdOpts.ForceJwt { + conn.Debugf(logCtx, "adding JWT token to environment\n") + env[wavebase.WaveJwtTokenVarName] = jwtToken + } + shellutil.AddTokenSwapEntry(cmdOpts.SwapToken) + } + + jobParams := jobcontroller.StartJobParams{ + ConnName: conn.GetName(), + Cmd: shellPath, + Args: shellOpts, + Env: env, + TermSize: &termSize, + } + jobId, err := jobcontroller.StartJob(ctx, jobParams) + if err != nil { + return "", fmt.Errorf("failed to start job: %w", err) + } + conn.Infof(logCtx, "started job: %s\n", jobId) + return jobId, nil +} + func StartLocalShellProc(logCtx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, connName string) (*ShellProc, error) { shellutil.InitCustomShellStartupFiles() var ecmd *exec.Cmd diff --git a/pkg/util/envutil/envutil.go b/pkg/util/envutil/envutil.go index dff40c1842..61dab43ba0 100644 --- a/pkg/util/envutil/envutil.go +++ b/pkg/util/envutil/envutil.go @@ -91,3 +91,39 @@ func EnvToSlice(envStr string) []string { } return result } + +func SliceToMap(env []string) map[string]string { + envMap := make(map[string]string) + for _, envVar := range env { + parts := strings.SplitN(envVar, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + return envMap +} + +func CopyAndAddToEnvMap(envMap map[string]string, key string, val string) map[string]string { + newMap := make(map[string]string, len(envMap)+1) + for k, v := range envMap { + newMap[k] = v + } + newMap[key] = val + return newMap +} + +func PruneInitialEnv(envMap map[string]string) map[string]string { + pruned := make(map[string]string) + for key, value := range envMap { + if strings.HasPrefix(key, "WAVETERM_") || strings.HasPrefix(key, "BASH_FUNC_") { + continue + } + if key == "XDG_SESSION_ID" || key == "SHLVL" || key == "S_COLORS" || + key == "SSH_CONNECTION" || key == "SSH_CLIENT" || key == "LESSOPEN" || + key == "which_declare" { + continue + } + pruned[key] = value + } + return pruned +} diff --git a/pkg/utilds/quickreorderqueue.go b/pkg/utilds/quickreorderqueue.go new file mode 100644 index 0000000000..b1da9c0afb --- /dev/null +++ b/pkg/utilds/quickreorderqueue.go @@ -0,0 +1,260 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package utilds + +import ( + "fmt" + "sort" + "sync" + "time" +) + +// the quick reorder queue implements reordering of items with a certain timerame (the timeout passed) +// if an item is queued in order, it gets processed immediately +// if it comes in out of order it gets buffered for up to the timeout while we wait for the correct next seq to come in +// if we still haven't received the "correct" next seq within the timeout the out of order event is flushed. +// "old" events (less than the current nextseq) are flushed immediately +// +// we also implement a "session" system. each session is assigned a virtual order based on the timestamp +// it was first seen. so all events of a session are either "before" or "after" all the events of a different session. +// the assumption is that sessions will always be separated by an amount of time greater than the timeout of the reorder queue (e.g. a system reboot, or main server restart) +// +// enqueuing without a sessionid or if seqNum is 0, will bypass the reorder queue and just flush the event + +type queuedItem[T any] struct { + sessionId string + seqNum int + data T + timestamp time.Time +} + +type QuickReorderQueue[T any] struct { + lock sync.Mutex + sessionOrder map[string]int64 // sessionId -> timestamp millis when first seen + currentSessionId string + nextSeqNum int + buffer []queuedItem[T] + outCh chan T + timeout time.Duration + timer *time.Timer + closed bool +} + +func MakeQuickReorderQueue[T any](bufSize int, timeout time.Duration) *QuickReorderQueue[T] { + return &QuickReorderQueue[T]{ + sessionOrder: make(map[string]int64), + nextSeqNum: 1, + outCh: make(chan T, bufSize), + timeout: timeout, + } +} + +func (q *QuickReorderQueue[T]) C() <-chan T { + return q.outCh +} + +func (q *QuickReorderQueue[T]) SetNextSeqNum(seqNum int) { + q.lock.Lock() + defer q.lock.Unlock() + q.nextSeqNum = seqNum +} + +func (q *QuickReorderQueue[T]) ensureSessionTs_withlock(sessionId string) { + if sessionId == "" { + return + } + if _, ok := q.sessionOrder[sessionId]; ok { + return + } + ts := time.Now().UnixMilli() + q.sessionOrder[sessionId] = ts + q.flushBuffer_withlock() + q.currentSessionId = sessionId + q.nextSeqNum = 1 +} + +func (q *QuickReorderQueue[T]) cmpSessionSeq_withlock(session1 string, seq1 int, session2 string, seq2 int) int { + ts1 := q.sessionOrder[session1] + ts2 := q.sessionOrder[session2] + if ts1 < ts2 { + return -1 + } + if ts1 > ts2 { + return 1 + } + if seq1 < seq2 { + return -1 + } + if seq1 > seq2 { + return 1 + } + return 0 +} + +func (q *QuickReorderQueue[T]) sortBuffer_withlock() { + sort.Slice(q.buffer, func(i, j int) bool { + return q.cmpSessionSeq_withlock(q.buffer[i].sessionId, q.buffer[i].seqNum, q.buffer[j].sessionId, q.buffer[j].seqNum) < 0 + }) +} + +func (q *QuickReorderQueue[T]) flushBuffer_withlock() { + if len(q.buffer) == 0 { + return + } + q.sortBuffer_withlock() + for _, item := range q.buffer { + q.outCh <- item.data + } + q.buffer = nil + if q.timer != nil { + q.timer.Stop() + q.timer = nil + } +} + +func (q *QuickReorderQueue[T]) QueueItem(sessionId string, seqNum int, data T) error { + q.lock.Lock() + defer q.lock.Unlock() + + if q.closed { + return fmt.Errorf("ReorderQueue is closed, cannot queue new item") + } + + if len(q.buffer)+len(q.outCh) >= cap(q.outCh) { + return fmt.Errorf("queue is full, cannot accept new items, cap: %d", cap(q.outCh)) + } + + q.ensureSessionTs_withlock(sessionId) + + cmp := q.cmpSessionSeq_withlock(sessionId, seqNum, q.currentSessionId, q.nextSeqNum) + + if cmp < 0 || seqNum == 0 || sessionId == "" { + q.outCh <- data + return nil + } + + if cmp == 0 { + q.outCh <- data + q.nextSeqNum++ + q.processBuffer_withlock() + return nil + } + + q.buffer = append(q.buffer, queuedItem[T]{ + sessionId: sessionId, + seqNum: seqNum, + data: data, + timestamp: time.Now(), + }) + q.ensureTimer_withlock() + return nil +} + +func (q *QuickReorderQueue[T]) processBuffer_withlock() { + if len(q.buffer) == 0 { + return + } + + q.sortBuffer_withlock() + + enqueued := 0 + for i, item := range q.buffer { + if item.sessionId == q.currentSessionId && item.seqNum == q.nextSeqNum { + q.outCh <- item.data + q.nextSeqNum++ + enqueued = i + 1 + } else { + break + } + } + + if enqueued > 0 { + q.buffer = q.buffer[enqueued:] + } +} + +func (q *QuickReorderQueue[T]) ensureTimer_withlock() { + if q.timer != nil { + return + } + q.timer = time.AfterFunc(q.timeout, func() { + q.onTimeout() + }) +} + +func (q *QuickReorderQueue[T]) onTimeout() { + q.lock.Lock() + defer q.lock.Unlock() + + if q.closed { + return + } + + q.timer = nil + + if len(q.buffer) == 0 { + return + } + + now := time.Now() + + q.sortBuffer_withlock() + + highestTimedOutSeq := -1 + for _, item := range q.buffer { + if now.Sub(item.timestamp) >= q.timeout { + highestTimedOutSeq = item.seqNum + } + } + + if highestTimedOutSeq == -1 { + return + } + + enqueued := 0 + for i, item := range q.buffer { + if item.seqNum <= highestTimedOutSeq { + q.outCh <- item.data + enqueued = i + 1 + if item.seqNum >= q.nextSeqNum { + q.nextSeqNum = item.seqNum + 1 + } + } else { + break + } + } + + q.buffer = q.buffer[enqueued:] + + if len(q.buffer) > 0 { + oldestTime := q.buffer[0].timestamp + for _, item := range q.buffer[1:] { + if item.timestamp.Before(oldestTime) { + oldestTime = item.timestamp + } + } + nextTimeout := q.timeout - now.Sub(oldestTime) + if nextTimeout < 0 { + nextTimeout = 0 + } + q.timer = time.AfterFunc(nextTimeout, func() { + q.onTimeout() + }) + } +} + +func (q *QuickReorderQueue[T]) Close() { + q.lock.Lock() + defer q.lock.Unlock() + + if q.closed { + return + } + q.closed = true + if q.timer != nil { + q.timer.Stop() + q.timer = nil + } + close(q.outCh) +} diff --git a/pkg/utilds/quickreorderqueue_test.go b/pkg/utilds/quickreorderqueue_test.go new file mode 100644 index 0000000000..1fb5131854 --- /dev/null +++ b/pkg/utilds/quickreorderqueue_test.go @@ -0,0 +1,443 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package utilds + +import ( + "testing" + "time" +) + +func collectItems[T any](ch <-chan T, count int, timeout time.Duration) []T { + result := make([]T, 0, count) + timer := time.NewTimer(timeout) + defer timer.Stop() + + for i := 0; i < count; i++ { + select { + case item := <-ch: + result = append(result, item) + case <-timer.C: + return result + } + } + return result +} + +func TestQuickReorderQueue_InOrder(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 100*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "item1") + q.QueueItem("session1", 2, "item2") + q.QueueItem("session1", 3, "item3") + + items := collectItems(q.C(), 3, 500*time.Millisecond) + + if len(items) != 3 { + t.Fatalf("expected 3 items, got %d", len(items)) + } + if items[0] != "item1" || items[1] != "item2" || items[2] != "item3" { + t.Errorf("expected [item1, item2, item3], got %v", items) + } +} + +func TestQuickReorderQueue_OutOfOrder(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 200*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "item1") + q.QueueItem("session1", 3, "item3") + q.QueueItem("session1", 2, "item2") + + items := collectItems(q.C(), 3, 500*time.Millisecond) + + if len(items) != 3 { + t.Fatalf("expected 3 items, got %d", len(items)) + } + if items[0] != "item1" || items[1] != "item2" || items[2] != "item3" { + t.Errorf("expected [item1, item2, item3], got %v", items) + } +} + +func TestQuickReorderQueue_MultipleOutOfOrder(t *testing.T) { + q := MakeQuickReorderQueue[int](10, 200*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, 1) + q.QueueItem("session1", 5, 5) + q.QueueItem("session1", 3, 3) + q.QueueItem("session1", 2, 2) + q.QueueItem("session1", 4, 4) + + items := collectItems(q.C(), 5, 500*time.Millisecond) + + if len(items) != 5 { + t.Fatalf("expected 5 items, got %d", len(items)) + } + for i := 0; i < 5; i++ { + if items[i] != i+1 { + t.Errorf("expected item %d at position %d, got %d", i+1, i, items[i]) + } + } +} + +func TestQuickReorderQueue_TwoSessions_StrongSeparation(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 200*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "s1-1") + q.QueueItem("session1", 2, "s1-2") + q.QueueItem("session1", 3, "s1-3") + + time.Sleep(500 * time.Millisecond) + + q.QueueItem("session2", 1, "s2-1") + q.QueueItem("session2", 2, "s2-2") + + items := collectItems(q.C(), 5, 500*time.Millisecond) + + if len(items) != 5 { + t.Fatalf("expected 5 items, got %d", len(items)) + } + + expected := []string{"s1-1", "s1-2", "s1-3", "s2-1", "s2-2"} + for i, exp := range expected { + if items[i] != exp { + t.Errorf("expected %s at position %d, got %s", exp, i, items[i]) + } + } +} + +func TestQuickReorderQueue_TwoSessions_OutOfOrder(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 200*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "s1-1") + q.QueueItem("session1", 3, "s1-3") + + time.Sleep(500 * time.Millisecond) + + q.QueueItem("session2", 1, "s2-1") + q.QueueItem("session1", 2, "s1-2") + q.QueueItem("session2", 3, "s2-3") + q.QueueItem("session2", 2, "s2-2") + + items := collectItems(q.C(), 6, 500*time.Millisecond) + + if len(items) != 6 { + t.Fatalf("expected 6 items, got %d", len(items)) + } + + expected := []string{"s1-1", "s1-3", "s2-1", "s1-2", "s2-2", "s2-3"} + for i, exp := range expected { + if items[i] != exp { + t.Errorf("expected %s at position %d, got %s", exp, i, items[i]) + } + } +} + +func TestQuickReorderQueue_ThreeSessions_Sequential(t *testing.T) { + q := MakeQuickReorderQueue[string](20, 200*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "s1-1") + q.QueueItem("session1", 2, "s1-2") + + time.Sleep(500 * time.Millisecond) + + q.QueueItem("session2", 1, "s2-1") + q.QueueItem("session2", 2, "s2-2") + + time.Sleep(500 * time.Millisecond) + + q.QueueItem("session3", 1, "s3-1") + q.QueueItem("session3", 2, "s3-2") + + items := collectItems(q.C(), 6, 1*time.Second) + + if len(items) != 6 { + t.Fatalf("expected 6 items, got %d", len(items)) + } + + expected := []string{"s1-1", "s1-2", "s2-1", "s2-2", "s3-1", "s3-2"} + for i, exp := range expected { + if items[i] != exp { + t.Errorf("expected %s at position %d, got %s", exp, i, items[i]) + } + } +} + +func TestQuickReorderQueue_SimpleTimeout(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 50*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "item1") + q.QueueItem("session1", 3, "item3") + + time.Sleep(100 * time.Millisecond) + + items := collectItems(q.C(), 2, 100*time.Millisecond) + + if len(items) != 2 { + t.Fatalf("expected 2 items after timeout, got %d", len(items)) + } + if items[0] != "item1" { + t.Errorf("expected item1 first, got %s", items[0]) + } + if items[1] != "item3" { + t.Errorf("expected item3 second (due to timeout), got %s", items[1]) + } + + q.QueueItem("session1", 5, "item5") + q.QueueItem("session1", 4, "item4") + + time.Sleep(100 * time.Millisecond) + + items2 := collectItems(q.C(), 2, 100*time.Millisecond) + + if len(items2) != 2 { + t.Fatalf("expected 2 more items after second timeout, got %d", len(items2)) + } + if items2[0] != "item4" || items2[1] != "item5" { + t.Errorf("expected [item4, item5] after reordering, got %v", items2) + } +} + +func TestQuickReorderQueue_RollingTimeout(t *testing.T) { + q := MakeQuickReorderQueue[string](20, 50*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "item1") + time.Sleep(10 * time.Millisecond) + + q.QueueItem("session1", 5, "item5") + time.Sleep(10 * time.Millisecond) + + q.QueueItem("session1", 3, "item3") + time.Sleep(10 * time.Millisecond) + + q.QueueItem("session1", 2, "item2") + time.Sleep(10 * time.Millisecond) + + q.QueueItem("session1", 4, "item4") + time.Sleep(10 * time.Millisecond) + + q.QueueItem("session1", 7, "item7") + time.Sleep(10 * time.Millisecond) + + q.QueueItem("session1", 6, "item6") + + time.Sleep(100 * time.Millisecond) + + items := collectItems(q.C(), 7, 200*time.Millisecond) + + if len(items) != 7 { + t.Fatalf("expected 7 items, got %d: %v", len(items), items) + } + + expected := []string{"item1", "item2", "item3", "item4", "item5", "item6", "item7"} + for i, exp := range expected { + if items[i] != exp { + t.Errorf("expected %s at position %d, got %s. Full output: %v", exp, i, items[i], items) + } + } +} + +func TestQuickReorderQueue_Timeout(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 150*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "item1") + q.QueueItem("session1", 3, "item3") + + time.Sleep(200 * time.Millisecond) + + items := collectItems(q.C(), 2, 100*time.Millisecond) + + if len(items) != 2 { + t.Fatalf("expected 2 items after timeout, got %d", len(items)) + } + if items[0] != "item1" { + t.Errorf("expected item1 first, got %s", items[0]) + } + if items[1] != "item3" { + t.Errorf("expected item3 second (due to timeout), got %s", items[1]) + } +} + +func TestQuickReorderQueue_TimeoutWithLateArrival(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 100*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "item1") + q.QueueItem("session1", 3, "item3") + + time.Sleep(150 * time.Millisecond) + + items := collectItems(q.C(), 2, 100*time.Millisecond) + + if len(items) != 2 { + t.Fatalf("expected 2 items after timeout, got %d", len(items)) + } + + q.QueueItem("session1", 2, "item2") + + lateItem := collectItems(q.C(), 1, 100*time.Millisecond) + if len(lateItem) != 1 { + t.Fatalf("expected 1 late item, got %d", len(lateItem)) + } + if lateItem[0] != "item2" { + t.Errorf("expected item2, got %s", lateItem[0]) + } +} + +func TestQuickReorderQueue_SessionOverlap_SmallWindow(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 200*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "s1-1") + q.QueueItem("session1", 2, "s1-2") + q.QueueItem("session1", 3, "s1-3") + + time.Sleep(500 * time.Millisecond) + + q.QueueItem("session2", 1, "s2-1") + + time.Sleep(50 * time.Millisecond) + + q.QueueItem("session1", 4, "s1-4") + q.QueueItem("session2", 2, "s2-2") + + items := collectItems(q.C(), 6, 500*time.Millisecond) + + if len(items) != 6 { + t.Fatalf("expected 6 items, got %d", len(items)) + } + + expected := []string{"s1-1", "s1-2", "s1-3", "s2-1", "s1-4", "s2-2"} + for i, exp := range expected { + if items[i] != exp { + t.Errorf("expected %s at position %d, got %s", exp, i, items[i]) + } + } +} + +func TestQuickReorderQueue_DuplicateSequence(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 200*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "item1-first") + q.QueueItem("session1", 2, "item2") + q.QueueItem("session1", 1, "item1-duplicate") + + items := collectItems(q.C(), 3, 500*time.Millisecond) + + if len(items) != 3 { + t.Fatalf("expected 3 items, got %d", len(items)) + } + if items[0] != "item1-first" || items[1] != "item2" || items[2] != "item1-duplicate" { + t.Errorf("got %v", items) + } +} + +func TestQuickReorderQueue_SetNextSeqNum(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 200*time.Millisecond) + defer q.Close() + + q.SetNextSeqNum(5) + + q.QueueItem("session1", 5, "item5") + q.QueueItem("session1", 6, "item6") + q.QueueItem("session1", 7, "item7") + + items := collectItems(q.C(), 3, 500*time.Millisecond) + + if len(items) != 3 { + t.Fatalf("expected 3 items, got %d", len(items)) + } + if items[0] != "item5" || items[1] != "item6" || items[2] != "item7" { + t.Errorf("expected [item5, item6, item7], got %v", items) + } +} + +func TestQuickReorderQueue_EmptyBuffer(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 200*time.Millisecond) + defer q.Close() + + select { + case <-q.C(): + t.Error("should not have any items") + case <-time.After(50 * time.Millisecond): + } +} + +func TestQuickReorderQueue_Close(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 200*time.Millisecond) + + q.QueueItem("session1", 1, "item1") + + q.Close() + + _, ok := <-q.C() + if !ok { + t.Error("expected to read item1 before close") + } + + _, ok = <-q.C() + if ok { + t.Error("channel should be closed") + } +} + +func TestQuickReorderQueue_CloseWithBufferedItems(t *testing.T) { + q := MakeQuickReorderQueue[string](10, 200*time.Millisecond) + + q.QueueItem("session1", 1, "item1") + q.QueueItem("session1", 3, "item3") + + q.Close() + + item, ok := <-q.C() + if !ok || item != "item1" { + t.Errorf("expected item1, got %s (ok=%v)", item, ok) + } + + _, ok = <-q.C() + if ok { + t.Error("channel should be closed, item3 should be dropped as buffered") + } +} + +func TestQuickReorderQueue_MultiSessionComplexReordering(t *testing.T) { + q := MakeQuickReorderQueue[string](20, 300*time.Millisecond) + defer q.Close() + + q.QueueItem("session1", 1, "s1-1") + q.QueueItem("session1", 4, "s1-4") + q.QueueItem("session1", 2, "s1-2") + + time.Sleep(500 * time.Millisecond) + + q.QueueItem("session2", 2, "s2-2") + q.QueueItem("session2", 1, "s2-1") + q.QueueItem("session1", 3, "s1-3") + + time.Sleep(500 * time.Millisecond) + + q.QueueItem("session3", 1, "s3-1") + q.QueueItem("session2", 3, "s2-3") + + items := collectItems(q.C(), 8, 1*time.Second) + + if len(items) != 8 { + t.Fatalf("expected 8 items, got %d", len(items)) + } + + expected := []string{"s1-1", "s1-2", "s1-4", "s2-1", "s2-2", "s1-3", "s3-1", "s2-3"} + for i, exp := range expected { + if items[i] != exp { + t.Errorf("expected %s at position %d, got %s", exp, i, items[i]) + } + } +} diff --git a/pkg/utilds/versionts.go b/pkg/utilds/versionts.go new file mode 100644 index 0000000000..fafb1da520 --- /dev/null +++ b/pkg/utilds/versionts.go @@ -0,0 +1,24 @@ +package utilds + +import ( + "sync" + "time" +) + +type VersionTs struct { + lock sync.Mutex + lastVersion int64 +} + +func (v *VersionTs) GetVersionTs() int64 { + v.lock.Lock() + defer v.lock.Unlock() + + nowMs := time.Now().UnixMilli() + if nowMs <= v.lastVersion { + v.lastVersion++ + return v.lastVersion + } + v.lastVersion = nowMs + return v.lastVersion +} diff --git a/pkg/waveapp/waveapp.go b/pkg/waveapp/waveapp.go index 339e09403b..9e9b5067f2 100644 --- a/pkg/waveapp/waveapp.go +++ b/pkg/waveapp/waveapp.go @@ -181,11 +181,11 @@ func (client *Client) Connect() error { return fmt.Errorf("error setting up domain socket rpc client: %v", err) } client.RpcClient = rpcClient - _, err = wshclient.AuthenticateCommand(client.RpcClient, jwtToken, &wshrpc.RpcOpts{Route: wshutil.ControlRoute}) + authRtnData, err := wshclient.AuthenticateCommand(client.RpcClient, jwtToken, &wshrpc.RpcOpts{Route: wshutil.ControlRoute}) if err != nil { return fmt.Errorf("error authenticating rpc connection: %v", err) } - client.RouteId = rpcCtx.RouteId + client.RouteId = authRtnData.RouteId return nil } diff --git a/pkg/wavebase/wavebase.go b/pkg/wavebase/wavebase.go index bdaff9b137..4043b0c378 100644 --- a/pkg/wavebase/wavebase.go +++ b/pkg/wavebase/wavebase.go @@ -181,6 +181,10 @@ func GetDomainSocketName() string { return filepath.Join(GetWaveDataDir(), DomainSocketBaseName) } +func GetPersistentRemoteSockName(clientId string) string { + return filepath.Join("~", ".waveterm", "client", clientId, "waveterm.sock") +} + func EnsureWaveDataDir() error { return CacheEnsureDir(GetWaveDataDir(), "wavehome", 0700, "wave home directory") } diff --git a/pkg/wavejwt/wavejwt.go b/pkg/wavejwt/wavejwt.go index 9e91003c58..c406f1a170 100644 --- a/pkg/wavejwt/wavejwt.go +++ b/pkg/wavejwt/wavejwt.go @@ -29,6 +29,7 @@ type WaveJwtClaims struct { MainServer bool `json:"mainserver,omitempty"` Sock string `json:"sock,omitempty"` RouteId string `json:"routeid,omitempty"` + ProcRoute bool `json:"procroute,omitempty"` BlockId string `json:"blockid,omitempty"` JobId string `json:"jobid,omitempty"` Conn string `json:"conn,omitempty"` diff --git a/pkg/waveobj/metaconsts.go b/pkg/waveobj/metaconsts.go index 801929541d..039265c29f 100644 --- a/pkg/waveobj/metaconsts.go +++ b/pkg/waveobj/metaconsts.go @@ -41,6 +41,7 @@ const ( MetaKey_Cmd = "cmd" MetaKey_CmdInteractive = "cmd:interactive" MetaKey_CmdLogin = "cmd:login" + MetaKey_CmdPersistent = "cmd:persistent" MetaKey_CmdRunOnStart = "cmd:runonstart" MetaKey_CmdClearOnStart = "cmd:clearonstart" MetaKey_CmdRunOnce = "cmd:runonce" diff --git a/pkg/waveobj/wtypemeta.go b/pkg/waveobj/wtypemeta.go index efe0a79f18..9525149ef6 100644 --- a/pkg/waveobj/wtypemeta.go +++ b/pkg/waveobj/wtypemeta.go @@ -40,6 +40,7 @@ type MetaTSType struct { Cmd string `json:"cmd,omitempty"` CmdInteractive bool `json:"cmd:interactive,omitempty"` CmdLogin bool `json:"cmd:login,omitempty"` + CmdPersistent bool `json:"cmd:persistent,omitempty"` CmdRunOnStart bool `json:"cmd:runonstart,omitempty"` CmdClearOnStart bool `json:"cmd:clearonstart,omitempty"` CmdRunOnce bool `json:"cmd:runonce,omitempty"` diff --git a/pkg/wcore/block.go b/pkg/wcore/block.go index fc66232a32..47864c022c 100644 --- a/pkg/wcore/block.go +++ b/pkg/wcore/block.go @@ -201,7 +201,7 @@ func DeleteBlock(ctx context.Context, blockId string, recursive bool) error { } SendActiveTabUpdate(ctx, parentWorkspaceId, newActiveTabId) } - go blockcontroller.StopBlockController(blockId) + go blockcontroller.DestroyBlockController(blockId) sendBlockCloseEvent(blockId) return nil } diff --git a/pkg/wcore/wcore.go b/pkg/wcore/wcore.go index 0c3fb905eb..d82aa67d3a 100644 --- a/pkg/wcore/wcore.go +++ b/pkg/wcore/wcore.go @@ -59,6 +59,7 @@ func EnsureInitialData() (bool, error) { } } log.Printf("clientid: %s\n", client.OID) + wstore.SetClientId(client.OID) if len(client.WindowIds) == 1 { log.Println("client has one window") CheckAndFixWindow(ctx, client.WindowIds[0]) @@ -151,16 +152,8 @@ func GoSendNoTelemetryUpdate(telemetryEnabled bool) { }() ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) defer cancelFn() - clientData, err := wstore.DBGetSingleton[*waveobj.Client](ctx) - if err != nil { - log.Printf("telemetry update: error getting client data: %v\n", err) - return - } - if clientData == nil { - log.Printf("telemetry update: client data is nil\n") - return - } - err = wcloud.SendNoTelemetryUpdate(ctx, clientData.OID, !telemetryEnabled) + clientId := wstore.GetClientId() + err := wcloud.SendNoTelemetryUpdate(ctx, clientId, !telemetryEnabled) if err != nil { log.Printf("[error] sending no-telemetry update: %v\n", err) return diff --git a/pkg/web/ws.go b/pkg/web/ws.go index 719753ba3c..ba1c9eb0ef 100644 --- a/pkg/web/ws.go +++ b/pkg/web/ws.go @@ -163,7 +163,7 @@ func ReadLoop(conn *websocket.Conn, outputCh chan any, closeCh chan any, rpcInpu outputCh <- pongMessage continue } - go processMessage(jmsg, outputCh, rpcInputCh) + processMessage(jmsg, outputCh, rpcInputCh) } } diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index 62cd66d90c..a8a3cf51b3 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -125,6 +125,12 @@ func ConnReinstallWshCommand(w *wshutil.WshRpc, data wshrpc.ConnExtData, opts *w return err } +// command "connserverinit", wshserver.ConnServerInitCommand +func ConnServerInitCommand(w *wshutil.WshRpc, data wshrpc.CommandConnServerInitData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "connserverinit", data, opts) + return err +} + // command "connstatus", wshserver.ConnStatusCommand func ConnStatusCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]wshrpc.ConnStatus, error) { resp, err := sendRpcRequestCallHelper[[]wshrpc.ConnStatus](w, "connstatus", nil, opts) @@ -143,6 +149,12 @@ func ControllerAppendOutputCommand(w *wshutil.WshRpc, data wshrpc.CommandControl return err } +// command "controllerdestroy", wshserver.ControllerDestroyCommand +func ControllerDestroyCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "controllerdestroy", data, opts) + return err +} + // command "controllerinput", wshserver.ControllerInputCommand func ControllerInputCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockInputData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "controllerinput", data, opts) @@ -155,12 +167,6 @@ func ControllerResyncCommand(w *wshutil.WshRpc, data wshrpc.CommandControllerRes return err } -// command "controllerstop", wshserver.ControllerStopCommand -func ControllerStopCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { - _, err := sendRpcRequestCallHelper[any](w, "controllerstop", data, opts) - return err -} - // command "createblock", wshserver.CreateBlockCommand func CreateBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandCreateBlockData, opts *wshrpc.RpcOpts) (waveobj.ORef, error) { resp, err := sendRpcRequestCallHelper[waveobj.ORef](w, "createblock", data, opts) diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index f0cfbb145c..e8dc392c34 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -9,6 +9,7 @@ import ( "io" "log" "net" + "os" "path/filepath" "sync" @@ -30,17 +31,21 @@ type ServerImpl struct { Router *wshutil.WshRouter RpcClient *wshutil.WshRpc IsLocal bool + InitialEnv map[string]string JobManagerMap map[string]*JobManagerConnection + SockName string Lock sync.Mutex } -func MakeRemoteRpcServerImpl(logWriter io.Writer, router *wshutil.WshRouter, rpcClient *wshutil.WshRpc, isLocal bool) *ServerImpl { +func MakeRemoteRpcServerImpl(logWriter io.Writer, router *wshutil.WshRouter, rpcClient *wshutil.WshRpc, isLocal bool, initialEnv map[string]string, sockName string) *ServerImpl { return &ServerImpl{ LogWriter: logWriter, Router: router, RpcClient: rpcClient, IsLocal: isLocal, + InitialEnv: initialEnv, JobManagerMap: make(map[string]*JobManagerConnection), + SockName: sockName, } } @@ -92,6 +97,32 @@ func (*ServerImpl) DisposeSuggestionsCommand(ctx context.Context, widgetId strin return nil } +func (impl *ServerImpl) ConnServerInitCommand(ctx context.Context, data wshrpc.CommandConnServerInitData) error { + if data.ClientId == "" { + return fmt.Errorf("clientid is required") + } + if impl.SockName == "" { + return fmt.Errorf("sockname not set in server impl") + } + symlinkPath, err := wavebase.ExpandHomeDir(wavebase.GetPersistentRemoteSockName(data.ClientId)) + if err != nil { + return fmt.Errorf("cannot expand symlink path: %w", err) + } + symlinkDir := filepath.Dir(symlinkPath) + + if err := os.MkdirAll(symlinkDir, 0700); err != nil { + impl.Log("warning: could not create client directory %s: %v\n", symlinkDir, err) + return nil + } + os.Remove(symlinkPath) + if err := os.Symlink(impl.SockName, symlinkPath); err != nil { + impl.Log("warning: could not create symlink %s -> %s: %v\n", symlinkPath, impl.SockName, err) + return nil + } + impl.Log("created symlink %s -> %s\n", symlinkPath, impl.SockName) + return nil +} + func (impl *ServerImpl) getWshPath() (string, error) { if impl.IsLocal { return filepath.Join(wavebase.GetWaveDataDir(), "bin", "wsh"), nil diff --git a/pkg/wshrpc/wshremote/wshremote_job.go b/pkg/wshrpc/wshremote/wshremote_job.go index 12545c9cc1..df6d1d4a0a 100644 --- a/pkg/wshrpc/wshremote/wshremote_job.go +++ b/pkg/wshrpc/wshremote/wshremote_job.go @@ -254,10 +254,17 @@ func (impl *ServerImpl) RemoteStartJobCommand(ctx context.Context, data wshrpc.C return nil, err } + combinedEnv := make(map[string]string) + for k, v := range impl.InitialEnv { + combinedEnv[k] = v + } + for k, v := range data.Env { + combinedEnv[k] = v + } startJobData := wshrpc.CommandStartJobData{ Cmd: data.Cmd, Args: data.Args, - Env: data.Env, + Env: combinedEnv, TermSize: data.TermSize, StreamMeta: data.StreamMeta, } diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index c0d8d1214b..ed89dc0892 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -9,6 +9,7 @@ import ( "context" "encoding/json" + "github.com/google/uuid" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata" "github.com/wavetermdev/waveterm/pkg/vdom" @@ -44,7 +45,7 @@ type WshRpcInterface interface { GetMetaCommand(ctx context.Context, data CommandGetMetaData) (waveobj.MetaMapType, error) SetMetaCommand(ctx context.Context, data CommandSetMetaData) error ControllerInputCommand(ctx context.Context, data CommandBlockInputData) error - ControllerStopCommand(ctx context.Context, blockId string) error + ControllerDestroyCommand(ctx context.Context, blockId string) error ControllerResyncCommand(ctx context.Context, data CommandControllerResyncData) error ControllerAppendOutputCommand(ctx context.Context, data CommandControllerAppendOutputData) error ResolveIdsCommand(ctx context.Context, data CommandResolveIdsData) (CommandResolveIdsRtnData, error) @@ -99,6 +100,7 @@ type WshRpcInterface interface { DismissWshFailCommand(ctx context.Context, connName string) error ConnUpdateWshCommand(ctx context.Context, remoteInfo RemoteInfo) (bool, error) FindGitBashCommand(ctx context.Context, rescan bool) (string, error) + ConnServerInitCommand(ctx context.Context, data CommandConnServerInitData) error // eventrecv is special, it's handled internally by WshRpc with EventListener EventRecvCommand(ctx context.Context, data wps.WaveEvent) error @@ -203,14 +205,24 @@ type RpcOpts struct { } type RpcContext struct { - SockName string `json:"sockname,omitempty"` // the domain socket name - RouteId string `json:"routeid"` // the routeid from the jwt - BlockId string `json:"blockid,omitempty"` // blockid for this rpc - Conn string `json:"conn,omitempty"` // the conn name - IsRouter bool `json:"isrouter,omitempty"` // if this is for a sub-router + SockName string `json:"sockname,omitempty"` // the domain socket name + RouteId string `json:"routeid"` // the routeid from the jwt + ProcRoute bool `json:"procroute,omitempty"` // use a random procid for route + BlockId string `json:"blockid,omitempty"` // blockid for this rpc + Conn string `json:"conn,omitempty"` // the conn name + IsRouter bool `json:"isrouter,omitempty"` // if this is for a sub-router +} + +func (rc RpcContext) GenerateRouteId() string { + if rc.RouteId != "" { + return rc.RouteId + } + return "proc:" + uuid.New().String() } type CommandAuthenticateRtnData struct { + RouteId string `json:"routeid"` + // these fields are only set when doing a token swap Env map[string]string `json:"env,omitempty"` InitScriptText string `json:"initscripttext,omitempty"` @@ -284,10 +296,12 @@ type CommandBlockInputData struct { } type CommandJobInputData struct { - JobId string `json:"jobid"` - InputData64 string `json:"inputdata64,omitempty"` - SigName string `json:"signame,omitempty"` - TermSize *waveobj.TermSize `json:"termsize,omitempty"` + JobId string `json:"jobid"` + InputSessionId string `json:"inputsessionid,omitempty"` + SeqNum int `json:"seqnum,omitempty"` + InputData64 string `json:"inputdata64,omitempty"` + SigName string `json:"signame,omitempty"` + TermSize *waveobj.TermSize `json:"termsize,omitempty"` } type CommandWaitForRouteData struct { @@ -382,6 +396,7 @@ type RemoteInfo struct { ClientOs string `json:"clientos"` ClientVersion string `json:"clientversion"` Shell string `json:"shell"` + HomeDir string `json:"homedir"` } const ( @@ -599,6 +614,10 @@ type ConnExtData struct { LogBlockId string `json:"logblockid,omitempty"` } +type CommandConnServerInitData struct { + ClientId string `json:"clientid"` +} + type FetchSuggestionsData struct { SuggestionType string `json:"suggestiontype"` Query string `json:"query"` @@ -764,8 +783,9 @@ type CommandStartJobRtnData struct { } type CommandJobPrepareConnectData struct { - StreamMeta StreamMeta `json:"streammeta"` - Seq int64 `json:"seq"` + StreamMeta StreamMeta `json:"streammeta"` + Seq int64 `json:"seq"` + TermSize waveobj.TermSize `json:"termsize"` } type CommandJobStartStreamData struct { diff --git a/pkg/wshrpc/wshserver/resolvers.go b/pkg/wshrpc/wshserver/resolvers.go index 552f3d3bca..3a6a33375c 100644 --- a/pkg/wshrpc/wshserver/resolvers.go +++ b/pkg/wshrpc/wshserver/resolvers.go @@ -106,11 +106,8 @@ func resolveThis(ctx context.Context, data wshrpc.CommandResolveIdsData, value s return &waveobj.ORef{OType: waveobj.OType_Workspace, OID: wsId}, nil } if value == SimpleId_Client || value == SimpleId_Global { - client, err := wstore.DBGetSingleton[*waveobj.Client](ctx) - if err != nil { - return nil, fmt.Errorf("error getting client: %v", err) - } - return &waveobj.ORef{OType: waveobj.OType_Client, OID: client.OID}, nil + clientId := wstore.GetClientId() + return &waveobj.ORef{OType: waveobj.OType_Client, OID: clientId}, nil } if value == SimpleId_Temp { client, err := wstore.DBGetSingleton[*waveobj.Client](ctx) diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 6446b5ed25..54a5197cb7 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -283,8 +283,8 @@ func (ws *WshServer) CreateSubBlockCommand(ctx context.Context, data wshrpc.Comm return blockRef, nil } -func (ws *WshServer) ControllerStopCommand(ctx context.Context, blockId string) error { - blockcontroller.StopBlockController(blockId) +func (ws *WshServer) ControllerDestroyCommand(ctx context.Context, blockId string) error { + blockcontroller.DestroyBlockController(blockId) return nil } @@ -295,21 +295,6 @@ func (ws *WshServer) ControllerResyncCommand(ctx context.Context, data wshrpc.Co } func (ws *WshServer) ControllerInputCommand(ctx context.Context, data wshrpc.CommandBlockInputData) error { - block, err := wstore.DBMustGet[*waveobj.Block](ctx, data.BlockId) - if err != nil { - return fmt.Errorf("error getting block: %w", err) - } - - if block.JobId != "" { - jobInputData := wshrpc.CommandJobInputData{ - JobId: block.JobId, - InputData64: data.InputData64, - SigName: data.SigName, - TermSize: data.TermSize, - } - return jobcontroller.SendInput(ctx, jobInputData) - } - inputUnion := &blockcontroller.BlockInputUnion{ SigName: data.SigName, TermSize: data.TermSize, @@ -865,13 +850,9 @@ func (ws *WshServer) BlockInfoCommand(ctx context.Context, blockId string) (*wsh } func (ws *WshServer) WaveInfoCommand(ctx context.Context) (*wshrpc.WaveInfoData, error) { - client, err := wstore.DBGetSingleton[*waveobj.Client](ctx) - if err != nil { - return nil, fmt.Errorf("error getting client: %w", err) - } return &wshrpc.WaveInfoData{ Version: wavebase.WaveVersion, - ClientId: client.OID, + ClientId: wstore.GetClientId(), BuildTime: wavebase.BuildTime, ConfigDir: wavebase.GetWaveConfigDir(), DataDir: wavebase.GetWaveDataDir(), @@ -1198,11 +1179,7 @@ func (ws *WshServer) RecordTEventCommand(ctx context.Context, data telemetrydata } func (ws WshServer) SendTelemetryCommand(ctx context.Context) error { - client, err := wstore.DBGetSingleton[*waveobj.Client](ctx) - if err != nil { - return fmt.Errorf("getting client data for telemetry: %v", err) - } - return wcloud.SendAllTelemetry(client.OID) + return wcloud.SendAllTelemetry(wstore.GetClientId()) } func (ws *WshServer) WaveAIEnableTelemetryCommand(ctx context.Context) error { @@ -1215,12 +1192,6 @@ func (ws *WshServer) WaveAIEnableTelemetryCommand(ctx context.Context) error { return fmt.Errorf("error setting telemetry enabled: %w", err) } - // Get client for telemetry operations - client, err := wstore.DBGetSingleton[*waveobj.Client](ctx) - if err != nil { - return fmt.Errorf("getting client data for telemetry: %v", err) - } - // Record the telemetry event event := telemetrydata.MakeTEvent("waveai:enabletelemetry", telemetrydata.TEventProps{}) err = telemetry.RecordTEvent(ctx, event) @@ -1229,7 +1200,7 @@ func (ws *WshServer) WaveAIEnableTelemetryCommand(ctx context.Context) error { } // Immediately send telemetry to cloud - err = wcloud.SendAllTelemetry(client.OID) + err = wcloud.SendAllTelemetry(wstore.GetClientId()) if err != nil { log.Printf("error sending telemetry after enabling: %v", err) } @@ -1479,7 +1450,7 @@ func (ws *WshServer) JobControllerDisconnectJobCommand(ctx context.Context, jobI } func (ws *WshServer) JobControllerReconnectJobCommand(ctx context.Context, jobId string) error { - return jobcontroller.ReconnectJob(ctx, jobId) + return jobcontroller.ReconnectJob(ctx, jobId, nil) } func (ws *WshServer) JobControllerReconnectJobsForConnCommand(ctx context.Context, connName string) error { diff --git a/pkg/wshutil/wshproxy.go b/pkg/wshutil/wshproxy.go index 4a8126789c..c20b485d22 100644 --- a/pkg/wshutil/wshproxy.go +++ b/pkg/wshutil/wshproxy.go @@ -39,8 +39,7 @@ func (p *WshRpcProxy) SetPeerInfo(peerInfo string) { p.PeerInfo = peerInfo } -// TODO: Figure out who is sending to closed routes and why we're not catching it -func (p *WshRpcProxy) SendRpcMessage(msg []byte, ingressLinkId baseds.LinkId, debugStr string) { +func (p *WshRpcProxy) SendRpcMessage(msg []byte, ingressLinkId baseds.LinkId, debugStr string) bool { defer func() { panicCtx := "WshRpcProxy.SendRpcMessage" if debugStr != "" { @@ -48,7 +47,12 @@ func (p *WshRpcProxy) SendRpcMessage(msg []byte, ingressLinkId baseds.LinkId, de } panichandler.PanicHandler(panicCtx, recover()) }() - p.ToRemoteCh <- msg + select { + case p.ToRemoteCh <- msg: + return true + default: + return false + } } func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) { diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go index 32c889756b..375772b924 100644 --- a/pkg/wshutil/wshrouter.go +++ b/pkg/wshutil/wshrouter.go @@ -40,6 +40,10 @@ const ( RoutePrefix_Bare = "bare:" ) +const RouterInputChQueueSize = 100 + +var BacklogLogThresholds = map[int]bool{1: true, 5: true, 10: true, 20: true, 30: true, 40: true, 50: true, 100: true, 200: true, 500: true, 1000: true} + // this works like a network switch // TODO maybe move the wps integration here instead of in wshserver @@ -76,6 +80,12 @@ type messageWrap struct { debugStr string } +type backlogMessageWrap struct { + msgBytes []byte + ingressLinkId baseds.LinkId + debugStr string +} + type WshRouter struct { lock *sync.Mutex isRootRouter bool @@ -91,6 +101,10 @@ type WshRouter struct { upstreamBuf []messageWrap upstreamLoopStarted bool + linkBacklogCond *sync.Cond + linkMsgBacklog map[baseds.LinkId][]backlogMessageWrap + backlogHighWaterMark map[baseds.LinkId]int + controlRpc *WshRpc } @@ -134,17 +148,21 @@ var DefaultRouter *WshRouter func NewWshRouter() *WshRouter { rtn := &WshRouter{ - lock: &sync.Mutex{}, - nextLinkId: 0, - upstreamLinkId: baseds.NoLinkId, - inputCh: make(chan baseds.RpcInputChType), - rpcMap: make(map[string]rpcRoutingInfo), - linkMap: make(map[baseds.LinkId]*linkMeta), - routeMap: make(map[string]baseds.LinkId), + lock: &sync.Mutex{}, + nextLinkId: 0, + upstreamLinkId: baseds.NoLinkId, + inputCh: make(chan baseds.RpcInputChType, RouterInputChQueueSize), + rpcMap: make(map[string]rpcRoutingInfo), + linkMap: make(map[baseds.LinkId]*linkMeta), + routeMap: make(map[string]baseds.LinkId), + linkMsgBacklog: make(map[baseds.LinkId][]backlogMessageWrap), + backlogHighWaterMark: make(map[baseds.LinkId]int), } rtn.upstreamBufCond = sync.NewCond(&rtn.upstreamBufLock) + rtn.linkBacklogCond = sync.NewCond(rtn.lock) rtn.registerControlPlane() go rtn.runServer() + go rtn.processBacklog() return rtn } @@ -154,6 +172,10 @@ func (router *WshRouter) IsRootRouter() bool { return router.isRootRouter } +func (router *WshRouter) GetControlRpc() *WshRpc { + return router.controlRpc +} + func (router *WshRouter) SetAsRootRouter() { router.lock.Lock() defer router.lock.Unlock() @@ -192,7 +214,7 @@ func (router *WshRouter) SendEvent(routeId string, event wps.WaveEvent) { // nothing to do return } - lm.client.SendRpcMessage(msgBytes, baseds.NoLinkId, "eventrecv") + router.sendRpcMessageToLink(lm.linkId, lm.client, msgBytes, baseds.NoLinkId, "eventrecv") } func (router *WshRouter) handleNoRoute(msg RpcMessage, ingressLinkId baseds.LinkId) { @@ -214,7 +236,7 @@ func (router *WshRouter) handleNoRoute(msg RpcMessage, ingressLinkId baseds.Link Data: wshrpc.CommandMessageData{Message: nrErr.Error()}, } respBytes, _ := json.Marshal(respMsg) - lm.client.SendRpcMessage(respBytes, baseds.NoLinkId, "no-route-err") + router.sendRpcMessageToLink(lm.linkId, lm.client, respBytes, baseds.NoLinkId, "no-route-err") return } // send error response @@ -266,12 +288,12 @@ func (router *WshRouter) sendRoutedMessage(msgBytes []byte, routeId string, comm } lm := router.getLinkForRoute(routeId) if lm != nil { - lm.client.SendRpcMessage(msgBytes, ingressLinkId, "route") + router.sendRpcMessageToLink(lm.linkId, lm.client, msgBytes, ingressLinkId, "route") return true } - upstream := router.getUpstreamClient() + upstreamLinkId, upstream := router.getUpstreamClient() if upstream != nil { - upstream.SendRpcMessage(msgBytes, ingressLinkId, "route-upstream") + router.sendRpcMessageToLink(upstreamLinkId, upstream, msgBytes, ingressLinkId, "route-upstream") return true } if commandName != "" { @@ -287,10 +309,45 @@ func (router *WshRouter) sendMessageToLink(msgBytes []byte, linkId baseds.LinkId if lm == nil { return false } - lm.client.SendRpcMessage(msgBytes, ingressLinkId, "link") + router.sendRpcMessageToLink(lm.linkId, lm.client, msgBytes, ingressLinkId, "link") return true } +func (router *WshRouter) addToBacklog_withlock(linkId baseds.LinkId, msgBytes []byte, ingressLinkId baseds.LinkId, debugStr string) { + mapWasEmpty := len(router.linkMsgBacklog) == 0 + backlog := router.linkMsgBacklog[linkId] + backlog = append(backlog, backlogMessageWrap{msgBytes: msgBytes, ingressLinkId: ingressLinkId, debugStr: debugStr}) + router.linkMsgBacklog[linkId] = backlog + + newLen := len(backlog) + highWater := router.backlogHighWaterMark[linkId] + + if BacklogLogThresholds[newLen] && highWater < newLen { + log.Printf("[router] backlog for linkid=%d reached %d messages\n", linkId, newLen) + } + + if newLen > highWater { + router.backlogHighWaterMark[linkId] = newLen + } + + if mapWasEmpty { + router.linkBacklogCond.Signal() + } +} + +func (router *WshRouter) sendRpcMessageToLink(linkId baseds.LinkId, client AbstractRpcClient, msgBytes []byte, ingressLinkId baseds.LinkId, debugStr string) { + router.lock.Lock() + defer router.lock.Unlock() + sent := false + backlog := router.linkMsgBacklog[linkId] + if len(backlog) == 0 { + sent = client.SendRpcMessage(msgBytes, ingressLinkId, debugStr) + } + if !sent { + router.addToBacklog_withlock(linkId, msgBytes, ingressLinkId, debugStr) + } +} + func (router *WshRouter) runServer() { for input := range router.inputCh { msgBytes := input.MsgBytes @@ -355,7 +412,8 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er // this will never block, can be called while holding router.Lock func (router *WshRouter) queueUpstreamMessage(msgBytes []byte, debugStr string) { - if router.getUpstreamClient() == nil { + _, upstream := router.getUpstreamClient() + if upstream == nil { return } router.upstreamBufLock.Lock() @@ -381,13 +439,68 @@ func (router *WshRouter) runUpstreamBufferLoop() { router.upstreamBuf = router.upstreamBuf[1:] router.upstreamBufLock.Unlock() - upstream := router.getUpstreamClient() + upstreamLinkId, upstream := router.getUpstreamClient() if upstream != nil { - upstream.SendRpcMessage(msg.msgBytes, baseds.NoLinkId, msg.debugStr) + router.sendRpcMessageToLink(upstreamLinkId, upstream, msg.msgBytes, baseds.NoLinkId, msg.debugStr) } } } +func (router *WshRouter) drainLinkBacklog_withLock(linkId baseds.LinkId, lm *linkMeta, backlog []backlogMessageWrap) []backlogMessageWrap { + for len(backlog) > 0 { + msg := backlog[0] + sent := lm.client.SendRpcMessage(msg.msgBytes, msg.ingressLinkId, msg.debugStr) + if !sent { + return backlog + } + backlog = backlog[1:] + } + return backlog +} + +func (router *WshRouter) processOneBacklogRound() { + router.lock.Lock() + defer router.lock.Unlock() + for linkId, backlog := range router.linkMsgBacklog { + lm := router.linkMap[linkId] + if lm == nil { + highWater := router.backlogHighWaterMark[linkId] + if highWater > 0 { + log.Printf("[router] backlog for linkid=%d cleared, link gone (highwater mark was %d messages)\n", linkId, highWater) + } + delete(router.linkMsgBacklog, linkId) + delete(router.backlogHighWaterMark, linkId) + continue + } + newBacklog := router.drainLinkBacklog_withLock(linkId, lm, backlog) + if len(newBacklog) == 0 { + highWater := router.backlogHighWaterMark[linkId] + if highWater > 0 { + log.Printf("[router] backlog for linkid=%d cleared (highwater mark was %d messages)\n", linkId, highWater) + } + delete(router.linkMsgBacklog, linkId) + delete(router.backlogHighWaterMark, linkId) + continue + } + router.linkMsgBacklog[linkId] = newBacklog + } +} + +func (router *WshRouter) processBacklog() { + defer func() { + panichandler.PanicHandler("WshRouter:processBacklog", recover()) + }() + for { + router.lock.Lock() + for len(router.linkMsgBacklog) == 0 { + router.linkBacklogCond.Wait() + } + router.lock.Unlock() + router.processOneBacklogRound() + time.Sleep(50 * time.Millisecond) + } +} + func (router *WshRouter) RegisterUntrustedLink(client AbstractRpcClient) baseds.LinkId { router.lock.Lock() defer router.lock.Unlock() @@ -459,7 +572,7 @@ func (router *WshRouter) runLinkClientRecvLoop(linkId baseds.LinkId, client Abst isControlRoute := rpcMsg.Route == ControlRoute || rpcMsg.Route == ControlRootRoute if !lm.trusted { if !isControlRoute { - sendControlUnauthenticatedErrorResponse(rpcMsg, *lm) + sendControlUnauthenticatedErrorResponse(rpcMsg, *lm, router) continue } log.Printf("wshrouter control-msg route=%s link=%s command=%s source=%s", rpcMsg.Route, lm.Name(), rpcMsg.Command, rpcMsg.Source) @@ -703,17 +816,17 @@ func (router *WshRouter) bindRoute(linkId baseds.LinkId, routeId string, isSourc return nil } -func (router *WshRouter) getUpstreamClient() AbstractRpcClient { +func (router *WshRouter) getUpstreamClient() (baseds.LinkId, AbstractRpcClient) { router.lock.Lock() defer router.lock.Unlock() if router.upstreamLinkId == baseds.NoLinkId { - return nil + return baseds.NoLinkId, nil } lm := router.linkMap[router.upstreamLinkId] if lm == nil { - return nil + return baseds.NoLinkId, nil } - return lm.client + return router.upstreamLinkId, lm.client } func (router *WshRouter) publishRouteToBroker(routeId string) { @@ -731,7 +844,7 @@ func (router *WshRouter) unsubscribeFromBroker(routeId string) { wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteDown, Scopes: []string{routeId}}) } -func sendControlUnauthenticatedErrorResponse(cmdMsg RpcMessage, linkMeta linkMeta) { +func sendControlUnauthenticatedErrorResponse(cmdMsg RpcMessage, linkMeta linkMeta, router *WshRouter) { if cmdMsg.ReqId == "" { return } @@ -741,5 +854,5 @@ func sendControlUnauthenticatedErrorResponse(cmdMsg RpcMessage, linkMeta linkMet Error: fmt.Sprintf("link is unauthenticated (%s), cannot call %q", linkMeta.Name(), cmdMsg.Command), } rtnBytes, _ := json.Marshal(rtnMsg) - linkMeta.client.SendRpcMessage(rtnBytes, baseds.NoLinkId, "unauthenticated") + router.sendRpcMessageToLink(linkMeta.linkId, linkMeta.client, rtnBytes, baseds.NoLinkId, "unauthenticated") } diff --git a/pkg/wshutil/wshrouter_controlimpl.go b/pkg/wshutil/wshrouter_controlimpl.go index 0cc29ca2f9..d0ebc034c4 100644 --- a/pkg/wshutil/wshrouter_controlimpl.go +++ b/pkg/wshutil/wshrouter_controlimpl.go @@ -91,7 +91,7 @@ func (impl *WshRouterControlImpl) AuthenticateCommand(ctx context.Context, data return wshrpc.CommandAuthenticateRtnData{}, err } - rtnData := wshrpc.CommandAuthenticateRtnData{} + rtnData := wshrpc.CommandAuthenticateRtnData{RouteId: routeId} if newCtx.IsRouter { log.Printf("wshrouter authenticate success linkid=%d (router)", linkId) impl.Router.trustLink(linkId, LinkKind_Router) @@ -116,10 +116,12 @@ func extractTokenData(token string) (wshrpc.CommandAuthenticateRtnData, error) { if entry.RpcContext.IsRouter { return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token") } - if entry.RpcContext.RouteId == "" { + routeId := entry.RpcContext.GenerateRouteId() + if routeId == "" { return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid") } return wshrpc.CommandAuthenticateRtnData{ + RouteId: routeId, Env: entry.Env, InitScriptText: entry.ScriptText, RpcContext: entry.RpcContext, @@ -140,7 +142,7 @@ func (impl *WshRouterControlImpl) AuthenticateTokenVerifyCommand(ctx context.Con return wshrpc.CommandAuthenticateRtnData{}, err } - log.Printf("wshrouter authenticate-token-verify success routeid=%q", rtnData.RpcContext.RouteId) + log.Printf("wshrouter authenticate-token-verify success routeid=%q", rtnData.RouteId) return rtnData, nil } @@ -186,9 +188,12 @@ func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context, if rtnData.RpcContext == nil { return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no rpccontext in token response") } - log.Printf("wshrouter authenticate-token success linkid=%d routeid=%q", linkId, rtnData.RpcContext.RouteId) + if rtnData.RouteId == "" { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid in token response") + } + log.Printf("wshrouter authenticate-token success linkid=%d routeid=%q", linkId, rtnData.RouteId) impl.Router.trustLink(linkId, LinkKind_Leaf) - impl.Router.bindRoute(linkId, rtnData.RpcContext.RouteId, true) + impl.Router.bindRoute(linkId, rtnData.RouteId, true) return rtnData, nil } @@ -275,11 +280,14 @@ func validateRpcContextFromAuth(newCtx *wshrpc.RpcContext) (string, error) { if newCtx.IsRouter && newCtx.RouteId != "" { return "", fmt.Errorf("invalid context, router cannot have a routeid") } - if !newCtx.IsRouter && newCtx.RouteId == "" { + if newCtx.IsRouter && newCtx.ProcRoute { + return "", fmt.Errorf("invalid context, router cannot have a proc-route") + } + if !newCtx.IsRouter && newCtx.RouteId == "" && !newCtx.ProcRoute { return "", fmt.Errorf("invalid context, must have a routeid") } if newCtx.IsRouter { return "", nil } - return newCtx.RouteId, nil + return newCtx.GenerateRouteId(), nil } diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go index eb2903c1f7..8a29e204bd 100644 --- a/pkg/wshutil/wshrpc.go +++ b/pkg/wshutil/wshrpc.go @@ -43,7 +43,7 @@ type ServerImpl interface { type AbstractRpcClient interface { GetPeerInfo() string - SendRpcMessage(msg []byte, ingressLinkId baseds.LinkId, debugStr string) + SendRpcMessage(msg []byte, ingressLinkId baseds.LinkId, debugStr string) bool RecvRpcMessage() ([]byte, bool) // blocking } @@ -110,8 +110,13 @@ func (w *WshRpc) GetPeerInfo() string { return w.DebugName } -func (w *WshRpc) SendRpcMessage(msg []byte, ingressLinkId baseds.LinkId, debugStr string) { - w.InputCh <- baseds.RpcInputChType{MsgBytes: msg, IngressLinkId: ingressLinkId} +func (w *WshRpc) SendRpcMessage(msg []byte, ingressLinkId baseds.LinkId, debugStr string) bool { + select { + case w.InputCh <- baseds.RpcInputChType{MsgBytes: msg, IngressLinkId: ingressLinkId}: + return true + default: + return false + } } func (w *WshRpc) RecvRpcMessage() ([]byte, bool) { diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index b5e938839c..d979e7652f 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -187,6 +187,13 @@ func tryTcpSocket(sockName string) (net.Conn, error) { func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl, debugName string) (*WshRpc, error) { sockName = wavebase.ExpandHomeDirSafe(sockName) + resolvedPath, err := filepath.EvalSymlinks(sockName) + if err == nil { + sockName = resolvedPath + } + if !filepath.IsAbs(sockName) { + return nil, fmt.Errorf("socket path must be absolute: %s", sockName) + } conn, tcpErr := tryTcpSocket(sockName) var unixErr error if tcpErr != nil { @@ -211,30 +218,32 @@ func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl, debugNam func MakeClientJWTToken(rpcCtx wshrpc.RpcContext) (string, error) { if wavebase.IsDevMode() { - if rpcCtx.IsRouter && rpcCtx.RouteId != "" { + if rpcCtx.IsRouter && (rpcCtx.RouteId != "" || rpcCtx.ProcRoute) { panic("Invalid RpcCtx, router w/ routeid") } - if !rpcCtx.IsRouter && rpcCtx.RouteId == "" { + if !rpcCtx.IsRouter && (rpcCtx.RouteId == "" && !rpcCtx.ProcRoute) { panic("Invalid RpcCtx, no routeid") } } claims := &wavejwt.WaveJwtClaims{ - Sock: rpcCtx.SockName, - RouteId: rpcCtx.RouteId, - BlockId: rpcCtx.BlockId, - Conn: rpcCtx.Conn, - Router: rpcCtx.IsRouter, + Sock: rpcCtx.SockName, + RouteId: rpcCtx.RouteId, + ProcRoute: rpcCtx.ProcRoute, + BlockId: rpcCtx.BlockId, + Conn: rpcCtx.Conn, + Router: rpcCtx.IsRouter, } return wavejwt.Sign(claims) } func claimsToRpcCtx(claims *wavejwt.WaveJwtClaims) *wshrpc.RpcContext { return &wshrpc.RpcContext{ - SockName: claims.Sock, - RouteId: claims.RouteId, - BlockId: claims.BlockId, - Conn: claims.Conn, - IsRouter: claims.Router, + SockName: claims.Sock, + RouteId: claims.RouteId, + ProcRoute: claims.ProcRoute, + BlockId: claims.BlockId, + Conn: claims.Conn, + IsRouter: claims.Router, } } @@ -395,8 +404,8 @@ func GetInfo() wshrpc.RemoteInfo { ClientOs: runtime.GOOS, ClientVersion: wavebase.WaveVersion, Shell: getShell(), + HomeDir: wavebase.GetHomeDir(), } - } func InstallRcFiles() error { diff --git a/pkg/wstore/wstore.go b/pkg/wstore/wstore.go index 91675729cb..4ce08d233b 100644 --- a/pkg/wstore/wstore.go +++ b/pkg/wstore/wstore.go @@ -6,6 +6,7 @@ package wstore import ( "context" "fmt" + "sync" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/waveobj" @@ -17,6 +18,23 @@ func init() { } } +var ( + clientIdLock sync.Mutex + cachedClientId string +) + +func SetClientId(clientId string) { + clientIdLock.Lock() + defer clientIdLock.Unlock() + cachedClientId = clientId +} + +func GetClientId() string { + clientIdLock.Lock() + defer clientIdLock.Unlock() + return cachedClientId +} + func UpdateTabName(ctx context.Context, tabId, name string) error { return WithTx(ctx, func(tx *TxWrap) error { tab, _ := DBGet[*waveobj.Tab](tx.Context(), tabId)