Skip to content

Convert linalg to loops

convert_linalg_to_loops

LowerLinalgStructuredOpPattern

Bases: RewritePattern

Source code in xdsl/transforms/convert_linalg_to_loops.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class LowerLinalgStructuredOpPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self,
        op: linalg.abstract_ops.LinalgStructuredOperation,
        rewriter: PatternRewriter,
    ) -> None:
        if op.res:
            raise NotImplementedError(
                f"lowering for {op.name} with tensor results not yet supported"
            )

        def insert_load(
            value_index: int,
            ind_vars: Sequence[SSAValue],
            rewriter: PatternRewriter,
            insertion_target: InsertPoint,
        ) -> SSAValue:
            value = op.operands[value_index]
            affine_map_attr = op.get_indexing_maps().data[value_index]
            if isinstance(value.type, MemRefType):
                indices = indices_for_map(
                    rewriter, insertion_target, affine_map_attr.data, ind_vars
                )
                load_op = memref.LoadOp.get(value, indices)
                rewriter.insert_op(load_op, insertion_target)
                return load_op.res
            else:
                return value

        ins_count = len(op.inputs)

        def insert_store(
            output_index: int,
            value: SSAValue,
            ind_vars: Sequence[SSAValue],
            rewriter: PatternRewriter,
            insertion_target: InsertPoint,
        ):
            value_index = ins_count + output_index
            destination = op.operands[value_index]
            affine_map_attr = op.get_indexing_maps().data[value_index]
            indices = indices_for_map(
                rewriter, insertion_target, affine_map_attr.data, ind_vars
            )
            store_op = memref.StoreOp.get(value, destination, indices)
            rewriter.insert_op(store_op, insertion_target)
            return store_op

        insertion_point = InsertPoint.before(op)
        rewrite_linalg_structured_to_loops(
            rewriter,
            insertion_point,
            create_loop_bounds(rewriter, insertion_point, op),
            op.get_indexing_maps().data,
            op.get_indexing_maps().data[-len(op.outputs) :],
            op.operands,
            op.outputs,
            op.body.block,
            insert_load,
            insert_store,
        )

match_and_rewrite(op: linalg.abstract_ops.LinalgStructuredOperation, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_linalg_to_loops.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@op_type_rewrite_pattern
def match_and_rewrite(
    self,
    op: linalg.abstract_ops.LinalgStructuredOperation,
    rewriter: PatternRewriter,
) -> None:
    if op.res:
        raise NotImplementedError(
            f"lowering for {op.name} with tensor results not yet supported"
        )

    def insert_load(
        value_index: int,
        ind_vars: Sequence[SSAValue],
        rewriter: PatternRewriter,
        insertion_target: InsertPoint,
    ) -> SSAValue:
        value = op.operands[value_index]
        affine_map_attr = op.get_indexing_maps().data[value_index]
        if isinstance(value.type, MemRefType):
            indices = indices_for_map(
                rewriter, insertion_target, affine_map_attr.data, ind_vars
            )
            load_op = memref.LoadOp.get(value, indices)
            rewriter.insert_op(load_op, insertion_target)
            return load_op.res
        else:
            return value

    ins_count = len(op.inputs)

    def insert_store(
        output_index: int,
        value: SSAValue,
        ind_vars: Sequence[SSAValue],
        rewriter: PatternRewriter,
        insertion_target: InsertPoint,
    ):
        value_index = ins_count + output_index
        destination = op.operands[value_index]
        affine_map_attr = op.get_indexing_maps().data[value_index]
        indices = indices_for_map(
            rewriter, insertion_target, affine_map_attr.data, ind_vars
        )
        store_op = memref.StoreOp.get(value, destination, indices)
        rewriter.insert_op(store_op, insertion_target)
        return store_op

    insertion_point = InsertPoint.before(op)
    rewrite_linalg_structured_to_loops(
        rewriter,
        insertion_point,
        create_loop_bounds(rewriter, insertion_point, op),
        op.get_indexing_maps().data,
        op.get_indexing_maps().data[-len(op.outputs) :],
        op.operands,
        op.outputs,
        op.body.block,
        insert_load,
        insert_store,
    )

ConvertLinalgToLoopsPass dataclass

Bases: ModulePass

Converts a linalg structured ops to perfectly nested loops.

Source code in xdsl/transforms/convert_linalg_to_loops.py
153
154
155
156
157
158
159
160
161
162
163
164
class ConvertLinalgToLoopsPass(ModulePass):
    """
    Converts a linalg structured ops to perfectly nested loops.
    """

    name = "convert-linalg-to-loops"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            LowerLinalgStructuredOpPattern(),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-linalg-to-loops' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/convert_linalg_to_loops.py
160
161
162
163
164
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        LowerLinalgStructuredOpPattern(),
        apply_recursively=False,
    ).rewrite_module(op)

materialize_loop_bound(rewriter: PatternRewriter, insertion_point: InsertPoint, operand: SSAValue[MemRefType], dim_index: int, dim_size: int) -> SSAValue

Create the value used as one loop upper bound.

If the memref dimension is dynamic, use memref.dim. If it is static, use an index constant.

Source code in xdsl/transforms/convert_linalg_to_loops.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def materialize_loop_bound(
    rewriter: PatternRewriter,
    insertion_point: InsertPoint,
    operand: SSAValue[MemRefType],
    dim_index: int,
    dim_size: int,
) -> SSAValue:
    """
    Create the value used as one loop upper bound.

    If the memref dimension is dynamic, use `memref.dim`.
    If it is static, use an index constant.
    """

    if dim_size == DYNAMIC_INDEX:
        dim_index_op = arith.ConstantOp.from_int_and_width(dim_index, IndexType())
        rewriter.insert_op(dim_index_op, insertion_point)

        dim_op = memref.DimOp.from_source_and_index(operand, dim_index_op.result)
        rewriter.insert_op(dim_op, insertion_point)
        return dim_op.result

    else:
        const_op = arith.ConstantOp.from_int_and_width(dim_size, IndexType())
        rewriter.insert_op(const_op, insertion_point)
        return const_op.result

create_loop_bounds(rewriter: PatternRewriter, insertion_point: InsertPoint, op: linalg.abstract_ops.LinalgStructuredOperation) -> Sequence[SSAValue]

Build loop upper bounds for a linalg structured operation.

This lowering only supports buffer semantics, so bound sources must come from memref operands.

Source code in xdsl/transforms/convert_linalg_to_loops.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def create_loop_bounds(
    rewriter: PatternRewriter,
    insertion_point: InsertPoint,
    op: linalg.abstract_ops.LinalgStructuredOperation,
) -> Sequence[SSAValue]:
    """
    Build loop upper bounds for a linalg structured operation.

    This lowering only supports buffer semantics, so bound sources must come
    from memref operands.
    """
    bounds: list[SSAValue] = []

    for operand, dim_index, dim_size in op.get_loop_bound_sources():
        if not isa(operand, SSAValue[MemRefType]):
            raise PassFailedException(
                "convert-linalg-to-loops requires buffer semantics; "
                "tensor operands must be bufferized to memrefs before lowering"
            )

        bounds.append(
            materialize_loop_bound(
                rewriter,
                insertion_point,
                operand,
                dim_index,
                dim_size,
            )
        )

    return bounds