Skip to content

Linalg transformations

linalg_transformations

FuseMultiplyAddPass dataclass

Bases: RewritePattern

Source code in xdsl/transforms/linalg_transformations.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@dataclass(frozen=True)
class FuseMultiplyAddPass(RewritePattern):
    require_scalar_factor: bool
    require_erasable_mul: bool

    @op_type_rewrite_pattern
    def match_and_rewrite(self, mul: linalg.MulOp, rewriter: PatternRewriter, /):
        if (
            len(mul.res) != 1
            or self.require_erasable_mul
            and len(set(use.operation for use in mul.res[0].uses)) != 1
        ):
            return

        for add in set(
            use.operation
            for use in mul.res[0].uses
            if isinstance(use.operation, linalg.AddOp)
            and mul.res[0] in use.operation.inputs
        ):
            # if the `require_scalar_factor` flag is set, check if either operand of `mul` is a scalar
            if (
                self.require_scalar_factor
                and not self.is_scalar_constant(mul.inputs[0])
                and not self.is_scalar_constant(mul.inputs[1])
            ):
                return

            # the operand of `add` that is not the `mul` result
            add_operand = (
                add.inputs[0] if mul.res[0] == add.inputs[1] else add.inputs[1]
            )

            # build fma op
            fma = build_generic_fma(
                mul.inputs[0], mul.inputs[1], add_operand, mul.outputs[0]
            )

            # replace in position of the add op
            rewriter.replace_op(add, fma)
            if not mul.res[0].uses:
                rewriter.erase_op(mul)

    @staticmethod
    def is_scalar_constant(op: SSAValue) -> bool:
        """
        Returns if the value is a scalar. This currently checks for scalar constants, and could
        in the future be extended to check for dynamically provided scalar values expanded via linalg.fill
        """
        return (
            isinstance(op, OpResult)
            and isinstance(op.op, arith.ConstantOp)
            and (
                not isinstance(v := op.op.value, DenseIntOrFPElementsAttr)
                or v.is_splat()
            )
        )

require_scalar_factor: bool instance-attribute

require_erasable_mul: bool instance-attribute

__init__(require_scalar_factor: bool, require_erasable_mul: bool) -> None

match_and_rewrite(mul: linalg.MulOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/linalg_transformations.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@op_type_rewrite_pattern
def match_and_rewrite(self, mul: linalg.MulOp, rewriter: PatternRewriter, /):
    if (
        len(mul.res) != 1
        or self.require_erasable_mul
        and len(set(use.operation for use in mul.res[0].uses)) != 1
    ):
        return

    for add in set(
        use.operation
        for use in mul.res[0].uses
        if isinstance(use.operation, linalg.AddOp)
        and mul.res[0] in use.operation.inputs
    ):
        # if the `require_scalar_factor` flag is set, check if either operand of `mul` is a scalar
        if (
            self.require_scalar_factor
            and not self.is_scalar_constant(mul.inputs[0])
            and not self.is_scalar_constant(mul.inputs[1])
        ):
            return

        # the operand of `add` that is not the `mul` result
        add_operand = (
            add.inputs[0] if mul.res[0] == add.inputs[1] else add.inputs[1]
        )

        # build fma op
        fma = build_generic_fma(
            mul.inputs[0], mul.inputs[1], add_operand, mul.outputs[0]
        )

        # replace in position of the add op
        rewriter.replace_op(add, fma)
        if not mul.res[0].uses:
            rewriter.erase_op(mul)

is_scalar_constant(op: SSAValue) -> bool staticmethod

Returns if the value is a scalar. This currently checks for scalar constants, and could in the future be extended to check for dynamically provided scalar values expanded via linalg.fill

Source code in xdsl/transforms/linalg_transformations.py
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@staticmethod
def is_scalar_constant(op: SSAValue) -> bool:
    """
    Returns if the value is a scalar. This currently checks for scalar constants, and could
    in the future be extended to check for dynamically provided scalar values expanded via linalg.fill
    """
    return (
        isinstance(op, OpResult)
        and isinstance(op.op, arith.ConstantOp)
        and (
            not isinstance(v := op.op.value, DenseIntOrFPElementsAttr)
            or v.is_splat()
        )
    )

LinalgFuseMultiplyAddPass dataclass

Bases: ModulePass

Pass that fuses linalg multiply and add ops into a generic fma.

Source code in xdsl/transforms/linalg_transformations.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@dataclass(frozen=True)
class LinalgFuseMultiplyAddPass(ModulePass):
    """
    Pass that fuses linalg multiply and add ops into a `generic` fma.
    """

    name = "linalg-fuse-multiply-add"

    require_scalar_factor: bool = False
    """Set to require one of the mul factors to be a scalar constant"""

    require_erasable_mul: bool = False
    """Set to only fuse ops if the multiply has no other use and can be erased"""

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        module_pass = PatternRewriteWalker(
            FuseMultiplyAddPass(self.require_scalar_factor, self.require_erasable_mul),
            apply_recursively=False,
        )
        module_pass.rewrite_module(op)

name = 'linalg-fuse-multiply-add' class-attribute instance-attribute

require_scalar_factor: bool = False class-attribute instance-attribute

Set to require one of the mul factors to be a scalar constant

require_erasable_mul: bool = False class-attribute instance-attribute

Set to only fuse ops if the multiply has no other use and can be erased

__init__(require_scalar_factor: bool = False, require_erasable_mul: bool = False) -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/linalg_transformations.py
115
116
117
118
119
120
def apply(self, ctx: Context, op: ModuleOp) -> None:
    module_pass = PatternRewriteWalker(
        FuseMultiplyAddPass(self.require_scalar_factor, self.require_erasable_mul),
        apply_recursively=False,
    )
    module_pass.rewrite_module(op)

build_generic_fma(mul_op1: SSAValue, mul_op2: SSAValue, add_op: SSAValue, out: SSAValue) -> linalg.GenericOp

Source code in xdsl/transforms/linalg_transformations.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def build_generic_fma(
    mul_op1: SSAValue, mul_op2: SSAValue, add_op: SSAValue, out: SSAValue
) -> linalg.GenericOp:
    inputs = (mul_op1, mul_op2, add_op)
    outputs = (out,)

    arg_types = linalg.NamedOperation.body_arg_types((*inputs, *outputs))

    @Builder.implicit_region(arg_types)
    def body(args: tuple[BlockArgument, ...]) -> None:
        m = arith.MulfOp(args[0], args[1])
        a = arith.AddfOp(m, args[2])
        linalg.YieldOp(a)

    return linalg.GenericOp(
        inputs,
        outputs,
        body,
        4 * [AffineMapAttr(AffineMap.from_callable(lambda i,: (i,)))],
        [linalg.IteratorTypeAttr.parallel()],
        [out.type],
    )