Skip to content

Commit 7afe6c8

Browse files
fix(coderd): ensure agent WebSocket conn is cleaned up (coder#19711) (coder#20094)
Co-authored-by: Danielle Maywood <danielle@themaywoods.com>
1 parent 5369204 commit 7afe6c8

File tree

2 files changed

+144
-6
lines changed

2 files changed

+144
-6
lines changed

coderd/workspaceagents.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -817,12 +817,13 @@ func (api *API) watchWorkspaceAgentContainers(rw http.ResponseWriter, r *http.Re
817817
var (
818818
ctx = r.Context()
819819
workspaceAgent = httpmw.WorkspaceAgentParam(r)
820+
logger = api.Logger.Named("agent_container_watcher").With(slog.F("agent_id", workspaceAgent.ID))
820821
)
821822

822823
// If the agent is unreachable, the request will hang. Assume that if we
823824
// don't get a response after 30s that the agent is unreachable.
824-
dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
825-
defer cancel()
825+
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
826+
defer dialCancel()
826827
apiAgent, err := db2sdk.WorkspaceAgent(
827828
api.DERPMap(),
828829
*api.TailnetCoordinator.Load(),
@@ -857,8 +858,7 @@ func (api *API) watchWorkspaceAgentContainers(rw http.ResponseWriter, r *http.Re
857858
}
858859
defer release()
859860

860-
watcherLogger := api.Logger.Named("agent_container_watcher").With(slog.F("agent_id", workspaceAgent.ID))
861-
containersCh, closer, err := agentConn.WatchContainers(ctx, watcherLogger)
861+
containersCh, closer, err := agentConn.WatchContainers(ctx, logger)
862862
if err != nil {
863863
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
864864
Message: "Internal error watching agent's containers.",
@@ -877,14 +877,17 @@ func (api *API) watchWorkspaceAgentContainers(rw http.ResponseWriter, r *http.Re
877877
return
878878
}
879879

880+
ctx, cancel := context.WithCancel(r.Context())
881+
defer cancel()
882+
880883
// Here we close the websocket for reading, so that the websocket library will handle pings and
881884
// close frames.
882885
_ = conn.CloseRead(context.Background())
883886

884887
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
885888
defer wsNetConn.Close()
886889

887-
go httpapi.Heartbeat(ctx, conn)
890+
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
888891

889892
encoder := json.NewEncoder(wsNetConn)
890893

coderd/workspaceagents_internal_test.go

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,145 @@ func (fakeAgentProvider) Close() error {
5959
return nil
6060
}
6161

62+
type channelCloser struct {
63+
closeFn func()
64+
}
65+
66+
func (c *channelCloser) Close() error {
67+
c.closeFn()
68+
return nil
69+
}
70+
6271
func TestWatchAgentContainers(t *testing.T) {
6372
t.Parallel()
6473

65-
t.Run("WebSocketClosesProperly", func(t *testing.T) {
74+
t.Run("CoderdWebSocketCanHandleClientClosing", func(t *testing.T) {
75+
t.Parallel()
76+
77+
// This test ensures that the agent containers `/watch` websocket can gracefully
78+
// handle the client websocket closing. This test was created in
79+
// response to this issue: https://github.com/coder/coder/issues/19449
80+
81+
var (
82+
ctx = testutil.Context(t, testutil.WaitLong)
83+
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd")
84+
85+
mCtrl = gomock.NewController(t)
86+
mDB = dbmock.NewMockStore(mCtrl)
87+
mCoordinator = tailnettest.NewMockCoordinator(mCtrl)
88+
mAgentConn = agentconnmock.NewMockAgentConn(mCtrl)
89+
90+
fAgentProvider = fakeAgentProvider{
91+
agentConn: func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) {
92+
return mAgentConn, func() {}, nil
93+
},
94+
}
95+
96+
workspaceID = uuid.New()
97+
agentID = uuid.New()
98+
resourceID = uuid.New()
99+
jobID = uuid.New()
100+
buildID = uuid.New()
101+
102+
containersCh = make(chan codersdk.WorkspaceAgentListContainersResponse)
103+
104+
r = chi.NewMux()
105+
106+
api = API{
107+
ctx: ctx,
108+
Options: &Options{
109+
AgentInactiveDisconnectTimeout: testutil.WaitShort,
110+
Database: mDB,
111+
Logger: logger,
112+
DeploymentValues: &codersdk.DeploymentValues{},
113+
TailnetCoordinator: tailnettest.NewFakeCoordinator(),
114+
},
115+
}
116+
)
117+
118+
var tailnetCoordinator tailnet.Coordinator = mCoordinator
119+
api.TailnetCoordinator.Store(&tailnetCoordinator)
120+
api.agentProvider = fAgentProvider
121+
122+
// Setup: Allow `ExtractWorkspaceAgentParams` to complete.
123+
mDB.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(database.WorkspaceAgent{
124+
ID: agentID,
125+
ResourceID: resourceID,
126+
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
127+
FirstConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
128+
LastConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
129+
}, nil)
130+
mDB.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).Return(database.WorkspaceResource{
131+
ID: resourceID,
132+
JobID: jobID,
133+
}, nil)
134+
mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(database.ProvisionerJob{
135+
ID: jobID,
136+
Type: database.ProvisionerJobTypeWorkspaceBuild,
137+
}, nil)
138+
mDB.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), jobID).Return(database.WorkspaceBuild{
139+
WorkspaceID: workspaceID,
140+
ID: buildID,
141+
}, nil)
142+
143+
// And: Allow `db2dsk.WorkspaceAgent` to complete.
144+
mCoordinator.EXPECT().Node(gomock.Any()).Return(nil)
145+
146+
// And: Allow `WatchContainers` to be called, returing our `containersCh` channel.
147+
mAgentConn.EXPECT().WatchContainers(gomock.Any(), gomock.Any()).
148+
DoAndReturn(func(_ context.Context, _ slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) {
149+
return containersCh, &channelCloser{closeFn: func() {
150+
close(containersCh)
151+
}}, nil
152+
})
153+
154+
// And: We mount the HTTP Handler
155+
r.With(httpmw.ExtractWorkspaceAgentParam(mDB)).
156+
Get("/workspaceagents/{workspaceagent}/containers/watch", api.watchWorkspaceAgentContainers)
157+
158+
// Given: We create the HTTP server
159+
srv := httptest.NewServer(r)
160+
defer srv.Close()
161+
162+
// And: Dial the WebSocket
163+
wsURL := strings.Replace(srv.URL, "http://", "ws://", 1)
164+
conn, resp, err := websocket.Dial(ctx, fmt.Sprintf("%s/workspaceagents/%s/containers/watch", wsURL, agentID), nil)
165+
require.NoError(t, err)
166+
if resp.Body != nil {
167+
defer resp.Body.Close()
168+
}
169+
170+
// And: Create a streaming decoder
171+
decoder := wsjson.NewDecoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText, logger)
172+
defer decoder.Close()
173+
decodeCh := decoder.Chan()
174+
175+
// And: We can successfully send through the channel.
176+
testutil.RequireSend(ctx, t, containersCh, codersdk.WorkspaceAgentListContainersResponse{
177+
Containers: []codersdk.WorkspaceAgentContainer{{
178+
ID: "test-container-id",
179+
}},
180+
})
181+
182+
// And: Receive the data.
183+
containerResp := testutil.RequireReceive(ctx, t, decodeCh)
184+
require.Len(t, containerResp.Containers, 1)
185+
require.Equal(t, "test-container-id", containerResp.Containers[0].ID)
186+
187+
// When: We close the WebSocket
188+
conn.Close(websocket.StatusNormalClosure, "test closing connection")
189+
190+
// Then: We expect `containersCh` to be closed.
191+
select {
192+
case <-ctx.Done():
193+
t.Fail()
194+
195+
case _, ok := <-containersCh:
196+
require.False(t, ok, "channel is expected to be closed")
197+
}
198+
})
199+
200+
t.Run("CoderdWebSocketCanHandleAgentClosing", func(t *testing.T) {
66201
t.Parallel()
67202

68203
// This test ensures that the agent containers `/watch` websocket can gracefully

0 commit comments

Comments
 (0)