Skip to content

Snitch stream

snitch_stream

StridedPointerInputStream dataclass

Bases: ReadableStream[float]

Source code in xdsl/interpreters/snitch_stream.py
18
19
20
21
22
23
24
25
26
27
@dataclass
class StridedPointerInputStream(ReadableStream[float]):
    offset_iter: Iterator[int]
    pointer: ptr.RawPtr
    index = -1

    def read(self) -> float:
        self.index += 1
        offset = next(self.offset_iter)
        return ptr.TypedPtr((self.pointer + offset), xtype=f64)[0]

offset_iter: Iterator[int] instance-attribute

pointer: ptr.RawPtr instance-attribute

index = -1 class-attribute instance-attribute

__init__(offset_iter: Iterator[int], pointer: ptr.RawPtr) -> None

read() -> float

Source code in xdsl/interpreters/snitch_stream.py
24
25
26
27
def read(self) -> float:
    self.index += 1
    offset = next(self.offset_iter)
    return ptr.TypedPtr((self.pointer + offset), xtype=f64)[0]

StridedPointerOutputStream dataclass

Bases: WritableStream[float]

Source code in xdsl/interpreters/snitch_stream.py
30
31
32
33
34
35
36
37
38
39
@dataclass
class StridedPointerOutputStream(WritableStream[float]):
    index = -1
    offset_iter: Iterator[int]
    pointer: ptr.RawPtr

    def write(self, value: float) -> None:
        self.index += 1
        offset = next(self.offset_iter)
        ptr.TypedPtr((self.pointer + offset), xtype=f64)[0] = value

index = -1 class-attribute instance-attribute

offset_iter: Iterator[int] instance-attribute

pointer: ptr.RawPtr instance-attribute

__init__(offset_iter: Iterator[int], pointer: ptr.RawPtr) -> None

write(value: float) -> None

Source code in xdsl/interpreters/snitch_stream.py
36
37
38
39
def write(self, value: float) -> None:
    self.index += 1
    offset = next(self.offset_iter)
    ptr.TypedPtr((self.pointer + offset), xtype=f64)[0] = value

SnitchStreamFunctions dataclass

Bases: InterpreterFunctions

Source code in xdsl/interpreters/snitch_stream.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@register_impls
class SnitchStreamFunctions(InterpreterFunctions):
    @impl(snitch_stream.StreamingRegionOp)
    def run_streaming_region(
        self,
        interpreter: Interpreter,
        op: snitch_stream.StreamingRegionOp,
        args: tuple[Any, ...],
    ) -> PythonValues:
        input_stream_count = len(op.inputs)
        output_stream_count = len(op.outputs)
        input_pointers: tuple[ptr.RawPtr, ...] = args[:input_stream_count]
        output_pointers: tuple[ptr.RawPtr, ...] = args[
            input_stream_count : input_stream_count + output_stream_count
        ]

        if len(op.stride_patterns) == 1:
            pattern = op.stride_patterns.data[0]
            input_stride_patterns = (pattern,) * input_stream_count
            output_stride_patterns = (pattern,) * output_stream_count
        else:
            input_stride_patterns = op.stride_patterns.data[:input_stream_count]
            output_stride_patterns = op.stride_patterns.data[input_stream_count:]

        input_streams = tuple(
            StridedPointerInputStream(pat.offset_iter(), ptr)
            for pat, ptr in zip(input_stride_patterns, input_pointers, strict=True)
        )

        output_streams = tuple(
            StridedPointerOutputStream(pat.offset_iter(), ptr)
            for pat, ptr in zip(output_stride_patterns, output_pointers, strict=True)
        )

        interpreter.run_ssacfg_region(
            op.body, (*input_streams, *output_streams), "steraming_region"
        )

        return ()

run_streaming_region(interpreter: Interpreter, op: snitch_stream.StreamingRegionOp, args: tuple[Any, ...]) -> PythonValues

Source code in xdsl/interpreters/snitch_stream.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@impl(snitch_stream.StreamingRegionOp)
def run_streaming_region(
    self,
    interpreter: Interpreter,
    op: snitch_stream.StreamingRegionOp,
    args: tuple[Any, ...],
) -> PythonValues:
    input_stream_count = len(op.inputs)
    output_stream_count = len(op.outputs)
    input_pointers: tuple[ptr.RawPtr, ...] = args[:input_stream_count]
    output_pointers: tuple[ptr.RawPtr, ...] = args[
        input_stream_count : input_stream_count + output_stream_count
    ]

    if len(op.stride_patterns) == 1:
        pattern = op.stride_patterns.data[0]
        input_stride_patterns = (pattern,) * input_stream_count
        output_stride_patterns = (pattern,) * output_stream_count
    else:
        input_stride_patterns = op.stride_patterns.data[:input_stream_count]
        output_stride_patterns = op.stride_patterns.data[input_stream_count:]

    input_streams = tuple(
        StridedPointerInputStream(pat.offset_iter(), ptr)
        for pat, ptr in zip(input_stride_patterns, input_pointers, strict=True)
    )

    output_streams = tuple(
        StridedPointerOutputStream(pat.offset_iter(), ptr)
        for pat, ptr in zip(output_stride_patterns, output_pointers, strict=True)
    )

    interpreter.run_ssacfg_region(
        op.body, (*input_streams, *output_streams), "steraming_region"
    )

    return ()