Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 75 additions & 80 deletions src/bun.js/api/server/NodeHTTPResponse.zig
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,37 @@ pub fn getServerSocketValue(this: *NodeHTTPResponse) jsc.JSValue {

pub fn pauseSocket(this: *NodeHTTPResponse) void {
log("pauseSocket", .{});
if (this.flags.socket_closed or this.flags.upgraded) {
return;
}
this.raw_response.pause();
}

pub fn resumeSocket(this: *NodeHTTPResponse) void {
log("resumeSocket", .{});
if (this.flags.socket_closed or this.flags.upgraded) {
return;
}
this.raw_response.@"resume"();
}

const OnBeforeOpen = struct {
this: *NodeHTTPResponse,
socketValue: jsc.JSValue,
globalObject: *jsc.JSGlobalObject,

pub fn onBeforeOpen(ctx: *OnBeforeOpen, js_websocket: JSValue, socket: *uws.RawWebSocket) void {
Bun__setNodeHTTPServerSocketUsSocketValue(ctx.socketValue, socket.asSocket());
ServerWebSocket.js.gc.socket.set(js_websocket, ctx.globalObject, ctx.socketValue);
ctx.this.flags.upgraded = true;
defer ctx.this.js_ref.unref(ctx.globalObject.bunVM());
switch (ctx.this.raw_response) {
.SSL => ctx.this.raw_response = uws.AnyResponse.init(uws.NewApp(true).Response.castRes(@alignCast(@ptrCast(socket)))),
.TCP => ctx.this.raw_response = uws.AnyResponse.init(uws.NewApp(false).Response.castRes(@alignCast(@ptrCast(socket)))),
}
}
};

pub fn upgrade(this: *NodeHTTPResponse, data_value: JSValue, sec_websocket_protocol: ZigString, sec_websocket_extensions: ZigString) bool {
const upgrade_ctx = this.upgrade_context.context orelse return false;
const ws_handler = this.server.webSocketHandler() orelse return false;
Expand All @@ -149,97 +173,67 @@ pub fn upgrade(this: *NodeHTTPResponse, data_value: JSValue, sec_websocket_proto
.this_value = data_value,
});

var new_socket: ?*uws.Socket = null;
defer if (new_socket) |socket| {
this.flags.upgraded = true;
Bun__setNodeHTTPServerSocketUsSocketValue(socketValue, socket);
ServerWebSocket.js.socketSetCached(ws.getThisValue(), ws_handler.globalObject, socketValue);
defer this.js_ref.unref(jsc.VirtualMachine.get());
switch (this.raw_response) {
.SSL => this.raw_response = uws.AnyResponse.init(uws.NewApp(true).Response.castRes(@alignCast(@ptrCast(socket)))),
.TCP => this.raw_response = uws.AnyResponse.init(uws.NewApp(false).Response.castRes(@alignCast(@ptrCast(socket)))),
}
};

if (this.upgrade_context.request) |request| {
this.upgrade_context = .{};

var sec_websocket_protocol_str: ?ZigString.Slice = null;
var sec_websocket_extensions_str: ?ZigString.Slice = null;

const sec_websocket_protocol_value = brk: {
if (sec_websocket_protocol.isEmpty()) {
break :brk request.header("sec-websocket-protocol") orelse "";
}
sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator);
break :brk sec_websocket_protocol_str.?.slice();
};

const sec_websocket_extensions_value = brk: {
if (sec_websocket_extensions.isEmpty()) {
break :brk request.header("sec-websocket-extensions") orelse "";
}
sec_websocket_extensions_str = sec_websocket_protocol.toSlice(bun.default_allocator);
break :brk sec_websocket_extensions_str.?.slice();
};
defer {
if (sec_websocket_protocol_str) |str| str.deinit();
if (sec_websocket_extensions_str) |str| str.deinit();
}

new_socket = this.raw_response.upgrade(
*ServerWebSocket,
ws,
request.header("sec-websocket-key") orelse "",
sec_websocket_protocol_value,
sec_websocket_extensions_value,
upgrade_ctx,
);
return true;
}

var sec_websocket_protocol_str: ?ZigString.Slice = null;
defer if (sec_websocket_protocol_str) |*str| str.deinit();
var sec_websocket_extensions_str: ?ZigString.Slice = null;
defer if (sec_websocket_extensions_str) |*str| str.deinit();

const sec_websocket_protocol_value = brk: {
if (sec_websocket_protocol.isEmpty()) {
break :brk this.upgrade_context.sec_websocket_protocol;
if (this.upgrade_context.request) |request| {
break :brk request.header("sec-websocket-protocol") orelse "";
} else {
break :brk this.upgrade_context.sec_websocket_protocol;
}
}
sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator);
break :brk sec_websocket_protocol_str.?.slice();
};

const sec_websocket_extensions_value = brk: {
if (sec_websocket_extensions.isEmpty()) {
break :brk this.upgrade_context.sec_websocket_extensions;
if (this.upgrade_context.request) |request| {
break :brk request.header("sec-websocket-extensions") orelse "";
} else {
break :brk this.upgrade_context.sec_websocket_extensions;
}
}
sec_websocket_extensions_str = sec_websocket_protocol.toSlice(bun.default_allocator);
sec_websocket_extensions_str = sec_websocket_extensions.toSlice(bun.default_allocator);
break :brk sec_websocket_extensions_str.?.slice();
};
defer {
if (sec_websocket_protocol_str) |str| str.deinit();
if (sec_websocket_extensions_str) |str| str.deinit();
}

new_socket = this.raw_response.upgrade(
*ServerWebSocket,
ws,
this.upgrade_context.sec_websocket_key,
sec_websocket_protocol_value,
sec_websocket_extensions_value,
upgrade_ctx,
);

const websocket_key = if (this.upgrade_context.request) |request|
request.header("sec-websocket-key") orelse ""
else
this.upgrade_context.sec_websocket_key;

var on_before_open = OnBeforeOpen{
.this = this,
.socketValue = socketValue,
.globalObject = this.server.globalThis(),
};
var on_before_open_ptr = WebSocketServerContext.Handler.OnBeforeOpen{
.ctx = &on_before_open,
.callback = @ptrCast(&OnBeforeOpen.onBeforeOpen),
};

this.server.webSocketHandler().?.onBeforeOpen = &on_before_open_ptr;
_ = this.raw_response.upgrade(*ServerWebSocket, ws, websocket_key, sec_websocket_protocol_value, sec_websocket_extensions_value, upgrade_ctx);

return true;
}
pub fn maybeStopReadingBody(this: *NodeHTTPResponse, vm: *jsc.VirtualMachine, thisValue: jsc.JSValue) void {
this.upgrade_context.deinit(); // we can discard the upgrade context now

if ((this.flags.socket_closed or this.flags.ended) and
if ((this.flags.upgraded or this.flags.socket_closed or this.flags.ended) and
(this.body_read_ref.has or this.body_read_state == .pending) and
(!this.flags.hasCustomOnData or js.onDataGetCached(thisValue) == null))
{
const had_ref = this.body_read_ref.has;
this.raw_response.clearOnData();
if (!this.flags.upgraded and !this.flags.socket_closed) {
this.raw_response.clearOnData();
}

this.body_read_ref.unref(vm);
this.body_read_state = .done;

Expand Down Expand Up @@ -578,7 +572,7 @@ pub fn onTimeout(this: *NodeHTTPResponse, _: uws.AnyResponse) void {

pub fn doPause(this: *NodeHTTPResponse, _: *jsc.JSGlobalObject, _: *jsc.CallFrame, thisValue: jsc.JSValue) bun.JSError!jsc.JSValue {
log("doPause", .{});
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended) {
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended or this.flags.upgraded) {
return .false;
}
if (this.body_read_ref.has and js.onDataGetCached(thisValue) == null) {
Expand Down Expand Up @@ -608,7 +602,7 @@ fn drainBufferedRequestBodyFromPause(this: *NodeHTTPResponse, globalObject: *jsc

pub fn doResume(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, _: *jsc.CallFrame) jsc.JSValue {
log("doResume", .{});
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended) {
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended or this.flags.upgraded) {
return .false;
}

Expand Down Expand Up @@ -671,7 +665,7 @@ pub export fn Bun__NodeHTTPRequest__onReject(globalObject: *jsc.JSGlobalObject,

defer this.deref();

if (!this.flags.request_has_completed and !this.flags.socket_closed) {
if (!this.flags.request_has_completed and !this.flags.socket_closed and !this.flags.upgraded) {
const this_value = this.getThisValue();
if (this_value != .zero) {
js.onAbortedSetCached(this_value, globalObject, .zero);
Expand Down Expand Up @@ -787,7 +781,7 @@ fn onDrain(this: *NodeHTTPResponse, offset: u64, response: uws.AnyResponse) bool
this.ref();
defer this.deref();
response.clearOnWritable();
if (this.flags.socket_closed or this.flags.request_has_completed) {
if (this.flags.socket_closed or this.flags.request_has_completed or this.flags.upgraded) {
// return false means we don't have anything to drain
return false;
}
Expand Down Expand Up @@ -963,14 +957,14 @@ pub fn getOnWritable(_: *NodeHTTPResponse, thisValue: jsc.JSValue, _: *jsc.JSGlo
}

pub fn getOnAbort(this: *NodeHTTPResponse, thisValue: jsc.JSValue, _: *jsc.JSGlobalObject) jsc.JSValue {
if (this.flags.socket_closed) {
if (this.flags.socket_closed or this.flags.upgraded) {
return .js_undefined;
}
return js.onAbortedGetCached(thisValue) orelse .js_undefined;
}

pub fn setOnAbort(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject, value: JSValue) void {
if (this.flags.socket_closed) {
if (this.flags.socket_closed or this.flags.upgraded) {
return;
}

Expand Down Expand Up @@ -1002,7 +996,7 @@ fn clearOnDataCallback(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalOb
if (thisValue != .zero) {
js.onDataSetCached(thisValue, globalObject, .js_undefined);
}
if (!this.flags.socket_closed)
if (!this.flags.socket_closed and !this.flags.upgraded)
this.raw_response.clearOnData();
if (this.body_read_state != .done) {
this.body_read_state = .done;
Expand All @@ -1011,7 +1005,7 @@ fn clearOnDataCallback(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalOb
}

pub fn setOnData(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject, value: JSValue) void {
if (value.isUndefined() or this.flags.ended or this.flags.socket_closed or this.body_read_state == .none or this.flags.is_data_buffered_during_pause_last) {
if (value.isUndefined() or this.flags.ended or this.flags.socket_closed or this.body_read_state == .none or this.flags.is_data_buffered_during_pause_last or this.flags.upgraded) {
js.onDataSetCached(thisValue, globalObject, .js_undefined);
defer {
if (this.body_read_ref.has) {
Expand All @@ -1020,7 +1014,7 @@ pub fn setOnData(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject:
}
switch (this.body_read_state) {
.pending, .done => {
if (!this.flags.request_has_completed and !this.flags.socket_closed) {
if (!this.flags.request_has_completed and !this.flags.socket_closed and !this.flags.upgraded) {
this.raw_response.clearOnData();
}
this.body_read_state = .done;
Expand Down Expand Up @@ -1048,7 +1042,7 @@ pub fn write(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, callfra
}

pub fn flushHeaders(this: *NodeHTTPResponse, _: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!jsc.JSValue {
if (!this.flags.socket_closed)
if (!this.flags.socket_closed and !this.flags.upgraded)
this.raw_response.flushHeaders();

return .js_undefined;
Expand All @@ -1074,7 +1068,7 @@ fn handleCorked(globalObject: *jsc.JSGlobalObject, function: jsc.JSValue, result
}

pub fn setTimeout(this: *NodeHTTPResponse, seconds: u8) void {
if (this.flags.request_has_completed or this.flags.socket_closed) {
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.upgraded) {
return;
}

Expand All @@ -1087,7 +1081,7 @@ export fn NodeHTTPResponse__setTimeout(this: *NodeHTTPResponse, seconds: jsc.JSV
return false;
}

if (this.flags.request_has_completed or this.flags.socket_closed) {
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.upgraded) {
return false;
}

Expand All @@ -1105,7 +1099,7 @@ pub fn cork(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, callfram
return globalObject.throwInvalidArgumentTypeValue("cork", "function", arguments[0]);
}

if (this.flags.request_has_completed or this.flags.socket_closed) {
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.upgraded) {
return globalObject.ERR(.STREAM_ALREADY_FINISHED, "Stream is already ended", .{}).throw();
}

Expand Down Expand Up @@ -1163,6 +1157,7 @@ pub export fn Bun__NodeHTTPResponse_setClosed(response: *NodeHTTPResponse) void

const string = []const u8;

const WebSocketServerContext = @import("./WebSocketServerContext.zig");
const std = @import("std");

const bun = @import("bun");
Expand Down
13 changes: 12 additions & 1 deletion src/bun.js/api/server/ServerWebSocket.zig
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,20 @@ pub fn onOpen(this: *ServerWebSocket, ws: uws.AnyWebSocket) void {
js.dataSetCached(current_this, globalObject, value_to_cache);
}

if (onOpenHandler.isEmptyOrUndefinedOrNull()) return;
if (onOpenHandler.isEmptyOrUndefinedOrNull()) {
if (bun.take(&this.handler.onBeforeOpen)) |on_before_open| {
// Only create the "this" value if needed.
const this_value = this.getThisValue();
on_before_open.callback(on_before_open.ctx, this_value, ws.raw());
}
return;
}

const this_value = this.getThisValue();
var args = [_]JSValue{this_value};
if (bun.take(&this.handler.onBeforeOpen)) |on_before_open| {
on_before_open.callback(on_before_open.ctx, this_value, ws.raw());
}

const loop = vm.eventLoop();
loop.enter();
Expand Down
15 changes: 14 additions & 1 deletion src/bun.js/api/server/WebSocketServerContext.zig
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,25 @@ pub const Handler = struct {
globalObject: *jsc.JSGlobalObject = undefined,
active_connections: usize = 0,

/// Only used by NodeHTTPResponse.
///
/// Before we call into JavaScript and after the WebSocket is upgraded, we need to call a function in NodeHTTPResponse.
///
/// This is per-ServerWebSocket data, so it needs to be null'd on usage.
onBeforeOpen: ?*OnBeforeOpen = null,

/// used by publish()
flags: packed struct(u2) {
flags: packed struct(u8) {
ssl: bool = false,
publish_to_self: bool = false,
_: u6 = 0,
} = .{},

pub const OnBeforeOpen = struct {
ctx: *anyopaque,
callback: *const fn (*anyopaque, this_value: jsc.JSValue, socket: *uws.RawWebSocket) void,
};

pub fn runErrorCallback(this: *const Handler, vm: *jsc.VirtualMachine, globalObject: *jsc.JSGlobalObject, error_value: jsc.JSValue) void {
const onError = this.onError;
if (!onError.isEmptyOrUndefinedOrNull()) {
Expand Down
9 changes: 9 additions & 0 deletions src/deps/uws/WebSocket.zig
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ pub const RawWebSocket = opaque {
pub fn memoryCost(this: *RawWebSocket, ssl_flag: i32) usize {
return c.uws_ws_memory_cost(ssl_flag, this);
}

/// They're the same memory address.
///
/// Equivalent to:
///
/// (struct us_socket_t *)socket
pub fn asSocket(this: *RawWebSocket) *uws.Socket {
return @as(*uws.Socket, @ptrCast(this));
}
};

pub const AnyWebSocket = union(enum) {
Expand Down
2 changes: 1 addition & 1 deletion src/js/thirdparty/ws.js
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ class WebSocketServer extends EventEmitter {
*/
handleUpgrade(req, socket, head, cb) {
// socket is actually fake so we use internal http_res
const response = socket._httpMessage;
const response = socket._httpMessage || socket[kBunInternals];

// socket.on("error", socketOnError);

Expand Down
Loading