Skip to content

Dmp

dmp

A dialect for handling distributed memory parallelism (DMP).

This is xDSL only for now.

This dialect aims to provide the tools necessary to facilitate the creation and lowering of stencil (and other) computations in a manner that makes them run on node clusters.

DIM_X = 0 module-attribute

DIM_Y = 1 module-attribute

DIM_Z = 2 module-attribute

DMP = Dialect('dmp', [SwapOp], [ExchangeDeclarationAttr, ShapeAttr, RankTopoAttr, GridSlice2dAttr, GridSlice3dAttr]) module-attribute

ExchangeDeclarationAttr

Bases: ParametrizedAttribute

This declares a region to be "halo-exchanged". The semantics define that the region specified by offset and size is the received part. To get the section that should be sent, use the source_area() method to get the source area.

  • offset gives the coordinates from the origin of the stencil field.
  • size gives the size of the buffer to be exchanged.
  • source_offset gives a translation (n-d offset) where the data should be read from that is exchanged with the other node.
  • neighbor gives the offset in rank to the node this data is to be exchanged with

Example:

offset = [4, 0]
size   = [10, 1]
source_offset = [0, 1]
neighbor = -1

To visualize: 0 4 14 xxxxxxxxxx 0 oooooooooo 1

Where x signifies the area that should be received, and o the area that should be read from.

This data will be exchanged with the node of rank (my_rank -1)

Source code in xdsl/dialects/experimental/dmp.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@irdl_attr_definition
class ExchangeDeclarationAttr(ParametrizedAttribute):
    """
    This declares a region to be "halo-exchanged".
    The semantics define that the region specified by offset and size
    is the *received part*. To get the section that should be sent,
    use the source_area() method to get the source area.

     - offset gives the coordinates from the origin of the stencil field.
     - size gives the size of the buffer to be exchanged.
     - source_offset gives a translation (n-d offset) where the data should be
       read from that is exchanged with the other node.
     - neighbor gives the offset in rank to the node this data is to be
       exchanged with

    Example:

        offset = [4, 0]
        size   = [10, 1]
        source_offset = [0, 1]
        neighbor = -1

    To visualize:
    0   4         14
        xxxxxxxxxx    0
        oooooooooo    1

    Where `x` signifies the area that should be received,
    and `o` the area that should be read from.

    This data will be exchanged with the node of rank (my_rank -1)
    """

    name = "dmp.exchange"

    offset_: builtin.DenseArrayBase[builtin.I64]
    size_: builtin.DenseArrayBase[builtin.I64]
    source_offset_: builtin.DenseArrayBase[builtin.I64]
    neighbor_: builtin.DenseArrayBase[builtin.I64]

    def __init__(
        self,
        offset: Sequence[int],
        size: Sequence[int],
        source_offset: Sequence[int],
        neighbor: Sequence[int],
    ):
        data_type = builtin.i64
        super().__init__(
            builtin.DenseArrayBase.from_list(data_type, offset),
            builtin.DenseArrayBase.from_list(data_type, size),
            builtin.DenseArrayBase.from_list(data_type, source_offset),
            builtin.DenseArrayBase.from_list(data_type, neighbor),
        )

    @classmethod
    def from_points(
        cls,
        points: Sequence[tuple[int, int]],
        dim: int,
        dir_sign: Literal[1, -1],
        neighbor_offset: int = 1,
    ):
        sizes = tuple(e - s for s, e in points)
        return cls(
            # get starting points
            tuple(s for s, _ in points),
            # calculated sizes
            sizes,
            # source_offset (opposite of exchange direction)
            tuple(
                0 if d != dim else -1 * dir_sign * sizes[dim] * neighbor_offset
                for d in range(len(sizes))
            ),
            # direction
            tuple(
                0 if d != dim else dir_sign * neighbor_offset for d in range(len(sizes))
            ),
        )

    @property
    def offset(self) -> tuple[int, ...]:
        return self.offset_.get_values()

    @property
    def size(self) -> tuple[int, ...]:
        return self.size_.get_values()

    @property
    def source_offset(self) -> tuple[int, ...]:
        return self.source_offset_.get_values()

    @property
    def neighbor(self) -> tuple[int, ...]:
        return self.neighbor_.get_values()

    @property
    def elem_count(self) -> int:
        return prod(self.size)

    @property
    def dims(self) -> int:
        """
        number of dimensions of the grid
        """
        return len(self.size)

    def source_area(self) -> ExchangeDeclarationAttr:
        """
        Since a HaloExchangeDef by default specifies the area to receive into,
        this method returns the area that should be read from.
        """
        # we set source_offset to all zero, so that repeated calls to source_area never
        # return the dest area
        return ExchangeDeclarationAttr(
            offset=tuple(
                val + offs for val, offs in zip(self.offset, self.source_offset)
            ),
            size=self.size,
            source_offset=tuple(0 for _ in range(len(self.source_offset))),
            neighbor=self.neighbor,
        )

    def print_parameters(self, printer: Printer) -> None:
        with printer.in_angle_brackets():
            printer.print_string("at ")
            with printer.in_square_brackets():
                printer.print_list(self.offset, printer.print_int)
            printer.print_string(" size ")
            with printer.in_square_brackets():
                printer.print_list(self.size, printer.print_int)
            printer.print_string(" source offset ")
            with printer.in_square_brackets():
                printer.print_list(self.source_offset, printer.print_int)
            printer.print_string(" to ")
            with printer.in_square_brackets():
                printer.print_list(self.neighbor, printer.print_int)

    @classmethod
    def parse_parameters(cls, parser: AttrParser) -> list[Attribute]:
        parser.parse_characters("<")
        parser.parse_characters("at")
        offset = parser.parse_comma_separated_list(
            parser.Delimiter.SQUARE, parser.parse_integer
        )
        parser.parse_characters("size")
        size = parser.parse_comma_separated_list(
            parser.Delimiter.SQUARE, parser.parse_integer
        )
        parser.parse_characters("source")
        parser.parse_characters("offset")
        source_offset = parser.parse_comma_separated_list(
            parser.Delimiter.SQUARE, parser.parse_integer
        )
        parser.parse_characters("to")
        to = parser.parse_comma_separated_list(
            parser.Delimiter.SQUARE, parser.parse_integer
        )
        parser.parse_characters(">")

        return [
            builtin.DenseArrayBase.from_list(builtin.i64, x)
            for x in (offset, size, source_offset, to)
        ]

name = 'dmp.exchange' class-attribute instance-attribute

offset_: builtin.DenseArrayBase[builtin.I64] instance-attribute

size_: builtin.DenseArrayBase[builtin.I64] instance-attribute

source_offset_: builtin.DenseArrayBase[builtin.I64] instance-attribute

neighbor_: builtin.DenseArrayBase[builtin.I64] instance-attribute

offset: tuple[int, ...] property

size: tuple[int, ...] property

source_offset: tuple[int, ...] property

neighbor: tuple[int, ...] property

elem_count: int property

dims: int property

number of dimensions of the grid

__init__(offset: Sequence[int], size: Sequence[int], source_offset: Sequence[int], neighbor: Sequence[int])

Source code in xdsl/dialects/experimental/dmp.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __init__(
    self,
    offset: Sequence[int],
    size: Sequence[int],
    source_offset: Sequence[int],
    neighbor: Sequence[int],
):
    data_type = builtin.i64
    super().__init__(
        builtin.DenseArrayBase.from_list(data_type, offset),
        builtin.DenseArrayBase.from_list(data_type, size),
        builtin.DenseArrayBase.from_list(data_type, source_offset),
        builtin.DenseArrayBase.from_list(data_type, neighbor),
    )

from_points(points: Sequence[tuple[int, int]], dim: int, dir_sign: Literal[1, -1], neighbor_offset: int = 1) classmethod

Source code in xdsl/dialects/experimental/dmp.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
@classmethod
def from_points(
    cls,
    points: Sequence[tuple[int, int]],
    dim: int,
    dir_sign: Literal[1, -1],
    neighbor_offset: int = 1,
):
    sizes = tuple(e - s for s, e in points)
    return cls(
        # get starting points
        tuple(s for s, _ in points),
        # calculated sizes
        sizes,
        # source_offset (opposite of exchange direction)
        tuple(
            0 if d != dim else -1 * dir_sign * sizes[dim] * neighbor_offset
            for d in range(len(sizes))
        ),
        # direction
        tuple(
            0 if d != dim else dir_sign * neighbor_offset for d in range(len(sizes))
        ),
    )

source_area() -> ExchangeDeclarationAttr

Since a HaloExchangeDef by default specifies the area to receive into, this method returns the area that should be read from.

Source code in xdsl/dialects/experimental/dmp.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def source_area(self) -> ExchangeDeclarationAttr:
    """
    Since a HaloExchangeDef by default specifies the area to receive into,
    this method returns the area that should be read from.
    """
    # we set source_offset to all zero, so that repeated calls to source_area never
    # return the dest area
    return ExchangeDeclarationAttr(
        offset=tuple(
            val + offs for val, offs in zip(self.offset, self.source_offset)
        ),
        size=self.size,
        source_offset=tuple(0 for _ in range(len(self.source_offset))),
        neighbor=self.neighbor,
    )

print_parameters(printer: Printer) -> None

Source code in xdsl/dialects/experimental/dmp.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def print_parameters(self, printer: Printer) -> None:
    with printer.in_angle_brackets():
        printer.print_string("at ")
        with printer.in_square_brackets():
            printer.print_list(self.offset, printer.print_int)
        printer.print_string(" size ")
        with printer.in_square_brackets():
            printer.print_list(self.size, printer.print_int)
        printer.print_string(" source offset ")
        with printer.in_square_brackets():
            printer.print_list(self.source_offset, printer.print_int)
        printer.print_string(" to ")
        with printer.in_square_brackets():
            printer.print_list(self.neighbor, printer.print_int)

parse_parameters(parser: AttrParser) -> list[Attribute] classmethod

Source code in xdsl/dialects/experimental/dmp.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@classmethod
def parse_parameters(cls, parser: AttrParser) -> list[Attribute]:
    parser.parse_characters("<")
    parser.parse_characters("at")
    offset = parser.parse_comma_separated_list(
        parser.Delimiter.SQUARE, parser.parse_integer
    )
    parser.parse_characters("size")
    size = parser.parse_comma_separated_list(
        parser.Delimiter.SQUARE, parser.parse_integer
    )
    parser.parse_characters("source")
    parser.parse_characters("offset")
    source_offset = parser.parse_comma_separated_list(
        parser.Delimiter.SQUARE, parser.parse_integer
    )
    parser.parse_characters("to")
    to = parser.parse_comma_separated_list(
        parser.Delimiter.SQUARE, parser.parse_integer
    )
    parser.parse_characters(">")

    return [
        builtin.DenseArrayBase.from_list(builtin.i64, x)
        for x in (offset, size, source_offset, to)
    ]

ShapeAttr dataclass

Bases: ParametrizedAttribute

This represents shape information that is attached to halo operations.

On the terminology used:

In each dimension, we are given four points. We abbreviate them in annotations to an, bn, cn, dn, with n being the dimension. In 2d, these create the following pattern, higher dimensional examples can be derived from this:

a1 b1 c1 d1 +--+-----------+--+ a0 | | | | +--+-----------+--+ b0 | | | | | | | | | | | | | | | | +--+-----------+--+ c0 | | | | +--+-----------+--+ d0

We can now name these points:

 - a: buffer_start
 - b: core_start
 - c: core_end
 - d: buffer_end

This class provides easy getters for these four.

We can also define some common sizes on this object:

- buff_size(n) = dn - an
- core_size(n) = cn - bn
- halo_size(n, start) = bn - an
- halo_size(n, end  ) = dn - cn
Source code in xdsl/dialects/experimental/dmp.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
@irdl_attr_definition
class ShapeAttr(ParametrizedAttribute):
    """
    This represents shape information that is attached to halo operations.

    On the terminology used:

    In each dimension, we are given four points. We abbreviate them in
    annotations to an, bn, cn, dn, with n being the dimension. In 2d, these
    create the following pattern, higher dimensional examples can
    be derived from this:

    a1 b1          c1 d1
    +--+-----------+--+ a0
    |  |           |  |
    +--+-----------+--+ b0
    |  |           |  |
    |  |           |  |
    |  |           |  |
    |  |           |  |
    +--+-----------+--+ c0
    |  |           |  |
    +--+-----------+--+ d0

    We can now name these points:

         - a: buffer_start
         - b: core_start
         - c: core_end
         - d: buffer_end

    This class provides easy getters for these four.

    We can also define some common sizes on this object:

        - buff_size(n) = dn - an
        - core_size(n) = cn - bn
        - halo_size(n, start) = bn - an
        - halo_size(n, end  ) = dn - cn
    """

    name = "dmp.shape_with_halo"

    buff_lb_: builtin.DenseArrayBase[builtin.I64]
    buff_ub_: builtin.DenseArrayBase[builtin.I64]
    core_lb_: builtin.DenseArrayBase[builtin.I64]
    core_ub_: builtin.DenseArrayBase[builtin.I64]

    @property
    def buff_lb(self) -> tuple[int, ...]:
        data = self.buff_lb_.get_values()
        return data

    @property
    def buff_ub(self) -> tuple[int, ...]:
        data = self.buff_ub_.get_values()
        return data

    @property
    def core_lb(self) -> tuple[int, ...]:
        data = self.core_lb_.get_values()
        return data

    @property
    def core_ub(self) -> tuple[int, ...]:
        data = self.core_ub_.get_values()
        return data

    @property
    def dims(self) -> int:
        """
        Number of axis of the data (len(shape))
        """
        return len(self.core_ub)

    @staticmethod
    def from_index_attrs(
        buff_lb: stencil.IndexAttr | Sequence[int],
        core_lb: stencil.IndexAttr | Sequence[int],
        core_ub: stencil.IndexAttr | Sequence[int],
        buff_ub: stencil.IndexAttr | Sequence[int],
    ):
        data_type = builtin.i64
        return ShapeAttr(
            *(
                builtin.DenseArrayBase.from_list(data_type, tuple(data))
                for data in (buff_lb, buff_ub, core_lb, core_ub)
            )
        )

    def buffer_start(self, dim: int) -> int:
        assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
        return self.buff_lb[dim]

    def core_start(self, dim: int) -> int:
        assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
        return self.core_lb[dim]

    def buffer_end(self, dim: int) -> int:
        assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
        return self.buff_ub[dim]

    def core_end(self, dim: int) -> int:
        assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
        return self.core_ub[dim]

    # Helpers for specific sizes:

    def buff_size(self, dim: int) -> int:
        assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
        return self.buff_ub[dim] - self.buff_lb[dim]

    def core_size(self, dim: int) -> int:
        assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
        return self.core_ub[dim] - self.core_lb[dim]

    def halo_size(self, dim: int, at_end: bool = False) -> int:
        assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
        if at_end:
            return self.buff_ub[dim] - self.core_ub[dim]
        return self.core_lb[dim] - self.buff_lb[dim]

    # parsing / printing

    def print_parameters(self, printer: Printer) -> None:
        dims = zip(self.buff_lb, self.core_lb, self.core_ub, self.buff_ub)
        printer.print_string("<")
        printer.print_string("x".join(f"{list(vals)}" for vals in dims))
        printer.print_string(">")

    @classmethod
    def parse_parameters(cls, parser: AttrParser) -> list[Attribute]:
        """
        Parses the attribute, the format of it is:

        #dmp.shape_with_halo<[a0,b0,c0,d0]x[a1,b1,c1,d1]x...>

        so different from the way it's stored internally.

        This decision was made to improve readability.
        """
        parser.parse_characters("<")
        buff_lb: list[int] = []
        buff_ub: list[int] = []
        core_lb: list[int] = []
        core_ub: list[int] = []

        while True:
            parser.parse_characters("[")
            buff_lb.append(parser.parse_integer())
            parser.parse_characters(",")
            core_lb.append(parser.parse_integer())
            parser.parse_characters(",")
            core_ub.append(parser.parse_integer())
            parser.parse_characters(",")
            buff_ub.append(parser.parse_integer())
            parser.parse_characters("]")
            if parser.parse_optional_characters("x") is None:
                break
        parser.parse_characters(">")

        data_type = builtin.i64
        return [
            builtin.DenseArrayBase.from_list(data_type, data)
            for data in (buff_lb, buff_ub, core_lb, core_ub)
        ]

name = 'dmp.shape_with_halo' class-attribute instance-attribute

buff_lb_: builtin.DenseArrayBase[builtin.I64] instance-attribute

buff_ub_: builtin.DenseArrayBase[builtin.I64] instance-attribute

core_lb_: builtin.DenseArrayBase[builtin.I64] instance-attribute

core_ub_: builtin.DenseArrayBase[builtin.I64] instance-attribute

buff_lb: tuple[int, ...] property

buff_ub: tuple[int, ...] property

core_lb: tuple[int, ...] property

core_ub: tuple[int, ...] property

dims: int property

Number of axis of the data (len(shape))

from_index_attrs(buff_lb: stencil.IndexAttr | Sequence[int], core_lb: stencil.IndexAttr | Sequence[int], core_ub: stencil.IndexAttr | Sequence[int], buff_ub: stencil.IndexAttr | Sequence[int]) staticmethod

Source code in xdsl/dialects/experimental/dmp.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
@staticmethod
def from_index_attrs(
    buff_lb: stencil.IndexAttr | Sequence[int],
    core_lb: stencil.IndexAttr | Sequence[int],
    core_ub: stencil.IndexAttr | Sequence[int],
    buff_ub: stencil.IndexAttr | Sequence[int],
):
    data_type = builtin.i64
    return ShapeAttr(
        *(
            builtin.DenseArrayBase.from_list(data_type, tuple(data))
            for data in (buff_lb, buff_ub, core_lb, core_ub)
        )
    )

buffer_start(dim: int) -> int

Source code in xdsl/dialects/experimental/dmp.py
308
309
310
def buffer_start(self, dim: int) -> int:
    assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
    return self.buff_lb[dim]

core_start(dim: int) -> int

Source code in xdsl/dialects/experimental/dmp.py
312
313
314
def core_start(self, dim: int) -> int:
    assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
    return self.core_lb[dim]

buffer_end(dim: int) -> int

Source code in xdsl/dialects/experimental/dmp.py
316
317
318
def buffer_end(self, dim: int) -> int:
    assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
    return self.buff_ub[dim]

core_end(dim: int) -> int

Source code in xdsl/dialects/experimental/dmp.py
320
321
322
def core_end(self, dim: int) -> int:
    assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
    return self.core_ub[dim]

buff_size(dim: int) -> int

Source code in xdsl/dialects/experimental/dmp.py
326
327
328
def buff_size(self, dim: int) -> int:
    assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
    return self.buff_ub[dim] - self.buff_lb[dim]

core_size(dim: int) -> int

Source code in xdsl/dialects/experimental/dmp.py
330
331
332
def core_size(self, dim: int) -> int:
    assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
    return self.core_ub[dim] - self.core_lb[dim]

halo_size(dim: int, at_end: bool = False) -> int

Source code in xdsl/dialects/experimental/dmp.py
334
335
336
337
338
def halo_size(self, dim: int, at_end: bool = False) -> int:
    assert dim < self.dims, f"The given DimsHelper only has {self.dims} dimensions"
    if at_end:
        return self.buff_ub[dim] - self.core_ub[dim]
    return self.core_lb[dim] - self.buff_lb[dim]

print_parameters(printer: Printer) -> None

Source code in xdsl/dialects/experimental/dmp.py
342
343
344
345
346
def print_parameters(self, printer: Printer) -> None:
    dims = zip(self.buff_lb, self.core_lb, self.core_ub, self.buff_ub)
    printer.print_string("<")
    printer.print_string("x".join(f"{list(vals)}" for vals in dims))
    printer.print_string(">")

parse_parameters(parser: AttrParser) -> list[Attribute] classmethod

Parses the attribute, the format of it is:

dmp.shape_with_halo<[a0,b0,c0,d0]x[a1,b1,c1,d1]x...>

so different from the way it's stored internally.

This decision was made to improve readability.

Source code in xdsl/dialects/experimental/dmp.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
@classmethod
def parse_parameters(cls, parser: AttrParser) -> list[Attribute]:
    """
    Parses the attribute, the format of it is:

    #dmp.shape_with_halo<[a0,b0,c0,d0]x[a1,b1,c1,d1]x...>

    so different from the way it's stored internally.

    This decision was made to improve readability.
    """
    parser.parse_characters("<")
    buff_lb: list[int] = []
    buff_ub: list[int] = []
    core_lb: list[int] = []
    core_ub: list[int] = []

    while True:
        parser.parse_characters("[")
        buff_lb.append(parser.parse_integer())
        parser.parse_characters(",")
        core_lb.append(parser.parse_integer())
        parser.parse_characters(",")
        core_ub.append(parser.parse_integer())
        parser.parse_characters(",")
        buff_ub.append(parser.parse_integer())
        parser.parse_characters("]")
        if parser.parse_optional_characters("x") is None:
            break
    parser.parse_characters(">")

    data_type = builtin.i64
    return [
        builtin.DenseArrayBase.from_list(data_type, data)
        for data in (buff_lb, buff_ub, core_lb, core_ub)
    ]

RankTopoAttr

Bases: ParametrizedAttribute

This attribute specifies the node layout used to distribute the computation.

dmp.grid<3x3> means nine ranks organized in a 3x3 grid.

This allows for higher-dimensional grids as well, e.g. dmp.grid<3x3x3> for 3-dimensional data.

Source code in xdsl/dialects/experimental/dmp.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
@irdl_attr_definition
class RankTopoAttr(ParametrizedAttribute):
    """
    This attribute specifies the node layout used to distribute the computation.

    dmp.grid<3x3> means nine ranks organized in a 3x3 grid.

    This allows for higher-dimensional grids as well, e.g. dmp.grid<3x3x3> for
    3-dimensional data.
    """

    name = "dmp.topo"

    shape: builtin.DenseArrayBase[builtin.I64]

    def __init__(self, shape: Sequence[int]):
        if len(shape) < 1:
            raise ValueError("dmp.grid must have at least one dimension!")
        super().__init__(builtin.DenseArrayBase.from_list(builtin.i64, shape))

    def as_tuple(self) -> tuple[int, ...]:
        shape = self.shape.get_values()
        return shape

    def node_count(self) -> int:
        return prod(self.as_tuple())

    @classmethod
    def parse_parameters(cls, parser: AttrParser) -> list[Attribute]:
        parser.parse_characters("<")
        shape: list[int] = [
            parser.parse_integer(allow_negative=False, allow_boolean=False)
        ]

        while parser.parse_optional_punctuation(">") is None:
            parser.parse_shape_delimiter()
            shape.append(
                parser.parse_integer(allow_negative=False, allow_boolean=False)
            )

        return [builtin.DenseArrayBase.from_list(builtin.i64, shape)]

    def print_parameters(self, printer: Printer) -> None:
        printer.print_string("<")
        printer.print_string("x".join(str(x) for x in self.shape.get_values()))
        printer.print_string(">")

name = 'dmp.topo' class-attribute instance-attribute

shape: builtin.DenseArrayBase[builtin.I64] instance-attribute

__init__(shape: Sequence[int])

Source code in xdsl/dialects/experimental/dmp.py
401
402
403
404
def __init__(self, shape: Sequence[int]):
    if len(shape) < 1:
        raise ValueError("dmp.grid must have at least one dimension!")
    super().__init__(builtin.DenseArrayBase.from_list(builtin.i64, shape))

as_tuple() -> tuple[int, ...]

Source code in xdsl/dialects/experimental/dmp.py
406
407
408
def as_tuple(self) -> tuple[int, ...]:
    shape = self.shape.get_values()
    return shape

node_count() -> int

Source code in xdsl/dialects/experimental/dmp.py
410
411
def node_count(self) -> int:
    return prod(self.as_tuple())

parse_parameters(parser: AttrParser) -> list[Attribute] classmethod

Source code in xdsl/dialects/experimental/dmp.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
@classmethod
def parse_parameters(cls, parser: AttrParser) -> list[Attribute]:
    parser.parse_characters("<")
    shape: list[int] = [
        parser.parse_integer(allow_negative=False, allow_boolean=False)
    ]

    while parser.parse_optional_punctuation(">") is None:
        parser.parse_shape_delimiter()
        shape.append(
            parser.parse_integer(allow_negative=False, allow_boolean=False)
        )

    return [builtin.DenseArrayBase.from_list(builtin.i64, shape)]

print_parameters(printer: Printer) -> None

Source code in xdsl/dialects/experimental/dmp.py
428
429
430
431
def print_parameters(self, printer: Printer) -> None:
    printer.print_string("<")
    printer.print_string("x".join(str(x) for x in self.shape.get_values()))
    printer.print_string(">")

DomainDecompositionStrategy dataclass

Bases: ParametrizedAttribute, ABC

Source code in xdsl/dialects/experimental/dmp.py
434
435
436
437
438
439
440
441
442
class DomainDecompositionStrategy(ParametrizedAttribute, ABC):
    def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]:
        raise NotImplementedError("SlicingStrategy must implement calc_resize!")

    def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]:
        raise NotImplementedError("SlicingStrategy must implement halo_exchange_defs!")

    def comm_layout(self) -> RankTopoAttr:
        raise NotImplementedError("SlicingStrategy must implement comm_count!")

calc_resize(shape: tuple[int, ...]) -> tuple[int, ...]

Source code in xdsl/dialects/experimental/dmp.py
435
436
def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]:
    raise NotImplementedError("SlicingStrategy must implement calc_resize!")

halo_exchange_defs(shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]

Source code in xdsl/dialects/experimental/dmp.py
438
439
def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]:
    raise NotImplementedError("SlicingStrategy must implement halo_exchange_defs!")

comm_layout() -> RankTopoAttr

Source code in xdsl/dialects/experimental/dmp.py
441
442
def comm_layout(self) -> RankTopoAttr:
    raise NotImplementedError("SlicingStrategy must implement comm_count!")

GridSlice2dAttr

Bases: DomainDecompositionStrategy

Takes a grid with two or more dimensions, slices it along the first two into equally sized segments.

Source code in xdsl/dialects/experimental/dmp.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
@irdl_attr_definition
class GridSlice2dAttr(DomainDecompositionStrategy):
    """
    Takes a grid with two or more dimensions, slices it along the first two into equally
    sized segments.
    """

    name = "dmp.grid_slice_2d"

    topology: RankTopoAttr

    diagonals: builtin.BoolAttr

    def __init__(self, topo: tuple[int, ...]):
        super().__init__(RankTopoAttr(topo), builtin.BoolAttr.from_int_and_width(0, 1))

    def _verify(self):
        assert len(self.topology.as_tuple()) >= 2, (
            "GridSlice2d requires at least two dimensions"
        )

    def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]:
        assert len(shape) >= 2, "GridSlice2d requires at least two dimensions"
        for size, node_count in zip(shape, self.topology.as_tuple()):
            assert size % node_count == 0, (
                "GridSlice2d requires domain be neatly divisible by shape"
            )
        return (
            *(
                size // node_count
                for size, node_count in zip(shape, self.topology.as_tuple())
            ),
            *(size for size in shape[2:]),
        )

    def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]:
        yield from _flat_face_exchanges_for_dim(shape, 0)

        yield from _flat_face_exchanges_for_dim(shape, 1)

        if self.diagonals.value.data:
            raise NotImplementedError("Diagonals support not implemented yet")

    def comm_layout(self) -> RankTopoAttr:
        return RankTopoAttr(self.topology.as_tuple())

name = 'dmp.grid_slice_2d' class-attribute instance-attribute

topology: RankTopoAttr instance-attribute

diagonals: builtin.BoolAttr instance-attribute

__init__(topo: tuple[int, ...])

Source code in xdsl/dialects/experimental/dmp.py
458
459
def __init__(self, topo: tuple[int, ...]):
    super().__init__(RankTopoAttr(topo), builtin.BoolAttr.from_int_and_width(0, 1))

calc_resize(shape: tuple[int, ...]) -> tuple[int, ...]

Source code in xdsl/dialects/experimental/dmp.py
466
467
468
469
470
471
472
473
474
475
476
477
478
def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]:
    assert len(shape) >= 2, "GridSlice2d requires at least two dimensions"
    for size, node_count in zip(shape, self.topology.as_tuple()):
        assert size % node_count == 0, (
            "GridSlice2d requires domain be neatly divisible by shape"
        )
    return (
        *(
            size // node_count
            for size, node_count in zip(shape, self.topology.as_tuple())
        ),
        *(size for size in shape[2:]),
    )

halo_exchange_defs(shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]

Source code in xdsl/dialects/experimental/dmp.py
480
481
482
483
484
485
486
def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]:
    yield from _flat_face_exchanges_for_dim(shape, 0)

    yield from _flat_face_exchanges_for_dim(shape, 1)

    if self.diagonals.value.data:
        raise NotImplementedError("Diagonals support not implemented yet")

comm_layout() -> RankTopoAttr

Source code in xdsl/dialects/experimental/dmp.py
488
489
def comm_layout(self) -> RankTopoAttr:
    return RankTopoAttr(self.topology.as_tuple())

GridSlice3dAttr

Bases: DomainDecompositionStrategy

Takes a grid with two or more dimensions, slices it along the first three.

Source code in xdsl/dialects/experimental/dmp.py
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
@irdl_attr_definition
class GridSlice3dAttr(DomainDecompositionStrategy):
    """
    Takes a grid with two or more dimensions, slices it along the first three.
    """

    name = "dmp.grid_slice_3d"

    topology: RankTopoAttr

    diagonals: builtin.BoolAttr

    def __init__(self, topo: tuple[int, ...]):
        super().__init__(RankTopoAttr(topo), builtin.BoolAttr.from_int_and_width(0, 1))

    def _verify(self):
        assert len(self.topology.as_tuple()) >= 3, (
            "GridSlice3d requires at least three dimensions"
        )

    def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]:
        assert len(shape) >= 3, "GridSlice3d requires at least two dimensions"
        for size, node_count in zip(shape, self.topology.as_tuple()):
            assert size % node_count == 0, (
                "GridSlice3d requires domain be neatly divisible by shape"
            )
        return (
            *(
                size // node_count
                for size, node_count in zip(shape, self.topology.as_tuple())
            ),
            *(size for size in shape[3:]),
        )

    def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]:
        yield from _flat_face_exchanges_for_dim(shape, 0)

        yield from _flat_face_exchanges_for_dim(shape, 1)

        yield from _flat_face_exchanges_for_dim(shape, 2)

        if self.diagonals.value.data:
            raise NotImplementedError("Diagonals support not implemented yet")

    def comm_layout(self) -> RankTopoAttr:
        return RankTopoAttr(self.topology.as_tuple())

name = 'dmp.grid_slice_3d' class-attribute instance-attribute

topology: RankTopoAttr instance-attribute

diagonals: builtin.BoolAttr instance-attribute

__init__(topo: tuple[int, ...])

Source code in xdsl/dialects/experimental/dmp.py
504
505
def __init__(self, topo: tuple[int, ...]):
    super().__init__(RankTopoAttr(topo), builtin.BoolAttr.from_int_and_width(0, 1))

calc_resize(shape: tuple[int, ...]) -> tuple[int, ...]

Source code in xdsl/dialects/experimental/dmp.py
512
513
514
515
516
517
518
519
520
521
522
523
524
def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]:
    assert len(shape) >= 3, "GridSlice3d requires at least two dimensions"
    for size, node_count in zip(shape, self.topology.as_tuple()):
        assert size % node_count == 0, (
            "GridSlice3d requires domain be neatly divisible by shape"
        )
    return (
        *(
            size // node_count
            for size, node_count in zip(shape, self.topology.as_tuple())
        ),
        *(size for size in shape[3:]),
    )

halo_exchange_defs(shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]

Source code in xdsl/dialects/experimental/dmp.py
526
527
528
529
530
531
532
533
534
def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]:
    yield from _flat_face_exchanges_for_dim(shape, 0)

    yield from _flat_face_exchanges_for_dim(shape, 1)

    yield from _flat_face_exchanges_for_dim(shape, 2)

    if self.diagonals.value.data:
        raise NotImplementedError("Diagonals support not implemented yet")

comm_layout() -> RankTopoAttr

Source code in xdsl/dialects/experimental/dmp.py
536
537
def comm_layout(self) -> RankTopoAttr:
    return RankTopoAttr(self.topology.as_tuple())

SwapOpHasShapeInferencePatterns dataclass

Bases: HasShapeInferencePatternsTrait

Source code in xdsl/dialects/experimental/dmp.py
642
643
644
645
646
647
648
649
650
class SwapOpHasShapeInferencePatterns(HasShapeInferencePatternsTrait):
    @classmethod
    def get_shape_inference_patterns(cls):
        from xdsl.transforms.shape_inference_patterns.dmp import (
            DmpSwapShapeInference,
            DmpSwapSwapsInference,
        )

        return (DmpSwapShapeInference(), DmpSwapSwapsInference())

get_shape_inference_patterns() classmethod

Source code in xdsl/dialects/experimental/dmp.py
643
644
645
646
647
648
649
650
@classmethod
def get_shape_inference_patterns(cls):
    from xdsl.transforms.shape_inference_patterns.dmp import (
        DmpSwapShapeInference,
        DmpSwapSwapsInference,
    )

    return (DmpSwapShapeInference(), DmpSwapSwapsInference())

SwapOpMemoryEffect dataclass

Bases: MemoryEffect

Side effect implementation of dmp.swap.

Source code in xdsl/dialects/experimental/dmp.py
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
class SwapOpMemoryEffect(MemoryEffect):
    """
    Side effect implementation of dmp.swap.
    """

    @classmethod
    def get_effects(cls, op: Operation) -> set[EffectInstance]:
        op = cast(SwapOp, op)
        # If it's operating in value-semantic mode, it has no side effects.
        if op.swapped_values:
            return set()
        # If it's operating in reference-semantic mode, it reads and writes to its field.
        # TODO: consider the empty swaps case at some point.
        # Right now, it relies on it before inferring them, so not very safe.
        # But it could be an elegant way to generically simplify those.
        return {
            EffectInstance(MemoryEffectKind.WRITE, op.input_stencil),
            EffectInstance(MemoryEffectKind.READ, op.input_stencil),
        }

get_effects(op: Operation) -> set[EffectInstance] classmethod

Source code in xdsl/dialects/experimental/dmp.py
658
659
660
661
662
663
664
665
666
667
668
669
670
671
@classmethod
def get_effects(cls, op: Operation) -> set[EffectInstance]:
    op = cast(SwapOp, op)
    # If it's operating in value-semantic mode, it has no side effects.
    if op.swapped_values:
        return set()
    # If it's operating in reference-semantic mode, it reads and writes to its field.
    # TODO: consider the empty swaps case at some point.
    # Right now, it relies on it before inferring them, so not very safe.
    # But it could be an elegant way to generically simplify those.
    return {
        EffectInstance(MemoryEffectKind.WRITE, op.input_stencil),
        EffectInstance(MemoryEffectKind.READ, op.input_stencil),
    }

SwapOp dataclass

Bases: IRDLOperation

Declarative swap of memref regions.

Source code in xdsl/dialects/experimental/dmp.py
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
@irdl_op_definition
class SwapOp(IRDLOperation):
    """
    Declarative swap of memref regions.
    """

    name = "dmp.swap"

    input_stencil = operand_def(stencil.StencilTypeConstr)
    swapped_values = opt_result_def(stencil.TempType[Attribute])

    swaps = attr_def(builtin.ArrayAttr[ExchangeDeclarationAttr])

    strategy = attr_def(DomainDecompositionStrategy)

    traits = traits_def(SwapOpHasShapeInferencePatterns(), SwapOpMemoryEffect())

    def verify_(self) -> None:
        if self.swapped_values:
            if isinstance(self.input_stencil.type, stencil.FieldType):
                raise VerifyException(
                    "dmp.swap_op cannot have a result if input is a field"
                )
        else:
            if isinstance(self.input_stencil.type, stencil.TempType):
                raise VerifyException(
                    "dmp.swap_op must have a result if input is a temporary"
                )

    @staticmethod
    def get(
        input_stencil: SSAValue | Operation,
        strategy: DomainDecompositionStrategy,
        swaps: builtin.ArrayAttr[ExchangeDeclarationAttr] | None = None,
    ):
        input_type = SSAValue.get(input_stencil).type

        result_types = (
            input_type if isa(input_type, stencil.TempType[Attribute]) else None
        )

        if swaps is None:
            swaps = builtin.ArrayAttr[ExchangeDeclarationAttr](())

        return SwapOp.build(
            operands=[input_stencil],
            result_types=[result_types],
            attributes={
                "strategy": strategy,
                "swaps": swaps,
            },
        )

name = 'dmp.swap' class-attribute instance-attribute

input_stencil = operand_def(stencil.StencilTypeConstr) class-attribute instance-attribute

swapped_values = opt_result_def(stencil.TempType[Attribute]) class-attribute instance-attribute

swaps = attr_def(builtin.ArrayAttr[ExchangeDeclarationAttr]) class-attribute instance-attribute

strategy = attr_def(DomainDecompositionStrategy) class-attribute instance-attribute

traits = traits_def(SwapOpHasShapeInferencePatterns(), SwapOpMemoryEffect()) class-attribute instance-attribute

verify_() -> None

Source code in xdsl/dialects/experimental/dmp.py
691
692
693
694
695
696
697
698
699
700
701
def verify_(self) -> None:
    if self.swapped_values:
        if isinstance(self.input_stencil.type, stencil.FieldType):
            raise VerifyException(
                "dmp.swap_op cannot have a result if input is a field"
            )
    else:
        if isinstance(self.input_stencil.type, stencil.TempType):
            raise VerifyException(
                "dmp.swap_op must have a result if input is a temporary"
            )

get(input_stencil: SSAValue | Operation, strategy: DomainDecompositionStrategy, swaps: builtin.ArrayAttr[ExchangeDeclarationAttr] | None = None) staticmethod

Source code in xdsl/dialects/experimental/dmp.py
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
@staticmethod
def get(
    input_stencil: SSAValue | Operation,
    strategy: DomainDecompositionStrategy,
    swaps: builtin.ArrayAttr[ExchangeDeclarationAttr] | None = None,
):
    input_type = SSAValue.get(input_stencil).type

    result_types = (
        input_type if isa(input_type, stencil.TempType[Attribute]) else None
    )

    if swaps is None:
        swaps = builtin.ArrayAttr[ExchangeDeclarationAttr](())

    return SwapOp.build(
        operands=[input_stencil],
        result_types=[result_types],
        attributes={
            "strategy": strategy,
            "swaps": swaps,
        },
    )