I was just looking at the DDP Tutorial:
https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
According to this:
It’s common to use torch.save and torch.load to checkpoint modules during training and recover from checkpoints. See SAVING AND LOADING MODELS for more details. When using DDP, one optimization is to save the model in only one process and then load it to all processes, reducing write overhead. This is correct because all processes start from the same parameters and gradients are synchronized in backward passes, and hence optimizers should keep setting parameters to the same values. If you use this optimization, make sure all processes do not start loading before the saving is finished. Besides, when loading the module, you need to provide an appropriate map_location argument to prevent a process to step into others’ devices. If map_location is missing, torch.load will first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. For more advanced failure recovery and elasticity support, please refer to TorchElastic.
I dont understand what this means. Shouldn't only one process/first GPU be saving the model? Is saving and loading how weights are shared across the processes/GPUs?