85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232 | @dataclass(frozen=True)
class PipelineGenericPattern(RewritePattern):
pipeline_depth: int
iterator_index: int | None = field(default=None)
unroll_factor: int | None = field(default=None)
@staticmethod
def indices_and_factors(
op: memref_stream.GenericOp,
) -> tuple[IndexAndFactor, ...]:
"""
Given a `memref_stream.generic` operation, returns all the possible options for
unrolling.
"""
if memref_stream.IteratorTypeAttr.interleaved() in op.iterator_types:
# Already interleaved
return ()
parallel_indices = tuple(
index
for index, iterator_type in enumerate(op.iterator_types)
if iterator_type == memref_stream.IteratorTypeAttr.parallel()
)
parallel_bounds = tuple(
op.bounds.data[index].value.data for index in parallel_indices
)
return tuple(
IndexAndFactor(index, factor)
for index, bound in zip(parallel_indices, parallel_bounds)
for factor in factors(bound)
)
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
if memref_stream.IteratorTypeAttr.interleaved() in op.iterator_types:
# Already interleaved
return
if memref_stream.IteratorTypeAttr.reduction() not in op.iterator_types:
# No reduction
return
assert (self.iterator_index is None) == (self.unroll_factor is None)
if self.iterator_index is not None and self.unroll_factor is not None:
interleave_bound_index = self.iterator_index
interleave_factor = self.unroll_factor
else:
indices_and_factors = self.indices_and_factors(op)
if not indices_and_factors:
return
t = IndexAndFactor.choose(indices_and_factors, self.pipeline_depth)
if t is None:
return
interleave_bound_index, interleave_factor = t
if interleave_factor == 1:
# If unroll factor is 1, rewrite is a no-op
return
old_block = op.body.block
new_region = Region(
Block(
arg_types=(
t
for arg in old_block.args
for t in repeat(arg.type, interleave_factor)
)
)
)
with ImplicitBuilder(new_region) as args:
# For each interleaved block replica, a mapping from old values to new values
value_map: tuple[dict[SSAValue, SSAValue], ...] = tuple(
{} for _ in range(interleave_factor)
)
for arg_index, new_arg in enumerate(args):
old_arg = old_block.args[arg_index // interleave_factor]
value_map[arg_index % interleave_factor][old_arg] = new_arg
new_arg.name_hint = old_arg.name_hint
for block_op in old_block.ops:
if isinstance(block_op, memref_stream.YieldOp):
memref_stream.YieldOp(
*([vm[arg] for vm in value_map for arg in block_op.arguments])
)
else:
for i in range(interleave_factor):
block_op.clone(value_mapper=value_map[i])
# New maps are the same, except that they have one more dimension and the
# dimension that is interleaved is updated to
# `dim * interleave_factor + new_dim`.
new_indexing_maps = ArrayAttr(
AffineMapAttr(
m.data.replace_dims_and_symbols(
(
tuple(
AffineExpr.dimension(i)
for i in range(interleave_bound_index)
)
+ (
AffineExpr.dimension(interleave_bound_index)
* interleave_factor
+ AffineExpr.dimension(m.data.num_dims),
)
+ tuple(
AffineExpr.dimension(i)
for i in range(
interleave_bound_index + 1, m.data.num_dims + 2
)
)
),
(),
m.data.num_dims + 1,
0,
)
)
for m in op.indexing_maps
)
# The new bounds are the same, except there is one more bound
new_bounds = list(op.bounds)
new_bounds.append(IntegerAttr.from_index_int_value(interleave_factor))
interleave_bound = op.bounds.data[interleave_bound_index].value.data
new_bounds[interleave_bound_index] = IntegerAttr.from_index_int_value(
interleave_bound // interleave_factor
)
rewriter.replace_op(
op,
memref_stream.GenericOp(
op.inputs,
op.outputs,
op.inits,
new_region,
new_indexing_maps,
ArrayAttr(
op.iterator_types.data
+ (memref_stream.IteratorTypeAttr.interleaved(),)
),
ArrayAttr(new_bounds),
op.init_indices,
op.doc,
op.library_call,
),
)
|