跳到主要内容

改造原生DI实现属性注入

我见过的依赖注入主要有两种

  • .net使用的是 构造函数注入, 将依赖显式地写在构造函数中
  • spring使用的是 属性注入, 将依赖使用 @Autowired 标记起来

由于.net原生的DI不支持 属性注入, 所以打算尝试自己实现一下

思考1

正常来说, 我希望的 属性注入C#中使用起来大致如下

[ApiController]
[Route("[controller]")]
public class DemoController : ControllerBase
{
[AutoInject]
public DemoService Demo { get; set; }
}

将需要自动进行注入的依赖做个标记, 然后在创建 DemoController 对象时自动对 Demo 属性进行赋值

由于没有使用 构造函数注入, 所以我们必须将 Demo 设为 public, 不然无法访问到对应的属性

原生DI使用 ServiceProvider 创建对象, 所以第一步我们 至少需要自定义一个 MyServiceProvider

自定义 MyServiceProvider

为了降低实现难度, 我们首先参考装饰器模式扩展出一个 MyServiceProvider, 复用已有的功能先从容器中获取对象 obj

public class MyServiceProvider : IServiceProvider
{
private IServiceProvider _provider;
public MyServiceProvider(IServiceProvider provider) => _provider = provider;
public object GetService(Type serviceType)
{
var obj = _provider.GetService(serviceType);
return obj;
}
}

其次我们需要获取到对象中被 [AutoInject] 标记的需要我们进行自动注入的属性

public class MyServiceProvider : IServiceProvider
{
private IServiceProvider _provider;
public MyServiceProvider(IServiceProvider provider)
{
_provider = provider;
}
public object GetService(Type serviceType)
{
var obj = _provider.GetService(serviceType);
if (obj == null)
return null;
var type = obj.GetType();
var props = type.GetProperties()
.Where(item => item.CustomAttributes.Any(attr => attr.AttributeType == typeof(AutoInjectAttribute)));
foreach (var prop in props)
prop.SetValue(obj, GetService(prop.PropertyType));
return obj;
}
}
提示

最好通过使用 GetType() 获取到的 type 获取属性

传入的 serviceType 参数很有可能只是接口的类型, 而不是具体的实现类型

通过判断属性的 CustomAttributes 有无我们自定义的 AutoInjectAttribute 筛选被标记的属性

然后通过筛选出来的属性对 obj 进行赋值

思考2

.net原生DI使用一个 ServiceProviderFactory 创建 ServiceProvider 对象

所以我们还需要自定义一个 MyServiceProviderFactory 用于创建我们自己的 MyServiceProvider

自定义 MyServiceProviderFactory

public class MyServiceProviderFactory : IServiceProviderFactory<IServiceCollection>
{
public IServiceCollection CreateBuilder(IServiceCollection services) => services ?? new ServiceCollection();
public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder)
{
return new MyServiceProvider(containerBuilder.BuildServiceProvider());
}
}

MyServiceProviderFactory 的实现非常简单, 只要返回我们之前自定义的 MyServiceProvider 对象即可

使用 MyServiceProviderFactory

使用以下代码即可将DI使用的工厂替换为我们自定义的实现

var builder = WebApplication.CreateBuilder(args);
builder.Host.UseServiceProviderFactory(new MyServiceProviderFactory());

发现问题

但即使替换了工厂实现, 我们可能依然无法正常进行 属性注入, 尤其是在使用 Controller 的时候

这主要是因为在asp.net coreControllerControllerActivator 创建

所以我们还需要实现自定义的 MyControllerActivator

实现 MyControllerActivator

public class MyControllerActivator : IControllerActivator
{
public virtual object Create(ControllerContext context)
{
var provider = context.HttpContext.RequestServices;
provider = provider is MyServiceProvider ? provider : new MyServiceProvider(provider);
return provider.GetService(context.ActionDescriptor.ControllerTypeInfo.AsType());
}
public virtual void Release(ControllerContext context, object controller)
{
if (controller is IDisposable disposeable) disposeable.Dispose();
}
}

Create 创建 Controller 时将原来的 provider 替换为自定义的 MyServiceProvider

替换 IControllerActivator

有了自己的 MyControllerActivator 之后, 就需要将原来注册的 IControllerActivator 替换掉, 并且需要将 Controller 注册为 Service

builder.Services.AddControllers().AddControllersAsServices();
builder.Services.Replace(ServiceDescriptor.Transient<IControllerActivator, MyControllerActivator>());

替换完成后即可在创建 Controller 时使用自定义的 MyServiceProvider

完整代码

using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.Controllers;
using Microsoft.Extensions.DependencyInjection.Extensions;

var builder = WebApplication.CreateBuilder(args);
// 使用自定义的 MyServiceProviderFactory
builder.Host.UseServiceProviderFactory(new MyServiceProviderFactory());
// 需要将 Controller 注册为 Service
builder.Services.AddControllers().AddControllersAsServices();
// 替换 MyControllerActivator
builder.Services.Replace(ServiceDescriptor.Transient<IControllerActivator, MyControllerActivator>());
// 注册 DemoService
builder.Services.AddScoped<DemoService>();

var app = builder.Build();
app.MapControllers();
app.Run();

/// <summary>
/// 只是用于标记的注解
/// </summary>
[AttributeUsage(AttributeTargets.All)]
public sealed class AutoInjectAttribute : Attribute { }

/// <summary>
/// 一个简单的演示用 Service
/// </summary>
public class DemoService
{
public string Get() => "Get";
}

/// <summary>
/// 在创建 Controller 时使用 自定义的 MyServiceProvider 获取对象
/// </summary>
public class MyControllerActivator : IControllerActivator
{
public virtual object Create(ControllerContext context)
{
var provider = context.HttpContext.RequestServices;
provider = provider is MyServiceProvider ? provider : new MyServiceProvider(provider);
return provider.GetService(context.ActionDescriptor.ControllerTypeInfo.AsType());
}
public virtual void Release(ControllerContext context, object controller)
{
if (controller is IDisposable disposeable) disposeable.Dispose();
}
}

/// <summary>
/// 自定义 ServiceProvider, 添加对属性注入的处理
/// </summary>
public class MyServiceProvider : IServiceProvider
{
private IServiceProvider _provider;
public MyServiceProvider(IServiceProvider provider)
{
_provider = provider;
}
public object GetService(Type serviceType)
{
var obj = _provider.GetService(serviceType);
if (obj == null)
return obj;
var type = obj.GetType();
var props = type.GetProperties()
.Where(item => item.CustomAttributes.Any(attr => attr.AttributeType == typeof(AutoInjectAttribute)));
foreach (var prop in props)
prop.SetValue(obj, GetService(prop.PropertyType));
return obj;
}
}
/// <summary>
/// 替换工厂
/// </summary>
public class MyServiceProviderFactory : IServiceProviderFactory<IServiceCollection>
{
public IServiceCollection CreateBuilder(IServiceCollection services) => services ?? new ServiceCollection();
public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder)
{
return new MyServiceProvider(containerBuilder.BuildServiceProvider());
}
}
/// <summary>
/// 简单的controller
/// </summary>
[ApiController]
[Route("[controller]")]
public class DemoController : ControllerBase
{
[AutoInject]
public DemoService Demo { get; set; }

[HttpGet]
public string Get()
{
return Demo?.Get();
}
}