91 lines
2.8 KiB
C#
91 lines
2.8 KiB
C#
using System.Reflection;
|
|
using DBConnection.Contexts;
|
|
using DBConnection.Repositories.Interfaces;
|
|
using Microsoft.EntityFrameworkCore;
|
|
using NuGet.Configuration;
|
|
using Treestar.Shared.Models.DBDomain;
|
|
|
|
namespace DBConnection.Repositories;
|
|
|
|
public abstract class BaseRepository<TEntityType> : IRepository<TEntityType> where TEntityType : BaseEntity
|
|
{
|
|
protected readonly AppDbContext DbContext;
|
|
|
|
private object?[]? GetPrimaryKey(TEntityType entity)
|
|
{
|
|
var keyProperties = DbContext.Model.FindEntityType(typeof(TEntityType))?.FindPrimaryKey()?.Properties.Select(p => p.Name);
|
|
var ret = keyProperties?.Select(p => entity.GetType().GetProperty(p)?.GetValue(entity, null)).ToArray();
|
|
return ret;
|
|
}
|
|
|
|
protected abstract IQueryable<TEntityType> GetAllIncludedQueryable();
|
|
|
|
public BaseRepository(AppDbContext dbContext)
|
|
{
|
|
DbContext = dbContext;
|
|
}
|
|
|
|
public virtual TEntityType Delete(TEntityType entity)
|
|
{
|
|
DbContext.Set<TEntityType>().Remove(entity);
|
|
return entity;
|
|
}
|
|
|
|
public virtual async Task<IEnumerable<TEntityType>> UpsertMany(IEnumerable<TEntityType> entities, bool saveAfter=true)
|
|
{
|
|
var newEntities = new List<TEntityType>();
|
|
foreach (var entity in entities)
|
|
{
|
|
newEntities.Add(await Upsert(entity, false));
|
|
}
|
|
|
|
if (saveAfter)
|
|
{
|
|
await DbContext.SaveChangesAsync();
|
|
}
|
|
return newEntities;
|
|
}
|
|
|
|
public virtual async Task<TEntityType> Upsert(TEntityType entity, bool saveAfter=true)
|
|
{
|
|
bool exists = await DbContext.Set<TEntityType>().ContainsAsync(entity);
|
|
if (!exists)
|
|
{
|
|
DbContext.Set<TEntityType>().Add(entity);
|
|
}
|
|
else
|
|
{
|
|
var dbEntry = await GetIncluded(entity);
|
|
entity.DateCreated = dbEntry.DateCreated;
|
|
var entry = DbContext.Entry(dbEntry);
|
|
entry.CurrentValues.SetValues(entity);
|
|
entity = dbEntry;
|
|
}
|
|
|
|
if (saveAfter)
|
|
{
|
|
await DbContext.SaveChangesAsync();
|
|
}
|
|
return entity;
|
|
}
|
|
|
|
public virtual async Task<IEnumerable<TEntityType>> GetAllIncluded()
|
|
{
|
|
return await GetWhereIncluded(i => true);
|
|
}
|
|
|
|
public virtual async Task<TEntityType?> GetIncluded(TEntityType entity)
|
|
{
|
|
return await GetIncluded(dbEntity => GetPrimaryKey(dbEntity).SequenceEqual(GetPrimaryKey(entity)));
|
|
}
|
|
|
|
public virtual async Task<TEntityType?> GetIncluded(Func<TEntityType, bool> predicate)
|
|
{
|
|
return GetAllIncludedQueryable().FirstOrDefault(predicate);
|
|
}
|
|
|
|
public virtual async Task<IEnumerable<TEntityType>> GetWhereIncluded(Func<TEntityType, bool> predicate)
|
|
{
|
|
return GetAllIncludedQueryable().AsEnumerable().Where(predicate);
|
|
}
|
|
} |