about summary refs log tree commit diff
path: root/src/math
diff options
context:
space:
mode:
Diffstat (limited to 'src/math')
-rw-r--r--src/math/fma.c150
1 files changed, 143 insertions, 7 deletions
diff --git a/src/math/fma.c b/src/math/fma.c
index f44ecda7..a9de2cb3 100644
--- a/src/math/fma.c
+++ b/src/math/fma.c
@@ -1,3 +1,141 @@
+#include <fenv.h>
+#include "libm.h"
+
+#if LDBL_MANT_DIG==64 && LDBL_MAX_EXP==16384
+union ld80 {
+	long double x;
+	struct {
+		uint64_t m;
+		uint16_t e : 15;
+		uint16_t s : 1;
+		uint16_t pad;
+	} bits;
+};
+
+/* exact add, assumes exponent_x >= exponent_y */
+static void add(long double *hi, long double *lo, long double x, long double y)
+{
+	long double r;
+
+	r = x + y;
+	*hi = r;
+	r -= x;
+	*lo = y - r;
+}
+
+/*
+TODO(nsz): probably simpler mul is enough if we assume x and y are doubles
+so last 11bits are all zeros, no subnormals etc
+*/
+/* exact mul, assumes no over/underflow */
+static void mul(long double *hi, long double *lo, long double x, long double y)
+{
+	static const long double c = 1.0 + 0x1p32L;
+	long double cx, xh, xl, cy, yh, yl;
+
+	cx = c*x;
+	xh = (x - cx) + cx;
+	xl = x - xh;
+	cy = c*y;
+	yh = (y - cy) + cy;
+	yl = y - yh;
+	*hi = x*y;
+	*lo = (xh*yh - *hi) + xh*yl + xl*yh + xl*yl;
+}
+
+/*
+assume (long double)(hi+lo) == hi
+return an adjusted hi so that rounding it to double is correct
+*/
+static long double adjust(long double hi, long double lo)
+{
+	union ld80 uhi, ulo;
+
+	if (lo == 0)
+		return hi;
+	uhi.x = hi;
+	if (uhi.bits.m & 0x3ff)
+		return hi;
+	ulo.x = lo;
+	if (uhi.bits.s == ulo.bits.s)
+		uhi.bits.m++;
+	else
+		uhi.bits.m--;
+	return uhi.x;
+}
+
+static long double dadd(long double x, long double y)
+{
+	add(&x, &y, x, y);
+	return adjust(x, y);
+}
+
+static long double dmul(long double x, long double y)
+{
+	mul(&x, &y, x, y);
+	return adjust(x, y);
+}
+
+static int getexp(long double x)
+{
+	union ld80 u;
+	u.x = x;
+	return u.bits.e;
+}
+
+double fma(double x, double y, double z)
+{
+	long double hi, lo1, lo2, xy;
+	int round, ez, exy;
+
+	/* handle +-inf,nan */
+	if (!isfinite(x) || !isfinite(y))
+		return x*y + z;
+	if (!isfinite(z))
+		return z;
+	/* handle +-0 */
+	if (x == 0.0 || y == 0.0)
+		return x*y + z;
+	round = fegetround();
+	if (z == 0.0) {
+		if (round == FE_TONEAREST)
+			return dmul(x, y);
+		return x*y;
+	}
+
+	/* exact mul and add require nearest rounding */
+	/* spurious inexact exceptions may be raised */
+	fesetround(FE_TONEAREST);
+	mul(&xy, &lo1, x, y);
+	exy = getexp(xy);
+	ez = getexp(z);
+	if (ez > exy) {
+		add(&hi, &lo2, z, xy);
+	} else if (ez > exy - 12) {
+		add(&hi, &lo2, xy, z);
+		if (hi == 0) {
+			fesetround(round);
+			/* TODO: verify that the sign of 0 is always correct */
+			return (xy + z) + lo1;
+		}
+	} else {
+		/*
+		ez <= exy - 12
+		the 12 extra bits (1guard, 11round+sticky) are needed so with
+			lo = dadd(lo1, lo2)
+		elo <= ehi - 11, and we use the last 10 bits in adjust so
+			dadd(hi, lo)
+		gives correct result when rounded to double
+		*/
+		hi = xy;
+		lo2 = z;
+	}
+	fesetround(round);
+	if (round == FE_TONEAREST)
+		return dadd(hi, dadd(lo1, lo2));
+	return hi + (lo1 + lo2);
+}
+#else
 /* origin: FreeBSD /usr/src/lib/msun/src/s_fma.c */
 /*-
  * Copyright (c) 2005-2011 David Schultz <das@FreeBSD.ORG>
@@ -25,9 +163,6 @@
  * SUCH DAMAGE.
  */
 
-#include <fenv.h>
-#include "libm.h"
-
 /*
  * A struct dd represents a floating-point number with twice the precision
  * of a double.  We maintain the invariant that "hi" stores the 53 high-order
@@ -178,14 +313,14 @@ double fma(double x, double y, double z)
 	 * return values here are crucial in handling special cases involving
 	 * infinities, NaNs, overflows, and signed zeroes correctly.
 	 */
-	if (x == 0.0 || y == 0.0)
-		return (x * y + z);
-	if (z == 0.0)
-		return (x * y);
 	if (!isfinite(x) || !isfinite(y))
 		return (x * y + z);
 	if (!isfinite(z))
 		return (z);
+	if (x == 0.0 || y == 0.0)
+		return (x * y + z);
+	if (z == 0.0)
+		return (x * y);
 
 	xs = frexp(x, &ex);
 	ys = frexp(y, &ey);
@@ -278,3 +413,4 @@ double fma(double x, double y, double z)
 	else
 		return (add_and_denormalize(r.hi, adj, spread));
 }
+#endif