Skip to content

Approximate math with bitcast

approximate_math_with_bitcast

LBs = {builtin.f16: (2 ** 10, 14), builtin.f32: (2 ** 23, 127), builtin.f64: (2 ** 52, 1023)} module-attribute

MakeBase2 dataclass

Bases: RewritePattern

Source code in xdsl/transforms/approximate_math_with_bitcast.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@dataclass
class MakeBase2(RewritePattern):
    log: bool
    exp: bool

    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter):
        if len(op.results) != 1:
            return

        t = op.results[0].type
        if not isa(t, builtin.AnyFloat):
            return

        match op:
            # rewrite ln(A) -> log2(A) * ln(2)
            case math.LogOp(operand=x, fastmath=ff) if self.log:
                ln2 = builtin.FloatAttr(pmath.log(2), t)

                rewriter.replace_matched_op(
                    [
                        c := arith.ConstantOp(ln2),
                        newlog := math.Log2Op(x, ff),
                        mul := arith.MulfOp(c, newlog, ff),
                    ],
                    mul.results,
                )
            # rewrite log1p(A) -> ln(A + 1)
            case math.Log1pOp(operand=x, fastmath=ff) if self.log and isa(
                x.type, builtin.AnyFloat
            ):
                rewriter.replace_matched_op(
                    [
                        one := arith.ConstantOp(builtin.FloatAttr(1.0, x.type)),
                        xp1 := arith.AddfOp(one, x, ff),
                        res := math.LogOp(xp1, ff),
                    ],
                    res.results,
                )
            # rewrite expe(%a) to exp2(%a * log2(e))
            case math.ExpOp(operand=x, fastmath=ff) if self.exp:
                log2e = builtin.FloatAttr(pmath.log2(pmath.e), t)
                rewriter.replace_matched_op(
                    [
                        c := arith.ConstantOp(log2e),
                        inner := arith.MulfOp(c, x, ff),
                        e := math.Exp2Op(inner, ff),
                    ],
                    e.results,
                )
            # TODO: math.powf
            # TODO: math.log10
            # TODO: math.fpowi?
            case _:
                pass

log: bool instance-attribute

exp: bool instance-attribute

__init__(log: bool, exp: bool) -> None

match_and_rewrite(op: Operation, rewriter: PatternRewriter)

Source code in xdsl/transforms/approximate_math_with_bitcast.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter):
    if len(op.results) != 1:
        return

    t = op.results[0].type
    if not isa(t, builtin.AnyFloat):
        return

    match op:
        # rewrite ln(A) -> log2(A) * ln(2)
        case math.LogOp(operand=x, fastmath=ff) if self.log:
            ln2 = builtin.FloatAttr(pmath.log(2), t)

            rewriter.replace_matched_op(
                [
                    c := arith.ConstantOp(ln2),
                    newlog := math.Log2Op(x, ff),
                    mul := arith.MulfOp(c, newlog, ff),
                ],
                mul.results,
            )
        # rewrite log1p(A) -> ln(A + 1)
        case math.Log1pOp(operand=x, fastmath=ff) if self.log and isa(
            x.type, builtin.AnyFloat
        ):
            rewriter.replace_matched_op(
                [
                    one := arith.ConstantOp(builtin.FloatAttr(1.0, x.type)),
                    xp1 := arith.AddfOp(one, x, ff),
                    res := math.LogOp(xp1, ff),
                ],
                res.results,
            )
        # rewrite expe(%a) to exp2(%a * log2(e))
        case math.ExpOp(operand=x, fastmath=ff) if self.exp:
            log2e = builtin.FloatAttr(pmath.log2(pmath.e), t)
            rewriter.replace_matched_op(
                [
                    c := arith.ConstantOp(log2e),
                    inner := arith.MulfOp(c, x, ff),
                    e := math.Exp2Op(inner, ff),
                ],
                e.results,
            )
        # TODO: math.powf
        # TODO: math.log10
        # TODO: math.fpowi?
        case _:
            pass

MakeApprox dataclass

Bases: RewritePattern

Source code in xdsl/transforms/approximate_math_with_bitcast.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
@dataclass
class MakeApprox(RewritePattern):
    log: bool
    exp: bool

    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
        if len(op.results) == 0:
            return

        t = op.results[0].type
        if not isa(t, builtin.AnyFloat):
            return

        L, B = LBs[t] if t in LBs else (1, 1)

        int_t = builtin.IntegerType(t.bitwidth)

        match op:
            # log2(%a) -> fp(L * (B - eps + A))
            case math.Log2Op(operand=x, fastmath=ff) if self.log:
                rewriter.replace_matched_op(
                    [
                        a := arith.ConstantOp(builtin.FloatAttr(L, t)),
                        b := arith.ConstantOp(builtin.FloatAttr(L * (B - 0.045), t)),
                        ax := arith.MulfOp(a, x, ff),
                        axpb := arith.AddfOp(b, ax, ff),
                        asint := arith.FPToSIOp(axpb, int_t),
                        res := arith.BitcastOp(asint, t),
                    ],
                    res.results,
                )
            case math.Exp2Op(operand=x, fastmath=ff) if self.exp:
                # 2^%x -> int(%x) * 1/L - B + eps
                rewriter.replace_matched_op(
                    [
                        a := arith.ConstantOp(builtin.FloatAttr(1 / L, t)),
                        b := arith.ConstantOp(builtin.FloatAttr(-B + 0.045, t)),
                        xi := arith.BitcastOp(x, int_t),
                        xif := arith.SIToFPOp(xi, t),
                        ax := arith.MulfOp(a, xif, ff),
                        axpb := arith.AddfOp(b, ax, ff),
                    ],
                    axpb.results,
                )
            case _:
                pass

log: bool instance-attribute

exp: bool instance-attribute

__init__(log: bool, exp: bool) -> None

match_and_rewrite(op: Operation, rewriter: PatternRewriter)

Source code in xdsl/transforms/approximate_math_with_bitcast.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
    if len(op.results) == 0:
        return

    t = op.results[0].type
    if not isa(t, builtin.AnyFloat):
        return

    L, B = LBs[t] if t in LBs else (1, 1)

    int_t = builtin.IntegerType(t.bitwidth)

    match op:
        # log2(%a) -> fp(L * (B - eps + A))
        case math.Log2Op(operand=x, fastmath=ff) if self.log:
            rewriter.replace_matched_op(
                [
                    a := arith.ConstantOp(builtin.FloatAttr(L, t)),
                    b := arith.ConstantOp(builtin.FloatAttr(L * (B - 0.045), t)),
                    ax := arith.MulfOp(a, x, ff),
                    axpb := arith.AddfOp(b, ax, ff),
                    asint := arith.FPToSIOp(axpb, int_t),
                    res := arith.BitcastOp(asint, t),
                ],
                res.results,
            )
        case math.Exp2Op(operand=x, fastmath=ff) if self.exp:
            # 2^%x -> int(%x) * 1/L - B + eps
            rewriter.replace_matched_op(
                [
                    a := arith.ConstantOp(builtin.FloatAttr(1 / L, t)),
                    b := arith.ConstantOp(builtin.FloatAttr(-B + 0.045, t)),
                    xi := arith.BitcastOp(x, int_t),
                    xif := arith.SIToFPOp(xi, t),
                    ax := arith.MulfOp(a, xif, ff),
                    axpb := arith.AddfOp(b, ax, ff),
                ],
                axpb.results,
            )
        case _:
            pass

ApproximateMathWithBitcastPass dataclass

Bases: ModulePass

This pass applies approximations for some math operations (currently log and exp) and converts them to bitcasting-based approximations.

These are intended for environments that don't need high accuracy, and do not have specialized hardware support for expf and log in hardware.

It makes use of the fact that IEEE floating-point numbers are encoded as three base-2 numbers:

s eeeeeeee mmmmmmmmmmmmmmmmmmmmmmm

With the final floating-point number being $$x = (-1)^s * 2^{e-B} * (1 + m/L)$$. Hence, $$e$$ encodes the $$\lfloor log2(x)\rfloor$$ of x, and can be accessed by bit-casting and left-shifting.

The following approximations are enabled through this:

1) $$\log_2(x) \approx bc-int(x)/L - B + \varepsilon$$ 2) $$2^x \approx bc-float(L(B-\varepsilon + x))$$

With $$\varepsilon$$ being a tunable constant that we initialize to 0.045 for simplicity

This pass first applies some rewrites to convert suitable arithmetic into base-2 format, and then applies the above approximations.

Inidividual rewrites can be controlled via pass arguments.

Source code in xdsl/transforms/approximate_math_with_bitcast.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@dataclass(frozen=True)
class ApproximateMathWithBitcastPass(ModulePass):
    r"""
    This pass applies approximations for some math operations (currently log and exp)
    and converts them to bitcasting-based approximations.

    These are intended for environments that don't need high accuracy, and do not have
    specialized hardware support for expf and log in hardware.

    It makes use of the fact that IEEE floating-point numbers are encoded as three base-2
    numbers:

        s eeeeeeee mmmmmmmmmmmmmmmmmmmmmmm

    With the final floating-point number being $$x = (-1)^s * 2^{e-B} * (1 + m/L)$$.
    Hence, $$e$$ encodes the $$\lfloor log2(x)\rfloor$$ of x, and can be accessed by
    bit-casting and left-shifting.

    The following approximations are enabled through this:

    1) $$\log_2(x) \approx bc-int(x)/L - B + \varepsilon$$
    2) $$2^x       \approx bc-float(L(B-\varepsilon + x))$$

    With $$\varepsilon$$ being a tunable constant that we initialize to 0.045 for simplicity

    This pass first applies some rewrites to convert suitable arithmetic into base-2 format,
    and then applies the above approximations.

    Inidividual rewrites can be controlled via pass arguments.
    """

    name = "approximate-math-with-bitcast"

    log: bool = True
    exp: bool = True

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    MakeBase2(self.log, self.exp),
                    MakeApprox(self.log, self.exp),
                ]
            )
        ).rewrite_module(op)

name = 'approximate-math-with-bitcast' class-attribute instance-attribute

log: bool = True class-attribute instance-attribute

exp: bool = True class-attribute instance-attribute

__init__(log: bool = True, exp: bool = True) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/approximate_math_with_bitcast.py
164
165
166
167
168
169
170
171
172
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                MakeBase2(self.log, self.exp),
                MakeApprox(self.log, self.exp),
            ]
        )
    ).rewrite_module(op)