Skip to content

Stencil tensorize z dimension

stencil_tensorize_z_dimension

StencilTypeConversion dataclass

Bases: TypeConversionPattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
122
123
124
125
class StencilTypeConversion(TypeConversionPattern):
    @attr_type_rewrite_pattern
    def convert_type(self, typ: FieldType[Attribute]) -> FieldType[Attribute]:
        return stencil_field_to_tensor(typ)

convert_type(typ: FieldType[Attribute]) -> FieldType[Attribute]

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
123
124
125
@attr_type_rewrite_pattern
def convert_type(self, typ: FieldType[Attribute]) -> FieldType[Attribute]:
    return stencil_field_to_tensor(typ)

AccessOpTensorize dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@dataclass(frozen=True)
class AccessOpTensorize(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /):
        if not is_tensorized(op.operands[0].type) or len(op.offset) != 3:
            return
        xy_offsets, z_offset = (
            tuple(o for o in op.offset)[:-1],
            tuple(o for o in op.offset)[-1],
        )
        a = AccessOp.get(op.temp, xy_offsets)
        # this conditional controls if ExtractSliceOps for x/y accesses should be generated
        # if xy_offsets[0] != 0 or xy_offsets[1] != 0:
        #     rewriter.replace_op(op, a)
        #     return
        assert isa(op.temp.type, TempType[Attribute])
        assert is_tensor(element_t := op.temp.type.get_element_type())
        extract = ExtractSliceOp.from_static_parameters(
            a, [z_offset], element_t.get_shape()
        )
        rewriter.insert_op(a, InsertPoint.before(op))
        rewriter.replace_op(op, extract)

__init__() -> None

match_and_rewrite(op: AccessOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@op_type_rewrite_pattern
def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /):
    if not is_tensorized(op.operands[0].type) or len(op.offset) != 3:
        return
    xy_offsets, z_offset = (
        tuple(o for o in op.offset)[:-1],
        tuple(o for o in op.offset)[-1],
    )
    a = AccessOp.get(op.temp, xy_offsets)
    # this conditional controls if ExtractSliceOps for x/y accesses should be generated
    # if xy_offsets[0] != 0 or xy_offsets[1] != 0:
    #     rewriter.replace_op(op, a)
    #     return
    assert isa(op.temp.type, TempType[Attribute])
    assert is_tensor(element_t := op.temp.type.get_element_type())
    extract = ExtractSliceOp.from_static_parameters(
        a, [z_offset], element_t.get_shape()
    )
    rewriter.insert_op(a, InsertPoint.before(op))
    rewriter.replace_op(op, extract)

ArithOpTensorize

Bases: RewritePattern

Tensorises arith binary ops. If both operands are tensor types, rebuilds the op with matching result type. If one operand is scalar and an arith.constant, create a tensor constant directly. If one operand is scalar and not an arith.constant, create an empty tensor and fill it with the scalar value.

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class ArithOpTensorize(RewritePattern):
    """
    Tensorises arith binary ops.
    If both operands are tensor types, rebuilds the op with matching result type.
    If one operand is scalar and an `arith.constant`, create a tensor constant directly.
    If one operand is scalar and not an `arith.constant`, create an empty tensor and fill it with the scalar value.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: FloatingPointLikeBinaryOperation, rewriter: PatternRewriter, /
    ):
        type_constructor = type(op)
        if is_tensor(op.result.type):
            return
        if is_tensor(op.lhs.type) and is_tensor(op.rhs.type):
            rewriter.replace_op(
                op,
                type_constructor(op.lhs, op.rhs, flags=None, result_type=op.lhs.type),
            )
        elif isa(op.lhs.type, TensorType[AnyFloat]) and is_scalar(op.rhs.type):
            new_rhs = ArithOpTensorize._rewrite_scalar_operand(
                op.rhs, op.lhs.type, op, rewriter
            )
            rewriter.replace_op(
                op,
                type_constructor(op.lhs, new_rhs, flags=None, result_type=op.lhs.type),
            )
        elif is_scalar(op.lhs.type) and isa(op.rhs.type, TensorType[AnyFloat]):
            new_lhs = ArithOpTensorize._rewrite_scalar_operand(
                op.lhs, op.rhs.type, op, rewriter
            )
            rewriter.replace_op(
                op,
                type_constructor(new_lhs, op.rhs, flags=None, result_type=op.rhs.type),
            )

    @staticmethod
    def _rewrite_scalar_operand(
        scalar_op: SSAValue,
        dest_typ: TensorType[AnyFloat],
        op: FloatingPointLikeBinaryOperation,
        rewriter: PatternRewriter,
    ) -> SSAValue:
        """
        Rewrites a scalar operand into a tensor.
        If it is a constant, create a corresponding tensor constant.
        If it is not a constant, create an empty tensor and `linalg.fill` it with the scalar value.
        """
        if isinstance(scalar_op, OpResult) and isinstance(scalar_op.op, ConstantOp):
            assert isinstance(float_attr := scalar_op.op.value, FloatAttr)
            scalar_value = float_attr.value.data
            tens_const = ConstantOp(
                DenseIntOrFPElementsAttr.from_list(dest_typ, [scalar_value])
            )
            rewriter.insert_op(tens_const, InsertPoint.before(scalar_op.op))
            return tens_const.result
        emptyop = EmptyOp((), dest_typ)
        fillop = FillOp((scalar_op,), (emptyop.tensor,), (dest_typ,))
        rewriter.insert_op(emptyop, InsertPoint.before(op))
        rewriter.insert_op(fillop, InsertPoint.before(op))
        return fillop.res[0]

match_and_rewrite(op: FloatingPointLikeBinaryOperation, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: FloatingPointLikeBinaryOperation, rewriter: PatternRewriter, /
):
    type_constructor = type(op)
    if is_tensor(op.result.type):
        return
    if is_tensor(op.lhs.type) and is_tensor(op.rhs.type):
        rewriter.replace_op(
            op,
            type_constructor(op.lhs, op.rhs, flags=None, result_type=op.lhs.type),
        )
    elif isa(op.lhs.type, TensorType[AnyFloat]) and is_scalar(op.rhs.type):
        new_rhs = ArithOpTensorize._rewrite_scalar_operand(
            op.rhs, op.lhs.type, op, rewriter
        )
        rewriter.replace_op(
            op,
            type_constructor(op.lhs, new_rhs, flags=None, result_type=op.lhs.type),
        )
    elif is_scalar(op.lhs.type) and isa(op.rhs.type, TensorType[AnyFloat]):
        new_lhs = ArithOpTensorize._rewrite_scalar_operand(
            op.lhs, op.rhs.type, op, rewriter
        )
        rewriter.replace_op(
            op,
            type_constructor(new_lhs, op.rhs, flags=None, result_type=op.rhs.type),
        )

ApplyOpTensorize dataclass

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
@dataclass(frozen=True)
class ApplyOpTensorize(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
        if all(is_tensorized(arg.type) for arg in op.args):
            access_patterns = dict[Operand, AccessPattern](
                zip(op.region.block.args, op.get_accesses())
            )
            for access_op in op.region.walk():
                if isinstance(access_op, AccessOp):
                    z_shift = -access_patterns[access_op.temp].halo_in_axis(2)[0]
                    access_op.offset = IndexAttr.get(
                        *access_op.offset.array.data[:-1],
                        access_op.offset.array.data[-1].data + z_shift,
                    )

            body = Block(arg_types=op.operand_types)
            rewriter.inline_block(
                op.region.block, InsertPoint.at_start(body), body.args
            )

            rewriter.replace_op(
                op,
                ApplyOp.get(
                    op.args,
                    body,
                    [stencil_temp_to_tensor(r.type) for r in op.res],
                ),
            )

__init__() -> None

match_and_rewrite(op: ApplyOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
    if all(is_tensorized(arg.type) for arg in op.args):
        access_patterns = dict[Operand, AccessPattern](
            zip(op.region.block.args, op.get_accesses())
        )
        for access_op in op.region.walk():
            if isinstance(access_op, AccessOp):
                z_shift = -access_patterns[access_op.temp].halo_in_axis(2)[0]
                access_op.offset = IndexAttr.get(
                    *access_op.offset.array.data[:-1],
                    access_op.offset.array.data[-1].data + z_shift,
                )

        body = Block(arg_types=op.operand_types)
        rewriter.inline_block(
            op.region.block, InsertPoint.at_start(body), body.args
        )

        rewriter.replace_op(
            op,
            ApplyOp.get(
                op.args,
                body,
                [stencil_temp_to_tensor(r.type) for r in op.res],
            ),
        )

FuncOpTensorize

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
247
248
249
250
251
252
253
254
255
class FuncOpTensorize(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /):
        if not op.is_declaration:
            for arg in op.args:
                if isa(arg.type, FieldType[Attribute]):
                    op.replace_argument_type(
                        arg, stencil_field_to_tensor(arg.type), rewriter
                    )

match_and_rewrite(op: FuncOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
248
249
250
251
252
253
254
255
@op_type_rewrite_pattern
def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /):
    if not op.is_declaration:
        for arg in op.args:
            if isa(arg.type, FieldType[Attribute]):
                op.replace_argument_type(
                    arg, stencil_field_to_tensor(arg.type), rewriter
                )

LoadOpTensorize

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
276
277
278
279
280
281
282
283
284
285
286
287
288
class LoadOpTensorize(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: LoadOp, rewriter: PatternRewriter, /):
        assert isa(op.res.type, TempType[Attribute])
        assert isinstance(bounds := op.res.type.bounds, StencilBoundsAttr)
        rewriter.replace_op(
            op,
            LoadOp.get(
                op.field,
                IndexAttr.get(*[lb for lb in bounds.lb][:-1]),
                IndexAttr.get(*[ub for ub in bounds.ub][:-1]),
            ),
        )

match_and_rewrite(op: LoadOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
277
278
279
280
281
282
283
284
285
286
287
288
@op_type_rewrite_pattern
def match_and_rewrite(self, op: LoadOp, rewriter: PatternRewriter, /):
    assert isa(op.res.type, TempType[Attribute])
    assert isinstance(bounds := op.res.type.bounds, StencilBoundsAttr)
    rewriter.replace_op(
        op,
        LoadOp.get(
            op.field,
            IndexAttr.get(*[lb for lb in bounds.lb][:-1]),
            IndexAttr.get(*[ub for ub in bounds.ub][:-1]),
        ),
    )

DmpSwapOpTensorize

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
291
292
293
294
295
296
297
298
299
300
301
302
class DmpSwapOpTensorize(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
        if (
            is_tensorized(op.input_stencil.type)
            and op.swapped_values
            and not is_tensorized(op.swapped_values.type)
        ):
            rewriter.replace_op(
                op,
                dmp.SwapOp.get(op.input_stencil, op.strategy, ArrayAttr(op.swaps.data)),
            )

match_and_rewrite(op: dmp.SwapOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
292
293
294
295
296
297
298
299
300
301
302
@op_type_rewrite_pattern
def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
    if (
        is_tensorized(op.input_stencil.type)
        and op.swapped_values
        and not is_tensorized(op.swapped_values.type)
    ):
        rewriter.replace_op(
            op,
            dmp.SwapOp.get(op.input_stencil, op.strategy, ArrayAttr(op.swaps.data)),
        )

StoreOpTensorize

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
class StoreOpTensorize(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter, /):
        if (
            is_tensorized(op.field.type)
            and isinstance(op.field.type, ShapedType)
            and len(op.bounds.lb) != len(op.field.type.get_shape())
        ):
            rewriter.replace_op(
                op,
                StoreOp.get(
                    op.temp,
                    op.field,
                    StencilBoundsAttr(
                        zip(list(op.bounds.lb.array)[:-1], list(op.bounds.ub)[:-1])
                    ),
                ),
            )

match_and_rewrite(op: StoreOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
@op_type_rewrite_pattern
def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter, /):
    if (
        is_tensorized(op.field.type)
        and isinstance(op.field.type, ShapedType)
        and len(op.bounds.lb) != len(op.field.type.get_shape())
    ):
        rewriter.replace_op(
            op,
            StoreOp.get(
                op.temp,
                op.field,
                StencilBoundsAttr(
                    zip(list(op.bounds.lb.array)[:-1], list(op.bounds.ub)[:-1])
                ),
            ),
        )

AccessOpUpdateShape

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
325
326
327
328
329
330
331
332
333
334
335
class AccessOpUpdateShape(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /):
        if typ := get_required_result_type(op):
            if needs_update_shape(op.res.type, typ):
                rewriter.replace_op(
                    op,
                    AccessOp.build(
                        operands=[op.temp], attributes=op.attributes, result_types=[typ]
                    ),
                )

match_and_rewrite(op: AccessOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
326
327
328
329
330
331
332
333
334
335
@op_type_rewrite_pattern
def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /):
    if typ := get_required_result_type(op):
        if needs_update_shape(op.res.type, typ):
            rewriter.replace_op(
                op,
                AccessOp.build(
                    operands=[op.temp], attributes=op.attributes, result_types=[typ]
                ),
            )

CslStencilAccessOpUpdateShape

Bases: RewritePattern

Updates the result type of a tensorized csl_stencil.access op

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
class CslStencilAccessOpUpdateShape(RewritePattern):
    """
    Updates the result type of a tensorized `csl_stencil.access` op
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: csl_stencil.AccessOp, rewriter: PatternRewriter, /):
        if typ := get_required_result_type(op):
            if needs_update_shape(op.result.type, typ) and (
                isa(op.op.type, TempType[TensorType[Attribute]])
                or isa(op.op.type, TensorType[Attribute])
            ):
                rewriter.replace_op(
                    op,
                    csl_stencil.AccessOp(
                        op.op,
                        op.offset,
                        (
                            op.op.type.get_element_type()
                            if isa(op.op.type, TempType[TensorType[Attribute]])
                            else typ
                        ),
                        op.offset_mapping,
                    ),
                )

match_and_rewrite(op: csl_stencil.AccessOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl_stencil.AccessOp, rewriter: PatternRewriter, /):
    if typ := get_required_result_type(op):
        if needs_update_shape(op.result.type, typ) and (
            isa(op.op.type, TempType[TensorType[Attribute]])
            or isa(op.op.type, TensorType[Attribute])
        ):
            rewriter.replace_op(
                op,
                csl_stencil.AccessOp(
                    op.op,
                    op.offset,
                    (
                        op.op.type.get_element_type()
                        if isa(op.op.type, TempType[TensorType[Attribute]])
                        else typ
                    ),
                    op.offset_mapping,
                ),
            )

ExtractSliceOpUpdateShape

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
365
366
367
368
369
370
371
372
373
374
375
class ExtractSliceOpUpdateShape(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ExtractSliceOp, rewriter: PatternRewriter, /):
        if typ := get_required_result_type(op):
            if needs_update_shape(op.result.type, typ):
                rewriter.replace_op(
                    op,
                    ExtractSliceOp.from_static_parameters(
                        op.source, op.static_offsets.get_values(), typ.get_shape()
                    ),
                )

match_and_rewrite(op: ExtractSliceOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
366
367
368
369
370
371
372
373
374
375
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ExtractSliceOp, rewriter: PatternRewriter, /):
    if typ := get_required_result_type(op):
        if needs_update_shape(op.result.type, typ):
            rewriter.replace_op(
                op,
                ExtractSliceOp.from_static_parameters(
                    op.source, op.static_offsets.get_values(), typ.get_shape()
                ),
            )

ArithOpUpdateShape

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
389
390
391
392
393
394
class ArithOpUpdateShape(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: FloatingPointLikeBinaryOperation, rewriter: PatternRewriter, /
    ):
        arithBinaryOpUpdateShape(op, rewriter)

match_and_rewrite(op: FloatingPointLikeBinaryOperation, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
390
391
392
393
394
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: FloatingPointLikeBinaryOperation, rewriter: PatternRewriter, /
):
    arithBinaryOpUpdateShape(op, rewriter)

VarithOpUpdateShape

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
397
398
399
400
401
402
403
404
405
class VarithOpUpdateShape(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /):
        type_constructor = type(op)
        if typ := get_required_result_type(op):
            if needs_update_shape(op.result_types[0], typ):
                rewriter.replace_op(
                    op, type_constructor.build(operands=[op.args], result_types=[typ])
                )

match_and_rewrite(op: varith.VarithOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
398
399
400
401
402
403
404
405
@op_type_rewrite_pattern
def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /):
    type_constructor = type(op)
    if typ := get_required_result_type(op):
        if needs_update_shape(op.result_types[0], typ):
            rewriter.replace_op(
                op, type_constructor.build(operands=[op.args], result_types=[typ])
            )

EmptyOpUpdateShape

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
408
409
410
411
412
413
class EmptyOpUpdateShape(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: EmptyOp, rewriter: PatternRewriter, /):
        if typ := get_required_result_type(op):
            if needs_update_shape(op.results[0].type, typ):
                rewriter.replace_op(op, EmptyOp((), typ))

match_and_rewrite(op: EmptyOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
409
410
411
412
413
@op_type_rewrite_pattern
def match_and_rewrite(self, op: EmptyOp, rewriter: PatternRewriter, /):
    if typ := get_required_result_type(op):
        if needs_update_shape(op.results[0].type, typ):
            rewriter.replace_op(op, EmptyOp((), typ))

FillOpUpdateShape

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
416
417
418
419
420
421
422
423
class FillOpUpdateShape(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: FillOp, rewriter: PatternRewriter, /):
        if typ := get_required_result_type(op):
            if needs_update_shape(op.results[0].type, typ):
                rewriter.replace_op(
                    op, FillOp(op.inputs, op.outputs, [typ] * len(op.outputs))
                )

match_and_rewrite(op: FillOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
417
418
419
420
421
422
423
@op_type_rewrite_pattern
def match_and_rewrite(self, op: FillOp, rewriter: PatternRewriter, /):
    if typ := get_required_result_type(op):
        if needs_update_shape(op.results[0].type, typ):
            rewriter.replace_op(
                op, FillOp(op.inputs, op.outputs, [typ] * len(op.outputs))
            )

ConstOpUpdateShape

Bases: RewritePattern

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
426
427
428
429
430
431
432
433
434
435
class ConstOpUpdateShape(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ConstantOp, rewriter: PatternRewriter, /):
        if is_tensor(op.result.type):
            if typ := get_required_result_type(op):
                if needs_update_shape(op.result.type, typ):
                    assert isinstance(op.value, DenseIntOrFPElementsAttr)
                    rewriter.replace_op(
                        op, ConstantOp(DenseIntOrFPElementsAttr(typ, op.value.data))
                    )

match_and_rewrite(op: ConstantOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
427
428
429
430
431
432
433
434
435
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ConstantOp, rewriter: PatternRewriter, /):
    if is_tensor(op.result.type):
        if typ := get_required_result_type(op):
            if needs_update_shape(op.result.type, typ):
                assert isinstance(op.value, DenseIntOrFPElementsAttr)
                rewriter.replace_op(
                    op, ConstantOp(DenseIntOrFPElementsAttr(typ, op.value.data))
                )

BackpropagateStencilShapes dataclass

Bases: ModulePass

Greedily back-propagates the result types of tensorized ops. Use after creating/modifying tensorization.

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
@dataclass(frozen=True)
class BackpropagateStencilShapes(ModulePass):
    """
    Greedily back-propagates the result types of tensorized ops.
    Use after creating/modifying tensorization.
    """

    name = "backpropagate-stencil-shapes"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        backpropagate_stencil_shapes = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    CslStencilAccessOpUpdateShape(),
                    ExtractSliceOpUpdateShape(),
                    EmptyOpUpdateShape(),
                    FillOpUpdateShape(),
                    ArithOpUpdateShape(),
                    VarithOpUpdateShape(),
                    ConstOpUpdateShape(),
                ]
            ),
            walk_reverse=True,
            apply_recursively=False,
        )
        backpropagate_stencil_shapes.rewrite_module(op)

name = 'backpropagate-stencil-shapes' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    backpropagate_stencil_shapes = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                CslStencilAccessOpUpdateShape(),
                ExtractSliceOpUpdateShape(),
                EmptyOpUpdateShape(),
                FillOpUpdateShape(),
                ArithOpUpdateShape(),
                VarithOpUpdateShape(),
                ConstOpUpdateShape(),
            ]
        ),
        walk_reverse=True,
        apply_recursively=False,
    )
    backpropagate_stencil_shapes.rewrite_module(op)

StencilTensorizeZDimension dataclass

Bases: ModulePass

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
@dataclass(frozen=True)
class StencilTensorizeZDimension(ModulePass):
    name = "stencil-tensorize-z-dimension"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        module_pass = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    FuncOpTensorize(),
                    StencilTypeConversion(),  # this needs to come after FuncOpTensorize()
                    LoadOpTensorize(),
                    ApplyOpTensorize(),
                    StoreOpTensorize(),
                    DmpSwapOpTensorize(),
                    # AccessOpTensorize(),   # this doesn't work here, using second pass
                ]
            ),
            walk_reverse=False,
            apply_recursively=False,
        )
        module_pass.rewrite_module(op)
        stencil_pass = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    AccessOpTensorize(),
                    ArithOpTensorize(),
                ]
            ),
            walk_reverse=False,
            apply_recursively=False,
        )
        stencil_pass.rewrite_module(op)
        BackpropagateStencilShapes().apply(ctx=ctx, op=op)

name = 'stencil-tensorize-z-dimension' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def apply(self, ctx: Context, op: ModuleOp) -> None:
    module_pass = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                FuncOpTensorize(),
                StencilTypeConversion(),  # this needs to come after FuncOpTensorize()
                LoadOpTensorize(),
                ApplyOpTensorize(),
                StoreOpTensorize(),
                DmpSwapOpTensorize(),
                # AccessOpTensorize(),   # this doesn't work here, using second pass
            ]
        ),
        walk_reverse=False,
        apply_recursively=False,
    )
    module_pass.rewrite_module(op)
    stencil_pass = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                AccessOpTensorize(),
                ArithOpTensorize(),
            ]
        ),
        walk_reverse=False,
        apply_recursively=False,
    )
    stencil_pass.rewrite_module(op)
    BackpropagateStencilShapes().apply(ctx=ctx, op=op)

get_required_result_type(op: Operation) -> TensorType[Any] | None

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def get_required_result_type(op: Operation) -> TensorType[Any] | None:
    for result in op.results:
        for use in result.uses:
            if (
                isinstance(use.operation, ReturnOp)
                and (p_op := use.operation.parent_op()) is not None
            ):
                for ret in p_op.results:
                    if is_tensorized(ret.type):
                        if isa(ret.type, TempType) and isa(
                            r_type := ret.type.get_element_type(), TensorType
                        ):
                            return r_type
                # abort when encountering an un-tensorized ReturnOp successor
                return None
            if isinstance(use.operation, InsertSliceOp) and is_tensor(
                use.operation.result.type
            ):
                static_sizes = use.operation.static_sizes.get_values()
                assert is_tensor(use.operation.source.type)
                # inserting an (n-1)d tensor into an (n)d tensor should not require the input tensor to also be (n)d
                # instead, drop the first `dimdiff` dimensions
                dimdiff = len(static_sizes) - len(use.operation.source.type.shape)
                return TensorType(
                    use.operation.result.type.get_element_type(),
                    static_sizes[dimdiff:],
                )
            for ret in use.operation.results:
                if isa(r_type := ret.type, TensorType):
                    return r_type

needs_update_shape(op_type: Attribute, succ_req_type: TensorType[Attribute]) -> bool

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
93
94
95
96
97
def needs_update_shape(
    op_type: Attribute, succ_req_type: TensorType[Attribute]
) -> bool:
    assert isa(op_type, TensorType)
    return op_type.get_shape() != succ_req_type.get_shape()

stencil_field_to_tensor(field: FieldType[Attribute]) -> FieldType[Attribute]

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
100
101
102
103
104
105
106
107
108
def stencil_field_to_tensor(field: FieldType[Attribute]) -> FieldType[Attribute]:
    if field.get_num_dims() != 3:
        return field
    typ = TensorType(field.get_element_type(), [field.get_shape()[-1]])
    assert isinstance(field.bounds, StencilBoundsAttr)
    assert isinstance(field.bounds.lb, IndexAttr)
    assert isinstance(field.bounds.ub, IndexAttr)
    bounds = list(zip(field.bounds.lb, field.bounds.ub))[:-1]
    return FieldType[Attribute](bounds, typ)

stencil_temp_to_tensor(field: TempType[Attribute]) -> TempType[Attribute]

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
111
112
113
114
115
116
117
118
119
def stencil_temp_to_tensor(field: TempType[Attribute]) -> TempType[Attribute]:
    if field.get_num_dims() != 3:
        return field
    typ = TensorType(field.get_element_type(), [field.get_shape()[-1]])
    assert isinstance(field.bounds, StencilBoundsAttr)
    assert isinstance(field.bounds.lb, IndexAttr)
    assert isinstance(field.bounds.ub, IndexAttr)
    bounds = list(zip(field.bounds.lb, field.bounds.ub))[:-1]
    return TempType[Attribute](bounds, typ)

is_tensorized(typ: Attribute)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
258
259
260
261
262
263
def is_tensorized(
    typ: Attribute,
):
    assert isinstance(typ, ShapedType)
    assert isinstance(typ, ContainerType)
    return len(typ.get_shape()) == 2 and isinstance(typ.get_element_type(), TensorType)

is_tensor(typ: Attribute) -> TypeGuard[TensorType[IndexType | IntegerType | AnyFloat]]

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
266
267
268
269
def is_tensor(
    typ: Attribute,
) -> TypeGuard[TensorType[IndexType | IntegerType | AnyFloat]]:
    return isinstance(typ, TensorType)

is_scalar(typ: Attribute) -> TypeGuard[AnyFloat]

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
272
273
def is_scalar(typ: Attribute) -> TypeGuard[AnyFloat]:
    return isinstance(typ, AnyFloat)

arithBinaryOpUpdateShape(op: FloatingPointLikeBinaryOperation, rewriter: PatternRewriter)

Source code in xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
378
379
380
381
382
383
384
385
386
def arithBinaryOpUpdateShape(
    op: FloatingPointLikeBinaryOperation,
    rewriter: PatternRewriter,
    /,
):
    type_constructor = type(op)
    if typ := get_required_result_type(op):
        if needs_update_shape(op.result.type, typ):
            rewriter.replace_op(op, type_constructor(op.lhs, op.rhs, result_type=typ))