PRELU RuntimeError for inputs more than 1 dimension

Code to reproduce:

import os
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['LOG_LEVEL_PT_FALLBACK'] = '1'
os.environ['PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES'] = '1'
os.environ['LOG_LEVEL_ALL'] = '3'
os.environ['ENABLE_CONSOLE'] = 'true'
import habana_frameworks.torch.core as htcore
import torch
with torch.no_grad():
    prelu = torch.nn.PReLU().to('hpu')
    input = torch.randn(1,2).to('hpu')
    output = prelu(input)
output

Output

[18:43:00.462428][SYN_API       ][info ][tid:6622E] + ---------------------------------------------------------------------- +
[18:43:00.462446][SYN_API       ][info ][tid:6622E] | Version:            1.16.2                                             |
[18:43:00.462461][SYN_API       ][info ][tid:6622E] | Synapse:            323df60                                            |
[18:43:00.462463][SYN_API       ][info ][tid:6622E] | HCL:                97cbe6a                                            |
[18:43:00.462464][SYN_API       ][info ][tid:6622E] | MME:                b7ec966                                            |
[18:43:00.462470][SYN_API       ][info ][tid:6622E] | SCAL:               0ecf6e1                                            |
[18:43:00.462472][SYN_API       ][info ][tid:6622E] | Description:        HabanaLabs Runtime and GraphCompiler               |
[18:43:00.462494][SYN_API       ][info ][tid:6622E] | Time:               2024-07-22 18:43:00.462473                         |
[18:43:00.462497][SYN_API       ][info ][tid:6622E] + ---------------------------------------------------------------------- +
[18:43:00.482429][KERNEL_DB             ][warning][tid:6622E] Failed loading version number from libTPCFuser.so
[18:43:00.486237][KERNEL_DB             ][warning][tid:6622E] Failed loading version number from libTPCFuser.so
[18:43:00.488345][KERNEL_DB             ][warning][tid:6622E] Failed loading version number from libTPCFuser.so
[18:43:01.528550][TPC_NODE              ][warning][tid:6622E] Can't access halReader, setting maxNumOfTPCS to 24 
[18:43:01.531598][TPC_NODE              ][warning][tid:6622E] Can't access halReader, setting maxNumOfTPCS to 64 
[18:43:01.546686][SCAL][info ][tid:6622E] +-------------------------------------------------+
[18:43:01.546723][SCAL][info ][tid:6622E] SCAL Commit SHA1 = 0ecf6e1
[18:43:01.546728][SCAL][info ][tid:6622E] SCAL Build Time = Mon Jun 24 02:52:40 AM IDT 2024
[18:43:01.546733][SCAL][info ][tid:6622E] SCAL loading config from :/gaudi2/default_edma_v3.json
[18:43:01.546741][SCAL][info ][tid:6622E] SCAL config Hash = 0x368e943541a6e68c
[18:43:01.546744][SCAL][info ][tid:6622E] +-------------------------------------------------+
[18:43:03.238187][HCL       ][info ][tid:6622E] Version:	1.16.2-97cbe6a
[18:43:03.238426][HL_GCFG][warning][tid:6622E] setValue: override BOX_TYPE_ID value that already set from observation
[18:43:04.051460][PT_SYNHELPER    ][warning][tid:6622E] /npu-stack/pytorch-integration/backend/synapse_helpers/mem_hlml.cpp:117	Allocated hlml shared memory0x7fea08394000
[18:43:04.052040][PT_DYNAMIC_SHAPE][warning][tid:6622E] MallocExtension_ReleaseFreeMemory was not linked
[18:43:04.087853][HABANA_NODE           ][error][tid:6622E] Output tensor and input tensor of Reshape4 doesn't match in elements' count ( 2 , 1 )
[18:43:04.087891][HABANA_NODE           ][error][tid:6622E] Node Validation Failed. Cannot create node Reshape4.
[18:43:04.087986][SYN_API       ][error][tid:6622E] _createGenericNode: Can not create reshape generic node ()
[18:43:04.088045][PT_BRIDGE       ][error][tid:6622E] /npu-stack/pytorch-integration/backend/synapse_helpers/graph.cpp: 481synNodeCreateWithId failed for node: reshape with synStatus 1 [Invalid argument]. .add_node
[18:43:04.096653][PT_BRIDGE       ][error][tid:6622E] backtrace (up to 30)
[18:43:04.096671][PT_BRIDGE       ][error][tid:6622E] /usr/lib/habanalabs/libhl_logger.so(hl_logger::v1_0::logStackTrace(std::shared_ptr<hl_logger::Logger> const&, int)+0x5c) [0x7feaf4057d0c]
[18:43:04.096676][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(void hl_logger::v1_7_inline_fmt_compile::logStacktrace<HlLogger::LoggerType>(HlLogger::LoggerType, int)+0x61) [0x7fea3fa77bb1]
[18:43:04.096685][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_backend.so(synapse_helpers::graph::add_node(std::vector<internalTensor*, std::allocator<internalTensor*> >&&, std::vector<internalTensor*, std::allocator<internalTensor*> >&&, void*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned long*, char const**, char const**, bool)+0x1ba5) [0x7fea3e81fbb5]
[18:43:04.096690][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_backend.so(habana::OpBackend::BuildNode(habana::OpBackend*, synapse_helpers::graph&, habana::NodeAttr&&)+0x7b5) [0x7fea3ea0b275]
[18:43:04.096695][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_backend.so(habana::OpBackend::BuildReshape(habana::OpBackend*, synapse_helpers::graph&, internalTensor*, c10::ArrayRef<long>, c10::ScalarType, std::optional<int>, std::optional<unsigned int>)+0x2fa) [0x7fea3ea1220a]
[18:43:04.096699][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_backend.so(habana::OpBackend::ReshapeHelper(synapse_helpers::graph&, internalTensor*, c10::ArrayRef<long>, c10::ScalarType, std::optional<int>, std::optional<unsigned int>)+0x2f) [0x7fea3ea1235f]
[18:43:04.096703][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_backend.so(habana::Prelu::AddNode(synapse_helpers::graph&, std::vector<c10::IValue, std::allocator<c10::IValue> > const&)+0x1e6) [0x7fea3ea2ed76]
[18:43:04.096716][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_backend.so(habana::HabanaLaunchOpPT::BuildSynapseGraph(std::shared_ptr<synapse_helpers::graph>&, habana::SynBuildCache&, bool)+0x15f3) [0x7fea3ef888e3]
[18:43:04.096730][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_backend.so(habana::HabanaLaunchOpPT::run(std::vector<c10::IValue, std::allocator<c10::IValue> >&, std::optional<std::vector<at::Tensor, std::allocator<at::Tensor> > >, std::optional<std::vector<std::vector<long, std::allocator<long> >, std::allocator<std::vector<long, std::allocator<long> > > > >, bool, habana::HabanaLaunchOpPipeline::PipelineCallBase&)+0x99e) [0x7fea3ef9e38e]
[18:43:04.096738][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(+0xf16e65) [0x7fea40447e65]
[18:43:04.096742][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(habana_lazy::exec::HlExec::Launch(std::vector<c10::IValue, std::allocator<c10::IValue> >&, c10::hpu::HPUStream const&, bool)+0x932) [0x7fea4044a6d2]
[18:43:04.096745][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(LaunchSyncTensorsGraph(LaunchTensorsInfo&&, LaunchEagerInfo&&, LaunchStreamInfo&&)+0x607) [0x7fea40424277]
[18:43:04.096751][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(habana_lazy::HbLazyTensor::SyncTensorsGraphInternal(std::vector<habana_lazy::HbLazyTensor, std::allocator<habana_lazy::HbLazyTensor> >*, std::shared_ptr<habana_lazy::HbLazyFrontEndInfoToBackend>, bool, bool)+0x2019) [0x7fea404272d9]
[18:43:04.096759][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(habana_lazy::HbLazyTensor::SyncTensorsGraph(std::vector<habana_lazy::HbLazyTensor, std::allocator<habana_lazy::HbLazyTensor> >*, std::shared_ptr<habana_lazy::HbLazyFrontEndInfoToBackend>, bool, bool)+0x43d) [0x7fea4042886d]
[18:43:04.096765][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(habana_lazy::HbLazyTensor::SyncLiveTensorsGraph(c10::Device const*, std::shared_ptr<habana_lazy::HbLazyFrontEndInfoToBackend>, std::vector<habana_lazy::HbLazyTensor, std::allocator<habana_lazy::HbLazyTensor> >, bool, bool, std::vector<habana_lazy::HbLazyTensor, std::allocator<habana_lazy::HbLazyTensor> >, std::set<long, std::less<long>, std::allocator<long> >)+0x3a8) [0x7fea40429248]
[18:43:04.096772][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(habana_lazy::HbLazyTensor::StepMarker(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::shared_ptr<habana_lazy::HbLazyFrontEndInfoToBackend>, std::vector<habana_lazy::HbLazyTensor, std::allocator<habana_lazy::HbLazyTensor> >, bool, bool, std::vector<habana_lazy::HbLazyTensor, std::allocator<habana_lazy::HbLazyTensor> >, std::set<long, std::less<long>, std::allocator<long> >)+0x95b) [0x7fea4042a26b]
[18:43:04.096776][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(habana_lazy::_local_scalar_dense_hpu(at::Tensor const&)+0x4b9) [0x7fea40266c59]
[18:43:04.096779][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(habana::_local_scalar_dense(at::Tensor const&)+0x2e4) [0x7fea3fd7f814]
[18:43:04.096791][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<c10::Scalar (*)(at::Tensor const&), c10::Scalar, c10::guts::typelist::typelist<at::Tensor const&> >, c10::Scalar (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&)+0x27) [0x7fea3fdd2947]
[18:43:04.096796][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so(at::_ops::_local_scalar_dense::call(at::Tensor const&)+0x143) [0x7feadf86b253]
[18:43:04.096798][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so(at::native::item(at::Tensor const&)+0x94) [0x7feadeed7084]
[18:43:04.096800][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so(+0x2b3d115) [0x7feadff5f115]
[18:43:04.096803][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so(at::_ops::item::call(at::Tensor const&)+0x143) [0x7feadf6b51f3]
[18:43:04.096806][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so(long at::Tensor::item<long>() const+0x2d) [0x7feae04e559d]
[18:43:04.096809][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(habana_lazy::nonzero_hpu_lazy(at::Tensor const&)+0x603) [0x7fea4029a2c3]
[18:43:04.096812][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(habana_lazy::masked_select_hpu_lazy(at::Tensor const&, at::Tensor const&)+0x35d) [0x7fea4029b36d]
[18:43:04.096815][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(hpu_wrap::masked_select(at::Tensor const&, at::Tensor const&)+0x5e2) [0x7fea40201bd2]
[18:43:04.096821][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/habana_frameworks/torch/lib/libhabana_pytorch_plugin.so(c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, at::Tensor const&), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)+0x2a) [0x7fea3fb0492a]
[18:43:04.096825][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so(at::_ops::masked_select::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)+0x8f) [0x7feadf7b83ef]
[18:43:04.096832][PT_BRIDGE       ][error][tid:6622E] /home/robinysh/SpeechLLM/.venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so(+0x4108a10) [0x7feae152aa10]
============================= HABANA PT BRIDGE CONFIGURATION =========================== 
 PT_HPU_LAZY_MODE = 1
 PT_RECIPE_CACHE_PATH = 
 PT_CACHE_FOLDER_DELETE = 0
 PT_HPU_RECIPE_CACHE_CONFIG = 
 PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
 PT_HPU_LAZY_ACC_PAR_MODE = 1
 PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 1
---------------------------: System Configuration :---------------------------
Num CPU Cores : 160
CPU RAM       : 1056375276 KB
------------------------------------------------------------------------------
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~/SpeechLLM/.venv/lib/python3.10/site-packages/IPython/core/formatters.py:711, in PlainTextFormatter.__call__(self, obj)
    704 stream = StringIO()
    705 printer = pretty.RepresentationPrinter(stream, self.verbose,
    706     self.max_width, self.newline,
    707     max_seq_length=self.max_seq_length,
    708     singleton_pprinters=self.singleton_printers,
    709     type_pprinters=self.type_printers,
    710     deferred_pprinters=self.deferred_printers)
--> 711 printer.pretty(obj)
    712 printer.flush()
    713 return stream.getvalue()

File ~/SpeechLLM/.venv/lib/python3.10/site-packages/IPython/lib/pretty.py:419, in RepresentationPrinter.pretty(self, obj)
    408                         return meth(obj, self, cycle)
    409                 if (
    410                     cls is not object
    411                     # check if cls defines __repr__
   (...)
    417                     and callable(_safe_getattr(cls, "__repr__", None))
    418                 ):
--> 419                     return _repr_pprint(obj, self, cycle)
    421     return _default_pprint(obj, self, cycle)
    422 finally:

File ~/SpeechLLM/.venv/lib/python3.10/site-packages/IPython/lib/pretty.py:787, in _repr_pprint(obj, p, cycle)
    785 """A pprint that just redirects to the normal repr function."""
    786 # Find newlines and replace them with p.break_()
--> 787 output = repr(obj)
    788 lines = output.splitlines()
    789 with p.group():

File ~/SpeechLLM/.venv/lib/python3.10/site-packages/torch/_tensor.py:471, in Tensor.__repr__(self, tensor_contents)
    467     return handle_torch_function(
    468         Tensor.__repr__, (self,), self, tensor_contents=tensor_contents
    469     )
    470 # All strings are unicode in Python 3.
--> 471 return torch._tensor_str._str(self, tensor_contents=tensor_contents)

File ~/SpeechLLM/.venv/lib/python3.10/site-packages/torch/_tensor_str.py:677, in _str(self, tensor_contents)
    675 with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
    676     guard = torch._C._DisableFuncTorch()
--> 677     return _str_intern(self, tensor_contents=tensor_contents)

File ~/SpeechLLM/.venv/lib/python3.10/site-packages/torch/_tensor_str.py:597, in _str_intern(inp, tensor_contents)
    595                     tensor_str = _tensor_str(self.to_dense(), indent)
    596                 else:
--> 597                     tensor_str = _tensor_str(self, indent)
    599 if self.layout != torch.strided:
    600     suffixes.append("layout=" + str(self.layout))

File ~/SpeechLLM/.venv/lib/python3.10/site-packages/torch/_tensor_str.py:349, in _tensor_str(self, indent)
    345     return _tensor_str_with_formatter(
    346         self, indent, summarize, real_formatter, imag_formatter
    347     )
    348 else:
--> 349     formatter = _Formatter(get_summarized_data(self) if summarize else self)
    350     return _tensor_str_with_formatter(self, indent, summarize, formatter)

File ~/SpeechLLM/.venv/lib/python3.10/site-packages/torch/_tensor_str.py:137, in _Formatter.__init__(self, tensor)
    134         self.max_width = max(self.max_width, len(value_str))
    136 else:
--> 137     nonzero_finite_vals = torch.masked_select(
    138         tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
    139     )
    141     if nonzero_finite_vals.numel() == 0:
    142         # no valid number, do nothing
    143         return

RuntimeError: synNodeCreateWithId failed for node: reshape with synStatus 1 [Invalid argument]. .

System info:

$HABANALABS_VIRTUAL_DIR='.venv' habanalabs-installer.sh validate
================================================================================
Environment
================================================================================
Device                               gaudi2 
OS                                   ubuntu 
OS version                           22.04  
Log file                             /var/log/habana_logs/install-2024-07-22-18-52-46.log
Release version                      1.16.2-2
Habanalabs server                    vault.habana.ai
Rewrite installer config             no     
Install type                         validate
Python repo URL                      https://vault.habana.ai/artifactory/api/pypi/gaudi-python/simple
Habanalabs software                  [OK]   
habanalabs-container-runtime=1.16.2-2
habanalabs-dkms=1.16.2-2
habanalabs-firmware=1.16.2-2
habanalabs-firmware-odm=1.16.2-2
habanalabs-firmware-tools=1.16.2-2
habanalabs-graph=1.16.2-2
habanalabs-qual=1.16.2-2
habanalabs-qual-workloads=1.16.2-2
habanalabs-rdma-core=1.16.2-2
habanalabs-thunk=1.16.2-2
================================================================================
System
================================================================================
CPU: 160
Model name: Intel(R) Xeon(R) Platinum 8380 CPU @ 2.30GHz
MemTotal: 1056375276 kB
Hugepagesize: 2048 kB
================================================================================
OS environment
================================================================================
Package manager                      [apt]  
The required sudo privileges are     [FAILED]
Python 3.10                          [OK]   
================================================================================
Basic dependencies
================================================================================
gcc                                  [OK]   
cmake                                [OK]   
lsof                                 [OK]   
curl                                 [OK]   
wget                                 [OK]   
linux-headers-5.15.0-92-generic      [OK]   
ethtool                              [OK]   
libelf-dev                           [OK]   
libbz2-dev                           [OK]   
liblzma-dev                          [OK]   
libibverbs-dev                       [OK]   
librdmacm-dev                        [OK]   
dkms                                 [OK]   
linux-modules-extra-5.15.0-92-generic  [OK]   
================================================================================
PyTorch dependencies
================================================================================
gcc                                  [OK]   
cmake                                [OK]   
lsof                                 [OK]   
curl                                 [OK]   
wget                                 [OK]   
unzip                                [OK]   
libcurl4                             [OK]   
moreutils                            [OK]   
iproute2                             [OK]   
libcairo2-dev                        [OK]   
libglib2.0-dev                       [OK]   
libselinux1-dev                      [OK]   
libnuma-dev                          [OK]   
libpcre2-dev                         [OK]   
libatlas-base-dev                    [OK]   
libjpeg-dev                          [OK]   
liblapack-dev                        [OK]   
libnuma-dev                          [OK]   
google-perftools                     [OK]   
numactl                              [OK]   
libopenblas-dev                      [OK]   
The installed version is: 4.1.5
================================================================================
Installed Habanalabs software
================================================================================
habanalabs-container-runtime=1.16.2-2
habanalabs-dkms=1.16.2-2
habanalabs-firmware=1.16.2-2
habanalabs-firmware-odm=1.16.2-2
habanalabs-firmware-tools=1.16.2-2
habanalabs-graph=1.16.2-2
habanalabs-qual=1.16.2-2
habanalabs-qual-workloads=1.16.2-2
habanalabs-rdma-core=1.16.2-2
habanalabs-thunk=1.16.2-2
================================================================================
Full install log: /var/log/habana_logs/install-2024-07-22-18-52-46.log
================================================================================

Thanks for bringing the bug to our notice. This should get fixed in release 1.17.

Given that you have torch.no_grad(), I assume you have a inferencing situation (as opposed to training).

You can do something like to workaround the error:

import habana_frameworks.torch.core as htcore
import torch
with torch.no_grad():
    #prelu = torch.nn.PReLU().to('hpu')
    prelu_cpu = torch.nn.PReLU()
    a = [k for k in prelu_cpu.parameters()][0].item()
    def prelu_fn(x):
        zero = torch.zeros_like(x, device=x.device)
        return torch.max(x,zero)+a*torch.min(x,zero)
    input = torch.randn(1,2).to('hpu')
    output = prelu_fn(input)
    print(output)
    print('***')

    output_cpu = prelu_cpu(input.to('cpu'))
    print(output_cpu)

Great, thank you. I added torch.no_grad() to isolate the bug from any backward operations.

In you need the workaround in training mode as well, you probably have to write a nn.module wih nn.Parameter() so that its trainable, and internally the forward would have max(0,x)+a∗min(0,x).

Something like this: torch.nn.modules.activation — PyTorch 2.3 documentation

Except the forward modified to max(0,x)+a∗min(0,x).

But the original issue should get fixed in the next release.