diff --git a/package.json b/package.json index b827cbac..81a3a2e7 100644 --- a/package.json +++ b/package.json @@ -265,6 +265,12 @@ "title": "Search", "category": "Coder", "icon": "$(search)" + }, + { + "command": "coder.debug.listDeployments", + "title": "List Stored Deployments", + "category": "Coder Debug", + "when": "coder.devMode" } ], "menus": { diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index 04c696be..f6a13355 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -31,7 +31,7 @@ import { HttpClientLogLevel, } from "../logging/types"; import { sizeOf } from "../logging/utils"; -import { HttpStatusCode } from "../websocket/codes"; +import { HttpStatusCode, WebSocketCloseCode } from "../websocket/codes"; import { type UnidirectionalStream, type CloseEvent, @@ -55,9 +55,9 @@ const coderSessionTokenHeader = "Coder-Session-Token"; * Unified API class that includes both REST API methods from the base Api class * and WebSocket methods for real-time functionality. */ -export class CoderApi extends Api { +export class CoderApi extends Api implements vscode.Disposable { private readonly reconnectingSockets = new Set< - ReconnectingWebSocket + ReconnectingWebSocket >(); private constructor(private readonly output: Logger) { @@ -74,75 +74,110 @@ export class CoderApi extends Api { output: Logger, ): CoderApi { const client = new CoderApi(output); - client.setHost(baseUrl); - if (token) { - client.setSessionToken(token); - } + client.setCredentials(baseUrl, token); setupInterceptors(client, output); return client; } - setSessionToken = (token: string): void => { - const defaultHeaders = this.getAxiosInstance().defaults.headers.common; - const currentToken = defaultHeaders[coderSessionTokenHeader]; - defaultHeaders[coderSessionTokenHeader] = token; + getHost(): string | undefined { + return this.getAxiosInstance().defaults.baseURL; + } - if (currentToken !== token) { - for (const socket of this.reconnectingSockets) { - socket.reconnect(); - } - } - }; + getSessionToken(): string | undefined { + return this.getAxiosInstance().defaults.headers.common[ + coderSessionTokenHeader + ] as string | undefined; + } - setHost = (host: string | undefined): void => { + /** + * Set both host and token together. Useful for login/logout/switch to + * avoid triggering multiple reconnection events. + */ + setCredentials = ( + host: string | undefined, + token: string | undefined, + ): void => { + const currentHost = this.getHost(); + const currentToken = this.getSessionToken(); + + // We cannot use the super.setHost/setSessionToken methods because they are shadowed here const defaults = this.getAxiosInstance().defaults; - const currentHost = defaults.baseURL; defaults.baseURL = host; + defaults.headers.common[coderSessionTokenHeader] = token; - if (currentHost !== host) { + const hostChanged = (currentHost || "") !== (host || ""); + const tokenChanged = currentToken !== token; + + if (hostChanged || tokenChanged) { for (const socket of this.reconnectingSockets) { - socket.reconnect(); + if (host) { + socket.reconnect(); + } else { + socket.disconnect(WebSocketCloseCode.NORMAL, "Host cleared"); + } } } }; + override setSessionToken = (token: string): void => { + this.setCredentials(this.getHost(), token); + }; + + override setHost = (host: string | undefined): void => { + this.setCredentials(host, this.getSessionToken()); + }; + + /** + * Permanently dispose all WebSocket connections. + * This clears handlers and prevents reconnection. + */ + dispose(): void { + for (const socket of this.reconnectingSockets) { + socket.close(); + } + this.reconnectingSockets.clear(); + } + watchInboxNotifications = async ( watchTemplates: string[], watchTargets: string[], options?: ClientOptions, ) => { - return this.createWebSocket({ - apiRoute: "/api/v2/notifications/inbox/watch", - searchParams: { - format: "plaintext", - templates: watchTemplates.join(","), - targets: watchTargets.join(","), - }, - options, - enableRetry: true, - }); + return this.createReconnectingSocket(() => + this.createOneWayWebSocket({ + apiRoute: "/api/v2/notifications/inbox/watch", + searchParams: { + format: "plaintext", + templates: watchTemplates.join(","), + targets: watchTargets.join(","), + }, + options, + }), + ); }; watchWorkspace = async (workspace: Workspace, options?: ClientOptions) => { - return this.createWebSocketWithFallback({ - apiRoute: `/api/v2/workspaces/${workspace.id}/watch-ws`, - fallbackApiRoute: `/api/v2/workspaces/${workspace.id}/watch`, - options, - enableRetry: true, - }); + return this.createReconnectingSocket(() => + this.createStreamWithSseFallback({ + apiRoute: `/api/v2/workspaces/${workspace.id}/watch-ws`, + fallbackApiRoute: `/api/v2/workspaces/${workspace.id}/watch`, + options, + }), + ); }; watchAgentMetadata = async ( agentId: WorkspaceAgent["id"], options?: ClientOptions, ) => { - return this.createWebSocketWithFallback({ - apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`, - fallbackApiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata`, - options, - enableRetry: true, - }); + return this.createReconnectingSocket(() => + this.createStreamWithSseFallback({ + apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`, + fallbackApiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata`, + options, + }), + ); }; watchBuildLogsByBuildId = async ( @@ -180,86 +215,60 @@ export class CoderApi extends Api { searchParams.append("after", lastLog.id.toString()); } - return this.createWebSocket({ + return this.createOneWayWebSocket({ apiRoute, searchParams, options, }); } - private async createWebSocket( - configs: Omit & { enableRetry?: boolean }, - ): Promise> { - const { enableRetry, ...socketConfigs } = configs; - - const socketFactory: SocketFactory = async () => { - const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; - if (!baseUrlRaw) { - throw new Error("No base URL set on REST client"); - } - - const baseUrl = new URL(baseUrlRaw); - const token = this.getAxiosInstance().defaults.headers.common[ - coderSessionTokenHeader - ] as string | undefined; - - const headersFromCommand = await getHeaders( - baseUrlRaw, - getHeaderCommand(vscode.workspace.getConfiguration()), - this.output, - ); - - const httpAgent = await createHttpAgent( - vscode.workspace.getConfiguration(), - ); + private async createOneWayWebSocket( + configs: Omit, + ): Promise> { + const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; + if (!baseUrlRaw) { + throw new Error("No base URL set on REST client"); + } + const token = this.getAxiosInstance().defaults.headers.common[ + coderSessionTokenHeader + ] as string | undefined; - /** - * Similar to the REST client, we want to prioritize headers in this order (highest to lowest): - * 1. Headers from the header command - * 2. Any headers passed directly to this function - * 3. Coder session token from the Api client (if set) - */ - const headers = { - ...(token ? { [coderSessionTokenHeader]: token } : {}), - ...configs.options?.headers, - ...headersFromCommand, - }; + const headersFromCommand = await getHeaders( + baseUrlRaw, + getHeaderCommand(vscode.workspace.getConfiguration()), + this.output, + ); - const webSocket = new OneWayWebSocket({ - location: baseUrl, - ...socketConfigs, - options: { - ...configs.options, - agent: httpAgent, - followRedirects: true, - headers, - }, - }); + const httpAgent = await createHttpAgent( + vscode.workspace.getConfiguration(), + ); - this.attachStreamLogger(webSocket); - return webSocket; + /** + * Similar to the REST client, we want to prioritize headers in this order (highest to lowest): + * 1. Headers from the header command + * 2. Any headers passed directly to this function + * 3. Coder session token from the Api client (if set) + */ + const headers = { + ...(token ? { [coderSessionTokenHeader]: token } : {}), + ...configs.options?.headers, + ...headersFromCommand, }; - if (enableRetry) { - const reconnectingSocket = await ReconnectingWebSocket.create( - socketFactory, - this.output, - configs.apiRoute, - undefined, - () => - this.reconnectingSockets.delete( - reconnectingSocket as ReconnectingWebSocket, - ), - ); - - this.reconnectingSockets.add( - reconnectingSocket as ReconnectingWebSocket, - ); + const baseUrl = new URL(baseUrlRaw); + const ws = new OneWayWebSocket({ + location: baseUrl, + ...configs, + options: { + ...configs.options, + agent: httpAgent, + followRedirects: true, + headers, + }, + }); - return reconnectingSocket; - } else { - return socketFactory(); - } + this.attachStreamLogger(ws); + return ws; } private attachStreamLogger( @@ -288,44 +297,67 @@ export class CoderApi extends Api { /** * Create a WebSocket connection with SSE fallback on 404. * + * Tries WS first, falls back to SSE on 404. + * * Note: The fallback on SSE ignores all passed client options except the headers. */ - private async createWebSocketWithFallback(configs: { - apiRoute: string; - fallbackApiRoute: string; - searchParams?: Record | URLSearchParams; - options?: ClientOptions; - enableRetry?: boolean; - }): Promise> { - let webSocket: UnidirectionalStream; + private async createStreamWithSseFallback( + configs: Omit & { + fallbackApiRoute: string; + }, + ): Promise> { + const { fallbackApiRoute, ...socketConfigs } = configs; try { - webSocket = await this.createWebSocket({ - apiRoute: configs.apiRoute, - searchParams: configs.searchParams, - options: configs.options, - enableRetry: configs.enableRetry, - }); - } catch { - // Failed to create WebSocket, use SSE fallback - return this.createSseFallback( - configs.fallbackApiRoute, - configs.searchParams, - configs.options?.headers, - ); + const ws = + await this.createOneWayWebSocket(socketConfigs); + return await this.waitForOpen(ws); + } catch (error) { + if (this.is404Error(error)) { + this.output.warn( + `WebSocket failed, using SSE fallback: ${socketConfigs.apiRoute}`, + ); + const sse = this.createSseConnection( + fallbackApiRoute, + socketConfigs.searchParams, + socketConfigs.options?.headers, + ); + return await this.waitForOpen(sse); + } + throw error; } + } - return this.waitForConnection(webSocket, () => - this.createSseFallback( - configs.fallbackApiRoute, - configs.searchParams, - configs.options?.headers, - ), - ); + /** + * Create an SSE connection without waiting for connection. + */ + private createSseConnection( + apiRoute: string, + searchParams?: Record | URLSearchParams, + optionsHeaders?: Record, + ): SseConnection { + const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; + if (!baseUrlRaw) { + throw new Error("No base URL set on REST client"); + } + const url = new URL(baseUrlRaw); + const sse = new SseConnection({ + location: url, + apiRoute, + searchParams, + axiosInstance: this.getAxiosInstance(), + optionsHeaders, + logger: this.output, + }); + + this.attachStreamLogger(sse); + return sse; } - private waitForConnection( + /** + * Wait for a connection to open. Rejects on error. + */ + private waitForOpen( connection: UnidirectionalStream, - onNotFound?: () => Promise>, ): Promise> { return new Promise((resolve, reject) => { const cleanup = () => { @@ -340,16 +372,8 @@ export class CoderApi extends Api { const handleError = (event: ErrorEvent) => { cleanup(); - const is404 = - event.message?.includes(String(HttpStatusCode.NOT_FOUND)) || - event.error?.message?.includes(String(HttpStatusCode.NOT_FOUND)); - - if (is404 && onNotFound) { - connection.close(); - onNotFound().then(resolve).catch(reject); - } else { - reject(event.error || new Error(event.message)); - } + connection.close(); + reject(event.error || new Error(event.message)); }; connection.addEventListener("open", handleOpen); @@ -358,32 +382,29 @@ export class CoderApi extends Api { } /** - * Create SSE fallback connection + * Check if an error is a 404 Not Found error. */ - private async createSseFallback( - apiRoute: string, - searchParams?: Record | URLSearchParams, - optionsHeaders?: Record, - ): Promise> { - this.output.warn(`WebSocket failed, using SSE fallback: ${apiRoute}`); + private is404Error(error: unknown): boolean { + const msg = error instanceof Error ? error.message : String(error); + return msg.includes(String(HttpStatusCode.NOT_FOUND)); + } - const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; - if (!baseUrlRaw) { - throw new Error("No base URL set on REST client"); - } + /** + * Create a ReconnectingWebSocket and track it for lifecycle management. + */ + private async createReconnectingSocket( + socketFactory: SocketFactory, + ): Promise> { + const reconnectingSocket = await ReconnectingWebSocket.create( + socketFactory, + this.output, + undefined, + () => this.reconnectingSockets.delete(reconnectingSocket), + ); - const baseUrl = new URL(baseUrlRaw); - const sseConnection = new SseConnection({ - location: baseUrl, - apiRoute, - searchParams, - axiosInstance: this.getAxiosInstance(), - optionsHeaders: optionsHeaders, - logger: this.output, - }); + this.reconnectingSockets.add(reconnectingSocket); - this.attachStreamLogger(sseConnection); - return this.waitForConnection(sseConnection); + return reconnectingSocket; } } @@ -457,7 +478,7 @@ function addLoggingInterceptors(client: AxiosInstance, logger: Logger) { }, (error: unknown) => { logError(logger, error, getLogLevel()); - return Promise.reject(error); + throw error; }, ); @@ -468,7 +489,7 @@ function addLoggingInterceptors(client: AxiosInstance, logger: Logger) { }, (error: unknown) => { logError(logger, error, getLogLevel()); - return Promise.reject(error); + throw error; }, ); } diff --git a/src/commands.ts b/src/commands.ts index 9bb2ed54..000e0bcc 100644 --- a/src/commands.ts +++ b/src/commands.ts @@ -1,15 +1,11 @@ -import { type Api } from "coder/site/src/api/api"; -import { getErrorMessage } from "coder/site/src/api/errors"; import { - type User, type Workspace, type WorkspaceAgent, } from "coder/site/src/api/typesGenerated"; import * as vscode from "vscode"; import { createWorkspaceIdentifier, extractAgents } from "./api/api-helper"; -import { CoderApi } from "./api/coderApi"; -import { needToken } from "./api/utils"; +import { type CoderApi } from "./api/coderApi"; import { getGlobalFlags } from "./cliConfig"; import { type CliManager } from "./core/cliManager"; import { type ServiceContainer } from "./core/container"; @@ -17,8 +13,10 @@ import { type ContextManager } from "./core/contextManager"; import { type MementoManager } from "./core/mementoManager"; import { type PathResolver } from "./core/pathResolver"; import { type SecretsManager } from "./core/secretsManager"; +import { type DeploymentManager } from "./deployment/deploymentManager"; import { CertificateError } from "./error"; import { type Logger } from "./logging/logger"; +import { type LoginCoordinator } from "./login/loginCoordinator"; import { maybeAskAgent, maybeAskUrl } from "./promptUtils"; import { escapeCommandArg, toRemoteAuthority, toSafeHost } from "./util"; import { @@ -35,6 +33,8 @@ export class Commands { private readonly secretsManager: SecretsManager; private readonly cliManager: CliManager; private readonly contextManager: ContextManager; + private readonly loginCoordinator: LoginCoordinator; + // These will only be populated when actively connected to a workspace and are // used in commands. Because commands can be executed by the user, it is not // possible to pass in arguments, so we have to store the current workspace @@ -44,11 +44,12 @@ export class Commands { // if you use multiple deployments). public workspace?: Workspace; public workspaceLogPath?: string; - public workspaceRestClient?: Api; + public remoteWorkspaceClient?: CoderApi; public constructor( serviceContainer: ServiceContainer, - private readonly restClient: Api, + private readonly extensionClient: CoderApi, + private readonly deploymentManager: DeploymentManager, ) { this.vscodeProposed = serviceContainer.getVsCodeProposed(); this.logger = serviceContainer.getLogger(); @@ -57,6 +58,18 @@ export class Commands { this.secretsManager = serviceContainer.getSecretsManager(); this.cliManager = serviceContainer.getCliManager(); this.contextManager = serviceContainer.getContextManager(); + this.loginCoordinator = serviceContainer.getLoginCoordinator(); + } + + /** + * Get the current deployment, throwing if not logged in. + */ + private requireExtensionBaseUrl(): string { + const url = this.extensionClient.getAxiosInstance().defaults.baseURL; + if (!url) { + throw new Error("You are not logged in"); + } + return url; } /** @@ -66,53 +79,51 @@ export class Commands { */ public async login(args?: { url?: string; - token?: string; - label?: string; autoLogin?: boolean; }): Promise { - if (this.contextManager.get("coder.authenticated")) { + if (this.deploymentManager.isAuthenticated()) { return; } this.logger.info("Logging in"); - const url = await maybeAskUrl(this.mementoManager, args?.url); + const currentDeployment = await this.secretsManager.getCurrentDeployment(); + const url = await maybeAskUrl( + this.mementoManager, + args?.url, + currentDeployment?.url, + ); if (!url) { return; // The user aborted. } - // It is possible that we are trying to log into an old-style host, in which - // case we want to write with the provided blank label instead of generating - // a host label. - const label = args?.label === undefined ? toSafeHost(url) : args.label; - - // Try to get a token from the user, if we need one, and their user. - const autoLogin = args?.autoLogin === true; - const res = await this.maybeAskToken(url, args?.token, autoLogin); - if (!res) { - return; // The user aborted, or unable to auth. - } + const safeHostname = toSafeHost(url); + this.logger.info("Using hostname", safeHostname); - // The URL is good and the token is either good or not required; authorize - // the global client. - this.restClient.setHost(url); - this.restClient.setSessionToken(res.token); - - // Store these to be used in later sessions. - await this.mementoManager.setUrl(url); - await this.secretsManager.setSessionToken(res.token); + const result = await this.loginCoordinator.ensureLoggedIn({ + safeHostname, + url, + autoLogin: args?.autoLogin, + }); - // Store on disk to be used by the cli. - await this.cliManager.configure(label, url, res.token); + if (!result.success) { + return; + } - // These contexts control various menu items and the sidebar. - this.contextManager.set("coder.authenticated", true); - if (res.user.roles.find((role) => role.name === "owner")) { - this.contextManager.set("coder.isOwner", true); + if (!result.user) { + // Login might have happened in another process/window so we do not have the user yet. + result.user = await this.extensionClient.getAuthenticatedUser(); } + await this.deploymentManager.changeDeployment({ + url, + safeHostname, + token: result.token, + user: result.user, + }); + vscode.window .showInformationMessage( - `Welcome to Coder, ${res.user.username}!`, + `Welcome to Coder, ${result.user.username}!`, { detail: "You can now use the Coder extension to manage your Coder instance.", @@ -124,101 +135,6 @@ export class Commands { vscode.commands.executeCommand("coder.open"); } }); - - await this.secretsManager.triggerLoginStateChange("login"); - // Fetch workspaces for the new deployment. - vscode.commands.executeCommand("coder.refreshWorkspaces"); - } - - /** - * If necessary, ask for a token, and keep asking until the token has been - * validated. Return the token and user that was fetched to validate the - * token. Null means the user aborted or we were unable to authenticate with - * mTLS (in the latter case, an error notification will have been displayed). - */ - private async maybeAskToken( - url: string, - token: string | undefined, - isAutoLogin: boolean, - ): Promise<{ user: User; token: string } | null> { - const client = CoderApi.create(url, token, this.logger); - const needsToken = needToken(vscode.workspace.getConfiguration()); - if (!needsToken || token) { - try { - const user = await client.getAuthenticatedUser(); - // For non-token auth, we write a blank token since the `vscodessh` - // command currently always requires a token file. - // For token auth, we have valid access so we can just return the user here - return { token: needsToken && token ? token : "", user }; - } catch (err) { - const message = getErrorMessage(err, "no response from the server"); - if (isAutoLogin) { - this.logger.warn("Failed to log in to Coder server:", message); - } else { - this.vscodeProposed.window.showErrorMessage( - "Failed to log in to Coder server", - { - detail: message, - modal: true, - useCustom: true, - }, - ); - } - // Invalid certificate, most likely. - return null; - } - } - - // This prompt is for convenience; do not error if they close it since - // they may already have a token or already have the page opened. - await vscode.env.openExternal(vscode.Uri.parse(`${url}/cli-auth`)); - - // For token auth, start with the existing token in the prompt or the last - // used token. Once submitted, if there is a failure we will keep asking - // the user for a new token until they quit. - let user: User | undefined; - const validatedToken = await vscode.window.showInputBox({ - title: "Coder API Key", - password: true, - placeHolder: "Paste your API key.", - value: token || (await this.secretsManager.getSessionToken()), - ignoreFocusOut: true, - validateInput: async (value) => { - if (!value) { - return null; - } - client.setSessionToken(value); - try { - user = await client.getAuthenticatedUser(); - } catch (err) { - // For certificate errors show both a notification and add to the - // text under the input box, since users sometimes miss the - // notification. - if (err instanceof CertificateError) { - err.showNotification(); - - return { - message: err.x509Err || err.message, - severity: vscode.InputBoxValidationSeverity.Error, - }; - } - // This could be something like the header command erroring or an - // invalid session token. - const message = getErrorMessage(err, "no response from the server"); - return { - message: "Failed to authenticate: " + message, - severity: vscode.InputBoxValidationSeverity.Error, - }; - } - }, - }); - - if (validatedToken && user) { - return { token: validatedToken, user }; - } - - // User aborted. - return null; } /** @@ -250,29 +166,14 @@ export class Commands { * Log out from the currently logged-in deployment. */ public async logout(): Promise { - const url = this.mementoManager.getUrl(); - if (!url) { - // Sanity check; command should not be available if no url. - throw new Error("You are not logged in"); - } - await this.forceLogout(); - } - - public async forceLogout(): Promise { - if (!this.contextManager.get("coder.authenticated")) { + if (!this.deploymentManager.isAuthenticated()) { return; } + this.logger.info("Logging out"); - // Clear from the REST client. An empty url will indicate to other parts of - // the code that we are logged out. - this.restClient.setHost(""); - this.restClient.setSessionToken(""); - // Clear from memory. - await this.mementoManager.setUrl(undefined); - await this.secretsManager.setSessionToken(undefined); + await this.deploymentManager.clearDeployment(); - this.contextManager.set("coder.authenticated", false); vscode.window .showInformationMessage("You've been logged out of Coder!", "Login") .then((action) => { @@ -280,10 +181,6 @@ export class Commands { this.login(); } }); - - await this.secretsManager.triggerLoginStateChange("logout"); - // This will result in clearing the workspace list. - vscode.commands.executeCommand("coder.refreshWorkspaces"); } /** @@ -292,7 +189,8 @@ export class Commands { * Must only be called if currently logged in. */ public async createWorkspace(): Promise { - const uri = this.mementoManager.getUrl() + "/templates"; + const baseUrl = this.requireExtensionBaseUrl(); + const uri = baseUrl + "/templates"; await vscode.commands.executeCommand("vscode.open", uri); } @@ -306,12 +204,13 @@ export class Commands { */ public async navigateToWorkspace(item: OpenableTreeItem) { if (item) { + const baseUrl = this.requireExtensionBaseUrl(); const workspaceId = createWorkspaceIdentifier(item.workspace); - const uri = this.mementoManager.getUrl() + `/@${workspaceId}`; + const uri = baseUrl + `/@${workspaceId}`; await vscode.commands.executeCommand("vscode.open", uri); - } else if (this.workspace && this.workspaceRestClient) { + } else if (this.workspace && this.remoteWorkspaceClient) { const baseUrl = - this.workspaceRestClient.getAxiosInstance().defaults.baseURL; + this.remoteWorkspaceClient.getAxiosInstance().defaults.baseURL; const uri = `${baseUrl}/@${createWorkspaceIdentifier(this.workspace)}`; await vscode.commands.executeCommand("vscode.open", uri); } else { @@ -329,12 +228,13 @@ export class Commands { */ public async navigateToWorkspaceSettings(item: OpenableTreeItem) { if (item) { + const baseUrl = this.requireExtensionBaseUrl(); const workspaceId = createWorkspaceIdentifier(item.workspace); - const uri = this.mementoManager.getUrl() + `/@${workspaceId}/settings`; + const uri = baseUrl + `/@${workspaceId}/settings`; await vscode.commands.executeCommand("vscode.open", uri); - } else if (this.workspace && this.workspaceRestClient) { + } else if (this.workspace && this.remoteWorkspaceClient) { const baseUrl = - this.workspaceRestClient.getAxiosInstance().defaults.baseURL; + this.remoteWorkspaceClient.getAxiosInstance().defaults.baseURL; const uri = `${baseUrl}/@${createWorkspaceIdentifier(this.workspace)}/settings`; await vscode.commands.executeCommand("vscode.open", uri); } else { @@ -352,7 +252,7 @@ export class Commands { */ public async openFromSidebar(item: OpenableTreeItem) { if (item) { - const baseUrl = this.restClient.getAxiosInstance().defaults.baseURL; + const baseUrl = this.extensionClient.getAxiosInstance().defaults.baseURL; if (!baseUrl) { throw new Error("You are not logged in"); } @@ -379,7 +279,7 @@ export class Commands { true, ); } else { - throw new Error("Unable to open unknown sidebar item"); + throw new TypeError("Unable to open unknown sidebar item"); } } else { // If there is no tree item, then the user manually ran this command. @@ -407,25 +307,20 @@ export class Commands { const terminal = vscode.window.createTerminal(app.name); // If workspace_name is provided, run coder ssh before the command - - const url = this.mementoManager.getUrl(); - if (!url) { - throw new Error("No coder url found for sidebar"); - } + const baseUrl = this.requireExtensionBaseUrl(); + const safeHost = toSafeHost(baseUrl); const binary = await this.cliManager.fetchBinary( - this.restClient, - toSafeHost(url), + this.extensionClient, + safeHost, ); - const configDir = this.pathResolver.getGlobalConfigDir( - toSafeHost(url), - ); + const configDir = this.pathResolver.getGlobalConfigDir(safeHost); const globalFlags = getGlobalFlags( vscode.workspace.getConfiguration(), configDir, ); terminal.sendText( - `${escapeCommandArg(binary)}${` ${globalFlags.join(" ")}`} ssh ${app.workspace_name}`, + `${escapeCommandArg(binary)} ${globalFlags.join(" ")} ssh ${app.workspace_name}`, ); await new Promise((resolve) => setTimeout(resolve, 5000)); terminal.sendText(app.command ?? ""); @@ -433,19 +328,6 @@ export class Commands { }, ); } - // Check if app has a URL to open - if (app.url) { - return vscode.window.withProgress( - { - location: vscode.ProgressLocation.Notification, - title: `Opening ${app.name || "application"} in browser...`, - cancellable: false, - }, - async () => { - await vscode.env.openExternal(vscode.Uri.parse(app.url!)); - }, - ); - } // If no URL or command, show information about the app status vscode.window.showInformationMessage(`${app.name}`, { @@ -469,14 +351,14 @@ export class Commands { folderPath?: string, openRecent?: boolean, ): Promise { - const baseUrl = this.restClient.getAxiosInstance().defaults.baseURL; + const baseUrl = this.extensionClient.getAxiosInstance().defaults.baseURL; if (!baseUrl) { throw new Error("You are not logged in"); } let workspace: Workspace | undefined; if (workspaceOwner && workspaceName) { - workspace = await this.restClient.getWorkspaceByOwnerAndName( + workspace = await this.extensionClient.getWorkspaceByOwnerAndName( workspaceOwner, workspaceName, ); @@ -512,7 +394,7 @@ export class Commands { localWorkspaceFolder: string = "", localConfigFile: string = "", ): Promise { - const baseUrl = this.restClient.getAxiosInstance().defaults.baseURL; + const baseUrl = this.extensionClient.getAxiosInstance().defaults.baseURL; if (!baseUrl) { throw new Error("You are not logged in"); } @@ -524,7 +406,7 @@ export class Commands { workspaceAgent, ); - const hostPath = localWorkspaceFolder ? localWorkspaceFolder : undefined; + const hostPath = localWorkspaceFolder || undefined; const configFile = hostPath && localConfigFile ? { @@ -568,7 +450,7 @@ export class Commands { * this is a no-op. */ public async updateWorkspace(): Promise { - if (!this.workspace || !this.workspaceRestClient) { + if (!this.workspace || !this.remoteWorkspaceClient) { return; } const action = await this.vscodeProposed.window.showWarningMessage( @@ -581,7 +463,7 @@ export class Commands { "Update", ); if (action === "Update") { - await this.workspaceRestClient.updateWorkspaceVersion(this.workspace); + await this.remoteWorkspaceClient.updateWorkspaceVersion(this.workspace); } } @@ -596,7 +478,7 @@ export class Commands { let lastWorkspaces: readonly Workspace[]; quickPick.onDidChangeValue((value) => { quickPick.busy = true; - this.restClient + this.extensionClient .getWorkspaces({ q: value, }) @@ -625,7 +507,6 @@ export class Commands { if (ex instanceof CertificateError) { ex.showNotification(); } - return; }); }); quickPick.show(); @@ -660,7 +541,7 @@ export class Commands { // we need to fetch the agents through the resources API, as the // workspaces query does not include agents when off. this.logger.info("Fetching agents from template version"); - const resources = await this.restClient.getTemplateVersionResources( + const resources = await this.extensionClient.getTemplateVersionResources( workspace.latest_build.template_version_id, ); return extractAgents(resources); diff --git a/src/core/cliManager.ts b/src/core/cliManager.ts index 5e0b3d26..6f1fdf8f 100644 --- a/src/core/cliManager.ts +++ b/src/core/cliManager.ts @@ -33,8 +33,8 @@ export class CliManager { /** * Download and return the path to a working binary for the deployment with - * the provided label using the provided client. If the label is empty, use - * the old deployment-unaware path instead. + * the provided hostname using the provided client. If the hostname is empty, + * use the old deployment-unaware path instead. * * If there is already a working binary and it matches the server version, * return that, skipping the download. If it does not match but downloads are @@ -42,7 +42,10 @@ export class CliManager { * unable to download a working binary, whether because of network issues or * downloads being disabled. */ - public async fetchBinary(restClient: Api, label: string): Promise { + public async fetchBinary( + restClient: Api, + safeHostname: string, + ): Promise { const cfg = vscode.workspace.getConfiguration("coder"); // Settings can be undefined when set to their defaults (true in this case), // so explicitly check against false. @@ -64,7 +67,7 @@ export class CliManager { // is valid and matches the server, or if it does not match the server but // downloads are disabled, we can return early. const binPath = path.join( - this.pathResolver.getBinaryCachePath(label), + this.pathResolver.getBinaryCachePath(safeHostname), cliUtils.name(), ); this.output.info("Using binary path", binPath); @@ -693,76 +696,71 @@ export class CliManager { } /** - * Configure the CLI for the deployment with the provided label. + * Configure the CLI for the deployment with the provided hostname. * * Falsey URLs and null tokens are a no-op; we avoid unconfiguring the CLI to * avoid breaking existing connections. */ public async configure( - label: string, + safeHostname: string, url: string | undefined, token: string | null, ) { await Promise.all([ - this.updateUrlForCli(label, url), - this.updateTokenForCli(label, token), + this.updateUrlForCli(safeHostname, url), + this.updateTokenForCli(safeHostname, token), ]); } /** - * Update the URL for the deployment with the provided label on disk which can - * be used by the CLI via --url-file. If the URL is falsey, do nothing. + * Update the URL for the deployment with the provided hostname on disk which + * can be used by the CLI via --url-file. If the URL is falsey, do nothing. * - * If the label is empty, read the old deployment-unaware config instead. + * If the hostname is empty, read the old deployment-unaware config instead. */ private async updateUrlForCli( - label: string, + safeHostname: string, url: string | undefined, ): Promise { if (url) { - const urlPath = this.pathResolver.getUrlPath(label); - await fs.mkdir(path.dirname(urlPath), { recursive: true }); - await fs.writeFile(urlPath, url); + const urlPath = this.pathResolver.getUrlPath(safeHostname); + await this.atomicWriteFile(urlPath, url); } } /** - * Update the session token for a deployment with the provided label on disk - * which can be used by the CLI via --session-token-file. If the token is - * null, do nothing. + * Update the session token for a deployment with the provided hostname on + * disk which can be used by the CLI via --session-token-file. If the token + * is null, do nothing. * - * If the label is empty, read the old deployment-unaware config instead. + * If the hostname is empty, read the old deployment-unaware config instead. */ - private async updateTokenForCli( - label: string, - token: string | undefined | null, - ) { + private async updateTokenForCli(safeHostname: string, token: string | null) { if (token !== null) { - const tokenPath = this.pathResolver.getSessionTokenPath(label); - await fs.mkdir(path.dirname(tokenPath), { recursive: true }); - await fs.writeFile(tokenPath, token ?? ""); + const tokenPath = this.pathResolver.getSessionTokenPath(safeHostname); + await this.atomicWriteFile(tokenPath, token); } } /** - * Read the CLI config for a deployment with the provided label. - * - * IF a config file does not exist, return an empty string. - * - * If the label is empty, read the old deployment-unaware config. + * Atomically write content to a file by writing to a temporary file first, + * then renaming it. */ - public async readConfig( - label: string, - ): Promise<{ url: string; token: string }> { - const urlPath = this.pathResolver.getUrlPath(label); - const tokenPath = this.pathResolver.getSessionTokenPath(label); - const [url, token] = await Promise.allSettled([ - fs.readFile(urlPath, "utf8"), - fs.readFile(tokenPath, "utf8"), - ]); - return { - url: url.status === "fulfilled" ? url.value.trim() : "", - token: token.status === "fulfilled" ? token.value.trim() : "", - }; + private async atomicWriteFile( + filePath: string, + content: string, + ): Promise { + await fs.mkdir(path.dirname(filePath), { recursive: true }); + const tempPath = + filePath + ".temp-" + Math.random().toString(36).substring(8); + try { + await fs.writeFile(tempPath, content); + await fs.rename(tempPath, filePath); + } catch (err) { + await fs.rm(tempPath, { force: true }).catch((rmErr) => { + this.output.warn("Failed to delete temp file", tempPath, rmErr); + }); + throw err; + } } } diff --git a/src/core/container.ts b/src/core/container.ts index a8f938ea..acf2d854 100644 --- a/src/core/container.ts +++ b/src/core/container.ts @@ -1,6 +1,7 @@ import * as vscode from "vscode"; import { type Logger } from "../logging/logger"; +import { LoginCoordinator } from "../login/loginCoordinator"; import { CliManager } from "./cliManager"; import { ContextManager } from "./contextManager"; @@ -19,6 +20,7 @@ export class ServiceContainer implements vscode.Disposable { private readonly secretsManager: SecretsManager; private readonly cliManager: CliManager; private readonly contextManager: ContextManager; + private readonly loginCoordinator: LoginCoordinator; constructor( context: vscode.ExtensionContext, @@ -30,13 +32,23 @@ export class ServiceContainer implements vscode.Disposable { context.logUri.fsPath, ); this.mementoManager = new MementoManager(context.globalState); - this.secretsManager = new SecretsManager(context.secrets); + this.secretsManager = new SecretsManager( + context.secrets, + context.globalState, + this.logger, + ); this.cliManager = new CliManager( this.vscodeProposed, this.logger, this.pathResolver, ); - this.contextManager = new ContextManager(); + this.contextManager = new ContextManager(context); + this.loginCoordinator = new LoginCoordinator( + this.secretsManager, + this.mementoManager, + this.vscodeProposed, + this.logger, + ); } getVsCodeProposed(): typeof vscode { @@ -67,6 +79,10 @@ export class ServiceContainer implements vscode.Disposable { return this.contextManager; } + getLoginCoordinator(): LoginCoordinator { + return this.loginCoordinator; + } + /** * Dispose of all services and clean up resources. */ diff --git a/src/core/contextManager.ts b/src/core/contextManager.ts index a5a18397..405850a2 100644 --- a/src/core/contextManager.ts +++ b/src/core/contextManager.ts @@ -12,10 +12,19 @@ type CoderContext = keyof typeof CONTEXT_DEFAULTS; export class ContextManager implements vscode.Disposable { private readonly context = new Map(); - public constructor() { - (Object.keys(CONTEXT_DEFAULTS) as CoderContext[]).forEach((key) => { + public constructor(extensionContext: vscode.ExtensionContext) { + for (const key of Object.keys(CONTEXT_DEFAULTS) as CoderContext[]) { this.set(key, CONTEXT_DEFAULTS[key]); - }); + } + this.setInternalContexts(extensionContext); + } + + private setInternalContexts(extensionContext: vscode.ExtensionContext): void { + vscode.commands.executeCommand( + "setContext", + "coder.devMode", + extensionContext.extensionMode === vscode.ExtensionMode.Development, + ); } public set(key: CoderContext, value: boolean): void { diff --git a/src/core/mementoManager.ts b/src/core/mementoManager.ts index f79be46c..3cf4478e 100644 --- a/src/core/mementoManager.ts +++ b/src/core/mementoManager.ts @@ -7,27 +7,16 @@ export class MementoManager { constructor(private readonly memento: Memento) {} /** - * Add the URL to the list of recently accessed URLs in global storage, then - * set it as the last used URL. - * - * If the URL is falsey, then remove it as the last used URL and do not touch - * the history. + * Add a URL to the history of recently accessed URLs. + * Used by the URL picker to show recent deployments. */ - public async setUrl(url?: string): Promise { - await this.memento.update("url", url); + public async addToUrlHistory(url: string): Promise { if (url) { const history = this.withUrlHistory(url); await this.memento.update("urlHistory", history); } } - /** - * Get the last used URL. - */ - public getUrl(): string | undefined { - return this.memento.get("url"); - } - /** * Get the most recently accessed URLs (oldest to newest) with the provided * values appended. Duplicates will be removed. diff --git a/src/core/pathResolver.ts b/src/core/pathResolver.ts index 514e64fb..8c320088 100644 --- a/src/core/pathResolver.ts +++ b/src/core/pathResolver.ts @@ -1,4 +1,4 @@ -import * as path from "path"; +import * as path from "node:path"; import * as vscode from "vscode"; export class PathResolver { @@ -8,26 +8,28 @@ export class PathResolver { ) {} /** - * Return the directory for the deployment with the provided label to where - * the global Coder configs are stored. + * Return the directory for the deployment with the provided hostname to + * where the global Coder configs are stored. * - * If the label is empty, read the old deployment-unaware config instead. + * If the hostname is empty, read the old deployment-unaware config instead. * * The caller must ensure this directory exists before use. */ - public getGlobalConfigDir(label: string): string { - return label ? path.join(this.basePath, label) : this.basePath; + public getGlobalConfigDir(safeHostname: string): string { + return safeHostname + ? path.join(this.basePath, safeHostname) + : this.basePath; } /** - * Return the directory for a deployment with the provided label to where its - * binary is cached. + * Return the directory for a deployment with the provided hostname to where + * its binary is cached. * - * If the label is empty, read the old deployment-unaware config instead. + * If the hostname is empty, read the old deployment-unaware config instead. * * The caller must ensure this directory exists before use. */ - public getBinaryCachePath(label: string): string { + public getBinaryCachePath(safeHostname: string): string { const settingPath = vscode.workspace .getConfiguration() .get("coder.binaryDestination") @@ -36,7 +38,7 @@ export class PathResolver { settingPath || process.env.CODER_BINARY_DESTINATION?.trim(); return binaryPath ? path.normalize(binaryPath) - : path.join(this.getGlobalConfigDir(label), "bin"); + : path.join(this.getGlobalConfigDir(safeHostname), "bin"); } /** @@ -69,39 +71,39 @@ export class PathResolver { } /** - * Return the directory for the deployment with the provided label to where - * its session token is stored. + * Return the directory for the deployment with the provided hostname to + * where its session token is stored. * - * If the label is empty, read the old deployment-unaware config instead. + * If the hostname is empty, read the old deployment-unaware config instead. * * The caller must ensure this directory exists before use. */ - public getSessionTokenPath(label: string): string { - return path.join(this.getGlobalConfigDir(label), "session"); + public getSessionTokenPath(safeHostname: string): string { + return path.join(this.getGlobalConfigDir(safeHostname), "session"); } /** - * Return the directory for the deployment with the provided label to where - * its session token was stored by older code. + * Return the directory for the deployment with the provided hostname to + * where its session token was stored by older code. * - * If the label is empty, read the old deployment-unaware config instead. + * If the hostname is empty, read the old deployment-unaware config instead. * * The caller must ensure this directory exists before use. */ - public getLegacySessionTokenPath(label: string): string { - return path.join(this.getGlobalConfigDir(label), "session_token"); + public getLegacySessionTokenPath(safeHostname: string): string { + return path.join(this.getGlobalConfigDir(safeHostname), "session_token"); } /** - * Return the directory for the deployment with the provided label to where - * its url is stored. + * Return the directory for the deployment with the provided hostname to + * where its url is stored. * - * If the label is empty, read the old deployment-unaware config instead. + * If the hostname is empty, read the old deployment-unaware config instead. * * The caller must ensure this directory exists before use. */ - public getUrlPath(label: string): string { - return path.join(this.getGlobalConfigDir(label), "url"); + public getUrlPath(safeHostname: string): string { + return path.join(this.getGlobalConfigDir(safeHostname), "url"); } /** diff --git a/src/core/secretsManager.ts b/src/core/secretsManager.ts index 94827b15..e6558299 100644 --- a/src/core/secretsManager.ts +++ b/src/core/secretsManager.ts @@ -1,73 +1,237 @@ -import type { SecretStorage, Disposable } from "vscode"; +import { type Logger } from "../logging/logger"; +import { toSafeHost } from "../util"; -const SESSION_TOKEN_KEY = "sessionToken"; +import type { Memento, SecretStorage, Disposable } from "vscode"; -const LOGIN_STATE_KEY = "loginState"; +import type { Deployment } from "../deployment/types"; -export enum AuthAction { - LOGIN, - LOGOUT, - INVALID, +// Each deployment has its own key to ensure atomic operations (multiple windows +// writing to a shared key could drop data) and to receive proper VS Code events. +const SESSION_KEY_PREFIX = "coder.session."; + +const CURRENT_DEPLOYMENT_KEY = "coder.currentDeployment"; + +const DEPLOYMENT_USAGE_KEY = "coder.deploymentUsage"; +const DEFAULT_MAX_DEPLOYMENTS = 10; + +const LEGACY_SESSION_TOKEN_KEY = "sessionToken"; + +export interface CurrentDeploymentState { + deployment: Deployment | null; +} + +export interface SessionAuth { + url: string; + token: string; +} + +// Tracks when a deployment was last accessed for LRU pruning. +interface DeploymentUsage { + safeHostname: string; + lastAccessedAt: string; } export class SecretsManager { - constructor(private readonly secrets: SecretStorage) {} + constructor( + private readonly secrets: SecretStorage, + private readonly memento: Memento, + private readonly logger: Logger, + ) {} /** - * Set or unset the last used token. + * Sets the current deployment and triggers a cross-window sync event. */ - public async setSessionToken(sessionToken?: string): Promise { - if (!sessionToken) { - await this.secrets.delete(SESSION_TOKEN_KEY); - } else { - await this.secrets.store(SESSION_TOKEN_KEY, sessionToken); - } + public async setCurrentDeployment( + deployment: Deployment | undefined, + ): Promise { + const state: CurrentDeploymentState & { timestamp: string } = { + // Extract the necessary fields before serializing + deployment: deployment + ? { + url: deployment?.url, + safeHostname: deployment?.safeHostname, + } + : null, + timestamp: new Date().toISOString(), + }; + await this.secrets.store(CURRENT_DEPLOYMENT_KEY, JSON.stringify(state)); } /** - * Get the last used token. + * Gets the current deployment from storage. */ - public async getSessionToken(): Promise { + public async getCurrentDeployment(): Promise { try { - return await this.secrets.get(SESSION_TOKEN_KEY); + const data = await this.secrets.get(CURRENT_DEPLOYMENT_KEY); + if (!data) { + return null; + } + const parsed = JSON.parse(data) as CurrentDeploymentState; + return parsed.deployment; } catch { - // The VS Code session store has become corrupt before, and - // will fail to get the session token... - return undefined; + return null; } } /** - * Triggers a login/logout event that propagates across all VS Code windows. - * Uses the secrets storage onDidChange event as a cross-window communication mechanism. - * Appends a timestamp to ensure the value always changes, guaranteeing the event fires. + * Listens for deployment changes from any VS Code window. + * Fires when login, logout, or deployment switch occurs. */ - public async triggerLoginStateChange( - action: "login" | "logout", - ): Promise { - const date = new Date().toISOString(); - await this.secrets.store(LOGIN_STATE_KEY, `${action}-${date}`); + public onDidChangeCurrentDeployment( + listener: (state: CurrentDeploymentState) => void | Promise, + ): Disposable { + return this.secrets.onDidChange(async (e) => { + if (e.key !== CURRENT_DEPLOYMENT_KEY) { + return; + } + + const deployment = await this.getCurrentDeployment(); + try { + await listener({ deployment }); + } catch (err) { + this.logger.error( + "Error in onDidChangeCurrentDeployment listener", + err, + ); + } + }); } /** - * Listens for login/logout events from any VS Code window. - * The secrets storage onDidChange event fires across all windows, enabling cross-window sync. + * Listen for changes to a specific deployment's session auth. */ - public onDidChangeLoginState( - listener: (state: AuthAction) => Promise, + public onDidChangeSessionAuth( + safeHostname: string, + listener: (auth: SessionAuth | undefined) => void | Promise, ): Disposable { + const sessionKey = this.getSessionKey(safeHostname); return this.secrets.onDidChange(async (e) => { - if (e.key === LOGIN_STATE_KEY) { - const state = await this.secrets.get(LOGIN_STATE_KEY); - if (state?.startsWith("login")) { - listener(AuthAction.LOGIN); - } else if (state?.startsWith("logout")) { - listener(AuthAction.LOGOUT); - } else { - // Secret was deleted or is invalid - listener(AuthAction.INVALID); - } + if (e.key !== sessionKey) { + return; + } + const auth = await this.getSessionAuth(safeHostname); + try { + await listener(auth); + } catch (err) { + this.logger.error("Error in onDidChangeSessionAuth listener", err); } }); } + + public async getSessionAuth( + safeHostname: string, + ): Promise { + const sessionKey = this.getSessionKey(safeHostname); + try { + const data = await this.secrets.get(sessionKey); + if (!data) { + return undefined; + } + return JSON.parse(data) as SessionAuth; + } catch { + return undefined; + } + } + + public async setSessionAuth( + safeHostname: string, + auth: SessionAuth, + ): Promise { + const sessionKey = this.getSessionKey(safeHostname); + // Extract only url and token before serializing + const state: SessionAuth = { url: auth.url, token: auth.token }; + await this.secrets.store(sessionKey, JSON.stringify(state)); + await this.recordDeploymentAccess(safeHostname); + } + + private async clearSessionAuth(safeHostname: string): Promise { + const sessionKey = this.getSessionKey(safeHostname); + await this.secrets.delete(sessionKey); + } + + private getSessionKey(safeHostname: string): string { + return `${SESSION_KEY_PREFIX}${safeHostname || ""}`; + } + + /** + * Record that a deployment was accessed, moving it to the front of the LRU list. + * Prunes deployments beyond maxCount, clearing their auth data. + */ + public async recordDeploymentAccess( + safeHostname: string, + maxCount = DEFAULT_MAX_DEPLOYMENTS, + ): Promise { + const usage = this.getDeploymentUsage(); + const filtered = usage.filter((u) => u.safeHostname !== safeHostname); + filtered.unshift({ + safeHostname, + lastAccessedAt: new Date().toISOString(), + }); + + const toKeep = filtered.slice(0, maxCount); + const toRemove = filtered.slice(maxCount); + + await Promise.all( + toRemove.map((u) => this.clearAllAuthData(u.safeHostname)), + ); + await this.memento.update(DEPLOYMENT_USAGE_KEY, toKeep); + } + + /** + * Clear all auth data for a deployment and remove it from the usage list. + */ + public async clearAllAuthData(safeHostname: string): Promise { + await this.clearSessionAuth(safeHostname); + const usage = this.getDeploymentUsage().filter( + (u) => u.safeHostname !== safeHostname, + ); + await this.memento.update(DEPLOYMENT_USAGE_KEY, usage); + } + + /** + * Get all known hostnames, ordered by most recently accessed. + */ + public getKnownSafeHostnames(): string[] { + return this.getDeploymentUsage().map((u) => u.safeHostname); + } + + /** + * Get the full deployment usage list with access timestamps. + */ + private getDeploymentUsage(): DeploymentUsage[] { + return this.memento.get(DEPLOYMENT_USAGE_KEY) ?? []; + } + + /** + * Migrate from legacy flat sessionToken storage to new format. + * Also sets the current deployment if none exists. + */ + public async migrateFromLegacyStorage(): Promise { + const legacyUrl = this.memento.get("url"); + if (!legacyUrl) { + return undefined; + } + + const oldToken = await this.secrets.get(LEGACY_SESSION_TOKEN_KEY); + + await this.secrets.delete(LEGACY_SESSION_TOKEN_KEY); + await this.memento.update("url", undefined); + + const safeHostname = toSafeHost(legacyUrl); + const existing = await this.getSessionAuth(safeHostname); + if (!existing) { + await this.setSessionAuth(safeHostname, { + url: legacyUrl, + token: oldToken ?? "", + }); + } + + // Also set as current deployment if none exists + const currentDeployment = await this.getCurrentDeployment(); + if (!currentDeployment) { + await this.setCurrentDeployment({ url: legacyUrl, safeHostname }); + } + + return safeHostname; + } } diff --git a/src/deployment/deploymentManager.ts b/src/deployment/deploymentManager.ts new file mode 100644 index 00000000..44d173c1 --- /dev/null +++ b/src/deployment/deploymentManager.ts @@ -0,0 +1,270 @@ +import type { User } from "coder/site/src/api/typesGenerated"; +import type * as vscode from "vscode"; + +import type { CoderApi } from "../api/coderApi"; +import type { ServiceContainer } from "../core/container"; +import type { ContextManager } from "../core/contextManager"; +import type { MementoManager } from "../core/mementoManager"; +import type { SecretsManager } from "../core/secretsManager"; +import type { Logger } from "../logging/logger"; +import type { WorkspaceProvider } from "../workspace/workspacesProvider"; + +import type { Deployment, DeploymentWithAuth } from "./types"; + +/** + * Internal state type that allows mutation of user property. + */ +type DeploymentWithUser = Deployment & { user?: User }; + +/** + * Manages deployment state for the extension. + * + * Centralizes: + * - In-memory deployment state (url, label, token, user) + * - Client credential updates + * - Auth listener registration + * - Context updates (coder.authenticated, coder.isOwner) + * - Workspace provider refresh + * - Cross-window sync handling + */ +export class DeploymentManager implements vscode.Disposable { + private readonly secretsManager: SecretsManager; + private readonly mementoManager: MementoManager; + private readonly contextManager: ContextManager; + private readonly logger: Logger; + + #deployment: DeploymentWithUser | null = null; + #authListenerDisposable: vscode.Disposable | undefined; + #crossWindowSyncDisposable: vscode.Disposable | undefined; + + private constructor( + serviceContainer: ServiceContainer, + private readonly client: CoderApi, + private readonly workspaceProviders: WorkspaceProvider[], + ) { + this.secretsManager = serviceContainer.getSecretsManager(); + this.mementoManager = serviceContainer.getMementoManager(); + this.contextManager = serviceContainer.getContextManager(); + this.logger = serviceContainer.getLogger(); + } + + public static create( + serviceContainer: ServiceContainer, + client: CoderApi, + workspaceProviders: WorkspaceProvider[], + ): DeploymentManager { + const manager = new DeploymentManager( + serviceContainer, + client, + workspaceProviders, + ); + manager.subscribeToCrossWindowChanges(); + return manager; + } + + /** + * Get the current deployment state. + */ + public getCurrentDeployment(): Deployment | null { + return this.#deployment; + } + + /** + * Check if we have an authenticated deployment (with a valid user). + */ + public isAuthenticated(): boolean { + return this.#deployment?.user !== undefined; + } + + /** + * Change to a fully authenticated deployment (with user). + * Use this when you already have the user from a successful login. + */ + public async changeDeployment( + deployment: DeploymentWithAuth & { user: User }, + ): Promise { + this.setDeploymentInternal(deployment); + await this.persistDeployment(deployment); + } + + /** + * Set deployment without requiring authentication. + * Immediately tries to fetch user and upgrade to authenticated state. + * Use this for startup or when you don't have the user yet. + */ + public async setDeploymentAndValidate( + deployment: Deployment & { token?: string }, + ): Promise { + this.setDeploymentInternal(deployment); + await this.tryFetchAndUpgradeUser(); + } + + /** + * Clears the current deployment. + */ + public async clearDeployment(): Promise { + this.#authListenerDisposable?.dispose(); + this.#authListenerDisposable = undefined; + this.#deployment = null; + + this.client.setCredentials(undefined, undefined); + this.updateAuthContexts(); + this.refreshWorkspaces(); + + await this.secretsManager.setCurrentDeployment(undefined); + } + + public dispose(): void { + this.#authListenerDisposable?.dispose(); + this.#crossWindowSyncDisposable?.dispose(); + } + + /** + * Internal method to set deployment state with all side effects. + * - Updates client credentials + * - Re-registers auth listener if hostname changed + * - Updates auth contexts + * - Refreshes workspaces + */ + private setDeploymentInternal(deployment: DeploymentWithAuth): void { + this.#deployment = { ...deployment }; + + // Update client credentials + if (deployment.token !== undefined) { + this.client.setCredentials(deployment.url, deployment.token); + } else { + this.client.setHost(deployment.url); + } + + this.registerAuthListener(); + this.updateAuthContexts(); + this.refreshWorkspaces(); + } + + /** + * Upgrade the current deployment with a user. + * Use this when the user has been fetched after initial deployment setup. + */ + private upgradeWithUser(user: User): void { + if (!this.#deployment) { + return; + } + + this.#deployment.user = user; + this.updateAuthContexts(); + this.refreshWorkspaces(); + } + + /** + * Register auth listener for the current deployment. + * Updates credentials when they change (token refresh, cross-window sync). + */ + private registerAuthListener(): void { + if (!this.#deployment) { + return; + } + + // Capture hostname at registration time for the guard clause + const safeHostname = this.#deployment.safeHostname; + + this.#authListenerDisposable?.dispose(); + this.logger.debug("Registering auth listener for hostname", safeHostname); + this.#authListenerDisposable = this.secretsManager.onDidChangeSessionAuth( + safeHostname, + async (auth) => { + if (this.#deployment?.safeHostname !== safeHostname) { + return; + } + + if (auth) { + this.client.setCredentials(auth.url, auth.token); + if (!this.isAuthenticated()) { + await this.tryFetchAndUpgradeUser(); + } + } else { + await this.clearDeployment(); + } + }, + ); + } + + private subscribeToCrossWindowChanges(): void { + this.#crossWindowSyncDisposable = + this.secretsManager.onDidChangeCurrentDeployment( + async ({ deployment }) => { + if (this.isAuthenticated()) { + // Ignore if we are already authenticated + return; + } + + if (deployment) { + this.logger.info("Deployment changed from another window"); + const auth = await this.secretsManager.getSessionAuth( + deployment.safeHostname, + ); + await this.setDeploymentAndValidate({ + ...deployment, + token: auth?.token, + }); + } + }, + ); + } + + /** + * Try to fetch the authenticated user and upgrade the deployment state. + */ + private async tryFetchAndUpgradeUser(): Promise { + if (!this.#deployment || this.isAuthenticated()) { + return; + } + + const safeHostname = this.#deployment.safeHostname; + + try { + const user = await this.client.getAuthenticatedUser(); + + // Re-validate deployment hasn't changed during await + if (this.#deployment?.safeHostname !== safeHostname) { + this.logger.debug( + "Deployment changed during user fetch, discarding result", + ); + return; + } + + this.upgradeWithUser(user); + + // Persist with user + await this.persistDeployment(this.#deployment); + } catch (e) { + this.logger.warn("Failed to fetch user:", e); + } + } + + /** + * Update authentication-related contexts. + */ + private updateAuthContexts(): void { + const user = this.#deployment?.user; + this.contextManager.set("coder.authenticated", Boolean(user)); + const isOwner = user?.roles.some((r) => r.name === "owner") ?? false; + this.contextManager.set("coder.isOwner", isOwner); + } + + /** + * Refresh all workspace providers asynchronously. + */ + private refreshWorkspaces(): void { + for (const provider of this.workspaceProviders) { + provider.fetchAndRefresh(); + } + } + + /** + * Persist deployment to storage for cross-window sync. + */ + private async persistDeployment(deployment: Deployment): Promise { + await this.secretsManager.setCurrentDeployment(deployment); + await this.mementoManager.addToUrlHistory(deployment.url); + } +} diff --git a/src/deployment/types.ts b/src/deployment/types.ts new file mode 100644 index 00000000..9200defb --- /dev/null +++ b/src/deployment/types.ts @@ -0,0 +1,23 @@ +import { type User } from "coder/site/src/api/typesGenerated"; + +/** + * Represents a Coder deployment with its URL and hostname. + * The safeHostname is used as a unique identifier for storing credentials and configuration. + * It is derived from the URL hostname (via toSafeHost) or from SSH host parsing. + */ +export interface Deployment { + readonly url: string; + readonly safeHostname: string; +} + +/** + * Deployment info with authentication credentials. + * Used when logging in or changing to a new deployment. + * + * - Undefined token means that we should not override the existing token (if any). + * - Undefined user means the deployment is set but not authenticated yet. + */ +export type DeploymentWithAuth = Deployment & { + readonly token?: string; + readonly user?: User; +}; diff --git a/src/extension.ts b/src/extension.ts index 974cbe7d..876dede6 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -11,7 +11,8 @@ import { CoderApi } from "./api/coderApi"; import { needToken } from "./api/utils"; import { Commands } from "./commands"; import { ServiceContainer } from "./core/container"; -import { AuthAction } from "./core/secretsManager"; +import { type SecretsManager } from "./core/secretsManager"; +import { DeploymentManager } from "./deployment/deploymentManager"; import { CertificateError, getErrorDetail } from "./error"; import { maybeAskUrl } from "./promptUtils"; import { Remote } from "./remote/remote"; @@ -60,18 +61,24 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { const secretsManager = serviceContainer.getSecretsManager(); const contextManager = serviceContainer.getContextManager(); + // Migrate auth storage from old flat format to new label-based format + await migrateAuthStorage(serviceContainer); + // Try to clear this flag ASAP const isFirstConnect = await mementoManager.getAndClearFirstConnect(); + const deployment = await secretsManager.getCurrentDeployment(); + // This client tracks the current login and will be used through the life of // the plugin to poll workspaces for the current login, as well as being used // in commands that operate on the current login. - const url = mementoManager.getUrl(); const client = CoderApi.create( - url || "", - await secretsManager.getSessionToken(), + deployment?.url || "", + (await secretsManager.getSessionAuth(deployment?.safeHostname ?? "")) + ?.token, output, ); + ctx.subscriptions.push(client); const myWorkspacesProvider = new WorkspaceProvider( WorkspaceQuery.Mine, @@ -116,11 +123,18 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { ctx.subscriptions, ); + // Create deployment manager to centralize deployment state management + const deploymentManager = DeploymentManager.create(serviceContainer, client, [ + myWorkspacesProvider, + allWorkspacesProvider, + ]); + ctx.subscriptions.push(deploymentManager); + // Handle vscode:// URIs. const uriHandler = vscode.window.registerUriHandler({ handleUri: async (uri) => { - const cliManager = serviceContainer.getCliManager(); const params = new URLSearchParams(uri.query); + if (uri.path === "/open") { const owner = params.get("owner"); const workspace = params.get("workspace"); @@ -137,49 +151,17 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { throw new Error("workspace must be specified as a query parameter"); } - // We are not guaranteed that the URL we currently have is for the URL - // this workspace belongs to, or that we even have a URL at all (the - // queries will default to localhost) so ask for it if missing. - // Pre-populate in case we do have the right URL so the user can just - // hit enter and move on. - const url = await maybeAskUrl( - mementoManager, - params.get("url"), - mementoManager.getUrl(), + await setupDeploymentFromUri( + params, + serviceContainer, + deploymentManager, ); - if (url) { - client.setHost(url); - await mementoManager.setUrl(url); - } else { - throw new Error( - "url must be provided or specified as a query parameter", - ); - } - // If the token is missing we will get a 401 later and the user will be - // prompted to sign in again, so we do not need to ensure it is set now. - // For non-token auth, we write a blank token since the `vscodessh` - // command currently always requires a token file. However, if there is - // a query parameter for non-token auth go ahead and use it anyway; all - // that really matters is the file is created. - const token = needToken(vscode.workspace.getConfiguration()) - ? params.get("token") - : (params.get("token") ?? ""); - - if (token) { - client.setSessionToken(token); - await secretsManager.setSessionToken(token); - } - - // Store on disk to be used by the cli. - await cliManager.configure(toSafeHost(url), url, token); - - vscode.commands.executeCommand( - "coder.open", + await commands.open( owner, workspace, - agent, - folder, + agent ?? undefined, + folder ?? undefined, openRecent, ); } else if (uri.path === "/openDevContainer") { @@ -203,6 +185,12 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { ); } + if (!workspaceAgent) { + throw new Error( + "workspace agent must be specified as a query parameter", + ); + } + if (!devContainerName) { throw new Error( "dev container name must be specified as a query parameter", @@ -221,47 +209,20 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { ); } - // We are not guaranteed that the URL we currently have is for the URL - // this workspace belongs to, or that we even have a URL at all (the - // queries will default to localhost) so ask for it if missing. - // Pre-populate in case we do have the right URL so the user can just - // hit enter and move on. - const url = await maybeAskUrl( - mementoManager, - params.get("url"), - mementoManager.getUrl(), + await setupDeploymentFromUri( + params, + serviceContainer, + deploymentManager, ); - if (url) { - client.setHost(url); - await mementoManager.setUrl(url); - } else { - throw new Error( - "url must be provided or specified as a query parameter", - ); - } - // If the token is missing we will get a 401 later and the user will be - // prompted to sign in again, so we do not need to ensure it is set now. - // For non-token auth, we write a blank token since the `vscodessh` - // command currently always requires a token file. However, if there is - // a query parameter for non-token auth go ahead and use it anyway; all - // that really matters is the file is created. - const token = needToken(vscode.workspace.getConfiguration()) - ? params.get("token") - : (params.get("token") ?? ""); - - // Store on disk to be used by the cli. - await cliManager.configure(toSafeHost(url), url, token); - - vscode.commands.executeCommand( - "coder.openDevContainer", + await commands.openDevContainer( workspaceOwner, workspaceName, workspaceAgent, devContainerName, devContainerFolder, - localWorkspaceFolder, - localConfigFile, + localWorkspaceFolder ?? "", + localConfigFile ?? "", ); } else { throw new Error(`Unknown path ${uri.path}`); @@ -272,7 +233,7 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { // Register globally available commands. Many of these have visibility // controlled by contexts, see `when` in the package.json. - const commands = new Commands(serviceContainer, client); + const commands = new Commands(serviceContainer, client, deploymentManager); ctx.subscriptions.push( vscode.commands.registerCommand( "coder.login", @@ -325,30 +286,12 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { vscode.commands.registerCommand("coder.searchAllWorkspaces", async () => showTreeViewSearch(ALL_WORKSPACES_TREE_ID), ), + vscode.commands.registerCommand("coder.debug.listDeployments", () => + listStoredDeployments(secretsManager), + ), ); - const remote = new Remote(serviceContainer, commands, ctx.extensionMode); - - ctx.subscriptions.push( - secretsManager.onDidChangeLoginState(async (state) => { - switch (state) { - case AuthAction.LOGIN: { - const token = await secretsManager.getSessionToken(); - const url = mementoManager.getUrl(); - // Should login the user directly if the URL+Token are valid - await commands.login({ url, token }); - // Resolve any pending login detection promises - remote.resolveLoginDetected(); - break; - } - case AuthAction.LOGOUT: - await commands.forceLogout(); - break; - case AuthAction.INVALID: - break; - } - }), - ); + const remote = new Remote(serviceContainer, commands, ctx); // Since the "onResolveRemoteAuthority:ssh-remote" activation event exists // in package.json we're able to perform actions before the authority is @@ -368,10 +311,13 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { ); if (details) { ctx.subscriptions.push(details); - // Authenticate the plugin client which is used in the sidebar to display - // workspaces belonging to this deployment. - client.setHost(details.url); - client.setSessionToken(details.token); + + // Will automatically fetch the user and upgrade the deployment + await deploymentManager.setDeploymentAndValidate({ + safeHostname: details.safeHostname, + url: details.url, + token: details.token, + }); } } catch (ex) { if (ex instanceof CertificateError) { @@ -411,31 +357,24 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { } } - // See if the plugin client is authenticated. - const baseUrl = client.getAxiosInstance().defaults.baseURL; - if (baseUrl) { - output.info(`Logged in to ${baseUrl}; checking credentials`); - client - .getAuthenticatedUser() - .then((user) => { - if (user && user.roles) { + // Initialize deployment manager with stored deployment (if any). + // Skip if already set by remote.setup above. + if (deploymentManager.getCurrentDeployment()) { + contextManager.set("coder.loaded", true); + } else if (deployment) { + output.info(`Initializing deployment: ${deployment.url}`); + const auth = await secretsManager.getSessionAuth(deployment.safeHostname); + deploymentManager + .setDeploymentAndValidate({ ...deployment, token: auth?.token }) + .then(() => { + if (deploymentManager.isAuthenticated()) { output.info("Credentials are valid"); - contextManager.set("coder.authenticated", true); - if (user.roles.find((role) => role.name === "owner")) { - contextManager.set("coder.isOwner", true); - } - - // Fetch and monitor workspaces, now that we know the client is good. - myWorkspacesProvider.fetchAndRefresh(); - allWorkspacesProvider.fetchAndRefresh(); } else { - output.warn("No error, but got unexpected response", user); + output.info("Deployment set but not authenticated"); } }) .catch((error) => { - // This should be a failure to make the request, like the header command - // errored. - output.warn("Failed to check user authentication", error); + output.warn("Failed to initialize deployment", error); vscode.window.showErrorMessage( `Failed to check user authentication: ${error.message}`, ); @@ -460,7 +399,113 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { } } +/** + * Migrates old flat storage (sessionToken) to new label-based map storage. + * This is a one-time operation that runs on extension activation. + */ +async function migrateAuthStorage( + serviceContainer: ServiceContainer, +): Promise { + const secretsManager = serviceContainer.getSecretsManager(); + const output = serviceContainer.getLogger(); + + try { + const migratedHostname = await secretsManager.migrateFromLegacyStorage(); + + if (migratedHostname) { + output.info( + `Successfully migrated auth storage (hostname: ${migratedHostname})`, + ); + } + } catch (error) { + output.error( + `Auth storage migration failed: ${error}. You may need to log in again.`, + ); + } +} + async function showTreeViewSearch(id: string): Promise { await vscode.commands.executeCommand(`${id}.focus`); await vscode.commands.executeCommand("list.find"); } + +/** + * Sets up deployment from URI parameters. Handles URL prompting, client setup, + * and token storage. Throws if user cancels URL input. + * + * If authentication succeeds, uses changeDeployment with the user. + * If authentication fails, uses setDeploymentWithoutAuth to let remote.setup handle 401. + */ +async function setupDeploymentFromUri( + params: URLSearchParams, + serviceContainer: ServiceContainer, + deploymentManager: DeploymentManager, +): Promise { + const secretsManager = serviceContainer.getSecretsManager(); + const mementoManager = serviceContainer.getMementoManager(); + const currentDeployment = await secretsManager.getCurrentDeployment(); + + // We are not guaranteed that the URL we currently have is for the URL + // this workspace belongs to, or that we even have a URL at all (the + // queries will default to localhost) so ask for it if missing. + // Pre-populate in case we do have the right URL so the user can just + // hit enter and move on. + const url = await maybeAskUrl( + mementoManager, + params.get("url"), + currentDeployment?.url, + ); + if (!url) { + throw new Error("url must be provided or specified as a query parameter"); + } + + const safeHost = toSafeHost(url); + + // If the token is missing we will get a 401 later and the user will be + // prompted to sign in again, so we do not need to ensure it is set now. + // For non-token auth, we write a blank token since the `vscodessh` + // command currently always requires a token file. However, if there is + // a query parameter for non-token auth go ahead and use it anyway; + let token: string | undefined = params.get("token") ?? undefined; + if (token === undefined) { + if (needToken(vscode.workspace.getConfiguration())) { + token = (await secretsManager.getSessionAuth(safeHost))?.token; + } else { + token = ""; + } + } else { + await secretsManager.setSessionAuth(safeHost, { url, token }); + } + + // Will automatically fetch the user and upgrade the deployment + await deploymentManager.setDeploymentAndValidate({ + safeHostname: safeHost, + url, + token, + }); +} + +async function listStoredDeployments( + secretsManager: SecretsManager, +): Promise { + const hostnames = secretsManager.getKnownSafeHostnames(); + if (hostnames.length === 0) { + vscode.window.showInformationMessage("No deployments stored."); + return; + } + + const selected = await vscode.window.showQuickPick( + hostnames.map((hostname) => ({ + label: hostname, + description: "Click to forget", + })), + { placeHolder: "Select a deployment to forget" }, + ); + + if (selected) { + await secretsManager.clearAllAuthData(selected.label); + vscode.window.showInformationMessage( + `Cleared auth data for ${selected.label}`, + ); + } +} diff --git a/src/login/loginCoordinator.ts b/src/login/loginCoordinator.ts new file mode 100644 index 00000000..8b5d7b07 --- /dev/null +++ b/src/login/loginCoordinator.ts @@ -0,0 +1,300 @@ +import { isAxiosError } from "axios"; +import { getErrorMessage } from "coder/site/src/api/errors"; +import * as vscode from "vscode"; + +import { CoderApi } from "../api/coderApi"; +import { needToken } from "../api/utils"; +import { CertificateError } from "../error"; +import { maybeAskUrl } from "../promptUtils"; + +import type { User } from "coder/site/src/api/typesGenerated"; + +import type { MementoManager } from "../core/mementoManager"; +import type { SecretsManager } from "../core/secretsManager"; +import type { Deployment } from "../deployment/types"; +import type { Logger } from "../logging/logger"; + +type LoginResult = + | { success: false } + | { success: true; user?: User; token: string }; + +interface LoginOptions { + safeHostname: string; + url: string | undefined; + autoLogin?: boolean; +} + +/** + * Coordinates login prompts across windows and prevents duplicate dialogs. + */ +export class LoginCoordinator { + private readonly inProgressLogins = new Map>(); + + constructor( + private readonly secretsManager: SecretsManager, + private readonly mementoManager: MementoManager, + private readonly vscodeProposed: typeof vscode, + private readonly logger: Logger, + ) {} + + /** + * Direct login - for user-initiated login via commands. + * Stores session auth and URL history on success. + */ + public async ensureLoggedIn( + options: LoginOptions & { url: string }, + ): Promise { + const { safeHostname, url } = options; + return this.executeWithGuard(safeHostname, async () => { + const result = await this.attemptLogin( + { safeHostname, url }, + options.autoLogin ?? false, + ); + + await this.persistSessionAuth(result, safeHostname, url); + + return result; + }); + } + + /** + * Shows dialog then login - for system-initiated auth (remote). + */ + public async ensureLoggedInWithDialog( + options: LoginOptions & { message?: string; detailPrefix?: string }, + ): Promise { + const { safeHostname, url, detailPrefix, message } = options; + return this.executeWithGuard(safeHostname, async () => { + // Show dialog promise + const dialogPromise = this.vscodeProposed.window + .showErrorMessage( + message || "Authentication Required", + { + modal: true, + useCustom: true, + detail: + (detailPrefix || `Authentication needed for ${safeHostname}.`) + + "\n\nIf you've already logged in, you may close this dialog.", + }, + "Login", + ) + .then(async (action) => { + if (action === "Login") { + // Proceed with the login flow, handling logging in from another window + const storedAuth = + await this.secretsManager.getSessionAuth(safeHostname); + const newUrl = await maybeAskUrl( + this.mementoManager, + url, + storedAuth?.url, + ); + if (!newUrl) { + throw new Error("URL must be provided"); + } + + const result = await this.attemptLogin( + { url: newUrl, safeHostname }, + false, + ); + + await this.persistSessionAuth(result, safeHostname, newUrl); + + return result; + } else { + // User cancelled + return { success: false } as const; + } + }); + + // Race between user clicking login and cross-window detection + const { + promise: crossWindowPromise, + dispose: disposeCrossWindowListener, + } = this.waitForCrossWindowLogin(safeHostname); + try { + return await Promise.race([dialogPromise, crossWindowPromise]); + } finally { + disposeCrossWindowListener(); + } + }); + } + + private async persistSessionAuth( + result: LoginResult, + safeHostname: string, + url: string, + ): Promise { + // Empty token is valid for mTLS + if (result.success) { + await this.secretsManager.setSessionAuth(safeHostname, { + url, + token: result.token, + }); + await this.mementoManager.addToUrlHistory(url); + } + } + + /** + * Same-window guard wrapper. + */ + private async executeWithGuard( + safeHostname: string, + executeFn: () => Promise, + ): Promise { + const existingLogin = this.inProgressLogins.get(safeHostname); + if (existingLogin) { + return existingLogin; + } + + const loginPromise = executeFn(); + this.inProgressLogins.set(safeHostname, loginPromise); + + try { + return await loginPromise; + } finally { + this.inProgressLogins.delete(safeHostname); + } + } + + /** + * Waits for login detected from another window. + * Returns a promise and a dispose function to clean up the listener. + */ + private waitForCrossWindowLogin(safeHostname: string): { + promise: Promise; + dispose: () => void; + } { + let disposable: vscode.Disposable | undefined; + const promise = new Promise((resolve) => { + disposable = this.secretsManager.onDidChangeSessionAuth( + safeHostname, + (auth) => { + if (auth?.token) { + disposable?.dispose(); + resolve({ success: true, token: auth.token }); + } + }, + ); + }); + return { + promise, + dispose: () => disposable?.dispose(), + }; + } + + /** + * Attempt to authenticate using token, or mTLS. If necessary, prompts + * for authentication method and credentials. Returns the token and user upon + * successful authentication. Null means the user aborted or authentication + * failed (in which case an error notification will have been displayed). + */ + private async attemptLogin( + deployment: Deployment, + isAutoLogin: boolean, + ): Promise { + const needsToken = needToken(vscode.workspace.getConfiguration()); + const client = CoderApi.create(deployment.url, "", this.logger); + + let storedToken: string | undefined; + if (needsToken) { + const auth = await this.secretsManager.getSessionAuth( + deployment.safeHostname, + ); + storedToken = auth?.token; + if (storedToken) { + client.setSessionToken(storedToken); + } + } + + // Attempt authentication with current credentials (token or mTLS) + try { + if (!needsToken || storedToken) { + const user = await client.getAuthenticatedUser(); + // Return the token that was used (empty string for mTLS since + // the `vscodessh` command currently always requires a token file) + return { success: true, token: storedToken ?? "", user }; + } + } catch (err) { + const is401 = isAxiosError(err) && err.response?.status === 401; + if (needsToken && is401) { + // For token auth with 401: silently continue to prompt for new credentials + } else { + // For mTLS or non-401 errors: show error and abort + const message = getErrorMessage(err, "no response from the server"); + if (isAutoLogin) { + this.logger.warn("Failed to log in to Coder server:", message); + } else { + this.vscodeProposed.window.showErrorMessage( + "Failed to log in to Coder server", + { + detail: message, + modal: true, + useCustom: true, + }, + ); + } + return { success: false }; + } + } + + const result = await this.loginWithToken(client); + return result; + } + + /** + * Session token authentication flow. + */ + private async loginWithToken(client: CoderApi): Promise { + const url = client.getAxiosInstance().defaults.baseURL; + if (!url) { + throw new Error("No base URL set on REST client"); + } + // This prompt is for convenience; do not error if they close it since + // they may already have a token or already have the page opened. + await vscode.env.openExternal(vscode.Uri.parse(`${url}/cli-auth`)); + + // For token auth, start with the existing token in the prompt or the last + // used token. Once submitted, if there is a failure we will keep asking + // the user for a new token until they quit. + let user: User | undefined; + const validatedToken = await vscode.window.showInputBox({ + title: "Coder API Key", + password: true, + placeHolder: "Paste your API key.", + ignoreFocusOut: true, + validateInput: async (value) => { + if (!value) { + return null; + } + client.setSessionToken(value); + try { + user = await client.getAuthenticatedUser(); + } catch (err) { + // For certificate errors show both a notification and add to the + // text under the input box, since users sometimes miss the + // notification. + if (err instanceof CertificateError) { + err.showNotification(); + return { + message: err.x509Err || err.message, + severity: vscode.InputBoxValidationSeverity.Error, + }; + } + // This could be something like the header command erroring or an + // invalid session token. + const message = getErrorMessage(err, "no response from the server"); + return { + message: "Failed to authenticate: " + message, + severity: vscode.InputBoxValidationSeverity.Error, + }; + } + }, + }); + + if (user) { + return { success: true, user, token: validatedToken ?? "" }; + } + + return { success: false }; + } +} diff --git a/src/promptUtils.ts b/src/promptUtils.ts index 4d058f12..3fb31475 100644 --- a/src/promptUtils.ts +++ b/src/promptUtils.ts @@ -61,15 +61,16 @@ export async function maybeAskAgent( */ async function askURL( mementoManager: MementoManager, - selection?: string, + prePopulateUrl: string | undefined, ): Promise { const defaultURL = vscode.workspace .getConfiguration() .get("coder.defaultUrl") ?.trim(); const quickPick = vscode.window.createQuickPick(); + quickPick.ignoreFocusOut = true; quickPick.value = - selection || defaultURL || process.env.CODER_URL?.trim() || ""; + prePopulateUrl || defaultURL || process.env.CODER_URL?.trim() || ""; quickPick.placeholder = "https://example.coder.com"; quickPick.title = "Enter the URL of your Coder deployment."; @@ -111,9 +112,9 @@ async function askURL( export async function maybeAskUrl( mementoManager: MementoManager, providedUrl: string | undefined | null, - lastUsedUrl?: string, + prePopulateUrl?: string, ): Promise { - let url = providedUrl || (await askURL(mementoManager, lastUsedUrl)); + let url = providedUrl || (await askURL(mementoManager, prePopulateUrl)); if (!url) { // User aborted. return undefined; diff --git a/src/remote/remote.ts b/src/remote/remote.ts index 27a0477e..4c27106b 100644 --- a/src/remote/remote.ts +++ b/src/remote/remote.ts @@ -27,9 +27,11 @@ import * as cliUtils from "../core/cliUtils"; import { type ServiceContainer } from "../core/container"; import { type ContextManager } from "../core/contextManager"; import { type PathResolver } from "../core/pathResolver"; +import { type SecretsManager } from "../core/secretsManager"; import { featureSetForVersion, type FeatureSet } from "../featureSet"; import { Inbox } from "../inbox"; import { type Logger } from "../logging/logger"; +import { type LoginCoordinator } from "../login/loginCoordinator"; import { AuthorityPrefix, escapeCommandArg, @@ -44,6 +46,7 @@ import { computeSSHProperties, sshSupportsSetEnv } from "./sshSupport"; import { WorkspaceStateMachine } from "./workspaceStateMachine"; export interface RemoteDetails extends vscode.Disposable { + safeHostname: string; url: string; token: string; } @@ -55,48 +58,21 @@ export class Remote { private readonly pathResolver: PathResolver; private readonly cliManager: CliManager; private readonly contextManager: ContextManager; - - // Used to race between the login dialog and logging in from a different window - private loginDetectedResolver: (() => void) | undefined; - private loginDetectedRejector: ((reason?: Error) => void) | undefined; - private loginDetectedPromise: Promise = Promise.resolve(); + private readonly secretsManager: SecretsManager; + private readonly loginCoordinator: LoginCoordinator; public constructor( serviceContainer: ServiceContainer, private readonly commands: Commands, - private readonly mode: vscode.ExtensionMode, + private readonly extensionContext: vscode.ExtensionContext, ) { this.vscodeProposed = serviceContainer.getVsCodeProposed(); this.logger = serviceContainer.getLogger(); this.pathResolver = serviceContainer.getPathResolver(); this.cliManager = serviceContainer.getCliManager(); this.contextManager = serviceContainer.getContextManager(); - } - - /** - * Creates a new promise that will be resolved when login is detected in another window. - */ - private createLoginDetectionPromise(): void { - if (this.loginDetectedRejector) { - this.loginDetectedRejector( - new Error("Login detection cancelled - new login attempt started"), - ); - } - this.loginDetectedPromise = new Promise((resolve, reject) => { - this.loginDetectedResolver = resolve; - this.loginDetectedRejector = reject; - }); - } - - /** - * Resolves the current login detection promise if one exists. - */ - public resolveLoginDetected(): void { - if (this.loginDetectedResolver) { - this.loginDetectedResolver(); - this.loginDetectedResolver = undefined; - this.loginDetectedRejector = undefined; - } + this.secretsManager = serviceContainer.getSecretsManager(); + this.loginCoordinator = serviceContainer.getLoginCoordinator(); } /** @@ -117,161 +93,187 @@ export class Remote { const workspaceName = `${parts.username}/${parts.workspace}`; - // Migrate "session_token" file to "session", if needed. - await this.migrateSessionToken(parts.label); + // Migrate existing legacy file-based auth to secrets storage. + await this.migrateToSecretsStorage(parts.safeHostname); // Get the URL and token belonging to this host. - const { url: baseUrlRaw, token } = await this.cliManager.readConfig( - parts.label, - ); + const auth = await this.secretsManager.getSessionAuth(parts.safeHostname); + const baseUrlRaw = auth?.url ?? ""; + const token = auth?.token; + // Empty token is valid for mTLS + if (baseUrlRaw && token !== undefined) { + await this.cliManager.configure(parts.safeHostname, baseUrlRaw, token); + } - const showLoginDialog = async (message: string) => { - this.createLoginDetectionPromise(); - const dialogPromise = this.vscodeProposed.window.showInformationMessage( - message, - { - useCustom: true, - modal: true, - detail: `You must log in to access ${workspaceName}. If you've already logged in, you may close this dialog.`, - }, - "Log In", - ); + const disposables: vscode.Disposable[] = []; - // Race between dialog and login detection - const result = await Promise.race([ - this.loginDetectedPromise.then(() => ({ type: "login" as const })), - dialogPromise.then((userChoice) => ({ - type: "dialog" as const, - userChoice, - })), - ]); - - if (result.type === "login") { - return this.setup(remoteAuthority, firstConnect, remoteSshExtensionId); - } else if (!result.userChoice) { - // User declined to log in. - await this.closeRemote(); - return; - } else { - // Log in then try again. - await this.commands.login({ url: baseUrlRaw, label: parts.label }); - return this.setup(remoteAuthority, firstConnect, remoteSshExtensionId); + try { + const ensureLoggedInAndRetry = async ( + message: string, + url: string | undefined, + ) => { + const result = await this.loginCoordinator.ensureLoggedInWithDialog({ + safeHostname: parts.safeHostname, + url, + message, + detailPrefix: `You must log in to access ${workspaceName}.`, + }); + + // Dispose before retrying since setup will create new disposables + disposables.forEach((d) => d.dispose()); + if (result.success) { + // Login successful, retry setup + return this.setup( + remoteAuthority, + firstConnect, + remoteSshExtensionId, + ); + } else { + // User cancelled or login failed + await this.closeRemote(); + } + }; + + // It could be that the cli config was deleted. If so, ask for the url. + if ( + !baseUrlRaw || + (!token && needToken(vscode.workspace.getConfiguration())) + ) { + return ensureLoggedInAndRetry("You are not logged in...", baseUrlRaw); } - }; - // It could be that the cli config was deleted. If so, ask for the url. - if ( - !baseUrlRaw || - (!token && needToken(vscode.workspace.getConfiguration())) - ) { - return showLoginDialog("You are not logged in..."); - } + this.logger.info("Using deployment URL", baseUrlRaw); + this.logger.info("Using hostname", parts.safeHostname || "n/a"); - this.logger.info("Using deployment URL", baseUrlRaw); - this.logger.info("Using deployment label", parts.label || "n/a"); - - // We could use the plugin client, but it is possible for the user to log - // out or log into a different deployment while still connected, which would - // break this connection. We could force close the remote session or - // disallow logging out/in altogether, but for now just use a separate - // client to remain unaffected by whatever the plugin is doing. - const workspaceClient = CoderApi.create(baseUrlRaw, token, this.logger); - // Store for use in commands. - this.commands.workspaceRestClient = workspaceClient; - - let binaryPath: string | undefined; - if (this.mode === vscode.ExtensionMode.Production) { - binaryPath = await this.cliManager.fetchBinary( - workspaceClient, - parts.label, + // We could use the plugin client, but it is possible for the user to log + // out or log into a different deployment while still connected, which would + // break this connection. We could force close the remote session or + // disallow logging out/in altogether, but for now just use a separate + // client to remain unaffected by whatever the plugin is doing. + const workspaceClient = CoderApi.create(baseUrlRaw, token, this.logger); + disposables.push(workspaceClient); + // Store for use in commands. + this.commands.remoteWorkspaceClient = workspaceClient; + + // Listen for token changes for this deployment + disposables.push( + this.secretsManager.onDidChangeSessionAuth( + parts.safeHostname, + async (auth) => { + workspaceClient.setCredentials(auth?.url, auth?.token); + if (auth?.url) { + await this.cliManager.configure( + parts.safeHostname, + auth.url, + auth.token, + ); + this.logger.info( + "Updated CLI config with new token for remote deployment", + ); + } + }, + ), ); - } else { - try { - // In development, try to use `/tmp/coder` as the binary path. - // This is useful for debugging with a custom bin! - binaryPath = path.join(os.tmpdir(), "coder"); - await fs.stat(binaryPath); - } catch { + + let binaryPath: string | undefined; + if ( + this.extensionContext.extensionMode === vscode.ExtensionMode.Production + ) { binaryPath = await this.cliManager.fetchBinary( workspaceClient, - parts.label, + parts.safeHostname, ); + } else { + try { + // In development, try to use `/tmp/coder` as the binary path. + // This is useful for debugging with a custom bin! + binaryPath = path.join(os.tmpdir(), "coder"); + await fs.stat(binaryPath); + } catch { + binaryPath = await this.cliManager.fetchBinary( + workspaceClient, + parts.safeHostname, + ); + } } - } - // First thing is to check the version. - const buildInfo = await workspaceClient.getBuildInfo(); + // First thing is to check the version. + const buildInfo = await workspaceClient.getBuildInfo(); - let version: semver.SemVer | null = null; - try { - version = semver.parse(await cliUtils.version(binaryPath)); - } catch { - version = semver.parse(buildInfo.version); - } - - const featureSet = featureSetForVersion(version); + let version: semver.SemVer | null = null; + try { + version = semver.parse(await cliUtils.version(binaryPath)); + } catch { + version = semver.parse(buildInfo.version); + } - // Server versions before v0.14.1 don't support the vscodessh command! - if (!featureSet.vscodessh) { - await this.vscodeProposed.window.showErrorMessage( - "Incompatible Server", - { - detail: - "Your Coder server is too old to support the Coder extension! Please upgrade to v0.14.1 or newer.", - modal: true, - useCustom: true, - }, - "Close Remote", - ); - await this.closeRemote(); - return; - } + const featureSet = featureSetForVersion(version); - // Next is to find the workspace from the URI scheme provided. - let workspace: Workspace; - try { - this.logger.info(`Looking for workspace ${workspaceName}...`); - workspace = await workspaceClient.getWorkspaceByOwnerAndName( - parts.username, - parts.workspace, - ); - this.logger.info( - `Found workspace ${workspaceName} with status`, - workspace.latest_build.status, - ); - this.commands.workspace = workspace; - } catch (error) { - if (!isAxiosError(error)) { - throw error; + // Server versions before v0.14.1 don't support the vscodessh command! + if (!featureSet.vscodessh) { + await this.vscodeProposed.window.showErrorMessage( + "Incompatible Server", + { + detail: + "Your Coder server is too old to support the Coder extension! Please upgrade to v0.14.1 or newer.", + modal: true, + useCustom: true, + }, + "Close Remote", + ); + disposables.forEach((d) => d.dispose()); + await this.closeRemote(); + return; } - switch (error.response?.status) { - case 404: { - const result = - await this.vscodeProposed.window.showInformationMessage( - `That workspace doesn't exist!`, - { - modal: true, - detail: `${workspaceName} cannot be found on ${baseUrlRaw}. Maybe it was deleted...`, - useCustom: true, - }, - "Open Workspace", + + // Next is to find the workspace from the URI scheme provided. + let workspace: Workspace; + try { + this.logger.info(`Looking for workspace ${workspaceName}...`); + workspace = await workspaceClient.getWorkspaceByOwnerAndName( + parts.username, + parts.workspace, + ); + this.logger.info( + `Found workspace ${workspaceName} with status`, + workspace.latest_build.status, + ); + this.commands.workspace = workspace; + } catch (error) { + if (!isAxiosError(error)) { + throw error; + } + switch (error.response?.status) { + case 404: { + const result = + await this.vscodeProposed.window.showInformationMessage( + `That workspace doesn't exist!`, + { + modal: true, + detail: `${workspaceName} cannot be found on ${baseUrlRaw}. Maybe it was deleted...`, + useCustom: true, + }, + "Open Workspace", + ); + disposables.forEach((d) => d.dispose()); + if (!result) { + await this.closeRemote(); + } + await vscode.commands.executeCommand("coder.open"); + return; + } + case 401: { + disposables.forEach((d) => d.dispose()); + return ensureLoggedInAndRetry( + "Your session expired...", + baseUrlRaw, ); - if (!result) { - await this.closeRemote(); } - await vscode.commands.executeCommand("coder.open"); - return; - } - case 401: { - return showLoginDialog("Your session expired..."); + default: + throw error; } - default: - throw error; } - } - const disposables: vscode.Disposable[] = []; - try { // Register before connection so the label still displays! let labelFormatterDisposable = this.registerLabelFormatter( remoteAuthority, @@ -410,10 +412,10 @@ export class Remote { // the user for the platform. let mungedPlatforms = false; if ( - !remotePlatforms[parts.host] || - remotePlatforms[parts.host] !== agent.operating_system + !remotePlatforms[parts.sshHost] || + remotePlatforms[parts.sshHost] !== agent.operating_system ) { - remotePlatforms[parts.host] = agent.operating_system; + remotePlatforms[parts.sshHost] = agent.operating_system; settingsContent = jsonc.applyEdits( settingsContent, jsonc.modify( @@ -473,8 +475,8 @@ export class Remote { this.logger.info("Updating SSH config..."); await this.updateSSHConfig( workspaceClient, - parts.label, - parts.host, + parts.safeHostname, + parts.sshHost, binaryPath, logDir, featureSet, @@ -486,7 +488,7 @@ export class Remote { // Monitor SSH process and display network status const sshMonitor = SshProcessMonitor.start({ - sshHost: parts.host, + sshHost: parts.sshHost, networkInfoPath: this.pathResolver.getNetworkInfoPath(), proxyLogDir: logDir || undefined, logger: this.logger, @@ -539,8 +541,9 @@ export class Remote { // deployment in the sidebar. We use our own client in here for reasons // explained above. return { + safeHostname: parts.safeHostname, url: baseUrlRaw, - token, + token: token ?? "", dispose: () => { disposables.forEach((d) => d.dispose()); }, @@ -548,18 +551,51 @@ export class Remote { } /** - * Migrate the session token file from "session_token" to "session", if needed. + * Migrate legacy file-based auth to secrets storage. */ - private async migrateSessionToken(label: string) { - const oldTokenPath = this.pathResolver.getLegacySessionTokenPath(label); - const newTokenPath = this.pathResolver.getSessionTokenPath(label); + private async migrateToSecretsStorage(safeHostname: string) { + await this.migrateSessionTokenFile(safeHostname); + await this.migrateSessionAuthFromFiles(safeHostname); + } + + /** + * Migrate the session token file from "session_token" to "session". + */ + private async migrateSessionTokenFile(safeHostname: string) { + const oldTokenPath = + this.pathResolver.getLegacySessionTokenPath(safeHostname); + const newTokenPath = this.pathResolver.getSessionTokenPath(safeHostname); try { await fs.rename(oldTokenPath, newTokenPath); } catch (error) { - if ((error as NodeJS.ErrnoException)?.code === "ENOENT") { - return; + if ((error as NodeJS.ErrnoException)?.code !== "ENOENT") { + throw error; } - throw error; + } + } + + /** + * Migrate URL and session token from files to the mutli-deployment secrets storage. + */ + private async migrateSessionAuthFromFiles(safeHostname: string) { + const existingAuth = await this.secretsManager.getSessionAuth(safeHostname); + if (existingAuth) { + return; + } + + const urlPath = this.pathResolver.getUrlPath(safeHostname); + const tokenPath = this.pathResolver.getSessionTokenPath(safeHostname); + const [url, token] = await Promise.allSettled([ + fs.readFile(urlPath, "utf8"), + fs.readFile(tokenPath, "utf8"), + ]); + + if (url.status === "fulfilled" && token.status === "fulfilled") { + this.logger.info("Migrating session auth from files for", safeHostname); + await this.secretsManager.setSessionAuth(safeHostname, { + url: url.value.trim(), + token: token.value.trim(), + }); } } @@ -662,7 +698,7 @@ export class Remote { // all Coder entries. private async updateSSHConfig( restClient: Api, - label: string, + safeHostname: string, hostName: string, binaryPath: string, logDir: string, @@ -737,13 +773,13 @@ export class Remote { const sshConfig = new SSHConfig(sshConfigFile); await sshConfig.load(); - const hostPrefix = label - ? `${AuthorityPrefix}.${label}--` + const hostPrefix = safeHostname + ? `${AuthorityPrefix}.${safeHostname}--` : `${AuthorityPrefix}--`; const proxyCommand = await this.buildProxyCommand( binaryPath, - label, + safeHostname, hostPrefix, logDir, featureSet.wildcardSSH, @@ -763,7 +799,7 @@ export class Remote { sshValues.SetEnv = " CODER_SSH_SESSION_TYPE=vscode"; } - await sshConfig.update(label, sshValues, sshConfigOverrides); + await sshConfig.update(safeHostname, sshValues, sshConfigOverrides); // A user can provide a "Host *" entry in their SSH config to add options // to all hosts. We need to ensure that the options we set are not @@ -795,6 +831,7 @@ export class Remote { await this.reloadWindow(); } await this.closeRemote(); + throw new Error("SSH config mismatch, closing remote"); } return sshConfig.getRaw(); diff --git a/src/remote/sshConfig.ts b/src/remote/sshConfig.ts index f5fea264..668ce092 100644 --- a/src/remote/sshConfig.ts +++ b/src/remote/sshConfig.ts @@ -85,18 +85,18 @@ export function mergeSSHConfigValues( } export class SSHConfig { - private filePath: string; - private fileSystem: FileSystem; + private readonly filePath: string; + private readonly fileSystem: FileSystem; private raw: string | undefined; - private startBlockComment(label: string): string { - return label - ? `# --- START CODER VSCODE ${label} ---` + private startBlockComment(safeHostname: string): string { + return safeHostname + ? `# --- START CODER VSCODE ${safeHostname} ---` : `# --- START CODER VSCODE ---`; } - private endBlockComment(label: string): string { - return label - ? `# --- END CODER VSCODE ${label} ---` + private endBlockComment(safeHostname: string): string { + return safeHostname + ? `# --- END CODER VSCODE ${safeHostname} ---` : `# --- END CODER VSCODE ---`; } @@ -115,15 +115,15 @@ export class SSHConfig { } /** - * Update the block for the deployment with the provided label. + * Update the block for the deployment with the provided hostname. */ async update( - label: string, + safeHostname: string, values: SSHValues, overrides?: Record, ) { - const block = this.getBlock(label); - const newBlock = this.buildBlock(label, values, overrides); + const block = this.getBlock(safeHostname); + const newBlock = this.buildBlock(safeHostname, values, overrides); if (block) { this.replaceBlock(block, newBlock); } else { @@ -133,24 +133,24 @@ export class SSHConfig { } /** - * Get the block for the deployment with the provided label. + * Get the block for the deployment with the provided hostname. */ - private getBlock(label: string): Block | undefined { + private getBlock(safeHostname: string): Block | undefined { const raw = this.getRaw(); - const startBlock = this.startBlockComment(label); - const endBlock = this.endBlockComment(label); + const startBlock = this.startBlockComment(safeHostname); + const endBlock = this.endBlockComment(safeHostname); const startBlockCount = countSubstring(startBlock, raw); const endBlockCount = countSubstring(endBlock, raw); if (startBlockCount !== endBlockCount) { throw new SSHConfigBadFormat( - `Malformed config: ${this.filePath} has an unterminated START CODER VSCODE ${label ? label + " " : ""}block. Each START block must have an END block.`, + `Malformed config: ${this.filePath} has an unterminated START CODER VSCODE ${safeHostname ? safeHostname + " " : ""}block. Each START block must have an END block.`, ); } if (startBlockCount > 1 || endBlockCount > 1) { throw new SSHConfigBadFormat( - `Malformed config: ${this.filePath} has ${startBlockCount} START CODER VSCODE ${label ? label + " " : ""}sections. Please remove all but one.`, + `Malformed config: ${this.filePath} has ${startBlockCount} START CODER VSCODE ${safeHostname ? safeHostname + " " : ""}sections. Please remove all but one.`, ); } @@ -185,22 +185,22 @@ export class SSHConfig { * the keys is determinstic based on the input. Expected values are always in * a consistent order followed by any additional overrides in sorted order. * - * @param label - The label for the deployment (like the encoded URL). - * @param values - The expected SSH values for using ssh with Coder. - * @param overrides - Overrides typically come from the deployment api and are - * used to override the default values. The overrides are - * given as key:value pairs where the key is the ssh config - * file key. If the key matches an expected value, the - * expected value is overridden. If it does not match an - * expected value, it is appended to the end of the block. + * @param safeHostname - The hostname for the deployment. + * @param values - The expected SSH values for using ssh with Coder. + * @param overrides - Overrides typically come from the deployment api and are + * used to override the default values. The overrides are + * given as key:value pairs where the key is the ssh config + * file key. If the key matches an expected value, the + * expected value is overridden. If it does not match an + * expected value, it is appended to the end of the block. */ private buildBlock( - label: string, + safeHostname: string, values: SSHValues, overrides?: Record, ) { const { Host, ...otherValues } = values; - const lines = [this.startBlockComment(label), `Host ${Host}`]; + const lines = [this.startBlockComment(safeHostname), `Host ${Host}`]; // configValues is the merged values of the defaults and the overrides. const configValues = mergeSSHConfigValues(otherValues, overrides || {}); @@ -216,7 +216,7 @@ export class SSHConfig { } }); - lines.push(this.endBlockComment(label)); + lines.push(this.endBlockComment(safeHostname)); return { raw: lines.join("\n"), }; diff --git a/src/remote/workspaceStateMachine.ts b/src/remote/workspaceStateMachine.ts index 340ec960..d797ae5c 100644 --- a/src/remote/workspaceStateMachine.ts +++ b/src/remote/workspaceStateMachine.ts @@ -76,7 +76,7 @@ export class WorkspaceStateMachine implements vscode.Disposable { progress.report({ message: `starting ${workspaceName}...` }); this.logger.info(`Starting ${workspaceName}`); const globalConfigDir = this.pathResolver.getGlobalConfigDir( - this.parts.label, + this.parts.safeHostname, ); await startWorkspaceIfStoppedOrFailed( this.workspaceClient, diff --git a/src/util.ts b/src/util.ts index 776ba1db..35492eea 100644 --- a/src/util.ts +++ b/src/util.ts @@ -3,8 +3,8 @@ import url from "node:url"; export interface AuthorityParts { agent: string | undefined; - host: string; - label: string; + sshHost: string; + safeHostname: string; username: string; workspace: string; } @@ -93,8 +93,8 @@ export function parseRemoteAuthority(authority: string): AuthorityParts | null { return { agent: agent, - host: authorityParts[1], - label: parts[0].replace(/^coder-vscode\.?/, ""), + sshHost: authorityParts[1], + safeHostname: parts[0].replace(/^coder-vscode\.?/, ""), username: parts[1], workspace: workspace, }; diff --git a/src/websocket/codes.ts b/src/websocket/codes.ts index ac8eccf7..f3fd95cd 100644 --- a/src/websocket/codes.ts +++ b/src/websocket/codes.ts @@ -19,7 +19,9 @@ export const WebSocketCloseCode = { /** HTTP status codes used for socket creation and connection logic */ export const HttpStatusCode = { - /** Authentication or permission denied */ + /** Authentication required */ + UNAUTHORIZED: 401, + /** Permission denied */ FORBIDDEN: 403, /** Endpoint not found */ NOT_FOUND: 404, @@ -43,7 +45,9 @@ export const UNRECOVERABLE_WS_CLOSE_CODES = new Set([ * These appear during socket creation and should stop reconnection attempts. */ export const UNRECOVERABLE_HTTP_CODES = new Set([ + HttpStatusCode.UNAUTHORIZED, HttpStatusCode.FORBIDDEN, + HttpStatusCode.NOT_FOUND, HttpStatusCode.GONE, HttpStatusCode.UPGRADE_REQUIRED, ]); diff --git a/src/websocket/reconnectingWebSocket.ts b/src/websocket/reconnectingWebSocket.ts index 2ced9351..9bd96d4a 100644 --- a/src/websocket/reconnectingWebSocket.ts +++ b/src/websocket/reconnectingWebSocket.ts @@ -27,7 +27,6 @@ export class ReconnectingWebSocket { readonly #socketFactory: SocketFactory; readonly #logger: Logger; - readonly #apiRoute: string; readonly #options: Required; readonly #eventHandlers: { [K in WebSocketEventType]: Set>; @@ -39,9 +38,11 @@ export class ReconnectingWebSocket }; #currentSocket: UnidirectionalStream | null = null; + #lastRoute = "unknown"; // Cached route for logging when socket is closed #backoffMs: number; #reconnectTimeoutId: NodeJS.Timeout | null = null; - #isDisposed = false; + isDisconnected = false; // Temporary pause, can be resumed via reconnect() + #isDisposed = false; // Permanent disposal, cannot be resumed #isConnecting = false; #pendingReconnect = false; readonly #onDispose?: () => void; @@ -49,13 +50,11 @@ export class ReconnectingWebSocket private constructor( socketFactory: SocketFactory, logger: Logger, - apiRoute: string, options: ReconnectingWebSocketOptions = {}, onDispose?: () => void, ) { this.#socketFactory = socketFactory; this.#logger = logger; - this.#apiRoute = apiRoute; this.#options = { initialBackoffMs: options.initialBackoffMs ?? 250, maxBackoffMs: options.maxBackoffMs ?? 30000, @@ -68,14 +67,12 @@ export class ReconnectingWebSocket static async create( socketFactory: SocketFactory, logger: Logger, - apiRoute: string, options: ReconnectingWebSocketOptions = {}, onDispose?: () => void, ): Promise> { const instance = new ReconnectingWebSocket( socketFactory, logger, - apiRoute, options, onDispose, ); @@ -87,6 +84,19 @@ export class ReconnectingWebSocket return this.#currentSocket?.url ?? ""; } + /** + * Extract the route (pathname + search) from the current socket URL for logging. + * Falls back to the last known route when the socket is closed. + */ + get #route(): string { + const socketUrl = this.#currentSocket?.url; + if (!socketUrl) { + return this.#lastRoute; + } + const url = new URL(socketUrl); + return url.pathname + url.search; + } + addEventListener( event: TEvent, callback: EventHandler, @@ -101,7 +111,16 @@ export class ReconnectingWebSocket this.#eventHandlers[event].delete(callback); } + /** + * Force an immediate reconnection attempt. + * Resumes the socket if previously disconnected via disconnect(). + */ reconnect(): void { + if (this.isDisconnected) { + this.isDisconnected = false; + this.#backoffMs = this.#options.initialBackoffMs; + } + if (this.#isDisposed) { return; } @@ -121,6 +140,18 @@ export class ReconnectingWebSocket this.connect().catch((error) => this.handleConnectionError(error)); } + /** + * Temporarily disconnect the socket. Can be resumed via reconnect(). + */ + disconnect(code?: number, reason?: string): void { + if (this.#isDisposed || this.isDisconnected) { + return; + } + + this.isDisconnected = true; + this.clearCurrentSocket(code, reason); + } + close(code?: number, reason?: string): void { if (this.#isDisposed) { return; @@ -139,7 +170,7 @@ export class ReconnectingWebSocket } private async connect(): Promise { - if (this.#isDisposed || this.#isConnecting) { + if (this.#isDisposed || this.isDisconnected || this.#isConnecting) { return; } @@ -156,6 +187,7 @@ export class ReconnectingWebSocket const socket = await this.#socketFactory(); this.#currentSocket = socket; + this.#lastRoute = this.#route; socket.addEventListener("open", (event) => { this.#backoffMs = this.#options.initialBackoffMs; @@ -168,10 +200,21 @@ export class ReconnectingWebSocket socket.addEventListener("error", (event) => { this.executeHandlers("error", event); + + // Check for unrecoverable HTTP errors in the error event + // HTTP errors during handshake fire 'error' then 'close' with 1006 + // We need to suspend here to prevent infinite reconnect loops + const errorMessage = event.error?.message ?? event.message ?? ""; + if (this.isUnrecoverableHttpError(errorMessage)) { + this.#logger.error( + `Unrecoverable HTTP error for ${this.#route}: ${errorMessage}`, + ); + this.disconnect(); + } }); socket.addEventListener("close", (event) => { - if (this.#isDisposed) { + if (this.#isDisposed || this.isDisconnected) { return; } @@ -181,7 +224,8 @@ export class ReconnectingWebSocket this.#logger.error( `WebSocket connection closed with unrecoverable error code ${event.code}`, ); - this.dispose(); + // Suspend instead of dispose - allows recovery when credentials change + this.disconnect(); return; } @@ -204,7 +248,11 @@ export class ReconnectingWebSocket } private scheduleReconnect(): void { - if (this.#isDisposed || this.#reconnectTimeoutId !== null) { + if ( + this.#isDisposed || + this.isDisconnected || + this.#reconnectTimeoutId !== null + ) { return; } @@ -213,7 +261,7 @@ export class ReconnectingWebSocket const delayMs = Math.max(0, this.#backoffMs + jitter); this.#logger.debug( - `Reconnecting WebSocket in ${Math.round(delayMs)}ms for ${this.#apiRoute}`, + `Reconnecting WebSocket in ${Math.round(delayMs)}ms for ${this.#route}`, ); this.#reconnectTimeoutId = setTimeout(() => { @@ -233,7 +281,7 @@ export class ReconnectingWebSocket handler(eventData); } catch (error) { this.#logger.error( - `Error in ${event} handler for ${this.#apiRoute}`, + `Error in ${event} handler for ${this.#route}`, error, ); } @@ -241,37 +289,34 @@ export class ReconnectingWebSocket } /** - * Checks if the error is unrecoverable and disposes the connection, + * Checks if the error is unrecoverable and suspends the connection, * otherwise schedules a reconnect. */ private handleConnectionError(error: unknown): void { - if (this.#isDisposed) { + if (this.#isDisposed || this.isDisconnected) { return; } if (this.isUnrecoverableHttpError(error)) { this.#logger.error( - `Unrecoverable HTTP error during connection for ${this.#apiRoute}`, + `Unrecoverable HTTP error during connection for ${this.#route}`, error, ); - this.dispose(); + this.disconnect(); return; } - this.#logger.warn( - `WebSocket connection failed for ${this.#apiRoute}`, - error, - ); + this.#logger.warn(`WebSocket connection failed for ${this.#route}`, error); this.scheduleReconnect(); } /** - * Check if an error contains an unrecoverable HTTP status code. + * Check if an error message contains an unrecoverable HTTP status code. */ private isUnrecoverableHttpError(error: unknown): boolean { - const errorMessage = error instanceof Error ? error.message : String(error); + const message = (error as { message?: string }).message || String(error); for (const code of UNRECOVERABLE_HTTP_CODES) { - if (errorMessage.includes(String(code))) { + if (message.includes(String(code))) { return true; } } @@ -284,6 +329,18 @@ export class ReconnectingWebSocket } this.#isDisposed = true; + this.clearCurrentSocket(code, reason); + + for (const set of Object.values(this.#eventHandlers)) { + set.clear(); + } + + this.#onDispose?.(); + } + + private clearCurrentSocket(code?: number, reason?: string): void { + // Clear pending reconnect to prevent resume + this.#pendingReconnect = false; if (this.#reconnectTimeoutId !== null) { clearTimeout(this.#reconnectTimeoutId); @@ -294,11 +351,5 @@ export class ReconnectingWebSocket this.#currentSocket.close(code, reason); this.#currentSocket = null; } - - for (const set of Object.values(this.#eventHandlers)) { - set.clear(); - } - - this.#onDispose?.(); } } diff --git a/test/mocks/testHelpers.ts b/test/mocks/testHelpers.ts index 5678cd48..f799d6e3 100644 --- a/test/mocks/testHelpers.ts +++ b/test/mocks/testHelpers.ts @@ -1,8 +1,11 @@ -import { type IncomingMessage } from "node:http"; import { vi } from "vitest"; import * as vscode from "vscode"; -import { type Logger } from "@/logging/logger"; +import type { User } from "coder/site/src/api/typesGenerated"; +import type { IncomingMessage } from "node:http"; + +import type { CoderApi } from "@/api/coderApi"; +import type { Logger } from "@/logging/logger"; /** * Mock configuration provider that integrates with the vscode workspace configuration mock. @@ -137,11 +140,13 @@ export class MockProgressReporter { } /** - * Mock user interaction that integrates with vscode.window message dialogs. + * Mock user interaction that integrates with vscode.window message dialogs and input boxes. * Use this to control user responses in tests. */ export class MockUserInteraction { private readonly responses = new Map(); + private inputBoxValue: string | undefined; + private inputBoxValidateInput: ((value: string) => Promise) | undefined; private externalUrls: string[] = []; constructor() { @@ -149,12 +154,28 @@ export class MockUserInteraction { } /** - * Set a response for a specific message + * Set a response for a specific message dialog */ setResponse(message: string, response: string | undefined): void { this.responses.set(message, response); } + /** + * Set the value to return from showInputBox. + * Pass undefined to simulate user cancelling. + */ + setInputBoxValue(value: string | undefined): void { + this.inputBoxValue = value; + } + + /** + * Set a custom validateInput handler for showInputBox. + * This allows tests to simulate the validation callback behavior. + */ + setInputBoxValidateInput(fn: (value: string) => Promise): void { + this.inputBoxValidateInput = fn; + } + /** * Get all URLs that were opened externally */ @@ -170,10 +191,13 @@ export class MockUserInteraction { } /** - * Clear all responses + * Clear all responses and input box values */ - clearResponses(): void { + clear(): void { this.responses.clear(); + this.inputBoxValue = undefined; + this.inputBoxValidateInput = undefined; + this.externalUrls = []; } /** @@ -206,6 +230,32 @@ export class MockUserInteraction { return Promise.resolve(true); }, ); + + vi.mocked(vscode.window.showInputBox).mockImplementation( + async (options?: vscode.InputBoxOptions) => { + const value = this.inputBoxValue; + if (value === undefined) { + return undefined; // User cancelled + } + + if (options?.validateInput) { + const validationResult = await options.validateInput(value); + if (validationResult) { + // Validation failed - in real VS Code this would show error + // For tests, we can use the custom handler or return undefined + if (this.inputBoxValidateInput) { + await this.inputBoxValidateInput(value); + } + return undefined; + } + } else if (this.inputBoxValidateInput) { + // Run custom validation handler even without options.validateInput + await this.inputBoxValidateInput(value); + } + + return value; + }, + ); } } @@ -399,3 +449,93 @@ export class MockStatusBar { ); } } + +/** + * Mock CoderApi for testing. Tracks method calls and allows controlling responses. + */ +export class MockCoderApi + implements + Pick< + CoderApi, + | "setHost" + | "setSessionToken" + | "setCredentials" + | "getAuthenticatedUser" + | "dispose" + > +{ + private _host: string | undefined; + private _token: string | undefined; + private authenticatedUser: User | Error | undefined; + + readonly setHost = vi.fn((host: string | undefined) => { + this._host = host; + }); + + readonly setSessionToken = vi.fn((token: string) => { + this._token = token; + }); + + readonly setCredentials = vi.fn( + (host: string | undefined, token: string | undefined) => { + this._host = host; + this._token = token; + }, + ); + + readonly getAuthenticatedUser = vi.fn((): Promise => { + if (this.authenticatedUser instanceof Error) { + return Promise.reject(this.authenticatedUser); + } + if (!this.authenticatedUser) { + return Promise.reject(new Error("Not authenticated")); + } + return Promise.resolve(this.authenticatedUser); + }); + + readonly dispose = vi.fn(); + + /** + * Get current host (for assertions) + */ + get host(): string | undefined { + return this._host; + } + + /** + * Get current token (for assertions) + */ + get token(): string | undefined { + return this._token; + } + + /** + * Set the authenticated user that will be returned by getAuthenticatedUser. + * Pass an Error to make getAuthenticatedUser reject. + */ + setAuthenticatedUserResponse(user: User | Error | undefined): void { + this.authenticatedUser = user; + } +} + +/** + * Create a mock User for testing. + */ +export function createMockUser(overrides: Partial = {}): User { + return { + id: "user-123", + username: "testuser", + email: "test@example.com", + name: "Test User", + created_at: new Date().toISOString(), + updated_at: new Date().toISOString(), + last_seen_at: new Date().toISOString(), + status: "active", + organization_ids: [], + roles: [], + avatar_url: "", + login_type: "password", + theme_preference: "", + ...overrides, + }; +} diff --git a/test/mocks/vscode.runtime.ts b/test/mocks/vscode.runtime.ts index 4da3796f..ba282f40 100644 --- a/test/mocks/vscode.runtime.ts +++ b/test/mocks/vscode.runtime.ts @@ -28,6 +28,11 @@ export const TreeItemCollapsibleState = E({ export const StatusBarAlignment = E({ Left: 1, Right: 2 }); export const ExtensionMode = E({ Production: 1, Development: 2, Test: 3 }); export const UIKind = E({ Desktop: 1, Web: 2 }); +export const InputBoxValidationSeverity = E({ + Info: 1, + Warning: 2, + Error: 3, +}); export class Uri { constructor( @@ -142,6 +147,7 @@ const vscode = { StatusBarAlignment, ExtensionMode, UIKind, + InputBoxValidationSeverity, Uri, EventEmitter, window, diff --git a/test/unit/api/coderApi.test.ts b/test/unit/api/coderApi.test.ts index 4f90f33e..877ef5fc 100644 --- a/test/unit/api/coderApi.test.ts +++ b/test/unit/api/coderApi.test.ts @@ -11,7 +11,6 @@ import { CertificateError } from "@/error"; import { getHeaders } from "@/headers"; import { type RequestConfigWithMeta } from "@/logging/types"; import { ReconnectingWebSocket } from "@/websocket/reconnectingWebSocket"; -import { SseConnection } from "@/websocket/sseConnection"; import { createMockLogger, @@ -336,18 +335,20 @@ describe("CoderApi", () => { expect(EventSource).not.toHaveBeenCalled(); }); - it("falls back to SSE when WebSocket creation fails", async () => { + it("falls back to SSE when WebSocket creation fails with 404", async () => { + // Only 404 errors trigger SSE fallback - other errors are thrown vi.mocked(Ws).mockImplementation(() => { - throw new Error("WebSocket creation failed"); + throw new Error("Unexpected server response: 404"); }); const connection = await api.watchAgentMetadata(AGENT_ID); - expect(connection).toBeInstanceOf(SseConnection); + // Returns ReconnectingWebSocket (which wraps SSE internally) + expect(connection).toBeInstanceOf(ReconnectingWebSocket); expect(EventSource).toHaveBeenCalled(); }); - it("falls back to SSE on 404 error from WebSocket", async () => { + it("falls back to SSE on 404 error from WebSocket open", async () => { const mockWs = createMockWebSocket( `wss://${CODER_URL.replace("https://", "")}/api/v2/test`, { @@ -368,9 +369,64 @@ describe("CoderApi", () => { const connection = await api.watchAgentMetadata(AGENT_ID); - expect(connection).toBeInstanceOf(SseConnection); + // Returns ReconnectingWebSocket (which wraps SSE internally after WS 404) + expect(connection).toBeInstanceOf(ReconnectingWebSocket); expect(EventSource).toHaveBeenCalled(); }); + + it("throws non-404 errors without SSE fallback", async () => { + vi.mocked(Ws).mockImplementation(() => { + throw new Error("Network error"); + }); + + await expect(api.watchAgentMetadata(AGENT_ID)).rejects.toThrow( + "Network error", + ); + expect(EventSource).not.toHaveBeenCalled(); + }); + + describe("reconnection after fallback", () => { + beforeEach(() => vi.useFakeTimers({ shouldAdvanceTime: true })); + afterEach(() => vi.useRealTimers()); + + it("reconnects after SSE fallback and retries WS on each reconnect", async () => { + let wsAttempts = 0; + const mockEventSources: MockEventSource[] = []; + + vi.mocked(Ws).mockImplementation(() => { + wsAttempts++; + const mockWs = createMockWebSocket("wss://test", { + on: vi.fn((event: string, handler: (e: unknown) => void) => { + if (event === "error") { + setImmediate(() => + handler({ error: new Error("Something 404") }), + ); + } + return mockWs as Ws; + }), + }); + return mockWs as Ws; + }); + + vi.mocked(EventSource).mockImplementation(() => { + const es = createMockEventSource(`${CODER_URL}/api/v2/test`); + mockEventSources.push(es); + return es as unknown as EventSource; + }); + + const connection = await api.watchAgentMetadata(AGENT_ID); + expect(wsAttempts).toBe(1); + expect(EventSource).toHaveBeenCalledTimes(1); + + mockEventSources[0].fireError(); + await vi.advanceTimersByTimeAsync(300); + + expect(wsAttempts).toBe(2); + expect(EventSource).toHaveBeenCalledTimes(2); + + connection.close(); + }); + }); }); describe("Reconnection on Host/Token Changes", () => { @@ -413,6 +469,7 @@ describe("CoderApi", () => { expect(wsWrap.url).toContain(CODER_URL.replace("http", "ws")); api.setHost("https://new-coder.example.com"); + // Wait for the async reconnect to complete (factory is async) await new Promise((resolve) => setImmediate(resolve)); expect(sockets[0].close).toHaveBeenCalledWith( @@ -420,7 +477,8 @@ describe("CoderApi", () => { "Replacing connection", ); expect(sockets).toHaveLength(2); - expect(wsWrap.url).toContain("wss://new-coder.example.com"); + // Verify the new socket was created with the correct URL + expect(sockets[1].url).toContain("wss://new-coder.example.com"); }); it("does not reconnect when token or host are unchanged", async () => { @@ -435,6 +493,58 @@ describe("CoderApi", () => { expect(sockets[0].close).not.toHaveBeenCalled(); expect(sockets).toHaveLength(1); }); + + it("suspends sockets when host is set to empty string (logout)", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + + // Setting host to empty string (logout) should suspend (not permanently close) + api.setHost(""); + await new Promise((resolve) => setImmediate(resolve)); + + expect(sockets[0].close).toHaveBeenCalledWith(1000, "Host cleared"); + expect(sockets).toHaveLength(1); + }); + + it("does not reconnect when setting token after clearing host", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + + api.setHost(""); + api.setSessionToken("new-token"); + await new Promise((resolve) => setImmediate(resolve)); + + // Should only have the initial socket - no reconnection after token change + expect(sockets).toHaveLength(1); + expect(sockets[0].close).toHaveBeenCalledWith(1000, "Host cleared"); + }); + + it("setCredentials sets both host and token together", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + + api.setCredentials("https://new-coder.example.com", "new-token"); + await new Promise((resolve) => setImmediate(resolve)); + + // Should reconnect only once despite both values changing + expect(sockets).toHaveLength(2); + expect(sockets[1].url).toContain("wss://new-coder.example.com"); + }); + + it("setCredentials suspends when host is cleared", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + + api.setCredentials(undefined, undefined); + await new Promise((resolve) => setImmediate(resolve)); + + expect(sockets).toHaveLength(1); + expect(sockets[0].close).toHaveBeenCalledWith(1000, "Host cleared"); + }); }); describe("Error Handling", () => { @@ -447,6 +557,42 @@ describe("CoderApi", () => { ); }); }); + + describe("getHost/getSessionToken", () => { + it("returns current host and token", () => { + const api = createApi(CODER_URL, AXIOS_TOKEN); + + expect(api.getHost()).toBe(CODER_URL); + expect(api.getSessionToken()).toBe(AXIOS_TOKEN); + }); + }); + + describe("dispose", () => { + it("disposes all tracked reconnecting sockets", async () => { + const sockets: Array> = []; + vi.mocked(Ws).mockImplementation((url: string | URL) => { + const mockWs = createMockWebSocket(String(url), { + on: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler()); + } + return mockWs as Ws; + }), + }); + sockets.push(mockWs); + return mockWs as Ws; + }); + + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + expect(sockets).toHaveLength(1); + + api.dispose(); + + // Socket should be closed + expect(sockets[0].close).toHaveBeenCalled(); + }); + }); }); const mockAdapterImpl = vi.hoisted(() => (config: Record) => { @@ -472,18 +618,32 @@ function createMockWebSocket( }; } -function createMockEventSource(url: string): Partial { - return { +type MockEventSource = Partial & { + readyState: number; + fireOpen: () => void; + fireError: () => void; +}; + +function createMockEventSource(url: string): MockEventSource { + const handlers: Record void) | undefined> = {}; + const mock: MockEventSource = { url, readyState: EventSource.CONNECTING, - addEventListener: vi.fn((event, handler) => { + addEventListener: vi.fn((event: string, handler: (e: Event) => void) => { + handlers[event] = handler; if (event === "open") { setImmediate(() => handler(new Event("open"))); } }), removeEventListener: vi.fn(), close: vi.fn(), + fireOpen: () => handlers.open?.(new Event("open")), + fireError: () => { + mock.readyState = EventSource.CLOSED; + handlers.error?.(new Event("error")); + }, }; + return mock; } function setupWebSocketMock(ws: Partial): void { diff --git a/test/unit/core/cliManager.test.ts b/test/unit/core/cliManager.test.ts index 95755d31..6cdcdb07 100644 --- a/test/unit/core/cliManager.test.ts +++ b/test/unit/core/cliManager.test.ts @@ -165,52 +165,6 @@ describe("CliManager", () => { }); }); - describe("Read CLI Configuration", () => { - it("should read and trim stored configuration", async () => { - // Create directories and write files with whitespace - vol.mkdirSync("/path/base/deployment", { recursive: true }); - memfs.writeFileSync( - "/path/base/deployment/url", - " https://coder.example.com \n", - ); - memfs.writeFileSync( - "/path/base/deployment/session", - "\t test-token \r\n", - ); - - const result = await manager.readConfig("deployment"); - - expect(result).toEqual({ - url: "https://coder.example.com", - token: "test-token", - }); - }); - - it("should return empty strings for missing files", async () => { - const result = await manager.readConfig("deployment"); - - expect(result).toEqual({ - url: "", - token: "", - }); - }); - - it("should handle partial configuration", async () => { - vol.mkdirSync("/path/base/deployment", { recursive: true }); - memfs.writeFileSync( - "/path/base/deployment/url", - "https://coder.example.com", - ); - - const result = await manager.readConfig("deployment"); - - expect(result).toEqual({ - url: "https://coder.example.com", - token: "", - }); - }); - }); - describe("Binary Version Validation", () => { it("rejects invalid server versions", async () => { mockApi.getBuildInfo = vi.fn().mockResolvedValue({ version: "invalid" }); diff --git a/test/unit/core/mementoManager.test.ts b/test/unit/core/mementoManager.test.ts index 54289a65..791f7602 100644 --- a/test/unit/core/mementoManager.test.ts +++ b/test/unit/core/mementoManager.test.ts @@ -13,28 +13,22 @@ describe("MementoManager", () => { mementoManager = new MementoManager(memento); }); - describe("setUrl", () => { - it("should store URL and add to history", async () => { - await mementoManager.setUrl("https://coder.example.com"); + describe("addToUrlHistory", () => { + it("should add URL to history", async () => { + await mementoManager.addToUrlHistory("https://coder.example.com"); - expect(mementoManager.getUrl()).toBe("https://coder.example.com"); expect(memento.get("urlHistory")).toEqual(["https://coder.example.com"]); }); it("should not update history for falsy values", async () => { - await mementoManager.setUrl(undefined); - expect(mementoManager.getUrl()).toBeUndefined(); - expect(memento.get("urlHistory")).toBeUndefined(); - - await mementoManager.setUrl(""); - expect(mementoManager.getUrl()).toBe(""); + await mementoManager.addToUrlHistory(""); expect(memento.get("urlHistory")).toBeUndefined(); }); it("should deduplicate URLs in history", async () => { - await mementoManager.setUrl("url1"); - await mementoManager.setUrl("url2"); - await mementoManager.setUrl("url1"); // Re-add first URL + await mementoManager.addToUrlHistory("url1"); + await mementoManager.addToUrlHistory("url2"); + await mementoManager.addToUrlHistory("url1"); // Re-add first URL expect(memento.get("urlHistory")).toEqual(["url2", "url1"]); }); diff --git a/test/unit/core/secretsManager.test.ts b/test/unit/core/secretsManager.test.ts index bfe8c713..e5d2efc3 100644 --- a/test/unit/core/secretsManager.test.ts +++ b/test/unit/core/secretsManager.test.ts @@ -1,82 +1,277 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; -import { AuthAction, SecretsManager } from "@/core/secretsManager"; +import { + type CurrentDeploymentState, + SecretsManager, +} from "@/core/secretsManager"; -import { InMemorySecretStorage } from "../../mocks/testHelpers"; +import { + InMemoryMemento, + InMemorySecretStorage, + createMockLogger, +} from "../../mocks/testHelpers"; describe("SecretsManager", () => { let secretStorage: InMemorySecretStorage; + let memento: InMemoryMemento; let secretsManager: SecretsManager; beforeEach(() => { + vi.useRealTimers(); secretStorage = new InMemorySecretStorage(); - secretsManager = new SecretsManager(secretStorage); + memento = new InMemoryMemento(); + secretsManager = new SecretsManager( + secretStorage, + memento, + createMockLogger(), + ); }); - describe("session token", () => { - it("should store and retrieve tokens", async () => { - await secretsManager.setSessionToken("test-token"); - expect(await secretsManager.getSessionToken()).toBe("test-token"); + describe("session auth", () => { + it("should store and retrieve session auth", async () => { + await secretsManager.setSessionAuth("example.com", { + url: "https://example.com", + token: "test-token", + }); + const auth = await secretsManager.getSessionAuth("example.com"); + expect(auth?.token).toBe("test-token"); + expect(auth?.url).toBe("https://example.com"); - await secretsManager.setSessionToken("new-token"); - expect(await secretsManager.getSessionToken()).toBe("new-token"); + await secretsManager.setSessionAuth("example.com", { + url: "https://example.com", + token: "new-token", + }); + const newAuth = await secretsManager.getSessionAuth("example.com"); + expect(newAuth?.token).toBe("new-token"); }); - it("should delete token when empty or undefined", async () => { - await secretsManager.setSessionToken("test-token"); - await secretsManager.setSessionToken(""); - expect(await secretsManager.getSessionToken()).toBeUndefined(); - - await secretsManager.setSessionToken("test-token"); - await secretsManager.setSessionToken(undefined); - expect(await secretsManager.getSessionToken()).toBeUndefined(); + it("should clear session auth", async () => { + await secretsManager.setSessionAuth("example.com", { + url: "https://example.com", + token: "test-token", + }); + await secretsManager.clearAllAuthData("example.com"); + expect( + await secretsManager.getSessionAuth("example.com"), + ).toBeUndefined(); }); it("should return undefined for corrupted storage", async () => { - await secretStorage.store("sessionToken", "valid-token"); + await secretStorage.store( + "coder.session.example.com", + JSON.stringify({ + url: "https://example.com", + token: "valid-token", + }), + ); secretStorage.corruptStorage(); - expect(await secretsManager.getSessionToken()).toBeUndefined(); + expect( + await secretsManager.getSessionAuth("example.com"), + ).toBeUndefined(); }); - }); - describe("login state", () => { - it("should trigger login events", async () => { - const events: Array = []; - secretsManager.onDidChangeLoginState((state) => { - events.push(state); - return Promise.resolve(); + it("should track known safe hostnames", async () => { + expect(secretsManager.getKnownSafeHostnames()).toEqual([]); + + await secretsManager.setSessionAuth("example.com", { + url: "https://example.com", + token: "test-token", }); + expect(secretsManager.getKnownSafeHostnames()).toContain("example.com"); - await secretsManager.triggerLoginStateChange("login"); - expect(events).toEqual([AuthAction.LOGIN]); + await secretsManager.setSessionAuth("other-com", { + url: "https://other.com", + token: "other-token", + }); + expect(secretsManager.getKnownSafeHostnames()).toContain("example.com"); + expect(secretsManager.getKnownSafeHostnames()).toContain("other-com"); }); - it("should trigger logout events", async () => { - const events: Array = []; - secretsManager.onDidChangeLoginState((state) => { - events.push(state); - return Promise.resolve(); + it("should remove safe hostname on clearAllAuthData", async () => { + await secretsManager.setSessionAuth("example.com", { + url: "https://example.com", + token: "test-token", + }); + expect(secretsManager.getKnownSafeHostnames()).toContain("example.com"); + + await secretsManager.clearAllAuthData("example.com"); + expect(secretsManager.getKnownSafeHostnames()).not.toContain( + "example.com", + ); + }); + + it("should order safe hostnames by most recently accessed", async () => { + await secretsManager.setSessionAuth("first.com", { + url: "https://first.com", + token: "token1", + }); + await secretsManager.setSessionAuth("second.com", { + url: "https://second.com", + token: "token2", + }); + await secretsManager.setSessionAuth("first.com", { + url: "https://first.com", + token: "token1-updated", }); - await secretsManager.triggerLoginStateChange("logout"); - expect(events).toEqual([AuthAction.LOGOUT]); + expect(secretsManager.getKnownSafeHostnames()).toEqual([ + "first.com", + "second.com", + ]); }); - it("should fire same event twice in a row", async () => { + it("should prune old deployments when exceeding maxCount", async () => { + for (let i = 1; i <= 5; i++) { + await secretsManager.setSessionAuth(`host${i}.com`, { + url: `https://host${i}.com`, + token: `token${i}`, + }); + } + + await secretsManager.recordDeploymentAccess("new.com", 3); + + expect(secretsManager.getKnownSafeHostnames()).toEqual([ + "new.com", + "host5.com", + "host4.com", + ]); + expect(await secretsManager.getSessionAuth("host1.com")).toBeUndefined(); + expect(await secretsManager.getSessionAuth("host2.com")).toBeUndefined(); + }); + }); + + describe("current deployment", () => { + it("should store and retrieve current deployment", async () => { + const deployment = { + url: "https://example.com", + safeHostname: "example.com", + }; + await secretsManager.setCurrentDeployment(deployment); + + const result = await secretsManager.getCurrentDeployment(); + expect(result).toEqual(deployment); + }); + + it("should clear current deployment with undefined", async () => { + const deployment = { + url: "https://example.com", + safeHostname: "example.com", + }; + await secretsManager.setCurrentDeployment(deployment); + await secretsManager.setCurrentDeployment(undefined); + + const result = await secretsManager.getCurrentDeployment(); + expect(result).toBeNull(); + }); + + it("should return null when no deployment set", async () => { + const result = await secretsManager.getCurrentDeployment(); + expect(result).toBeNull(); + }); + + it("should notify listeners on deployment change", async () => { vi.useFakeTimers(); - const events: Array = []; - secretsManager.onDidChangeLoginState((state) => { + const events: Array = []; + secretsManager.onDidChangeCurrentDeployment((state) => { events.push(state); - return Promise.resolve(); }); - await secretsManager.triggerLoginStateChange("login"); + const deployments = [ + { url: "https://example.com", safeHostname: "example.com" }, + { url: "https://another.org", safeHostname: "another.org" }, + { url: "https://another.org", safeHostname: "another.org" }, + ]; + await secretsManager.setCurrentDeployment(deployments[0]); + vi.advanceTimersByTime(5); + await secretsManager.setCurrentDeployment(deployments[1]); vi.advanceTimersByTime(5); - await secretsManager.triggerLoginStateChange("login"); + await secretsManager.setCurrentDeployment(deployments[2]); + vi.advanceTimersByTime(5); + + // Trigger an event even if the deployment did not change + expect(events).toEqual(deployments.map((deployment) => ({ deployment }))); + }); + + it("should handle corrupted storage gracefully", async () => { + await secretStorage.store("coder.currentDeployment", "invalid-json{"); + + const result = await secretsManager.getCurrentDeployment(); + expect(result).toBeNull(); + }); + }); + + describe("migrateFromLegacyStorage", () => { + it("migrates legacy url/token to new format and sets current deployment", async () => { + // Set up legacy storage + await memento.update("url", "https://legacy.coder.com"); + await secretStorage.store("sessionToken", "legacy-token"); + + const result = await secretsManager.migrateFromLegacyStorage(); + + // Should return the migrated hostname + expect(result).toBe("legacy.coder.com"); + + // Should have migrated to new format + const auth = await secretsManager.getSessionAuth("legacy.coder.com"); + expect(auth?.url).toBe("https://legacy.coder.com"); + expect(auth?.token).toBe("legacy-token"); + + // Should have set current deployment + const deployment = await secretsManager.getCurrentDeployment(); + expect(deployment?.url).toBe("https://legacy.coder.com"); + expect(deployment?.safeHostname).toBe("legacy.coder.com"); + + // Legacy keys should be cleared + expect(memento.get("url")).toBeUndefined(); + expect(await secretStorage.get("sessionToken")).toBeUndefined(); + }); + + it("does not overwrite existing session auth", async () => { + // Set up existing auth + await secretsManager.setSessionAuth("existing.coder.com", { + url: "https://existing.coder.com", + token: "existing-token", + }); + + // Set up legacy storage with same hostname + await memento.update("url", "https://existing.coder.com"); + await secretStorage.store("sessionToken", "legacy-token"); + + await secretsManager.migrateFromLegacyStorage(); + + // Existing auth should not be overwritten + const auth = await secretsManager.getSessionAuth("existing.coder.com"); + expect(auth?.token).toBe("existing-token"); + }); + + it("returns undefined when no legacy data exists", async () => { + const result = await secretsManager.migrateFromLegacyStorage(); + expect(result).toBeUndefined(); + }); + + it("migrates with empty token when only URL exists (mTLS)", async () => { + await memento.update("url", "https://legacy.coder.com"); + + const result = await secretsManager.migrateFromLegacyStorage(); + expect(result).toBe("legacy.coder.com"); + + const auth = await secretsManager.getSessionAuth("legacy.coder.com"); + expect(auth?.url).toBe("https://legacy.coder.com"); + expect(auth?.token).toBe(""); + }); + }); + + describe("session auth - empty token handling (mTLS)", () => { + it("stores and retrieves empty string token", async () => { + await secretsManager.setSessionAuth("mtls.coder.com", { + url: "https://mtls.coder.com", + token: "", + }); - expect(events).toEqual([AuthAction.LOGIN, AuthAction.LOGIN]); - vi.useRealTimers(); + const auth = await secretsManager.getSessionAuth("mtls.coder.com"); + expect(auth?.token).toBe(""); + expect(auth?.url).toBe("https://mtls.coder.com"); }); }); }); diff --git a/test/unit/deployment/deploymentManager.test.ts b/test/unit/deployment/deploymentManager.test.ts new file mode 100644 index 00000000..63161eb5 --- /dev/null +++ b/test/unit/deployment/deploymentManager.test.ts @@ -0,0 +1,353 @@ +import { describe, expect, it, vi } from "vitest"; + +import { MementoManager } from "@/core/mementoManager"; +import { SecretsManager } from "@/core/secretsManager"; +import { DeploymentManager } from "@/deployment/deploymentManager"; + +import { + createMockLogger, + createMockUser, + InMemoryMemento, + InMemorySecretStorage, + MockCoderApi, +} from "../../mocks/testHelpers"; + +import type { CoderApi } from "@/api/coderApi"; +import type { ServiceContainer } from "@/core/container"; +import type { ContextManager } from "@/core/contextManager"; +import type { WorkspaceProvider } from "@/workspace/workspacesProvider"; + +/** + * Mock ContextManager for deployment tests. + */ +class MockContextManager { + private readonly contexts = new Map(); + + readonly set = vi.fn((key: string, value: boolean) => { + this.contexts.set(key, value); + }); + + get(key: string): boolean | undefined { + return this.contexts.get(key); + } +} + +/** + * Mock WorkspaceProvider for deployment tests. + */ +class MockWorkspaceProvider { + readonly fetchAndRefresh = vi.fn(); +} + +const TEST_URL = "https://coder.example.com"; +const TEST_HOSTNAME = "coder.example.com"; + +/** + * Creates a fresh test context with all dependencies. + */ +function createTestContext() { + vi.resetAllMocks(); + + const mockClient = new MockCoderApi(); + const mockWorkspaceProvider = new MockWorkspaceProvider(); + const secretStorage = new InMemorySecretStorage(); + const memento = new InMemoryMemento(); + const logger = createMockLogger(); + const secretsManager = new SecretsManager(secretStorage, memento, logger); + const mementoManager = new MementoManager(memento); + const contextManager = new MockContextManager(); + + const container = { + getSecretsManager: () => secretsManager, + getMementoManager: () => mementoManager, + getContextManager: () => contextManager as unknown as ContextManager, + getLogger: () => logger, + }; + + const manager = DeploymentManager.create( + container as unknown as ServiceContainer, + mockClient as unknown as CoderApi, + [mockWorkspaceProvider as unknown as WorkspaceProvider], + ); + + return { + mockClient, + secretsManager, + contextManager, + manager, + }; +} + +describe("DeploymentManager", () => { + describe("deployment state", () => { + it("returns null and isAuthenticated=false with no deployment", () => { + const { manager } = createTestContext(); + + expect(manager.getCurrentDeployment()).toBeNull(); + expect(manager.isAuthenticated()).toBe(false); + }); + + it("returns deployment and isAuthenticated=true after changeDeployment", async () => { + const { manager } = createTestContext(); + const user = createMockUser(); + + await manager.changeDeployment({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + token: "test-token", + user, + }); + + expect(manager.getCurrentDeployment()).toMatchObject({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + expect(manager.isAuthenticated()).toBe(true); + }); + + it("clears state after logout", async () => { + const { manager } = createTestContext(); + const user = createMockUser(); + + await manager.changeDeployment({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + token: "test-token", + user, + }); + + await manager.clearDeployment(); + + expect(manager.getCurrentDeployment()).toBeNull(); + expect(manager.isAuthenticated()).toBe(false); + }); + }); + + describe("changeDeployment", () => { + it("sets credentials, refreshes workspaces, persists deployment", async () => { + const { mockClient, secretsManager, contextManager, manager } = + createTestContext(); + const user = createMockUser(); + + await manager.changeDeployment({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + token: "test-token", + user, + }); + + expect(mockClient.host).toBe(TEST_URL); + expect(mockClient.token).toBe("test-token"); + expect(contextManager.get("coder.authenticated")).toBe(true); + expect(contextManager.get("coder.isOwner")).toBe(false); + + const persisted = await secretsManager.getCurrentDeployment(); + expect(persisted?.url).toBe(TEST_URL); + }); + + it("sets isOwner context when user has owner role", async () => { + const { contextManager, manager } = createTestContext(); + const ownerUser = createMockUser({ + roles: [{ name: "owner", display_name: "Owner", organization_id: "" }], + }); + + await manager.changeDeployment({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + token: "test-token", + user: ownerUser, + }); + + expect(contextManager.get("coder.isOwner")).toBe(true); + }); + }); + + describe("setDeploymentWithoutAuth", () => { + it("fetches user and upgrades to authenticated on success", async () => { + const { mockClient, manager } = createTestContext(); + const user = createMockUser(); + mockClient.setAuthenticatedUserResponse(user); + + await manager.setDeploymentAndValidate({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + token: "test-token", + }); + + expect(mockClient.host).toBe(TEST_URL); + expect(mockClient.token).toBe("test-token"); + expect(manager.isAuthenticated()).toBe(true); + }); + + it("remains unauthenticated on user fetch failure", async () => { + const { mockClient, manager } = createTestContext(); + mockClient.setAuthenticatedUserResponse(new Error("Auth failed")); + + await manager.setDeploymentAndValidate({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + token: "test-token", + }); + + expect(manager.getCurrentDeployment()).not.toBeNull(); + expect(manager.isAuthenticated()).toBe(false); + }); + + it("handles empty string token (mTLS) correctly (token='' is valid)", async () => { + const { mockClient, manager } = createTestContext(); + const user = createMockUser(); + mockClient.setAuthenticatedUserResponse(user); + + await manager.setDeploymentAndValidate({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + token: "", + }); + + // Empty string token should be set + expect(mockClient.host).toBe(TEST_URL); + expect(mockClient.token).toBe(""); + expect(manager.isAuthenticated()).toBe(true); + }); + + it("sets host without token when token is undefined", async () => { + const { mockClient, manager } = createTestContext(); + mockClient.setAuthenticatedUserResponse(new Error("Auth failed")); + + await manager.setDeploymentAndValidate({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + // Host should be set, token should remain undefined + expect(mockClient.host).toBe(TEST_URL); + expect(mockClient.token).toBeUndefined(); + }); + }); + + describe("cross-window sync", () => { + it("ignores changes when already authenticated", async () => { + const { mockClient, secretsManager, manager } = createTestContext(); + const user = createMockUser(); + + await manager.changeDeployment({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + token: "test-token", + user, + }); + + // Simulate cross-window change by directly updating secrets + await secretsManager.setCurrentDeployment({ + url: "https://other.example.com", + safeHostname: "other.example.com", + }); + + // Should still have original credentials + expect(mockClient.host).toBe(TEST_URL); + expect(mockClient.token).toBe("test-token"); + }); + + it("picks up deployment when not authenticated", async () => { + const { mockClient, secretsManager } = createTestContext(); + const user = createMockUser(); + mockClient.setAuthenticatedUserResponse(user); + + // Set up auth in secrets before triggering cross-window sync + await secretsManager.setSessionAuth(TEST_HOSTNAME, { + url: TEST_URL, + token: "synced-token", + }); + + // Simulate cross-window change + await secretsManager.setCurrentDeployment({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + // Wait for async handler + await new Promise((resolve) => setImmediate(resolve)); + + expect(mockClient.host).toBe(TEST_URL); + expect(mockClient.token).toBe("synced-token"); + }); + + it("handles mTLS deployment (empty token) from other window", async () => { + const { mockClient, secretsManager } = createTestContext(); + const user = createMockUser(); + mockClient.setAuthenticatedUserResponse(user); + + await secretsManager.setSessionAuth(TEST_HOSTNAME, { + url: TEST_URL, + token: "", + }); + + await secretsManager.setCurrentDeployment({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + // Wait for async handler + await new Promise((resolve) => setImmediate(resolve)); + + expect(mockClient.host).toBe(TEST_URL); + expect(mockClient.token).toBe(""); + }); + }); + + describe("auth listener", () => { + it("updates credentials on token change and authenticates user", async () => { + const { mockClient, secretsManager, manager } = createTestContext(); + const user = createMockUser(); + + // Initially fail auth (no valid token yet) + mockClient.setAuthenticatedUserResponse(new Error("Auth failed")); + + // Set up initial deployment without user (will fail to authenticate) + await manager.setDeploymentAndValidate({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + token: "initial-token", + }); + + expect(mockClient.token).toBe("initial-token"); + expect(manager.isAuthenticated()).toBe(false); + + // Now auth succeeds with the new token + mockClient.setAuthenticatedUserResponse(user); + + // Simulate token refresh via secrets change + await secretsManager.setSessionAuth(TEST_HOSTNAME, { + url: TEST_URL, + token: "refreshed-token", + }); + + // Wait for async handler + await new Promise((resolve) => setImmediate(resolve)); + + expect(mockClient.token).toBe("refreshed-token"); + expect(manager.isAuthenticated()).toBe(true); + }); + }); + + describe("logout", () => { + it("clears credentials and updates contexts", async () => { + const { mockClient, contextManager, manager } = createTestContext(); + const user = createMockUser(); + + await manager.changeDeployment({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + token: "test-token", + user, + }); + + await manager.clearDeployment(); + + expect(mockClient.host).toBeUndefined(); + expect(mockClient.token).toBeUndefined(); + expect(contextManager.get("coder.authenticated")).toBe(false); + expect(contextManager.get("coder.isOwner")).toBe(false); + }); + }); +}); diff --git a/test/unit/login/loginCoordinator.test.ts b/test/unit/login/loginCoordinator.test.ts new file mode 100644 index 00000000..fda88ada --- /dev/null +++ b/test/unit/login/loginCoordinator.test.ts @@ -0,0 +1,334 @@ +import axios from "axios"; +import { describe, expect, it, vi } from "vitest"; +import * as vscode from "vscode"; + +import { MementoManager } from "@/core/mementoManager"; +import { SecretsManager } from "@/core/secretsManager"; +import { getHeaders } from "@/headers"; +import { LoginCoordinator } from "@/login/loginCoordinator"; + +import { + createMockLogger, + createMockUser, + InMemoryMemento, + InMemorySecretStorage, + MockConfigurationProvider, + MockUserInteraction, +} from "../../mocks/testHelpers"; + +// Hoisted mock adapter implementation +const mockAxiosAdapterImpl = vi.hoisted( + () => (config: Record) => + Promise.resolve({ + data: config.data || "{}", + status: 200, + statusText: "OK", + headers: {}, + config, + }), +); + +vi.mock("axios", async () => { + const actual = await vi.importActual("axios"); + const mockAdapter = vi.fn(); + return { + ...actual, + default: { + ...actual.default, + create: vi.fn((config) => + actual.default.create({ ...config, adapter: mockAdapter }), + ), + __mockAdapter: mockAdapter, + }, + }; +}); + +vi.mock("@/headers", () => ({ + getHeaders: vi.fn().mockResolvedValue({}), + getHeaderCommand: vi.fn(), +})); + +vi.mock("@/api/utils", async () => { + const actual = + await vi.importActual("@/api/utils"); + return { ...actual, createHttpAgent: vi.fn() }; +}); + +vi.mock("@/api/streamingFetchAdapter", () => ({ + createStreamingFetchAdapter: vi.fn(() => fetch), +})); + +vi.mock("@/promptUtils"); + +// Type for axios with our mock adapter +type MockedAxios = typeof axios & { __mockAdapter: ReturnType }; + +const TEST_URL = "https://coder.example.com"; +const TEST_HOSTNAME = "coder.example.com"; + +/** + * Creates a fresh test context with all dependencies. + */ +function createTestContext() { + vi.resetAllMocks(); + + const mockAdapter = (axios as MockedAxios).__mockAdapter; + mockAdapter.mockImplementation(mockAxiosAdapterImpl); + vi.mocked(getHeaders).mockResolvedValue({}); + + // MockConfigurationProvider sets sensible defaults (httpClientLogLevel, tlsCertFile, tlsKeyFile) + const mockConfig = new MockConfigurationProvider(); + // MockUserInteraction sets up vscode.window dialogs and input boxes + const userInteraction = new MockUserInteraction(); + + const secretStorage = new InMemorySecretStorage(); + const memento = new InMemoryMemento(); + const logger = createMockLogger(); + const secretsManager = new SecretsManager(secretStorage, memento, logger); + const mementoManager = new MementoManager(memento); + + const coordinator = new LoginCoordinator( + secretsManager, + mementoManager, + vscode, + logger, + ); + + const mockSuccessfulAuth = (user = createMockUser()) => { + mockAdapter.mockResolvedValue({ + data: user, + status: 200, + statusText: "OK", + headers: {}, + config: {}, + }); + return user; + }; + + const mockAuthFailure = (message = "Unauthorized") => { + mockAdapter.mockRejectedValue({ + response: { status: 401, data: { message } }, + message, + }); + }; + + return { + mockAdapter, + mockConfig, + userInteraction, + secretsManager, + mementoManager, + coordinator, + mockSuccessfulAuth, + mockAuthFailure, + }; +} + +describe("LoginCoordinator", () => { + describe("token authentication", () => { + it("authenticates with stored token on success", async () => { + const { secretsManager, coordinator, mockSuccessfulAuth } = + createTestContext(); + const user = mockSuccessfulAuth(); + + // Pre-store a token + await secretsManager.setSessionAuth(TEST_HOSTNAME, { + url: TEST_URL, + token: "stored-token", + }); + + const result = await coordinator.ensureLoggedIn({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + expect(result).toEqual({ success: true, user, token: "stored-token" }); + + const auth = await secretsManager.getSessionAuth(TEST_HOSTNAME); + expect(auth?.token).toBe("stored-token"); + }); + + it("prompts for token when no stored auth exists", async () => { + const { mockAdapter, userInteraction, secretsManager, coordinator } = + createTestContext(); + const user = createMockUser(); + + // No stored token, so goes directly to input box flow + // Mock succeeds when validateInput calls getAuthenticatedUser + mockAdapter.mockResolvedValueOnce({ + data: user, + status: 200, + statusText: "OK", + headers: {}, + config: {}, + }); + + // User enters a new token in the input box + userInteraction.setInputBoxValue("new-token"); + + const result = await coordinator.ensureLoggedIn({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + expect(result).toEqual({ success: true, user, token: "new-token" }); + + // Verify new token was persisted + const auth = await secretsManager.getSessionAuth(TEST_HOSTNAME); + expect(auth?.token).toBe("new-token"); + }); + + it("returns success false when user cancels input", async () => { + const { userInteraction, coordinator, mockAuthFailure } = + createTestContext(); + mockAuthFailure(); + userInteraction.setInputBoxValue(undefined); + + const result = await coordinator.ensureLoggedIn({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + expect(result.success).toBe(false); + }); + }); + + describe("same-window guard", () => { + it("prevents duplicate login calls for same hostname", async () => { + const { mockAdapter, userInteraction, coordinator } = createTestContext(); + const user = createMockUser(); + + // User enters a token in the input box + userInteraction.setInputBoxValue("new-token"); + + let resolveAuth: (value: unknown) => void; + mockAdapter.mockReturnValue( + new Promise((resolve) => { + resolveAuth = resolve; + }), + ); + + // Start first login + const login1 = coordinator.ensureLoggedIn({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + // Start second login immediately (same hostname) + const login2 = coordinator.ensureLoggedIn({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + // Resolve the auth (this validates the token from input box) + resolveAuth!({ + data: user, + status: 200, + statusText: "OK", + headers: {}, + config: {}, + }); + + // Both should complete with the same result + const [result1, result2] = await Promise.all([login1, login2]); + expect(result1.success).toBe(true); + expect(result1).toEqual(result2); + + // Input box should only be shown once (guard prevents duplicate prompts) + expect(vscode.window.showInputBox).toHaveBeenCalledTimes(1); + }); + }); + + describe("mTLS authentication", () => { + it("succeeds without prompt and returns token=''", async () => { + const { mockConfig, secretsManager, coordinator, mockSuccessfulAuth } = + createTestContext(); + // Configure mTLS via certs (no token needed) + mockConfig.set("coder.tlsCertFile", "/path/to/cert.pem"); + mockConfig.set("coder.tlsKeyFile", "/path/to/key.pem"); + + const user = mockSuccessfulAuth(); + + const result = await coordinator.ensureLoggedIn({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + expect(result).toEqual({ success: true, user, token: "" }); + + // Verify empty string token was persisted + const auth = await secretsManager.getSessionAuth(TEST_HOSTNAME); + expect(auth?.token).toBe(""); + + // Should NOT prompt for token + expect(vscode.window.showInputBox).not.toHaveBeenCalled(); + }); + + it("shows error and returns failure when mTLS fails", async () => { + const { mockConfig, coordinator, mockAuthFailure } = createTestContext(); + mockConfig.set("coder.tlsCertFile", "/path/to/cert.pem"); + mockConfig.set("coder.tlsKeyFile", "/path/to/key.pem"); + mockAuthFailure("Certificate error"); + + const result = await coordinator.ensureLoggedIn({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + expect(result.success).toBe(false); + expect(vscode.window.showErrorMessage).toHaveBeenCalledWith( + "Failed to log in to Coder server", + expect.objectContaining({ modal: true }), + ); + + // Should NOT prompt for token since it's mTLS + expect(vscode.window.showInputBox).not.toHaveBeenCalled(); + }); + + it("logs warning instead of showing dialog for autoLogin", async () => { + const { mockConfig, secretsManager, mementoManager, mockAuthFailure } = + createTestContext(); + mockConfig.set("coder.tlsCertFile", "/path/to/cert.pem"); + mockConfig.set("coder.tlsKeyFile", "/path/to/key.pem"); + + const logger = createMockLogger(); + const coordinator = new LoginCoordinator( + secretsManager, + mementoManager, + vscode, + logger, + ); + + mockAuthFailure("Certificate error"); + + const result = await coordinator.ensureLoggedIn({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + autoLogin: true, + }); + + expect(result.success).toBe(false); + expect(logger.warn).toHaveBeenCalled(); + expect(vscode.window.showErrorMessage).not.toHaveBeenCalled(); + }); + }); + + describe("ensureLoggedInWithDialog", () => { + it("returns success false when user dismisses dialog", async () => { + const { mockConfig, userInteraction, coordinator } = createTestContext(); + // Use mTLS for simpler dialog test + mockConfig.set("coder.tlsCertFile", "/path/to/cert.pem"); + mockConfig.set("coder.tlsKeyFile", "/path/to/key.pem"); + + // User dismisses dialog (returns undefined instead of "Login") + userInteraction.setResponse("Authentication Required", undefined); + + const result = await coordinator.ensureLoggedInWithDialog({ + url: TEST_URL, + safeHostname: TEST_HOSTNAME, + }); + + expect(result.success).toBe(false); + }); + }); +}); diff --git a/test/unit/util.test.ts b/test/unit/util.test.ts index 3015a47d..441ecd8c 100644 --- a/test/unit/util.test.ts +++ b/test/unit/util.test.ts @@ -42,8 +42,8 @@ describe("parseRemoteAuthority", () => { parseRemoteAuthority("vscode://ssh-remote+coder-vscode--foo--bar"), ).toStrictEqual({ agent: "", - host: "coder-vscode--foo--bar", - label: "", + sshHost: "coder-vscode--foo--bar", + safeHostname: "", username: "foo", workspace: "bar", }); @@ -51,8 +51,8 @@ describe("parseRemoteAuthority", () => { parseRemoteAuthority("vscode://ssh-remote+coder-vscode--foo--bar--baz"), ).toStrictEqual({ agent: "baz", - host: "coder-vscode--foo--bar--baz", - label: "", + sshHost: "coder-vscode--foo--bar--baz", + safeHostname: "", username: "foo", workspace: "bar", }); @@ -62,8 +62,8 @@ describe("parseRemoteAuthority", () => { ), ).toStrictEqual({ agent: "", - host: "coder-vscode.dev.coder.com--foo--bar", - label: "dev.coder.com", + sshHost: "coder-vscode.dev.coder.com--foo--bar", + safeHostname: "dev.coder.com", username: "foo", workspace: "bar", }); @@ -73,8 +73,8 @@ describe("parseRemoteAuthority", () => { ), ).toStrictEqual({ agent: "baz", - host: "coder-vscode.dev.coder.com--foo--bar--baz", - label: "dev.coder.com", + sshHost: "coder-vscode.dev.coder.com--foo--bar--baz", + safeHostname: "dev.coder.com", username: "foo", workspace: "bar", }); @@ -84,8 +84,8 @@ describe("parseRemoteAuthority", () => { ), ).toStrictEqual({ agent: "baz", - host: "coder-vscode.dev.coder.com--foo--bar.baz", - label: "dev.coder.com", + sshHost: "coder-vscode.dev.coder.com--foo--bar.baz", + safeHostname: "dev.coder.com", username: "foo", workspace: "bar", }); diff --git a/test/unit/websocket/reconnectingWebSocket.test.ts b/test/unit/websocket/reconnectingWebSocket.test.ts index cdf08949..468a53a4 100644 --- a/test/unit/websocket/reconnectingWebSocket.test.ts +++ b/test/unit/websocket/reconnectingWebSocket.test.ts @@ -91,11 +91,7 @@ describe("ReconnectingWebSocket", () => { }); await expect( - ReconnectingWebSocket.create( - factory, - createMockLogger(), - "/api/test", - ), + ReconnectingWebSocket.create(factory, createMockLogger()), ).rejects.toThrow(`Unexpected server response: ${statusCode}`); // Should not retry after unrecoverable HTTP error @@ -104,6 +100,32 @@ describe("ReconnectingWebSocket", () => { }, ); + it.each([ + HttpStatusCode.UNAUTHORIZED, + HttpStatusCode.FORBIDDEN, + HttpStatusCode.GONE, + ])( + "does not reconnect on unrecoverable HTTP error via error event: %i", + async (statusCode) => { + // HTTP errors during handshake fire 'error' event, then 'close' with 1006 + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireError( + new Error(`Unexpected server response: ${statusCode}`), + ); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Connection failed", + }); + + // Should not reconnect - unrecoverable HTTP error + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + + ws.close(); + }, + ); + it("reconnect() connects immediately and cancels pending reconnections", async () => { const { ws, sockets } = await createReconnectingWebSocket(); @@ -125,26 +147,8 @@ describe("ReconnectingWebSocket", () => { }); it("queues reconnect() calls made during connection", async () => { - const sockets: MockSocket[] = []; - let pendingResolve: ((socket: MockSocket) => void) | null = null; - - const factory = vi.fn(() => { - const socket = createMockSocket(); - sockets.push(socket); - - // First call resolves immediately, other calls wait for manual resolve - if (sockets.length === 1) { - return Promise.resolve(socket); - } else { - return new Promise((resolve) => { - pendingResolve = resolve; - }); - } - }); - - const ws = await fromFactory(factory); - sockets[0].fireOpen(); - expect(sockets).toHaveLength(1); + const { ws, sockets, completeConnection } = + await createBlockingReconnectingWebSocket(); // Start first reconnect (will block on factory promise) ws.reconnect(); @@ -154,17 +158,33 @@ describe("ReconnectingWebSocket", () => { // Still only 2 sockets (queued reconnect hasn't started) expect(sockets).toHaveLength(2); - // Complete the first reconnect - pendingResolve!(sockets[1]); - sockets[1].fireOpen(); - - // Wait a tick for the queued reconnect to execute + completeConnection(); await Promise.resolve(); // Now queued reconnect should have executed, creating third socket expect(sockets).toHaveLength(3); ws.close(); }); + + it("suspend() cancels pending reconnect queued during connection", async () => { + const { ws, sockets, failConnection } = + await createBlockingReconnectingWebSocket(); + + ws.reconnect(); + ws.reconnect(); // queued + expect(sockets).toHaveLength(2); + + // This should cancel the queued request + ws.disconnect(); + failConnection(new Error("No base URL")); + await Promise.resolve(); + + expect(sockets).toHaveLength(2); + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(2); + + ws.close(); + }); }); describe("Event Handlers", () => { @@ -216,6 +236,48 @@ describe("ReconnectingWebSocket", () => { ws.close(); }); + + it("preserves event handlers after suspend() and reconnect()", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + sockets[0].fireOpen(); + + const handler = vi.fn(); + ws.addEventListener("message", handler); + sockets[0].fireMessage({ test: 1 }); + expect(handler).toHaveBeenCalledTimes(1); + + // Suspend the socket + ws.disconnect(); + + // Reconnect (async operation) + ws.reconnect(); + await Promise.resolve(); // Wait for async connect() + expect(sockets).toHaveLength(2); + sockets[1].fireOpen(); + + // Handler should still work after suspend/reconnect + sockets[1].fireMessage({ test: 2 }); + expect(handler).toHaveBeenCalledTimes(2); + + ws.close(); + }); + + it("clears event handlers after close()", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + sockets[0].fireOpen(); + + const handler = vi.fn(); + ws.addEventListener("message", handler); + sockets[0].fireMessage({ test: 1 }); + expect(handler).toHaveBeenCalledTimes(1); + + // Close permanently + ws.close(); + + // Even if we could reconnect (we can't), handlers would be cleared + // Verify handler was removed by checking it's no longer in the set + // We can't easily test this without exposing internals, but close() clears handlers + }); }); describe("close() and Disposal", () => { @@ -258,9 +320,9 @@ describe("ReconnectingWebSocket", () => { expect(disposeCount).toBe(1); }); - it("calls onDispose callback on unrecoverable WebSocket close code", async () => { + it("suspends (not disposes) on unrecoverable WebSocket close code", async () => { let disposeCount = 0; - const { sockets } = await createReconnectingWebSocket( + const { ws, sockets } = await createReconnectingWebSocket( () => ++disposeCount, ); @@ -270,7 +332,14 @@ describe("ReconnectingWebSocket", () => { reason: "Protocol error", }); - expect(disposeCount).toBe(1); + // Should suspend, not dispose - allows recovery when credentials change + expect(disposeCount).toBe(0); + + // Should be able to reconnect after suspension + ws.reconnect(); + expect(sockets).toHaveLength(2); + + ws.close(); }); it("does not call onDispose callback during reconnection", async () => { @@ -291,6 +360,41 @@ describe("ReconnectingWebSocket", () => { ws.close(); expect(disposeCount).toBe(1); }); + + it("reconnect() resumes suspended socket after HTTP 403 error", async () => { + const { ws, sockets, setFactoryError } = + await createReconnectingWebSocketWithErrorControl(); + sockets[0].fireOpen(); + + // Trigger reconnect that will fail with 403 + setFactoryError( + new Error(`Unexpected server response: ${HttpStatusCode.FORBIDDEN}`), + ); + ws.reconnect(); + await Promise.resolve(); + + // Socket should be suspended - no automatic reconnection + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + + // reconnect() should resume the suspended socket + setFactoryError(null); + ws.reconnect(); + await Promise.resolve(); + expect(sockets).toHaveLength(2); + + ws.close(); + }); + + it("reconnect() does nothing after close()", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + sockets[0].fireOpen(); + + ws.close(); + ws.reconnect(); + + expect(sockets).toHaveLength(1); + }); }); describe("Backoff Strategy", () => { @@ -454,6 +558,35 @@ async function createReconnectingWebSocket(onDispose?: () => void): Promise<{ return { ws, sockets }; } +async function createReconnectingWebSocketWithErrorControl(): Promise<{ + ws: ReconnectingWebSocket; + sockets: MockSocket[]; + setFactoryError: (error: Error | null) => void; +}> { + const sockets: MockSocket[] = []; + let factoryError: Error | null = null; + + const factory = vi.fn(() => { + if (factoryError) { + return Promise.reject(factoryError); + } + const socket = createMockSocket(); + sockets.push(socket); + return Promise.resolve(socket); + }); + + const ws = await fromFactory(factory); + expect(sockets).toHaveLength(1); + + return { + ws, + sockets, + setFactoryError: (error: Error | null) => { + factoryError = error; + }, + }; +} + async function fromFactory( factory: SocketFactory, onDispose?: () => void, @@ -461,8 +594,44 @@ async function fromFactory( return await ReconnectingWebSocket.create( factory, createMockLogger(), - "/random/api", undefined, onDispose, ); } + +async function createBlockingReconnectingWebSocket(): Promise<{ + ws: ReconnectingWebSocket; + sockets: MockSocket[]; + completeConnection: () => void; + failConnection: (error: Error) => void; +}> { + const sockets: MockSocket[] = []; + let pendingResolve: ((socket: MockSocket) => void) | null = null; + let pendingReject: ((error: Error) => void) | null = null; + + const factory = vi.fn(() => { + const socket = createMockSocket(); + sockets.push(socket); + if (sockets.length === 1) { + return Promise.resolve(socket); + } + return new Promise((resolve, reject) => { + pendingResolve = resolve; + pendingReject = reject; + }); + }); + + const ws = await fromFactory(factory); + sockets[0].fireOpen(); + + return { + ws, + sockets, + completeConnection: () => { + const socket = sockets.at(-1)!; + pendingResolve?.(socket); + socket.fireOpen(); + }, + failConnection: (error: Error) => pendingReject?.(error), + }; +}