Skip to content

Apply eqsat pdl

apply_eqsat_pdl

ApplyEqsatPDLPass dataclass

Bases: ModulePass

A pass that applies PDL patterns using equality saturation.

Source code in xdsl/transforms/apply_eqsat_pdl.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@dataclass(frozen=True)
class ApplyEqsatPDLPass(ModulePass):
    """
    A pass that applies PDL patterns using equality saturation.
    """

    name = "apply-eqsat-pdl"

    pdl_file: str | None = None
    """Path to external PDL file containing patterns. If None, patterns are taken from the input module."""

    max_iterations: int = 20
    """Maximum number of iterations to run the equality saturation algorithm."""

    individual_patterns: bool = False
    """
    Whether to convert and apply patterns individually rather than all together.

    When True: Each pattern is converted to PDL_interp separately and applied individually
    in each iteration.

    When False (default): All patterns are converted together, potentially producing a more efficient
    matcher by reusing equivalent expressions.
    """

    optimize_matcher: bool = False
    """When enabled, the matcher is optimized to evaluate equality constraints early."""

    def _load_pdl_module(self, ctx: Context, op: builtin.ModuleOp) -> builtin.ModuleOp:
        """Load PDL module from file or use the input module."""
        if self.pdl_file is not None:
            assert os.path.exists(self.pdl_file)
            with open(self.pdl_file) as f:
                pdl_module_str = f.read()
                parser = Parser(ctx, pdl_module_str)
                return parser.parse_module()
        else:
            return op

    def _convert_single_pattern(
        self, ctx: Context, pattern_op: pdl.PatternOp
    ) -> builtin.ModuleOp:
        """Convert a single PDL pattern to PDL_interp."""
        pattern_copy = pattern_op.clone()
        temp_module = builtin.ModuleOp([pattern_copy])

        pdl_to_pdl_interp = MLIROptPass(
            arguments=("--convert-pdl-to-pdl-interp", "-allow-unregistered-dialect")
        )
        pdl_to_pdl_interp.apply(ctx, temp_module)
        ConvertPDLInterpToEqsatPDLInterpPass().apply(ctx, temp_module)
        return temp_module

    def _extract_matcher_and_rewriters(
        self, temp_module: builtin.ModuleOp
    ) -> tuple[pdl_interp.FuncOp, pdl_interp.FuncOp]:
        """Extract matcher and rewriter function from converted module."""
        matcher = SymbolTable.lookup_symbol(temp_module, "matcher")
        assert isinstance(matcher, pdl_interp.FuncOp)
        assert matcher is not None, "matcher function not found"

        rewriter_module = cast(
            builtin.ModuleOp, SymbolTable.lookup_symbol(temp_module, "rewriters")
        )
        assert rewriter_module.body.first_block is not None
        rewriter_func = rewriter_module.body.first_block.first_op
        assert isinstance(rewriter_func, pdl_interp.FuncOp)

        return matcher, rewriter_func

    def _apply_individual_patterns(
        self, ctx: Context, op: builtin.ModuleOp, pdl_module: builtin.ModuleOp
    ) -> None:
        """Apply patterns individually in separate iterations."""
        patterns = (
            pattern for pattern in pdl_module.ops if isinstance(pattern, pdl.PatternOp)
        )

        implementations = EqsatPDLInterpFunctions()
        implementations.populate_known_ops(op)

        matchers_module = builtin.ModuleOp([])
        rewriters_module = builtin.ModuleOp([], sym_name=StringAttr("rewriters"))
        matchers_builder = Builder(InsertPoint.at_end(matchers_module.body.block))
        matchers_builder.insert_op(rewriters_module)
        rewriters_builder = Builder(InsertPoint.at_end(rewriters_module.body.block))

        interpreter = Interpreter(matchers_module)
        PDLInterpFunctions.set_ctx(interpreter, ctx)
        interpreter.register_implementations(implementations)
        interpreter.register_implementations(
            vanilla_pdl_interp_implementations := PDLInterpFunctions()
        )
        interpreter.register_implementations(EqsatConstraintFunctions())

        rewrite_patterns: list[PDLInterpRewritePattern] = []
        for pattern_op in patterns:
            temp_module = self._convert_single_pattern(ctx, pattern_op)
            matcher, rewriter_func = self._extract_matcher_and_rewriters(temp_module)

            assert matcher.body.last_block is not None
            assert isinstance(
                recordmatch := matcher.body.last_block.last_op,
                eqsat_pdl_interp.RecordMatchOp,
            )
            name = (
                pattern_op.sym_name
                if pattern_op.sym_name
                else StringAttr(f"pattern_{len(rewrite_patterns)}")
            )
            recordmatch.rewriter = builtin.SymbolRefAttr("rewriters", (name,))
            rewriter_func.sym_name = name

            # Detach and insert operations
            matcher.detach()
            matchers_builder.insert_op(matcher)

            rewriter_func.detach()
            rewriters_builder.insert_op(rewriter_func)

            rewrite_pattern = PDLInterpRewritePattern(
                matcher, interpreter, vanilla_pdl_interp_implementations, name.data
            )
            rewrite_patterns.append(rewrite_pattern)

        # Initialize listener
        listener = PatternRewriterListener()
        listener.operation_modification_handler.append(
            implementations.modification_handler
        )

        # Main iteration loop
        for _i in range(self.max_iterations):
            # Apply each pattern individually
            for rewrite_pattern in rewrite_patterns:
                assert rewrite_pattern.matcher is not None
                walker = PatternRewriteWalker(rewrite_pattern, apply_recursively=False)
                walker.listener = listener
                walker.rewrite_module(op)

            # Execute all pending rewrites
            implementations.execute_pending_rewrites(interpreter)

            if not implementations.worklist:
                break

            implementations.rebuild(interpreter)

    def _apply_combined_patterns(
        self, ctx: Context, op: builtin.ModuleOp, pdl_module: builtin.ModuleOp
    ) -> None:
        """Apply all patterns together (original behavior)."""
        pdl_to_pdl_interp = MLIROptPass(
            arguments=("--convert-pdl-to-pdl-interp", "-allow-unregistered-dialect")
        )
        pdl_to_pdl_interp.apply(ctx, pdl_module)
        pdl_interp_module = pdl_module
        ConvertPDLInterpToEqsatPDLInterpPass().apply(ctx, pdl_interp_module)

        matcher = SymbolTable.lookup_symbol(pdl_interp_module, "matcher")
        assert isinstance(matcher, pdl_interp.FuncOp)
        assert matcher is not None, "matcher function not found"

        # Initialize interpreter and implementations
        interpreter = Interpreter(pdl_interp_module)
        pdl_interp_functions = PDLInterpFunctions()
        eqsat_pdl_interp_functions = EqsatPDLInterpFunctions()
        PDLInterpFunctions.set_ctx(interpreter, ctx)
        eqsat_pdl_interp_functions.populate_known_ops(op)
        interpreter.register_implementations(eqsat_pdl_interp_functions)
        interpreter.register_implementations(pdl_interp_functions)
        interpreter.register_implementations(EqsatConstraintFunctions())
        rewrite_pattern = PDLInterpRewritePattern(
            matcher, interpreter, pdl_interp_functions
        )

        listener = PatternRewriterListener()
        listener.operation_modification_handler.append(
            eqsat_pdl_interp_functions.modification_handler
        )
        walker = PatternRewriteWalker(rewrite_pattern, apply_recursively=False)
        walker.listener = listener

        for _i in range(self.max_iterations):
            walker.rewrite_module(op)
            eqsat_pdl_interp_functions.execute_pending_rewrites(interpreter)

            if not eqsat_pdl_interp_functions.worklist:
                break

            eqsat_pdl_interp_functions.rebuild(interpreter)

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        pdl_module = self._load_pdl_module(ctx, op)

        if self.individual_patterns:
            self._apply_individual_patterns(ctx, op, pdl_module)
        else:
            self._apply_combined_patterns(ctx, op, pdl_module)

name = 'apply-eqsat-pdl' class-attribute instance-attribute

pdl_file: str | None = None class-attribute instance-attribute

Path to external PDL file containing patterns. If None, patterns are taken from the input module.

max_iterations: int = 20 class-attribute instance-attribute

Maximum number of iterations to run the equality saturation algorithm.

individual_patterns: bool = False class-attribute instance-attribute

Whether to convert and apply patterns individually rather than all together.

When True: Each pattern is converted to PDL_interp separately and applied individually in each iteration.

When False (default): All patterns are converted together, potentially producing a more efficient matcher by reusing equivalent expressions.

optimize_matcher: bool = False class-attribute instance-attribute

When enabled, the matcher is optimized to evaluate equality constraints early.

__init__(pdl_file: str | None = None, max_iterations: int = 20, individual_patterns: bool = False, optimize_matcher: bool = False) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/apply_eqsat_pdl.py
217
218
219
220
221
222
223
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    pdl_module = self._load_pdl_module(ctx, op)

    if self.individual_patterns:
        self._apply_individual_patterns(ctx, op, pdl_module)
    else:
        self._apply_combined_patterns(ctx, op, pdl_module)