diff --git a/package.json b/package.json index bd60a54c..a3b585ab 100644 --- a/package.json +++ b/package.json @@ -255,6 +255,12 @@ "title": "Search", "category": "Coder", "icon": "$(search)" + }, + { + "command": "coder.debug.listDeployments", + "title": "List Stored Deployments", + "category": "Coder Debug", + "when": "coder.devMode" } ], "menus": { diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index 04c696be..0f52f585 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -31,7 +31,7 @@ import { HttpClientLogLevel, } from "../logging/types"; import { sizeOf } from "../logging/utils"; -import { HttpStatusCode } from "../websocket/codes"; +import { HttpStatusCode, WebSocketCloseCode } from "../websocket/codes"; import { type UnidirectionalStream, type CloseEvent, @@ -55,7 +55,7 @@ const coderSessionTokenHeader = "Coder-Session-Token"; * Unified API class that includes both REST API methods from the base Api class * and WebSocket methods for real-time functionality. */ -export class CoderApi extends Api { +export class CoderApi extends Api implements vscode.Disposable { private readonly reconnectingSockets = new Set< ReconnectingWebSocket >(); @@ -74,38 +74,63 @@ export class CoderApi extends Api { output: Logger, ): CoderApi { const client = new CoderApi(output); - client.setHost(baseUrl); - if (token) { - client.setSessionToken(token); - } + client.setCredentials(baseUrl, token); setupInterceptors(client, output); return client; } - setSessionToken = (token: string): void => { - const defaultHeaders = this.getAxiosInstance().defaults.headers.common; - const currentToken = defaultHeaders[coderSessionTokenHeader]; - defaultHeaders[coderSessionTokenHeader] = token; + /** + * Set both host and token together. Useful for login/logout/switch to + * avoid triggering multiple reconnection events. + */ + setCredentials = ( + host: string | undefined, + token: string | undefined, + ): void => { + const defaults = this.getAxiosInstance().defaults; + const currentHost = defaults.baseURL; + const currentToken = defaults.headers.common[coderSessionTokenHeader]; - if (currentToken !== token) { + defaults.baseURL = host; + defaults.headers.common[coderSessionTokenHeader] = token; + + const hostChanged = currentHost !== host; + const tokenChanged = currentToken !== token; + + if (hostChanged || tokenChanged) { for (const socket of this.reconnectingSockets) { - socket.reconnect(); + if (host) { + socket.reconnect(); + } else { + socket.suspend(WebSocketCloseCode.NORMAL, "Host cleared"); + } } } }; + setSessionToken = (token: string): void => { + const currentHost = this.getAxiosInstance().defaults.baseURL; + this.setCredentials(currentHost, token); + }; + setHost = (host: string | undefined): void => { - const defaults = this.getAxiosInstance().defaults; - const currentHost = defaults.baseURL; - defaults.baseURL = host; + const currentToken = this.getAxiosInstance().defaults.headers.common[ + coderSessionTokenHeader + ] as string | undefined; + this.setCredentials(host, currentToken); + }; - if (currentHost !== host) { - for (const socket of this.reconnectingSockets) { - socket.reconnect(); - } + /** + * Permanently dispose all WebSocket connections. + * This clears handlers and prevents reconnection. + */ + dispose(): void { + for (const socket of this.reconnectingSockets) { + socket.close(); } - }; + this.reconnectingSockets.clear(); + } watchInboxNotifications = async ( watchTemplates: string[], @@ -125,7 +150,7 @@ export class CoderApi extends Api { }; watchWorkspace = async (workspace: Workspace, options?: ClientOptions) => { - return this.createWebSocketWithFallback({ + return this.createWebSocketWithFallback({ apiRoute: `/api/v2/workspaces/${workspace.id}/watch-ws`, fallbackApiRoute: `/api/v2/workspaces/${workspace.id}/watch`, options, @@ -137,7 +162,7 @@ export class CoderApi extends Api { agentId: WorkspaceAgent["id"], options?: ClientOptions, ) => { - return this.createWebSocketWithFallback({ + return this.createWebSocketWithFallback({ apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`, fallbackApiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata`, options, @@ -198,68 +223,62 @@ export class CoderApi extends Api { throw new Error("No base URL set on REST client"); } - const baseUrl = new URL(baseUrlRaw); - const token = this.getAxiosInstance().defaults.headers.common[ - coderSessionTokenHeader - ] as string | undefined; + return this.createOneWayWebSocket(socketConfigs); + }; - const headersFromCommand = await getHeaders( - baseUrlRaw, - getHeaderCommand(vscode.workspace.getConfiguration()), - this.output, - ); + if (enableRetry) { + return this.createReconnectingSocket(socketFactory, configs.apiRoute); + } + return socketFactory(); + } - const httpAgent = await createHttpAgent( - vscode.workspace.getConfiguration(), - ); + private async createOneWayWebSocket( + configs: Omit, + ): Promise> { + const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; + if (!baseUrlRaw) { + throw new Error("No base URL set on REST client"); + } + const token = this.getAxiosInstance().defaults.headers.common[ + coderSessionTokenHeader + ] as string | undefined; - /** - * Similar to the REST client, we want to prioritize headers in this order (highest to lowest): - * 1. Headers from the header command - * 2. Any headers passed directly to this function - * 3. Coder session token from the Api client (if set) - */ - const headers = { - ...(token ? { [coderSessionTokenHeader]: token } : {}), - ...configs.options?.headers, - ...headersFromCommand, - }; + const headersFromCommand = await getHeaders( + baseUrlRaw, + getHeaderCommand(vscode.workspace.getConfiguration()), + this.output, + ); - const webSocket = new OneWayWebSocket({ - location: baseUrl, - ...socketConfigs, - options: { - ...configs.options, - agent: httpAgent, - followRedirects: true, - headers, - }, - }); + const httpAgent = await createHttpAgent( + vscode.workspace.getConfiguration(), + ); - this.attachStreamLogger(webSocket); - return webSocket; + /** + * Similar to the REST client, we want to prioritize headers in this order (highest to lowest): + * 1. Headers from the header command + * 2. Any headers passed directly to this function + * 3. Coder session token from the Api client (if set) + */ + const headers = { + ...(token ? { [coderSessionTokenHeader]: token } : {}), + ...configs.options?.headers, + ...headersFromCommand, }; - if (enableRetry) { - const reconnectingSocket = await ReconnectingWebSocket.create( - socketFactory, - this.output, - configs.apiRoute, - undefined, - () => - this.reconnectingSockets.delete( - reconnectingSocket as ReconnectingWebSocket, - ), - ); - - this.reconnectingSockets.add( - reconnectingSocket as ReconnectingWebSocket, - ); + const baseUrl = new URL(baseUrlRaw); + const ws = new OneWayWebSocket({ + location: baseUrl, + ...configs, + options: { + ...configs.options, + agent: httpAgent, + followRedirects: true, + headers, + }, + }); - return reconnectingSocket; - } else { - return socketFactory(); - } + this.attachStreamLogger(ws); + return ws; } private attachStreamLogger( @@ -288,44 +307,79 @@ export class CoderApi extends Api { /** * Create a WebSocket connection with SSE fallback on 404. * + * The factory tries WS first, falls back to SSE on 404. Since the factory + * is called on every reconnect. + * * Note: The fallback on SSE ignores all passed client options except the headers. */ - private async createWebSocketWithFallback(configs: { - apiRoute: string; - fallbackApiRoute: string; - searchParams?: Record | URLSearchParams; - options?: ClientOptions; - enableRetry?: boolean; - }): Promise> { - let webSocket: UnidirectionalStream; - try { - webSocket = await this.createWebSocket({ - apiRoute: configs.apiRoute, - searchParams: configs.searchParams, - options: configs.options, - enableRetry: configs.enableRetry, - }); - } catch { - // Failed to create WebSocket, use SSE fallback - return this.createSseFallback( - configs.fallbackApiRoute, - configs.searchParams, - configs.options?.headers, + private async createWebSocketWithFallback( + configs: Omit & { + fallbackApiRoute: string; + enableRetry?: boolean; + }, + ): Promise> { + const { fallbackApiRoute, enableRetry, ...socketConfigs } = configs; + const socketFactory: SocketFactory = async () => { + try { + const ws = + await this.createOneWayWebSocket(socketConfigs); + return await this.waitForOpen(ws); + } catch (error) { + if (this.is404Error(error)) { + this.output.warn( + `WebSocket failed, using SSE fallback: ${socketConfigs.apiRoute}`, + ); + const sse = this.createSseConnection( + fallbackApiRoute, + socketConfigs.searchParams, + socketConfigs.options?.headers, + ); + return await this.waitForOpen(sse); + } + throw error; + } + }; + + if (enableRetry) { + return this.createReconnectingSocket( + socketFactory, + socketConfigs.apiRoute, ); } + return socketFactory(); + } - return this.waitForConnection(webSocket, () => - this.createSseFallback( - configs.fallbackApiRoute, - configs.searchParams, - configs.options?.headers, - ), - ); + /** + * Create an SSE connection without waiting for connection. + */ + private createSseConnection( + apiRoute: string, + searchParams?: Record | URLSearchParams, + optionsHeaders?: Record, + ): SseConnection { + const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; + if (!baseUrlRaw) { + throw new Error("No base URL set on REST client"); + } + const url = new URL(baseUrlRaw); + const sse = new SseConnection({ + location: url, + apiRoute, + searchParams, + axiosInstance: this.getAxiosInstance(), + optionsHeaders, + logger: this.output, + }); + + this.attachStreamLogger(sse); + return sse; } - private waitForConnection( + /** + * Wait for a connection to open. Rejects on error. + */ + private waitForOpen( connection: UnidirectionalStream, - onNotFound?: () => Promise>, ): Promise> { return new Promise((resolve, reject) => { const cleanup = () => { @@ -340,16 +394,8 @@ export class CoderApi extends Api { const handleError = (event: ErrorEvent) => { cleanup(); - const is404 = - event.message?.includes(String(HttpStatusCode.NOT_FOUND)) || - event.error?.message?.includes(String(HttpStatusCode.NOT_FOUND)); - - if (is404 && onNotFound) { - connection.close(); - onNotFound().then(resolve).catch(reject); - } else { - reject(event.error || new Error(event.message)); - } + connection.close(); + reject(event.error || new Error(event.message)); }; connection.addEventListener("open", handleOpen); @@ -358,32 +404,36 @@ export class CoderApi extends Api { } /** - * Create SSE fallback connection + * Check if an error is a 404 Not Found error. */ - private async createSseFallback( - apiRoute: string, - searchParams?: Record | URLSearchParams, - optionsHeaders?: Record, - ): Promise> { - this.output.warn(`WebSocket failed, using SSE fallback: ${apiRoute}`); - - const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; - if (!baseUrlRaw) { - throw new Error("No base URL set on REST client"); - } + private is404Error(error: unknown): boolean { + const msg = error instanceof Error ? error.message : String(error); + return msg.includes(String(HttpStatusCode.NOT_FOUND)); + } - const baseUrl = new URL(baseUrlRaw); - const sseConnection = new SseConnection({ - location: baseUrl, + /** + * Create a ReconnectingWebSocket and track it for lifecycle management. + */ + private async createReconnectingSocket( + socketFactory: SocketFactory, + apiRoute: string, + ): Promise> { + const reconnectingSocket = await ReconnectingWebSocket.create( + socketFactory, + this.output, apiRoute, - searchParams, - axiosInstance: this.getAxiosInstance(), - optionsHeaders: optionsHeaders, - logger: this.output, - }); + undefined, + () => + this.reconnectingSockets.delete( + reconnectingSocket as ReconnectingWebSocket, + ), + ); + + this.reconnectingSockets.add( + reconnectingSocket as ReconnectingWebSocket, + ); - this.attachStreamLogger(sseConnection); - return this.waitForConnection(sseConnection); + return reconnectingSocket; } } @@ -457,7 +507,7 @@ function addLoggingInterceptors(client: AxiosInstance, logger: Logger) { }, (error: unknown) => { logError(logger, error, getLogLevel()); - return Promise.reject(error); + throw error; }, ); @@ -468,7 +518,7 @@ function addLoggingInterceptors(client: AxiosInstance, logger: Logger) { }, (error: unknown) => { logError(logger, error, getLogLevel()); - return Promise.reject(error); + throw error; }, ); } diff --git a/src/api/oauthInterceptors.ts b/src/api/oauthInterceptors.ts new file mode 100644 index 00000000..b80e1d96 --- /dev/null +++ b/src/api/oauthInterceptors.ts @@ -0,0 +1,116 @@ +import { type AxiosError, isAxiosError } from "axios"; + +import { type Logger } from "../logging/logger"; +import { type RequestConfigWithMeta } from "../logging/types"; +import { parseOAuthError, requiresReAuthentication } from "../oauth/errors"; +import { type OAuthSessionManager } from "../oauth/sessionManager"; + +import { type CoderApi } from "./coderApi"; + +const coderSessionTokenHeader = "Coder-Session-Token"; + +/** + * Attach OAuth token refresh interceptors to a CoderApi instance. + * This should be called after creating the CoderApi when OAuth authentication is being used. + * + * Success interceptor: proactively refreshes token when approaching expiry. + * Error interceptor: reactively refreshes token on 401 responses. + */ +export function attachOAuthInterceptors( + client: CoderApi, + logger: Logger, + oauthSessionManager: OAuthSessionManager, +): void { + client.getAxiosInstance().interceptors.response.use( + // Success response interceptor: proactive token refresh + (response) => { + // Fire-and-forget: don't await, don't block response + oauthSessionManager.refreshIfAlmostExpired().catch((error) => { + logger.warn("Proactive background token refresh failed:", error); + }); + + return response; + }, + // Error response interceptor: reactive token refresh on 401 + async (error: unknown) => { + if (!isAxiosError(error)) { + throw error; + } + + if (error.config) { + const config = error.config as { + _oauthRetryAttempted?: boolean; + }; + if (config._oauthRetryAttempted) { + throw error; + } + } + + const status = error.response?.status; + + // These could indicate permanent auth failures that won't be fixed by token refresh + if (status === 400 || status === 403) { + handlePossibleOAuthError(error, logger, oauthSessionManager); + throw error; + } else if (status === 401) { + return handle401Error(error, client, logger, oauthSessionManager); + } + + throw error; + }, + ); +} + +function handlePossibleOAuthError( + error: unknown, + logger: Logger, + oauthSessionManager: OAuthSessionManager, +): void { + const oauthError = parseOAuthError(error); + if (oauthError && requiresReAuthentication(oauthError)) { + logger.error( + `OAuth error requires re-authentication: ${oauthError.errorCode}`, + ); + + oauthSessionManager.showReAuthenticationModal(oauthError).catch((err) => { + logger.error("Failed to show re-auth modal:", err); + }); + } +} + +async function handle401Error( + error: AxiosError, + client: CoderApi, + logger: Logger, + oauthSessionManager: OAuthSessionManager, +): Promise { + if (!oauthSessionManager.isLoggedInWithOAuth()) { + throw error; + } + + logger.info("Received 401 response, attempting token refresh"); + + try { + const newTokens = await oauthSessionManager.refreshToken(); + client.setSessionToken(newTokens.access_token); + + logger.info("Token refresh successful, retrying request"); + + // Retry the original request with the new token + if (error.config) { + const config = error.config as RequestConfigWithMeta & { + _oauthRetryAttempted?: boolean; + }; + config._oauthRetryAttempted = true; + config.headers[coderSessionTokenHeader] = newTokens.access_token; + return client.getAxiosInstance().request(config); + } + + throw error; + } catch (refreshError) { + logger.error("Token refresh failed:", refreshError); + + handlePossibleOAuthError(refreshError, logger, oauthSessionManager); + throw error; + } +} diff --git a/src/commands.ts b/src/commands.ts index 384b4d79..00cb2ee0 100644 --- a/src/commands.ts +++ b/src/commands.ts @@ -1,15 +1,11 @@ -import { type Api } from "coder/site/src/api/api"; -import { getErrorMessage } from "coder/site/src/api/errors"; import { - type User, type Workspace, type WorkspaceAgent, } from "coder/site/src/api/typesGenerated"; import * as vscode from "vscode"; import { createWorkspaceIdentifier, extractAgents } from "./api/api-helper"; -import { CoderApi } from "./api/coderApi"; -import { needToken } from "./api/utils"; +import { type CoderApi } from "./api/coderApi"; import { type CliManager } from "./core/cliManager"; import { type ServiceContainer } from "./core/container"; import { type ContextManager } from "./core/contextManager"; @@ -19,6 +15,8 @@ import { type SecretsManager } from "./core/secretsManager"; import { CertificateError } from "./error"; import { getGlobalFlags } from "./globalFlags"; import { type Logger } from "./logging/logger"; +import { type LoginCoordinator } from "./login/loginCoordinator"; +import { type OAuthSessionManager } from "./oauth/sessionManager"; import { maybeAskAgent, maybeAskUrl } from "./promptUtils"; import { escapeCommandArg, toRemoteAuthority, toSafeHost } from "./util"; import { @@ -35,6 +33,8 @@ export class Commands { private readonly secretsManager: SecretsManager; private readonly cliManager: CliManager; private readonly contextManager: ContextManager; + private readonly loginCoordinator: LoginCoordinator; + // These will only be populated when actively connected to a workspace and are // used in commands. Because commands can be executed by the user, it is not // possible to pass in arguments, so we have to store the current workspace @@ -44,11 +44,12 @@ export class Commands { // if you use multiple deployments). public workspace?: Workspace; public workspaceLogPath?: string; - public workspaceRestClient?: Api; + public remoteWorkspaceClient?: CoderApi; public constructor( serviceContainer: ServiceContainer, - private readonly restClient: Api, + private readonly extensionClient: CoderApi, + private readonly oauthSessionManager: OAuthSessionManager, ) { this.vscodeProposed = serviceContainer.getVsCodeProposed(); this.logger = serviceContainer.getLogger(); @@ -57,6 +58,18 @@ export class Commands { this.secretsManager = serviceContainer.getSecretsManager(); this.cliManager = serviceContainer.getCliManager(); this.contextManager = serviceContainer.getContextManager(); + this.loginCoordinator = serviceContainer.getLoginCoordinator(); + } + + /** + * Get the current deployment, throwing if not logged in. + */ + private requireExtensionBaseUrl(): string { + const url = this.extensionClient.getAxiosInstance().defaults.baseURL; + if (!url) { + throw new Error("You are not logged in"); + } + return url; } /** @@ -66,7 +79,6 @@ export class Commands { */ public async login(args?: { url?: string; - token?: string; label?: string; autoLogin?: boolean; }): Promise { @@ -75,44 +87,49 @@ export class Commands { } this.logger.info("Logging in"); - const url = await maybeAskUrl(this.mementoManager, args?.url); + const currentDeployment = await this.secretsManager.getCurrentDeployment(); + const url = await maybeAskUrl( + this.mementoManager, + args?.url, + currentDeployment?.url, + ); if (!url) { - return; // The user aborted. + return; } // It is possible that we are trying to log into an old-style host, in which // case we want to write with the provided blank label instead of generating // a host label. - const label = args?.label === undefined ? toSafeHost(url) : args.label; + const label = args?.label ?? toSafeHost(url); + this.logger.info("Using deployment label", label); + + const result = await this.loginCoordinator.promptForLogin({ + label, + url, + autoLogin: args?.autoLogin, + oauthSessionManager: this.oauthSessionManager, + }); - // Try to get a token from the user, if we need one, and their user. - const autoLogin = args?.autoLogin === true; - const res = await this.maybeAskToken(url, args?.token, autoLogin); - if (!res) { - return; // The user aborted, or unable to auth. + if (!result.success || !result.user || !result.token) { + return; } - // The URL is good and the token is either good or not required; authorize - // the global client. - this.restClient.setHost(url); - this.restClient.setSessionToken(res.token); - - // Store these to be used in later sessions. - await this.mementoManager.setUrl(url); - await this.secretsManager.setSessionToken(res.token); + // Set client immediately so subsequent operations in this function have the correct host/token. + // The cross-window listener will also update the client, but that's async. + this.extensionClient.setCredentials(url, result.token); - // Store on disk to be used by the cli. - await this.cliManager.configure(label, url, res.token); + // Set as current deployment + await this.secretsManager.setCurrentDeployment({ url, label }); - // These contexts control various menu items and the sidebar. + // Update contexts this.contextManager.set("coder.authenticated", true); - if (res.user.roles.find((role) => role.name === "owner")) { + if (result.user.roles.some((role) => role.name === "owner")) { this.contextManager.set("coder.isOwner", true); } vscode.window .showInformationMessage( - `Welcome to Coder, ${res.user.username}!`, + `Welcome to Coder, ${result.user.username}!`, { detail: "You can now use the Coder extension to manage your Coder instance.", @@ -124,101 +141,6 @@ export class Commands { vscode.commands.executeCommand("coder.open"); } }); - - await this.secretsManager.triggerLoginStateChange("login"); - // Fetch workspaces for the new deployment. - vscode.commands.executeCommand("coder.refreshWorkspaces"); - } - - /** - * If necessary, ask for a token, and keep asking until the token has been - * validated. Return the token and user that was fetched to validate the - * token. Null means the user aborted or we were unable to authenticate with - * mTLS (in the latter case, an error notification will have been displayed). - */ - private async maybeAskToken( - url: string, - token: string | undefined, - isAutoLogin: boolean, - ): Promise<{ user: User; token: string } | null> { - const client = CoderApi.create(url, token, this.logger); - const needsToken = needToken(vscode.workspace.getConfiguration()); - if (!needsToken || token) { - try { - const user = await client.getAuthenticatedUser(); - // For non-token auth, we write a blank token since the `vscodessh` - // command currently always requires a token file. - // For token auth, we have valid access so we can just return the user here - return { token: needsToken && token ? token : "", user }; - } catch (err) { - const message = getErrorMessage(err, "no response from the server"); - if (isAutoLogin) { - this.logger.warn("Failed to log in to Coder server:", message); - } else { - this.vscodeProposed.window.showErrorMessage( - "Failed to log in to Coder server", - { - detail: message, - modal: true, - useCustom: true, - }, - ); - } - // Invalid certificate, most likely. - return null; - } - } - - // This prompt is for convenience; do not error if they close it since - // they may already have a token or already have the page opened. - await vscode.env.openExternal(vscode.Uri.parse(`${url}/cli-auth`)); - - // For token auth, start with the existing token in the prompt or the last - // used token. Once submitted, if there is a failure we will keep asking - // the user for a new token until they quit. - let user: User | undefined; - const validatedToken = await vscode.window.showInputBox({ - title: "Coder API Key", - password: true, - placeHolder: "Paste your API key.", - value: token || (await this.secretsManager.getSessionToken()), - ignoreFocusOut: true, - validateInput: async (value) => { - if (!value) { - return null; - } - client.setSessionToken(value); - try { - user = await client.getAuthenticatedUser(); - } catch (err) { - // For certificate errors show both a notification and add to the - // text under the input box, since users sometimes miss the - // notification. - if (err instanceof CertificateError) { - err.showNotification(); - - return { - message: err.x509Err || err.message, - severity: vscode.InputBoxValidationSeverity.Error, - }; - } - // This could be something like the header command erroring or an - // invalid session token. - const message = getErrorMessage(err, "no response from the server"); - return { - message: "Failed to authenticate: " + message, - severity: vscode.InputBoxValidationSeverity.Error, - }; - } - }, - }); - - if (validatedToken && user) { - return { token: validatedToken, user }; - } - - // User aborted. - return null; } /** @@ -250,27 +172,25 @@ export class Commands { * Log out from the currently logged-in deployment. */ public async logout(): Promise { - const url = this.mementoManager.getUrl(); - if (!url) { - // Sanity check; command should not be available if no url. - throw new Error("You are not logged in"); - } - await this.forceLogout(); - } - - public async forceLogout(): Promise { + const baseUrl = this.requireExtensionBaseUrl(); if (!this.contextManager.get("coder.authenticated")) { return; } - this.logger.info("Logging out"); + + const label = toSafeHost(baseUrl); + this.logger.info(`Logging out of deployment: ${label}`); + + // Fire and forget OAuth logout + this.oauthSessionManager.logout().catch((error) => { + this.logger.warn("OAuth logout failed, continuing with cleanup:", error); + }); + // Clear from the REST client. An empty url will indicate to other parts of // the code that we are logged out. - this.restClient.setHost(""); - this.restClient.setSessionToken(""); + this.extensionClient.setCredentials(undefined, undefined); - // Clear from memory. - await this.mementoManager.setUrl(undefined); - await this.secretsManager.setSessionToken(undefined); + // Clear current deployment (triggers cross-window sync) + await this.secretsManager.setCurrentDeployment(undefined); this.contextManager.set("coder.authenticated", false); vscode.window @@ -280,10 +200,6 @@ export class Commands { this.login(); } }); - - await this.secretsManager.triggerLoginStateChange("logout"); - // This will result in clearing the workspace list. - vscode.commands.executeCommand("coder.refreshWorkspaces"); } /** @@ -292,7 +208,8 @@ export class Commands { * Must only be called if currently logged in. */ public async createWorkspace(): Promise { - const uri = this.mementoManager.getUrl() + "/templates"; + const baseUrl = this.requireExtensionBaseUrl(); + const uri = baseUrl + "/templates"; await vscode.commands.executeCommand("vscode.open", uri); } @@ -306,12 +223,13 @@ export class Commands { */ public async navigateToWorkspace(item: OpenableTreeItem) { if (item) { + const baseUrl = this.requireExtensionBaseUrl(); const workspaceId = createWorkspaceIdentifier(item.workspace); - const uri = this.mementoManager.getUrl() + `/@${workspaceId}`; + const uri = baseUrl + `/@${workspaceId}`; await vscode.commands.executeCommand("vscode.open", uri); - } else if (this.workspace && this.workspaceRestClient) { + } else if (this.workspace && this.remoteWorkspaceClient) { const baseUrl = - this.workspaceRestClient.getAxiosInstance().defaults.baseURL; + this.remoteWorkspaceClient.getAxiosInstance().defaults.baseURL; const uri = `${baseUrl}/@${createWorkspaceIdentifier(this.workspace)}`; await vscode.commands.executeCommand("vscode.open", uri); } else { @@ -329,12 +247,13 @@ export class Commands { */ public async navigateToWorkspaceSettings(item: OpenableTreeItem) { if (item) { + const baseUrl = this.requireExtensionBaseUrl(); const workspaceId = createWorkspaceIdentifier(item.workspace); - const uri = this.mementoManager.getUrl() + `/@${workspaceId}/settings`; + const uri = baseUrl + `/@${workspaceId}/settings`; await vscode.commands.executeCommand("vscode.open", uri); - } else if (this.workspace && this.workspaceRestClient) { + } else if (this.workspace && this.remoteWorkspaceClient) { const baseUrl = - this.workspaceRestClient.getAxiosInstance().defaults.baseURL; + this.remoteWorkspaceClient.getAxiosInstance().defaults.baseURL; const uri = `${baseUrl}/@${createWorkspaceIdentifier(this.workspace)}/settings`; await vscode.commands.executeCommand("vscode.open", uri); } else { @@ -352,7 +271,7 @@ export class Commands { */ public async openFromSidebar(item: OpenableTreeItem) { if (item) { - const baseUrl = this.restClient.getAxiosInstance().defaults.baseURL; + const baseUrl = this.extensionClient.getAxiosInstance().defaults.baseURL; if (!baseUrl) { throw new Error("You are not logged in"); } @@ -379,7 +298,7 @@ export class Commands { true, ); } else { - throw new Error("Unable to open unknown sidebar item"); + throw new TypeError("Unable to open unknown sidebar item"); } } else { // If there is no tree item, then the user manually ran this command. @@ -407,25 +326,20 @@ export class Commands { const terminal = vscode.window.createTerminal(app.name); // If workspace_name is provided, run coder ssh before the command - - const url = this.mementoManager.getUrl(); - if (!url) { - throw new Error("No coder url found for sidebar"); - } + const baseUrl = this.requireExtensionBaseUrl(); + const label = toSafeHost(baseUrl); const binary = await this.cliManager.fetchBinary( - this.restClient, - toSafeHost(url), + this.extensionClient, + label, ); - const configDir = this.pathResolver.getGlobalConfigDir( - toSafeHost(url), - ); + const configDir = this.pathResolver.getGlobalConfigDir(label); const globalFlags = getGlobalFlags( vscode.workspace.getConfiguration(), configDir, ); terminal.sendText( - `${escapeCommandArg(binary)}${` ${globalFlags.join(" ")}`} ssh ${app.workspace_name}`, + `${escapeCommandArg(binary)} ${globalFlags.join(" ")} ssh ${app.workspace_name}`, ); await new Promise((resolve) => setTimeout(resolve, 5000)); terminal.sendText(app.command ?? ""); @@ -433,19 +347,6 @@ export class Commands { }, ); } - // Check if app has a URL to open - if (app.url) { - return vscode.window.withProgress( - { - location: vscode.ProgressLocation.Notification, - title: `Opening ${app.name || "application"} in browser...`, - cancellable: false, - }, - async () => { - await vscode.env.openExternal(vscode.Uri.parse(app.url!)); - }, - ); - } // If no URL or command, show information about the app status vscode.window.showInformationMessage(`${app.name}`, { @@ -469,14 +370,14 @@ export class Commands { folderPath?: string, openRecent?: boolean, ): Promise { - const baseUrl = this.restClient.getAxiosInstance().defaults.baseURL; + const baseUrl = this.extensionClient.getAxiosInstance().defaults.baseURL; if (!baseUrl) { throw new Error("You are not logged in"); } let workspace: Workspace | undefined; if (workspaceOwner && workspaceName) { - workspace = await this.restClient.getWorkspaceByOwnerAndName( + workspace = await this.extensionClient.getWorkspaceByOwnerAndName( workspaceOwner, workspaceName, ); @@ -512,7 +413,7 @@ export class Commands { localWorkspaceFolder: string = "", localConfigFile: string = "", ): Promise { - const baseUrl = this.restClient.getAxiosInstance().defaults.baseURL; + const baseUrl = this.extensionClient.getAxiosInstance().defaults.baseURL; if (!baseUrl) { throw new Error("You are not logged in"); } @@ -524,7 +425,7 @@ export class Commands { workspaceAgent, ); - const hostPath = localWorkspaceFolder ? localWorkspaceFolder : undefined; + const hostPath = localWorkspaceFolder || undefined; const configFile = hostPath && localConfigFile ? { @@ -568,7 +469,7 @@ export class Commands { * this is a no-op. */ public async updateWorkspace(): Promise { - if (!this.workspace || !this.workspaceRestClient) { + if (!this.workspace || !this.remoteWorkspaceClient) { return; } const action = await this.vscodeProposed.window.showWarningMessage( @@ -581,7 +482,7 @@ export class Commands { "Update", ); if (action === "Update") { - await this.workspaceRestClient.updateWorkspaceVersion(this.workspace); + await this.remoteWorkspaceClient.updateWorkspaceVersion(this.workspace); } } @@ -596,7 +497,7 @@ export class Commands { let lastWorkspaces: readonly Workspace[]; quickPick.onDidChangeValue((value) => { quickPick.busy = true; - this.restClient + this.extensionClient .getWorkspaces({ q: value, }) @@ -625,7 +526,6 @@ export class Commands { if (ex instanceof CertificateError) { ex.showNotification(); } - return; }); }); quickPick.show(); @@ -660,7 +560,7 @@ export class Commands { // we need to fetch the agents through the resources API, as the // workspaces query does not include agents when off. this.logger.info("Fetching agents from template version"); - const resources = await this.restClient.getTemplateVersionResources( + const resources = await this.extensionClient.getTemplateVersionResources( workspace.latest_build.template_version_id, ); return extractAgents(resources); diff --git a/src/core/cliManager.ts b/src/core/cliManager.ts index 5e0b3d26..64ebc99e 100644 --- a/src/core/cliManager.ts +++ b/src/core/cliManager.ts @@ -721,8 +721,7 @@ export class CliManager { ): Promise { if (url) { const urlPath = this.pathResolver.getUrlPath(label); - await fs.mkdir(path.dirname(urlPath), { recursive: true }); - await fs.writeFile(urlPath, url); + await this.atomicWriteFile(urlPath, url); } } @@ -739,30 +738,27 @@ export class CliManager { ) { if (token !== null) { const tokenPath = this.pathResolver.getSessionTokenPath(label); - await fs.mkdir(path.dirname(tokenPath), { recursive: true }); - await fs.writeFile(tokenPath, token ?? ""); + await this.atomicWriteFile(tokenPath, token ?? ""); } } /** - * Read the CLI config for a deployment with the provided label. - * - * IF a config file does not exist, return an empty string. - * - * If the label is empty, read the old deployment-unaware config. + * Atomically write content to a file by writing to a temporary file first, + * then renaming it. */ - public async readConfig( - label: string, - ): Promise<{ url: string; token: string }> { - const urlPath = this.pathResolver.getUrlPath(label); - const tokenPath = this.pathResolver.getSessionTokenPath(label); - const [url, token] = await Promise.allSettled([ - fs.readFile(urlPath, "utf8"), - fs.readFile(tokenPath, "utf8"), - ]); - return { - url: url.status === "fulfilled" ? url.value.trim() : "", - token: token.status === "fulfilled" ? token.value.trim() : "", - }; + private async atomicWriteFile( + filePath: string, + content: string, + ): Promise { + await fs.mkdir(path.dirname(filePath), { recursive: true }); + const tempPath = + filePath + ".temp-" + Math.random().toString(36).substring(8); + try { + await fs.writeFile(tempPath, content); + await fs.rename(tempPath, filePath); + } catch (err) { + await fs.rm(tempPath, { force: true }).catch(() => {}); + throw err; + } } } diff --git a/src/core/container.ts b/src/core/container.ts index a8f938ea..f140f628 100644 --- a/src/core/container.ts +++ b/src/core/container.ts @@ -1,6 +1,7 @@ import * as vscode from "vscode"; import { type Logger } from "../logging/logger"; +import { LoginCoordinator } from "../login/loginCoordinator"; import { CliManager } from "./cliManager"; import { ContextManager } from "./contextManager"; @@ -19,6 +20,7 @@ export class ServiceContainer implements vscode.Disposable { private readonly secretsManager: SecretsManager; private readonly cliManager: CliManager; private readonly contextManager: ContextManager; + private readonly loginCoordinator: LoginCoordinator; constructor( context: vscode.ExtensionContext, @@ -30,13 +32,22 @@ export class ServiceContainer implements vscode.Disposable { context.logUri.fsPath, ); this.mementoManager = new MementoManager(context.globalState); - this.secretsManager = new SecretsManager(context.secrets); + this.secretsManager = new SecretsManager( + context.secrets, + context.globalState, + ); this.cliManager = new CliManager( this.vscodeProposed, this.logger, this.pathResolver, ); - this.contextManager = new ContextManager(); + this.contextManager = new ContextManager(context); + this.loginCoordinator = new LoginCoordinator( + this.secretsManager, + this.mementoManager, + this.vscodeProposed, + this.logger, + ); } getVsCodeProposed(): typeof vscode { @@ -67,6 +78,10 @@ export class ServiceContainer implements vscode.Disposable { return this.contextManager; } + getLoginCoordinator(): LoginCoordinator { + return this.loginCoordinator; + } + /** * Dispose of all services and clean up resources. */ diff --git a/src/core/contextManager.ts b/src/core/contextManager.ts index a5a18397..9a0f3d00 100644 --- a/src/core/contextManager.ts +++ b/src/core/contextManager.ts @@ -5,6 +5,7 @@ const CONTEXT_DEFAULTS = { "coder.isOwner": false, "coder.loaded": false, "coder.workspace.updatable": false, + "coder.devMode": false, } as const; type CoderContext = keyof typeof CONTEXT_DEFAULTS; @@ -12,10 +13,14 @@ type CoderContext = keyof typeof CONTEXT_DEFAULTS; export class ContextManager implements vscode.Disposable { private readonly context = new Map(); - public constructor() { - (Object.keys(CONTEXT_DEFAULTS) as CoderContext[]).forEach((key) => { + public constructor(extensionContext: vscode.ExtensionContext) { + for (const key of Object.keys(CONTEXT_DEFAULTS) as CoderContext[]) { this.set(key, CONTEXT_DEFAULTS[key]); - }); + } + this.set( + "coder.devMode", + extensionContext.extensionMode === vscode.ExtensionMode.Development, + ); } public set(key: CoderContext, value: boolean): void { diff --git a/src/core/deployment.ts b/src/core/deployment.ts new file mode 100644 index 00000000..a29c07ae --- /dev/null +++ b/src/core/deployment.ts @@ -0,0 +1,9 @@ +/** + * Represents a Coder deployment with its URL and label. + * The label is used as a unique identifier for storing credentials and configuration. + * It may be derived from the URL hostname (via toSafeHost) or come from SSH host parsing. + */ +export interface Deployment { + readonly url: string; + readonly label: string; +} diff --git a/src/core/mementoManager.ts b/src/core/mementoManager.ts index f79be46c..3cf4478e 100644 --- a/src/core/mementoManager.ts +++ b/src/core/mementoManager.ts @@ -7,27 +7,16 @@ export class MementoManager { constructor(private readonly memento: Memento) {} /** - * Add the URL to the list of recently accessed URLs in global storage, then - * set it as the last used URL. - * - * If the URL is falsey, then remove it as the last used URL and do not touch - * the history. + * Add a URL to the history of recently accessed URLs. + * Used by the URL picker to show recent deployments. */ - public async setUrl(url?: string): Promise { - await this.memento.update("url", url); + public async addToUrlHistory(url: string): Promise { if (url) { const history = this.withUrlHistory(url); await this.memento.update("urlHistory", history); } } - /** - * Get the last used URL. - */ - public getUrl(): string | undefined { - return this.memento.get("url"); - } - /** * Get the most recently accessed URLs (oldest to newest) with the provided * values appended. Duplicates will be removed. diff --git a/src/core/secretsManager.ts b/src/core/secretsManager.ts index 94827b15..e1e9411b 100644 --- a/src/core/secretsManager.ts +++ b/src/core/secretsManager.ts @@ -1,73 +1,338 @@ -import type { SecretStorage, Disposable } from "vscode"; +import { toSafeHost } from "../util"; -const SESSION_TOKEN_KEY = "sessionToken"; +import type { Memento, SecretStorage, Disposable } from "vscode"; -const LOGIN_STATE_KEY = "loginState"; +import type { TokenResponse, ClientRegistrationResponse } from "../oauth/types"; -export enum AuthAction { - LOGIN, - LOGOUT, - INVALID, +import type { Deployment } from "./deployment"; + +const SESSION_KEY_PREFIX = "coder.session."; +const OAUTH_TOKENS_PREFIX = "coder.oauth.tokens."; +const OAUTH_CLIENT_PREFIX = "coder.oauth.client."; + +const CURRENT_DEPLOYMENT_KEY = "coder.currentDeployment"; +const OAUTH_CALLBACK_KEY = "coder.oauthCallback"; + +const DEPLOYMENT_USAGE_KEY = "coder.deploymentUsage"; +const DEFAULT_MAX_DEPLOYMENTS = 10; + +const LEGACY_SESSION_TOKEN_KEY = "sessionToken"; + +export interface DeploymentUsage { + label: string; + lastAccessedAt: string; +} + +export type StoredOAuthTokens = Omit & { + expiry_timestamp: number; + deployment_url: string; +}; + +export interface SessionAuth { + url: string; + token: string; +} + +interface OAuthCallbackData { + state: string; + code: string | null; + error: string | null; +} + +export interface CurrentDeploymentState { + deployment: Deployment | null; } export class SecretsManager { - constructor(private readonly secrets: SecretStorage) {} + constructor( + private readonly secrets: SecretStorage, + private readonly memento: Memento, + ) {} /** - * Set or unset the last used token. + * Sets the current deployment and triggers a cross-window sync event. + * This is the single source of truth for which deployment is currently active. */ - public async setSessionToken(sessionToken?: string): Promise { - if (!sessionToken) { - await this.secrets.delete(SESSION_TOKEN_KEY); - } else { - await this.secrets.store(SESSION_TOKEN_KEY, sessionToken); - } + public async setCurrentDeployment( + deployment: Deployment | undefined, + ): Promise { + const state = { + deployment: deployment ?? null, + timestamp: new Date().toISOString(), + }; + await this.secrets.store(CURRENT_DEPLOYMENT_KEY, JSON.stringify(state)); } /** - * Get the last used token. + * Gets the current deployment from storage. */ - public async getSessionToken(): Promise { + public async getCurrentDeployment(): Promise { try { - return await this.secrets.get(SESSION_TOKEN_KEY); + const data = await this.secrets.get(CURRENT_DEPLOYMENT_KEY); + if (!data) { + return undefined; + } + const parsed = JSON.parse(data) as { deployment: Deployment | null }; + return parsed.deployment ?? undefined; } catch { - // The VS Code session store has become corrupt before, and - // will fail to get the session token... return undefined; } } /** - * Triggers a login/logout event that propagates across all VS Code windows. - * Uses the secrets storage onDidChange event as a cross-window communication mechanism. - * Appends a timestamp to ensure the value always changes, guaranteeing the event fires. + * Listens for deployment changes from any VS Code window. + * Fires when login, logout, or deployment switch occurs. */ - public async triggerLoginStateChange( - action: "login" | "logout", - ): Promise { - const date = new Date().toISOString(); - await this.secrets.store(LOGIN_STATE_KEY, `${action}-${date}`); + public onDidChangeCurrentDeployment( + listener: (state: CurrentDeploymentState) => void | Promise, + ): Disposable { + return this.secrets.onDidChange(async (e) => { + if (e.key !== CURRENT_DEPLOYMENT_KEY) { + return; + } + + try { + const data = await this.secrets.get(CURRENT_DEPLOYMENT_KEY); + if (data) { + const parsed = JSON.parse(data) as { + deployment: Deployment | null; + }; + await listener({ deployment: parsed.deployment }); + } + } catch { + // Ignore parse errors + } + }); } /** - * Listens for login/logout events from any VS Code window. - * The secrets storage onDidChange event fires across all windows, enabling cross-window sync. + * Write an OAuth callback result to secrets storage. + * Used for cross-window communication when OAuth callback arrives in a different window. */ - public onDidChangeLoginState( - listener: (state: AuthAction) => Promise, + public async setOAuthCallback(data: OAuthCallbackData): Promise { + await this.secrets.store(OAUTH_CALLBACK_KEY, JSON.stringify(data)); + } + + /** + * Listen for OAuth callback results from any VS Code window. + * The listener receives the state parameter, code (if success), and error (if failed). + */ + public onDidChangeOAuthCallback( + listener: (data: OAuthCallbackData) => void, ): Disposable { return this.secrets.onDidChange(async (e) => { - if (e.key === LOGIN_STATE_KEY) { - const state = await this.secrets.get(LOGIN_STATE_KEY); - if (state?.startsWith("login")) { - listener(AuthAction.LOGIN); - } else if (state?.startsWith("logout")) { - listener(AuthAction.LOGOUT); - } else { - // Secret was deleted or is invalid - listener(AuthAction.INVALID); + if (e.key !== OAUTH_CALLBACK_KEY) { + return; + } + + try { + const data = await this.secrets.get(OAUTH_CALLBACK_KEY); + if (data) { + const parsed = JSON.parse(data) as OAuthCallbackData; + listener(parsed); } + } catch { + // Ignore parse errors } }); } + + /** + * Listen for changes to a specific deployment's session auth. + */ + public onDidChangeSessionAuth( + label: string, + listener: (auth: SessionAuth | undefined) => void | Promise, + ): Disposable { + const key = `${SESSION_KEY_PREFIX}${label}`; + return this.secrets.onDidChange(async (e) => { + if (e.key !== key) { + return; + } + const auth = await this.getSessionAuth(label); + await listener(auth); + }); + } + + public async getSessionAuth(label: string): Promise { + if (!label) { + return undefined; + } + + try { + const data = await this.secrets.get(`${SESSION_KEY_PREFIX}${label}`); + if (!data) { + return undefined; + } + return JSON.parse(data) as SessionAuth; + } catch { + return undefined; + } + } + + public async getSessionToken(label: string): Promise { + const auth = await this.getSessionAuth(label); + return auth?.token; + } + + public async getUrl(label: string): Promise { + const auth = await this.getSessionAuth(label); + return auth?.url; + } + + public async setSessionAuth(label: string, auth: SessionAuth): Promise { + await this.secrets.store( + `${SESSION_KEY_PREFIX}${label}`, + JSON.stringify(auth), + ); + await this.recordDeploymentAccess(label); + } + + public async clearSessionAuth(label: string): Promise { + await this.secrets.delete(`${SESSION_KEY_PREFIX}${label}`); + } + + public async getOAuthTokens( + label: string, + ): Promise { + try { + const data = await this.secrets.get(`${OAUTH_TOKENS_PREFIX}${label}`); + if (!data) { + return undefined; + } + return JSON.parse(data) as StoredOAuthTokens; + } catch { + return undefined; + } + } + + public async setOAuthTokens( + label: string, + tokens: StoredOAuthTokens, + ): Promise { + await this.secrets.store( + `${OAUTH_TOKENS_PREFIX}${label}`, + JSON.stringify(tokens), + ); + await this.recordDeploymentAccess(label); + } + + public async clearOAuthTokens(label: string): Promise { + await this.secrets.delete(`${OAUTH_TOKENS_PREFIX}${label}`); + } + + public async getOAuthClientRegistration( + label: string, + ): Promise { + try { + const data = await this.secrets.get(`${OAUTH_CLIENT_PREFIX}${label}`); + if (!data) { + return undefined; + } + return JSON.parse(data) as ClientRegistrationResponse; + } catch { + return undefined; + } + } + + public async setOAuthClientRegistration( + label: string, + registration: ClientRegistrationResponse, + ): Promise { + await this.secrets.store( + `${OAUTH_CLIENT_PREFIX}${label}`, + JSON.stringify(registration), + ); + await this.recordDeploymentAccess(label); + } + + public async clearOAuthClientRegistration(label: string): Promise { + await this.secrets.delete(`${OAUTH_CLIENT_PREFIX}${label}`); + } + + public async clearOAuthData(label: string): Promise { + await Promise.all([ + this.clearOAuthTokens(label), + this.clearOAuthClientRegistration(label), + ]); + } + + /** + * Record that a deployment was accessed, moving it to the front of the LRU list. + * Prunes deployments beyond maxCount, clearing their auth data. + */ + public async recordDeploymentAccess( + label: string, + maxCount = DEFAULT_MAX_DEPLOYMENTS, + ): Promise { + const usage = this.getDeploymentUsage(); + const filtered = usage.filter((u) => u.label !== label); + filtered.unshift({ label, lastAccessedAt: new Date().toISOString() }); + + const toKeep = filtered.slice(0, maxCount); + const toRemove = filtered.slice(maxCount); + + await Promise.all(toRemove.map((u) => this.clearAllAuthData(u.label))); + await this.memento.update(DEPLOYMENT_USAGE_KEY, toKeep); + } + + /** + * Clear all auth data for a deployment and remove it from the usage list. + */ + public async clearAllAuthData(label: string): Promise { + await Promise.all([ + this.clearSessionAuth(label), + this.clearOAuthData(label), + ]); + const usage = this.getDeploymentUsage().filter((u) => u.label !== label); + await this.memento.update(DEPLOYMENT_USAGE_KEY, usage); + } + + /** + * Get all known deployment labels, ordered by most recently accessed. + */ + public getKnownLabels(): string[] { + return this.getDeploymentUsage().map((u) => u.label); + } + + /** + * Get the full deployment usage list with access timestamps. + */ + private getDeploymentUsage(): DeploymentUsage[] { + return this.memento.get(DEPLOYMENT_USAGE_KEY) ?? []; + } + + /** + * Migrate from legacy flat sessionToken storage to new format. + * Also sets the current deployment if none exists. + */ + public async migrateFromLegacyStorage(): Promise { + const legacyUrl = this.memento.get("url"); + if (!legacyUrl) { + return undefined; + } + + const label = toSafeHost(legacyUrl); + + const existing = await this.getSessionAuth(label); + if (existing) { + return undefined; + } + + const oldToken = await this.secrets.get(LEGACY_SESSION_TOKEN_KEY); + if (!oldToken) { + return undefined; + } + + await this.setSessionAuth(label, { url: legacyUrl, token: oldToken }); + await this.secrets.delete(LEGACY_SESSION_TOKEN_KEY); + + // Also set as current deployment if none exists + const currentDeployment = await this.getCurrentDeployment(); + if (!currentDeployment) { + await this.setCurrentDeployment({ url: legacyUrl, label }); + } + + return label; + } } diff --git a/src/extension.ts b/src/extension.ts index 974cbe7d..093fc5e7 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -8,11 +8,15 @@ import * as vscode from "vscode"; import { errToStr } from "./api/api-helper"; import { CoderApi } from "./api/coderApi"; +import { attachOAuthInterceptors } from "./api/oauthInterceptors"; import { needToken } from "./api/utils"; import { Commands } from "./commands"; import { ServiceContainer } from "./core/container"; -import { AuthAction } from "./core/secretsManager"; +import { type Deployment } from "./core/deployment"; +import { type SecretsManager } from "./core/secretsManager"; import { CertificateError, getErrorDetail } from "./error"; +import { OAuthSessionManager } from "./oauth/sessionManager"; +import { CALLBACK_PATH } from "./oauth/utils"; import { maybeAskUrl } from "./promptUtils"; import { Remote } from "./remote/remote"; import { getRemoteSshExtension } from "./remote/sshExtension"; @@ -60,18 +64,32 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { const secretsManager = serviceContainer.getSecretsManager(); const contextManager = serviceContainer.getContextManager(); + // Migrate auth storage from old flat format to new label-based format + await migrateAuthStorage(serviceContainer); + // Try to clear this flag ASAP const isFirstConnect = await mementoManager.getAndClearFirstConnect(); + const deployment = await secretsManager.getCurrentDeployment(); + + // Create OAuth session manager with login coordinator + const oauthSessionManager = await OAuthSessionManager.create( + deployment, + serviceContainer, + ctx.extension.id, + ); + ctx.subscriptions.push(oauthSessionManager); + // This client tracks the current login and will be used through the life of // the plugin to poll workspaces for the current login, as well as being used // in commands that operate on the current login. - const url = mementoManager.getUrl(); const client = CoderApi.create( - url || "", - await secretsManager.getSessionToken(), + deployment?.url || "", + await secretsManager.getSessionToken(deployment?.label ?? ""), output, ); + ctx.subscriptions.push(client); + attachOAuthInterceptors(client, output, oauthSessionManager); const myWorkspacesProvider = new WorkspaceProvider( WorkspaceQuery.Mine, @@ -116,11 +134,95 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { ctx.subscriptions, ); + // Listen for deployment auth changes (token updates) for the current deployment + // This listener is re-registered when the user logs into a different deployment + let authChangeDisposable: vscode.Disposable | undefined; + const registerAuthListener = (deploymentLabel: string | undefined) => { + authChangeDisposable?.dispose(); + + if (!deploymentLabel) { + return; + } + + output.debug("Registering auth listener for deployment", deploymentLabel); + authChangeDisposable = secretsManager.onDidChangeSessionAuth( + deploymentLabel, + (auth) => { + client.setCredentials(auth?.url, auth?.token); + + // Update authentication context for current deployment + // TODO(ehab) this might never even happen :thinking: + contextManager.set("coder.authenticated", auth !== undefined); + }, + ); + }; + + // Initialize auth listener for current deployment + registerAuthListener(deployment?.label); + ctx.subscriptions.push({ dispose: () => authChangeDisposable?.dispose() }); + + const changeDeployment = async ( + deployment: Deployment | null, + sessionToken?: string, + ) => { + // Update client + if (deployment) { + const token = + sessionToken || + (await secretsManager.getSessionToken(deployment.label)); + client.setCredentials(deployment.url, token); + await oauthSessionManager.setDeployment(deployment); + } else { + client.setCredentials(undefined, undefined); + oauthSessionManager.clearDeployment(); + } + registerAuthListener(deployment?.label); + + // Update context + contextManager.set("coder.authenticated", Boolean(deployment)); + + // Refresh workspaces + myWorkspacesProvider.fetchAndRefresh(); + allWorkspacesProvider.fetchAndRefresh(); + }; + + const changeDeploymentAndPersist = async ( + deployment: Deployment | null, + sessionToken?: string, + ) => { + await changeDeployment(deployment, sessionToken); + // Persist and sync deployment across windows + await secretsManager.setCurrentDeployment(deployment ?? undefined); + await mementoManager.addToUrlHistory(deployment?.url ?? ""); + }; + + // Listen for deployment changes from other windows (cross-window sync) + ctx.subscriptions.push( + secretsManager.onDidChangeCurrentDeployment(async ({ deployment }) => { + const isLoggedIn = contextManager.get("coder.authenticated"); + if (isLoggedIn) { + // We keep whatever deployment we have if we're logged in + return; + } + + output.info("Deployment changed from another window"); + return changeDeployment(deployment); + }), + ); + // Handle vscode:// URIs. const uriHandler = vscode.window.registerUriHandler({ handleUri: async (uri) => { - const cliManager = serviceContainer.getCliManager(); const params = new URLSearchParams(uri.query); + + if (uri.path === CALLBACK_PATH) { + const code = params.get("code"); + const state = params.get("state"); + const error = params.get("error"); + await oauthSessionManager.handleCallback(code, state, error); + return; + } + if (uri.path === "/open") { const owner = params.get("owner"); const workspace = params.get("workspace"); @@ -137,42 +239,11 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { throw new Error("workspace must be specified as a query parameter"); } - // We are not guaranteed that the URL we currently have is for the URL - // this workspace belongs to, or that we even have a URL at all (the - // queries will default to localhost) so ask for it if missing. - // Pre-populate in case we do have the right URL so the user can just - // hit enter and move on. - const url = await maybeAskUrl( - mementoManager, - params.get("url"), - mementoManager.getUrl(), + await setupDeploymentFromUri( + params, + serviceContainer, + changeDeploymentAndPersist, ); - if (url) { - client.setHost(url); - await mementoManager.setUrl(url); - } else { - throw new Error( - "url must be provided or specified as a query parameter", - ); - } - - // If the token is missing we will get a 401 later and the user will be - // prompted to sign in again, so we do not need to ensure it is set now. - // For non-token auth, we write a blank token since the `vscodessh` - // command currently always requires a token file. However, if there is - // a query parameter for non-token auth go ahead and use it anyway; all - // that really matters is the file is created. - const token = needToken(vscode.workspace.getConfiguration()) - ? params.get("token") - : (params.get("token") ?? ""); - - if (token) { - client.setSessionToken(token); - await secretsManager.setSessionToken(token); - } - - // Store on disk to be used by the cli. - await cliManager.configure(toSafeHost(url), url, token); vscode.commands.executeCommand( "coder.open", @@ -221,37 +292,11 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { ); } - // We are not guaranteed that the URL we currently have is for the URL - // this workspace belongs to, or that we even have a URL at all (the - // queries will default to localhost) so ask for it if missing. - // Pre-populate in case we do have the right URL so the user can just - // hit enter and move on. - const url = await maybeAskUrl( - mementoManager, - params.get("url"), - mementoManager.getUrl(), + await setupDeploymentFromUri( + params, + serviceContainer, + changeDeploymentAndPersist, ); - if (url) { - client.setHost(url); - await mementoManager.setUrl(url); - } else { - throw new Error( - "url must be provided or specified as a query parameter", - ); - } - - // If the token is missing we will get a 401 later and the user will be - // prompted to sign in again, so we do not need to ensure it is set now. - // For non-token auth, we write a blank token since the `vscodessh` - // command currently always requires a token file. However, if there is - // a query parameter for non-token auth go ahead and use it anyway; all - // that really matters is the file is created. - const token = needToken(vscode.workspace.getConfiguration()) - ? params.get("token") - : (params.get("token") ?? ""); - - // Store on disk to be used by the cli. - await cliManager.configure(toSafeHost(url), url, token); vscode.commands.executeCommand( "coder.openDevContainer", @@ -272,7 +317,7 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { // Register globally available commands. Many of these have visibility // controlled by contexts, see `when` in the package.json. - const commands = new Commands(serviceContainer, client); + const commands = new Commands(serviceContainer, client, oauthSessionManager); ctx.subscriptions.push( vscode.commands.registerCommand( "coder.login", @@ -325,30 +370,12 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { vscode.commands.registerCommand("coder.searchAllWorkspaces", async () => showTreeViewSearch(ALL_WORKSPACES_TREE_ID), ), + vscode.commands.registerCommand("coder.debug.listDeployments", () => + listStoredDeployments(secretsManager), + ), ); - const remote = new Remote(serviceContainer, commands, ctx.extensionMode); - - ctx.subscriptions.push( - secretsManager.onDidChangeLoginState(async (state) => { - switch (state) { - case AuthAction.LOGIN: { - const token = await secretsManager.getSessionToken(); - const url = mementoManager.getUrl(); - // Should login the user directly if the URL+Token are valid - await commands.login({ url, token }); - // Resolve any pending login detection promises - remote.resolveLoginDetected(); - break; - } - case AuthAction.LOGOUT: - await commands.forceLogout(); - break; - case AuthAction.INVALID: - break; - } - }), - ); + const remote = new Remote(serviceContainer, commands, ctx); // Since the "onResolveRemoteAuthority:ssh-remote" activation event exists // in package.json we're able to perform actions before the authority is @@ -368,10 +395,11 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { ); if (details) { ctx.subscriptions.push(details); - // Authenticate the plugin client which is used in the sidebar to display - // workspaces belonging to this deployment. - client.setHost(details.url); - client.setSessionToken(details.token); + + await changeDeploymentAndPersist( + { label: details.label, url: details.url }, + details.token, + ); } } catch (ex) { if (ex instanceof CertificateError) { @@ -460,7 +488,119 @@ export async function activate(ctx: vscode.ExtensionContext): Promise { } } +/** + * Migrates old flat storage (sessionToken) to new label-based map storage. + * This is a one-time operation that runs on extension activation. + */ +async function migrateAuthStorage( + serviceContainer: ServiceContainer, +): Promise { + const secretsManager = serviceContainer.getSecretsManager(); + const output = serviceContainer.getLogger(); + + try { + const migratedLabel = await secretsManager.migrateFromLegacyStorage(); + + if (migratedLabel) { + output.info( + `Successfully migrated auth storage to label-based format (label: ${migratedLabel})`, + ); + } + } catch (error) { + output.error( + `Auth storage migration failed: ${error}. You may need to log in again.`, + ); + } +} + async function showTreeViewSearch(id: string): Promise { await vscode.commands.executeCommand(`${id}.focus`); await vscode.commands.executeCommand("list.find"); } + +/** + * Sets up deployment from URI parameters. Handles URL prompting, client setup, + * and token storage. Throws if user cancels URL input. + * + * Updates the client host/token, auth listener, OAuth manager, context, etc. + * through the `changeDeploymentAndPersist` callback. + */ +async function setupDeploymentFromUri( + params: URLSearchParams, + serviceContainer: ServiceContainer, + changeDeploymentAndPersist: ( + deployment: Deployment | null, + sessionToken?: string, + ) => Promise, +): Promise { + const secretsManager = serviceContainer.getSecretsManager(); + const mementoManager = serviceContainer.getMementoManager(); + const currentDeployment = await secretsManager.getCurrentDeployment(); + + // We are not guaranteed that the URL we currently have is for the URL + // this workspace belongs to, or that we even have a URL at all (the + // queries will default to localhost) so ask for it if missing. + // Pre-populate in case we do have the right URL so the user can just + // hit enter and move on. + const url = await maybeAskUrl( + mementoManager, + params.get("url"), + currentDeployment?.url, + ); + if (!url) { + throw new Error("url must be provided or specified as a query parameter"); + } + + const label = toSafeHost(url); + + // If the token is missing we will get a 401 later and the user will be + // prompted to sign in again, so we do not need to ensure it is set now. + // For non-token auth, we write a blank token since the `vscodessh` + // command currently always requires a token file. However, if there is + // a query parameter for non-token auth go ahead and use it anyway; + const token = await getToken(params, label, secretsManager); + if (token) { + await secretsManager.setSessionAuth(label, { url, token }); + } + + await changeDeploymentAndPersist({ label, url }, token); +} + +async function getToken( + params: URLSearchParams, + label: string, + secretsManager: SecretsManager, +): Promise { + const paramsToken = params.get("token"); + if (paramsToken !== null) { + // Always prefer the passed token if set + return paramsToken; + } + + if (needToken(vscode.workspace.getConfiguration())) { + return await secretsManager.getSessionToken(label); + } + return ""; +} + +async function listStoredDeployments( + secretsManager: SecretsManager, +): Promise { + const labels = secretsManager.getKnownLabels(); + if (labels.length === 0) { + vscode.window.showInformationMessage("No deployments stored."); + return; + } + + const selected = await vscode.window.showQuickPick( + labels.map((label) => ({ label, description: "Click to forget" })), + { placeHolder: "Select a deployment to forget" }, + ); + + if (selected) { + await secretsManager.clearAllAuthData(selected.label); + vscode.window.showInformationMessage( + `Cleared auth data for ${selected.label}`, + ); + } +} diff --git a/src/login/loginCoordinator.ts b/src/login/loginCoordinator.ts new file mode 100644 index 00000000..f170baeb --- /dev/null +++ b/src/login/loginCoordinator.ts @@ -0,0 +1,332 @@ +import { getErrorMessage } from "coder/site/src/api/errors"; +import * as vscode from "vscode"; + +import { CoderApi } from "../api/coderApi"; +import { needToken } from "../api/utils"; +import { type Deployment } from "../core/deployment"; +import { type MementoManager } from "../core/mementoManager"; +import { type SecretsManager } from "../core/secretsManager"; +import { CertificateError } from "../error"; +import { type Logger } from "../logging/logger"; +import { maybeAskAuthMethod, maybeAskUrl } from "../promptUtils"; + +import type { User } from "coder/site/src/api/typesGenerated"; + +import type { OAuthSessionManager } from "../oauth/sessionManager"; + +interface LoginResult { + success: boolean; + user?: User; + token?: string; +} + +interface LoginOptions { + label: string; + url: string | undefined; + oauthSessionManager: OAuthSessionManager; + autoLogin?: boolean; +} + +/** + * Coordinates login prompts across windows and prevents duplicate dialogs. + */ +export class LoginCoordinator { + private readonly inProgressLogins = new Map>(); + + constructor( + private readonly secretsManager: SecretsManager, + private readonly mementoManager: MementoManager, + private readonly vscodeProposed: typeof vscode, + private readonly logger: Logger, + ) {} + + /** + * Direct login - for user-initiated login via commands. + * Stores session auth and URL history on success. + */ + public async promptForLogin( + options: LoginOptions & { url: string }, + ): Promise { + const { label, url, oauthSessionManager } = options; + return this.executeWithGuard(label, async () => { + const result = await this.attemptLogin( + { label, url }, + options.autoLogin ?? false, + oauthSessionManager, + ); + + await this.persistSessionAuth(result, label, url); + + return result; + }); + } + + /** + * Shows dialog then login - for system-initiated auth (remote, OAuth refresh). + */ + public async promptForLoginWithDialog( + options: LoginOptions & { message?: string; detailPrefix?: string }, + ): Promise { + const { label, url, detailPrefix, message, oauthSessionManager } = options; + return this.executeWithGuard(label, () => { + // Show dialog promise + const dialogPromise = this.vscodeProposed.window + .showErrorMessage( + message || "Authentication Required", + { + modal: true, + useCustom: true, + detail: + (detailPrefix || `Authentication needed for ${label}.`) + + "\n\nIf you've already logged in, you may close this dialog.", + }, + "Login", + ) + .then(async (action) => { + if (action === "Login") { + // Proceed with the login flow, handling logging in from another window + const storedUrl = await this.secretsManager.getUrl(label); + const newUrl = await maybeAskUrl( + this.mementoManager, + url, + storedUrl, + ); + if (!newUrl) { + throw new Error("URL must be provided"); + } + + const result = await this.attemptLogin( + { url: newUrl, label }, + false, + oauthSessionManager, + ); + + await this.persistSessionAuth(result, label, newUrl); + + return result; + } else { + // User cancelled + return { success: false }; + } + }); + + // Race between user clicking login and cross-window detection + return Promise.race([dialogPromise, this.waitForCrossWindowLogin(label)]); + }); + } + + private async persistSessionAuth( + result: LoginResult, + label: string, + url: string, + ): Promise { + if (result.success && result.token) { + await this.secretsManager.setSessionAuth(label, { + url, + token: result.token, + }); + await this.mementoManager.addToUrlHistory(url); + } + } + + /** + * Same-window guard wrapper. + */ + private async executeWithGuard( + label: string, + executeFn: () => Promise, + ): Promise { + const existingLogin = this.inProgressLogins.get(label); + if (existingLogin) { + return existingLogin; + } + + const loginPromise = executeFn(); + this.inProgressLogins.set(label, loginPromise); + + try { + return await loginPromise; + } finally { + this.inProgressLogins.delete(label); + } + } + + /** + * Waits for login detected from another window. + */ + private async waitForCrossWindowLogin(label: string): Promise { + return new Promise((resolve) => { + const disposable = this.secretsManager.onDidChangeSessionAuth( + label, + (auth) => { + if (auth?.token) { + disposable.dispose(); + resolve({ success: true, token: auth.token }); + } + }, + ); + }); + } + + /** + * Attempt to authenticate using OAuth, token, or mTLS. If necessary, prompts + * for authentication method and credentials. Returns the token and user upon + * successful authentication. Null means the user aborted or authentication + * failed (in which case an error notification will have been displayed). + */ + private async attemptLogin( + deployment: Deployment, + isAutoLogin: boolean, + oauthSessionManager: OAuthSessionManager, + ): Promise { + const needsToken = needToken(vscode.workspace.getConfiguration()); + const client = CoderApi.create(deployment.url, "", this.logger); + + let storedToken: string | undefined; + if (needsToken) { + storedToken = await this.secretsManager.getSessionToken(deployment.label); + if (storedToken) { + client.setSessionToken(storedToken); + } + } + + // Attempt authentication with current credentials (token or mTLS) + try { + const user = await client.getAuthenticatedUser(); + // Return the token that was used (empty string for mTLS since + // the `vscodessh` command currently always requires a token file) + return { success: true, token: storedToken ?? "", user }; + } catch (err) { + if (needsToken) { + // For token auth: silently continue to prompt for new credentials + } else { + // For mTLS: show error and abort (no credentials to prompt for) + const message = getErrorMessage(err, "no response from the server"); + if (isAutoLogin) { + this.logger.warn("Failed to log in to Coder server:", message); + } else { + this.vscodeProposed.window.showErrorMessage( + "Failed to log in to Coder server", + { + detail: message, + modal: true, + useCustom: true, + }, + ); + } + return { success: false }; + } + } + + const authMethod = await maybeAskAuthMethod(client); + switch (authMethod) { + case "oauth": + return this.loginWithOAuth(client, oauthSessionManager, deployment); + case "legacy": + return this.loginWithToken(client); + case undefined: + return { success: false }; // User aborted + } + } + + /** + * Session token authentication flow. + */ + private async loginWithToken(client: CoderApi): Promise { + const url = client.getAxiosInstance().defaults.baseURL; + if (!url) { + throw new Error("No base URL set on REST client"); + } + // This prompt is for convenience; do not error if they close it since + // they may already have a token or already have the page opened. + await vscode.env.openExternal(vscode.Uri.parse(`${url}/cli-auth`)); + + // For token auth, start with the existing token in the prompt or the last + // used token. Once submitted, if there is a failure we will keep asking + // the user for a new token until they quit. + let user: User | undefined; + const validatedToken = await vscode.window.showInputBox({ + title: "Coder API Key", + password: true, + placeHolder: "Paste your API key.", + ignoreFocusOut: true, + validateInput: async (value) => { + if (!value) { + return null; + } + client.setSessionToken(value); + try { + user = await client.getAuthenticatedUser(); + } catch (err) { + // For certificate errors show both a notification and add to the + // text under the input box, since users sometimes miss the + // notification. + if (err instanceof CertificateError) { + err.showNotification(); + return { + message: err.x509Err || err.message, + severity: vscode.InputBoxValidationSeverity.Error, + }; + } + // This could be something like the header command erroring or an + // invalid session token. + const message = getErrorMessage(err, "no response from the server"); + return { + message: "Failed to authenticate: " + message, + severity: vscode.InputBoxValidationSeverity.Error, + }; + } + }, + }); + + if (user === undefined || validatedToken === undefined) { + return { success: false }; + } + + return { success: true, user, token: validatedToken }; + } + + /** + * OAuth authentication flow. + */ + private async loginWithOAuth( + client: CoderApi, + oauthSessionManager: OAuthSessionManager, + deployment: Deployment, + ): Promise { + try { + this.logger.info("Starting OAuth authentication"); + + const tokenResponse = await vscode.window.withProgress( + { + location: vscode.ProgressLocation.Notification, + title: "Authenticating", + cancellable: false, + }, + async (progress) => + await oauthSessionManager.login(client, deployment, progress), + ); + + // Validate token by fetching user + client.setSessionToken(tokenResponse.access_token); + const user = await client.getAuthenticatedUser(); + + return { + success: true, + token: tokenResponse.access_token, + user, + }; + } catch (error) { + const title = "OAuth authentication failed"; + this.logger.error(title, error); + if (error instanceof CertificateError) { + error.showNotification(title); + } else { + vscode.window.showErrorMessage( + `${title}: ${getErrorMessage(error, "Unknown error")}`, + ); + } + return { success: false }; + } + } +} diff --git a/src/oauth/errors.ts b/src/oauth/errors.ts new file mode 100644 index 00000000..9b7ee3ac --- /dev/null +++ b/src/oauth/errors.ts @@ -0,0 +1,166 @@ +import { isAxiosError } from "axios"; + +import type { OAuthErrorResponse } from "./types"; + +/** + * Base class for OAuth errors + */ +export class OAuthError extends Error { + constructor( + message: string, + public readonly errorCode: string, + public readonly description?: string, + public readonly errorUri?: string, + ) { + super(message); + this.name = "OAuthError"; + } +} + +/** + * Refresh token is invalid, expired, or revoked. Requires re-authentication. + */ +export class InvalidGrantError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth refresh token is invalid, expired, or revoked", + "invalid_grant", + description, + errorUri, + ); + this.name = "InvalidGrantError"; + } +} + +/** + * Client credentials are invalid. Requires re-registration. + */ +export class InvalidClientError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth client credentials are invalid", + "invalid_client", + description, + errorUri, + ); + this.name = "InvalidClientError"; + } +} + +/** + * Invalid request error - malformed OAuth request + */ +export class InvalidRequestError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth request is malformed or invalid", + "invalid_request", + description, + errorUri, + ); + this.name = "InvalidRequestError"; + } +} + +/** + * Client is not authorized for this grant type. + */ +export class UnauthorizedClientError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth client is not authorized for this grant type", + "unauthorized_client", + description, + errorUri, + ); + this.name = "UnauthorizedClientError"; + } +} + +/** + * Unsupported grant type error. + */ +export class UnsupportedGrantTypeError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth grant type is not supported", + "unsupported_grant_type", + description, + errorUri, + ); + this.name = "UnsupportedGrantTypeError"; + } +} + +/** + * Invalid scope error. + */ +export class InvalidScopeError extends OAuthError { + constructor(description?: string, errorUri?: string) { + super( + "OAuth scope is invalid, unknown, malformed, or exceeds the scope granted by the resource owner", + "invalid_scope", + description, + errorUri, + ); + this.name = "InvalidScopeError"; + } +} + +/** + * Parses an axios error to extract OAuth error information + * Returns an OAuthError instance if the error is OAuth-related, otherwise returns null + */ +export function parseOAuthError(error: unknown): OAuthError | null { + if (!isAxiosError(error)) { + return null; + } + + const data = error.response?.data; + + if (!isOAuthErrorResponse(data)) { + return null; + } + + const { error: errorCode, error_description, error_uri } = data; + + switch (errorCode) { + case "invalid_grant": + return new InvalidGrantError(error_description, error_uri); + case "invalid_client": + return new InvalidClientError(error_description, error_uri); + case "invalid_request": + return new InvalidRequestError(error_description, error_uri); + case "unauthorized_client": + return new UnauthorizedClientError(error_description, error_uri); + case "unsupported_grant_type": + return new UnsupportedGrantTypeError(error_description, error_uri); + case "invalid_scope": + return new InvalidScopeError(error_description, error_uri); + default: + return new OAuthError( + `OAuth error: ${errorCode}`, + errorCode, + error_description, + error_uri, + ); + } +} + +function isOAuthErrorResponse(data: unknown): data is OAuthErrorResponse { + return ( + data !== null && + typeof data === "object" && + "error" in data && + typeof data.error === "string" + ); +} + +/** + * Checks if an error requires re-authentication + */ +export function requiresReAuthentication(error: OAuthError): boolean { + return ( + error instanceof InvalidGrantError || error instanceof InvalidClientError + ); +} diff --git a/src/oauth/metadataClient.ts b/src/oauth/metadataClient.ts new file mode 100644 index 00000000..149d64fa --- /dev/null +++ b/src/oauth/metadataClient.ts @@ -0,0 +1,137 @@ +import type { AxiosInstance } from "axios"; + +import type { Logger } from "../logging/logger"; + +import type { OAuthServerMetadata } from "./types"; + +const OAUTH_DISCOVERY_ENDPOINT = "/.well-known/oauth-authorization-server"; + +const AUTH_GRANT_TYPE = "authorization_code" as const; +const REFRESH_GRANT_TYPE = "refresh_token" as const; +const RESPONSE_TYPE = "code" as const; +const OAUTH_METHOD = "client_secret_post" as const; +const PKCE_CHALLENGE_METHOD = "S256" as const; + +const REQUIRED_GRANT_TYPES = [AUTH_GRANT_TYPE, REFRESH_GRANT_TYPE] as const; + +/** + * Client for discovering and validating OAuth server metadata. + */ +export class OAuthMetadataClient { + constructor( + private readonly axiosInstance: AxiosInstance, + private readonly logger: Logger, + ) {} + + /** + * Check if a server supports OAuth by attempting to fetch the well-known endpoint. + */ + public static async checkOAuthSupport( + axiosInstance: AxiosInstance, + ): Promise { + try { + await axiosInstance.get(OAUTH_DISCOVERY_ENDPOINT); + return true; + } catch { + return false; + } + } + + /** + * Fetch and validate OAuth server metadata. + * Throws detailed errors if server doesn't meet OAuth 2.1 requirements. + */ + async getMetadata(): Promise { + this.logger.debug("Discovering OAuth endpoints..."); + + const response = await this.axiosInstance.get( + OAUTH_DISCOVERY_ENDPOINT, + ); + + const metadata = response.data; + + this.validateRequiredEndpoints(metadata); + this.validateGrantTypes(metadata); + this.validateResponseTypes(metadata); + this.validateAuthMethods(metadata); + this.validatePKCEMethods(metadata); + + this.logger.debug("OAuth endpoints discovered:", { + authorization: metadata.authorization_endpoint, + token: metadata.token_endpoint, + registration: metadata.registration_endpoint, + revocation: metadata.revocation_endpoint, + }); + + return metadata; + } + + private validateRequiredEndpoints(metadata: OAuthServerMetadata): void { + if ( + !metadata.authorization_endpoint || + !metadata.token_endpoint || + !metadata.issuer + ) { + throw new Error( + "OAuth server metadata missing required endpoints: " + + JSON.stringify(metadata), + ); + } + } + + private validateGrantTypes(metadata: OAuthServerMetadata): void { + if ( + !includesAllTypes(metadata.grant_types_supported, REQUIRED_GRANT_TYPES) + ) { + throw new Error( + `Server does not support required grant types: ${REQUIRED_GRANT_TYPES.join(", ")}. Supported: ${metadata.grant_types_supported?.join(", ") || "none"}`, + ); + } + } + + private validateResponseTypes(metadata: OAuthServerMetadata): void { + if (!includesAllTypes(metadata.response_types_supported, [RESPONSE_TYPE])) { + throw new Error( + `Server does not support required response type: ${RESPONSE_TYPE}. Supported: ${metadata.response_types_supported?.join(", ") || "none"}`, + ); + } + } + + private validateAuthMethods(metadata: OAuthServerMetadata): void { + if ( + !includesAllTypes(metadata.token_endpoint_auth_methods_supported, [ + OAUTH_METHOD, + ]) + ) { + throw new Error( + `Server does not support required auth method: ${OAUTH_METHOD}. Supported: ${metadata.token_endpoint_auth_methods_supported?.join(", ") || "none"}`, + ); + } + } + + private validatePKCEMethods(metadata: OAuthServerMetadata): void { + if ( + !includesAllTypes(metadata.code_challenge_methods_supported, [ + PKCE_CHALLENGE_METHOD, + ]) + ) { + throw new Error( + `Server does not support required PKCE method: ${PKCE_CHALLENGE_METHOD}. Supported: ${metadata.code_challenge_methods_supported?.join(", ") || "none"}`, + ); + } + } +} + +/** + * Check if an array includes all required types. + * If the array is undefined, returns true (server didn't specify, assume all allowed). + */ +function includesAllTypes( + arr: string[] | undefined, + requiredTypes: readonly string[], +): boolean { + if (arr === undefined) { + return true; + } + return requiredTypes.every((type) => arr.includes(type)); +} diff --git a/src/oauth/sessionManager.ts b/src/oauth/sessionManager.ts new file mode 100644 index 00000000..482065af --- /dev/null +++ b/src/oauth/sessionManager.ts @@ -0,0 +1,808 @@ +import { type AxiosInstance } from "axios"; +import * as vscode from "vscode"; + +import { CoderApi } from "../api/coderApi"; +import { type ServiceContainer } from "../core/container"; +import { type Deployment } from "../core/deployment"; +import { type LoginCoordinator } from "../login/loginCoordinator"; + +import { OAuthMetadataClient } from "./metadataClient"; +import { + CALLBACK_PATH, + generatePKCE, + generateState, + toUrlSearchParams, +} from "./utils"; + +import type { SecretsManager, StoredOAuthTokens } from "../core/secretsManager"; +import type { Logger } from "../logging/logger"; + +import type { OAuthError } from "./errors"; +import type { + ClientRegistrationRequest, + ClientRegistrationResponse, + OAuthServerMetadata, + RefreshTokenRequestParams, + TokenRequestParams, + TokenResponse, + TokenRevocationRequest, +} from "./types"; + +const AUTH_GRANT_TYPE = "authorization_code" as const; +const REFRESH_GRANT_TYPE = "refresh_token" as const; +const RESPONSE_TYPE = "code" as const; +const PKCE_CHALLENGE_METHOD = "S256" as const; + +/** + * Token refresh threshold: refresh when token expires in less than this time. + */ +const TOKEN_REFRESH_THRESHOLD_MS = 10 * 60 * 1000; + +/** + * Default expiry time for OAuth access tokens when the server doesn't provide one. + */ +const ACCESS_TOKEN_DEFAULT_EXPIRY_MS = 60 * 60 * 1000; + +/** + * Minimum time between refresh attempts to prevent thrashing. + */ +const REFRESH_THROTTLE_MS = 30 * 1000; + +/** + * Background token refresh check interval. + */ +const BACKGROUND_REFRESH_INTERVAL_MS = 5 * 60 * 1000; + +/** + * Minimal scopes required by the VS Code extension. + */ +const DEFAULT_OAUTH_SCOPES = [ + "workspace:read", + "workspace:update", + "workspace:start", + "workspace:ssh", + "workspace:application_connect", + "template:read", + "user:read_personal", +].join(" "); + +/** + * Manages OAuth session lifecycle for a Coder deployment. + * Coordinates authorization flow, token management, and automatic refresh. + */ +export class OAuthSessionManager implements vscode.Disposable { + private storedTokens: StoredOAuthTokens | undefined; + private refreshPromise: Promise | null = null; + private lastRefreshAttempt = 0; + private refreshTimer: NodeJS.Timeout | undefined; + + private pendingAuthReject: ((reason: Error) => void) | undefined; + + /** + * Create and initialize a new OAuth session manager. + */ + public static async create( + deployment: Deployment | undefined, + container: ServiceContainer, + extensionId: string, + ): Promise { + const manager = new OAuthSessionManager( + deployment, + container.getSecretsManager(), + container.getLogger(), + container.getLoginCoordinator(), + extensionId, + ); + await manager.loadTokens(); + manager.scheduleBackgroundRefresh(); + return manager; + } + + private constructor( + private deployment: Deployment | undefined, + private readonly secretsManager: SecretsManager, + private readonly logger: Logger, + private readonly loginCoordinator: LoginCoordinator, + private readonly extensionId: string, + ) {} + + /** + * Get current deployment, throwing if not set. + * Use this in methods that require a deployment to be configured. + */ + private requireDeployment(): Deployment { + if (!this.deployment) { + throw new Error("No deployment configured for OAuth session manager"); + } + return this.deployment; + } + + /** + * Load stored tokens from storage. + * No-op if deployment is not set. + * Validates that tokens belong to the current deployment URL. + */ + private async loadTokens(): Promise { + if (!this.deployment) { + return; + } + + const tokens = await this.secretsManager.getOAuthTokens( + this.deployment.label, + ); + if (!tokens) { + return; + } + + if (tokens.deployment_url !== this.deployment.url) { + this.logger.warn("Stored tokens for different deployment, clearing", { + stored: tokens.deployment_url, + current: this.deployment.url, + }); + this.clearInMemoryTokens(); + await this.secretsManager.clearOAuthData(this.deployment.label); + return; + } + + if (!this.hasRequiredScopes(tokens.scope)) { + this.logger.warn( + "Stored token missing required scopes, clearing tokens", + { + stored_scope: tokens.scope, + required_scopes: DEFAULT_OAUTH_SCOPES, + }, + ); + this.clearInMemoryTokens(); + await this.secretsManager.clearOAuthTokens(this.deployment.label); + return; + } + + this.storedTokens = tokens; + this.logger.info(`Loaded stored OAuth tokens for ${this.deployment.label}`); + } + + private clearInMemoryTokens(): void { + this.storedTokens = undefined; + this.refreshPromise = null; + this.lastRefreshAttempt = 0; + } + + /** + * Schedule the next background token refresh check. + * Only schedules the next check after the current one completes. + */ + private scheduleBackgroundRefresh(): void { + if (this.refreshTimer) { + clearTimeout(this.refreshTimer); + } + + this.refreshTimer = setTimeout(async () => { + try { + await this.refreshIfAlmostExpired(); + } catch (error) { + this.logger.warn("Background token refresh failed:", error); + } + this.scheduleBackgroundRefresh(); + }, BACKGROUND_REFRESH_INTERVAL_MS); + } + + /** + * Check if granted scopes cover all required scopes. + * Supports wildcard scopes like "workspace:*". + */ + private hasRequiredScopes(grantedScope: string | undefined): boolean { + if (!grantedScope) { + // TODO server always returns empty scopes + return true; + } + + const grantedScopes = new Set(grantedScope.split(" ")); + const requiredScopes = DEFAULT_OAUTH_SCOPES.split(" "); + + for (const required of requiredScopes) { + if (grantedScopes.has(required)) { + continue; + } + + // Check wildcard match (e.g., "workspace:*" grants "workspace:read") + const colonIndex = required.indexOf(":"); + if (colonIndex !== -1) { + const prefix = required.substring(0, colonIndex); + const wildcard = `${prefix}:*`; + if (grantedScopes.has(wildcard)) { + continue; + } + } + + return false; + } + + return true; + } + + /** + * Get the redirect URI for OAuth callbacks. + */ + private getRedirectUri(): string { + return `${vscode.env.uriScheme}://${this.extensionId}${CALLBACK_PATH}`; + } + + /** + * Prepare common OAuth operation setup: client, metadata, and registration. + * Used by refresh and revoke operations to reduce duplication. + */ + private async prepareOAuthOperation(token?: string): Promise<{ + axiosInstance: AxiosInstance; + metadata: OAuthServerMetadata; + registration: ClientRegistrationResponse; + }> { + const deployment = this.requireDeployment(); + const client = CoderApi.create(deployment.url, token, this.logger); + const axiosInstance = client.getAxiosInstance(); + + const metadataClient = new OAuthMetadataClient(axiosInstance, this.logger); + const metadata = await metadataClient.getMetadata(); + + const registration = await this.secretsManager.getOAuthClientRegistration( + deployment.label, + ); + if (!registration) { + throw new Error("No client registration found"); + } + + return { axiosInstance, metadata, registration }; + } + + /** + * Register OAuth client or return existing if still valid. + * Re-registers if redirect URI has changed. + */ + private async registerClient( + axiosInstance: AxiosInstance, + metadata: OAuthServerMetadata, + ): Promise { + const deployment = this.requireDeployment(); + const redirectUri = this.getRedirectUri(); + + const existing = await this.secretsManager.getOAuthClientRegistration( + deployment.label, + ); + if (existing?.client_id) { + if (existing.redirect_uris.includes(redirectUri)) { + this.logger.info( + "Using existing client registration:", + existing.client_id, + ); + return existing; + } + this.logger.info("Redirect URI changed, re-registering client"); + } + + if (!metadata.registration_endpoint) { + throw new Error("Server does not support dynamic client registration"); + } + + const registrationRequest: ClientRegistrationRequest = { + redirect_uris: [redirectUri], + application_type: "web", + grant_types: ["authorization_code"], + response_types: ["code"], + client_name: "VS Code Coder Extension", + token_endpoint_auth_method: "client_secret_post", + }; + + const response = await axiosInstance.post( + metadata.registration_endpoint, + registrationRequest, + ); + + await this.secretsManager.setOAuthClientRegistration( + deployment.label, + response.data, + ); + this.logger.info( + "Saved OAuth client registration:", + response.data.client_id, + ); + + return response.data; + } + + public async setDeployment(deployment: Deployment): Promise { + if ( + this.deployment && + deployment.label === this.deployment.label && + deployment.url === this.deployment.url + ) { + return; + } + this.logger.debug("Switching OAuth deployment", deployment); + this.deployment = deployment; + this.clearInMemoryTokens(); + await this.loadTokens(); + } + + public clearDeployment(): void { + this.logger.debug("Clearing OAuth deployment state"); + this.deployment = undefined; + this.clearInMemoryTokens(); + } + + /** + * OAuth login flow that handles the entire process. + * Fetches metadata, registers client, starts authorization, and exchanges tokens. + * + * @returns TokenResponse containing access token and optional refresh token + */ + public async login( + client: CoderApi, + deployment: Deployment, + progress: vscode.Progress<{ message?: string; increment?: number }>, + ): Promise { + const baseUrl = client.getAxiosInstance().defaults.baseURL; + if (!baseUrl) { + throw new Error("Client has no base URL set"); + } + if (baseUrl !== deployment.url) { + throw new Error( + `Client base URL (${baseUrl}) does not match deployment URL (${deployment.url})`, + ); + } + + // Update deployment if changed + if ( + !this.deployment || + this.deployment.url !== deployment.url || + this.deployment.label !== deployment.label + ) { + this.logger.info("Deployment changed, clearing cached state", { + old: this.deployment, + new: deployment, + }); + this.clearInMemoryTokens(); + this.deployment = deployment; + } + + const axiosInstance = client.getAxiosInstance(); + const metadataClient = new OAuthMetadataClient(axiosInstance, this.logger); + const metadata = await metadataClient.getMetadata(); + + // Only register the client on login + progress.report({ message: "registering client...", increment: 10 }); + const registration = await this.registerClient(axiosInstance, metadata); + + progress.report({ message: "waiting for authorization...", increment: 30 }); + const { code, verifier } = await this.startAuthorization( + metadata, + registration, + ); + + progress.report({ message: "exchanging token...", increment: 30 }); + const tokenResponse = await this.exchangeToken( + code, + verifier, + axiosInstance, + metadata, + registration, + ); + + progress.report({ increment: 30 }); + this.logger.info("OAuth login flow completed successfully"); + + return tokenResponse; + } + + /** + * Build authorization URL with all required OAuth 2.1 parameters. + */ + private buildAuthorizationUrl( + metadata: OAuthServerMetadata, + clientId: string, + state: string, + challenge: string, + ): string { + if (metadata.scopes_supported) { + const requestedScopes = DEFAULT_OAUTH_SCOPES.split(" "); + const unsupportedScopes = requestedScopes.filter( + (s) => !metadata.scopes_supported?.includes(s), + ); + if (unsupportedScopes.length > 0) { + this.logger.warn( + `Requested scopes not in server's supported scopes: ${unsupportedScopes.join(", ")}. Server may still accept them.`, + { supported_scopes: metadata.scopes_supported }, + ); + } + } + + const params = new URLSearchParams({ + client_id: clientId, + response_type: RESPONSE_TYPE, + redirect_uri: this.getRedirectUri(), + scope: DEFAULT_OAUTH_SCOPES, + state, + code_challenge: challenge, + code_challenge_method: PKCE_CHALLENGE_METHOD, + }); + + const url = `${metadata.authorization_endpoint}?${params.toString()}`; + + this.logger.debug("Built OAuth authorization URL:", { + client_id: clientId, + redirect_uri: this.getRedirectUri(), + scope: DEFAULT_OAUTH_SCOPES, + }); + + return url; + } + + /** + * Start OAuth authorization flow. + * Opens browser for user authentication and waits for callback. + * Returns authorization code and PKCE verifier on success. + */ + private async startAuthorization( + metadata: OAuthServerMetadata, + registration: ClientRegistrationResponse, + ): Promise<{ code: string; verifier: string }> { + const state = generateState(); + const { verifier, challenge } = generatePKCE(); + + const authUrl = this.buildAuthorizationUrl( + metadata, + registration.client_id, + state, + challenge, + ); + + const callbackPromise = new Promise<{ code: string; verifier: string }>( + (resolve, reject) => { + const timeoutMins = 5; + const timeoutHandle = setTimeout( + () => { + cleanup(); + reject( + new Error(`OAuth flow timed out after ${timeoutMins} minutes`), + ); + }, + timeoutMins * 60 * 1000, + ); + + const listener = this.secretsManager.onDidChangeOAuthCallback( + ({ state: callbackState, code, error }) => { + if (callbackState !== state) { + return; + } + + cleanup(); + + if (error) { + reject(new Error(`OAuth error: ${error}`)); + } else if (code) { + resolve({ code, verifier }); + } else { + reject(new Error("No authorization code received")); + } + }, + ); + + const cleanup = () => { + clearTimeout(timeoutHandle); + listener.dispose(); + }; + + this.pendingAuthReject = (error) => { + cleanup(); + reject(error); + }; + }, + ); + + try { + await vscode.env.openExternal(vscode.Uri.parse(authUrl)); + } catch (error) { + throw error instanceof Error + ? error + : new Error("Failed to open browser"); + } + + return callbackPromise; + } + + /** + * Handle OAuth callback from browser redirect. + * Writes the callback result to secrets storage, triggering the waiting window to proceed. + */ + public async handleCallback( + code: string | null, + state: string | null, + error: string | null, + ): Promise { + if (!state) { + this.logger.warn("Received OAuth callback with no state parameter"); + return; + } + + try { + await this.secretsManager.setOAuthCallback({ state, code, error }); + this.logger.debug("OAuth callback processed successfully"); + } catch (err) { + this.logger.error("Failed to process OAuth callback:", err); + } + } + + /** + * Exchange authorization code for access token. + */ + private async exchangeToken( + code: string, + verifier: string, + axiosInstance: AxiosInstance, + metadata: OAuthServerMetadata, + registration: ClientRegistrationResponse, + ): Promise { + this.logger.info("Exchanging authorization code for token"); + + const params: TokenRequestParams = { + grant_type: AUTH_GRANT_TYPE, + code, + redirect_uri: this.getRedirectUri(), + client_id: registration.client_id, + client_secret: registration.client_secret, + code_verifier: verifier, + }; + + const tokenRequest = toUrlSearchParams(params); + + const response = await axiosInstance.post( + metadata.token_endpoint, + tokenRequest, + { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + }, + ); + + this.logger.info("Token exchange successful"); + + await this.saveTokens(response.data); + + return response.data; + } + + /** + * Refresh the access token using the stored refresh token. + * Uses a shared promise to handle concurrent refresh attempts. + */ + public async refreshToken(): Promise { + // If a refresh is already in progress, return the existing promise + if (this.refreshPromise) { + this.logger.debug( + "Token refresh already in progress, waiting for result", + ); + return this.refreshPromise; + } + + if (!this.storedTokens?.refresh_token) { + throw new Error("No refresh token available"); + } + + const refreshToken = this.storedTokens.refresh_token; + const accessToken = this.storedTokens.access_token; + + this.lastRefreshAttempt = Date.now(); + + // Create and store the refresh promise + this.refreshPromise = (async () => { + try { + const { axiosInstance, metadata, registration } = + await this.prepareOAuthOperation(accessToken); + + this.logger.debug("Refreshing access token"); + + const params: RefreshTokenRequestParams = { + grant_type: REFRESH_GRANT_TYPE, + refresh_token: refreshToken, + client_id: registration.client_id, + client_secret: registration.client_secret, + }; + + const tokenRequest = toUrlSearchParams(params); + + const response = await axiosInstance.post( + metadata.token_endpoint, + tokenRequest, + { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + }, + ); + + this.logger.debug("Token refresh successful"); + + await this.saveTokens(response.data); + + return response.data; + } finally { + this.refreshPromise = null; + } + })(); + + return this.refreshPromise; + } + + /** + * Save token response to storage. + * Also triggers event via secretsManager to update global client. + */ + private async saveTokens(tokenResponse: TokenResponse): Promise { + const deployment = this.requireDeployment(); + const expiryTimestamp = tokenResponse.expires_in + ? Date.now() + tokenResponse.expires_in * 1000 + : Date.now() + ACCESS_TOKEN_DEFAULT_EXPIRY_MS; + + const tokens: StoredOAuthTokens = { + ...tokenResponse, + deployment_url: deployment.url, + expiry_timestamp: expiryTimestamp, + }; + + this.storedTokens = tokens; + await this.secretsManager.setOAuthTokens(deployment.label, tokens); + await this.secretsManager.setSessionAuth(deployment.label, { + url: deployment.url, + token: tokenResponse.access_token, + }); + + this.logger.info("Tokens saved", { + expires_at: new Date(expiryTimestamp).toISOString(), + deployment: deployment.url, + }); + } + + /** + * Refreshes the token if it is approaching expiry. + */ + public async refreshIfAlmostExpired(): Promise { + if (this.shouldRefreshToken()) { + this.logger.debug("Token approaching expiry, triggering refresh"); + await this.refreshToken(); + } + } + + /** + * Check if token should be refreshed. + * Returns true if: + * 1. Token expires in less than TOKEN_REFRESH_THRESHOLD_MS + * 2. Last refresh attempt was more than REFRESH_THROTTLE_MS ago + * 3. No refresh is currently in progress + */ + private shouldRefreshToken(): boolean { + if ( + !this.isLoggedInWithOAuth() || + !this.storedTokens?.refresh_token || + this.refreshPromise !== null + ) { + return false; + } + + const now = Date.now(); + if (now - this.lastRefreshAttempt < REFRESH_THROTTLE_MS) { + return false; + } + + const timeUntilExpiry = this.storedTokens.expiry_timestamp - now; + return timeUntilExpiry < TOKEN_REFRESH_THRESHOLD_MS; + } + + /** + * Revoke a token using the OAuth server's revocation endpoint. + */ + private async revokeToken( + token: string, + tokenTypeHint: "access_token" | "refresh_token" = "refresh_token", + ): Promise { + const { axiosInstance, metadata, registration } = + await this.prepareOAuthOperation(this.storedTokens?.access_token); + + const revocationEndpoint = + metadata.revocation_endpoint || `${metadata.issuer}/oauth2/revoke`; + + this.logger.info("Revoking refresh token"); + + const params: TokenRevocationRequest = { + token, + client_id: registration.client_id, + client_secret: registration.client_secret, + token_type_hint: tokenTypeHint, + }; + + const revocationRequest = toUrlSearchParams(params); + + try { + await axiosInstance.post(revocationEndpoint, revocationRequest, { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + }); + + this.logger.info("Token revocation successful"); + } catch (error) { + this.logger.error("Token revocation failed:", error); + throw error; + } + } + + /** + * Logout by revoking tokens and clearing all OAuth data. + */ + public async logout(): Promise { + if (!this.isLoggedInWithOAuth()) { + return; + } + + // Revoke refresh token (which also invalidates access token per RFC 7009) + if (this.storedTokens?.refresh_token) { + try { + // TODO what if other windows are using this? + // We should only revoke if we are clearing the OAuth data + await this.revokeToken(this.storedTokens.refresh_token); + } catch (error) { + this.logger.warn("Token revocation failed during logout:", error); + } + } + + this.clearInMemoryTokens(); + this.deployment = undefined; + + this.logger.info("OAuth logout complete"); + } + + /** + * Returns true if (valid or invalid) OAuth tokens exist for the current deployment. + */ + public isLoggedInWithOAuth(): boolean { + return this.storedTokens !== undefined; + } + + /** + * Show a modal dialog to the user when OAuth re-authentication is required. + * This is called when the refresh token is invalid or the client credentials are invalid. + * Clears tokens directly and lets listeners handle updates. + */ + public async showReAuthenticationModal(error: OAuthError): Promise { + const deployment = this.requireDeployment(); + const errorMessage = + error.description || + "Your session is no longer valid. This could be due to token expiration or revocation."; + + // Clear invalid tokens - listeners will handle updates automatically + this.clearInMemoryTokens(); + await this.secretsManager.clearAllAuthData(deployment.label); + + await this.loginCoordinator.promptForLoginWithDialog({ + label: deployment.label, + url: deployment.url, + detailPrefix: errorMessage, + oauthSessionManager: this, + }); + } + + /** + * Clears all in-memory state. + */ + public dispose(): void { + if (this.refreshTimer) { + clearTimeout(this.refreshTimer); + this.refreshTimer = undefined; + } + if (this.pendingAuthReject) { + this.pendingAuthReject(new Error("OAuth session manager disposed")); + } + this.pendingAuthReject = undefined; + this.clearInMemoryTokens(); + + this.logger.debug("OAuth session manager disposed"); + } +} diff --git a/src/oauth/types.ts b/src/oauth/types.ts new file mode 100644 index 00000000..6ecaa0ff --- /dev/null +++ b/src/oauth/types.ts @@ -0,0 +1,163 @@ +// OAuth 2.1 Grant Types +export type GrantType = + | "authorization_code" + | "refresh_token" + | "client_credentials"; + +// OAuth 2.1 Response Types +export type ResponseType = "code"; + +// Token Endpoint Authentication Methods +export type TokenEndpointAuthMethod = + | "client_secret_post" + | "client_secret_basic" + | "none"; + +// Application Types +export type ApplicationType = "native" | "web"; + +// PKCE Code Challenge Methods (OAuth 2.1 requires S256) +export type CodeChallengeMethod = "S256"; + +// Token Types +export type TokenType = "Bearer" | "DPoP"; + +// Client Registration Request (RFC 7591 + OAuth 2.1) +export interface ClientRegistrationRequest { + redirect_uris: string[]; + token_endpoint_auth_method: TokenEndpointAuthMethod; + application_type: ApplicationType; + grant_types: GrantType[]; + response_types: ResponseType[]; + client_name?: string; + client_uri?: string; + logo_uri?: string; + scope?: string; + contacts?: string[]; + tos_uri?: string; + policy_uri?: string; + jwks_uri?: string; + software_id?: string; + software_version?: string; +} + +// Client Registration Response (RFC 7591) +export interface ClientRegistrationResponse { + client_id: string; + client_secret?: string; + client_id_issued_at?: number; + client_secret_expires_at?: number; + redirect_uris: string[]; + token_endpoint_auth_method: TokenEndpointAuthMethod; + application_type?: ApplicationType; + grant_types: GrantType[]; + response_types: ResponseType[]; + client_name?: string; + client_uri?: string; + logo_uri?: string; + scope?: string; + contacts?: string[]; + tos_uri?: string; + policy_uri?: string; + jwks_uri?: string; + software_id?: string; + software_version?: string; + registration_client_uri?: string; + registration_access_token?: string; +} + +// OAuth 2.1 Authorization Server Metadata (RFC 8414) +export interface OAuthServerMetadata { + issuer: string; + authorization_endpoint: string; + token_endpoint: string; + registration_endpoint?: string; + jwks_uri?: string; + response_types_supported: ResponseType[]; + grant_types_supported?: GrantType[]; + code_challenge_methods_supported: CodeChallengeMethod[]; + scopes_supported?: string[]; + token_endpoint_auth_methods_supported?: TokenEndpointAuthMethod[]; + revocation_endpoint?: string; + revocation_endpoint_auth_methods_supported?: TokenEndpointAuthMethod[]; + introspection_endpoint?: string; + introspection_endpoint_auth_methods_supported?: TokenEndpointAuthMethod[]; + service_documentation?: string; + ui_locales_supported?: string[]; +} + +// Token Response (RFC 6749 Section 5.1) +export interface TokenResponse { + access_token: string; + token_type: TokenType; + expires_in?: number; + refresh_token?: string; + scope?: string; +} + +// Authorization Request Parameters (OAuth 2.1) +export interface AuthorizationRequestParams { + client_id: string; + response_type: ResponseType; + redirect_uri: string; + scope?: string; + state: string; + code_challenge: string; + code_challenge_method: CodeChallengeMethod; +} + +// Token Request Parameters - Authorization Code Grant (OAuth 2.1) +export interface TokenRequestParams { + grant_type: "authorization_code"; + code: string; + redirect_uri: string; + client_id: string; + code_verifier: string; + client_secret?: string; +} + +// Token Request Parameters - Refresh Token Grant +export interface RefreshTokenRequestParams { + grant_type: "refresh_token"; + refresh_token: string; + client_id: string; + client_secret?: string; + scope?: string; +} + +// Token Request Parameters - Client Credentials Grant +export interface ClientCredentialsRequestParams { + grant_type: "client_credentials"; + client_id: string; + client_secret: string; + scope?: string; +} + +// Union type for all token request types +export type TokenRequestParamsUnion = + | TokenRequestParams + | RefreshTokenRequestParams + | ClientCredentialsRequestParams; + +// Token Revocation Request (RFC 7009) +export interface TokenRevocationRequest { + token: string; + token_type_hint?: "access_token" | "refresh_token"; + client_id: string; + client_secret?: string; +} + +// Error Response (RFC 6749 Section 5.2) +export interface OAuthErrorResponse { + error: + | "invalid_request" + | "invalid_client" + | "invalid_grant" + | "unauthorized_client" + | "unsupported_grant_type" + | "invalid_scope" + | "server_error" + | "temporarily_unavailable"; + error_description?: string; + error_uri?: string; +} diff --git a/src/oauth/utils.ts b/src/oauth/utils.ts new file mode 100644 index 00000000..61beeb50 --- /dev/null +++ b/src/oauth/utils.ts @@ -0,0 +1,42 @@ +import { createHash, randomBytes } from "node:crypto"; + +/** + * OAuth callback path for handling authorization responses (RFC 6749). + */ +export const CALLBACK_PATH = "/oauth/callback"; + +export interface PKCEChallenge { + verifier: string; + challenge: string; +} + +/** + * Generates a PKCE challenge pair (RFC 7636). + * Creates a code verifier and its SHA256 challenge for secure OAuth flows. + */ +export function generatePKCE(): PKCEChallenge { + const verifier = randomBytes(32).toString("base64url"); + const challenge = createHash("sha256").update(verifier).digest("base64url"); + return { verifier, challenge }; +} + +/** + * Generates a cryptographically secure state parameter to prevent CSRF attacks (RFC 6749). + */ +export function generateState(): string { + return randomBytes(16).toString("base64url"); +} + +/** + * Converts an object with string properties to URLSearchParams, + * filtering out undefined values for use with OAuth requests. + */ +export function toUrlSearchParams(obj: object): URLSearchParams { + const params = Object.fromEntries( + Object.entries(obj).filter( + ([, value]) => value !== undefined && typeof value === "string", + ), + ) as Record; + + return new URLSearchParams(params); +} diff --git a/src/promptUtils.ts b/src/promptUtils.ts index 4d058f12..9e3d8895 100644 --- a/src/promptUtils.ts +++ b/src/promptUtils.ts @@ -1,7 +1,11 @@ import { type WorkspaceAgent } from "coder/site/src/api/typesGenerated"; import * as vscode from "vscode"; +import { type CoderApi } from "./api/coderApi"; import { type MementoManager } from "./core/mementoManager"; +import { OAuthMetadataClient } from "./oauth/metadataClient"; + +type AuthMethod = "oauth" | "legacy"; /** * Find the requested agent if specified, otherwise return the agent if there @@ -61,15 +65,16 @@ export async function maybeAskAgent( */ async function askURL( mementoManager: MementoManager, - selection?: string, + prePopulateUrl: string | undefined, ): Promise { const defaultURL = vscode.workspace .getConfiguration() .get("coder.defaultUrl") ?.trim(); const quickPick = vscode.window.createQuickPick(); + quickPick.ignoreFocusOut = true; quickPick.value = - selection || defaultURL || process.env.CODER_URL?.trim() || ""; + prePopulateUrl || defaultURL || process.env.CODER_URL?.trim() || ""; quickPick.placeholder = "https://example.coder.com"; quickPick.title = "Enter the URL of your Coder deployment."; @@ -111,9 +116,9 @@ async function askURL( export async function maybeAskUrl( mementoManager: MementoManager, providedUrl: string | undefined | null, - lastUsedUrl?: string, + prePopulateUrl?: string, ): Promise { - let url = providedUrl || (await askURL(mementoManager, lastUsedUrl)); + let url = providedUrl || (await askURL(mementoManager, prePopulateUrl)); if (!url) { // User aborted. return undefined; @@ -129,3 +134,54 @@ export async function maybeAskUrl( } return url; } + +export async function maybeAskAuthMethod( + client: CoderApi, +): Promise { + // Check if server supports OAuth with progress indication + const supportsOAuth = await vscode.window.withProgress( + { + location: vscode.ProgressLocation.Notification, + title: "Checking authentication methods", + cancellable: false, + }, + async () => { + return await OAuthMetadataClient.checkOAuthSupport( + client.getAxiosInstance(), + ); + }, + ); + + if (supportsOAuth) { + return await askAuthMethod(); + } else { + return "legacy"; + } +} + +/** + * Ask user to choose between OAuth and legacy API token authentication. + */ +async function askAuthMethod(): Promise { + const choice = await vscode.window.showQuickPick( + [ + { + label: "OAuth (Recommended)", + description: "Secure authentication with automatic token refresh", + value: "oauth" as const, + }, + { + label: "Session Token (Legacy)", + description: "Generate and paste a session token manually", + value: "legacy" as const, + }, + ], + { + title: "Select authentication method", + placeHolder: "How would you like to authenticate?", + ignoreFocusOut: true, + }, + ); + + return choice?.value; +} diff --git a/src/remote/remote.ts b/src/remote/remote.ts index 4193e46d..f992c18e 100644 --- a/src/remote/remote.ts +++ b/src/remote/remote.ts @@ -19,6 +19,7 @@ import { } from "../api/agentMetadataHelper"; import { extractAgents } from "../api/api-helper"; import { CoderApi } from "../api/coderApi"; +import { attachOAuthInterceptors } from "../api/oauthInterceptors"; import { needToken } from "../api/utils"; import { type Commands } from "../commands"; import { type CliManager } from "../core/cliManager"; @@ -26,10 +27,13 @@ import * as cliUtils from "../core/cliUtils"; import { type ServiceContainer } from "../core/container"; import { type ContextManager } from "../core/contextManager"; import { type PathResolver } from "../core/pathResolver"; +import { type SecretsManager } from "../core/secretsManager"; import { featureSetForVersion, type FeatureSet } from "../featureSet"; import { getGlobalFlags } from "../globalFlags"; import { Inbox } from "../inbox"; import { type Logger } from "../logging/logger"; +import { type LoginCoordinator } from "../login/loginCoordinator"; +import { OAuthSessionManager } from "../oauth/sessionManager"; import { AuthorityPrefix, escapeCommandArg, @@ -44,6 +48,7 @@ import { computeSSHProperties, sshSupportsSetEnv } from "./sshSupport"; import { WorkspaceStateMachine } from "./workspaceStateMachine"; export interface RemoteDetails extends vscode.Disposable { + label: string; url: string; token: string; } @@ -55,48 +60,21 @@ export class Remote { private readonly pathResolver: PathResolver; private readonly cliManager: CliManager; private readonly contextManager: ContextManager; - - // Used to race between the login dialog and logging in from a different window - private loginDetectedResolver: (() => void) | undefined; - private loginDetectedRejector: ((reason?: Error) => void) | undefined; - private loginDetectedPromise: Promise = Promise.resolve(); + private readonly secretsManager: SecretsManager; + private readonly loginCoordinator: LoginCoordinator; public constructor( - serviceContainer: ServiceContainer, + private readonly serviceContainer: ServiceContainer, private readonly commands: Commands, - private readonly mode: vscode.ExtensionMode, + private readonly extensionContext: vscode.ExtensionContext, ) { this.vscodeProposed = serviceContainer.getVsCodeProposed(); this.logger = serviceContainer.getLogger(); this.pathResolver = serviceContainer.getPathResolver(); this.cliManager = serviceContainer.getCliManager(); this.contextManager = serviceContainer.getContextManager(); - } - - /** - * Creates a new promise that will be resolved when login is detected in another window. - */ - private createLoginDetectionPromise(): void { - if (this.loginDetectedRejector) { - this.loginDetectedRejector( - new Error("Login detection cancelled - new login attempt started"), - ); - } - this.loginDetectedPromise = new Promise((resolve, reject) => { - this.loginDetectedResolver = resolve; - this.loginDetectedRejector = reject; - }); - } - - /** - * Resolves the current login detection promise if one exists. - */ - public resolveLoginDetected(): void { - if (this.loginDetectedResolver) { - this.loginDetectedResolver(); - this.loginDetectedResolver = undefined; - this.loginDetectedRejector = undefined; - } + this.secretsManager = serviceContainer.getSecretsManager(); + this.loginCoordinator = serviceContainer.getLoginCoordinator(); } /** @@ -121,157 +99,193 @@ export class Remote { await this.migrateSessionToken(parts.label); // Get the URL and token belonging to this host. - const { url: baseUrlRaw, token } = await this.cliManager.readConfig( - parts.label, - ); + const baseUrlRaw = (await this.secretsManager.getUrl(parts.label)) ?? ""; + const token = + (await this.secretsManager.getSessionToken(parts.label)) ?? ""; + if (baseUrlRaw && token) { + await this.cliManager.configure(parts.label, baseUrlRaw, token); + } - const showLoginDialog = async (message: string) => { - this.createLoginDetectionPromise(); - const dialogPromise = this.vscodeProposed.window.showInformationMessage( - message, - { - useCustom: true, - modal: true, - detail: `You must log in to access ${workspaceName}. If you've already logged in, you may close this dialog.`, - }, - "Log In", + const disposables: vscode.Disposable[] = []; + + try { + disposables.push( + this.secretsManager.onDidChangeSessionAuth( + parts.label, + async (auth) => { + if (auth?.token && auth.url) { + // Update CLI config with new token + await this.cliManager.configure( + parts.label, + auth.url, + auth.token, + ); + this.logger.info( + "Updated CLI config with new token for remote deployment", + ); + } + }, + ), ); - // Race between dialog and login detection - const result = await Promise.race([ - this.loginDetectedPromise.then(() => ({ type: "login" as const })), - dialogPromise.then((userChoice) => ({ - type: "dialog" as const, - userChoice, - })), - ]); - - if (result.type === "login") { - return this.setup(remoteAuthority, firstConnect, remoteSshExtensionId); - } else if (!result.userChoice) { - // User declined to log in. - await this.closeRemote(); - return; - } else { - // Log in then try again. - await this.commands.login({ url: baseUrlRaw, label: parts.label }); - return this.setup(remoteAuthority, firstConnect, remoteSshExtensionId); - } - }; + // Create OAuth session manager for this remote deployment + const remoteOAuthManager = await OAuthSessionManager.create( + { url: baseUrlRaw, label: parts.label }, + this.serviceContainer, + this.extensionContext.extension.id, + ); + disposables.push(remoteOAuthManager); + + const promptForLoginAndRetry = async ( + message: string, + url: string | undefined, + ) => { + const result = await this.loginCoordinator.promptForLoginWithDialog({ + label: parts.label, + url, + message, + detailPrefix: `You must log in to access ${workspaceName}.`, + oauthSessionManager: remoteOAuthManager, + }); - // It could be that the cli config was deleted. If so, ask for the url. - if ( - !baseUrlRaw || - (!token && needToken(vscode.workspace.getConfiguration())) - ) { - return showLoginDialog("You are not logged in..."); - } + if (result.success) { + // Login successful, retry setup + return this.setup( + remoteAuthority, + firstConnect, + remoteSshExtensionId, + ); + } else { + // User cancelled or login failed + await this.closeRemote(); + } + }; - this.logger.info("Using deployment URL", baseUrlRaw); - this.logger.info("Using deployment label", parts.label || "n/a"); - - // We could use the plugin client, but it is possible for the user to log - // out or log into a different deployment while still connected, which would - // break this connection. We could force close the remote session or - // disallow logging out/in altogether, but for now just use a separate - // client to remain unaffected by whatever the plugin is doing. - const workspaceClient = CoderApi.create(baseUrlRaw, token, this.logger); - // Store for use in commands. - this.commands.workspaceRestClient = workspaceClient; - - let binaryPath: string | undefined; - if (this.mode === vscode.ExtensionMode.Production) { - binaryPath = await this.cliManager.fetchBinary( - workspaceClient, - parts.label, + // It could be that the cli config was deleted. If so, ask for the url. + if ( + !baseUrlRaw || + (!token && needToken(vscode.workspace.getConfiguration())) + ) { + return promptForLoginAndRetry("You are not logged in...", baseUrlRaw); + } + + this.logger.info("Using deployment URL", baseUrlRaw); + this.logger.info("Using deployment label", parts.label || "n/a"); + + // We could use the plugin client, but it is possible for the user to log + // out or log into a different deployment while still connected, which would + // break this connection. We could force close the remote session or + // disallow logging out/in altogether, but for now just use a separate + // client to remain unaffected by whatever the plugin is doing. + const workspaceClient = CoderApi.create(baseUrlRaw, token, this.logger); + disposables.push(workspaceClient); + attachOAuthInterceptors(workspaceClient, this.logger, remoteOAuthManager); + // Store for use in commands. + this.commands.remoteWorkspaceClient = workspaceClient; + + // Listen for token changes for this deployment + disposables.push( + this.secretsManager.onDidChangeSessionAuth(parts.label, (auth) => { + workspaceClient.setCredentials(auth?.url, auth?.token); + }), ); - } else { - try { - // In development, try to use `/tmp/coder` as the binary path. - // This is useful for debugging with a custom bin! - binaryPath = path.join(os.tmpdir(), "coder"); - await fs.stat(binaryPath); - } catch { + + let binaryPath: string | undefined; + if ( + this.extensionContext.extensionMode === vscode.ExtensionMode.Production + ) { binaryPath = await this.cliManager.fetchBinary( workspaceClient, parts.label, ); + } else { + try { + // In development, try to use `/tmp/coder` as the binary path. + // This is useful for debugging with a custom bin! + binaryPath = path.join(os.tmpdir(), "coder"); + await fs.stat(binaryPath); + } catch { + binaryPath = await this.cliManager.fetchBinary( + workspaceClient, + parts.label, + ); + } } - } - // First thing is to check the version. - const buildInfo = await workspaceClient.getBuildInfo(); + // First thing is to check the version. + const buildInfo = await workspaceClient.getBuildInfo(); - let version: semver.SemVer | null = null; - try { - version = semver.parse(await cliUtils.version(binaryPath)); - } catch { - version = semver.parse(buildInfo.version); - } + let version: semver.SemVer | null = null; + try { + version = semver.parse(await cliUtils.version(binaryPath)); + } catch { + version = semver.parse(buildInfo.version); + } - const featureSet = featureSetForVersion(version); + const featureSet = featureSetForVersion(version); - // Server versions before v0.14.1 don't support the vscodessh command! - if (!featureSet.vscodessh) { - await this.vscodeProposed.window.showErrorMessage( - "Incompatible Server", - { - detail: - "Your Coder server is too old to support the Coder extension! Please upgrade to v0.14.1 or newer.", - modal: true, - useCustom: true, - }, - "Close Remote", - ); - await this.closeRemote(); - return; - } - - // Next is to find the workspace from the URI scheme provided. - let workspace: Workspace; - try { - this.logger.info(`Looking for workspace ${workspaceName}...`); - workspace = await workspaceClient.getWorkspaceByOwnerAndName( - parts.username, - parts.workspace, - ); - this.logger.info( - `Found workspace ${workspaceName} with status`, - workspace.latest_build.status, - ); - this.commands.workspace = workspace; - } catch (error) { - if (!isAxiosError(error)) { - throw error; + // Server versions before v0.14.1 don't support the vscodessh command! + if (!featureSet.vscodessh) { + await this.vscodeProposed.window.showErrorMessage( + "Incompatible Server", + { + detail: + "Your Coder server is too old to support the Coder extension! Please upgrade to v0.14.1 or newer.", + modal: true, + useCustom: true, + }, + "Close Remote", + ); + await this.closeRemote(); + return; } - switch (error.response?.status) { - case 404: { - const result = - await this.vscodeProposed.window.showInformationMessage( - `That workspace doesn't exist!`, - { - modal: true, - detail: `${workspaceName} cannot be found on ${baseUrlRaw}. Maybe it was deleted...`, - useCustom: true, - }, - "Open Workspace", + + // Next is to find the workspace from the URI scheme provided. + let workspace: Workspace; + try { + this.logger.info(`Looking for workspace ${workspaceName}...`); + workspace = await workspaceClient.getWorkspaceByOwnerAndName( + parts.username, + parts.workspace, + ); + this.logger.info( + `Found workspace ${workspaceName} with status`, + workspace.latest_build.status, + ); + this.commands.workspace = workspace; + } catch (error) { + if (!isAxiosError(error)) { + throw error; + } + switch (error.response?.status) { + case 404: { + const result = + await this.vscodeProposed.window.showInformationMessage( + `That workspace doesn't exist!`, + { + modal: true, + detail: `${workspaceName} cannot be found on ${baseUrlRaw}. Maybe it was deleted...`, + useCustom: true, + }, + "Open Workspace", + ); + if (!result) { + await this.closeRemote(); + } + await vscode.commands.executeCommand("coder.open"); + return; + } + case 401: { + return promptForLoginAndRetry( + "Your session expired...", + baseUrlRaw, ); - if (!result) { - await this.closeRemote(); } - await vscode.commands.executeCommand("coder.open"); - return; - } - case 401: { - return showLoginDialog("Your session expired..."); + default: + throw error; } - default: - throw error; } - } - const disposables: vscode.Disposable[] = []; - try { // Register before connection so the label still displays! let labelFormatterDisposable = this.registerLabelFormatter( remoteAuthority, @@ -529,6 +543,7 @@ export class Remote { // deployment in the sidebar. We use our own client in here for reasons // explained above. return { + label: parts.label, url: baseUrlRaw, token, dispose: () => { diff --git a/src/websocket/codes.ts b/src/websocket/codes.ts index ac8eccf7..f3fd95cd 100644 --- a/src/websocket/codes.ts +++ b/src/websocket/codes.ts @@ -19,7 +19,9 @@ export const WebSocketCloseCode = { /** HTTP status codes used for socket creation and connection logic */ export const HttpStatusCode = { - /** Authentication or permission denied */ + /** Authentication required */ + UNAUTHORIZED: 401, + /** Permission denied */ FORBIDDEN: 403, /** Endpoint not found */ NOT_FOUND: 404, @@ -43,7 +45,9 @@ export const UNRECOVERABLE_WS_CLOSE_CODES = new Set([ * These appear during socket creation and should stop reconnection attempts. */ export const UNRECOVERABLE_HTTP_CODES = new Set([ + HttpStatusCode.UNAUTHORIZED, HttpStatusCode.FORBIDDEN, + HttpStatusCode.NOT_FOUND, HttpStatusCode.GONE, HttpStatusCode.UPGRADE_REQUIRED, ]); diff --git a/src/websocket/reconnectingWebSocket.ts b/src/websocket/reconnectingWebSocket.ts index 2ced9351..55f72988 100644 --- a/src/websocket/reconnectingWebSocket.ts +++ b/src/websocket/reconnectingWebSocket.ts @@ -41,7 +41,8 @@ export class ReconnectingWebSocket #currentSocket: UnidirectionalStream | null = null; #backoffMs: number; #reconnectTimeoutId: NodeJS.Timeout | null = null; - #isDisposed = false; + #isSuspended = false; // Temporary pause, can be resumed via reconnect() + #isDisposed = false; // Permanent disposal, cannot be resumed #isConnecting = false; #pendingReconnect = false; readonly #onDispose?: () => void; @@ -102,6 +103,11 @@ export class ReconnectingWebSocket } reconnect(): void { + if (this.#isSuspended) { + this.#isSuspended = false; + this.#backoffMs = this.#options.initialBackoffMs; + } + if (this.#isDisposed) { return; } @@ -121,6 +127,18 @@ export class ReconnectingWebSocket this.connect().catch((error) => this.handleConnectionError(error)); } + /** + * Temporarily suspend the socket. Can be resumed via reconnect(). + */ + suspend(code?: number, reason?: string): void { + if (this.#isDisposed || this.#isSuspended) { + return; + } + + this.#isSuspended = true; + this.clearCurrentSocket(code, reason); + } + close(code?: number, reason?: string): void { if (this.#isDisposed) { return; @@ -139,7 +157,7 @@ export class ReconnectingWebSocket } private async connect(): Promise { - if (this.#isDisposed || this.#isConnecting) { + if (this.#isDisposed || this.#isSuspended || this.#isConnecting) { return; } @@ -168,10 +186,21 @@ export class ReconnectingWebSocket socket.addEventListener("error", (event) => { this.executeHandlers("error", event); + + // Check for unrecoverable HTTP errors in the error event + // HTTP errors during handshake fire 'error' then 'close' with 1006 + // We need to suspend here to prevent infinite reconnect loops + const errorMessage = event.error?.message ?? event.message ?? ""; + if (this.isUnrecoverableHttpError(errorMessage)) { + this.#logger.error( + `Unrecoverable HTTP error for ${this.#apiRoute}: ${errorMessage}`, + ); + this.suspend(); + } }); socket.addEventListener("close", (event) => { - if (this.#isDisposed) { + if (this.#isDisposed || this.#isSuspended) { return; } @@ -181,7 +210,8 @@ export class ReconnectingWebSocket this.#logger.error( `WebSocket connection closed with unrecoverable error code ${event.code}`, ); - this.dispose(); + // Suspend instead of dispose - allows recovery when credentials change + this.suspend(); return; } @@ -204,7 +234,11 @@ export class ReconnectingWebSocket } private scheduleReconnect(): void { - if (this.#isDisposed || this.#reconnectTimeoutId !== null) { + if ( + this.#isDisposed || + this.#isSuspended || + this.#reconnectTimeoutId !== null + ) { return; } @@ -241,11 +275,11 @@ export class ReconnectingWebSocket } /** - * Checks if the error is unrecoverable and disposes the connection, + * Checks if the error is unrecoverable and suspends the connection, * otherwise schedules a reconnect. */ private handleConnectionError(error: unknown): void { - if (this.#isDisposed) { + if (this.#isDisposed || this.#isSuspended) { return; } @@ -254,7 +288,7 @@ export class ReconnectingWebSocket `Unrecoverable HTTP error during connection for ${this.#apiRoute}`, error, ); - this.dispose(); + this.suspend(); return; } @@ -266,12 +300,12 @@ export class ReconnectingWebSocket } /** - * Check if an error contains an unrecoverable HTTP status code. + * Check if an error message contains an unrecoverable HTTP status code. */ private isUnrecoverableHttpError(error: unknown): boolean { - const errorMessage = error instanceof Error ? error.message : String(error); + const message = (error as { message?: string }).message || String(error); for (const code of UNRECOVERABLE_HTTP_CODES) { - if (errorMessage.includes(String(code))) { + if (message.includes(String(code))) { return true; } } @@ -284,6 +318,18 @@ export class ReconnectingWebSocket } this.#isDisposed = true; + this.clearCurrentSocket(code, reason); + + for (const set of Object.values(this.#eventHandlers)) { + set.clear(); + } + + this.#onDispose?.(); + } + + private clearCurrentSocket(code?: number, reason?: string): void { + // Clear pending reconnect to prevent resume + this.#pendingReconnect = false; if (this.#reconnectTimeoutId !== null) { clearTimeout(this.#reconnectTimeoutId); @@ -294,11 +340,5 @@ export class ReconnectingWebSocket this.#currentSocket.close(code, reason); this.#currentSocket = null; } - - for (const set of Object.values(this.#eventHandlers)) { - set.clear(); - } - - this.#onDispose?.(); } } diff --git a/test/unit/api/coderApi.test.ts b/test/unit/api/coderApi.test.ts index 4f90f33e..7a7758c0 100644 --- a/test/unit/api/coderApi.test.ts +++ b/test/unit/api/coderApi.test.ts @@ -11,7 +11,6 @@ import { CertificateError } from "@/error"; import { getHeaders } from "@/headers"; import { type RequestConfigWithMeta } from "@/logging/types"; import { ReconnectingWebSocket } from "@/websocket/reconnectingWebSocket"; -import { SseConnection } from "@/websocket/sseConnection"; import { createMockLogger, @@ -336,18 +335,20 @@ describe("CoderApi", () => { expect(EventSource).not.toHaveBeenCalled(); }); - it("falls back to SSE when WebSocket creation fails", async () => { + it("falls back to SSE when WebSocket creation fails with 404", async () => { + // Only 404 errors trigger SSE fallback - other errors are thrown vi.mocked(Ws).mockImplementation(() => { - throw new Error("WebSocket creation failed"); + throw new Error("Unexpected server response: 404"); }); const connection = await api.watchAgentMetadata(AGENT_ID); - expect(connection).toBeInstanceOf(SseConnection); + // Returns ReconnectingWebSocket (which wraps SSE internally) + expect(connection).toBeInstanceOf(ReconnectingWebSocket); expect(EventSource).toHaveBeenCalled(); }); - it("falls back to SSE on 404 error from WebSocket", async () => { + it("falls back to SSE on 404 error from WebSocket open", async () => { const mockWs = createMockWebSocket( `wss://${CODER_URL.replace("https://", "")}/api/v2/test`, { @@ -368,9 +369,64 @@ describe("CoderApi", () => { const connection = await api.watchAgentMetadata(AGENT_ID); - expect(connection).toBeInstanceOf(SseConnection); + // Returns ReconnectingWebSocket (which wraps SSE internally after WS 404) + expect(connection).toBeInstanceOf(ReconnectingWebSocket); expect(EventSource).toHaveBeenCalled(); }); + + it("throws non-404 errors without SSE fallback", async () => { + vi.mocked(Ws).mockImplementation(() => { + throw new Error("Network error"); + }); + + await expect(api.watchAgentMetadata(AGENT_ID)).rejects.toThrow( + "Network error", + ); + expect(EventSource).not.toHaveBeenCalled(); + }); + + describe("reconnection after fallback", () => { + beforeEach(() => vi.useFakeTimers({ shouldAdvanceTime: true })); + afterEach(() => vi.useRealTimers()); + + it("reconnects after SSE fallback and retries WS on each reconnect", async () => { + let wsAttempts = 0; + const mockEventSources: MockEventSource[] = []; + + vi.mocked(Ws).mockImplementation(() => { + wsAttempts++; + const mockWs = createMockWebSocket("wss://test", { + on: vi.fn((event: string, handler: (e: unknown) => void) => { + if (event === "error") { + setImmediate(() => + handler({ error: new Error("Something 404") }), + ); + } + return mockWs as Ws; + }), + }); + return mockWs as Ws; + }); + + vi.mocked(EventSource).mockImplementation(() => { + const es = createMockEventSource(`${CODER_URL}/api/v2/test`); + mockEventSources.push(es); + return es as unknown as EventSource; + }); + + const connection = await api.watchAgentMetadata(AGENT_ID); + expect(wsAttempts).toBe(1); + expect(EventSource).toHaveBeenCalledTimes(1); + + mockEventSources[0].fireError(); + await vi.advanceTimersByTimeAsync(300); + + expect(wsAttempts).toBe(2); + expect(EventSource).toHaveBeenCalledTimes(2); + + connection.close(); + }); + }); }); describe("Reconnection on Host/Token Changes", () => { @@ -413,6 +469,7 @@ describe("CoderApi", () => { expect(wsWrap.url).toContain(CODER_URL.replace("http", "ws")); api.setHost("https://new-coder.example.com"); + // Wait for the async reconnect to complete (factory is async) await new Promise((resolve) => setImmediate(resolve)); expect(sockets[0].close).toHaveBeenCalledWith( @@ -420,7 +477,8 @@ describe("CoderApi", () => { "Replacing connection", ); expect(sockets).toHaveLength(2); - expect(wsWrap.url).toContain("wss://new-coder.example.com"); + // Verify the new socket was created with the correct URL + expect(sockets[1].url).toContain("wss://new-coder.example.com"); }); it("does not reconnect when token or host are unchanged", async () => { @@ -435,6 +493,58 @@ describe("CoderApi", () => { expect(sockets[0].close).not.toHaveBeenCalled(); expect(sockets).toHaveLength(1); }); + + it("suspends sockets when host is set to empty string (logout)", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + + // Setting host to empty string (logout) should suspend (not permanently close) + api.setHost(""); + await new Promise((resolve) => setImmediate(resolve)); + + expect(sockets[0].close).toHaveBeenCalledWith(1000, "Host cleared"); + expect(sockets).toHaveLength(1); + }); + + it("does not reconnect when setting token after clearing host", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + + api.setHost(""); + api.setSessionToken("new-token"); + await new Promise((resolve) => setImmediate(resolve)); + + // Should only have the initial socket - no reconnection after token change + expect(sockets).toHaveLength(1); + expect(sockets[0].close).toHaveBeenCalledWith(1000, "Host cleared"); + }); + + it("setCredentials sets both host and token together", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + + api.setCredentials("https://new-coder.example.com", "new-token"); + await new Promise((resolve) => setImmediate(resolve)); + + // Should reconnect only once despite both values changing + expect(sockets).toHaveLength(2); + expect(sockets[1].url).toContain("wss://new-coder.example.com"); + }); + + it("setCredentials suspends when host is cleared", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + + api.setCredentials(undefined, undefined); + await new Promise((resolve) => setImmediate(resolve)); + + expect(sockets).toHaveLength(1); + expect(sockets[0].close).toHaveBeenCalledWith(1000, "Host cleared"); + }); }); describe("Error Handling", () => { @@ -472,18 +582,32 @@ function createMockWebSocket( }; } -function createMockEventSource(url: string): Partial { - return { +type MockEventSource = Partial & { + readyState: number; + fireOpen: () => void; + fireError: () => void; +}; + +function createMockEventSource(url: string): MockEventSource { + const handlers: Record void) | undefined> = {}; + const mock: MockEventSource = { url, readyState: EventSource.CONNECTING, - addEventListener: vi.fn((event, handler) => { + addEventListener: vi.fn((event: string, handler: (e: Event) => void) => { + handlers[event] = handler; if (event === "open") { setImmediate(() => handler(new Event("open"))); } }), removeEventListener: vi.fn(), close: vi.fn(), + fireOpen: () => handlers.open?.(new Event("open")), + fireError: () => { + mock.readyState = EventSource.CLOSED; + handlers.error?.(new Event("error")); + }, }; + return mock; } function setupWebSocketMock(ws: Partial): void { diff --git a/test/unit/core/cliManager.test.ts b/test/unit/core/cliManager.test.ts index 95755d31..6cdcdb07 100644 --- a/test/unit/core/cliManager.test.ts +++ b/test/unit/core/cliManager.test.ts @@ -165,52 +165,6 @@ describe("CliManager", () => { }); }); - describe("Read CLI Configuration", () => { - it("should read and trim stored configuration", async () => { - // Create directories and write files with whitespace - vol.mkdirSync("/path/base/deployment", { recursive: true }); - memfs.writeFileSync( - "/path/base/deployment/url", - " https://coder.example.com \n", - ); - memfs.writeFileSync( - "/path/base/deployment/session", - "\t test-token \r\n", - ); - - const result = await manager.readConfig("deployment"); - - expect(result).toEqual({ - url: "https://coder.example.com", - token: "test-token", - }); - }); - - it("should return empty strings for missing files", async () => { - const result = await manager.readConfig("deployment"); - - expect(result).toEqual({ - url: "", - token: "", - }); - }); - - it("should handle partial configuration", async () => { - vol.mkdirSync("/path/base/deployment", { recursive: true }); - memfs.writeFileSync( - "/path/base/deployment/url", - "https://coder.example.com", - ); - - const result = await manager.readConfig("deployment"); - - expect(result).toEqual({ - url: "https://coder.example.com", - token: "", - }); - }); - }); - describe("Binary Version Validation", () => { it("rejects invalid server versions", async () => { mockApi.getBuildInfo = vi.fn().mockResolvedValue({ version: "invalid" }); diff --git a/test/unit/core/mementoManager.test.ts b/test/unit/core/mementoManager.test.ts index 54289a65..791f7602 100644 --- a/test/unit/core/mementoManager.test.ts +++ b/test/unit/core/mementoManager.test.ts @@ -13,28 +13,22 @@ describe("MementoManager", () => { mementoManager = new MementoManager(memento); }); - describe("setUrl", () => { - it("should store URL and add to history", async () => { - await mementoManager.setUrl("https://coder.example.com"); + describe("addToUrlHistory", () => { + it("should add URL to history", async () => { + await mementoManager.addToUrlHistory("https://coder.example.com"); - expect(mementoManager.getUrl()).toBe("https://coder.example.com"); expect(memento.get("urlHistory")).toEqual(["https://coder.example.com"]); }); it("should not update history for falsy values", async () => { - await mementoManager.setUrl(undefined); - expect(mementoManager.getUrl()).toBeUndefined(); - expect(memento.get("urlHistory")).toBeUndefined(); - - await mementoManager.setUrl(""); - expect(mementoManager.getUrl()).toBe(""); + await mementoManager.addToUrlHistory(""); expect(memento.get("urlHistory")).toBeUndefined(); }); it("should deduplicate URLs in history", async () => { - await mementoManager.setUrl("url1"); - await mementoManager.setUrl("url2"); - await mementoManager.setUrl("url1"); // Re-add first URL + await mementoManager.addToUrlHistory("url1"); + await mementoManager.addToUrlHistory("url2"); + await mementoManager.addToUrlHistory("url1"); // Re-add first URL expect(memento.get("urlHistory")).toEqual(["url2", "url1"]); }); diff --git a/test/unit/core/secretsManager.test.ts b/test/unit/core/secretsManager.test.ts index bfe8c713..f4456fa5 100644 --- a/test/unit/core/secretsManager.test.ts +++ b/test/unit/core/secretsManager.test.ts @@ -1,82 +1,194 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; -import { AuthAction, SecretsManager } from "@/core/secretsManager"; +import { + type CurrentDeploymentState, + SecretsManager, +} from "@/core/secretsManager"; -import { InMemorySecretStorage } from "../../mocks/testHelpers"; +import { + InMemoryMemento, + InMemorySecretStorage, +} from "../../mocks/testHelpers"; describe("SecretsManager", () => { let secretStorage: InMemorySecretStorage; + let memento: InMemoryMemento; let secretsManager: SecretsManager; beforeEach(() => { + vi.useRealTimers(); secretStorage = new InMemorySecretStorage(); - secretsManager = new SecretsManager(secretStorage); + memento = new InMemoryMemento(); + secretsManager = new SecretsManager(secretStorage, memento); }); - describe("session token", () => { - it("should store and retrieve tokens", async () => { - await secretsManager.setSessionToken("test-token"); - expect(await secretsManager.getSessionToken()).toBe("test-token"); - - await secretsManager.setSessionToken("new-token"); - expect(await secretsManager.getSessionToken()).toBe("new-token"); + describe("session auth", () => { + it("should store and retrieve session auth", async () => { + await secretsManager.setSessionAuth("example.com", { + url: "https://example.com", + token: "test-token", + }); + expect(await secretsManager.getSessionToken("example.com")).toBe( + "test-token", + ); + expect(await secretsManager.getUrl("example.com")).toBe( + "https://example.com", + ); + + await secretsManager.setSessionAuth("example.com", { + url: "https://example.com", + token: "new-token", + }); + expect(await secretsManager.getSessionToken("example.com")).toBe( + "new-token", + ); }); - it("should delete token when empty or undefined", async () => { - await secretsManager.setSessionToken("test-token"); - await secretsManager.setSessionToken(""); - expect(await secretsManager.getSessionToken()).toBeUndefined(); - - await secretsManager.setSessionToken("test-token"); - await secretsManager.setSessionToken(undefined); - expect(await secretsManager.getSessionToken()).toBeUndefined(); + it("should clear session auth", async () => { + await secretsManager.setSessionAuth("example.com", { + url: "https://example.com", + token: "test-token", + }); + await secretsManager.clearSessionAuth("example.com"); + expect( + await secretsManager.getSessionToken("example.com"), + ).toBeUndefined(); }); it("should return undefined for corrupted storage", async () => { - await secretStorage.store("sessionToken", "valid-token"); + await secretStorage.store( + "coder.session.example.com", + JSON.stringify({ + url: "https://example.com", + token: "valid-token", + }), + ); secretStorage.corruptStorage(); - expect(await secretsManager.getSessionToken()).toBeUndefined(); + expect( + await secretsManager.getSessionToken("example.com"), + ).toBeUndefined(); }); - }); - describe("login state", () => { - it("should trigger login events", async () => { - const events: Array = []; - secretsManager.onDidChangeLoginState((state) => { - events.push(state); - return Promise.resolve(); + it("should track known labels", async () => { + expect(secretsManager.getKnownLabels()).toEqual([]); + + await secretsManager.setSessionAuth("example.com", { + url: "https://example.com", + token: "test-token", }); + expect(secretsManager.getKnownLabels()).toContain("example.com"); - await secretsManager.triggerLoginStateChange("login"); - expect(events).toEqual([AuthAction.LOGIN]); + await secretsManager.setSessionAuth("other-com", { + url: "https://other.com", + token: "other-token", + }); + expect(secretsManager.getKnownLabels()).toContain("example.com"); + expect(secretsManager.getKnownLabels()).toContain("other-com"); }); - it("should trigger logout events", async () => { - const events: Array = []; - secretsManager.onDidChangeLoginState((state) => { - events.push(state); - return Promise.resolve(); + it("should remove label on clearAllAuthData", async () => { + await secretsManager.setSessionAuth("example.com", { + url: "https://example.com", + token: "test-token", + }); + expect(secretsManager.getKnownLabels()).toContain("example.com"); + + await secretsManager.clearAllAuthData("example.com"); + expect(secretsManager.getKnownLabels()).not.toContain("example.com"); + }); + + it("should order labels by most recently accessed", async () => { + await secretsManager.setSessionAuth("first.com", { + url: "https://first.com", + token: "token1", + }); + await secretsManager.setSessionAuth("second.com", { + url: "https://second.com", + token: "token2", + }); + await secretsManager.setSessionAuth("first.com", { + url: "https://first.com", + token: "token1-updated", }); - await secretsManager.triggerLoginStateChange("logout"); - expect(events).toEqual([AuthAction.LOGOUT]); + expect(secretsManager.getKnownLabels()).toEqual([ + "first.com", + "second.com", + ]); }); - it("should fire same event twice in a row", async () => { + it("should prune old deployments when exceeding maxCount", async () => { + for (let i = 1; i <= 5; i++) { + await secretsManager.setSessionAuth(`host${i}.com`, { + url: `https://host${i}.com`, + token: `token${i}`, + }); + } + + await secretsManager.recordDeploymentAccess("new.com", 3); + + expect(secretsManager.getKnownLabels()).toEqual([ + "new.com", + "host5.com", + "host4.com", + ]); + expect(await secretsManager.getSessionToken("host1.com")).toBeUndefined(); + expect(await secretsManager.getSessionToken("host2.com")).toBeUndefined(); + }); + }); + + describe("current deployment", () => { + it("should store and retrieve current deployment", async () => { + const deployment = { url: "https://example.com", label: "example.com" }; + await secretsManager.setCurrentDeployment(deployment); + + const result = await secretsManager.getCurrentDeployment(); + expect(result).toEqual(deployment); + }); + + it("should clear current deployment with undefined", async () => { + const deployment = { url: "https://example.com", label: "example.com" }; + await secretsManager.setCurrentDeployment(deployment); + await secretsManager.setCurrentDeployment(undefined); + + const result = await secretsManager.getCurrentDeployment(); + expect(result).toBeUndefined(); + }); + + it("should return undefined when no deployment set", async () => { + const result = await secretsManager.getCurrentDeployment(); + expect(result).toBeUndefined(); + }); + + it("should notify listeners on deployment change", async () => { vi.useFakeTimers(); - const events: Array = []; - secretsManager.onDidChangeLoginState((state) => { + const events: Array = []; + secretsManager.onDidChangeCurrentDeployment((state) => { events.push(state); - return Promise.resolve(); }); - await secretsManager.triggerLoginStateChange("login"); + const deployments = [ + { url: "https://example.com", label: "example.com" }, + { url: "https://another.org", label: "another.org" }, + { url: "https://another.org", label: "another.org" }, + ]; + await secretsManager.setCurrentDeployment(deployments[0]); + vi.advanceTimersByTime(5); + await secretsManager.setCurrentDeployment(deployments[1]); vi.advanceTimersByTime(5); - await secretsManager.triggerLoginStateChange("login"); + await secretsManager.setCurrentDeployment(deployments[2]); + vi.advanceTimersByTime(5); + + // Trigger an event even if the deployment did not change + expect(events).toEqual(deployments.map((deployment) => ({ deployment }))); + }); + + it("should handle corrupted storage gracefully", async () => { + await secretStorage.store("coder.currentDeployment", "invalid-json{"); - expect(events).toEqual([AuthAction.LOGIN, AuthAction.LOGIN]); - vi.useRealTimers(); + const result = await secretsManager.getCurrentDeployment(); + expect(result).toBeUndefined(); }); }); }); diff --git a/test/unit/websocket/reconnectingWebSocket.test.ts b/test/unit/websocket/reconnectingWebSocket.test.ts index cdf08949..1434b6a6 100644 --- a/test/unit/websocket/reconnectingWebSocket.test.ts +++ b/test/unit/websocket/reconnectingWebSocket.test.ts @@ -104,6 +104,32 @@ describe("ReconnectingWebSocket", () => { }, ); + it.each([ + HttpStatusCode.UNAUTHORIZED, + HttpStatusCode.FORBIDDEN, + HttpStatusCode.GONE, + ])( + "does not reconnect on unrecoverable HTTP error via error event: %i", + async (statusCode) => { + // HTTP errors during handshake fire 'error' event, then 'close' with 1006 + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireError( + new Error(`Unexpected server response: ${statusCode}`), + ); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Connection failed", + }); + + // Should not reconnect - unrecoverable HTTP error + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + + ws.close(); + }, + ); + it("reconnect() connects immediately and cancels pending reconnections", async () => { const { ws, sockets } = await createReconnectingWebSocket(); @@ -125,26 +151,8 @@ describe("ReconnectingWebSocket", () => { }); it("queues reconnect() calls made during connection", async () => { - const sockets: MockSocket[] = []; - let pendingResolve: ((socket: MockSocket) => void) | null = null; - - const factory = vi.fn(() => { - const socket = createMockSocket(); - sockets.push(socket); - - // First call resolves immediately, other calls wait for manual resolve - if (sockets.length === 1) { - return Promise.resolve(socket); - } else { - return new Promise((resolve) => { - pendingResolve = resolve; - }); - } - }); - - const ws = await fromFactory(factory); - sockets[0].fireOpen(); - expect(sockets).toHaveLength(1); + const { ws, sockets, completeConnection } = + await createBlockingReconnectingWebSocket(); // Start first reconnect (will block on factory promise) ws.reconnect(); @@ -154,17 +162,33 @@ describe("ReconnectingWebSocket", () => { // Still only 2 sockets (queued reconnect hasn't started) expect(sockets).toHaveLength(2); - // Complete the first reconnect - pendingResolve!(sockets[1]); - sockets[1].fireOpen(); - - // Wait a tick for the queued reconnect to execute + completeConnection(); await Promise.resolve(); // Now queued reconnect should have executed, creating third socket expect(sockets).toHaveLength(3); ws.close(); }); + + it("suspend() cancels pending reconnect queued during connection", async () => { + const { ws, sockets, failConnection } = + await createBlockingReconnectingWebSocket(); + + ws.reconnect(); + ws.reconnect(); // queued + expect(sockets).toHaveLength(2); + + // This should cancel the queued request + ws.suspend(); + failConnection(new Error("No base URL")); + await Promise.resolve(); + + expect(sockets).toHaveLength(2); + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(2); + + ws.close(); + }); }); describe("Event Handlers", () => { @@ -216,6 +240,48 @@ describe("ReconnectingWebSocket", () => { ws.close(); }); + + it("preserves event handlers after suspend() and reconnect()", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + sockets[0].fireOpen(); + + const handler = vi.fn(); + ws.addEventListener("message", handler); + sockets[0].fireMessage({ test: 1 }); + expect(handler).toHaveBeenCalledTimes(1); + + // Suspend the socket + ws.suspend(); + + // Reconnect (async operation) + ws.reconnect(); + await Promise.resolve(); // Wait for async connect() + expect(sockets).toHaveLength(2); + sockets[1].fireOpen(); + + // Handler should still work after suspend/reconnect + sockets[1].fireMessage({ test: 2 }); + expect(handler).toHaveBeenCalledTimes(2); + + ws.close(); + }); + + it("clears event handlers after close()", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + sockets[0].fireOpen(); + + const handler = vi.fn(); + ws.addEventListener("message", handler); + sockets[0].fireMessage({ test: 1 }); + expect(handler).toHaveBeenCalledTimes(1); + + // Close permanently + ws.close(); + + // Even if we could reconnect (we can't), handlers would be cleared + // Verify handler was removed by checking it's no longer in the set + // We can't easily test this without exposing internals, but close() clears handlers + }); }); describe("close() and Disposal", () => { @@ -258,9 +324,9 @@ describe("ReconnectingWebSocket", () => { expect(disposeCount).toBe(1); }); - it("calls onDispose callback on unrecoverable WebSocket close code", async () => { + it("suspends (not disposes) on unrecoverable WebSocket close code", async () => { let disposeCount = 0; - const { sockets } = await createReconnectingWebSocket( + const { ws, sockets } = await createReconnectingWebSocket( () => ++disposeCount, ); @@ -270,7 +336,14 @@ describe("ReconnectingWebSocket", () => { reason: "Protocol error", }); - expect(disposeCount).toBe(1); + // Should suspend, not dispose - allows recovery when credentials change + expect(disposeCount).toBe(0); + + // Should be able to reconnect after suspension + ws.reconnect(); + expect(sockets).toHaveLength(2); + + ws.close(); }); it("does not call onDispose callback during reconnection", async () => { @@ -291,6 +364,41 @@ describe("ReconnectingWebSocket", () => { ws.close(); expect(disposeCount).toBe(1); }); + + it("reconnect() resumes suspended socket after HTTP 403 error", async () => { + const { ws, sockets, setFactoryError } = + await createReconnectingWebSocketWithErrorControl(); + sockets[0].fireOpen(); + + // Trigger reconnect that will fail with 403 + setFactoryError( + new Error(`Unexpected server response: ${HttpStatusCode.FORBIDDEN}`), + ); + ws.reconnect(); + await Promise.resolve(); + + // Socket should be suspended - no automatic reconnection + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + + // reconnect() should resume the suspended socket + setFactoryError(null); + ws.reconnect(); + await Promise.resolve(); + expect(sockets).toHaveLength(2); + + ws.close(); + }); + + it("reconnect() does nothing after close()", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + sockets[0].fireOpen(); + + ws.close(); + ws.reconnect(); + + expect(sockets).toHaveLength(1); + }); }); describe("Backoff Strategy", () => { @@ -454,6 +562,35 @@ async function createReconnectingWebSocket(onDispose?: () => void): Promise<{ return { ws, sockets }; } +async function createReconnectingWebSocketWithErrorControl(): Promise<{ + ws: ReconnectingWebSocket; + sockets: MockSocket[]; + setFactoryError: (error: Error | null) => void; +}> { + const sockets: MockSocket[] = []; + let factoryError: Error | null = null; + + const factory = vi.fn(() => { + if (factoryError) { + return Promise.reject(factoryError); + } + const socket = createMockSocket(); + sockets.push(socket); + return Promise.resolve(socket); + }); + + const ws = await fromFactory(factory); + expect(sockets).toHaveLength(1); + + return { + ws, + sockets, + setFactoryError: (error: Error | null) => { + factoryError = error; + }, + }; +} + async function fromFactory( factory: SocketFactory, onDispose?: () => void, @@ -466,3 +603,40 @@ async function fromFactory( onDispose, ); } + +async function createBlockingReconnectingWebSocket(): Promise<{ + ws: ReconnectingWebSocket; + sockets: MockSocket[]; + completeConnection: () => void; + failConnection: (error: Error) => void; +}> { + const sockets: MockSocket[] = []; + let pendingResolve: ((socket: MockSocket) => void) | null = null; + let pendingReject: ((error: Error) => void) | null = null; + + const factory = vi.fn(() => { + const socket = createMockSocket(); + sockets.push(socket); + if (sockets.length === 1) { + return Promise.resolve(socket); + } + return new Promise((resolve, reject) => { + pendingResolve = resolve; + pendingReject = reject; + }); + }); + + const ws = await fromFactory(factory); + sockets[0].fireOpen(); + + return { + ws, + sockets, + completeConnection: () => { + const socket = sockets.at(-1)!; + pendingResolve?.(socket); + socket.fireOpen(); + }, + failConnection: (error: Error) => pendingReject?.(error), + }; +}