Skip to content

Csl stencil set global coeffs

csl_stencil_set_global_coeffs

GenerateCoeffAPICalls dataclass

Bases: RewritePattern

Generates calls to the stencil_comms API to set coefficients.

The API currently supports only f32 coeffs.

Source code in xdsl/transforms/csl_stencil_set_global_coeffs.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@dataclass(frozen=True)
class GenerateCoeffAPICalls(RewritePattern):
    """
    Generates calls to the stencil_comms API to set coefficients.

    The API currently supports only f32 coeffs.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: csl_wrapper.ModuleOp, rewriter: PatternRewriter, /):
        applies: list[csl_stencil.ApplyOp] = []
        global_coeffs = []

        # check that all apply ops have the same coefficients
        for apply in op.walk():
            if isinstance(apply, csl_stencil.ApplyOp):
                # if we have not encountered any apply op before, coeffs are simply stored, not compared
                if not applies:
                    if apply.coeffs:
                        global_coeffs = sorted(
                            apply.coeffs.data, key=lambda x: x.offset
                        )
                elif global_coeffs != (
                    sorted(apply.coeffs.data, key=lambda x: x.offset)
                    if apply.coeffs
                    else []
                ):
                    return
                applies.append(apply)

        # do nothing if there are no apply ops or no coefficients
        if not global_coeffs or not applies:
            return

        op_in_main_fn = applies[0]
        main_fn = None
        while (
            op_in_main_fn
            and (main_fn := op_in_main_fn.parent_op())
            and not isinstance(main_fn, csl.FuncOp)
            and not isinstance(main_fn.parent_op(), csl_wrapper.ModuleOp)
        ):
            op_in_main_fn = op_in_main_fn.parent_op()

        if not op_in_main_fn:
            return

        assert isinstance(main_fn, csl.FuncOp)
        assert main_fn.sym_name == op.program_name, "Apply must be in the main function"

        coeffs_api_call_ops = get_coeff_api_ops(applies[0], op)
        rewriter.insert_op(coeffs_api_call_ops, InsertPoint.before(op_in_main_fn))

        # delete coefficients from apply ops
        for apply in applies:
            apply.coeffs = None

__init__() -> None

match_and_rewrite(op: csl_wrapper.ModuleOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/csl_stencil_set_global_coeffs.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl_wrapper.ModuleOp, rewriter: PatternRewriter, /):
    applies: list[csl_stencil.ApplyOp] = []
    global_coeffs = []

    # check that all apply ops have the same coefficients
    for apply in op.walk():
        if isinstance(apply, csl_stencil.ApplyOp):
            # if we have not encountered any apply op before, coeffs are simply stored, not compared
            if not applies:
                if apply.coeffs:
                    global_coeffs = sorted(
                        apply.coeffs.data, key=lambda x: x.offset
                    )
            elif global_coeffs != (
                sorted(apply.coeffs.data, key=lambda x: x.offset)
                if apply.coeffs
                else []
            ):
                return
            applies.append(apply)

    # do nothing if there are no apply ops or no coefficients
    if not global_coeffs or not applies:
        return

    op_in_main_fn = applies[0]
    main_fn = None
    while (
        op_in_main_fn
        and (main_fn := op_in_main_fn.parent_op())
        and not isinstance(main_fn, csl.FuncOp)
        and not isinstance(main_fn.parent_op(), csl_wrapper.ModuleOp)
    ):
        op_in_main_fn = op_in_main_fn.parent_op()

    if not op_in_main_fn:
        return

    assert isinstance(main_fn, csl.FuncOp)
    assert main_fn.sym_name == op.program_name, "Apply must be in the main function"

    coeffs_api_call_ops = get_coeff_api_ops(applies[0], op)
    rewriter.insert_op(coeffs_api_call_ops, InsertPoint.before(op_in_main_fn))

    # delete coefficients from apply ops
    for apply in applies:
        apply.coeffs = None

CslStencilSetGlobalCoeffs dataclass

Bases: ModulePass

Generates a single coeff api call - only works if all csl_stencil.apply ops use the same coeffs. csl_stencil.apply ops must be in a main csl.func inside a module wrapper.

Source code in xdsl/transforms/csl_stencil_set_global_coeffs.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
@dataclass(frozen=True)
class CslStencilSetGlobalCoeffs(ModulePass):
    """
    Generates a single coeff api call - only works if all csl_stencil.apply ops use the same coeffs.
    `csl_stencil.apply` ops must be in a main csl.func inside a module wrapper.
    """

    name = "csl-stencil-set-global-coeffs"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GenerateCoeffAPICalls(),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'csl-stencil-set-global-coeffs' class-attribute instance-attribute

__init__() -> None

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/csl_stencil_set_global_coeffs.py
183
184
185
186
187
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GenerateCoeffAPICalls(),
        apply_recursively=False,
    ).rewrite_module(op)

get_dir_and_distance(offset: stencil.IndexAttr | tuple[int, ...]) -> tuple[csl.Direction, int]

Given an access op, return the distance and direction, assuming as access to a neighbour (not self) in a star-shape pattern

Source code in xdsl/transforms/csl_stencil_set_global_coeffs.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def get_dir_and_distance(
    offset: stencil.IndexAttr | tuple[int, ...],
) -> tuple[csl.Direction, int]:
    """
    Given an access op, return the distance and direction, assuming as access
    to a neighbour (not self) in a star-shape pattern
    """

    if isinstance(offset, stencil.IndexAttr):
        offset = tuple(offset)
    assert len(offset) == 2, "Expecting 2-dimensional access"
    assert (offset[0] == 0) != (offset[1] == 0), (
        "Expecting neighbour access in a star-shape pattern"
    )
    if offset[0] < 0:
        d = csl.Direction.EAST
    elif offset[0] > 0:
        d = csl.Direction.WEST
    elif offset[1] < 0:
        d = csl.Direction.NORTH
    elif offset[1] > 0:
        d = csl.Direction.SOUTH
    else:
        raise ValueError(
            "Invalid offset, expecting 2-dimensional star-shape neighbor access"
        )
    max_distance = abs(max(offset, key=abs))
    return d, max_distance

get_dir_and_distance_ops(op: csl_stencil.AccessOp) -> tuple[csl.DirectionOp, arith.ConstantOp]

Given an access op, return the distance and direction ops, assuming as access to a neighbour (not self) in a star-shape pattern

Source code in xdsl/transforms/csl_stencil_set_global_coeffs.py
48
49
50
51
52
53
54
55
56
def get_dir_and_distance_ops(
    op: csl_stencil.AccessOp,
) -> tuple[csl.DirectionOp, arith.ConstantOp]:
    """
    Given an access op, return the distance and direction ops, assuming as access
    to a neighbour (not self) in a star-shape pattern
    """
    d, max_distance = get_dir_and_distance(op.offset)
    return csl.DirectionOp(d), arith.ConstantOp(IntegerAttr(max_distance, 16))

get_coeff_api_ops(op: csl_stencil.ApplyOp, wrapper: csl_wrapper.ModuleOp)

Source code in xdsl/transforms/csl_stencil_set_global_coeffs.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def get_coeff_api_ops(op: csl_stencil.ApplyOp, wrapper: csl_wrapper.ModuleOp):
    coeffs = list(op.coeffs or [])
    pattern = wrapper.get_param_value("pattern").value.data
    neighbours = pattern - 1
    is_wse2 = wrapper.target.data == "wse2"
    if is_wse2:
        empty = [0] + neighbours * [1.0]
        shape = (pattern,)
    else:
        empty = neighbours * [1.0]
        shape = (pattern - 1,)

    cmap: dict[csl.Direction, list[float]] = {
        csl.Direction.NORTH: empty,
        csl.Direction.SOUTH: empty.copy(),
        csl.Direction.EAST: empty.copy(),
        csl.Direction.WEST: empty.copy(),
    }

    for c in coeffs:
        direction, distance = get_dir_and_distance(c.offset)
        if not is_wse2:
            distance -= 1
        cmap[direction][distance] = c.coeff.value.data

    memref_t = memref.MemRefType(f32, shape)
    ptr_t = csl.PtrType.get(memref_t, is_single=True, is_const=True)

    cnsts = {
        d: arith.ConstantOp(DenseIntOrFPElementsAttr.from_list(memref_t, v))
        for d, v in cmap.items()
    }
    addrs = {d: csl.AddressOfOp(v, ptr_t) for d, v in cnsts.items()}

    # pretty-printing
    for d, c in cnsts.items():
        c.result.name_hint = str(d)

    args: list[Operation] = [
        addrs[csl.Direction.EAST],
        addrs[csl.Direction.WEST],
        addrs[csl.Direction.SOUTH],
        addrs[csl.Direction.NORTH],
    ]

    return [
        *cnsts.values(),
        *args,
        csl.MemberCallOp(
            "setCoeffs",
            None,
            wrapper.get_program_import("stencil_comms.csl"),
            args,
        ),
    ]