Skip to content

Convert linalg to loops

convert_linalg_to_loops

LowerGenericOpPattern

Bases: RewritePattern

Source code in xdsl/transforms/convert_linalg_to_loops.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class LowerGenericOpPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: linalg.GenericOp, rewriter: PatternRewriter
    ) -> None:
        if op.res:
            raise NotImplementedError(
                "lowering for linalg.generic with results not yet supported"
            )

        def insert_load(
            value_index: int,
            ind_vars: Sequence[SSAValue],
            rewriter: PatternRewriter,
            insertion_target: InsertPoint,
        ) -> SSAValue:
            value = op.operands[value_index]
            affine_map_attr = op.indexing_maps.data[value_index]
            if isinstance(value.type, MemRefType):
                indices = indices_for_map(
                    rewriter, insertion_target, affine_map_attr.data, ind_vars
                )
                load_op = memref.LoadOp.get(value, indices)
                rewriter.insert_op(load_op, insertion_target)
                return load_op.res
            else:
                return value

        ins_count = len(op.inputs)

        def insert_store(
            output_index: int,
            value: SSAValue,
            ind_vars: Sequence[SSAValue],
            rewriter: PatternRewriter,
            insertion_target: InsertPoint,
        ):
            value_index = ins_count + output_index
            destination = op.operands[value_index]
            affine_map_attr = op.indexing_maps.data[value_index]
            indices = indices_for_map(
                rewriter, insertion_target, affine_map_attr.data, ind_vars
            )
            store_op = memref.StoreOp.get(value, destination, indices)
            rewriter.insert_op(store_op, insertion_target)
            return store_op

        rewrite_generic_to_loops(
            rewriter,
            InsertPoint.before(op),
            op.get_static_loop_ranges(),
            op.indexing_maps.data,
            op.indexing_maps.data[-len(op.outputs) :],
            op.operands,
            op.outputs,
            op.body.block,
            insert_load,
            insert_store,
        )

match_and_rewrite(op: linalg.GenericOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_linalg_to_loops.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: linalg.GenericOp, rewriter: PatternRewriter
) -> None:
    if op.res:
        raise NotImplementedError(
            "lowering for linalg.generic with results not yet supported"
        )

    def insert_load(
        value_index: int,
        ind_vars: Sequence[SSAValue],
        rewriter: PatternRewriter,
        insertion_target: InsertPoint,
    ) -> SSAValue:
        value = op.operands[value_index]
        affine_map_attr = op.indexing_maps.data[value_index]
        if isinstance(value.type, MemRefType):
            indices = indices_for_map(
                rewriter, insertion_target, affine_map_attr.data, ind_vars
            )
            load_op = memref.LoadOp.get(value, indices)
            rewriter.insert_op(load_op, insertion_target)
            return load_op.res
        else:
            return value

    ins_count = len(op.inputs)

    def insert_store(
        output_index: int,
        value: SSAValue,
        ind_vars: Sequence[SSAValue],
        rewriter: PatternRewriter,
        insertion_target: InsertPoint,
    ):
        value_index = ins_count + output_index
        destination = op.operands[value_index]
        affine_map_attr = op.indexing_maps.data[value_index]
        indices = indices_for_map(
            rewriter, insertion_target, affine_map_attr.data, ind_vars
        )
        store_op = memref.StoreOp.get(value, destination, indices)
        rewriter.insert_op(store_op, insertion_target)
        return store_op

    rewrite_generic_to_loops(
        rewriter,
        InsertPoint.before(op),
        op.get_static_loop_ranges(),
        op.indexing_maps.data,
        op.indexing_maps.data[-len(op.outputs) :],
        op.operands,
        op.outputs,
        op.body.block,
        insert_load,
        insert_store,
    )

ConvertLinalgToLoopsPass dataclass

Bases: ModulePass

Converts a linalg generic to perfectly nested loops.

Source code in xdsl/transforms/convert_linalg_to_loops.py
83
84
85
86
87
88
89
90
91
92
93
94
class ConvertLinalgToLoopsPass(ModulePass):
    """
    Converts a linalg generic to perfectly nested loops.
    """

    name = "convert-linalg-to-loops"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier([LowerGenericOpPattern()]),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-linalg-to-loops' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/convert_linalg_to_loops.py
90
91
92
93
94
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier([LowerGenericOpPattern()]),
        apply_recursively=False,
    ).rewrite_module(op)