Skip to content

KV Cache

  • Why is it important?

    • When considering the standard self-attention mechanism, the time complexity is O(n^2·d), where:
      • n represents the number of input tokens (or sequence length).
      • d denotes the dimensionality of the vector representations.
    • This quadratic complexity arises due to the pairwise token operations required by self-attention.
  • What Is KV Caching?

    • The KV cache is a technique used to reduce the time complexity of self-attention during inference.
    • It stores the key and value vectors for each token in the input sequence.
    • KV caching occurs during multiple token generation steps and is primarily used in the decoder part of LLMs.
    • Specifically, it happens in decoder-only models like GPT or in the decoder portion of encoder-decoder models like T5.
    • The purpose of KV caching is to spare the recomputation of key (K) and value (V) tensors for past tokens during each generation step.
    • These tensors are cached (stored) in GPU memory as they are computed along the generation process 1.

alt text Reference: Transformers KV Caching Explained

  • Why Is It Important?
    • In an auto-regressive generation process (like GPT-2 text generation), the decoder predicts the next token based on the combined input of previous tokens.
    • Since the decoder is causal (attention only depends on preceding tokens), it recalculates the same previous token attention at each generation step.
    • KV caching allows us to focus on calculating attention only for the new token, improving efficiency.
    • By caching the previous keys and values, we avoid redundant computations.
    • Although it requires more GPU VRAM (or CPU RAM if GPU is not used), it significantly speeds up matrix multiplications
  • Comparison with and without KV Caching:
    • When KV caching is used, the matrices obtained are smaller, resulting in faster matrix operations.
    • Without KV caching, the attention computation involves unnecessary recomputation of previous token attention.
  • Memory Consumption and Challenges:
    • The KV cache grows linearly with the batch size and, more importantly, with the total sequence length.
    • Since the total sequence length is not known in advance, managing KV cache memory requirements becomes particularly challenging.
  • Let’s look at some numbers for popular MHA (multi-head attention) models:
    • For example, Meta’s Llama-2 (7B) model has KV cache memory consumption of approximately 0.5MB/token, assuming half precision (FP16 or BF16) storage.
  • Trade-Off: Memory vs. Compute:
    • KV caching is a trade-off between memory and compute. By storing precomputed tensors, it reduces the need for recomputation but increases memory usage.
    • The KV cache size is determined by the number of tokens, attention heads, and precision used