Skip to content

Vector split load extract

vector_split_load_extract

VectorSplitLoadExtract

Bases: RewritePattern

Source code in xdsl/transforms/vector_split_load_extract.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class VectorSplitLoadExtract(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ptr.LoadOp, rewriter: PatternRewriter) -> None:
        if not all(
            isinstance(use.operation, vector.ExtractOp)
            and not use.operation.dynamic_position
            for use in op.res.uses
        ):
            return

        vector_type = cast(VectorType, op.res.type)
        element_type = vector_type.element_type
        if not isinstance(element_type, CompileTimeFixedBitwidthType):
            raise NotImplementedError

        if len(vector_type.shape) != 1:
            raise NotImplementedError

        element_size = element_type.compile_time_size
        name_hint = op.res.name_hint

        for use in op.res.uses:
            user = cast(vector.ExtractOp, use.operation)
            indices = user.static_position.get_values()
            assert len(indices) == 1
            (element_index,) = indices
            rewriter.insert_op(
                (
                    constant_op := arith.ConstantOp(
                        IntegerAttr(element_size * element_index, IndexType())
                    ),
                    add_op := ptr.PtrAddOp(op.addr, constant_op.result),
                    load_op := ptr.LoadOp(add_op.result, element_type),
                )
            )
            rewriter.replace_all_uses_with(user.result, load_op.res)
            rewriter.erase_op(user)
            constant_op.result.name_hint = name_hint
            add_op.result.name_hint = name_hint
            load_op.res.name_hint = name_hint

        rewriter.erase_op(op)

match_and_rewrite(op: ptr.LoadOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/vector_split_load_extract.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.LoadOp, rewriter: PatternRewriter) -> None:
    if not all(
        isinstance(use.operation, vector.ExtractOp)
        and not use.operation.dynamic_position
        for use in op.res.uses
    ):
        return

    vector_type = cast(VectorType, op.res.type)
    element_type = vector_type.element_type
    if not isinstance(element_type, CompileTimeFixedBitwidthType):
        raise NotImplementedError

    if len(vector_type.shape) != 1:
        raise NotImplementedError

    element_size = element_type.compile_time_size
    name_hint = op.res.name_hint

    for use in op.res.uses:
        user = cast(vector.ExtractOp, use.operation)
        indices = user.static_position.get_values()
        assert len(indices) == 1
        (element_index,) = indices
        rewriter.insert_op(
            (
                constant_op := arith.ConstantOp(
                    IntegerAttr(element_size * element_index, IndexType())
                ),
                add_op := ptr.PtrAddOp(op.addr, constant_op.result),
                load_op := ptr.LoadOp(add_op.result, element_type),
            )
        )
        rewriter.replace_all_uses_with(user.result, load_op.res)
        rewriter.erase_op(user)
        constant_op.result.name_hint = name_hint
        add_op.result.name_hint = name_hint
        load_op.res.name_hint = name_hint

    rewriter.erase_op(op)

VectorSplitLoadExtractPass dataclass

Bases: ModulePass

Rewrites a vector load followed only by extracts with scalar loads.

Source code in xdsl/transforms/vector_split_load_extract.py
66
67
68
69
70
71
72
73
74
75
76
77
78
@dataclass(frozen=True)
class VectorSplitLoadExtractPass(ModulePass):
    """
    Rewrites a vector load followed only by extracts with scalar loads.
    """

    name = "vector-split-load-extract"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            VectorSplitLoadExtract(),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'vector-split-load-extract' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/vector_split_load_extract.py
74
75
76
77
78
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        VectorSplitLoadExtract(),
        apply_recursively=False,
    ).rewrite_module(op)