DEV Community

Sadiul Hakim
Sadiul Hakim

Posted on

Spring Boot Rest Custom Jwt Security

To secure a Spring Boot REST API using JWT, you'll need to implement several key components, including UserDetails, UserDetailsService, filters, and a JWT helper class. The process involves authenticating a user to generate a JWT, and then using that JWT for subsequent requests to authorize access to protected resources.

Dependencies


<dependencies>
    <!-- Spring Boot Security -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-security</artifactId>
    </dependency>

    <!-- Spring Boot Web -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>

    <dependency>
        <groupId>io.jsonwebtoken</groupId>
        <artifactId>jjwt-api</artifactId>
        <version>0.11.5</version>
    </dependency>
    <!-- https://mvnrepository.com/artifact/io.jsonwebtoken/jjwt-impl -->
    <dependency>
        <groupId>io.jsonwebtoken</groupId>
        <artifactId>jjwt-impl</artifactId>
        <version>0.11.5</version>
        <scope>runtime</scope>
    </dependency>
    <!-- https://mvnrepository.com/artifact/io.jsonwebtoken/jjwt-jackson -->
    <dependency>
        <groupId>io.jsonwebtoken</groupId>
        <artifactId>jjwt-jackson</artifactId>
        <version>0.11.5</version>
        <scope>runtime</scope>
    </dependency>
</dependencies>
Enter fullscreen mode Exit fullscreen mode

1. UserDetails

The UserDetails interface in Spring Security represents a principal, or a user, and provides the core user information required for the framework to perform authentication and authorization. It acts as an adapter, allowing your custom user model to be used by Spring Security's authentication mechanisms. You don't usually implement this interface directly; instead, you'll often use org.springframework.security.core.userdetails.User, which is a concrete implementation provided by Spring Security.

Method Purpose
getAuthorities() Returns the user's granted authorities (permissions/roles).
getPassword() Returns the password used to authenticate the user.
getUsername() Returns the username used to authenticate the user.
isAccountNonExpired() Indicates whether the user's account has expired.
isAccountNonLocked() Indicates whether the user is locked or unlocked.
isCredentialsNonExpired() Indicates whether the user's credentials (password) have expired.
isEnabled() Indicates whether the user is enabled or disabled.

2. UserDetailsService

The UserDetailsService interface is a critical component for loading user-specific data during authentication. Its primary responsibility is to find a user by their username and return a UserDetails object that contains all the necessary user data, including roles, password, and account status. You'll need to create a custom class that implements this interface. Spring Security's DaoAuthenticationProvider uses a UserDetailsService to look up the username and compare the provided password with the one loaded from the service.

import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.stereotype.Service;

@Service
public class CustomUserDetailsService implements UserDetailsService {

    // You can inject your user repository here to fetch user from the database
    // private final UserRepository userRepository;

    @Override
    public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException {
        // Here, you would load the user from your database, not in-memory
        // For example:
        // User user = userRepository.findByUsername(username)
        //     .orElseThrow(() -> new UsernameNotFoundException("User not found: " + username));

        // In the provided example, it is an in-memory user detail service.
        // It's a simple implementation for demonstration.
        // In a real application, you would connect to a database.

        if ("admin".equals(username)) {
            // Return a UserDetails object for the 'admin' user
            // using Spring's built-in User class
            return org.springframework.security.core.userdetails.User
                    .withUsername("admin")
                    .password("$2a$10$tJ0V5M.v/hXG8zJ9iU2u3.b3X1i.e4wW1d8p.h/p3p.kQ4S4Jz3iE") // password "admin" encoded
                    .roles("ADMIN")
                    .build();
        } else if ("user".equals(username)) {
            // Return a UserDetails object for the 'user' user
            return org.springframework.security.core.userdetails.User
                    .withUsername("user")
                    .password("$2a$10$Wp.uM8n.E/R2p.3aQ4f9e.D5E4Q0.j.y.j6h.gT8L0.T5R5G7.D3") // password "user" encoded
                    .roles("USER")
                    .build();
        } else {
            throw new UsernameNotFoundException("User not found: " + username);
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

3. Authorities

In Spring Security, authorities (represented by the GrantedAuthority interface) are the permissions granted to a principal (user). They define what a user is authorized to do. Authorities are typically a set of roles, such as "ROLE_ADMIN", "ROLE_USER", "ROLE_EDITOR", etc., which are then used to secure specific endpoints or resources. When a user logs in, their authorities are loaded and stored in the security context, allowing Spring to make authorization decisions for every subsequent request.


4. JwtHelper

The JwtHelper class is a utility class for generating, validating, and extracting information from JSON Web Tokens. It encapsulates all the logic related to JWTs, keeping your security configuration clean and modular.

import io.jsonwebtoken.*;
import io.jsonwebtoken.io.Decoders;
import io.jsonwebtoken.security.Keys;
import io.jsonwebtoken.security.SignatureException;
import org.springframework.security.core.userdetails.UserDetails;

import java.security.Key;
import java.util.Date;
import java.util.Map;
import java.util.function.Function;

public class JwtHelper {
    // A secure, long, and complex secret key for signing the JWT. It should be stored securely.
    private static final String SECRET = "VxRfBGJFviiO62cg/M0YY5WypcyvtUUjfkI5aDJgwt4dLz6BQKuaKChKyn+Ulhz+";

    // Generates a JWT with a given user's details, claims, and an expiration date.
    public static String generateToken(UserDetails userDetails, Map<String, Object> extraClaims, long expirationDate) {
        return Jwts.builder()
                .setClaims(extraClaims) // Add custom claims, like roles.
                .setSubject(userDetails.getUsername()) // Set the subject (username).
                .setIssuedAt(new Date(System.currentTimeMillis())) // Set the issue date.
                .setExpiration(new Date(System.currentTimeMillis() + expirationDate)) // Set the expiration date.
                .signWith(getSecretKey(), SignatureAlgorithm.HS256) // Sign the token with the secret key and algorithm.
                .compact(); // Build the token into a compact, URL-safe string.
    }

    // Validates a token against a user's details. Checks for username match and expiration.
    public static boolean isValidToken(String token, UserDetails details) throws MalformedJwtException {
        // Extract username from token and compare it with the user's username.
        boolean isValid = extractUsername(token).equalsIgnoreCase(details.getUsername()) && !isExpired(token);
        if (!isValid) {
            throw new MalformedJwtException("Invalid Token");
        }
        return true;
    }

    // Checks if the token has expired.
    private static boolean isExpired(String token) {
        return extractExpiration(token).before(new Date());
    }

    // Extracts the expiration date from the token's claims.
    private static Date extractExpiration(String token) throws MalformedJwtException {
        return parseSingleClaim(token, Claims::getExpiration);
    }

    // Extracts the username (subject) from the token's claims.
    public static String extractUsername(String token) throws ExpiredJwtException, UnsupportedJwtException,
            MalformedJwtException, SignatureException, IllegalArgumentException {
        return parseSingleClaim(token, Claims::getSubject);
    }

    // Extracts a specific claim from the token.
    public static Object extractClaim(String token, String claim) throws MalformedJwtException {
        return parseSingleClaim(token, claims -> claims.get(claim, Object.class));
    }

    // A generic method to parse a single claim from the token.
    private static <T> T parseSingleClaim(String token, Function<Claims, T> resolver) throws ExpiredJwtException,
            UnsupportedJwtException, MalformedJwtException, SignatureException, IllegalArgumentException {
        Claims claims = extractAllClaims(token);
        return resolver.apply(claims);
    }

    // Extracts all claims from the token by parsing and verifying it.
    private static Claims extractAllClaims(String token) throws ExpiredJwtException, UnsupportedJwtException,
            MalformedJwtException, SignatureException, IllegalArgumentException {
        JwtParser parser = Jwts.parserBuilder()
                .setSigningKey(getSecretKey()).build();
        return parser.parseClaimsJws(token).getBody();
    }

    // Decodes the secret key from Base64 to a Key object.
    private static Key getSecretKey() {
        byte[] bytes = Decoders.BASE64.decode(SECRET);
        return Keys.hmacShaKeyFor(bytes);
    }
}
Enter fullscreen mode Exit fullscreen mode

5. Filters

In Spring Security, filters intercept incoming requests and outgoing responses. They form a chain, where each filter performs a specific task. We use them to implement cross-cutting concerns like authentication, authorization, and logging before the request reaches the controller.

CustomAuthenticationFilter

This filter is a custom implementation of UsernamePasswordAuthenticationFilter. It's used for the login process. When a user attempts to log in, this filter intercepts the request, extracts the username and password, and attempts to authenticate the user using the AuthenticationProvider. If authentication is successful, it generates a JWT token and sends it back to the client.

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import xyz.sadiulhakim.util.ResponseUtility;
import org.massmanagement.util.JwtHelper;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class CustomAuthenticationFilter extends UsernamePasswordAuthenticationFilter {
    private final AuthenticationProvider authenticationProvider;

    public CustomAuthenticationFilter(AuthenticationProvider authenticationProvider) {
        this.authenticationProvider = authenticationProvider;
    }

    @Override
    public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
            throws AuthenticationException {
        // Extract username and password from the request.
        String username = request.getParameter("username");
        String password = request.getParameter("password");

        // Create an authentication token to be authenticated by the provider.
        UsernamePasswordAuthenticationToken authenticationToken = new UsernamePasswordAuthenticationToken(username, password);

        // Authenticate the user and return the authenticated token.
        return authenticationProvider.authenticate(authenticationToken);
    }

    @Override
    protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, Authentication authentication) throws IOException, ServletException {
        // Extract the authenticated user details.
        var user = (User) authentication.getPrincipal();

        // Add user roles to the claims for the JWT.
        Map<String, Object> extraClaims = new HashMap<>();
        extraClaims.put("roles", user.getAuthorities());

        // Generate the access and refresh tokens using the JwtHelper.
        String accessToken = JwtHelper.generateToken(user, extraClaims, (1000L * 60 * 60 * 24 * 7)); // 7 days expiration
        String refreshToken = JwtHelper.generateToken(user, extraClaims, (1000L * 60 * 60 * 24 * 30)); // 30 days expiration

        // Create a map to hold the tokens.
        Map<String, String> tokenMap = new HashMap<>();
        tokenMap.put("accessToken", accessToken);
        tokenMap.put("refreshToken", refreshToken);

        // Send the tokens back in the response.
        ResponseUtility.commitResponse(response, tokenMap, 200);
    }
}
Enter fullscreen mode Exit fullscreen mode

CustomAuthorizationFilter

This filter is a custom implementation of OncePerRequestFilter. It's responsible for authorizing requests after the user has logged in. It intercepts every request, checks for a JWT in the Authorization header, validates the token, and if it's valid, sets the user's authentication in the SecurityContextHolder so that the user is authenticated for the current request.

import jakarta.servlet.FilterChain;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.massmanagement.service.CustomUserDetailsService;
import org.massmanagement.util.JwtHelper;
import org.massmanagement.util.ResponseUtility;
import org.springframework.http.HttpHeaders;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;

import java.util.HashMap;
import java.util.Map;

@Slf4j
@Component
@RequiredArgsConstructor
public class CustomAuthorizationFilter extends OncePerRequestFilter {
    private final CustomUserDetailsService userDetailsService;

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) {
        try {
            // Exclude specific paths from the filter chain, e.g., login, token validation.
            if (request.getServletPath().equalsIgnoreCase("/login") ||
                    request.getServletPath().endsWith("/validate-token")) {
                filterChain.doFilter(request, response);
            } else {
                String authorization = request.getHeader(HttpHeaders.AUTHORIZATION);
                // Check if the authorization header is present and starts with "Bearer ".
                if (authorization != null && authorization.startsWith("Bearer ")) {
                    // Extract the token.
                    String token = authorization.substring("Bearer ".length());
                    // Extract the username from the token.
                    String username = JwtHelper.extractUsername(token);
                    // Load user details using the username.
                    UserDetails userDetails = userDetailsService.loadUserByUsername(username);

                    // Validate the token and check if the user is not already authenticated.
                    if (JwtHelper.isValidToken(token, userDetails) && SecurityContextHolder.getContext().getAuthentication() == null) {
                        // Create a new authentication token and set it in the SecurityContextHolder.
                        UsernamePasswordAuthenticationToken authenticationToken = new UsernamePasswordAuthenticationToken(
                                userDetails,
                                null,
                                userDetails.getAuthorities() // Pass the authorities for authorization checks.
                        );
                        SecurityContextHolder.getContext().setAuthentication(authenticationToken);
                    }
                }
                // Continue the filter chain.
                filterChain.doFilter(request, response);
            }
        } catch (Exception ex) {
            log.error("Error occurred in CustomAuthorizationFilter. Cause: {}", ex.getMessage());

            // If an exception occurs (e.g., invalid token), send an error response.
            Map<String, String> errorMap = new HashMap<>();
            errorMap.put("error", ex.getMessage());
            ResponseUtility.commitResponse(response, errorMap, 500);
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

6. Password Encoder

A password encoder is a component used to securely hash and store passwords. Spring Security recommends using a one-way, irreversible hashing algorithm like BCrypt, Argon2, or Scrypt to prevent passwords from being stored in plain text. BCryptPasswordEncoder is a robust and widely used implementation.

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;

@Configuration
public class SecurityConfig {
    // Other beans...

    @Bean
    public PasswordEncoder passwordEncoder() {
        return new BCryptPasswordEncoder();
    }
}
Enter fullscreen mode Exit fullscreen mode

7. AuthenticationProvider

The AuthenticationProvider is the core of the authentication process. It takes an Authentication object (like our UsernamePasswordAuthenticationToken) and authenticates it. Spring's DaoAuthenticationProvider is a common choice, which uses a UserDetailsService to load a UserDetails object and a PasswordEncoder to verify the password.

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.security.authentication.dao.DaoAuthenticationProvider;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.crypto.password.PasswordEncoder;

public class SecurityConfig {

    // Other beans...

    @Bean
    public DaoAuthenticationProvider authenticationProvider(PasswordEncoder passwordEncoder,
                                                            UserDetailsService userDetailsService) {
        var provider = new DaoAuthenticationProvider();
        provider.setPasswordEncoder(passwordEncoder);
        provider.setUserDetailsService(userDetailsService);
        return provider;
    }
}
Enter fullscreen mode Exit fullscreen mode

8. Security Configuration

The SecurityFilterChain is the main configuration class that orchestrates all the security components. It defines the rules for which requests are secured, what type of session management to use (e.g., stateless for JWT), and which filters to add to the chain.

import lombok.RequiredArgsConstructor;
import org.massmanagement.security.CustomAuthenticationFilter;
import org.massmanagement.security.CustomAuthorizationFilter;
import org.massmanagement.service.CustomUserDetailsService;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;

import java.util.List;

@Configuration
@RequiredArgsConstructor
@EnableWebSecurity
class SecurityConfig {
    private final CustomAuthorizationFilter customAuthorizationFilter;

    @Value("${frontend.uri}")
    private String frontendUri;

    @Bean
    public SecurityFilterChain config(HttpSecurity http, AuthenticationProvider authenticationProvider) throws Exception {
        return http
                .csrf(AbstractHttpConfigurer::disable) // Disable CSRF for stateless REST APIs.
                .cors(c -> { // Configure CORS.
                    CorsConfigurationSource source = e -> {
                        CorsConfiguration config = new CorsConfiguration();
                        config.setAllowedOrigins(List.of(frontendUri)); // Allow requests from a specific origin.
                        config.setAllowedMethods(List.of("GET", "POST", "PUT", "DELETE")); // Allow specific methods.
                        config.setAllowedHeaders(List.of("*")); // Allow all headers.
                        return config;
                    };
                    c.configurationSource(source);
                })
                .authorizeHttpRequests(auth -> auth
                        .requestMatchers("/login", "/security/v1/validate-token").permitAll() // Allow these paths without authentication.
                        .anyRequest().authenticated() // All other requests require authentication.
                )
                .authenticationProvider(authenticationProvider) // Set the authentication provider.
                .sessionManagement(session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS)) // Use a stateless session for JWT.
                .addFilterBefore(customAuthorizationFilter, UsernamePasswordAuthenticationFilter.class) // Add the authorization filter before the default one.
                .addFilter(new CustomAuthenticationFilter(authenticationProvider)) // Add the authentication filter.
                .build();
    }

    // Other beans for PasswordEncoder, UserDetailsService, AuthenticationProvider
    // are defined here as in the previous sections.
}
Enter fullscreen mode Exit fullscreen mode

9. What if CSRF?

CSRF (Cross-Site Request Forgery) is a security vulnerability where an attacker tricks a user into submitting a malicious request to a web application in which they are authenticated. In traditional, stateful web applications, a unique CSRF token is sent with each request to ensure it originated from the legitimate web application and not from an external, malicious site.

For a stateless REST API using JWTs, CSRF protection is generally not necessary because the authentication state isn't stored in a server-side session. Instead, authentication information is contained within the JWT itself, which is typically stored in the client's local storage or a cookie. An attacker cannot forge a request with a valid JWT because they cannot access the token from another domain. The token must be explicitly sent with each request, and without the token, the request is not authenticated.

The provided SecurityConfig class disables CSRF protection using csrf(AbstractHttpConfigurer::disable).


10. What is CORS?

CORS (Cross-Origin Resource Sharing) is a browser security mechanism that restricts a web page from making requests to a different domain from which the web page was served. For example, if your front-end application is hosted at http://localhost:3000 and your Spring Boot API is at http://localhost:8080, the browser will block requests from the front-end to the API by default.

To enable communication between the two, you need to configure your Spring Boot API to allow requests from the front-end's origin. This is done by adding the necessary CORS headers to the API responses. The provided SecurityConfig class uses http.cors() to set up a CorsConfigurationSource bean, specifying which origins (frontendUri), HTTP methods (GET, POST, etc.), and headers are allowed. This tells the browser that it's safe to allow requests from the specified front-end application.


11. Session

A session is a way for a server to maintain a user's state across multiple requests. In traditional, stateful web applications, the server creates a session and stores a session ID in a cookie on the client. For subsequent requests, the client sends this cookie, and the server uses the session ID to retrieve the user's information from the server-side session.

In a JWT-based security system, we use a stateless session management policy. This means the server does not store any session state. Instead, the entire authentication and authorization information is contained within the JWT itself. With each request, the client sends the JWT in the Authorization header. The server then validates the token to authenticate and authorize the user for that single request. This approach is highly scalable as it removes the burden of session management from the server, making it suitable for REST APIs. The provided SecurityConfig class explicitly sets the session creation policy to SessionCreationPolicy.STATELESS.

Top comments (0)