gogo-chi

How can I change the response code of an http request using middleware?


I'm trying to develop an api gateway using go-chi. I want that if request includes "_always200=true" in query, I will set status code to 200. Here what I tried: custom_response_writer.go:

type CustomResponseWriter struct {
    http.ResponseWriter
    Buf         *bytes.Buffer
    StatusCode  int
    WroteHeader bool
}

func NewCustomResponseWriter(w http.ResponseWriter) *CustomResponseWriter {
    return &CustomResponseWriter{ResponseWriter: w, Buf: new(bytes.Buffer)}
}

func (c *CustomResponseWriter) WriteHeader(code int) {
    c.StatusCode = code
    c.ResponseWriter.WriteHeader(code)
}

func (c *CustomResponseWriter) Write(b []byte) (int, error) {
    return c.Buf.Write(b)
}

middleware:

func HeaderFilterMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        always200, _ := r.Context().Value(always200QueryKey).(bool)
        crw := NewCustomResponseWriter(w)
        next.ServeHTTP(crw, r)

        if always200 {
            crw.WriteHeader(http.StatusOK)
        }
    })
}

It is always returns 404. Is there any way to fix that ?


Solution

  • There are at least two issues in the code.

    1. always200, _ := r.Context().Value(always200QueryKey).(bool) is not the correct way to get the query parameter. I think you should read it with r.URL.Query().Get("_always200") and compare the value to the string "true".

    2. it's too late to call crw.WriteHeader(http.StatusOK) after next.ServeHTTP(crw, r). It printed this warning message when I tested the code:

      http: superfluous response.WriteHeader call from main.(*CustomResponseWriter).WriteHeader
      

    Here is an amended example that changes the status code of a response:

    package main
    
    import (
        "bytes"
        "net/http"
        "strings"
    
        "github.com/go-chi/chi/v5"
    )
    
    type CustomResponseWriter struct {
        http.ResponseWriter
        Buf                *bytes.Buffer
        OriginalStatusCode int
        WroteHeader        bool
        always200          bool
    }
    
    func NewCustomResponseWriter(w http.ResponseWriter, always200 bool) *CustomResponseWriter {
        return &CustomResponseWriter{
            ResponseWriter: w,
            Buf:            new(bytes.Buffer),
            always200:      always200,
        }
    }
    
    func (c *CustomResponseWriter) WriteHeader(code int) {
        c.OriginalStatusCode = code
    
        if c.always200 {
            code = http.StatusOK
        }
        c.ResponseWriter.WriteHeader(code)
    }
    
    func (c *CustomResponseWriter) Write(b []byte) (int, error) {
        return c.Buf.Write(b)
    }
    
    func HeaderFilterMiddleware(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            always200 := r.URL.Query().Get("_always200")
            crw := NewCustomResponseWriter(w, strings.EqualFold(always200, "true"))
            next.ServeHTTP(crw, r)
        })
    }
    
    func main() {
        r := chi.NewRouter()
    
        r.Use(HeaderFilterMiddleware)
    
        r.Get("/", func(w http.ResponseWriter, r *http.Request) {
            w.Write([]byte("welcome"))
        })
        http.ListenAndServe(":3000", r)
    }