Update to support Subprotocols. Solves #639 (#642)

Ocelot websocket middleware did not work for STOMP over websocket. After investigation i found out that the issue was with subprotocol and headers that are send and filtered. 

I the end i used ASP.Net core proxy as a reference to solve the issue here:

3015029f51/src/Microsoft.AspNetCore.Proxy/ProxyAdvancedExtensions.cs

So i modified the code to use the way ASP.Net proxy handles this.
This commit is contained in:
vasicvuk 2018-09-30 10:17:09 +02:00 committed by Tom Pallister
parent 65b4115e90
commit b58b3810d8

View File

@ -1,4 +1,5 @@
using System; using System;
using System.Linq;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -10,7 +11,10 @@ namespace Ocelot.WebSockets.Middleware
{ {
public class WebSocketsProxyMiddleware : OcelotMiddleware public class WebSocketsProxyMiddleware : OcelotMiddleware
{ {
private OcelotRequestDelegate _next; private static readonly string[] NotForwardedWebSocketHeaders = new[] { "Connection", "Host", "Upgrade", "Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions" };
private const int DefaultWebSocketBufferSize = 4096;
private const int StreamCopyBufferSize = 81920;
private readonly OcelotRequestDelegate _next;
public WebSocketsProxyMiddleware(OcelotRequestDelegate next, public WebSocketsProxyMiddleware(OcelotRequestDelegate next,
IOcelotLoggerFactory loggerFactory) IOcelotLoggerFactory loggerFactory)
@ -19,6 +23,37 @@ namespace Ocelot.WebSockets.Middleware
_next = next; _next = next;
} }
private static async Task PumpWebSocket(WebSocket source, WebSocket destination, int bufferSize, CancellationToken cancellationToken)
{
if (bufferSize <= 0)
{
throw new ArgumentOutOfRangeException(nameof(bufferSize));
}
var buffer = new byte[bufferSize];
while (true)
{
WebSocketReceiveResult result;
try
{
result = await source.ReceiveAsync(new ArraySegment<byte>(buffer), cancellationToken);
}
catch (OperationCanceledException)
{
await destination.CloseOutputAsync(WebSocketCloseStatus.EndpointUnavailable, null, cancellationToken);
return;
}
if (result.MessageType == WebSocketMessageType.Close)
{
await destination.CloseOutputAsync(source.CloseStatus.Value, source.CloseStatusDescription, cancellationToken);
return;
}
await destination.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, cancellationToken);
}
}
public async Task Invoke(DownstreamContext context) public async Task Invoke(DownstreamContext context)
{ {
await Proxy(context.HttpContext, context.DownstreamRequest.ToUri()); await Proxy(context.HttpContext, context.DownstreamRequest.ToUri());
@ -26,88 +61,42 @@ namespace Ocelot.WebSockets.Middleware
private async Task Proxy(HttpContext context, string serverEndpoint) private async Task Proxy(HttpContext context, string serverEndpoint)
{ {
var wsToUpstreamClient = await context.WebSockets.AcceptWebSocketAsync(); if (context == null)
var wsToDownstreamService = new ClientWebSocket();
foreach (var requestHeader in context.Request.Headers)
{ {
// Do not copy the Sec-Websocket headers because it is specified by the own connection it will fail when you copy this one. throw new ArgumentNullException(nameof(context));
if (requestHeader.Key.StartsWith("Sec-WebSocket"))
{
continue;
}
wsToDownstreamService.Options.SetRequestHeader(requestHeader.Key, requestHeader.Value);
} }
var uri = new Uri(serverEndpoint); if (serverEndpoint == null)
await wsToDownstreamService.ConnectAsync(uri, CancellationToken.None);
var receiveFromUpstreamSendToDownstream = Task.Run(async () =>
{ {
var buffer = new byte[1024 * 4]; throw new ArgumentNullException(nameof(serverEndpoint));
var receiveSegment = new ArraySegment<byte>(buffer);
while (wsToUpstreamClient.State == WebSocketState.Open || wsToUpstreamClient.State == WebSocketState.CloseSent)
{
var result = await wsToUpstreamClient.ReceiveAsync(receiveSegment, CancellationToken.None);
var sendSegment = new ArraySegment<byte>(buffer, 0, result.Count);
if(result.MessageType == WebSocketMessageType.Close)
{
await wsToUpstreamClient.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "",
CancellationToken.None);
await wsToDownstreamService.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "",
CancellationToken.None);
break;
} }
await wsToDownstreamService.SendAsync(sendSegment, result.MessageType, result.EndOfMessage, if (!context.WebSockets.IsWebSocketRequest)
CancellationToken.None);
if (wsToUpstreamClient.State != WebSocketState.Open)
{ {
await wsToDownstreamService.CloseAsync(WebSocketCloseStatus.Empty, "", throw new InvalidOperationException();
CancellationToken.None);
break;
}
}
});
var receiveFromDownstreamAndSendToUpstream = Task.Run(async () =>
{
var buffer = new byte[1024 * 4];
while (wsToDownstreamService.State == WebSocketState.Open || wsToDownstreamService.State == WebSocketState.CloseSent)
{
if (wsToUpstreamClient.State != WebSocketState.Open)
{
break;
}
else
{
var receiveSegment = new ArraySegment<byte>(buffer);
var result = await wsToDownstreamService.ReceiveAsync(receiveSegment, CancellationToken.None);
if (result.MessageType == WebSocketMessageType.Close)
{
break;
} }
var sendSegment = new ArraySegment<byte>(buffer, 0, result.Count); var client = new ClientWebSocket();
foreach (var protocol in context.WebSockets.WebSocketRequestedProtocols)
{
client.Options.AddSubProtocol(protocol);
}
//send to upstream client foreach (var headerEntry in context.Request.Headers)
await wsToUpstreamClient.SendAsync(sendSegment, result.MessageType, result.EndOfMessage, {
CancellationToken.None); if (!NotForwardedWebSocketHeaders.Contains(headerEntry.Key, StringComparer.OrdinalIgnoreCase))
{
client.Options.SetRequestHeader(headerEntry.Key, headerEntry.Value);
} }
} }
});
await Task.WhenAll(receiveFromDownstreamAndSendToUpstream, receiveFromUpstreamSendToDownstream); var destinationUri = new Uri(serverEndpoint);
await client.ConnectAsync(destinationUri, context.RequestAborted);
using (var server = await context.WebSockets.AcceptWebSocketAsync(client.SubProtocol))
{
var bufferSize = DefaultWebSocketBufferSize;
await Task.WhenAll(PumpWebSocket(client, server, bufferSize, context.RequestAborted), PumpWebSocket(server, client, bufferSize, context.RequestAborted));
}
} }
} }
} }