halide

Different expressions for different outputs in Halide


I'm new to Halide so also kinda didn't know how to ask the question. Let me explain. Let's assume I have a simple code for Halide's generator like this:

class Blur : public Generator<Blur>{
public:
    Input<Buffer<float>> in_func{"in_func", 2};
    Output<Buffer<float>> forward{"forward", 2};

    Var x, y, n;
    void generate(){

        Expr m1 = in_func(x+1, y+2)+in_func(x+2, y+1);
        Expr m2 = in_func(x+1, y+2)-in_func(x+2, y+1);
        Expr m3 = in_func(x+2, y+1)+in_func(x+1, y+1);
        Expr m4 = in_func(x+2, y+1)-in_func(x+1, y+1);
        Expr w0010_2 = -in_func(x+2, y+2)+in_func(x, y+2);
        Expr w0111_2 = -in_func(x+3, y+2)+in_func(x+1, y+2);

        forward(0,0) = w0010_2+m4+m3+m2+m1;
        forward(1,0) = -w0111_2+m4+m3-m2-m1; 
        forward(0,1) = w0010_2-m4+m3-m2+m1;
        forward(1,1) = w0111_2-m4+m3+m2-m1;
    }
};

What I want to achieve is to define that output at index (0,0) should be the result of m1 + m2 but output at index (1,0) should be the result of different expression, for example, m1 - m2. I would be really grateful for help.


Solution

  • What I want to achieve is to define that output at index (0,0) should be the result of m1 + m2 but output at index (1,0) should be the result of different expression, for example, m1 - m2. [...] I want result[0][0] = expression1, result[0][1] = expression2, result[1][0] = expression3 and result[1][1] = expression4. But also result[0][2], result[0][4] and so on = expression1

    Compute the values x%2 and y%2 and use their values in a select:

    forward(x, y) = select(
      x % 2 == 0 && y % 2 == 0, m1 + m2,
      x % 2 == 1 && y % 2 == 0, m1 - m2,
      x % 2 == 0 && y % 2 == 1, expr3,
      /* otherwise, */ expr4
    );
    

    Select is a pure if-then-else. It evaluates all of its arguments and then picks the one corresponding to the first true predicate. If the expressions all use nearby points of in_func, this might not be too slow.

    If you find that performance suffers, I'd try to create four funcs, one for each of the four expressions, and then select loads from those. If that's still too slow, you might be able to optimize the indexing to not compute any extra points. If you show all four expressions, I might be able to help you do that.