Skip to content

Csl stencil to csl wrapper

csl_stencil_to_csl_wrapper

ConvertStencilFuncToModuleWrappedPattern dataclass

Bases: RewritePattern

Wraps program in the csl_stencil dialect in a csl_wrapper module. Scans the csl_stencil.apply ops for stencil-related params, passing them as properties to the wrapped module (note, properties are in return passed as block args to the layout_module and program_module blocks).

The layout module wrapper can be used to initialise general program module params. This pass generates code to initialise stencil-specific program params and yields them from the layout module.

Source code in xdsl/transforms/csl_stencil_to_csl_wrapper.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
@dataclass(frozen=True)
class ConvertStencilFuncToModuleWrappedPattern(RewritePattern):
    """
    Wraps program in the csl_stencil dialect in a csl_wrapper module.
    Scans the csl_stencil.apply ops for stencil-related params, passing them as properties to the wrapped module
    (note, properties are in return passed as block args to the layout_module and program_module blocks).

    The layout module wrapper can be used to initialise general program module params. This pass generates code
    to initialise stencil-specific program params and yields them from the layout module.
    """

    target: csl.Target
    """
    Specifies the target architecture.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
        # erase timer stubs
        if op.is_declaration and op.sym_name.data in [TIMER_START, TIMER_END]:
            rewriter.erase_op(op)
            return
        # find csl_stencil.apply ops, abort if there are none
        apply_ops = self.get_csl_stencil_apply_ops(op)
        if len(apply_ops) == 0:
            return
        max_distance: int = 1
        width: int = 1
        height: int = 1
        z_dim_no_ghost_cells: int = 1
        z_dim: int = 1
        num_chunks: int = 1
        chunk_size: int = 1
        for apply_op in apply_ops:
            # loop over accesses to get max_distance (from which we build `pattern`)
            for ap in apply_op.get_accesses():
                if ap.is_diagonal:
                    raise ValueError("Diagonal accesses are currently not supported")
                if len(ap.offsets) > 0:
                    if ap.dims != 2:
                        raise ValueError(
                            "Stencil accesses must be 2-dimensional at this stage"
                        )
                    max_distance = max(max_distance, ap.max_distance())

            # find max x and y dimensions
            if len(shape := apply_op.topo.shape.get_values()) == 2:
                width = max(width, shape[0])
                height = max(height, shape[1])
            else:
                raise ValueError("Stencil accesses must be 2-dimensional at this stage")

            # find max z dimension - we could get this from func args, store ops, or apply ops
            # to support both bufferized and unbufferized csl_stencils, retrieve this from accumulator
            if isinstance(apply_op.done_exchange.block.args[1].type, ShapedType):
                z_dim_no_ghost_cells = max(
                    z_dim_no_ghost_cells,
                    apply_op.done_exchange.block.args[1].type.get_shape()[-1],
                )

            # retrieve z_dim from done_exchange arg[0]
            if stencil.StencilTypeConstr.verifies(
                field_t := apply_op.done_exchange.block.args[0].type
            ) and isa(el_type := field_t.element_type, TensorType | MemRefType):
                # unbufferized csl_stencil
                z_dim = max(z_dim, el_type.get_shape()[-1])
            elif isa(field_t, memref.MemRefType):
                # bufferized csl_stencil
                z_dim = max(z_dim, field_t.get_shape()[-1])

            num_chunks = max(num_chunks, apply_op.num_chunks.value.data)
            if isa(
                buf_t := apply_op.receive_chunk.block.args[0].type,
                TensorType | MemRefType,
            ):
                chunk_size = max(chunk_size, buf_t.get_shape()[-1])

        padded_z_dim: int = chunk_size * num_chunks

        # initialise module op
        module_op = csl_wrapper.ModuleOp(
            width=IntegerAttr(width + (max_distance * 2), 16),
            height=IntegerAttr(height + (max_distance * 2), 16),
            target=self.target,
            params={
                "z_dim": IntegerAttr(z_dim, 16),
                "pattern": IntegerAttr(max_distance + 1, 16),
                "num_chunks": IntegerAttr(num_chunks, 16),
                "chunk_size": IntegerAttr(chunk_size, 16),
                "padded_z_dim": IntegerAttr(padded_z_dim, 16),
            },
        )

        self.initialise_layout_module(module_op)
        module_op.program_name = op.sym_name

        # add yield op args to program_module block args
        module_op.update_program_block_args()

        # set up main function and move func.func ops into this csl.func
        main_func = csl.FuncOp(op.sym_name.data, ((), None))
        func_export = csl.SymbolExportOp(main_func.sym_name, main_func.function_type)
        args_to_ops, arg_mappings = self._translate_function_args(op.args, op.arg_attrs)
        rewriter.inline_block(
            op.body.block,
            InsertPoint.at_start(main_func.body.block),
            arg_mappings,
        )

        # initialise program_module and add main func and empty yield op
        self.initialise_program_module(
            module_op, add_ops=[*args_to_ops, func_export, main_func]
        )

        # replace func.return by unblock_cmd_stream and csl.return
        func_return = main_func.body.block.last_op
        assert isinstance(func_return, func.ReturnOp)
        assert len(func_return.arguments) == 0, (
            "Non-empty returns currently not supported"
        )
        memcpy = module_op.get_program_import("<memcpy/memcpy>")
        unblock_call = csl.MemberCallOp(
            struct=memcpy, fname="unblock_cmd_stream", params=[], result_type=None
        )
        rewriter.replace_op(func_return, [unblock_call, csl.ReturnOp()])

        # replace (now empty) func by module wrapper
        rewriter.replace_op(op, module_op)

    def get_csl_stencil_apply_ops(
        self, op: func.FuncOp
    ) -> Sequence[csl_stencil.ApplyOp]:
        result: list[csl_stencil.ApplyOp] = []
        for apply_op in op.body.walk():
            if isinstance(apply_op, csl_stencil.ApplyOp):
                result.append(apply_op)
        return result

    def _translate_function_args(
        self, args: Sequence[BlockArgument], attrs: ArrayAttr[DictionaryAttr] | None
    ) -> tuple[Sequence[Operation], Sequence[SSAValue]]:
        """
        Args of the top-level function act as the interface to the program and need to
        be translated to writable buffers.
        """
        arg_ops: list[Operation] = []
        arg_op_mapping: list[SSAValue] = []
        ptr_converts: list[Operation] = []
        export_ops: list[Operation] = []
        cast_ops: list[Operation] = []
        import_ops: list[Operation] = []

        if attrs is not None:
            for arg, attr in zip(args, attrs, strict=True):
                assert isinstance(attr, DictionaryAttr)
                if "llvm.name" in attr.data:
                    nh = attr.data["llvm.name"]
                    assert isinstance(nh, StringAttr)
                    arg.name_hint = nh.data

        for arg in args:
            arg_name = arg.name_hint or ("arg" + str(args.index(arg)))

            if isa(arg.type, stencil.FieldType[TensorType[Attribute]]) or isa(
                arg.type, memref.MemRefType
            ):
                arg_t = (
                    csl_stencil_bufferize.tensor_to_memref_type(
                        arg.type.get_element_type()
                    )
                    if isa(arg.type, stencil.FieldType[TensorType[Attribute]])
                    else arg.type
                )
                arg_ops.append(alloc := memref.AllocOp([], [], arg_t))
                ptr_converts.append(
                    address := csl.AddressOfOp(
                        alloc,
                        csl.PtrType.get(
                            arg_t.get_element_type(), is_single=False, is_const=False
                        ),
                    )
                )
                export_ops.append(csl.SymbolExportOp(arg_name, SSAValue.get(address)))
                if arg_t != arg.type:
                    cast_ops.append(
                        cast_op := builtin.UnrealizedConversionCastOp.get(
                            [alloc], [arg.type]
                        )
                    )
                    arg_op_mapping.append(cast_op.outputs[0])
                else:
                    arg_op_mapping.append(alloc.memref)
            # check if this looks like a timer
            elif isinstance(arg.type, llvm.LLVMPointerType) and all(
                isinstance(u.operation, llvm.StoreOp)
                and isinstance(u.operation.value, OpResult)
                and isinstance(u.operation.value.op, func.CallOp)
                and u.operation.value.op.callee.string_value() == TIMER_END
                for u in arg.uses
            ):
                start_end_size = 3
                arg_t = memref.MemRefType(
                    IntegerType(16, Signedness.UNSIGNED), (2 * start_end_size,)
                )
                arg_ops.append(alloc := memref.AllocOp([], [], arg_t))
                ptr_converts.append(
                    address := csl.AddressOfOp(
                        alloc,
                        csl.PtrType.get(
                            arg_t.get_element_type(), is_single=False, is_const=False
                        ),
                    )
                )
                export_ops.append(csl.SymbolExportOp(arg_name, SSAValue.get(address)))
                arg_op_mapping.append(alloc.memref)
                import_ops.append(
                    csl_wrapper.ImportOp(
                        "<time>",
                        field_name_mapping={},
                    )
                )

        return [
            *arg_ops,
            *cast_ops,
            *ptr_converts,
            *export_ops,
            *import_ops,
        ], arg_op_mapping

    def initialise_layout_module(self, module_op: csl_wrapper.ModuleOp):
        """Initialises the layout_module (wrapper block) by setting up (esp. stencil-related) program params"""

        # extract layout module params as the function has linear complexity
        param_width = module_op.get_layout_param("width")
        param_height = module_op.get_layout_param("height")
        param_x = module_op.get_layout_param("x")
        param_y = module_op.get_layout_param("y")
        param_pattern = module_op.get_layout_param("pattern")

        # fill layout module wrapper block with ops
        with ImplicitBuilder(module_op.layout_module.block):
            # import memcpy/get_params and routes
            memcpy = csl_wrapper.ImportOp(
                "<memcpy/get_params>",
                {
                    "width": param_width,
                    "height": param_height,
                },
            )
            routes = csl_wrapper.ImportOp(
                "routes.csl",
                {
                    "pattern": param_pattern,
                    "peWidth": param_width,
                    "peHeight": param_height,
                },
            )

            # set up program param `stencil_comms_params`
            all_routes = csl.MemberCallOp(
                "computeAllRoutes",
                csl.ComptimeStructType(),
                routes,
                params=[
                    param_x,
                    param_y,
                    param_width,
                    param_height,
                    param_pattern,
                ],
            )
            # set up program param `memcpy_params`
            memcpy_params = csl.MemberCallOp(
                "get_params",
                csl.ComptimeStructType(),
                memcpy,
                params=[
                    param_x,
                ],
            )

            # set up program param `is_border_region_pe`
            one = arith.ConstantOp(IntegerAttr(1, 16))
            pattern_minus_one = arith.SubiOp(param_pattern, one)
            width_minus_x = arith.SubiOp(param_width, param_x)
            height_minus_y = arith.SubiOp(param_height, param_y)
            x_lt_pattern_minus_one = arith.CmpiOp(param_x, pattern_minus_one, "slt")
            y_lt_pattern_minus_one = arith.CmpiOp(param_y, pattern_minus_one, "slt")
            width_minus_one_lt_pattern = arith.CmpiOp(
                width_minus_x, param_pattern, "slt"
            )
            height_minus_one_lt_pattern = arith.CmpiOp(
                height_minus_y, param_pattern, "slt"
            )
            or1_op = arith.OrIOp(x_lt_pattern_minus_one, y_lt_pattern_minus_one)
            or2_op = arith.OrIOp(or1_op, width_minus_one_lt_pattern)
            is_border_region_pe = arith.OrIOp(or2_op, height_minus_one_lt_pattern)

            # yield things as named params to the program module
            csl_wrapper.YieldOp.from_field_name_mapping(
                field_name_mapping={
                    "memcpy_params": memcpy_params.results[0],
                    "stencil_comms_params": all_routes.results[0],
                    "isBorderRegionPE": is_border_region_pe.result,
                }
            )

    def initialise_program_module(
        self, module_op: csl_wrapper.ModuleOp, add_ops: Sequence[Operation]
    ):
        with ImplicitBuilder(module_op.program_module.block):
            csl_wrapper.ImportOp(
                "<memcpy/memcpy>",
                field_name_mapping={"": module_op.get_program_param("memcpy_params")},
            )
            csl_wrapper.ImportOp(
                "stencil_comms.csl",
                field_name_mapping={
                    "pattern": module_op.get_program_param("pattern"),
                    "chunkSize": module_op.get_program_param("chunk_size"),
                    "": module_op.get_program_param("stencil_comms_params"),
                },
            )
        module_op.program_module.block.add_ops(add_ops)
        module_op.program_module.block.add_op(csl_wrapper.YieldOp([], []))

target: csl.Target instance-attribute

Specifies the target architecture.

__init__(target: csl.Target) -> None

match_and_rewrite(op: func.FuncOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/csl_stencil_to_csl_wrapper.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
    # erase timer stubs
    if op.is_declaration and op.sym_name.data in [TIMER_START, TIMER_END]:
        rewriter.erase_op(op)
        return
    # find csl_stencil.apply ops, abort if there are none
    apply_ops = self.get_csl_stencil_apply_ops(op)
    if len(apply_ops) == 0:
        return
    max_distance: int = 1
    width: int = 1
    height: int = 1
    z_dim_no_ghost_cells: int = 1
    z_dim: int = 1
    num_chunks: int = 1
    chunk_size: int = 1
    for apply_op in apply_ops:
        # loop over accesses to get max_distance (from which we build `pattern`)
        for ap in apply_op.get_accesses():
            if ap.is_diagonal:
                raise ValueError("Diagonal accesses are currently not supported")
            if len(ap.offsets) > 0:
                if ap.dims != 2:
                    raise ValueError(
                        "Stencil accesses must be 2-dimensional at this stage"
                    )
                max_distance = max(max_distance, ap.max_distance())

        # find max x and y dimensions
        if len(shape := apply_op.topo.shape.get_values()) == 2:
            width = max(width, shape[0])
            height = max(height, shape[1])
        else:
            raise ValueError("Stencil accesses must be 2-dimensional at this stage")

        # find max z dimension - we could get this from func args, store ops, or apply ops
        # to support both bufferized and unbufferized csl_stencils, retrieve this from accumulator
        if isinstance(apply_op.done_exchange.block.args[1].type, ShapedType):
            z_dim_no_ghost_cells = max(
                z_dim_no_ghost_cells,
                apply_op.done_exchange.block.args[1].type.get_shape()[-1],
            )

        # retrieve z_dim from done_exchange arg[0]
        if stencil.StencilTypeConstr.verifies(
            field_t := apply_op.done_exchange.block.args[0].type
        ) and isa(el_type := field_t.element_type, TensorType | MemRefType):
            # unbufferized csl_stencil
            z_dim = max(z_dim, el_type.get_shape()[-1])
        elif isa(field_t, memref.MemRefType):
            # bufferized csl_stencil
            z_dim = max(z_dim, field_t.get_shape()[-1])

        num_chunks = max(num_chunks, apply_op.num_chunks.value.data)
        if isa(
            buf_t := apply_op.receive_chunk.block.args[0].type,
            TensorType | MemRefType,
        ):
            chunk_size = max(chunk_size, buf_t.get_shape()[-1])

    padded_z_dim: int = chunk_size * num_chunks

    # initialise module op
    module_op = csl_wrapper.ModuleOp(
        width=IntegerAttr(width + (max_distance * 2), 16),
        height=IntegerAttr(height + (max_distance * 2), 16),
        target=self.target,
        params={
            "z_dim": IntegerAttr(z_dim, 16),
            "pattern": IntegerAttr(max_distance + 1, 16),
            "num_chunks": IntegerAttr(num_chunks, 16),
            "chunk_size": IntegerAttr(chunk_size, 16),
            "padded_z_dim": IntegerAttr(padded_z_dim, 16),
        },
    )

    self.initialise_layout_module(module_op)
    module_op.program_name = op.sym_name

    # add yield op args to program_module block args
    module_op.update_program_block_args()

    # set up main function and move func.func ops into this csl.func
    main_func = csl.FuncOp(op.sym_name.data, ((), None))
    func_export = csl.SymbolExportOp(main_func.sym_name, main_func.function_type)
    args_to_ops, arg_mappings = self._translate_function_args(op.args, op.arg_attrs)
    rewriter.inline_block(
        op.body.block,
        InsertPoint.at_start(main_func.body.block),
        arg_mappings,
    )

    # initialise program_module and add main func and empty yield op
    self.initialise_program_module(
        module_op, add_ops=[*args_to_ops, func_export, main_func]
    )

    # replace func.return by unblock_cmd_stream and csl.return
    func_return = main_func.body.block.last_op
    assert isinstance(func_return, func.ReturnOp)
    assert len(func_return.arguments) == 0, (
        "Non-empty returns currently not supported"
    )
    memcpy = module_op.get_program_import("<memcpy/memcpy>")
    unblock_call = csl.MemberCallOp(
        struct=memcpy, fname="unblock_cmd_stream", params=[], result_type=None
    )
    rewriter.replace_op(func_return, [unblock_call, csl.ReturnOp()])

    # replace (now empty) func by module wrapper
    rewriter.replace_op(op, module_op)

get_csl_stencil_apply_ops(op: func.FuncOp) -> Sequence[csl_stencil.ApplyOp]

Source code in xdsl/transforms/csl_stencil_to_csl_wrapper.py
179
180
181
182
183
184
185
186
def get_csl_stencil_apply_ops(
    self, op: func.FuncOp
) -> Sequence[csl_stencil.ApplyOp]:
    result: list[csl_stencil.ApplyOp] = []
    for apply_op in op.body.walk():
        if isinstance(apply_op, csl_stencil.ApplyOp):
            result.append(apply_op)
    return result

initialise_layout_module(module_op: csl_wrapper.ModuleOp)

Initialises the layout_module (wrapper block) by setting up (esp. stencil-related) program params

Source code in xdsl/transforms/csl_stencil_to_csl_wrapper.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def initialise_layout_module(self, module_op: csl_wrapper.ModuleOp):
    """Initialises the layout_module (wrapper block) by setting up (esp. stencil-related) program params"""

    # extract layout module params as the function has linear complexity
    param_width = module_op.get_layout_param("width")
    param_height = module_op.get_layout_param("height")
    param_x = module_op.get_layout_param("x")
    param_y = module_op.get_layout_param("y")
    param_pattern = module_op.get_layout_param("pattern")

    # fill layout module wrapper block with ops
    with ImplicitBuilder(module_op.layout_module.block):
        # import memcpy/get_params and routes
        memcpy = csl_wrapper.ImportOp(
            "<memcpy/get_params>",
            {
                "width": param_width,
                "height": param_height,
            },
        )
        routes = csl_wrapper.ImportOp(
            "routes.csl",
            {
                "pattern": param_pattern,
                "peWidth": param_width,
                "peHeight": param_height,
            },
        )

        # set up program param `stencil_comms_params`
        all_routes = csl.MemberCallOp(
            "computeAllRoutes",
            csl.ComptimeStructType(),
            routes,
            params=[
                param_x,
                param_y,
                param_width,
                param_height,
                param_pattern,
            ],
        )
        # set up program param `memcpy_params`
        memcpy_params = csl.MemberCallOp(
            "get_params",
            csl.ComptimeStructType(),
            memcpy,
            params=[
                param_x,
            ],
        )

        # set up program param `is_border_region_pe`
        one = arith.ConstantOp(IntegerAttr(1, 16))
        pattern_minus_one = arith.SubiOp(param_pattern, one)
        width_minus_x = arith.SubiOp(param_width, param_x)
        height_minus_y = arith.SubiOp(param_height, param_y)
        x_lt_pattern_minus_one = arith.CmpiOp(param_x, pattern_minus_one, "slt")
        y_lt_pattern_minus_one = arith.CmpiOp(param_y, pattern_minus_one, "slt")
        width_minus_one_lt_pattern = arith.CmpiOp(
            width_minus_x, param_pattern, "slt"
        )
        height_minus_one_lt_pattern = arith.CmpiOp(
            height_minus_y, param_pattern, "slt"
        )
        or1_op = arith.OrIOp(x_lt_pattern_minus_one, y_lt_pattern_minus_one)
        or2_op = arith.OrIOp(or1_op, width_minus_one_lt_pattern)
        is_border_region_pe = arith.OrIOp(or2_op, height_minus_one_lt_pattern)

        # yield things as named params to the program module
        csl_wrapper.YieldOp.from_field_name_mapping(
            field_name_mapping={
                "memcpy_params": memcpy_params.results[0],
                "stencil_comms_params": all_routes.results[0],
                "isBorderRegionPE": is_border_region_pe.result,
            }
        )

initialise_program_module(module_op: csl_wrapper.ModuleOp, add_ops: Sequence[Operation])

Source code in xdsl/transforms/csl_stencil_to_csl_wrapper.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def initialise_program_module(
    self, module_op: csl_wrapper.ModuleOp, add_ops: Sequence[Operation]
):
    with ImplicitBuilder(module_op.program_module.block):
        csl_wrapper.ImportOp(
            "<memcpy/memcpy>",
            field_name_mapping={"": module_op.get_program_param("memcpy_params")},
        )
        csl_wrapper.ImportOp(
            "stencil_comms.csl",
            field_name_mapping={
                "pattern": module_op.get_program_param("pattern"),
                "chunkSize": module_op.get_program_param("chunk_size"),
                "": module_op.get_program_param("stencil_comms_params"),
            },
        )
    module_op.program_module.block.add_ops(add_ops)
    module_op.program_module.block.add_op(csl_wrapper.YieldOp([], []))

LowerTimerFuncCall dataclass

Bases: RewritePattern

Lowers calls to the start and end timer to csl API calls.

Source code in xdsl/transforms/csl_stencil_to_csl_wrapper.py
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
@dataclass(frozen=True)
class LowerTimerFuncCall(RewritePattern):
    """
    Lowers calls to the start and end timer to csl API calls.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: llvm.StoreOp, rewriter: PatternRewriter, /):
        if (
            not isinstance(end_call := op.value.owner, func.CallOp)
            or not end_call.callee.string_value() == TIMER_END
            or not (isinstance(start_call := end_call.arguments[0].owner, func.CallOp))
            or not start_call.callee.string_value() == TIMER_START
            or not (wrapper := _get_module_wrapper(op))
            or not isa(op.ptr.type, MemRefType)
        ):
            return

        time_lib = wrapper.get_program_import("<time>")

        three_elem_ptr_type = csl.PtrType(
            memref.MemRefType(op.ptr.type.get_element_type(), (3,)),
            csl.PtrKindAttr(csl.PtrKind.SINGLE),
            csl.PtrConstAttr(csl.PtrConst.VAR),
        )

        rewriter.insert_op(
            [
                three := arith.ConstantOp.from_int_and_width(3, IndexType()),
                load_three := memref.LoadOp.get(op.ptr, [three]),
                addr_of := csl.AddressOfOp(
                    load_three,
                    csl.PtrType.get(
                        op.ptr.type.get_element_type(), is_single=True, is_const=False
                    ),
                ),
                ptrcast := csl.PtrCastOp(addr_of, three_elem_ptr_type),
                csl.MemberCallOp("get_timestamp", None, time_lib, [ptrcast]),
                csl.MemberCallOp("disable_tsc", None, time_lib, []),
            ],
            InsertPoint.before(end_call),
        )
        rewriter.insert_op(
            [
                addr_of := csl.AddressOfOp(
                    op.ptr,
                    csl.PtrType.get(
                        op.ptr.type.get_element_type(), is_single=False, is_const=False
                    ),
                ),
                ptrcast := csl.PtrCastOp(addr_of, three_elem_ptr_type),
                csl.MemberCallOp("enable_tsc", None, time_lib, []),
                csl.MemberCallOp("get_timestamp", None, time_lib, [ptrcast]),
            ],
            InsertPoint.before(start_call),
        )
        rewriter.erase_op(op)
        rewriter.erase_op(end_call)
        rewriter.erase_op(start_call)

__init__() -> None

match_and_rewrite(op: llvm.StoreOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/csl_stencil_to_csl_wrapper.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
@op_type_rewrite_pattern
def match_and_rewrite(self, op: llvm.StoreOp, rewriter: PatternRewriter, /):
    if (
        not isinstance(end_call := op.value.owner, func.CallOp)
        or not end_call.callee.string_value() == TIMER_END
        or not (isinstance(start_call := end_call.arguments[0].owner, func.CallOp))
        or not start_call.callee.string_value() == TIMER_START
        or not (wrapper := _get_module_wrapper(op))
        or not isa(op.ptr.type, MemRefType)
    ):
        return

    time_lib = wrapper.get_program_import("<time>")

    three_elem_ptr_type = csl.PtrType(
        memref.MemRefType(op.ptr.type.get_element_type(), (3,)),
        csl.PtrKindAttr(csl.PtrKind.SINGLE),
        csl.PtrConstAttr(csl.PtrConst.VAR),
    )

    rewriter.insert_op(
        [
            three := arith.ConstantOp.from_int_and_width(3, IndexType()),
            load_three := memref.LoadOp.get(op.ptr, [three]),
            addr_of := csl.AddressOfOp(
                load_three,
                csl.PtrType.get(
                    op.ptr.type.get_element_type(), is_single=True, is_const=False
                ),
            ),
            ptrcast := csl.PtrCastOp(addr_of, three_elem_ptr_type),
            csl.MemberCallOp("get_timestamp", None, time_lib, [ptrcast]),
            csl.MemberCallOp("disable_tsc", None, time_lib, []),
        ],
        InsertPoint.before(end_call),
    )
    rewriter.insert_op(
        [
            addr_of := csl.AddressOfOp(
                op.ptr,
                csl.PtrType.get(
                    op.ptr.type.get_element_type(), is_single=False, is_const=False
                ),
            ),
            ptrcast := csl.PtrCastOp(addr_of, three_elem_ptr_type),
            csl.MemberCallOp("enable_tsc", None, time_lib, []),
            csl.MemberCallOp("get_timestamp", None, time_lib, [ptrcast]),
        ],
        InsertPoint.before(start_call),
    )
    rewriter.erase_op(op)
    rewriter.erase_op(end_call)
    rewriter.erase_op(start_call)

CslStencilToCslWrapperPass dataclass

Bases: ModulePass

Wraps program in the csl_stencil dialect in a csl_wrapper by translating each top-level function to one module wrapper.

Source code in xdsl/transforms/csl_stencil_to_csl_wrapper.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
@dataclass(frozen=True)
class CslStencilToCslWrapperPass(ModulePass):
    """
    Wraps program in the csl_stencil dialect in a csl_wrapper by translating each
    top-level function to one module wrapper.
    """

    name = "csl-stencil-to-csl-wrapper"

    target: csl.Target
    """
    Specifies the target architecture.
    """

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        module_pass = PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ConvertStencilFuncToModuleWrappedPattern(self.target),
                    LowerTimerFuncCall(),
                ]
            ),
            apply_recursively=False,
        )
        module_pass.rewrite_module(op)

name = 'csl-stencil-to-csl-wrapper' class-attribute instance-attribute

target: csl.Target instance-attribute

Specifies the target architecture.

__init__(target: csl.Target) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/csl_stencil_to_csl_wrapper.py
453
454
455
456
457
458
459
460
461
462
463
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    module_pass = PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ConvertStencilFuncToModuleWrappedPattern(self.target),
                LowerTimerFuncCall(),
            ]
        ),
        apply_recursively=False,
    )
    module_pass.rewrite_module(op)