Skip to content

Convert arith to riscv

convert_arith_to_riscv

RdRsRsIntegerOperation = riscv.RdRsRsOperation[riscv.IntRegisterType, riscv.IntRegisterType, riscv.IntRegisterType] module-attribute

RdRsRsFloatOperation = riscv.RdRsRsOperation[riscv.FloatRegisterType, riscv.FloatRegisterType, riscv.FloatRegisterType] module-attribute

lower_arith_addi = LowerBinaryIntegerOp(arith.AddiOp, riscv.AddOp) module-attribute

lower_arith_subi = LowerBinaryIntegerOp(arith.SubiOp, riscv.SubOp) module-attribute

lower_arith_muli = LowerBinaryIntegerOp(arith.MuliOp, riscv.MulOp) module-attribute

lower_arith_divui = LowerBinaryIntegerOp(arith.DivUIOp, riscv.DivuOp) module-attribute

lower_arith_divsi = LowerBinaryIntegerOp(arith.DivSIOp, riscv.DivOp) module-attribute

lower_arith_remui = LowerBinaryIntegerOp(arith.RemUIOp, riscv.RemuOp) module-attribute

lower_arith_remsi = LowerBinaryIntegerOp(arith.RemSIOp, riscv.RemOp) module-attribute

lower_arith_andi = LowerBinaryIntegerOp(arith.AndIOp, riscv.AndOp) module-attribute

lower_arith_ori = LowerBinaryIntegerOp(arith.OrIOp, riscv.OrOp) module-attribute

lower_arith_xori = LowerBinaryIntegerOp(arith.XOrIOp, riscv.XorOp) module-attribute

lower_arith_shli = LowerBinaryIntegerOp(arith.ShLIOp, riscv.SllOp) module-attribute

lower_arith_shrui = LowerBinaryIntegerOp(arith.ShRUIOp, riscv.SrlOp) module-attribute

lower_arith_shrsi = LowerBinaryIntegerOp(arith.ShRSIOp, riscv.SraOp) module-attribute

lower_arith_addf = LowerBinaryFloatOp(arith.AddfOp, riscv.FAddSOp, riscv.FAddDOp) module-attribute

lower_arith_subf = LowerBinaryFloatOp(arith.SubfOp, riscv.FSubSOp, riscv.FSubDOp) module-attribute

lower_arith_mulf = LowerBinaryFloatOp(arith.MulfOp, riscv.FMulSOp, riscv.FMulDOp) module-attribute

lower_arith_divf = LowerBinaryFloatOp(arith.DivfOp, riscv.FDivSOp, riscv.FDivDOp) module-attribute

lower_arith_minf = LowerBinaryFloatOp(arith.MinimumfOp, riscv.FMinSOp, riscv.FMinDOp) module-attribute

lower_arith_maxf = LowerBinaryFloatOp(arith.MaximumfOp, riscv.FMaxSOp, riscv.FMaxDOp) module-attribute

LowerArithConstant

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class LowerArithConstant(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.ConstantOp, rewriter: PatternRewriter
    ) -> None:
        op_result_type = op.result.type
        if isa(op_result_type, IntegerType) and isinstance(
            op_val := op.value, IntegerAttr
        ):
            if op_result_type.width.data <= 32:
                rewriter.replace_op(
                    op,
                    [
                        constant := rv32.LiOp(op_val.value.data),
                        UnrealizedConversionCastOp.get(
                            constant.results, (op_result_type,)
                        ),
                    ],
                )
            else:
                raise NotImplementedError("Only 32 bit integers are supported for now")
        elif isinstance(op_val := op.value, FloatAttr):
            if isinstance(op_result_type, Float32Type):
                rewriter.replace_op(
                    op,
                    [
                        lui := rv32.LiOp(
                            convert_f32_to_u32(op_val.value.data),
                            rd=_INT_REGISTER_TYPE,
                        ),
                        fld := riscv.FMvWXOp(lui.rd),
                        UnrealizedConversionCastOp.get(fld.results, (op_result_type,)),
                    ],
                )
            elif isinstance(op_result_type, Float64Type):
                # There is no way to load an immediate value to a float register directly.

                s32_min = signed_lower_bound(32)
                s32_max = signed_upper_bound(32)
                # If the value is an integer that fits in s32, then convert.
                if (val_data := op_val.value.data).is_integer() and s32_min <= (
                    int_val := int(val_data)
                ) < s32_max:
                    rewriter.replace_op(
                        op,
                        [
                            lui := rv32.LiOp(
                                int_val,
                                rd=_INT_REGISTER_TYPE,
                            ),
                            fcvtdw := riscv.FCvtDWOp(lui.rd, rd=_FLOAT_REGISTER_TYPE),
                            UnrealizedConversionCastOp.get(
                                fcvtdw.results, (op_result_type,)
                            ),
                        ],
                    )
                else:
                    # We have to load the bits into an integer register, store them on the
                    # stack, and load again.

                    # TODO: check the xlen in this lowering.

                    # This lowering assumes that xlen is 32 and flen is 64

                    lower, upper = struct.unpack(
                        "<ii", struct.pack("<d", op_val.value.data)
                    )
                    rewriter.replace_op(
                        op,
                        [
                            sp := rv32.GetRegisterOp(riscv.Registers.SP),
                            li_upper := rv32.LiOp(upper),
                            riscv.SwOp(sp, li_upper, -4),
                            li_lower := rv32.LiOp(lower),
                            riscv.SwOp(sp, li_lower, -8),
                            fld := riscv.FLdOp(sp, -8, rd=_FLOAT_REGISTER_TYPE),
                            UnrealizedConversionCastOp.get(
                                fld.results, (op_result_type,)
                            ),
                        ],
                    )
            else:
                raise NotImplementedError("Only 32 or 64 bit floats are supported")
        elif (
            isinstance(op_val, DenseIntOrFPElementsAttr) and len(op_val.data.data) == 8
        ):
            if isinstance(op_val.get_element_type(), IntegerType):
                raise PassFailedException(
                    "Integer vector constants cannot be lowered to float registers"
                )

            # We have to load the bits into an integer register, store them on the
            # stack, and load again.

            # TODO: check the xlen in this lowering.

            # This lowering assumes that xlen is 32 and flen is 64

            lower, upper = struct.unpack("<ii", op_val.data.data)
            rewriter.replace_op(
                op,
                [
                    sp := rv32.GetRegisterOp(riscv.Registers.SP),
                    li_upper := rv32.LiOp(upper),
                    riscv.SwOp(sp, li_upper, -4),
                    li_lower := rv32.LiOp(lower),
                    riscv.SwOp(sp, li_lower, -8),
                    fld := riscv.FLdOp(sp, -8, rd=_FLOAT_REGISTER_TYPE),
                    UnrealizedConversionCastOp.get(fld.results, (op_result_type,)),
                ],
            )
        elif isinstance(op_result_type, IndexType) and isinstance(
            op_val := op.value, IntegerAttr
        ):
            rewriter.replace_op(
                op,
                [
                    constant := rv32.LiOp(op_val.value.data),
                    UnrealizedConversionCastOp.get(constant.results, (op_result_type,)),
                ],
            )
        else:
            raise NotImplementedError(
                f"Unsupported constant type {op_val} of type {type(op_val)}"
            )

match_and_rewrite(op: arith.ConstantOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.ConstantOp, rewriter: PatternRewriter
) -> None:
    op_result_type = op.result.type
    if isa(op_result_type, IntegerType) and isinstance(
        op_val := op.value, IntegerAttr
    ):
        if op_result_type.width.data <= 32:
            rewriter.replace_op(
                op,
                [
                    constant := rv32.LiOp(op_val.value.data),
                    UnrealizedConversionCastOp.get(
                        constant.results, (op_result_type,)
                    ),
                ],
            )
        else:
            raise NotImplementedError("Only 32 bit integers are supported for now")
    elif isinstance(op_val := op.value, FloatAttr):
        if isinstance(op_result_type, Float32Type):
            rewriter.replace_op(
                op,
                [
                    lui := rv32.LiOp(
                        convert_f32_to_u32(op_val.value.data),
                        rd=_INT_REGISTER_TYPE,
                    ),
                    fld := riscv.FMvWXOp(lui.rd),
                    UnrealizedConversionCastOp.get(fld.results, (op_result_type,)),
                ],
            )
        elif isinstance(op_result_type, Float64Type):
            # There is no way to load an immediate value to a float register directly.

            s32_min = signed_lower_bound(32)
            s32_max = signed_upper_bound(32)
            # If the value is an integer that fits in s32, then convert.
            if (val_data := op_val.value.data).is_integer() and s32_min <= (
                int_val := int(val_data)
            ) < s32_max:
                rewriter.replace_op(
                    op,
                    [
                        lui := rv32.LiOp(
                            int_val,
                            rd=_INT_REGISTER_TYPE,
                        ),
                        fcvtdw := riscv.FCvtDWOp(lui.rd, rd=_FLOAT_REGISTER_TYPE),
                        UnrealizedConversionCastOp.get(
                            fcvtdw.results, (op_result_type,)
                        ),
                    ],
                )
            else:
                # We have to load the bits into an integer register, store them on the
                # stack, and load again.

                # TODO: check the xlen in this lowering.

                # This lowering assumes that xlen is 32 and flen is 64

                lower, upper = struct.unpack(
                    "<ii", struct.pack("<d", op_val.value.data)
                )
                rewriter.replace_op(
                    op,
                    [
                        sp := rv32.GetRegisterOp(riscv.Registers.SP),
                        li_upper := rv32.LiOp(upper),
                        riscv.SwOp(sp, li_upper, -4),
                        li_lower := rv32.LiOp(lower),
                        riscv.SwOp(sp, li_lower, -8),
                        fld := riscv.FLdOp(sp, -8, rd=_FLOAT_REGISTER_TYPE),
                        UnrealizedConversionCastOp.get(
                            fld.results, (op_result_type,)
                        ),
                    ],
                )
        else:
            raise NotImplementedError("Only 32 or 64 bit floats are supported")
    elif (
        isinstance(op_val, DenseIntOrFPElementsAttr) and len(op_val.data.data) == 8
    ):
        if isinstance(op_val.get_element_type(), IntegerType):
            raise PassFailedException(
                "Integer vector constants cannot be lowered to float registers"
            )

        # We have to load the bits into an integer register, store them on the
        # stack, and load again.

        # TODO: check the xlen in this lowering.

        # This lowering assumes that xlen is 32 and flen is 64

        lower, upper = struct.unpack("<ii", op_val.data.data)
        rewriter.replace_op(
            op,
            [
                sp := rv32.GetRegisterOp(riscv.Registers.SP),
                li_upper := rv32.LiOp(upper),
                riscv.SwOp(sp, li_upper, -4),
                li_lower := rv32.LiOp(lower),
                riscv.SwOp(sp, li_lower, -8),
                fld := riscv.FLdOp(sp, -8, rd=_FLOAT_REGISTER_TYPE),
                UnrealizedConversionCastOp.get(fld.results, (op_result_type,)),
            ],
        )
    elif isinstance(op_result_type, IndexType) and isinstance(
        op_val := op.value, IntegerAttr
    ):
        rewriter.replace_op(
            op,
            [
                constant := rv32.LiOp(op_val.value.data),
                UnrealizedConversionCastOp.get(constant.results, (op_result_type,)),
            ],
        )
    else:
        raise NotImplementedError(
            f"Unsupported constant type {op_val} of type {type(op_val)}"
        )

LowerArithIndexCast

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
166
167
168
169
170
171
172
173
174
175
176
177
class LowerArithIndexCast(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.IndexCastOp, rewriter: PatternRewriter
    ) -> None:
        """
        On a RV32 triple, the index type is 32 bits, so we can just drop the cast.
        """

        rewriter.replace_op(
            op, UnrealizedConversionCastOp.get((op.input,), (op.result.type,))
        )

match_and_rewrite(op: arith.IndexCastOp, rewriter: PatternRewriter) -> None

On a RV32 triple, the index type is 32 bits, so we can just drop the cast.

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
167
168
169
170
171
172
173
174
175
176
177
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.IndexCastOp, rewriter: PatternRewriter
) -> None:
    """
    On a RV32 triple, the index type is 32 bits, so we can just drop the cast.
    """

    rewriter.replace_op(
        op, UnrealizedConversionCastOp.get((op.input,), (op.result.type,))
    )

LowerBinaryIntegerOp dataclass

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
@dataclass
class LowerBinaryIntegerOp(RewritePattern):
    arith_op_cls: type[arith.SignlessIntegerBinaryOperation]
    riscv_op_cls: type[RdRsRsIntegerOperation]

    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
        if not isinstance(op, self.arith_op_cls):
            return

        lhs = UnrealizedConversionCastOp.get((op.lhs,), (_INT_REGISTER_TYPE,))
        rhs = UnrealizedConversionCastOp.get((op.rhs,), (_INT_REGISTER_TYPE,))
        add = self.riscv_op_cls(lhs, rhs, rd=_INT_REGISTER_TYPE)
        cast = UnrealizedConversionCastOp.get((add.rd,), (op.result.type,))

        rewriter.replace_op(op, (lhs, rhs, add, cast))

arith_op_cls: type[arith.SignlessIntegerBinaryOperation] instance-attribute

riscv_op_cls: type[RdRsRsIntegerOperation] instance-attribute

__init__(arith_op_cls: type[arith.SignlessIntegerBinaryOperation], riscv_op_cls: type[RdRsRsIntegerOperation]) -> None

match_and_rewrite(op: Operation, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
194
195
196
197
198
199
200
201
202
203
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
    if not isinstance(op, self.arith_op_cls):
        return

    lhs = UnrealizedConversionCastOp.get((op.lhs,), (_INT_REGISTER_TYPE,))
    rhs = UnrealizedConversionCastOp.get((op.rhs,), (_INT_REGISTER_TYPE,))
    add = self.riscv_op_cls(lhs, rhs, rd=_INT_REGISTER_TYPE)
    cast = UnrealizedConversionCastOp.get((add.rd,), (op.result.type,))

    rewriter.replace_op(op, (lhs, rhs, add, cast))

LowerBinaryFloatOp dataclass

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
@dataclass
class LowerBinaryFloatOp(RewritePattern):
    arith_op_cls: type[arith.FloatingPointLikeBinaryOperation]
    riscv_f_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
    riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]

    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
        if not isinstance(op, self.arith_op_cls):
            return

        lhs = UnrealizedConversionCastOp.get((op.lhs,), (_FLOAT_REGISTER_TYPE,))
        rhs = UnrealizedConversionCastOp.get((op.rhs,), (_FLOAT_REGISTER_TYPE,))
        match op.lhs.type:
            case Float32Type():
                cls = self.riscv_f_op_cls
            case Float64Type():
                cls = self.riscv_d_op_cls
            case _:
                raise ValueError(f"Unexpected float type {op.lhs.type}")

        rv_flags = riscv.FastMathFlagsAttr(op.fastmath.data)

        new_op = cls(lhs, rhs, rd=_FLOAT_REGISTER_TYPE, fastmath=rv_flags)
        cast = UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,))

        rewriter.replace_op(op, (lhs, rhs, new_op, cast))

arith_op_cls: type[arith.FloatingPointLikeBinaryOperation] instance-attribute

riscv_f_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath] instance-attribute

riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath] instance-attribute

__init__(arith_op_cls: type[arith.FloatingPointLikeBinaryOperation], riscv_f_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath], riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]) -> None

match_and_rewrite(op: Operation, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
    if not isinstance(op, self.arith_op_cls):
        return

    lhs = UnrealizedConversionCastOp.get((op.lhs,), (_FLOAT_REGISTER_TYPE,))
    rhs = UnrealizedConversionCastOp.get((op.rhs,), (_FLOAT_REGISTER_TYPE,))
    match op.lhs.type:
        case Float32Type():
            cls = self.riscv_f_op_cls
        case Float64Type():
            cls = self.riscv_d_op_cls
        case _:
            raise ValueError(f"Unexpected float type {op.lhs.type}")

    rv_flags = riscv.FastMathFlagsAttr(op.fastmath.data)

    new_op = cls(lhs, rhs, rd=_FLOAT_REGISTER_TYPE, fastmath=rv_flags)
    cast = UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,))

    rewriter.replace_op(op, (lhs, rhs, new_op, cast))

LowerArithFloorDivSI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
241
242
243
244
245
246
class LowerArithFloorDivSI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.FloorDivSIOp, rewriter: PatternRewriter
    ) -> None:
        raise NotImplementedError("FloorDivSI is not supported")

match_and_rewrite(op: arith.FloorDivSIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
242
243
244
245
246
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.FloorDivSIOp, rewriter: PatternRewriter
) -> None:
    raise NotImplementedError("FloorDivSI is not supported")

LowerArithCeilDivSI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
249
250
251
252
253
254
class LowerArithCeilDivSI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.CeilDivSIOp, rewriter: PatternRewriter
    ) -> None:
        raise NotImplementedError("CeilDivSI is not supported")

match_and_rewrite(op: arith.CeilDivSIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
250
251
252
253
254
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.CeilDivSIOp, rewriter: PatternRewriter
) -> None:
    raise NotImplementedError("CeilDivSI is not supported")

LowerArithCeilDivUI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
257
258
259
260
261
262
class LowerArithCeilDivUI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: arith.CeilDivUIOp, rewriter: PatternRewriter
    ) -> None:
        raise NotImplementedError("CeilDivUI is not supported")

match_and_rewrite(op: arith.CeilDivUIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
258
259
260
261
262
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: arith.CeilDivUIOp, rewriter: PatternRewriter
) -> None:
    raise NotImplementedError("CeilDivUI is not supported")

LowerArithMinSI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
269
270
271
272
class LowerArithMinSI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.MinSIOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("MinSI is not supported")

match_and_rewrite(op: arith.MinSIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
270
271
272
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.MinSIOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("MinSI is not supported")

LowerArithMaxSI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
275
276
277
278
class LowerArithMaxSI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.MaxSIOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("MaxSI is not supported")

match_and_rewrite(op: arith.MaxSIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
276
277
278
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.MaxSIOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("MaxSI is not supported")

LowerArithMinUI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
281
282
283
284
class LowerArithMinUI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.MinUIOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("MinUI is not supported")

match_and_rewrite(op: arith.MinUIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
282
283
284
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.MinUIOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("MinUI is not supported")

LowerArithMaxUI

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
287
288
289
290
class LowerArithMaxUI(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.MaxUIOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("MaxUI is not supported")

match_and_rewrite(op: arith.MaxUIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
288
289
290
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.MaxUIOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("MaxUI is not supported")

LowerArithCmpi

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
class LowerArithCmpi(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.CmpiOp, rewriter: PatternRewriter) -> None:
        # based on https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/RISCV/i32-icmp.ll
        lhs, rhs = cast_operands_to_regs(rewriter, op)
        cast_op_results(rewriter, op)

        match op.predicate.value.data:
            # eq
            case 0:
                xor_op = riscv.XorOp(lhs, rhs)
                seqz_op = riscv.SltiuOp(xor_op, 1)
                rewriter.replace_op(op, [xor_op, seqz_op])
            # ne
            case 1:
                zero = rv32.GetRegisterOp(riscv.Registers.ZERO)
                xor_op = riscv.XorOp(lhs, rhs)
                snez_op = riscv.SltuOp(zero, xor_op)
                rewriter.replace_op(op, [zero, xor_op, snez_op])
            # slt
            case 2:
                rewriter.replace_op(op, [riscv.SltOp(lhs, rhs)])
            # sle
            case 3:
                slt = riscv.SltOp(lhs, rhs)
                xori = riscv.XoriOp(slt, 1)
                rewriter.replace_op(op, [slt, xori])
            # ult
            case 4:
                rewriter.replace_op(op, [riscv.SltuOp(lhs, rhs)])
            # ule
            case 5:
                sltu = riscv.SltuOp(lhs, rhs)
                xori = riscv.XoriOp(sltu, 1)
                rewriter.replace_op(op, [sltu, xori])
            # ugt
            case 6:
                rewriter.replace_op(op, [riscv.SltuOp(rhs, lhs)])
            # uge
            case 7:
                sltu = riscv.SltuOp(rhs, lhs)
                xori = riscv.XoriOp(sltu, 1)
                rewriter.replace_op(op, [sltu, xori])
            case _:
                raise NotImplementedError("Cmpi predicate not supported")

match_and_rewrite(op: arith.CmpiOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.CmpiOp, rewriter: PatternRewriter) -> None:
    # based on https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/RISCV/i32-icmp.ll
    lhs, rhs = cast_operands_to_regs(rewriter, op)
    cast_op_results(rewriter, op)

    match op.predicate.value.data:
        # eq
        case 0:
            xor_op = riscv.XorOp(lhs, rhs)
            seqz_op = riscv.SltiuOp(xor_op, 1)
            rewriter.replace_op(op, [xor_op, seqz_op])
        # ne
        case 1:
            zero = rv32.GetRegisterOp(riscv.Registers.ZERO)
            xor_op = riscv.XorOp(lhs, rhs)
            snez_op = riscv.SltuOp(zero, xor_op)
            rewriter.replace_op(op, [zero, xor_op, snez_op])
        # slt
        case 2:
            rewriter.replace_op(op, [riscv.SltOp(lhs, rhs)])
        # sle
        case 3:
            slt = riscv.SltOp(lhs, rhs)
            xori = riscv.XoriOp(slt, 1)
            rewriter.replace_op(op, [slt, xori])
        # ult
        case 4:
            rewriter.replace_op(op, [riscv.SltuOp(lhs, rhs)])
        # ule
        case 5:
            sltu = riscv.SltuOp(lhs, rhs)
            xori = riscv.XoriOp(sltu, 1)
            rewriter.replace_op(op, [sltu, xori])
        # ugt
        case 6:
            rewriter.replace_op(op, [riscv.SltuOp(rhs, lhs)])
        # uge
        case 7:
            sltu = riscv.SltuOp(rhs, lhs)
            xori = riscv.XoriOp(sltu, 1)
            rewriter.replace_op(op, [sltu, xori])
        case _:
            raise NotImplementedError("Cmpi predicate not supported")

LowerArithSelect

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
340
341
342
343
class LowerArithSelect(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("Select is not supported")

match_and_rewrite(op: arith.SelectOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
341
342
343
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("Select is not supported")

LowerArithNegf

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
362
363
364
365
366
367
368
369
370
371
372
373
374
class LowerArithNegf(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.NegfOp, rewriter: PatternRewriter) -> None:
        rewriter.replace_op(
            op,
            (
                operand := UnrealizedConversionCastOp.get(
                    (op.operand,), (_FLOAT_REGISTER_TYPE,)
                ),
                negf := riscv.FSgnJNSOp(operand, operand),
                UnrealizedConversionCastOp.get((negf.rd,), (op.result.type,)),
            ),
        )

match_and_rewrite(op: arith.NegfOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
363
364
365
366
367
368
369
370
371
372
373
374
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.NegfOp, rewriter: PatternRewriter) -> None:
    rewriter.replace_op(
        op,
        (
            operand := UnrealizedConversionCastOp.get(
                (op.operand,), (_FLOAT_REGISTER_TYPE,)
            ),
            negf := riscv.FSgnJNSOp(operand, operand),
            UnrealizedConversionCastOp.get((negf.rd,), (op.result.type,)),
        ),
    )

LowerArithCmpf

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
class LowerArithCmpf(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.CmpfOp, rewriter: PatternRewriter) -> None:
        # https://llvm.org/docs/LangRef.html#id309
        lhs, rhs = cast_operands_to_regs(rewriter, op)
        cast_op_results(rewriter, op)

        fastmath = riscv.FastMathFlagsAttr(op.fastmath.data)

        match op.predicate.value.data:
            # false
            case 0:
                rewriter.replace_op(op, [rv32.LiOp(0)])
            # oeq
            case 1:
                rewriter.replace_op(op, [riscv.FeqSOp(lhs, rhs, fastmath=fastmath)])
            # ogt
            case 2:
                rewriter.replace_op(op, [riscv.FltSOp(rhs, lhs, fastmath=fastmath)])
            # oge
            case 3:
                rewriter.replace_op(op, [riscv.FleSOp(rhs, lhs, fastmath=fastmath)])
            # olt
            case 4:
                rewriter.replace_op(op, [riscv.FltSOp(lhs, rhs, fastmath=fastmath)])
            # ole
            case 5:
                rewriter.replace_op(op, [riscv.FleSOp(lhs, rhs, fastmath=fastmath)])
            # one
            case 6:
                flt1 = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
                flt2 = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
                rewriter.replace_op(
                    op,
                    [
                        flt1,
                        flt2,
                        riscv.OrOp(flt2, flt1),
                    ],
                )
            # ord
            case 7:
                feq1 = riscv.FeqSOp(lhs, lhs, fastmath=fastmath)
                feq2 = riscv.FeqSOp(rhs, rhs, fastmath=fastmath)
                rewriter.replace_op(
                    op,
                    [
                        feq1,
                        feq2,
                        riscv.AndOp(feq2, feq1),
                    ],
                )
            # ueq
            case 8:
                flt1 = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
                flt2 = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
                or_ = riscv.OrOp(flt2, flt1)
                rewriter.replace_op(op, [flt1, flt2, or_, riscv.XoriOp(or_, 1)])
            # ugt
            case 9:
                fle = riscv.FleSOp(lhs, rhs, fastmath=fastmath)
                rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
            # uge
            case 10:
                fle = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
                rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
            # ult
            case 11:
                fle = riscv.FleSOp(rhs, lhs, fastmath=fastmath)
                rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
            # ule
            case 12:
                flt = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
                rewriter.replace_op(op, [flt, riscv.XoriOp(flt, 1)])
            # une
            case 13:
                feq = riscv.FeqSOp(lhs, rhs, fastmath=fastmath)
                rewriter.replace_op(op, [feq, riscv.XoriOp(feq, 1)])
            # uno
            case 14:
                feq1 = riscv.FeqSOp(lhs, lhs, fastmath=fastmath)
                feq2 = riscv.FeqSOp(rhs, rhs, fastmath=fastmath)
                and_ = riscv.AndOp(feq2, feq1)
                rewriter.replace_op(op, [feq1, feq2, and_, riscv.XoriOp(and_, 1)])
            # true
            case 15:
                rewriter.replace_op(op, [rv32.LiOp(1)])
            case _:
                raise NotImplementedError("Cmpf predicate not supported")

match_and_rewrite(op: arith.CmpfOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.CmpfOp, rewriter: PatternRewriter) -> None:
    # https://llvm.org/docs/LangRef.html#id309
    lhs, rhs = cast_operands_to_regs(rewriter, op)
    cast_op_results(rewriter, op)

    fastmath = riscv.FastMathFlagsAttr(op.fastmath.data)

    match op.predicate.value.data:
        # false
        case 0:
            rewriter.replace_op(op, [rv32.LiOp(0)])
        # oeq
        case 1:
            rewriter.replace_op(op, [riscv.FeqSOp(lhs, rhs, fastmath=fastmath)])
        # ogt
        case 2:
            rewriter.replace_op(op, [riscv.FltSOp(rhs, lhs, fastmath=fastmath)])
        # oge
        case 3:
            rewriter.replace_op(op, [riscv.FleSOp(rhs, lhs, fastmath=fastmath)])
        # olt
        case 4:
            rewriter.replace_op(op, [riscv.FltSOp(lhs, rhs, fastmath=fastmath)])
        # ole
        case 5:
            rewriter.replace_op(op, [riscv.FleSOp(lhs, rhs, fastmath=fastmath)])
        # one
        case 6:
            flt1 = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
            flt2 = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
            rewriter.replace_op(
                op,
                [
                    flt1,
                    flt2,
                    riscv.OrOp(flt2, flt1),
                ],
            )
        # ord
        case 7:
            feq1 = riscv.FeqSOp(lhs, lhs, fastmath=fastmath)
            feq2 = riscv.FeqSOp(rhs, rhs, fastmath=fastmath)
            rewriter.replace_op(
                op,
                [
                    feq1,
                    feq2,
                    riscv.AndOp(feq2, feq1),
                ],
            )
        # ueq
        case 8:
            flt1 = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
            flt2 = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
            or_ = riscv.OrOp(flt2, flt1)
            rewriter.replace_op(op, [flt1, flt2, or_, riscv.XoriOp(or_, 1)])
        # ugt
        case 9:
            fle = riscv.FleSOp(lhs, rhs, fastmath=fastmath)
            rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
        # uge
        case 10:
            fle = riscv.FltSOp(lhs, rhs, fastmath=fastmath)
            rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
        # ult
        case 11:
            fle = riscv.FleSOp(rhs, lhs, fastmath=fastmath)
            rewriter.replace_op(op, [fle, riscv.XoriOp(fle, 1)])
        # ule
        case 12:
            flt = riscv.FltSOp(rhs, lhs, fastmath=fastmath)
            rewriter.replace_op(op, [flt, riscv.XoriOp(flt, 1)])
        # une
        case 13:
            feq = riscv.FeqSOp(lhs, rhs, fastmath=fastmath)
            rewriter.replace_op(op, [feq, riscv.XoriOp(feq, 1)])
        # uno
        case 14:
            feq1 = riscv.FeqSOp(lhs, lhs, fastmath=fastmath)
            feq2 = riscv.FeqSOp(rhs, rhs, fastmath=fastmath)
            and_ = riscv.AndOp(feq2, feq1)
            rewriter.replace_op(op, [feq1, feq2, and_, riscv.XoriOp(and_, 1)])
        # true
        case 15:
            rewriter.replace_op(op, [rv32.LiOp(1)])
        case _:
            raise NotImplementedError("Cmpf predicate not supported")

LowerArithSIToFPOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
class LowerArithSIToFPOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.SIToFPOp, rewriter: PatternRewriter) -> None:
        match op.result.type:
            case Float32Type():
                cls = riscv.FCvtSWOp
            case Float64Type():
                cls = riscv.FCvtDWOp
            case _:
                raise ValueError(f"Unexpected float type {op.result.type}")

        rewriter.replace_op(
            op,
            (
                cast_input := UnrealizedConversionCastOp.get(
                    (op.input,), (_INT_REGISTER_TYPE,)
                ),
                new_op := cls(cast_input.results[0]),
                UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,)),
            ),
        )

match_and_rewrite(op: arith.SIToFPOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.SIToFPOp, rewriter: PatternRewriter) -> None:
    match op.result.type:
        case Float32Type():
            cls = riscv.FCvtSWOp
        case Float64Type():
            cls = riscv.FCvtDWOp
        case _:
            raise ValueError(f"Unexpected float type {op.result.type}")

    rewriter.replace_op(
        op,
        (
            cast_input := UnrealizedConversionCastOp.get(
                (op.input,), (_INT_REGISTER_TYPE,)
            ),
            new_op := cls(cast_input.results[0]),
            UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,)),
        ),
    )

LowerArithFPToSIOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
491
492
493
494
495
496
497
498
499
500
501
502
503
class LowerArithFPToSIOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.FPToSIOp, rewriter: PatternRewriter) -> None:
        rewriter.replace_op(
            op,
            (
                cast_input := UnrealizedConversionCastOp.get(
                    (op.input,), (_FLOAT_REGISTER_TYPE,)
                ),
                new_op := riscv.FCvtWSOp(cast_input.results[0]),
                UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,)),
            ),
        )

match_and_rewrite(op: arith.FPToSIOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
492
493
494
495
496
497
498
499
500
501
502
503
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.FPToSIOp, rewriter: PatternRewriter) -> None:
    rewriter.replace_op(
        op,
        (
            cast_input := UnrealizedConversionCastOp.get(
                (op.input,), (_FLOAT_REGISTER_TYPE,)
            ),
            new_op := riscv.FCvtWSOp(cast_input.results[0]),
            UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,)),
        ),
    )

LowerArithExtFOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
506
507
508
509
class LowerArithExtFOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.ExtFOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("ExtF is not supported")

match_and_rewrite(op: arith.ExtFOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
507
508
509
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.ExtFOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("ExtF is not supported")

LowerArithTruncFOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
512
513
514
515
class LowerArithTruncFOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.TruncFOp, rewriter: PatternRewriter) -> None:
        raise NotImplementedError("TruncF is not supported")

match_and_rewrite(op: arith.TruncFOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
513
514
515
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.TruncFOp, rewriter: PatternRewriter) -> None:
    raise NotImplementedError("TruncF is not supported")

ConvertArithToRiscvPass dataclass

Bases: ModulePass

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
class ConvertArithToRiscvPass(ModulePass):
    name = "convert-arith-to-riscv"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        walker = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LowerArithConstant(),
                    LowerArithIndexCast(),
                    LowerArithSIToFPOp(),
                    LowerArithFPToSIOp(),
                    lower_arith_addi,
                    lower_arith_subi,
                    lower_arith_muli,
                    lower_arith_divui,
                    lower_arith_divsi,
                    LowerArithFloorDivSI(),
                    lower_arith_remsi,
                    LowerArithCmpi(),
                    lower_arith_addf,
                    lower_arith_subf,
                    lower_arith_divf,
                    LowerArithNegf(),
                    lower_arith_mulf,
                    LowerArithCmpf(),
                    lower_arith_remui,
                    lower_arith_andi,
                    lower_arith_ori,
                    lower_arith_xori,
                    lower_arith_shli,
                    lower_arith_shrui,
                    lower_arith_shrsi,
                    LowerArithCeilDivSI(),
                    LowerArithCeilDivUI(),
                    LowerArithMinSI(),
                    LowerArithMaxSI(),
                    LowerArithMinUI(),
                    LowerArithMaxUI(),
                    LowerArithSelect(),
                    LowerArithExtFOp(),
                    LowerArithTruncFOp(),
                    lower_arith_minf,
                    lower_arith_maxf,
                ],
            ),
            apply_recursively=False,
        )
        walker.rewrite_module(op)

name = 'convert-arith-to-riscv' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
def apply(self, ctx: Context, op: ModuleOp) -> None:
    walker = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                LowerArithConstant(),
                LowerArithIndexCast(),
                LowerArithSIToFPOp(),
                LowerArithFPToSIOp(),
                lower_arith_addi,
                lower_arith_subi,
                lower_arith_muli,
                lower_arith_divui,
                lower_arith_divsi,
                LowerArithFloorDivSI(),
                lower_arith_remsi,
                LowerArithCmpi(),
                lower_arith_addf,
                lower_arith_subf,
                lower_arith_divf,
                LowerArithNegf(),
                lower_arith_mulf,
                LowerArithCmpf(),
                lower_arith_remui,
                lower_arith_andi,
                lower_arith_ori,
                lower_arith_xori,
                lower_arith_shli,
                lower_arith_shrui,
                lower_arith_shrsi,
                LowerArithCeilDivSI(),
                LowerArithCeilDivUI(),
                LowerArithMinSI(),
                LowerArithMaxSI(),
                LowerArithMinUI(),
                LowerArithMaxUI(),
                LowerArithSelect(),
                LowerArithExtFOp(),
                LowerArithTruncFOp(),
                lower_arith_minf,
                lower_arith_maxf,
            ],
        ),
        apply_recursively=False,
    )
    walker.rewrite_module(op)