Skip to content

Convert func to x86 func

convert_func_to_x86_func

arg_passing_registers = [x86.registers.RDI, x86.registers.RSI, x86.registers.RDX, x86.registers.RCX, x86.registers.R8, x86.registers.R9] module-attribute

return_passing_register = x86.registers.RAX module-attribute

MAX_REG_PASSING_INPUTS = 6 module-attribute

STACK_SLOT_SIZE_BYTES = 8 module-attribute

LowerFuncOp

Bases: RewritePattern

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class LowerFuncOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
        if op.body.blocks.first is None:
            raise DiagnosticException(
                "Cannot lower external functions (not implemented)"
            )

        for ty in op.function_type.inputs.data:
            if isinstance(ty, builtin.ShapedType):
                raise DiagnosticException(
                    "Cannot lower shaped function parameters (not implemented)"
                )
            elif isinstance(ty, builtin.FixedBitwidthType) and ty.bitwidth > 64:
                raise DiagnosticException(
                    "Cannot lower function parameters bigger than 64 bits (not implemented)"
                )

        if op.sym_visibility == StringAttr("public"):
            directive_op = x86.DirectiveOp(".global", op.sym_name)
            rewriter.insert_op(directive_op)

        num_inputs = len(op.function_type.inputs.data)
        reg_args_types = arg_passing_registers[
            : min(num_inputs, MAX_REG_PASSING_INPUTS)
        ]

        new_region = rewriter.move_region_contents_to_new_regions(op.body)
        first_block = new_region.blocks.first
        assert isinstance(first_block, Block)

        insertion_point = InsertPoint.at_start(first_block)

        # Load the register-carried parameters
        for i, register_type in enumerate(reg_args_types):
            arg = first_block.args[i]
            register = first_block.insert_arg(register_type, i)
            mov_op = x86.DS_MovOp(
                source=register, destination=x86.registers.UNALLOCATED_GENERAL
            )
            cast_op, parameter = builtin.UnrealizedConversionCastOp.cast_one(
                mov_op.destination, arg.type
            )
            rewriter.insert_op([mov_op, cast_op], insertion_point)
            arg.replace_all_uses_with(parameter)
            first_block.erase_arg(arg)

        # The last argument of the basic block should be the stack pointer
        sp = first_block.insert_arg(
            x86.registers.RSP, min(num_inputs, MAX_REG_PASSING_INPUTS)
        )

        # If needed, load the stack-carried parameters by iteratively
        # consuming the 7th argument of the basic block. Once the 7th argument
        # has been read from the stack, it is removed from the
        # basic block arguments, and the former 8th becomes the 7th.
        for i in range(num_inputs - MAX_REG_PASSING_INPUTS):
            arg = first_block.args[MAX_REG_PASSING_INPUTS + 1]
            assert sp != arg
            mov_op = x86.DM_MovOp(
                memory=sp,
                memory_offset=STACK_SLOT_SIZE_BYTES * (i + 1),
                destination=x86.registers.UNALLOCATED_GENERAL,
                comment=f"Load the {i + MAX_REG_PASSING_INPUTS + 1}th argument of the function",
            )
            cast_op = builtin.UnrealizedConversionCastOp.get(
                (mov_op.destination,), (arg.type,)
            )
            rewriter.insert_op([mov_op, cast_op], insertion_point)
            arg.replace_all_uses_with(cast_op.results[0])
            first_block.erase_arg(arg)

        outputs_types: list[Attribute] = []
        if len(op.function_type.outputs.data) == 1:
            outputs_types.append(return_passing_register)

        new_func = x86_func.FuncOp(
            op.sym_name.data,
            new_region,
            (reg_args_types + [x86.registers.RSP], outputs_types),
            visibility=op.sym_visibility,
        )

        rewriter.replace_op(op, new_func)

match_and_rewrite(op: func.FuncOp, rewriter: PatternRewriter)

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
    if op.body.blocks.first is None:
        raise DiagnosticException(
            "Cannot lower external functions (not implemented)"
        )

    for ty in op.function_type.inputs.data:
        if isinstance(ty, builtin.ShapedType):
            raise DiagnosticException(
                "Cannot lower shaped function parameters (not implemented)"
            )
        elif isinstance(ty, builtin.FixedBitwidthType) and ty.bitwidth > 64:
            raise DiagnosticException(
                "Cannot lower function parameters bigger than 64 bits (not implemented)"
            )

    if op.sym_visibility == StringAttr("public"):
        directive_op = x86.DirectiveOp(".global", op.sym_name)
        rewriter.insert_op(directive_op)

    num_inputs = len(op.function_type.inputs.data)
    reg_args_types = arg_passing_registers[
        : min(num_inputs, MAX_REG_PASSING_INPUTS)
    ]

    new_region = rewriter.move_region_contents_to_new_regions(op.body)
    first_block = new_region.blocks.first
    assert isinstance(first_block, Block)

    insertion_point = InsertPoint.at_start(first_block)

    # Load the register-carried parameters
    for i, register_type in enumerate(reg_args_types):
        arg = first_block.args[i]
        register = first_block.insert_arg(register_type, i)
        mov_op = x86.DS_MovOp(
            source=register, destination=x86.registers.UNALLOCATED_GENERAL
        )
        cast_op, parameter = builtin.UnrealizedConversionCastOp.cast_one(
            mov_op.destination, arg.type
        )
        rewriter.insert_op([mov_op, cast_op], insertion_point)
        arg.replace_all_uses_with(parameter)
        first_block.erase_arg(arg)

    # The last argument of the basic block should be the stack pointer
    sp = first_block.insert_arg(
        x86.registers.RSP, min(num_inputs, MAX_REG_PASSING_INPUTS)
    )

    # If needed, load the stack-carried parameters by iteratively
    # consuming the 7th argument of the basic block. Once the 7th argument
    # has been read from the stack, it is removed from the
    # basic block arguments, and the former 8th becomes the 7th.
    for i in range(num_inputs - MAX_REG_PASSING_INPUTS):
        arg = first_block.args[MAX_REG_PASSING_INPUTS + 1]
        assert sp != arg
        mov_op = x86.DM_MovOp(
            memory=sp,
            memory_offset=STACK_SLOT_SIZE_BYTES * (i + 1),
            destination=x86.registers.UNALLOCATED_GENERAL,
            comment=f"Load the {i + MAX_REG_PASSING_INPUTS + 1}th argument of the function",
        )
        cast_op = builtin.UnrealizedConversionCastOp.get(
            (mov_op.destination,), (arg.type,)
        )
        rewriter.insert_op([mov_op, cast_op], insertion_point)
        arg.replace_all_uses_with(cast_op.results[0])
        first_block.erase_arg(arg)

    outputs_types: list[Attribute] = []
    if len(op.function_type.outputs.data) == 1:
        outputs_types.append(return_passing_register)

    new_func = x86_func.FuncOp(
        op.sym_name.data,
        new_region,
        (reg_args_types + [x86.registers.RSP], outputs_types),
        visibility=op.sym_visibility,
    )

    rewriter.replace_op(op, new_func)

LowerReturnOp

Bases: RewritePattern

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class LowerReturnOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter):
        if not op.arguments:
            rewriter.replace_op(op, [x86_func.RetOp()])
            return
        elif len(op.arguments) > 1:
            raise DiagnosticException(
                "Cannot lower func.return with more than 1 argument (not implemented)"
            )

        return_value = op.arguments[0]

        if isinstance(return_value.type, builtin.ShapedType):
            raise DiagnosticException(
                "Cannot lower shaped function output (not implemented)"
            )
        elif (
            isinstance(return_value.type, builtin.FixedBitwidthType)
            and return_value.type.bitwidth > 64
        ):
            raise DiagnosticException(
                "Cannot lower function return values bigger than 64 bits (not implemented)"
            )

        cast_op = builtin.UnrealizedConversionCastOp.get(
            (return_value,), (x86.registers.UNALLOCATED_GENERAL,)
        )
        mov_op = x86.ops.DS_MovOp(cast_op, destination=return_passing_register)
        ret_op = x86_func.RetOp()

        rewriter.replace_op(op, [cast_op, mov_op, ret_op])

match_and_rewrite(op: func.ReturnOp, rewriter: PatternRewriter)

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter):
    if not op.arguments:
        rewriter.replace_op(op, [x86_func.RetOp()])
        return
    elif len(op.arguments) > 1:
        raise DiagnosticException(
            "Cannot lower func.return with more than 1 argument (not implemented)"
        )

    return_value = op.arguments[0]

    if isinstance(return_value.type, builtin.ShapedType):
        raise DiagnosticException(
            "Cannot lower shaped function output (not implemented)"
        )
    elif (
        isinstance(return_value.type, builtin.FixedBitwidthType)
        and return_value.type.bitwidth > 64
    ):
        raise DiagnosticException(
            "Cannot lower function return values bigger than 64 bits (not implemented)"
        )

    cast_op = builtin.UnrealizedConversionCastOp.get(
        (return_value,), (x86.registers.UNALLOCATED_GENERAL,)
    )
    mov_op = x86.ops.DS_MovOp(cast_op, destination=return_passing_register)
    ret_op = x86_func.RetOp()

    rewriter.replace_op(op, [cast_op, mov_op, ret_op])

ConvertFuncToX86FuncPass dataclass

Bases: ModulePass

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
160
161
162
163
164
165
166
167
168
169
170
171
172
class ConvertFuncToX86FuncPass(ModulePass):
    name = "convert-func-to-x86-func"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LowerFuncOp(),
                    LowerReturnOp(),
                ]
            ),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-func-to-x86-func' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/backend/x86/lowering/convert_func_to_x86_func.py
163
164
165
166
167
168
169
170
171
172
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                LowerFuncOp(),
                LowerReturnOp(),
            ]
        ),
        apply_recursively=False,
    ).rewrite_module(op)