Implements a lowering from operations on memrefs to operations on pointers.
The default lowering of memrefs in MLIR is to a structure that carries the base pointer,
an offset, and strides
(See (memref.extract_strided_metadata)[https://mlir.llvm.org/docs/Dialects/MemRef/#memrefextract_strided_metadata-memrefextractstridedmetadataop]).
If the offset and strides are statically known, they can be encoded in the type, but
they can also be dynamic.
When lowering operations on memrefs to operations on pointers, the offset and stride
information must be lowered to operations that perform the required pointer arithmetic.
In order to simplify the lowering of memory accesses, this pass makes the choice to
lower memrefs to a pointer to the buffer including the dynamic offset.
This means that memory accesses can be a simple dot product of statically known strides
and dynamic indices.
On the other hand, operations that create views of memrefs from other memrefs must lower
to the relevant pointer arithmetic to encode the new inner buffer offset, when possible.
Bases: RewritePattern
Source code in xdsl/transforms/convert_memref_to_ptr.py
204
205
206
207
208
209
210 | @dataclass
class ConvertStorePattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.StoreOp, rewriter: PatternRewriter, /):
assert isa(memref_type := op.memref.type, memref.MemRefType)
target_ptr = build_target_ptr(op.memref, memref_type, op.indices, rewriter)
rewriter.replace_op(op, ptr.StoreOp(target_ptr, op.value))
|
Source code in xdsl/transforms/convert_memref_to_ptr.py
| @op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.StoreOp, rewriter: PatternRewriter, /):
assert isa(memref_type := op.memref.type, memref.MemRefType)
target_ptr = build_target_ptr(op.memref, memref_type, op.indices, rewriter)
rewriter.replace_op(op, ptr.StoreOp(target_ptr, op.value))
|
Bases: RewritePattern
Source code in xdsl/transforms/convert_memref_to_ptr.py
213
214
215
216
217
218
219 | @dataclass
class ConvertLoadPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /):
assert isa(memref_type := op.memref.type, memref.MemRefType)
target_ptr = build_target_ptr(op.memref, memref_type, op.indices, rewriter)
rewriter.replace_op(op, ptr.LoadOp(target_ptr, memref_type.element_type))
|
Source code in xdsl/transforms/convert_memref_to_ptr.py
| @op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /):
assert isa(memref_type := op.memref.type, memref.MemRefType)
target_ptr = build_target_ptr(op.memref, memref_type, op.indices, rewriter)
rewriter.replace_op(op, ptr.LoadOp(target_ptr, memref_type.element_type))
|
Bases: RewritePattern
Converts the subview to a pointer offset.
From the subview op documentation:
In the absence of rank reductions, the resulting memref type is computed as
follows:
...
result_offset = src_offset + dot_product(offset_operands, src_strides)
The pointer that the source memref is lowered to is assumed to incorporate the
src_offset, so this lowering just addds the dot product.
Source code in xdsl/transforms/convert_memref_to_ptr.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297 | class ConvertSubviewPattern(RewritePattern):
"""
Converts the subview to a pointer offset.
From the subview op documentation:
> In the absence of rank reductions, the resulting memref type is computed as
follows:
```
...
result_offset = src_offset + dot_product(offset_operands, src_strides)
```
The pointer that the source memref is lowered to is assumed to incorporate the
`src_offset`, so this lowering just addds the dot product.
"""
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
# The result type of the subview op already encodes the shape, strides, and
# element type that downstream users need for indexing, so this rewrite only
# emits the pointer offset. For dynamic source shapes the source strides can
# be SSA values, which the loop below multiplies directly.
source_type = op.source.type
assert isa(source_type, builtin.MemRefType)
result_type = op.result.type
element_type = result_type.element_type
source_strides = get_strides(op.source, source_type, rewriter)
pointer = rewriter.insert_op(ptr.ToPtrOp(op.source)).res
pointer.name_hint = op.source.name_hint
# The new pointer
head = None
dynamic_offset_index = 0
for stride, offset in zip(
source_strides, op.static_offsets.iter_values(), strict=True
):
if offset == builtin.DYNAMIC_INDEX:
offset_val = op.offsets[dynamic_offset_index]
dynamic_offset_index += 1
else:
offset_val = rewriter.insert_op(
arith.ConstantOp(builtin.IntegerAttr(offset, _index_type))
).result
offset_val.name_hint = f"c{offset}"
if stride == 1:
increment = offset_val
elif isinstance(stride, int):
stride_val = rewriter.insert_op(
arith.ConstantOp(builtin.IntegerAttr(stride, _index_type))
).result
increment = rewriter.insert_op(
arith.MuliOp(stride_val, offset_val)
).result
stride_val.name_hint = f"c{stride}"
increment.name_hint = "increment"
else:
increment = rewriter.insert_op(arith.MuliOp(stride, offset_val)).result
increment.name_hint = "increment"
if head is None:
head = increment
else:
# Otherwise sum up the products.
head = rewriter.insert_op(arith.AddiOp(head, increment)).result
head.name_hint = "subview"
if head is not None:
offset = build_bytes_offset(head, element_type, rewriter)
pointer = build_offset_pointer(pointer, offset, rewriter)
rewriter.replace_op(op, ptr.FromPtrOp(pointer, result_type))
|
Source code in xdsl/transforms/convert_memref_to_ptr.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297 | @op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
# The result type of the subview op already encodes the shape, strides, and
# element type that downstream users need for indexing, so this rewrite only
# emits the pointer offset. For dynamic source shapes the source strides can
# be SSA values, which the loop below multiplies directly.
source_type = op.source.type
assert isa(source_type, builtin.MemRefType)
result_type = op.result.type
element_type = result_type.element_type
source_strides = get_strides(op.source, source_type, rewriter)
pointer = rewriter.insert_op(ptr.ToPtrOp(op.source)).res
pointer.name_hint = op.source.name_hint
# The new pointer
head = None
dynamic_offset_index = 0
for stride, offset in zip(
source_strides, op.static_offsets.iter_values(), strict=True
):
if offset == builtin.DYNAMIC_INDEX:
offset_val = op.offsets[dynamic_offset_index]
dynamic_offset_index += 1
else:
offset_val = rewriter.insert_op(
arith.ConstantOp(builtin.IntegerAttr(offset, _index_type))
).result
offset_val.name_hint = f"c{offset}"
if stride == 1:
increment = offset_val
elif isinstance(stride, int):
stride_val = rewriter.insert_op(
arith.ConstantOp(builtin.IntegerAttr(stride, _index_type))
).result
increment = rewriter.insert_op(
arith.MuliOp(stride_val, offset_val)
).result
stride_val.name_hint = f"c{stride}"
increment.name_hint = "increment"
else:
increment = rewriter.insert_op(arith.MuliOp(stride, offset_val)).result
increment.name_hint = "increment"
if head is None:
head = increment
else:
# Otherwise sum up the products.
head = rewriter.insert_op(arith.AddiOp(head, increment)).result
head.name_hint = "subview"
if head is not None:
offset = build_bytes_offset(head, element_type, rewriter)
pointer = build_offset_pointer(pointer, offset, rewriter)
rewriter.replace_op(op, ptr.FromPtrOp(pointer, result_type))
|
Bases: RewritePattern
Rewrites function arguments of MemRefType to PtrType.
Source code in xdsl/transforms/convert_memref_to_ptr.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346 | @dataclass
class LowerMemRefFuncOpPattern(RewritePattern):
"""
Rewrites function arguments of MemRefType to PtrType.
"""
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
# rewrite function declaration
new_input_types = [
ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
for arg in op.function_type.inputs
]
new_output_types = [
ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
for arg in op.function_type.outputs
]
op.function_type = func.FunctionType.from_lists(
new_input_types,
new_output_types,
)
if op.is_declaration:
return
insert_point = InsertPoint.at_start(op.body.blocks[0])
# rewrite arguments
for arg in op.args:
if not isinstance(arg_type := arg.type, memref.MemRefType):
continue
old_type = cast(memref.MemRefType, arg_type)
arg = rewriter.replace_value_with_new_type(arg, ptr.PtrType())
if not arg.uses:
continue
rewriter.insert_op(
cast_op := ptr.FromPtrOp(arg, old_type),
insert_point,
)
rewriter.replace_uses_with_if(
arg,
cast_op.res,
lambda x: x.operation is not cast_op,
)
|
Source code in xdsl/transforms/convert_memref_to_ptr.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346 | @op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
# rewrite function declaration
new_input_types = [
ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
for arg in op.function_type.inputs
]
new_output_types = [
ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
for arg in op.function_type.outputs
]
op.function_type = func.FunctionType.from_lists(
new_input_types,
new_output_types,
)
if op.is_declaration:
return
insert_point = InsertPoint.at_start(op.body.blocks[0])
# rewrite arguments
for arg in op.args:
if not isinstance(arg_type := arg.type, memref.MemRefType):
continue
old_type = cast(memref.MemRefType, arg_type)
arg = rewriter.replace_value_with_new_type(arg, ptr.PtrType())
if not arg.uses:
continue
rewriter.insert_op(
cast_op := ptr.FromPtrOp(arg, old_type),
insert_point,
)
rewriter.replace_uses_with_if(
arg,
cast_op.res,
lambda x: x.operation is not cast_op,
)
|
Bases: RewritePattern
Rewrites all memref arguments to func.return into ptr.PtrType
Source code in xdsl/transforms/convert_memref_to_ptr.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371 | @dataclass
class LowerMemRefFuncReturnPattern(RewritePattern):
"""
Rewrites all `memref` arguments to `func.return` into `ptr.PtrType`
"""
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /):
if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments):
return
new_arguments: list[SSAValue] = []
# insert `memref -> ptr` casts for memref return values
for argument in op.arguments:
if isinstance(argument.type, memref.MemRefType):
rewriter.insert_op(cast_op := ptr.ToPtrOp(argument))
new_arguments.append(cast_op.res)
cast_op.res.name_hint = argument.name_hint
else:
new_arguments.append(argument)
rewriter.replace_op(op, func.ReturnOp(*new_arguments))
|
Source code in xdsl/transforms/convert_memref_to_ptr.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371 | @op_type_rewrite_pattern
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /):
if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments):
return
new_arguments: list[SSAValue] = []
# insert `memref -> ptr` casts for memref return values
for argument in op.arguments:
if isinstance(argument.type, memref.MemRefType):
rewriter.insert_op(cast_op := ptr.ToPtrOp(argument))
new_arguments.append(cast_op.res)
cast_op.res.name_hint = argument.name_hint
else:
new_arguments.append(argument)
rewriter.replace_op(op, func.ReturnOp(*new_arguments))
|
Bases: RewritePattern
Source code in xdsl/transforms/convert_memref_to_ptr.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412 | @dataclass
class LowerMemRefFuncCallPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /):
if not any(
isinstance(arg.type, memref.MemRefType) for arg in op.arguments
) and not any(isinstance(type, memref.MemRefType) for type in op.result_types):
return
# rewrite arguments
new_arguments: list[SSAValue] = []
# insert `memref -> ptr` casts for memref arguments values, if necessary
for argument in op.arguments:
if isinstance(argument.type, memref.MemRefType):
rewriter.insert_op(cast_op := ptr.ToPtrOp(argument))
new_arguments.append(cast_op.res)
cast_op.res.name_hint = argument.name_hint
else:
new_arguments.append(argument)
new_return_types = [
ptr.PtrType() if isinstance(type, memref.MemRefType) else type
for type in op.result_types
]
new_ops: list[Operation] = [
call_op := func.CallOp(op.callee, new_arguments, new_return_types)
]
new_results = list(call_op.results)
# insert `ptr -> memref` casts for return values, if necessary
for i, (new_result, old_result) in enumerate(zip(call_op.results, op.results)):
new_result.name_hint = old_result.name_hint
if isa(old_result.type, memref.MemRefType):
new_ops.append(cast_op := ptr.FromPtrOp(new_result, old_result.type))
new_results[i] = cast_op.res
rewriter.replace_op(op, new_ops, new_results)
|
Source code in xdsl/transforms/convert_memref_to_ptr.py
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412 | @op_type_rewrite_pattern
def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /):
if not any(
isinstance(arg.type, memref.MemRefType) for arg in op.arguments
) and not any(isinstance(type, memref.MemRefType) for type in op.result_types):
return
# rewrite arguments
new_arguments: list[SSAValue] = []
# insert `memref -> ptr` casts for memref arguments values, if necessary
for argument in op.arguments:
if isinstance(argument.type, memref.MemRefType):
rewriter.insert_op(cast_op := ptr.ToPtrOp(argument))
new_arguments.append(cast_op.res)
cast_op.res.name_hint = argument.name_hint
else:
new_arguments.append(argument)
new_return_types = [
ptr.PtrType() if isinstance(type, memref.MemRefType) else type
for type in op.result_types
]
new_ops: list[Operation] = [
call_op := func.CallOp(op.callee, new_arguments, new_return_types)
]
new_results = list(call_op.results)
# insert `ptr -> memref` casts for return values, if necessary
for i, (new_result, old_result) in enumerate(zip(call_op.results, op.results)):
new_result.name_hint = old_result.name_hint
if isa(old_result.type, memref.MemRefType):
new_ops.append(cast_op := ptr.FromPtrOp(new_result, old_result.type))
new_results[i] = cast_op.res
rewriter.replace_op(op, new_ops, new_results)
|
Bases: RewritePattern
Source code in xdsl/transforms/convert_memref_to_ptr.py
| @dataclass
class ConvertCastOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.CastOp, rewriter: PatternRewriter, /):
assert isa(op.source.type, memref.MemRefType)
rewriter.replace_op(op, (), (op.source,))
|
Source code in xdsl/transforms/convert_memref_to_ptr.py
| @op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.CastOp, rewriter: PatternRewriter, /):
assert isa(op.source.type, memref.MemRefType)
rewriter.replace_op(op, (), (op.source,))
|
Bases: RewritePattern
Source code in xdsl/transforms/convert_memref_to_ptr.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450 | @dataclass
class ConvertReinterpretCastOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: memref.ReinterpretCastOp, rewriter: PatternRewriter, /
):
pointer = rewriter.insert_op(ptr.ToPtrOp(op.source)).res
pointer.name_hint = op.source.name_hint
# reinterpret_cast has exactly one flat element offset
static_offset = next(iter(op.static_offsets.iter_values()))
if static_offset != 0:
# dynamic offset: passed as operand
if static_offset == builtin.DYNAMIC_INDEX:
offset_val = op.offsets[0]
# static non-zero offset: materialize as constant
else:
offset_val = rewriter.insert_op(
arith.ConstantOp(builtin.IntegerAttr(static_offset, _index_type))
).result
offset_val.name_hint = f"c{static_offset}"
# scale element offset to bytes and advance pointer
element_type = op.result.type.element_type
byte_offset = build_bytes_offset(offset_val, element_type, rewriter)
pointer = build_offset_pointer(pointer, byte_offset, rewriter)
rewriter.replace_op(op, ptr.FromPtrOp(pointer, op.result.type))
|
Source code in xdsl/transforms/convert_memref_to_ptr.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450 | @op_type_rewrite_pattern
def match_and_rewrite(
self, op: memref.ReinterpretCastOp, rewriter: PatternRewriter, /
):
pointer = rewriter.insert_op(ptr.ToPtrOp(op.source)).res
pointer.name_hint = op.source.name_hint
# reinterpret_cast has exactly one flat element offset
static_offset = next(iter(op.static_offsets.iter_values()))
if static_offset != 0:
# dynamic offset: passed as operand
if static_offset == builtin.DYNAMIC_INDEX:
offset_val = op.offsets[0]
# static non-zero offset: materialize as constant
else:
offset_val = rewriter.insert_op(
arith.ConstantOp(builtin.IntegerAttr(static_offset, _index_type))
).result
offset_val.name_hint = f"c{static_offset}"
# scale element offset to bytes and advance pointer
element_type = op.result.type.element_type
byte_offset = build_bytes_offset(offset_val, element_type, rewriter)
pointer = build_offset_pointer(pointer, byte_offset, rewriter)
rewriter.replace_op(op, ptr.FromPtrOp(pointer, op.result.type))
|
Bases: ModulePass
Source code in xdsl/transforms/convert_memref_to_ptr.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481 | @dataclass(frozen=True)
class ConvertMemRefToPtr(ModulePass):
name = "convert-memref-to-ptr"
lower_func: bool = False
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
ConvertStorePattern(),
ConvertLoadPattern(),
ConvertSubviewPattern(),
ConvertCastOp(),
ConvertReinterpretCastOp(),
]
)
).rewrite_module(op)
if self.lower_func:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
LowerMemRefFuncOpPattern(),
LowerMemRefFuncCallPattern(),
LowerMemRefFuncReturnPattern(),
]
)
).rewrite_module(op)
|
Source code in xdsl/transforms/convert_memref_to_ptr.py
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481 | def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
ConvertStorePattern(),
ConvertLoadPattern(),
ConvertSubviewPattern(),
ConvertCastOp(),
ConvertReinterpretCastOp(),
]
)
).rewrite_module(op)
if self.lower_func:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
LowerMemRefFuncOpPattern(),
LowerMemRefFuncCallPattern(),
LowerMemRefFuncReturnPattern(),
]
)
).rewrite_module(op)
|
Returns the offset in bytes given an offset in elements and the element type.
Source code in xdsl/transforms/convert_memref_to_ptr.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59 | def build_bytes_offset(
elements_offset: SSAValue, element_type: Attribute, builder: Builder
) -> SSAValue:
"""
Returns the offset in bytes given an offset in elements and the element type.
"""
bytes_per_element_op = builder.insert_op(
ptr.TypeOffsetOp(element_type, _index_type)
)
bytes_offset = builder.insert_op(
arith.MuliOp(elements_offset, bytes_per_element_op)
)
bytes_per_element_op.offset.name_hint = "bytes_per_element"
bytes_offset.result.name_hint = "scaled_pointer_offset"
return bytes_offset.result
|
Returns the pointer incremented by the given number of bytes.
Source code in xdsl/transforms/convert_memref_to_ptr.py
62
63
64
65
66
67
68
69
70
71
72 | def build_offset_pointer(
pointer: SSAValue,
bytes_offset: SSAValue,
builder: Builder,
) -> SSAValue:
"""
Returns the pointer incremented by the given number of bytes.
"""
target_ptr = builder.insert_op(ptr.PtrAddOp(pointer, bytes_offset))
target_ptr.result.name_hint = "offset_pointer"
return target_ptr.result
|
Returns the per-dimension strides of memref_val. Each stride is an int when
statically known, or an SSA value when it depends on dynamic source dimensions
(for which memref.dim ops are emitted).
Source code in xdsl/transforms/convert_memref_to_ptr.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136 | def get_strides(
memref_val: SSAValue,
memref_type: builtin.MemRefType,
builder: Builder,
) -> Sequence[int | SSAValue]:
"""
Returns the per-dimension strides of `memref_val`. Each stride is an `int` when
statically known, or an SSA value when it depends on dynamic source dimensions
(for which `memref.dim` ops are emitted).
"""
match memref_type.layout:
case builtin.StridedLayoutAttr() as layout:
layout_strides = layout.get_strides()
if None in layout_strides:
raise DiagnosticException(
f"MemRef {memref_type} with dynamic stride is not yet implemented"
)
return cast(Sequence[int], layout_strides)
case builtin.NoneAttr():
pass
case _:
raise DiagnosticException(f"Unsupported layout type {memref_type.layout}")
shape = memref_type.get_shape()
if builtin.DYNAMIC_INDEX not in shape:
return builtin.ShapedType.strides_for_shape(shape)
rank = len(shape)
strides: list[int | SSAValue] = [1] * rank
for i in range(rank - 2, -1, -1):
dim_size: int | SSAValue = shape[i + 1]
if dim_size == builtin.DYNAMIC_INDEX:
dim_idx = builder.insert_op(
arith.ConstantOp.from_int_and_width(i + 1, _index_type)
)
dim_idx.result.name_hint = "dim_idx"
dim_size = builder.insert_op(
memref.DimOp.from_source_and_index(memref_val, dim_idx.result)
).result
prev = strides[i + 1]
match (prev, dim_size):
case (int(p), int(d)):
strides[i] = p * d
continue
case (1, _):
strides[i] = dim_size
continue
case (_, 1):
strides[i] = prev
continue
case _:
pass
if isinstance(prev, int):
prev = builder.insert_op(
arith.ConstantOp.from_int_and_width(prev, _index_type)
).result
if isinstance(dim_size, int):
dim_size = builder.insert_op(
arith.ConstantOp.from_int_and_width(dim_size, _index_type)
).result
strides[i] = builder.insert_op(arith.MuliOp(prev, dim_size)).result
return strides
|
Given SSA values for indices, and strides (each either an int or an SSA value),
insert the arithmetic ops that create the combined index offset.
The length of indices and strides must be the same.
Static strides must be positive.
Source code in xdsl/transforms/convert_memref_to_ptr.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177 | def build_strides_offset(
indices: Iterable[SSAValue],
strides: Sequence[int | SSAValue],
builder: Builder,
) -> SSAValue | None:
"""
Given SSA values for indices, and strides (each either an `int` or an SSA value),
insert the arithmetic ops that create the combined index offset.
The length of indices and strides must be the same.
Static strides must be positive.
"""
head: SSAValue | None = None
for index, stride in zip(indices, strides, strict=True):
increment = index
if isinstance(stride, int):
assert stride > 0, f"Strides must be positive, got {stride}"
if stride != 1:
stride_op = builder.insert_op(
arith.ConstantOp.from_int_and_width(stride, _index_type)
)
offset_op = builder.insert_op(arith.MuliOp(increment, stride_op))
stride_op.result.name_hint = "pointer_dim_stride"
offset_op.result.name_hint = "pointer_dim_offset"
increment = offset_op.result
else:
offset_op = builder.insert_op(arith.MuliOp(increment, stride))
offset_op.result.name_hint = "pointer_dim_offset"
increment = offset_op.result
if head is None:
head = increment
continue
add_op = builder.insert_op(arith.AddiOp(head, increment))
add_op.result.name_hint = "pointer_dim_stride"
head = add_op.result
return head
|
Build operations returning a pointer to an element of a memref referenced by indices.
Source code in xdsl/transforms/convert_memref_to_ptr.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201 | def build_target_ptr(
target_memref: SSAValue,
memref_type: memref.MemRefType[Any],
indices: Iterable[SSAValue],
builder: Builder,
) -> SSAValue:
"""
Build operations returning a pointer to an element of a memref referenced by indices.
"""
memref_ptr = builder.insert_op(ptr.ToPtrOp(target_memref))
pointer = memref_ptr.res
pointer.name_hint = target_memref.name_hint
strides = get_strides(target_memref, memref_type, builder)
head = build_strides_offset(indices, strides, builder)
if head is not None:
offset = build_bytes_offset(head, memref_type.element_type, builder)
pointer = build_offset_pointer(pointer, offset, builder)
return pointer
|