Skip to content

Commit 2870e93

Browse files
authored
[benchmarks] Set matrix multiplication precision. (#7748)
1 parent 6280701 commit 2870e93

File tree

10 files changed

+26
-25
lines changed

10 files changed

+26
-25
lines changed

benchmarks/experiment_runner.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from benchmark_experiment import ExperimentLoader, BenchmarkExperiment
3030
from util import cleanup, move_to_device, randomize_input, reset_rng_state, us_to_s, ns_to_s, StrOrBool
3131

32+
import torch_xla
3233
import torch_xla.core.xla_model as xm
3334
import torch_xla.debug.profiler as xp
3435

@@ -939,6 +940,11 @@ def __str__(self):
939940
action="store_true",
940941
help="Whether to enable fast F32 multiplication in PyTorch.",
941942
)
943+
parser.add_argument(
944+
"--matmul-precision",
945+
choices=["default", "high", "highest"],
946+
help="Set matrix multiplication for both PyTorch and PyTorch/XLA.",
947+
)
942948
parser.add_argument(
943949
"--experiment-config",
944950
type=str,
@@ -1009,9 +1015,15 @@ def main():
10091015
logging.basicConfig(level=args.log_level.value, force=True)
10101016
logger.debug(f"Parsed args: {args}")
10111017

1018+
precision = 'highest'
1019+
if args.matmul_precision is not None:
1020+
precision = args.matmul_precision
1021+
# --disable-tf32 flag may overwrite precision settings for BC reasons.
10121022
if not args.disable_tf32:
10131023
logger.warning('Enabling fast F32 multiplication for PyTorch')
1014-
torch.set_float32_matmul_precision('high')
1024+
precision = 'high'
1025+
torch.set_float32_matmul_precision(precision)
1026+
torch_xla._XLAC._xla_set_mat_mul_precision(precision)
10151027

10161028
if args.profile_xla:
10171029
logger.info(

test/bench.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,5 @@ def run_benchmarks(args):
129129
args.benchs = benchs
130130

131131
torch.set_default_dtype(torch.float32)
132-
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
133-
use_full_mat_mul_precision=True)
132+
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
134133
run_benchmarks(args)

test/pytorch_test_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,8 +632,7 @@ def get_primary_device(cls):
632632
def setUpClass(cls):
633633
# Sets the primary test device to the xla_device (CPU or TPU)
634634
cls.primary_device = str(xm.xla_device())
635-
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
636-
use_full_mat_mul_precision=True)
635+
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
637636

638637
def setUp(self):
639638
super().setUp()

test/test_gmm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,6 @@ def test_gmm_backward_3(self):
465465
logging.getLogger().setLevel(logging.INFO)
466466
torch.set_default_dtype(torch.float32)
467467
torch.manual_seed(42)
468-
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
469-
use_full_mat_mul_precision=True)
468+
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
470469
test = unittest.main()
471470
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_mp_distributed_mm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ def _mp_fn(index):
1212

1313
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
1414
world_size = xr.world_size()
15-
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
16-
use_full_mat_mul_precision=True)
15+
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
1716
torch.manual_seed(11)
1817
xm.set_rng_state(11)
1918

test/test_operations.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3104,8 +3104,7 @@ def test_repeat_special(self):
31043104
if __name__ == '__main__':
31053105
torch.set_default_dtype(torch.float32)
31063106
torch.manual_seed(42)
3107-
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
3108-
use_full_mat_mul_precision=True)
3107+
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
31093108
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
31103109
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
31113110
print(met.metrics_report())

test/test_operations_hlo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ def test_dropout_by_u8_mask(self):
7171
if __name__ == '__main__':
7272
torch.set_default_dtype(torch.float32)
7373
torch.manual_seed(42)
74-
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
75-
use_full_mat_mul_precision=True)
74+
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
7675
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
7776
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
7877
print(met.metrics_report())

test/test_pallas.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,6 @@ def test_flash_attention_sm_scale_backward(self):
925925
logging.getLogger().setLevel(logging.INFO)
926926
torch.set_default_dtype(torch.float32)
927927
torch.manual_seed(42)
928-
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
929-
use_full_mat_mul_precision=True)
928+
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
930929
test = unittest.main()
931930
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_pallas_spmd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ def test_flash_attention_backward_spmd_data_parallel(self):
103103
logging.getLogger().setLevel(logging.INFO)
104104
torch.set_default_dtype(torch.float32)
105105
torch.manual_seed(42)
106-
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
107-
use_full_mat_mul_precision=True)
106+
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
108107
xr.use_spmd()
109108
test = unittest.main()
110109
sys.exit(0 if test.result.wasSuccessful() else 1)

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,14 +1939,11 @@ void InitXlaModuleBindings(py::module m) {
19391939
py::arg("nodes_threshold") = 100, py::arg("device") = "");
19401940
m.def("_xla_memory_info",
19411941
[](const std::string& device) { return GetMemoryInfo(device); });
1942-
m.def(
1943-
"_xla_set_use_full_mat_mul_precision",
1944-
[](bool use_full_mat_mul_precision) {
1945-
XlaHelpers::set_mat_mul_precision(use_full_mat_mul_precision
1946-
? xla::PrecisionConfig::HIGHEST
1947-
: xla::PrecisionConfig::DEFAULT);
1948-
},
1949-
py::arg("use_full_mat_mul_precision") = true);
1942+
m.def("_xla_set_mat_mul_precision", [](const std::string& mat_mul_precision) {
1943+
xla::PrecisionConfig::Precision precision =
1944+
ConsumeValue(xla::StringToPrecision(mat_mul_precision));
1945+
XlaHelpers::set_mat_mul_precision(precision);
1946+
});
19501947

19511948
py::class_<xla::XlaBuilder, op_builder::BuilderPtr>(m, "XlaBuilder");
19521949
py::class_<op_builder::Op, op_builder::OpPtr>(m, "XlaOp");

0 commit comments

Comments
 (0)