Mask Sugar
In Finch, expressions like i == j are treated as a sugar for mask tensors, which can be used to encode fancy iteration patterns. For example, the expression i == j is converted to a diagonal boolean mask tensor DiagMask()[i, j], which allows an expression like
@finch begin
for i in _, j in _
if i == j
s[] += A[i, j]
end
end
endto compile to something like
for i in 1:n
s[] += A[i, i]
endThere are several mask tensors and syntaxes available, summarized in the following table where i, j are indices:
| Expression | Transformed Expression |
|---|---|
i < j | UpTriMask()[i, j - 1] |
i <= j | UpTriMask()[i, j] |
i > j | LoTriMask()[i, j + 1] |
i >= j | LoTriMask()[i, j] |
i == j | DiagMask()[i, j] |
i != j | !(DiagMask()[i, j]) |
Note that either i or j may be expressions, so long as the expression is constant with respect to the loop over the index.
The mask tensors are described below:
Finch.uptrimask — ConstantuptrimaskA mask for an upper triangular tensor, uptrimask[i, j] = i <= j. Note that this specializes each column for the cases where i <= j and i > j.
Finch.lotrimask — ConstantlotrimaskA mask for an upper triangular tensor, lotrimask[i, j] = i >= j. Note that this specializes each column for the cases where i < j and i >= j.
Finch.diagmask — ConstantdiagmaskA mask for a diagonal tensor, diagmask[i, j] = i == j. Note that this specializes each column for the cases where i < j, i == j, and i > j.
Finch.bandmask — ConstantbandmaskA mask for a banded tensor, bandmask[i, j, k] = j <= i <= k. Note that this specializes each column for the cases where i < j, j <= i <= k, and k < i.
Finch.splitmask — Functionsplitmask(n, P)A mask to evenly divide n indices into P regions. If M = splitmask(P, n), then M[i, j] = fld(n * (j - 1), P) <= i < fld(n * j, P).
julia> splitmask(10, 3)
10×3 Finch.SplitMask{Int64}:
1 0 0
1 0 0
1 0 0
0 1 0
0 1 0
0 1 0
0 0 1
0 0 1
0 0 1
0 0 1
Finch.chunkmask — Functionchunkmask(n, b)A mask to evenly divide n indices into regions of size b. If m = chunkmask(b, n), then m[i, j] = b * (j - 1) < i <= b * j. Note that this specializes for the cleanup case at the end of the range.
julia> chunkmask(10, 3)
10×4 Finch.ChunkMask{Int64}:
1 0 0 0
1 0 0 0
1 0 0 0
0 1 0 0
0 1 0 0
0 1 0 0
0 0 1 0
0 0 1 0
0 0 1 0
0 0 0 1