Skip to content

Snitch allocate registers

snitch_allocate_registers

AllocateSnitchStreamingRegionRegisters

Bases: RewritePattern

Allocates the registers in the body of a snitch_stream.streaming_region operation by assigning them to the ones specified by the streams.

Source code in xdsl/transforms/snitch_allocate_registers.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class AllocateSnitchStreamingRegionRegisters(RewritePattern):
    """
    Allocates the registers in the body of a `snitch_stream.streaming_region` operation by
    assigning them to the ones specified by the streams.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: snitch_stream.StreamingRegionOp, rewriter: PatternRewriter, /
    ):
        block = op.body.block

        for index, input_stream in enumerate(block.args):
            rewriter.replace_value_with_new_type(
                input_stream, snitch.ReadableStreamType(riscv.Registers.FT[index])
            )

        input_count = len(op.inputs)

        for index, output_stream in enumerate(block.args[input_count:]):
            rewriter.replace_value_with_new_type(
                output_stream,
                snitch.WritableStreamType(riscv.Registers.FT[index + input_count]),
            )

match_and_rewrite(op: snitch_stream.StreamingRegionOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/snitch_allocate_registers.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: snitch_stream.StreamingRegionOp, rewriter: PatternRewriter, /
):
    block = op.body.block

    for index, input_stream in enumerate(block.args):
        rewriter.replace_value_with_new_type(
            input_stream, snitch.ReadableStreamType(riscv.Registers.FT[index])
        )

    input_count = len(op.inputs)

    for index, output_stream in enumerate(block.args[input_count:]):
        rewriter.replace_value_with_new_type(
            output_stream,
            snitch.WritableStreamType(riscv.Registers.FT[index + input_count]),
        )

AllocateRiscvSnitchReadRegisters

Bases: RewritePattern

Propagates the register allocation done at the stream level to the values read from the streams.

Source code in xdsl/transforms/snitch_allocate_registers.py
43
44
45
46
47
48
49
50
51
52
53
54
class AllocateRiscvSnitchReadRegisters(RewritePattern):
    """
    Propagates the register allocation done at the stream level to the values read from
    the streams.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: riscv_snitch.ReadOp, rewriter: PatternRewriter, /):
        stream_type = cast(
            snitch.ReadableStreamType[riscv.FloatRegisterType], op.stream.type
        )
        rewriter.replace_value_with_new_type(op.res, stream_type.element_type)

match_and_rewrite(op: riscv_snitch.ReadOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/snitch_allocate_registers.py
49
50
51
52
53
54
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_snitch.ReadOp, rewriter: PatternRewriter, /):
    stream_type = cast(
        snitch.ReadableStreamType[riscv.FloatRegisterType], op.stream.type
    )
    rewriter.replace_value_with_new_type(op.res, stream_type.element_type)

AllocateRiscvSnitchWriteRegisters

Bases: RewritePattern

Propagates the register allocation done at the stream level to the values written to the streams.

Source code in xdsl/transforms/snitch_allocate_registers.py
57
58
59
60
61
62
63
64
65
66
67
68
class AllocateRiscvSnitchWriteRegisters(RewritePattern):
    """
    Propagates the register allocation done at the stream level to the values written to
    the streams.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: riscv_snitch.WriteOp, rewriter: PatternRewriter, /):
        stream_type = cast(
            snitch.WritableStreamType[riscv.FloatRegisterType], op.stream.type
        )
        rewriter.replace_value_with_new_type(op.value, stream_type.element_type)

match_and_rewrite(op: riscv_snitch.WriteOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/snitch_allocate_registers.py
63
64
65
66
67
68
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_snitch.WriteOp, rewriter: PatternRewriter, /):
    stream_type = cast(
        snitch.WritableStreamType[riscv.FloatRegisterType], op.stream.type
    )
    rewriter.replace_value_with_new_type(op.value, stream_type.element_type)

SnitchAllocateRegistersPass dataclass

Bases: ModulePass

Allocates unallocated registers for snitch operations.

Source code in xdsl/transforms/snitch_allocate_registers.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
@dataclass(frozen=True)
class SnitchAllocateRegistersPass(ModulePass):
    """
    Allocates unallocated registers for snitch operations.
    """

    name = "snitch-allocate-registers"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            AllocateSnitchStreamingRegionRegisters(),
            apply_recursively=False,
        ).rewrite_module(op)
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    AllocateRiscvSnitchReadRegisters(),
                    AllocateRiscvSnitchWriteRegisters(),
                ]
            ),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'snitch-allocate-registers' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/snitch_allocate_registers.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        AllocateSnitchStreamingRegionRegisters(),
        apply_recursively=False,
    ).rewrite_module(op)
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                AllocateRiscvSnitchReadRegisters(),
                AllocateRiscvSnitchWriteRegisters(),
            ]
        ),
        apply_recursively=False,
    ).rewrite_module(op)