Skip to content

Eqsat create eclasses

eqsat_create_eclasses

InsertEclassOps

Bases: RewritePattern

Inserts a equivalence.class after each operation except module op and function op.

Source code in xdsl/transforms/eqsat_create_eclasses.py
49
50
51
52
53
54
55
56
class InsertEclassOps(RewritePattern):
    """
    Inserts a `equivalence.class` after each operation except module op and function op.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
        insert_eclass_ops(op.body.block, rewriter)

match_and_rewrite(op: func.FuncOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/eqsat_create_eclasses.py
54
55
56
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
    insert_eclass_ops(op.body.block, rewriter)

EqsatCreateEclassesPass dataclass

Bases: ModulePass

Create initial eclasses from an MLIR program.

Input example
func.func @test(%a : index, %b : index) -> (index) {
    %c = arith.addi %a, %b : index
    func.return %c : index
}

Output example: mlir func.func @test(%a : index, %b : index) -> (index) { %a_eq = equivalence.class %a : index %b_eq = equivalence.class %b : index %c = arith.addi %a_eq, %b_eq : index %c_eq = equivalence.class %c : index func.return %c_eq : index }

Source code in xdsl/transforms/eqsat_create_eclasses.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class EqsatCreateEclassesPass(ModulePass):
    """
    Create initial eclasses from an MLIR program.

    Input example:
       ```mlir
       func.func @test(%a : index, %b : index) -> (index) {
           %c = arith.addi %a, %b : index
           func.return %c : index
       }
       ```
    Output example:
        ```mlir
        func.func @test(%a : index, %b : index) -> (index) {
            %a_eq = equivalence.class %a : index
            %b_eq = equivalence.class %b : index
            %c = arith.addi %a_eq, %b_eq : index
            %c_eq = equivalence.class %c : index
            func.return %c_eq : index
        }
        ```
    """

    name = "eqsat-create-eclasses"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier([InsertEclassOps()]),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'eqsat-create-eclasses' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/eqsat_create_eclasses.py
84
85
86
87
88
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier([InsertEclassOps()]),
        apply_recursively=False,
    ).rewrite_module(op)

insert_eclass_ops(block: Block, rewriter: PatternRewriter)

Source code in xdsl/transforms/eqsat_create_eclasses.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def insert_eclass_ops(block: Block, rewriter: PatternRewriter):
    # Insert equivalence.class for each operation
    for op in block.ops:
        results = op.results

        # Skip special ops such as return ops
        if isinstance(op, func.ReturnOp):
            continue

        if len(results) != 1:
            raise DiagnosticException("Ops with non-single results not handled")

        eclass_op = equivalence.ClassOp(results[0])
        insertion_point = InsertPoint.after(op)
        Rewriter.insert_op(eclass_op, insertion_point)
        rewriter.replace_uses_with_if(
            results[0],
            eclass_op.results[0],
            lambda u: not isinstance(u.operation, equivalence.ClassOp),
        )

    # Insert equivalence.class for each arg
    for arg in block.args:
        eclass_op = equivalence.ClassOp(arg)
        insertion_point = InsertPoint.at_start(block)
        Rewriter.insert_op(eclass_op, insertion_point)
        rewriter.replace_uses_with_if(
            arg,
            eclass_op.results[0],
            lambda u: not isinstance(u.operation, equivalence.ClassOp),
        )