diff --git a/src/durable/notifications-hub.ts b/src/durable/notifications-hub.ts index b969e84..786f290 100644 --- a/src/durable/notifications-hub.ts +++ b/src/durable/notifications-hub.ts @@ -1,3 +1,4 @@ +import { DurableObject } from 'cloudflare:workers'; import type { Env } from '../types'; const SIGNALR_RECORD_SEPARATOR = 0x1e; @@ -6,11 +7,11 @@ const SIGNALR_UPDATE_TYPE_SYNC_VAULT = 5; const SIGNALR_UPDATE_TYPE_LOG_OUT = 11; const SIGNALR_UPDATE_TYPE_DEVICE_STATUS = 12; const SIGNALR_UPDATE_TYPE_BACKUP_RESTORE_PROGRESS = 13; -const SIGNALR_PING_INTERVAL_MS = 15_000; type HubProtocol = 'json' | 'messagepack'; -interface ConnectionState { +interface WsAttachment { + userId: string; handshakeComplete: boolean; protocol: HubProtocol; deviceIdentifier: string | null; @@ -145,10 +146,6 @@ function buildSignalRJsonInvocation( }) + String.fromCharCode(SIGNALR_RECORD_SEPARATOR); } -function buildSignalRJsonPing(): string { - return JSON.stringify({ type: 6 }) + String.fromCharCode(SIGNALR_RECORD_SEPARATOR); -} - function buildSignalRMessagePackInvocation( updateType: number, messagePayload: Record, @@ -172,24 +169,15 @@ function buildSignalRMessagePackInvocation( return frameSignalRBinary(encodedPayload); } -function buildSignalRMessagePackPing(): Uint8Array { - return frameSignalRBinary(encodeMsgPack([6])); -} - -function decodeIncomingMessage(data: string | ArrayBuffer | ArrayBufferView): string { - if (typeof data === 'string') return data; - if (data instanceof ArrayBuffer) return new TextDecoder().decode(new Uint8Array(data)); - return new TextDecoder().decode(new Uint8Array(data.buffer, data.byteOffset, data.byteLength)); -} - -export class NotificationsHub { - private readonly connections = new Map(); - private userId = ''; - private pingTimer: ReturnType | null = null; - - constructor(private readonly state: DurableObjectState, private readonly env: Env) { - void this.state; - void this.env; +export class NotificationsHub extends DurableObject { + constructor(ctx: DurableObjectState, env: Env) { + super(ctx, env); + this.ctx.setWebSocketAutoResponse( + new WebSocketRequestResponsePair( + JSON.stringify({ type: 6 }) + String.fromCharCode(SIGNALR_RECORD_SEPARATOR), + JSON.stringify({ type: 6 }) + String.fromCharCode(SIGNALR_RECORD_SEPARATOR) + ) + ); } async fetch(request: Request): Promise { @@ -205,14 +193,14 @@ export class NotificationsHub { payload?: Record | null; } | null; const revisionDate = String(body?.revisionDate || '').trim() || new Date().toISOString(); - this.userId = String(request.headers.get('X-NodeWarden-UserId') || body?.userId || this.userId).trim(); + const userId = String(request.headers.get('X-NodeWarden-UserId') || body?.userId || '').trim(); const contextId = String(body?.contextId || '').trim() || null; const updateType = Number(body?.updateType || SIGNALR_UPDATE_TYPE_SYNC_VAULT) || SIGNALR_UPDATE_TYPE_SYNC_VAULT; const targetDeviceIdentifier = String(body?.targetDeviceIdentifier || '').trim() || null; const payload = body?.payload && typeof body.payload === 'object' ? body.payload : { - UserId: this.userId, + UserId: userId, Date: revisionDate, }; this.broadcastMessage(updateType, payload, contextId, targetDeviceIdentifier); @@ -238,46 +226,27 @@ export class NotificationsHub { const requestUserId = String(url.searchParams.get('nw_uid') || '').trim(); const requestDeviceIdentifier = String(url.searchParams.get('nw_did') || '').trim() || null; - if (requestUserId) { - this.userId = requestUserId; - } - if (!this.userId) { + if (!requestUserId) { return new Response('Unauthorized', { status: 401 }); } const pair = new WebSocketPair(); const client = pair[0]; const server = pair[1]; - server.accept(); - this.connections.set(server, { + const tags: string[] = []; + if (requestDeviceIdentifier) { + tags.push(`device:${requestDeviceIdentifier}`); + } + this.ctx.acceptWebSocket(server, tags); + + server.serializeAttachment({ + userId: requestUserId, handshakeComplete: false, protocol: 'messagepack', deviceIdentifier: requestDeviceIdentifier, - }); - this.ensurePingLoop(); - - server.addEventListener('message', (event) => { - void this.handleSocketMessage(server, event.data); - }); - server.addEventListener('close', () => { - const shouldBroadcast = !!this.connections.get(server)?.handshakeComplete; - this.connections.delete(server); - this.stopPingLoopIfIdle(); - if (shouldBroadcast) this.broadcastDeviceStatus(); - }); - server.addEventListener('error', () => { - const shouldBroadcast = !!this.connections.get(server)?.handshakeComplete; - this.connections.delete(server); - this.stopPingLoopIfIdle(); - if (shouldBroadcast) this.broadcastDeviceStatus(); - try { - server.close(1011, 'Socket error'); - } catch { - // ignore close races - } - }); + } satisfies WsAttachment); return new Response(null, { status: 101, @@ -285,21 +254,23 @@ export class NotificationsHub { }); } - private async handleSocketMessage(socket: WebSocket, rawData: string | ArrayBuffer | ArrayBufferView): Promise { - const connection = this.connections.get(socket); - if (!connection) return; + async webSocketMessage(ws: WebSocket, message: string | ArrayBuffer): Promise { + const attachment = ws.deserializeAttachment() as WsAttachment | null; + if (!attachment) return; - if (!connection.handshakeComplete) { - const text = decodeIncomingMessage(rawData); + if (!attachment.handshakeComplete) { + const text = typeof message === 'string' + ? message + : new TextDecoder().decode(new Uint8Array(message)); const frames = text.split(String.fromCharCode(SIGNALR_RECORD_SEPARATOR)).filter(Boolean); for (const frame of frames) { try { const handshake = JSON.parse(frame) as { protocol?: string }; - const protocol = handshake.protocol === 'json' ? 'json' : 'messagepack'; - connection.protocol = protocol; - connection.handshakeComplete = true; - socket.send(SIGNALR_HANDSHAKE_ACK); - this.broadcastDeviceStatus(); + attachment.protocol = handshake.protocol === 'json' ? 'json' : 'messagepack'; + attachment.handshakeComplete = true; + ws.serializeAttachment(attachment); + ws.send(SIGNALR_HANDSHAKE_ACK); + this.broadcastDeviceStatus(attachment.userId); return; } catch { // Ignore malformed pre-handshake payloads. @@ -307,53 +278,48 @@ export class NotificationsHub { } return; } - } - private ensurePingLoop(): void { - if (this.pingTimer !== null) return; - this.pingTimer = setInterval(() => { - this.broadcastPing(); - }, SIGNALR_PING_INTERVAL_MS); - } - - private stopPingLoopIfIdle(): void { - if (this.connections.size > 0 || this.pingTimer === null) return; - clearInterval(this.pingTimer); - this.pingTimer = null; - } - - private broadcastPing(): void { - if (this.connections.size === 0) { - this.stopPingLoopIfIdle(); - return; - } - - for (const [socket, connection] of this.connections) { - if (!connection.handshakeComplete) continue; + if (message instanceof ArrayBuffer) { try { - if (connection.protocol === 'json') { - socket.send(buildSignalRJsonPing()); - } else { - socket.send(buildSignalRMessagePackPing()); - } + ws.send(message); } catch { - this.connections.delete(socket); - try { - socket.close(1011, 'Ping send failed'); - } catch { - // ignore close races - } + // ignore send errors on echo } } + } - this.stopPingLoopIfIdle(); + async webSocketClose(ws: WebSocket, code: number, reason: string, wasClean: boolean): Promise { + const attachment = ws.deserializeAttachment() as WsAttachment | null; + const shouldBroadcast = !!attachment?.handshakeComplete; + try { + ws.close(code, 'Durable Object is closing WebSocket'); + } catch { + // ignore close races + } + if (shouldBroadcast && attachment?.userId) { + this.broadcastDeviceStatus(attachment.userId); + } + } + + async webSocketError(ws: WebSocket, error: unknown): Promise { + const attachment = ws.deserializeAttachment() as WsAttachment | null; + const shouldBroadcast = !!attachment?.handshakeComplete; + try { + ws.close(1011, 'Socket error'); + } catch { + // ignore close races + } + if (shouldBroadcast && attachment?.userId) { + this.broadcastDeviceStatus(attachment.userId); + } } private getOnlineDeviceIdentifiers(): string[] { const out = new Set(); - for (const connection of this.connections.values()) { - if (!connection.handshakeComplete || !connection.deviceIdentifier) continue; - out.add(connection.deviceIdentifier); + for (const ws of this.ctx.getWebSockets()) { + const attachment = ws.deserializeAttachment() as WsAttachment | null; + if (!attachment?.handshakeComplete || !attachment.deviceIdentifier) continue; + out.add(attachment.deviceIdentifier); } return Array.from(out); } @@ -364,35 +330,36 @@ export class NotificationsHub { contextId: string | null, targetDeviceIdentifier: string | null ): void { - if (!this.userId || this.connections.size === 0) return; + const sockets = targetDeviceIdentifier + ? this.ctx.getWebSockets(`device:${targetDeviceIdentifier}`) + : this.ctx.getWebSockets(); - for (const [socket, connection] of this.connections) { - if (!connection.handshakeComplete) continue; - if (targetDeviceIdentifier && connection.deviceIdentifier !== targetDeviceIdentifier) continue; + if (sockets.length === 0) return; + + for (const ws of sockets) { + const attachment = ws.deserializeAttachment() as WsAttachment | null; + if (!attachment?.handshakeComplete) continue; try { - if (connection.protocol === 'json') { - socket.send(buildSignalRJsonInvocation(updateType, payload, contextId)); + if (attachment.protocol === 'json') { + ws.send(buildSignalRJsonInvocation(updateType, payload, contextId)); } else { - socket.send(buildSignalRMessagePackInvocation(updateType, payload, contextId)); + ws.send(buildSignalRMessagePackInvocation(updateType, payload, contextId)); } } catch { - this.connections.delete(socket); try { - socket.close(1011, 'Notification send failed'); + ws.close(1011, 'Notification send failed'); } catch { // ignore close races } } } - - this.stopPingLoopIfIdle(); } - private broadcastDeviceStatus(): void { + private broadcastDeviceStatus(userId: string): void { this.broadcastMessage( SIGNALR_UPDATE_TYPE_DEVICE_STATUS, { - UserId: this.userId, + UserId: userId, Date: new Date().toISOString(), }, null, diff --git a/webapp/src/App.tsx b/webapp/src/App.tsx index 7abf002..d10d6ea 100644 --- a/webapp/src/App.tsx +++ b/webapp/src/App.tsx @@ -884,6 +884,15 @@ export default function App() { return; } + let pingTimer: number | null = null; + + const clearPingTimer = () => { + if (pingTimer !== null) { + window.clearInterval(pingTimer); + pingTimer = null; + } + }; + socket.addEventListener('open', () => { reconnectAttempts = 0; void refreshAuthorizedDevicesRef.current(); @@ -891,7 +900,16 @@ export default function App() { socket?.send(`{"protocol":"json","version":1}${SIGNALR_RECORD_SEPARATOR}`); } catch { socket?.close(); + return; } + clearPingTimer(); + pingTimer = window.setInterval(() => { + try { + socket?.send(`{"type":6}${SIGNALR_RECORD_SEPARATOR}`); + } catch { + // send failure will trigger close event + } + }, 15_000); }); socket.addEventListener('message', (event) => { @@ -934,6 +952,7 @@ export default function App() { socket.addEventListener('close', () => { socket = null; + clearPingTimer(); void refreshAuthorizedDevicesRef.current(); scheduleReconnect(); }); @@ -952,9 +971,11 @@ export default function App() { return () => { disposed = true; clearReconnectTimer(); - if (socket && socket.readyState === WebSocket.OPEN) { + if (socket) { + const s = socket; + socket = null; try { - socket.close(); + s.close(); } catch { // ignore close races }