Skip to content

Arith

arith

SignlessIntegerBinaryOperationZeroOrUnitRight

Bases: RewritePattern

Source code in xdsl/transforms/canonicalization_patterns/arith.py
17
18
19
20
21
22
23
24
25
26
27
class SignlessIntegerBinaryOperationZeroOrUnitRight(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.SignlessIntegerBinaryOperation, rewriter: PatternRewriter, /
    ):
        if (rhs := const_evaluate_operand_attribute(op.rhs)) is None:
            return
        if op.is_right_zero(rhs):
            rewriter.replace_op(op, (), (op.rhs,))
        elif op.is_right_unit(rhs):
            rewriter.replace_op(op, (), (op.lhs,))

match_and_rewrite(op: arith.SignlessIntegerBinaryOperation, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/arith.py
18
19
20
21
22
23
24
25
26
27
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.SignlessIntegerBinaryOperation, rewriter: PatternRewriter, /
):
    if (rhs := const_evaluate_operand_attribute(op.rhs)) is None:
        return
    if op.is_right_zero(rhs):
        rewriter.replace_op(op, (), (op.rhs,))
    elif op.is_right_unit(rhs):
        rewriter.replace_op(op, (), (op.lhs,))

SignlessIntegerBinaryOperationConstantProp

Bases: RewritePattern

Source code in xdsl/transforms/canonicalization_patterns/arith.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class SignlessIntegerBinaryOperationConstantProp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.SignlessIntegerBinaryOperation, rewriter: PatternRewriter, /
    ):
        if (lhs := const_evaluate_operand(op.lhs)) is None:
            return
        if (rhs := const_evaluate_operand(op.rhs)) is None:
            # Swap inputs if lhs is constant and rhs is not
            if op.has_trait(Commutative):
                rewriter.replace_op(op, op.__class__(op.rhs, op.lhs))
            return

        if (res := op.py_operation(lhs, rhs)) is None:
            return
        assert isinstance(op.result.type, IntegerType | IndexType)

        rewriter.replace_op(
            op,
            arith.ConstantOp.from_int_and_width(
                res, op.result.type, truncate_bits=True
            ),
        )

match_and_rewrite(op: arith.SignlessIntegerBinaryOperation, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/arith.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.SignlessIntegerBinaryOperation, rewriter: PatternRewriter, /
):
    if (lhs := const_evaluate_operand(op.lhs)) is None:
        return
    if (rhs := const_evaluate_operand(op.rhs)) is None:
        # Swap inputs if lhs is constant and rhs is not
        if op.has_trait(Commutative):
            rewriter.replace_op(op, op.__class__(op.rhs, op.lhs))
        return

    if (res := op.py_operation(lhs, rhs)) is None:
        return
    assert isinstance(op.result.type, IntegerType | IndexType)

    rewriter.replace_op(
        op,
        arith.ConstantOp.from_int_and_width(
            res, op.result.type, truncate_bits=True
        ),
    )

FoldConstConstOp

Bases: RewritePattern

Folds a floating point binary op whose operands are both arith.constants.

Source code in xdsl/transforms/canonicalization_patterns/arith.py
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class FoldConstConstOp(RewritePattern):
    """
    Folds a floating point binary op whose operands are both `arith.constant`s.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.FloatingPointLikeBinaryOperation, rewriter: PatternRewriter, /
    ):
        if (
            isinstance(op.lhs.owner, arith.ConstantOp)
            and isinstance(op.rhs.owner, arith.ConstantOp)
            and isa(l := op.lhs.owner.value, builtin.FloatAttr)
            and isa(r := op.rhs.owner.value, builtin.FloatAttr)
            and (cnst := _fold_const_operation(type(op), l, r))
        ):
            rewriter.replace_op(op, cnst)

match_and_rewrite(op: arith.FloatingPointLikeBinaryOperation, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/arith.py
88
89
90
91
92
93
94
95
96
97
98
99
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.FloatingPointLikeBinaryOperation, rewriter: PatternRewriter, /
):
    if (
        isinstance(op.lhs.owner, arith.ConstantOp)
        and isinstance(op.rhs.owner, arith.ConstantOp)
        and isa(l := op.lhs.owner.value, builtin.FloatAttr)
        and isa(r := op.rhs.owner.value, builtin.FloatAttr)
        and (cnst := _fold_const_operation(type(op), l, r))
    ):
        rewriter.replace_op(op, cnst)

FoldConstsByReassociation

Bases: RewritePattern

Rewrites a chain of (const1 <op> var) <op> const2 as folded_const <op> val

The op must be associative and have the fastmath<reassoc> flag set.

Source code in xdsl/transforms/canonicalization_patterns/arith.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class FoldConstsByReassociation(RewritePattern):
    """
    Rewrites a chain of
        `(const1 <op> var) <op> const2`
    as
        `folded_const <op> val`

    The op must be associative and have the `fastmath<reassoc>` flag set.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.AddfOp | arith.MulfOp, rewriter: PatternRewriter, /
    ):
        if isinstance(op.lhs.owner, arith.ConstantOp):
            const1, val = op.lhs.owner, op.rhs
        else:
            const1, val = op.rhs.owner, op.lhs

        if (
            not isinstance(const1, arith.ConstantOp)
            or not isinstance(u := op.result.get_user_of_unique_use(), type(op))
            or not isinstance(
                const2 := u.lhs.owner if u.rhs == op.result else u.rhs.owner,
                arith.ConstantOp,
            )
            or arith.FastMathFlag.REASSOC not in op.fastmath.flags
            or arith.FastMathFlag.REASSOC not in u.fastmath.flags
            or not isa(c1 := const1.value, builtin.FloatAttr)
            or not isa(c2 := const2.value, builtin.FloatAttr)
        ):
            return

        if cnsts := _fold_const_operation(type(op), c1, c2):
            flags = arith.FastMathFlagsAttr(list(op.fastmath.flags | u.fastmath.flags))
            rebuild = type(op)(cnsts, val, flags)
            rewriter.replace_op(op, [cnsts, rebuild])
            rewriter.replace_op(u, [], [rebuild.result])

match_and_rewrite(op: arith.AddfOp | arith.MulfOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/arith.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.AddfOp | arith.MulfOp, rewriter: PatternRewriter, /
):
    if isinstance(op.lhs.owner, arith.ConstantOp):
        const1, val = op.lhs.owner, op.rhs
    else:
        const1, val = op.rhs.owner, op.lhs

    if (
        not isinstance(const1, arith.ConstantOp)
        or not isinstance(u := op.result.get_user_of_unique_use(), type(op))
        or not isinstance(
            const2 := u.lhs.owner if u.rhs == op.result else u.rhs.owner,
            arith.ConstantOp,
        )
        or arith.FastMathFlag.REASSOC not in op.fastmath.flags
        or arith.FastMathFlag.REASSOC not in u.fastmath.flags
        or not isa(c1 := const1.value, builtin.FloatAttr)
        or not isa(c2 := const2.value, builtin.FloatAttr)
    ):
        return

    if cnsts := _fold_const_operation(type(op), c1, c2):
        flags = arith.FastMathFlagsAttr(list(op.fastmath.flags | u.fastmath.flags))
        rebuild = type(op)(cnsts, val, flags)
        rewriter.replace_op(op, [cnsts, rebuild])
        rewriter.replace_op(u, [], [rebuild.result])

SelectConstPattern

Bases: RewritePattern

arith.select %true %x %y = %x arith.select %false %x %y = %y

Source code in xdsl/transforms/canonicalization_patterns/arith.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class SelectConstPattern(RewritePattern):
    """
    arith.select %true %x %y = %x
    arith.select %false %x %y = %y
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter):
        const_value = const_evaluate_operand(op.cond)

        if const_value is None:
            return

        new_results = (op.lhs,) if const_value else (op.rhs,)
        rewriter.replace_op(op, (), new_results)

match_and_rewrite(op: arith.SelectOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/arith.py
148
149
150
151
152
153
154
155
156
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter):
    const_value = const_evaluate_operand(op.cond)

    if const_value is None:
        return

    new_results = (op.lhs,) if const_value else (op.rhs,)
    rewriter.replace_op(op, (), new_results)

SelectTrueFalsePattern

Bases: RewritePattern

arith.select %x %true %false = %x arith.select %x %false %true = %x xor 1

Source code in xdsl/transforms/canonicalization_patterns/arith.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class SelectTrueFalsePattern(RewritePattern):
    """
    arith.select %x %true %false = %x
    arith.select %x %false %true = %x xor 1
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter):
        if op.result.type != IntegerType(1):
            return

        if (lhs := const_evaluate_operand(op.lhs)) is None or (
            rhs := const_evaluate_operand(op.rhs)
        ) is None:
            return

        if lhs and not rhs:
            rewriter.replace_op(op, (), (op.cond,))

        if not lhs and rhs:
            rewriter.replace_op(op, arith.XOrIOp(op.cond, op.rhs))

match_and_rewrite(op: arith.SelectOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/arith.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter):
    if op.result.type != IntegerType(1):
        return

    if (lhs := const_evaluate_operand(op.lhs)) is None or (
        rhs := const_evaluate_operand(op.rhs)
    ) is None:
        return

    if lhs and not rhs:
        rewriter.replace_op(op, (), (op.cond,))

    if not lhs and rhs:
        rewriter.replace_op(op, arith.XOrIOp(op.cond, op.rhs))

SelectSamePattern

Bases: RewritePattern

arith.select %x %y %y = %y

Source code in xdsl/transforms/canonicalization_patterns/arith.py
182
183
184
185
186
187
188
189
190
class SelectSamePattern(RewritePattern):
    """
    arith.select %x %y %y = %y
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter):
        if op.lhs == op.rhs:
            rewriter.replace_op(op, (), (op.lhs,))

match_and_rewrite(op: arith.SelectOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/arith.py
187
188
189
190
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter):
    if op.lhs == op.rhs:
        rewriter.replace_op(op, (), (op.lhs,))

SelectFoldCmpfPattern

Bases: RewritePattern

%1 = arith.cmpf ogt, %0, %cst fastmath : f64 %2 = arith.select %1, %0, %cst : f64

Source code in xdsl/transforms/canonicalization_patterns/arith.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class SelectFoldCmpfPattern(RewritePattern):
    """
    %1 = arith.cmpf  ogt, %0, %cst fastmath<nnan> : f64
    %2 = arith.select %1, %0, %cst : f64
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter):
        if not isinstance(op.cond, OpResult) or not isinstance(
            cmpf := op.cond.op, arith.CmpfOp
        ):
            return
        if (
            arith.FastMathFlag.NO_NANS not in cmpf.fastmath.flags
            or arith.FastMathFlag.NO_SIGNED_ZEROS not in cmpf.fastmath.flags
        ):
            return
        if not (op.lhs == cmpf.lhs and op.rhs == cmpf.rhs):
            return

        target = None
        match cmpf.predicate.value.data:
            case 2 | 3 | 9 | 10:
                # ogt | oge | ugt | uge
                target = arith.MaximumfOp
            case 4 | 5 | 11 | 12:
                # olt | ole | ult | ule
                target = arith.MinimumfOp
            case _:
                return
        rewriter.replace_op(op, target(op.lhs, op.rhs, cmpf.fastmath))

match_and_rewrite(op: arith.SelectOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/arith.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter):
    if not isinstance(op.cond, OpResult) or not isinstance(
        cmpf := op.cond.op, arith.CmpfOp
    ):
        return
    if (
        arith.FastMathFlag.NO_NANS not in cmpf.fastmath.flags
        or arith.FastMathFlag.NO_SIGNED_ZEROS not in cmpf.fastmath.flags
    ):
        return
    if not (op.lhs == cmpf.lhs and op.rhs == cmpf.rhs):
        return

    target = None
    match cmpf.predicate.value.data:
        case 2 | 3 | 9 | 10:
            # ogt | oge | ugt | uge
            target = arith.MaximumfOp
        case 4 | 5 | 11 | 12:
            # olt | ole | ult | ule
            target = arith.MinimumfOp
        case _:
            return
    rewriter.replace_op(op, target(op.lhs, op.rhs, cmpf.fastmath))

ApplyCmpiPredicateToEqualOperands

Bases: RewritePattern

Source code in xdsl/transforms/canonicalization_patterns/arith.py
226
227
228
229
230
231
232
class ApplyCmpiPredicateToEqualOperands(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.CmpiOp, rewriter: PatternRewriter):
        if op.lhs != op.rhs:
            return
        val = op.predicate.value.data in (0, 3, 5, 7, 9)
        rewriter.replace_op(op, arith.ConstantOp(BoolAttr.from_bool(val)))

match_and_rewrite(op: arith.CmpiOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/arith.py
227
228
229
230
231
232
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.CmpiOp, rewriter: PatternRewriter):
    if op.lhs != op.rhs:
        return
    val = op.predicate.value.data in (0, 3, 5, 7, 9)
    rewriter.replace_op(op, arith.ConstantOp(BoolAttr.from_bool(val)))