Skip to content

Lower riscv func

lower_riscv_func

LowerSyscallOp

Bases: RewritePattern

Lower SSA version of syscall, storing the optional result to a0.

Different platforms have different calling conventions. This lowering assumes that the inputs are stored in a0-a6, and the opcode is stored to a7. Upon return, the a0 contains the result value. This is not the case for some kernels.

In the future, this pass should take the compilation target as a parameter to guide the rewrites.

Issue tracking this: https://github.com/xdslproject/xdsl/issues/952

Source code in xdsl/transforms/lower_riscv_func.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class LowerSyscallOp(RewritePattern):
    """
    Lower SSA version of syscall, storing the optional result to a0.


    Different platforms have different calling conventions. This lowering assumes that
    the inputs are stored in a0-a6, and the opcode is stored to a7. Upon return, the
    a0 contains the result value. This is not the case for some kernels.

    In the future, this pass should take the compilation target as a parameter to guide
    the rewrites.

    Issue tracking this: https://github.com/xdslproject/xdsl/issues/952
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: riscv_func.SyscallOp, rewriter: PatternRewriter):
        ops: list[Operation] = []

        for i, arg in enumerate(op.args):
            ops.append(
                riscv.MVOp(
                    arg,
                    rd=riscv.IntRegisterType.a_register(i),
                )
            )

        ops.append(riscv.LiOp(immediate=op.syscall_num, rd=riscv.Registers.A7))

        if op.result is None:
            ops.append(riscv.EcallOp())
            new_results = ()
        else:
            # The result will be stored to a0, move to register that will be used
            ecall = riscv.EcallOp()
            ops.append(ecall)
            gr = riscv.GetRegisterOp(riscv.Registers.A0)
            ops.append(gr)
            res = gr.res

            mv = riscv.MVOp(res, rd=op.result.type)
            ops.append(mv)
            new_results = mv.results

        rewriter.replace_op(op, ops, new_results=new_results)

match_and_rewrite(op: riscv_func.SyscallOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_riscv_func.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_func.SyscallOp, rewriter: PatternRewriter):
    ops: list[Operation] = []

    for i, arg in enumerate(op.args):
        ops.append(
            riscv.MVOp(
                arg,
                rd=riscv.IntRegisterType.a_register(i),
            )
        )

    ops.append(riscv.LiOp(immediate=op.syscall_num, rd=riscv.Registers.A7))

    if op.result is None:
        ops.append(riscv.EcallOp())
        new_results = ()
    else:
        # The result will be stored to a0, move to register that will be used
        ecall = riscv.EcallOp()
        ops.append(ecall)
        gr = riscv.GetRegisterOp(riscv.Registers.A0)
        ops.append(gr)
        res = gr.res

        mv = riscv.MVOp(res, rd=op.result.type)
        ops.append(mv)
        new_results = mv.results

    rewriter.replace_op(op, ops, new_results=new_results)

InsertExitSyscallOp

Bases: RewritePattern

Source code in xdsl/transforms/lower_riscv_func.py
63
64
65
66
67
68
69
70
71
72
73
74
class InsertExitSyscallOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: riscv_func.ReturnOp, rewriter: PatternRewriter):
        parent_op = op.parent_op()
        if (
            not isinstance(parent_op, riscv_func.FuncOp)
            or parent_op.sym_name.data != "main"
        ):
            return

        EXIT = 93
        rewriter.insert_op(riscv_func.SyscallOp(EXIT))

match_and_rewrite(op: riscv_func.ReturnOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lower_riscv_func.py
64
65
66
67
68
69
70
71
72
73
74
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_func.ReturnOp, rewriter: PatternRewriter):
    parent_op = op.parent_op()
    if (
        not isinstance(parent_op, riscv_func.FuncOp)
        or parent_op.sym_name.data != "main"
    ):
        return

    EXIT = 93
    rewriter.insert_op(riscv_func.SyscallOp(EXIT))

LowerRISCVFunc dataclass

Bases: ModulePass

Source code in xdsl/transforms/lower_riscv_func.py
77
78
79
80
81
82
83
84
85
86
87
88
@dataclass(frozen=True)
class LowerRISCVFunc(ModulePass):
    name = "lower-riscv-func"

    insert_exit_syscall: bool = field(default=False)

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        if self.insert_exit_syscall:
            PatternRewriteWalker(
                InsertExitSyscallOp(), apply_recursively=False
            ).rewrite_module(op)
        PatternRewriteWalker(LowerSyscallOp()).rewrite_module(op)

name = 'lower-riscv-func' class-attribute instance-attribute

insert_exit_syscall: bool = field(default=False) class-attribute instance-attribute

__init__(insert_exit_syscall: bool = False) -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/lower_riscv_func.py
83
84
85
86
87
88
def apply(self, ctx: Context, op: ModuleOp) -> None:
    if self.insert_exit_syscall:
        PatternRewriteWalker(
            InsertExitSyscallOp(), apply_recursively=False
        ).rewrite_module(op)
    PatternRewriteWalker(LowerSyscallOp()).rewrite_module(op)