0

I would like to define a network that comprises many templates. Below under Network Definitions is a simplified example where the first network definition is used as a template in the second one. This doesn't work - when I initialise my optimiser is says that the network parameters are empty! How should I do this properly? The network that I ultimately want is very complicated.

Main Function

if __name__ == "__main__":

myNet       = Network().cuda().train()
optimizer   = optim.SGD(myNet.parameters(), lr=0.01, momentum=0.9)

Network definitions:

class NetworkTemplate(nn.Module):


def __init__(self):
    super(NetworkTemplate, self).__init__()
    self.conv1 = nn.Conv2d(1, 3, kernel_size=1, bias=False)
    self.bn1 = nn.BatchNorm2d(3)

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)

    return x

class Network(nn.Module):


def __init__(self, nNets):
    super(Network, self).__init__()

    self.nets = []
    for curNet in range(nNets):
        self.nets.append(NetworkTemplate())

    def forward(self, x):

        for curNet in self.nets:
            x = curNet(x)

        return x

1 Answer 1

2

Just use torch.nn.Sequential? Like self.nets=torch.nn.Sequential(*self.nets) after you populated self.nets and then call return self.nets(x) in your forward function?

If you want to do something more complicated, you can put all networks into torch.nn.ModuleList, however you'll need to manually take care of calling them in your forward method in that case (but it can be more complicated than just sequential).

Sign up to request clarification or add additional context in comments.

4 Comments

Changing it with nn.Sequential now gives the error list is not a module subclass
My bad, I'd forgotten the * to unwrap self.nets when passing it to nn.Sequential. I've corrected the answer.
This works! Thank you very much indeed. I will up-vote your answer. Can you please explain the * operator. Is this a Python thing, or is this a PyTorch thing?
No, that's python. Say you have function my_function(a,b,c) with three arguments. As it happens, your arguments are in a list my_list=['hello', 'world', '!']. You could now pass those arguments as my_function(my_list[0], my_list[1], my_list[2]), basically unwrapping the list yourself and passing each individually. Or you use *: my_function(*my_list), that is much easier!

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.