Skip to content

Memref stream generalize fill

memref_stream_generalize_fill

GeneralizeFillPattern

Bases: RewritePattern

Source code in xdsl/transforms/memref_stream_generalize_fill.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class GeneralizeFillPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: memref_stream.FillOp, rewriter: PatternRewriter
    ) -> None:
        block = Block(arg_types=(op.value.type, op.value.type))

        with ImplicitBuilder(block) as (arg0, _):
            memref_stream.YieldOp(arg0)

        assert isinstance(memref_type := op.memref.type, memref.MemRefType)

        memref_type = cast(MemRefType, memref_type)

        shape = memref_type.get_shape()
        index = IndexType()
        ubs = ArrayAttr(IntegerAttr(ub, index) for ub in shape)

        rewriter.replace_op(
            op,
            memref_stream.GenericOp(
                (op.value,),
                (op.memref,),
                (),
                Region((block,)),
                ArrayAttr(
                    (
                        AffineMapAttr(AffineMap(len(shape), 0, ())),
                        AffineMapAttr(AffineMap.identity(len(shape))),
                    )
                ),
                ArrayAttr((memref_stream.IteratorTypeAttr.parallel(),) * len(shape)),
                ubs,
                ArrayAttr(()),
            ),
        )

match_and_rewrite(op: memref_stream.FillOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/memref_stream_generalize_fill.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: memref_stream.FillOp, rewriter: PatternRewriter
) -> None:
    block = Block(arg_types=(op.value.type, op.value.type))

    with ImplicitBuilder(block) as (arg0, _):
        memref_stream.YieldOp(arg0)

    assert isinstance(memref_type := op.memref.type, memref.MemRefType)

    memref_type = cast(MemRefType, memref_type)

    shape = memref_type.get_shape()
    index = IndexType()
    ubs = ArrayAttr(IntegerAttr(ub, index) for ub in shape)

    rewriter.replace_op(
        op,
        memref_stream.GenericOp(
            (op.value,),
            (op.memref,),
            (),
            Region((block,)),
            ArrayAttr(
                (
                    AffineMapAttr(AffineMap(len(shape), 0, ())),
                    AffineMapAttr(AffineMap.identity(len(shape))),
                )
            ),
            ArrayAttr((memref_stream.IteratorTypeAttr.parallel(),) * len(shape)),
            ubs,
            ArrayAttr(()),
        ),
    )

MemRefStreamGeneralizeFillPass dataclass

Bases: ModulePass

Generalizes memref_stream.fill ops.

Source code in xdsl/transforms/memref_stream_generalize_fill.py
64
65
66
67
68
69
70
71
72
73
74
75
76
@dataclass(frozen=True)
class MemRefStreamGeneralizeFillPass(ModulePass):
    """
    Generalizes memref_stream.fill ops.
    """

    name = "memref-stream-generalize-fill"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GeneralizeFillPattern(),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'memref-stream-generalize-fill' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/memref_stream_generalize_fill.py
72
73
74
75
76
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GeneralizeFillPattern(),
        apply_recursively=False,
    ).rewrite_module(op)