diff --git a/src/Ocelot/Authentication/Middleware/AuthenticationMiddleware.cs b/src/Ocelot/Authentication/Middleware/AuthenticationMiddleware.cs index a6eb9cc2..f29161ce 100644 --- a/src/Ocelot/Authentication/Middleware/AuthenticationMiddleware.cs +++ b/src/Ocelot/Authentication/Middleware/AuthenticationMiddleware.cs @@ -1,97 +1,97 @@ -using System.Collections.Generic; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Http; -using Ocelot.Authentication.Handler.Factory; -using Ocelot.Configuration; -using Ocelot.Errors; -using Ocelot.Infrastructure.Extensions; -using Ocelot.Infrastructure.RequestData; -using Ocelot.Logging; -using Ocelot.Middleware; - -namespace Ocelot.Authentication.Middleware -{ - public class AuthenticationMiddleware : OcelotMiddleware - { - private readonly RequestDelegate _next; - private readonly IApplicationBuilder _app; - private readonly IAuthenticationHandlerFactory _authHandlerFactory; - private readonly IOcelotLogger _logger; - - public AuthenticationMiddleware(RequestDelegate next, - IApplicationBuilder app, - IRequestScopedDataRepository requestScopedDataRepository, - IAuthenticationHandlerFactory authHandlerFactory, - IOcelotLoggerFactory loggerFactory) - : base(requestScopedDataRepository) - { - _next = next; - _authHandlerFactory = authHandlerFactory; - _app = app; - _logger = loggerFactory.CreateLogger(); - } - - public async Task Invoke(HttpContext context) - { - _logger.TraceMiddlewareEntry(); - - if (IsAuthenticatedRoute(DownstreamRoute.ReRoute)) - { - _logger.LogDebug($"{context.Request.Path} is an authenticated route. {MiddlwareName} checking if client is authenticated"); - - var authenticationHandler = _authHandlerFactory.Get(_app, DownstreamRoute.ReRoute.AuthenticationOptions); - - if (authenticationHandler.IsError) +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Ocelot.Authentication.Handler.Factory; +using Ocelot.Configuration; +using Ocelot.Errors; +using Ocelot.Infrastructure.Extensions; +using Ocelot.Infrastructure.RequestData; +using Ocelot.Logging; +using Ocelot.Middleware; + +namespace Ocelot.Authentication.Middleware +{ + public class AuthenticationMiddleware : OcelotMiddleware + { + private readonly RequestDelegate _next; + private readonly IApplicationBuilder _app; + private readonly IAuthenticationHandlerFactory _authHandlerFactory; + private readonly IOcelotLogger _logger; + + public AuthenticationMiddleware(RequestDelegate next, + IApplicationBuilder app, + IRequestScopedDataRepository requestScopedDataRepository, + IAuthenticationHandlerFactory authHandlerFactory, + IOcelotLoggerFactory loggerFactory) + : base(requestScopedDataRepository) + { + _next = next; + _authHandlerFactory = authHandlerFactory; + _app = app; + _logger = loggerFactory.CreateLogger(); + } + + public async Task Invoke(HttpContext context) + { + _logger.TraceMiddlewareEntry(); + + if (IsAuthenticatedRoute(DownstreamRoute.ReRoute)) + { + _logger.LogDebug($"{context.Request.Path} is an authenticated route. {MiddlwareName} checking if client is authenticated"); + + var authenticationHandler = _authHandlerFactory.Get(_app, DownstreamRoute.ReRoute.AuthenticationOptions); + + if (authenticationHandler.IsError) { - _logger.LogError($"Error getting authentication handler for {context.Request.Path}. {authenticationHandler.Errors.ToErrorString()}"); - SetPipelineError(authenticationHandler.Errors); - _logger.TraceMiddlewareCompleted(); - return; + _logger.LogError($"Error getting authentication handler for {context.Request.Path}. {authenticationHandler.Errors.ToErrorString()}"); + SetPipelineError(authenticationHandler.Errors); + _logger.TraceMiddlewareCompleted(); + return; } - await authenticationHandler.Data.Handler.Handle(context); - - - if (context.User.Identity.IsAuthenticated) - { - _logger.LogDebug($"Client has been authenticated for {context.Request.Path}"); - - _logger.TraceInvokeNext(); - await _next.Invoke(context); - _logger.TraceInvokeNextCompleted(); - _logger.TraceMiddlewareCompleted(); - } - else - { - var error = new List - { - new UnauthenticatedError( - $"Request for authenticated route {context.Request.Path} by {context.User.Identity.Name} was unauthenticated") - }; - - _logger.LogError($"Client has NOT been authenticated for {context.Request.Path} and pipeline error set. {error.ToErrorString()}"); - SetPipelineError(error); - - _logger.TraceMiddlewareCompleted(); - return; - } - } - else + await authenticationHandler.Data.Handler.Handle(context); + + + if (context.User.Identity.IsAuthenticated) + { + _logger.LogDebug($"Client has been authenticated for {context.Request.Path}"); + + _logger.TraceInvokeNext(); + await _next.Invoke(context); + _logger.TraceInvokeNextCompleted(); + _logger.TraceMiddlewareCompleted(); + } + else + { + var error = new List + { + new UnauthenticatedError( + $"Request for authenticated route {context.Request.Path} by {context.User.Identity.Name} was unauthenticated") + }; + + _logger.LogError($"Client has NOT been authenticated for {context.Request.Path} and pipeline error set. {error.ToErrorString()}"); + SetPipelineError(error); + + _logger.TraceMiddlewareCompleted(); + return; + } + } + else { _logger.LogTrace($"No authentication needed for {context.Request.Path}"); _logger.TraceInvokeNext(); await _next.Invoke(context); _logger.TraceInvokeNextCompleted(); - _logger.TraceMiddlewareCompleted(); - } - } - - private static bool IsAuthenticatedRoute(ReRoute reRoute) - { - return reRoute.IsAuthenticated; - } - } -} - + _logger.TraceMiddlewareCompleted(); + } + } + + private static bool IsAuthenticatedRoute(ReRoute reRoute) + { + return reRoute.IsAuthenticated; + } + } +} + diff --git a/src/Ocelot/Cache/Middleware/OutputCacheMiddleware.cs b/src/Ocelot/Cache/Middleware/OutputCacheMiddleware.cs index 948a7397..d671f182 100644 --- a/src/Ocelot/Cache/Middleware/OutputCacheMiddleware.cs +++ b/src/Ocelot/Cache/Middleware/OutputCacheMiddleware.cs @@ -27,7 +27,7 @@ namespace Ocelot.Cache.Middleware public async Task Invoke(HttpContext context) { - var downstreamUrlKey = DownstreamUrl; + var downstreamUrlKey = DownstreamRequest.RequestUri.OriginalString; if (!DownstreamRoute.ReRoute.IsCached) { diff --git a/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs b/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs index 631e278a..7cabd305 100644 --- a/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs +++ b/src/Ocelot/DownstreamUrlCreator/Middleware/DownstreamUrlCreatorMiddleware.cs @@ -4,6 +4,7 @@ using Ocelot.DownstreamUrlCreator.UrlTemplateReplacer; using Ocelot.Infrastructure.RequestData; using Ocelot.Logging; using Ocelot.Middleware; +using System; namespace Ocelot.DownstreamUrlCreator.Middleware { @@ -42,23 +43,31 @@ namespace Ocelot.DownstreamUrlCreator.Middleware return; } - var dsScheme = DownstreamRoute.ReRoute.DownstreamScheme; - - var dsHostAndPort = HostAndPort; + //var dsScheme = DownstreamRoute.ReRoute.DownstreamScheme; - var dsUrl = _urlBuilder.Build(dsPath.Data.Value, dsScheme, dsHostAndPort); + //var dsHostAndPort = HostAndPort; - if (dsUrl.IsError) + //var dsUrl = _urlBuilder.Build(dsPath.Data.Value, dsScheme, dsHostAndPort); + + //if (dsUrl.IsError) + //{ + // _logger.LogDebug("IUrlBuilder returned an error, setting pipeline error"); + + // SetPipelineError(dsUrl.Errors); + // return; + //} + + var uriBuilder = new UriBuilder(DownstreamRequest.RequestUri) { - _logger.LogDebug("IUrlBuilder returned an error, setting pipeline error"); + Path = dsPath.Data.Value, + Scheme = DownstreamRoute.ReRoute.DownstreamScheme + }; - SetPipelineError(dsUrl.Errors); - return; - } + DownstreamRequest.RequestUri = uriBuilder.Uri; - _logger.LogDebug("downstream url is {downstreamUrl.Data.Value}", dsUrl.Data.Value); + _logger.LogDebug("downstream url is {downstreamUrl.Data.Value}", DownstreamRequest.RequestUri); - SetDownstreamUrlForThisRequest(dsUrl.Data.Value); + //SetDownstreamUrlForThisRequest(dsUrl.Data.Value); _logger.LogDebug("calling next middleware"); diff --git a/src/Ocelot/DownstreamUrlCreator/UrlBuilder.cs b/src/Ocelot/DownstreamUrlCreator/UrlBuilder.cs index 43a63715..eed7b8d7 100644 --- a/src/Ocelot/DownstreamUrlCreator/UrlBuilder.cs +++ b/src/Ocelot/DownstreamUrlCreator/UrlBuilder.cs @@ -25,6 +25,7 @@ namespace Ocelot.DownstreamUrlCreator return new ErrorResponse(new List { new DownstreamHostNullOrEmptyError() }); } + var builder = new UriBuilder { Host = downstreamHostAndPort.DownstreamHost, diff --git a/src/Ocelot/Headers/AddHeadersToRequest.cs b/src/Ocelot/Headers/AddHeadersToRequest.cs index 97cd3e69..791860b5 100644 --- a/src/Ocelot/Headers/AddHeadersToRequest.cs +++ b/src/Ocelot/Headers/AddHeadersToRequest.cs @@ -1,10 +1,9 @@ using System.Collections.Generic; using System.Linq; -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Primitives; using Ocelot.Configuration; using Ocelot.Infrastructure.Claims.Parser; using Ocelot.Responses; +using System.Net.Http; namespace Ocelot.Headers { @@ -17,25 +16,49 @@ namespace Ocelot.Headers _claimsParser = claimsParser; } - public Response SetHeadersOnContext(List claimsToThings, HttpContext context) + //public Response SetHeadersOnContext(List claimsToThings, HttpContext context) + //{ + // foreach (var config in claimsToThings) + // { + // var value = _claimsParser.GetValue(context.User.Claims, config.NewKey, config.Delimiter, config.Index); + + // if (value.IsError) + // { + // return new ErrorResponse(value.Errors); + // } + + // var exists = context.Request.Headers.FirstOrDefault(x => x.Key == config.ExistingKey); + + // if (!string.IsNullOrEmpty(exists.Key)) + // { + // context.Request.Headers.Remove(exists); + // } + + // context.Request.Headers.Add(config.ExistingKey, new StringValues(value.Data)); + // } + + // return new OkResponse(); + //} + + public Response SetHeadersOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest) { foreach (var config in claimsToThings) { - var value = _claimsParser.GetValue(context.User.Claims, config.NewKey, config.Delimiter, config.Index); + var value = _claimsParser.GetValue(claims, config.NewKey, config.Delimiter, config.Index); if (value.IsError) { return new ErrorResponse(value.Errors); } - var exists = context.Request.Headers.FirstOrDefault(x => x.Key == config.ExistingKey); + var exists = downstreamRequest.Headers.FirstOrDefault(x => x.Key == config.ExistingKey); if (!string.IsNullOrEmpty(exists.Key)) { - context.Request.Headers.Remove(exists); + downstreamRequest.Headers.Remove(exists.Key); } - context.Request.Headers.Add(config.ExistingKey, new StringValues(value.Data)); + downstreamRequest.Headers.Add(config.ExistingKey, value.Data); } return new OkResponse(); diff --git a/src/Ocelot/Headers/IAddHeadersToRequest.cs b/src/Ocelot/Headers/IAddHeadersToRequest.cs index 3bf786a4..a819bbf1 100644 --- a/src/Ocelot/Headers/IAddHeadersToRequest.cs +++ b/src/Ocelot/Headers/IAddHeadersToRequest.cs @@ -2,12 +2,15 @@ using Microsoft.AspNetCore.Http; using Ocelot.Configuration; using Ocelot.Responses; +using System.Net.Http; namespace Ocelot.Headers { public interface IAddHeadersToRequest { - Response SetHeadersOnContext(List claimsToThings, - HttpContext context); + //Response SetHeadersOnContext(List claimsToThings, + // HttpContext context); + + Response SetHeadersOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest); } } diff --git a/src/Ocelot/Headers/Middleware/HttpRequestHeadersBuilderMiddleware.cs b/src/Ocelot/Headers/Middleware/HttpRequestHeadersBuilderMiddleware.cs index a89d2ec2..380fcc7d 100644 --- a/src/Ocelot/Headers/Middleware/HttpRequestHeadersBuilderMiddleware.cs +++ b/src/Ocelot/Headers/Middleware/HttpRequestHeadersBuilderMiddleware.cs @@ -32,7 +32,8 @@ namespace Ocelot.Headers.Middleware { _logger.LogDebug("this route has instructions to convert claims to headers"); - var response = _addHeadersToRequest.SetHeadersOnContext(DownstreamRoute.ReRoute.ClaimsToHeaders, context); + //var response = _addHeadersToRequest.SetHeadersOnContext(DownstreamRoute.ReRoute.ClaimsToHeaders, context); + var response = _addHeadersToRequest.SetHeadersOnDownstreamRequest(DownstreamRoute.ReRoute.ClaimsToHeaders, context.User.Claims, DownstreamRequest); if (response.IsError) { diff --git a/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs b/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs index 8e26cbf2..d74759e9 100644 --- a/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs +++ b/src/Ocelot/LoadBalancer/Middleware/LoadBalancingMiddleware.cs @@ -44,7 +44,14 @@ namespace Ocelot.LoadBalancer.Middleware return; } - SetHostAndPortForThisRequest(hostAndPort.Data); + //SetHostAndPortForThisRequest(hostAndPort.Data); + var uriBuilder = new UriBuilder(DownstreamRequest.RequestUri); + uriBuilder.Host = hostAndPort.Data.DownstreamHost; + if (hostAndPort.Data.DownstreamPort > 0) + { + uriBuilder.Port = hostAndPort.Data.DownstreamPort; + } + DownstreamRequest.RequestUri = uriBuilder.Uri; _logger.LogDebug("calling next middleware"); diff --git a/src/Ocelot/Middleware/OcelotMiddleware.cs b/src/Ocelot/Middleware/OcelotMiddleware.cs index 2926e543..373f1472 100644 --- a/src/Ocelot/Middleware/OcelotMiddleware.cs +++ b/src/Ocelot/Middleware/OcelotMiddleware.cs @@ -46,14 +46,16 @@ namespace Ocelot.Middleware } } - public string DownstreamUrl - { - get - { - var downstreamUrl = _requestScopedDataRepository.Get("DownstreamUrl"); - return downstreamUrl.Data; - } - } + //public string DownstreamUrl + //{ + // get + // { + // var downstreamUrl = _requestScopedDataRepository.Get("DownstreamUrl"); + // return downstreamUrl.Data; + // } + //} + + public HttpRequestMessage DownstreamRequest => _requestScopedDataRepository.Get("DownstreamRequest").Data; public Request.Request Request { @@ -73,18 +75,23 @@ namespace Ocelot.Middleware } } - public HostAndPort HostAndPort - { - get - { - var hostAndPort = _requestScopedDataRepository.Get("HostAndPort"); - return hostAndPort.Data; - } - } + //public HostAndPort HostAndPort + //{ + // get + // { + // var hostAndPort = _requestScopedDataRepository.Get("HostAndPort"); + // return hostAndPort.Data; + // } + //} - public void SetHostAndPortForThisRequest(HostAndPort hostAndPort) + //public void SetHostAndPortForThisRequest(HostAndPort hostAndPort) + //{ + // _requestScopedDataRepository.Add("HostAndPort", hostAndPort); + //} + + public void SetDownstreamRequest(HttpRequestMessage request) { - _requestScopedDataRepository.Add("HostAndPort", hostAndPort); + _requestScopedDataRepository.Add("DownstreamRequest", request); } public void SetDownstreamRouteForThisRequest(DownstreamRoute downstreamRoute) @@ -92,10 +99,10 @@ namespace Ocelot.Middleware _requestScopedDataRepository.Add("DownstreamRoute", downstreamRoute); } - public void SetDownstreamUrlForThisRequest(string downstreamUrl) - { - _requestScopedDataRepository.Add("DownstreamUrl", downstreamUrl); - } + //public void SetDownstreamUrlForThisRequest(string downstreamUrl) + //{ + // _requestScopedDataRepository.Add("DownstreamUrl", downstreamUrl); + //} public void SetUpstreamRequestForThisRequest(Request.Request request) { diff --git a/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs b/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs index 457c2448..105be8dd 100644 --- a/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs +++ b/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs @@ -53,6 +53,9 @@ namespace Ocelot.Middleware { await CreateAdministrationArea(builder); + // Initialises downstream request + builder.UseDownstreamRequestInitialiser(); + // This is registered to catch any global exceptions that are not handled builder.UseExceptionHandlerMiddleware(); diff --git a/src/Ocelot/QueryStrings/AddQueriesToRequest.cs b/src/Ocelot/QueryStrings/AddQueriesToRequest.cs index 02fcb63d..ecf07715 100644 --- a/src/Ocelot/QueryStrings/AddQueriesToRequest.cs +++ b/src/Ocelot/QueryStrings/AddQueriesToRequest.cs @@ -4,6 +4,9 @@ using Microsoft.AspNetCore.Http; using Ocelot.Configuration; using Ocelot.Infrastructure.Claims.Parser; using Ocelot.Responses; +using System.Security.Claims; +using System.Net.Http; +using System; namespace Ocelot.QueryStrings { @@ -16,13 +19,44 @@ namespace Ocelot.QueryStrings _claimsParser = claimsParser; } - public Response SetQueriesOnContext(List claimsToThings, HttpContext context) + //public Response SetQueriesOnContext(List claimsToThings, HttpContext context) + //{ + // var queryDictionary = ConvertQueryStringToDictionary(context.Request.QueryString); + + // foreach (var config in claimsToThings) + // { + // var value = _claimsParser.GetValue(context.User.Claims, config.NewKey, config.Delimiter, config.Index); + + // if (value.IsError) + // { + // return new ErrorResponse(value.Errors); + // } + + // var exists = queryDictionary.FirstOrDefault(x => x.Key == config.ExistingKey); + + // if (!string.IsNullOrEmpty(exists.Key)) + // { + // queryDictionary[exists.Key] = value.Data; + // } + // else + // { + // queryDictionary.Add(config.ExistingKey, value.Data); + // } + // } + + // context.Request.QueryString = ConvertDictionaryToQueryString(queryDictionary); + + // return new OkResponse(); + //} + + public Response SetQueriesOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest) { - var queryDictionary = ConvertQueryStringToDictionary(context); + + var queryDictionary = ConvertQueryStringToDictionary(downstreamRequest.RequestUri.Query); foreach (var config in claimsToThings) { - var value = _claimsParser.GetValue(context.User.Claims, config.NewKey, config.Delimiter, config.Index); + var value = _claimsParser.GetValue(claims, config.NewKey, config.Delimiter, config.Index); if (value.IsError) { @@ -41,22 +75,37 @@ namespace Ocelot.QueryStrings } } - context.Request.QueryString = ConvertDictionaryToQueryString(queryDictionary); + var uriBuilder = new UriBuilder(downstreamRequest.RequestUri); + uriBuilder.Query = ConvertDictionaryToQueryString(queryDictionary); + + downstreamRequest.RequestUri = uriBuilder.Uri; return new OkResponse(); } - private Dictionary ConvertQueryStringToDictionary(HttpContext context) + //private Dictionary ConvertQueryStringToDictionary(HttpContext context) + //{ + // return Microsoft.AspNetCore.WebUtilities.QueryHelpers.ParseQuery(context.Request.QueryString.Value) + // .ToDictionary(q => q.Key, q => q.Value.FirstOrDefault() ?? string.Empty); + //} + + private Dictionary ConvertQueryStringToDictionary(string queryString) { - return Microsoft.AspNetCore.WebUtilities.QueryHelpers.ParseQuery(context.Request.QueryString.Value) + return Microsoft.AspNetCore.WebUtilities.QueryHelpers + .ParseQuery(queryString) .ToDictionary(q => q.Key, q => q.Value.FirstOrDefault() ?? string.Empty); } - private Microsoft.AspNetCore.Http.QueryString ConvertDictionaryToQueryString(Dictionary queryDictionary) - { - var newQueryString = Microsoft.AspNetCore.WebUtilities.QueryHelpers.AddQueryString("", queryDictionary); + //private Microsoft.AspNetCore.Http.QueryString ConvertDictionaryToQueryString(Dictionary queryDictionary) + //{ + // var newQueryString = Microsoft.AspNetCore.WebUtilities.QueryHelpers.AddQueryString("", queryDictionary); - return new Microsoft.AspNetCore.Http.QueryString(newQueryString); + // return new Microsoft.AspNetCore.Http.QueryString(newQueryString); + //} + + private string ConvertDictionaryToQueryString(Dictionary queryDictionary) + { + return Microsoft.AspNetCore.WebUtilities.QueryHelpers.AddQueryString("", queryDictionary); } } } \ No newline at end of file diff --git a/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs b/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs index 6fa1b8da..606732c6 100644 --- a/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs +++ b/src/Ocelot/QueryStrings/IAddQueriesToRequest.cs @@ -2,12 +2,16 @@ using Microsoft.AspNetCore.Http; using Ocelot.Configuration; using Ocelot.Responses; +using System.Net.Http; +using System.Security.Claims; namespace Ocelot.QueryStrings { public interface IAddQueriesToRequest { - Response SetQueriesOnContext(List claimsToThings, - HttpContext context); + //Response SetQueriesOnContext(List claimsToThings, + // HttpContext context); + + Response SetQueriesOnDownstreamRequest(List claimsToThings, IEnumerable claims, HttpRequestMessage downstreamRequest); } } diff --git a/src/Ocelot/QueryStrings/Middleware/QueryStringBuilderMiddleware.cs b/src/Ocelot/QueryStrings/Middleware/QueryStringBuilderMiddleware.cs index edeee51c..80db18b4 100644 --- a/src/Ocelot/QueryStrings/Middleware/QueryStringBuilderMiddleware.cs +++ b/src/Ocelot/QueryStrings/Middleware/QueryStringBuilderMiddleware.cs @@ -32,7 +32,8 @@ namespace Ocelot.QueryStrings.Middleware { _logger.LogDebug("this route has instructions to convert claims to queries"); - var response = _addQueriesToRequest.SetQueriesOnContext(DownstreamRoute.ReRoute.ClaimsToQueries, context); + //var response = _addQueriesToRequest.SetQueriesOnContext(DownstreamRoute.ReRoute.ClaimsToQueries, context); + var response = _addQueriesToRequest.SetQueriesOnDownstreamRequest(DownstreamRoute.ReRoute.ClaimsToQueries, context.User.Claims, DownstreamRequest); if (response.IsError) { diff --git a/src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs b/src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs index dffc6448..f29dcb82 100644 --- a/src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs +++ b/src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs @@ -111,11 +111,7 @@ namespace Ocelot.RateLimit.Middleware public bool IsWhitelisted(ClientRequestIdentity requestIdentity, RateLimitOptions option) { - if (option.ClientWhitelist.Contains(requestIdentity.ClientId)) - { - return true; - } - return false; + return option.ClientWhitelist.Contains(requestIdentity.ClientId); } public virtual void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule) diff --git a/src/Ocelot/Request/Builder/HttpRequestCreator.cs b/src/Ocelot/Request/Builder/HttpRequestCreator.cs index de030f83..b4aa57aa 100644 --- a/src/Ocelot/Request/Builder/HttpRequestCreator.cs +++ b/src/Ocelot/Request/Builder/HttpRequestCreator.cs @@ -4,35 +4,44 @@ using Microsoft.AspNetCore.Http; using Ocelot.Responses; using Ocelot.Configuration; using Ocelot.Requester.QoS; +using System.Net.Http; namespace Ocelot.Request.Builder { public sealed class HttpRequestCreator : IRequestCreator { + //public async Task> Build( + // string httpMethod, + // string downstreamUrl, + // Stream content, + // IHeaderDictionary headers, + // QueryString queryString, + // string contentType, + // RequestId.RequestId requestId, + // bool isQos, + // IQoSProvider qosProvider) + //{ + // var request = await new RequestBuilder() + // .WithHttpMethod(httpMethod) + // .WithDownstreamUrl(downstreamUrl) + // .WithQueryString(queryString) + // .WithContent(content) + // .WithContentType(contentType) + // .WithHeaders(headers) + // .WithRequestId(requestId) + // .WithIsQos(isQos) + // .WithQos(qosProvider) + // .Build(); + + // return new OkResponse(request); + //} + public async Task> Build( - string httpMethod, - string downstreamUrl, - Stream content, - IHeaderDictionary headers, - QueryString queryString, - string contentType, - RequestId.RequestId requestId, + HttpRequestMessage httpRequestMessage, bool isQos, IQoSProvider qosProvider) { - var request = await new RequestBuilder() - .WithHttpMethod(httpMethod) - .WithDownstreamUrl(downstreamUrl) - .WithQueryString(queryString) - .WithContent(content) - .WithContentType(contentType) - .WithHeaders(headers) - .WithRequestId(requestId) - .WithIsQos(isQos) - .WithQos(qosProvider) - .Build(); - - return new OkResponse(request); + return new OkResponse(new Request(httpRequestMessage, isQos, qosProvider)); } } } \ No newline at end of file diff --git a/src/Ocelot/Request/Builder/IRequestCreator.cs b/src/Ocelot/Request/Builder/IRequestCreator.cs index 379f0aac..d4eae34f 100644 --- a/src/Ocelot/Request/Builder/IRequestCreator.cs +++ b/src/Ocelot/Request/Builder/IRequestCreator.cs @@ -3,18 +3,24 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Ocelot.Requester.QoS; using Ocelot.Responses; +using System.Net.Http; namespace Ocelot.Request.Builder { public interface IRequestCreator { - Task> Build(string httpMethod, - string downstreamUrl, - Stream content, - IHeaderDictionary headers, - QueryString queryString, - string contentType, - RequestId.RequestId requestId, + //Task> Build(string httpMethod, + // string downstreamUrl, + // Stream content, + // IHeaderDictionary headers, + // QueryString queryString, + // string contentType, + // RequestId.RequestId requestId, + // bool isQos, + // IQoSProvider qosProvider); + + Task> Build( + HttpRequestMessage httpRequestMessage, bool isQos, IQoSProvider qosProvider); } diff --git a/src/Ocelot/Request/Builder/RequestBuilder.cs b/src/Ocelot/Request/Builder/RequestBuilder.cs index e47eea1f..694cb71e 100644 --- a/src/Ocelot/Request/Builder/RequestBuilder.cs +++ b/src/Ocelot/Request/Builder/RequestBuilder.cs @@ -1,177 +1,177 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Net; -using System.Net.Http; -using System.Net.Http.Headers; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Primitives; -using Ocelot.Requester.QoS; +//using System; +//using System.Collections.Generic; +//using System.IO; +//using System.Linq; +//using System.Net; +//using System.Net.Http; +//using System.Net.Http.Headers; +//using System.Threading.Tasks; +//using Microsoft.AspNetCore.Http; +//using Microsoft.Extensions.Primitives; +//using Ocelot.Requester.QoS; -namespace Ocelot.Request.Builder -{ - internal sealed class RequestBuilder - { - private HttpMethod _method; - private string _downstreamUrl; - private QueryString _queryString; - private Stream _content; - private string _contentType; - private IHeaderDictionary _headers; - private RequestId.RequestId _requestId; - private readonly string[] _unsupportedHeaders = {"host"}; - private bool _isQos; - private IQoSProvider _qoSProvider; +//namespace Ocelot.Request.Builder +//{ +// internal sealed class RequestBuilder +// { +// private HttpMethod _method; +// private string _downstreamUrl; +// private QueryString _queryString; +// private Stream _content; +// private string _contentType; +// private IHeaderDictionary _headers; +// private RequestId.RequestId _requestId; +// private readonly string[] _unsupportedHeaders = {"host"}; +// private bool _isQos; +// private IQoSProvider _qoSProvider; - public RequestBuilder WithHttpMethod(string httpMethod) - { - _method = new HttpMethod(httpMethod); - return this; - } +// public RequestBuilder WithHttpMethod(string httpMethod) +// { +// _method = new HttpMethod(httpMethod); +// return this; +// } - public RequestBuilder WithDownstreamUrl(string downstreamUrl) - { - _downstreamUrl = downstreamUrl; - return this; - } +// public RequestBuilder WithDownstreamUrl(string downstreamUrl) +// { +// _downstreamUrl = downstreamUrl; +// return this; +// } - public RequestBuilder WithQueryString(QueryString queryString) - { - _queryString = queryString; - return this; - } +// public RequestBuilder WithQueryString(QueryString queryString) +// { +// _queryString = queryString; +// return this; +// } - public RequestBuilder WithContent(Stream content) - { - _content = content; - return this; - } +// public RequestBuilder WithContent(Stream content) +// { +// _content = content; +// return this; +// } - public RequestBuilder WithContentType(string contentType) - { - _contentType = contentType; - return this; - } +// public RequestBuilder WithContentType(string contentType) +// { +// _contentType = contentType; +// return this; +// } - public RequestBuilder WithHeaders(IHeaderDictionary headers) - { - _headers = headers; - return this; - } +// public RequestBuilder WithHeaders(IHeaderDictionary headers) +// { +// _headers = headers; +// return this; +// } - public RequestBuilder WithRequestId(RequestId.RequestId requestId) - { - _requestId = requestId; - return this; - } +// public RequestBuilder WithRequestId(RequestId.RequestId requestId) +// { +// _requestId = requestId; +// return this; +// } - public RequestBuilder WithIsQos(bool isqos) - { - _isQos = isqos; - return this; - } +// public RequestBuilder WithIsQos(bool isqos) +// { +// _isQos = isqos; +// return this; +// } - public RequestBuilder WithQos(IQoSProvider qoSProvider) - { - _qoSProvider = qoSProvider; - return this; - } +// public RequestBuilder WithQos(IQoSProvider qoSProvider) +// { +// _qoSProvider = qoSProvider; +// return this; +// } - public async Task Build() - { - var uri = CreateUri(); +// public async Task Build() +// { +// var uri = CreateUri(); - var httpRequestMessage = new HttpRequestMessage(_method, uri); +// var httpRequestMessage = new HttpRequestMessage(_method, uri); - await AddContentToRequest(httpRequestMessage); +// await AddContentToRequest(httpRequestMessage); - AddContentTypeToRequest(httpRequestMessage); +// AddContentTypeToRequest(httpRequestMessage); - AddHeadersToRequest(httpRequestMessage); +// AddHeadersToRequest(httpRequestMessage); - if (ShouldAddRequestId(_requestId, httpRequestMessage.Headers)) - { - AddRequestIdHeader(_requestId, httpRequestMessage); - } +// if (ShouldAddRequestId(_requestId, httpRequestMessage.Headers)) +// { +// AddRequestIdHeader(_requestId, httpRequestMessage); +// } - return new Request(httpRequestMessage,_isQos, _qoSProvider); - } +// return new Request(httpRequestMessage,_isQos, _qoSProvider); +// } - private Uri CreateUri() - { - var uri = new Uri(string.Format("{0}{1}", _downstreamUrl, _queryString.ToUriComponent())); - return uri; - } +// private Uri CreateUri() +// { +// var uri = new Uri(string.Format("{0}{1}", _downstreamUrl, _queryString.ToUriComponent())); +// return uri; +// } - private async Task AddContentToRequest(HttpRequestMessage httpRequestMessage) - { - if (_content != null) - { - httpRequestMessage.Content = new ByteArrayContent(await ToByteArray(_content)); - } - } +// private async Task AddContentToRequest(HttpRequestMessage httpRequestMessage) +// { +// if (_content != null) +// { +// httpRequestMessage.Content = new ByteArrayContent(await ToByteArray(_content)); +// } +// } - private void AddContentTypeToRequest(HttpRequestMessage httpRequestMessage) - { - if (!string.IsNullOrEmpty(_contentType)) - { - httpRequestMessage.Content.Headers.Remove("Content-Type"); - httpRequestMessage.Content.Headers.TryAddWithoutValidation("Content-Type", _contentType); - } - } +// private void AddContentTypeToRequest(HttpRequestMessage httpRequestMessage) +// { +// if (!string.IsNullOrEmpty(_contentType)) +// { +// httpRequestMessage.Content.Headers.Remove("Content-Type"); +// httpRequestMessage.Content.Headers.TryAddWithoutValidation("Content-Type", _contentType); +// } +// } - private void AddHeadersToRequest(HttpRequestMessage httpRequestMessage) - { - if (_headers != null) - { - _headers.Remove("Content-Type"); +// private void AddHeadersToRequest(HttpRequestMessage httpRequestMessage) +// { +// if (_headers != null) +// { +// _headers.Remove("Content-Type"); - foreach (var header in _headers) - { - //todo get rid of if.. - if (IsSupportedHeader(header)) - { - httpRequestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray()); - } - } - } - } +// foreach (var header in _headers) +// { +// //todo get rid of if.. +// if (IsSupportedHeader(header)) +// { +// httpRequestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray()); +// } +// } +// } +// } - private bool IsSupportedHeader(KeyValuePair header) - { - return !_unsupportedHeaders.Contains(header.Key.ToLower()); - } +// private bool IsSupportedHeader(KeyValuePair header) +// { +// return !_unsupportedHeaders.Contains(header.Key.ToLower()); +// } - private void AddRequestIdHeader(RequestId.RequestId requestId, HttpRequestMessage httpRequestMessage) - { - httpRequestMessage.Headers.Add(requestId.RequestIdKey, requestId.RequestIdValue); - } +// private void AddRequestIdHeader(RequestId.RequestId requestId, HttpRequestMessage httpRequestMessage) +// { +// httpRequestMessage.Headers.Add(requestId.RequestIdKey, requestId.RequestIdValue); +// } - private bool RequestIdInHeaders(RequestId.RequestId requestId, HttpRequestHeaders headers) - { - IEnumerable value; - return headers.TryGetValues(requestId.RequestIdKey, out value); - } +// private bool RequestIdInHeaders(RequestId.RequestId requestId, HttpRequestHeaders headers) +// { +// IEnumerable value; +// return headers.TryGetValues(requestId.RequestIdKey, out value); +// } - private bool ShouldAddRequestId(RequestId.RequestId requestId, HttpRequestHeaders headers) - { - return !string.IsNullOrEmpty(requestId?.RequestIdKey) - && !string.IsNullOrEmpty(requestId.RequestIdValue) - && !RequestIdInHeaders(requestId, headers); - } +// private bool ShouldAddRequestId(RequestId.RequestId requestId, HttpRequestHeaders headers) +// { +// return !string.IsNullOrEmpty(requestId?.RequestIdKey) +// && !string.IsNullOrEmpty(requestId.RequestIdValue) +// && !RequestIdInHeaders(requestId, headers); +// } - private async Task ToByteArray(Stream stream) - { - using (stream) - { - using (var memStream = new MemoryStream()) - { - await stream.CopyToAsync(memStream); - return memStream.ToArray(); - } - } - } - } -} +// private async Task ToByteArray(Stream stream) +// { +// using (stream) +// { +// using (var memStream = new MemoryStream()) +// { +// await stream.CopyToAsync(memStream); +// return memStream.ToArray(); +// } +// } +// } +// } +//} diff --git a/src/Ocelot/Request/Mapper.cs b/src/Ocelot/Request/Mapper.cs new file mode 100644 index 00000000..30ad3820 --- /dev/null +++ b/src/Ocelot/Request/Mapper.cs @@ -0,0 +1,37 @@ +using System.IO; +using System.Net.Http; +using System.Threading.Tasks; + +namespace Ocelot.Request +{ + public class Mapper + { + public async Task Map(Microsoft.AspNetCore.Http.HttpRequest request) + { + var requestMessage = new HttpRequestMessage() + { + Content = new ByteArrayContent(await ToByteArray(request.Body)), + //Headers = request.Headers, + //Method = request.Method, + //Properties = request.P, + //RequestUri = request., + //Version = null + }; + + return requestMessage; + } + + private async Task ToByteArray(Stream stream) + { + using (stream) + { + using (var memStream = new MemoryStream()) + { + await stream.CopyToAsync(memStream); + return memStream.ToArray(); + } + } + } + } +} + diff --git a/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs b/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs new file mode 100644 index 00000000..39fe3b9b --- /dev/null +++ b/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs @@ -0,0 +1,46 @@ +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Ocelot.Infrastructure.RequestData; +using Ocelot.Logging; +using Ocelot.Middleware; +using Ocelot.Request.Builder; +using Ocelot.Requester.QoS; + +namespace Ocelot.Request.Middleware +{ + public class DownstreamRequestInitialiserMiddleware : OcelotMiddleware + { + private readonly RequestDelegate _next; + private readonly IRequestCreator _requestCreator; + private readonly IOcelotLogger _logger; + private readonly IQosProviderHouse _qosProviderHouse; + + public DownstreamRequestInitialiserMiddleware(RequestDelegate next, + IOcelotLoggerFactory loggerFactory, + IRequestScopedDataRepository requestScopedDataRepository, + IRequestCreator requestCreator, + IQosProviderHouse qosProviderHouse) + :base(requestScopedDataRepository) + { + _next = next; + _requestCreator = requestCreator; + _qosProviderHouse = qosProviderHouse; + _logger = loggerFactory.CreateLogger(); + } + + public async Task Invoke(HttpContext context) + { + _logger.LogDebug("started calling request builder middleware"); + + var mapper = new Mapper(); + + SetDownstreamRequest(await mapper.Map(context.Request)); + + _logger.LogDebug("calling next middleware"); + + await _next.Invoke(context); + + _logger.LogDebug("succesfully called next middleware"); + } + } +} \ No newline at end of file diff --git a/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs b/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs index 0701e089..fa92c8e1 100644 --- a/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs +++ b/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs @@ -43,14 +43,19 @@ namespace Ocelot.Request.Middleware return; } - var buildResult = await _requestCreator - .Build(context.Request.Method, - DownstreamUrl, - context.Request.Body, - context.Request.Headers, - context.Request.QueryString, - context.Request.ContentType, - new RequestId.RequestId(DownstreamRoute?.ReRoute?.RequestIdKey, context.TraceIdentifier), + //var buildResult = await _requestCreator + // .Build(context.Request.Method, + // DownstreamUrl, + // context.Request.Body, + // context.Request.Headers, + // context.Request.QueryString, + // context.Request.ContentType, + // new RequestId.RequestId(DownstreamRoute?.ReRoute?.RequestIdKey, context.TraceIdentifier), + // DownstreamRoute.ReRoute.IsQos, + // qosProvider.Data); + + var buildResult = await _requestCreator.Build( + DownstreamRequest, DownstreamRoute.ReRoute.IsQos, qosProvider.Data); diff --git a/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddlewareExtensions.cs b/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddlewareExtensions.cs index 4c08afec..20bb9164 100644 --- a/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddlewareExtensions.cs +++ b/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddlewareExtensions.cs @@ -8,5 +8,10 @@ namespace Ocelot.Request.Middleware { return builder.UseMiddleware(); } + + public static IApplicationBuilder UseDownstreamRequestInitialiser(this IApplicationBuilder builder) + { + return builder.UseMiddleware(); + } } } \ No newline at end of file diff --git a/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs b/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs index 1e1a955c..168c377a 100644 --- a/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs +++ b/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs @@ -5,6 +5,9 @@ using Microsoft.Extensions.Primitives; using Ocelot.Infrastructure.RequestData; using Ocelot.Logging; using Ocelot.Middleware; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Collections.Generic; namespace Ocelot.RequestId.Middleware { @@ -30,8 +33,6 @@ namespace Ocelot.RequestId.Middleware SetOcelotRequestId(context); - _logger.LogDebug("set requestId"); - _logger.TraceInvokeNext(); await _next.Invoke(context); _logger.TraceInvokeNextCompleted(); @@ -46,15 +47,28 @@ namespace Ocelot.RequestId.Middleware { key = DownstreamRoute.ReRoute.RequestIdKey; } + + StringValues requestIds; - StringValues requestId; - - if (context.Request.Headers.TryGetValue(key, out requestId)) + if (context.Request.Headers.TryGetValue(key, out requestIds)) { - _requestScopedDataRepository.Add("RequestId", requestId.First()); + var requestId = requestIds.First(); + var downstreamRequestHeaders = DownstreamRequest.Headers; + + if (!string.IsNullOrEmpty(requestId) && + !HeaderExists(key, downstreamRequestHeaders)) + { + downstreamRequestHeaders.Add(key, requestId); + } context.TraceIdentifier = requestId; } } + + private bool HeaderExists(string headerKey, HttpRequestHeaders headers) + { + IEnumerable value; + return headers.TryGetValues(headerKey, out value); + } } } \ No newline at end of file diff --git a/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs b/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs index 47951859..950dc72d 100644 --- a/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs +++ b/test/Ocelot.UnitTests/Headers/AddHeadersToRequestTests.cs @@ -129,7 +129,9 @@ namespace Ocelot.UnitTests.Headers private void WhenIAddHeadersToTheRequest() { - _result = _addHeadersToRequest.SetHeadersOnContext(_configuration, _context); + //_result = _addHeadersToRequest.SetHeadersOnContext(_configuration, _context); + //TODO: pass in DownstreamRequest + _result = _addHeadersToRequest.SetHeadersOnDownstreamRequest(_configuration, _context.User.Claims, null); } private void ThenTheResultIsSuccess() diff --git a/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs b/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs index 032a76e9..aff337fc 100644 --- a/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Headers/HttpRequestHeadersBuilderMiddlewareTests.cs @@ -84,17 +84,28 @@ namespace Ocelot.UnitTests.Headers private void GivenTheAddHeadersToRequestReturns() { + //_addHeaders + // .Setup(x => x.SetHeadersOnContext(It.IsAny>(), + // It.IsAny())) + // .Returns(new OkResponse()); _addHeaders - .Setup(x => x.SetHeadersOnContext(It.IsAny>(), - It.IsAny())) + .Setup(x => x.SetHeadersOnDownstreamRequest( + It.IsAny>(), + It.IsAny>(), + It.IsAny())) .Returns(new OkResponse()); } private void ThenTheAddHeadersToRequestIsCalledCorrectly() { + //_addHeaders + // .Verify(x => x.SetHeadersOnContext(It.IsAny>(), + // It.IsAny()), Times.Once); _addHeaders - .Verify(x => x.SetHeadersOnContext(It.IsAny>(), - It.IsAny()), Times.Once); + .Verify(x => x.SetHeadersOnDownstreamRequest( + It.IsAny>(), + It.IsAny>(), + It.IsAny()), Times.Once); } private void WhenICallTheMiddleware() diff --git a/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs b/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs index 99cf68b8..cf713017 100644 --- a/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs +++ b/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs @@ -128,7 +128,9 @@ namespace Ocelot.UnitTests.QueryStrings private void WhenIAddQueriesToTheRequest() { - _result = _addQueriesToRequest.SetQueriesOnContext(_configuration, _context); + //_result = _addQueriesToRequest.SetQueriesOnContext(_configuration, _context); + //TODO: set downstreamRequest + _result = _addQueriesToRequest.SetQueriesOnDownstreamRequest(_configuration, _context.User.Claims, null); } private void ThenTheResultIsSuccess() diff --git a/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs b/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs index f381ff1b..3b701554 100644 --- a/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/QueryStrings/QueryStringBuilderMiddlewareTests.cs @@ -20,6 +20,7 @@ using Ocelot.QueryStrings.Middleware; using Ocelot.Responses; using TestStack.BDDfy; using Xunit; +using System.Security.Claims; namespace Ocelot.UnitTests.QueryStrings { @@ -82,17 +83,28 @@ namespace Ocelot.UnitTests.QueryStrings private void GivenTheAddHeadersToRequestReturns() { + //_addQueries + // .Setup(x => x.SetQueriesOnContext(It.IsAny>(), + // It.IsAny())) + // .Returns(new OkResponse()); _addQueries - .Setup(x => x.SetQueriesOnContext(It.IsAny>(), - It.IsAny())) + .Setup(x => x.SetQueriesOnDownstreamRequest( + It.IsAny>(), + It.IsAny>(), + It.IsAny())) .Returns(new OkResponse()); } private void ThenTheAddQueriesToRequestIsCalledCorrectly() { + //_addQueries + // .Verify(x => x.SetQueriesOnContext(It.IsAny>(), + // It.IsAny()), Times.Once); _addQueries - .Verify(x => x.SetQueriesOnContext(It.IsAny>(), - It.IsAny()), Times.Once); + .Verify(x => x.SetQueriesOnDownstreamRequest( + It.IsAny>(), + It.IsAny>(), + It.IsAny()), Times.Once); } private void WhenICallTheMiddleware() diff --git a/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs b/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs index 002e4e02..4bae0151 100644 --- a/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs @@ -103,9 +103,12 @@ namespace Ocelot.UnitTests.Request private void GivenTheRequestBuilderReturns(Ocelot.Request.Request request) { _request = new OkResponse(request); + //_requestBuilder + // .Setup(x => x.Build(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), + // It.IsAny(), It.IsAny(), It.IsAny(),It.IsAny(), It.IsAny())) + // .ReturnsAsync(_request); _requestBuilder - .Setup(x => x.Build(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), - It.IsAny(), It.IsAny(), It.IsAny(),It.IsAny(), It.IsAny())) + .Setup(x => x.Build(It.IsAny(), It.IsAny(), It.IsAny())) .ReturnsAsync(_request); } diff --git a/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs b/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs index 667c80b0..c182e47b 100644 --- a/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs +++ b/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs @@ -287,8 +287,11 @@ namespace Ocelot.UnitTests.Request private void WhenICreateARequest() { - _result = _requestCreator.Build(_httpMethod, _downstreamUrl, _content?.ReadAsStreamAsync().Result, _headers, - _query, _contentType, _requestId,_isQos,_qoSProvider).Result; + //_result = _requestCreator.Build(_httpMethod, _downstreamUrl, _content?.ReadAsStreamAsync().Result, _headers, + // _query, _contentType, _requestId,_isQos,_qoSProvider).Result; + + //todo: add httprequestmessage + _result = _requestCreator.Build(null, _isQos, _qoSProvider).Result; }