diff options
author | Daniel Hill <daniel@gluo.nz> | 2022-09-22 17:41:50 +1200 |
---|---|---|
committer | Daniel Hill <daniel@gluo.nz> | 2022-09-23 12:42:03 +1200 |
commit | df92f413c8bbff3cfe40bfdcccacf00ac9567c76 (patch) | |
tree | 0ca4541ff4457c3dc66be74c2f448e565f4c85f8 | |
parent | 2604a8e4bbac50c014007bee9b1bf1a84ec1fa13 (diff) |
128bit math
-rw-r--r-- | include/linux/mean_and_variance.h | 95 | ||||
-rw-r--r-- | lib/Kconfig.debug | 7 | ||||
-rw-r--r-- | lib/math/Kconfig | 9 | ||||
-rw-r--r-- | lib/math/mean_and_variance.c | 23 | ||||
-rw-r--r-- | lib/math/mean_and_variance_test.c | 53 |
5 files changed, 161 insertions, 26 deletions
diff --git a/include/linux/mean_and_variance.h b/include/linux/mean_and_variance.h index 0a64c1be9d20..1ae948547d99 100644 --- a/include/linux/mean_and_variance.h +++ b/include/linux/mean_and_variance.h @@ -1,15 +1,100 @@ /* SPDX-License-Identifier: GPL-2.0 */ -#ifndef STATS_H_ -#define STATS_H_ +#ifndef MEAN_AND_VARIANCE_H_ +#define MEAN_AND_VAIRANCE_H_ #include <linux/types.h> +#include <linux/limits.h> +#include <linux/math64.h> +#include <linux/printbuf.h> #define SQRT_U64_MAX 4294967295ULL +//#ifdef __SIZEOF_INT128__ + +//typedef unsigned __int128 u128; + +//#else + +typedef struct { + u64 hi; + u64 lo; +} u128; + +static inline u128 u128_init(u64 a, u64 b) +{ + return (u128){ .hi = a, .lo = b }; +} + +static inline u128 u128_add(u128 a, u128 b) +{ + u128 c; + c.lo = a.lo + b.lo; + c.hi = a.hi + b.hi + (c.lo < a.lo); + return c; +} + +static inline u128 u128_sub(u128 a, u128 b) +{ + u128 c; + c.lo = a.lo - b.lo; + c.hi = a.hi - (b.hi + c.lo > a.lo); + return c; +} +static inline u128 u128_shl(u128 i, s8 s1) { + u128 r; + s8 s2 = 64 - s1; + r.lo = i.lo << s1; + r.hi = (i.hi << s1) + (i.lo >> s2); + return r; +} + +static inline u128 u128_square(u64 i) +{ + u128 r; + u64 h = i >> 32, l = i & (u64)U32_MAX; + u64 x; + // overflows: + // + // ( (a*a) << 128) + (( a*b) << 97) + ((a*c) << 65) + ((b*b) << 65) + printk("square %llu", i); + printk("h = %llu, l = %llu\n", h, l); + r = u128_init(h*h, 0); + x = h*l; + printk("hi = %llu, x = %llu, x >> 31 = %llu\n", r.hi, x, x >> 31); + r = u128_add(r, u128_shl(u128_init(0,x), 33)); + printk("hi = %llu, lo = %llu\n", r.hi, r.lo); + x = l*l; + printk("x = %llu\n", x); + r = u128_add(r, u128_init(0, x)); + printk("hi = %llu, lo = %llu\n", r.hi, r.lo); + return r; +} + +static inline u128 u128_div(u128 n, u64 d) { + u128 result; + u64 r; + u64 rem; + u64 hh = n.hi & ((u64)U32_MAX << 32), hl = (n.hi & (u64)U32_MAX), + lh = n.lo & ((u64)U32_MAX << 32), ll = (n.lo & (u64)U32_MAX); + printk("divide: %llu::%llu / %llu", n.hi, n.lo, d); + printk("hi = %llu, hh = %llu, hl = %llu, hh+hl = %llu\n", n.hi, hh, hl, hh+hl); + r = div64_u64_rem(hh, d, &rem); + result.hi = r; + printk("hi = %llu, r = %llu, rem = %llu \n", result.hi, r, rem); + r = div64_u64_rem(((hl + rem) << 32), d, &rem); + result.hi += r >> 32; + printk("hi = %llu, r = %llu, rem = %llu \n", result.hi, r, rem); + r = div64_u64_rem((n.lo + ((rem) << 32)), d, &rem) + (r << 32); + result.lo = r; + printk("lo = %llu, r = %llu, rem = %llu \n", result.hi, r, rem); + return result; +} +//#endif + struct mean_and_variance { s64 n; s64 sum; - u64 sum_squares; + u128 sum_squares; }; /* expontentially weighted variant */ @@ -34,4 +119,6 @@ struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_ u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s); u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s); -#endif // STATS_H_ + + +#endif // MEAN_AND_VAIRANCE_H_ diff --git a/lib/Kconfig.debug b/lib/Kconfig.debug index 1d4ed12a5355..f39f09b835ce 100644 --- a/lib/Kconfig.debug +++ b/lib/Kconfig.debug @@ -2030,6 +2030,13 @@ config LKDTM Documentation on how to use the module can be found in Documentation/fault-injection/provoke-crashes.rst +config MEAN_AND_VARIANCE_UNIT_TEST + tristate "mean_and_variance unit tests" if !KUNIT_ALL_TESTS + depends on MEAN_AND_VARIANCE && KUNIT + default KUNIT_ALL_TESTS + help + This option tests the mean_and_variance module + config TEST_LIST_SORT tristate "Linked list sorting test" if !KUNIT_ALL_TESTS depends on KUNIT diff --git a/lib/math/Kconfig b/lib/math/Kconfig index 12fe28622e82..d5688c673804 100644 --- a/lib/math/Kconfig +++ b/lib/math/Kconfig @@ -17,9 +17,6 @@ config RATIONAL tristate config MEAN_AND_VARIANCE - tristate - -config MEAN_AND_VARIANCE_UNIT_TEST - tristate "mean_and_variance unit tests" if !KUNIT_ALL_TESTS - depends on MEAN_AND_VARIANCE && KUNIT - default KUNIT_ALL_TESTS + tristate "blah" + help + "This is needed for kunit to work" diff --git a/lib/math/mean_and_variance.c b/lib/math/mean_and_variance.c index f15aefc0b4d0..c0552828f8c3 100644 --- a/lib/math/mean_and_variance.c +++ b/lib/math/mean_and_variance.c @@ -57,6 +57,16 @@ inline s64 fast_divpow2(s64 n, u8 d) return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; // + (n < 0 ? 1 : 0); } +/* +inline u128 u128_mul(u128 a, u128 b) +{ + u32 a1 = a.lo, a2 = a.lo >> 32, a3 = a.hi, a4 = a.hi >> 32; + u32 b1 = b.lo, b2 = b.lo >> 32, b3 = b.hi, b4 = b.hi >> 32; + + + a1 * b1 + (a2 * b1) << 32 + (a1 + b2) << 32 + a2 * b2) +} +*/ /** * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 * and return it. @@ -70,14 +80,15 @@ struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s struct mean_and_variance s2; u64 v2 = abs(v1); + /* if (v2 > SQRT_U64_MAX) { v2 = SQRT_U64_MAX; WARN(true, "stats overflow! %lld^2 > U64_MAX", v1); } - + */ s2.n = s1.n + 1; s2.sum = s1.sum + v1; - s2.sum_squares = s1.sum_squares + v2*v2; + s2.sum_squares = u128_add(s1.sum_squares, u128_square(v2)); return s2; } EXPORT_SYMBOL_GPL(mean_and_variance_update); @@ -98,11 +109,11 @@ EXPORT_SYMBOL_GPL(mean_and_variance_get_mean); */ u64 mean_and_variance_get_variance(struct mean_and_variance s1) { - u64 s2 = s1.sum_squares / s1.n; - u64 s3 = abs(mean_and_variance_get_mean(s1)); + u128 s2 = u128_div(s1.sum_squares,s1.n); + u64 s3 = abs(mean_and_variance_get_mean(s1)); - WARN(s3 > SQRT_U64_MAX, "stats overflow %lld ^2 > S64_MAX", s3); - return s2 - s3*s3; + // WARN(s3 > SQRT_U64_MAX, "stats overflow %llu ^2 > S64_MAX", (u64)(s3 >> 64)); + return u128_sub(s2, u128_square(s3)).lo; } EXPORT_SYMBOL_GPL(mean_and_variance_get_variance); diff --git a/lib/math/mean_and_variance_test.c b/lib/math/mean_and_variance_test.c index 215e128acd86..4384a075dd36 100644 --- a/lib/math/mean_and_variance_test.c +++ b/lib/math/mean_and_variance_test.c @@ -30,12 +30,12 @@ static void mean_and_variance_basic_test(struct kunit *test) s = mean_and_variance_update(s, SQRT_U64_MAX); KUNIT_EXPECT_EQ_MSG(test, - s.sum_squares, + s.sum_squares.lo, MAX_SQR, "%llu == %llu, sqrt: %llu == %llu", - s.sum_squares, + s.sum_squares.lo, MAX_SQR, - int_sqrt64(s.sum_squares), + int_sqrt64(s.sum_squares.lo), SQRT_U64_MAX); s = (struct mean_and_variance){}; @@ -43,25 +43,25 @@ static void mean_and_variance_basic_test(struct kunit *test) s = mean_and_variance_update(s, -(s64)SQRT_U64_MAX); KUNIT_EXPECT_EQ_MSG(test, - s.sum_squares, + s.sum_squares.lo, MAX_SQR, "%llu == %llu, sqrt: %llu == %llu", - s.sum_squares, + s.sum_squares.lo, MAX_SQR, - int_sqrt64(s.sum_squares), + int_sqrt64(s.sum_squares.lo), SQRT_U64_MAX); s = (struct mean_and_variance){}; s = mean_and_variance_update(s, (SQRT_U64_MAX + 1)); - KUNIT_EXPECT_EQ(test, s.sum_squares, MAX_SQR); + KUNIT_EXPECT_EQ(test, s.sum_squares.lo, MAX_SQR); s = (struct mean_and_variance){}; s = mean_and_variance_update(s, (-(s64)SQRT_U64_MAX) - 1); - KUNIT_EXPECT_EQ(test, s.sum_squares, MAX_SQR); + KUNIT_EXPECT_EQ(test, s.sum_squares.lo, MAX_SQR); } /* @@ -145,16 +145,49 @@ static void mean_and_variance_fast_divpow2(struct kunit *test) } } +static void mean_and_variance_u128_test(struct kunit *test) +{ + u128 a = u128_init(0, U64_MAX); + u128 a0 = u128_init(0, 0); + u128 a1 = u128_init(0, 1); + u128 a2 = u128_init(0, 2); + u128 b = u128_init(1, 0); + u128 c = u128_init(0, 1LLU << 63); + + KUNIT_EXPECT_EQ(test, u128_add(a,a1).hi, 1); + KUNIT_EXPECT_EQ(test, u128_add(a,a1).lo, 0); + KUNIT_EXPECT_EQ(test, u128_sub(b,a1).lo, U64_MAX); + KUNIT_EXPECT_EQ(test, u128_sub(b,a1).hi, 0); + + KUNIT_EXPECT_EQ(test, u128_shl(c, 1).hi, 1 ); + KUNIT_EXPECT_EQ(test, u128_shl(c, 1).lo, 0 ); + + KUNIT_EXPECT_EQ(test, u128_square(1).hi, 0); + KUNIT_EXPECT_EQ(test, u128_square(1).lo, 1); + + KUNIT_EXPECT_EQ(test, u128_square(0).hi, 0); + KUNIT_EXPECT_EQ(test, u128_square(0).lo, 0); + + KUNIT_EXPECT_EQ(test, u128_square(2).lo, 4); + + KUNIT_EXPECT_EQ(test, u128_square(U64_MAX).hi, U64_MAX); + KUNIT_EXPECT_EQ(test, u128_square(U64_MAX).lo, 1); + + + KUNIT_EXPECT_EQ(test, u128_div(b, 2).lo, 1LLU << 63); +} + static struct kunit_case mean_and_variance_test_cases[] = { + KUNIT_CASE(mean_and_variance_fast_divpow2), + KUNIT_CASE(mean_and_variance_u128_test), KUNIT_CASE(mean_and_variance_basic_test), KUNIT_CASE(mean_and_variance_weighted_test), KUNIT_CASE(mean_and_variance_weighted_advanced_test), - KUNIT_CASE(mean_and_variance_fast_divpow2), {} }; static struct kunit_suite mean_and_variance_test_suite = { -.name = "statistics", +.name = "mean and variance tests", .test_cases = mean_and_variance_test_cases }; |