0

Firstly, the network function is defined:

def softmax(X):
    X_exp=torch.exp(X)
    partition=X_exp.sum(1,keepdim=True)
    return X_exp/partition

def net(X):
    return softmax(torch.matmul(X.reshape(-1,W.shape[0]),W)+b)

Then update the function parameters by training

train(net,train_iter,test_iter,cross_entropy,num_epoches,updater)

Finally, the function is saved and loaded for prediction

PATH='./net.pth'
torch.save(net,PATH)
saved_net=torch.load(PATH)
predict(saved_net,test_iter,6)

The prediction results show that the updated parameters W and b are not saved and loaded. What is the correct way to save custom functions and updated parameters ?

1 Answer 1

1

The correct way is to implement your own nn.Module and then use the provided utilities to save and load the model's state (their weights) on demand.

You must define two functions:

  • __init__: the class initializer logic where you define your model's parameters.

  • forward: the function which implements the model's forward pass.

A minimal example would be of the form:

class LinearSoftmax(nn.Module):
    def __init__(self, in_feat, out_feat):
        super().__init__()
        self.W = torch.rand(in_feat, out_feat)
        self.b = torch.rand(out_feat)

    def softmax(X):
        X_exp = torch.exp(X)
        partition = X_exp.sum(1, keepdim=True)
        return X_exp / partition

    def forward(X):
        return softmax(torch.matmul(X.reshape(-1,W.shape[0]),W)+b)

You can initialize a new model by doing:

>>> model = LinearSoftmax(10, 3)

You can then save and load weights W and b of a given instance:

Sign up to request clarification or add additional context in comments.

2 Comments

Thank you very much for your answer! That is, torch. save only works for preserving tensor (or dic) or classes inherited from nn.Module, not for saving custom function parameters
@BuDiu Consider accepting the answer as valid if you found it helped you solve your question.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.