godefer-keyword

Can I create a function that must only be used with defer?


For example:

package package

// Dear user, CleanUp must only be used with defer: defer CleanUp()
func CleanUp() {
    // some logic to check if call was deferred
    // do tear down
}

And in userland code:

func main() {
    package.CleanUp() // PANIC, CleanUp must be deferred!
}

But all should be fine if user runs:

func main() {
   defer package.CleanUp() // good job, no panic
}

Things I already tried:

func DeferCleanUp() {
    defer func() { /* do tear down */ }()
    // But then I realized this was exactly the opposite of what I needed
    // user doesn't need to call defer CleanUp anymore but...
}
// now if the APi is misused it can cause problems too:
defer DeferCleanUp() // a defer inception xD, question remains.

Solution

  • Alright, per OPs request and just for laughs, I'm posting this hacky approach to solving this by looking at the call stack and applying some heuristics.

    DISCLAIMER: Do not use this in real code. I don't think checking deferred is even a good thing.

    Also Note: this approach will only work if the executable and the source are on the same machine.

    Link to gist: https://gist.github.com/dvirsky/dfdfd4066c70e8391dc5 (this doesn't work in the playground because you can't read the source file there)

    package main
    
    import(
        "fmt"
        "runtime"
        "io/ioutil"
        "bytes"
        "strings"
    )
    
    
    
    
    func isDeferred() bool {
        
        // Let's get the caller's name first
        var caller string
        if fn, _, _, ok  := runtime.Caller(1); ok {
            caller = function(fn)
        } else {
            panic("No caller")
        }
        
        // Let's peek 2 levels above this - the first level is this function,
        // The second is CleanUp()
        // The one we want is who called CleanUp()
        if _, file, line, ok  := runtime.Caller(2); ok {
            
            // now we actually need to read the source file
            // This should be cached of course to avoid terrible performance
            // I copied this from runtime/debug, so it's a legitimate thing to do :)
            data, err := ioutil.ReadFile(file)
            if err != nil {
                panic("Could not read file")
            }
            
            // now let's read the exact line of the caller 
            lines := bytes.Split(data, []byte{'\n'})
            lineText := strings.TrimSpace(string(lines[line-1]))
            fmt.Printf("Line text: '%s'\n", lineText)
            
            
            // Now let's apply some ugly rules of thumb. This is the fragile part
            // It can be improved with regex or actual AST parsing, but dude...
            return lineText == "}" ||  // on simple defer this is what we get
                   !strings.Contains(lineText, caller)  || // this handles the case of defer func() { CleanUp() }()
                   strings.Contains(lineText, "defer ")
            
            
        } // not ok - means we were not clled from at least 3 levels deep
        
        return false
    }
    
    func CleanUp() {
        if !isDeferred() {
            panic("Not Deferred!")
        }
        
        
    }
    
    // This should not panic
    func fine() {
        defer CleanUp() 
        
        fmt.Println("Fine!")
    }
    
    
    // this should not panic as well
    func alsoFine() {
        defer func() { CleanUp() }()
        
        fmt.Println("Also Fine!")
    }
    
    // this should panic
    func notFine() {
        CleanUp() 
        
        fmt.Println("Not Fine!")
    }
    
    // Taken from the std lib's runtime/debug:
    // function returns, if possible, the name of the function containing the PC.
    func function(pc uintptr) string {
        fn := runtime.FuncForPC(pc)
        if fn == nil {
            return ""
        }
        name := fn.Name()
        if lastslash := strings.LastIndex(name, "/"); lastslash >= 0 {
            name = name[lastslash+1:]
        }
        if period := strings.Index(name, "."); period >= 0 {
            name = name[period+1:]
        }
        name = strings.Replace(name, "·", ".", -1)
        return name
    }
    
    func main(){
        fine()
        alsoFine()
        notFine()
    }