using System.Reflection; using DBConnection.Contexts; using DBConnection.Repositories.Interfaces; using Microsoft.EntityFrameworkCore; using NuGet.Configuration; using Treestar.Shared.Models.DBDomain; namespace DBConnection.Repositories; public abstract class BaseRepository : IRepository where TEntityType : BaseEntity { protected readonly AppDbContext DbContext; private object?[]? GetPrimaryKey(TEntityType entity) { var keyProperties = DbContext.Model.FindEntityType(typeof(TEntityType))?.FindPrimaryKey()?.Properties.Select(p => p.Name); var ret = keyProperties?.Select(p => entity.GetType().GetProperty(p)?.GetValue(entity, null)).ToArray(); return ret; } protected abstract IQueryable GetAllIncludedQueryable(); public BaseRepository(AppDbContext dbContext) { DbContext = dbContext; } public virtual TEntityType Delete(TEntityType entity) { DbContext.Set().Remove(entity); return entity; } public virtual async Task> UpsertMany(IEnumerable entities, bool saveAfter=true) { var newEntities = new List(); foreach (var entity in entities) { newEntities.Add(await Upsert(entity, false)); } if (saveAfter) { await DbContext.SaveChangesAsync(); } return newEntities; } public virtual async Task Upsert(TEntityType entity, bool saveAfter=true) { bool exists = await DbContext.Set().ContainsAsync(entity); if (!exists) { DbContext.Set().Add(entity); } else { var dbEntry = await GetIncluded(entity); entity.DateCreated = dbEntry.DateCreated; var entry = DbContext.Entry(dbEntry); entry.CurrentValues.SetValues(entity); entity = dbEntry; } if (saveAfter) { await DbContext.SaveChangesAsync(); } return entity; } public virtual async Task> GetAllIncluded() { return await GetWhereIncluded(i => true); } public virtual async Task GetIncluded(TEntityType entity) { return await GetIncluded(dbEntity => GetPrimaryKey(dbEntity).SequenceEqual(GetPrimaryKey(entity))); } public virtual async Task GetIncluded(Func predicate) { return GetAllIncludedQueryable().FirstOrDefault(predicate); } public virtual async Task> GetWhereIncluded(Func predicate) { return GetAllIncludedQueryable().AsEnumerable().Where(predicate); } }