Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Golang create a mock database with handler and call to database using interfaces

I am trying to implement unit testing on my SignUp Handler and a call to database. However, it throws panic error on the database call in my SignUp Handler. It is a simple SignUp Handler that receives a JSON with username, password, and email. I will then use a SELECT statement to check if this username is duplicated inside the SignUp handler itself.

This all works when I am sending my post request to this handler. However, when I am actually doing unit testing, it doesn't work and threw me the 2 error messages. I feel that this is because the database wasn't initialized in the test environment but I am not sure how do do this without using third party frameworks to conduct a mock database.

error message

panic: runtime error: invalid memory address or nil pointer dereference [recovered]
        panic: runtime error: invalid memory address or nil pointer dereference

signup.go

package handler

type SignUpJson struct {
    Username string `json:"username"`
    Password string `json:"password"`
    Email    string `json:"email"`
}

func SignUp(w http.ResponseWriter, r *http.Request) {
    // Set Headers
    w.Header().Set("Content-Type", "application/json")
    var newUser auth_management.SignUpJson

    // Reading the request body and UnMarshal the body to the LoginJson struct
    bs, _ := io.ReadAll(req.Body)
    if err := json.Unmarshal(bs, &newUser); err != nil {
        utils.ResponseJson(w, http.StatusInternalServerError, "Internal Server Error")
        log.Println("Internal Server Error in UnMarshal JSON body in SignUp route:", err)
        return
    }

    ctx := context.Background()
    ctx, cancel = context.WithTimeout(ctx, time.Minute * 2)
    defer cancel()

    // Check if username already exists in database (duplicates not allowed)
    isExistingUsername := database.GetUsername(ctx, newUser.Username) // throws panic error here when testing
    if isExistingUsername {
        utils.ResponseJson(w, http.StatusBadRequest, "Username has already been taken. Please try again.")
        return
    }

    // other code logic...
}

sqlquery.go

package database

var SQL_SELECT_FROM_USERS = "SELECT %s FROM users WHERE %s = $1;"

func GetUsername(ctx context.Context, username string) bool {
    row := conn.QueryRow(ctx, fmt.Sprintf(SQL_SELECT_FROM_USERS, "username", "username"), username)
    return row.Scan() != pgx.ErrNoRows
}

SignUp_test.go

package handler

func Test_SignUp(t *testing.T) {

    var tests = []struct {
        name               string
        postedData         SignUpJson
        expectedStatusCode int
    }{
        {
            name: "valid login",
            postedData: SignUpJson{
                Username: "testusername",
                Password: "testpassword",
                Email:    "[email protected]",
            },
            expectedStatusCode: 200,
        },
    }

    for _, e := range tests {
        jsonStr, err := json.Marshal(e.postedData)
        if err != nil {
            t.Fatal(err)
        }

        // Setting a request for testing
        req, _ := http.NewRequest(http.MethodPost, "/signup", strings.NewReader(string(jsonStr)))
        req.Header.Set("Content-Type", "application/json")

        // Setting and recording the response
        res := httptest.NewRecorder()
        handler := http.HandlerFunc(SignUp)

        handler.ServeHTTP(res, req)

        if res.Code != e.expectedStatusCode {
            t.Errorf("%s: returned wrong status code; expected %d but got %d", e.name, e.expectedStatusCode, res.Code)
        }
    }
}

setup_test.go

func TestMain(m *testing.M) {

    os.Exit(m.Run())
}

I have seen a similar question here but not sure if that is the right approach as there was no response and the answer was confusing: How to write an unit test for a handler that invokes a function that interacts with db in Golang using pgx driver?

like image 847
Jessica Avatar asked Oct 19 '25 05:10

Jessica


1 Answers

Let me try to help you in figuring out how to achieve these things. I refactored your code a little bit but the general idea and the tools used are still the same as yours. First, I'm gonna share the production code that is spread into two files: handlers/handlers.go and repo/repo.go.

handlers/handlers.go file

package handlers

import (
    "context"
    "database/sql"
    "encoding/json"
    "io"
    "net/http"
    "time"

    "handlertest/repo"
)

type SignUpJson struct {
    Username string `json:"username"`
    Password string `json:"password"`
    Email    string `json:"email"`
}

func SignUp(w http.ResponseWriter, r *http.Request) {
    w.Header().Set("Content-Type", "application/json")

    var newUser SignUpJson
    bs, _ := io.ReadAll(r.Body)
    if err := json.Unmarshal(bs, &newUser); err != nil {
        w.WriteHeader(http.StatusBadRequest)
        w.Write([]byte(err.Error()))
        return
    }

    ctx, cancel := context.WithTimeout(r.Context(), time.Minute*2)
    defer cancel()

    db, _ := ctx.Value("DB").(*sql.DB)
    if isExistingUserName := repo.GetUserName(ctx, db, newUser.Username); isExistingUserName {
        w.WriteHeader(http.StatusBadRequest)
        w.Write([]byte("username already present"))
        return
    }
    w.WriteHeader(http.StatusOK)
}

Here, there are two main differences:

  1. The context used. You don't have to instantiate another ctx, just use the one that is provided alongside the http.Request.
  2. The sql client used. The right way is to pass it through the context.Context. For this scenario, you don't have to build any structs or use any interface and so on. Just write a function that expects an *sql.DB as a parameter. Remember this, Functions are first-class citizens.

For sure, there is room for refactoring. The "DB" should be a constant and we've to check for the existence of this entry in the context values but, for the sake of brevity, I omitted these checks.

repo/repo.go file

package repo

import (
    "context"
    "database/sql"

    "github.com/jackc/pgx/v5"
)

func GetUserName(ctx context.Context, db *sql.DB, username string) bool {
    row := db.QueryRowContext(ctx, "SELECT username FROM users WHERE username = $1", username)
    return row.Scan() != pgx.ErrNoRows
}

Here, the code is pretty similar to yours except for these two small things:

  1. There is a dedicated method called QueryRowContext when you wish to take into consideration the context.
  2. Use the prepared statements feature when you've to build an SQL query. Don't concatenate stuff with fmt.Sprintf for two reasons: security and testability.

Now, we're gonna look at the test code.

handlers/handlers_test.go file

package handlers

import (
    "context"
    "net/http"
    "net/http/httptest"
    "strings"
    "testing"

    "github.com/DATA-DOG/go-sqlmock"
    "github.com/jackc/pgx/v5"
    "github.com/stretchr/testify/assert"
)

func TestSignUp(t *testing.T) {
    db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
    if err != nil {
        t.Fatalf("err not expected while open a mock db, %v", err)
    }
    defer db.Close()
    t.Run("NewUser", func(t *testing.T) {
        mock.ExpectQuery("SELECT username FROM users WHERE username = $1").WithArgs("[email protected]").WillReturnError(pgx.ErrNoRows)

        w := httptest.NewRecorder()
        r := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(`{"username": "[email protected]", "password": "1234", "email": "[email protected]"}`))

        ctx := context.WithValue(r.Context(), "DB", db)
        r = r.WithContext(ctx)

        SignUp(w, r)

        assert.Equal(t, http.StatusOK, w.Code)
        if err := mock.ExpectationsWereMet(); err != nil {
            t.Errorf("not all expectations were met: %v", err)
        }
    })

    t.Run("AlreadyExistentUser", func(t *testing.T) {
        rows := sqlmock.NewRows([]string{"username"}).AddRow("[email protected]")
        mock.ExpectQuery("SELECT username FROM users WHERE username = $1").WithArgs("[email protected]").WillReturnRows(rows)

        w := httptest.NewRecorder()
        r := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(`{"username": "[email protected]", "password": "1234", "email": "[email protected]"}`))

        ctx := context.WithValue(r.Context(), "DB", db)
        r = r.WithContext(ctx)

        SignUp(w, r)

        assert.Equal(t, http.StatusBadRequest, w.Code)
        if err := mock.ExpectationsWereMet(); err != nil {
            t.Errorf("not all expectations were met: %v", err)
        }
    })
}

Here, there are a lot of changes compared to your version. Let me quickly recap them:

  • Use the sub-test feature to give a hierarchical structure to the tests.
  • Use the httptest package that provides stuff for building and asserting HTTP Requests and Responses.
  • Use the sqlmock package. The de-facto standard when it comes to mocking a database.
  • Use the context to pass the sql client alongside the http.Request.
  • Assertions have been done with the package github.com/stretchr/testify/assert.

The same applies here: there is room for refactoring (e.g. you can rework the tests by using the Table-Driven Tests feature).

Outro

This can be considered an idiomatic way to write Go code. I know this can be very challenging, especially at the beginning. If you need further details on some parts just let me know and I'll be happy to help you, thanks!

like image 123
ossan Avatar answered Oct 21 '25 20:10

ossan