Skip to content

Convert memref to ptr

convert_memref_to_ptr

Implements a lowering from operations on memrefs to operations on pointers.

The default lowering of memrefs in MLIR is to a structure that carries the base pointer, an offset, and strides (See (memref.extract_strided_metadata)[https://mlir.llvm.org/docs/Dialects/MemRef/#memrefextract_strided_metadata-memrefextractstridedmetadataop]). If the offset and strides are statically known, they can be encoded in the type, but they can also be dynamic.

When lowering operations on memrefs to operations on pointers, the offset and stride information must be lowered to operations that perform the required pointer arithmetic. In order to simplify the lowering of memory accesses, this pass makes the choice to lower memrefs to a pointer to the buffer including the dynamic offset. This means that memory accesses can be a simple dot product of statically known strides and dynamic indices. On the other hand, operations that create views of memrefs from other memrefs must lower to the relevant pointer arithmetic to encode the new inner buffer offset, when possible.

ConvertStorePattern dataclass

Bases: RewritePattern

Source code in xdsl/transforms/convert_memref_to_ptr.py
162
163
164
165
166
167
168
@dataclass
class ConvertStorePattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: memref.StoreOp, rewriter: PatternRewriter, /):
        assert isa(memref_type := op.memref.type, memref.MemRefType)
        target_ptr = get_target_ptr(op.memref, memref_type, op.indices, rewriter)
        rewriter.replace_op(op, ptr.StoreOp(target_ptr, op.value))

__init__() -> None

match_and_rewrite(op: memref.StoreOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_memref_to_ptr.py
164
165
166
167
168
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.StoreOp, rewriter: PatternRewriter, /):
    assert isa(memref_type := op.memref.type, memref.MemRefType)
    target_ptr = get_target_ptr(op.memref, memref_type, op.indices, rewriter)
    rewriter.replace_op(op, ptr.StoreOp(target_ptr, op.value))

ConvertLoadPattern dataclass

Bases: RewritePattern

Source code in xdsl/transforms/convert_memref_to_ptr.py
171
172
173
174
175
176
177
@dataclass
class ConvertLoadPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /):
        assert isa(memref_type := op.memref.type, memref.MemRefType)
        target_ptr = get_target_ptr(op.memref, memref_type, op.indices, rewriter)
        rewriter.replace_op(op, ptr.LoadOp(target_ptr, memref_type.element_type))

__init__() -> None

match_and_rewrite(op: memref.LoadOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_memref_to_ptr.py
173
174
175
176
177
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /):
    assert isa(memref_type := op.memref.type, memref.MemRefType)
    target_ptr = get_target_ptr(op.memref, memref_type, op.indices, rewriter)
    rewriter.replace_op(op, ptr.LoadOp(target_ptr, memref_type.element_type))

ConvertSubviewPattern

Bases: RewritePattern

Converts the subview to a pointer offset.

From the subview op documentation:

In the absence of rank reductions, the resulting memref type is computed as follows:

...
result_offset = src_offset + dot_product(offset_operands, src_strides)

The pointer that the source memref is lowered to is assumed to incorporate the src_offset, so this lowering just addds the dot product.

Source code in xdsl/transforms/convert_memref_to_ptr.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
class ConvertSubviewPattern(RewritePattern):
    """
    Converts the subview to a pointer offset.

    From the subview op documentation:

    > In the absence of rank reductions, the resulting memref type is computed as
    follows:
    ```
    ...
    result_offset = src_offset + dot_product(offset_operands, src_strides)
    ```

    The pointer that the source memref is lowered to is assumed to incorporate the
    `src_offset`, so this lowering just addds the dot product.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
        # The result of the subview op has the necessary information for downstream
        # users to perform indexing, we only need to translate the offset here.

        source_type = op.source.type
        assert isa(source_type, builtin.MemRefType)
        result_type = op.result.type
        element_type = result_type.element_type

        source_shape = source_type.get_shape()
        if builtin.DYNAMIC_INDEX in source_shape:
            raise PassFailedException(
                f"Cannot lower memref subview of memref type with dynamic "
                f"shape {source_type}."
            )
        source_strides = get_constant_strides(source_type)

        pointer = rewriter.insert_op(ptr.ToPtrOp(op.source)).res
        pointer.name_hint = op.source.name_hint
        # The new pointer
        head = None
        dynamic_offset_index = 0

        for stride, offset in zip(
            source_strides, op.static_offsets.iter_values(), strict=True
        ):
            if offset == builtin.DYNAMIC_INDEX:
                offset_val = op.offsets[dynamic_offset_index]
                dynamic_offset_index += 1
            else:
                offset_val = rewriter.insert_op(
                    arith.ConstantOp(builtin.IntegerAttr(offset, _index_type))
                ).result
                offset_val.name_hint = f"c{offset}"

            if stride == 1:
                increment = offset_val
            else:
                stride_val = rewriter.insert_op(
                    arith.ConstantOp(builtin.IntegerAttr(stride, _index_type))
                ).result
                increment = rewriter.insert_op(
                    arith.MuliOp(stride_val, offset_val)
                ).result
                stride_val.name_hint = f"c{stride}"
                increment.name_hint = "increment"

            if head is None:
                head = increment
            else:
                # Otherwise sum up the products.
                head = rewriter.insert_op(arith.AddiOp(head, increment)).result
                head.name_hint = "subview"

        if head is not None:
            offset = get_bytes_offset(head, element_type, rewriter)
            pointer = get_offset_pointer(pointer, offset, rewriter)

        rewriter.replace_op(op, ptr.FromPtrOp(pointer, result_type))

match_and_rewrite(op: memref.SubviewOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_memref_to_ptr.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
    # The result of the subview op has the necessary information for downstream
    # users to perform indexing, we only need to translate the offset here.

    source_type = op.source.type
    assert isa(source_type, builtin.MemRefType)
    result_type = op.result.type
    element_type = result_type.element_type

    source_shape = source_type.get_shape()
    if builtin.DYNAMIC_INDEX in source_shape:
        raise PassFailedException(
            f"Cannot lower memref subview of memref type with dynamic "
            f"shape {source_type}."
        )
    source_strides = get_constant_strides(source_type)

    pointer = rewriter.insert_op(ptr.ToPtrOp(op.source)).res
    pointer.name_hint = op.source.name_hint
    # The new pointer
    head = None
    dynamic_offset_index = 0

    for stride, offset in zip(
        source_strides, op.static_offsets.iter_values(), strict=True
    ):
        if offset == builtin.DYNAMIC_INDEX:
            offset_val = op.offsets[dynamic_offset_index]
            dynamic_offset_index += 1
        else:
            offset_val = rewriter.insert_op(
                arith.ConstantOp(builtin.IntegerAttr(offset, _index_type))
            ).result
            offset_val.name_hint = f"c{offset}"

        if stride == 1:
            increment = offset_val
        else:
            stride_val = rewriter.insert_op(
                arith.ConstantOp(builtin.IntegerAttr(stride, _index_type))
            ).result
            increment = rewriter.insert_op(
                arith.MuliOp(stride_val, offset_val)
            ).result
            stride_val.name_hint = f"c{stride}"
            increment.name_hint = "increment"

        if head is None:
            head = increment
        else:
            # Otherwise sum up the products.
            head = rewriter.insert_op(arith.AddiOp(head, increment)).result
            head.name_hint = "subview"

    if head is not None:
        offset = get_bytes_offset(head, element_type, rewriter)
        pointer = get_offset_pointer(pointer, offset, rewriter)

    rewriter.replace_op(op, ptr.FromPtrOp(pointer, result_type))

LowerMemRefFuncOpPattern dataclass

Bases: RewritePattern

Rewrites function arguments of MemRefType to PtrType.

Source code in xdsl/transforms/convert_memref_to_ptr.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
@dataclass
class LowerMemRefFuncOpPattern(RewritePattern):
    """
    Rewrites function arguments of MemRefType to PtrType.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
        # rewrite function declaration
        new_input_types = [
            ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
            for arg in op.function_type.inputs
        ]
        new_output_types = [
            ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
            for arg in op.function_type.outputs
        ]
        op.function_type = func.FunctionType.from_lists(
            new_input_types,
            new_output_types,
        )

        if op.is_declaration:
            return

        insert_point = InsertPoint.at_start(op.body.blocks[0])

        # rewrite arguments
        for arg in op.args:
            if not isinstance(arg_type := arg.type, memref.MemRefType):
                continue

            old_type = cast(memref.MemRefType, arg_type)
            arg = rewriter.replace_value_with_new_type(arg, ptr.PtrType())

            if not arg.uses:
                continue

            rewriter.insert_op(
                cast_op := ptr.FromPtrOp(arg, old_type),
                insert_point,
            )
            rewriter.replace_uses_with_if(
                arg,
                cast_op.res,
                lambda x: x.operation is not cast_op,
            )

__init__() -> None

match_and_rewrite(op: func.FuncOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_memref_to_ptr.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
    # rewrite function declaration
    new_input_types = [
        ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
        for arg in op.function_type.inputs
    ]
    new_output_types = [
        ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
        for arg in op.function_type.outputs
    ]
    op.function_type = func.FunctionType.from_lists(
        new_input_types,
        new_output_types,
    )

    if op.is_declaration:
        return

    insert_point = InsertPoint.at_start(op.body.blocks[0])

    # rewrite arguments
    for arg in op.args:
        if not isinstance(arg_type := arg.type, memref.MemRefType):
            continue

        old_type = cast(memref.MemRefType, arg_type)
        arg = rewriter.replace_value_with_new_type(arg, ptr.PtrType())

        if not arg.uses:
            continue

        rewriter.insert_op(
            cast_op := ptr.FromPtrOp(arg, old_type),
            insert_point,
        )
        rewriter.replace_uses_with_if(
            arg,
            cast_op.res,
            lambda x: x.operation is not cast_op,
        )

LowerMemRefFuncReturnPattern dataclass

Bases: RewritePattern

Rewrites all memref arguments to func.return into ptr.PtrType

Source code in xdsl/transforms/convert_memref_to_ptr.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
@dataclass
class LowerMemRefFuncReturnPattern(RewritePattern):
    """
    Rewrites all `memref` arguments to `func.return` into `ptr.PtrType`
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /):
        if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments):
            return

        new_arguments: list[SSAValue] = []

        # insert `memref -> ptr` casts for memref return values
        for argument in op.arguments:
            if isinstance(argument.type, memref.MemRefType):
                rewriter.insert_op(cast_op := ptr.ToPtrOp(argument))
                new_arguments.append(cast_op.res)
                cast_op.res.name_hint = argument.name_hint
            else:
                new_arguments.append(argument)

        rewriter.replace_op(op, func.ReturnOp(*new_arguments))

__init__() -> None

match_and_rewrite(op: func.ReturnOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_memref_to_ptr.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /):
    if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments):
        return

    new_arguments: list[SSAValue] = []

    # insert `memref -> ptr` casts for memref return values
    for argument in op.arguments:
        if isinstance(argument.type, memref.MemRefType):
            rewriter.insert_op(cast_op := ptr.ToPtrOp(argument))
            new_arguments.append(cast_op.res)
            cast_op.res.name_hint = argument.name_hint
        else:
            new_arguments.append(argument)

    rewriter.replace_op(op, func.ReturnOp(*new_arguments))

LowerMemRefFuncCallPattern dataclass

Bases: RewritePattern

Source code in xdsl/transforms/convert_memref_to_ptr.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
@dataclass
class LowerMemRefFuncCallPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /):
        if not any(
            isinstance(arg.type, memref.MemRefType) for arg in op.arguments
        ) and not any(isinstance(type, memref.MemRefType) for type in op.result_types):
            return

        # rewrite arguments
        new_arguments: list[SSAValue] = []

        # insert `memref -> ptr` casts for memref arguments values, if necessary
        for argument in op.arguments:
            if isinstance(argument.type, memref.MemRefType):
                rewriter.insert_op(cast_op := ptr.ToPtrOp(argument))
                new_arguments.append(cast_op.res)
                cast_op.res.name_hint = argument.name_hint
            else:
                new_arguments.append(argument)

        new_return_types = [
            ptr.PtrType() if isinstance(type, memref.MemRefType) else type
            for type in op.result_types
        ]

        new_ops: list[Operation] = [
            call_op := func.CallOp(op.callee, new_arguments, new_return_types)
        ]
        new_results = list(call_op.results)

        #  insert `ptr -> memref` casts for return values, if necessary
        for i, (new_result, old_result) in enumerate(zip(call_op.results, op.results)):
            new_result.name_hint = old_result.name_hint
            if isa(old_result.type, memref.MemRefType):
                new_ops.append(cast_op := ptr.FromPtrOp(new_result, old_result.type))
                new_results[i] = cast_op.res

        rewriter.replace_op(op, new_ops, new_results)

__init__() -> None

match_and_rewrite(op: func.CallOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_memref_to_ptr.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /):
    if not any(
        isinstance(arg.type, memref.MemRefType) for arg in op.arguments
    ) and not any(isinstance(type, memref.MemRefType) for type in op.result_types):
        return

    # rewrite arguments
    new_arguments: list[SSAValue] = []

    # insert `memref -> ptr` casts for memref arguments values, if necessary
    for argument in op.arguments:
        if isinstance(argument.type, memref.MemRefType):
            rewriter.insert_op(cast_op := ptr.ToPtrOp(argument))
            new_arguments.append(cast_op.res)
            cast_op.res.name_hint = argument.name_hint
        else:
            new_arguments.append(argument)

    new_return_types = [
        ptr.PtrType() if isinstance(type, memref.MemRefType) else type
        for type in op.result_types
    ]

    new_ops: list[Operation] = [
        call_op := func.CallOp(op.callee, new_arguments, new_return_types)
    ]
    new_results = list(call_op.results)

    #  insert `ptr -> memref` casts for return values, if necessary
    for i, (new_result, old_result) in enumerate(zip(call_op.results, op.results)):
        new_result.name_hint = old_result.name_hint
        if isa(old_result.type, memref.MemRefType):
            new_ops.append(cast_op := ptr.FromPtrOp(new_result, old_result.type))
            new_results[i] = cast_op.res

    rewriter.replace_op(op, new_ops, new_results)

ConvertMemRefToPtr dataclass

Bases: ModulePass

Source code in xdsl/transforms/convert_memref_to_ptr.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
@dataclass(frozen=True)
class ConvertMemRefToPtr(ModulePass):
    name = "convert-memref-to-ptr"

    lower_func: bool = False

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    ConvertStorePattern(),
                    ConvertLoadPattern(),
                    ConvertSubviewPattern(),
                ]
            )
        ).rewrite_module(op)

        if self.lower_func:
            PatternRewriteWalker(
                GreedyRewritePatternApplier(
                    [
                        LowerMemRefFuncOpPattern(),
                        LowerMemRefFuncCallPattern(),
                        LowerMemRefFuncReturnPattern(),
                    ]
                )
            ).rewrite_module(op)

name = 'convert-memref-to-ptr' class-attribute instance-attribute

lower_func: bool = False class-attribute instance-attribute

__init__(lower_func: bool = False) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/convert_memref_to_ptr.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                ConvertStorePattern(),
                ConvertLoadPattern(),
                ConvertSubviewPattern(),
            ]
        )
    ).rewrite_module(op)

    if self.lower_func:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    LowerMemRefFuncOpPattern(),
                    LowerMemRefFuncCallPattern(),
                    LowerMemRefFuncReturnPattern(),
                ]
            )
        ).rewrite_module(op)

get_bytes_offset(elements_offset: SSAValue, element_type: Attribute, builder: Builder) -> SSAValue

Returns the offset in bytes given an offset in elements and the element type.

Source code in xdsl/transforms/convert_memref_to_ptr.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def get_bytes_offset(
    elements_offset: SSAValue, element_type: Attribute, builder: Builder
) -> SSAValue:
    """
    Returns the offset in bytes given an offset in elements and the element type.
    """
    bytes_per_element_op = builder.insert_op(
        ptr.TypeOffsetOp(element_type, _index_type)
    )
    bytes_offset = builder.insert_op(
        arith.MuliOp(elements_offset, bytes_per_element_op)
    )
    bytes_per_element_op.offset.name_hint = "bytes_per_element"
    bytes_offset.result.name_hint = "scaled_pointer_offset"

    return bytes_offset.result

get_offset_pointer(pointer: SSAValue, bytes_offset: SSAValue, builder: Builder) -> SSAValue

Returns the pointer incremented by the given number of bytes.

Source code in xdsl/transforms/convert_memref_to_ptr.py
62
63
64
65
66
67
68
69
70
71
72
def get_offset_pointer(
    pointer: SSAValue,
    bytes_offset: SSAValue,
    builder: Builder,
) -> SSAValue:
    """
    Returns the pointer incremented by the given number of bytes.
    """
    target_ptr = builder.insert_op(ptr.PtrAddOp(pointer, bytes_offset))
    target_ptr.result.name_hint = "offset_pointer"
    return target_ptr.result

get_constant_strides(memref_type: builtin.MemRefType) -> Sequence[int]

If the memref has constant strides and offset, returns them, otherwise raises a DiagnosticException.

Source code in xdsl/transforms/convert_memref_to_ptr.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def get_constant_strides(memref_type: builtin.MemRefType) -> Sequence[int]:
    """
    If the memref has constant strides and offset, returns them, otherwise raises a
    DiagnosticException.
    """
    match memref_type.layout:
        case builtin.NoneAttr():
            strides = builtin.ShapedType.strides_for_shape(memref_type.get_shape())
        case builtin.StridedLayoutAttr():
            strides = memref_type.layout.get_strides()
            if None in strides:
                raise DiagnosticException(
                    f"MemRef {memref_type} with dynamic stride is not yet implemented"
                )
            strides = cast(Sequence[int], strides)
        case _:
            raise DiagnosticException(f"Unsupported layout type {memref_type.layout}")
    return strides

get_strides_offset(indices: Iterable[SSAValue], strides: Sequence[int], builder: Builder) -> SSAValue | None

Given SSA values for indices, and constant strides, insert the arithmetic ops that create the combined index offset. The length of indices and strides must be the same. Strides must be positive.

Source code in xdsl/transforms/convert_memref_to_ptr.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def get_strides_offset(
    indices: Iterable[SSAValue], strides: Sequence[int], builder: Builder
) -> SSAValue | None:
    """
    Given SSA values for indices, and constant strides, insert the arithmetic ops that
    create the combined index offset.
    The length of indices and strides must be the same.
    Strides must be positive.
    """
    head: SSAValue | None = None

    for index, stride in zip(indices, strides, strict=True):
        assert stride > 0, f"Strides must be positive, got {stride}"
        # Calculate the offset that needs to be added through the index of the current
        # dimension.
        increment = index

        # Stride 1 is a noop making the index equal to the offset.
        if stride != 1:
            # Otherwise, multiply the stride (which by definition is the number of
            # elements required to be skipped when incrementing that dimension).
            stride_op = builder.insert_op(
                arith.ConstantOp.from_int_and_width(stride, _index_type)
            )
            offset_op = builder.insert_op(arith.MuliOp(increment, stride_op))
            stride_op.result.name_hint = "pointer_dim_stride"
            offset_op.result.name_hint = "pointer_dim_offset"

            increment = offset_op.result

        if head is None:
            # First iteration.
            head = increment
            continue

        # Otherwise sum up the products.
        add_op = builder.insert_op(arith.AddiOp(head, increment))
        add_op.result.name_hint = "pointer_dim_stride"
        head = add_op.result

    return head

get_target_ptr(target_memref: SSAValue, memref_type: memref.MemRefType[Any], indices: Iterable[SSAValue], builder: Builder) -> SSAValue

Get operations returning a pointer to an element of a memref referenced by indices.

Source code in xdsl/transforms/convert_memref_to_ptr.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def get_target_ptr(
    target_memref: SSAValue,
    memref_type: memref.MemRefType[Any],
    indices: Iterable[SSAValue],
    builder: Builder,
) -> SSAValue:
    """
    Get operations returning a pointer to an element of a memref referenced by indices.
    """

    memref_ptr = builder.insert_op(ptr.ToPtrOp(target_memref))
    pointer = memref_ptr.res
    pointer.name_hint = target_memref.name_hint

    strides = get_constant_strides(memref_type)
    head = get_strides_offset(indices, strides, builder)

    if head is not None:
        offset = get_bytes_offset(head, memref_type.element_type, builder)
        pointer = get_offset_pointer(pointer, offset, builder)

    return pointer