Skip to content

Context

context

FuncInfo

Bases: NamedTuple

Information about a decorated function being generated into IR.

Source code in xdsl/frontend/pyast/context.py
23
24
25
26
27
28
29
30
31
32
33
class FuncInfo(NamedTuple):
    """Information about a decorated function being generated into IR."""

    file: str
    """The path of the file containing the function."""

    globals: dict[str, Any]
    """The globals defined in that file up to the point of function definition."""

    ast: ast.FunctionDef
    """The Python AST representation of the function."""

file: str instance-attribute

The path of the file containing the function.

globals: dict[str, Any] instance-attribute

The globals defined in that file up to the point of function definition.

ast: ast.FunctionDef instance-attribute

The Python AST representation of the function.

PyASTContext dataclass

Encapsulate the mapping between Python and IR types and operations.

Source code in xdsl/frontend/pyast/context.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
@dataclass
class PyASTContext:
    """Encapsulate the mapping between Python and IR types and operations."""

    type_registry: TypeRegistry = field(default_factory=TypeRegistry)
    """Mappings between source code and IR type."""

    function_registry: FunctionRegistry = field(default_factory=FunctionRegistry)
    """Mappings between functions and their operation types."""

    post_transforms: list[ModulePass] = field(
        default_factory=lambda: [FrontendDesymrefyPass()]
    )
    """An ordered list of passes to apply to the built module."""

    post_callback: Callable[[ModulePass, ModuleOp, ModulePass], None] | None = (
        default_pipeline_callback
    )
    """Callback to run between post transforms."""

    ir_context: Context = field(
        default_factory=lambda: Context(allow_unregistered=True)
    )
    """The xDSL context to use when applying transformations to the built module."""

    def register_type(
        self,
        source_type: type,
        ir_type: TypeAttribute,
    ) -> None:
        """Associate a type in the source code with its type in the IR."""
        self.type_registry.insert(source_type, ir_type)

    def register_function(
        self, function: Callable[..., Any], ir_constructor: Callable[..., Operation]
    ) -> None:
        """Associate a method on an object in the source code with its IR implementation."""
        self.function_registry.insert(function, ir_constructor)

    def register_post_transform(self, transform: ModulePass) -> None:
        """Add a module pass to be run on the generated IR."""
        self.post_transforms.append(transform)

    def register_dialect(self, dialect: Dialect) -> None:
        """Add a dialect to the context used for transformation."""
        self.ir_context.load_dialect(dialect)

    @property
    def pass_pipeline(self) -> PassPipeline:
        """Get a pass pipeline from the context state."""
        return PassPipeline(tuple(self.post_transforms), self.post_callback)

    @classmethod
    def _get_func_info(
        cls,
        current_frame: FrameType | None,
        func: Callable[P, R],
    ) -> FuncInfo:
        """Get information about the decorated function."""
        # Get the correct function frame from the call stack
        assert current_frame is not None
        func_frame = current_frame.f_back
        assert func_frame is not None

        # Get the required information about the function from the frame
        func_file = func_frame.f_code.co_filename
        func_globals = func_frame.f_globals

        # Remove leading indentation from the source code to avoid parsing errors
        source = getsource(func.__code__)
        source = textwrap.dedent(source)

        # Retrieve the AST for the function body, without the decorator
        func_ast = ast.parse(source).body[0]
        assert isinstance(func_ast, ast.FunctionDef)
        assert func_ast.name == func.__name__
        assert len(func_ast.decorator_list) == 1
        func_ast.decorator_list = []

        # Return the information about the function
        return FuncInfo(func_file, func_globals, func_ast)

    @classmethod
    def _get_wrapped_program(
        cls, func: Callable[P, R], builder: PyASTBuilder
    ) -> PyASTProgram[P, R]:
        """Return a PyAST program for this function with the builder."""
        program = PyASTProgram[P, R](
            name=func.__name__,
            func=func,
            _builder=builder,
        )
        functools.update_wrapper(program, func)
        assert program.__doc__ == func.__doc__
        return program

    def parse_program(self, func: Callable[P, R]) -> PyASTProgram[P, R]:
        """Get a program wrapper by decorating a function."""
        func_file, func_globals, func_ast = self._get_func_info(currentframe(), func)
        builder = PyASTBuilder(
            type_registry=self.type_registry,
            function_registry=self.function_registry,
            file=func_file,
            globals=func_globals,
            function_ast=func_ast,
            build_context=self.ir_context,
            post_transforms=self.pass_pipeline,
        )
        return self._get_wrapped_program(func, builder)

type_registry: TypeRegistry = field(default_factory=TypeRegistry) class-attribute instance-attribute

Mappings between source code and IR type.

function_registry: FunctionRegistry = field(default_factory=FunctionRegistry) class-attribute instance-attribute

Mappings between functions and their operation types.

post_transforms: list[ModulePass] = field(default_factory=(lambda: [FrontendDesymrefyPass()])) class-attribute instance-attribute

An ordered list of passes to apply to the built module.

post_callback: Callable[[ModulePass, ModuleOp, ModulePass], None] | None = default_pipeline_callback class-attribute instance-attribute

Callback to run between post transforms.

ir_context: Context = field(default_factory=(lambda: Context(allow_unregistered=True))) class-attribute instance-attribute

The xDSL context to use when applying transformations to the built module.

pass_pipeline: PassPipeline property

Get a pass pipeline from the context state.

__init__(type_registry: TypeRegistry = TypeRegistry(), function_registry: FunctionRegistry = FunctionRegistry(), post_transforms: list[ModulePass] = (lambda: [FrontendDesymrefyPass()])(), post_callback: Callable[[ModulePass, ModuleOp, ModulePass], None] | None = default_pipeline_callback, ir_context: Context = (lambda: Context(allow_unregistered=True))()) -> None

register_type(source_type: type, ir_type: TypeAttribute) -> None

Associate a type in the source code with its type in the IR.

Source code in xdsl/frontend/pyast/context.py
68
69
70
71
72
73
74
def register_type(
    self,
    source_type: type,
    ir_type: TypeAttribute,
) -> None:
    """Associate a type in the source code with its type in the IR."""
    self.type_registry.insert(source_type, ir_type)

register_function(function: Callable[..., Any], ir_constructor: Callable[..., Operation]) -> None

Associate a method on an object in the source code with its IR implementation.

Source code in xdsl/frontend/pyast/context.py
76
77
78
79
80
def register_function(
    self, function: Callable[..., Any], ir_constructor: Callable[..., Operation]
) -> None:
    """Associate a method on an object in the source code with its IR implementation."""
    self.function_registry.insert(function, ir_constructor)

register_post_transform(transform: ModulePass) -> None

Add a module pass to be run on the generated IR.

Source code in xdsl/frontend/pyast/context.py
82
83
84
def register_post_transform(self, transform: ModulePass) -> None:
    """Add a module pass to be run on the generated IR."""
    self.post_transforms.append(transform)

register_dialect(dialect: Dialect) -> None

Add a dialect to the context used for transformation.

Source code in xdsl/frontend/pyast/context.py
86
87
88
def register_dialect(self, dialect: Dialect) -> None:
    """Add a dialect to the context used for transformation."""
    self.ir_context.load_dialect(dialect)

parse_program(func: Callable[P, R]) -> PyASTProgram[P, R]

Get a program wrapper by decorating a function.

Source code in xdsl/frontend/pyast/context.py
139
140
141
142
143
144
145
146
147
148
149
150
151
def parse_program(self, func: Callable[P, R]) -> PyASTProgram[P, R]:
    """Get a program wrapper by decorating a function."""
    func_file, func_globals, func_ast = self._get_func_info(currentframe(), func)
    builder = PyASTBuilder(
        type_registry=self.type_registry,
        function_registry=self.function_registry,
        file=func_file,
        globals=func_globals,
        function_ast=func_ast,
        build_context=self.ir_context,
        post_transforms=self.pass_pipeline,
    )
    return self._get_wrapped_program(func, builder)

CodeContext dataclass

Bases: AbstractContextManager[Any]

The CodeContext with block marks the scope in which the code in the custom DSL can be written. This code will be translated to xDSL/MLIR.

Source code in xdsl/frontend/pyast/context.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
@dataclass
class CodeContext(AbstractContextManager[Any]):
    """
    The CodeContext with block marks the scope in which the code in the custom
    DSL can be written. This code will be translated to xDSL/MLIR.
    """

    program: FrontendProgram
    """
    Underlying frontend program which can be compiled and translated to
    xDSL/MLIR.
    """

    def __enter__(self) -> None:
        # First, get the Python AST from the code.
        frame = _getframe(1)
        self.program.file = frame.f_code.co_filename
        src = getsource(frame)
        python_ast = ast.parse(src)

        # Get all the global information and record it as well. In particular,
        # this contains all the imports.
        self.program.globals = frame.f_globals

        # Find where the program starts.
        for node in ast.walk(python_ast):
            if (
                isinstance(node, ast.With)
                and node.lineno == frame.f_lineno - frame.f_code.co_firstlineno + 1
            ):
                # Found the program AST. Store it for later compilation or
                # execution.
                self.program.stmts = node.body

    def __exit__(self, *args: object):
        # Having proccessed all the code in the context, check it is well-formed
        # and can be compiled/executed. Additionally, record it for subsequent code generation.
        assert self.program.stmts is not None
        self.program.functions_and_blocks = PythonCodeCheck.run(
            self.program.stmts, self.program.file
        )

program: FrontendProgram instance-attribute

Underlying frontend program which can be compiled and translated to xDSL/MLIR.

__init__(program: FrontendProgram) -> None

__enter__() -> None

Source code in xdsl/frontend/pyast/context.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def __enter__(self) -> None:
    # First, get the Python AST from the code.
    frame = _getframe(1)
    self.program.file = frame.f_code.co_filename
    src = getsource(frame)
    python_ast = ast.parse(src)

    # Get all the global information and record it as well. In particular,
    # this contains all the imports.
    self.program.globals = frame.f_globals

    # Find where the program starts.
    for node in ast.walk(python_ast):
        if (
            isinstance(node, ast.With)
            and node.lineno == frame.f_lineno - frame.f_code.co_firstlineno + 1
        ):
            # Found the program AST. Store it for later compilation or
            # execution.
            self.program.stmts = node.body

__exit__(*args: object)

Source code in xdsl/frontend/pyast/context.py
188
189
190
191
192
193
194
def __exit__(self, *args: object):
    # Having proccessed all the code in the context, check it is well-formed
    # and can be compiled/executed. Additionally, record it for subsequent code generation.
    assert self.program.stmts is not None
    self.program.functions_and_blocks = PythonCodeCheck.run(
        self.program.stmts, self.program.file
    )

default_pipeline_callback(_previous_pass: ModulePass, module: ModuleOp, _next_pass: ModulePass) -> None

Default callback to verify the module after each transformation pass.

Source code in xdsl/frontend/pyast/context.py
36
37
38
39
40
def default_pipeline_callback(
    _previous_pass: ModulePass, module: ModuleOp, _next_pass: ModulePass
) -> None:
    """Default callback to verify the module after each transformation pass."""
    module.verify()