Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 39 additions & 36 deletions TROUBLESHOOTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ report sent to us if you have it.

## PyTorch/XLA Debugging Tool

You can enable the PyTorch/XLA debugging tool by setting `PT_XLA_DEBUG=1`, which provides a couple useful debugging features.

## PyTorch/XLA + Dynamo Debugging Tool

You can enable the PyTorch/XLA + Dynamo debugging tool by setting `XLA_DYNAMO_DEBUG=1`.
You can enable the PyTorch/XLA debugging tool by setting `PT_XLA_DEBUG_LEVEL=2`, which provides a couple useful debugging features. You can also lower the debug level to `1` to slip the execution analysis.

### Perform A Auto-Metrics Analysis

Expand All @@ -79,41 +75,44 @@ The debugging tool will analyze every compilation and execution for your model.
```
Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis: user mark_step
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: 537d4b0264b029688281412214d252e9
Compilation Analysis: Number of Graph Inputs: 588
Compilation Analysis: Number of Graph Outputs: 320
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: mark_step (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:840)
Compilation Analysis: broadcast_master_param (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:1230)
Compilation Analysis: train_imagenet (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:261)
Compilation Analysis: _mp_fn (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:365)
Compilation Analysis: __call__ (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:176)
Compilation Analysis: _thread_fn (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:70)
Compilation Analysis: run (/usr/local/lib/python3.8/concurrent/futures/thread.py:57)
Compilation Analysis: _worker (/usr/local/lib/python3.8/concurrent/futures/thread.py:80)
Compilation Analysis: ..........
Compilation Analysis: mark_step in parallel loader at step end
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: c74c3b91b855b2b123f833b0d5f86943
Compilation Analysis: Number of Graph Inputs: 35
Compilation Analysis: Number of Graph Outputs: 107
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055)
Compilation Analysis: next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44)
Compilation Analysis: __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32)
Compilation Analysis: train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:48)
Compilation Analysis: start_training (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:65)
Compilation Analysis: <module> (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:73)
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 1.548000 GB
Post Compilation Analysis: Graph output size: 7.922460 GB
Post Compilation Analysis: Aliased Input size: 1.547871 GB
Post Compilation Analysis: Intermediate tensor size: 12.124478 GB
Post Compilation Analysis: Compiled program size: 0.028210 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: user mark_step
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: 537d4b0264b029688281412214d252e9
Execution Analysis: Number of Graph Inputs: 588
Execution Analysis: Number of Graph Outputs: 320
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: mark_step (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:840)
Execution Analysis: broadcast_master_param (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:1230)
Execution Analysis: train_imagenet (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:261)
Execution Analysis: _mp_fn (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:365)
Execution Analysis: __call__ (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:176)
Execution Analysis: _thread_fn (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:70)
Execution Analysis: run (/usr/local/lib/python3.8/concurrent/futures/thread.py:57)
Execution Analysis: _worker (/usr/local/lib/python3.8/concurrent/futures/thread.py:80)
Execution Analysis: ..........
Execution Analysis: mark_step in parallel loader at step end
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: c74c3b91b855b2b123f833b0d5f86943
Execution Analysis: Number of Graph Inputs: 35
Execution Analysis: Number of Graph Outputs: 107
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055)
Execution Analysis: next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44)
Execution Analysis: __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32)
Execution Analysis: train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:48)
Execution Analysis: start_training (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:65)
Execution Analysis: <module> (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:73)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
```
Expand All @@ -127,7 +126,7 @@ Some common causes of Compilation/Executation are

The executation caused by 1-4 are expected, and we want to avoid 5 by either reduce the frequency of accessing tensor values or manually add a `mark_step` before accessing.

Users should expect to see this `Compilation Cause` + `Executation Cause` pairs for first couple steps. After the model stabilize users should expect to only see `Execution Cause`. To use PyTorch/XLA efficiently, we expect the same models code to be run for every step and compilation only happen once for every graph. If you keep seeing `Compilation Cause`, you should try to dump the IR/HLO following [this section](#common-debugging-environment-variables-combinations) and compare the graphs for each step and understand the source of the differences.
Users should expect to see this `Compilation Cause` + `Executation Cause` pairs for first couple steps. After the model stabilize users should expect to only see `Execution Cause`(you can disable execution analysis by `PT_XLA_DEBUG_LEVEL=1`). To use PyTorch/XLA efficiently, we expect the same models code to be run for every step and compilation only happen once for every graph. If you keep seeing `Compilation Cause`, you should try to dump the IR/HLO following [this section](#common-debugging-environment-variables-combinations) and compare the graphs for each step and understand the source of the differences.

Following section will explain how to get and understand a more detail metrics report.

Expand Down Expand Up @@ -192,6 +191,10 @@ import torch_xla.debug.metrics as met
met.clear_all()
```

## PyTorch/XLA + Dynamo Debugging Tool

You can enable the PyTorch/XLA + Dynamo debugging tool by setting `XLA_DYNAMO_DEBUG=1`.

## Performance Profiling
To profile your workload in depth to understand bottlenecks please check the following resources:
* [Official tutorial](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)
Expand Down
152 changes: 94 additions & 58 deletions test/debug_tool/test_pt_xla_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ class PtXLADebugTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
if not check_env_flag('PT_XLA_DEBUG'):
pt_xla_debug_enabled = xu.getenv_as('PT_XLA_DEBUG', bool, False)
cls.debug_level = xu.getenv_as('PT_XLA_DEBUG_LEVEL', int, -1)
cls.debug_level = 100 if (cls.debug_level == -1 and
pt_xla_debug_enabled) else cls.debug_level
if not check_env_flag('PT_XLA_DEBUG') and cls.debug_level == -1:
assert False, "This test should be run with PT_XLA_DEBUG"
cls.debug_file_name = os.getenv('PT_XLA_DEBUG_FILE')
if not cls.debug_file_name:
Expand All @@ -37,25 +41,30 @@ def test_user_mark_step(self):
post_compilation_infos = extract_post_compilation_analysis(lines)

self.assertEqual(len(post_compilation_infos), 1)
# test case is too small, size round to 0 MB
self.assertIn('0MB', post_compilation_infos[0].input_size)
self.assertIn('0MB', post_compilation_infos[0].output_size)
self.assertIn('0MB', post_compilation_infos[0].aliased_size)
self.assertIn('0MB', post_compilation_infos[0].intermediate_size)
self.assertIn('0MB', post_compilation_infos[0].program_size)

self.assertEqual(len(executation_causes), 1)
self.assertIn('user mark_step', executation_causes[0])
self.assertIn('GB', post_compilation_infos[0].input_size)
self.assertIn('GB', post_compilation_infos[0].output_size)
self.assertIn('GB', post_compilation_infos[0].aliased_size)
self.assertIn('GB', post_compilation_infos[0].intermediate_size)
self.assertIn('GB', post_compilation_infos[0].program_size)

if self.debug_level > 1:
self.assertEqual(len(executation_causes), 1)
self.assertIn('user mark_step', executation_causes[0])
else:
self.assertEqual(len(executation_causes), 0)

self.assertEqual(len(compilation_causes), 1)
self.assertIn('user mark_step', compilation_causes[0])

self.assertEqual(len(graph_infos), 2)
# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
if self.debug_level > 1:
self.assertEqual(len(graph_infos), 2)
# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
else:
self.assertEqual(len(graph_infos), 1)
# this graph has one input(random seed) and one output(t1)
self.assertEqual(graph_infos[1].num_input, 1)
self.assertEqual(graph_infos[1].num_output, 1)
self.assertEqual(graph_infos[0].num_input, 1)
self.assertEqual(graph_infos[0].num_output, 1)
open(self.debug_file_name, 'w').close()

def test_step_trace(self):
Expand All @@ -68,20 +77,26 @@ def test_step_trace(self):
compilation_causes = extract_compilation_cause(lines)
graph_infos = extract_graph_infos(lines)

self.assertEqual(len(causes), 1)
self.assertIn('mark_step when exiting a profiler StepTrace region',
causes[0])
if self.debug_level > 1:
self.assertEqual(len(causes), 1)
self.assertIn('mark_step when exiting a profiler StepTrace region',
causes[0])
else:
self.assertEqual(len(causes), 0)

self.assertEqual(len(compilation_causes), 1)
self.assertIn('mark_step when exiting a profiler StepTrace region',
compilation_causes[0])

self.assertEqual(len(graph_infos), 2)
# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
if self.debug_level > 1:
self.assertEqual(len(graph_infos), 2)
# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
else:
self.assertEqual(len(graph_infos), 1)
# this graph has one input(random seed) and one output(t1)
self.assertEqual(graph_infos[1].num_input, 1)
self.assertEqual(graph_infos[1].num_output, 1)
self.assertEqual(graph_infos[0].num_input, 1)
self.assertEqual(graph_infos[0].num_output, 1)
open(self.debug_file_name, 'w').close()

def test_dynamo(self):
Expand All @@ -99,29 +114,39 @@ def toy_program(t1):
compilation_causes = extract_compilation_cause(lines)
graph_infos = extract_graph_infos(lines)

self.assertEqual(len(executation_causes), 2)
self.assertIn('mark_step when dynamo processing input graphs',
executation_causes[0])
self.assertIn('dynamo is executing a compiled program',
executation_causes[1])
if self.debug_level > 1:
self.assertEqual(len(executation_causes), 2)
self.assertIn('mark_step when dynamo processing input graphs',
executation_causes[0])
self.assertIn('dynamo is executing a compiled program',
executation_causes[1])
else:
self.assertEqual(len(executation_causes), 0)

self.assertEqual(len(compilation_causes), 2)
self.assertIn('mark_step when dynamo processing input graphs',
compilation_causes[0])
self.assertIn('dynamo is compiling a FX graph to HLO',
compilation_causes[1])

# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
if self.debug_level > 1:
# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
# this graph has one input(random seed) and one output(t1)
self.assertEqual(graph_infos[1].num_input, 1)
self.assertEqual(graph_infos[1].num_output, 1)

# one graph info from dynamo compilation, one from dynamo execution, hash should match
self.assertEqual(graph_infos[2].hash, graph_infos[3].hash)
# this graph has two input(t1, 100) and one output
self.assertEqual(graph_infos[3].num_input, 2)
self.assertEqual(graph_infos[3].num_output, 1)
self.assertEqual(graph_infos[0].num_input, 1)
self.assertEqual(graph_infos[0].num_output, 1)

if self.debug_level > 1:
# one graph info from dynamo compilation, one from dynamo execution, hash should match
self.assertEqual(graph_infos[2].hash, graph_infos[3].hash)
# this graph has two input(t1, 100) and one output
self.assertEqual(graph_infos[3].num_input, 2)
self.assertEqual(graph_infos[3].num_output, 1)
else:
# this graph has two input(t1, 100) and one output
self.assertEqual(graph_infos[1].num_input, 2)
self.assertEqual(graph_infos[1].num_output, 1)

open(self.debug_file_name, 'w').close()

def test_parallel_loader(self):
Expand Down Expand Up @@ -150,22 +175,26 @@ def test_parallel_loader(self):
compilation_causes = extract_compilation_cause(lines)
graph_infos = extract_graph_infos(lines)

self.assertEqual(len(executation_causes), batch_size)
for cause in executation_causes:
self.assertIn('mark_step in parallel loader at step end', cause)
if self.debug_level > 1:
self.assertEqual(len(executation_causes), batch_size)
for cause in executation_causes:
self.assertIn('mark_step in parallel loader at step end', cause)
else:
self.assertEqual(len(executation_causes), 0)

# We should only compile once.
self.assertEqual(len(compilation_causes), 1)
self.assertIn('mark_step in parallel loader at step end',
compilation_causes[0])

self.assertEqual(len(graph_infos), batch_size + 1)
# one graph info from compilation, batch size from execution, hash should match
for i in range(batch_size + 1):
self.assertEqual(graph_infos[0].hash, graph_infos[i].hash)
# this graph has two input(data, 100) and one output(dummy)
self.assertEqual(graph_infos[i].num_input, 2)
self.assertEqual(graph_infos[i].num_output, 1)
if self.debug_level > 1:
self.assertEqual(len(graph_infos), batch_size + 1)
# one graph info from compilation, batch size from execution, hash should match
for i in range(batch_size + 1):
self.assertEqual(graph_infos[0].hash, graph_infos[i].hash)
# this graph has two input(data, 100) and one output(dummy)
self.assertEqual(graph_infos[i].num_input, 2)
self.assertEqual(graph_infos[i].num_output, 1)
open(self.debug_file_name, 'w').close()

def test_print(self):
Expand All @@ -178,19 +207,22 @@ def test_print(self):
compilation_causes = extract_compilation_cause(lines)
graph_infos = extract_graph_infos(lines)

self.assertEqual(len(executation_causes), 1)
self.assertIn('user code trying to access tensor value',
executation_causes[0])
if self.debug_level > 1:
self.assertEqual(len(executation_causes), 1)
self.assertIn('user code trying to access tensor value',
executation_causes[0])
# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
else:
self.assertEqual(len(executation_causes), 0)

self.assertEqual(len(compilation_causes), 1)
self.assertIn('user code trying to access tensor value',
compilation_causes[0])

# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
# this graph has one input(random seed) and one output(t1)
self.assertEqual(graph_infos[1].num_input, 1)
self.assertEqual(graph_infos[1].num_output, 1)
self.assertEqual(graph_infos[0].num_input, 1)
self.assertEqual(graph_infos[0].num_output, 1)
open(self.debug_file_name, 'w').close()

def test_frame(self):
Expand All @@ -201,16 +233,20 @@ def test_frame(self):
lines = f.readlines()
frames = extract_python_frames(lines)

# one for compilation, one for execution
self.assertEqual(len(frames), 3)
# one for compilation, one for post-compilation analysis, one for execution
if self.debug_level > 1:
self.assertEqual(len(frames), 3)
else:
self.assertEqual(len(frames), 2)
max_frame = os.getenv('PT_XLA_DEBUG_MAX_FRAME', 8)
# Additonal lines are
# 1. Python Frame Triggered Execution:
# 2. ....
# 3. empty line
self.assertEqual(len(frames[0].split('\n')), max_frame + 3)
# second frame will be empty from the post-compilation-analysis
self.assertEqual(len(frames[2].split('\n')), max_frame + 3)
if self.debug_level > 1:
self.assertEqual(len(frames[2].split('\n')), max_frame + 3)
# Check mark_step is the first frame
self.assertIn('mark_step', frames[0].split('\n')[1])

Expand Down
6 changes: 6 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ function run_pt_xla_debug {
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
}

function run_pt_xla_debug_level1 {
echo "Running in save tensor file mode: $@"
PT_XLA_DEBUG_LEVEL=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
}

function run_torchrun {
if [ -x "$(command -v nvidia-smi)" ] && [ "$XLA_CUDA" != "0" ]; then
echo "Running torchrun test for GPU $@"
Expand Down Expand Up @@ -165,6 +170,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py"
run_pt_xla_debug_level1 "$CDIR/debug_tool/test_pt_xla_debug.py"
run_test "$CDIR/test_async_closures.py"
run_test "$CDIR/test_hlo_metadata.py"
run_test "$CDIR/test_profiler.py"
Expand Down
Loading