diff --git a/README.md b/README.md index ab3d89a5..f8ad7b22 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![Windows Build history](https://buildstats.info/appveyor/chart/TomPallister/ocelot-fcfpb?branch=develop&includeBuildsFromPullRequest=false)](https://ci.appveyor.com/project/TomPallister/ocelot-fcfpb/history?branch=develop) -[![Coverage Status](https://coveralls.io/repos/github/TomPallister/Ocelot/badge.svg?branch=develop)](https://coveralls.io/github/TomPallister/Ocelot?branch=develop) +[![Coverage Status](https://coveralls.io/repos/github/ThreeMammals/Ocelot/badge.svg?branch=develop)](https://coveralls.io/github/ThreeMammals/Ocelot?branch=develop) # Ocelot @@ -41,6 +41,7 @@ A quick list of Ocelot's capabilities for more information see the [documentatio * Request Aggregation * Service Discovery with Consul * Service Fabric +* WebSockets * Authentication * Authorisation * Rate Limiting diff --git a/docs/features/logging.rst b/docs/features/logging.rst index 313f62ff..b09a26cf 100644 --- a/docs/features/logging.rst +++ b/docs/features/logging.rst @@ -11,4 +11,10 @@ Finally if logging is set to trace level Ocelot will log starting, finishing and The reason for not just using bog standard framework logging is that I could not work out how to override the request id that get's logged when setting IncludeScopes -to true for logging settings. Nicely onto the next feature. \ No newline at end of file +to true for logging settings. Nicely onto the next feature. + +Warning +^^^^^^^ + +If you are logging to Console you will get terrible performance. I have had so many issues about performance issues with Ocelot +and it is always logging level Debug, logging to Console :) Make sure you are logging to something proper in production :) diff --git a/docs/features/websockets.rst b/docs/features/websockets.rst new file mode 100644 index 00000000..828d1051 --- /dev/null +++ b/docs/features/websockets.rst @@ -0,0 +1,68 @@ +Websockets +========== + +Ocelot supports proxying websockets with some extra bits. This functionality was requested in `Issue 212 `_. + +In order to get websocket proxying working with Ocelot you need to do the following. + +In your Configure method you need to tell your application to use WebSockets. + +.. code-block:: csharp + + Configure(app => + { + app.UseWebSockets(); + app.UseOcelot().Wait(); + }) + +Then in your configuration.json add the following to proxy a ReRoute using websockets. + +.. code-block:: json + + { + "DownstreamPathTemplate": "/ws", + "UpstreamPathTemplate": "/", + "DownstreamScheme": "ws", + "DownstreamHostAndPorts": [ + { + "Host": "localhost", + "Port": 5001 + } + ], + } + +With this configuration set Ocelot will match any websocket traffic that comes in on / and proxy it to localhost:5001/ws. To make this clearer +Ocelot will receive messages from the upstream client, proxy these to the downstream service, receive messages from the downstream service and +proxy these to the upstream client. + +Supported +^^^^^^^^^ + +1. Load Balancer +2. Routing +3. Service Discovery + +This means that you can set up your downstream services running websockets and either have multiple DownstreamHostAndPorts in your ReRoute +config or hook your ReRoute into a service discovery provider and then load balance requests...Which I think is pretty cool :) + +Not Supported +^^^^^^^^^^^^^ + +Unfortunately a lot of Ocelot's features are non websocket specific such as header and http client stuff. I've listed what won't work below. + +1. Tracing +2. RequestId +3. Request Aggregation +4. Rate Limiting +5. Quality of Service +6. Middleware Injection +7. Header Transformation +8. Delegating Handlers +9. Claims Transformation +10. Caching +11. Authentication - If anyone requests it we might be able to do something with basic authentication. +12. Authorisation + +I'm not 100% sure what will happen with this feature when it get's into the wild so please make sure you test thoroughly! + + diff --git a/docs/index.rst b/docs/index.rst index 43a6c436..78b305b2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,6 +25,7 @@ Thanks for taking a look at the Ocelot documentation. Please use the left hand n features/servicefabric features/authentication features/authorisation + features/websockets features/administration features/ratelimiting features/caching diff --git a/docs/introduction/bigpicture.rst b/docs/introduction/bigpicture.rst index 989c0f5c..b047abb4 100644 --- a/docs/introduction/bigpicture.rst +++ b/docs/introduction/bigpicture.rst @@ -1,7 +1,7 @@ Big Picture =========== -Ocleot is aimed at people using .NET running +Ocelot is aimed at people using .NET running a micro services / service orientated architecture that need a unified point of entry into their system. diff --git a/docs/introduction/gettingstarted.rst b/docs/introduction/gettingstarted.rst index 8f635fdc..7f81cb51 100644 --- a/docs/introduction/gettingstarted.rst +++ b/docs/introduction/gettingstarted.rst @@ -9,7 +9,7 @@ built to netcoreapp2.0 `this ();; + _logger = factory.CreateLogger(); _placeholders = placeholders; } diff --git a/src/Ocelot/Configuration/Creator/HeaderTransformations.cs b/src/Ocelot/Configuration/Creator/HeaderTransformations.cs index 72d307e5..461e5c35 100644 --- a/src/Ocelot/Configuration/Creator/HeaderTransformations.cs +++ b/src/Ocelot/Configuration/Creator/HeaderTransformations.cs @@ -14,21 +14,10 @@ namespace Ocelot.Configuration.Creator Downstream = downstream; } - public List Upstream { get; private set; } + public List Upstream { get; } - public List Downstream { get; private set; } - public List AddHeadersToDownstream {get;private set;} - } + public List Downstream { get; } - public class AddHeader - { - public AddHeader(string key, string value) - { - this.Key = key; - this.Value = value; - - } - public string Key { get; private set; } - public string Value { get; private set; } + public List AddHeadersToDownstream { get; } } } diff --git a/src/Ocelot/Configuration/Repository/ConsulFileConfigurationRepository.cs b/src/Ocelot/Configuration/Repository/ConsulFileConfigurationRepository.cs index 97bf7c97..10e99c10 100644 --- a/src/Ocelot/Configuration/Repository/ConsulFileConfigurationRepository.cs +++ b/src/Ocelot/Configuration/Repository/ConsulFileConfigurationRepository.cs @@ -6,6 +6,7 @@ using Newtonsoft.Json; using Ocelot.Configuration.File; using Ocelot.Responses; using Ocelot.ServiceDiscovery; +using Ocelot.ServiceDiscovery.Configuration; namespace Ocelot.Configuration.Repository { diff --git a/src/Ocelot/Configuration/Repository/IConsulPollerConfiguration.cs b/src/Ocelot/Configuration/Repository/IConsulPollerConfiguration.cs index 93003087..d1f1430d 100644 --- a/src/Ocelot/Configuration/Repository/IConsulPollerConfiguration.cs +++ b/src/Ocelot/Configuration/Repository/IConsulPollerConfiguration.cs @@ -4,5 +4,4 @@ { int Delay { get; } } - - } +} diff --git a/src/Ocelot/DownstreamUrlCreator/DownstreamHostNullOrEmptyError.cs b/src/Ocelot/DownstreamUrlCreator/DownstreamHostNullOrEmptyError.cs deleted file mode 100644 index d56532a0..00000000 --- a/src/Ocelot/DownstreamUrlCreator/DownstreamHostNullOrEmptyError.cs +++ /dev/null @@ -1,12 +0,0 @@ -using Ocelot.Errors; - -namespace Ocelot.DownstreamUrlCreator -{ - public class DownstreamHostNullOrEmptyError : Error - { - public DownstreamHostNullOrEmptyError() - : base("downstream host was null or empty", OcelotErrorCode.DownstreamHostNullOrEmptyError) - { - } - } -} \ No newline at end of file diff --git a/src/Ocelot/DownstreamUrlCreator/DownstreamPathNullOrEmptyError.cs b/src/Ocelot/DownstreamUrlCreator/DownstreamPathNullOrEmptyError.cs deleted file mode 100644 index 69528d43..00000000 --- a/src/Ocelot/DownstreamUrlCreator/DownstreamPathNullOrEmptyError.cs +++ /dev/null @@ -1,12 +0,0 @@ -using Ocelot.Errors; - -namespace Ocelot.DownstreamUrlCreator -{ - public class DownstreamPathNullOrEmptyError : Error - { - public DownstreamPathNullOrEmptyError() - : base("downstream path was null or empty", OcelotErrorCode.DownstreamPathNullOrEmptyError) - { - } - } -} \ No newline at end of file diff --git a/src/Ocelot/DownstreamUrlCreator/DownstreamSchemeNullOrEmptyError.cs b/src/Ocelot/DownstreamUrlCreator/DownstreamSchemeNullOrEmptyError.cs deleted file mode 100644 index 9f83bfee..00000000 --- a/src/Ocelot/DownstreamUrlCreator/DownstreamSchemeNullOrEmptyError.cs +++ /dev/null @@ -1,12 +0,0 @@ -using Ocelot.Errors; - -namespace Ocelot.DownstreamUrlCreator -{ - public class DownstreamSchemeNullOrEmptyError : Error - { - public DownstreamSchemeNullOrEmptyError() - : base("downstream scheme was null or empty", OcelotErrorCode.DownstreamSchemeNullOrEmptyError) - { - } - } -} \ No newline at end of file diff --git a/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs b/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs index 308daea5..43944050 100644 --- a/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs +++ b/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs @@ -40,49 +40,36 @@ namespace Ocelot.DownstreamUrlCreator.Middleware return; } - UriBuilder uriBuilder; - + context.DownstreamRequest.Scheme = context.DownstreamReRoute.DownstreamScheme; + if (ServiceFabricRequest(context)) { - uriBuilder = CreateServiceFabricUri(context, dsPath); + var pathAndQuery = CreateServiceFabricUri(context, dsPath); + context.DownstreamRequest.AbsolutePath = pathAndQuery.path; + context.DownstreamRequest.Query = pathAndQuery.query; } else { - uriBuilder = new UriBuilder(context.DownstreamRequest.RequestUri) - { - Path = dsPath.Data.Value, - Scheme = context.DownstreamReRoute.DownstreamScheme - }; + context.DownstreamRequest.AbsolutePath = dsPath.Data.Value; } - context.DownstreamRequest.RequestUri = uriBuilder.Uri; - - _logger.LogDebug("downstream url is {downstreamUrl.Data.Value}", context.DownstreamRequest.RequestUri); + _logger.LogDebug("downstream url is {context.DownstreamRequest}", context.DownstreamRequest); await _next.Invoke(context); } - private UriBuilder CreateServiceFabricUri(DownstreamContext context, Response dsPath) + private (string path, string query) CreateServiceFabricUri(DownstreamContext context, Response dsPath) { - var query = context.DownstreamRequest.RequestUri.Query; - var scheme = context.DownstreamReRoute.DownstreamScheme; - var host = context.DownstreamRequest.RequestUri.Host; - var port = context.DownstreamRequest.RequestUri.Port; + var query = context.DownstreamRequest.Query; var serviceFabricPath = $"/{context.DownstreamReRoute.ServiceName + dsPath.Data.Value}"; - Uri uri; - if (RequestForStatefullService(query)) { - uri = new Uri($"{scheme}://{host}:{port}{serviceFabricPath}{query}"); - } - else - { - var split = string.IsNullOrEmpty(query) ? "?" : "&"; - uri = new Uri($"{scheme}://{host}:{port}{serviceFabricPath}{query}{split}cmd=instance"); + return (serviceFabricPath, query); } - return new UriBuilder(uri); + var split = string.IsNullOrEmpty(query) ? "?" : "&"; + return (serviceFabricPath, $"{query}{split}cmd=instance"); } private static bool ServiceFabricRequest(DownstreamContext context) diff --git a/src/Ocelot/Headers/AddHeadersToRequest.cs b/src/Ocelot/Headers/AddHeadersToRequest.cs index f61b1ac3..1a71935a 100644 --- a/src/Ocelot/Headers/AddHeadersToRequest.cs +++ b/src/Ocelot/Headers/AddHeadersToRequest.cs @@ -5,6 +5,7 @@ using Ocelot.Infrastructure.Claims.Parser; using Ocelot.Responses; using System.Net.Http; using Ocelot.Configuration.Creator; +using Ocelot.Request.Middleware; namespace Ocelot.Headers { @@ -17,7 +18,7 @@ namespace Ocelot.Headers _claimsParser = claimsParser; } - public Response SetHeadersOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest) + public Response SetHeadersOnDownstreamRequest(List claimsToThings, IEnumerable claims, DownstreamRequest downstreamRequest) { foreach (var config in claimsToThings) { diff --git a/src/Ocelot/Headers/AddHeadersToResponse.cs b/src/Ocelot/Headers/AddHeadersToResponse.cs index 4c6f48ce..40d829ee 100644 --- a/src/Ocelot/Headers/AddHeadersToResponse.cs +++ b/src/Ocelot/Headers/AddHeadersToResponse.cs @@ -1,23 +1,22 @@ namespace Ocelot.Headers { - using System; using System.Collections.Generic; using System.Net.Http; using Ocelot.Configuration.Creator; using Ocelot.Infrastructure; - using Ocelot.Infrastructure.RequestData; using Ocelot.Logging; public class AddHeadersToResponse : IAddHeadersToResponse { - private IPlaceholders _placeholders; - private IOcelotLogger _logger; + private readonly IPlaceholders _placeholders; + private readonly IOcelotLogger _logger; public AddHeadersToResponse(IPlaceholders placeholders, IOcelotLoggerFactory factory) { _logger = factory.CreateLogger(); _placeholders = placeholders; } + public void Add(List addHeaders, HttpResponseMessage response) { foreach(var add in addHeaders) diff --git a/src/Ocelot/Headers/HttpResponseHeaderReplacer.cs b/src/Ocelot/Headers/HttpResponseHeaderReplacer.cs index fc3e8e4c..c4763746 100644 --- a/src/Ocelot/Headers/HttpResponseHeaderReplacer.cs +++ b/src/Ocelot/Headers/HttpResponseHeaderReplacer.cs @@ -5,6 +5,7 @@ using System.Net.Http; using Ocelot.Configuration; using Ocelot.Infrastructure; using Ocelot.Infrastructure.Extensions; +using Ocelot.Request.Middleware; using Ocelot.Responses; namespace Ocelot.Headers @@ -18,7 +19,7 @@ namespace Ocelot.Headers _placeholders = placeholders; } - public Response Replace(HttpResponseMessage response, List fAndRs, HttpRequestMessage request) + public Response Replace(HttpResponseMessage response, List fAndRs, DownstreamRequest request) { foreach (var f in fAndRs) { diff --git a/src/Ocelot/Headers/IAddHeadersToRequest.cs b/src/Ocelot/Headers/IAddHeadersToRequest.cs index fed2407c..c8b1c967 100644 --- a/src/Ocelot/Headers/IAddHeadersToRequest.cs +++ b/src/Ocelot/Headers/IAddHeadersToRequest.cs @@ -6,10 +6,11 @@ using Ocelot.Configuration; using Ocelot.Configuration.Creator; using Ocelot.Infrastructure.RequestData; + using Ocelot.Request.Middleware; using Ocelot.Responses; public interface IAddHeadersToRequest { - Response SetHeadersOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest); + Response SetHeadersOnDownstreamRequest(List claimsToThings, IEnumerable claims, DownstreamRequest downstreamRequest); } } diff --git a/src/Ocelot/Headers/IHttpResponseHeaderReplacer.cs b/src/Ocelot/Headers/IHttpResponseHeaderReplacer.cs index 8e74d111..6c805b86 100644 --- a/src/Ocelot/Headers/IHttpResponseHeaderReplacer.cs +++ b/src/Ocelot/Headers/IHttpResponseHeaderReplacer.cs @@ -1,12 +1,13 @@ using System.Collections.Generic; using System.Net.Http; using Ocelot.Configuration; +using Ocelot.Request.Middleware; using Ocelot.Responses; namespace Ocelot.Headers { public interface IHttpResponseHeaderReplacer { - Response Replace(HttpResponseMessage response, List fAndRs, HttpRequestMessage httpRequestMessage); + Response Replace(HttpResponseMessage response, List fAndRs, DownstreamRequest httpRequestMessage); } } \ No newline at end of file diff --git a/src/Ocelot/Infrastructure/IPlaceholders.cs b/src/Ocelot/Infrastructure/IPlaceholders.cs index f95fb8b8..1d2bbfa5 100644 --- a/src/Ocelot/Infrastructure/IPlaceholders.cs +++ b/src/Ocelot/Infrastructure/IPlaceholders.cs @@ -1,4 +1,5 @@ using System.Net.Http; +using Ocelot.Request.Middleware; using Ocelot.Responses; namespace Ocelot.Infrastructure @@ -6,6 +7,6 @@ namespace Ocelot.Infrastructure public interface IPlaceholders { Response Get(string key); - Response Get(string key, HttpRequestMessage request); + Response Get(string key, DownstreamRequest request); } } \ No newline at end of file diff --git a/src/Ocelot/Infrastructure/Placeholders.cs b/src/Ocelot/Infrastructure/Placeholders.cs index b00f54fa..c43dea14 100644 --- a/src/Ocelot/Infrastructure/Placeholders.cs +++ b/src/Ocelot/Infrastructure/Placeholders.cs @@ -3,14 +3,15 @@ using System.Collections.Generic; using System.Net.Http; using Ocelot.Infrastructure.RequestData; using Ocelot.Middleware; +using Ocelot.Request.Middleware; using Ocelot.Responses; namespace Ocelot.Infrastructure { public class Placeholders : IPlaceholders { - private Dictionary>> _placeholders; - private Dictionary> _requestPlaceholders; + private readonly Dictionary>> _placeholders; + private readonly Dictionary> _requestPlaceholders; private readonly IBaseUrlFinder _finder; private readonly IRequestScopedDataRepository _repo; @@ -30,13 +31,13 @@ namespace Ocelot.Infrastructure return new OkResponse(traceId.Data); }); - _requestPlaceholders = new Dictionary>(); + _requestPlaceholders = new Dictionary>(); _requestPlaceholders.Add("{DownstreamBaseUrl}", x => { - var downstreamUrl = $"{x.RequestUri.Scheme}://{x.RequestUri.Host}"; + var downstreamUrl = $"{x.Scheme}://{x.Host}"; - if(x.RequestUri.Port != 80 && x.RequestUri.Port != 443) + if(x.Port != 80 && x.Port != 443) { - downstreamUrl = $"{downstreamUrl}:{x.RequestUri.Port}"; + downstreamUrl = $"{downstreamUrl}:{x.Port}"; } return $"{downstreamUrl}/"; @@ -57,7 +58,7 @@ namespace Ocelot.Infrastructure return new ErrorResponse(new CouldNotFindPlaceholderError(key)); } - public Response Get(string key, HttpRequestMessage request) + public Response Get(string key, DownstreamRequest request) { if(_requestPlaceholders.ContainsKey(key)) { @@ -67,4 +68,4 @@ namespace Ocelot.Infrastructure return new ErrorResponse(new CouldNotFindPlaceholderError(key)); } } -} \ No newline at end of file +} diff --git a/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs b/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs index 8c2e963a..82f792f9 100644 --- a/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs +++ b/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs @@ -1,12 +1,8 @@ using System; using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; -using Ocelot.DownstreamRouteFinder.Middleware; -using Ocelot.Infrastructure.RequestData; using Ocelot.LoadBalancer.LoadBalancers; using Ocelot.Logging; using Ocelot.Middleware; -using Ocelot.QueryStrings.Middleware; namespace Ocelot.LoadBalancer.Middleware { @@ -43,17 +39,13 @@ namespace Ocelot.LoadBalancer.Middleware return; } - var uriBuilder = new UriBuilder(context.DownstreamRequest.RequestUri); - - uriBuilder.Host = hostAndPort.Data.DownstreamHost; + context.DownstreamRequest.Host = hostAndPort.Data.DownstreamHost; if (hostAndPort.Data.DownstreamPort > 0) { - uriBuilder.Port = hostAndPort.Data.DownstreamPort; + context.DownstreamRequest.Port = hostAndPort.Data.DownstreamPort; } - context.DownstreamRequest.RequestUri = uriBuilder.Uri; - try { await _next.Invoke(context); diff --git a/src/Ocelot/Middleware/DownstreamContext.cs b/src/Ocelot/Middleware/DownstreamContext.cs index 1c805deb..9e054eb5 100644 --- a/src/Ocelot/Middleware/DownstreamContext.cs +++ b/src/Ocelot/Middleware/DownstreamContext.cs @@ -1,9 +1,11 @@ +using System; using System.Collections.Generic; using System.Net.Http; using Microsoft.AspNetCore.Http; using Ocelot.Configuration; using Ocelot.DownstreamRouteFinder.UrlMatcher; using Ocelot.Errors; +using Ocelot.Request.Middleware; namespace Ocelot.Middleware { @@ -19,7 +21,7 @@ namespace Ocelot.Middleware public ServiceProviderConfiguration ServiceProviderConfiguration {get; set;} public HttpContext HttpContext { get; private set; } public DownstreamReRoute DownstreamReRoute { get; set; } - public HttpRequestMessage DownstreamRequest { get; set; } + public DownstreamRequest DownstreamRequest { get; set; } public HttpResponseMessage DownstreamResponse { get; set; } public List Errors { get;set; } public bool IsError => Errors.Count > 0; diff --git a/src/Ocelot/Middleware/Pipeline/IOcelotPipelineBuilder.cs b/src/Ocelot/Middleware/Pipeline/IOcelotPipelineBuilder.cs index 3bc0d6b0..9cb0db56 100644 --- a/src/Ocelot/Middleware/Pipeline/IOcelotPipelineBuilder.cs +++ b/src/Ocelot/Middleware/Pipeline/IOcelotPipelineBuilder.cs @@ -11,5 +11,6 @@ namespace Ocelot.Middleware.Pipeline IServiceProvider ApplicationServices { get; } OcelotPipelineBuilder Use(Func middleware); OcelotRequestDelegate Build(); + IOcelotPipelineBuilder New(); } } diff --git a/src/Ocelot/Middleware/Pipeline/MapWhenMiddleware.cs b/src/Ocelot/Middleware/Pipeline/MapWhenMiddleware.cs new file mode 100644 index 00000000..f05c35e4 --- /dev/null +++ b/src/Ocelot/Middleware/Pipeline/MapWhenMiddleware.cs @@ -0,0 +1,44 @@ +using System; +using System.Threading.Tasks; + +namespace Ocelot.Middleware.Pipeline +{ + public class MapWhenMiddleware + { + private readonly OcelotRequestDelegate _next; + private readonly MapWhenOptions _options; + + public MapWhenMiddleware(OcelotRequestDelegate next, MapWhenOptions options) + { + if (next == null) + { + throw new ArgumentNullException(nameof(next)); + } + + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + _next = next; + _options = options; + } + + public async Task Invoke(DownstreamContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + if (_options.Predicate(context)) + { + await _options.Branch(context); + } + else + { + await _next(context); + } + } + } +} diff --git a/src/Ocelot/Middleware/Pipeline/MapWhenOptions.cs b/src/Ocelot/Middleware/Pipeline/MapWhenOptions.cs new file mode 100644 index 00000000..912688c3 --- /dev/null +++ b/src/Ocelot/Middleware/Pipeline/MapWhenOptions.cs @@ -0,0 +1,28 @@ +using System; + +namespace Ocelot.Middleware.Pipeline +{ + public class MapWhenOptions + { + private Func _predicate; + + public Func Predicate + { + get + { + return _predicate; + } + set + { + if (value == null) + { + throw new ArgumentNullException(nameof(value)); + } + + _predicate = value; + } + } + + public OcelotRequestDelegate Branch { get; set; } + } +} diff --git a/src/Ocelot/Middleware/Pipeline/OcelotPipelineBuilder.cs b/src/Ocelot/Middleware/Pipeline/OcelotPipelineBuilder.cs index 1e37514c..5877ab62 100644 --- a/src/Ocelot/Middleware/Pipeline/OcelotPipelineBuilder.cs +++ b/src/Ocelot/Middleware/Pipeline/OcelotPipelineBuilder.cs @@ -19,6 +19,12 @@ namespace Ocelot.Middleware.Pipeline _middlewares = new List>(); } + public OcelotPipelineBuilder(IOcelotPipelineBuilder builder) + { + ApplicationServices = builder.ApplicationServices; + _middlewares = new List>(); + } + public IServiceProvider ApplicationServices { get; } public OcelotPipelineBuilder Use(Func middleware) @@ -42,5 +48,10 @@ namespace Ocelot.Middleware.Pipeline return app; } + + public IOcelotPipelineBuilder New() + { + return new OcelotPipelineBuilder(this); + } } } diff --git a/src/Ocelot/Middleware/Pipeline/OcelotPipelineBuilderExtensions.cs b/src/Ocelot/Middleware/Pipeline/OcelotPipelineBuilderExtensions.cs index 2469bf7c..968b7720 100644 --- a/src/Ocelot/Middleware/Pipeline/OcelotPipelineBuilderExtensions.cs +++ b/src/Ocelot/Middleware/Pipeline/OcelotPipelineBuilderExtensions.cs @@ -12,6 +12,8 @@ using Microsoft.Extensions.DependencyInjection; namespace Ocelot.Middleware.Pipeline { + using Predicate = Func; + public static class OcelotPipelineBuilderExtensions { internal const string InvokeMethodName = "Invoke"; @@ -91,6 +93,35 @@ namespace Ocelot.Middleware.Pipeline }); } + public static IOcelotPipelineBuilder MapWhen(this IOcelotPipelineBuilder app, Predicate predicate, Action configuration) + { + if (app == null) + { + throw new ArgumentNullException(nameof(app)); + } + + if (predicate == null) + { + throw new ArgumentNullException(nameof(predicate)); + } + + if (configuration == null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + var branchBuilder = app.New(); + configuration(branchBuilder); + var branch = branchBuilder.Build(); + + var options = new MapWhenOptions + { + Predicate = predicate, + Branch = branch, + }; + return app.Use(next => new MapWhenMiddleware(next, options).Invoke); + } + private static Func Compile(MethodInfo methodinfo, ParameterInfo[] parameters) { var middleware = typeof(T); diff --git a/src/Ocelot/Middleware/Pipeline/OcelotPipelineExtensions.cs b/src/Ocelot/Middleware/Pipeline/OcelotPipelineExtensions.cs index 374b867f..f9a74235 100644 --- a/src/Ocelot/Middleware/Pipeline/OcelotPipelineExtensions.cs +++ b/src/Ocelot/Middleware/Pipeline/OcelotPipelineExtensions.cs @@ -15,18 +15,30 @@ using Ocelot.Request.Middleware; using Ocelot.Requester.Middleware; using Ocelot.RequestId.Middleware; using Ocelot.Responder.Middleware; +using Ocelot.WebSockets.Middleware; namespace Ocelot.Middleware.Pipeline { public static class OcelotPipelineExtensions { public static OcelotRequestDelegate BuildOcelotPipeline(this IOcelotPipelineBuilder builder, - OcelotPipelineConfiguration pipelineConfiguration = null) + OcelotPipelineConfiguration pipelineConfiguration) { // This is registered to catch any global exceptions that are not handled // It also sets the Request Id if anything is set globally builder.UseExceptionHandlerMiddleware(); + // If the request is for websockets upgrade we fork into a different pipeline + builder.MapWhen(context => context.HttpContext.WebSockets.IsWebSocketRequest, + app => + { + app.UseDownstreamRouteFinderMiddleware(); + app.UseDownstreamRequestInitialiser(); + app.UseLoadBalancingMiddleware(); + app.UseDownstreamUrlCreatorMiddleware(); + app.UseWebSocketsProxyMiddleware(); + }); + // Allow the user to respond with absolutely anything they want. builder.UseIfNotNull(pipelineConfiguration.PreErrorResponderMiddleware); diff --git a/src/Ocelot/QueryStrings/AddQueriesToRequest.cs b/src/Ocelot/QueryStrings/AddQueriesToRequest.cs index 74cc7696..3cc2abdf 100644 --- a/src/Ocelot/QueryStrings/AddQueriesToRequest.cs +++ b/src/Ocelot/QueryStrings/AddQueriesToRequest.cs @@ -7,6 +7,7 @@ using Ocelot.Responses; using System.Security.Claims; using System.Net.Http; using System; +using Ocelot.Request.Middleware; using Microsoft.Extensions.Primitives; using System.Text; @@ -21,9 +22,9 @@ namespace Ocelot.QueryStrings _claimsParser = claimsParser; } - public Response SetQueriesOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest) + public Response SetQueriesOnDownstreamRequest(List claimsToThings, IEnumerable claims, DownstreamRequest downstreamRequest) { - var queryDictionary = ConvertQueryStringToDictionary(downstreamRequest.RequestUri.Query); + var queryDictionary = ConvertQueryStringToDictionary(downstreamRequest.Query); foreach (var config in claimsToThings) { @@ -46,11 +47,7 @@ namespace Ocelot.QueryStrings } } - var uriBuilder = new UriBuilder(downstreamRequest.RequestUri); - - uriBuilder.Query = ConvertDictionaryToQueryString(queryDictionary); - - downstreamRequest.RequestUri = uriBuilder.Uri; + downstreamRequest.Query = ConvertDictionaryToQueryString(queryDictionary); return new OkResponse(); } @@ -94,4 +91,4 @@ namespace Ocelot.QueryStrings return builder.ToString(); } } -} \ No newline at end of file +} diff --git a/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs b/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs index 34a6c2f5..bc017936 100644 --- a/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs +++ b/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs @@ -4,11 +4,12 @@ using Ocelot.Configuration; using Ocelot.Responses; using System.Net.Http; using System.Security.Claims; +using Ocelot.Request.Middleware; namespace Ocelot.QueryStrings { public interface IAddQueriesToRequest { - Response SetQueriesOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest); + Response SetQueriesOnDownstreamRequest(List claimsToThings, IEnumerable claims, DownstreamRequest downstreamRequest); } } diff --git a/src/Ocelot/Request/Middleware/DownstreamRequest.cs b/src/Ocelot/Request/Middleware/DownstreamRequest.cs new file mode 100644 index 00000000..449b33cc --- /dev/null +++ b/src/Ocelot/Request/Middleware/DownstreamRequest.cs @@ -0,0 +1,69 @@ +namespace Ocelot.Request.Middleware +{ + using System; + using System.Net.Http; + using System.Net.Http.Headers; + + public class DownstreamRequest + { + private readonly HttpRequestMessage _request; + + public DownstreamRequest(HttpRequestMessage request) + { + _request = request; + Method = _request.Method.Method; + OriginalString = _request.RequestUri.OriginalString; + Scheme = _request.RequestUri.Scheme; + Host = _request.RequestUri.Host; + Port = _request.RequestUri.Port; + Headers = _request.Headers; + AbsolutePath = _request.RequestUri.AbsolutePath; + Query = _request.RequestUri.Query; + } + + public HttpRequestHeaders Headers { get; } + + public string Method { get; } + + public string OriginalString { get; } + + public string Scheme { get; set; } + + public string Host { get; set; } + + public int Port { get; set; } + + public string AbsolutePath { get; set; } + + public string Query { get; set; } + + public HttpRequestMessage ToHttpRequestMessage() + { + var uriBuilder = new UriBuilder + { + Port = Port, + Host = Host, + Path = AbsolutePath, + Query = Query, + Scheme = Scheme + }; + + _request.RequestUri = uriBuilder.Uri; + return _request; + } + + public string ToUri() + { + var uriBuilder = new UriBuilder + { + Port = Port, + Host = Host, + Path = AbsolutePath, + Query = Query, + Scheme = Scheme + }; + + return uriBuilder.Uri.AbsoluteUri; + } + } +} diff --git a/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs b/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs index f14c1394..ecdb134a 100644 --- a/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs +++ b/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs @@ -1,5 +1,6 @@ namespace Ocelot.Request.Middleware { + using System.Net.Http; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Ocelot.DownstreamRouteFinder.Middleware; @@ -31,9 +32,9 @@ namespace Ocelot.Request.Middleware return; } - context.DownstreamRequest = downstreamRequest.Data; + context.DownstreamRequest = new DownstreamRequest(downstreamRequest.Data); await _next.Invoke(context); } } -} \ No newline at end of file +} diff --git a/src/Ocelot/RequestId/Middleware/ReRouteRequestIdMiddleware.cs b/src/Ocelot/RequestId/Middleware/ReRouteRequestIdMiddleware.cs index 32ee6c68..7ae416b0 100644 --- a/src/Ocelot/RequestId/Middleware/ReRouteRequestIdMiddleware.cs +++ b/src/Ocelot/RequestId/Middleware/ReRouteRequestIdMiddleware.cs @@ -9,6 +9,7 @@ using System.Net.Http; using System.Net.Http.Headers; using System.Collections.Generic; using Ocelot.DownstreamRouteFinder.Middleware; +using Ocelot.Request.Middleware; namespace Ocelot.RequestId.Middleware { @@ -82,7 +83,7 @@ namespace Ocelot.RequestId.Middleware return headers.TryGetValues(requestId.RequestIdKey, out value); } - private void AddRequestIdHeader(RequestId requestId, HttpRequestMessage httpRequestMessage) + private void AddRequestIdHeader(RequestId requestId, DownstreamRequest httpRequestMessage) { httpRequestMessage.Headers.Add(requestId.RequestIdKey, requestId.RequestIdValue); } diff --git a/src/Ocelot/Requester/HttpClientBuilder.cs b/src/Ocelot/Requester/HttpClientBuilder.cs index 6cbb3aec..348bb207 100644 --- a/src/Ocelot/Requester/HttpClientBuilder.cs +++ b/src/Ocelot/Requester/HttpClientBuilder.cs @@ -78,9 +78,7 @@ namespace Ocelot.Requester private string GetCacheKey(DownstreamContext request) { - var baseUrl = $"{request.DownstreamRequest.RequestUri.Scheme}://{request.DownstreamRequest.RequestUri.Authority}{request.DownstreamRequest.RequestUri.AbsolutePath}"; - - return baseUrl; + return request.DownstreamRequest.OriginalString; } } } diff --git a/src/Ocelot/Requester/HttpClientHttpRequester.cs b/src/Ocelot/Requester/HttpClientHttpRequester.cs index 4202f611..ddbcf119 100644 --- a/src/Ocelot/Requester/HttpClientHttpRequester.cs +++ b/src/Ocelot/Requester/HttpClientHttpRequester.cs @@ -32,7 +32,7 @@ namespace Ocelot.Requester try { - var response = await httpClient.SendAsync(context.DownstreamRequest); + var response = await httpClient.SendAsync(context.DownstreamRequest.ToHttpRequestMessage()); return new OkResponse(response); } catch (TimeoutRejectedException exception) diff --git a/src/Ocelot/ServiceDiscovery/ConsulRegistryConfiguration.cs b/src/Ocelot/ServiceDiscovery/Configuration/ConsulRegistryConfiguration.cs similarity index 90% rename from src/Ocelot/ServiceDiscovery/ConsulRegistryConfiguration.cs rename to src/Ocelot/ServiceDiscovery/Configuration/ConsulRegistryConfiguration.cs index ba389c05..9a9e5de8 100644 --- a/src/Ocelot/ServiceDiscovery/ConsulRegistryConfiguration.cs +++ b/src/Ocelot/ServiceDiscovery/Configuration/ConsulRegistryConfiguration.cs @@ -1,16 +1,16 @@ -namespace Ocelot.ServiceDiscovery -{ - public class ConsulRegistryConfiguration - { - public ConsulRegistryConfiguration(string hostName, int port, string keyOfServiceInConsul) - { - HostName = hostName; - Port = port; - KeyOfServiceInConsul = keyOfServiceInConsul; - } - - public string KeyOfServiceInConsul { get; private set; } - public string HostName { get; private set; } - public int Port { get; private set; } - } -} \ No newline at end of file +namespace Ocelot.ServiceDiscovery.Configuration +{ + public class ConsulRegistryConfiguration + { + public ConsulRegistryConfiguration(string hostName, int port, string keyOfServiceInConsul) + { + HostName = hostName; + Port = port; + KeyOfServiceInConsul = keyOfServiceInConsul; + } + + public string KeyOfServiceInConsul { get; private set; } + public string HostName { get; private set; } + public int Port { get; private set; } + } +} diff --git a/src/Ocelot/ServiceDiscovery/ServiceFabricConfiguration.cs b/src/Ocelot/ServiceDiscovery/Configuration/ServiceFabricConfiguration.cs similarity index 58% rename from src/Ocelot/ServiceDiscovery/ServiceFabricConfiguration.cs rename to src/Ocelot/ServiceDiscovery/Configuration/ServiceFabricConfiguration.cs index 7522a1e0..73211a5b 100644 --- a/src/Ocelot/ServiceDiscovery/ServiceFabricConfiguration.cs +++ b/src/Ocelot/ServiceDiscovery/Configuration/ServiceFabricConfiguration.cs @@ -1,4 +1,4 @@ -namespace Ocelot.ServiceDiscovery +namespace Ocelot.ServiceDiscovery.Configuration { public class ServiceFabricConfiguration { @@ -9,8 +9,10 @@ ServiceName = serviceName; } - public string ServiceName { get; private set; } - public string HostName { get; private set; } - public int Port { get; private set; } + public string ServiceName { get; } + + public string HostName { get; } + + public int Port { get; } } } diff --git a/src/Ocelot/ServiceDiscovery/UnableToFindServiceDiscoveryProviderError.cs b/src/Ocelot/ServiceDiscovery/Errors/UnableToFindServiceDiscoveryProviderError.cs similarity index 86% rename from src/Ocelot/ServiceDiscovery/UnableToFindServiceDiscoveryProviderError.cs rename to src/Ocelot/ServiceDiscovery/Errors/UnableToFindServiceDiscoveryProviderError.cs index 639e4659..a31ed2ee 100644 --- a/src/Ocelot/ServiceDiscovery/UnableToFindServiceDiscoveryProviderError.cs +++ b/src/Ocelot/ServiceDiscovery/Errors/UnableToFindServiceDiscoveryProviderError.cs @@ -1,12 +1,12 @@ -using Ocelot.Errors; - -namespace Ocelot.ServiceDiscovery -{ - public class UnableToFindServiceDiscoveryProviderError : Error - { - public UnableToFindServiceDiscoveryProviderError(string message) - : base(message, OcelotErrorCode.UnableToFindServiceDiscoveryProviderError) - { - } - } -} \ No newline at end of file +using Ocelot.Errors; + +namespace Ocelot.ServiceDiscovery.Errors +{ + public class UnableToFindServiceDiscoveryProviderError : Error + { + public UnableToFindServiceDiscoveryProviderError(string message) + : base(message, OcelotErrorCode.UnableToFindServiceDiscoveryProviderError) + { + } + } +} diff --git a/src/Ocelot/ServiceDiscovery/IServiceDiscoveryProviderFactory.cs b/src/Ocelot/ServiceDiscovery/IServiceDiscoveryProviderFactory.cs index 9f0bc93a..91e9c700 100644 --- a/src/Ocelot/ServiceDiscovery/IServiceDiscoveryProviderFactory.cs +++ b/src/Ocelot/ServiceDiscovery/IServiceDiscoveryProviderFactory.cs @@ -1,4 +1,5 @@ using Ocelot.Configuration; +using Ocelot.ServiceDiscovery.Providers; namespace Ocelot.ServiceDiscovery { @@ -6,4 +7,4 @@ namespace Ocelot.ServiceDiscovery { IServiceDiscoveryProvider Get(ServiceProviderConfiguration serviceConfig, DownstreamReRoute reRoute); } -} \ No newline at end of file +} diff --git a/src/Ocelot/ServiceDiscovery/ConfigurationServiceProvider.cs b/src/Ocelot/ServiceDiscovery/Providers/ConfigurationServiceProvider.cs similarity index 88% rename from src/Ocelot/ServiceDiscovery/ConfigurationServiceProvider.cs rename to src/Ocelot/ServiceDiscovery/Providers/ConfigurationServiceProvider.cs index 28a296c5..04369996 100644 --- a/src/Ocelot/ServiceDiscovery/ConfigurationServiceProvider.cs +++ b/src/Ocelot/ServiceDiscovery/Providers/ConfigurationServiceProvider.cs @@ -1,21 +1,21 @@ -using System.Collections.Generic; -using System.Threading.Tasks; -using Ocelot.Values; - -namespace Ocelot.ServiceDiscovery -{ - public class ConfigurationServiceProvider : IServiceDiscoveryProvider - { - private readonly List _services; - - public ConfigurationServiceProvider(List services) - { - _services = services; - } - - public async Task> Get() - { - return await Task.FromResult(_services); - } - } -} \ No newline at end of file +using System.Collections.Generic; +using System.Threading.Tasks; +using Ocelot.Values; + +namespace Ocelot.ServiceDiscovery.Providers +{ + public class ConfigurationServiceProvider : IServiceDiscoveryProvider + { + private readonly List _services; + + public ConfigurationServiceProvider(List services) + { + _services = services; + } + + public async Task> Get() + { + return await Task.FromResult(_services); + } + } +} diff --git a/src/Ocelot/ServiceDiscovery/ConsulServiceDiscoveryProvider.cs b/src/Ocelot/ServiceDiscovery/Providers/ConsulServiceDiscoveryProvider.cs similarity index 96% rename from src/Ocelot/ServiceDiscovery/ConsulServiceDiscoveryProvider.cs rename to src/Ocelot/ServiceDiscovery/Providers/ConsulServiceDiscoveryProvider.cs index ef882a50..775033fd 100644 --- a/src/Ocelot/ServiceDiscovery/ConsulServiceDiscoveryProvider.cs +++ b/src/Ocelot/ServiceDiscovery/Providers/ConsulServiceDiscoveryProvider.cs @@ -1,83 +1,84 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using Consul; -using Ocelot.Infrastructure.Extensions; -using Ocelot.Logging; -using Ocelot.Values; - -namespace Ocelot.ServiceDiscovery -{ - public class ConsulServiceDiscoveryProvider : IServiceDiscoveryProvider - { - private readonly ConsulRegistryConfiguration _consulConfig; - private readonly IOcelotLogger _logger; - private readonly ConsulClient _consul; - private const string VersionPrefix = "version-"; - - public ConsulServiceDiscoveryProvider(ConsulRegistryConfiguration consulRegistryConfiguration, IOcelotLoggerFactory factory) - {; - _logger = factory.CreateLogger(); - - var consulHost = string.IsNullOrEmpty(consulRegistryConfiguration?.HostName) ? "localhost" : consulRegistryConfiguration.HostName; - - var consulPort = consulRegistryConfiguration?.Port ?? 8500; - - _consulConfig = new ConsulRegistryConfiguration(consulHost, consulPort, consulRegistryConfiguration?.KeyOfServiceInConsul); - - _consul = new ConsulClient(config => - { - config.Address = new Uri($"http://{_consulConfig.HostName}:{_consulConfig.Port}"); - }); - } - - public async Task> Get() - { - var queryResult = await _consul.Health.Service(_consulConfig.KeyOfServiceInConsul, string.Empty, true); - - var services = new List(); - - foreach (var serviceEntry in queryResult.Response) - { - if (IsValid(serviceEntry)) - { - services.Add(BuildService(serviceEntry)); - } - else - { - _logger.LogError($"Unable to use service Address: {serviceEntry.Service.Address} and Port: {serviceEntry.Service.Port} as it is invalid. Address must contain host only e.g. localhost and port must be greater than 0"); - } - } - - return services.ToList(); - } - - private Service BuildService(ServiceEntry serviceEntry) - { - return new Service( - serviceEntry.Service.Service, - new ServiceHostAndPort(serviceEntry.Service.Address, serviceEntry.Service.Port), - serviceEntry.Service.ID, - GetVersionFromStrings(serviceEntry.Service.Tags), - serviceEntry.Service.Tags ?? Enumerable.Empty()); - } - - private bool IsValid(ServiceEntry serviceEntry) - { - if (serviceEntry.Service.Address.Contains("http://") || serviceEntry.Service.Address.Contains("https://") || serviceEntry.Service.Port <= 0) - { - return false; - } - - return true; - } - - private string GetVersionFromStrings(IEnumerable strings) - { - return strings - ?.FirstOrDefault(x => x.StartsWith(VersionPrefix, StringComparison.Ordinal)) - .TrimStart(VersionPrefix); - } - } -} +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Consul; +using Ocelot.Infrastructure.Extensions; +using Ocelot.Logging; +using Ocelot.ServiceDiscovery.Configuration; +using Ocelot.Values; + +namespace Ocelot.ServiceDiscovery.Providers +{ + public class ConsulServiceDiscoveryProvider : IServiceDiscoveryProvider + { + private readonly ConsulRegistryConfiguration _consulConfig; + private readonly IOcelotLogger _logger; + private readonly ConsulClient _consul; + private const string VersionPrefix = "version-"; + + public ConsulServiceDiscoveryProvider(ConsulRegistryConfiguration consulRegistryConfiguration, IOcelotLoggerFactory factory) + {; + _logger = factory.CreateLogger(); + + var consulHost = string.IsNullOrEmpty(consulRegistryConfiguration?.HostName) ? "localhost" : consulRegistryConfiguration.HostName; + + var consulPort = consulRegistryConfiguration?.Port ?? 8500; + + _consulConfig = new ConsulRegistryConfiguration(consulHost, consulPort, consulRegistryConfiguration?.KeyOfServiceInConsul); + + _consul = new ConsulClient(config => + { + config.Address = new Uri($"http://{_consulConfig.HostName}:{_consulConfig.Port}"); + }); + } + + public async Task> Get() + { + var queryResult = await _consul.Health.Service(_consulConfig.KeyOfServiceInConsul, string.Empty, true); + + var services = new List(); + + foreach (var serviceEntry in queryResult.Response) + { + if (IsValid(serviceEntry)) + { + services.Add(BuildService(serviceEntry)); + } + else + { + _logger.LogError($"Unable to use service Address: {serviceEntry.Service.Address} and Port: {serviceEntry.Service.Port} as it is invalid. Address must contain host only e.g. localhost and port must be greater than 0"); + } + } + + return services.ToList(); + } + + private Service BuildService(ServiceEntry serviceEntry) + { + return new Service( + serviceEntry.Service.Service, + new ServiceHostAndPort(serviceEntry.Service.Address, serviceEntry.Service.Port), + serviceEntry.Service.ID, + GetVersionFromStrings(serviceEntry.Service.Tags), + serviceEntry.Service.Tags ?? Enumerable.Empty()); + } + + private bool IsValid(ServiceEntry serviceEntry) + { + if (serviceEntry.Service.Address.Contains("http://") || serviceEntry.Service.Address.Contains("https://") || serviceEntry.Service.Port <= 0) + { + return false; + } + + return true; + } + + private string GetVersionFromStrings(IEnumerable strings) + { + return strings + ?.FirstOrDefault(x => x.StartsWith(VersionPrefix, StringComparison.Ordinal)) + .TrimStart(VersionPrefix); + } + } +} diff --git a/src/Ocelot/ServiceDiscovery/IServiceDiscoveryProvider.cs b/src/Ocelot/ServiceDiscovery/Providers/IServiceDiscoveryProvider.cs similarity index 79% rename from src/Ocelot/ServiceDiscovery/IServiceDiscoveryProvider.cs rename to src/Ocelot/ServiceDiscovery/Providers/IServiceDiscoveryProvider.cs index 3d9887f3..ef28375e 100644 --- a/src/Ocelot/ServiceDiscovery/IServiceDiscoveryProvider.cs +++ b/src/Ocelot/ServiceDiscovery/Providers/IServiceDiscoveryProvider.cs @@ -1,11 +1,11 @@ -using System.Collections.Generic; -using System.Threading.Tasks; -using Ocelot.Values; - -namespace Ocelot.ServiceDiscovery -{ - public interface IServiceDiscoveryProvider - { - Task> Get(); - } -} \ No newline at end of file +using System.Collections.Generic; +using System.Threading.Tasks; +using Ocelot.Values; + +namespace Ocelot.ServiceDiscovery.Providers +{ + public interface IServiceDiscoveryProvider + { + Task> Get(); + } +} diff --git a/src/Ocelot/ServiceDiscovery/ServiceFabricServiceDiscoveryProvider.cs b/src/Ocelot/ServiceDiscovery/Providers/ServiceFabricServiceDiscoveryProvider.cs similarity index 90% rename from src/Ocelot/ServiceDiscovery/ServiceFabricServiceDiscoveryProvider.cs rename to src/Ocelot/ServiceDiscovery/Providers/ServiceFabricServiceDiscoveryProvider.cs index 257298b7..bc7ebf6e 100644 --- a/src/Ocelot/ServiceDiscovery/ServiceFabricServiceDiscoveryProvider.cs +++ b/src/Ocelot/ServiceDiscovery/Providers/ServiceFabricServiceDiscoveryProvider.cs @@ -1,8 +1,9 @@ using System.Collections.Generic; using System.Threading.Tasks; +using Ocelot.ServiceDiscovery.Configuration; using Ocelot.Values; -namespace Ocelot.ServiceDiscovery +namespace Ocelot.ServiceDiscovery.Providers { public class ServiceFabricServiceDiscoveryProvider : IServiceDiscoveryProvider { diff --git a/src/Ocelot/ServiceDiscovery/ServiceDiscoveryProviderFactory.cs b/src/Ocelot/ServiceDiscovery/ServiceDiscoveryProviderFactory.cs index 86b28a72..95201dfb 100644 --- a/src/Ocelot/ServiceDiscovery/ServiceDiscoveryProviderFactory.cs +++ b/src/Ocelot/ServiceDiscovery/ServiceDiscoveryProviderFactory.cs @@ -1,6 +1,8 @@ using System.Collections.Generic; using Ocelot.Configuration; using Ocelot.Logging; +using Ocelot.ServiceDiscovery.Configuration; +using Ocelot.ServiceDiscovery.Providers; using Ocelot.Values; namespace Ocelot.ServiceDiscovery diff --git a/src/Ocelot/Values/DownstreamPath.cs b/src/Ocelot/Values/DownstreamPath.cs index b4dd346b..f0821118 100644 --- a/src/Ocelot/Values/DownstreamPath.cs +++ b/src/Ocelot/Values/DownstreamPath.cs @@ -7,6 +7,6 @@ Value = value; } - public string Value { get; private set; } + public string Value { get; } } } diff --git a/src/Ocelot/Values/DownstreamUrl.cs b/src/Ocelot/Values/DownstreamUrl.cs deleted file mode 100644 index 48644595..00000000 --- a/src/Ocelot/Values/DownstreamUrl.cs +++ /dev/null @@ -1,12 +0,0 @@ -namespace Ocelot.Values -{ - public class DownstreamUrl - { - public DownstreamUrl(string value) - { - Value = value; - } - - public string Value { get; private set; } - } -} \ No newline at end of file diff --git a/src/Ocelot/Values/PathTemplate.cs b/src/Ocelot/Values/PathTemplate.cs index 6d7221a6..584f80ac 100644 --- a/src/Ocelot/Values/PathTemplate.cs +++ b/src/Ocelot/Values/PathTemplate.cs @@ -7,6 +7,6 @@ Value = value; } - public string Value { get; private set; } + public string Value { get; } } } diff --git a/src/Ocelot/Values/Service.cs b/src/Ocelot/Values/Service.cs index abf5449b..234b248b 100644 --- a/src/Ocelot/Values/Service.cs +++ b/src/Ocelot/Values/Service.cs @@ -5,9 +5,9 @@ namespace Ocelot.Values public class Service { public Service(string name, - ServiceHostAndPort hostAndPort, - string id, - string version, + ServiceHostAndPort hostAndPort, + string id, + string version, IEnumerable tags) { Name = name; @@ -17,14 +17,14 @@ namespace Ocelot.Values Tags = tags; } - public string Id { get; private set; } + public string Id { get; } - public string Name { get; private set; } + public string Name { get; } - public string Version { get; private set; } + public string Version { get; } - public IEnumerable Tags { get; private set; } + public IEnumerable Tags { get; } - public ServiceHostAndPort HostAndPort { get; private set; } + public ServiceHostAndPort HostAndPort { get; } } } diff --git a/src/Ocelot/Values/ServiceHostAndPort.cs b/src/Ocelot/Values/ServiceHostAndPort.cs index 135944b1..4e4271b8 100644 --- a/src/Ocelot/Values/ServiceHostAndPort.cs +++ b/src/Ocelot/Values/ServiceHostAndPort.cs @@ -8,7 +8,8 @@ DownstreamPort = downstreamPort; } - public string DownstreamHost { get; private set; } - public int DownstreamPort { get; private set; } + public string DownstreamHost { get; } + + public int DownstreamPort { get; } } } diff --git a/src/Ocelot/Values/UpstreamPathTemplate.cs b/src/Ocelot/Values/UpstreamPathTemplate.cs index 0cd44bc5..4b33a230 100644 --- a/src/Ocelot/Values/UpstreamPathTemplate.cs +++ b/src/Ocelot/Values/UpstreamPathTemplate.cs @@ -8,7 +8,8 @@ namespace Ocelot.Values Priority = priority; } - public string Template {get;} - public int Priority {get;} + public string Template { get; } + + public int Priority { get; } } -} \ No newline at end of file +} diff --git a/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddleware.cs b/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddleware.cs new file mode 100644 index 00000000..7a9e0ad0 --- /dev/null +++ b/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddleware.cs @@ -0,0 +1,103 @@ +using System; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Ocelot.Logging; +using Ocelot.Middleware; + +namespace Ocelot.WebSockets.Middleware +{ + public class WebSocketsProxyMiddleware : OcelotMiddleware + { + private OcelotRequestDelegate _next; + private IOcelotLogger _logger; + + public WebSocketsProxyMiddleware(OcelotRequestDelegate next, + IOcelotLoggerFactory loggerFactory) + { + _next = next; + _logger = loggerFactory.CreateLogger(); + } + + public async Task Invoke(DownstreamContext context) + { + await Proxy(context.HttpContext, context.DownstreamRequest.ToUri()); + } + + private async Task Proxy(HttpContext context, string serverEndpoint) + { + var wsToUpstreamClient = await context.WebSockets.AcceptWebSocketAsync(); + + var wsToDownstreamService = new ClientWebSocket(); + var uri = new Uri(serverEndpoint); + await wsToDownstreamService.ConnectAsync(uri, CancellationToken.None); + + var receiveFromUpstreamSendToDownstream = Task.Run(async () => + { + var buffer = new byte[1024 * 4]; + + var receiveSegment = new ArraySegment(buffer); + + while (wsToUpstreamClient.State == WebSocketState.Open || wsToUpstreamClient.State == WebSocketState.CloseSent) + { + var result = await wsToUpstreamClient.ReceiveAsync(receiveSegment, CancellationToken.None); + + var sendSegment = new ArraySegment(buffer, 0, result.Count); + + if(result.MessageType == WebSocketMessageType.Close) + { + await wsToUpstreamClient.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", + CancellationToken.None); + + await wsToDownstreamService.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", + CancellationToken.None); + + break; + } + + await wsToDownstreamService.SendAsync(sendSegment, result.MessageType, result.EndOfMessage, + CancellationToken.None); + + if (wsToUpstreamClient.State != WebSocketState.Open) + { + await wsToDownstreamService.CloseAsync(WebSocketCloseStatus.Empty, "", + CancellationToken.None); + break; + } + } + }); + + var receiveFromDownstreamAndSendToUpstream = Task.Run(async () => + { + var buffer = new byte[1024 * 4]; + + while (wsToDownstreamService.State == WebSocketState.Open || wsToDownstreamService.State == WebSocketState.CloseSent) + { + if (wsToUpstreamClient.State != WebSocketState.Open) + { + break; + } + else + { + var receiveSegment = new ArraySegment(buffer); + var result = await wsToDownstreamService.ReceiveAsync(receiveSegment, CancellationToken.None); + + if (result.MessageType == WebSocketMessageType.Close) + { + break; + } + + var sendSegment = new ArraySegment(buffer, 0, result.Count); + + //send to upstream client + await wsToUpstreamClient.SendAsync(sendSegment, result.MessageType, result.EndOfMessage, + CancellationToken.None); + } + } + }); + + await Task.WhenAll(receiveFromDownstreamAndSendToUpstream, receiveFromUpstreamSendToDownstream); + } + } +} diff --git a/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddlewareExtensions.cs b/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddlewareExtensions.cs new file mode 100644 index 00000000..e973dfc3 --- /dev/null +++ b/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddlewareExtensions.cs @@ -0,0 +1,12 @@ +using Ocelot.Middleware.Pipeline; + +namespace Ocelot.WebSockets.Middleware +{ + public static class WebSocketsProxyMiddlewareExtensions + { + public static IOcelotPipelineBuilder UseWebSocketsProxyMiddleware(this IOcelotPipelineBuilder builder) + { + return builder.UseMiddleware(); + } + } +} diff --git a/test/Ocelot.AcceptanceTests/LoadBalancerTests.cs b/test/Ocelot.AcceptanceTests/LoadBalancerTests.cs index f60008bd..14c69f9a 100644 --- a/test/Ocelot.AcceptanceTests/LoadBalancerTests.cs +++ b/test/Ocelot.AcceptanceTests/LoadBalancerTests.cs @@ -1,8 +1,6 @@ using System; using System.Collections.Generic; using System.IO; -using System.Net; -using Consul; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; @@ -28,7 +26,7 @@ namespace Ocelot.AcceptanceTests } [Fact] - public void should_use_service_discovery_and_load_balance_request() + public void should_load_balance_request() { var downstreamServiceOneUrl = "http://localhost:50881"; var downstreamServiceTwoUrl = "http://localhost:50892"; @@ -74,18 +72,6 @@ namespace Ocelot.AcceptanceTests .BDDfy(); } - private void ThenOnlyOneServiceHasBeenCalled() - { - _counterOne.ShouldBe(10); - _counterTwo.ShouldBe(0); - } - - private void GivenIResetCounters() - { - _counterOne = 0; - _counterTwo = 0; - } - private void ThenBothServicesCalledRealisticAmountOfTimes(int bottom, int top) { _counterOne.ShouldBeInRange(bottom, top); @@ -121,7 +107,7 @@ namespace Ocelot.AcceptanceTests context.Response.StatusCode = statusCode; await context.Response.WriteAsync(response); } - catch (System.Exception exception) + catch (Exception exception) { await context.Response.WriteAsync(exception.StackTrace); } diff --git a/test/Ocelot.AcceptanceTests/Steps.cs b/test/Ocelot.AcceptanceTests/Steps.cs index 1ee6f9bc..cca5ac18 100644 --- a/test/Ocelot.AcceptanceTests/Steps.cs +++ b/test/Ocelot.AcceptanceTests/Steps.cs @@ -40,12 +40,49 @@ namespace Ocelot.AcceptanceTests public string RequestIdKey = "OcRequestId"; private readonly Random _random; private IWebHostBuilder _webHostBuilder; + private WebHostBuilder _ocelotBuilder; + private IWebHost _ocelotHost; public Steps() { _random = new Random(); } + public async Task StartFakeOcelotWithWebSockets() + { + _ocelotBuilder = new WebHostBuilder(); + _ocelotBuilder.ConfigureServices(s => + { + s.AddSingleton(_ocelotBuilder); + s.AddOcelot(); + }); + _ocelotBuilder.UseKestrel() + .UseUrls("http://localhost:5000") + .UseContentRoot(Directory.GetCurrentDirectory()) + .ConfigureAppConfiguration((hostingContext, config) => + { + config.SetBasePath(hostingContext.HostingEnvironment.ContentRootPath); + var env = hostingContext.HostingEnvironment; + config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true); + config.AddJsonFile("configuration.json"); + config.AddEnvironmentVariables(); + }) + .ConfigureLogging((hostingContext, logging) => + { + logging.AddConfiguration(hostingContext.Configuration.GetSection("Logging")); + logging.AddConsole(); + }) + .Configure(app => + { + app.UseWebSockets(); + app.UseOcelot().Wait(); + }) + .UseIISIntegration(); + _ocelotHost = _ocelotBuilder.Build(); + await _ocelotHost.StartAsync(); + } + public void GivenThereIsAConfiguration(FileConfiguration fileConfiguration) { var configurationPath = TestConfiguration.ConfigurationPath; @@ -698,6 +735,7 @@ namespace Ocelot.AcceptanceTests { _ocelotClient?.Dispose(); _ocelotServer?.Dispose(); + _ocelotHost?.Dispose(); } public void ThenTheRequestIdIsReturned() diff --git a/test/Ocelot.AcceptanceTests/WebSocketTests.cs b/test/Ocelot.AcceptanceTests/WebSocketTests.cs new file mode 100644 index 00000000..8137649a --- /dev/null +++ b/test/Ocelot.AcceptanceTests/WebSocketTests.cs @@ -0,0 +1,487 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Consul; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Ocelot.Configuration.File; +using Shouldly; +using TestStack.BDDfy; +using Xunit; + +namespace Ocelot.AcceptanceTests +{ + public class WebSocketTests : IDisposable + { + private IWebHost _firstDownstreamHost; + private IWebHost _secondDownstreamHost; + private readonly List _secondRecieved; + private readonly List _firstRecieved; + private readonly List _serviceEntries; + private readonly Steps _steps; + private IWebHost _fakeConsulBuilder; + + public WebSocketTests() + { + _steps = new Steps(); + _firstRecieved = new List(); + _secondRecieved = new List(); + _serviceEntries = new List(); + } + + [Fact] + public async Task should_proxy_websocket_input_to_downstream_service() + { + var downstreamPort = 5001; + var downstreamHost = "localhost"; + + var config = new FileConfiguration + { + ReRoutes = new List + { + new FileReRoute + { + UpstreamPathTemplate = "/", + DownstreamPathTemplate = "/ws", + DownstreamScheme = "ws", + DownstreamHostAndPorts = new List + { + new FileHostAndPort + { + Host = downstreamHost, + Port = downstreamPort + } + } + } + } + }; + + this.Given(_ => _steps.GivenThereIsAConfiguration(config)) + .And(_ => _steps.StartFakeOcelotWithWebSockets()) + .And(_ => StartFakeDownstreamService($"http://{downstreamHost}:{downstreamPort}", "/ws")) + .When(_ => StartClient("ws://localhost:5000/")) + .Then(_ => _firstRecieved.Count.ShouldBe(10)) + .BDDfy(); + } + + [Fact] + public async Task should_proxy_websocket_input_to_downstream_service_and_use_load_balancer() + { + var downstreamPort = 5005; + var downstreamHost = "localhost"; + var secondDownstreamPort = 5006; + var secondDownstreamHost = "localhost"; + + var config = new FileConfiguration + { + ReRoutes = new List + { + new FileReRoute + { + UpstreamPathTemplate = "/", + DownstreamPathTemplate = "/ws", + DownstreamScheme = "ws", + DownstreamHostAndPorts = new List + { + new FileHostAndPort + { + Host = downstreamHost, + Port = downstreamPort + }, + new FileHostAndPort + { + Host = secondDownstreamHost, + Port = secondDownstreamPort + } + }, + LoadBalancer = "RoundRobin" + } + } + }; + + this.Given(_ => _steps.GivenThereIsAConfiguration(config)) + .And(_ => _steps.StartFakeOcelotWithWebSockets()) + .And(_ => StartFakeDownstreamService($"http://{downstreamHost}:{downstreamPort}", "/ws")) + .And(_ => StartSecondFakeDownstreamService($"http://{secondDownstreamHost}:{secondDownstreamPort}","/ws")) + .When(_ => WhenIStartTheClients()) + .Then(_ => ThenBothDownstreamServicesAreCalled()) + .BDDfy(); + } + + [Fact] + public async Task should_proxy_websocket_input_to_downstream_service_and_use_service_discovery_and_load_balancer() + { + var downstreamPort = 5007; + var downstreamHost = "localhost"; + + var secondDownstreamPort = 5008; + var secondDownstreamHost = "localhost"; + + var serviceName = "websockets"; + var consulPort = 8509; + var fakeConsulServiceDiscoveryUrl = $"http://localhost:{consulPort}"; + var serviceEntryOne = new ServiceEntry() + { + Service = new AgentService() + { + Service = serviceName, + Address = downstreamHost, + Port = downstreamPort, + ID = Guid.NewGuid().ToString(), + Tags = new string[0] + }, + }; + var serviceEntryTwo = new ServiceEntry() + { + Service = new AgentService() + { + Service = serviceName, + Address = secondDownstreamHost, + Port = secondDownstreamPort, + ID = Guid.NewGuid().ToString(), + Tags = new string[0] + }, + }; + + var config = new FileConfiguration + { + ReRoutes = new List + { + new FileReRoute + { + UpstreamPathTemplate = "/", + DownstreamPathTemplate = "/ws", + DownstreamScheme = "ws", + LoadBalancer = "RoundRobin", + ServiceName = serviceName, + UseServiceDiscovery = true + } + }, + GlobalConfiguration = new FileGlobalConfiguration + { + ServiceDiscoveryProvider = new FileServiceDiscoveryProvider + { + Host = "localhost", + Port = consulPort, + Type = "consul" + } + } + }; + + this.Given(_ => _steps.GivenThereIsAConfiguration(config)) + .And(_ => _steps.StartFakeOcelotWithWebSockets()) + .And(_ => GivenThereIsAFakeConsulServiceDiscoveryProvider(fakeConsulServiceDiscoveryUrl, serviceName)) + .And(_ => GivenTheServicesAreRegisteredWithConsul(serviceEntryOne, serviceEntryTwo)) + .And(_ => StartFakeDownstreamService($"http://{downstreamHost}:{downstreamPort}", "/ws")) + .And(_ => StartSecondFakeDownstreamService($"http://{secondDownstreamHost}:{secondDownstreamPort}", "/ws")) + .When(_ => WhenIStartTheClients()) + .Then(_ => ThenBothDownstreamServicesAreCalled()) + .BDDfy(); + } + + private void ThenBothDownstreamServicesAreCalled() + { + _firstRecieved.Count.ShouldBe(10); + _firstRecieved.ForEach(x => + { + x.ShouldBe("test"); + }); + + _secondRecieved.Count.ShouldBe(10); + _secondRecieved.ForEach(x => + { + x.ShouldBe("chocolate"); + }); + } + + private void GivenTheServicesAreRegisteredWithConsul(params ServiceEntry[] serviceEntries) + { + foreach (var serviceEntry in serviceEntries) + { + _serviceEntries.Add(serviceEntry); + } + } + + private void GivenThereIsAFakeConsulServiceDiscoveryProvider(string url, string serviceName) + { + _fakeConsulBuilder = new WebHostBuilder() + .UseUrls(url) + .UseKestrel() + .UseContentRoot(Directory.GetCurrentDirectory()) + .UseIISIntegration() + .UseUrls(url) + .Configure(app => + { + app.Run(async context => + { + if (context.Request.Path.Value == $"/v1/health/service/{serviceName}") + { + await context.Response.WriteJsonAsync(_serviceEntries); + } + }); + }) + .Build(); + + _fakeConsulBuilder.Start(); + } + + private async Task WhenIStartTheClients() + { + var firstClient = StartClient("ws://localhost:5000/"); + + var secondClient = StartSecondClient("ws://localhost:5000/"); + + await Task.WhenAll(firstClient, secondClient); + } + + private async Task StartClient(string url) + { + var client = new ClientWebSocket(); + + await client.ConnectAsync(new Uri(url), CancellationToken.None); + + var sending = Task.Run(async () => + { + string line = "test"; + for (int i = 0; i < 10; i++) + { + var bytes = Encoding.UTF8.GetBytes(line); + + await client.SendAsync(new ArraySegment(bytes), WebSocketMessageType.Text, true, + CancellationToken.None); + await Task.Delay(10); + } + + await client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + }); + + var receiving = Task.Run(async () => + { + var buffer = new byte[1024 * 4]; + + while (true) + { + var result = await client.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + + if (result.MessageType == WebSocketMessageType.Text) + { + _firstRecieved.Add(Encoding.UTF8.GetString(buffer, 0, result.Count)); + } + + else if (result.MessageType == WebSocketMessageType.Close) + { + await client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + break; + } + } + }); + + await Task.WhenAll(sending, receiving); + } + + private async Task StartSecondClient(string url) + { + await Task.Delay(500); + + var client = new ClientWebSocket(); + + await client.ConnectAsync(new Uri(url), CancellationToken.None); + + var sending = Task.Run(async () => + { + string line = "test"; + for (int i = 0; i < 10; i++) + { + var bytes = Encoding.UTF8.GetBytes(line); + + await client.SendAsync(new ArraySegment(bytes), WebSocketMessageType.Text, true, + CancellationToken.None); + await Task.Delay(10); + } + + await client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + }); + + var receiving = Task.Run(async () => + { + var buffer = new byte[1024 * 4]; + + while (true) + { + var result = await client.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + + if (result.MessageType == WebSocketMessageType.Text) + { + _secondRecieved.Add(Encoding.UTF8.GetString(buffer, 0, result.Count)); + } + + else if (result.MessageType == WebSocketMessageType.Close) + { + await client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + break; + } + } + }); + + await Task.WhenAll(sending, receiving); + } + + + private async Task StartFakeDownstreamService(string url, string path) + { + _firstDownstreamHost = new WebHostBuilder() + .ConfigureServices(s => { }).UseKestrel() + .UseUrls(url) + .UseContentRoot(Directory.GetCurrentDirectory()) + .ConfigureAppConfiguration((hostingContext, config) => + { + config.SetBasePath(hostingContext.HostingEnvironment.ContentRootPath); + var env = hostingContext.HostingEnvironment; + config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true); + config.AddEnvironmentVariables(); + }) + .ConfigureLogging((hostingContext, logging) => + { + logging.AddConfiguration(hostingContext.Configuration.GetSection("Logging")); + logging.AddConsole(); + }) + .Configure(app => + { + app.UseWebSockets(); + app.Use(async (context, next) => + { + if (context.Request.Path == path) + { + if (context.WebSockets.IsWebSocketRequest) + { + WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync(); + await Echo(webSocket); + } + else + { + context.Response.StatusCode = 400; + } + } + else + { + await next(); + } + }); + }) + .UseIISIntegration().Build(); + await _firstDownstreamHost.StartAsync(); + } + + + private async Task StartSecondFakeDownstreamService(string url, string path) + { + _secondDownstreamHost = new WebHostBuilder() + .ConfigureServices(s => { }).UseKestrel() + .UseUrls(url) + .UseContentRoot(Directory.GetCurrentDirectory()) + .ConfigureAppConfiguration((hostingContext, config) => + { + config.SetBasePath(hostingContext.HostingEnvironment.ContentRootPath); + var env = hostingContext.HostingEnvironment; + config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true); + config.AddEnvironmentVariables(); + }) + .ConfigureLogging((hostingContext, logging) => + { + logging.AddConfiguration(hostingContext.Configuration.GetSection("Logging")); + logging.AddConsole(); + }) + .Configure(app => + { + app.UseWebSockets(); + app.Use(async (context, next) => + { + if (context.Request.Path == path) + { + if (context.WebSockets.IsWebSocketRequest) + { + WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync(); + await Message(webSocket); + } + else + { + context.Response.StatusCode = 400; + } + } + else + { + await next(); + } + }); + }) + .UseIISIntegration().Build(); + await _secondDownstreamHost.StartAsync(); + } + + + private async Task Echo(WebSocket webSocket) + { + try + { + var buffer = new byte[1024 * 4]; + + var result = await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + + while (!result.CloseStatus.HasValue) + { + await webSocket.SendAsync(new ArraySegment(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None); + + result = await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + } + + await webSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None); + } + catch (Exception e) + { + Console.WriteLine(e); + } + } + + private async Task Message(WebSocket webSocket) + { + try + { + var buffer = new byte[1024 * 4]; + + var bytes = Encoding.UTF8.GetBytes("chocolate"); + + var result = await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + + while (!result.CloseStatus.HasValue) + { + await webSocket.SendAsync(new ArraySegment(bytes), result.MessageType, result.EndOfMessage, CancellationToken.None); + + result = await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + } + + await webSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None); + } + catch (Exception e) + { + Console.WriteLine(e); + } + } + + public void Dispose() + { + _steps.Dispose(); + _firstDownstreamHost?.Dispose(); + _secondDownstreamHost?.Dispose(); + _fakeConsulBuilder?.Dispose(); + } + } +} diff --git a/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareRealCacheTests.cs b/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareRealCacheTests.cs index 210b1cb5..c1bc9b9d 100644 --- a/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareRealCacheTests.cs +++ b/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareRealCacheTests.cs @@ -43,7 +43,7 @@ namespace Ocelot.UnitTests.Cache }); _cacheManager = new OcelotCacheManagerCache(cacheManagerOutputCache); _downstreamContext = new DownstreamContext(new DefaultHttpContext()); - _downstreamContext.DownstreamRequest = new HttpRequestMessage(HttpMethod.Get, "https://some.url/blah?abcd=123"); + _downstreamContext.DownstreamRequest = new Ocelot.Request.Middleware.DownstreamRequest(new HttpRequestMessage(HttpMethod.Get, "https://some.url/blah?abcd=123")); _next = context => Task.CompletedTask; _middleware = new OutputCacheMiddleware(_next, _loggerFactory.Object, _cacheManager, _regionCreator); } diff --git a/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareTests.cs b/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareTests.cs index 78911787..83a4744d 100644 --- a/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareTests.cs @@ -40,7 +40,7 @@ namespace Ocelot.UnitTests.Cache _logger = new Mock(); _loggerFactory.Setup(x => x.CreateLogger()).Returns(_logger.Object); _next = context => Task.CompletedTask; - _downstreamContext.DownstreamRequest = new HttpRequestMessage(HttpMethod.Get, "https://some.url/blah?abcd=123"); + _downstreamContext.DownstreamRequest = new Ocelot.Request.Middleware.DownstreamRequest(new HttpRequestMessage(HttpMethod.Get, "https://some.url/blah?abcd=123")); } [Fact] diff --git a/test/Ocelot.UnitTests/DownstreamUrlCreator/DownstreamUrlCreatorMiddlewareTests.cs b/test/Ocelot.UnitTests/DownstreamUrlCreator/DownstreamUrlCreatorMiddlewareTests.cs index c226f245..57177126 100644 --- a/test/Ocelot.UnitTests/DownstreamUrlCreator/DownstreamUrlCreatorMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/DownstreamUrlCreator/DownstreamUrlCreatorMiddlewareTests.cs @@ -20,6 +20,7 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator using Xunit; using Shouldly; using Microsoft.AspNetCore.Http; + using Ocelot.Request.Middleware; public class DownstreamUrlCreatorMiddlewareTests { @@ -30,6 +31,7 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator private DownstreamUrlCreatorMiddleware _middleware; private DownstreamContext _downstreamContext; private OcelotRequestDelegate _next; + private HttpRequestMessage _request; public DownstreamUrlCreatorMiddlewareTests() { @@ -38,7 +40,8 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator _logger = new Mock(); _loggerFactory.Setup(x => x.CreateLogger()).Returns(_logger.Object); _downstreamUrlTemplateVariableReplacer = new Mock(); - _downstreamContext.DownstreamRequest = new HttpRequestMessage(HttpMethod.Get, "https://my.url/abc/?q=123"); + _request = new HttpRequestMessage(HttpMethod.Get, "https://my.url/abc/?q=123"); + _downstreamContext.DownstreamRequest = new DownstreamRequest(_request); _next = context => Task.CompletedTask; } @@ -208,7 +211,9 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator private void GivenTheDownstreamRequestUriIs(string uri) { - _downstreamContext.DownstreamRequest.RequestUri = new Uri(uri); + _request.RequestUri = new Uri(uri); + //todo - not sure if needed + _downstreamContext.DownstreamRequest = new DownstreamRequest(_request); } private void GivenTheUrlReplacerWillReturn(string path) @@ -221,7 +226,7 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator private void ThenTheDownstreamRequestUriIs(string expectedUri) { - _downstreamContext.DownstreamRequest.RequestUri.OriginalString.ShouldBe(expectedUri); + _downstreamContext.DownstreamRequest.ToHttpRequestMessage().RequestUri.OriginalString.ShouldBe(expectedUri); } } } diff --git a/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs b/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs index 0a8290bc..8e323bb2 100644 --- a/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs +++ b/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs @@ -11,6 +11,7 @@ using Shouldly; using TestStack.BDDfy; using Xunit; using System.Net.Http; +using Ocelot.Request.Middleware; namespace Ocelot.UnitTests.Headers { @@ -18,7 +19,7 @@ namespace Ocelot.UnitTests.Headers { private readonly AddHeadersToRequest _addHeadersToRequest; private readonly Mock _parser; - private readonly HttpRequestMessage _downstreamRequest; + private readonly DownstreamRequest _downstreamRequest; private List _claims; private List _configuration; private Response _result; @@ -28,7 +29,7 @@ namespace Ocelot.UnitTests.Headers { _parser = new Mock(); _addHeadersToRequest = new AddHeadersToRequest(_parser.Object); - _downstreamRequest = new HttpRequestMessage(); + _downstreamRequest = new DownstreamRequest(new HttpRequestMessage(HttpMethod.Get, "http://test.com")); } [Fact] diff --git a/test/Ocelot.UnitTests/Headers/HttpHeadersTransformationMiddlewareTests.cs b/test/Ocelot.UnitTests/Headers/HttpHeadersTransformationMiddlewareTests.cs index c49bf172..4992479b 100644 --- a/test/Ocelot.UnitTests/Headers/HttpHeadersTransformationMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Headers/HttpHeadersTransformationMiddlewareTests.cs @@ -16,6 +16,7 @@ using Ocelot.Middleware; namespace Ocelot.UnitTests.Headers { using System.Threading.Tasks; + using Ocelot.Request.Middleware; public class HttpHeadersTransformationMiddlewareTests { @@ -68,7 +69,7 @@ namespace Ocelot.UnitTests.Headers private void GivenTheDownstreamRequestIs() { - _downstreamContext.DownstreamRequest = new HttpRequestMessage(); + _downstreamContext.DownstreamRequest = new DownstreamRequest(new HttpRequestMessage(HttpMethod.Get, "http://test.com")); } private void GivenTheHttpResponseMessageIs() @@ -97,7 +98,7 @@ namespace Ocelot.UnitTests.Headers private void ThenTheIHttpResponseHeaderReplacerIsCalledCorrectly() { - _postReplacer.Verify(x => x.Replace(It.IsAny(), It.IsAny>(), It.IsAny()), Times.Once); + _postReplacer.Verify(x => x.Replace(It.IsAny(), It.IsAny>(), It.IsAny()), Times.Once); } private void GivenTheFollowingRequest() diff --git a/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs b/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs index 7f96b247..7a448118 100644 --- a/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs @@ -14,6 +14,7 @@ namespace Ocelot.UnitTests.Headers using Ocelot.Headers; using Ocelot.Headers.Middleware; using Ocelot.Logging; + using Ocelot.Request.Middleware; using Ocelot.Responses; using TestStack.BDDfy; using Xunit; @@ -37,7 +38,7 @@ namespace Ocelot.UnitTests.Headers _loggerFactory.Setup(x => x.CreateLogger()).Returns(_logger.Object); _next = context => Task.CompletedTask; _middleware = new HttpRequestHeadersBuilderMiddleware(_next, _loggerFactory.Object, _addHeaders.Object); - _downstreamContext.DownstreamRequest = new HttpRequestMessage(); + _downstreamContext.DownstreamRequest = new DownstreamRequest(new HttpRequestMessage(HttpMethod.Get, "http://test.com")); } [Fact] @@ -81,7 +82,7 @@ namespace Ocelot.UnitTests.Headers .Setup(x => x.SetHeadersOnDownstreamRequest( It.IsAny>(), It.IsAny>(), - It.IsAny())) + It.IsAny())) .Returns(new OkResponse()); } diff --git a/test/Ocelot.UnitTests/Headers/HttpResponseHeaderReplacerTests.cs b/test/Ocelot.UnitTests/Headers/HttpResponseHeaderReplacerTests.cs index ca97de4f..6eefdc4c 100644 --- a/test/Ocelot.UnitTests/Headers/HttpResponseHeaderReplacerTests.cs +++ b/test/Ocelot.UnitTests/Headers/HttpResponseHeaderReplacerTests.cs @@ -11,6 +11,7 @@ using Moq; using Ocelot.Infrastructure; using Ocelot.Middleware; using Ocelot.Infrastructure.RequestData; +using Ocelot.Request.Middleware; namespace Ocelot.UnitTests.Headers { @@ -21,7 +22,7 @@ namespace Ocelot.UnitTests.Headers private HttpResponseHeaderReplacer _replacer; private List _headerFindAndReplaces; private Response _result; - private HttpRequestMessage _request; + private DownstreamRequest _request; private Mock _finder; private Mock _repo; @@ -69,7 +70,7 @@ namespace Ocelot.UnitTests.Headers { var downstreamUrl = "http://downstream.com/"; - var request = new HttpRequestMessage(); + var request = new HttpRequestMessage(HttpMethod.Get, "http://test.com"); request.RequestUri = new System.Uri(downstreamUrl); var response = new HttpResponseMessage(); @@ -91,7 +92,7 @@ namespace Ocelot.UnitTests.Headers { var downstreamUrl = "http://downstream.com/"; - var request = new HttpRequestMessage(); + var request = new HttpRequestMessage(HttpMethod.Get, "http://test.com"); request.RequestUri = new System.Uri(downstreamUrl); var response = new HttpResponseMessage(); @@ -113,7 +114,7 @@ namespace Ocelot.UnitTests.Headers { var downstreamUrl = "http://downstream.com/test/product"; - var request = new HttpRequestMessage(); + var request = new HttpRequestMessage(HttpMethod.Get, "http://test.com"); request.RequestUri = new System.Uri(downstreamUrl); var response = new HttpResponseMessage(); @@ -135,7 +136,7 @@ namespace Ocelot.UnitTests.Headers { var downstreamUrl = "http://downstream.com/test/product"; - var request = new HttpRequestMessage(); + var request = new HttpRequestMessage(HttpMethod.Get, "http://test.com"); request.RequestUri = new System.Uri(downstreamUrl); var response = new HttpResponseMessage(); @@ -157,7 +158,7 @@ namespace Ocelot.UnitTests.Headers { var downstreamUrl = "http://downstream.com:123/test/product"; - var request = new HttpRequestMessage(); + var request = new HttpRequestMessage(HttpMethod.Get, "http://test.com"); request.RequestUri = new System.Uri(downstreamUrl); var response = new HttpResponseMessage(); @@ -179,7 +180,7 @@ namespace Ocelot.UnitTests.Headers { var downstreamUrl = "http://downstream.com:123/test/product"; - var request = new HttpRequestMessage(); + var request = new HttpRequestMessage(HttpMethod.Get, "http://test.com"); request.RequestUri = new System.Uri(downstreamUrl); var response = new HttpResponseMessage(); @@ -198,7 +199,7 @@ namespace Ocelot.UnitTests.Headers private void GivenTheRequestIs(HttpRequestMessage request) { - _request = request; + _request = new DownstreamRequest(request); } private void ThenTheHeadersAreNotReplaced() diff --git a/test/Ocelot.UnitTests/Infrastructure/PlaceholdersTests.cs b/test/Ocelot.UnitTests/Infrastructure/PlaceholdersTests.cs index 6acd3655..51cbffef 100644 --- a/test/Ocelot.UnitTests/Infrastructure/PlaceholdersTests.cs +++ b/test/Ocelot.UnitTests/Infrastructure/PlaceholdersTests.cs @@ -4,6 +4,7 @@ using Moq; using Ocelot.Infrastructure; using Ocelot.Infrastructure.RequestData; using Ocelot.Middleware; +using Ocelot.Request.Middleware; using Ocelot.Responses; using Shouldly; using Xunit; @@ -43,8 +44,9 @@ namespace Ocelot.UnitTests.Infrastructure [Fact] public void should_return_downstream_base_url_when_port_is_not_80_or_443() { - var request = new HttpRequestMessage(); - request.RequestUri = new Uri("http://www.bbc.co.uk"); + var httpRequest = new HttpRequestMessage(); + httpRequest.RequestUri = new Uri("http://www.bbc.co.uk"); + var request = new DownstreamRequest(httpRequest); var result = _placeholders.Get("{DownstreamBaseUrl}", request); result.Data.ShouldBe("http://www.bbc.co.uk/"); } @@ -53,8 +55,9 @@ namespace Ocelot.UnitTests.Infrastructure [Fact] public void should_return_downstream_base_url_when_port_is_80_or_443() { - var request = new HttpRequestMessage(); - request.RequestUri = new Uri("http://www.bbc.co.uk:123"); + var httpRequest = new HttpRequestMessage(); + httpRequest.RequestUri = new Uri("http://www.bbc.co.uk:123"); + var request = new DownstreamRequest(httpRequest); var result = _placeholders.Get("{DownstreamBaseUrl}", request); result.Data.ShouldBe("http://www.bbc.co.uk:123/"); } @@ -62,7 +65,8 @@ namespace Ocelot.UnitTests.Infrastructure [Fact] public void should_return_key_does_not_exist_for_http_request_message() { - var result = _placeholders.Get("{Test}", new System.Net.Http.HttpRequestMessage()); + var request = new DownstreamRequest(new HttpRequestMessage(HttpMethod.Get, "http://west.com")); + var result = _placeholders.Get("{Test}", request); result.IsError.ShouldBeTrue(); result.Errors[0].Message.ShouldBe("Unable to find placeholder called {Test}"); } diff --git a/test/Ocelot.UnitTests/Infrastructure/StringExtensionsTests.cs b/test/Ocelot.UnitTests/Infrastructure/StringExtensionsTests.cs new file mode 100644 index 00000000..88632675 --- /dev/null +++ b/test/Ocelot.UnitTests/Infrastructure/StringExtensionsTests.cs @@ -0,0 +1,29 @@ +using Xunit; +using Ocelot.Infrastructure.Extensions; +using Shouldly; + +namespace Ocelot.UnitTests.Infrastructure +{ + public class StringExtensionsTests + { + [Fact] + public void should_trim_start() + { + var test = "/string"; + + test = test.TrimStart("/"); + + test.ShouldBe("string"); + } + + [Fact] + public void should_return_source() + { + var test = "string"; + + test = test.LastCharAsForwardSlash(); + + test.ShouldBe("string/"); + } + } +} diff --git a/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerFactoryTests.cs b/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerFactoryTests.cs index 5d7fa8b6..da439976 100644 --- a/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerFactoryTests.cs +++ b/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerFactoryTests.cs @@ -5,6 +5,7 @@ using Ocelot.LoadBalancer.LoadBalancers; using Ocelot.ServiceDiscovery; using Shouldly; using System.Collections.Generic; +using Ocelot.ServiceDiscovery.Providers; using TestStack.BDDfy; using Xunit; diff --git a/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerMiddlewareTests.cs b/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerMiddlewareTests.cs index e03fe6e1..a51b5cff 100644 --- a/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerMiddlewareTests.cs @@ -13,6 +13,7 @@ namespace Ocelot.UnitTests.LoadBalancer using Ocelot.LoadBalancer.LoadBalancers; using Ocelot.LoadBalancer.Middleware; using Ocelot.Logging; + using Ocelot.Request.Middleware; using Ocelot.Responses; using Ocelot.Values; using Shouldly; @@ -39,13 +40,13 @@ namespace Ocelot.UnitTests.LoadBalancer _loadBalancerHouse = new Mock(); _loadBalancer = new Mock(); _loadBalancerHouse = new Mock(); - _downstreamRequest = new HttpRequestMessage(HttpMethod.Get, ""); + _downstreamRequest = new HttpRequestMessage(HttpMethod.Get, "http://test.com/"); _downstreamContext = new DownstreamContext(new DefaultHttpContext()); _loggerFactory = new Mock(); _logger = new Mock(); _loggerFactory.Setup(x => x.CreateLogger()).Returns(_logger.Object); _next = context => Task.CompletedTask; - _downstreamContext.DownstreamRequest = _downstreamRequest; + _downstreamContext.DownstreamRequest = new DownstreamRequest(_downstreamRequest); } [Fact] @@ -122,6 +123,7 @@ namespace Ocelot.UnitTests.LoadBalancer private void GivenTheDownStreamUrlIs(string downstreamUrl) { _downstreamRequest.RequestUri = new System.Uri(downstreamUrl); + _downstreamContext.DownstreamRequest = new DownstreamRequest(_downstreamRequest); } private void GivenTheLoadBalancerReturnsAnError() @@ -185,7 +187,7 @@ namespace Ocelot.UnitTests.LoadBalancer private void ThenTheDownstreamUrlIsReplacedWith(string expectedUri) { - _downstreamContext.DownstreamRequest.RequestUri.OriginalString.ShouldBe(expectedUri); + _downstreamContext.DownstreamRequest.ToHttpRequestMessage().RequestUri.OriginalString.ShouldBe(expectedUri); } } } diff --git a/test/Ocelot.UnitTests/Middleware/SimpleJsonResponseAggregatorTests.cs b/test/Ocelot.UnitTests/Middleware/SimpleJsonResponseAggregatorTests.cs index 27d1b7e0..f0a2b224 100644 --- a/test/Ocelot.UnitTests/Middleware/SimpleJsonResponseAggregatorTests.cs +++ b/test/Ocelot.UnitTests/Middleware/SimpleJsonResponseAggregatorTests.cs @@ -9,6 +9,7 @@ using Ocelot.Configuration.Builder; using Ocelot.Errors; using Ocelot.Middleware; using Ocelot.Middleware.Multiplexer; +using Ocelot.Request.Middleware; using Ocelot.UnitTests.Responder; using Shouldly; using TestStack.BDDfy; @@ -48,7 +49,7 @@ namespace Ocelot.UnitTests.Middleware new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent("Bill says hi") }, DownstreamReRoute = billDownstreamReRoute, Errors = new List { new AnyError() }, - DownstreamRequest = new HttpRequestMessage(HttpMethod.Get, new Uri("http://www.bbc.co.uk")), + DownstreamRequest = new DownstreamRequest(new HttpRequestMessage(HttpMethod.Get, new Uri("http://www.bbc.co.uk"))), }; var downstreamContexts = new List { billDownstreamContext }; diff --git a/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs b/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs index bebe4307..83c486f2 100644 --- a/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs +++ b/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs @@ -12,24 +12,27 @@ using TestStack.BDDfy; using Xunit; using System.Net.Http; using System; +using Ocelot.Request.Middleware; namespace Ocelot.UnitTests.QueryStrings { public class AddQueriesToRequestTests { private readonly AddQueriesToRequest _addQueriesToRequest; - private HttpRequestMessage _downstreamRequest; + private DownstreamRequest _downstreamRequest; private readonly Mock _parser; private List _configuration; private List _claims; private Response _result; private Response _claimValue; + private HttpRequestMessage _request; public AddQueriesToRequestTests() { + _request = new HttpRequestMessage(HttpMethod.Post, "http://my.url/abc?q=123"); _parser = new Mock(); _addQueriesToRequest = new AddQueriesToRequest(_parser.Object); - _downstreamRequest = new HttpRequestMessage(HttpMethod.Post, "http://my.url/abc?q=123"); + _downstreamRequest = new DownstreamRequest(_request); } [Fact] @@ -78,7 +81,7 @@ namespace Ocelot.UnitTests.QueryStrings private void TheTheQueryStringIs(string expected) { - _downstreamRequest.RequestUri.Query.ShouldBe(expected); + _downstreamRequest.Query.ShouldBe(expected); } [Fact] @@ -123,7 +126,7 @@ namespace Ocelot.UnitTests.QueryStrings private void ThenTheQueryIsAdded() { - var queries = Microsoft.AspNetCore.WebUtilities.QueryHelpers.ParseQuery(_downstreamRequest.RequestUri.OriginalString); + var queries = Microsoft.AspNetCore.WebUtilities.QueryHelpers.ParseQuery(_downstreamRequest.ToHttpRequestMessage().RequestUri.OriginalString); var query = queries.First(x => x.Key == "query-key"); query.Value.First().ShouldBe(_claimValue.Data); } @@ -140,15 +143,18 @@ namespace Ocelot.UnitTests.QueryStrings private void GivenTheDownstreamRequestHasQueryString(string queryString) { - _downstreamRequest = new HttpRequestMessage(HttpMethod.Post, $"http://my.url/abc{queryString}"); + _request = new HttpRequestMessage(HttpMethod.Post, $"http://my.url/abc{queryString}"); + _downstreamRequest = new DownstreamRequest(_request); } private void GivenTheDownstreamRequestHasQueryString(string key, string value) { var newUri = Microsoft.AspNetCore.WebUtilities.QueryHelpers - .AddQueryString(_downstreamRequest.RequestUri.OriginalString, key, value); + .AddQueryString(_downstreamRequest.ToHttpRequestMessage().RequestUri.OriginalString, key, value); - _downstreamRequest.RequestUri = new Uri(newUri); + _request.RequestUri = new Uri(newUri); + //todo - might not need to instanciate + _downstreamRequest = new DownstreamRequest(_request); } private void GivenTheClaimParserReturns(Response claimValue) diff --git a/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs b/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs index 163a0411..e4d95028 100644 --- a/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs @@ -18,6 +18,7 @@ namespace Ocelot.UnitTests.QueryStrings using System.Security.Claims; using Microsoft.AspNetCore.Http; using System.Threading.Tasks; + using Ocelot.Request.Middleware; public class QueryStringBuilderMiddlewareTests { @@ -36,7 +37,7 @@ namespace Ocelot.UnitTests.QueryStrings _loggerFactory.Setup(x => x.CreateLogger()).Returns(_logger.Object); _next = context => Task.CompletedTask; _addQueries = new Mock(); - _downstreamContext.DownstreamRequest = new HttpRequestMessage(); + _downstreamContext.DownstreamRequest = new DownstreamRequest(new HttpRequestMessage(HttpMethod.Get, "http://test.com")); _middleware = new QueryStringBuilderMiddleware(_next, _loggerFactory.Object, _addQueries.Object); } @@ -74,7 +75,7 @@ namespace Ocelot.UnitTests.QueryStrings .Setup(x => x.SetQueriesOnDownstreamRequest( It.IsAny>(), It.IsAny>(), - It.IsAny())) + It.IsAny())) .Returns(new OkResponse()); } diff --git a/test/Ocelot.UnitTests/RateLimit/ClientRateLimitMiddlewareTests.cs b/test/Ocelot.UnitTests/RateLimit/ClientRateLimitMiddlewareTests.cs index 819e38dd..7d3cac1a 100644 --- a/test/Ocelot.UnitTests/RateLimit/ClientRateLimitMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/RateLimit/ClientRateLimitMiddlewareTests.cs @@ -18,6 +18,7 @@ namespace Ocelot.UnitTests.RateLimit using Microsoft.Extensions.Caching.Memory; using System.IO; using System.Threading.Tasks; + using Ocelot.Request.Middleware; public class ClientRateLimitMiddlewareTests { @@ -100,7 +101,7 @@ namespace Ocelot.UnitTests.RateLimit { var request = new HttpRequestMessage(new HttpMethod("GET"), _url); request.Headers.Add("ClientId", clientId); - _downstreamContext.DownstreamRequest = request; + _downstreamContext.DownstreamRequest = new DownstreamRequest(request); _middleware.Invoke(_downstreamContext).GetAwaiter().GetResult(); _responseStatusCode = (int)_downstreamContext.HttpContext.Response.StatusCode; @@ -115,7 +116,7 @@ namespace Ocelot.UnitTests.RateLimit { var request = new HttpRequestMessage(new HttpMethod("GET"), _url); request.Headers.Add("ClientId", clientId); - _downstreamContext.DownstreamRequest = request; + _downstreamContext.DownstreamRequest = new DownstreamRequest(request); _downstreamContext.HttpContext.Request.Headers.TryAdd("ClientId", clientId); _middleware.Invoke(_downstreamContext).GetAwaiter().GetResult(); diff --git a/test/Ocelot.UnitTests/Request/DownstreamRequestInitialiserMiddlewareTests.cs b/test/Ocelot.UnitTests/Request/DownstreamRequestInitialiserMiddlewareTests.cs index 2bd26a4b..d37efc42 100644 --- a/test/Ocelot.UnitTests/Request/DownstreamRequestInitialiserMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Request/DownstreamRequestInitialiserMiddlewareTests.cs @@ -88,7 +88,7 @@ namespace Ocelot.UnitTests.Request private void GivenTheMapperWillReturnAMappedRequest() { - _mappedRequest = new OkResponse(new HttpRequestMessage()); + _mappedRequest = new OkResponse(new HttpRequestMessage(HttpMethod.Get, "http://www.bbc.co.uk")); _requestMapper .Setup(rm => rm.Map(It.IsAny())) diff --git a/test/Ocelot.UnitTests/RequestId/ReRouteRequestIdMiddlewareTests.cs b/test/Ocelot.UnitTests/RequestId/ReRouteRequestIdMiddlewareTests.cs index e27fc50f..91c1e15c 100644 --- a/test/Ocelot.UnitTests/RequestId/ReRouteRequestIdMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/RequestId/ReRouteRequestIdMiddlewareTests.cs @@ -20,6 +20,7 @@ namespace Ocelot.UnitTests.RequestId using Shouldly; using TestStack.BDDfy; using Xunit; + using Ocelot.Request.Middleware; public class ReRouteRequestIdMiddlewareTests { @@ -35,7 +36,7 @@ namespace Ocelot.UnitTests.RequestId public ReRouteRequestIdMiddlewareTests() { - _downstreamRequest = new HttpRequestMessage(); + _downstreamRequest = new HttpRequestMessage(HttpMethod.Get, "http://test.com"); _repo = new Mock(); _downstreamContext = new DownstreamContext(new DefaultHttpContext()); _loggerFactory = new Mock(); @@ -47,7 +48,7 @@ namespace Ocelot.UnitTests.RequestId return Task.CompletedTask; }; _middleware = new ReRouteRequestIdMiddleware(_next, _loggerFactory.Object, _repo.Object); - _downstreamContext.DownstreamRequest = _downstreamRequest; + _downstreamContext.DownstreamRequest = new DownstreamRequest(_downstreamRequest); } [Fact] diff --git a/test/Ocelot.UnitTests/Requester/HttpClientBuilderTests.cs b/test/Ocelot.UnitTests/Requester/HttpClientBuilderTests.cs index 7368fac8..c5895c1d 100644 --- a/test/Ocelot.UnitTests/Requester/HttpClientBuilderTests.cs +++ b/test/Ocelot.UnitTests/Requester/HttpClientBuilderTests.cs @@ -14,6 +14,7 @@ using Ocelot.Configuration; using Ocelot.Configuration.Builder; using Ocelot.Logging; using Ocelot.Middleware; +using Ocelot.Request.Middleware; using Ocelot.Requester; using Ocelot.Responses; using Shouldly; @@ -170,7 +171,7 @@ namespace Ocelot.UnitTests.Requester var context = new DownstreamContext(new DefaultHttpContext()) { DownstreamReRoute = downstream, - DownstreamRequest = new HttpRequestMessage() { RequestUri = new Uri("http://localhost:5003") }, + DownstreamRequest = new DownstreamRequest(new HttpRequestMessage() { RequestUri = new Uri("http://localhost:5003") }), }; _context = context; diff --git a/test/Ocelot.UnitTests/Requester/HttpClientHttpRequesterTest.cs b/test/Ocelot.UnitTests/Requester/HttpClientHttpRequesterTest.cs index bbf59692..c80ea391 100644 --- a/test/Ocelot.UnitTests/Requester/HttpClientHttpRequesterTest.cs +++ b/test/Ocelot.UnitTests/Requester/HttpClientHttpRequesterTest.cs @@ -12,13 +12,16 @@ using Ocelot.Middleware; using TestStack.BDDfy; using Xunit; using Shouldly; +using Ocelot.Request.Middleware; +using System.Threading.Tasks; +using System.Threading; namespace Ocelot.UnitTests.Requester { public class HttpClientHttpRequesterTest { private readonly Mock _cacheHandlers; - private Mock _house; + private Mock _factory; private Response _response; private readonly HttpClientHttpRequester _httpClientRequester; private DownstreamContext _request; @@ -27,8 +30,8 @@ namespace Ocelot.UnitTests.Requester public HttpClientHttpRequesterTest() { - _house = new Mock(); - _house.Setup(x => x.Get(It.IsAny())).Returns(new OkResponse>>(new List>())); + _factory = new Mock(); + _factory.Setup(x => x.Get(It.IsAny())).Returns(new OkResponse>>(new List>())); _logger = new Mock(); _loggerFactory = new Mock(); _loggerFactory @@ -38,7 +41,7 @@ namespace Ocelot.UnitTests.Requester _httpClientRequester = new HttpClientHttpRequester( _loggerFactory.Object, _cacheHandlers.Object, - _house.Object); + _factory.Object); } [Fact] @@ -50,10 +53,11 @@ namespace Ocelot.UnitTests.Requester var context = new DownstreamContext(new DefaultHttpContext()) { DownstreamReRoute = reRoute, - DownstreamRequest = new HttpRequestMessage() { RequestUri = new Uri("http://www.bbc.co.uk") }, + DownstreamRequest = new DownstreamRequest(new HttpRequestMessage() { RequestUri = new Uri("http://www.bbc.co.uk") }), }; this.Given(x=>x.GivenTheRequestIs(context)) + .And(x => GivenTheHouseReturnsOkHandler()) .When(x=>x.WhenIGetResponse()) .Then(x => x.ThenTheResponseIsCalledCorrectly()) .BDDfy(); @@ -68,7 +72,7 @@ namespace Ocelot.UnitTests.Requester var context = new DownstreamContext(new DefaultHttpContext()) { DownstreamReRoute = reRoute, - DownstreamRequest = new HttpRequestMessage() { RequestUri = new Uri("http://localhost:60080") }, + DownstreamRequest = new DownstreamRequest(new HttpRequestMessage() { RequestUri = new Uri("http://localhost:60080") }), }; this.Given(x => x.GivenTheRequestIs(context)) @@ -96,5 +100,23 @@ namespace Ocelot.UnitTests.Requester { _response.IsError.ShouldBeTrue(); } + + private void GivenTheHouseReturnsOkHandler() + { + var handlers = new List> + { + () => new OkDelegatingHandler() + }; + + _factory.Setup(x => x.Get(It.IsAny())).Returns(new OkResponse>>(handlers)); + } + + class OkDelegatingHandler : DelegatingHandler + { + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + return Task.FromResult(new HttpResponseMessage()); + } + } } } diff --git a/test/Ocelot.UnitTests/ServiceDiscovery/ConfigurationServiceProviderTests.cs b/test/Ocelot.UnitTests/ServiceDiscovery/ConfigurationServiceProviderTests.cs index 08b67820..60555435 100644 --- a/test/Ocelot.UnitTests/ServiceDiscovery/ConfigurationServiceProviderTests.cs +++ b/test/Ocelot.UnitTests/ServiceDiscovery/ConfigurationServiceProviderTests.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using Ocelot.ServiceDiscovery; +using Ocelot.ServiceDiscovery.Providers; using Ocelot.Values; using Shouldly; using TestStack.BDDfy; diff --git a/test/Ocelot.UnitTests/ServiceDiscovery/ConsulServiceDiscoveryProviderTests.cs b/test/Ocelot.UnitTests/ServiceDiscovery/ConsulServiceDiscoveryProviderTests.cs index a272d1b4..fc67ab8e 100644 --- a/test/Ocelot.UnitTests/ServiceDiscovery/ConsulServiceDiscoveryProviderTests.cs +++ b/test/Ocelot.UnitTests/ServiceDiscovery/ConsulServiceDiscoveryProviderTests.cs @@ -9,6 +9,8 @@ using Microsoft.AspNetCore.Http; using Moq; using Ocelot.Logging; using Ocelot.ServiceDiscovery; +using Ocelot.ServiceDiscovery.Configuration; +using Ocelot.ServiceDiscovery.Providers; using Ocelot.Values; using Xunit; using TestStack.BDDfy; diff --git a/test/Ocelot.UnitTests/ServiceDiscovery/ServiceFabricServiceDiscoveryProviderTests.cs b/test/Ocelot.UnitTests/ServiceDiscovery/ServiceFabricServiceDiscoveryProviderTests.cs index 39deb681..eabe72d7 100644 --- a/test/Ocelot.UnitTests/ServiceDiscovery/ServiceFabricServiceDiscoveryProviderTests.cs +++ b/test/Ocelot.UnitTests/ServiceDiscovery/ServiceFabricServiceDiscoveryProviderTests.cs @@ -1,4 +1,7 @@ -namespace Ocelot.UnitTests.ServiceDiscovery +using Ocelot.ServiceDiscovery.Configuration; +using Ocelot.ServiceDiscovery.Providers; + +namespace Ocelot.UnitTests.ServiceDiscovery { using System; using System.Collections.Generic; diff --git a/test/Ocelot.UnitTests/ServiceDiscovery/ServiceProviderFactoryTests.cs b/test/Ocelot.UnitTests/ServiceDiscovery/ServiceProviderFactoryTests.cs index 021b9efb..2eef9b78 100644 --- a/test/Ocelot.UnitTests/ServiceDiscovery/ServiceProviderFactoryTests.cs +++ b/test/Ocelot.UnitTests/ServiceDiscovery/ServiceProviderFactoryTests.cs @@ -5,6 +5,7 @@ using Ocelot.Configuration; using Ocelot.Configuration.Builder; using Ocelot.Logging; using Ocelot.ServiceDiscovery; +using Ocelot.ServiceDiscovery.Providers; using Shouldly; using TestStack.BDDfy; using Xunit; diff --git a/test/Ocelot.UnitTests/WebSockets/WebSocketsProxyMiddlewareTests.cs b/test/Ocelot.UnitTests/WebSockets/WebSocketsProxyMiddlewareTests.cs new file mode 100644 index 00000000..c3b087d4 --- /dev/null +++ b/test/Ocelot.UnitTests/WebSockets/WebSocketsProxyMiddlewareTests.cs @@ -0,0 +1,239 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Consul; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; +using Ocelot.Configuration.File; +using Ocelot.DependencyInjection; +using Ocelot.Middleware; +using Shouldly; +using TestStack.BDDfy; +using Xunit; + +namespace Ocelot.UnitTests.Websockets +{ + public class WebSocketsProxyMiddlewareTests : IDisposable + { + private IWebHost _firstDownstreamHost; + private readonly List _firstRecieved; + private WebHostBuilder _ocelotBuilder; + private IWebHost _ocelotHost; + + public WebSocketsProxyMiddlewareTests() + { + _firstRecieved = new List(); + } + + [Fact] + public async Task should_proxy_websocket_input_to_downstream_service() + { + var downstreamPort = 5001; + var downstreamHost = "localhost"; + + var config = new FileConfiguration + { + ReRoutes = new List + { + new FileReRoute + { + UpstreamPathTemplate = "/", + DownstreamPathTemplate = "/ws", + DownstreamScheme = "ws", + DownstreamHostAndPorts = new List + { + new FileHostAndPort + { + Host = downstreamHost, + Port = downstreamPort + } + } + } + } + }; + + this.Given(_ => GivenThereIsAConfiguration(config)) + .And(_ => StartFakeOcelotWithWebSockets()) + .And(_ => StartFakeDownstreamService($"http://{downstreamHost}:{downstreamPort}", "/ws")) + .When(_ => StartClient("ws://localhost:5000/")) + .Then(_ => _firstRecieved.Count.ShouldBe(10)) + .BDDfy(); + } + + public void Dispose() + { + _firstDownstreamHost?.Dispose(); + } + + public async Task StartFakeOcelotWithWebSockets() + { + _ocelotBuilder = new WebHostBuilder(); + _ocelotBuilder.ConfigureServices(s => + { + s.AddSingleton(_ocelotBuilder); + s.AddOcelot(); + }); + _ocelotBuilder.UseKestrel() + .UseUrls("http://localhost:5000") + .UseContentRoot(Directory.GetCurrentDirectory()) + .ConfigureAppConfiguration((hostingContext, config) => + { + config.SetBasePath(hostingContext.HostingEnvironment.ContentRootPath); + var env = hostingContext.HostingEnvironment; + config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true); + config.AddJsonFile("configuration.json"); + config.AddEnvironmentVariables(); + }) + .ConfigureLogging((hostingContext, logging) => + { + logging.AddConfiguration(hostingContext.Configuration.GetSection("Logging")); + logging.AddConsole(); + }) + .Configure(app => + { + app.UseWebSockets(); + app.UseOcelot().Wait(); + }) + .UseIISIntegration(); + _ocelotHost = _ocelotBuilder.Build(); + await _ocelotHost.StartAsync(); + } + + public void GivenThereIsAConfiguration(FileConfiguration fileConfiguration) + { + var configurationPath = Path.Combine(AppContext.BaseDirectory, "configuration.json"); + + var jsonConfiguration = JsonConvert.SerializeObject(fileConfiguration); + + if (File.Exists(configurationPath)) + { + File.Delete(configurationPath); + } + + File.WriteAllText(configurationPath, jsonConfiguration); + } + + private async Task StartFakeDownstreamService(string url, string path) + { + _firstDownstreamHost = new WebHostBuilder() + .ConfigureServices(s => { }).UseKestrel() + .UseUrls(url) + .UseContentRoot(Directory.GetCurrentDirectory()) + .ConfigureAppConfiguration((hostingContext, config) => + { + config.SetBasePath(hostingContext.HostingEnvironment.ContentRootPath); + var env = hostingContext.HostingEnvironment; + config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true); + config.AddEnvironmentVariables(); + }) + .ConfigureLogging((hostingContext, logging) => + { + logging.AddConfiguration(hostingContext.Configuration.GetSection("Logging")); + logging.AddConsole(); + }) + .Configure(app => + { + app.UseWebSockets(); + app.Use(async (context, next) => + { + if (context.Request.Path == path) + { + if (context.WebSockets.IsWebSocketRequest) + { + WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync(); + await Echo(webSocket); + } + else + { + context.Response.StatusCode = 400; + } + } + else + { + await next(); + } + }); + }) + .UseIISIntegration().Build(); + await _firstDownstreamHost.StartAsync(); + } + + private async Task StartClient(string url) + { + var client = new ClientWebSocket(); + + await client.ConnectAsync(new Uri(url), CancellationToken.None); + + var sending = Task.Run(async () => + { + string line = "test"; + for (int i = 0; i < 10; i++) + { + var bytes = Encoding.UTF8.GetBytes(line); + + await client.SendAsync(new ArraySegment(bytes), WebSocketMessageType.Text, true, + CancellationToken.None); + await Task.Delay(10); + } + + await client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + }); + + var receiving = Task.Run(async () => + { + var buffer = new byte[1024 * 4]; + + while (true) + { + var result = await client.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + + if (result.MessageType == WebSocketMessageType.Text) + { + _firstRecieved.Add(Encoding.UTF8.GetString(buffer, 0, result.Count)); + } + + else if (result.MessageType == WebSocketMessageType.Close) + { + await client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + break; + } + } + }); + + await Task.WhenAll(sending, receiving); + } + + private async Task Echo(WebSocket webSocket) + { + try + { + var buffer = new byte[1024 * 4]; + + var result = await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + + while (!result.CloseStatus.HasValue) + { + await webSocket.SendAsync(new ArraySegment(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None); + + result = await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + } + + await webSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None); + } + catch (Exception e) + { + Console.WriteLine(e); + } + } + } +}