diff --git a/tensorflow_probability/python/layers/distribution_layer.py b/tensorflow_probability/python/layers/distribution_layer.py index f536417d1b..e856d14f2c 100644 --- a/tensorflow_probability/python/layers/distribution_layer.py +++ b/tensorflow_probability/python/layers/distribution_layer.py @@ -68,8 +68,11 @@ 'VariationalGaussianProcess', ] - -tf.keras.__internal__.utils.register_symbolic_tensor_type(dtc._TensorCoercible) # pylint: disable=protected-access +try: + k_u = tf.keras.__internal__.utils +except: + from keras.utils import tf_utils as k_u +k_u.register_symbolic_tensor_type(dtc._TensorCoercible) # pylint: disable=protected-access def _event_size(event_shape, name=None):