using System; using System.Collections.Generic; using System.IO; using System.Net.WebSockets; using System.Text; using System.Threading; using System.Threading.Tasks; using Consul; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Newtonsoft.Json; using Ocelot.Configuration.File; using Ocelot.DependencyInjection; using Ocelot.Middleware; using Shouldly; using TestStack.BDDfy; using Xunit; namespace Ocelot.UnitTests.Websockets { public class WebSocketsProxyMiddlewareTests : IDisposable { private IWebHost _firstDownstreamHost; private readonly List _firstRecieved; private WebHostBuilder _ocelotBuilder; private IWebHost _ocelotHost; public WebSocketsProxyMiddlewareTests() { _firstRecieved = new List(); } [Fact] public async Task should_proxy_websocket_input_to_downstream_service() { var downstreamPort = 5001; var downstreamHost = "localhost"; var config = new FileConfiguration { ReRoutes = new List { new FileReRoute { UpstreamPathTemplate = "/", DownstreamPathTemplate = "/ws", DownstreamScheme = "ws", DownstreamHostAndPorts = new List { new FileHostAndPort { Host = downstreamHost, Port = downstreamPort } } } } }; this.Given(_ => GivenThereIsAConfiguration(config)) .And(_ => StartFakeOcelotWithWebSockets()) .And(_ => StartFakeDownstreamService($"http://{downstreamHost}:{downstreamPort}", "/ws")) .When(_ => StartClient("ws://localhost:5000/")) .Then(_ => _firstRecieved.Count.ShouldBe(10)) .BDDfy(); } public void Dispose() { _firstDownstreamHost?.Dispose(); } public async Task StartFakeOcelotWithWebSockets() { _ocelotBuilder = new WebHostBuilder(); _ocelotBuilder.ConfigureServices(s => { s.AddSingleton(_ocelotBuilder); s.AddOcelot(); }); _ocelotBuilder.UseKestrel() .UseUrls("http://localhost:5000") .UseContentRoot(Directory.GetCurrentDirectory()) .ConfigureAppConfiguration((hostingContext, config) => { config.SetBasePath(hostingContext.HostingEnvironment.ContentRootPath); var env = hostingContext.HostingEnvironment; config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true); config.AddJsonFile("configuration.json"); config.AddEnvironmentVariables(); }) .ConfigureLogging((hostingContext, logging) => { logging.AddConfiguration(hostingContext.Configuration.GetSection("Logging")); logging.AddConsole(); }) .Configure(app => { app.UseWebSockets(); app.UseOcelot().Wait(); }) .UseIISIntegration(); _ocelotHost = _ocelotBuilder.Build(); await _ocelotHost.StartAsync(); } public void GivenThereIsAConfiguration(FileConfiguration fileConfiguration) { var configurationPath = Path.Combine(AppContext.BaseDirectory, "configuration.json"); var jsonConfiguration = JsonConvert.SerializeObject(fileConfiguration); if (File.Exists(configurationPath)) { File.Delete(configurationPath); } File.WriteAllText(configurationPath, jsonConfiguration); } private async Task StartFakeDownstreamService(string url, string path) { _firstDownstreamHost = new WebHostBuilder() .ConfigureServices(s => { }).UseKestrel() .UseUrls(url) .UseContentRoot(Directory.GetCurrentDirectory()) .ConfigureAppConfiguration((hostingContext, config) => { config.SetBasePath(hostingContext.HostingEnvironment.ContentRootPath); var env = hostingContext.HostingEnvironment; config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true); config.AddEnvironmentVariables(); }) .ConfigureLogging((hostingContext, logging) => { logging.AddConfiguration(hostingContext.Configuration.GetSection("Logging")); logging.AddConsole(); }) .Configure(app => { app.UseWebSockets(); app.Use(async (context, next) => { if (context.Request.Path == path) { if (context.WebSockets.IsWebSocketRequest) { WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync(); await Echo(webSocket); } else { context.Response.StatusCode = 400; } } else { await next(); } }); }) .UseIISIntegration().Build(); await _firstDownstreamHost.StartAsync(); } private async Task StartClient(string url) { var client = new ClientWebSocket(); await client.ConnectAsync(new Uri(url), CancellationToken.None); var sending = Task.Run(async () => { string line = "test"; for (int i = 0; i < 10; i++) { var bytes = Encoding.UTF8.GetBytes(line); await client.SendAsync(new ArraySegment(bytes), WebSocketMessageType.Text, true, CancellationToken.None); await Task.Delay(10); } await client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); }); var receiving = Task.Run(async () => { var buffer = new byte[1024 * 4]; while (true) { var result = await client.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); if (result.MessageType == WebSocketMessageType.Text) { _firstRecieved.Add(Encoding.UTF8.GetString(buffer, 0, result.Count)); } else if (result.MessageType == WebSocketMessageType.Close) { await client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); break; } } }); await Task.WhenAll(sending, receiving); } private async Task Echo(WebSocket webSocket) { try { var buffer = new byte[1024 * 4]; var result = await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); while (!result.CloseStatus.HasValue) { await webSocket.SendAsync(new ArraySegment(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None); result = await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); } await webSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None); } catch (Exception e) { Console.WriteLine(e); } } } }