Skip to content

Varith transformations

varith_transformations

ARITH_TO_VARITH_TYPE_MAP: dict[type[arith.SignlessIntegerBinaryOperation | arith.FloatingPointLikeBinaryOperation], type[varith.VarithOp]] = {arith.AddiOp: varith.VarithAddOp, arith.AddfOp: varith.VarithAddOp, arith.MuliOp: varith.VarithMulOp, arith.MulfOp: varith.VarithMulOp} module-attribute

ARITH_TYPES: dict[tuple[Literal['float', 'int'], Literal['add', 'mul']], type[arith.SignlessIntegerBinaryOperation | arith.FloatingPointLikeBinaryOperation]] = {('int', 'add'): arith.AddiOp, ('int', 'mul'): arith.MuliOp, ('float', 'add'): arith.AddfOp, ('float', 'mul'): arith.MulfOp} module-attribute

ArithToVarithPattern

Bases: RewritePattern

Merges two arith operations into a varith operation.

Source code in xdsl/transforms/varith_transformations.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class ArithToVarithPattern(RewritePattern):
    """
    Merges two arith operations into a varith operation.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self,
        op: arith.AddiOp | arith.AddfOp | arith.MuliOp | arith.MulfOp,
        rewriter: PatternRewriter,
        /,
    ):
        dest_type = ARITH_TO_VARITH_TYPE_MAP[type(op)]

        if type(use_op := op.result.get_user_of_unique_use()) not in (
            type(op),
            dest_type,
        ):
            return
        # pyright does not understand that `use_op` cannot be None here
        use_op = cast(Operation, use_op)

        other_operands = [o for o in use_op.operands if o != op.result]
        rewriter.replace_op(
            use_op,
            dest_type(op.lhs, op.rhs, *other_operands),
        )
        rewriter.erase_op(op)

match_and_rewrite(op: arith.AddiOp | arith.AddfOp | arith.MuliOp | arith.MulfOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/varith_transformations.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@op_type_rewrite_pattern
def match_and_rewrite(
    self,
    op: arith.AddiOp | arith.AddfOp | arith.MuliOp | arith.MulfOp,
    rewriter: PatternRewriter,
    /,
):
    dest_type = ARITH_TO_VARITH_TYPE_MAP[type(op)]

    if type(use_op := op.result.get_user_of_unique_use()) not in (
        type(op),
        dest_type,
    ):
        return
    # pyright does not understand that `use_op` cannot be None here
    use_op = cast(Operation, use_op)

    other_operands = [o for o in use_op.operands if o != op.result]
    rewriter.replace_op(
        use_op,
        dest_type(op.lhs, op.rhs, *other_operands),
    )
    rewriter.erase_op(op)

VarithToArithPattern

Bases: RewritePattern

Splits a varith operation into a sequence of arith operations.

Source code in xdsl/transforms/varith_transformations.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class VarithToArithPattern(RewritePattern):
    """
    Splits a varith operation into a sequence of arith operations.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /):
        # get the type kind of the target arith ops (float|int)
        type_name: Literal["float", "int"] = (
            "int" if is_integer_like_type(op.res.type) else "float"
        )
        # get the opeation kind of the target arith ops (add|mul)
        kind: Literal["add", "mul"] = (
            "add" if isinstance(op, varith.VarithAddOp) else "mul"
        )

        # get the corresponding arith type (e.g. addi/mulf)
        target_arith_type = ARITH_TYPES[(type_name, kind)]

        arith_ops: list[Operation] = []

        # Break the varith op down into a sequence of arith ops
        first_arg = op.operands[0]

        if len(op.operands) == 1:
            rewriter.replace_op(op, [], new_results=[first_arg])
            return

        for i in range(1, len(op.operands)):
            newop = target_arith_type(first_arg, op.operands[i])
            arith_ops.append(newop)
            first_arg = newop.result

        rewriter.replace_op(op, arith_ops)

match_and_rewrite(op: varith.VarithOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/varith_transformations.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@op_type_rewrite_pattern
def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /):
    # get the type kind of the target arith ops (float|int)
    type_name: Literal["float", "int"] = (
        "int" if is_integer_like_type(op.res.type) else "float"
    )
    # get the opeation kind of the target arith ops (add|mul)
    kind: Literal["add", "mul"] = (
        "add" if isinstance(op, varith.VarithAddOp) else "mul"
    )

    # get the corresponding arith type (e.g. addi/mulf)
    target_arith_type = ARITH_TYPES[(type_name, kind)]

    arith_ops: list[Operation] = []

    # Break the varith op down into a sequence of arith ops
    first_arg = op.operands[0]

    if len(op.operands) == 1:
        rewriter.replace_op(op, [], new_results=[first_arg])
        return

    for i in range(1, len(op.operands)):
        newop = target_arith_type(first_arg, op.operands[i])
        arith_ops.append(newop)
        first_arg = newop.result

    rewriter.replace_op(op, arith_ops)

MergeVarithOpsPattern

Bases: RewritePattern

Looks at every operand to a varith op, and merges them into it if possible.

e.g. a varith.add(arith.addi(1,2), varith.add(3, 4, 5), 6) becomes a varith.add(1,2,3,4,5,6)

Source code in xdsl/transforms/varith_transformations.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class MergeVarithOpsPattern(RewritePattern):
    """
    Looks at every operand to a varith op, and merges them into it if possible.

    e.g. a
        varith.add(arith.addi(1,2), varith.add(3, 4, 5), 6)
    becomes a
        varith.add(1,2,3,4,5,6)
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /):
        # get the type kind (float|int)
        type_name: Literal["float", "int"] = (
            "int" if is_integer_like_type(op.res.type) else "float"
        )
        # get the opeation kind (add|mul)
        kind: Literal["add", "mul"] = (
            "add" if isinstance(op, varith.VarithAddOp) else "mul"
        )

        # grab the corresponding arith type (e.g. addi/mulf)
        target_arith_type = ARITH_TYPES[(type_name, kind)]

        # create a list of new operands
        new_operands: list[SSAValue] = []
        # create a list of ops that could be erased if their results become unused
        possibly_erased_ops: list[Operation] = []

        # iterate over operands of the varith op:
        for inp in op.operands:
            # if the input happens to be the right arith op:
            if isinstance(inp.owner, target_arith_type):
                # fuse the operands of the arith op into the new varith op
                new_operands.append(inp.owner.lhs)
                new_operands.append(inp.owner.rhs)
                # check if the old arith op can be erased
                possibly_erased_ops.append(inp.owner)
            # if the parent op is a varith op of the same type as us
            elif isinstance(inp.owner, type(op)):
                # include all operands of that
                new_operands.extend(inp.owner.operands)
                # check the old varith op for usages
                possibly_erased_ops.append(inp.owner)
            else:
                # otherwise don't change the input
                new_operands.append(inp)

        # nothing to do if no new operands are added
        if len(possibly_erased_ops) == 0:
            return

        # instantiate a new varith op of the same type as the old op:
        # we can ignore the type error as we know that all VarithOps are instantiated
        # with an *arg of their operands
        rewriter.replace_op(op, type(op)(*new_operands))

        # check all ops that may be erased later:
        for old_op in possibly_erased_ops:
            if not old_op.results[-1].uses:
                rewriter.erase_op(old_op)

match_and_rewrite(op: varith.VarithOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/varith_transformations.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
@op_type_rewrite_pattern
def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /):
    # get the type kind (float|int)
    type_name: Literal["float", "int"] = (
        "int" if is_integer_like_type(op.res.type) else "float"
    )
    # get the opeation kind (add|mul)
    kind: Literal["add", "mul"] = (
        "add" if isinstance(op, varith.VarithAddOp) else "mul"
    )

    # grab the corresponding arith type (e.g. addi/mulf)
    target_arith_type = ARITH_TYPES[(type_name, kind)]

    # create a list of new operands
    new_operands: list[SSAValue] = []
    # create a list of ops that could be erased if their results become unused
    possibly_erased_ops: list[Operation] = []

    # iterate over operands of the varith op:
    for inp in op.operands:
        # if the input happens to be the right arith op:
        if isinstance(inp.owner, target_arith_type):
            # fuse the operands of the arith op into the new varith op
            new_operands.append(inp.owner.lhs)
            new_operands.append(inp.owner.rhs)
            # check if the old arith op can be erased
            possibly_erased_ops.append(inp.owner)
        # if the parent op is a varith op of the same type as us
        elif isinstance(inp.owner, type(op)):
            # include all operands of that
            new_operands.extend(inp.owner.operands)
            # check the old varith op for usages
            possibly_erased_ops.append(inp.owner)
        else:
            # otherwise don't change the input
            new_operands.append(inp)

    # nothing to do if no new operands are added
    if len(possibly_erased_ops) == 0:
        return

    # instantiate a new varith op of the same type as the old op:
    # we can ignore the type error as we know that all VarithOps are instantiated
    # with an *arg of their operands
    rewriter.replace_op(op, type(op)(*new_operands))

    # check all ops that may be erased later:
    for old_op in possibly_erased_ops:
        if not old_op.results[-1].uses:
            rewriter.erase_op(old_op)

FuseRepeatedAddArgsPattern dataclass

Bases: RewritePattern

Prefer operand * count(operand) over repeated addition of operand.

The minimum count to trigger this rewrite can be specified in min_reps.

Source code in xdsl/transforms/varith_transformations.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
@dataclass
class FuseRepeatedAddArgsPattern(RewritePattern):
    """
    Prefer `operand * count(operand)` over repeated addition of `operand`.

    The minimum count to trigger this rewrite can be specified in `min_reps`.
    """

    min_reps: int
    """Minimum repetitions of operand to trigger fusion."""

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: varith.VarithAddOp, rewriter: PatternRewriter, /):
        elem_t = get_element_type_or_self(op.res.type)

        assert isa(elem_t, builtin.IntegerType | builtin.IndexType | builtin.AnyFloat)

        consts: list[arith.ConstantOp] = []
        fusions: list[Operation] = []
        new_args: list[Operation | SSAValue] = []
        for arg, count in collections.Counter(op.args).items():
            if count >= self.min_reps:
                c, f = self.fuse(arg, count, elem_t)
                consts.append(c)
                fusions.append(f)
                new_args.append(f)
            else:
                new_args.append(arg)
        if fusions:
            rewriter.insert_op([*consts, *fusions], InsertPoint.before(op))
            rewriter.replace_op(op, varith.VarithAddOp(*new_args))

    @staticmethod
    def fuse(
        arg: SSAValue,
        count: int,
        t: builtin.IntegerType | builtin.IndexType | builtin.AnyFloat,
    ):
        if isinstance(t, builtin.IntegerType | builtin.IndexType):
            c = arith.ConstantOp(builtin.IntegerAttr(count, t))
            f = arith.MuliOp
        else:
            c = arith.ConstantOp(builtin.FloatAttr(count, t))
            f = arith.MulfOp
        return c, f(c, arg)

min_reps: int instance-attribute

Minimum repetitions of operand to trigger fusion.

__init__(min_reps: int) -> None

match_and_rewrite(op: varith.VarithAddOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/varith_transformations.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
@op_type_rewrite_pattern
def match_and_rewrite(self, op: varith.VarithAddOp, rewriter: PatternRewriter, /):
    elem_t = get_element_type_or_self(op.res.type)

    assert isa(elem_t, builtin.IntegerType | builtin.IndexType | builtin.AnyFloat)

    consts: list[arith.ConstantOp] = []
    fusions: list[Operation] = []
    new_args: list[Operation | SSAValue] = []
    for arg, count in collections.Counter(op.args).items():
        if count >= self.min_reps:
            c, f = self.fuse(arg, count, elem_t)
            consts.append(c)
            fusions.append(f)
            new_args.append(f)
        else:
            new_args.append(arg)
    if fusions:
        rewriter.insert_op([*consts, *fusions], InsertPoint.before(op))
        rewriter.replace_op(op, varith.VarithAddOp(*new_args))

fuse(arg: SSAValue, count: int, t: builtin.IntegerType | builtin.IndexType | builtin.AnyFloat) staticmethod

Source code in xdsl/transforms/varith_transformations.py
214
215
216
217
218
219
220
221
222
223
224
225
226
@staticmethod
def fuse(
    arg: SSAValue,
    count: int,
    t: builtin.IntegerType | builtin.IndexType | builtin.AnyFloat,
):
    if isinstance(t, builtin.IntegerType | builtin.IndexType):
        c = arith.ConstantOp(builtin.IntegerAttr(count, t))
        f = arith.MuliOp
    else:
        c = arith.ConstantOp(builtin.FloatAttr(count, t))
        f = arith.MulfOp
    return c, f(c, arg)

ConvertArithToVarithPass dataclass

Bases: ModulePass

Convert chains of arith.{add|mul}{i,f} operations into a single long variadic add or mul operation.

Used for simplifying arithmetic operations for rewrites that need to either change the order or completely "cut an equation in half".

Source code in xdsl/transforms/varith_transformations.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
class ConvertArithToVarithPass(ModulePass):
    """
    Convert chains of arith.{add|mul}{i,f} operations into a single long variadic add or mul operation.

    Used for simplifying arithmetic operations for rewrites that need to either change the order or
    completely "cut an equation in half".
    """

    name = "convert-arith-to-varith"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ArithToVarithPattern(),
                    MergeVarithOpsPattern(),
                ]
            ),
            walk_reverse=True,
        ).rewrite_module(op)

name = 'convert-arith-to-varith' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/varith_transformations.py
239
240
241
242
243
244
245
246
247
248
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ArithToVarithPattern(),
                MergeVarithOpsPattern(),
            ]
        ),
        walk_reverse=True,
    ).rewrite_module(op)

ConvertVarithToArithPass dataclass

Bases: ModulePass

Convert a single long variadic add or mul operation into a chain of arith.{add|mul}{i,f} operations. Reverses ConvertArithToVarithPass.

Source code in xdsl/transforms/varith_transformations.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
class ConvertVarithToArithPass(ModulePass):
    """
    Convert a single long variadic add or mul operation into a chain of arith.{add|mul}{i,f} operations.
    Reverses ConvertArithToVarithPass.

    """

    name = "convert-varith-to-arith"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            VarithToArithPattern(),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-varith-to-arith' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/varith_transformations.py
260
261
262
263
264
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        VarithToArithPattern(),
        apply_recursively=False,
    ).rewrite_module(op)

VarithFuseRepeatedOperandsPass dataclass

Bases: ModulePass

Fuses several occurrences of the same operand into one.

Source code in xdsl/transforms/varith_transformations.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
class VarithFuseRepeatedOperandsPass(ModulePass):
    """
    Fuses several occurrences of the same operand into one.
    """

    name = "varith-fuse-repeated-operands"

    min_reps: int = 2
    """The minimum number of times an operand needs to be repeated before being fused."""

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            FuseRepeatedAddArgsPattern(self.min_reps),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'varith-fuse-repeated-operands' class-attribute instance-attribute

min_reps: int = 2 class-attribute instance-attribute

The minimum number of times an operand needs to be repeated before being fused.

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/varith_transformations.py
277
278
279
280
281
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        FuseRepeatedAddArgsPattern(self.min_reps),
        apply_recursively=False,
    ).rewrite_module(op)

is_integer_like_type(t: Attribute) -> bool

Returns true if t is an integer/container of integers/container of container of integers ...

Source code in xdsl/transforms/varith_transformations.py
173
174
175
176
177
178
179
def is_integer_like_type(t: Attribute) -> bool:
    """
    Returns true if t is an integer/container of integers/container of
    container of integers ...
    """
    t = get_element_type_or_self(t)
    return isinstance(t, builtin.IntegerType | builtin.IndexType)