Skip to content

Convert memref stream to snitch stream

convert_memref_stream_to_snitch_stream

ReadOpLowering

Bases: RewritePattern

Source code in xdsl/transforms/convert_memref_stream_to_snitch_stream.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class ReadOpLowering(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: memref_stream.ReadOp, rewriter: PatternRewriter
    ) -> None:
        stream_type = op.stream.type
        assert isinstance(stream_type, memref_stream.ReadableStreamType)
        value_type = cast(
            memref_stream.ReadableStreamType[Attribute], stream_type
        ).element_type
        if not snitch_stream_element_type_is_valid(value_type):
            raise DiagnosticException(
                f"Invalid snitch stream element type {value_type}"
            )
        register_type = riscv.Registers.UNALLOCATED_FLOAT

        new_stream = UnrealizedConversionCastOp.get(
            (op.stream,), (snitch.ReadableStreamType(register_type),)
        )
        new_op = riscv_snitch.ReadOp(new_stream.results[0])
        if op.res.has_one_use():
            new_mv = ()
            new_vals = (new_op.res,)
        else:
            new_mv, new_vals = move_to_unallocated_regs(
                (new_op.res,),
                (value_type,),
            )
        new_res = UnrealizedConversionCastOp.get(
            new_vals,
            (value_type,),
        )

        rewriter.replace_op(
            op,
            (new_stream, new_op, *new_mv, new_res),
        )

match_and_rewrite(op: memref_stream.ReadOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_memref_stream_to_snitch_stream.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: memref_stream.ReadOp, rewriter: PatternRewriter
) -> None:
    stream_type = op.stream.type
    assert isinstance(stream_type, memref_stream.ReadableStreamType)
    value_type = cast(
        memref_stream.ReadableStreamType[Attribute], stream_type
    ).element_type
    if not snitch_stream_element_type_is_valid(value_type):
        raise DiagnosticException(
            f"Invalid snitch stream element type {value_type}"
        )
    register_type = riscv.Registers.UNALLOCATED_FLOAT

    new_stream = UnrealizedConversionCastOp.get(
        (op.stream,), (snitch.ReadableStreamType(register_type),)
    )
    new_op = riscv_snitch.ReadOp(new_stream.results[0])
    if op.res.has_one_use():
        new_mv = ()
        new_vals = (new_op.res,)
    else:
        new_mv, new_vals = move_to_unallocated_regs(
            (new_op.res,),
            (value_type,),
        )
    new_res = UnrealizedConversionCastOp.get(
        new_vals,
        (value_type,),
    )

    rewriter.replace_op(
        op,
        (new_stream, new_op, *new_mv, new_res),
    )

WriteOpLowering

Bases: RewritePattern

Source code in xdsl/transforms/convert_memref_stream_to_snitch_stream.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class WriteOpLowering(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: memref_stream.WriteOp, rewriter: PatternRewriter
    ) -> None:
        stream_type = op.stream.type
        assert isinstance(stream_type, memref_stream.WritableStreamType)
        value_type = cast(
            memref_stream.WritableStreamType[Attribute], stream_type
        ).element_type
        if not snitch_stream_element_type_is_valid(value_type):
            raise DiagnosticException(
                f"Invalid snitch stream element type {value_type}"
            )
        register_type = riscv.Registers.UNALLOCATED_FLOAT

        new_stream = UnrealizedConversionCastOp.get(
            (op.stream,), (snitch.WritableStreamType(register_type),)
        )
        cast_op = UnrealizedConversionCastOp.get((op.value,), (register_type,))
        if isinstance(defining_op := op.value.owner, Operation) and (
            defining_op.parent_region() is op.parent_region()
            and not isinstance(defining_op, memref_stream.ReadOp)
        ):
            move_ops = ()
            new_values = cast_op.results
        else:
            move_ops = (
                riscv.FMvDOp(cast_op.results[0], rd=riscv.Registers.UNALLOCATED_FLOAT),
            )
            new_values = move_ops[0].results
        new_write = riscv_snitch.WriteOp(new_values[0], new_stream.results[0])

        rewriter.replace_op(
            op,
            (new_stream, cast_op, *move_ops, new_write),
        )

match_and_rewrite(op: memref_stream.WriteOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_memref_stream_to_snitch_stream.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: memref_stream.WriteOp, rewriter: PatternRewriter
) -> None:
    stream_type = op.stream.type
    assert isinstance(stream_type, memref_stream.WritableStreamType)
    value_type = cast(
        memref_stream.WritableStreamType[Attribute], stream_type
    ).element_type
    if not snitch_stream_element_type_is_valid(value_type):
        raise DiagnosticException(
            f"Invalid snitch stream element type {value_type}"
        )
    register_type = riscv.Registers.UNALLOCATED_FLOAT

    new_stream = UnrealizedConversionCastOp.get(
        (op.stream,), (snitch.WritableStreamType(register_type),)
    )
    cast_op = UnrealizedConversionCastOp.get((op.value,), (register_type,))
    if isinstance(defining_op := op.value.owner, Operation) and (
        defining_op.parent_region() is op.parent_region()
        and not isinstance(defining_op, memref_stream.ReadOp)
    ):
        move_ops = ()
        new_values = cast_op.results
    else:
        move_ops = (
            riscv.FMvDOp(cast_op.results[0], rd=riscv.Registers.UNALLOCATED_FLOAT),
        )
        new_values = move_ops[0].results
    new_write = riscv_snitch.WriteOp(new_values[0], new_stream.results[0])

    rewriter.replace_op(
        op,
        (new_stream, cast_op, *move_ops, new_write),
    )

StreamOpLowering

Bases: RewritePattern

Source code in xdsl/transforms/convert_memref_stream_to_snitch_stream.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class StreamOpLowering(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: memref_stream.StreamingRegionOp, rewriter: PatternRewriter
    ) -> None:
        operand_types = tuple(
            cast(memref.MemRefType, value_type)
            for value in op.operands
            if isinstance(value_type := value.type, memref.MemRefType)
        )
        stride_patterns = tuple(
            snitch_stream.StridePattern(
                ArrayAttr(ub.value for ub in pattern.ub),
                ArrayAttr(
                    IntAttr(stride)
                    for stride in strides_for_affine_map(
                        pattern.index_map.data, memref_type
                    )
                ),
            ).simplified()
            for pattern, memref_type in zip(op.patterns, operand_types, strict=True)
        )
        if len(set(stride_patterns)) == 1:
            stride_patterns = (stride_patterns[0],)
        new_operands = cast_operands_to_regs(rewriter, op)
        new_inputs = new_operands[: len(op.inputs)]
        new_outputs = new_operands[len(op.inputs) :]
        freg = riscv.Registers.UNALLOCATED_FLOAT

        rewriter.replace_op(
            op,
            new_op := snitch_stream.StreamingRegionOp(
                new_inputs,
                new_outputs,
                ArrayAttr(stride_patterns),
                rewriter.move_region_contents_to_new_regions(op.body),
            ),
        )

        new_body = new_op.body.block

        input_stream_types = (snitch.ReadableStreamType(freg),) * len(op.inputs)
        output_stream_types = (snitch.WritableStreamType(freg),) * len(op.outputs)
        stream_types = input_stream_types + output_stream_types
        for i in reversed(range(len(stream_types))):
            arg = new_body.args[i]
            stream_type = stream_types[i]
            rewriter.insert_op(
                cast_op := builtin.UnrealizedConversionCastOp.get((arg,), (arg.type,)),
                InsertPoint.at_start(new_body),
            )
            rewriter.replace_uses_with_if(
                arg, cast_op.results[0], lambda use: use.operation is not cast_op
            )
            rewriter.replace_value_with_new_type(arg, stream_type)

match_and_rewrite(op: memref_stream.StreamingRegionOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_memref_stream_to_snitch_stream.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: memref_stream.StreamingRegionOp, rewriter: PatternRewriter
) -> None:
    operand_types = tuple(
        cast(memref.MemRefType, value_type)
        for value in op.operands
        if isinstance(value_type := value.type, memref.MemRefType)
    )
    stride_patterns = tuple(
        snitch_stream.StridePattern(
            ArrayAttr(ub.value for ub in pattern.ub),
            ArrayAttr(
                IntAttr(stride)
                for stride in strides_for_affine_map(
                    pattern.index_map.data, memref_type
                )
            ),
        ).simplified()
        for pattern, memref_type in zip(op.patterns, operand_types, strict=True)
    )
    if len(set(stride_patterns)) == 1:
        stride_patterns = (stride_patterns[0],)
    new_operands = cast_operands_to_regs(rewriter, op)
    new_inputs = new_operands[: len(op.inputs)]
    new_outputs = new_operands[len(op.inputs) :]
    freg = riscv.Registers.UNALLOCATED_FLOAT

    rewriter.replace_op(
        op,
        new_op := snitch_stream.StreamingRegionOp(
            new_inputs,
            new_outputs,
            ArrayAttr(stride_patterns),
            rewriter.move_region_contents_to_new_regions(op.body),
        ),
    )

    new_body = new_op.body.block

    input_stream_types = (snitch.ReadableStreamType(freg),) * len(op.inputs)
    output_stream_types = (snitch.WritableStreamType(freg),) * len(op.outputs)
    stream_types = input_stream_types + output_stream_types
    for i in reversed(range(len(stream_types))):
        arg = new_body.args[i]
        stream_type = stream_types[i]
        rewriter.insert_op(
            cast_op := builtin.UnrealizedConversionCastOp.get((arg,), (arg.type,)),
            InsertPoint.at_start(new_body),
        )
        rewriter.replace_uses_with_if(
            arg, cast_op.results[0], lambda use: use.operation is not cast_op
        )
        rewriter.replace_value_with_new_type(arg, stream_type)

ConvertMemRefStreamToSnitchStreamPass dataclass

Bases: ModulePass

Converts memref_stream read and write operations to the snitch_stream equivalents.

Care needs to be taken to preserve the semantics of the program. In assembly, the reads and writes are implicit, by using a register. In IR, they are modeled by read and write ops, which are not printed at the assembly level.

To preserve semantics, additional move ops are inserted in the following cases: - reading form a stream: if the value read has multiple uses, - writing to a stream: if the value is defined by an operation outside of the streaming region or if the defining operation is a stream read.

Source code in xdsl/transforms/convert_memref_stream_to_snitch_stream.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class ConvertMemRefStreamToSnitchStreamPass(ModulePass):
    """
    Converts memref_stream `read` and `write` operations to the snitch_stream equivalents.

    Care needs to be taken to preserve the semantics of the program.
    In assembly, the reads and writes are implicit, by using a register.
    In IR, they are modeled by `read` and `write` ops, which are not printed at the
    assembly level.

    To preserve semantics, additional move ops are inserted in the following cases:
     - reading form a stream: if the value read has multiple uses,
     - writing to a stream: if the value is defined by an operation outside of the
     streaming region or if the defining operation is a stream read.
    """

    name = "convert-memref-stream-to-snitch-stream"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ReadOpLowering(),
                    WriteOpLowering(),
                    StreamOpLowering(),
                ]
            ),
            apply_recursively=False,
            walk_reverse=True,
        ).rewrite_module(op)

name = 'convert-memref-stream-to-snitch-stream' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/convert_memref_stream_to_snitch_stream.py
251
252
253
254
255
256
257
258
259
260
261
262
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ReadOpLowering(),
                WriteOpLowering(),
                StreamOpLowering(),
            ]
        ),
        apply_recursively=False,
        walk_reverse=True,
    ).rewrite_module(op)

snitch_stream_element_type_is_valid(attr: Attribute) -> bool

An override of the helper to account for Snitch packed SIMD.

Source code in xdsl/transforms/convert_memref_stream_to_snitch_stream.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def snitch_stream_element_type_is_valid(attr: Attribute) -> bool:
    """
    An override of the helper to account for Snitch packed SIMD.
    """
    if isinstance(attr, VectorType):
        attr = cast(VectorType[Any], attr)
        match attr.element_type, attr.element_count():
            case Float64Type(), 1:
                return True
            case Float32Type(), 2:
                return True
            case Float16Type(), 4:
                return True
            case _:
                # TODO: handle fp8
                return False
    else:
        return isinstance(attr, Float64Type)

strides_for_affine_map(affine_map: AffineMap, memref_type: MemRefType[AttributeCovT]) -> list[int]

Given an iteration space represented as an affine map (for indexing) and a shape (for bounds), returns the corresponding iteration strides for each dimension.

The affine map must not have symbols.

Source code in xdsl/transforms/convert_memref_stream_to_snitch_stream.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def strides_for_affine_map(
    affine_map: AffineMap, memref_type: MemRefType[AttributeCovT]
) -> list[int]:
    """
    Given an iteration space represented as an affine map (for indexing) and a shape (for
    bounds), returns the corresponding iteration strides for each dimension.

    The affine map must not have symbols.
    """
    if affine_map.num_symbols:
        raise ValueError("Cannot create strides for affine map with symbols")

    # only static memref shapes are supported for now:
    static_shapes = (shape != -1 for shape in memref_type.get_shape())
    if not all(static_shapes):
        raise ValueError("Cannot create strides for a memref with dynamic shapes")

    offset_map = memref_type.get_affine_map_in_bytes()
    composed = offset_map.compose(affine_map)

    zeros = [0] * composed.num_dims
    # composed map can have symbols for dynamic offset, just set them to 0
    symbols = [0] * composed.num_symbols

    result: list[int] = []

    # subtract the static offset from each result
    offset = composed.eval(zeros, symbols)[0]

    for i in range(composed.num_dims):
        zeros[i] = 1
        result.append(composed.eval(zeros, symbols)[0] - offset)
        zeros[i] = 0

    return result