metalmetal-performance-shaders

MPSGraph assign variable operation is not updating value


I'm trying to understand the MPSGraph api. Why does the following code failing to change the value of var after the second run call?

#import <Foundation/Foundation.h>
#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>

void run(void);

void run(void) {
    MPSGraph *graph = [MPSGraph new];

    float test1 = 2.0;
    float test2 = 3.0;
    MPSGraphTensor *constant = [graph constantWithScalar:test1
                                                dataType:MPSDataTypeFloat32];
    MPSGraphTensor *var = [graph variableWithData:[NSData dataWithBytes:&test2 length:sizeof(test2)]
                                            shape:@[@1]
                                         dataType:MPSDataTypeFloat32
                                             name:@"var"];

    MPSGraphTensorDataDictionary *result = [graph runWithFeeds:@{}
                                                 targetTensors:@[constant, var]
                                              targetOperations:NULL];

    float test3;
    NSInteger temp = sizeof(test3);
    [result[var].mpsndarray readBytes:&test3
                          strideBytes:&temp];
    NSLog(@"%f", test3);

    [result[constant].mpsndarray readBytes:&test3
                               strideBytes:&temp];
    NSLog(@"%f", test3);
    
    MPSGraphOperation *op = [graph assignVariable:var
                                withValueOfTensor:constant
                                             name:NULL];
    
    result = [graph runWithFeeds:@{}
                   targetTensors:@[var]
                targetOperations:@[op]];

    [result[var].mpsndarray readBytes:&test3
                          strideBytes:&temp];
    
    NSLog(@"%f", test3);
}

int main(int argc, const char * argv[]) {
    @autoreleasepool {
        run();
    }
    return 0;
}

Output:

2022-09-27 13:26:35.708641-0400 MPSGraphExample[9643:2669233] Metal API Validation Enabled
2022-09-27 13:26:35.732058-0400 MPSGraphExample[9643:2669233] 3.000000
2022-09-27 13:26:35.732097-0400 MPSGraphExample[9643:2669233] 2.000000
2022-09-27 13:26:35.733821-0400 MPSGraphExample[9643:2669233] 3.000000
Program ended with exit code: 0

Solution

  • It turns out that the operation is applied after the tensor is evaluated. Changing it to

    [graph runWithFeeds:@{}
          targetTensors:@[]
       targetOperations:@[op]];
    
    result = [graph runWithFeeds:@{}
                   targetTensors:@[var]
                targetOperations:@[]];
    

    now shows the updated values.