|
13 | 13 | from keras.src import quantizers |
14 | 14 | from keras.src import regularizers |
15 | 15 | from keras.src.api_export import keras_export |
| 16 | +from keras.src.dtype_policies import QuantizedFloat8DTypePolicy |
16 | 17 | from keras.src.layers.input_spec import InputSpec |
17 | 18 | from keras.src.layers.layer import Layer |
18 | 19 | from keras.src.quantizers.quantization_config import QuantizationConfig |
19 | | -from keras.src.quantizers.quantization_config import validate_and_resolve_config |
20 | 20 | from keras.src.quantizers.quantizers import dequantize_with_sz_map |
21 | 21 |
|
22 | 22 |
|
@@ -457,91 +457,8 @@ def quantized_build(self, kernel_shape, mode, config=None): |
457 | 457 | raise self._quantization_mode_error(mode) |
458 | 458 | self._is_quantized = True |
459 | 459 |
|
460 | | - def quantize(self, mode=None, type_check=True, config=None): |
461 | | - # Prevent quantization of the subclasses |
462 | | - if type_check and (type(self) is not EinsumDense): |
463 | | - raise self._not_implemented_error(self.quantize) |
464 | | - |
465 | | - config = validate_and_resolve_config(mode, config) |
466 | | - mode = config.mode |
467 | | - |
468 | | - kernel_shape = self._kernel.shape |
469 | | - if mode == "int8": |
470 | | - # Handle activation quantization |
471 | | - if config.activation_quantizer: |
472 | | - self.inputs_quantizer = config.activation_quantizer |
473 | | - elif config.activation_quantizer is None: |
474 | | - # Weight-only quantization |
475 | | - pass |
476 | | - else: |
477 | | - # Default behavior |
478 | | - self.inputs_quantizer = quantizers.AbsMaxQuantizer( |
479 | | - axis=self._input_reduced_axes |
480 | | - ) |
481 | | - |
482 | | - # Handle weight quantization |
483 | | - # Quantize `self._kernel` to int8 and compute corresponding scale |
484 | | - weight_quantizer = QuantizationConfig.weight_quantizer_or_default( |
485 | | - config, quantizers.AbsMaxQuantizer(axis=0) |
486 | | - ) |
487 | | - self._kernel, self.kernel_scale = weight_quantizer( |
488 | | - self._kernel, to_numpy=True |
489 | | - ) |
490 | | - self.quantized_build(kernel_shape, mode, config) |
491 | | - |
492 | | - elif mode == "int4": |
493 | | - # Handle activation quantization |
494 | | - if config.activation_quantizer: |
495 | | - self.inputs_quantizer = config.activation_quantizer |
496 | | - elif config.activation_quantizer is None: |
497 | | - # Weight-only quantization |
498 | | - pass |
499 | | - else: |
500 | | - # Default behavior |
501 | | - self.inputs_quantizer = quantizers.AbsMaxQuantizer( |
502 | | - axis=self._input_reduced_axes |
503 | | - ) |
504 | | - |
505 | | - # Handle weight quantization |
506 | | - # 1. Quantize to int4 values (stored in int8 dtype, range [-8, 7]) |
507 | | - weight_quantizer = QuantizationConfig.weight_quantizer_or_default( |
508 | | - config, |
509 | | - quantizers.AbsMaxQuantizer( |
510 | | - axis=0, |
511 | | - value_range=(-8, 7), |
512 | | - output_dtype="int8", |
513 | | - ), |
514 | | - ) |
515 | | - self._kernel, self.kernel_scale = weight_quantizer( |
516 | | - self._kernel, to_numpy=True |
517 | | - ) |
518 | | - # 2. Pack two int4 values into a single int8 byte. |
519 | | - # Choose the axis to perform int4 packing - use the first reduced |
520 | | - # axis for the kernel (analogous to the input dimension of a Dense |
521 | | - # layer). |
522 | | - self._int4_pack_axis = ( |
523 | | - self._kernel_reduced_axes[0] if self._kernel_reduced_axes else 0 |
524 | | - ) |
525 | | - self._kernel, _, _ = quantizers.pack_int4( |
526 | | - self._kernel, axis=self._int4_pack_axis |
527 | | - ) |
528 | | - self.quantized_build(kernel_shape, mode, config) |
529 | | - |
530 | | - elif mode == "float8": |
531 | | - self.quantized_build(kernel_shape, mode) |
532 | | - |
533 | | - elif mode == "gptq": |
534 | | - self.quantized_build(kernel_shape, mode, config) |
535 | | - |
536 | | - # Set new dtype policy. |
537 | | - if self.dtype_policy.quantization_mode is None: |
538 | | - policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}") |
539 | | - self.dtype_policy = policy |
540 | | - |
541 | 460 | def _int8_build(self, kernel_shape, config=None): |
542 | 461 | self._set_quantization_info() |
543 | | - from keras.src.quantizers.quantization_config import QuantizationConfig |
544 | | - |
545 | 462 | self.inputs_quantizer = ( |
546 | 463 | QuantizationConfig.activation_quantizer_or_default( |
547 | 464 | config, |
@@ -691,8 +608,6 @@ def _int4_build(self, kernel_shape, config=None): |
691 | 608 | self._set_quantization_info() |
692 | 609 |
|
693 | 610 | # Quantizer for the inputs (per the reduced axes) |
694 | | - from keras.src.quantizers.quantization_config import QuantizationConfig |
695 | | - |
696 | 611 | self.inputs_quantizer = ( |
697 | 612 | QuantizationConfig.activation_quantizer_or_default( |
698 | 613 | config, |
@@ -736,8 +651,6 @@ def _int4_build(self, kernel_shape, config=None): |
736 | 651 | ) |
737 | 652 |
|
738 | 653 | def _float8_build(self): |
739 | | - from keras.src.dtype_policies import QuantizedFloat8DTypePolicy |
740 | | - |
741 | 654 | # If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set |
742 | 655 | # `amax_history_length` to its default value. |
743 | 656 | amax_history_length = getattr( |
@@ -903,7 +816,8 @@ def grad_fn(*args, upstream=None): |
903 | 816 | # Quantize inputs per `self.inputs_quantizer`. |
904 | 817 | if self.inputs_quantizer: |
905 | 818 | inputs_q, inputs_scale = self.inputs_quantizer(inputs) |
906 | | - # Align `inputs_scale` axes with the output for correct broadcasting |
| 819 | + # Align `inputs_scale` axes with the output |
| 820 | + # for correct broadcasting |
907 | 821 | inputs_scale = self._adjust_scale_for_quant( |
908 | 822 | inputs_scale, "input" |
909 | 823 | ) |
@@ -1036,10 +950,8 @@ def quantize(self, mode, type_check=True, config=None): |
1036 | 950 | raise self._not_implemented_error(self.quantize) |
1037 | 951 |
|
1038 | 952 | kernel_shape = self._kernel.shape |
1039 | | - if mode in ("int8", "int4", "gptq"): |
1040 | | - self._set_quantization_info() |
1041 | 953 |
|
1042 | | - from keras.src.quantizers.quantization_config import QuantizationConfig |
| 954 | + self._set_quantization_info() |
1043 | 955 |
|
1044 | 956 | if mode == "int8": |
1045 | 957 | # Quantize `self._kernel` to int8 and compute corresponding scale |
|
0 commit comments