Skip to content

Tensor

tensor

T = TypeVar('T') module-attribute

TensorFunctions dataclass

Bases: InterpreterFunctions

Source code in xdsl/interpreters/tensor.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@register_impls
class TensorFunctions(InterpreterFunctions):
    @impl(tensor.EmptyOp)
    def run_empty(
        self, interpreter: Interpreter, op: tensor.EmptyOp, args: tuple[int, ...]
    ) -> tuple[ShapedArray[Any]]:
        result_type = op.tensor.type
        assert isinstance(result_type, TensorType)
        result_shape = list(result_type.get_shape())

        dynamic_dims = iter(args)
        for i in range(len(result_shape)):
            if result_shape[i] == DYNAMIC_INDEX:
                result_shape[i] = next(dynamic_dims)
        assert next(dynamic_dims, None) is None

        xtype = xtype_for_el_type(result_type.element_type, interpreter.index_bitwidth)
        return (
            ShapedArray(
                TypedPtr[Any].new((0,) * math.prod(result_shape), xtype=xtype),
                result_shape,
            ),
        )

    @impl(tensor.ReshapeOp)
    def run_reshape(
        self,
        interpreter: Interpreter,
        op: tensor.ReshapeOp,
        args: tuple[ShapedArray[T], ShapedArray[int]],
    ) -> tuple[ShapedArray[T]]:
        input, new_shape = args
        assert isinstance(input, ShapedArray)
        result_type = op.result.type
        assert isinstance(result_type, TensorType)
        static_shape = list(result_type.get_shape())
        assert static_shape is not None
        if static_shape != new_shape.data:
            raise InterpretationError("Mismatch between static shape and new shape")
        result = ShapedArray(input.data_ptr, static_shape)
        return (result,)

    @impl(tensor.InsertOp)
    def run_insert(
        self,
        interpreter: Interpreter,
        op: tensor.InsertOp,
        args: tuple[T | ShapedArray[T] | int, ...],
    ) -> tuple[ShapedArray[T]]:
        value = cast(T, args[0])
        dest = cast(ShapedArray[T], args[1])
        indices = cast(Sequence[int], args[2:])

        assert isinstance(dest, ShapedArray)
        assert len(indices) == len(dest.shape)

        result = dest.copy()
        result.store(indices, value)

        return (result,)

    @impl(tensor.ExtractOp)
    def run_extract(
        self,
        interpreter: Interpreter,
        op: tensor.ExtractOp,
        args: tuple[ShapedArray[T] | int, ...],
    ) -> tuple[T]:
        tensor = cast(ShapedArray[T], args[0])
        indices = cast(Sequence[int], args[1:])

        assert isinstance(tensor, ShapedArray)
        assert len(indices) == len(tensor.shape)

        return (tensor.load(indices),)

    @impl(tensor.DimOp)
    def run_dim(
        self,
        interpreter: Interpreter,
        op: tensor.DimOp,
        args: tuple[ShapedArray[T] | int, ...],
    ) -> tuple[int]:
        tensor = cast(ShapedArray[T], args[0])
        dim = cast(int, args[1])

        assert isinstance(tensor, ShapedArray)

        return (tensor.shape[dim],)

run_empty(interpreter: Interpreter, op: tensor.EmptyOp, args: tuple[int, ...]) -> tuple[ShapedArray[Any]]

Source code in xdsl/interpreters/tensor.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@impl(tensor.EmptyOp)
def run_empty(
    self, interpreter: Interpreter, op: tensor.EmptyOp, args: tuple[int, ...]
) -> tuple[ShapedArray[Any]]:
    result_type = op.tensor.type
    assert isinstance(result_type, TensorType)
    result_shape = list(result_type.get_shape())

    dynamic_dims = iter(args)
    for i in range(len(result_shape)):
        if result_shape[i] == DYNAMIC_INDEX:
            result_shape[i] = next(dynamic_dims)
    assert next(dynamic_dims, None) is None

    xtype = xtype_for_el_type(result_type.element_type, interpreter.index_bitwidth)
    return (
        ShapedArray(
            TypedPtr[Any].new((0,) * math.prod(result_shape), xtype=xtype),
            result_shape,
        ),
    )

run_reshape(interpreter: Interpreter, op: tensor.ReshapeOp, args: tuple[ShapedArray[T], ShapedArray[int]]) -> tuple[ShapedArray[T]]

Source code in xdsl/interpreters/tensor.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@impl(tensor.ReshapeOp)
def run_reshape(
    self,
    interpreter: Interpreter,
    op: tensor.ReshapeOp,
    args: tuple[ShapedArray[T], ShapedArray[int]],
) -> tuple[ShapedArray[T]]:
    input, new_shape = args
    assert isinstance(input, ShapedArray)
    result_type = op.result.type
    assert isinstance(result_type, TensorType)
    static_shape = list(result_type.get_shape())
    assert static_shape is not None
    if static_shape != new_shape.data:
        raise InterpretationError("Mismatch between static shape and new shape")
    result = ShapedArray(input.data_ptr, static_shape)
    return (result,)

run_insert(interpreter: Interpreter, op: tensor.InsertOp, args: tuple[T | ShapedArray[T] | int, ...]) -> tuple[ShapedArray[T]]

Source code in xdsl/interpreters/tensor.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@impl(tensor.InsertOp)
def run_insert(
    self,
    interpreter: Interpreter,
    op: tensor.InsertOp,
    args: tuple[T | ShapedArray[T] | int, ...],
) -> tuple[ShapedArray[T]]:
    value = cast(T, args[0])
    dest = cast(ShapedArray[T], args[1])
    indices = cast(Sequence[int], args[2:])

    assert isinstance(dest, ShapedArray)
    assert len(indices) == len(dest.shape)

    result = dest.copy()
    result.store(indices, value)

    return (result,)

run_extract(interpreter: Interpreter, op: tensor.ExtractOp, args: tuple[ShapedArray[T] | int, ...]) -> tuple[T]

Source code in xdsl/interpreters/tensor.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
@impl(tensor.ExtractOp)
def run_extract(
    self,
    interpreter: Interpreter,
    op: tensor.ExtractOp,
    args: tuple[ShapedArray[T] | int, ...],
) -> tuple[T]:
    tensor = cast(ShapedArray[T], args[0])
    indices = cast(Sequence[int], args[1:])

    assert isinstance(tensor, ShapedArray)
    assert len(indices) == len(tensor.shape)

    return (tensor.load(indices),)

run_dim(interpreter: Interpreter, op: tensor.DimOp, args: tuple[ShapedArray[T] | int, ...]) -> tuple[int]

Source code in xdsl/interpreters/tensor.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
@impl(tensor.DimOp)
def run_dim(
    self,
    interpreter: Interpreter,
    op: tensor.DimOp,
    args: tuple[ShapedArray[T] | int, ...],
) -> tuple[int]:
    tensor = cast(ShapedArray[T], args[0])
    dim = cast(int, args[1])

    assert isinstance(tensor, ShapedArray)

    return (tensor.shape[dim],)