Skip to content

Tiling

tiling

OperandTileInfo dataclass

This records how one operand should be sliced when we enter a tile. - source_type keeps the original type. - loop_dims the loop dimension that comes from each indexing-map. - result_shape the shape that tiled subview should have.

Source code in xdsl/dialects/linalg/transforms/tiling.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@dataclass(frozen=True)
class OperandTileInfo:
    """
    This records how one operand should be sliced when we enter a tile.
    - `source_type` keeps the original type.
    - `loop_dims` the loop dimension that comes from each indexing-map.
    - `result_shape` the shape that tiled subview should have.
    """

    source_type: MemRefType[Attribute]
    loop_dims: tuple[int, ...]
    result_shape: tuple[int, ...]

    @staticmethod
    def analyze(
        indexing_map: AffineMap,
        source_type: MemRefType[Attribute],
        tile_sizes: Sequence[int],
    ) -> "OperandTileInfo":
        """
        Analyze how one operand should be sliced for each tile.
        """

        source_shape = source_type.get_shape()
        loop_dims = tuple(
            cast(AffineDimExpr, expr).position for expr in indexing_map.results
        )
        result_shape = tuple(
            tile_sizes[loop_dim]
            if tile_sizes[loop_dim] != 0
            else source_shape[result_index]
            for result_index, loop_dim in enumerate(loop_dims)
        )
        return OperandTileInfo(source_type, loop_dims, result_shape)

source_type: MemRefType[Attribute] instance-attribute

loop_dims: tuple[int, ...] instance-attribute

result_shape: tuple[int, ...] instance-attribute

__init__(source_type: MemRefType[Attribute], loop_dims: tuple[int, ...], result_shape: tuple[int, ...]) -> None

analyze(indexing_map: AffineMap, source_type: MemRefType[Attribute], tile_sizes: Sequence[int]) -> OperandTileInfo staticmethod

Analyze how one operand should be sliced for each tile.

Source code in xdsl/dialects/linalg/transforms/tiling.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@staticmethod
def analyze(
    indexing_map: AffineMap,
    source_type: MemRefType[Attribute],
    tile_sizes: Sequence[int],
) -> "OperandTileInfo":
    """
    Analyze how one operand should be sliced for each tile.
    """

    source_shape = source_type.get_shape()
    loop_dims = tuple(
        cast(AffineDimExpr, expr).position for expr in indexing_map.results
    )
    result_shape = tuple(
        tile_sizes[loop_dim]
        if tile_sizes[loop_dim] != 0
        else source_shape[result_index]
        for result_index, loop_dim in enumerate(loop_dims)
    )
    return OperandTileInfo(source_type, loop_dims, result_shape)

TilingPlan dataclass

This stores the information needed to turn one op into tiled loop and tiled subview. - loop_ranges are original static loop ranges. - tiled_dims the dimensions that really get tiled. - operand_infos stores one OperandTileInfo per operand. - tile_sizes are the normalized tile sizes, padded to match the op loop count.

Source code in xdsl/dialects/linalg/transforms/tiling.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@dataclass(frozen=True)
class TilingPlan:
    """
    This stores the information needed to turn one op into tiled loop and tiled subview.
    - `loop_ranges` are original static loop ranges.
    - `tiled_dims` the dimensions that really get tiled.
    - `operand_infos` stores one `OperandTileInfo` per operand.
    - `tile_sizes` are the normalized tile sizes, padded to match the op loop count.
    """

    loop_ranges: tuple[int, ...]
    tiled_dims: tuple[int, ...]
    operand_infos: tuple[OperandTileInfo, ...]
    tile_sizes: tuple[int, ...]

    @staticmethod
    def analyze_generic_op(
        op: linalg.ops.GenericOp,
        tile_sizes: tuple[int, ...],
    ) -> "TilingPlan":
        """
        Analyze one supported `linalg.generic` and return a `TilingPlan`.
        """

        num_loops = op.get_num_loops()
        normalized_tile_sizes = tile_sizes[:num_loops] + (0,) * (
            num_loops - len(tile_sizes)
        )

        tiled_dims = tuple(
            dim for dim, tile_size in enumerate(normalized_tile_sizes) if tile_size != 0
        )

        if not tiled_dims:
            return TilingPlan(
                loop_ranges=(),
                tiled_dims=(),
                operand_infos=(),
                tile_sizes=normalized_tile_sizes,
            )

        loop_ranges = _verify_generic_is_tileable(
            op,
            normalized_tile_sizes,
            tiled_dims,
        )

        operand_infos_list: list[OperandTileInfo] = []
        for operand, indexing_map in zip(
            op.operands, op.get_indexing_maps(), strict=True
        ):
            source_type = operand.type
            assert isa(source_type, MemRefType)
            operand_infos_list.append(
                OperandTileInfo.analyze(
                    indexing_map.data,
                    source_type,
                    normalized_tile_sizes,
                )
            )
        operand_infos = tuple(operand_infos_list)

        return TilingPlan(
            loop_ranges=loop_ranges,
            tiled_dims=tiled_dims,
            operand_infos=operand_infos,
            tile_sizes=normalized_tile_sizes,
        )

loop_ranges: tuple[int, ...] instance-attribute

tiled_dims: tuple[int, ...] instance-attribute

operand_infos: tuple[OperandTileInfo, ...] instance-attribute

tile_sizes: tuple[int, ...] instance-attribute

__init__(loop_ranges: tuple[int, ...], tiled_dims: tuple[int, ...], operand_infos: tuple[OperandTileInfo, ...], tile_sizes: tuple[int, ...]) -> None

analyze_generic_op(op: linalg.ops.GenericOp, tile_sizes: tuple[int, ...]) -> TilingPlan staticmethod

Analyze one supported linalg.generic and return a TilingPlan.

Source code in xdsl/dialects/linalg/transforms/tiling.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@staticmethod
def analyze_generic_op(
    op: linalg.ops.GenericOp,
    tile_sizes: tuple[int, ...],
) -> "TilingPlan":
    """
    Analyze one supported `linalg.generic` and return a `TilingPlan`.
    """

    num_loops = op.get_num_loops()
    normalized_tile_sizes = tile_sizes[:num_loops] + (0,) * (
        num_loops - len(tile_sizes)
    )

    tiled_dims = tuple(
        dim for dim, tile_size in enumerate(normalized_tile_sizes) if tile_size != 0
    )

    if not tiled_dims:
        return TilingPlan(
            loop_ranges=(),
            tiled_dims=(),
            operand_infos=(),
            tile_sizes=normalized_tile_sizes,
        )

    loop_ranges = _verify_generic_is_tileable(
        op,
        normalized_tile_sizes,
        tiled_dims,
    )

    operand_infos_list: list[OperandTileInfo] = []
    for operand, indexing_map in zip(
        op.operands, op.get_indexing_maps(), strict=True
    ):
        source_type = operand.type
        assert isa(source_type, MemRefType)
        operand_infos_list.append(
            OperandTileInfo.analyze(
                indexing_map.data,
                source_type,
                normalized_tile_sizes,
            )
        )
    operand_infos = tuple(operand_infos_list)

    return TilingPlan(
        loop_ranges=loop_ranges,
        tiled_dims=tiled_dims,
        operand_infos=operand_infos,
        tile_sizes=normalized_tile_sizes,
    )

tile_linalg_generic(rewriter: PatternRewriter, op: linalg.ops.GenericOp, tile_sizes: tuple[int, ...]) -> bool

Rewrite supported linalg.generic ops into tiled formed.

Source code in xdsl/dialects/linalg/transforms/tiling.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def tile_linalg_generic(
    rewriter: PatternRewriter,
    op: linalg.ops.GenericOp,
    tile_sizes: tuple[int, ...],
) -> bool:
    """
    Rewrite supported `linalg.generic` ops into tiled formed.
    """
    try:
        plan = TilingPlan.analyze_generic_op(op, tile_sizes)
    except (ValueError, NotImplementedError) as e:
        raise PassFailedException(str(e)) from e

    if not plan.tiled_dims:
        return False

    loops, tiled_loop_ivs, inner_ip = _build_tile_loops(
        rewriter,
        InsertPoint.before(op),
        plan.loop_ranges,
        plan.tile_sizes,
        plan.tiled_dims,
    )
    tiled_operands: list[SSAValue] = []

    for operand, operand_info, indexing_map in zip(
        op.operands, plan.operand_infos, op.get_indexing_maps(), strict=True
    ):
        subview = _build_tiled_subview(
            rewriter, inner_ip, operand, indexing_map.data, operand_info, tiled_loop_ivs
        )
        tiled_operands.append(subview.result)

    num_inputs = len(op.inputs)
    tiled_generic = linalg.ops.GenericOp(
        tiled_operands[:num_inputs],
        tiled_operands[num_inputs:],
        op.body.clone(),
        op.get_indexing_maps(),
        op.get_iterator_types(),
    )
    rewriter.insert_op(tiled_generic, inner_ip)

    for loop in reversed(loops):
        rewriter.insert_op(scf.YieldOp(), InsertPoint.at_end(loop.body.block))

    rewriter.erase_op(op)
    return True