Skip to content

Memref stream interleave

memref_stream_interleave

IndexAndFactor

Bases: NamedTuple

Helper data structure holding an option for which index of a memref_stream.generic operation to interleave and with which factor.

Source code in xdsl/transforms/memref_stream_interleave.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class IndexAndFactor(NamedTuple):
    """
    Helper data structure holding an option for which index of a `memref_stream.generic`
    operation to interleave and with which factor.
    """

    iterator_index: int
    factor: int

    @staticmethod
    def choose(
        indices_and_factors: Sequence[IndexAndFactor], pipeline_depth: int
    ) -> IndexAndFactor | None:
        """
        A heuristic to choose the interleave index and factor automatically given the
        pipeline depth of floating-point operations on the processor.
        The higher the factor chosen, the higher the instruction-level parallelism, but
        also the more registers need to be used at the same time, potentially leading to
        spilling.
        """
        if not indices_and_factors:
            return None
        # Filter for innermost parallel index
        max_index = max(index for index, _ in indices_and_factors)
        indices_and_factors = tuple(
            t for t in indices_and_factors if t.iterator_index == max_index
        )

        # Want the biggest number for maximal instruction-level parallelism, less than
        # 2 * pipeline depth as a heuristic to limit register pressure.
        indices_and_factors = tuple(
            t for t in indices_and_factors if t.factor < pipeline_depth * 2
        )
        if not indices_and_factors:
            return None

        sorted_indices_and_factors = sorted(indices_and_factors, key=lambda x: x[1])

        # Greatest number less than double of pipeline depth.
        return sorted_indices_and_factors[-1]

iterator_index: int instance-attribute

factor: int instance-attribute

choose(indices_and_factors: Sequence[IndexAndFactor], pipeline_depth: int) -> IndexAndFactor | None staticmethod

A heuristic to choose the interleave index and factor automatically given the pipeline depth of floating-point operations on the processor. The higher the factor chosen, the higher the instruction-level parallelism, but also the more registers need to be used at the same time, potentially leading to spilling.

Source code in xdsl/transforms/memref_stream_interleave.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@staticmethod
def choose(
    indices_and_factors: Sequence[IndexAndFactor], pipeline_depth: int
) -> IndexAndFactor | None:
    """
    A heuristic to choose the interleave index and factor automatically given the
    pipeline depth of floating-point operations on the processor.
    The higher the factor chosen, the higher the instruction-level parallelism, but
    also the more registers need to be used at the same time, potentially leading to
    spilling.
    """
    if not indices_and_factors:
        return None
    # Filter for innermost parallel index
    max_index = max(index for index, _ in indices_and_factors)
    indices_and_factors = tuple(
        t for t in indices_and_factors if t.iterator_index == max_index
    )

    # Want the biggest number for maximal instruction-level parallelism, less than
    # 2 * pipeline depth as a heuristic to limit register pressure.
    indices_and_factors = tuple(
        t for t in indices_and_factors if t.factor < pipeline_depth * 2
    )
    if not indices_and_factors:
        return None

    sorted_indices_and_factors = sorted(indices_and_factors, key=lambda x: x[1])

    # Greatest number less than double of pipeline depth.
    return sorted_indices_and_factors[-1]

PipelineGenericPattern dataclass

Bases: RewritePattern

Source code in xdsl/transforms/memref_stream_interleave.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
@dataclass(frozen=True)
class PipelineGenericPattern(RewritePattern):
    pipeline_depth: int
    iterator_index: int | None = field(default=None)
    unroll_factor: int | None = field(default=None)

    @staticmethod
    def indices_and_factors(
        op: memref_stream.GenericOp,
    ) -> tuple[IndexAndFactor, ...]:
        """
        Given a `memref_stream.generic` operation, returns all the possible options for
        unrolling.
        """
        if memref_stream.IteratorTypeAttr.interleaved() in op.iterator_types:
            # Already interleaved
            return ()
        parallel_indices = tuple(
            index
            for index, iterator_type in enumerate(op.iterator_types)
            if iterator_type == memref_stream.IteratorTypeAttr.parallel()
        )
        parallel_bounds = tuple(
            op.bounds.data[index].value.data for index in parallel_indices
        )
        return tuple(
            IndexAndFactor(index, factor)
            for index, bound in zip(parallel_indices, parallel_bounds)
            for factor in factors(bound)
        )

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: memref_stream.GenericOp, rewriter: PatternRewriter
    ) -> None:
        if memref_stream.IteratorTypeAttr.interleaved() in op.iterator_types:
            # Already interleaved
            return

        if memref_stream.IteratorTypeAttr.reduction() not in op.iterator_types:
            # No reduction
            return

        assert (self.iterator_index is None) == (self.unroll_factor is None)

        if self.iterator_index is not None and self.unroll_factor is not None:
            interleave_bound_index = self.iterator_index
            interleave_factor = self.unroll_factor
        else:
            indices_and_factors = self.indices_and_factors(op)
            if not indices_and_factors:
                return

            t = IndexAndFactor.choose(indices_and_factors, self.pipeline_depth)
            if t is None:
                return

            interleave_bound_index, interleave_factor = t

        if interleave_factor == 1:
            # If unroll factor is 1, rewrite is a no-op
            return

        old_block = op.body.block
        new_region = Region(
            Block(
                arg_types=(
                    t
                    for arg in old_block.args
                    for t in repeat(arg.type, interleave_factor)
                )
            )
        )
        with ImplicitBuilder(new_region) as args:
            # For each interleaved block replica, a mapping from old values to new values
            value_map: tuple[dict[SSAValue, SSAValue], ...] = tuple(
                {} for _ in range(interleave_factor)
            )
            for arg_index, new_arg in enumerate(args):
                old_arg = old_block.args[arg_index // interleave_factor]
                value_map[arg_index % interleave_factor][old_arg] = new_arg
                new_arg.name_hint = old_arg.name_hint
            for block_op in old_block.ops:
                if isinstance(block_op, memref_stream.YieldOp):
                    memref_stream.YieldOp(
                        *([vm[arg] for vm in value_map for arg in block_op.arguments])
                    )
                else:
                    for i in range(interleave_factor):
                        block_op.clone(value_mapper=value_map[i])

        # New maps are the same, except that they have one more dimension and the
        # dimension that is interleaved is updated to
        # `dim * interleave_factor + new_dim`.
        new_indexing_maps = ArrayAttr(
            AffineMapAttr(
                m.data.replace_dims_and_symbols(
                    (
                        tuple(
                            AffineExpr.dimension(i)
                            for i in range(interleave_bound_index)
                        )
                        + (
                            AffineExpr.dimension(interleave_bound_index)
                            * interleave_factor
                            + AffineExpr.dimension(m.data.num_dims),
                        )
                        + tuple(
                            AffineExpr.dimension(i)
                            for i in range(
                                interleave_bound_index + 1, m.data.num_dims + 2
                            )
                        )
                    ),
                    (),
                    m.data.num_dims + 1,
                    0,
                )
            )
            for m in op.indexing_maps
        )

        # The new bounds are the same, except there is one more bound
        new_bounds = list(op.bounds)
        new_bounds.append(IntegerAttr.from_index_int_value(interleave_factor))
        interleave_bound = op.bounds.data[interleave_bound_index].value.data
        new_bounds[interleave_bound_index] = IntegerAttr.from_index_int_value(
            interleave_bound // interleave_factor
        )

        rewriter.replace_op(
            op,
            memref_stream.GenericOp(
                op.inputs,
                op.outputs,
                op.inits,
                new_region,
                new_indexing_maps,
                ArrayAttr(
                    op.iterator_types.data
                    + (memref_stream.IteratorTypeAttr.interleaved(),)
                ),
                ArrayAttr(new_bounds),
                op.init_indices,
                op.doc,
                op.library_call,
            ),
        )

pipeline_depth: int instance-attribute

iterator_index: int | None = field(default=None) class-attribute instance-attribute

unroll_factor: int | None = field(default=None) class-attribute instance-attribute

__init__(pipeline_depth: int, iterator_index: int | None = None, unroll_factor: int | None = None) -> None

indices_and_factors(op: memref_stream.GenericOp) -> tuple[IndexAndFactor, ...] staticmethod

Given a memref_stream.generic operation, returns all the possible options for unrolling.

Source code in xdsl/transforms/memref_stream_interleave.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@staticmethod
def indices_and_factors(
    op: memref_stream.GenericOp,
) -> tuple[IndexAndFactor, ...]:
    """
    Given a `memref_stream.generic` operation, returns all the possible options for
    unrolling.
    """
    if memref_stream.IteratorTypeAttr.interleaved() in op.iterator_types:
        # Already interleaved
        return ()
    parallel_indices = tuple(
        index
        for index, iterator_type in enumerate(op.iterator_types)
        if iterator_type == memref_stream.IteratorTypeAttr.parallel()
    )
    parallel_bounds = tuple(
        op.bounds.data[index].value.data for index in parallel_indices
    )
    return tuple(
        IndexAndFactor(index, factor)
        for index, bound in zip(parallel_indices, parallel_bounds)
        for factor in factors(bound)
    )

match_and_rewrite(op: memref_stream.GenericOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/memref_stream_interleave.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
    if memref_stream.IteratorTypeAttr.interleaved() in op.iterator_types:
        # Already interleaved
        return

    if memref_stream.IteratorTypeAttr.reduction() not in op.iterator_types:
        # No reduction
        return

    assert (self.iterator_index is None) == (self.unroll_factor is None)

    if self.iterator_index is not None and self.unroll_factor is not None:
        interleave_bound_index = self.iterator_index
        interleave_factor = self.unroll_factor
    else:
        indices_and_factors = self.indices_and_factors(op)
        if not indices_and_factors:
            return

        t = IndexAndFactor.choose(indices_and_factors, self.pipeline_depth)
        if t is None:
            return

        interleave_bound_index, interleave_factor = t

    if interleave_factor == 1:
        # If unroll factor is 1, rewrite is a no-op
        return

    old_block = op.body.block
    new_region = Region(
        Block(
            arg_types=(
                t
                for arg in old_block.args
                for t in repeat(arg.type, interleave_factor)
            )
        )
    )
    with ImplicitBuilder(new_region) as args:
        # For each interleaved block replica, a mapping from old values to new values
        value_map: tuple[dict[SSAValue, SSAValue], ...] = tuple(
            {} for _ in range(interleave_factor)
        )
        for arg_index, new_arg in enumerate(args):
            old_arg = old_block.args[arg_index // interleave_factor]
            value_map[arg_index % interleave_factor][old_arg] = new_arg
            new_arg.name_hint = old_arg.name_hint
        for block_op in old_block.ops:
            if isinstance(block_op, memref_stream.YieldOp):
                memref_stream.YieldOp(
                    *([vm[arg] for vm in value_map for arg in block_op.arguments])
                )
            else:
                for i in range(interleave_factor):
                    block_op.clone(value_mapper=value_map[i])

    # New maps are the same, except that they have one more dimension and the
    # dimension that is interleaved is updated to
    # `dim * interleave_factor + new_dim`.
    new_indexing_maps = ArrayAttr(
        AffineMapAttr(
            m.data.replace_dims_and_symbols(
                (
                    tuple(
                        AffineExpr.dimension(i)
                        for i in range(interleave_bound_index)
                    )
                    + (
                        AffineExpr.dimension(interleave_bound_index)
                        * interleave_factor
                        + AffineExpr.dimension(m.data.num_dims),
                    )
                    + tuple(
                        AffineExpr.dimension(i)
                        for i in range(
                            interleave_bound_index + 1, m.data.num_dims + 2
                        )
                    )
                ),
                (),
                m.data.num_dims + 1,
                0,
            )
        )
        for m in op.indexing_maps
    )

    # The new bounds are the same, except there is one more bound
    new_bounds = list(op.bounds)
    new_bounds.append(IntegerAttr.from_index_int_value(interleave_factor))
    interleave_bound = op.bounds.data[interleave_bound_index].value.data
    new_bounds[interleave_bound_index] = IntegerAttr.from_index_int_value(
        interleave_bound // interleave_factor
    )

    rewriter.replace_op(
        op,
        memref_stream.GenericOp(
            op.inputs,
            op.outputs,
            op.inits,
            new_region,
            new_indexing_maps,
            ArrayAttr(
                op.iterator_types.data
                + (memref_stream.IteratorTypeAttr.interleaved(),)
            ),
            ArrayAttr(new_bounds),
            op.init_indices,
            op.doc,
            op.library_call,
        ),
    )

MemRefStreamInterleavePass dataclass

Bases: ModulePass

Tiles the innermost parallel dimension of a memref_stream.generic. If specified, the pipeline-depth parameter specifies the number of operations in the resulting body that should be executed concurrently. The pass will select the largest factor of the corresponding bound smaller than pipeline-depth * 2. The search range is bound by pipeline-depth * 2 as very large interleaving factors can increase register pressure and potentially exhaust all available registers. In the future, it would be good to take the number of available registers into account when choosing a search range, as well as inspecting the generic body for read-after-write dependencies.

Source code in xdsl/transforms/memref_stream_interleave.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
@dataclass(frozen=True)
class MemRefStreamInterleavePass(ModulePass):
    """
    Tiles the innermost parallel dimension of a `memref_stream.generic`.
    If specified, the `pipeline-depth` parameter specifies the number of operations in the
    resulting body that should be executed concurrently.
    The pass will select the largest factor of the corresponding bound smaller than
    `pipeline-depth * 2`.
    The search range is bound by `pipeline-depth * 2` as very large interleaving factors
    can increase register pressure and potentially exhaust all available registers.
    In the future, it would be good to take the number of available registers into account
    when choosing a search range, as well as inspecting the generic body for
    read-after-write dependencies.
    """

    name = "memref-stream-interleave"

    pipeline_depth: int = field(default=4)
    op_index: int | None = field(default=None)
    iterator_index: int | None = field(default=None)
    unroll_factor: int | None = field(default=None)

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        pattern = PipelineGenericPattern(
            self.pipeline_depth,
            self.iterator_index,
            self.unroll_factor,
        )
        if self.op_index is not None:
            matched_op = OpSelector(self.op_index, "memref_stream.generic").get_op(op)
            pattern.match_and_rewrite(matched_op, PatternRewriter(matched_op))
            return

        PatternRewriteWalker(pattern, apply_recursively=False).rewrite_module(op)

    @classmethod
    def schedule_space(cls, ctx: Context, module_op: ModuleOp):
        return tuple(
            MemRefStreamInterleavePass(
                op_index=op_idx,
                iterator_index=iterator_index,
                unroll_factor=unroll_factor,
            )
            for op_idx, matched_op in enumerate(module_op.walk())
            if isinstance(matched_op, memref_stream.GenericOp)
            for iterator_index, unroll_factor in PipelineGenericPattern.indices_and_factors(
                matched_op
            )
        )

name = 'memref-stream-interleave' class-attribute instance-attribute

pipeline_depth: int = field(default=4) class-attribute instance-attribute

op_index: int | None = field(default=None) class-attribute instance-attribute

iterator_index: int | None = field(default=None) class-attribute instance-attribute

unroll_factor: int | None = field(default=None) class-attribute instance-attribute

__init__(pipeline_depth: int = 4, op_index: int | None = None, iterator_index: int | None = None, unroll_factor: int | None = None) -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/memref_stream_interleave.py
257
258
259
260
261
262
263
264
265
266
267
268
def apply(self, ctx: Context, op: ModuleOp) -> None:
    pattern = PipelineGenericPattern(
        self.pipeline_depth,
        self.iterator_index,
        self.unroll_factor,
    )
    if self.op_index is not None:
        matched_op = OpSelector(self.op_index, "memref_stream.generic").get_op(op)
        pattern.match_and_rewrite(matched_op, PatternRewriter(matched_op))
        return

    PatternRewriteWalker(pattern, apply_recursively=False).rewrite_module(op)

schedule_space(ctx: Context, module_op: ModuleOp) classmethod

Source code in xdsl/transforms/memref_stream_interleave.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
@classmethod
def schedule_space(cls, ctx: Context, module_op: ModuleOp):
    return tuple(
        MemRefStreamInterleavePass(
            op_index=op_idx,
            iterator_index=iterator_index,
            unroll_factor=unroll_factor,
        )
        for op_idx, matched_op in enumerate(module_op.walk())
        if isinstance(matched_op, memref_stream.GenericOp)
        for iterator_index, unroll_factor in PipelineGenericPattern.indices_and_factors(
            matched_op
        )
    )

factors(num: int) -> tuple[int, ...]

For all positive integers, returns the n-tuple of all numbers that evenly divide the input, returns an empty tuple for 0 or negative inputs.

Source code in xdsl/transforms/memref_stream_interleave.py
29
30
31
32
33
34
35
36
37
38
39
40
def factors(num: int) -> tuple[int, ...]:
    """
    For all positive integers, returns the n-tuple of all numbers that evenly divide the
    input, returns an empty tuple for 0 or negative inputs.
    """
    if num <= 0:
        return ()

    if num == 1:
        return (1,)

    return tuple(factor for factor in range(1, num + 1) if not num % factor)