Skip to content

Scf

scf

RehoistConstInLoops

Bases: RewritePattern

Carry out const definitions from the loops. In the future this will probably be done by the pattern rewriter itself, like it's done in the MLIR's applyPatternsAndFoldGreedily.

Source code in xdsl/transforms/canonicalization_patterns/scf.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class RehoistConstInLoops(RewritePattern):
    """
    Carry out const definitions from the loops.
    In the future this will probably be done by the pattern rewriter itself, like it's
    done in the MLIR's applyPatternsAndFoldGreedily.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: scf.ForOp, rewriter: PatternRewriter) -> None:
        for child_op in op.body.ops:
            if child_op.has_trait(ConstantLike):
                # we only rehoist consts that are not embeded in another region inside the loop
                rewriter.insert_op(new_const := child_op.clone())
                rewriter.replace_op(child_op, (), new_const.results)

match_and_rewrite(op: scf.ForOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/canonicalization_patterns/scf.py
24
25
26
27
28
29
30
@op_type_rewrite_pattern
def match_and_rewrite(self, op: scf.ForOp, rewriter: PatternRewriter) -> None:
    for child_op in op.body.ops:
        if child_op.has_trait(ConstantLike):
            # we only rehoist consts that are not embeded in another region inside the loop
            rewriter.insert_op(new_const := child_op.clone())
            rewriter.replace_op(child_op, (), new_const.results)

SimplifyTrivialLoops

Bases: RewritePattern

Rewriting pattern that erases loops that are known not to iterate, replaces single-iteration loops with their bodies, and removes empty loops that iterate at least once and only return values defined outside of the loop.

Source code in xdsl/transforms/canonicalization_patterns/scf.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class SimplifyTrivialLoops(RewritePattern):
    """
    Rewriting pattern that erases loops that are known not to iterate, replaces
    single-iteration loops with their bodies, and removes empty loops that iterate at
    least once and only return values defined outside of the loop.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: scf.ForOp, rewriter: PatternRewriter) -> None:
        # If the upper bound is the same as the lower bound, the loop does not iterate,
        # just remove it.
        if (lb := const_evaluate_operand(op.lb)) is None:
            return
        if (ub := const_evaluate_operand(op.ub)) is None:
            return

        if lb == ub:
            rewriter.replace_op(op, (), op.iter_args)
            return

        # If the loop is known to have 0 iterations, remove it.
        if (diff := ub - lb) <= 0:
            rewriter.replace_op(op, (), op.iter_args)
            return

        if (step := const_evaluate_operand(op.step)) is None:
            return

        # If the loop is known to have 1 iteration, inline its body and remove the loop.
        # TODO: handle signless values
        if step >= diff:
            block_args = (op.lb, *op.iter_args)
            replace_op_with_region(
                rewriter,
                op,
                op.body,
                block_args,
            )

match_and_rewrite(op: scf.ForOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/canonicalization_patterns/scf.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@op_type_rewrite_pattern
def match_and_rewrite(self, op: scf.ForOp, rewriter: PatternRewriter) -> None:
    # If the upper bound is the same as the lower bound, the loop does not iterate,
    # just remove it.
    if (lb := const_evaluate_operand(op.lb)) is None:
        return
    if (ub := const_evaluate_operand(op.ub)) is None:
        return

    if lb == ub:
        rewriter.replace_op(op, (), op.iter_args)
        return

    # If the loop is known to have 0 iterations, remove it.
    if (diff := ub - lb) <= 0:
        rewriter.replace_op(op, (), op.iter_args)
        return

    if (step := const_evaluate_operand(op.step)) is None:
        return

    # If the loop is known to have 1 iteration, inline its body and remove the loop.
    # TODO: handle signless values
    if step >= diff:
        block_args = (op.lb, *op.iter_args)
        replace_op_with_region(
            rewriter,
            op,
            op.body,
            block_args,
        )

IfPropagateConstantCondition

Bases: RewritePattern

Source code in xdsl/transforms/canonicalization_patterns/scf.py
82
83
84
85
86
87
88
89
90
91
92
class IfPropagateConstantCondition(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: scf.IfOp, rewriter: PatternRewriter) -> None:
        if (cond := const_evaluate_operand(op.cond)) is None:
            return
        if not cond and not op.false_region.blocks:
            # Cannot use helper below as false region is not single-block
            rewriter.erase_op(op)
            return
        region = op.true_region if cond else op.false_region
        replace_op_with_region(rewriter, op, region)

match_and_rewrite(op: scf.IfOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/canonicalization_patterns/scf.py
83
84
85
86
87
88
89
90
91
92
@op_type_rewrite_pattern
def match_and_rewrite(self, op: scf.IfOp, rewriter: PatternRewriter) -> None:
    if (cond := const_evaluate_operand(op.cond)) is None:
        return
    if not cond and not op.false_region.blocks:
        # Cannot use helper below as false region is not single-block
        rewriter.erase_op(op)
        return
    region = op.true_region if cond else op.false_region
    replace_op_with_region(rewriter, op, region)

SingleBlockExecuteInliner

Bases: RewritePattern

Source code in xdsl/transforms/canonicalization_patterns/scf.py
 95
 96
 97
 98
 99
100
101
102
103
class SingleBlockExecuteInliner(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: scf.ExecuteRegionOp, rewriter: PatternRewriter
    ) -> None:
        assert op.region.first_block is not None
        if op.region.first_block is not op.region.last_block:
            return
        replace_op_with_region(rewriter, op, op.region)

match_and_rewrite(op: scf.ExecuteRegionOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/canonicalization_patterns/scf.py
 96
 97
 98
 99
100
101
102
103
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: scf.ExecuteRegionOp, rewriter: PatternRewriter
) -> None:
    assert op.region.first_block is not None
    if op.region.first_block is not op.region.last_block:
        return
    replace_op_with_region(rewriter, op, op.region)

replace_op_with_region(rewriter: PatternRewriter, op: Operation, region: Region, args: Sequence[SSAValue] = ())

Replaces the given op with the contents of the given single-block region, using the operands of the block terminator to replace operation results.

:raises ValueError: if the region does not have a single block.

Source code in xdsl/transforms/canonicalization_patterns/scf.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def replace_op_with_region(
    rewriter: PatternRewriter,
    op: Operation,
    region: Region,
    args: Sequence[SSAValue] = (),
):
    """
    Replaces the given op with the contents of the given single-block region, using the
    operands of the block terminator to replace operation results.

    :raises ValueError: if the region does not have a single block.
    """

    block = region.block
    terminator = block.last_op
    assert terminator is not None
    rewriter.inline_block(block, InsertPoint.before(op), args)
    rewriter.replace_op(op, (), terminator.operands)
    rewriter.erase_op(terminator)