Bases: RewritePattern
Removes the inputs corresponding to unused arguments in the body.
Source code in xdsl/transforms/canonicalization_patterns/memref_stream.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 | class RemoveUnusedInitOperandPattern(RewritePattern):
"""
Removes the inputs corresponding to unused arguments in the body.
"""
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
if memref_stream.IteratorTypeAttr.interleaved() in op.iterator_types.data:
# Do not run on interleaved ops
return
block = op.body.block
block_args = block.args
inputs = op.inputs
num_inputs = len(inputs)
unused_input_indices = tuple(
index for index, arg in enumerate(block_args[:num_inputs]) if not arg.uses
)
if not unused_input_indices:
# All args have uses, nothing to remove
return
outputs = op.outputs
optional_inits: list[SSAValue | None] = [None] * len(outputs)
for init_index, init in zip(op.init_indices, op.inits, strict=True):
optional_inits[init_index.data] = init
assert len(optional_inits) == len(outputs)
new_inputs: list[SSAValue] = []
new_indexing_maps: list[AffineMapAttr] = []
for index, arg in enumerate(block_args[:num_inputs]):
drop_operand = index in unused_input_indices
if drop_operand:
arg.erase()
continue
new_indexing_maps.append(op.indexing_maps.data[index])
new_inputs.append(inputs[index])
new_indexing_maps.extend(op.indexing_maps.data[num_inputs:])
for i in reversed(unused_input_indices):
block.erase_arg(block_args[i])
rewriter.replace_op(
op,
memref_stream.GenericOp(
new_inputs,
op.outputs,
op.inits,
rewriter.move_region_contents_to_new_regions(op.body),
ArrayAttr(new_indexing_maps),
op.iterator_types,
op.bounds,
op.init_indices,
op.doc,
op.library_call,
),
)
|
Source code in xdsl/transforms/canonicalization_patterns/memref_stream.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 | @op_type_rewrite_pattern
def match_and_rewrite(
self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
if memref_stream.IteratorTypeAttr.interleaved() in op.iterator_types.data:
# Do not run on interleaved ops
return
block = op.body.block
block_args = block.args
inputs = op.inputs
num_inputs = len(inputs)
unused_input_indices = tuple(
index for index, arg in enumerate(block_args[:num_inputs]) if not arg.uses
)
if not unused_input_indices:
# All args have uses, nothing to remove
return
outputs = op.outputs
optional_inits: list[SSAValue | None] = [None] * len(outputs)
for init_index, init in zip(op.init_indices, op.inits, strict=True):
optional_inits[init_index.data] = init
assert len(optional_inits) == len(outputs)
new_inputs: list[SSAValue] = []
new_indexing_maps: list[AffineMapAttr] = []
for index, arg in enumerate(block_args[:num_inputs]):
drop_operand = index in unused_input_indices
if drop_operand:
arg.erase()
continue
new_indexing_maps.append(op.indexing_maps.data[index])
new_inputs.append(inputs[index])
new_indexing_maps.extend(op.indexing_maps.data[num_inputs:])
for i in reversed(unused_input_indices):
block.erase_arg(block_args[i])
rewriter.replace_op(
op,
memref_stream.GenericOp(
new_inputs,
op.outputs,
op.inits,
rewriter.move_region_contents_to_new_regions(op.body),
ArrayAttr(new_indexing_maps),
op.iterator_types,
op.bounds,
op.init_indices,
op.doc,
op.library_call,
),
)
|