diff --git a/src/libraries/Common/tests/System/Net/Http/Http2Frames.cs b/src/libraries/Common/tests/System/Net/Http/Http2Frames.cs index 214dab635061a6..5f0ef36311049a 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2Frames.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2Frames.cs @@ -44,7 +44,8 @@ public enum SettingId : ushort MaxConcurrentStreams = 0x3, InitialWindowSize = 0x4, MaxFrameSize = 0x5, - MaxHeaderListSize = 0x6 + MaxHeaderListSize = 0x6, + EnableConnect = 0x8 } public class Frame diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs index edbefefb6ac1f8..9d17c93f0d0cb5 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs @@ -44,6 +44,10 @@ public override Uri Address localEndPoint.Address.ToString(); string scheme = _options.UseSsl ? "https" : "http"; + if (_options.WebSocketEndpoint) + { + scheme = _options.UseSsl ? "wss" : "ws"; + } _uri = new Uri($"{scheme}://{host}:{localEndPoint.Port}/"); @@ -177,6 +181,7 @@ public static async Task CreateClientAndServerAsync(Func clientFunc, public class Http2Options : GenericLoopbackOptions { + public bool WebSocketEndpoint { get; set; } = false; public bool ClientCertificateRequired { get; set; } public bool EnableTransparentPingResponse { get; set; } = true; diff --git a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs index 87fac367760f2d..1592265729f804 100644 --- a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs +++ b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs @@ -1031,7 +1031,7 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => Assert.Equal(PlatformDetection.IsBrowser && !enableWasmStreaming, responseStream.CanSeek); // Not supported operations - Assert.Throws(() => responseStream.BeginWrite(new byte[1], 0, 1, null, null)); + await Assert.ThrowsAsync(async () => await Task.Factory.FromAsync(responseStream.BeginWrite, responseStream.EndWrite, new byte[1], 0, 1, null)); if (!responseStream.CanSeek) { Assert.Throws(() => responseStream.Length); diff --git a/src/libraries/System.Net.Http/ref/System.Net.Http.cs b/src/libraries/System.Net.Http/ref/System.Net.Http.cs index 65288819174af6..d5ba8339edd54a 100644 --- a/src/libraries/System.Net.Http/ref/System.Net.Http.cs +++ b/src/libraries/System.Net.Http/ref/System.Net.Http.cs @@ -231,6 +231,7 @@ public HttpMethod(string method) { } public static System.Net.Http.HttpMethod Post { get { throw null; } } public static System.Net.Http.HttpMethod Put { get { throw null; } } public static System.Net.Http.HttpMethod Trace { get { throw null; } } + public static System.Net.Http.HttpMethod Connect { get { throw null; } } public bool Equals([System.Diagnostics.CodeAnalysis.NotNullWhen(true)] System.Net.Http.HttpMethod? other) { throw null; } public override bool Equals([System.Diagnostics.CodeAnalysis.NotNullWhen(true)] object? obj) { throw null; } public override int GetHashCode() { throw null; } @@ -661,6 +662,7 @@ internal HttpRequestHeaders() { } public System.DateTimeOffset? IfUnmodifiedSince { get { throw null; } set { } } public int? MaxForwards { get { throw null; } set { } } public System.Net.Http.Headers.HttpHeaderValueCollection Pragma { get { throw null; } } + public string? Protocol { get { throw null; } set { } } public System.Net.Http.Headers.AuthenticationHeaderValue? ProxyAuthorization { get { throw null; } set { } } public System.Net.Http.Headers.RangeHeaderValue? Range { get { throw null; } set { } } public System.Uri? Referrer { get { throw null; } set { } } diff --git a/src/libraries/System.Net.Http/src/Resources/Strings.resx b/src/libraries/System.Net.Http/src/Resources/Strings.resx index 20a30c23e9a9f8..3ac957f5d8107c 100644 --- a/src/libraries/System.Net.Http/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Http/src/Resources/Strings.resx @@ -168,6 +168,9 @@ The stream does not support writing. + + The stream does not support reading. + The character set provided in ContentType is invalid. Cannot read content as string using an invalid character set. @@ -564,4 +567,7 @@ The HTTP/1.1 response chunk was too large. - \ No newline at end of file + + Failed to establish web socket connection over HTTP/2 because extended CONNECT is not supported. Try to downgrade the request version to HTTP/1.1. + + diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpRequestHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpRequestHeaders.cs index 3a223a3a309443..ed287aaec6d5a1 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpRequestHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpRequestHeaders.cs @@ -21,6 +21,7 @@ public sealed class HttpRequestHeaders : HttpHeaders private HttpGeneralHeaders? _generalHeaders; private HttpHeaderValueCollection? _expect; private bool _expectContinueSet; + private string? _protocol; #region Request Headers @@ -159,6 +160,15 @@ public int? MaxForwards set { SetOrRemoveParsedValue(KnownHeaders.MaxForwards.Descriptor, value); } } + public string? Protocol + { + get => _protocol; + set + { + CheckContainsNewLine(value); + _protocol = value; + } + } public AuthenticationHeaderValue? ProxyAuthorization { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HttpMethod.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HttpMethod.cs index 6fddc9dcb964fc..5743ff3ec8bef0 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/HttpMethod.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/HttpMethod.cs @@ -67,7 +67,7 @@ public static HttpMethod Patch // Don't expose CONNECT as static property, since it's used by the transport to connect to a proxy. // CONNECT is not used by users directly. - internal static HttpMethod Connect + public static HttpMethod Connect { get { return s_connectMethod; } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HttpRequestMessage.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HttpRequestMessage.cs index 4a9470d7b85e37..f258c5ca8ae800 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/HttpRequestMessage.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/HttpRequestMessage.cs @@ -168,6 +168,8 @@ public override string ToString() internal bool WasRedirected() => (_sendStatus & MessageIsRedirect) != 0; + internal bool IsWebSocketH2Request() => _version.Major == 2 && Method == HttpMethod.Connect && HasHeaders && string.Equals(Headers.Protocol, "websocket", StringComparison.OrdinalIgnoreCase); + #region IDisposable Members protected virtual void Dispose(bool disposing) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 680ad58e3f7978..30c8761de4eeb4 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -8,7 +8,6 @@ using System.IO; using System.Net.Http.Headers; using System.Net.Http.HPack; -using System.Net.Security; using System.Runtime.CompilerServices; using System.Text; using System.Threading; @@ -19,6 +18,13 @@ namespace System.Net.Http { internal sealed partial class Http2Connection : HttpConnectionBase { + // Equivalent to the bytes returned from HPackEncoder.EncodeLiteralHeaderFieldWithoutIndexingNewNameToAllocatedArray(":protocol") + private static ReadOnlySpan ProtocolLiteralHeaderBytes => new byte[] { 0x0, 0x9, 0x3a, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c }; + + private static readonly TaskCompletionSourceWithCancellation s_settingsReceivedSingleton = CreateSuccessfullyCompletedTcs(); + + private TaskCompletionSourceWithCancellation? _initialSettingsReceived; + private readonly HttpConnectionPool _pool; private readonly Stream _stream; @@ -174,6 +180,13 @@ static long TimeSpanToMs(TimeSpan value) { private object SyncObject => _httpStreams; + internal TaskCompletionSourceWithCancellation InitialSettingsReceived => + _initialSettingsReceived ?? + Interlocked.CompareExchange(ref _initialSettingsReceived, new(), null) ?? + _initialSettingsReceived; + + internal bool IsConnectEnabled { get; private set; } + public async ValueTask SetupAsync(CancellationToken cancellationToken) { try @@ -832,6 +845,20 @@ private void ProcessSettingsFrame(FrameHeader frameHeader, bool initialFrame = f // We don't actually store this value; we always send frames of the minimum size (16K). break; + case SettingId.EnableConnect: + if (settingValue == 1) + { + IsConnectEnabled = true; + } + else if (settingValue == 0 && IsConnectEnabled) + { + // Accroding to RFC: a sender MUST NOT send a SETTINGS_ENABLE_CONNECT_PROTOCOL parameter + // with the value of 0 after previously sending a value of 1. + // https://datatracker.ietf.org/doc/html/rfc8441#section-3 + ThrowProtocolError(); + } + break; + default: // All others are ignored because we don't care about them. // Note, per RFC, unknown settings IDs should be ignored. @@ -839,10 +866,20 @@ private void ProcessSettingsFrame(FrameHeader frameHeader, bool initialFrame = f } } - if (initialFrame && !maxConcurrentStreamsReceived) + if (initialFrame) { - // Set to 'infinite' because MaxConcurrentStreams was not set on the initial SETTINGS frame. - ChangeMaxConcurrentStreams(int.MaxValue); + if (!maxConcurrentStreamsReceived) + { + // Set to 'infinite' because MaxConcurrentStreams was not set on the initial SETTINGS frame. + ChangeMaxConcurrentStreams(int.MaxValue); + } + + if (_initialSettingsReceived is null) + { + Interlocked.CompareExchange(ref _initialSettingsReceived, s_settingsReceivedSingleton, null); + } + // Set result in case if CompareExchange lost the race + InitialSettingsReceived.TrySetResult(true); } _incomingBuffer.Discard(frameHeader.PayloadLength); @@ -1455,6 +1492,13 @@ private void WriteHeaders(HttpRequestMessage request, ref ArrayBuffer headerBuff if (request.HasHeaders) { + if (request.Headers.Protocol != null) + { + WriteBytes(ProtocolLiteralHeaderBytes, ref headerBuffer); + Encoding? protocolEncoding = _pool.Settings._requestHeaderEncodingSelector?.Invoke(":protocol", request); + WriteLiteralHeaderValue(request.Headers.Protocol, protocolEncoding, ref headerBuffer); + } + WriteHeaderCollection(request, request.Headers, ref headerBuffer); } @@ -1895,7 +1939,15 @@ private enum SettingId : ushort MaxConcurrentStreams = 0x3, InitialWindowSize = 0x4, MaxFrameSize = 0x5, - MaxHeaderListSize = 0x6 + MaxHeaderListSize = 0x6, + EnableConnect = 0x8 + } + + private static TaskCompletionSourceWithCancellation CreateSuccessfullyCompletedTcs() + { + var tcs = new TaskCompletionSourceWithCancellation(); + tcs.TrySetResult(true); + return tcs; } // Note that this is safe to be called concurrently by multiple threads. diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index 172b192d4eb16e..b3f82779a715db 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -11,6 +11,7 @@ using System.Runtime.ExceptionServices; using System.Text; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using System.Threading.Tasks.Sources; @@ -44,6 +45,7 @@ private sealed class Http2Stream : IValueTaskSource, IHttpStreamHeadersHandler, private StreamCompletionState _requestCompletionState; private StreamCompletionState _responseCompletionState; private ResponseProtocolState _responseProtocolState; + private bool _webSocketEstablished; // If this is not null, then we have received a reset from the server // (i.e. RST_STREAM or general IO error processing the connection) @@ -106,6 +108,10 @@ public Http2Stream(HttpRequestMessage request, Http2Connection connection) if (_request.Content == null) { _requestCompletionState = StreamCompletionState.Completed; + if (_request.IsWebSocketH2Request()) + { + _requestBodyCancellationSource = new CancellationTokenSource(); + } } else { @@ -629,6 +635,11 @@ private void OnStatus(int statusCode) } else { + if (statusCode == 200 && _response.RequestMessage!.IsWebSocketH2Request()) + { + _webSocketEstablished = true; + } + _responseProtocolState = ResponseProtocolState.ExpectingHeaders; // If we are waiting for a 100-continue response, signal the waiter now. @@ -1036,6 +1047,10 @@ public async Task ReadResponseHeadersAsync(CancellationToken cancellationToken) MoveTrailersToResponseMessage(_response); responseContent.SetStream(EmptyReadStream.Instance); } + else if (_webSocketEstablished) + { + responseContent.SetStream(new Http2ReadWriteStream(this)); + } else { responseContent.SetStream(new Http2ReadStream(this)); @@ -1417,12 +1432,61 @@ private enum StreamCompletionState : byte Failed } - private sealed class Http2ReadStream : HttpBaseStream + private sealed class Http2ReadStream : Http2ReadWriteStream + { + public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream) + { + base.CloseResponseBodyOnDispose = true; + } + + public override bool CanWrite => false; + + public override void Write(ReadOnlySpan buffer) => throw new NotSupportedException(SR.net_http_content_readonly_stream); + + public override ValueTask WriteAsync(ReadOnlyMemory destination, CancellationToken cancellationToken) => ValueTask.FromException(new NotSupportedException(SR.net_http_content_readonly_stream)); + } + + private sealed class Http2WriteStream : Http2ReadWriteStream + { + public long BytesWritten { get; private set; } + + public long ContentLength { get; } + + public Http2WriteStream(Http2Stream http2Stream, long contentLength) : base(http2Stream) + { + Debug.Assert(contentLength >= -1); + ContentLength = contentLength; + } + + public override bool CanRead => false; + + public override int Read(Span buffer) => throw new NotSupportedException(SR.net_http_content_writeonly_stream); + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) => ValueTask.FromException(new NotSupportedException(SR.net_http_content_writeonly_stream)); + + public override void CopyTo(Stream destination, int bufferSize) => throw new NotSupportedException(SR.net_http_content_writeonly_stream); + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) => Task.FromException(new NotSupportedException(SR.net_http_content_writeonly_stream)); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + BytesWritten += buffer.Length; + + if ((ulong)BytesWritten > (ulong)ContentLength) // If ContentLength == -1, this will always be false + { + return ValueTask.FromException(new HttpRequestException(SR.net_http_content_write_larger_than_content_length)); + } + + return base.WriteAsync(buffer, cancellationToken); + } + } + + public class Http2ReadWriteStream : HttpBaseStream { private Http2Stream? _http2Stream; private readonly HttpResponseMessage _responseMessage; - public Http2ReadStream(Http2Stream http2Stream) + public Http2ReadWriteStream(Http2Stream http2Stream) { Debug.Assert(http2Stream != null); Debug.Assert(http2Stream._response != null); @@ -1430,7 +1494,7 @@ public Http2ReadStream(Http2Stream http2Stream) _responseMessage = _http2Stream._response; } - ~Http2ReadStream() + ~Http2ReadWriteStream() { if (NetEventSource.Log.IsEnabled()) _http2Stream?.Trace(""); try @@ -1443,6 +1507,8 @@ public Http2ReadStream(Http2Stream http2Stream) } } + protected bool CloseResponseBodyOnDispose { get; set; } + protected override void Dispose(bool disposing) { Http2Stream? http2Stream = Interlocked.Exchange(ref _http2Stream, null); @@ -1456,14 +1522,16 @@ protected override void Dispose(bool disposing) // protocol, we have little choice: if someone drops the Http2ReadStream without // disposing of it, we need to a) signal to the server that the stream is being // canceled, and b) clean up the associated state in the Http2Connection. - - http2Stream.CloseResponseBody(); + if (CloseResponseBodyOnDispose) + { + http2Stream.CloseResponseBody(); + } base.Dispose(disposing); } public override bool CanRead => _http2Stream != null; - public override bool CanWrite => false; + public override bool CanWrite => _http2Stream != null; public override int Read(Span destination) { @@ -1507,53 +1575,8 @@ public override Task CopyToAsync(Stream destination, int bufferSize, Cancellatio http2Stream.CopyToAsync(_responseMessage, destination, bufferSize, cancellationToken); } - public override void Write(ReadOnlySpan buffer) => throw new NotSupportedException(SR.net_http_content_readonly_stream); - - public override ValueTask WriteAsync(ReadOnlyMemory destination, CancellationToken cancellationToken) => throw new NotSupportedException(); - } - - private sealed class Http2WriteStream : HttpBaseStream - { - private Http2Stream? _http2Stream; - - public long BytesWritten { get; private set; } - - public long ContentLength { get; private set; } - - public Http2WriteStream(Http2Stream http2Stream, long contentLength) - { - Debug.Assert(http2Stream != null); - Debug.Assert(contentLength >= -1); - _http2Stream = http2Stream; - ContentLength = contentLength; - } - - protected override void Dispose(bool disposing) - { - Http2Stream? http2Stream = Interlocked.Exchange(ref _http2Stream, null); - if (http2Stream == null) - { - return; - } - - base.Dispose(disposing); - } - - public override bool CanRead => false; - public override bool CanWrite => _http2Stream != null; - - public override int Read(Span buffer) => throw new NotSupportedException(); - - public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) => throw new NotSupportedException(); - public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) { - BytesWritten += buffer.Length; - - if ((ulong)BytesWritten > (ulong)ContentLength) // If ContentLength == -1, this will always be false - { - return ValueTask.FromException(new HttpRequestException(SR.net_http_content_write_larger_than_content_length)); - } Http2Stream? http2Stream = _http2Stream; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs index dfa7f096c5cf96..45cc1a6e1e0571 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs @@ -995,6 +995,7 @@ public async ValueTask SendWithVersionDetectionAndRetryAsyn // Use HTTP/3 if possible. if (IsHttp3Supported() && // guard to enable trimming HTTP/3 support _http3Enabled && + !request.IsWebSocketH2Request() && (request.Version.Major >= 3 || (request.VersionPolicy == HttpVersionPolicy.RequestVersionOrHigher && IsSecure))) { Debug.Assert(async); @@ -1018,6 +1019,17 @@ public async ValueTask SendWithVersionDetectionAndRetryAsyn Debug.Assert(connection is not null || !_http2Enabled); if (connection is not null) { + if (request.IsWebSocketH2Request()) + { + await connection.InitialSettingsReceived.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false); + if (!connection.IsConnectEnabled) + { + HttpRequestException exception = new(SR.net_unsupported_extended_connect); + exception.Data["SETTINGS_ENABLE_CONNECT_PROTOCOL"] = false; + throw exception; + } + } + response = await connection.SendAsync(request, async, cancellationToken).ConfigureAwait(false); } } @@ -1075,7 +1087,12 @@ public async ValueTask SendWithVersionDetectionAndRetryAsyn // Throw if fallback is not allowed by the version policy. if (request.VersionPolicy != HttpVersionPolicy.RequestVersionOrLower) { - throw new HttpRequestException(SR.Format(SR.net_http_requested_version_server_refused, request.Version, request.VersionPolicy), e); + HttpRequestException exception = new HttpRequestException(SR.Format(SR.net_http_requested_version_server_refused, request.Version, request.VersionPolicy), e); + if (request.IsWebSocketH2Request()) + { + exception.Data["HTTP2_ENABLED"] = false; + } + throw exception; } if (NetEventSource.Log.IsEnabled()) diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs index 98537366ec43d8..dfc3625aaa7b12 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs @@ -2548,6 +2548,70 @@ public async Task PostAsyncDuplex_ServerSendsEndStream_Success() } } + [Fact] + public async Task ConnectAsync_ReadWriteWebSocketStream() + { + var clientMessage = new byte[] { 1, 2, 3 }; + var serverMessage = new byte[] { 4, 5, 6, 7 }; + + using Http2LoopbackServer server = Http2LoopbackServer.CreateServer(); + Http2LoopbackConnection connection = null; + + Task serverTask = Task.Run(async () => + { + connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); + + // read request headers + (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + + // send response headers + await connection.SendResponseHeadersAsync(streamId, endStream: false).ConfigureAwait(false); + + // send reply + await connection.SendResponseDataAsync(streamId, serverMessage, endStream: false); + + // send server EOS + await connection.SendResponseDataAsync(streamId, Array.Empty(), endStream: true); + }); + + StreamingHttpContent requestContent = new StreamingHttpContent(); + + using var handler = new SocketsHttpHandler(); + handler.SslOptions.RemoteCertificateValidationCallback = delegate { return true; }; + + using HttpClient client = new HttpClient(handler); + + HttpRequestMessage request = new(HttpMethod.Connect, server.Address); + request.Version = HttpVersion.Version20; + request.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + request.Headers.Protocol = "websocket"; + + // initiate request + var responseTask = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + + using HttpResponseMessage response = await responseTask.WaitAsync(TimeSpan.FromSeconds(10)); + + await serverTask.WaitAsync(TimeSpan.FromSeconds(60)); + + var responseStream = await response.Content.ReadAsStreamAsync(); + + // receive data + var readBuffer = new byte[10]; + int bytesRead = await responseStream.ReadAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + Assert.Equal(bytesRead, serverMessage.Length); + Assert.Equal(serverMessage, readBuffer[..bytesRead]); + + await responseStream.WriteAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + + // Send client's EOS + requestContent.CompleteStream(); + // Receive server's EOS + Assert.Equal(0, await responseStream.ReadAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10))); + + Assert.NotNull(connection); + connection.Dispose(); + } + [Fact] [ActiveIssue("https://github.com/dotnet/runtime/issues/69870", TestPlatforms.Android)] public async Task PostAsyncDuplex_RequestContentException_ResetsStream() diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index 96cecd9e30f471..15cc579688698f 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -3,7 +3,6 @@ // ------------------------------------------------------------------------------ // Changes to this file must follow the https://aka.ms/api-review process. // ------------------------------------------------------------------------------ - namespace System.Net.WebSockets { public sealed partial class ClientWebSocket : System.Net.WebSockets.WebSocket @@ -18,6 +17,7 @@ public override void Abort() { } public override System.Threading.Tasks.Task CloseAsync(System.Net.WebSockets.WebSocketCloseStatus closeStatus, string? statusDescription, System.Threading.CancellationToken cancellationToken) { throw null; } public override System.Threading.Tasks.Task CloseOutputAsync(System.Net.WebSockets.WebSocketCloseStatus closeStatus, string? statusDescription, System.Threading.CancellationToken cancellationToken) { throw null; } public System.Threading.Tasks.Task ConnectAsync(System.Uri uri, System.Threading.CancellationToken cancellationToken) { throw null; } + public System.Threading.Tasks.Task ConnectAsync(System.Uri uri, System.Net.Http.HttpMessageInvoker? invoker, System.Threading.CancellationToken cancellationToken) { throw null; } public override void Dispose() { } public override System.Threading.Tasks.Task ReceiveAsync(System.ArraySegment buffer, System.Threading.CancellationToken cancellationToken) { throw null; } public override System.Threading.Tasks.ValueTask ReceiveAsync(System.Memory buffer, System.Threading.CancellationToken cancellationToken) { throw null; } @@ -43,6 +43,8 @@ internal ClientWebSocketOptions() { } public System.Net.Security.RemoteCertificateValidationCallback? RemoteCertificateValidationCallback { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public bool UseDefaultCredentials { get { throw null; } set { } } + public System.Version HttpVersion { get { throw null; } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] set { } } + public System.Net.Http.HttpVersionPolicy HttpVersionPolicy { get { throw null; } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] set { } } public void AddSubProtocol(string subProtocol) { } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public void SetBuffer(int receiveBufferSize, int sendBufferSize) { } diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.csproj index 57de4cf647d84c..deca422762934e 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.csproj +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.csproj @@ -8,6 +8,7 @@ + diff --git a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx index 360d9f9a337e12..401be8dc707616 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx @@ -75,8 +75,8 @@ The argument must be a value greater than {0}. - - The server returned status code '{0}' when status code '101' was expected. + + The server returned status code '{0}' when status code '{1}' was expected. The server's response was missing the required header '{0}'. @@ -129,4 +129,4 @@ The WebSocket failed to negotiate max client window bits. The client requested {0} but the server responded with {1}. - \ No newline at end of file + diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs index 79dd04229b9c33..e01b4fcf46a876 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; +using System.Net.Http; using System.Runtime.Versioning; using System.Security.Cryptography.X509Certificates; @@ -32,6 +33,20 @@ public bool UseDefaultCredentials set => throw new PlatformNotSupportedException(); } + public Version HttpVersion + { + get => Net.HttpVersion.Version11; + [UnsupportedOSPlatform("browser")] + set => throw new PlatformNotSupportedException(); + } + + public System.Net.Http.HttpVersionPolicy HttpVersionPolicy + { + get => HttpVersionPolicy.RequestVersionOrLower; + [UnsupportedOSPlatform("browser")] + set => throw new PlatformNotSupportedException(); + } + [UnsupportedOSPlatform("browser")] public System.Net.ICredentials Credentials { diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs index 85ac3a70e1942a..2a32dee3b1ee9f 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocket.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Net.Http; using System.Threading; using System.Threading.Tasks; @@ -51,6 +52,11 @@ public override WebSocketState State } public Task ConnectAsync(Uri uri, CancellationToken cancellationToken) + { + return ConnectAsync(uri, null, cancellationToken); + } + + public Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, CancellationToken cancellationToken) { ArgumentNullException.ThrowIfNull(uri); @@ -77,16 +83,16 @@ public Task ConnectAsync(Uri uri, CancellationToken cancellationToken) } Options.SetToReadOnly(); - return ConnectAsyncCore(uri, cancellationToken); + return ConnectAsyncCore(uri, invoker, cancellationToken); } - private async Task ConnectAsyncCore(Uri uri, CancellationToken cancellationToken) + private async Task ConnectAsyncCore(Uri uri, HttpMessageInvoker? invoker, CancellationToken cancellationToken) { _innerWebSocket = new WebSocketHandle(); try { - await _innerWebSocket.ConnectAsync(uri, cancellationToken, Options).ConfigureAwait(false); + await _innerWebSocket.ConnectAsync(uri, invoker, cancellationToken, Options).ConfigureAwait(false); } catch { diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index d58cda99112eec..5f8027abda7bb2 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; +using System.Net.Http; using System.Net.Security; using System.Runtime.Versioning; using System.Security.Cryptography.X509Certificates; @@ -25,11 +26,36 @@ public sealed class ClientWebSocketOptions internal X509CertificateCollection? _clientCertificates; internal WebHeaderCollection? _requestHeaders; internal List? _requestedSubProtocols; + private Version _version = Net.HttpVersion.Version11; + private HttpVersionPolicy _versionPolicy = HttpVersionPolicy.RequestVersionOrLower; internal ClientWebSocketOptions() { } // prevent external instantiation #region HTTP Settings + public Version HttpVersion + { + get => _version; + [UnsupportedOSPlatform("browser")] + set + { + ThrowIfReadOnly(); + ArgumentNullException.ThrowIfNull(value); + _version = value; + } + } + + public HttpVersionPolicy HttpVersionPolicy + { + get => _versionPolicy; + [UnsupportedOSPlatform("browser")] + set + { + ThrowIfReadOnly(); + _versionPolicy = value; + } + } + [UnsupportedOSPlatform("browser")] // Note that some headers are restricted like Host. public void SetRequestHeader(string headerName, string? headerValue) diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs index 3a948cc64ab34f..2addc85ea5aed1 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Browser.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Net.Http; using System.Threading; using System.Threading.Tasks; @@ -27,7 +28,7 @@ public void Abort() WebSocket?.Abort(); } - public Task ConnectAsync(Uri uri, CancellationToken cancellationToken, ClientWebSocketOptions options) + public Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, CancellationToken cancellationToken, ClientWebSocketOptions options) { cancellationToken.ThrowIfCancellationRequested(); diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 9f5beb3874e657..480ea91ce1e3e3 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -42,118 +42,84 @@ public void Abort() WebSocket?.Abort(); } - public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, ClientWebSocketOptions options) + public async Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, CancellationToken cancellationToken, ClientWebSocketOptions options) { + bool disposeHandler = false; + invoker ??= new HttpMessageInvoker(SetupHandler(options, out disposeHandler)); HttpResponseMessage? response = null; - SocketsHttpHandler? handler = null; - bool disposeHandler = true; + + bool tryDowngrade = false; try { - var request = new HttpRequestMessage(HttpMethod.Get, uri); - if (options._requestHeaders?.Count > 0) // use field to avoid lazily initializing the collection - { - foreach (string key in options.RequestHeaders) - { - request.Headers.TryAddWithoutValidation(key, options.RequestHeaders[key]); - } - } - // Create the security key and expected response, then build all of the request headers - KeyValuePair secKeyAndSecWebSocketAccept = CreateSecKeyAndSecWebSocketAccept(); - AddWebSocketHeaders(request, secKeyAndSecWebSocketAccept.Key, options); - - // Create the handler for this request and populate it with all of the options. - // Try to use a shared handler rather than creating a new one just for this request, if - // the options are compatible. - if (options.Credentials == null && - !options.UseDefaultCredentials && - options.Proxy == null && - options.Cookies == null && - options.RemoteCertificateValidationCallback == null && - options._clientCertificates?.Count == 0) + while (true) { - disposeHandler = false; - handler = s_defaultHandler; - if (handler == null) + try { - handler = new SocketsHttpHandler() + HttpRequestMessage request; + if (!tryDowngrade && options.HttpVersion >= HttpVersion.Version20 + || (options.HttpVersion == HttpVersion.Version11 && options.HttpVersionPolicy == HttpVersionPolicy.RequestVersionOrHigher)) { - PooledConnectionLifetime = TimeSpan.Zero, - UseProxy = false, - UseCookies = false, - }; - if (Interlocked.CompareExchange(ref s_defaultHandler, handler, null) != null) + if (options.HttpVersion > HttpVersion.Version20 && options.HttpVersionPolicy != HttpVersionPolicy.RequestVersionOrLower) + { + throw new WebSocketException(WebSocketError.UnsupportedProtocol); + } + request = new HttpRequestMessage(HttpMethod.Connect, uri) { Version = HttpVersion.Version20 }; + tryDowngrade = true; + } + else if (tryDowngrade || options.HttpVersion == HttpVersion.Version11) { - handler.Dispose(); - handler = s_defaultHandler; + request = new HttpRequestMessage(HttpMethod.Get, uri) { Version = HttpVersion.Version11 }; + tryDowngrade = false; + } + else + { + throw new WebSocketException(WebSocketError.UnsupportedProtocol); } - } - } - else - { - handler = new SocketsHttpHandler(); - handler.PooledConnectionLifetime = TimeSpan.Zero; - handler.CookieContainer = options.Cookies; - handler.UseCookies = options.Cookies != null; - handler.SslOptions.RemoteCertificateValidationCallback = options.RemoteCertificateValidationCallback; - if (options.UseDefaultCredentials) - { - handler.Credentials = CredentialCache.DefaultCredentials; - } - else - { - handler.Credentials = options.Credentials; - } + if (options._requestHeaders?.Count > 0) // use field to avoid lazily initializing the collection + { + foreach (string key in options.RequestHeaders) + { + request.Headers.TryAddWithoutValidation(key, options.RequestHeaders[key]); + } + } - if (options.Proxy == null) - { - handler.UseProxy = false; - } - else if (options.Proxy != DefaultWebProxy.Instance) - { - handler.Proxy = options.Proxy; - } + string? secValue = AddWebSocketHeaders(request, options); - if (options._clientCertificates?.Count > 0) // use field to avoid lazily initializing the collection - { - Debug.Assert(handler.SslOptions.ClientCertificates == null); - handler.SslOptions.ClientCertificates = new X509Certificate2Collection(); - handler.SslOptions.ClientCertificates.AddRange(options.ClientCertificates); - } - } + // Issue the request. + CancellationTokenSource? linkedCancellation; + CancellationTokenSource externalAndAbortCancellation; + if (cancellationToken.CanBeCanceled) // avoid allocating linked source if external token is not cancelable + { + linkedCancellation = + externalAndAbortCancellation = + CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _abortSource.Token); + } + else + { + linkedCancellation = null; + externalAndAbortCancellation = _abortSource; + } - // Issue the request. The response must be status code 101. - CancellationTokenSource? linkedCancellation; - CancellationTokenSource externalAndAbortCancellation; - if (cancellationToken.CanBeCanceled) // avoid allocating linked source if external token is not cancelable - { - linkedCancellation = - externalAndAbortCancellation = - CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _abortSource.Token); - } - else - { - linkedCancellation = null; - externalAndAbortCancellation = _abortSource; - } + using (linkedCancellation) + { + response = await invoker.SendAsync(request, externalAndAbortCancellation.Token).ConfigureAwait(false); + externalAndAbortCancellation.Token.ThrowIfCancellationRequested(); // poll in case sends/receives in request/response didn't observe cancellation + } - using (linkedCancellation) - { - response = await new HttpMessageInvoker(handler).SendAsync(request, externalAndAbortCancellation.Token).ConfigureAwait(false); - externalAndAbortCancellation.Token.ThrowIfCancellationRequested(); // poll in case sends/receives in request/response didn't observe cancellation - } + ValidateResponse(response, secValue, options); + break; + } + catch (HttpRequestException ex) when + ((ex.Data.Contains("SETTINGS_ENABLE_CONNECT_PROTOCOL") || ex.Data.Contains("HTTP2_ENABLED")) + && tryDowngrade + && (options.HttpVersion == HttpVersion.Version11 || options.HttpVersionPolicy == HttpVersionPolicy.RequestVersionOrLower)) + { + } - if (response.StatusCode != HttpStatusCode.SwitchingProtocols) - { - throw new WebSocketException(WebSocketError.NotAWebSocket, SR.Format(SR.net_WebSockets_Connect101Expected, (int)response.StatusCode)); } - // The Connection, Upgrade, and SecWebSocketAccept headers are required and with specific values. - ValidateHeader(response.Headers, HttpKnownHeaderNames.Connection, "Upgrade"); - ValidateHeader(response.Headers, HttpKnownHeaderNames.Upgrade, "websocket"); - ValidateHeader(response.Headers, HttpKnownHeaderNames.SecWebSocketAccept, secKeyAndSecWebSocketAccept.Value); - // The SecWebSocketProtocol header is optional. We should only get it with a non-empty value if we requested subprotocols, // and then it must only be one of the ones we requested. If we got a subprotocol other than one we requested (or if we // already got one in a previous header), fail. Otherwise, track which one we got. @@ -200,11 +166,6 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } - if (response.Content is null) - { - throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); - } - // Get the response stream and wrap it in a web socket. Stream connectedStream = response.Content.ReadAsStream(); Debug.Assert(connectedStream.CanWrite); @@ -241,11 +202,75 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli // Disposing the handler will not affect any active stream wrapped in the WebSocket. if (disposeHandler) { - handler?.Dispose(); + invoker?.Dispose(); } } } + private static SocketsHttpHandler SetupHandler(ClientWebSocketOptions options, out bool disposeHandler) + { + SocketsHttpHandler? handler; + + // Create the handler for this request and populate it with all of the options. + // Try to use a shared handler rather than creating a new one just for this request, if + // the options are compatible. + if (options.Credentials == null && + !options.UseDefaultCredentials && + options.Proxy == null && + options.Cookies == null && + options.RemoteCertificateValidationCallback == null && + (options._clientCertificates?.Count ?? 0) == 0) + { + disposeHandler = false; + handler = s_defaultHandler; + if (handler == null) + { + handler = new SocketsHttpHandler() + { + PooledConnectionLifetime = TimeSpan.Zero, + UseProxy = false, + UseCookies = false, + }; + if (Interlocked.CompareExchange(ref s_defaultHandler, handler, null) != null) + { + handler.Dispose(); + handler = s_defaultHandler; + } + } + } + else + { + disposeHandler = true; + handler = new SocketsHttpHandler(); + handler.PooledConnectionLifetime = TimeSpan.Zero; + handler.CookieContainer = options.Cookies; + handler.UseCookies = options.Cookies != null; + handler.SslOptions.RemoteCertificateValidationCallback = options.RemoteCertificateValidationCallback; + + handler.Credentials = options.UseDefaultCredentials ? + CredentialCache.DefaultCredentials : + options.Credentials; + + if (options.Proxy == null) + { + handler.UseProxy = false; + } + else if (options.Proxy != DefaultWebProxy.Instance) + { + handler.Proxy = options.Proxy; + } + + if (options._clientCertificates?.Count > 0) // use field to avoid lazily initializing the collection + { + Debug.Assert(handler.SslOptions.ClientCertificates == null); + handler.SslOptions.ClientCertificates = new X509Certificate2Collection(); + handler.SslOptions.ClientCertificates.AddRange(options.ClientCertificates); + } + } + + return handler; + } + private static WebSocketDeflateOptions ParseDeflateOptions(ReadOnlySpan extension, WebSocketDeflateOptions original) { var options = new WebSocketDeflateOptions(); @@ -315,14 +340,29 @@ static int ParseWindowBits(ReadOnlySpan value) /// Adds the necessary headers for the web socket request. /// The request to which the headers should be added. - /// The generated security key to send in the Sec-WebSocket-Key header. /// The options controlling the request. - private static void AddWebSocketHeaders(HttpRequestMessage request, string secKey, ClientWebSocketOptions options) + private static string? AddWebSocketHeaders(HttpRequestMessage request, ClientWebSocketOptions options) { - request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.Connection, HttpKnownHeaderNames.Upgrade); - request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.Upgrade, "websocket"); + // always exact because we handle downgrade here + request.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + string? secValue = null; + + if (request.Version == HttpVersion.Version11) + { + // Create the security key and expected response, then build all of the request headers + KeyValuePair secKeyAndSecWebSocketAccept = CreateSecKeyAndSecWebSocketAccept(); + secValue = secKeyAndSecWebSocketAccept.Value; + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.Connection, HttpKnownHeaderNames.Upgrade); + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.Upgrade, "websocket"); + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketKey, secKeyAndSecWebSocketAccept.Key); + } + else if (request.Version == HttpVersion.Version20) + { + request.Headers.Protocol = "websocket"; + } + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketVersion, "13"); - request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketKey, secKey); + if (options._requestedSubProtocols?.Count > 0) { request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketProtocol, string.Join(", ", options.RequestedSubProtocols)); @@ -365,6 +405,39 @@ static string GetDeflateOptions(WebSocketDeflateOptions options) return builder.ToString(); } } + return secValue; + } + + private static void ValidateResponse(HttpResponseMessage response, string? secValue, ClientWebSocketOptions options) + { + Debug.Assert(response.Version == HttpVersion.Version11 || response.Version == HttpVersion.Version20); + + if (response.Version == HttpVersion.Version11) + { + if (response.StatusCode != HttpStatusCode.SwitchingProtocols) + { + throw new WebSocketException(WebSocketError.NotAWebSocket, SR.Format(SR.net_WebSockets_ConnectStatusExpected, (int)response.StatusCode, (int)HttpStatusCode.SwitchingProtocols)); + } + + Debug.Assert(secValue != null); + + // The Connection, Upgrade, and SecWebSocketAccept headers are required and with specific values. + ValidateHeader(response.Headers, HttpKnownHeaderNames.Connection, "Upgrade"); + ValidateHeader(response.Headers, HttpKnownHeaderNames.Upgrade, "websocket"); + ValidateHeader(response.Headers, HttpKnownHeaderNames.SecWebSocketAccept, secValue); + } + else if (response.Version == HttpVersion.Version20) + { + if (response.StatusCode != HttpStatusCode.OK) + { + throw new WebSocketException(WebSocketError.NotAWebSocket, SR.Format(SR.net_WebSockets_ConnectStatusExpected, (int)response.StatusCode, (int)HttpStatusCode.OK)); + } + } + + if (response.Content is null) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); + } } /// diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs index 91e5f8cc681e26..a03eff4c18c899 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs @@ -53,7 +53,7 @@ public static IEnumerable UnavailableWebSocketServers { server = System.Net.Test.Common.Configuration.Http.RemoteEchoServer; var ub = new UriBuilder("ws", server.Host, server.Port, server.PathAndQuery); - exceptionMessage = ResourceHelper.GetExceptionMessage("net_WebSockets_Connect101Expected", (int) HttpStatusCode.OK); + exceptionMessage = ResourceHelper.GetExceptionMessage("net_WebSockets_ConnectStatusExpected", (int) HttpStatusCode.OK, (int) HttpStatusCode.SwitchingProtocols); yield return new object[] { ub.Uri, exceptionMessage, WebSocketError.NotAWebSocket }; } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.Http2.cs b/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.Http2.cs new file mode 100644 index 00000000000000..23d260c87e2ffe --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/ConnectTest.Http2.cs @@ -0,0 +1,71 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Net.Http; +using System.Net.Test.Common; +using System.Threading; +using System.Threading.Tasks; + +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests +{ + public class ConnectTest_Http2 : ClientWebSocketTestBase + { + public ConnectTest_Http2(ITestOutputHelper output) : base(output) { } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/69870", TestPlatforms.Browser)] + public async Task ConnectAsync_VersionNotSupported_Throws() + { + await Http2LoopbackServer.CreateClientAndServerAsync(async uri => + { + using (var clientSocket = new ClientWebSocket()) + using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) + { + clientSocket.Options.HttpVersion = HttpVersion.Version20; + clientSocket.Options.HttpVersionPolicy = Http.HttpVersionPolicy.RequestVersionExact; + using var handler = new SocketsHttpHandler(); + handler.SslOptions.RemoteCertificateValidationCallback = delegate { return true; }; + Task t = clientSocket.ConnectAsync(uri, new HttpMessageInvoker(handler), cts.Token); + var ex = await Assert.ThrowsAnyAsync(() => t); + Assert.IsType(ex.InnerException); + Assert.True(ex.InnerException.Data.Contains("SETTINGS_ENABLE_CONNECT_PROTOCOL")); + } + }, + async server => + { + Http2LoopbackConnection connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 0 }); + }, new Http2Options() { WebSocketEndpoint = true } + ); + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/69870", TestPlatforms.Browser)] + public async Task ConnectAsync_VersionSupported_Success() + { + await Http2LoopbackServer.CreateClientAndServerAsync(async uri => + { + using (var cws = new ClientWebSocket()) + using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) + { + cws.Options.HttpVersion = HttpVersion.Version20; + cws.Options.HttpVersionPolicy = Http.HttpVersionPolicy.RequestVersionExact; + using var handler = new SocketsHttpHandler(); + handler.SslOptions.RemoteCertificateValidationCallback = delegate { return true; }; + await cws.ConnectAsync(uri, new HttpMessageInvoker(handler), cts.Token); + } + }, + async server => + { + Http2LoopbackConnection connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); + (int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false); + await connection.SendResponseHeadersAsync(streamId, endStream: false, HttpStatusCode.OK); + }, new Http2Options() { WebSocketEndpoint = true } + ); + } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index 27bd2a0f5c1daa..56c8ad1f7f194b 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -26,10 +26,9 @@ - - - + + + @@ -37,36 +36,33 @@ - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + +