diff --git a/src/Ocelot/QueryStrings/AddQueriesToRequest.cs b/src/Ocelot/QueryStrings/AddQueriesToRequest.cs index 25e9772a..74cc7696 100644 --- a/src/Ocelot/QueryStrings/AddQueriesToRequest.cs +++ b/src/Ocelot/QueryStrings/AddQueriesToRequest.cs @@ -7,6 +7,8 @@ using Ocelot.Responses; using System.Security.Claims; using System.Net.Http; using System; +using Microsoft.Extensions.Primitives; +using System.Text; namespace Ocelot.QueryStrings { @@ -45,6 +47,7 @@ namespace Ocelot.QueryStrings } var uriBuilder = new UriBuilder(downstreamRequest.RequestUri); + uriBuilder.Query = ConvertDictionaryToQueryString(queryDictionary); downstreamRequest.RequestUri = uriBuilder.Uri; @@ -52,16 +55,43 @@ namespace Ocelot.QueryStrings return new OkResponse(); } - private Dictionary ConvertQueryStringToDictionary(string queryString) + private Dictionary ConvertQueryStringToDictionary(string queryString) { - return Microsoft.AspNetCore.WebUtilities.QueryHelpers - .ParseQuery(queryString) - .ToDictionary(q => q.Key, q => q.Value.FirstOrDefault() ?? string.Empty); + var query = Microsoft.AspNetCore.WebUtilities.QueryHelpers + .ParseQuery(queryString); + + return query; } - private string ConvertDictionaryToQueryString(Dictionary queryDictionary) + private string ConvertDictionaryToQueryString(Dictionary queryDictionary) { - return Microsoft.AspNetCore.WebUtilities.QueryHelpers.AddQueryString("", queryDictionary); + var builder = new StringBuilder(); + + builder.Append("?"); + + int outerCount = 0; + + foreach (var query in queryDictionary) + { + for (int innerCount = 0; innerCount < query.Value.Count; innerCount++) + { + builder.Append($"{query.Key}={query.Value[innerCount]}"); + + if(innerCount < (query.Value.Count - 1)) + { + builder.Append("&"); + } + } + + if(outerCount < (queryDictionary.Count - 1)) + { + builder.Append("&"); + } + + outerCount++; + } + + return builder.ToString(); } } } \ No newline at end of file diff --git a/test/Ocelot.AcceptanceTests/ClaimsToQueryStringForwardingTests.cs b/test/Ocelot.AcceptanceTests/ClaimsToQueryStringForwardingTests.cs index f73a360e..1b9469e1 100644 --- a/test/Ocelot.AcceptanceTests/ClaimsToQueryStringForwardingTests.cs +++ b/test/Ocelot.AcceptanceTests/ClaimsToQueryStringForwardingTests.cs @@ -19,6 +19,7 @@ namespace Ocelot.AcceptanceTests { using IdentityServer4; using IdentityServer4.Test; + using Shouldly; public class ClaimsToQueryStringForwardingTests : IDisposable { @@ -27,6 +28,7 @@ namespace Ocelot.AcceptanceTests private readonly Steps _steps; private Action _options; private string _identityServerRootUrl = "http://localhost:57888"; + private string _downstreamQueryString; public ClaimsToQueryStringForwardingTests() { @@ -105,6 +107,71 @@ namespace Ocelot.AcceptanceTests .BDDfy(); } + [Fact] + public void should_return_response_200_and_foward_claim_as_query_string_and_preserve_original_string() + { + var user = new TestUser() + { + Username = "test", + Password = "test", + SubjectId = "registered|1231231", + Claims = new List + { + new Claim("CustomerId", "123"), + new Claim("LocationId", "1") + } + }; + + var configuration = new FileConfiguration + { + ReRoutes = new List + { + new FileReRoute + { + DownstreamPathTemplate = "/", + DownstreamHostAndPorts = new List + { + new FileHostAndPort + { + Host = "localhost", + Port = 57876, + } + }, + DownstreamScheme = "http", + UpstreamPathTemplate = "/", + UpstreamHttpMethod = new List { "Get" }, + AuthenticationOptions = new FileAuthenticationOptions + { + AuthenticationProviderKey = "Test", + AllowedScopes = new List + { + "openid", "offline_access", "api" + }, + }, + AddQueriesToRequest = + { + {"CustomerId", "Claims[CustomerId] > value"}, + {"LocationId", "Claims[LocationId] > value"}, + {"UserType", "Claims[sub] > value[0] > |"}, + {"UserId", "Claims[sub] > value[1] > |"} + } + } + } + }; + + this.Given(x => x.GivenThereIsAnIdentityServerOn("http://localhost:57888", "api", AccessTokenType.Jwt, user)) + .And(x => x.GivenThereIsAServiceRunningOn("http://localhost:57876", 200)) + .And(x => _steps.GivenIHaveAToken("http://localhost:57888")) + .And(x => _steps.GivenThereIsAConfiguration(configuration)) + .And(x => _steps.GivenOcelotIsRunning(_options, "Test")) + .And(x => _steps.GivenIHaveAddedATokenToMyRequest()) + .When(x => _steps.WhenIGetUrlOnTheApiGateway("/?test=1&test=2")) + .Then(x => _steps.ThenTheStatusCodeShouldBe(HttpStatusCode.OK)) + .And(x => _steps.ThenTheResponseBodyShouldBe("CustomerId: 123 LocationId: 1 UserType: registered UserId: 1231231")) + .And(_ => _downstreamQueryString.ShouldBe("?test=1&test=2&CustomerId=123&LocationId=1&UserId=1231231&UserType=registered")) + .BDDfy(); + } + private void GivenThereIsAServiceRunningOn(string url, int statusCode) { _servicebuilder = new WebHostBuilder() @@ -117,6 +184,8 @@ namespace Ocelot.AcceptanceTests { app.Run(async context => { + _downstreamQueryString = context.Request.QueryString.Value; + StringValues customerId; context.Request.Query.TryGetValue("CustomerId", out customerId); diff --git a/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs b/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs index d12d4294..bebe4307 100644 --- a/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs +++ b/test/Ocelot.UnitTests/QueryStrings/AddQueriesToRequestTests.cs @@ -18,7 +18,7 @@ namespace Ocelot.UnitTests.QueryStrings public class AddQueriesToRequestTests { private readonly AddQueriesToRequest _addQueriesToRequest; - private readonly HttpRequestMessage _downstreamRequest; + private HttpRequestMessage _downstreamRequest; private readonly Mock _parser; private List _configuration; private List _claims; @@ -53,6 +53,34 @@ namespace Ocelot.UnitTests.QueryStrings .BDDfy(); } + [Fact] + public void should_add_new_queries_to_downstream_request_and_preserve_other_queries() + { + var claims = new List + { + new Claim("test", "data") + }; + + this.Given( + x => x.GivenAClaimToThing(new List + { + new ClaimToThing("query-key", "", "", 0) + })) + .Given(x => x.GivenClaims(claims)) + .And(x => GivenTheDownstreamRequestHasQueryString("?test=1&test=2")) + .And(x => x.GivenTheClaimParserReturns(new OkResponse("value"))) + .When(x => x.WhenIAddQueriesToTheRequest()) + .Then(x => x.ThenTheResultIsSuccess()) + .And(x => x.ThenTheQueryIsAdded()) + .And(x => TheTheQueryStringIs("?test=1&test=2&query-key=value")) + .BDDfy(); + } + + private void TheTheQueryStringIs(string expected) + { + _downstreamRequest.RequestUri.Query.ShouldBe(expected); + } + [Fact] public void should_replace_existing_queries_on_downstream_request() { @@ -110,6 +138,11 @@ namespace Ocelot.UnitTests.QueryStrings _claims = claims; } + private void GivenTheDownstreamRequestHasQueryString(string queryString) + { + _downstreamRequest = new HttpRequestMessage(HttpMethod.Post, $"http://my.url/abc{queryString}"); + } + private void GivenTheDownstreamRequestHasQueryString(string key, string value) { var newUri = Microsoft.AspNetCore.WebUtilities.QueryHelpers