Skip to content

Commit

Permalink
Merge pull request #2460 from domaindrivendev/make-auth-infer-opt-in
Browse files Browse the repository at this point in the history
Make auth infer opt in
  • Loading branch information
domaindrivendev committed Jul 19, 2022
2 parents 71ed7d3 + ddd0627 commit 426c5bb
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 130 deletions.
Expand Up @@ -4,6 +4,7 @@
using Microsoft.OpenApi.Models;
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Swashbuckle.AspNetCore.SwaggerGen;
using Microsoft.AspNetCore.Authentication;

namespace Microsoft.Extensions.DependencyInjection
{
Expand Down Expand Up @@ -308,6 +309,22 @@ public static void SupportNonNullableReferenceTypes(this SwaggerGenOptions swagg
swaggerGenOptions.SchemaGeneratorOptions.SupportNonNullableReferenceTypes = true;
}

/// <summary>
/// Automatically infer security schemes from authentication/authorization state in ASP.NET Core.
/// </summary>
/// <param name="swaggerGenOptions"></param>
/// <param name="securitySchemesSelector">
/// Provide alternative implementation that maps ASP.NET Core Authentication schemes to Open API security schemes
/// </param>
/// <remarks>Currently only supports JWT Bearer authentication</remarks>
public static void InferSecuritySchemes(
this SwaggerGenOptions swaggerGenOptions,
Func<IEnumerable<AuthenticationScheme>, IDictionary<string, OpenApiSecurityScheme>> securitySchemesSelector = null)
{
swaggerGenOptions.SwaggerGeneratorOptions.InferSecuritySchemes = true;
swaggerGenOptions.SwaggerGeneratorOptions.SecuritySchemesSelector = securitySchemesSelector;
}

/// <summary>
/// Extend the Swagger Generator with "filters" that can modify Schemas after they're initially generated
/// </summary>
Expand Down
Expand Up @@ -34,26 +34,48 @@ public class SwaggerGenerator : ISwaggerProvider, IAsyncSwaggerProvider
SwaggerGeneratorOptions options,
IApiDescriptionGroupCollectionProvider apiDescriptionsProvider,
ISchemaGenerator schemaGenerator,
IAuthenticationSchemeProvider authentiationSchemeProvider) : this(options, apiDescriptionsProvider, schemaGenerator)
IAuthenticationSchemeProvider authenticationSchemeProvider) : this(options, apiDescriptionsProvider, schemaGenerator)
{
_authenticationSchemeProvider = authentiationSchemeProvider;
_authenticationSchemeProvider = authenticationSchemeProvider;
}

public async Task<OpenApiDocument> GetSwaggerAsync(string documentName, string host = null, string basePath = null)
{
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocument(documentName, host, basePath);
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocumentWithoutFilters(documentName, host, basePath);

swaggerDoc.Components.SecuritySchemes = await GetSecuritySchemes();

// NOTE: Filter processing moved here so they may effect generated security schemes
var filterContext = new DocumentFilterContext(applicableApiDescriptions, _schemaGenerator, schemaRepository);
foreach (var filter in _options.DocumentFilters)
{
filter.Apply(swaggerDoc, filterContext);
}

swaggerDoc.Components.Schemas = new SortedDictionary<string, OpenApiSchema>(swaggerDoc.Components.Schemas, _options.SchemaComparer);

return swaggerDoc;
}

public OpenApiDocument GetSwagger(string documentName, string host = null, string basePath = null)
{
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocument(documentName, host, basePath);
var (applicableApiDescriptions, swaggerDoc, schemaRepository) = GetSwaggerDocumentWithoutFilters(documentName, host, basePath);

swaggerDoc.Components.SecuritySchemes = GetSecuritySchemes().Result;

// NOTE: Filter processing moved here so they may effect generated security schemes
var filterContext = new DocumentFilterContext(applicableApiDescriptions, _schemaGenerator, schemaRepository);
foreach (var filter in _options.DocumentFilters)
{
filter.Apply(swaggerDoc, filterContext);
}

swaggerDoc.Components.Schemas = new SortedDictionary<string, OpenApiSchema>(swaggerDoc.Components.Schemas, _options.SchemaComparer);

return swaggerDoc;
}

private (IEnumerable<ApiDescription>, OpenApiDocument, SchemaRepository) GetSwaggerDocument(string documentName, string host = null, string basePath = null)
private (IEnumerable<ApiDescription>, OpenApiDocument, SchemaRepository) GetSwaggerDocumentWithoutFilters(string documentName, string host = null, string basePath = null)
{
if (!_options.SwaggerDocs.TryGetValue(documentName, out OpenApiInfo info))
throw new UnknownSwaggerDocument(documentName, _options.SwaggerDocs.Select(d => d.Key));
Expand All @@ -77,38 +99,37 @@ public OpenApiDocument GetSwagger(string documentName, string host = null, strin
SecurityRequirements = new List<OpenApiSecurityRequirement>(_options.SecurityRequirements)
};

var filterContext = new DocumentFilterContext(applicableApiDescriptions, _schemaGenerator, schemaRepository);
foreach (var filter in _options.DocumentFilters)
{
filter.Apply(swaggerDoc, filterContext);
}

swaggerDoc.Components.Schemas = new SortedDictionary<string, OpenApiSchema>(swaggerDoc.Components.Schemas, _options.SchemaComparer);

return (applicableApiDescriptions, swaggerDoc, schemaRepository);
}

private async Task<Dictionary<string, OpenApiSecurityScheme>> GetSecuritySchemes()
private async Task<IDictionary<string, OpenApiSecurityScheme>> GetSecuritySchemes()
{
var securitySchemes = new Dictionary<string, OpenApiSecurityScheme>(_options.SecuritySchemes);
var authenticationSchemes = Enumerable.Empty<AuthenticationScheme>();
if (_authenticationSchemeProvider is not null)
if (!_options.InferSecuritySchemes)
{
authenticationSchemes = await _authenticationSchemeProvider.GetAllSchemesAsync();
return new Dictionary<string, OpenApiSecurityScheme>(_options.SecuritySchemes);
}
var securitySchemesFromSelector = _options.SecuritySchemesSelector(authenticationSchemes);
// Favor security schemes set via options over those generated
// from the selector. For the default selector, this effectively
// ends up favoring `Bearer` authentication types explicitly set
// by the user over those derived by the selector.
foreach (var securityScheme in securitySchemesFromSelector)

var authenticationSchemes = (_authenticationSchemeProvider is not null)
? await _authenticationSchemeProvider.GetAllSchemesAsync()
: Enumerable.Empty<AuthenticationScheme>();

if (_options.SecuritySchemesSelector != null)
{
if (!securitySchemes.ContainsKey(securityScheme.Key))
{
securitySchemes.Add(securityScheme.Key, securityScheme.Value);
}
return _options.SecuritySchemesSelector(authenticationSchemes);
}
return securitySchemes;

// Default implementation, currently only supports JWT Bearer scheme
return authenticationSchemes
.Where(authScheme => authScheme.Name == "Bearer")
.ToDictionary(
(authScheme) => authScheme.Name,
(authScheme) => new OpenApiSecurityScheme
{
Type = SecuritySchemeType.Http,
Scheme = "bearer", // "bearer" refers to the header name here
In = ParameterLocation.Header,
BearerFormat = "Json Web Token"
});
}

private IList<OpenApiServer> GenerateServers(string host, string basePath)
Expand Down
Expand Up @@ -20,7 +20,7 @@ public SwaggerGeneratorOptions()
OperationIdSelector = DefaultOperationIdSelector;
TagsSelector = DefaultTagsSelector;
SortKeySelector = DefaultSortKeySelector;
SecuritySchemesSelector = DefaultSecuritySchemeSelector;
SecuritySchemesSelector = null;
SchemaComparer = StringComparer.Ordinal;
Servers = new List<OpenApiServer>();
SecuritySchemes = new Dictionary<string, OpenApiSecurityScheme>();
Expand All @@ -45,6 +45,10 @@ public SwaggerGeneratorOptions()

public Func<ApiDescription, string> SortKeySelector { get; set; }

public bool InferSecuritySchemes { get; set; }

public Func<IEnumerable<AuthenticationScheme>, IDictionary<string, OpenApiSecurityScheme>> SecuritySchemesSelector { get; set;}

public bool DescribeAllParametersInCamelCase { get; set; }

public List<OpenApiServer> Servers { get; set; }
Expand All @@ -63,8 +67,6 @@ public SwaggerGeneratorOptions()

public IList<IDocumentFilter> DocumentFilters { get; set; }

public Func<IEnumerable<AuthenticationScheme>, Dictionary<string, OpenApiSecurityScheme>> SecuritySchemesSelector { get; set;}

private bool DefaultDocInclusionPredicate(string documentName, ApiDescription apiDescription)
{
return apiDescription.GroupName == null || apiDescription.GroupName == documentName;
Expand Down Expand Up @@ -106,26 +108,5 @@ private string DefaultSortKeySelector(ApiDescription apiDescription)
{
return TagsSelector(apiDescription).First();
}

private Dictionary<string, OpenApiSecurityScheme> DefaultSecuritySchemeSelector(IEnumerable<AuthenticationScheme> schemes)
{
Dictionary<string, OpenApiSecurityScheme> securitySchemes = new();
#if (NET6_0_OR_GREATER)
foreach (var scheme in schemes)
{
if (scheme.Name == "Bearer")
{
securitySchemes[scheme.Name] = new OpenApiSecurityScheme
{
Type = SecuritySchemeType.Http,
Scheme = "bearer", // "bearer" refers to the header name here
In = ParameterLocation.Header,
BearerFormat = "Json Web Token"
};
}
}
#endif
return securitySchemes;
}
}
}
@@ -0,0 +1,47 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;

namespace Swashbuckle.AspNetCore.SwaggerGen.Test
{
public class FakeAuthenticationSchemeProvider : IAuthenticationSchemeProvider
{
private readonly IEnumerable<AuthenticationScheme> _authenticationSchemes;

public FakeAuthenticationSchemeProvider(IEnumerable<AuthenticationScheme> authenticationSchemes)
{
_authenticationSchemes = authenticationSchemes;
}

public void AddScheme(AuthenticationScheme scheme)
=> throw new NotImplementedException();
public Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
=> Task.FromResult(_authenticationSchemes);

public Task<AuthenticationScheme> GetDefaultAuthenticateSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultChallengeSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultForbidSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultSignInSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultSignOutSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
=> throw new NotImplementedException();

public Task<AuthenticationScheme> GetSchemeAsync(string name)
=> Task.FromResult(_authenticationSchemes.First());

public void RemoveScheme(string name)
=> throw new NotImplementedException();
}
}

0 comments on commit 426c5bb

Please sign in to comment.