Skip to content

Vector

vector

DYNAMIC_INDEX: int = -2 ** 63 module-attribute

Vector = Dialect('vector', [BitcastOp, BroadcastOp, CreateMaskOp, ExtractElementOp, ExtractOp, FMAOp, InsertElementOp, InsertOp, LoadOp, MaskedLoadOp, MaskedStoreOp, PrintOp, ReductionOp, ShuffleOp, StoreOp, TransferReadOp, TransferWriteOp], [CombiningKindAttr]) module-attribute

LoadOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@irdl_op_definition
class LoadOp(IRDLOperation):
    name = "vector.load"
    base = operand_def(MemRefType)
    indices = var_operand_def(IndexType)
    result = result_def(VectorType)
    nontemporal = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False))

    irdl_options = (ParsePropInAttrDict(),)
    assembly_format = (
        "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)"
    )

    def __init__(
        self,
        ref: SSAValue | Operation,
        indices: Sequence[SSAValue | Operation],
        result_type: VectorType,
    ):
        super().__init__(
            operands=(ref, indices),
            result_types=(result_type,),
        )

    def verify_(self):
        assert isa(self.base.type, MemRefType)
        assert isa(self.result.type, VectorType[Attribute])

        if self.base.type.element_type != self.result.type.element_type:
            raise VerifyException(
                "MemRef element type should match the Vector element type."
            )

        if self.base.type.get_num_dims() != len(self.indices):
            raise VerifyException("Expected an index for each dimension.")

    @deprecated("Please use vector.LoadOp(ref, indices, result_type)")
    @staticmethod
    def get(
        ref: SSAValue | Operation, indices: Sequence[SSAValue | Operation]
    ) -> LoadOp:
        ref = SSAValue.get(ref, type=MemRefType)
        return LoadOp(ref, indices, VectorType(ref.type.element_type, [1]))

name = 'vector.load' class-attribute instance-attribute

base = operand_def(MemRefType) class-attribute instance-attribute

indices = var_operand_def(IndexType) class-attribute instance-attribute

result = result_def(VectorType) class-attribute instance-attribute

nontemporal = opt_prop_def(BoolAttr, default_value=(BoolAttr.from_bool(False))) class-attribute instance-attribute

irdl_options = (ParsePropInAttrDict(),) class-attribute instance-attribute

assembly_format = '$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)' class-attribute instance-attribute

__init__(ref: SSAValue | Operation, indices: Sequence[SSAValue | Operation], result_type: VectorType)

Source code in xdsl/dialects/vector.py
100
101
102
103
104
105
106
107
108
109
def __init__(
    self,
    ref: SSAValue | Operation,
    indices: Sequence[SSAValue | Operation],
    result_type: VectorType,
):
    super().__init__(
        operands=(ref, indices),
        result_types=(result_type,),
    )

verify_()

Source code in xdsl/dialects/vector.py
111
112
113
114
115
116
117
118
119
120
121
def verify_(self):
    assert isa(self.base.type, MemRefType)
    assert isa(self.result.type, VectorType[Attribute])

    if self.base.type.element_type != self.result.type.element_type:
        raise VerifyException(
            "MemRef element type should match the Vector element type."
        )

    if self.base.type.get_num_dims() != len(self.indices):
        raise VerifyException("Expected an index for each dimension.")

get(ref: SSAValue | Operation, indices: Sequence[SSAValue | Operation]) -> LoadOp staticmethod

Source code in xdsl/dialects/vector.py
123
124
125
126
127
128
129
@deprecated("Please use vector.LoadOp(ref, indices, result_type)")
@staticmethod
def get(
    ref: SSAValue | Operation, indices: Sequence[SSAValue | Operation]
) -> LoadOp:
    ref = SSAValue.get(ref, type=MemRefType)
    return LoadOp(ref, indices, VectorType(ref.type.element_type, [1]))

StoreOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@irdl_op_definition
class StoreOp(IRDLOperation):
    name = "vector.store"
    vector = operand_def(VectorType)
    base = operand_def(MemRefType)
    indices = var_operand_def(IndexType)
    nontemporal = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False))

    irdl_options = (ParsePropInAttrDict(),)
    assembly_format = (
        "$vector `,` $base `[` $indices `]` attr-dict `:` type($base) `,` type($vector)"
    )

    def __init__(
        self,
        vector: SSAValue | Operation,
        base: SSAValue | Operation,
        indices: Sequence[SSAValue | Operation],
        nontemporal: BoolAttr | None = None,
    ):
        super().__init__(
            operands=[vector, base, indices],
            properties={"nontemporal": nontemporal},
        )

    def verify_(self):
        assert isa(self.base.type, MemRefType)
        assert isa(self.vector.type, VectorType[Attribute])

        if self.base.type.element_type != self.vector.type.element_type:
            raise VerifyException(
                "MemRef element type should match the Vector element type."
            )

        if self.base.type.get_num_dims() != len(self.indices):
            raise VerifyException("Expected an index for each dimension.")

    @deprecated("Please use vector.StoreOp(vector, ref, indices)")
    @staticmethod
    def get(
        vector: Operation | SSAValue,
        ref: Operation | SSAValue,
        indices: Sequence[Operation | SSAValue],
    ) -> StoreOp:
        return StoreOp(vector, ref, indices)

name = 'vector.store' class-attribute instance-attribute

vector = operand_def(VectorType) class-attribute instance-attribute

base = operand_def(MemRefType) class-attribute instance-attribute

indices = var_operand_def(IndexType) class-attribute instance-attribute

nontemporal = opt_prop_def(BoolAttr, default_value=(BoolAttr.from_bool(False))) class-attribute instance-attribute

irdl_options = (ParsePropInAttrDict(),) class-attribute instance-attribute

assembly_format = '$vector `,` $base `[` $indices `]` attr-dict `:` type($base) `,` type($vector)' class-attribute instance-attribute

__init__(vector: SSAValue | Operation, base: SSAValue | Operation, indices: Sequence[SSAValue | Operation], nontemporal: BoolAttr | None = None)

Source code in xdsl/dialects/vector.py
145
146
147
148
149
150
151
152
153
154
155
def __init__(
    self,
    vector: SSAValue | Operation,
    base: SSAValue | Operation,
    indices: Sequence[SSAValue | Operation],
    nontemporal: BoolAttr | None = None,
):
    super().__init__(
        operands=[vector, base, indices],
        properties={"nontemporal": nontemporal},
    )

verify_()

Source code in xdsl/dialects/vector.py
157
158
159
160
161
162
163
164
165
166
167
def verify_(self):
    assert isa(self.base.type, MemRefType)
    assert isa(self.vector.type, VectorType[Attribute])

    if self.base.type.element_type != self.vector.type.element_type:
        raise VerifyException(
            "MemRef element type should match the Vector element type."
        )

    if self.base.type.get_num_dims() != len(self.indices):
        raise VerifyException("Expected an index for each dimension.")

get(vector: Operation | SSAValue, ref: Operation | SSAValue, indices: Sequence[Operation | SSAValue]) -> StoreOp staticmethod

Source code in xdsl/dialects/vector.py
169
170
171
172
173
174
175
176
@deprecated("Please use vector.StoreOp(vector, ref, indices)")
@staticmethod
def get(
    vector: Operation | SSAValue,
    ref: Operation | SSAValue,
    indices: Sequence[Operation | SSAValue],
) -> StoreOp:
    return StoreOp(vector, ref, indices)

ShuffleResultConstraint dataclass

Bases: AttrConstraint[VectorType]

Source code in xdsl/dialects/vector.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
@dataclass(frozen=True)
class ShuffleResultConstraint(AttrConstraint[VectorType]):
    element_constr: AttrConstraint
    v1_shape_constr: VarConstraint
    v2_shape_constr: VarConstraint
    mask_constraint: VarConstraint

    def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
        # We can only verify the element type here, and not the relations to other shapes
        VectorType.constr(self.element_constr).verify(attr, constraint_context)
        attr = cast(VectorType, attr)
        if not attr.shape.data:
            raise VerifyException("Result vector type must not be 0-D.")

    def can_infer(self, var_constraint_names: AbstractSet[str]) -> bool:
        res = self.element_constr.can_infer(var_constraint_names) and (
            self.v1_shape_constr.name in var_constraint_names
            and self.v2_shape_constr.name in var_constraint_names
            and self.mask_constraint.name in var_constraint_names
        )
        assert res
        return res

    def infer(self, context: ConstraintContext) -> VectorType:
        v1_shape = context.get_variable(self.v1_shape_constr.name)
        v2_shape = context.get_variable(self.v2_shape_constr.name)
        mask = context.get_variable(self.mask_constraint.name)
        assert v1_shape is not None
        assert v2_shape is not None
        assert mask is not None
        assert _IntArrayConstr.verifies(v1_shape)
        assert _IntArrayConstr.verifies(v2_shape)
        assert _MaskConstr.verifies(mask)

        result_trailing: tuple[IntAttr, ...]
        if not v1_shape:
            assert not v2_shape
            result_trailing = ()
        else:
            result_trailing = v1_shape.data[1:]

        element_type = self.element_constr.infer(context)
        result_leading = len(mask)
        shape = (
            (IntAttr(result_leading), *result_trailing)
            if result_leading
            else result_trailing
        )
        return VectorType(element_type, ArrayAttr(shape))

    def mapping_type_vars(
        self, type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint]
    ) -> AttrConstraint[VectorType]:
        return ShuffleResultConstraint(
            self.element_constr.mapping_type_vars(type_var_mapping),
            self.v1_shape_constr.mapping_type_vars(type_var_mapping),
            self.v2_shape_constr.mapping_type_vars(type_var_mapping),
            self.mask_constraint.mapping_type_vars(type_var_mapping),
        )

element_constr: AttrConstraint instance-attribute

v1_shape_constr: VarConstraint instance-attribute

v2_shape_constr: VarConstraint instance-attribute

mask_constraint: VarConstraint instance-attribute

__init__(element_constr: AttrConstraint, v1_shape_constr: VarConstraint, v2_shape_constr: VarConstraint, mask_constraint: VarConstraint) -> None

verify(attr: Attribute, constraint_context: ConstraintContext) -> None

Source code in xdsl/dialects/vector.py
190
191
192
193
194
195
def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
    # We can only verify the element type here, and not the relations to other shapes
    VectorType.constr(self.element_constr).verify(attr, constraint_context)
    attr = cast(VectorType, attr)
    if not attr.shape.data:
        raise VerifyException("Result vector type must not be 0-D.")

can_infer(var_constraint_names: AbstractSet[str]) -> bool

Source code in xdsl/dialects/vector.py
197
198
199
200
201
202
203
204
def can_infer(self, var_constraint_names: AbstractSet[str]) -> bool:
    res = self.element_constr.can_infer(var_constraint_names) and (
        self.v1_shape_constr.name in var_constraint_names
        and self.v2_shape_constr.name in var_constraint_names
        and self.mask_constraint.name in var_constraint_names
    )
    assert res
    return res

infer(context: ConstraintContext) -> VectorType

Source code in xdsl/dialects/vector.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def infer(self, context: ConstraintContext) -> VectorType:
    v1_shape = context.get_variable(self.v1_shape_constr.name)
    v2_shape = context.get_variable(self.v2_shape_constr.name)
    mask = context.get_variable(self.mask_constraint.name)
    assert v1_shape is not None
    assert v2_shape is not None
    assert mask is not None
    assert _IntArrayConstr.verifies(v1_shape)
    assert _IntArrayConstr.verifies(v2_shape)
    assert _MaskConstr.verifies(mask)

    result_trailing: tuple[IntAttr, ...]
    if not v1_shape:
        assert not v2_shape
        result_trailing = ()
    else:
        result_trailing = v1_shape.data[1:]

    element_type = self.element_constr.infer(context)
    result_leading = len(mask)
    shape = (
        (IntAttr(result_leading), *result_trailing)
        if result_leading
        else result_trailing
    )
    return VectorType(element_type, ArrayAttr(shape))

mapping_type_vars(type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint]) -> AttrConstraint[VectorType]

Source code in xdsl/dialects/vector.py
233
234
235
236
237
238
239
240
241
def mapping_type_vars(
    self, type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint]
) -> AttrConstraint[VectorType]:
    return ShuffleResultConstraint(
        self.element_constr.mapping_type_vars(type_var_mapping),
        self.v1_shape_constr.mapping_type_vars(type_var_mapping),
        self.v2_shape_constr.mapping_type_vars(type_var_mapping),
        self.mask_constraint.mapping_type_vars(type_var_mapping),
    )

ShuffleOp

Bases: IRDLOperation

The shuffle operation constructs a permutation (or duplication) of elements from two input vectors, returning a vector with the same element type as the input and a length that is the same as the shuffle mask. The two input vectors must have the same element type, same rank , and trailing dimension sizes and shuffles their values in the leading dimension (which may differ in size) according to the given mask. The legality rules are: * the two operands must have the same element type as the result - Either, the two operands and the result must have the same rank and trailing dimension sizes, viz. given two k-D operands v1 : and v2 : we have s_i = t_i for all 1 < i <= k - Or, the two operands must be 0-D vectors and the result is a 1-D vector. * the mask length equals the leading dimension size of the result * numbering the input vector indices left to right across the operands, all mask values must be within range, viz. given two k-D operands v1 and v2 above, all mask values are in the range [0,s_1+t_1)

Note, scalable vectors are not supported.

Example:

%0 = vector.shuffle %a, %a [0, 3]
            : vector<2xf32>, vector<2xf32>       ; yields vector<2xf32>
%1 = vector.shuffle %c, %b [0, 1, 2]
            : vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32>
%2 = vector.shuffle %a, %a [3, 2, 1, 0]
             : vector<2xf32>, vector<2xf32>      ; yields vector<4xf32>
%3 = vector.shuffle %d, %d [0, 1]
            : vector<f32>, vector<f32>           ; yields vector<2xf32>

See external documentation.

Source code in xdsl/dialects/vector.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
@irdl_op_definition
class ShuffleOp(IRDLOperation):
    """
    The shuffle operation constructs a permutation (or duplication) of elements
    from two input vectors, returning a vector with the same element type as
    the input and a length that is the same as the shuffle mask. The two input
    vectors must have the same element type, same rank , and trailing dimension
    sizes and shuffles their values in the
    leading dimension (which may differ in size) according to the given mask.
    The legality rules are:
    * the two operands must have the same element type as the result
      - Either, the two operands and the result must have the same
        rank and trailing dimension sizes, viz. given two k-D operands
                v1 : <s_1 x s_2 x .. x s_k x type> and
                v2 : <t_1 x t_2 x .. x t_k x type>
        we have s_i = t_i for all 1 < i <= k
      - Or, the two operands must be 0-D vectors and the result is a 1-D vector.
    * the mask length equals the leading dimension size of the result
    * numbering the input vector indices left to right across the operands, all
      mask values must be within range, viz. given two k-D operands v1 and v2
      above, all mask values are in the range [0,s_1+t_1)

    Note, scalable vectors are not supported.

    Example:

    ```mlir
    %0 = vector.shuffle %a, %a [0, 3]
                : vector<2xf32>, vector<2xf32>       ; yields vector<2xf32>
    %1 = vector.shuffle %c, %b [0, 1, 2]
                : vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32>
    %2 = vector.shuffle %a, %a [3, 2, 1, 0]
                 : vector<2xf32>, vector<2xf32>      ; yields vector<4xf32>
    %3 = vector.shuffle %d, %d [0, 1]
                : vector<f32>, vector<f32>           ; yields vector<2xf32>
    ```

    See external [documentation](https://mlir.llvm.org/docs/Dialects/Vector/#vectorshuffle-vectorshuffleop).
    """

    name = "vector.shuffle"

    T: ClassVar = VarConstraint("T", AnyAttr())
    V1_SHAPE: ClassVar = VarConstraint("V1_SHAPE", _IntArrayConstr)
    V2_SHAPE: ClassVar = VarConstraint("V2_SHAPE", _IntArrayConstr)
    MASK: ClassVar = VarConstraint("MASK", _MaskConstr)
    RES: ClassVar = ShuffleResultConstraint(T, V1_SHAPE, V2_SHAPE, MASK)

    v1 = operand_def(VectorType.constr(T, shape=V1_SHAPE))
    v2 = operand_def(VectorType.constr(T, shape=V2_SHAPE))
    mask = prop_def(MASK)
    result = result_def(RES)

    irdl_options = (ParsePropInAttrDict(),)
    traits = traits_def(NoMemoryEffect())

    assembly_format = "operands $mask attr-dict `:` type(operands)"

    def __init__(
        self,
        v1: SSAValue,
        v2: SSAValue,
        mask: DenseArrayBase[I64],
        *,
        result_type: VectorType,
    ):
        super().__init__(
            operands=(v1, v2),
            result_types=(result_type,),
            properties={"mask": mask},
        )

    def verify_(self):
        assert isa(self.v1.type, VectorType)
        assert isa(self.v2.type, VectorType)
        assert isa(self.result.type, VectorType)

        v1_shape = self.v1.type.get_shape()
        v2_shape = self.v2.type.get_shape()
        result_shape = self.result.type.get_shape()
        mask = self.mask.get_values()

        result_leading_dim = result_shape[0]

        if len(mask) != result_leading_dim:
            # the mask length equals the leading dimension size of the result
            raise VerifyException(
                f"Length of mask {self.mask} must equal leading dim of result {self.result.type}."
            )

        if not v1_shape or not v2_shape:
            if v1_shape or v2_shape:
                raise VerifyException(
                    "Inputs must either both be non-0-D or both be 0-D"
                )

            if len(result_shape) != 1:
                raise VerifyException("If inputs are 0-D output must be 1-D")

            v1_leading_dim = 1
            v2_leading_dim = 1
        else:
            v1_leading_dim, *v1_trailing = v1_shape
            v2_leading_dim, *v2_trailing = v2_shape

            if v1_trailing != v2_trailing:
                raise VerifyException("Input trailing dimensions must match")

        dim_bound = v1_leading_dim + v2_leading_dim
        for dim in mask:
            if not (-1 <= dim < dim_bound):
                raise VerifyException(
                    f"Mask value {dim} out of range [-1, {dim_bound})"
                )

name = 'vector.shuffle' class-attribute instance-attribute

T: ClassVar = VarConstraint('T', AnyAttr()) class-attribute instance-attribute

V1_SHAPE: ClassVar = VarConstraint('V1_SHAPE', _IntArrayConstr) class-attribute instance-attribute

V2_SHAPE: ClassVar = VarConstraint('V2_SHAPE', _IntArrayConstr) class-attribute instance-attribute

MASK: ClassVar = VarConstraint('MASK', _MaskConstr) class-attribute instance-attribute

RES: ClassVar = ShuffleResultConstraint(T, V1_SHAPE, V2_SHAPE, MASK) class-attribute instance-attribute

v1 = operand_def(VectorType.constr(T, shape=V1_SHAPE)) class-attribute instance-attribute

v2 = operand_def(VectorType.constr(T, shape=V2_SHAPE)) class-attribute instance-attribute

mask = prop_def(MASK) class-attribute instance-attribute

result = result_def(RES) class-attribute instance-attribute

irdl_options = (ParsePropInAttrDict(),) class-attribute instance-attribute

traits = traits_def(NoMemoryEffect()) class-attribute instance-attribute

assembly_format = 'operands $mask attr-dict `:` type(operands)' class-attribute instance-attribute

__init__(v1: SSAValue, v2: SSAValue, mask: DenseArrayBase[I64], *, result_type: VectorType)

Source code in xdsl/dialects/vector.py
302
303
304
305
306
307
308
309
310
311
312
313
314
def __init__(
    self,
    v1: SSAValue,
    v2: SSAValue,
    mask: DenseArrayBase[I64],
    *,
    result_type: VectorType,
):
    super().__init__(
        operands=(v1, v2),
        result_types=(result_type,),
        properties={"mask": mask},
    )

verify_()

Source code in xdsl/dialects/vector.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def verify_(self):
    assert isa(self.v1.type, VectorType)
    assert isa(self.v2.type, VectorType)
    assert isa(self.result.type, VectorType)

    v1_shape = self.v1.type.get_shape()
    v2_shape = self.v2.type.get_shape()
    result_shape = self.result.type.get_shape()
    mask = self.mask.get_values()

    result_leading_dim = result_shape[0]

    if len(mask) != result_leading_dim:
        # the mask length equals the leading dimension size of the result
        raise VerifyException(
            f"Length of mask {self.mask} must equal leading dim of result {self.result.type}."
        )

    if not v1_shape or not v2_shape:
        if v1_shape or v2_shape:
            raise VerifyException(
                "Inputs must either both be non-0-D or both be 0-D"
            )

        if len(result_shape) != 1:
            raise VerifyException("If inputs are 0-D output must be 1-D")

        v1_leading_dim = 1
        v2_leading_dim = 1
    else:
        v1_leading_dim, *v1_trailing = v1_shape
        v2_leading_dim, *v2_trailing = v2_shape

        if v1_trailing != v2_trailing:
            raise VerifyException("Input trailing dimensions must match")

    dim_bound = v1_leading_dim + v2_leading_dim
    for dim in mask:
        if not (-1 <= dim < dim_bound):
            raise VerifyException(
                f"Mask value {dim} out of range [-1, {dim_bound})"
            )

BroadcastOp

Bases: IRDLOperation

See external documentation.

Source code in xdsl/dialects/vector.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
@irdl_op_definition
class BroadcastOp(IRDLOperation):
    """
    See external [documentation](https://mlir.llvm.org/docs/Dialects/Vector/#vectorbroadcast-vectorbroadcastop).
    """

    name = "vector.broadcast"
    source = operand_def()
    vector = result_def(VectorType)
    traits = traits_def(Pure())

    assembly_format = "$source attr-dict `:` type($source) `to` type($vector)"

    def __init__(self, source: Operation | SSAValue, result_type: VectorType):
        super().__init__(operands=(source,), result_types=(result_type,))

    def verify_(self):
        if isa(self.source.type, VectorType):
            element_type = self.source.type.element_type
        else:
            element_type = self.source.type

        if element_type != self.vector.type.element_type:
            raise VerifyException(
                "Source operand and result vector must have the same element type."
            )

    @deprecated("Please use vector.BroadcastOp(source, result_type)")
    @staticmethod
    def get(source: Operation | SSAValue) -> BroadcastOp:
        return BroadcastOp(source, VectorType(SSAValue.get(source).type, [1]))

name = 'vector.broadcast' class-attribute instance-attribute

source = operand_def() class-attribute instance-attribute

vector = result_def(VectorType) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

assembly_format = '$source attr-dict `:` type($source) `to` type($vector)' class-attribute instance-attribute

__init__(source: Operation | SSAValue, result_type: VectorType)

Source code in xdsl/dialects/vector.py
373
374
def __init__(self, source: Operation | SSAValue, result_type: VectorType):
    super().__init__(operands=(source,), result_types=(result_type,))

verify_()

Source code in xdsl/dialects/vector.py
376
377
378
379
380
381
382
383
384
385
def verify_(self):
    if isa(self.source.type, VectorType):
        element_type = self.source.type.element_type
    else:
        element_type = self.source.type

    if element_type != self.vector.type.element_type:
        raise VerifyException(
            "Source operand and result vector must have the same element type."
        )

get(source: Operation | SSAValue) -> BroadcastOp staticmethod

Source code in xdsl/dialects/vector.py
387
388
389
390
@deprecated("Please use vector.BroadcastOp(source, result_type)")
@staticmethod
def get(source: Operation | SSAValue) -> BroadcastOp:
    return BroadcastOp(source, VectorType(SSAValue.get(source).type, [1]))

FMAOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
@irdl_op_definition
class FMAOp(IRDLOperation):
    name = "vector.fma"

    T: ClassVar = VarConstraint("T", VectorType.constr(AnyFloatConstr))

    lhs = operand_def(T)
    rhs = operand_def(T)
    acc = operand_def(T)
    res = result_def(T)
    traits = traits_def(Pure())

    assembly_format = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)"

    def __init__(
        self,
        lhs: Operation | SSAValue,
        rhs: Operation | SSAValue,
        acc: Operation | SSAValue,
    ):
        acc = SSAValue.get(acc)
        super().__init__(operands=(lhs, rhs, acc), result_types=(acc.type,))

    @deprecated("Please use vector.FMAOp(lhs, rhs, acc)")
    @staticmethod
    def get(
        lhs: Operation | SSAValue, rhs: Operation | SSAValue, acc: Operation | SSAValue
    ) -> FMAOp:
        return FMAOp(lhs, rhs, acc)

name = 'vector.fma' class-attribute instance-attribute

T: ClassVar = VarConstraint('T', VectorType.constr(AnyFloatConstr)) class-attribute instance-attribute

lhs = operand_def(T) class-attribute instance-attribute

rhs = operand_def(T) class-attribute instance-attribute

acc = operand_def(T) class-attribute instance-attribute

res = result_def(T) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

assembly_format = '$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)' class-attribute instance-attribute

__init__(lhs: Operation | SSAValue, rhs: Operation | SSAValue, acc: Operation | SSAValue)

Source code in xdsl/dialects/vector.py
407
408
409
410
411
412
413
414
def __init__(
    self,
    lhs: Operation | SSAValue,
    rhs: Operation | SSAValue,
    acc: Operation | SSAValue,
):
    acc = SSAValue.get(acc)
    super().__init__(operands=(lhs, rhs, acc), result_types=(acc.type,))

get(lhs: Operation | SSAValue, rhs: Operation | SSAValue, acc: Operation | SSAValue) -> FMAOp staticmethod

Source code in xdsl/dialects/vector.py
416
417
418
419
420
421
@deprecated("Please use vector.FMAOp(lhs, rhs, acc)")
@staticmethod
def get(
    lhs: Operation | SSAValue, rhs: Operation | SSAValue, acc: Operation | SSAValue
) -> FMAOp:
    return FMAOp(lhs, rhs, acc)

MaskedLoadOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
@irdl_op_definition
class MaskedLoadOp(IRDLOperation):
    name = "vector.maskedload"
    base = operand_def(MemRefType)
    indices = var_operand_def(IndexType)
    mask = operand_def(VectorBaseTypeAndRankConstraint(i1, 1))
    pass_thru = operand_def(VectorType)
    result = result_def(VectorRankConstraint(1))

    assembly_format = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"  # noqa: E501

    def __init__(
        self,
        base: SSAValue | Operation,
        indices: Sequence[SSAValue | Operation],
        mask: SSAValue | Operation,
        pass_thru: SSAValue | Operation,
        result_type: VectorType | None = None,
    ):
        pass_thru = SSAValue.get(pass_thru, type=VectorType)
        if result_type is None:
            result_type = pass_thru.type
        super().__init__(
            operands=[base, indices, mask, pass_thru],
            result_types=[result_type],
        )

    def verify_(self):
        memref_type = self.base.type
        assert isa(memref_type, MemRefType)
        memref_element_type = memref_type.element_type

        res_type = self.result.type
        assert isa(res_type, VectorType[Attribute])
        res_element_type = res_type.element_type

        passthrough_type = self.pass_thru.type
        assert isa(passthrough_type, VectorType[Attribute])
        passthrough_element_type = passthrough_type.element_type

        if memref_element_type != res_element_type:
            raise VerifyException(
                "MemRef element type should match the result vector and passthrough vector "
                "element type. Found different element types for memref and result."
            )
        elif memref_element_type != passthrough_element_type:
            raise VerifyException(
                "MemRef element type should match the result vector and passthrough vector "
                "element type. Found different element types for memref and passthrough."
            )

        if memref_type.get_num_dims() != len(self.indices):
            raise VerifyException("Expected an index for each memref dimension.")

    @deprecated(
        "Please use vector.MaskedLoadOp(memref, indices, mask, passthrough, result_type)"
    )
    @staticmethod
    def get(
        memref: SSAValue | Operation,
        indices: Sequence[SSAValue | Operation],
        mask: SSAValue | Operation,
        passthrough: SSAValue | Operation,
    ) -> MaskedLoadOp:
        memref = SSAValue.get(memref, type=MemRefType)

        return MaskedLoadOp.build(
            operands=[memref, indices, mask, passthrough],
            result_types=[VectorType(memref.type.element_type, [1])],
        )

name = 'vector.maskedload' class-attribute instance-attribute

base = operand_def(MemRefType) class-attribute instance-attribute

indices = var_operand_def(IndexType) class-attribute instance-attribute

mask = operand_def(VectorBaseTypeAndRankConstraint(i1, 1)) class-attribute instance-attribute

pass_thru = operand_def(VectorType) class-attribute instance-attribute

result = result_def(VectorRankConstraint(1)) class-attribute instance-attribute

assembly_format = '$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)' class-attribute instance-attribute

__init__(base: SSAValue | Operation, indices: Sequence[SSAValue | Operation], mask: SSAValue | Operation, pass_thru: SSAValue | Operation, result_type: VectorType | None = None)

Source code in xdsl/dialects/vector.py
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def __init__(
    self,
    base: SSAValue | Operation,
    indices: Sequence[SSAValue | Operation],
    mask: SSAValue | Operation,
    pass_thru: SSAValue | Operation,
    result_type: VectorType | None = None,
):
    pass_thru = SSAValue.get(pass_thru, type=VectorType)
    if result_type is None:
        result_type = pass_thru.type
    super().__init__(
        operands=[base, indices, mask, pass_thru],
        result_types=[result_type],
    )

verify_()

Source code in xdsl/dialects/vector.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
def verify_(self):
    memref_type = self.base.type
    assert isa(memref_type, MemRefType)
    memref_element_type = memref_type.element_type

    res_type = self.result.type
    assert isa(res_type, VectorType[Attribute])
    res_element_type = res_type.element_type

    passthrough_type = self.pass_thru.type
    assert isa(passthrough_type, VectorType[Attribute])
    passthrough_element_type = passthrough_type.element_type

    if memref_element_type != res_element_type:
        raise VerifyException(
            "MemRef element type should match the result vector and passthrough vector "
            "element type. Found different element types for memref and result."
        )
    elif memref_element_type != passthrough_element_type:
        raise VerifyException(
            "MemRef element type should match the result vector and passthrough vector "
            "element type. Found different element types for memref and passthrough."
        )

    if memref_type.get_num_dims() != len(self.indices):
        raise VerifyException("Expected an index for each memref dimension.")

get(memref: SSAValue | Operation, indices: Sequence[SSAValue | Operation], mask: SSAValue | Operation, passthrough: SSAValue | Operation) -> MaskedLoadOp staticmethod

Source code in xdsl/dialects/vector.py
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
@deprecated(
    "Please use vector.MaskedLoadOp(memref, indices, mask, passthrough, result_type)"
)
@staticmethod
def get(
    memref: SSAValue | Operation,
    indices: Sequence[SSAValue | Operation],
    mask: SSAValue | Operation,
    passthrough: SSAValue | Operation,
) -> MaskedLoadOp:
    memref = SSAValue.get(memref, type=MemRefType)

    return MaskedLoadOp.build(
        operands=[memref, indices, mask, passthrough],
        result_types=[VectorType(memref.type.element_type, [1])],
    )

MaskedStoreOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
@irdl_op_definition
class MaskedStoreOp(IRDLOperation):
    name = "vector.maskedstore"
    base = operand_def(MemRefType)
    indices = var_operand_def(IndexType)
    mask = operand_def(VectorBaseTypeAndRankConstraint(i1, 1))
    value_to_store = operand_def(VectorRankConstraint(1))

    assembly_format = "$base `[` $indices `]` `,` $mask `,` $value_to_store attr-dict `:` type($base) `,` type($mask) `,` type($value_to_store)"  # noqa: E501

    def verify_(self):
        memref_type = self.base.type
        assert isa(memref_type, MemRefType)
        memref_element_type = memref_type.element_type

        value_to_store_type = self.value_to_store.type
        assert isa(value_to_store_type, VectorType[Attribute])

        mask_type = self.mask.type
        assert isa(mask_type, VectorType[Attribute])

        if memref_element_type != value_to_store_type.element_type:
            raise VerifyException(
                "MemRef element type should match the stored vector type. "
                "Obtained types were "
                + str(memref_element_type)
                + " and "
                + str(value_to_store_type.element_type)
                + "."
            )

        if memref_type.get_num_dims() != len(self.indices):
            raise VerifyException("Expected an index for each memref dimension.")

    def __init__(
        self,
        memref: SSAValue | Operation,
        indices: Sequence[SSAValue | Operation],
        mask: SSAValue | Operation,
        value_to_store: SSAValue | Operation,
    ):
        super().__init__(operands=[memref, indices, mask, value_to_store])

    @deprecated(
        "Please use vector.MaskedStoreOp(memref, indices, mask, value_to_store)"
    )
    @staticmethod
    def get(
        memref: SSAValue | Operation,
        indices: Sequence[SSAValue | Operation],
        mask: SSAValue | Operation,
        value_to_store: SSAValue | Operation,
    ) -> MaskedStoreOp:
        return MaskedStoreOp(memref, indices, mask, value_to_store)

name = 'vector.maskedstore' class-attribute instance-attribute

base = operand_def(MemRefType) class-attribute instance-attribute

indices = var_operand_def(IndexType) class-attribute instance-attribute

mask = operand_def(VectorBaseTypeAndRankConstraint(i1, 1)) class-attribute instance-attribute

value_to_store = operand_def(VectorRankConstraint(1)) class-attribute instance-attribute

assembly_format = '$base `[` $indices `]` `,` $mask `,` $value_to_store attr-dict `:` type($base) `,` type($mask) `,` type($value_to_store)' class-attribute instance-attribute

verify_()

Source code in xdsl/dialects/vector.py
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
def verify_(self):
    memref_type = self.base.type
    assert isa(memref_type, MemRefType)
    memref_element_type = memref_type.element_type

    value_to_store_type = self.value_to_store.type
    assert isa(value_to_store_type, VectorType[Attribute])

    mask_type = self.mask.type
    assert isa(mask_type, VectorType[Attribute])

    if memref_element_type != value_to_store_type.element_type:
        raise VerifyException(
            "MemRef element type should match the stored vector type. "
            "Obtained types were "
            + str(memref_element_type)
            + " and "
            + str(value_to_store_type.element_type)
            + "."
        )

    if memref_type.get_num_dims() != len(self.indices):
        raise VerifyException("Expected an index for each memref dimension.")

__init__(memref: SSAValue | Operation, indices: Sequence[SSAValue | Operation], mask: SSAValue | Operation, value_to_store: SSAValue | Operation)

Source code in xdsl/dialects/vector.py
530
531
532
533
534
535
536
537
def __init__(
    self,
    memref: SSAValue | Operation,
    indices: Sequence[SSAValue | Operation],
    mask: SSAValue | Operation,
    value_to_store: SSAValue | Operation,
):
    super().__init__(operands=[memref, indices, mask, value_to_store])

get(memref: SSAValue | Operation, indices: Sequence[SSAValue | Operation], mask: SSAValue | Operation, value_to_store: SSAValue | Operation) -> MaskedStoreOp staticmethod

Source code in xdsl/dialects/vector.py
539
540
541
542
543
544
545
546
547
548
549
@deprecated(
    "Please use vector.MaskedStoreOp(memref, indices, mask, value_to_store)"
)
@staticmethod
def get(
    memref: SSAValue | Operation,
    indices: Sequence[SSAValue | Operation],
    mask: SSAValue | Operation,
    value_to_store: SSAValue | Operation,
) -> MaskedStoreOp:
    return MaskedStoreOp(memref, indices, mask, value_to_store)

PrintOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
552
553
554
555
556
557
558
559
560
561
562
563
@irdl_op_definition
class PrintOp(IRDLOperation):
    name = "vector.print"
    source = operand_def()

    def __init__(self, source: SSAValue | Operation):
        super().__init__(operands=[SSAValue.get(source)])

    @deprecated("Please use vector.PrintOp(source)")
    @staticmethod
    def get(source: Operation | SSAValue) -> PrintOp:
        return PrintOp(source)

name = 'vector.print' class-attribute instance-attribute

source = operand_def() class-attribute instance-attribute

__init__(source: SSAValue | Operation)

Source code in xdsl/dialects/vector.py
557
558
def __init__(self, source: SSAValue | Operation):
    super().__init__(operands=[SSAValue.get(source)])

get(source: Operation | SSAValue) -> PrintOp staticmethod

Source code in xdsl/dialects/vector.py
560
561
562
563
@deprecated("Please use vector.PrintOp(source)")
@staticmethod
def get(source: Operation | SSAValue) -> PrintOp:
    return PrintOp(source)

CreateMaskOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
@irdl_op_definition
class CreateMaskOp(IRDLOperation):
    name = "vector.create_mask"
    mask_dim_sizes = var_operand_def(IndexType)
    mask_vector = result_def(VectorBaseTypeConstraint(i1))

    assembly_format = "$mask_dim_sizes attr-dict `:` type(results)"

    def __init__(
        self, mask_operands: list[Operation | SSAValue], result_type: VectorType
    ):
        super().__init__(operands=(mask_operands,), result_types=(result_type,))

    def verify_(self):
        assert isa(self.mask_vector.type, VectorType[Attribute])
        if self.mask_vector.type.get_num_dims() != len(self.mask_dim_sizes):
            raise VerifyException(
                "Expected an operand value for each dimension of resultant mask."
            )

    @deprecated("Please use vector.CreateMaskOp(mask_operands, result_type)")
    @staticmethod
    def get(mask_operands: list[Operation | SSAValue]) -> CreateMaskOp:
        return CreateMaskOp.build(
            operands=[mask_operands],
            result_types=[VectorType(i1, [1])],
        )

name = 'vector.create_mask' class-attribute instance-attribute

mask_dim_sizes = var_operand_def(IndexType) class-attribute instance-attribute

mask_vector = result_def(VectorBaseTypeConstraint(i1)) class-attribute instance-attribute

assembly_format = '$mask_dim_sizes attr-dict `:` type(results)' class-attribute instance-attribute

__init__(mask_operands: list[Operation | SSAValue], result_type: VectorType)

Source code in xdsl/dialects/vector.py
574
575
576
577
def __init__(
    self, mask_operands: list[Operation | SSAValue], result_type: VectorType
):
    super().__init__(operands=(mask_operands,), result_types=(result_type,))

verify_()

Source code in xdsl/dialects/vector.py
579
580
581
582
583
584
def verify_(self):
    assert isa(self.mask_vector.type, VectorType[Attribute])
    if self.mask_vector.type.get_num_dims() != len(self.mask_dim_sizes):
        raise VerifyException(
            "Expected an operand value for each dimension of resultant mask."
        )

get(mask_operands: list[Operation | SSAValue]) -> CreateMaskOp staticmethod

Source code in xdsl/dialects/vector.py
586
587
588
589
590
591
592
@deprecated("Please use vector.CreateMaskOp(mask_operands, result_type)")
@staticmethod
def get(mask_operands: list[Operation | SSAValue]) -> CreateMaskOp:
    return CreateMaskOp.build(
        operands=[mask_operands],
        result_types=[VectorType(i1, [1])],
    )

ExtractOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
@irdl_op_definition
class ExtractOp(IRDLOperation):
    name = "vector.extract"

    _T: ClassVar = VarConstraint(
        "T", base(IntegerType) | base(IndexType) | AnyFloatConstr
    )
    _V: ClassVar = VarConstraint("V", VectorType.constr(_T))

    static_position = prop_def(DenseArrayBase.constr(i64))

    vector = operand_def(_V)
    dynamic_position = var_operand_def(IndexTypeConstr)

    result = result_def(
        VectorType.constr(
            _T,
            shape=MessageConstraint(
                ArrayAttr.constr(RangeOf(base(IntAttr)).of_length(AtLeast(1))),
                "Cannot extract 0d vector.",
            ),
        )
        | _T
    )

    traits = traits_def(Pure())

    DYNAMIC_INDEX: ClassVar = DYNAMIC_INDEX
    """This value is used to indicate that a position is a dynamic index."""

    assembly_format = (
        "$vector `` custom<DynamicIndexList>($dynamic_position, $static_position)"
        " attr-dict `:` type($result) `from` type($vector)"
    )

    custom_directives = (DynamicIndexList,)

    def get_mixed_position(self) -> list[SSAValue | int]:
        """
        Returns the list of positions, represented as either an SSAValue or an int
        """
        static_positions = self.static_position.get_values()
        return get_dynamic_index_list(
            static_positions,
            self.dynamic_position,
            ExtractOp.DYNAMIC_INDEX,
        )

    def verify_(self):
        # Check that static position attribute and dynamic position operands
        # are compatible.
        static_values = self.static_position.get_values()
        verify_dynamic_index_list(
            static_values,
            self.dynamic_position,
            self.DYNAMIC_INDEX,
        )

        num_indices = len(self.static_position)
        vector_type = self.vector.type
        assert isa(vector_type, VectorType[Attribute])
        # Check that the number of dimensions match
        if isa(self.result.type, VectorType):
            if (
                num_indices + self.result.type.get_num_dims()
                != vector_type.get_num_dims()
            ):
                raise VerifyException(
                    f"Expected position attribute rank ({num_indices}) + result rank "
                    f"({self.result.type.get_num_dims()}) to "
                    f"match source vector rank ({vector_type.get_num_dims()})."
                )
        else:
            if num_indices != vector_type.get_num_dims():
                raise VerifyException(
                    f"Expected position attribute rank ({num_indices}) to match "
                    f"source vector rank ({vector_type.get_num_dims()})."
                )

    def __init__(
        self,
        vector: SSAValue,
        positions: Sequence[SSAValue | int],
        result_type: Attribute,
    ):
        static_positions, dynamic_positions = split_dynamic_index_list(
            positions, ExtractOp.DYNAMIC_INDEX
        )

        super().__init__(
            operands=[vector, dynamic_positions],
            result_types=[result_type],
            properties={
                "static_position": DenseArrayBase.from_list(i64, static_positions)
            },
        )

name = 'vector.extract' class-attribute instance-attribute

static_position = prop_def(DenseArrayBase.constr(i64)) class-attribute instance-attribute

vector = operand_def(_V) class-attribute instance-attribute

dynamic_position = var_operand_def(IndexTypeConstr) class-attribute instance-attribute

result = result_def(VectorType.constr(_T, shape=(MessageConstraint(ArrayAttr.constr(RangeOf(base(IntAttr)).of_length(AtLeast(1))), 'Cannot extract 0d vector.'))) | _T) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

DYNAMIC_INDEX: ClassVar = DYNAMIC_INDEX class-attribute instance-attribute

This value is used to indicate that a position is a dynamic index.

assembly_format = '$vector `` custom<DynamicIndexList>($dynamic_position, $static_position) attr-dict `:` type($result) `from` type($vector)' class-attribute instance-attribute

custom_directives = (DynamicIndexList,) class-attribute instance-attribute

get_mixed_position() -> list[SSAValue | int]

Returns the list of positions, represented as either an SSAValue or an int

Source code in xdsl/dialects/vector.py
632
633
634
635
636
637
638
639
640
641
def get_mixed_position(self) -> list[SSAValue | int]:
    """
    Returns the list of positions, represented as either an SSAValue or an int
    """
    static_positions = self.static_position.get_values()
    return get_dynamic_index_list(
        static_positions,
        self.dynamic_position,
        ExtractOp.DYNAMIC_INDEX,
    )

verify_()

Source code in xdsl/dialects/vector.py
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
def verify_(self):
    # Check that static position attribute and dynamic position operands
    # are compatible.
    static_values = self.static_position.get_values()
    verify_dynamic_index_list(
        static_values,
        self.dynamic_position,
        self.DYNAMIC_INDEX,
    )

    num_indices = len(self.static_position)
    vector_type = self.vector.type
    assert isa(vector_type, VectorType[Attribute])
    # Check that the number of dimensions match
    if isa(self.result.type, VectorType):
        if (
            num_indices + self.result.type.get_num_dims()
            != vector_type.get_num_dims()
        ):
            raise VerifyException(
                f"Expected position attribute rank ({num_indices}) + result rank "
                f"({self.result.type.get_num_dims()}) to "
                f"match source vector rank ({vector_type.get_num_dims()})."
            )
    else:
        if num_indices != vector_type.get_num_dims():
            raise VerifyException(
                f"Expected position attribute rank ({num_indices}) to match "
                f"source vector rank ({vector_type.get_num_dims()})."
            )

__init__(vector: SSAValue, positions: Sequence[SSAValue | int], result_type: Attribute)

Source code in xdsl/dialects/vector.py
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
def __init__(
    self,
    vector: SSAValue,
    positions: Sequence[SSAValue | int],
    result_type: Attribute,
):
    static_positions, dynamic_positions = split_dynamic_index_list(
        positions, ExtractOp.DYNAMIC_INDEX
    )

    super().__init__(
        operands=[vector, dynamic_positions],
        result_types=[result_type],
        properties={
            "static_position": DenseArrayBase.from_list(i64, static_positions)
        },
    )

ExtractElementOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
@irdl_op_definition
class ExtractElementOp(IRDLOperation):
    name = "vector.extractelement"
    vector = operand_def(VectorType)
    position = opt_operand_def(IndexTypeConstr | SignlessIntegerConstraint)
    result = result_def(Attribute)
    traits = traits_def(Pure())

    def verify_(self):
        assert isa(self.vector.type, VectorType[Attribute])

        if self.result.type != self.vector.type.element_type:
            raise VerifyException(
                "Expected result type to match element type of vector operand."
            )

        if self.vector.type.get_num_dims() == 0:
            if self.position is not None:
                raise VerifyException("Expected position to be empty with 0-D vector.")
            return
        if self.vector.type.get_num_dims() != 1:
            raise VerifyException("Unexpected >1 vector rank.")
        if self.position is None:
            raise VerifyException("Expected position for 1-D vector.")

    def __init__(
        self,
        vector: SSAValue | Operation,
        position: SSAValue | Operation | None = None,
    ):
        vector = SSAValue.get(vector, type=VectorType)

        result_type = vector.type.element_type

        super().__init__(
            operands=[vector, position],
            result_types=[result_type],
        )

name = 'vector.extractelement' class-attribute instance-attribute

vector = operand_def(VectorType) class-attribute instance-attribute

position = opt_operand_def(IndexTypeConstr | SignlessIntegerConstraint) class-attribute instance-attribute

result = result_def(Attribute) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

verify_()

Source code in xdsl/dialects/vector.py
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
def verify_(self):
    assert isa(self.vector.type, VectorType[Attribute])

    if self.result.type != self.vector.type.element_type:
        raise VerifyException(
            "Expected result type to match element type of vector operand."
        )

    if self.vector.type.get_num_dims() == 0:
        if self.position is not None:
            raise VerifyException("Expected position to be empty with 0-D vector.")
        return
    if self.vector.type.get_num_dims() != 1:
        raise VerifyException("Unexpected >1 vector rank.")
    if self.position is None:
        raise VerifyException("Expected position for 1-D vector.")

__init__(vector: SSAValue | Operation, position: SSAValue | Operation | None = None)

Source code in xdsl/dialects/vector.py
718
719
720
721
722
723
724
725
726
727
728
729
730
def __init__(
    self,
    vector: SSAValue | Operation,
    position: SSAValue | Operation | None = None,
):
    vector = SSAValue.get(vector, type=VectorType)

    result_type = vector.type.element_type

    super().__init__(
        operands=[vector, position],
        result_types=[result_type],
    )

InsertOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
@irdl_op_definition
class InsertOp(IRDLOperation):
    name = "vector.insert"

    _T: ClassVar = VarConstraint(
        "T", base(IntegerType) | base(IndexType) | AnyFloatConstr
    )
    _V: ClassVar = VarConstraint("V", VectorType.constr(_T))

    static_position = prop_def(DenseArrayBase.constr(i64))

    source = operand_def(
        VectorType.constr(
            _T,
            shape=MessageConstraint(
                ArrayAttr.constr(RangeOf(base(IntAttr)).of_length(AtLeast(1))),
                "Cannot insert 0d vector.",
            ),
        )
        | _T
    )
    dest = operand_def(_V)
    dynamic_position = var_operand_def(IndexTypeConstr)

    result = result_def(_V)

    traits = traits_def(Pure())

    DYNAMIC_INDEX: ClassVar = -(2**63)
    """This value is used to indicate that a position is a dynamic index."""

    assembly_format = (
        "$source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)"
        "attr-dict `:` type($source) `into` type($dest)"
    )

    custom_directives = (DynamicIndexList,)

    def get_mixed_position(self) -> list[SSAValue | int]:
        """
        Returns the list of positions, represented as either an SSAValue or an int.
        """
        static_positions = self.static_position.get_values()
        return get_dynamic_index_list(
            static_positions,
            self.dynamic_position,
            InsertOp.DYNAMIC_INDEX,
        )

    def verify_(self):
        # Check that static position attribute and dynamic position operands
        # are compatible.
        static_values = self.static_position.get_values()
        verify_dynamic_index_list(
            static_values,
            self.dynamic_position,
            self.DYNAMIC_INDEX,
        )

        num_indices = len(self.static_position)
        # Check that the number of dimensions match
        if isa(self.source.type, VectorType):
            if (
                num_indices + self.source.type.get_num_dims()
                != self.result.type.get_num_dims()
            ):
                raise VerifyException(
                    f"Expected position attribute rank ({num_indices}) + source rank "
                    f"({self.source.type.get_num_dims()}) to "
                    f"match dest vector rank ({self.result.type.get_num_dims()})."
                )
        else:
            if num_indices != self.result.type.get_num_dims():
                raise VerifyException(
                    f"Expected position attribute rank ({num_indices}) to match "
                    f"dest vector rank ({self.result.type.get_num_dims()})."
                )

    def __init__(
        self,
        source: SSAValue,
        dest: SSAValue,
        positions: Sequence[SSAValue | int],
        result_type: Attribute | None = None,
    ):
        static_positions, dynamic_positions = split_dynamic_index_list(
            positions, InsertOp.DYNAMIC_INDEX
        )

        if result_type is None:
            result_type = dest.type

        super().__init__(
            operands=[source, dest, dynamic_positions],
            result_types=[result_type],
            properties={
                "static_position": DenseArrayBase.from_list(i64, static_positions)
            },
        )

name = 'vector.insert' class-attribute instance-attribute

static_position = prop_def(DenseArrayBase.constr(i64)) class-attribute instance-attribute

source = operand_def(VectorType.constr(_T, shape=(MessageConstraint(ArrayAttr.constr(RangeOf(base(IntAttr)).of_length(AtLeast(1))), 'Cannot insert 0d vector.'))) | _T) class-attribute instance-attribute

dest = operand_def(_V) class-attribute instance-attribute

dynamic_position = var_operand_def(IndexTypeConstr) class-attribute instance-attribute

result = result_def(_V) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

DYNAMIC_INDEX: ClassVar = -2 ** 63 class-attribute instance-attribute

This value is used to indicate that a position is a dynamic index.

assembly_format = '$source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)attr-dict `:` type($source) `into` type($dest)' class-attribute instance-attribute

custom_directives = (DynamicIndexList,) class-attribute instance-attribute

get_mixed_position() -> list[SSAValue | int]

Returns the list of positions, represented as either an SSAValue or an int.

Source code in xdsl/dialects/vector.py
771
772
773
774
775
776
777
778
779
780
def get_mixed_position(self) -> list[SSAValue | int]:
    """
    Returns the list of positions, represented as either an SSAValue or an int.
    """
    static_positions = self.static_position.get_values()
    return get_dynamic_index_list(
        static_positions,
        self.dynamic_position,
        InsertOp.DYNAMIC_INDEX,
    )

verify_()

Source code in xdsl/dialects/vector.py
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
def verify_(self):
    # Check that static position attribute and dynamic position operands
    # are compatible.
    static_values = self.static_position.get_values()
    verify_dynamic_index_list(
        static_values,
        self.dynamic_position,
        self.DYNAMIC_INDEX,
    )

    num_indices = len(self.static_position)
    # Check that the number of dimensions match
    if isa(self.source.type, VectorType):
        if (
            num_indices + self.source.type.get_num_dims()
            != self.result.type.get_num_dims()
        ):
            raise VerifyException(
                f"Expected position attribute rank ({num_indices}) + source rank "
                f"({self.source.type.get_num_dims()}) to "
                f"match dest vector rank ({self.result.type.get_num_dims()})."
            )
    else:
        if num_indices != self.result.type.get_num_dims():
            raise VerifyException(
                f"Expected position attribute rank ({num_indices}) to match "
                f"dest vector rank ({self.result.type.get_num_dims()})."
            )

__init__(source: SSAValue, dest: SSAValue, positions: Sequence[SSAValue | int], result_type: Attribute | None = None)

Source code in xdsl/dialects/vector.py
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
def __init__(
    self,
    source: SSAValue,
    dest: SSAValue,
    positions: Sequence[SSAValue | int],
    result_type: Attribute | None = None,
):
    static_positions, dynamic_positions = split_dynamic_index_list(
        positions, InsertOp.DYNAMIC_INDEX
    )

    if result_type is None:
        result_type = dest.type

    super().__init__(
        operands=[source, dest, dynamic_positions],
        result_types=[result_type],
        properties={
            "static_position": DenseArrayBase.from_list(i64, static_positions)
        },
    )

InsertElementOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
@irdl_op_definition
class InsertElementOp(IRDLOperation):
    name = "vector.insertelement"
    source = operand_def(Attribute)
    dest = operand_def(VectorType)
    position = opt_operand_def(IndexTypeConstr | SignlessIntegerConstraint)
    result = result_def(VectorType)
    traits = traits_def(Pure())

    def verify_(self):
        assert isa(self.dest.type, VectorType[Attribute])

        if self.result.type != self.dest.type:
            raise VerifyException(
                "Expected dest operand and result to have matching types."
            )
        if self.source.type != self.dest.type.element_type:
            raise VerifyException(
                "Expected source operand type to match element type of dest operand."
            )

        if self.dest.type.get_num_dims() == 0:
            if self.position is not None:
                raise VerifyException("Expected position to be empty with 0-D vector.")
            return
        if self.dest.type.get_num_dims() != 1:
            raise VerifyException("Unexpected >1 vector rank.")
        if self.position is None:
            raise VerifyException("Expected position for 1-D vector.")

    def __init__(
        self,
        source: SSAValue | Operation,
        dest: SSAValue | Operation,
        position: SSAValue | Operation | None = None,
    ):
        dest = SSAValue.get(dest, type=VectorType)

        result_type = SSAValue.get(dest).type

        super().__init__(
            operands=[source, dest, position],
            result_types=[result_type],
        )

name = 'vector.insertelement' class-attribute instance-attribute

source = operand_def(Attribute) class-attribute instance-attribute

dest = operand_def(VectorType) class-attribute instance-attribute

position = opt_operand_def(IndexTypeConstr | SignlessIntegerConstraint) class-attribute instance-attribute

result = result_def(VectorType) class-attribute instance-attribute

traits = traits_def(Pure()) class-attribute instance-attribute

verify_()

Source code in xdsl/dialects/vector.py
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
def verify_(self):
    assert isa(self.dest.type, VectorType[Attribute])

    if self.result.type != self.dest.type:
        raise VerifyException(
            "Expected dest operand and result to have matching types."
        )
    if self.source.type != self.dest.type.element_type:
        raise VerifyException(
            "Expected source operand type to match element type of dest operand."
        )

    if self.dest.type.get_num_dims() == 0:
        if self.position is not None:
            raise VerifyException("Expected position to be empty with 0-D vector.")
        return
    if self.dest.type.get_num_dims() != 1:
        raise VerifyException("Unexpected >1 vector rank.")
    if self.position is None:
        raise VerifyException("Expected position for 1-D vector.")

__init__(source: SSAValue | Operation, dest: SSAValue | Operation, position: SSAValue | Operation | None = None)

Source code in xdsl/dialects/vector.py
864
865
866
867
868
869
870
871
872
873
874
875
876
877
def __init__(
    self,
    source: SSAValue | Operation,
    dest: SSAValue | Operation,
    position: SSAValue | Operation | None = None,
):
    dest = SSAValue.get(dest, type=VectorType)

    result_type = SSAValue.get(dest).type

    super().__init__(
        operands=[source, dest, position],
        result_types=[result_type],
    )

VectorTransferOperation dataclass

Bases: IRDLOperation, ABC

Encodes properties of a vector.transfer_read or vector.transfer_write operation. Vector transfer ops have:

  • A shaped value that the op reads from/writes to: a memref or a tensor.
  • A vector, either as a result or as an operand.
  • Indices that describe where the transfer from/to the shaped value starts.
  • An optional mask.
  • An optional in_bounds array to indicate transfer dimensions that are guaranteed to be in-bounds.
  • A permutation map to indicate transposes and broadcasts.

The "vector rank" is the rank of the vector type. E.g.:

// Transfer with shaped value rank 2 and vector (transfer) rank 1.
%0 = vector.transfer_read %arg0[%c3, %c3], %f0
    {permutation_map = affine_map<(d0, d1) -> (d0)>}
    : memref<?x?xf32>, vector<128xf32>

The "vector transfer rank" is the number of dimensions that participate in the transfer and broadcasts, and matches the number of results in the permutation map. In most cases, the vector rank matches the vector transfer rank; the only exception is when a vector is flattened as part of the transfer (see permutation_map).

Mirrors VectorTransferOpInterface from MLIR

Source code in xdsl/dialects/vector.py
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
class VectorTransferOperation(IRDLOperation, ABC):
    """
    Encodes properties of a `vector.transfer_read` or `vector.transfer_write`
    operation. Vector transfer ops have:

    - A shaped value that the op reads from/writes to: a memref or a tensor.
    - A vector, either as a result or as an operand.
    - Indices that describe where the transfer from/to the shaped value starts.
    - An optional mask.
    - An optional in_bounds array to indicate transfer dimensions that are
      guaranteed to be in-bounds.
    - A permutation map to indicate transposes and broadcasts.

    The "vector rank" is the rank of the vector type. E.g.:
    ```mlir
    // Transfer with shaped value rank 2 and vector (transfer) rank 1.
    %0 = vector.transfer_read %arg0[%c3, %c3], %f0
        {permutation_map = affine_map<(d0, d1) -> (d0)>}
        : memref<?x?xf32>, vector<128xf32>
    ```

    The "vector transfer rank" is the number of dimensions that participate in
    the transfer and broadcasts, and matches the number of results in the
    permutation map. In most cases, the vector rank matches the vector transfer
    rank; the only exception is when a vector is flattened as part of the
    transfer (see `permutation_map`).

    Mirrors VectorTransferOpInterface from [MLIR](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Interfaces/VectorInterfaces.td)
    """

    permutation_map = prop_def(AffineMapAttr)
    """
    The permutation map that describes the mapping of vector
    dimensions to source dimensions, as well as broadcast dimensions.

    The permutation result has one result per vector transfer dimension.
    Each result is either a dim expression, indicating the corresponding
    dimension in the source operand, or a constant "0" expression,
    indicating a broadcast dimension.

    Note: Nested vector dimensions that are flattened by this op are not
    accounted for in the permutation map. E.g.:
    ```mlir
    // Vector type has rank 4, but permutation map has only 2 results. That
    // is because there are only 2 transfer dimensions.
    %0 = vector.transfer_read %arg1[%c3, %c3], %vf0
        {permutation_map = affine_map<(d0, d1) -> (d0, d1)>}
        : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
    ```
    """

    in_bounds = prop_def(ArrayAttr[BoolAttr])
    """
    For every vector dimension, the boolean array attribute `in_bounds` specifies if the
    transfer is guaranteed to be within the source bounds. If set to `“false”`, accesses
    (including the starting point) may run out-of-bounds along the respective vector
    dimension as the index increases. Non-vector dimensions must always be in-bounds.
    The `in_bounds` array length has to be equal to the vector rank. This attribute has
    a default value: `false` (i.e. “out-of-bounds”). When skipped in the textual IR, the
    default value is assumed. Similarly, the OP printer will omit this attribute when
    all dimensions are out-of-bounds (i.e. the default value is used).
    """

    @staticmethod
    def infer_transfer_op_mask_type(
        vec_type: VectorType, perm_map: AffineMap
    ) -> VectorType[I1]:
        """
        Given a resulting vector type and a permutation map from the dimensions of the
        shaped type to the vector type dimensions, return the vector type of the mask.
        """
        unused_dims_bit_vector = tuple(
            not dim for dim in perm_map.used_dims_bit_vector()
        )
        inv_perm_map = perm_map.drop_dims(unused_dims_bit_vector).inverse_permutation()
        assert inv_perm_map is not None, "Inversed permutation map couldn't be computed"
        mask_shape = inv_perm_map.eval(vec_type.get_shape(), ())
        scalable_dims = ArrayAttr(
            BoolAttr.from_bool(bool(b))
            for b in inv_perm_map.eval(vec_type.get_scalable_dims(), ())
        )
        res = VectorType(i1, mask_shape, scalable_dims)
        return res

    @staticmethod
    def get_transfer_minor_identity_map(
        shaped_type: TensorType | MemRefType, vector_type: VectorType
    ) -> AffineMap:
        """
        Get the minor identity map for a transfer operation.

        This is a helper function to compute the default permutation map for
        transfer operations when none is specified.
        """
        element_vector_rank = 0
        element_type = shaped_type.element_type
        if isa(element_type, VectorType):
            element_vector_rank += element_type.get_num_dims()

        # 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
        # TODO: replace once we have 0-d vectors.
        if shaped_type.get_num_dims() == 0 and vector_type.get_shape() == (1,):
            return AffineMap.constant_map(0)

        return AffineMap.minor_identity(
            shaped_type.get_num_dims(),
            vector_type.get_num_dims() - element_vector_rank,
        )

    def _print_attrs(self, printer: Printer):
        reserved_attr_names = {"operandSegmentSizes"}
        if self.permutation_map.data.is_minor_identity():
            reserved_attr_names.add("permutation_map")
        if not any(self.in_bounds):
            reserved_attr_names.add("in_bounds")
        printer.print_op_attributes(
            self.attributes | self.properties, reserved_attr_names=reserved_attr_names
        )

    @staticmethod
    def resolve_attrs(
        parser: Parser,
        attributes_dict: dict[str, Attribute],
        shaped_type: TensorType | MemRefType,
        vector_type: VectorType,
        mask_start_pos: Position | None,
        mask_end_pos: Position | None,
        mask: UnresolvedOperand | None,
        types_pos: Position,
    ):
        # Create default permutation_map if not provided in attributes
        permutation_map = None
        if attributes_dict and "permutation_map" in attributes_dict:
            permutation_map = attributes_dict["permutation_map"]
            assert isinstance(permutation_map, AffineMapAttr)
        else:
            # Create identity permutation map for the shaped type's rank
            permutation_map = AffineMapAttr(
                VectorTransferOperation.get_transfer_minor_identity_map(
                    shaped_type, vector_type
                )
            )

        # Create in_bounds attribute if not provided
        in_bounds = None
        if attributes_dict and "in_bounds" in attributes_dict:
            in_bounds = cast(ArrayAttr[BoolAttr], attributes_dict["in_bounds"])
        else:
            # Default: all dimensions are out-of-bounds
            in_bounds = ArrayAttr(
                (BoolAttr.from_bool(False),) * len(permutation_map.data.results)
            )

        if mask is not None:
            if isa(shaped_type.element_type, VectorType):
                assert mask_start_pos is not None
                assert mask_end_pos is not None
                parser.raise_error(
                    "does not support masks with vector element type",
                    at_position=mask_start_pos,
                    end_position=mask_end_pos,
                )
            if vector_type.get_num_dims() != len(permutation_map.data.results):
                parser.raise_error(
                    "expected the same rank for the vector and the "
                    "results of the permutation map",
                    types_pos,
                )
            # Instead of adding the mask type as an op type, compute it based on the
            # vector type and the permutation map (to keep the type signature small).
            mask_type = VectorTransferOperation.infer_transfer_op_mask_type(
                vector_type, permutation_map.data
            )
            resolved_mask = parser.resolve_operand(mask, mask_type)
        else:
            resolved_mask = None

        return resolved_mask, permutation_map, in_bounds

    def has_broadcast_dim(self):
        """
        Return "true" if at least one of the vector dimensions is a broadcasted dimension.
        """
        return any(
            isinstance(expr, AffineConstantExpr) and expr.value == 0
            for expr in self.permutation_map.data.results
        )

    @staticmethod
    def verify_op(
        op: TransferReadOp | TransferWriteOp,
        shaped_type: MemRefType | TensorType,
        vector_type: VectorType,
        mask_type: VectorType[I1] | None,
        inferred_mask_type: VectorType[I1] | None,
        permutation_map: AffineMap,
        in_bounds: ArrayAttr[BoolAttr],
    ):
        """
        This mirrors VectorOps.cpp -> verifyTransferOp from MLIR
        """

        element_type = shaped_type.element_type
        vector_element_type = vector_type.element_type

        if isa(element_type, VectorType):
            # Memref or tensor has vector element type
            # TODO verify vector element type
            pass
        else:
            # Memref of tensor has scalar element type
            if isa(vector_element_type, IndexType):
                if not isa(element_type, IndexType):
                    raise VerifyException(
                        "Element type of source is index, expected element type of vector also to be index"
                    )
            else:
                assert isa(vector_element_type, IntegerType | AnyFloat)
                assert isa(element_type, IntegerType | AnyFloat)

                minor_size = (
                    1
                    if vector_type.get_num_dims() == 0
                    else vector_type.get_shape()[-1]
                )
                result_vec_size = vector_element_type.bitwidth * minor_size
                if result_vec_size % element_type.bitwidth != 0:
                    raise VerifyException(
                        f'"{op.name}" requires the bitwidth of the minor 1-D vector to be '
                        "an integral multiple of the bitwidth of the source element type"
                    )

            # Check that permutation map results match rank of vector type.
            if len(permutation_map.results) != vector_type.get_num_dims():
                raise VerifyException(
                    f'"{op.name}" requires a permutation_map with result dims of the same rank as the vector type'
                )

        if permutation_map.num_symbols != 0:
            raise VerifyException(
                f'"{op.name}" requires permutation_map without symbols'
            )

        if permutation_map.num_dims != shaped_type.get_num_dims():
            raise VerifyException(
                f'"{op.name}" requires a permutation_map with input dims of the same rank as the source type'
            )

        if mask_type:
            if mask_type != inferred_mask_type:
                raise VerifyException(
                    f'"{op.name}" inferred mask type ({inferred_mask_type}) and mask operand type ({mask_type}) don\'t match'
                )

        if len(in_bounds) != len(permutation_map.results):
            raise VerifyException(
                f'"{op.name}" expects the in_bounds attr of same rank as permutation_map results: '
                f"{str(permutation_map)} vs in_bounds of of size {len(in_bounds)}"
            )

    @staticmethod
    def verify_permutation_map(
        op: TransferReadOp | TransferWriteOp,
        permutation_map: AffineMap,
    ):
        """
        This mirrors VectorOps.cpp -> verifyPermutationMap
        """

        seen: list[bool] = [False for _ in range(permutation_map.num_dims)]

        for expr in permutation_map.results:
            if isa(expr, AffineConstantExpr):
                if expr.value != 0:
                    raise VerifyException(
                        f'"{op.name}" requires a projected permutation_map '
                        "(at most one dim or the zero constant can appear in each result)"
                    )
                continue
            if not isa(expr, AffineDimExpr):
                raise VerifyException(
                    f'"{op.name}" requires a projected permutation_map '
                    "(at most one dim or the zero constant can appear in each result)"
                )
            if seen[expr.position]:
                raise VerifyException(
                    f'"{op.name}" requires a permutation_map that is a permutation '
                    "(found one dim used more than once)"
                )
            seen[expr.position] = True

permutation_map = prop_def(AffineMapAttr) class-attribute instance-attribute

The permutation map that describes the mapping of vector dimensions to source dimensions, as well as broadcast dimensions.

The permutation result has one result per vector transfer dimension. Each result is either a dim expression, indicating the corresponding dimension in the source operand, or a constant "0" expression, indicating a broadcast dimension.

Note: Nested vector dimensions that are flattened by this op are not accounted for in the permutation map. E.g.:

// Vector type has rank 4, but permutation map has only 2 results. That
// is because there are only 2 transfer dimensions.
%0 = vector.transfer_read %arg1[%c3, %c3], %vf0
    {permutation_map = affine_map<(d0, d1) -> (d0, d1)>}
    : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>

in_bounds = prop_def(ArrayAttr[BoolAttr]) class-attribute instance-attribute

For every vector dimension, the boolean array attribute in_bounds specifies if the transfer is guaranteed to be within the source bounds. If set to “false”, accesses (including the starting point) may run out-of-bounds along the respective vector dimension as the index increases. Non-vector dimensions must always be in-bounds. The in_bounds array length has to be equal to the vector rank. This attribute has a default value: false (i.e. “out-of-bounds”). When skipped in the textual IR, the default value is assumed. Similarly, the OP printer will omit this attribute when all dimensions are out-of-bounds (i.e. the default value is used).

infer_transfer_op_mask_type(vec_type: VectorType, perm_map: AffineMap) -> VectorType[I1] staticmethod

Given a resulting vector type and a permutation map from the dimensions of the shaped type to the vector type dimensions, return the vector type of the mask.

Source code in xdsl/dialects/vector.py
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
@staticmethod
def infer_transfer_op_mask_type(
    vec_type: VectorType, perm_map: AffineMap
) -> VectorType[I1]:
    """
    Given a resulting vector type and a permutation map from the dimensions of the
    shaped type to the vector type dimensions, return the vector type of the mask.
    """
    unused_dims_bit_vector = tuple(
        not dim for dim in perm_map.used_dims_bit_vector()
    )
    inv_perm_map = perm_map.drop_dims(unused_dims_bit_vector).inverse_permutation()
    assert inv_perm_map is not None, "Inversed permutation map couldn't be computed"
    mask_shape = inv_perm_map.eval(vec_type.get_shape(), ())
    scalable_dims = ArrayAttr(
        BoolAttr.from_bool(bool(b))
        for b in inv_perm_map.eval(vec_type.get_scalable_dims(), ())
    )
    res = VectorType(i1, mask_shape, scalable_dims)
    return res

get_transfer_minor_identity_map(shaped_type: TensorType | MemRefType, vector_type: VectorType) -> AffineMap staticmethod

Get the minor identity map for a transfer operation.

This is a helper function to compute the default permutation map for transfer operations when none is specified.

Source code in xdsl/dialects/vector.py
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
@staticmethod
def get_transfer_minor_identity_map(
    shaped_type: TensorType | MemRefType, vector_type: VectorType
) -> AffineMap:
    """
    Get the minor identity map for a transfer operation.

    This is a helper function to compute the default permutation map for
    transfer operations when none is specified.
    """
    element_vector_rank = 0
    element_type = shaped_type.element_type
    if isa(element_type, VectorType):
        element_vector_rank += element_type.get_num_dims()

    # 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
    # TODO: replace once we have 0-d vectors.
    if shaped_type.get_num_dims() == 0 and vector_type.get_shape() == (1,):
        return AffineMap.constant_map(0)

    return AffineMap.minor_identity(
        shaped_type.get_num_dims(),
        vector_type.get_num_dims() - element_vector_rank,
    )

resolve_attrs(parser: Parser, attributes_dict: dict[str, Attribute], shaped_type: TensorType | MemRefType, vector_type: VectorType, mask_start_pos: Position | None, mask_end_pos: Position | None, mask: UnresolvedOperand | None, types_pos: Position) staticmethod

Source code in xdsl/dialects/vector.py
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
@staticmethod
def resolve_attrs(
    parser: Parser,
    attributes_dict: dict[str, Attribute],
    shaped_type: TensorType | MemRefType,
    vector_type: VectorType,
    mask_start_pos: Position | None,
    mask_end_pos: Position | None,
    mask: UnresolvedOperand | None,
    types_pos: Position,
):
    # Create default permutation_map if not provided in attributes
    permutation_map = None
    if attributes_dict and "permutation_map" in attributes_dict:
        permutation_map = attributes_dict["permutation_map"]
        assert isinstance(permutation_map, AffineMapAttr)
    else:
        # Create identity permutation map for the shaped type's rank
        permutation_map = AffineMapAttr(
            VectorTransferOperation.get_transfer_minor_identity_map(
                shaped_type, vector_type
            )
        )

    # Create in_bounds attribute if not provided
    in_bounds = None
    if attributes_dict and "in_bounds" in attributes_dict:
        in_bounds = cast(ArrayAttr[BoolAttr], attributes_dict["in_bounds"])
    else:
        # Default: all dimensions are out-of-bounds
        in_bounds = ArrayAttr(
            (BoolAttr.from_bool(False),) * len(permutation_map.data.results)
        )

    if mask is not None:
        if isa(shaped_type.element_type, VectorType):
            assert mask_start_pos is not None
            assert mask_end_pos is not None
            parser.raise_error(
                "does not support masks with vector element type",
                at_position=mask_start_pos,
                end_position=mask_end_pos,
            )
        if vector_type.get_num_dims() != len(permutation_map.data.results):
            parser.raise_error(
                "expected the same rank for the vector and the "
                "results of the permutation map",
                types_pos,
            )
        # Instead of adding the mask type as an op type, compute it based on the
        # vector type and the permutation map (to keep the type signature small).
        mask_type = VectorTransferOperation.infer_transfer_op_mask_type(
            vector_type, permutation_map.data
        )
        resolved_mask = parser.resolve_operand(mask, mask_type)
    else:
        resolved_mask = None

    return resolved_mask, permutation_map, in_bounds

has_broadcast_dim()

Return "true" if at least one of the vector dimensions is a broadcasted dimension.

Source code in xdsl/dialects/vector.py
1059
1060
1061
1062
1063
1064
1065
1066
def has_broadcast_dim(self):
    """
    Return "true" if at least one of the vector dimensions is a broadcasted dimension.
    """
    return any(
        isinstance(expr, AffineConstantExpr) and expr.value == 0
        for expr in self.permutation_map.data.results
    )

verify_op(op: TransferReadOp | TransferWriteOp, shaped_type: MemRefType | TensorType, vector_type: VectorType, mask_type: VectorType[I1] | None, inferred_mask_type: VectorType[I1] | None, permutation_map: AffineMap, in_bounds: ArrayAttr[BoolAttr]) staticmethod

This mirrors VectorOps.cpp -> verifyTransferOp from MLIR

Source code in xdsl/dialects/vector.py
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
@staticmethod
def verify_op(
    op: TransferReadOp | TransferWriteOp,
    shaped_type: MemRefType | TensorType,
    vector_type: VectorType,
    mask_type: VectorType[I1] | None,
    inferred_mask_type: VectorType[I1] | None,
    permutation_map: AffineMap,
    in_bounds: ArrayAttr[BoolAttr],
):
    """
    This mirrors VectorOps.cpp -> verifyTransferOp from MLIR
    """

    element_type = shaped_type.element_type
    vector_element_type = vector_type.element_type

    if isa(element_type, VectorType):
        # Memref or tensor has vector element type
        # TODO verify vector element type
        pass
    else:
        # Memref of tensor has scalar element type
        if isa(vector_element_type, IndexType):
            if not isa(element_type, IndexType):
                raise VerifyException(
                    "Element type of source is index, expected element type of vector also to be index"
                )
        else:
            assert isa(vector_element_type, IntegerType | AnyFloat)
            assert isa(element_type, IntegerType | AnyFloat)

            minor_size = (
                1
                if vector_type.get_num_dims() == 0
                else vector_type.get_shape()[-1]
            )
            result_vec_size = vector_element_type.bitwidth * minor_size
            if result_vec_size % element_type.bitwidth != 0:
                raise VerifyException(
                    f'"{op.name}" requires the bitwidth of the minor 1-D vector to be '
                    "an integral multiple of the bitwidth of the source element type"
                )

        # Check that permutation map results match rank of vector type.
        if len(permutation_map.results) != vector_type.get_num_dims():
            raise VerifyException(
                f'"{op.name}" requires a permutation_map with result dims of the same rank as the vector type'
            )

    if permutation_map.num_symbols != 0:
        raise VerifyException(
            f'"{op.name}" requires permutation_map without symbols'
        )

    if permutation_map.num_dims != shaped_type.get_num_dims():
        raise VerifyException(
            f'"{op.name}" requires a permutation_map with input dims of the same rank as the source type'
        )

    if mask_type:
        if mask_type != inferred_mask_type:
            raise VerifyException(
                f'"{op.name}" inferred mask type ({inferred_mask_type}) and mask operand type ({mask_type}) don\'t match'
            )

    if len(in_bounds) != len(permutation_map.results):
        raise VerifyException(
            f'"{op.name}" expects the in_bounds attr of same rank as permutation_map results: '
            f"{str(permutation_map)} vs in_bounds of of size {len(in_bounds)}"
        )

verify_permutation_map(op: TransferReadOp | TransferWriteOp, permutation_map: AffineMap) staticmethod

This mirrors VectorOps.cpp -> verifyPermutationMap

Source code in xdsl/dialects/vector.py
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
@staticmethod
def verify_permutation_map(
    op: TransferReadOp | TransferWriteOp,
    permutation_map: AffineMap,
):
    """
    This mirrors VectorOps.cpp -> verifyPermutationMap
    """

    seen: list[bool] = [False for _ in range(permutation_map.num_dims)]

    for expr in permutation_map.results:
        if isa(expr, AffineConstantExpr):
            if expr.value != 0:
                raise VerifyException(
                    f'"{op.name}" requires a projected permutation_map '
                    "(at most one dim or the zero constant can appear in each result)"
                )
            continue
        if not isa(expr, AffineDimExpr):
            raise VerifyException(
                f'"{op.name}" requires a projected permutation_map '
                "(at most one dim or the zero constant can appear in each result)"
            )
        if seen[expr.position]:
            raise VerifyException(
                f'"{op.name}" requires a permutation_map that is a permutation '
                "(found one dim used more than once)"
            )
        seen[expr.position] = True

TransferReadOp

Bases: VectorTransferOperation

Reads a supervector from memory into an SSA vector value.

Source code in xdsl/dialects/vector.py
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
@irdl_op_definition
class TransferReadOp(VectorTransferOperation):
    "Reads a supervector from memory into an SSA vector value."

    name = "vector.transfer_read"

    source = operand_def(TensorType | MemRefType)
    indices = var_operand_def(IndexType)
    padding = operand_def()
    mask = opt_operand_def(VectorType[I1])

    permutation_map = prop_def(AffineMapAttr)

    result = result_def(VectorType)

    irdl_options = (
        AttrSizedOperandSegments(as_property=True),
        ParsePropInAttrDict(),
    )

    def __init__(
        self,
        source: SSAValue | Operation,
        indices: Sequence[SSAValue | Operation],
        padding: SSAValue | Operation,
        result_type: Attribute,
        in_bounds: ArrayAttr[BoolAttr],
        permutation_map: AffineMapAttr,
        mask: SSAValue | Operation | None = None,
    ):
        super().__init__(
            operands=[source, indices, padding, mask],
            result_types=[result_type],
            properties={"in_bounds": in_bounds, "permutation_map": permutation_map},
        )

    def print(self, printer: Printer):
        printer.print_string(" ", indent=0)
        printer.print_ssa_value(self.source)
        printer.print_string("[", indent=0)
        printer.print_list(self.indices, printer.print_ssa_value)
        printer.print_string("], ", indent=0)
        printer.print_ssa_value(self.padding)
        if self.mask is not None:
            printer.print_string(", ", indent=0)
            printer.print_ssa_value(self.mask)
        self._print_attrs(printer)
        printer.print_string(" : ", indent=0)
        printer.print_attribute(self.source.type)
        printer.print_string(", ", indent=0)
        printer.print_attribute(self.result.type)

    @classmethod
    def parse(cls, parser: Parser) -> TransferReadOp:
        source = parser.parse_unresolved_operand()
        indices = parser.parse_comma_separated_list(
            Parser.Delimiter.SQUARE, parser.parse_operand
        )
        parser.parse_punctuation(",")
        padding = parser.parse_operand()
        if parser.parse_optional_punctuation(","):
            mask_start_pos = parser.pos
            mask = parser.parse_unresolved_operand()
            mask_end_pos = parser.pos
        else:
            mask_start_pos = None
            mask = None
            mask_end_pos = None
        attributes_dict = parser.parse_optional_attr_dict()

        types_pos = parser.pos
        parser.parse_punctuation(":")
        shaped_type = parser.parse_type()
        parser.parse_punctuation(",")
        vector_type = parser.parse_type()

        source = parser.resolve_operand(source, shaped_type)

        if not isa(shaped_type, MemRefType | TensorType):
            parser.raise_error(
                "requires memref or ranked tensor type", at_position=types_pos
            )

        if not isa(vector_type, VectorType):
            parser.raise_error("requires vector type", at_position=types_pos)

        mask, permutation_map, in_bounds = VectorTransferOperation.resolve_attrs(
            parser,
            attributes_dict,
            shaped_type,
            vector_type,
            mask_start_pos,
            mask_end_pos,
            mask,
            types_pos,
        )

        # Create and return the TransferReadOp
        return TransferReadOp(
            source=source,
            indices=indices,
            padding=padding,
            mask=mask,
            permutation_map=permutation_map,
            in_bounds=in_bounds,
            result_type=vector_type,
        )

    def verify_(self):
        assert isa(self.source.type, MemRefType | TensorType)
        assert isa(self.result.type, VectorType)
        if self.mask:
            assert isa(self.mask.type, VectorType[I1])
            mask_type = self.mask.type
        else:
            mask_type = None

        if len(self.indices) != self.source.type.get_num_dims():
            raise VerifyException("Expected an index for each memref/tensor dimension.")

        if mask_type:
            inferred_mask_type = VectorTransferOperation.infer_transfer_op_mask_type(
                self.result.type,
                self.permutation_map.data,
            )
        else:
            inferred_mask_type = VectorType(i1, [])

        VectorTransferOperation.verify_op(
            self,
            self.source.type,
            self.result.type,
            mask_type,
            inferred_mask_type,
            self.permutation_map.data,
            self.in_bounds,
        )

        if isa(self.source.type.element_type, VectorType):
            # TODO verify vector element type
            pass
        else:
            # source memref/tensor has scalar element type
            # TODO verify that padding type is a valid element_type for a vector
            if self.source.type.element_type != self.padding.type:
                raise VerifyException(
                    f'"{self.name}" requires formal padding and source of the same elemental type'
                )

        VectorTransferOperation.verify_permutation_map(
            self,
            self.permutation_map.data,
        )

name = 'vector.transfer_read' class-attribute instance-attribute

source = operand_def(TensorType | MemRefType) class-attribute instance-attribute

indices = var_operand_def(IndexType) class-attribute instance-attribute

padding = operand_def() class-attribute instance-attribute

mask = opt_operand_def(VectorType[I1]) class-attribute instance-attribute

permutation_map = prop_def(AffineMapAttr) class-attribute instance-attribute

result = result_def(VectorType) class-attribute instance-attribute

irdl_options = (AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()) class-attribute instance-attribute

__init__(source: SSAValue | Operation, indices: Sequence[SSAValue | Operation], padding: SSAValue | Operation, result_type: Attribute, in_bounds: ArrayAttr[BoolAttr], permutation_map: AffineMapAttr, mask: SSAValue | Operation | None = None)

Source code in xdsl/dialects/vector.py
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
def __init__(
    self,
    source: SSAValue | Operation,
    indices: Sequence[SSAValue | Operation],
    padding: SSAValue | Operation,
    result_type: Attribute,
    in_bounds: ArrayAttr[BoolAttr],
    permutation_map: AffineMapAttr,
    mask: SSAValue | Operation | None = None,
):
    super().__init__(
        operands=[source, indices, padding, mask],
        result_types=[result_type],
        properties={"in_bounds": in_bounds, "permutation_map": permutation_map},
    )

print(printer: Printer)

Source code in xdsl/dialects/vector.py
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
def print(self, printer: Printer):
    printer.print_string(" ", indent=0)
    printer.print_ssa_value(self.source)
    printer.print_string("[", indent=0)
    printer.print_list(self.indices, printer.print_ssa_value)
    printer.print_string("], ", indent=0)
    printer.print_ssa_value(self.padding)
    if self.mask is not None:
        printer.print_string(", ", indent=0)
        printer.print_ssa_value(self.mask)
    self._print_attrs(printer)
    printer.print_string(" : ", indent=0)
    printer.print_attribute(self.source.type)
    printer.print_string(", ", indent=0)
    printer.print_attribute(self.result.type)

parse(parser: Parser) -> TransferReadOp classmethod

Source code in xdsl/dialects/vector.py
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
@classmethod
def parse(cls, parser: Parser) -> TransferReadOp:
    source = parser.parse_unresolved_operand()
    indices = parser.parse_comma_separated_list(
        Parser.Delimiter.SQUARE, parser.parse_operand
    )
    parser.parse_punctuation(",")
    padding = parser.parse_operand()
    if parser.parse_optional_punctuation(","):
        mask_start_pos = parser.pos
        mask = parser.parse_unresolved_operand()
        mask_end_pos = parser.pos
    else:
        mask_start_pos = None
        mask = None
        mask_end_pos = None
    attributes_dict = parser.parse_optional_attr_dict()

    types_pos = parser.pos
    parser.parse_punctuation(":")
    shaped_type = parser.parse_type()
    parser.parse_punctuation(",")
    vector_type = parser.parse_type()

    source = parser.resolve_operand(source, shaped_type)

    if not isa(shaped_type, MemRefType | TensorType):
        parser.raise_error(
            "requires memref or ranked tensor type", at_position=types_pos
        )

    if not isa(vector_type, VectorType):
        parser.raise_error("requires vector type", at_position=types_pos)

    mask, permutation_map, in_bounds = VectorTransferOperation.resolve_attrs(
        parser,
        attributes_dict,
        shaped_type,
        vector_type,
        mask_start_pos,
        mask_end_pos,
        mask,
        types_pos,
    )

    # Create and return the TransferReadOp
    return TransferReadOp(
        source=source,
        indices=indices,
        padding=padding,
        mask=mask,
        permutation_map=permutation_map,
        in_bounds=in_bounds,
        result_type=vector_type,
    )

verify_()

Source code in xdsl/dialects/vector.py
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
def verify_(self):
    assert isa(self.source.type, MemRefType | TensorType)
    assert isa(self.result.type, VectorType)
    if self.mask:
        assert isa(self.mask.type, VectorType[I1])
        mask_type = self.mask.type
    else:
        mask_type = None

    if len(self.indices) != self.source.type.get_num_dims():
        raise VerifyException("Expected an index for each memref/tensor dimension.")

    if mask_type:
        inferred_mask_type = VectorTransferOperation.infer_transfer_op_mask_type(
            self.result.type,
            self.permutation_map.data,
        )
    else:
        inferred_mask_type = VectorType(i1, [])

    VectorTransferOperation.verify_op(
        self,
        self.source.type,
        self.result.type,
        mask_type,
        inferred_mask_type,
        self.permutation_map.data,
        self.in_bounds,
    )

    if isa(self.source.type.element_type, VectorType):
        # TODO verify vector element type
        pass
    else:
        # source memref/tensor has scalar element type
        # TODO verify that padding type is a valid element_type for a vector
        if self.source.type.element_type != self.padding.type:
            raise VerifyException(
                f'"{self.name}" requires formal padding and source of the same elemental type'
            )

    VectorTransferOperation.verify_permutation_map(
        self,
        self.permutation_map.data,
    )

TransferWriteOp

Bases: VectorTransferOperation

Source code in xdsl/dialects/vector.py
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
@irdl_op_definition
class TransferWriteOp(VectorTransferOperation):
    name = "vector.transfer_write"

    vector = operand_def(VectorType)
    source = operand_def(TensorType | MemRefType)
    indices = var_operand_def(IndexType)
    mask = opt_operand_def(VectorType[I1])

    permutation_map = prop_def(AffineMapAttr)

    result = opt_result_def(TensorType)

    irdl_options = (
        AttrSizedOperandSegments(as_property=True),
        ParsePropInAttrDict(),
    )

    def __init__(
        self,
        vector: SSAValue | Operation,
        source: SSAValue | Operation,
        indices: Sequence[SSAValue | Operation],
        in_bounds: ArrayAttr[BoolAttr],
        mask: SSAValue | Operation | None = None,
        permutation_map: AffineMapAttr | None = None,
        result_type: TensorType | None = None,
    ):
        super().__init__(
            operands=[vector, source, indices, mask],
            properties={"in_bounds": in_bounds, "permutation_map": permutation_map},
            result_types=[result_type],
        )

    def print(self, printer: Printer):
        printer.print_string(" ", indent=0)
        printer.print_operand(self.vector)
        printer.print_string(", ", indent=0)
        printer.print_operand(self.source)
        printer.print_string("[", indent=0)
        printer.print_list(self.indices, printer.print_operand)
        printer.print_string("]", indent=0)
        if self.mask is not None:
            printer.print_string(", ", indent=0)
            printer.print_ssa_value(self.mask)
        self._print_attrs(printer)
        printer.print_string(" : ", indent=0)
        printer.print_attribute(self.vector.type)
        printer.print_string(", ", indent=0)
        printer.print_attribute(self.source.type)

    @classmethod
    def parse(cls, parser: Parser) -> TransferWriteOp:
        vector = parser.parse_unresolved_operand()
        parser.parse_punctuation(",")
        source = parser.parse_unresolved_operand()
        indices = parser.parse_comma_separated_list(
            Parser.Delimiter.SQUARE, parser.parse_operand
        )
        if parser.parse_optional_punctuation(","):
            mask_start_pos = parser.pos
            mask = parser.parse_unresolved_operand()
            mask_end_pos = parser.pos
        else:
            mask_start_pos = None
            mask = None
            mask_end_pos = None
        attributes_dict = parser.parse_optional_attr_dict()

        types_pos = parser.pos
        parser.parse_punctuation(":")
        vector_type = parser.parse_type()
        parser.parse_punctuation(",")
        shaped_type = parser.parse_type()

        vector = parser.resolve_operand(vector, vector_type)
        source = parser.resolve_operand(source, shaped_type)

        if not isa(shaped_type, MemRefType | TensorType):
            parser.raise_error(
                "requires memref or ranked tensor type", at_position=types_pos
            )

        if not isa(vector_type, VectorType):
            parser.raise_error("requires vector type", at_position=types_pos)

        mask, permutation_map, in_bounds = VectorTransferOperation.resolve_attrs(
            parser,
            attributes_dict,
            shaped_type,
            vector_type,
            mask_start_pos,
            mask_end_pos,
            mask,
            types_pos,
        )

        # Create and return the TransferReadOp
        return TransferWriteOp(
            vector=vector,
            source=source,
            indices=indices,
            mask=mask,
            permutation_map=permutation_map,
            in_bounds=in_bounds,
            result_type=shaped_type if isinstance(shaped_type, TensorType) else None,
        )

    def verify_(self):
        assert isa(self.source.type, MemRefType | TensorType)
        assert isa(self.vector.type, VectorType)
        if self.mask:
            assert isa(self.mask.type, VectorType[I1])
            mask_type = self.mask.type
        else:
            mask_type = None

        if len(self.indices) != self.source.type.get_num_dims():
            raise VerifyException("Expected an index for each memref/tensor dimension.")

        if self.has_broadcast_dim():
            raise VerifyException(
                f'"{self.name}" should not have broadcast dimensions.'
            )

        if mask_type:
            inferred_mask_type = VectorTransferOperation.infer_transfer_op_mask_type(
                self.vector.type,
                self.permutation_map.data,
            )
        else:
            inferred_mask_type = VectorType(i1, [])

        VectorTransferOperation.verify_op(
            self,
            self.source.type,
            self.vector.type,
            mask_type,
            inferred_mask_type,
            self.permutation_map.data,
            self.in_bounds,
        )

        VectorTransferOperation.verify_permutation_map(
            self,
            self.permutation_map.data,
        )

name = 'vector.transfer_write' class-attribute instance-attribute

vector = operand_def(VectorType) class-attribute instance-attribute

source = operand_def(TensorType | MemRefType) class-attribute instance-attribute

indices = var_operand_def(IndexType) class-attribute instance-attribute

mask = opt_operand_def(VectorType[I1]) class-attribute instance-attribute

permutation_map = prop_def(AffineMapAttr) class-attribute instance-attribute

result = opt_result_def(TensorType) class-attribute instance-attribute

irdl_options = (AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()) class-attribute instance-attribute

__init__(vector: SSAValue | Operation, source: SSAValue | Operation, indices: Sequence[SSAValue | Operation], in_bounds: ArrayAttr[BoolAttr], mask: SSAValue | Operation | None = None, permutation_map: AffineMapAttr | None = None, result_type: TensorType | None = None)

Source code in xdsl/dialects/vector.py
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
def __init__(
    self,
    vector: SSAValue | Operation,
    source: SSAValue | Operation,
    indices: Sequence[SSAValue | Operation],
    in_bounds: ArrayAttr[BoolAttr],
    mask: SSAValue | Operation | None = None,
    permutation_map: AffineMapAttr | None = None,
    result_type: TensorType | None = None,
):
    super().__init__(
        operands=[vector, source, indices, mask],
        properties={"in_bounds": in_bounds, "permutation_map": permutation_map},
        result_types=[result_type],
    )

print(printer: Printer)

Source code in xdsl/dialects/vector.py
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
def print(self, printer: Printer):
    printer.print_string(" ", indent=0)
    printer.print_operand(self.vector)
    printer.print_string(", ", indent=0)
    printer.print_operand(self.source)
    printer.print_string("[", indent=0)
    printer.print_list(self.indices, printer.print_operand)
    printer.print_string("]", indent=0)
    if self.mask is not None:
        printer.print_string(", ", indent=0)
        printer.print_ssa_value(self.mask)
    self._print_attrs(printer)
    printer.print_string(" : ", indent=0)
    printer.print_attribute(self.vector.type)
    printer.print_string(", ", indent=0)
    printer.print_attribute(self.source.type)

parse(parser: Parser) -> TransferWriteOp classmethod

Source code in xdsl/dialects/vector.py
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
@classmethod
def parse(cls, parser: Parser) -> TransferWriteOp:
    vector = parser.parse_unresolved_operand()
    parser.parse_punctuation(",")
    source = parser.parse_unresolved_operand()
    indices = parser.parse_comma_separated_list(
        Parser.Delimiter.SQUARE, parser.parse_operand
    )
    if parser.parse_optional_punctuation(","):
        mask_start_pos = parser.pos
        mask = parser.parse_unresolved_operand()
        mask_end_pos = parser.pos
    else:
        mask_start_pos = None
        mask = None
        mask_end_pos = None
    attributes_dict = parser.parse_optional_attr_dict()

    types_pos = parser.pos
    parser.parse_punctuation(":")
    vector_type = parser.parse_type()
    parser.parse_punctuation(",")
    shaped_type = parser.parse_type()

    vector = parser.resolve_operand(vector, vector_type)
    source = parser.resolve_operand(source, shaped_type)

    if not isa(shaped_type, MemRefType | TensorType):
        parser.raise_error(
            "requires memref or ranked tensor type", at_position=types_pos
        )

    if not isa(vector_type, VectorType):
        parser.raise_error("requires vector type", at_position=types_pos)

    mask, permutation_map, in_bounds = VectorTransferOperation.resolve_attrs(
        parser,
        attributes_dict,
        shaped_type,
        vector_type,
        mask_start_pos,
        mask_end_pos,
        mask,
        types_pos,
    )

    # Create and return the TransferReadOp
    return TransferWriteOp(
        vector=vector,
        source=source,
        indices=indices,
        mask=mask,
        permutation_map=permutation_map,
        in_bounds=in_bounds,
        result_type=shaped_type if isinstance(shaped_type, TensorType) else None,
    )

verify_()

Source code in xdsl/dialects/vector.py
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
def verify_(self):
    assert isa(self.source.type, MemRefType | TensorType)
    assert isa(self.vector.type, VectorType)
    if self.mask:
        assert isa(self.mask.type, VectorType[I1])
        mask_type = self.mask.type
    else:
        mask_type = None

    if len(self.indices) != self.source.type.get_num_dims():
        raise VerifyException("Expected an index for each memref/tensor dimension.")

    if self.has_broadcast_dim():
        raise VerifyException(
            f'"{self.name}" should not have broadcast dimensions.'
        )

    if mask_type:
        inferred_mask_type = VectorTransferOperation.infer_transfer_op_mask_type(
            self.vector.type,
            self.permutation_map.data,
        )
    else:
        inferred_mask_type = VectorType(i1, [])

    VectorTransferOperation.verify_op(
        self,
        self.source.type,
        self.vector.type,
        mask_type,
        inferred_mask_type,
        self.permutation_map.data,
        self.in_bounds,
    )

    VectorTransferOperation.verify_permutation_map(
        self,
        self.permutation_map.data,
    )

CombiningKindFlag

Bases: StrEnum

Values specifying the kind of combining operation.

Source code in xdsl/dialects/vector.py
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
class CombiningKindFlag(StrEnum):
    """
    Values specifying the kind of combining operation.
    """

    ADD = "add"
    MUL = "mul"
    MINUI = "minui"
    MINSI = "minsi"
    MINNUMF = "minnumf"
    MAXUI = "maxui"
    MAXSI = "maxsi"
    MAXNUMF = "maxnumf"
    AND = "and"
    OR = "or"
    XOR = "xor"
    MAXIMUMF = "maximumf"
    MINIMUMF = "minimumf"

ADD = 'add' class-attribute instance-attribute

MUL = 'mul' class-attribute instance-attribute

MINUI = 'minui' class-attribute instance-attribute

MINSI = 'minsi' class-attribute instance-attribute

MINNUMF = 'minnumf' class-attribute instance-attribute

MAXUI = 'maxui' class-attribute instance-attribute

MAXSI = 'maxsi' class-attribute instance-attribute

MAXNUMF = 'maxnumf' class-attribute instance-attribute

AND = 'and' class-attribute instance-attribute

OR = 'or' class-attribute instance-attribute

XOR = 'xor' class-attribute instance-attribute

MAXIMUMF = 'maximumf' class-attribute instance-attribute

MINIMUMF = 'minimumf' class-attribute instance-attribute

CombiningKindAttr dataclass

Bases: EnumAttribute[CombiningKindFlag]

A mirror of LLVM's vector.kind attribute.

Source code in xdsl/dialects/vector.py
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
@irdl_attr_definition
class CombiningKindAttr(EnumAttribute[CombiningKindFlag]):
    """
    A mirror of LLVM's vector.kind attribute.
    """

    name = "vector.kind"

    def print_parameter(self, printer: Printer) -> None:
        with printer.in_angle_brackets():
            printer.print_string(self.data)

    @classmethod
    def parse_parameter(cls, parser: AttrParser) -> CombiningKindFlag:
        with parser.in_angle_brackets():
            return CombiningKindFlag(parser.parse_identifier())

name = 'vector.kind' class-attribute instance-attribute

print_parameter(printer: Printer) -> None

Source code in xdsl/dialects/vector.py
1504
1505
1506
def print_parameter(self, printer: Printer) -> None:
    with printer.in_angle_brackets():
        printer.print_string(self.data)

parse_parameter(parser: AttrParser) -> CombiningKindFlag classmethod

Source code in xdsl/dialects/vector.py
1508
1509
1510
1511
@classmethod
def parse_parameter(cls, parser: AttrParser) -> CombiningKindFlag:
    with parser.in_angle_brackets():
        return CombiningKindFlag(parser.parse_identifier())

ReductionOp

Bases: IRDLOperation

Source code in xdsl/dialects/vector.py
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
@irdl_op_definition
class ReductionOp(IRDLOperation):
    name = "vector.reduction"

    _T: ClassVar = VarConstraint("T", AnyAttr())

    vector = operand_def(VectorType.constr(_T))
    acc = opt_operand_def(_T)
    dest = result_def(_T)
    kind = prop_def(CombiningKindAttr)
    fastmath = prop_def(FastMathFlagsAttr, default_value=FastMathFlagsAttr("none"))

    assembly_format = "$kind `,` $vector (`,` $acc^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($vector) `into` type($dest)"

    def __init__(
        self,
        vector: SSAValue | Operation,
        kind: CombiningKindAttr,
        acc: SSAValue | Operation | None = None,
        fastmath: FastMathFlagsAttr | None = None,
    ):
        vector = SSAValue.get(vector)
        super().__init__(
            operands=[vector, acc],
            result_types=[vector.type],
            properties={
                "kind": kind,
                "fastmath": fastmath,
            },
        )

name = 'vector.reduction' class-attribute instance-attribute

vector = operand_def(VectorType.constr(_T)) class-attribute instance-attribute

acc = opt_operand_def(_T) class-attribute instance-attribute

dest = result_def(_T) class-attribute instance-attribute

kind = prop_def(CombiningKindAttr) class-attribute instance-attribute

fastmath = prop_def(FastMathFlagsAttr, default_value=(FastMathFlagsAttr('none'))) class-attribute instance-attribute

assembly_format = '$kind `,` $vector (`,` $acc^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($vector) `into` type($dest)' class-attribute instance-attribute

__init__(vector: SSAValue | Operation, kind: CombiningKindAttr, acc: SSAValue | Operation | None = None, fastmath: FastMathFlagsAttr | None = None)

Source code in xdsl/dialects/vector.py
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
def __init__(
    self,
    vector: SSAValue | Operation,
    kind: CombiningKindAttr,
    acc: SSAValue | Operation | None = None,
    fastmath: FastMathFlagsAttr | None = None,
):
    vector = SSAValue.get(vector)
    super().__init__(
        operands=[vector, acc],
        result_types=[vector.type],
        properties={
            "kind": kind,
            "fastmath": fastmath,
        },
    )

BitcastOp

Bases: IRDLOperation

Bitcast between vectors.

See external documentation.

Source code in xdsl/dialects/vector.py
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
@irdl_op_definition
class BitcastOp(IRDLOperation):
    """
    Bitcast between vectors.

    See [external documentation](https://mlir.llvm.org/docs/Dialects/Vector/#vectorbitcast-vectorbitcastop).
    """

    name = "vector.bitcast"

    source = operand_def(
        VectorType.constr(base(IntegerType) | base(IndexType) | AnyFloatConstr)
    )
    result = result_def(
        VectorType.constr(base(IntegerType) | base(IndexType) | AnyFloatConstr)
    )

    assembly_format = "$source attr-dict `:` type($source) `to` type($result)"

    def __init__(
        self,
        source: SSAValue | Operation,
        result_type: Attribute,
    ):
        super().__init__(
            operands=[source],
            result_types=[result_type],
        )

    def verify_(self) -> None:
        s_t = self.source.type
        r_t = self.result.type

        assert isa(s_t, VectorType)
        assert isa(r_t, VectorType)

        s_elem_t = s_t.get_element_type()
        r_elem_t = r_t.get_element_type()
        s_shape = s_t.get_shape()
        r_shape = r_t.get_shape()

        # technically only support index -> index conversions if sizes unknown,
        # and they must have the same shape
        s_elem_t_sized = isinstance(s_elem_t, FixedBitwidthType)
        r_elem_t_sized = isinstance(r_elem_t, FixedBitwidthType)

        if not s_elem_t_sized or not r_elem_t_sized:
            # if they are both unsized and have the same shape
            if not (s_elem_t_sized ^ r_elem_t_sized) and s_shape == r_shape:
                return

            raise VerifyException(
                "For element types of undefined bitwidth, expect "
                + "both types to have undefined bitwidth and shape to be equal"
            )

        source_size = prod(s_shape) * s_elem_t.bitwidth
        result_size = prod(r_shape) * r_elem_t.bitwidth

        # if sizes are known, they must match perfectly
        if not source_size == result_size:
            raise VerifyException(
                "The source and result types do not have an equal bitwidth"
            )

name = 'vector.bitcast' class-attribute instance-attribute

source = operand_def(VectorType.constr(base(IntegerType) | base(IndexType) | AnyFloatConstr)) class-attribute instance-attribute

result = result_def(VectorType.constr(base(IntegerType) | base(IndexType) | AnyFloatConstr)) class-attribute instance-attribute

assembly_format = '$source attr-dict `:` type($source) `to` type($result)' class-attribute instance-attribute

__init__(source: SSAValue | Operation, result_type: Attribute)

Source code in xdsl/dialects/vector.py
1565
1566
1567
1568
1569
1570
1571
1572
1573
def __init__(
    self,
    source: SSAValue | Operation,
    result_type: Attribute,
):
    super().__init__(
        operands=[source],
        result_types=[result_type],
    )

verify_() -> None

Source code in xdsl/dialects/vector.py
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
def verify_(self) -> None:
    s_t = self.source.type
    r_t = self.result.type

    assert isa(s_t, VectorType)
    assert isa(r_t, VectorType)

    s_elem_t = s_t.get_element_type()
    r_elem_t = r_t.get_element_type()
    s_shape = s_t.get_shape()
    r_shape = r_t.get_shape()

    # technically only support index -> index conversions if sizes unknown,
    # and they must have the same shape
    s_elem_t_sized = isinstance(s_elem_t, FixedBitwidthType)
    r_elem_t_sized = isinstance(r_elem_t, FixedBitwidthType)

    if not s_elem_t_sized or not r_elem_t_sized:
        # if they are both unsized and have the same shape
        if not (s_elem_t_sized ^ r_elem_t_sized) and s_shape == r_shape:
            return

        raise VerifyException(
            "For element types of undefined bitwidth, expect "
            + "both types to have undefined bitwidth and shape to be equal"
        )

    source_size = prod(s_shape) * s_elem_t.bitwidth
    result_size = prod(r_shape) * r_elem_t.bitwidth

    # if sizes are known, they must match perfectly
    if not source_size == result_size:
        raise VerifyException(
            "The source and result types do not have an equal bitwidth"
        )