1

Hi I'm facing an issue in gathering all the losses and predictions in multi gpu scenario. I'm using pytorch lightning 2.0.4 and deepspeed, distributed strategy - deepspeed_stage_2.

I'm adding my skeleton code here for reference.

    def __init__(self):
        self.batch_train_preds = []
        self.batch_train_losses = []


    def  training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        # Model Step
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=train_labels)

        train_preds = torch.argmax(outputs.logits, dim=-1)

        return {'loss': outputs[0],
                'train_preds': train_preds}

    def on_train_batch_end(self, outputs, batch, batch_idx):
        # aggregate metrics or outputs at batch level
        train_batch_loss = outputs["loss"].mean()
        train_batch_preds = torch.cat(outputs["train_preds"])

        self.batch_train_preds.append(train_batch_preds)
        self.batch_train_losses.append(train_batch_loss.item())

        return {'train_batch_loss': train_batch_loss,
                'train_batch_preds': train_batch_preds
                }

    def on_train_epoch_end(self) -> None:
        # Aggregate epoch level training metrics

        epoch_train_preds = torch.cat(self.batch_train_preds)
        epoch_train_loss = np.mean(self.batch_train_losses)

        self.logger.log_metrics({"epoch_train_loss": epoch_train_loss})

In the above code block, I'm trying to combine all the predictions into a single tensor at the end of the epoch by tracking each batch in a global list (defined at init). but in multi gpu training, I faced an error with concatination as each gpu is treating the batch in it's own device and I couldn't combine the results in a single global list.

Question:

What should I be doing in on_train_batch_end or on_train_epoch_end or in training_step in order to combine the results across all the gpus into a list created in my init because I want to calculate some additional metrics(precision, recall etc) during ON_*_EPOCH_END() function in my train, validation, test

(validation and test are exactly similar to my 3 training functions above i.e combining losses and predictions).

I have come across all_gather but it is being called across all devices(gpus) but comibining the results which I wanted.

Now the question is how do I use only one of the device's output from all_gather. A code snippet would be very much helpful.

1 Answer 1

2

lightning documentation suggests to use all_gather. Moreover, you do not need to manually accumulate the loss, just log it with self.log(..., epoch=True) to let lightning accumulate and log it correctly:

class MyLightningModule(LightningModule):

    def __init__(self):
        super().__init__()
        self.batch_train_preds = []

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        # Model Step
        outputs = self.model(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )

        loss = outputs[0]

        train_preds = torch.argmax(outputs.logits, dim=-1)
        self.batch_train_preds.append(train_preds)

        self.log('train/loss', loss, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def on_train_epoch_end(self) -> None:

        # Aggregate epoch level training metrics
        epoch_train_preds = torch.cat(self.batch_train_preds, dim=0)

        # the following will stack predictions from all the distributed processes on dim=0
        epoch_train_preds = self.all_gather(epoch_train_preds)

        # reshape to (dataset_size, *other_dims)
        new_batch_size = self.trainer.world_size() * epoch_train_preds.shape[0]
        epoch_train_preds = epoch_train_preds.view(new_batch_size, *epoch_train_preds.shape[1:])

        # compute here your metrics over `epoch_train_preds`

        self.batch_train_preds.clear()  # free memory 

If you want to compute the metric only on a single process, protect the metric computation with if self.trainer.global_rank == 0:.

I also suggest to take a look at torchmetrics, which enables automatic synchronisation of metrics in distributed setting with a few lines of code.

Additionally, I've written a framework for easy training and testing of several Transformer models for NLP.

Additional example using torchmetrics

from torchmetrics.classification import BinaryAccuracy
from lightning.pytorch import LightningModule


class MyLightningModule(LightningModule):

    def __init__(self):
        super().__init__()
        self.train_accuracy = BinaryAccuracy()

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        # Model Step
        outputs = self.model(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )

        loss = outputs[0]

        train_preds = torch.argmax(outputs.logits, dim=-1)
        self.train_accuracy(train_preds, labels)  # updates the metric internal state with predictions and labels

        self.log('train/loss', loss, on_step=True, on_epoch=True, sync_dist=True)
        self.log('train/acc', self.train_accuracy, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def on_train_epoch_end(self) -> None:
        pass  # no need to reset the metric as lightning will take care of that after each epoch
Sign up to request clarification or add additional context in comments.

4 Comments

Thank you very much for your answer. I've used all_gather approach. I have a doubt with your answer. You have appended the outputs at step level. I can also append at batch_end level right? something like this - def on_train_batch_end(self, outputs, batch, batch_idx): combined_outputs = self.all_gather(outputs) if self.trainer.is_global_zero: self.batch_train_losses.append(combined_outputs["loss"]) self.batch_train_preds.append(combined_outputs['train_preds']) self.batch_train_preds.append(combined_outputs['train_labels'])
Yes you can, but I do not see any reason for doing that.
Im sorry you're right. I could not see any difference. Also, would you please provide me some reference in using torchmetrics in distributed setting? Im more interested towards using it in my case in calculating scores like acc, f1, precision recall etc. I'll look into your transformer framework as well for future reference. Thanks for sharing
@sastaengineer I added the example with torchmetrics

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.