diff --git a/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddleware.cs b/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddleware.cs index 033653ca..301bfef7 100644 --- a/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddleware.cs +++ b/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddleware.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; @@ -10,7 +11,10 @@ namespace Ocelot.WebSockets.Middleware { public class WebSocketsProxyMiddleware : OcelotMiddleware { - private OcelotRequestDelegate _next; + private static readonly string[] NotForwardedWebSocketHeaders = new[] { "Connection", "Host", "Upgrade", "Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions" }; + private const int DefaultWebSocketBufferSize = 4096; + private const int StreamCopyBufferSize = 81920; + private readonly OcelotRequestDelegate _next; public WebSocketsProxyMiddleware(OcelotRequestDelegate next, IOcelotLoggerFactory loggerFactory) @@ -19,6 +23,37 @@ namespace Ocelot.WebSockets.Middleware _next = next; } + private static async Task PumpWebSocket(WebSocket source, WebSocket destination, int bufferSize, CancellationToken cancellationToken) + { + if (bufferSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bufferSize)); + } + + var buffer = new byte[bufferSize]; + while (true) + { + WebSocketReceiveResult result; + try + { + result = await source.ReceiveAsync(new ArraySegment(buffer), cancellationToken); + } + catch (OperationCanceledException) + { + await destination.CloseOutputAsync(WebSocketCloseStatus.EndpointUnavailable, null, cancellationToken); + return; + } + + if (result.MessageType == WebSocketMessageType.Close) + { + await destination.CloseOutputAsync(source.CloseStatus.Value, source.CloseStatusDescription, cancellationToken); + return; + } + + await destination.SendAsync(new ArraySegment(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, cancellationToken); + } + } + public async Task Invoke(DownstreamContext context) { await Proxy(context.HttpContext, context.DownstreamRequest.ToUri()); @@ -26,88 +61,42 @@ namespace Ocelot.WebSockets.Middleware private async Task Proxy(HttpContext context, string serverEndpoint) { - var wsToUpstreamClient = await context.WebSockets.AcceptWebSocketAsync(); - - var wsToDownstreamService = new ClientWebSocket(); - - foreach (var requestHeader in context.Request.Headers) + if (context == null) { - // Do not copy the Sec-Websocket headers because it is specified by the own connection it will fail when you copy this one. - if (requestHeader.Key.StartsWith("Sec-WebSocket")) - { - continue; - } - wsToDownstreamService.Options.SetRequestHeader(requestHeader.Key, requestHeader.Value); + throw new ArgumentNullException(nameof(context)); } - var uri = new Uri(serverEndpoint); - await wsToDownstreamService.ConnectAsync(uri, CancellationToken.None); - - var receiveFromUpstreamSendToDownstream = Task.Run(async () => + if (serverEndpoint == null) { - var buffer = new byte[1024 * 4]; + throw new ArgumentNullException(nameof(serverEndpoint)); + } - var receiveSegment = new ArraySegment(buffer); - - while (wsToUpstreamClient.State == WebSocketState.Open || wsToUpstreamClient.State == WebSocketState.CloseSent) - { - var result = await wsToUpstreamClient.ReceiveAsync(receiveSegment, CancellationToken.None); - - var sendSegment = new ArraySegment(buffer, 0, result.Count); - - if(result.MessageType == WebSocketMessageType.Close) - { - await wsToUpstreamClient.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", - CancellationToken.None); - - await wsToDownstreamService.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", - CancellationToken.None); - - break; - } - - await wsToDownstreamService.SendAsync(sendSegment, result.MessageType, result.EndOfMessage, - CancellationToken.None); - - if (wsToUpstreamClient.State != WebSocketState.Open) - { - await wsToDownstreamService.CloseAsync(WebSocketCloseStatus.Empty, "", - CancellationToken.None); - break; - } - } - }); - - var receiveFromDownstreamAndSendToUpstream = Task.Run(async () => + if (!context.WebSockets.IsWebSocketRequest) { - var buffer = new byte[1024 * 4]; + throw new InvalidOperationException(); + } - while (wsToDownstreamService.State == WebSocketState.Open || wsToDownstreamService.State == WebSocketState.CloseSent) + var client = new ClientWebSocket(); + foreach (var protocol in context.WebSockets.WebSocketRequestedProtocols) + { + client.Options.AddSubProtocol(protocol); + } + + foreach (var headerEntry in context.Request.Headers) + { + if (!NotForwardedWebSocketHeaders.Contains(headerEntry.Key, StringComparer.OrdinalIgnoreCase)) { - if (wsToUpstreamClient.State != WebSocketState.Open) - { - break; - } - else - { - var receiveSegment = new ArraySegment(buffer); - var result = await wsToDownstreamService.ReceiveAsync(receiveSegment, CancellationToken.None); - - if (result.MessageType == WebSocketMessageType.Close) - { - break; - } - - var sendSegment = new ArraySegment(buffer, 0, result.Count); - - //send to upstream client - await wsToUpstreamClient.SendAsync(sendSegment, result.MessageType, result.EndOfMessage, - CancellationToken.None); - } + client.Options.SetRequestHeader(headerEntry.Key, headerEntry.Value); } - }); + } - await Task.WhenAll(receiveFromDownstreamAndSendToUpstream, receiveFromUpstreamSendToDownstream); + var destinationUri = new Uri(serverEndpoint); + await client.ConnectAsync(destinationUri, context.RequestAborted); + using (var server = await context.WebSockets.AcceptWebSocketAsync(client.SubProtocol)) + { + var bufferSize = DefaultWebSocketBufferSize; + await Task.WhenAll(PumpWebSocket(client, server, bufferSize, context.RequestAborted), PumpWebSocket(server, client, bufferSize, context.RequestAborted)); + } } } }