Mask Sugar
In Finch, expressions like i == j
are treated as a sugar for mask tensors, which can be used to encode fancy iteration patterns. For example, the expression i == j
is converted to a diagonal boolean mask tensor DiagMask()[i, j]
, which allows an expression like
@finch begin
for i in _, j in _
if i == j
s[] += A[i, j]
end
end
end
to compile to something like
for i in 1:n
s[] += A[i, i]
end
There are several mask tensors and syntaxes available, summarized in the following table where i, j
are indices:
Expression | Transformed Expression |
---|---|
i < j | UpTriMask()[i, j - 1] |
i <= j | UpTriMask()[i, j] |
i > j | LoTriMask()[i, j + 1] |
i >= j | LoTriMask()[i, j] |
i == j | DiagMask()[i, j] |
i != j | !(DiagMask()[i, j]) |
Note that either i
or j
may be expressions, so long as the expression is constant with respect to the loop over the index.
The mask tensors are described below:
Finch.uptrimask
— Constantuptrimask
A mask for an upper triangular tensor, uptrimask[i, j] = i <= j
. Note that this specializes each column for the cases where i <= j
and i > j
.
Finch.lotrimask
— Constantlotrimask
A mask for an upper triangular tensor, lotrimask[i, j] = i >= j
. Note that this specializes each column for the cases where i < j
and i >= j
.
Finch.diagmask
— Constantdiagmask
A mask for a diagonal tensor, diagmask[i, j] = i == j
. Note that this specializes each column for the cases where i < j
, i == j
, and i > j
.
Finch.bandmask
— Constantbandmask
A mask for a banded tensor, bandmask[i, j, k] = j <= i <= k
. Note that this specializes each column for the cases where i < j
, j <= i <= k
, and k < i
.
Finch.splitmask
— Functionsplitmask(n, P)
A mask to evenly divide n
indices into P regions. If M = splitmask(P, n)
, then M[i, j] = fld(n * (j - 1), P) <= i < fld(n * j, P)
.
julia> splitmask(10, 3)
10×3 Finch.SplitMask{Int64}:
1 0 0
1 0 0
1 0 0
0 1 0
0 1 0
0 1 0
0 0 1
0 0 1
0 0 1
0 0 1
Finch.chunkmask
— Functionchunkmask(n, b)
A mask to evenly divide n
indices into regions of size b
. If m
= chunkmask(b, n), then
m[i, j] = b * (j - 1) < i <= b * j`. Note that this specializes for the cleanup case at the end of the range.
julia> chunkmask(10, 3)
10×4 Finch.ChunkMask{Int64}:
1 0 0 0
1 0 0 0
1 0 0 0
0 1 0 0
0 1 0 0
0 1 0 0
0 0 1 0
0 0 1 0
0 0 1 0
0 0 0 1