Jax use donated arguments
jax_use_donated_arguments
SubstituteDonatedTensors
dataclass
Bases: RewritePattern
Looks at returned tensors and if they match donated argument tensors ask bufferization to use them as buffers.
Source code in xdsl/transforms/jax_use_donated_arguments.py
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | |
remove_matched_outputs: bool = False
class-attribute
instance-attribute
__init__(remove_matched_outputs: bool = False) -> None
match_and_rewrite(op: ReturnOp, rewriter: PatternRewriter)
Source code in xdsl/transforms/jax_use_donated_arguments.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | |
JaxUseDonatedArguments
dataclass
Bases: ModulePass
Source code in xdsl/transforms/jax_use_donated_arguments.py
124 125 126 127 128 129 130 131 132 133 134 135 136 | |
name = 'jax-use-donated-arguments'
class-attribute
instance-attribute
remove_matched_outputs: bool = False
class-attribute
instance-attribute
__init__(remove_matched_outputs: bool = False) -> None
apply(ctx: Context, op: builtin.ModuleOp) -> None
Source code in xdsl/transforms/jax_use_donated_arguments.py
130 131 132 133 134 135 136 | |
map_donated_input_by_output(donatable_inputs: Sequence[BlockArgument], outputs: Sequence[SSAValue]) -> dict[SSAValue, BlockArgument]
Find suitable donated buffers for each of returned variables. Each buffer can be used only once. Types of the buffer and the variable should match.
Source code in xdsl/transforms/jax_use_donated_arguments.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | |