Skip to content

Memref

memref

MemRefFunctions dataclass

Bases: InterpreterFunctions

Source code in xdsl/interpreters/memref.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@register_impls
class MemRefFunctions(InterpreterFunctions):
    @impl(memref.AllocOp)
    def run_alloc(
        self, interpreter: Interpreter, op: memref.AllocOp, args: PythonValues
    ) -> PythonValues:
        memref_type = op.memref.type

        shape = memref_type.get_shape()
        size = prod(shape)
        xtype = xtype_for_el_type(
            memref_type.get_element_type(), interpreter.index_bitwidth
        )

        shaped_array = ShapedArray(TypedPtr[Any].zeros(size, xtype=xtype), list(shape))
        return (shaped_array,)

    @impl(memref.DeallocOp)
    def run_dealloc(
        self, interpreter: Interpreter, op: memref.DeallocOp, args: PythonValues
    ) -> PythonValues:
        return ()

    @impl(memref.StoreOp)
    def run_store(
        self, interpreter: Interpreter, op: memref.StoreOp, args: PythonValues
    ) -> PythonValues:
        value, memref, *indices = args

        memref = cast(ShapedArray[Any], memref)

        indices = tuple(indices)
        memref.store(indices, value)

        return ()

    @impl(memref.LoadOp)
    def run_load(
        self, interpreter: Interpreter, op: memref.LoadOp, args: tuple[Any, ...]
    ):
        shaped_array, *indices = args

        shaped_array = cast(ShapedArray[Any], shaped_array)

        indices = tuple(indices)
        value = shaped_array.load(indices)

        return (value,)

    @impl(memref.GetGlobalOp)
    def run_get_global(
        self, interpreter: Interpreter, op: memref.GetGlobalOp, args: PythonValues
    ) -> PythonValues:
        mem = SymbolTable.lookup_symbol(op, op.name_)
        assert isinstance(mem, memref.GlobalOp)
        initial_value = mem.initial_value
        if not isa(initial_value, builtin.DenseIntOrFPElementsAttr):
            raise NotImplementedError(
                "MemRefs that are not dense int or float arrays are not implemented"
            )
        data = initial_value.get_values()
        shape = initial_value.get_shape()
        assert shape is not None
        xtype = xtype_for_el_type(
            initial_value.get_element_type(), interpreter.index_bitwidth
        )
        shaped_array = ShapedArray(TypedPtr[Any].new(data, xtype=xtype), list(shape))
        return (shaped_array,)

run_alloc(interpreter: Interpreter, op: memref.AllocOp, args: PythonValues) -> PythonValues

Source code in xdsl/interpreters/memref.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@impl(memref.AllocOp)
def run_alloc(
    self, interpreter: Interpreter, op: memref.AllocOp, args: PythonValues
) -> PythonValues:
    memref_type = op.memref.type

    shape = memref_type.get_shape()
    size = prod(shape)
    xtype = xtype_for_el_type(
        memref_type.get_element_type(), interpreter.index_bitwidth
    )

    shaped_array = ShapedArray(TypedPtr[Any].zeros(size, xtype=xtype), list(shape))
    return (shaped_array,)

run_dealloc(interpreter: Interpreter, op: memref.DeallocOp, args: PythonValues) -> PythonValues

Source code in xdsl/interpreters/memref.py
36
37
38
39
40
@impl(memref.DeallocOp)
def run_dealloc(
    self, interpreter: Interpreter, op: memref.DeallocOp, args: PythonValues
) -> PythonValues:
    return ()

run_store(interpreter: Interpreter, op: memref.StoreOp, args: PythonValues) -> PythonValues

Source code in xdsl/interpreters/memref.py
42
43
44
45
46
47
48
49
50
51
52
53
@impl(memref.StoreOp)
def run_store(
    self, interpreter: Interpreter, op: memref.StoreOp, args: PythonValues
) -> PythonValues:
    value, memref, *indices = args

    memref = cast(ShapedArray[Any], memref)

    indices = tuple(indices)
    memref.store(indices, value)

    return ()

run_load(interpreter: Interpreter, op: memref.LoadOp, args: tuple[Any, ...])

Source code in xdsl/interpreters/memref.py
55
56
57
58
59
60
61
62
63
64
65
66
@impl(memref.LoadOp)
def run_load(
    self, interpreter: Interpreter, op: memref.LoadOp, args: tuple[Any, ...]
):
    shaped_array, *indices = args

    shaped_array = cast(ShapedArray[Any], shaped_array)

    indices = tuple(indices)
    value = shaped_array.load(indices)

    return (value,)

run_get_global(interpreter: Interpreter, op: memref.GetGlobalOp, args: PythonValues) -> PythonValues

Source code in xdsl/interpreters/memref.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@impl(memref.GetGlobalOp)
def run_get_global(
    self, interpreter: Interpreter, op: memref.GetGlobalOp, args: PythonValues
) -> PythonValues:
    mem = SymbolTable.lookup_symbol(op, op.name_)
    assert isinstance(mem, memref.GlobalOp)
    initial_value = mem.initial_value
    if not isa(initial_value, builtin.DenseIntOrFPElementsAttr):
        raise NotImplementedError(
            "MemRefs that are not dense int or float arrays are not implemented"
        )
    data = initial_value.get_values()
    shape = initial_value.get_shape()
    assert shape is not None
    xtype = xtype_for_el_type(
        initial_value.get_element_type(), interpreter.index_bitwidth
    )
    shaped_array = ShapedArray(TypedPtr[Any].new(data, xtype=xtype), list(shape))
    return (shaped_array,)