Skip to content

Individual rewrite

individual_rewrite

ApplyIndividualRewritePass dataclass

Bases: ModulePass

Module pass representing the application of an individual rewrite pattern to a module.

Matches the operation at the provided index within the module and applies the rewrite pattern specified by the operation and pattern names.

Source code in xdsl/transforms/individual_rewrite.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@dataclass(frozen=True)
class ApplyIndividualRewritePass(ModulePass):
    """
    Module pass representing the application of an individual rewrite pattern to a module.

    Matches the operation at the provided index within the module and applies the rewrite
    pattern specified by the operation and pattern names.
    """

    name = "apply-individual-rewrite"

    matched_operation_index: int = field()
    operation_name: str = field()
    pattern_name: str = field()

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        matched_operation = OpSelector(
            self.matched_operation_index, self.operation_name
        ).get_op(op)
        rewriter = PatternRewriter(matched_operation)

        for trait in matched_operation.get_traits_of_type(
            HasCanonicalizationPatternsTrait
        ):
            for pattern in trait.get_canonicalization_patterns():
                if type(pattern).__name__ == self.pattern_name:
                    pattern.match_and_rewrite(matched_operation, rewriter)
                    if not rewriter.has_done_action:
                        raise ValueError(
                            f"Invalid rewrite ({self.pattern_name}) for operation "
                            f"({matched_operation}) at location "
                            f"{self.matched_operation_index}."
                        )
                    return

        raise ValueError(
            f"Pattern name {self.pattern_name} not found for the provided operation name."
        )

    @classmethod
    def schedule_space(cls, ctx: Context, module_op: ModuleOp):
        res: list[ApplyIndividualRewritePass] = []

        for op_idx, matched_op in enumerate(module_op.walk()):
            if (
                trait := matched_op.get_trait(HasCanonicalizationPatternsTrait)
            ) is None:
                continue

            pattern_by_name = {
                type(pattern).__name__: pattern
                for pattern in trait.get_canonicalization_patterns()
            }
            selector = OpSelector(op_idx, matched_op.name)

            for pattern_name, pattern in pattern_by_name.items():
                cloned_op = selector.get_op(module_op.clone())
                rewriter = PatternRewriter(cloned_op)
                pattern.match_and_rewrite(cloned_op, rewriter)
                if rewriter.has_done_action:
                    res.append(
                        ApplyIndividualRewritePass(
                            op_idx, cloned_op.name, pattern_name
                        ),
                    )

        return tuple(res)

name = 'apply-individual-rewrite' class-attribute instance-attribute

matched_operation_index: int = field() class-attribute instance-attribute

operation_name: str = field() class-attribute instance-attribute

pattern_name: str = field() class-attribute instance-attribute

__init__(matched_operation_index: int = field(), operation_name: str = field(), pattern_name: str = field()) -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/individual_rewrite.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def apply(self, ctx: Context, op: ModuleOp) -> None:
    matched_operation = OpSelector(
        self.matched_operation_index, self.operation_name
    ).get_op(op)
    rewriter = PatternRewriter(matched_operation)

    for trait in matched_operation.get_traits_of_type(
        HasCanonicalizationPatternsTrait
    ):
        for pattern in trait.get_canonicalization_patterns():
            if type(pattern).__name__ == self.pattern_name:
                pattern.match_and_rewrite(matched_operation, rewriter)
                if not rewriter.has_done_action:
                    raise ValueError(
                        f"Invalid rewrite ({self.pattern_name}) for operation "
                        f"({matched_operation}) at location "
                        f"{self.matched_operation_index}."
                    )
                return

    raise ValueError(
        f"Pattern name {self.pattern_name} not found for the provided operation name."
    )

schedule_space(ctx: Context, module_op: ModuleOp) classmethod

Source code in xdsl/transforms/individual_rewrite.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@classmethod
def schedule_space(cls, ctx: Context, module_op: ModuleOp):
    res: list[ApplyIndividualRewritePass] = []

    for op_idx, matched_op in enumerate(module_op.walk()):
        if (
            trait := matched_op.get_trait(HasCanonicalizationPatternsTrait)
        ) is None:
            continue

        pattern_by_name = {
            type(pattern).__name__: pattern
            for pattern in trait.get_canonicalization_patterns()
        }
        selector = OpSelector(op_idx, matched_op.name)

        for pattern_name, pattern in pattern_by_name.items():
            cloned_op = selector.get_op(module_op.clone())
            rewriter = PatternRewriter(cloned_op)
            pattern.match_and_rewrite(cloned_op, rewriter)
            if rewriter.has_done_action:
                res.append(
                    ApplyIndividualRewritePass(
                        op_idx, cloned_op.name, pattern_name
                    ),
                )

    return tuple(res)