Skip to content

Commit

Permalink
Add CacheControl (#2053)
Browse files Browse the repository at this point in the history
Cleanup response writer code
  • Loading branch information
alexeyzimarev committed Apr 9, 2023
1 parent edce8c9 commit e147e1c
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 101 deletions.
53 changes: 49 additions & 4 deletions src/RestSharp/Extensions/HttpResponseExtensions.cs
Expand Up @@ -13,15 +13,60 @@
// limitations under the License.
//

using System.Text;

namespace RestSharp.Extensions;

public static class HttpResponseExtensions {
internal static Exception? MaybeException(this HttpResponseMessage httpResponse)
static class HttpResponseExtensions {
public static Exception? MaybeException(this HttpResponseMessage httpResponse)
=> httpResponse.IsSuccessStatusCode
? null
#if NETSTANDARD || NETFRAMEWORK
#if NET
: new HttpRequestException($"Request failed with status code {httpResponse.StatusCode}", null, httpResponse.StatusCode);
#else
: new HttpRequestException($"Request failed with status code {httpResponse.StatusCode}");
#endif

public static string GetResponseString(this HttpResponseMessage response, byte[] bytes, Encoding clientEncoding) {
var encodingString = response.Content.Headers.ContentType?.CharSet;
var encoding = encodingString != null ? TryGetEncoding(encodingString) : clientEncoding;

using var reader = new StreamReader(new MemoryStream(bytes), encoding);
return reader.ReadToEnd();

Encoding TryGetEncoding(string es) {
try {
return Encoding.GetEncoding(es);
}
catch {
return Encoding.Default;
}
}
}

public static Task<Stream?> ReadResponseStream(
this HttpResponseMessage httpResponse,
Func<Stream, Stream?>? writer,
CancellationToken cancellationToken = default
) {
var readTask = writer == null ? ReadResponse() : ReadAndConvertResponse();
return readTask;

Task<Stream?> ReadResponse() {
#if NET
return httpResponse.Content.ReadAsStreamAsync(cancellationToken)!;
# else
return httpResponse.Content.ReadAsStreamAsync();
#endif
}

async Task<Stream?> ReadAndConvertResponse() {
#if NET
await using var original = await ReadResponse().ConfigureAwait(false);
#else
: new HttpRequestException($"Request failed with status code {httpResponse.StatusCode}", null, httpResponse.StatusCode);
using var original = await ReadResponse().ConfigureAwait(false);
#endif
return writer!(original!);
}
}
}
6 changes: 3 additions & 3 deletions src/RestSharp/Extensions/StreamExtensions.cs
Expand Up @@ -30,10 +30,10 @@ static class StreamExtensions {
using var ms = new MemoryStream();

int read;
#if NETSTANDARD || NETFRAMEWORK
while ((read = await input.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)) > 0)
#else
#if NET
while ((read = await input.ReadAsync(buffer, cancellationToken).ConfigureAwait(false)) > 0)
#else
while ((read = await input.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)) > 0)
#endif
ms.Write(buffer, 0, read);

Expand Down
2 changes: 1 addition & 1 deletion src/RestSharp/Properties/IsExternalInit.cs
@@ -1,4 +1,4 @@
#if NETSTANDARD || NETFRAMEWORK
#if !NET
using System.ComponentModel;

// ReSharper disable once CheckNamespace
Expand Down
14 changes: 10 additions & 4 deletions src/RestSharp/Request/RestRequest.cs
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

using System.Net;
using System.Net.Http.Headers;
using RestSharp.Authenticators;
using RestSharp.Extensions;

Expand All @@ -25,8 +26,8 @@ namespace RestSharp;
/// Container for data used to make requests
/// </summary>
public class RestRequest {
readonly Func<HttpResponseMessage, RestRequest, RestResponse>? _advancedResponseHandler;
readonly Func<Stream, Stream?>? _responseWriter;
Func<HttpResponseMessage, RestRequest, RestResponse>? _advancedResponseHandler;
Func<Stream, Stream?>? _responseWriter;

/// <summary>
/// Default constructor
Expand Down Expand Up @@ -186,12 +187,17 @@ public RestRequest(Uri resource, Method method = Method.Get)
/// </summary>
public HttpCompletionOption CompletionOption { get; set; } = HttpCompletionOption.ResponseContentRead;

/// <summary>
/// Cache policy to be used for requests using <seealso cref="CacheControlHeaderValue"/>
/// </summary>
public CacheControlHeaderValue? CachePolicy { get; set; }

/// <summary>
/// Set this to write response to Stream rather than reading into memory.
/// </summary>
public Func<Stream, Stream?>? ResponseWriter {
get => _responseWriter;
init {
set {
if (AdvancedResponseWriter != null)
throw new ArgumentException(
"AdvancedResponseWriter is not null. Only one response writer can be used."
Expand All @@ -206,7 +212,7 @@ public RestRequest(Uri resource, Method method = Method.Get)
/// </summary>
public Func<HttpResponseMessage, RestRequest, RestResponse>? AdvancedResponseWriter {
get => _advancedResponseHandler;
init {
set {
if (ResponseWriter != null) throw new ArgumentException("ResponseWriter is not null. Only one response writer can be used.");

_advancedResponseHandler = value;
Expand Down
45 changes: 0 additions & 45 deletions src/RestSharp/Response/ResponseHandling.cs

This file was deleted.

20 changes: 4 additions & 16 deletions src/RestSharp/Response/RestResponse.cs
Expand Up @@ -72,14 +72,13 @@ CancellationToken cancellationToken
return request.AdvancedResponseWriter?.Invoke(httpResponse, request) ?? await GetDefaultResponse().ConfigureAwait(false);

async Task<RestResponse> GetDefaultResponse() {
var readTask = request.ResponseWriter == null ? ReadResponse() : ReadAndConvertResponse();
#if NETSTANDARD || NETFRAMEWORK
using var stream = await readTask.ConfigureAwait(false);
#if NET
await using var stream = await httpResponse.ReadResponseStream(request.ResponseWriter, cancellationToken).ConfigureAwait(false);
#else
await using var stream = await readTask.ConfigureAwait(false);
using var stream = await httpResponse.ReadResponseStream(request.ResponseWriter, cancellationToken).ConfigureAwait(false);
#endif

var bytes = request.ResponseWriter != null || stream == null ? null : await stream.ReadAsBytes(cancellationToken).ConfigureAwait(false);
var bytes = stream == null ? null : await stream.ReadAsBytes(cancellationToken).ConfigureAwait(false);
var content = bytes == null ? null : httpResponse.GetResponseString(bytes, encoding);

return new RestResponse(request) {
Expand All @@ -101,17 +100,6 @@ CancellationToken cancellationToken
Cookies = cookieCollection,
RootElement = request.RootElement
};

Task<Stream?> ReadResponse() => httpResponse.ReadResponse(cancellationToken);

async Task<Stream?> ReadAndConvertResponse() {
#if NETSTANDARD || NETFRAMEWORK
using var original = await ReadResponse().ConfigureAwait(false);
#else
await using var original = await ReadResponse().ConfigureAwait(false);
#endif
return request.ResponseWriter!(original!);
}
}
}

Expand Down
14 changes: 3 additions & 11 deletions src/RestSharp/RestClient.Async.cs
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

using System.Net;
using System.Net.Http.Headers;
using RestSharp.Extensions;

namespace RestSharp;
Expand Down Expand Up @@ -52,16 +53,7 @@ public partial class RestClient {

if (response.ResponseMessage == null) return null;

if (request.ResponseWriter != null) {
#if NETSTANDARD || NETFRAMEWORK
using var stream = await response.ResponseMessage.ReadResponse(cancellationToken).ConfigureAwait(false);
#else
await using var stream = await response.ResponseMessage.ReadResponse(cancellationToken).ConfigureAwait(false);
#endif
return request.ResponseWriter(stream!);
}

return await response.ResponseMessage.ReadResponse(cancellationToken).ConfigureAwait(false);
return await response.ResponseMessage.ReadResponseStream(request.ResponseWriter, cancellationToken).ConfigureAwait(false);
}

static RestResponse GetErrorResponse(RestRequest request, Exception exception, CancellationToken timeoutToken) {
Expand Down Expand Up @@ -95,7 +87,7 @@ public partial class RestClient {
var url = this.BuildUri(request);
var message = new HttpRequestMessage(httpMethod, url) { Content = requestContent.BuildContent() };
message.Headers.Host = Options.BaseHost;
message.Headers.CacheControl = Options.CachePolicy;
message.Headers.CacheControl = request.CachePolicy ?? Options.CachePolicy;

using var timeoutCts = new CancellationTokenSource(request.Timeout > 0 ? request.Timeout : int.MaxValue);
using var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken);
Expand Down
6 changes: 3 additions & 3 deletions src/RestSharp/RestClient.Extensions.cs
Expand Up @@ -294,10 +294,10 @@ public static RestResponse Post(this IRestClient client, RestRequest request)
/// <returns>The downloaded file.</returns>
[PublicAPI]
public static async Task<byte[]?> DownloadDataAsync(this IRestClient client, RestRequest request, CancellationToken cancellationToken = default) {
#if NETSTANDARD || NETFRAMEWORK
using var stream = await client.DownloadStreamAsync(request, cancellationToken).ConfigureAwait(false);
#else
#if NET
await using var stream = await client.DownloadStreamAsync(request, cancellationToken).ConfigureAwait(false);
#else
using var stream = await client.DownloadStreamAsync(request, cancellationToken).ConfigureAwait(false);
#endif
return stream == null ? null : await stream.ReadAsBytes(cancellationToken).ConfigureAwait(false);
}
Expand Down
31 changes: 17 additions & 14 deletions test/RestSharp.Tests.Integrated/DownloadFileTests.cs
Expand Up @@ -34,13 +34,15 @@ public sealed class DownloadFileTests : IDisposable {
public async Task AdvancedResponseWriter_without_ResponseWriter_reads_stream() {
var tag = string.Empty;

var rr = new RestRequest("Assets/Koala.jpg") {
AdvancedResponseWriter = (response, request) => {
var buf = new byte[16];
response.Content.ReadAsStream().Read(buf, 0, buf.Length);
tag = Encoding.ASCII.GetString(buf, 6, 4);
return new RestResponse(request);
}
// ReSharper disable once UseObjectOrCollectionInitializer
var rr = new RestRequest("Assets/Koala.jpg");

rr.AdvancedResponseWriter = (response, request) => {
var buf = new byte[16];
// ReSharper disable once MustUseReturnValue
response.Content.ReadAsStream().Read(buf, 0, buf.Length);
tag = Encoding.ASCII.GetString(buf, 6, 4);
return new RestResponse(request);
};

await _client.ExecuteAsync(rr);
Expand All @@ -50,7 +52,7 @@ public sealed class DownloadFileTests : IDisposable {
[Fact]
public async Task Handles_File_Download_Failure() {
var request = new RestRequest("Assets/Koala1.jpg");
var task = () => _client.DownloadDataAsync(request);
var task = () => _client.DownloadDataAsync(request);
await task.Should().ThrowAsync<HttpRequestException>().WithMessage("Request failed with status code NotFound");
}

Expand All @@ -67,13 +69,14 @@ public sealed class DownloadFileTests : IDisposable {
public async Task Writes_Response_To_Stream() {
var tempFile = Path.GetTempFileName();

var request = new RestRequest("Assets/Koala.jpg") {
ResponseWriter = responseStream => {
using var writer = File.OpenWrite(tempFile);
// ReSharper disable once UseObjectOrCollectionInitializer
var request = new RestRequest("Assets/Koala.jpg");

responseStream.CopyTo(writer);
return null;
}
request.ResponseWriter = responseStream => {
using var writer = File.OpenWrite(tempFile);
responseStream.CopyTo(writer);
return null;
};
var response = await _client.DownloadDataAsync(request);

Expand Down
46 changes: 46 additions & 0 deletions test/RestSharp.Tests.Integrated/RedirectTests.cs
@@ -0,0 +1,46 @@
// Copyright (c) .NET Foundation and Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

using System.Net;
using RestSharp.Tests.Integrated.Server;

namespace RestSharp.Tests.Integrated;

[Collection(nameof(TestServerCollection))]
public class RedirectTests {
readonly RestClient _client;

public RedirectTests(TestServerFixture fixture, ITestOutputHelper output) {
var options = new RestClientOptions(fixture.Server.Url) {
FollowRedirects = true
};
_client = new RestClient(options);
}

[Fact]
public async Task Can_Perform_GET_Async_With_Redirect() {
const string val = "Works!";

var request = new RestRequest("redirect");

var response = await _client.ExecuteAsync<Response>(request);
response.StatusCode.Should().Be(HttpStatusCode.OK);
response.Data!.Message.Should().Be(val);
}

class Response {
public string? Message { get; set; }
}
}
1 change: 1 addition & 0 deletions test/RestSharp.Tests.Integrated/Server/TestServer.cs
Expand Up @@ -40,6 +40,7 @@ public sealed class HttpServer {
// Cookies
_app.MapGet("get-cookies", CookieHandlers.HandleCookies);
_app.MapGet("set-cookies", CookieHandlers.HandleSetCookies);
_app.MapGet("redirect", () => Results.Redirect("/success", false, true));

// PUT
_app.MapPut(
Expand Down
16 changes: 16 additions & 0 deletions test/RestSharp.Tests/OptionsTests.cs
@@ -0,0 +1,16 @@
namespace RestSharp.Tests;

public class OptionsTests {
[Fact]
public void Ensure_follow_redirect() {
var value = false;
var options = new RestClientOptions { FollowRedirects = true, ConfigureMessageHandler = Configure};
var _ = new RestClient(options);
value.Should().BeTrue();

HttpMessageHandler Configure(HttpMessageHandler handler) {
value = ((handler as HttpClientHandler)!).AllowAutoRedirect;
return handler;
}
}
}

0 comments on commit e147e1c

Please sign in to comment.