javaspring-bootspring-websocket

Error 401 with Websockets and Spring Boot Security


I am using Spring Boot 3.1.0, and I am trying to implement some websockets in my application. I am following a tutorial from internet to implement them and generate some unitary tests to be sure that are working.

The code is very similar from the tutorial. A Configuration class and a controller:

@Configuration
@EnableWebSocketMessageBroker
public class WebSocketConfiguration implements WebSocketMessageBrokerConfigurer {
    //Where is listening to messages
    public static final String SOCKET_RECEIVE_PREFIX = "/app";

    //Where messages will be sent.
    public static final String SOCKET_SEND_PREFIX = "/topic";

    //URL where the client must subscribe.
    public static final String SOCKETS_ROOT_URL = "/ws-endpoint";

    @Override
    public void registerStompEndpoints(StompEndpointRegistry registry) {
        registry.addEndpoint(SOCKETS_ROOT_URL)
                .setAllowedOrigins("*")
                .withSockJS();
    }

    @Override
    public void configureMessageBroker(MessageBrokerRegistry registry) {
        registry.setApplicationDestinationPrefixes(SOCKET_RECEIVE_PREFIX)
                .enableSimpleBroker(SOCKET_SEND_PREFIX);
    }
}
@Controller
public class WebSocketController {
    @MessageMapping("/welcome")
    @SendTo("/topic/greetings")
    public String greeting(String payload) {
        System.out.println("Generating new greeting message for " + payload);
        return "Hello, " + payload + "!";
    }

    @SubscribeMapping("/chat")
    public MessageContent sendWelcomeMessageOnSubscription() {
        return new MessageContent(String.class.getSimpleName(), "Testing");
    }
}

The extra information that is not in the tutorial is the Security configuration:

@Configuration
@EnableWebSecurity
public class WebSecurityConfig {
    private static final String[] AUTH_WHITELIST = {
            // -- Swagger
            "/v3/api-docs/**", "/swagger-ui/**",
            // Own
            "/",
            "/info/**",
            "/auth/public/**",
            //Websockets
            WebSocketConfiguration.SOCKETS_ROOT_URL,
            WebSocketConfiguration.SOCKET_RECEIVE_PREFIX,
            WebSocketConfiguration.SOCKET_SEND_PREFIX,
            WebSocketConfiguration.SOCKETS_ROOT_URL + "/**",
            WebSocketConfiguration.SOCKET_RECEIVE_PREFIX + "/**",
            WebSocketConfiguration.SOCKET_SEND_PREFIX + "/**"
    };

    private final JwtTokenFilter jwtTokenFilter;

    @Value("${server.cors.domains:null}")
    private List<String> serverCorsDomains;

    @Autowired
    public WebSecurityConfig(JwtTokenFilter jwtTokenFilter) {
        this.jwtTokenFilter = jwtTokenFilter;
    }

    @Bean
    public PasswordEncoder passwordEncoder() {
        return new BCryptPasswordEncoder();
    }


    @Bean
    public AuthenticationManager authenticationManager(AuthenticationConfiguration authenticationConfiguration) throws Exception {
        //Will use the bean KendoUserDetailsService.
        return authenticationConfiguration.getAuthenticationManager();
    }


    @Bean
    public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
        http
                //Disable cors headers
                .cors(cors -> cors.configurationSource(generateCorsConfigurationSource()))
                //Disable csrf protection
                .csrf(AbstractHttpConfigurer::disable)
                //Sessions should be stateless
                .sessionManagement(httpSecuritySessionManagementConfigurer ->
                        httpSecuritySessionManagementConfigurer.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
                .exceptionHandling(httpSecurityExceptionHandlingConfigurer ->
                        httpSecurityExceptionHandlingConfigurer
                                .authenticationEntryPoint((request, response, ex) -> {
                                    RestServerLogger.severe(this.getClass().getName(), ex.getMessage());
                                    response.sendError(HttpServletResponse.SC_UNAUTHORIZED, ex.getMessage());
                                })
                                .accessDeniedHandler((request, response, ex) -> {
                                    RestServerLogger.severe(this.getClass().getName(), ex.getMessage());
                                    response.sendError(HttpServletResponse.SC_FORBIDDEN, ex.getMessage());
                                })
                )
                .addFilterBefore(jwtTokenFilter, UsernamePasswordAuthenticationFilter.class)
                .authorizeHttpRequests((requests) -> requests
                        .requestMatchers(AUTH_WHITELIST).permitAll()
                        .anyRequest().authenticated());
        return http.build();
    }


    private CorsConfigurationSource generateCorsConfigurationSource() {
        final CorsConfiguration configuration = new CorsConfiguration();
        if (serverCorsDomains == null || serverCorsDomains.contains("*")) {
            configuration.setAllowedOriginPatterns(Collections.singletonList("*"));
        } else {
            configuration.setAllowedOrigins(serverCorsDomains);
            configuration.setAllowCredentials(true);
        }
        configuration.addAllowedHeader("*");
        configuration.addAllowedMethod("*");
        configuration.addExposedHeader(HttpHeaders.AUTHORIZATION);
        final UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
        source.registerCorsConfiguration("/**", configuration);
        return source;
    }
}

And the test is:

@SpringBootTest(webEnvironment = RANDOM_PORT)
@Test(groups = "websockets")
@AutoConfigureMockMvc(addFilters = false)
public class BasicWebsocketsTests extends AbstractTestNGSpringContextTests {
 @BeforeClass
    public void authentication() {
       //Generates a user for authentication on the test.
       AuthenticatedUser authenticatedUser = authenticatedUserController.createUser(null, USER_NAME, USER_FIRST_NAME, USER_LAST_NAME, USER_PASSWORD, USER_ROLES);

        headers = new WebSocketHttpHeaders();
        headers.set("Authorization", "Bearer " + jwtTokenUtil.generateAccessToken(authenticatedUser, "127.0.0.1"));
    }


    @BeforeMethod
    public void setup() throws ExecutionException, InterruptedException, TimeoutException {
        WebSocketClient webSocketClient = new StandardWebSocketClient();
        this.webSocketStompClient = new WebSocketStompClient(webSocketClient);
        this.webSocketStompClient.setMessageConverter(new MappingJackson2MessageConverter());
   }


 @Test
    public void echoTest() throws ExecutionException, InterruptedException, TimeoutException {
        BlockingQueue<String> blockingQueue = new ArrayBlockingQueue<>(1);

        StompSession session = webSocketStompClient.connectAsync(getWsPath(), this.headers,
                new StompSessionHandlerAdapter() {
                }).get(1, TimeUnit.SECONDS);

        session.subscribe("/topic/greetings", new StompFrameHandler() {

            @Override
            public Type getPayloadType(StompHeaders headers) {
                return String.class;
            }

            @Override
            public void handleFrame(StompHeaders headers, Object payload) {
                blockingQueue.add((String) payload);
            }
        });

        session.send("/app/welcome", TESTING_MESSAGE);

        await().atMost(1, TimeUnit.SECONDS)
                .untilAsserted(() -> Assert.assertEquals("Hello, Mike!", blockingQueue.poll()));
    }

The result obtained is:

ERROR 2024-01-05 20:58:55.068 GMT+0100 c.s.k.l.RestServerLogger [http-nio-auto-1-exec-1] - com.softwaremagico.kt.rest.security.WebSecurityConfig$$SpringCGLIB$$0: Full authentication is required to access this resource

java.util.concurrent.ExecutionException: jakarta.websocket.DeploymentException: Failed to handle HTTP response code [401]. Missing [WWW-Authenticate] header in response.

This is not new on StackOverflow, and I have review some suggestions from other questions. Solutions are slightly different as are for previous versions of Spring Boot and the Security configuration has evolved over versions.

If enabling the websockets loggers:

logging.level.org.springframework.messaging=trace
logging.level.org.springframework.web.socket=trace

I can see that the auth headers are there:

TRACE 2024-01-06 08:25:54.109 GMT+0100 o.s.w.s.c.s.StandardWebSocketClient [SimpleAsyncTaskExecutor-1] - Handshake request headers: {Authorization=[Bearer eyJhbGciOiJIUzUxMiJ9.eyJzdWIiOiIxLFRlc3QuVXNlciwxMjcuMC4wLjEsIiwiaXNzIjoiY29tLnNvZnR3YXJlbWFnaWNvIiwiaWF0IjoxNzA0NTI1OTUzLCJleHAiOjE3MDUxMzA3NTN9.-ploHleAF6IpUmP4IPzLV1nYNHnigpamYgS9e3Gp183SLri-37QZA2TDKIbE6iTDCunF0JRYry7xSsq_Op1UgQ], Sec-WebSocket-Key=[cxyjc2DjRRfm/elvG0261A==], Connection=[upgrade], Sec-WebSocket-Version=[13], Host=[127.0.0.1:38373], Upgrade=[websocket]}

Note: The security is working fine for the REST endpoints. And I have similar tests where I generate a user with some roles and later I authenticate using it. The code of the test is available here, and can be executed without any special configuration.

What I have tried:

I also have tried to generate a non stomp websocket with the jakarta websocket package and the issue is exactly the same: a 401 error.

From my point of view, seems that the JWT header is not correctly included on the test. But I am not able to see the issue with the code.


Solution

  • Well you have several problems:

    1. You should correct your ws path in test. Have a look at constant
    private String getWsPath() {
            return String.format("ws://127.0.0.1:%d/kendo-tournament-backend/%s", port,
                    WebSocketConfiguration.SOCKETS_ROOT_URL);
    }
    
    1. You should use SockJsClient as I mentioned in comment. In other case you will receive 400 error
    this.webSocketStompClient = new WebSocketStompClient(new SockJsClient(Arrays.asList(new WebSocketTransport(new StandardWebSocketClient()))));
    
    1. You should use StringMessageConverter because you return String in controller:
    this.webSocketStompClient.setMessageConverter(new StringMessageConverter());
    
    1. You should check your assertion for valid check:
    await().atMost(3, TimeUnit.SECONDS)
           .untilAsserted(() -> Assert.assertEquals(blockingQueue.poll(), String.format("Hello, %s!", TESTING_MESSAGE)));