x is the image, yis the label, and metadata are dates, times etc.
for x, y_true, metadata in train_loader:
print(x.shape)
The shape returns:
torch.Size([16, 3, 448, 448])
How do I go about displaying x as an image? Do I use plt?
Your x is not a single image, but rather a batch of 16 different images, all of size 448x448 pixels.
You can use torchvision.utils.make_grid to convert x into a grid of 4x4 images, and then plot it:
import torchvision
with torch.no_grad(): # no need for gradients here
grid = torchvision.utils.make_grid(x, nrow=4) # you might consider normalize=True
# convert the grid into a numpy array suitable for plt
grid_np = grid.cpu().numpy().transpose(1, 2, 0) # channel dim should be last
plt.matshow(grid_np)
matplotlib.pyplotas you say, but for that you have to consider that you have16x3xWxHtensor, you will have to iterate over first dimention and then transpose (using torch transpose you have to transpose 2 times to getWxHx3or only one to getHxWx3) you alternatively can convert the tensor to numpy to transpose automatically all dimensionsfor idx,im in enumerate(x):