0

I am using pretrained torchvision models in PyTorch and transfer learning to classify my own data set. That is working fine, but I think I could further improve my classification performance. Our images come in different dimensions, all of them are resized to fit the input of my model (e.g. to 224x224 pixels).

However, the original image size often says a lot of the class this image belongs to. So I thought it might help the model to add the original image dimension as second input to the model.

Currently I build my model in PyTorch like this:

model = resnet50(pretrained=True)  # Could be another base model as well
for module, param in zip(model.modules(), model.parameters()):
    if isinstance(module, nn.BatchNorm2d):
        param.requires_grad = False
model.fc = nn.Sequential(
                nn.Linear(2048, 512),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.Linear(256, num_classes),
            )

Now how would I add another (two-dimensional?) input to that model so that I can feed x and y dimensions of the original image to the model? Also, where does that make most sense - directly into the "beginning" of the model, or better somewhere "in between"?

1 Answer 1

1

One way to inject the data into the model can be directly to the linear layers.

This will have the drawback of not affecting the conv layers.

Note that I injected to the final layer, but this can go in any layer.

model.start = nn.Sequential(
                nn.Linear(2048, 512),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Dropout(0.25),
            )

model.end = nn.Sequential(   
                nn.Linear(256 + 2, num_classes),
            )

and your forward should be (pseudocode) something like

def forward(x):
    x1 = model.start(x)
    mid = torch.concatenate([x, extra_2d_data])
    x2 = model.end(mid)
    return x2

See also this

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.