Skip to content

Test vectorize matmul

test_vectorize_matmul

VectorizeMatmulOp dataclass

Bases: RewritePattern

Source code in xdsl/transforms/test_vectorize_matmul.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@dataclass
class VectorizeMatmulOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: linalg.MatmulOp, rewriter: PatternRewriter, /):
        # C += A * B
        # C: M x N, A: M x K, B: K x N

        a, b = op.inputs
        c = op.outputs[0]

        a_type = a.type
        b_type = b.type
        c_type = c.type

        # Only handle matmul on memrefs for now
        if (
            not isa(a_type, builtin.MemRefType)
            or not isa(b_type, builtin.MemRefType)
            or not isa(c_type, builtin.MemRefType)
        ):
            raise DiagnosticException(
                "Vectorizing matmul on tensors not yet implemented."
            )

        M, K = a_type.get_shape()
        _K, N = b_type.get_shape()
        _M, _N = c_type.get_shape()

        assert M == _M
        assert N == _N
        assert K == _K
        assert M != -1
        assert N != -1
        assert K != -1

        vector_type = builtin.VectorType(a_type.element_type, (N,))

        # All operations created inside this block are inserted before the matched op
        with ImplicitBuilder(rewriter):
            # Insert all the integer constants we'll need to index into the matrices
            constants = tuple(
                arith.ConstantOp(builtin.IntegerAttr(i, _index_type)).result
                for i in range(max(M, N, K))
            )
            # Zero for convenience
            c0 = constants[0]

            # Load the rows of C as vectors
            c_rows = [
                vector.LoadOp(c, (constants[m], c0), vector_type).result
                for m in range(M)
            ]

            # Load the rows of B as vectors
            b_rows = tuple(
                vector.LoadOp(b, (constants[k], c0), vector_type).result
                for k in range(K)
            )

            for m in range(M):
                # Load the mth column of A as scalars
                a_col = tuple(
                    memref.LoadOp.get(a, (constants[m], constants[k])).res
                    for k in range(K)
                )
                # Broadcast the mth column of A to vectors
                a_col_vectors = tuple(
                    vector.BroadcastOp(a_col[k], vector_type) for k in range(K)
                )

                for k in range(K):
                    # Accumulate the dot product of rows of B with A's column
                    # The list c_rows is updated in place for convenience, but we're
                    # really creating a new SSA value on each iteration
                    c_rows[m] = vector.FMAOp(a_col_vectors[k], b_rows[k], c_rows[m]).res

            for m in range(M):
                vector.StoreOp(c_rows[m], c, (constants[m], c0))

        rewriter.erase_op(op)

__init__() -> None

match_and_rewrite(op: linalg.MatmulOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/test_vectorize_matmul.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.MatmulOp, rewriter: PatternRewriter, /):
    # C += A * B
    # C: M x N, A: M x K, B: K x N

    a, b = op.inputs
    c = op.outputs[0]

    a_type = a.type
    b_type = b.type
    c_type = c.type

    # Only handle matmul on memrefs for now
    if (
        not isa(a_type, builtin.MemRefType)
        or not isa(b_type, builtin.MemRefType)
        or not isa(c_type, builtin.MemRefType)
    ):
        raise DiagnosticException(
            "Vectorizing matmul on tensors not yet implemented."
        )

    M, K = a_type.get_shape()
    _K, N = b_type.get_shape()
    _M, _N = c_type.get_shape()

    assert M == _M
    assert N == _N
    assert K == _K
    assert M != -1
    assert N != -1
    assert K != -1

    vector_type = builtin.VectorType(a_type.element_type, (N,))

    # All operations created inside this block are inserted before the matched op
    with ImplicitBuilder(rewriter):
        # Insert all the integer constants we'll need to index into the matrices
        constants = tuple(
            arith.ConstantOp(builtin.IntegerAttr(i, _index_type)).result
            for i in range(max(M, N, K))
        )
        # Zero for convenience
        c0 = constants[0]

        # Load the rows of C as vectors
        c_rows = [
            vector.LoadOp(c, (constants[m], c0), vector_type).result
            for m in range(M)
        ]

        # Load the rows of B as vectors
        b_rows = tuple(
            vector.LoadOp(b, (constants[k], c0), vector_type).result
            for k in range(K)
        )

        for m in range(M):
            # Load the mth column of A as scalars
            a_col = tuple(
                memref.LoadOp.get(a, (constants[m], constants[k])).res
                for k in range(K)
            )
            # Broadcast the mth column of A to vectors
            a_col_vectors = tuple(
                vector.BroadcastOp(a_col[k], vector_type) for k in range(K)
            )

            for k in range(K):
                # Accumulate the dot product of rows of B with A's column
                # The list c_rows is updated in place for convenience, but we're
                # really creating a new SSA value on each iteration
                c_rows[m] = vector.FMAOp(a_col_vectors[k], b_rows[k], c_rows[m]).res

        for m in range(M):
            vector.StoreOp(c_rows[m], c, (constants[m], c0))

    rewriter.erase_op(op)

TestVectorizeMatmulPass dataclass

Bases: ModulePass

A test pass vectorizing linalg.matmul with a specific vectorization strategy.

Source code in xdsl/transforms/test_vectorize_matmul.py
101
102
103
104
105
106
107
108
109
110
111
112
@dataclass(frozen=True)
class TestVectorizeMatmulPass(ModulePass):
    """
    A test pass vectorizing linalg.matmul with a specific vectorization strategy.
    """

    name = "test-vectorize-matmul"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            VectorizeMatmulOp(), apply_recursively=False
        ).rewrite_module(op)

name = 'test-vectorize-matmul' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/test_vectorize_matmul.py
109
110
111
112
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        VectorizeMatmulOp(), apply_recursively=False
    ).rewrite_module(op)