0

I'm trying to loop through my pre-trained CNN using the following code, it's slightly modified from PyTorch's example:

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for i, batch in loaders[phase]:
                inputs = batch["image"].float().to(device)   # <---- error happens here
                labels = batch["label"].float().to(device) 

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

However I get the error:

Epoch 0/24
----------
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-53-79684c739f29> in <module>()
----> 1 model_ft = train_model(resnet_cnn, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)

<ipython-input-49-55bb790e99a0> in train_model(model, criterion, optimizer, scheduler, num_epochs)
     21             # Iterate over data.
     22             for i, batch in loaders[phase]:
---> 23                 inputs = batch["image"].float().to(device)
     24                 labels = batch["label"].float().to(device)
     25 

TypeError: string indices must be integers

The loaders variable is:

loaders = {"train":train_loader, "val":valid_loader}

The Dataset class I'm using for this for my train_loader and valid_loader is, and explains why I'm using the string in my initial model function:

class GetDataLabel(Dataset):

  def __init__(self, df, root, transform = None):
    self.df = df
    self.root = root
    self.transform = transform

  def __len__(self):
    return len(self.df)

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()

    img_path = os.path.join(self.root, self.df.iloc[idx, 0])
    img = Image.open(img_path)
    label = self.df.iloc[idx, 1]

    if self.transform:
      img = self.transform(img)
    
    img_lab = {"image": img,
               "label": label}
    return (img_lab)

Thank you in advance.

1 Answer 1

1

There is a missing enumerate:

for i, batch in enumerate(loaders[phase]):  # <--- here
    inputs = batch["image"].float().to(device)
    labels = batch["label"].float().to(device)
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.