Skip to content

Cf

cf

AssertTrue

Bases: RewritePattern

Erase assertion if argument is constant true.

Source code in xdsl/transforms/canonicalization_patterns/cf.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class AssertTrue(RewritePattern):
    """Erase assertion if argument is constant true."""

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.AssertOp, rewriter: PatternRewriter):
        owner = op.arg.owner

        if not isinstance(owner, arith.ConstantOp):
            return

        value = owner.value

        if not isinstance(value, IntegerAttr):
            return

        if not value.value.data:
            return

        rewriter.replace_op(op, [])

match_and_rewrite(op: cf.AssertOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.AssertOp, rewriter: PatternRewriter):
    owner = op.arg.owner

    if not isinstance(owner, arith.ConstantOp):
        return

    value = owner.value

    if not isinstance(value, IntegerAttr):
        return

    if not value.value.data:
        return

    rewriter.replace_op(op, [])

SimplifyBrToBlockWithSinglePred

Bases: RewritePattern

Simplify a branch to a block that has a single predecessor. This effectively merges the two blocks.

Source code in xdsl/transforms/canonicalization_patterns/cf.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class SimplifyBrToBlockWithSinglePred(RewritePattern):
    """
    Simplify a branch to a block that has a single predecessor. This effectively
    merges the two blocks.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.BranchOp, rewriter: PatternRewriter):
        succ = op.successor
        parent = op.parent_block()
        if parent is None:
            return

        # Check that the successor block has a single predecessor
        if succ == parent or len(succ.predecessors()) != 1:
            return

        br_operands = op.operands
        rewriter.erase_op(op)
        rewriter.inline_block(succ, InsertPoint.at_end(parent), br_operands)

match_and_rewrite(op: cf.BranchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.BranchOp, rewriter: PatternRewriter):
    succ = op.successor
    parent = op.parent_block()
    if parent is None:
        return

    # Check that the successor block has a single predecessor
    if succ == parent or len(succ.predecessors()) != 1:
        return

    br_operands = op.operands
    rewriter.erase_op(op)
    rewriter.inline_block(succ, InsertPoint.at_end(parent), br_operands)

SimplifyPassThroughBr

Bases: RewritePattern

br ^bb1 ^bb1 br ^bbN(...)

-> br ^bbN(...)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class SimplifyPassThroughBr(RewritePattern):
    """
      br ^bb1
    ^bb1
      br ^bbN(...)

     -> br ^bbN(...)
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.BranchOp, rewriter: PatternRewriter):
        # Check the successor doesn't point back to the current block
        parent = op.parent_block()
        if parent is None or op.successor == parent:
            return

        ret = collapse_branch(op.successor, op.arguments)
        if ret is None:
            return
        (block, args) = ret

        rewriter.replace_op(op, cf.BranchOp(block, *args))

match_and_rewrite(op: cf.BranchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
119
120
121
122
123
124
125
126
127
128
129
130
131
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.BranchOp, rewriter: PatternRewriter):
    # Check the successor doesn't point back to the current block
    parent = op.parent_block()
    if parent is None or op.successor == parent:
        return

    ret = collapse_branch(op.successor, op.arguments)
    if ret is None:
        return
    (block, args) = ret

    rewriter.replace_op(op, cf.BranchOp(block, *args))

SimplifyConstCondBranchPred

Bases: RewritePattern

cf.cond_br true, ^bb1, ^bb2 -> br ^bb1 cf.cond_br false, ^bb1, ^bb2 -> br ^bb2

Source code in xdsl/transforms/canonicalization_patterns/cf.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class SimplifyConstCondBranchPred(RewritePattern):
    """
    cf.cond_br true, ^bb1, ^bb2
     -> br ^bb1
    cf.cond_br false, ^bb1, ^bb2
     -> br ^bb2
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.ConditionalBranchOp, rewriter: PatternRewriter):
        # Check if cond operand is constant
        cond = const_evaluate_operand(op.cond)

        if cond is None:
            return

        if cond:
            rewriter.replace_op(op, cf.BranchOp(op.then_block, *op.then_arguments))
        else:
            rewriter.replace_op(op, cf.BranchOp(op.else_block, *op.else_arguments))

match_and_rewrite(op: cf.ConditionalBranchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
142
143
144
145
146
147
148
149
150
151
152
153
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.ConditionalBranchOp, rewriter: PatternRewriter):
    # Check if cond operand is constant
    cond = const_evaluate_operand(op.cond)

    if cond is None:
        return

    if cond:
        rewriter.replace_op(op, cf.BranchOp(op.then_block, *op.then_arguments))
    else:
        rewriter.replace_op(op, cf.BranchOp(op.else_block, *op.else_arguments))

SimplifyPassThroughCondBranch

Bases: RewritePattern

cf.cond_br %cond, ^bb1, ^bb2 ^bb1 br ^bbN(...) ^bb2 br ^bbK(...)

-> cf.cond_br %cond, ^bbN(...), ^bbK(...)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class SimplifyPassThroughCondBranch(RewritePattern):
    """
      cf.cond_br %cond, ^bb1, ^bb2
    ^bb1
      br ^bbN(...)
    ^bb2
      br ^bbK(...)

     -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.ConditionalBranchOp, rewriter: PatternRewriter):
        # Try to collapse both branches
        collapsed_then = collapse_branch(op.then_block, op.then_arguments)
        collapsed_else = collapse_branch(op.else_block, op.else_arguments)

        # If neither collapsed then we return
        if collapsed_then is None and collapsed_else is None:
            return

        (new_then, new_then_args) = collapsed_then or (op.then_block, op.then_arguments)

        (new_else, new_else_args) = collapsed_else or (op.else_block, op.else_arguments)

        rewriter.replace_op(
            op,
            cf.ConditionalBranchOp(
                op.cond, new_then, new_then_args, new_else, new_else_args
            ),
        )

match_and_rewrite(op: cf.ConditionalBranchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.ConditionalBranchOp, rewriter: PatternRewriter):
    # Try to collapse both branches
    collapsed_then = collapse_branch(op.then_block, op.then_arguments)
    collapsed_else = collapse_branch(op.else_block, op.else_arguments)

    # If neither collapsed then we return
    if collapsed_then is None and collapsed_else is None:
        return

    (new_then, new_then_args) = collapsed_then or (op.then_block, op.then_arguments)

    (new_else, new_else_args) = collapsed_else or (op.else_block, op.else_arguments)

    rewriter.replace_op(
        op,
        cf.ConditionalBranchOp(
            op.cond, new_then, new_then_args, new_else, new_else_args
        ),
    )

SimplifyCondBranchIdenticalSuccessors

Bases: RewritePattern

cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) -> br ^bb1(A, ..., N)

cf.cond_br %cond, ^bb1(A), ^bb1(B) -> %select = arith.select %cond, A, B br ^bb1(%select)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class SimplifyCondBranchIdenticalSuccessors(RewritePattern):
    """
    cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
     -> br ^bb1(A, ..., N)

    cf.cond_br %cond, ^bb1(A), ^bb1(B)
     -> %select = arith.select %cond, A, B
        br ^bb1(%select)
    """

    @staticmethod
    def _merge_operand(
        op1: SSAValue,
        op2: SSAValue,
        rewriter: PatternRewriter,
        cond_br: cf.ConditionalBranchOp,
    ) -> SSAValue:
        if op1 == op2:
            return op1
        select = arith.SelectOp(cond_br.cond, op1, op2)
        rewriter.insert_op(select, InsertPoint.before(cond_br))
        return select.result

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.ConditionalBranchOp, rewriter: PatternRewriter):
        # Check that the true and false destinations are the same
        if op.then_block != op.else_block:
            return

        merged_operands = tuple(
            self._merge_operand(op1, op2, rewriter, op)
            for (op1, op2) in zip(op.then_arguments, op.else_arguments, strict=True)
        )

        rewriter.replace_op(op, cf.BranchOp(op.then_block, *merged_operands))

match_and_rewrite(op: cf.ConditionalBranchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
212
213
214
215
216
217
218
219
220
221
222
223
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.ConditionalBranchOp, rewriter: PatternRewriter):
    # Check that the true and false destinations are the same
    if op.then_block != op.else_block:
        return

    merged_operands = tuple(
        self._merge_operand(op1, op2, rewriter, op)
        for (op1, op2) in zip(op.then_arguments, op.else_arguments, strict=True)
    )

    rewriter.replace_op(op, cf.BranchOp(op.then_block, *merged_operands))

CondBranchTruthPropagation

Bases: RewritePattern

cf.cond_br %arg0, ^trueB, ^falseB

^trueB: "test.consumer1"(%arg0) : (i1) -> () ...

^falseB: "test.consumer2"(%arg0) : (i1) -> () ...

->

cf.cond_br %arg0, ^trueB, ^falseB ^trueB: "test.consumer1"(%true) : (i1) -> () ...

^falseB: "test.consumer2"(%false) : (i1) -> () ...

Source code in xdsl/transforms/canonicalization_patterns/cf.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
class CondBranchTruthPropagation(RewritePattern):
    """
      cf.cond_br %arg0, ^trueB, ^falseB

    ^trueB:
      "test.consumer1"(%arg0) : (i1) -> ()
       ...

    ^falseB:
      "test.consumer2"(%arg0) : (i1) -> ()
      ...

    ->

      cf.cond_br %arg0, ^trueB, ^falseB
    ^trueB:
      "test.consumer1"(%true) : (i1) -> ()
      ...

    ^falseB:
      "test.consumer2"(%false) : (i1) -> ()
      ...
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.ConditionalBranchOp, rewriter: PatternRewriter):
        if len(op.then_block.predecessors()) == 1:
            if any(
                use.operation.parent_block() is op.then_block for use in op.cond.uses
            ):
                const_true = arith.ConstantOp(BoolAttr.from_bool(True))
                rewriter.insert_op(const_true, InsertPoint.before(op))
                rewriter.replace_uses_with_if(
                    op.cond,
                    const_true.result,
                    lambda use: use.operation.parent_block() is op.then_block,
                )
        if len(op.else_block.predecessors()) == 1:
            if any(
                use.operation.parent_block() is op.else_block for use in op.cond.uses
            ):
                const_false = arith.ConstantOp(BoolAttr.from_bool(False))
                rewriter.insert_op(const_false, InsertPoint.before(op))
                rewriter.replace_uses_with_if(
                    op.cond,
                    const_false.result,
                    lambda use: use.operation.parent_block() is op.else_block,
                )

match_and_rewrite(op: cf.ConditionalBranchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.ConditionalBranchOp, rewriter: PatternRewriter):
    if len(op.then_block.predecessors()) == 1:
        if any(
            use.operation.parent_block() is op.then_block for use in op.cond.uses
        ):
            const_true = arith.ConstantOp(BoolAttr.from_bool(True))
            rewriter.insert_op(const_true, InsertPoint.before(op))
            rewriter.replace_uses_with_if(
                op.cond,
                const_true.result,
                lambda use: use.operation.parent_block() is op.then_block,
            )
    if len(op.else_block.predecessors()) == 1:
        if any(
            use.operation.parent_block() is op.else_block for use in op.cond.uses
        ):
            const_false = arith.ConstantOp(BoolAttr.from_bool(False))
            rewriter.insert_op(const_false, InsertPoint.before(op))
            rewriter.replace_uses_with_if(
                op.cond,
                const_false.result,
                lambda use: use.operation.parent_block() is op.else_block,
            )

SimplifySwitchWithOnlyDefault

Bases: RewritePattern

switch %flag : i32, [ default: ^bb1 ] -> br ^bb1

Source code in xdsl/transforms/canonicalization_patterns/cf.py
276
277
278
279
280
281
282
283
284
285
286
287
class SimplifySwitchWithOnlyDefault(RewritePattern):
    """
    switch %flag : i32, [
      default:  ^bb1
    ]
     -> br ^bb1
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
        if not op.case_blocks:
            rewriter.replace_op(op, cf.BranchOp(op.default_block, *op.default_operands))

match_and_rewrite(op: cf.SwitchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
284
285
286
287
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
    if not op.case_blocks:
        rewriter.replace_op(op, cf.BranchOp(op.default_block, *op.default_operands))

DropSwitchCasesThatMatchDefault

Bases: RewritePattern

switch %flag : i32, [ default: ^bb1, 42: ^bb1, 43: ^bb2 ] -> switch %flag : i32, [ default: ^bb1, 43: ^bb2 ]

Source code in xdsl/transforms/canonicalization_patterns/cf.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
class DropSwitchCasesThatMatchDefault(RewritePattern):
    """
    switch %flag : i32, [
      default: ^bb1,
      42: ^bb1,
      43: ^bb2
    ]
    ->
    switch %flag : i32, [
      default: ^bb1,
      43: ^bb2
    ]
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
        def predicate(
            switch_case: IntegerAttr,
            block: Block,
            operands: Sequence[Operation | SSAValue],
        ) -> bool:
            return block == op.default_block and operands == op.default_operands

        drop_case_helper(rewriter, op, predicate)

match_and_rewrite(op: cf.SwitchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
349
350
351
352
353
354
355
356
357
358
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
    def predicate(
        switch_case: IntegerAttr,
        block: Block,
        operands: Sequence[Operation | SSAValue],
    ) -> bool:
        return block == op.default_block and operands == op.default_operands

    drop_case_helper(rewriter, op, predicate)

SimplifyConstSwitchValue

Bases: RewritePattern

switch %c_42 : i32, [ default: ^bb1, 42: ^bb2, 43: ^bb3 ] -> br ^bb2

Source code in xdsl/transforms/canonicalization_patterns/cf.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
class SimplifyConstSwitchValue(RewritePattern):
    """
    switch %c_42 : i32, [
      default: ^bb1,
      42: ^bb2,
      43: ^bb3
    ]
    -> br ^bb2
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
        if (flag := const_evaluate_operand(op.flag)) is not None:
            fold_switch(op, rewriter, flag)

match_and_rewrite(op: cf.SwitchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
397
398
399
400
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
    if (flag := const_evaluate_operand(op.flag)) is not None:
        fold_switch(op, rewriter, flag)

SimplifyPassThroughSwitch

Bases: RewritePattern

switch %c_42 : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3 -> switch %c_42 : i32, [ default: ^bb1, 42: ^bb3, ]

Source code in xdsl/transforms/canonicalization_patterns/cf.py
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
class SimplifyPassThroughSwitch(RewritePattern):
    """
    switch %c_42 : i32, [
      default: ^bb1,
      42: ^bb2,
    ]
    ^bb2:
      br ^bb3
    ->
    switch %c_42 : i32, [
      default: ^bb1,
      42: ^bb3,
    ]
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
        requires_change = False

        new_case_blocks: list[Block] = []
        new_case_operands: list[Sequence[Operation | SSAValue]] = []

        for block, operands in zip(op.case_blocks, op.case_operand, strict=True):
            collapsed = collapse_branch(block, operands)
            requires_change |= collapsed is not None
            (new_block, new_operands) = collapsed or (block, operands)
            new_case_blocks.append(new_block)
            new_case_operands.append(new_operands)

        collapsed = collapse_branch(op.default_block, op.default_operands)

        requires_change |= collapsed is not None

        (default_block, default_operands) = collapsed or (
            op.default_block,
            op.default_operands,
        )

        if requires_change:
            rewriter.replace_op(
                op,
                cf.SwitchOp(
                    op.flag,
                    default_block,
                    default_operands,
                    op.case_values,
                    new_case_blocks,
                    new_case_operands,
                ),
            )

match_and_rewrite(op: cf.SwitchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
    requires_change = False

    new_case_blocks: list[Block] = []
    new_case_operands: list[Sequence[Operation | SSAValue]] = []

    for block, operands in zip(op.case_blocks, op.case_operand, strict=True):
        collapsed = collapse_branch(block, operands)
        requires_change |= collapsed is not None
        (new_block, new_operands) = collapsed or (block, operands)
        new_case_blocks.append(new_block)
        new_case_operands.append(new_operands)

    collapsed = collapse_branch(op.default_block, op.default_operands)

    requires_change |= collapsed is not None

    (default_block, default_operands) = collapsed or (
        op.default_block,
        op.default_operands,
    )

    if requires_change:
        rewriter.replace_op(
            op,
            cf.SwitchOp(
                op.flag,
                default_block,
                default_operands,
                op.case_values,
                new_case_blocks,
                new_case_operands,
            ),
        )

SimplifySwitchFromSwitchOnSameCondition

Bases: RewritePattern

switch %flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch %flag : i32, [ default: ^bb3, 42: ^bb4 ] -> switch %flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb4

and

switch %flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: switch %flag : i32, [ default: ^bb3, 43: ^bb4 ] -> switch %flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb2: br ^bb3

and

switch %flag : i32, [ default: ^bb1, 42: ^bb2 ] ^bb1: switch %flag : i32, [ default: ^bb3, 42: ^bb4, 43: ^bb5 ] -> switch %flag : i32, [ default: ^bb1, 42: ^bb2, ] ^bb1: switch %flag : i32, [ default: ^bb3, 43: ^bb5 ]

Source code in xdsl/transforms/canonicalization_patterns/cf.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
class SimplifySwitchFromSwitchOnSameCondition(RewritePattern):
    """
    switch %flag : i32, [
      default: ^bb1,
      42: ^bb2,
    ]
    ^bb2:
      switch %flag : i32, [
        default: ^bb3,
        42: ^bb4
      ]
    ->
    switch %flag : i32, [
      default: ^bb1,
      42: ^bb2,
    ]
    ^bb2:
      br ^bb4

     and

    switch %flag : i32, [
      default: ^bb1,
      42: ^bb2,
    ]
    ^bb2:
      switch %flag : i32, [
        default: ^bb3,
        43: ^bb4
      ]
    ->
    switch %flag : i32, [
      default: ^bb1,
      42: ^bb2,
    ]
    ^bb2:
      br ^bb3

    and

    switch %flag : i32, [
      default: ^bb1,
      42: ^bb2
    ]
    ^bb1:
      switch %flag : i32, [
        default: ^bb3,
        42: ^bb4,
        43: ^bb5
      ]
    ->
    switch %flag : i32, [
      default: ^bb1,
      42: ^bb2,
    ]
    ^bb1:
      switch %flag : i32, [
        default: ^bb3,
        43: ^bb5
      ]
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
        block = op.parent_block()
        if block is None:
            return
        if (pred := block.get_unique_use()) is None:
            return
        switch = pred.operation
        if not isinstance(switch, cf.SwitchOp):
            return

        if switch.flag != op.flag:
            return

        case_values = switch.case_values
        if case_values is None:
            return

        if pred.index != 0:
            fold_switch(
                op,
                rewriter,
                case_values.get_values()[pred.index - 1],
            )
        else:

            def predicate(
                switch_case: IntegerAttr,
                block: Block,
                operands: Sequence[Operation | SSAValue],
            ) -> bool:
                return switch_case in case_values.get_attrs()

            drop_case_helper(rewriter, op, predicate)

match_and_rewrite(op: cf.SwitchOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/canonicalization_patterns/cf.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
    block = op.parent_block()
    if block is None:
        return
    if (pred := block.get_unique_use()) is None:
        return
    switch = pred.operation
    if not isinstance(switch, cf.SwitchOp):
        return

    if switch.flag != op.flag:
        return

    case_values = switch.case_values
    if case_values is None:
        return

    if pred.index != 0:
        fold_switch(
            op,
            rewriter,
            case_values.get_values()[pred.index - 1],
        )
    else:

        def predicate(
            switch_case: IntegerAttr,
            block: Block,
            operands: Sequence[Operation | SSAValue],
        ) -> bool:
            return switch_case in case_values.get_attrs()

        drop_case_helper(rewriter, op, predicate)

collapse_branch(successor: Block, successor_operands: Sequence[SSAValue]) -> tuple[Block, Sequence[SSAValue]] | None

Given a successor, try to collapse it to a new destination if it only contains a passthrough unconditional branch. If the successor is collapsable, successor and successorOperands are updated to reference the new destination and values. argStorage is used as storage if operands to the collapsed successor need to be remapped. It must outlive uses of successorOperands.

Source code in xdsl/transforms/canonicalization_patterns/cf.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def collapse_branch(
    successor: Block, successor_operands: Sequence[SSAValue]
) -> tuple[Block, Sequence[SSAValue]] | None:
    """
    Given a successor, try to collapse it to a new destination if it only
    contains a passthrough unconditional branch. If the successor is
    collapsable, `successor` and `successorOperands` are updated to reference
    the new destination and values. `argStorage` is used as storage if operands
    to the collapsed successor need to be remapped. It must outlive uses of
    successorOperands.
    """

    # Check that successor only contains branch
    if len(successor.ops) != 1:
        return

    branch = successor.ops.first
    # Check that the terminator is an unconditional branch
    if not isinstance(branch, cf.BranchOp):
        return

    # Check that the arguments are only used within the terminator
    for argument in successor.args:
        for user in argument.uses:
            if user.operation != branch:
                return

    # Don't try to collapse branches to infinite loops.
    if branch.successor == successor:
        return

    # Remap operands
    operands = branch.operands

    new_operands = tuple(
        successor_operands[operand.index]
        if isinstance(operand, BlockArgument) and operand.owner is successor
        else operand
        for operand in operands
    )

    return (branch.successor, new_operands)

drop_case_helper(rewriter: PatternRewriter, op: cf.SwitchOp, predicate: Callable[[IntegerAttr, Block, Sequence[Operation | SSAValue]], bool])

Source code in xdsl/transforms/canonicalization_patterns/cf.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def drop_case_helper(
    rewriter: PatternRewriter,
    op: cf.SwitchOp,
    predicate: Callable[[IntegerAttr, Block, Sequence[Operation | SSAValue]], bool],
):
    case_values = op.case_values
    if case_values is None:
        return
    requires_change = False

    new_case_values: list[int] = []
    new_case_blocks: list[Block] = []
    new_case_operands: list[Sequence[Operation | SSAValue]] = []

    for switch_case, block, operands in zip(
        case_values.get_attrs(),
        op.case_blocks,
        op.case_operand,
        strict=True,
    ):
        int_switch_case = cast(IntegerAttr, switch_case)
        if predicate(int_switch_case, block, operands):
            requires_change = True
            continue
        new_case_values.append(cast(IntegerAttr, switch_case).value.data)
        new_case_blocks.append(block)
        new_case_operands.append(operands)

    if requires_change:
        rewriter.replace_op(
            op,
            cf.SwitchOp(
                op.flag,
                op.default_block,
                op.default_operands,
                DenseIntElementsAttr.from_list(
                    VectorType(case_values.get_element_type(), (len(new_case_values),)),
                    new_case_values,
                ),
                new_case_blocks,
                new_case_operands,
            ),
        )

fold_switch(switch: cf.SwitchOp, rewriter: PatternRewriter, flag: int)

Helper for folding a switch with a constant value. switch %c_42 : i32, [ default: ^bb1 , 42: ^bb2, 43: ^bb3 ] -> br ^bb2

Source code in xdsl/transforms/canonicalization_patterns/cf.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def fold_switch(switch: cf.SwitchOp, rewriter: PatternRewriter, flag: int):
    """
    Helper for folding a switch with a constant value.
    switch %c_42 : i32, [
      default: ^bb1 ,
      42: ^bb2,
      43: ^bb3
    ]
    -> br ^bb2
    """
    case_values = () if switch.case_values is None else switch.case_values.get_attrs()

    new_block, new_operands = next(
        (
            (block, operand)
            for (c, block, operand) in zip(
                case_values, switch.case_blocks, switch.case_operand, strict=True
            )
            if flag == c.value.data
        ),
        (switch.default_block, switch.default_operands),
    )

    rewriter.replace_op(switch, cf.BranchOp(new_block, *new_operands))