Recently I have been studying your code. However, It seems to me that your implemention will not expand the kv cache during the decoding phase. The follow code is excerpted from the function def _concatenate_to_cache in llama.py.
if query.shape[1] == 1:
mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim)
def fn(cached_key, cached_value, key, value, cur_index):
assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape)
sp_size = max_length // mesh.shape['sp']
axis_index = jax.lax.axis_index('sp')
cur_index = cur_index - axis_index * sp_size
key, value = jax.lax.cond(
jnp.logical_and(cur_index >= 0, cur_index < sp_size),
lambda: (
cached_key.at[:, cur_index].set(key[:, -1]),
cached_value.at[:, cur_index].set(value[:, -1]),
),
lambda: (cached_key, cached_value),
)
return key, value
In this function, we will only update cached_key and cached_value with the newly-generated key/value in the decoding phase, instead of pushing back them into the cached_key and cached_value. However, it seems to me that a correct implementation of kvcache should make the size of kvcache grow and become longer.
Maybe I do not fully understand your code, but I am looking forward to your reply.
Recently I have been studying your code. However, It seems to me that your implemention will not expand the kv cache during the decoding phase. The follow code is excerpted from the function
def _concatenate_to_cachein llama.py.In this function, we will only update
cached_keyandcached_valuewith the newly-generated key/value in the decoding phase, instead of pushing back them into thecached_keyandcached_value. However, it seems to me that a correct implementation of kvcache should make the size of kvcache grow and become longer.Maybe I do not fully understand your code, but I am looking forward to your reply.