Skip to content

Lower affine

lower_affine

LowerAffineStore

Bases: RewritePattern

Source code in xdsl/transforms/lower_affine.py
85
86
87
88
89
90
91
92
93
class LowerAffineStore(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: affine.StoreOp, rewriter: PatternRewriter):
        ops, indices = insert_affine_map_ops(op.map, op.indices, [])
        rewriter.insert_op(ops)

        # TODO: add nontemporal=false once that's added to memref
        # https://github.com/xdslproject/xdsl/issues/1482
        rewriter.replace_op(op, memref.StoreOp.get(op.value, op.memref, indices))

match_and_rewrite(op: affine.StoreOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_affine.py
86
87
88
89
90
91
92
93
@op_type_rewrite_pattern
def match_and_rewrite(self, op: affine.StoreOp, rewriter: PatternRewriter):
    ops, indices = insert_affine_map_ops(op.map, op.indices, [])
    rewriter.insert_op(ops)

    # TODO: add nontemporal=false once that's added to memref
    # https://github.com/xdslproject/xdsl/issues/1482
    rewriter.replace_op(op, memref.StoreOp.get(op.value, op.memref, indices))

LowerAffineLoad

Bases: RewritePattern

Source code in xdsl/transforms/lower_affine.py
 96
 97
 98
 99
100
101
102
103
104
class LowerAffineLoad(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: affine.LoadOp, rewriter: PatternRewriter):
        ops, indices = insert_affine_map_ops(op.map, op.indices, [])
        rewriter.insert_op(ops)

        # TODO: add nontemporal=false once that's added to memref
        # https://github.com/xdslproject/xdsl/issues/1482
        rewriter.replace_op(op, memref.LoadOp.get(op.memref, indices))

match_and_rewrite(op: affine.LoadOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_affine.py
 97
 98
 99
100
101
102
103
104
@op_type_rewrite_pattern
def match_and_rewrite(self, op: affine.LoadOp, rewriter: PatternRewriter):
    ops, indices = insert_affine_map_ops(op.map, op.indices, [])
    rewriter.insert_op(ops)

    # TODO: add nontemporal=false once that's added to memref
    # https://github.com/xdslproject/xdsl/issues/1482
    rewriter.replace_op(op, memref.LoadOp.get(op.memref, indices))

LowerAffineFor

Bases: RewritePattern

Source code in xdsl/transforms/lower_affine.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class LowerAffineFor(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: affine.ForOp, rewriter: PatternRewriter):
        lb_map = op.lowerBoundMap.data
        ub_map = op.upperBoundMap.data
        assert len(lb_map.results) == 1, "Affine for lower_bound must have one result"
        assert len(ub_map.results) == 1, "Affine for upper_bound must have one result"
        lb_ops, lb_val = affine_expr_ops(lb_map.results[0], [], [])
        rewriter.insert_op(lb_ops)
        ub_ops, ub_val = affine_expr_ops(ub_map.results[0], [], [])
        rewriter.insert_op(ub_ops)
        step_op = arith.ConstantOp(op.step)
        rewriter.insert_op(step_op)
        rewriter.replace_op(
            op,
            scf.ForOp(
                lb_val,
                ub_val,
                step_op.result,
                op.inits,
                rewriter.move_region_contents_to_new_regions(op.body),
            ),
        )

match_and_rewrite(op: affine.ForOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_affine.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@op_type_rewrite_pattern
def match_and_rewrite(self, op: affine.ForOp, rewriter: PatternRewriter):
    lb_map = op.lowerBoundMap.data
    ub_map = op.upperBoundMap.data
    assert len(lb_map.results) == 1, "Affine for lower_bound must have one result"
    assert len(ub_map.results) == 1, "Affine for upper_bound must have one result"
    lb_ops, lb_val = affine_expr_ops(lb_map.results[0], [], [])
    rewriter.insert_op(lb_ops)
    ub_ops, ub_val = affine_expr_ops(ub_map.results[0], [], [])
    rewriter.insert_op(ub_ops)
    step_op = arith.ConstantOp(op.step)
    rewriter.insert_op(step_op)
    rewriter.replace_op(
        op,
        scf.ForOp(
            lb_val,
            ub_val,
            step_op.result,
            op.inits,
            rewriter.move_region_contents_to_new_regions(op.body),
        ),
    )

LowerAffineYield

Bases: RewritePattern

Source code in xdsl/transforms/lower_affine.py
132
133
134
135
class LowerAffineYield(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: affine.YieldOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(op, scf.YieldOp(*op.arguments))

match_and_rewrite(op: affine.YieldOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_affine.py
133
134
135
@op_type_rewrite_pattern
def match_and_rewrite(self, op: affine.YieldOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(op, scf.YieldOp(*op.arguments))

LowerAffineApply

Bases: RewritePattern

Source code in xdsl/transforms/lower_affine.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class LowerAffineApply(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: affine.ApplyOp, rewriter: PatternRewriter, /):
        affine_map = op.map.data
        assert len(affine_map.results) == 1

        operands = op.mapOperands
        assert affine_map.num_dims + affine_map.num_symbols == len(operands)

        dims = operands[: affine_map.num_dims]
        symbols = operands[affine_map.num_dims :]

        new_ops: list[Operation] = []
        new_results: list[SSAValue] = []

        ops, val = affine_expr_ops(affine_map.results[0], dims, symbols)
        new_ops.extend(ops)
        new_results.append(val)
        rewriter.replace_op(op, new_ops, new_results)

match_and_rewrite(op: affine.ApplyOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_affine.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@op_type_rewrite_pattern
def match_and_rewrite(self, op: affine.ApplyOp, rewriter: PatternRewriter, /):
    affine_map = op.map.data
    assert len(affine_map.results) == 1

    operands = op.mapOperands
    assert affine_map.num_dims + affine_map.num_symbols == len(operands)

    dims = operands[: affine_map.num_dims]
    symbols = operands[affine_map.num_dims :]

    new_ops: list[Operation] = []
    new_results: list[SSAValue] = []

    ops, val = affine_expr_ops(affine_map.results[0], dims, symbols)
    new_ops.extend(ops)
    new_results.append(val)
    rewriter.replace_op(op, new_ops, new_results)

LowerAffinePass dataclass

Bases: ModulePass

Source code in xdsl/transforms/lower_affine.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class LowerAffinePass(ModulePass):
    name = "lower-affine"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LowerAffineStore(),
                    LowerAffineLoad(),
                    LowerAffineFor(),
                    LowerAffineYield(),
                    LowerAffineApply(),
                ]
            )
        ).rewrite_module(op)

name = 'lower-affine' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/lower_affine.py
162
163
164
165
166
167
168
169
170
171
172
173
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                LowerAffineStore(),
                LowerAffineLoad(),
                LowerAffineFor(),
                LowerAffineYield(),
                LowerAffineApply(),
            ]
        )
    ).rewrite_module(op)

affine_expr_ops(expr: affine.AffineExpr, dims: Sequence[SSAValue], symbols: Sequence[SSAValue]) -> tuple[list[Operation], SSAValue]

Returns the operations that evaluate the affine expression when given input SSA values, along with the SSAValue representing the result.

Source code in xdsl/transforms/lower_affine.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def affine_expr_ops(
    expr: affine.AffineExpr,
    dims: Sequence[SSAValue],
    symbols: Sequence[SSAValue],
) -> tuple[list[Operation], SSAValue]:
    """
    Returns the operations that evaluate the affine expression when given input SSA
    values, along with the SSAValue representing the result.
    """
    match expr:
        case AffineConstantExpr():
            constant = arith.ConstantOp(
                builtin.IntegerAttr.from_index_int_value(expr.value)
            )
            return [constant], constant.result
        case AffineDimExpr():
            return [], dims[expr.position]
        case AffineSymExpr():
            return [], symbols[expr.position]
        case AffineBinaryOpExpr():
            lhs_ops, lhs_val = affine_expr_ops(expr.lhs, dims, symbols)
            rhs_ops, rhs_val = affine_expr_ops(expr.rhs, dims, symbols)

            match expr.kind:
                case AffineBinaryOpKind.Add:
                    op = arith.AddiOp(lhs_val, rhs_val)
                case AffineBinaryOpKind.Mul:
                    op = arith.MuliOp(lhs_val, rhs_val)
                case AffineBinaryOpKind.Mod:
                    op = arith.RemSIOp(lhs_val, rhs_val)
                case AffineBinaryOpKind.FloorDiv:
                    op = arith.FloorDivSIOp(lhs_val, rhs_val)
                case AffineBinaryOpKind.CeilDiv:
                    op = arith.CeilDivSIOp(lhs_val, rhs_val)

            return [*lhs_ops, *rhs_ops, op], op.result
        case _:
            raise ValueError(f"Unexpected affine expr: {expr}")

insert_affine_map_ops(map: affine.AffineMapAttr | None, dims: Sequence[SSAValue], symbols: list[SSAValue]) -> tuple[list[Operation], list[SSAValue]]

Returns operations that evaluate the affine map when given input SSA values and the resulting indices.

Source code in xdsl/transforms/lower_affine.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def insert_affine_map_ops(
    map: affine.AffineMapAttr | None,
    dims: Sequence[SSAValue],
    symbols: list[SSAValue],
) -> tuple[list[Operation], list[SSAValue]]:
    """
    Returns operations that evaluate the affine map when given input SSA values and the
    resulting indices.
    """
    ops: list[Operation] = []
    if map is None:
        indices = list(dims)
    else:
        indices: list[SSAValue] = []
        for expr in map.data.results:
            new_ops, val = affine_expr_ops(expr, dims, [])
            ops.extend(new_ops)
            indices.append(val)

    return ops, indices