Changelog:
- 6/24/25: Wrote entire first draft of the lecture.
- 6/30/25: Fix: moved this changelog to inside content div.
Lecture 12: Adapting Pre-Trained Models to New Tasks
In the last lecture, we saw how to use sampling techniques to generate programs from a language model, subject to some contraints. In general, such techniques can be very useful when the program we are looking for is not too unlikely to be generated by the model. However, if the program is very unlikely, or even has zero probability, then we will need to go beyond sampling techniques and instead adapt the model itself to the task at hand. In this lecture, we will look at a few different ways to adapt a pre-trained model to a specific task, using prompts and weights.Finetuning
The most straightforward way to adapt a pre-trained model to a specific task is to finetune it on a dataset that is relevant to the task. Finetuning really just means that we train the model on a new dataset, but starting from the pre-trained weights instead of from scratch. Oftentimes, this lets us get away with training on relatively little data, since the base model already has a lot of general capabilities. Unlike the sampling techniques we saw in the last lecture, finetuning fundamentally changes the model itself. This means that even if the desired program had zero probability under the original model $p_\theta$, it may have high likelihood under the finetuned model $p_{\theta'}$. To see why this would be useful, suppose you were a hardware researcher who has come up with a new hardware description language, like BlueSpec (TODO cite). You want to use a language model to help computer architects design new hardware in this language, but the model has never seen any examples of this language before. If you were to use the model as is, it would likely generate programs that are not valid in the new language; even if you used sampling techniques to try to guide the model towards valid programs, it would still be unlikely to be able to make full use of the new language's features. If you instead finetuned the model on a dataset of programs written in this new language, you would fundamentally increase the model's ability to design hardware in this language, thus making it much more useful. While finetuning is a powerful technique, it does have some drawbacks. First, it requires access to a dataset that is relevant to the task at hand, which may be difficult to construct in cases such as the above. Secondly, as a form of continual learning, finetuning can lead to catastrophic forgetting, where the model forgets how to perform tasks it was previously able to do. In the worst case, this can mean that the model loses some of its general capabilities that are actually useful in the domain, such as understanding the natural language feedback from the user. Finally, finetuning can be computationally expensive, as it requires training the entire model. While the first two of these issues can be mitigated to some extent by careful dataset construction and training, the last one is a fundamental limitation of finetuning. As a result, several parameter-efficient variations have been developed, the most common of which is low-rank adaptation (LoRA).Low-Rank Adaptation (LoRA)
In LoRA, instead of finetuning the entire model, we keep the model's weights fixed and instead learn a small set of adapters. These adapters consist of two low-rank (tall-skinny) matrices which are multipled together and then added to the model's activations. Consider for example a (single-head) self-attention layer, with weights $W_q, W_k, W_v$ for the query, key, and value projections respectively. Each of these weights is a matrix of size $d \times d$, where $d$ is the layer's hidden dimension; thus, if we were to finetune the model, we would need to update $3d^2$ parameters. In LoRA, we instead learn six low-rank matrices $A_q, B_q, A_k, B_k, A_v, B_v \in \mathbb{R}^{d \times r}$, where $r$ is a small rank (typically much smaller than $d$). We then replace the original query, key, and value projections with the following: \[ \begin{align*} Q &= x(W_q + A_q B_q^T) = x W_q + x A_q B_q^T \\ K &= x(W_k + A_k B_k^T) = x W_k + x A_k B_k^T \\ V &= x(W_v + A_v B_v^T) = x W_v + x A_v B_v^T \end{align*} \] where $x$ is the input to the self-attention layer. This means that we only need to update $6dr$ parameters, which is smaller than $3d^2$ as long as $r < d/2$. In practice, LoRA is often applied to the fully-connected layers of the model as well, which typically have much larger hidden dimensions than the self-attention layers; this is where the largest savings in parameters come from. LoRA is a very powerful technique, and has come to dominate the field of parameter-efficient finetuning. Since the number of parameters that are updated is much smaller than in standard finetuning, LoRA is faster, requires less data, and is less prone to catastrophic forgetting than standard finetuning. However, intuitively, LoRA can only adapt the model to tasks that are similar to the tasks it was pre-trained on.Retreival-Augmented Generation
Another way to adapt a pre-trained model to a new task is to use retrieval-augmented generation (RAG). In RAG, the model weights are not adapted to the new task; instead, the adaptation is done by pairing it with a retrieval model (or, more generally, a vector-similarity database). Formally, RAG assumes that we have a pre-trained language model $p_\theta$ and a retrieval model $r_\phi$. We then begin by processing our dataset, which consists of a set of documents $D = \{d_1, d_2, \ldots, d_n\}$, into a set of embeddings $E = \{e_i \triangleq r_\phi(d_i) \mid i = 1, \ldots, n\}$. These can then be stored in a key-value store, where the keys are the embeddings $e_i$ and the values are the documents $d_i$. When we want to generate a response to a query $q$, we then embed the query using the retrieval model, $e_q \triangleq r_\phi(q)$, and then retrieve the top $k$ documents that are most similar to the query embedding. The notion of similarity can be defined in many ways, but the most common is to use cosine similarity: \[ \text{similarity}(e_q, e_i) = \frac{\langle e_q, e_i\rangle}{\|e_q\| \|e_i\|}.\] \] The top $k$ documents are then concatenated into a prompt which is passed to the pre-trained language model in order to generate a response. For a more concrete example of how one would construct such a system, suppose again that we are trying to teach a language model to write programs in a new hardware description language, as in the finetuning example above. We may find that simply providing the model with a few examples of programs written in this language is not enough to get it to generate valid programs. And even if we have a large dataset of programs written in this language, it may not be feasible for us to finetune the model; for example, we may not have access to the model's weights, or we may not have enough computational resources to train the model, or the engineering effort involved in balancing learning the new dataset versus retaining the model's general capabilities may simply be too high. In this case, RAG offers a very powerful alternative. Instead of finetuning the model, we can simply pair the model with an off-the-shelf shelf retrieval model such as BM25 or a BERT-based retriever, and then follow the steps outlined above to retrieve the relevant examples from our dataset of programs. Thus, in the ideal case, the model will be equipped with a number of examples of programs that are similar to the one we want to generate, hopefully allowing it to find the desired program. Thus, RAG is very similar to in-context learning, but with two key differences:- Historically, the context windows of transformers have been limited to a few thousand tokens, making it difficult to put all relevant information in the prompt. This has been mitigated to some extent by the development of long-context transformers, but these remain limited to perhaps 64k or 128k tokens. In RAG, the amount of relevant information that can be made use of is much larger, as it is not limited by the context window of the model; instead, the only bottleneck is the speed at which the retrieval system can return the relevant documents.
- With RAG, the retrieval model also serves as a filter. This can reduce the amount of irrelevant information that is passed to the model, which can have a significant impact on performance.
Test-time training
The final technique we will look at is test-time training (TTT). Test-time training seeks to adapt the model to the task at hand by training it on the fly, during inference. While it was first proposed in the context of computer vision, it has recently shown promise in language modeling as well, particularly in the context of abstract reasoning. It seeks to replicate the power of finetuning without the need for large datasets, instead relying on being able to adapt the model to the specific task at hand (instead of the entire domain) using only a few examples. Unlike finetuning, it also plays very nicely with RAG. Suppose again that we are trying to teach the model to write programs in our new hardware description language. We may find that we simply do not have enough examples to finetune the model, but that RAG is not enough to get the model to generate valid programs. In this case, we could explore using TTT as follows:- Given a query $q$, we first retrieve a set $D_q$ consisting of $k$ relevant examples from our retrieval database. (Alternatively, these examples could be provided by the user.)
- Next, we construct a test-time training dataset $D_{TTT}$ by considering all $k!$ permutations of the examples in $D_q$, and encoding them all as prompts. (That is, each prompt is the concatenation of the same $k$ examples; the dataset only differs in the order in which they appear.)
- Finally, we use LoRA to fit the model parameters $\theta$ to $D_{TTT}$, and then sample $\rho \sim p_{\theta'}(\cdot ; q)$