using System; using System.Collections.Concurrent; using System.Reflection; using System.Threading.Tasks; namespace FreeSql { internal class RepositoryDbContext : DbContext { protected IBaseRepository _repos; public RepositoryDbContext(IFreeSql orm, IBaseRepository repos) : base() { _orm = orm; _repos = repos; _isUseUnitOfWork = false; _uowPriv = _repos.UnitOfWork; } static ConcurrentDictionary _dicGetRepositoryDbField = new ConcurrentDictionary(); static FieldInfo GetRepositoryDbField(Type type) => _dicGetRepositoryDbField.GetOrAdd(type, tp => typeof(BaseRepository<,>).MakeGenericType(tp, typeof(int)).GetField("_dbPriv", BindingFlags.Instance | BindingFlags.NonPublic)); public override IDbSet Set(Type entityType) { if (_dicSet.ContainsKey(entityType)) return _dicSet[entityType]; var tb = _orm.CodeFirst.GetTableByEntity(entityType); if (tb == null) return null; object repos = _repos; if (entityType != _repos.EntityType) { repos = Activator.CreateInstance(typeof(DefaultRepository<,>).MakeGenericType(entityType, typeof(int)), _repos.Orm); (repos as IBaseRepository).UnitOfWork = _repos.UnitOfWork; GetRepositoryDbField(entityType).SetValue(repos, this); typeof(RepositoryDbContext).GetMethod("SetRepositoryDataFilter").MakeGenericMethod(_repos.EntityType) .Invoke(null, new object[] { repos, _repos }); } var sd = Activator.CreateInstance(typeof(RepositoryDbSet<>).MakeGenericType(entityType), repos) as IDbSet; if (entityType != typeof(object)) _dicSet.Add(entityType, sd); return sd; } public static void SetRepositoryDataFilter(object repos, BaseRepository baseRepo) where TEntity : class { var filter = baseRepo.DataFilter as DataFilter; DataFilterUtil.SetRepositoryDataFilter(repos, fl => { foreach (var f in filter._filters) fl.Apply(f.Key, f.Value.Expression); }); } public override int SaveChanges() { ExecCommand(); var ret = _affrows; _affrows = 0; return ret; } async public override Task SaveChangesAsync() { await ExecCommandAsync(); var ret = _affrows; _affrows = 0; return ret; } } }