Skip to content

Convert stencil to ll mlir

convert_stencil_to_ll_mlir

CastOpToMemRef dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
73
74
75
76
77
78
79
80
81
82
83
84
@dataclass
class CastOpToMemRef(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: CastOp, rewriter: PatternRewriter, /):
        assert isa(op.result.type, FieldType[Attribute])
        assert isinstance(op.result.type.bounds, StencilBoundsAttr)

        result_type = StencilToMemRefType(op.result.type)

        cast = memref.CastOp.get(op.field, result_type)

        rewriter.replace_op(op, cast)

__init__() -> None

match_and_rewrite(op: CastOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
75
76
77
78
79
80
81
82
83
84
@op_type_rewrite_pattern
def match_and_rewrite(self, op: CastOp, rewriter: PatternRewriter, /):
    assert isa(op.result.type, FieldType[Attribute])
    assert isinstance(op.result.type.bounds, StencilBoundsAttr)

    result_type = StencilToMemRefType(op.result.type)

    cast = memref.CastOp.get(op.field, result_type)

    rewriter.replace_op(op, cast)

ReturnOpToMemRef dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@dataclass
class ReturnOpToMemRef(RewritePattern):
    return_target: dict[ApplyOp, list[SSAValue | None]]

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ReturnOp, rewriter: PatternRewriter, /):
        unroll_factor = op.unroll_factor
        n_res = len(op.arg) // unroll_factor

        store_list: list[Operation] = []
        apply = op.parent_op()
        assert isinstance(apply, ApplyOp)
        if op.unroll is not None:
            apply.attributes["__unroll__"] = op.unroll

        n_dims = apply.get_rank()

        for j in range(n_res):
            if len(apply.res) > 0:
                target = self.return_target[apply][j]
                # Insert pending subview if it was created but not yet inserted
                # (StencilStoreToSubview defers insertion to avoid dead code elimination)
                if (
                    target is not None
                    and isinstance(target.owner, memref.SubviewOp)
                    and target.owner.parent_block() is None
                ):
                    subview_op = target.owner
                    field = subview_op.source
                    if isinstance(field.owner, Operation):
                        rewriter.insert_op(subview_op, InsertPoint.after(field.owner))
                    else:
                        rewriter.insert_op(
                            subview_op, InsertPoint.at_start(field.owner)
                        )
            else:
                target = apply.dest[j]
                rewriter.insert_op(
                    subview := field_subview(target), InsertPoint.before(apply)
                )
                target = subview

            unroll = op.unroll
            if unroll is None:
                unroll = IndexAttr.get(*([1] * n_dims))

            for k, offset in enumerate(product(*(range(u) for u in unroll))):
                arg = op.arg[j * unroll_factor + k]
                index_ops: list[Operation] = list(
                    IndexOp(
                        attributes={
                            "dim": builtin.IntegerAttr.from_index_int_value(i),
                            "offset": IndexAttr.get(*([0] * n_dims)),
                        },
                        result_types=[builtin.IndexType()],
                    )
                    for i in range(n_dims)
                )
                store_list += index_ops

                for i in range(n_dims):
                    if offset[i] != 0:
                        constant_op = arith.ConstantOp.from_int_and_width(
                            offset[i], builtin.IndexType()
                        )
                        add_op = arith.AddiOp(index_ops[i], constant_op)
                        index_ops[i] = add_op
                        store_list.append(constant_op)
                        store_list.append(add_op)

                if isinstance(arg.type, ResultType):
                    result_owner = _find_result_store(arg)
                    for owner in result_owner:
                        if owner.arg:
                            if target is not None:
                                store = memref.StoreOp.get(owner.arg, target, index_ops)
                            else:
                                store = list[Operation]()
                            rewriter.replace_op(
                                owner,
                                store,
                                new_results=[owner.arg],
                            )
                        else:
                            dummy = UnrealizedConversionCastOp.get([], [arg.type.elem])
                            rewriter.replace_op(owner, dummy)

                else:
                    if target is not None:
                        store = memref.StoreOp.get(arg, target, index_ops)
                        store_list.append(store)

        rewriter.insert_op(store_list)
        rewriter.erase_op(op)

return_target: dict[ApplyOp, list[SSAValue | None]] instance-attribute

__init__(return_target: dict[ApplyOp, list[SSAValue | None]]) -> None

match_and_rewrite(op: ReturnOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ReturnOp, rewriter: PatternRewriter, /):
    unroll_factor = op.unroll_factor
    n_res = len(op.arg) // unroll_factor

    store_list: list[Operation] = []
    apply = op.parent_op()
    assert isinstance(apply, ApplyOp)
    if op.unroll is not None:
        apply.attributes["__unroll__"] = op.unroll

    n_dims = apply.get_rank()

    for j in range(n_res):
        if len(apply.res) > 0:
            target = self.return_target[apply][j]
            # Insert pending subview if it was created but not yet inserted
            # (StencilStoreToSubview defers insertion to avoid dead code elimination)
            if (
                target is not None
                and isinstance(target.owner, memref.SubviewOp)
                and target.owner.parent_block() is None
            ):
                subview_op = target.owner
                field = subview_op.source
                if isinstance(field.owner, Operation):
                    rewriter.insert_op(subview_op, InsertPoint.after(field.owner))
                else:
                    rewriter.insert_op(
                        subview_op, InsertPoint.at_start(field.owner)
                    )
        else:
            target = apply.dest[j]
            rewriter.insert_op(
                subview := field_subview(target), InsertPoint.before(apply)
            )
            target = subview

        unroll = op.unroll
        if unroll is None:
            unroll = IndexAttr.get(*([1] * n_dims))

        for k, offset in enumerate(product(*(range(u) for u in unroll))):
            arg = op.arg[j * unroll_factor + k]
            index_ops: list[Operation] = list(
                IndexOp(
                    attributes={
                        "dim": builtin.IntegerAttr.from_index_int_value(i),
                        "offset": IndexAttr.get(*([0] * n_dims)),
                    },
                    result_types=[builtin.IndexType()],
                )
                for i in range(n_dims)
            )
            store_list += index_ops

            for i in range(n_dims):
                if offset[i] != 0:
                    constant_op = arith.ConstantOp.from_int_and_width(
                        offset[i], builtin.IndexType()
                    )
                    add_op = arith.AddiOp(index_ops[i], constant_op)
                    index_ops[i] = add_op
                    store_list.append(constant_op)
                    store_list.append(add_op)

            if isinstance(arg.type, ResultType):
                result_owner = _find_result_store(arg)
                for owner in result_owner:
                    if owner.arg:
                        if target is not None:
                            store = memref.StoreOp.get(owner.arg, target, index_ops)
                        else:
                            store = list[Operation]()
                        rewriter.replace_op(
                            owner,
                            store,
                            new_results=[owner.arg],
                        )
                    else:
                        dummy = UnrealizedConversionCastOp.get([], [arg.type.elem])
                        rewriter.replace_op(owner, dummy)

            else:
                if target is not None:
                    store = memref.StoreOp.get(arg, target, index_ops)
                    store_list.append(store)

    rewriter.insert_op(store_list)
    rewriter.erase_op(op)

LoadOpToMemRef

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
class LoadOpToMemRef(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: LoadOp, rewriter: PatternRewriter, /):
        for use in op.field.uses:
            if isa(use.operation, StoreOp):
                raise VerifyException(
                    "Cannot lower directly if loading and storing the same field! Try running `stencil-bufferize` before."
                )
        field = op.field.type
        assert isa(field, FieldType[Attribute])
        assert isa(field.bounds, StencilBoundsAttr)
        temp = op.res.type
        assert isa(temp, TempType[Attribute])
        assert isa(temp.bounds, StencilBoundsAttr)

        assert_subset(field, temp)

        offsets = [i for i in -field.bounds.lb]
        sizes = [i for i in temp.get_shape()]
        strides = [1] * len(sizes)

        subview = memref.SubviewOp.from_static_parameters(
            op.field, StencilToMemRefType(field), offsets, sizes, strides
        )

        rewriter.replace_op(op, subview)
        name = None
        if subview.source.name_hint:
            name = subview.source.name_hint + "_loadview"
        subview.result.name_hint = name

match_and_rewrite(op: LoadOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
@op_type_rewrite_pattern
def match_and_rewrite(self, op: LoadOp, rewriter: PatternRewriter, /):
    for use in op.field.uses:
        if isa(use.operation, StoreOp):
            raise VerifyException(
                "Cannot lower directly if loading and storing the same field! Try running `stencil-bufferize` before."
            )
    field = op.field.type
    assert isa(field, FieldType[Attribute])
    assert isa(field.bounds, StencilBoundsAttr)
    temp = op.res.type
    assert isa(temp, TempType[Attribute])
    assert isa(temp.bounds, StencilBoundsAttr)

    assert_subset(field, temp)

    offsets = [i for i in -field.bounds.lb]
    sizes = [i for i in temp.get_shape()]
    strides = [1] * len(sizes)

    subview = memref.SubviewOp.from_static_parameters(
        op.field, StencilToMemRefType(field), offsets, sizes, strides
    )

    rewriter.replace_op(op, subview)
    name = None
    if subview.source.name_hint:
        name = subview.source.name_hint + "_loadview"
    subview.result.name_hint = name

BufferOpToMemRef dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
@dataclass
class BufferOpToMemRef(RewritePattern):
    return_targets: dict[ApplyOp, list[SSAValue | None]]

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter, /):
        # The current lowering simply allocate at block entry and deallocates at block
        # exit.
        # One could be smarter, e.g., aim precisly at where the first and last need is
        # But first, this requires more code with stencil.combine in the mix.
        # And second, do we want to be that smart? We use it iterated either way, so
        # probably always hoping to hoist (de)allocations out of the loop anyway?
        temp_t = op.temp.type
        assert isa(temp_t, TempType[Attribute])
        temp_bounds = temp_t.bounds
        assert isa(temp_bounds, StencilBoundsAttr)

        block = op.parent_block()
        assert block is not None
        first_op = block.first_op
        assert first_op is not None
        last_op = block.last_op
        assert last_op is not None

        shape = temp_t.get_shape()
        strides = [prod(shape[i + 1 :]) for i in range(len(shape))]
        offset = -sum(o * s for o, s in zip(temp_bounds.lb, strides, strict=True))

        layout = memref.StridedLayoutAttr(strides, offset)
        alloc = memref.AllocOp.get(
            temp_t.get_element_type(), shape=temp_t.get_shape(), layout=layout
        )
        alloc_type = alloc.memref.type
        assert isa(alloc_type, MemRefType)

        rewriter.insert_op(alloc, InsertPoint.before(first_op))

        update_return_target(self.return_targets, op.temp, alloc.memref)

        dealloc = memref.DeallocOp.get(alloc.memref)

        if not op.res.uses:
            rewriter.insert_op(dealloc, InsertPoint.after(op))
            rewriter.erase_op(op)
            return

        rewriter.insert_op(dealloc, InsertPoint.before(last_op))
        rewriter.replace_op(op, [], [alloc.memref])

return_targets: dict[ApplyOp, list[SSAValue | None]] instance-attribute

__init__(return_targets: dict[ApplyOp, list[SSAValue | None]]) -> None

match_and_rewrite(op: BufferOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
@op_type_rewrite_pattern
def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter, /):
    # The current lowering simply allocate at block entry and deallocates at block
    # exit.
    # One could be smarter, e.g., aim precisly at where the first and last need is
    # But first, this requires more code with stencil.combine in the mix.
    # And second, do we want to be that smart? We use it iterated either way, so
    # probably always hoping to hoist (de)allocations out of the loop anyway?
    temp_t = op.temp.type
    assert isa(temp_t, TempType[Attribute])
    temp_bounds = temp_t.bounds
    assert isa(temp_bounds, StencilBoundsAttr)

    block = op.parent_block()
    assert block is not None
    first_op = block.first_op
    assert first_op is not None
    last_op = block.last_op
    assert last_op is not None

    shape = temp_t.get_shape()
    strides = [prod(shape[i + 1 :]) for i in range(len(shape))]
    offset = -sum(o * s for o, s in zip(temp_bounds.lb, strides, strict=True))

    layout = memref.StridedLayoutAttr(strides, offset)
    alloc = memref.AllocOp.get(
        temp_t.get_element_type(), shape=temp_t.get_shape(), layout=layout
    )
    alloc_type = alloc.memref.type
    assert isa(alloc_type, MemRefType)

    rewriter.insert_op(alloc, InsertPoint.before(first_op))

    update_return_target(self.return_targets, op.temp, alloc.memref)

    dealloc = memref.DeallocOp.get(alloc.memref)

    if not op.res.uses:
        rewriter.insert_op(dealloc, InsertPoint.after(op))
        rewriter.erase_op(op)
        return

    rewriter.insert_op(dealloc, InsertPoint.before(last_op))
    rewriter.replace_op(op, [], [alloc.memref])

AllocOpToMemRef

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
358
359
360
361
362
363
364
class AllocOpToMemRef(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: AllocOp, rewriter: PatternRewriter, /):
        alloc = memref.AllocOp(
            [], [], StencilToMemRefType(cast(StencilType[Attribute], op.field.type))
        )
        rewriter.replace_op(op, alloc)

match_and_rewrite(op: AllocOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
359
360
361
362
363
364
@op_type_rewrite_pattern
def match_and_rewrite(self, op: AllocOp, rewriter: PatternRewriter, /):
    alloc = memref.AllocOp(
        [], [], StencilToMemRefType(cast(StencilType[Attribute], op.field.type))
    )
    rewriter.replace_op(op, alloc)

ApplyOpFieldSubviews dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
@dataclass
class ApplyOpFieldSubviews(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
        args = [
            field_subview(arg) if isinstance(arg.type, FieldType) else arg
            for arg in op.args
        ]
        if args == list(op.args):
            return

        new_apply = ApplyOp.create(
            operands=[SSAValue.get(arg) for arg in args] + list(op.dest),
            result_types=[r.type for r in op.res],
            regions=[op.detach_region(0)],
            attributes=op.attributes,
            properties=op.properties,
        )
        rewriter.replace_op(
            op, [*(arg for arg in args if isinstance(arg, Operation)), new_apply]
        )

__init__() -> None

match_and_rewrite(op: ApplyOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
    args = [
        field_subview(arg) if isinstance(arg.type, FieldType) else arg
        for arg in op.args
    ]
    if args == list(op.args):
        return

    new_apply = ApplyOp.create(
        operands=[SSAValue.get(arg) for arg in args] + list(op.dest),
        result_types=[r.type for r in op.res],
        regions=[op.detach_region(0)],
        attributes=op.attributes,
        properties=op.properties,
    )
    rewriter.replace_op(
        op, [*(arg for arg in args if isinstance(arg, Operation)), new_apply]
    )

ApplyOpToParallel dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
@dataclass
class ApplyOpToParallel(RewritePattern):
    return_targets: dict[ApplyOp, list[SSAValue | None]]

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
        if len(op.res) > 0:
            res_type = op.res[0].type
            assert isa(res_type, TempType[Attribute])
            assert isinstance(res_type.bounds, StencilBoundsAttr)
            lb = res_type.bounds.lb
            ub = res_type.bounds.ub
        else:
            assert op.bounds is not None
            lb = op.bounds.lb
            ub = op.bounds.ub

        # Get this apply's ReturnOp
        unroll = op.attributes.get("__unroll__", None)
        if unroll is not None:
            assert isinstance(unroll, IndexAttr)
            unroll = list(u for u in unroll)

        rank = op.get_rank()
        body = prepare_apply_body(op)
        if unroll is None:
            unroll = [1] * rank
        else:
            unroll = [i for i in unroll]

        # Then create the corresponding scf.parallel
        boilerplate_ops = [
            *(
                lowerBounds := [
                    arith.ConstantOp.from_int_and_width(x, builtin.IndexType())
                    for x in lb
                ]
            ),
            *(
                steps := [
                    arith.ConstantOp.from_int_and_width(x, builtin.IndexType())
                    for x in unroll
                ]
            ),
            *(
                upperBounds := [
                    arith.ConstantOp.from_int_and_width(x, builtin.IndexType())
                    for x in ub
                ]
            ),
        ]

        # Generate an outer parallel loop as well as two inner sequential
        # loops. The inner sequential loops ensure that the computational
        # kernel itself is not slowed down by the OpenMP runtime.
        tiled_steps = steps
        p = scf.ParallelOp(
            lower_bounds=lowerBounds,
            upper_bounds=upperBounds,
            steps=tiled_steps,
            body=Region(body),
        )
        for index in body.walk():
            if isinstance(index, IndexOp):
                offset = list(index.offset)
                ops: list[Operation] = []
                res: list[SSAValue] = [body.args[index.dim.value.data]]
                if offset[index.dim.value.data] != 0:
                    ops = [
                        cst := arith.ConstantOp.from_int_and_width(
                            offset[index.dim.value.data], builtin.IndexType()
                        ),
                        add := arith.AddiOp(body.args[index.dim.value.data], cst),
                    ]
                    res = [add.result]
                rewriter.replace_op(index, ops, res)

        # Get the maybe updated results
        new_results = self.return_targets[op] if op in self.return_targets else []
        # Replace with the loop and necessary constants.
        assert isa(boilerplate_ops, list[Operation])
        rewriter.insert_op([*boilerplate_ops, p])
        rewriter.replace_op(op, [], new_results)

return_targets: dict[ApplyOp, list[SSAValue | None]] instance-attribute

__init__(return_targets: dict[ApplyOp, list[SSAValue | None]]) -> None

match_and_rewrite(op: ApplyOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
    if len(op.res) > 0:
        res_type = op.res[0].type
        assert isa(res_type, TempType[Attribute])
        assert isinstance(res_type.bounds, StencilBoundsAttr)
        lb = res_type.bounds.lb
        ub = res_type.bounds.ub
    else:
        assert op.bounds is not None
        lb = op.bounds.lb
        ub = op.bounds.ub

    # Get this apply's ReturnOp
    unroll = op.attributes.get("__unroll__", None)
    if unroll is not None:
        assert isinstance(unroll, IndexAttr)
        unroll = list(u for u in unroll)

    rank = op.get_rank()
    body = prepare_apply_body(op)
    if unroll is None:
        unroll = [1] * rank
    else:
        unroll = [i for i in unroll]

    # Then create the corresponding scf.parallel
    boilerplate_ops = [
        *(
            lowerBounds := [
                arith.ConstantOp.from_int_and_width(x, builtin.IndexType())
                for x in lb
            ]
        ),
        *(
            steps := [
                arith.ConstantOp.from_int_and_width(x, builtin.IndexType())
                for x in unroll
            ]
        ),
        *(
            upperBounds := [
                arith.ConstantOp.from_int_and_width(x, builtin.IndexType())
                for x in ub
            ]
        ),
    ]

    # Generate an outer parallel loop as well as two inner sequential
    # loops. The inner sequential loops ensure that the computational
    # kernel itself is not slowed down by the OpenMP runtime.
    tiled_steps = steps
    p = scf.ParallelOp(
        lower_bounds=lowerBounds,
        upper_bounds=upperBounds,
        steps=tiled_steps,
        body=Region(body),
    )
    for index in body.walk():
        if isinstance(index, IndexOp):
            offset = list(index.offset)
            ops: list[Operation] = []
            res: list[SSAValue] = [body.args[index.dim.value.data]]
            if offset[index.dim.value.data] != 0:
                ops = [
                    cst := arith.ConstantOp.from_int_and_width(
                        offset[index.dim.value.data], builtin.IndexType()
                    ),
                    add := arith.AddiOp(body.args[index.dim.value.data], cst),
                ]
                res = [add.result]
            rewriter.replace_op(index, ops, res)

    # Get the maybe updated results
    new_results = self.return_targets[op] if op in self.return_targets else []
    # Replace with the loop and necessary constants.
    assert isa(boilerplate_ops, list[Operation])
    rewriter.insert_op([*boilerplate_ops, p])
    rewriter.replace_op(op, [], new_results)

AccessOpToMemRef dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
@dataclass
class AccessOpToMemRef(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /):
        temp = op.temp.type
        assert StencilTypeConstr.verifies(temp)
        assert isinstance(temp.bounds, StencilBoundsAttr)

        memref_offset = op.offset

        mapping = (
            op.offset_mapping
            if op.offset_mapping is not None
            else range(len(memref_offset))
        )

        args = [
            IndexOp(
                attributes={
                    "dim": builtin.IntegerAttr.from_index_int_value(i),
                    "offset": IndexAttr.get(*([0] * op.get_apply().get_rank())),
                },
                result_types=[builtin.IndexType()],
            )
            for i in mapping
        ]

        off_const_ops: list[Operation] = []
        memref_load_args: list[BlockArgument | OpResult] = []

        # This will apply an offset to the index if one is required
        # (e.g the offset is not zero), otherwise will use the index value directly
        for arg, x in zip(args, memref_offset):
            if x != 0:
                constant_op = arith.ConstantOp.from_int_and_width(
                    x, builtin.IndexType()
                )
                add_op = arith.AddiOp(arg, constant_op)
                memref_load_args.append(add_op.results[0])
                off_const_ops += [constant_op, add_op]
            else:
                memref_load_args.append(arg.idx)

        load = memref.LoadOp(
            operands=[op.temp, memref_load_args], result_types=[temp.element_type]
        )

        rewriter.insert_op(args)
        rewriter.replace_op(op, [*off_const_ops, load], [load.res])

__init__() -> None

match_and_rewrite(op: AccessOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
@op_type_rewrite_pattern
def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /):
    temp = op.temp.type
    assert StencilTypeConstr.verifies(temp)
    assert isinstance(temp.bounds, StencilBoundsAttr)

    memref_offset = op.offset

    mapping = (
        op.offset_mapping
        if op.offset_mapping is not None
        else range(len(memref_offset))
    )

    args = [
        IndexOp(
            attributes={
                "dim": builtin.IntegerAttr.from_index_int_value(i),
                "offset": IndexAttr.get(*([0] * op.get_apply().get_rank())),
            },
            result_types=[builtin.IndexType()],
        )
        for i in mapping
    ]

    off_const_ops: list[Operation] = []
    memref_load_args: list[BlockArgument | OpResult] = []

    # This will apply an offset to the index if one is required
    # (e.g the offset is not zero), otherwise will use the index value directly
    for arg, x in zip(args, memref_offset):
        if x != 0:
            constant_op = arith.ConstantOp.from_int_and_width(
                x, builtin.IndexType()
            )
            add_op = arith.AddiOp(arg, constant_op)
            memref_load_args.append(add_op.results[0])
            off_const_ops += [constant_op, add_op]
        else:
            memref_load_args.append(arg.idx)

    load = memref.LoadOp(
        operands=[op.temp, memref_load_args], result_types=[temp.element_type]
    )

    rewriter.insert_op(args)
    rewriter.replace_op(op, [*off_const_ops, load], [load.res])

StencilStoreToSubview dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
@dataclass
class StencilStoreToSubview(RewritePattern):
    return_targets: dict[ApplyOp, list[SSAValue | None]]

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter, /):
        for use in op.field.uses:
            if isa(use.operation, LoadOp):
                raise VerifyException(
                    "Cannot lower directly if loading and storing the same field! "
                    "Try running `stencil-bufferize` before."
                )
            if isa(use.operation, StoreOp) and use.operation is not op:
                raise VerifyException(
                    "Cannot lower directly if storing to the same field multiple "
                    "times! Try running `stencil-bufferize` before."
                )
        field = op.field
        assert isa(field.type, FieldType[Attribute])
        assert isa(field.type.bounds, StencilBoundsAttr)
        temp = op.temp
        assert isa(temp.type, TempType[Attribute])
        offsets = [i for i in -field.type.bounds.lb]
        sizes = [i for i in temp.type.get_shape()]
        subview = memref.SubviewOp.from_static_parameters(
            field,
            StencilToMemRefType(field.type),
            offsets,
            sizes,
            [1] * len(sizes),
        )
        name = None
        if subview.source.name_hint:
            name = subview.source.name_hint + "_storeview"
        subview.result.name_hint = name

        # Don't insert the subview here - defer insertion to ReturnOpToMemRef
        # to avoid the subview being eliminated as dead code before it gets users.
        # The subview will be inserted when ReturnOpToMemRef processes the
        # corresponding return operation and needs this target.

        rewriter.erase_op(op)

        update_return_target(self.return_targets, field, subview.result)

return_targets: dict[ApplyOp, list[SSAValue | None]] instance-attribute

__init__(return_targets: dict[ApplyOp, list[SSAValue | None]]) -> None

match_and_rewrite(op: StoreOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
@op_type_rewrite_pattern
def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter, /):
    for use in op.field.uses:
        if isa(use.operation, LoadOp):
            raise VerifyException(
                "Cannot lower directly if loading and storing the same field! "
                "Try running `stencil-bufferize` before."
            )
        if isa(use.operation, StoreOp) and use.operation is not op:
            raise VerifyException(
                "Cannot lower directly if storing to the same field multiple "
                "times! Try running `stencil-bufferize` before."
            )
    field = op.field
    assert isa(field.type, FieldType[Attribute])
    assert isa(field.type.bounds, StencilBoundsAttr)
    temp = op.temp
    assert isa(temp.type, TempType[Attribute])
    offsets = [i for i in -field.type.bounds.lb]
    sizes = [i for i in temp.type.get_shape()]
    subview = memref.SubviewOp.from_static_parameters(
        field,
        StencilToMemRefType(field.type),
        offsets,
        sizes,
        [1] * len(sizes),
    )
    name = None
    if subview.source.name_hint:
        name = subview.source.name_hint + "_storeview"
    subview.result.name_hint = name

    # Don't insert the subview here - defer insertion to ReturnOpToMemRef
    # to avoid the subview being eliminated as dead code before it gets users.
    # The subview will be inserted when ReturnOpToMemRef processes the
    # corresponding return operation and needs this target.

    rewriter.erase_op(op)

    update_return_target(self.return_targets, field, subview.result)

TrivialExternalLoadOpCleanup

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
572
573
574
575
576
577
578
579
580
581
582
class TrivialExternalLoadOpCleanup(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ExternalLoadOp, rewriter: PatternRewriter, /):
        assert isa(op.result.type, FieldType[Attribute])
        rewriter.replace_value_with_new_type(
            op.result, StencilToMemRefType(op.result.type)
        )

        if op.field.type == op.result.type:
            rewriter.replace_op(op, [], [op.field])
        pass

match_and_rewrite(op: ExternalLoadOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
573
574
575
576
577
578
579
580
581
582
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ExternalLoadOp, rewriter: PatternRewriter, /):
    assert isa(op.result.type, FieldType[Attribute])
    rewriter.replace_value_with_new_type(
        op.result, StencilToMemRefType(op.result.type)
    )

    if op.field.type == op.result.type:
        rewriter.replace_op(op, [], [op.field])
    pass

TrivialExternalStoreOpCleanup

Bases: RewritePattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
585
586
587
588
class TrivialExternalStoreOpCleanup(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ExternalStoreOp, rewriter: PatternRewriter, /):
        rewriter.erase_op(op)

match_and_rewrite(op: ExternalStoreOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
586
587
588
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ExternalStoreOp, rewriter: PatternRewriter, /):
    rewriter.erase_op(op)

CombineOpCleanup

Bases: RewritePattern

Just remove stencil.combines as they are just used for return target analysis.

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
591
592
593
594
595
596
597
598
class CombineOpCleanup(RewritePattern):
    """
    Just remove `stencil.combine`s as they are just used for return target analysis.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter, /):
        rewriter.erase_op(op)

match_and_rewrite(op: CombineOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
596
597
598
@op_type_rewrite_pattern
def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter, /):
    rewriter.erase_op(op)

StencilTypeConversion dataclass

Bases: TypeConversionPattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
673
674
675
676
class StencilTypeConversion(TypeConversionPattern):
    @attr_constr_rewrite_pattern(StencilTypeConstr)
    def convert_type(self, typ: StencilType[Attribute]) -> MemRefType:
        return StencilToMemRefType(typ)

convert_type(typ: StencilType[Attribute]) -> MemRefType

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
674
675
676
@attr_constr_rewrite_pattern(StencilTypeConstr)
def convert_type(self, typ: StencilType[Attribute]) -> MemRefType:
    return StencilToMemRefType(typ)

ResultTypeConversion dataclass

Bases: TypeConversionPattern

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
679
680
681
682
class ResultTypeConversion(TypeConversionPattern):
    @attr_type_rewrite_pattern
    def convert_type(self, typ: ResultType) -> Attribute:
        return typ.elem

convert_type(typ: ResultType) -> Attribute

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
680
681
682
@attr_type_rewrite_pattern
def convert_type(self, typ: ResultType) -> Attribute:
    return typ.elem

ConvertStencilToLLMLIRPass dataclass

Bases: ModulePass

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
@dataclass(frozen=True)
class ConvertStencilToLLMLIRPass(ModulePass):
    name = "convert-stencil-to-ll-mlir"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        return_targets: dict[ApplyOp, list[SSAValue | None]] = return_target_analysis(
            op
        )

        the_one_pass = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ApplyOpFieldSubviews(),
                    ApplyOpToParallel(return_targets),
                    BufferOpToMemRef(return_targets),
                    StencilStoreToSubview(return_targets),
                    CastOpToMemRef(),
                    LoadOpToMemRef(),
                    AccessOpToMemRef(),
                    ReturnOpToMemRef(return_targets),
                    TrivialExternalLoadOpCleanup(),
                    TrivialExternalStoreOpCleanup(),
                    AllocOpToMemRef(),
                ]
            ),
            apply_recursively=True,
            walk_reverse=True,
            walk_regions_first=True,
        )
        the_one_pass.rewrite_module(op)
        type_pass = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    CombineOpCleanup(),
                    StencilTypeConversion(recursive=True),
                    ResultTypeConversion(recursive=True),
                ]
            ),
            walk_reverse=True,
        )
        type_pass.rewrite_module(op)

name = 'convert-stencil-to-ll-mlir' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    return_targets: dict[ApplyOp, list[SSAValue | None]] = return_target_analysis(
        op
    )

    the_one_pass = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ApplyOpFieldSubviews(),
                ApplyOpToParallel(return_targets),
                BufferOpToMemRef(return_targets),
                StencilStoreToSubview(return_targets),
                CastOpToMemRef(),
                LoadOpToMemRef(),
                AccessOpToMemRef(),
                ReturnOpToMemRef(return_targets),
                TrivialExternalLoadOpCleanup(),
                TrivialExternalStoreOpCleanup(),
                AllocOpToMemRef(),
            ]
        ),
        apply_recursively=True,
        walk_reverse=True,
        walk_regions_first=True,
    )
    the_one_pass.rewrite_module(op)
    type_pass = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                CombineOpCleanup(),
                StencilTypeConversion(recursive=True),
                ResultTypeConversion(recursive=True),
            ]
        ),
        walk_reverse=True,
    )
    type_pass.rewrite_module(op)

StencilToMemRefType(input_type: StencilType[_TypeElement]) -> MemRefType[_TypeElement]

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
67
68
69
70
def StencilToMemRefType(
    input_type: StencilType[_TypeElement],
) -> MemRefType[_TypeElement]:
    return MemRefType(input_type.element_type, input_type.get_shape())

collectBlockArguments(number: int, block: Block)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def collectBlockArguments(number: int, block: Block):
    args = []

    while len(args) < number:
        args = list(block.args[0 : number - len(args)]) + args

        parent = block.parent_block()
        if parent is None:
            break

        block = parent

    return args

update_return_target(return_targets: dict[ApplyOp, list[SSAValue | None]], old_target: SSAValue, new_target: SSAValue | None)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
104
105
106
107
108
109
110
111
112
def update_return_target(
    return_targets: dict[ApplyOp, list[SSAValue | None]],
    old_target: SSAValue,
    new_target: SSAValue | None,
):
    for targets in return_targets.values():
        for i, target in enumerate(targets):
            if target == old_target:
                targets[i] = new_target

assert_subset(field: FieldType[Attribute], temp: TempType[Attribute])

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
233
234
235
236
237
238
239
240
241
242
243
244
245
def assert_subset(field: FieldType[Attribute], temp: TempType[Attribute]):
    assert isinstance(field.bounds, StencilBoundsAttr)
    assert isinstance(temp.bounds, StencilBoundsAttr)
    if temp.bounds.lb < field.bounds.lb:
        raise VerifyException(
            "The stencil computation requires a field with lower bound at least "
            f"{temp.bounds.lb}, got {field.bounds.lb}, min: {min(field.bounds.lb, temp.bounds.lb)}"
        )
    if temp.bounds.ub > field.bounds.ub:
        raise VerifyException(
            "The stencil computation requires a field with upper bound at least "
            f"{temp.bounds.ub}, got {field.bounds.ub}, max: {max(field.bounds.ub, temp.bounds.ub)}"
        )

prepare_apply_body(op: ApplyOp)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def prepare_apply_body(op: ApplyOp):
    # First replace all current arguments by their definition
    # and erase them from the block. (We are changing the op
    # to a loop, which has access to them either way)
    entry = op.region.block

    for operand, arg in zip(op.operands, entry.args):
        arg.replace_all_uses_with(operand)
        entry.erase_arg(arg)
    entry.add_op(scf.ReduceOp())
    for _ in range(op.get_rank()):
        entry.insert_arg(builtin.IndexType(), 0)

    return op.region.detach_block(entry)

field_subview(field: SSAValue)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
346
347
348
349
350
351
352
353
354
355
def field_subview(field: SSAValue):
    assert isa(field_type := field.type, FieldType[Attribute])
    assert isinstance(bounds := field_type.bounds, StencilBoundsAttr)
    offsets = [i for i in -bounds.lb]
    sizes = [i for i in field_type.get_shape()]
    strides = [1] * len(sizes)

    return memref.SubviewOp.from_static_parameters(
        field, StencilToMemRefType(field_type), offsets, sizes, strides
    )

return_target_analysis(module: builtin.ModuleOp)

Source code in xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
def return_target_analysis(module: builtin.ModuleOp):
    return_targets: dict[ApplyOp, list[SSAValue | None]] = {}

    for op in module.walk():
        if not isinstance(op, ReturnOp):
            continue

        apply = op.parent_op()
        assert isinstance(apply, ApplyOp)

        return_targets[apply] = []
        for res in list(apply.res):
            store = [
                use
                for use in list(res.uses)
                if isinstance(use.operation, StoreOp | BufferOp | CombineOp)
            ]

            if len(store) > 1:
                warn("Each stencil result should be stored only once.")
                continue

            elif len(store) == 0:
                field = None
            else:
                field = _get_use_target(store[0])

            return_targets[apply].append(field)

    return return_targets