Skip to content

Convert func to x86 func

convert_func_to_x86_func

ARG_PASSING_REGISTER_INDICES = [7, 6, 2, 1, 8, 9] module-attribute

ABI-specified function argument registers: RDI, RSI, RDX, RCX, R8, R9.

RETURN_PASSING_REGISTER = 0 module-attribute

ABI-specified return register: RAX.

MAX_REG_PASSING_INPUTS = 6 module-attribute

STACK_SLOT_SIZE_BYTES = 8 module-attribute

LowerFuncOp dataclass

Bases: RewritePattern

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
@dataclass
class LowerFuncOp(RewritePattern):
    arch: Arch

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
        if op.body.blocks.first is None:
            raise DiagnosticException(
                "Cannot lower external functions (not implemented)"
            )

        for ty in op.function_type.inputs.data:
            if isinstance(ty, builtin.ShapedType):
                raise DiagnosticException(
                    "Cannot lower shaped function parameters (not implemented)"
                )
            elif isinstance(ty, builtin.FixedBitwidthType) and ty.bitwidth > 64:
                raise DiagnosticException(
                    "Cannot lower function parameters bigger than 64 bits (not implemented)"
                )

        num_inputs = len(op.function_type.inputs.data)
        new_region = rewriter.move_region_contents_to_new_regions(op.body)
        first_block = new_region.blocks.first
        assert isinstance(first_block, Block)

        reg_args_types = tuple(
            self.arch.register_type_for_type(arg.type).from_index(register_index)
            for register_index, arg in zip(
                ARG_PASSING_REGISTER_INDICES, first_block.args
            )
        )

        insertion_point = InsertPoint.at_start(first_block)

        # Load the register-carried parameters
        for i, register_type in enumerate(reg_args_types):
            arg = first_block.args[i]
            register = first_block.insert_arg(register_type, i)
            mov_op = x86.DS_MovOp(
                source=register, destination=register_type.unallocated()
            )
            cast_op, parameter = builtin.UnrealizedConversionCastOp.cast_one(
                mov_op.destination, arg.type
            )
            rewriter.insert_op([mov_op, cast_op], insertion_point)
            arg.replace_all_uses_with(parameter)
            first_block.erase_arg(arg)

        # The last argument of the basic block should be the stack pointer
        sp = first_block.insert_arg(
            x86.registers.RSP, min(num_inputs, MAX_REG_PASSING_INPUTS)
        )

        # If needed, load the stack-carried parameters by iteratively
        # consuming the 7th argument of the basic block. Once the 7th argument
        # has been read from the stack, it is removed from the
        # basic block arguments, and the former 8th becomes the 7th.
        for i in range(num_inputs - MAX_REG_PASSING_INPUTS):
            arg = first_block.args[MAX_REG_PASSING_INPUTS + 1]
            assert sp != arg
            destination_reg = self.arch.register_type_for_type(arg.type).unallocated()
            assert isinstance(destination_reg, GeneralRegisterType)
            mov_op = x86.DM_MovOp(
                memory=sp,
                memory_offset=STACK_SLOT_SIZE_BYTES * (i + 1),
                destination=destination_reg,
                comment=f"Load the {i + MAX_REG_PASSING_INPUTS + 1}th argument of the function",
            )
            cast_op = builtin.UnrealizedConversionCastOp.get(
                (mov_op.destination,), (arg.type,)
            )
            rewriter.insert_op([mov_op, cast_op], insertion_point)
            arg.replace_all_uses_with(cast_op.results[0])
            first_block.erase_arg(arg)

        outputs_types: list[Attribute] = []
        if len(op.function_type.outputs.data) == 1:
            output_reg = self.arch.register_type_for_type(
                op.function_type.outputs.data[0]
            ).from_index(RETURN_PASSING_REGISTER)
            outputs_types.append(output_reg)

        new_func = x86_func.FuncOp(
            op.sym_name.data,
            new_region,
            (reg_args_types + (x86.registers.RSP,), outputs_types),
            visibility=op.sym_visibility,
        )

        rewriter.replace_op(op, new_func)

arch: Arch instance-attribute

__init__(arch: Arch) -> None

match_and_rewrite(op: func.FuncOp, rewriter: PatternRewriter)

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
    if op.body.blocks.first is None:
        raise DiagnosticException(
            "Cannot lower external functions (not implemented)"
        )

    for ty in op.function_type.inputs.data:
        if isinstance(ty, builtin.ShapedType):
            raise DiagnosticException(
                "Cannot lower shaped function parameters (not implemented)"
            )
        elif isinstance(ty, builtin.FixedBitwidthType) and ty.bitwidth > 64:
            raise DiagnosticException(
                "Cannot lower function parameters bigger than 64 bits (not implemented)"
            )

    num_inputs = len(op.function_type.inputs.data)
    new_region = rewriter.move_region_contents_to_new_regions(op.body)
    first_block = new_region.blocks.first
    assert isinstance(first_block, Block)

    reg_args_types = tuple(
        self.arch.register_type_for_type(arg.type).from_index(register_index)
        for register_index, arg in zip(
            ARG_PASSING_REGISTER_INDICES, first_block.args
        )
    )

    insertion_point = InsertPoint.at_start(first_block)

    # Load the register-carried parameters
    for i, register_type in enumerate(reg_args_types):
        arg = first_block.args[i]
        register = first_block.insert_arg(register_type, i)
        mov_op = x86.DS_MovOp(
            source=register, destination=register_type.unallocated()
        )
        cast_op, parameter = builtin.UnrealizedConversionCastOp.cast_one(
            mov_op.destination, arg.type
        )
        rewriter.insert_op([mov_op, cast_op], insertion_point)
        arg.replace_all_uses_with(parameter)
        first_block.erase_arg(arg)

    # The last argument of the basic block should be the stack pointer
    sp = first_block.insert_arg(
        x86.registers.RSP, min(num_inputs, MAX_REG_PASSING_INPUTS)
    )

    # If needed, load the stack-carried parameters by iteratively
    # consuming the 7th argument of the basic block. Once the 7th argument
    # has been read from the stack, it is removed from the
    # basic block arguments, and the former 8th becomes the 7th.
    for i in range(num_inputs - MAX_REG_PASSING_INPUTS):
        arg = first_block.args[MAX_REG_PASSING_INPUTS + 1]
        assert sp != arg
        destination_reg = self.arch.register_type_for_type(arg.type).unallocated()
        assert isinstance(destination_reg, GeneralRegisterType)
        mov_op = x86.DM_MovOp(
            memory=sp,
            memory_offset=STACK_SLOT_SIZE_BYTES * (i + 1),
            destination=destination_reg,
            comment=f"Load the {i + MAX_REG_PASSING_INPUTS + 1}th argument of the function",
        )
        cast_op = builtin.UnrealizedConversionCastOp.get(
            (mov_op.destination,), (arg.type,)
        )
        rewriter.insert_op([mov_op, cast_op], insertion_point)
        arg.replace_all_uses_with(cast_op.results[0])
        first_block.erase_arg(arg)

    outputs_types: list[Attribute] = []
    if len(op.function_type.outputs.data) == 1:
        output_reg = self.arch.register_type_for_type(
            op.function_type.outputs.data[0]
        ).from_index(RETURN_PASSING_REGISTER)
        outputs_types.append(output_reg)

    new_func = x86_func.FuncOp(
        op.sym_name.data,
        new_region,
        (reg_args_types + (x86.registers.RSP,), outputs_types),
        visibility=op.sym_visibility,
    )

    rewriter.replace_op(op, new_func)

LowerReturnOp dataclass

Bases: RewritePattern

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@dataclass
class LowerReturnOp(RewritePattern):
    arch: Arch

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter):
        if not op.arguments:
            rewriter.replace_op(op, [x86_func.RetOp()])
            return
        elif len(op.arguments) > 1:
            raise DiagnosticException(
                "Cannot lower func.return with more than 1 argument (not implemented)"
            )

        return_value = op.arguments[0]

        if isinstance(return_value.type, builtin.ShapedType):
            raise DiagnosticException(
                "Cannot lower shaped function output (not implemented)"
            )
        elif (
            isinstance(return_value.type, builtin.FixedBitwidthType)
            and return_value.type.bitwidth > 64
        ):
            raise DiagnosticException(
                "Cannot lower function return values bigger than 64 bits (not implemented)"
            )

        ret_unalloc = self.arch.register_type_for_type(return_value.type).unallocated()
        cast_op = builtin.UnrealizedConversionCastOp.get(
            (return_value,), (ret_unalloc,)
        )
        mov_op = x86.ops.DS_MovOp(
            cast_op, destination=ret_unalloc.from_index(RETURN_PASSING_REGISTER)
        )
        ret_op = x86_func.RetOp()

        rewriter.replace_op(op, [cast_op, mov_op, ret_op])

arch: Arch instance-attribute

__init__(arch: Arch) -> None

match_and_rewrite(op: func.ReturnOp, rewriter: PatternRewriter)

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter):
    if not op.arguments:
        rewriter.replace_op(op, [x86_func.RetOp()])
        return
    elif len(op.arguments) > 1:
        raise DiagnosticException(
            "Cannot lower func.return with more than 1 argument (not implemented)"
        )

    return_value = op.arguments[0]

    if isinstance(return_value.type, builtin.ShapedType):
        raise DiagnosticException(
            "Cannot lower shaped function output (not implemented)"
        )
    elif (
        isinstance(return_value.type, builtin.FixedBitwidthType)
        and return_value.type.bitwidth > 64
    ):
        raise DiagnosticException(
            "Cannot lower function return values bigger than 64 bits (not implemented)"
        )

    ret_unalloc = self.arch.register_type_for_type(return_value.type).unallocated()
    cast_op = builtin.UnrealizedConversionCastOp.get(
        (return_value,), (ret_unalloc,)
    )
    mov_op = x86.ops.DS_MovOp(
        cast_op, destination=ret_unalloc.from_index(RETURN_PASSING_REGISTER)
    )
    ret_op = x86_func.RetOp()

    rewriter.replace_op(op, [cast_op, mov_op, ret_op])

ConvertFuncToX86FuncPass dataclass

Bases: ModulePass

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
@dataclass(frozen=True)
class ConvertFuncToX86FuncPass(ModulePass):
    name = "convert-func-to-x86-func"
    arch: str | None = None

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        arch = Arch.arch_for_name(self.arch)
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LowerFuncOp(arch),
                    LowerReturnOp(arch),
                ]
            ),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-func-to-x86-func' class-attribute instance-attribute

arch: str | None = None class-attribute instance-attribute

__init__(arch: str | None = None) -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
188
189
190
191
192
193
194
195
196
197
198
def apply(self, ctx: Context, op: ModuleOp) -> None:
    arch = Arch.arch_for_name(self.arch)
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                LowerFuncOp(arch),
                LowerReturnOp(arch),
            ]
        ),
        apply_recursively=False,
    ).rewrite_module(op)