Skip to content

Apply eqsat pdl interp

apply_eqsat_pdl_interp

EqsatConstraintFunctions dataclass

Bases: InterpreterFunctions

Source code in xdsl/transforms/apply_eqsat_pdl_interp.py
28
29
30
31
32
33
34
35
@register_impls
class EqsatConstraintFunctions(InterpreterFunctions):
    @impl_external("is_not_unsound")
    def run_is_not_unsound(
        self, interp: Interpreter, _op: Operation, args: PythonValues
    ):
        assert isinstance(op := args[0], Operation)
        return "unsound" not in op.attributes, ()

run_is_not_unsound(interp: Interpreter, _op: Operation, args: PythonValues)

Source code in xdsl/transforms/apply_eqsat_pdl_interp.py
30
31
32
33
34
35
@impl_external("is_not_unsound")
def run_is_not_unsound(
    self, interp: Interpreter, _op: Operation, args: PythonValues
):
    assert isinstance(op := args[0], Operation)
    return "unsound" not in op.attributes, ()

ApplyEqsatPDLInterpPass dataclass

Bases: ModulePass

Source code in xdsl/transforms/apply_eqsat_pdl_interp.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@dataclass(frozen=True)
class ApplyEqsatPDLInterpPass(ModulePass):
    name = "apply-eqsat-pdl-interp"

    pdl_interp_file: str | None = None
    max_iterations: int = _DEFAULT_MAX_ITERATIONS
    """Maximum number of iterations to run, default 20."""

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        if self.pdl_interp_file is not None:
            assert os.path.exists(self.pdl_interp_file)
            with open(self.pdl_interp_file) as f:
                pdl_interp_module_str = f.read()
                parser = Parser(ctx, pdl_interp_module_str)
                pdl_interp_module = parser.parse_module()
        else:
            pdl_interp_module = op

        apply_eqsat_pdl_interp(op, ctx, pdl_interp_module, self.max_iterations)

name = 'apply-eqsat-pdl-interp' class-attribute instance-attribute

pdl_interp_file: str | None = None class-attribute instance-attribute

max_iterations: int = _DEFAULT_MAX_ITERATIONS class-attribute instance-attribute

Maximum number of iterations to run, default 20.

__init__(pdl_interp_file: str | None = None, max_iterations: int = _DEFAULT_MAX_ITERATIONS) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/apply_eqsat_pdl_interp.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    if self.pdl_interp_file is not None:
        assert os.path.exists(self.pdl_interp_file)
        with open(self.pdl_interp_file) as f:
            pdl_interp_module_str = f.read()
            parser = Parser(ctx, pdl_interp_module_str)
            pdl_interp_module = parser.parse_module()
    else:
        pdl_interp_module = op

    apply_eqsat_pdl_interp(op, ctx, pdl_interp_module, self.max_iterations)

apply_eqsat_pdl_interp(op: builtin.ModuleOp, ctx: Context, pdl_interp_module: builtin.ModuleOp, max_iterations: int = _DEFAULT_MAX_ITERATIONS, callback: Callable[[builtin.ModuleOp], None] | None = None)

Source code in xdsl/transforms/apply_eqsat_pdl_interp.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def apply_eqsat_pdl_interp(
    op: builtin.ModuleOp,
    ctx: Context,
    pdl_interp_module: builtin.ModuleOp,
    max_iterations: int = _DEFAULT_MAX_ITERATIONS,
    callback: Callable[[builtin.ModuleOp], None] | None = None,
):
    matcher = SymbolTable.lookup_symbol(pdl_interp_module, "matcher")
    assert isinstance(matcher, pdl_interp.FuncOp)
    assert matcher is not None, "matcher function not found"

    # Initialize interpreter and implementations once
    interpreter = Interpreter(pdl_interp_module)
    pdl_interp_functions = PDLInterpFunctions()
    eqsat_pdl_interp_functions = EqsatPDLInterpFunctions()
    PDLInterpFunctions.set_ctx(interpreter, ctx)
    eqsat_pdl_interp_functions.populate_known_ops(op)
    interpreter.register_implementations(eqsat_pdl_interp_functions)
    interpreter.register_implementations(pdl_interp_functions)
    interpreter.register_implementations(EqsatConstraintFunctions())
    rewrite_pattern = PDLInterpRewritePattern(
        matcher, interpreter, pdl_interp_functions
    )

    listener = PatternRewriterListener()
    listener.operation_modification_handler.append(
        eqsat_pdl_interp_functions.modification_handler
    )
    walker = PatternRewriteWalker(rewrite_pattern, apply_recursively=False)
    walker.listener = listener

    for _i in range(max_iterations):
        # Register matches by walking the module
        walker.rewrite_module(op)
        # Execute all pending rewrites that were aggregated during matching
        eqsat_pdl_interp_functions.execute_pending_rewrites(interpreter)

        if not eqsat_pdl_interp_functions.worklist:
            break

        eqsat_pdl_interp_functions.rebuild(interpreter)
        if callback is not None:
            callback(op)