Skip to content

Convert riscv scf for to frep

convert_riscv_scf_for_to_frep

ALLOWED_FREP_OP_LOWERING_TYPES = (*(riscv_snitch.ALLOWED_FREP_OP_TYPES), riscv_scf.YieldOp) module-attribute

ScfForLowering

Bases: RewritePattern

Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class ScfForLowering(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None:
        body_block = op.body.block
        indvar = body_block.args[0]
        if indvar.uses:
            # 1. Induction variable is used
            return

        if not (
            isinstance(step_op := op.step.owner, riscv.LiOp)
            and isinstance(step_op.immediate, builtin.IntegerAttr)
            and step_op.immediate.value.data == 1
        ):
            # 2. Step is 1
            return

        if not all(
            isinstance(
                value.type,
                riscv.FloatRegisterType
                | snitch.ReadableStreamType
                | snitch.WritableStreamType,
            )
            for o in body_block.ops
            for value in chain(o.operands, o.results)
        ):
            # 3. All operations in the loop operate on float registers
            return

        if not all(
            isinstance(o, ALLOWED_FREP_OP_LOWERING_TYPES) or o.has_trait(Pure)
            for o in body_block.ops
        ):
            # 4. All operations are pure or one of
            #     a) riscv_snitch.read
            #     b) riscv_snitch.write
            #     c) builtin.unrealized_conversion_cast
            return

        rewriter.erase_block_argument(indvar)
        rewriter.replace_op(
            op,
            (
                iter_count := riscv.SubOp(op.ub, op.lb),
                iter_count_minus_one := riscv.AddiOp(iter_count, -1),
                riscv_snitch.FrepOuterOp(
                    iter_count_minus_one,
                    rewriter.move_region_contents_to_new_regions(op.body),
                    op.iter_args,
                ),
            ),
        )

match_and_rewrite(op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None:
    body_block = op.body.block
    indvar = body_block.args[0]
    if indvar.uses:
        # 1. Induction variable is used
        return

    if not (
        isinstance(step_op := op.step.owner, riscv.LiOp)
        and isinstance(step_op.immediate, builtin.IntegerAttr)
        and step_op.immediate.value.data == 1
    ):
        # 2. Step is 1
        return

    if not all(
        isinstance(
            value.type,
            riscv.FloatRegisterType
            | snitch.ReadableStreamType
            | snitch.WritableStreamType,
        )
        for o in body_block.ops
        for value in chain(o.operands, o.results)
    ):
        # 3. All operations in the loop operate on float registers
        return

    if not all(
        isinstance(o, ALLOWED_FREP_OP_LOWERING_TYPES) or o.has_trait(Pure)
        for o in body_block.ops
    ):
        # 4. All operations are pure or one of
        #     a) riscv_snitch.read
        #     b) riscv_snitch.write
        #     c) builtin.unrealized_conversion_cast
        return

    rewriter.erase_block_argument(indvar)
    rewriter.replace_op(
        op,
        (
            iter_count := riscv.SubOp(op.ub, op.lb),
            iter_count_minus_one := riscv.AddiOp(iter_count, -1),
            riscv_snitch.FrepOuterOp(
                iter_count_minus_one,
                rewriter.move_region_contents_to_new_regions(op.body),
                op.iter_args,
            ),
        ),
    )

ScfYieldLowering

Bases: RewritePattern

Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
76
77
78
79
80
81
82
class ScfYieldLowering(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: riscv_scf.YieldOp, rewriter: PatternRewriter
    ) -> None:
        if isinstance(op.parent_op(), riscv_snitch.FRepOperation):
            rewriter.replace_op(op, riscv_snitch.FrepYieldOp(*op.operands))

match_and_rewrite(op: riscv_scf.YieldOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
77
78
79
80
81
82
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: riscv_scf.YieldOp, rewriter: PatternRewriter
) -> None:
    if isinstance(op.parent_op(), riscv_snitch.FRepOperation):
        rewriter.replace_op(op, riscv_snitch.FrepYieldOp(*op.operands))

ConvertRiscvScfForToFrepPass dataclass

Bases: ModulePass

Converts all riscv_scf.for loops to riscv_snitch.frep_outer loops, if the loops pass the riscv_snitch.frep_outer verification criteria:

  1. The induction variable is not used
  2. Step is 1
  3. All operations in the loop all operate on float registers
  4. All operations are pure or one of a) riscv_snitch.read b) riscv_snitch.write c) builtin.unrealized_conversion_cast
Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class ConvertRiscvScfForToFrepPass(ModulePass):
    """
    Converts all riscv_scf.for loops to riscv_snitch.frep_outer loops, if the loops pass
    the riscv_snitch.frep_outer verification criteria:

    1. The induction variable is not used
    2. Step is 1
    3. All operations in the loop all operate on float registers
    4. All operations are pure or one of
        a) riscv_snitch.read
        b) riscv_snitch.write
        c) builtin.unrealized_conversion_cast

    """

    name = "convert-riscv-scf-for-to-frep"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ScfYieldLowering(),
                    ScfForLowering(),
                ]
            ),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-riscv-scf-for-to-frep' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
102
103
104
105
106
107
108
109
110
111
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ScfYieldLowering(),
                ScfForLowering(),
            ]
        ),
        apply_recursively=False,
    ).rewrite_module(op)