DEV Community

Cover image for Write Go unit tests for db CRUD with random data
TECH SCHOOL
TECH SCHOOL

Posted on • Updated on

Write Go unit tests for db CRUD with random data

Hi guys, welcome back!

In the previous lecture, we have learn how to generate golang CRUD code to talk to the database. Today we will learn how to write unit test for those CRUD operations.

Here's:

Test Create Account

Let’s start with the CreateAccount() function. I’m gonna create a new file account_test.go inside the db/sqlc folder.

In Golang, we have a convention to put the test file in the same folder with the code, and the name of the test file should end with the _test suffix.

The package name of this test file will be db, the same package that our CRUD code is in. Now let’s define function TestCreateAccount().

func TestCreateAccount(t *testing.T) {
    ...
}
Enter fullscreen mode Exit fullscreen mode

Every unit test function in Go must start with the Test prefix (with uppercase letter T) and takes a testing.T object as input. We will use this T object to manage the test state.

The CreateAccount() function is defined as a method of Queries object, and it requires a database connection to talk to the database. So in order to write the test, we have to setup the connection and the Queries object first. The right place to do that is in the main_test.go file.

I will define a testQueries object as a global variable because we’re gonna use it extensively in all of our unit tests.

var testQueries *Queries
Enter fullscreen mode Exit fullscreen mode

The Queries object is defined in db.go file, that was generated by sqlc. It contains a DBTX, which can either be a db connection or a transaction:

type Queries struct {
    db DBTX
}
Enter fullscreen mode Exit fullscreen mode

In our case, we’re gonna build a db connection and use it to create the Queries object.

I’m gonna declare a special function called TestMain(), which takes a testing.M object as input.

func TestMain(m *testing.M) {
    ...
}
Enter fullscreen mode Exit fullscreen mode

By convention, the TestMain() function is the main entry point of all unit tests inside 1 specific golang package which in this case, is package db.

Keep in mind that unit tests in Golang are run separately for each package, so if you have multiple packages in your project, you can have multiple main_test.go file with different TestMain() entry points.

OK, now to create a new connection to the database, we use sql.Open() function, and pass in the db driver and db source string. For now, I’m just gonna declare them as constants. In the future, we will learn how to load them from environment variables instead.

The db driver should be postgres. And the db source, we can copy from the migrate command that we’ve written in the previous lecture.

const (
    dbDriver = "postgres"
    dbSource = "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable"
)

var testQueries *Queries

func TestMain(m *testing.M) {
    conn, err := sql.Open(dbDriver, dbSource)
    if err != nil {
        log.Fatal("cannot connect to db:", err)
    }

    testQueries = New(conn)

    os.Exit(m.Run())
}
Enter fullscreen mode Exit fullscreen mode

The sql.Open() function returns a connection object and an error. If error is not nil, we just write a fatal log saying we cannot connect to the database.

Else, we use the connection to create the new testQueries object. The New() function is defined in the db.go file that sqlc has generated for us.

Now the testQueries is ready, all we have to do is to call m.Run() to start running the unit test. This function will return an exit code, which tell us whether the tests pass or fail. Then we should report it back to the test runner via os.Exit() command.

OK let’s try to run it!

fail-import

We’ve got an error: cannot connect to db: unknown driver "postgres".

This is because the database/sql package just provides a generic interface around SQL database. It needs to be used in conjunction with a database driver in order to talk to a specific database engine.

We’re using postgres, so I’m gonna use lib/pq driver. Let’s open its github page, and copy the go get command. Run it in the terminal to install the package:

go get github.com/lib/pq
Enter fullscreen mode Exit fullscreen mode

Now if we open the go.mod file, we can see lib/pq is added.

indirect-import

Here it says "indirect” because we haven’t imported and used it in our code yet. So let’s go back to the main_test.go file and import the lib/pq driver:

import "github.com/lib/pq"
Enter fullscreen mode Exit fullscreen mode

This is a very special import because we don’t actually call any function of lib/pq directly in the code. The underlying code of database/sql will take care of that.

So if we just import like this, the go formatter will automatically remove it when we save the file. To tell go formatter to keep it, we must use the blank identifier by adding an underscore before the import package name:

import (
    "database/sql"
    "log"
    "os"
    "testing"

    _ "github.com/lib/pq"
)
Enter fullscreen mode Exit fullscreen mode

Now if we run the TestMain() again, there are no errors any more.

import-ok

And if we open the terminal and run go mod tidy to clean up the dependencies, we can now see that the require lib/pq in go.mod file is no longer indirect, since we have imported it in our code.

no-indirect

Alright, now the setup is done, we can start writing our first unit test for CreateAccount() function.

First we declare a new arguments: CreateAccountParams. Let’s say owner’s name is tom, the account balance is 100, and the currency is USD.

Then we call testQueries.CreateAccount(), pass in a background context, and the arguments. This testQueries object is the one we declared in the main_test.go file before.

func TestCreateAccount(t *testing.T) {
    arg := CreateAccountParams{
        Owner:    "tom",
        Balance:  100,
        Currency: "USD",
    }

    account, err := testQueries.CreateAccount(context.Background(), arg)

    ...
}
Enter fullscreen mode Exit fullscreen mode

The CreateAccount() function returns an account object or an error as result.

To check the test result, I recommend using the testify package. It’s more concise than just using the standard if else statements. Let’s run this go get command in the terminal to install the package:

go get github.com/stretchr/testify
Enter fullscreen mode Exit fullscreen mode

Alright, now to use this package, we need to import it first. Testify contains several sub-packages, but I’m just gonna use one of them, which is the require package.

import "github.com/stretchr/testify/require"
Enter fullscreen mode Exit fullscreen mode

With this import, we can now call require.NoError(), pass in the testing.T object and the error returned by the CreateAccount() function.

func TestCreateAccount(t *testing.T) {
    ...

    account, err := testQueries.CreateAccount(context.Background(), arg)

    require.NoError(t, err)
    require.NotEmpty(t, account)
}
Enter fullscreen mode Exit fullscreen mode

Basically, this command will check that the error must be nil and will automatically fail the test if it’s not.

Next, we require that the returned account should not be an empty object using require.NotEmpty() function.

After that, we would want to check that the account owner, balance and currency matches with the input arguments.

So we call require.Equal(), pass in t, the expected input owner, and the actual account.Owner.

func TestCreateAccount(t *testing.T) {
    ...

    require.Equal(t, arg.Owner, account.Owner)
    require.Equal(t, arg.Balance, account.Balance)
    require.Equal(t, arg.Currency, account.Currency)
}
Enter fullscreen mode Exit fullscreen mode

Similarly, we require arg.Balance to be equal to account.Balance, and arg.Currency to be equal to account.Currency.

We also want to check that the account ID is automatically generated by Postgres. So here we require account.ID to be not zero.

func TestCreateAccount(t *testing.T) {
    ...

    require.NotZero(t, account.ID)
    require.NotZero(t, account.CreatedAt)
}
Enter fullscreen mode Exit fullscreen mode

Finally, the created_at column should also be filled with the current timestamp. The NotZero() function will assert that a value must not be a zero value of its type.

That’s it! The unit test is completed. Let’s click this button to run it.

test-create-account-ok

We see an ok here, so it passed. Let’s open the simple_bank database with TablePlus to make sure that a record has been inserted.

table-plus-account-created

Here it is, we have 1 account with id 1. The owner, balance and currency values are the same as we set in the test. And the created_at field is also filled with the current timestamp. Excellent!

We can also click Run package tests to run the whole unit tests in this package. For now it just has only 1 test, so it doesn’t matter.

code-coverage

But the nice thing is the code coverage is also reported. At the moment, our unit tests cover only 6.5% of the statements, which is very low.

If we look at the account.sql.go file, we can see the CreateAccount() function is now marked with green, which means it is covered by the unit tests.

cover-green-red

All other functions are still red, which means they’re not covered. We will write more unit tests to cover them in a moment.

But before that, I’m gonna show you a better way to generate test data instead of filling them manually as we’re doing for the create-account arguments.

Generate random data

By generating random data, we will save a lot of time figuring out what values to use, the code will be more concise and easier to understand.

And because the data is random, it will help us avoid conflicts between multiple unit tests. This is specially important if we have a column with unique constraint in the database, for example.

Alright, let’s create a new folder util, and add a new file random.go inside it. The package name is util, same as the folder containing it.

First we need to write a special function: init(). This function will be called automatically when the package is first used.

package util

func init() {
    rand.Seed(time.Now().UnixNano())
}
Enter fullscreen mode Exit fullscreen mode

In this function, we set the seed value for the random generator by callling rand.Seed(). Normally the seed value is often set to the current time.

As rand.Seed() expect an int64 as input, we should convert the time to unix nano before passing it to the function.

This will make sure that every time we run the code, the generated values will be different. If we don’t call rand.Seed(), the random generator will behave like it is seeded by 1, so the generated values will be the same for every run.

Now we will write a function to generate a random integer:

func RandomInt(min, max int64) int64 {
    return min + rand.Int63n(max-min+1)
}
Enter fullscreen mode Exit fullscreen mode

This RandomInt() function takes 2 int64 numbers: min and max as input. And it returns a random int64 number between min and max.

Basically the rand.Int63n(n) function returns a random integer between 0 and n-1. So rand.Int63n(max - min + 1) will return a random integer between 0 and max - min.

Thus, when we add min to this expression, the final result will be a random integer between min and max.

Next, let’s write a function to generate a random string of n characters. For this, we will need to declare an alphabet that contains all supported characters. To be simple, here I just use the 26 lowercase English letters.

const alphabet = "abcdefghijklmnopqrstuvwxyz"

func RandomString(n int) string {
    var sb strings.Builder
    k := len(alphabet)

    for i := 0; i < n; i++ {
        c := alphabet[rand.Intn(k)]
        sb.WriteByte(c)
    }

    return sb.String()
}
Enter fullscreen mode Exit fullscreen mode

In the RandomString() function, we declare a new string builder object sb, get the total number of characters in the alphabet and assign it to k.

Then we will use a simple for loop to generate n random characters. We use rand.Intn(k) to get a random position from 0 to k-1, and take the corresponding character at that position in the alphabet, assign it to variable c.

We call sb.WriteByte() to write that character c to the string builder. Finally we just return sb.ToString() to the caller.

And the RandomString() function is done. We can now use it to generate a random owner name.

Let’s define a new RandomOwner() function for this purpose. And inside, we just return a random string of 6 letters. I think that’s long enough to avoid duplication.

func RandomOwner() string {
    return RandomString(6)
}
Enter fullscreen mode Exit fullscreen mode

Similarly, I’m gonna define another RandomMoney() function to generate a random amount of money. Let’s say it’s gonna be a random integer between 0 and 1000.

func RandomMoney() int64 {
    return RandomInt(0, 1000)
}
Enter fullscreen mode Exit fullscreen mode

We need one more function to generate a random currency as well.

func RandomCurrency() string {
    currencies := []string{"EUR", "USD", "CAD"}
    n := len(currencies)
    return currencies[rand.Intn(n)]
}
Enter fullscreen mode Exit fullscreen mode

This RandomCurrency() function will return one of the currencies in the list. Here I just use 3 currencies: EUR, USD and CAD. You can add more values if you want.

Similar to what we’ve done to generate a random character from the alphabet, here we compute the length of the currency list and assign it to n.

Then we use rand.Intn(n) function to generate a random index between 0 and n-1, and return the currency at that index from the list.

Alright, Now get back to the account_test.go file. In the CreateAccountParams, we can replace the specific owner name with util.RandomOwner(), the balance with util.RandomMoney(), and USD with util.RandomCurrency().

func createRandomAccount(t *testing.T) Account {
    arg := CreateAccountParams{
        Owner:    util.RandomOwner(),
        Balance:  util.RandomMoney(),
        Currency: util.RandomCurrency(),
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

And that’s it!

Now if we rerun the unit test and refresh TablePlus, we can see a new record id = 3 with random values.

random-account

The first 2 records are fixed values because we had run the test twice before we use random functions.

So it works!

Now I’m gonna add a new test command to the Makefile so that we can easily run unit tests in the terminal.

The command is simple. We just call go test, use -v option to print verbose logs, and -cover option to measure code coverage.

test:
    go test -v -cover ./...
Enter fullscreen mode Exit fullscreen mode

As our project is gonna have multiple packages, we use this ./... argument to run unit tests in all of them.

Now if we run make test in the terminal, we can see it prints out verbose logs whenever a test is run or finished.

make-test

It all so reports the code coverage of the unit tests for each package. Cool!

Let’s refresh TablePlus to see the new record:

different-account

It’s a completely different value from the previous record. So the random generator is working well.

Next I will show you how to write unit tests for the rest of the CRUD operations: Delete, Get, List, and Update.

Test Get Account

Let’s start with the GetAccount() function.

You know, to test all of other CRUD operations, we always need to create an account first.

Note that when writing unit tests, we should make sure that they are independent from each other.

Why? Because it would be very hard to maintain if we have hundred of tests that depends on each other. Believe me, the last thing you ever want is when a simple change in a test affects the result of some other ones.

For this reason, each test should create its own account records. To void code duplication, let’s write a separate function to create a random account. Paste in the codes that we’ve written in the TestCreateAccount() function:

func createRandomAccount(t *testing.T) Account {
    arg := CreateAccountParams{
        Owner:    util.RandomOwner(),
        Balance:  util.RandomMoney(),
        Currency: util.RandomCurrency(),
    }

    account, err := testQueries.CreateAccount(context.Background(), arg)
    require.NoError(t, err)
    require.NotEmpty(t, account)

    require.Equal(t, arg.Owner, account.Owner)
    require.Equal(t, arg.Balance, account.Balance)
    require.Equal(t, arg.Currency, account.Currency)

    require.NotZero(t, account.ID)
    require.NotZero(t, account.CreatedAt)

    return account
}
Enter fullscreen mode Exit fullscreen mode

Then for the TestCreateAccount(), we just need to call createRandomAccount() and pass in the testing.T object like this:

func TestCreateAccount(t *testing.T) {
    createRandomAccount(t)
}
Enter fullscreen mode Exit fullscreen mode

Note that the createRandomAccount() function doesn’t have the Test prefix, so it won’t be run as a unit test. Instead, it should return the created Account record, so that other unit tests can have enough data to perform their own operation.

Now with this function in hand, we can write test for the GetAccount() function.

First we call createRandomAccount() and save the created record to account1. Then we call testQueries.GetAccount() with a background context and the ID of account1. The result is account2 or an error.

func TestGetAccount(t *testing.T) {
    account1 := createRandomAccount(t)
    account2, err := testQueries.GetAccount(context.Background(), account1.ID)

    require.NoError(t, err)
    require.NotEmpty(t, account2)

    ...
}
Enter fullscreen mode Exit fullscreen mode

We check that error should be nil using the require.NoError() function. Then we require account2 to be not empty.

All the data fields of account2 should equal to those of account1. We use require.Equal() function to compare them. First the ID, then the account owner, the balance, and the currency.

func TestGetAccount(t *testing.T) {
    ...

    require.Equal(t, account1.ID, account2.ID)
    require.Equal(t, account1.Owner, account2.Owner)
    require.Equal(t, account1.Balance, account2.Balance)
    require.Equal(t, account1.Currency, account2.Currency)
    require.WithinDuration(t, account1.CreatedAt, account2.CreatedAt, time.Second)

    ...
}
Enter fullscreen mode Exit fullscreen mode

For the timestamp fields like created_at, beside require.Equal(), you can also use require.WithinDuration() to check that 2 timestamps are different by at most some delta duration. For example, in this case, I choose delta to be 1 second.

And that’s it! The unit test for GetAccount() operation is done. Let’s run it:

get-account

It passed!

Test Update Account

Now let’s write test for the UpdateAccount() function. The first step is to create a new account1.

Then we declare the arguments, which is an UpdateAccountParams object, where ID is the created account’s ID, and balance is a random amount of money.

func TestUpdateAccount(t *testing.T) {
    account1 := createRandomAccount(t)

    arg := UpdateAccountParams{
        ID:      account1.ID,
        Balance: util.RandomMoney(),
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Now we call testQueries.UpdateAccount(), pass in a background context and the update arguments.

Then we require no errors to be returned. The updated account2 object should not be empty.

func TestUpdateAccount(t *testing.T) {
    ...

    account2, err := testQueries.UpdateAccount(context.Background(), arg)

    require.NoError(t, err)
    require.NotEmpty(t, account2)
}
Enter fullscreen mode Exit fullscreen mode

And we compare each individual field of account2 to account1. Almost all of them should be the same, except for the balance, which should be changed to arg.Balance:

func TestUpdateAccount(t *testing.T) {
    ...

    require.Equal(t, account1.ID, account2.ID)
    require.Equal(t, account1.Owner, account2.Owner)
    require.Equal(t, arg.Balance, account2.Balance)
    require.Equal(t, account1.Currency, account2.Currency)
    require.WithinDuration(t, account1.CreatedAt, account2.CreatedAt, time.Second)
}
Enter fullscreen mode Exit fullscreen mode

Alright, let’s run this test.

update-account

It passed!

Test Delete Account

The TestDeleteAccount() can be easily implemented in the similar fashion.

First we create a new account1. Then we call testQueries.DeleteAccount(), and pass in the background context as well as the ID of the created account1. We require no errors to be returned.

func TestDeleteAccount(t *testing.T) {
    account1 := createRandomAccount(t)
    err := testQueries.DeleteAccount(context.Background(), account1.ID)
    require.NoError(t, err)

    ...
}
Enter fullscreen mode Exit fullscreen mode

Then to make sure that the account is really deleted, we call testQueries.GetAccount() to find it in the database. In this case, the call should return an error. So we use require.Error() here.

func TestDeleteAccount(t *testing.T) {
    ...

    account2, err := testQueries.GetAccount(context.Background(), account1.ID)
    require.Error(t, err)
    require.EqualError(t, err, sql.ErrNoRows.Error())
    require.Empty(t, account2)
}

Enter fullscreen mode Exit fullscreen mode

To be more precise, we use require.EqualError() function to check that the error should be sql.ErrNoRows. And finally check that the account2 object should be empty.

Now let’s run the test.

delete-account

It passed! Excellent!

Test List Accounts

The last operation we want to test is ListAccount(). It’s a bit different from other functions because it select multiple records.

So to test it, we need to create several accounts. Here I just use a simple for loop to create 10 random accounts.

func TestListAccounts(t *testing.T) {
    for i := 0; i < 10; i++ {
        createRandomAccount(t)
    }

    arg := ListAccountsParams{
        Limit:  5,
        Offset: 5,
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Then we declare the list-accounts parameters. Let’s say the limit is 5, and offset is 5, which means skip the first 5 records, and return the next 5.

When we run the tests, there will be at least 10 accounts in the database, So with these parameters, we expect to get 5 records.

Now we call testQueries.ListAccounts() with a background context and the parameters.

func TestListAccounts(t *testing.T) {
    ...

    accounts, err := testQueries.ListAccounts(context.Background(), arg)
    require.NoError(t, err)
    require.Len(t, accounts, 5)

    for _, account := range accounts {
        require.NotEmpty(t, account)
    }
}
Enter fullscreen mode Exit fullscreen mode

We require no errors, and the length of the returned accounts slice should be 5.

We also iterate through the list of the accounts and require each of them to be not empty.

That’s it! Let’s run this test.

list-accounts

It passed! Now let’s run all unit tests in this package.

package test

All passed.

If we look at the account.sql.go file, we can see that all Account CRUD functions are covered.

covered-all

But why the total coverage of this package is only 33.8%?

That’s because we haven’t written any tests for the CRUD operations of Entry and Transfer tables. I leave it as an exercise for you to practice.

I hope this article is useful for you. Thanks for reading and see you in the next lecture.


If you like the article, please subscribe to our Youtube channel and follow us on Twitter for more tutorials in the future.


If you want to join me on my current amazing team at Voodoo, check out our job openings here. Remote or onsite in Paris/Amsterdam/London/Berlin/Barcelona with visa sponsorship.

Top comments (5)

Collapse
 
quii profile image
Chris James

These look like integration tests to me, not unit tests.

Collapse
 
yangvw profile image
yangvw

I agree, unit tests don't connect to external components such as real databases.

Collapse
 
objque profile image
Mikhail Kalinin • Edited

Thanks for great article! Just one question: Am I right that tests are creating entities and do not remove they after test is finished? So, that may cause side effects. Better to use setup/teardown or tx+rollback :)

Collapse
 
techschoolguru profile image
TECH SCHOOL

Hi Mikhail, I don't have to setup/teardown in my tests because I use random data. So they won't conflict with each other.

Collapse
 
shayantrix profile image
shayan amirshahkarami

the course is great, I love it.
I have a question about these test parts, if we didn't want our DB to be field with these Random Data (the Data that we don't want to use) so what can we do then ?
TNX a lot my friend ;)