summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--net/ipv4/udp.c22
-rw-r--r--tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c135
-rw-r--r--tools/testing/selftests/bpf/progs/bpf_tracing_net.h3
-rw-r--r--tools/testing/selftests/bpf/progs/sock_iter_batch.c91
-rw-r--r--tools/testing/selftests/bpf/progs/test_jhash.h31
5 files changed, 270 insertions, 12 deletions
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 5f742d0b9e07..148ffb007969 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -3137,16 +3137,18 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
struct bpf_udp_iter_state *iter = seq->private;
struct udp_iter_state *state = &iter->state;
struct net *net = seq_file_net(seq);
+ int resume_bucket, resume_offset;
struct udp_table *udptable;
unsigned int batch_sks = 0;
bool resized = false;
struct sock *sk;
+ resume_bucket = state->bucket;
+ resume_offset = iter->offset;
+
/* The current batch is done, so advance the bucket. */
- if (iter->st_bucket_done) {
+ if (iter->st_bucket_done)
state->bucket++;
- iter->offset = 0;
- }
udptable = udp_get_table_seq(seq, net);
@@ -3166,19 +3168,19 @@ again:
for (; state->bucket <= udptable->mask; state->bucket++) {
struct udp_hslot *hslot2 = &udptable->hash2[state->bucket];
- if (hlist_empty(&hslot2->head)) {
- iter->offset = 0;
+ if (hlist_empty(&hslot2->head))
continue;
- }
+ iter->offset = 0;
spin_lock_bh(&hslot2->lock);
udp_portaddr_for_each_entry(sk, &hslot2->head) {
if (seq_sk_match(seq, sk)) {
/* Resume from the last iterated socket at the
* offset in the bucket before iterator was stopped.
*/
- if (iter->offset) {
- --iter->offset;
+ if (state->bucket == resume_bucket &&
+ iter->offset < resume_offset) {
+ ++iter->offset;
continue;
}
if (iter->end_sk < iter->max_sk) {
@@ -3192,9 +3194,6 @@ again:
if (iter->end_sk)
break;
-
- /* Reset the current bucket's offset before moving to the next bucket. */
- iter->offset = 0;
}
/* All done: no batch made. */
@@ -3213,7 +3212,6 @@ again:
/* After allocating a larger batch, retry one more time to grab
* the whole bucket.
*/
- state->bucket--;
goto again;
}
done:
diff --git a/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c b/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c
new file mode 100644
index 000000000000..0c365f36c73b
--- /dev/null
+++ b/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c
@@ -0,0 +1,135 @@
+// SPDX-License-Identifier: GPL-2.0
+// Copyright (c) 2024 Meta
+
+#include <test_progs.h>
+#include "network_helpers.h"
+#include "sock_iter_batch.skel.h"
+
+#define TEST_NS "sock_iter_batch_netns"
+
+static const int nr_soreuse = 4;
+
+static void do_test(int sock_type, bool onebyone)
+{
+ int err, i, nread, to_read, total_read, iter_fd = -1;
+ int first_idx, second_idx, indices[nr_soreuse];
+ struct bpf_link *link = NULL;
+ struct sock_iter_batch *skel;
+ int *fds[2] = {};
+
+ skel = sock_iter_batch__open();
+ if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open"))
+ return;
+
+ /* Prepare 2 buckets of sockets in the kernel hashtable */
+ for (i = 0; i < ARRAY_SIZE(fds); i++) {
+ int local_port;
+
+ fds[i] = start_reuseport_server(AF_INET6, sock_type, "::1", 0, 0,
+ nr_soreuse);
+ if (!ASSERT_OK_PTR(fds[i], "start_reuseport_server"))
+ goto done;
+ local_port = get_socket_local_port(*fds[i]);
+ if (!ASSERT_GE(local_port, 0, "get_socket_local_port"))
+ goto done;
+ skel->rodata->ports[i] = ntohs(local_port);
+ }
+
+ err = sock_iter_batch__load(skel);
+ if (!ASSERT_OK(err, "sock_iter_batch__load"))
+ goto done;
+
+ link = bpf_program__attach_iter(sock_type == SOCK_STREAM ?
+ skel->progs.iter_tcp_soreuse :
+ skel->progs.iter_udp_soreuse,
+ NULL);
+ if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter"))
+ goto done;
+
+ iter_fd = bpf_iter_create(bpf_link__fd(link));
+ if (!ASSERT_GE(iter_fd, 0, "bpf_iter_create"))
+ goto done;
+
+ /* Test reading a bucket (either from fds[0] or fds[1]).
+ * Only read "nr_soreuse - 1" number of sockets
+ * from a bucket and leave one socket out from
+ * that bucket on purpose.
+ */
+ to_read = (nr_soreuse - 1) * sizeof(*indices);
+ total_read = 0;
+ first_idx = -1;
+ do {
+ nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read);
+ if (nread <= 0 || nread % sizeof(*indices))
+ break;
+ total_read += nread;
+
+ if (first_idx == -1)
+ first_idx = indices[0];
+ for (i = 0; i < nread / sizeof(*indices); i++)
+ ASSERT_EQ(indices[i], first_idx, "first_idx");
+ } while (total_read < to_read);
+ ASSERT_EQ(nread, onebyone ? sizeof(*indices) : to_read, "nread");
+ ASSERT_EQ(total_read, to_read, "total_read");
+
+ free_fds(fds[first_idx], nr_soreuse);
+ fds[first_idx] = NULL;
+
+ /* Read the "whole" second bucket */
+ to_read = nr_soreuse * sizeof(*indices);
+ total_read = 0;
+ second_idx = !first_idx;
+ do {
+ nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read);
+ if (nread <= 0 || nread % sizeof(*indices))
+ break;
+ total_read += nread;
+
+ for (i = 0; i < nread / sizeof(*indices); i++)
+ ASSERT_EQ(indices[i], second_idx, "second_idx");
+ } while (total_read <= to_read);
+ ASSERT_EQ(nread, 0, "nread");
+ /* Both so_reuseport ports should be in different buckets, so
+ * total_read must equal to the expected to_read.
+ *
+ * For a very unlikely case, both ports collide at the same bucket,
+ * the bucket offset (i.e. 3) will be skipped and it cannot
+ * expect the to_read number of bytes.
+ */
+ if (skel->bss->bucket[0] != skel->bss->bucket[1])
+ ASSERT_EQ(total_read, to_read, "total_read");
+
+done:
+ for (i = 0; i < ARRAY_SIZE(fds); i++)
+ free_fds(fds[i], nr_soreuse);
+ if (iter_fd < 0)
+ close(iter_fd);
+ bpf_link__destroy(link);
+ sock_iter_batch__destroy(skel);
+}
+
+void test_sock_iter_batch(void)
+{
+ struct nstoken *nstoken = NULL;
+
+ SYS_NOFAIL("ip netns del " TEST_NS " &> /dev/null");
+ SYS(done, "ip netns add %s", TEST_NS);
+ SYS(done, "ip -net %s link set dev lo up", TEST_NS);
+
+ nstoken = open_netns(TEST_NS);
+ if (!ASSERT_OK_PTR(nstoken, "open_netns"))
+ goto done;
+
+ if (test__start_subtest("tcp")) {
+ do_test(SOCK_STREAM, true);
+ do_test(SOCK_STREAM, false);
+ }
+ if (test__start_subtest("udp")) {
+ do_test(SOCK_DGRAM, true);
+ do_test(SOCK_DGRAM, false);
+ }
+ close_netns(nstoken);
+
+done:
+ SYS_NOFAIL("ip netns del " TEST_NS " &> /dev/null");
+}
diff --git a/tools/testing/selftests/bpf/progs/bpf_tracing_net.h b/tools/testing/selftests/bpf/progs/bpf_tracing_net.h
index 1bdc680b0e0e..e8bd4b7b5ef7 100644
--- a/tools/testing/selftests/bpf/progs/bpf_tracing_net.h
+++ b/tools/testing/selftests/bpf/progs/bpf_tracing_net.h
@@ -72,6 +72,8 @@
#define inet_rcv_saddr sk.__sk_common.skc_rcv_saddr
#define inet_dport sk.__sk_common.skc_dport
+#define udp_portaddr_hash inet.sk.__sk_common.skc_u16hashes[1]
+
#define ir_loc_addr req.__req_common.skc_rcv_saddr
#define ir_num req.__req_common.skc_num
#define ir_rmt_addr req.__req_common.skc_daddr
@@ -85,6 +87,7 @@
#define sk_rmem_alloc sk_backlog.rmem_alloc
#define sk_refcnt __sk_common.skc_refcnt
#define sk_state __sk_common.skc_state
+#define sk_net __sk_common.skc_net
#define sk_v6_daddr __sk_common.skc_v6_daddr
#define sk_v6_rcv_saddr __sk_common.skc_v6_rcv_saddr
#define sk_flags __sk_common.skc_flags
diff --git a/tools/testing/selftests/bpf/progs/sock_iter_batch.c b/tools/testing/selftests/bpf/progs/sock_iter_batch.c
new file mode 100644
index 000000000000..ffbbfe1fa1c1
--- /dev/null
+++ b/tools/testing/selftests/bpf/progs/sock_iter_batch.c
@@ -0,0 +1,91 @@
+// SPDX-License-Identifier: GPL-2.0
+// Copyright (c) 2024 Meta
+
+#include "vmlinux.h"
+#include <bpf/bpf_helpers.h>
+#include <bpf/bpf_core_read.h>
+#include <bpf/bpf_endian.h>
+#include "bpf_tracing_net.h"
+#include "bpf_kfuncs.h"
+
+#define ATTR __always_inline
+#include "test_jhash.h"
+
+static bool ipv6_addr_loopback(const struct in6_addr *a)
+{
+ return (a->s6_addr32[0] | a->s6_addr32[1] |
+ a->s6_addr32[2] | (a->s6_addr32[3] ^ bpf_htonl(1))) == 0;
+}
+
+volatile const __u16 ports[2];
+unsigned int bucket[2];
+
+SEC("iter/tcp")
+int iter_tcp_soreuse(struct bpf_iter__tcp *ctx)
+{
+ struct sock *sk = (struct sock *)ctx->sk_common;
+ struct inet_hashinfo *hinfo;
+ unsigned int hash;
+ struct net *net;
+ int idx;
+
+ if (!sk)
+ return 0;
+
+ sk = bpf_rdonly_cast(sk, bpf_core_type_id_kernel(struct sock));
+ if (sk->sk_family != AF_INET6 ||
+ sk->sk_state != TCP_LISTEN ||
+ !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr))
+ return 0;
+
+ if (sk->sk_num == ports[0])
+ idx = 0;
+ else if (sk->sk_num == ports[1])
+ idx = 1;
+ else
+ return 0;
+
+ /* bucket selection as in inet_lhash2_bucket_sk() */
+ net = sk->sk_net.net;
+ hash = jhash2(sk->sk_v6_rcv_saddr.s6_addr32, 4, net->hash_mix);
+ hash ^= sk->sk_num;
+ hinfo = net->ipv4.tcp_death_row.hashinfo;
+ bucket[idx] = hash & hinfo->lhash2_mask;
+ bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx));
+
+ return 0;
+}
+
+#define udp_sk(ptr) container_of(ptr, struct udp_sock, inet.sk)
+
+SEC("iter/udp")
+int iter_udp_soreuse(struct bpf_iter__udp *ctx)
+{
+ struct sock *sk = (struct sock *)ctx->udp_sk;
+ struct udp_table *udptable;
+ int idx;
+
+ if (!sk)
+ return 0;
+
+ sk = bpf_rdonly_cast(sk, bpf_core_type_id_kernel(struct sock));
+ if (sk->sk_family != AF_INET6 ||
+ !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr))
+ return 0;
+
+ if (sk->sk_num == ports[0])
+ idx = 0;
+ else if (sk->sk_num == ports[1])
+ idx = 1;
+ else
+ return 0;
+
+ /* bucket selection as in udp_hashslot2() */
+ udptable = sk->sk_net.net->ipv4.udp_table;
+ bucket[idx] = udp_sk(sk)->udp_portaddr_hash & udptable->mask;
+ bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx));
+
+ return 0;
+}
+
+char _license[] SEC("license") = "GPL";
diff --git a/tools/testing/selftests/bpf/progs/test_jhash.h b/tools/testing/selftests/bpf/progs/test_jhash.h
index c300734d26f6..ef53559bbbdf 100644
--- a/tools/testing/selftests/bpf/progs/test_jhash.h
+++ b/tools/testing/selftests/bpf/progs/test_jhash.h
@@ -69,3 +69,34 @@ u32 jhash(const void *key, u32 length, u32 initval)
return c;
}
+
+static __always_inline u32 jhash2(const u32 *k, u32 length, u32 initval)
+{
+ u32 a, b, c;
+
+ /* Set up the internal state */
+ a = b = c = JHASH_INITVAL + (length<<2) + initval;
+
+ /* Handle most of the key */
+ while (length > 3) {
+ a += k[0];
+ b += k[1];
+ c += k[2];
+ __jhash_mix(a, b, c);
+ length -= 3;
+ k += 3;
+ }
+
+ /* Handle the last 3 u32's */
+ switch (length) {
+ case 3: c += k[2];
+ case 2: b += k[1];
+ case 1: a += k[0];
+ __jhash_final(a, b, c);
+ break;
+ case 0: /* Nothing left to add */
+ break;
+ }
+
+ return c;
+}