Skip to content

Scf parallel loop tiling

scf_parallel_loop_tiling

ScfParallelLoopTilingPattern dataclass

Bases: RewritePattern

Source code in xdsl/transforms/scf_parallel_loop_tiling.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@dataclass
class ScfParallelLoopTilingPattern(RewritePattern):
    tile_sizes: tuple[int, ...]

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ParallelOp, rewriter: PatternRewriter, /):
        # Only tile the innermost parallel loop
        if any(isinstance(o, ParallelOp) for o in op.body.walk()):
            return
        lower = op.lowerBound
        upper = op.upperBound
        step = op.step
        # The pass is meant to work on any parallel loop with any nu;ber of tile sizes.
        # For a loop of dimension N, either use the N first tile sizes or use them all
        # and fill the rest with 1.
        tile_sizes_v = self.tile_sizes[: len(lower)] + (1,) * (
            len(lower) - len(self.tile_sizes)
        )

        zero = arith.ConstantOp(IntegerAttr.from_index_int_value(0))
        tile_sizes_v = {i: s for i, s in enumerate(tile_sizes_v) if s != 0}
        tile_sizes = {
            i: arith.ConstantOp(IntegerAttr.from_index_int_value(s))
            for i, s in tile_sizes_v.items()
        }
        tiled_dims = sorted(tile_sizes.keys())
        if not tiled_dims:
            return
        outter_lower = [lower[d] for d in tiled_dims]
        outter_upper = [upper[d] for d in tiled_dims]
        outter_step = [arith.MuliOp(step[d], tile_sizes[d]) for d in tiled_dims]

        outter_loop = ParallelOp(
            outter_lower,
            outter_upper,
            outter_step,
            Region(
                Block(
                    [(outter_reduce := ReduceOp())],
                    arg_types=[IndexType()] * len(outter_lower),
                )
            ),
        )

        inner_lower = list[SSAValue | Operation]()
        inner_upper = list[SSAValue | Operation]()
        minops = list[Operation]()
        minmap = affine.AffineMapAttr(
            affine.AffineMap(
                3,
                0,
                (
                    affine.AffineExpr.dimension(0),
                    affine.AffineExpr.dimension(1) - affine.AffineExpr.dimension(2),
                ),
            )
        )
        for i in range(len(op.lowerBound)):
            if i in tile_sizes:
                inner_lower.append(zero)
                ilower, iupper, istep = lower[i], upper[i], step[i]
                if (
                    isinstance(ilower, arith.ConstantOp)
                    and isinstance(iupper, arith.ConstantOp)
                    and isinstance(istep, arith.ConstantOp)
                ):
                    lower_v, upper_v, step_v = (
                        c.value for c in (ilower, iupper, istep)
                    )
                    assert isa(lower_v, IntegerAttr[IndexType])
                    assert isa(upper_v, IntegerAttr[IndexType])
                    assert isa(step_v, IntegerAttr[IndexType])
                    lower_v, upper_v, step_v = (
                        c.value.data for c in (lower_v, upper_v, step_v)
                    )
                    iter_count = (upper_v - lower_v) // step_v
                    if iter_count % tile_sizes_v[i] == 0:
                        inner_upper.append(tile_sizes[i])
                        continue

                arg_index = tiled_dims.index(i)
                minop = affine.MinOp(
                    operands=[
                        [
                            tile_sizes[i],
                            outter_upper[arg_index],
                            outter_loop.body.block.args[arg_index],
                        ]
                    ],
                    properties={"map": minmap},
                    result_types=[IndexType()],
                )
                minops.append(minop)
                inner_upper.append(minop)

            else:
                inner_lower.append(lower[i])
                inner_upper.append(upper[i])

        inner_loop = ParallelOp(inner_lower, inner_upper, step, op.detach_region(0))
        for i, arg in reversed(list(enumerate(inner_loop.body.block.args))):
            if i in tile_sizes:
                arg_index = tiled_dims.index(i)
                iv = arith.AddiOp(outter_loop.body.block.args[arg_index], arg)
                assert inner_loop.body.block.first_op is not None
                inner_loop.body.block.insert_op_before(
                    iv, inner_loop.body.block.first_op
                )
                for use in tuple(arg.uses):
                    if use.operation is iv:
                        continue
                    use.operation.operands[use.index] = iv.result
        outter_loop.body.block.insert_ops_before([*minops, inner_loop], outter_reduce)
        rewriter.replace_op(op, [zero, *tile_sizes.values(), *outter_step, outter_loop])

tile_sizes: tuple[int, ...] instance-attribute

__init__(tile_sizes: tuple[int, ...]) -> None

match_and_rewrite(op: ParallelOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/scf_parallel_loop_tiling.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ParallelOp, rewriter: PatternRewriter, /):
    # Only tile the innermost parallel loop
    if any(isinstance(o, ParallelOp) for o in op.body.walk()):
        return
    lower = op.lowerBound
    upper = op.upperBound
    step = op.step
    # The pass is meant to work on any parallel loop with any nu;ber of tile sizes.
    # For a loop of dimension N, either use the N first tile sizes or use them all
    # and fill the rest with 1.
    tile_sizes_v = self.tile_sizes[: len(lower)] + (1,) * (
        len(lower) - len(self.tile_sizes)
    )

    zero = arith.ConstantOp(IntegerAttr.from_index_int_value(0))
    tile_sizes_v = {i: s for i, s in enumerate(tile_sizes_v) if s != 0}
    tile_sizes = {
        i: arith.ConstantOp(IntegerAttr.from_index_int_value(s))
        for i, s in tile_sizes_v.items()
    }
    tiled_dims = sorted(tile_sizes.keys())
    if not tiled_dims:
        return
    outter_lower = [lower[d] for d in tiled_dims]
    outter_upper = [upper[d] for d in tiled_dims]
    outter_step = [arith.MuliOp(step[d], tile_sizes[d]) for d in tiled_dims]

    outter_loop = ParallelOp(
        outter_lower,
        outter_upper,
        outter_step,
        Region(
            Block(
                [(outter_reduce := ReduceOp())],
                arg_types=[IndexType()] * len(outter_lower),
            )
        ),
    )

    inner_lower = list[SSAValue | Operation]()
    inner_upper = list[SSAValue | Operation]()
    minops = list[Operation]()
    minmap = affine.AffineMapAttr(
        affine.AffineMap(
            3,
            0,
            (
                affine.AffineExpr.dimension(0),
                affine.AffineExpr.dimension(1) - affine.AffineExpr.dimension(2),
            ),
        )
    )
    for i in range(len(op.lowerBound)):
        if i in tile_sizes:
            inner_lower.append(zero)
            ilower, iupper, istep = lower[i], upper[i], step[i]
            if (
                isinstance(ilower, arith.ConstantOp)
                and isinstance(iupper, arith.ConstantOp)
                and isinstance(istep, arith.ConstantOp)
            ):
                lower_v, upper_v, step_v = (
                    c.value for c in (ilower, iupper, istep)
                )
                assert isa(lower_v, IntegerAttr[IndexType])
                assert isa(upper_v, IntegerAttr[IndexType])
                assert isa(step_v, IntegerAttr[IndexType])
                lower_v, upper_v, step_v = (
                    c.value.data for c in (lower_v, upper_v, step_v)
                )
                iter_count = (upper_v - lower_v) // step_v
                if iter_count % tile_sizes_v[i] == 0:
                    inner_upper.append(tile_sizes[i])
                    continue

            arg_index = tiled_dims.index(i)
            minop = affine.MinOp(
                operands=[
                    [
                        tile_sizes[i],
                        outter_upper[arg_index],
                        outter_loop.body.block.args[arg_index],
                    ]
                ],
                properties={"map": minmap},
                result_types=[IndexType()],
            )
            minops.append(minop)
            inner_upper.append(minop)

        else:
            inner_lower.append(lower[i])
            inner_upper.append(upper[i])

    inner_loop = ParallelOp(inner_lower, inner_upper, step, op.detach_region(0))
    for i, arg in reversed(list(enumerate(inner_loop.body.block.args))):
        if i in tile_sizes:
            arg_index = tiled_dims.index(i)
            iv = arith.AddiOp(outter_loop.body.block.args[arg_index], arg)
            assert inner_loop.body.block.first_op is not None
            inner_loop.body.block.insert_op_before(
                iv, inner_loop.body.block.first_op
            )
            for use in tuple(arg.uses):
                if use.operation is iv:
                    continue
                use.operation.operands[use.index] = iv.result
    outter_loop.body.block.insert_ops_before([*minops, inner_loop], outter_reduce)
    rewriter.replace_op(op, [zero, *tile_sizes.values(), *outter_step, outter_loop])

ScfParallelLoopTilingPass dataclass

Bases: ModulePass

Source code in xdsl/transforms/scf_parallel_loop_tiling.py
134
135
136
137
138
139
140
141
142
143
144
145
146
@dataclass(frozen=True)
class ScfParallelLoopTilingPass(ModulePass):
    name = "scf-parallel-loop-tiling"

    parallel_loop_tile_sizes: tuple[int, ...]

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        walker = PatternRewriteWalker(
            ScfParallelLoopTilingPattern(tuple(self.parallel_loop_tile_sizes)),
            walk_regions_first=True,
            apply_recursively=False,
        )
        walker.rewrite_module(op)

name = 'scf-parallel-loop-tiling' class-attribute instance-attribute

parallel_loop_tile_sizes: tuple[int, ...] instance-attribute

__init__(parallel_loop_tile_sizes: tuple[int, ...]) -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/scf_parallel_loop_tiling.py
140
141
142
143
144
145
146
def apply(self, ctx: Context, op: ModuleOp) -> None:
    walker = PatternRewriteWalker(
        ScfParallelLoopTilingPattern(tuple(self.parallel_loop_tile_sizes)),
        walk_regions_first=True,
        apply_recursively=False,
    )
    walker.rewrite_module(op)