diff --git a/src/bun.js/api/server/NodeHTTPResponse.zig b/src/bun.js/api/server/NodeHTTPResponse.zig index 408693cae2f00b..1769ba630a14ef 100644 --- a/src/bun.js/api/server/NodeHTTPResponse.zig +++ b/src/bun.js/api/server/NodeHTTPResponse.zig @@ -122,13 +122,37 @@ pub fn getServerSocketValue(this: *NodeHTTPResponse) jsc.JSValue { pub fn pauseSocket(this: *NodeHTTPResponse) void { log("pauseSocket", .{}); + if (this.flags.socket_closed or this.flags.upgraded) { + return; + } this.raw_response.pause(); } pub fn resumeSocket(this: *NodeHTTPResponse) void { log("resumeSocket", .{}); + if (this.flags.socket_closed or this.flags.upgraded) { + return; + } this.raw_response.@"resume"(); } + +const OnBeforeOpen = struct { + this: *NodeHTTPResponse, + socketValue: jsc.JSValue, + globalObject: *jsc.JSGlobalObject, + + pub fn onBeforeOpen(ctx: *OnBeforeOpen, js_websocket: JSValue, socket: *uws.RawWebSocket) void { + Bun__setNodeHTTPServerSocketUsSocketValue(ctx.socketValue, socket.asSocket()); + ServerWebSocket.js.gc.socket.set(js_websocket, ctx.globalObject, ctx.socketValue); + ctx.this.flags.upgraded = true; + defer ctx.this.js_ref.unref(ctx.globalObject.bunVM()); + switch (ctx.this.raw_response) { + .SSL => ctx.this.raw_response = uws.AnyResponse.init(uws.NewApp(true).Response.castRes(@alignCast(@ptrCast(socket)))), + .TCP => ctx.this.raw_response = uws.AnyResponse.init(uws.NewApp(false).Response.castRes(@alignCast(@ptrCast(socket)))), + } + } +}; + pub fn upgrade(this: *NodeHTTPResponse, data_value: JSValue, sec_websocket_protocol: ZigString, sec_websocket_extensions: ZigString) bool { const upgrade_ctx = this.upgrade_context.context orelse return false; const ws_handler = this.server.webSocketHandler() orelse return false; @@ -149,61 +173,18 @@ pub fn upgrade(this: *NodeHTTPResponse, data_value: JSValue, sec_websocket_proto .this_value = data_value, }); - var new_socket: ?*uws.Socket = null; - defer if (new_socket) |socket| { - this.flags.upgraded = true; - Bun__setNodeHTTPServerSocketUsSocketValue(socketValue, socket); - ServerWebSocket.js.socketSetCached(ws.getThisValue(), ws_handler.globalObject, socketValue); - defer this.js_ref.unref(jsc.VirtualMachine.get()); - switch (this.raw_response) { - .SSL => this.raw_response = uws.AnyResponse.init(uws.NewApp(true).Response.castRes(@alignCast(@ptrCast(socket)))), - .TCP => this.raw_response = uws.AnyResponse.init(uws.NewApp(false).Response.castRes(@alignCast(@ptrCast(socket)))), - } - }; - - if (this.upgrade_context.request) |request| { - this.upgrade_context = .{}; - - var sec_websocket_protocol_str: ?ZigString.Slice = null; - var sec_websocket_extensions_str: ?ZigString.Slice = null; - - const sec_websocket_protocol_value = brk: { - if (sec_websocket_protocol.isEmpty()) { - break :brk request.header("sec-websocket-protocol") orelse ""; - } - sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator); - break :brk sec_websocket_protocol_str.?.slice(); - }; - - const sec_websocket_extensions_value = brk: { - if (sec_websocket_extensions.isEmpty()) { - break :brk request.header("sec-websocket-extensions") orelse ""; - } - sec_websocket_extensions_str = sec_websocket_protocol.toSlice(bun.default_allocator); - break :brk sec_websocket_extensions_str.?.slice(); - }; - defer { - if (sec_websocket_protocol_str) |str| str.deinit(); - if (sec_websocket_extensions_str) |str| str.deinit(); - } - - new_socket = this.raw_response.upgrade( - *ServerWebSocket, - ws, - request.header("sec-websocket-key") orelse "", - sec_websocket_protocol_value, - sec_websocket_extensions_value, - upgrade_ctx, - ); - return true; - } - var sec_websocket_protocol_str: ?ZigString.Slice = null; + defer if (sec_websocket_protocol_str) |*str| str.deinit(); var sec_websocket_extensions_str: ?ZigString.Slice = null; + defer if (sec_websocket_extensions_str) |*str| str.deinit(); const sec_websocket_protocol_value = brk: { if (sec_websocket_protocol.isEmpty()) { - break :brk this.upgrade_context.sec_websocket_protocol; + if (this.upgrade_context.request) |request| { + break :brk request.header("sec-websocket-protocol") orelse ""; + } else { + break :brk this.upgrade_context.sec_websocket_protocol; + } } sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator); break :brk sec_websocket_protocol_str.?.slice(); @@ -211,35 +192,48 @@ pub fn upgrade(this: *NodeHTTPResponse, data_value: JSValue, sec_websocket_proto const sec_websocket_extensions_value = brk: { if (sec_websocket_extensions.isEmpty()) { - break :brk this.upgrade_context.sec_websocket_extensions; + if (this.upgrade_context.request) |request| { + break :brk request.header("sec-websocket-extensions") orelse ""; + } else { + break :brk this.upgrade_context.sec_websocket_extensions; + } } - sec_websocket_extensions_str = sec_websocket_protocol.toSlice(bun.default_allocator); + sec_websocket_extensions_str = sec_websocket_extensions.toSlice(bun.default_allocator); break :brk sec_websocket_extensions_str.?.slice(); }; - defer { - if (sec_websocket_protocol_str) |str| str.deinit(); - if (sec_websocket_extensions_str) |str| str.deinit(); - } - - new_socket = this.raw_response.upgrade( - *ServerWebSocket, - ws, - this.upgrade_context.sec_websocket_key, - sec_websocket_protocol_value, - sec_websocket_extensions_value, - upgrade_ctx, - ); + + const websocket_key = if (this.upgrade_context.request) |request| + request.header("sec-websocket-key") orelse "" + else + this.upgrade_context.sec_websocket_key; + + var on_before_open = OnBeforeOpen{ + .this = this, + .socketValue = socketValue, + .globalObject = this.server.globalThis(), + }; + var on_before_open_ptr = WebSocketServerContext.Handler.OnBeforeOpen{ + .ctx = &on_before_open, + .callback = @ptrCast(&OnBeforeOpen.onBeforeOpen), + }; + + this.server.webSocketHandler().?.onBeforeOpen = &on_before_open_ptr; + _ = this.raw_response.upgrade(*ServerWebSocket, ws, websocket_key, sec_websocket_protocol_value, sec_websocket_extensions_value, upgrade_ctx); + return true; } pub fn maybeStopReadingBody(this: *NodeHTTPResponse, vm: *jsc.VirtualMachine, thisValue: jsc.JSValue) void { this.upgrade_context.deinit(); // we can discard the upgrade context now - if ((this.flags.socket_closed or this.flags.ended) and + if ((this.flags.upgraded or this.flags.socket_closed or this.flags.ended) and (this.body_read_ref.has or this.body_read_state == .pending) and (!this.flags.hasCustomOnData or js.onDataGetCached(thisValue) == null)) { const had_ref = this.body_read_ref.has; - this.raw_response.clearOnData(); + if (!this.flags.upgraded and !this.flags.socket_closed) { + this.raw_response.clearOnData(); + } + this.body_read_ref.unref(vm); this.body_read_state = .done; @@ -578,7 +572,7 @@ pub fn onTimeout(this: *NodeHTTPResponse, _: uws.AnyResponse) void { pub fn doPause(this: *NodeHTTPResponse, _: *jsc.JSGlobalObject, _: *jsc.CallFrame, thisValue: jsc.JSValue) bun.JSError!jsc.JSValue { log("doPause", .{}); - if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended) { + if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended or this.flags.upgraded) { return .false; } if (this.body_read_ref.has and js.onDataGetCached(thisValue) == null) { @@ -608,7 +602,7 @@ fn drainBufferedRequestBodyFromPause(this: *NodeHTTPResponse, globalObject: *jsc pub fn doResume(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, _: *jsc.CallFrame) jsc.JSValue { log("doResume", .{}); - if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended) { + if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended or this.flags.upgraded) { return .false; } @@ -671,7 +665,7 @@ pub export fn Bun__NodeHTTPRequest__onReject(globalObject: *jsc.JSGlobalObject, defer this.deref(); - if (!this.flags.request_has_completed and !this.flags.socket_closed) { + if (!this.flags.request_has_completed and !this.flags.socket_closed and !this.flags.upgraded) { const this_value = this.getThisValue(); if (this_value != .zero) { js.onAbortedSetCached(this_value, globalObject, .zero); @@ -787,7 +781,7 @@ fn onDrain(this: *NodeHTTPResponse, offset: u64, response: uws.AnyResponse) bool this.ref(); defer this.deref(); response.clearOnWritable(); - if (this.flags.socket_closed or this.flags.request_has_completed) { + if (this.flags.socket_closed or this.flags.request_has_completed or this.flags.upgraded) { // return false means we don't have anything to drain return false; } @@ -963,14 +957,14 @@ pub fn getOnWritable(_: *NodeHTTPResponse, thisValue: jsc.JSValue, _: *jsc.JSGlo } pub fn getOnAbort(this: *NodeHTTPResponse, thisValue: jsc.JSValue, _: *jsc.JSGlobalObject) jsc.JSValue { - if (this.flags.socket_closed) { + if (this.flags.socket_closed or this.flags.upgraded) { return .js_undefined; } return js.onAbortedGetCached(thisValue) orelse .js_undefined; } pub fn setOnAbort(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject, value: JSValue) void { - if (this.flags.socket_closed) { + if (this.flags.socket_closed or this.flags.upgraded) { return; } @@ -1002,7 +996,7 @@ fn clearOnDataCallback(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalOb if (thisValue != .zero) { js.onDataSetCached(thisValue, globalObject, .js_undefined); } - if (!this.flags.socket_closed) + if (!this.flags.socket_closed and !this.flags.upgraded) this.raw_response.clearOnData(); if (this.body_read_state != .done) { this.body_read_state = .done; @@ -1011,7 +1005,7 @@ fn clearOnDataCallback(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalOb } pub fn setOnData(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject, value: JSValue) void { - if (value.isUndefined() or this.flags.ended or this.flags.socket_closed or this.body_read_state == .none or this.flags.is_data_buffered_during_pause_last) { + if (value.isUndefined() or this.flags.ended or this.flags.socket_closed or this.body_read_state == .none or this.flags.is_data_buffered_during_pause_last or this.flags.upgraded) { js.onDataSetCached(thisValue, globalObject, .js_undefined); defer { if (this.body_read_ref.has) { @@ -1020,7 +1014,7 @@ pub fn setOnData(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject: } switch (this.body_read_state) { .pending, .done => { - if (!this.flags.request_has_completed and !this.flags.socket_closed) { + if (!this.flags.request_has_completed and !this.flags.socket_closed and !this.flags.upgraded) { this.raw_response.clearOnData(); } this.body_read_state = .done; @@ -1048,7 +1042,7 @@ pub fn write(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, callfra } pub fn flushHeaders(this: *NodeHTTPResponse, _: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!jsc.JSValue { - if (!this.flags.socket_closed) + if (!this.flags.socket_closed and !this.flags.upgraded) this.raw_response.flushHeaders(); return .js_undefined; @@ -1074,7 +1068,7 @@ fn handleCorked(globalObject: *jsc.JSGlobalObject, function: jsc.JSValue, result } pub fn setTimeout(this: *NodeHTTPResponse, seconds: u8) void { - if (this.flags.request_has_completed or this.flags.socket_closed) { + if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.upgraded) { return; } @@ -1087,7 +1081,7 @@ export fn NodeHTTPResponse__setTimeout(this: *NodeHTTPResponse, seconds: jsc.JSV return false; } - if (this.flags.request_has_completed or this.flags.socket_closed) { + if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.upgraded) { return false; } @@ -1105,7 +1099,7 @@ pub fn cork(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, callfram return globalObject.throwInvalidArgumentTypeValue("cork", "function", arguments[0]); } - if (this.flags.request_has_completed or this.flags.socket_closed) { + if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.upgraded) { return globalObject.ERR(.STREAM_ALREADY_FINISHED, "Stream is already ended", .{}).throw(); } @@ -1163,6 +1157,7 @@ pub export fn Bun__NodeHTTPResponse_setClosed(response: *NodeHTTPResponse) void const string = []const u8; +const WebSocketServerContext = @import("./WebSocketServerContext.zig"); const std = @import("std"); const bun = @import("bun"); diff --git a/src/bun.js/api/server/ServerWebSocket.zig b/src/bun.js/api/server/ServerWebSocket.zig index d2b3160fa932bc..18c3f682968e85 100644 --- a/src/bun.js/api/server/ServerWebSocket.zig +++ b/src/bun.js/api/server/ServerWebSocket.zig @@ -73,9 +73,20 @@ pub fn onOpen(this: *ServerWebSocket, ws: uws.AnyWebSocket) void { js.dataSetCached(current_this, globalObject, value_to_cache); } - if (onOpenHandler.isEmptyOrUndefinedOrNull()) return; + if (onOpenHandler.isEmptyOrUndefinedOrNull()) { + if (bun.take(&this.handler.onBeforeOpen)) |on_before_open| { + // Only create the "this" value if needed. + const this_value = this.getThisValue(); + on_before_open.callback(on_before_open.ctx, this_value, ws.raw()); + } + return; + } + const this_value = this.getThisValue(); var args = [_]JSValue{this_value}; + if (bun.take(&this.handler.onBeforeOpen)) |on_before_open| { + on_before_open.callback(on_before_open.ctx, this_value, ws.raw()); + } const loop = vm.eventLoop(); loop.enter(); diff --git a/src/bun.js/api/server/WebSocketServerContext.zig b/src/bun.js/api/server/WebSocketServerContext.zig index 8b92eddf68b29c..694715916fe92f 100644 --- a/src/bun.js/api/server/WebSocketServerContext.zig +++ b/src/bun.js/api/server/WebSocketServerContext.zig @@ -28,12 +28,25 @@ pub const Handler = struct { globalObject: *jsc.JSGlobalObject = undefined, active_connections: usize = 0, + /// Only used by NodeHTTPResponse. + /// + /// Before we call into JavaScript and after the WebSocket is upgraded, we need to call a function in NodeHTTPResponse. + /// + /// This is per-ServerWebSocket data, so it needs to be null'd on usage. + onBeforeOpen: ?*OnBeforeOpen = null, + /// used by publish() - flags: packed struct(u2) { + flags: packed struct(u8) { ssl: bool = false, publish_to_self: bool = false, + _: u6 = 0, } = .{}, + pub const OnBeforeOpen = struct { + ctx: *anyopaque, + callback: *const fn (*anyopaque, this_value: jsc.JSValue, socket: *uws.RawWebSocket) void, + }; + pub fn runErrorCallback(this: *const Handler, vm: *jsc.VirtualMachine, globalObject: *jsc.JSGlobalObject, error_value: jsc.JSValue) void { const onError = this.onError; if (!onError.isEmptyOrUndefinedOrNull()) { diff --git a/src/deps/uws/WebSocket.zig b/src/deps/uws/WebSocket.zig index 9a3c976f526196..31ea914e75e3fe 100644 --- a/src/deps/uws/WebSocket.zig +++ b/src/deps/uws/WebSocket.zig @@ -72,6 +72,15 @@ pub const RawWebSocket = opaque { pub fn memoryCost(this: *RawWebSocket, ssl_flag: i32) usize { return c.uws_ws_memory_cost(ssl_flag, this); } + + /// They're the same memory address. + /// + /// Equivalent to: + /// + /// (struct us_socket_t *)socket + pub fn asSocket(this: *RawWebSocket) *uws.Socket { + return @as(*uws.Socket, @ptrCast(this)); + } }; pub const AnyWebSocket = union(enum) { diff --git a/src/js/thirdparty/ws.js b/src/js/thirdparty/ws.js index 5cef4c4c2266e7..024a2f5d3d7489 100644 --- a/src/js/thirdparty/ws.js +++ b/src/js/thirdparty/ws.js @@ -1288,7 +1288,7 @@ class WebSocketServer extends EventEmitter { */ handleUpgrade(req, socket, head, cb) { // socket is actually fake so we use internal http_res - const response = socket._httpMessage; + const response = socket._httpMessage || socket[kBunInternals]; // socket.on("error", socketOnError); diff --git a/test/js/first_party/ws/ws.test.ts b/test/js/first_party/ws/ws.test.ts index a9a244ceb41151..679ae04aa505c0 100644 --- a/test/js/first_party/ws/ws.test.ts +++ b/test/js/first_party/ws/ws.test.ts @@ -553,6 +553,88 @@ it("WebSocketServer should handle backpressure", async () => { } }); +it("should abort incorrect WebSocket handshake", async () => { + const { promise, resolve, reject } = Promise.withResolvers(); + const wss = new WebSocketServer({ port: 0 }); + let connectionAttempted = false; + let testResolved = false; + + wss.on("connection", () => { + connectionAttempted = true; + if (!testResolved) { + testResolved = true; + reject(new Error("Connection should not have been established")); + } + }); + + wss.on("error", error => { + // Server errors are expected for invalid handshakes + console.log("Server error (expected):", error.message); + }); + + try { + const net = require("node:net"); + const port = (wss.address() as any).port; + const socket = net.createConnection(port, "localhost"); + + socket.on("connect", () => { + // Send an invalid WebSocket handshake request (invalid Sec-WebSocket-Key) + const invalidRequest = [ + "GET / HTTP/1.1", + "Host: localhost", + "Connection: Upgrade", + "Upgrade: websocket", + "Sec-WebSocket-Key: invalid-key", // Invalid key format + "Sec-WebSocket-Version: 13", + "", + "", + ].join("\r\n"); + + socket.write(invalidRequest); + }); + + let responseReceived = false; + socket.on("data", data => { + const response = data.toString(); + responseReceived = true; + + // Should receive a 400 Bad Request response for invalid handshake + if (response.includes("400") && !testResolved) { + testResolved = true; + resolve(); + } else if (!testResolved) { + testResolved = true; + reject(new Error(`Expected 400 response, got: ${response}`)); + } + socket.end(); + }); + + socket.on("error", error => { + // Connection errors are also acceptable as the server may close the connection + if (!testResolved) { + testResolved = true; + resolve(); + } + }); + + socket.on("close", () => { + // If we reach here without getting a proper response and connection wasn't attempted, + // the server properly rejected the invalid handshake + if (!responseReceived && !connectionAttempted && !testResolved) { + testResolved = true; + resolve(); + } + }); + + await promise; + } finally { + wss.close(); + } + + expect(connectionAttempted).toBeFalse(); + expect(testResolved).toBeTrue(); +}); + it("Server should be able to send empty pings", async () => { // WebSocket frame creation function with masking function createWebSocketFrame(message: string) {