-
Notifications
You must be signed in to change notification settings - Fork 346
Description
Describe the bug
When using the default Keras MultiHeadAttention layer in a model, even if you are not quantising the attention layers, the quantize_apply function leads to an error.
System information
TensorFlow version: 2.6.0.dev20210330
TensorFlow Model Optimization version: 0.5.0.dev20210409
Python version: 3.7.10
Describe the expected behavior
The model is quantised successfully.
Describe the current behavior
During quantisation an error is thrown regarding incompatible variable shapes.
Code to reproduce the issue
import tensorflow as tf
import tensorflow_model_optimization as tfmot
def get_model():
inp = tf.keras.layers.Input(shape=(16, 8), batch_size=32)
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=4)(query=inp, value=inp)
out = tf.keras.layers.Dense(units=64)(x)
return tf.keras.Model(inputs=inp, outputs=out)
def apply_quantization_to_dense(layer):
if isinstance(layer, tf.keras.layers.Dense):
return tfmot.quantization.keras.quantize_annotate_layer(layer)
return layer
base_model = get_model()
annotated_model = tf.keras.models.clone_model(
base_model,
clone_function=apply_quantization_to_dense,
)
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model) # Error
Additional context
This error does not only occur when using the MultiHeadAttention layer. If a custom layer has weights with names that end with the same string (e.g. kernel:0) the same type of error occurs.
In model_transformer.py the _weight_name() method only extracts the last part of a weight name, so if there are multiple weights ending in the same way (as with the MultiHeadAttention layer) the weights are mapped incorrectly.