Skip to content

Conversation

@ymwangg
Copy link
Contributor

@ymwangg ymwangg commented Nov 29, 2021

This PR fixes the f64 div issue when a f32 tensor is divided by a f64 scalar. This issue significantly slows down some huggingface models on Nvidia T4 GPU which has poor f64 performance.

Minimal code to reproduce this issue:

import torch, torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

a = torch.rand(10,10, dtype=torch.float32).to(device)
b = a / 2.0
print(torch_xla._XLAC._get_xla_tensors_text([b]))

Before this fix:

IR {
  %0 = f64[] xla::device_data(), device=GPU:0
  %1 = f32[10,10]{1,0} xla::device_data(), device=GPU:0
  %2 = f64[10,10]{1,0} aten::div(%1, %0)
  %3 = f32[10,10]{1,0} xla::cast(%2), type=f32, ROOT=0
}

After this fix:

IR {
  %0 = f32[] xla::device_data(), device=GPU:0
  %1 = f32[10,10]{1,0} xla::device_data(), device=GPU:0
  %2 = f32[10,10]{1,0} aten::div(%1, %0), ROOT=0
}

@JackCaoG
Copy link
Collaborator

Thanks @ymwangg , I will take a look

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

XLA_FN_COUNTER("xla::");
at::ScalarType dtype = at::result_type(self, other);
auto operands = GetBinaryOperands(self, other);
auto operands = GetBinaryOperands(self, UnwrapNumber(other, dtype));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah.. nice catch. All of the binary ops should have this UnwrapNumber logic baked in. However div now we have rounding_mode in its signature, we don't call DoBinaryOp.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 1, 2021

@ymwangg Can you rebase to get pass the build error?

Add test case
@JackCaoG JackCaoG merged commit 807e25a into pytorch:master Dec 2, 2021
@ymwangg ymwangg deleted the fix_div branch August 25, 2022 18:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants