Skip to content

Convert linalg to memref stream

convert_linalg_to_memref_stream

ConvertGenericOpPattern

Bases: RewritePattern

Source code in xdsl/transforms/convert_linalg_to_memref_stream.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class ConvertGenericOpPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: linalg.GenericOp, rewriter: PatternRewriter
    ) -> None:
        if op.res:
            raise NotImplementedError(
                "converting linalg.generic with results not supported"
            )

        # The memref_stream.generic op may take as arguments memrefs, scalars, or streams,
        # the latter of which does not carry shape information. linalg.generic constructs
        # the nested loop bounds from the shapes of the inputs, so we need to cache that
        # derived information here, as we may not be able to recover it later.
        ubs = op.get_static_loop_ranges()
        index = IndexType()
        bounds = ArrayAttr(IntegerAttr(IntAttr(ub), index) for ub in ubs)

        iterator_types = ArrayAttr(iterator_type_attr(t) for t in op.iterator_types)

        rewriter.replace_op(
            op,
            memref_stream.GenericOp(
                op.inputs,
                op.outputs,
                (),
                rewriter.move_region_contents_to_new_regions(op.body),
                op.indexing_maps,
                iterator_types,
                bounds,
                ArrayAttr(()),
                op.doc,
                op.library_call,
            ),
        )

match_and_rewrite(op: linalg.GenericOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_linalg_to_memref_stream.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: linalg.GenericOp, rewriter: PatternRewriter
) -> None:
    if op.res:
        raise NotImplementedError(
            "converting linalg.generic with results not supported"
        )

    # The memref_stream.generic op may take as arguments memrefs, scalars, or streams,
    # the latter of which does not carry shape information. linalg.generic constructs
    # the nested loop bounds from the shapes of the inputs, so we need to cache that
    # derived information here, as we may not be able to recover it later.
    ubs = op.get_static_loop_ranges()
    index = IndexType()
    bounds = ArrayAttr(IntegerAttr(IntAttr(ub), index) for ub in ubs)

    iterator_types = ArrayAttr(iterator_type_attr(t) for t in op.iterator_types)

    rewriter.replace_op(
        op,
        memref_stream.GenericOp(
            op.inputs,
            op.outputs,
            (),
            rewriter.move_region_contents_to_new_regions(op.body),
            op.indexing_maps,
            iterator_types,
            bounds,
            ArrayAttr(()),
            op.doc,
            op.library_call,
        ),
    )

ConvertYieldOpPattern

Bases: RewritePattern

Source code in xdsl/transforms/convert_linalg_to_memref_stream.py
61
62
63
64
class ConvertYieldOpPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: linalg.YieldOp, rewriter: PatternRewriter) -> None:
        rewriter.replace_op(op, memref_stream.YieldOp(*op.operands))

match_and_rewrite(op: linalg.YieldOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_linalg_to_memref_stream.py
62
63
64
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.YieldOp, rewriter: PatternRewriter) -> None:
    rewriter.replace_op(op, memref_stream.YieldOp(*op.operands))

ConvertLinalgToMemRefStreamPass dataclass

Bases: ModulePass

Source code in xdsl/transforms/convert_linalg_to_memref_stream.py
67
68
69
70
71
72
73
74
75
76
class ConvertLinalgToMemRefStreamPass(ModulePass):
    name = "convert-linalg-to-memref-stream"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [ConvertGenericOpPattern(), ConvertYieldOpPattern()]
            ),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-linalg-to-memref-stream' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/convert_linalg_to_memref_stream.py
70
71
72
73
74
75
76
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [ConvertGenericOpPattern(), ConvertYieldOpPattern()]
        ),
        apply_recursively=False,
    ).rewrite_module(op)

iterator_type_attr(t: linalg.IteratorTypeAttr) -> memref_stream.IteratorTypeAttr

Source code in xdsl/transforms/convert_linalg_to_memref_stream.py
14
15
16
17
18
19
20
21
def iterator_type_attr(t: linalg.IteratorTypeAttr) -> memref_stream.IteratorTypeAttr:
    match t.data:
        case linalg.IteratorType.PARALLEL:
            return memref_stream.IteratorTypeAttr.parallel()
        case linalg.IteratorType.REDUCTION:
            return memref_stream.IteratorTypeAttr.reduction()
        case linalg.IteratorType.WINDOW:
            raise NotImplementedError("Cannot convert window iterator type")