Skip to content

Memref stream fold fill

memref_stream_fold_fill

MemRefStreamFoldFillPass dataclass

Bases: ModulePass

Folds memref_stream.fill operations that run immediately before a memref_stream.generic operation into the init value. Assumes that none of the memrefs involved are aliased.

Source code in xdsl/transforms/memref_stream_fold_fill.py
57
58
59
60
61
62
63
64
65
66
67
68
@dataclass(frozen=True)
class MemRefStreamFoldFillPass(ModulePass):
    """
    Folds `memref_stream.fill` operations that run immediately before a
    `memref_stream.generic` operation into the init value.
    Assumes that none of the memrefs involved are aliased.
    """

    name = "memref-stream-fold-fill"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        fold_fills_in_module(op)

name = 'memref-stream-fold-fill' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/memref_stream_fold_fill.py
67
68
def apply(self, ctx: Context, op: ModuleOp) -> None:
    fold_fills_in_module(op)

fold_fills_in_module(module_op: ModuleOp)

Source code in xdsl/transforms/memref_stream_fold_fill.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def fold_fills_in_module(module_op: ModuleOp):
    fill_op_by_memref: dict[SSAValue, memref_stream.FillOp] = {}
    for op in module_op.walk():
        if isinstance(op, memref_stream.FillOp):
            if op.memref in fill_op_by_memref:
                # Two consecutive fills, erase first one, and replace with the new one
                Rewriter.erase_op(fill_op_by_memref[op.memref])
            fill_op_by_memref[op.memref] = op
            continue

        if isinstance(op, memref_stream.GenericOp):
            fill_ops = tuple(
                fill_op_by_memref.get(output, None) for output in op.outputs
            )
            indices = tuple(
                index for index, value in enumerate(fill_ops) if value is not None
            )
            if indices and op.is_imperfectly_nested:
                # There are values to rewrite, replace the operation
                init_indices = ArrayAttr(IntAttr(index) for index in indices)
                inits = tuple(
                    fill_op.value for fill_op in fill_ops if fill_op is not None
                )
                Rewriter.replace_op(
                    op,
                    memref_stream.GenericOp(
                        op.inputs,
                        op.outputs,
                        inits,
                        Rewriter.move_region_contents_to_new_regions(op.body),
                        op.indexing_maps,
                        op.iterator_types,
                        op.bounds,
                        init_indices,
                        op.doc,
                        op.library_call,
                    ),
                )
                for fill_op in set(value for value in fill_ops if value is not None):
                    Rewriter.erase_op(fill_op)

        for operand in op.operands:
            if operand in fill_op_by_memref:
                del fill_op_by_memref[operand]