Skip to content

Lift arith to linalg

lift_arith_to_linalg

LiftAddfPass

Bases: RewritePattern

Source code in xdsl/transforms/lift_arith_to_linalg.py
18
19
20
21
22
23
24
class LiftAddfPass(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.AddfOp, rewriter: PatternRewriter, /):
        if isa(op.result.type, TensorType[Attribute]):
            rewriter.replace_op(
                op, linalg.AddOp(op.operands, [op.lhs], [op.result.type])
            )

match_and_rewrite(op: arith.AddfOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lift_arith_to_linalg.py
19
20
21
22
23
24
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.AddfOp, rewriter: PatternRewriter, /):
    if isa(op.result.type, TensorType[Attribute]):
        rewriter.replace_op(
            op, linalg.AddOp(op.operands, [op.lhs], [op.result.type])
        )

LiftSubfPass

Bases: RewritePattern

Source code in xdsl/transforms/lift_arith_to_linalg.py
27
28
29
30
31
32
33
class LiftSubfPass(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.SubfOp, rewriter: PatternRewriter, /):
        if isa(op.result.type, TensorType[Attribute]):
            rewriter.replace_op(
                op, linalg.SubOp(op.operands, [op.lhs], [op.result.type])
            )

match_and_rewrite(op: arith.SubfOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lift_arith_to_linalg.py
28
29
30
31
32
33
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.SubfOp, rewriter: PatternRewriter, /):
    if isa(op.result.type, TensorType[Attribute]):
        rewriter.replace_op(
            op, linalg.SubOp(op.operands, [op.lhs], [op.result.type])
        )

LiftMulfPass

Bases: RewritePattern

Source code in xdsl/transforms/lift_arith_to_linalg.py
36
37
38
39
40
41
42
class LiftMulfPass(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: arith.MulfOp, rewriter: PatternRewriter, /):
        if isa(op.result.type, TensorType[Attribute]):
            rewriter.replace_op(
                op, linalg.MulOp(op.operands, [op.lhs], [op.result.type])
            )

match_and_rewrite(op: arith.MulfOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/lift_arith_to_linalg.py
37
38
39
40
41
42
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.MulfOp, rewriter: PatternRewriter, /):
    if isa(op.result.type, TensorType[Attribute]):
        rewriter.replace_op(
            op, linalg.MulOp(op.operands, [op.lhs], [op.result.type])
        )

LiftArithToLinalg dataclass

Bases: ModulePass

Pass that lifts arith ops to linalg in order to make use of destination-passing style and bufferization.

Source code in xdsl/transforms/lift_arith_to_linalg.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@dataclass(frozen=True)
class LiftArithToLinalg(ModulePass):
    """
    Pass that lifts arith ops to linalg in order to make use of destination-passing style and bufferization.
    """

    name = "lift-arith-to-linalg"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        module_pass = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LiftAddfPass(),
                    LiftSubfPass(),
                    LiftMulfPass(),
                ]
            ),
            walk_reverse=False,
            apply_recursively=False,
        )
        module_pass.rewrite_module(op)

name = 'lift-arith-to-linalg' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/lift_arith_to_linalg.py
53
54
55
56
57
58
59
60
61
62
63
64
65
def apply(self, ctx: Context, op: ModuleOp) -> None:
    module_pass = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                LiftAddfPass(),
                LiftSubfPass(),
                LiftMulfPass(),
            ]
        ),
        walk_reverse=False,
        apply_recursively=False,
    )
    module_pass.rewrite_module(op)