Result of torch.argmax with -inf tensor on hpu is different from that of cpu and gpu

Hi,

I found out the result of torch.argmax with -inf tensor on HPU is different from that of CPU and GPU.
When the input tensor is filled with -inf, torch.argmax returns the length of the input tensor which is an invalid index.

Environments

  • HL-SMI version: hl-1.15.1-fw-49.0.0.0
  • Driver version: 1.15.0-a596ef0
  • Pytorch version: 2.2.0a0+git8964477

Reproducer

HPU

import torch
import habana_frameworks.torch as htorch

for device in ["cpu", "hpu"]:
    outputs_finfo = torch.fill(torch.empty(2, 10), torch.finfo(torch.float32).min).to(device)
    outputs_inf = torch.fill(torch.empty(2, 10), float("-inf")).to(device)
    samples = torch.argmax(outputs_finfo, dim=-1)
    print(samples.device)
    print(samples)
    samples = torch.argmax(outputs_inf, dim=-1)
    print(samples)

Result

cpu
tensor([0, 0])
tensor([0, 0])
============================= HABANA PT BRIDGE CONFIGURATION =========================== 
 PT_HPU_LAZY_MODE = 1
 PT_RECIPE_CACHE_PATH = 
 PT_CACHE_FOLDER_DELETE = 0
 PT_HPU_RECIPE_CACHE_CONFIG = 
 PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
 PT_HPU_LAZY_ACC_PAR_MODE = 1
 PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
---------------------------: System Configuration :---------------------------
Num CPU Cores : 160
CPU RAM       : 1056375276 KB
------------------------------------------------------------------------------
hpu:0
tensor([0, 0], device='hpu:0')
tensor([10, 10], device='hpu:0')

GPU

import torch

for device in ["cpu", "cuda"]:
    outputs_finfo = torch.fill(torch.empty(2, 10), torch.finfo(torch.float32).min).to(device)
    outputs_inf = torch.fill(torch.empty(2, 10), float("-inf")).to(device)
    samples = torch.argmax(outputs_finfo, dim=-1)
    print(samples.device)
    print(samples)
    samples = torch.argmax(outputs_inf, dim=-1)
    print(samples)

Result

cpu
tensor([0, 0])
tensor([0, 0])
cuda:0
tensor([0, 0], device='cuda:0')
tensor([0, 0], device='cuda:0')
1 Like

Thanks for the post. I can reproduce the error. Will post back here after we investigate this

1 Like