Files
WebNovelPortal/DBConnection/Repositories/BaseRepository.cs

101 lines
3.1 KiB
C#

using System.Reflection;
using DBConnection.Contexts;
using DBConnection.Repositories.Interfaces;
using Microsoft.EntityFrameworkCore;
using NuGet.Configuration;
using Treestar.Shared.Models.DBDomain;
namespace DBConnection.Repositories;
public abstract class BaseRepository<TEntityType> : IRepository<TEntityType> where TEntityType : BaseEntity
{
protected readonly AppDbContext DbContext;
private object?[]? GetPrimaryKey(TEntityType entity)
{
var keyProperties = DbContext.Model.FindEntityType(typeof(TEntityType))?.FindPrimaryKey()?.Properties.Select(p => p.Name);
var ret = keyProperties?.Select(p => entity.GetType().GetProperty(p)?.GetValue(entity, null)).ToArray();
return ret;
}
protected abstract IQueryable<TEntityType> GetAllIncludedQueryable();
public BaseRepository(AppDbContext dbContext)
{
DbContext = dbContext;
}
public virtual TEntityType Delete(TEntityType entity)
{
DbContext.Set<TEntityType>().Remove(entity);
return entity;
}
public virtual async Task<IEnumerable<TEntityType>> UpsertMany(IEnumerable<TEntityType> entities, bool saveAfter=true)
{
var newEntities = new List<TEntityType>();
foreach (var entity in entities)
{
newEntities.Add(await Upsert(entity, false));
}
if (saveAfter)
{
await DbContext.SaveChangesAsync();
}
return newEntities;
}
public virtual async Task<TEntityType> Upsert(TEntityType entity, bool saveAfter=true)
{
bool exists = await DbContext.Set<TEntityType>().ContainsAsync(entity);
if (!exists)
{
DbContext.Set<TEntityType>().Add(entity);
}
else
{
var dbEntry = await GetIncluded(entity);
entity.DateCreated = dbEntry.DateCreated;
var entry = DbContext.Entry(dbEntry);
entry.CurrentValues.SetValues(entity);
entity = dbEntry;
}
if (saveAfter)
{
await DbContext.SaveChangesAsync();
}
return entity;
}
public virtual async Task<IEnumerable<TEntityType>> GetAllIncluded()
{
return await GetWhereIncluded(i => true);
}
public virtual async Task<TEntityType?> GetIncluded(TEntityType entity)
{
return await GetIncluded(dbEntity => GetPrimaryKey(dbEntity).SequenceEqual(GetPrimaryKey(entity)));
}
public virtual async Task<TEntityType?> GetIncluded(Func<TEntityType, bool> predicate)
{
return GetAllIncludedQueryable().FirstOrDefault(predicate);
}
public virtual async Task<IEnumerable<TEntityType?>> GetWhereIncluded(IEnumerable<TEntityType> entities)
{
return await GetWhereIncluded(entities.Contains);
}
public virtual async Task<IEnumerable<TEntityType>> GetWhereIncluded(Func<TEntityType, bool> predicate)
{
return GetAllIncludedQueryable().AsEnumerable().Where(predicate);
}
public virtual async Task PersistChanges()
{
await DbContext.SaveChangesAsync();
}
}