Skip to content

Memref stream legalize

memref_stream_legalize

StreamingVectorLegalizationType: TypeAlias = VectorType[Float64Type | Float32Type | Float16Type] module-attribute

StreamingAlreadyLegalType: TypeAlias = Float64Type module-attribute

MemRefStreamGenericLegalize dataclass

Bases: RewritePattern

Source code in xdsl/transforms/memref_stream_legalize.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
@dataclass(frozen=True)
class MemRefStreamGenericLegalize(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: memref_stream.GenericOp, rewriter: PatternRewriter
    ) -> None:
        # Collect block arguments that need to be legalized
        legalizations: dict[int, StreamingVectorLegalizationType] = {}
        for i, arg in enumerate(op.body.block.args):
            legal = _legalize_attr(arg.type)
            if not isinstance(legal, StreamingAlreadyLegalType):
                legalizations[i] = legal
        if not legalizations:
            return
        if op.iterator_types.data[-1].data != IteratorType.PARALLEL:
            raise DiagnosticException(
                "iterators other than 'parallel' are not supported yet"
            )
        # Check that vectorized bounds are compatible with all no. of lanes
        # involved in legalizations
        innermost_bound = op.bounds.data[-1].value.data
        vector_lengths: set[int] = set()
        for i, v in legalizations.items():
            n_lanes: int = v.get_shape()[0]
            if innermost_bound % n_lanes != 0:
                raise ValueError(
                    f"no. of vector lanes ({n_lanes}) introduced to legalize argument #{i} "
                    f"is not a divisor for the innermost dimension's bound ({innermost_bound})"
                )
            vector_lengths.add(n_lanes)
        if len(vector_lengths) != 1:
            # FIXME we should deal with heterogeneous generic ops
            raise NotImplementedError(
                "cannot legalize heterogeneous block arguments yet"
            )
        vlen = next(iter(vector_lengths))
        # Fix iteration bounds accordingly
        new_bounds = list(op.bounds)
        new_bounds.pop()
        new_bounds.append(IntegerAttr.from_index_int_value(innermost_bound // vlen))
        # Fix access maps accordingly
        new_dims = [AffineExpr.dimension(i) for i in range(len(op.bounds))]
        new_dims[-1] = new_dims[-1] * vlen
        new_maps = tuple(
            AffineMapAttr(
                m.data.replace_dims_and_symbols(new_dims, (), len(new_dims), 0)
            )
            for m in op.indexing_maps
        )
        # Legalize block arguments
        new_body = op.body.clone()
        # Starting point for block legalization
        to_be_legalized: set[Operation] = set()
        for i, arg in enumerate(new_body.block.args):
            if i not in legalizations:
                continue
            arg = rewriter.replace_value_with_new_type(arg, legalizations[i])
            to_be_legalized.update(use.operation for use in arg.uses)
        # Legalize payload
        _legalize_block(new_body.block, to_be_legalized, rewriter)

        rewriter.replace_op(
            op,
            memref_stream.GenericOp(
                op.inputs,
                op.outputs,
                op.inits,
                new_body,
                ArrayAttr(new_maps),
                op.iterator_types,
                ArrayAttr(new_bounds),
                op.init_indices,
            ),
        )

__init__() -> None

match_and_rewrite(op: memref_stream.GenericOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/memref_stream_legalize.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
    # Collect block arguments that need to be legalized
    legalizations: dict[int, StreamingVectorLegalizationType] = {}
    for i, arg in enumerate(op.body.block.args):
        legal = _legalize_attr(arg.type)
        if not isinstance(legal, StreamingAlreadyLegalType):
            legalizations[i] = legal
    if not legalizations:
        return
    if op.iterator_types.data[-1].data != IteratorType.PARALLEL:
        raise DiagnosticException(
            "iterators other than 'parallel' are not supported yet"
        )
    # Check that vectorized bounds are compatible with all no. of lanes
    # involved in legalizations
    innermost_bound = op.bounds.data[-1].value.data
    vector_lengths: set[int] = set()
    for i, v in legalizations.items():
        n_lanes: int = v.get_shape()[0]
        if innermost_bound % n_lanes != 0:
            raise ValueError(
                f"no. of vector lanes ({n_lanes}) introduced to legalize argument #{i} "
                f"is not a divisor for the innermost dimension's bound ({innermost_bound})"
            )
        vector_lengths.add(n_lanes)
    if len(vector_lengths) != 1:
        # FIXME we should deal with heterogeneous generic ops
        raise NotImplementedError(
            "cannot legalize heterogeneous block arguments yet"
        )
    vlen = next(iter(vector_lengths))
    # Fix iteration bounds accordingly
    new_bounds = list(op.bounds)
    new_bounds.pop()
    new_bounds.append(IntegerAttr.from_index_int_value(innermost_bound // vlen))
    # Fix access maps accordingly
    new_dims = [AffineExpr.dimension(i) for i in range(len(op.bounds))]
    new_dims[-1] = new_dims[-1] * vlen
    new_maps = tuple(
        AffineMapAttr(
            m.data.replace_dims_and_symbols(new_dims, (), len(new_dims), 0)
        )
        for m in op.indexing_maps
    )
    # Legalize block arguments
    new_body = op.body.clone()
    # Starting point for block legalization
    to_be_legalized: set[Operation] = set()
    for i, arg in enumerate(new_body.block.args):
        if i not in legalizations:
            continue
        arg = rewriter.replace_value_with_new_type(arg, legalizations[i])
        to_be_legalized.update(use.operation for use in arg.uses)
    # Legalize payload
    _legalize_block(new_body.block, to_be_legalized, rewriter)

    rewriter.replace_op(
        op,
        memref_stream.GenericOp(
            op.inputs,
            op.outputs,
            op.inits,
            new_body,
            ArrayAttr(new_maps),
            op.iterator_types,
            ArrayAttr(new_bounds),
            op.init_indices,
        ),
    )

MemRefStreamLegalizePass dataclass

Bases: ModulePass

Legalize memref_stream.generic payload and bounds for streaming.

Source code in xdsl/transforms/memref_stream_legalize.py
194
195
196
197
198
199
200
201
202
203
204
205
206
@dataclass(frozen=True)
class MemRefStreamLegalizePass(ModulePass):
    """
    Legalize memref_stream.generic payload and bounds for streaming.
    """

    name = "memref-stream-legalize"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier([MemRefStreamGenericLegalize()]),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'memref-stream-legalize' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/memref_stream_legalize.py
202
203
204
205
206
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier([MemRefStreamGenericLegalize()]),
        apply_recursively=False,
    ).rewrite_module(op)