Skip to content

Convert ml program to memref

convert_ml_program_to_memref

ConvertGlobalPattern

Bases: RewritePattern

Source code in xdsl/transforms/convert_ml_program_to_memref.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class ConvertGlobalPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: ml_program.GlobalOp, rewriter: PatternRewriter
    ) -> None:
        if op.value is None:
            raise NotImplementedError(
                "Converting ml_program.global with no value not implemented"
            )
        assert isinstance(op_type := op.type, TensorType)
        op_type = cast(TensorType[Any], op_type)
        new_type = memref.MemRefType(op_type.element_type, op_type.shape)
        rewriter.replace_op(
            op,
            (
                memref.GlobalOp.get(
                    op.sym_name,
                    new_type,
                    op.value,
                    op.sym_visibility,
                    UnitAttr() if op.is_mutable is None else None,
                ),
            ),
        )

match_and_rewrite(op: ml_program.GlobalOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_ml_program_to_memref.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: ml_program.GlobalOp, rewriter: PatternRewriter
) -> None:
    if op.value is None:
        raise NotImplementedError(
            "Converting ml_program.global with no value not implemented"
        )
    assert isinstance(op_type := op.type, TensorType)
    op_type = cast(TensorType[Any], op_type)
    new_type = memref.MemRefType(op_type.element_type, op_type.shape)
    rewriter.replace_op(
        op,
        (
            memref.GlobalOp.get(
                op.sym_name,
                new_type,
                op.value,
                op.sym_visibility,
                UnitAttr() if op.is_mutable is None else None,
            ),
        ),
    )

ConvertGlobalLoadConst

Bases: RewritePattern

Source code in xdsl/transforms/convert_ml_program_to_memref.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class ConvertGlobalLoadConst(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(
        self, op: ml_program.GlobalLoadConstantOp, rewriter: PatternRewriter
    ) -> None:
        assert isinstance(op_type := op.result.type, TensorType)
        op_type = cast(TensorType[Any], op_type)
        new_type = memref.MemRefType(op_type.element_type, op_type.shape)
        rewriter.replace_op(
            op,
            (
                mem := memref.GetGlobalOp(op.global_attr, new_type),
                bufferization.ToTensorOp(mem.memref),
            ),
        )

match_and_rewrite(op: ml_program.GlobalLoadConstantOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/convert_ml_program_to_memref.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@op_type_rewrite_pattern
def match_and_rewrite(
    self, op: ml_program.GlobalLoadConstantOp, rewriter: PatternRewriter
) -> None:
    assert isinstance(op_type := op.result.type, TensorType)
    op_type = cast(TensorType[Any], op_type)
    new_type = memref.MemRefType(op_type.element_type, op_type.shape)
    rewriter.replace_op(
        op,
        (
            mem := memref.GetGlobalOp(op.global_attr, new_type),
            bufferization.ToTensorOp(mem.memref),
        ),
    )

ConvertMlProgramToMemRefPass dataclass

Bases: ModulePass

Converts operations in the ml_program dialect to memref. ml_program operations are at the tensor level of abstraction, so some of the rewrites insert bufferization ops to bridge the gap to existing consumers of global tensors.

Source code in xdsl/transforms/convert_ml_program_to_memref.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class ConvertMlProgramToMemRefPass(ModulePass):
    """
    Converts operations in the `ml_program` dialect to `memref`.
    `ml_program` operations are at the `tensor` level of abstraction, so some of the
    rewrites insert `bufferization` ops to bridge the gap to existing consumers of global
    `tensor`s.
    """

    name = "convert-ml-program-to-memref"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ConvertGlobalPattern(),
                    ConvertGlobalLoadConst(),
                ]
            ),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'convert-ml-program-to-memref' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/convert_ml_program_to_memref.py
73
74
75
76
77
78
79
80
81
82
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ConvertGlobalPattern(),
                ConvertGlobalLoadConst(),
            ]
        ),
        apply_recursively=False,
    ).rewrite_module(op)