Skip to content

Commit 4a04fd9

Browse files
author
Kamil Zakiev
committed
added open generic support for decorator func
1 parent 16b281d commit 4a04fd9

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

src/Scrutor/ServiceCollectionExtensions.Decoration.cs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ public static IServiceCollection Decorate(this IServiceCollection services, Type
170170
Preconditions.NotNull(serviceType, nameof(serviceType));
171171
Preconditions.NotNull(decorator, nameof(decorator));
172172

173+
if (serviceType.IsOpenGeneric())
174+
{
175+
return services.DecorateOpenGeneric(serviceType, decorator);
176+
}
177+
173178
return services.DecorateDescriptors(serviceType, x => x.Decorate(decorator));
174179
}
175180

@@ -243,10 +248,28 @@ private static bool IsSameGenericType(Type t1, Type t2)
243248
return t1.IsGenericType && t2.IsGenericType && t1.GetGenericTypeDefinition() == t2.GetGenericTypeDefinition();
244249
}
245250

246-
private static bool TryDecorateOpenGeneric(this IServiceCollection services, Type serviceType, Type decoratorType)
251+
private static IServiceCollection DecorateOpenGeneric(this IServiceCollection services, Type serviceType, Func<object, IServiceProvider, object> decorator)
247252
{
248253
bool TryDecorate(Type[] typeArguments)
249254
{
255+
var closedServiceType = serviceType.MakeGenericType(typeArguments);
256+
return services.TryDecorateDescriptors(closedServiceType, x => x.Decorate(decorator));
257+
}
258+
259+
if (services.TryDecorateOpenGeneric(serviceType, openTypeDecorator: TryDecorate))
260+
{
261+
return services;
262+
}
263+
264+
throw new MissingTypeRegistrationException(serviceType);
265+
}
266+
267+
private static bool TryDecorateOpenGeneric(this IServiceCollection services, Type serviceType, Type decoratorType = null, Func<Type[], bool> openTypeDecorator = null)
268+
{
269+
bool TryDecorate(Type[] typeArguments)
270+
{
271+
Preconditions.NotNull(decoratorType, nameof(decoratorType));
272+
250273
var closedServiceType = serviceType.MakeGenericType(typeArguments);
251274
var closedDecoratorType = decoratorType.MakeGenericType(typeArguments);
252275

@@ -263,7 +286,8 @@ bool TryDecorate(Type[] typeArguments)
263286
return false;
264287
}
265288

266-
return arguments.Aggregate(true, (result, args) => result && TryDecorate(args));
289+
var tryDecorate = openTypeDecorator ?? TryDecorate;
290+
return arguments.Aggregate(true, (result, args) => result && tryDecorate(args));
267291
}
268292

269293
private static IServiceCollection DecorateDescriptors(this IServiceCollection services, Type serviceType, Func<ServiceDescriptor, ServiceDescriptor> decorator)

test/Scrutor.Tests/OpenGenericDecorationTests.cs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,28 @@ public void CanDecorateOpenGenericTypeBasedOnInterface()
4141
Assert.IsType<MyQueryHandler>(loggingDecorator.Inner);
4242
}
4343

44+
[Fact]
45+
public void CanDecorateOpenGenericTypeBasedOnInterfaceByDecoratorFunc()
46+
{
47+
var provider = ConfigureProvider(services =>
48+
{
49+
services.AddSingleton<IQueryHandler<MyQuery, MyResult>, MySpecialQueryHandler>();
50+
services.Decorate(typeof(IQueryHandler<,>), (handlerObj, serviceProvider) =>
51+
{
52+
if (handlerObj is ISpecialInterface specialInterface)
53+
{
54+
specialInterface.InitSomeField();
55+
}
56+
57+
return handlerObj;
58+
});
59+
});
60+
61+
var instance = provider.GetRequiredService<IQueryHandler<MyQuery, MyResult>>();
62+
var myQueryHandler = Assert.IsType<MySpecialQueryHandler>(instance);
63+
Assert.True(myQueryHandler.GetSomeField());
64+
}
65+
4466
[Fact]
4567
public void DecoratingNonRegisteredOpenGenericServiceThrows()
4668
{
@@ -79,6 +101,22 @@ public void DecoratingOpenGenericTypeBasedOnGrandparentInterfaceDoesNotDecorateP
79101
}
80102
}
81103

104+
public interface ISpecialInterface
105+
{
106+
void InitSomeField();
107+
}
108+
109+
public class MySpecialQueryHandler : QueryHandler<MyQuery, MyResult>, ISpecialInterface
110+
{
111+
private bool _someField = false;
112+
public void InitSomeField()
113+
{
114+
_someField = true;
115+
}
116+
117+
public bool GetSomeField() => _someField;
118+
}
119+
82120
public class MyQuery { }
83121

84122
public class MyResult { }

0 commit comments

Comments
 (0)