Skip to content

Mesh

mesh

MeshAxis: TypeAlias = I16 module-attribute

The type used to represent numbers on a mesh axis.

See the MLIR definition.

MeshAxesAttr: TypeAlias = DenseArrayBase[MeshAxis] module-attribute

The type used to represent a list of mesh axes.

See the MLIR definition.

Mesh = Dialect('mesh', [BroadcastOp, GatherOp, RecvOp, SendOp, ScatterOp, ShiftOp, MeshOp, ShardingOp, ShardOp], [ReductionKindAttr, ShardingType, MeshAxesArrayAttr]) module-attribute

MeshAxesArrayAttr dataclass

Bases: ParametrizedAttribute, OpaqueSyntaxAttribute

MeshAxesArrayAttr attribute for representing mutiple mesh axes.

Reflects the MLIR attribute.

Source code in xdsl/dialects/mesh.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
@irdl_attr_definition
class MeshAxesArrayAttr(ParametrizedAttribute, OpaqueSyntaxAttribute):
    """
    MeshAxesArrayAttr attribute for representing mutiple mesh axes.

    Reflects [the MLIR attribute](https://github.com/llvm/llvm-project/blob/6146a88f60492b520a36f8f8f3231e15f3cc6082/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td#L83).
    """

    name = "mesh.axisarray"

    axes: ArrayAttr[MeshAxesAttr]

    @classmethod
    def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
        """
        Parses a MeshAxesArrayAttr, which has the syntax of a list
        of lists, e.g.:

        [[1, 2, 3], [], [4, 5]]
        """
        axes = parser.parse_comma_separated_list(
            parser.Delimiter.SQUARE,
            lambda: _parse_mesh_axes_attr(parser),
        )

        return (ArrayAttr(axes),)

    def print_parameters(self, printer: Printer) -> None:
        """
        Prints a MeshAxesArrayAttr, which has the syntax of a list
        of lists, e.g.:

        [[1, 2, 3], [], [4, 5]]
        """
        with printer.in_square_brackets():
            printer.print_list(
                self.axes.data,
                lambda x: _print_sublist(printer, x),
            )

name = 'mesh.axisarray' class-attribute instance-attribute

axes: ArrayAttr[MeshAxesAttr] instance-attribute

parse_parameters(parser: AttrParser) -> Sequence[Attribute] classmethod

Parses a MeshAxesArrayAttr, which has the syntax of a list of lists, e.g.:

[[1, 2, 3], [], [4, 5]]

Source code in xdsl/dialects/mesh.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@classmethod
def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
    """
    Parses a MeshAxesArrayAttr, which has the syntax of a list
    of lists, e.g.:

    [[1, 2, 3], [], [4, 5]]
    """
    axes = parser.parse_comma_separated_list(
        parser.Delimiter.SQUARE,
        lambda: _parse_mesh_axes_attr(parser),
    )

    return (ArrayAttr(axes),)

print_parameters(printer: Printer) -> None

Prints a MeshAxesArrayAttr, which has the syntax of a list of lists, e.g.:

[[1, 2, 3], [], [4, 5]]

Source code in xdsl/dialects/mesh.py
119
120
121
122
123
124
125
126
127
128
129
130
def print_parameters(self, printer: Printer) -> None:
    """
    Prints a MeshAxesArrayAttr, which has the syntax of a list
    of lists, e.g.:

    [[1, 2, 3], [], [4, 5]]
    """
    with printer.in_square_brackets():
        printer.print_list(
            self.axes.data,
            lambda x: _print_sublist(printer, x),
        )

ReductionKind

Bases: StrEnum

Reduction kind for mesh dialect

Source code in xdsl/dialects/mesh.py
133
134
135
136
137
138
139
140
141
142
143
144
class ReductionKind(StrEnum):
    "Reduction kind for mesh dialect"

    SUM = auto()
    MAX = auto()
    MIN = auto()
    PRODUCT = auto()
    AVERAGE = auto()
    BITWISE_AND = auto()
    BITWISE_OR = auto()
    BITWISE_XOR = auto()
    GENERIC = auto()

SUM = auto() class-attribute instance-attribute

MAX = auto() class-attribute instance-attribute

MIN = auto() class-attribute instance-attribute

PRODUCT = auto() class-attribute instance-attribute

AVERAGE = auto() class-attribute instance-attribute

BITWISE_AND = auto() class-attribute instance-attribute

BITWISE_OR = auto() class-attribute instance-attribute

BITWISE_XOR = auto() class-attribute instance-attribute

GENERIC = auto() class-attribute instance-attribute

ReductionKindAttr dataclass

Bases: EnumAttribute[ReductionKind], SpacedOpaqueSyntaxAttribute

Source code in xdsl/dialects/mesh.py
147
148
149
150
151
@irdl_attr_definition
class ReductionKindAttr(EnumAttribute[ReductionKind], SpacedOpaqueSyntaxAttribute):
    name = "mesh.partial"

    assembly_format = "$value"

name = 'mesh.partial' class-attribute instance-attribute

assembly_format = '$value' class-attribute instance-attribute

ShardingType dataclass

Bases: ParametrizedAttribute, TypeAttribute

Source code in xdsl/dialects/mesh.py
154
155
156
@irdl_attr_definition
class ShardingType(ParametrizedAttribute, TypeAttribute):
    name = "mesh.sharding"

name = 'mesh.sharding' class-attribute instance-attribute

CollectiveCommunicationOp dataclass

Bases: IRDLOperation, ABC

Base class for collective communication ops.

Source code in xdsl/dialects/mesh.py
164
165
166
167
168
169
170
class CollectiveCommunicationOp(IRDLOperation, ABC):
    """
    Base class for collective communication ops.
    """

    mesh = prop_def(FlatSymbolRefAttr)
    mesh_axes = prop_def(MeshAxesAttr, default_value=MeshAxesAttr(i16, BytesAttr(b"")))

mesh = prop_def(FlatSymbolRefAttr) class-attribute instance-attribute

mesh_axes = prop_def(MeshAxesAttr, default_value=(MeshAxesAttr(i16, BytesAttr(b'')))) class-attribute instance-attribute

BroadcastOp dataclass

Bases: CollectiveCommunicationOp

Broadcast tensor from one device to many devices.

See external documentation.

Source code in xdsl/dialects/mesh.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
@irdl_op_definition
class BroadcastOp(CollectiveCommunicationOp):
    """
    Broadcast tensor from one device to many devices.

    See [external documentation](https://mlir.llvm.org/docs/Dialects/Shard/#shardbroadcast-shardbroadcastop).
    """

    name = "mesh.broadcast"

    input = operand_def(TensorType)
    root = prop_def(DenseArrayBase[I64])
    root_dynamic = var_operand_def(IndexType)

    result = result_def(TensorType)

    traits = traits_def(Pure())

    assembly_format = (
        "$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? "
        + "`root` `=` custom<DynamicIndexList>($root_dynamic, $root) "
        + "attr-dict `:` functional-type(operands, results)"
    )

    custom_directives = (DynamicIndexList,)

name = 'mesh.broadcast' class-attribute instance-attribute

input = operand_def(TensorType) class-attribute instance-attribute

root = prop_def(DenseArrayBase[I64]) class-attribute instance-attribute

root_dynamic = var_operand_def(IndexType) class-attribute instance-attribute

result = result_def(TensorType) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

assembly_format = '$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? ' + '`root` `=` custom<DynamicIndexList>($root_dynamic, $root) ' + 'attr-dict `:` functional-type(operands, results)' class-attribute instance-attribute

custom_directives = (DynamicIndexList,) class-attribute instance-attribute

GatherOp dataclass

Bases: CollectiveCommunicationOp

Gather tensor shards from many devices to a single device.

See external documentation.

Source code in xdsl/dialects/mesh.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
@irdl_op_definition
class GatherOp(CollectiveCommunicationOp):
    """
    Gather tensor shards from many devices to a single device.

    See [external documentation](https://mlir.llvm.org/docs/Dialects/Shard/#shardgather-shardgatherop).
    """

    name = "mesh.gather"

    input = operand_def(TensorType)
    gather_axis = prop_def(IntegerAttr.constr(IndexTypeConstr))
    root = prop_def(DenseArrayBase[I64])
    root_dynamic = var_operand_def(IndexType)

    result = result_def(TensorType)

    traits = traits_def(Pure())

    assembly_format = (
        "$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? "
        + "`gather_axis` `=` $gather_axis "
        + "`root` `=` custom<DynamicIndexList>($root_dynamic, $root) "
        + "attr-dict `:` functional-type(operands, results)"
    )

    custom_directives = (DynamicIndexList,)

name = 'mesh.gather' class-attribute instance-attribute

input = operand_def(TensorType) class-attribute instance-attribute

gather_axis = prop_def(IntegerAttr.constr(IndexTypeConstr)) class-attribute instance-attribute

root = prop_def(DenseArrayBase[I64]) class-attribute instance-attribute

root_dynamic = var_operand_def(IndexType) class-attribute instance-attribute

result = result_def(TensorType) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

assembly_format = '$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? ' + '`gather_axis` `=` $gather_axis ' + '`root` `=` custom<DynamicIndexList>($root_dynamic, $root) ' + 'attr-dict `:` functional-type(operands, results)' class-attribute instance-attribute

custom_directives = (DynamicIndexList,) class-attribute instance-attribute

ScatterOp dataclass

Bases: CollectiveCommunicationOp

Scatter tensor over a device mesh.

For each device group split the input tensor on the root device along axis scatter_axis and scatter the parts across the group devices.

See external documentation.

Source code in xdsl/dialects/mesh.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
@irdl_op_definition
class ScatterOp(CollectiveCommunicationOp):
    """
    Scatter tensor over a device mesh.

    For each device group split the input tensor on the `root` device along
    axis `scatter_axis` and scatter the parts across the group devices.

    See [external documentation](https://mlir.llvm.org/docs/Dialects/Shard/#shardscatter-shardscatterop).
    """

    name = "mesh.scatter"

    input = operand_def(TensorType)
    scatter_axis = prop_def(IntegerAttr.constr(IndexTypeConstr))

    root = prop_def(DenseArrayBase[I64])
    root_dynamic = var_operand_def(IndexType)

    result = result_def(TensorType)

    traits = traits_def(
        Pure(),
    )

    assembly_format = (
        "$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? "
        + "`scatter_axis` `=` $scatter_axis "
        + "`root` `=` custom<DynamicIndexList>($root_dynamic, $root) "
        + "attr-dict `:` functional-type(operands, results)"
    )

    custom_directives = (DynamicIndexList,)

name = 'mesh.scatter' class-attribute instance-attribute

input = operand_def(TensorType) class-attribute instance-attribute

scatter_axis = prop_def(IntegerAttr.constr(IndexTypeConstr)) class-attribute instance-attribute

root = prop_def(DenseArrayBase[I64]) class-attribute instance-attribute

root_dynamic = var_operand_def(IndexType) class-attribute instance-attribute

result = result_def(TensorType) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

assembly_format = '$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? ' + '`scatter_axis` `=` $scatter_axis ' + '`root` `=` custom<DynamicIndexList>($root_dynamic, $root) ' + 'attr-dict `:` functional-type(operands, results)' class-attribute instance-attribute

custom_directives = (DynamicIndexList,) class-attribute instance-attribute

RecvOp dataclass

Bases: CollectiveCommunicationOp

Receive from a device within a device group.

Source code in xdsl/dialects/mesh.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
@irdl_op_definition
class RecvOp(CollectiveCommunicationOp):
    """
    Receive from a device within a device group.
    """

    name = "mesh.recv"

    input = operand_def(TensorType)
    source = opt_prop_def(DenseArrayBase[I64])
    source_dynamic = var_operand_def(IndexType)

    result = result_def(TensorType)

    assembly_format = (
        "$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? "
        + "(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)? "
        + "attr-dict `:` functional-type(operands, results)"
    )

    custom_directives = (DynamicIndexList,)

name = 'mesh.recv' class-attribute instance-attribute

input = operand_def(TensorType) class-attribute instance-attribute

source = opt_prop_def(DenseArrayBase[I64]) class-attribute instance-attribute

source_dynamic = var_operand_def(IndexType) class-attribute instance-attribute

result = result_def(TensorType) class-attribute instance-attribute

assembly_format = '$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? ' + '(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)? ' + 'attr-dict `:` functional-type(operands, results)' class-attribute instance-attribute

custom_directives = (DynamicIndexList,) class-attribute instance-attribute

SendOp dataclass

Bases: CollectiveCommunicationOp

Send from one device to another within a device group.

Source code in xdsl/dialects/mesh.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
@irdl_op_definition
class SendOp(CollectiveCommunicationOp):
    """
    Send from one device to another within a device group.
    """

    name = "mesh.send"

    input = operand_def(TensorType)

    destination = prop_def(DenseArrayBase[I64])
    destination_dynamic = var_operand_def(IndexType)

    result = result_def(TensorType)

    assembly_format = (
        "$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? "
        + "`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination) "
        + "attr-dict `:` functional-type(operands, results)"
    )

    custom_directives = (DynamicIndexList,)

name = 'mesh.send' class-attribute instance-attribute

input = operand_def(TensorType) class-attribute instance-attribute

destination = prop_def(DenseArrayBase[I64]) class-attribute instance-attribute

destination_dynamic = var_operand_def(IndexType) class-attribute instance-attribute

result = result_def(TensorType) class-attribute instance-attribute

assembly_format = '$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? ' + '`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination) ' + 'attr-dict `:` functional-type(operands, results)' class-attribute instance-attribute

custom_directives = (DynamicIndexList,) class-attribute instance-attribute

ShiftOp dataclass

Bases: CollectiveCommunicationOp

Shift over a device mesh.

Within each device group shift along shift_axis by offset. If the rotate flag is present a rotation is performed instead of a shift.

Source code in xdsl/dialects/mesh.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
@irdl_op_definition
class ShiftOp(CollectiveCommunicationOp):
    """
    Shift over a device mesh.

    Within each device group shift along `shift_axis` by `offset`. If the
    `rotate` flag is present a rotation is performed instead of a shift.
    """

    name = "mesh.shift"

    input = operand_def(TensorType)

    shift_axis = prop_def(IntegerAttr.constr(IndexTypeConstr))
    offset = prop_def(IntegerAttr[I64])
    rotate = prop_def(UnitAttr)

    result = result_def(TensorType)

    traits = traits_def(
        Pure(),
    )

    assembly_format = (
        "$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? "
        + "`shift_axis` `=` $shift_axis "
        + "`offset` `=` $offset "
        + "(`rotate` $rotate^)? "
        + "attr-dict `:` type($input) `->` type($result)"
    )

name = 'mesh.shift' class-attribute instance-attribute

input = operand_def(TensorType) class-attribute instance-attribute

shift_axis = prop_def(IntegerAttr.constr(IndexTypeConstr)) class-attribute instance-attribute

offset = prop_def(IntegerAttr[I64]) class-attribute instance-attribute

rotate = prop_def(UnitAttr) class-attribute instance-attribute

result = result_def(TensorType) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

assembly_format = '$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? ' + '`shift_axis` `=` $shift_axis ' + '`offset` `=` $offset ' + '(`rotate` $rotate^)? ' + 'attr-dict `:` type($input) `->` type($result)' class-attribute instance-attribute

MeshOp dataclass

Bases: IRDLOperation

Source code in xdsl/dialects/mesh.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
@irdl_op_definition
class MeshOp(IRDLOperation):
    name = "mesh.mesh"

    sym_name = prop_def(SymbolNameConstraint())
    shape = prop_def(DenseArrayBase[I64])

    traits = traits_def(SymbolOpInterface())

    assembly_format = (
        "$sym_name `(` `shape` `=` custom<DimensionList>($shape) `)` attr-dict"
    )

    custom_directives = (DimensionList,)

    def verify_(self):
        if not self.shape.get_values():
            raise VerifyException(
                "'mesh.mesh' op rank of mesh is expected to be a positive integer"
            )

name = 'mesh.mesh' class-attribute instance-attribute

sym_name = prop_def(SymbolNameConstraint()) class-attribute instance-attribute

shape = prop_def(DenseArrayBase[I64]) class-attribute instance-attribute

traits = traits_def(SymbolOpInterface()) class-attribute instance-attribute

assembly_format = '$sym_name `(` `shape` `=` custom<DimensionList>($shape) `)` attr-dict' class-attribute instance-attribute

custom_directives = (DimensionList,) class-attribute instance-attribute

verify_()

Source code in xdsl/dialects/mesh.py
363
364
365
366
367
def verify_(self):
    if not self.shape.get_values():
        raise VerifyException(
            "'mesh.mesh' op rank of mesh is expected to be a positive integer"
        )

ShardingOp dataclass

Bases: IRDLOperation

Mesh dialect sharding operation.

Note: halo_sizes and sharded_dims_offsets are mutually exlcusive.

See external documentation

Source code in xdsl/dialects/mesh.py
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
@irdl_op_definition
class ShardingOp(IRDLOperation):
    """
    Mesh dialect sharding operation.

    Note: `halo_sizes` and `sharded_dims_offsets` are mutually exlcusive.

    See [external documentation](https://mlir.llvm.org/docs/Dialects/Shard/#shardsharding-shardshardingop)
    """

    name = "mesh.sharding"

    mesh = prop_def(FlatSymbolRefAttr)
    split_axes = prop_def(MeshAxesArrayAttr)
    partial_axes = opt_prop_def(MeshAxesAttr)
    partial_type = opt_prop_def(ReductionKindAttr)
    static_sharded_dims_offsets = prop_def(
        DenseArrayBase[I64], default_value=DenseArrayBase[I64](i64, BytesAttr(b""))
    )
    dynamic_sharded_dims_offsets = var_operand_def(I64)
    static_halo_sizes = prop_def(
        DenseArrayBase[I64], default_value=DenseArrayBase[I64](i64, BytesAttr(b""))
    )
    dynamic_halo_sizes = var_operand_def(I64)

    result = result_def(ShardingType)

    irdl_options = (AttrSizedOperandSegments(as_property=True),)

    traits = traits_def(
        Pure(),
    )

    assembly_format = (
        "$mesh `split_axes` "
        + "`=` $split_axes (`partial` `=` $partial_type $partial_axes^)? "
        + "(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)? "
        + "(`sharded_dims_offsets` `=` "
        + "custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)? "
        + "attr-dict `:` type($result)"
    )

    custom_directives = (DynamicIndexList,)

    def verify_(self) -> None:
        dims_offsets = (
            self.static_sharded_dims_offsets or self.dynamic_sharded_dims_offsets
        )
        halo_sizes = self.static_halo_sizes or self.dynamic_halo_sizes

        if dims_offsets and halo_sizes:
            raise VerifyException(
                "'mesh.sharding' cannot use both `halo_sizes` and `sharded_dims_offsets`"
            )

name = 'mesh.sharding' class-attribute instance-attribute

mesh = prop_def(FlatSymbolRefAttr) class-attribute instance-attribute

split_axes = prop_def(MeshAxesArrayAttr) class-attribute instance-attribute

partial_axes = opt_prop_def(MeshAxesAttr) class-attribute instance-attribute

partial_type = opt_prop_def(ReductionKindAttr) class-attribute instance-attribute

static_sharded_dims_offsets = prop_def(DenseArrayBase[I64], default_value=(DenseArrayBase[I64](i64, BytesAttr(b'')))) class-attribute instance-attribute

dynamic_sharded_dims_offsets = var_operand_def(I64) class-attribute instance-attribute

static_halo_sizes = prop_def(DenseArrayBase[I64], default_value=(DenseArrayBase[I64](i64, BytesAttr(b'')))) class-attribute instance-attribute

dynamic_halo_sizes = var_operand_def(I64) class-attribute instance-attribute

result = result_def(ShardingType) class-attribute instance-attribute

irdl_options = (AttrSizedOperandSegments(as_property=True),) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

assembly_format = '$mesh `split_axes` ' + '`=` $split_axes (`partial` `=` $partial_type $partial_axes^)? ' + '(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)? ' + '(`sharded_dims_offsets` `=` ' + 'custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)? ' + 'attr-dict `:` type($result)' class-attribute instance-attribute

custom_directives = (DynamicIndexList,) class-attribute instance-attribute

verify_() -> None

Source code in xdsl/dialects/mesh.py
419
420
421
422
423
424
425
426
427
428
def verify_(self) -> None:
    dims_offsets = (
        self.static_sharded_dims_offsets or self.dynamic_sharded_dims_offsets
    )
    halo_sizes = self.static_halo_sizes or self.dynamic_halo_sizes

    if dims_offsets and halo_sizes:
        raise VerifyException(
            "'mesh.sharding' cannot use both `halo_sizes` and `sharded_dims_offsets`"
        )

ShardOp

Bases: IRDLOperation

Annotate on how a tensor is sharded across a shard.

See external documentation.

Source code in xdsl/dialects/mesh.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
@irdl_op_definition
class ShardOp(IRDLOperation):
    """
    Annotate on how a tensor is sharded across a shard.

    See [external documentation](https://mlir.llvm.org/docs/Dialects/Shard/#shardshard-shardshardop).
    """

    name = "mesh.shard"

    T: ClassVar = VarConstraint("T", TensorType.constr())

    src = operand_def(T)
    sharding = operand_def(ShardingType)
    annotate_for_users = opt_prop_def(UnitAttr)

    result = result_def(T)

    traits = traits_def(
        Pure(),
    )

    assembly_format = "$src `to` $sharding (`annotate_for_users` $annotate_for_users^)? attr-dict `:` type($result)"

    def __init__(
        self,
        src: SSAValue,
        sharding: SSAValue,
        annotate_for_users: UnitAttr | None,
    ):
        return super().__init__(
            operands=[src, sharding],
            result_types=[src.type],
            properties={
                "annotate_for_users": annotate_for_users,
            },
        )

name = 'mesh.shard' class-attribute instance-attribute

T: ClassVar = VarConstraint('T', TensorType.constr()) class-attribute instance-attribute

src = operand_def(T) class-attribute instance-attribute

sharding = operand_def(ShardingType) class-attribute instance-attribute

annotate_for_users = opt_prop_def(UnitAttr) class-attribute instance-attribute

result = result_def(T) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

assembly_format = '$src `to` $sharding (`annotate_for_users` $annotate_for_users^)? attr-dict `:` type($result)' class-attribute instance-attribute

__init__(src: SSAValue, sharding: SSAValue, annotate_for_users: UnitAttr | None)

Source code in xdsl/dialects/mesh.py
455
456
457
458
459
460
461
462
463
464
465
466
467
def __init__(
    self,
    src: SSAValue,
    sharding: SSAValue,
    annotate_for_users: UnitAttr | None,
):
    return super().__init__(
        operands=[src, sharding],
        result_types=[src.type],
        properties={
            "annotate_for_users": annotate_for_users,
        },
    )