@@ -63,6 +63,22 @@ struct Caster<tsl::bfloat16> {
6363 }
6464};
6565
66+ template <>
67+ struct Caster <at::Float8_e4m3fn> {
68+ template <typename D>
69+ D cast (const at::Float8_e4m3fn& value) const {
70+ return static_cast <D>(static_cast <float >(value));
71+ }
72+ };
73+
74+ template <>
75+ struct Caster <tsl::float8_e4m3fn> {
76+ template <typename D>
77+ D cast (const tsl::float8_e4m3fn& value) const {
78+ return static_cast <D>(static_cast <float >(value));
79+ }
80+ };
81+
6682template <>
6783struct Caster <at::Float8_e5m2> {
6884 template <typename D>
@@ -201,6 +217,15 @@ template <>
201217struct NeedCast <at::BFloat16> {
202218 static constexpr bool value = true ;
203219};
220+
221+ template <>
222+ struct NeedCast <tsl::float8_e4m3fn> {
223+ static constexpr bool value = true ;
224+ };
225+ template <>
226+ struct NeedCast <at::Float8_e4m3fn> {
227+ static constexpr bool value = true ;
228+ };
204229template <>
205230struct NeedCast <tsl::float8_e5m2> {
206231 static constexpr bool value = true ;
@@ -274,6 +299,18 @@ void CopyData<tsl::bfloat16, at::BFloat16>(tsl::bfloat16* dest,
274299 CheckedMemcpy<tsl::bfloat16, at::BFloat16>(dest, source, n);
275300}
276301template <>
302+ void CopyData<at::Float8_e4m3fn, tsl::float8_e4m3fn>(
303+ at::Float8_e4m3fn* dest, const tsl::float8_e4m3fn* source, int64_t n,
304+ const CopyCasted&) {
305+ CheckedMemcpy<at::Float8_e4m3fn, tsl::float8_e4m3fn>(dest, source, n);
306+ }
307+ template <>
308+ void CopyData<tsl::float8_e4m3fn, at::Float8_e4m3fn>(
309+ tsl::float8_e4m3fn* dest, const at::Float8_e4m3fn* source, int64_t n,
310+ const CopyCasted&) {
311+ CheckedMemcpy<tsl::float8_e4m3fn, at::Float8_e4m3fn>(dest, source, n);
312+ }
313+ template <>
277314void CopyData<at::Float8_e5m2, tsl::float8_e5m2>(at::Float8_e5m2* dest,
278315 const tsl::float8_e5m2* source,
279316 int64_t n, const CopyCasted&) {
@@ -451,6 +488,10 @@ void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape,
451488 TensorToBuffer<SType, double >(tensor, dest_shape, dest_buffer,
452489 dest_buffer_size, device);
453490 break ;
491+ case xla::PrimitiveType::F8E4M3FN:
492+ TensorToBuffer<SType, tsl::float8_e4m3fn>(tensor, dest_shape, dest_buffer,
493+ dest_buffer_size, device);
494+ break ;
454495 case xla::PrimitiveType::F8E5M2:
455496 TensorToBuffer<SType, tsl::float8_e5m2>(tensor, dest_shape, dest_buffer,
456497 dest_buffer_size, device);
@@ -578,6 +619,10 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,
578619 dest_element_type);
579620 case at::ScalarType::Half:
580621 return XlaLiteralToTensor<SType, at::Half>(literal, dest_element_type);
622+ case at::ScalarType::Float8_e4m3fn:
623+ return XlaLiteralToTensor<SType, at::Float8_e4m3fn>(literal,
624+ dest_element_type);
625+
581626 case at::ScalarType::Float8_e5m2:
582627 return XlaLiteralToTensor<SType, at::Float8_e5m2>(literal,
583628 dest_element_type);
@@ -611,6 +656,11 @@ void PopulateTensorBuffer(const at::Tensor& tensor,
611656 TensorToBufferSType<at::BFloat16>(tensor, dest_shape, dest_buffer,
612657 dest_buffer_size, device);
613658 break ;
659+ case at::ScalarType::Float8_e4m3fn:
660+ TensorToBufferSType<at::Float8_e4m3fn>(tensor, dest_shape, dest_buffer,
661+ dest_buffer_size, device);
662+ break ;
663+
614664 case at::ScalarType::Float8_e5m2:
615665 TensorToBufferSType<at::Float8_e5m2>(tensor, dest_shape, dest_buffer,
616666 dest_buffer_size, device);
@@ -674,6 +724,10 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal,
674724 case xla::PrimitiveType::BF16:
675725 return XlaLiteralToTensorHelper<tsl::bfloat16>(literal,
676726 dest_element_type);
727+ case xla::PrimitiveType::F8E4M3FN:
728+ return XlaLiteralToTensorHelper<tsl::float8_e4m3fn>(literal,
729+ dest_element_type);
730+
677731 case xla::PrimitiveType::F8E5M2:
678732 return XlaLiteralToTensorHelper<tsl::float8_e5m2>(literal,
679733 dest_element_type);
0 commit comments