Bases: InterpreterFunctions
Source code in xdsl/interpreters/snitch_stream.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80 | @register_impls
class SnitchStreamFunctions(InterpreterFunctions):
@impl(snitch_stream.StreamingRegionOp)
def run_streaming_region(
self,
interpreter: Interpreter,
op: snitch_stream.StreamingRegionOp,
args: tuple[Any, ...],
) -> PythonValues:
input_stream_count = len(op.inputs)
output_stream_count = len(op.outputs)
input_pointers: tuple[ptr.RawPtr, ...] = args[:input_stream_count]
output_pointers: tuple[ptr.RawPtr, ...] = args[
input_stream_count : input_stream_count + output_stream_count
]
if len(op.stride_patterns) == 1:
pattern = op.stride_patterns.data[0]
input_stride_patterns = (pattern,) * input_stream_count
output_stride_patterns = (pattern,) * output_stream_count
else:
input_stride_patterns = op.stride_patterns.data[:input_stream_count]
output_stride_patterns = op.stride_patterns.data[input_stream_count:]
input_streams = tuple(
StridedPointerInputStream(pat.offset_iter(), ptr)
for pat, ptr in zip(input_stride_patterns, input_pointers, strict=True)
)
output_streams = tuple(
StridedPointerOutputStream(pat.offset_iter(), ptr)
for pat, ptr in zip(output_stride_patterns, output_pointers, strict=True)
)
interpreter.run_ssacfg_region(
op.body, (*input_streams, *output_streams), "steraming_region"
)
return ()
|
run_streaming_region(interpreter: Interpreter, op: snitch_stream.StreamingRegionOp, args: tuple[Any, ...]) -> PythonValues
Source code in xdsl/interpreters/snitch_stream.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80 | @impl(snitch_stream.StreamingRegionOp)
def run_streaming_region(
self,
interpreter: Interpreter,
op: snitch_stream.StreamingRegionOp,
args: tuple[Any, ...],
) -> PythonValues:
input_stream_count = len(op.inputs)
output_stream_count = len(op.outputs)
input_pointers: tuple[ptr.RawPtr, ...] = args[:input_stream_count]
output_pointers: tuple[ptr.RawPtr, ...] = args[
input_stream_count : input_stream_count + output_stream_count
]
if len(op.stride_patterns) == 1:
pattern = op.stride_patterns.data[0]
input_stride_patterns = (pattern,) * input_stream_count
output_stride_patterns = (pattern,) * output_stream_count
else:
input_stride_patterns = op.stride_patterns.data[:input_stream_count]
output_stride_patterns = op.stride_patterns.data[input_stream_count:]
input_streams = tuple(
StridedPointerInputStream(pat.offset_iter(), ptr)
for pat, ptr in zip(input_stride_patterns, input_pointers, strict=True)
)
output_streams = tuple(
StridedPointerOutputStream(pat.offset_iter(), ptr)
for pat, ptr in zip(output_stride_patterns, output_pointers, strict=True)
)
interpreter.run_ssacfg_region(
op.body, (*input_streams, *output_streams), "steraming_region"
)
return ()
|