sqlgoreflectionrefactoringgeneralization

Generalizing *sql.Rows Scan in Go


I am developing a web API using Go and there is a lot of redundant database query scan code.

func (m *ContractModel) WorkQuestions(cid int) ([]models.WorkQuestion, error) {
    results, err := m.DB.Query(queries.WORK_QUESTIONS, cid)
    if err != nil {
        return nil, err
    }

    var workQuestions []models.WorkQuestion
    for results.Next() {
        var wq models.WorkQuestion
        err = results.Scan(&wq.ContractStateID, &wq.QuestionID, &wq.Question, &wq.ID, &wq.Answer, &wq.Compulsory)
        if err != nil {
            return nil, err
        }
        workQuestions = append(workQuestions, wq)
    }

    return workQuestions, nil
}

func (m *ContractModel) Questions(cid int) ([]models.Question, error) {
    results, err := m.DB.Query(queries.QUESTIONS, cid)
    if err != nil {
        return nil, err
    }

    var questions []models.Question
    for results.Next() {
        var q models.Question
        err = results.Scan(&q.Question, &q.Answer)
        if err != nil {
            return nil, err
        }
        questions = append(questions, q)
    }

    return questions, nil
}

func (m *ContractModel) Documents(cid int) ([]models.Document, error) {
    results, err := m.DB.Query(queries.DOCUMENTS, cid)
    if err != nil {
        return nil, err
    }

    var documents []models.Document
    for results.Next() {
        var d models.Document
        err = results.Scan(&d.Document, &d.S3Region, &d.S3Bucket, &d.Source)
        if err != nil {
            return nil, err
        }
        documents = append(documents, d)
    }

    return documents, nil
}

I need to generalize this code so that I can pass in the result *sql.Rows to a function and obtain a struct slice containing the scanned rows. I know that there is a StructScan method in sqlx package but this cannot be used since I have a significant amount of code written using the go standard database/sql package.

Using the reflect package, I can create a generic StructScan function but reflect package cannot create a slice of struct from a passed interface{} type. What I need to achieve is something like as follows

func RowsToStructs(rows *sql.Rows, model interface{}) ([]interface{}, error) {
    // 1. Create a slice of structs from the passed struct type of model
    // 2. Loop through each row,
    // 3. Create a struct of passed mode interface{} type
    // 4. Scan the row results to a slice of interface{}
    // 5. Set the field values of struct created in step 3 using the slice in step 4
    // 6. Add the struct created in step 3 to slice created in step 1
    // 7. Return the struct slice
}

I cannot seem to find a way to scan the struct passed as the model parameter and create a slice of it using the reflect package. Is there any workaround to this or am I looking at the question in a wrong way?

Struct fields has the correct number of cols returned from the result and in correct order


Solution

  • You can avoid using a type assertion in the calling function by passing a pointer to the destination slice as an argument. Here's RowsToStructs with that modification:

    // RowsToStructs scans rows to the slice pointed to by dest.
    // The slice elements must be pointers to structs with exported
    // fields corresponding to the the columns in the result set.
    //
    // The function panics if dest is not as described above.
    func RowsToStructs(rows *sql.Rows, dest interface{}) error {
    
        // 1. Create a slice of structs from the passed struct type of model
        //
        // Not needed, the caller passes pointer to destination slice.
        // Elem() dereferences the pointer.
        //
        // If you do need to create the slice in this function
        // instead of using the argument, then use
        // destv := reflect.MakeSlice(reflect.TypeOf(model).
    
        destv := reflect.ValueOf(dest).Elem()
    
        // Allocate argument slice once before the loop.
    
        args := make([]interface{}, destv.Type().Elem().NumField())
    
        // 2. Loop through each row
    
        for rows.Next() {
    
            // 3. Create a struct of passed mode interface{} type
            rowp := reflect.New(destv.Type().Elem())
            rowv := rowp.Elem()
    
            // 4. Scan the row results to a slice of interface{}
            // 5. Set the field values of struct created in step 3 using the slice in step 4
            //
            // Scan directly to the struct fields so the database
            // package handles the conversion from database
            // types to a Go types.
            //
            // The slice args is filled with pointers to struct fields.
    
            for i := 0; i < rowv.NumField(); i++ {
                args[i] = rowv.Field(i).Addr().Interface()
            }
    
            if err := rows.Scan(args...); err != nil {
                return err
            }
    
            // 6. Add the struct created in step 3 to slice created in step 1
    
            destv.Set(reflect.Append(destv, rowv))
    
        }
        return nil
    }
    

    Call it like this:

    func (m *ContractModel) Documents(cid int) ([]*models.Document, error) {
        results, err := m.DB.Query(queries.DOCUMENTS, cid)
        if err != nil {
            return nil, err
        }
        defer results.Close()
        var documents []*models.Document
        err := RowsToStruct(results, &documents)
        return documents, err
    }
    

    Eliminate more boilerplate by moving the query to a helper function:

    func QueryToStructs(dest interface{}, db *sql.DB, q string, args ...interface{}) error {
        rows, err := db.Query(q, args...)
        if err != nil {
            return err
        }
        defer rows.Close()
        return RowsToStructs(rows, dest)
    }
    

    Call it like this:

    func (m *ContractModel) Documents(cid int) ([]*models.Document, error) {
        var documents []*model.Document
        err := QueryToStructs(&documents, m.DB, queries.DOCUMENTS, cid)
        return documents, err
    }