DEV Community

Cover image for DevPill #8 - How to set up jwt authentication for your Go REST API
Raul Paes Silva
Raul Paes Silva

Posted on

DevPill #8 - How to set up jwt authentication for your Go REST API

1. Create your login handler to check user credential and generate a new JSON Web Token.

Example:

func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {

    var req dto.LoginRequest
    if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
        http.Error(w, "invalid body", http.StatusBadRequest)
        return
    }

    // Validate user from database
    user, err := h.Service.Login(r.Context(), req.Email, req.Password)
    if err != nil {
        http.Error(w, "invalid credentials", http.StatusUnauthorized)
        return
    }

    // Create JWT token
    token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
        "sub": user.ID,
        "exp": time.Now().Add(h.JWTExpiry).Unix(),
    })

    tokenString, err := token.SignedString([]byte(h.JWTSecret))
    if err != nil {
        http.Error(w, "could not create token", http.StatusInternalServerError)
        return
    }

    json.NewEncoder(w).Encode(map[string]string{
        "token": tokenString,
    })
}
Enter fullscreen mode Exit fullscreen mode

The important part here is jwt.NewWithClaims(), where the JWT is generated with the user payload and predefined expiration time.

    // Create JWT token
    token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
        "sub": user.ID,
        "exp": time.Now().Add(h.JWTExpiry).Unix(),
    })

    tokenString, err := token.SignedString([]byte(h.JWTSecret))
    if err != nil {
        http.Error(w, "could not create token", http.StatusInternalServerError)
        return
    }
Enter fullscreen mode Exit fullscreen mode

2. Create a new middleware to check Authorization header

Here, you'll have to get the header "Authorization" to look for the Bearer, which will inform a valid Json Web Token.
Example:

func JWTAuth(secret string) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

            authHeader := r.Header.Get("Authorization")
            if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
                http.Error(w, "unauthorized", http.StatusUnauthorized)
                return
            }

            tokenStr := strings.TrimPrefix(authHeader, "Bearer ")

            token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
                if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
                    return nil, fmt.Errorf("invalid signing method")
                }
                return []byte(secret), nil
            })

            if err != nil || !token.Valid {
                http.Error(w, "invalid token", http.StatusUnauthorized)
                return
            }

            claims := token.Claims.(jwt.MapClaims)

            ctx := context.WithValue(r.Context(), "userID", claims["sub"])

            next.ServeHTTP(w, r.WithContext(ctx))
        })
    }
}
Enter fullscreen mode Exit fullscreen mode

Additionally, it will check if it is valid according to the configured secret.

3. Make sure you use the middleware

Add the middleware handler to all the endpoints you want to protect with JWT authentication.
Example:

    publicMux := http.NewServeMux()
    publicMux.HandleFunc("POST /users/signup", userHandler.CreateUser)
    publicMux.HandleFunc("POST /users/login", authHandler.Login)

    var public http.Handler = publicMux
    public = middleware.Logger(public)
    public = middleware.RequestID(public)

    protectedMux := http.NewServeMux()
    protectedMux.HandleFunc("POST /devices", devHandler.CreateDevice)
    protectedMux.HandleFunc("PUT /devices/{id}", devHandler.UpdateDevice)

    var protected http.Handler = protectedMux
    protected = middleware.JWTAuth(cfg.JWTSecret)(protected)
    protected = middleware.Logger(protected)
    protected = middleware.RequestID(protected)

    root := http.NewServeMux()
    root.Handle("/users/signup", public)
    root.Handle("/users/login", public)

    root.Handle("/devices/", protected)
    root.Handle("/devices", protected)

    server := http.Server{
        Addr:    fmt.Sprintf(":%d", cfg.WebServerPort),
        Handler: root,
    }

Enter fullscreen mode Exit fullscreen mode

Top comments (0)