Skip to content

Python code check

python_code_check

BlockMap = dict[str, ast.FunctionDef] module-attribute

FunctionData = tuple[ast.FunctionDef, BlockMap] module-attribute

FunctionMap = dict[str, FunctionData] module-attribute

PythonCodeCheck dataclass

Source code in xdsl/frontend/pyast/utils/python_code_check.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@dataclass
class PythonCodeCheck:
    @staticmethod
    def run(stmts: Sequence[ast.stmt], file: str | None) -> FunctionMap:
        """
        Checks if Python code within `CodeContext` is supported. On unsupported
        cases, an exception is raised.

        Performed checks and transformations:

            1. Checks structure of code inside `CodeContext`. For example, no
               inner functions are allowed, etc. For more information see the
               docstring of `CheckStructure`.

            2. Checks the placement of constant expressions and inlines them
               into the AST.
        """
        # Check Python code is correctly structured.
        checker = CheckStructure(file)
        checker.run(stmts)

        # Check constant expressions are correctly defined. Should be called
        # only after the structure is checked.
        CheckAndInlineConstants.run(stmts, file)

        # Return well-structured functions and blocks to the caller.
        return checker.functions_and_blocks

__init__() -> None

run(stmts: Sequence[ast.stmt], file: str | None) -> FunctionMap staticmethod

Checks if Python code within CodeContext is supported. On unsupported cases, an exception is raised.

Performed checks and transformations:

1. Checks structure of code inside `CodeContext`. For example, no
   inner functions are allowed, etc. For more information see the
   docstring of `CheckStructure`.

2. Checks the placement of constant expressions and inlines them
   into the AST.
Source code in xdsl/frontend/pyast/utils/python_code_check.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@staticmethod
def run(stmts: Sequence[ast.stmt], file: str | None) -> FunctionMap:
    """
    Checks if Python code within `CodeContext` is supported. On unsupported
    cases, an exception is raised.

    Performed checks and transformations:

        1. Checks structure of code inside `CodeContext`. For example, no
           inner functions are allowed, etc. For more information see the
           docstring of `CheckStructure`.

        2. Checks the placement of constant expressions and inlines them
           into the AST.
    """
    # Check Python code is correctly structured.
    checker = CheckStructure(file)
    checker.run(stmts)

    # Check constant expressions are correctly defined. Should be called
    # only after the structure is checked.
    CheckAndInlineConstants.run(stmts, file)

    # Return well-structured functions and blocks to the caller.
    return checker.functions_and_blocks

CheckStructure dataclass

Ensures that the front-end program can be lowered to xDSL.

Any code written within CodeContext must be organized as a sequence of functions (possibly with a dedicated entry point), for example:

with CodeContext(p):
    def foo(x: i32) -> i32:
        return x
    def bar():
        y: i32 = foo(100)
        return

    def main():
        bar()
        return

For each function, it holds that: 1) Any function does not contain inner functions. 2) Any function has an explicit terminator: a return statement.

Additionally, any function can contain explicitly defined blocks, for example:

with CodeContext(p):
    def foo(x: i32) -> i32:
        @block
        def bb0(y: i32) -> i32:
            # unconditional branch to another block
            return bb1(y)

        @block
        def bb1(y: i32) -> i32:
            # terminator for the function
            return y

    # specifies the entry block
    return bb0(x)

For each block, it holds that: 1) No block has inner functions or nested blocks. 2) Any block has an explicit terminator: a return statement. The terminator can either transfer control flow to a next block, or terminate the enclosing function. 3) It is up to a user to ensure that the control flow is transfered correctly, e.g. to avoid infinite cycles.

Source code in xdsl/frontend/pyast/utils/python_code_check.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
@dataclass
class CheckStructure:
    """
    Ensures that the front-end program can be lowered to xDSL.

    Any code written within `CodeContext` must be organized as a sequence of
    functions (possibly with a dedicated entry point), for example:

    ```
    with CodeContext(p):
        def foo(x: i32) -> i32:
            return x
        def bar():
            y: i32 = foo(100)
            return

        def main():
            bar()
            return
    ```

    For each function, it holds that:
        1) Any function does not contain inner functions.
        2) Any function has an explicit terminator: a `return` statement.

    Additionally, any function can contain explicitly defined blocks, for
    example:

    ```
    with CodeContext(p):
        def foo(x: i32) -> i32:
            @block
            def bb0(y: i32) -> i32:
                # unconditional branch to another block
                return bb1(y)

            @block
            def bb1(y: i32) -> i32:
                # terminator for the function
                return y

        # specifies the entry block
        return bb0(x)
    ```

    For each block, it holds that:
        1) No block has inner functions or nested blocks.
        2) Any block has an explicit terminator: a `return` statement. The
           terminator can either transfer control flow to a next block, or
           terminate the enclosing function.
        3) It is up to a user to ensure that the control flow is transfered
           correctly, e.g. to avoid infinite cycles.
    """

    file: str | None = field(default=None)
    """File for error reporting."""

    functions_and_blocks: FunctionMap = field(default_factory=FunctionMap)
    """
    Contains all information about functions and blocks. Populated during the
    structure check.
    """

    def run(self, stmts: Sequence[ast.stmt]) -> None:
        for stmt in stmts:
            # Allow constant expression statements or pass.
            if is_constant_stmt(stmt) or isinstance(stmt, ast.Pass):
                continue

            # TODO: Right now we want all code to be placed in functions to make
            # code generation easier. This is limiting but can be fixed easily by
            # placing the whole AST into a dummy function, and performing code
            # generation on it. The only challenge is to make error messages
            # consistent.
            if isinstance(stmt, ast.FunctionDef):
                # Function should be a top-level operation.
                if is_block(stmt):
                    raise CodeGenerationException(
                        self.file,
                        stmt.lineno,
                        stmt.col_offset,
                        f"Expected a function, but found a block '{stmt.name}'."
                        " Only functions can be declared in the CodeContext",
                    )

                # Record this function.
                if stmt.name in self.functions_and_blocks:
                    line = self.functions_and_blocks[stmt.name][0].lineno
                    col = self.functions_and_blocks[stmt.name][0].col_offset
                    raise CodeGenerationException(
                        self.file,
                        stmt.lineno,
                        stmt.col_offset,
                        f"Function '{stmt.name}' is already defined at line "
                        f"{line} column {col}.",
                    )
                self.functions_and_blocks[stmt.name] = (stmt, dict())

                # Every function must have an explicit terminator, i.e. a return
                # statement. Using pass is not allowed. This design makes code
                # generation easier and can be relaxed in the future.
                if len(stmt.body) == 0:
                    raise CodeGenerationException(
                        self.file,
                        stmt.lineno,
                        stmt.col_offset,
                        f"Function '{stmt.name}' must have an explicit terminator."
                        " Have you tried adding a return statement?",
                    )
                if not isinstance(stmt.body[-1], ast.Return):
                    raise CodeGenerationException(
                        self.file,
                        stmt.lineno,
                        stmt.col_offset,
                        f"Function '{stmt.name}' must have an explicit return"
                        " in the end.",
                    )

                # Lastly, record basic block information that we can check
                # afterwards.
                for inner_stmt in stmt.body:
                    if isinstance(inner_stmt, ast.FunctionDef) and is_block(inner_stmt):
                        if inner_stmt.name in self.functions_and_blocks[stmt.name][1]:
                            line = self.functions_and_blocks[stmt.name][1][
                                inner_stmt.name
                            ].lineno
                            col = self.functions_and_blocks[stmt.name][1][
                                inner_stmt.name
                            ].col_offset
                            raise CodeGenerationException(
                                self.file,
                                stmt.lineno,
                                stmt.col_offset,
                                f"Block '{inner_stmt.name}' is already defined at line "
                                f"{line} column {col}.",
                            )
                        self.functions_and_blocks[stmt.name][1][inner_stmt.name] = (
                            inner_stmt
                        )
                continue

            # Otherwise, not a function, pass nor constant expression. Abort.
            raise CodeGenerationException(
                self.file,
                stmt.lineno,
                stmt.col_offset,
                "Frontend program must consist of functions or constant expressions.",
            )

        # Check structure of all functions and if necessary populate the map
        # with block information.
        for function_data in self.functions_and_blocks.values():
            self._check_function_structure(function_data[0])

    def _is_branch(self, function_name: str, node: ast.expr | None) -> bool:
        """Returns true if the terminator node is an unconditional branch."""
        return (
            node is not None
            and isinstance(node, ast.Call)
            and isinstance(node.func, ast.Name)
            and node.func.id in self.functions_and_blocks[function_name][1]
        )

    def _is_cond_branch(self, function_name: str, node: ast.expr | None) -> bool:
        """Returns true if the terminator node is a conditional branch."""
        return (
            node is not None
            and isinstance(node, ast.IfExp)
            and isinstance(node.body, ast.Call)
            and isinstance(node.body.func, ast.Name)
            and node.body.func.id in self.functions_and_blocks[function_name][1]
            and isinstance(node.orelse, ast.Call)
            and isinstance(node.orelse.func, ast.Name)
            and node.orelse.func.id in self.functions_and_blocks[function_name][1]
        )

    def _check_block_structure(self, function_name: str, node: ast.FunctionDef) -> bool:
        # Check that the basic block is well-formed.
        for stmt in node.body:
            # No inner functions or nested blocks.
            if isinstance(stmt, ast.FunctionDef):
                if is_block(stmt):
                    raise CodeGenerationException(
                        self.file,
                        stmt.lineno,
                        stmt.col_offset,
                        f"Cannot have a nested block '{stmt.name}'"
                        f" inside the block '{node.name}'.",
                    )
                else:
                    raise CodeGenerationException(
                        self.file,
                        stmt.lineno,
                        stmt.col_offset,
                        f"Cannot have a nested function '{stmt.name}'"
                        f" inside the block '{node.name}'.",
                    )

        # Check blocks have an explicit terminator.
        if len(node.body) == 0 or not isinstance(node.body[-1], ast.Return):
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Block '{node.name}' must have an explicit terminator."
                " Have you tried adding a return statement?",
            )

        # Check if the terminator of the block is well-formed. It can
        # terminate function, branch unconditionally to another block, or be
        # a conditional branch.
        assert isinstance(node.body[-1], ast.Return)
        terminator = node.body[-1].value
        return self._is_branch(function_name, terminator) or self._is_cond_branch(
            function_name, terminator
        )

    def _check_function_structure(self, node: ast.FunctionDef):
        # Functions cannot have inner functions but can have blocks inside
        # which we still have to check.
        num_explicit_blocks = len(self.functions_and_blocks[node.name][1])
        num_explicit_blocks_with_branches = 0

        for stmt in node.body[:-1]:
            # Constant expressions can be placed inside the function.
            if is_constant_stmt(stmt):
                continue

            # If there are explicit blocks, no operations are allowed outside of
            # them.
            if num_explicit_blocks > 0 and not isinstance(stmt, ast.FunctionDef):
                raise CodeGenerationException(
                    self.file,
                    stmt.lineno,
                    stmt.col_offset,
                    f"Function '{node.name}' cannot contain operations outside"
                    " of blocks apart from explicit entry point or constant "
                    "expressions.",
                )

            # Otherwise we allow anything, and only have to carefully look at
            # inner functions.
            if not isinstance(stmt, ast.FunctionDef):
                continue

            # Only blocks are allowed
            if not is_block(stmt):
                raise CodeGenerationException(
                    self.file,
                    stmt.lineno,
                    stmt.col_offset,
                    f"Cannot have an inner function '{stmt.name}' inside "
                    f"the function '{node.name}'.",
                )

            # Check the block and record if its terminator is a branch or not.
            if self._check_block_structure(node.name, stmt):
                num_explicit_blocks_with_branches += 1

        # Last check: we must have exactly one terminating block if blocks are
        # explicitly defined.
        if (
            num_explicit_blocks > 1
            and num_explicit_blocks == num_explicit_blocks_with_branches
        ):
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Function '{node.name}' does not have a terminating block.",
            )
        num_explicit_terminating_blocks = (
            num_explicit_blocks - num_explicit_blocks_with_branches
        )
        if num_explicit_terminating_blocks > 1:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Function '{node.name}' expected one terminating block, got"
                f" {num_explicit_terminating_blocks}.",
            )

file: str | None = field(default=None) class-attribute instance-attribute

File for error reporting.

functions_and_blocks: FunctionMap = field(default_factory=FunctionMap) class-attribute instance-attribute

Contains all information about functions and blocks. Populated during the structure check.

__init__(file: str | None = None, functions_and_blocks: FunctionMap = FunctionMap()) -> None

run(stmts: Sequence[ast.stmt]) -> None

Source code in xdsl/frontend/pyast/utils/python_code_check.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def run(self, stmts: Sequence[ast.stmt]) -> None:
    for stmt in stmts:
        # Allow constant expression statements or pass.
        if is_constant_stmt(stmt) or isinstance(stmt, ast.Pass):
            continue

        # TODO: Right now we want all code to be placed in functions to make
        # code generation easier. This is limiting but can be fixed easily by
        # placing the whole AST into a dummy function, and performing code
        # generation on it. The only challenge is to make error messages
        # consistent.
        if isinstance(stmt, ast.FunctionDef):
            # Function should be a top-level operation.
            if is_block(stmt):
                raise CodeGenerationException(
                    self.file,
                    stmt.lineno,
                    stmt.col_offset,
                    f"Expected a function, but found a block '{stmt.name}'."
                    " Only functions can be declared in the CodeContext",
                )

            # Record this function.
            if stmt.name in self.functions_and_blocks:
                line = self.functions_and_blocks[stmt.name][0].lineno
                col = self.functions_and_blocks[stmt.name][0].col_offset
                raise CodeGenerationException(
                    self.file,
                    stmt.lineno,
                    stmt.col_offset,
                    f"Function '{stmt.name}' is already defined at line "
                    f"{line} column {col}.",
                )
            self.functions_and_blocks[stmt.name] = (stmt, dict())

            # Every function must have an explicit terminator, i.e. a return
            # statement. Using pass is not allowed. This design makes code
            # generation easier and can be relaxed in the future.
            if len(stmt.body) == 0:
                raise CodeGenerationException(
                    self.file,
                    stmt.lineno,
                    stmt.col_offset,
                    f"Function '{stmt.name}' must have an explicit terminator."
                    " Have you tried adding a return statement?",
                )
            if not isinstance(stmt.body[-1], ast.Return):
                raise CodeGenerationException(
                    self.file,
                    stmt.lineno,
                    stmt.col_offset,
                    f"Function '{stmt.name}' must have an explicit return"
                    " in the end.",
                )

            # Lastly, record basic block information that we can check
            # afterwards.
            for inner_stmt in stmt.body:
                if isinstance(inner_stmt, ast.FunctionDef) and is_block(inner_stmt):
                    if inner_stmt.name in self.functions_and_blocks[stmt.name][1]:
                        line = self.functions_and_blocks[stmt.name][1][
                            inner_stmt.name
                        ].lineno
                        col = self.functions_and_blocks[stmt.name][1][
                            inner_stmt.name
                        ].col_offset
                        raise CodeGenerationException(
                            self.file,
                            stmt.lineno,
                            stmt.col_offset,
                            f"Block '{inner_stmt.name}' is already defined at line "
                            f"{line} column {col}.",
                        )
                    self.functions_and_blocks[stmt.name][1][inner_stmt.name] = (
                        inner_stmt
                    )
            continue

        # Otherwise, not a function, pass nor constant expression. Abort.
        raise CodeGenerationException(
            self.file,
            stmt.lineno,
            stmt.col_offset,
            "Frontend program must consist of functions or constant expressions.",
        )

    # Check structure of all functions and if necessary populate the map
    # with block information.
    for function_data in self.functions_and_blocks.values():
        self._check_function_structure(function_data[0])

CheckAndInlineConstants dataclass

This class is responsible for checking that the constants defined in the frontend program are valid. Every valid constant is inlined as a new AST node.

The algorithm for checking and inlining is iterative. When a new constant definition is encountered, the algorithm tries to inline it. This way frontend programs can define constants such as:

a: Const[i32] = 1 + len([1, 2, 3, 4])
b: Const[i32] = a * a
# here b = 25
Source code in xdsl/frontend/pyast/utils/python_code_check.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
@dataclass
class CheckAndInlineConstants:
    """
    This class is responsible for checking that the constants defined in the
    frontend program are valid. Every valid constant is inlined as a new AST
    node.

    The algorithm for checking and inlining is iterative. When a new constant
    definition is encountered, the algorithm tries to inline it. This way
    frontend programs can define constants such as:

    ```
    a: Const[i32] = 1 + len([1, 2, 3, 4])
    b: Const[i32] = a * a
    # here b = 25
    ```
    """

    @staticmethod
    def run(stmts: Sequence[ast.stmt], file: str | None) -> None:
        CheckAndInlineConstants.run_with_variables(stmts, set(), file)

    @staticmethod
    def run_with_variables(
        stmts: Sequence[ast.stmt], defined_variables: set[str], file: str | None
    ) -> None:
        for i, stmt in enumerate(stmts):
            # This variable (`a = ...`) can be redefined as a constant, and so
            # we have to keep track of these to raise an exception.
            if (
                isinstance(stmt, ast.Assign)
                and len(stmt.targets) == 1
                and isinstance(stmt.targets[0], ast.Name)
            ):
                defined_variables.add(stmt.targets[0].id)
                continue

            # Similarly, this case (`a: i32 = ...`) can also be redefined as a
            # constant.
            if (
                isinstance(stmt, ast.AnnAssign)
                and isinstance(stmt.target, ast.Name)
                and not is_constant(stmt.annotation)
            ):
                defined_variables.add(stmt.target.id)
                continue

            # This is a constant.
            if isinstance(stmt, ast.AnnAssign) and is_constant(stmt.annotation):
                if not isinstance(stmt.target, ast.Name):
                    raise CodeGenerationException(
                        file,
                        stmt.lineno,
                        stmt.col_offset,
                        "All constant expressions have to be assigned to "
                        "'ast.Name' nodes.",
                    )

                name = stmt.target.id
                try:
                    assert stmt.value is not None
                    value = eval(ast.unparse(stmt.value))
                except Exception:
                    # TODO: This error message can be improved by matching exact
                    # exceptions returned by `eval` call.
                    raise CodeGenerationException(
                        file,
                        stmt.lineno,
                        stmt.col_offset,
                        f"Non-constant expression cannot be assigned to "
                        f"constant variable '{name}' or cannot be evaluated.",
                    )

                # For now, support primitive types only and add a guard to abort
                # in other cases.
                if not isinstance(value, int) and not isinstance(value, float):
                    raise CodeGenerationException(
                        file,
                        stmt.lineno,
                        stmt.col_offset,
                        f"Constant '{name}' has evaluated type '{type(value)}' "
                        "which is not supported.",
                    )

                # TODO: We should typecheck the value against the type. This can
                # get tricky since ints can overflow, etc. For example, `a:
                # Const[i16] = 100000000` should give an error.
                new_node = ast.Constant(value)
                inliner = ConstantInliner(name, new_node, file)
                for candidate in stmts[(i + 1) :]:
                    inliner.visit(candidate)
                continue

            # In case of a function/block definition, we must ensure we process
            # the nested list of statements as well. Note that if we reached
            # this then all constants above `i` must have been already inlined.
            # Hence, it is sufficient to check the function body only.
            if isinstance(stmt, ast.FunctionDef):
                new_defined_variables = {arg.arg for arg in stmt.args.args}
                CheckAndInlineConstants.run_with_variables(
                    stmt.body, new_defined_variables, file
                )

__init__() -> None

run(stmts: Sequence[ast.stmt], file: str | None) -> None staticmethod

Source code in xdsl/frontend/pyast/utils/python_code_check.py
346
347
348
@staticmethod
def run(stmts: Sequence[ast.stmt], file: str | None) -> None:
    CheckAndInlineConstants.run_with_variables(stmts, set(), file)

run_with_variables(stmts: Sequence[ast.stmt], defined_variables: set[str], file: str | None) -> None staticmethod

Source code in xdsl/frontend/pyast/utils/python_code_check.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
@staticmethod
def run_with_variables(
    stmts: Sequence[ast.stmt], defined_variables: set[str], file: str | None
) -> None:
    for i, stmt in enumerate(stmts):
        # This variable (`a = ...`) can be redefined as a constant, and so
        # we have to keep track of these to raise an exception.
        if (
            isinstance(stmt, ast.Assign)
            and len(stmt.targets) == 1
            and isinstance(stmt.targets[0], ast.Name)
        ):
            defined_variables.add(stmt.targets[0].id)
            continue

        # Similarly, this case (`a: i32 = ...`) can also be redefined as a
        # constant.
        if (
            isinstance(stmt, ast.AnnAssign)
            and isinstance(stmt.target, ast.Name)
            and not is_constant(stmt.annotation)
        ):
            defined_variables.add(stmt.target.id)
            continue

        # This is a constant.
        if isinstance(stmt, ast.AnnAssign) and is_constant(stmt.annotation):
            if not isinstance(stmt.target, ast.Name):
                raise CodeGenerationException(
                    file,
                    stmt.lineno,
                    stmt.col_offset,
                    "All constant expressions have to be assigned to "
                    "'ast.Name' nodes.",
                )

            name = stmt.target.id
            try:
                assert stmt.value is not None
                value = eval(ast.unparse(stmt.value))
            except Exception:
                # TODO: This error message can be improved by matching exact
                # exceptions returned by `eval` call.
                raise CodeGenerationException(
                    file,
                    stmt.lineno,
                    stmt.col_offset,
                    f"Non-constant expression cannot be assigned to "
                    f"constant variable '{name}' or cannot be evaluated.",
                )

            # For now, support primitive types only and add a guard to abort
            # in other cases.
            if not isinstance(value, int) and not isinstance(value, float):
                raise CodeGenerationException(
                    file,
                    stmt.lineno,
                    stmt.col_offset,
                    f"Constant '{name}' has evaluated type '{type(value)}' "
                    "which is not supported.",
                )

            # TODO: We should typecheck the value against the type. This can
            # get tricky since ints can overflow, etc. For example, `a:
            # Const[i16] = 100000000` should give an error.
            new_node = ast.Constant(value)
            inliner = ConstantInliner(name, new_node, file)
            for candidate in stmts[(i + 1) :]:
                inliner.visit(candidate)
            continue

        # In case of a function/block definition, we must ensure we process
        # the nested list of statements as well. Note that if we reached
        # this then all constants above `i` must have been already inlined.
        # Hence, it is sufficient to check the function body only.
        if isinstance(stmt, ast.FunctionDef):
            new_defined_variables = {arg.arg for arg in stmt.args.args}
            CheckAndInlineConstants.run_with_variables(
                stmt.body, new_defined_variables, file
            )

ConstantInliner dataclass

Bases: NodeTransformer

Given the name of a constant and a corresponding AST node, ConstantInliner traverses the AST and replaces the uses of the name with the node. Additionally, it is responsible for performing various checks whether the constant value is correctly used. In cases of a misuse (e.g. assigning to a constant), an exception is raised.

Source code in xdsl/frontend/pyast/utils/python_code_check.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
@dataclass
class ConstantInliner(ast.NodeTransformer):
    """
    Given the name of a constant and a corresponding AST node, `ConstantInliner`
    traverses the AST and replaces the uses of the `name` with the node.
    Additionally, it is responsible for performing various checks whether the
    constant value is correctly used. In cases of a misuse (e.g. assigning to a
    constant), an exception is raised.
    """

    name: str
    """The name of the constant to inline."""

    new_node: ast.Constant
    """New AST node to inline."""

    file: str | None = field(default=None)
    """Path to the file containing the program."""

    def visit_Assign(self, node: ast.Assign) -> ast.Assign:
        if (
            len(node.targets) == 1
            and isinstance(node.targets[0], ast.Name)
            and node.targets[0].id == self.name
        ):
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Constant '{self.name}' is already defined and cannot be assigned to.",
            )
        node.value = self.visit(node.value)
        return node

    def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign:
        if isinstance(node.target, ast.Name) and node.target.id == self.name:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Constant '{self.name}' is already defined.",
            )
        assert node.value is not None
        node.value = self.visit(node.value)
        return node

    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
        for arg in node.args.args:
            if arg.arg == self.name:
                raise CodeGenerationException(
                    self.file,
                    node.lineno,
                    node.col_offset,
                    f"Constant '{self.name}' is already defined and cannot be "
                    "used as a function/block argument name.",
                )
        for stmt in node.body:
            self.visit(stmt)
        return node

    def visit_Name(self, node: ast.Name) -> ast.Name | ast.Constant:
        if node.id == self.name:
            return self.new_node
        else:
            return node

name: str instance-attribute

The name of the constant to inline.

new_node: ast.Constant instance-attribute

New AST node to inline.

file: str | None = field(default=None) class-attribute instance-attribute

Path to the file containing the program.

__init__(name: str, new_node: ast.Constant, file: str | None = None) -> None

visit_Assign(node: ast.Assign) -> ast.Assign

Source code in xdsl/frontend/pyast/utils/python_code_check.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
def visit_Assign(self, node: ast.Assign) -> ast.Assign:
    if (
        len(node.targets) == 1
        and isinstance(node.targets[0], ast.Name)
        and node.targets[0].id == self.name
    ):
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Constant '{self.name}' is already defined and cannot be assigned to.",
        )
    node.value = self.visit(node.value)
    return node

visit_AnnAssign(node: ast.AnnAssign) -> ast.AnnAssign

Source code in xdsl/frontend/pyast/utils/python_code_check.py
466
467
468
469
470
471
472
473
474
475
476
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign:
    if isinstance(node.target, ast.Name) and node.target.id == self.name:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Constant '{self.name}' is already defined.",
        )
    assert node.value is not None
    node.value = self.visit(node.value)
    return node

visit_FunctionDef(node: ast.FunctionDef) -> ast.FunctionDef

Source code in xdsl/frontend/pyast/utils/python_code_check.py
478
479
480
481
482
483
484
485
486
487
488
489
490
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
    for arg in node.args.args:
        if arg.arg == self.name:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Constant '{self.name}' is already defined and cannot be "
                "used as a function/block argument name.",
            )
    for stmt in node.body:
        self.visit(stmt)
    return node

visit_Name(node: ast.Name) -> ast.Name | ast.Constant

Source code in xdsl/frontend/pyast/utils/python_code_check.py
492
493
494
495
496
def visit_Name(self, node: ast.Name) -> ast.Name | ast.Constant:
    if node.id == self.name:
        return self.new_node
    else:
        return node