Linalg transformations
linalg_transformations
FuseMultiplyAddPass
dataclass
Bases: RewritePattern
Source code in xdsl/transforms/linalg_transformations.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | |
require_scalar_factor: bool
instance-attribute
require_erasable_mul: bool
instance-attribute
__init__(require_scalar_factor: bool, require_erasable_mul: bool) -> None
match_and_rewrite(mul: linalg.MulOp, rewriter: PatternRewriter)
Source code in xdsl/transforms/linalg_transformations.py
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | |
is_scalar_constant(op: SSAValue) -> bool
staticmethod
Returns if the value is a scalar. This currently checks for scalar constants, and could in the future be extended to check for dynamically provided scalar values expanded via linalg.fill
Source code in xdsl/transforms/linalg_transformations.py
85 86 87 88 89 90 91 92 93 94 95 96 97 98 | |
LinalgFuseMultiplyAddPass
dataclass
Bases: ModulePass
Pass that fuses linalg multiply and add ops into a generic fma.
Source code in xdsl/transforms/linalg_transformations.py
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | |
name = 'linalg-fuse-multiply-add'
class-attribute
instance-attribute
require_scalar_factor: bool = False
class-attribute
instance-attribute
Set to require one of the mul factors to be a scalar constant
require_erasable_mul: bool = False
class-attribute
instance-attribute
Set to only fuse ops if the multiply has no other use and can be erased
__init__(require_scalar_factor: bool = False, require_erasable_mul: bool = False) -> None
apply(ctx: Context, op: ModuleOp) -> None
Source code in xdsl/transforms/linalg_transformations.py
115 116 117 118 119 120 | |
build_generic_fma(mul_op1: SSAValue, mul_op2: SSAValue, add_op: SSAValue, out: SSAValue) -> linalg.GenericOp
Source code in xdsl/transforms/linalg_transformations.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | |