Changelog:
- 6/24/25: Wrote entire first draft of the lecture.
- 6/30/25: Fix: moved this changelog to inside content div.
Lecture 11: Sampling from Large Language Models
In the last lecture, we saw how to construct (large) language models. Fundamentally, these models parametrize distributions $p_\theta(\overline{t})$ over sequences of tokens $\overline{t} = t_1, \ldots, t_n$, with a key property that because of the way the models are trained, they learn not just the joint distribution over sequences of tokens, but also the conditional distribution over the next token given the previous ones. Thus, we can sample from the model by sampling the first token $t_1$ from $p_\theta(t_1)$, then sampling the second token $t_2$ from $p_\theta(t_2 | t_1)$, and so on, until we reach the end of the sequence, a strategy which is typically called autoregression. This also means that it is very easy to prompt (i.e., condition on a prefix) the models, which has turned out to be a surprisingly effective and flexible way to elicit all sorts of behaviors from the models. But what if we want to condition the model on something other than a prefix, or if the condition could be encoded as a prefix but the learned distribution does not match the distribution we want to sample from perfectly? This is a common problem in program synthesis, where we often want to sample programs that satisfy some hard constraint, such as satisfying a logic formula or passing a set of unit tests. In this situation it is not reasonable to expect that the learned distribution will match the distribution of programs that satisfy the constraint perfectly, and yet we need to be able to obtain these programs reliably. In this lecture, we will see how to combine language models with advanced sampling techniques to obtain programs that we know to satisfy our (hard) constraints. Notation. We will use the notation $p_\theta(\cdot)$ to denote the distribution over strings induced by a language model with parameters $\theta$. For hard constraints $\mathcal{C}$, we will use the notation $p_\theta(\cdot | \mathcal{C})$ to denote the distribution over strings that satisfy the constraint $\mathcal{C}$, and we will use $1(s \models \mathcal{C})$ to denote the indicator function that returns 1 if the string $s$ satisfies the constraint $\mathcal{C}$ and 0 otherwise. With some abuse of notation, when we write $p_\theta(\cdot ; p)$ we will mean the distribution over continuations of the prompt $p$. We will also use the notation $++$ to denote concatenation of prompts.Rejection sampling
The simplest way to obtain a sample which adheres to some hard constraint $\mathcal{C}$ is to use rejection sampling. Algorithmically, we can sample from the distribution $p_\theta(\rho | \mathcal{C})$ as follows:- Sample $\rho \sim p_\theta(\cdot)$.
- If $1(\rho \models \mathcal{C}) == 1$ return $\rho$; else restart from step 1.
Filtering based on observational equivalence
An interesting early example of a system which went one step beyond rejection sampling is AlphaCodeAlphaCode. AlphaCode sought to solve problems from programming competitions; at the time, generalist models were not yet very competent programmers, so instead AlphaCode begun by taking a language model which had first been pre-trained on a large set of programs from Github, and then finetuned the model on a dataset of programming competition problems. Despite this specialization, performance was still not very good, so AlphaCode introduced a novel sampling strategy to improve the quality of the solutions. First, a very large number of candidate programs (in its most aggressive mode, 1 million samples per task) were sampled from the model. Then, AlphaCode performed rejection sampling on a set of public unit tests, which were provided as part of the problem description. This turned out to not be enough, however, as many of the sampled programs were still incorrect; when evaluated on the hidden test cases, which were not provided as part of the problem description, many of the programs would fail. As a result, AlphaCode introduced a second filtering step, where it would take the set of programs that passed the public tests and then cluster them based on observational equivalance on synthetic input data. Finally, only one program from each cluster was returned. Thus, the final output of AlphaCode was a small set of unique programs that passed all the public tests, and which were also distinct from each other in terms of their behavior on the synthetic input data, increasing the chances that at least one of them would pass the hidden tests.Self-repair
As language models became more powerful, it became possible to use them to not just generate continuations of programs, but also to repair programs that were known to be incorrect. This begged the question of whether it was possible to increase the quality of the sampled programs by using the model to iteratively improve them. A simple implementation of this idea is as follows:- Sample a program $\rho \sim p_\theta(\cdot ; p_\mathcal{C})$.
- If $1(\rho \models \mathcal{C}) == 1$ return $\rho$; else, let $e = \mathcal{C}(\rho)$ be the error of $\rho$ with respect to the constraint $\mathcal{C}$.
- Sample a repaired program $\rho \sim p_\theta(\cdot ; p_\mathcal{C} ++ \rho ++ e)$.
- Repeat from Step 2.
Constrained Decoding with Synchromesh
This probably needs to be fleshed out a bit more. There's actually a lot of formal machinery in Synchromesh, so I had a hard time figuring out at which level to explain it.
Sometimes, we may not have any hard constraints on the semantics of the program we want to sample, but we may still want to ensure that the sampled program satisfies some other type of hard constraint. In particular, if we know which programming language we would like the solution to be in, we may as a first step at least want to ensure that the sampled program is syntactically valid in that language. In this special case, we can use a technique called constrained decoding to ensure that the sampled program is syntactically valid. The high-level idea is simple. Let $\mathcal{L}$ be the set of all strings which are valid programs in the target programming language. Note that $\mathcal{L}$ may be infinite, but that is no concern; since we have a grammar for the programming language, it is easy to check whether a given string $s$ belongs to $\mathcal{L}$ by parsing it. Thus, we already have a procedure for checking whether a string is a valid program, and we can use this to filter out invalid programs with rejection sampling. But that is not very efficient, as we do not know until we have sampled the entire string whether it is a valid program or not. Can we do better? As first observed by Poesia et al. in the Synchromesh projectpoesia2022synchromesh, the answer is yes. By using the prefix closure $\mathcal{L}^c = \{u \mid \exists v. uv \in \mathcal{L}\}$ of $\mathcal{L}$, i.e., the set of all strings that can be extended to a valid program, we can ensure at each step of the sampling process that the sampled string is a valid prefix of a program in $\mathcal{L}$. It is not immediately obvious, but it turns out that answering the question of whether a string $s$ is in $\mathcal{L}^c$ can be done somewhat efficiently; Poesia et al. found that their implementation could ensure syntactically valid programs with only an 8% overhead compared to sampling from the model without any constraints.Steering with Sequential Monte Carlo
One limitation of the style of constrained decoding employed in Synchromesh is that it is quite inflexible. Poesia et al. do present an extension which allows some degree of semantic steering, but its reach is limited and its implementation is quite complex. Recently, a more flexible approach to constrained decoding has been developed by a group of researchers here at MIT, which is based on Sequential Monte Carlo (SMC). The key idea is that the constraints can be encoded as part of the importance weights of the SMC particles. For example, to replicate the behaviour of Synchromesh, we can simply set the importance weights to be 0 for all particles that do not belong to the prefix closure $\mathcal{L}^c$. What is neat about this approach is that it gives a very flexible framework for constrained decoding, where we can use the importance weights to encode any kind of constraint we want; it gives a sort of algebra through which to compose constraints. Another important benefit of this approach is that it, unlike the type of local steering employed in Synchromesh, actually samples from the distribution $p_\theta(\cdot | \mathcal{C})$. To see why, suppose we are interested in sampling programs that do not use any variable names longer than 10 characters. We could do this in the style of Synchromesh by checking, when sampling the token $t_k \sim p_\theta(t_k | t_1, \ldots, t_{k-1})$, whether adding the token $t_k$ to the current program $t_1 ++ t_2 \ldots ++ t_{k-1}$ would result in the construction of a variable name longer than 10 characters; we would then filter out all such tokens, and only sample from the remaining ones. However, suppose the current program isdef check_balance(account_
.
We would then need to filter out the tokens num
, name
, number
, etc., which would all lead to the argument variable having a name longer than 10 characters.
Perhaps we would end up having to settle for a completion like def check_balance(account_1)
, which is not very legible; formally speaking,
while it does satisfy the constraint, it is not a representative sample from the distribution $p_\theta(\cdot | \mathcal{C})$ because it has low probability under the prior.
Beam search, in which we keep track of the $k$ most likely continuations at each step, does alleviate this problem somewhat, but such degenerate samples can still occur with local constraints.
In contrast, the SMC approach guarantees that we will obtain representative samples from the distribution $p_\theta(\cdot | \mathcal{C})$, because (in the sample limit) the model will essentially be able to backtrack
to the program def check_balance(acc
and sample the more legible completion _num)
instead.