Skip to content

Convert vector to x86

convert_vector_to_x86

VectorBroadcastToX86 dataclass

Bases: RewritePattern

Source code in xdsl/backend/x86/lowering/convert_vector_to_x86.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@dataclass
class VectorBroadcastToX86(RewritePattern):
    arch: Arch

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: vector.BroadcastOp, rewriter: PatternRewriter):
        # Get the register to be broadcasted
        source_cast_op, source_x86 = UnrealizedConversionCastOp.cast_one(
            op.source, x86.registers.UNALLOCATED_GENERAL
        )
        # Actually broadcast the register
        element_type = op.source.type
        assert isinstance(element_type, FixedBitwidthType)
        element_size = element_type.bitwidth
        match element_size:
            case 16:
                raise DiagnosticException(
                    "Half-precision vector broadcast is not implemented yet."
                )
            case 32:
                broadcast = x86.ops.DS_VpbroadcastdOp
            case 64:
                broadcast = x86.ops.DS_VpbroadcastqOp
            case _:
                raise DiagnosticException(
                    "Float precision must be half, single or double."
                )
        broadcast_op = broadcast(
            source=source_x86,
            destination=self.arch.register_type_for_type(op.vector.type).unallocated(),
        )
        # Get back the abstract vector
        dest_cast_op, _ = UnrealizedConversionCastOp.cast_one(
            broadcast_op.destination, op.vector.type
        )

        rewriter.replace_op(op, [source_cast_op, broadcast_op, dest_cast_op])

arch: Arch instance-attribute

__init__(arch: Arch) -> None

match_and_rewrite(op: vector.BroadcastOp, rewriter: PatternRewriter)

Source code in xdsl/backend/x86/lowering/convert_vector_to_x86.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@op_type_rewrite_pattern
def match_and_rewrite(self, op: vector.BroadcastOp, rewriter: PatternRewriter):
    # Get the register to be broadcasted
    source_cast_op, source_x86 = UnrealizedConversionCastOp.cast_one(
        op.source, x86.registers.UNALLOCATED_GENERAL
    )
    # Actually broadcast the register
    element_type = op.source.type
    assert isinstance(element_type, FixedBitwidthType)
    element_size = element_type.bitwidth
    match element_size:
        case 16:
            raise DiagnosticException(
                "Half-precision vector broadcast is not implemented yet."
            )
        case 32:
            broadcast = x86.ops.DS_VpbroadcastdOp
        case 64:
            broadcast = x86.ops.DS_VpbroadcastqOp
        case _:
            raise DiagnosticException(
                "Float precision must be half, single or double."
            )
    broadcast_op = broadcast(
        source=source_x86,
        destination=self.arch.register_type_for_type(op.vector.type).unallocated(),
    )
    # Get back the abstract vector
    dest_cast_op, _ = UnrealizedConversionCastOp.cast_one(
        broadcast_op.destination, op.vector.type
    )

    rewriter.replace_op(op, [source_cast_op, broadcast_op, dest_cast_op])

VectorFMAToX86 dataclass

Bases: RewritePattern

Source code in xdsl/backend/x86/lowering/convert_vector_to_x86.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@dataclass
class VectorFMAToX86(RewritePattern):
    arch: Arch

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: vector.FMAOp, rewriter: PatternRewriter):
        vect_type = cast(VectorType, op.acc.type)
        x86_vect_type = self.arch.register_type_for_type(vect_type).unallocated()
        # Pointer casts
        lhs_cast_op, lhs_new = UnrealizedConversionCastOp.cast_one(
            op.lhs, x86_vect_type
        )
        rhs_cast_op, rhs_new = UnrealizedConversionCastOp.cast_one(
            op.rhs, x86_vect_type
        )
        acc_cast_op, acc_new = UnrealizedConversionCastOp.cast_one(
            op.acc, x86_vect_type
        )
        # Instruction selection
        element_size = cast(FixedBitwidthType, vect_type.get_element_type()).bitwidth
        match element_size:
            case 16:
                raise DiagnosticException(
                    "Half-precision vector load is not implemented yet."
                )
            case 32:
                fma = x86.ops.RSS_Vfmadd231psOp
            case 64:
                fma = x86.ops.RSS_Vfmadd231pdOp
            case _:
                raise DiagnosticException(
                    "Float precision must be half, single or double."
                )
        fma_op = fma(acc_new, lhs_new, rhs_new)

        res_cast_op = UnrealizedConversionCastOp.get(
            (fma_op.register_out,), (vect_type,)
        )
        rewriter.replace_op(
            op, [lhs_cast_op, rhs_cast_op, acc_cast_op, fma_op, res_cast_op]
        )

arch: Arch instance-attribute

__init__(arch: Arch) -> None

match_and_rewrite(op: vector.FMAOp, rewriter: PatternRewriter)

Source code in xdsl/backend/x86/lowering/convert_vector_to_x86.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@op_type_rewrite_pattern
def match_and_rewrite(self, op: vector.FMAOp, rewriter: PatternRewriter):
    vect_type = cast(VectorType, op.acc.type)
    x86_vect_type = self.arch.register_type_for_type(vect_type).unallocated()
    # Pointer casts
    lhs_cast_op, lhs_new = UnrealizedConversionCastOp.cast_one(
        op.lhs, x86_vect_type
    )
    rhs_cast_op, rhs_new = UnrealizedConversionCastOp.cast_one(
        op.rhs, x86_vect_type
    )
    acc_cast_op, acc_new = UnrealizedConversionCastOp.cast_one(
        op.acc, x86_vect_type
    )
    # Instruction selection
    element_size = cast(FixedBitwidthType, vect_type.get_element_type()).bitwidth
    match element_size:
        case 16:
            raise DiagnosticException(
                "Half-precision vector load is not implemented yet."
            )
        case 32:
            fma = x86.ops.RSS_Vfmadd231psOp
        case 64:
            fma = x86.ops.RSS_Vfmadd231pdOp
        case _:
            raise DiagnosticException(
                "Float precision must be half, single or double."
            )
    fma_op = fma(acc_new, lhs_new, rhs_new)

    res_cast_op = UnrealizedConversionCastOp.get(
        (fma_op.register_out,), (vect_type,)
    )
    rewriter.replace_op(
        op, [lhs_cast_op, rhs_cast_op, acc_cast_op, fma_op, res_cast_op]
    )

ConvertVectorToX86Pass dataclass

Bases: ModulePass

Source code in xdsl/backend/x86/lowering/convert_vector_to_x86.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@dataclass(frozen=True)
class ConvertVectorToX86Pass(ModulePass):
    name = "convert-vector-to-x86"

    arch: str

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        arch = Arch.arch_for_name(self.arch)
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    VectorFMAToX86(arch),
                    VectorBroadcastToX86(arch),
                ],
                dce_enabled=False,
            ),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-vector-to-x86' class-attribute instance-attribute

arch: str instance-attribute

__init__(arch: str) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/backend/x86/lowering/convert_vector_to_x86.py
111
112
113
114
115
116
117
118
119
120
121
122
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    arch = Arch.arch_for_name(self.arch)
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                VectorFMAToX86(arch),
                VectorBroadcastToX86(arch),
            ],
            dce_enabled=False,
        ),
        apply_recursively=False,
    ).rewrite_module(op)