Skip to content

Dmp

dmp

DmpSwapShapeInference

Bases: RewritePattern

Infer the shape of the dmp.swap operation.

Source code in xdsl/transforms/shape_inference_patterns/dmp.py
12
13
14
15
16
17
18
19
20
21
22
23
24
class DmpSwapShapeInference(RewritePattern):
    """
    Infer the shape of the `dmp.swap` operation.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: dmp.SwapOp, rewrite: PatternRewriter):
        if not op.swapped_values:
            return
        swap_t = op.swapped_values.type
        if op.input_stencil.type != swap_t:
            rewrite.replace_value_with_new_type(op.input_stencil, swap_t)
            rewrite.handle_operation_modification(op)

match_and_rewrite(op: dmp.SwapOp, rewrite: PatternRewriter)

Source code in xdsl/transforms/shape_inference_patterns/dmp.py
17
18
19
20
21
22
23
24
@op_type_rewrite_pattern
def match_and_rewrite(self, op: dmp.SwapOp, rewrite: PatternRewriter):
    if not op.swapped_values:
        return
    swap_t = op.swapped_values.type
    if op.input_stencil.type != swap_t:
        rewrite.replace_value_with_new_type(op.input_stencil, swap_t)
        rewrite.handle_operation_modification(op)

DmpSwapSwapsInference

Bases: RewritePattern

Infer the exact exchanges this dmp.swap needs to perform.

Source code in xdsl/transforms/shape_inference_patterns/dmp.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class DmpSwapSwapsInference(RewritePattern):
    """
    Infer the exact exchanges this `dmp.swap` needs to perform.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: dmp.SwapOp, rewrite: PatternRewriter):
        core_lb: stencil.IndexAttr | None = None
        core_ub: stencil.IndexAttr | None = None

        if not op.swapped_values:
            return

        for use in op.swapped_values.uses:
            if not isinstance(use.operation, stencil.ApplyOp):
                continue
            assert use.operation.res
            bounds = use.operation.get_bounds()
            if isinstance(bounds, builtin.IntAttr):
                return
            core_lb = bounds.lb
            core_ub = bounds.ub
            break

        # this shouldn't have changed since the op was created!
        temp = op.input_stencil.type
        assert isa(temp, stencil.TempType[Attribute])
        if not isinstance(temp.bounds, stencil.StencilBoundsAttr):
            return
        buff_lb = temp.bounds.lb
        buff_ub = temp.bounds.ub

        if core_lb is None or core_ub is None:
            return

        # if buff_* has fewer dimensions than core_*, we need to find out which dimensions of core_*
        # those in buff_* map to. This information only exists in the offset_mapping of a stencil.access
        if len(buff_lb) < len(core_lb):
            for use in op.swapped_values.uses:
                if not isinstance(apply_op := use.operation, stencil.ApplyOp):
                    continue
                arg_idx = apply_op.operands.index(op.swapped_values)
                arg = apply_op.region.block.args[arg_idx]
                for arg_use in arg.uses:
                    if (
                        not isinstance(access_op := arg_use.operation, stencil.AccessOp)
                        or not access_op.offset_mapping
                    ):
                        continue
                    accessed_dims = tuple(access_op.offset_mapping)

                    # reconstruct what dims buff_* maps to
                    new_lb = [0] * len(core_lb)
                    new_ub = [0] * len(core_ub)
                    for idx, dim in enumerate(accessed_dims):
                        new_lb[dim] = buff_lb.array.data[idx].data
                        new_ub[dim] = buff_ub.array.data[idx].data
                    buff_lb = stencil.IndexAttr.get(*new_lb)
                    buff_ub = stencil.IndexAttr.get(*new_ub)

                    # todo breaking here means we do not verify that all accesses map to the same dimension
                    break

        swaps = builtin.ArrayAttr(
            exchange
            for exchange in op.strategy.halo_exchange_defs(
                dmp.ShapeAttr.from_index_attrs(
                    buff_lb=buff_lb,
                    core_lb=core_lb,
                    buff_ub=buff_ub,
                    core_ub=core_ub,
                )
            )
            # drop empty exchanges
            if exchange.elem_count > 0
        )
        if swaps != op.swaps:
            op.swaps = swaps
            rewrite.handle_operation_modification(op)

match_and_rewrite(op: dmp.SwapOp, rewrite: PatternRewriter)

Source code in xdsl/transforms/shape_inference_patterns/dmp.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
@op_type_rewrite_pattern
def match_and_rewrite(self, op: dmp.SwapOp, rewrite: PatternRewriter):
    core_lb: stencil.IndexAttr | None = None
    core_ub: stencil.IndexAttr | None = None

    if not op.swapped_values:
        return

    for use in op.swapped_values.uses:
        if not isinstance(use.operation, stencil.ApplyOp):
            continue
        assert use.operation.res
        bounds = use.operation.get_bounds()
        if isinstance(bounds, builtin.IntAttr):
            return
        core_lb = bounds.lb
        core_ub = bounds.ub
        break

    # this shouldn't have changed since the op was created!
    temp = op.input_stencil.type
    assert isa(temp, stencil.TempType[Attribute])
    if not isinstance(temp.bounds, stencil.StencilBoundsAttr):
        return
    buff_lb = temp.bounds.lb
    buff_ub = temp.bounds.ub

    if core_lb is None or core_ub is None:
        return

    # if buff_* has fewer dimensions than core_*, we need to find out which dimensions of core_*
    # those in buff_* map to. This information only exists in the offset_mapping of a stencil.access
    if len(buff_lb) < len(core_lb):
        for use in op.swapped_values.uses:
            if not isinstance(apply_op := use.operation, stencil.ApplyOp):
                continue
            arg_idx = apply_op.operands.index(op.swapped_values)
            arg = apply_op.region.block.args[arg_idx]
            for arg_use in arg.uses:
                if (
                    not isinstance(access_op := arg_use.operation, stencil.AccessOp)
                    or not access_op.offset_mapping
                ):
                    continue
                accessed_dims = tuple(access_op.offset_mapping)

                # reconstruct what dims buff_* maps to
                new_lb = [0] * len(core_lb)
                new_ub = [0] * len(core_ub)
                for idx, dim in enumerate(accessed_dims):
                    new_lb[dim] = buff_lb.array.data[idx].data
                    new_ub[dim] = buff_ub.array.data[idx].data
                buff_lb = stencil.IndexAttr.get(*new_lb)
                buff_ub = stencil.IndexAttr.get(*new_ub)

                # todo breaking here means we do not verify that all accesses map to the same dimension
                break

    swaps = builtin.ArrayAttr(
        exchange
        for exchange in op.strategy.halo_exchange_defs(
            dmp.ShapeAttr.from_index_attrs(
                buff_lb=buff_lb,
                core_lb=core_lb,
                buff_ub=buff_ub,
                core_ub=core_ub,
            )
        )
        # drop empty exchanges
        if exchange.elem_count > 0
    )
    if swaps != op.swaps:
        op.swaps = swaps
        rewrite.handle_operation_modification(op)