Thread (44 messages) 44 messages, 5 authors, 2023-06-29
STALE1069d
Revisions (2)
  1. v3 [diff vs current]
  2. v4 current

[PATCH RFC net-next v4 7/8] vsock: Add lockless sendmsg() support

From: Bobby Eshleman <hidden>
Date: 2023-06-10 00:59:32
Also in: bpf, kvm, linux-hyperv, lkml
Subsystem: hyper-v/azure core and drivers, networking [general], the rest, virtio and vhost vsock driver, virtio core, virtio host (vhost), vm sockets (af_vsock), vmware vsock vmci transport driver · Maintainers: "K. Y. Srinivasan", Haiyang Zhang, Wei Liu, Dexuan Cui, Long Li, "David S. Miller", Eric Dumazet, Jakub Kicinski, Paolo Abeni, Linus Torvalds, Stefan Hajnoczi, Stefano Garzarella, "Michael S. Tsirkin", Jason Wang, Bryan Tan, Vishnu Dasa

Because the dgram sendmsg() path for AF_VSOCK acquires the socket lock
it does not scale when many senders share a socket.

Prior to this patch the socket lock is used to protect both reads and
writes to the local_addr, remote_addr, transport, and buffer size
variables of a vsock socket. What follows are the new protection schemes
for these fields that ensure a race-free and usually lock-free
multi-sender sendmsg() path for vsock dgrams.

- local_addr
local_addr changes as a result of binding a socket. The write path
for local_addr is bind() and various vsock_auto_bind() call sites.
After a socket has been bound via vsock_auto_bind() or bind(), subsequent
calls to bind()/vsock_auto_bind() do not write to local_addr again. bind()
rejects the user request and vsock_auto_bind() early exits.
Therefore, the local addr can not change while a parallel thread is
in sendmsg() and lock-free reads of local addr in sendmsg() are safe.
Change: only acquire lock for auto-binding as-needed in sendmsg().

- buffer size variables
Not used by dgram, so they do not need protection. No change.

- remote_addr and transport
Because a remote_addr update may result in a changed transport, but we
would like to be able to read these two fields lock-free but coherently
in the vsock send path, this patch packages these two fields into a new
struct vsock_remote_info that is referenced by an RCU-protected pointer.

Writes are synchronized as usual by the socket lock. Reads only take
place in RCU read-side critical sections. When remote_addr or transport
is updated, a new remote info is allocated. Old readers still see the
old coherent remote_addr/transport pair, and new readers will refer to
the new coherent. The coherency between remote_addr and transport
previously provided by the socket lock alone is now also preserved by
RCU, except with the highly-scalable lock-free read-side.

Helpers are introduced for accessing and updating the new pointer.

The new structure is contains an rcu_head so that kfree_rcu() can be
used. This removes the need of writers to use synchronize_rcu() after
freeing old structures which is simply more efficient and reduces code
churn where remote_addr/transport are already being updated inside RCU
read-side sections.

Only virtio has been tested, but updates were necessary to the VMCI and
hyperv code. Unfortunately the author does not have access to
VMCI/hyperv systems so those changes are untested.

Perf Tests (results from patch v2)
vCPUS: 16
Threads: 16
Payload: 4KB
Test Runs: 5
Type: SOCK_DGRAM

Before: 245.2 MB/s
After: 509.2 MB/s (+107%)

Notably, on the same test system, vsock dgram even outperforms
multi-threaded UDP over virtio-net with vhost and MQ support enabled.

Throughput metrics for single-threaded SOCK_DGRAM and
single/multi-threaded SOCK_STREAM showed no statistically signficant
throughput changes (lowest p-value reaching 0.27), with the range of the
mean difference ranging between -5% to +1%.

Signed-off-by: Bobby Eshleman <redacted>
---
 drivers/vhost/vsock.c                   |  12 +-
 include/linux/virtio_vsock.h            |   3 +-
 include/net/af_vsock.h                  |  38 ++-
 net/vmw_vsock/af_vsock.c                | 399 ++++++++++++++++++++++++++------
 net/vmw_vsock/diag.c                    |  10 +-
 net/vmw_vsock/hyperv_transport.c        |  27 ++-
 net/vmw_vsock/virtio_transport_common.c |  34 ++-
 net/vmw_vsock/vmci_transport.c          |  84 +++++--
 net/vmw_vsock/vsock_bpf.c               |  10 +-
 9 files changed, 492 insertions(+), 125 deletions(-)
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 159c1a22c1a8..b027a780d333 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -297,13 +297,17 @@ static int
 vhost_transport_cancel_pkt(struct vsock_sock *vsk)
 {
 	struct vhost_vsock *vsock;
+	unsigned int cid;
 	int cnt = 0;
 	int ret = -ENODEV;
 
 	rcu_read_lock();
+	ret = vsock_remote_addr_cid(vsk, &cid);
+	if (ret < 0)
+		goto out;
 
 	/* Find the vhost_vsock according to guest context id  */
-	vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
+	vsock = vhost_vsock_get(cid);
 	if (!vsock)
 		goto out;
 
@@ -706,6 +710,10 @@ static void vhost_vsock_flush(struct vhost_vsock *vsock)
 static void vhost_vsock_reset_orphans(struct sock *sk)
 {
 	struct vsock_sock *vsk = vsock_sk(sk);
+	unsigned int cid;
+
+	if (vsock_remote_addr_cid(vsk, &cid) < 0)
+		return;
 
 	/* vmci_transport.c doesn't take sk_lock here either.  At least we're
 	 * under vsock_table_lock so the sock cannot disappear while we're
@@ -713,7 +721,7 @@ static void vhost_vsock_reset_orphans(struct sock *sk)
 	 */
 
 	/* If the peer is still valid, no need to reset connection */
-	if (vhost_vsock_get(vsk->remote_addr.svm_cid))
+	if (vhost_vsock_get(cid))
 		return;
 
 	/* If the close timeout is pending, let it expire.  This avoids races
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index 237ca87a2ecd..97656e83606f 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -231,7 +231,8 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk,
 				struct msghdr *msg,
 				size_t len);
 int
-virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
+virtio_transport_dgram_enqueue(const struct vsock_transport *transport,
+			       struct vsock_sock *vsk,
 			       struct sockaddr_vm *remote_addr,
 			       struct msghdr *msg,
 			       size_t len);
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index c115e655b4f5..928b09fbc64b 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -25,12 +25,17 @@ extern spinlock_t vsock_table_lock;
 #define vsock_sk(__sk)    ((struct vsock_sock *)__sk)
 #define sk_vsock(__vsk)   (&(__vsk)->sk)
 
+struct vsock_remote_info {
+	struct sockaddr_vm addr;
+	struct rcu_head rcu;
+	const struct vsock_transport *transport;
+};
+
 struct vsock_sock {
 	/* sk must be the first member. */
 	struct sock sk;
-	const struct vsock_transport *transport;
 	struct sockaddr_vm local_addr;
-	struct sockaddr_vm remote_addr;
+	struct vsock_remote_info __rcu *remote_info;
 	/* Links for the global tables of bound and connected sockets. */
 	struct list_head bound_table;
 	struct list_head connected_table;
@@ -120,8 +125,8 @@ struct vsock_transport {
 
 	/* DGRAM. */
 	int (*dgram_bind)(struct vsock_sock *, struct sockaddr_vm *);
-	int (*dgram_enqueue)(struct vsock_sock *, struct sockaddr_vm *,
-			     struct msghdr *, size_t len);
+	int (*dgram_enqueue)(const struct vsock_transport *, struct vsock_sock *,
+			     struct sockaddr_vm *, struct msghdr *, size_t len);
 	bool (*dgram_allow)(u32 cid, u32 port);
 	int (*dgram_get_cid)(struct sk_buff *skb, unsigned int *cid);
 	int (*dgram_get_port)(struct sk_buff *skb, unsigned int *port);
@@ -196,6 +201,16 @@ void vsock_core_unregister(const struct vsock_transport *t);
 /* The transport may downcast this to access transport-specific functions */
 const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk);
 
+static inline struct vsock_remote_info *
+vsock_core_get_remote_info(struct vsock_sock *vsk)
+{
+	/* vsk->remote_info may be accessed if the rcu read lock is held OR the
+	 * socket lock is held
+	 */
+	return rcu_dereference_check(vsk->remote_info,
+				     lockdep_sock_is_held(sk_vsock(vsk)));
+}
+
 /**** UTILS ****/
 
 /* vsock_table_lock must be held */
@@ -214,7 +229,7 @@ void vsock_release_pending(struct sock *pending);
 void vsock_add_pending(struct sock *listener, struct sock *pending);
 void vsock_remove_pending(struct sock *listener, struct sock *pending);
 void vsock_enqueue_accept(struct sock *listener, struct sock *connected);
-void vsock_insert_connected(struct vsock_sock *vsk);
+int vsock_insert_connected(struct vsock_sock *vsk);
 void vsock_remove_bound(struct vsock_sock *vsk);
 void vsock_remove_connected(struct vsock_sock *vsk);
 struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
@@ -223,7 +238,8 @@ struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
 void vsock_remove_sock(struct vsock_sock *vsk);
 void vsock_for_each_connected_socket(struct vsock_transport *transport,
 				     void (*fn)(struct sock *sk));
-int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk);
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk,
+			   struct sockaddr_vm *remote_addr);
 bool vsock_find_cid(unsigned int cid);
 struct sock *vsock_find_bound_dgram_socket(struct sockaddr_vm *addr);
 
@@ -253,4 +269,14 @@ static inline void __init vsock_bpf_build_proto(void)
 {}
 #endif
 
+/* RCU-protected remote addr helpers */
+int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid);
+int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port);
+int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid,
+			       unsigned int *port);
+int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *dest);
+bool vsock_remote_addr_bound(struct vsock_sock *vsk);
+bool vsock_remote_addr_equals(struct vsock_sock *vsk, struct sockaddr_vm *other);
+int vsock_remote_addr_update_cid_port(struct vsock_sock *vsk, u32 cid, u32 port);
+
 #endif /* __AF_VSOCK_H__ */
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index b0b18e7f4299..9e620d67889b 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -114,7 +114,12 @@
 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr);
 static void vsock_sk_destruct(struct sock *sk);
 static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
+static bool vsock_use_local_transport(unsigned int remote_cid);
 static bool sock_type_connectible(u16 type);
+static const struct vsock_transport *
+vsock_connectible_lookup_transport(unsigned int cid, __u8 flags);
+static const struct vsock_transport *
+vsock_dgram_lookup_transport(unsigned int cid, __u8 flags);
 
 /* Protocol family. */
 struct proto vsock_proto = {
@@ -146,6 +151,123 @@ static const struct vsock_transport *transport_local;
 static DEFINE_MUTEX(vsock_register_mutex);
 
 /**** UTILS ****/
+bool vsock_remote_addr_bound(struct vsock_sock *vsk)
+{
+	struct vsock_remote_info *remote_info;
+	bool ret;
+
+	rcu_read_lock();
+	remote_info = vsock_core_get_remote_info(vsk);
+	if (!remote_info) {
+		rcu_read_unlock();
+		return false;
+	}
+
+	ret = vsock_addr_bound(&remote_info->addr);
+	rcu_read_unlock();
+
+	return ret;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_bound);
+
+int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *dest)
+{
+	struct vsock_remote_info *remote_info;
+
+	rcu_read_lock();
+	remote_info = vsock_core_get_remote_info(vsk);
+	if (!remote_info) {
+		rcu_read_unlock();
+		return -EINVAL;
+	}
+	memcpy(dest, &remote_info->addr, sizeof(*dest));
+	rcu_read_unlock();
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_copy);
+
+int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid)
+{
+	return vsock_remote_addr_cid_port(vsk, cid, NULL);
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_cid);
+
+int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port)
+{
+	return vsock_remote_addr_cid_port(vsk, NULL, port);
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_port);
+
+int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid,
+			       unsigned int *port)
+{
+	struct vsock_remote_info *remote_info;
+
+	rcu_read_lock();
+	remote_info = vsock_core_get_remote_info(vsk);
+	if (!remote_info) {
+		rcu_read_unlock();
+		return -EINVAL;
+	}
+
+	if (cid)
+		*cid = remote_info->addr.svm_cid;
+	if (port)
+		*port = remote_info->addr.svm_port;
+
+	rcu_read_unlock();
+	return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_cid_port);
+
+/* The socket lock must be held by the caller */
+static int vsock_set_remote_info(struct vsock_sock *vsk,
+				 const struct vsock_transport *transport,
+				 struct sockaddr_vm *addr)
+{
+	struct vsock_remote_info *old, *new;
+
+	if (addr || transport) {
+		new = kmalloc(sizeof(*new), GFP_KERNEL);
+		if (!new)
+			return -ENOMEM;
+
+		if (addr)
+			memcpy(&new->addr, addr, sizeof(new->addr));
+
+		if (transport)
+			new->transport = transport;
+	} else {
+		new = NULL;
+	}
+
+	old = rcu_replace_pointer(vsk->remote_info, new,
+				  lockdep_sock_is_held(sk_vsock(vsk)));
+	kfree_rcu(old, rcu);
+
+	return 0;
+}
+
+bool vsock_remote_addr_equals(struct vsock_sock *vsk,
+			      struct sockaddr_vm *other)
+{
+	struct vsock_remote_info *remote_info;
+	bool equals;
+
+	rcu_read_lock();
+	remote_info = vsock_core_get_remote_info(vsk);
+	if (!remote_info) {
+		rcu_read_unlock();
+		return false;
+	}
+
+	equals = vsock_addr_equals_addr(&remote_info->addr, other);
+	rcu_read_unlock();
+
+	return equals;
+}
+EXPORT_SYMBOL_GPL(vsock_remote_addr_equals);
 
 /* Each bound VSocket is stored in the bind hash table and each connected
  * VSocket is stored in the connected hash table.
@@ -283,10 +405,17 @@ static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
 
 	list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
 			    connected_table) {
-		if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
+		struct vsock_remote_info *remote_info;
+
+		rcu_read_lock();
+		remote_info = vsock_core_get_remote_info(vsk);
+		if (remote_info &&
+		    vsock_addr_equals_addr(src, &remote_info->addr) &&
 		    dst->svm_port == vsk->local_addr.svm_port) {
+			rcu_read_unlock();
 			return sk_vsock(vsk);
 		}
+		rcu_read_unlock();
 	}
 
 	return NULL;
@@ -299,14 +428,25 @@ static void vsock_insert_unbound(struct vsock_sock *vsk)
 	spin_unlock_bh(&vsock_table_lock);
 }
 
-void vsock_insert_connected(struct vsock_sock *vsk)
+int vsock_insert_connected(struct vsock_sock *vsk)
 {
-	struct list_head *list = vsock_connected_sockets(
-		&vsk->remote_addr, &vsk->local_addr);
+	struct vsock_remote_info *remote_info;
+	struct list_head *list;
+
+	rcu_read_lock();
+	remote_info = vsock_core_get_remote_info(vsk);
+	if (!remote_info) {
+		rcu_read_unlock();
+		return -EINVAL;
+	}
+	list = vsock_connected_sockets(&remote_info->addr, &vsk->local_addr);
+	rcu_read_unlock();
 
 	spin_lock_bh(&vsock_table_lock);
 	__vsock_insert_connected(list, vsk);
 	spin_unlock_bh(&vsock_table_lock);
+
+	return 0;
 }
 EXPORT_SYMBOL_GPL(vsock_insert_connected);
 
@@ -388,7 +528,7 @@ void vsock_for_each_connected_socket(struct vsock_transport *transport,
 		struct vsock_sock *vsk;
 		list_for_each_entry(vsk, &vsock_connected_table[i],
 				    connected_table) {
-			if (vsk->transport != transport)
+			if (vsock_core_get_transport(vsk) != transport)
 				continue;
 
 			fn(sk_vsock(vsk));
@@ -454,12 +594,19 @@ static bool vsock_use_local_transport(unsigned int remote_cid)
 
 static void vsock_deassign_transport(struct vsock_sock *vsk)
 {
-	if (!vsk->transport)
+	struct vsock_remote_info *remote_info;
+
+	remote_info = rcu_replace_pointer(vsk->remote_info, NULL,
+					  lockdep_sock_is_held(sk_vsock(vsk)));
+	if (!remote_info)
 		return;
 
-	vsk->transport->destruct(vsk);
-	module_put(vsk->transport->module);
-	vsk->transport = NULL;
+	if (remote_info->transport) {
+		remote_info->transport->destruct(vsk);
+		module_put(remote_info->transport->module);
+	}
+
+	kfree_rcu(remote_info, rcu);
 }
 
 static const struct vsock_transport *
@@ -490,26 +637,29 @@ vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
 	return transport_dgram;
 }
 
-/* Assign a transport to a socket and call the .init transport callback.
+/* Assign a transport and remote addr to a socket and call the .init transport
+ * callback.
  *
- * Note: for connection oriented socket this must be called when vsk->remote_addr
- * is set (e.g. during the connect() or when a connection request on a listener
- * socket is received).
- * The vsk->remote_addr is used to decide which transport to use:
+ * The remote_addr is used to decide which transport to use. Both the addr
+ * and transport are updated simultaneously via RCU-protected pointer:
  *  - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
  *    g2h is not loaded, will use local transport;
  *  - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field
  *    includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport;
  *  - remote CID > VMADDR_CID_HOST will use host->guest transport;
  */
-int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk,
+			   struct sockaddr_vm *remote_addr)
 {
 	const struct vsock_transport *new_transport;
+	struct vsock_remote_info *old_info;
 	struct sock *sk = sk_vsock(vsk);
-	unsigned int remote_cid = vsk->remote_addr.svm_cid;
+	unsigned int remote_cid;
 	__u8 remote_flags;
 	int ret;
 
+	remote_cid = remote_addr->svm_cid;
+
 	/* If the packet is coming with the source and destination CIDs higher
 	 * than VMADDR_CID_HOST, then a vsock channel where all the packets are
 	 * forwarded to the host should be established. Then the host will
@@ -519,10 +669,10 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
 	 * the connect path the flag can be set by the user space application.
 	 */
 	if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST &&
-	    vsk->remote_addr.svm_cid > VMADDR_CID_HOST)
-		vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST;
+	    remote_cid > VMADDR_CID_HOST)
+		remote_addr->svm_flags |= VMADDR_FLAG_TO_HOST;
 
-	remote_flags = vsk->remote_addr.svm_flags;
+	remote_flags = remote_addr->svm_flags;
 
 	switch (sk->sk_type) {
 	case SOCK_DGRAM:
@@ -538,8 +688,9 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
 		return -ESOCKTNOSUPPORT;
 	}
 
-	if (vsk->transport) {
-		if (vsk->transport == new_transport)
+	old_info = vsock_core_get_remote_info(vsk);
+	if (old_info && old_info->transport) {
+		if (old_info->transport == new_transport)
 			return 0;
 
 		/* transport->release() must be called with sock lock acquired.
@@ -548,7 +699,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
 		 * function is called on a new socket which is not assigned to
 		 * any transport.
 		 */
-		vsk->transport->release(vsk);
+		old_info->transport->release(vsk);
 		vsock_deassign_transport(vsk);
 	}
 
@@ -566,13 +717,18 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
 		}
 	}
 
-	ret = new_transport->init(vsk, psk);
+	ret = vsock_set_remote_info(vsk, new_transport, remote_addr);
 	if (ret) {
 		module_put(new_transport->module);
 		return ret;
 	}
 
-	vsk->transport = new_transport;
+	ret = new_transport->init(vsk, psk);
+	if (ret) {
+		vsock_set_remote_info(vsk, NULL, NULL);
+		module_put(new_transport->module);
+		return ret;
+	}
 
 	return 0;
 }
@@ -629,12 +785,14 @@ static bool vsock_is_pending(struct sock *sk)
 
 static int vsock_send_shutdown(struct sock *sk, int mode)
 {
+	const struct vsock_transport *transport;
 	struct vsock_sock *vsk = vsock_sk(sk);
 
-	if (!vsk->transport)
+	transport = vsock_core_get_transport(vsk);
+	if (!transport)
 		return -ENODEV;
 
-	return vsk->transport->shutdown(vsk, mode);
+	return transport->shutdown(vsk, mode);
 }
 
 static void vsock_pending_work(struct work_struct *work)
@@ -757,7 +915,10 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
 static int vsock_bind_dgram(struct vsock_sock *vsk,
 			    struct sockaddr_vm *addr)
 {
-	if (!vsk->transport || !vsk->transport->dgram_bind) {
+	const struct vsock_transport *transport;
+
+	transport = vsock_core_get_transport(vsk);
+	if (!transport || !transport->dgram_bind) {
 		int retval;
 
 		spin_lock_bh(&vsock_dgram_table_lock);
@@ -768,7 +929,7 @@ static int vsock_bind_dgram(struct vsock_sock *vsk,
 		return retval;
 	}
 
-	return vsk->transport->dgram_bind(vsk, addr);
+	return transport->dgram_bind(vsk, addr);
 }
 
 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
@@ -817,6 +978,7 @@ static struct sock *__vsock_create(struct net *net,
 				   unsigned short type,
 				   int kern)
 {
+	struct vsock_remote_info *remote_info;
 	struct sock *sk;
 	struct vsock_sock *psk;
 	struct vsock_sock *vsk;
@@ -836,7 +998,14 @@ static struct sock *__vsock_create(struct net *net,
 
 	vsk = vsock_sk(sk);
 	vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
-	vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+
+	remote_info = kmalloc(sizeof(*remote_info), GFP_KERNEL);
+	if (!remote_info) {
+		sk_free(sk);
+		return NULL;
+	}
+	vsock_addr_init(&remote_info->addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+	rcu_assign_pointer(vsk->remote_info, remote_info);
 
 	sk->sk_destruct = vsock_sk_destruct;
 	sk->sk_backlog_rcv = vsock_queue_rcv_skb;
@@ -883,6 +1052,7 @@ static bool sock_type_connectible(u16 type)
 static void __vsock_release(struct sock *sk, int level)
 {
 	if (sk) {
+		const struct vsock_transport *transport;
 		struct sock *pending;
 		struct vsock_sock *vsk;
 
@@ -896,8 +1066,9 @@ static void __vsock_release(struct sock *sk, int level)
 		 */
 		lock_sock_nested(sk, level);
 
-		if (vsk->transport)
-			vsk->transport->release(vsk);
+		transport = vsock_core_get_transport(vsk);
+		if (transport)
+			transport->release(vsk);
 		else if (sock_type_connectible(sk->sk_type))
 			vsock_remove_sock(vsk);
 
@@ -927,8 +1098,6 @@ static void vsock_sk_destruct(struct sock *sk)
 	 * possibly register the address family with the kernel.
 	 */
 	vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
-	vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
-
 	put_cred(vsk->owner);
 }
 
@@ -952,16 +1121,22 @@ EXPORT_SYMBOL_GPL(vsock_create_connected);
 
 s64 vsock_stream_has_data(struct vsock_sock *vsk)
 {
-	return vsk->transport->stream_has_data(vsk);
+	const struct vsock_transport *transport;
+
+	transport = vsock_core_get_transport(vsk);
+
+	return transport->stream_has_data(vsk);
 }
 EXPORT_SYMBOL_GPL(vsock_stream_has_data);
 
 s64 vsock_connectible_has_data(struct vsock_sock *vsk)
 {
+	const struct vsock_transport *transport;
 	struct sock *sk = sk_vsock(vsk);
 
+	transport = vsock_core_get_transport(vsk);
 	if (sk->sk_type == SOCK_SEQPACKET)
-		return vsk->transport->seqpacket_has_data(vsk);
+		return transport->seqpacket_has_data(vsk);
 	else
 		return vsock_stream_has_data(vsk);
 }
@@ -969,7 +1144,10 @@ EXPORT_SYMBOL_GPL(vsock_connectible_has_data);
 
 s64 vsock_stream_has_space(struct vsock_sock *vsk)
 {
-	return vsk->transport->stream_has_space(vsk);
+	const struct vsock_transport *transport;
+
+	transport = vsock_core_get_transport(vsk);
+	return transport->stream_has_space(vsk);
 }
 EXPORT_SYMBOL_GPL(vsock_stream_has_space);
 
@@ -1018,6 +1196,7 @@ static int vsock_getname(struct socket *sock,
 	struct sock *sk;
 	struct vsock_sock *vsk;
 	struct sockaddr_vm *vm_addr;
+	struct vsock_remote_info *rcu_ptr;
 
 	sk = sock->sk;
 	vsk = vsock_sk(sk);
@@ -1030,7 +1209,14 @@ static int vsock_getname(struct socket *sock,
 			err = -ENOTCONN;
 			goto out;
 		}
-		vm_addr = &vsk->remote_addr;
+
+		rcu_ptr = vsock_core_get_remote_info(vsk);
+		if (!rcu_ptr) {
+			err = -EINVAL;
+			goto out;
+		}
+
+		vm_addr = &rcu_ptr->addr;
 	} else {
 		vm_addr = &vsk->local_addr;
 	}
@@ -1154,7 +1340,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
 
 		lock_sock(sk);
 
-		transport = vsk->transport;
+		transport = vsock_core_get_transport(vsk);
 
 		/* Listening sockets that have connections in their accept
 		 * queue can be read.
@@ -1225,9 +1411,11 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
 
 static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor)
 {
+	const struct vsock_transport *transport;
 	struct vsock_sock *vsk = vsock_sk(sk);
 
-	return vsk->transport->read_skb(vsk, read_actor);
+	transport = vsock_core_get_transport(vsk);
+	return transport->read_skb(vsk, read_actor);
 }
 
 static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
@@ -1236,7 +1424,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
 	int err;
 	struct sock *sk;
 	struct vsock_sock *vsk;
-	struct sockaddr_vm *remote_addr;
+	struct sockaddr_vm stack_addr, *remote_addr;
 	const struct vsock_transport *transport;
 
 	if (msg->msg_flags & MSG_OOB)
@@ -1247,7 +1435,23 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
 	sk = sock->sk;
 	vsk = vsock_sk(sk);
 
-	lock_sock(sk);
+	/* If auto-binding is required, acquire the slock to avoid potential
+	 * race conditions. Otherwise, do not acquire the lock.
+	 *
+	 * We know that the first check of local_addr is racy (indicated by
+	 * data_race()). By acquiring the lock and then subsequently checking
+	 * again if local_addr is bound (inside vsock_auto_bind()), we can
+	 * ensure there are no real data races.
+	 *
+	 * This technique is borrowed by inet_send_prepare().
+	 */
+	if (data_race(!vsock_addr_bound(&vsk->local_addr))) {
+		lock_sock(sk);
+		err = vsock_auto_bind(vsk);
+		release_sock(sk);
+		if (err)
+			return err;
+	}
 
 	/* If the provided message contains an address, use that.  Otherwise
 	 * fall back on the socket's remote handle (if it has been connected).
@@ -1257,6 +1461,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
 			    &remote_addr) == 0) {
 		transport = vsock_dgram_lookup_transport(remote_addr->svm_cid,
 							 remote_addr->svm_flags);
+
 		if (!transport) {
 			err = -EINVAL;
 			goto out;
@@ -1287,18 +1492,39 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
 			goto out;
 		}
 
-		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
+		err = transport->dgram_enqueue(transport, vsk, remote_addr, msg, len);
 		module_put(transport->module);
 	} else if (sock->state == SS_CONNECTED) {
-		remote_addr = &vsk->remote_addr;
-		transport = vsk->transport;
+		struct vsock_remote_info *remote_info;
+		const struct vsock_transport *transport;
 
-		err = vsock_auto_bind(vsk);
-		if (err)
+		rcu_read_lock();
+		remote_info = vsock_core_get_remote_info(vsk);
+		if (!remote_info) {
+			err = -EINVAL;
+			rcu_read_unlock();
 			goto out;
+		}
 
-		if (remote_addr->svm_cid == VMADDR_CID_ANY)
+		transport = remote_info->transport;
+		memcpy(&stack_addr, &remote_info->addr, sizeof(stack_addr));
+		rcu_read_unlock();
+
+		remote_addr = &stack_addr;
+
+		if (remote_addr->svm_cid == VMADDR_CID_ANY) {
 			remote_addr->svm_cid = transport->get_local_cid();
+			lock_sock(sk_vsock(vsk));
+			/* Even though the CID has changed, We do not have to
+			 * look up the transport again because the local CID
+			 * will never resolve to a different transport.
+			 */
+			err = vsock_set_remote_info(vsk, transport, remote_addr);
+			release_sock(sk_vsock(vsk));
+
+			if (err)
+				goto out;
+		}
 
 		/* XXX Should connect() or this function ensure remote_addr is
 		 * bound?
@@ -1314,14 +1540,13 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
 			goto out;
 		}
 
-		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
+		err = transport->dgram_enqueue(transport, vsk, &stack_addr, msg, len);
 	} else {
 		err = -EINVAL;
 		goto out;
 	}
 
 out:
-	release_sock(sk);
 	return err;
 }
 
@@ -1332,18 +1557,22 @@ static int vsock_dgram_connect(struct socket *sock,
 	struct sock *sk;
 	struct vsock_sock *vsk;
 	struct sockaddr_vm *remote_addr;
+	const struct vsock_transport *transport;
 
 	sk = sock->sk;
 	vsk = vsock_sk(sk);
 
 	err = vsock_addr_cast(addr, addr_len, &remote_addr);
 	if (err == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) {
+		struct sockaddr_vm addr_any;
+
 		lock_sock(sk);
-		vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY,
-				VMADDR_PORT_ANY);
+		vsock_addr_init(&addr_any, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+		err = vsock_set_remote_info(vsk, vsock_core_get_transport(vsk),
+					    &addr_any);
 		sock->state = SS_UNCONNECTED;
 		release_sock(sk);
-		return 0;
+		return err;
 	} else if (err != 0)
 		return -EINVAL;
 
@@ -1353,14 +1582,13 @@ static int vsock_dgram_connect(struct socket *sock,
 	if (err)
 		goto out;
 
-	memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
-
-	err = vsock_assign_transport(vsk, NULL);
+	err = vsock_assign_transport(vsk, NULL, remote_addr);
 	if (err)
 		goto out;
 
-	if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
-					 remote_addr->svm_port)) {
+	transport = vsock_core_get_transport(vsk);
+	if (!transport->dgram_allow(remote_addr->svm_cid,
+				    remote_addr->svm_port)) {
 		err = -EINVAL;
 		goto out;
 	}
@@ -1407,7 +1635,9 @@ int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
 	if (flags & MSG_OOB || flags & MSG_ERRQUEUE)
 		return -EOPNOTSUPP;
 
-	transport = vsk->transport;
+	rcu_read_lock();
+	transport = vsock_core_get_transport(vsk);
+	rcu_read_unlock();
 
 	/* Retrieve the head sk_buff from the socket's receive queue. */
 	err = 0;
@@ -1475,7 +1705,7 @@ static const struct proto_ops vsock_dgram_ops = {
 
 static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
 {
-	const struct vsock_transport *transport = vsk->transport;
+	const struct vsock_transport *transport = vsock_core_get_transport(vsk);
 
 	if (!transport || !transport->cancel_pkt)
 		return -EOPNOTSUPP;
@@ -1512,6 +1742,7 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr,
 	struct sock *sk;
 	struct vsock_sock *vsk;
 	const struct vsock_transport *transport;
+	struct vsock_remote_info *remote_info;
 	struct sockaddr_vm *remote_addr;
 	long timeout;
 	DEFINE_WAIT(wait);
@@ -1549,14 +1780,20 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr,
 		}
 
 		/* Set the remote address that we are connecting to. */
-		memcpy(&vsk->remote_addr, remote_addr,
-		       sizeof(vsk->remote_addr));
-
-		err = vsock_assign_transport(vsk, NULL);
+		err = vsock_assign_transport(vsk, NULL, remote_addr);
 		if (err)
 			goto out;
 
-		transport = vsk->transport;
+		rcu_read_lock();
+		remote_info = vsock_core_get_remote_info(vsk);
+		if (!remote_info) {
+			err = -EINVAL;
+			rcu_read_unlock();
+			goto out;
+		}
+
+		transport = remote_info->transport;
+		rcu_read_unlock();
 
 		/* The hypervisor and well-known contexts do not have socket
 		 * endpoints.
@@ -1820,7 +2057,7 @@ static int vsock_connectible_setsockopt(struct socket *sock,
 
 	lock_sock(sk);
 
-	transport = vsk->transport;
+	transport = vsock_core_get_transport(vsk);
 
 	switch (optname) {
 	case SO_VM_SOCKETS_BUFFER_SIZE:
@@ -1958,7 +2195,7 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
 
 	lock_sock(sk);
 
-	transport = vsk->transport;
+	transport = vsock_core_get_transport(vsk);
 
 	/* Callers should not provide a destination with connection oriented
 	 * sockets.
@@ -1981,7 +2218,7 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
 		goto out;
 	}
 
-	if (!vsock_addr_bound(&vsk->remote_addr)) {
+	if (!vsock_remote_addr_bound(vsk)) {
 		err = -EDESTADDRREQ;
 		goto out;
 	}
@@ -2102,7 +2339,7 @@ static int vsock_connectible_wait_data(struct sock *sk,
 
 	vsk = vsock_sk(sk);
 	err = 0;
-	transport = vsk->transport;
+	transport = vsock_core_get_transport(vsk);
 
 	while (1) {
 		prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE);
@@ -2170,7 +2407,7 @@ static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg,
 	DEFINE_WAIT(wait);
 
 	vsk = vsock_sk(sk);
-	transport = vsk->transport;
+	transport = vsock_core_get_transport(vsk);
 
 	/* We must not copy less than target bytes into the user's buffer
 	 * before returning successfully, so we wait for the consume queue to
@@ -2246,7 +2483,7 @@ static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
 	DEFINE_WAIT(wait);
 
 	vsk = vsock_sk(sk);
-	transport = vsk->transport;
+	transport = vsock_core_get_transport(vsk);
 
 	timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 
@@ -2303,7 +2540,7 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 
 	lock_sock(sk);
 
-	transport = vsk->transport;
+	transport = vsock_core_get_transport(vsk);
 
 	if (!transport || sk->sk_state != TCP_ESTABLISHED) {
 		/* Recvmsg is supposed to return 0 if a peer performs an
@@ -2370,7 +2607,7 @@ static int vsock_set_rcvlowat(struct sock *sk, int val)
 	if (val > vsk->buffer_size)
 		return -EINVAL;
 
-	transport = vsk->transport;
+	transport = vsock_core_get_transport(vsk);
 
 	if (transport && transport->set_rcvlowat)
 		return transport->set_rcvlowat(vsk, val);
@@ -2460,7 +2697,10 @@ static int vsock_create(struct net *net, struct socket *sock,
 	vsk = vsock_sk(sk);
 
 	if (sock->type == SOCK_DGRAM) {
-		ret = vsock_assign_transport(vsk, NULL);
+		struct sockaddr_vm remote_addr;
+
+		vsock_addr_init(&remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+		ret = vsock_assign_transport(vsk, NULL, &remote_addr);
 		if (ret < 0) {
 			sock_put(sk);
 			return ret;
@@ -2582,7 +2822,18 @@ static void __exit vsock_exit(void)
 
 const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
 {
-	return vsk->transport;
+	const struct vsock_transport *transport;
+	struct vsock_remote_info *remote_info;
+
+	rcu_read_lock();
+	remote_info = vsock_core_get_remote_info(vsk);
+	if (!remote_info) {
+		rcu_read_unlock();
+		return NULL;
+	}
+	transport = remote_info->transport;
+	rcu_read_unlock();
+	return transport;
 }
 EXPORT_SYMBOL_GPL(vsock_core_get_transport);
 
diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c
index a2823b1c5e28..f843bae86b32 100644
--- a/net/vmw_vsock/diag.c
+++ b/net/vmw_vsock/diag.c
@@ -15,8 +15,14 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
 			u32 portid, u32 seq, u32 flags)
 {
 	struct vsock_sock *vsk = vsock_sk(sk);
+	struct sockaddr_vm remote_addr;
 	struct vsock_diag_msg *rep;
 	struct nlmsghdr *nlh;
+	int err;
+
+	err = vsock_remote_addr_copy(vsk, &remote_addr);
+	if (err < 0)
+		return err;
 
 	nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep),
 			flags);
@@ -36,8 +42,8 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
 	rep->vdiag_shutdown = sk->sk_shutdown;
 	rep->vdiag_src_cid = vsk->local_addr.svm_cid;
 	rep->vdiag_src_port = vsk->local_addr.svm_port;
-	rep->vdiag_dst_cid = vsk->remote_addr.svm_cid;
-	rep->vdiag_dst_port = vsk->remote_addr.svm_port;
+	rep->vdiag_dst_cid = remote_addr.svm_cid;
+	rep->vdiag_dst_port = remote_addr.svm_port;
 	rep->vdiag_ino = sock_i_ino(sk);
 
 	sock_diag_save_cookie(sk, rep->vdiag_cookie);
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index c00bc5da769a..84e8c64b3365 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -323,6 +323,8 @@ static void hvs_open_connection(struct vmbus_channel *chan)
 		goto out;
 
 	if (conn_from_host) {
+		struct sockaddr_vm remote_addr;
+
 		if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
 			goto out;
 
@@ -336,10 +338,9 @@ static void hvs_open_connection(struct vmbus_channel *chan)
 		hvs_addr_init(&vnew->local_addr, if_type);
 
 		/* Remote peer is always the host */
-		vsock_addr_init(&vnew->remote_addr,
-				VMADDR_CID_HOST, VMADDR_PORT_ANY);
-		vnew->remote_addr.svm_port = get_port_by_srv_id(if_instance);
-		ret = vsock_assign_transport(vnew, vsock_sk(sk));
+		vsock_addr_init(&remote_addr, VMADDR_CID_HOST, get_port_by_srv_id(if_instance));
+
+		ret = vsock_assign_transport(vnew, vsock_sk(sk), &remote_addr);
 		/* Transport assigned (looking at remote_addr) must be the
 		 * same where we received the request.
 		 */
@@ -459,13 +460,18 @@ static int hvs_connect(struct vsock_sock *vsk)
 {
 	union hvs_service_id vm, host;
 	struct hvsock *h = vsk->trans;
+	int err;
 
 	vm.srv_id = srv_id_template;
 	vm.svm_port = vsk->local_addr.svm_port;
 	h->vm_srv_id = vm.srv_id;
 
 	host.srv_id = srv_id_template;
-	host.svm_port = vsk->remote_addr.svm_port;
+
+	err = vsock_remote_addr_port(vsk, &host.svm_port);
+	if (err < 0)
+		return err;
+
 	h->host_srv_id = host.srv_id;
 
 	return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
@@ -566,7 +572,8 @@ static int hvs_dgram_get_length(struct sk_buff *skb, size_t *len)
 	return -EOPNOTSUPP;
 }
 
-static int hvs_dgram_enqueue(struct vsock_sock *vsk,
+static int hvs_dgram_enqueue(const struct vsock_transport *transport,
+			     struct vsock_sock *vsk,
 			     struct sockaddr_vm *remote, struct msghdr *msg,
 			     size_t dgram_len)
 {
@@ -866,7 +873,13 @@ static struct vsock_transport hvs_transport = {
 
 static bool hvs_check_transport(struct vsock_sock *vsk)
 {
-	return vsk->transport == &hvs_transport;
+	bool ret;
+
+	rcu_read_lock();
+	ret = vsock_core_get_transport(vsk) == &hvs_transport;
+	rcu_read_unlock();
+
+	return ret;
 }
 
 static int hvs_probe(struct hv_device *hdev,
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index bc9d459723f5..9d090f208648 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -259,8 +259,9 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
 	src_cid = t_ops->transport.get_local_cid();
 	src_port = vsk->local_addr.svm_port;
 	if (!info->remote_cid) {
-		dst_cid	= vsk->remote_addr.svm_cid;
-		dst_port = vsk->remote_addr.svm_port;
+		ret = vsock_remote_addr_cid_port(vsk, &dst_cid, &dst_port);
+		if (ret < 0)
+			return ret;
 	} else {
 		dst_cid = info->remote_cid;
 		dst_port = info->remote_port;
@@ -878,12 +879,14 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
 
 int
-virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
+virtio_transport_dgram_enqueue(const struct vsock_transport *transport,
+			       struct vsock_sock *vsk,
 			       struct sockaddr_vm *remote_addr,
 			       struct msghdr *msg,
 			       size_t dgram_len)
 {
-	const struct virtio_transport *t_ops;
+	const struct virtio_transport *t_ops =
+		(const struct virtio_transport *)transport;
 	struct virtio_vsock_pkt_info info = {
 		.op = VIRTIO_VSOCK_OP_RW,
 		.msg = msg,
@@ -897,7 +900,6 @@ virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
 	if (dgram_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
 		return -EMSGSIZE;
 
-	t_ops = virtio_transport_get_ops(vsk);
 	src_cid = t_ops->transport.get_local_cid();
 	src_port = vsk->local_addr.svm_port;
 
@@ -1121,7 +1123,11 @@ virtio_transport_recv_connecting(struct sock *sk,
 	case VIRTIO_VSOCK_OP_RESPONSE:
 		sk->sk_state = TCP_ESTABLISHED;
 		sk->sk_socket->state = SS_CONNECTED;
-		vsock_insert_connected(vsk);
+		err = vsock_insert_connected(vsk);
+		if (err) {
+			skerr = ECONNRESET;
+			goto destroy;
+		}
 		sk->sk_state_change(sk);
 		break;
 	case VIRTIO_VSOCK_OP_INVALID:
@@ -1323,6 +1329,7 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
 	struct vsock_sock *vsk = vsock_sk(sk);
 	struct vsock_sock *vchild;
+	struct sockaddr_vm child_remote;
 	struct sock *child;
 	int ret;
 
@@ -1351,14 +1358,13 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
 	vchild = vsock_sk(child);
 	vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid),
 			le32_to_cpu(hdr->dst_port));
-	vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid),
+	vsock_addr_init(&child_remote, le64_to_cpu(hdr->src_cid),
 			le32_to_cpu(hdr->src_port));
-
-	ret = vsock_assign_transport(vchild, vsk);
+	ret = vsock_assign_transport(vchild, vsk, &child_remote);
 	/* Transport assigned (looking at remote_addr) must be the same
 	 * where we received the request.
 	 */
-	if (ret || vchild->transport != &t->transport) {
+	if (ret || vsock_core_get_transport(vchild) != &t->transport) {
 		release_sock(child);
 		virtio_transport_reset_no_sock(t, skb);
 		sock_put(child);
@@ -1368,7 +1374,13 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
 	if (virtio_transport_space_update(child, skb))
 		child->sk_write_space(child);
 
-	vsock_insert_connected(vchild);
+	ret = vsock_insert_connected(vchild);
+	if (ret) {
+		release_sock(child);
+		virtio_transport_reset_no_sock(t, skb);
+		sock_put(child);
+		return ret;
+	}
 	vsock_enqueue_accept(sk, child);
 	virtio_transport_send_response(vchild, skb);
 
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index bbc63826bf48..943539857ccb 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -283,18 +283,25 @@ vmci_transport_send_control_pkt(struct sock *sk,
 				u16 proto,
 				struct vmci_handle handle)
 {
+	struct sockaddr_vm addr_stack;
+	struct sockaddr_vm *remote_addr = &addr_stack;
 	struct vsock_sock *vsk;
+	int err;
 
 	vsk = vsock_sk(sk);
 
 	if (!vsock_addr_bound(&vsk->local_addr))
 		return -EINVAL;
 
-	if (!vsock_addr_bound(&vsk->remote_addr))
+	if (!vsock_remote_addr_bound(vsk))
 		return -EINVAL;
 
+	err = vsock_remote_addr_copy(vsk, remote_addr);
+	if (err < 0)
+		return err;
+
 	return vmci_transport_alloc_send_control_pkt(&vsk->local_addr,
-						     &vsk->remote_addr,
+						     remote_addr,
 						     type, size, mode,
 						     wait, proto, handle);
 }
@@ -317,6 +324,7 @@ static int vmci_transport_send_reset(struct sock *sk,
 	struct sockaddr_vm *dst_ptr;
 	struct sockaddr_vm dst;
 	struct vsock_sock *vsk;
+	int err;
 
 	if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST)
 		return 0;
@@ -326,13 +334,16 @@ static int vmci_transport_send_reset(struct sock *sk,
 	if (!vsock_addr_bound(&vsk->local_addr))
 		return -EINVAL;
 
-	if (vsock_addr_bound(&vsk->remote_addr)) {
-		dst_ptr = &vsk->remote_addr;
+	if (vsock_remote_addr_bound(vsk)) {
+		err = vsock_remote_addr_copy(vsk, &dst);
+		if (err < 0)
+			return err;
 	} else {
 		vsock_addr_init(&dst, pkt->dg.src.context,
 				pkt->src_port);
-		dst_ptr = &dst;
 	}
+	dst_ptr = &dst;
+
 	return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, dst_ptr,
 					     VMCI_TRANSPORT_PACKET_TYPE_RST,
 					     0, 0, NULL, VSOCK_PROTO_INVALID,
@@ -490,7 +501,7 @@ static struct sock *vmci_transport_get_pending(
 
 	list_for_each_entry(vpending, &vlistener->pending_links,
 			    pending_links) {
-		if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
+		if (vsock_remote_addr_equals(vpending, &src) &&
 		    pkt->dst_port == vpending->local_addr.svm_port) {
 			pending = sk_vsock(vpending);
 			sock_hold(pending);
@@ -940,6 +951,7 @@ static void vmci_transport_recv_pkt_work(struct work_struct *work)
 static int vmci_transport_recv_listen(struct sock *sk,
 				      struct vmci_transport_packet *pkt)
 {
+	struct sockaddr_vm remote_addr;
 	struct sock *pending;
 	struct vsock_sock *vpending;
 	int err;
@@ -1015,10 +1027,10 @@ static int vmci_transport_recv_listen(struct sock *sk,
 
 	vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context,
 			pkt->dst_port);
-	vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
-			pkt->src_port);
 
-	err = vsock_assign_transport(vpending, vsock_sk(sk));
+	vsock_addr_init(&remote_addr, pkt->dg.src.context, pkt->src_port);
+
+	err = vsock_assign_transport(vpending, vsock_sk(sk), &remote_addr);
 	/* Transport assigned (looking at remote_addr) must be the same
 	 * where we received the request.
 	 */
@@ -1133,6 +1145,7 @@ vmci_transport_recv_connecting_server(struct sock *listener,
 {
 	struct vsock_sock *vpending;
 	struct vmci_handle handle;
+	unsigned int vpending_remote_cid;
 	struct vmci_qp *qpair;
 	bool is_local;
 	u32 flags;
@@ -1189,8 +1202,13 @@ vmci_transport_recv_connecting_server(struct sock *listener,
 	/* vpending->local_addr always has a context id so we do not need to
 	 * worry about VMADDR_CID_ANY in this case.
 	 */
-	is_local =
-	    vpending->remote_addr.svm_cid == vpending->local_addr.svm_cid;
+	err = vsock_remote_addr_cid(vpending, &vpending_remote_cid);
+	if (err < 0) {
+		skerr = EPROTO;
+		goto destroy;
+	}
+
+	is_local = vpending_remote_cid == vpending->local_addr.svm_cid;
 	flags = VMCI_QPFLAG_ATTACH_ONLY;
 	flags |= is_local ? VMCI_QPFLAG_LOCAL : 0;
 
@@ -1203,7 +1221,7 @@ vmci_transport_recv_connecting_server(struct sock *listener,
 					flags,
 					vmci_transport_is_trusted(
 						vpending,
-						vpending->remote_addr.svm_cid));
+						vpending_remote_cid));
 	if (err < 0) {
 		vmci_transport_send_reset(pending, pkt);
 		skerr = -err;
@@ -1277,6 +1295,8 @@ static int
 vmci_transport_recv_connecting_client(struct sock *sk,
 				      struct vmci_transport_packet *pkt)
 {
+	struct vsock_remote_info *remote_info;
+	struct sockaddr_vm *remote_addr;
 	struct vsock_sock *vsk;
 	int err;
 	int skerr;
@@ -1306,9 +1326,20 @@ vmci_transport_recv_connecting_client(struct sock *sk,
 		break;
 	case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE:
 	case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2:
+		rcu_read_lock();
+		remote_info = vsock_core_get_remote_info(vsk);
+		if (!remote_info) {
+			skerr = EPROTO;
+			err = -EINVAL;
+			rcu_read_unlock();
+			goto destroy;
+		}
+
+		remote_addr = &remote_info->addr;
+
 		if (pkt->u.size == 0
-		    || pkt->dg.src.context != vsk->remote_addr.svm_cid
-		    || pkt->src_port != vsk->remote_addr.svm_port
+		    || pkt->dg.src.context != remote_addr->svm_cid
+		    || pkt->src_port != remote_addr->svm_port
 		    || !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)
 		    || vmci_trans(vsk)->qpair
 		    || vmci_trans(vsk)->produce_size != 0
@@ -1316,9 +1347,10 @@ vmci_transport_recv_connecting_client(struct sock *sk,
 		    || vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) {
 			skerr = EPROTO;
 			err = -EINVAL;
-
+			rcu_read_unlock();
 			goto destroy;
 		}
+		rcu_read_unlock();
 
 		err = vmci_transport_recv_connecting_client_negotiate(sk, pkt);
 		if (err) {
@@ -1379,6 +1411,7 @@ static int vmci_transport_recv_connecting_client_negotiate(
 	int err;
 	struct vsock_sock *vsk;
 	struct vmci_handle handle;
+	unsigned int remote_cid;
 	struct vmci_qp *qpair;
 	u32 detach_sub_id;
 	bool is_local;
@@ -1449,19 +1482,23 @@ static int vmci_transport_recv_connecting_client_negotiate(
 
 	/* Make VMCI select the handle for us. */
 	handle = VMCI_INVALID_HANDLE;
-	is_local = vsk->remote_addr.svm_cid == vsk->local_addr.svm_cid;
+
+	err = vsock_remote_addr_cid(vsk, &remote_cid);
+	if (err < 0)
+		goto destroy;
+
+	is_local = remote_cid == vsk->local_addr.svm_cid;
 	flags = is_local ? VMCI_QPFLAG_LOCAL : 0;
 
 	err = vmci_transport_queue_pair_alloc(&qpair,
 					      &handle,
 					      pkt->u.size,
 					      pkt->u.size,
-					      vsk->remote_addr.svm_cid,
+					      remote_cid,
 					      flags,
 					      vmci_transport_is_trusted(
 						  vsk,
-						  vsk->
-						  remote_addr.svm_cid));
+						  remote_cid));
 	if (err < 0)
 		goto destroy;
 
@@ -1692,6 +1729,7 @@ static int vmci_transport_dgram_bind(struct vsock_sock *vsk,
 }
 
 static int vmci_transport_dgram_enqueue(
+	const struct vsock_transport *transport,
 	struct vsock_sock *vsk,
 	struct sockaddr_vm *remote_addr,
 	struct msghdr *msg,
@@ -2052,7 +2090,13 @@ static struct vsock_transport vmci_transport = {
 
 static bool vmci_check_transport(struct vsock_sock *vsk)
 {
-	return vsk->transport == &vmci_transport;
+	bool retval;
+
+	rcu_read_lock();
+	retval = vsock_core_get_transport(vsk) == &vmci_transport;
+	rcu_read_unlock();
+
+	return retval;
 }
 
 static void vmci_vsock_transport_cb(bool is_host)
diff --git a/net/vmw_vsock/vsock_bpf.c b/net/vmw_vsock/vsock_bpf.c
index a3c97546ab84..4d811c9cdf6e 100644
--- a/net/vmw_vsock/vsock_bpf.c
+++ b/net/vmw_vsock/vsock_bpf.c
@@ -148,6 +148,7 @@ static void vsock_bpf_check_needs_rebuild(struct proto *ops)
 
 int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
 {
+	const struct vsock_transport *transport;
 	struct vsock_sock *vsk;
 
 	if (restore) {
@@ -157,10 +158,15 @@ int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore
 	}
 
 	vsk = vsock_sk(sk);
-	if (!vsk->transport)
+
+	rcu_read_lock();
+	transport = vsock_core_get_transport(vsk);
+	rcu_read_unlock();
+
+	if (!transport)
 		return -ENODEV;
 
-	if (!vsk->transport->read_skb)
+	if (!transport->read_skb)
 		return -EOPNOTSUPP;
 
 	vsock_bpf_check_needs_rebuild(psock->sk_proto);
-- 
2.30.2
Keyboard shortcuts
hback out one level
jnext message in thread
kprevious message in thread
ldrill in
Escclose help / fold thread tree
?toggle this help