I am confused while passing the parameters in classmethod.
The code is shown below:
def gather_feature(fmap, index):
# fmap.shape(B, k1, 1) index.shape(B, k2)
dim = fmap.size(-1)
index = index.unsqueeze(len(index.shape)).expand(*index.shape, dim) # this works
fmap = fmap.gather(dim=1, index=index) # out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1.
return fmap
def gather_feature(fmap, index):
# fmap.shape(B, k1, 1) index.shape(B, k2)
dim = fmap.size(-1)
index = index.unsqueeze(len(index.shape))
index = index.expand(*index.shape, dim) # raise error
fmap = fmap.gather(dim=1, index=index) # out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1.
return fmap
Once index.unsqueeze() has done, the shape of index would be changed to (B, k2, 1).
If the index.shape that pass to expand() classmethod is (B, k2, 1), an error has raised.
However, if writing these tow classmethod in one line, namely index.unsqueeze().expand(), the index.shape passing to expand() classmethod seems to be (B, k2).
Has the index.shape been computed and stored before performing .unsqueeze()?
Therefore, the .unsqueeze() won't affect the index.shape which pass to .expand().
That is my guess, but I cannot figure out another one.
Thank you for your time.