Skip to content

Utils

utils

register_type_for_type(attr: Attribute) -> type[riscv.IntRegisterType] | type[riscv.FloatRegisterType]

Returns the appropriate register fype for a given input type.

Source code in xdsl/backend/riscv/lowering/utils.py
15
16
17
18
19
20
21
22
23
24
25
def register_type_for_type(
    attr: Attribute,
) -> type[riscv.IntRegisterType] | type[riscv.FloatRegisterType]:
    """
    Returns the appropriate register fype for a given input type.
    """
    if isinstance(attr, riscv.IntRegisterType | riscv.FloatRegisterType):
        return type(attr)
    if isinstance(attr, builtin.AnyFloat):
        return riscv.FloatRegisterType
    return riscv.IntRegisterType

move_ops_for_value(value: SSAValue, value_type: Attribute, rd: riscv.RISCVRegisterType) -> tuple[Operation, SSAValue]

Returns the operation that moves the value from the input to a new register. In order to disambiguate which floating point move should be used (fmv.s vs fmv.d), the floating point type in question must be passed

Source code in xdsl/backend/riscv/lowering/utils.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def move_ops_for_value(
    value: SSAValue, value_type: Attribute, rd: riscv.RISCVRegisterType
) -> tuple[Operation, SSAValue]:
    """
    Returns the operation that moves the value from the input to a new register.
    In order to disambiguate which floating point move should be used (fmv.s vs fmv.d),
    the floating point type in question must be passed
    """

    if isinstance(rd, riscv.IntRegisterType):
        mv_op = riscv.MVOp(value, rd=rd)
        return mv_op, mv_op.rd
    elif isinstance(rd, riscv.FloatRegisterType):
        match value_type:
            case builtin.Float64Type():
                mv_op = riscv.FMvDOp(value, rd=rd)
            case builtin.Float32Type():
                mv_op = riscv.FMVOp(value, rd=rd)
            case _:
                raise NotImplementedError(
                    f"Move operation for float register containing value of type {value.type} is not implemented"
                )
        return mv_op, mv_op.rd
    else:
        raise NotImplementedError(f"Unsupported register type for move op: {rd}")

move_to_regs(values: Iterable[SSAValue], value_types: Iterable[Attribute], reg_types: Iterable[riscv.RISCVRegisterType]) -> tuple[list[Operation], list[SSAValue]]

Return move operations to a registers (a0, a1, ... | fa0, fa1, ...).

Source code in xdsl/backend/riscv/lowering/utils.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def move_to_regs(
    values: Iterable[SSAValue],
    value_types: Iterable[Attribute],
    reg_types: Iterable[riscv.RISCVRegisterType],
) -> tuple[list[Operation], list[SSAValue]]:
    """
    Return move operations to `a` registers (a0, a1, ... | fa0, fa1, ...).
    """
    # We only care about bitwidths for floats for now, so default to 32 for non floats
    widths = tuple(
        i.bitwidth if isinstance(i, builtin.AnyFloat) else 32 for i in value_types
    )

    new_op = riscv.ParallelMovOp(
        tuple(values),
        tuple(reg_types),
        builtin.DenseArrayBase.from_list(builtin.i32, widths),
    )
    new_values = new_op.results

    return [new_op], list(new_values)

a_regs_for_types(types: Iterable[Attribute]) -> Iterator[riscv.RISCVRegisterType]

Returns the "a" registers in which to store types, i.e. fa0, fa1, etc for floating-point values and a0, a1, etc for integer values and pointers. The register index is separate for integer and floating-point registers according to the RISC-V ABI.

Source code in xdsl/backend/riscv/lowering/utils.py
78
79
80
81
82
83
84
85
86
87
88
89
90
def a_regs_for_types(types: Iterable[Attribute]) -> Iterator[riscv.RISCVRegisterType]:
    """
    Returns the "a" registers in which to store types, i.e. `fa0`, `fa1`, etc for
    floating-point values and `a0`, `a1`, etc for integer values and pointers. The
    register index is separate for integer and floating-point registers according to the
    RISC-V ABI.
    """
    counter = Counter[type[riscv.RISCVRegisterType]]()
    for attr_type in types:
        register_type = register_type_for_type(attr_type)
        index = counter[register_type]
        yield register_type.a_register(index)
        counter[register_type] += 1

a_regs(values: Iterable[SSAValue]) -> Iterator[riscv.RISCVRegisterType]

Source code in xdsl/backend/riscv/lowering/utils.py
93
94
def a_regs(values: Iterable[SSAValue]) -> Iterator[riscv.RISCVRegisterType]:
    return a_regs_for_types(value.type for value in values)

move_to_a_regs(values: Iterable[SSAValue], value_types: Iterable[Attribute]) -> tuple[list[Operation], list[SSAValue]]

Return move operations to a registers (a0, a1, ... | fa0, fa1, ...).

Source code in xdsl/backend/riscv/lowering/utils.py
 97
 98
 99
100
101
102
103
104
def move_to_a_regs(
    values: Iterable[SSAValue],
    value_types: Iterable[Attribute],
) -> tuple[list[Operation], list[SSAValue]]:
    """
    Return move operations to `a` registers (a0, a1, ... | fa0, fa1, ...).
    """
    return move_to_regs(values, value_types, a_regs(values))

move_to_unallocated_regs(values: Iterable[SSAValue], value_types: Iterable[Attribute]) -> tuple[list[Operation], list[SSAValue]]

Return move operations to unallocated registers.

Source code in xdsl/backend/riscv/lowering/utils.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def move_to_unallocated_regs(
    values: Iterable[SSAValue],
    value_types: Iterable[Attribute],
) -> tuple[list[Operation], list[SSAValue]]:
    """
    Return move operations to unallocated registers.
    """

    outputs = (register_type_for_type(value.type).unallocated() for value in values)

    # We only care about bitwidths for floats for now, so default to 32 for non floats
    widths = tuple(
        i.bitwidth if isinstance(i, builtin.AnyFloat) else 32 for i in value_types
    )

    new_op = riscv.ParallelMovOp(
        tuple(values),
        tuple(outputs),
        builtin.DenseArrayBase.from_list(builtin.i32, widths),
    )
    new_values = new_op.results

    return [new_op], list(new_values)

cast_operands_to_regs(rewriter: PatternRewriter, operation: Operation | None = None) -> list[SSAValue]

Add cast operations just before the targeted operation if the operands were not already int registers.

Source code in xdsl/backend/riscv/lowering/utils.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def cast_operands_to_regs(
    rewriter: PatternRewriter, operation: Operation | None = None
) -> list[SSAValue]:
    """
    Add cast operations just before the targeted operation
    if the operands were not already int registers.
    """
    if operation is None:
        warnings.warn(
            "Please provide use `cast_operands_to_regs(rewriter, rewriter.current_operation)`",
            DeprecationWarning,
        )
        operation = rewriter.current_operation
    return cast_to_regs(operation.operands, register_type_for_type, rewriter)

cast_matched_op_results(rewriter: PatternRewriter) -> list[SSAValue]

Add cast operations just after the matched operation, to preserve the type validity of arguments of uses of results.

Source code in xdsl/backend/riscv/lowering/utils.py
148
149
150
151
152
153
154
@deprecated("Please use `cast_op_results(rewriter, rewriter.current_operation)`")
def cast_matched_op_results(rewriter: PatternRewriter) -> list[SSAValue]:
    """
    Add cast operations just after the matched operation, to preserve the type validity of
    arguments of uses of results.
    """
    return cast_op_results(rewriter, rewriter.current_operation)

cast_op_results(builder: Builder, op: Operation) -> list[SSAValue]

Add cast operations just after the provided operation, to preserve the type validity of arguments of uses of results.

Source code in xdsl/backend/riscv/lowering/utils.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def cast_op_results(builder: Builder, op: Operation) -> list[SSAValue]:
    """
    Add cast operations just after the provided operation, to preserve the type validity
    of arguments of uses of results.
    """
    results = [
        builtin.UnrealizedConversionCastOp.get((val,), (val.type,))
        for val in op.results
    ]

    for res, result in zip(op.results, results):
        for use in set(res.uses):
            # avoid recursion on the casts we just inserted
            if use.operation != result:
                use.operation.operands[use.index] = result.results[0]

    builder.insert_op(results, InsertPoint.after(op))
    return [result.results[0] for result in results]

cast_block_args_from_a_regs(block: Block, rewriter: PatternRewriter)

Change the type of the block arguments to "a" registers and add cast operations just after the block entry. Use fa0, fa1, etc for floating-point values and a0, a1, etc for integer values and pointers. The register index is separate for integer and floating-point registers according to the RISC-V ABI.

Source code in xdsl/backend/riscv/lowering/utils.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def cast_block_args_from_a_regs(block: Block, rewriter: PatternRewriter):
    """
    Change the type of the block arguments to "a" registers and add cast operations just
    after the block entry. Use `fa0`, `fa1`, etc for floating-point values and `a0`, `a1`,
    etc for integer values and pointers. The register index is separate for integer and
    floating-point registers according to the RISC-V ABI.
    """

    new_ops: list[Operation] = []
    counter = Counter[type[riscv.RISCVRegisterType]]()

    for arg in block.args:
        register_type = register_type_for_type(arg.type)
        move_op, new_value = move_ops_for_value(
            arg, arg.type, register_type.unallocated()
        )
        cast_op = builtin.UnrealizedConversionCastOp.get((new_value,), (arg.type,))
        new_ops.append(move_op)
        new_ops.append(cast_op)

        index = counter[register_type]
        rewriter.replace_uses_with_if(
            arg, cast_op.results[0], lambda use: use.operation != move_op
        )
        rewriter.replace_value_with_new_type(arg, register_type.a_register(index))
        counter[register_type] += 1

    rewriter.insert_op(new_ops, InsertPoint.at_start(block))

cast_block_args_to_regs(block: Block, rewriter: PatternRewriter)

Change the type of the block arguments to registers and add cast operations just after the block entry.

Source code in xdsl/backend/riscv/lowering/utils.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def cast_block_args_to_regs(block: Block, rewriter: PatternRewriter):
    """
    Change the type of the block arguments to registers and add cast operations just after
    the block entry.
    """

    for arg in block.args:
        rewriter.insert_op(
            cast_op := builtin.UnrealizedConversionCastOp(
                operands=[arg], result_types=[arg.type]
            ),
            InsertPoint.at_start(block),
        )
        new_val = cast_op.results[0]

        new_type = register_type_for_type(arg.type).unallocated()
        rewriter.replace_uses_with_if(
            arg, new_val, lambda use: use.operation != cast_op
        )
        rewriter.replace_value_with_new_type(arg, new_type)