Skip to content

Csl stencil

csl_stencil

RedundantAccumulatorInitialisation

Bases: RewritePattern

Removes redundant allocations of empty tensors with no uses other than passed as iter_arg to csl_stencil.apply. Prefer re-use where possible.

Source code in xdsl/transforms/canonicalization_patterns/csl_stencil.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class RedundantAccumulatorInitialisation(RewritePattern):
    """
    Removes redundant allocations of empty tensors with no uses other than passed
    as `iter_arg` to `csl_stencil.apply`. Prefer re-use where possible.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter
    ) -> None:
        if op.accumulator.has_more_than_one_use():
            return

        next_apply = op
        while (next_apply := next_apply.next_op) is not None:
            if (
                isinstance(next_apply, csl_stencil.ApplyOp)
                and next_apply.accumulator.has_one_use()
                and isinstance(next_apply.accumulator, OpResult)
                and isinstance(next_apply.accumulator.op, tensor.EmptyOp)
                and op.accumulator.type == next_apply.accumulator.type
            ):
                rewriter.replace_op(next_apply.accumulator.op, [], [op.accumulator])

match_and_rewrite(op: csl_stencil.ApplyOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/canonicalization_patterns/csl_stencil.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter
) -> None:
    if op.accumulator.has_more_than_one_use():
        return

    next_apply = op
    while (next_apply := next_apply.next_op) is not None:
        if (
            isinstance(next_apply, csl_stencil.ApplyOp)
            and next_apply.accumulator.has_one_use()
            and isinstance(next_apply.accumulator, OpResult)
            and isinstance(next_apply.accumulator.op, tensor.EmptyOp)
            and op.accumulator.type == next_apply.accumulator.type
        ):
            rewriter.replace_op(next_apply.accumulator.op, [], [op.accumulator])