Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make auth infer opt in #2460

Merged
merged 2 commits into from Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -4,6 +4,7 @@
using Microsoft.OpenApi.Models;
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Swashbuckle.AspNetCore.SwaggerGen;
using Microsoft.AspNetCore.Authentication;

namespace Microsoft.Extensions.DependencyInjection
{
Expand Down Expand Up @@ -308,6 +309,22 @@ public static void SupportNonNullableReferenceTypes(this SwaggerGenOptions swagg
swaggerGenOptions.SchemaGeneratorOptions.SupportNonNullableReferenceTypes = true;
}

/// <summary>
/// Automatically infer security schemes from authentication/authorization state in ASP.NET Core.
/// </summary>
/// <param name="swaggerGenOptions"></param>
/// <param name="securitySchemesSelector">
/// Provide alternative implementation that maps ASP.NET Core Authentication schemes to Open API security schemes
/// </param>
/// <remarks>Currently only supports JWT Bearer authentication</remarks>
public static void InferSecuritySchemes(
this SwaggerGenOptions swaggerGenOptions,
Func<IEnumerable<AuthenticationScheme>, IDictionary<string, OpenApiSecurityScheme>> securitySchemesSelector = null)
{
swaggerGenOptions.SwaggerGeneratorOptions.InferSecuritySchemes = true;
swaggerGenOptions.SwaggerGeneratorOptions.SecuritySchemesSelector = securitySchemesSelector;
}

/// <summary>
/// Extend the Swagger Generator with "filters" that can modify Schemas after they're initially generated
/// </summary>
Expand Down
Expand Up @@ -34,26 +34,48 @@ public class SwaggerGenerator : ISwaggerProvider, IAsyncSwaggerProvider
SwaggerGeneratorOptions options,
IApiDescriptionGroupCollectionProvider apiDescriptionsProvider,
ISchemaGenerator schemaGenerator,
IAuthenticationSchemeProvider authentiationSchemeProvider) : this(options, apiDescriptionsProvider, schemaGenerator)
IAuthenticationSchemeProvider authenticationSchemeProvider) : this(options, apiDescriptionsProvider, schemaGenerator)
{
_authenticationSchemeProvider = authentiationSchemeProvider;
_authenticationSchemeProvider = authenticationSchemeProvider;
}

public async Task<OpenApiDocument> GetSwaggerAsync(string documentName, string host = null, string basePath = null)
{
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocument(documentName, host, basePath);
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocumentWithoutFilters(documentName, host, basePath);

swaggerDoc.Components.SecuritySchemes = await GetSecuritySchemes();

// NOTE: Filter processing moved here so they may effect generated security schemes
var filterContext = new DocumentFilterContext(applicableApiDescriptions, _schemaGenerator, schemaRepository);
foreach (var filter in _options.DocumentFilters)
{
filter.Apply(swaggerDoc, filterContext);
}

swaggerDoc.Components.Schemas = new SortedDictionary<string, OpenApiSchema>(swaggerDoc.Components.Schemas, _options.SchemaComparer);

return swaggerDoc;
}

public OpenApiDocument GetSwagger(string documentName, string host = null, string basePath = null)
{
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocument(documentName, host, basePath);
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocumentWithoutFilters(documentName, host, basePath);

swaggerDoc.Components.SecuritySchemes = GetSecuritySchemes().Result;

// NOTE: Filter processing moved here so they may effect generated security schemes
var filterContext = new DocumentFilterContext(applicableApiDescriptions, _schemaGenerator, schemaRepository);
foreach (var filter in _options.DocumentFilters)
{
filter.Apply(swaggerDoc, filterContext);
}

swaggerDoc.Components.Schemas = new SortedDictionary<string, OpenApiSchema>(swaggerDoc.Components.Schemas, _options.SchemaComparer);

return swaggerDoc;
}

private (IEnumerable<ApiDescription>, OpenApiDocument, SchemaRepository) GetSwaggerDocument(string documentName, string host = null, string basePath = null)
private (IEnumerable<ApiDescription>, OpenApiDocument, SchemaRepository) GetSwaggerDocumentWithoutFilters(string documentName, string host = null, string basePath = null)
{
if (!_options.SwaggerDocs.TryGetValue(documentName, out OpenApiInfo info))
throw new UnknownSwaggerDocument(documentName, _options.SwaggerDocs.Select(d => d.Key));
Expand All @@ -77,38 +99,37 @@ public OpenApiDocument GetSwagger(string documentName, string host = null, strin
SecurityRequirements = new List<OpenApiSecurityRequirement>(_options.SecurityRequirements)
};

var filterContext = new DocumentFilterContext(applicableApiDescriptions, _schemaGenerator, schemaRepository);
foreach (var filter in _options.DocumentFilters)
{
filter.Apply(swaggerDoc, filterContext);
}

swaggerDoc.Components.Schemas = new SortedDictionary<string, OpenApiSchema>(swaggerDoc.Components.Schemas, _options.SchemaComparer);

return (applicableApiDescriptions, swaggerDoc, schemaRepository);
}

private async Task<Dictionary<string, OpenApiSecurityScheme>> GetSecuritySchemes()
private async Task<IDictionary<string, OpenApiSecurityScheme>> GetSecuritySchemes()
{
var securitySchemes = new Dictionary<string, OpenApiSecurityScheme>(_options.SecuritySchemes);
var authenticationSchemes = Enumerable.Empty<AuthenticationScheme>();
if (_authenticationSchemeProvider is not null)
if (!_options.InferSecuritySchemes)
{
authenticationSchemes = await _authenticationSchemeProvider.GetAllSchemesAsync();
return new Dictionary<string, OpenApiSecurityScheme>(_options.SecuritySchemes);
}
var securitySchemesFromSelector = _options.SecuritySchemesSelector(authenticationSchemes);
// Favor security schemes set via options over those generated
// from the selector. For the default selector, this effectively
// ends up favoring `Bearer` authentication types explicitly set
// by the user over those derived by the selector.
foreach (var securityScheme in securitySchemesFromSelector)

var authenticationSchemes = (_authenticationSchemeProvider is not null)
? await _authenticationSchemeProvider.GetAllSchemesAsync()
: Enumerable.Empty<AuthenticationScheme>();

if (_options.SecuritySchemesSelector != null)
{
if (!securitySchemes.ContainsKey(securityScheme.Key))
{
securitySchemes.Add(securityScheme.Key, securityScheme.Value);
}
return _options.SecuritySchemesSelector(authenticationSchemes);
}
return securitySchemes;

// Default implementation, currently only supports JWT Bearer scheme
return authenticationSchemes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize that this commit removed the code that would've prevented the Bearer auth scheme from being generated if there was already one provided with options.SecuritySchemes. Perhaps we should add that back here?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combining the use of options.SecuritySchemes with the infer behaviour (either default or via SecuritySchemesSelector), with trumping rules etc, feels like a recipe for disaster. So, for this PR, I've simplified the approach so that you either set the security schemes explicitly (via options.SecuritySchemes) OR you "opt-in" to the infer behavior (via InferSecuritySchemes()). So, you can use one approach or the other not a combination of both.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And of course, in both cases schema filters may be applied for any final tweaks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fair! I think the schema filters ordering is the most important bit here.

Will this be shipped in a patch or minor release?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will ship as a minor version release

.Where(authScheme => authScheme.Name == "Bearer")
.ToDictionary(
(authScheme) => authScheme.Name,
(authScheme) => new OpenApiSecurityScheme
{
Type = SecuritySchemeType.Http,
Scheme = "bearer", // "bearer" refers to the header name here
In = ParameterLocation.Header,
BearerFormat = "Json Web Token"
});
}

private IList<OpenApiServer> GenerateServers(string host, string basePath)
Expand Down
Expand Up @@ -20,7 +20,7 @@ public SwaggerGeneratorOptions()
OperationIdSelector = DefaultOperationIdSelector;
TagsSelector = DefaultTagsSelector;
SortKeySelector = DefaultSortKeySelector;
SecuritySchemesSelector = DefaultSecuritySchemeSelector;
SecuritySchemesSelector = null;
SchemaComparer = StringComparer.Ordinal;
Servers = new List<OpenApiServer>();
SecuritySchemes = new Dictionary<string, OpenApiSecurityScheme>();
Expand All @@ -45,6 +45,10 @@ public SwaggerGeneratorOptions()

public Func<ApiDescription, string> SortKeySelector { get; set; }

public bool InferSecuritySchemes { get; set; }

public Func<IEnumerable<AuthenticationScheme>, IDictionary<string, OpenApiSecurityScheme>> SecuritySchemesSelector { get; set;}

public bool DescribeAllParametersInCamelCase { get; set; }

public List<OpenApiServer> Servers { get; set; }
Expand All @@ -63,8 +67,6 @@ public SwaggerGeneratorOptions()

public IList<IDocumentFilter> DocumentFilters { get; set; }

public Func<IEnumerable<AuthenticationScheme>, Dictionary<string, OpenApiSecurityScheme>> SecuritySchemesSelector { get; set;}

private bool DefaultDocInclusionPredicate(string documentName, ApiDescription apiDescription)
{
return apiDescription.GroupName == null || apiDescription.GroupName == documentName;
Expand Down Expand Up @@ -106,26 +108,5 @@ private string DefaultSortKeySelector(ApiDescription apiDescription)
{
return TagsSelector(apiDescription).First();
}

private Dictionary<string, OpenApiSecurityScheme> DefaultSecuritySchemeSelector(IEnumerable<AuthenticationScheme> schemes)
{
Dictionary<string, OpenApiSecurityScheme> securitySchemes = new();
#if (NET6_0_OR_GREATER)
foreach (var scheme in schemes)
{
if (scheme.Name == "Bearer")
{
securitySchemes[scheme.Name] = new OpenApiSecurityScheme
{
Type = SecuritySchemeType.Http,
Scheme = "bearer", // "bearer" refers to the header name here
In = ParameterLocation.Header,
BearerFormat = "Json Web Token"
};
}
}
#endif
return securitySchemes;
}
}
}
@@ -0,0 +1,47 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;

namespace Swashbuckle.AspNetCore.SwaggerGen.Test
{
public class FakeAuthenticationSchemeProvider : IAuthenticationSchemeProvider
{
private readonly IEnumerable<AuthenticationScheme> _authenticationSchemes;

public FakeAuthenticationSchemeProvider(IEnumerable<AuthenticationScheme> authenticationSchemes)
{
_authenticationSchemes = authenticationSchemes;
}

public void AddScheme(AuthenticationScheme scheme)
=> throw new NotImplementedException();
public Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
=> Task.FromResult(_authenticationSchemes);

public Task<AuthenticationScheme> GetDefaultAuthenticateSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultChallengeSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultForbidSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultSignInSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultSignOutSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
=> throw new NotImplementedException();

public Task<AuthenticationScheme> GetSchemeAsync(string name)
=> Task.FromResult(_authenticationSchemes.First());

public void RemoveScheme(string name)
=> throw new NotImplementedException();
}
}