I'm trying to understand the MPSGraph
api. Why does the following code failing to change the value of var
after the second run call?
#import <Foundation/Foundation.h>
#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
void run(void);
void run(void) {
MPSGraph *graph = [MPSGraph new];
float test1 = 2.0;
float test2 = 3.0;
MPSGraphTensor *constant = [graph constantWithScalar:test1
dataType:MPSDataTypeFloat32];
MPSGraphTensor *var = [graph variableWithData:[NSData dataWithBytes:&test2 length:sizeof(test2)]
shape:@[@1]
dataType:MPSDataTypeFloat32
name:@"var"];
MPSGraphTensorDataDictionary *result = [graph runWithFeeds:@{}
targetTensors:@[constant, var]
targetOperations:NULL];
float test3;
NSInteger temp = sizeof(test3);
[result[var].mpsndarray readBytes:&test3
strideBytes:&temp];
NSLog(@"%f", test3);
[result[constant].mpsndarray readBytes:&test3
strideBytes:&temp];
NSLog(@"%f", test3);
MPSGraphOperation *op = [graph assignVariable:var
withValueOfTensor:constant
name:NULL];
result = [graph runWithFeeds:@{}
targetTensors:@[var]
targetOperations:@[op]];
[result[var].mpsndarray readBytes:&test3
strideBytes:&temp];
NSLog(@"%f", test3);
}
int main(int argc, const char * argv[]) {
@autoreleasepool {
run();
}
return 0;
}
Output:
2022-09-27 13:26:35.708641-0400 MPSGraphExample[9643:2669233] Metal API Validation Enabled
2022-09-27 13:26:35.732058-0400 MPSGraphExample[9643:2669233] 3.000000
2022-09-27 13:26:35.732097-0400 MPSGraphExample[9643:2669233] 2.000000
2022-09-27 13:26:35.733821-0400 MPSGraphExample[9643:2669233] 3.000000
Program ended with exit code: 0
It turns out that the operation is applied after the tensor is evaluated. Changing it to
[graph runWithFeeds:@{}
targetTensors:@[]
targetOperations:@[op]];
result = [graph runWithFeeds:@{}
targetTensors:@[var]
targetOperations:@[]];
now shows the updated values.