Skip to content

Commit a800581

Browse files
committed
Add WebSocket suspension
1 parent 7f7cb74 commit a800581

File tree

7 files changed

+211
-35
lines changed

7 files changed

+211
-35
lines changed

src/api/coderApi.ts

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import {
3131
HttpClientLogLevel,
3232
} from "../logging/types";
3333
import { sizeOf } from "../logging/utils";
34-
import { HttpStatusCode } from "../websocket/codes";
34+
import { HttpStatusCode, WebSocketCloseCode } from "../websocket/codes";
3535
import {
3636
type UnidirectionalStream,
3737
type CloseEvent,
@@ -55,7 +55,7 @@ const coderSessionTokenHeader = "Coder-Session-Token";
5555
* Unified API class that includes both REST API methods from the base Api class
5656
* and WebSocket methods for real-time functionality.
5757
*/
58-
export class CoderApi extends Api {
58+
export class CoderApi extends Api implements vscode.Disposable {
5959
private readonly reconnectingSockets = new Set<
6060
ReconnectingWebSocket<unknown>
6161
>();
@@ -102,11 +102,27 @@ export class CoderApi extends Api {
102102

103103
if (currentHost !== host) {
104104
for (const socket of this.reconnectingSockets) {
105-
socket.reconnect();
105+
if (host) {
106+
socket.reconnect();
107+
} else {
108+
// No host means logout - suspend sockets (can resume when host is set again)
109+
socket.suspend(WebSocketCloseCode.NORMAL, "Host cleared");
110+
}
106111
}
107112
}
108113
};
109114

115+
/**
116+
* Permanently dispose all WebSocket connections.
117+
* This clears handlers and prevents reconnection.
118+
*/
119+
dispose(): void {
120+
for (const socket of this.reconnectingSockets) {
121+
socket.close();
122+
}
123+
this.reconnectingSockets.clear();
124+
}
125+
110126
watchInboxNotifications = async (
111127
watchTemplates: string[],
112128
watchTargets: string[],
@@ -125,7 +141,7 @@ export class CoderApi extends Api {
125141
};
126142

127143
watchWorkspace = async (workspace: Workspace, options?: ClientOptions) => {
128-
return this.createWebSocketWithFallback<ServerSentEvent>({
144+
return this.createWebSocketWithFallback({
129145
apiRoute: `/api/v2/workspaces/${workspace.id}/watch-ws`,
130146
fallbackApiRoute: `/api/v2/workspaces/${workspace.id}/watch`,
131147
options,
@@ -137,7 +153,7 @@ export class CoderApi extends Api {
137153
agentId: WorkspaceAgent["id"],
138154
options?: ClientOptions,
139155
) => {
140-
return this.createWebSocketWithFallback<ServerSentEvent>({
156+
return this.createWebSocketWithFallback({
141157
apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`,
142158
fallbackApiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata`,
143159
options,
@@ -290,43 +306,43 @@ export class CoderApi extends Api {
290306
*
291307
* Note: The fallback on SSE ignores all passed client options except the headers.
292308
*/
293-
private async createWebSocketWithFallback<TData = unknown>(configs: {
309+
private async createWebSocketWithFallback(configs: {
294310
apiRoute: string;
295311
fallbackApiRoute: string;
296312
searchParams?: Record<string, string> | URLSearchParams;
297313
options?: ClientOptions;
298314
enableRetry?: boolean;
299-
}): Promise<UnidirectionalStream<TData>> {
300-
let webSocket: UnidirectionalStream<TData>;
315+
}): Promise<UnidirectionalStream<ServerSentEvent>> {
316+
let webSocket: UnidirectionalStream<ServerSentEvent>;
301317
try {
302-
webSocket = await this.createWebSocket<TData>({
318+
webSocket = await this.createWebSocket<ServerSentEvent>({
303319
apiRoute: configs.apiRoute,
304320
searchParams: configs.searchParams,
305321
options: configs.options,
306322
enableRetry: configs.enableRetry,
307323
});
308324
} catch {
309325
// Failed to create WebSocket, use SSE fallback
310-
return this.createSseFallback<TData>(
326+
return this.createSseFallback(
311327
configs.fallbackApiRoute,
312328
configs.searchParams,
313329
configs.options?.headers,
314330
);
315331
}
316332

317333
return this.waitForConnection(webSocket, () =>
318-
this.createSseFallback<TData>(
334+
this.createSseFallback(
319335
configs.fallbackApiRoute,
320336
configs.searchParams,
321337
configs.options?.headers,
322338
),
323339
);
324340
}
325341

326-
private waitForConnection<TData>(
327-
connection: UnidirectionalStream<TData>,
328-
onNotFound?: () => Promise<UnidirectionalStream<TData>>,
329-
): Promise<UnidirectionalStream<TData>> {
342+
private waitForConnection(
343+
connection: UnidirectionalStream<ServerSentEvent>,
344+
onNotFound?: () => Promise<UnidirectionalStream<ServerSentEvent>>,
345+
): Promise<UnidirectionalStream<ServerSentEvent>> {
330346
return new Promise((resolve, reject) => {
331347
const cleanup = () => {
332348
connection.removeEventListener("open", handleOpen);
@@ -345,7 +361,12 @@ export class CoderApi extends Api {
345361
event.error?.message?.includes(String(HttpStatusCode.NOT_FOUND));
346362

347363
if (is404 && onNotFound) {
348-
connection.close();
364+
if (connection instanceof ReconnectingWebSocket) {
365+
// We can attempt this again if we change the host
366+
connection.suspend();
367+
} else {
368+
connection.close();
369+
}
349370
onNotFound().then(resolve).catch(reject);
350371
} else {
351372
reject(event.error || new Error(event.message));
@@ -360,11 +381,11 @@ export class CoderApi extends Api {
360381
/**
361382
* Create SSE fallback connection
362383
*/
363-
private async createSseFallback<TData = unknown>(
384+
private async createSseFallback(
364385
apiRoute: string,
365386
searchParams?: Record<string, string> | URLSearchParams,
366387
optionsHeaders?: Record<string, string>,
367-
): Promise<UnidirectionalStream<TData>> {
388+
): Promise<UnidirectionalStream<ServerSentEvent>> {
368389
this.output.warn(`WebSocket failed, using SSE fallback: ${apiRoute}`);
369390

370391
const baseUrlRaw = this.getAxiosInstance().defaults.baseURL;

src/extension.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ export async function activate(ctx: vscode.ExtensionContext): Promise<void> {
9393
await secretsManager.getSessionToken(deployment?.label ?? ""),
9494
output,
9595
);
96+
ctx.subscriptions.push(client);
9697
attachOAuthInterceptors(client, output, oauthSessionManager);
9798

9899
const myWorkspacesProvider = new WorkspaceProvider(

src/remote/remote.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ export class Remote {
175175
// disallow logging out/in altogether, but for now just use a separate
176176
// client to remain unaffected by whatever the plugin is doing.
177177
const workspaceClient = CoderApi.create(baseUrlRaw, token, this.logger);
178+
disposables.push(workspaceClient);
178179
attachOAuthInterceptors(workspaceClient, this.logger, remoteOAuthManager);
179180
// Store for use in commands.
180181
this.commands.workspaceRestClient = workspaceClient;

src/websocket/codes.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export const UNRECOVERABLE_WS_CLOSE_CODES = new Set<number>([
4444
*/
4545
export const UNRECOVERABLE_HTTP_CODES = new Set<number>([
4646
HttpStatusCode.FORBIDDEN,
47+
HttpStatusCode.NOT_FOUND,
4748
HttpStatusCode.GONE,
4849
HttpStatusCode.UPGRADE_REQUIRED,
4950
]);

src/websocket/reconnectingWebSocket.ts

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ export class ReconnectingWebSocket<TData = unknown>
4141
#currentSocket: UnidirectionalStream<TData> | null = null;
4242
#backoffMs: number;
4343
#reconnectTimeoutId: NodeJS.Timeout | null = null;
44-
#isDisposed = false;
44+
#isSuspended = false; // Temporary pause, can be resumed via reconnect()
45+
#isDisposed = false; // Permanent disposal, cannot be resumed
4546
#isConnecting = false;
4647
#pendingReconnect = false;
4748
readonly #onDispose?: () => void;
@@ -102,6 +103,11 @@ export class ReconnectingWebSocket<TData = unknown>
102103
}
103104

104105
reconnect(): void {
106+
if (this.#isSuspended) {
107+
this.#isSuspended = false;
108+
this.#backoffMs = this.#options.initialBackoffMs;
109+
}
110+
105111
if (this.#isDisposed) {
106112
return;
107113
}
@@ -121,6 +127,18 @@ export class ReconnectingWebSocket<TData = unknown>
121127
this.connect().catch((error) => this.handleConnectionError(error));
122128
}
123129

130+
/**
131+
* Temporarily suspend the socket. Can be resumed via reconnect().
132+
*/
133+
suspend(code?: number, reason?: string): void {
134+
if (this.#isDisposed || this.#isSuspended) {
135+
return;
136+
}
137+
138+
this.#isSuspended = true;
139+
this.clearCurrentSocket(code, reason);
140+
}
141+
124142
close(code?: number, reason?: string): void {
125143
if (this.#isDisposed) {
126144
return;
@@ -139,7 +157,7 @@ export class ReconnectingWebSocket<TData = unknown>
139157
}
140158

141159
private async connect(): Promise<void> {
142-
if (this.#isDisposed || this.#isConnecting) {
160+
if (this.#isDisposed || this.#isSuspended || this.#isConnecting) {
143161
return;
144162
}
145163

@@ -171,7 +189,7 @@ export class ReconnectingWebSocket<TData = unknown>
171189
});
172190

173191
socket.addEventListener("close", (event) => {
174-
if (this.#isDisposed) {
192+
if (this.#isDisposed || this.#isSuspended) {
175193
return;
176194
}
177195

@@ -181,7 +199,8 @@ export class ReconnectingWebSocket<TData = unknown>
181199
this.#logger.error(
182200
`WebSocket connection closed with unrecoverable error code ${event.code}`,
183201
);
184-
this.dispose();
202+
// Suspend instead of dispose - allows recovery when credentials change
203+
this.suspend();
185204
return;
186205
}
187206

@@ -204,7 +223,11 @@ export class ReconnectingWebSocket<TData = unknown>
204223
}
205224

206225
private scheduleReconnect(): void {
207-
if (this.#isDisposed || this.#reconnectTimeoutId !== null) {
226+
if (
227+
this.#isDisposed ||
228+
this.#isSuspended ||
229+
this.#reconnectTimeoutId !== null
230+
) {
208231
return;
209232
}
210233

@@ -241,11 +264,11 @@ export class ReconnectingWebSocket<TData = unknown>
241264
}
242265

243266
/**
244-
* Checks if the error is unrecoverable and disposes the connection,
267+
* Checks if the error is unrecoverable and suspends the connection,
245268
* otherwise schedules a reconnect.
246269
*/
247270
private handleConnectionError(error: unknown): void {
248-
if (this.#isDisposed) {
271+
if (this.#isDisposed || this.#isSuspended) {
249272
return;
250273
}
251274

@@ -254,7 +277,7 @@ export class ReconnectingWebSocket<TData = unknown>
254277
`Unrecoverable HTTP error during connection for ${this.#apiRoute}`,
255278
error,
256279
);
257-
this.dispose();
280+
this.suspend();
258281
return;
259282
}
260283

@@ -284,7 +307,16 @@ export class ReconnectingWebSocket<TData = unknown>
284307
}
285308

286309
this.#isDisposed = true;
310+
this.clearCurrentSocket(code, reason);
287311

312+
for (const set of Object.values(this.#eventHandlers)) {
313+
set.clear();
314+
}
315+
316+
this.#onDispose?.();
317+
}
318+
319+
private clearCurrentSocket(code?: number, reason?: string): void {
288320
if (this.#reconnectTimeoutId !== null) {
289321
clearTimeout(this.#reconnectTimeoutId);
290322
this.#reconnectTimeoutId = null;
@@ -294,11 +326,5 @@ export class ReconnectingWebSocket<TData = unknown>
294326
this.#currentSocket.close(code, reason);
295327
this.#currentSocket = null;
296328
}
297-
298-
for (const set of Object.values(this.#eventHandlers)) {
299-
set.clear();
300-
}
301-
302-
this.#onDispose?.();
303329
}
304330
}

test/unit/api/coderApi.test.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,19 @@ describe("CoderApi", () => {
435435
expect(sockets[0].close).not.toHaveBeenCalled();
436436
expect(sockets).toHaveLength(1);
437437
});
438+
439+
it("suspends sockets when host is set to empty string (logout)", async () => {
440+
const sockets = setupAutoOpeningWebSocket();
441+
api = createApi(CODER_URL, AXIOS_TOKEN);
442+
await api.watchAgentMetadata(AGENT_ID);
443+
444+
// Setting host to empty string (logout) should suspend (not permanently close)
445+
api.setHost("");
446+
await new Promise((resolve) => setImmediate(resolve));
447+
448+
expect(sockets[0].close).toHaveBeenCalledWith(1000, "Host cleared");
449+
expect(sockets).toHaveLength(1);
450+
});
438451
});
439452

440453
describe("Error Handling", () => {

0 commit comments

Comments
 (0)