Learning to Write Halide Algorithms by Examples¶
In the last tutorial we learned how to construct Halide functions, which describe computation on infinite multi-dimensional grids. In this tutorial, we will learn how to use these Halide functions to construct more complex Halide algorithms, including how to connect different Halide functions and how to write loops and conditions inside a Halide function.
Note
Remember that Halide separates code into algorithms and schedules. In this tutorial we focus on the algorithms and use the default schedules. This means that if you execute the programs listed here they will be slow. Read the subsequent tutorials on how to make them fast.
Matrix multiplication¶
ImageParam A(Float(32), 2, "A"), B(Float(32), 2, "B");
Func C("C");
Var i("i"), j("j");
C(i, j) = 0.f;
RDom k(0, A.dim(1).extent(), "k"); // inner loop
C(i, j) += A(i, k) * B(k, j);
A, B = hl.ImageParam(hl.Float(32), 2, "A"), hl.ImageParam(hl.Float(32), 2, "B")
C = hl.Func("C")
i, j = hl.Var("i"), hl.Var("j")
C[i, j] = 0.0
k = hl.RDom(0, A.dim(1).extent(), "k") # inner loop
C[i, j] += A[i, k] * B[k, j]
The first example we will look at is a matrix multiplication Halide program:
given two matrices A
and B
we want to compute C = A * B
.
The Halide code above roughly translates to the following pseudo code:
for i = 0 to A.rows()
for j = 0 to B.cols()
C(i, j) = 0.f
for k = 0 to A.cols()
C(i, j) += A(i, k) * B(k, j)
We have previously introduced Halide variables Var
. When we write C(i, j) = 0.f
we can think of it as Halide implicitly constructing loops that traverse over C
‘s
elements. How do we introduce the inner loop that sum over entries from A
and B
then? We declare a RDom
(which stands for reduction domain) construct k
to loop
over A
‘s columns. Also note that we defined C
twice. In the first time we initialize
it to zero over the whole domain, and the second time we update C
to be the sum over
A
‘s rows and B
‘s columns.
Image convolution¶
ImageParam input(Float(32), 2, "in"), kernel(Float(32), 2, "k");
Var x("x"), y("y");
Func bounded_input("input");
bounded_input(x, y) = input(clamp(x, 0, input.dim(0).extent()),
clamp(y, 0, input.dim(1).extent()));
Func output("output");
RDom r(0, kernel.dim(0).extent(), 0, kernel.dim(1).extent());
output(x, y) += bounded_input(x - r.x + kernel.dim(0).extent() / 2,
y - r.y + kernel.dim(1).extent() / 2) *
kernel(r.x, r.y);
input = hl.ImageParam(hl.Float(32), 2, 'in')
kernel = hl.ImageParam(hl.Float(32), 2, 'k')
x, y = hl.Var('x'), hl.Var('y')
bounded_input = hl.Func('input')
bounded_input[x, y] = input[hl.clamp(x, 0, input.dim(0).extent()),
hl.clamp(y, 0, input.dim(1).extent())];
output = hl.Func('output');
r = hl.RDom(0, kernel.dim(0).extent(), 0, kernel.dim(1).extent());
output[x, y] += bounded_input[x - r.x + kernel.dim(0).extent() / 2,
y - r.y + kernel.dim(1).extent() / 2] *
kernel[r.x, r.y];
Image convolution can also be described using reduction domains. There are two crucial differences compared to the matrix multiplication.
Firstly, the reduction domain becomes 2D. We can use r.x
and r.y
to access different
dimensions of the reduction domains. Alternatively we can use r[0]
or r[1]
.
Secondly, the convolution filter can read from outside the bounds of the input. Therefore we need
to define a boundary condition for our input. This is done through the bounded_input
function.
We define the out-of-bound access to the closest pixel based on the coordinate, and this is achieved
through clamping the coordinates. Since this pattern is very common, Halide provides a syntatic sugar
for this:
Func bounded_input = BoundaryConditions::repeat_edge(input);
bounded_input = hl.BoundaryConditions.repeat_edge(input)
Check out Halide’s documentation for different kinds of boundary conditions.
Also note that we omit the initialization of output
. Halide will automatically
initialize it to zero.
Histogram¶
Param<int> num_bins;
Param<float> hist_min, hist_max;
ImageParam input(Float(32), 1, "in");
Var x("x");
Func hist("hist");
RDom r(0, input.dim(0).extent());
Expr hist_index = cast<int>(num_bins * ((input(r) - hist_min) / (hist_max - hist_min)));
hist(hist_index) += 1;
num_bins = hl.Param(hl.Int(32))
hist_min, hist_max = hl.Param(hl.Float(32)), hl.Param(hl.Float(32))
input = hl.ImageParam(hl.Float(32), 1, 'in')
x = hl.Var('x')
hist = hl.Func('hist')
r = hl.RDom(0, input.dim(0).extent())
hist_index = hl.cast(hl.Int(32), (input(r) - hist_min) / (hist_max - hist_min);
hist(hist_index) += 1;
Reduction variables can also be used at the left-hand side of updates, like in the histogram
example above. If an expression is too long, it can be temporarily stored in an Expr
.
Todo
Examples for select
Todo
Examples for rdom.where
Todo
Examples where we metaprogram a huge expression