using System; using System.Collections.Generic; using System.Collections.Concurrent; using System.Linq; using System.Reflection; using System.Threading.Tasks; using System.Threading; namespace FreeSql { public abstract partial class DbContext : IDisposable { internal IFreeSql _ormPriv; public IFreeSql Orm => _ormPriv ?? throw new ArgumentNullException("请在 OnConfiguring 或 AddFreeDbContext 中配置 UseFreeSql"); #region Property UnitOfWork internal bool _isUseUnitOfWork = true; //是否创建工作单元事务 IUnitOfWork _uowPriv; public IUnitOfWork UnitOfWork { set => _uowPriv = value; get { if (_uowPriv != null) return _uowPriv; if (_isUseUnitOfWork == false) return null; return _uowPriv = new UnitOfWork(Orm); } } #endregion #region Property Options internal DbContextOptions _optionsPriv; public DbContextOptions Options { set => _optionsPriv = value; get { if (_optionsPriv == null) { _optionsPriv = new DbContextOptions(); if (FreeSqlDbContextExtensions._dicSetDbContextOptions.TryGetValue(Orm, out var opt)) { _optionsPriv.EnableAddOrUpdateNavigateList = opt.EnableAddOrUpdateNavigateList; _optionsPriv.OnEntityChange = opt.OnEntityChange; } } return _optionsPriv; } } internal void EmitOnEntityChange(List report) { var oec = UnitOfWork?.EntityChangeReport?.OnChange ?? Options.OnEntityChange; if (oec == null || report == null || report.Any() == false) return; oec(report); } #endregion protected DbContext() : this(null, null) { } protected DbContext(IFreeSql fsql, DbContextOptions options) { _ormPriv = fsql; _optionsPriv = options; if (_ormPriv == null) { var builder = new DbContextOptionsBuilder(); OnConfiguring(builder); _ormPriv = builder._fsql; _optionsPriv = builder._options; } if (_ormPriv != null) InitPropSets(); } protected virtual void OnConfiguring(DbContextOptionsBuilder builder) { } #region Set static ConcurrentDictionary _dicGetDbSetProps = new ConcurrentDictionary(); internal void InitPropSets() { var props = _dicGetDbSetProps.GetOrAdd(this.GetType(), tp => tp.GetProperties(BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public) .Where(a => a.PropertyType.IsGenericType && a.PropertyType == typeof(DbSet<>).MakeGenericType(a.PropertyType.GetGenericArguments()[0])).ToArray()); foreach (var prop in props) { var set = this.Set(prop.PropertyType.GetGenericArguments()[0]); prop.SetValue(this, set, null); AllSets.Add(prop.Name, set); } } protected List _listSet = new List(); protected Dictionary _dicSet = new Dictionary(); internal Dictionary InternalDicSet => _dicSet; public DbSet Set() where TEntity : class => this.Set(typeof(TEntity)) as DbSet; public virtual IDbSet Set(Type entityType) { if (_dicSet.ContainsKey(entityType)) return _dicSet[entityType]; var sd = Activator.CreateInstance(typeof(DbContextDbSet<>).MakeGenericType(entityType), this) as IDbSet; _listSet.Add(sd); if (entityType != typeof(object)) _dicSet.Add(entityType, sd); return sd; } protected Dictionary AllSets { get; } = new Dictionary(); #endregion #region DbSet 快速代理 void CheckEntityTypeOrThrow(Type entityType) { if (Orm.CodeFirst.GetTableByEntity(entityType) == null) throw new ArgumentException($"参数 data 类型错误 {entityType.FullName} "); } /// /// 添加 /// /// /// public void Add(TEntity data) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); this.Set().Add(data); } public void AddRange(IEnumerable data) where TEntity : class => this.Set().AddRange(data); /// /// 更新 /// /// /// public void Update(TEntity data) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); this.Set().Update(data); } public void UpdateRange(IEnumerable data) where TEntity : class => this.Set().UpdateRange(data); /// /// 删除 /// /// /// public void Remove(TEntity data) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); this.Set().Remove(data); } public void RemoveRange(IEnumerable data) where TEntity : class => this.Set().RemoveRange(data); /// /// 添加或更新 /// /// /// public void AddOrUpdate(TEntity data) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); this.Set().AddOrUpdate(data); } /// /// 保存实体的指定 ManyToMany 导航属性 /// /// 实体对象 /// 属性名 public void SaveManyToMany(TEntity data, string propertyName) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); this.Set().SaveManyToMany(data, propertyName); } /// /// 附加实体,可用于不查询就更新或删除 /// /// /// public void Attach(TEntity data) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); this.Set().Attach(data); } public void AttachRange(IEnumerable data) where TEntity : class => this.Set().AttachRange(data); /// /// 附加实体,并且只附加主键值,可用于不更新属性值为null或默认值的字段 /// /// /// public DbContext AttachOnlyPrimary(TEntity data) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); this.Set().AttachOnlyPrimary(data); return this; } #if net40 #else public Task AddAsync(TEntity data) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); return this.Set().AddAsync(data); } public Task AddRangeAsync(IEnumerable data) where TEntity : class => this.Set().AddRangeAsync(data); public Task UpdateAsync(TEntity data) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); return this.Set().UpdateAsync(data); } public Task UpdateRangeAsync(IEnumerable data) where TEntity : class => this.Set().UpdateRangeAsync(data); public Task AddOrUpdateAsync(TEntity data) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); return this.Set().AddOrUpdateAsync(data); } public Task SaveManyToManyAsync(TEntity data, string propertyName) where TEntity : class { CheckEntityTypeOrThrow(typeof(TEntity)); return this.Set().SaveManyToManyAsync(data, propertyName); } #endif #endregion #region Queue Action public class EntityChangeReport { public class ChangeInfo { public object Object { get; set; } public EntityChangeType Type { get; set; } } /// /// 实体变化记录 /// public List Report { get; } = new List(); /// /// 实体变化事件 /// public Action> OnChange { get; set; } } internal List _entityChangeReport = new List(); public enum EntityChangeType { Insert, Update, Delete, SqlRaw } internal class ExecCommandInfo { public EntityChangeType changeType { get; set; } public IDbSet dbSet { get; set; } public Type stateType { get; set; } public Type entityType { get; set; } public object state { get; set; } } Queue _actions = new Queue(); internal int _affrows = 0; internal void EnqueueAction(EntityChangeType changeType, IDbSet dbSet, Type stateType, Type entityType, object state) => _actions.Enqueue(new ExecCommandInfo { changeType = changeType, dbSet = dbSet, stateType = stateType, entityType = entityType, state = state }); #endregion ~DbContext() => this.Dispose(); int _disposeCounter; public void Dispose() { if (Interlocked.Increment(ref _disposeCounter) != 1) return; try { _actions.Clear(); foreach (var set in _listSet) try { set.Dispose(); } catch { } _listSet.Clear(); _dicSet.Clear(); AllSets.Clear(); if (_isUseUnitOfWork) UnitOfWork?.Dispose(); } finally { GC.SuppressFinalize(this); } } } }