Skip to content

Csl wrapper hoist buffers

csl_wrapper_hoist_buffers

HoistBuffers dataclass

Bases: RewritePattern

Hoists buffers to csl_wrapper.program_module-level.

Source code in xdsl/transforms/csl_wrapper_hoist_buffers.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@dataclass(frozen=True)
class HoistBuffers(RewritePattern):
    """
    Hoists buffers to `csl_wrapper.program_module`-level.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: memref.AllocOp, rewriter: PatternRewriter, /):
        # always attempt to set name hints
        self._set_name_hint(op)

        wrapper = op.parent_op()
        while wrapper and not isinstance(wrapper, csl_wrapper.ModuleOp):
            wrapper = wrapper.parent_op()

        # no action required if this op exists on module-level
        if not wrapper or wrapper == op.parent_op():
            return

        assert len(op.dynamic_sizes) == 0, "not implemented"
        assert len(op.symbol_operands) == 0, "not implemented"

        rewriter.insert_op(
            alloc := op.clone(), InsertPoint.at_start(wrapper.program_module.block)
        )
        rewriter.replace_op(op, [], new_results=[alloc.memref])

    @staticmethod
    def _set_name_hint(op: memref.AllocOp):
        """
        Attempts to find a chain of:
          %0 = memref.alloc
          %1 = csl.addressof(%0)
          csl.export(%1) <{var_name = "buf"}>

        and sets name hints for alloc and addressof to "buf" and "buf_ptr", respectively
        """
        for ptr_use in op.memref.uses:
            if not isinstance(ptr_op := ptr_use.operation, csl.AddressOfOp):
                continue

            for exp_use in ptr_op.res.uses:
                if not isinstance(exp_op := exp_use.operation, csl.SymbolExportOp):
                    continue
                op.memref.name_hint = exp_op.get_name()
                ptr_op.res.name_hint = f"{exp_op.get_name()}_ptr"
                return

__init__() -> None

match_and_rewrite(op: memref.AllocOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/csl_wrapper_hoist_buffers.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.AllocOp, rewriter: PatternRewriter, /):
    # always attempt to set name hints
    self._set_name_hint(op)

    wrapper = op.parent_op()
    while wrapper and not isinstance(wrapper, csl_wrapper.ModuleOp):
        wrapper = wrapper.parent_op()

    # no action required if this op exists on module-level
    if not wrapper or wrapper == op.parent_op():
        return

    assert len(op.dynamic_sizes) == 0, "not implemented"
    assert len(op.symbol_operands) == 0, "not implemented"

    rewriter.insert_op(
        alloc := op.clone(), InsertPoint.at_start(wrapper.program_module.block)
    )
    rewriter.replace_op(op, [], new_results=[alloc.memref])

CslWrapperHoistBuffers dataclass

Bases: ModulePass

Hoists buffers to the csl_wrapper.program_module-level.

Source code in xdsl/transforms/csl_wrapper_hoist_buffers.py
66
67
68
69
70
71
72
73
74
75
76
@dataclass(frozen=True)
class CslWrapperHoistBuffers(ModulePass):
    """
    Hoists buffers to the `csl_wrapper.program_module`-level.
    """

    name = "csl-wrapper-hoist-buffers"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        module_pass = PatternRewriteWalker(HoistBuffers())
        module_pass.rewrite_module(op)

name = 'csl-wrapper-hoist-buffers' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/csl_wrapper_hoist_buffers.py
74
75
76
def apply(self, ctx: Context, op: ModuleOp) -> None:
    module_pass = PatternRewriteWalker(HoistBuffers())
    module_pass.rewrite_module(op)