JAX, Static-Shape Programming and Polyhedron

Rectangles are fine. Weird shapes are fun.

JAX wants static shapes.
Your loops, alas, are sometimes not rectangles.

This post tours the pain –> coping strategies –> a tiny helper I wrote called HedraX, which lets you index arbitrary polyhedral domains in JAX without summoning five GPTs.

This is Part 1: we’ll build intuition with hand-rolled code and end with HedraX’s table indexer.
In Part 2, I’ll show a more “closed-form” approach HedraX can auto-generate for suitable domains.


Rectangles are Boring

In JAX, we often translate a Python-like loop like

for i in range(N):
    for j in range(N):
        a[i, j] = f(i, j)

into

a = jax.vmap(
      jax.vmap(f, in_axes=(0, None)),
      in_axes=(None, 0)
    )(jnp.arange(N), jnp.arange(N))

This translation works fine for rectangular domains. But suppose we want the lower triangle:

for i in range(N):
    for j in range(i):
        a[i, j] = f(i, j)

At first glance this looks “dynamic” because the inner bound j depends on i. Can we do this in JAX with static shapes?

The answer is yes.


The Heroic (but Fragile) Closed-Form for Triangles

Although the domain isn’t rectangular, it is statically sized : it has N * (N + 1) // 2 points.

We can biject a linear index k to (i, j) and iterate over k:

We can picture the domain as a triangle, and we assign each point a linear index k in the order of enumerating the rows and columns.

And here is the JAX code that implements this idea:

# Lower triangle: j in [0, i] (including the diagonal)
# k ranges over 0..T_{N-1} where T_m = m(m+1)/2
def body(a, k):
    # Solve for row i from k using the quadratic formula
    i = jnp.floor((jnp.sqrt(8.0 * k + 1.0) - 1.0) / 2.0).astype(jnp.int32)
    Ti = (i * (i + 1)) // 2        # T_i
    j = (k - Ti).astype(jnp.int32) # j in [0, i]
    a = a.at[i, j].set(f(i, j))
    return a, None

K = N * (N + 1) // 2
a, _ = lax.scan(body, a0, jnp.arange(K))

This works and is reasonably fast, but the math is bespoke. You also won’t want to re-derive a closed-form quadratic formula for every odd-shaped loop you meet.


The “fine, I’ll just precompute it” Route

Another approach: explicitly enumerate the valid lattice points into a table and scan over that table.

import jax
import jax.numpy as jnp
from jax import lax

def build_coords_triangle(N):
    # Lower triangle (including the diagonal)
    # Store linear addresses k = i * N + j
    pts = [i * N + j for i in range(N) for j in range(i + 1)]
    return jnp.asarray(pts, dtype=jnp.int32)

addresses = build_coords_triangle(N)  # shape: (K,)
def body(a, k):
    i, j = k // N, k % N
    a = a.at[i, j].set(f(i, j))
    return a, None

a, _ = lax.scan(body, a0, addresses)

This is conceptually simple but:

  • adds an address table (memory),
  • adds an extra read per iteration,
  • still asks you to hand-enumerate the domain.

What if your domain is… less cozy?


From Triangles to “Whatever”

Consider the polygonal domain \(\mathcal{D} = \{ (i, j) \in \mathbb{Z}^2 \mid \; 5j - i - 8 \ge 0,\; -3i - 6j + 39 \ge 0,\; 4i + j - 10 \ge 0 \}.\) that looks like this:

How to implement the build_coords_triangle function for this domain? It’s not obvious.

A simple approach is to bound the domain by a rectangle and reject points outside the domain, as shown by the dashed rectangle in the figure above.

But, in higher dimensions:

  • bounding boxes get tedious,
  • rejection gets expensive.

Introducing HedraX

Happily, the problem of parametric polyhedral enumeration has been studied to death (Verdoolaege et al., 2007; Klöckner, 2014; Verdoolaege, 2010).
It powers polyhedral compilation in systems like LLVM/MLIR.

I wrapped just enough of that machinery into a tiny helper: HedraX, specifically built for the use case of static-shape programming in JAX.1

TL;DR: Tell HedraX your domain; it builds the address table for you and gives you an unravel to recover multi-indices.

The Triangle Example

import hedrax as hdx
from jax import lax

addresses, unravel = hdx.compile_table_indexer(
    "[N] -> { [i, j] : 0 <= j <= i < N }",
    N=10
)

def body(a, k):
    i, j = unravel(k)
    a = a.at[i, j].set(f(i, j))
    return a, None

a, _ = lax.scan(body, a0, addresses)

Crazy domain? Just change the set:

addresses, unravel = hdx.compile_table_indexer(
    "[N] -> { [i, j] : 5j - i - 8 >= 0 and -3i - 6j + 39 >= 0 and 4i + j - 10 >= 0 }",
    N=10
)

The GPT Unicorn

With the table indexer in HedraX, you can even do unions of polyhedra.

For example, here is a ChatGPT-generated unicorn built as a union of convex polyhedra:

Voilà!

What About the Quadratic-Solving Approach?

hdx.compile_table_indexer automates the “precompute the table” route in precompute-route.
It doesn’t produce the same neat closed-form mapping as in our closed-form approach — but in Part 2 I’ll show how HedraX can derive those closed-forms automatically when the domain admits them.


  1. A lot of credit underneath the hood of HedraX goes to islpy (Klöckner, 2014), which is a Python binding for the isl library for manipulating parametric polyhedra. 

References

  1. [1] Verdoolaege, Sven and Seghir, Kazem and Beyls, Kristof and D’Hollander, Erik and Bruynooghe, Maurice. Counting Integer Points in Parametric Polytopes Using Barvinok’s Rational Functions. Algorithmica, 2007.
  2. [2] Klöckner, Andreas. islpy: Python bindings for isl. 2014.
  3. [3] Verdoolaege, Sven. isl: an integer set library for the polyhedral model. International Congress Conference on Mathematical Software, 2010.



    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • On The Computability of Parametric Inversion
  • Estimating Fluid Velocity and Diffusion from Temperature Measurements (Part 2, Simulation)
  • Estimating Fluid Velocity and Diffusion from Temperature Measurements (in Theory)