diff --git a/src/Ocelot/Authentication/Middleware/AuthenticationMiddleware.cs b/src/Ocelot/Authentication/Middleware/AuthenticationMiddleware.cs index a6eb9cc2..f29161ce 100644 --- a/src/Ocelot/Authentication/Middleware/AuthenticationMiddleware.cs +++ b/src/Ocelot/Authentication/Middleware/AuthenticationMiddleware.cs @@ -1,97 +1,97 @@ -using System.Collections.Generic; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Http; -using Ocelot.Authentication.Handler.Factory; -using Ocelot.Configuration; -using Ocelot.Errors; -using Ocelot.Infrastructure.Extensions; -using Ocelot.Infrastructure.RequestData; -using Ocelot.Logging; -using Ocelot.Middleware; - -namespace Ocelot.Authentication.Middleware -{ - public class AuthenticationMiddleware : OcelotMiddleware - { - private readonly RequestDelegate _next; - private readonly IApplicationBuilder _app; - private readonly IAuthenticationHandlerFactory _authHandlerFactory; - private readonly IOcelotLogger _logger; - - public AuthenticationMiddleware(RequestDelegate next, - IApplicationBuilder app, - IRequestScopedDataRepository requestScopedDataRepository, - IAuthenticationHandlerFactory authHandlerFactory, - IOcelotLoggerFactory loggerFactory) - : base(requestScopedDataRepository) - { - _next = next; - _authHandlerFactory = authHandlerFactory; - _app = app; - _logger = loggerFactory.CreateLogger(); - } - - public async Task Invoke(HttpContext context) - { - _logger.TraceMiddlewareEntry(); - - if (IsAuthenticatedRoute(DownstreamRoute.ReRoute)) - { - _logger.LogDebug($"{context.Request.Path} is an authenticated route. {MiddlwareName} checking if client is authenticated"); - - var authenticationHandler = _authHandlerFactory.Get(_app, DownstreamRoute.ReRoute.AuthenticationOptions); - - if (authenticationHandler.IsError) +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Ocelot.Authentication.Handler.Factory; +using Ocelot.Configuration; +using Ocelot.Errors; +using Ocelot.Infrastructure.Extensions; +using Ocelot.Infrastructure.RequestData; +using Ocelot.Logging; +using Ocelot.Middleware; + +namespace Ocelot.Authentication.Middleware +{ + public class AuthenticationMiddleware : OcelotMiddleware + { + private readonly RequestDelegate _next; + private readonly IApplicationBuilder _app; + private readonly IAuthenticationHandlerFactory _authHandlerFactory; + private readonly IOcelotLogger _logger; + + public AuthenticationMiddleware(RequestDelegate next, + IApplicationBuilder app, + IRequestScopedDataRepository requestScopedDataRepository, + IAuthenticationHandlerFactory authHandlerFactory, + IOcelotLoggerFactory loggerFactory) + : base(requestScopedDataRepository) + { + _next = next; + _authHandlerFactory = authHandlerFactory; + _app = app; + _logger = loggerFactory.CreateLogger(); + } + + public async Task Invoke(HttpContext context) + { + _logger.TraceMiddlewareEntry(); + + if (IsAuthenticatedRoute(DownstreamRoute.ReRoute)) + { + _logger.LogDebug($"{context.Request.Path} is an authenticated route. {MiddlwareName} checking if client is authenticated"); + + var authenticationHandler = _authHandlerFactory.Get(_app, DownstreamRoute.ReRoute.AuthenticationOptions); + + if (authenticationHandler.IsError) { - _logger.LogError($"Error getting authentication handler for {context.Request.Path}. {authenticationHandler.Errors.ToErrorString()}"); - SetPipelineError(authenticationHandler.Errors); - _logger.TraceMiddlewareCompleted(); - return; + _logger.LogError($"Error getting authentication handler for {context.Request.Path}. {authenticationHandler.Errors.ToErrorString()}"); + SetPipelineError(authenticationHandler.Errors); + _logger.TraceMiddlewareCompleted(); + return; } - await authenticationHandler.Data.Handler.Handle(context); - - - if (context.User.Identity.IsAuthenticated) - { - _logger.LogDebug($"Client has been authenticated for {context.Request.Path}"); - - _logger.TraceInvokeNext(); - await _next.Invoke(context); - _logger.TraceInvokeNextCompleted(); - _logger.TraceMiddlewareCompleted(); - } - else - { - var error = new List - { - new UnauthenticatedError( - $"Request for authenticated route {context.Request.Path} by {context.User.Identity.Name} was unauthenticated") - }; - - _logger.LogError($"Client has NOT been authenticated for {context.Request.Path} and pipeline error set. {error.ToErrorString()}"); - SetPipelineError(error); - - _logger.TraceMiddlewareCompleted(); - return; - } - } - else + await authenticationHandler.Data.Handler.Handle(context); + + + if (context.User.Identity.IsAuthenticated) + { + _logger.LogDebug($"Client has been authenticated for {context.Request.Path}"); + + _logger.TraceInvokeNext(); + await _next.Invoke(context); + _logger.TraceInvokeNextCompleted(); + _logger.TraceMiddlewareCompleted(); + } + else + { + var error = new List + { + new UnauthenticatedError( + $"Request for authenticated route {context.Request.Path} by {context.User.Identity.Name} was unauthenticated") + }; + + _logger.LogError($"Client has NOT been authenticated for {context.Request.Path} and pipeline error set. {error.ToErrorString()}"); + SetPipelineError(error); + + _logger.TraceMiddlewareCompleted(); + return; + } + } + else { _logger.LogTrace($"No authentication needed for {context.Request.Path}"); _logger.TraceInvokeNext(); await _next.Invoke(context); _logger.TraceInvokeNextCompleted(); - _logger.TraceMiddlewareCompleted(); - } - } - - private static bool IsAuthenticatedRoute(ReRoute reRoute) - { - return reRoute.IsAuthenticated; - } - } -} - + _logger.TraceMiddlewareCompleted(); + } + } + + private static bool IsAuthenticatedRoute(ReRoute reRoute) + { + return reRoute.IsAuthenticated; + } + } +} + diff --git a/src/Ocelot/Cache/Middleware/OutputCacheMiddleware.cs b/src/Ocelot/Cache/Middleware/OutputCacheMiddleware.cs index 948a7397..85a73e70 100644 --- a/src/Ocelot/Cache/Middleware/OutputCacheMiddleware.cs +++ b/src/Ocelot/Cache/Middleware/OutputCacheMiddleware.cs @@ -27,14 +27,14 @@ namespace Ocelot.Cache.Middleware public async Task Invoke(HttpContext context) { - var downstreamUrlKey = DownstreamUrl; - if (!DownstreamRoute.ReRoute.IsCached) { await _next.Invoke(context); return; } + var downstreamUrlKey = DownstreamRequest.RequestUri.OriginalString; + _logger.LogDebug("started checking cache for {downstreamUrlKey}", downstreamUrlKey); var cached = _outputCache.Get(downstreamUrlKey); diff --git a/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs b/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs index e975dcb6..ef022923 100644 --- a/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs +++ b/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs @@ -31,6 +31,7 @@ using Ocelot.Middleware; using Ocelot.QueryStrings; using Ocelot.RateLimit; using Ocelot.Request.Builder; +using Ocelot.Request.Mapper; using Ocelot.Requester; using Ocelot.Requester.QoS; using Ocelot.Responder; @@ -160,6 +161,7 @@ namespace Ocelot.DependencyInjection services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(); + services.TryAddSingleton(); // see this for why we register this as singleton http://stackoverflow.com/questions/37371264/invalidoperationexception-unable-to-resolve-service-for-type-microsoft-aspnetc // could maybe use a scoped data repository diff --git a/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs b/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs index 631e278a..c91e06da 100644 --- a/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs +++ b/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs @@ -4,6 +4,7 @@ using Ocelot.DownstreamUrlCreator.UrlTemplateReplacer; using Ocelot.Infrastructure.RequestData; using Ocelot.Logging; using Ocelot.Middleware; +using System; namespace Ocelot.DownstreamUrlCreator.Middleware { @@ -42,23 +43,15 @@ namespace Ocelot.DownstreamUrlCreator.Middleware return; } - var dsScheme = DownstreamRoute.ReRoute.DownstreamScheme; - - var dsHostAndPort = HostAndPort; - - var dsUrl = _urlBuilder.Build(dsPath.Data.Value, dsScheme, dsHostAndPort); - - if (dsUrl.IsError) + var uriBuilder = new UriBuilder(DownstreamRequest.RequestUri) { - _logger.LogDebug("IUrlBuilder returned an error, setting pipeline error"); + Path = dsPath.Data.Value, + Scheme = DownstreamRoute.ReRoute.DownstreamScheme + }; - SetPipelineError(dsUrl.Errors); - return; - } + DownstreamRequest.RequestUri = uriBuilder.Uri; - _logger.LogDebug("downstream url is {downstreamUrl.Data.Value}", dsUrl.Data.Value); - - SetDownstreamUrlForThisRequest(dsUrl.Data.Value); + _logger.LogDebug("downstream url is {downstreamUrl.Data.Value}", DownstreamRequest.RequestUri); _logger.LogDebug("calling next middleware"); diff --git a/src/Ocelot/DownstreamUrlCreator/UrlBuilder.cs b/src/Ocelot/DownstreamUrlCreator/UrlBuilder.cs index 43a63715..eed7b8d7 100644 --- a/src/Ocelot/DownstreamUrlCreator/UrlBuilder.cs +++ b/src/Ocelot/DownstreamUrlCreator/UrlBuilder.cs @@ -25,6 +25,7 @@ namespace Ocelot.DownstreamUrlCreator return new ErrorResponse(new List { new DownstreamHostNullOrEmptyError() }); } + var builder = new UriBuilder { Host = downstreamHostAndPort.DownstreamHost, diff --git a/src/Ocelot/Errors/OcelotErrorCode.cs b/src/Ocelot/Errors/OcelotErrorCode.cs index c1c55dbb..de7960c4 100644 --- a/src/Ocelot/Errors/OcelotErrorCode.cs +++ b/src/Ocelot/Errors/OcelotErrorCode.cs @@ -28,6 +28,7 @@ UnableToFindLoadBalancerError, RequestTimedOutError, UnableToFindQoSProviderError, - UnableToSetConfigInConsulError + UnableToSetConfigInConsulError, + UnmappableRequestError } } diff --git a/src/Ocelot/Headers/AddHeadersToRequest.cs b/src/Ocelot/Headers/AddHeadersToRequest.cs index 97cd3e69..3c818bcc 100644 --- a/src/Ocelot/Headers/AddHeadersToRequest.cs +++ b/src/Ocelot/Headers/AddHeadersToRequest.cs @@ -1,10 +1,9 @@ using System.Collections.Generic; using System.Linq; -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Primitives; using Ocelot.Configuration; using Ocelot.Infrastructure.Claims.Parser; using Ocelot.Responses; +using System.Net.Http; namespace Ocelot.Headers { @@ -17,25 +16,25 @@ namespace Ocelot.Headers _claimsParser = claimsParser; } - public Response SetHeadersOnContext(List claimsToThings, HttpContext context) + public Response SetHeadersOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest) { foreach (var config in claimsToThings) { - var value = _claimsParser.GetValue(context.User.Claims, config.NewKey, config.Delimiter, config.Index); + var value = _claimsParser.GetValue(claims, config.NewKey, config.Delimiter, config.Index); if (value.IsError) { return new ErrorResponse(value.Errors); } - var exists = context.Request.Headers.FirstOrDefault(x => x.Key == config.ExistingKey); + var exists = downstreamRequest.Headers.FirstOrDefault(x => x.Key == config.ExistingKey); if (!string.IsNullOrEmpty(exists.Key)) { - context.Request.Headers.Remove(exists); + downstreamRequest.Headers.Remove(exists.Key); } - context.Request.Headers.Add(config.ExistingKey, new StringValues(value.Data)); + downstreamRequest.Headers.Add(config.ExistingKey, value.Data); } return new OkResponse(); diff --git a/src/Ocelot/Headers/IAddHeadersToRequest.cs b/src/Ocelot/Headers/IAddHeadersToRequest.cs index 3bf786a4..cc4e0383 100644 --- a/src/Ocelot/Headers/IAddHeadersToRequest.cs +++ b/src/Ocelot/Headers/IAddHeadersToRequest.cs @@ -1,13 +1,13 @@ -using System.Collections.Generic; -using Microsoft.AspNetCore.Http; -using Ocelot.Configuration; -using Ocelot.Responses; - -namespace Ocelot.Headers +namespace Ocelot.Headers { + using System.Collections.Generic; + using System.Net.Http; + + using Ocelot.Configuration; + using Ocelot.Responses; + public interface IAddHeadersToRequest { - Response SetHeadersOnContext(List claimsToThings, - HttpContext context); + Response SetHeadersOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest); } } diff --git a/src/Ocelot/Headers/Middleware/HttpRequestHeadersBuilderMiddleware.cs b/src/Ocelot/Headers/Middleware/HttpRequestHeadersBuilderMiddleware.cs index a89d2ec2..cb0fdc98 100644 --- a/src/Ocelot/Headers/Middleware/HttpRequestHeadersBuilderMiddleware.cs +++ b/src/Ocelot/Headers/Middleware/HttpRequestHeadersBuilderMiddleware.cs @@ -32,7 +32,7 @@ namespace Ocelot.Headers.Middleware { _logger.LogDebug("this route has instructions to convert claims to headers"); - var response = _addHeadersToRequest.SetHeadersOnContext(DownstreamRoute.ReRoute.ClaimsToHeaders, context); + var response = _addHeadersToRequest.SetHeadersOnDownstreamRequest(DownstreamRoute.ReRoute.ClaimsToHeaders, context.User.Claims, DownstreamRequest); if (response.IsError) { diff --git a/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs b/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs index 8e26cbf2..be5a34d1 100644 --- a/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs +++ b/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs @@ -44,7 +44,13 @@ namespace Ocelot.LoadBalancer.Middleware return; } - SetHostAndPortForThisRequest(hostAndPort.Data); + var uriBuilder = new UriBuilder(DownstreamRequest.RequestUri); + uriBuilder.Host = hostAndPort.Data.DownstreamHost; + if (hostAndPort.Data.DownstreamPort > 0) + { + uriBuilder.Port = hostAndPort.Data.DownstreamPort; + } + DownstreamRequest.RequestUri = uriBuilder.Uri; _logger.LogDebug("calling next middleware"); diff --git a/src/Ocelot/Middleware/OcelotMiddleware.cs b/src/Ocelot/Middleware/OcelotMiddleware.cs index 2926e543..af56891e 100644 --- a/src/Ocelot/Middleware/OcelotMiddleware.cs +++ b/src/Ocelot/Middleware/OcelotMiddleware.cs @@ -3,7 +3,6 @@ using System.Net.Http; using Ocelot.DownstreamRouteFinder; using Ocelot.Errors; using Ocelot.Infrastructure.RequestData; -using Ocelot.Values; namespace Ocelot.Middleware { @@ -17,91 +16,35 @@ namespace Ocelot.Middleware MiddlwareName = this.GetType().Name; } - public string MiddlwareName { get; } + public string MiddlwareName { get; } - public bool PipelineError - { - get - { - var response = _requestScopedDataRepository.Get("OcelotMiddlewareError"); - return response.Data; - } - } + public bool PipelineError => _requestScopedDataRepository.Get("OcelotMiddlewareError").Data; - public List PipelineErrors - { - get - { - var response = _requestScopedDataRepository.Get>("OcelotMiddlewareErrors"); - return response.Data; - } - } + public List PipelineErrors => _requestScopedDataRepository.Get>("OcelotMiddlewareErrors").Data; - public DownstreamRoute DownstreamRoute - { - get - { - var downstreamRoute = _requestScopedDataRepository.Get("DownstreamRoute"); - return downstreamRoute.Data; - } - } + public DownstreamRoute DownstreamRoute => _requestScopedDataRepository.Get("DownstreamRoute").Data; - public string DownstreamUrl - { - get - { - var downstreamUrl = _requestScopedDataRepository.Get("DownstreamUrl"); - return downstreamUrl.Data; - } - } + public Request.Request Request => _requestScopedDataRepository.Get("Request").Data; - public Request.Request Request - { - get - { - var request = _requestScopedDataRepository.Get("Request"); - return request.Data; - } - } + public HttpRequestMessage DownstreamRequest => _requestScopedDataRepository.Get("DownstreamRequest").Data; - public HttpResponseMessage HttpResponseMessage - { - get - { - var request = _requestScopedDataRepository.Get("HttpResponseMessage"); - return request.Data; - } - } - - public HostAndPort HostAndPort - { - get - { - var hostAndPort = _requestScopedDataRepository.Get("HostAndPort"); - return hostAndPort.Data; - } - } - - public void SetHostAndPortForThisRequest(HostAndPort hostAndPort) - { - _requestScopedDataRepository.Add("HostAndPort", hostAndPort); - } + public HttpResponseMessage HttpResponseMessage => _requestScopedDataRepository.Get("HttpResponseMessage").Data; public void SetDownstreamRouteForThisRequest(DownstreamRoute downstreamRoute) { _requestScopedDataRepository.Add("DownstreamRoute", downstreamRoute); } - public void SetDownstreamUrlForThisRequest(string downstreamUrl) - { - _requestScopedDataRepository.Add("DownstreamUrl", downstreamUrl); - } - public void SetUpstreamRequestForThisRequest(Request.Request request) { _requestScopedDataRepository.Add("Request", request); } + public void SetDownstreamRequest(HttpRequestMessage request) + { + _requestScopedDataRepository.Add("DownstreamRequest", request); + } + public void SetHttpResponseMessageThisRequest(HttpResponseMessage responseMessage) { _requestScopedDataRepository.Add("HttpResponseMessage", responseMessage); diff --git a/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs b/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs index 457c2448..3f98f959 100644 --- a/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs +++ b/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs @@ -62,6 +62,9 @@ namespace Ocelot.Middleware // This is registered first so it can catch any errors and issue an appropriate response builder.UseResponderMiddleware(); + // Initialises downstream request + builder.UseDownstreamRequestInitialiser(); + // Then we get the downstream route information builder.UseDownstreamRouteFinderMiddleware(); diff --git a/src/Ocelot/QueryStrings/AddQueriesToRequest.cs b/src/Ocelot/QueryStrings/AddQueriesToRequest.cs index 02fcb63d..839da9bc 100644 --- a/src/Ocelot/QueryStrings/AddQueriesToRequest.cs +++ b/src/Ocelot/QueryStrings/AddQueriesToRequest.cs @@ -4,6 +4,9 @@ using Microsoft.AspNetCore.Http; using Ocelot.Configuration; using Ocelot.Infrastructure.Claims.Parser; using Ocelot.Responses; +using System.Security.Claims; +using System.Net.Http; +using System; namespace Ocelot.QueryStrings { @@ -16,13 +19,13 @@ namespace Ocelot.QueryStrings _claimsParser = claimsParser; } - public Response SetQueriesOnContext(List claimsToThings, HttpContext context) + public Response SetQueriesOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest) { - var queryDictionary = ConvertQueryStringToDictionary(context); + var queryDictionary = ConvertQueryStringToDictionary(downstreamRequest.RequestUri.Query); foreach (var config in claimsToThings) { - var value = _claimsParser.GetValue(context.User.Claims, config.NewKey, config.Delimiter, config.Index); + var value = _claimsParser.GetValue(claims, config.NewKey, config.Delimiter, config.Index); if (value.IsError) { @@ -41,22 +44,24 @@ namespace Ocelot.QueryStrings } } - context.Request.QueryString = ConvertDictionaryToQueryString(queryDictionary); + var uriBuilder = new UriBuilder(downstreamRequest.RequestUri); + uriBuilder.Query = ConvertDictionaryToQueryString(queryDictionary); + + downstreamRequest.RequestUri = uriBuilder.Uri; return new OkResponse(); } - private Dictionary ConvertQueryStringToDictionary(HttpContext context) + private Dictionary ConvertQueryStringToDictionary(string queryString) { - return Microsoft.AspNetCore.WebUtilities.QueryHelpers.ParseQuery(context.Request.QueryString.Value) + return Microsoft.AspNetCore.WebUtilities.QueryHelpers + .ParseQuery(queryString) .ToDictionary(q => q.Key, q => q.Value.FirstOrDefault() ?? string.Empty); } - private Microsoft.AspNetCore.Http.QueryString ConvertDictionaryToQueryString(Dictionary queryDictionary) + private string ConvertDictionaryToQueryString(Dictionary queryDictionary) { - var newQueryString = Microsoft.AspNetCore.WebUtilities.QueryHelpers.AddQueryString("", queryDictionary); - - return new Microsoft.AspNetCore.Http.QueryString(newQueryString); + return Microsoft.AspNetCore.WebUtilities.QueryHelpers.AddQueryString("", queryDictionary); } } } \ No newline at end of file diff --git a/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs b/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs index 6fa1b8da..762b8a00 100644 --- a/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs +++ b/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs @@ -2,12 +2,13 @@ using Microsoft.AspNetCore.Http; using Ocelot.Configuration; using Ocelot.Responses; +using System.Net.Http; +using System.Security.Claims; namespace Ocelot.QueryStrings { public interface IAddQueriesToRequest { - Response SetQueriesOnContext(List claimsToThings, - HttpContext context); + Response SetQueriesOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest); } } diff --git a/src/Ocelot/QueryStrings/Middleware/QueryStringBuilderMiddleware.cs b/src/Ocelot/QueryStrings/Middleware/QueryStringBuilderMiddleware.cs index edeee51c..f3001e87 100644 --- a/src/Ocelot/QueryStrings/Middleware/QueryStringBuilderMiddleware.cs +++ b/src/Ocelot/QueryStrings/Middleware/QueryStringBuilderMiddleware.cs @@ -32,7 +32,7 @@ namespace Ocelot.QueryStrings.Middleware { _logger.LogDebug("this route has instructions to convert claims to queries"); - var response = _addQueriesToRequest.SetQueriesOnContext(DownstreamRoute.ReRoute.ClaimsToQueries, context); + var response = _addQueriesToRequest.SetQueriesOnDownstreamRequest(DownstreamRoute.ReRoute.ClaimsToQueries, context.User.Claims, DownstreamRequest); if (response.IsError) { diff --git a/src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs b/src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs index dffc6448..f29dcb82 100644 --- a/src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs +++ b/src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs @@ -111,11 +111,7 @@ namespace Ocelot.RateLimit.Middleware public bool IsWhitelisted(ClientRequestIdentity requestIdentity, RateLimitOptions option) { - if (option.ClientWhitelist.Contains(requestIdentity.ClientId)) - { - return true; - } - return false; + return option.ClientWhitelist.Contains(requestIdentity.ClientId); } public virtual void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule) diff --git a/src/Ocelot/Request/Builder/HttpRequestCreator.cs b/src/Ocelot/Request/Builder/HttpRequestCreator.cs index de030f83..f95db4b9 100644 --- a/src/Ocelot/Request/Builder/HttpRequestCreator.cs +++ b/src/Ocelot/Request/Builder/HttpRequestCreator.cs @@ -1,38 +1,18 @@ -using System.IO; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; +using System.Threading.Tasks; using Ocelot.Responses; -using Ocelot.Configuration; using Ocelot.Requester.QoS; +using System.Net.Http; namespace Ocelot.Request.Builder { public sealed class HttpRequestCreator : IRequestCreator { public async Task> Build( - string httpMethod, - string downstreamUrl, - Stream content, - IHeaderDictionary headers, - QueryString queryString, - string contentType, - RequestId.RequestId requestId, + HttpRequestMessage httpRequestMessage, bool isQos, IQoSProvider qosProvider) { - var request = await new RequestBuilder() - .WithHttpMethod(httpMethod) - .WithDownstreamUrl(downstreamUrl) - .WithQueryString(queryString) - .WithContent(content) - .WithContentType(contentType) - .WithHeaders(headers) - .WithRequestId(requestId) - .WithIsQos(isQos) - .WithQos(qosProvider) - .Build(); - - return new OkResponse(request); + return new OkResponse(new Request(httpRequestMessage, isQos, qosProvider)); } } } \ No newline at end of file diff --git a/src/Ocelot/Request/Builder/IRequestCreator.cs b/src/Ocelot/Request/Builder/IRequestCreator.cs index 379f0aac..85fbfa8d 100644 --- a/src/Ocelot/Request/Builder/IRequestCreator.cs +++ b/src/Ocelot/Request/Builder/IRequestCreator.cs @@ -1,20 +1,15 @@ -using System.IO; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; -using Ocelot.Requester.QoS; -using Ocelot.Responses; - -namespace Ocelot.Request.Builder +namespace Ocelot.Request.Builder { + using System.Net.Http; + using System.Threading.Tasks; + + using Ocelot.Requester.QoS; + using Ocelot.Responses; + public interface IRequestCreator { - Task> Build(string httpMethod, - string downstreamUrl, - Stream content, - IHeaderDictionary headers, - QueryString queryString, - string contentType, - RequestId.RequestId requestId, + Task> Build( + HttpRequestMessage httpRequestMessage, bool isQos, IQoSProvider qosProvider); } diff --git a/src/Ocelot/Request/Builder/RequestBuilder.cs b/src/Ocelot/Request/Builder/RequestBuilder.cs deleted file mode 100644 index e47eea1f..00000000 --- a/src/Ocelot/Request/Builder/RequestBuilder.cs +++ /dev/null @@ -1,177 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Net; -using System.Net.Http; -using System.Net.Http.Headers; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Primitives; -using Ocelot.Requester.QoS; - -namespace Ocelot.Request.Builder -{ - internal sealed class RequestBuilder - { - private HttpMethod _method; - private string _downstreamUrl; - private QueryString _queryString; - private Stream _content; - private string _contentType; - private IHeaderDictionary _headers; - private RequestId.RequestId _requestId; - private readonly string[] _unsupportedHeaders = {"host"}; - private bool _isQos; - private IQoSProvider _qoSProvider; - - public RequestBuilder WithHttpMethod(string httpMethod) - { - _method = new HttpMethod(httpMethod); - return this; - } - - public RequestBuilder WithDownstreamUrl(string downstreamUrl) - { - _downstreamUrl = downstreamUrl; - return this; - } - - public RequestBuilder WithQueryString(QueryString queryString) - { - _queryString = queryString; - return this; - } - - public RequestBuilder WithContent(Stream content) - { - _content = content; - return this; - } - - public RequestBuilder WithContentType(string contentType) - { - _contentType = contentType; - return this; - } - - public RequestBuilder WithHeaders(IHeaderDictionary headers) - { - _headers = headers; - return this; - } - - public RequestBuilder WithRequestId(RequestId.RequestId requestId) - { - _requestId = requestId; - return this; - } - - public RequestBuilder WithIsQos(bool isqos) - { - _isQos = isqos; - return this; - } - - public RequestBuilder WithQos(IQoSProvider qoSProvider) - { - _qoSProvider = qoSProvider; - return this; - } - - public async Task Build() - { - var uri = CreateUri(); - - var httpRequestMessage = new HttpRequestMessage(_method, uri); - - await AddContentToRequest(httpRequestMessage); - - AddContentTypeToRequest(httpRequestMessage); - - AddHeadersToRequest(httpRequestMessage); - - if (ShouldAddRequestId(_requestId, httpRequestMessage.Headers)) - { - AddRequestIdHeader(_requestId, httpRequestMessage); - } - - return new Request(httpRequestMessage,_isQos, _qoSProvider); - } - - private Uri CreateUri() - { - var uri = new Uri(string.Format("{0}{1}", _downstreamUrl, _queryString.ToUriComponent())); - return uri; - } - - private async Task AddContentToRequest(HttpRequestMessage httpRequestMessage) - { - if (_content != null) - { - httpRequestMessage.Content = new ByteArrayContent(await ToByteArray(_content)); - } - } - - private void AddContentTypeToRequest(HttpRequestMessage httpRequestMessage) - { - if (!string.IsNullOrEmpty(_contentType)) - { - httpRequestMessage.Content.Headers.Remove("Content-Type"); - httpRequestMessage.Content.Headers.TryAddWithoutValidation("Content-Type", _contentType); - } - } - - private void AddHeadersToRequest(HttpRequestMessage httpRequestMessage) - { - if (_headers != null) - { - _headers.Remove("Content-Type"); - - foreach (var header in _headers) - { - //todo get rid of if.. - if (IsSupportedHeader(header)) - { - httpRequestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray()); - } - } - } - } - - private bool IsSupportedHeader(KeyValuePair header) - { - return !_unsupportedHeaders.Contains(header.Key.ToLower()); - } - - private void AddRequestIdHeader(RequestId.RequestId requestId, HttpRequestMessage httpRequestMessage) - { - httpRequestMessage.Headers.Add(requestId.RequestIdKey, requestId.RequestIdValue); - } - - private bool RequestIdInHeaders(RequestId.RequestId requestId, HttpRequestHeaders headers) - { - IEnumerable value; - return headers.TryGetValues(requestId.RequestIdKey, out value); - } - - private bool ShouldAddRequestId(RequestId.RequestId requestId, HttpRequestHeaders headers) - { - return !string.IsNullOrEmpty(requestId?.RequestIdKey) - && !string.IsNullOrEmpty(requestId.RequestIdValue) - && !RequestIdInHeaders(requestId, headers); - } - - private async Task ToByteArray(Stream stream) - { - using (stream) - { - using (var memStream = new MemoryStream()) - { - await stream.CopyToAsync(memStream); - return memStream.ToArray(); - } - } - } - } -} diff --git a/src/Ocelot/Request/Mapper/IRequestMapper.cs b/src/Ocelot/Request/Mapper/IRequestMapper.cs new file mode 100644 index 00000000..941a24f7 --- /dev/null +++ b/src/Ocelot/Request/Mapper/IRequestMapper.cs @@ -0,0 +1,12 @@ +namespace Ocelot.Request.Mapper +{ + using System.Net.Http; + using System.Threading.Tasks; + using Microsoft.AspNetCore.Http; + using Ocelot.Responses; + + public interface IRequestMapper + { + Task> Map(HttpRequest request); + } +} diff --git a/src/Ocelot/Request/Mapper/RequestMapper.cs b/src/Ocelot/Request/Mapper/RequestMapper.cs new file mode 100644 index 00000000..17e2afe5 --- /dev/null +++ b/src/Ocelot/Request/Mapper/RequestMapper.cs @@ -0,0 +1,90 @@ +namespace Ocelot.Request.Mapper +{ + using System; + using System.Collections.Generic; + using System.IO; + using System.Linq; + using System.Net.Http; + using System.Threading.Tasks; + + using Microsoft.AspNetCore.Http; + using Microsoft.AspNetCore.Http.Extensions; + using Microsoft.Extensions.Primitives; + using Ocelot.Responses; + + public class RequestMapper : IRequestMapper + { + private readonly string[] _unsupportedHeaders = { "host" }; + + public async Task> Map(HttpRequest request) + { + try + { + var requestMessage = new HttpRequestMessage() + { + Content = await MapContent(request), + Method = MapMethod(request), + RequestUri = MapUri(request) + }; + + MapHeaders(request, requestMessage); + + return new OkResponse(requestMessage); + } + catch (Exception ex) + { + return new ErrorResponse(new UnmappableRequestError(ex)); + } + } + + private async Task MapContent(HttpRequest request) + { + if (request.Body == null) + { + return null; + } + + return new ByteArrayContent(await ToByteArray(request.Body)); + } + + private HttpMethod MapMethod(HttpRequest request) + { + return new HttpMethod(request.Method); + } + + private Uri MapUri(HttpRequest request) + { + return new Uri(request.GetEncodedUrl()); + } + + private void MapHeaders(HttpRequest request, HttpRequestMessage requestMessage) + { + foreach (var header in request.Headers) + { + //todo get rid of if.. + if (IsSupportedHeader(header)) + { + requestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray()); + } + } + } + + private async Task ToByteArray(Stream stream) + { + using (stream) + { + using (var memStream = new MemoryStream()) + { + await stream.CopyToAsync(memStream); + return memStream.ToArray(); + } + } + } + + private bool IsSupportedHeader(KeyValuePair header) + { + return !_unsupportedHeaders.Contains(header.Key.ToLower()); + } + } +} + diff --git a/src/Ocelot/Request/Mapper/UnmappableRequestError.cs b/src/Ocelot/Request/Mapper/UnmappableRequestError.cs new file mode 100644 index 00000000..4a860f5b --- /dev/null +++ b/src/Ocelot/Request/Mapper/UnmappableRequestError.cs @@ -0,0 +1,12 @@ +namespace Ocelot.Request.Mapper +{ + using Ocelot.Errors; + using System; + + public class UnmappableRequestError : Error + { + public UnmappableRequestError(Exception ex) : base($"Error when parsing incoming request, exception: {ex.Message}", OcelotErrorCode.UnmappableRequestError) + { + } + } +} diff --git a/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs b/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs new file mode 100644 index 00000000..a2813c25 --- /dev/null +++ b/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs @@ -0,0 +1,47 @@ +namespace Ocelot.Request.Middleware +{ + using System.Threading.Tasks; + using Microsoft.AspNetCore.Http; + + using Ocelot.Infrastructure.RequestData; + using Ocelot.Logging; + using Ocelot.Middleware; + + public class DownstreamRequestInitialiserMiddleware : OcelotMiddleware + { + private readonly RequestDelegate _next; + private readonly IOcelotLogger _logger; + private readonly Mapper.IRequestMapper _requestMapper; + + public DownstreamRequestInitialiserMiddleware(RequestDelegate next, + IOcelotLoggerFactory loggerFactory, + IRequestScopedDataRepository requestScopedDataRepository, + Mapper.IRequestMapper requestMapper) + :base(requestScopedDataRepository) + { + _next = next; + _logger = loggerFactory.CreateLogger(); + _requestMapper = requestMapper; + } + + public async Task Invoke(HttpContext context) + { + _logger.LogDebug("started calling request builder middleware"); + + var downstreamRequest = await _requestMapper.Map(context.Request); + if (downstreamRequest.IsError) + { + SetPipelineError(downstreamRequest.Errors); + return; + } + + SetDownstreamRequest(downstreamRequest.Data); + + _logger.LogDebug("calling next middleware"); + + await _next.Invoke(context); + + _logger.LogDebug("succesfully called next middleware"); + } + } +} \ No newline at end of file diff --git a/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs b/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs index 0701e089..c59af7a6 100644 --- a/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs +++ b/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs @@ -43,14 +43,8 @@ namespace Ocelot.Request.Middleware return; } - var buildResult = await _requestCreator - .Build(context.Request.Method, - DownstreamUrl, - context.Request.Body, - context.Request.Headers, - context.Request.QueryString, - context.Request.ContentType, - new RequestId.RequestId(DownstreamRoute?.ReRoute?.RequestIdKey, context.TraceIdentifier), + var buildResult = await _requestCreator.Build( + DownstreamRequest, DownstreamRoute.ReRoute.IsQos, qosProvider.Data); diff --git a/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddlewareExtensions.cs b/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddlewareExtensions.cs index 4c08afec..20bb9164 100644 --- a/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddlewareExtensions.cs +++ b/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddlewareExtensions.cs @@ -8,5 +8,10 @@ namespace Ocelot.Request.Middleware { return builder.UseMiddleware(); } + + public static IApplicationBuilder UseDownstreamRequestInitialiser(this IApplicationBuilder builder) + { + return builder.UseMiddleware(); + } } } \ No newline at end of file diff --git a/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs b/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs index 1e1a955c..08222d9d 100644 --- a/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs +++ b/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs @@ -5,6 +5,9 @@ using Microsoft.Extensions.Primitives; using Ocelot.Infrastructure.RequestData; using Ocelot.Logging; using Ocelot.Middleware; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Collections.Generic; namespace Ocelot.RequestId.Middleware { @@ -30,8 +33,6 @@ namespace Ocelot.RequestId.Middleware SetOcelotRequestId(context); - _logger.LogDebug("set requestId"); - _logger.TraceInvokeNext(); await _next.Invoke(context); _logger.TraceInvokeNextCompleted(); @@ -40,21 +41,40 @@ namespace Ocelot.RequestId.Middleware private void SetOcelotRequestId(HttpContext context) { - var key = DefaultRequestIdKey.Value; - - if (DownstreamRoute.ReRoute.RequestIdKey != null) + // if get request ID is set on upstream request then retrieve it + var key = DownstreamRoute.ReRoute.RequestIdKey ?? DefaultRequestIdKey.Value; + + StringValues upstreamRequestIds; + if (context.Request.Headers.TryGetValue(key, out upstreamRequestIds)) { - key = DownstreamRoute.ReRoute.RequestIdKey; + context.TraceIdentifier = upstreamRequestIds.First(); } - StringValues requestId; + // set request ID on downstream request, if required + var requestId = new RequestId(DownstreamRoute?.ReRoute?.RequestIdKey, context.TraceIdentifier); - if (context.Request.Headers.TryGetValue(key, out requestId)) + if (ShouldAddRequestId(requestId, DownstreamRequest.Headers)) { - _requestScopedDataRepository.Add("RequestId", requestId.First()); - - context.TraceIdentifier = requestId; + AddRequestIdHeader(requestId, DownstreamRequest); } } + + private bool ShouldAddRequestId(RequestId requestId, HttpRequestHeaders headers) + { + return !string.IsNullOrEmpty(requestId?.RequestIdKey) + && !string.IsNullOrEmpty(requestId.RequestIdValue) + && !RequestIdInHeaders(requestId, headers); + } + + private bool RequestIdInHeaders(RequestId requestId, HttpRequestHeaders headers) + { + IEnumerable value; + return headers.TryGetValues(requestId.RequestIdKey, out value); + } + + private void AddRequestIdHeader(RequestId requestId, HttpRequestMessage httpRequestMessage) + { + httpRequestMessage.Headers.Add(requestId.RequestIdKey, requestId.RequestIdValue); + } } } \ No newline at end of file diff --git a/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareTests.cs b/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareTests.cs index 66532a18..6537c684 100644 --- a/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Cache/OutputCacheMiddlewareTests.cs @@ -2,11 +2,9 @@ using System.Collections.Generic; using System.IO; using System.Net.Http; -using CacheManager.Core; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using Moq; using Ocelot.Cache; using Ocelot.Cache.Middleware; @@ -37,7 +35,6 @@ namespace Ocelot.UnitTests.Cache _cacheManager = new Mock>(); _scopedRepo = new Mock(); - _url = "http://localhost:51879"; var builder = new WebHostBuilder() .ConfigureServices(x => @@ -57,6 +54,10 @@ namespace Ocelot.UnitTests.Cache app.UseOutputCacheMiddleware(); }); + _scopedRepo + .Setup(sr => sr.Get("DownstreamRequest")) + .Returns(new OkResponse(new HttpRequestMessage(HttpMethod.Get, "https://some.url/blah?abcd=123"))); + _server = new TestServer(builder); _client = _server.CreateClient(); } diff --git a/test/Ocelot.UnitTests/DownstreamUrlCreator/DownstreamUrlCreatorMiddlewareTests.cs b/test/Ocelot.UnitTests/DownstreamUrlCreator/DownstreamUrlCreatorMiddlewareTests.cs index 438710f6..aea1780f 100644 --- a/test/Ocelot.UnitTests/DownstreamUrlCreator/DownstreamUrlCreatorMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/DownstreamUrlCreator/DownstreamUrlCreatorMiddlewareTests.cs @@ -5,12 +5,9 @@ using System.Net.Http; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using Moq; -using Ocelot.Configuration; using Ocelot.Configuration.Builder; using Ocelot.DownstreamRouteFinder; -using Ocelot.DownstreamRouteFinder.Middleware; using Ocelot.DownstreamRouteFinder.UrlMatcher; using Ocelot.DownstreamUrlCreator; using Ocelot.DownstreamUrlCreator.Middleware; @@ -21,6 +18,7 @@ using Ocelot.Responses; using Ocelot.Values; using TestStack.BDDfy; using Xunit; +using Shouldly; namespace Ocelot.UnitTests.DownstreamUrlCreator { @@ -33,10 +31,9 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator private readonly TestServer _server; private readonly HttpClient _client; private Response _downstreamRoute; - private HttpResponseMessage _result; private OkResponse _downstreamPath; - private OkResponse _downstreamUrl; - private HostAndPort _hostAndPort; + private HttpRequestMessage _downstreamRequest; + private HttpResponseMessage _result; public DownstreamUrlCreatorMiddlewareTests() { @@ -63,65 +60,34 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator app.UseDownstreamUrlCreatorMiddleware(); }); + _downstreamRequest = new HttpRequestMessage(HttpMethod.Get, "https://my.url/abc/?q=123"); + + _scopedRepository + .Setup(sr => sr.Get("DownstreamRequest")) + .Returns(new OkResponse(_downstreamRequest)); + _server = new TestServer(builder); _client = _server.CreateClient(); } [Fact] - public void should_call_dependencies_correctly() + public void should_replace_scheme_and_path() { - var hostAndPort = new HostAndPort("127.0.0.1", 80); - this.Given(x => x.GivenTheDownStreamRouteIs( new DownstreamRoute( new List(), new ReRouteBuilder() .WithDownstreamPathTemplate("any old string") .WithUpstreamHttpMethod("Get") + .WithDownstreamScheme("https") .Build()))) - .And(x => x.GivenTheHostAndPortIs(hostAndPort)) - .And(x => x.TheUrlReplacerReturns("/api/products/1")) - .And(x => x.TheUrlBuilderReturns("http://127.0.0.1:80/api/products/1")) + .And(x => x.GivenTheDownstreamRequestUriIs("http://my.url/abc?q=123")) + .And(x => x.GivenTheUrlReplacerWillReturn("/api/products/1")) .When(x => x.WhenICallTheMiddleware()) - .Then(x => x.ThenTheScopedDataRepositoryIsCalledCorrectly()) + .Then(x => x.ThenTheDownstreamRequestUriIs("https://my.url:80/api/products/1?q=123")) .BDDfy(); } - private void GivenTheHostAndPortIs(HostAndPort hostAndPort) - { - _hostAndPort = hostAndPort; - _scopedRepository - .Setup(x => x.Get("HostAndPort")) - .Returns(new OkResponse(_hostAndPort)); - } - - private void TheUrlBuilderReturns(string dsUrl) - { - _downstreamUrl = new OkResponse(new DownstreamUrl(dsUrl)); - _urlBuilder - .Setup(x => x.Build(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(_downstreamUrl); - } - - private void TheUrlReplacerReturns(string downstreamUrl) - { - _downstreamPath = new OkResponse(new DownstreamPath(downstreamUrl)); - _downstreamUrlTemplateVariableReplacer - .Setup(x => x.Replace(It.IsAny(), It.IsAny>())) - .Returns(_downstreamPath); - } - - private void ThenTheScopedDataRepositoryIsCalledCorrectly() - { - _scopedRepository - .Verify(x => x.Add("DownstreamUrl", _downstreamUrl.Data.Value), Times.Once()); - } - - private void WhenICallTheMiddleware() - { - _result = _client.GetAsync(_url).Result; - } - private void GivenTheDownStreamRouteIs(DownstreamRoute downstreamRoute) { _downstreamRoute = new OkResponse(downstreamRoute); @@ -130,6 +96,29 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator .Returns(_downstreamRoute); } + private void GivenTheDownstreamRequestUriIs(string uri) + { + _downstreamRequest.RequestUri = new Uri(uri); + } + + private void GivenTheUrlReplacerWillReturn(string path) + { + _downstreamPath = new OkResponse(new DownstreamPath(path)); + _downstreamUrlTemplateVariableReplacer + .Setup(x => x.Replace(It.IsAny(), It.IsAny>())) + .Returns(_downstreamPath); + } + + private void WhenICallTheMiddleware() + { + _result = _client.GetAsync(_url).Result; + } + + private void ThenTheDownstreamRequestUriIs(string expectedUri) + { + _downstreamRequest.RequestUri.OriginalString.ShouldBe(expectedUri); + } + public void Dispose() { _client.Dispose(); diff --git a/test/Ocelot.UnitTests/DownstreamUrlCreator/UrlBuilderTests.cs b/test/Ocelot.UnitTests/DownstreamUrlCreator/UrlBuilderTests.cs index 7e512798..0414bbb3 100644 --- a/test/Ocelot.UnitTests/DownstreamUrlCreator/UrlBuilderTests.cs +++ b/test/Ocelot.UnitTests/DownstreamUrlCreator/UrlBuilderTests.cs @@ -1,7 +1,5 @@ using System; -using Ocelot.Configuration; using Ocelot.DownstreamUrlCreator; -using Ocelot.DownstreamUrlCreator.UrlTemplateReplacer; using Ocelot.Responses; using Ocelot.Values; using Shouldly; diff --git a/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs b/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs index 47951859..d276fd24 100644 --- a/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs +++ b/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs @@ -1,8 +1,6 @@ using System.Collections.Generic; using System.Linq; using System.Security.Claims; -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Primitives; using Moq; using Ocelot.Configuration; using Ocelot.Errors; @@ -12,6 +10,7 @@ using Ocelot.Responses; using Shouldly; using TestStack.BDDfy; using Xunit; +using System.Net.Http; namespace Ocelot.UnitTests.Headers { @@ -19,8 +18,9 @@ namespace Ocelot.UnitTests.Headers { private readonly AddHeadersToRequest _addHeadersToRequest; private readonly Mock _parser; + private readonly HttpRequestMessage _downstreamRequest; + private List _claims; private List _configuration; - private HttpContext _context; private Response _result; private Response _claimValue; @@ -28,17 +28,15 @@ namespace Ocelot.UnitTests.Headers { _parser = new Mock(); _addHeadersToRequest = new AddHeadersToRequest(_parser.Object); + _downstreamRequest = new HttpRequestMessage(); } [Fact] - public void should_add_headers_to_context() + public void should_add_headers_to_downstreamRequest() { - var context = new DefaultHttpContext + var claims = new List { - User = new ClaimsPrincipal(new ClaimsIdentity(new List - { - new Claim("test", "data") - })) + new Claim("test", "data") }; this.Given( @@ -46,7 +44,7 @@ namespace Ocelot.UnitTests.Headers { new ClaimToThing("header-key", "", "", 0) })) - .Given(x => x.GivenHttpContext(context)) + .Given(x => x.GivenClaims(claims)) .And(x => x.GivenTheClaimParserReturns(new OkResponse("value"))) .When(x => x.WhenIAddHeadersToTheRequest()) .Then(x => x.ThenTheResultIsSuccess()) @@ -55,25 +53,19 @@ namespace Ocelot.UnitTests.Headers } [Fact] - public void if_header_exists_should_replace_it() + public void should_replace_existing_headers_on_request() { - var context = new DefaultHttpContext - { - User = new ClaimsPrincipal(new ClaimsIdentity(new List - { - new Claim("test", "data") - })), - }; - - context.Request.Headers.Add("header-key", new StringValues("initial")); - this.Given( x => x.GivenConfigurationHeaderExtractorProperties(new List { new ClaimToThing("header-key", "", "", 0) })) - .Given(x => x.GivenHttpContext(context)) + .Given(x => x.GivenClaims(new List + { + new Claim("test", "data") + })) .And(x => x.GivenTheClaimParserReturns(new OkResponse("value"))) + .And(x => x.GivenThatTheRequestContainsHeader("header-key", "initial")) .When(x => x.WhenIAddHeadersToTheRequest()) .Then(x => x.ThenTheResultIsSuccess()) .And(x => x.ThenTheHeaderIsAdded()) @@ -88,7 +80,7 @@ namespace Ocelot.UnitTests.Headers { new ClaimToThing("", "", "", 0) })) - .Given(x => x.GivenHttpContext(new DefaultHttpContext())) + .Given(x => x.GivenClaims(new List())) .And(x => x.GivenTheClaimParserReturns(new ErrorResponse(new List { new AnyError() @@ -98,10 +90,9 @@ namespace Ocelot.UnitTests.Headers .BDDfy(); } - private void ThenTheHeaderIsAdded() + private void GivenClaims(List claims) { - var header = _context.Request.Headers.First(x => x.Key == "header-key"); - header.Value.First().ShouldBe(_claimValue.Data); + _claims = claims; } private void GivenConfigurationHeaderExtractorProperties(List configuration) @@ -109,9 +100,9 @@ namespace Ocelot.UnitTests.Headers _configuration = configuration; } - private void GivenHttpContext(HttpContext context) + private void GivenThatTheRequestContainsHeader(string key, string value) { - _context = context; + _downstreamRequest.Headers.Add(key, value); } private void GivenTheClaimParserReturns(Response claimValue) @@ -129,7 +120,7 @@ namespace Ocelot.UnitTests.Headers private void WhenIAddHeadersToTheRequest() { - _result = _addHeadersToRequest.SetHeadersOnContext(_configuration, _context); + _result = _addHeadersToRequest.SetHeadersOnDownstreamRequest(_configuration, _claims, _downstreamRequest); } private void ThenTheResultIsSuccess() @@ -143,6 +134,12 @@ namespace Ocelot.UnitTests.Headers _result.IsError.ShouldBe(true); } + private void ThenTheHeaderIsAdded() + { + var header = _downstreamRequest.Headers.First(x => x.Key == "header-key"); + header.Value.First().ShouldBe(_claimValue.Data); + } + class AnyError : Error { public AnyError() diff --git a/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs b/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs index 032a76e9..95989688 100644 --- a/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs @@ -3,16 +3,13 @@ using System.Collections.Generic; using System.IO; using System.Net.Http; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using Moq; using Ocelot.Configuration; using Ocelot.Configuration.Builder; using Ocelot.DownstreamRouteFinder; using Ocelot.DownstreamRouteFinder.UrlMatcher; -using Ocelot.DownstreamUrlCreator.Middleware; using Ocelot.Headers; using Ocelot.Headers.Middleware; using Ocelot.Infrastructure.RequestData; @@ -27,6 +24,7 @@ namespace Ocelot.UnitTests.Headers { private readonly Mock _scopedRepository; private readonly Mock _addHeaders; + private readonly HttpRequestMessage _downstreamRequest; private readonly string _url; private readonly TestServer _server; private readonly HttpClient _client; @@ -58,6 +56,12 @@ namespace Ocelot.UnitTests.Headers app.UseHttpRequestHeadersBuilderMiddleware(); }); + _downstreamRequest = new HttpRequestMessage(); + + _scopedRepository + .Setup(sr => sr.Get("DownstreamRequest")) + .Returns(new OkResponse(_downstreamRequest)); + _server = new TestServer(builder); _client = _server.CreateClient(); } @@ -76,25 +80,29 @@ namespace Ocelot.UnitTests.Headers .Build()); this.Given(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) - .And(x => x.GivenTheAddHeadersToRequestReturns()) + .And(x => x.GivenTheAddHeadersToDownstreamRequestReturnsOk()) .When(x => x.WhenICallTheMiddleware()) .Then(x => x.ThenTheAddHeadersToRequestIsCalledCorrectly()) .BDDfy(); } - private void GivenTheAddHeadersToRequestReturns() + private void GivenTheAddHeadersToDownstreamRequestReturnsOk() { _addHeaders - .Setup(x => x.SetHeadersOnContext(It.IsAny>(), - It.IsAny())) + .Setup(x => x.SetHeadersOnDownstreamRequest( + It.IsAny>(), + It.IsAny>(), + It.IsAny())) .Returns(new OkResponse()); } private void ThenTheAddHeadersToRequestIsCalledCorrectly() { _addHeaders - .Verify(x => x.SetHeadersOnContext(It.IsAny>(), - It.IsAny()), Times.Once); + .Verify(x => x.SetHeadersOnDownstreamRequest( + It.IsAny>(), + It.IsAny>(), + _downstreamRequest), Times.Once); } private void WhenICallTheMiddleware() diff --git a/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerMiddlewareTests.cs b/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerMiddlewareTests.cs index e76f3c2f..a7dfd1ac 100644 --- a/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/LoadBalancer/LoadBalancerMiddlewareTests.cs @@ -16,6 +16,7 @@ using Ocelot.Responses; using Ocelot.Values; using TestStack.BDDfy; using Xunit; +using Shouldly; namespace Ocelot.UnitTests.LoadBalancer { @@ -29,10 +30,10 @@ namespace Ocelot.UnitTests.LoadBalancer private readonly HttpClient _client; private HttpResponseMessage _result; private HostAndPort _hostAndPort; - private OkResponse _downstreamUrl; private OkResponse _downstreamRoute; private ErrorResponse _getLoadBalancerHouseError; private ErrorResponse _getHostAndPortError; + private HttpRequestMessage _downstreamRequest; public LoadBalancerMiddlewareTests() { @@ -59,6 +60,10 @@ namespace Ocelot.UnitTests.LoadBalancer app.UseLoadBalancingMiddleware(); }); + _downstreamRequest = new HttpRequestMessage(HttpMethod.Get, ""); + _scopedRepository + .Setup(sr => sr.Get("DownstreamRequest")) + .Returns(new OkResponse(_downstreamRequest)); _server = new TestServer(builder); _client = _server.CreateClient(); } @@ -71,12 +76,12 @@ namespace Ocelot.UnitTests.LoadBalancer .WithUpstreamHttpMethod("Get") .Build()); - this.Given(x => x.GivenTheDownStreamUrlIs("any old string")) + this.Given(x => x.GivenTheDownStreamUrlIs("http://my.url/abc?q=123")) .And(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) .And(x => x.GivenTheLoadBalancerHouseReturns()) .And(x => x.GivenTheLoadBalancerReturns()) .When(x => x.WhenICallTheMiddleware()) - .Then(x => x.ThenTheScopedDataRepositoryIsCalledCorrectly()) + .Then(x => x.ThenTheDownstreamUrlIsReplacedWith("http://127.0.0.1:80/abc?q=123")) .BDDfy(); } @@ -88,7 +93,7 @@ namespace Ocelot.UnitTests.LoadBalancer .WithUpstreamHttpMethod("Get") .Build()); - this.Given(x => x.GivenTheDownStreamUrlIs("any old string")) + this.Given(x => x.GivenTheDownStreamUrlIs("http://my.url/abc?q=123")) .And(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) .And(x => x.GivenTheLoadBalancerHouseReturnsAnError()) .When(x => x.WhenICallTheMiddleware()) @@ -104,7 +109,7 @@ namespace Ocelot.UnitTests.LoadBalancer .WithUpstreamHttpMethod("Get") .Build()); - this.Given(x => x.GivenTheDownStreamUrlIs("any old string")) + this.Given(x => x.GivenTheDownStreamUrlIs("http://my.url/abc?q=123")) .And(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) .And(x => x.GivenTheLoadBalancerHouseReturns()) .And(x => x.GivenTheLoadBalancerReturnsAnError()) @@ -113,6 +118,11 @@ namespace Ocelot.UnitTests.LoadBalancer .BDDfy(); } + private void GivenTheDownStreamUrlIs(string downstreamUrl) + { + _downstreamRequest.RequestUri = new System.Uri(downstreamUrl); + } + private void GivenTheLoadBalancerReturnsAnError() { _getHostAndPortError = new ErrorResponse(new List() { new ServicesAreNullError($"services were null for bah") }); @@ -157,10 +167,9 @@ namespace Ocelot.UnitTests.LoadBalancer .Returns(_getLoadBalancerHouseError); } - private void ThenTheScopedDataRepositoryIsCalledCorrectly() + private void WhenICallTheMiddleware() { - _scopedRepository - .Verify(x => x.Add("HostAndPort", _hostAndPort), Times.Once()); + _result = _client.GetAsync(_url).Result; } private void ThenAnErrorStatingLoadBalancerCouldNotBeFoundIsSetOnPipeline() @@ -190,17 +199,11 @@ namespace Ocelot.UnitTests.LoadBalancer .Verify(x => x.Add("OcelotMiddlewareErrors", _getHostAndPortError.Errors), Times.Once); } - private void WhenICallTheMiddleware() - { - _result = _client.GetAsync(_url).Result; - } - private void GivenTheDownStreamUrlIs(string downstreamUrl) + + private void ThenTheDownstreamUrlIsReplacedWith(string expectedUri) { - _downstreamUrl = new OkResponse(downstreamUrl); - _scopedRepository - .Setup(x => x.Get(It.IsAny())) - .Returns(_downstreamUrl); + _downstreamRequest.RequestUri.OriginalString.ShouldBe(expectedUri); } public void Dispose() diff --git a/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs b/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs index 99cf68b8..ee7769a5 100644 --- a/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs +++ b/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs @@ -1,7 +1,6 @@ using System.Collections.Generic; using System.Linq; using System.Security.Claims; -using Microsoft.AspNetCore.Http; using Moq; using Ocelot.Configuration; using Ocelot.Errors; @@ -11,15 +10,18 @@ using Ocelot.Responses; using Shouldly; using TestStack.BDDfy; using Xunit; +using System.Net.Http; +using System; namespace Ocelot.UnitTests.QueryStrings { public class AddQueriesToRequestTests { private readonly AddQueriesToRequest _addQueriesToRequest; + private readonly HttpRequestMessage _downstreamRequest; private readonly Mock _parser; private List _configuration; - private HttpContext _context; + private List _claims; private Response _result; private Response _claimValue; @@ -27,17 +29,15 @@ namespace Ocelot.UnitTests.QueryStrings { _parser = new Mock(); _addQueriesToRequest = new AddQueriesToRequest(_parser.Object); + _downstreamRequest = new HttpRequestMessage(HttpMethod.Post, "http://my.url/abc?q=123"); } [Fact] - public void should_add_queries_to_context() + public void should_add_new_queries_to_downstream_request() { - var context = new DefaultHttpContext + var claims = new List { - User = new ClaimsPrincipal(new ClaimsIdentity(new List - { - new Claim("test", "data") - })) + new Claim("test", "data") }; this.Given( @@ -45,7 +45,7 @@ namespace Ocelot.UnitTests.QueryStrings { new ClaimToThing("query-key", "", "", 0) })) - .Given(x => x.GivenHttpContext(context)) + .Given(x => x.GivenClaims(claims)) .And(x => x.GivenTheClaimParserReturns(new OkResponse("value"))) .When(x => x.WhenIAddQueriesToTheRequest()) .Then(x => x.ThenTheResultIsSuccess()) @@ -54,24 +54,20 @@ namespace Ocelot.UnitTests.QueryStrings } [Fact] - public void if_query_exists_should_replace_it() + public void should_replace_existing_queries_on_downstream_request() { - var context = new DefaultHttpContext + var claims = new List { - User = new ClaimsPrincipal(new ClaimsIdentity(new List - { - new Claim("test", "data") - })), + new Claim("test", "data") }; - context.Request.QueryString = context.Request.QueryString.Add("query-key", "initial"); - this.Given( x => x.GivenAClaimToThing(new List { new ClaimToThing("query-key", "", "", 0) })) - .Given(x => x.GivenHttpContext(context)) + .And(x => x.GivenClaims(claims)) + .And(x => x.GivenTheDownstreamRequestHasQueryString("query-key", "initial")) .And(x => x.GivenTheClaimParserReturns(new OkResponse("value"))) .When(x => x.WhenIAddQueriesToTheRequest()) .Then(x => x.ThenTheResultIsSuccess()) @@ -87,7 +83,7 @@ namespace Ocelot.UnitTests.QueryStrings { new ClaimToThing("", "", "", 0) })) - .Given(x => x.GivenHttpContext(new DefaultHttpContext())) + .Given(x => x.GivenClaims(new List())) .And(x => x.GivenTheClaimParserReturns(new ErrorResponse(new List { new AnyError() @@ -99,7 +95,8 @@ namespace Ocelot.UnitTests.QueryStrings private void ThenTheQueryIsAdded() { - var query = _context.Request.Query.First(x => x.Key == "query-key"); + var queries = Microsoft.AspNetCore.WebUtilities.QueryHelpers.ParseQuery(_downstreamRequest.RequestUri.OriginalString); + var query = queries.First(x => x.Key == "query-key"); query.Value.First().ShouldBe(_claimValue.Data); } @@ -108,9 +105,17 @@ namespace Ocelot.UnitTests.QueryStrings _configuration = configuration; } - private void GivenHttpContext(HttpContext context) + private void GivenClaims(List claims) { - _context = context; + _claims = claims; + } + + private void GivenTheDownstreamRequestHasQueryString(string key, string value) + { + var newUri = Microsoft.AspNetCore.WebUtilities.QueryHelpers + .AddQueryString(_downstreamRequest.RequestUri.OriginalString, key, value); + + _downstreamRequest.RequestUri = new Uri(newUri); } private void GivenTheClaimParserReturns(Response claimValue) @@ -128,7 +133,7 @@ namespace Ocelot.UnitTests.QueryStrings private void WhenIAddQueriesToTheRequest() { - _result = _addQueriesToRequest.SetQueriesOnContext(_configuration, _context); + _result = _addQueriesToRequest.SetQueriesOnDownstreamRequest(_configuration, _claims, _downstreamRequest); } private void ThenTheResultIsSuccess() @@ -138,7 +143,6 @@ namespace Ocelot.UnitTests.QueryStrings private void ThenTheResultIsError() { - _result.IsError.ShouldBe(true); } diff --git a/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs b/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs index f381ff1b..898563f7 100644 --- a/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs @@ -3,16 +3,13 @@ using System.Collections.Generic; using System.IO; using System.Net.Http; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using Moq; using Ocelot.Configuration; using Ocelot.Configuration.Builder; using Ocelot.DownstreamRouteFinder; using Ocelot.DownstreamRouteFinder.UrlMatcher; -using Ocelot.Headers.Middleware; using Ocelot.Infrastructure.RequestData; using Ocelot.Logging; using Ocelot.QueryStrings; @@ -20,6 +17,7 @@ using Ocelot.QueryStrings.Middleware; using Ocelot.Responses; using TestStack.BDDfy; using Xunit; +using System.Security.Claims; namespace Ocelot.UnitTests.QueryStrings { @@ -30,6 +28,7 @@ namespace Ocelot.UnitTests.QueryStrings private readonly string _url; private readonly TestServer _server; private readonly HttpClient _client; + private readonly HttpRequestMessage _downstreamRequest; private Response _downstreamRoute; private HttpResponseMessage _result; @@ -56,6 +55,11 @@ namespace Ocelot.UnitTests.QueryStrings app.UseQueryStringBuilderMiddleware(); }); + _downstreamRequest = new HttpRequestMessage(); + + _scopedRepository.Setup(sr => sr.Get("DownstreamRequest")) + .Returns(new OkResponse(_downstreamRequest)); + _server = new TestServer(builder); _client = _server.CreateClient(); } @@ -74,25 +78,29 @@ namespace Ocelot.UnitTests.QueryStrings .Build()); this.Given(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) - .And(x => x.GivenTheAddHeadersToRequestReturns()) + .And(x => x.GivenTheAddHeadersToRequestReturnsOk()) .When(x => x.WhenICallTheMiddleware()) .Then(x => x.ThenTheAddQueriesToRequestIsCalledCorrectly()) .BDDfy(); } - private void GivenTheAddHeadersToRequestReturns() + private void GivenTheAddHeadersToRequestReturnsOk() { _addQueries - .Setup(x => x.SetQueriesOnContext(It.IsAny>(), - It.IsAny())) + .Setup(x => x.SetQueriesOnDownstreamRequest( + It.IsAny>(), + It.IsAny>(), + It.IsAny())) .Returns(new OkResponse()); } private void ThenTheAddQueriesToRequestIsCalledCorrectly() { _addQueries - .Verify(x => x.SetQueriesOnContext(It.IsAny>(), - It.IsAny()), Times.Once); + .Verify(x => x.SetQueriesOnDownstreamRequest( + It.IsAny>(), + It.IsAny>(), + _downstreamRequest), Times.Once); } private void WhenICallTheMiddleware() diff --git a/test/Ocelot.UnitTests/Request/DownstreamRequestInitialiserMiddlewareTests.cs b/test/Ocelot.UnitTests/Request/DownstreamRequestInitialiserMiddlewareTests.cs new file mode 100644 index 00000000..91c1d011 --- /dev/null +++ b/test/Ocelot.UnitTests/Request/DownstreamRequestInitialiserMiddlewareTests.cs @@ -0,0 +1,142 @@ +namespace Ocelot.UnitTests.Request +{ + using System.Net.Http; + using Microsoft.AspNetCore.Http; + using Moq; + using Ocelot.Logging; + using Ocelot.Request.Mapper; + using Ocelot.Request.Middleware; + using Ocelot.Infrastructure.RequestData; + using TestStack.BDDfy; + using Xunit; + using Ocelot.Responses; + + public class DownstreamRequestInitialiserMiddlewareTests + { + readonly DownstreamRequestInitialiserMiddleware _middleware; + + readonly Mock _httpContext; + + readonly Mock _httpRequest; + + readonly Mock _next; + + readonly Mock _requestMapper; + + readonly Mock _repo; + + readonly Mock _loggerFactory; + + readonly Mock _logger; + + Response _mappedRequest; + + public DownstreamRequestInitialiserMiddlewareTests() + { + + _httpContext = new Mock(); + _httpRequest = new Mock(); + _requestMapper = new Mock(); + _repo = new Mock(); + _next = new Mock(); + _logger = new Mock(); + + _loggerFactory = new Mock(); + _loggerFactory + .Setup(lf => lf.CreateLogger()) + .Returns(_logger.Object); + + _middleware = new DownstreamRequestInitialiserMiddleware( + _next.Object, + _loggerFactory.Object, + _repo.Object, + _requestMapper.Object); + } + + [Fact] + public void Should_handle_valid_httpRequest() + { + this.Given(_ => GivenTheHttpContextContainsARequest()) + .And(_ => GivenTheMapperWillReturnAMappedRequest()) + .When(_ => WhenTheMiddlewareIsInvoked()) + .Then(_ => ThenTheContexRequestIsMappedToADownstreamRequest()) + .And(_ => ThenTheDownstreamRequestIsStored()) + .And(_ => ThenTheNextMiddlewareIsInvoked()) + .BDDfy(); + } + + [Fact] + public void Should_handle_mapping_failure() + { + this.Given(_ => GivenTheHttpContextContainsARequest()) + .And(_ => GivenTheMapperWillReturnAnError()) + .When(_ => WhenTheMiddlewareIsInvoked()) + .And(_ => ThenTheDownstreamRequestIsNotStored()) + .And(_ => ThenAPipelineErrorIsStored()) + .And(_ => ThenTheNextMiddlewareIsNotInvoked()) + .BDDfy(); + } + + private void GivenTheHttpContextContainsARequest() + { + _httpContext + .Setup(hc => hc.Request) + .Returns(_httpRequest.Object); + } + + private void GivenTheMapperWillReturnAMappedRequest() + { + _mappedRequest = new OkResponse(new HttpRequestMessage()); + + _requestMapper + .Setup(rm => rm.Map(It.IsAny())) + .ReturnsAsync(_mappedRequest); + } + + private void GivenTheMapperWillReturnAnError() + { + _mappedRequest = new ErrorResponse(new UnmappableRequestError(new System.Exception("boooom!"))); + + _requestMapper + .Setup(rm => rm.Map(It.IsAny())) + .ReturnsAsync(_mappedRequest); + } + + private void WhenTheMiddlewareIsInvoked() + { + _middleware.Invoke(_httpContext.Object).GetAwaiter().GetResult(); + } + + private void ThenTheContexRequestIsMappedToADownstreamRequest() + { + _requestMapper.Verify(rm => rm.Map(_httpRequest.Object), Times.Once); + } + + private void ThenTheDownstreamRequestIsStored() + { + _repo.Verify(r => r.Add("DownstreamRequest", _mappedRequest.Data), Times.Once); + } + + private void ThenTheDownstreamRequestIsNotStored() + { + _repo.Verify(r => r.Add("DownstreamRequest", It.IsAny()), Times.Never); + } + + private void ThenAPipelineErrorIsStored() + { + _repo.Verify(r => r.Add("OcelotMiddlewareError", true), Times.Once); + _repo.Verify(r => r.Add("OcelotMiddlewareErrors", _mappedRequest.Errors), Times.Once); + } + + private void ThenTheNextMiddlewareIsInvoked() + { + _next.Verify(n => n(_httpContext.Object), Times.Once); + } + + private void ThenTheNextMiddlewareIsNotInvoked() + { + _next.Verify(n => n(It.IsAny()), Times.Never); + } + + } +} diff --git a/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs b/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs index 002e4e02..8cec477e 100644 --- a/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs @@ -1,13 +1,10 @@ using System; using System.Collections.Generic; using System.IO; -using System.Net; using System.Net.Http; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using Moq; using Ocelot.Configuration.Builder; using Ocelot.DownstreamRouteFinder; @@ -19,7 +16,6 @@ using Ocelot.Request.Middleware; using Ocelot.Responses; using TestStack.BDDfy; using Xunit; -using Ocelot.Configuration; using Ocelot.Requester.QoS; namespace Ocelot.UnitTests.Request @@ -29,6 +25,7 @@ namespace Ocelot.UnitTests.Request private readonly Mock _requestBuilder; private readonly Mock _scopedRepository; private readonly Mock _qosProviderHouse; + private readonly HttpRequestMessage _downstreamRequest; private readonly string _url; private readonly TestServer _server; private readonly HttpClient _client; @@ -62,6 +59,12 @@ namespace Ocelot.UnitTests.Request app.UseHttpRequestBuilderMiddleware(); }); + _downstreamRequest = new HttpRequestMessage(); + + _scopedRepository + .Setup(sr => sr.Get("DownstreamRequest")) + .Returns(new OkResponse(_downstreamRequest)); + _server = new TestServer(builder); _client = _server.CreateClient(); } @@ -103,9 +106,9 @@ namespace Ocelot.UnitTests.Request private void GivenTheRequestBuilderReturns(Ocelot.Request.Request request) { _request = new OkResponse(request); + _requestBuilder - .Setup(x => x.Build(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), - It.IsAny(), It.IsAny(), It.IsAny(),It.IsAny(), It.IsAny())) + .Setup(x => x.Build(It.IsAny(), It.IsAny(), It.IsAny())) .ReturnsAsync(_request); } diff --git a/test/Ocelot.UnitTests/Request/HttpRequestCreatorTests.cs b/test/Ocelot.UnitTests/Request/HttpRequestCreatorTests.cs new file mode 100644 index 00000000..d831aed9 --- /dev/null +++ b/test/Ocelot.UnitTests/Request/HttpRequestCreatorTests.cs @@ -0,0 +1,56 @@ +namespace Ocelot.UnitTests.Request +{ + using System.Net.Http; + + using Ocelot.Request.Builder; + using Ocelot.Requester.QoS; + using Ocelot.Responses; + using Shouldly; + using TestStack.BDDfy; + using Xunit; + + public class HttpRequestCreatorTests + { + private readonly IRequestCreator _requestCreator; + private readonly bool _isQos; + private readonly IQoSProvider _qoSProvider; + private readonly HttpRequestMessage _requestMessage; + private Response _response; + + public HttpRequestCreatorTests() + { + _requestCreator = new HttpRequestCreator(); + _isQos = true; + _qoSProvider = new NoQoSProvider(); + _requestMessage = new HttpRequestMessage(); + } + + [Fact] + public void ShouldBuildRequest() + { + this.When(x => x.WhenIBuildARequest()) + .Then(x => x.ThenTheRequestContainsTheRequestMessage()) + .BDDfy(); + } + + private void WhenIBuildARequest() + { + _response = _requestCreator.Build(_requestMessage, _isQos, _qoSProvider).GetAwaiter().GetResult(); + } + + private void ThenTheRequestContainsTheRequestMessage() + { + _response.Data.HttpRequestMessage.ShouldBe(_requestMessage); + } + + private void ThenTheRequestContainsTheIsQos() + { + _response.Data.IsQos.ShouldBe(_isQos); + } + + private void ThenTheRequestContainsTheQosProvider() + { + _response.Data.QosProvider.ShouldBe(_qoSProvider); + } + } +} diff --git a/test/Ocelot.UnitTests/Request/Mapper/RequestMapperTests.cs b/test/Ocelot.UnitTests/Request/Mapper/RequestMapperTests.cs new file mode 100644 index 00000000..4334e017 --- /dev/null +++ b/test/Ocelot.UnitTests/Request/Mapper/RequestMapperTests.cs @@ -0,0 +1,258 @@ +namespace Ocelot.UnitTests.Request.Mapper +{ + using System.Collections.Generic; + using System.Linq; + using System.Net.Http; + + using Microsoft.AspNetCore.Http; + using Microsoft.AspNetCore.Http.Internal; + using Microsoft.Extensions.Primitives; + using Ocelot.Request.Mapper; + using Ocelot.Responses; + using TestStack.BDDfy; + using Xunit; + using Shouldly; + using System; + using System.IO; + using System.Text; + + public class RequestMapperTests + { + readonly HttpRequest _inputRequest; + + readonly RequestMapper _requestMapper; + + Response _mappedRequest; + + List> _inputHeaders = null; + + public RequestMapperTests() + { + _inputRequest = new DefaultHttpRequest(new DefaultHttpContext()); + + _requestMapper = new RequestMapper(); + } + + [Theory] + [InlineData("https", "my.url:123", "/abc/DEF", "?a=1&b=2", "https://my.url:123/abc/DEF?a=1&b=2")] + [InlineData("http", "blah.com", "/d ef", "?abc=123", "http://blah.com/d%20ef?abc=123")] // note! the input is encoded when building the input request + [InlineData("http", "myusername:mypassword@abc.co.uk", null, null, "http://myusername:mypassword@abc.co.uk/")] + [InlineData("http", "點看.com", null, null, "http://xn--c1yn36f.com/")] + [InlineData("http", "xn--c1yn36f.com", null, null, "http://xn--c1yn36f.com/")] + public void Should_map_valid_request_uri(string scheme, string host, string path, string queryString, string expectedUri) + { + this.Given(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasScheme(scheme)) + .And(_ => GivenTheInputRequestHasHost(host)) + .And(_ => GivenTheInputRequestHasPath(path)) + .And(_ => GivenTheInputRequestHasQueryString(queryString)) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasUri(expectedUri)) + .BDDfy(); + } + + [Theory] + [InlineData("ftp", "google.com", "/abc/DEF", "?a=1&b=2")] + public void Should_error_on_unsupported_request_uri(string scheme, string host, string path, string queryString) + { + this.Given(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasScheme(scheme)) + .And(_ => GivenTheInputRequestHasHost(host)) + .And(_ => GivenTheInputRequestHasPath(path)) + .And(_ => GivenTheInputRequestHasQueryString(queryString)) + .When(_ => WhenMapped()) + .Then(_ => ThenAnErrorIsReturned()) + .And(_ => ThenTheMappedRequestIsNull()) + .BDDfy(); + } + + [Theory] + [InlineData("GET")] + [InlineData("POST")] + [InlineData("WHATEVER")] + public void Should_map_method(string method) + { + this.Given(_ => GivenTheInputRequestHasMethod(method)) + .And(_ => GivenTheInputRequestHasAValidUri()) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasMethod(method)) + .BDDfy(); + } + + [Fact] + public void Should_map_all_headers() + { + this.Given(_ => GivenTheInputRequestHasHeaders()) + .And(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasAValidUri()) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasEachHeader()) + .BDDfy(); + } + + [Fact] + public void Should_handle_no_headers() + { + this.Given(_ => GivenTheInputRequestHasNoHeaders()) + .And(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasAValidUri()) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasNoHeaders()) + .BDDfy(); + } + + [Fact] + public void Should_map_content() + { + this.Given(_ => GivenTheInputRequestHasContent("This is my content")) + .And(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasAValidUri()) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasContent("This is my content")) + .BDDfy(); + } + + [Fact] + public void Should_handle_no_content() + { + this.Given(_ => GivenTheInputRequestHasNoContent()) + .And(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasAValidUri()) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasNoContent()) + .BDDfy(); + } + + private void GivenTheInputRequestHasMethod(string method) + { + _inputRequest.Method = method; + } + + private void GivenTheInputRequestHasScheme(string scheme) + { + _inputRequest.Scheme = scheme; + } + + private void GivenTheInputRequestHasHost(string host) + { + _inputRequest.Host = new HostString(host); + } + + private void GivenTheInputRequestHasPath(string path) + { + if (path != null) + { + _inputRequest.Path = path; + } + } + + private void GivenTheInputRequestHasQueryString(string querystring) + { + if (querystring != null) + { + _inputRequest.QueryString = new QueryString(querystring); + } + } + + private void GivenTheInputRequestHasAValidUri() + { + GivenTheInputRequestHasScheme("http"); + GivenTheInputRequestHasHost("www.google.com"); + } + + private void GivenTheInputRequestHasHeaders() + { + _inputHeaders = new List>() + { + new KeyValuePair("abc", new StringValues(new string[]{"123","456" })), + new KeyValuePair("def", new StringValues(new string[]{"789","012" })), + }; + + foreach (var inputHeader in _inputHeaders) + { + _inputRequest.Headers.Add(inputHeader); + } + } + + private void GivenTheInputRequestHasNoHeaders() + { + _inputRequest.Headers.Clear(); + } + + private void GivenTheInputRequestHasContent(string content) + { + _inputRequest.Body = new MemoryStream(Encoding.UTF8.GetBytes(content)); + } + + private void GivenTheInputRequestHasNoContent() + { + _inputRequest.Body = null; + } + + private void WhenMapped() + { + _mappedRequest = _requestMapper.Map(_inputRequest).GetAwaiter().GetResult(); + } + + private void ThenNoErrorIsReturned() + { + _mappedRequest.IsError.ShouldBeFalse(); + } + + private void ThenAnErrorIsReturned() + { + _mappedRequest.IsError.ShouldBeTrue(); + } + + private void ThenTheMappedRequestHasUri(string expectedUri) + { + _mappedRequest.Data.RequestUri.OriginalString.ShouldBe(expectedUri); + } + + private void ThenTheMappedRequestHasMethod(string expectedMethod) + { + _mappedRequest.Data.Method.ToString().ShouldBe(expectedMethod); + } + + private void ThenTheMappedRequestHasEachHeader() + { + _mappedRequest.Data.Headers.Count().ShouldBe(_inputHeaders.Count); + foreach(var header in _mappedRequest.Data.Headers) + { + var inputHeader = _inputHeaders.First(h => h.Key == header.Key); + inputHeader.ShouldNotBeNull(); + inputHeader.Value.Count().ShouldBe(header.Value.Count()); + foreach(var inputHeaderValue in inputHeader.Value) + { + header.Value.Any(v => v == inputHeaderValue); + } + } + } + + private void ThenTheMappedRequestHasNoHeaders() + { + _mappedRequest.Data.Headers.Count().ShouldBe(0); + } + + private void ThenTheMappedRequestHasContent(string expectedContent) + { + _mappedRequest.Data.Content.ReadAsStringAsync().GetAwaiter().GetResult().ShouldBe(expectedContent); + } + + private void ThenTheMappedRequestHasNoContent() + { + _mappedRequest.Data.Content.ShouldBeNull(); + } + + private void ThenTheMappedRequestIsNull() + { + _mappedRequest.Data.ShouldBeNull(); + } + } +} diff --git a/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs b/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs deleted file mode 100644 index 667c80b0..00000000 --- a/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs +++ /dev/null @@ -1,310 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Net; -using System.Net.Http; -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Internal; -using Ocelot.Request.Builder; -using Ocelot.Responses; -using Shouldly; -using TestStack.BDDfy; -using Xunit; -using Ocelot.Configuration; -using Ocelot.Requester.QoS; - -namespace Ocelot.UnitTests.Request -{ - public class RequestBuilderTests - { - private string _httpMethod; - private string _downstreamUrl; - private HttpContent _content; - private IHeaderDictionary _headers; - private IRequestCookieCollection _cookies; - private QueryString _query; - private string _contentType; - private readonly IRequestCreator _requestCreator; - private Response _result; - private Ocelot.RequestId.RequestId _requestId; - private bool _isQos; - private IQoSProvider _qoSProvider; - - public RequestBuilderTests() - { - _content = new StringContent(string.Empty); - _requestCreator = new HttpRequestCreator(); - } - - [Fact] - public void should_user_downstream_url() - { - this.Given(x => x.GivenIHaveHttpMethod("GET")) - .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) - .And(x=> x.GivenTheQos(true, new NoQoSProvider())) - .When(x => x.WhenICreateARequest()) - .And(x => x.ThenTheCorrectDownstreamUrlIsUsed("http://www.bbc.co.uk/")) - .BDDfy(); - } - - [Fact] - public void should_use_http_method() - { - this.Given(x => x.GivenIHaveHttpMethod("POST")) - .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) - .And(x => x.GivenTheQos(true, new NoQoSProvider())) - - .When(x => x.WhenICreateARequest()) - .And(x => x.ThenTheCorrectHttpMethodIsUsed(HttpMethod.Post)) - .BDDfy(); - } - - [Fact] - public void should_use_http_content() - { - this.Given(x => x.GivenIHaveHttpMethod("POST")) - .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) - .And(x => x.GivenIHaveTheHttpContent(new StringContent("Hi from Tom"))) - .And(x => x.GivenTheContentTypeIs("application/json")) - .And(x => x.GivenTheQos(true, new NoQoSProvider())) - - .When(x => x.WhenICreateARequest()) - .And(x => x.ThenTheCorrectContentIsUsed(new StringContent("Hi from Tom"))) - .BDDfy(); - } - - [Fact] - public void should_use_http_content_headers() - { - this.Given(x => x.GivenIHaveHttpMethod("POST")) - .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) - .And(x => x.GivenIHaveTheHttpContent(new StringContent("Hi from Tom"))) - .And(x => x.GivenTheContentTypeIs("application/json")) - .And(x => x.GivenTheQos(true, new NoQoSProvider())) - - .When(x => x.WhenICreateARequest()) - .And(x => x.ThenTheCorrectContentHeadersAreUsed(new HeaderDictionary - { - { - "Content-Type", "application/json" - } - })) - .BDDfy(); - } - - [Fact] - public void should_use_unvalidated_http_content_headers() - { - this.Given(x => x.GivenIHaveHttpMethod("POST")) - .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) - .And(x => x.GivenIHaveTheHttpContent(new StringContent("Hi from Tom"))) - .And(x => x.GivenTheContentTypeIs("application/json; charset=utf-8")) - .And(x => x.GivenTheQos(true, new NoQoSProvider())) - - .When(x => x.WhenICreateARequest()) - .And(x => x.ThenTheCorrectContentHeadersAreUsed(new HeaderDictionary - { - { - "Content-Type", "application/json; charset=utf-8" - } - })) - .BDDfy(); - } - - [Fact] - public void should_use_headers() - { - this.Given(x => x.GivenIHaveHttpMethod("GET")) - .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) - .And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary - { - {"ChopSticks", "Bubbles" } - })) - .And(x => x.GivenTheQos(true, new NoQoSProvider())) - - .When(x => x.WhenICreateARequest()) - .And(x => x.ThenTheCorrectHeadersAreUsed(new HeaderDictionary - { - {"ChopSticks", "Bubbles" } - })) - .BDDfy(); - } - - [Fact] - public void should_use_request_id() - { - var requestId = Guid.NewGuid().ToString(); - - this.Given(x => x.GivenIHaveHttpMethod("GET")) - .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) - .And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary())) - .And(x => x.GivenTheRequestIdIs(new Ocelot.RequestId.RequestId("RequestId", requestId))) - .And(x => x.GivenTheQos(true, new NoQoSProvider())) - .When(x => x.WhenICreateARequest()) - .And(x => x.ThenTheCorrectHeadersAreUsed(new HeaderDictionary - { - {"RequestId", requestId } - })) - .BDDfy(); - } - - [Fact] - public void should_not_use_request_if_if_already_in_headers() - { - this.Given(x => x.GivenIHaveHttpMethod("GET")) - .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) - .And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary - { - {"RequestId", "534534gv54gv45g" } - })) - .And(x => x.GivenTheRequestIdIs(new Ocelot.RequestId.RequestId("RequestId", Guid.NewGuid().ToString()))) - .And(x => x.GivenTheQos(true, new NoQoSProvider())) - .When(x => x.WhenICreateARequest()) - .And(x => x.ThenTheCorrectHeadersAreUsed(new HeaderDictionary - { - {"RequestId", "534534gv54gv45g" } - })) - .BDDfy(); - } - - [Theory] - [InlineData(null, "blahh")] - [InlineData("", "blahh")] - [InlineData("RequestId", "")] - [InlineData("RequestId", null)] - public void should_not_use_request_id(string requestIdKey, string requestIdValue) - { - this.Given(x => x.GivenIHaveHttpMethod("GET")) - .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) - .And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary())) - .And(x => x.GivenTheRequestIdIs(new Ocelot.RequestId.RequestId(requestIdKey, requestIdValue))) - .And(x => x.GivenTheQos(true, new NoQoSProvider())) - .When(x => x.WhenICreateARequest()) - .And(x => x.ThenTheRequestIdIsNotInTheHeaders()) - .BDDfy(); - } - - private void GivenTheRequestIdIs(Ocelot.RequestId.RequestId requestId) - { - _requestId = requestId; - } - - private void GivenTheQos(bool isQos, IQoSProvider qoSProvider) - { - _isQos = isQos; - _qoSProvider = qoSProvider; - } - - [Fact] - public void should_user_query_string() - { - this.Given(x => x.GivenIHaveHttpMethod("POST")) - .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) - .And(x => x.GivenTheQueryStringIs(new QueryString("?jeff=1&geoff=2"))) - .When(x => x.WhenICreateARequest()) - .And(x => x.ThenTheCorrectQueryStringIsUsed("?jeff=1&geoff=2")) - .BDDfy(); - } - - private void GivenTheContentTypeIs(string contentType) - { - _contentType = contentType; - } - - private void ThenTheCorrectQueryStringIsUsed(string expected) - { - _result.Data.HttpRequestMessage.RequestUri.Query.ShouldBe(expected); - } - - private void GivenTheQueryStringIs(QueryString query) - { - _query = query; - } - - private void ThenTheCorrectCookiesAreUsed(IRequestCookieCollection expected) - { - /* var resultCookies = _result.Data.CookieContainer.GetCookies(new Uri(_downstreamUrl + _query)); - var resultDictionary = resultCookies.Cast().ToDictionary(cook => cook.Name, cook => cook.Value); - - foreach (var expectedCookie in expected) - { - var resultCookie = resultDictionary[expectedCookie.Key]; - resultCookie.ShouldBe(expectedCookie.Value); - }*/ - } - - private void GivenTheCookiesAre(IRequestCookieCollection cookies) - { - _cookies = cookies; - } - - private void ThenTheRequestIdIsNotInTheHeaders() - { - _result.Data.HttpRequestMessage.Headers.ShouldNotContain(x => x.Key == "RequestId"); - } - - private void ThenTheCorrectHeadersAreUsed(IHeaderDictionary expected) - { - var expectedHeaders = expected.Select(x => new KeyValuePair(x.Key, x.Value)); - - foreach (var expectedHeader in expectedHeaders) - { - _result.Data.HttpRequestMessage.Headers.ShouldContain(x => x.Key == expectedHeader.Key && x.Value.First() == expectedHeader.Value[0]); - } - } - - private void ThenTheCorrectContentHeadersAreUsed(IHeaderDictionary expected) - { - var expectedHeaders = expected.Select(x => new KeyValuePair(x.Key, x.Value)); - - foreach (var expectedHeader in expectedHeaders) - { - _result.Data.HttpRequestMessage.Content.Headers.ShouldContain(x => x.Key == expectedHeader.Key - && x.Value.First() == expectedHeader.Value[0] - ); - } - } - - private void GivenTheHttpHeadersAre(IHeaderDictionary headers) - { - _headers = headers; - } - - private void GivenIHaveTheHttpContent(HttpContent content) - { - _content = content; - } - - private void GivenIHaveHttpMethod(string httpMethod) - { - _httpMethod = httpMethod; - } - - private void GivenIHaveDownstreamUrl(string downstreamUrl) - { - _downstreamUrl = downstreamUrl; - } - - private void WhenICreateARequest() - { - _result = _requestCreator.Build(_httpMethod, _downstreamUrl, _content?.ReadAsStreamAsync().Result, _headers, - _query, _contentType, _requestId,_isQos,_qoSProvider).Result; - } - - - private void ThenTheCorrectDownstreamUrlIsUsed(string expected) - { - _result.Data.HttpRequestMessage.RequestUri.AbsoluteUri.ShouldBe(expected); - } - - private void ThenTheCorrectHttpMethodIsUsed(HttpMethod expected) - { - _result.Data.HttpRequestMessage.Method.Method.ShouldBe(expected.Method); - } - - private void ThenTheCorrectContentIsUsed(HttpContent expected) - { - _result.Data.HttpRequestMessage.Content.ReadAsStringAsync().Result.ShouldBe(expected.ReadAsStringAsync().Result); - } - } -} diff --git a/test/Ocelot.UnitTests/RequestId/RequestIdMiddlewareTests.cs b/test/Ocelot.UnitTests/RequestId/RequestIdMiddlewareTests.cs index 26450ab8..4521e315 100644 --- a/test/Ocelot.UnitTests/RequestId/RequestIdMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/RequestId/RequestIdMiddlewareTests.cs @@ -8,14 +8,12 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using Moq; using Ocelot.Configuration.Builder; using Ocelot.DownstreamRouteFinder; using Ocelot.DownstreamRouteFinder.UrlMatcher; using Ocelot.Infrastructure.RequestData; using Ocelot.Logging; -using Ocelot.Request.Middleware; using Ocelot.RequestId.Middleware; using Ocelot.Responses; using Shouldly; @@ -27,6 +25,7 @@ namespace Ocelot.UnitTests.RequestId public class RequestIdMiddlewareTests { private readonly Mock _scopedRepository; + private readonly HttpRequestMessage _downstreamRequest; private readonly string _url; private readonly TestServer _server; private readonly HttpClient _client; @@ -64,10 +63,16 @@ namespace Ocelot.UnitTests.RequestId _server = new TestServer(builder); _client = _server.CreateClient(); + + _downstreamRequest = new HttpRequestMessage(); + + _scopedRepository + .Setup(sr => sr.Get("DownstreamRequest")) + .Returns(new OkResponse(_downstreamRequest)); } [Fact] - public void should_add_request_id_to_repository() + public void should_pass_down_request_id_from_upstream_request() { var downstreamRoute = new DownstreamRoute(new List(), new ReRouteBuilder() @@ -86,7 +91,7 @@ namespace Ocelot.UnitTests.RequestId } [Fact] - public void should_add_trace_indentifier_to_repository() + public void should_add_request_id_when_not_on_upstream_request() { var downstreamRoute = new DownstreamRoute(new List(), new ReRouteBuilder() @@ -101,14 +106,12 @@ namespace Ocelot.UnitTests.RequestId .BDDfy(); } - private void ThenTheTraceIdIsAnything() + private void GivenTheDownStreamRouteIs(DownstreamRoute downstreamRoute) { - _result.Headers.GetValues("LSRequestId").First().ShouldNotBeNullOrEmpty(); - } - - private void ThenTheTraceIdIs(string expected) - { - _result.Headers.GetValues("LSRequestId").First().ShouldBe(expected); + _downstreamRoute = new OkResponse(downstreamRoute); + _scopedRepository + .Setup(x => x.Get(It.IsAny())) + .Returns(_downstreamRoute); } private void GivenTheRequestIdIsAddedToTheRequest(string key, string value) @@ -123,12 +126,14 @@ namespace Ocelot.UnitTests.RequestId _result = _client.GetAsync(_url).Result; } - private void GivenTheDownStreamRouteIs(DownstreamRoute downstreamRoute) + private void ThenTheTraceIdIsAnything() { - _downstreamRoute = new OkResponse(downstreamRoute); - _scopedRepository - .Setup(x => x.Get(It.IsAny())) - .Returns(_downstreamRoute); + _result.Headers.GetValues("LSRequestId").First().ShouldNotBeNullOrEmpty(); + } + + private void ThenTheTraceIdIs(string expected) + { + _result.Headers.GetValues("LSRequestId").First().ShouldBe(expected); } public void Dispose()