Skip to content

Memref stream infer fill

memref_stream_infer_fill

InferFillPattern

Bases: RewritePattern

Source code in xdsl/transforms/memref_stream_infer_fill.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class InferFillPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: memref_stream.GenericOp, rewriter: PatternRewriter
    ) -> None:
        if len(op.inputs) != 1:
            return

        if len(op.outputs) != 1:
            return

        if op.inits:
            return

        if any(
            iterator_type.data != memref_stream.IteratorType.PARALLEL
            for iterator_type in op.iterator_types.data
        ):
            return

        output = op.outputs[0]
        input = op.inputs[0]

        if not isinstance(output_type := output.type, memref.MemRefType):
            return

        output_type = cast(memref.MemRefType, output.type)

        type_shape = output_type.get_shape()
        bounds = tuple(attr.value.data for attr in op.bounds)

        if type_shape != bounds:
            return

        if input.type != output_type.element_type:
            return

        block = op.body.block
        ops = tuple(block.ops)

        if len(ops) != 1:
            return

        if not isinstance(yield_op := ops[0], memref_stream.YieldOp):
            return

        if len(yielded_vals := tuple(yield_op.operands)) != 1:
            return

        if yielded_vals[0] is not block.args[0]:
            return

        rewriter.replace_op(op, memref_stream.FillOp(output, input))

match_and_rewrite(op: memref_stream.GenericOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/memref_stream_infer_fill.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
    if len(op.inputs) != 1:
        return

    if len(op.outputs) != 1:
        return

    if op.inits:
        return

    if any(
        iterator_type.data != memref_stream.IteratorType.PARALLEL
        for iterator_type in op.iterator_types.data
    ):
        return

    output = op.outputs[0]
    input = op.inputs[0]

    if not isinstance(output_type := output.type, memref.MemRefType):
        return

    output_type = cast(memref.MemRefType, output.type)

    type_shape = output_type.get_shape()
    bounds = tuple(attr.value.data for attr in op.bounds)

    if type_shape != bounds:
        return

    if input.type != output_type.element_type:
        return

    block = op.body.block
    ops = tuple(block.ops)

    if len(ops) != 1:
        return

    if not isinstance(yield_op := ops[0], memref_stream.YieldOp):
        return

    if len(yielded_vals := tuple(yield_op.operands)) != 1:
        return

    if yielded_vals[0] is not block.args[0]:
        return

    rewriter.replace_op(op, memref_stream.FillOp(output, input))

MemRefStreamInferFillPass dataclass

Bases: ModulePass

Detects memref_stream.generic operations that can be represented as memref_stream.fill ops.

Source code in xdsl/transforms/memref_stream_infer_fill.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@dataclass(frozen=True)
class MemRefStreamInferFillPass(ModulePass):
    """
    Detects memref_stream.generic operations that can be represented as
    `memref_stream.fill` ops.
    """

    name = "memref-stream-infer-fill"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            InferFillPattern(),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'memref-stream-infer-fill' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/memref_stream_infer_fill.py
80
81
82
83
84
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        InferFillPattern(),
        apply_recursively=False,
    ).rewrite_module(op)