0

As shown here in PyTorch tutorials the code for an autoencoder model is like this:

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

My question is what is the reason for use of view function on output of embedding layer?

1 Answer 1

2

The view function added extra dimension to given input shape to match expected input shape. In the function initHidden the hidden shape is initialized to (1, 1, 256).

def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

Based on documentation, GRU input shape must have 3 dimensions, input of shape (seq_len, batch, input_size).

https://pytorch.org/docs/stable/generated/torch.nn.GRU.html

The shape of self.embedding(input) is (1, 256) and a sample output is,

tensor([[ 0.1421,  0.4135, -1.0619,  0.0149,  0.0673, -0.3770,  0.4231,  2.2803,
         -1.6939, -0.0071,  1.1131, -1.0019,  0.6593,  0.1366,  1.1033, -0.8804,
          1.3676,  0.4115, -0.5671,  0.3314, -0.2599, -0.3082,  1.3644,  0.5788,
         -0.1929, -2.0505,  0.4518,  0.8757, -0.2360, -0.4099, -0.5697, -1.5973,
         -0.6638, -1.1523,  1.4425,  1.3651,  1.9371,  0.5698, -0.3541, -1.3883,
         -0.0195, -1.0757, -1.4324, -1.6226, -2.4267,  0.3874, -0.7529,  1.4938,
         -2.5773, -1.1962,  0.3759, -0.6143, -1.0444, -0.6443, -0.8130, -1.7283,
          1.4167,  1.3945, -1.2695,  0.7289,  0.7777, -0.0094, -1.8108,  0.2126,
         -0.2018, -0.4055, -0.7779, -0.8523,  0.0162,  0.2463,  0.5588, -0.7250,
         -0.0128,  0.6272, -0.7729,  0.4259,  0.7596, -1.9500,  0.5853,  0.3764,
         -0.1112,  0.7274, -2.8535, -0.0445,  0.4225,  1.2179,  0.2219, -0.7064,
         -0.9654,  1.0501,  1.7142,  0.5312, -0.8180, -1.5697,  1.3062, -0.9321,
         -0.1652, -1.5298, -0.3575, -1.2046, -0.6571, -0.7689, -0.7032,  1.0727,
         -1.3259,  0.1200,  1.9357, -0.2519, -0.3717,  0.8054,  0.1180, -0.6921,
          1.0245, -1.5500, -0.5280, -0.7462,  0.7924,  2.2701, -1.5094, -0.1973,
         -1.5919,  0.4869,  0.6739, -0.5242,  0.2559, -0.0149, -0.5332, -1.8313,
          0.3598,  0.0804, -0.0780, -0.2930, -0.2844, -0.4752, -0.9919,  0.1809,
          0.7622, -2.5069, -0.7724, -0.9441,  1.6101,  0.6461, -0.8932,  0.0600,
          0.6911,  0.5191, -0.1719, -0.5829, -0.9168,  1.5282,  1.4399,  0.3264,
         -0.8894,  0.2880, -0.0697,  0.8977, -0.5004,  0.3844,  0.0925,  0.5592,
         -0.1664,  0.8575, -1.0348,  0.7326, -0.2124,  0.7533,  0.6270, -0.9559,
         -1.4159,  0.6788,  0.6163, -0.5951, -0.1403, -1.6088, -0.7731,  0.3876,
          1.0429, -2.0960,  0.1726,  1.7446, -0.3963,  0.0785, -0.4701,  1.0074,
          0.3319, -2.2675, -1.6163, -0.4003, -0.5468,  0.0452, -2.5586,  0.4747,
         -0.0271, -1.2161,  1.2121,  1.8738, -1.2207, -0.9218, -0.1430,  0.2512,
         -0.5236, -0.2544, -0.5868, -0.7086, -1.3328, -0.0243,  0.4759,  1.4125,
          0.4947,  0.5054,  1.6253,  0.4198, -0.9150,  0.6374,  0.4581,  1.1527,
          1.4440, -0.0590, -0.4601,  0.2490, -0.5739,  0.6798, -0.2156, -1.1386,
         -0.5011, -0.7411,  0.2825, -0.2595,  0.8070,  0.5270,  0.2595, -0.1089,
          0.4221, -0.7851,  0.7112, -0.3038,  0.6169, -0.1513, -0.5872,  0.3974,
          0.2431,  0.4934, -0.9406, -0.9372,  1.4525,  0.1376,  0.2558,  0.0661,
          0.3509,  2.1667,  2.8428,  0.9429, -0.6143, -1.0969,  0.0955,  0.0914]],
       device='cuda:0', grad_fn=<EmbeddingBackward>)

The shape of self.embedding(input).view(1, 1, -1) is (1, 1, 256) and a sample output is,

tensor([[[ 0.1421,  0.4135, -1.0619,  0.0149,  0.0673, -0.3770,  0.4231,
           2.2803, -1.6939, -0.0071,  1.1131, -1.0019,  0.6593,  0.1366,
           1.1033, -0.8804,  1.3676,  0.4115, -0.5671,  0.3314, -0.2599,
          -0.3082,  1.3644,  0.5788, -0.1929, -2.0505,  0.4518,  0.8757,
          -0.2360, -0.4099, -0.5697, -1.5973, -0.6638, -1.1523,  1.4425,
           1.3651,  1.9371,  0.5698, -0.3541, -1.3883, -0.0195, -1.0757,
          -1.4324, -1.6226, -2.4267,  0.3874, -0.7529,  1.4938, -2.5773,
          -1.1962,  0.3759, -0.6143, -1.0444, -0.6443, -0.8130, -1.7283,
           1.4167,  1.3945, -1.2695,  0.7289,  0.7777, -0.0094, -1.8108,
           0.2126, -0.2018, -0.4055, -0.7779, -0.8523,  0.0162,  0.2463,
           0.5588, -0.7250, -0.0128,  0.6272, -0.7729,  0.4259,  0.7596,
          -1.9500,  0.5853,  0.3764, -0.1112,  0.7274, -2.8535, -0.0445,
           0.4225,  1.2179,  0.2219, -0.7064, -0.9654,  1.0501,  1.7142,
           0.5312, -0.8180, -1.5697,  1.3062, -0.9321, -0.1652, -1.5298,
          -0.3575, -1.2046, -0.6571, -0.7689, -0.7032,  1.0727, -1.3259,
           0.1200,  1.9357, -0.2519, -0.3717,  0.8054,  0.1180, -0.6921,
           1.0245, -1.5500, -0.5280, -0.7462,  0.7924,  2.2701, -1.5094,
          -0.1973, -1.5919,  0.4869,  0.6739, -0.5242,  0.2559, -0.0149,
          -0.5332, -1.8313,  0.3598,  0.0804, -0.0780, -0.2930, -0.2844,
          -0.4752, -0.9919,  0.1809,  0.7622, -2.5069, -0.7724, -0.9441,
           1.6101,  0.6461, -0.8932,  0.0600,  0.6911,  0.5191, -0.1719,
          -0.5829, -0.9168,  1.5282,  1.4399,  0.3264, -0.8894,  0.2880,
          -0.0697,  0.8977, -0.5004,  0.3844,  0.0925,  0.5592, -0.1664,
           0.8575, -1.0348,  0.7326, -0.2124,  0.7533,  0.6270, -0.9559,
          -1.4159,  0.6788,  0.6163, -0.5951, -0.1403, -1.6088, -0.7731,
           0.3876,  1.0429, -2.0960,  0.1726,  1.7446, -0.3963,  0.0785,
          -0.4701,  1.0074,  0.3319, -2.2675, -1.6163, -0.4003, -0.5468,
           0.0452, -2.5586,  0.4747, -0.0271, -1.2161,  1.2121,  1.8738,
          -1.2207, -0.9218, -0.1430,  0.2512, -0.5236, -0.2544, -0.5868,
          -0.7086, -1.3328, -0.0243,  0.4759,  1.4125,  0.4947,  0.5054,
           1.6253,  0.4198, -0.9150,  0.6374,  0.4581,  1.1527,  1.4440,
          -0.0590, -0.4601,  0.2490, -0.5739,  0.6798, -0.2156, -1.1386,
          -0.5011, -0.7411,  0.2825, -0.2595,  0.8070,  0.5270,  0.2595,
          -0.1089,  0.4221, -0.7851,  0.7112, -0.3038,  0.6169, -0.1513,
          -0.5872,  0.3974,  0.2431,  0.4934, -0.9406, -0.9372,  1.4525,
           0.1376,  0.2558,  0.0661,  0.3509,  2.1667,  2.8428,  0.9429,
          -0.6143, -1.0969,  0.0955,  0.0914]]], device='cuda:0',
       grad_fn=<ViewBackward>)

Code

This code works,

rnn1 = nn.GRU(256, 128, 1)
input1 = torch.randn(100, 2, 256)
h01 = torch.randn(1, 2, 128)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)

Output

torch.Size([100, 2, 256]) torch.Size([1, 2, 128])
torch.Size([100, 2, 128]) torch.Size([1, 2, 128])

Code

This code also works,

rnn1 = nn.GRU(256, 256)
input1 = torch.randn(1, 1, 256)
h01 = torch.randn(1, 1, 256)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)

Output

torch.Size([1, 1, 256]) torch.Size([1, 1, 256])
torch.Size([1, 1, 256]) torch.Size([1, 1, 256])

Code

This does not work,

rnn1 = nn.GRU(256, 256)
input1 = torch.randn(1, 256)
#input1 = input1.view(1, 1, -1)
h01 = torch.randn(1, 1, 256)
output1, hn1 = rnn1(input1, h01)
print(input1.shape, h01.shape)
print(output1.shape, hn1.shape)

Output

RuntimeError: input must have 3 dimensions, got 2
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.