Skip to content

Control flow hoist

control_flow_hoist

AffineIfHoistPattern

Bases: RewritePattern

Hoist everything out of a pure affine.if.

Source code in xdsl/transforms/control_flow_hoist.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class AffineIfHoistPattern(RewritePattern):
    """
    Hoist everything out of a pure affine.if.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: affine.IfOp, rewriter: PatternRewriter):
        # Easy bail out for now
        if not (is_speculatable(op) and is_side_effect_free(op)):
            return

        hoist_all(
            rewriter,
            chain(op.then_region.ops, op.else_region.ops),
            InsertPoint.before(op),
        )
        if not rewriter.has_done_action:
            return
        block = op.parent
        if block:
            cse(block, rewriter)

match_and_rewrite(op: affine.IfOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/control_flow_hoist.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@op_type_rewrite_pattern
def match_and_rewrite(self, op: affine.IfOp, rewriter: PatternRewriter):
    # Easy bail out for now
    if not (is_speculatable(op) and is_side_effect_free(op)):
        return

    hoist_all(
        rewriter,
        chain(op.then_region.ops, op.else_region.ops),
        InsertPoint.before(op),
    )
    if not rewriter.has_done_action:
        return
    block = op.parent
    if block:
        cse(block, rewriter)

SCFIfHoistPattern

Bases: RewritePattern

Hoist everything out of a pure scf.if

Source code in xdsl/transforms/control_flow_hoist.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class SCFIfHoistPattern(RewritePattern):
    """
    Hoist everything out of a pure scf.if
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: scf.IfOp, rewriter: PatternRewriter):
        # Easy bail out for now
        if not (is_speculatable(op) and is_side_effect_free(op)):
            return

        hoist_all(
            rewriter,
            chain(op.true_region.ops, op.false_region.ops),
            InsertPoint.before(op),
        )

        # Perf-friendly cleanup
        # None needed if nothing happened
        if not rewriter.has_done_action:
            return
        block = op.parent
        if block:
            # If we hoisted some ops, run CSE on that block to not keep pushing duplicates upward.
            cse(block, rewriter)

match_and_rewrite(op: scf.IfOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/control_flow_hoist.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@op_type_rewrite_pattern
def match_and_rewrite(self, op: scf.IfOp, rewriter: PatternRewriter):
    # Easy bail out for now
    if not (is_speculatable(op) and is_side_effect_free(op)):
        return

    hoist_all(
        rewriter,
        chain(op.true_region.ops, op.false_region.ops),
        InsertPoint.before(op),
    )

    # Perf-friendly cleanup
    # None needed if nothing happened
    if not rewriter.has_done_action:
        return
    block = op.parent
    if block:
        # If we hoisted some ops, run CSE on that block to not keep pushing duplicates upward.
        cse(block, rewriter)

ControlFlowHoistPass dataclass

Bases: ModulePass

Hoist all hoistable ops from control flow ops.

Source code in xdsl/transforms/control_flow_hoist.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class ControlFlowHoistPass(ModulePass):
    """
    Hoist all hoistable ops from control flow ops.
    """

    name = "control-flow-hoist"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    AffineIfHoistPattern(),
                    SCFIfHoistPattern(),
                ]
            ),
            walk_regions_first=True,
        ).rewrite_module(op)

name = 'control-flow-hoist' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/control_flow_hoist.py
 97
 98
 99
100
101
102
103
104
105
106
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                AffineIfHoistPattern(),
                SCFIfHoistPattern(),
            ]
        ),
        walk_regions_first=True,
    ).rewrite_module(op)

hoist_all(rewriter: PatternRewriter, ops: Iterable[Operation], at: InsertPoint, value_mapper: dict[SSAValue, SSAValue] | None = None)

Source code in xdsl/transforms/control_flow_hoist.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def hoist_all(
    rewriter: PatternRewriter,
    ops: Iterable[Operation],
    at: InsertPoint,
    value_mapper: dict[SSAValue, SSAValue] | None = None,
):
    if value_mapper is None:
        value_mapper = {}
    for o in ops:
        if o.has_trait(IsTerminator, value_if_unregistered=False):
            continue
        new_op = o.clone(value_mapper=value_mapper)
        value_mapper |= {
            old: new for old, new in zip(o.results, new_op.results, strict=True)
        }
        rewriter.insert_op(new_op, at)
        rewriter.replace_op(o, [], new_op.results)