Skip to content

Memref stream unnest out parameters

memref_stream_unnest_out_parameters

UnnestOutParametersPattern

Bases: RewritePattern

Source code in xdsl/transforms/memref_stream_unnest_out_parameters.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class UnnestOutParametersPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: memref_stream.GenericOp, rewriter: PatternRewriter
    ) -> None:
        if op.is_imperfectly_nested:
            # Already unnested
            return

        num_outputs = len(op.outputs)
        if not num_outputs:
            return

        num_inputs = len(op.inputs)

        num_parallel = sum(
            i == memref_stream.IteratorTypeAttr.parallel() for i in op.iterator_types
        )
        num_reduction = sum(
            i == memref_stream.IteratorTypeAttr.reduction() for i in op.iterator_types
        )
        if num_parallel == len(op.iterator_types):
            return

        reduction_dims = (False,) * num_parallel + (True,) * num_reduction

        maps = op.indexing_maps.data[num_inputs:]
        new_maps = ArrayAttr(
            (
                *op.indexing_maps.data[:num_inputs],
                *(AffineMapAttr(m.data.drop_dims(reduction_dims)) for m in maps),
            )
        )

        op.indexing_maps = new_maps

match_and_rewrite(op: memref_stream.GenericOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/memref_stream_unnest_out_parameters.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
    if op.is_imperfectly_nested:
        # Already unnested
        return

    num_outputs = len(op.outputs)
    if not num_outputs:
        return

    num_inputs = len(op.inputs)

    num_parallel = sum(
        i == memref_stream.IteratorTypeAttr.parallel() for i in op.iterator_types
    )
    num_reduction = sum(
        i == memref_stream.IteratorTypeAttr.reduction() for i in op.iterator_types
    )
    if num_parallel == len(op.iterator_types):
        return

    reduction_dims = (False,) * num_parallel + (True,) * num_reduction

    maps = op.indexing_maps.data[num_inputs:]
    new_maps = ArrayAttr(
        (
            *op.indexing_maps.data[:num_inputs],
            *(AffineMapAttr(m.data.drop_dims(reduction_dims)) for m in maps),
        )
    )

    op.indexing_maps = new_maps

MemRefStreamUnnestOutParametersPass dataclass

Bases: ModulePass

Converts the affine maps of memref_stream.generic out parameters from taking all the indices to only taking "parallel" ones.

Source code in xdsl/transforms/memref_stream_unnest_out_parameters.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@dataclass(frozen=True)
class MemRefStreamUnnestOutParametersPass(ModulePass):
    """
    Converts the affine maps of memref_stream.generic out parameters from taking all the
    indices to only taking "parallel" ones.
    """

    name = "memref-stream-unnest-out-parameters"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            UnnestOutParametersPattern(),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'memref-stream-unnest-out-parameters' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/memref_stream_unnest_out_parameters.py
65
66
67
68
69
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        UnnestOutParametersPattern(),
        apply_recursively=False,
    ).rewrite_module(op)