Skip to content

Commit fad1ed2

Browse files
improve einsum
1 parent 6344cf5 commit fad1ed2

File tree

6 files changed

+12
-118
lines changed

6 files changed

+12
-118
lines changed

keras/src/layers/core/dense.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras.src import quantizers
1010
from keras.src import regularizers
1111
from keras.src.api_export import keras_export
12+
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
1213
from keras.src.layers.input_spec import InputSpec
1314
from keras.src.layers.layer import Layer
1415
from keras.src.quantizers.quantization_config import QuantizationConfig
@@ -384,8 +385,6 @@ def quantized_build(self, kernel_shape, mode, config=None):
384385
self._is_quantized = True
385386

386387
def _int8_build(self, kernel_shape, config=None):
387-
from keras.src.quantizers.quantization_config import QuantizationConfig
388-
389388
# Per-channel int8 quantizer for the last axis (features).
390389
self.inputs_quantizer = (
391390
QuantizationConfig.activation_quantizer_or_default(
@@ -500,8 +499,6 @@ def _int4_build(self, kernel_shape, config=None):
500499
int8 byte.
501500
"""
502501
# Per-channel int8 quantizer for the last axis (features).
503-
from keras.src.quantizers.quantization_config import QuantizationConfig
504-
505502
self.inputs_quantizer = (
506503
QuantizationConfig.activation_quantizer_or_default(
507504
config, quantizers.AbsMaxQuantizer(axis=-1)
@@ -529,8 +526,6 @@ def _int4_build(self, kernel_shape, config=None):
529526
self._orig_input_dim = input_dim
530527

531528
def _float8_build(self):
532-
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
533-
534529
# If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set
535530
# `amax_history_length` to its default value.
536531
amax_history_length = getattr(
@@ -781,16 +776,6 @@ def quantize(self, mode=None, type_check=True, config=None):
781776

782777
kernel_shape = self._kernel.shape
783778
if mode == "int8":
784-
# Handle activation quantization
785-
if config.activation_quantizer:
786-
self.inputs_quantizer = config.activation_quantizer
787-
elif config.activation_quantizer is None:
788-
# Weight-only quantization
789-
pass
790-
else:
791-
# Default behavior
792-
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
793-
794779
# Handle weight quantization
795780
# Quantize `self._kernel` to int8 and compute corresponding scale
796781
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(

keras/src/layers/core/einsum_dense.py

Lines changed: 4 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
from keras.src import quantizers
1414
from keras.src import regularizers
1515
from keras.src.api_export import keras_export
16+
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
1617
from keras.src.layers.input_spec import InputSpec
1718
from keras.src.layers.layer import Layer
1819
from keras.src.quantizers.quantization_config import QuantizationConfig
19-
from keras.src.quantizers.quantization_config import validate_and_resolve_config
2020
from keras.src.quantizers.quantizers import dequantize_with_sz_map
2121

2222

@@ -457,91 +457,8 @@ def quantized_build(self, kernel_shape, mode, config=None):
457457
raise self._quantization_mode_error(mode)
458458
self._is_quantized = True
459459

460-
def quantize(self, mode=None, type_check=True, config=None):
461-
# Prevent quantization of the subclasses
462-
if type_check and (type(self) is not EinsumDense):
463-
raise self._not_implemented_error(self.quantize)
464-
465-
config = validate_and_resolve_config(mode, config)
466-
mode = config.mode
467-
468-
kernel_shape = self._kernel.shape
469-
if mode == "int8":
470-
# Handle activation quantization
471-
if config.activation_quantizer:
472-
self.inputs_quantizer = config.activation_quantizer
473-
elif config.activation_quantizer is None:
474-
# Weight-only quantization
475-
pass
476-
else:
477-
# Default behavior
478-
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
479-
axis=self._input_reduced_axes
480-
)
481-
482-
# Handle weight quantization
483-
# Quantize `self._kernel` to int8 and compute corresponding scale
484-
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
485-
config, quantizers.AbsMaxQuantizer(axis=0)
486-
)
487-
self._kernel, self.kernel_scale = weight_quantizer(
488-
self._kernel, to_numpy=True
489-
)
490-
self.quantized_build(kernel_shape, mode, config)
491-
492-
elif mode == "int4":
493-
# Handle activation quantization
494-
if config.activation_quantizer:
495-
self.inputs_quantizer = config.activation_quantizer
496-
elif config.activation_quantizer is None:
497-
# Weight-only quantization
498-
pass
499-
else:
500-
# Default behavior
501-
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
502-
axis=self._input_reduced_axes
503-
)
504-
505-
# Handle weight quantization
506-
# 1. Quantize to int4 values (stored in int8 dtype, range [-8, 7])
507-
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
508-
config,
509-
quantizers.AbsMaxQuantizer(
510-
axis=0,
511-
value_range=(-8, 7),
512-
output_dtype="int8",
513-
),
514-
)
515-
self._kernel, self.kernel_scale = weight_quantizer(
516-
self._kernel, to_numpy=True
517-
)
518-
# 2. Pack two int4 values into a single int8 byte.
519-
# Choose the axis to perform int4 packing - use the first reduced
520-
# axis for the kernel (analogous to the input dimension of a Dense
521-
# layer).
522-
self._int4_pack_axis = (
523-
self._kernel_reduced_axes[0] if self._kernel_reduced_axes else 0
524-
)
525-
self._kernel, _, _ = quantizers.pack_int4(
526-
self._kernel, axis=self._int4_pack_axis
527-
)
528-
self.quantized_build(kernel_shape, mode, config)
529-
530-
elif mode == "float8":
531-
self.quantized_build(kernel_shape, mode)
532-
533-
elif mode == "gptq":
534-
self.quantized_build(kernel_shape, mode, config)
535-
536-
# Set new dtype policy.
537-
if self.dtype_policy.quantization_mode is None:
538-
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
539-
self.dtype_policy = policy
540-
541460
def _int8_build(self, kernel_shape, config=None):
542461
self._set_quantization_info()
543-
from keras.src.quantizers.quantization_config import QuantizationConfig
544-
545462
self.inputs_quantizer = (
546463
QuantizationConfig.activation_quantizer_or_default(
547464
config,
@@ -691,8 +608,6 @@ def _int4_build(self, kernel_shape, config=None):
691608
self._set_quantization_info()
692609

693610
# Quantizer for the inputs (per the reduced axes)
694-
from keras.src.quantizers.quantization_config import QuantizationConfig
695-
696611
self.inputs_quantizer = (
697612
QuantizationConfig.activation_quantizer_or_default(
698613
config,
@@ -736,8 +651,6 @@ def _int4_build(self, kernel_shape, config=None):
736651
)
737652

738653
def _float8_build(self):
739-
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
740-
741654
# If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set
742655
# `amax_history_length` to its default value.
743656
amax_history_length = getattr(
@@ -903,7 +816,8 @@ def grad_fn(*args, upstream=None):
903816
# Quantize inputs per `self.inputs_quantizer`.
904817
if self.inputs_quantizer:
905818
inputs_q, inputs_scale = self.inputs_quantizer(inputs)
906-
# Align `inputs_scale` axes with the output for correct broadcasting
819+
# Align `inputs_scale` axes with the output
820+
# for correct broadcasting
907821
inputs_scale = self._adjust_scale_for_quant(
908822
inputs_scale, "input"
909823
)
@@ -1036,10 +950,8 @@ def quantize(self, mode, type_check=True, config=None):
1036950
raise self._not_implemented_error(self.quantize)
1037951

1038952
kernel_shape = self._kernel.shape
1039-
if mode in ("int8", "int4", "gptq"):
1040-
self._set_quantization_info()
1041953

1042-
from keras.src.quantizers.quantization_config import QuantizationConfig
954+
self._set_quantization_info()
1043955

1044956
if mode == "int8":
1045957
# Quantize `self._kernel` to int8 and compute corresponding scale

keras/src/layers/core/reversible_embedding.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def _int8_build(self, embeddings_shape, config=None):
187187
if embeddings_shape is None:
188188
embeddings_shape = (self.input_dim, self.output_dim)
189189
super()._int8_build(embeddings_shape=embeddings_shape)
190-
from keras.src.quantizers.quantization_config import QuantizationConfig
191190

192191
self.inputs_quantizer = (
193192
QuantizationConfig.activation_quantizer_or_default(
@@ -213,7 +212,6 @@ def _int4_build(self, embeddings_shape, config=None):
213212
if embeddings_shape is None:
214213
embeddings_shape = (self.input_dim, self.output_dim)
215214
super()._int4_build(embeddings_shape=embeddings_shape, config=config)
216-
from keras.src.quantizers.quantization_config import QuantizationConfig
217215

218216
self.inputs_quantizer = (
219217
QuantizationConfig.activation_quantizer_or_default(

keras/src/models/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras.src.layers.layer import Layer
1010
from keras.src.models.variable_mapping import map_saveable_variables
1111
from keras.src.quantizers.gptq_core import gptq_quantize
12+
from keras.src.quantizers.quantization_config import validate_and_resolve_config
1213
from keras.src.saving import saving_api
1314
from keras.src.trainers import trainer as base_trainer
1415
from keras.src.utils import summary_utils
@@ -433,9 +434,6 @@ def quantize(self, mode=None, config=None, **kwargs):
433434
time.
434435
config: The configuration of the quantization.
435436
"""
436-
from keras.src.quantizers.quantization_config import (
437-
validate_and_resolve_config,
438-
)
439437

440438
# Validate inputs.
441439
type_check = kwargs.pop("type_check", True)

keras/src/quantizers/gptq_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,8 @@ def test_quantize_gptq_combinations(self, dataset, config):
617617
"mode": "gptq",
618618
"config": {"weight_bits": 4},
619619
"expected_exception": ValueError,
620-
"error_msg": "Argument `config` must be an instance of `QuantizationConfig`",
620+
"error_msg": "Argument `config` must be an instance of "
621+
"`QuantizationConfig`",
621622
},
622623
{
623624
"testcase_name": "gptq_with_none_config",

keras/src/quantizers/quantization_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def activation_quantizer_or_default(config, default):
5454
@keras_export("keras.quantizers.Int8QuantizationConfig")
5555
class Int8QuantizationConfig(QuantizationConfig):
5656
def __init__(self, weight_quantizer=None, activation_quantizer="default"):
57-
if activation_quantizer == "default":
58-
from keras.src.quantizers.quantizers import AbsMaxQuantizer
57+
from keras.src.quantizers.quantizers import AbsMaxQuantizer
5958

59+
if activation_quantizer == "default":
6060
activation_quantizer = AbsMaxQuantizer(axis=-1)
6161
super().__init__(weight_quantizer, activation_quantizer)
6262
if self.weight_quantizer:
@@ -76,9 +76,9 @@ def mode(self):
7676
@keras_export("keras.quantizers.Int4QuantizationConfig")
7777
class Int4QuantizationConfig(QuantizationConfig):
7878
def __init__(self, weight_quantizer=None, activation_quantizer="default"):
79-
if activation_quantizer == "default":
80-
from keras.src.quantizers.quantizers import AbsMaxQuantizer
79+
from keras.src.quantizers.quantizers import AbsMaxQuantizer
8180

81+
if activation_quantizer == "default":
8282
activation_quantizer = AbsMaxQuantizer(axis=-1)
8383
super().__init__(weight_quantizer, activation_quantizer)
8484
if self.weight_quantizer:

0 commit comments

Comments
 (0)