Skip to content

Commit ea7b658

Browse files
committed
WebSocket/SSE reconnection improvements
1 parent a800581 commit ea7b658

File tree

5 files changed

+270
-138
lines changed

5 files changed

+270
-138
lines changed

src/api/coderApi.ts

Lines changed: 144 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -214,68 +214,62 @@ export class CoderApi extends Api implements vscode.Disposable {
214214
throw new Error("No base URL set on REST client");
215215
}
216216

217-
const baseUrl = new URL(baseUrlRaw);
218-
const token = this.getAxiosInstance().defaults.headers.common[
219-
coderSessionTokenHeader
220-
] as string | undefined;
221-
222-
const headersFromCommand = await getHeaders(
223-
baseUrlRaw,
224-
getHeaderCommand(vscode.workspace.getConfiguration()),
225-
this.output,
226-
);
217+
return this.createOneWayWebSocket<TData>(socketConfigs);
218+
};
227219

228-
const httpAgent = await createHttpAgent(
229-
vscode.workspace.getConfiguration(),
230-
);
220+
if (enableRetry) {
221+
return this.createReconnectingSocket(socketFactory, configs.apiRoute);
222+
}
223+
return socketFactory();
224+
}
231225

232-
/**
233-
* Similar to the REST client, we want to prioritize headers in this order (highest to lowest):
234-
* 1. Headers from the header command
235-
* 2. Any headers passed directly to this function
236-
* 3. Coder session token from the Api client (if set)
237-
*/
238-
const headers = {
239-
...(token ? { [coderSessionTokenHeader]: token } : {}),
240-
...configs.options?.headers,
241-
...headersFromCommand,
242-
};
226+
private async createOneWayWebSocket<TData>(
227+
configs: Omit<OneWayWebSocketInit, "location">,
228+
): Promise<OneWayWebSocket<TData>> {
229+
const baseUrlRaw = this.getAxiosInstance().defaults.baseURL;
230+
if (!baseUrlRaw) {
231+
throw new Error("No base URL set on REST client");
232+
}
233+
const token = this.getAxiosInstance().defaults.headers.common[
234+
coderSessionTokenHeader
235+
] as string | undefined;
243236

244-
const webSocket = new OneWayWebSocket<TData>({
245-
location: baseUrl,
246-
...socketConfigs,
247-
options: {
248-
...configs.options,
249-
agent: httpAgent,
250-
followRedirects: true,
251-
headers,
252-
},
253-
});
237+
const headersFromCommand = await getHeaders(
238+
baseUrlRaw,
239+
getHeaderCommand(vscode.workspace.getConfiguration()),
240+
this.output,
241+
);
254242

255-
this.attachStreamLogger(webSocket);
256-
return webSocket;
257-
};
243+
const httpAgent = await createHttpAgent(
244+
vscode.workspace.getConfiguration(),
245+
);
258246

259-
if (enableRetry) {
260-
const reconnectingSocket = await ReconnectingWebSocket.create<TData>(
261-
socketFactory,
262-
this.output,
263-
configs.apiRoute,
264-
undefined,
265-
() =>
266-
this.reconnectingSockets.delete(
267-
reconnectingSocket as ReconnectingWebSocket<unknown>,
268-
),
269-
);
247+
/**
248+
* Similar to the REST client, we want to prioritize headers in this order (highest to lowest):
249+
* 1. Headers from the header command
250+
* 2. Any headers passed directly to this function
251+
* 3. Coder session token from the Api client (if set)
252+
*/
253+
const headers = {
254+
...(token ? { [coderSessionTokenHeader]: token } : {}),
255+
...configs.options?.headers,
256+
...headersFromCommand,
257+
};
270258

271-
this.reconnectingSockets.add(
272-
reconnectingSocket as ReconnectingWebSocket<unknown>,
273-
);
259+
const baseUrl = new URL(baseUrlRaw);
260+
const ws = new OneWayWebSocket<TData>({
261+
location: baseUrl,
262+
...configs,
263+
options: {
264+
...configs.options,
265+
agent: httpAgent,
266+
followRedirects: true,
267+
headers,
268+
},
269+
});
274270

275-
return reconnectingSocket;
276-
} else {
277-
return socketFactory();
278-
}
271+
this.attachStreamLogger(ws);
272+
return ws;
279273
}
280274

281275
private attachStreamLogger<TData>(
@@ -304,45 +298,80 @@ export class CoderApi extends Api implements vscode.Disposable {
304298
/**
305299
* Create a WebSocket connection with SSE fallback on 404.
306300
*
301+
* The factory tries WS first, falls back to SSE on 404. Since the factory
302+
* is called on every reconnect.
303+
*
307304
* Note: The fallback on SSE ignores all passed client options except the headers.
308305
*/
309-
private async createWebSocketWithFallback(configs: {
310-
apiRoute: string;
311-
fallbackApiRoute: string;
312-
searchParams?: Record<string, string> | URLSearchParams;
313-
options?: ClientOptions;
314-
enableRetry?: boolean;
315-
}): Promise<UnidirectionalStream<ServerSentEvent>> {
316-
let webSocket: UnidirectionalStream<ServerSentEvent>;
317-
try {
318-
webSocket = await this.createWebSocket<ServerSentEvent>({
319-
apiRoute: configs.apiRoute,
320-
searchParams: configs.searchParams,
321-
options: configs.options,
322-
enableRetry: configs.enableRetry,
323-
});
324-
} catch {
325-
// Failed to create WebSocket, use SSE fallback
326-
return this.createSseFallback(
327-
configs.fallbackApiRoute,
328-
configs.searchParams,
329-
configs.options?.headers,
306+
private async createWebSocketWithFallback(
307+
configs: Omit<OneWayWebSocketInit, "location"> & {
308+
fallbackApiRoute: string;
309+
enableRetry?: boolean;
310+
},
311+
): Promise<UnidirectionalStream<ServerSentEvent>> {
312+
const { fallbackApiRoute, enableRetry, ...socketConfigs } = configs;
313+
const socketFactory: SocketFactory<ServerSentEvent> = async () => {
314+
try {
315+
const ws =
316+
await this.createOneWayWebSocket<ServerSentEvent>(socketConfigs);
317+
return await this.waitForOpen(ws);
318+
} catch (error) {
319+
if (this.is404Error(error)) {
320+
this.output.warn(
321+
`WebSocket failed, using SSE fallback: ${socketConfigs.apiRoute}`,
322+
);
323+
const sse = this.createSseConnection(
324+
fallbackApiRoute,
325+
socketConfigs.searchParams,
326+
socketConfigs.options?.headers,
327+
);
328+
return await this.waitForOpen(sse);
329+
}
330+
throw error;
331+
}
332+
};
333+
334+
if (enableRetry) {
335+
return this.createReconnectingSocket(
336+
socketFactory,
337+
socketConfigs.apiRoute,
330338
);
331339
}
340+
return socketFactory();
341+
}
332342

333-
return this.waitForConnection(webSocket, () =>
334-
this.createSseFallback(
335-
configs.fallbackApiRoute,
336-
configs.searchParams,
337-
configs.options?.headers,
338-
),
339-
);
343+
/**
344+
* Create an SSE connection without waiting for connection.
345+
*/
346+
private createSseConnection(
347+
apiRoute: string,
348+
searchParams?: Record<string, string> | URLSearchParams,
349+
optionsHeaders?: Record<string, string>,
350+
): SseConnection {
351+
const baseUrlRaw = this.getAxiosInstance().defaults.baseURL;
352+
if (!baseUrlRaw) {
353+
throw new Error("No base URL set on REST client");
354+
}
355+
const url = new URL(baseUrlRaw);
356+
const sse = new SseConnection({
357+
location: url,
358+
apiRoute,
359+
searchParams,
360+
axiosInstance: this.getAxiosInstance(),
361+
optionsHeaders,
362+
logger: this.output,
363+
});
364+
365+
this.attachStreamLogger(sse);
366+
return sse;
340367
}
341368

342-
private waitForConnection(
343-
connection: UnidirectionalStream<ServerSentEvent>,
344-
onNotFound?: () => Promise<UnidirectionalStream<ServerSentEvent>>,
345-
): Promise<UnidirectionalStream<ServerSentEvent>> {
369+
/**
370+
* Wait for a connection to open. Rejects on error.
371+
*/
372+
private waitForOpen<TData>(
373+
connection: UnidirectionalStream<TData>,
374+
): Promise<UnidirectionalStream<TData>> {
346375
return new Promise((resolve, reject) => {
347376
const cleanup = () => {
348377
connection.removeEventListener("open", handleOpen);
@@ -356,21 +385,8 @@ export class CoderApi extends Api implements vscode.Disposable {
356385

357386
const handleError = (event: ErrorEvent) => {
358387
cleanup();
359-
const is404 =
360-
event.message?.includes(String(HttpStatusCode.NOT_FOUND)) ||
361-
event.error?.message?.includes(String(HttpStatusCode.NOT_FOUND));
362-
363-
if (is404 && onNotFound) {
364-
if (connection instanceof ReconnectingWebSocket) {
365-
// We can attempt this again if we change the host
366-
connection.suspend();
367-
} else {
368-
connection.close();
369-
}
370-
onNotFound().then(resolve).catch(reject);
371-
} else {
372-
reject(event.error || new Error(event.message));
373-
}
388+
connection.close();
389+
reject(event.error || new Error(event.message));
374390
};
375391

376392
connection.addEventListener("open", handleOpen);
@@ -379,32 +395,36 @@ export class CoderApi extends Api implements vscode.Disposable {
379395
}
380396

381397
/**
382-
* Create SSE fallback connection
398+
* Check if an error is a 404 Not Found error.
383399
*/
384-
private async createSseFallback(
385-
apiRoute: string,
386-
searchParams?: Record<string, string> | URLSearchParams,
387-
optionsHeaders?: Record<string, string>,
388-
): Promise<UnidirectionalStream<ServerSentEvent>> {
389-
this.output.warn(`WebSocket failed, using SSE fallback: ${apiRoute}`);
390-
391-
const baseUrlRaw = this.getAxiosInstance().defaults.baseURL;
392-
if (!baseUrlRaw) {
393-
throw new Error("No base URL set on REST client");
394-
}
400+
private is404Error(error: unknown): boolean {
401+
const msg = error instanceof Error ? error.message : String(error);
402+
return msg.includes(String(HttpStatusCode.NOT_FOUND));
403+
}
395404

396-
const baseUrl = new URL(baseUrlRaw);
397-
const sseConnection = new SseConnection({
398-
location: baseUrl,
405+
/**
406+
* Create a ReconnectingWebSocket and track it for lifecycle management.
407+
*/
408+
private async createReconnectingSocket<TData>(
409+
socketFactory: SocketFactory<TData>,
410+
apiRoute: string,
411+
): Promise<ReconnectingWebSocket<TData>> {
412+
const reconnectingSocket = await ReconnectingWebSocket.create<TData>(
413+
socketFactory,
414+
this.output,
399415
apiRoute,
400-
searchParams,
401-
axiosInstance: this.getAxiosInstance(),
402-
optionsHeaders: optionsHeaders,
403-
logger: this.output,
404-
});
416+
undefined,
417+
() =>
418+
this.reconnectingSockets.delete(
419+
reconnectingSocket as ReconnectingWebSocket<unknown>,
420+
),
421+
);
422+
423+
this.reconnectingSockets.add(
424+
reconnectingSocket as ReconnectingWebSocket<unknown>,
425+
);
405426

406-
this.attachStreamLogger(sseConnection);
407-
return this.waitForConnection(sseConnection);
427+
return reconnectingSocket;
408428
}
409429
}
410430

src/websocket/codes.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ export const WebSocketCloseCode = {
1919

2020
/** HTTP status codes used for socket creation and connection logic */
2121
export const HttpStatusCode = {
22-
/** Authentication or permission denied */
22+
/** Authentication required */
23+
UNAUTHORIZED: 401,
24+
/** Permission denied */
2325
FORBIDDEN: 403,
2426
/** Endpoint not found */
2527
NOT_FOUND: 404,
@@ -43,6 +45,7 @@ export const UNRECOVERABLE_WS_CLOSE_CODES = new Set<number>([
4345
* These appear during socket creation and should stop reconnection attempts.
4446
*/
4547
export const UNRECOVERABLE_HTTP_CODES = new Set<number>([
48+
HttpStatusCode.UNAUTHORIZED,
4649
HttpStatusCode.FORBIDDEN,
4750
HttpStatusCode.NOT_FOUND,
4851
HttpStatusCode.GONE,

src/websocket/reconnectingWebSocket.ts

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,17 @@ export class ReconnectingWebSocket<TData = unknown>
186186

187187
socket.addEventListener("error", (event) => {
188188
this.executeHandlers("error", event);
189+
190+
// Check for unrecoverable HTTP errors in the error event
191+
// HTTP errors during handshake fire 'error' then 'close' with 1006
192+
// We need to suspend here to prevent infinite reconnect loops
193+
const errorMessage = event.error?.message ?? event.message ?? "";
194+
if (this.isUnrecoverableHttpError(errorMessage)) {
195+
this.#logger.error(
196+
`Unrecoverable HTTP error for ${this.#apiRoute}: ${errorMessage}`,
197+
);
198+
this.suspend();
199+
}
189200
});
190201

191202
socket.addEventListener("close", (event) => {
@@ -289,12 +300,12 @@ export class ReconnectingWebSocket<TData = unknown>
289300
}
290301

291302
/**
292-
* Check if an error contains an unrecoverable HTTP status code.
303+
* Check if an error message contains an unrecoverable HTTP status code.
293304
*/
294305
private isUnrecoverableHttpError(error: unknown): boolean {
295-
const errorMessage = error instanceof Error ? error.message : String(error);
306+
const message = (error as { message?: string }).message || String(error);
296307
for (const code of UNRECOVERABLE_HTTP_CODES) {
297-
if (errorMessage.includes(String(code))) {
308+
if (message.includes(String(code))) {
298309
return true;
299310
}
300311
}

0 commit comments

Comments
 (0)