We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5734ab8 commit 6fa7770Copy full SHA for 6fa7770
torch_xla/__init__.py
@@ -60,6 +60,15 @@ def _summarize_fn_tracker():
60
os.remove(_tmp_fname)
61
62
63
+def _aws_ec2_inf_trn_init():
64
+ try:
65
+ from torch_neuronx import xla
66
+ except ImportError:
67
+ return
68
+ else:
69
+ xla.init()
70
+
71
72
def _setup_tpu_vm_library_path() -> bool:
73
"""Returns true if $TPU_LIBRARY is set or can be inferred.
74
@@ -105,6 +114,9 @@ def _setup_tpu_vm_library_path() -> bool:
105
114
106
115
_found_libtpu = _setup_tpu_vm_library_path()
107
116
117
+# Setup Neuron library for AWS EC2 inf/trn instances.
118
+_aws_ec2_inf_trn_init()
119
108
120
109
121
def _prepare_to_exit():
110
122
_XLAC._prepare_to_exit()
0 commit comments