DEV Community

Cover image for Setting up a Database Driver, Repository and Implementation of a transaction function for your Go App
Orololuwa
Orololuwa

Posted on • Updated on

Setting up a Database Driver, Repository and Implementation of a transaction function for your Go App

Overview

Databases are the foundation for building any app as it stores the data for implementing the architectural design. It is important to set it up in the best possible way that will allow for optimal performance of the app and also to make it seamless when writing tests. In this article, I will be taking you on a step-by-step journey on how to set up your database connection and repository.

Prerequisite

  • Basic understanding of the Go programming language

Getting Started

Sometimes, backend developers tend to opt for an ORM library because it provides an abstraction between your app and the database and thus there is little or no need to write raw queries and migrations which is nice. However, if you want to get better at writing queries (SQL for example), you need to learn how to build your repositories without an ORM. To open a database handle, you can either do it directly from the database driver or do it from database/sql with the driver passed into it. I will be opening the connection with database/sql together with pgx which is a driver and toolkit for PostgreSQL. Walk with me.

First things first, we start by installing pgx with the command below.

go get github.com/jackc/pgx/v5
Enter fullscreen mode Exit fullscreen mode

In a driver.go file, put the following code

package driver

import (
    "database/sql"

    _ "github.com/jackc/pgx/v5/stdlib"
)

type DB struct {
    SQL *sql.DB
}

var dbConn = &DB{}

func ConnectSQL(dsn string)(*DB, error){
    db, err := sql.Open("pgx", dsn)

    if err := db.Ping(); err != nil {
        panic(err)
    }

    dbConn.SQL = db

    return dbConn, err
}
Enter fullscreen mode Exit fullscreen mode

In the above snippet, we import database/sql and do a blank import of github.com/jackc/pgx/v5/stdlib. The essence of the blank import is to make the package available to the driver from database/sql. sql.Open returns a pointer to sql.DB so we, first of all, create a struct DB with a field SQL that can hold it and declare it into dbConn. The essence of having a struct is so it can hold a variety of different connections, for example, if you were to open another connection to a MongoDB database, you could store it in a field called Mongo in the DB struct. Afterward, we open a connection in the ConnectSQL function, test for an error by pinging it, and assign it to dbConn.

Then in our main.go file we call the ConnectSQL function as thus.

package main

import (
    "fmt"
    "log"
    "net/http"

    "github.com/orololuwa/crispy-octo-guacamole/driver"
)
const portNumber = ":8080"

func main(){
    db, err := run()
    if (err != nil){
        log.Fatal(err)
    }
    defer db.SQL.Close()

    fmt.Println(fmt.Sprintf("Staring application on port %s", portNumber))


    srv := &http.Server{
        Addr: portNumber,
        Handler: nil,
    }

    err = srv.ListenAndServe()
    if err != nil {
        log.Fatal(err)
    }
}

func run()(*driver.DB, error){
    dbHost := "localhost"
    dbPort := "5432"
    dbName := "your_db_name"
    dbUser := "your_user"
    dbPassword := ""
    dbSSL := "disable"

    // Connecto to DB
    log.Println("Connecting to dabase")
    connectionString := fmt.Sprintf("host=%s port=%s dbname=%s user=%s password=%s sslmode=%s", dbHost, dbPort, dbName, dbUser, dbPassword, dbSSL)

    db, err := driver.ConnectSQL(connectionString)
    if err != nil {
        log.Fatal("Cannot conect to database: Dying!", err)
    }
    log.Println("Connected to database")

    return db, nil
}
Enter fullscreen mode Exit fullscreen mode

In the main.go file above, we have a run function where we call the ConnectSQL function to open a connection to the database, and is passed to the main function after which we call defer db.SQL.Close() to close the connection when the main function is done executing.

Next, we create two files repository.go and users.go in a folder or package repository, but before then we have to define the user model. So in a models.go file, we define the model with fields as thus:

package models

import "time"

type User struct {
    ID int
    FirstName string
    LastName  string
    Email     string
    Password string
    CreatedAt time.Time
    UpdatedAt time.Time
}
Enter fullscreen mode Exit fullscreen mode

I'm guessing you know how to create migrations to populate your database, if not you can check out an amazing package I use called fizz.

Next, in the repository.go file, we put the following code:

package repository

import (
    "github.com/orololuwa/crispy-octo-guacamole/models"
)

type UserRepo interface {
    CreateAUser(user models.User) (int, error)
    GetAUser(id int) (models.User, error)
    GetAllUser() ([]models.User, error)
    UpdateAUsersName(id int, firstName, lastName string)(error)
    DeleteUserByID(id int) error
}
Enter fullscreen mode Exit fullscreen mode

In the snippet above, the UserRepo is an interface that defines all the functions that a user repo struct should implement.

Next, in the users.go file, we would go ahead and create a user repo struct, a New-like function to initialize it, and receiver functions to the user repo struct that satisfies the UserRepo interface as thus:

package repository

import (
    "context"
    "database/sql"
    "time"

    "github.com/orololuwa/crispy-octo-guacamole/models"
)

type user struct {
    DB *sql.DB
}

func NewUserRepo(conn *sql.DB) UserRepo {
    return &user{
        DB: conn,
    }
}

func (m *user) CreateAUser(user models.User) (int, error){
    ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    defer cancel()

    var newId int

    query := `
            INSERT into users 
                (first_name, last_name, email, password, created_at, updated_at)
            values 
                ($1, $2, $3, $4, $5, $6)
            returning id`

    err := m.DB.QueryRowContext(ctx, query, 
        user.FirstName, 
        user.LastName, 
        user.Email, 
        user.Password,
        time.Now(),
        time.Now(),
    ).Scan(&newId)

    if err != nil {
        return 0, err
    }

    return newId, nil
}

func (m *user) GetAUser(id int) (models.User, error){
    ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    defer cancel()

    var user models.User

    query := `
            SELECT id, first_name, last_name, email, password, created_at, updated_at
            from users
            WHERE
            id=$1
    `

    err := m.DB.QueryRowContext(ctx, query, id).Scan(
        &user.ID,
        &user.FirstName,
        &user.LastName,
        &user.Email,
        &user.Password,
        &user.CreatedAt,
        &user.UpdatedAt,
    )

    if err != nil {
        return user, err
    }

    return user, nil
}

func (m *user) GetAllUser() ([]models.User, error){
    ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    defer cancel()

    var users = make([]models.User, 0)

    query := `
        SELECT id, first_name, last_name, email, password, created_at, updated_at
        from users
    `

    rows, err := m.DB.QueryContext(ctx, query)
    if err != nil {
        return users, err
    }

    for rows.Next(){
        var user models.User
        err := rows.Scan(
            &user.ID,
            &user.FirstName,
            &user.LastName,
            &user.Email,
            &user.Password,
            &user.CreatedAt,
            &user.UpdatedAt,
        )
        if err != nil {
            return users, err
        }
        users = append(users, user)
    }

    if err = rows.Err(); err != nil {
        return users, err
    }

    return users, nil
}

func (m *user) UpdateAUsersName(id int, firstName, lastName string)(error){
    ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    defer cancel()

    query := `
        UPDATE 
            users set (first_name, last_name) = ($1, $2)
        WHERE
            id = $3
    `

    _, err := m.DB.ExecContext(ctx, query, firstName, lastName, id)
    if err != nil{
        return  err
    }

    return nil
}

func (m *user) DeleteUserByID(id int) error {
    ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    defer cancel()

    query := "DELETE FROM users WHERE id = $1"

    _, err := m.DB.ExecContext(ctx, query, id)
    if err != nil {
        return err
    }

    return nil
}
Enter fullscreen mode Exit fullscreen mode

The standard convention across all repositories involves an initialization pattern, where a New-like function takes a database connection to initialize the type, thereby assigning it to an unexported field. Some refer to this convention as the Repository Pattern. The function NewUserRepo takes in a connection and initializes a new repo from the user struct, it also returns the UserRepo to ensure that the struct properly implements the functions in the UserRepo interface.
Without going too deep, the functions create, read from, update, and delete a user from the database. You'd notice that the values were being passed into the query in the format of $1, $2; it is a good practice to pass it as such to prevent SQL injection.

Next, in the main.go file, I would create a handler that will call one of the DatabaseRepo functions to test as thus:

package main

import (
+   "encoding/json"
    "fmt"
    "log"
    "net/http"

+   "github.com/go-chi/chi/v5"
    "github.com/orololuwa/crispy-octo-guacamole/driver"
+   "github.com/orololuwa/crispy-octo-guacamole/models"
+   "github.com/orololuwa/crispy-octo-guacamole/repository"
)
const portNumber = ":8080"

func main(){
-   db, err := run()
+   db, route, err := run()
    if (err != nil){
        log.Fatal(err)
    }
    defer db.SQL.Close()

    fmt.Println(fmt.Sprintf("Staring application on port %s", portNumber))


    srv := &http.Server{
        Addr: portNumber,
-       Handler: nil,
+       Handler: route,
    }

    err = srv.ListenAndServe()
    if err != nil {
        log.Fatal(err)
    }
}

-func run()(*driver.DB, error){
+func run()(*driver.DB, *chi.Mux, error){
    dbHost := "localhost"
    dbPort := "5432"
    dbName := "bookings"
    dbUser := "orololuwa"
    dbPassword := ""
    dbSSL := "disable"

    // Connecto to DB
    log.Println("Connecting to dabase")
    connectionString := fmt.Sprintf("host=%s port=%s dbname=%s user=%s password=%s sslmode=%s", dbHost, dbPort, dbName, dbUser, dbPassword, dbSSL)

    db, err := driver.ConnectSQL(connectionString)
    if err != nil {
        log.Fatal("Cannot conect to database: Dying!", err)
    }
    log.Println("Connected to database")
+   userRepo := repository.NewUserRepo(db.SQL)
+   router := chi.NewRouter()
+
+   router.Post("/user", func(w http.ResponseWriter, r *http.Request) {
+       type userBody struct {
+           FirstName string `json:"firstName"`
+           LastName string `json:"lastName"`
+           Email string `json:"email"`
+           Password string `json:"password"`
+       }
+
+       var body userBody
+       
+       err := json.NewDecoder(r.Body).Decode(&body)
+       if err != nil {
+           w.Header().Set("Content-Type", "application/json")
+           +w.WriteHeader(http.StatusInternalServerError)
+           return
+       }
+
+       user := models.User{
+           FirstName: body.FirstName,
+           LastName: body.LastName,
+           Email: body.Email,
+           Password: body.Password,
+       }
+       
+       id, err := userRepo.CreateAUser(user)
+       if err != nil {
+           w.Header().Set("Content-Type", "application/json")
+           +w.WriteHeader(http.StatusInternalServerError)
+           return
+       }
+
+       response := map[string]interface{}{"message": "user created successfully", "data": id}
+       jsonResponse, err := json.Marshal(response)
+       if err != nil {
+           http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
+           return
+       }
+       w.Header().Set("Content-Type", "application/json")
+       w.WriteHeader(http.StatusOK)
+       w.Write(jsonResponse)
+   })
+
-   return db, nil
+   return db, router, nil
}
Enter fullscreen mode Exit fullscreen mode

in the code above, we register a new router (chi router is being used here), create a post handler and assign it to a /user route. In the handler, we decode the data from the request's body and pass it into the CreateAUser function which is available to us in the userRepo. You can run your application and go ahead to test it on Postman.

a Postman image showing the result of successfully the user route on a post method

Part 2: Implementing a transaction function

There are some queries you want to either completely succeed or fail, e.g. If you have an e-commerce system that has a wallet feature; When a user buys a product, you want to debit the wallet and reduce the product count in the inventory. These two actions in the database have to either completely succeed or completely fail.
Transaction functions are very crucial to maintaining data consistency. They have a state and all queries executed in the transaction function are tagged to the state after which you can decide to commit the results of executing those queries or rollback if there is an error with either one of the queries.
To start, we'd create a new file under the repository package and name it. you can name it whatever you want, i'd go ahead and name mine db-repo.go. In the file, we'd use the repository pattern to initialize a new instance, but before then we need to create an interface that declares the transaction function that should be implemented. So, we would go ahead and update the repository.go file with this.

package repository

import (
+   "context"
+   "database/sql"
+
    "github.com/orololuwa/crispy-octo-guacamole/models"
)

type UserRepo interface {
    CreateAUser(user models.User) (int, error)
    GetAUser(id int) (models.User, error)
    GetAllUser() ([]models.User, error)
    UpdateAUsersName(id int, firstName, lastName string)(error)
    DeleteUserByID(id int) error
}

+type DBRepo interface {
+   Transaction(ctx context.Context, operation func(context.Context, *sql.Tx) error) error 
+}
Enter fullscreen mode Exit fullscreen mode

Then, in the db-repo.go file, we go ahead and put the following code.

package repository

import (
    "context"
    "database/sql"
)

type dbRepo struct {
    DB *sql.DB
}

func NewDBRepo(conn *sql.DB) DBRepo {
    return &dbRepo{
        DB: conn,
    }
}

func (m *dbRepo) Transaction(ctx context.Context, operation func(context.Context, *sql.Tx) error) error {
    tx, err := m.DB.BeginTx(ctx, nil)
    if err != nil {
        return err
    }

    defer func() error{
        if err != nil {
            tx.Rollback()
            return err
        }

        if err := tx.Commit(); err != nil {
            return err
        }

        return nil
    }()

    if err := operation(ctx, tx); err != nil {
        return err
    }

    return nil
}
Enter fullscreen mode Exit fullscreen mode

In the code above, we declare a dbRepo struct and a function to initialize it and implement the DBRepo interface. The transaction function takes in a context and a function that uses the context and the transaction that was started inside the parent function to execute queries.
The BeginTx function from *sql.DB takes in two variables, the context and an object where you can specify the isolation level of the transaction. A function is run after the transaction has been called where there is a rollback if an error occurs or the changes are committed if no error occurs.
Now we would have to refactor the UserRepo functions to accept the context and the transaction reference.

package repository

import (
    "context"
    "database/sql"

    "github.com/orololuwa/crispy-octo-guacamole/models"
)

type UserRepo interface {
-   CreateAUser(user models.User) (int, error)
-   GetAUser(id int) (models.User, error)
-   GetAllUser() ([]models.User, error)
-   UpdateAUsersName(id int, firstName, lastName string)(error)
-   DeleteUserByID(id int) error
+   CreateAUser(ctx context.Context, tx *sql.Tx, user models.User) (int, error)
+   GetAUser(ctx context.Context, tx *sql.Tx, id int) (models.User, error)
+   GetAllUser(ctx context.Context, tx *sql.Tx) ([]models.User, error)
+   UpdateAUsersName(ctx context.Context, tx *sql.Tx, id int, firstName, lastName string)(error)
+   DeleteUserByID(ctx context.Context, tx *sql.Tx, id int) error
}

type DBRepo interface {
    Transaction(ctx context.Context, operation func(context.Context, *sql.Tx) error) error 
}
Enter fullscreen mode Exit fullscreen mode

Using the CreateAUser function to highlight the difference in the user.go file, we have;

-func (m *user) CreateAUser(user models.User) (int, error){
-   ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+func (m *user) CreateAUser(ctx context.Context, tx *sql.Tx, user models.User) (int, error){
+   ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
    defer cancel()

    var newId int

    query := `
            INSERT into users 
                (first_name, last_name, email, password, created_at, updated_at)
            values 
                ($1, $2, $3, $4, $5, $6)
            returning id`

-   err := m.DB.QueryRowContext(ctx, query, 
+   var err error;
+   if tx != nil {
+       err = tx.QueryRowContext(ctx, query, 
+           user.FirstName, 
+           user.LastName, 
+           user.Email, 
+           user.Password,
+           time.Now(),
+           time.Now(),
+       ).Scan(&newId)
+   }else{
        err = m.DB.QueryRowContext(ctx, query, 
            user.FirstName, 
            user.LastName, 
            user.Email, 
            user.Password,
            time.Now(),
            time.Now(),
        ).Scan(&newId)
+   }
Enter fullscreen mode Exit fullscreen mode

we now use the context obtained from the function variable and dependent on whether a transaction is passed, we use the tx variable to execute our queries. The full updated code for all the UserRepo functions:

package repository

import (
    "context"
    "database/sql"
    "time"

    "github.com/orololuwa/crispy-octo-guacamole/models"
)

type user struct {
    DB *sql.DB
}

func NewUserRepo(conn *sql.DB) UserRepo {
    return &user{
        DB: conn,
    }
}

func (m *user) CreateAUser(ctx context.Context, tx *sql.Tx, user models.User) (int, error){
    ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
    defer cancel()

    var newId int

    query := `
            INSERT into users 
                (first_name, last_name, email, password, created_at, updated_at)
            values 
                ($1, $2, $3, $4, $5, $6)
            returning id`

    var err error;
    if tx != nil {
        err = tx.QueryRowContext(ctx, query, 
            user.FirstName, 
            user.LastName, 
            user.Email, 
            user.Password,
            time.Now(),
            time.Now(),
        ).Scan(&newId)
    }else{
        err = m.DB.QueryRowContext(ctx, query, 
            user.FirstName, 
            user.LastName, 
            user.Email, 
            user.Password,
            time.Now(),
            time.Now(),
        ).Scan(&newId)
    }

    if err != nil {
        return 0, err
    }

    return newId, nil
}

func (m *user) GetAUser(ctx context.Context, tx *sql.Tx, id int) (models.User, error){
    ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
    defer cancel()

    var user models.User

    query := `
            SELECT (id, first_name, last_name, email, password, created_at, updated_at)
            from users
            WHERE
            id=$1
    `

    var err error
    if tx != nil {
        err = tx.QueryRowContext(ctx, query, id).Scan(
            &user.ID,
            &user.FirstName,
            &user.LastName,
            &user.Email,
            &user.Password,
            &user.CreatedAt,
            &user.UpdatedAt,
        )
    }else{
        err = m.DB.QueryRowContext(ctx, query, id).Scan(
            &user.ID,
            &user.FirstName,
            &user.LastName,
            &user.Email,
            &user.Password,
            &user.CreatedAt,
            &user.UpdatedAt,
        )
    }

    if err != nil {
        return user, err
    }

    return user, nil
}

func (m *user) GetAllUser(ctx context.Context, tx *sql.Tx) ([]models.User, error){
    ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
    defer cancel()

    var users = make([]models.User, 0)

    query := `
        SELECT (id, first_name, last_name, email, password, created_at, updated_at)
        from users
    `

    var rows *sql.Rows
    var err error

    if tx != nil {
        rows, err = tx.QueryContext(ctx, query)
    }else{
        rows, err = m.DB.QueryContext(ctx, query)
    }
    if err != nil {
        return users, err
    }

    for rows.Next(){
        var user models.User
        err := rows.Scan(
            &user.ID,
            &user.FirstName,
            &user.LastName,
            &user.Email,
            &user.Password,
            &user.CreatedAt,
            &user.UpdatedAt,
        )
        if err != nil {
            return users, err
        }
        users = append(users, user)
    }

    if err = rows.Err(); err != nil {
        return users, err
    }

    return users, nil
}

func (m *user) UpdateAUsersName(ctx context.Context, tx *sql.Tx, id int, firstName, lastName string)(error){
    ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
    defer cancel()

    query := `
        UPDATE 
            users set (first_name, last_name) = ($1, $2)
        WHERE
            id = $3
    `

    var err error
    if tx != nil{
        _, err = tx.ExecContext(ctx, query, firstName, lastName, id)
    }else{
        _, err = m.DB.ExecContext(ctx, query, firstName, lastName, id)
    }

    if err != nil{
        return  err
    }

    return nil
}

func (m *user) DeleteUserByID(ctx context.Context, tx *sql.Tx, id int) error {
    ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
    defer cancel()

    query := "DELETE FROM users WHERE id = $1"

    var err error 

    if tx != nil {
        _, err = tx.ExecContext(ctx, query, id)
    }else{
        _, err = m.DB.ExecContext(ctx, query, id)
    }

    if err != nil {
        return err
    }

    return nil
}
Enter fullscreen mode Exit fullscreen mode

Then in your handler, you can go ahead and call the transaction function as thus:

package main

import (
+   "context"
+   "database/sql"
    "encoding/json"
    "fmt"
    "log"
    "net/http"

    "github.com/go-chi/chi/v5"
    "github.com/orololuwa/crispy-octo-guacamole/driver"
    "github.com/orololuwa/crispy-octo-guacamole/models"
    "github.com/orololuwa/crispy-octo-guacamole/repository"
)
const portNumber = ":8080"

func main(){
    db, route, err := run()
    if (err != nil){
        log.Fatal(err)
    }
    defer db.SQL.Close()

    fmt.Println(fmt.Sprintf("Staring application on port %s", portNumber))


    srv := &http.Server{
        Addr: portNumber,
        Handler: route,
    }

    err = srv.ListenAndServe()
    if err != nil {
        log.Fatal(err)
    }
}

func run()(*driver.DB, *chi.Mux, error){
    dbHost := "localhost"
    dbPort := "5432"
    dbName := "bookings"
    dbUser := "orololuwa"
    dbPassword := ""
    dbSSL := "disable"

    // Connecto to DB
    log.Println("Connecting to dabase")
    connectionString := fmt.Sprintf("host=%s port=%s dbname=%s user=%s password=%s sslmode=%s", dbHost, dbPort, dbName, dbUser, dbPassword, dbSSL)

    db, err := driver.ConnectSQL(connectionString)
    if err != nil {
        log.Fatal("Cannot conect to database: Dying!", err)
    }
    log.Println("Connected to database")

    userRepo := repository.NewUserRepo(db.SQL)
    dbRepo := repository.NewDBRepo(db.SQL)
    router := chi.NewRouter()

    router.Post("/user", func(w http.ResponseWriter, r *http.Request) {
        type userBody struct {
            FirstName string `json:"firstName"`
            LastName string `json:"lastName"`
            Email string `json:"email"`
            Password string `json:"password"`
        }

        var body userBody

        err := json.NewDecoder(r.Body).Decode(&body)
        if err != nil {
            w.Header().Set("Content-Type", "application/json")
            w.WriteHeader(http.StatusInternalServerError)
            return
        }

        user := models.User{
            FirstName: body.FirstName,
            LastName: body.LastName,
            Email: body.Email,
            Password: body.Password,
        }

-       id, err := userRepo.CreateAUser(user)
+       ctx := context.Background()
+       var id int

+       err = dbRepo.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
+           id, err = userRepo.CreateAUser(ctx, tx, user)
+           if err != nil {
+               return err
+           }
+
+           userRepo.UpdateAUsersName(ctx, tx, id, body.FirstName, "test")
+           if err != nil {
+               return err
+           }
+
+           return nil
+       })
+       
        if err != nil {
            w.Header().Set("Content-Type", "application/json")
            w.WriteHeader(http.StatusInternalServerError)
            return
        }

        response := map[string]interface{}{"message": "user created successfully", "data": id}
        jsonResponse, err := json.Marshal(response)
        if err != nil {
            http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
            return
        }
        w.Header().Set("Content-Type", "application/json")
        w.WriteHeader(http.StatusOK)
        w.Write(jsonResponse)
    })

    return db, router, nil
}
Enter fullscreen mode Exit fullscreen mode

The full code is available here on Github

What next?

You'd probably want to learn how to properly set up and test handlers. Stay tuned 🧏🏽

Top comments (2)

Collapse
 
oladipo profile image
Oladipo Olasemo

This is good stuff

Collapse
 
orololuwa profile image
Orololuwa

Thank you Mr Dipo