23

I tried to find the answer but I can't.

I make a custom deep learning model using pytorch. For example,

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.nn_layers = nn.ModuleList()
        self.layer = nn.Linear(2,3).double()
        torch.nn.init.xavier_normal_(self.layer.weight)

        self.bias = torch.nn.Parameter(torch.randn(3))

        self.nn_layers.append(self.layer)

    def forward(self, x):
        activation = torch.tanh
        output = activation(self.layer(x)) + self.bias

        return output

If I print

model = Net()
print(list(model.parameters()))

it does not contains model.bias, so optimizer = optimizer.Adam(model.parameters()) does not update model.bias. How can I go through this? Thanks!

1
  • If you are subclassing module see the solution posted here. discuss.pytorch.org/t/…. Test with print(list(model.parameters())). Commented Jan 17, 2022 at 6:06

1 Answer 1

30

You need to register your parameters:

self.register_parameter(name='bias', param=torch.nn.Parameter(torch.randn(3)))

Update:
In more recent versions of PyTorch, you no longer need to explicitly register_parameter, it's enough to set a member of your nn.Module with nn.Parameter to "notify" pytorch that this variable should be treated as a trainable parameter:

self.bias = torch.nn.Parameter(torch.randn(3))

Please note that is you want to have more complex data structures of parameters (e.g., lists, etc.) you should use dedicated containers like torch.nn.ParameterList or torch.nn.ParameterDict.

Sign up to request clarification or add additional context in comments.

4 Comments

Thank you! I think ``` self.register_parameter(name='bias', param=torch.nn.Parameter(self.bias)) ``` does not work. Is there a way to get rid of 'name' part? I think this will be troublesome when save and load the model.
@CSH by using name='bias' when registering the parameter you implicitly creates self.bias to be used in your Net class. you do not need to explicitly assign self.bias for it to exist.
OMG. What a convenient function! Thank you for your helps!
It should be torch.nn.parameter.Parameter now.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.