DEV Community

Cover image for Dapper: Generic Repository
Amjad Abujamous
Amjad Abujamous

Posted on • Updated on

Dapper: Generic Repository

Motivation

While Entity Framework is a great object-relational mapper for the dotnet framework, the queries it generates can sometimes be slow. For instance, in my previous job, EF Core's query performance struggled with tables that have over 500,000 rows. The solution to that problem turned out to be Dapper, since it allows one to write their own optimized query. In this blog post, we will explore how to do a clean implementation of the Generic Repository pattern in Dapper.

Code

The source code can be found here.

Credit

This article builds on top of the brilliant work by Zuraiz Ahmed Shehzad on Medium. This implementation is asynchronous, and its queries are more robust, especially the SELECT queries which select the column name as the field name dynamiclly.

Approach

We will have an N-tier ASP.NET Web API with four layers:
API --> Application --> Repository (Infrastructure) --> Database (Domain).

The generic repository implementation will be, you guessed it, in the Repository layer.

Furthermore, we will connect to a Postgres database since those are easy to setup on any operating system.

Step 1: The Database

We will create a simple API that allows us to do CRUD operations on Products and Categories. Following is the Db script to create them:

CREATE DATABASE "GenericRepoDapperDb"
    WITH
    OWNER = postgres
    ENCODING = 'UTF8'
    LC_COLLATE = 'English_United States.1252'
    LC_CTYPE = 'English_United States.1252'
    LOCALE_PROVIDER = 'libc'
    TABLESPACE = pg_default
    CONNECTION LIMIT = -1
    IS_TEMPLATE = False;
Enter fullscreen mode Exit fullscreen mode
CREATE TABLE categories (
  id SERIAL PRIMARY KEY,
  name VARCHAR(255) NOT NULL,
  created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
  updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
Enter fullscreen mode Exit fullscreen mode
CREATE TABLE products (
  id SERIAL PRIMARY KEY,
  name VARCHAR(255) NOT NULL,
  description TEXT,
  category_id INTEGER NOT NULL REFERENCES categories(id),
  created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
  updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
Enter fullscreen mode Exit fullscreen mode

Step 2: The Domain Layer
We will create Plain Old C# Objects (POCOs) that map to those tables:

public interface IEntity
{
    int Id { get; set; }
    DateTime CreatedAt { get; set; }
    DateTime UpdatedAt { get; set; }
}

public class Entity : IEntity
{
    [Key]
    [Column("id")]
    public int Id { get; set; }
    [Column("created_at")]
    public DateTime CreatedAt { get; set; }
    [Column("updated_at")]
    public DateTime UpdatedAt { get; set; }
}
Enter fullscreen mode Exit fullscreen mode
[Table("categories")]
public class Category : Entity
{
    [Column("name")]
    public string? Name { get; set; }
}
Enter fullscreen mode Exit fullscreen mode
[Table("products")]
public class Product : Entity
{
    [Column("name")]
    public string? Name { get; set; }
    [Column("description")]
    public string? Description { get; set; }
    [Column("category_id")]
    public int CategoryId { get; set; }
}
Enter fullscreen mode Exit fullscreen mode

Now, we will create the database context, which will be registered as a Singleton in a later step.

public class ApplicationDbContext
{
    private readonly IConfiguration _configuration;

    public ApplicationDbContext(IConfiguration configuration)
    {
        _configuration = configuration;
    }

    public IDbConnection CreateConnection(string connectionString = "DefaultConnection")
    {
        string? connection = _configuration.GetConnectionString(connectionString);
        return new NpgsqlConnection(connection);

    }
}
Enter fullscreen mode Exit fullscreen mode

Note the use of inheritance to avoid repeating the common fields, and the Table, Key, and Column annotations. which will be useful for step 3.

Step 3: The Generic Repository implementation

In this step, we will implement a GenericRepository which takes a type T and have it linked by a Unit Of Work which we can easily register in the Dependency injection container (in a later step).

public interface IGenericRepository<T>
{
    Task<T> GetById(int id);
    Task<IEnumerable<T>> GetAll();
    Task<int> CountAll();
    Task<int> Add(T entity);
    Task<int> Update(T entity);
    Task<int> Delete(T entity);
}

public class GenericRepository<T> : IGenericRepository<T> where T : class
{
    private readonly IDbConnection _connection;


    public GenericRepository(ApplicationDbContext context)
    {
        _connection = context.CreateConnection();
    }

    public async Task<T> GetById(int id)
    {
        T result;
        try
        {
            string tableName = GetTableName();
            string keyColumn = GetKeyColumnName();
            string query     = $"SELECT {GetColumnsAsProperties()} FROM {tableName} WHERE {keyColumn} = '{id}'";

            result = await _connection.QueryFirstOrDefaultAsync<T>(query);
        }
        catch (Exception ex)
        {
            Console.WriteLine($"Error fetching a record from db: ${ex.Message}");
            throw new Exception("Unable to fetch data. Please contact the administrator.");
        }
        finally
        {
            _connection.Close();
        }

        return result;
    }

    public async Task<IEnumerable<T>> GetAll()
    {
        IEnumerable<T> result;
        try
        {
            string tableName = GetTableName();
            string query = $"SELECT {GetColumnsAsProperties()} FROM {tableName}";

            result = await _connection.QueryAsync<T>(query);
        }
        catch (Exception ex)
        {
            Console.WriteLine($"Error fetching records from db: ${ex.Message}");
            throw new Exception("Unable to fetch data. Please contact the administrator.");
        }
        finally
        {
            _connection.Close();
        }
        return result;
    }

    public async Task<int> CountAll()
    {
        int result = -1;
        try
        {
            string tableName = GetTableName();
            string query = $"SELECT COUNT(*) FROM {tableName}"; // May need exact column names

            result = await _connection.QueryFirstOrDefaultAsync<int>(query);
        }
        catch (Exception ex)
        {
            Console.WriteLine($"Error counting records in db: ${ex.Message}");
            throw new Exception("Unable to count data. Please contact the administrator.");
        }
        finally
        {
            _connection.Close();
        }

        return result;
    }

    public async Task<int> Add(T entity)
    {
        int rowsEffected = 0;
        try
        {
            string tableName  = GetTableName();
            string columns    = GetColumns(excludeKey: true);
            string properties = GetPropertyNames(excludeKey: true);
            string query      = $"INSERT INTO {tableName} ({columns}) VALUES ({properties})";

            rowsEffected = await _connection.ExecuteAsync(query, entity);
        }
        catch (Exception ex)
        {
            Console.WriteLine($"Error adding a record to db: ${ex.Message}");
            rowsEffected = -1;
        }
        finally
        {
            _connection.Close();
        }

        return rowsEffected;
    }

    public async Task<int> Update(T entity)
    {
        int rowsEffected = 0;
        try
        {
            string? tableName   = GetTableName();
            string? keyColumn   = GetKeyColumnName();
            string? keyProperty = GetKeyPropertyName();

            StringBuilder query = new StringBuilder();
            query.Append($"UPDATE {tableName} SET ");

            foreach (var property in GetProperties(true))
            {
                var columnAttribute = property.GetCustomAttribute<ColumnAttribute>();

                string propertyName = property.Name;
                string columnName   = columnAttribute?.Name ?? "";

                query.Append($"{columnName} = @{propertyName},");
            }

            query.Remove(query.Length - 1, 1);

            query.Append($" WHERE {keyColumn} = @{keyProperty}");

            rowsEffected = await _connection.ExecuteAsync(query.ToString(), entity);
        }
        catch (Exception ex)
        {
            Console.WriteLine($"Error updating a record in db: ${ex.Message}");
            rowsEffected = -1;
        }
        finally
        {
            _connection.Close();
        }

        return rowsEffected;
    }

    public async Task<int> Delete(T entity)
    {
        int rowsEffected = 0;
        try
        {
            string? tableName   = GetTableName();
            string? keyColumn   = GetKeyColumnName();
            string? keyProperty = GetKeyPropertyName();
            string query       = $"DELETE FROM {tableName} WHERE {keyColumn} = @{keyProperty}";

            rowsEffected = await _connection.ExecuteAsync(query, entity);
        }
        catch (Exception ex)
        {
            Console.WriteLine($"Error deleting a record in db: ${ex.Message}");
            rowsEffected = -1;
        }
        finally
        {
            _connection.Close();
        }

        return rowsEffected;
    }

    private string GetTableName()
    {
        var type           = typeof(T);
        var tableAttribute = type.GetCustomAttribute<TableAttribute>();
        if (tableAttribute != null)
            return tableAttribute.Name;

        return type.Name;
    }

    private static string? GetKeyColumnName()
    {
        PropertyInfo[] properties = typeof(T).GetProperties();

        foreach (PropertyInfo property in properties)
        {
            object[] keyAttributes = property.GetCustomAttributes(typeof(KeyAttribute), true);

            if (keyAttributes != null && keyAttributes.Length > 0)
            {
                object[] columnAttributes = property.GetCustomAttributes(typeof(ColumnAttribute), true);

                if (columnAttributes != null && columnAttributes.Length > 0)
                {
                    ColumnAttribute columnAttribute = (ColumnAttribute)columnAttributes[0];
                    return columnAttribute?.Name ?? "";
                }
                else
                {
                    return property.Name;
                }
            }
        }

        return null;
    }


    private string GetColumns(bool excludeKey = false)
    {
        var type = typeof(T);
        var columns = string.Join(", ", type.GetProperties()
            .Where(p => !excludeKey || !p.IsDefined(typeof(KeyAttribute)))
            .Select(p =>
            {
                var columnAttribute = p.GetCustomAttribute<ColumnAttribute>();
                return columnAttribute != null ? columnAttribute.Name : p.Name;
            }));

        return columns;
    }

    private string GetColumnsAsProperties(bool excludeKey = false)
    {
        var type = typeof(T);
        var columnsAsProperties = string.Join(", ", type.GetProperties()
            .Where(p => !excludeKey || !p.IsDefined(typeof(KeyAttribute)))
            .Select(p =>
            {
                var columnAttribute = p.GetCustomAttribute<ColumnAttribute>();
                return columnAttribute != null ? $"{columnAttribute.Name} AS {p.Name}" : p.Name;
            }));

        return columnsAsProperties;
    }

    private string GetPropertyNames(bool excludeKey = false)
    {
        var properties = typeof(T).GetProperties()
            .Where(p => !excludeKey || p.GetCustomAttribute<KeyAttribute>() == null);

        var values = string.Join(", ", properties.Select(p => $"@{p.Name}"));

        return values;
    }

    private IEnumerable<PropertyInfo> GetProperties(bool excludeKey = false)
    {
        var properties = typeof(T).GetProperties()
            .Where(p => !excludeKey || p.GetCustomAttribute<KeyAttribute>() == null);

        return properties;
    }

    private string? GetKeyPropertyName()
    {
        var properties = typeof(T).GetProperties()
            .Where(p => p.GetCustomAttribute<KeyAttribute>() != null).ToList();

        if (properties.Any())
            return properties?.FirstOrDefault()?.Name ?? null;

        return null;
    }
}
Enter fullscreen mode Exit fullscreen mode
public interface IUnit
{
    GenericRepository<T> GetRepository<T>() where T : class, IEntity;
}

public class Unit : IUnit
{
    private readonly ApplicationDbContext _context;

    public Unit(ApplicationDbContext context)
    {
        _context = context;
    }

    public GenericRepository<T> GetRepository<T>() where T : class, IEntity
    {
        return new GenericRepository<T>(_context);
    }
}
Enter fullscreen mode Exit fullscreen mode

Step 4: The Service (Application) layer

In this layer, will simply make use of the repository to access the database and execute our CRUD operations for both Products and Categories (will only show the code for Products below for breivity).

public interface IProductService
{
    Task<int> Create(ProductDto productDto);
    Task<int> Update(int id, ProductDto productDto);
    Task<IEnumerable<Product>> GetAll();
    Task<int> CountAll();
    Task<Product> GetById(int id);
    Task<bool> Delete(int id);
}

public class ProductService : IProductService
{
    private readonly IUnit _unit;
    private readonly IGenericRepository<Product> _repository;
    private readonly IMapper _mapper;

    public ProductService(IUnit unit, IMapper mapper)
    {
        _unit       = unit;
        _repository = _unit.GetRepository<Product>();
        _mapper     = mapper;
    }

    public async Task<int> Create(ProductDto productDto)
    {
        var product       = _mapper.Map<Product>(productDto);
        product.CreatedAt = DateTime.SpecifyKind(DateTime.Now, DateTimeKind.Utc);
        int result        = await _repository.Add(product);
        return result;
    }

    public async Task<int> Update(int id, ProductDto productDto)
    {
        var product         = await GetById(id);
        product.Name        = productDto.Name;
        product.Description = productDto.Description;
        product.CategoryId  = productDto.CategoryId;
        product.UpdatedAt   = DateTime.SpecifyKind(DateTime.Now, DateTimeKind.Utc);
        int productsUpdated = await _repository.Update(product);
        return productsUpdated;
    }

    public async Task<IEnumerable<Product>> GetAll()
    {
        var products = await _repository.GetAll();
        return products;
    }

    public async Task<int> CountAll()
    {
        int count = await _repository.CountAll();
        return count;
    }

    public async Task<Product> GetById(int id)
    {
        var product = await _repository.GetById(id);
        if (product == null)
            throw new Exception("Product record does not exist.");
        return product;
    }

    public async Task<bool> Delete(int id)
    {
        var product  = await GetById(id);
        int result   = await _repository.Delete(product); // Could also be done with "isDeleted = true;"
        return (result > 0);
    }
}
Enter fullscreen mode Exit fullscreen mode

Step 5: The Web API

We will create a Web API with two controllers, and give it a Swagger UI to be used for testing.

Program.cs

using Application.Logic.CategoryService;
using Application.Logic.ProductService;
using Domain.Database;
using GenericRepo_Dapper.Configuration;
using Infrastructure.UnitOfWork;
using Microsoft.OpenApi.Models;

var builder = WebApplication.CreateBuilder(args);

string corsName = "CorsName";
builder.Services.AddCors(options =>
{
    options.AddPolicy(corsName, policyBuilder => policyBuilder
        .WithOrigins("http://localhost", "https://localhost")
        .AllowAnyMethod()
        .AllowAnyHeader());
});

builder.Services.AddSingleton<ApplicationDbContext>();

builder.Services.AddRouting(context => context.LowercaseUrls = true);
builder.Services.AddControllersWithViews().AddJsonOptions(options =>
{
    options.JsonSerializerOptions.ReferenceHandler = System.Text.Json.Serialization.ReferenceHandler.IgnoreCycles;
    options.JsonSerializerOptions.WriteIndented    = true;
});

builder.Services.AddAutoMapper(AppDomain.CurrentDomain.GetAssemblies());

// Service
builder.Services.AddScoped<ICategoryService, CategoryService>();
builder.Services.AddScoped<IProductService, ProductService>();

// Repository
builder.Services.AddScoped<IUnit, Unit>();

// Swagger
builder.Services.AddEndpointsApiExplorer();
builder.Services.AddSwaggerGen(option =>
{
    option.SwaggerDoc("v1", info: new OpenApiInfo { Title = "Generic Repository and Dapper API", Version = "v1" });
    option.OperationFilter<HeaderFilter>();
});

var app = builder.Build();

var swaggerConfig = new SwaggerConfig();
builder.Configuration.GetSection(nameof(SwaggerConfig)).Bind(swaggerConfig);
app.UseSwagger(option => { option.RouteTemplate = swaggerConfig.JsonRoute; });
app.UseSwaggerUI(option => { option.SwaggerEndpoint(swaggerConfig.UIEndpoint, swaggerConfig.Description); });

app.UseCors(corsName);

app.UseExceptionHandler("/Error");

app.MapControllers();


app.Run();

Enter fullscreen mode Exit fullscreen mode

Result

Swagger UI
Create Category
Get Categories
Create Product
Get Products

Room for imrpovement

  • Add Bulk.
  • Pagination.
  • Filters for the Get functions.
  • Returning a view model instead of the object itself.

End Note

Feel free to comment with questions or inquiries, or to add any notes in the section below. Happy developing!

Top comments (0)