Introduction to Program Synthesis

© Theo X. Olausson. 2025. All rights reserved.

Changelog:

Lecture 12: Adapting Pre-Trained Models to New Tasks

In the last lecture, we saw how to use sampling techniques to generate programs from a language model, subject to some contraints. In general, such techniques can be very useful when the program we are looking for is not too unlikely to be generated by the model. However, if the program is very unlikely, or even has zero probability, then we will need to go beyond sampling techniques and instead adapt the model itself to the task at hand. In this lecture, we will look at a few different ways to adapt a pre-trained model to a specific task, using prompts and weights.

Finetuning

The most straightforward way to adapt a pre-trained model to a specific task is to finetune it on a dataset that is relevant to the task. Finetuning really just means that we train the model on a new dataset, but starting from the pre-trained weights instead of from scratch. Oftentimes, this lets us get away with training on relatively little data, since the base model already has a lot of general capabilities.

Unlike the sampling techniques we saw in the last lecture, finetuning fundamentally changes the model itself. This means that even if the desired program had zero probability under the original model $p_\theta$, it may have high likelihood under the finetuned model $p_{\theta'}$.

To see why this would be useful, suppose you were a hardware researcher who has come up with a new hardware description language, like BlueSpec (TODO cite). You want to use a language model to help computer architects design new hardware in this language, but the model has never seen any examples of this language before. If you were to use the model as is, it would likely generate programs that are not valid in the new language; even if you used sampling techniques to try to guide the model towards valid programs, it would still be unlikely to be able to make full use of the new language's features. If you instead finetuned the model on a dataset of programs written in this new language, you would fundamentally increase the model's ability to design hardware in this language, thus making it much more useful.

While finetuning is a powerful technique, it does have some drawbacks. First, it requires access to a dataset that is relevant to the task at hand, which may be difficult to construct in cases such as the above. Secondly, as a form of continual learning, finetuning can lead to catastrophic forgetting, where the model forgets how to perform tasks it was previously able to do. In the worst case, this can mean that the model loses some of its general capabilities that are actually useful in the domain, such as understanding the natural language feedback from the user. Finally, finetuning can be computationally expensive, as it requires training the entire model. While the first two of these issues can be mitigated to some extent by careful dataset construction and training, the last one is a fundamental limitation of finetuning. As a result, several parameter-efficient variations have been developed, the most common of which is low-rank adaptation (LoRA).

Low-Rank Adaptation (LoRA)

In LoRA, instead of finetuning the entire model, we keep the model's weights fixed and instead learn a small set of adapters. These adapters consist of two low-rank (tall-skinny) matrices which are multipled together and then added to the model's activations.

Consider for example a (single-head) self-attention layer, with weights $W_q, W_k, W_v$ for the query, key, and value projections respectively. Each of these weights is a matrix of size $d \times d$, where $d$ is the layer's hidden dimension; thus, if we were to finetune the model, we would need to update $3d^2$ parameters. In LoRA, we instead learn six low-rank matrices $A_q, B_q, A_k, B_k, A_v, B_v \in \mathbb{R}^{d \times r}$, where $r$ is a small rank (typically much smaller than $d$). We then replace the original query, key, and value projections with the following: \[ \begin{align*} Q &= x(W_q + A_q B_q^T) = x W_q + x A_q B_q^T \\ K &= x(W_k + A_k B_k^T) = x W_k + x A_k B_k^T \\ V &= x(W_v + A_v B_v^T) = x W_v + x A_v B_v^T \end{align*} \] where $x$ is the input to the self-attention layer. This means that we only need to update $6dr$ parameters, which is smaller than $3d^2$ as long as $r < d/2$.

In practice, LoRA is often applied to the fully-connected layers of the model as well, which typically have much larger hidden dimensions than the self-attention layers; this is where the largest savings in parameters come from.

LoRA is a very powerful technique, and has come to dominate the field of parameter-efficient finetuning. Since the number of parameters that are updated is much smaller than in standard finetuning, LoRA is faster, requires less data, and is less prone to catastrophic forgetting than standard finetuning. However, intuitively, LoRA can only adapt the model to tasks that are similar to the tasks it was pre-trained on.

Retreival-Augmented Generation

Another way to adapt a pre-trained model to a new task is to use retrieval-augmented generation (RAG). In RAG, the model weights are not adapted to the new task; instead, the adaptation is done by pairing it with a retrieval model (or, more generally, a vector-similarity database).

Formally, RAG assumes that we have a pre-trained language model $p_\theta$ and a retrieval model $r_\phi$. We then begin by processing our dataset, which consists of a set of documents $D = \{d_1, d_2, \ldots, d_n\}$, into a set of embeddings $E = \{e_i \triangleq r_\phi(d_i) \mid i = 1, \ldots, n\}$. These can then be stored in a key-value store, where the keys are the embeddings $e_i$ and the values are the documents $d_i$. When we want to generate a response to a query $q$, we then embed the query using the retrieval model, $e_q \triangleq r_\phi(q)$, and then retrieve the top $k$ documents that are most similar to the query embedding. The notion of similarity can be defined in many ways, but the most common is to use cosine similarity: \[ \text{similarity}(e_q, e_i) = \frac{\langle e_q, e_i\rangle}{\|e_q\| \|e_i\|}.\] \] The top $k$ documents are then concatenated into a prompt which is passed to the pre-trained language model in order to generate a response.

For a more concrete example of how one would construct such a system, suppose again that we are trying to teach a language model to write programs in a new hardware description language, as in the finetuning example above. We may find that simply providing the model with a few examples of programs written in this language is not enough to get it to generate valid programs. And even if we have a large dataset of programs written in this language, it may not be feasible for us to finetune the model; for example, we may not have access to the model's weights, or we may not have enough computational resources to train the model, or the engineering effort involved in balancing learning the new dataset versus retaining the model's general capabilities may simply be too high. In this case, RAG offers a very powerful alternative. Instead of finetuning the model, we can simply pair the model with an off-the-shelf shelf retrieval model such as BM25 or a BERT-based retriever, and then follow the steps outlined above to retrieve the relevant examples from our dataset of programs. Thus, in the ideal case, the model will be equipped with a number of examples of programs that are similar to the one we want to generate, hopefully allowing it to find the desired program.

Thus, RAG is very similar to in-context learning, but with two key differences:

While it may appear to be somewhat brutish of a solution, RAG has proven to be very effective in practice. In addition to its simplicity, RAG benefits from being trivial to extend to more data in an online fashion, as we can simply add more documents to the retrieval database without needing to retrain the model.

Test-time training

The final technique we will look at is test-time training (TTT). Test-time training seeks to adapt the model to the task at hand by training it on the fly, during inference. While it was first proposed in the context of computer vision, it has recently shown promise in language modeling as well, particularly in the context of abstract reasoning. It seeks to replicate the power of finetuning without the need for large datasets, instead relying on being able to adapt the model to the specific task at hand (instead of the entire domain) using only a few examples. Unlike finetuning, it also plays very nicely with RAG.

Suppose again that we are trying to teach the model to write programs in our new hardware description language. We may find that we simply do not have enough examples to finetune the model, but that RAG is not enough to get the model to generate valid programs. In this case, we could explore using TTT as follows:

  1. Given a query $q$, we first retrieve a set $D_q$ consisting of $k$ relevant examples from our retrieval database. (Alternatively, these examples could be provided by the user.)
  2. Next, we construct a test-time training dataset $D_{TTT}$ by considering all $k!$ permutations of the examples in $D_q$, and encoding them all as prompts. (That is, each prompt is the concatenation of the same $k$ examples; the dataset only differs in the order in which they appear.)
  3. Finally, we use LoRA to fit the model parameters $\theta$ to $D_{TTT}$, and then sample $\rho \sim p_{\theta'}(\cdot ; q)$
It is not immediately obvious why this would work better than simply doing RAG. One reason why it may do so is that it allows us to perform some simple forms of data augmentation to the examples, such as permuting their order; in more specialized domains, we may even be able to come up with more sophisticated augmentations, such as replacing certain variables with synonyms. Another reason is that updating the model's parameters in response to the examples may simply enforce a stricter adherence to the examples than simply using them as context, especially if the model has to learn some new linguistic structure that it has not seen before (e.g., the syntax of the new hardware description language).

So far, TTT has been shown to be particularly effective in the context of abstract reasoning tasks, such as the ARC dataset chollet2019measureintelligence. However, it is still an active area of research, and it remains to be seen how well it will generalize to other domains such as program synthesis. It is worth noting that TTT is not a replacement for finetuning or RAG, but rather a complementary technique that can be used in conjunction with them. For example, the model could be finetuned on an initial dataset of programs, and we could then use RAG combined with TTT to further augment it as we gather more data from our users.