1
model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
results, labels = predict_function(model, dev_data, version)

> /home/ofsdms/san_mrc/my_utils/data_utils.py(34)predict_squad()
-> phrase, spans, scores = model.predict(batch)
(Pdb) n
AttributeError: 'dict' object has no attribute 'predict'

How do I load a saved checkpoint of pytorch model, and use the same for prediction. I have the model saved in .pt extension

1 Answer 1

1

the checkpoint you save is usually a state_dict: a dictionary containing the values of the trained weights - but not the actual architecture of the net. The actual computational graph/architecture of the net is described as a python class (derived from nn.Module).
To use a trained model you need:

  1. Instantiate a model from the class implementing the computational graph.
  2. Load the saved state_dict to that instance:

    model.load_state_dict(torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
    
Sign up to request clarification or add additional context in comments.

1 Comment

I have used torch.load to do a full model load, in my case, model = DocReaderModel(opt, embedding) model_file_2 = os.path.join(model_dir, 'checkpoint_{}_epoch_{}_full_model.pt'.format(version, epoch)), torch.save(model, model_file_2)

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.