Skip to content

Reshape ops utils

reshape_ops_utils

Utilities used by reshape ops. See MLIR counterpart for more details.

ArrayOfIntArrayAttr = ArrayAttr[ArrayAttr[IntegerAttr]] module-attribute

ContiguousArrayOfIntArray dataclass

Bases: AttrConstraint[ArrayOfIntArrayAttr]

Enforce an ArrayAttr of ArrayAttr[IntegerAttr] to contain contiguous integer values across all inner arrays. For example: [[0, 1], [2, 3]] is valid, but [[3, 4], [0, 1]] is not. An empty inner array is considered contiguous.

Source code in xdsl/dialects/utils/reshape_ops_utils.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@dataclass(frozen=True)
class ContiguousArrayOfIntArray(AttrConstraint[ArrayOfIntArrayAttr]):
    """
    Enforce an ArrayAttr of ArrayAttr[IntegerAttr] to contain contiguous integer values across all inner arrays.
    For example: [[0, 1], [2, 3]] is valid, but [[3, 4], [0, 1]] is not.
    An empty inner array is considered contiguous.
    """

    def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
        _CONTIGUOUS_ARRAY_TYPE_CONSTRAINT.verify(
            attr, constraint_context=constraint_context
        )
        attr = cast(ArrayOfIntArrayAttr, attr)

        # Flatten all integer values from all inner arrays
        flat_values = [e.value.data for inner in attr.data for e in inner.data]
        # Check that the flattened list is contiguous
        for prev, curr in zip(flat_values, flat_values[1:]):
            if curr != prev + 1:
                raise VerifyException(f"All inner arrays must be contiguous: {attr}")

    def mapping_type_vars(
        self, type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint]
    ) -> "ContiguousArrayOfIntArray":
        # No type variables to map in this constraint
        return self

__init__() -> None

verify(attr: Attribute, constraint_context: ConstraintContext) -> None

Source code in xdsl/dialects/utils/reshape_ops_utils.py
53
54
55
56
57
58
59
60
61
62
63
64
def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
    _CONTIGUOUS_ARRAY_TYPE_CONSTRAINT.verify(
        attr, constraint_context=constraint_context
    )
    attr = cast(ArrayOfIntArrayAttr, attr)

    # Flatten all integer values from all inner arrays
    flat_values = [e.value.data for inner in attr.data for e in inner.data]
    # Check that the flattened list is contiguous
    for prev, curr in zip(flat_values, flat_values[1:]):
        if curr != prev + 1:
            raise VerifyException(f"All inner arrays must be contiguous: {attr}")

mapping_type_vars(type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint]) -> ContiguousArrayOfIntArray

Source code in xdsl/dialects/utils/reshape_ops_utils.py
66
67
68
69
70
def mapping_type_vars(
    self, type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint]
) -> "ContiguousArrayOfIntArray":
    # No type variables to map in this constraint
    return self

verify_reshape_like_types(collapsed_type: ShapedType, expanded_type: ShapedType, reassociation: ArrayAttr[ArrayAttr[IntegerAttr]])

Verify that collapsed and expanded types conform to reassociation mapping.

Source code in xdsl/dialects/utils/reshape_ops_utils.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def verify_reshape_like_types(
    collapsed_type: ShapedType,
    expanded_type: ShapedType,
    reassociation: ArrayAttr[ArrayAttr[IntegerAttr]],
):
    """
    Verify that collapsed and expanded types conform to reassociation mapping.
    """
    expanded_rank = len(expanded_type.get_shape())
    collapsed_rank = len(collapsed_type.get_shape())

    if expanded_rank < collapsed_rank:
        raise VerifyException(
            f"expected the expanded type, {expanded_type} to have a higher (or same) rank "
            f"than the collapsed type, {collapsed_type}."
        )

    if collapsed_rank != len(reassociation):
        raise VerifyException(
            f"expected collapsed rank ({collapsed_rank}) to equal the number of "
            f"reassociation maps ({len(reassociation)})."
        )

    # Check that the total reassociation dimensions match the expanded type's rank.
    total_reassociation_dims = sum(len(rm) for rm in reassociation)
    if total_reassociation_dims != expanded_rank:
        raise VerifyException(
            f"expected the total number of reassociation dimensions ({total_reassociation_dims}) "
            f"to equal the expanded type's rank ({expanded_rank})."
        )

    verify_reshape_like_shapes_are_compatible(
        collapsed_shape=collapsed_type.get_shape(),
        expanded_shape=expanded_type.get_shape(),
        reassociation=reassociation,
    )

verify_reshape_like_shapes_are_compatible(collapsed_shape: tuple[int, ...], expanded_shape: tuple[int, ...], reassociation: ArrayOfIntArrayAttr)

Verify that collapsed and expanded shapes adhere to reassociation mapping.

Source code in xdsl/dialects/utils/reshape_ops_utils.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def verify_reshape_like_shapes_are_compatible(
    collapsed_shape: tuple[int, ...],
    expanded_shape: tuple[int, ...],
    reassociation: ArrayOfIntArrayAttr,
):
    """
    Verify that collapsed and expanded shapes adhere to reassociation mapping.
    """
    expanded_dim_start = 0

    for map_idx, rm in enumerate(reassociation):
        found_dynamic = False
        linearized_static = 1

        # Look at the next `len(rm)` dims in expanded_shape
        for dim in expanded_shape[expanded_dim_start : expanded_dim_start + len(rm)]:
            if dim == DYNAMIC_INDEX:
                found_dynamic = True
            else:
                linearized_static *= dim

        if found_dynamic:
            # if any is dynamic, the collapsed must be dynamic too
            if not collapsed_shape[map_idx] == DYNAMIC_INDEX:
                raise VerifyException(
                    f"expected dimension {map_idx} of collapsed type to be dynamic "
                    f"since one or more of the corresponding dimensions in the "
                    f"expanded type is dynamic"
                )
        else:
            # all static → product must match
            if collapsed_shape[map_idx] != linearized_static:
                raise VerifyException(
                    f"expected dimension {map_idx} of collapsed type to be static "
                    f"value of {linearized_static}"
                )

        expanded_dim_start += len(rm)