diff options
Diffstat (limited to 'include/linux/mean_and_variance.h')
-rw-r--r-- | include/linux/mean_and_variance.h | 219 |
1 files changed, 94 insertions, 125 deletions
diff --git a/include/linux/mean_and_variance.h b/include/linux/mean_and_variance.h index 756eb3d1..9ed79f42 100644 --- a/include/linux/mean_and_variance.h +++ b/include/linux/mean_and_variance.h @@ -2,122 +2,112 @@ #ifndef MEAN_AND_VARIANCE_H_ #define MEAN_AND_VARIANCE_H_ -#include <linux/kernel.h> #include <linux/types.h> +#include <linux/kernel.h> #include <linux/limits.h> #include <linux/math64.h> +#include <stdlib.h> #define SQRT_U64_MAX 4294967295ULL -/** - * abs - return absolute value of an argument - * @x: the value. If it is unsigned type, it is converted to signed type first. - * char is treated as if it was signed (regardless of whether it really is) - * but the macro's return type is preserved as char. - * - * Return: an absolute value of x. +/* + * u128_u: u128 user mode, because not all architectures support a real int128 + * type */ -#define abs(x) __abs_choose_expr(x, long long, \ - __abs_choose_expr(x, long, \ - __abs_choose_expr(x, int, \ - __abs_choose_expr(x, short, \ - __abs_choose_expr(x, char, \ - __builtin_choose_expr( \ - __builtin_types_compatible_p(typeof(x), char), \ - (char)({ signed char __x = (x); __x<0?-__x:__x; }), \ - ((void)0))))))) -#define __abs_choose_expr(x, type, other) __builtin_choose_expr( \ - __builtin_types_compatible_p(typeof(x), signed type) || \ - __builtin_types_compatible_p(typeof(x), unsigned type), \ - ({ signed type __x = (x); __x < 0 ? -__x : __x; }), other) +#ifdef __SIZEOF_INT128__ -#if defined(CONFIG_ARCH_SUPPORTS_INT128) && defined(__SIZEOF_INT128__) - -typedef unsigned __int128 u128; +typedef struct { + unsigned __int128 v; +} __aligned(16) u128_u; -static inline u128 u64_to_u128(u64 a) +static inline u128_u u64_to_u128(u64 a) { - return (u128)a; + return (u128_u) { .v = a }; } -static inline u64 u128_to_u64(u128 a) +static inline u64 u128_lo(u128_u a) { - return (u64)a; + return a.v; } -static inline u64 u128_shr64_to_u64(u128 a) +static inline u64 u128_hi(u128_u a) { - return (u64)(a >> 64); + return a.v >> 64; } -static inline u128 u128_add(u128 a, u128 b) +static inline u128_u u128_add(u128_u a, u128_u b) { - return a + b; + a.v += b.v; + return a; } -static inline u128 u128_sub(u128 a, u128 b) +static inline u128_u u128_sub(u128_u a, u128_u b) { - return a - b; + a.v -= b.v; + return a; } -static inline u128 u128_shl(u128 i, s8 shift) +static inline u128_u u128_shl(u128_u a, s8 shift) { - return i << shift; + a.v <<= shift; + return a; } -static inline u128 u128_shl64_add(u64 a, u64 b) +static inline u128_u u128_square(u64 a) { - return ((u128)a << 64) + b; -} + u128_u b = u64_to_u128(a); -static inline u128 u128_square(u64 i) -{ - return i*i; + b.v *= b.v; + return b; } #else typedef struct { u64 hi, lo; -} u128; +} __aligned(16) u128_u; + +/* conversions */ -static inline u128 u64_to_u128(u64 a) +static inline u128_u u64_to_u128(u64 a) { - return (u128){ .lo = a }; + return (u128_u) { .lo = a }; } -static inline u64 u128_to_u64(u128 a) +static inline u64 u128_lo(u128_u a) { return a.lo; } -static inline u64 u128_shr64_to_u64(u128 a) +static inline u64 u128_hi(u128_u a) { return a.hi; } -static inline u128 u128_add(u128 a, u128 b) +/* arithmetic */ + +static inline u128_u u128_add(u128_u a, u128_u b) { - u128 c; + u128_u c; c.lo = a.lo + b.lo; c.hi = a.hi + b.hi + (c.lo < a.lo); return c; } -static inline u128 u128_sub(u128 a, u128 b) +static inline u128_u u128_sub(u128_u a, u128_u b) { - u128 c; + u128_u c; c.lo = a.lo - b.lo; c.hi = a.hi - b.hi - (c.lo > a.lo); return c; } -static inline u128 u128_shl(u128 i, s8 shift) +static inline u128_u u128_shl(u128_u i, s8 shift) { - u128 r; + u128_u r; r.lo = i.lo << shift; if (shift < 64) @@ -129,15 +119,10 @@ static inline u128 u128_shl(u128 i, s8 shift) return r; } -static inline u128 u128_shl64_add(u64 a, u64 b) +static inline u128_u u128_square(u64 i) { - return u128_add(u128_shl(u64_to_u128(a), 64), u64_to_u128(b)); -} - -static inline u128 u128_square(u64 i) -{ - u128 r; - u64 h = i >> 32, l = i & (u64)U32_MAX; + u128_u r; + u64 h = i >> 32, l = i & U32_MAX; r = u128_shl(u64_to_u128(h*h), 64); r = u128_add(r, u128_shl(u64_to_u128(h*l), 32)); @@ -148,85 +133,69 @@ static inline u128 u128_square(u64 i) #endif -static inline u128 u128_div(u128 n, u64 d) +static inline u128_u u64s_to_u128(u64 hi, u64 lo) { - u128 r; - u64 rem; - u64 hi = u128_shr64_to_u64(n); - u64 lo = u128_to_u64(n); - u64 h = hi & ((u64)U32_MAX << 32); - u64 l = (hi & (u64)U32_MAX) << 32; + u128_u c = u64_to_u128(hi); - r = u128_shl(u64_to_u128(div64_u64_rem(h, d, &rem)), 64); - r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l + (rem << 32), d, &rem)), 32)); - r = u128_add(r, u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem))); - return r; + c = u128_shl(c, 64); + c = u128_add(c, u64_to_u128(lo)); + return c; } +u128_u u128_div(u128_u n, u64 d); + struct mean_and_variance { - s64 n; - s64 sum; - u128 sum_squares; + s64 n; + s64 sum; + u128_u sum_squares; }; /* expontentially weighted variant */ struct mean_and_variance_weighted { - bool init; - u8 w; - s64 mean; - u64 variance; + bool init; + u8 weight; /* base 2 logarithim */ + s64 mean; + u64 variance; }; -s64 fast_divpow2(s64 n, u8 d); +/** + * fast_divpow2() - fast approximation for n / (1 << d) + * @n: numerator + * @d: the power of 2 denominator. + * + * note: this rounds towards 0. + */ +static inline s64 fast_divpow2(s64 n, u8 d) +{ + return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; +} +/** + * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 + * and return it. + * @s1: the mean_and_variance to update. + * @v1: the new sample. + * + * see linked pdf equation 12. + */ static inline struct mean_and_variance -mean_and_variance_update_inlined(struct mean_and_variance s1, s64 v1) -{ - struct mean_and_variance s2; - u64 v2 = abs(v1); - - s2.n = s1.n + 1; - s2.sum = s1.sum + v1; - s2.sum_squares = u128_add(s1.sum_squares, u128_square(v2)); - return s2; -} - -static inline struct mean_and_variance_weighted -mean_and_variance_weighted_update_inlined(struct mean_and_variance_weighted s1, s64 x) -{ - struct mean_and_variance_weighted s2; - // previous weighted variance. - u64 var_w0 = s1.variance; - u8 w = s2.w = s1.w; - // new value weighted. - s64 x_w = x << w; - s64 diff_w = x_w - s1.mean; - s64 diff = fast_divpow2(diff_w, w); - // new mean weighted. - s64 u_w1 = s1.mean + diff; - - BUG_ON(w % 2 != 0); - - if (!s1.init) { - s2.mean = x_w; - s2.variance = 0; - } else { - s2.mean = u_w1; - s2.variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w; - } - s2.init = true; - - return s2; +mean_and_variance_update(struct mean_and_variance s, s64 v) +{ + return (struct mean_and_variance) { + .n = s.n + 1, + .sum = s.sum + v, + .sum_squares = u128_add(s.sum_squares, u128_square(abs(v))), + }; } -struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1); - s64 mean_and_variance_get_mean(struct mean_and_variance s); - u64 mean_and_variance_get_variance(struct mean_and_variance s1); - u32 mean_and_variance_get_stddev(struct mean_and_variance s); +s64 mean_and_variance_get_mean(struct mean_and_variance s); +u64 mean_and_variance_get_variance(struct mean_and_variance s1); +u32 mean_and_variance_get_stddev(struct mean_and_variance s); + +void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v); -struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1, s64 v1); - s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s); - u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s); - u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s); +s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s); +u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s); +u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s); #endif // MEAN_AND_VAIRANCE_H_ |