63 lines
2.2 KiB
C#

using System;
using System.Collections.Concurrent;
using System.Reflection;
using System.Threading.Tasks;
namespace FreeSql {
internal class RepositoryDbContext : DbContext {
protected IBaseRepository _repos;
public RepositoryDbContext(IFreeSql orm, IBaseRepository repos) : base() {
_orm = orm;
_repos = repos;
_isUseUnitOfWork = false;
_uowPriv = _repos.UnitOfWork;
}
static ConcurrentDictionary<Type, FieldInfo> _dicGetRepositoryDbField = new ConcurrentDictionary<Type, FieldInfo>();
static FieldInfo GetRepositoryDbField(Type type) => _dicGetRepositoryDbField.GetOrAdd(type, tp => typeof(BaseRepository<,>).MakeGenericType(tp, typeof(int)).GetField("_dbPriv", BindingFlags.Instance | BindingFlags.NonPublic));
public override IDbSet Set(Type entityType) {
if (_dicSet.ContainsKey(entityType)) return _dicSet[entityType];
var tb = _orm.CodeFirst.GetTableByEntity(entityType);
if (tb == null) return null;
object repos = _repos;
if (entityType != _repos.EntityType) {
repos = Activator.CreateInstance(typeof(DefaultRepository<,>).MakeGenericType(entityType, typeof(int)), _repos.Orm);
(repos as IBaseRepository).UnitOfWork = _repos.UnitOfWork;
GetRepositoryDbField(entityType).SetValue(repos, this);
typeof(RepositoryDbContext).GetMethod("SetRepositoryDataFilter").MakeGenericMethod(_repos.EntityType)
.Invoke(null, new object[] { repos, _repos });
}
var sd = Activator.CreateInstance(typeof(RepositoryDbSet<>).MakeGenericType(entityType), repos) as IDbSet;
if (entityType != typeof(object)) _dicSet.Add(entityType, sd);
return sd;
}
public static void SetRepositoryDataFilter<TEntity>(object repos, BaseRepository<TEntity> baseRepo) where TEntity : class {
var filter = baseRepo.DataFilter as DataFilter<TEntity>;
DataFilterUtil.SetRepositoryDataFilter(repos, fl => {
foreach (var f in filter._filters)
fl.Apply<TEntity>(f.Key, f.Value.Expression);
});
}
public override int SaveChanges() {
ExecCommand();
var ret = _affrows;
_affrows = 0;
return ret;
}
async public override Task<int> SaveChangesAsync() {
await ExecCommandAsync();
var ret = _affrows;
_affrows = 0;
return ret;
}
}
}