Skip to content

Code generation

code_generation

CodeGeneration dataclass

Source code in xdsl/frontend/pyast/code_generation.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass
class CodeGeneration:
    @staticmethod
    def run_with_type_converter(
        type_converter: TypeConverter,
        source: FunctionMap | ast.FunctionDef,
        file: str | None,
    ) -> builtin.ModuleOp:
        """Generates xDSL code and returns it encapsulated into a single module."""
        module = builtin.ModuleOp([])

        visitor = CodeGenerationVisitor(type_converter, module, file)
        if isinstance(source, ast.FunctionDef):
            visitor.visit(source)
        else:
            for function_def, _ in source.values():
                visitor.visit(function_def)
        return module

__init__() -> None

run_with_type_converter(type_converter: TypeConverter, source: FunctionMap | ast.FunctionDef, file: str | None) -> builtin.ModuleOp staticmethod

Generates xDSL code and returns it encapsulated into a single module.

Source code in xdsl/frontend/pyast/code_generation.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@staticmethod
def run_with_type_converter(
    type_converter: TypeConverter,
    source: FunctionMap | ast.FunctionDef,
    file: str | None,
) -> builtin.ModuleOp:
    """Generates xDSL code and returns it encapsulated into a single module."""
    module = builtin.ModuleOp([])

    visitor = CodeGenerationVisitor(type_converter, module, file)
    if isinstance(source, ast.FunctionDef):
        visitor.visit(source)
    else:
        for function_def, _ in source.values():
            visitor.visit(function_def)
    return module

CodeGenerationVisitor dataclass

Bases: NodeVisitor

Visitor that generates xDSL from the Python AST.

Source code in xdsl/frontend/pyast/code_generation.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
@dataclass(init=False)
class CodeGenerationVisitor(ast.NodeVisitor):
    """Visitor that generates xDSL from the Python AST."""

    type_converter: TypeConverter
    """Used for type conversion during code generation."""

    inserter: OpInserter
    """Used for inserting newly generated operations to the right block."""

    symbol_table: dict[str, Attribute] | None = field(default=None)
    """
    Maps local variable names to their xDSL types. A single dictionary is sufficient
    because inner functions and global variables are not allowed (yet).
    """

    file: str | None
    """Path of the file containing the program being processed."""

    def __init__(
        self,
        type_converter: TypeConverter,
        module: builtin.ModuleOp,
        file: str | None,
    ) -> None:
        self.type_converter = type_converter
        self.file = file

        assert len(module.body.blocks) == 1
        self.inserter = OpInserter(module.body.block)

    def get_symbol(self, node: ast.Name) -> Attribute:
        assert self.symbol_table is not None
        if node.id not in self.symbol_table:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Symbol '{node.id}' is not defined.",
            )
        return self.symbol_table[node.id]

    def visit(self, node: ast.AST) -> None:
        super().visit(node)

    def generic_visit(self, node: ast.AST) -> None:
        raise CodeGenerationException(
            self.file,
            getattr(node, "lineno"),
            getattr(node, "col_offset"),
            f"Unsupported Python AST node {str(node)}",
        )

    def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
        # TODO: Implement assignemnt in the next patch.
        pass

    def visit_Assert(self, node: ast.Assert) -> None:
        self.visit(node.test)
        if node.msg is None:
            msg = ""
        else:
            if not isinstance(node.msg, ast.Constant) or not isinstance(
                node.msg.value, str
            ):
                raise CodeGenerationException(
                    self.file,
                    node.lineno,
                    node.col_offset,
                    "Expected a string constant for assertion message, found "
                    f"'ast.{type(node.msg).__qualname__}'",
                )
            msg = str(node.msg.value)
        op = cf.AssertOp(self.inserter.get_operand(), msg)
        self.inserter.insert_op(op)

    def visit_Assign(self, node: ast.Assign) -> None:
        # TODO: Implement assignemnt in the next patch.
        pass

    def visit_BinOp(self, node: ast.BinOp) -> None:
        op_name: str = node.op.__class__.__qualname__

        # Table with mappings of Python AST operator to Python methods.
        python_AST_operator_to_python_overload = {
            "Add": "__add__",
            "Sub": "__sub__",
            "Mult": "__mul__",
            "Div": "__truediv__",
            "FloorDiv": "__floordiv__",
            "Mod": "__mod__",
            "Pow": "__pow__",
            "LShift": "__lshift__",
            "RShift": "__rshift__",
            "BitOr": "__or__",
            "BitXor": "__xor__",
            "BitAnd": "__and__",
            "MatMult": "__matmul__",
        }

        if op_name not in python_AST_operator_to_python_overload:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Unexpected binary operation {op_name}.",
            )

        # Check that the types of the operands are the same.
        # This is a (temporary?) restriction over Python for implementation simplicity.
        # This also means that we do not need to support reflected operations
        # (__radd__, __rsub__, etc.) which only exist for operations between different types.
        self.visit(node.right)
        rhs = self.inserter.get_operand()
        self.visit(node.left)
        lhs = self.inserter.get_operand()
        if lhs.type != rhs.type:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Expected the same types for binary operation '{op_name}', "
                f"but got {lhs.type} and {rhs.type}.",
            )

        ir_type = cast(TypeAttribute, lhs.type)
        source_type = self.type_converter.type_registry.get_annotation(ir_type)
        if source_type is None:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"IR type '{ir_type}' is not registered with a source type.",
            )

        method_name = python_AST_operator_to_python_overload[op_name]
        function_name = f"{source_type.__qualname__}.{method_name}"
        op = self.type_converter.function_registry.resolve_operation(
            module_name=source_type.__module__,
            method_name=function_name,
            args=(lhs, rhs),
        )
        if op is not None:
            self.inserter.insert_op(op)
            return

        overload_name = python_AST_operator_to_python_overload[op_name]
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Binary operation '{op_name}' "
            f"is not supported by type '{source_type.__qualname__}' "
            f"which does not overload '{overload_name}'.",
        )

    def visit_Call(self, node: ast.Call) -> None:
        match node.func:
            case ast.Name():
                source_kind = "function"
                source, source_name = self._call_source_function(node)
            case ast.Attribute():
                source_kind = "classmethod"
                source, source_name = self._call_source_classmethod(node)
            case _:
                raise CodeGenerationException(
                    self.file,
                    node.lineno,
                    node.col_offset,
                    "Unsupported call expression.",
                )

        ir_op = self.type_converter.function_registry.get_operation_constructor(source)
        if ir_op is None:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"{source_kind.capitalize()} '{source_name}' is not registered.",
            )

        # Resolve arguments
        assert self.symbol_table is not None
        args: list[symref.FetchOp] = []
        for arg in node.args:
            if not isinstance(arg, ast.Name) or arg.id not in self.symbol_table:
                raise CodeGenerationException(
                    self.file,
                    node.lineno,
                    node.col_offset,
                    f"{source_kind.capitalize()} arguments must be declared variables.",
                )
            args.append(arg_op := symref.FetchOp(arg.id, self.symbol_table[arg.id]))
            self.inserter.insert_op(arg_op)

        # Resolve keyword arguments
        kwargs: dict[str, symref.FetchOp] = {}
        for keyword in node.keywords:
            if (
                not isinstance(keyword.value, ast.Name)
                or keyword.value.id not in self.symbol_table
            ):
                raise CodeGenerationException(
                    self.file,
                    node.lineno,
                    node.col_offset,
                    f"{source_kind.capitalize()} arguments must be declared variables.",
                )
            assert keyword.arg is not None
            kwargs[keyword.arg] = symref.FetchOp(
                keyword.value.id, self.symbol_table[keyword.value.id]
            )
            self.inserter.insert_op(kwargs[keyword.arg])

        self.inserter.insert_op(ir_op(*args, **kwargs))

    # Get called function for a call expression.
    def _call_source_function(self, node: ast.Call) -> tuple[Callable[..., Any], str]:
        assert isinstance(node.func, ast.Name)

        func_name = node.func.id
        func = self.type_converter.globals.get(func_name, None)
        if func is None:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Function '{func_name}' is not defined in scope.",
            )
        return func, func_name

    # Get called classmethod for a call expression.
    def _call_source_classmethod(
        self, node: ast.Call
    ) -> tuple[Callable[..., Any], str]:
        assert isinstance(node.func, ast.Attribute)
        assert isinstance(node.func.value, ast.Name)

        class_name = node.func.value.id
        method_name = node.func.attr
        classmethod_name = f"{class_name}.{method_name}"

        source_class = self.type_converter.globals.get(class_name, None)
        if source_class is None:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Class '{class_name}' is not defined in scope.",
            )
        classmethod_ = getattr(source_class, method_name, None)
        if classmethod_ is None:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Method '{method_name}' is not defined on class '{class_name}'.",
            )
        return classmethod_, classmethod_name

    def visit_Compare(self, node: ast.Compare) -> None:
        # Allow a single comparison only.
        if len(node.comparators) != 1 or len(node.ops) != 1:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Expected a single comparator, but found {len(node.comparators)}.",
            )
        comp = node.comparators[0]
        op_name: str = node.ops[0].__class__.__qualname__

        # Table with mappings of Python AST cmpop to Python method.
        python_AST_cmpop_to_python_overload = {
            "Eq": "__eq__",
            "Gt": "__gt__",
            "GtE": "__ge__",
            "Lt": "__lt__",
            "LtE": "__le__",
            "NotEq": "__ne__",
            "In": "__contains__",
            "NotIn": "__contains__",
        }

        # Table with currently unsupported Python AST cmpops.
        # The "is" and "is not" operators are (currently) not supported,
        # since the frontend does not consider/preserve object identity.
        # Finally, "not in" does not directly correspond to a special method
        # and is instead simply implemented as the negation of __contains__
        # which the current mapping framework cannot handle.
        unsupported_python_AST_cmpop = {"Is", "IsNot", "NotIn"}

        if op_name in unsupported_python_AST_cmpop:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Unsupported comparison operation '{op_name}'.",
            )

        # Check that the types of the operands are the same.
        # This is a (temporary?) restriction over Python for implementation simplicity.
        # This also means that we do not need to consider swapping arguments
        # (__eq__ and __ne__ are their own reflection, __lt__ <-> __gt__  and __le__ <-> __ge__).
        self.visit(comp)
        rhs = self.inserter.get_operand()
        self.visit(node.left)
        lhs = self.inserter.get_operand()
        if lhs.type != rhs.type:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Expected the same types for comparison operator '{op_name}',"
                f" but got {lhs.type} and {rhs.type}.",
            )

        ir_type = cast(TypeAttribute, lhs.type)
        source_type = self.type_converter.type_registry.get_annotation(ir_type)
        if source_type is None:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"IR type '{ir_type}' is not registered with a source type.",
            )

        method_name = python_AST_cmpop_to_python_overload[op_name]
        function_name = f"{source_type.__qualname__}.{method_name}"
        op = self.type_converter.function_registry.resolve_operation(
            module_name=source_type.__module__,
            method_name=function_name,
            args=(lhs, rhs),
        )
        if op is not None:
            self.inserter.insert_op(op)
            return

        python_op = python_AST_cmpop_to_python_overload[op_name]
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Comparison operation '{op_name}' "
            f"is not supported by type '{ir_type.name}' "
            f"which does not overload '{python_op}'.",
        )

    def visit_Expr(self, node: ast.Expr) -> None:
        self.visit(node.value)

    def visit_For(self, node: ast.For) -> None:
        raise NotImplementedError("For loops are currently not supported!")

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        # Set the symbol table.
        if self.symbol_table is not None:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Cannot have an inner function '{node.name}' inside another function.",
            )
        self.symbol_table = dict()

        # Then, convert types in the function signature.
        argument_types: list[Attribute] = []
        for i, arg in enumerate(node.args.args):
            if arg.annotation is None:
                raise CodeGenerationException(
                    self.file,
                    arg.lineno,
                    arg.col_offset,
                    "Function arguments must be type hinted",
                )
            xdsl_type = self.type_converter.type_registry.resolve_attribute(
                ast.unparse(arg.annotation), self.type_converter.globals
            )
            if xdsl_type is None:
                raise CodeGenerationException(
                    self.file,
                    arg.lineno,
                    arg.col_offset,
                    f"Unsupported function argument type: '{ast.unparse(arg.annotation)}'",
                )
            argument_types.append(xdsl_type)

        returns = node.returns
        return_types: list[Attribute] = []
        if not (
            returns is None
            or (isinstance(returns, ast.Constant) and returns.value is None)
        ):
            xdsl_type = self.type_converter.type_registry.resolve_attribute(
                ast.unparse(returns), self.type_converter.globals
            )
            if xdsl_type is None:
                raise CodeGenerationException(
                    self.file,
                    node.lineno,
                    node.col_offset,
                    f"Unsupported function return type: '{ast.unparse(returns)}'",
                )
            return_types.append(xdsl_type)

        # Create a function operation.
        entry_block = Block()
        body_region = Region(entry_block)
        func_op = func.FuncOp.from_region(
            node.name, argument_types, return_types, body_region
        )

        self.inserter.insert_op(func_op)
        self.inserter.set_insertion_point_from_block(entry_block)

        # All arguments are declared using symref.
        for i, arg in enumerate(node.args.args):
            symbol_name = str(arg.arg)
            block_arg = entry_block.insert_arg(argument_types[i], i)
            block_arg.name_hint = symbol_name
            self.symbol_table[symbol_name] = argument_types[i]
            entry_block.add_op(symref.DeclareOp(symbol_name))
            entry_block.add_op(symref.UpdateOp(symbol_name, block_arg))

        # Parse function body.
        for stmt in node.body:
            self.visit(stmt)

        # If function does not end with a return statement to be visited, we
        # must insert a ReturnOp here.
        if not isinstance(node.body[-1], ast.Return):
            self.inserter.insert_op(func.ReturnOp())

        # When function definition is processed, reset the symbol table and set
        # the insertion point.
        self.symbol_table = None
        parent_op = func_op.parent_op()
        assert parent_op is not None
        self.inserter.set_insertion_point_from_op(parent_op)

    def visit_If(self, node: ast.If) -> None:
        # Get the condition.
        self.visit(node.test)
        cond = self.inserter.get_operand()
        cond_block = self.inserter.insertion_point

        def visit_region(stmts: list[ast.stmt]) -> Region:
            region = Region([Block()])
            self.inserter.set_insertion_point_from_region(region)
            for stmt in stmts:
                self.visit(stmt)
            return region

        # Generate code for both branches.
        true_region = visit_region(node.body)
        false_region = visit_region(node.orelse)

        # In our case, if statement never returns a value and therefore we can
        # simply yield nothing. It is the responsibility of subsequent passes to
        # ensure SSA-form of IR and that values are yielded correctly.
        true_region.blocks[-1].add_op(scf.YieldOp())
        false_region.blocks[-1].add_op(scf.YieldOp())
        op = scf.IfOp(cond, [], true_region, false_region)

        # Reset insertion point and insert a new operation.
        self.inserter.set_insertion_point_from_block(cond_block)
        self.inserter.insert_op(op)

    def visit_IfExp(self, node: ast.IfExp) -> None:
        self.visit(node.test)
        cond = self.inserter.get_operand()
        cond_block = self.inserter.insertion_point

        def visit_expr(expr: ast.expr) -> tuple[Attribute, Region]:
            region = Region([Block()])
            self.inserter.set_insertion_point_from_region(region)
            self.visit(expr)
            result = self.inserter.get_operand()
            self.inserter.insert_op(scf.YieldOp(result))
            return result.type, region

        # Generate code for both branches.
        true_type, true_region = visit_expr(node.body)
        false_type, false_region = visit_expr(node.orelse)

        # Check types are the same for this to be a valid if statement.
        if true_type != false_type:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Expected the same types for if expression,"
                f" but got {true_type} and {false_type}.",
            )
        op = scf.IfOp(cond, [true_type], true_region, false_region)

        # Reset insertion point to add scf.if.
        self.inserter.set_insertion_point_from_block(cond_block)
        self.inserter.insert_op(op)

    def visit_Name(self, node: ast.Name) -> None:
        fetch_op = symref.FetchOp(node.id, self.get_symbol(node))
        self.inserter.insert_op(fetch_op)

    def visit_Pass(self, node: ast.Pass) -> None:
        pass

    def visit_Return(self, node: ast.Return) -> None:
        # First of all, we should only be able to return if the statement is directly
        # in the function. Cases like:
        #
        # def foo(cond: i1):
        #   if cond:
        #     return 1
        #   else:
        #     return 0
        #
        # are not allowed at the moment.
        parent_op = self.inserter.insertion_point.parent_op()
        if not isinstance(parent_op, func.FuncOp):
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                "Return statement should be placed only at the end of the "
                "function body.",
            )

        callee = parent_op.sym_name.data
        func_return_types = parent_op.function_type.outputs.data

        value = node.value
        if value is None or (isinstance(value, ast.Constant) and value.value is None):
            # Return nothing, check function signature matches.
            if func_return_types:
                raise CodeGenerationException(
                    self.file,
                    node.lineno,
                    node.col_offset,
                    f"Expected non-zero number of return types in function "
                    f"'{callee}', but got 0.",
                )
            self.inserter.insert_op(func.ReturnOp())
        else:
            # Return some type, check function signature matches as well.
            # TODO: Support multiple return values if we allow multiple assignemnts.
            self.visit(value)
            operands = [self.inserter.get_operand()]

            if not func_return_types:
                raise CodeGenerationException(
                    self.file,
                    node.lineno,
                    node.col_offset,
                    f"Expected no return types in function '{callee}'.",
                )

            for i in range(len(operands)):
                if func_return_types[i] != operands[i].type:
                    raise CodeGenerationException(
                        self.file,
                        node.lineno,
                        node.col_offset,
                        f"Type signature and the type of the return value do "
                        f"not match at position {i}: expected {func_return_types[i]},"
                        f" got {operands[i].type}.",
                    )

            self.inserter.insert_op(func.ReturnOp(*operands))

symbol_table: dict[str, Attribute] | None = field(default=None) class-attribute instance-attribute

Maps local variable names to their xDSL types. A single dictionary is sufficient because inner functions and global variables are not allowed (yet).

type_converter: TypeConverter = type_converter instance-attribute

Used for type conversion during code generation.

file: str | None = file instance-attribute

Path of the file containing the program being processed.

inserter: OpInserter = OpInserter(module.body.block) instance-attribute

Used for inserting newly generated operations to the right block.

__init__(type_converter: TypeConverter, module: builtin.ModuleOp, file: str | None) -> None

Source code in xdsl/frontend/pyast/code_generation.py
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    type_converter: TypeConverter,
    module: builtin.ModuleOp,
    file: str | None,
) -> None:
    self.type_converter = type_converter
    self.file = file

    assert len(module.body.blocks) == 1
    self.inserter = OpInserter(module.body.block)

get_symbol(node: ast.Name) -> Attribute

Source code in xdsl/frontend/pyast/code_generation.py
71
72
73
74
75
76
77
78
79
80
def get_symbol(self, node: ast.Name) -> Attribute:
    assert self.symbol_table is not None
    if node.id not in self.symbol_table:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Symbol '{node.id}' is not defined.",
        )
    return self.symbol_table[node.id]

visit(node: ast.AST) -> None

Source code in xdsl/frontend/pyast/code_generation.py
82
83
def visit(self, node: ast.AST) -> None:
    super().visit(node)

generic_visit(node: ast.AST) -> None

Source code in xdsl/frontend/pyast/code_generation.py
85
86
87
88
89
90
91
def generic_visit(self, node: ast.AST) -> None:
    raise CodeGenerationException(
        self.file,
        getattr(node, "lineno"),
        getattr(node, "col_offset"),
        f"Unsupported Python AST node {str(node)}",
    )

visit_AnnAssign(node: ast.AnnAssign) -> None

Source code in xdsl/frontend/pyast/code_generation.py
93
94
95
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
    # TODO: Implement assignemnt in the next patch.
    pass

visit_Assert(node: ast.Assert) -> None

Source code in xdsl/frontend/pyast/code_generation.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def visit_Assert(self, node: ast.Assert) -> None:
    self.visit(node.test)
    if node.msg is None:
        msg = ""
    else:
        if not isinstance(node.msg, ast.Constant) or not isinstance(
            node.msg.value, str
        ):
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                "Expected a string constant for assertion message, found "
                f"'ast.{type(node.msg).__qualname__}'",
            )
        msg = str(node.msg.value)
    op = cf.AssertOp(self.inserter.get_operand(), msg)
    self.inserter.insert_op(op)

visit_Assign(node: ast.Assign) -> None

Source code in xdsl/frontend/pyast/code_generation.py
116
117
118
def visit_Assign(self, node: ast.Assign) -> None:
    # TODO: Implement assignemnt in the next patch.
    pass

visit_BinOp(node: ast.BinOp) -> None

Source code in xdsl/frontend/pyast/code_generation.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def visit_BinOp(self, node: ast.BinOp) -> None:
    op_name: str = node.op.__class__.__qualname__

    # Table with mappings of Python AST operator to Python methods.
    python_AST_operator_to_python_overload = {
        "Add": "__add__",
        "Sub": "__sub__",
        "Mult": "__mul__",
        "Div": "__truediv__",
        "FloorDiv": "__floordiv__",
        "Mod": "__mod__",
        "Pow": "__pow__",
        "LShift": "__lshift__",
        "RShift": "__rshift__",
        "BitOr": "__or__",
        "BitXor": "__xor__",
        "BitAnd": "__and__",
        "MatMult": "__matmul__",
    }

    if op_name not in python_AST_operator_to_python_overload:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Unexpected binary operation {op_name}.",
        )

    # Check that the types of the operands are the same.
    # This is a (temporary?) restriction over Python for implementation simplicity.
    # This also means that we do not need to support reflected operations
    # (__radd__, __rsub__, etc.) which only exist for operations between different types.
    self.visit(node.right)
    rhs = self.inserter.get_operand()
    self.visit(node.left)
    lhs = self.inserter.get_operand()
    if lhs.type != rhs.type:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Expected the same types for binary operation '{op_name}', "
            f"but got {lhs.type} and {rhs.type}.",
        )

    ir_type = cast(TypeAttribute, lhs.type)
    source_type = self.type_converter.type_registry.get_annotation(ir_type)
    if source_type is None:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"IR type '{ir_type}' is not registered with a source type.",
        )

    method_name = python_AST_operator_to_python_overload[op_name]
    function_name = f"{source_type.__qualname__}.{method_name}"
    op = self.type_converter.function_registry.resolve_operation(
        module_name=source_type.__module__,
        method_name=function_name,
        args=(lhs, rhs),
    )
    if op is not None:
        self.inserter.insert_op(op)
        return

    overload_name = python_AST_operator_to_python_overload[op_name]
    raise CodeGenerationException(
        self.file,
        node.lineno,
        node.col_offset,
        f"Binary operation '{op_name}' "
        f"is not supported by type '{source_type.__qualname__}' "
        f"which does not overload '{overload_name}'.",
    )

visit_Call(node: ast.Call) -> None

Source code in xdsl/frontend/pyast/code_generation.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def visit_Call(self, node: ast.Call) -> None:
    match node.func:
        case ast.Name():
            source_kind = "function"
            source, source_name = self._call_source_function(node)
        case ast.Attribute():
            source_kind = "classmethod"
            source, source_name = self._call_source_classmethod(node)
        case _:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                "Unsupported call expression.",
            )

    ir_op = self.type_converter.function_registry.get_operation_constructor(source)
    if ir_op is None:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"{source_kind.capitalize()} '{source_name}' is not registered.",
        )

    # Resolve arguments
    assert self.symbol_table is not None
    args: list[symref.FetchOp] = []
    for arg in node.args:
        if not isinstance(arg, ast.Name) or arg.id not in self.symbol_table:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"{source_kind.capitalize()} arguments must be declared variables.",
            )
        args.append(arg_op := symref.FetchOp(arg.id, self.symbol_table[arg.id]))
        self.inserter.insert_op(arg_op)

    # Resolve keyword arguments
    kwargs: dict[str, symref.FetchOp] = {}
    for keyword in node.keywords:
        if (
            not isinstance(keyword.value, ast.Name)
            or keyword.value.id not in self.symbol_table
        ):
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"{source_kind.capitalize()} arguments must be declared variables.",
            )
        assert keyword.arg is not None
        kwargs[keyword.arg] = symref.FetchOp(
            keyword.value.id, self.symbol_table[keyword.value.id]
        )
        self.inserter.insert_op(kwargs[keyword.arg])

    self.inserter.insert_op(ir_op(*args, **kwargs))

visit_Compare(node: ast.Compare) -> None

Source code in xdsl/frontend/pyast/code_generation.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def visit_Compare(self, node: ast.Compare) -> None:
    # Allow a single comparison only.
    if len(node.comparators) != 1 or len(node.ops) != 1:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Expected a single comparator, but found {len(node.comparators)}.",
        )
    comp = node.comparators[0]
    op_name: str = node.ops[0].__class__.__qualname__

    # Table with mappings of Python AST cmpop to Python method.
    python_AST_cmpop_to_python_overload = {
        "Eq": "__eq__",
        "Gt": "__gt__",
        "GtE": "__ge__",
        "Lt": "__lt__",
        "LtE": "__le__",
        "NotEq": "__ne__",
        "In": "__contains__",
        "NotIn": "__contains__",
    }

    # Table with currently unsupported Python AST cmpops.
    # The "is" and "is not" operators are (currently) not supported,
    # since the frontend does not consider/preserve object identity.
    # Finally, "not in" does not directly correspond to a special method
    # and is instead simply implemented as the negation of __contains__
    # which the current mapping framework cannot handle.
    unsupported_python_AST_cmpop = {"Is", "IsNot", "NotIn"}

    if op_name in unsupported_python_AST_cmpop:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Unsupported comparison operation '{op_name}'.",
        )

    # Check that the types of the operands are the same.
    # This is a (temporary?) restriction over Python for implementation simplicity.
    # This also means that we do not need to consider swapping arguments
    # (__eq__ and __ne__ are their own reflection, __lt__ <-> __gt__  and __le__ <-> __ge__).
    self.visit(comp)
    rhs = self.inserter.get_operand()
    self.visit(node.left)
    lhs = self.inserter.get_operand()
    if lhs.type != rhs.type:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Expected the same types for comparison operator '{op_name}',"
            f" but got {lhs.type} and {rhs.type}.",
        )

    ir_type = cast(TypeAttribute, lhs.type)
    source_type = self.type_converter.type_registry.get_annotation(ir_type)
    if source_type is None:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"IR type '{ir_type}' is not registered with a source type.",
        )

    method_name = python_AST_cmpop_to_python_overload[op_name]
    function_name = f"{source_type.__qualname__}.{method_name}"
    op = self.type_converter.function_registry.resolve_operation(
        module_name=source_type.__module__,
        method_name=function_name,
        args=(lhs, rhs),
    )
    if op is not None:
        self.inserter.insert_op(op)
        return

    python_op = python_AST_cmpop_to_python_overload[op_name]
    raise CodeGenerationException(
        self.file,
        node.lineno,
        node.col_offset,
        f"Comparison operation '{op_name}' "
        f"is not supported by type '{ir_type.name}' "
        f"which does not overload '{python_op}'.",
    )

visit_Expr(node: ast.Expr) -> None

Source code in xdsl/frontend/pyast/code_generation.py
388
389
def visit_Expr(self, node: ast.Expr) -> None:
    self.visit(node.value)

visit_For(node: ast.For) -> None

Source code in xdsl/frontend/pyast/code_generation.py
391
392
def visit_For(self, node: ast.For) -> None:
    raise NotImplementedError("For loops are currently not supported!")

visit_FunctionDef(node: ast.FunctionDef) -> None

Source code in xdsl/frontend/pyast/code_generation.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
    # Set the symbol table.
    if self.symbol_table is not None:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Cannot have an inner function '{node.name}' inside another function.",
        )
    self.symbol_table = dict()

    # Then, convert types in the function signature.
    argument_types: list[Attribute] = []
    for i, arg in enumerate(node.args.args):
        if arg.annotation is None:
            raise CodeGenerationException(
                self.file,
                arg.lineno,
                arg.col_offset,
                "Function arguments must be type hinted",
            )
        xdsl_type = self.type_converter.type_registry.resolve_attribute(
            ast.unparse(arg.annotation), self.type_converter.globals
        )
        if xdsl_type is None:
            raise CodeGenerationException(
                self.file,
                arg.lineno,
                arg.col_offset,
                f"Unsupported function argument type: '{ast.unparse(arg.annotation)}'",
            )
        argument_types.append(xdsl_type)

    returns = node.returns
    return_types: list[Attribute] = []
    if not (
        returns is None
        or (isinstance(returns, ast.Constant) and returns.value is None)
    ):
        xdsl_type = self.type_converter.type_registry.resolve_attribute(
            ast.unparse(returns), self.type_converter.globals
        )
        if xdsl_type is None:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Unsupported function return type: '{ast.unparse(returns)}'",
            )
        return_types.append(xdsl_type)

    # Create a function operation.
    entry_block = Block()
    body_region = Region(entry_block)
    func_op = func.FuncOp.from_region(
        node.name, argument_types, return_types, body_region
    )

    self.inserter.insert_op(func_op)
    self.inserter.set_insertion_point_from_block(entry_block)

    # All arguments are declared using symref.
    for i, arg in enumerate(node.args.args):
        symbol_name = str(arg.arg)
        block_arg = entry_block.insert_arg(argument_types[i], i)
        block_arg.name_hint = symbol_name
        self.symbol_table[symbol_name] = argument_types[i]
        entry_block.add_op(symref.DeclareOp(symbol_name))
        entry_block.add_op(symref.UpdateOp(symbol_name, block_arg))

    # Parse function body.
    for stmt in node.body:
        self.visit(stmt)

    # If function does not end with a return statement to be visited, we
    # must insert a ReturnOp here.
    if not isinstance(node.body[-1], ast.Return):
        self.inserter.insert_op(func.ReturnOp())

    # When function definition is processed, reset the symbol table and set
    # the insertion point.
    self.symbol_table = None
    parent_op = func_op.parent_op()
    assert parent_op is not None
    self.inserter.set_insertion_point_from_op(parent_op)

visit_If(node: ast.If) -> None

Source code in xdsl/frontend/pyast/code_generation.py
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
def visit_If(self, node: ast.If) -> None:
    # Get the condition.
    self.visit(node.test)
    cond = self.inserter.get_operand()
    cond_block = self.inserter.insertion_point

    def visit_region(stmts: list[ast.stmt]) -> Region:
        region = Region([Block()])
        self.inserter.set_insertion_point_from_region(region)
        for stmt in stmts:
            self.visit(stmt)
        return region

    # Generate code for both branches.
    true_region = visit_region(node.body)
    false_region = visit_region(node.orelse)

    # In our case, if statement never returns a value and therefore we can
    # simply yield nothing. It is the responsibility of subsequent passes to
    # ensure SSA-form of IR and that values are yielded correctly.
    true_region.blocks[-1].add_op(scf.YieldOp())
    false_region.blocks[-1].add_op(scf.YieldOp())
    op = scf.IfOp(cond, [], true_region, false_region)

    # Reset insertion point and insert a new operation.
    self.inserter.set_insertion_point_from_block(cond_block)
    self.inserter.insert_op(op)

visit_IfExp(node: ast.IfExp) -> None

Source code in xdsl/frontend/pyast/code_generation.py
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
def visit_IfExp(self, node: ast.IfExp) -> None:
    self.visit(node.test)
    cond = self.inserter.get_operand()
    cond_block = self.inserter.insertion_point

    def visit_expr(expr: ast.expr) -> tuple[Attribute, Region]:
        region = Region([Block()])
        self.inserter.set_insertion_point_from_region(region)
        self.visit(expr)
        result = self.inserter.get_operand()
        self.inserter.insert_op(scf.YieldOp(result))
        return result.type, region

    # Generate code for both branches.
    true_type, true_region = visit_expr(node.body)
    false_type, false_region = visit_expr(node.orelse)

    # Check types are the same for this to be a valid if statement.
    if true_type != false_type:
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            f"Expected the same types for if expression,"
            f" but got {true_type} and {false_type}.",
        )
    op = scf.IfOp(cond, [true_type], true_region, false_region)

    # Reset insertion point to add scf.if.
    self.inserter.set_insertion_point_from_block(cond_block)
    self.inserter.insert_op(op)

visit_Name(node: ast.Name) -> None

Source code in xdsl/frontend/pyast/code_generation.py
540
541
542
def visit_Name(self, node: ast.Name) -> None:
    fetch_op = symref.FetchOp(node.id, self.get_symbol(node))
    self.inserter.insert_op(fetch_op)

visit_Pass(node: ast.Pass) -> None

Source code in xdsl/frontend/pyast/code_generation.py
544
545
def visit_Pass(self, node: ast.Pass) -> None:
    pass

visit_Return(node: ast.Return) -> None

Source code in xdsl/frontend/pyast/code_generation.py
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
def visit_Return(self, node: ast.Return) -> None:
    # First of all, we should only be able to return if the statement is directly
    # in the function. Cases like:
    #
    # def foo(cond: i1):
    #   if cond:
    #     return 1
    #   else:
    #     return 0
    #
    # are not allowed at the moment.
    parent_op = self.inserter.insertion_point.parent_op()
    if not isinstance(parent_op, func.FuncOp):
        raise CodeGenerationException(
            self.file,
            node.lineno,
            node.col_offset,
            "Return statement should be placed only at the end of the "
            "function body.",
        )

    callee = parent_op.sym_name.data
    func_return_types = parent_op.function_type.outputs.data

    value = node.value
    if value is None or (isinstance(value, ast.Constant) and value.value is None):
        # Return nothing, check function signature matches.
        if func_return_types:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Expected non-zero number of return types in function "
                f"'{callee}', but got 0.",
            )
        self.inserter.insert_op(func.ReturnOp())
    else:
        # Return some type, check function signature matches as well.
        # TODO: Support multiple return values if we allow multiple assignemnts.
        self.visit(value)
        operands = [self.inserter.get_operand()]

        if not func_return_types:
            raise CodeGenerationException(
                self.file,
                node.lineno,
                node.col_offset,
                f"Expected no return types in function '{callee}'.",
            )

        for i in range(len(operands)):
            if func_return_types[i] != operands[i].type:
                raise CodeGenerationException(
                    self.file,
                    node.lineno,
                    node.col_offset,
                    f"Type signature and the type of the return value do "
                    f"not match at position {i}: expected {func_return_types[i]},"
                    f" got {operands[i].type}.",
                )

        self.inserter.insert_op(func.ReturnOp(*operands))