Skip to content

Memref

memref

MemRefSubviewOfSubviewFolding

Bases: RewritePattern

Source code in xdsl/transforms/canonicalization_patterns/memref.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class MemRefSubviewOfSubviewFolding(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
        source_subview = op.source.owner
        if not isinstance(source_subview, memref.SubviewOp):
            return

        current_strides = op.static_strides.get_values()

        if not all(stride == 1 for stride in current_strides):
            return

        if not all(
            stride == 1 for stride in source_subview.static_strides.iter_values()
        ):
            return

        if not len(op.static_offsets) == len(source_subview.static_offsets):
            return

        assert isa(source_subview.source.type, memref.MemRefType)

        assert isa(op.result.type, memref.MemRefType)

        reduce_rank = False

        if len(source_subview.source.type.shape) != len(op.result.type.shape):
            reduce_rank = True

        if len(op.offsets) > 0 or len(source_subview.offsets) > 0:
            return
        if len(op.sizes) > 0 or len(source_subview.sizes) > 0:
            return
        if len(op.strides) > 0 or len(source_subview.strides) > 0:
            return

        new_offsets = [
            off1 + off2
            for off1, off2 in zip(
                op.static_offsets.iter_values(),
                source_subview.static_offsets.iter_values(),
                strict=True,
            )
        ]

        current_sizes = op.static_sizes.get_values()

        new_op = memref.SubviewOp.from_static_parameters(
            source_subview.source,
            source_subview.source.type,
            new_offsets,
            current_sizes,
            current_strides,
            reduce_rank=reduce_rank,
        )
        if reduce_rank:
            if new_op.result.type != op.result.type:
                return

        rewriter.replace_op(op, new_op)

match_and_rewrite(op: memref.SubviewOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/memref.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
    source_subview = op.source.owner
    if not isinstance(source_subview, memref.SubviewOp):
        return

    current_strides = op.static_strides.get_values()

    if not all(stride == 1 for stride in current_strides):
        return

    if not all(
        stride == 1 for stride in source_subview.static_strides.iter_values()
    ):
        return

    if not len(op.static_offsets) == len(source_subview.static_offsets):
        return

    assert isa(source_subview.source.type, memref.MemRefType)

    assert isa(op.result.type, memref.MemRefType)

    reduce_rank = False

    if len(source_subview.source.type.shape) != len(op.result.type.shape):
        reduce_rank = True

    if len(op.offsets) > 0 or len(source_subview.offsets) > 0:
        return
    if len(op.sizes) > 0 or len(source_subview.sizes) > 0:
        return
    if len(op.strides) > 0 or len(source_subview.strides) > 0:
        return

    new_offsets = [
        off1 + off2
        for off1, off2 in zip(
            op.static_offsets.iter_values(),
            source_subview.static_offsets.iter_values(),
            strict=True,
        )
    ]

    current_sizes = op.static_sizes.get_values()

    new_op = memref.SubviewOp.from_static_parameters(
        source_subview.source,
        source_subview.source.type,
        new_offsets,
        current_sizes,
        current_strides,
        reduce_rank=reduce_rank,
    )
    if reduce_rank:
        if new_op.result.type != op.result.type:
            return

    rewriter.replace_op(op, new_op)

ElideUnusedAlloc

Bases: RewritePattern

Source code in xdsl/transforms/canonicalization_patterns/memref.py
72
73
74
75
76
77
class ElideUnusedAlloc(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: memref.AllocOp, rewriter: PatternRewriter, /):
        if isinstance(only_use := op.memref.get_user_of_unique_use(), memref.DeallocOp):
            rewriter.erase_op(only_use)
            rewriter.erase_op(op)

match_and_rewrite(op: memref.AllocOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/memref.py
73
74
75
76
77
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.AllocOp, rewriter: PatternRewriter, /):
    if isinstance(only_use := op.memref.get_user_of_unique_use(), memref.DeallocOp):
        rewriter.erase_op(only_use)
        rewriter.erase_op(op)