13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151 | class LowerRiscvScfForPattern(RewritePattern):
"""
Inline the for loop body into its parent region, using `Block`s to represent control
flow. The `Block` containing the `ForOp` is split into two, and the blocks in the
`body` of the for loop are spliced between them. Additional operations are inserted
into the block before, and the block after to handle the initialization of the
iteration argument, and loop-carried variables, as well as control flow. If the for
loop contained other `riscv_scf` ops, they will have been rewritten by the time this
rewrite is called. Two comparison operations are inserted, one just before the loop
blocks, skipping the loop entirely if the condition is not met, and one at the end of
the loop body, to exit or continue the loop. A canonicalization step may be able to
eliminate the first check if the bounds are known at compile time.
```
+--------------------------------------------------------------+
| <code before the ForOp> |
| <definitions of %args_init...> |
| <compute initial %iv value> |
| riscv_cf.bge %iv, %ub, end, body (%iv, %args_init...) |
+--------------------------------------------------------------+
| |
-------------------| | -----------------------|
| v v |
| +--------------------------------------------------------------+ |
| | body-first(%iv, %args_body...): | |
| | <body contents> | |
| +--------------------------------------------------------------+ |
| | |
| ... |
| | |
| +--------------------------------------------------------------+ |
| | body-last: | |
| | <body contents> | |
| | <%yields... = operands of yield> | |
| | <%ub and %step visible by dominance> | |
| | %new_iv =<add %step to %iv> | |
| | riscv_cf.blt %new_iv, %ub, body, end (%new_iv, %yields...) | |
| +--------------------------------------------------------------+ |
| | | |
|------------------ | |-----------------------
v v
+--------------------------------------------------------------+
| end(%iv, %args_end...): |
| <results of ForOp = %args_end> |
| <code after the ForOp> |
+--------------------------------------------------------------+
```
"""
for_idx: int
def __init__(self):
super().__init__()
self.for_idx = -1
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /):
# To ensure that we have a unique labels for each (nested) loop, we use an index
# that is incremented for each loop as a suffix.
self.for_idx += 1
suffix = f"{self.for_idx}_for"
# Start by splitting the block containing the 'scf.for' into two parts.
# The part before will get the init code, the part after will be the end point.
init_block = op.parent_block()
assert init_block is not None
body = op.body.blocks[0]
# TODO: add method to rewriter
end_block = init_block.split_before(op, arg_types=body.arg_types)
# The first argument of the loop body block is the loop counter by SCF invariant.
loop_var_reg = body.args[0].type
assert isinstance(loop_var_reg, riscv.IntRegisterType)
# Use the first block of the loop body as the condition block since it is the
# block that has the induction variable and loop-carried values as arguments.
# Split out all operations from the first block into a new block. Move all
# body blocks from the loop body region to the region containing the loop.
first_body_block = op.body.blocks[0]
last_body_block = op.body.blocks[-1]
# Get the induction variable and its register
iv = first_body_block.args[0]
iv_reg = iv.type
assert isinstance(iv_reg, riscv.IntRegisterType)
# Append the induction variable stepping logic to the last body block, add
# comparison with upper bound, and conditionally branch back into the body.
yield_op = last_body_block.last_op
assert isinstance(yield_op, riscv_scf.YieldOp)
rewriter.replace_op(
yield_op,
(
add_op := riscv.AddOp(iv, op.step, rd=iv_reg),
riscv_cf.BltOp(
add_op.rd,
op.ub,
(add_op.rd, *yield_op.operands),
(add_op.rd, *yield_op.operands),
first_body_block,
end_block,
),
),
)
rewriter.inline_region(op.body, BlockInsertPoint.before(end_block))
# Move lb to new register to initialize the iv.
# Skip for loop if condition is not satisfied at start.
rewriter.insert_op(
(
mv_op := riscv.MVOp(op.lb, rd=iv_reg),
riscv_cf.BgeOp(
mv_op.rd,
op.ub,
(mv_op.rd, *op.iter_args),
(mv_op.rd, *op.iter_args),
end_block,
first_body_block,
),
),
InsertPoint.at_end(init_block),
)
# Insert label at the start of the first body block.
rewriter.insert_op(
riscv.LabelOp(f"scf_body_{suffix}"), InsertPoint.at_start(first_body_block)
)
# Replace operation by arguments to the newly end block.
rewriter.replace_op(
op,
riscv.LabelOp(f"scf_body_end_{suffix}"),
end_block.args[1:],
)
|