@dataclass(frozen=True)
class MemRefStreamGenericLegalize(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
# Collect block arguments that need to be legalized
legalizations: dict[int, StreamingVectorLegalizationType] = {}
for i, arg in enumerate(op.body.block.args):
legal = _legalize_attr(arg.type)
if not isinstance(legal, StreamingAlreadyLegalType):
legalizations[i] = legal
if not legalizations:
return
if op.iterator_types.data[-1].data != IteratorType.PARALLEL:
raise DiagnosticException(
"iterators other than 'parallel' are not supported yet"
)
# Check that vectorized bounds are compatible with all no. of lanes
# involved in legalizations
innermost_bound = op.bounds.data[-1].value.data
vector_lengths: set[int] = set()
for i, v in legalizations.items():
n_lanes: int = v.get_shape()[0]
if innermost_bound % n_lanes != 0:
raise ValueError(
f"no. of vector lanes ({n_lanes}) introduced to legalize argument #{i} "
f"is not a divisor for the innermost dimension's bound ({innermost_bound})"
)
vector_lengths.add(n_lanes)
if len(vector_lengths) != 1:
# FIXME we should deal with heterogeneous generic ops
raise NotImplementedError(
"cannot legalize heterogeneous block arguments yet"
)
vlen = next(iter(vector_lengths))
# Fix iteration bounds accordingly
new_bounds = list(op.bounds)
new_bounds.pop()
new_bounds.append(IntegerAttr.from_index_int_value(innermost_bound // vlen))
# Fix access maps accordingly
new_dims = [AffineExpr.dimension(i) for i in range(len(op.bounds))]
new_dims[-1] = new_dims[-1] * vlen
new_maps = tuple(
AffineMapAttr(
m.data.replace_dims_and_symbols(new_dims, (), len(new_dims), 0)
)
for m in op.indexing_maps
)
# Legalize block arguments
new_body = op.body.clone()
# Starting point for block legalization
to_be_legalized: set[Operation] = set()
for i, arg in enumerate(new_body.block.args):
if i not in legalizations:
continue
arg = rewriter.replace_value_with_new_type(arg, legalizations[i])
to_be_legalized.update(use.operation for use in arg.uses)
# Legalize payload
_legalize_block(new_body.block, to_be_legalized, rewriter)
rewriter.replace_op(
op,
memref_stream.GenericOp(
op.inputs,
op.outputs,
op.inits,
new_body,
ArrayAttr(new_maps),
op.iterator_types,
ArrayAttr(new_bounds),
op.init_indices,
),
)