Introduction
Recently, many efforts have emerged aiming to expand the context length of Large Language Models (LLMs), all while keeping their perplexity stable. Among these initiatives, a noteworthy paper titled LongLoRA stands out. It impressively manages to extend the context size of the Lama-70b model to 32k and the more compact Llama-7b model to a whopping 100k! And guess what? It accomplishes this through a surprisingly straightforward approach (and saying this doesn’t undermine the paper’s value – in fact, I’m a fan of simplicity and effectiveness in solutions!). In this article, we’ll take a close look at LongLoRA, breaking down its content to understand how it operates and reviewing its practical applications.
Why it’s hard to expand the context size?
Understanding context size in Large Language Models (LLMs) is crucial before we delve deeper into the LongLoRA’s mechanisms. In simple terms, the “context size” refers to the amount of text or data the model can consider or “read” at once while generating responses. It’s like the model’s short-term memory, influencing the quality and relevance of its output.
Traditionally, LLMs have been limited in their context sizes due to computational constraints. The foundational architecture of LLMs relies on self-attention mechanism to incorporate contextual information, and we know that self-attention operation scales quadratically with sequence length. Bigger context sizes mean that the model requires more computational cost and memory, making it more challenging and expensive to train or even fine-tune. If we expand the context size from 2,048 to 8,192, it would require 16X computational costs. (since every tokens needs to attend to all the other tokens) This limitation restricts the model’s ability to understand and generate text that is coherent and relevant over longer passages, which is crucial for tasks like detailed document summarization, long-form question answering, and more. Most of the existing methods attempt to expand the context size of LLMs focuses on the extrapolation on positional embeddings, but this work tackle the problem at the attention mechanism.
LongLoRA
How do we maintain the computational cost while expanding the context length?
Shift Short Attention ($S^2-Attn$)
The concept behind $S^2-Attn$ is quite straightforward. Imagine we want to make the context length longer but don’t want to spend more on computational costs. One way to do this is by dividing the tokens into several smaller groups and then applying self-attention to each group separately.
Let’s visualize it with an example. Suppose we aim to extend the context size to 8,192. Instead of applying self-attention to all 8,192 tokens at once, which would be computationally expensive, we divide them into four smaller groups, each containing 2,048 tokens. Now, we apply self-attention within each of these smaller groups. Each group handles 2,048 tokens, and the self-attention process occurs within these manageable chunks. This method allows us to handle more tokens overall (a total of 8,192) without increasing computational costs, since the self-attention mechanism is still working on smaller segments (groups of 2,048 tokens) at a time. This way, $S^2-Attn$ cleverly enables a longer context without demanding more computational power.
While efficient, it faces a challenge: information gets disconnected between the various groups. To tackle this, the authors introduce a clever technique called the shift pattern in the attention mechanism.
Let’s break this down. First, the attention heads are divided into two halves. With the first half, self-attention is applied within each group of tokens. For example, it operates on tokens [1…2,048], then [2,049…4,096], and so on, up to the full context length of [6,145…8,192].
However, with the second half of the attention heads, things get a bit different. Here, a shift pattern is implemented. Self-attention begins with tokens [1,025…3,072], then moves on to [3,073…5,120], followed by [5,121…7,168], and finally [7,169…8,192…1,024]. Notice something special about the shifted pattern? The first and last 1,024 tokens of the context are in the same group. This smart shifting reconnects the information that was previously disjointed between groups in the first half of the attention heads, and it does this without requiring extra computational cost. In this manner, the shift pattern seamlessly bridges the information gap, ensuring a smooth flow of understanding across all tokens.
Above is a simple visualization on how the shifting is performed (we take only a small section from heads 7).
Once the self-attention computations are complete, they are shifted back to their original positions. Now, each token has a representation that attends not only to its immediate neighbors but also to tokens in adjacent groups.
Note: The reason for shifting back is to realign the token representations with their original positions in the sequence. This realignment is crucial as it helps maintain the coherence and structure of the output, either as it progresses through subsequent layers or is utilized in loss computation during the training process.
LoRA
Many LLMs are pre-trained with a fixed context length. If we try to increase the context length directly without fine-tuning, the performance tends to deteriorate. The table below, extracted from the paper, illustrates this point well: In the table’s first row, you see the results of an experiment where the context length was expanded without any additional training. This experiment, conducted on the PG19 dataset and evaluated based on perplexity, shows a decline in performance as the context length grows. However, fine-tuning LLMs is no small feat. It demands a considerable amount of computing power and memory. To navigate around this, most of existing works decided to employ the LoRA approach during the fine-tuning phase. Not only does this method prevent memory usage from skyrocketing, but it’s also become a cornerstone in the fine-tuning of LLMs. For those keen on understanding how LoRA works, feel free to dive into my earlier blog post. There, I provide a detailed explanation and walk through its implementation for a better understanding.
Additional “Tricks”
While the researchers have introduced a new attention pattern and integrated it with LoRA for the fine-tuning process, there still appears to be a performance gap. Models trained using this combined approach don’t quite match the performance of models that have been fully fine-tuned, even when larger ranks are utilized for LoRA.
Note: When we mention “full fine-tuning” here, we’re talking about models that have been fine-tuned using $S^2-Attn$ but without incorporating LoRA into the process. While increasing the LoRA rank does improve matters somewhat, the perplexity values are still noticeably higher than those observed with full fine-tuning.
Interestingly, the researchers discovered a solution to narrow this performance gap: making the normalization and embedding layers trainable. By allowing these layers to learn and adjust during training, the models’ performance edged closer to that achieved through full fine-tuning. Even better, this tweak doesn’t demand much more memory. Despite their critical role, the normalization and embedding layers make up only about 3% of the model’s total parameters, so making them trainable doesn’t significantly increase memory requirements.
Note: Importantly, this approach is not only effective but also versatile. It can be effectively integrated with other efficient fine-tuning approaches, like Flash-Attention 2.
Implementation
Now, let’s implement the ($S^2-Attn$) for better understanding. The implementation is heavily based on the official repo. Feel free to check it out. Let’s random initialize a qkv with integer values.
1
2
3
4
5
6
7
import torch
# Define the size of the tensor
# batch_size, context_length, 3=q,k,v,num_heads,head_dims
size = (1, 8192, 3, 8, 64)
qkv = torch.randint(low=0, high=10, size=size, dtype=torch.int32)
# to be used in later comparison of before and after shifting
qkv_ori = qkv.clone()
Let’s define the number of groups. (How many group we want to split the tokens into) In this case, we follow the original implementation and split them into four groups where each groups contains 2,048 tokens each.
1
2
3
4
5
6
group_size_ratio = 1/4
bsz, q_len, _, num_heads, head_dim = qkv.shape
group_size = int(q_len * group_size_ratio)
if q_len % group_size > 0:
raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size))
num_group = q_len // group_size
This is the core of ($S^2-Attn$). They perform a left shit along the dimension of sequence length (context length) on the other half of attention heads.
1
2
3
4
5
6
7
8
9
# Perform shifting
qkv[:, :, :, num_heads//2:] = qkv[:, :, :, num_heads//2:].roll(-group_size//2, dims=1)
qkv_temp = qkv
is_equal = torch.equal(qkv_ori[:,:,:,num_heads//2:][0,0],qkv_temp[:,:,:,num_heads//2:][0,7168])
print(f"Original tensor at position 0 and shifted tensor at position 7168 are {is_equal}")
qkv = qkv.reshape(bsz*num_group, group_size, 3, num_heads, head_dim)
print(qkv.shape)
# continue to use qkv and perform self-attention
...
and we received the following output:
1
2
Original tensor at position 0 and shifted tensor at position 7168 are True
torch.Size([4, 2048, 3, 8, 64])
This is align with our explanation earlier that the last group (if we use context length of 8,192) [7,169…8192…1024] will contain overlapping tokens with first group (first and the last 1024 tokens belong to the same group). Lastly, we shift back the output of self-attention to retain the original input sequence order.
1
attn_output[:, :, num_heads//2:] = attn_output[:, :, num_heads//2:].roll(group_size//2, dims=1)
Note: they perform left shifting in the shifted patterns and from their github, it seems like they haven’t tested using right shift
Conclusion
Having a longer context length can be a game-changer, especially for applications that use RAG-based LLMs. When you can feed more information into an LLM while still getting precise answers out, that’s a big win! With extended context length, we might eventually get to a point where the retrieval performance doesn’t matter as much because the LLM can compensate and sift through (maybe) the top 100 contexts and still generate accurate responses smoothly.
In reflecting on the broader landscape, it’s evident that the open-source community is significantly contributing to the rapid development of LLMs. While powerhouse models like GPT-4 or Google’s Germini (not sure when they are going to release tho) are leading the pack, open-source LLMs aren’t far behind, thanks to their fast-paced improvement and the ability to try out different ideas on them. The unique advantage of open-source models is their customizability, allowing for quick adaptations and experimentations.
Moreover, when it comes to practical applications in various industries, the need isn’t necessarily for a jack-of-all-trades LLM. Instead, there’s often a preference for models that excel at specific tasks. This isn’t to downplay the importance of developing versatile LLMs – indeed, crafting LLMs that are as well-rounded as possible is crucial for the eventual realization of AGI (still hoping someday I will have AI that can make my meal).
I hope this article has shed some light on LongLoRA and how it successfully extends the context length of Llama-7B to an impressive 100k. I look forward to engaging with you more on these fascinating research in future articles. Until next time!