summaryrefslogtreecommitdiff
path: root/include/linux/mean_and_variance.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/linux/mean_and_variance.h')
-rw-r--r--include/linux/mean_and_variance.h95
1 files changed, 91 insertions, 4 deletions
diff --git a/include/linux/mean_and_variance.h b/include/linux/mean_and_variance.h
index 0a64c1be9d20..1ae948547d99 100644
--- a/include/linux/mean_and_variance.h
+++ b/include/linux/mean_and_variance.h
@@ -1,15 +1,100 @@
/* SPDX-License-Identifier: GPL-2.0 */
-#ifndef STATS_H_
-#define STATS_H_
+#ifndef MEAN_AND_VARIANCE_H_
+#define MEAN_AND_VAIRANCE_H_
#include <linux/types.h>
+#include <linux/limits.h>
+#include <linux/math64.h>
+#include <linux/printbuf.h>
#define SQRT_U64_MAX 4294967295ULL
+//#ifdef __SIZEOF_INT128__
+
+//typedef unsigned __int128 u128;
+
+//#else
+
+typedef struct {
+ u64 hi;
+ u64 lo;
+} u128;
+
+static inline u128 u128_init(u64 a, u64 b)
+{
+ return (u128){ .hi = a, .lo = b };
+}
+
+static inline u128 u128_add(u128 a, u128 b)
+{
+ u128 c;
+ c.lo = a.lo + b.lo;
+ c.hi = a.hi + b.hi + (c.lo < a.lo);
+ return c;
+}
+
+static inline u128 u128_sub(u128 a, u128 b)
+{
+ u128 c;
+ c.lo = a.lo - b.lo;
+ c.hi = a.hi - (b.hi + c.lo > a.lo);
+ return c;
+}
+static inline u128 u128_shl(u128 i, s8 s1) {
+ u128 r;
+ s8 s2 = 64 - s1;
+ r.lo = i.lo << s1;
+ r.hi = (i.hi << s1) + (i.lo >> s2);
+ return r;
+}
+
+static inline u128 u128_square(u64 i)
+{
+ u128 r;
+ u64 h = i >> 32, l = i & (u64)U32_MAX;
+ u64 x;
+ // overflows:
+ //
+ // ( (a*a) << 128) + (( a*b) << 97) + ((a*c) << 65) + ((b*b) << 65)
+ printk("square %llu", i);
+ printk("h = %llu, l = %llu\n", h, l);
+ r = u128_init(h*h, 0);
+ x = h*l;
+ printk("hi = %llu, x = %llu, x >> 31 = %llu\n", r.hi, x, x >> 31);
+ r = u128_add(r, u128_shl(u128_init(0,x), 33));
+ printk("hi = %llu, lo = %llu\n", r.hi, r.lo);
+ x = l*l;
+ printk("x = %llu\n", x);
+ r = u128_add(r, u128_init(0, x));
+ printk("hi = %llu, lo = %llu\n", r.hi, r.lo);
+ return r;
+}
+
+static inline u128 u128_div(u128 n, u64 d) {
+ u128 result;
+ u64 r;
+ u64 rem;
+ u64 hh = n.hi & ((u64)U32_MAX << 32), hl = (n.hi & (u64)U32_MAX),
+ lh = n.lo & ((u64)U32_MAX << 32), ll = (n.lo & (u64)U32_MAX);
+ printk("divide: %llu::%llu / %llu", n.hi, n.lo, d);
+ printk("hi = %llu, hh = %llu, hl = %llu, hh+hl = %llu\n", n.hi, hh, hl, hh+hl);
+ r = div64_u64_rem(hh, d, &rem);
+ result.hi = r;
+ printk("hi = %llu, r = %llu, rem = %llu \n", result.hi, r, rem);
+ r = div64_u64_rem(((hl + rem) << 32), d, &rem);
+ result.hi += r >> 32;
+ printk("hi = %llu, r = %llu, rem = %llu \n", result.hi, r, rem);
+ r = div64_u64_rem((n.lo + ((rem) << 32)), d, &rem) + (r << 32);
+ result.lo = r;
+ printk("lo = %llu, r = %llu, rem = %llu \n", result.hi, r, rem);
+ return result;
+}
+//#endif
+
struct mean_and_variance {
s64 n;
s64 sum;
- u64 sum_squares;
+ u128 sum_squares;
};
/* expontentially weighted variant */
@@ -34,4 +119,6 @@ struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_
u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
-#endif // STATS_H_
+
+
+#endif // MEAN_AND_VAIRANCE_H_