Hi Habana team!
I am trying to use a PyTorch model on HPU, but I am getting an error when trying to encode the text input. The same script below works fine on CPU/CUDA. Will I need to change anything else in the model to make it work? I couldn’t find anything else on the PyTorch porting guide.
I have included the debugger file and also the hl-smi_log found in the ~/.habana_logs
dir.
I am using the Deep Learning AMI Habana PyTorch 1.7.1 SynapseAI 0.15.4 (Ubuntu 18.04) 20211025 image (ami-061d5e0b81dfa2121).
Any help is very much appreciated =)
import os
import torch
import habana_frameworks.torch.core
from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
from models import loader
text_model = loader.load_model('BERT-Distil-40')
_ = text_model.to('hpu') #works on cpu/cuda
with torch.no_grad():
print(text_model('hello'))
Error:
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 756, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/ubuntu/.local/lib/python3.7/site-packages/transformers/models/distilbert/modeling_distilbert.py", line 550, in forward
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 756, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/ubuntu/.local/lib/python3.7/site-packages/transformers/models/distilbert/modeling_distilbert.py", line 130, in forward
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 756, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/sparse.py", line 126, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 1852, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: array::at: __n (which is 18446744073709551615) >= _Nm (which is 8)