Here are a couple of work arounds you can try:
- Perform
bernoulli_
ofStochasticDepth
on CPU
Around this line here,
if 'hpu' in input.device.type:
dev = 'cpu'
#noise = torch.empty(size, dtype=input.dtype, device=input.device)
noise = torch.empty(size, dtype=input.dtype, device=dev)
noise = noise.bernoulli_(survival_rate)
if 'hpu' in input.device.type:
noise = noise.to(input.device.type)
- Disable inplace
Dropout
Around here
Replacenn.Dropout(p=dropout, inplace=True),
withnn.Dropout(p=dropout),
Please let me know if you see speedups with these 2 changes.
Thanks
Sayantan