Skip to content

Riscv lower parallel mov

riscv_lower_parallel_mov

ALLOWED_INT_WIDTHS = [32, 64] module-attribute

ParallelMovPattern

Bases: RewritePattern

Source code in xdsl/transforms/riscv_lower_parallel_mov.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class ParallelMovPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: riscv.ParallelMovOp, rewriter: PatternRewriter):
        srcs = cast(SSAValues[SSAValue[riscv.RISCVRegisterType]], op.inputs)
        dsts = cast(SSAValues[SSAValue[riscv.RISCVRegisterType]], op.outputs)
        src_types = cast(Sequence[riscv.RISCVRegisterType], op.inputs.types)
        dst_types = cast(Sequence[riscv.RISCVRegisterType], op.outputs.types)

        if not (
            all(i.is_allocated for i in src_types)
            and all(i.is_allocated for i in dst_types)
        ):
            raise PassFailedException("All registers must be allocated")

        # make a list of free registers for each type so we can add to it later
        free_registers: dict[
            type[riscv.RISCVRegisterType], list[riscv.RISCVRegisterType]
        ] = defaultdict(list)
        if op.free_registers is not None:
            for reg in op.free_registers:
                free_registers[type(reg)].append(reg)

        num_operands = len(op.operands)

        results: list[SSAValue | None] = [None] * num_operands

        # cache the indices from output register type to the index in the outputs array
        # this is typed as Attribute to ensure we can index by input type
        output_index = {register: idx for idx, register in enumerate(dst_types)}

        src_type_by_src = {
            src: src_type
            for src, src_type in zip(srcs, op.input_widths.iter_values(), strict=True)
        }

        # We have a graph with nodes as registers and directed edges as moves,
        # pointing from source to destination.
        # Every node has at most 1 in edge since we can't write to a register twice.
        # Therefore the graph forms a directed pseudoforest, which is a group of
        # connected components with at most 1 cycle each.

        # If we ignore the cycles, we will have a forest.
        # For each tree, we need to perform each move such that all out edges of a node
        # are before the in edge, so a post-order traversal.
        # We can do this iteratively by storing processed edges for each node.
        # Then we iterate up the tree from every leaf, stopping whenever we encounter
        # a node where all out edges haven't been processed yet.

        # store the back edges of the graph
        src_by_dst_type: dict[
            riscv.RISCVRegisterType, SSAValue[riscv.RISCVRegisterType]
        ] = {}
        leaves = set(dst_types)
        unprocessed_children = Counter[SSAValue]()

        for idx, src, dst in zip(range(num_operands), srcs, dsts, strict=True):
            # src.type points to something so it can't be a leaf
            leaves.discard(src.type)

            if src.type == dst.type:
                # Trivial case of moving register to itself.
                # We can ignore all instances of this
                results[idx] = src
            else:
                src_by_dst_type[dst.type] = src
                unprocessed_children[src] += 1

        for dst_type in dst_types:
            if dst_type not in leaves:
                continue
            # Iterate up the tree by traversing back edges.
            while dst_type in src_by_dst_type:
                src = src_by_dst_type[dst_type]
                mvop = _insert_mv_op(rewriter, src, dst_type, src_type_by_src[src])
                # sanity check since we should only have 1 result per output
                assert results[output_index[dst_type]] is None
                results[output_index[dst_type]] = mvop.results[0]
                unprocessed_children[src] -= 1
                # only continue up the tree if all children were processed
                if unprocessed_children[src]:
                    break
                dst_type = src.type

            # if dst is a register that has no input, we can use it as a free register.
            if dst_type not in src_by_dst_type:
                free_registers[type(dst_type)].append(dst_type)

        # If we have a cycle in the graph, all trees pointing into the cycle cannot
        # enter the cycle because it will have an unprocessed node from its previous
        # node in the cycle.
        # Therefore, all nodes in the cycle will be unprocessed, and their results
        # will still be None

        for idx, val in enumerate(results):
            if val is None:
                reg_type = type(dst_types[idx])
                # Find a free register.
                # We don't have to modify its value since all the cycles
                # can use the same register.
                if not free_registers[reg_type]:
                    if reg_type != riscv.IntRegisterType:
                        raise PassFailedException(
                            "Float cyclic move without free register"
                        )

                    # Otherwise if the registers are all integers, we can use the xor swapping
                    # trick to repeatedly swap values to perform the cyclic move.

                    # we don't take srcs[idx] -> dsts[idx] since we need
                    # the SSAValue for both input and output
                    out = srcs[idx]
                    inp = src_by_dst_type[out.type]

                    while inp.type != out.type:
                        # we know these are ints since input and output are of the same type
                        inp = cast(SSAValue[riscv.IntRegisterType], inp)
                        out = cast(SSAValue[riscv.IntRegisterType], out)
                        nw_out, nw_inp = _insert_swap_ops(rewriter, inp, out)
                        # after the swap, the input is in the right place, the input's input
                        # needs to be moved to the new output
                        results[output_index[nw_inp.type]] = nw_inp
                        inp = src_by_dst_type[inp.type]
                        out = nw_out

                    results[output_index[src_types[idx]]] = out
                    continue

                # Break the cycle by using free register
                temp_reg = free_registers[reg_type][0]
                # split the current mov
                cur_input = srcs[idx]
                cur_output = dsts[idx]
                temp_ssa_type = op.input_widths.get_values()[idx]
                temp_ssa = _insert_mv_op(rewriter, cur_input, temp_reg, temp_ssa_type)
                # iterate up the chain until we reach the current output
                dst_type = cur_input.type
                while dst_type != cur_output.type:
                    src = src_by_dst_type[dst_type]
                    mvop = _insert_mv_op(rewriter, src, dst_type, src_type_by_src[src])
                    results[output_index[dst_type]] = mvop.results[0]
                    dst_type = src.type
                # finish the split mov
                mvop = _insert_mv_op(rewriter, temp_ssa, cur_output.type, temp_ssa_type)
                results[idx] = mvop.results[0]

        rewriter.replace_matched_op((), results)

match_and_rewrite(op: riscv.ParallelMovOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/riscv_lower_parallel_mov.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv.ParallelMovOp, rewriter: PatternRewriter):
    srcs = cast(SSAValues[SSAValue[riscv.RISCVRegisterType]], op.inputs)
    dsts = cast(SSAValues[SSAValue[riscv.RISCVRegisterType]], op.outputs)
    src_types = cast(Sequence[riscv.RISCVRegisterType], op.inputs.types)
    dst_types = cast(Sequence[riscv.RISCVRegisterType], op.outputs.types)

    if not (
        all(i.is_allocated for i in src_types)
        and all(i.is_allocated for i in dst_types)
    ):
        raise PassFailedException("All registers must be allocated")

    # make a list of free registers for each type so we can add to it later
    free_registers: dict[
        type[riscv.RISCVRegisterType], list[riscv.RISCVRegisterType]
    ] = defaultdict(list)
    if op.free_registers is not None:
        for reg in op.free_registers:
            free_registers[type(reg)].append(reg)

    num_operands = len(op.operands)

    results: list[SSAValue | None] = [None] * num_operands

    # cache the indices from output register type to the index in the outputs array
    # this is typed as Attribute to ensure we can index by input type
    output_index = {register: idx for idx, register in enumerate(dst_types)}

    src_type_by_src = {
        src: src_type
        for src, src_type in zip(srcs, op.input_widths.iter_values(), strict=True)
    }

    # We have a graph with nodes as registers and directed edges as moves,
    # pointing from source to destination.
    # Every node has at most 1 in edge since we can't write to a register twice.
    # Therefore the graph forms a directed pseudoforest, which is a group of
    # connected components with at most 1 cycle each.

    # If we ignore the cycles, we will have a forest.
    # For each tree, we need to perform each move such that all out edges of a node
    # are before the in edge, so a post-order traversal.
    # We can do this iteratively by storing processed edges for each node.
    # Then we iterate up the tree from every leaf, stopping whenever we encounter
    # a node where all out edges haven't been processed yet.

    # store the back edges of the graph
    src_by_dst_type: dict[
        riscv.RISCVRegisterType, SSAValue[riscv.RISCVRegisterType]
    ] = {}
    leaves = set(dst_types)
    unprocessed_children = Counter[SSAValue]()

    for idx, src, dst in zip(range(num_operands), srcs, dsts, strict=True):
        # src.type points to something so it can't be a leaf
        leaves.discard(src.type)

        if src.type == dst.type:
            # Trivial case of moving register to itself.
            # We can ignore all instances of this
            results[idx] = src
        else:
            src_by_dst_type[dst.type] = src
            unprocessed_children[src] += 1

    for dst_type in dst_types:
        if dst_type not in leaves:
            continue
        # Iterate up the tree by traversing back edges.
        while dst_type in src_by_dst_type:
            src = src_by_dst_type[dst_type]
            mvop = _insert_mv_op(rewriter, src, dst_type, src_type_by_src[src])
            # sanity check since we should only have 1 result per output
            assert results[output_index[dst_type]] is None
            results[output_index[dst_type]] = mvop.results[0]
            unprocessed_children[src] -= 1
            # only continue up the tree if all children were processed
            if unprocessed_children[src]:
                break
            dst_type = src.type

        # if dst is a register that has no input, we can use it as a free register.
        if dst_type not in src_by_dst_type:
            free_registers[type(dst_type)].append(dst_type)

    # If we have a cycle in the graph, all trees pointing into the cycle cannot
    # enter the cycle because it will have an unprocessed node from its previous
    # node in the cycle.
    # Therefore, all nodes in the cycle will be unprocessed, and their results
    # will still be None

    for idx, val in enumerate(results):
        if val is None:
            reg_type = type(dst_types[idx])
            # Find a free register.
            # We don't have to modify its value since all the cycles
            # can use the same register.
            if not free_registers[reg_type]:
                if reg_type != riscv.IntRegisterType:
                    raise PassFailedException(
                        "Float cyclic move without free register"
                    )

                # Otherwise if the registers are all integers, we can use the xor swapping
                # trick to repeatedly swap values to perform the cyclic move.

                # we don't take srcs[idx] -> dsts[idx] since we need
                # the SSAValue for both input and output
                out = srcs[idx]
                inp = src_by_dst_type[out.type]

                while inp.type != out.type:
                    # we know these are ints since input and output are of the same type
                    inp = cast(SSAValue[riscv.IntRegisterType], inp)
                    out = cast(SSAValue[riscv.IntRegisterType], out)
                    nw_out, nw_inp = _insert_swap_ops(rewriter, inp, out)
                    # after the swap, the input is in the right place, the input's input
                    # needs to be moved to the new output
                    results[output_index[nw_inp.type]] = nw_inp
                    inp = src_by_dst_type[inp.type]
                    out = nw_out

                results[output_index[src_types[idx]]] = out
                continue

            # Break the cycle by using free register
            temp_reg = free_registers[reg_type][0]
            # split the current mov
            cur_input = srcs[idx]
            cur_output = dsts[idx]
            temp_ssa_type = op.input_widths.get_values()[idx]
            temp_ssa = _insert_mv_op(rewriter, cur_input, temp_reg, temp_ssa_type)
            # iterate up the chain until we reach the current output
            dst_type = cur_input.type
            while dst_type != cur_output.type:
                src = src_by_dst_type[dst_type]
                mvop = _insert_mv_op(rewriter, src, dst_type, src_type_by_src[src])
                results[output_index[dst_type]] = mvop.results[0]
                dst_type = src.type
            # finish the split mov
            mvop = _insert_mv_op(rewriter, temp_ssa, cur_output.type, temp_ssa_type)
            results[idx] = mvop.results[0]

    rewriter.replace_matched_op((), results)

RISCVLowerParallelMovPass dataclass

Bases: ModulePass

Lowers ParallelMovOp in a module into separate moves.

Source code in xdsl/transforms/riscv_lower_parallel_mov.py
214
215
216
217
218
219
220
221
@dataclass(frozen=True)
class RISCVLowerParallelMovPass(ModulePass):
    """Lowers ParallelMovOp in a module into separate moves."""

    name = "riscv-lower-parallel-mov"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(ParallelMovPattern()).rewrite_module(op)

name = 'riscv-lower-parallel-mov' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/riscv_lower_parallel_mov.py
220
221
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(ParallelMovPattern()).rewrite_module(op)