Skip to content

Commit

Permalink
Merge client- and request-level cookies in the header. Also, dispose …
Browse files Browse the repository at this point in the history
…the request if not downloading data. (#2056)
  • Loading branch information
alexeyzimarev committed Apr 9, 2023
1 parent e147e1c commit 794348e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 22 deletions.
31 changes: 31 additions & 0 deletions src/RestSharp/Extensions/CookieContainerExtensions.cs
@@ -0,0 +1,31 @@
// Copyright (c) .NET Foundation and Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

using System.Net;

namespace RestSharp.Extensions;

static class CookieContainerExtensions {
public static void AddCookies(this CookieContainer cookieContainer, Uri uri, IEnumerable<string> cookiesHeader) {
foreach (var header in cookiesHeader) {
try {
cookieContainer.SetCookies(uri, header);
}
catch (CookieException) {
// Do not fail request if we cannot parse a cookie
}
}
}
}
17 changes: 14 additions & 3 deletions src/RestSharp/Request/RequestHeaders.cs
Expand Up @@ -39,13 +39,24 @@ class RequestHeaders {
}

// Add Cookie header from the cookie container
public RequestHeaders AddCookieHeaders(CookieContainer cookieContainer, Uri uri) {
public RequestHeaders AddCookieHeaders(Uri uri, CookieContainer? cookieContainer) {
if (cookieContainer == null) return this;

var cookies = cookieContainer.GetCookieHeader(uri);

if (cookies.Length > 0) {
Parameters.AddParameter(new HeaderParameter(KnownHeaders.Cookie, cookies));
if (string.IsNullOrWhiteSpace(cookies)) return this;

var newCookies = SplitHeader(cookies);
var existing = Parameters.GetParameters<HeaderParameter>().FirstOrDefault(x => x.Name == KnownHeaders.Cookie);

if (existing?.Value != null) {
newCookies = newCookies.Union(SplitHeader(existing.Value.ToString()!));
}

Parameters.AddParameter(new HeaderParameter(KnownHeaders.Cookie, string.Join("; ", newCookies)));

return this;

IEnumerable<string> SplitHeader(string header) => header.Split(';').Select(x => x.Trim());
}
}
28 changes: 12 additions & 16 deletions src/RestSharp/RestClient.Async.cs
Expand Up @@ -21,7 +21,7 @@ namespace RestSharp;
public partial class RestClient {
/// <inheritdoc />
public async Task<RestResponse> ExecuteAsync(RestRequest request, CancellationToken cancellationToken = default) {
var internalResponse = await ExecuteRequestAsync(request, cancellationToken).ConfigureAwait(false);
using var internalResponse = await ExecuteRequestAsync(request, cancellationToken).ConfigureAwait(false);

var response = internalResponse.Exception == null
? await RestResponse.FromHttpResponse(
Expand Down Expand Up @@ -85,7 +85,8 @@ public partial class RestClient {

var httpMethod = AsHttpMethod(request.Method);
var url = this.BuildUri(request);
var message = new HttpRequestMessage(httpMethod, url) { Content = requestContent.BuildContent() };

using var message = new HttpRequestMessage(httpMethod, url) { Content = requestContent.BuildContent() };
message.Headers.Host = Options.BaseHost;
message.Headers.CacheControl = request.CachePolicy ?? Options.CachePolicy;

Expand All @@ -102,11 +103,8 @@ public partial class RestClient {
.AddHeaders(request.Parameters)
.AddHeaders(DefaultParameters)
.AddAcceptHeader(AcceptedContentTypes)
.AddCookieHeaders(cookieContainer, url);

if (Options.CookieContainer != null) {
headers.AddCookieHeaders(Options.CookieContainer, url);
}
.AddCookieHeaders(url, cookieContainer)
.AddCookieHeaders(url, Options.CookieContainer);

message.AddHeaders(headers);

Expand All @@ -116,14 +114,10 @@ public partial class RestClient {

// Parse all the cookies from the response and update the cookie jar with cookies
if (responseMessage.Headers.TryGetValues(KnownHeaders.SetCookie, out var cookiesHeader)) {
foreach (var header in cookiesHeader) {
try {
cookieContainer.SetCookies(url, header);
}
catch (CookieException) {
// Do not fail request if we cannot parse a cookie
}
}
// ReSharper disable once PossibleMultipleEnumeration
cookieContainer.AddCookies(url, cookiesHeader);
// ReSharper disable once PossibleMultipleEnumeration
Options.CookieContainer?.AddCookies(url, cookiesHeader);
}

if (request.OnAfterRequest != null) await request.OnAfterRequest(responseMessage).ConfigureAwait(false);
Expand All @@ -141,7 +135,9 @@ record HttpResponse(
CookieContainer? CookieContainer,
Exception? Exception,
CancellationToken TimeoutToken
);
) : IDisposable {
public void Dispose() => ResponseMessage?.Dispose();
}

static HttpMethod AsHttpMethod(Method method)
=> method switch {
Expand Down
24 changes: 21 additions & 3 deletions test/RestSharp.Tests.Integrated/CookieTests.cs
Expand Up @@ -9,7 +9,10 @@ public class CookieTests {
readonly string _host;

public CookieTests(TestServerFixture fixture) {
_client = new RestClient(fixture.Server.Url);
var options = new RestClientOptions(fixture.Server.Url) {
CookieContainer = new CookieContainer()
};
_client = new RestClient(options);
_host = _client.Options.BaseUrl!.Host;
}

Expand All @@ -24,6 +27,21 @@ public class CookieTests {
response.Content.Should().Be("[\"cookie=value\",\"cookie2=value2\"]");
}

[Fact]
public async Task Can_Perform_GET_Async_With_Request_And_Client_Cookies() {
_client.Options.CookieContainer!.Add(new Cookie("clientCookie", "clientCookieValue", null, _host));

var request = new RestRequest("get-cookies") {
CookieContainer = new CookieContainer()
};
request.CookieContainer.Add(new Cookie("cookie", "value", null, _host));
request.CookieContainer.Add(new Cookie("cookie2", "value2", null, _host));
var response = await _client.ExecuteAsync<string[]>(request);

var expected = new[] { "cookie=value", "cookie2=value2", "clientCookie=clientCookieValue" };
response.Data.Should().BeEquivalentTo(expected);
}

[Fact]
public async Task Can_Perform_GET_Async_With_Response_Cookies() {
var request = new RestRequest("set-cookies");
Expand All @@ -37,7 +55,7 @@ public class CookieTests {
FindCookie("cookie5").Should().BeNull("Cookie 5 should vanish as the request is not SSL");
AssertCookie("cookie6", "value6", x => x == DateTime.MinValue, true);

Cookie? FindCookie(string name) =>response!.Cookies!.FirstOrDefault(p => p.Name == name);
Cookie? FindCookie(string name) => response.Cookies!.FirstOrDefault(p => p.Name == name);

void AssertCookie(string name, string value, Func<DateTime, bool> checkExpiration, bool httpOnly = false) {
var c = FindCookie(name)!;
Expand All @@ -62,7 +80,7 @@ public class CookieTests {
.SingleOrDefault(h => h.Name == KnownHeaders.SetCookie && ((string)h.Value!).StartsWith("cookie_empty_domain"));
emptyDomainCookieHeader.Should().NotBeNull();
((string)emptyDomainCookieHeader!.Value!).Should().Contain("domain=;");

Cookie? FindCookie(string name) => response!.Cookies!.FirstOrDefault(p => p.Name == name);
}
}

0 comments on commit 794348e

Please sign in to comment.