Skip to content

Scf for loop range folding

scf_for_loop_range_folding

ScfForLoopRangeFolding

Bases: RewritePattern

Source code in xdsl/transforms/scf_for_loop_range_folding.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class ScfForLoopRangeFolding(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: scf.ForOp, rewriter: PatternRewriter) -> None:
        index = op.body.block.args[0]

        # Fold until a fixed point is reached
        while True:
            if not index.has_one_use():
                # If the induction variable is used more than once, we can't fold its
                # arith ops into the loop range
                return

            user = next(iter(index.uses)).operation

            if not isinstance(user, arith.AddiOp | arith.MuliOp):
                return

            if user.operands[0] is index:
                if not is_foldable(user.operands[1], op):
                    return
                folding_const = user.operands[1]
            else:
                if not is_foldable(user.operands[0], op):
                    return
                folding_const = user.operands[0]

            match user:
                case arith.AddiOp():
                    rewriter.insert_op(
                        [
                            new_lb := arith.AddiOp(op.lb, folding_const),
                            new_ub := arith.AddiOp(op.ub, folding_const),
                        ]
                    )
                case arith.MuliOp():
                    rewriter.insert_op(
                        [
                            new_lb := arith.MuliOp(op.lb, folding_const),
                            new_ub := arith.MuliOp(op.ub, folding_const),
                            new_step := arith.MuliOp(op.step, folding_const),
                        ]
                    )
                    op.operands[2] = new_step.result

            op.operands[0] = new_lb.result
            op.operands[1] = new_ub.result

            rewriter.replace_op(user, [], [index])

match_and_rewrite(op: scf.ForOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/scf_for_loop_range_folding.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@op_type_rewrite_pattern
def match_and_rewrite(self, op: scf.ForOp, rewriter: PatternRewriter) -> None:
    index = op.body.block.args[0]

    # Fold until a fixed point is reached
    while True:
        if not index.has_one_use():
            # If the induction variable is used more than once, we can't fold its
            # arith ops into the loop range
            return

        user = next(iter(index.uses)).operation

        if not isinstance(user, arith.AddiOp | arith.MuliOp):
            return

        if user.operands[0] is index:
            if not is_foldable(user.operands[1], op):
                return
            folding_const = user.operands[1]
        else:
            if not is_foldable(user.operands[0], op):
                return
            folding_const = user.operands[0]

        match user:
            case arith.AddiOp():
                rewriter.insert_op(
                    [
                        new_lb := arith.AddiOp(op.lb, folding_const),
                        new_ub := arith.AddiOp(op.ub, folding_const),
                    ]
                )
            case arith.MuliOp():
                rewriter.insert_op(
                    [
                        new_lb := arith.MuliOp(op.lb, folding_const),
                        new_ub := arith.MuliOp(op.ub, folding_const),
                        new_step := arith.MuliOp(op.step, folding_const),
                    ]
                )
                op.operands[2] = new_step.result

        op.operands[0] = new_lb.result
        op.operands[1] = new_ub.result

        rewriter.replace_op(user, [], [index])

ScfForLoopRangeFoldingPass dataclass

Bases: ModulePass

xdsl implementation of the pass with the same name

Source code in xdsl/transforms/scf_for_loop_range_folding.py
67
68
69
70
71
72
73
74
75
76
77
class ScfForLoopRangeFoldingPass(ModulePass):
    """
    xdsl implementation of the pass with the same name
    """

    name = "scf-for-loop-range-folding"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            ScfForLoopRangeFolding(), apply_recursively=False, walk_regions_first=True
        ).rewrite_module(op)

name = 'scf-for-loop-range-folding' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/scf_for_loop_range_folding.py
74
75
76
77
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        ScfForLoopRangeFolding(), apply_recursively=False, walk_regions_first=True
    ).rewrite_module(op)

is_foldable(val: SSAValue, for_op: scf.ForOp)

Source code in xdsl/transforms/scf_for_loop_range_folding.py
13
14
def is_foldable(val: SSAValue, for_op: scf.ForOp):
    return not for_op.is_ancestor(val.owner)