Skip to content

Eqsat extract

eqsat_extract

EqsatExtractPass dataclass

Bases: ModulePass

Extracts the subprogram with the lowest cost, as specified by the min_cost_index

Source code in xdsl/transforms/eqsat_extract.py
49
50
51
52
53
54
55
56
57
class EqsatExtractPass(ModulePass):
    """
    Extracts the subprogram with the lowest cost, as specified by the `min_cost_index`
    """

    name = "eqsat-extract"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        eqsat_extract(op)

name = 'eqsat-extract' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/eqsat_extract.py
56
57
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    eqsat_extract(op)

eqsat_extract(module_op: builtin.ModuleOp)

Source code in xdsl/transforms/eqsat_extract.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def eqsat_extract(module_op: builtin.ModuleOp):
    eclass_ops = [
        op for op in module_op.walk() if isinstance(op, equivalence.AnyClassOp)
    ]

    while eclass_ops:
        op = eclass_ops.pop()
        if not op.result.uses:
            # Erase all operands and uses of operands
            ops_to_erase = [op] + [
                operand.owner
                for operand in op.operands
                if isinstance(operand, OpResult)
            ]
        elif (min_cost_index := op.min_cost_index) is not None:
            # Replace eclass result by operand
            operand = op.operands[min_cost_index.data]
            op.result.replace_uses_with_if(operand, lambda use: use.operation is not op)
            # Erase eclass and all operand ops excluding min cost one
            ops_to_erase = [op] + [
                operand.owner
                for i, operand in enumerate(op.operands)
                if i != min_cost_index.data and isinstance(operand, OpResult)
            ]
            # Delete cost
            if (
                isinstance(operand, OpResult)
                and equivalence.EQSAT_COST_LABEL in operand.op.attributes
            ):
                del operand.op.attributes[equivalence.EQSAT_COST_LABEL]

        else:
            # Don't touch this eclass or its operands
            ops_to_erase = ()

        for op in ops_to_erase:
            Rewriter.erase_op(op)

    assert not eclass_ops