JAX, Static-Shape Programming and Polyhedron
Rectangles are fine. Weird shapes are fun.
JAX wants static shapes.
Your loops, alas, are sometimes not rectangles.
This post tours the pain –> coping strategies –> a tiny helper I wrote called HedraX, which lets you index arbitrary polyhedral domains in JAX without summoning five GPTs.
This is Part 1: we’ll build intuition with hand-rolled code and end with HedraX’s table indexer.
In Part 2, I’ll show a more “closed-form” approach HedraX can auto-generate for suitable domains.
Rectangles are Boring
In JAX, we often translate a Python-like loop like
for i in range(N):
for j in range(N):
a[i, j] = f(i, j)
into
a = jax.vmap(
jax.vmap(f, in_axes=(0, None)),
in_axes=(None, 0)
)(jnp.arange(N), jnp.arange(N))
This translation works fine for rectangular domains. But suppose we want the lower triangle:
for i in range(N):
for j in range(i):
a[i, j] = f(i, j)
At first glance this looks “dynamic” because the inner bound j
depends on i
. Can we do this in JAX with static shapes?
The answer is yes.
The Heroic (but Fragile) Closed-Form for Triangles
Although the domain isn’t rectangular, it is statically sized : it has N * (N + 1) // 2
points.
We can biject a linear index k
to (i, j)
and iterate over k
:
We can picture the domain as a triangle, and we assign each point a linear index k
in the order of enumerating the rows and columns.
And here is the JAX code that implements this idea:
# Lower triangle: j in [0, i] (including the diagonal)
# k ranges over 0..T_{N-1} where T_m = m(m+1)/2
def body(a, k):
# Solve for row i from k using the quadratic formula
i = jnp.floor((jnp.sqrt(8.0 * k + 1.0) - 1.0) / 2.0).astype(jnp.int32)
Ti = (i * (i + 1)) // 2 # T_i
j = (k - Ti).astype(jnp.int32) # j in [0, i]
a = a.at[i, j].set(f(i, j))
return a, None
K = N * (N + 1) // 2
a, _ = lax.scan(body, a0, jnp.arange(K))
This works and is reasonably fast, but the math is bespoke. You also won’t want to re-derive a closed-form quadratic formula for every odd-shaped loop you meet.
The “fine, I’ll just precompute it” Route
Another approach: explicitly enumerate the valid lattice points into a table and scan over that table.
import jax
import jax.numpy as jnp
from jax import lax
def build_coords_triangle(N):
# Lower triangle (including the diagonal)
# Store linear addresses k = i * N + j
pts = [i * N + j for i in range(N) for j in range(i + 1)]
return jnp.asarray(pts, dtype=jnp.int32)
addresses = build_coords_triangle(N) # shape: (K,)
def body(a, k):
i, j = k // N, k % N
a = a.at[i, j].set(f(i, j))
return a, None
a, _ = lax.scan(body, a0, addresses)
This is conceptually simple but:
- adds an address table (memory),
- adds an extra read per iteration,
- still asks you to hand-enumerate the domain.
What if your domain is… less cozy?
From Triangles to “Whatever”
Consider the polygonal domain \(\mathcal{D} = \{ (i, j) \in \mathbb{Z}^2 \mid \; 5j - i - 8 \ge 0,\; -3i - 6j + 39 \ge 0,\; 4i + j - 10 \ge 0 \}.\) that looks like this:
How to implement the build_coords_triangle
function for this domain?
It’s not obvious.
A simple approach is to bound the domain by a rectangle and reject points outside the domain, as shown by the dashed rectangle in the figure above.
But, in higher dimensions:
- bounding boxes get tedious,
- rejection gets expensive.
Introducing HedraX
Happily, the problem of parametric polyhedral enumeration has been studied to death (Verdoolaege et al., 2007; Klöckner, 2014; Verdoolaege, 2010).
It powers polyhedral compilation in systems like LLVM/MLIR.
I wrapped just enough of that machinery into a tiny helper: HedraX, specifically built for the use case of static-shape programming in JAX.1
TL;DR: Tell HedraX your domain; it builds the address table for you and gives you an unravel
to recover multi-indices.
The Triangle Example
import hedrax as hdx
from jax import lax
addresses, unravel = hdx.compile_table_indexer(
"[N] -> { [i, j] : 0 <= j <= i < N }",
N=10
)
def body(a, k):
i, j = unravel(k)
a = a.at[i, j].set(f(i, j))
return a, None
a, _ = lax.scan(body, a0, addresses)
Crazy domain? Just change the set:
addresses, unravel = hdx.compile_table_indexer(
"[N] -> { [i, j] : 5j - i - 8 >= 0 and -3i - 6j + 39 >= 0 and 4i + j - 10 >= 0 }",
N=10
)
The GPT Unicorn
With the table indexer in HedraX, you can even do unions of polyhedra.
For example, here is a ChatGPT-generated unicorn built as a union of convex polyhedra:
What About the Quadratic-Solving Approach?
hdx.compile_table_indexer
automates the “precompute the table” route in precompute-route.
It doesn’t produce the same neat closed-form mapping as in our closed-form approach — but in Part 2 I’ll show how HedraX can derive those closed-forms automatically when the domain admits them.
-
A lot of credit underneath the hood of HedraX goes to islpy (Klöckner, 2014), which is a Python binding for the isl library for manipulating parametric polyhedra. ↩
References
- [1] Verdoolaege, Sven and Seghir, Kazem and Beyls, Kristof and D’Hollander, Erik and Bruynooghe, Maurice. Counting Integer Points in Parametric Polytopes Using Barvinok’s Rational Functions. Algorithmica, 2007.
- [2] Klöckner, Andreas. islpy: Python bindings for isl. 2014.
- [3] Verdoolaege, Sven. isl: an integer set library for the polyhedral model. International Congress Conference on Mathematical Software, 2010.
Enjoy Reading This Article?
Here are some more articles you might like to read next: