Skip to content

Linalg to csl

linalg_to_csl

ConvertLinalgGenericFMAPass

Bases: RewritePattern

Lowers linalg.generic fused multiply-adds to csl builtin ops.

Source code in xdsl/transforms/linalg_to_csl.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class ConvertLinalgGenericFMAPass(RewritePattern):
    """Lowers `linalg.generic` fused multiply-adds to csl builtin ops."""

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: linalg.GenericOp, rewriter: PatternRewriter, /):
        if not self.is_fma(op) or not isa(op.outputs.types[0], MemRefType):
            return

        # one of the factors must be a scalar const, which the csl function signatures require
        if scalar_const := get_scalar_const(op.inputs[0]):
            rewriter.insert_op(
                a := arith.ConstantOp(scalar_const), InsertPoint.before(op)
            )
            x = op.inputs[1]
        elif scalar_const := get_scalar_const(op.inputs[1]):
            rewriter.insert_op(
                a := arith.ConstantOp(scalar_const), InsertPoint.before(op)
            )
            x = op.inputs[0]
        else:
            # if neither factor is a scalar, return
            return

        # fetch the csl op to build depending on the precision
        csl_op = match_op_for_precision(
            op.outputs.types[0].get_element_type(), f16=csl.FmachOp, f32=csl.FmacsOp
        )

        r = op.outputs[0]
        y = op.inputs[2]

        # builds `r = a * x + y`
        rewriter.replace_op(op, csl_op(operands=[[r, y, x, a]]))

    @staticmethod
    def is_fma(op: linalg.GenericOp) -> bool:
        """Returns if a given `generic` op is a fused multiply-add"""
        return (
            len(op.inputs) == 3
            and len(op.outputs) == 1
            and len((block := op.body.block).args) == 4
            and len(block.ops) == 3
            and isinstance(mul := block.first_op, arith.MulfOp)
            and mul.lhs == block.args[0]
            and mul.rhs == block.args[1]
            and isinstance(add := mul.next_op, arith.AddfOp)
            and add.lhs == mul.result
            and add.rhs == block.args[2]
            and isinstance(yld := add.next_op, linalg.YieldOp)
            and yld.operands[0] == add.result
        )

match_and_rewrite(op: linalg.GenericOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/linalg_to_csl.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.GenericOp, rewriter: PatternRewriter, /):
    if not self.is_fma(op) or not isa(op.outputs.types[0], MemRefType):
        return

    # one of the factors must be a scalar const, which the csl function signatures require
    if scalar_const := get_scalar_const(op.inputs[0]):
        rewriter.insert_op(
            a := arith.ConstantOp(scalar_const), InsertPoint.before(op)
        )
        x = op.inputs[1]
    elif scalar_const := get_scalar_const(op.inputs[1]):
        rewriter.insert_op(
            a := arith.ConstantOp(scalar_const), InsertPoint.before(op)
        )
        x = op.inputs[0]
    else:
        # if neither factor is a scalar, return
        return

    # fetch the csl op to build depending on the precision
    csl_op = match_op_for_precision(
        op.outputs.types[0].get_element_type(), f16=csl.FmachOp, f32=csl.FmacsOp
    )

    r = op.outputs[0]
    y = op.inputs[2]

    # builds `r = a * x + y`
    rewriter.replace_op(op, csl_op(operands=[[r, y, x, a]]))

is_fma(op: linalg.GenericOp) -> bool staticmethod

Returns if a given generic op is a fused multiply-add

Source code in xdsl/transforms/linalg_to_csl.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@staticmethod
def is_fma(op: linalg.GenericOp) -> bool:
    """Returns if a given `generic` op is a fused multiply-add"""
    return (
        len(op.inputs) == 3
        and len(op.outputs) == 1
        and len((block := op.body.block).args) == 4
        and len(block.ops) == 3
        and isinstance(mul := block.first_op, arith.MulfOp)
        and mul.lhs == block.args[0]
        and mul.rhs == block.args[1]
        and isinstance(add := mul.next_op, arith.AddfOp)
        and add.lhs == mul.result
        and add.rhs == block.args[2]
        and isinstance(yld := add.next_op, linalg.YieldOp)
        and yld.operands[0] == add.result
    )

ConvertLinalgMinPass

Bases: RewritePattern

Lowers the linalg.min op to csl by negating the operands, performing max, and again negating the operands as well as the result.

todo: scalar operands are currently not supported

Source code in xdsl/transforms/linalg_to_csl.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class ConvertLinalgMinPass(RewritePattern):
    """
    Lowers the `linalg.min` op to csl by negating the operands, performing max, and
    again negating the operands as well as the result.

    todo: scalar operands are currently not supported
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: linalg.MinOp, rewriter: PatternRewriter, /):
        if not isa(op.outputs.types[0], MemRefType):
            return

        # sets of operands to be negated before and after
        negate_before = set(op.inputs)
        negate_after = negate_before | set(op.outputs)

        # builtin op for negating
        neg_op = match_op_for_precision(
            op.outputs.types[0].get_element_type(), f16=csl.FneghOp, f32=csl.FnegsOp
        )

        # constructing in-place negate ops before and after
        before_ops = [neg_op(operands=[(o, o)]) for o in negate_before]
        after_ops = [neg_op(operands=[(o, o)]) for o in negate_after]

        rewriter.insert_op(before_ops, InsertPoint.before(op))
        rewriter.insert_op(after_ops, InsertPoint.after(op))
        transform_op(op, rewriter, f16=csl.FmaxhOp, f32=csl.FmaxsOp)

match_and_rewrite(op: linalg.MinOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/linalg_to_csl.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.MinOp, rewriter: PatternRewriter, /):
    if not isa(op.outputs.types[0], MemRefType):
        return

    # sets of operands to be negated before and after
    negate_before = set(op.inputs)
    negate_after = negate_before | set(op.outputs)

    # builtin op for negating
    neg_op = match_op_for_precision(
        op.outputs.types[0].get_element_type(), f16=csl.FneghOp, f32=csl.FnegsOp
    )

    # constructing in-place negate ops before and after
    before_ops = [neg_op(operands=[(o, o)]) for o in negate_before]
    after_ops = [neg_op(operands=[(o, o)]) for o in negate_after]

    rewriter.insert_op(before_ops, InsertPoint.before(op))
    rewriter.insert_op(after_ops, InsertPoint.after(op))
    transform_op(op, rewriter, f16=csl.FmaxhOp, f32=csl.FmaxsOp)

ConvertLinalgAddPass

Bases: RewritePattern

Source code in xdsl/transforms/linalg_to_csl.py
165
166
167
168
class ConvertLinalgAddPass(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: linalg.AddOp, rewriter: PatternRewriter, /):
        transform_op(op, rewriter, f16=csl.FaddhOp, f32=csl.FaddsOp)

match_and_rewrite(op: linalg.AddOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/linalg_to_csl.py
166
167
168
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.AddOp, rewriter: PatternRewriter, /):
    transform_op(op, rewriter, f16=csl.FaddhOp, f32=csl.FaddsOp)

ConvertLinalgSubPass

Bases: RewritePattern

Source code in xdsl/transforms/linalg_to_csl.py
171
172
173
174
class ConvertLinalgSubPass(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: linalg.SubOp, rewriter: PatternRewriter, /):
        transform_op(op, rewriter, f16=csl.FsubhOp, f32=csl.FsubsOp)

match_and_rewrite(op: linalg.SubOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/linalg_to_csl.py
172
173
174
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.SubOp, rewriter: PatternRewriter, /):
    transform_op(op, rewriter, f16=csl.FsubhOp, f32=csl.FsubsOp)

ConvertLinalgMulPass

Bases: RewritePattern

Source code in xdsl/transforms/linalg_to_csl.py
177
178
179
180
class ConvertLinalgMulPass(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: linalg.MulOp, rewriter: PatternRewriter, /):
        transform_op(op, rewriter, f16=csl.FmulhOp, f32=csl.FmulsOp)

match_and_rewrite(op: linalg.MulOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/linalg_to_csl.py
178
179
180
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.MulOp, rewriter: PatternRewriter, /):
    transform_op(op, rewriter, f16=csl.FmulhOp, f32=csl.FmulsOp)

ConvertLinalgMaxPass

Bases: RewritePattern

Source code in xdsl/transforms/linalg_to_csl.py
183
184
185
186
class ConvertLinalgMaxPass(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: linalg.MaxOp, rewriter: PatternRewriter, /):
        transform_op(op, rewriter, f16=csl.FmaxhOp, f32=csl.FmaxsOp)

match_and_rewrite(op: linalg.MaxOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/linalg_to_csl.py
184
185
186
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.MaxOp, rewriter: PatternRewriter, /):
    transform_op(op, rewriter, f16=csl.FmaxhOp, f32=csl.FmaxsOp)

LinalgToCsl dataclass

Bases: ModulePass

Convert linalg ops to csl ops.

The linalg ops are required to be in 'memref mode', i.e., after bufferization has been applied.

Source code in xdsl/transforms/linalg_to_csl.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
@dataclass(frozen=True)
class LinalgToCsl(ModulePass):
    """
    Convert linalg ops to csl ops.

    The linalg ops are required to be in 'memref mode', i.e., after bufferization has been applied.
    """

    name = "linalg-to-csl"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        module_pass = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ConvertLinalgGenericFMAPass(),
                    ConvertLinalgAddPass(),
                    ConvertLinalgSubPass(),
                    ConvertLinalgMulPass(),
                    ConvertLinalgMaxPass(),
                    ConvertLinalgMinPass(),
                ]
            ),
        )
        module_pass.rewrite_module(op)

name = 'linalg-to-csl' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/linalg_to_csl.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def apply(self, ctx: Context, op: ModuleOp) -> None:
    module_pass = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ConvertLinalgGenericFMAPass(),
                ConvertLinalgAddPass(),
                ConvertLinalgSubPass(),
                ConvertLinalgMulPass(),
                ConvertLinalgMaxPass(),
                ConvertLinalgMinPass(),
            ]
        ),
    )
    module_pass.rewrite_module(op)

match_op_for_precision(prec: Attribute, f16: type[csl.BuiltinDsdOp], f32: type[csl.BuiltinDsdOp]) -> type[csl.BuiltinDsdOp]

Returns the op type matching a given precision.

Source code in xdsl/transforms/linalg_to_csl.py
26
27
28
29
30
31
32
33
34
35
36
37
def match_op_for_precision(
    prec: Attribute, f16: type[csl.BuiltinDsdOp], f32: type[csl.BuiltinDsdOp]
) -> type[csl.BuiltinDsdOp]:
    """Returns the op type matching a given precision."""
    # todo support mixed-precision
    match prec:
        case builtin.f16:
            return f16
        case builtin.f32:
            return f32
        case _:
            raise ValueError(f"Unsupported element type {prec}")

get_scalar_const(op: SSAValue) -> FloatAttr | IntegerAttr | None

Returns the value of a scalar arith.constant, or None if not a constant or not scalar).

Source code in xdsl/transforms/linalg_to_csl.py
40
41
42
43
44
45
46
47
48
def get_scalar_const(op: SSAValue) -> FloatAttr | IntegerAttr | None:
    """Returns the value of a scalar arith.constant, or None if not a constant or not scalar)."""
    if (
        isinstance(op, OpResult)
        and isinstance(op.op, arith.ConstantOp)
        and isa(val := op.op.value, DenseIntOrFPElementsAttr)
        and val.is_splat()
    ):
        return val.get_attrs()[0]

transform_op(op: linalg.NamedOperation, rewriter: PatternRewriter, f16: type[csl.BuiltinDsdOp], f32: type[csl.BuiltinDsdOp])

Source code in xdsl/transforms/linalg_to_csl.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def transform_op(
    op: linalg.NamedOperation,
    rewriter: PatternRewriter,
    f16: type[csl.BuiltinDsdOp],
    f32: type[csl.BuiltinDsdOp],
):
    if not isa(target_t := op.outputs.types[0], MemRefType):
        return

    builtin = match_op_for_precision(target_t.get_element_type(), f16, f32)

    lhs = op.inputs[0]
    rhs = op.inputs[1]

    # binary functions translated here support mixing scalar and collection operands
    # may need revisiting if more functions are translated
    if scalar_const := get_scalar_const(lhs):
        rewriter.insert_op(
            const_op := arith.ConstantOp(scalar_const), InsertPoint.before(op)
        )
        lhs = const_op.result
    elif scalar_const := get_scalar_const(rhs):
        rewriter.insert_op(
            const_op := arith.ConstantOp(scalar_const), InsertPoint.before(op)
        )
        rhs = const_op.result

    rewriter.replace_op(op, builtin(operands=[[op.outputs[0], lhs, rhs]]))