diff --git a/src/Ocelot/Configuration/Builder/ReRouteBuilder.cs b/src/Ocelot/Configuration/Builder/ReRouteBuilder.cs index cf780650..d47675de 100644 --- a/src/Ocelot/Configuration/Builder/ReRouteBuilder.cs +++ b/src/Ocelot/Configuration/Builder/ReRouteBuilder.cs @@ -20,6 +20,7 @@ namespace Ocelot.Configuration.Builder private Dictionary _routeClaimRequirement; private bool _isAuthorised; private List _claimToQueries; + private string _requestIdHeaderKey; public ReRouteBuilder() { @@ -96,6 +97,12 @@ namespace Ocelot.Configuration.Builder return this; } + public ReRouteBuilder WithRequestIdKey(string input) + { + _requestIdHeaderKey = input; + return this; + } + public ReRouteBuilder WithClaimsToHeaders(List input) { _configHeaderExtractorProperties = input; @@ -122,7 +129,7 @@ namespace Ocelot.Configuration.Builder public ReRoute Build() { - return new ReRoute(_downstreamTemplate, _upstreamTemplate, _upstreamHttpMethod, _upstreamTemplatePattern, _isAuthenticated, new AuthenticationOptions(_authenticationProvider, _authenticationProviderUrl, _scopeName, _requireHttps, _additionalScopes, _scopeSecret), _configHeaderExtractorProperties, _claimToClaims, _routeClaimRequirement, _isAuthorised, _claimToQueries); + return new ReRoute(_downstreamTemplate, _upstreamTemplate, _upstreamHttpMethod, _upstreamTemplatePattern, _isAuthenticated, new AuthenticationOptions(_authenticationProvider, _authenticationProviderUrl, _scopeName, _requireHttps, _additionalScopes, _scopeSecret), _configHeaderExtractorProperties, _claimToClaims, _routeClaimRequirement, _isAuthorised, _claimToQueries, _requestIdHeaderKey); } } } diff --git a/src/Ocelot/Configuration/Creator/YamlOcelotConfigurationCreator.cs b/src/Ocelot/Configuration/Creator/YamlOcelotConfigurationCreator.cs index 2dac5cc9..d35a7321 100644 --- a/src/Ocelot/Configuration/Creator/YamlOcelotConfigurationCreator.cs +++ b/src/Ocelot/Configuration/Creator/YamlOcelotConfigurationCreator.cs @@ -113,12 +113,17 @@ namespace Ocelot.Configuration.Creator return new ReRoute(reRoute.DownstreamTemplate, reRoute.UpstreamTemplate, reRoute.UpstreamHttpMethod, upstreamTemplate, isAuthenticated, - authOptionsForRoute, claimsToHeaders, claimsToClaims, reRoute.RouteClaimsRequirement, isAuthorised, claimsToQueries + authOptionsForRoute, claimsToHeaders, claimsToClaims, + reRoute.RouteClaimsRequirement, isAuthorised, claimsToQueries, + reRoute.RequestIdKey ); } - return new ReRoute(reRoute.DownstreamTemplate, reRoute.UpstreamTemplate, reRoute.UpstreamHttpMethod, - upstreamTemplate, isAuthenticated, null, new List(), new List(), reRoute.RouteClaimsRequirement, isAuthorised, new List()); + return new ReRoute(reRoute.DownstreamTemplate, reRoute.UpstreamTemplate, + reRoute.UpstreamHttpMethod, upstreamTemplate, isAuthenticated, + null, new List(), new List(), + reRoute.RouteClaimsRequirement, isAuthorised, new List(), + reRoute.RequestIdKey); } private List GetAddThingsToRequest(Dictionary thingBeingAdded) diff --git a/src/Ocelot/Configuration/ReRoute.cs b/src/Ocelot/Configuration/ReRoute.cs index 7d0a1600..0caf972a 100644 --- a/src/Ocelot/Configuration/ReRoute.cs +++ b/src/Ocelot/Configuration/ReRoute.cs @@ -4,7 +4,7 @@ namespace Ocelot.Configuration { public class ReRoute { - public ReRoute(string downstreamTemplate, string upstreamTemplate, string upstreamHttpMethod, string upstreamTemplatePattern, bool isAuthenticated, AuthenticationOptions authenticationOptions, List configurationHeaderExtractorProperties, List claimsToClaims, Dictionary routeClaimsRequirement, bool isAuthorised, List claimsToQueries) + public ReRoute(string downstreamTemplate, string upstreamTemplate, string upstreamHttpMethod, string upstreamTemplatePattern, bool isAuthenticated, AuthenticationOptions authenticationOptions, List configurationHeaderExtractorProperties, List claimsToClaims, Dictionary routeClaimsRequirement, bool isAuthorised, List claimsToQueries, string requestIdKey) { DownstreamTemplate = downstreamTemplate; UpstreamTemplate = upstreamTemplate; @@ -14,6 +14,7 @@ namespace Ocelot.Configuration AuthenticationOptions = authenticationOptions; RouteClaimsRequirement = routeClaimsRequirement; IsAuthorised = isAuthorised; + RequestIdKey = requestIdKey; ClaimsToQueries = claimsToQueries ?? new List(); ClaimsToClaims = claimsToClaims @@ -33,6 +34,6 @@ namespace Ocelot.Configuration public List ClaimsToHeaders { get; private set; } public List ClaimsToClaims { get; private set; } public Dictionary RouteClaimsRequirement { get; private set; } - + public string RequestIdKey { get; private set; } } } \ No newline at end of file diff --git a/src/Ocelot/Configuration/Yaml/YamlReRoute.cs b/src/Ocelot/Configuration/Yaml/YamlReRoute.cs index 8e5bbd70..03dbe593 100644 --- a/src/Ocelot/Configuration/Yaml/YamlReRoute.cs +++ b/src/Ocelot/Configuration/Yaml/YamlReRoute.cs @@ -19,6 +19,7 @@ namespace Ocelot.Configuration.Yaml public Dictionary AddHeadersToRequest { get; set; } public Dictionary AddClaimsToRequest { get; set; } public Dictionary RouteClaimsRequirement { get; set; } - public Dictionary AddQueriesToRequest { get; set; } + public Dictionary AddQueriesToRequest { get; set; } + public string RequestIdKey { get; set; } } } \ No newline at end of file diff --git a/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs b/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs index cca55afe..db761130 100644 --- a/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs +++ b/src/Ocelot/DependencyInjection/ServiceCollectionExtensions.cs @@ -1,5 +1,4 @@ -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Ocelot.Authentication.Handler.Creator; @@ -16,6 +15,7 @@ using Ocelot.DownstreamRouteFinder.Finder; using Ocelot.DownstreamRouteFinder.UrlMatcher; using Ocelot.DownstreamUrlCreator.UrlTemplateReplacer; using Ocelot.Headers; +using Ocelot.Infrastructure.Claims.Parser; using Ocelot.Infrastructure.RequestData; using Ocelot.QueryStrings; using Ocelot.Request.Builder; @@ -24,8 +24,6 @@ using Ocelot.Responder; namespace Ocelot.DependencyInjection { - using Infrastructure.Claims.Parser; - public static class ServiceCollectionExtensions { public static IServiceCollection AddOcelotYamlConfiguration(this IServiceCollection services, IConfigurationRoot configurationRoot) @@ -48,6 +46,7 @@ namespace Ocelot.DependencyInjection services.AddLogging(); // ocelot services. + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Ocelot/DownstreamRouteFinder/Middleware/DownstreamRouteFinderMiddleware.cs b/src/Ocelot/DownstreamRouteFinder/Middleware/DownstreamRouteFinderMiddleware.cs index 04f3ff2e..7c8b12ed 100644 --- a/src/Ocelot/DownstreamRouteFinder/Middleware/DownstreamRouteFinderMiddleware.cs +++ b/src/Ocelot/DownstreamRouteFinder/Middleware/DownstreamRouteFinderMiddleware.cs @@ -10,7 +10,6 @@ namespace Ocelot.DownstreamRouteFinder.Middleware { private readonly RequestDelegate _next; private readonly IDownstreamRouteFinder _downstreamRouteFinder; - private readonly IRequestScopedDataRepository _requestScopedDataRepository; public DownstreamRouteFinderMiddleware(RequestDelegate next, IDownstreamRouteFinder downstreamRouteFinder, @@ -19,7 +18,6 @@ namespace Ocelot.DownstreamRouteFinder.Middleware { _next = next; _downstreamRouteFinder = downstreamRouteFinder; - _requestScopedDataRepository = requestScopedDataRepository; } public async Task Invoke(HttpContext context) diff --git a/src/Ocelot/Headers/IRemoveHeaders.cs b/src/Ocelot/Headers/IRemoveHeaders.cs new file mode 100644 index 00000000..d22f4306 --- /dev/null +++ b/src/Ocelot/Headers/IRemoveHeaders.cs @@ -0,0 +1,10 @@ +using System.Net.Http.Headers; +using Ocelot.Responses; + +namespace Ocelot.Headers +{ + public interface IRemoveHeaders + { + Response Remove(HttpResponseHeaders headers); + } +} diff --git a/src/Ocelot/Headers/RemoveHeaders.cs b/src/Ocelot/Headers/RemoveHeaders.cs new file mode 100644 index 00000000..46ce7ba1 --- /dev/null +++ b/src/Ocelot/Headers/RemoveHeaders.cs @@ -0,0 +1,30 @@ +using System.Collections.Generic; +using System.Net.Http.Headers; +using Ocelot.Responses; + +namespace Ocelot.Headers +{ + public class RemoveHeaders : IRemoveHeaders + { + /// + /// Some webservers return headers that cannot be forwarded to the client + /// in a given context such as transfer encoding chunked when ASP.NET is not + /// returning the response in this manner + /// + private readonly List _unsupportedHeaders = new List + { + "Transfer-Encoding" + }; + + + public Response Remove(HttpResponseHeaders headers) + { + foreach (var unsupported in _unsupportedHeaders) + { + headers.Remove(unsupported); + } + + return new OkResponse(); + } + } +} \ No newline at end of file diff --git a/src/Ocelot/Middleware/ExceptionHandlerMiddleware.cs b/src/Ocelot/Middleware/ExceptionHandlerMiddleware.cs index c53256c9..1ee0c93d 100644 --- a/src/Ocelot/Middleware/ExceptionHandlerMiddleware.cs +++ b/src/Ocelot/Middleware/ExceptionHandlerMiddleware.cs @@ -10,7 +10,8 @@ namespace Ocelot.Middleware private readonly RequestDelegate _next; private readonly ILogger _logger; - public ExceptionHandlerMiddleware(RequestDelegate next, ILoggerFactory loggerFactory) + public ExceptionHandlerMiddleware(RequestDelegate next, + ILoggerFactory loggerFactory) { _next = next; _logger = loggerFactory.CreateLogger(); @@ -24,20 +25,25 @@ namespace Ocelot.Middleware } catch (Exception e) { - var message = - $"Exception caught in global error handler, exception message: {e.Message}, exception stack: {e.StackTrace}"; - - if (e.InnerException != null) - { - message = $"{message}, inner exception message {e.InnerException.Message}, inner exception stack {e.InnerException.StackTrace}"; - } - + var message = CreateMessage(context, e); _logger.LogError(new EventId(1, "Ocelot Global Error"), message, e); - context.Response.StatusCode = 500; context.Response.ContentType = "application/json"; await context.Response.WriteAsync("Internal Server Error"); } } + + private static string CreateMessage(HttpContext context, Exception e) + { + var message = + $"RequestId: {context.TraceIdentifier}, Exception caught in global error handler, exception message: {e.Message}, exception stack: {e.StackTrace}"; + + if (e.InnerException != null) + { + message = + $"{message}, inner exception message {e.InnerException.Message}, inner exception stack {e.InnerException.StackTrace}"; + } + return message; + } } } diff --git a/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs b/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs index 30813a5c..dbbe79d3 100644 --- a/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs +++ b/src/Ocelot/Middleware/OcelotMiddlewareExtensions.cs @@ -7,6 +7,7 @@ using Ocelot.Headers.Middleware; using Ocelot.QueryStrings.Middleware; using Ocelot.Request.Middleware; using Ocelot.Requester.Middleware; +using Ocelot.RequestId.Middleware; using Ocelot.Responder.Middleware; namespace Ocelot.Middleware @@ -49,6 +50,9 @@ namespace Ocelot.Middleware // Then we get the downstream route information builder.UseDownstreamRouteFinderMiddleware(); + // Now we can look for the requestId + builder.UseRequestIdMiddleware(); + // Allow pre authentication logic. The idea being people might want to run something custom before what is built in. builder.UseIfNotNull(middlewareConfiguration.PreAuthenticationMiddleware); diff --git a/src/Ocelot/Request/Builder/HttpRequestBuilder.cs b/src/Ocelot/Request/Builder/HttpRequestBuilder.cs index a1b437f1..6d2f29cc 100644 --- a/src/Ocelot/Request/Builder/HttpRequestBuilder.cs +++ b/src/Ocelot/Request/Builder/HttpRequestBuilder.cs @@ -1,7 +1,9 @@ using System; +using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Http; +using System.Net.Http.Headers; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Ocelot.Responses; @@ -10,8 +12,15 @@ namespace Ocelot.Request.Builder { public class HttpRequestBuilder : IRequestBuilder { - public async Task> Build(string httpMethod, string downstreamUrl, Stream content, IHeaderDictionary headers, - IRequestCookieCollection cookies, Microsoft.AspNetCore.Http.QueryString queryString, string contentType) + public async Task> Build( + string httpMethod, + string downstreamUrl, + Stream content, + IHeaderDictionary headers, + IRequestCookieCollection cookies, + QueryString queryString, + string contentType, + RequestId.RequestId requestId) { var method = new HttpMethod(httpMethod); @@ -21,7 +30,7 @@ namespace Ocelot.Request.Builder if (content != null) { - httpRequestMessage.Content = new ByteArrayContent(ToByteArray(content)); + httpRequestMessage.Content = new ByteArrayContent(await ToByteArray(content)); } if (!string.IsNullOrEmpty(contentType)) @@ -45,6 +54,11 @@ namespace Ocelot.Request.Builder } } + if (RequestKeyIsNotNull(requestId) && !RequestIdInHeaders(requestId, httpRequestMessage.Headers)) + { + ForwardRequestIdToDownstreamService(requestId, httpRequestMessage); + } + var cookieContainer = new CookieContainer(); //todo get rid of if @@ -59,13 +73,34 @@ namespace Ocelot.Request.Builder return new OkResponse(new Request(httpRequestMessage, cookieContainer)); } - private byte[] ToByteArray(Stream stream) + private void ForwardRequestIdToDownstreamService(RequestId.RequestId requestId, HttpRequestMessage httpRequestMessage) + { + httpRequestMessage.Headers.Add(requestId.RequestIdKey, requestId.RequestIdValue); + } + + private bool RequestIdInHeaders(RequestId.RequestId requestId, HttpRequestHeaders headers) + { + IEnumerable value; + if (headers.TryGetValues(requestId.RequestIdKey, out value)) + { + return true; + } + + return false; + } + + private bool RequestKeyIsNotNull(RequestId.RequestId requestId) + { + return !string.IsNullOrEmpty(requestId?.RequestIdKey) && !string.IsNullOrEmpty(requestId.RequestIdValue); + } + + private async Task ToByteArray(Stream stream) { using (stream) { - using (MemoryStream memStream = new MemoryStream()) + using (var memStream = new MemoryStream()) { - stream.CopyTo(memStream); + await stream.CopyToAsync(memStream); return memStream.ToArray(); } } diff --git a/src/Ocelot/Request/Builder/IRequestBuilder.cs b/src/Ocelot/Request/Builder/IRequestBuilder.cs index efb8f31a..cefa3bcf 100644 --- a/src/Ocelot/Request/Builder/IRequestBuilder.cs +++ b/src/Ocelot/Request/Builder/IRequestBuilder.cs @@ -12,7 +12,8 @@ namespace Ocelot.Request.Builder Stream content, IHeaderDictionary headers, IRequestCookieCollection cookies, - Microsoft.AspNetCore.Http.QueryString queryString, - string contentType); + QueryString queryString, + string contentType, + RequestId.RequestId requestId); } } diff --git a/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs b/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs index d4e75bf3..82f01452 100644 --- a/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs +++ b/src/Ocelot/Request/Middleware/HttpRequestBuilderMiddleware.cs @@ -22,17 +22,18 @@ namespace Ocelot.Request.Middleware public async Task Invoke(HttpContext context) { - var request = await _requestBuilder - .Build(context.Request.Method, DownstreamUrl, context.Request.Body, - context.Request.Headers, context.Request.Cookies, context.Request.QueryString, context.Request.ContentType); + var buildResult = await _requestBuilder + .Build(context.Request.Method, DownstreamUrl, context.Request.Body, + context.Request.Headers, context.Request.Cookies, context.Request.QueryString, + context.Request.ContentType, new RequestId.RequestId(DownstreamRoute?.ReRoute?.RequestIdKey, context.TraceIdentifier)); - if (request.IsError) + if (buildResult.IsError) { - SetPipelineError(request.Errors); + SetPipelineError(buildResult.Errors); return; } - SetUpstreamRequestForThisRequest(request.Data); + SetUpstreamRequestForThisRequest(buildResult.Data); await _next.Invoke(context); } diff --git a/src/Ocelot/RequestId/DefaultRequestIdKey.cs b/src/Ocelot/RequestId/DefaultRequestIdKey.cs new file mode 100644 index 00000000..94c0d82d --- /dev/null +++ b/src/Ocelot/RequestId/DefaultRequestIdKey.cs @@ -0,0 +1,9 @@ +namespace Ocelot.RequestId +{ + public static class DefaultRequestIdKey + { + // This is set incase anyone isnt doing this specifically with there requests. + // It will not be forwarded on to downstream services unless specfied in the config. + public const string Value = "RequestId"; + } +} diff --git a/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs b/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs new file mode 100644 index 00000000..e6d02497 --- /dev/null +++ b/src/Ocelot/RequestId/Middleware/RequestIdMiddleware.cs @@ -0,0 +1,45 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Primitives; +using Ocelot.Infrastructure.RequestData; +using Ocelot.Middleware; + +namespace Ocelot.RequestId.Middleware +{ + public class RequestIdMiddleware : OcelotMiddleware + { + private readonly RequestDelegate _next; + + public RequestIdMiddleware(RequestDelegate next, + IRequestScopedDataRepository requestScopedDataRepository) + :base(requestScopedDataRepository) + { + _next = next; + } + + public async Task Invoke(HttpContext context) + { + SetTraceIdentifier(context); + + await _next.Invoke(context); + } + + private void SetTraceIdentifier(HttpContext context) + { + var key = DefaultRequestIdKey.Value; + + if (DownstreamRoute.ReRoute.RequestIdKey != null) + { + key = DownstreamRoute.ReRoute.RequestIdKey; + } + + StringValues requestId; + + if (context.Request.Headers.TryGetValue(key, out requestId)) + { + context.TraceIdentifier = requestId; + } + } + } +} \ No newline at end of file diff --git a/src/Ocelot/RequestId/Middleware/RequestIdMiddlewareExtensions.cs b/src/Ocelot/RequestId/Middleware/RequestIdMiddlewareExtensions.cs new file mode 100644 index 00000000..dc29afde --- /dev/null +++ b/src/Ocelot/RequestId/Middleware/RequestIdMiddlewareExtensions.cs @@ -0,0 +1,12 @@ +using Microsoft.AspNetCore.Builder; + +namespace Ocelot.RequestId.Middleware +{ + public static class RequestIdMiddlewareExtensions + { + public static IApplicationBuilder UseRequestIdMiddleware(this IApplicationBuilder builder) + { + return builder.UseMiddleware(); + } + } +} \ No newline at end of file diff --git a/src/Ocelot/RequestId/RequestId.cs b/src/Ocelot/RequestId/RequestId.cs new file mode 100644 index 00000000..998d33aa --- /dev/null +++ b/src/Ocelot/RequestId/RequestId.cs @@ -0,0 +1,14 @@ +namespace Ocelot.RequestId +{ + public class RequestId + { + public RequestId(string requestIdKey, string requestIdValue) + { + RequestIdKey = requestIdKey; + RequestIdValue = requestIdValue; + } + + public string RequestIdKey { get; private set; } + public string RequestIdValue { get; private set; } + } +} diff --git a/src/Ocelot/Requester/Middleware/HttpRequesterMiddleware.cs b/src/Ocelot/Requester/Middleware/HttpRequesterMiddleware.cs index 10422819..6dbb28bf 100644 --- a/src/Ocelot/Requester/Middleware/HttpRequesterMiddleware.cs +++ b/src/Ocelot/Requester/Middleware/HttpRequesterMiddleware.cs @@ -1,5 +1,7 @@ +using System.Diagnostics; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; using Ocelot.Infrastructure.RequestData; using Ocelot.Middleware; using Ocelot.Responder; @@ -25,7 +27,6 @@ namespace Ocelot.Requester.Middleware public async Task Invoke(HttpContext context) { - var response = await _requester.GetResponse(Request); if (response.IsError) diff --git a/src/Ocelot/Responder/HttpContextResponder.cs b/src/Ocelot/Responder/HttpContextResponder.cs index 2f381ea2..5c81225f 100644 --- a/src/Ocelot/Responder/HttpContextResponder.cs +++ b/src/Ocelot/Responder/HttpContextResponder.cs @@ -1,6 +1,10 @@ -using System.Net.Http; +using System.IO; +using System.Linq; +using System.Net.Http; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Primitives; +using Ocelot.Headers; using Ocelot.Responses; namespace Ocelot.Responder @@ -11,15 +15,42 @@ namespace Ocelot.Responder /// public class HttpContextResponder : IHttpResponder { + private readonly IRemoveHeaders _removeHeaders; + + public HttpContextResponder(IRemoveHeaders removeHeaders) + { + _removeHeaders = removeHeaders; + } + public async Task SetResponseOnHttpContext(HttpContext context, HttpResponseMessage response) { - context.Response.OnStarting(x => + _removeHeaders.Remove(response.Headers); + + foreach (var httpResponseHeader in response.Headers) { - context.Response.StatusCode = (int)response.StatusCode; + context.Response.Headers.Add(httpResponseHeader.Key, new StringValues(httpResponseHeader.Value.ToArray())); + } + + var content = await response.Content.ReadAsStreamAsync(); + + context.Response.Headers.Add("Content-Length", new[] { content.Length.ToString() }); + + context.Response.OnStarting(state => + { + var httpContext = (HttpContext)state; + + httpContext.Response.StatusCode = (int)response.StatusCode; + return Task.CompletedTask; + }, context); - await context.Response.WriteAsync(await response.Content.ReadAsStringAsync()); + using (var reader = new StreamReader(content)) + { + var responseContent = reader.ReadToEnd(); + await context.Response.WriteAsync(responseContent); + } + return new OkResponse(); } diff --git a/test/Ocelot.AcceptanceTests/RequestIdTests.cs b/test/Ocelot.AcceptanceTests/RequestIdTests.cs new file mode 100644 index 00000000..0c3407b5 --- /dev/null +++ b/test/Ocelot.AcceptanceTests/RequestIdTests.cs @@ -0,0 +1,108 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Primitives; +using Ocelot.Configuration.Yaml; +using TestStack.BDDfy; +using Xunit; + +namespace Ocelot.AcceptanceTests +{ + public class RequestIdTests : IDisposable + { + private IWebHost _builder; + private readonly Steps _steps; + + public RequestIdTests() + { + _steps = new Steps(); + } + + [Fact] + public void should_use_default_request_id_and_forward() + { + var yamlConfiguration = new YamlConfiguration + { + ReRoutes = new List + { + new YamlReRoute + { + DownstreamTemplate = "http://localhost:51879/", + UpstreamTemplate = "/", + UpstreamHttpMethod = "Get", + RequestIdKey = _steps.RequestIdKey + } + } + }; + + this.Given(x => x.GivenThereIsAServiceRunningOn("http://localhost:51879")) + .And(x => _steps.GivenThereIsAConfiguration(yamlConfiguration)) + .And(x => _steps.GivenOcelotIsRunning()) + .When(x => _steps.WhenIGetUrlOnTheApiGateway("/")) + .Then(x => _steps.ThenTheRequestIdIsReturned()) + .BDDfy(); + } + + [Fact] + public void should_use_request_id_and_forward() + { + var yamlConfiguration = new YamlConfiguration + { + ReRoutes = new List + { + new YamlReRoute + { + DownstreamTemplate = "http://localhost:51879/", + UpstreamTemplate = "/", + UpstreamHttpMethod = "Get", + RequestIdKey = _steps.RequestIdKey + } + } + }; + + var requestId = Guid.NewGuid().ToString(); + + this.Given(x => x.GivenThereIsAServiceRunningOn("http://localhost:51879")) + .And(x => _steps.GivenThereIsAConfiguration(yamlConfiguration)) + .And(x => _steps.GivenOcelotIsRunning()) + .When(x => _steps.WhenIGetUrlOnTheApiGateway("/", requestId)) + .Then(x => _steps.ThenTheRequestIdIsReturned(requestId)) + .BDDfy(); + } + + private void GivenThereIsAServiceRunningOn(string url) + { + _builder = new WebHostBuilder() + .UseUrls(url) + .UseKestrel() + .UseContentRoot(Directory.GetCurrentDirectory()) + .UseIISIntegration() + .UseUrls(url) + .Configure(app => + { + app.Run(context => + { + StringValues requestId; + context.Request.Headers.TryGetValue(_steps.RequestIdKey, out requestId); + context.Response.Headers.Add(_steps.RequestIdKey, requestId.First()); + return Task.CompletedTask; + }); + }) + .Build(); + + _builder.Start(); + } + + public void Dispose() + { + _builder?.Dispose(); + _steps.Dispose(); + } + } +} diff --git a/test/Ocelot.AcceptanceTests/Steps.cs b/test/Ocelot.AcceptanceTests/Steps.cs index b9f94f17..23a7cdbc 100644 --- a/test/Ocelot.AcceptanceTests/Steps.cs +++ b/test/Ocelot.AcceptanceTests/Steps.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; @@ -26,6 +27,7 @@ namespace Ocelot.AcceptanceTests private HttpContent _postContent; private BearerToken _token; public HttpClient OcelotClient => _ocelotClient; + public string RequestIdKey = "OcRequestId"; public void GivenThereIsAConfiguration(YamlConfiguration yamlConfiguration) { @@ -146,6 +148,13 @@ namespace Ocelot.AcceptanceTests _response = _ocelotClient.GetAsync(url).Result; } + public void WhenIGetUrlOnTheApiGateway(string url, string requestId) + { + _ocelotClient.DefaultRequestHeaders.TryAddWithoutValidation(RequestIdKey, requestId); + + _response = _ocelotClient.GetAsync(url).Result; + } + public void WhenIPostUrlOnTheApiGateway(string url) { _response = _ocelotClient.PostAsync(url, _postContent).Result; @@ -171,5 +180,15 @@ namespace Ocelot.AcceptanceTests _ocelotClient?.Dispose(); _ocelotServer?.Dispose(); } + + public void ThenTheRequestIdIsReturned() + { + _response.Headers.GetValues(RequestIdKey).First().ShouldNotBeNullOrEmpty(); + } + + public void ThenTheRequestIdIsReturned(string expected) + { + _response.Headers.GetValues(RequestIdKey).First().ShouldBe(expected); + } } } diff --git a/test/Ocelot.ManualTest/Properties/launchSettings.json b/test/Ocelot.ManualTest/Properties/launchSettings.json index e2e8ad9a..bfc47fdf 100644 --- a/test/Ocelot.ManualTest/Properties/launchSettings.json +++ b/test/Ocelot.ManualTest/Properties/launchSettings.json @@ -17,7 +17,6 @@ }, "Ocelot.ManualTest": { "commandName": "Project", - "launchBrowser": true, "launchUrl": "http://localhost:5000", "environmentVariables": { "ASPNETCORE_ENVIRONMENT": "Development" diff --git a/test/Ocelot.ManualTest/configuration.yaml b/test/Ocelot.ManualTest/configuration.yaml index 74bdcac3..2d2aab11 100644 --- a/test/Ocelot.ManualTest/configuration.yaml +++ b/test/Ocelot.ManualTest/configuration.yaml @@ -50,6 +50,10 @@ ReRoutes: # the value must be registered RouteClaimsRequirement: UserType: registered +# This tells Ocelot to look for a header and use its value as a request/correlation id. +# If it is set here then the id will be forwarded to the downstream service. If it +# does not then it will not be forwarded + RequestIdKey: OcRequestId # The next re route... - DownstreamTemplate: http://jsonplaceholder.typicode.com/posts UpstreamTemplate: /posts diff --git a/test/Ocelot.UnitTests/Errors/GobalErrorHandlerTests.cs b/test/Ocelot.UnitTests/Errors/GobalErrorHandlerTests.cs new file mode 100644 index 00000000..9a75e38a --- /dev/null +++ b/test/Ocelot.UnitTests/Errors/GobalErrorHandlerTests.cs @@ -0,0 +1,83 @@ +/* +using System; +using System.IO; +using System.Net.Http; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Moq; +using Ocelot.Middleware; +using Ocelot.RequestId.Provider; +using TestStack.BDDfy; +using Xunit; + +namespace Ocelot.UnitTests.Errors +{ + public class GobalErrorHandlerTests + { + private readonly Mock _loggerFactory; + private readonly Mock> _logger; + private readonly Mock _requestIdProvider; + private readonly string _url; + private readonly TestServer _server; + private readonly HttpClient _client; + private HttpResponseMessage _result; + + public GobalErrorHandlerTests() + { + _url = "http://localhost:51879"; + _logger = new Mock>(); + _loggerFactory = new Mock(); + _requestIdProvider = new Mock(); + var builder = new WebHostBuilder() + .ConfigureServices(x => + { + x.AddSingleton(_requestIdProvider.Object); + x.AddSingleton(_loggerFactory.Object); + }) + .UseUrls(_url) + .UseKestrel() + .UseContentRoot(Directory.GetCurrentDirectory()) + .UseIISIntegration() + .UseUrls(_url) + .Configure(app => + { + app.UseExceptionHandlerMiddleware(); + + app.Run(x => + { + throw new Exception("BLAM"); + }); + }); + + _loggerFactory + .Setup(x => x.CreateLogger()) + .Returns(_logger.Object); + + _server = new TestServer(builder); + _client = _server.CreateClient(); + } + + [Fact] + public void should_catch_exception_and_log() + { + this.When(x => x.WhenICallTheMiddleware()) + .And(x => x.TheLoggerIsCalledCorrectly()) + .BDDfy(); + } + + private void TheLoggerIsCalledCorrectly() + { + _logger + .Verify(x => x.LogError(It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); + } + + private void WhenICallTheMiddleware() + { + _result = _client.GetAsync(_url).Result; + } + } +} +*/ diff --git a/test/Ocelot.UnitTests/Headers/RemoveHeaders.cs b/test/Ocelot.UnitTests/Headers/RemoveHeaders.cs new file mode 100644 index 00000000..774dd56b --- /dev/null +++ b/test/Ocelot.UnitTests/Headers/RemoveHeaders.cs @@ -0,0 +1,52 @@ +using System.Net.Http; +using System.Net.Http.Headers; +using Ocelot.Responses; +using Shouldly; +using TestStack.BDDfy; +using Xunit; + +namespace Ocelot.UnitTests.Headers +{ + public class RemoveHeaders + { + private HttpResponseHeaders _headers; + private readonly Ocelot.Headers.RemoveHeaders _removeHeaders; + private Response _result; + + public RemoveHeaders() + { + _removeHeaders = new Ocelot.Headers.RemoveHeaders(); + } + + [Fact] + public void should_remove_header() + { + var httpResponse = new HttpResponseMessage() + { + Headers = {{ "Transfer-Encoding", "chunked"}} + }; + + this.Given(x => x.GivenAHttpContext(httpResponse.Headers)) + .When(x => x.WhenIRemoveTheHeaders()) + .Then(x => x.TheHeaderIsNoLongerInTheContext()) + .BDDfy(); + } + + private void GivenAHttpContext(HttpResponseHeaders headers) + { + _headers = headers; + } + + private void WhenIRemoveTheHeaders() + { + _result = _removeHeaders.Remove(_headers); + } + + private void TheHeaderIsNoLongerInTheContext() + { + _result.IsError.ShouldBeFalse(); + _headers.ShouldNotContain(x => x.Key == "Transfer-Encoding"); + _headers.ShouldNotContain(x => x.Key == "transfer-encoding"); + } + } +} diff --git a/test/Ocelot.UnitTests/Infrastructure/ClaimParserTests.cs b/test/Ocelot.UnitTests/Infrastructure/ClaimParserTests.cs index 5683d636..ddd383c3 100644 --- a/test/Ocelot.UnitTests/Infrastructure/ClaimParserTests.cs +++ b/test/Ocelot.UnitTests/Infrastructure/ClaimParserTests.cs @@ -1,4 +1,6 @@ -namespace Ocelot.UnitTests.Infrastructure +using Ocelot.Errors; + +namespace Ocelot.UnitTests.Infrastructure { using System.Collections.Generic; using System.Security.Claims; diff --git a/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs b/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs index f894b63f..ccc55e9e 100644 --- a/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Request/HttpRequestBuilderMiddlewareTests.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Http; @@ -7,6 +8,9 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; using Moq; +using Ocelot.Configuration.Builder; +using Ocelot.DownstreamRouteFinder; +using Ocelot.DownstreamRouteFinder.UrlMatcher; using Ocelot.Infrastructure.RequestData; using Ocelot.Request.Builder; using Ocelot.Request.Middleware; @@ -26,6 +30,7 @@ namespace Ocelot.UnitTests.Request private HttpResponseMessage _result; private OkResponse _request; private OkResponse _downstreamUrl; + private OkResponse _downstreamRoute; public HttpRequestBuilderMiddlewareTests() { @@ -56,19 +61,34 @@ namespace Ocelot.UnitTests.Request [Fact] public void happy_path() { + + var downstreamRoute = new DownstreamRoute(new List(), + new ReRouteBuilder() + .WithRequestIdKey("LSRequestId").Build()); + + this.Given(x => x.GivenTheDownStreamUrlIs("any old string")) + .And(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) .And(x => x.GivenTheRequestBuilderReturns(new Ocelot.Request.Request(new HttpRequestMessage(), new CookieContainer()))) .When(x => x.WhenICallTheMiddleware()) .Then(x => x.ThenTheScopedDataRepositoryIsCalledCorrectly()) .BDDfy(); } + private void GivenTheDownStreamRouteIs(DownstreamRoute downstreamRoute) + { + _downstreamRoute = new OkResponse(downstreamRoute); + _scopedRepository + .Setup(x => x.Get(It.IsAny())) + .Returns(_downstreamRoute); + } + private void GivenTheRequestBuilderReturns(Ocelot.Request.Request request) { _request = new OkResponse(request); _requestBuilder .Setup(x => x.Build(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), - It.IsAny(), It.IsAny(), It.IsAny())) + It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .ReturnsAsync(_request); } diff --git a/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs b/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs index 7b6a1688..5e82b86c 100644 --- a/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs +++ b/test/Ocelot.UnitTests/Request/RequestBuilderTests.cs @@ -24,6 +24,7 @@ namespace Ocelot.UnitTests.Request private string _contentType; private readonly IRequestBuilder _requestBuilder; private Response _result; + private Ocelot.RequestId.RequestId _requestId; public RequestBuilderTests() { @@ -114,6 +115,62 @@ namespace Ocelot.UnitTests.Request .BDDfy(); } + [Fact] + public void should_use_request_id() + { + var requestId = Guid.NewGuid().ToString(); + + this.Given(x => x.GivenIHaveHttpMethod("GET")) + .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) + .And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary())) + .And(x => x.GivenTheRequestIdIs(new Ocelot.RequestId.RequestId("RequestId", requestId))) + .When(x => x.WhenICreateARequest()) + .And(x => x.ThenTheCorrectHeadersAreUsed(new HeaderDictionary + { + {"RequestId", requestId } + })) + .BDDfy(); + } + + [Fact] + public void should_not_use_request_if_if_already_in_headers() + { + this.Given(x => x.GivenIHaveHttpMethod("GET")) + .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) + .And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary + { + {"RequestId", "534534gv54gv45g" } + })) + .And(x => x.GivenTheRequestIdIs(new Ocelot.RequestId.RequestId("RequestId", Guid.NewGuid().ToString()))) + .When(x => x.WhenICreateARequest()) + .And(x => x.ThenTheCorrectHeadersAreUsed(new HeaderDictionary + { + {"RequestId", "534534gv54gv45g" } + })) + .BDDfy(); + } + + [Theory] + [InlineData(null, "blahh")] + [InlineData("", "blahh")] + [InlineData("RequestId", "")] + [InlineData("RequestId", null)] + public void should_not_use_request_id(string requestIdKey, string requestIdValue) + { + this.Given(x => x.GivenIHaveHttpMethod("GET")) + .And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk")) + .And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary())) + .And(x => x.GivenTheRequestIdIs(new Ocelot.RequestId.RequestId(requestIdKey, requestIdValue))) + .When(x => x.WhenICreateARequest()) + .And(x => x.ThenTheRequestIdIsNotInTheHeaders()) + .BDDfy(); + } + + private void GivenTheRequestIdIs(Ocelot.RequestId.RequestId requestId) + { + _requestId = requestId; + } + [Fact] public void should_use_cookies() { @@ -174,6 +231,11 @@ namespace Ocelot.UnitTests.Request _cookies = cookies; } + private void ThenTheRequestIdIsNotInTheHeaders() + { + _result.Data.HttpRequestMessage.Headers.ShouldNotContain(x => x.Key == "RequestId"); + } + private void ThenTheCorrectHeadersAreUsed(IHeaderDictionary expected) { var expectedHeaders = expected.Select(x => new KeyValuePair(x.Key, x.Value)); @@ -219,7 +281,7 @@ namespace Ocelot.UnitTests.Request private void WhenICreateARequest() { _result = _requestBuilder.Build(_httpMethod, _downstreamUrl, _content?.ReadAsStreamAsync().Result, _headers, - _cookies, _query, _contentType).Result; + _cookies, _query, _contentType, _requestId).Result; } diff --git a/test/Ocelot.UnitTests/RequestId/RequestIdMiddlewareTests.cs b/test/Ocelot.UnitTests/RequestId/RequestIdMiddlewareTests.cs new file mode 100644 index 00000000..6edaabc5 --- /dev/null +++ b/test/Ocelot.UnitTests/RequestId/RequestIdMiddlewareTests.cs @@ -0,0 +1,132 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Moq; +using Ocelot.Configuration.Builder; +using Ocelot.DownstreamRouteFinder; +using Ocelot.DownstreamRouteFinder.UrlMatcher; +using Ocelot.Infrastructure.RequestData; +using Ocelot.RequestId.Middleware; +using Ocelot.Responses; +using Shouldly; +using TestStack.BDDfy; +using Xunit; + +namespace Ocelot.UnitTests.RequestId +{ + public class RequestIdMiddlewareTests + { + private readonly Mock _scopedRepository; + private readonly string _url; + private readonly TestServer _server; + private readonly HttpClient _client; + private Response _downstreamRoute; + private HttpResponseMessage _result; + private string _value; + private string _key; + + public RequestIdMiddlewareTests() + { + _url = "http://localhost:51879"; + _scopedRepository = new Mock(); + + var builder = new WebHostBuilder() + .ConfigureServices(x => + { + x.AddSingleton(_scopedRepository.Object); + }) + .UseUrls(_url) + .UseKestrel() + .UseContentRoot(Directory.GetCurrentDirectory()) + .UseIISIntegration() + .UseUrls(_url) + .Configure(app => + { + app.UseRequestIdMiddleware(); + + app.Run(x => + { + x.Response.Headers.Add("LSRequestId", x.TraceIdentifier); + return Task.CompletedTask; + }); + }); + + _server = new TestServer(builder); + _client = _server.CreateClient(); + } + + [Fact] + public void should_add_request_id_to_repository() + { + var downstreamRoute = new DownstreamRoute(new List(), + new ReRouteBuilder() + .WithDownstreamTemplate("any old string") + .WithRequestIdKey("LSRequestId").Build()); + + var requestId = Guid.NewGuid().ToString(); + + this.Given(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) + .And(x => x.GivenTheRequestIdIsAddedToTheRequest("LSRequestId", requestId)) + .When(x => x.WhenICallTheMiddleware()) + .Then(x => x.ThenTheTraceIdIs(requestId)) + .BDDfy(); + } + + [Fact] + public void should_add_trace_indentifier_to_repository() + { + var downstreamRoute = new DownstreamRoute(new List(), + new ReRouteBuilder() + .WithDownstreamTemplate("any old string") + .WithRequestIdKey("LSRequestId").Build()); + + this.Given(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) + .When(x => x.WhenICallTheMiddleware()) + .Then(x => x.ThenTheTraceIdIsAnything()) + .BDDfy(); + } + + private void ThenTheTraceIdIsAnything() + { + _result.Headers.GetValues("LSRequestId").First().ShouldNotBeNullOrEmpty(); + } + + private void ThenTheTraceIdIs(string expected) + { + _result.Headers.GetValues("LSRequestId").First().ShouldBe(expected); + } + + private void GivenTheRequestIdIsAddedToTheRequest(string key, string value) + { + _key = key; + _value = value; + _client.DefaultRequestHeaders.TryAddWithoutValidation(_key, _value); + } + + private void WhenICallTheMiddleware() + { + _result = _client.GetAsync(_url).Result; + } + + private void GivenTheDownStreamRouteIs(DownstreamRoute downstreamRoute) + { + _downstreamRoute = new OkResponse(downstreamRoute); + _scopedRepository + .Setup(x => x.Get(It.IsAny())) + .Returns(_downstreamRoute); + } + + public void Dispose() + { + _client.Dispose(); + _server.Dispose(); + } + } +}