Skip to content

Convert arith to riscv

convert_arith_to_riscv

RdRsRsIntegerOperation = riscv.RdRsRsOperation[riscv.IntRegisterType, riscv.IntRegisterType, riscv.IntRegisterType] module-attribute

RdRsRsFloatOperation = riscv.RdRsRsOperation[riscv.FloatRegisterType, riscv.FloatRegisterType, riscv.FloatRegisterType] module-attribute

lower_arith_addi = LowerBinaryIntegerOp(arith.AddiOp, riscv.AddOp) module-attribute

lower_arith_subi = LowerBinaryIntegerOp(arith.SubiOp, riscv.SubOp) module-attribute

lower_arith_muli = LowerBinaryIntegerOp(arith.MuliOp, riscv.MulOp) module-attribute

lower_arith_divui = LowerBinaryIntegerOp(arith.DivUIOp, riscv.DivuOp) module-attribute

lower_arith_divsi = LowerBinaryIntegerOp(arith.DivSIOp, riscv.DivOp) module-attribute

lower_arith_remui = LowerBinaryIntegerOp(arith.RemUIOp, riscv.RemuOp) module-attribute

lower_arith_remsi = LowerBinaryIntegerOp(arith.RemSIOp, riscv.RemOp) module-attribute

lower_arith_andi = LowerBinaryIntegerOp(arith.AndIOp, riscv.AndOp) module-attribute

lower_arith_ori = LowerBinaryIntegerOp(arith.OrIOp, riscv.OrOp) module-attribute

lower_arith_xori = LowerBinaryIntegerOp(arith.XOrIOp, riscv.XorOp) module-attribute

lower_arith_shli = LowerBinaryIntegerOp(arith.ShLIOp, riscv.SllOp) module-attribute

lower_arith_shrui = LowerBinaryIntegerOp(arith.ShRUIOp, riscv.SrlOp) module-attribute

lower_arith_shrsi = LowerBinaryIntegerOp(arith.ShRSIOp, riscv.SraOp) module-attribute

lower_arith_addf = LowerBinaryFloatOp(arith.AddfOp, riscv.FAddSOp, riscv.FAddDOp) module-attribute

lower_arith_subf = LowerBinaryFloatOp(arith.SubfOp, riscv.FSubSOp, riscv.FSubDOp) module-attribute

lower_arith_mulf = LowerBinaryFloatOp(arith.MulfOp, riscv.FMulSOp, riscv.FMulDOp) module-attribute

lower_arith_divf = LowerBinaryFloatOp(arith.DivfOp, riscv.FDivSOp, riscv.FDivDOp) module-attribute

lower_arith_minf = LowerBinaryFloatOp(arith.MinimumfOp, riscv.FMinSOp, riscv.FMinDOp) module-attribute

lower_arith_maxf = LowerBinaryFloatOp(arith.MaximumfOp, riscv.FMaxSOp, riscv.FMaxDOp) module-attribute

LowerArithConstant

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class LowerArithConstant(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.ConstantOp, rewriter: PatternRewriter
    ) -> None:
        op_result_type = op.result.type
        if isa(op_result_type, IntegerType) and isinstance(
            op_val := op.value, IntegerAttr
        ):
            if op_result_type.width.data <= 32:
                rewriter.replace_op(
                    op,
                    [
                        constant := riscv.LiOp(op_val.value.data),
                        UnrealizedConversionCastOp.get(
                            constant.results, (op_result_type,)
                        ),
                    ],
                )
            else:
                raise NotImplementedError("Only 32 bit integers are supported for now")
        elif isinstance(op_val := op.value, FloatAttr):
            if isinstance(op_result_type, Float32Type):
                rewriter.replace_op(
                    op,
                    [
                        lui := riscv.LiOp(
                            convert_f32_to_u32(op_val.value.data),
                            rd=_INT_REGISTER_TYPE,
                        ),
                        fld := riscv.FMvWXOp(lui.rd),
                        UnrealizedConversionCastOp.get(fld.results, (op_result_type,)),
                    ],
                )
            elif isinstance(op_result_type, Float64Type):
                # There is no way to load an immediate value to a float register directly.

                s32_min = signed_lower_bound(32)
                s32_max = signed_upper_bound(32)
                # If the value is an integer that fits in s32, then convert.
                if (val_data := op_val.value.data).is_integer() and s32_min <= (
                    int_val := int(val_data)
                ) < s32_max:
                    rewriter.replace_op(
                        op,
                        [
                            lui := riscv.LiOp(
                                int_val,
                                rd=_INT_REGISTER_TYPE,
                            ),
                            fcvtdw := riscv.FCvtDWOp(lui.rd, rd=_FLOAT_REGISTER_TYPE),
                            UnrealizedConversionCastOp.get(
                                fcvtdw.results, (op_result_type,)
                            ),
                        ],
                    )
                else:
                    # We have to load the bits into an integer register, store them on the
                    # stack, and load again.

                    # TODO: check the xlen in this lowering.

                    # This lowering assumes that xlen is 32 and flen is 64

                    lower, upper = struct.unpack(
                        "<ii", struct.pack("<d", op_val.value.data)
                    )
                    rewriter.replace_op(
                        op,
                        [
                            sp := riscv.GetRegisterOp(riscv.Registers.SP),
                            li_upper := riscv.LiOp(upper),
                            riscv.SwOp(sp, li_upper, -4),
                            li_lower := riscv.LiOp(lower),
                            riscv.SwOp(sp, li_lower, -8),
                            fld := riscv.FLdOp(sp, -8, rd=_FLOAT_REGISTER_TYPE),
                            UnrealizedConversionCastOp.get(
                                fld.results, (op_result_type,)
                            ),
                        ],
                    )
            else:
                raise NotImplementedError("Only 32 or 64 bit floats are supported")
        elif isinstance(op_result_type, IndexType) and isinstance(
            op_val := op.value, IntegerAttr
        ):
            rewriter.replace_op(
                op,
                [
                    constant := riscv.LiOp(op_val.value.data),
                    UnrealizedConversionCastOp.get(constant.results, (op_result_type,)),
                ],
            )
        else:
            raise NotImplementedError(
                f"Unsupported constant type {op_val} of type {type(op_val)}"
            )

match_and_rewrite(op: arith.ConstantOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.ConstantOp, rewriter: PatternRewriter
) -> None:
    op_result_type = op.result.type
    if isa(op_result_type, IntegerType) and isinstance(
        op_val := op.value, IntegerAttr
    ):
        if op_result_type.width.data <= 32:
            rewriter.replace_op(
                op,
                [
                    constant := riscv.LiOp(op_val.value.data),
                    UnrealizedConversionCastOp.get(
                        constant.results, (op_result_type,)
                    ),
                ],
            )
        else:
            raise NotImplementedError("Only 32 bit integers are supported for now")
    elif isinstance(op_val := op.value, FloatAttr):
        if isinstance(op_result_type, Float32Type):
            rewriter.replace_op(
                op,
                [
                    lui := riscv.LiOp(
                        convert_f32_to_u32(op_val.value.data),
                        rd=_INT_REGISTER_TYPE,
                    ),
                    fld := riscv.FMvWXOp(lui.rd),
                    UnrealizedConversionCastOp.get(fld.results, (op_result_type,)),
                ],
            )
        elif isinstance(op_result_type, Float64Type):
            # There is no way to load an immediate value to a float register directly.

            s32_min = signed_lower_bound(32)
            s32_max = signed_upper_bound(32)
            # If the value is an integer that fits in s32, then convert.
            if (val_data := op_val.value.data).is_integer() and s32_min <= (
                int_val := int(val_data)
            ) < s32_max:
                rewriter.replace_op(
                    op,
                    [
                        lui := riscv.LiOp(
                            int_val,
                            rd=_INT_REGISTER_TYPE,
                        ),
                        fcvtdw := riscv.FCvtDWOp(lui.rd, rd=_FLOAT_REGISTER_TYPE),
                        UnrealizedConversionCastOp.get(
                            fcvtdw.results, (op_result_type,)
                        ),
                    ],
                )
            else:
                # We have to load the bits into an integer register, store them on the
                # stack, and load again.

                # TODO: check the xlen in this lowering.

                # This lowering assumes that xlen is 32 and flen is 64

                lower, upper = struct.unpack(
                    "<ii", struct.pack("<d", op_val.value.data)
                )
                rewriter.replace_op(
                    op,
                    [
                        sp := riscv.GetRegisterOp(riscv.Registers.SP),
                        li_upper := riscv.LiOp(upper),
                        riscv.SwOp(sp, li_upper, -4),
                        li_lower := riscv.LiOp(lower),
                        riscv.SwOp(sp, li_lower, -8),
                        fld := riscv.FLdOp(sp, -8, rd=_FLOAT_REGISTER_TYPE),
                        UnrealizedConversionCastOp.get(
                            fld.results, (op_result_type,)
                        ),
                    ],
                )
        else:
            raise NotImplementedError("Only 32 or 64 bit floats are supported")
    elif isinstance(op_result_type, IndexType) and isinstance(
        op_val := op.value, IntegerAttr
    ):
        rewriter.replace_op(
            op,
            [
                constant := riscv.LiOp(op_val.value.data),
                UnrealizedConversionCastOp.get(constant.results, (op_result_type,)),
            ],
        )
    else:
        raise NotImplementedError(
            f"Unsupported constant type {op_val} of type {type(op_val)}"
        )

LowerArithIndexCast

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
136
137
138
139
140
141
142
143
144
145
146
147
class LowerArithIndexCast(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.IndexCastOp, rewriter: PatternRewriter
    ) -> None:
        """
        On a RV32 triple, the index type is 32 bits, so we can just drop the cast.
        """

        rewriter.replace_op(
            op, UnrealizedConversionCastOp.get((op.input,), (op.result.type,))
        )

match_and_rewrite(op: arith.IndexCastOp, rewriter: PatternRewriter) -> None

On a RV32 triple, the index type is 32 bits, so we can just drop the cast.

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
137
138
139
140
141
142
143
144
145
146
147
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.IndexCastOp, rewriter: PatternRewriter
) -> None:
    """
    On a RV32 triple, the index type is 32 bits, so we can just drop the cast.
    """

    rewriter.replace_op(
        op, UnrealizedConversionCastOp.get((op.input,), (op.result.type,))
    )

LowerBinaryIntegerOp dataclass

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@dataclass
class LowerBinaryIntegerOp(RewritePattern):
    arith_op_cls: type[arith.SignlessIntegerBinaryOperation]
    riscv_op_cls: type[RdRsRsIntegerOperation]

    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
        if not isinstance(op, self.arith_op_cls):
            return

        lhs = UnrealizedConversionCastOp.get((op.lhs,), (_INT_REGISTER_TYPE,))
        rhs = UnrealizedConversionCastOp.get((op.rhs,), (_INT_REGISTER_TYPE,))
        add = self.riscv_op_cls(lhs, rhs, rd=_INT_REGISTER_TYPE)
        cast = UnrealizedConversionCastOp.get((add.rd,), (op.result.type,))

        rewriter.replace_op(op, (lhs, rhs, add, cast))

arith_op_cls: type[arith.SignlessIntegerBinaryOperation] instance-attribute

riscv_op_cls: type[RdRsRsIntegerOperation] instance-attribute

__init__(arith_op_cls: type[arith.SignlessIntegerBinaryOperation], riscv_op_cls: type[RdRsRsIntegerOperation]) -> None

match_and_rewrite(op: Operation, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
164
165
166
167
168
169
170
171
172
173
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
    if not isinstance(op, self.arith_op_cls):
        return

    lhs = UnrealizedConversionCastOp.get((op.lhs,), (_INT_REGISTER_TYPE,))
    rhs = UnrealizedConversionCastOp.get((op.rhs,), (_INT_REGISTER_TYPE,))
    add = self.riscv_op_cls(lhs, rhs, rd=_INT_REGISTER_TYPE)
    cast = UnrealizedConversionCastOp.get((add.rd,), (op.result.type,))

    rewriter.replace_op(op, (lhs, rhs, add, cast))

LowerBinaryFloatOp dataclass

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
@dataclass
class LowerBinaryFloatOp(RewritePattern):
    arith_op_cls: type[arith.FloatingPointLikeBinaryOperation]
    riscv_f_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
    riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]

    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
        if not isinstance(op, self.arith_op_cls):
            return

        lhs = UnrealizedConversionCastOp.get((op.lhs,), (_FLOAT_REGISTER_TYPE,))
        rhs = UnrealizedConversionCastOp.get((op.rhs,), (_FLOAT_REGISTER_TYPE,))
        match op.lhs.type:
            case Float32Type():
                cls = self.riscv_f_op_cls
            case Float64Type():
                cls = self.riscv_d_op_cls
            case _:
                raise ValueError(f"Unexpected float type {op.lhs.type}")

        rv_flags = riscv.FastMathFlagsAttr(op.fastmath.data)

        new_op = cls(lhs, rhs, rd=_FLOAT_REGISTER_TYPE, fastmath=rv_flags)
        cast = UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,))

        rewriter.replace_op(op, (lhs, rhs, new_op, cast))

arith_op_cls: type[arith.FloatingPointLikeBinaryOperation] instance-attribute

riscv_f_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath] instance-attribute

riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath] instance-attribute

__init__(arith_op_cls: type[arith.FloatingPointLikeBinaryOperation], riscv_f_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath], riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]) -> None

match_and_rewrite(op: Operation, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
    if not isinstance(op, self.arith_op_cls):
        return

    lhs = UnrealizedConversionCastOp.get((op.lhs,), (_FLOAT_REGISTER_TYPE,))
    rhs = UnrealizedConversionCastOp.get((op.rhs,), (_FLOAT_REGISTER_TYPE,))
    match op.lhs.type:
        case Float32Type():
            cls = self.riscv_f_op_cls
        case Float64Type():
            cls = self.riscv_d_op_cls
        case _:
            raise ValueError(f"Unexpected float type {op.lhs.type}")

    rv_flags = riscv.FastMathFlagsAttr(op.fastmath.data)

    new_op = cls(lhs, rhs, rd=_FLOAT_REGISTER_TYPE, fastmath=rv_flags)
    cast = UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,))

    rewriter.replace_op(op, (lhs, rhs, new_op, cast))

LowerArithFloorDivSI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
211
212
213
214
215
216
class LowerArithFloorDivSI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.FloorDivSIOp, rewriter: PatternRewriter
    ) -> None:
        raise NotImplementedError("FloorDivSI is not supported")

match_and_rewrite(op: arith.FloorDivSIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
212
213
214
215
216
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.FloorDivSIOp, rewriter: PatternRewriter
) -> None:
    raise NotImplementedError("FloorDivSI is not supported")

LowerArithCeilDivSI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
219
220
221
222
223
224
class LowerArithCeilDivSI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.CeilDivSIOp, rewriter: PatternRewriter
    ) -> None:
        raise NotImplementedError("CeilDivSI is not supported")

match_and_rewrite(op: arith.CeilDivSIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
220
221
222
223
224
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.CeilDivSIOp, rewriter: PatternRewriter
) -> None:
    raise NotImplementedError("CeilDivSI is not supported")

LowerArithCeilDivUI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
227
228
229
230
231
232
class LowerArithCeilDivUI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.CeilDivUIOp, rewriter: PatternRewriter
    ) -> None:
        raise NotImplementedError("CeilDivUI is not supported")

match_and_rewrite(op: arith.CeilDivUIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
228
229
230
231
232
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.CeilDivUIOp, rewriter: PatternRewriter
) -> None:
    raise NotImplementedError("CeilDivUI is not supported")

LowerArithMinSI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
239
240
241
242
class LowerArithMinSI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.MinSIOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("MinSI is not supported")

match_and_rewrite(op: arith.MinSIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
240
241
242
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.MinSIOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("MinSI is not supported")

LowerArithMaxSI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
245
246
247
248
class LowerArithMaxSI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.MaxSIOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("MaxSI is not supported")

match_and_rewrite(op: arith.MaxSIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
246
247
248
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.MaxSIOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("MaxSI is not supported")

LowerArithMinUI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
251
252
253
254
class LowerArithMinUI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.MinUIOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("MinUI is not supported")

match_and_rewrite(op: arith.MinUIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
252
253
254
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.MinUIOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("MinUI is not supported")

LowerArithMaxUI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
257
258
259
260
class LowerArithMaxUI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.MaxUIOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("MaxUI is not supported")

match_and_rewrite(op: arith.MaxUIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
258
259
260
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.MaxUIOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("MaxUI is not supported")

LowerArithCmpi

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
class LowerArithCmpi(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.CmpiOp, rewriter: PatternRewriter) -> None:
        # based on https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/RISCV/i32-icmp.ll
        lhs, rhs = cast_operands_to_regs(rewriter, op)
        cast_op_results(rewriter, op)

        match op.predicate.value.data:
            # eq
            case 0:
                xor_op = riscv.XorOp(lhs, rhs)
                seqz_op = riscv.SltiuOp(xor_op, 1)
                rewriter.replace_op(op, [xor_op, seqz_op])
            # ne
            case 1:
                zero = riscv.GetRegisterOp(riscv.Registers.ZERO)
                xor_op = riscv.XorOp(lhs, rhs)
                snez_op = riscv.SltuOp(zero, xor_op)
                rewriter.replace_op(op, [zero, xor_op, snez_op])
            # slt
            case 2:
                rewriter.replace_op(op, [riscv.SltOp(lhs, rhs)])
            # sle
            case 3:
                slt = riscv.SltOp(lhs, rhs)
                xori = riscv.XoriOp(slt, 1)
                rewriter.replace_op(op, [slt, xori])
            # ult
            case 4:
                rewriter.replace_op(op, [riscv.SltuOp(lhs, rhs)])
            # ule
            case 5:
                sltu = riscv.SltuOp(lhs, rhs)
                xori = riscv.XoriOp(sltu, 1)
                rewriter.replace_op(op, [sltu, xori])
            # ugt
            case 6:
                rewriter.replace_op(op, [riscv.SltuOp(rhs, lhs)])
            # uge
            case 7:
                sltu = riscv.SltuOp(rhs, lhs)
                xori = riscv.XoriOp(sltu, 1)
                rewriter.replace_op(op, [sltu, xori])
            case _:
                raise NotImplementedError("Cmpi predicate not supported")

match_and_rewrite(op: arith.CmpiOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.CmpiOp, rewriter: PatternRewriter) -> None:
    # based on https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/RISCV/i32-icmp.ll
    lhs, rhs = cast_operands_to_regs(rewriter, op)
    cast_op_results(rewriter, op)

    match op.predicate.value.data:
        # eq
        case 0:
            xor_op = riscv.XorOp(lhs, rhs)
            seqz_op = riscv.SltiuOp(xor_op, 1)
            rewriter.replace_op(op, [xor_op, seqz_op])
        # ne
        case 1:
            zero = riscv.GetRegisterOp(riscv.Registers.ZERO)
            xor_op = riscv.XorOp(lhs, rhs)
            snez_op = riscv.SltuOp(zero, xor_op)
            rewriter.replace_op(op, [zero, xor_op, snez_op])
        # slt
        case 2:
            rewriter.replace_op(op, [riscv.SltOp(lhs, rhs)])
        # sle
        case 3:
            slt = riscv.SltOp(lhs, rhs)
            xori = riscv.XoriOp(slt, 1)
            rewriter.replace_op(op, [slt, xori])
        # ult
        case 4:
            rewriter.replace_op(op, [riscv.SltuOp(lhs, rhs)])
        # ule
        case 5:
            sltu = riscv.SltuOp(lhs, rhs)
            xori = riscv.XoriOp(sltu, 1)
            rewriter.replace_op(op, [sltu, xori])
        # ugt
        case 6:
            rewriter.replace_op(op, [riscv.SltuOp(rhs, lhs)])
        # uge
        case 7:
            sltu = riscv.SltuOp(rhs, lhs)
            xori = riscv.XoriOp(sltu, 1)
            rewriter.replace_op(op, [sltu, xori])
        case _:
            raise NotImplementedError("Cmpi predicate not supported")

LowerArithSelect

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
310
311
312
313
class LowerArithSelect(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("Select is not supported")

match_and_rewrite(op: arith.SelectOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
311
312
313
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("Select is not supported")

LowerArithNegf

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
332
333
334
335
336
337
338
339
340
341
342
343
344
class LowerArithNegf(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.NegfOp, rewriter: PatternRewriter) -> None:
        rewriter.replace_op(
            op,
            (
                operand := UnrealizedConversionCastOp.get(
                    (op.operand,), (_FLOAT_REGISTER_TYPE,)
                ),
                negf := riscv.FSgnJNSOp(operand, operand),
                UnrealizedConversionCastOp.get((negf.rd,), (op.result.type,)),
            ),
        )

match_and_rewrite(op: arith.NegfOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
333
334
335
336
337
338
339
340
341
342
343
344
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.NegfOp, rewriter: PatternRewriter) -> None:
    rewriter.replace_op(
        op,
        (
            operand := UnrealizedConversionCastOp.get(
                (op.operand,), (_FLOAT_REGISTER_TYPE,)
            ),
            negf := riscv.FSgnJNSOp(operand, operand),
            UnrealizedConversionCastOp.get((negf.rd,), (op.result.type,)),
        ),
    )

LowerArithCmpf

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
class LowerArithCmpf(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.CmpfOp, rewriter: PatternRewriter) -> None:
        # https://llvm.org/docs/LangRef.html#id309
        lhs, rhs = cast_operands_to_regs(rewriter, op)
        cast_op_results(rewriter, op)

        fastmath = riscv.FastMathFlagsAttr(op.fastmath.data)

        match op.predicate.value.data:
            # false
            case 0:
                rewriter.replace_op(op, [riscv.LiOp(0)])
            # oeq
            case 1:
                rewriter.replace_op(op, [riscv.FeqSOp(lhs, rhs, fastmath=fastmath)])
            # ogt
            case 2:
                rewriter.replace_op(op, [riscv.FltSOp(rhs, lhs, fastmath=fastmath)])
            # oge
            case 3:
                rewriter.replace_op(op, [riscv.FleSOp(rhs, lhs, fastmath=fastmath)])
            # olt
            case 4:
                rewriter.replace_op(op, [riscv.FltSOp(lhs, rhs, fastmath=fastmath)])
            # ole
            case 5:
                rewriter.replace_op(op, [riscv.FleSOp(lhs, rhs, fastmath=fastmath)])
            # one
            case 6:
                flt1 = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
                flt2 = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
                rewriter.replace_op(
                    op,
                    [
                        flt1,
                        flt2,
                        riscv.OrOp(flt2, flt1),
                    ],
                )
            # ord
            case 7:
                feq1 = riscv.FeqSOp(lhs, lhs, fastmath=fastmath)
                feq2 = riscv.FeqSOp(rhs, rhs, fastmath=fastmath)
                rewriter.replace_op(
                    op,
                    [
                        feq1,
                        feq2,
                        riscv.AndOp(feq2, feq1),
                    ],
                )
            # ueq
            case 8:
                flt1 = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
                flt2 = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
                or_ = riscv.OrOp(flt2, flt1)
                rewriter.replace_op(op, [flt1, flt2, or_, riscv.XoriOp(or_, 1)])
            # ugt
            case 9:
                fle = riscv.FleSOp(lhs, rhs, fastmath=fastmath)
                rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
            # uge
            case 10:
                fle = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
                rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
            # ult
            case 11:
                fle = riscv.FleSOp(rhs, lhs, fastmath=fastmath)
                rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
            # ule
            case 12:
                flt = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
                rewriter.replace_op(op, [flt, riscv.XoriOp(flt, 1)])
            # une
            case 13:
                feq = riscv.FeqSOp(lhs, rhs, fastmath=fastmath)
                rewriter.replace_op(op, [feq, riscv.XoriOp(feq, 1)])
            # uno
            case 14:
                feq1 = riscv.FeqSOp(lhs, lhs, fastmath=fastmath)
                feq2 = riscv.FeqSOp(rhs, rhs, fastmath=fastmath)
                and_ = riscv.AndOp(feq2, feq1)
                rewriter.replace_op(op, [feq1, feq2, and_, riscv.XoriOp(and_, 1)])
            # true
            case 15:
                rewriter.replace_op(op, [riscv.LiOp(1)])
            case _:
                raise NotImplementedError("Cmpf predicate not supported")

match_and_rewrite(op: arith.CmpfOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.CmpfOp, rewriter: PatternRewriter) -> None:
    # https://llvm.org/docs/LangRef.html#id309
    lhs, rhs = cast_operands_to_regs(rewriter, op)
    cast_op_results(rewriter, op)

    fastmath = riscv.FastMathFlagsAttr(op.fastmath.data)

    match op.predicate.value.data:
        # false
        case 0:
            rewriter.replace_op(op, [riscv.LiOp(0)])
        # oeq
        case 1:
            rewriter.replace_op(op, [riscv.FeqSOp(lhs, rhs, fastmath=fastmath)])
        # ogt
        case 2:
            rewriter.replace_op(op, [riscv.FltSOp(rhs, lhs, fastmath=fastmath)])
        # oge
        case 3:
            rewriter.replace_op(op, [riscv.FleSOp(rhs, lhs, fastmath=fastmath)])
        # olt
        case 4:
            rewriter.replace_op(op, [riscv.FltSOp(lhs, rhs, fastmath=fastmath)])
        # ole
        case 5:
            rewriter.replace_op(op, [riscv.FleSOp(lhs, rhs, fastmath=fastmath)])
        # one
        case 6:
            flt1 = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
            flt2 = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
            rewriter.replace_op(
                op,
                [
                    flt1,
                    flt2,
                    riscv.OrOp(flt2, flt1),
                ],
            )
        # ord
        case 7:
            feq1 = riscv.FeqSOp(lhs, lhs, fastmath=fastmath)
            feq2 = riscv.FeqSOp(rhs, rhs, fastmath=fastmath)
            rewriter.replace_op(
                op,
                [
                    feq1,
                    feq2,
                    riscv.AndOp(feq2, feq1),
                ],
            )
        # ueq
        case 8:
            flt1 = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
            flt2 = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
            or_ = riscv.OrOp(flt2, flt1)
            rewriter.replace_op(op, [flt1, flt2, or_, riscv.XoriOp(or_, 1)])
        # ugt
        case 9:
            fle = riscv.FleSOp(lhs, rhs, fastmath=fastmath)
            rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
        # uge
        case 10:
            fle = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
            rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
        # ult
        case 11:
            fle = riscv.FleSOp(rhs, lhs, fastmath=fastmath)
            rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
        # ule
        case 12:
            flt = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
            rewriter.replace_op(op, [flt, riscv.XoriOp(flt, 1)])
        # une
        case 13:
            feq = riscv.FeqSOp(lhs, rhs, fastmath=fastmath)
            rewriter.replace_op(op, [feq, riscv.XoriOp(feq, 1)])
        # uno
        case 14:
            feq1 = riscv.FeqSOp(lhs, lhs, fastmath=fastmath)
            feq2 = riscv.FeqSOp(rhs, rhs, fastmath=fastmath)
            and_ = riscv.AndOp(feq2, feq1)
            rewriter.replace_op(op, [feq1, feq2, and_, riscv.XoriOp(and_, 1)])
        # true
        case 15:
            rewriter.replace_op(op, [riscv.LiOp(1)])
        case _:
            raise NotImplementedError("Cmpf predicate not supported")

LowerArithSIToFPOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
class LowerArithSIToFPOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.SIToFPOp, rewriter: PatternRewriter) -> None:
        match op.result.type:
            case Float32Type():
                cls = riscv.FCvtSWOp
            case Float64Type():
                cls = riscv.FCvtDWOp
            case _:
                raise ValueError(f"Unexpected float type {op.result.type}")

        rewriter.replace_op(
            op,
            (
                cast_input := UnrealizedConversionCastOp.get(
                    (op.input,), (_INT_REGISTER_TYPE,)
                ),
                new_op := cls(cast_input.results[0]),
                UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,)),
            ),
        )

match_and_rewrite(op: arith.SIToFPOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.SIToFPOp, rewriter: PatternRewriter) -> None:
    match op.result.type:
        case Float32Type():
            cls = riscv.FCvtSWOp
        case Float64Type():
            cls = riscv.FCvtDWOp
        case _:
            raise ValueError(f"Unexpected float type {op.result.type}")

    rewriter.replace_op(
        op,
        (
            cast_input := UnrealizedConversionCastOp.get(
                (op.input,), (_INT_REGISTER_TYPE,)
            ),
            new_op := cls(cast_input.results[0]),
            UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,)),
        ),
    )

LowerArithFPToSIOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
461
462
463
464
465
466
467
468
469
470
471
472
473
class LowerArithFPToSIOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.FPToSIOp, rewriter: PatternRewriter) -> None:
        rewriter.replace_op(
            op,
            (
                cast_input := UnrealizedConversionCastOp.get(
                    (op.input,), (_FLOAT_REGISTER_TYPE,)
                ),
                new_op := riscv.FCvtWSOp(cast_input.results[0]),
                UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,)),
            ),
        )

match_and_rewrite(op: arith.FPToSIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
462
463
464
465
466
467
468
469
470
471
472
473
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.FPToSIOp, rewriter: PatternRewriter) -> None:
    rewriter.replace_op(
        op,
        (
            cast_input := UnrealizedConversionCastOp.get(
                (op.input,), (_FLOAT_REGISTER_TYPE,)
            ),
            new_op := riscv.FCvtWSOp(cast_input.results[0]),
            UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,)),
        ),
    )

LowerArithExtFOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
476
477
478
479
class LowerArithExtFOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.ExtFOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("ExtF is not supported")

match_and_rewrite(op: arith.ExtFOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
477
478
479
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.ExtFOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("ExtF is not supported")

LowerArithTruncFOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
482
483
484
485
class LowerArithTruncFOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.TruncFOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("TruncF is not supported")

match_and_rewrite(op: arith.TruncFOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
483
484
485
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.TruncFOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("TruncF is not supported")

ConvertArithToRiscvPass dataclass

Bases: ModulePass

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
class ConvertArithToRiscvPass(ModulePass):
    name = "convert-arith-to-riscv"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        walker = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LowerArithConstant(),
                    LowerArithIndexCast(),
                    LowerArithSIToFPOp(),
                    LowerArithFPToSIOp(),
                    lower_arith_addi,
                    lower_arith_subi,
                    lower_arith_muli,
                    lower_arith_divui,
                    lower_arith_divsi,
                    LowerArithFloorDivSI(),
                    lower_arith_remsi,
                    LowerArithCmpi(),
                    lower_arith_addf,
                    lower_arith_subf,
                    lower_arith_divf,
                    LowerArithNegf(),
                    lower_arith_mulf,
                    LowerArithCmpf(),
                    lower_arith_remui,
                    lower_arith_andi,
                    lower_arith_ori,
                    lower_arith_xori,
                    lower_arith_shli,
                    lower_arith_shrui,
                    lower_arith_shrsi,
                    LowerArithCeilDivSI(),
                    LowerArithCeilDivUI(),
                    LowerArithMinSI(),
                    LowerArithMaxSI(),
                    LowerArithMinUI(),
                    LowerArithMaxUI(),
                    LowerArithSelect(),
                    LowerArithExtFOp(),
                    LowerArithTruncFOp(),
                    lower_arith_minf,
                    lower_arith_maxf,
                ],
                dce_enabled=False,
            ),
            apply_recursively=False,
        )
        walker.rewrite_module(op)

name = 'convert-arith-to-riscv' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
def apply(self, ctx: Context, op: ModuleOp) -> None:
    walker = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                LowerArithConstant(),
                LowerArithIndexCast(),
                LowerArithSIToFPOp(),
                LowerArithFPToSIOp(),
                lower_arith_addi,
                lower_arith_subi,
                lower_arith_muli,
                lower_arith_divui,
                lower_arith_divsi,
                LowerArithFloorDivSI(),
                lower_arith_remsi,
                LowerArithCmpi(),
                lower_arith_addf,
                lower_arith_subf,
                lower_arith_divf,
                LowerArithNegf(),
                lower_arith_mulf,
                LowerArithCmpf(),
                lower_arith_remui,
                lower_arith_andi,
                lower_arith_ori,
                lower_arith_xori,
                lower_arith_shli,
                lower_arith_shrui,
                lower_arith_shrsi,
                LowerArithCeilDivSI(),
                LowerArithCeilDivUI(),
                LowerArithMinSI(),
                LowerArithMaxSI(),
                LowerArithMinUI(),
                LowerArithMaxUI(),
                LowerArithSelect(),
                LowerArithExtFOp(),
                LowerArithTruncFOp(),
                lower_arith_minf,
                lower_arith_maxf,
            ],
            dce_enabled=False,
        ),
        apply_recursively=False,
    )
    walker.rewrite_module(op)