Bases: InterpreterFunctions
Source code in xdsl/interpreters/ml_program.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37 | @register_impls
class MLProgramFunctions(InterpreterFunctions):
@impl(ml_program.GlobalLoadConstantOp)
def run_global_load_constant(
self,
interpreter: Interpreter,
op: ml_program.GlobalLoadConstantOp,
args: tuple[Any, ...],
) -> tuple[Any, ...]:
global_op = SymbolTable.lookup_symbol(op, op.global_attr)
assert isinstance(global_op, ml_program.GlobalOp)
global_value = global_op.value
assert isa(global_value, DenseIntOrFPElementsAttr)
shape = global_value.get_shape()
xtype = xtype_for_el_type(
global_value.get_element_type(), interpreter.index_bitwidth
)
data = TypedPtr[Any].new(global_value.get_values(), xtype=xtype)
shaped_array = ShapedArray(data, list(shape))
return (shaped_array,)
|
run_global_load_constant(interpreter: Interpreter, op: ml_program.GlobalLoadConstantOp, args: tuple[Any, ...]) -> tuple[Any, ...]
Source code in xdsl/interpreters/ml_program.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37 | @impl(ml_program.GlobalLoadConstantOp)
def run_global_load_constant(
self,
interpreter: Interpreter,
op: ml_program.GlobalLoadConstantOp,
args: tuple[Any, ...],
) -> tuple[Any, ...]:
global_op = SymbolTable.lookup_symbol(op, op.global_attr)
assert isinstance(global_op, ml_program.GlobalOp)
global_value = global_op.value
assert isa(global_value, DenseIntOrFPElementsAttr)
shape = global_value.get_shape()
xtype = xtype_for_el_type(
global_value.get_element_type(), interpreter.index_bitwidth
)
data = TypedPtr[Any].new(global_value.get_values(), xtype=xtype)
shaped_array = ShapedArray(data, list(shape))
return (shaped_array,)
|