25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131 | class Arch(StrEnum):
UNKNOWN = "unknown"
AVX2 = "avx2"
AVX512 = "avx512"
@staticmethod
def arch_for_name(name: str | None) -> Arch:
if name is None:
return Arch.UNKNOWN
try:
return _ARCH_BY_NAME[name]
except KeyError:
raise DiagnosticException(f"Unsupported arch {name}")
def _register_type_for_vector_type(
self, value_type: VectorType
) -> type[X86VectorRegisterType]:
"""
Given any vector type, returns the appropriate register type.
The vector type must fit exactly into a full bitwidth vector supported by the
ISA, otherwise a `DiagnosticException` is raised.
"""
vector_num_elements = value_type.element_count()
element_type = cast(FixedBitwidthType, value_type.get_element_type())
element_size = element_type.bitwidth
vector_size = vector_num_elements * element_size
match self, vector_size:
case ((Arch.AVX2 | Arch.AVX512), 256):
return x86.registers.AVX2RegisterType
case Arch.AVX512, 512:
return x86.registers.AVX512RegisterType
case _, 128:
return x86.registers.SSERegisterType
case _:
raise DiagnosticException(
f"The vector size ({vector_size} bits) and target architecture `{self}` are inconsistent."
)
def _scalar_type_for_type(self, value_type: Attribute) -> type[X86RegisterType]:
assert not isinstance(value_type, ShapedType)
if (
(isinstance(value_type, FixedBitwidthType) and value_type.bitwidth <= 64)
or isinstance(value_type, IndexType)
or isinstance(value_type, ptr.PtrType)
):
return x86.registers.GeneralRegisterType
else:
raise DiagnosticException("Not implemented for bitwidth larger than 64.")
@overload
def register_type_for_type(
self, value_type: VectorType
) -> type[X86VectorRegisterType]: ...
@overload
def register_type_for_type(
self, value_type: Attribute
) -> type[X86RegisterType]: ...
def register_type_for_type(self, value_type: Attribute) -> type[X86RegisterType]:
if isinstance(value_type, X86RegisterType):
return type(value_type)
if isa(value_type, VectorType):
return self._register_type_for_vector_type(value_type)
return self._scalar_type_for_type(value_type)
def cast_to_regs(
self, values: Sequence[SSAValue], builder: Builder
) -> list[SSAValue]:
return cast_to_regs(values, self.register_type_for_type, builder)
@deprecated("Please use `arch.cast_to_regs(values, rewriter)`")
def cast_operands_to_regs(self, rewriter: PatternRewriter) -> list[SSAValue]:
new_operands = self.cast_to_regs(rewriter.current_operation.operands, rewriter)
return new_operands
def move_value_to_unallocated(
self, value: SSAValue, value_type: Attribute, builder: Builder
) -> SSAValue:
if isa(value_type, VectorType[FixedBitwidthType]):
if not isinstance(reg_type := value.type, X86VectorRegisterType):
raise ValueError(f"Invalid type for move {value_type}")
# Choose the x86 vector instruction according to the
# abstract vector element size
match value_type.get_element_type().bitwidth:
case 16:
raise DiagnosticException(
"Half-precision floating point vector move is not implemented yet."
)
case 32:
raise DiagnosticException(
"Half-precision floating point vector move is not implemented yet."
)
case 64:
mov_op = x86.ops.DS_VmovapdOp(
value, destination=type(reg_type).unallocated()
)
case _:
raise DiagnosticException(
"Float precision must be half, single or double."
)
else:
if not isinstance(reg_type := value.type, X86RegisterType):
raise ValueError(f"Invalid type for move {value_type}")
mov_op = x86.DS_MovOp(value, destination=type(reg_type).unallocated())
return builder.insert_op(mov_op).results[0]
|