Skip to content

Stencil global to local

stencil_global_to_local

ChangeStoreOpSizes dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@dataclass
class ChangeStoreOpSizes(RewritePattern):
    strategy: dmp.DomainDecompositionStrategy

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: stencil.StoreOp, rewriter: PatternRewriter, /):
        assert all(
            integer_attr.data == 0 for integer_attr in op.bounds.lb.array.data
        ), "lb must be 0"
        shape: tuple[int, ...] = tuple(
            integer_attr.data for integer_attr in op.bounds.ub.array.data
        )
        new_shape = self.strategy.calc_resize(shape)
        op.bounds = stencil.StencilBoundsAttr.new(
            [
                stencil.IndexAttr.get(*(len(new_shape) * [0])),
                stencil.IndexAttr.get(*new_shape),
            ]
        )

strategy: dmp.DomainDecompositionStrategy instance-attribute

__init__(strategy: dmp.DomainDecompositionStrategy) -> None

match_and_rewrite(op: stencil.StoreOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.StoreOp, rewriter: PatternRewriter, /):
    assert all(
        integer_attr.data == 0 for integer_attr in op.bounds.lb.array.data
    ), "lb must be 0"
    shape: tuple[int, ...] = tuple(
        integer_attr.data for integer_attr in op.bounds.ub.array.data
    )
    new_shape = self.strategy.calc_resize(shape)
    op.bounds = stencil.StencilBoundsAttr.new(
        [
            stencil.IndexAttr.get(*(len(new_shape) * [0])),
            stencil.IndexAttr.get(*new_shape),
        ]
    )

AddHaloExchangeOps dataclass

Bases: RewritePattern

This rewrite adds a stencil.halo_exchange after each stencil.load op

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@dataclass
class AddHaloExchangeOps(RewritePattern):
    """
    This rewrite adds a `stencil.halo_exchange` after each `stencil.load` op
    """

    strategy: dmp.DomainDecompositionStrategy

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: stencil.LoadOp, rewriter: PatternRewriter, /):
        swap_op = dmp.SwapOp.get(op.res, self.strategy)
        assert swap_op.swapped_values
        rewriter.insert_op(swap_op, InsertPoint.after(op))
        for use in tuple(op.res.uses):
            if use.operation is swap_op:
                continue
            use.operation.operands[use.index] = swap_op.swapped_values
            rewriter.handle_operation_modification(use.operation)

strategy: dmp.DomainDecompositionStrategy instance-attribute

__init__(strategy: dmp.DomainDecompositionStrategy) -> None

match_and_rewrite(op: stencil.LoadOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
55
56
57
58
59
60
61
62
63
64
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.LoadOp, rewriter: PatternRewriter, /):
    swap_op = dmp.SwapOp.get(op.res, self.strategy)
    assert swap_op.swapped_values
    rewriter.insert_op(swap_op, InsertPoint.after(op))
    for use in tuple(op.res.uses):
        if use.operation is swap_op:
            continue
        use.operation.operands[use.index] = swap_op.swapped_values
        rewriter.handle_operation_modification(use.operation)

LowerHaloExchangeToMpi dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@dataclass
class LowerHaloExchangeToMpi(RewritePattern):
    init: bool
    debug_prints: bool = False

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
        exchanges = list(op.swaps)

        input_type = cast(ContainerType, op.input_stencil.type)

        rewriter.replace_op(
            op,
            list(
                generate_mpi_calls_for(
                    op.input_stencil,
                    exchanges,
                    input_type.get_element_type(),
                    op.strategy.comm_layout(),
                    emit_init=self.init,
                    emit_debug=self.debug_prints,
                )
            ),
            [],
        )

init: bool instance-attribute

debug_prints: bool = False class-attribute instance-attribute

__init__(init: bool, debug_prints: bool = False) -> None

match_and_rewrite(op: dmp.SwapOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@op_type_rewrite_pattern
def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
    exchanges = list(op.swaps)

    input_type = cast(ContainerType, op.input_stencil.type)

    rewriter.replace_op(
        op,
        list(
            generate_mpi_calls_for(
                op.input_stencil,
                exchanges,
                input_type.get_element_type(),
                op.strategy.comm_layout(),
                emit_init=self.init,
                emit_debug=self.debug_prints,
            )
        ),
        [],
    )

MpiLoopInvariantCodeMotion

THIS IS NOT A REWRITE PATTERN!

This is a two-stage rewrite that modifies operations in a manner that is incompatible with the PatternRewriter!

It implements a custom rewrite_module() method directly on the class.

This rewrite moves all memref.allo, mpi.comm.rank, mpi.allocate and mpi.unwrap_memref ops and moves them "up" until it hits a func.func, and then places them before the op they appear in.

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
class MpiLoopInvariantCodeMotion:
    """
    THIS IS NOT A REWRITE PATTERN!

    This is a two-stage rewrite that modifies operations in a manner
    that is incompatible with the PatternRewriter!

    It implements a custom rewrite_module() method directly
    on the class.

    This rewrite moves all memref.allo, mpi.comm.rank, mpi.allocate
    and mpi.unwrap_memref ops and moves them "up" until it hits a
    func.func, and then places them *before* the op they appear in.
    """

    seen_ops: set[Operation]
    has_init: set[func.FuncOp]

    def __init__(self):
        self.seen_ops = set()
        self.has_init = set()

    def rewrite(
        self,
        op: (
            memref.AllocOp
            | mpi.CommRankOp
            | mpi.AllocateTypeOp
            | mpi.UnwrapMemRefOp
            | mpi.InitOp
        ),
        rewriter: Rewriter,
        /,
    ):
        if op in self.seen_ops:
            return
        self.seen_ops.add(op)

        # memref unwraps can always be moved to their allocation
        if isinstance(op, mpi.UnwrapMemRefOp) and isinstance(
            op.ref.owner, memref.AllocOp
        ):
            op.detach()
            rewriter.insert_op(op, InsertPoint.after(op.ref.owner))
            return

        base = op
        parent = op.parent_op()
        # walk upwards until we hit a function
        while parent is not None and not isinstance(parent, func.FuncOp):
            base = parent
            parent = base.parent_op()

        # check that we did not run into "nowhere"
        assert parent is not None, "Expected MPI to be inside a func.FuncOp!"
        assert isinstance(parent, func.FuncOp)  # this must be true now

        # check that we "ascended"
        if base == op:
            return

        if not can_loop_invariant_code_move(op):
            return

        # if we move an mpi.init, generate a finalize()!
        if isinstance(op, mpi.InitOp):
            # ignore multiple inits
            if parent in self.has_init:
                rewriter.erase_op(op)
                return
            self.has_init.add(parent)
            # add a finalize() call to the end of the function
            block = parent.regions[0].blocks[-1]
            return_op = block.last_op
            assert return_op is not None
            rewriter.insert_op(mpi.FinalizeOp(), InsertPoint.before(return_op))

        ops = list(collect_args_recursive(op))
        for found_op in ops:
            found_op.detach()
            rewriter.insert_op(found_op, InsertPoint.before(base))

    def get_matcher(
        self,
        worklist: list[
            memref.AllocOp
            | mpi.CommRankOp
            | mpi.AllocateTypeOp
            | mpi.UnwrapMemRefOp
            | mpi.InitOp
        ],
    ) -> Callable[[Operation], None]:
        """
        Returns a match() function that adds methods to a worklist
        if they satisfy some criteria.
        """

        def match(op: Operation):
            if isinstance(
                op,
                memref.AllocOp
                | mpi.CommRankOp
                | mpi.AllocateTypeOp
                | mpi.UnwrapMemRefOp
                | mpi.InitOp,
            ):
                worklist.append(op)

        return match

    def rewrite_module(self, op: builtin.ModuleOp):
        """
        Apply the rewrite to a module.

        We do a two-stage rewrite because we are modifying
        the operations we loop on them, which would throw of `op.walk`.
        """
        # collect all ops that should be rewritten
        worklist: list[
            memref.AllocOp
            | mpi.CommRankOp
            | mpi.AllocateTypeOp
            | mpi.UnwrapMemRefOp
            | mpi.InitOp
        ] = list()
        matcher = self.get_matcher(worklist)
        for o in op.walk():
            matcher(o)

        # rewrite ops
        rewriter = Rewriter()
        for matched_op in worklist:
            self.rewrite(matched_op, rewriter)

seen_ops: set[Operation] = set() instance-attribute

has_init: set[func.FuncOp] = set() instance-attribute

__init__()

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
431
432
433
def __init__(self):
    self.seen_ops = set()
    self.has_init = set()

rewrite(op: memref.AllocOp | mpi.CommRankOp | mpi.AllocateTypeOp | mpi.UnwrapMemRefOp | mpi.InitOp, rewriter: Rewriter)

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def rewrite(
    self,
    op: (
        memref.AllocOp
        | mpi.CommRankOp
        | mpi.AllocateTypeOp
        | mpi.UnwrapMemRefOp
        | mpi.InitOp
    ),
    rewriter: Rewriter,
    /,
):
    if op in self.seen_ops:
        return
    self.seen_ops.add(op)

    # memref unwraps can always be moved to their allocation
    if isinstance(op, mpi.UnwrapMemRefOp) and isinstance(
        op.ref.owner, memref.AllocOp
    ):
        op.detach()
        rewriter.insert_op(op, InsertPoint.after(op.ref.owner))
        return

    base = op
    parent = op.parent_op()
    # walk upwards until we hit a function
    while parent is not None and not isinstance(parent, func.FuncOp):
        base = parent
        parent = base.parent_op()

    # check that we did not run into "nowhere"
    assert parent is not None, "Expected MPI to be inside a func.FuncOp!"
    assert isinstance(parent, func.FuncOp)  # this must be true now

    # check that we "ascended"
    if base == op:
        return

    if not can_loop_invariant_code_move(op):
        return

    # if we move an mpi.init, generate a finalize()!
    if isinstance(op, mpi.InitOp):
        # ignore multiple inits
        if parent in self.has_init:
            rewriter.erase_op(op)
            return
        self.has_init.add(parent)
        # add a finalize() call to the end of the function
        block = parent.regions[0].blocks[-1]
        return_op = block.last_op
        assert return_op is not None
        rewriter.insert_op(mpi.FinalizeOp(), InsertPoint.before(return_op))

    ops = list(collect_args_recursive(op))
    for found_op in ops:
        found_op.detach()
        rewriter.insert_op(found_op, InsertPoint.before(base))

get_matcher(worklist: list[memref.AllocOp | mpi.CommRankOp | mpi.AllocateTypeOp | mpi.UnwrapMemRefOp | mpi.InitOp]) -> Callable[[Operation], None]

Returns a match() function that adds methods to a worklist if they satisfy some criteria.

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
def get_matcher(
    self,
    worklist: list[
        memref.AllocOp
        | mpi.CommRankOp
        | mpi.AllocateTypeOp
        | mpi.UnwrapMemRefOp
        | mpi.InitOp
    ],
) -> Callable[[Operation], None]:
    """
    Returns a match() function that adds methods to a worklist
    if they satisfy some criteria.
    """

    def match(op: Operation):
        if isinstance(
            op,
            memref.AllocOp
            | mpi.CommRankOp
            | mpi.AllocateTypeOp
            | mpi.UnwrapMemRefOp
            | mpi.InitOp,
        ):
            worklist.append(op)

    return match

rewrite_module(op: builtin.ModuleOp)

Apply the rewrite to a module.

We do a two-stage rewrite because we are modifying the operations we loop on them, which would throw of op.walk.

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
def rewrite_module(self, op: builtin.ModuleOp):
    """
    Apply the rewrite to a module.

    We do a two-stage rewrite because we are modifying
    the operations we loop on them, which would throw of `op.walk`.
    """
    # collect all ops that should be rewritten
    worklist: list[
        memref.AllocOp
        | mpi.CommRankOp
        | mpi.AllocateTypeOp
        | mpi.UnwrapMemRefOp
        | mpi.InitOp
    ] = list()
    matcher = self.get_matcher(worklist)
    for o in op.walk():
        matcher(o)

    # rewrite ops
    rewriter = Rewriter()
    for matched_op in worklist:
        self.rewrite(matched_op, rewriter)

DmpDecompositionPass dataclass

Bases: ModulePass, ABC

Represents a pass that takes a strategy as input

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
582
583
584
585
586
@dataclass(frozen=True)
class DmpDecompositionPass(ModulePass, ABC):
    """
    Represents a pass that takes a strategy as input
    """

__init__() -> None

DistributeStencilPass dataclass

Bases: DmpDecompositionPass

Decompose a stencil to apply to a local domain.

This pass applies stencil shape inference!

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
@dataclass(frozen=True)
class DistributeStencilPass(DmpDecompositionPass):
    """
    Decompose a stencil to apply to a local domain.

    This pass applies stencil shape inference!
    """

    name = "distribute-stencil"

    STRATEGIES: ClassVar[dict[str, type[dmp.GridSlice2dAttr | dmp.GridSlice3dAttr]]] = {
        "2d-grid": dmp.GridSlice2dAttr,
        "3d-grid": dmp.GridSlice3dAttr,
    }

    slices: tuple[int, ...]
    """
    Number of slices to decompose the input into
    """

    strategy: str
    """
    Name of the decomposition strategy to use, see STRATEGIES property for options
    """

    restrict_domain: bool = True
    """
    Apply the domain restriction (i.e. change the stencil.apply to operate on the
    local domain. If false, it assumes that the generated code is already local)
    """

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        if self.strategy not in self.STRATEGIES:
            raise ValueError(f"Unknown strategy: {self.strategy}")
        strategy = self.STRATEGIES[self.strategy](self.slices)

        rewrites: list[RewritePattern] = [
            AddHaloExchangeOps(strategy),
        ]

        if self.restrict_domain:
            rewrites.append(ChangeStoreOpSizes(strategy))

        PatternRewriteWalker(
            GreedyRewritePatternApplier(rewrites),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'distribute-stencil' class-attribute instance-attribute

STRATEGIES: dict[str, type[dmp.GridSlice2dAttr | dmp.GridSlice3dAttr]] = {'2d-grid': dmp.GridSlice2dAttr, '3d-grid': dmp.GridSlice3dAttr} class-attribute

slices: tuple[int, ...] instance-attribute

Number of slices to decompose the input into

strategy: str instance-attribute

Name of the decomposition strategy to use, see STRATEGIES property for options

restrict_domain: bool = True class-attribute instance-attribute

Apply the domain restriction (i.e. change the stencil.apply to operate on the local domain. If false, it assumes that the generated code is already local)

__init__(slices: tuple[int, ...], strategy: str, restrict_domain: bool = True) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    if self.strategy not in self.STRATEGIES:
        raise ValueError(f"Unknown strategy: {self.strategy}")
    strategy = self.STRATEGIES[self.strategy](self.slices)

    rewrites: list[RewritePattern] = [
        AddHaloExchangeOps(strategy),
    ]

    if self.restrict_domain:
        rewrites.append(ChangeStoreOpSizes(strategy))

    PatternRewriteWalker(
        GreedyRewritePatternApplier(rewrites),
        apply_recursively=False,
    ).rewrite_module(op)

DmpToMpiPass dataclass

Bases: ModulePass

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
@dataclass(frozen=True)
class DmpToMpiPass(ModulePass):
    name = "dmp-to-mpi"

    mpi_init: bool = True

    generate_debug_prints: bool = False

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LowerHaloExchangeToMpi(
                        self.mpi_init,
                        self.generate_debug_prints,
                    ),
                ]
            )
        ).rewrite_module(op)
        MpiLoopInvariantCodeMotion().rewrite_module(op)

name = 'dmp-to-mpi' class-attribute instance-attribute

mpi_init: bool = True class-attribute instance-attribute

generate_debug_prints: bool = False class-attribute instance-attribute

__init__(mpi_init: bool = True, generate_debug_prints: bool = False) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
646
647
648
649
650
651
652
653
654
655
656
657
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                LowerHaloExchangeToMpi(
                    self.mpi_init,
                    self.generate_debug_prints,
                ),
            ]
        )
    ).rewrite_module(op)
    MpiLoopInvariantCodeMotion().rewrite_module(op)

generate_mpi_calls_for(source: SSAValue, exchanges: list[dmp.ExchangeDeclarationAttr], dtype: Attribute, grid: dmp.RankTopoAttr, emit_init: bool = True, emit_debug: bool = False) -> Iterable[Operation]

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def generate_mpi_calls_for(
    source: SSAValue,
    exchanges: list[dmp.ExchangeDeclarationAttr],
    dtype: Attribute,
    grid: dmp.RankTopoAttr,
    emit_init: bool = True,
    emit_debug: bool = False,
) -> Iterable[Operation]:
    # call mpi init (this will be hoisted to function level)
    if emit_init:
        yield mpi.InitOp()
    # allocate request array
    # we need two request objects per exchange
    # one for the send, one for the recv
    req_cnt = arith.ConstantOp.from_int_and_width(len(exchanges) * 2, builtin.i32)
    reqs = mpi.AllocateTypeOp(mpi.RequestType, req_cnt)
    # get comm rank
    rank = mpi.CommRankOp()
    # define static tag of 0
    tag = arith.ConstantOp.from_int_and_width(0, builtin.i32)

    yield from (req_cnt, reqs, rank, tag)

    recv_buffers: list[
        tuple[dmp.ExchangeDeclarationAttr, memref.AllocOp, SSAValue]
    ] = []

    for i, ex in enumerate(exchanges):
        # generate a temp buffer to store the data in
        reduced_size = [i for i in ex.size if i != 1]
        alloc_outbound = memref.AllocOp.get(dtype, 64, reduced_size)
        alloc_outbound.memref.name_hint = f"send_buff_ex{i}"
        alloc_inbound = memref.AllocOp.get(dtype, 64, reduced_size)
        alloc_inbound.memref.name_hint = f"recv_buff_ex{i}"
        yield from (alloc_outbound, alloc_inbound)

        # calc dest rank and check if it's in-bounds
        ops, dest_rank, is_in_bounds = _generate_dest_rank_computation(
            rank.rank, ex.neighbor, grid
        )
        yield from ops

        recv_buffers.append((ex, alloc_inbound, is_in_bounds))

        # get two unique indices
        cst_i = arith.ConstantOp.from_int_and_width(i, builtin.i32)
        cst_in = arith.ConstantOp.from_int_and_width(i + len(exchanges), builtin.i32)
        yield from (cst_i, cst_in)
        # from these indices, get request objects
        req_send = mpi.VectorGetOp(reqs, cst_i)
        req_recv = mpi.VectorGetOp(reqs, cst_in)
        yield from (req_send, req_recv)

        def then() -> Iterable[Operation]:
            # copy source area to outbound buffer
            yield from generate_memcpy(source, ex.source_area(), alloc_outbound.memref)
            # get ptr, count, dtype
            unwrap_out = mpi.UnwrapMemRefOp(alloc_outbound)
            unwrap_out.ptr.name_hint = f"send_buff_ex{i}_ptr"
            yield unwrap_out

            if emit_debug:
                yield printf.PrintFormatOp(
                    f"Rank {{}}: sending {ex.source_area()} -> {{}}\n", rank, dest_rank
                )

            # isend call
            yield mpi.IsendOp(
                unwrap_out.ptr,
                unwrap_out.len,
                unwrap_out.type,
                dest_rank,
                tag,
                req_send,
            )

            # get ptr for receive buffer
            unwrap_in = mpi.UnwrapMemRefOp(alloc_inbound)
            unwrap_in.ptr.name_hint = f"recv_buff_ex{i}_ptr"
            yield unwrap_in
            # Irecv call
            yield mpi.IrecvOp(
                unwrap_in.ptr,
                unwrap_in.len,
                unwrap_in.type,
                dest_rank,
                tag,
                req_recv,
            )
            yield scf.YieldOp()

        def else_() -> Iterable[Operation]:
            # set the request object to MPI_REQUEST_NULL s.t. they are ignored
            # in the waitall call
            yield mpi.NullRequestOp(req_send)
            yield mpi.NullRequestOp(req_recv)
            yield scf.YieldOp()

        yield scf.IfOp(
            is_in_bounds,
            [],
            Region([Block(then())]),
            Region([Block(else_())]),
        )

    # wait for all calls to complete
    yield mpi.WaitallOp(reqs.result, req_cnt.result)

    # start shuffling data into the main memref again
    for ex, buffer, cond_val in recv_buffers:
        yield scf.IfOp(
            cond_val,
            [],
            Region(
                [
                    Block(
                        list(
                            generate_memcpy(
                                source,
                                ex,
                                buffer.memref,
                                receive=True,
                            )
                        )
                        + [
                            printf.PrintFormatOp(
                                f"Rank {{}} receiving from {ex.neighbor}\n",
                                rank,
                            )
                        ]
                        * (1 if emit_debug else 0)
                        + [scf.YieldOp()]
                    )
                ]
            ),
            Region([Block([scf.YieldOp()])]),
        )

generate_memcpy(field: SSAValue, ex: dmp.ExchangeDeclarationAttr, buffer: SSAValue, receive: bool = False) -> list[Operation]

This function generates a memcpy routine to copy over the parts specified by the field from field into buffer.

If receive=True, it instead copy from buffer into the parts of field as specified by ex

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def generate_memcpy(
    field: SSAValue,
    ex: dmp.ExchangeDeclarationAttr,
    buffer: SSAValue,
    receive: bool = False,
) -> list[Operation]:
    """
    This function generates a memcpy routine to copy over the parts
    specified by the `field` from `field` into `buffer`.

    If receive=True, it instead copy from `buffer` into the parts of
    `field` as specified by `ex`

    """
    field_type = cast(stencil.FieldType[Attribute], field.type)
    assert isinstance(field_type.bounds, stencil.StencilBoundsAttr)
    memref_type = StencilToMemRefType(field_type)

    uc = builtin.UnrealizedConversionCastOp.get([field], result_type=[memref_type])

    memref_val = uc.results[0]

    offset = stencil.IndexAttr.get(*ex.offset) - field_type.bounds.lb

    subview = memref.SubviewOp.from_static_parameters(
        memref_val,
        memref_type,
        tuple(offset),
        ex.size,
        [1] * len(ex.offset),
        reduce_rank=True,
    )
    if receive:
        copy = memref.CopyOp(buffer, subview)
    else:
        copy = memref.CopyOp(subview, buffer)

    return [
        uc,
        subview,
        copy,
    ]

can_loop_invariant_code_move(op: Operation)

This function walks the def-use chain up to see if all the args are "constant enough" to move outside the loop.

This check is very conservative, but that means it definitely works!

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
def can_loop_invariant_code_move(op: Operation):
    """
    This function walks the def-use chain up to see if all the args are
    "constant enough" to move outside the loop.

    This check is very conservative, but that means it definitely works!
    """

    for arg in op.operands:
        if not isinstance(arg, OpResult):
            print(f"{arg} is not opresult")
            return False
        if not isinstance(arg.owner, _LOOP_INVARIANT_OPS):
            print(f"{arg} is not loop invariant")
            return False
        if not can_loop_invariant_code_move(arg.owner):
            return False
    return True

collect_args_recursive(op: Operation) -> Iterable[Operation]

Collect the def-use chain "upwards" of an operation. Check with can_loop_invariant_code_move prior to using this!

Source code in xdsl/transforms/experimental/dmp/stencil_global_to_local.py
571
572
573
574
575
576
577
578
579
def collect_args_recursive(op: Operation) -> Iterable[Operation]:
    """
    Collect the def-use chain "upwards" of an operation.
    Check with can_loop_invariant_code_move prior to using this!
    """
    for arg in op.operands:
        assert isinstance(arg, OpResult)
        yield from collect_args_recursive(arg.owner)
    yield op