goprotocol-buffersgrpcgrpc-gateway

Extend grpc-gateway generated functions with custom decode logic in proto.Message


In the course of work there appeared a necessity to perform strings.TrimSpace on all fields with protobuf type string

As a result it was decided to write a plugin that generates the necessary functions for each structure: trimAll (to create a function to trim fields of type string) and UnmarshalJSON (we thought that this is the function used by grpc-gateway to translate from http.Request.Body to proto.Message).

func (m *GetLogbookCall_Request) trimAll() {
    if m == nil {
        return
    }
    m.Id = strings.TrimSpace(m.Id)
}

func (m *GetLogbookCall_Request) UnmarshalJSON(data []byte) error {
    err := proto.Unmarshal(data, m)
    if err != nil {
        return err
    }

    m.trimAll()

    return nil
}

But nothing worked as a result)

The question is: Is there any way to add custom logic when translating http.Request.Body to proto.Message?

PS:

Here are the functions generated by the grpc-gateway plugin.

func request_AccountService_UpdatePassword_0(ctx context.Context, marshaler runtime.Marshaler, client AccountServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
    var protoReq account.UpdatePasswordCall_Request
    var metadata runtime.ServerMetadata

    if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && err != io.EOF {
        return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
    }

    msg, err := client.UpdatePassword(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
    return msg, metadata, err

}

Solution

  • I found a solution to my problem. You can write a custom Marshaler and then use it to create a proxy server.

    type Marshaler struct {
        runtime.JSONPb
    }
    
    func (m *Marshaler) NewDecoder(r io.Reader) runtime.Decoder {
        d := json.NewDecoder(r)
        return Decoder{
            decoder: &runtime.DecoderWrapper{
                Decoder:          d,
                UnmarshalOptions: m.UnmarshalOptions,
            },
        }
    }
    
    type Decoder struct {
        decoder *runtime.DecoderWrapper
    }
    
    func (d Decoder) Decode(v interface{}) error {
        err := d.decoder.Decode(v)
        if err != nil {
            return err
        }
        
        // Write your own custom logic here. In my case it's use TrimAll method
    
        type Trimmable interface {
            TrimAll()
        }
    
        if v, ok := v.(Trimmable); ok {
            v.TrimAll()
        }
    
        return nil
    }
    
    
        mux := runtime.NewServeMux(
            runtime.WithMetadata(func(ctx context.Context, r *http.Request) metadata.MD {
            runtime.WithMarshalerOption(runtime.MIMEWildcard, &Marshaler{
                runtime.JSONPb{
                    MarshalOptions: protojson.MarshalOptions{
                        EmitUnpopulated: true,
                    },
                },
            }),
        )