Skip to content

Riscv scf loop range folding

riscv_scf_loop_range_folding

HoistIndexTimesConstantOp

Bases: RewritePattern

Source code in xdsl/transforms/riscv_scf_loop_range_folding.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class HoistIndexTimesConstantOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None:
        index = op.body.block.args[0]

        # Fold until a fixed point is reached
        while True:
            if not index.has_one_use():
                # If the induction variable is used more than once, we can't fold its
                # arith ops into the loop range
                return

            user = next(iter(index.uses)).operation

            if not isinstance(user, riscv.AddOp | riscv.MulOp):
                return

            if user.rs1 is index:
                if (imm := get_constant_value(user.rs2)) is None:
                    return
            else:
                if (imm := get_constant_value(user.rs1)) is None:
                    return

            constant = imm.value.data

            match user:
                case riscv.AddOp():
                    # All the uses are multiplications by a constant, we can fold
                    rewriter.insert_op(
                        [
                            shift := riscv.LiOp(constant),
                            new_lb := riscv.AddOp(op.lb, shift),
                            new_ub := riscv.AddOp(op.ub, shift),
                        ]
                    )
                case riscv.MulOp():
                    # All the uses are multiplications by a constant, we can fold
                    rewriter.insert_op(
                        [
                            factor := riscv.LiOp(constant),
                            new_lb := riscv.MulOp(op.lb, factor),
                            new_ub := riscv.MulOp(op.ub, factor),
                            new_step := riscv.MulOp(op.step, factor),
                        ]
                    )

                    op.operands[2] = new_step.rd

            op.operands[0] = new_lb.rd
            op.operands[1] = new_ub.rd
            rewriter.replace_op(user, [], [index])

match_and_rewrite(op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/riscv_scf_loop_range_folding.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None:
    index = op.body.block.args[0]

    # Fold until a fixed point is reached
    while True:
        if not index.has_one_use():
            # If the induction variable is used more than once, we can't fold its
            # arith ops into the loop range
            return

        user = next(iter(index.uses)).operation

        if not isinstance(user, riscv.AddOp | riscv.MulOp):
            return

        if user.rs1 is index:
            if (imm := get_constant_value(user.rs2)) is None:
                return
        else:
            if (imm := get_constant_value(user.rs1)) is None:
                return

        constant = imm.value.data

        match user:
            case riscv.AddOp():
                # All the uses are multiplications by a constant, we can fold
                rewriter.insert_op(
                    [
                        shift := riscv.LiOp(constant),
                        new_lb := riscv.AddOp(op.lb, shift),
                        new_ub := riscv.AddOp(op.ub, shift),
                    ]
                )
            case riscv.MulOp():
                # All the uses are multiplications by a constant, we can fold
                rewriter.insert_op(
                    [
                        factor := riscv.LiOp(constant),
                        new_lb := riscv.MulOp(op.lb, factor),
                        new_ub := riscv.MulOp(op.ub, factor),
                        new_step := riscv.MulOp(op.step, factor),
                    ]
                )

                op.operands[2] = new_step.rd

        op.operands[0] = new_lb.rd
        op.operands[1] = new_ub.rd
        rewriter.replace_op(user, [], [index])

RiscvScfLoopRangeFoldingPass dataclass

Bases: ModulePass

Similar to scf-loop-range-folding in MLIR, folds multiplication operations into the loop range computation when possible.

Source code in xdsl/transforms/riscv_scf_loop_range_folding.py
67
68
69
70
71
72
73
74
75
76
77
78
79
class RiscvScfLoopRangeFoldingPass(ModulePass):
    """
    Similar to scf-loop-range-folding in MLIR, folds multiplication operations into the
    loop range computation when possible.
    """

    name = "riscv-scf-loop-range-folding"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            HoistIndexTimesConstantOp(),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'riscv-scf-loop-range-folding' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/riscv_scf_loop_range_folding.py
75
76
77
78
79
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        HoistIndexTimesConstantOp(),
        apply_recursively=False,
    ).rewrite_module(op)