summaryrefslogtreecommitdiff
path: root/kernel/ucount.c
diff options
context:
space:
mode:
Diffstat (limited to 'kernel/ucount.c')
-rw-r--r--kernel/ucount.c65
1 files changed, 60 insertions, 5 deletions
diff --git a/kernel/ucount.c b/kernel/ucount.c
index bb51849e6375..4f5613dac227 100644
--- a/kernel/ucount.c
+++ b/kernel/ucount.c
@@ -150,9 +150,15 @@ static void hlist_add_ucounts(struct ucounts *ucounts)
spin_unlock_irq(&ucounts_lock);
}
+static inline bool get_ucounts_or_wrap(struct ucounts *ucounts)
+{
+ /* Returns true on a successful get, false if the count wraps. */
+ return !atomic_add_negative(1, &ucounts->count);
+}
+
struct ucounts *get_ucounts(struct ucounts *ucounts)
{
- if (ucounts && atomic_add_negative(1, &ucounts->count)) {
+ if (!get_ucounts_or_wrap(ucounts)) {
put_ucounts(ucounts);
ucounts = NULL;
}
@@ -163,7 +169,7 @@ struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid)
{
struct hlist_head *hashent = ucounts_hashentry(ns, uid);
struct ucounts *ucounts, *new;
- long overflow;
+ bool wrapped;
spin_lock_irq(&ucounts_lock);
ucounts = find_ucounts(ns, uid, hashent);
@@ -188,9 +194,9 @@ struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid)
return new;
}
}
- overflow = atomic_add_negative(1, &ucounts->count);
+ wrapped = !get_ucounts_or_wrap(ucounts);
spin_unlock_irq(&ucounts_lock);
- if (overflow) {
+ if (wrapped) {
put_ucounts(ucounts);
return NULL;
}
@@ -276,7 +282,7 @@ bool dec_rlimit_ucounts(struct ucounts *ucounts, enum ucount_type type, long v)
struct ucounts *iter;
long new = -1; /* Silence compiler warning */
for (iter = ucounts; iter; iter = iter->ns->ucounts) {
- long dec = atomic_long_add_return(-v, &iter->ucount[type]);
+ long dec = atomic_long_sub_return(v, &iter->ucount[type]);
WARN_ON_ONCE(dec < 0);
if (iter == ucounts)
new = dec;
@@ -284,6 +290,55 @@ bool dec_rlimit_ucounts(struct ucounts *ucounts, enum ucount_type type, long v)
return (new == 0);
}
+static void do_dec_rlimit_put_ucounts(struct ucounts *ucounts,
+ struct ucounts *last, enum ucount_type type)
+{
+ struct ucounts *iter, *next;
+ for (iter = ucounts; iter != last; iter = next) {
+ long dec = atomic_long_sub_return(1, &iter->ucount[type]);
+ WARN_ON_ONCE(dec < 0);
+ next = iter->ns->ucounts;
+ if (dec == 0)
+ put_ucounts(iter);
+ }
+}
+
+void dec_rlimit_put_ucounts(struct ucounts *ucounts, enum ucount_type type)
+{
+ do_dec_rlimit_put_ucounts(ucounts, NULL, type);
+}
+
+long inc_rlimit_get_ucounts(struct ucounts *ucounts, enum ucount_type type)
+{
+ /* Caller must hold a reference to ucounts */
+ struct ucounts *iter;
+ long dec, ret = 0;
+
+ for (iter = ucounts; iter; iter = iter->ns->ucounts) {
+ long max = READ_ONCE(iter->ns->ucount_max[type]);
+ long new = atomic_long_add_return(1, &iter->ucount[type]);
+ if (new < 0 || new > max)
+ goto unwind;
+ if (iter == ucounts)
+ ret = new;
+ /*
+ * Grab an extra ucount reference for the caller when
+ * the rlimit count was previously 0.
+ */
+ if (new != 1)
+ continue;
+ if (!get_ucounts(iter))
+ goto dec_unwind;
+ }
+ return ret;
+dec_unwind:
+ dec = atomic_long_sub_return(1, &iter->ucount[type]);
+ WARN_ON_ONCE(dec < 0);
+unwind:
+ do_dec_rlimit_put_ucounts(ucounts, iter, type);
+ return 0;
+}
+
bool is_ucounts_overlimit(struct ucounts *ucounts, enum ucount_type type, unsigned long max)
{
struct ucounts *iter;