Is there still a fast invsqrt magic number for float128?

140 Views Asked by At
template<class T,class U>
T union_cast(U data){
    union{U a;T b;}t{data};
    return t.b;
}
float128_t quick_invsqrt_with_magic_num(int128_t mnum,float128_t X){
    auto x= union_cast<float128_t>(mnum - (union_cast<int128_t>(X) >> 1));
    return x;
}

I'm trying to figure out a fast invsqrt hack for float128 to complete my library, and based on the R0 value given in Chris Lomont's paper, I get that the magic number should probably be 0x2FFF6EC85E7DE30DAABC6027118577EF.

int main() {
    cout << (190_128 >> 1);
    cout << int128_t(r0*(1_u128 << 23)+0.5);
    cout << endl;
    cout << (24574_128 >> 1);
    cout << int128_t(r0*(1_u128 << 112)+0.5);
    cout << endl;//理论最佳值?(来自FAST INVERSE SQUARE ROOT form CHRIS LOMONT)
}

But in my test programme this theoretical magic number generates NaN 14 times in 36 test data, which is never seen in invsqrt for float64. I tried to generate the magic number with random numbers and checked its quality repeatedly, the programme worked for a whole day on my PC and the best result I got was 13 NaNs out of 36 tests.

2FFF6EC85E7DE30DAABC6027118577EF : 2.53097e+237
not good enough on 2 13 7.2e+73 1.14514e+196 980 0.142857 0.0588235 0.0138889 0.0136986 0.000499251 0.000497265 0.000495786 0.00049334 0.000502765
bad rate 14/36 (38.8889%)

Is it theoretically impossible to derive fast invsqrt magic numbers for float128 and larger IEEE floating point numbers? If so, why? If not, is there any more accurate way to get fast invsqrt magic numbers for IEEE floats with arbitrary parameters?

A link to an online run of the code I used for testing here.

1

There are 1 best solutions below

1
On BEST ANSWER

The paper's method can be extrapolated to give a way to calculate the magic numbers for any floating point type, but finding $r_0$ (which doesn't depend on floating point type) is complicated, so I just used the given constant.

This snippet of Haskell code calculates the magic number for a particular floating point type given the value of $r_0$ from the paper:

import Data.Ratio ((%))
import Numeric (showHex) -- used later

r0 :: Rational
r0 = 0.432744889959443195468521587014

magic :: Integer -> Integer -> Rational
magic exponentBits explicitMantissaBits = decode 2 explicitMantissaBits (0, 3 * 2 ^ (exponentBits - 2) - 2 - (2 ^ (exponentBits - 1) - 1), floor (2 ^ explicitMantissaBits * (1 + r0) + 0.5))

decode :: Integer -> Integer -> (Integer, Integer, Integer) -> Rational
decode base mantissaDigits (s, e, m) = (if s == 0 then id else negate) (m % 2 ^ mantissaDigits) * 2^^e   -- here (num % den) makes a rational number

Then you need to encode it into a binary floating point format, see below for inefficient code (but that is hopefully correct). Then, printing the integer in hex gives these magic numbers for half/single/double/quadruple, which last is not the same as yours (how did you calculate it?):

59bb
5f37642f
5fe6ec85e7de30db
5ffe6ec85e7de30dab0ff77452ab6769

Appendix 1 (binary format conversion):

encode :: Integer -> Integer -> Rational -> (Integer, Integer, Integer)
encode _ _ 0 = (0, 0, 0)
encode base mantissaDigits n
  | n < 0 = case encode base mantissaDigits (negate n) of (s, e, m) -> (1 - s, e, m)
  | n < 1 = case encode base mantissaDigits (n * fromInteger base) of (s, e, m) -> (s, e - 1, m)
  | n >= fromInteger base = case encode base mantissaDigits (n / fromInteger base) of (s, e, m) -> (s, e + 1, m)
  | otherwise = (0, 0, round (n * fromInteger (base ^ (mantissaDigits - 1))))

binary :: Integer -> Integer -> Integer -> (Rational -> Integer, Integer -> Maybe Rational, Integer)
binary signBits exponentBits explicitMantissaBits = (enc, dec, enc (magic exponentBits explicitMantissaBits))
 where
  dec :: Integer -> Maybe Rational
  dec n
    | biasedExponent == 2 ^ exponentBits - 1 = Nothing
    | biasedExponent == 0 = Just 0 -- todo subnormal
    | otherwise = Just (decode 2 (explicitMantissaBits) (sign, unbiasedExponent, mantissa))
    where
     sign
       | n >= 2 ^ (exponentBits + explicitMantissaBits) = 1
       | otherwise = 0
     unbiasedExponent = biasedExponent - (2 ^ (exponentBits - 1) - 1)
     mantissa = implicitMantissa + 2 ^ explicitMantissaBits
     implicitMantissa = n `mod` 2 ^ explicitMantissaBits
     biasedExponent = (n `div` 2 ^ explicitMantissaBits) `mod` 2 ^ exponentBits
  enc :: Rational -> Integer
  enc n = case encode 2 (explicitMantissaBits + 1) n of
   (s, unbiasedExponent, explicitMantissa)
    | n == 0 -> 0
    | biasedExponent <= 0 -> 0 -- underflow, todo subnormal
    | biasedExponent >= 2 ^ exponentBits - 1 -> (sign * 2 ^ exponentBits + 2 ^ exponentBits - 1) * 2 ^ explicitMantissaBits -- overflow, infinity
    | otherwise -> (sign * 2 ^ exponentBits + biasedExponent) * 2 ^ explicitMantissaBits + implicitMantissa
    where
      sign
        | signBits > 0 = s
        | otherwise = 0
      mantissaBits = explicitMantissaBits + 1
      biasedExponent = unbiasedExponent + (2 ^ (exponentBits - 1) - 1)
      implicitMantissa = explicitMantissa - (2 ^ (mantissaBits - 1))

half = binary 1 5 10
single = binary 1 8 23
double = binary 1 11 52
quadruple = binary 1 15 112

test float@(enc, dec, mgc) n = case dec (enc (toRational n)) of
  Just m -> do
    putStrLn (showHex mgc "") -- magic number
    print (n - fromRational m) -- should be small
    let r = rSqrtLinear float (toRational n)
        l = toRational n * r * r
    print (fromRational l) -- should be 1

-- here pi defaults to Double precision floating point
main = do
  test half pi
  test single pi
  test double pi
  test quadruple pi

rSqrtLinear (enc, dec, mgc) x = case dec (mgc - (enc x `div` 2)) of Just y -> y

Appendix 2 (test code in C):

#define _GNU_SOURCE
#include <math.h>
#include <stdio.h>

typedef _Float128 float128_t;
typedef __int128 int128_t;

int128_t decode(float128_t f)
{
  return *(int128_t*)&f;
}

float128_t encode(int128_t i)
{
  return *(float128_t*)&i;
}

float128_t magic_f128 = 7.813819087707539518680007964123963e2465f128;
int128_t magic_i128;

float128_t rsqrt_approx(float128_t x)
{
  return encode(magic_i128 - (decode(x) >> 1));
}

float128_t relative_error(float128_t x)
{
  float128_t y = 1 / sqrtf128(x);
  return fabsf128(rsqrt_approx(x) - y) / y;
}

float128_t test_cases[] =
{
2,7,17,13,72,73,2003,2011,2017,2027,1989,
72893,72901,72907,72911,72923,72931,72937,72949,72953,
9.99e+37,72e72,17e17,114514e191,980,
1.0/3,1.0/7,1.0/17,1.0/13,1.0/72,1.0/73,1.0/2003,1.0/2011,1.0/2017,1.0/2027,1.0/1989,
1.0/72893,1.0/72901,1.0/72907,1.0/72911,1.0/72923,1.0/72931,1.0/72937,1.0/72949,1.0/72953,
1.0/9.99e+37,1.0/72e72,1.0/17e17,1.0/114514e191,1.0/980,
};

float128_t rms_error(void)
{
  int count = sizeof(test_cases) / sizeof(test_cases[0]);
  float128_t sum = 0;
  for (int i = 0; i < count; ++i)
  {
    float128_t err = relative_error(test_cases[i]);
    sum += err * err;
  }
  return sqrtf128(sum / count);
}

int main(int argc, char **argv)
{
  magic_i128 = decode(magic_f128);
  printf("%g\n", (double) rms_error());
  return 0;
}

This test code (compiled with gcc version 12.2.0 (Debian 12.2.0-14) x86_64) outputs 0.0198455, which shows the algorithm is working (relative error about 2%).