Skip to content

Memref to dsd

memref_to_dsd

LowerAllocOpPass

Bases: RewritePattern

Lowers memref.alloc to csl.zeros.

Source code in xdsl/transforms/memref_to_dsd.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class LowerAllocOpPass(RewritePattern):
    """Lowers `memref.alloc` to `csl.zeros`."""

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: memref.AllocOp, rewriter: PatternRewriter, /):
        assert (
            MemRefType[ZerosOpAttr]
            .constr(csl.ZerosOpAttrConstr)
            .verifies(memref_type := op.memref.type)
        )
        zeros_op = csl.ZerosOp(memref_type)

        dsd_t = csl.DsdType(
            csl.DsdKind.mem1d_dsd
            if len(memref_type.shape) == 1
            else csl.DsdKind.mem4d_dsd
        )
        offsets = None
        if isinstance(memref_type.layout, StridedLayoutAttr) and isinstance(
            memref_type.layout.offset, IntAttr
        ):
            offsets = ArrayAttr([IntegerAttr(memref_type.layout.offset, 16)])

        shape = [arith.ConstantOp(IntegerAttr(d, 16)) for d in memref_type.shape]
        dsd_op = csl.GetMemDsdOp.build(
            operands=[zeros_op, shape],
            result_types=[dsd_t],
            properties={
                "offsets": offsets,
            },
        )

        if op.memref.name_hint:
            zeros_op.result.name_hint = op.memref.name_hint
            dsd_op.result.name_hint = f"{op.memref.name_hint}_dsd"
            for s in shape:
                s.result.name_hint = f"{op.memref.name_hint}_size"

        rewriter.replace_op(op, [zeros_op, *shape, dsd_op])

match_and_rewrite(op: memref.AllocOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/memref_to_dsd.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.AllocOp, rewriter: PatternRewriter, /):
    assert (
        MemRefType[ZerosOpAttr]
        .constr(csl.ZerosOpAttrConstr)
        .verifies(memref_type := op.memref.type)
    )
    zeros_op = csl.ZerosOp(memref_type)

    dsd_t = csl.DsdType(
        csl.DsdKind.mem1d_dsd
        if len(memref_type.shape) == 1
        else csl.DsdKind.mem4d_dsd
    )
    offsets = None
    if isinstance(memref_type.layout, StridedLayoutAttr) and isinstance(
        memref_type.layout.offset, IntAttr
    ):
        offsets = ArrayAttr([IntegerAttr(memref_type.layout.offset, 16)])

    shape = [arith.ConstantOp(IntegerAttr(d, 16)) for d in memref_type.shape]
    dsd_op = csl.GetMemDsdOp.build(
        operands=[zeros_op, shape],
        result_types=[dsd_t],
        properties={
            "offsets": offsets,
        },
    )

    if op.memref.name_hint:
        zeros_op.result.name_hint = op.memref.name_hint
        dsd_op.result.name_hint = f"{op.memref.name_hint}_dsd"
        for s in shape:
            s.result.name_hint = f"{op.memref.name_hint}_size"

    rewriter.replace_op(op, [zeros_op, *shape, dsd_op])

FixGetDsdOnGetDsd

Bases: RewritePattern

This rewrite pattern resolves GetMemDsdOp being called on GetMemDsdOp instead of the underlying buffer, a side effect created by LowerAllocOpPass in case of pre-existing GetMemDsdOp ops being present in the program that were created outside of this pass.

Source code in xdsl/transforms/memref_to_dsd.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class FixGetDsdOnGetDsd(RewritePattern):
    """
    This rewrite pattern resolves GetMemDsdOp being called on GetMemDsdOp instead of the underlying buffer,
    a side effect created by `LowerAllocOpPass` in case of pre-existing GetMemDsdOp ops being present in
    the program that were created outside of this pass.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter, /):
        if isinstance(op.base_addr.type, csl.DsdType):
            if isinstance(op.base_addr, OpResult) and isinstance(
                op.base_addr.op, csl.GetMemDsdOp
            ):
                rewriter.replace_op(
                    op,
                    csl.GetMemDsdOp.build(
                        operands=[op.base_addr.op.base_addr, op.sizes],
                        properties=op.properties,
                        attributes=op.attributes,
                        result_types=op.result_types,
                    ),
                )
            else:
                raise ValueError("Failed to resolve GetMemDsdOp called on dsd type")

match_and_rewrite(op: csl.GetMemDsdOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/memref_to_dsd.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter, /):
    if isinstance(op.base_addr.type, csl.DsdType):
        if isinstance(op.base_addr, OpResult) and isinstance(
            op.base_addr.op, csl.GetMemDsdOp
        ):
            rewriter.replace_op(
                op,
                csl.GetMemDsdOp.build(
                    operands=[op.base_addr.op.base_addr, op.sizes],
                    properties=op.properties,
                    attributes=op.attributes,
                    result_types=op.result_types,
                ),
            )
        else:
            raise ValueError("Failed to resolve GetMemDsdOp called on dsd type")

FixMemRefLoadOnGetDsd

Bases: RewritePattern

MemRef load ops should load from the underlying memref, not from the dsd.

Source code in xdsl/transforms/memref_to_dsd.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class FixMemRefLoadOnGetDsd(RewritePattern):
    """
    MemRef load ops should load from the underlying memref, not from the dsd.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /):
        if isinstance(op.memref.type, csl.DsdType):
            if isinstance(op.memref, OpResult) and isinstance(
                op.memref.op, csl.GetMemDsdOp
            ):
                rewriter.replace_op(
                    op, memref.LoadOp.get(op.memref.op.base_addr, op.indices)
                )
            else:
                raise ValueError("Failed to resolve memref.load called on dsd type")

match_and_rewrite(op: memref.LoadOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/memref_to_dsd.py
108
109
110
111
112
113
114
115
116
117
118
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /):
    if isinstance(op.memref.type, csl.DsdType):
        if isinstance(op.memref, OpResult) and isinstance(
            op.memref.op, csl.GetMemDsdOp
        ):
            rewriter.replace_op(
                op, memref.LoadOp.get(op.memref.op.base_addr, op.indices)
            )
        else:
            raise ValueError("Failed to resolve memref.load called on dsd type")

LowerSubviewOpPass

Bases: RewritePattern

Lowers memref.subview to dsd ops

Source code in xdsl/transforms/memref_to_dsd.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class LowerSubviewOpPass(RewritePattern):
    """Lowers memref.subview to dsd ops"""

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
        assert isa(op.source.type, MemRefType)
        assert isa(op.result.type, MemRefType)

        if len(op.result.type.get_shape()) == 1 and len(op.source.type.get_shape()) > 1:
            # 1d subview onto a nd memref
            sizes = op.static_sizes.get_values()
            counter_sizes = collections.Counter(sizes)
            counter_sizes.pop(1, None)
            assert len(counter_sizes) == 1, (
                "1d access into nd memref must specify one size > 1"
            )
            size, size_count = counter_sizes.most_common()[0]

            assert size_count == 1, (
                "1d access into nd memref can only specify one size > 1, which can occur only once"
            )
            assert all(stride == 1 for stride in op.static_strides.get_values()), (
                "All strides must equal 1"
            )

            amap: list[AffineExpr] = [
                AffineConstantExpr(o if o != DYNAMIC_INDEX else 0)
                for o in op.static_offsets.get_values()
            ]
            amap[sizes.index(size)] += AffineDimExpr(0)

            size_op = arith.ConstantOp.from_int_and_width(size, 16)
            dsd_op = csl.GetMemDsdOp(
                operands=[op.source, [size_op]],
                properties={
                    "tensor_access": AffineMapAttr(AffineMap(1, 0, tuple(amap)))
                },
                result_types=[csl.DsdType(csl.DsdKind.mem1d_dsd)],
            )
            offset_ops = self._update_offsets(op, dsd_op) if op.offsets else []
            rewriter.replace_op(op, [size_op, dsd_op, *offset_ops])
            return

        assert len(op.static_sizes) == 1, "not implemented"
        assert len(op.static_offsets) == 1, "not implemented"
        assert len(op.static_strides) == 1, "not implemented"

        last_op = op.source
        size_ops = self._update_sizes(op, last_op)

        last_op = size_ops[-1] if len(size_ops) > 0 else last_op
        stride_ops = self._update_strides(op, last_op)

        last_op = stride_ops[-1] if len(stride_ops) > 0 else last_op
        offset_ops = self._update_offsets(op, last_op)

        new_ops = [*size_ops, *stride_ops, *offset_ops]
        if new_ops:
            rewriter.replace_op(op, [*size_ops, *stride_ops, *offset_ops])
        else:
            # subview has no effect (todo: this could be canonicalized away)
            rewriter.replace_op(op, [], new_results=[op.source])

    @staticmethod
    def _update_sizes(
        subview: memref.SubviewOp, curr_op: SSAValue | Operation
    ) -> list[Operation]:
        assert isa(subview.source.type, MemRefType)
        ops = list[Operation]()

        static_sizes = subview.static_sizes.get_values()

        if static_sizes[0] == DYNAMIC_INDEX:
            ops.append(cast_op := arith.IndexCastOp(subview.sizes[0], i16))
            ops.append(
                curr_op := csl.SetDsdLengthOp.build(
                    operands=[curr_op, cast_op], result_types=[subview.source.type]
                )
            )
        elif static_sizes != subview.source.type.get_shape():
            # update sizes only if they differ from op.source.type
            ops.append(
                len_op := arith.ConstantOp(
                    IntegerAttr(
                        static_sizes[0],
                        i16,
                    )
                )
            )
            ops.append(
                curr_op := csl.SetDsdLengthOp.build(
                    operands=[curr_op, len_op], result_types=[subview.source.type]
                )
            )
        return ops

    @staticmethod
    def _update_strides(
        subview: memref.SubviewOp, curr_op: SSAValue | Operation
    ) -> list[Operation]:
        assert isa(subview.source.type, MemRefType)
        ops = list[Operation]()

        static_strides = subview.static_strides.get_values()

        if static_strides[0] == DYNAMIC_INDEX:
            ops.append(cast_op := arith.IndexCastOp(subview.strides[0], i8))
            ops.append(
                csl.SetDsdStrideOp.build(
                    operands=[curr_op, cast_op], result_types=[subview.source.type]
                )
            )
        elif static_strides != subview.source.type.get_strides():
            # update strides only if they differ from op.source.type
            ops.append(
                stride_op := arith.ConstantOp(
                    IntegerAttr(
                        static_strides[0],
                        i8,
                    )
                )
            )
            ops.append(
                csl.SetDsdStrideOp.build(
                    operands=[curr_op, stride_op], result_types=[subview.source.type]
                )
            )
        return ops

    @staticmethod
    def _update_offsets(
        subview: memref.SubviewOp, curr_op: SSAValue | Operation
    ) -> list[Operation]:
        assert isa(subview.source.type, MemRefType)
        ops = list[Operation]()

        static_offsets = subview.static_offsets.get_values()

        if subview.offsets:
            ops.append(cast_op := arith.IndexCastOp(subview.offsets[0], i16))
            ops.append(
                csl.IncrementDsdOffsetOp.build(
                    operands=[curr_op, cast_op],
                    properties={"elem_type": subview.source.type.get_element_type()},
                    result_types=[subview.source.type],
                )
            )
        elif (
            isinstance(subview.source.type.layout, StridedLayoutAttr)
            and static_offsets[0] != (subview.source.type.layout.get_offset() or 0)
            or isinstance(subview.source.type.layout, NoneAttr)
            and static_offsets[0] != 0
        ):
            # update offsets only if they differ from op.source.type
            ops.append(
                offset_op := arith.ConstantOp(
                    IntegerAttr(
                        static_offsets[0],
                        i16,
                    )
                )
            )
            ops.append(
                csl.IncrementDsdOffsetOp.build(
                    operands=[curr_op, offset_op],
                    properties={"elem_type": subview.source.type.get_element_type()},
                    result_types=[subview.source.type],
                )
            )
        return ops

match_and_rewrite(op: memref.SubviewOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/memref_to_dsd.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
    assert isa(op.source.type, MemRefType)
    assert isa(op.result.type, MemRefType)

    if len(op.result.type.get_shape()) == 1 and len(op.source.type.get_shape()) > 1:
        # 1d subview onto a nd memref
        sizes = op.static_sizes.get_values()
        counter_sizes = collections.Counter(sizes)
        counter_sizes.pop(1, None)
        assert len(counter_sizes) == 1, (
            "1d access into nd memref must specify one size > 1"
        )
        size, size_count = counter_sizes.most_common()[0]

        assert size_count == 1, (
            "1d access into nd memref can only specify one size > 1, which can occur only once"
        )
        assert all(stride == 1 for stride in op.static_strides.get_values()), (
            "All strides must equal 1"
        )

        amap: list[AffineExpr] = [
            AffineConstantExpr(o if o != DYNAMIC_INDEX else 0)
            for o in op.static_offsets.get_values()
        ]
        amap[sizes.index(size)] += AffineDimExpr(0)

        size_op = arith.ConstantOp.from_int_and_width(size, 16)
        dsd_op = csl.GetMemDsdOp(
            operands=[op.source, [size_op]],
            properties={
                "tensor_access": AffineMapAttr(AffineMap(1, 0, tuple(amap)))
            },
            result_types=[csl.DsdType(csl.DsdKind.mem1d_dsd)],
        )
        offset_ops = self._update_offsets(op, dsd_op) if op.offsets else []
        rewriter.replace_op(op, [size_op, dsd_op, *offset_ops])
        return

    assert len(op.static_sizes) == 1, "not implemented"
    assert len(op.static_offsets) == 1, "not implemented"
    assert len(op.static_strides) == 1, "not implemented"

    last_op = op.source
    size_ops = self._update_sizes(op, last_op)

    last_op = size_ops[-1] if len(size_ops) > 0 else last_op
    stride_ops = self._update_strides(op, last_op)

    last_op = stride_ops[-1] if len(stride_ops) > 0 else last_op
    offset_ops = self._update_offsets(op, last_op)

    new_ops = [*size_ops, *stride_ops, *offset_ops]
    if new_ops:
        rewriter.replace_op(op, [*size_ops, *stride_ops, *offset_ops])
    else:
        # subview has no effect (todo: this could be canonicalized away)
        rewriter.replace_op(op, [], new_results=[op.source])

LowerCopyOpPass

Bases: RewritePattern

Lowers memref.copy to csl

Source code in xdsl/transforms/memref_to_dsd.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
class LowerCopyOpPass(RewritePattern):
    """Lowers memref.copy to csl"""

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: memref.CopyOp, rewriter: PatternRewriter, /):
        assert isa(op.source.type, MemRefType)

        match op.source.type.get_element_type():
            case Float16Type():
                func = csl.FmovhOp
            case Float32Type():
                func = csl.FmovsOp
            case builtin.i16:
                func = csl.Mov16Op
            case builtin.i32:
                func = csl.Mov32Op
            case _:
                raise ValueError("unsupported value")

        rewriter.replace_op(op, func(operands=[[op.destination, op.source]]))

match_and_rewrite(op: memref.CopyOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/memref_to_dsd.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.CopyOp, rewriter: PatternRewriter, /):
    assert isa(op.source.type, MemRefType)

    match op.source.type.get_element_type():
        case Float16Type():
            func = csl.FmovhOp
        case Float32Type():
            func = csl.FmovsOp
        case builtin.i16:
            func = csl.Mov16Op
        case builtin.i32:
            func = csl.Mov32Op
        case _:
            raise ValueError("unsupported value")

    rewriter.replace_op(op, func(operands=[[op.destination, op.source]]))

LowerUnrealizedConversionCastOpPass

Bases: RewritePattern

Conversions from dsd to memref are no longer necessary after this pass.

Source code in xdsl/transforms/memref_to_dsd.py
315
316
317
318
319
320
321
322
323
324
325
326
327
class LowerUnrealizedConversionCastOpPass(RewritePattern):
    """
    Conversions from dsd to memref are no longer necessary after this pass.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: UnrealizedConversionCastOp, rewriter: PatternRewriter, /
    ):
        if all(isa(t, csl.DsdType) for t in op.inputs.types) and all(
            isa(t, MemRefType) for t in op.outputs.types
        ):
            rewriter.replace_op(op, [], new_results=op.inputs)

match_and_rewrite(op: UnrealizedConversionCastOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/memref_to_dsd.py
320
321
322
323
324
325
326
327
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: UnrealizedConversionCastOp, rewriter: PatternRewriter, /
):
    if all(isa(t, csl.DsdType) for t in op.inputs.types) and all(
        isa(t, MemRefType) for t in op.outputs.types
    ):
        rewriter.replace_op(op, [], new_results=op.inputs)

DsdOpUpdateType

Bases: RewritePattern

Rebuild DSD ops from memref to DSD types.

Source code in xdsl/transforms/memref_to_dsd.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
class DsdOpUpdateType(RewritePattern):
    """Rebuild DSD ops from memref to DSD types."""

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self,
        op: csl.IncrementDsdOffsetOp | csl.SetDsdStrideOp | csl.SetDsdLengthOp,
        rewriter: PatternRewriter,
        /,
    ):
        rewriter.replace_op(
            op,
            type(op).build(
                operands=op.operands,
                properties=op.properties,
                attributes=op.attributes,
                result_types=[op.op.type],
            ),
        )

match_and_rewrite(op: csl.IncrementDsdOffsetOp | csl.SetDsdStrideOp | csl.SetDsdLengthOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/memref_to_dsd.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
@op_type_rewrite_pattern
def match_and_rewrite(
    self,
    op: csl.IncrementDsdOffsetOp | csl.SetDsdStrideOp | csl.SetDsdLengthOp,
    rewriter: PatternRewriter,
    /,
):
    rewriter.replace_op(
        op,
        type(op).build(
            operands=op.operands,
            properties=op.properties,
            attributes=op.attributes,
            result_types=[op.op.type],
        ),
    )

RetainAddressOfOpPass

Bases: RewritePattern

Ensure we don't export DSD but the underlying memref.

Source code in xdsl/transforms/memref_to_dsd.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
class RetainAddressOfOpPass(RewritePattern):
    """Ensure we don't export DSD but the underlying memref."""

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: csl.AddressOfOp, rewriter: PatternRewriter, /):
        if isinstance(op.value.type, csl.DsdType) and isinstance(
            op.value.owner, csl.GetMemDsdOp
        ):
            rewriter.replace_op(
                op,
                csl.AddressOfOp.build(
                    operands=[op.value.owner.base_addr], result_types=op.result_types
                ),
            )

match_and_rewrite(op: csl.AddressOfOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/memref_to_dsd.py
354
355
356
357
358
359
360
361
362
363
364
@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl.AddressOfOp, rewriter: PatternRewriter, /):
    if isinstance(op.value.type, csl.DsdType) and isinstance(
        op.value.owner, csl.GetMemDsdOp
    ):
        rewriter.replace_op(
            op,
            csl.AddressOfOp.build(
                operands=[op.value.owner.base_addr], result_types=op.result_types
            ),
        )

CslVarUpdate

Bases: RewritePattern

Update CSL Variable Definitions.

Source code in xdsl/transforms/memref_to_dsd.py
367
368
369
370
371
372
373
374
375
376
377
class CslVarUpdate(RewritePattern):
    """Update CSL Variable Definitions."""

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: csl.VariableOp, rewriter: PatternRewriter, /):
        if not isa(elem_t := op.res.type.get_element_type(), MemRefType) or op.default:
            return
        dsd_t = csl.DsdType(
            csl.DsdKind.mem1d_dsd if len(elem_t.shape) == 1 else csl.DsdKind.mem4d_dsd
        )
        rewriter.replace_op(op, csl.VariableOp.from_type(dsd_t))

match_and_rewrite(op: csl.VariableOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/memref_to_dsd.py
370
371
372
373
374
375
376
377
@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl.VariableOp, rewriter: PatternRewriter, /):
    if not isa(elem_t := op.res.type.get_element_type(), MemRefType) or op.default:
        return
    dsd_t = csl.DsdType(
        csl.DsdKind.mem1d_dsd if len(elem_t.shape) == 1 else csl.DsdKind.mem4d_dsd
    )
    rewriter.replace_op(op, csl.VariableOp.from_type(dsd_t))

CslVarLoad

Bases: RewritePattern

Update CSL Load Variables.

Source code in xdsl/transforms/memref_to_dsd.py
380
381
382
383
384
385
386
387
388
389
390
391
class CslVarLoad(RewritePattern):
    """Update CSL Load Variables."""

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: csl.LoadVarOp, rewriter: PatternRewriter, /):
        if (
            not isa(op.res.type, MemRefType)
            or not isinstance(op.var.type, csl.VarType)
            or not isa(op.var.type.get_element_type(), csl.DsdType)
        ):
            return
        rewriter.replace_op(op, csl.LoadVarOp(op.var))

match_and_rewrite(op: csl.LoadVarOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/memref_to_dsd.py
383
384
385
386
387
388
389
390
391
@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl.LoadVarOp, rewriter: PatternRewriter, /):
    if (
        not isa(op.res.type, MemRefType)
        or not isinstance(op.var.type, csl.VarType)
        or not isa(op.var.type.get_element_type(), csl.DsdType)
    ):
        return
    rewriter.replace_op(op, csl.LoadVarOp(op.var))

MemRefToDsdPass dataclass

Bases: ModulePass

Lowers memref ops to CSL DSDs.

Note, that CSL uses memref types in some places.

This performs a backwards pass translating memref-consuming ops to dsd-consuming ops when all memref type information is known. A second forward pass translates memref-generating ops to dsd-generating ops.

Source code in xdsl/transforms/memref_to_dsd.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
@dataclass(frozen=True)
class MemRefToDsdPass(ModulePass):
    """
    Lowers memref ops to CSL DSDs.

    Note, that CSL uses memref types in some places.

    This performs a backwards pass translating memref-consuming ops to dsd-consuming ops when all memref type
    information is known. A second forward pass translates memref-generating ops to dsd-generating ops.
    """

    name = "memref-to-dsd"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        module_pass = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LowerSubviewOpPass(),
                    LowerCopyOpPass(),
                    LowerUnrealizedConversionCastOpPass(),
                ]
            ),
            walk_reverse=True,
            apply_recursively=False,
        )
        module_pass.rewrite_module(op)
        forward_pass = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    CslVarUpdate(),
                    CslVarLoad(),
                    LowerAllocOpPass(),
                    DsdOpUpdateType(),
                    RetainAddressOfOpPass(),
                    FixMemRefLoadOnGetDsd(),
                    FixGetDsdOnGetDsd(),
                ]
            ),
            apply_recursively=False,
        )
        forward_pass.rewrite_module(op)

name = 'memref-to-dsd' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/memref_to_dsd.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
def apply(self, ctx: Context, op: ModuleOp) -> None:
    module_pass = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                LowerSubviewOpPass(),
                LowerCopyOpPass(),
                LowerUnrealizedConversionCastOpPass(),
            ]
        ),
        walk_reverse=True,
        apply_recursively=False,
    )
    module_pass.rewrite_module(op)
    forward_pass = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                CslVarUpdate(),
                CslVarLoad(),
                LowerAllocOpPass(),
                DsdOpUpdateType(),
                RetainAddressOfOpPass(),
                FixMemRefLoadOnGetDsd(),
                FixGetDsdOnGetDsd(),
            ]
        ),
        apply_recursively=False,
    )
    forward_pass.rewrite_module(op)