Skip to content

Lower mpi

lower_mpi

MpiLibraryInfo dataclass

This object is meant to capture characteristics of a specific MPI implementations.

It holds magic values, sizes of structs, field offsets and much more.

We need these as we currently cannot load these library headers into the programs we want to lower, therefore we need to generate our own external stubs and load magic values directly.

This way of doing it is inherently fragile, but we don't know of any better way. We plan to include a C file that automagically extracts all this information from MPI headers. You can see the current C file used in this PR: https://github.com/xdslproject/xdsl/pull/526 You can see the status of OpenMPI support here: https://github.com/xdslproject/xdsl/issues/523

These defaults have been extracted from MPICH 3.3a2. We would highly suggest running the mpi-info.c file yourself with your version of the library!

Source code in xdsl/transforms/lower_mpi.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
@dataclass(frozen=True)
class MpiLibraryInfo:
    """
    This object is meant to capture characteristics of a specific MPI implementations.

    It holds magic values, sizes of structs, field offsets and much more.

    We need these as we currently cannot load these library headers into the programs we want to lower,
    therefore we need to generate our own external stubs and load magic values directly.

    This way of doing it is inherently fragile, but we don't know of any better way.
    We plan to include a C file that automagically extracts all this information from MPI headers.
    You can see the current C file used in this PR: https://github.com/xdslproject/xdsl/pull/526
    You can see the status of OpenMPI support here: https://github.com/xdslproject/xdsl/issues/523

    These defaults have been extracted from MPICH 3.3a2. We would highly suggest
    running the mpi-info.c file yourself with your version of the library!
    """

    # MPI_Datatype
    MPI_Datatype_size: int = 4
    MPI_CHAR: int = 0x4C000101
    MPI_SIGNED_CHAR: int = 0x4C000118
    MPI_UNSIGNED_CHAR: int = 0x4C000102
    MPI_BYTE: int = 0x4C00010D
    MPI_WCHAR: int = 0x4C00040E
    MPI_SHORT: int = 0x4C000203
    MPI_UNSIGNED_SHORT: int = 0x4C000204
    MPI_INT: int = 0x4C000405
    MPI_UNSIGNED: int = 0x4C000406
    MPI_LONG: int = 0x4C000807
    MPI_UNSIGNED_LONG: int = 0x4C000808
    MPI_FLOAT: int = 0x4C00040A
    MPI_DOUBLE: int = 0x4C00080B
    MPI_LONG_DOUBLE: int = 0x4C00100C
    MPI_LONG_LONG_INT: int = 0x4C000809
    MPI_UNSIGNED_LONG_LONG: int = 0x4C000819
    MPI_LONG_LONG: int = 0x4C000809

    # MPI_Op
    MPI_Op_size: int = 4
    MPI_MAX: int = 0x58000001
    MPI_MIN: int = 0x58000002
    MPI_SUM: int = 0x58000003
    MPI_PROD: int = 0x58000004
    MPI_LAND: int = 0x58000005
    MPI_BAND: int = 0x58000006
    MPI_LOR: int = 0x58000007
    MPI_BOR: int = 0x58000008
    MPI_LXOR: int = 0x58000009
    MPI_BXOR: int = 0x5800000A
    MPI_MINLOC: int = 0x5800000B
    MPI_MAXLOC: int = 0x5800000C
    MPI_REPLACE: int = 0x5800000D
    MPI_NO_OP: int = 0x5800000E

    # MPI_Comm
    MPI_Comm_size: int = 4
    MPI_COMM_WORLD: int = 0x44000000
    MPI_COMM_SELF: int = 0x44000001

    # MPI_Request
    MPI_Request_size: int = 4
    MPI_REQUEST_NULL = 0x2C000000

    # MPI_Status
    MPI_Status_size: int = 20
    MPI_STATUS_IGNORE: int = 0x00000001
    MPI_STATUSES_IGNORE: int = 0x00000001
    MPI_Status_field_MPI_SOURCE: int = (
        8  # offset of field MPI_SOURCE in struct MPI_Status
    )
    MPI_Status_field_MPI_TAG: int = 12  # offset of field MPI_TAG in struct MPI_Status
    MPI_Status_field_MPI_ERROR: int = (
        16  # offset of field MPI_ERROR in struct MPI_Status
    )

    # In place MPI All reduce
    MPI_IN_PLACE: int = -1

MPI_Datatype_size: int = 4 class-attribute instance-attribute

MPI_CHAR: int = 1275068673 class-attribute instance-attribute

MPI_SIGNED_CHAR: int = 1275068696 class-attribute instance-attribute

MPI_UNSIGNED_CHAR: int = 1275068674 class-attribute instance-attribute

MPI_BYTE: int = 1275068685 class-attribute instance-attribute

MPI_WCHAR: int = 1275069454 class-attribute instance-attribute

MPI_SHORT: int = 1275068931 class-attribute instance-attribute

MPI_UNSIGNED_SHORT: int = 1275068932 class-attribute instance-attribute

MPI_INT: int = 1275069445 class-attribute instance-attribute

MPI_UNSIGNED: int = 1275069446 class-attribute instance-attribute

MPI_LONG: int = 1275070471 class-attribute instance-attribute

MPI_UNSIGNED_LONG: int = 1275070472 class-attribute instance-attribute

MPI_FLOAT: int = 1275069450 class-attribute instance-attribute

MPI_DOUBLE: int = 1275070475 class-attribute instance-attribute

MPI_LONG_DOUBLE: int = 1275072524 class-attribute instance-attribute

MPI_LONG_LONG_INT: int = 1275070473 class-attribute instance-attribute

MPI_UNSIGNED_LONG_LONG: int = 1275070489 class-attribute instance-attribute

MPI_LONG_LONG: int = 1275070473 class-attribute instance-attribute

MPI_Op_size: int = 4 class-attribute instance-attribute

MPI_MAX: int = 1476395009 class-attribute instance-attribute

MPI_MIN: int = 1476395010 class-attribute instance-attribute

MPI_SUM: int = 1476395011 class-attribute instance-attribute

MPI_PROD: int = 1476395012 class-attribute instance-attribute

MPI_LAND: int = 1476395013 class-attribute instance-attribute

MPI_BAND: int = 1476395014 class-attribute instance-attribute

MPI_LOR: int = 1476395015 class-attribute instance-attribute

MPI_BOR: int = 1476395016 class-attribute instance-attribute

MPI_LXOR: int = 1476395017 class-attribute instance-attribute

MPI_BXOR: int = 1476395018 class-attribute instance-attribute

MPI_MINLOC: int = 1476395019 class-attribute instance-attribute

MPI_MAXLOC: int = 1476395020 class-attribute instance-attribute

MPI_REPLACE: int = 1476395021 class-attribute instance-attribute

MPI_NO_OP: int = 1476395022 class-attribute instance-attribute

MPI_Comm_size: int = 4 class-attribute instance-attribute

MPI_COMM_WORLD: int = 1140850688 class-attribute instance-attribute

MPI_COMM_SELF: int = 1140850689 class-attribute instance-attribute

MPI_Request_size: int = 4 class-attribute instance-attribute

MPI_REQUEST_NULL = 738197504 class-attribute instance-attribute

MPI_Status_size: int = 20 class-attribute instance-attribute

MPI_STATUS_IGNORE: int = 1 class-attribute instance-attribute

MPI_STATUSES_IGNORE: int = 1 class-attribute instance-attribute

MPI_Status_field_MPI_SOURCE: int = 8 class-attribute instance-attribute

MPI_Status_field_MPI_TAG: int = 12 class-attribute instance-attribute

MPI_Status_field_MPI_ERROR: int = 16 class-attribute instance-attribute

MPI_IN_PLACE: int = -1 class-attribute instance-attribute

__init__(MPI_Datatype_size: int = 4, MPI_CHAR: int = 1275068673, MPI_SIGNED_CHAR: int = 1275068696, MPI_UNSIGNED_CHAR: int = 1275068674, MPI_BYTE: int = 1275068685, MPI_WCHAR: int = 1275069454, MPI_SHORT: int = 1275068931, MPI_UNSIGNED_SHORT: int = 1275068932, MPI_INT: int = 1275069445, MPI_UNSIGNED: int = 1275069446, MPI_LONG: int = 1275070471, MPI_UNSIGNED_LONG: int = 1275070472, MPI_FLOAT: int = 1275069450, MPI_DOUBLE: int = 1275070475, MPI_LONG_DOUBLE: int = 1275072524, MPI_LONG_LONG_INT: int = 1275070473, MPI_UNSIGNED_LONG_LONG: int = 1275070489, MPI_LONG_LONG: int = 1275070473, MPI_Op_size: int = 4, MPI_MAX: int = 1476395009, MPI_MIN: int = 1476395010, MPI_SUM: int = 1476395011, MPI_PROD: int = 1476395012, MPI_LAND: int = 1476395013, MPI_BAND: int = 1476395014, MPI_LOR: int = 1476395015, MPI_BOR: int = 1476395016, MPI_LXOR: int = 1476395017, MPI_BXOR: int = 1476395018, MPI_MINLOC: int = 1476395019, MPI_MAXLOC: int = 1476395020, MPI_REPLACE: int = 1476395021, MPI_NO_OP: int = 1476395022, MPI_Comm_size: int = 4, MPI_COMM_WORLD: int = 1140850688, MPI_COMM_SELF: int = 1140850689, MPI_Request_size: int = 4, MPI_Status_size: int = 20, MPI_STATUS_IGNORE: int = 1, MPI_STATUSES_IGNORE: int = 1, MPI_Status_field_MPI_SOURCE: int = 8, MPI_Status_field_MPI_TAG: int = 12, MPI_Status_field_MPI_ERROR: int = 16, MPI_IN_PLACE: int = -1) -> None

LowerMpiInit dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
321
322
323
324
325
326
327
328
329
330
331
332
333
class LowerMpiInit(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.InitOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(self, op: mpi.InitOp) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        We currently don't model any argument passing to `MPI_Init()` and pass two nullptrs.
        """
        return [
            nullptr := llvm.ZeroOp(result_types=[llvm.LLVMPointerType()]),
            func.CallOp(self._mpi_name(op), [nullptr, nullptr], [i32]),
        ], []

match_and_rewrite(op: mpi.InitOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
322
323
324
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.InitOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.InitOp) -> tuple[list[Operation], list[SSAValue | None]]

We currently don't model any argument passing to MPI_Init() and pass two nullptrs.

Source code in xdsl/transforms/lower_mpi.py
326
327
328
329
330
331
332
333
def lower(self, op: mpi.InitOp) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    We currently don't model any argument passing to `MPI_Init()` and pass two nullptrs.
    """
    return [
        nullptr := llvm.ZeroOp(result_types=[llvm.LLVMPointerType()]),
        func.CallOp(self._mpi_name(op), [nullptr, nullptr], [i32]),
    ], []

LowerMpiFinalize dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
class LowerMpiFinalize(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.FinalizeOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(
        self, op: mpi.FinalizeOp
    ) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        Relatively straight forward lowering of mpi.finalize operation.
        """
        return [
            func.CallOp(self._mpi_name(op), [], [i32]),
        ], []

match_and_rewrite(op: mpi.FinalizeOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
337
338
339
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.FinalizeOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.FinalizeOp) -> tuple[list[Operation], list[SSAValue | None]]

Relatively straight forward lowering of mpi.finalize operation.

Source code in xdsl/transforms/lower_mpi.py
341
342
343
344
345
346
347
348
349
def lower(
    self, op: mpi.FinalizeOp
) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    Relatively straight forward lowering of mpi.finalize operation.
    """
    return [
        func.CallOp(self._mpi_name(op), [], [i32]),
    ], []

LowerMpiWait dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
class LowerMpiWait(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.WaitOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(self, op: mpi.WaitOp) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        Relatively straight forward lowering of mpi.wait operation.
        """
        ops, new_results, res = self._emit_mpi_status_objs(len(op.results))
        return [
            *ops,
            func.CallOp(self._mpi_name(op), [op.request, res], [i32]),
        ], new_results

match_and_rewrite(op: mpi.WaitOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
353
354
355
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.WaitOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.WaitOp) -> tuple[list[Operation], list[SSAValue | None]]

Relatively straight forward lowering of mpi.wait operation.

Source code in xdsl/transforms/lower_mpi.py
357
358
359
360
361
362
363
364
365
def lower(self, op: mpi.WaitOp) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    Relatively straight forward lowering of mpi.wait operation.
    """
    ops, new_results, res = self._emit_mpi_status_objs(len(op.results))
    return [
        *ops,
        func.CallOp(self._mpi_name(op), [op.request, res], [i32]),
    ], new_results

LowerMpiWaitall dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
class LowerMpiWaitall(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.WaitallOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(self, op: mpi.WaitallOp) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        Relatively straight forward lowering of mpi.waitall operation.
        """

        ops, new_results, res = self._emit_mpi_status_objs(len(op.results))
        return [
            *ops,
            func.CallOp(self._mpi_name(op), [op.count, op.requests, res], [i32]),
        ], new_results

match_and_rewrite(op: mpi.WaitallOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
369
370
371
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.WaitallOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.WaitallOp) -> tuple[list[Operation], list[SSAValue | None]]

Relatively straight forward lowering of mpi.waitall operation.

Source code in xdsl/transforms/lower_mpi.py
373
374
375
376
377
378
379
380
381
382
def lower(self, op: mpi.WaitallOp) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    Relatively straight forward lowering of mpi.waitall operation.
    """

    ops, new_results, res = self._emit_mpi_status_objs(len(op.results))
    return [
        *ops,
        func.CallOp(self._mpi_name(op), [op.count, op.requests, res], [i32]),
    ], new_results

LowerMpiReduce dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
class LowerMpiReduce(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.ReduceOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(self, op: mpi.ReduceOp) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        Lowers the MPI Reduce operation
        """

        return [
            comm_global := arith.ConstantOp.from_int_and_width(
                self.info.MPI_COMM_WORLD, i32
            ),
            mpi_op := self._emit_mpi_operation_load(op.operationtype),
            func.CallOp(
                self._mpi_name(op),
                [
                    op.send_buffer,
                    op.recv_buffer,
                    op.count,
                    op.datatype,
                    mpi_op,
                    op.root,
                    comm_global,
                ],
                [],
            ),
        ], []

match_and_rewrite(op: mpi.ReduceOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
386
387
388
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.ReduceOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.ReduceOp) -> tuple[list[Operation], list[SSAValue | None]]

Lowers the MPI Reduce operation

Source code in xdsl/transforms/lower_mpi.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
def lower(self, op: mpi.ReduceOp) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    Lowers the MPI Reduce operation
    """

    return [
        comm_global := arith.ConstantOp.from_int_and_width(
            self.info.MPI_COMM_WORLD, i32
        ),
        mpi_op := self._emit_mpi_operation_load(op.operationtype),
        func.CallOp(
            self._mpi_name(op),
            [
                op.send_buffer,
                op.recv_buffer,
                op.count,
                op.datatype,
                mpi_op,
                op.root,
                comm_global,
            ],
            [],
        ),
    ], []

LowerMpiAllreduce dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
class LowerMpiAllreduce(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.AllreduceOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(
        self, op: mpi.AllreduceOp
    ) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        Lowers the MPI Allreduce operation
        """

        # Send buffer is optional (if not provided then call using MPI_IN_PLACE)
        has_send_buffer = op.send_buffer is not None

        comm_global = arith.ConstantOp.from_int_and_width(self.info.MPI_COMM_WORLD, i32)
        mpi_op = self._emit_mpi_operation_load(op.operationtype)

        operations = [comm_global, mpi_op]

        send_buffer_op: SSAValue | Operation
        if has_send_buffer:
            assert op.send_buffer is not None
            send_buffer_op = op.send_buffer
        else:
            send_buffer_op = arith.ConstantOp.from_int_and_width(
                self.info.MPI_IN_PLACE, i64
            )
            operations.append(send_buffer_op)

        return [
            *operations,
            func.CallOp(
                self._mpi_name(op),
                [
                    send_buffer_op,
                    op.recv_buffer,
                    op.count,
                    op.datatype,
                    mpi_op,
                    comm_global,
                ],
                [],
            ),
        ], []

match_and_rewrite(op: mpi.AllreduceOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
417
418
419
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.AllreduceOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.AllreduceOp) -> tuple[list[Operation], list[SSAValue | None]]

Lowers the MPI Allreduce operation

Source code in xdsl/transforms/lower_mpi.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def lower(
    self, op: mpi.AllreduceOp
) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    Lowers the MPI Allreduce operation
    """

    # Send buffer is optional (if not provided then call using MPI_IN_PLACE)
    has_send_buffer = op.send_buffer is not None

    comm_global = arith.ConstantOp.from_int_and_width(self.info.MPI_COMM_WORLD, i32)
    mpi_op = self._emit_mpi_operation_load(op.operationtype)

    operations = [comm_global, mpi_op]

    send_buffer_op: SSAValue | Operation
    if has_send_buffer:
        assert op.send_buffer is not None
        send_buffer_op = op.send_buffer
    else:
        send_buffer_op = arith.ConstantOp.from_int_and_width(
            self.info.MPI_IN_PLACE, i64
        )
        operations.append(send_buffer_op)

    return [
        *operations,
        func.CallOp(
            self._mpi_name(op),
            [
                send_buffer_op,
                op.recv_buffer,
                op.count,
                op.datatype,
                mpi_op,
                comm_global,
            ],
            [],
        ),
    ], []

LowerMpiBcast dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
class LowerMpiBcast(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.BcastOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(self, op: mpi.BcastOp) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        Lowers the MPI Bcast operation
        """

        return [
            comm_global := arith.ConstantOp.from_int_and_width(
                self.info.MPI_COMM_WORLD, i32
            ),
            func.CallOp(
                self._mpi_name(op),
                [op.buffer, op.count, op.datatype, op.root, comm_global],
                [],
            ),
        ], []

match_and_rewrite(op: mpi.BcastOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
464
465
466
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.BcastOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.BcastOp) -> tuple[list[Operation], list[SSAValue | None]]

Lowers the MPI Bcast operation

Source code in xdsl/transforms/lower_mpi.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def lower(self, op: mpi.BcastOp) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    Lowers the MPI Bcast operation
    """

    return [
        comm_global := arith.ConstantOp.from_int_and_width(
            self.info.MPI_COMM_WORLD, i32
        ),
        func.CallOp(
            self._mpi_name(op),
            [op.buffer, op.count, op.datatype, op.root, comm_global],
            [],
        ),
    ], []

LowerMpiIsend dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
class LowerMpiIsend(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.IsendOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(self, op: mpi.IsendOp) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        This method lowers mpi.isend

        int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest,
              int tag, MPI_Comm comm, MPI_Request *request)
        """

        return [
            comm_global := arith.ConstantOp.from_int_and_width(
                self.info.MPI_COMM_WORLD, i32
            ),
            func.CallOp(
                self._mpi_name(op),
                [
                    op.buffer,
                    op.count,
                    op.datatype,
                    op.dest,
                    op.tag,
                    comm_global,
                    op.request,
                ],
                [i32],
            ),
        ], []

match_and_rewrite(op: mpi.IsendOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
486
487
488
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.IsendOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.IsendOp) -> tuple[list[Operation], list[SSAValue | None]]

This method lowers mpi.isend

int MPI_Isend(const void buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request request)

Source code in xdsl/transforms/lower_mpi.py
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def lower(self, op: mpi.IsendOp) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    This method lowers mpi.isend

    int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest,
          int tag, MPI_Comm comm, MPI_Request *request)
    """

    return [
        comm_global := arith.ConstantOp.from_int_and_width(
            self.info.MPI_COMM_WORLD, i32
        ),
        func.CallOp(
            self._mpi_name(op),
            [
                op.buffer,
                op.count,
                op.datatype,
                op.dest,
                op.tag,
                comm_global,
                op.request,
            ],
            [i32],
        ),
    ], []

LowerMpiIrecv dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
class LowerMpiIrecv(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.IrecvOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(self, op: mpi.IrecvOp) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        This method lowers mpi.irecv operations

        int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
              MPI_Comm comm, MPI_Request *request)
        """

        return [
            comm_global := arith.ConstantOp.from_int_and_width(
                self.info.MPI_COMM_WORLD, i32
            ),
            func.CallOp(
                self._mpi_name(op),
                [
                    op.buffer,
                    op.count,
                    op.datatype,
                    op.source,
                    op.tag,
                    comm_global,
                    op.request,
                ],
                [i32],
            ),
        ], []

match_and_rewrite(op: mpi.IrecvOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
519
520
521
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.IrecvOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.IrecvOp) -> tuple[list[Operation], list[SSAValue | None]]

This method lowers mpi.irecv operations

int MPI_Irecv(void buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request request)

Source code in xdsl/transforms/lower_mpi.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
def lower(self, op: mpi.IrecvOp) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    This method lowers mpi.irecv operations

    int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
          MPI_Comm comm, MPI_Request *request)
    """

    return [
        comm_global := arith.ConstantOp.from_int_and_width(
            self.info.MPI_COMM_WORLD, i32
        ),
        func.CallOp(
            self._mpi_name(op),
            [
                op.buffer,
                op.count,
                op.datatype,
                op.source,
                op.tag,
                comm_global,
                op.request,
            ],
            [i32],
        ),
    ], []

LowerMpiSend dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
class LowerMpiSend(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.SendOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(self, op: mpi.SendOp) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        This method lowers mpi.send operations

        MPI_Send signature:

        int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest,
                 int tag, MPI_Comm comm)
        """

        return [
            comm_global := arith.ConstantOp.from_int_and_width(
                self.info.MPI_COMM_WORLD, i32
            ),
            func.CallOp(
                self._mpi_name(op),
                [op.buffer, op.count, op.datatype, op.dest, op.tag, comm_global],
                [i32],
            ),
        ], []

match_and_rewrite(op: mpi.SendOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
552
553
554
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.SendOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.SendOp) -> tuple[list[Operation], list[SSAValue | None]]

This method lowers mpi.send operations

MPI_Send signature:

int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm)

Source code in xdsl/transforms/lower_mpi.py
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
def lower(self, op: mpi.SendOp) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    This method lowers mpi.send operations

    MPI_Send signature:

    int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest,
             int tag, MPI_Comm comm)
    """

    return [
        comm_global := arith.ConstantOp.from_int_and_width(
            self.info.MPI_COMM_WORLD, i32
        ),
        func.CallOp(
            self._mpi_name(op),
            [op.buffer, op.count, op.datatype, op.dest, op.tag, comm_global],
            [i32],
        ),
    ], []

LowerMpiRecv dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
class LowerMpiRecv(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.RecvOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(self, op: mpi.RecvOp) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        This method lowers mpi.recv operations

        MPI_Recv signature:

        int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
             MPI_Comm comm, MPI_Status *status)
        """

        mpi_status_ops, new_results, status = self._emit_mpi_status_objs(
            len(op.results)
        )

        return [
            *mpi_status_ops,
            comm_global := arith.ConstantOp.from_int_and_width(
                self.info.MPI_COMM_WORLD, i32
            ),
            func.CallOp(
                self._mpi_name(op),
                [
                    op.buffer,
                    op.count,
                    op.datatype,
                    op.source,
                    op.tag,
                    comm_global,
                    status,
                ],
                [i32],
            ),
        ], new_results

match_and_rewrite(op: mpi.RecvOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
579
580
581
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.RecvOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.RecvOp) -> tuple[list[Operation], list[SSAValue | None]]

This method lowers mpi.recv operations

MPI_Recv signature:

int MPI_Recv(void buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status status)

Source code in xdsl/transforms/lower_mpi.py
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
def lower(self, op: mpi.RecvOp) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    This method lowers mpi.recv operations

    MPI_Recv signature:

    int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
         MPI_Comm comm, MPI_Status *status)
    """

    mpi_status_ops, new_results, status = self._emit_mpi_status_objs(
        len(op.results)
    )

    return [
        *mpi_status_ops,
        comm_global := arith.ConstantOp.from_int_and_width(
            self.info.MPI_COMM_WORLD, i32
        ),
        func.CallOp(
            self._mpi_name(op),
            [
                op.buffer,
                op.count,
                op.datatype,
                op.source,
                op.tag,
                comm_global,
                status,
            ],
            [i32],
        ),
    ], new_results

LowerMpiUnwrapMemRefOp dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
class LowerMpiUnwrapMemRefOp(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.UnwrapMemRefOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(
        self, op: mpi.UnwrapMemRefOp
    ) -> tuple[list[Operation], list[SSAValue | None]]:
        count_ops, count_ssa_val = self._emit_memref_counts(op.ref)
        extract_ptr_ops, ptr = self._memref_get_llvm_ptr(op.ref)

        elem_type = cast(MemRefType[mpi.AnyNumericType], op.ref.type).element_type

        return [
            *extract_ptr_ops,
            *count_ops,
            dtype := mpi.GetDtypeOp(elem_type),
        ], [ptr.results[0], count_ssa_val, dtype.result]

match_and_rewrite(op: mpi.UnwrapMemRefOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
619
620
621
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.UnwrapMemRefOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.UnwrapMemRefOp) -> tuple[list[Operation], list[SSAValue | None]]

Source code in xdsl/transforms/lower_mpi.py
623
624
625
626
627
628
629
630
631
632
633
634
635
def lower(
    self, op: mpi.UnwrapMemRefOp
) -> tuple[list[Operation], list[SSAValue | None]]:
    count_ops, count_ssa_val = self._emit_memref_counts(op.ref)
    extract_ptr_ops, ptr = self._memref_get_llvm_ptr(op.ref)

    elem_type = cast(MemRefType[mpi.AnyNumericType], op.ref.type).element_type

    return [
        *extract_ptr_ops,
        *count_ops,
        dtype := mpi.GetDtypeOp(elem_type),
    ], [ptr.results[0], count_ssa_val, dtype.result]

LowerMpiGetDtype dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
638
639
640
641
642
643
644
645
646
647
648
class LowerMpiGetDtype(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.GetDtypeOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(
        self, op: mpi.GetDtypeOp
    ) -> tuple[list[Operation], list[SSAValue | None]]:
        return [
            dtype := self._emit_mpi_type_load(op.dtype),
        ], [dtype.results[0]]

match_and_rewrite(op: mpi.GetDtypeOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
639
640
641
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.GetDtypeOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.GetDtypeOp) -> tuple[list[Operation], list[SSAValue | None]]

Source code in xdsl/transforms/lower_mpi.py
643
644
645
646
647
648
def lower(
    self, op: mpi.GetDtypeOp
) -> tuple[list[Operation], list[SSAValue | None]]:
    return [
        dtype := self._emit_mpi_type_load(op.dtype),
    ], [dtype.results[0]]

LowerMpiAllocateType dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
class LowerMpiAllocateType(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.AllocateTypeOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(
        self, op: mpi.AllocateTypeOp
    ) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        Allocation operation, allocates the required memory as an LLVM pointer
        """
        datatype_size = self._get_mpi_dtype_size(op.dtype)
        return [
            request := llvm.AllocaOp(op.count, builtin.IntegerType(8 * datatype_size)),
        ], [request.results[0]]

match_and_rewrite(op: mpi.AllocateTypeOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
652
653
654
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.AllocateTypeOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.AllocateTypeOp) -> tuple[list[Operation], list[SSAValue | None]]

Allocation operation, allocates the required memory as an LLVM pointer

Source code in xdsl/transforms/lower_mpi.py
656
657
658
659
660
661
662
663
664
665
def lower(
    self, op: mpi.AllocateTypeOp
) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    Allocation operation, allocates the required memory as an LLVM pointer
    """
    datatype_size = self._get_mpi_dtype_size(op.dtype)
    return [
        request := llvm.AllocaOp(op.count, builtin.IntegerType(8 * datatype_size)),
    ], [request.results[0]]

LowerMpiVectorGet dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
class LowerMpiVectorGet(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.VectorGetOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(
        self, op: mpi.VectorGetOp
    ) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        This lowers the get array at index MPI operation in the dialect. Converts
        the pointer to an integer and then increments this to find the correct
        location before going back to a pointer and setting this as the result
        """

        assert mpi.VectorWrappableConstr.verifies(op.result.type)
        assert isa(op.vect.type, llvm.LLVMPointerType)
        datatype_size = self._get_mpi_dtype_size(op.result.type)

        return [
            ptr_int := llvm.PtrToIntOp(op.vect, i64),
            lit1 := arith.ConstantOp.from_int_and_width(datatype_size, 64),
            idx_cast1 := arith.IndexCastOp(op.element, IndexType()),
            idx_cast2 := arith.IndexCastOp(idx_cast1, i64),
            mul := arith.MuliOp(lit1, idx_cast2),
            add := arith.AddiOp(mul, ptr_int),
            out_ptr := llvm.IntToPtrOp(add),
        ], [out_ptr.results[0]]

match_and_rewrite(op: mpi.VectorGetOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
669
670
671
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.VectorGetOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.VectorGetOp) -> tuple[list[Operation], list[SSAValue | None]]

This lowers the get array at index MPI operation in the dialect. Converts the pointer to an integer and then increments this to find the correct location before going back to a pointer and setting this as the result

Source code in xdsl/transforms/lower_mpi.py
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
def lower(
    self, op: mpi.VectorGetOp
) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    This lowers the get array at index MPI operation in the dialect. Converts
    the pointer to an integer and then increments this to find the correct
    location before going back to a pointer and setting this as the result
    """

    assert mpi.VectorWrappableConstr.verifies(op.result.type)
    assert isa(op.vect.type, llvm.LLVMPointerType)
    datatype_size = self._get_mpi_dtype_size(op.result.type)

    return [
        ptr_int := llvm.PtrToIntOp(op.vect, i64),
        lit1 := arith.ConstantOp.from_int_and_width(datatype_size, 64),
        idx_cast1 := arith.IndexCastOp(op.element, IndexType()),
        idx_cast2 := arith.IndexCastOp(idx_cast1, i64),
        mul := arith.MuliOp(lit1, idx_cast2),
        add := arith.AddiOp(mul, ptr_int),
        out_ptr := llvm.IntToPtrOp(add),
    ], [out_ptr.results[0]]

LowerMpiCommRank dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
class LowerMpiCommRank(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.CommRankOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(
        self, op: mpi.CommRankOp
    ) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        This method lowers mpi.comm.rank operation

        int MPI_Comm_rank(MPI_Comm comm, int *rank)
        """
        return [
            comm_global := arith.ConstantOp.from_int_and_width(
                self.info.MPI_COMM_WORLD, i32
            ),
            lit1 := arith.ConstantOp.from_int_and_width(1, 64),
            int_ptr := llvm.AllocaOp(lit1, i32),
            func.CallOp(self._mpi_name(op), [comm_global, int_ptr], [i32]),
            rank := llvm.LoadOp(int_ptr, IntegerType(32)),
        ], [rank.dereferenced_value]

match_and_rewrite(op: mpi.CommRankOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
698
699
700
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.CommRankOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.CommRankOp) -> tuple[list[Operation], list[SSAValue | None]]

This method lowers mpi.comm.rank operation

int MPI_Comm_rank(MPI_Comm comm, int *rank)

Source code in xdsl/transforms/lower_mpi.py
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
def lower(
    self, op: mpi.CommRankOp
) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    This method lowers mpi.comm.rank operation

    int MPI_Comm_rank(MPI_Comm comm, int *rank)
    """
    return [
        comm_global := arith.ConstantOp.from_int_and_width(
            self.info.MPI_COMM_WORLD, i32
        ),
        lit1 := arith.ConstantOp.from_int_and_width(1, 64),
        int_ptr := llvm.AllocaOp(lit1, i32),
        func.CallOp(self._mpi_name(op), [comm_global, int_ptr], [i32]),
        rank := llvm.LoadOp(int_ptr, IntegerType(32)),
    ], [rank.dereferenced_value]

LowerMpiCommSize dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
class LowerMpiCommSize(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.CommSizeOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(
        self, op: mpi.CommSizeOp
    ) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        This method lowers mpi.comm.size operation

        int MPI_Comm_size(MPI_Comm comm, int *size)
        """
        return [
            comm_global := arith.ConstantOp.from_int_and_width(
                self.info.MPI_COMM_WORLD, i32
            ),
            lit1 := arith.ConstantOp.from_int_and_width(1, 64),
            int_ptr := llvm.AllocaOp(lit1, i32),
            func.CallOp(self._mpi_name(op), [comm_global, int_ptr], [i32]),
            rank := llvm.LoadOp(int_ptr, IntegerType(32)),
        ], [rank.dereferenced_value]

match_and_rewrite(op: mpi.CommSizeOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
722
723
724
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.CommSizeOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.CommSizeOp) -> tuple[list[Operation], list[SSAValue | None]]

This method lowers mpi.comm.size operation

int MPI_Comm_size(MPI_Comm comm, int *size)

Source code in xdsl/transforms/lower_mpi.py
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
def lower(
    self, op: mpi.CommSizeOp
) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    This method lowers mpi.comm.size operation

    int MPI_Comm_size(MPI_Comm comm, int *size)
    """
    return [
        comm_global := arith.ConstantOp.from_int_and_width(
            self.info.MPI_COMM_WORLD, i32
        ),
        lit1 := arith.ConstantOp.from_int_and_width(1, 64),
        int_ptr := llvm.AllocaOp(lit1, i32),
        func.CallOp(self._mpi_name(op), [comm_global, int_ptr], [i32]),
        rank := llvm.LoadOp(int_ptr, IntegerType(32)),
    ], [rank.dereferenced_value]

LowerNullRequestOp dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
class LowerNullRequestOp(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.NullRequestOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(
        self, op: mpi.NullRequestOp
    ) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        This method lowers mpi.comm.size operation

        int MPI_Comm_size(MPI_Comm comm, int *size)
        """
        assert isa(op.request.type, llvm.LLVMPointerType)
        return [
            val := arith.ConstantOp.from_int_and_width(self.info.MPI_REQUEST_NULL, i32),
            llvm.StoreOp(val, op.request),
        ], []

match_and_rewrite(op: mpi.NullRequestOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
777
778
779
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.NullRequestOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.NullRequestOp) -> tuple[list[Operation], list[SSAValue | None]]

This method lowers mpi.comm.size operation

int MPI_Comm_size(MPI_Comm comm, int *size)

Source code in xdsl/transforms/lower_mpi.py
781
782
783
784
785
786
787
788
789
790
791
792
793
def lower(
    self, op: mpi.NullRequestOp
) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    This method lowers mpi.comm.size operation

    int MPI_Comm_size(MPI_Comm comm, int *size)
    """
    assert isa(op.request.type, llvm.LLVMPointerType)
    return [
        val := arith.ConstantOp.from_int_and_width(self.info.MPI_REQUEST_NULL, i32),
        llvm.StoreOp(val, op.request),
    ], []

LowerMpiGatherOp dataclass

Bases: _MPIToLLVMRewriteBase

Source code in xdsl/transforms/lower_mpi.py
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
class LowerMpiGatherOp(_MPIToLLVMRewriteBase):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: mpi.GatherOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, *self.lower(op))

    def lower(self, op: mpi.GatherOp) -> tuple[list[Operation], list[SSAValue | None]]:
        """
        This method lowers mpi.gather operation.


        int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                       void *recvbuf, int recvcount, MPI_Datatype recvtype,
                       int root,
                       MPI_Comm comm)
        """
        return [
            comm_global := arith.ConstantOp.from_int_and_width(
                self.info.MPI_COMM_WORLD, i32
            ),
            func.CallOp(
                self._mpi_name(op),
                [
                    op.sendbuf,
                    op.sendcount,
                    op.sendtype,
                    op.recvbuf,
                    op.recvcount,
                    op.recvtype,
                    op.root,
                    comm_global,
                ],
                [i32],
            ),
        ], []

match_and_rewrite(op: mpi.GatherOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_mpi.py
797
798
799
@op_type_rewrite_pattern
def match_and_rewrite(self, op: mpi.GatherOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, *self.lower(op))

lower(op: mpi.GatherOp) -> tuple[list[Operation], list[SSAValue | None]]

This method lowers mpi.gather operation.

int MPI_Gather(const void sendbuf, int sendcount, MPI_Datatype sendtype, void recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)

Source code in xdsl/transforms/lower_mpi.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
def lower(self, op: mpi.GatherOp) -> tuple[list[Operation], list[SSAValue | None]]:
    """
    This method lowers mpi.gather operation.


    int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                   void *recvbuf, int recvcount, MPI_Datatype recvtype,
                   int root,
                   MPI_Comm comm)
    """
    return [
        comm_global := arith.ConstantOp.from_int_and_width(
            self.info.MPI_COMM_WORLD, i32
        ),
        func.CallOp(
            self._mpi_name(op),
            [
                op.sendbuf,
                op.sendcount,
                op.sendtype,
                op.recvbuf,
                op.recvcount,
                op.recvtype,
                op.root,
                comm_global,
            ],
            [i32],
        ),
    ], []

LowerMPIPass dataclass

Bases: ModulePass

Source code in xdsl/transforms/lower_mpi.py
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
@dataclass(frozen=True)
class LowerMPIPass(ModulePass):
    name = "lower-mpi"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        # TODO: how to get the lib info in here?
        lib_info = MpiLibraryInfo()
        walker1 = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LowerMpiInit(lib_info),
                    LowerMpiFinalize(lib_info),
                    LowerMpiWait(lib_info),
                    LowerMpiWaitall(lib_info),
                    LowerMpiCommRank(lib_info),
                    LowerMpiCommSize(lib_info),
                    LowerMpiIsend(lib_info),
                    LowerMpiIrecv(lib_info),
                    LowerMpiSend(lib_info),
                    LowerMpiRecv(lib_info),
                    LowerMpiReduce(lib_info),
                    LowerMpiAllreduce(lib_info),
                    LowerMpiBcast(lib_info),
                    LowerMpiUnwrapMemRefOp(lib_info),
                    LowerMpiGetDtype(lib_info),
                    LowerMpiAllocateType(lib_info),
                    LowerNullRequestOp(lib_info),
                    LowerMpiVectorGet(lib_info),
                    LowerMpiGatherOp(lib_info),
                ]
            ),
            apply_recursively=True,
        )

        walker1.rewrite_module(op)

        # add func.func to declare external functions
        add_external_func_defs(op)

name = 'lower-mpi' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/lower_mpi.py
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    # TODO: how to get the lib info in here?
    lib_info = MpiLibraryInfo()
    walker1 = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                LowerMpiInit(lib_info),
                LowerMpiFinalize(lib_info),
                LowerMpiWait(lib_info),
                LowerMpiWaitall(lib_info),
                LowerMpiCommRank(lib_info),
                LowerMpiCommSize(lib_info),
                LowerMpiIsend(lib_info),
                LowerMpiIrecv(lib_info),
                LowerMpiSend(lib_info),
                LowerMpiRecv(lib_info),
                LowerMpiReduce(lib_info),
                LowerMpiAllreduce(lib_info),
                LowerMpiBcast(lib_info),
                LowerMpiUnwrapMemRefOp(lib_info),
                LowerMpiGetDtype(lib_info),
                LowerMpiAllocateType(lib_info),
                LowerNullRequestOp(lib_info),
                LowerMpiVectorGet(lib_info),
                LowerMpiGatherOp(lib_info),
            ]
        ),
        apply_recursively=True,
    )

    walker1.rewrite_module(op)

    # add func.func to declare external functions
    add_external_func_defs(op)

add_external_func_defs(module: builtin.ModuleOp)

This rewriter adds all external function definitions for MPI calls to the module.

It does so by first walking the whole module to discover MPI_ calls. Then it inserts a func.Func.external() op with the correct types at the end of the module.

Make sure to apply this in a separate pass after the lowerings, otherwise this will match first and find no inserted MPI calls.

Source code in xdsl/transforms/lower_mpi.py
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
def add_external_func_defs(module: builtin.ModuleOp):
    """
    This rewriter adds all external function definitions for MPI calls to the module.

    It does so by first walking the whole module to discover MPI_ calls. Then
    it inserts a `func.Func.external()` op with the correct types at the end of the module.

    Make sure to apply this *in a separate pass after the lowerings*, otherwise
    this will match first and find no inserted MPI calls.
    """

    mpi_func_call_names = set(_MPIToLLVMRewriteBase.MPI_SYMBOL_NAMES.values())

    # collect all func calls to MPI functions
    funcs_to_emit: dict[str, tuple[Sequence[Attribute], Sequence[Attribute]]] = dict()

    for op in module.walk():
        if not isinstance(op, func.CallOp):
            continue
        if op.callee.string_value() not in mpi_func_call_names:
            continue
        funcs_to_emit[op.callee.string_value()] = (
            op.arguments.types,
            op.result_types,
        )

    # for each func found, add a FuncOp to the top of the module.
    for name, types in funcs_to_emit.items():
        SymbolTable.insert_or_update(module, func.FuncOp.external(name, *types))