Skip to content

Transform interpreter

transform_interpreter

Transform interpreter.

TransformInterpreterPass dataclass

Bases: ModulePass

Transform dialect interpreter

Source code in xdsl/transforms/transform_interpreter.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@dataclass(frozen=True)
class TransformInterpreterPass(ModulePass):
    """Transform dialect interpreter"""

    name = "transform-interpreter"

    entry_point: str = "__transform_main"

    @staticmethod
    def find_transform_entry_point(
        root: builtin.ModuleOp, entry_point: str
    ) -> transform.NamedSequenceOp:
        for op in root.walk():
            if (
                isinstance(op, transform.NamedSequenceOp)
                and op.sym_name.data == entry_point
            ):
                return op
        raise PassFailedException(
            f"{root} could not find a nested named sequence with name: {entry_point}"
        )

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        schedule = TransformInterpreterPass.find_transform_entry_point(
            op, self.entry_point
        )
        interpreter = Interpreter(op)
        interpreter.register_implementations(TransformFunctions(ctx, get_all_passes()))
        interpreter.call_op(schedule, (op,))

name = 'transform-interpreter' class-attribute instance-attribute

entry_point: str = '__transform_main' class-attribute instance-attribute

__init__(entry_point: str = '__transform_main') -> None

find_transform_entry_point(root: builtin.ModuleOp, entry_point: str) -> transform.NamedSequenceOp staticmethod

Source code in xdsl/transforms/transform_interpreter.py
21
22
23
24
25
26
27
28
29
30
31
32
33
@staticmethod
def find_transform_entry_point(
    root: builtin.ModuleOp, entry_point: str
) -> transform.NamedSequenceOp:
    for op in root.walk():
        if (
            isinstance(op, transform.NamedSequenceOp)
            and op.sym_name.data == entry_point
        ):
            return op
    raise PassFailedException(
        f"{root} could not find a nested named sequence with name: {entry_point}"
    )

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/transform_interpreter.py
35
36
37
38
39
40
41
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    schedule = TransformInterpreterPass.find_transform_entry_point(
        op, self.entry_point
    )
    interpreter = Interpreter(op)
    interpreter.register_implementations(TransformFunctions(ctx, get_all_passes()))
    interpreter.call_op(schedule, (op,))