Skip to content

Commit 1fb09bc

Browse files
authored
Add fp8e4m3fn support (#7842)
1 parent 41bf6da commit 1fb09bc

File tree

3 files changed

+103
-38
lines changed

3 files changed

+103
-38
lines changed

test/test_fp8.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,53 @@
44
import torch
55
import torch_xla
66
import unittest
7-
8-
9-
class Fp8Test(unittest.TestCase):
10-
11-
def test_fp8(self):
12-
device = torch_xla.device()
13-
fp8_types = [torch.float8_e5m2]
14-
for dtype in fp8_types:
15-
t = torch.rand(2, 2).to(dtype)
16-
xla_t = t.to(device)
17-
torch_t = xla_t.cpu()
18-
self.assertEqual(xla_t.dtype, dtype)
19-
self.assertEqual(torch_t.dtype, dtype)
20-
# Need to cast to float32 since allclose doesn't work with fp8.
21-
self.assertTrue(
22-
torch.allclose(t.to(torch.float32), torch_t.to(torch.float32)))
23-
24-
def test_fp8_matmul(self):
25-
device = torch_xla.device()
26-
fp8_types = [torch.float8_e5m2]
27-
for dtype in fp8_types:
28-
t = torch.rand(3, 2).to(dtype)
29-
w = torch.rand(2, 5).to(dtype)
30-
torch_matmul = torch.matmul(t, w)
31-
xla_t = t.to(device)
32-
xla_w = w.to(device)
33-
xla_matmul = torch.matmul(xla_t, xla_w)
34-
xla_matmul = xla_matmul.cpu()
35-
# Need to cast to float32 since allclose doesn't work with fp8.
36-
self.assertTrue(
37-
torch.allclose(
38-
xla_matmul.to(torch.float32), torch_matmul.to(torch.float32)))
39-
40-
def test_fp8_hlo(self):
41-
device = torch_xla.device()
42-
x = torch.randn((3, 5)).to(torch.float8_e5m2).to(device)
43-
w = torch.randn((5, 8)).to(torch.float8_e5m2).to(device)
7+
from absl.testing import parameterized
8+
9+
device = torch_xla.device()
10+
11+
dtype_parameters = [
12+
torch.float8_e5m2,
13+
torch.float8_e4m3fn,
14+
]
15+
16+
17+
class Fp8Test(parameterized.TestCase):
18+
19+
@parameterized.parameters(*dtype_parameters)
20+
def test_fp8(self, dtype):
21+
t = torch.rand(2, 2).to(dtype)
22+
xla_t = t.to(device)
23+
torch_t = xla_t.cpu()
24+
self.assertEqual(xla_t.dtype, dtype)
25+
self.assertEqual(torch_t.dtype, dtype)
26+
# Need to cast to float32 since allclose doesn't work with fp8.
27+
self.assertTrue(
28+
torch.allclose(t.to(torch.float32), torch_t.to(torch.float32)))
29+
30+
@parameterized.parameters(*dtype_parameters)
31+
def test_fp8_matmul(self, dtype):
32+
t = torch.rand(3, 2).to(dtype)
33+
w = torch.rand(2, 5).to(dtype)
34+
torch_matmul = torch.matmul(t, w)
35+
xla_t = t.to(device)
36+
xla_w = w.to(device)
37+
xla_matmul = torch.matmul(xla_t, xla_w)
38+
xla_matmul = xla_matmul.cpu()
39+
# Need to cast to float32 since allclose doesn't work with fp8.
40+
self.assertTrue(
41+
torch.allclose(
42+
xla_matmul.to(torch.float32), torch_matmul.to(torch.float32)))
43+
44+
@parameterized.parameters(*dtype_parameters)
45+
def test_fp8_hlo(self, dtype):
46+
x = torch.randn((3, 5)).to(dtype).to(device)
47+
w = torch.randn((5, 8)).to(dtype).to(device)
4448
output = torch.matmul(x, w)
4549
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
46-
self.assertTrue(re.search(r'f8e5m2.*dot.*f8e5m2.*f8e5m2', hlo) is not None)
50+
exmy_str = str(dtype).split('_')[-1]
51+
self.assertTrue(
52+
re.search(rf'f8{exmy_str}.*dot.*f8{exmy_str}.*f8{exmy_str}', hlo)
53+
is not None)
4754

4855

4956
if __name__ == '__main__':

torch_xla/csrc/dtype.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
99
switch (xla_type) {
1010
case xla::PrimitiveType::BF16:
1111
return at::ScalarType::BFloat16;
12+
case xla::PrimitiveType::F8E4M3FN:
13+
return at::ScalarType::Float8_e4m3fn;
1214
case xla::PrimitiveType::F8E5M2:
1315
return at::ScalarType::Float8_e5m2;
1416
case xla::PrimitiveType::F16:
@@ -51,6 +53,8 @@ xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) {
5153
return xla::PrimitiveType::BF16;
5254
case at::ScalarType::Half:
5355
return xla::PrimitiveType::F16;
56+
case at::ScalarType::Float8_e4m3fn:
57+
return xla::PrimitiveType::F8E4M3FN;
5458
case at::ScalarType::Float8_e5m2:
5559
return xla::PrimitiveType::F8E5M2;
5660
case at::ScalarType::Bool:

torch_xla/csrc/tensor_util.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,22 @@ struct Caster<tsl::bfloat16> {
6363
}
6464
};
6565

66+
template <>
67+
struct Caster<at::Float8_e4m3fn> {
68+
template <typename D>
69+
D cast(const at::Float8_e4m3fn& value) const {
70+
return static_cast<D>(static_cast<float>(value));
71+
}
72+
};
73+
74+
template <>
75+
struct Caster<tsl::float8_e4m3fn> {
76+
template <typename D>
77+
D cast(const tsl::float8_e4m3fn& value) const {
78+
return static_cast<D>(static_cast<float>(value));
79+
}
80+
};
81+
6682
template <>
6783
struct Caster<at::Float8_e5m2> {
6884
template <typename D>
@@ -201,6 +217,15 @@ template <>
201217
struct NeedCast<at::BFloat16> {
202218
static constexpr bool value = true;
203219
};
220+
221+
template <>
222+
struct NeedCast<tsl::float8_e4m3fn> {
223+
static constexpr bool value = true;
224+
};
225+
template <>
226+
struct NeedCast<at::Float8_e4m3fn> {
227+
static constexpr bool value = true;
228+
};
204229
template <>
205230
struct NeedCast<tsl::float8_e5m2> {
206231
static constexpr bool value = true;
@@ -274,6 +299,18 @@ void CopyData<tsl::bfloat16, at::BFloat16>(tsl::bfloat16* dest,
274299
CheckedMemcpy<tsl::bfloat16, at::BFloat16>(dest, source, n);
275300
}
276301
template <>
302+
void CopyData<at::Float8_e4m3fn, tsl::float8_e4m3fn>(
303+
at::Float8_e4m3fn* dest, const tsl::float8_e4m3fn* source, int64_t n,
304+
const CopyCasted&) {
305+
CheckedMemcpy<at::Float8_e4m3fn, tsl::float8_e4m3fn>(dest, source, n);
306+
}
307+
template <>
308+
void CopyData<tsl::float8_e4m3fn, at::Float8_e4m3fn>(
309+
tsl::float8_e4m3fn* dest, const at::Float8_e4m3fn* source, int64_t n,
310+
const CopyCasted&) {
311+
CheckedMemcpy<tsl::float8_e4m3fn, at::Float8_e4m3fn>(dest, source, n);
312+
}
313+
template <>
277314
void CopyData<at::Float8_e5m2, tsl::float8_e5m2>(at::Float8_e5m2* dest,
278315
const tsl::float8_e5m2* source,
279316
int64_t n, const CopyCasted&) {
@@ -451,6 +488,10 @@ void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape,
451488
TensorToBuffer<SType, double>(tensor, dest_shape, dest_buffer,
452489
dest_buffer_size, device);
453490
break;
491+
case xla::PrimitiveType::F8E4M3FN:
492+
TensorToBuffer<SType, tsl::float8_e4m3fn>(tensor, dest_shape, dest_buffer,
493+
dest_buffer_size, device);
494+
break;
454495
case xla::PrimitiveType::F8E5M2:
455496
TensorToBuffer<SType, tsl::float8_e5m2>(tensor, dest_shape, dest_buffer,
456497
dest_buffer_size, device);
@@ -578,6 +619,10 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,
578619
dest_element_type);
579620
case at::ScalarType::Half:
580621
return XlaLiteralToTensor<SType, at::Half>(literal, dest_element_type);
622+
case at::ScalarType::Float8_e4m3fn:
623+
return XlaLiteralToTensor<SType, at::Float8_e4m3fn>(literal,
624+
dest_element_type);
625+
581626
case at::ScalarType::Float8_e5m2:
582627
return XlaLiteralToTensor<SType, at::Float8_e5m2>(literal,
583628
dest_element_type);
@@ -611,6 +656,11 @@ void PopulateTensorBuffer(const at::Tensor& tensor,
611656
TensorToBufferSType<at::BFloat16>(tensor, dest_shape, dest_buffer,
612657
dest_buffer_size, device);
613658
break;
659+
case at::ScalarType::Float8_e4m3fn:
660+
TensorToBufferSType<at::Float8_e4m3fn>(tensor, dest_shape, dest_buffer,
661+
dest_buffer_size, device);
662+
break;
663+
614664
case at::ScalarType::Float8_e5m2:
615665
TensorToBufferSType<at::Float8_e5m2>(tensor, dest_shape, dest_buffer,
616666
dest_buffer_size, device);
@@ -674,6 +724,10 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal,
674724
case xla::PrimitiveType::BF16:
675725
return XlaLiteralToTensorHelper<tsl::bfloat16>(literal,
676726
dest_element_type);
727+
case xla::PrimitiveType::F8E4M3FN:
728+
return XlaLiteralToTensorHelper<tsl::float8_e4m3fn>(literal,
729+
dest_element_type);
730+
677731
case xla::PrimitiveType::F8E5M2:
678732
return XlaLiteralToTensorHelper<tsl::float8_e5m2>(literal,
679733
dest_element_type);

0 commit comments

Comments
 (0)