Skip to content

Memref stream

memref_stream

MemRefStreamFunctions dataclass

Bases: InterpreterFunctions

Source code in xdsl/interpreters/memref_stream.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@register_impls
class MemRefStreamFunctions(InterpreterFunctions):
    @impl(memref_stream.GenericOp)
    def run_generic(
        self,
        interpreter: Interpreter,
        op: memref_stream.GenericOp,
        args: tuple[Any, ...],
    ) -> PythonValues:
        if memref_stream.IteratorTypeAttr.interleaved() in op.iterator_types:
            raise NotImplementedError(
                "Interpreter for interleaved operations not yet implemented"
            )

        inputs_count = len(op.inputs)
        outputs_count = len(op.outputs)

        outputs: tuple[ShapedArray[int | float], ...] = args[
            inputs_count : inputs_count + outputs_count
        ]
        init_values: tuple[int | float, ...] = args[inputs_count + outputs_count :]

        indexing_maps = tuple(attr.data for attr in op.indexing_maps)
        output_indexing_maps = indexing_maps[inputs_count:]

        outer_ubs, inner_ubs = op.get_static_loop_ranges()

        inits: list[None | int | float] = [None] * len(op.outputs)
        for index, init in zip(op.init_indices, init_values, strict=True):
            inits[index.data] = init

        if inner_ubs:
            inputs: tuple[ShapedArray[float] | float, ...] = args[:inputs_count]
            input_indexing_maps = indexing_maps[:inputs_count]
            for outer_indices in product(*(range(outer_ub) for outer_ub in outer_ubs)):
                output_loop_args = tuple(
                    (
                        o.load(indexing_map.eval(outer_indices, ()))
                        if init is None
                        else init
                    )
                    for o, indexing_map, init in zip(
                        outputs,
                        output_indexing_maps,
                        inits,
                        strict=True,
                    )
                )
                for inner_indices in product(
                    *(range(inner_ub) for inner_ub in inner_ubs)
                ):
                    input_loop_args = tuple(
                        (
                            (cast(ShapedArray[Any], i)).load(
                                indexing_map.eval(outer_indices + inner_indices, ())
                            )
                            if isinstance(i, ShapedArray)
                            else i
                        )
                        for i, indexing_map in zip(
                            inputs, input_indexing_maps, strict=True
                        )
                    )

                    loop_results = interpreter.run_ssacfg_region(
                        op.body, input_loop_args + output_loop_args, "for_loop"
                    )
                    output_loop_args = loop_results
                for res, indexing_map, output in zip(
                    output_loop_args, output_indexing_maps, outputs, strict=True
                ):
                    result_indices = indexing_map.eval(outer_indices, ())
                    output.store(result_indices, res)
        else:
            for indices in product(*(range(outer_ub) for outer_ub in outer_ubs)):
                loop_args = tuple(
                    (
                        (cast(ShapedArray[Any], i)).load(indexing_map.eval(indices, ()))
                        if isinstance(i, ShapedArray)
                        else i
                    )
                    for i, indexing_map in zip(args, indexing_maps, strict=True)
                )
                loop_results = interpreter.run_ssacfg_region(
                    op.body, loop_args, "for_loop"
                )
                for res, indexing_map, output in zip(
                    loop_results, output_indexing_maps, outputs, strict=True
                ):
                    result_indices = indexing_map.eval(indices, ())
                    output.store(result_indices, res)

        return ()

    @impl_terminator(memref_stream.YieldOp)
    def run_yield(
        self, interpreter: Interpreter, op: memref_stream.YieldOp, args: tuple[Any, ...]
    ):
        return ReturnedValues(args), ()

run_generic(interpreter: Interpreter, op: memref_stream.GenericOp, args: tuple[Any, ...]) -> PythonValues

Source code in xdsl/interpreters/memref_stream.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
@impl(memref_stream.GenericOp)
def run_generic(
    self,
    interpreter: Interpreter,
    op: memref_stream.GenericOp,
    args: tuple[Any, ...],
) -> PythonValues:
    if memref_stream.IteratorTypeAttr.interleaved() in op.iterator_types:
        raise NotImplementedError(
            "Interpreter for interleaved operations not yet implemented"
        )

    inputs_count = len(op.inputs)
    outputs_count = len(op.outputs)

    outputs: tuple[ShapedArray[int | float], ...] = args[
        inputs_count : inputs_count + outputs_count
    ]
    init_values: tuple[int | float, ...] = args[inputs_count + outputs_count :]

    indexing_maps = tuple(attr.data for attr in op.indexing_maps)
    output_indexing_maps = indexing_maps[inputs_count:]

    outer_ubs, inner_ubs = op.get_static_loop_ranges()

    inits: list[None | int | float] = [None] * len(op.outputs)
    for index, init in zip(op.init_indices, init_values, strict=True):
        inits[index.data] = init

    if inner_ubs:
        inputs: tuple[ShapedArray[float] | float, ...] = args[:inputs_count]
        input_indexing_maps = indexing_maps[:inputs_count]
        for outer_indices in product(*(range(outer_ub) for outer_ub in outer_ubs)):
            output_loop_args = tuple(
                (
                    o.load(indexing_map.eval(outer_indices, ()))
                    if init is None
                    else init
                )
                for o, indexing_map, init in zip(
                    outputs,
                    output_indexing_maps,
                    inits,
                    strict=True,
                )
            )
            for inner_indices in product(
                *(range(inner_ub) for inner_ub in inner_ubs)
            ):
                input_loop_args = tuple(
                    (
                        (cast(ShapedArray[Any], i)).load(
                            indexing_map.eval(outer_indices + inner_indices, ())
                        )
                        if isinstance(i, ShapedArray)
                        else i
                    )
                    for i, indexing_map in zip(
                        inputs, input_indexing_maps, strict=True
                    )
                )

                loop_results = interpreter.run_ssacfg_region(
                    op.body, input_loop_args + output_loop_args, "for_loop"
                )
                output_loop_args = loop_results
            for res, indexing_map, output in zip(
                output_loop_args, output_indexing_maps, outputs, strict=True
            ):
                result_indices = indexing_map.eval(outer_indices, ())
                output.store(result_indices, res)
    else:
        for indices in product(*(range(outer_ub) for outer_ub in outer_ubs)):
            loop_args = tuple(
                (
                    (cast(ShapedArray[Any], i)).load(indexing_map.eval(indices, ()))
                    if isinstance(i, ShapedArray)
                    else i
                )
                for i, indexing_map in zip(args, indexing_maps, strict=True)
            )
            loop_results = interpreter.run_ssacfg_region(
                op.body, loop_args, "for_loop"
            )
            for res, indexing_map, output in zip(
                loop_results, output_indexing_maps, outputs, strict=True
            ):
                result_indices = indexing_map.eval(indices, ())
                output.store(result_indices, res)

    return ()

run_yield(interpreter: Interpreter, op: memref_stream.YieldOp, args: tuple[Any, ...])

Source code in xdsl/interpreters/memref_stream.py
111
112
113
114
115
@impl_terminator(memref_stream.YieldOp)
def run_yield(
    self, interpreter: Interpreter, op: memref_stream.YieldOp, args: tuple[Any, ...]
):
    return ReturnedValues(args), ()