Skip to content

Prologue epilogue insertion

prologue_epilogue_insertion

PrologueEpilogueInsertion dataclass

Bases: ModulePass

Pass inserting a prologue and epilogue according to the RISC-V ABI. The prologues and epilogues are responsible for saving any callee-preserved registers. In RISC-V these are 's0' to 's11' and 'fs0' to `fs11'. The stack pointer 'sp' must also be restored to its original value.

This pass should be run late in the pipeline after register allocation. It does not itself require register allocation nor invalidate the result of the register allocator.

Source code in xdsl/backend/riscv/prologue_epilogue_insertion.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@dataclass(frozen=True)
class PrologueEpilogueInsertion(ModulePass):
    """
    Pass inserting a prologue and epilogue according to the RISC-V ABI.
    The prologues and epilogues are responsible for saving any callee-preserved
    registers.
    In RISC-V these are 's0' to 's11' and 'fs0' to `fs11'.
    The stack pointer 'sp' must also be restored to its original value.

    This pass should be run late in the pipeline after register allocation.
    It does not itself require register allocation nor invalidate the result of the
    register allocator.
    """

    name = "riscv-prologue-epilogue-insertion"
    xlen: int = field(default=4)
    flen: int = field(default=8)

    def _process_function(self, func: riscv_func.FuncOp) -> None:
        # Find all callee-preserved registers that are clobbered. We define clobbered
        # as it being the result of some operation and therefore written to.
        used_callee_preserved_registers = OrderedSet(
            res.type
            for op in func.walk()
            if not isinstance(op, riscv.GetRegisterOp | riscv.GetFloatRegisterOp)
            for res in op.results
            if isinstance(res.type, IntRegisterType | FloatRegisterType)
            if res.type in Registers.S or res.type in Registers.FS
        )

        if not used_callee_preserved_registers:
            return

        def get_register_size(r: RISCVRegisterType):
            if isinstance(r, IntRegisterType):
                return self.xlen
            return self.flen

        # Build the prologue at the beginning of the function.
        builder = Builder(InsertPoint.at_start(func.body.blocks[0]))
        sp_register = builder.insert(riscv.GetRegisterOp(Registers.SP))
        stack_size = sum(get_register_size(r) for r in used_callee_preserved_registers)
        builder.insert(riscv.AddiOp(sp_register, -stack_size, rd=Registers.SP))
        offset = 0
        for reg in used_callee_preserved_registers:
            if isinstance(reg, IntRegisterType):
                reg_op = builder.insert(riscv.GetRegisterOp(reg))
                op = riscv.SwOp(rs1=sp_register, rs2=reg_op, immediate=offset)
            else:
                reg_op = builder.insert(riscv.GetFloatRegisterOp(reg))
                op = riscv.FSdOp(rs1=sp_register, rs2=reg_op, immediate=offset)

            builder.insert(op)
            offset += get_register_size(reg)

        # Now build the epilogue right before every return operation.
        for block in func.body.blocks:
            ret_op = block.last_op
            if not isinstance(ret_op, riscv_func.ReturnOp):
                continue

            builder = Builder(InsertPoint.before(ret_op))
            offset = 0
            for reg in used_callee_preserved_registers:
                if isinstance(reg, IntRegisterType):
                    op = riscv.LwOp(rs1=sp_register, rd=reg, immediate=offset)
                else:
                    op = riscv.FLdOp(rs1=sp_register, rd=reg, immediate=offset)
                builder.insert(op)
                offset += get_register_size(reg)

            builder.insert(riscv.AddiOp(sp_register, stack_size, rd=Registers.SP))

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        for func in op.walk():
            if not isinstance(func, riscv_func.FuncOp):
                continue

            if len(func.body.blocks) == 0:
                continue

            self._process_function(func)

name = 'riscv-prologue-epilogue-insertion' class-attribute instance-attribute

xlen: int = field(default=4) class-attribute instance-attribute

flen: int = field(default=8) class-attribute instance-attribute

__init__(xlen: int = 4, flen: int = 8) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/backend/riscv/prologue_epilogue_insertion.py
90
91
92
93
94
95
96
97
98
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    for func in op.walk():
        if not isinstance(func, riscv_func.FuncOp):
            continue

        if len(func.body.blocks) == 0:
            continue

        self._process_function(func)