-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathgen_dtype_enum.go
More file actions
372 lines (320 loc) · 12 KB
/
gen_dtype_enum.go
File metadata and controls
372 lines (320 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
package dtypes
/***** File generated by ./cmd/dtypes_codegen, don't edit it directly. *****/
import "github.com/gomlx/gopjrt/internal/protos/xla_data"
// DType is an enum represents the data type of a buffer or a scalar.
// These are all the types supported by XLA/PJRT.
//
// The names come from the C/C++ constants, so they are not Go idiomatic.
// The package provides some aliases.
//
// It is unfortunate, but the data types enums used in XLA/PJRT (which DType is modeled after)
// and in C++ XlaBuilder (and other parts of XLA) don't match.
// The gopjrt project uses the PJRT enum everywhere, and makes the conversions when needed to call C++ code (see
// DType.PrimitiveType and FromPrimitiveType for conversion).
type DType int32
const (
// InvalidDType is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_INVALID).
// Invalid primitive type to serve as default.
InvalidDType DType = 0
// Bool is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_PRED).
// Predicates are two-state booleans.
Bool DType = 1
// Int8 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S8).
// Signed integral values of fixed width.
Int8 DType = 2
// Int16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S16).
Int16 DType = 3
// Int32 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S32).
Int32 DType = 4
// Int64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S64).
Int64 DType = 5
// Uint8 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U8).
// Unsigned integral values of fixed width.
Uint8 DType = 6
// Uint16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U16).
Uint16 DType = 7
// Uint32 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U32).
Uint32 DType = 8
// Uint64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U64).
Uint64 DType = 9
// Float16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F16).
// Floating-point values of fixed width.
Float16 DType = 10
// Float32 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F32).
Float32 DType = 11
// Float64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F64).
Float64 DType = 12
// BFloat16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_BF16).
// Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
// floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
// and 7 bits for the mantissa.
BFloat16 DType = 13
// Complex64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_C64).
// Complex values of fixed width.
//
// Paired F32 (real, imag), as in std::complex<float>.
Complex64 DType = 14
// Complex128 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_C128).
// Paired F64 (real, imag), as in std::complex<double>.
Complex128 DType = 15
// F8E5M2 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E5M2).
// Truncated 8 bit floating-point formats.
F8E5M2 DType = 16
// F8E4M3FN is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E4M3FN).
F8E4M3FN DType = 17
// F8E4M3B11FNUZ is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E4M3B11FNUZ).
F8E4M3B11FNUZ DType = 18
// F8E5M2FNUZ is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E5M2FNUZ).
F8E5M2FNUZ DType = 19
// F8E4M3FNUZ is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E4M3FNUZ).
F8E4M3FNUZ DType = 20
// S4 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S4).
// 4-bit integer types
S4 DType = 21
// U4 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U4).
U4 DType = 22
// TOKEN is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_TOKEN).
TOKEN DType = 23
// S2 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S2).
// 2-bit integer types
S2 DType = 24
// U2 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U2).
U2 DType = 25
// F8E4M3 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E4M3).
// More truncated 8 bit floating-point formats.
F8E4M3 DType = 26
// F8E3M4 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E3M4).
F8E3M4 DType = 27
// F8E8M0FNU is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E8M0FNU).
F8E8M0FNU DType = 28
// F4E2M1FN is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F4E2M1FN).
// 4-bit MX floating-point format.
F4E2M1FN DType = 29
)
// Aliases from PJRT C API.
const (
// INVALID (or PJRT_Buffer_Type_INVALID) is the C enum name for InvalidDType.
INVALID = InvalidDType
// PRED (or PJRT_Buffer_Type_PRED) is the C enum name for Bool.
PRED = Bool
// S8 (or PJRT_Buffer_Type_S8) is the C enum name for Int8.
S8 = Int8
// S16 (or PJRT_Buffer_Type_S16) is the C enum name for Int16.
S16 = Int16
// S32 (or PJRT_Buffer_Type_S32) is the C enum name for Int32.
S32 = Int32
// S64 (or PJRT_Buffer_Type_S64) is the C enum name for Int64.
S64 = Int64
// U8 (or PJRT_Buffer_Type_U8) is the C enum name for Uint8.
U8 = Uint8
// U16 (or PJRT_Buffer_Type_U16) is the C enum name for Uint16.
U16 = Uint16
// U32 (or PJRT_Buffer_Type_U32) is the C enum name for Uint32.
U32 = Uint32
// U64 (or PJRT_Buffer_Type_U64) is the C enum name for Uint64.
U64 = Uint64
// F16 (or PJRT_Buffer_Type_F16) is the C enum name for Float16.
F16 = Float16
// F32 (or PJRT_Buffer_Type_F32) is the C enum name for Float32.
F32 = Float32
// F64 (or PJRT_Buffer_Type_F64) is the C enum name for Float64.
F64 = Float64
// BF16 (or PJRT_Buffer_Type_BF16) is the C enum name for BFloat16.
BF16 = BFloat16
// C64 (or PJRT_Buffer_Type_C64) is the C enum name for Complex64.
C64 = Complex64
// C128 (or PJRT_Buffer_Type_C128) is the C enum name for Complex128.
C128 = Complex128
)
// MapOfNames to their dtypes. It includes also aliases to the various dtypes.
// It is also later initialized to include the lower-case version of the names.
var MapOfNames = map[string]DType{
"InvalidDType": InvalidDType,
"INVALID": InvalidDType,
"Bool": Bool,
"PRED": Bool,
"Int8": Int8,
"S8": Int8,
"Int16": Int16,
"S16": Int16,
"Int32": Int32,
"S32": Int32,
"Int64": Int64,
"S64": Int64,
"Uint8": Uint8,
"U8": Uint8,
"Uint16": Uint16,
"U16": Uint16,
"Uint32": Uint32,
"U32": Uint32,
"Uint64": Uint64,
"U64": Uint64,
"Float16": Float16,
"F16": Float16,
"Float32": Float32,
"F32": Float32,
"Float64": Float64,
"F64": Float64,
"BFloat16": BFloat16,
"BF16": BFloat16,
"Complex64": Complex64,
"C64": Complex64,
"Complex128": Complex128,
"C128": Complex128,
"F8E5M2": F8E5M2,
"F8E4M3FN": F8E4M3FN,
"F8E4M3B11FNUZ": F8E4M3B11FNUZ,
"F8E5M2FNUZ": F8E5M2FNUZ,
"F8E4M3FNUZ": F8E4M3FNUZ,
"S4": S4,
"U4": U4,
"TOKEN": TOKEN,
"S2": S2,
"U2": U2,
"F8E4M3": F8E4M3,
"F8E3M4": F8E3M4,
"F8E8M0FNU": F8E8M0FNU,
"F4E2M1FN": F4E2M1FN,
}
// PrimitiveType returns the DType equivalent used in C++ XlaBuilder.
// For internal use only.
//
// It is unfortunate, but the data types enums used in PJRT (which DType is modeled after)
// and C++ XlaBuilder (and other parts of XLA) don't match.
func (dtype DType) PrimitiveType() xla_data.PrimitiveType {
switch dtype {
case InvalidDType:
return xla_data.PrimitiveType_PRIMITIVE_TYPE_INVALID
case Bool:
return xla_data.PrimitiveType_PRED
case Int8:
return xla_data.PrimitiveType_S8
case Int16:
return xla_data.PrimitiveType_S16
case Int32:
return xla_data.PrimitiveType_S32
case Int64:
return xla_data.PrimitiveType_S64
case Uint8:
return xla_data.PrimitiveType_U8
case Uint16:
return xla_data.PrimitiveType_U16
case Uint32:
return xla_data.PrimitiveType_U32
case Uint64:
return xla_data.PrimitiveType_U64
case Float16:
return xla_data.PrimitiveType_F16
case Float32:
return xla_data.PrimitiveType_F32
case Float64:
return xla_data.PrimitiveType_F64
case BFloat16:
return xla_data.PrimitiveType_BF16
case Complex64:
return xla_data.PrimitiveType_C64
case Complex128:
return xla_data.PrimitiveType_C128
case F8E5M2:
return xla_data.PrimitiveType_F8E5M2
case F8E4M3FN:
return xla_data.PrimitiveType_F8E4M3FN
case F8E4M3B11FNUZ:
return xla_data.PrimitiveType_F8E4M3B11FNUZ
case F8E5M2FNUZ:
return xla_data.PrimitiveType_F8E5M2FNUZ
case F8E4M3FNUZ:
return xla_data.PrimitiveType_F8E4M3FNUZ
case S4:
return xla_data.PrimitiveType_S4
case U4:
return xla_data.PrimitiveType_U4
case TOKEN:
return xla_data.PrimitiveType_TOKEN
case S2:
return xla_data.PrimitiveType_S2
case U2:
return xla_data.PrimitiveType_U2
case F8E4M3:
return xla_data.PrimitiveType_F8E4M3
case F8E3M4:
return xla_data.PrimitiveType_F8E3M4
case F8E8M0FNU:
return xla_data.PrimitiveType_F8E8M0FNU
case F4E2M1FN:
return xla_data.PrimitiveType_F4E2M1FN
default:
return xla_data.PrimitiveType_PRIMITIVE_TYPE_INVALID
}
}
// FromPrimitiveType returns the equivalent DType.
// For internal use only.
//
// It is unfortunate, but the data types enums used in PJRT (which DType is modeled after)
// and C++ XlaBuilder (and other parts of XLA) don't match.
func FromPrimitiveType(primitiveType xla_data.PrimitiveType) DType {
switch primitiveType {
case xla_data.PrimitiveType_PRIMITIVE_TYPE_INVALID:
return InvalidDType
case xla_data.PrimitiveType_PRED:
return Bool
case xla_data.PrimitiveType_S8:
return Int8
case xla_data.PrimitiveType_S16:
return Int16
case xla_data.PrimitiveType_S32:
return Int32
case xla_data.PrimitiveType_S64:
return Int64
case xla_data.PrimitiveType_U8:
return Uint8
case xla_data.PrimitiveType_U16:
return Uint16
case xla_data.PrimitiveType_U32:
return Uint32
case xla_data.PrimitiveType_U64:
return Uint64
case xla_data.PrimitiveType_F16:
return Float16
case xla_data.PrimitiveType_F32:
return Float32
case xla_data.PrimitiveType_F64:
return Float64
case xla_data.PrimitiveType_BF16:
return BFloat16
case xla_data.PrimitiveType_C64:
return Complex64
case xla_data.PrimitiveType_C128:
return Complex128
case xla_data.PrimitiveType_F8E5M2:
return F8E5M2
case xla_data.PrimitiveType_F8E4M3FN:
return F8E4M3FN
case xla_data.PrimitiveType_F8E4M3B11FNUZ:
return F8E4M3B11FNUZ
case xla_data.PrimitiveType_F8E5M2FNUZ:
return F8E5M2FNUZ
case xla_data.PrimitiveType_F8E4M3FNUZ:
return F8E4M3FNUZ
case xla_data.PrimitiveType_S4:
return S4
case xla_data.PrimitiveType_U4:
return U4
case xla_data.PrimitiveType_TOKEN:
return TOKEN
case xla_data.PrimitiveType_S2:
return S2
case xla_data.PrimitiveType_U2:
return U2
case xla_data.PrimitiveType_F8E4M3:
return F8E4M3
case xla_data.PrimitiveType_F8E3M4:
return F8E3M4
case xla_data.PrimitiveType_F8E8M0FNU:
return F8E8M0FNU
case xla_data.PrimitiveType_F4E2M1FN:
return F4E2M1FN
default:
return InvalidDType
}
}