Skip to content

Stencil bufferize

stencil_bufferize

ApplyBufferizePattern

Bases: RewritePattern

Naive partial stencil.apply bufferization.

Just replace all temp arguments with the field result of a stencil.buffer on them, meaning "The buffer those value are allocated to".

Example:

%out = stencil.apply(%0 = %in : !stencil.temp<[0,32]xf64>) -> (!stencil.temp<[0,32]>xf64) {
    // [...]
}

yields:

%in_buf = stencil.buffer %in : !stencil.temp<[0,32]xf64> -> !stencil.field<[0,32]xf64>
stencil.apply(%0 = %in_buf : !stencil.field<[0,32]>xf64) outs (%out_buf : !stencil.field<[0,32]>xf64) {
    // [...]
}
Source code in xdsl/transforms/stencil_bufferize.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class ApplyBufferizePattern(RewritePattern):
    """
    Naive partial `stencil.apply` bufferization.

    Just replace all temp arguments with the field result of a stencil.buffer on them, meaning
    "The buffer those value are allocated to".

    Example:
    ```mlir
    %out = stencil.apply(%0 = %in : !stencil.temp<[0,32]xf64>) -> (!stencil.temp<[0,32]>xf64) {
        // [...]
    }
    ```
    yields:
    ```mlir
    %in_buf = stencil.buffer %in : !stencil.temp<[0,32]xf64> -> !stencil.field<[0,32]xf64>
    stencil.apply(%0 = %in_buf : !stencil.field<[0,32]>xf64) outs (%out_buf : !stencil.field<[0,32]>xf64) {
        // [...]
    }
    ```
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter):
        if all(not isinstance(o.type, TempType) for o in op.args):
            return

        bounds = op.get_bounds()

        args = [
            (
                BufferOp.create(
                    operands=[o],
                    result_types=[field_from_temp(o.type)],
                )
                if isa(o.type, TempType[Attribute])
                else o
            )
            for o in op.args
        ]

        new = ApplyOp(
            operands=[args, op.dest],
            regions=[op.detach_region(0)],
            result_types=[op.res.types],
            properties={"bounds": bounds},
        )

        rewriter.replace_op(op, [*(o for o in args if isinstance(o, Operation)), new])

match_and_rewrite(op: ApplyOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/stencil_bufferize.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter):
    if all(not isinstance(o.type, TempType) for o in op.args):
        return

    bounds = op.get_bounds()

    args = [
        (
            BufferOp.create(
                operands=[o],
                result_types=[field_from_temp(o.type)],
            )
            if isa(o.type, TempType[Attribute])
            else o
        )
        for o in op.args
    ]

    new = ApplyOp(
        operands=[args, op.dest],
        regions=[op.detach_region(0)],
        result_types=[op.res.types],
        properties={"bounds": bounds},
    )

    rewriter.replace_op(op, [*(o for o in args if isinstance(o, Operation)), new])

LoadBufferFoldPattern

Bases: RewritePattern

Fold a reference-semantic stencil.buffer of a stencil.load to the underlying field if safe.

Example:

%temp = stencil.load %field : !stencil.field<[-2,34]> -> !stencil.temp<[0,32]>
// [... No changes on %field]
%temp_f = stencil.buffer %temp : !stencil.temp<[0,32]> -> !stencil.field<[0,32]>
// [... No changes on %field]
// Last use of temp_f

yields: ```mlir // Will be simplified away or folded again %temp = stencil.load %field : !stencil.field<[-2,34]> -> !stencil.temp<[0,32]> // [... %temp_f replaced by %field]

Source code in xdsl/transforms/stencil_bufferize.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class LoadBufferFoldPattern(RewritePattern):
    """
    Fold a reference-semantic `stencil.buffer` of a `stencil.load` to the underlying
    field if safe.

    Example:
    ```mlir
    %temp = stencil.load %field : !stencil.field<[-2,34]> -> !stencil.temp<[0,32]>
    // [... No changes on %field]
    %temp_f = stencil.buffer %temp : !stencil.temp<[0,32]> -> !stencil.field<[0,32]>
    // [... No changes on %field]
    // Last use of temp_f
    ```
    yields:
    ```mlir
    // Will be simplified away or folded again
    %temp = stencil.load %field : !stencil.field<[-2,34]> -> !stencil.temp<[0,32]>
    // [... %temp_f replaced by %field]
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter):
        # If this is a value-semantic buffer, we can't fold it
        if not isinstance(op.res.type, FieldType):
            return

        temp = op.temp
        load = temp.owner
        # We are interested in folding buffers of loaded values
        if not isinstance(load, LoadOp):
            return

        underlying = load.field

        # TODO: further analysis
        # For now, only handle usages in the same block
        uses = tuple(op.res.uses)
        block = op.parent
        if not block or any(use.operation.parent is not block for use in uses):
            return
        last_user = max(
            uses, key=lambda u: block.get_operation_index(u.operation)
        ).operation

        effecting = [
            o
            for o in walk_from_to(load, last_user, inclusive=True)
            if might_effect(o, {MemoryEffectKind.WRITE}, underlying)
        ]

        # If the last effecting op is a stencil, handle the safe inplace case
        if (
            effecting
            and isinstance(effecting[-1], ApplyOp)
            and is_inplace(effecting[-1], op.res)
        ):
            effecting.pop()
        if effecting:
            return

        rewriter.replace_op(op, new_ops=[], new_results=[underlying])

match_and_rewrite(op: BufferOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/stencil_bufferize.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@op_type_rewrite_pattern
def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter):
    # If this is a value-semantic buffer, we can't fold it
    if not isinstance(op.res.type, FieldType):
        return

    temp = op.temp
    load = temp.owner
    # We are interested in folding buffers of loaded values
    if not isinstance(load, LoadOp):
        return

    underlying = load.field

    # TODO: further analysis
    # For now, only handle usages in the same block
    uses = tuple(op.res.uses)
    block = op.parent
    if not block or any(use.operation.parent is not block for use in uses):
        return
    last_user = max(
        uses, key=lambda u: block.get_operation_index(u.operation)
    ).operation

    effecting = [
        o
        for o in walk_from_to(load, last_user, inclusive=True)
        if might_effect(o, {MemoryEffectKind.WRITE}, underlying)
    ]

    # If the last effecting op is a stencil, handle the safe inplace case
    if (
        effecting
        and isinstance(effecting[-1], ApplyOp)
        and is_inplace(effecting[-1], op.res)
    ):
        effecting.pop()
    if effecting:
        return

    rewriter.replace_op(op, new_ops=[], new_results=[underlying])

ApplyStoreFoldPattern

Bases: RewritePattern

Fold stores of applys result

Example:

%temp = stencil.apply() -> (!stencil.temp<[0,32]>) {
    // [...]
}
// [... %dest not read]
stencil.store %temp to %dest (<[0], [32]>) : !stencil.temp<[0,32]> to !stencil.field<[-2,34]>

yields:

// Outputs on dest directly
stencil.apply() outs (%dest : !stencil.field<[-2,34]>) {
    // [...]
}
Source code in xdsl/transforms/stencil_bufferize.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
class ApplyStoreFoldPattern(RewritePattern):
    """
    Fold stores of applys result

    Example:
    ```mlir
    %temp = stencil.apply() -> (!stencil.temp<[0,32]>) {
        // [...]
    }
    // [... %dest not read]
    stencil.store %temp to %dest (<[0], [32]>) : !stencil.temp<[0,32]> to !stencil.field<[-2,34]>
    ```
    yields:
    ```mlir
    // Outputs on dest directly
    stencil.apply() outs (%dest : !stencil.field<[-2,34]>) {
        // [...]
    }
    ```
    """

    @staticmethod
    def is_dest_safe(apply: ApplyOp, store: StoreOp) -> bool:
        # Check that the destination is not used between the apply and store.
        dest = store.field
        effecting = [
            o
            for o in walk_from_to(apply, store)
            if might_effect(o, {MemoryEffectKind.READ, MemoryEffectKind.WRITE}, dest)
        ]
        return not effecting

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter):
        apply = op
        for temp_index, stored in enumerate(op.res):
            # We are looking for a result that is stored and foldable
            stores = [
                use.operation
                for use in stored.uses
                if isinstance(use.operation, StoreOp)
                and self.is_dest_safe(apply, use.operation)
            ]
            if not stores:
                continue

            bounds = apply.get_bounds()
            if not isinstance(bounds, StencilBoundsAttr):
                raise ValueError(
                    "Stencil shape inference must be ran before bufferization."
                )

            new_apply = ApplyOp.build(
                # We add new destinations for each store of the removed result
                operands=[
                    apply.args,
                    (*apply.dest, *(store.field for store in stores)),
                ],
                # We only remove the considered result
                result_types=[
                    [
                        r.type
                        for r in apply.results[:temp_index]
                        + apply.results[temp_index + 1 :]
                    ]
                ],
                properties=apply.properties.copy() | {"bounds": bounds},
                attributes=apply.attributes.copy(),
                # The block signature is the same
                regions=[
                    Region(Block(arg_types=[SSAValue.get(a).type for a in apply.args])),
                ],
            )

            # The body is the same
            rewriter.inline_block(
                apply.region.block,
                InsertPoint.at_start(new_apply.region.block),
                new_apply.region.block.args,
            )

            # We swap the return's operand order, to make sure the order still matches destinations
            # after bufferization
            old_return = new_apply.region.block.last_op
            assert isinstance(old_return, ReturnOp)
            uf = old_return.unroll_factor
            new_return_args = list(
                old_return.arg[: uf * temp_index]
                + old_return.arg[uf * (temp_index + 1) :]
                + old_return.arg[uf * temp_index : uf * (temp_index + 1)] * len(stores)
            )
            new_return = ReturnOp.create(
                operands=new_return_args,
                properties=old_return.properties.copy(),
                attributes=old_return.attributes.copy(),
            )
            rewriter.replace_op(old_return, new_return)

            # Create a load of a destination, for any other user of the result
            load = LoadOp.get(stores[0].field, bounds.lb, bounds.ub)

            rewriter.replace_op(
                op,
                [new_apply, load],
                new_apply.results[:temp_index]
                + (load.res,)
                + new_apply.results[temp_index:],
            )
            for store in stores:
                rewriter.erase_op(store)
            return

is_dest_safe(apply: ApplyOp, store: StoreOp) -> bool staticmethod

Source code in xdsl/transforms/stencil_bufferize.py
247
248
249
250
251
252
253
254
255
256
@staticmethod
def is_dest_safe(apply: ApplyOp, store: StoreOp) -> bool:
    # Check that the destination is not used between the apply and store.
    dest = store.field
    effecting = [
        o
        for o in walk_from_to(apply, store)
        if might_effect(o, {MemoryEffectKind.READ, MemoryEffectKind.WRITE}, dest)
    ]
    return not effecting

match_and_rewrite(op: ApplyOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/stencil_bufferize.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter):
    apply = op
    for temp_index, stored in enumerate(op.res):
        # We are looking for a result that is stored and foldable
        stores = [
            use.operation
            for use in stored.uses
            if isinstance(use.operation, StoreOp)
            and self.is_dest_safe(apply, use.operation)
        ]
        if not stores:
            continue

        bounds = apply.get_bounds()
        if not isinstance(bounds, StencilBoundsAttr):
            raise ValueError(
                "Stencil shape inference must be ran before bufferization."
            )

        new_apply = ApplyOp.build(
            # We add new destinations for each store of the removed result
            operands=[
                apply.args,
                (*apply.dest, *(store.field for store in stores)),
            ],
            # We only remove the considered result
            result_types=[
                [
                    r.type
                    for r in apply.results[:temp_index]
                    + apply.results[temp_index + 1 :]
                ]
            ],
            properties=apply.properties.copy() | {"bounds": bounds},
            attributes=apply.attributes.copy(),
            # The block signature is the same
            regions=[
                Region(Block(arg_types=[SSAValue.get(a).type for a in apply.args])),
            ],
        )

        # The body is the same
        rewriter.inline_block(
            apply.region.block,
            InsertPoint.at_start(new_apply.region.block),
            new_apply.region.block.args,
        )

        # We swap the return's operand order, to make sure the order still matches destinations
        # after bufferization
        old_return = new_apply.region.block.last_op
        assert isinstance(old_return, ReturnOp)
        uf = old_return.unroll_factor
        new_return_args = list(
            old_return.arg[: uf * temp_index]
            + old_return.arg[uf * (temp_index + 1) :]
            + old_return.arg[uf * temp_index : uf * (temp_index + 1)] * len(stores)
        )
        new_return = ReturnOp.create(
            operands=new_return_args,
            properties=old_return.properties.copy(),
            attributes=old_return.attributes.copy(),
        )
        rewriter.replace_op(old_return, new_return)

        # Create a load of a destination, for any other user of the result
        load = LoadOp.get(stores[0].field, bounds.lb, bounds.ub)

        rewriter.replace_op(
            op,
            [new_apply, load],
            new_apply.results[:temp_index]
            + (load.res,)
            + new_apply.results[temp_index:],
        )
        for store in stores:
            rewriter.erase_op(store)
        return

UpdateApplyArgs dataclass

Bases: RewritePattern

Stencil bufferization will often replace a temporary apply's argument with a wider one. This pattern simply updates block arguments accordingly.

Source code in xdsl/transforms/stencil_bufferize.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
@dataclass(frozen=True)
class UpdateApplyArgs(RewritePattern):
    """
    Stencil bufferization will often replace a temporary apply's argument with a wider
    one.
    This pattern simply updates block arguments accordingly.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter):
        new_arg_types = op.args.types
        if new_arg_types == op.region.block.arg_types:
            return

        new_block = Block(arg_types=new_arg_types)
        new_apply = ApplyOp.create(
            operands=op.operands,
            result_types=op.result_types,
            properties=op.properties.copy(),
            attributes=op.attributes.copy(),
            regions=[Region(new_block)],
        )

        rewriter.inline_block(
            op.region.block, InsertPoint.at_start(new_block), new_block.args
        )

        rewriter.replace_op(op, new_apply)

__init__() -> None

match_and_rewrite(op: ApplyOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/stencil_bufferize.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter):
    new_arg_types = op.args.types
    if new_arg_types == op.region.block.arg_types:
        return

    new_block = Block(arg_types=new_arg_types)
    new_apply = ApplyOp.create(
        operands=op.operands,
        result_types=op.result_types,
        properties=op.properties.copy(),
        attributes=op.attributes.copy(),
        regions=[Region(new_block)],
    )

    rewriter.inline_block(
        op.region.block, InsertPoint.at_start(new_block), new_block.args
    )

    rewriter.replace_op(op, new_apply)

BufferAlloc dataclass

Bases: RewritePattern

Replace a value semantic stencil.buffer by a load from an allocated field, after a store of the input values on it.

This matches the orginal dialect's lowering for this operation.

Example:

// [...]
%forward = stencil.buffer %in : !stencil.temp<[0,32]> -> !stencil.temp<[0,32]>
// [...]

yields:

%alloc = stencil.alloc : !stencil.field<[0,32]>xf64
// [...]
// This should be folded in the above computation
stencil.store %in to %alloc (<[0], [32]>) : !stencil.temp<[0,32]> to !stencil.field<[0,32]>
// This should be folded in the below computation
%forward = stencil.load %alloc : !stencil.field<[0,32]>xf64 -> !stencil.temp<[0,32]>
// [...]
Source code in xdsl/transforms/stencil_bufferize.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
@dataclass(frozen=True)
class BufferAlloc(RewritePattern):
    """
    Replace a value semantic `stencil.buffer` by a load from an allocated field, after
    a store of the input values on it.

    This matches the orginal dialect's lowering for this operation.

    Example:
    ```mlir
    // [...]
    %forward = stencil.buffer %in : !stencil.temp<[0,32]> -> !stencil.temp<[0,32]>
    // [...]
    ```
    yields:
    ```mlir
    %alloc = stencil.alloc : !stencil.field<[0,32]>xf64
    // [...]
    // This should be folded in the above computation
    stencil.store %in to %alloc (<[0], [32]>) : !stencil.temp<[0,32]> to !stencil.field<[0,32]>
    // This should be folded in the below computation
    %forward = stencil.load %alloc : !stencil.field<[0,32]>xf64 -> !stencil.temp<[0,32]>
    // [...]
    ```
    """  # noqa: E501

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter):
        # If it's not a value-semantic buffer, we let other patterns work on it.
        if not isinstance(op.res.type, TempType):
            return

        temp_t = cast(TempType[Attribute], op.temp.type)
        if not isinstance(temp_t.bounds, StencilBoundsAttr):
            raise ValueError(
                "Stencil shape inference must be ran before bufferization."
            )
        alloc = AllocOp(result_types=[field_from_temp(temp_t)])
        rewriter.insert_op(alloc, InsertPoint.at_start(cast(Block, op.parent)))

        rewriter.replace_op(
            op,
            new_ops=[
                StoreOp.get(op.temp, alloc.field, temp_t.bounds),
                LoadOp.get(alloc.field, temp_t.bounds.lb, temp_t.bounds.ub),
            ],
        )

__init__() -> None

match_and_rewrite(op: BufferOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/stencil_bufferize.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
@op_type_rewrite_pattern
def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter):
    # If it's not a value-semantic buffer, we let other patterns work on it.
    if not isinstance(op.res.type, TempType):
        return

    temp_t = cast(TempType[Attribute], op.temp.type)
    if not isinstance(temp_t.bounds, StencilBoundsAttr):
        raise ValueError(
            "Stencil shape inference must be ran before bufferization."
        )
    alloc = AllocOp(result_types=[field_from_temp(temp_t)])
    rewriter.insert_op(alloc, InsertPoint.at_start(cast(Block, op.parent)))

    rewriter.replace_op(
        op,
        new_ops=[
            StoreOp.get(op.temp, alloc.field, temp_t.bounds),
            LoadOp.get(alloc.field, temp_t.bounds.lb, temp_t.bounds.ub),
        ],
    )

CombineStoreFold dataclass

Bases: RewritePattern

A stored combine result is folded into stores of the matching operand in the destination field.

Example:

%res1, %res2 = stencil.combine 1 at 11 lower = (%0 : !stencil.temp<[0,16]xf64>) upper = (%1 : !stencil.temp<[16,32]xf64>) lowerext = (%2 : !stencil.temp<[0,16]xf64>): !stencil.temp<[0,32]xf64>, !stencil.temp<[0,32]xf64>
stencil.store %res1 to %dest1 (<[0], [32]>) : !stencil.temp<[0,32]xf64> to !stencil.field<[-2,34]xf64>
stencil.store %res2 to %dest2 (<[0], [32]>) : !stencil.temp<[0,32]xf64> to !stencil.field<[-2,34]xf64>

yields:

stencil.store %0 to %dest1 (<[0], [16]>) : !stencil.temp<[0,16]xf64> to !stencil.field<[-2,34]xf64>
stencil.store %1 to %dest1 (<[16], [32]>) : !stencil.temp<[16,32]xf64> to !stencil.field<[-2,34]xf64>
stencil.store %2 to %dest2 (<[0], [16]>) : !stencil.temp<[0,16]xf64> to !stencil.field<[-2,34]xf64>
Source code in xdsl/transforms/stencil_bufferize.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
@dataclass(frozen=True)
class CombineStoreFold(RewritePattern):
    """
    A stored combine result is folded into stores of the matching operand in the
    destination field.

    Example:
    ```mlir
    %res1, %res2 = stencil.combine 1 at 11 lower = (%0 : !stencil.temp<[0,16]xf64>) upper = (%1 : !stencil.temp<[16,32]xf64>) lowerext = (%2 : !stencil.temp<[0,16]xf64>): !stencil.temp<[0,32]xf64>, !stencil.temp<[0,32]xf64>
    stencil.store %res1 to %dest1 (<[0], [32]>) : !stencil.temp<[0,32]xf64> to !stencil.field<[-2,34]xf64>
    stencil.store %res2 to %dest2 (<[0], [32]>) : !stencil.temp<[0,32]xf64> to !stencil.field<[-2,34]xf64>
    ```
    yields:
    ```mlir
    stencil.store %0 to %dest1 (<[0], [16]>) : !stencil.temp<[0,16]xf64> to !stencil.field<[-2,34]xf64>
    stencil.store %1 to %dest1 (<[16], [32]>) : !stencil.temp<[16,32]xf64> to !stencil.field<[-2,34]xf64>
    stencil.store %2 to %dest2 (<[0], [16]>) : !stencil.temp<[0,16]xf64> to !stencil.field<[-2,34]xf64>
    ```
    """  # noqa: E501

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter):
        for i, r in enumerate(op.results):
            if not r.has_one_use():
                continue
            store = next(iter(r.uses)).operation
            if not isinstance(store, StoreOp):
                continue

            new_lower = op.lower
            new_upper = op.upper
            new_lowerext = op.lowerext
            new_upperext = op.upperext
            new_results_types = list(op.result_types)
            new_results_types.pop(i)

            bounds = cast(StencilBoundsAttr, cast(TempType[Attribute], r.type).bounds)
            newub = list(bounds.ub)
            newub[op.dim.value.data] = op.index.value.data
            lower_bounds = StencilBoundsAttr.new((bounds.lb, IndexAttr.get(*newub)))
            newlb = list(bounds.lb)
            newlb[op.dim.value.data] = op.index.value.data
            upper_bounds = StencilBoundsAttr.new((IndexAttr.get(*newlb), bounds.ub))

            rewriter.erase_op(store)

            # If it corresponds to a lower/upper result
            if i < len(op.lower):
                new_lower = op.lower[:i] + op.lower[i + 1 :]
                new_upper = op.upper[:i] + op.upper[i + 1 :]
                rewriter.insert_op(
                    (
                        StoreOp.get(
                            op.lower[i],
                            store.field,
                            lower_bounds,
                        ),
                        StoreOp.get(
                            op.upper[i],
                            store.field,
                            upper_bounds,
                        ),
                    ),
                    InsertPoint.before(op),
                )
            # If it corresponds to a lowerext result
            elif i < len(op.lower) + len(op.lowerext):
                new_lowerext = (
                    op.lowerext[: i - len(op.lower)]
                    + op.lowerext[i - len(op.lower) + 1 :]
                )
                rewriter.insert_op(
                    (
                        StoreOp.get(
                            op.lower[i],
                            store.field,
                            lower_bounds,
                        ),
                        StoreOp.get(
                            op.upper[i],
                            store.field,
                            upper_bounds,
                        ),
                    ),
                    InsertPoint.before(op),
                )
            else:
                new_upperext = (
                    op.upperext[: i - len(op.lower) - len(op.lowerext)]
                    + op.upperext[i - len(op.lower) - len(op.lowerext) + 1 :]
                )
                rewriter.insert_op(
                    (
                        StoreOp.get(
                            op.lower[i],
                            store.field,
                            lower_bounds,
                        ),
                        StoreOp.get(
                            op.upper[i],
                            store.field,
                            upper_bounds,
                        ),
                    ),
                    InsertPoint.before(op),
                )

            new_combine = CombineOp(
                operands=[new_lower, new_upper, new_lowerext, new_upperext],
                result_types=[new_results_types],
                attributes=op.attributes.copy(),
                properties=op.properties.copy(),
            )
            rewriter.replace_op(
                op,
                new_combine,
                new_results=new_combine.results[:i] + (None,) + new_combine.results[i:],
            )
            return

__init__() -> None

match_and_rewrite(op: CombineOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/stencil_bufferize.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
@op_type_rewrite_pattern
def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter):
    for i, r in enumerate(op.results):
        if not r.has_one_use():
            continue
        store = next(iter(r.uses)).operation
        if not isinstance(store, StoreOp):
            continue

        new_lower = op.lower
        new_upper = op.upper
        new_lowerext = op.lowerext
        new_upperext = op.upperext
        new_results_types = list(op.result_types)
        new_results_types.pop(i)

        bounds = cast(StencilBoundsAttr, cast(TempType[Attribute], r.type).bounds)
        newub = list(bounds.ub)
        newub[op.dim.value.data] = op.index.value.data
        lower_bounds = StencilBoundsAttr.new((bounds.lb, IndexAttr.get(*newub)))
        newlb = list(bounds.lb)
        newlb[op.dim.value.data] = op.index.value.data
        upper_bounds = StencilBoundsAttr.new((IndexAttr.get(*newlb), bounds.ub))

        rewriter.erase_op(store)

        # If it corresponds to a lower/upper result
        if i < len(op.lower):
            new_lower = op.lower[:i] + op.lower[i + 1 :]
            new_upper = op.upper[:i] + op.upper[i + 1 :]
            rewriter.insert_op(
                (
                    StoreOp.get(
                        op.lower[i],
                        store.field,
                        lower_bounds,
                    ),
                    StoreOp.get(
                        op.upper[i],
                        store.field,
                        upper_bounds,
                    ),
                ),
                InsertPoint.before(op),
            )
        # If it corresponds to a lowerext result
        elif i < len(op.lower) + len(op.lowerext):
            new_lowerext = (
                op.lowerext[: i - len(op.lower)]
                + op.lowerext[i - len(op.lower) + 1 :]
            )
            rewriter.insert_op(
                (
                    StoreOp.get(
                        op.lower[i],
                        store.field,
                        lower_bounds,
                    ),
                    StoreOp.get(
                        op.upper[i],
                        store.field,
                        upper_bounds,
                    ),
                ),
                InsertPoint.before(op),
            )
        else:
            new_upperext = (
                op.upperext[: i - len(op.lower) - len(op.lowerext)]
                + op.upperext[i - len(op.lower) - len(op.lowerext) + 1 :]
            )
            rewriter.insert_op(
                (
                    StoreOp.get(
                        op.lower[i],
                        store.field,
                        lower_bounds,
                    ),
                    StoreOp.get(
                        op.upper[i],
                        store.field,
                        upper_bounds,
                    ),
                ),
                InsertPoint.before(op),
            )

        new_combine = CombineOp(
            operands=[new_lower, new_upper, new_lowerext, new_upperext],
            result_types=[new_results_types],
            attributes=op.attributes.copy(),
            properties=op.properties.copy(),
        )
        rewriter.replace_op(
            op,
            new_combine,
            new_results=new_combine.results[:i] + (None,) + new_combine.results[i:],
        )
        return

SwapBufferize

Bases: RewritePattern

Bufferize a dmp.swap operation.

NB: This should most likely consider a shared pass following canonicalize and shape-inference.

Source code in xdsl/transforms/stencil_bufferize.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
class SwapBufferize(RewritePattern):
    """
    Bufferize a dmp.swap operation.

    NB: This should most likely consider a shared pass following canonicalize and
    shape-inference.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: SwapOp, rewriter: PatternRewriter):
        temp = op.input_stencil

        if not isa(temp_t := temp.type, TempType[Attribute]):
            return

        load = temp.owner
        if not isinstance(load, LoadOp):
            return

        buffer = BufferOp.create(
            operands=[temp], result_types=[field_from_temp(temp_t)]
        )
        new_swap = SwapOp.get(buffer.res, op.strategy)
        new_swap.swaps = op.swaps
        load = LoadOp(operands=[buffer.res], result_types=[temp_t])

        rewriter.replace_op(
            op,
            new_ops=[buffer, new_swap, load],
        )

match_and_rewrite(op: SwapOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/stencil_bufferize.py
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
@op_type_rewrite_pattern
def match_and_rewrite(self, op: SwapOp, rewriter: PatternRewriter):
    temp = op.input_stencil

    if not isa(temp_t := temp.type, TempType[Attribute]):
        return

    load = temp.owner
    if not isinstance(load, LoadOp):
        return

    buffer = BufferOp.create(
        operands=[temp], result_types=[field_from_temp(temp_t)]
    )
    new_swap = SwapOp.get(buffer.res, op.strategy)
    new_swap.swaps = op.swaps
    load = LoadOp(operands=[buffer.res], result_types=[temp_t])

    rewriter.replace_op(
        op,
        new_ops=[buffer, new_swap, load],
    )

StencilBufferize dataclass

Bases: ModulePass

Bufferize the stencil dialect, i.e., try to fold all loads, sotres, buffer and combines, and to output stencils working directly on buffers (fields) with hopefully few allocations.

Source code in xdsl/transforms/stencil_bufferize.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
@dataclass(frozen=True)
class StencilBufferize(ModulePass):
    """
    Bufferize the stencil dialect, i.e., try to fold all loads, sotres, buffer and
    combines, and to output stencils working directly on buffers (fields) with
    hopefully few allocations.
    """

    name = "stencil-bufferize"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        walker = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    UpdateApplyArgs(),
                    ApplyBufferizePattern(),
                    BufferAlloc(),
                    CombineStoreFold(),
                    LoadBufferFoldPattern(),
                    ApplyStoreFoldPattern(),
                    RemoveUnusedOperations(),
                    ApplyUnusedResults(),
                    SwapBufferize(),
                ]
            ),
            apply_recursively=True,
        )
        walker.rewrite_module(op)

name = 'stencil-bufferize' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/stencil_bufferize.py
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    walker = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                UpdateApplyArgs(),
                ApplyBufferizePattern(),
                BufferAlloc(),
                CombineStoreFold(),
                LoadBufferFoldPattern(),
                ApplyStoreFoldPattern(),
                RemoveUnusedOperations(),
                ApplyUnusedResults(),
                SwapBufferize(),
            ]
        ),
        apply_recursively=True,
    )
    walker.rewrite_module(op)

field_from_temp(temp: TempType[_TypeElement]) -> FieldType[_TypeElement]

Source code in xdsl/transforms/stencil_bufferize.py
50
51
def field_from_temp(temp: TempType[_TypeElement]) -> FieldType[_TypeElement]:
    return FieldType[_TypeElement].new(temp.parameters)

might_effect(operation: Operation, effects: set[MemoryEffectKind], value: SSAValue) -> bool

Return True if the operation might have any of the given effects on the given value.

Source code in xdsl/transforms/stencil_bufferize.py
54
55
56
57
58
59
60
61
62
63
def might_effect(
    operation: Operation, effects: set[MemoryEffectKind], value: SSAValue
) -> bool:
    """
    Return True if the operation might have any of the given effects on the given value.
    """
    op_effects = get_effects(operation)
    return op_effects is None or any(
        e.kind in effects and e.value in (None, value) for e in op_effects
    )

walk_from(a: Operation) -> Generator[Operation, Any, None]

Walk through all operations recursively inside a or its block.

Source code in xdsl/transforms/stencil_bufferize.py
117
118
119
120
121
122
123
124
125
def walk_from(a: Operation) -> Generator[Operation, Any, None]:
    """
    Walk through all operations recursively inside a or its block.
    """
    while True:
        yield from a.walk()
        if a.next_op is None:
            break
        a = a.next_op

walk_from_to(a: Operation, b: Operation, *, inclusive: bool = False)

Walk through all operations recursively inside a or its block, until b is met, if ever.

Source code in xdsl/transforms/stencil_bufferize.py
128
129
130
131
132
133
134
135
136
137
138
def walk_from_to(a: Operation, b: Operation, *, inclusive: bool = False):
    """
    Walk through all operations recursively inside a or its block, until b is met, if
    ever.
    """
    for o in walk_from(a):
        if o == b:
            if inclusive:
                yield o
            return
        yield o

is_inplace(apply: ApplyOp, field: SSAValue)

Check if the passed stencil.apply has any non-zero offset access to the passed stencil.field.

Source code in xdsl/transforms/stencil_bufferize.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def is_inplace(apply: ApplyOp, field: SSAValue):
    """
    Check if the passed `stencil.apply` has any non-zero offset access to the passed
    `stencil.field`.
    """
    # Get all block arguments matching this field
    field_args = set(
        apply.region.block.args[i] for (i, a) in enumerate(apply.args) if a is field
    )
    # Is there any non-zero access on those arguments?
    return not any(
        access
        for access in apply.walk()
        if isinstance(access, AccessOp)
        and access.temp in field_args
        and any(o != 0 for o in access.offset)
        or isinstance(access, DynAccessOp)
        and access.temp in field_args
        and any(o != 0 for o in chain(access.lb, access.ub))
    )