Skip to content

Ml program

ml_program

MLProgramFunctions dataclass

Bases: InterpreterFunctions

Source code in xdsl/interpreters/ml_program.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@register_impls
class MLProgramFunctions(InterpreterFunctions):
    @impl(ml_program.GlobalLoadConstantOp)
    def run_global_load_constant(
        self,
        interpreter: Interpreter,
        op: ml_program.GlobalLoadConstantOp,
        args: tuple[Any, ...],
    ) -> tuple[Any, ...]:
        global_op = SymbolTable.lookup_symbol(op, op.global_attr)
        assert isinstance(global_op, ml_program.GlobalOp)
        global_value = global_op.value
        assert isa(global_value, DenseIntOrFPElementsAttr)
        shape = global_value.get_shape()
        xtype = xtype_for_el_type(
            global_value.get_element_type(), interpreter.index_bitwidth
        )
        data = TypedPtr[Any].new(global_value.get_values(), xtype=xtype)
        shaped_array = ShapedArray(data, list(shape))
        return (shaped_array,)

run_global_load_constant(interpreter: Interpreter, op: ml_program.GlobalLoadConstantOp, args: tuple[Any, ...]) -> tuple[Any, ...]

Source code in xdsl/interpreters/ml_program.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@impl(ml_program.GlobalLoadConstantOp)
def run_global_load_constant(
    self,
    interpreter: Interpreter,
    op: ml_program.GlobalLoadConstantOp,
    args: tuple[Any, ...],
) -> tuple[Any, ...]:
    global_op = SymbolTable.lookup_symbol(op, op.global_attr)
    assert isinstance(global_op, ml_program.GlobalOp)
    global_value = global_op.value
    assert isa(global_value, DenseIntOrFPElementsAttr)
    shape = global_value.get_shape()
    xtype = xtype_for_el_type(
        global_value.get_element_type(), interpreter.index_bitwidth
    )
    data = TypedPtr[Any].new(global_value.get_values(), xtype=xtype)
    shaped_array = ShapedArray(data, list(shape))
    return (shaped_array,)