Skip to content

Pattern rewriter

pattern_rewriter

PatternRewriterListener dataclass

Bases: BuilderListener

A listener for pattern rewriter events.

Source code in xdsl/pattern_rewriter.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@dataclass(eq=False)
class PatternRewriterListener(BuilderListener):
    """A listener for pattern rewriter events."""

    operation_removal_handler: list[Callable[[Operation], None]] = field(
        default_factory=list[Callable[[Operation], None]], kw_only=True
    )
    """Callbacks that are called when an operation is removed."""

    operation_modification_handler: list[Callable[[Operation], None]] = field(
        default_factory=list[Callable[[Operation], None]], kw_only=True
    )
    """Callbacks that are called when an operation is modified."""

    operation_replacement_handler: list[
        Callable[[Operation, Sequence[SSAValue | None]], None]
    ] = field(
        default_factory=list[Callable[[Operation, Sequence[SSAValue | None]], None]],
        kw_only=True,
    )
    """Callbacks that are called when an operation is replaced."""

    def handle_operation_removal(self, op: Operation) -> None:
        """Pass the operation that will be removed to the registered callbacks."""
        for handler in self.operation_removal_handler:
            handler(op)

    def handle_operation_modification(self, op: Operation) -> None:
        """Pass the operation that was just modified to the registered callbacks."""
        for handler in self.operation_modification_handler:
            handler(op)

    def handle_operation_replacement(
        self, op: Operation, new_results: Sequence[SSAValue | None]
    ) -> None:
        """Pass the operation that will be replaced to the registered callbacks."""
        for handler in self.operation_replacement_handler:
            handler(op, new_results)

    def extend_from_listener(self, listener: BuilderListener | PatternRewriterListener):
        """Forward all callbacks from `listener` to this listener."""
        super().extend_from_listener(listener)
        if isinstance(listener, PatternRewriterListener):
            self.operation_removal_handler.extend(listener.operation_removal_handler)
            self.operation_modification_handler.extend(
                listener.operation_modification_handler
            )
            self.operation_replacement_handler.extend(
                listener.operation_replacement_handler
            )

operation_removal_handler: list[Callable[[Operation], None]] = field(default_factory=(list[Callable[[Operation], None]]), kw_only=True) class-attribute instance-attribute

Callbacks that are called when an operation is removed.

operation_modification_handler: list[Callable[[Operation], None]] = field(default_factory=(list[Callable[[Operation], None]]), kw_only=True) class-attribute instance-attribute

Callbacks that are called when an operation is modified.

operation_replacement_handler: list[Callable[[Operation, Sequence[SSAValue | None]], None]] = field(default_factory=(list[Callable[[Operation, Sequence[SSAValue | None]], None]]), kw_only=True) class-attribute instance-attribute

Callbacks that are called when an operation is replaced.

__init__(*, operation_insertion_handler: list[Callable[[Operation], None]] = list[Callable[[Operation], None]](), block_creation_handler: list[Callable[[Block], None]] = list[Callable[[Block], None]](), operation_removal_handler: list[Callable[[Operation], None]] = list[Callable[[Operation], None]](), operation_modification_handler: list[Callable[[Operation], None]] = list[Callable[[Operation], None]](), operation_replacement_handler: list[Callable[[Operation, Sequence[SSAValue | None]], None]] = list[Callable[[Operation, Sequence[SSAValue | None]], None]]()) -> None

handle_operation_removal(op: Operation) -> None

Pass the operation that will be removed to the registered callbacks.

Source code in xdsl/pattern_rewriter.py
57
58
59
60
def handle_operation_removal(self, op: Operation) -> None:
    """Pass the operation that will be removed to the registered callbacks."""
    for handler in self.operation_removal_handler:
        handler(op)

handle_operation_modification(op: Operation) -> None

Pass the operation that was just modified to the registered callbacks.

Source code in xdsl/pattern_rewriter.py
62
63
64
65
def handle_operation_modification(self, op: Operation) -> None:
    """Pass the operation that was just modified to the registered callbacks."""
    for handler in self.operation_modification_handler:
        handler(op)

handle_operation_replacement(op: Operation, new_results: Sequence[SSAValue | None]) -> None

Pass the operation that will be replaced to the registered callbacks.

Source code in xdsl/pattern_rewriter.py
67
68
69
70
71
72
def handle_operation_replacement(
    self, op: Operation, new_results: Sequence[SSAValue | None]
) -> None:
    """Pass the operation that will be replaced to the registered callbacks."""
    for handler in self.operation_replacement_handler:
        handler(op, new_results)

extend_from_listener(listener: BuilderListener | PatternRewriterListener)

Forward all callbacks from listener to this listener.

Source code in xdsl/pattern_rewriter.py
74
75
76
77
78
79
80
81
82
83
84
def extend_from_listener(self, listener: BuilderListener | PatternRewriterListener):
    """Forward all callbacks from `listener` to this listener."""
    super().extend_from_listener(listener)
    if isinstance(listener, PatternRewriterListener):
        self.operation_removal_handler.extend(listener.operation_removal_handler)
        self.operation_modification_handler.extend(
            listener.operation_modification_handler
        )
        self.operation_replacement_handler.extend(
            listener.operation_replacement_handler
        )

PatternRewriter dataclass

Bases: Builder, PatternRewriterListener

A rewriter used during pattern matching. Once an operation is matched, this rewriter is used to apply modification to the operation and its children.

Source code in xdsl/pattern_rewriter.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
@dataclass(eq=False, init=False)
class PatternRewriter(Builder, PatternRewriterListener):
    """
    A rewriter used during pattern matching.
    Once an operation is matched, this rewriter is used to apply
    modification to the operation and its children.
    """

    current_operation: Operation
    """The matched operation."""

    has_done_action: bool = field(default=False, init=False)
    """Has the rewriter done any action during the current match."""

    def __init__(self, current_operation: Operation):
        PatternRewriterListener.__init__(self)
        self.current_operation = current_operation
        Builder.__init__(self, InsertPoint.before(current_operation))

    def insert_op(
        self,
        op: InsertOpInvT,
        insertion_point: InsertPoint | None = None,
    ) -> InsertOpInvT:
        """Insert operations at a certain location in a block."""
        self.has_done_action = True
        return super().insert_op(op, insertion_point)

    @deprecated(
        "Please use `rewriter.insert_op(op, InsertPoint.before(rewriter.current_operation))` instead"
    )
    def insert_op_before_matched_op(self, op: InsertOpInvT) -> InsertOpInvT:
        """Insert operations before the matched operation."""
        return self.insert_op(op, InsertPoint.before(self.current_operation))

    @deprecated(
        "Please use `rewriter.insert_op(op, InsertPoint.after(rewriter.current_operation))` instead"
    )
    def insert_op_after_matched_op(self, op: InsertOpInvT) -> InsertOpInvT:
        """Insert operations after the matched operation."""
        return self.insert_op(op, InsertPoint.after(self.current_operation))

    @deprecated("Please use `erase_op(op)` instead")
    def erase_matched_op(self, safe_erase: bool = True):
        """
        Erase the operation that was matched to.
        If safe_erase is True, check that the operation has no uses.
        Otherwise, replace its uses with ErasedSSAValue.
        """
        self.erase_op(self.current_operation, safe_erase=safe_erase)

    def erase_op(self, op: Operation, safe_erase: bool = True):
        """
        Erase an operation.
        If safe_erase is True, check that the operation has no uses.
        Otherwise, replace its uses with ErasedSSAValue.
        """
        self.has_done_action = True
        self.handle_operation_removal(op)
        Rewriter.erase_op(op, safe_erase=safe_erase)

    def replace_all_uses_with(
        self, from_value: SSAValue, to_value: SSAValue | None, safe_erase: bool = True
    ):
        """Replace all uses of `old` with `new`."""
        modified_ops = [use.operation for use in from_value.uses]
        if to_value is None:
            from_value.erase(safe_erase=safe_erase)
        else:
            from_value.replace_all_uses_with(to_value)
        for op in modified_ops:
            self.handle_operation_modification(op)

    def replace_uses_with_if(
        self,
        from_value: SSAValue,
        to_value: SSAValue,
        predicate: Callable[[Use], bool],
    ):
        """Replace uses of `old` satisfying `predicate` with `new`."""
        tracking = _TrackingPredicate(predicate)
        from_value.replace_uses_with_if(to_value, tracking)

        for op in tracking.modified_ops:
            self.handle_operation_modification(op)

    def replace_matched_op(
        self,
        new_ops: Operation | Sequence[Operation],
        new_results: Sequence[SSAValue | None] | None = None,
        safe_erase: bool = True,
    ):
        """
        Replace the matched operation with new operations.
        Also, optionally specify SSA values to replace the operation results.
        If safe_erase is True, check that the operation has no uses.
        Otherwise, replace its uses with ErasedSSAValue.
        """
        self.replace_op(
            self.current_operation, new_ops, new_results, safe_erase=safe_erase
        )

    def replace_op(
        self,
        op: Operation,
        new_ops: Operation | Sequence[Operation],
        new_results: Sequence[SSAValue | None] | None = None,
        safe_erase: bool = True,
    ):
        """
        Replace an operation with new operations.
        Also, optionally specify SSA values to replace the operation results.
        If safe_erase is True, check that the operation has no uses.
        Otherwise, replace its uses with ErasedSSAValue.
        """
        self.has_done_action = True

        if isinstance(new_ops, Operation):
            new_ops = (new_ops,)

        # First, insert the new operations before the matched operation
        self.insert_op(new_ops, InsertPoint.before(op))

        if new_results is None:
            new_results = new_ops[-1].results if new_ops else []

        if len(op.results) != len(new_results):
            raise ValueError(
                f"Expected {len(op.results)} new results, but got {len(new_results)}"
            )

        # Then, replace the results with new ones
        self.handle_operation_replacement(op, new_results)
        for old_result, new_result in zip(op.results, new_results):
            self.replace_all_uses_with(old_result, new_result, safe_erase=safe_erase)

            # Preserve name hints for ops with multiple results
            if new_result is not None and not new_result.name_hint:
                new_result.name_hint = old_result.name_hint

        # Add name hints for existing ops, only if there is a single new result
        if (
            len(new_results) == 1
            and (only_result := new_results[0]) is not None
            and (name_hint := only_result.name_hint) is not None
        ):
            for new_op in new_ops:
                for res in new_op.results:
                    if not res.name_hint:
                        res.name_hint = name_hint

        # Then, erase the original operation
        self.erase_op(op, safe_erase=safe_erase)

    def replace_value_with_new_type(
        self, val: SSAValue, new_type: Attribute
    ) -> SSAValue:
        """
        Replace a value with a value of a new type, and return the new value.
        This will insert the new value in the operation or block, and remove the existing
        value.
        """
        self.has_done_action = True
        if isinstance(val, OpResult):
            self.handle_operation_modification(val.op)
        if isinstance(val, BlockArgument):
            if (op := val.block.parent_op()) is not None:
                self.handle_operation_modification(op)
        return Rewriter.replace_value_with_new_type(val, new_type)

    def insert_block_argument(
        self, block: Block, index: int, arg_type: Attribute
    ) -> BlockArgument:
        """Insert a new block argument."""
        self.has_done_action = True
        return block.insert_arg(arg_type, index)

    def erase_block_argument(self, arg: BlockArgument, safe_erase: bool = True) -> None:
        """
        Erase a new block argument.
        If safe_erase is true, then raise an exception if the block argument has still
        uses, otherwise, replace it with an ErasedSSAValue.
        """
        self.has_done_action = True
        self.replace_all_uses_with(arg, None, safe_erase=safe_erase)
        arg.block.erase_arg(arg, safe_erase)

    def inline_block(
        self,
        block: Block,
        insertion_point: InsertPoint,
        arg_values: Sequence[SSAValue] = (),
    ):
        """
        Move the block operations to the specified insertion point.
        """
        self.has_done_action = True
        Rewriter.inline_block(block, insertion_point, arg_values=arg_values)

    @deprecated("Please use `inline_block(block, InsertPoint.before(op))`")
    def inline_block_before_matched_op(
        self, block: Block, arg_values: Sequence[SSAValue] = ()
    ):
        """
        Move the block operations before the matched operation.
        The block should not be a parent of the operation.
        """
        self.inline_block(
            block, InsertPoint.before(self.current_operation), arg_values=arg_values
        )

    @deprecated("Please use `inline_block(block, InsertPoint.after(op))`")
    def inline_block_after_matched_op(
        self, block: Block, arg_values: Sequence[SSAValue] = ()
    ):
        """
        Move the block operations after the matched operation.
        The block should not be a parent of the operation.
        """
        self.inline_block(
            block, InsertPoint.after(self.current_operation), arg_values=arg_values
        )

    def move_region_contents_to_new_regions(self, region: Region) -> Region:
        """Move the region blocks to a new region."""
        self.has_done_action = True
        return Rewriter.move_region_contents_to_new_regions(region)

    def inline_region(self, region: Region, insertion_point: BlockInsertPoint) -> None:
        """Move the region blocks to the specified insertion point."""
        self.has_done_action = True
        Rewriter.inline_region(region, insertion_point)

    def notify_op_modified(self, op: Operation) -> None:
        """
        Notify the rewriter that an operation was modified in the pattern.
        This will correctly update the rewriter state.
        """
        self.has_done_action = True
        self.handle_operation_modification(op)

has_done_action: bool = field(default=False, init=False) class-attribute instance-attribute

Has the rewriter done any action during the current match.

current_operation: Operation = current_operation instance-attribute

The matched operation.

__init__(current_operation: Operation)

Source code in xdsl/pattern_rewriter.py
125
126
127
128
def __init__(self, current_operation: Operation):
    PatternRewriterListener.__init__(self)
    self.current_operation = current_operation
    Builder.__init__(self, InsertPoint.before(current_operation))

insert_op(op: InsertOpInvT, insertion_point: InsertPoint | None = None) -> InsertOpInvT

Insert operations at a certain location in a block.

Source code in xdsl/pattern_rewriter.py
130
131
132
133
134
135
136
137
def insert_op(
    self,
    op: InsertOpInvT,
    insertion_point: InsertPoint | None = None,
) -> InsertOpInvT:
    """Insert operations at a certain location in a block."""
    self.has_done_action = True
    return super().insert_op(op, insertion_point)

insert_op_before_matched_op(op: InsertOpInvT) -> InsertOpInvT

Insert operations before the matched operation.

Source code in xdsl/pattern_rewriter.py
139
140
141
142
143
144
@deprecated(
    "Please use `rewriter.insert_op(op, InsertPoint.before(rewriter.current_operation))` instead"
)
def insert_op_before_matched_op(self, op: InsertOpInvT) -> InsertOpInvT:
    """Insert operations before the matched operation."""
    return self.insert_op(op, InsertPoint.before(self.current_operation))

insert_op_after_matched_op(op: InsertOpInvT) -> InsertOpInvT

Insert operations after the matched operation.

Source code in xdsl/pattern_rewriter.py
146
147
148
149
150
151
@deprecated(
    "Please use `rewriter.insert_op(op, InsertPoint.after(rewriter.current_operation))` instead"
)
def insert_op_after_matched_op(self, op: InsertOpInvT) -> InsertOpInvT:
    """Insert operations after the matched operation."""
    return self.insert_op(op, InsertPoint.after(self.current_operation))

erase_matched_op(safe_erase: bool = True)

Erase the operation that was matched to. If safe_erase is True, check that the operation has no uses. Otherwise, replace its uses with ErasedSSAValue.

Source code in xdsl/pattern_rewriter.py
153
154
155
156
157
158
159
160
@deprecated("Please use `erase_op(op)` instead")
def erase_matched_op(self, safe_erase: bool = True):
    """
    Erase the operation that was matched to.
    If safe_erase is True, check that the operation has no uses.
    Otherwise, replace its uses with ErasedSSAValue.
    """
    self.erase_op(self.current_operation, safe_erase=safe_erase)

erase_op(op: Operation, safe_erase: bool = True)

Erase an operation. If safe_erase is True, check that the operation has no uses. Otherwise, replace its uses with ErasedSSAValue.

Source code in xdsl/pattern_rewriter.py
162
163
164
165
166
167
168
169
170
def erase_op(self, op: Operation, safe_erase: bool = True):
    """
    Erase an operation.
    If safe_erase is True, check that the operation has no uses.
    Otherwise, replace its uses with ErasedSSAValue.
    """
    self.has_done_action = True
    self.handle_operation_removal(op)
    Rewriter.erase_op(op, safe_erase=safe_erase)

replace_all_uses_with(from_value: SSAValue, to_value: SSAValue | None, safe_erase: bool = True)

Replace all uses of old with new.

Source code in xdsl/pattern_rewriter.py
172
173
174
175
176
177
178
179
180
181
182
def replace_all_uses_with(
    self, from_value: SSAValue, to_value: SSAValue | None, safe_erase: bool = True
):
    """Replace all uses of `old` with `new`."""
    modified_ops = [use.operation for use in from_value.uses]
    if to_value is None:
        from_value.erase(safe_erase=safe_erase)
    else:
        from_value.replace_all_uses_with(to_value)
    for op in modified_ops:
        self.handle_operation_modification(op)

replace_uses_with_if(from_value: SSAValue, to_value: SSAValue, predicate: Callable[[Use], bool])

Replace uses of old satisfying predicate with new.

Source code in xdsl/pattern_rewriter.py
184
185
186
187
188
189
190
191
192
193
194
195
def replace_uses_with_if(
    self,
    from_value: SSAValue,
    to_value: SSAValue,
    predicate: Callable[[Use], bool],
):
    """Replace uses of `old` satisfying `predicate` with `new`."""
    tracking = _TrackingPredicate(predicate)
    from_value.replace_uses_with_if(to_value, tracking)

    for op in tracking.modified_ops:
        self.handle_operation_modification(op)

replace_matched_op(new_ops: Operation | Sequence[Operation], new_results: Sequence[SSAValue | None] | None = None, safe_erase: bool = True)

Replace the matched operation with new operations. Also, optionally specify SSA values to replace the operation results. If safe_erase is True, check that the operation has no uses. Otherwise, replace its uses with ErasedSSAValue.

Source code in xdsl/pattern_rewriter.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def replace_matched_op(
    self,
    new_ops: Operation | Sequence[Operation],
    new_results: Sequence[SSAValue | None] | None = None,
    safe_erase: bool = True,
):
    """
    Replace the matched operation with new operations.
    Also, optionally specify SSA values to replace the operation results.
    If safe_erase is True, check that the operation has no uses.
    Otherwise, replace its uses with ErasedSSAValue.
    """
    self.replace_op(
        self.current_operation, new_ops, new_results, safe_erase=safe_erase
    )

replace_op(op: Operation, new_ops: Operation | Sequence[Operation], new_results: Sequence[SSAValue | None] | None = None, safe_erase: bool = True)

Replace an operation with new operations. Also, optionally specify SSA values to replace the operation results. If safe_erase is True, check that the operation has no uses. Otherwise, replace its uses with ErasedSSAValue.

Source code in xdsl/pattern_rewriter.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def replace_op(
    self,
    op: Operation,
    new_ops: Operation | Sequence[Operation],
    new_results: Sequence[SSAValue | None] | None = None,
    safe_erase: bool = True,
):
    """
    Replace an operation with new operations.
    Also, optionally specify SSA values to replace the operation results.
    If safe_erase is True, check that the operation has no uses.
    Otherwise, replace its uses with ErasedSSAValue.
    """
    self.has_done_action = True

    if isinstance(new_ops, Operation):
        new_ops = (new_ops,)

    # First, insert the new operations before the matched operation
    self.insert_op(new_ops, InsertPoint.before(op))

    if new_results is None:
        new_results = new_ops[-1].results if new_ops else []

    if len(op.results) != len(new_results):
        raise ValueError(
            f"Expected {len(op.results)} new results, but got {len(new_results)}"
        )

    # Then, replace the results with new ones
    self.handle_operation_replacement(op, new_results)
    for old_result, new_result in zip(op.results, new_results):
        self.replace_all_uses_with(old_result, new_result, safe_erase=safe_erase)

        # Preserve name hints for ops with multiple results
        if new_result is not None and not new_result.name_hint:
            new_result.name_hint = old_result.name_hint

    # Add name hints for existing ops, only if there is a single new result
    if (
        len(new_results) == 1
        and (only_result := new_results[0]) is not None
        and (name_hint := only_result.name_hint) is not None
    ):
        for new_op in new_ops:
            for res in new_op.results:
                if not res.name_hint:
                    res.name_hint = name_hint

    # Then, erase the original operation
    self.erase_op(op, safe_erase=safe_erase)

replace_value_with_new_type(val: SSAValue, new_type: Attribute) -> SSAValue

Replace a value with a value of a new type, and return the new value. This will insert the new value in the operation or block, and remove the existing value.

Source code in xdsl/pattern_rewriter.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def replace_value_with_new_type(
    self, val: SSAValue, new_type: Attribute
) -> SSAValue:
    """
    Replace a value with a value of a new type, and return the new value.
    This will insert the new value in the operation or block, and remove the existing
    value.
    """
    self.has_done_action = True
    if isinstance(val, OpResult):
        self.handle_operation_modification(val.op)
    if isinstance(val, BlockArgument):
        if (op := val.block.parent_op()) is not None:
            self.handle_operation_modification(op)
    return Rewriter.replace_value_with_new_type(val, new_type)

insert_block_argument(block: Block, index: int, arg_type: Attribute) -> BlockArgument

Insert a new block argument.

Source code in xdsl/pattern_rewriter.py
281
282
283
284
285
286
def insert_block_argument(
    self, block: Block, index: int, arg_type: Attribute
) -> BlockArgument:
    """Insert a new block argument."""
    self.has_done_action = True
    return block.insert_arg(arg_type, index)

erase_block_argument(arg: BlockArgument, safe_erase: bool = True) -> None

Erase a new block argument. If safe_erase is true, then raise an exception if the block argument has still uses, otherwise, replace it with an ErasedSSAValue.

Source code in xdsl/pattern_rewriter.py
288
289
290
291
292
293
294
295
296
def erase_block_argument(self, arg: BlockArgument, safe_erase: bool = True) -> None:
    """
    Erase a new block argument.
    If safe_erase is true, then raise an exception if the block argument has still
    uses, otherwise, replace it with an ErasedSSAValue.
    """
    self.has_done_action = True
    self.replace_all_uses_with(arg, None, safe_erase=safe_erase)
    arg.block.erase_arg(arg, safe_erase)

inline_block(block: Block, insertion_point: InsertPoint, arg_values: Sequence[SSAValue] = ())

Move the block operations to the specified insertion point.

Source code in xdsl/pattern_rewriter.py
298
299
300
301
302
303
304
305
306
307
308
def inline_block(
    self,
    block: Block,
    insertion_point: InsertPoint,
    arg_values: Sequence[SSAValue] = (),
):
    """
    Move the block operations to the specified insertion point.
    """
    self.has_done_action = True
    Rewriter.inline_block(block, insertion_point, arg_values=arg_values)

inline_block_before_matched_op(block: Block, arg_values: Sequence[SSAValue] = ())

Move the block operations before the matched operation. The block should not be a parent of the operation.

Source code in xdsl/pattern_rewriter.py
310
311
312
313
314
315
316
317
318
319
320
@deprecated("Please use `inline_block(block, InsertPoint.before(op))`")
def inline_block_before_matched_op(
    self, block: Block, arg_values: Sequence[SSAValue] = ()
):
    """
    Move the block operations before the matched operation.
    The block should not be a parent of the operation.
    """
    self.inline_block(
        block, InsertPoint.before(self.current_operation), arg_values=arg_values
    )

inline_block_after_matched_op(block: Block, arg_values: Sequence[SSAValue] = ())

Move the block operations after the matched operation. The block should not be a parent of the operation.

Source code in xdsl/pattern_rewriter.py
322
323
324
325
326
327
328
329
330
331
332
@deprecated("Please use `inline_block(block, InsertPoint.after(op))`")
def inline_block_after_matched_op(
    self, block: Block, arg_values: Sequence[SSAValue] = ()
):
    """
    Move the block operations after the matched operation.
    The block should not be a parent of the operation.
    """
    self.inline_block(
        block, InsertPoint.after(self.current_operation), arg_values=arg_values
    )

move_region_contents_to_new_regions(region: Region) -> Region

Move the region blocks to a new region.

Source code in xdsl/pattern_rewriter.py
334
335
336
337
def move_region_contents_to_new_regions(self, region: Region) -> Region:
    """Move the region blocks to a new region."""
    self.has_done_action = True
    return Rewriter.move_region_contents_to_new_regions(region)

inline_region(region: Region, insertion_point: BlockInsertPoint) -> None

Move the region blocks to the specified insertion point.

Source code in xdsl/pattern_rewriter.py
339
340
341
342
def inline_region(self, region: Region, insertion_point: BlockInsertPoint) -> None:
    """Move the region blocks to the specified insertion point."""
    self.has_done_action = True
    Rewriter.inline_region(region, insertion_point)

notify_op_modified(op: Operation) -> None

Notify the rewriter that an operation was modified in the pattern. This will correctly update the rewriter state.

Source code in xdsl/pattern_rewriter.py
344
345
346
347
348
349
350
def notify_op_modified(self, op: Operation) -> None:
    """
    Notify the rewriter that an operation was modified in the pattern.
    This will correctly update the rewriter state.
    """
    self.has_done_action = True
    self.handle_operation_modification(op)

RewritePattern

Bases: ABC

A side-effect free rewrite pattern matching on a DAG.

Source code in xdsl/pattern_rewriter.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
class RewritePattern(ABC):
    """
    A side-effect free rewrite pattern matching on a DAG.
    """

    # The / in the function signature makes the previous arguments positional, see
    # https://peps.python.org/pep-0570/
    # This is used by the op_type_rewrite_pattern
    @abstractmethod
    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
        """
        Match an operation, and optionally perform a rewrite using the rewriter.
        """
        ...

match_and_rewrite(op: Operation, rewriter: PatternRewriter) abstractmethod

Match an operation, and optionally perform a rewrite using the rewriter.

Source code in xdsl/pattern_rewriter.py
361
362
363
364
365
366
@abstractmethod
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
    """
    Match an operation, and optionally perform a rewrite using the rewriter.
    """
    ...

TypeConversionPattern dataclass

Bases: RewritePattern

Base pattern for type conversion. It is supposed to be inherited from, then one can implement convert_type to define the conversion.

It will convert an Operations' result types, dictionary attributes, and block arguments.

One can use @attr_type_rewrite_pattern on this defined method to automatically filter on the Attribute type used.

This base pattern defines two flags:

  • recursive (defaulting to False): recurse over structured attributes to convert parameters. e.g. a recusrive i32 to index conversion will convert vector<i32> to vector<index>.
  • ops (defaulting to any Operation) is a tuple of Operation types on which to apply the defined attribute conversion.
Source code in xdsl/pattern_rewriter.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
@dataclass
class TypeConversionPattern(RewritePattern):
    """
    Base pattern for type conversion. It is supposed to be inherited from, then one can
    implement `convert_type` to define the conversion.

    It will convert an Operations' result types, dictionary attributes, and block arguments.

    One can use `@attr_type_rewrite_pattern` on this defined method to automatically filter
    on the Attribute type used.

    This base pattern defines two flags:

    - `recursive` (defaulting to False): recurse over structured attributes to convert
      parameters.
      e.g. a recusrive `i32` to `index` conversion will convert `vector<i32>` to
      `vector<index>`.
    - `ops` (defaulting to any Operation) is a tuple of Operation types on which to apply
      the defined attribute conversion.
    """

    recursive: bool = False
    """
    recurse over structured attributes to convert parameters.
    Defaults to False.
    """
    ops: tuple[type[Operation], ...] | None = None
    """
    A tuple of Operation types on which to apply the defined attribute conversion.
    Defaults to any operation type.
    """

    @abstractmethod
    def convert_type(self, typ: Attribute, /) -> Attribute | None:
        """
        The method to implement to define a TypeConversionPattern

        This defines how the input Attribute should be converted.
        It allows returning None, meaning "this attribute should not
        be converted".
        """
        raise NotImplementedError()

    @final
    def _convert_type_rec(self, typ: Attribute) -> Attribute | None:
        """
        Provided recursion over structed/parameterized Attributes.
        """
        inp = typ
        if self.recursive:
            if isinstance(typ, ParametrizedAttribute):
                parameters = list(
                    self._convert_type_rec(p) or p for p in typ.parameters
                )
                inp = type(typ).new(parameters)
            if isa(typ, ArrayAttr[Attribute]):
                parameters = tuple(self._convert_type_rec(p) or p for p in typ)
                inp = type(typ).new(parameters)
            if isa(typ, DictionaryAttr):
                parameters = {k: self._convert_type_rec(v) for k, v in typ.data.items()}
                inp = type(typ).new(parameters)
        converted = self.convert_type(inp)
        return converted if converted is not None else inp

    @final
    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter):
        """
        Pattern application implementation
        """
        if self.ops and not isinstance(op, self.ops):
            return
        new_result_types: list[Attribute] = []
        new_attributes: dict[str, Attribute] = {}
        new_properties: dict[str, Attribute] = {}
        changed: bool = False
        for result in op.results:
            converted = self._convert_type_rec(result.type)
            new_result_types.append(converted or result.type)
            if converted is not None and converted != result.type:
                changed = True
        for name, attribute in op.attributes.items():
            converted = self._convert_type_rec(attribute)
            new_attributes[name] = converted or attribute
            if converted is not None and converted != attribute:
                changed = True
        for name, attribute in op.properties.items():
            converted = self._convert_type_rec(attribute)
            new_properties[name] = converted or attribute
            if converted is not None and converted != attribute:
                changed = True
        for region in op.regions:
            for block in region.blocks:
                for arg in block.args:
                    converted = self._convert_type_rec(arg.type)
                    if converted is not None and converted != arg.type:
                        rewriter.replace_value_with_new_type(arg, converted)
        if changed:
            regions = [op.detach_region(r) for r in op.regions]
            new_op = type(op).create(
                operands=op.operands,
                result_types=new_result_types,
                properties=new_properties,
                attributes=new_attributes,
                successors=op.successors,
                regions=regions,
            )
            rewriter.replace_op(op, new_op)
            for new, old in zip(new_op.results, op.results):
                new.name_hint = old.name_hint

recursive: bool = False class-attribute instance-attribute

recurse over structured attributes to convert parameters. Defaults to False.

ops: tuple[type[Operation], ...] | None = None class-attribute instance-attribute

A tuple of Operation types on which to apply the defined attribute conversion. Defaults to any operation type.

__init__(recursive: bool = False, ops: tuple[type[Operation], ...] | None = None) -> None

convert_type(typ: Attribute) -> Attribute | None abstractmethod

The method to implement to define a TypeConversionPattern

This defines how the input Attribute should be converted. It allows returning None, meaning "this attribute should not be converted".

Source code in xdsl/pattern_rewriter.py
455
456
457
458
459
460
461
462
463
464
@abstractmethod
def convert_type(self, typ: Attribute, /) -> Attribute | None:
    """
    The method to implement to define a TypeConversionPattern

    This defines how the input Attribute should be converted.
    It allows returning None, meaning "this attribute should not
    be converted".
    """
    raise NotImplementedError()

match_and_rewrite(op: Operation, rewriter: PatternRewriter)

Pattern application implementation

Source code in xdsl/pattern_rewriter.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
@final
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter):
    """
    Pattern application implementation
    """
    if self.ops and not isinstance(op, self.ops):
        return
    new_result_types: list[Attribute] = []
    new_attributes: dict[str, Attribute] = {}
    new_properties: dict[str, Attribute] = {}
    changed: bool = False
    for result in op.results:
        converted = self._convert_type_rec(result.type)
        new_result_types.append(converted or result.type)
        if converted is not None and converted != result.type:
            changed = True
    for name, attribute in op.attributes.items():
        converted = self._convert_type_rec(attribute)
        new_attributes[name] = converted or attribute
        if converted is not None and converted != attribute:
            changed = True
    for name, attribute in op.properties.items():
        converted = self._convert_type_rec(attribute)
        new_properties[name] = converted or attribute
        if converted is not None and converted != attribute:
            changed = True
    for region in op.regions:
        for block in region.blocks:
            for arg in block.args:
                converted = self._convert_type_rec(arg.type)
                if converted is not None and converted != arg.type:
                    rewriter.replace_value_with_new_type(arg, converted)
    if changed:
        regions = [op.detach_region(r) for r in op.regions]
        new_op = type(op).create(
            operands=op.operands,
            result_types=new_result_types,
            properties=new_properties,
            attributes=new_attributes,
            successors=op.successors,
            regions=regions,
        )
        rewriter.replace_op(op, new_op)
        for new, old in zip(new_op.results, op.results):
            new.name_hint = old.name_hint

GreedyRewritePatternApplier dataclass

Bases: RewritePattern

Apply a list of patterns in order until one pattern matches, and then use this rewrite. By default, the applier attempts to fold the operation first. If the operation is trivially dead, it is erased.

Source code in xdsl/pattern_rewriter.py
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
@dataclass(eq=False, repr=False)
class GreedyRewritePatternApplier(RewritePattern):
    """
    Apply a list of patterns in order until one pattern matches,
    and then use this rewrite.
    By default, the applier attempts to fold the operation first.
    If the operation is trivially dead, it is erased.
    """

    rewrite_patterns: list[RewritePattern]
    """The list of rewrites to apply in order."""

    ctx: Context | None = field(default=None)
    """Used to materialize constant operations when folding."""

    folding_enabled: bool = field(default=False, kw_only=True)
    """
    Whether the folders should be invoked.
    If this is True, the GreedyRewritePatternApplier must also have a context.
    """

    dce_enabled: bool = field(default=True, kw_only=True)
    """
    Whether trivial dead code elimination should be run on all operations before
    attempting to rewrite them.
    """

    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
        from xdsl.transforms.dead_code_elimination import is_trivially_dead

        if self.dce_enabled and is_trivially_dead(op):
            rewriter.erase_op(op)
            return

        # Do not fold constant ops. That would lead to an infinite folding loop,
        # as every constant op would be folded to an Attribute and then
        # immediately be rematerialized as a constant op, which is then put
        # back into the worklist.
        if (
            self.folding_enabled
            and op.has_trait(HasFolder, value_if_unregistered=False)
            and not op.has_trait(ConstantLike, value_if_unregistered=True)
        ):
            if self.ctx is None:
                raise ValueError("Context is required for folding")
            folded = Folder(self.ctx).try_fold(op)
            if folded is not None:
                folded_values, folded_ops = folded
                rewriter.replace_op(op, new_ops=folded_ops, new_results=folded_values)
                return

        for pattern in self.rewrite_patterns:
            pattern.match_and_rewrite(op, rewriter)
            if rewriter.has_done_action:
                return
        return

rewrite_patterns: list[RewritePattern] instance-attribute

The list of rewrites to apply in order.

ctx: Context | None = field(default=None) class-attribute instance-attribute

Used to materialize constant operations when folding.

folding_enabled: bool = field(default=False, kw_only=True) class-attribute instance-attribute

Whether the folders should be invoked. If this is True, the GreedyRewritePatternApplier must also have a context.

dce_enabled: bool = field(default=True, kw_only=True) class-attribute instance-attribute

Whether trivial dead code elimination should be run on all operations before attempting to rewrite them.

__init__(rewrite_patterns: list[RewritePattern], ctx: Context | None = None, *, folding_enabled: bool = False, dce_enabled: bool = True) -> None

match_and_rewrite(op: Operation, rewriter: PatternRewriter) -> None

Source code in xdsl/pattern_rewriter.py
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
    from xdsl.transforms.dead_code_elimination import is_trivially_dead

    if self.dce_enabled and is_trivially_dead(op):
        rewriter.erase_op(op)
        return

    # Do not fold constant ops. That would lead to an infinite folding loop,
    # as every constant op would be folded to an Attribute and then
    # immediately be rematerialized as a constant op, which is then put
    # back into the worklist.
    if (
        self.folding_enabled
        and op.has_trait(HasFolder, value_if_unregistered=False)
        and not op.has_trait(ConstantLike, value_if_unregistered=True)
    ):
        if self.ctx is None:
            raise ValueError("Context is required for folding")
        folded = Folder(self.ctx).try_fold(op)
        if folded is not None:
            folded_values, folded_ops = folded
            rewriter.replace_op(op, new_ops=folded_ops, new_results=folded_values)
            return

    for pattern in self.rewrite_patterns:
        pattern.match_and_rewrite(op, rewriter)
        if rewriter.has_done_action:
            return
    return

Worklist dataclass

Source code in xdsl/pattern_rewriter.py
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
@dataclass(eq=False)
class Worklist:
    _op_stack: list[Operation | None] = field(
        default_factory=list[Operation | None], init=False
    )
    """
    The list of operations to iterate over, used as a last-in-first-out stack.
    Operations are added and removed at the end of the list.
    Operation that are `None` are meant to be discarded, and are used to
    keep removal of operations O(1).
    """

    _map: dict[Operation, int] = field(default_factory=dict[Operation, int], init=False)
    """
    The map of operations to their index in the stack.
    It is used to check if an operation is already in the stack, and to
    remove it in O(1).
    """

    def is_empty(self) -> bool:
        """Check if the worklist is empty."""
        while self._op_stack and self._op_stack[-1] is None:
            self._op_stack.pop()
        return not bool(self._op_stack)

    def push(self, op: Operation):
        """
        Push an operation to the end of the worklist, if it is not already in it.
        """
        if op not in self._map:
            self._map[op] = len(self._op_stack)
            self._op_stack.append(op)

    def pop(self) -> Operation | None:
        """Pop the operation at the end of the worklist."""
        # All `None` operations at the end of the stack are discarded,
        # as they were removed previously.
        # We either return `None` if the stack is empty, or the last operation
        # that is not `None`.
        while self._op_stack:
            op = self._op_stack.pop()
            if op is not None:
                del self._map[op]
                return op
        return None

    def remove(self, op: Operation):
        """Remove an operation from the worklist."""
        if op in self._map:
            index = self._map[op]
            self._op_stack[index] = None
            del self._map[op]

__init__() -> None

is_empty() -> bool

Check if the worklist is empty.

Source code in xdsl/pattern_rewriter.py
658
659
660
661
662
def is_empty(self) -> bool:
    """Check if the worklist is empty."""
    while self._op_stack and self._op_stack[-1] is None:
        self._op_stack.pop()
    return not bool(self._op_stack)

push(op: Operation)

Push an operation to the end of the worklist, if it is not already in it.

Source code in xdsl/pattern_rewriter.py
664
665
666
667
668
669
670
def push(self, op: Operation):
    """
    Push an operation to the end of the worklist, if it is not already in it.
    """
    if op not in self._map:
        self._map[op] = len(self._op_stack)
        self._op_stack.append(op)

pop() -> Operation | None

Pop the operation at the end of the worklist.

Source code in xdsl/pattern_rewriter.py
672
673
674
675
676
677
678
679
680
681
682
683
def pop(self) -> Operation | None:
    """Pop the operation at the end of the worklist."""
    # All `None` operations at the end of the stack are discarded,
    # as they were removed previously.
    # We either return `None` if the stack is empty, or the last operation
    # that is not `None`.
    while self._op_stack:
        op = self._op_stack.pop()
        if op is not None:
            del self._map[op]
            return op
    return None

remove(op: Operation)

Remove an operation from the worklist.

Source code in xdsl/pattern_rewriter.py
685
686
687
688
689
690
def remove(self, op: Operation):
    """Remove an operation from the worklist."""
    if op in self._map:
        index = self._map[op]
        self._op_stack[index] = None
        del self._map[op]

PatternRewriteWalker dataclass

Walks the IR in the block and instruction order, and rewrite it in place. Previous references to the walked operations are invalid after the walk. Can walk either first the regions, or first the owner operation. The walker will also walk recursively on the created operations.

Source code in xdsl/pattern_rewriter.py
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
@dataclass(eq=False, repr=False)
class PatternRewriteWalker:
    """
    Walks the IR in the block and instruction order, and rewrite it in place.
    Previous references to the walked operations are invalid after the walk.
    Can walk either first the regions, or first the owner operation.
    The walker will also walk recursively on the created operations.
    """

    pattern: RewritePattern
    """Pattern to apply during the walk."""

    walk_regions_first: bool = field(default=False)
    """
    Choose if the walker should first walk the operation regions first,
    or the operation itself.
    """

    apply_recursively: bool = field(default=True)
    """Apply recursively rewrites on new operations."""

    walk_reverse: bool = field(default=False)
    """
    Walk the regions and blocks in reverse order.
    That way, all uses are replaced before the definitions.
    """

    post_walk_func: Callable[[Region, PatternRewriterListener], bool] | None = field(
        default=None
    )
    """
    Function to call between each walk of the IR.
    """

    listener: PatternRewriterListener = field(default_factory=PatternRewriterListener)
    """The listener that will be called when an operation or block is modified."""

    _worklist: Worklist = field(default_factory=Worklist, init=False)
    """The worklist of operations to walk over."""

    def _add_operands_to_worklist(self, operands: Iterable[SSAValue]) -> None:
        """
        Add defining operations of SSA values to the worklist if they have only
        one use. This is a heuristic based on the fact that single-use operations
        have more canonicalization opportunities.
        """
        for operand in operands:
            if (
                operand.has_one_use()
                and not isinstance(operand, ErasedSSAValue)
                and isinstance((op := operand.owner), Operation)
            ):
                self._worklist.push(op)

    def _handle_operation_insertion(self, op: Operation) -> None:
        """Handle insertion of an operation."""
        if self.apply_recursively:
            self._worklist.push(op)

    def _handle_operation_removal(self, op: Operation) -> None:
        """Handle removal of an operation."""
        if self.apply_recursively:
            self._add_operands_to_worklist(op.operands)
        if op.regions:
            for sub_op in op.walk():
                self._worklist.remove(sub_op)
        else:
            self._worklist.remove(op)

    def _handle_operation_modification(self, op: Operation) -> None:
        """Handle modification of an operation."""
        if self.apply_recursively:
            self._worklist.push(op)

    def _handle_operation_replacement(
        self, op: Operation, new_results: Sequence[SSAValue | None]
    ) -> None:
        """Handle replacement of an operation."""
        if self.apply_recursively:
            for result in op.results:
                for user in result.uses:
                    self._worklist.push(user.operation)

    def _get_rewriter_listener(self) -> PatternRewriterListener:
        """
        Get the listener that will be passed to the rewriter.
        It will take care of adding operations to the worklist, and calling the
        listener passed as configuration to the walker.
        """
        return PatternRewriterListener(
            operation_insertion_handler=[
                *self.listener.operation_insertion_handler,
                self._handle_operation_insertion,
            ],
            operation_removal_handler=[
                *self.listener.operation_removal_handler,
                self._handle_operation_removal,
            ],
            operation_modification_handler=[
                *self.listener.operation_modification_handler,
                self._handle_operation_modification,
            ],
            operation_replacement_handler=[
                *self.listener.operation_replacement_handler,
                self._handle_operation_replacement,
            ],
            block_creation_handler=self.listener.block_creation_handler,
        )

    def rewrite_module(self, module: ModuleOp) -> bool:
        """
        Rewrite operations nested in the given operation by repeatedly applying the
        pattern. Returns `True` if the IR was mutated.
        """
        return self.rewrite_region(module.body)

    def rewrite_region(self, region: Region) -> bool:
        """
        Rewrite operations nested in the given operation by repeatedly applying the
        pattern. Returns `True` if the IR was mutated.
        """
        pattern_listener = self._get_rewriter_listener()

        self._populate_worklist(region)
        op_was_modified = self._process_worklist(pattern_listener)
        if self.post_walk_func is not None:
            op_was_modified |= self.post_walk_func(region, pattern_listener)

        if not self.apply_recursively:
            return op_was_modified

        result = op_was_modified

        while op_was_modified:
            self._populate_worklist(region)
            op_was_modified = self._process_worklist(pattern_listener)
            if self.post_walk_func is not None:
                op_was_modified |= self.post_walk_func(region, pattern_listener)

        return result

    def _populate_worklist(self, op: Operation | Region | Block) -> None:
        """Populate the worklist with all nested operations."""
        # We walk in reverse order since we use a stack for our worklist.
        for sub_op in op.walk(
            reverse=not self.walk_reverse, region_first=not self.walk_regions_first
        ):
            self._worklist.push(sub_op)

    def _process_worklist(self, listener: PatternRewriterListener) -> bool:
        """
        Process the worklist until it is empty.
        Returns true if any modification was done.
        """
        rewriter_has_done_action = False

        # Handle empty worklist
        op = self._worklist.pop()
        if op is None:
            return rewriter_has_done_action

        # Create a rewriter on the first operation
        rewriter = PatternRewriter(op)
        rewriter.extend_from_listener(listener)

        # do/while loop
        while True:
            # Reset the rewriter on `op`
            rewriter.has_done_action = False
            rewriter.current_operation = op
            rewriter.insertion_point = InsertPoint.before(op)
            rewriter.name_hint = None

            # Apply the pattern on the operation
            try:
                self.pattern.match_and_rewrite(op, rewriter)
            except Exception as err:
                op.emit_error(
                    f"Error while applying pattern: {err}",
                    underlying_error=err,
                )
            rewriter_has_done_action |= rewriter.has_done_action

            # If the worklist is empty, we are done
            op = self._worklist.pop()
            if op is None:
                return rewriter_has_done_action

pattern: RewritePattern instance-attribute

Pattern to apply during the walk.

walk_regions_first: bool = field(default=False) class-attribute instance-attribute

Choose if the walker should first walk the operation regions first, or the operation itself.

apply_recursively: bool = field(default=True) class-attribute instance-attribute

Apply recursively rewrites on new operations.

walk_reverse: bool = field(default=False) class-attribute instance-attribute

Walk the regions and blocks in reverse order. That way, all uses are replaced before the definitions.

post_walk_func: Callable[[Region, PatternRewriterListener], bool] | None = field(default=None) class-attribute instance-attribute

Function to call between each walk of the IR.

listener: PatternRewriterListener = field(default_factory=PatternRewriterListener) class-attribute instance-attribute

The listener that will be called when an operation or block is modified.

__init__(pattern: RewritePattern, walk_regions_first: bool = False, apply_recursively: bool = True, walk_reverse: bool = False, post_walk_func: Callable[[Region, PatternRewriterListener], bool] | None = None, listener: PatternRewriterListener = PatternRewriterListener()) -> None

rewrite_module(module: ModuleOp) -> bool

Rewrite operations nested in the given operation by repeatedly applying the pattern. Returns True if the IR was mutated.

Source code in xdsl/pattern_rewriter.py
802
803
804
805
806
807
def rewrite_module(self, module: ModuleOp) -> bool:
    """
    Rewrite operations nested in the given operation by repeatedly applying the
    pattern. Returns `True` if the IR was mutated.
    """
    return self.rewrite_region(module.body)

rewrite_region(region: Region) -> bool

Rewrite operations nested in the given operation by repeatedly applying the pattern. Returns True if the IR was mutated.

Source code in xdsl/pattern_rewriter.py
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
def rewrite_region(self, region: Region) -> bool:
    """
    Rewrite operations nested in the given operation by repeatedly applying the
    pattern. Returns `True` if the IR was mutated.
    """
    pattern_listener = self._get_rewriter_listener()

    self._populate_worklist(region)
    op_was_modified = self._process_worklist(pattern_listener)
    if self.post_walk_func is not None:
        op_was_modified |= self.post_walk_func(region, pattern_listener)

    if not self.apply_recursively:
        return op_was_modified

    result = op_was_modified

    while op_was_modified:
        self._populate_worklist(region)
        op_was_modified = self._process_worklist(pattern_listener)
        if self.post_walk_func is not None:
            op_was_modified |= self.post_walk_func(region, pattern_listener)

    return result

op_type_rewrite_pattern(func: Callable[[_RewritePatternT, _OperationT, PatternRewriter], None]) -> Callable[[_RewritePatternT, Operation, PatternRewriter], None]

This function is intended to be used as a decorator on a RewritePatter method. It uses type hints to match on a specific operation type before calling the decorated function.

Source code in xdsl/pattern_rewriter.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def op_type_rewrite_pattern(
    func: Callable[[_RewritePatternT, _OperationT, PatternRewriter], None],
) -> Callable[[_RewritePatternT, Operation, PatternRewriter], None]:
    """
    This function is intended to be used as a decorator on a RewritePatter
    method. It uses type hints to match on a specific operation type before
    calling the decorated function.
    """
    # Get the operation argument and check that it is a subclass of Operation
    params = [
        param for param in inspect.signature(func, eval_str=True).parameters.values()
    ]
    if len(params) != 3:
        raise Exception(
            "op_type_rewrite_pattern expects the decorated function to "
            "have two non-self arguments."
        )
    is_method = params[0].name == "self"
    if is_method:
        if len(params) != 3:
            raise Exception(
                "op_type_rewrite_pattern expects the decorated method to "
                "have two non-self arguments."
            )
    else:
        if len(params) != 2:
            raise Exception(
                "op_type_rewrite_pattern expects the decorated function to "
                "have two arguments."
            )
    expected_type: type[_OperationT] = params[-2].annotation

    expected_types = (expected_type,)
    if get_origin(expected_type) in [Union, UnionType]:
        expected_types = get_args(expected_type)

    if not all(issubclass(t, Operation) for t in expected_types):
        raise Exception(
            "op_type_rewrite_pattern expects the first non-self argument "
            "type hint to be an `Operation` subclass or a union of `Operation` "
            "subclasses."
        )

    def impl(self: _RewritePatternT, op: Operation, rewriter: PatternRewriter) -> None:
        if isinstance(op, expected_type):
            func(self, op, rewriter)

    return impl

attr_constr_rewrite_pattern(constr: AttrConstraint[_AttributeT]) -> Callable[[Callable[[_TypeConversionPatternT, _AttributeT], Attribute | None]], Callable[[_TypeConversionPatternT, Attribute], Attribute | None]]

This function is intended to be used as a decorator on a TypeConversionPattern method. It uses the passed constraint to match on a specific attribute type before calling the decorated function.

Source code in xdsl/pattern_rewriter.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
def attr_constr_rewrite_pattern(
    constr: AttrConstraint[_AttributeT],
) -> Callable[
    [Callable[[_TypeConversionPatternT, _AttributeT], Attribute | None]],
    Callable[[_TypeConversionPatternT, Attribute], Attribute | None],
]:
    """
    This function is intended to be used as a decorator on a TypeConversionPattern
    method. It uses the passed constraint to match on a specific attribute type before
    calling the decorated function.
    """

    def wrapper(
        func: Callable[[_TypeConversionPatternT, _AttributeT], _ConvertedT | None],
    ):
        @wraps(func)
        def impl(self: _TypeConversionPatternT, typ: Attribute) -> Attribute | None:
            if constr.verifies(typ):
                return func(self, typ)
            return None

        return impl

    return wrapper

attr_type_rewrite_pattern(func: Callable[[_TypeConversionPatternT, _AttributeT], Attribute | None]) -> Callable[[_TypeConversionPatternT, Attribute], Attribute | None]

This function is intended to be used as a decorator on a TypeConversionPattern method. It uses type hints to match on a specific attribute type before calling the decorated function.

Source code in xdsl/pattern_rewriter.py
567
568
569
570
571
572
573
574
575
576
577
578
def attr_type_rewrite_pattern(
    func: Callable[[_TypeConversionPatternT, _AttributeT], Attribute | None],
) -> Callable[[_TypeConversionPatternT, Attribute], Attribute | None]:
    """
    This function is intended to be used as a decorator on a TypeConversionPattern
    method. It uses type hints to match on a specific attribute type before
    calling the decorated function.
    """
    params = list(inspect.signature(func, eval_str=True).parameters.values())
    expected_type: type[_AttributeT] = params[-1].annotation
    constr = base(expected_type)
    return attr_constr_rewrite_pattern(constr)(func)