pythontensorflowpytorcharray-broadcasting

broadcasting tensor matmul over batches


how can i find dot product of each batch response and X data.

y_yhat_allBatches_matmulX_allBatches = torch.matmul(yTrue_yHat_allBatches_tensorSub, interceptXY_data_allBatches[:, :, :-1])

expected shape of y_yhat_allBatches_matmulX_allBatches should be 2 by 5. where each row is for specific batch

yTrue_yHat_allBatches_tensorSub.shape = [2, 15] where rows batch (1&2) and columns = size of response (15)

interceptXY_data_allBatches[:, :, :-1].shape = torch.Size([2, 15, 5]) for 15 observations by 5 features for 2 batches

please see full reproducible code

#define dataset
nFeatures_withIntercept = 5
NObservations = 15
miniBatches = 2
interceptXY_data_allBatches = torch.randn(miniBatches, NObservations, nFeatures_withIntercept+1) #+1 Y(response variable)

#random assign beta to work with
beta_holder = torch.rand(nFeatures_withIntercept)

#y_predicted for each mini-batch
y_predBatchAllBatches = torch.matmul(interceptXY_data_allBatches[:, :, :-1], beta_holder)

#y_true - y_predicted for each mini-batch
yTrue_yHat_allBatches_tensorSub = torch.sub(interceptXY_data_allBatches[..., -1], y_predBatchAllBatches)
y_yhat_allBatches_matmulX_allBatches = torch.matmul(yTrue_yHat_allBatches_tensorSub, interceptXY_data_allBatches[:, :, :-1])

Solution

  • It looks like you have:

    If you want to multiply them to get a resulting shape of (2, 5), then you need to make the first one into (2, 1, 15) using .unsqueeze(dim=1). Then you can use torch.bmm() or the @ operator to multiply (2, 1, 15) into (2, 15, 5), yielding a result shaped (2, 1, 5). Finally, .squeeze the result to drop the singleton dimension and get (2, 5).

    y_yhat_allBatches_matmulX_allBatches =\
        torch.bmm(yTrue_yHat_allBatches_tensorSub.unsqueeze(dim=1),
                  interceptXY_data_allBatches[:, :, :-1]
                 ).squeeze()
    

    More compact notation using the @ operator:

    y_yhat_allBatches_matmulX_allBatches =\
        (yTrue_yHat_allBatches_tensorSub.unsqueeze(dim=1) @ interceptXY_data_allBatches[:, :, :-1]).squeeze()