use a stream rather than byte array in responder (#519)

This commit is contained in:
Tom Pallister 2018-07-31 19:21:12 +01:00 committed by GitHub
parent eb4b996c99
commit b854ca63ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 334 additions and 335 deletions

View File

@ -1,16 +1,15 @@
using System;
using System.Linq;
using System.Threading.Tasks;
using Ocelot.Configuration.Repository;
using Ocelot.Infrastructure.Extensions;
using Ocelot.Infrastructure.RequestData;
using Ocelot.Logging;
using Ocelot.Middleware;
namespace Ocelot.Errors.Middleware namespace Ocelot.Errors.Middleware
{ {
using Configuration; using Configuration;
using System;
using System.Linq;
using System.Threading.Tasks;
using Ocelot.Configuration.Repository;
using Ocelot.Infrastructure.Extensions;
using Ocelot.Infrastructure.RequestData;
using Ocelot.Logging;
using Ocelot.Middleware;
/// <summary> /// <summary>
/// Catches all unhandled exceptions thrown by middleware, logs and returns a 500 /// Catches all unhandled exceptions thrown by middleware, logs and returns a 500
/// </summary> /// </summary>

View File

@ -1,74 +1,74 @@
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Net; using System.Net;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives; using Microsoft.Extensions.Primitives;
using Ocelot.Headers; using Ocelot.Headers;
using Ocelot.Middleware; using Ocelot.Middleware;
namespace Ocelot.Responder namespace Ocelot.Responder
{ {
/// <summary> /// <summary>
/// Cannot unit test things in this class due to methods not being implemented /// Cannot unit test things in this class due to methods not being implemented
/// on .net concretes used for testing /// on .net concretes used for testing
/// </summary> /// </summary>
public class HttpContextResponder : IHttpResponder public class HttpContextResponder : IHttpResponder
{ {
private readonly IRemoveOutputHeaders _removeOutputHeaders; private readonly IRemoveOutputHeaders _removeOutputHeaders;
public HttpContextResponder(IRemoveOutputHeaders removeOutputHeaders) public HttpContextResponder(IRemoveOutputHeaders removeOutputHeaders)
{
_removeOutputHeaders = removeOutputHeaders;
}
public async Task SetResponseOnHttpContext(HttpContext context, DownstreamResponse response)
{
_removeOutputHeaders.Remove(response.Headers);
foreach (var httpResponseHeader in response.Headers)
{
AddHeaderIfDoesntExist(context, httpResponseHeader);
}
foreach (var httpResponseHeader in response.Content.Headers)
{
AddHeaderIfDoesntExist(context, new Header(httpResponseHeader.Key, httpResponseHeader.Value));
}
var content = await response.Content.ReadAsByteArrayAsync();
AddHeaderIfDoesntExist(context, new Header("Content-Length", new []{ content.Length.ToString() }) );
context.Response.OnStarting(state =>
{
var httpContext = (HttpContext)state;
httpContext.Response.StatusCode = (int)response.StatusCode;
return Task.CompletedTask;
}, context);
using (Stream stream = new MemoryStream(content))
{
if (response.StatusCode != HttpStatusCode.NotModified && context.Response.ContentLength != 0)
{
await stream.CopyToAsync(context.Response.Body);
}
}
}
public void SetErrorResponseOnContext(HttpContext context, int statusCode)
{ {
context.Response.StatusCode = statusCode; _removeOutputHeaders = removeOutputHeaders;
} }
private static void AddHeaderIfDoesntExist(HttpContext context, Header httpResponseHeader) public async Task SetResponseOnHttpContext(HttpContext context, DownstreamResponse response)
{ {
if (!context.Response.Headers.ContainsKey(httpResponseHeader.Key)) _removeOutputHeaders.Remove(response.Headers);
{
context.Response.Headers.Add(httpResponseHeader.Key, new StringValues(httpResponseHeader.Values.ToArray())); foreach (var httpResponseHeader in response.Headers)
} {
} AddHeaderIfDoesntExist(context, httpResponseHeader);
} }
}
foreach (var httpResponseHeader in response.Content.Headers)
{
AddHeaderIfDoesntExist(context, new Header(httpResponseHeader.Key, httpResponseHeader.Value));
}
var content = await response.Content.ReadAsStreamAsync();
AddHeaderIfDoesntExist(context, new Header("Content-Length", new []{ content.Length.ToString() }) );
context.Response.OnStarting(state =>
{
var httpContext = (HttpContext)state;
httpContext.Response.StatusCode = (int)response.StatusCode;
return Task.CompletedTask;
}, context);
using(content)
{
if (response.StatusCode != HttpStatusCode.NotModified && context.Response.ContentLength != 0)
{
await content.CopyToAsync(context.Response.Body);
}
}
}
public void SetErrorResponseOnContext(HttpContext context, int statusCode)
{
context.Response.StatusCode = statusCode;
}
private static void AddHeaderIfDoesntExist(HttpContext context, Header httpResponseHeader)
{
if (!context.Response.Headers.ContainsKey(httpResponseHeader.Key))
{
context.Response.Headers.Add(httpResponseHeader.Key, new StringValues(httpResponseHeader.Values.ToArray()));
}
}
}
}

View File

@ -1,55 +1,55 @@
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Ocelot.Errors; using Ocelot.Errors;
using Ocelot.Logging; using Ocelot.Logging;
using Ocelot.Middleware; using Ocelot.Middleware;
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading.Tasks; using System.Threading.Tasks;
using Ocelot.Infrastructure.Extensions; using Ocelot.Infrastructure.Extensions;
namespace Ocelot.Responder.Middleware namespace Ocelot.Responder.Middleware
{ {
/// <summary> /// <summary>
/// Completes and returns the request and request body, if any pipeline errors occured then sets the appropriate HTTP status code instead. /// Completes and returns the request and request body, if any pipeline errors occured then sets the appropriate HTTP status code instead.
/// </summary> /// </summary>
public class ResponderMiddleware : OcelotMiddleware public class ResponderMiddleware : OcelotMiddleware
{ {
private readonly OcelotRequestDelegate _next; private readonly OcelotRequestDelegate _next;
private readonly IHttpResponder _responder; private readonly IHttpResponder _responder;
private readonly IErrorsToHttpStatusCodeMapper _codeMapper; private readonly IErrorsToHttpStatusCodeMapper _codeMapper;
public ResponderMiddleware(OcelotRequestDelegate next, public ResponderMiddleware(OcelotRequestDelegate next,
IHttpResponder responder, IHttpResponder responder,
IOcelotLoggerFactory loggerFactory, IOcelotLoggerFactory loggerFactory,
IErrorsToHttpStatusCodeMapper codeMapper IErrorsToHttpStatusCodeMapper codeMapper
) )
:base(loggerFactory.CreateLogger<ResponderMiddleware>()) :base(loggerFactory.CreateLogger<ResponderMiddleware>())
{ {
_next = next; _next = next;
_responder = responder; _responder = responder;
_codeMapper = codeMapper; _codeMapper = codeMapper;
} }
public async Task Invoke(DownstreamContext context) public async Task Invoke(DownstreamContext context)
{ {
await _next.Invoke(context); await _next.Invoke(context);
if (context.IsError) if (context.IsError)
{ {
Logger.LogWarning($"{context.Errors.ToErrorString()} errors found in {MiddlewareName}. Setting error response for request path:{context.HttpContext.Request.Path}, request method: {context.HttpContext.Request.Method}"); Logger.LogWarning($"{context.Errors.ToErrorString()} errors found in {MiddlewareName}. Setting error response for request path:{context.HttpContext.Request.Path}, request method: {context.HttpContext.Request.Method}");
SetErrorResponse(context.HttpContext, context.Errors); SetErrorResponse(context.HttpContext, context.Errors);
} }
else else
{ {
Logger.LogDebug("no pipeline errors, setting and returning completed response"); Logger.LogDebug("no pipeline errors, setting and returning completed response");
await _responder.SetResponseOnHttpContext(context.HttpContext, context.DownstreamResponse); await _responder.SetResponseOnHttpContext(context.HttpContext, context.DownstreamResponse);
} }
} }
private void SetErrorResponse(HttpContext context, List<Error> errors) private void SetErrorResponse(HttpContext context, List<Error> errors)
{ {
var statusCode = _codeMapper.Map(errors); var statusCode = _codeMapper.Map(errors);
_responder.SetErrorResponseOnContext(context, statusCode); _responder.SetErrorResponseOnContext(context, statusCode);
} }
} }
} }

View File

@ -1,197 +1,197 @@
namespace Ocelot.UnitTests.Errors namespace Ocelot.UnitTests.Errors
{ {
using System; using System;
using System.Net; using System.Net;
using System.Threading.Tasks; using System.Threading.Tasks;
using Ocelot.Errors.Middleware; using Ocelot.Errors.Middleware;
using Ocelot.Logging; using Ocelot.Logging;
using Shouldly; using Shouldly;
using TestStack.BDDfy; using TestStack.BDDfy;
using Xunit; using Xunit;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Moq; using Moq;
using Ocelot.Configuration; using Ocelot.Configuration;
using Ocelot.Errors; using Ocelot.Errors;
using Ocelot.Infrastructure.RequestData; using Ocelot.Infrastructure.RequestData;
using Ocelot.Middleware; using Ocelot.Middleware;
using Ocelot.Configuration.Repository; using Ocelot.Configuration.Repository;
public class ExceptionHandlerMiddlewareTests public class ExceptionHandlerMiddlewareTests
{ {
bool _shouldThrowAnException; bool _shouldThrowAnException;
private readonly Mock<IInternalConfigurationRepository> _configRepo; private readonly Mock<IInternalConfigurationRepository> _configRepo;
private readonly Mock<IRequestScopedDataRepository> _repo; private readonly Mock<IRequestScopedDataRepository> _repo;
private Mock<IOcelotLoggerFactory> _loggerFactory; private Mock<IOcelotLoggerFactory> _loggerFactory;
private Mock<IOcelotLogger> _logger; private Mock<IOcelotLogger> _logger;
private readonly ExceptionHandlerMiddleware _middleware; private readonly ExceptionHandlerMiddleware _middleware;
private readonly DownstreamContext _downstreamContext; private readonly DownstreamContext _downstreamContext;
private OcelotRequestDelegate _next; private OcelotRequestDelegate _next;
public ExceptionHandlerMiddlewareTests() public ExceptionHandlerMiddlewareTests()
{ {
_configRepo = new Mock<IInternalConfigurationRepository>(); _configRepo = new Mock<IInternalConfigurationRepository>();
_repo = new Mock<IRequestScopedDataRepository>(); _repo = new Mock<IRequestScopedDataRepository>();
_downstreamContext = new DownstreamContext(new DefaultHttpContext()); _downstreamContext = new DownstreamContext(new DefaultHttpContext());
_loggerFactory = new Mock<IOcelotLoggerFactory>(); _loggerFactory = new Mock<IOcelotLoggerFactory>();
_logger = new Mock<IOcelotLogger>(); _logger = new Mock<IOcelotLogger>();
_loggerFactory.Setup(x => x.CreateLogger<ExceptionHandlerMiddleware>()).Returns(_logger.Object); _loggerFactory.Setup(x => x.CreateLogger<ExceptionHandlerMiddleware>()).Returns(_logger.Object);
_next = async context => { _next = async context => {
await Task.CompletedTask; await Task.CompletedTask;
if (_shouldThrowAnException) if (_shouldThrowAnException)
{ {
throw new Exception("BOOM"); throw new Exception("BOOM");
} }
context.HttpContext.Response.StatusCode = (int)HttpStatusCode.OK; context.HttpContext.Response.StatusCode = (int)HttpStatusCode.OK;
}; };
_middleware = new ExceptionHandlerMiddleware(_next, _loggerFactory.Object, _configRepo.Object, _repo.Object); _middleware = new ExceptionHandlerMiddleware(_next, _loggerFactory.Object, _configRepo.Object, _repo.Object);
} }
[Fact] [Fact]
public void NoDownstreamException() public void NoDownstreamException()
{ {
var config = new InternalConfiguration(null, null, null, null, null, null, null, null); var config = new InternalConfiguration(null, null, null, null, null, null, null, null);
this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream())
.And(_ => GivenTheConfigurationIs(config)) .And(_ => GivenTheConfigurationIs(config))
.When(_ => WhenICallTheMiddleware()) .When(_ => WhenICallTheMiddleware())
.Then(_ => ThenTheResponseIsOk()) .Then(_ => ThenTheResponseIsOk())
.And(_ => TheAspDotnetRequestIdIsSet()) .And(_ => TheAspDotnetRequestIdIsSet())
.BDDfy(); .BDDfy();
} }
[Fact] [Fact]
public void DownstreamException() public void DownstreamException()
{ {
var config = new InternalConfiguration(null, null, null, null, null, null, null, null); var config = new InternalConfiguration(null, null, null, null, null, null, null, null);
this.Given(_ => GivenAnExceptionWillBeThrownDownstream()) this.Given(_ => GivenAnExceptionWillBeThrownDownstream())
.And(_ => GivenTheConfigurationIs(config)) .And(_ => GivenTheConfigurationIs(config))
.When(_ => WhenICallTheMiddleware()) .When(_ => WhenICallTheMiddleware())
.Then(_ => ThenTheResponseIsError()) .Then(_ => ThenTheResponseIsError())
.BDDfy(); .BDDfy();
} }
[Fact] [Fact]
public void ShouldSetRequestId() public void ShouldSetRequestId()
{ {
var config = new InternalConfiguration(null, null, null, "requestidkey", null, null, null, null); var config = new InternalConfiguration(null, null, null, "requestidkey", null, null, null, null);
this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream())
.And(_ => GivenTheConfigurationIs(config)) .And(_ => GivenTheConfigurationIs(config))
.When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234"))
.Then(_ => ThenTheResponseIsOk()) .Then(_ => ThenTheResponseIsOk())
.And(_ => TheRequestIdIsSet("RequestId", "1234")) .And(_ => TheRequestIdIsSet("RequestId", "1234"))
.BDDfy(); .BDDfy();
} }
[Fact] [Fact]
public void ShouldSetAspDotNetRequestId() public void ShouldSetAspDotNetRequestId()
{ {
var config = new InternalConfiguration(null, null, null, null, null, null, null, null); var config = new InternalConfiguration(null, null, null, null, null, null, null, null);
this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream())
.And(_ => GivenTheConfigurationIs(config)) .And(_ => GivenTheConfigurationIs(config))
.When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234"))
.Then(_ => ThenTheResponseIsOk()) .Then(_ => ThenTheResponseIsOk())
.And(_ => TheAspDotnetRequestIdIsSet()) .And(_ => TheAspDotnetRequestIdIsSet())
.BDDfy(); .BDDfy();
} }
[Fact] [Fact]
public void should_throw_exception_if_config_provider_returns_error() public void should_throw_exception_if_config_provider_returns_error()
{ {
this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream())
.And(_ => GivenTheConfigReturnsError()) .And(_ => GivenTheConfigReturnsError())
.When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234"))
.Then(_ => ThenAnExceptionIsThrown()) .Then(_ => ThenAnExceptionIsThrown())
.BDDfy(); .BDDfy();
} }
[Fact] [Fact]
public void should_throw_exception_if_config_provider_throws() public void should_throw_exception_if_config_provider_throws()
{ {
this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream())
.And(_ => GivenTheConfigThrows()) .And(_ => GivenTheConfigThrows())
.When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234"))
.Then(_ => ThenAnExceptionIsThrown()) .Then(_ => ThenAnExceptionIsThrown())
.BDDfy(); .BDDfy();
} }
private void WhenICallTheMiddlewareWithTheRequestIdKey(string key, string value) private void WhenICallTheMiddlewareWithTheRequestIdKey(string key, string value)
{ {
_downstreamContext.HttpContext.Request.Headers.Add(key, value); _downstreamContext.HttpContext.Request.Headers.Add(key, value);
_middleware.Invoke(_downstreamContext).GetAwaiter().GetResult(); _middleware.Invoke(_downstreamContext).GetAwaiter().GetResult();
} }
private void WhenICallTheMiddleware() private void WhenICallTheMiddleware()
{ {
_middleware.Invoke(_downstreamContext).GetAwaiter().GetResult(); _middleware.Invoke(_downstreamContext).GetAwaiter().GetResult();
} }
private void GivenTheConfigThrows() private void GivenTheConfigThrows()
{ {
var ex = new Exception("outer", new Exception("inner")); var ex = new Exception("outer", new Exception("inner"));
_configRepo _configRepo
.Setup(x => x.Get()).Throws(ex); .Setup(x => x.Get()).Throws(ex);
} }
private void ThenAnExceptionIsThrown() private void ThenAnExceptionIsThrown()
{ {
_downstreamContext.HttpContext.Response.StatusCode.ShouldBe(500); _downstreamContext.HttpContext.Response.StatusCode.ShouldBe(500);
} }
private void GivenTheConfigReturnsError() private void GivenTheConfigReturnsError()
{ {
var response = new Responses.ErrorResponse<IInternalConfiguration>(new FakeError()); var response = new Responses.ErrorResponse<IInternalConfiguration>(new FakeError());
_configRepo _configRepo
.Setup(x => x.Get()).Returns(response); .Setup(x => x.Get()).Returns(response);
} }
private void TheRequestIdIsSet(string key, string value) private void TheRequestIdIsSet(string key, string value)
{ {
_repo.Verify(x => x.Add(key, value), Times.Once); _repo.Verify(x => x.Add(key, value), Times.Once);
} }
private void GivenTheConfigurationIs(IInternalConfiguration config) private void GivenTheConfigurationIs(IInternalConfiguration config)
{ {
var response = new Responses.OkResponse<IInternalConfiguration>(config); var response = new Responses.OkResponse<IInternalConfiguration>(config);
_configRepo _configRepo
.Setup(x => x.Get()).Returns(response); .Setup(x => x.Get()).Returns(response);
} }
private void GivenAnExceptionWillNotBeThrownDownstream() private void GivenAnExceptionWillNotBeThrownDownstream()
{ {
_shouldThrowAnException = false; _shouldThrowAnException = false;
} }
private void GivenAnExceptionWillBeThrownDownstream() private void GivenAnExceptionWillBeThrownDownstream()
{ {
_shouldThrowAnException = true; _shouldThrowAnException = true;
} }
private void ThenTheResponseIsOk() private void ThenTheResponseIsOk()
{ {
_downstreamContext.HttpContext.Response.StatusCode.ShouldBe(200); _downstreamContext.HttpContext.Response.StatusCode.ShouldBe(200);
} }
private void ThenTheResponseIsError() private void ThenTheResponseIsError()
{ {
_downstreamContext.HttpContext.Response.StatusCode.ShouldBe(500); _downstreamContext.HttpContext.Response.StatusCode.ShouldBe(500);
} }
private void TheAspDotnetRequestIdIsSet() private void TheAspDotnetRequestIdIsSet()
{ {
_repo.Verify(x => x.Add(It.IsAny<string>(), It.IsAny<string>()), Times.Once); _repo.Verify(x => x.Add(It.IsAny<string>(), It.IsAny<string>()), Times.Once);
} }
class FakeError : Error class FakeError : Error
{ {
internal FakeError() internal FakeError()
: base("meh", OcelotErrorCode.CannotAddDataError) : base("meh", OcelotErrorCode.CannotAddDataError)
{ {
} }
} }
} }
} }