8

I created a pyTorch Model to classify images. I saved it once via state_dict and the entire model like that:

torch.save(model.state_dict(), "model1_statedict")
torch.save(model, "model1_complete")

How can i use these models? I'd like to check them with some images to see if they're good.

I am loading the model with:

model = torch.load(path_model)
model.eval()

This works alright, but i have no idea how to use it to predict on a new picture.

3
  • 2
    I edited your question because asking for resources, such as tutorials, is unfortunately not allowed here Commented Apr 5, 2021 at 12:21
  • ok sorry, i didnt know that, thanks for editing Commented Apr 5, 2021 at 14:41
  • not allowed? dumb Commented Jan 23, 2023 at 14:47

2 Answers 2

8
def predict(self, test_images):
    self.eval()
    # model is self(VGG class's object)
    
    count = test_images.shape[0]
    result_np = []
        
    for idx in range(0, count):
        # print(idx)
        img = test_images[idx, :, :, :]
        img = np.expand_dims(img, axis=0)
        img = torch.Tensor(img).permute(0, 3, 1, 2).to(device)
        # print(img.shape)
        pred = self(img)
        pred_np = pred.cpu().detach().numpy()
        for elem in pred_np:
            result_np.append(elem)
    return result_np

network is VGG-19 and ref my source code.

like this architecture:

class VGG(object):
    def __init__(self):
    ...


    def train(self, train_images, valid_images):
        train_dataset = torch.utils.data.Dataset(train_images)
        valid_dataset = torch.utils.data.Dataset(valid_images)

        trainloader = torch.utils.data.DataLoader(train_dataset)
        validloader = torch.utils.data.DataLoader(valid_dataset)

        self.optimizer = Adam(...)
        self.criterion = CrossEntropyLoss(...)
    
        for epoch in range(0, epochs):
            ...
            self.evaluate(validloader, model=self, criterion=self.criterion)
    ...

    def evaluate(self, dataloader, model, criterion):
        model.eval()
        for i, sample in enumerate(dataloader):
    ...

    def predict(self, test_images):
    
    ...

if __name__ == "__main__":
    network = VGG()
    trainset, validset = get_dataset()    # abstract function for showing
    testset = get_test_dataset()
    
    network.train(trainset, validset)

    result = network.predict(testset)
Sign up to request clarification or add additional context in comments.

Comments

7

A pytorch model is a function. You provide it with appropriately defined input, and it returns an output. If you just want to visually inspect the output given a specific input image, simply call it:

model.eval()
output = model(example_image)

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.