summaryrefslogtreecommitdiff
path: root/net/mctp
diff options
context:
space:
mode:
Diffstat (limited to 'net/mctp')
-rw-r--r--net/mctp/Kconfig12
-rw-r--r--net/mctp/Makefile3
-rw-r--r--net/mctp/af_mctp.c152
-rw-r--r--net/mctp/device.c104
-rw-r--r--net/mctp/neigh.c4
-rw-r--r--net/mctp/route.c362
-rw-r--r--net/mctp/test/route-test.c544
-rw-r--r--net/mctp/test/utils.c67
-rw-r--r--net/mctp/test/utils.h20
9 files changed, 1165 insertions, 103 deletions
diff --git a/net/mctp/Kconfig b/net/mctp/Kconfig
index 2cdf3d0a28c9..3a5c0e70da77 100644
--- a/net/mctp/Kconfig
+++ b/net/mctp/Kconfig
@@ -1,7 +1,7 @@
menuconfig MCTP
depends on NET
- tristate "MCTP core protocol support"
+ bool "MCTP core protocol support"
help
Management Component Transport Protocol (MCTP) is an in-system
protocol for communicating between management controllers and
@@ -11,3 +11,13 @@ menuconfig MCTP
This option enables core MCTP support. For communicating with other
devices, you'll want to enable a driver for a specific hardware
channel.
+
+config MCTP_TEST
+ bool "MCTP core tests" if !KUNIT_ALL_TESTS
+ depends on MCTP=y && KUNIT=y
+ default KUNIT_ALL_TESTS
+
+config MCTP_FLOWS
+ bool
+ depends on MCTP
+ select SKB_EXTENSIONS
diff --git a/net/mctp/Makefile b/net/mctp/Makefile
index 0171333384d7..6cd55233e685 100644
--- a/net/mctp/Makefile
+++ b/net/mctp/Makefile
@@ -1,3 +1,6 @@
# SPDX-License-Identifier: GPL-2.0
obj-$(CONFIG_MCTP) += mctp.o
mctp-objs := af_mctp.o device.o route.o neigh.o
+
+# tests
+obj-$(CONFIG_MCTP_TEST) += test/utils.o
diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c
index a9526ac29dff..d344b02a1cde 100644
--- a/net/mctp/af_mctp.c
+++ b/net/mctp/af_mctp.c
@@ -16,6 +16,9 @@
#include <net/mctpdevice.h>
#include <net/sock.h>
+#define CREATE_TRACE_POINTS
+#include <trace/events/mctp.h>
+
/* socket implementation */
static int mctp_release(struct socket *sock)
@@ -74,6 +77,7 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
int rc, addrlen = msg->msg_namelen;
struct sock *sk = sock->sk;
+ struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct mctp_skb_cb *cb;
struct mctp_route *rt;
struct sk_buff *skb;
@@ -97,11 +101,6 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
if (addr->smctp_network == MCTP_NET_ANY)
addr->smctp_network = mctp_default_net(sock_net(sk));
- rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
- addr->smctp_addr.s_addr);
- if (!rt)
- return -EHOSTUNREACH;
-
skb = sock_alloc_send_skb(sk, hlen + 1 + len,
msg->msg_flags & MSG_DONTWAIT, &rc);
if (!skb)
@@ -113,19 +112,45 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
*(u8 *)skb_put(skb, 1) = addr->smctp_type;
rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
- if (rc < 0) {
- kfree_skb(skb);
- return rc;
- }
+ if (rc < 0)
+ goto err_free;
/* set up cb */
cb = __mctp_cb(skb);
cb->net = addr->smctp_network;
+ /* direct addressing */
+ if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
+ DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
+ extaddr, msg->msg_name);
+
+ if (extaddr->smctp_halen > sizeof(cb->haddr)) {
+ rc = -EINVAL;
+ goto err_free;
+ }
+
+ cb->ifindex = extaddr->smctp_ifindex;
+ cb->halen = extaddr->smctp_halen;
+ memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen);
+
+ rt = NULL;
+ } else {
+ rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
+ addr->smctp_addr.s_addr);
+ if (!rt) {
+ rc = -EHOSTUNREACH;
+ goto err_free;
+ }
+ }
+
rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
addr->smctp_tag);
return rc ? : len;
+
+err_free:
+ kfree_skb(skb);
+ return rc;
}
static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
@@ -133,6 +158,7 @@ static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
{
DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
struct sock *sk = sock->sk;
+ struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct sk_buff *skb;
size_t msglen;
u8 type;
@@ -178,6 +204,16 @@ static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
addr->smctp_tag = hdr->flags_seq_tag &
(MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
msg->msg_namelen = sizeof(*addr);
+
+ if (msk->addr_ext) {
+ DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae,
+ msg->msg_name);
+ msg->msg_namelen = sizeof(*ae);
+ ae->smctp_ifindex = cb->ifindex;
+ ae->smctp_halen = cb->halen;
+ memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr));
+ memcpy(ae->smctp_haddr, cb->haddr, cb->halen);
+ }
}
rc = len;
@@ -193,12 +229,45 @@ out_free:
static int mctp_setsockopt(struct socket *sock, int level, int optname,
sockptr_t optval, unsigned int optlen)
{
- return -EINVAL;
+ struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
+ int val;
+
+ if (level != SOL_MCTP)
+ return -EINVAL;
+
+ if (optname == MCTP_OPT_ADDR_EXT) {
+ if (optlen != sizeof(int))
+ return -EINVAL;
+ if (copy_from_sockptr(&val, optval, sizeof(int)))
+ return -EFAULT;
+ msk->addr_ext = val;
+ return 0;
+ }
+
+ return -ENOPROTOOPT;
}
static int mctp_getsockopt(struct socket *sock, int level, int optname,
char __user *optval, int __user *optlen)
{
+ struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
+ int len, val;
+
+ if (level != SOL_MCTP)
+ return -EINVAL;
+
+ if (get_user(len, optlen))
+ return -EFAULT;
+
+ if (optname == MCTP_OPT_ADDR_EXT) {
+ if (len != sizeof(int))
+ return -EINVAL;
+ val = !!msk->addr_ext;
+ if (copy_to_user(optval, &val, len))
+ return -EFAULT;
+ return 0;
+ }
+
return -EINVAL;
}
@@ -223,16 +292,61 @@ static const struct proto_ops mctp_dgram_ops = {
.sendpage = sock_no_sendpage,
};
+static void mctp_sk_expire_keys(struct timer_list *timer)
+{
+ struct mctp_sock *msk = container_of(timer, struct mctp_sock,
+ key_expiry);
+ struct net *net = sock_net(&msk->sk);
+ unsigned long next_expiry, flags;
+ struct mctp_sk_key *key;
+ struct hlist_node *tmp;
+ bool next_expiry_valid = false;
+
+ spin_lock_irqsave(&net->mctp.keys_lock, flags);
+
+ hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
+ spin_lock(&key->lock);
+
+ if (!time_after_eq(key->expiry, jiffies)) {
+ trace_mctp_key_release(key, MCTP_TRACE_KEY_TIMEOUT);
+ key->valid = false;
+ hlist_del_rcu(&key->hlist);
+ hlist_del_rcu(&key->sklist);
+ spin_unlock(&key->lock);
+ mctp_key_unref(key);
+ continue;
+ }
+
+ if (next_expiry_valid) {
+ if (time_before(key->expiry, next_expiry))
+ next_expiry = key->expiry;
+ } else {
+ next_expiry = key->expiry;
+ next_expiry_valid = true;
+ }
+ spin_unlock(&key->lock);
+ }
+
+ spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+
+ if (next_expiry_valid)
+ mod_timer(timer, next_expiry);
+}
+
static int mctp_sk_init(struct sock *sk)
{
struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
INIT_HLIST_HEAD(&msk->keys);
+ timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0);
return 0;
}
static void mctp_sk_close(struct sock *sk, long timeout)
{
+ struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
+
+ del_timer_sync(&msk->key_expiry);
sk_common_release(sk);
}
@@ -263,21 +377,23 @@ static void mctp_sk_unhash(struct sock *sk)
/* remove tag allocations */
spin_lock_irqsave(&net->mctp.keys_lock, flags);
hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
- hlist_del_rcu(&key->sklist);
- hlist_del_rcu(&key->hlist);
+ hlist_del(&key->sklist);
+ hlist_del(&key->hlist);
- spin_lock(&key->reasm_lock);
+ trace_mctp_key_release(key, MCTP_TRACE_KEY_CLOSED);
+
+ spin_lock(&key->lock);
if (key->reasm_head)
kfree_skb(key->reasm_head);
key->reasm_head = NULL;
key->reasm_dead = true;
- spin_unlock(&key->reasm_lock);
+ key->valid = false;
+ spin_unlock(&key->lock);
- kfree_rcu(key, rcu);
+ /* key is no longer on the lookup lists, unref */
+ mctp_key_unref(key);
}
spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
-
- synchronize_rcu();
}
static struct proto mctp_proto = {
@@ -385,7 +501,7 @@ static __exit void mctp_exit(void)
sock_unregister(PF_MCTP);
}
-module_init(mctp_init);
+subsys_initcall(mctp_init);
module_exit(mctp_exit);
MODULE_DESCRIPTION("MCTP core");
diff --git a/net/mctp/device.c b/net/mctp/device.c
index b9f38e765f61..8799ee77e7b7 100644
--- a/net/mctp/device.c
+++ b/net/mctp/device.c
@@ -35,14 +35,6 @@ struct mctp_dev *mctp_dev_get_rtnl(const struct net_device *dev)
return rtnl_dereference(dev->mctp_ptr);
}
-static void mctp_dev_destroy(struct mctp_dev *mdev)
-{
- struct net_device *dev = mdev->dev;
-
- dev_put(dev);
- kfree_rcu(mdev, rcu);
-}
-
static int mctp_fill_addrinfo(struct sk_buff *skb, struct netlink_callback *cb,
struct mctp_dev *mdev, mctp_eid_t eid)
{
@@ -255,6 +247,37 @@ static int mctp_rtm_deladdr(struct sk_buff *skb, struct nlmsghdr *nlh,
return 0;
}
+void mctp_dev_hold(struct mctp_dev *mdev)
+{
+ refcount_inc(&mdev->refs);
+}
+
+void mctp_dev_put(struct mctp_dev *mdev)
+{
+ if (refcount_dec_and_test(&mdev->refs)) {
+ dev_put(mdev->dev);
+ kfree_rcu(mdev, rcu);
+ }
+}
+
+void mctp_dev_release_key(struct mctp_dev *dev, struct mctp_sk_key *key)
+ __must_hold(&key->lock)
+{
+ if (!dev)
+ return;
+ if (dev->ops && dev->ops->release_flow)
+ dev->ops->release_flow(dev, key);
+ key->dev = NULL;
+ mctp_dev_put(dev);
+}
+
+void mctp_dev_set_key(struct mctp_dev *dev, struct mctp_sk_key *key)
+ __must_hold(&key->lock)
+{
+ mctp_dev_hold(dev);
+ key->dev = dev;
+}
+
static struct mctp_dev *mctp_add_dev(struct net_device *dev)
{
struct mctp_dev *mdev;
@@ -270,7 +293,9 @@ static struct mctp_dev *mctp_add_dev(struct net_device *dev)
mdev->net = mctp_default_net(dev_net(dev));
/* associate to net_device */
+ refcount_set(&mdev->refs, 1);
rcu_assign_pointer(dev->mctp_ptr, mdev);
+
dev_hold(dev);
mdev->dev = dev;
@@ -330,12 +355,26 @@ static int mctp_set_link_af(struct net_device *dev, const struct nlattr *attr,
return 0;
}
+/* Matches netdev types that should have MCTP handling */
+static bool mctp_known(struct net_device *dev)
+{
+ /* only register specific types (inc. NONE for TUN devices) */
+ return dev->type == ARPHRD_MCTP ||
+ dev->type == ARPHRD_LOOPBACK ||
+ dev->type == ARPHRD_NONE;
+}
+
static void mctp_unregister(struct net_device *dev)
{
struct mctp_dev *mdev;
mdev = mctp_dev_get_rtnl(dev);
-
+ if (mctp_known(dev) != (bool)mdev) {
+ // Sanity check, should match what was set in mctp_register
+ netdev_warn(dev, "%s: mdev pointer %d but type (%d) match is %d",
+ __func__, (bool)mdev, mctp_known(dev), dev->type);
+ return;
+ }
if (!mdev)
return;
@@ -345,7 +384,7 @@ static void mctp_unregister(struct net_device *dev)
mctp_neigh_remove_dev(mdev);
kfree(mdev->addrs);
- mctp_dev_destroy(mdev);
+ mctp_dev_put(mdev);
}
static int mctp_register(struct net_device *dev)
@@ -353,11 +392,17 @@ static int mctp_register(struct net_device *dev)
struct mctp_dev *mdev;
/* Already registered? */
- if (rtnl_dereference(dev->mctp_ptr))
+ mdev = rtnl_dereference(dev->mctp_ptr);
+
+ if (mdev) {
+ if (!mctp_known(dev))
+ netdev_warn(dev, "%s: mctp_dev set for unknown type %d",
+ __func__, dev->type);
return 0;
+ }
- /* only register specific types; MCTP-specific and loopback for now */
- if (dev->type != ARPHRD_MCTP && dev->type != ARPHRD_LOOPBACK)
+ /* only register specific types */
+ if (!mctp_known(dev))
return 0;
mdev = mctp_add_dev(dev);
@@ -387,6 +432,39 @@ static int mctp_dev_notify(struct notifier_block *this, unsigned long event,
return NOTIFY_OK;
}
+static int mctp_register_netdevice(struct net_device *dev,
+ const struct mctp_netdev_ops *ops)
+{
+ struct mctp_dev *mdev;
+
+ mdev = mctp_add_dev(dev);
+ if (IS_ERR(mdev))
+ return PTR_ERR(mdev);
+
+ mdev->ops = ops;
+
+ return register_netdevice(dev);
+}
+
+int mctp_register_netdev(struct net_device *dev,
+ const struct mctp_netdev_ops *ops)
+{
+ int rc;
+
+ rtnl_lock();
+ rc = mctp_register_netdevice(dev, ops);
+ rtnl_unlock();
+
+ return rc;
+}
+EXPORT_SYMBOL_GPL(mctp_register_netdev);
+
+void mctp_unregister_netdev(struct net_device *dev)
+{
+ unregister_netdev(dev);
+}
+EXPORT_SYMBOL_GPL(mctp_unregister_netdev);
+
static struct rtnl_af_ops mctp_af_ops = {
.family = AF_MCTP,
.fill_link_af = mctp_fill_link_af,
diff --git a/net/mctp/neigh.c b/net/mctp/neigh.c
index 90ed2f02d1fb..5cc042121493 100644
--- a/net/mctp/neigh.c
+++ b/net/mctp/neigh.c
@@ -47,7 +47,7 @@ static int mctp_neigh_add(struct mctp_dev *mdev, mctp_eid_t eid,
}
INIT_LIST_HEAD(&neigh->list);
neigh->dev = mdev;
- dev_hold(neigh->dev->dev);
+ mctp_dev_hold(neigh->dev);
neigh->eid = eid;
neigh->source = source;
memcpy(neigh->ha, lladdr, lladdr_len);
@@ -63,7 +63,7 @@ static void __mctp_neigh_free(struct rcu_head *rcu)
{
struct mctp_neigh *neigh = container_of(rcu, struct mctp_neigh, rcu);
- dev_put(neigh->dev->dev);
+ mctp_dev_put(neigh->dev);
kfree(neigh);
}
diff --git a/net/mctp/route.c b/net/mctp/route.c
index 5ca186d53cb0..46c44823edb7 100644
--- a/net/mctp/route.c
+++ b/net/mctp/route.c
@@ -11,6 +11,7 @@
*/
#include <linux/idr.h>
+#include <linux/kconfig.h>
#include <linux/mctp.h>
#include <linux/netdevice.h>
#include <linux/rtnetlink.h>
@@ -23,7 +24,12 @@
#include <net/netlink.h>
#include <net/sock.h>
+#include <trace/events/mctp.h>
+
static const unsigned int mctp_message_maxlen = 64 * 1024;
+static const unsigned long mctp_key_lifetime = 6 * CONFIG_HZ;
+
+static void mctp_flow_prepare_output(struct sk_buff *skb, struct mctp_dev *dev);
/* route output callbacks */
static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb)
@@ -83,25 +89,43 @@ static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
return true;
}
+/* returns a key (with key->lock held, and refcounted), or NULL if no such
+ * key exists.
+ */
static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
- mctp_eid_t peer)
+ mctp_eid_t peer,
+ unsigned long *irqflags)
+ __acquires(&key->lock)
{
struct mctp_sk_key *key, *ret;
+ unsigned long flags;
struct mctp_hdr *mh;
u8 tag;
- WARN_ON(!rcu_read_lock_held());
-
mh = mctp_hdr(skb);
tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
ret = NULL;
+ spin_lock_irqsave(&net->mctp.keys_lock, flags);
- hlist_for_each_entry_rcu(key, &net->mctp.keys, hlist) {
- if (mctp_key_match(key, mh->dest, peer, tag)) {
+ hlist_for_each_entry(key, &net->mctp.keys, hlist) {
+ if (!mctp_key_match(key, mh->dest, peer, tag))
+ continue;
+
+ spin_lock(&key->lock);
+ if (key->valid) {
+ refcount_inc(&key->refs);
ret = key;
break;
}
+ spin_unlock(&key->lock);
+ }
+
+ if (ret) {
+ spin_unlock(&net->mctp.keys_lock);
+ *irqflags = flags;
+ } else {
+ spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
}
return ret;
@@ -121,11 +145,30 @@ static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
key->local_addr = local;
key->tag = tag;
key->sk = &msk->sk;
- spin_lock_init(&key->reasm_lock);
+ key->valid = true;
+ spin_lock_init(&key->lock);
+ refcount_set(&key->refs, 1);
return key;
}
+void mctp_key_unref(struct mctp_sk_key *key)
+{
+ unsigned long flags;
+
+ if (!refcount_dec_and_test(&key->refs))
+ return;
+
+ /* even though no refs exist here, the lock allows us to stay
+ * consistent with the locking requirement of mctp_dev_release_key
+ */
+ spin_lock_irqsave(&key->lock, flags);
+ mctp_dev_release_key(key->dev, key);
+ spin_unlock_irqrestore(&key->lock, flags);
+
+ kfree(key);
+}
+
static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
{
struct net *net = sock_net(&msk->sk);
@@ -138,12 +181,20 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
key->tag)) {
- rc = -EEXIST;
- break;
+ spin_lock(&tmp->lock);
+ if (tmp->valid)
+ rc = -EEXIST;
+ spin_unlock(&tmp->lock);
+ if (rc)
+ break;
}
}
if (!rc) {
+ refcount_inc(&key->refs);
+ key->expiry = jiffies + mctp_key_lifetime;
+ timer_reduce(&msk->key_expiry, key->expiry);
+
hlist_add_head(&key->hlist, &net->mctp.keys);
hlist_add_head(&key->sklist, &msk->keys);
}
@@ -153,30 +204,72 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
return rc;
}
-/* Must be called with key->reasm_lock, which it will release. Will schedule
- * the key for an RCU free.
+/* We're done with the key; unset valid and remove from lists. There may still
+ * be outstanding refs on the key though...
*/
static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
unsigned long flags)
- __releases(&key->reasm_lock)
+ __releases(&key->lock)
{
struct sk_buff *skb;
skb = key->reasm_head;
key->reasm_head = NULL;
key->reasm_dead = true;
- spin_unlock_irqrestore(&key->reasm_lock, flags);
+ key->valid = false;
+ mctp_dev_release_key(key->dev, key);
+ spin_unlock_irqrestore(&key->lock, flags);
spin_lock_irqsave(&net->mctp.keys_lock, flags);
- hlist_del_rcu(&key->hlist);
- hlist_del_rcu(&key->sklist);
+ hlist_del(&key->hlist);
+ hlist_del(&key->sklist);
spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
- kfree_rcu(key, rcu);
+
+ /* one unref for the lists */
+ mctp_key_unref(key);
+
+ /* and one for the local reference */
+ mctp_key_unref(key);
if (skb)
kfree_skb(skb);
+
}
+#ifdef CONFIG_MCTP_FLOWS
+static void mctp_skb_set_flow(struct sk_buff *skb, struct mctp_sk_key *key)
+{
+ struct mctp_flow *flow;
+
+ flow = skb_ext_add(skb, SKB_EXT_MCTP);
+ if (!flow)
+ return;
+
+ refcount_inc(&key->refs);
+ flow->key = key;
+}
+
+static void mctp_flow_prepare_output(struct sk_buff *skb, struct mctp_dev *dev)
+{
+ struct mctp_sk_key *key;
+ struct mctp_flow *flow;
+
+ flow = skb_ext_find(skb, SKB_EXT_MCTP);
+ if (!flow)
+ return;
+
+ key = flow->key;
+
+ if (WARN_ON(key->dev && key->dev != dev))
+ return;
+
+ mctp_dev_set_key(dev, key);
+}
+#else
+static void mctp_skb_set_flow(struct sk_buff *skb, struct mctp_sk_key *key) {}
+static void mctp_flow_prepare_output(struct sk_buff *skb, struct mctp_dev *dev) {}
+#endif
+
static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb)
{
struct mctp_hdr *hdr = mctp_hdr(skb);
@@ -248,8 +341,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
rcu_read_lock();
- /* lookup socket / reasm context, exactly matching (src,dest,tag) */
- key = mctp_lookup_key(net, skb, mh->src);
+ /* lookup socket / reasm context, exactly matching (src,dest,tag).
+ * we hold a ref on the key, and key->lock held.
+ */
+ key = mctp_lookup_key(net, skb, mh->src, &f);
if (flags & MCTP_HDR_FLAG_SOM) {
if (key) {
@@ -260,10 +355,12 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
* key for reassembly - we'll create a more specific
* one for future packets if required (ie, !EOM).
*/
- key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY);
+ key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f);
if (key) {
msk = container_of(key->sk,
struct mctp_sock, sk);
+ spin_unlock_irqrestore(&key->lock, f);
+ mctp_key_unref(key);
key = NULL;
}
}
@@ -282,11 +379,13 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
if (flags & MCTP_HDR_FLAG_EOM) {
sock_queue_rcv_skb(&msk->sk, skb);
if (key) {
- spin_lock_irqsave(&key->reasm_lock, f);
/* we've hit a pending reassembly; not much we
* can do but drop it
*/
+ trace_mctp_key_release(key,
+ MCTP_TRACE_KEY_REPLIED);
__mctp_key_unlock_drop(key, net, f);
+ key = NULL;
}
rc = 0;
goto out_unlock;
@@ -303,7 +402,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
goto out_unlock;
}
- /* we can queue without the reasm lock here, as the
+ /* we can queue without the key lock here, as the
* key isn't observable yet
*/
mctp_frag_queue(key, skb);
@@ -318,17 +417,22 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
if (rc)
kfree(key);
- } else {
- /* existing key: start reassembly */
- spin_lock_irqsave(&key->reasm_lock, f);
+ trace_mctp_key_acquire(key);
+
+ /* we don't need to release key->lock on exit */
+ mctp_key_unref(key);
+ key = NULL;
+ } else {
if (key->reasm_head || key->reasm_dead) {
/* duplicate start? drop everything */
+ trace_mctp_key_release(key,
+ MCTP_TRACE_KEY_INVALIDATED);
__mctp_key_unlock_drop(key, net, f);
rc = -EEXIST;
+ key = NULL;
} else {
rc = mctp_frag_queue(key, skb);
- spin_unlock_irqrestore(&key->reasm_lock, f);
}
}
@@ -337,8 +441,6 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
* using the message-specific key
*/
- spin_lock_irqsave(&key->reasm_lock, f);
-
/* we need to be continuing an existing reassembly... */
if (!key->reasm_head)
rc = -EINVAL;
@@ -351,9 +453,9 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
if (!rc && flags & MCTP_HDR_FLAG_EOM) {
sock_queue_rcv_skb(key->sk, key->reasm_head);
key->reasm_head = NULL;
+ trace_mctp_key_release(key, MCTP_TRACE_KEY_REPLIED);
__mctp_key_unlock_drop(key, net, f);
- } else {
- spin_unlock_irqrestore(&key->reasm_lock, f);
+ key = NULL;
}
} else {
@@ -363,6 +465,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
out_unlock:
rcu_read_unlock();
+ if (key) {
+ spin_unlock_irqrestore(&key->lock, f);
+ mctp_key_unref(key);
+ }
out:
if (rc)
kfree_skb(skb);
@@ -376,6 +482,7 @@ static unsigned int mctp_route_mtu(struct mctp_route *rt)
static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
{
+ struct mctp_skb_cb *cb = mctp_cb(skb);
struct mctp_hdr *hdr = mctp_hdr(skb);
char daddr_buf[MAX_ADDR_LEN];
char *daddr = NULL;
@@ -390,9 +497,14 @@ static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
return -EMSGSIZE;
}
- /* If lookup fails let the device handle daddr==NULL */
- if (mctp_neigh_lookup(route->dev, hdr->dest, daddr_buf) == 0)
- daddr = daddr_buf;
+ if (cb->ifindex) {
+ /* direct route; use the hwaddr we stashed in sendmsg */
+ daddr = cb->haddr;
+ } else {
+ /* If lookup fails let the device handle daddr==NULL */
+ if (mctp_neigh_lookup(route->dev, hdr->dest, daddr_buf) == 0)
+ daddr = daddr_buf;
+ }
rc = dev_hard_header(skb, skb->dev, ntohs(skb->protocol),
daddr, skb->dev->dev_addr, skb->len);
@@ -401,6 +513,8 @@ static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
return -EHOSTUNREACH;
}
+ mctp_flow_prepare_output(skb, route->dev);
+
rc = dev_queue_xmit(skb);
if (rc)
rc = net_xmit_errno(rc);
@@ -412,7 +526,7 @@ static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
static void mctp_route_release(struct mctp_route *rt)
{
if (refcount_dec_and_test(&rt->refs)) {
- dev_put(rt->dev->dev);
+ mctp_dev_put(rt->dev);
kfree_rcu(rt, rcu);
}
}
@@ -454,30 +568,38 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
lockdep_assert_held(&mns->keys_lock);
+ key->expiry = jiffies + mctp_key_lifetime;
+ timer_reduce(&msk->key_expiry, key->expiry);
+
/* we hold the net->key_lock here, allowing updates to both
* then net and sk
*/
hlist_add_head_rcu(&key->hlist, &mns->keys);
hlist_add_head_rcu(&key->sklist, &msk->keys);
+ refcount_inc(&key->refs);
}
/* Allocate a locally-owned tag value for (saddr, daddr), and reserve
* it for the socket msk
*/
-static int mctp_alloc_local_tag(struct mctp_sock *msk,
- mctp_eid_t saddr, mctp_eid_t daddr, u8 *tagp)
+static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
+ mctp_eid_t saddr,
+ mctp_eid_t daddr, u8 *tagp)
{
struct net *net = sock_net(&msk->sk);
struct netns_mctp *mns = &net->mctp;
struct mctp_sk_key *key, *tmp;
unsigned long flags;
- int rc = -EAGAIN;
u8 tagbits;
+ /* for NULL destination EIDs, we may get a response from any peer */
+ if (daddr == MCTP_ADDR_NULL)
+ daddr = MCTP_ADDR_ANY;
+
/* be optimistic, alloc now */
key = mctp_key_alloc(msk, saddr, daddr, 0, GFP_KERNEL);
if (!key)
- return -ENOMEM;
+ return ERR_PTR(-ENOMEM);
/* 8 possible tag values */
tagbits = 0xff;
@@ -488,14 +610,26 @@ static int mctp_alloc_local_tag(struct mctp_sock *msk,
* tags. If we find a conflict, clear that bit from tagbits
*/
hlist_for_each_entry(tmp, &mns->keys, hlist) {
+ /* We can check the lookup fields (*_addr, tag) without the
+ * lock held, they don't change over the lifetime of the key.
+ */
+
/* if we don't own the tag, it can't conflict */
if (tmp->tag & MCTP_HDR_FLAG_TO)
continue;
- if ((tmp->peer_addr == daddr ||
- tmp->peer_addr == MCTP_ADDR_ANY) &&
- tmp->local_addr == saddr)
+ if (!((tmp->peer_addr == daddr ||
+ tmp->peer_addr == MCTP_ADDR_ANY) &&
+ tmp->local_addr == saddr))
+ continue;
+
+ spin_lock(&tmp->lock);
+ /* key must still be valid. If we find a match, clear the
+ * potential tag value
+ */
+ if (tmp->valid)
tagbits &= ~(1 << tmp->tag);
+ spin_unlock(&tmp->lock);
if (!tagbits)
break;
@@ -504,16 +638,19 @@ static int mctp_alloc_local_tag(struct mctp_sock *msk,
if (tagbits) {
key->tag = __ffs(tagbits);
mctp_reserve_tag(net, key, msk);
+ trace_mctp_key_acquire(key);
+
*tagp = key->tag;
- rc = 0;
}
spin_unlock_irqrestore(&mns->keys_lock, flags);
- if (!tagbits)
+ if (!tagbits) {
kfree(key);
+ return ERR_PTR(-EBUSY);
+ }
- return rc;
+ return key;
}
/* routing lookups */
@@ -552,14 +689,18 @@ struct mctp_route *mctp_route_lookup(struct net *net, unsigned int dnet,
return rt;
}
-/* sends a skb to rt and releases the route. */
-int mctp_do_route(struct mctp_route *rt, struct sk_buff *skb)
+static struct mctp_route *mctp_route_lookup_null(struct net *net,
+ struct net_device *dev)
{
- int rc;
+ struct mctp_route *rt;
- rc = rt->output(rt, skb);
- mctp_route_release(rt);
- return rc;
+ list_for_each_entry_rcu(rt, &net->mctp.routes, list) {
+ if (rt->dev->dev == dev && rt->type == RTN_LOCAL &&
+ refcount_inc_not_zero(&rt->refs))
+ return rt;
+ }
+
+ return NULL;
}
static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
@@ -628,7 +769,7 @@ static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
/* copy message payload */
skb_copy_bits(skb, pos, skb_transport_header(skb2), size);
- /* do route, but don't drop the rt reference */
+ /* do route */
rc = rt->output(rt, skb2);
if (rc)
break;
@@ -637,7 +778,6 @@ static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
pos += size;
}
- mctp_route_release(rt);
consume_skb(skb);
return rc;
}
@@ -647,15 +787,52 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
{
struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct mctp_skb_cb *cb = mctp_cb(skb);
+ struct mctp_route tmp_rt;
+ struct mctp_sk_key *key;
+ struct net_device *dev;
struct mctp_hdr *hdr;
unsigned long flags;
unsigned int mtu;
mctp_eid_t saddr;
+ bool ext_rt;
int rc;
u8 tag;
- if (WARN_ON(!rt->dev))
+ rc = -ENODEV;
+
+ if (rt) {
+ ext_rt = false;
+ dev = NULL;
+
+ if (WARN_ON(!rt->dev))
+ goto out_release;
+
+ } else if (cb->ifindex) {
+ ext_rt = true;
+ rt = &tmp_rt;
+
+ rcu_read_lock();
+ dev = dev_get_by_index_rcu(sock_net(sk), cb->ifindex);
+ if (!dev) {
+ rcu_read_unlock();
+ return rc;
+ }
+
+ rt->dev = __mctp_dev_get(dev);
+ rcu_read_unlock();
+
+ if (!rt->dev)
+ goto out_release;
+
+ /* establish temporary route - we set up enough to keep
+ * mctp_route_output happy
+ */
+ rt->output = mctp_route_output;
+ rt->mtu = 0;
+
+ } else {
return -EINVAL;
+ }
spin_lock_irqsave(&rt->dev->addrs_lock, flags);
if (rt->dev->num_addrs == 0) {
@@ -668,18 +845,23 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
spin_unlock_irqrestore(&rt->dev->addrs_lock, flags);
if (rc)
- return rc;
+ goto out_release;
if (req_tag & MCTP_HDR_FLAG_TO) {
- rc = mctp_alloc_local_tag(msk, saddr, daddr, &tag);
- if (rc)
- return rc;
+ key = mctp_alloc_local_tag(msk, saddr, daddr, &tag);
+ if (IS_ERR(key)) {
+ rc = PTR_ERR(key);
+ goto out_release;
+ }
+ mctp_skb_set_flow(skb, key);
+ /* done with the key in this scope */
+ mctp_key_unref(key);
tag |= MCTP_HDR_FLAG_TO;
} else {
+ key = NULL;
tag = req_tag;
}
-
skb->protocol = htons(ETH_P_MCTP);
skb->priority = 0;
skb_reset_transport_header(skb);
@@ -699,12 +881,22 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
mtu = mctp_route_mtu(rt);
if (skb->len + sizeof(struct mctp_hdr) <= mtu) {
- hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM |
- tag;
- return mctp_do_route(rt, skb);
+ hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM |
+ MCTP_HDR_FLAG_EOM | tag;
+ rc = rt->output(rt, skb);
} else {
- return mctp_do_fragment_route(rt, skb, mtu, tag);
+ rc = mctp_do_fragment_route(rt, skb, mtu, tag);
}
+
+out_release:
+ if (!ext_rt)
+ mctp_route_release(rt);
+
+ if (dev)
+ dev_put(dev);
+
+ return rc;
+
}
/* route management */
@@ -741,7 +933,7 @@ static int mctp_route_add(struct mctp_dev *mdev, mctp_eid_t daddr_start,
rt->max = daddr_start + daddr_extent;
rt->mtu = mtu;
rt->dev = mdev;
- dev_hold(rt->dev->dev);
+ mctp_dev_hold(rt->dev);
rt->type = type;
rt->output = rtfn;
@@ -821,13 +1013,18 @@ static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev,
struct net_device *orig_dev)
{
struct net *net = dev_net(dev);
+ struct mctp_dev *mdev;
struct mctp_skb_cb *cb;
struct mctp_route *rt;
struct mctp_hdr *mh;
- /* basic non-data sanity checks */
- if (dev->type != ARPHRD_MCTP)
+ rcu_read_lock();
+ mdev = __mctp_dev_get(dev);
+ rcu_read_unlock();
+ if (!mdev) {
+ /* basic non-data sanity checks */
goto err_drop;
+ }
if (!pskb_may_pull(skb, sizeof(struct mctp_hdr)))
goto err_drop;
@@ -840,16 +1037,27 @@ static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev,
if (mh->ver < MCTP_VER_MIN || mh->ver > MCTP_VER_MAX)
goto err_drop;
- cb = __mctp_cb(skb);
- rcu_read_lock();
- cb->net = READ_ONCE(__mctp_dev_get(dev)->net);
- rcu_read_unlock();
+ /* MCTP drivers must populate halen/haddr */
+ if (dev->type == ARPHRD_MCTP) {
+ cb = mctp_cb(skb);
+ } else {
+ cb = __mctp_cb(skb);
+ cb->halen = 0;
+ }
+ cb->net = READ_ONCE(mdev->net);
+ cb->ifindex = dev->ifindex;
rt = mctp_route_lookup(net, cb->net, mh->dest);
+
+ /* NULL EID, but addressed to our physical address */
+ if (!rt && mh->dest == MCTP_ADDR_NULL && skb->pkt_type == PACKET_HOST)
+ rt = mctp_route_lookup_null(net, dev);
+
if (!rt)
goto err_drop;
- mctp_do_route(rt, skb);
+ rt->output(rt, skb);
+ mctp_route_release(rt);
return NET_RX_SUCCESS;
@@ -926,10 +1134,15 @@ static int mctp_route_nlparse(struct sk_buff *skb, struct nlmsghdr *nlh,
return 0;
}
+static const struct nla_policy rta_metrics_policy[RTAX_MAX + 1] = {
+ [RTAX_MTU] = { .type = NLA_U32 },
+};
+
static int mctp_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
struct netlink_ext_ack *extack)
{
struct nlattr *tb[RTA_MAX + 1];
+ struct nlattr *tbx[RTAX_MAX + 1];
mctp_eid_t daddr_start;
struct mctp_dev *mdev;
struct rtmsg *rtm;
@@ -946,8 +1159,15 @@ static int mctp_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
return -EINVAL;
}
- /* TODO: parse mtu from nlparse */
mtu = 0;
+ if (tb[RTA_METRICS]) {
+ rc = nla_parse_nested(tbx, RTAX_MAX, tb[RTA_METRICS],
+ rta_metrics_policy, NULL);
+ if (rc < 0)
+ return rc;
+ if (tbx[RTAX_MTU])
+ mtu = nla_get_u32(tbx[RTAX_MTU]);
+ }
if (rtm->rtm_type != RTN_UNICAST)
return -EINVAL;
@@ -1116,3 +1336,7 @@ void __exit mctp_routes_exit(void)
rtnl_unregister(PF_MCTP, RTM_GETROUTE);
dev_remove_pack(&mctp_packet_type);
}
+
+#if IS_ENABLED(CONFIG_MCTP_TEST)
+#include "test/route-test.c"
+#endif
diff --git a/net/mctp/test/route-test.c b/net/mctp/test/route-test.c
new file mode 100644
index 000000000000..36fac3daf86a
--- /dev/null
+++ b/net/mctp/test/route-test.c
@@ -0,0 +1,544 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#include <kunit/test.h>
+
+#include "utils.h"
+
+struct mctp_test_route {
+ struct mctp_route rt;
+ struct sk_buff_head pkts;
+};
+
+static int mctp_test_route_output(struct mctp_route *rt, struct sk_buff *skb)
+{
+ struct mctp_test_route *test_rt = container_of(rt, struct mctp_test_route, rt);
+
+ skb_queue_tail(&test_rt->pkts, skb);
+
+ return 0;
+}
+
+/* local version of mctp_route_alloc() */
+static struct mctp_test_route *mctp_route_test_alloc(void)
+{
+ struct mctp_test_route *rt;
+
+ rt = kzalloc(sizeof(*rt), GFP_KERNEL);
+ if (!rt)
+ return NULL;
+
+ INIT_LIST_HEAD(&rt->rt.list);
+ refcount_set(&rt->rt.refs, 1);
+ rt->rt.output = mctp_test_route_output;
+
+ skb_queue_head_init(&rt->pkts);
+
+ return rt;
+}
+
+static struct mctp_test_route *mctp_test_create_route(struct net *net,
+ struct mctp_dev *dev,
+ mctp_eid_t eid,
+ unsigned int mtu)
+{
+ struct mctp_test_route *rt;
+
+ rt = mctp_route_test_alloc();
+ if (!rt)
+ return NULL;
+
+ rt->rt.min = eid;
+ rt->rt.max = eid;
+ rt->rt.mtu = mtu;
+ rt->rt.type = RTN_UNSPEC;
+ if (dev)
+ mctp_dev_hold(dev);
+ rt->rt.dev = dev;
+
+ list_add_rcu(&rt->rt.list, &net->mctp.routes);
+
+ return rt;
+}
+
+static void mctp_test_route_destroy(struct kunit *test,
+ struct mctp_test_route *rt)
+{
+ unsigned int refs;
+
+ rtnl_lock();
+ list_del_rcu(&rt->rt.list);
+ rtnl_unlock();
+
+ skb_queue_purge(&rt->pkts);
+ if (rt->rt.dev)
+ mctp_dev_put(rt->rt.dev);
+
+ refs = refcount_read(&rt->rt.refs);
+ KUNIT_ASSERT_EQ_MSG(test, refs, 1, "route ref imbalance");
+
+ kfree_rcu(&rt->rt, rcu);
+}
+
+static struct sk_buff *mctp_test_create_skb(const struct mctp_hdr *hdr,
+ unsigned int data_len)
+{
+ size_t hdr_len = sizeof(*hdr);
+ struct sk_buff *skb;
+ unsigned int i;
+ u8 *buf;
+
+ skb = alloc_skb(hdr_len + data_len, GFP_KERNEL);
+ if (!skb)
+ return NULL;
+
+ memcpy(skb_put(skb, hdr_len), hdr, hdr_len);
+
+ buf = skb_put(skb, data_len);
+ for (i = 0; i < data_len; i++)
+ buf[i] = i & 0xff;
+
+ return skb;
+}
+
+static struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr,
+ const void *data,
+ size_t data_len)
+{
+ size_t hdr_len = sizeof(*hdr);
+ struct sk_buff *skb;
+
+ skb = alloc_skb(hdr_len + data_len, GFP_KERNEL);
+ if (!skb)
+ return NULL;
+
+ memcpy(skb_put(skb, hdr_len), hdr, hdr_len);
+ memcpy(skb_put(skb, data_len), data, data_len);
+
+ return skb;
+}
+
+#define mctp_test_create_skb_data(h, d) \
+ __mctp_test_create_skb_data(h, d, sizeof(*d))
+
+struct mctp_frag_test {
+ unsigned int mtu;
+ unsigned int msgsize;
+ unsigned int n_frags;
+};
+
+static void mctp_test_fragment(struct kunit *test)
+{
+ const struct mctp_frag_test *params;
+ int rc, i, n, mtu, msgsize;
+ struct mctp_test_route *rt;
+ struct sk_buff *skb;
+ struct mctp_hdr hdr;
+ u8 seq;
+
+ params = test->param_value;
+ mtu = params->mtu;
+ msgsize = params->msgsize;
+
+ hdr.ver = 1;
+ hdr.src = 8;
+ hdr.dest = 10;
+ hdr.flags_seq_tag = MCTP_HDR_FLAG_TO;
+
+ skb = mctp_test_create_skb(&hdr, msgsize);
+ KUNIT_ASSERT_TRUE(test, skb);
+
+ rt = mctp_test_create_route(&init_net, NULL, 10, mtu);
+ KUNIT_ASSERT_TRUE(test, rt);
+
+ /* The refcount would usually be incremented as part of a route lookup,
+ * but we're setting the route directly here.
+ */
+ refcount_inc(&rt->rt.refs);
+
+ rc = mctp_do_fragment_route(&rt->rt, skb, mtu, MCTP_TAG_OWNER);
+ KUNIT_EXPECT_FALSE(test, rc);
+
+ n = rt->pkts.qlen;
+
+ KUNIT_EXPECT_EQ(test, n, params->n_frags);
+
+ for (i = 0;; i++) {
+ struct mctp_hdr *hdr2;
+ struct sk_buff *skb2;
+ u8 tag_mask, seq2;
+ bool first, last;
+
+ first = i == 0;
+ last = i == (n - 1);
+
+ skb2 = skb_dequeue(&rt->pkts);
+
+ if (!skb2)
+ break;
+
+ hdr2 = mctp_hdr(skb2);
+
+ tag_mask = MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO;
+
+ KUNIT_EXPECT_EQ(test, hdr2->ver, hdr.ver);
+ KUNIT_EXPECT_EQ(test, hdr2->src, hdr.src);
+ KUNIT_EXPECT_EQ(test, hdr2->dest, hdr.dest);
+ KUNIT_EXPECT_EQ(test, hdr2->flags_seq_tag & tag_mask,
+ hdr.flags_seq_tag & tag_mask);
+
+ KUNIT_EXPECT_EQ(test,
+ !!(hdr2->flags_seq_tag & MCTP_HDR_FLAG_SOM), first);
+ KUNIT_EXPECT_EQ(test,
+ !!(hdr2->flags_seq_tag & MCTP_HDR_FLAG_EOM), last);
+
+ seq2 = (hdr2->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT) &
+ MCTP_HDR_SEQ_MASK;
+
+ if (first) {
+ seq = seq2;
+ } else {
+ seq++;
+ KUNIT_EXPECT_EQ(test, seq2, seq & MCTP_HDR_SEQ_MASK);
+ }
+
+ if (!last)
+ KUNIT_EXPECT_EQ(test, skb2->len, mtu);
+ else
+ KUNIT_EXPECT_LE(test, skb2->len, mtu);
+
+ kfree_skb(skb2);
+ }
+
+ mctp_test_route_destroy(test, rt);
+}
+
+static const struct mctp_frag_test mctp_frag_tests[] = {
+ {.mtu = 68, .msgsize = 63, .n_frags = 1},
+ {.mtu = 68, .msgsize = 64, .n_frags = 1},
+ {.mtu = 68, .msgsize = 65, .n_frags = 2},
+ {.mtu = 68, .msgsize = 66, .n_frags = 2},
+ {.mtu = 68, .msgsize = 127, .n_frags = 2},
+ {.mtu = 68, .msgsize = 128, .n_frags = 2},
+ {.mtu = 68, .msgsize = 129, .n_frags = 3},
+ {.mtu = 68, .msgsize = 130, .n_frags = 3},
+};
+
+static void mctp_frag_test_to_desc(const struct mctp_frag_test *t, char *desc)
+{
+ sprintf(desc, "mtu %d len %d -> %d frags",
+ t->msgsize, t->mtu, t->n_frags);
+}
+
+KUNIT_ARRAY_PARAM(mctp_frag, mctp_frag_tests, mctp_frag_test_to_desc);
+
+struct mctp_rx_input_test {
+ struct mctp_hdr hdr;
+ bool input;
+};
+
+static void mctp_test_rx_input(struct kunit *test)
+{
+ const struct mctp_rx_input_test *params;
+ struct mctp_test_route *rt;
+ struct mctp_test_dev *dev;
+ struct sk_buff *skb;
+
+ params = test->param_value;
+
+ dev = mctp_test_create_dev();
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
+
+ rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
+
+ skb = mctp_test_create_skb(&params->hdr, 1);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
+
+ __mctp_cb(skb);
+
+ mctp_pkttype_receive(skb, dev->ndev, &mctp_packet_type, NULL);
+
+ KUNIT_EXPECT_EQ(test, !!rt->pkts.qlen, params->input);
+
+ mctp_test_route_destroy(test, rt);
+ mctp_test_destroy_dev(dev);
+}
+
+#define RX_HDR(_ver, _src, _dest, _fst) \
+ { .ver = _ver, .src = _src, .dest = _dest, .flags_seq_tag = _fst }
+
+/* we have a route for EID 8 only */
+static const struct mctp_rx_input_test mctp_rx_input_tests[] = {
+ { .hdr = RX_HDR(1, 10, 8, 0), .input = true },
+ { .hdr = RX_HDR(1, 10, 9, 0), .input = false }, /* no input route */
+ { .hdr = RX_HDR(2, 10, 8, 0), .input = false }, /* invalid version */
+};
+
+static void mctp_rx_input_test_to_desc(const struct mctp_rx_input_test *t,
+ char *desc)
+{
+ sprintf(desc, "{%x,%x,%x,%x}", t->hdr.ver, t->hdr.src, t->hdr.dest,
+ t->hdr.flags_seq_tag);
+}
+
+KUNIT_ARRAY_PARAM(mctp_rx_input, mctp_rx_input_tests,
+ mctp_rx_input_test_to_desc);
+
+/* set up a local dev, route on EID 8, and a socket listening on type 0 */
+static void __mctp_route_test_init(struct kunit *test,
+ struct mctp_test_dev **devp,
+ struct mctp_test_route **rtp,
+ struct socket **sockp)
+{
+ struct sockaddr_mctp addr;
+ struct mctp_test_route *rt;
+ struct mctp_test_dev *dev;
+ struct socket *sock;
+ int rc;
+
+ dev = mctp_test_create_dev();
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
+
+ rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
+
+ rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock);
+ KUNIT_ASSERT_EQ(test, rc, 0);
+
+ addr.smctp_family = AF_MCTP;
+ addr.smctp_network = MCTP_NET_ANY;
+ addr.smctp_addr.s_addr = 8;
+ addr.smctp_type = 0;
+ rc = kernel_bind(sock, (struct sockaddr *)&addr, sizeof(addr));
+ KUNIT_ASSERT_EQ(test, rc, 0);
+
+ *rtp = rt;
+ *devp = dev;
+ *sockp = sock;
+}
+
+static void __mctp_route_test_fini(struct kunit *test,
+ struct mctp_test_dev *dev,
+ struct mctp_test_route *rt,
+ struct socket *sock)
+{
+ sock_release(sock);
+ mctp_test_route_destroy(test, rt);
+ mctp_test_destroy_dev(dev);
+}
+
+struct mctp_route_input_sk_test {
+ struct mctp_hdr hdr;
+ u8 type;
+ bool deliver;
+};
+
+static void mctp_test_route_input_sk(struct kunit *test)
+{
+ const struct mctp_route_input_sk_test *params;
+ struct sk_buff *skb, *skb2;
+ struct mctp_test_route *rt;
+ struct mctp_test_dev *dev;
+ struct socket *sock;
+ int rc;
+
+ params = test->param_value;
+
+ __mctp_route_test_init(test, &dev, &rt, &sock);
+
+ skb = mctp_test_create_skb_data(&params->hdr, &params->type);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
+
+ skb->dev = dev->ndev;
+ __mctp_cb(skb);
+
+ rc = mctp_route_input(&rt->rt, skb);
+
+ if (params->deliver) {
+ KUNIT_EXPECT_EQ(test, rc, 0);
+
+ skb2 = skb_recv_datagram(sock->sk, 0, 1, &rc);
+ KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
+ KUNIT_EXPECT_EQ(test, skb->len, 1);
+
+ skb_free_datagram(sock->sk, skb2);
+
+ } else {
+ KUNIT_EXPECT_NE(test, rc, 0);
+ skb2 = skb_recv_datagram(sock->sk, 0, 1, &rc);
+ KUNIT_EXPECT_PTR_EQ(test, skb2, NULL);
+ }
+
+ __mctp_route_test_fini(test, dev, rt, sock);
+}
+
+#define FL_S (MCTP_HDR_FLAG_SOM)
+#define FL_E (MCTP_HDR_FLAG_EOM)
+#define FL_T (MCTP_HDR_FLAG_TO)
+
+static const struct mctp_route_input_sk_test mctp_route_input_sk_tests[] = {
+ { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T), .type = 0, .deliver = true },
+ { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T), .type = 1, .deliver = false },
+ { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E), .type = 0, .deliver = false },
+ { .hdr = RX_HDR(1, 10, 8, FL_E | FL_T), .type = 0, .deliver = false },
+ { .hdr = RX_HDR(1, 10, 8, FL_T), .type = 0, .deliver = false },
+ { .hdr = RX_HDR(1, 10, 8, 0), .type = 0, .deliver = false },
+};
+
+static void mctp_route_input_sk_to_desc(const struct mctp_route_input_sk_test *t,
+ char *desc)
+{
+ sprintf(desc, "{%x,%x,%x,%x} type %d", t->hdr.ver, t->hdr.src,
+ t->hdr.dest, t->hdr.flags_seq_tag, t->type);
+}
+
+KUNIT_ARRAY_PARAM(mctp_route_input_sk, mctp_route_input_sk_tests,
+ mctp_route_input_sk_to_desc);
+
+struct mctp_route_input_sk_reasm_test {
+ const char *name;
+ struct mctp_hdr hdrs[4];
+ int n_hdrs;
+ int rx_len;
+};
+
+static void mctp_test_route_input_sk_reasm(struct kunit *test)
+{
+ const struct mctp_route_input_sk_reasm_test *params;
+ struct sk_buff *skb, *skb2;
+ struct mctp_test_route *rt;
+ struct mctp_test_dev *dev;
+ struct socket *sock;
+ int i, rc;
+ u8 c;
+
+ params = test->param_value;
+
+ __mctp_route_test_init(test, &dev, &rt, &sock);
+
+ for (i = 0; i < params->n_hdrs; i++) {
+ c = i;
+ skb = mctp_test_create_skb_data(&params->hdrs[i], &c);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
+
+ skb->dev = dev->ndev;
+ __mctp_cb(skb);
+
+ rc = mctp_route_input(&rt->rt, skb);
+ }
+
+ skb2 = skb_recv_datagram(sock->sk, 0, 1, &rc);
+
+ if (params->rx_len) {
+ KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
+ KUNIT_EXPECT_EQ(test, skb2->len, params->rx_len);
+ skb_free_datagram(sock->sk, skb2);
+
+ } else {
+ KUNIT_EXPECT_PTR_EQ(test, skb2, NULL);
+ }
+
+ __mctp_route_test_fini(test, dev, rt, sock);
+}
+
+#define RX_FRAG(f, s) RX_HDR(1, 10, 8, FL_T | (f) | ((s) << MCTP_HDR_SEQ_SHIFT))
+
+static const struct mctp_route_input_sk_reasm_test mctp_route_input_sk_reasm_tests[] = {
+ {
+ .name = "single packet",
+ .hdrs = {
+ RX_FRAG(FL_S | FL_E, 0),
+ },
+ .n_hdrs = 1,
+ .rx_len = 1,
+ },
+ {
+ .name = "single packet, offset seq",
+ .hdrs = {
+ RX_FRAG(FL_S | FL_E, 1),
+ },
+ .n_hdrs = 1,
+ .rx_len = 1,
+ },
+ {
+ .name = "start & end packets",
+ .hdrs = {
+ RX_FRAG(FL_S, 0),
+ RX_FRAG(FL_E, 1),
+ },
+ .n_hdrs = 2,
+ .rx_len = 2,
+ },
+ {
+ .name = "start & end packets, offset seq",
+ .hdrs = {
+ RX_FRAG(FL_S, 1),
+ RX_FRAG(FL_E, 2),
+ },
+ .n_hdrs = 2,
+ .rx_len = 2,
+ },
+ {
+ .name = "start & end packets, out of order",
+ .hdrs = {
+ RX_FRAG(FL_E, 1),
+ RX_FRAG(FL_S, 0),
+ },
+ .n_hdrs = 2,
+ .rx_len = 0,
+ },
+ {
+ .name = "start, middle & end packets",
+ .hdrs = {
+ RX_FRAG(FL_S, 0),
+ RX_FRAG(0, 1),
+ RX_FRAG(FL_E, 2),
+ },
+ .n_hdrs = 3,
+ .rx_len = 3,
+ },
+ {
+ .name = "missing seq",
+ .hdrs = {
+ RX_FRAG(FL_S, 0),
+ RX_FRAG(FL_E, 2),
+ },
+ .n_hdrs = 2,
+ .rx_len = 0,
+ },
+ {
+ .name = "seq wrap",
+ .hdrs = {
+ RX_FRAG(FL_S, 3),
+ RX_FRAG(FL_E, 0),
+ },
+ .n_hdrs = 2,
+ .rx_len = 2,
+ },
+};
+
+static void mctp_route_input_sk_reasm_to_desc(
+ const struct mctp_route_input_sk_reasm_test *t,
+ char *desc)
+{
+ sprintf(desc, "%s", t->name);
+}
+
+KUNIT_ARRAY_PARAM(mctp_route_input_sk_reasm, mctp_route_input_sk_reasm_tests,
+ mctp_route_input_sk_reasm_to_desc);
+
+static struct kunit_case mctp_test_cases[] = {
+ KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params),
+ KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params),
+ KUNIT_CASE_PARAM(mctp_test_route_input_sk, mctp_route_input_sk_gen_params),
+ KUNIT_CASE_PARAM(mctp_test_route_input_sk_reasm,
+ mctp_route_input_sk_reasm_gen_params),
+ {}
+};
+
+static struct kunit_suite mctp_test_suite = {
+ .name = "mctp",
+ .test_cases = mctp_test_cases,
+};
+
+kunit_test_suite(mctp_test_suite);
diff --git a/net/mctp/test/utils.c b/net/mctp/test/utils.c
new file mode 100644
index 000000000000..cc6b8803aa9d
--- /dev/null
+++ b/net/mctp/test/utils.c
@@ -0,0 +1,67 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#include <linux/netdevice.h>
+#include <linux/mctp.h>
+#include <linux/if_arp.h>
+
+#include <net/mctpdevice.h>
+#include <net/pkt_sched.h>
+
+#include "utils.h"
+
+static netdev_tx_t mctp_test_dev_tx(struct sk_buff *skb,
+ struct net_device *ndev)
+{
+ kfree(skb);
+ return NETDEV_TX_OK;
+}
+
+static const struct net_device_ops mctp_test_netdev_ops = {
+ .ndo_start_xmit = mctp_test_dev_tx,
+};
+
+static void mctp_test_dev_setup(struct net_device *ndev)
+{
+ ndev->type = ARPHRD_MCTP;
+ ndev->mtu = MCTP_DEV_TEST_MTU;
+ ndev->hard_header_len = 0;
+ ndev->addr_len = 0;
+ ndev->tx_queue_len = DEFAULT_TX_QUEUE_LEN;
+ ndev->flags = IFF_NOARP;
+ ndev->netdev_ops = &mctp_test_netdev_ops;
+ ndev->needs_free_netdev = true;
+}
+
+struct mctp_test_dev *mctp_test_create_dev(void)
+{
+ struct mctp_test_dev *dev;
+ struct net_device *ndev;
+ int rc;
+
+ ndev = alloc_netdev(sizeof(*dev), "mctptest%d", NET_NAME_ENUM,
+ mctp_test_dev_setup);
+ if (!ndev)
+ return NULL;
+
+ dev = netdev_priv(ndev);
+ dev->ndev = ndev;
+
+ rc = register_netdev(ndev);
+ if (rc) {
+ free_netdev(ndev);
+ return NULL;
+ }
+
+ rcu_read_lock();
+ dev->mdev = __mctp_dev_get(ndev);
+ mctp_dev_hold(dev->mdev);
+ rcu_read_unlock();
+
+ return dev;
+}
+
+void mctp_test_destroy_dev(struct mctp_test_dev *dev)
+{
+ mctp_dev_put(dev->mdev);
+ unregister_netdev(dev->ndev);
+}
diff --git a/net/mctp/test/utils.h b/net/mctp/test/utils.h
new file mode 100644
index 000000000000..df6aa1c03440
--- /dev/null
+++ b/net/mctp/test/utils.h
@@ -0,0 +1,20 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+
+#ifndef __NET_MCTP_TEST_UTILS_H
+#define __NET_MCTP_TEST_UTILS_H
+
+#include <kunit/test.h>
+
+#define MCTP_DEV_TEST_MTU 68
+
+struct mctp_test_dev {
+ struct net_device *ndev;
+ struct mctp_dev *mdev;
+};
+
+struct mctp_test_dev;
+
+struct mctp_test_dev *mctp_test_create_dev(void);
+void mctp_test_destroy_dev(struct mctp_test_dev *dev);
+
+#endif /* __NET_MCTP_TEST_UTILS_H */