Skip to content

Convert stencil to csl stencil

convert_stencil_to_csl_stencil

ConvertAccessOpPattern dataclass

Bases: RewritePattern

Rebuilds stencil.access by csl_stencil.access which operates on prefetched accesses.

stencil.access operates on stencil.temp types found at arg_index csl_stencil.access operates on memref< num_neighbors x tensor< buf_size x data_type >> found at last arg index

Note: This is intended to be called in a nested pattern rewriter, such that the above precondition is met.

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@dataclass(frozen=True)
class ConvertAccessOpPattern(RewritePattern):
    """
    Rebuilds stencil.access by csl_stencil.access which operates on prefetched accesses.

    stencil.access operates on stencil.temp types found at arg_index
    csl_stencil.access operates on memref< num_neighbors x tensor< buf_size x data_type >> found at last arg index

    Note: This is intended to be called in a nested pattern rewriter, such that the above precondition is met.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: stencil.AccessOp, rewriter: PatternRewriter, /):
        assert len(op.offset) == 2
        if isa(op.temp.type, AnyTensorType):
            res_type = TensorType(
                op.temp.type.get_element_type(), op.temp.type.get_shape()[1:]
            )
        else:
            assert isa(op.res.type, AnyTensorType)
            res_type = op.res.type
        rewriter.replace_op(
            op,
            new_access_op := csl_stencil.AccessOp(
                op=op.temp,
                offset=op.offset,
                offset_mapping=op.offset_mapping,
                result_type=res_type,
            ),
        )

        # The stencil-tensorize-z-dimension pass inserts tensor.ExtractSliceOps after stencil.access to remove ghost cells.
        # Since ghost cells are not prefetched, these ops can be removed again. Check if the ExtractSliceOp
        # has no other effect and if so, remove both.
        if (
            isinstance(
                use := new_access_op.result.get_user_of_unique_use(),
                tensor.ExtractSliceOp,
            )
            and use.static_sizes.get_values() == res_type.get_shape()
            and len(use.offsets) == 0
            and len(use.sizes) == 0
            and len(use.strides) == 0
        ):
            rewriter.replace_op(use, [], new_results=[new_access_op.result])

__init__() -> None

match_and_rewrite(op: stencil.AccessOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.AccessOp, rewriter: PatternRewriter, /):
    assert len(op.offset) == 2
    if isa(op.temp.type, AnyTensorType):
        res_type = TensorType(
            op.temp.type.get_element_type(), op.temp.type.get_shape()[1:]
        )
    else:
        assert isa(op.res.type, AnyTensorType)
        res_type = op.res.type
    rewriter.replace_op(
        op,
        new_access_op := csl_stencil.AccessOp(
            op=op.temp,
            offset=op.offset,
            offset_mapping=op.offset_mapping,
            result_type=res_type,
        ),
    )

    # The stencil-tensorize-z-dimension pass inserts tensor.ExtractSliceOps after stencil.access to remove ghost cells.
    # Since ghost cells are not prefetched, these ops can be removed again. Check if the ExtractSliceOp
    # has no other effect and if so, remove both.
    if (
        isinstance(
            use := new_access_op.result.get_user_of_unique_use(),
            tensor.ExtractSliceOp,
        )
        and use.static_sizes.get_values() == res_type.get_shape()
        and len(use.offsets) == 0
        and len(use.sizes) == 0
        and len(use.strides) == 0
    ):
        rewriter.replace_op(use, [], new_results=[new_access_op.result])

ConvertSwapToPrefetchPattern dataclass

Bases: RewritePattern

Translates dmp.swap to csl_stencil.prefetch

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
@dataclass
class ConvertSwapToPrefetchPattern(RewritePattern):
    """
    Translates dmp.swap to csl_stencil.prefetch
    """

    num_chunks: int = 1

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
        # remove op if it contains no swaps
        if len(op.swaps) == 0:
            rewriter.erase_op(op, safe_erase=False)
            return

        assert all(len(swap.size) == 3 for swap in op.swaps), (
            "currently only 3-dimensional stencils are supported"
        )

        assert all(swap.size[:2] == (1, 1) for swap in op.swaps), (
            "invoke dmp to decompose from (x,y,z) to (1,1,z)"
        )

        # check that size is uniform
        uniform_size = op.swaps.data[0].size[2]
        assert all(swap.size[2] == uniform_size for swap in op.swaps), (
            "all swaps need to be of uniform size"
        )

        assert (MemRefType.constr() | stencil.StencilTypeConstr).verifies(
            op.input_stencil.type
        )
        assert isa(
            t_type := op.input_stencil.type.get_element_type(), TensorType[Attribute]
        )
        assert op.strategy.comm_layout() is not None, (
            f"topology on {type(op)} is not given"
        )

        # when translating swaps, remove third dimension
        prefetch_op = csl_stencil.PrefetchOp(
            input_stencil=op.input_stencil,
            topo=op.strategy.comm_layout(),
            num_chunks=IntegerAttr(self.num_chunks, 64),
            swaps=[
                csl_stencil.ExchangeDeclarationAttr(swap.neighbor[:2])
                for swap in op.swaps
            ],
            result_type=TensorType(
                t_type.get_element_type(),
                (len(op.swaps), uniform_size),
            ),
        )

        # if the rewriter needs a result, use `input_stencil` as a drop-in replacement
        # prefetch_op produces a result that needs to be handled separately
        # note, that only un-bufferized dmp.swaps produce a result
        rewriter.replace_op(
            op, prefetch_op, new_results=[op.input_stencil] if op.swapped_values else []
        )

        # uses have to be retrieved *before* the loop because of the rewriting happening inside the loop
        uses = list(op.input_stencil.uses)

        # csl_stencil.prefetch, unlike dmp.swap, has a return value. This is added as the last arg
        # to stencil.apply, before rebuilding the op and replacing stencil.access ops by csl_stencil.access ops
        # that reference the prefetched buffers (note, this is only done for neighbor accesses)
        for use in uses:
            if not isinstance(use.operation, stencil.ApplyOp):
                continue
            apply_op = use.operation

            # arg_idx points to the stencil.temp type whose data is prefetched in a separate buffer
            arg_idx = apply_op.args.index(op.input_stencil)
            field_block_arg = apply_op.region.block.args[arg_idx]

            # add the prefetched buffer as the last arg to stencil.access
            prefetch_block_arg = apply_op.region.block.insert_arg(
                prefetch_op.result.type, len(apply_op.args)
            )
            rewriter.replace_uses_with_if(
                field_block_arg,
                prefetch_block_arg,
                lambda use: isinstance(use.operation, stencil.AccessOp)
                and tuple(use.operation.offset) != (0, 0),
            )

            # rebuild stencil.apply op
            r_types = apply_op.result_types
            assert isa(r_types, Sequence[stencil.TempType[Attribute]])
            new_apply_op = stencil.ApplyOp.build(
                operands=[[*apply_op.args, prefetch_op.result], apply_op.dest],
                regions=[apply_op.detach_region(apply_op.region)],
                result_types=[r_types],
                properties=apply_op.properties,
                attributes=apply_op.attributes,
            )
            rewriter.replace_op(apply_op, new_apply_op)

num_chunks: int = 1 class-attribute instance-attribute

__init__(num_chunks: int = 1) -> None

match_and_rewrite(op: dmp.SwapOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
@op_type_rewrite_pattern
def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
    # remove op if it contains no swaps
    if len(op.swaps) == 0:
        rewriter.erase_op(op, safe_erase=False)
        return

    assert all(len(swap.size) == 3 for swap in op.swaps), (
        "currently only 3-dimensional stencils are supported"
    )

    assert all(swap.size[:2] == (1, 1) for swap in op.swaps), (
        "invoke dmp to decompose from (x,y,z) to (1,1,z)"
    )

    # check that size is uniform
    uniform_size = op.swaps.data[0].size[2]
    assert all(swap.size[2] == uniform_size for swap in op.swaps), (
        "all swaps need to be of uniform size"
    )

    assert (MemRefType.constr() | stencil.StencilTypeConstr).verifies(
        op.input_stencil.type
    )
    assert isa(
        t_type := op.input_stencil.type.get_element_type(), TensorType[Attribute]
    )
    assert op.strategy.comm_layout() is not None, (
        f"topology on {type(op)} is not given"
    )

    # when translating swaps, remove third dimension
    prefetch_op = csl_stencil.PrefetchOp(
        input_stencil=op.input_stencil,
        topo=op.strategy.comm_layout(),
        num_chunks=IntegerAttr(self.num_chunks, 64),
        swaps=[
            csl_stencil.ExchangeDeclarationAttr(swap.neighbor[:2])
            for swap in op.swaps
        ],
        result_type=TensorType(
            t_type.get_element_type(),
            (len(op.swaps), uniform_size),
        ),
    )

    # if the rewriter needs a result, use `input_stencil` as a drop-in replacement
    # prefetch_op produces a result that needs to be handled separately
    # note, that only un-bufferized dmp.swaps produce a result
    rewriter.replace_op(
        op, prefetch_op, new_results=[op.input_stencil] if op.swapped_values else []
    )

    # uses have to be retrieved *before* the loop because of the rewriting happening inside the loop
    uses = list(op.input_stencil.uses)

    # csl_stencil.prefetch, unlike dmp.swap, has a return value. This is added as the last arg
    # to stencil.apply, before rebuilding the op and replacing stencil.access ops by csl_stencil.access ops
    # that reference the prefetched buffers (note, this is only done for neighbor accesses)
    for use in uses:
        if not isinstance(use.operation, stencil.ApplyOp):
            continue
        apply_op = use.operation

        # arg_idx points to the stencil.temp type whose data is prefetched in a separate buffer
        arg_idx = apply_op.args.index(op.input_stencil)
        field_block_arg = apply_op.region.block.args[arg_idx]

        # add the prefetched buffer as the last arg to stencil.access
        prefetch_block_arg = apply_op.region.block.insert_arg(
            prefetch_op.result.type, len(apply_op.args)
        )
        rewriter.replace_uses_with_if(
            field_block_arg,
            prefetch_block_arg,
            lambda use: isinstance(use.operation, stencil.AccessOp)
            and tuple(use.operation.offset) != (0, 0),
        )

        # rebuild stencil.apply op
        r_types = apply_op.result_types
        assert isa(r_types, Sequence[stencil.TempType[Attribute]])
        new_apply_op = stencil.ApplyOp.build(
            operands=[[*apply_op.args, prefetch_op.result], apply_op.dest],
            regions=[apply_op.detach_region(apply_op.region)],
            result_types=[r_types],
            properties=apply_op.properties,
            attributes=apply_op.attributes,
        )
        rewriter.replace_op(apply_op, new_apply_op)

SplitVarithOpPattern dataclass

Bases: RewritePattern

Splits a varith op into two, depending on whether the operands holds stencil accesses to buf (only) or any other accesses.

This pass is intended to be run with buf set to the block arg indicating data received from neighbours.

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
@dataclass(frozen=True)
class SplitVarithOpPattern(RewritePattern):
    """
    Splits a varith op into two, depending on whether the operands holds stencil accesses to `buf` (only)
    or any other accesses.

    This pass is intended to be run with `buf` set to the block arg indicating data received from neighbours.
    """

    buf: BlockArgument

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /):
        if not (apply := _get_apply_op(op)) or not (
            buf_idx := _get_prefetch_buf_idx(apply)
        ):
            return
        buf = apply.region.block.args[buf_idx]
        buf_accesses, others = list[SSAValue](), list[SSAValue]()

        for arg in op.args:
            accs = get_stencil_access_operands(arg)
            (others, buf_accesses)[buf in accs and len(accs) == 1].append(arg)

        if len(others) > 0 and len(buf_accesses) > 0:
            rewriter.replace_op(
                op,
                [
                    n_op := type(op)(*buf_accesses),
                    type(op)(n_op, *others),
                ],
            )

buf: BlockArgument instance-attribute

__init__(buf: BlockArgument) -> None

match_and_rewrite(op: varith.VarithOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
@op_type_rewrite_pattern
def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /):
    if not (apply := _get_apply_op(op)) or not (
        buf_idx := _get_prefetch_buf_idx(apply)
    ):
        return
    buf = apply.region.block.args[buf_idx]
    buf_accesses, others = list[SSAValue](), list[SSAValue]()

    for arg in op.args:
        accs = get_stencil_access_operands(arg)
        (others, buf_accesses)[buf in accs and len(accs) == 1].append(arg)

    if len(others) > 0 and len(buf_accesses) > 0:
        rewriter.replace_op(
            op,
            [
                n_op := type(op)(*buf_accesses),
                type(op)(n_op, *others),
            ],
        )

ConvertApplyOpPattern dataclass

Bases: RewritePattern

Fuses a csl_stencil.prefetch and a stencil.apply to build a csl_stencil.apply.

If there are several candidate prefetch ops, the one with the largest result buffer size is selected. The selection is greedy, and could in the future be expanded into a more global selection optimising for minimal prefetch overhead across multiple apply ops.

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
@dataclass(frozen=True)
class ConvertApplyOpPattern(RewritePattern):
    """
    Fuses a `csl_stencil.prefetch` and a `stencil.apply` to build a `csl_stencil.apply`.

    If there are several candidate prefetch ops, the one with the largest result buffer
    size is selected.
    The selection is greedy, and could in the future be expanded into a more global
    selection optimising for minimal prefetch overhead across multiple apply ops.
    """

    num_chunks: int = 1
    """
    Number of chunks into which communication and computation should be split.
    Effectively, the number of times `csl_stencil.apply.receive_chunk` will be executed
    and the tensor sizes it handles.
    Higher values may increase compute overhead but reduce size of communication buffers
    when lowered.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /):
        if not (prefetch_idx := _get_prefetch_buf_idx(op)):
            return

        # select the prefetch with the biggest communication overhead to be fused with matched stencil.apply
        prefetch = op.operands[prefetch_idx]
        assert isinstance(prefetch, OpResult)
        assert isinstance(prefetch.op, csl_stencil.PrefetchOp)
        field_idx = op.operands.index(prefetch.op.input_stencil)
        assert isinstance(prefetch.op, csl_stencil.PrefetchOp)
        assert isa(prefetch.type, TensorType[Attribute])
        field_op_arg = prefetch.op.input_stencil

        # add empty tensor before op to be used as `accumulator`
        # this could potentially be re-used if we have one of the same size lying around
        accumulator = tensor.EmptyOp(
            (),
            TensorType(prefetch.type.get_element_type(), prefetch.type.get_shape()[1:]),
        )
        rewriter.insert_op(accumulator, InsertPoint.before(op))

        # run pass (on this apply's region only) to consume data from `prefetch` accesses first
        # find varith ops and split according to neighbour data
        PatternRewriteWalker(
            SplitVarithOpPattern(op.region.block.args[prefetch_idx]),
            apply_recursively=False,
            listener=rewriter,
        ).rewrite_region(op.region)

        # determine how ops should be split across the two regions
        chunk_region_ops, done_exchange_ops = split_ops(
            list(op.region.block.ops), op.region.block.args[prefetch_idx], rewriter
        )

        # fetch what receive_chunk is computing for
        if isinstance(chunk_region_ops[-1], stencil.ReturnOp):
            chunk_res = chunk_region_ops[-1].operands[0]
        else:
            chunk_res = chunk_region_ops[-1].results[0]

        # after region split, check which block args (from the old ops block) are being accessed in each of the new regions
        # ignore accesses block args which already are part of the region's required signature
        chunk_region_used_block_args = sorted(
            set(
                x
                for o in chunk_region_ops
                for x in o.operands
                if isinstance(x, BlockArgument) and x.index != prefetch_idx
            ),
            key=lambda b: b.index,
        )
        done_exchange_used_block_args = sorted(
            set(
                x
                for o in done_exchange_ops
                for x in o.operands
                if isinstance(x, BlockArgument) and x.index != field_idx
            ),
            key=lambda b: b.index,
        )

        # set up region signatures, comprising fixed and optional args - see docs on `csl_stencil.apply` for details
        chunk_region_args = [
            # required arg 0: slice of type(%prefetch)
            TensorType(
                prefetch.type.get_element_type(),
                (
                    len(prefetch.op.swaps),
                    prefetch.type.get_shape()[1] // self.num_chunks,
                ),
            ),
            # required arg 1: %offset
            IndexType(),
            # required arg 2: %accumulator
            accumulator.tensor.type,
            # optional args: as needed by the ops
            *[a.type for a in chunk_region_used_block_args],
        ]
        done_exchange_args = [
            # required arg 0: stencil.temp to access own data
            field_op_arg.type,
            # required arg 1: %accumulator
            accumulator.tensor.type,
            # optional args: as needed by the ops
            *[a.type for a in done_exchange_used_block_args],
        ]

        # set up two regions
        receive_chunk = Region(Block(arg_types=chunk_region_args))
        done_exchange = Region(Block(arg_types=done_exchange_args))

        # translate old to new block arg index for optional args
        chunk_region_oprnd_table = dict[Operand, Operand](
            (old, receive_chunk.block.args[idx])
            for idx, old in enumerate(chunk_region_used_block_args, start=3)
        )
        done_exchange_oprnd_table = dict[Operand, Operand](
            (old, done_exchange.block.args[idx])
            for idx, old in enumerate(done_exchange_used_block_args, start=2)
        )

        # add translation from old to new arg index for non-optional args - note, access
        # to accumulator must be handled separately below
        chunk_region_oprnd_table[op.region.block.args[prefetch_idx]] = (
            receive_chunk.block.args[0]
        )
        done_exchange_oprnd_table[op.region.block.args[field_idx]] = (
            done_exchange.block.args[0]
        )
        done_exchange_oprnd_table[chunk_res] = done_exchange.block.args[1]

        # detach ops from old region
        for o in op.region.block.ops:
            op.region.block.detach_op(o)

        # add operations from list to receive_chunk, use translation table to rebuild operands
        for o in chunk_region_ops:
            if isinstance(o, stencil.ReturnOp | csl_stencil.YieldOp):
                break
            o.operands = [chunk_region_oprnd_table.get(x, x) for x in o.operands]
            rewriter.insert_op(o, InsertPoint.at_end(receive_chunk.block))

        # put `chunk_res` into `accumulator` (using tensor.insert_slice) and yield the result
        rewriter.insert_op(
            [
                insert_slice_op := tensor.InsertSliceOp.get(
                    source=chunk_res,
                    dest=receive_chunk.block.args[2],
                    offsets=(receive_chunk.block.args[1],),
                    static_sizes=(prefetch.type.get_shape()[1] // self.num_chunks,),
                ),
                csl_stencil.YieldOp(insert_slice_op.result),
            ],
            InsertPoint.at_end(receive_chunk.block),
        )

        # add operations from list to done_exchange, use translation table to rebuild operands
        for o in done_exchange_ops:
            o.operands = [done_exchange_oprnd_table.get(x, x) for x in o.operands]
            rewriter.insert_op(o, InsertPoint.at_end(done_exchange.block))
            if isinstance(o, stencil.ReturnOp):
                rewriter.replace_op(o, csl_stencil.YieldOp(*o.operands))

        rewriter.replace_op(
            op,
            csl_stencil.ApplyOp(
                operands=[
                    field_op_arg,
                    accumulator,
                    [op.operands[a.index] for a in chunk_region_used_block_args],
                    [op.operands[a.index] for a in done_exchange_used_block_args],
                    op.dest,
                ],
                properties={
                    "swaps": prefetch.op.swaps,
                    "topo": prefetch.op.topo,
                    "num_chunks": IntegerAttr(self.num_chunks, IntegerType(64)),
                    "bounds": op.bounds,
                },
                regions=[
                    receive_chunk,
                    done_exchange,
                ],
                result_types=[op.result_types],
            ),
        )

        if not prefetch.uses:
            rewriter.erase_op(prefetch.op)

num_chunks: int = 1 class-attribute instance-attribute

Number of chunks into which communication and computation should be split. Effectively, the number of times csl_stencil.apply.receive_chunk will be executed and the tensor sizes it handles. Higher values may increase compute overhead but reduce size of communication buffers when lowered.

__init__(num_chunks: int = 1) -> None

match_and_rewrite(op: stencil.ApplyOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /):
    if not (prefetch_idx := _get_prefetch_buf_idx(op)):
        return

    # select the prefetch with the biggest communication overhead to be fused with matched stencil.apply
    prefetch = op.operands[prefetch_idx]
    assert isinstance(prefetch, OpResult)
    assert isinstance(prefetch.op, csl_stencil.PrefetchOp)
    field_idx = op.operands.index(prefetch.op.input_stencil)
    assert isinstance(prefetch.op, csl_stencil.PrefetchOp)
    assert isa(prefetch.type, TensorType[Attribute])
    field_op_arg = prefetch.op.input_stencil

    # add empty tensor before op to be used as `accumulator`
    # this could potentially be re-used if we have one of the same size lying around
    accumulator = tensor.EmptyOp(
        (),
        TensorType(prefetch.type.get_element_type(), prefetch.type.get_shape()[1:]),
    )
    rewriter.insert_op(accumulator, InsertPoint.before(op))

    # run pass (on this apply's region only) to consume data from `prefetch` accesses first
    # find varith ops and split according to neighbour data
    PatternRewriteWalker(
        SplitVarithOpPattern(op.region.block.args[prefetch_idx]),
        apply_recursively=False,
        listener=rewriter,
    ).rewrite_region(op.region)

    # determine how ops should be split across the two regions
    chunk_region_ops, done_exchange_ops = split_ops(
        list(op.region.block.ops), op.region.block.args[prefetch_idx], rewriter
    )

    # fetch what receive_chunk is computing for
    if isinstance(chunk_region_ops[-1], stencil.ReturnOp):
        chunk_res = chunk_region_ops[-1].operands[0]
    else:
        chunk_res = chunk_region_ops[-1].results[0]

    # after region split, check which block args (from the old ops block) are being accessed in each of the new regions
    # ignore accesses block args which already are part of the region's required signature
    chunk_region_used_block_args = sorted(
        set(
            x
            for o in chunk_region_ops
            for x in o.operands
            if isinstance(x, BlockArgument) and x.index != prefetch_idx
        ),
        key=lambda b: b.index,
    )
    done_exchange_used_block_args = sorted(
        set(
            x
            for o in done_exchange_ops
            for x in o.operands
            if isinstance(x, BlockArgument) and x.index != field_idx
        ),
        key=lambda b: b.index,
    )

    # set up region signatures, comprising fixed and optional args - see docs on `csl_stencil.apply` for details
    chunk_region_args = [
        # required arg 0: slice of type(%prefetch)
        TensorType(
            prefetch.type.get_element_type(),
            (
                len(prefetch.op.swaps),
                prefetch.type.get_shape()[1] // self.num_chunks,
            ),
        ),
        # required arg 1: %offset
        IndexType(),
        # required arg 2: %accumulator
        accumulator.tensor.type,
        # optional args: as needed by the ops
        *[a.type for a in chunk_region_used_block_args],
    ]
    done_exchange_args = [
        # required arg 0: stencil.temp to access own data
        field_op_arg.type,
        # required arg 1: %accumulator
        accumulator.tensor.type,
        # optional args: as needed by the ops
        *[a.type for a in done_exchange_used_block_args],
    ]

    # set up two regions
    receive_chunk = Region(Block(arg_types=chunk_region_args))
    done_exchange = Region(Block(arg_types=done_exchange_args))

    # translate old to new block arg index for optional args
    chunk_region_oprnd_table = dict[Operand, Operand](
        (old, receive_chunk.block.args[idx])
        for idx, old in enumerate(chunk_region_used_block_args, start=3)
    )
    done_exchange_oprnd_table = dict[Operand, Operand](
        (old, done_exchange.block.args[idx])
        for idx, old in enumerate(done_exchange_used_block_args, start=2)
    )

    # add translation from old to new arg index for non-optional args - note, access
    # to accumulator must be handled separately below
    chunk_region_oprnd_table[op.region.block.args[prefetch_idx]] = (
        receive_chunk.block.args[0]
    )
    done_exchange_oprnd_table[op.region.block.args[field_idx]] = (
        done_exchange.block.args[0]
    )
    done_exchange_oprnd_table[chunk_res] = done_exchange.block.args[1]

    # detach ops from old region
    for o in op.region.block.ops:
        op.region.block.detach_op(o)

    # add operations from list to receive_chunk, use translation table to rebuild operands
    for o in chunk_region_ops:
        if isinstance(o, stencil.ReturnOp | csl_stencil.YieldOp):
            break
        o.operands = [chunk_region_oprnd_table.get(x, x) for x in o.operands]
        rewriter.insert_op(o, InsertPoint.at_end(receive_chunk.block))

    # put `chunk_res` into `accumulator` (using tensor.insert_slice) and yield the result
    rewriter.insert_op(
        [
            insert_slice_op := tensor.InsertSliceOp.get(
                source=chunk_res,
                dest=receive_chunk.block.args[2],
                offsets=(receive_chunk.block.args[1],),
                static_sizes=(prefetch.type.get_shape()[1] // self.num_chunks,),
            ),
            csl_stencil.YieldOp(insert_slice_op.result),
        ],
        InsertPoint.at_end(receive_chunk.block),
    )

    # add operations from list to done_exchange, use translation table to rebuild operands
    for o in done_exchange_ops:
        o.operands = [done_exchange_oprnd_table.get(x, x) for x in o.operands]
        rewriter.insert_op(o, InsertPoint.at_end(done_exchange.block))
        if isinstance(o, stencil.ReturnOp):
            rewriter.replace_op(o, csl_stencil.YieldOp(*o.operands))

    rewriter.replace_op(
        op,
        csl_stencil.ApplyOp(
            operands=[
                field_op_arg,
                accumulator,
                [op.operands[a.index] for a in chunk_region_used_block_args],
                [op.operands[a.index] for a in done_exchange_used_block_args],
                op.dest,
            ],
            properties={
                "swaps": prefetch.op.swaps,
                "topo": prefetch.op.topo,
                "num_chunks": IntegerAttr(self.num_chunks, IntegerType(64)),
                "bounds": op.bounds,
            },
            regions=[
                receive_chunk,
                done_exchange,
            ],
            result_types=[op.result_types],
        ),
    )

    if not prefetch.uses:
        rewriter.erase_op(prefetch.op)

PromoteCoefficients

Bases: RewritePattern

Promotes constant coefficients to attributes. When a csl_stencil.access is immediately multiplied by an arith.constant as the sole use of the accessed data, the constant is promoted to a coefficient property in the csl_stencil.apply op.

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
class PromoteCoefficients(RewritePattern):
    """
    Promotes constant coefficients to attributes. When a `csl_stencil.access` is immediately multiplied by
    an `arith.constant` as the sole use of the accessed data, the constant is promoted to a coefficient property
    in the `csl_stencil.apply` op.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: csl_stencil.AccessOp, rewriter: PatternRewriter, /):
        if (
            not isinstance(apply := op.get_apply(), csl_stencil.ApplyOp)
            or not op.op == apply.receive_chunk.block.args[0]
            or not isinstance(mulf := op.result.get_user_of_unique_use(), arith.MulfOp)
        ):
            return

        coeff = mulf.lhs if op.result == mulf.rhs else mulf.rhs

        if (
            not isinstance(cnst := coeff.owner, arith.ConstantOp)
            or not isinstance(dense := cnst.value, DenseIntOrFPElementsAttr)
            or not dense.is_splat()
        ):
            return

        val = dense.get_attrs()[0]
        assert isinstance(val, FloatAttr)
        apply.add_coeff(op.offset, val)
        rewriter.replace_op(mulf, [], new_results=[op.result])

match_and_rewrite(op: csl_stencil.AccessOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl_stencil.AccessOp, rewriter: PatternRewriter, /):
    if (
        not isinstance(apply := op.get_apply(), csl_stencil.ApplyOp)
        or not op.op == apply.receive_chunk.block.args[0]
        or not isinstance(mulf := op.result.get_user_of_unique_use(), arith.MulfOp)
    ):
        return

    coeff = mulf.lhs if op.result == mulf.rhs else mulf.rhs

    if (
        not isinstance(cnst := coeff.owner, arith.ConstantOp)
        or not isinstance(dense := cnst.value, DenseIntOrFPElementsAttr)
        or not dense.is_splat()
    ):
        return

    val = dense.get_attrs()[0]
    assert isinstance(val, FloatAttr)
    apply.add_coeff(op.offset, val)
    rewriter.replace_op(mulf, [], new_results=[op.result])

TransformPrefetch

Bases: RewritePattern

Rewrites a prefetch into a communicate-only csl_stencil.apply

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
class TransformPrefetch(RewritePattern):
    """
    Rewrites a prefetch into a communicate-only csl_stencil.apply
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: csl_stencil.PrefetchOp, rewriter: PatternRewriter, /
    ):
        a_buf = tensor.EmptyOp((), op.result.type)
        # because we are building a set of offsets, we are not retaining offset mappings
        offsets = [swap.neighbor for swap in op.swaps]

        assert isa(op.result.type, AnyTensorType)
        chunk_buf_t = TensorType(
            op.result.type.get_element_type(),
            (
                len(op.swaps),
                op.result.type.get_shape()[1] // op.num_chunks.value.data,
            ),
        )
        chunk_t = TensorType(chunk_buf_t.element_type, chunk_buf_t.get_shape()[1:])

        block = Block(arg_types=[chunk_buf_t, builtin.IndexType(), op.result.type])
        block2 = Block(arg_types=[op.input_stencil.type, op.result.type])
        block2.add_op(csl_stencil.YieldOp())

        with ImplicitBuilder(block) as (buf, offset, acc):
            dest = acc
            for i, acc_offset in enumerate(offsets):
                ac_op = csl_stencil.AccessOp(
                    buf, stencil.IndexAttr.get(*acc_offset), chunk_t
                )
                assert isa(ac_op.result.type, AnyTensorType)
                # inserts 1 (see static_sizes) 1d slice into a 2d tensor at offset (i, `offset`) (see static_offsets)
                # where the latter offset is provided dynamically (see offsets)
                dest = tensor.InsertSliceOp.get(
                    source=ac_op.result,
                    dest=dest,
                    static_sizes=[1, *ac_op.result.type.get_shape()],
                    static_offsets=[i, DYNAMIC_INDEX],
                    offsets=[offset],
                ).result
            csl_stencil.YieldOp(dest)

        apply_op = csl_stencil.ApplyOp(
            operands=[op.input_stencil, a_buf, [], [], []],
            regions=[Region(block), Region(block2)],
            properties={
                "swaps": op.swaps,
                "topo": op.topo,
                "num_chunks": op.num_chunks,
            },
            result_types=[[]],
        )

        rewriter.replace_op(op, [a_buf, apply_op], new_results=[a_buf.tensor])

match_and_rewrite(op: csl_stencil.PrefetchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: csl_stencil.PrefetchOp, rewriter: PatternRewriter, /
):
    a_buf = tensor.EmptyOp((), op.result.type)
    # because we are building a set of offsets, we are not retaining offset mappings
    offsets = [swap.neighbor for swap in op.swaps]

    assert isa(op.result.type, AnyTensorType)
    chunk_buf_t = TensorType(
        op.result.type.get_element_type(),
        (
            len(op.swaps),
            op.result.type.get_shape()[1] // op.num_chunks.value.data,
        ),
    )
    chunk_t = TensorType(chunk_buf_t.element_type, chunk_buf_t.get_shape()[1:])

    block = Block(arg_types=[chunk_buf_t, builtin.IndexType(), op.result.type])
    block2 = Block(arg_types=[op.input_stencil.type, op.result.type])
    block2.add_op(csl_stencil.YieldOp())

    with ImplicitBuilder(block) as (buf, offset, acc):
        dest = acc
        for i, acc_offset in enumerate(offsets):
            ac_op = csl_stencil.AccessOp(
                buf, stencil.IndexAttr.get(*acc_offset), chunk_t
            )
            assert isa(ac_op.result.type, AnyTensorType)
            # inserts 1 (see static_sizes) 1d slice into a 2d tensor at offset (i, `offset`) (see static_offsets)
            # where the latter offset is provided dynamically (see offsets)
            dest = tensor.InsertSliceOp.get(
                source=ac_op.result,
                dest=dest,
                static_sizes=[1, *ac_op.result.type.get_shape()],
                static_offsets=[i, DYNAMIC_INDEX],
                offsets=[offset],
            ).result
        csl_stencil.YieldOp(dest)

    apply_op = csl_stencil.ApplyOp(
        operands=[op.input_stencil, a_buf, [], [], []],
        regions=[Region(block), Region(block2)],
        properties={
            "swaps": op.swaps,
            "topo": op.topo,
            "num_chunks": op.num_chunks,
        },
        result_types=[[]],
    )

    rewriter.replace_op(op, [a_buf, apply_op], new_results=[a_buf.tensor])

ConvertStencilToCslStencilPass dataclass

Bases: ModulePass

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
@dataclass(frozen=True)
class ConvertStencilToCslStencilPass(ModulePass):
    name = "convert-stencil-to-csl-stencil"

    # chunks into which to slice communication
    num_chunks: int = 1

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        ConvertArithToVarithPass().apply(ctx, op)
        module_pass = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ConvertSwapToPrefetchPattern(num_chunks=self.num_chunks),
                    ConvertAccessOpPattern(),
                ]
            ),
            walk_reverse=False,
            apply_recursively=False,
        )
        module_pass.rewrite_module(op)
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ConvertApplyOpPattern(num_chunks=self.num_chunks),
                    PromoteCoefficients(),
                    TransformPrefetch(),
                ]
            ),
            apply_recursively=False,
            walk_reverse=True,
        ).rewrite_module(op)

        ConvertVarithToArithPass().apply(ctx, op)

        if self.num_chunks > 1:
            BackpropagateStencilShapes().apply(ctx, op)

name = 'convert-stencil-to-csl-stencil' class-attribute instance-attribute

num_chunks: int = 1 class-attribute instance-attribute

__init__(num_chunks: int = 1) -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
def apply(self, ctx: Context, op: ModuleOp) -> None:
    ConvertArithToVarithPass().apply(ctx, op)
    module_pass = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ConvertSwapToPrefetchPattern(num_chunks=self.num_chunks),
                ConvertAccessOpPattern(),
            ]
        ),
        walk_reverse=False,
        apply_recursively=False,
    )
    module_pass.rewrite_module(op)
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ConvertApplyOpPattern(num_chunks=self.num_chunks),
                PromoteCoefficients(),
                TransformPrefetch(),
            ]
        ),
        apply_recursively=False,
        walk_reverse=True,
    ).rewrite_module(op)

    ConvertVarithToArithPass().apply(ctx, op)

    if self.num_chunks > 1:
        BackpropagateStencilShapes().apply(ctx, op)

get_stencil_access_operands(op: Operand) -> set[Operand]

Returns the symbols of all stencil accessess by op and all its dependencies.

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def get_stencil_access_operands(op: Operand) -> set[Operand]:
    """
    Returns the symbols of all stencil accessess by op and all its dependencies.
    """
    res: set[Operand] = set()
    frontier: set[Operand] = {op}
    next: set[Operand] = set()
    while len(frontier) > 0:
        for o in frontier:
            if isinstance(o, OpResult):
                if isinstance(o.op, csl_stencil.AccessOp):
                    res.add(o.op.op)
                else:
                    next.update(o.op.operands)
        frontier = next
        next = set()

    return res

split_ops(ops: Sequence[Operation], buf: BlockArgument, rewriter: PatternRewriter) -> tuple[Sequence[Operation], Sequence[Operation]]

Returns a split of ops into an (a,b) tuple, such that:

  • a contains neighbour accesses to buf plus the minumum set of instructions to reduce the accessed data to 1 thing
  • b contains everything else

If no valid split can be found, return (ops, []).

This function does not attempt to arithmetically re-structure the computation to obtain a good split. To do this, RestructureSymmetricReductionPattern() may be executed first.

Source code in xdsl/transforms/convert_stencil_to_csl_stencil.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def split_ops(
    ops: Sequence[Operation], buf: BlockArgument, rewriter: PatternRewriter
) -> tuple[Sequence[Operation], Sequence[Operation]]:
    """
    Returns a split of `ops` into an `(a,b)` tuple, such that:

    - `a` contains neighbour accesses to `buf` plus the minumum set of instructions to reduce the accessed data to 1 thing
    - `b` contains everything else

    If no valid split can be found, return `(ops, [])`.

    This function does not attempt to arithmetically re-structure the computation to obtain a good split. To do this,
    `RestructureSymmetricReductionPattern()` may be executed first.
    """
    a: list[Operation] = []
    b: list[Operation] = []
    rem: list[Operation] = []
    for op in ops:
        if isinstance(op, csl_stencil.AccessOp):
            (b, a)[op.op == buf].append(op)
        elif isinstance(op, arith.ConstantOp):
            a.append(op)
        else:
            rem.append(op)

    # loop until we can make no more changes, or until only 1 thing computed in `a` is used outside of it
    has_changes = True
    while (
        len(
            # ops in `a` whose results are used outside of `a`
            a_exports := set(
                op
                for op in a
                for result in op.results
                for use in result.uses
                if use.operation not in a
            )
        )
        > 1
        and has_changes
    ):
        has_changes = False

        # find ops that directly depend on `a` but are not themselves in `a`
        for exp in a_exports:
            for result in exp.results:
                for use in result.uses:
                    # op is only movable if *all* operands are already in `a` (and it hasn't been moved yet)
                    if (op := use.operation) in rem and all(
                        x.op in a for x in op.operands if isinstance(x, OpResult)
                    ):
                        has_changes = True
                        a.append(use.operation)
                        rem.remove(use.operation)

    # find constants in `a` needed outside of `a`
    cnst_exports = tuple(
        cnst for cnst in a_exports if isinstance(cnst, arith.ConstantOp)
    )

    # `a` exports one value plus any number of constants - duplicate exported constants and return op split
    if len(a_exports) == 1 + len(cnst_exports):
        recv_chunk_ops, done_exch_ops = list[Operation](), list[Operation]()
        for op in ops:
            if op in a:
                recv_chunk_ops.append(op)
                if op in cnst_exports:
                    assert isinstance(op, arith.ConstantOp)
                    # create a copy of the constant in the second region
                    done_exch_ops.append(cln := op.clone())
                    # rewire ops of the second region to use the copied constant
                    rewriter.replace_uses_with_if(
                        op.result,
                        cln.result,
                        lambda use: use.operation in b or use.operation in rem,
                    )
            else:
                done_exch_ops.append(op)

        return recv_chunk_ops, done_exch_ops

    # fallback
    # always place `stencil.return` in second block
    return ops[:-1], [ops[-1]]