Skip to content

Conversion

conversion

PDL to PDL_interp Transformation

ConvertPDLToPDLInterpPass dataclass

Bases: ModulePass

Pass to convert PDL operations to PDL interpreter operations. This is a somewhat faithful port of the implementation in MLIR, but it may not generate the same exact results.

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@dataclass(frozen=True)
class ConvertPDLToPDLInterpPass(ModulePass):
    """
    Pass to convert PDL operations to PDL interpreter operations.
    This is a somewhat faithful port of the implementation in MLIR, but it may not generate the same exact results.
    """

    name = "convert-pdl-to-pdl-interp"

    optimize_for_eqsat: bool = True

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        patterns = [
            pattern for pattern in op.body.ops if isinstance(pattern, pdl.PatternOp)
        ]

        rewriter_module = ModuleOp([], sym_name=StringAttr("rewriters"))

        matcher_func = pdl_interp.FuncOp("matcher", ((pdl.OperationType(),), ()))
        generator = MatcherGenerator(
            matcher_func, rewriter_module, self.optimize_for_eqsat
        )
        generator.lower(patterns)
        op.body.block.add_op(matcher_func)

        # Replace all pattern ops with the matcher func and rewriter module
        rewriter = Rewriter()
        for pattern in patterns:
            rewriter.erase_op(pattern)
        op.body.block.add_op(rewriter_module)

name = 'convert-pdl-to-pdl-interp' class-attribute instance-attribute

optimize_for_eqsat: bool = True class-attribute instance-attribute

__init__(optimize_for_eqsat: bool = True) -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def apply(self, ctx: Context, op: ModuleOp) -> None:
    patterns = [
        pattern for pattern in op.body.ops if isinstance(pattern, pdl.PatternOp)
    ]

    rewriter_module = ModuleOp([], sym_name=StringAttr("rewriters"))

    matcher_func = pdl_interp.FuncOp("matcher", ((pdl.OperationType(),), ()))
    generator = MatcherGenerator(
        matcher_func, rewriter_module, self.optimize_for_eqsat
    )
    generator.lower(patterns)
    op.body.block.add_op(matcher_func)

    # Replace all pattern ops with the matcher func and rewriter module
    rewriter = Rewriter()
    for pattern in patterns:
        rewriter.erase_op(pattern)
    op.body.block.add_op(rewriter_module)

MatcherNode dataclass

Bases: ABC

Base class for matcher tree nodes

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
108
109
110
111
112
113
114
@dataclass
class MatcherNode(ABC):
    """Base class for matcher tree nodes"""

    position: Position | None = None
    question: Question | None = None
    failure_node: Optional["MatcherNode"] = None

position: Position | None = None class-attribute instance-attribute

question: Question | None = None class-attribute instance-attribute

failure_node: Optional[MatcherNode] = None class-attribute instance-attribute

__init__(position: Position | None = None, question: Question | None = None, failure_node: Optional[MatcherNode] = None) -> None

BoolNode dataclass

Bases: MatcherNode

Boolean predicate node

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
117
118
119
120
121
122
123
124
@dataclass(kw_only=True)
class BoolNode(MatcherNode):
    """Boolean predicate node"""

    success_node: MatcherNode | None = None
    failure_node: MatcherNode | None = None

    answer: Answer

success_node: MatcherNode | None = None class-attribute instance-attribute

failure_node: MatcherNode | None = None class-attribute instance-attribute

answer: Answer instance-attribute

__init__(position: Position | None = None, question: Question | None = None, *, failure_node: MatcherNode | None = None, success_node: MatcherNode | None = None, answer: Answer) -> None

SwitchNode dataclass

Bases: MatcherNode

Multi-way switch node

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
127
128
129
130
131
@dataclass
class SwitchNode(MatcherNode):
    """Multi-way switch node"""

    children: dict[Answer, MatcherNode | None] = field(default_factory=lambda: {})

children: dict[Answer, MatcherNode | None] = field(default_factory=(lambda: {})) class-attribute instance-attribute

__init__(position: Position | None = None, question: Question | None = None, failure_node: Optional[MatcherNode] = None, children: dict[Answer, MatcherNode | None] = (lambda: {})()) -> None

SuccessNode dataclass

Bases: MatcherNode

Successful pattern match

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
134
135
136
137
138
139
@dataclass(kw_only=True)
class SuccessNode(MatcherNode):
    """Successful pattern match"""

    pattern: pdl.PatternOp  # PDL pattern reference
    root: SSAValue | None = None  # Root value

pattern: pdl.PatternOp instance-attribute

root: SSAValue | None = None class-attribute instance-attribute

__init__(position: Position | None = None, question: Question | None = None, failure_node: Optional[MatcherNode] = None, *, pattern: pdl.PatternOp, root: SSAValue | None = None) -> None

ExitNode dataclass

Bases: MatcherNode

Exit/failure node

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
142
143
144
145
146
@dataclass
class ExitNode(MatcherNode):
    """Exit/failure node"""

    pass

__init__(position: Position | None = None, question: Question | None = None, failure_node: Optional[MatcherNode] = None) -> None

PatternAnalyzer

Analyzes PDL patterns and extracts predicates

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
class PatternAnalyzer:
    """Analyzes PDL patterns and extracts predicates"""

    def detect_roots(self, pattern: pdl.PatternOp) -> list[OpResult[pdl.OperationType]]:
        """Detect root operations in a pattern"""
        used = {
            operand.owner.parent_
            for operation_op in pattern.body.ops
            if isinstance(operation_op, pdl.OperationOp)
            for operand in operation_op.operand_values
            if isinstance(operand.owner, pdl.ResultOp | pdl.ResultsOp)
        }

        rewriter = pattern.body.block.last_op
        assert isinstance(rewriter, pdl.RewriteOp)
        if rewriter.root is not None:
            if rewriter.root in used:
                used.remove(rewriter.root)

        roots = [
            op.op
            for op in pattern.body.ops
            if isinstance(op, pdl.OperationOp) and op.op not in used
        ]
        return roots

    def extract_tree_predicates(
        self,
        value: SSAValue,
        position: Position,
        inputs: dict[SSAValue, Position],
        ignore_operand: int | None = None,
    ) -> list[PositionalPredicate]:
        """Extract predicates by walking the operation tree"""
        predicates: list[PositionalPredicate] = []

        # Check if this value has been visited before
        existing_pos = inputs.get(value)
        if existing_pos is not None:
            # If this is an input value that has been visited in the tree,
            # add a constraint to ensure both instances refer to the same value
            defining_op = value.owner
            if isinstance(
                defining_op,
                pdl.AttributeOp
                | pdl.OperandOp
                | pdl.OperandsOp
                | pdl.OperationOp
                | pdl.TypeOp
                | pdl.TypesOp,
            ):
                # Order positions by depth (deeper position gets the equality predicate)
                if position.get_operation_depth() > existing_pos.get_operation_depth():
                    deeper_pos, shallower_pos = position, existing_pos
                else:
                    deeper_pos, shallower_pos = existing_pos, position

                equal_pred = Predicate.get_equal_to(shallower_pos)
                predicates.append(
                    PositionalPredicate(
                        q=equal_pred.q, a=equal_pred.a, position=deeper_pos
                    )
                )
            return predicates

        inputs[value] = position

        # Dispatch based on position type (not value type!)
        match position:
            case AttributePosition():
                assert isinstance(value, OpResult)
                predicates.extend(
                    self._extract_attribute_predicates(value.owner, position, inputs)
                )
            case OperationPosition():
                assert isinstance(value, OpResult)
                predicates.extend(
                    self._extract_operation_predicates(
                        value.owner, position, inputs, ignore_operand
                    )
                )
            case TypePosition():
                assert isinstance(value, OpResult)
                predicates.extend(
                    self._extract_type_predicates(value.owner, position, inputs)
                )
            case OperandPosition() | OperandGroupPosition():
                assert isinstance(value, SSAValue)
                predicates.extend(
                    self._extract_operand_tree_predicates(value, position, inputs)
                )
            case _:
                raise TypeError(f"Unexpected position kind: {type(position)}")

        return predicates

    def _get_num_non_range_values(self, values: Sequence[SSAValue]) -> int:
        """Returns the number of non-range elements within values"""
        return sum(1 for v in values if not isinstance(v.type, pdl.RangeType))

    def _extract_attribute_predicates(
        self,
        attr_op: Operation,
        attr_pos: AttributePosition,
        inputs: dict[SSAValue, Position],
    ) -> list[PositionalPredicate]:
        """Extract predicates for an attribute"""
        predicates: list[PositionalPredicate] = []

        is_not_null = Predicate.get_is_not_null()
        predicates.append(
            PositionalPredicate(q=is_not_null.q, a=is_not_null.a, position=attr_pos)
        )

        if isinstance(attr_op, pdl.AttributeOp):
            if attr_op.value_type:
                type_pos = attr_pos.get_type()
                predicates.extend(
                    self.extract_tree_predicates(attr_op.value_type, type_pos, inputs)
                )

            elif attr_op.value:
                attr_constraint = Predicate.get_attribute_constraint(attr_op.value)
                predicates.append(
                    PositionalPredicate(
                        q=attr_constraint.q, a=attr_constraint.a, position=attr_pos
                    )
                )

        return predicates

    def _extract_operation_predicates(
        self,
        op_op: Operation,
        op_pos: OperationPosition,
        inputs: dict[SSAValue, Position],
        ignore_operand: int | None = None,
    ) -> list[PositionalPredicate]:
        """Extract predicates for an operation"""
        predicates: list[PositionalPredicate] = []

        if not isinstance(op_op, pdl.OperationOp):
            return predicates

        if not op_pos.is_root():
            is_not_null = Predicate.get_is_not_null()
            predicates.append(
                PositionalPredicate(q=is_not_null.q, a=is_not_null.a, position=op_pos)
            )

        # Operation name check
        if op_op.opName:
            op_name = op_op.opName.data
            op_name_pred = Predicate.get_operation_name(op_name)
            predicates.append(
                PositionalPredicate(q=op_name_pred.q, a=op_name_pred.a, position=op_pos)
            )

        operands = op_op.operand_values
        min_operands = self._get_num_non_range_values(operands)
        if min_operands != len(operands):
            # Has variadic operands - check minimum
            if min_operands > 0:
                operand_count_pred = Predicate.get_operand_count_at_least(min_operands)
                predicates.append(
                    PositionalPredicate(
                        q=operand_count_pred.q, a=operand_count_pred.a, position=op_pos
                    )
                )
        else:
            # All non-variadic - check exact count
            operand_count_pred = Predicate.get_operand_count(min_operands)
            predicates.append(
                PositionalPredicate(
                    q=operand_count_pred.q, a=operand_count_pred.a, position=op_pos
                )
            )

        types = op_op.type_values
        min_results = self._get_num_non_range_values(types)
        if min_results == len(types):
            # All non-variadic - check exact count
            result_count_pred = Predicate.get_result_count(len(types))
            predicates.append(
                PositionalPredicate(
                    q=result_count_pred.q, a=result_count_pred.a, position=op_pos
                )
            )
        elif min_results > 0:
            # Has variadic results - check minimum
            result_count_pred = Predicate.get_result_count_at_least(min_results)
            predicates.append(
                PositionalPredicate(
                    q=result_count_pred.q, a=result_count_pred.a, position=op_pos
                )
            )

        # Process attributes
        for attr_name, attr in zip(
            op_op.attributeValueNames, op_op.attribute_values, strict=True
        ):
            attr_pos = op_pos.get_attribute(attr_name.data)
            predicates.extend(self.extract_tree_predicates(attr, attr_pos, inputs))

        if len(operands) == 1 and isinstance(operands[0].type, pdl.RangeType):
            # Special case: single variadic operand represents all operands
            if op_pos.is_root() or op_pos.is_operand_defining_op():
                all_operands_pos = op_pos.get_all_operands()
                predicates.extend(
                    self.extract_tree_predicates(operands[0], all_operands_pos, inputs)
                )
        else:
            # Process individual operands
            found_variable_length = False
            for i, operand in enumerate(operands):
                is_variadic = isinstance(operand.type, pdl.RangeType)
                found_variable_length = found_variable_length or is_variadic

                if ignore_operand is not None and i == ignore_operand:
                    continue

                # Switch to group-based positioning after first variadic
                if found_variable_length:
                    operand_pos = op_pos.get_operand_group(i, is_variadic)
                else:
                    operand_pos = op_pos.get_operand(i)

                predicates.extend(
                    self.extract_tree_predicates(operand, operand_pos, inputs)
                )

        if len(types) == 1 and isinstance(types[0].type, pdl.RangeType):
            # Single variadic result represents all results
            all_results_pos = op_pos.get_all_results()
            type_pos = all_results_pos.get_type()
            predicates.extend(self.extract_tree_predicates(types[0], type_pos, inputs))
        else:
            # Process individual results
            found_variable_length = False
            for i, type_value in enumerate(types):
                is_variadic = isinstance(type_value.type, pdl.RangeType)
                found_variable_length = found_variable_length or is_variadic

                # Switch to group-based positioning after first variadic
                if found_variable_length:
                    result_pos = op_pos.get_result_group(i, is_variadic)
                else:
                    result_pos = op_pos.get_result(i)

                # Add not-null check for each result
                is_not_null = Predicate.get_is_not_null()
                predicates.append(
                    PositionalPredicate(
                        q=is_not_null.q, a=is_not_null.a, position=result_pos
                    )
                )

                # Process the result type
                type_pos = result_pos.get_type()
                predicates.extend(
                    self.extract_tree_predicates(type_value, type_pos, inputs)
                )

        return predicates

    def _extract_operand_tree_predicates(
        self,
        operand_value: SSAValue,
        operand_pos: OperandPosition | OperandGroupPosition,
        inputs: dict[SSAValue, Position],
    ) -> list[PositionalPredicate]:
        """Extract predicates for an operand or operand group"""
        predicates: list[PositionalPredicate] = []

        # Get the defining operation
        defining_op = operand_value.owner
        is_variadic = isinstance(operand_value.type, pdl.RangeType)

        match defining_op:
            case pdl.OperandOp() | pdl.OperandsOp():
                match defining_op:
                    case pdl.OperandOp():
                        is_not_null = Predicate.get_is_not_null()
                        predicates.append(
                            PositionalPredicate(
                                q=is_not_null.q, a=is_not_null.a, position=operand_pos
                            )
                        )
                    case pdl.OperandsOp() if (
                        isinstance(operand_pos, OperandGroupPosition)
                        and operand_pos.group_number is not None
                    ):
                        is_not_null = Predicate.get_is_not_null()
                        predicates.append(
                            PositionalPredicate(
                                q=is_not_null.q, a=is_not_null.a, position=operand_pos
                            )
                        )
                    case _:
                        pass

                if defining_op.value_type:
                    type_pos = operand_pos.get_type()
                    predicates.extend(
                        self.extract_tree_predicates(
                            defining_op.value_type, type_pos, inputs
                        )
                    )

            case pdl.ResultOp() | pdl.ResultsOp():
                index_attr = defining_op.index
                index = index_attr.value.data if index_attr is not None else None

                if index is not None:
                    is_not_null = Predicate.get_is_not_null()
                    predicates.append(
                        PositionalPredicate(
                            q=is_not_null.q, a=is_not_null.a, position=operand_pos
                        )
                    )

                # Get the parent operation position
                parent_op = defining_op.parent_
                defining_op_pos = operand_pos.get_defining_op()

                # Parent operation should not be null
                is_not_null = Predicate.get_is_not_null()
                predicates.append(
                    PositionalPredicate(
                        q=is_not_null.q, a=is_not_null.a, position=defining_op_pos
                    )
                )

                match defining_op:
                    case pdl.ResultOp():
                        result_pos = defining_op_pos.get_result(
                            index if index is not None else 0
                        )
                    case pdl.ResultsOp():  # ResultsOp
                        result_pos = defining_op_pos.get_result_group(
                            index, is_variadic
                        )

                equal_to = Predicate.get_equal_to(operand_pos)
                predicates.append(
                    PositionalPredicate(q=equal_to.q, a=equal_to.a, position=result_pos)
                )

                # Recursively process the parent operation
                predicates.extend(
                    self.extract_tree_predicates(parent_op, defining_op_pos, inputs)
                )
            case _:
                pass

        return predicates

    def _extract_type_predicates(
        self,
        type_op: Operation,
        type_pos: TypePosition,
        inputs: dict[SSAValue, Position],
    ) -> list[PositionalPredicate]:
        """Extract predicates for a type"""
        predicates: list[PositionalPredicate] = []

        match type_op:
            case pdl.TypeOp(constantType=const_type) if const_type:
                type_constraint = Predicate.get_type_constraint(const_type)
                predicates.append(
                    PositionalPredicate(
                        q=type_constraint.q, a=type_constraint.a, position=type_pos
                    )
                )
            case pdl.TypesOp(constantTypes=const_types) if const_types:
                type_constraint = Predicate.get_type_constraint(const_types)
                predicates.append(
                    PositionalPredicate(
                        q=type_constraint.q, a=type_constraint.a, position=type_pos
                    )
                )
            case _:
                pass

        return predicates

    def extract_non_tree_predicates(
        self,
        pattern: pdl.PatternOp,
        inputs: dict[SSAValue, Position],
    ) -> list[PositionalPredicate]:
        """Extract predicates that cannot be determined via tree walking"""
        predicates: list[PositionalPredicate] = []

        for op in pattern.body.ops:
            match op:
                case pdl.AttributeOp():
                    if op.output not in inputs:
                        if op.value:
                            # Create literal position for constant attribute
                            attr_pos = AttributeLiteralPosition(
                                value=op.value, parent=None
                            )
                            inputs[op.output] = attr_pos

                case pdl.ApplyNativeConstraintOp():
                    # Collect all argument positions
                    arg_positions = tuple(inputs.get(arg) for arg in op.args)
                    for pos in arg_positions:
                        assert pos is not None
                    arg_positions = cast(tuple[Position, ...], arg_positions)

                    # Find the furthest position (deepest)
                    furthest_pos = max(
                        arg_positions, key=lambda p: p.get_operation_depth() if p else 0
                    )

                    # Create the constraint predicate
                    result_types = tuple(r.type for r in op.res)
                    is_negated = bool(op.is_negated.value.data)
                    constraint_pred = Predicate.get_constraint(
                        op.constraint_name.data, arg_positions, result_types, is_negated
                    )

                    # Register positions for constraint results
                    for i, result in enumerate(op.results):
                        assert isinstance(constraint_pred.q, ConstraintQuestion)
                        constraint_pos = ConstraintPosition.get_constraint(
                            constraint_pred.q, i
                        )
                        existing = inputs.get(result)
                        if existing:
                            # Add equality constraint if result already has a position
                            deeper, shallower = (
                                (constraint_pos, existing)
                                if furthest_pos.get_operation_depth()
                                > existing.get_operation_depth()
                                else (existing, constraint_pos)
                            )
                            eq_pred = Predicate.get_equal_to(shallower)
                            predicates.append(
                                PositionalPredicate(
                                    q=eq_pred.q, a=eq_pred.a, position=deeper
                                )
                            )
                        else:
                            inputs[result] = constraint_pos

                    predicates.append(
                        PositionalPredicate(
                            q=constraint_pred.q,
                            a=constraint_pred.a,
                            position=furthest_pos,
                        )
                    )

                case pdl.ResultOp():
                    # Ensure result exists
                    if op.val not in inputs:
                        assert isinstance(op.parent_.owner, pdl.OperationOp)
                        parent_pos = inputs.get(op.parent_.owner.op)
                        if parent_pos and isinstance(parent_pos, OperationPosition):
                            result_pos = parent_pos.get_result(op.index.value.data)
                            inputs[op.val] = result_pos
                            is_not_null = Predicate.get_is_not_null()
                            predicates.append(
                                PositionalPredicate(
                                    q=is_not_null.q,
                                    a=is_not_null.a,
                                    position=result_pos,
                                )
                            )

                case pdl.ResultsOp():
                    # Handle result groups
                    if op.val not in inputs:
                        assert isinstance(op.parent_.owner, pdl.OperationOp)
                        parent_pos = inputs.get(op.parent_.owner.op)
                        if parent_pos and isinstance(parent_pos, OperationPosition):
                            is_variadic = isinstance(op.val.type, pdl.RangeType)
                            index = op.index.value.data if op.index else None
                            result_pos = parent_pos.get_result_group(index, is_variadic)
                            inputs[op.val] = result_pos
                            if index is not None:
                                is_not_null = Predicate.get_is_not_null()
                                predicates.append(
                                    PositionalPredicate(
                                        q=is_not_null.q,
                                        a=is_not_null.a,
                                        position=result_pos,
                                    )
                                )

                case pdl.TypeOp():
                    # Handle constant types
                    if op.result not in inputs and op.constantType:
                        type_pos = TypeLiteralPosition.get_type_literal(
                            value=op.constantType
                        )
                        inputs[op.result] = type_pos

                case pdl.TypesOp():
                    # Handle constant type arrays
                    if op.result not in inputs and op.constantTypes:
                        type_pos = TypeLiteralPosition.get_type_literal(
                            value=op.constantTypes
                        )
                        inputs[op.result] = type_pos

                case _:
                    pass

        return predicates

detect_roots(pattern: pdl.PatternOp) -> list[OpResult[pdl.OperationType]]

Detect root operations in a pattern

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def detect_roots(self, pattern: pdl.PatternOp) -> list[OpResult[pdl.OperationType]]:
    """Detect root operations in a pattern"""
    used = {
        operand.owner.parent_
        for operation_op in pattern.body.ops
        if isinstance(operation_op, pdl.OperationOp)
        for operand in operation_op.operand_values
        if isinstance(operand.owner, pdl.ResultOp | pdl.ResultsOp)
    }

    rewriter = pattern.body.block.last_op
    assert isinstance(rewriter, pdl.RewriteOp)
    if rewriter.root is not None:
        if rewriter.root in used:
            used.remove(rewriter.root)

    roots = [
        op.op
        for op in pattern.body.ops
        if isinstance(op, pdl.OperationOp) and op.op not in used
    ]
    return roots

extract_tree_predicates(value: SSAValue, position: Position, inputs: dict[SSAValue, Position], ignore_operand: int | None = None) -> list[PositionalPredicate]

Extract predicates by walking the operation tree

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def extract_tree_predicates(
    self,
    value: SSAValue,
    position: Position,
    inputs: dict[SSAValue, Position],
    ignore_operand: int | None = None,
) -> list[PositionalPredicate]:
    """Extract predicates by walking the operation tree"""
    predicates: list[PositionalPredicate] = []

    # Check if this value has been visited before
    existing_pos = inputs.get(value)
    if existing_pos is not None:
        # If this is an input value that has been visited in the tree,
        # add a constraint to ensure both instances refer to the same value
        defining_op = value.owner
        if isinstance(
            defining_op,
            pdl.AttributeOp
            | pdl.OperandOp
            | pdl.OperandsOp
            | pdl.OperationOp
            | pdl.TypeOp
            | pdl.TypesOp,
        ):
            # Order positions by depth (deeper position gets the equality predicate)
            if position.get_operation_depth() > existing_pos.get_operation_depth():
                deeper_pos, shallower_pos = position, existing_pos
            else:
                deeper_pos, shallower_pos = existing_pos, position

            equal_pred = Predicate.get_equal_to(shallower_pos)
            predicates.append(
                PositionalPredicate(
                    q=equal_pred.q, a=equal_pred.a, position=deeper_pos
                )
            )
        return predicates

    inputs[value] = position

    # Dispatch based on position type (not value type!)
    match position:
        case AttributePosition():
            assert isinstance(value, OpResult)
            predicates.extend(
                self._extract_attribute_predicates(value.owner, position, inputs)
            )
        case OperationPosition():
            assert isinstance(value, OpResult)
            predicates.extend(
                self._extract_operation_predicates(
                    value.owner, position, inputs, ignore_operand
                )
            )
        case TypePosition():
            assert isinstance(value, OpResult)
            predicates.extend(
                self._extract_type_predicates(value.owner, position, inputs)
            )
        case OperandPosition() | OperandGroupPosition():
            assert isinstance(value, SSAValue)
            predicates.extend(
                self._extract_operand_tree_predicates(value, position, inputs)
            )
        case _:
            raise TypeError(f"Unexpected position kind: {type(position)}")

    return predicates

extract_non_tree_predicates(pattern: pdl.PatternOp, inputs: dict[SSAValue, Position]) -> list[PositionalPredicate]

Extract predicates that cannot be determined via tree walking

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
def extract_non_tree_predicates(
    self,
    pattern: pdl.PatternOp,
    inputs: dict[SSAValue, Position],
) -> list[PositionalPredicate]:
    """Extract predicates that cannot be determined via tree walking"""
    predicates: list[PositionalPredicate] = []

    for op in pattern.body.ops:
        match op:
            case pdl.AttributeOp():
                if op.output not in inputs:
                    if op.value:
                        # Create literal position for constant attribute
                        attr_pos = AttributeLiteralPosition(
                            value=op.value, parent=None
                        )
                        inputs[op.output] = attr_pos

            case pdl.ApplyNativeConstraintOp():
                # Collect all argument positions
                arg_positions = tuple(inputs.get(arg) for arg in op.args)
                for pos in arg_positions:
                    assert pos is not None
                arg_positions = cast(tuple[Position, ...], arg_positions)

                # Find the furthest position (deepest)
                furthest_pos = max(
                    arg_positions, key=lambda p: p.get_operation_depth() if p else 0
                )

                # Create the constraint predicate
                result_types = tuple(r.type for r in op.res)
                is_negated = bool(op.is_negated.value.data)
                constraint_pred = Predicate.get_constraint(
                    op.constraint_name.data, arg_positions, result_types, is_negated
                )

                # Register positions for constraint results
                for i, result in enumerate(op.results):
                    assert isinstance(constraint_pred.q, ConstraintQuestion)
                    constraint_pos = ConstraintPosition.get_constraint(
                        constraint_pred.q, i
                    )
                    existing = inputs.get(result)
                    if existing:
                        # Add equality constraint if result already has a position
                        deeper, shallower = (
                            (constraint_pos, existing)
                            if furthest_pos.get_operation_depth()
                            > existing.get_operation_depth()
                            else (existing, constraint_pos)
                        )
                        eq_pred = Predicate.get_equal_to(shallower)
                        predicates.append(
                            PositionalPredicate(
                                q=eq_pred.q, a=eq_pred.a, position=deeper
                            )
                        )
                    else:
                        inputs[result] = constraint_pos

                predicates.append(
                    PositionalPredicate(
                        q=constraint_pred.q,
                        a=constraint_pred.a,
                        position=furthest_pos,
                    )
                )

            case pdl.ResultOp():
                # Ensure result exists
                if op.val not in inputs:
                    assert isinstance(op.parent_.owner, pdl.OperationOp)
                    parent_pos = inputs.get(op.parent_.owner.op)
                    if parent_pos and isinstance(parent_pos, OperationPosition):
                        result_pos = parent_pos.get_result(op.index.value.data)
                        inputs[op.val] = result_pos
                        is_not_null = Predicate.get_is_not_null()
                        predicates.append(
                            PositionalPredicate(
                                q=is_not_null.q,
                                a=is_not_null.a,
                                position=result_pos,
                            )
                        )

            case pdl.ResultsOp():
                # Handle result groups
                if op.val not in inputs:
                    assert isinstance(op.parent_.owner, pdl.OperationOp)
                    parent_pos = inputs.get(op.parent_.owner.op)
                    if parent_pos and isinstance(parent_pos, OperationPosition):
                        is_variadic = isinstance(op.val.type, pdl.RangeType)
                        index = op.index.value.data if op.index else None
                        result_pos = parent_pos.get_result_group(index, is_variadic)
                        inputs[op.val] = result_pos
                        if index is not None:
                            is_not_null = Predicate.get_is_not_null()
                            predicates.append(
                                PositionalPredicate(
                                    q=is_not_null.q,
                                    a=is_not_null.a,
                                    position=result_pos,
                                )
                            )

            case pdl.TypeOp():
                # Handle constant types
                if op.result not in inputs and op.constantType:
                    type_pos = TypeLiteralPosition.get_type_literal(
                        value=op.constantType
                    )
                    inputs[op.result] = type_pos

            case pdl.TypesOp():
                # Handle constant type arrays
                if op.result not in inputs and op.constantTypes:
                    type_pos = TypeLiteralPosition.get_type_literal(
                        value=op.constantTypes
                    )
                    inputs[op.result] = type_pos

            case _:
                pass

    return predicates

OrderedPredicate dataclass

Predicate with ordering information

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
@dataclass
class OrderedPredicate:
    """Predicate with ordering information"""

    position: Position
    question: Question
    primary_score: int = 0  # Frequency across patterns
    secondary_score: int = 0  # Squared sum within patterns
    tie_breaker: int = 0  # Insertion order
    pattern_answers: dict[pdl.PatternOp, Answer] = field(default_factory=lambda: {})

    def __lt__(self, other: "OrderedPredicate") -> bool:
        """Comparison for priority ordering"""
        return (
            self.primary_score,
            self.secondary_score,
            -self.position.get_operation_depth(),  # Prefer lower depth
            -get_position_cost(self.position),  # Position dependency
            -get_question_cost(self.question),  # Predicate dependency
            -self.tie_breaker,  # Deterministic order
        ) > (
            other.primary_score,
            other.secondary_score,
            -other.position.get_operation_depth(),
            -get_position_cost(other.position),
            -get_question_cost(other.question),
            -other.tie_breaker,
        )

    def __hash__(self):
        """The hash is based on the immutable identity of the predicate."""
        return hash((self.position, self.question))

position: Position instance-attribute

question: Question instance-attribute

primary_score: int = 0 class-attribute instance-attribute

secondary_score: int = 0 class-attribute instance-attribute

tie_breaker: int = 0 class-attribute instance-attribute

pattern_answers: dict[pdl.PatternOp, Answer] = field(default_factory=(lambda: {})) class-attribute instance-attribute

__init__(position: Position, question: Question, primary_score: int = 0, secondary_score: int = 0, tie_breaker: int = 0, pattern_answers: dict[pdl.PatternOp, Answer] = (lambda: {})()) -> None

__lt__(other: OrderedPredicate) -> bool

Comparison for priority ordering

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
def __lt__(self, other: "OrderedPredicate") -> bool:
    """Comparison for priority ordering"""
    return (
        self.primary_score,
        self.secondary_score,
        -self.position.get_operation_depth(),  # Prefer lower depth
        -get_position_cost(self.position),  # Position dependency
        -get_question_cost(self.question),  # Predicate dependency
        -self.tie_breaker,  # Deterministic order
    ) > (
        other.primary_score,
        other.secondary_score,
        -other.position.get_operation_depth(),
        -get_position_cost(other.position),
        -get_question_cost(other.question),
        -other.tie_breaker,
    )

__hash__()

The hash is based on the immutable identity of the predicate.

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
703
704
705
def __hash__(self):
    """The hash is based on the immutable identity of the predicate."""
    return hash((self.position, self.question))

PredicateTreeBuilder

Builds optimized predicate matching trees

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
class PredicateTreeBuilder:
    """Builds optimized predicate matching trees"""

    analyzer: PatternAnalyzer
    _pattern_roots: dict[pdl.PatternOp, SSAValue]
    pattern_value_positions: dict[pdl.PatternOp, dict[SSAValue, Position]]

    def __init__(self):
        self.analyzer = PatternAnalyzer()
        self._pattern_roots = {}
        self.pattern_value_positions = {}

    def build_predicate_tree(self, patterns: list[pdl.PatternOp]) -> MatcherNode:
        """Build optimized matcher tree from multiple patterns"""

        # Extract predicates for all patterns
        all_pattern_predicates: list[
            tuple[pdl.PatternOp, list[PositionalPredicate]]
        ] = []
        for pattern in patterns:
            predicates, root, inputs = self._extract_pattern_predicates(pattern)
            all_pattern_predicates.append((pattern, predicates))
            self._pattern_roots[pattern] = root
            self.pattern_value_positions[pattern] = inputs

        # Create ordered predicates with frequency analysis
        ordered_predicates = self._create_ordered_predicates(all_pattern_predicates)
        # Sort predicates by priority
        sorted_predicates = sorted(ordered_predicates.values())
        sorted_predicates = _stable_topological_sort(sorted_predicates)

        # Build matcher tree by propagating patterns
        root_node = None
        for pattern, predicates in all_pattern_predicates:
            if not predicates:
                continue
            pattern_predicate_set = {
                (pred.position, pred.q): pred for pred in predicates
            }
            root_node = self._propagate_pattern(
                root_node, pattern, pattern_predicate_set, sorted_predicates, 0
            )

        # Add exit node and optimize
        if root_node is not None:
            root_node = self._optimize_tree(root_node)
            root_node = self._insert_exit_node(root_node)
            return root_node
        else:
            # Return a default exit node if no patterns were processed
            return ExitNode()

    def _extract_pattern_predicates(
        self, pattern: pdl.PatternOp
    ) -> tuple[list[PositionalPredicate], SSAValue, dict[SSAValue, Position]]:
        """Extract all predicates for a single pattern"""
        predicates: list[PositionalPredicate] = []
        inputs: dict[SSAValue, Position] = {}

        roots = self.analyzer.detect_roots(pattern)
        if len(roots) != 1:
            raise ValueError("Multi-root patterns are not yet supported.")

        rewriter = pattern.body.block.last_op
        assert isinstance(rewriter, pdl.RewriteOp)
        best_root = rewriter.root if rewriter.root is not None else roots[0]

        # Downward traversal from the best root
        root_pos = OperationPosition(depth=0)
        predicates.extend(
            self.analyzer.extract_tree_predicates(best_root, root_pos, inputs)
        )

        predicates.extend(self.analyzer.extract_non_tree_predicates(pattern, inputs))
        return predicates, best_root, inputs

    def _create_ordered_predicates(
        self,
        all_pattern_predicates: list[tuple[pdl.PatternOp, list[PositionalPredicate]]],
    ) -> dict[tuple[Position, Question], OrderedPredicate]:
        """Create ordered predicates with frequency analysis"""
        predicate_map: dict[tuple[Position, Question], OrderedPredicate] = {}
        tie_breaker = 0

        # Collect unique predicates
        for pattern, predicates in all_pattern_predicates:
            for pred in predicates:
                key = (pred.position, pred.q)

                if key not in predicate_map:
                    ordered_pred = OrderedPredicate(
                        position=pred.position,
                        question=pred.q,
                        tie_breaker=tie_breaker,
                    )
                    predicate_map[key] = ordered_pred
                    tie_breaker += 1

                # Track pattern answers and increment frequency
                predicate_map[key].pattern_answers[pattern] = pred.a
                predicate_map[key].primary_score += 1

        # Calculate secondary scores
        for pattern, predicates in all_pattern_predicates:
            pattern_primary_sum = 0
            seen_keys: set[tuple[Position, Question]] = (
                set()
            )  # Track unique keys per pattern

            # First pass: collect unique predicates for this pattern
            for pred in predicates:
                key = (pred.position, pred.q)
                if key not in seen_keys:
                    seen_keys.add(key)
                    ordered_pred = predicate_map[key]
                    pattern_primary_sum += ordered_pred.primary_score**2

            # Second pass: add secondary score to each unique predicate
            for key in seen_keys:
                ordered_pred = predicate_map[key]
                ordered_pred.secondary_score += pattern_primary_sum

        return predicate_map

    def _propagate_pattern(
        self,
        node: MatcherNode | None,
        pattern: pdl.PatternOp,
        pattern_predicates: dict[tuple[Position, Question], PositionalPredicate],
        sorted_predicates: list[OrderedPredicate],
        predicate_index: int,
    ) -> MatcherNode:
        """Propagate a pattern through the predicate tree"""

        # Base case: reached end of predicates
        if predicate_index >= len(sorted_predicates):
            root_val = self._pattern_roots.get(pattern)
            return SuccessNode(pattern=pattern, root=root_val, failure_node=node)

        current_predicate = sorted_predicates[predicate_index]
        pred_key = (current_predicate.position, current_predicate.question)

        # Skip predicates not in this pattern
        if pred_key not in pattern_predicates:
            return self._propagate_pattern(
                node,
                pattern,
                pattern_predicates,
                sorted_predicates,
                predicate_index + 1,
            )

        # Create or match existing node
        if node is None:
            # Create new switch node
            node = SwitchNode(
                position=current_predicate.position, question=current_predicate.question
            )

        if self._nodes_match(node, current_predicate):
            # Continue down matching path
            pattern_answer = pattern_predicates[pred_key].a

            if isinstance(node, SwitchNode):
                if pattern_answer not in node.children:
                    node.children[pattern_answer] = None

                node.children[pattern_answer] = self._propagate_pattern(
                    node.children[pattern_answer],
                    pattern,
                    pattern_predicates,
                    sorted_predicates,
                    predicate_index + 1,
                )

        else:
            # Divergence - continue down failure path
            node.failure_node = self._propagate_pattern(
                node.failure_node,
                pattern,
                pattern_predicates,
                sorted_predicates,
                predicate_index,
            )

        return node

    def _nodes_match(self, node: MatcherNode, predicate: OrderedPredicate) -> bool:
        """Check if node matches the given predicate"""
        return (
            node.position == predicate.position and node.question == predicate.question
        )

    def _insert_exit_node(self, root: MatcherNode) -> MatcherNode:
        """Insert exit node at end of failure paths"""
        curr = root
        while curr.failure_node:
            curr = curr.failure_node
        curr.failure_node = ExitNode()
        return root

    def _optimize_tree(self, root: MatcherNode) -> MatcherNode:
        """Optimize the tree by collapsing single-child switches to bools"""
        # Recursively optimize children
        if isinstance(root, SwitchNode):
            for answer in root.children:
                child_node = root.children[answer]
                if child_node is not None:
                    root.children[answer] = self._optimize_tree(child_node)
        elif isinstance(root, BoolNode):
            if root.success_node is not None:
                root.success_node = self._optimize_tree(root.success_node)

        if root.failure_node is not None:
            root.failure_node = self._optimize_tree(root.failure_node)

        if isinstance(root, SwitchNode) and len(root.children) == 1:
            # Convert switch to bool node
            answer, child = next(iter(root.children.items()))
            bool_node = BoolNode(
                position=root.position,
                question=root.question,
                success_node=child,
                failure_node=root.failure_node,
                answer=answer,
            )
            return bool_node

        return root

analyzer: PatternAnalyzer = PatternAnalyzer() instance-attribute

pattern_value_positions: dict[pdl.PatternOp, dict[SSAValue, Position]] = {} instance-attribute

__init__()

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
775
776
777
778
def __init__(self):
    self.analyzer = PatternAnalyzer()
    self._pattern_roots = {}
    self.pattern_value_positions = {}

build_predicate_tree(patterns: list[pdl.PatternOp]) -> MatcherNode

Build optimized matcher tree from multiple patterns

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
def build_predicate_tree(self, patterns: list[pdl.PatternOp]) -> MatcherNode:
    """Build optimized matcher tree from multiple patterns"""

    # Extract predicates for all patterns
    all_pattern_predicates: list[
        tuple[pdl.PatternOp, list[PositionalPredicate]]
    ] = []
    for pattern in patterns:
        predicates, root, inputs = self._extract_pattern_predicates(pattern)
        all_pattern_predicates.append((pattern, predicates))
        self._pattern_roots[pattern] = root
        self.pattern_value_positions[pattern] = inputs

    # Create ordered predicates with frequency analysis
    ordered_predicates = self._create_ordered_predicates(all_pattern_predicates)
    # Sort predicates by priority
    sorted_predicates = sorted(ordered_predicates.values())
    sorted_predicates = _stable_topological_sort(sorted_predicates)

    # Build matcher tree by propagating patterns
    root_node = None
    for pattern, predicates in all_pattern_predicates:
        if not predicates:
            continue
        pattern_predicate_set = {
            (pred.position, pred.q): pred for pred in predicates
        }
        root_node = self._propagate_pattern(
            root_node, pattern, pattern_predicate_set, sorted_predicates, 0
        )

    # Add exit node and optimize
    if root_node is not None:
        root_node = self._optimize_tree(root_node)
        root_node = self._insert_exit_node(root_node)
        return root_node
    else:
        # Return a default exit node if no patterns were processed
        return ExitNode()

MatcherGenerator

Generates PDL interpreter matcher from matcher tree

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
class MatcherGenerator:
    """Generates PDL interpreter matcher from matcher tree"""

    matcher_func: pdl_interp.FuncOp
    rewriter_module: ModuleOp
    rewriter_builder: Builder
    value_to_position: dict[pdl.PatternOp, dict[SSAValue, Position]]
    values: ScopedDict[Position, SSAValue]
    failure_block_stack: list[Block]
    builder: Builder
    constraint_op_map: dict[ConstraintQuestion, pdl_interp.ApplyConstraintOp]
    rewriter_names: dict[str, int]

    def __init__(
        self,
        matcher_func: pdl_interp.FuncOp,
        rewriter_module: ModuleOp,
        optimize_for_eqsat: bool = False,
    ) -> None:
        self.matcher_func = matcher_func
        self.rewriter_module = rewriter_module
        self.rewriter_builder = Builder(InsertPoint.at_end(rewriter_module.body.block))
        self.value_to_position = {}
        self.values = ScopedDict()
        self.failure_block_stack = []
        self.builder = Builder(InsertPoint.at_start(matcher_func.body.block))
        self.constraint_op_map = {}
        self.rewriter_names = {}

    def lower(self, patterns: list[pdl.PatternOp]) -> None:
        """Lower PDL patterns to PDL interpreter"""

        # Build the predicate tree
        tree_builder = PredicateTreeBuilder()
        root = tree_builder.build_predicate_tree(patterns)
        self.value_to_position = tree_builder.pattern_value_positions

        # Get the entry block and add root operation argument
        entry_block = self.matcher_func.body.block

        # The first argument is the root operation
        root_pos = OperationPosition(depth=0)
        self.values[root_pos] = entry_block.args[0]

        # Generate the matcher
        _ = self.generate_matcher(root, self.matcher_func.body, block=entry_block)

    def generate_matcher(
        self, node: MatcherNode, region: Region, block: Block | None = None
    ) -> Block:
        """Generate PDL interpreter operations for a matcher node"""

        # Create block if needed
        if block is None:
            block = Block()
            region.add_block(block)

        # Set insertion point to end of this block
        self.builder.insertion_point = InsertPoint.at_end(block)

        # Handle exit node - just add finalize
        if isinstance(node, ExitNode):
            self.builder.insert(pdl_interp.FinalizeOp())
            return block

        self.values = ScopedDict(self.values)
        assert self.values.parent is not None

        # Handle failure node
        failure_block = None
        if node.failure_node:
            failure_block = self.generate_matcher(node.failure_node, region)
            self.failure_block_stack.append(failure_block)
            # Restore insertion point after generating failure node
            self.builder.insertion_point = InsertPoint.at_end(block)
        else:
            assert self.failure_block_stack, "Expected valid failure block"
            failure_block = self.failure_block_stack[-1]

        # Get value for position if exists (may change insertion point)
        val = None
        if node.position:
            val = self.get_value_at(node.position)

        # Dispatch based on node type
        match node:
            case BoolNode():
                assert val is not None
                self.generate_bool_node(node, val)
            case SwitchNode():
                assert val is not None
                self.generate_switch_node(node, val)
            case SuccessNode():
                self.generate_success_node(node)
            case _:
                raise NotImplementedError(f"Unhandled node type {type(node)}")

        # Pop failure block if we pushed one
        if node.failure_node:
            self.failure_block_stack.pop()

        self.values = self.values.parent  # Pop scope
        return block

    def get_value_at(self, position: Position) -> SSAValue:
        """Get or create SSA value for a position.

        Assumes self.builder.insertion_point is correctly set.
        May modify the insertion point (e.g., when creating foreach loops).
        """

        # Check cache
        if position in self.values:
            return self.values[position]

        # Get parent value if needed (may change insertion point)
        parent_val = None
        if position.parent:
            parent_val = self.get_value_at(position.parent)

        # Create value based on position type
        value = None

        if isinstance(position, OperationPosition):
            if position.is_operand_defining_op():
                assert parent_val is not None
                # Get defining operation of operand
                defining_op = pdl_interp.GetDefiningOpOp(parent_val)
                defining_op.attributes["position"] = StringAttr(position.__repr__())
                self.builder.insert(defining_op)
                value = defining_op.input_op
            else:
                # Passthrough
                value = parent_val

        elif isinstance(position, OperandPosition):
            assert parent_val is not None
            get_operand_op = pdl_interp.GetOperandOp(
                position.operand_number, parent_val
            )
            self.builder.insert(get_operand_op)
            value = get_operand_op.value

        elif isinstance(position, OperandGroupPosition):
            assert parent_val is not None
            # Get operands (possibly variadic)
            result_type = (
                pdl.RangeType(pdl.ValueType())
                if position.is_variadic
                else pdl.ValueType()
            )
            get_operands_op = pdl_interp.GetOperandsOp(
                position.group_number, parent_val, result_type
            )
            self.builder.insert(get_operands_op)
            value = get_operands_op.value

        elif isinstance(position, ResultPosition):
            assert parent_val is not None
            get_result_op = pdl_interp.GetResultOp(position.result_number, parent_val)
            self.builder.insert(get_result_op)
            value = get_result_op.value

        elif isinstance(position, ResultGroupPosition):
            assert parent_val is not None
            # Get results (possibly variadic)
            result_type = (
                pdl.RangeType(pdl.ValueType())
                if position.is_variadic
                else pdl.ValueType()
            )
            get_results_op = pdl_interp.GetResultsOp(
                position.group_number, parent_val, result_type
            )
            self.builder.insert(get_results_op)
            value = get_results_op.value

        elif isinstance(position, AttributePosition):
            assert parent_val is not None
            get_attr_op = pdl_interp.GetAttributeOp(position.attribute_name, parent_val)
            self.builder.insert(get_attr_op)
            value = get_attr_op.value

        elif isinstance(position, AttributeLiteralPosition):
            # Create a constant attribute
            create_attr_op = pdl_interp.CreateAttributeOp(position.value)
            self.builder.insert(create_attr_op)
            value = create_attr_op.attribute

        elif isinstance(position, TypePosition):
            assert parent_val is not None
            # Get type of value or attribute
            if parent_val.type == pdl.AttributeType():
                get_type_op = pdl_interp.GetAttributeTypeOp(parent_val)
            else:
                get_type_op = pdl_interp.GetValueTypeOp(parent_val)
            self.builder.insert(get_type_op)
            value = get_type_op.result

        elif isinstance(position, TypeLiteralPosition):
            # Create a constant type or types
            raw_type_attr = position.value
            if isinstance(raw_type_attr, TypeAttribute):
                create_type_op = pdl_interp.CreateTypeOp(raw_type_attr)
                self.builder.insert(create_type_op)
                value = create_type_op.result
            else:
                # Assume it's an ArrayAttr of types
                assert isinstance(raw_type_attr, ArrayAttr)
                type_attr = cast(ArrayAttr[TypeAttribute], raw_type_attr)
                create_types_op = pdl_interp.CreateTypesOp(type_attr)
                self.builder.insert(create_types_op)
                value = create_types_op.result

        elif isinstance(position, ConstraintPosition):
            # The constraint op has already been created, find it in the map
            constraint_op = self.constraint_op_map.get(position.constraint)
            assert constraint_op is not None
            value = constraint_op.results[position.result_index]

        elif isinstance(position, UsersPosition):
            raise NotImplementedError("UsersPosition not implemented in lowering")
        elif isinstance(position, ForEachPosition):
            raise NotImplementedError("ForEachPosition not implemented in lowering")
        else:
            raise NotImplementedError(f"Unhandled position type {type(position)}")

        # Cache and return
        assert value is not None
        self.values[position] = value
        return value

    def generate_bool_node(self, node: BoolNode, val: SSAValue) -> None:
        """Generate operations for a boolean predicate node.

        Assumes self.builder.insertion_point is correctly set.
        """

        question = node.question
        answer = node.answer
        block = self.builder.insertion_point.block
        region = block.parent
        assert region is not None, "Block must be in a region"

        # Handle getValue queries first for constraint questions (may change insertion point)
        args: list[SSAValue] = []
        if isinstance(question, EqualToQuestion):
            args = [self.get_value_at(question.other_position)]
        elif isinstance(question, ConstraintQuestion):
            for position in question.arg_positions:
                args.append(self.get_value_at(position))

        # Get the current block after potentially changed insertion point
        block = self.builder.insertion_point.block
        region = block.parent
        assert region is not None, "Block must be in a region"

        # Create success block
        success_block = Block()
        region.add_block(success_block)
        failure_block = self.failure_block_stack[-1]

        # Generate predicate check operation based on question type
        match question:
            case IsNotNullQuestion():
                check_op = pdl_interp.IsNotNullOp(val, success_block, failure_block)
            case OperationNameQuestion():
                assert isinstance(answer, StringAnswer)
                check_op = pdl_interp.CheckOperationNameOp(
                    answer.value, val, success_block, failure_block
                )
            case OperandCountQuestion() | OperandCountAtLeastQuestion():
                assert isinstance(answer, UnsignedAnswer)
                compare_at_least = isinstance(question, OperandCountAtLeastQuestion)
                check_op = pdl_interp.CheckOperandCountOp(
                    val, answer.value, success_block, failure_block, compare_at_least
                )
            case ResultCountQuestion() | ResultCountAtLeastQuestion():
                assert isinstance(answer, UnsignedAnswer)
                compare_at_least = isinstance(question, ResultCountAtLeastQuestion)
                check_op = pdl_interp.CheckResultCountOp(
                    val, answer.value, success_block, failure_block, compare_at_least
                )
            case EqualToQuestion():
                # Get the other value to compare with
                other_val = self.get_value_at(question.other_position)
                # Update block reference after potential insertion point change
                block = self.builder.insertion_point.block
                assert isinstance(answer, TrueAnswer)
                check_op = pdl_interp.AreEqualOp(
                    val, other_val, success_block, failure_block
                )
            case AttributeConstraintQuestion():
                assert isinstance(answer, AttributeAnswer)
                check_op = pdl_interp.CheckAttributeOp(
                    answer.value, val, success_block, failure_block
                )
            case TypeConstraintQuestion():
                assert isinstance(answer, TypeAnswer)
                if isinstance(val.type, pdl.RangeType):
                    # Check multiple types
                    assert isinstance(answer.value, ArrayAttr)
                    check_op = pdl_interp.CheckTypesOp(
                        answer.value, val, success_block, failure_block
                    )
                else:
                    # Check single type
                    assert isinstance(answer.value, TypeAttribute)
                    check_op = pdl_interp.CheckTypeOp(
                        answer.value, val, success_block, failure_block
                    )
            case ConstraintQuestion():
                check_op = pdl_interp.ApplyConstraintOp(
                    question.name,
                    args,
                    success_block,
                    failure_block,
                    is_negated=question.is_negated,
                    res_types=question.result_types,
                )
                # Store the constraint op for later result access
                self.constraint_op_map[question] = check_op
            case _:
                raise NotImplementedError(f"Unhandled question type {type(question)}")

        self.builder.insert(check_op)

        # Generate matcher for success node
        if node.success_node:
            self.generate_matcher(node.success_node, region, success_block)

    def generate_switch_node(self, node: SwitchNode, val: SSAValue) -> None:
        """Generate operations for a switch node.

        Assumes self.builder.insertion_point is correctly set.
        """

        question = node.question
        block = self.builder.insertion_point.block
        region = block.parent
        assert region is not None, "Block must be in a region"
        default_dest = self.failure_block_stack[-1]

        # Handle at-least questions specially
        if isinstance(
            question, OperandCountAtLeastQuestion | ResultCountAtLeastQuestion
        ):
            # Sort children in reverse numerical order
            sorted_children = sorted(
                node.children.items(),
                key=lambda x: cast(UnsignedAnswer, x[0]).value,
                reverse=True,
            )

            # Push temporary entry to failure block stack
            self.failure_block_stack.append(default_dest)

            for answer, child_node in sorted_children:
                if child_node:
                    success_block = self.generate_matcher(child_node, region)
                    current_check_block = Block()
                    region.insert_block_before(current_check_block, success_block)
                    self.builder.insertion_point = InsertPoint.at_end(
                        current_check_block
                    )
                    assert isinstance(answer, UnsignedAnswer)
                    if isinstance(question, OperandCountAtLeastQuestion):
                        check_op = pdl_interp.CheckOperandCountOp(
                            val,
                            answer.value,
                            success_block,
                            default_dest,
                            True,
                        )
                    else:
                        check_op = pdl_interp.CheckResultCountOp(
                            val,
                            answer.value,
                            success_block,
                            default_dest,
                            True,
                        )
                    self.builder.insert(check_op)

                    # Update failure block stack for next child matcher
                    self.failure_block_stack[-1] = current_check_block

            # Pop the temporary entry from failure block stack
            first_predicate_block = self.failure_block_stack.pop()

            # Move ops from the first check block into the main block
            for op in list(first_predicate_block.ops):
                op.detach()
                block.add_op(op)
            assert first_predicate_block.parent is not None
            first_predicate_block.parent.detach_block(first_predicate_block)
            first_predicate_block.erase()

            return

        # Generate child blocks and collect case values
        case_blocks: list[Block] = []
        case_values: list[Answer] = []

        for answer, child_node in node.children.items():
            if child_node:
                child_block = self.generate_matcher(child_node, region)
                case_blocks.append(child_block)
                case_values.append(answer)

        # Restore insertion point after generating child matchers
        self.builder.insertion_point = InsertPoint.at_end(block)

        # Create switch operation based on question type
        match question:
            case OperationNameQuestion():
                # Extract string values from StringAnswer objects
                switch_values = [cast(StringAnswer, ans).value for ans in case_values]
                switch_attr = ArrayAttr([StringAttr(v) for v in switch_values])
                switch_op = pdl_interp.SwitchOperationNameOp(
                    switch_attr, val, default_dest, case_blocks
                )
            case OperandCountQuestion():
                # Extract integer values from UnsignedAnswer objects
                switch_values = [cast(UnsignedAnswer, ans).value for ans in case_values]
                switch_op = pdl_interp.SwitchOperandCountOp(
                    switch_values, val, default_dest, case_blocks
                )
            case ResultCountQuestion():
                # Extract integer values from UnsignedAnswer objects
                switch_values = [cast(UnsignedAnswer, ans).value for ans in case_values]
                switch_op = pdl_interp.SwitchResultCountOp(
                    switch_values, val, default_dest, case_blocks
                )
            case TypeConstraintQuestion():
                # Extract type attributes from TypeAnswer objects
                switch_values = [cast(TypeAnswer, ans).value for ans in case_values]
                if isinstance(val.type, pdl.RangeType):
                    assert isa(switch_values, list[ArrayAttr[TypeAttribute]])
                    switch_attr = ArrayAttr(switch_values)

                    switch_op = pdl_interp.SwitchTypesOp(
                        switch_attr, val, default_dest, case_blocks
                    )
                else:
                    assert isa(switch_values, list[TypeAttribute])
                    switch_attr = ArrayAttr(switch_values)
                    switch_op = pdl_interp.SwitchTypeOp(
                        switch_attr, val, default_dest, case_blocks
                    )
            case AttributeConstraintQuestion():
                # Extract attribute values from AttributeAnswer objects
                switch_values = [
                    cast(AttributeAnswer, ans).value for ans in case_values
                ]
                switch_attr = ArrayAttr(switch_values)
                switch_op = pdl_interp.SwitchAttributeOp(
                    val, switch_attr, default_dest, case_blocks
                )
            case _:
                raise NotImplementedError(f"Unhandled question type {type(question)}")

        self.builder.insert(switch_op)

    def generate_success_node(self, node: SuccessNode) -> None:
        """Generate operations for a successful match.

        Assumes self.builder.insertion_point is correctly set.
        """

        pattern = node.pattern
        root = node.root

        # Generate a rewriter for the pattern
        used_match_positions: list[Position] = []
        rewriter_func_ref = self.generate_rewriter(pattern, used_match_positions)

        # Process values used in the rewrite that are defined in the match
        # (may change insertion point)
        mapped_match_values = [self.get_value_at(pos) for pos in used_match_positions]

        # Collect generated op names from DAG rewriter
        rewriter_op = pattern.body.block.last_op
        assert isinstance(rewriter_op, pdl.RewriteOp)
        if not rewriter_op.name:
            generated_op_names = ArrayAttr(
                [
                    op.opName
                    for op in rewriter_op.body.walk()
                    if isinstance(op, pdl.OperationOp) and op.opName
                ]
            )
        else:
            generated_op_names = None
        # Get root kind if present
        root_kind: StringAttr | None = None
        if root:
            defining_op = root.owner
            if isinstance(defining_op, pdl.OperationOp) and defining_op.opName:
                root_kind = StringAttr(defining_op.opName.data)

        # Create the RecordMatchOp
        record_op = pdl_interp.RecordMatchOp(
            rewriter_func_ref,
            root_kind,
            generated_op_names,
            pattern.benefit,
            mapped_match_values,
            [],
            self.failure_block_stack[-1],
        )
        self.builder.insert(record_op)

    def generate_rewriter(
        self, pattern: pdl.PatternOp, used_match_positions: list[Position]
    ) -> SymbolRefAttr:
        """
        Generate a rewriter function for the given pattern, and return a
        reference to that function.
        """
        rewriter_op = pattern.body.block.last_op
        assert isinstance(rewriter_op, pdl.RewriteOp)

        if pattern.sym_name:
            rewriter_name = pattern.sym_name.data
        else:
            rewriter_name = "pdl_generated_rewriter"
        if rewriter_name in self.rewriter_names:
            # duplicate names get a numeric suffix starting from 0 (foo, foo_0, foo_1, ...)
            self.rewriter_names[rewriter_name] += 1
            rewriter_name = f"{rewriter_name}_{self.rewriter_names[rewriter_name] - 2}"
        else:
            self.rewriter_names[rewriter_name] = 1

        # Create the rewriter function
        rewriter_func = pdl_interp.FuncOp(rewriter_name, ([], []))

        self.rewriter_module.body.block.add_op(rewriter_func)
        entry_block = rewriter_func.body.block
        self.rewriter_builder.insertion_point = InsertPoint.at_end(entry_block)

        rewrite_values: dict[SSAValue, SSAValue] = {}
        pattern_value_positions = self.value_to_position[pattern]

        def map_rewrite_value(old_value: SSAValue) -> SSAValue:
            if new_value := rewrite_values.get(old_value):
                return new_value

            # Prefer materializing constants directly when possible.
            old_op = old_value.owner
            new_val_op: Operation | None = None
            if isinstance(old_op, pdl.AttributeOp) and old_op.value:
                new_val_op = pdl_interp.CreateAttributeOp(old_op.value)
            elif isinstance(old_op, pdl.TypeOp) and old_op.constantType:
                new_val_op = pdl_interp.CreateTypeOp(old_op.constantType)
            elif isinstance(old_op, pdl.TypesOp) and old_op.constantTypes:
                new_val_op = pdl_interp.CreateTypesOp(old_op.constantTypes)

            if new_val_op is not None:
                self.rewriter_builder.insert(new_val_op)
                new_value = new_val_op.results[0]
                rewrite_values[old_value] = new_value
                return new_value

            # Otherwise, it's an input from the matcher.
            input_pos = pattern_value_positions.get(old_value)
            assert input_pos is not None, "Expected value to be a pattern input"
            if input_pos not in used_match_positions:
                used_match_positions.append(input_pos)

            arg = entry_block.insert_arg(old_value.type, len(entry_block.args))
            rewrite_values[old_value] = arg
            return arg

        # If this is a custom rewriter, dispatch to the registered method.
        if rewriter_op.name_:
            args: list[SSAValue] = []
            if rewriter_op.root:
                args.append(map_rewrite_value(rewriter_op.root))
            args.extend(map_rewrite_value(arg) for arg in rewriter_op.external_args)

            apply_op = pdl_interp.ApplyRewriteOp(rewriter_op.name_.data, args)
            self.rewriter_builder.insert(apply_op)
        else:
            # Otherwise, this is a DAG rewriter defined using PDL operations.
            assert rewriter_op.body is not None
            for op in rewriter_op.body.ops:
                match op:
                    case pdl.ApplyNativeRewriteOp():
                        self._generate_rewriter_for_apply_native_rewrite(
                            op, rewrite_values, map_rewrite_value
                        )
                    case pdl.AttributeOp():
                        self._generate_rewriter_for_attribute(
                            op, rewrite_values, map_rewrite_value
                        )
                    case pdl.EraseOp():
                        self._generate_rewriter_for_erase(
                            op, rewrite_values, map_rewrite_value
                        )
                    case pdl.OperationOp():
                        self._generate_rewriter_for_operation(
                            op, rewrite_values, map_rewrite_value
                        )
                    case pdl.RangeOp():
                        self._generate_rewriter_for_range(
                            op, rewrite_values, map_rewrite_value
                        )
                    case pdl.ReplaceOp():
                        self._generate_rewriter_for_replace(
                            op, rewrite_values, map_rewrite_value
                        )
                    case pdl.ResultOp():
                        self._generate_rewriter_for_result(
                            op, rewrite_values, map_rewrite_value
                        )
                    case pdl.ResultsOp():
                        self._generate_rewriter_for_results(
                            op, rewrite_values, map_rewrite_value
                        )
                    case pdl.TypeOp():
                        self._generate_rewriter_for_type(
                            op, rewrite_values, map_rewrite_value
                        )
                    case pdl.TypesOp():
                        self._generate_rewriter_for_types(
                            op, rewrite_values, map_rewrite_value
                        )
                    case _:
                        raise TypeError(f"Unexpected op type: {type(op)}")

        # Update the signature of the rewrite function.
        rewriter_func.function_type = FunctionType.from_lists(entry_block.arg_types, ())

        self.rewriter_builder.insert(pdl_interp.FinalizeOp())
        return SymbolRefAttr(
            "rewriters",
            [
                StringAttr(rewriter_name),
            ],
        )

    def _generate_rewriter_for_apply_native_rewrite(
        self,
        op: pdl.ApplyNativeRewriteOp,
        rewrite_values: dict[SSAValue, SSAValue],
        map_rewrite_value: Callable[[SSAValue], SSAValue],
    ):
        arguments = [map_rewrite_value(arg) for arg in op.args]
        result_types = [res.type for res in op.res]
        interp_op = pdl_interp.ApplyRewriteOp(
            op.constraint_name, arguments, result_types
        )
        self.rewriter_builder.insert(interp_op)
        for old_res, new_res in zip(op.results, interp_op.results, strict=True):
            rewrite_values[old_res] = new_res

    def _generate_rewriter_for_attribute(
        self,
        op: pdl.AttributeOp,
        rewrite_values: dict[SSAValue, SSAValue],
        map_rewrite_value: Callable[[SSAValue], SSAValue],
    ):
        if op.value is not None:
            new_attr_op = pdl_interp.CreateAttributeOp(op.value)
            self.rewriter_builder.insert(new_attr_op)
            rewrite_values[op.output] = new_attr_op.attribute

    def _generate_rewriter_for_erase(
        self,
        op: pdl.EraseOp,
        rewrite_values: dict[SSAValue, SSAValue],
        map_rewrite_value: Callable[[SSAValue], SSAValue],
    ) -> None:
        self.rewriter_builder.insert(pdl_interp.EraseOp(map_rewrite_value(op.op_value)))

    def _generate_rewriter_for_operation(
        self,
        op: pdl.OperationOp,
        rewrite_values: dict[SSAValue, SSAValue],
        map_rewrite_value: Callable[[SSAValue], SSAValue],
    ):
        operands = tuple(map_rewrite_value(operand) for operand in op.operand_values)
        attributes = tuple(map_rewrite_value(attr) for attr in op.attribute_values)

        types: list[SSAValue] = []
        has_inferred_result_types = self._generate_operation_result_type_rewriter(
            op, map_rewrite_value, types, rewrite_values
        )

        if op.opName is None:
            raise ValueError("Cannot create operation without a name.")

        create_op = pdl_interp.CreateOperationOp(
            op.opName,
            UnitAttr() if has_inferred_result_types else None,
            op.attributeValueNames,
            operands,
            attributes,
            types,
        )
        self.rewriter_builder.insert(create_op)
        created_op_val = create_op.result_op
        rewrite_values[op.op] = created_op_val

        # Generate accesses for any results that have their types constrained.
        result_types = op.type_values
        if len(result_types) == 1 and isinstance(result_types[0].type, pdl.RangeType):
            if result_types[0] not in rewrite_values:
                get_results = pdl_interp.GetResultsOp(
                    None, created_op_val, pdl.RangeType(pdl.ValueType())
                )
                self.rewriter_builder.insert(get_results)
                get_type = pdl_interp.GetValueTypeOp(get_results.value)
                self.rewriter_builder.insert(get_type)
                rewrite_values[result_types[0]] = get_type.result
            return

        seen_variable_length = False
        for i, type_value in enumerate(result_types):
            if type_value in rewrite_values:
                continue
            is_variadic = isinstance(type_value.type, pdl.RangeType)
            seen_variable_length = seen_variable_length or is_variadic

            result_val: SSAValue
            if seen_variable_length:
                get_results = pdl_interp.GetResultsOp(
                    i, created_op_val, pdl.RangeType(pdl.ValueType())
                )
                self.rewriter_builder.insert(get_results)
                result_val = get_results.value
            else:
                get_result = pdl_interp.GetResultOp(i, created_op_val)
                self.rewriter_builder.insert(get_result)
                result_val = get_result.value

            get_type = pdl_interp.GetValueTypeOp(result_val)
            self.rewriter_builder.insert(get_type)
            rewrite_values[type_value] = get_type.result

    def _generate_rewriter_for_range(
        self,
        op: pdl.RangeOp,
        rewrite_values: dict[SSAValue, SSAValue],
        map_rewrite_value: Callable[[SSAValue], SSAValue],
    ) -> None:
        args = [map_rewrite_value(arg) for arg in op.arguments]
        create_range_op = pdl_interp.CreateRangeOp(args, op.result.type)
        self.rewriter_builder.insert(create_range_op)
        rewrite_values[op.result] = create_range_op.result

    def _generate_rewriter_for_replace(
        self,
        op: pdl.ReplaceOp,
        rewrite_values: dict[SSAValue, SSAValue],
        map_rewrite_value: Callable[[SSAValue], SSAValue],
    ):
        if op.repl_operation:
            op_op_def = op.op_value.owner
            # either we statically know the operation return types, or we
            # don't, in which case we assume there are results such that
            # we don't incorrectly erase the operation instead of replacing it.
            has_results = (
                not isinstance(op_op_def, pdl.OperationOp) or op_op_def.type_values
            )
            if has_results:
                get_results = pdl_interp.GetResultsOp(
                    None,
                    map_rewrite_value(op.repl_operation),
                    pdl.RangeType(pdl.ValueType()),
                )
                self.rewriter_builder.insert(get_results)
                repl_operands = (get_results.value,)
            else:
                # The new operation has no results to replace with
                repl_operands = ()
        else:
            repl_operands = tuple(map_rewrite_value(val) for val in op.repl_values)

        mapped_op_value = map_rewrite_value(op.op_value)
        if not repl_operands:
            # Note that if an operation is replaced by a new one, the new operation
            # will already have been inserted during `pdl_interp.create_operation`.
            # In case there are no new values to replace the op with,
            # a replacement is the same as just erasing the op.
            self.rewriter_builder.insert(pdl_interp.EraseOp(mapped_op_value))
        else:
            self.rewriter_builder.insert(
                pdl_interp.ReplaceOp(mapped_op_value, repl_operands)
            )

    def _generate_rewriter_for_result(
        self,
        op: pdl.ResultOp,
        rewrite_values: dict[SSAValue, SSAValue],
        map_rewrite_value: Callable[[SSAValue], SSAValue],
    ):
        get_result_op = pdl_interp.GetResultOp(op.index, map_rewrite_value(op.parent_))
        self.rewriter_builder.insert(get_result_op)
        rewrite_values[op.val] = get_result_op.value

    def _generate_rewriter_for_results(
        self,
        op: pdl.ResultsOp,
        rewrite_values: dict[SSAValue, SSAValue],
        map_rewrite_value: Callable[[SSAValue], SSAValue],
    ):
        get_results_op = pdl_interp.GetResultsOp(
            op.index, map_rewrite_value(op.parent_), op.val.type
        )
        self.rewriter_builder.insert(get_results_op)
        rewrite_values[op.val] = get_results_op.value

    def _generate_rewriter_for_type(
        self,
        op: pdl.TypeOp,
        rewrite_values: dict[SSAValue, SSAValue],
        map_rewrite_value: Callable[[SSAValue], SSAValue],
    ):
        if op.constantType:
            create_type_op = pdl_interp.CreateTypeOp(op.constantType)
            self.rewriter_builder.insert(create_type_op)
            rewrite_values[op.result] = create_type_op.result

    def _generate_rewriter_for_types(
        self,
        op: pdl.TypesOp,
        rewrite_values: dict[SSAValue, SSAValue],
        map_rewrite_value: Callable[[SSAValue], SSAValue],
    ):
        if op.constantTypes:
            create_types_op = pdl_interp.CreateTypesOp(op.constantTypes)
            self.rewriter_builder.insert(create_types_op)
            rewrite_values[op.result] = create_types_op.result
        # Else, nothing needs to be created.
        # A `pdl.type` operation in the rewrite section is
        # not used as a declarative constraint. If there is
        # no constantTypes, it is essentially a no-op.

    def _generate_operation_result_type_rewriter(
        self,
        op: pdl.OperationOp,
        map_rewrite_value: Callable[[SSAValue], SSAValue],
        types_list: list[SSAValue],
        rewrite_values: dict[SSAValue, SSAValue],
    ) -> bool:
        """Returns `has_inferred_result_types`"""
        rewriter_block = op.parent
        assert rewriter_block is not None
        result_type_values = op.type_values

        # Strategy 1: Resolve all types individually
        if result_type_values:
            temp_types: list[SSAValue] = []
            can_resolve_all = True
            for result_type in result_type_values:
                if (val := rewrite_values.get(result_type)) is not None:
                    temp_types.append(val)
                elif result_type.owner.parent is not rewriter_block:
                    temp_types.append(map_rewrite_value(result_type))
                else:
                    can_resolve_all = False
                    break
            if can_resolve_all:
                types_list.extend(temp_types)
                return False

        # Strategy 2: Check if created op has `inferredResultTypes` interface
        # This interface doesn't exist in xDSL, so we don't do this yet.
        # https://github.com/xdslproject/xdsl/issues/5455

        # Strategy 3: Infer from a replaced operation
        for use in op.op.uses:
            user_op = use.operation
            if not isinstance(user_op, pdl.ReplaceOp) or use.index == 0:
                continue

            replaced_op_val = user_op.op_value
            replaced_op_def = replaced_op_val.owner
            assert isinstance(replaced_op_def, Operation)
            if (
                replaced_op_def.parent is rewriter_block
                and not replaced_op_def.is_before_in_block(op)
            ):
                continue

            mapped_replaced_op = map_rewrite_value(replaced_op_val)
            get_results = pdl_interp.GetResultsOp(
                None, mapped_replaced_op, pdl.RangeType(pdl.ValueType())
            )
            self.rewriter_builder.insert(get_results)
            get_type = pdl_interp.GetValueTypeOp(get_results.value)
            self.rewriter_builder.insert(get_type)
            types_list.append(get_type.result)
            return False

        # Strategy 4: If no explicit types, assume no results
        if not result_type_values:
            return False

        raise ValueError(f"Unable to infer result types for pdl.operation {op.opName}")

matcher_func: pdl_interp.FuncOp = matcher_func instance-attribute

rewriter_module: ModuleOp = rewriter_module instance-attribute

rewriter_builder: Builder = Builder(InsertPoint.at_end(rewriter_module.body.block)) instance-attribute

value_to_position: dict[pdl.PatternOp, dict[SSAValue, Position]] = {} instance-attribute

values: ScopedDict[Position, SSAValue] = ScopedDict() instance-attribute

failure_block_stack: list[Block] = [] instance-attribute

builder: Builder = Builder(InsertPoint.at_start(matcher_func.body.block)) instance-attribute

constraint_op_map: dict[ConstraintQuestion, pdl_interp.ApplyConstraintOp] = {} instance-attribute

rewriter_names: dict[str, int] = {} instance-attribute

__init__(matcher_func: pdl_interp.FuncOp, rewriter_module: ModuleOp, optimize_for_eqsat: bool = False) -> None

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
def __init__(
    self,
    matcher_func: pdl_interp.FuncOp,
    rewriter_module: ModuleOp,
    optimize_for_eqsat: bool = False,
) -> None:
    self.matcher_func = matcher_func
    self.rewriter_module = rewriter_module
    self.rewriter_builder = Builder(InsertPoint.at_end(rewriter_module.body.block))
    self.value_to_position = {}
    self.values = ScopedDict()
    self.failure_block_stack = []
    self.builder = Builder(InsertPoint.at_start(matcher_func.body.block))
    self.constraint_op_map = {}
    self.rewriter_names = {}

lower(patterns: list[pdl.PatternOp]) -> None

Lower PDL patterns to PDL interpreter

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
def lower(self, patterns: list[pdl.PatternOp]) -> None:
    """Lower PDL patterns to PDL interpreter"""

    # Build the predicate tree
    tree_builder = PredicateTreeBuilder()
    root = tree_builder.build_predicate_tree(patterns)
    self.value_to_position = tree_builder.pattern_value_positions

    # Get the entry block and add root operation argument
    entry_block = self.matcher_func.body.block

    # The first argument is the root operation
    root_pos = OperationPosition(depth=0)
    self.values[root_pos] = entry_block.args[0]

    # Generate the matcher
    _ = self.generate_matcher(root, self.matcher_func.body, block=entry_block)

generate_matcher(node: MatcherNode, region: Region, block: Block | None = None) -> Block

Generate PDL interpreter operations for a matcher node

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
def generate_matcher(
    self, node: MatcherNode, region: Region, block: Block | None = None
) -> Block:
    """Generate PDL interpreter operations for a matcher node"""

    # Create block if needed
    if block is None:
        block = Block()
        region.add_block(block)

    # Set insertion point to end of this block
    self.builder.insertion_point = InsertPoint.at_end(block)

    # Handle exit node - just add finalize
    if isinstance(node, ExitNode):
        self.builder.insert(pdl_interp.FinalizeOp())
        return block

    self.values = ScopedDict(self.values)
    assert self.values.parent is not None

    # Handle failure node
    failure_block = None
    if node.failure_node:
        failure_block = self.generate_matcher(node.failure_node, region)
        self.failure_block_stack.append(failure_block)
        # Restore insertion point after generating failure node
        self.builder.insertion_point = InsertPoint.at_end(block)
    else:
        assert self.failure_block_stack, "Expected valid failure block"
        failure_block = self.failure_block_stack[-1]

    # Get value for position if exists (may change insertion point)
    val = None
    if node.position:
        val = self.get_value_at(node.position)

    # Dispatch based on node type
    match node:
        case BoolNode():
            assert val is not None
            self.generate_bool_node(node, val)
        case SwitchNode():
            assert val is not None
            self.generate_switch_node(node, val)
        case SuccessNode():
            self.generate_success_node(node)
        case _:
            raise NotImplementedError(f"Unhandled node type {type(node)}")

    # Pop failure block if we pushed one
    if node.failure_node:
        self.failure_block_stack.pop()

    self.values = self.values.parent  # Pop scope
    return block

get_value_at(position: Position) -> SSAValue

Get or create SSA value for a position.

Assumes self.builder.insertion_point is correctly set. May modify the insertion point (e.g., when creating foreach loops).

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
def get_value_at(self, position: Position) -> SSAValue:
    """Get or create SSA value for a position.

    Assumes self.builder.insertion_point is correctly set.
    May modify the insertion point (e.g., when creating foreach loops).
    """

    # Check cache
    if position in self.values:
        return self.values[position]

    # Get parent value if needed (may change insertion point)
    parent_val = None
    if position.parent:
        parent_val = self.get_value_at(position.parent)

    # Create value based on position type
    value = None

    if isinstance(position, OperationPosition):
        if position.is_operand_defining_op():
            assert parent_val is not None
            # Get defining operation of operand
            defining_op = pdl_interp.GetDefiningOpOp(parent_val)
            defining_op.attributes["position"] = StringAttr(position.__repr__())
            self.builder.insert(defining_op)
            value = defining_op.input_op
        else:
            # Passthrough
            value = parent_val

    elif isinstance(position, OperandPosition):
        assert parent_val is not None
        get_operand_op = pdl_interp.GetOperandOp(
            position.operand_number, parent_val
        )
        self.builder.insert(get_operand_op)
        value = get_operand_op.value

    elif isinstance(position, OperandGroupPosition):
        assert parent_val is not None
        # Get operands (possibly variadic)
        result_type = (
            pdl.RangeType(pdl.ValueType())
            if position.is_variadic
            else pdl.ValueType()
        )
        get_operands_op = pdl_interp.GetOperandsOp(
            position.group_number, parent_val, result_type
        )
        self.builder.insert(get_operands_op)
        value = get_operands_op.value

    elif isinstance(position, ResultPosition):
        assert parent_val is not None
        get_result_op = pdl_interp.GetResultOp(position.result_number, parent_val)
        self.builder.insert(get_result_op)
        value = get_result_op.value

    elif isinstance(position, ResultGroupPosition):
        assert parent_val is not None
        # Get results (possibly variadic)
        result_type = (
            pdl.RangeType(pdl.ValueType())
            if position.is_variadic
            else pdl.ValueType()
        )
        get_results_op = pdl_interp.GetResultsOp(
            position.group_number, parent_val, result_type
        )
        self.builder.insert(get_results_op)
        value = get_results_op.value

    elif isinstance(position, AttributePosition):
        assert parent_val is not None
        get_attr_op = pdl_interp.GetAttributeOp(position.attribute_name, parent_val)
        self.builder.insert(get_attr_op)
        value = get_attr_op.value

    elif isinstance(position, AttributeLiteralPosition):
        # Create a constant attribute
        create_attr_op = pdl_interp.CreateAttributeOp(position.value)
        self.builder.insert(create_attr_op)
        value = create_attr_op.attribute

    elif isinstance(position, TypePosition):
        assert parent_val is not None
        # Get type of value or attribute
        if parent_val.type == pdl.AttributeType():
            get_type_op = pdl_interp.GetAttributeTypeOp(parent_val)
        else:
            get_type_op = pdl_interp.GetValueTypeOp(parent_val)
        self.builder.insert(get_type_op)
        value = get_type_op.result

    elif isinstance(position, TypeLiteralPosition):
        # Create a constant type or types
        raw_type_attr = position.value
        if isinstance(raw_type_attr, TypeAttribute):
            create_type_op = pdl_interp.CreateTypeOp(raw_type_attr)
            self.builder.insert(create_type_op)
            value = create_type_op.result
        else:
            # Assume it's an ArrayAttr of types
            assert isinstance(raw_type_attr, ArrayAttr)
            type_attr = cast(ArrayAttr[TypeAttribute], raw_type_attr)
            create_types_op = pdl_interp.CreateTypesOp(type_attr)
            self.builder.insert(create_types_op)
            value = create_types_op.result

    elif isinstance(position, ConstraintPosition):
        # The constraint op has already been created, find it in the map
        constraint_op = self.constraint_op_map.get(position.constraint)
        assert constraint_op is not None
        value = constraint_op.results[position.result_index]

    elif isinstance(position, UsersPosition):
        raise NotImplementedError("UsersPosition not implemented in lowering")
    elif isinstance(position, ForEachPosition):
        raise NotImplementedError("ForEachPosition not implemented in lowering")
    else:
        raise NotImplementedError(f"Unhandled position type {type(position)}")

    # Cache and return
    assert value is not None
    self.values[position] = value
    return value

generate_bool_node(node: BoolNode, val: SSAValue) -> None

Generate operations for a boolean predicate node.

Assumes self.builder.insertion_point is correctly set.

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
def generate_bool_node(self, node: BoolNode, val: SSAValue) -> None:
    """Generate operations for a boolean predicate node.

    Assumes self.builder.insertion_point is correctly set.
    """

    question = node.question
    answer = node.answer
    block = self.builder.insertion_point.block
    region = block.parent
    assert region is not None, "Block must be in a region"

    # Handle getValue queries first for constraint questions (may change insertion point)
    args: list[SSAValue] = []
    if isinstance(question, EqualToQuestion):
        args = [self.get_value_at(question.other_position)]
    elif isinstance(question, ConstraintQuestion):
        for position in question.arg_positions:
            args.append(self.get_value_at(position))

    # Get the current block after potentially changed insertion point
    block = self.builder.insertion_point.block
    region = block.parent
    assert region is not None, "Block must be in a region"

    # Create success block
    success_block = Block()
    region.add_block(success_block)
    failure_block = self.failure_block_stack[-1]

    # Generate predicate check operation based on question type
    match question:
        case IsNotNullQuestion():
            check_op = pdl_interp.IsNotNullOp(val, success_block, failure_block)
        case OperationNameQuestion():
            assert isinstance(answer, StringAnswer)
            check_op = pdl_interp.CheckOperationNameOp(
                answer.value, val, success_block, failure_block
            )
        case OperandCountQuestion() | OperandCountAtLeastQuestion():
            assert isinstance(answer, UnsignedAnswer)
            compare_at_least = isinstance(question, OperandCountAtLeastQuestion)
            check_op = pdl_interp.CheckOperandCountOp(
                val, answer.value, success_block, failure_block, compare_at_least
            )
        case ResultCountQuestion() | ResultCountAtLeastQuestion():
            assert isinstance(answer, UnsignedAnswer)
            compare_at_least = isinstance(question, ResultCountAtLeastQuestion)
            check_op = pdl_interp.CheckResultCountOp(
                val, answer.value, success_block, failure_block, compare_at_least
            )
        case EqualToQuestion():
            # Get the other value to compare with
            other_val = self.get_value_at(question.other_position)
            # Update block reference after potential insertion point change
            block = self.builder.insertion_point.block
            assert isinstance(answer, TrueAnswer)
            check_op = pdl_interp.AreEqualOp(
                val, other_val, success_block, failure_block
            )
        case AttributeConstraintQuestion():
            assert isinstance(answer, AttributeAnswer)
            check_op = pdl_interp.CheckAttributeOp(
                answer.value, val, success_block, failure_block
            )
        case TypeConstraintQuestion():
            assert isinstance(answer, TypeAnswer)
            if isinstance(val.type, pdl.RangeType):
                # Check multiple types
                assert isinstance(answer.value, ArrayAttr)
                check_op = pdl_interp.CheckTypesOp(
                    answer.value, val, success_block, failure_block
                )
            else:
                # Check single type
                assert isinstance(answer.value, TypeAttribute)
                check_op = pdl_interp.CheckTypeOp(
                    answer.value, val, success_block, failure_block
                )
        case ConstraintQuestion():
            check_op = pdl_interp.ApplyConstraintOp(
                question.name,
                args,
                success_block,
                failure_block,
                is_negated=question.is_negated,
                res_types=question.result_types,
            )
            # Store the constraint op for later result access
            self.constraint_op_map[question] = check_op
        case _:
            raise NotImplementedError(f"Unhandled question type {type(question)}")

    self.builder.insert(check_op)

    # Generate matcher for success node
    if node.success_node:
        self.generate_matcher(node.success_node, region, success_block)

generate_switch_node(node: SwitchNode, val: SSAValue) -> None

Generate operations for a switch node.

Assumes self.builder.insertion_point is correctly set.

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
def generate_switch_node(self, node: SwitchNode, val: SSAValue) -> None:
    """Generate operations for a switch node.

    Assumes self.builder.insertion_point is correctly set.
    """

    question = node.question
    block = self.builder.insertion_point.block
    region = block.parent
    assert region is not None, "Block must be in a region"
    default_dest = self.failure_block_stack[-1]

    # Handle at-least questions specially
    if isinstance(
        question, OperandCountAtLeastQuestion | ResultCountAtLeastQuestion
    ):
        # Sort children in reverse numerical order
        sorted_children = sorted(
            node.children.items(),
            key=lambda x: cast(UnsignedAnswer, x[0]).value,
            reverse=True,
        )

        # Push temporary entry to failure block stack
        self.failure_block_stack.append(default_dest)

        for answer, child_node in sorted_children:
            if child_node:
                success_block = self.generate_matcher(child_node, region)
                current_check_block = Block()
                region.insert_block_before(current_check_block, success_block)
                self.builder.insertion_point = InsertPoint.at_end(
                    current_check_block
                )
                assert isinstance(answer, UnsignedAnswer)
                if isinstance(question, OperandCountAtLeastQuestion):
                    check_op = pdl_interp.CheckOperandCountOp(
                        val,
                        answer.value,
                        success_block,
                        default_dest,
                        True,
                    )
                else:
                    check_op = pdl_interp.CheckResultCountOp(
                        val,
                        answer.value,
                        success_block,
                        default_dest,
                        True,
                    )
                self.builder.insert(check_op)

                # Update failure block stack for next child matcher
                self.failure_block_stack[-1] = current_check_block

        # Pop the temporary entry from failure block stack
        first_predicate_block = self.failure_block_stack.pop()

        # Move ops from the first check block into the main block
        for op in list(first_predicate_block.ops):
            op.detach()
            block.add_op(op)
        assert first_predicate_block.parent is not None
        first_predicate_block.parent.detach_block(first_predicate_block)
        first_predicate_block.erase()

        return

    # Generate child blocks and collect case values
    case_blocks: list[Block] = []
    case_values: list[Answer] = []

    for answer, child_node in node.children.items():
        if child_node:
            child_block = self.generate_matcher(child_node, region)
            case_blocks.append(child_block)
            case_values.append(answer)

    # Restore insertion point after generating child matchers
    self.builder.insertion_point = InsertPoint.at_end(block)

    # Create switch operation based on question type
    match question:
        case OperationNameQuestion():
            # Extract string values from StringAnswer objects
            switch_values = [cast(StringAnswer, ans).value for ans in case_values]
            switch_attr = ArrayAttr([StringAttr(v) for v in switch_values])
            switch_op = pdl_interp.SwitchOperationNameOp(
                switch_attr, val, default_dest, case_blocks
            )
        case OperandCountQuestion():
            # Extract integer values from UnsignedAnswer objects
            switch_values = [cast(UnsignedAnswer, ans).value for ans in case_values]
            switch_op = pdl_interp.SwitchOperandCountOp(
                switch_values, val, default_dest, case_blocks
            )
        case ResultCountQuestion():
            # Extract integer values from UnsignedAnswer objects
            switch_values = [cast(UnsignedAnswer, ans).value for ans in case_values]
            switch_op = pdl_interp.SwitchResultCountOp(
                switch_values, val, default_dest, case_blocks
            )
        case TypeConstraintQuestion():
            # Extract type attributes from TypeAnswer objects
            switch_values = [cast(TypeAnswer, ans).value for ans in case_values]
            if isinstance(val.type, pdl.RangeType):
                assert isa(switch_values, list[ArrayAttr[TypeAttribute]])
                switch_attr = ArrayAttr(switch_values)

                switch_op = pdl_interp.SwitchTypesOp(
                    switch_attr, val, default_dest, case_blocks
                )
            else:
                assert isa(switch_values, list[TypeAttribute])
                switch_attr = ArrayAttr(switch_values)
                switch_op = pdl_interp.SwitchTypeOp(
                    switch_attr, val, default_dest, case_blocks
                )
        case AttributeConstraintQuestion():
            # Extract attribute values from AttributeAnswer objects
            switch_values = [
                cast(AttributeAnswer, ans).value for ans in case_values
            ]
            switch_attr = ArrayAttr(switch_values)
            switch_op = pdl_interp.SwitchAttributeOp(
                val, switch_attr, default_dest, case_blocks
            )
        case _:
            raise NotImplementedError(f"Unhandled question type {type(question)}")

    self.builder.insert(switch_op)

generate_success_node(node: SuccessNode) -> None

Generate operations for a successful match.

Assumes self.builder.insertion_point is correctly set.

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
def generate_success_node(self, node: SuccessNode) -> None:
    """Generate operations for a successful match.

    Assumes self.builder.insertion_point is correctly set.
    """

    pattern = node.pattern
    root = node.root

    # Generate a rewriter for the pattern
    used_match_positions: list[Position] = []
    rewriter_func_ref = self.generate_rewriter(pattern, used_match_positions)

    # Process values used in the rewrite that are defined in the match
    # (may change insertion point)
    mapped_match_values = [self.get_value_at(pos) for pos in used_match_positions]

    # Collect generated op names from DAG rewriter
    rewriter_op = pattern.body.block.last_op
    assert isinstance(rewriter_op, pdl.RewriteOp)
    if not rewriter_op.name:
        generated_op_names = ArrayAttr(
            [
                op.opName
                for op in rewriter_op.body.walk()
                if isinstance(op, pdl.OperationOp) and op.opName
            ]
        )
    else:
        generated_op_names = None
    # Get root kind if present
    root_kind: StringAttr | None = None
    if root:
        defining_op = root.owner
        if isinstance(defining_op, pdl.OperationOp) and defining_op.opName:
            root_kind = StringAttr(defining_op.opName.data)

    # Create the RecordMatchOp
    record_op = pdl_interp.RecordMatchOp(
        rewriter_func_ref,
        root_kind,
        generated_op_names,
        pattern.benefit,
        mapped_match_values,
        [],
        self.failure_block_stack[-1],
    )
    self.builder.insert(record_op)

generate_rewriter(pattern: pdl.PatternOp, used_match_positions: list[Position]) -> SymbolRefAttr

Generate a rewriter function for the given pattern, and return a reference to that function.

Source code in xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
def generate_rewriter(
    self, pattern: pdl.PatternOp, used_match_positions: list[Position]
) -> SymbolRefAttr:
    """
    Generate a rewriter function for the given pattern, and return a
    reference to that function.
    """
    rewriter_op = pattern.body.block.last_op
    assert isinstance(rewriter_op, pdl.RewriteOp)

    if pattern.sym_name:
        rewriter_name = pattern.sym_name.data
    else:
        rewriter_name = "pdl_generated_rewriter"
    if rewriter_name in self.rewriter_names:
        # duplicate names get a numeric suffix starting from 0 (foo, foo_0, foo_1, ...)
        self.rewriter_names[rewriter_name] += 1
        rewriter_name = f"{rewriter_name}_{self.rewriter_names[rewriter_name] - 2}"
    else:
        self.rewriter_names[rewriter_name] = 1

    # Create the rewriter function
    rewriter_func = pdl_interp.FuncOp(rewriter_name, ([], []))

    self.rewriter_module.body.block.add_op(rewriter_func)
    entry_block = rewriter_func.body.block
    self.rewriter_builder.insertion_point = InsertPoint.at_end(entry_block)

    rewrite_values: dict[SSAValue, SSAValue] = {}
    pattern_value_positions = self.value_to_position[pattern]

    def map_rewrite_value(old_value: SSAValue) -> SSAValue:
        if new_value := rewrite_values.get(old_value):
            return new_value

        # Prefer materializing constants directly when possible.
        old_op = old_value.owner
        new_val_op: Operation | None = None
        if isinstance(old_op, pdl.AttributeOp) and old_op.value:
            new_val_op = pdl_interp.CreateAttributeOp(old_op.value)
        elif isinstance(old_op, pdl.TypeOp) and old_op.constantType:
            new_val_op = pdl_interp.CreateTypeOp(old_op.constantType)
        elif isinstance(old_op, pdl.TypesOp) and old_op.constantTypes:
            new_val_op = pdl_interp.CreateTypesOp(old_op.constantTypes)

        if new_val_op is not None:
            self.rewriter_builder.insert(new_val_op)
            new_value = new_val_op.results[0]
            rewrite_values[old_value] = new_value
            return new_value

        # Otherwise, it's an input from the matcher.
        input_pos = pattern_value_positions.get(old_value)
        assert input_pos is not None, "Expected value to be a pattern input"
        if input_pos not in used_match_positions:
            used_match_positions.append(input_pos)

        arg = entry_block.insert_arg(old_value.type, len(entry_block.args))
        rewrite_values[old_value] = arg
        return arg

    # If this is a custom rewriter, dispatch to the registered method.
    if rewriter_op.name_:
        args: list[SSAValue] = []
        if rewriter_op.root:
            args.append(map_rewrite_value(rewriter_op.root))
        args.extend(map_rewrite_value(arg) for arg in rewriter_op.external_args)

        apply_op = pdl_interp.ApplyRewriteOp(rewriter_op.name_.data, args)
        self.rewriter_builder.insert(apply_op)
    else:
        # Otherwise, this is a DAG rewriter defined using PDL operations.
        assert rewriter_op.body is not None
        for op in rewriter_op.body.ops:
            match op:
                case pdl.ApplyNativeRewriteOp():
                    self._generate_rewriter_for_apply_native_rewrite(
                        op, rewrite_values, map_rewrite_value
                    )
                case pdl.AttributeOp():
                    self._generate_rewriter_for_attribute(
                        op, rewrite_values, map_rewrite_value
                    )
                case pdl.EraseOp():
                    self._generate_rewriter_for_erase(
                        op, rewrite_values, map_rewrite_value
                    )
                case pdl.OperationOp():
                    self._generate_rewriter_for_operation(
                        op, rewrite_values, map_rewrite_value
                    )
                case pdl.RangeOp():
                    self._generate_rewriter_for_range(
                        op, rewrite_values, map_rewrite_value
                    )
                case pdl.ReplaceOp():
                    self._generate_rewriter_for_replace(
                        op, rewrite_values, map_rewrite_value
                    )
                case pdl.ResultOp():
                    self._generate_rewriter_for_result(
                        op, rewrite_values, map_rewrite_value
                    )
                case pdl.ResultsOp():
                    self._generate_rewriter_for_results(
                        op, rewrite_values, map_rewrite_value
                    )
                case pdl.TypeOp():
                    self._generate_rewriter_for_type(
                        op, rewrite_values, map_rewrite_value
                    )
                case pdl.TypesOp():
                    self._generate_rewriter_for_types(
                        op, rewrite_values, map_rewrite_value
                    )
                case _:
                    raise TypeError(f"Unexpected op type: {type(op)}")

    # Update the signature of the rewrite function.
    rewriter_func.function_type = FunctionType.from_lists(entry_block.arg_types, ())

    self.rewriter_builder.insert(pdl_interp.FinalizeOp())
    return SymbolRefAttr(
        "rewriters",
        [
            StringAttr(rewriter_name),
        ],
    )