Skip to content

Riscv cf

riscv_cf

ElideConstantBranches

Bases: RewritePattern

Source code in xdsl/transforms/canonicalization_patterns/riscv_cf.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class ElideConstantBranches(RewritePattern):
    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
        if not isinstance(op, ConditionalBranchOperation):
            return

        rs1, rs2 = map(get_constant_value, (op.rs1, op.rs2))
        if rs1 is None or rs2 is None:
            return

        # check if the op would take the branch or not
        # TODO: take bitwidth into account
        branch_taken = op.const_evaluate(rs1.value.data, rs2.value.data, 32)

        # if branch is always taken, replace by jump
        if branch_taken:
            rewriter.replace_op(
                op,
                JOp(
                    op.then_arguments,
                    op.then_block,
                    comment=f"Constant folded {op.name}",
                ),
            )
        # if branch is never taken, replace by "fall through"
        else:
            rewriter.replace_op(
                op,
                BranchOp(
                    op.else_arguments,
                    op.else_block,
                    comment=f"Constant folded {op.name}",
                ),
            )

match_and_rewrite(op: Operation, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/riscv_cf.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
    if not isinstance(op, ConditionalBranchOperation):
        return

    rs1, rs2 = map(get_constant_value, (op.rs1, op.rs2))
    if rs1 is None or rs2 is None:
        return

    # check if the op would take the branch or not
    # TODO: take bitwidth into account
    branch_taken = op.const_evaluate(rs1.value.data, rs2.value.data, 32)

    # if branch is always taken, replace by jump
    if branch_taken:
        rewriter.replace_op(
            op,
            JOp(
                op.then_arguments,
                op.then_block,
                comment=f"Constant folded {op.name}",
            ),
        )
    # if branch is never taken, replace by "fall through"
    else:
        rewriter.replace_op(
            op,
            BranchOp(
                op.else_arguments,
                op.else_block,
                comment=f"Constant folded {op.name}",
            ),
        )