diff options
Diffstat (limited to 'include/linux/mean_and_variance.h')
-rw-r--r-- | include/linux/mean_and_variance.h | 95 |
1 files changed, 91 insertions, 4 deletions
diff --git a/include/linux/mean_and_variance.h b/include/linux/mean_and_variance.h index 0a64c1be9d20..1ae948547d99 100644 --- a/include/linux/mean_and_variance.h +++ b/include/linux/mean_and_variance.h @@ -1,15 +1,100 @@ /* SPDX-License-Identifier: GPL-2.0 */ -#ifndef STATS_H_ -#define STATS_H_ +#ifndef MEAN_AND_VARIANCE_H_ +#define MEAN_AND_VAIRANCE_H_ #include <linux/types.h> +#include <linux/limits.h> +#include <linux/math64.h> +#include <linux/printbuf.h> #define SQRT_U64_MAX 4294967295ULL +//#ifdef __SIZEOF_INT128__ + +//typedef unsigned __int128 u128; + +//#else + +typedef struct { + u64 hi; + u64 lo; +} u128; + +static inline u128 u128_init(u64 a, u64 b) +{ + return (u128){ .hi = a, .lo = b }; +} + +static inline u128 u128_add(u128 a, u128 b) +{ + u128 c; + c.lo = a.lo + b.lo; + c.hi = a.hi + b.hi + (c.lo < a.lo); + return c; +} + +static inline u128 u128_sub(u128 a, u128 b) +{ + u128 c; + c.lo = a.lo - b.lo; + c.hi = a.hi - (b.hi + c.lo > a.lo); + return c; +} +static inline u128 u128_shl(u128 i, s8 s1) { + u128 r; + s8 s2 = 64 - s1; + r.lo = i.lo << s1; + r.hi = (i.hi << s1) + (i.lo >> s2); + return r; +} + +static inline u128 u128_square(u64 i) +{ + u128 r; + u64 h = i >> 32, l = i & (u64)U32_MAX; + u64 x; + // overflows: + // + // ( (a*a) << 128) + (( a*b) << 97) + ((a*c) << 65) + ((b*b) << 65) + printk("square %llu", i); + printk("h = %llu, l = %llu\n", h, l); + r = u128_init(h*h, 0); + x = h*l; + printk("hi = %llu, x = %llu, x >> 31 = %llu\n", r.hi, x, x >> 31); + r = u128_add(r, u128_shl(u128_init(0,x), 33)); + printk("hi = %llu, lo = %llu\n", r.hi, r.lo); + x = l*l; + printk("x = %llu\n", x); + r = u128_add(r, u128_init(0, x)); + printk("hi = %llu, lo = %llu\n", r.hi, r.lo); + return r; +} + +static inline u128 u128_div(u128 n, u64 d) { + u128 result; + u64 r; + u64 rem; + u64 hh = n.hi & ((u64)U32_MAX << 32), hl = (n.hi & (u64)U32_MAX), + lh = n.lo & ((u64)U32_MAX << 32), ll = (n.lo & (u64)U32_MAX); + printk("divide: %llu::%llu / %llu", n.hi, n.lo, d); + printk("hi = %llu, hh = %llu, hl = %llu, hh+hl = %llu\n", n.hi, hh, hl, hh+hl); + r = div64_u64_rem(hh, d, &rem); + result.hi = r; + printk("hi = %llu, r = %llu, rem = %llu \n", result.hi, r, rem); + r = div64_u64_rem(((hl + rem) << 32), d, &rem); + result.hi += r >> 32; + printk("hi = %llu, r = %llu, rem = %llu \n", result.hi, r, rem); + r = div64_u64_rem((n.lo + ((rem) << 32)), d, &rem) + (r << 32); + result.lo = r; + printk("lo = %llu, r = %llu, rem = %llu \n", result.hi, r, rem); + return result; +} +//#endif + struct mean_and_variance { s64 n; s64 sum; - u64 sum_squares; + u128 sum_squares; }; /* expontentially weighted variant */ @@ -34,4 +119,6 @@ struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_ u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s); u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s); -#endif // STATS_H_ + + +#endif // MEAN_AND_VAIRANCE_H_ |