Optimization Tips for Finch

It's easy to ask Finch to run the same operation in different ways. However, different approaches have different performance. The right approach really depends on your particular situation. Here's a collection of general approaches that help Finch generate faster code in most cases.

Concordant Iteration

By default, Finch stores arrays in column major order (first index fast). When the storage order of an array in a Finch expression corresponds to the loop order, we call this concordant iteration. For example, the following expression represents a concordant traversal of a sparse matrix, as the outer loops access the higher levels of the tensor tree:

A = Tensor(
    Dense(SparseList(Element(0.0))),
    fsparse([2, 3, 4, 1, 3], [1, 1, 1, 3, 3], [1.1, 2.2, 3.3, 4.4, 5.5], (4, 3)),
)
s = Scalar(0.0)
@finch for j in _, i in _
    s[] += A[i, j]
end

# output

NamedTuple()

We can investigate the generated code with @finch_code. This code iterates over only the nonzeros in order. If our matrix is m × n with nnz nonzeros, this takes O(n + nnz) time.

@finch_code for j in _, i in _
    s[] += A[i, j]
end

# output

quote
    s_data = (ex.bodies[1]).body.body.lhs.tns.bind
    s_val = s_data.val
    A_lvl = (ex.bodies[1]).body.body.rhs.tns.bind.lvl
    A_lvl_stop = A_lvl.shape
    A_lvl_2 = A_lvl.lvl
    A_lvl_2_ptr = A_lvl_2.ptr
    A_lvl_2_idx = A_lvl_2.idx
    A_lvl_2_stop = A_lvl_2.shape
    A_lvl_3 = A_lvl_2.lvl
    A_lvl_3_val = A_lvl_3.val
    for j_3 = 1:A_lvl_stop
        A_lvl_q = (1 - 1) * A_lvl_stop + j_3
        A_lvl_2_q = A_lvl_2_ptr[A_lvl_q]
        A_lvl_2_q_stop = A_lvl_2_ptr[A_lvl_q + 1]
        if A_lvl_2_q < A_lvl_2_q_stop
            A_lvl_2_i1 = A_lvl_2_idx[A_lvl_2_q_stop - 1]
        else
            A_lvl_2_i1 = 0
        end
        phase_stop = min(A_lvl_2_i1, A_lvl_2_stop)
        if phase_stop >= 1
            if A_lvl_2_idx[A_lvl_2_q] < 1
                A_lvl_2_q = Finch.scansearch(A_lvl_2_idx, 1, A_lvl_2_q, A_lvl_2_q_stop - 1)
            end
            while true
                A_lvl_2_i = A_lvl_2_idx[A_lvl_2_q]
                if A_lvl_2_i < phase_stop
                    A_lvl_3_val_2 = A_lvl_3_val[A_lvl_2_q]
                    s_val = A_lvl_3_val_2 + s_val
                    A_lvl_2_q += 1
                else
                    phase_stop_3 = min(phase_stop, A_lvl_2_i)
                    if A_lvl_2_i == phase_stop_3
                        A_lvl_3_val_2 = A_lvl_3_val[A_lvl_2_q]
                        s_val += A_lvl_3_val_2
                        A_lvl_2_q += 1
                    end
                    break
                end
            end
        end
    end
    result = ()
    s_data.val = s_val
    result
end

When the loop order does not correspond to storage order, we call this discordant iteration. For example, if we swap the loop order in the above example, then Finch needs to randomly access each sparse column for each row i. We end up needing to find each (i, j) pair because we don't know whether it will be zero until we search for it. In all, this takes time O(n * m * log(nnz)), much less efficient! We shouldn't randomly access sparse arrays unless we really need to and they support it efficiently!

Note the double for loop in the following code

@finch_code for i in _, j in _
    s[] += A[i, j]
end # DISCORDANT, DO NOT DO THIS

# output

quote
    s_data = (ex.bodies[1]).body.body.lhs.tns.bind
    s_val = s_data.val
    A_lvl = (ex.bodies[1]).body.body.rhs.tns.bind.lvl
    A_lvl_stop = A_lvl.shape
    A_lvl_2 = A_lvl.lvl
    A_lvl_2_ptr = A_lvl_2.ptr
    A_lvl_2_idx = A_lvl_2.idx
    A_lvl_2_stop = A_lvl_2.shape
    A_lvl_3 = A_lvl_2.lvl
    A_lvl_3_val = A_lvl_3.val
    @warn "Performance Warning: non-concordant traversal of A[i, j] (hint: most arrays prefer column major or first index fast, run in fast mode to ignore this warning)"
    for i_3 = 1:A_lvl_2_stop
        for j_3 = 1:A_lvl_stop
            A_lvl_q = (1 - 1) * A_lvl_stop + j_3
            A_lvl_2_q = A_lvl_2_ptr[A_lvl_q]
            A_lvl_2_q_stop = A_lvl_2_ptr[A_lvl_q + 1]
            if A_lvl_2_q < A_lvl_2_q_stop
                A_lvl_2_i1 = A_lvl_2_idx[A_lvl_2_q_stop - 1]
            else
                A_lvl_2_i1 = 0
            end
            phase_stop = min(i_3, A_lvl_2_i1)
            if phase_stop >= i_3
                if A_lvl_2_idx[A_lvl_2_q] < i_3
                    A_lvl_2_q = Finch.scansearch(A_lvl_2_idx, i_3, A_lvl_2_q, A_lvl_2_q_stop - 1)
                end
                while true
                    A_lvl_2_i = A_lvl_2_idx[A_lvl_2_q]
                    if A_lvl_2_i < phase_stop
                        A_lvl_3_val_2 = A_lvl_3_val[A_lvl_2_q]
                        s_val = A_lvl_3_val_2 + s_val
                        A_lvl_2_q += 1
                    else
                        phase_stop_3 = min(phase_stop, A_lvl_2_i)
                        if A_lvl_2_i == phase_stop_3
                            A_lvl_3_val_2 = A_lvl_3_val[A_lvl_2_q]
                            s_val += A_lvl_3_val_2
                            A_lvl_2_q += 1
                        end
                        break
                    end
                end
            end
        end
    end
    result = ()
    s_data.val = s_val
    result
end

TL;DR: As a quick heuristic, if your array indices are all in alphabetical order, then the loop indices should be reverse alphabetical.

Appropriate Fill Values

The @finch macro requires the user to specify an output format. This is the most flexibile approach, but can sometimes lead to densification unless the output fill value is appropriate for the computation.

For example, if A is m × n with nnz nonzeros, the following Finch kernel will densify B, filling it with m * n stored values:

A = Tensor(
    Dense(SparseList(Element(0.0))),
    fsparse([2, 3, 4, 1, 3], [1, 1, 1, 3, 3], [1.1, 2.2, 3.3, 4.4, 5.5], (4, 3)),
)
B = Tensor(Dense(SparseList(Element(0.0)))) #DO NOT DO THIS, B has the wrong fill value
@finch begin
    B .= 0
    for j in _, i in _
        B[i, j] = A[i, j] + 1
    end
    return B
end
countstored(B)

# output

12

Since A is filled with 0.0, adding 1 to the fill value produces 1.0. However, B can only represent a fill value of 0.0. Instead, we should specify 1.0 for the fill.

A = Tensor(
    Dense(SparseList(Element(0.0))),
    fsparse([2, 3, 4, 1, 3], [1, 1, 1, 3, 3], [1.1, 2.2, 3.3, 4.4, 5.5], (4, 3)),
)
B = Tensor(Dense(SparseList(Element(1.0))))
@finch begin
    B .= 1
    for j in _, i in _
        B[i, j] = A[i, j] + 1
    end
    return B
end
countstored(B)

# output

5

Static Versus Dynamic Values

In order to skip some computations, Finch must be able to determine the value of program variables. Continuing our above example, if we obscure the value of 1 behind a variable x, Finch can only determine that x has type Int, not that it is 1.

A = Tensor(
    Dense(SparseList(Element(0.0))),
    fsparse([2, 3, 4, 1, 3], [1, 1, 1, 3, 3], [1.1, 2.2, 3.3, 4.4, 5.5], (4, 3)),
)
B = Tensor(Dense(SparseList(Element(1.0))))
x = 1 #DO NOT DO THIS, Finch cannot see the value of x anymore
@finch begin
    B .= 1
    for j in _, i in _
        B[i, j] = A[i, j] + x
    end
    return B
end
countstored(B)

# output

12

However, there are some situations where you may want a value to be dynamic. For example, consider the function saxpy(x, a, y) = x .* a .+ y. Because we do not know the value of a until we run the function, we should treat it as dynamic, and the following implementation is reasonable:

function saxpy(x, a, y)
    z = Tensor(SparseList(Element(0.0)))
    @finch begin
        z .= 0
        for i in _
            z[i] = a * x[i] + y[i]
        end
        return z
    end
end

Use Known Functions

Unless you declare the properties of your functions using Finch's User-Defined Functions interface, Finch doesn't know how they work. For example, using a lambda obscures the meaning of *.

A = Tensor(
    Dense(SparseList(Element(0.0))),
    fsparse([2, 3, 4, 1, 3], [1, 1, 1, 3, 3], [1.1, 2.2, 3.3, 4.4, 5.5], (4, 3)),
)
B = ones(4, 3)
C = Scalar(0.0)
f(x, y) = x * y # DO NOT DO THIS, Obscures *
@finch begin
    C .= 0
    for j in _, i in _
        C[] += f(A[i, j], B[i, j])
    end
    return C
end

# output

(C = Scalar{0.0, Float64}(16.5),)

Checking the generated code, we see that this code is indeed densifying (notice the for-loop which repeatedly evaluates f(B[i, j], 0.0)).

@finch_code begin
    C .= 0
    for j in _, i in _
        C[] += f(A[i, j], B[i, j])
    end
    return C
end

# output

quote
    C_data = ((ex.bodies[1]).bodies[1]).tns.bind
    A_lvl = (((ex.bodies[1]).bodies[2]).body.body.rhs.args[1]).tns.bind.lvl
    A_lvl_stop = A_lvl.shape
    A_lvl_2 = A_lvl.lvl
    A_lvl_2_ptr = A_lvl_2.ptr
    A_lvl_2_idx = A_lvl_2.idx
    A_lvl_2_stop = A_lvl_2.shape
    A_lvl_3 = A_lvl_2.lvl
    A_lvl_3_val = A_lvl_3.val
    B_data = (((ex.bodies[1]).bodies[2]).body.body.rhs.args[2]).tns.bind
    sugar_1 = size((((ex.bodies[1]).bodies[2]).body.body.rhs.args[2]).tns.bind)
    B_mode1_stop = sugar_1[1]
    B_mode2_stop = sugar_1[2]
    B_mode1_stop == A_lvl_2_stop || throw(DimensionMismatch("mismatched dimension limits ($(B_mode1_stop) != $(A_lvl_2_stop))"))
    B_mode2_stop == A_lvl_stop || throw(DimensionMismatch("mismatched dimension limits ($(B_mode2_stop) != $(A_lvl_stop))"))
    C_val = 0
    for j_4 = 1:B_mode2_stop
        A_lvl_q = (1 - 1) * A_lvl_stop + j_4
        A_lvl_2_q = A_lvl_2_ptr[A_lvl_q]
        A_lvl_2_q_stop = A_lvl_2_ptr[A_lvl_q + 1]
        if A_lvl_2_q < A_lvl_2_q_stop
            A_lvl_2_i1 = A_lvl_2_idx[A_lvl_2_q_stop - 1]
        else
            A_lvl_2_i1 = 0
        end
        phase_stop = min(B_mode1_stop, A_lvl_2_i1)
        if phase_stop >= 1
            i = 1
            if A_lvl_2_idx[A_lvl_2_q] < 1
                A_lvl_2_q = Finch.scansearch(A_lvl_2_idx, 1, A_lvl_2_q, A_lvl_2_q_stop - 1)
            end
            while true
                A_lvl_2_i = A_lvl_2_idx[A_lvl_2_q]
                if A_lvl_2_i < phase_stop
                    for i_6 = i:-1 + A_lvl_2_i
                        val = B_data[i_6, j_4]
                        C_val = (Main).f(0.0, val) + C_val
                    end
                    A_lvl_3_val_2 = A_lvl_3_val[A_lvl_2_q]
                    val_2 = B_data[A_lvl_2_i, j_4]
                    C_val += (Main).f(A_lvl_3_val_2, val_2)
                    A_lvl_2_q += 1
                    i = A_lvl_2_i + 1
                else
                    phase_stop_3 = min(phase_stop, A_lvl_2_i)
                    if A_lvl_2_i == phase_stop_3
                        for i_8 = i:-1 + phase_stop_3
                            val_3 = B_data[i_8, j_4]
                            C_val += (Main).f(0.0, val_3)
                        end
                        A_lvl_3_val_2 = A_lvl_3_val[A_lvl_2_q]
                        val_4 = B_data[phase_stop_3, j_4]
                        C_val += (Main).f(A_lvl_3_val_2, val_4)
                        A_lvl_2_q += 1
                    else
                        for i_10 = i:phase_stop_3
                            val_5 = B_data[i_10, j_4]
                            C_val += (Main).f(0.0, val_5)
                        end
                    end
                    i = phase_stop_3 + 1
                    break
                end
            end
        end
        phase_start_3 = max(1, 1 + A_lvl_2_i1)
        if B_mode1_stop >= phase_start_3
            for i_12 = phase_start_3:B_mode1_stop
                val_6 = B_data[i_12, j_4]
                C_val += (Main).f(0.0, val_6)
            end
        end
    end
    C_data.val = C_val
    (C = C_data,)
end

Type Stability

Julia code runs fastest when the compiler can infer the types of all intermediate values. Finch does not check that the generated code is type-stable. In situations where tensors have nonuniform index or element types, or the computation itself might involve multiple types, one should check that the output of @finch_kernel code is type-stable with @code_warntype.

Dense Arrays

Finch is currently optimized for sparse code and does not implement traditional dense optimizations. We are currently adding these features, but if you need dense performance, you may want to look at JuliaGPU instead.