Skip to content

Convert func to riscv func

convert_func_to_riscv_func

LowerFuncOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class LowerFuncOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
        if len(op.function_type.inputs.data) > 8:
            raise ValueError("Cannot lower func.func with more than 8 inputs")
        if len(op.function_type.outputs.data) > 2:
            raise ValueError("Cannot lower func.func with more than 2 outputs")

        if (first_block := op.body.blocks.first) is not None:
            cast_block_args_from_a_regs(first_block, rewriter)

            input_types = first_block.arg_types
        else:
            input_types = tuple(a_regs_for_types(op.function_type.inputs.data))
        result_types = list(a_regs_for_types(op.function_type.outputs.data))

        # TODO we should ask the target for alignment, this works for rv32
        p2align = 2

        # C-like: default is public
        sym_visibility = (
            StringAttr("public") if op.sym_visibility is None else op.sym_visibility
        )

        new_func = riscv_func.FuncOp(
            op.sym_name.data,
            rewriter.move_region_contents_to_new_regions(op.body),
            (input_types, result_types),
            sym_visibility,
            p2align=p2align,
        )

        rewriter.replace_op(op, new_func)

match_and_rewrite(op: func.FuncOp, rewriter: PatternRewriter)

Source code in xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
    if len(op.function_type.inputs.data) > 8:
        raise ValueError("Cannot lower func.func with more than 8 inputs")
    if len(op.function_type.outputs.data) > 2:
        raise ValueError("Cannot lower func.func with more than 2 outputs")

    if (first_block := op.body.blocks.first) is not None:
        cast_block_args_from_a_regs(first_block, rewriter)

        input_types = first_block.arg_types
    else:
        input_types = tuple(a_regs_for_types(op.function_type.inputs.data))
    result_types = list(a_regs_for_types(op.function_type.outputs.data))

    # TODO we should ask the target for alignment, this works for rv32
    p2align = 2

    # C-like: default is public
    sym_visibility = (
        StringAttr("public") if op.sym_visibility is None else op.sym_visibility
    )

    new_func = riscv_func.FuncOp(
        op.sym_name.data,
        rewriter.move_region_contents_to_new_regions(op.body),
        (input_types, result_types),
        sym_visibility,
        p2align=p2align,
    )

    rewriter.replace_op(op, new_func)

LowerFuncCallOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class LowerFuncCallOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter) -> None:
        if len(op.arguments) > 8:
            raise ValueError("Cannot lower func.call with more than 8 operands")
        if len(op.res) > 2:
            raise ValueError("Cannot lower func.call with more than 2 results")

        if len(op.results) == 1:
            rewriter.name_hint = op.results[0].name_hint

        register_operands = cast_to_regs(op.arguments, register_type_for_type, rewriter)
        operand_types = op.arguments.types
        move_operand_ops, moved_operands = move_to_a_regs(
            register_operands, operand_types
        )

        new_result_types = list(a_regs(op.results))
        new_op = riscv_func.CallOp(op.callee, moved_operands, new_result_types)

        move_result_ops, moved_results = move_to_unallocated_regs(
            new_op.results, op.result_types
        )
        cast_result_ops = [
            UnrealizedConversionCastOp.get((moved_result,), (old_result.type,))
            for moved_result, old_result in zip(moved_results, op.results)
        ]
        rewriter.replace_op(
            op,
            [
                op
                for ops in (
                    move_operand_ops,
                    (new_op,),
                    move_result_ops,
                    cast_result_ops,
                )
                for op in ops
            ],
            [op.results[-1] for op in cast_result_ops],
        )

match_and_rewrite(op: func.CallOp, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter) -> None:
    if len(op.arguments) > 8:
        raise ValueError("Cannot lower func.call with more than 8 operands")
    if len(op.res) > 2:
        raise ValueError("Cannot lower func.call with more than 2 results")

    if len(op.results) == 1:
        rewriter.name_hint = op.results[0].name_hint

    register_operands = cast_to_regs(op.arguments, register_type_for_type, rewriter)
    operand_types = op.arguments.types
    move_operand_ops, moved_operands = move_to_a_regs(
        register_operands, operand_types
    )

    new_result_types = list(a_regs(op.results))
    new_op = riscv_func.CallOp(op.callee, moved_operands, new_result_types)

    move_result_ops, moved_results = move_to_unallocated_regs(
        new_op.results, op.result_types
    )
    cast_result_ops = [
        UnrealizedConversionCastOp.get((moved_result,), (old_result.type,))
        for moved_result, old_result in zip(moved_results, op.results)
    ]
    rewriter.replace_op(
        op,
        [
            op
            for ops in (
                move_operand_ops,
                (new_op,),
                move_result_ops,
                cast_result_ops,
            )
            for op in ops
        ],
        [op.results[-1] for op in cast_result_ops],
    )

LowerReturnOp

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
101
102
103
104
105
106
107
108
109
110
111
112
class LowerReturnOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter):
        if len(op.arguments) > 2:
            raise ValueError("Cannot lower func.return with more than 2 arguments")

        register_values = cast_to_regs(op.arguments, register_type_for_type, rewriter)
        move_ops, moved_values = move_to_a_regs(register_values, op.arguments.types)

        rewriter.insert_op(move_ops)

        rewriter.replace_op(op, riscv_func.ReturnOp(*moved_values))

match_and_rewrite(op: func.ReturnOp, rewriter: PatternRewriter)

Source code in xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
102
103
104
105
106
107
108
109
110
111
112
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter):
    if len(op.arguments) > 2:
        raise ValueError("Cannot lower func.return with more than 2 arguments")

    register_values = cast_to_regs(op.arguments, register_type_for_type, rewriter)
    move_ops, moved_values = move_to_a_regs(register_values, op.arguments.types)

    rewriter.insert_op(move_ops)

    rewriter.replace_op(op, riscv_func.ReturnOp(*moved_values))

ConvertFuncToRiscvFuncPass dataclass

Bases: ModulePass

Source code in xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class ConvertFuncToRiscvFuncPass(ModulePass):
    name = "convert-func-to-riscv-func"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LowerFuncOp(),
                    LowerFuncCallOp(),
                    LowerReturnOp(),
                ]
            ),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-func-to-riscv-func' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
118
119
120
121
122
123
124
125
126
127
128
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                LowerFuncOp(),
                LowerFuncCallOp(),
                LowerReturnOp(),
            ]
        ),
        apply_recursively=False,
    ).rewrite_module(op)