summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Hill <daniel@gluo.nz>2022-09-22 17:41:50 +1200
committerDaniel Hill <daniel@gluo.nz>2022-09-23 12:42:03 +1200
commitdf92f413c8bbff3cfe40bfdcccacf00ac9567c76 (patch)
tree0ca4541ff4457c3dc66be74c2f448e565f4c85f8
parent2604a8e4bbac50c014007bee9b1bf1a84ec1fa13 (diff)
128bit math
-rw-r--r--include/linux/mean_and_variance.h95
-rw-r--r--lib/Kconfig.debug7
-rw-r--r--lib/math/Kconfig9
-rw-r--r--lib/math/mean_and_variance.c23
-rw-r--r--lib/math/mean_and_variance_test.c53
5 files changed, 161 insertions, 26 deletions
diff --git a/include/linux/mean_and_variance.h b/include/linux/mean_and_variance.h
index 0a64c1be9d20..1ae948547d99 100644
--- a/include/linux/mean_and_variance.h
+++ b/include/linux/mean_and_variance.h
@@ -1,15 +1,100 @@
/* SPDX-License-Identifier: GPL-2.0 */
-#ifndef STATS_H_
-#define STATS_H_
+#ifndef MEAN_AND_VARIANCE_H_
+#define MEAN_AND_VAIRANCE_H_
#include <linux/types.h>
+#include <linux/limits.h>
+#include <linux/math64.h>
+#include <linux/printbuf.h>
#define SQRT_U64_MAX 4294967295ULL
+//#ifdef __SIZEOF_INT128__
+
+//typedef unsigned __int128 u128;
+
+//#else
+
+typedef struct {
+ u64 hi;
+ u64 lo;
+} u128;
+
+static inline u128 u128_init(u64 a, u64 b)
+{
+ return (u128){ .hi = a, .lo = b };
+}
+
+static inline u128 u128_add(u128 a, u128 b)
+{
+ u128 c;
+ c.lo = a.lo + b.lo;
+ c.hi = a.hi + b.hi + (c.lo < a.lo);
+ return c;
+}
+
+static inline u128 u128_sub(u128 a, u128 b)
+{
+ u128 c;
+ c.lo = a.lo - b.lo;
+ c.hi = a.hi - (b.hi + c.lo > a.lo);
+ return c;
+}
+static inline u128 u128_shl(u128 i, s8 s1) {
+ u128 r;
+ s8 s2 = 64 - s1;
+ r.lo = i.lo << s1;
+ r.hi = (i.hi << s1) + (i.lo >> s2);
+ return r;
+}
+
+static inline u128 u128_square(u64 i)
+{
+ u128 r;
+ u64 h = i >> 32, l = i & (u64)U32_MAX;
+ u64 x;
+ // overflows:
+ //
+ // ( (a*a) << 128) + (( a*b) << 97) + ((a*c) << 65) + ((b*b) << 65)
+ printk("square %llu", i);
+ printk("h = %llu, l = %llu\n", h, l);
+ r = u128_init(h*h, 0);
+ x = h*l;
+ printk("hi = %llu, x = %llu, x >> 31 = %llu\n", r.hi, x, x >> 31);
+ r = u128_add(r, u128_shl(u128_init(0,x), 33));
+ printk("hi = %llu, lo = %llu\n", r.hi, r.lo);
+ x = l*l;
+ printk("x = %llu\n", x);
+ r = u128_add(r, u128_init(0, x));
+ printk("hi = %llu, lo = %llu\n", r.hi, r.lo);
+ return r;
+}
+
+static inline u128 u128_div(u128 n, u64 d) {
+ u128 result;
+ u64 r;
+ u64 rem;
+ u64 hh = n.hi & ((u64)U32_MAX << 32), hl = (n.hi & (u64)U32_MAX),
+ lh = n.lo & ((u64)U32_MAX << 32), ll = (n.lo & (u64)U32_MAX);
+ printk("divide: %llu::%llu / %llu", n.hi, n.lo, d);
+ printk("hi = %llu, hh = %llu, hl = %llu, hh+hl = %llu\n", n.hi, hh, hl, hh+hl);
+ r = div64_u64_rem(hh, d, &rem);
+ result.hi = r;
+ printk("hi = %llu, r = %llu, rem = %llu \n", result.hi, r, rem);
+ r = div64_u64_rem(((hl + rem) << 32), d, &rem);
+ result.hi += r >> 32;
+ printk("hi = %llu, r = %llu, rem = %llu \n", result.hi, r, rem);
+ r = div64_u64_rem((n.lo + ((rem) << 32)), d, &rem) + (r << 32);
+ result.lo = r;
+ printk("lo = %llu, r = %llu, rem = %llu \n", result.hi, r, rem);
+ return result;
+}
+//#endif
+
struct mean_and_variance {
s64 n;
s64 sum;
- u64 sum_squares;
+ u128 sum_squares;
};
/* expontentially weighted variant */
@@ -34,4 +119,6 @@ struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_
u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
-#endif // STATS_H_
+
+
+#endif // MEAN_AND_VAIRANCE_H_
diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug
index 1d4ed12a5355..f39f09b835ce 100644
--- a/lib/Kconfig.debug
+++ b/lib/Kconfig.debug
@@ -2030,6 +2030,13 @@ config LKDTM
Documentation on how to use the module can be found in
Documentation/fault-injection/provoke-crashes.rst
+config MEAN_AND_VARIANCE_UNIT_TEST
+ tristate "mean_and_variance unit tests" if !KUNIT_ALL_TESTS
+ depends on MEAN_AND_VARIANCE && KUNIT
+ default KUNIT_ALL_TESTS
+ help
+ This option tests the mean_and_variance module
+
config TEST_LIST_SORT
tristate "Linked list sorting test" if !KUNIT_ALL_TESTS
depends on KUNIT
diff --git a/lib/math/Kconfig b/lib/math/Kconfig
index 12fe28622e82..d5688c673804 100644
--- a/lib/math/Kconfig
+++ b/lib/math/Kconfig
@@ -17,9 +17,6 @@ config RATIONAL
tristate
config MEAN_AND_VARIANCE
- tristate
-
-config MEAN_AND_VARIANCE_UNIT_TEST
- tristate "mean_and_variance unit tests" if !KUNIT_ALL_TESTS
- depends on MEAN_AND_VARIANCE && KUNIT
- default KUNIT_ALL_TESTS
+ tristate "blah"
+ help
+ "This is needed for kunit to work"
diff --git a/lib/math/mean_and_variance.c b/lib/math/mean_and_variance.c
index f15aefc0b4d0..c0552828f8c3 100644
--- a/lib/math/mean_and_variance.c
+++ b/lib/math/mean_and_variance.c
@@ -57,6 +57,16 @@ inline s64 fast_divpow2(s64 n, u8 d)
return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; // + (n < 0 ? 1 : 0);
}
+/*
+inline u128 u128_mul(u128 a, u128 b)
+{
+ u32 a1 = a.lo, a2 = a.lo >> 32, a3 = a.hi, a4 = a.hi >> 32;
+ u32 b1 = b.lo, b2 = b.lo >> 32, b3 = b.hi, b4 = b.hi >> 32;
+
+
+ a1 * b1 + (a2 * b1) << 32 + (a1 + b2) << 32 + a2 * b2)
+}
+*/
/**
* mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
* and return it.
@@ -70,14 +80,15 @@ struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s
struct mean_and_variance s2;
u64 v2 = abs(v1);
+ /*
if (v2 > SQRT_U64_MAX) {
v2 = SQRT_U64_MAX;
WARN(true, "stats overflow! %lld^2 > U64_MAX", v1);
}
-
+ */
s2.n = s1.n + 1;
s2.sum = s1.sum + v1;
- s2.sum_squares = s1.sum_squares + v2*v2;
+ s2.sum_squares = u128_add(s1.sum_squares, u128_square(v2));
return s2;
}
EXPORT_SYMBOL_GPL(mean_and_variance_update);
@@ -98,11 +109,11 @@ EXPORT_SYMBOL_GPL(mean_and_variance_get_mean);
*/
u64 mean_and_variance_get_variance(struct mean_and_variance s1)
{
- u64 s2 = s1.sum_squares / s1.n;
- u64 s3 = abs(mean_and_variance_get_mean(s1));
+ u128 s2 = u128_div(s1.sum_squares,s1.n);
+ u64 s3 = abs(mean_and_variance_get_mean(s1));
- WARN(s3 > SQRT_U64_MAX, "stats overflow %lld ^2 > S64_MAX", s3);
- return s2 - s3*s3;
+ // WARN(s3 > SQRT_U64_MAX, "stats overflow %llu ^2 > S64_MAX", (u64)(s3 >> 64));
+ return u128_sub(s2, u128_square(s3)).lo;
}
EXPORT_SYMBOL_GPL(mean_and_variance_get_variance);
diff --git a/lib/math/mean_and_variance_test.c b/lib/math/mean_and_variance_test.c
index 215e128acd86..4384a075dd36 100644
--- a/lib/math/mean_and_variance_test.c
+++ b/lib/math/mean_and_variance_test.c
@@ -30,12 +30,12 @@ static void mean_and_variance_basic_test(struct kunit *test)
s = mean_and_variance_update(s, SQRT_U64_MAX);
KUNIT_EXPECT_EQ_MSG(test,
- s.sum_squares,
+ s.sum_squares.lo,
MAX_SQR,
"%llu == %llu, sqrt: %llu == %llu",
- s.sum_squares,
+ s.sum_squares.lo,
MAX_SQR,
- int_sqrt64(s.sum_squares),
+ int_sqrt64(s.sum_squares.lo),
SQRT_U64_MAX);
s = (struct mean_and_variance){};
@@ -43,25 +43,25 @@ static void mean_and_variance_basic_test(struct kunit *test)
s = mean_and_variance_update(s, -(s64)SQRT_U64_MAX);
KUNIT_EXPECT_EQ_MSG(test,
- s.sum_squares,
+ s.sum_squares.lo,
MAX_SQR,
"%llu == %llu, sqrt: %llu == %llu",
- s.sum_squares,
+ s.sum_squares.lo,
MAX_SQR,
- int_sqrt64(s.sum_squares),
+ int_sqrt64(s.sum_squares.lo),
SQRT_U64_MAX);
s = (struct mean_and_variance){};
s = mean_and_variance_update(s, (SQRT_U64_MAX + 1));
- KUNIT_EXPECT_EQ(test, s.sum_squares, MAX_SQR);
+ KUNIT_EXPECT_EQ(test, s.sum_squares.lo, MAX_SQR);
s = (struct mean_and_variance){};
s = mean_and_variance_update(s, (-(s64)SQRT_U64_MAX) - 1);
- KUNIT_EXPECT_EQ(test, s.sum_squares, MAX_SQR);
+ KUNIT_EXPECT_EQ(test, s.sum_squares.lo, MAX_SQR);
}
/*
@@ -145,16 +145,49 @@ static void mean_and_variance_fast_divpow2(struct kunit *test)
}
}
+static void mean_and_variance_u128_test(struct kunit *test)
+{
+ u128 a = u128_init(0, U64_MAX);
+ u128 a0 = u128_init(0, 0);
+ u128 a1 = u128_init(0, 1);
+ u128 a2 = u128_init(0, 2);
+ u128 b = u128_init(1, 0);
+ u128 c = u128_init(0, 1LLU << 63);
+
+ KUNIT_EXPECT_EQ(test, u128_add(a,a1).hi, 1);
+ KUNIT_EXPECT_EQ(test, u128_add(a,a1).lo, 0);
+ KUNIT_EXPECT_EQ(test, u128_sub(b,a1).lo, U64_MAX);
+ KUNIT_EXPECT_EQ(test, u128_sub(b,a1).hi, 0);
+
+ KUNIT_EXPECT_EQ(test, u128_shl(c, 1).hi, 1 );
+ KUNIT_EXPECT_EQ(test, u128_shl(c, 1).lo, 0 );
+
+ KUNIT_EXPECT_EQ(test, u128_square(1).hi, 0);
+ KUNIT_EXPECT_EQ(test, u128_square(1).lo, 1);
+
+ KUNIT_EXPECT_EQ(test, u128_square(0).hi, 0);
+ KUNIT_EXPECT_EQ(test, u128_square(0).lo, 0);
+
+ KUNIT_EXPECT_EQ(test, u128_square(2).lo, 4);
+
+ KUNIT_EXPECT_EQ(test, u128_square(U64_MAX).hi, U64_MAX);
+ KUNIT_EXPECT_EQ(test, u128_square(U64_MAX).lo, 1);
+
+
+ KUNIT_EXPECT_EQ(test, u128_div(b, 2).lo, 1LLU << 63);
+}
+
static struct kunit_case mean_and_variance_test_cases[] = {
+ KUNIT_CASE(mean_and_variance_fast_divpow2),
+ KUNIT_CASE(mean_and_variance_u128_test),
KUNIT_CASE(mean_and_variance_basic_test),
KUNIT_CASE(mean_and_variance_weighted_test),
KUNIT_CASE(mean_and_variance_weighted_advanced_test),
- KUNIT_CASE(mean_and_variance_fast_divpow2),
{}
};
static struct kunit_suite mean_and_variance_test_suite = {
-.name = "statistics",
+.name = "mean and variance tests",
.test_cases = mean_and_variance_test_cases
};