Merge pull request #24 from TomPallister/develop

mergeNewestcode
This commit is contained in:
geffzhang 2017-04-24 07:16:26 +08:00 committed by GitHub
commit c76fa1e5fb
43 changed files with 1054 additions and 922 deletions

View File

@ -27,14 +27,14 @@ namespace Ocelot.Cache.Middleware
public async Task Invoke(HttpContext context) public async Task Invoke(HttpContext context)
{ {
var downstreamUrlKey = DownstreamUrl;
if (!DownstreamRoute.ReRoute.IsCached) if (!DownstreamRoute.ReRoute.IsCached)
{ {
await _next.Invoke(context); await _next.Invoke(context);
return; return;
} }
var downstreamUrlKey = DownstreamRequest.RequestUri.OriginalString;
_logger.LogDebug("started checking cache for {downstreamUrlKey}", downstreamUrlKey); _logger.LogDebug("started checking cache for {downstreamUrlKey}", downstreamUrlKey);
var cached = _outputCache.Get(downstreamUrlKey); var cached = _outputCache.Get(downstreamUrlKey);

View File

@ -31,6 +31,7 @@ using Ocelot.Middleware;
using Ocelot.QueryStrings; using Ocelot.QueryStrings;
using Ocelot.RateLimit; using Ocelot.RateLimit;
using Ocelot.Request.Builder; using Ocelot.Request.Builder;
using Ocelot.Request.Mapper;
using Ocelot.Requester; using Ocelot.Requester;
using Ocelot.Requester.QoS; using Ocelot.Requester.QoS;
using Ocelot.Responder; using Ocelot.Responder;
@ -160,6 +161,7 @@ namespace Ocelot.DependencyInjection
services.TryAddSingleton<IAuthenticationHandlerCreator, AuthenticationHandlerCreator>(); services.TryAddSingleton<IAuthenticationHandlerCreator, AuthenticationHandlerCreator>();
services.TryAddSingleton<IRateLimitCounterHandler, MemoryCacheRateLimitCounterHandler>(); services.TryAddSingleton<IRateLimitCounterHandler, MemoryCacheRateLimitCounterHandler>();
services.TryAddSingleton<IHttpClientCache, MemoryHttpClientCache>(); services.TryAddSingleton<IHttpClientCache, MemoryHttpClientCache>();
services.TryAddSingleton<IRequestMapper, RequestMapper>();
// see this for why we register this as singleton http://stackoverflow.com/questions/37371264/invalidoperationexception-unable-to-resolve-service-for-type-microsoft-aspnetc // see this for why we register this as singleton http://stackoverflow.com/questions/37371264/invalidoperationexception-unable-to-resolve-service-for-type-microsoft-aspnetc
// could maybe use a scoped data repository // could maybe use a scoped data repository

View File

@ -19,7 +19,7 @@ namespace Ocelot.DownstreamRouteFinder.UrlMatcher
{ {
var variableName = GetPlaceholderVariableName(upstreamUrlPathTemplate, counterForTemplate); var variableName = GetPlaceholderVariableName(upstreamUrlPathTemplate, counterForTemplate);
var variableValue = GetPlaceholderVariableValue(upstreamUrlPath, counterForUrl); var variableValue = GetPlaceholderVariableValue(upstreamUrlPathTemplate, variableName, upstreamUrlPath, counterForUrl);
var templateVariableNameAndValue = new UrlPathPlaceholderNameAndValue(variableName, variableValue); var templateVariableNameAndValue = new UrlPathPlaceholderNameAndValue(variableName, variableValue);
@ -40,11 +40,11 @@ namespace Ocelot.DownstreamRouteFinder.UrlMatcher
return new OkResponse<List<UrlPathPlaceholderNameAndValue>>(templateKeysAndValues); return new OkResponse<List<UrlPathPlaceholderNameAndValue>>(templateKeysAndValues);
} }
private string GetPlaceholderVariableValue(string urlPath, int counterForUrl) private string GetPlaceholderVariableValue(string urlPathTemplate, string variableName, string urlPath, int counterForUrl)
{ {
var positionOfNextSlash = urlPath.IndexOf('/', counterForUrl); var positionOfNextSlash = urlPath.IndexOf('/', counterForUrl);
if(positionOfNextSlash == -1) if (positionOfNextSlash == -1 || urlPathTemplate.Trim('/').EndsWith(variableName))
{ {
positionOfNextSlash = urlPath.Length; positionOfNextSlash = urlPath.Length;
} }

View File

@ -4,6 +4,7 @@ using Ocelot.DownstreamUrlCreator.UrlTemplateReplacer;
using Ocelot.Infrastructure.RequestData; using Ocelot.Infrastructure.RequestData;
using Ocelot.Logging; using Ocelot.Logging;
using Ocelot.Middleware; using Ocelot.Middleware;
using System;
namespace Ocelot.DownstreamUrlCreator.Middleware namespace Ocelot.DownstreamUrlCreator.Middleware
{ {
@ -42,23 +43,15 @@ namespace Ocelot.DownstreamUrlCreator.Middleware
return; return;
} }
var dsScheme = DownstreamRoute.ReRoute.DownstreamScheme; var uriBuilder = new UriBuilder(DownstreamRequest.RequestUri)
var dsHostAndPort = HostAndPort;
var dsUrl = _urlBuilder.Build(dsPath.Data.Value, dsScheme, dsHostAndPort);
if (dsUrl.IsError)
{ {
_logger.LogDebug("IUrlBuilder returned an error, setting pipeline error"); Path = dsPath.Data.Value,
Scheme = DownstreamRoute.ReRoute.DownstreamScheme
};
SetPipelineError(dsUrl.Errors); DownstreamRequest.RequestUri = uriBuilder.Uri;
return;
}
_logger.LogDebug("downstream url is {downstreamUrl.Data.Value}", dsUrl.Data.Value); _logger.LogDebug("downstream url is {downstreamUrl.Data.Value}", DownstreamRequest.RequestUri);
SetDownstreamUrlForThisRequest(dsUrl.Data.Value);
_logger.LogDebug("calling next middleware"); _logger.LogDebug("calling next middleware");

View File

@ -25,6 +25,7 @@ namespace Ocelot.DownstreamUrlCreator
return new ErrorResponse<DownstreamUrl>(new List<Error> { new DownstreamHostNullOrEmptyError() }); return new ErrorResponse<DownstreamUrl>(new List<Error> { new DownstreamHostNullOrEmptyError() });
} }
var builder = new UriBuilder var builder = new UriBuilder
{ {
Host = downstreamHostAndPort.DownstreamHost, Host = downstreamHostAndPort.DownstreamHost,

View File

@ -28,6 +28,7 @@
UnableToFindLoadBalancerError, UnableToFindLoadBalancerError,
RequestTimedOutError, RequestTimedOutError,
UnableToFindQoSProviderError, UnableToFindQoSProviderError,
UnableToSetConfigInConsulError UnableToSetConfigInConsulError,
UnmappableRequestError
} }
} }

View File

@ -1,10 +1,9 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
using Ocelot.Configuration; using Ocelot.Configuration;
using Ocelot.Infrastructure.Claims.Parser; using Ocelot.Infrastructure.Claims.Parser;
using Ocelot.Responses; using Ocelot.Responses;
using System.Net.Http;
namespace Ocelot.Headers namespace Ocelot.Headers
{ {
@ -17,25 +16,25 @@ namespace Ocelot.Headers
_claimsParser = claimsParser; _claimsParser = claimsParser;
} }
public Response SetHeadersOnContext(List<ClaimToThing> claimsToThings, HttpContext context) public Response SetHeadersOnDownstreamRequest(List<ClaimToThing> claimsToThings, IEnumerable<System.Security.Claims.Claim> claims, HttpRequestMessage downstreamRequest)
{ {
foreach (var config in claimsToThings) foreach (var config in claimsToThings)
{ {
var value = _claimsParser.GetValue(context.User.Claims, config.NewKey, config.Delimiter, config.Index); var value = _claimsParser.GetValue(claims, config.NewKey, config.Delimiter, config.Index);
if (value.IsError) if (value.IsError)
{ {
return new ErrorResponse(value.Errors); return new ErrorResponse(value.Errors);
} }
var exists = context.Request.Headers.FirstOrDefault(x => x.Key == config.ExistingKey); var exists = downstreamRequest.Headers.FirstOrDefault(x => x.Key == config.ExistingKey);
if (!string.IsNullOrEmpty(exists.Key)) if (!string.IsNullOrEmpty(exists.Key))
{ {
context.Request.Headers.Remove(exists); downstreamRequest.Headers.Remove(exists.Key);
} }
context.Request.Headers.Add(config.ExistingKey, new StringValues(value.Data)); downstreamRequest.Headers.Add(config.ExistingKey, value.Data);
} }
return new OkResponse(); return new OkResponse();

View File

@ -1,13 +1,13 @@
using System.Collections.Generic; namespace Ocelot.Headers
using Microsoft.AspNetCore.Http; {
using System.Collections.Generic;
using System.Net.Http;
using Ocelot.Configuration; using Ocelot.Configuration;
using Ocelot.Responses; using Ocelot.Responses;
namespace Ocelot.Headers
{
public interface IAddHeadersToRequest public interface IAddHeadersToRequest
{ {
Response SetHeadersOnContext(List<ClaimToThing> claimsToThings, Response SetHeadersOnDownstreamRequest(List<ClaimToThing> claimsToThings, IEnumerable<System.Security.Claims.Claim> claims, HttpRequestMessage downstreamRequest);
HttpContext context);
} }
} }

View File

@ -32,7 +32,7 @@ namespace Ocelot.Headers.Middleware
{ {
_logger.LogDebug("this route has instructions to convert claims to headers"); _logger.LogDebug("this route has instructions to convert claims to headers");
var response = _addHeadersToRequest.SetHeadersOnContext(DownstreamRoute.ReRoute.ClaimsToHeaders, context); var response = _addHeadersToRequest.SetHeadersOnDownstreamRequest(DownstreamRoute.ReRoute.ClaimsToHeaders, context.User.Claims, DownstreamRequest);
if (response.IsError) if (response.IsError)
{ {

View File

@ -44,7 +44,13 @@ namespace Ocelot.LoadBalancer.Middleware
return; return;
} }
SetHostAndPortForThisRequest(hostAndPort.Data); var uriBuilder = new UriBuilder(DownstreamRequest.RequestUri);
uriBuilder.Host = hostAndPort.Data.DownstreamHost;
if (hostAndPort.Data.DownstreamPort > 0)
{
uriBuilder.Port = hostAndPort.Data.DownstreamPort;
}
DownstreamRequest.RequestUri = uriBuilder.Uri;
_logger.LogDebug("calling next middleware"); _logger.LogDebug("calling next middleware");

View File

@ -3,7 +3,6 @@ using System.Net.Http;
using Ocelot.DownstreamRouteFinder; using Ocelot.DownstreamRouteFinder;
using Ocelot.Errors; using Ocelot.Errors;
using Ocelot.Infrastructure.RequestData; using Ocelot.Infrastructure.RequestData;
using Ocelot.Values;
namespace Ocelot.Middleware namespace Ocelot.Middleware
{ {
@ -19,89 +18,33 @@ namespace Ocelot.Middleware
public string MiddlwareName { get; } public string MiddlwareName { get; }
public bool PipelineError public bool PipelineError => _requestScopedDataRepository.Get<bool>("OcelotMiddlewareError").Data;
{
get
{
var response = _requestScopedDataRepository.Get<bool>("OcelotMiddlewareError");
return response.Data;
}
}
public List<Error> PipelineErrors public List<Error> PipelineErrors => _requestScopedDataRepository.Get<List<Error>>("OcelotMiddlewareErrors").Data;
{
get
{
var response = _requestScopedDataRepository.Get<List<Error>>("OcelotMiddlewareErrors");
return response.Data;
}
}
public DownstreamRoute DownstreamRoute public DownstreamRoute DownstreamRoute => _requestScopedDataRepository.Get<DownstreamRoute>("DownstreamRoute").Data;
{
get
{
var downstreamRoute = _requestScopedDataRepository.Get<DownstreamRoute>("DownstreamRoute");
return downstreamRoute.Data;
}
}
public string DownstreamUrl public Request.Request Request => _requestScopedDataRepository.Get<Request.Request>("Request").Data;
{
get
{
var downstreamUrl = _requestScopedDataRepository.Get<string>("DownstreamUrl");
return downstreamUrl.Data;
}
}
public Request.Request Request public HttpRequestMessage DownstreamRequest => _requestScopedDataRepository.Get<HttpRequestMessage>("DownstreamRequest").Data;
{
get
{
var request = _requestScopedDataRepository.Get<Request.Request>("Request");
return request.Data;
}
}
public HttpResponseMessage HttpResponseMessage public HttpResponseMessage HttpResponseMessage => _requestScopedDataRepository.Get<HttpResponseMessage>("HttpResponseMessage").Data;
{
get
{
var request = _requestScopedDataRepository.Get<HttpResponseMessage>("HttpResponseMessage");
return request.Data;
}
}
public HostAndPort HostAndPort
{
get
{
var hostAndPort = _requestScopedDataRepository.Get<HostAndPort>("HostAndPort");
return hostAndPort.Data;
}
}
public void SetHostAndPortForThisRequest(HostAndPort hostAndPort)
{
_requestScopedDataRepository.Add("HostAndPort", hostAndPort);
}
public void SetDownstreamRouteForThisRequest(DownstreamRoute downstreamRoute) public void SetDownstreamRouteForThisRequest(DownstreamRoute downstreamRoute)
{ {
_requestScopedDataRepository.Add("DownstreamRoute", downstreamRoute); _requestScopedDataRepository.Add("DownstreamRoute", downstreamRoute);
} }
public void SetDownstreamUrlForThisRequest(string downstreamUrl)
{
_requestScopedDataRepository.Add("DownstreamUrl", downstreamUrl);
}
public void SetUpstreamRequestForThisRequest(Request.Request request) public void SetUpstreamRequestForThisRequest(Request.Request request)
{ {
_requestScopedDataRepository.Add("Request", request); _requestScopedDataRepository.Add("Request", request);
} }
public void SetDownstreamRequest(HttpRequestMessage request)
{
_requestScopedDataRepository.Add("DownstreamRequest", request);
}
public void SetHttpResponseMessageThisRequest(HttpResponseMessage responseMessage) public void SetHttpResponseMessageThisRequest(HttpResponseMessage responseMessage)
{ {
_requestScopedDataRepository.Add("HttpResponseMessage", responseMessage); _requestScopedDataRepository.Add("HttpResponseMessage", responseMessage);

View File

@ -62,6 +62,9 @@ namespace Ocelot.Middleware
// This is registered first so it can catch any errors and issue an appropriate response // This is registered first so it can catch any errors and issue an appropriate response
builder.UseResponderMiddleware(); builder.UseResponderMiddleware();
// Initialises downstream request
builder.UseDownstreamRequestInitialiser();
// Then we get the downstream route information // Then we get the downstream route information
builder.UseDownstreamRouteFinderMiddleware(); builder.UseDownstreamRouteFinderMiddleware();

View File

@ -4,6 +4,9 @@ using Microsoft.AspNetCore.Http;
using Ocelot.Configuration; using Ocelot.Configuration;
using Ocelot.Infrastructure.Claims.Parser; using Ocelot.Infrastructure.Claims.Parser;
using Ocelot.Responses; using Ocelot.Responses;
using System.Security.Claims;
using System.Net.Http;
using System;
namespace Ocelot.QueryStrings namespace Ocelot.QueryStrings
{ {
@ -16,13 +19,13 @@ namespace Ocelot.QueryStrings
_claimsParser = claimsParser; _claimsParser = claimsParser;
} }
public Response SetQueriesOnContext(List<ClaimToThing> claimsToThings, HttpContext context) public Response SetQueriesOnDownstreamRequest(List<ClaimToThing> claimsToThings, IEnumerable<Claim> claims, HttpRequestMessage downstreamRequest)
{ {
var queryDictionary = ConvertQueryStringToDictionary(context); var queryDictionary = ConvertQueryStringToDictionary(downstreamRequest.RequestUri.Query);
foreach (var config in claimsToThings) foreach (var config in claimsToThings)
{ {
var value = _claimsParser.GetValue(context.User.Claims, config.NewKey, config.Delimiter, config.Index); var value = _claimsParser.GetValue(claims, config.NewKey, config.Delimiter, config.Index);
if (value.IsError) if (value.IsError)
{ {
@ -41,22 +44,24 @@ namespace Ocelot.QueryStrings
} }
} }
context.Request.QueryString = ConvertDictionaryToQueryString(queryDictionary); var uriBuilder = new UriBuilder(downstreamRequest.RequestUri);
uriBuilder.Query = ConvertDictionaryToQueryString(queryDictionary);
downstreamRequest.RequestUri = uriBuilder.Uri;
return new OkResponse(); return new OkResponse();
} }
private Dictionary<string, string> ConvertQueryStringToDictionary(HttpContext context) private Dictionary<string, string> ConvertQueryStringToDictionary(string queryString)
{ {
return Microsoft.AspNetCore.WebUtilities.QueryHelpers.ParseQuery(context.Request.QueryString.Value) return Microsoft.AspNetCore.WebUtilities.QueryHelpers
.ParseQuery(queryString)
.ToDictionary(q => q.Key, q => q.Value.FirstOrDefault() ?? string.Empty); .ToDictionary(q => q.Key, q => q.Value.FirstOrDefault() ?? string.Empty);
} }
private Microsoft.AspNetCore.Http.QueryString ConvertDictionaryToQueryString(Dictionary<string, string> queryDictionary) private string ConvertDictionaryToQueryString(Dictionary<string, string> queryDictionary)
{ {
var newQueryString = Microsoft.AspNetCore.WebUtilities.QueryHelpers.AddQueryString("", queryDictionary); return Microsoft.AspNetCore.WebUtilities.QueryHelpers.AddQueryString("", queryDictionary);
return new Microsoft.AspNetCore.Http.QueryString(newQueryString);
} }
} }
} }

View File

@ -2,12 +2,13 @@
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Ocelot.Configuration; using Ocelot.Configuration;
using Ocelot.Responses; using Ocelot.Responses;
using System.Net.Http;
using System.Security.Claims;
namespace Ocelot.QueryStrings namespace Ocelot.QueryStrings
{ {
public interface IAddQueriesToRequest public interface IAddQueriesToRequest
{ {
Response SetQueriesOnContext(List<ClaimToThing> claimsToThings, Response SetQueriesOnDownstreamRequest(List<ClaimToThing> claimsToThings, IEnumerable<Claim> claims, HttpRequestMessage downstreamRequest);
HttpContext context);
} }
} }

View File

@ -32,7 +32,7 @@ namespace Ocelot.QueryStrings.Middleware
{ {
_logger.LogDebug("this route has instructions to convert claims to queries"); _logger.LogDebug("this route has instructions to convert claims to queries");
var response = _addQueriesToRequest.SetQueriesOnContext(DownstreamRoute.ReRoute.ClaimsToQueries, context); var response = _addQueriesToRequest.SetQueriesOnDownstreamRequest(DownstreamRoute.ReRoute.ClaimsToQueries, context.User.Claims, DownstreamRequest);
if (response.IsError) if (response.IsError)
{ {

View File

@ -111,11 +111,7 @@ namespace Ocelot.RateLimit.Middleware
public bool IsWhitelisted(ClientRequestIdentity requestIdentity, RateLimitOptions option) public bool IsWhitelisted(ClientRequestIdentity requestIdentity, RateLimitOptions option)
{ {
if (option.ClientWhitelist.Contains(requestIdentity.ClientId)) return option.ClientWhitelist.Contains(requestIdentity.ClientId);
{
return true;
}
return false;
} }
public virtual void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule) public virtual void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule)

View File

@ -1,38 +1,18 @@
using System.IO; using System.Threading.Tasks;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Ocelot.Responses; using Ocelot.Responses;
using Ocelot.Configuration;
using Ocelot.Requester.QoS; using Ocelot.Requester.QoS;
using System.Net.Http;
namespace Ocelot.Request.Builder namespace Ocelot.Request.Builder
{ {
public sealed class HttpRequestCreator : IRequestCreator public sealed class HttpRequestCreator : IRequestCreator
{ {
public async Task<Response<Request>> Build( public async Task<Response<Request>> Build(
string httpMethod, HttpRequestMessage httpRequestMessage,
string downstreamUrl,
Stream content,
IHeaderDictionary headers,
QueryString queryString,
string contentType,
RequestId.RequestId requestId,
bool isQos, bool isQos,
IQoSProvider qosProvider) IQoSProvider qosProvider)
{ {
var request = await new RequestBuilder() return new OkResponse<Request>(new Request(httpRequestMessage, isQos, qosProvider));
.WithHttpMethod(httpMethod)
.WithDownstreamUrl(downstreamUrl)
.WithQueryString(queryString)
.WithContent(content)
.WithContentType(contentType)
.WithHeaders(headers)
.WithRequestId(requestId)
.WithIsQos(isQos)
.WithQos(qosProvider)
.Build();
return new OkResponse<Request>(request);
} }
} }
} }

View File

@ -1,20 +1,15 @@
using System.IO; namespace Ocelot.Request.Builder
{
using System.Net.Http;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Ocelot.Requester.QoS; using Ocelot.Requester.QoS;
using Ocelot.Responses; using Ocelot.Responses;
namespace Ocelot.Request.Builder
{
public interface IRequestCreator public interface IRequestCreator
{ {
Task<Response<Request>> Build(string httpMethod, Task<Response<Request>> Build(
string downstreamUrl, HttpRequestMessage httpRequestMessage,
Stream content,
IHeaderDictionary headers,
QueryString queryString,
string contentType,
RequestId.RequestId requestId,
bool isQos, bool isQos,
IQoSProvider qosProvider); IQoSProvider qosProvider);
} }

View File

@ -1,177 +0,0 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
using Ocelot.Requester.QoS;
namespace Ocelot.Request.Builder
{
internal sealed class RequestBuilder
{
private HttpMethod _method;
private string _downstreamUrl;
private QueryString _queryString;
private Stream _content;
private string _contentType;
private IHeaderDictionary _headers;
private RequestId.RequestId _requestId;
private readonly string[] _unsupportedHeaders = {"host"};
private bool _isQos;
private IQoSProvider _qoSProvider;
public RequestBuilder WithHttpMethod(string httpMethod)
{
_method = new HttpMethod(httpMethod);
return this;
}
public RequestBuilder WithDownstreamUrl(string downstreamUrl)
{
_downstreamUrl = downstreamUrl;
return this;
}
public RequestBuilder WithQueryString(QueryString queryString)
{
_queryString = queryString;
return this;
}
public RequestBuilder WithContent(Stream content)
{
_content = content;
return this;
}
public RequestBuilder WithContentType(string contentType)
{
_contentType = contentType;
return this;
}
public RequestBuilder WithHeaders(IHeaderDictionary headers)
{
_headers = headers;
return this;
}
public RequestBuilder WithRequestId(RequestId.RequestId requestId)
{
_requestId = requestId;
return this;
}
public RequestBuilder WithIsQos(bool isqos)
{
_isQos = isqos;
return this;
}
public RequestBuilder WithQos(IQoSProvider qoSProvider)
{
_qoSProvider = qoSProvider;
return this;
}
public async Task<Request> Build()
{
var uri = CreateUri();
var httpRequestMessage = new HttpRequestMessage(_method, uri);
await AddContentToRequest(httpRequestMessage);
AddContentTypeToRequest(httpRequestMessage);
AddHeadersToRequest(httpRequestMessage);
if (ShouldAddRequestId(_requestId, httpRequestMessage.Headers))
{
AddRequestIdHeader(_requestId, httpRequestMessage);
}
return new Request(httpRequestMessage,_isQos, _qoSProvider);
}
private Uri CreateUri()
{
var uri = new Uri(string.Format("{0}{1}", _downstreamUrl, _queryString.ToUriComponent()));
return uri;
}
private async Task AddContentToRequest(HttpRequestMessage httpRequestMessage)
{
if (_content != null)
{
httpRequestMessage.Content = new ByteArrayContent(await ToByteArray(_content));
}
}
private void AddContentTypeToRequest(HttpRequestMessage httpRequestMessage)
{
if (!string.IsNullOrEmpty(_contentType))
{
httpRequestMessage.Content.Headers.Remove("Content-Type");
httpRequestMessage.Content.Headers.TryAddWithoutValidation("Content-Type", _contentType);
}
}
private void AddHeadersToRequest(HttpRequestMessage httpRequestMessage)
{
if (_headers != null)
{
_headers.Remove("Content-Type");
foreach (var header in _headers)
{
//todo get rid of if..
if (IsSupportedHeader(header))
{
httpRequestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
}
}
}
}
private bool IsSupportedHeader(KeyValuePair<string, StringValues> header)
{
return !_unsupportedHeaders.Contains(header.Key.ToLower());
}
private void AddRequestIdHeader(RequestId.RequestId requestId, HttpRequestMessage httpRequestMessage)
{
httpRequestMessage.Headers.Add(requestId.RequestIdKey, requestId.RequestIdValue);
}
private bool RequestIdInHeaders(RequestId.RequestId requestId, HttpRequestHeaders headers)
{
IEnumerable<string> value;
return headers.TryGetValues(requestId.RequestIdKey, out value);
}
private bool ShouldAddRequestId(RequestId.RequestId requestId, HttpRequestHeaders headers)
{
return !string.IsNullOrEmpty(requestId?.RequestIdKey)
&& !string.IsNullOrEmpty(requestId.RequestIdValue)
&& !RequestIdInHeaders(requestId, headers);
}
private async Task<byte[]> ToByteArray(Stream stream)
{
using (stream)
{
using (var memStream = new MemoryStream())
{
await stream.CopyToAsync(memStream);
return memStream.ToArray();
}
}
}
}
}

View File

@ -0,0 +1,12 @@
namespace Ocelot.Request.Mapper
{
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Ocelot.Responses;
public interface IRequestMapper
{
Task<Response<HttpRequestMessage>> Map(HttpRequest request);
}
}

View File

@ -0,0 +1,90 @@
namespace Ocelot.Request.Mapper
{
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.Extensions.Primitives;
using Ocelot.Responses;
public class RequestMapper : IRequestMapper
{
private readonly string[] _unsupportedHeaders = { "host" };
public async Task<Response<HttpRequestMessage>> Map(HttpRequest request)
{
try
{
var requestMessage = new HttpRequestMessage()
{
Content = await MapContent(request),
Method = MapMethod(request),
RequestUri = MapUri(request)
};
MapHeaders(request, requestMessage);
return new OkResponse<HttpRequestMessage>(requestMessage);
}
catch (Exception ex)
{
return new ErrorResponse<HttpRequestMessage>(new UnmappableRequestError(ex));
}
}
private async Task<HttpContent> MapContent(HttpRequest request)
{
if (request.Body == null)
{
return null;
}
return new ByteArrayContent(await ToByteArray(request.Body));
}
private HttpMethod MapMethod(HttpRequest request)
{
return new HttpMethod(request.Method);
}
private Uri MapUri(HttpRequest request)
{
return new Uri(request.GetEncodedUrl());
}
private void MapHeaders(HttpRequest request, HttpRequestMessage requestMessage)
{
foreach (var header in request.Headers)
{
//todo get rid of if..
if (IsSupportedHeader(header))
{
requestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
}
}
}
private async Task<byte[]> ToByteArray(Stream stream)
{
using (stream)
{
using (var memStream = new MemoryStream())
{
await stream.CopyToAsync(memStream);
return memStream.ToArray();
}
}
}
private bool IsSupportedHeader(KeyValuePair<string, StringValues> header)
{
return !_unsupportedHeaders.Contains(header.Key.ToLower());
}
}
}

View File

@ -0,0 +1,12 @@
namespace Ocelot.Request.Mapper
{
using Ocelot.Errors;
using System;
public class UnmappableRequestError : Error
{
public UnmappableRequestError(Exception ex) : base($"Error when parsing incoming request, exception: {ex.Message}", OcelotErrorCode.UnmappableRequestError)
{
}
}
}

View File

@ -0,0 +1,47 @@
namespace Ocelot.Request.Middleware
{
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Ocelot.Infrastructure.RequestData;
using Ocelot.Logging;
using Ocelot.Middleware;
public class DownstreamRequestInitialiserMiddleware : OcelotMiddleware
{
private readonly RequestDelegate _next;
private readonly IOcelotLogger _logger;
private readonly Mapper.IRequestMapper _requestMapper;
public DownstreamRequestInitialiserMiddleware(RequestDelegate next,
IOcelotLoggerFactory loggerFactory,
IRequestScopedDataRepository requestScopedDataRepository,
Mapper.IRequestMapper requestMapper)
:base(requestScopedDataRepository)
{
_next = next;
_logger = loggerFactory.CreateLogger<DownstreamRequestInitialiserMiddleware>();
_requestMapper = requestMapper;
}
public async Task Invoke(HttpContext context)
{
_logger.LogDebug("started calling request builder middleware");
var downstreamRequest = await _requestMapper.Map(context.Request);
if (downstreamRequest.IsError)
{
SetPipelineError(downstreamRequest.Errors);
return;
}
SetDownstreamRequest(downstreamRequest.Data);
_logger.LogDebug("calling next middleware");
await _next.Invoke(context);
_logger.LogDebug("succesfully called next middleware");
}
}
}

View File

@ -43,14 +43,8 @@ namespace Ocelot.Request.Middleware
return; return;
} }
var buildResult = await _requestCreator var buildResult = await _requestCreator.Build(
.Build(context.Request.Method, DownstreamRequest,
DownstreamUrl,
context.Request.Body,
context.Request.Headers,
context.Request.QueryString,
context.Request.ContentType,
new RequestId.RequestId(DownstreamRoute?.ReRoute?.RequestIdKey, context.TraceIdentifier),
DownstreamRoute.ReRoute.IsQos, DownstreamRoute.ReRoute.IsQos,
qosProvider.Data); qosProvider.Data);

View File

@ -8,5 +8,10 @@ namespace Ocelot.Request.Middleware
{ {
return builder.UseMiddleware<HttpRequestBuilderMiddleware>(); return builder.UseMiddleware<HttpRequestBuilderMiddleware>();
} }
public static IApplicationBuilder UseDownstreamRequestInitialiser(this IApplicationBuilder builder)
{
return builder.UseMiddleware<DownstreamRequestInitialiserMiddleware>();
}
} }
} }

View File

@ -5,6 +5,9 @@ using Microsoft.Extensions.Primitives;
using Ocelot.Infrastructure.RequestData; using Ocelot.Infrastructure.RequestData;
using Ocelot.Logging; using Ocelot.Logging;
using Ocelot.Middleware; using Ocelot.Middleware;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Collections.Generic;
namespace Ocelot.RequestId.Middleware namespace Ocelot.RequestId.Middleware
{ {
@ -30,8 +33,6 @@ namespace Ocelot.RequestId.Middleware
SetOcelotRequestId(context); SetOcelotRequestId(context);
_logger.LogDebug("set requestId");
_logger.TraceInvokeNext(); _logger.TraceInvokeNext();
await _next.Invoke(context); await _next.Invoke(context);
_logger.TraceInvokeNextCompleted(); _logger.TraceInvokeNextCompleted();
@ -40,21 +41,40 @@ namespace Ocelot.RequestId.Middleware
private void SetOcelotRequestId(HttpContext context) private void SetOcelotRequestId(HttpContext context)
{ {
var key = DefaultRequestIdKey.Value; // if get request ID is set on upstream request then retrieve it
var key = DownstreamRoute.ReRoute.RequestIdKey ?? DefaultRequestIdKey.Value;
if (DownstreamRoute.ReRoute.RequestIdKey != null) StringValues upstreamRequestIds;
if (context.Request.Headers.TryGetValue(key, out upstreamRequestIds))
{ {
key = DownstreamRoute.ReRoute.RequestIdKey; context.TraceIdentifier = upstreamRequestIds.First();
} }
StringValues requestId; // set request ID on downstream request, if required
var requestId = new RequestId(DownstreamRoute?.ReRoute?.RequestIdKey, context.TraceIdentifier);
if (context.Request.Headers.TryGetValue(key, out requestId)) if (ShouldAddRequestId(requestId, DownstreamRequest.Headers))
{ {
_requestScopedDataRepository.Add("RequestId", requestId.First()); AddRequestIdHeader(requestId, DownstreamRequest);
}
}
context.TraceIdentifier = requestId; private bool ShouldAddRequestId(RequestId requestId, HttpRequestHeaders headers)
} {
return !string.IsNullOrEmpty(requestId?.RequestIdKey)
&& !string.IsNullOrEmpty(requestId.RequestIdValue)
&& !RequestIdInHeaders(requestId, headers);
}
private bool RequestIdInHeaders(RequestId requestId, HttpRequestHeaders headers)
{
IEnumerable<string> value;
return headers.TryGetValues(requestId.RequestIdKey, out value);
}
private void AddRequestIdHeader(RequestId requestId, HttpRequestMessage httpRequestMessage)
{
httpRequestMessage.Headers.Add(requestId.RequestIdKey, requestId.RequestIdValue);
} }
} }
} }

View File

@ -46,7 +46,6 @@ namespace Ocelot.AcceptanceTests
DownstreamPort = 51879, DownstreamPort = 51879,
UpstreamPathTemplate = "/", UpstreamPathTemplate = "/",
UpstreamHttpMethod = "Get", UpstreamHttpMethod = "Get",
} }
} }
}; };
@ -292,6 +291,34 @@ namespace Ocelot.AcceptanceTests
.BDDfy(); .BDDfy();
} }
[Fact]
public void should_return_response_200_with_placeholder_for_final_url_path()
{
var configuration = new FileConfiguration
{
ReRoutes = new List<FileReRoute>
{
new FileReRoute
{
DownstreamPathTemplate = "/api/{urlPath}",
DownstreamScheme = "http",
DownstreamHost = "localhost",
DownstreamPort = 51879,
UpstreamPathTemplate = "/myApp1Name/api/{urlPath}",
UpstreamHttpMethod = "Get",
}
}
};
this.Given(x => x.GivenThereIsAServiceRunningOn("http://localhost:51879/myApp1Name/api/products/1", 200, "Some Product"))
.And(x => _steps.GivenThereIsAConfiguration(configuration))
.And(x => _steps.GivenOcelotIsRunning())
.When(x => _steps.WhenIGetUrlOnTheApiGateway("/myApp1Name/api/products/1"))
.Then(x => _steps.ThenTheStatusCodeShouldBe(HttpStatusCode.OK))
.And(x => _steps.ThenTheResponseBodyShouldBe("Some Product"))
.BDDfy();
}
private void GivenThereIsAServiceRunningOn(string url, int statusCode, string responseBody) private void GivenThereIsAServiceRunningOn(string url, int statusCode, string responseBody)
{ {
_builder = new WebHostBuilder() _builder = new WebHostBuilder()

View File

@ -2,11 +2,9 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Net.Http; using System.Net.Http;
using CacheManager.Core;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Moq; using Moq;
using Ocelot.Cache; using Ocelot.Cache;
using Ocelot.Cache.Middleware; using Ocelot.Cache.Middleware;
@ -37,7 +35,6 @@ namespace Ocelot.UnitTests.Cache
_cacheManager = new Mock<IOcelotCache<HttpResponseMessage>>(); _cacheManager = new Mock<IOcelotCache<HttpResponseMessage>>();
_scopedRepo = new Mock<IRequestScopedDataRepository>(); _scopedRepo = new Mock<IRequestScopedDataRepository>();
_url = "http://localhost:51879"; _url = "http://localhost:51879";
var builder = new WebHostBuilder() var builder = new WebHostBuilder()
.ConfigureServices(x => .ConfigureServices(x =>
@ -57,6 +54,10 @@ namespace Ocelot.UnitTests.Cache
app.UseOutputCacheMiddleware(); app.UseOutputCacheMiddleware();
}); });
_scopedRepo
.Setup(sr => sr.Get<HttpRequestMessage>("DownstreamRequest"))
.Returns(new OkResponse<HttpRequestMessage>(new HttpRequestMessage(HttpMethod.Get, "https://some.url/blah?abcd=123")));
_server = new TestServer(builder); _server = new TestServer(builder);
_client = _server.CreateClient(); _client = _server.CreateClient();
} }

View File

@ -140,6 +140,21 @@ namespace Ocelot.UnitTests.DownstreamRouteFinder.UrlMatcher
.BDDfy(); .BDDfy();
} }
[Fact]
public void can_match_down_stream_url_with_downstream_template_with_place_holder_to_final_url_path()
{
var expectedTemplates = new List<UrlPathPlaceholderNameAndValue>
{
new UrlPathPlaceholderNameAndValue("{finalUrlPath}", "product/products/categories/"),
};
this.Given(x => x.GivenIHaveAUpstreamPath("api/product/products/categories/"))
.And(x => x.GivenIHaveAnUpstreamUrlTemplate("api/{finalUrlPath}/"))
.When(x => x.WhenIFindTheUrlVariableNamesAndValues())
.And(x => x.ThenTheTemplatesVariablesAre(expectedTemplates))
.BDDfy();
}
private void ThenTheTemplatesVariablesAre(List<UrlPathPlaceholderNameAndValue> expectedResults) private void ThenTheTemplatesVariablesAre(List<UrlPathPlaceholderNameAndValue> expectedResults)
{ {
foreach (var expectedResult in expectedResults) foreach (var expectedResult in expectedResults)

View File

@ -5,12 +5,9 @@ using System.Net.Http;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Moq; using Moq;
using Ocelot.Configuration;
using Ocelot.Configuration.Builder; using Ocelot.Configuration.Builder;
using Ocelot.DownstreamRouteFinder; using Ocelot.DownstreamRouteFinder;
using Ocelot.DownstreamRouteFinder.Middleware;
using Ocelot.DownstreamRouteFinder.UrlMatcher; using Ocelot.DownstreamRouteFinder.UrlMatcher;
using Ocelot.DownstreamUrlCreator; using Ocelot.DownstreamUrlCreator;
using Ocelot.DownstreamUrlCreator.Middleware; using Ocelot.DownstreamUrlCreator.Middleware;
@ -21,6 +18,7 @@ using Ocelot.Responses;
using Ocelot.Values; using Ocelot.Values;
using TestStack.BDDfy; using TestStack.BDDfy;
using Xunit; using Xunit;
using Shouldly;
namespace Ocelot.UnitTests.DownstreamUrlCreator namespace Ocelot.UnitTests.DownstreamUrlCreator
{ {
@ -33,10 +31,9 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator
private readonly TestServer _server; private readonly TestServer _server;
private readonly HttpClient _client; private readonly HttpClient _client;
private Response<DownstreamRoute> _downstreamRoute; private Response<DownstreamRoute> _downstreamRoute;
private HttpResponseMessage _result;
private OkResponse<DownstreamPath> _downstreamPath; private OkResponse<DownstreamPath> _downstreamPath;
private OkResponse<DownstreamUrl> _downstreamUrl; private HttpRequestMessage _downstreamRequest;
private HostAndPort _hostAndPort; private HttpResponseMessage _result;
public DownstreamUrlCreatorMiddlewareTests() public DownstreamUrlCreatorMiddlewareTests()
{ {
@ -63,65 +60,34 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator
app.UseDownstreamUrlCreatorMiddleware(); app.UseDownstreamUrlCreatorMiddleware();
}); });
_downstreamRequest = new HttpRequestMessage(HttpMethod.Get, "https://my.url/abc/?q=123");
_scopedRepository
.Setup(sr => sr.Get<HttpRequestMessage>("DownstreamRequest"))
.Returns(new OkResponse<HttpRequestMessage>(_downstreamRequest));
_server = new TestServer(builder); _server = new TestServer(builder);
_client = _server.CreateClient(); _client = _server.CreateClient();
} }
[Fact] [Fact]
public void should_call_dependencies_correctly() public void should_replace_scheme_and_path()
{ {
var hostAndPort = new HostAndPort("127.0.0.1", 80);
this.Given(x => x.GivenTheDownStreamRouteIs( this.Given(x => x.GivenTheDownStreamRouteIs(
new DownstreamRoute( new DownstreamRoute(
new List<UrlPathPlaceholderNameAndValue>(), new List<UrlPathPlaceholderNameAndValue>(),
new ReRouteBuilder() new ReRouteBuilder()
.WithDownstreamPathTemplate("any old string") .WithDownstreamPathTemplate("any old string")
.WithUpstreamHttpMethod("Get") .WithUpstreamHttpMethod("Get")
.WithDownstreamScheme("https")
.Build()))) .Build())))
.And(x => x.GivenTheHostAndPortIs(hostAndPort)) .And(x => x.GivenTheDownstreamRequestUriIs("http://my.url/abc?q=123"))
.And(x => x.TheUrlReplacerReturns("/api/products/1")) .And(x => x.GivenTheUrlReplacerWillReturn("/api/products/1"))
.And(x => x.TheUrlBuilderReturns("http://127.0.0.1:80/api/products/1"))
.When(x => x.WhenICallTheMiddleware()) .When(x => x.WhenICallTheMiddleware())
.Then(x => x.ThenTheScopedDataRepositoryIsCalledCorrectly()) .Then(x => x.ThenTheDownstreamRequestUriIs("https://my.url:80/api/products/1?q=123"))
.BDDfy(); .BDDfy();
} }
private void GivenTheHostAndPortIs(HostAndPort hostAndPort)
{
_hostAndPort = hostAndPort;
_scopedRepository
.Setup(x => x.Get<HostAndPort>("HostAndPort"))
.Returns(new OkResponse<HostAndPort>(_hostAndPort));
}
private void TheUrlBuilderReturns(string dsUrl)
{
_downstreamUrl = new OkResponse<DownstreamUrl>(new DownstreamUrl(dsUrl));
_urlBuilder
.Setup(x => x.Build(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<HostAndPort>()))
.Returns(_downstreamUrl);
}
private void TheUrlReplacerReturns(string downstreamUrl)
{
_downstreamPath = new OkResponse<DownstreamPath>(new DownstreamPath(downstreamUrl));
_downstreamUrlTemplateVariableReplacer
.Setup(x => x.Replace(It.IsAny<PathTemplate>(), It.IsAny<List<UrlPathPlaceholderNameAndValue>>()))
.Returns(_downstreamPath);
}
private void ThenTheScopedDataRepositoryIsCalledCorrectly()
{
_scopedRepository
.Verify(x => x.Add("DownstreamUrl", _downstreamUrl.Data.Value), Times.Once());
}
private void WhenICallTheMiddleware()
{
_result = _client.GetAsync(_url).Result;
}
private void GivenTheDownStreamRouteIs(DownstreamRoute downstreamRoute) private void GivenTheDownStreamRouteIs(DownstreamRoute downstreamRoute)
{ {
_downstreamRoute = new OkResponse<DownstreamRoute>(downstreamRoute); _downstreamRoute = new OkResponse<DownstreamRoute>(downstreamRoute);
@ -130,6 +96,29 @@ namespace Ocelot.UnitTests.DownstreamUrlCreator
.Returns(_downstreamRoute); .Returns(_downstreamRoute);
} }
private void GivenTheDownstreamRequestUriIs(string uri)
{
_downstreamRequest.RequestUri = new Uri(uri);
}
private void GivenTheUrlReplacerWillReturn(string path)
{
_downstreamPath = new OkResponse<DownstreamPath>(new DownstreamPath(path));
_downstreamUrlTemplateVariableReplacer
.Setup(x => x.Replace(It.IsAny<PathTemplate>(), It.IsAny<List<UrlPathPlaceholderNameAndValue>>()))
.Returns(_downstreamPath);
}
private void WhenICallTheMiddleware()
{
_result = _client.GetAsync(_url).Result;
}
private void ThenTheDownstreamRequestUriIs(string expectedUri)
{
_downstreamRequest.RequestUri.OriginalString.ShouldBe(expectedUri);
}
public void Dispose() public void Dispose()
{ {
_client.Dispose(); _client.Dispose();

View File

@ -1,7 +1,5 @@
using System; using System;
using Ocelot.Configuration;
using Ocelot.DownstreamUrlCreator; using Ocelot.DownstreamUrlCreator;
using Ocelot.DownstreamUrlCreator.UrlTemplateReplacer;
using Ocelot.Responses; using Ocelot.Responses;
using Ocelot.Values; using Ocelot.Values;
using Shouldly; using Shouldly;

View File

@ -1,8 +1,6 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Security.Claims; using System.Security.Claims;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
using Moq; using Moq;
using Ocelot.Configuration; using Ocelot.Configuration;
using Ocelot.Errors; using Ocelot.Errors;
@ -12,6 +10,7 @@ using Ocelot.Responses;
using Shouldly; using Shouldly;
using TestStack.BDDfy; using TestStack.BDDfy;
using Xunit; using Xunit;
using System.Net.Http;
namespace Ocelot.UnitTests.Headers namespace Ocelot.UnitTests.Headers
{ {
@ -19,8 +18,9 @@ namespace Ocelot.UnitTests.Headers
{ {
private readonly AddHeadersToRequest _addHeadersToRequest; private readonly AddHeadersToRequest _addHeadersToRequest;
private readonly Mock<IClaimsParser> _parser; private readonly Mock<IClaimsParser> _parser;
private readonly HttpRequestMessage _downstreamRequest;
private List<Claim> _claims;
private List<ClaimToThing> _configuration; private List<ClaimToThing> _configuration;
private HttpContext _context;
private Response _result; private Response _result;
private Response<string> _claimValue; private Response<string> _claimValue;
@ -28,17 +28,15 @@ namespace Ocelot.UnitTests.Headers
{ {
_parser = new Mock<IClaimsParser>(); _parser = new Mock<IClaimsParser>();
_addHeadersToRequest = new AddHeadersToRequest(_parser.Object); _addHeadersToRequest = new AddHeadersToRequest(_parser.Object);
_downstreamRequest = new HttpRequestMessage();
} }
[Fact] [Fact]
public void should_add_headers_to_context() public void should_add_headers_to_downstreamRequest()
{ {
var context = new DefaultHttpContext var claims = new List<Claim>
{
User = new ClaimsPrincipal(new ClaimsIdentity(new List<Claim>
{ {
new Claim("test", "data") new Claim("test", "data")
}))
}; };
this.Given( this.Given(
@ -46,7 +44,7 @@ namespace Ocelot.UnitTests.Headers
{ {
new ClaimToThing("header-key", "", "", 0) new ClaimToThing("header-key", "", "", 0)
})) }))
.Given(x => x.GivenHttpContext(context)) .Given(x => x.GivenClaims(claims))
.And(x => x.GivenTheClaimParserReturns(new OkResponse<string>("value"))) .And(x => x.GivenTheClaimParserReturns(new OkResponse<string>("value")))
.When(x => x.WhenIAddHeadersToTheRequest()) .When(x => x.WhenIAddHeadersToTheRequest())
.Then(x => x.ThenTheResultIsSuccess()) .Then(x => x.ThenTheResultIsSuccess())
@ -55,25 +53,19 @@ namespace Ocelot.UnitTests.Headers
} }
[Fact] [Fact]
public void if_header_exists_should_replace_it() public void should_replace_existing_headers_on_request()
{ {
var context = new DefaultHttpContext
{
User = new ClaimsPrincipal(new ClaimsIdentity(new List<Claim>
{
new Claim("test", "data")
})),
};
context.Request.Headers.Add("header-key", new StringValues("initial"));
this.Given( this.Given(
x => x.GivenConfigurationHeaderExtractorProperties(new List<ClaimToThing> x => x.GivenConfigurationHeaderExtractorProperties(new List<ClaimToThing>
{ {
new ClaimToThing("header-key", "", "", 0) new ClaimToThing("header-key", "", "", 0)
})) }))
.Given(x => x.GivenHttpContext(context)) .Given(x => x.GivenClaims(new List<Claim>
{
new Claim("test", "data")
}))
.And(x => x.GivenTheClaimParserReturns(new OkResponse<string>("value"))) .And(x => x.GivenTheClaimParserReturns(new OkResponse<string>("value")))
.And(x => x.GivenThatTheRequestContainsHeader("header-key", "initial"))
.When(x => x.WhenIAddHeadersToTheRequest()) .When(x => x.WhenIAddHeadersToTheRequest())
.Then(x => x.ThenTheResultIsSuccess()) .Then(x => x.ThenTheResultIsSuccess())
.And(x => x.ThenTheHeaderIsAdded()) .And(x => x.ThenTheHeaderIsAdded())
@ -88,7 +80,7 @@ namespace Ocelot.UnitTests.Headers
{ {
new ClaimToThing("", "", "", 0) new ClaimToThing("", "", "", 0)
})) }))
.Given(x => x.GivenHttpContext(new DefaultHttpContext())) .Given(x => x.GivenClaims(new List<Claim>()))
.And(x => x.GivenTheClaimParserReturns(new ErrorResponse<string>(new List<Error> .And(x => x.GivenTheClaimParserReturns(new ErrorResponse<string>(new List<Error>
{ {
new AnyError() new AnyError()
@ -98,10 +90,9 @@ namespace Ocelot.UnitTests.Headers
.BDDfy(); .BDDfy();
} }
private void ThenTheHeaderIsAdded() private void GivenClaims(List<Claim> claims)
{ {
var header = _context.Request.Headers.First(x => x.Key == "header-key"); _claims = claims;
header.Value.First().ShouldBe(_claimValue.Data);
} }
private void GivenConfigurationHeaderExtractorProperties(List<ClaimToThing> configuration) private void GivenConfigurationHeaderExtractorProperties(List<ClaimToThing> configuration)
@ -109,9 +100,9 @@ namespace Ocelot.UnitTests.Headers
_configuration = configuration; _configuration = configuration;
} }
private void GivenHttpContext(HttpContext context) private void GivenThatTheRequestContainsHeader(string key, string value)
{ {
_context = context; _downstreamRequest.Headers.Add(key, value);
} }
private void GivenTheClaimParserReturns(Response<string> claimValue) private void GivenTheClaimParserReturns(Response<string> claimValue)
@ -129,7 +120,7 @@ namespace Ocelot.UnitTests.Headers
private void WhenIAddHeadersToTheRequest() private void WhenIAddHeadersToTheRequest()
{ {
_result = _addHeadersToRequest.SetHeadersOnContext(_configuration, _context); _result = _addHeadersToRequest.SetHeadersOnDownstreamRequest(_configuration, _claims, _downstreamRequest);
} }
private void ThenTheResultIsSuccess() private void ThenTheResultIsSuccess()
@ -143,6 +134,12 @@ namespace Ocelot.UnitTests.Headers
_result.IsError.ShouldBe(true); _result.IsError.ShouldBe(true);
} }
private void ThenTheHeaderIsAdded()
{
var header = _downstreamRequest.Headers.First(x => x.Key == "header-key");
header.Value.First().ShouldBe(_claimValue.Data);
}
class AnyError : Error class AnyError : Error
{ {
public AnyError() public AnyError()

View File

@ -3,16 +3,13 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Net.Http; using System.Net.Http;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Moq; using Moq;
using Ocelot.Configuration; using Ocelot.Configuration;
using Ocelot.Configuration.Builder; using Ocelot.Configuration.Builder;
using Ocelot.DownstreamRouteFinder; using Ocelot.DownstreamRouteFinder;
using Ocelot.DownstreamRouteFinder.UrlMatcher; using Ocelot.DownstreamRouteFinder.UrlMatcher;
using Ocelot.DownstreamUrlCreator.Middleware;
using Ocelot.Headers; using Ocelot.Headers;
using Ocelot.Headers.Middleware; using Ocelot.Headers.Middleware;
using Ocelot.Infrastructure.RequestData; using Ocelot.Infrastructure.RequestData;
@ -27,6 +24,7 @@ namespace Ocelot.UnitTests.Headers
{ {
private readonly Mock<IRequestScopedDataRepository> _scopedRepository; private readonly Mock<IRequestScopedDataRepository> _scopedRepository;
private readonly Mock<IAddHeadersToRequest> _addHeaders; private readonly Mock<IAddHeadersToRequest> _addHeaders;
private readonly HttpRequestMessage _downstreamRequest;
private readonly string _url; private readonly string _url;
private readonly TestServer _server; private readonly TestServer _server;
private readonly HttpClient _client; private readonly HttpClient _client;
@ -58,6 +56,12 @@ namespace Ocelot.UnitTests.Headers
app.UseHttpRequestHeadersBuilderMiddleware(); app.UseHttpRequestHeadersBuilderMiddleware();
}); });
_downstreamRequest = new HttpRequestMessage();
_scopedRepository
.Setup(sr => sr.Get<HttpRequestMessage>("DownstreamRequest"))
.Returns(new OkResponse<HttpRequestMessage>(_downstreamRequest));
_server = new TestServer(builder); _server = new TestServer(builder);
_client = _server.CreateClient(); _client = _server.CreateClient();
} }
@ -76,25 +80,29 @@ namespace Ocelot.UnitTests.Headers
.Build()); .Build());
this.Given(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) this.Given(x => x.GivenTheDownStreamRouteIs(downstreamRoute))
.And(x => x.GivenTheAddHeadersToRequestReturns()) .And(x => x.GivenTheAddHeadersToDownstreamRequestReturnsOk())
.When(x => x.WhenICallTheMiddleware()) .When(x => x.WhenICallTheMiddleware())
.Then(x => x.ThenTheAddHeadersToRequestIsCalledCorrectly()) .Then(x => x.ThenTheAddHeadersToRequestIsCalledCorrectly())
.BDDfy(); .BDDfy();
} }
private void GivenTheAddHeadersToRequestReturns() private void GivenTheAddHeadersToDownstreamRequestReturnsOk()
{ {
_addHeaders _addHeaders
.Setup(x => x.SetHeadersOnContext(It.IsAny<List<ClaimToThing>>(), .Setup(x => x.SetHeadersOnDownstreamRequest(
It.IsAny<HttpContext>())) It.IsAny<List<ClaimToThing>>(),
It.IsAny<IEnumerable<System.Security.Claims.Claim>>(),
It.IsAny<HttpRequestMessage>()))
.Returns(new OkResponse()); .Returns(new OkResponse());
} }
private void ThenTheAddHeadersToRequestIsCalledCorrectly() private void ThenTheAddHeadersToRequestIsCalledCorrectly()
{ {
_addHeaders _addHeaders
.Verify(x => x.SetHeadersOnContext(It.IsAny<List<ClaimToThing>>(), .Verify(x => x.SetHeadersOnDownstreamRequest(
It.IsAny<HttpContext>()), Times.Once); It.IsAny<List<ClaimToThing>>(),
It.IsAny<IEnumerable<System.Security.Claims.Claim>>(),
_downstreamRequest), Times.Once);
} }
private void WhenICallTheMiddleware() private void WhenICallTheMiddleware()

View File

@ -16,6 +16,7 @@ using Ocelot.Responses;
using Ocelot.Values; using Ocelot.Values;
using TestStack.BDDfy; using TestStack.BDDfy;
using Xunit; using Xunit;
using Shouldly;
namespace Ocelot.UnitTests.LoadBalancer namespace Ocelot.UnitTests.LoadBalancer
{ {
@ -29,10 +30,10 @@ namespace Ocelot.UnitTests.LoadBalancer
private readonly HttpClient _client; private readonly HttpClient _client;
private HttpResponseMessage _result; private HttpResponseMessage _result;
private HostAndPort _hostAndPort; private HostAndPort _hostAndPort;
private OkResponse<string> _downstreamUrl;
private OkResponse<DownstreamRoute> _downstreamRoute; private OkResponse<DownstreamRoute> _downstreamRoute;
private ErrorResponse<ILoadBalancer> _getLoadBalancerHouseError; private ErrorResponse<ILoadBalancer> _getLoadBalancerHouseError;
private ErrorResponse<HostAndPort> _getHostAndPortError; private ErrorResponse<HostAndPort> _getHostAndPortError;
private HttpRequestMessage _downstreamRequest;
public LoadBalancerMiddlewareTests() public LoadBalancerMiddlewareTests()
{ {
@ -59,6 +60,10 @@ namespace Ocelot.UnitTests.LoadBalancer
app.UseLoadBalancingMiddleware(); app.UseLoadBalancingMiddleware();
}); });
_downstreamRequest = new HttpRequestMessage(HttpMethod.Get, "");
_scopedRepository
.Setup(sr => sr.Get<HttpRequestMessage>("DownstreamRequest"))
.Returns(new OkResponse<HttpRequestMessage>(_downstreamRequest));
_server = new TestServer(builder); _server = new TestServer(builder);
_client = _server.CreateClient(); _client = _server.CreateClient();
} }
@ -71,12 +76,12 @@ namespace Ocelot.UnitTests.LoadBalancer
.WithUpstreamHttpMethod("Get") .WithUpstreamHttpMethod("Get")
.Build()); .Build());
this.Given(x => x.GivenTheDownStreamUrlIs("any old string")) this.Given(x => x.GivenTheDownStreamUrlIs("http://my.url/abc?q=123"))
.And(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) .And(x => x.GivenTheDownStreamRouteIs(downstreamRoute))
.And(x => x.GivenTheLoadBalancerHouseReturns()) .And(x => x.GivenTheLoadBalancerHouseReturns())
.And(x => x.GivenTheLoadBalancerReturns()) .And(x => x.GivenTheLoadBalancerReturns())
.When(x => x.WhenICallTheMiddleware()) .When(x => x.WhenICallTheMiddleware())
.Then(x => x.ThenTheScopedDataRepositoryIsCalledCorrectly()) .Then(x => x.ThenTheDownstreamUrlIsReplacedWith("http://127.0.0.1:80/abc?q=123"))
.BDDfy(); .BDDfy();
} }
@ -88,7 +93,7 @@ namespace Ocelot.UnitTests.LoadBalancer
.WithUpstreamHttpMethod("Get") .WithUpstreamHttpMethod("Get")
.Build()); .Build());
this.Given(x => x.GivenTheDownStreamUrlIs("any old string")) this.Given(x => x.GivenTheDownStreamUrlIs("http://my.url/abc?q=123"))
.And(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) .And(x => x.GivenTheDownStreamRouteIs(downstreamRoute))
.And(x => x.GivenTheLoadBalancerHouseReturnsAnError()) .And(x => x.GivenTheLoadBalancerHouseReturnsAnError())
.When(x => x.WhenICallTheMiddleware()) .When(x => x.WhenICallTheMiddleware())
@ -104,7 +109,7 @@ namespace Ocelot.UnitTests.LoadBalancer
.WithUpstreamHttpMethod("Get") .WithUpstreamHttpMethod("Get")
.Build()); .Build());
this.Given(x => x.GivenTheDownStreamUrlIs("any old string")) this.Given(x => x.GivenTheDownStreamUrlIs("http://my.url/abc?q=123"))
.And(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) .And(x => x.GivenTheDownStreamRouteIs(downstreamRoute))
.And(x => x.GivenTheLoadBalancerHouseReturns()) .And(x => x.GivenTheLoadBalancerHouseReturns())
.And(x => x.GivenTheLoadBalancerReturnsAnError()) .And(x => x.GivenTheLoadBalancerReturnsAnError())
@ -113,6 +118,11 @@ namespace Ocelot.UnitTests.LoadBalancer
.BDDfy(); .BDDfy();
} }
private void GivenTheDownStreamUrlIs(string downstreamUrl)
{
_downstreamRequest.RequestUri = new System.Uri(downstreamUrl);
}
private void GivenTheLoadBalancerReturnsAnError() private void GivenTheLoadBalancerReturnsAnError()
{ {
_getHostAndPortError = new ErrorResponse<HostAndPort>(new List<Error>() { new ServicesAreNullError($"services were null for bah") }); _getHostAndPortError = new ErrorResponse<HostAndPort>(new List<Error>() { new ServicesAreNullError($"services were null for bah") });
@ -157,10 +167,9 @@ namespace Ocelot.UnitTests.LoadBalancer
.Returns(_getLoadBalancerHouseError); .Returns(_getLoadBalancerHouseError);
} }
private void ThenTheScopedDataRepositoryIsCalledCorrectly() private void WhenICallTheMiddleware()
{ {
_scopedRepository _result = _client.GetAsync(_url).Result;
.Verify(x => x.Add("HostAndPort", _hostAndPort), Times.Once());
} }
private void ThenAnErrorStatingLoadBalancerCouldNotBeFoundIsSetOnPipeline() private void ThenAnErrorStatingLoadBalancerCouldNotBeFoundIsSetOnPipeline()
@ -190,17 +199,11 @@ namespace Ocelot.UnitTests.LoadBalancer
.Verify(x => x.Add("OcelotMiddlewareErrors", _getHostAndPortError.Errors), Times.Once); .Verify(x => x.Add("OcelotMiddlewareErrors", _getHostAndPortError.Errors), Times.Once);
} }
private void WhenICallTheMiddleware()
{
_result = _client.GetAsync(_url).Result;
}
private void GivenTheDownStreamUrlIs(string downstreamUrl)
private void ThenTheDownstreamUrlIsReplacedWith(string expectedUri)
{ {
_downstreamUrl = new OkResponse<string>(downstreamUrl); _downstreamRequest.RequestUri.OriginalString.ShouldBe(expectedUri);
_scopedRepository
.Setup(x => x.Get<string>(It.IsAny<string>()))
.Returns(_downstreamUrl);
} }
public void Dispose() public void Dispose()

View File

@ -1,7 +1,6 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Security.Claims; using System.Security.Claims;
using Microsoft.AspNetCore.Http;
using Moq; using Moq;
using Ocelot.Configuration; using Ocelot.Configuration;
using Ocelot.Errors; using Ocelot.Errors;
@ -11,15 +10,18 @@ using Ocelot.Responses;
using Shouldly; using Shouldly;
using TestStack.BDDfy; using TestStack.BDDfy;
using Xunit; using Xunit;
using System.Net.Http;
using System;
namespace Ocelot.UnitTests.QueryStrings namespace Ocelot.UnitTests.QueryStrings
{ {
public class AddQueriesToRequestTests public class AddQueriesToRequestTests
{ {
private readonly AddQueriesToRequest _addQueriesToRequest; private readonly AddQueriesToRequest _addQueriesToRequest;
private readonly HttpRequestMessage _downstreamRequest;
private readonly Mock<IClaimsParser> _parser; private readonly Mock<IClaimsParser> _parser;
private List<ClaimToThing> _configuration; private List<ClaimToThing> _configuration;
private HttpContext _context; private List<Claim> _claims;
private Response _result; private Response _result;
private Response<string> _claimValue; private Response<string> _claimValue;
@ -27,17 +29,15 @@ namespace Ocelot.UnitTests.QueryStrings
{ {
_parser = new Mock<IClaimsParser>(); _parser = new Mock<IClaimsParser>();
_addQueriesToRequest = new AddQueriesToRequest(_parser.Object); _addQueriesToRequest = new AddQueriesToRequest(_parser.Object);
_downstreamRequest = new HttpRequestMessage(HttpMethod.Post, "http://my.url/abc?q=123");
} }
[Fact] [Fact]
public void should_add_queries_to_context() public void should_add_new_queries_to_downstream_request()
{ {
var context = new DefaultHttpContext var claims = new List<Claim>
{
User = new ClaimsPrincipal(new ClaimsIdentity(new List<Claim>
{ {
new Claim("test", "data") new Claim("test", "data")
}))
}; };
this.Given( this.Given(
@ -45,7 +45,7 @@ namespace Ocelot.UnitTests.QueryStrings
{ {
new ClaimToThing("query-key", "", "", 0) new ClaimToThing("query-key", "", "", 0)
})) }))
.Given(x => x.GivenHttpContext(context)) .Given(x => x.GivenClaims(claims))
.And(x => x.GivenTheClaimParserReturns(new OkResponse<string>("value"))) .And(x => x.GivenTheClaimParserReturns(new OkResponse<string>("value")))
.When(x => x.WhenIAddQueriesToTheRequest()) .When(x => x.WhenIAddQueriesToTheRequest())
.Then(x => x.ThenTheResultIsSuccess()) .Then(x => x.ThenTheResultIsSuccess())
@ -54,24 +54,20 @@ namespace Ocelot.UnitTests.QueryStrings
} }
[Fact] [Fact]
public void if_query_exists_should_replace_it() public void should_replace_existing_queries_on_downstream_request()
{ {
var context = new DefaultHttpContext var claims = new List<Claim>
{
User = new ClaimsPrincipal(new ClaimsIdentity(new List<Claim>
{ {
new Claim("test", "data") new Claim("test", "data")
})),
}; };
context.Request.QueryString = context.Request.QueryString.Add("query-key", "initial");
this.Given( this.Given(
x => x.GivenAClaimToThing(new List<ClaimToThing> x => x.GivenAClaimToThing(new List<ClaimToThing>
{ {
new ClaimToThing("query-key", "", "", 0) new ClaimToThing("query-key", "", "", 0)
})) }))
.Given(x => x.GivenHttpContext(context)) .And(x => x.GivenClaims(claims))
.And(x => x.GivenTheDownstreamRequestHasQueryString("query-key", "initial"))
.And(x => x.GivenTheClaimParserReturns(new OkResponse<string>("value"))) .And(x => x.GivenTheClaimParserReturns(new OkResponse<string>("value")))
.When(x => x.WhenIAddQueriesToTheRequest()) .When(x => x.WhenIAddQueriesToTheRequest())
.Then(x => x.ThenTheResultIsSuccess()) .Then(x => x.ThenTheResultIsSuccess())
@ -87,7 +83,7 @@ namespace Ocelot.UnitTests.QueryStrings
{ {
new ClaimToThing("", "", "", 0) new ClaimToThing("", "", "", 0)
})) }))
.Given(x => x.GivenHttpContext(new DefaultHttpContext())) .Given(x => x.GivenClaims(new List<Claim>()))
.And(x => x.GivenTheClaimParserReturns(new ErrorResponse<string>(new List<Error> .And(x => x.GivenTheClaimParserReturns(new ErrorResponse<string>(new List<Error>
{ {
new AnyError() new AnyError()
@ -99,7 +95,8 @@ namespace Ocelot.UnitTests.QueryStrings
private void ThenTheQueryIsAdded() private void ThenTheQueryIsAdded()
{ {
var query = _context.Request.Query.First(x => x.Key == "query-key"); var queries = Microsoft.AspNetCore.WebUtilities.QueryHelpers.ParseQuery(_downstreamRequest.RequestUri.OriginalString);
var query = queries.First(x => x.Key == "query-key");
query.Value.First().ShouldBe(_claimValue.Data); query.Value.First().ShouldBe(_claimValue.Data);
} }
@ -108,9 +105,17 @@ namespace Ocelot.UnitTests.QueryStrings
_configuration = configuration; _configuration = configuration;
} }
private void GivenHttpContext(HttpContext context) private void GivenClaims(List<Claim> claims)
{ {
_context = context; _claims = claims;
}
private void GivenTheDownstreamRequestHasQueryString(string key, string value)
{
var newUri = Microsoft.AspNetCore.WebUtilities.QueryHelpers
.AddQueryString(_downstreamRequest.RequestUri.OriginalString, key, value);
_downstreamRequest.RequestUri = new Uri(newUri);
} }
private void GivenTheClaimParserReturns(Response<string> claimValue) private void GivenTheClaimParserReturns(Response<string> claimValue)
@ -128,7 +133,7 @@ namespace Ocelot.UnitTests.QueryStrings
private void WhenIAddQueriesToTheRequest() private void WhenIAddQueriesToTheRequest()
{ {
_result = _addQueriesToRequest.SetQueriesOnContext(_configuration, _context); _result = _addQueriesToRequest.SetQueriesOnDownstreamRequest(_configuration, _claims, _downstreamRequest);
} }
private void ThenTheResultIsSuccess() private void ThenTheResultIsSuccess()
@ -138,7 +143,6 @@ namespace Ocelot.UnitTests.QueryStrings
private void ThenTheResultIsError() private void ThenTheResultIsError()
{ {
_result.IsError.ShouldBe(true); _result.IsError.ShouldBe(true);
} }

View File

@ -3,16 +3,13 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Net.Http; using System.Net.Http;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Moq; using Moq;
using Ocelot.Configuration; using Ocelot.Configuration;
using Ocelot.Configuration.Builder; using Ocelot.Configuration.Builder;
using Ocelot.DownstreamRouteFinder; using Ocelot.DownstreamRouteFinder;
using Ocelot.DownstreamRouteFinder.UrlMatcher; using Ocelot.DownstreamRouteFinder.UrlMatcher;
using Ocelot.Headers.Middleware;
using Ocelot.Infrastructure.RequestData; using Ocelot.Infrastructure.RequestData;
using Ocelot.Logging; using Ocelot.Logging;
using Ocelot.QueryStrings; using Ocelot.QueryStrings;
@ -20,6 +17,7 @@ using Ocelot.QueryStrings.Middleware;
using Ocelot.Responses; using Ocelot.Responses;
using TestStack.BDDfy; using TestStack.BDDfy;
using Xunit; using Xunit;
using System.Security.Claims;
namespace Ocelot.UnitTests.QueryStrings namespace Ocelot.UnitTests.QueryStrings
{ {
@ -30,6 +28,7 @@ namespace Ocelot.UnitTests.QueryStrings
private readonly string _url; private readonly string _url;
private readonly TestServer _server; private readonly TestServer _server;
private readonly HttpClient _client; private readonly HttpClient _client;
private readonly HttpRequestMessage _downstreamRequest;
private Response<DownstreamRoute> _downstreamRoute; private Response<DownstreamRoute> _downstreamRoute;
private HttpResponseMessage _result; private HttpResponseMessage _result;
@ -56,6 +55,11 @@ namespace Ocelot.UnitTests.QueryStrings
app.UseQueryStringBuilderMiddleware(); app.UseQueryStringBuilderMiddleware();
}); });
_downstreamRequest = new HttpRequestMessage();
_scopedRepository.Setup(sr => sr.Get<HttpRequestMessage>("DownstreamRequest"))
.Returns(new OkResponse<HttpRequestMessage>(_downstreamRequest));
_server = new TestServer(builder); _server = new TestServer(builder);
_client = _server.CreateClient(); _client = _server.CreateClient();
} }
@ -74,25 +78,29 @@ namespace Ocelot.UnitTests.QueryStrings
.Build()); .Build());
this.Given(x => x.GivenTheDownStreamRouteIs(downstreamRoute)) this.Given(x => x.GivenTheDownStreamRouteIs(downstreamRoute))
.And(x => x.GivenTheAddHeadersToRequestReturns()) .And(x => x.GivenTheAddHeadersToRequestReturnsOk())
.When(x => x.WhenICallTheMiddleware()) .When(x => x.WhenICallTheMiddleware())
.Then(x => x.ThenTheAddQueriesToRequestIsCalledCorrectly()) .Then(x => x.ThenTheAddQueriesToRequestIsCalledCorrectly())
.BDDfy(); .BDDfy();
} }
private void GivenTheAddHeadersToRequestReturns() private void GivenTheAddHeadersToRequestReturnsOk()
{ {
_addQueries _addQueries
.Setup(x => x.SetQueriesOnContext(It.IsAny<List<ClaimToThing>>(), .Setup(x => x.SetQueriesOnDownstreamRequest(
It.IsAny<HttpContext>())) It.IsAny<List<ClaimToThing>>(),
It.IsAny<IEnumerable<Claim>>(),
It.IsAny<HttpRequestMessage>()))
.Returns(new OkResponse()); .Returns(new OkResponse());
} }
private void ThenTheAddQueriesToRequestIsCalledCorrectly() private void ThenTheAddQueriesToRequestIsCalledCorrectly()
{ {
_addQueries _addQueries
.Verify(x => x.SetQueriesOnContext(It.IsAny<List<ClaimToThing>>(), .Verify(x => x.SetQueriesOnDownstreamRequest(
It.IsAny<HttpContext>()), Times.Once); It.IsAny<List<ClaimToThing>>(),
It.IsAny<IEnumerable<Claim>>(),
_downstreamRequest), Times.Once);
} }
private void WhenICallTheMiddleware() private void WhenICallTheMiddleware()

View File

@ -0,0 +1,142 @@
namespace Ocelot.UnitTests.Request
{
using System.Net.Http;
using Microsoft.AspNetCore.Http;
using Moq;
using Ocelot.Logging;
using Ocelot.Request.Mapper;
using Ocelot.Request.Middleware;
using Ocelot.Infrastructure.RequestData;
using TestStack.BDDfy;
using Xunit;
using Ocelot.Responses;
public class DownstreamRequestInitialiserMiddlewareTests
{
readonly DownstreamRequestInitialiserMiddleware _middleware;
readonly Mock<HttpContext> _httpContext;
readonly Mock<HttpRequest> _httpRequest;
readonly Mock<RequestDelegate> _next;
readonly Mock<IRequestMapper> _requestMapper;
readonly Mock<IRequestScopedDataRepository> _repo;
readonly Mock<IOcelotLoggerFactory> _loggerFactory;
readonly Mock<IOcelotLogger> _logger;
Response<HttpRequestMessage> _mappedRequest;
public DownstreamRequestInitialiserMiddlewareTests()
{
_httpContext = new Mock<HttpContext>();
_httpRequest = new Mock<HttpRequest>();
_requestMapper = new Mock<IRequestMapper>();
_repo = new Mock<IRequestScopedDataRepository>();
_next = new Mock<RequestDelegate>();
_logger = new Mock<IOcelotLogger>();
_loggerFactory = new Mock<IOcelotLoggerFactory>();
_loggerFactory
.Setup(lf => lf.CreateLogger<DownstreamRequestInitialiserMiddleware>())
.Returns(_logger.Object);
_middleware = new DownstreamRequestInitialiserMiddleware(
_next.Object,
_loggerFactory.Object,
_repo.Object,
_requestMapper.Object);
}
[Fact]
public void Should_handle_valid_httpRequest()
{
this.Given(_ => GivenTheHttpContextContainsARequest())
.And(_ => GivenTheMapperWillReturnAMappedRequest())
.When(_ => WhenTheMiddlewareIsInvoked())
.Then(_ => ThenTheContexRequestIsMappedToADownstreamRequest())
.And(_ => ThenTheDownstreamRequestIsStored())
.And(_ => ThenTheNextMiddlewareIsInvoked())
.BDDfy();
}
[Fact]
public void Should_handle_mapping_failure()
{
this.Given(_ => GivenTheHttpContextContainsARequest())
.And(_ => GivenTheMapperWillReturnAnError())
.When(_ => WhenTheMiddlewareIsInvoked())
.And(_ => ThenTheDownstreamRequestIsNotStored())
.And(_ => ThenAPipelineErrorIsStored())
.And(_ => ThenTheNextMiddlewareIsNotInvoked())
.BDDfy();
}
private void GivenTheHttpContextContainsARequest()
{
_httpContext
.Setup(hc => hc.Request)
.Returns(_httpRequest.Object);
}
private void GivenTheMapperWillReturnAMappedRequest()
{
_mappedRequest = new OkResponse<HttpRequestMessage>(new HttpRequestMessage());
_requestMapper
.Setup(rm => rm.Map(It.IsAny<HttpRequest>()))
.ReturnsAsync(_mappedRequest);
}
private void GivenTheMapperWillReturnAnError()
{
_mappedRequest = new ErrorResponse<HttpRequestMessage>(new UnmappableRequestError(new System.Exception("boooom!")));
_requestMapper
.Setup(rm => rm.Map(It.IsAny<HttpRequest>()))
.ReturnsAsync(_mappedRequest);
}
private void WhenTheMiddlewareIsInvoked()
{
_middleware.Invoke(_httpContext.Object).GetAwaiter().GetResult();
}
private void ThenTheContexRequestIsMappedToADownstreamRequest()
{
_requestMapper.Verify(rm => rm.Map(_httpRequest.Object), Times.Once);
}
private void ThenTheDownstreamRequestIsStored()
{
_repo.Verify(r => r.Add("DownstreamRequest", _mappedRequest.Data), Times.Once);
}
private void ThenTheDownstreamRequestIsNotStored()
{
_repo.Verify(r => r.Add("DownstreamRequest", It.IsAny<HttpRequestMessage>()), Times.Never);
}
private void ThenAPipelineErrorIsStored()
{
_repo.Verify(r => r.Add("OcelotMiddlewareError", true), Times.Once);
_repo.Verify(r => r.Add("OcelotMiddlewareErrors", _mappedRequest.Errors), Times.Once);
}
private void ThenTheNextMiddlewareIsInvoked()
{
_next.Verify(n => n(_httpContext.Object), Times.Once);
}
private void ThenTheNextMiddlewareIsNotInvoked()
{
_next.Verify(n => n(It.IsAny<HttpContext>()), Times.Never);
}
}
}

View File

@ -1,13 +1,10 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Net;
using System.Net.Http; using System.Net.Http;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Moq; using Moq;
using Ocelot.Configuration.Builder; using Ocelot.Configuration.Builder;
using Ocelot.DownstreamRouteFinder; using Ocelot.DownstreamRouteFinder;
@ -19,7 +16,6 @@ using Ocelot.Request.Middleware;
using Ocelot.Responses; using Ocelot.Responses;
using TestStack.BDDfy; using TestStack.BDDfy;
using Xunit; using Xunit;
using Ocelot.Configuration;
using Ocelot.Requester.QoS; using Ocelot.Requester.QoS;
namespace Ocelot.UnitTests.Request namespace Ocelot.UnitTests.Request
@ -29,6 +25,7 @@ namespace Ocelot.UnitTests.Request
private readonly Mock<IRequestCreator> _requestBuilder; private readonly Mock<IRequestCreator> _requestBuilder;
private readonly Mock<IRequestScopedDataRepository> _scopedRepository; private readonly Mock<IRequestScopedDataRepository> _scopedRepository;
private readonly Mock<IQosProviderHouse> _qosProviderHouse; private readonly Mock<IQosProviderHouse> _qosProviderHouse;
private readonly HttpRequestMessage _downstreamRequest;
private readonly string _url; private readonly string _url;
private readonly TestServer _server; private readonly TestServer _server;
private readonly HttpClient _client; private readonly HttpClient _client;
@ -62,6 +59,12 @@ namespace Ocelot.UnitTests.Request
app.UseHttpRequestBuilderMiddleware(); app.UseHttpRequestBuilderMiddleware();
}); });
_downstreamRequest = new HttpRequestMessage();
_scopedRepository
.Setup(sr => sr.Get<HttpRequestMessage>("DownstreamRequest"))
.Returns(new OkResponse<HttpRequestMessage>(_downstreamRequest));
_server = new TestServer(builder); _server = new TestServer(builder);
_client = _server.CreateClient(); _client = _server.CreateClient();
} }
@ -103,9 +106,9 @@ namespace Ocelot.UnitTests.Request
private void GivenTheRequestBuilderReturns(Ocelot.Request.Request request) private void GivenTheRequestBuilderReturns(Ocelot.Request.Request request)
{ {
_request = new OkResponse<Ocelot.Request.Request>(request); _request = new OkResponse<Ocelot.Request.Request>(request);
_requestBuilder _requestBuilder
.Setup(x => x.Build(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<Stream>(), It.IsAny<IHeaderDictionary>(), .Setup(x => x.Build(It.IsAny<HttpRequestMessage>(), It.IsAny<bool>(), It.IsAny<IQoSProvider>()))
It.IsAny<QueryString>(), It.IsAny<string>(), It.IsAny<Ocelot.RequestId.RequestId>(),It.IsAny<bool>(), It.IsAny<IQoSProvider>()))
.ReturnsAsync(_request); .ReturnsAsync(_request);
} }

View File

@ -0,0 +1,56 @@
namespace Ocelot.UnitTests.Request
{
using System.Net.Http;
using Ocelot.Request.Builder;
using Ocelot.Requester.QoS;
using Ocelot.Responses;
using Shouldly;
using TestStack.BDDfy;
using Xunit;
public class HttpRequestCreatorTests
{
private readonly IRequestCreator _requestCreator;
private readonly bool _isQos;
private readonly IQoSProvider _qoSProvider;
private readonly HttpRequestMessage _requestMessage;
private Response<Ocelot.Request.Request> _response;
public HttpRequestCreatorTests()
{
_requestCreator = new HttpRequestCreator();
_isQos = true;
_qoSProvider = new NoQoSProvider();
_requestMessage = new HttpRequestMessage();
}
[Fact]
public void ShouldBuildRequest()
{
this.When(x => x.WhenIBuildARequest())
.Then(x => x.ThenTheRequestContainsTheRequestMessage())
.BDDfy();
}
private void WhenIBuildARequest()
{
_response = _requestCreator.Build(_requestMessage, _isQos, _qoSProvider).GetAwaiter().GetResult();
}
private void ThenTheRequestContainsTheRequestMessage()
{
_response.Data.HttpRequestMessage.ShouldBe(_requestMessage);
}
private void ThenTheRequestContainsTheIsQos()
{
_response.Data.IsQos.ShouldBe(_isQos);
}
private void ThenTheRequestContainsTheQosProvider()
{
_response.Data.QosProvider.ShouldBe(_qoSProvider);
}
}
}

View File

@ -0,0 +1,258 @@
namespace Ocelot.UnitTests.Request.Mapper
{
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Internal;
using Microsoft.Extensions.Primitives;
using Ocelot.Request.Mapper;
using Ocelot.Responses;
using TestStack.BDDfy;
using Xunit;
using Shouldly;
using System;
using System.IO;
using System.Text;
public class RequestMapperTests
{
readonly HttpRequest _inputRequest;
readonly RequestMapper _requestMapper;
Response<HttpRequestMessage> _mappedRequest;
List<KeyValuePair<string, StringValues>> _inputHeaders = null;
public RequestMapperTests()
{
_inputRequest = new DefaultHttpRequest(new DefaultHttpContext());
_requestMapper = new RequestMapper();
}
[Theory]
[InlineData("https", "my.url:123", "/abc/DEF", "?a=1&b=2", "https://my.url:123/abc/DEF?a=1&b=2")]
[InlineData("http", "blah.com", "/d ef", "?abc=123", "http://blah.com/d%20ef?abc=123")] // note! the input is encoded when building the input request
[InlineData("http", "myusername:mypassword@abc.co.uk", null, null, "http://myusername:mypassword@abc.co.uk/")]
[InlineData("http", "點看.com", null, null, "http://xn--c1yn36f.com/")]
[InlineData("http", "xn--c1yn36f.com", null, null, "http://xn--c1yn36f.com/")]
public void Should_map_valid_request_uri(string scheme, string host, string path, string queryString, string expectedUri)
{
this.Given(_ => GivenTheInputRequestHasMethod("GET"))
.And(_ => GivenTheInputRequestHasScheme(scheme))
.And(_ => GivenTheInputRequestHasHost(host))
.And(_ => GivenTheInputRequestHasPath(path))
.And(_ => GivenTheInputRequestHasQueryString(queryString))
.When(_ => WhenMapped())
.Then(_ => ThenNoErrorIsReturned())
.And(_ => ThenTheMappedRequestHasUri(expectedUri))
.BDDfy();
}
[Theory]
[InlineData("ftp", "google.com", "/abc/DEF", "?a=1&b=2")]
public void Should_error_on_unsupported_request_uri(string scheme, string host, string path, string queryString)
{
this.Given(_ => GivenTheInputRequestHasMethod("GET"))
.And(_ => GivenTheInputRequestHasScheme(scheme))
.And(_ => GivenTheInputRequestHasHost(host))
.And(_ => GivenTheInputRequestHasPath(path))
.And(_ => GivenTheInputRequestHasQueryString(queryString))
.When(_ => WhenMapped())
.Then(_ => ThenAnErrorIsReturned())
.And(_ => ThenTheMappedRequestIsNull())
.BDDfy();
}
[Theory]
[InlineData("GET")]
[InlineData("POST")]
[InlineData("WHATEVER")]
public void Should_map_method(string method)
{
this.Given(_ => GivenTheInputRequestHasMethod(method))
.And(_ => GivenTheInputRequestHasAValidUri())
.When(_ => WhenMapped())
.Then(_ => ThenNoErrorIsReturned())
.And(_ => ThenTheMappedRequestHasMethod(method))
.BDDfy();
}
[Fact]
public void Should_map_all_headers()
{
this.Given(_ => GivenTheInputRequestHasHeaders())
.And(_ => GivenTheInputRequestHasMethod("GET"))
.And(_ => GivenTheInputRequestHasAValidUri())
.When(_ => WhenMapped())
.Then(_ => ThenNoErrorIsReturned())
.And(_ => ThenTheMappedRequestHasEachHeader())
.BDDfy();
}
[Fact]
public void Should_handle_no_headers()
{
this.Given(_ => GivenTheInputRequestHasNoHeaders())
.And(_ => GivenTheInputRequestHasMethod("GET"))
.And(_ => GivenTheInputRequestHasAValidUri())
.When(_ => WhenMapped())
.Then(_ => ThenNoErrorIsReturned())
.And(_ => ThenTheMappedRequestHasNoHeaders())
.BDDfy();
}
[Fact]
public void Should_map_content()
{
this.Given(_ => GivenTheInputRequestHasContent("This is my content"))
.And(_ => GivenTheInputRequestHasMethod("GET"))
.And(_ => GivenTheInputRequestHasAValidUri())
.When(_ => WhenMapped())
.Then(_ => ThenNoErrorIsReturned())
.And(_ => ThenTheMappedRequestHasContent("This is my content"))
.BDDfy();
}
[Fact]
public void Should_handle_no_content()
{
this.Given(_ => GivenTheInputRequestHasNoContent())
.And(_ => GivenTheInputRequestHasMethod("GET"))
.And(_ => GivenTheInputRequestHasAValidUri())
.When(_ => WhenMapped())
.Then(_ => ThenNoErrorIsReturned())
.And(_ => ThenTheMappedRequestHasNoContent())
.BDDfy();
}
private void GivenTheInputRequestHasMethod(string method)
{
_inputRequest.Method = method;
}
private void GivenTheInputRequestHasScheme(string scheme)
{
_inputRequest.Scheme = scheme;
}
private void GivenTheInputRequestHasHost(string host)
{
_inputRequest.Host = new HostString(host);
}
private void GivenTheInputRequestHasPath(string path)
{
if (path != null)
{
_inputRequest.Path = path;
}
}
private void GivenTheInputRequestHasQueryString(string querystring)
{
if (querystring != null)
{
_inputRequest.QueryString = new QueryString(querystring);
}
}
private void GivenTheInputRequestHasAValidUri()
{
GivenTheInputRequestHasScheme("http");
GivenTheInputRequestHasHost("www.google.com");
}
private void GivenTheInputRequestHasHeaders()
{
_inputHeaders = new List<KeyValuePair<string, StringValues>>()
{
new KeyValuePair<string, StringValues>("abc", new StringValues(new string[]{"123","456" })),
new KeyValuePair<string, StringValues>("def", new StringValues(new string[]{"789","012" })),
};
foreach (var inputHeader in _inputHeaders)
{
_inputRequest.Headers.Add(inputHeader);
}
}
private void GivenTheInputRequestHasNoHeaders()
{
_inputRequest.Headers.Clear();
}
private void GivenTheInputRequestHasContent(string content)
{
_inputRequest.Body = new MemoryStream(Encoding.UTF8.GetBytes(content));
}
private void GivenTheInputRequestHasNoContent()
{
_inputRequest.Body = null;
}
private void WhenMapped()
{
_mappedRequest = _requestMapper.Map(_inputRequest).GetAwaiter().GetResult();
}
private void ThenNoErrorIsReturned()
{
_mappedRequest.IsError.ShouldBeFalse();
}
private void ThenAnErrorIsReturned()
{
_mappedRequest.IsError.ShouldBeTrue();
}
private void ThenTheMappedRequestHasUri(string expectedUri)
{
_mappedRequest.Data.RequestUri.OriginalString.ShouldBe(expectedUri);
}
private void ThenTheMappedRequestHasMethod(string expectedMethod)
{
_mappedRequest.Data.Method.ToString().ShouldBe(expectedMethod);
}
private void ThenTheMappedRequestHasEachHeader()
{
_mappedRequest.Data.Headers.Count().ShouldBe(_inputHeaders.Count);
foreach(var header in _mappedRequest.Data.Headers)
{
var inputHeader = _inputHeaders.First(h => h.Key == header.Key);
inputHeader.ShouldNotBeNull();
inputHeader.Value.Count().ShouldBe(header.Value.Count());
foreach(var inputHeaderValue in inputHeader.Value)
{
header.Value.Any(v => v == inputHeaderValue);
}
}
}
private void ThenTheMappedRequestHasNoHeaders()
{
_mappedRequest.Data.Headers.Count().ShouldBe(0);
}
private void ThenTheMappedRequestHasContent(string expectedContent)
{
_mappedRequest.Data.Content.ReadAsStringAsync().GetAwaiter().GetResult().ShouldBe(expectedContent);
}
private void ThenTheMappedRequestHasNoContent()
{
_mappedRequest.Data.Content.ShouldBeNull();
}
private void ThenTheMappedRequestIsNull()
{
_mappedRequest.Data.ShouldBeNull();
}
}
}

View File

@ -1,310 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Internal;
using Ocelot.Request.Builder;
using Ocelot.Responses;
using Shouldly;
using TestStack.BDDfy;
using Xunit;
using Ocelot.Configuration;
using Ocelot.Requester.QoS;
namespace Ocelot.UnitTests.Request
{
public class RequestBuilderTests
{
private string _httpMethod;
private string _downstreamUrl;
private HttpContent _content;
private IHeaderDictionary _headers;
private IRequestCookieCollection _cookies;
private QueryString _query;
private string _contentType;
private readonly IRequestCreator _requestCreator;
private Response<Ocelot.Request.Request> _result;
private Ocelot.RequestId.RequestId _requestId;
private bool _isQos;
private IQoSProvider _qoSProvider;
public RequestBuilderTests()
{
_content = new StringContent(string.Empty);
_requestCreator = new HttpRequestCreator();
}
[Fact]
public void should_user_downstream_url()
{
this.Given(x => x.GivenIHaveHttpMethod("GET"))
.And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk"))
.And(x=> x.GivenTheQos(true, new NoQoSProvider()))
.When(x => x.WhenICreateARequest())
.And(x => x.ThenTheCorrectDownstreamUrlIsUsed("http://www.bbc.co.uk/"))
.BDDfy();
}
[Fact]
public void should_use_http_method()
{
this.Given(x => x.GivenIHaveHttpMethod("POST"))
.And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk"))
.And(x => x.GivenTheQos(true, new NoQoSProvider()))
.When(x => x.WhenICreateARequest())
.And(x => x.ThenTheCorrectHttpMethodIsUsed(HttpMethod.Post))
.BDDfy();
}
[Fact]
public void should_use_http_content()
{
this.Given(x => x.GivenIHaveHttpMethod("POST"))
.And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk"))
.And(x => x.GivenIHaveTheHttpContent(new StringContent("Hi from Tom")))
.And(x => x.GivenTheContentTypeIs("application/json"))
.And(x => x.GivenTheQos(true, new NoQoSProvider()))
.When(x => x.WhenICreateARequest())
.And(x => x.ThenTheCorrectContentIsUsed(new StringContent("Hi from Tom")))
.BDDfy();
}
[Fact]
public void should_use_http_content_headers()
{
this.Given(x => x.GivenIHaveHttpMethod("POST"))
.And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk"))
.And(x => x.GivenIHaveTheHttpContent(new StringContent("Hi from Tom")))
.And(x => x.GivenTheContentTypeIs("application/json"))
.And(x => x.GivenTheQos(true, new NoQoSProvider()))
.When(x => x.WhenICreateARequest())
.And(x => x.ThenTheCorrectContentHeadersAreUsed(new HeaderDictionary
{
{
"Content-Type", "application/json"
}
}))
.BDDfy();
}
[Fact]
public void should_use_unvalidated_http_content_headers()
{
this.Given(x => x.GivenIHaveHttpMethod("POST"))
.And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk"))
.And(x => x.GivenIHaveTheHttpContent(new StringContent("Hi from Tom")))
.And(x => x.GivenTheContentTypeIs("application/json; charset=utf-8"))
.And(x => x.GivenTheQos(true, new NoQoSProvider()))
.When(x => x.WhenICreateARequest())
.And(x => x.ThenTheCorrectContentHeadersAreUsed(new HeaderDictionary
{
{
"Content-Type", "application/json; charset=utf-8"
}
}))
.BDDfy();
}
[Fact]
public void should_use_headers()
{
this.Given(x => x.GivenIHaveHttpMethod("GET"))
.And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk"))
.And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary
{
{"ChopSticks", "Bubbles" }
}))
.And(x => x.GivenTheQos(true, new NoQoSProvider()))
.When(x => x.WhenICreateARequest())
.And(x => x.ThenTheCorrectHeadersAreUsed(new HeaderDictionary
{
{"ChopSticks", "Bubbles" }
}))
.BDDfy();
}
[Fact]
public void should_use_request_id()
{
var requestId = Guid.NewGuid().ToString();
this.Given(x => x.GivenIHaveHttpMethod("GET"))
.And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk"))
.And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary()))
.And(x => x.GivenTheRequestIdIs(new Ocelot.RequestId.RequestId("RequestId", requestId)))
.And(x => x.GivenTheQos(true, new NoQoSProvider()))
.When(x => x.WhenICreateARequest())
.And(x => x.ThenTheCorrectHeadersAreUsed(new HeaderDictionary
{
{"RequestId", requestId }
}))
.BDDfy();
}
[Fact]
public void should_not_use_request_if_if_already_in_headers()
{
this.Given(x => x.GivenIHaveHttpMethod("GET"))
.And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk"))
.And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary
{
{"RequestId", "534534gv54gv45g" }
}))
.And(x => x.GivenTheRequestIdIs(new Ocelot.RequestId.RequestId("RequestId", Guid.NewGuid().ToString())))
.And(x => x.GivenTheQos(true, new NoQoSProvider()))
.When(x => x.WhenICreateARequest())
.And(x => x.ThenTheCorrectHeadersAreUsed(new HeaderDictionary
{
{"RequestId", "534534gv54gv45g" }
}))
.BDDfy();
}
[Theory]
[InlineData(null, "blahh")]
[InlineData("", "blahh")]
[InlineData("RequestId", "")]
[InlineData("RequestId", null)]
public void should_not_use_request_id(string requestIdKey, string requestIdValue)
{
this.Given(x => x.GivenIHaveHttpMethod("GET"))
.And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk"))
.And(x => x.GivenTheHttpHeadersAre(new HeaderDictionary()))
.And(x => x.GivenTheRequestIdIs(new Ocelot.RequestId.RequestId(requestIdKey, requestIdValue)))
.And(x => x.GivenTheQos(true, new NoQoSProvider()))
.When(x => x.WhenICreateARequest())
.And(x => x.ThenTheRequestIdIsNotInTheHeaders())
.BDDfy();
}
private void GivenTheRequestIdIs(Ocelot.RequestId.RequestId requestId)
{
_requestId = requestId;
}
private void GivenTheQos(bool isQos, IQoSProvider qoSProvider)
{
_isQos = isQos;
_qoSProvider = qoSProvider;
}
[Fact]
public void should_user_query_string()
{
this.Given(x => x.GivenIHaveHttpMethod("POST"))
.And(x => x.GivenIHaveDownstreamUrl("http://www.bbc.co.uk"))
.And(x => x.GivenTheQueryStringIs(new QueryString("?jeff=1&geoff=2")))
.When(x => x.WhenICreateARequest())
.And(x => x.ThenTheCorrectQueryStringIsUsed("?jeff=1&geoff=2"))
.BDDfy();
}
private void GivenTheContentTypeIs(string contentType)
{
_contentType = contentType;
}
private void ThenTheCorrectQueryStringIsUsed(string expected)
{
_result.Data.HttpRequestMessage.RequestUri.Query.ShouldBe(expected);
}
private void GivenTheQueryStringIs(QueryString query)
{
_query = query;
}
private void ThenTheCorrectCookiesAreUsed(IRequestCookieCollection expected)
{
/* var resultCookies = _result.Data.CookieContainer.GetCookies(new Uri(_downstreamUrl + _query));
var resultDictionary = resultCookies.Cast<Cookie>().ToDictionary(cook => cook.Name, cook => cook.Value);
foreach (var expectedCookie in expected)
{
var resultCookie = resultDictionary[expectedCookie.Key];
resultCookie.ShouldBe(expectedCookie.Value);
}*/
}
private void GivenTheCookiesAre(IRequestCookieCollection cookies)
{
_cookies = cookies;
}
private void ThenTheRequestIdIsNotInTheHeaders()
{
_result.Data.HttpRequestMessage.Headers.ShouldNotContain(x => x.Key == "RequestId");
}
private void ThenTheCorrectHeadersAreUsed(IHeaderDictionary expected)
{
var expectedHeaders = expected.Select(x => new KeyValuePair<string, string[]>(x.Key, x.Value));
foreach (var expectedHeader in expectedHeaders)
{
_result.Data.HttpRequestMessage.Headers.ShouldContain(x => x.Key == expectedHeader.Key && x.Value.First() == expectedHeader.Value[0]);
}
}
private void ThenTheCorrectContentHeadersAreUsed(IHeaderDictionary expected)
{
var expectedHeaders = expected.Select(x => new KeyValuePair<string, string[]>(x.Key, x.Value));
foreach (var expectedHeader in expectedHeaders)
{
_result.Data.HttpRequestMessage.Content.Headers.ShouldContain(x => x.Key == expectedHeader.Key
&& x.Value.First() == expectedHeader.Value[0]
);
}
}
private void GivenTheHttpHeadersAre(IHeaderDictionary headers)
{
_headers = headers;
}
private void GivenIHaveTheHttpContent(HttpContent content)
{
_content = content;
}
private void GivenIHaveHttpMethod(string httpMethod)
{
_httpMethod = httpMethod;
}
private void GivenIHaveDownstreamUrl(string downstreamUrl)
{
_downstreamUrl = downstreamUrl;
}
private void WhenICreateARequest()
{
_result = _requestCreator.Build(_httpMethod, _downstreamUrl, _content?.ReadAsStreamAsync().Result, _headers,
_query, _contentType, _requestId,_isQos,_qoSProvider).Result;
}
private void ThenTheCorrectDownstreamUrlIsUsed(string expected)
{
_result.Data.HttpRequestMessage.RequestUri.AbsoluteUri.ShouldBe(expected);
}
private void ThenTheCorrectHttpMethodIsUsed(HttpMethod expected)
{
_result.Data.HttpRequestMessage.Method.Method.ShouldBe(expected.Method);
}
private void ThenTheCorrectContentIsUsed(HttpContent expected)
{
_result.Data.HttpRequestMessage.Content.ReadAsStringAsync().Result.ShouldBe(expected.ReadAsStringAsync().Result);
}
}
}

View File

@ -8,14 +8,12 @@ using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Moq; using Moq;
using Ocelot.Configuration.Builder; using Ocelot.Configuration.Builder;
using Ocelot.DownstreamRouteFinder; using Ocelot.DownstreamRouteFinder;
using Ocelot.DownstreamRouteFinder.UrlMatcher; using Ocelot.DownstreamRouteFinder.UrlMatcher;
using Ocelot.Infrastructure.RequestData; using Ocelot.Infrastructure.RequestData;
using Ocelot.Logging; using Ocelot.Logging;
using Ocelot.Request.Middleware;
using Ocelot.RequestId.Middleware; using Ocelot.RequestId.Middleware;
using Ocelot.Responses; using Ocelot.Responses;
using Shouldly; using Shouldly;
@ -27,6 +25,7 @@ namespace Ocelot.UnitTests.RequestId
public class RequestIdMiddlewareTests public class RequestIdMiddlewareTests
{ {
private readonly Mock<IRequestScopedDataRepository> _scopedRepository; private readonly Mock<IRequestScopedDataRepository> _scopedRepository;
private readonly HttpRequestMessage _downstreamRequest;
private readonly string _url; private readonly string _url;
private readonly TestServer _server; private readonly TestServer _server;
private readonly HttpClient _client; private readonly HttpClient _client;
@ -64,10 +63,16 @@ namespace Ocelot.UnitTests.RequestId
_server = new TestServer(builder); _server = new TestServer(builder);
_client = _server.CreateClient(); _client = _server.CreateClient();
_downstreamRequest = new HttpRequestMessage();
_scopedRepository
.Setup(sr => sr.Get<HttpRequestMessage>("DownstreamRequest"))
.Returns(new OkResponse<HttpRequestMessage>(_downstreamRequest));
} }
[Fact] [Fact]
public void should_add_request_id_to_repository() public void should_pass_down_request_id_from_upstream_request()
{ {
var downstreamRoute = new DownstreamRoute(new List<UrlPathPlaceholderNameAndValue>(), var downstreamRoute = new DownstreamRoute(new List<UrlPathPlaceholderNameAndValue>(),
new ReRouteBuilder() new ReRouteBuilder()
@ -86,7 +91,7 @@ namespace Ocelot.UnitTests.RequestId
} }
[Fact] [Fact]
public void should_add_trace_indentifier_to_repository() public void should_add_request_id_when_not_on_upstream_request()
{ {
var downstreamRoute = new DownstreamRoute(new List<UrlPathPlaceholderNameAndValue>(), var downstreamRoute = new DownstreamRoute(new List<UrlPathPlaceholderNameAndValue>(),
new ReRouteBuilder() new ReRouteBuilder()
@ -101,14 +106,12 @@ namespace Ocelot.UnitTests.RequestId
.BDDfy(); .BDDfy();
} }
private void ThenTheTraceIdIsAnything() private void GivenTheDownStreamRouteIs(DownstreamRoute downstreamRoute)
{ {
_result.Headers.GetValues("LSRequestId").First().ShouldNotBeNullOrEmpty(); _downstreamRoute = new OkResponse<DownstreamRoute>(downstreamRoute);
} _scopedRepository
.Setup(x => x.Get<DownstreamRoute>(It.IsAny<string>()))
private void ThenTheTraceIdIs(string expected) .Returns(_downstreamRoute);
{
_result.Headers.GetValues("LSRequestId").First().ShouldBe(expected);
} }
private void GivenTheRequestIdIsAddedToTheRequest(string key, string value) private void GivenTheRequestIdIsAddedToTheRequest(string key, string value)
@ -123,12 +126,14 @@ namespace Ocelot.UnitTests.RequestId
_result = _client.GetAsync(_url).Result; _result = _client.GetAsync(_url).Result;
} }
private void GivenTheDownStreamRouteIs(DownstreamRoute downstreamRoute) private void ThenTheTraceIdIsAnything()
{ {
_downstreamRoute = new OkResponse<DownstreamRoute>(downstreamRoute); _result.Headers.GetValues("LSRequestId").First().ShouldNotBeNullOrEmpty();
_scopedRepository }
.Setup(x => x.Get<DownstreamRoute>(It.IsAny<string>()))
.Returns(_downstreamRoute); private void ThenTheTraceIdIs(string expected)
{
_result.Headers.GetValues("LSRequestId").First().ShouldBe(expected);
} }
public void Dispose() public void Dispose()