Transferring kNN results from CPU to HPU breaks back propagation

In a GNN training code, we use the kNN algorithm to generate graphs and then operate graph convolution on the kNN graphs.

  • The kNN algorithm is operated on CPU
  • Then we transfer the generated edges to HPU
  • Then we use the edges for graph convolution

However, this seemingly breaks the back propagation and raises the error RuntimeError: [Rank:0] FATAL ERROR :: MODULE:PT_BRIDGE Exception in Lowering thread. See similar errors here.

See the details of the kNN algorithm here, and see the code here.

  • I use the docker image vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest and then pip install torch_geometric
  • Specifically, torch-geometric==2.6.1