diff --git a/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs b/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs index e975dcb6..ef022923 100644 --- a/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs +++ b/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs @@ -31,6 +31,7 @@ using Ocelot.Middleware; using Ocelot.QueryStrings; using Ocelot.RateLimit; using Ocelot.Request.Builder; +using Ocelot.Request.Mapper; using Ocelot.Requester; using Ocelot.Requester.QoS; using Ocelot.Responder; @@ -160,6 +161,7 @@ namespace Ocelot.DependencyInjection services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(); + services.TryAddSingleton(); // see this for why we register this as singleton http://stackoverflow.com/questions/37371264/invalidoperationexception-unable-to-resolve-service-for-type-microsoft-aspnetc // could maybe use a scoped data repository diff --git a/src/Ocelot/Errors/OcelotErrorCode.cs b/src/Ocelot/Errors/OcelotErrorCode.cs index c1c55dbb..de7960c4 100644 --- a/src/Ocelot/Errors/OcelotErrorCode.cs +++ b/src/Ocelot/Errors/OcelotErrorCode.cs @@ -28,6 +28,7 @@ UnableToFindLoadBalancerError, RequestTimedOutError, UnableToFindQoSProviderError, - UnableToSetConfigInConsulError + UnableToSetConfigInConsulError, + UnmappableRequestError } } diff --git a/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs b/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs index 105be8dd..3f98f959 100644 --- a/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs +++ b/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs @@ -53,9 +53,6 @@ namespace Ocelot.Middleware { await CreateAdministrationArea(builder); - // Initialises downstream request - builder.UseDownstreamRequestInitialiser(); - // This is registered to catch any global exceptions that are not handled builder.UseExceptionHandlerMiddleware(); @@ -65,6 +62,9 @@ namespace Ocelot.Middleware // This is registered first so it can catch any errors and issue an appropriate response builder.UseResponderMiddleware(); + // Initialises downstream request + builder.UseDownstreamRequestInitialiser(); + // Then we get the downstream route information builder.UseDownstreamRouteFinderMiddleware(); diff --git a/src/Ocelot/Request/Mapper/IRequestMapper.cs b/src/Ocelot/Request/Mapper/IRequestMapper.cs new file mode 100644 index 00000000..941a24f7 --- /dev/null +++ b/src/Ocelot/Request/Mapper/IRequestMapper.cs @@ -0,0 +1,12 @@ +namespace Ocelot.Request.Mapper +{ + using System.Net.Http; + using System.Threading.Tasks; + using Microsoft.AspNetCore.Http; + using Ocelot.Responses; + + public interface IRequestMapper + { + Task> Map(HttpRequest request); + } +} diff --git a/src/Ocelot/Request/Mapper.cs b/src/Ocelot/Request/Mapper/RequestMapper.cs similarity index 59% rename from src/Ocelot/Request/Mapper.cs rename to src/Ocelot/Request/Mapper/RequestMapper.cs index c0cc3bba..17e2afe5 100644 --- a/src/Ocelot/Request/Mapper.cs +++ b/src/Ocelot/Request/Mapper/RequestMapper.cs @@ -1,33 +1,40 @@ -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Extensions; -using Microsoft.Extensions.Primitives; -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Net.Http; -using System.Threading.Tasks; - -namespace Ocelot.Request +namespace Ocelot.Request.Mapper { - public class Mapper + using System; + using System.Collections.Generic; + using System.IO; + using System.Linq; + using System.Net.Http; + using System.Threading.Tasks; + + using Microsoft.AspNetCore.Http; + using Microsoft.AspNetCore.Http.Extensions; + using Microsoft.Extensions.Primitives; + using Ocelot.Responses; + + public class RequestMapper : IRequestMapper { private readonly string[] _unsupportedHeaders = { "host" }; - public async Task Map(HttpRequest request) + public async Task> Map(HttpRequest request) { - var requestMessage = new HttpRequestMessage() + try { - Content = await MapContent(request), - Method = MapMethod(request), - RequestUri = MapUri(request), - //Properties = null - //Version = null - }; + var requestMessage = new HttpRequestMessage() + { + Content = await MapContent(request), + Method = MapMethod(request), + RequestUri = MapUri(request) + }; - MapHeaders(request, requestMessage); + MapHeaders(request, requestMessage); - return requestMessage; + return new OkResponse(requestMessage); + } + catch (Exception ex) + { + return new ErrorResponse(new UnmappableRequestError(ex)); + } } private async Task MapContent(HttpRequest request) @@ -37,7 +44,6 @@ namespace Ocelot.Request return null; } - return new ByteArrayContent(await ToByteArray(request.Body)); } diff --git a/src/Ocelot/Request/Mapper/UnmappableRequestError.cs b/src/Ocelot/Request/Mapper/UnmappableRequestError.cs new file mode 100644 index 00000000..4a860f5b --- /dev/null +++ b/src/Ocelot/Request/Mapper/UnmappableRequestError.cs @@ -0,0 +1,12 @@ +namespace Ocelot.Request.Mapper +{ + using Ocelot.Errors; + using System; + + public class UnmappableRequestError : Error + { + public UnmappableRequestError(Exception ex) : base($"Error when parsing incoming request, exception: {ex.Message}", OcelotErrorCode.UnmappableRequestError) + { + } + } +} diff --git a/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs b/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs index 39fe3b9b..a2813c25 100644 --- a/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs +++ b/src/Ocelot/Request/Middleware/DownstreamRequestInitialiserMiddleware.cs @@ -1,40 +1,41 @@ -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; -using Ocelot.Infrastructure.RequestData; -using Ocelot.Logging; -using Ocelot.Middleware; -using Ocelot.Request.Builder; -using Ocelot.Requester.QoS; - namespace Ocelot.Request.Middleware { + using System.Threading.Tasks; + using Microsoft.AspNetCore.Http; + + using Ocelot.Infrastructure.RequestData; + using Ocelot.Logging; + using Ocelot.Middleware; + public class DownstreamRequestInitialiserMiddleware : OcelotMiddleware { private readonly RequestDelegate _next; - private readonly IRequestCreator _requestCreator; private readonly IOcelotLogger _logger; - private readonly IQosProviderHouse _qosProviderHouse; + private readonly Mapper.IRequestMapper _requestMapper; public DownstreamRequestInitialiserMiddleware(RequestDelegate next, IOcelotLoggerFactory loggerFactory, - IRequestScopedDataRepository requestScopedDataRepository, - IRequestCreator requestCreator, - IQosProviderHouse qosProviderHouse) + IRequestScopedDataRepository requestScopedDataRepository, + Mapper.IRequestMapper requestMapper) :base(requestScopedDataRepository) { _next = next; - _requestCreator = requestCreator; - _qosProviderHouse = qosProviderHouse; _logger = loggerFactory.CreateLogger(); + _requestMapper = requestMapper; } public async Task Invoke(HttpContext context) { _logger.LogDebug("started calling request builder middleware"); - var mapper = new Mapper(); + var downstreamRequest = await _requestMapper.Map(context.Request); + if (downstreamRequest.IsError) + { + SetPipelineError(downstreamRequest.Errors); + return; + } - SetDownstreamRequest(await mapper.Map(context.Request)); + SetDownstreamRequest(downstreamRequest.Data); _logger.LogDebug("calling next middleware"); diff --git a/test/Ocelot.UnitTests/Request/DownstreamRequestInitialiserMiddlewareTests.cs b/test/Ocelot.UnitTests/Request/DownstreamRequestInitialiserMiddlewareTests.cs new file mode 100644 index 00000000..91c1d011 --- /dev/null +++ b/test/Ocelot.UnitTests/Request/DownstreamRequestInitialiserMiddlewareTests.cs @@ -0,0 +1,142 @@ +namespace Ocelot.UnitTests.Request +{ + using System.Net.Http; + using Microsoft.AspNetCore.Http; + using Moq; + using Ocelot.Logging; + using Ocelot.Request.Mapper; + using Ocelot.Request.Middleware; + using Ocelot.Infrastructure.RequestData; + using TestStack.BDDfy; + using Xunit; + using Ocelot.Responses; + + public class DownstreamRequestInitialiserMiddlewareTests + { + readonly DownstreamRequestInitialiserMiddleware _middleware; + + readonly Mock _httpContext; + + readonly Mock _httpRequest; + + readonly Mock _next; + + readonly Mock _requestMapper; + + readonly Mock _repo; + + readonly Mock _loggerFactory; + + readonly Mock _logger; + + Response _mappedRequest; + + public DownstreamRequestInitialiserMiddlewareTests() + { + + _httpContext = new Mock(); + _httpRequest = new Mock(); + _requestMapper = new Mock(); + _repo = new Mock(); + _next = new Mock(); + _logger = new Mock(); + + _loggerFactory = new Mock(); + _loggerFactory + .Setup(lf => lf.CreateLogger()) + .Returns(_logger.Object); + + _middleware = new DownstreamRequestInitialiserMiddleware( + _next.Object, + _loggerFactory.Object, + _repo.Object, + _requestMapper.Object); + } + + [Fact] + public void Should_handle_valid_httpRequest() + { + this.Given(_ => GivenTheHttpContextContainsARequest()) + .And(_ => GivenTheMapperWillReturnAMappedRequest()) + .When(_ => WhenTheMiddlewareIsInvoked()) + .Then(_ => ThenTheContexRequestIsMappedToADownstreamRequest()) + .And(_ => ThenTheDownstreamRequestIsStored()) + .And(_ => ThenTheNextMiddlewareIsInvoked()) + .BDDfy(); + } + + [Fact] + public void Should_handle_mapping_failure() + { + this.Given(_ => GivenTheHttpContextContainsARequest()) + .And(_ => GivenTheMapperWillReturnAnError()) + .When(_ => WhenTheMiddlewareIsInvoked()) + .And(_ => ThenTheDownstreamRequestIsNotStored()) + .And(_ => ThenAPipelineErrorIsStored()) + .And(_ => ThenTheNextMiddlewareIsNotInvoked()) + .BDDfy(); + } + + private void GivenTheHttpContextContainsARequest() + { + _httpContext + .Setup(hc => hc.Request) + .Returns(_httpRequest.Object); + } + + private void GivenTheMapperWillReturnAMappedRequest() + { + _mappedRequest = new OkResponse(new HttpRequestMessage()); + + _requestMapper + .Setup(rm => rm.Map(It.IsAny())) + .ReturnsAsync(_mappedRequest); + } + + private void GivenTheMapperWillReturnAnError() + { + _mappedRequest = new ErrorResponse(new UnmappableRequestError(new System.Exception("boooom!"))); + + _requestMapper + .Setup(rm => rm.Map(It.IsAny())) + .ReturnsAsync(_mappedRequest); + } + + private void WhenTheMiddlewareIsInvoked() + { + _middleware.Invoke(_httpContext.Object).GetAwaiter().GetResult(); + } + + private void ThenTheContexRequestIsMappedToADownstreamRequest() + { + _requestMapper.Verify(rm => rm.Map(_httpRequest.Object), Times.Once); + } + + private void ThenTheDownstreamRequestIsStored() + { + _repo.Verify(r => r.Add("DownstreamRequest", _mappedRequest.Data), Times.Once); + } + + private void ThenTheDownstreamRequestIsNotStored() + { + _repo.Verify(r => r.Add("DownstreamRequest", It.IsAny()), Times.Never); + } + + private void ThenAPipelineErrorIsStored() + { + _repo.Verify(r => r.Add("OcelotMiddlewareError", true), Times.Once); + _repo.Verify(r => r.Add("OcelotMiddlewareErrors", _mappedRequest.Errors), Times.Once); + } + + private void ThenTheNextMiddlewareIsInvoked() + { + _next.Verify(n => n(_httpContext.Object), Times.Once); + } + + private void ThenTheNextMiddlewareIsNotInvoked() + { + _next.Verify(n => n(It.IsAny()), Times.Never); + } + + } +} diff --git a/test/Ocelot.UnitTests/Request/Mapper/RequestMapperTests.cs b/test/Ocelot.UnitTests/Request/Mapper/RequestMapperTests.cs new file mode 100644 index 00000000..4334e017 --- /dev/null +++ b/test/Ocelot.UnitTests/Request/Mapper/RequestMapperTests.cs @@ -0,0 +1,258 @@ +namespace Ocelot.UnitTests.Request.Mapper +{ + using System.Collections.Generic; + using System.Linq; + using System.Net.Http; + + using Microsoft.AspNetCore.Http; + using Microsoft.AspNetCore.Http.Internal; + using Microsoft.Extensions.Primitives; + using Ocelot.Request.Mapper; + using Ocelot.Responses; + using TestStack.BDDfy; + using Xunit; + using Shouldly; + using System; + using System.IO; + using System.Text; + + public class RequestMapperTests + { + readonly HttpRequest _inputRequest; + + readonly RequestMapper _requestMapper; + + Response _mappedRequest; + + List> _inputHeaders = null; + + public RequestMapperTests() + { + _inputRequest = new DefaultHttpRequest(new DefaultHttpContext()); + + _requestMapper = new RequestMapper(); + } + + [Theory] + [InlineData("https", "my.url:123", "/abc/DEF", "?a=1&b=2", "https://my.url:123/abc/DEF?a=1&b=2")] + [InlineData("http", "blah.com", "/d ef", "?abc=123", "http://blah.com/d%20ef?abc=123")] // note! the input is encoded when building the input request + [InlineData("http", "myusername:mypassword@abc.co.uk", null, null, "http://myusername:mypassword@abc.co.uk/")] + [InlineData("http", "點看.com", null, null, "http://xn--c1yn36f.com/")] + [InlineData("http", "xn--c1yn36f.com", null, null, "http://xn--c1yn36f.com/")] + public void Should_map_valid_request_uri(string scheme, string host, string path, string queryString, string expectedUri) + { + this.Given(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasScheme(scheme)) + .And(_ => GivenTheInputRequestHasHost(host)) + .And(_ => GivenTheInputRequestHasPath(path)) + .And(_ => GivenTheInputRequestHasQueryString(queryString)) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasUri(expectedUri)) + .BDDfy(); + } + + [Theory] + [InlineData("ftp", "google.com", "/abc/DEF", "?a=1&b=2")] + public void Should_error_on_unsupported_request_uri(string scheme, string host, string path, string queryString) + { + this.Given(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasScheme(scheme)) + .And(_ => GivenTheInputRequestHasHost(host)) + .And(_ => GivenTheInputRequestHasPath(path)) + .And(_ => GivenTheInputRequestHasQueryString(queryString)) + .When(_ => WhenMapped()) + .Then(_ => ThenAnErrorIsReturned()) + .And(_ => ThenTheMappedRequestIsNull()) + .BDDfy(); + } + + [Theory] + [InlineData("GET")] + [InlineData("POST")] + [InlineData("WHATEVER")] + public void Should_map_method(string method) + { + this.Given(_ => GivenTheInputRequestHasMethod(method)) + .And(_ => GivenTheInputRequestHasAValidUri()) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasMethod(method)) + .BDDfy(); + } + + [Fact] + public void Should_map_all_headers() + { + this.Given(_ => GivenTheInputRequestHasHeaders()) + .And(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasAValidUri()) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasEachHeader()) + .BDDfy(); + } + + [Fact] + public void Should_handle_no_headers() + { + this.Given(_ => GivenTheInputRequestHasNoHeaders()) + .And(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasAValidUri()) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasNoHeaders()) + .BDDfy(); + } + + [Fact] + public void Should_map_content() + { + this.Given(_ => GivenTheInputRequestHasContent("This is my content")) + .And(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasAValidUri()) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasContent("This is my content")) + .BDDfy(); + } + + [Fact] + public void Should_handle_no_content() + { + this.Given(_ => GivenTheInputRequestHasNoContent()) + .And(_ => GivenTheInputRequestHasMethod("GET")) + .And(_ => GivenTheInputRequestHasAValidUri()) + .When(_ => WhenMapped()) + .Then(_ => ThenNoErrorIsReturned()) + .And(_ => ThenTheMappedRequestHasNoContent()) + .BDDfy(); + } + + private void GivenTheInputRequestHasMethod(string method) + { + _inputRequest.Method = method; + } + + private void GivenTheInputRequestHasScheme(string scheme) + { + _inputRequest.Scheme = scheme; + } + + private void GivenTheInputRequestHasHost(string host) + { + _inputRequest.Host = new HostString(host); + } + + private void GivenTheInputRequestHasPath(string path) + { + if (path != null) + { + _inputRequest.Path = path; + } + } + + private void GivenTheInputRequestHasQueryString(string querystring) + { + if (querystring != null) + { + _inputRequest.QueryString = new QueryString(querystring); + } + } + + private void GivenTheInputRequestHasAValidUri() + { + GivenTheInputRequestHasScheme("http"); + GivenTheInputRequestHasHost("www.google.com"); + } + + private void GivenTheInputRequestHasHeaders() + { + _inputHeaders = new List>() + { + new KeyValuePair("abc", new StringValues(new string[]{"123","456" })), + new KeyValuePair("def", new StringValues(new string[]{"789","012" })), + }; + + foreach (var inputHeader in _inputHeaders) + { + _inputRequest.Headers.Add(inputHeader); + } + } + + private void GivenTheInputRequestHasNoHeaders() + { + _inputRequest.Headers.Clear(); + } + + private void GivenTheInputRequestHasContent(string content) + { + _inputRequest.Body = new MemoryStream(Encoding.UTF8.GetBytes(content)); + } + + private void GivenTheInputRequestHasNoContent() + { + _inputRequest.Body = null; + } + + private void WhenMapped() + { + _mappedRequest = _requestMapper.Map(_inputRequest).GetAwaiter().GetResult(); + } + + private void ThenNoErrorIsReturned() + { + _mappedRequest.IsError.ShouldBeFalse(); + } + + private void ThenAnErrorIsReturned() + { + _mappedRequest.IsError.ShouldBeTrue(); + } + + private void ThenTheMappedRequestHasUri(string expectedUri) + { + _mappedRequest.Data.RequestUri.OriginalString.ShouldBe(expectedUri); + } + + private void ThenTheMappedRequestHasMethod(string expectedMethod) + { + _mappedRequest.Data.Method.ToString().ShouldBe(expectedMethod); + } + + private void ThenTheMappedRequestHasEachHeader() + { + _mappedRequest.Data.Headers.Count().ShouldBe(_inputHeaders.Count); + foreach(var header in _mappedRequest.Data.Headers) + { + var inputHeader = _inputHeaders.First(h => h.Key == header.Key); + inputHeader.ShouldNotBeNull(); + inputHeader.Value.Count().ShouldBe(header.Value.Count()); + foreach(var inputHeaderValue in inputHeader.Value) + { + header.Value.Any(v => v == inputHeaderValue); + } + } + } + + private void ThenTheMappedRequestHasNoHeaders() + { + _mappedRequest.Data.Headers.Count().ShouldBe(0); + } + + private void ThenTheMappedRequestHasContent(string expectedContent) + { + _mappedRequest.Data.Content.ReadAsStringAsync().GetAwaiter().GetResult().ShouldBe(expectedContent); + } + + private void ThenTheMappedRequestHasNoContent() + { + _mappedRequest.Data.Content.ShouldBeNull(); + } + + private void ThenTheMappedRequestIsNull() + { + _mappedRequest.Data.ShouldBeNull(); + } + } +}