Skip to content

Convert riscv scf for to frep

convert_riscv_scf_for_to_frep

ScfForLowering

Bases: RewritePattern

Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class ScfForLowering(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None:
        body_block = op.body.block
        indvar = body_block.args[0]
        # 1. Induction variable is not used
        if indvar.uses:
            return

        # 2. Step is 1
        match op.step:
            case IntegerAttr(value=IntAttr(1)):
                pass
            case step_ssa if (
                isinstance(step_ssa, SSAValue)
                and isinstance(step_op := step_ssa.owner, rv32.LiOp)
                and isinstance(step_op.immediate, builtin.IntegerAttr)
                and step_op.immediate.value.data == 1
            ):
                pass
            case _:
                return

        # 3. All operations in the loop operate on float registers
        if not all(
            isinstance(
                value.type,
                riscv.FloatRegisterType
                | snitch.ReadableStreamType
                | snitch.WritableStreamType,
            )
            for o in body_block.ops
            for value in chain(o.operands, o.results)
        ):
            return

        # 4. All operations are pure or one of
        #     a) riscv_snitch.read
        #     b) riscv_snitch.write
        #     c) builtin.unrealized_conversion_cast
        if not all(
            isinstance(o, riscv_scf.YieldOp) or riscv_snitch.is_valid_frep_body_op(o)
            for o in body_block.ops
        ):
            return

        rewriter.erase_block_argument(indvar)
        rewriter.replace_op(
            op,
            (
                iter_count := riscv.SubOp(op.ub, op.lb),
                iter_count_minus_one := riscv.AddiOp(iter_count, -1),
                riscv_snitch.FrepOuterOp(
                    iter_count_minus_one,
                    rewriter.move_region_contents_to_new_regions(op.body),
                    op.iter_args,
                ),
            ),
        )

match_and_rewrite(op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None:
    body_block = op.body.block
    indvar = body_block.args[0]
    # 1. Induction variable is not used
    if indvar.uses:
        return

    # 2. Step is 1
    match op.step:
        case IntegerAttr(value=IntAttr(1)):
            pass
        case step_ssa if (
            isinstance(step_ssa, SSAValue)
            and isinstance(step_op := step_ssa.owner, rv32.LiOp)
            and isinstance(step_op.immediate, builtin.IntegerAttr)
            and step_op.immediate.value.data == 1
        ):
            pass
        case _:
            return

    # 3. All operations in the loop operate on float registers
    if not all(
        isinstance(
            value.type,
            riscv.FloatRegisterType
            | snitch.ReadableStreamType
            | snitch.WritableStreamType,
        )
        for o in body_block.ops
        for value in chain(o.operands, o.results)
    ):
        return

    # 4. All operations are pure or one of
    #     a) riscv_snitch.read
    #     b) riscv_snitch.write
    #     c) builtin.unrealized_conversion_cast
    if not all(
        isinstance(o, riscv_scf.YieldOp) or riscv_snitch.is_valid_frep_body_op(o)
        for o in body_block.ops
    ):
        return

    rewriter.erase_block_argument(indvar)
    rewriter.replace_op(
        op,
        (
            iter_count := riscv.SubOp(op.ub, op.lb),
            iter_count_minus_one := riscv.AddiOp(iter_count, -1),
            riscv_snitch.FrepOuterOp(
                iter_count_minus_one,
                rewriter.move_region_contents_to_new_regions(op.body),
                op.iter_args,
            ),
        ),
    )

ScfYieldLowering

Bases: RewritePattern

Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
78
79
80
81
82
83
84
class ScfYieldLowering(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: riscv_scf.YieldOp, rewriter: PatternRewriter
    ) -> None:
        if isinstance(op.parent_op(), riscv_snitch.FRepOperation):
            rewriter.replace_op(op, riscv_snitch.FrepYieldOp(*op.operands))

match_and_rewrite(op: riscv_scf.YieldOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
79
80
81
82
83
84
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: riscv_scf.YieldOp, rewriter: PatternRewriter
) -> None:
    if isinstance(op.parent_op(), riscv_snitch.FRepOperation):
        rewriter.replace_op(op, riscv_snitch.FrepYieldOp(*op.operands))

ConvertRiscvScfForToFrepPass dataclass

Bases: ModulePass

Converts all riscv_scf.for loops to riscv_snitch.frep_outer loops, if the loops pass the riscv_snitch.frep_outer verification criteria:

  1. The induction variable is not used
  2. Step is 1
  3. All operations in the loop all operate on float registers
  4. All operations are pure or one of a) riscv_snitch.read b) riscv_snitch.write c) builtin.unrealized_conversion_cast
Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class ConvertRiscvScfForToFrepPass(ModulePass):
    """
    Converts all riscv_scf.for loops to riscv_snitch.frep_outer loops, if the loops pass
    the riscv_snitch.frep_outer verification criteria:

    1. The induction variable is not used
    2. Step is 1
    3. All operations in the loop all operate on float registers
    4. All operations are pure or one of
        a) riscv_snitch.read
        b) riscv_snitch.write
        c) builtin.unrealized_conversion_cast

    """

    name = "convert-riscv-scf-for-to-frep"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ScfYieldLowering(),
                    ScfForLowering(),
                ]
            ),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-riscv-scf-for-to-frep' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/convert_riscv_scf_for_to_frep.py
104
105
106
107
108
109
110
111
112
113
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ScfYieldLowering(),
                ScfForLowering(),
            ]
        ),
        apply_recursively=False,
    ).rewrite_module(op)