SyncBatchNorm Error

I am experiencing some issues when running SyncBatchNorm in HPU devices. In particular, I get this error that is linked to this torch line

ValueError: SyncBatchNorm expected input tensor to be on GPU or privateuseone

One workaround is to simply remove the torch line that I shared and the code works but not sure if that might hurt the performance or generate an unexpected behaviour. Can you please share your thoughts?

I am sharing a minimal code to replicate the issue. I am using the image: vault.habana.ai/gaudi-docker/1.13.0/ubuntu20.04/habanalabs/pytorch-installer-2.1.0:latest

main.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from unet import UNet  # Assume UNet is defined in unet.py
import habana_frameworks.torch as htorch
import habana_frameworks.torch.core as htcore
import torch.distributed as dist


class SyntheticDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples, image_size=(256, 256), num_classes=3):
        self.num_samples = num_samples
        self.image_size = image_size
        self.num_classes = num_classes

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate random image
        image = torch.rand(3, *self.image_size)  # Random RGB image

        # Generate random label mask
        label = torch.randint(0, self.num_classes, (1, *self.image_size), dtype=torch.long)

        return image, label

device = "hpu"

### INIT DIST MODE
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, gpu = initialize_distributed_hpu()
import habana_frameworks.torch.distributed.hccl
dist.init_process_group(
    backend="hccl",
    init_method="env://",
    world_size=world_size,
    rank=rank,
)
dist.barrier()
### INIT DIST MODE

# Define transformation for input images and masks
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Define dataset and dataloader
train_dataset = SyntheticDataset(num_samples=100, image_size=(256,256), num_classes=3)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Initialize U-Net model
model = UNet(in_channels=3, out_channels=12).to(device)  # Adjust out_channels based on number of classes
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.DataParallel(model)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        
        # Forward pass
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, masks)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        htcore.mark_step()
        optimizer.step()
        htcore.mark_step()
        
        running_loss += loss.item()
        htorch.hpu.synchronize()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

print('Finished Training')

unet.py

import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UpBlock(nn.Module):
    def __init__(self, in_channels_f1, in_channels_f2):
        super().__init__()
        self.trans_conv = nn.ConvTranspose2d(in_channels_f1, in_channels_f2, 2, stride=2)
        self.conv = DoubleConv(in_channels_f2*2, in_channels_f2)


    def forward(self, f1, f2):
        f1 = self.trans_conv(f1)
        f = torch.cat([f1, f2], dim=1)
        f = self.conv(f)
        return f

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.enc1 = nn.Sequential(*[DoubleConv(in_channels, 64), nn.MaxPool2d(2)])
        self.enc2 = nn.Sequential(*[DoubleConv(64, 128), nn.MaxPool2d(2)])
        self.enc3 = nn.Sequential(*[DoubleConv(128, 256), nn.MaxPool2d(2)])
        self.enc4 = nn.Sequential(*[DoubleConv(256, 512), nn.MaxPool2d(2)])

        self.up1 = UpBlock(in_channels_f1=512, in_channels_f2=256)
        self.up2 = UpBlock(in_channels_f1=256, in_channels_f2=128)
        self.up3 = UpBlock(in_channels_f1=128, in_channels_f2=64)
        


        self.last_up = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.linear = nn.Conv2d(32, out_channels, 1)

    def forward(self, x):
        encoder_outputs = []
        
        # Encoder
        x = self.enc1(x)
        encoder_outputs.append(x)
        x = self.enc2(x)
        encoder_outputs.append(x)
        x = self.enc3(x)
        encoder_outputs.append(x)
        x = self.enc4(x)
        encoder_outputs.append(x)

        # Decoder
        up = self.up1(encoder_outputs[-1], encoder_outputs[-2])
        up = self.up2(up, encoder_outputs[-3])
        up = self.up3(up, encoder_outputs[-4])
        
 
        up = self.last_up(up)
        up = self.linear(up)
        return up

You can run the script as:

torchrun --nproc_per_node=1 main.py

I can repro the error. Looking into it.

syncBN is a GPU specific op, and currently we do not support it. Will update here if its supported in the future

Is there an alternative to use Batch Normalization in a distributed environment?

We are looking into that. Please stay tuned, will update here