Skip to content

Commit 6fa7770

Browse files
authored
Neuron import hook (#5429)
* Enable Neuron import hook for calling initialization functions if using AWS Neuron * removing copy/paste error * moving aws init call and removing comment
1 parent 5734ab8 commit 6fa7770

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

torch_xla/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ def _summarize_fn_tracker():
6060
os.remove(_tmp_fname)
6161

6262

63+
def _aws_ec2_inf_trn_init():
64+
try:
65+
from torch_neuronx import xla
66+
except ImportError:
67+
return
68+
else:
69+
xla.init()
70+
71+
6372
def _setup_tpu_vm_library_path() -> bool:
6473
"""Returns true if $TPU_LIBRARY is set or can be inferred.
6574
@@ -105,6 +114,9 @@ def _setup_tpu_vm_library_path() -> bool:
105114

106115
_found_libtpu = _setup_tpu_vm_library_path()
107116

117+
# Setup Neuron library for AWS EC2 inf/trn instances.
118+
_aws_ec2_inf_trn_init()
119+
108120

109121
def _prepare_to_exit():
110122
_XLAC._prepare_to_exit()

0 commit comments

Comments
 (0)