Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Feature: Use Microsoft KeyedServiceProvider #1075

Merged
Changes from 5 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
4a5123c
Use KeyedServiceProvide instead of ContractDictionary for Splat.Micro…
OleksandrTsvirkun Feb 16, 2024
221eea8
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
ChrisPulman Feb 17, 2024
a0f94b3
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
ChrisPulman Feb 17, 2024
a7b8f27
refactor service contract check
dpvreony Feb 23, 2024
400641e
add another use of refactor, update xmldoc
dpvreony Feb 23, 2024
5c80516
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
ChrisPulman Feb 26, 2024
0ba3592
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
ChrisPulman Feb 29, 2024
5de1f56
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
dpvreony Mar 5, 2024
5365add
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
dpvreony Mar 18, 2024
437c9a5
add net8 as test framework
dpvreony Mar 18, 2024
a9b8a55
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
ChrisPulman Apr 13, 2024
9eaa2e3
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
dpvreony Apr 26, 2024
77555d6
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
ChrisPulman May 1, 2024
83a2927
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
ChrisPulman May 1, 2024
e327002
restore contract dictionary
dpvreony May 8, 2024
d08a6a4
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
dpvreony May 8, 2024
7f8d1b0
remove contract dictionary, not actually needed
dpvreony May 9, 2024
17fe3a1
Merge branch 'main' into feature/update-microsoft-di-keyed-service-pr…
dpvreony May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Concurrent;
using System.Data;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Microsoft.Extensions.DependencyInjection;

namespace Splat.Microsoft.Extensions.DependencyInjection;
Expand All @@ -17,7 +18,6 @@ namespace Splat.Microsoft.Extensions.DependencyInjection;
public class MicrosoftDependencyResolver : IDependencyResolver
{
private const string ImmutableExceptionMessage = "This container has already been built and cannot be modified.";
private static readonly Type _dictionaryType = typeof(ContractDictionary<>);
private readonly object _syncLock = new();
private IServiceCollection? _serviceCollection;
private bool _isImmutable;
Expand Down Expand Up @@ -91,29 +91,27 @@ public virtual IEnumerable<object> GetServices(Type? serviceType, string? contra
var isNull = serviceType is null;
serviceType ??= typeof(NullServiceType);

IEnumerable<object> services;
IEnumerable<object> services = Enumerable.Empty<object>();

if (contract is null || string.IsNullOrWhiteSpace(contract))
{
// this is to deal with CS8613 that GetServices returns IEnumerable<object?>?
services = ServiceProvider.GetServices(serviceType)
.Where(a => a is not null)
.Select(a => a!);

if (isNull)
{
services = services
.Cast<NullServiceType>()
.Select(nst => nst.Factory()!);
}
}
else
else if (ServiceProvider is IKeyedServiceProvider serviceProvider)
dpvreony marked this conversation as resolved.
Show resolved Hide resolved
{
services = serviceProvider.GetKeyedServices(serviceType, contract)
.Where(a => a is not null)
.Select(a => a!);
}

if (isNull)
{
var dic = GetContractDictionary(serviceType, false);
services = dic?
.GetFactories(contract)
.Select(f => f()!)
?? Array.Empty<object>();
services = services
.Cast<NullServiceType>()
.Select(nst => nst.Factory()!);
}

return services;
Expand Down Expand Up @@ -142,9 +140,10 @@ public virtual void Register(Func<object?> factory, Type? serviceType, string? c
}
else
{
var dic = GetContractDictionary(serviceType, true);

dic?.AddFactory(contract, factory);
_serviceCollection?.AddKeyedTransient(serviceType, contract, (_, __) =>
isNull
? new NullServiceType(factory)
: factory()!);
}

// required so that it gets rebuilt if not injected externally.
Expand All @@ -166,22 +165,18 @@ public virtual void UnregisterCurrent(Type? serviceType, string? contract = null
{
if (contract is null || string.IsNullOrWhiteSpace(contract))
{
var sd = _serviceCollection?.LastOrDefault(s => s.ServiceType == serviceType);
var sd = _serviceCollection?.LastOrDefault(s => !s.IsKeyedService && s.ServiceType == serviceType);
if (sd is not null)
{
_serviceCollection?.Remove(sd);
}
}
else
{
var dic = GetContractDictionary(serviceType, false);
if (dic is not null)
var sd = _serviceCollection?.LastOrDefault(sd => MatchesKeyedContract(serviceType, contract, sd));
if (sd is not null)
{
dic.RemoveLastFactory(contract);
if (dic.IsEmpty)
{
RemoveContractService(serviceType);
}
_serviceCollection?.Remove(sd);
}
}

Expand All @@ -196,7 +191,7 @@ public virtual void UnregisterCurrent(Type? serviceType, string? contract = null
/// ignoring the <paramref name="serviceType"/> argument.
/// </summary>
/// <param name="serviceType">The service type to unregister.</param>
/// <param name="contract">This parameter is ignored. Service will be removed from all contracts.</param>
/// <param name="contract">A optional value which will remove only an object registered with the same contract.</param>
dpvreony marked this conversation as resolved.
Show resolved Hide resolved
public virtual void UnregisterAll(Type? serviceType, string? contract = null)
{
if (_isImmutable)
Expand All @@ -208,34 +203,28 @@ public virtual void UnregisterAll(Type? serviceType, string? contract = null)

lock (_syncLock)
{
switch (contract)
if (_serviceCollection is null)
dpvreony marked this conversation as resolved.
Show resolved Hide resolved
{
// required so that it gets rebuilt if not injected externally.
_serviceProvider = null;
return;
}

IEnumerable<ServiceDescriptor> sds = Enumerable.Empty<ServiceDescriptor>();

if (contract is null || string.IsNullOrWhiteSpace(contract))
{
case null when _serviceCollection is not null:
{
var sds = _serviceCollection
.Where(s => s.ServiceType == serviceType)
.ToList();

foreach (var sd in sds)
{
_serviceCollection.Remove(sd);
}

break;
}

case null:
throw new ArgumentException("There must be a valid contract if there is no service collection.", nameof(contract));
default:
{
var dic = GetContractDictionary(serviceType, false);
if (dic?.TryRemoveContract(contract) == true && dic.IsEmpty)
{
RemoveContractService(serviceType);
}

break;
}
sds = _serviceCollection.Where(s => !s.IsKeyedService && s.ServiceType == serviceType);
}
else
{
sds = _serviceCollection
.Where(sd => MatchesKeyedContract(serviceType, contract, sd));
}

foreach (var sd in sds.ToList())
{
_serviceCollection.Remove(sd);
}

// required so that it gets rebuilt if not injected externally.
Expand All @@ -255,16 +244,10 @@ public virtual bool HasRegistration(Type? serviceType, string? contract = null)
{
if (contract is null || string.IsNullOrWhiteSpace(contract))
{
return _serviceCollection?.Any(sd => sd.ServiceType == serviceType) == true;
return _serviceCollection?.Any(sd => !sd.IsKeyedService && sd.ServiceType == serviceType) == true;
}

var dictionary = (ContractDictionary?)_serviceCollection?.FirstOrDefault(sd => sd.ServiceType == GetDictionaryType(serviceType))?.ImplementationInstance;

return dictionary switch
{
null => false,
_ => dictionary.GetFactories(contract).Select(f => f()).Any()
};
return _serviceCollection?.Any(sd => MatchesKeyedContract(serviceType, contract, sd)) == true;
}

if (contract is null)
Expand All @@ -273,8 +256,12 @@ public virtual bool HasRegistration(Type? serviceType, string? contract = null)
return service is not null;
}

var dic = GetContractDictionary(serviceType, false);
return dic?.IsEmpty == false;
if (_serviceProvider is IKeyedServiceProvider keyedServiceProvider)
{
return keyedServiceProvider.GetKeyedService(serviceType, contract) is not null;
}

return false;
}

/// <inheritdoc />
Expand All @@ -292,103 +279,9 @@ protected virtual void Dispose(bool disposing)
{
}

private static Type GetDictionaryType(Type serviceType) => _dictionaryType.MakeGenericType(serviceType);
dpvreony marked this conversation as resolved.
Show resolved Hide resolved

private void RemoveContractService(Type serviceType)
{
var dicType = GetDictionaryType(serviceType);
var sd = _serviceCollection?.SingleOrDefault(s => s.ServiceType == serviceType);

if (sd is not null)
{
_serviceCollection?.Remove(sd);
}
}

[SuppressMessage("Naming Rules", "SA1300", Justification = "Intentional")]
private ContractDictionary? GetContractDictionary(Type serviceType, bool createIfNotExists)
{
var dicType = GetDictionaryType(serviceType);

if (ServiceProvider is null)
{
throw new InvalidOperationException("The ServiceProvider is null.");
}

if (_isImmutable)
{
return (ContractDictionary?)ServiceProvider.GetService(dicType);
}

var dic = getDictionary();
if (createIfNotExists && dic is null)
{
lock (_syncLock)
{
if (createIfNotExists)
{
dic = (ContractDictionary?)Activator.CreateInstance(dicType);

if (dic is not null)
{
_serviceCollection?.AddSingleton(dicType, dic);
}
}
}
}

return dic;

ContractDictionary? getDictionary() => _serviceCollection?
.Where(sd => sd.ServiceType == dicType)
.Select(sd => sd.ImplementationInstance)
.Cast<ContractDictionary>()
.SingleOrDefault();
}

private class ContractDictionary
{
private readonly ConcurrentDictionary<string, List<Func<object?>>> _dictionary = new();

public bool IsEmpty => _dictionary.IsEmpty;

public bool TryRemoveContract(string contract) =>
_dictionary.TryRemove(contract, out var _);

public Func<object?>? GetFactory(string contract) =>
GetFactories(contract)
.LastOrDefault();

public IEnumerable<Func<object?>> GetFactories(string contract) =>
_dictionary.TryGetValue(contract, out var collection)
? collection ?? Enumerable.Empty<Func<object?>>()
: Array.Empty<Func<object?>>();

public void AddFactory(string contract, Func<object?> factory) =>
_dictionary.AddOrUpdate(contract, _ => new() { factory }, (_, list) =>
{
(list ??= []).Add(factory);
return list;
});

public void RemoveLastFactory(string contract) =>
_dictionary.AddOrUpdate(contract, [], (_, list) =>
{
var lastIndex = list.Count - 1;
if (lastIndex > 0)
{
list.RemoveAt(lastIndex);
}

// TODO if list empty remove contract entirely
// need to find how to atomically update or remove
// https://github.com/dotnet/corefx/issues/24246
return list;
});
}

[SuppressMessage("Design", "CA1812: Unused class.", Justification = "Used in reflection.")]
private sealed class ContractDictionary<T> : ContractDictionary
{
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool MatchesKeyedContract(Type? serviceType, string contract, ServiceDescriptor sd) =>
sd.ServiceType == serviceType
&& sd is { IsKeyedService: true, ServiceKey: string serviceKey }
&& serviceKey == contract;
}