[RFC,net-next,v4,3/8] vsock: support multi-transport datagrams

Message ID 20230413-b4-vsock-dgram-v4-3-0cebbb2ae899@bytedance.com
State New
Headers
Series virtio/vsock: support datagrams |

Commit Message

Bobby Eshleman June 10, 2023, 12:58 a.m. UTC
  This patch adds support for multi-transport datagrams.

This includes:
- Per-packet lookup of transports when using sendto(sockaddr_vm)
- Selecting H2G or G2H transport using VMADDR_FLAG_TO_HOST and CID in
  sockaddr_vm

To preserve backwards compatibility with VMCI, some important changes
were made. The "transport_dgram" / VSOCK_TRANSPORT_F_DGRAM is changed to
be used for dgrams iff there is not yet a g2h or h2g transport that has
been registered that can transmit the packet. If there is a g2h/h2g
transport for that remote address, then that transport will be used and
not "transport_dgram". This essentially makes "transport_dgram" a
fallback transport for when h2g/g2h has not yet gone online, which
appears to be the exact use case for VMCI.

This design makes sense, because there is no reason that the
transport_{g2h,h2g} cannot also service datagrams, which makes the role
of transport_dgram difficult to understand outside of the VMCI context.

The logic around "transport_dgram" had to be retained to prevent
breaking VMCI:

1) VMCI datagrams appear to function outside of the h2g/g2h
   paradigm. When the vmci transport becomes online, it registers itself
   with the DGRAM feature, but not H2G/G2H. Only later when the
   transport has more information about its environment does it register
   H2G or G2H. In the case that a datagram socket becomes active
   after DGRAM registration but before G2H/H2G registration, the
   "transport_dgram" transport needs to be used.

2) VMCI seems to require special message be sent by the transport when a
   datagram socket calls bind(). Under the h2g/g2h model, the transport
   is selected using the remote_addr which is set by connect(). At
   bind time there is no remote_addr because often no connect() has been
   called yet: the transport is null. Therefore, with a null transport
   there doesn't seem to be any good way for a datagram socket a tell the
   VMCI transport that it has just had bind() called upon it.

Only transports with a special datagram fallback use-case such as VMCI
need to register VSOCK_TRANSPORT_F_DGRAM.

Signed-off-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
---
 drivers/vhost/vsock.c                   |  1 -
 include/linux/virtio_vsock.h            |  2 -
 net/vmw_vsock/af_vsock.c                | 78 +++++++++++++++++++++++++--------
 net/vmw_vsock/hyperv_transport.c        |  6 ---
 net/vmw_vsock/virtio_transport.c        |  1 -
 net/vmw_vsock/virtio_transport_common.c |  7 ---
 net/vmw_vsock/vsock_loopback.c          |  1 -
 7 files changed, 60 insertions(+), 36 deletions(-)
  

Comments

Stefano Garzarella June 22, 2023, 3:19 p.m. UTC | #1
On Sat, Jun 10, 2023 at 12:58:30AM +0000, Bobby Eshleman wrote:
>This patch adds support for multi-transport datagrams.
>
>This includes:
>- Per-packet lookup of transports when using sendto(sockaddr_vm)
>- Selecting H2G or G2H transport using VMADDR_FLAG_TO_HOST and CID in
>  sockaddr_vm
>
>To preserve backwards compatibility with VMCI, some important changes
>were made. The "transport_dgram" / VSOCK_TRANSPORT_F_DGRAM is changed to
>be used for dgrams iff there is not yet a g2h or h2g transport that has

s/iff/if

>been registered that can transmit the packet. If there is a g2h/h2g
>transport for that remote address, then that transport will be used and
>not "transport_dgram". This essentially makes "transport_dgram" a
>fallback transport for when h2g/g2h has not yet gone online, which
>appears to be the exact use case for VMCI.
>
>This design makes sense, because there is no reason that the
>transport_{g2h,h2g} cannot also service datagrams, which makes the role
>of transport_dgram difficult to understand outside of the VMCI context.
>
>The logic around "transport_dgram" had to be retained to prevent
>breaking VMCI:
>
>1) VMCI datagrams appear to function outside of the h2g/g2h
>   paradigm. When the vmci transport becomes online, it registers itself
>   with the DGRAM feature, but not H2G/G2H. Only later when the
>   transport has more information about its environment does it register
>   H2G or G2H. In the case that a datagram socket becomes active
>   after DGRAM registration but before G2H/H2G registration, the
>   "transport_dgram" transport needs to be used.

IIRC we did this, because at that time only VMCI supported DGRAM. Now 
that there are more transports, maybe DGRAM can follow the h2g/g2h 
paradigm.

>
>2) VMCI seems to require special message be sent by the transport when a
>   datagram socket calls bind(). Under the h2g/g2h model, the transport
>   is selected using the remote_addr which is set by connect(). At
>   bind time there is no remote_addr because often no connect() has been
>   called yet: the transport is null. Therefore, with a null transport
>   there doesn't seem to be any good way for a datagram socket a tell the
>   VMCI transport that it has just had bind() called upon it.

@Vishnu, @Bryan do you think we can avoid this in some way?

>
>Only transports with a special datagram fallback use-case such as VMCI
>need to register VSOCK_TRANSPORT_F_DGRAM.

Maybe we should rename it in VSOCK_TRANSPORT_F_DGRAM_FALLBACK or
something like that.

In any case, we definitely need to update the comment in 
include/net/af_vsock.h on top of VSOCK_TRANSPORT_F_DGRAM mentioning
this.

>
>Signed-off-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
>---
> drivers/vhost/vsock.c                   |  1 -
> include/linux/virtio_vsock.h            |  2 -
> net/vmw_vsock/af_vsock.c                | 78 +++++++++++++++++++++++++--------
> net/vmw_vsock/hyperv_transport.c        |  6 ---
> net/vmw_vsock/virtio_transport.c        |  1 -
> net/vmw_vsock/virtio_transport_common.c |  7 ---
> net/vmw_vsock/vsock_loopback.c          |  1 -
> 7 files changed, 60 insertions(+), 36 deletions(-)
>
>diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
>index c8201c070b4b..8f0082da5e70 100644
>--- a/drivers/vhost/vsock.c
>+++ b/drivers/vhost/vsock.c
>@@ -410,7 +410,6 @@ static struct virtio_transport vhost_transport = {
> 		.cancel_pkt               = vhost_transport_cancel_pkt,
>
> 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
>-		.dgram_bind               = virtio_transport_dgram_bind,
> 		.dgram_allow              = virtio_transport_dgram_allow,
> 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
> 		.dgram_get_port		  = virtio_transport_dgram_get_port,
>diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
>index 23521a318cf0..73afa09f4585 100644
>--- a/include/linux/virtio_vsock.h
>+++ b/include/linux/virtio_vsock.h
>@@ -216,8 +216,6 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val);
> u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
> bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
> bool virtio_transport_stream_allow(u32 cid, u32 port);
>-int virtio_transport_dgram_bind(struct vsock_sock *vsk,
>-				struct sockaddr_vm *addr);
> bool virtio_transport_dgram_allow(u32 cid, u32 port);
> int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid);
> int virtio_transport_dgram_get_port(struct sk_buff *skb, unsigned int *port);
>diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
>index 74358f0b47fa..ef86765f3765 100644
>--- a/net/vmw_vsock/af_vsock.c
>+++ b/net/vmw_vsock/af_vsock.c
>@@ -438,6 +438,18 @@ vsock_connectible_lookup_transport(unsigned int cid, __u8 flags)
> 	return transport;
> }
>
>+static const struct vsock_transport *
>+vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
>+{
>+	const struct vsock_transport *transport;
>+
>+	transport = vsock_connectible_lookup_transport(cid, flags);
>+	if (transport)
>+		return transport;
>+
>+	return transport_dgram;
>+}
>+
> /* Assign a transport to a socket and call the .init transport callback.
>  *
>  * Note: for connection oriented socket this must be called when vsk->remote_addr
>@@ -474,7 +486,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
>
> 	switch (sk->sk_type) {
> 	case SOCK_DGRAM:
>-		new_transport = transport_dgram;
>+		new_transport = vsock_dgram_lookup_transport(remote_cid,
>+							     remote_flags);
> 		break;
> 	case SOCK_STREAM:
> 	case SOCK_SEQPACKET:
>@@ -691,6 +704,9 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
> static int __vsock_bind_dgram(struct vsock_sock *vsk,
> 			      struct sockaddr_vm *addr)
> {
>+	if (!vsk->transport || !vsk->transport->dgram_bind)
>+		return -EINVAL;
>+
> 	return vsk->transport->dgram_bind(vsk, addr);
> }
>
>@@ -1172,19 +1188,24 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
>
> 	lock_sock(sk);
>
>-	transport = vsk->transport;
>-
>-	err = vsock_auto_bind(vsk);
>-	if (err)
>-		goto out;
>-
>-
> 	/* If the provided message contains an address, use that.  Otherwise
> 	 * fall back on the socket's remote handle (if it has been connected).
> 	 */
> 	if (msg->msg_name &&
> 	    vsock_addr_cast(msg->msg_name, msg->msg_namelen,
> 			    &remote_addr) == 0) {
>+		transport = vsock_dgram_lookup_transport(remote_addr->svm_cid,
>+							 remote_addr->svm_flags);
>+		if (!transport) {
>+			err = -EINVAL;
>+			goto out;
>+		}
>+
>+		if (!try_module_get(transport->module)) {
>+			err = -ENODEV;
>+			goto out;
>+		}
>+
> 		/* Ensure this address is of the right type and is a valid
> 		 * destination.
> 		 */
>@@ -1193,11 +1214,27 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> 			remote_addr->svm_cid = transport->get_local_cid();
>

 From here ...

> 		if (!vsock_addr_bound(remote_addr)) {
>+			module_put(transport->module);
>+			err = -EINVAL;
>+			goto out;
>+		}
>+
>+		if (!transport->dgram_allow(remote_addr->svm_cid,
>+					    remote_addr->svm_port)) {
>+			module_put(transport->module);
> 			err = -EINVAL;
> 			goto out;
> 		}
>+
>+		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);

... to here, looks like duplicate code, can we get it out of the if 
block?

>+		module_put(transport->module);
> 	} else if (sock->state == SS_CONNECTED) {
> 		remote_addr = &vsk->remote_addr;
>+		transport = vsk->transport;
>+
>+		err = vsock_auto_bind(vsk);
>+		if (err)
>+			goto out;
>
> 		if (remote_addr->svm_cid == VMADDR_CID_ANY)
> 			remote_addr->svm_cid = transport->get_local_cid();
>@@ -1205,23 +1242,23 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> 		/* XXX Should connect() or this function ensure remote_addr is
> 		 * bound?
> 		 */
>-		if (!vsock_addr_bound(&vsk->remote_addr)) {
>+		if (!vsock_addr_bound(remote_addr)) {
> 			err = -EINVAL;
> 			goto out;
> 		}
>-	} else {
>-		err = -EINVAL;
>-		goto out;
>-	}
>
>-	if (!transport->dgram_allow(remote_addr->svm_cid,
>-				    remote_addr->svm_port)) {
>+		if (!transport->dgram_allow(remote_addr->svm_cid,
>+					    remote_addr->svm_port)) {
>+			err = -EINVAL;
>+			goto out;
>+		}
>+
>+		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
>+	} else {
> 		err = -EINVAL;
> 		goto out;
> 	}
>
>-	err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
>-
> out:
> 	release_sock(sk);
> 	return err;
>@@ -1255,13 +1292,18 @@ static int vsock_dgram_connect(struct socket *sock,
> 	if (err)
> 		goto out;
>
>+	memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
>+
>+	err = vsock_assign_transport(vsk, NULL);
>+	if (err)
>+		goto out;
>+
> 	if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
> 					 remote_addr->svm_port)) {
> 		err = -EINVAL;
> 		goto out;
> 	}
>
>-	memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
> 	sock->state = SS_CONNECTED;
>
> 	/* sock map disallows redirection of non-TCP sockets with sk_state !=
>diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
>index ff6e87e25fa0..c00bc5da769a 100644
>--- a/net/vmw_vsock/hyperv_transport.c
>+++ b/net/vmw_vsock/hyperv_transport.c
>@@ -551,11 +551,6 @@ static void hvs_destruct(struct vsock_sock *vsk)
> 	kfree(hvs);
> }
>
>-static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
>-{
>-	return -EOPNOTSUPP;
>-}
>-
> static int hvs_dgram_get_cid(struct sk_buff *skb, unsigned int *cid)
> {
> 	return -EOPNOTSUPP;
>@@ -841,7 +836,6 @@ static struct vsock_transport hvs_transport = {
> 	.connect                  = hvs_connect,
> 	.shutdown                 = hvs_shutdown,
>
>-	.dgram_bind               = hvs_dgram_bind,
> 	.dgram_get_cid		  = hvs_dgram_get_cid,
> 	.dgram_get_port		  = hvs_dgram_get_port,
> 	.dgram_get_length	  = hvs_dgram_get_length,
>diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
>index 5763cdf13804..1b7843a7779a 100644
>--- a/net/vmw_vsock/virtio_transport.c
>+++ b/net/vmw_vsock/virtio_transport.c
>@@ -428,7 +428,6 @@ static struct virtio_transport virtio_transport = {
> 		.shutdown                 = virtio_transport_shutdown,
> 		.cancel_pkt               = virtio_transport_cancel_pkt,
>
>-		.dgram_bind               = virtio_transport_dgram_bind,
> 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
> 		.dgram_allow              = virtio_transport_dgram_allow,
> 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
>diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
>index e6903c719964..d5a3c8efe84b 100644
>--- a/net/vmw_vsock/virtio_transport_common.c
>+++ b/net/vmw_vsock/virtio_transport_common.c
>@@ -790,13 +790,6 @@ bool virtio_transport_stream_allow(u32 cid, u32 port)
> }
> EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
>
>-int virtio_transport_dgram_bind(struct vsock_sock *vsk,
>-				struct sockaddr_vm *addr)
>-{
>-	return -EOPNOTSUPP;
>-}
>-EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
>-
> int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid)
> {
> 	return -EOPNOTSUPP;
>diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c
>index 2f3cabc79ee5..e9de45a26fbd 100644
>--- a/net/vmw_vsock/vsock_loopback.c
>+++ b/net/vmw_vsock/vsock_loopback.c
>@@ -61,7 +61,6 @@ static struct virtio_transport loopback_transport = {
> 		.shutdown                 = virtio_transport_shutdown,
> 		.cancel_pkt               = vsock_loopback_cancel_pkt,
>
>-		.dgram_bind               = virtio_transport_dgram_bind,
> 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
> 		.dgram_allow              = virtio_transport_dgram_allow,
> 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
>
>-- 
>2.30.2
>

The rest LGTM!

Stefano
  
Bobby Eshleman June 23, 2023, 2:50 a.m. UTC | #2
On Thu, Jun 22, 2023 at 05:19:08PM +0200, Stefano Garzarella wrote:
> On Sat, Jun 10, 2023 at 12:58:30AM +0000, Bobby Eshleman wrote:
> > This patch adds support for multi-transport datagrams.
> > 
> > This includes:
> > - Per-packet lookup of transports when using sendto(sockaddr_vm)
> > - Selecting H2G or G2H transport using VMADDR_FLAG_TO_HOST and CID in
> >  sockaddr_vm
> > 
> > To preserve backwards compatibility with VMCI, some important changes
> > were made. The "transport_dgram" / VSOCK_TRANSPORT_F_DGRAM is changed to
> > be used for dgrams iff there is not yet a g2h or h2g transport that has
> 
> s/iff/if
> 
> > been registered that can transmit the packet. If there is a g2h/h2g
> > transport for that remote address, then that transport will be used and
> > not "transport_dgram". This essentially makes "transport_dgram" a
> > fallback transport for when h2g/g2h has not yet gone online, which
> > appears to be the exact use case for VMCI.
> > 
> > This design makes sense, because there is no reason that the
> > transport_{g2h,h2g} cannot also service datagrams, which makes the role
> > of transport_dgram difficult to understand outside of the VMCI context.
> > 
> > The logic around "transport_dgram" had to be retained to prevent
> > breaking VMCI:
> > 
> > 1) VMCI datagrams appear to function outside of the h2g/g2h
> >   paradigm. When the vmci transport becomes online, it registers itself
> >   with the DGRAM feature, but not H2G/G2H. Only later when the
> >   transport has more information about its environment does it register
> >   H2G or G2H. In the case that a datagram socket becomes active
> >   after DGRAM registration but before G2H/H2G registration, the
> >   "transport_dgram" transport needs to be used.
> 
> IIRC we did this, because at that time only VMCI supported DGRAM. Now that
> there are more transports, maybe DGRAM can follow the h2g/g2h paradigm.
> 

Totally makes sense. I'll add the detail above that the prior design was
a result of chronology.

> > 
> > 2) VMCI seems to require special message be sent by the transport when a
> >   datagram socket calls bind(). Under the h2g/g2h model, the transport
> >   is selected using the remote_addr which is set by connect(). At
> >   bind time there is no remote_addr because often no connect() has been
> >   called yet: the transport is null. Therefore, with a null transport
> >   there doesn't seem to be any good way for a datagram socket a tell the
> >   VMCI transport that it has just had bind() called upon it.
> 
> @Vishnu, @Bryan do you think we can avoid this in some way?
> 
> > 
> > Only transports with a special datagram fallback use-case such as VMCI
> > need to register VSOCK_TRANSPORT_F_DGRAM.
> 
> Maybe we should rename it in VSOCK_TRANSPORT_F_DGRAM_FALLBACK or
> something like that.
> 
> In any case, we definitely need to update the comment in
> include/net/af_vsock.h on top of VSOCK_TRANSPORT_F_DGRAM mentioning
> this.
> 

Agreed. I'll rename to VSOCK_TRANSPORT_F_DGRAM_FALLBACK, unless we find
there is a better way altogether.

> > 
> > Signed-off-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
> > ---
> > drivers/vhost/vsock.c                   |  1 -
> > include/linux/virtio_vsock.h            |  2 -
> > net/vmw_vsock/af_vsock.c                | 78 +++++++++++++++++++++++++--------
> > net/vmw_vsock/hyperv_transport.c        |  6 ---
> > net/vmw_vsock/virtio_transport.c        |  1 -
> > net/vmw_vsock/virtio_transport_common.c |  7 ---
> > net/vmw_vsock/vsock_loopback.c          |  1 -
> > 7 files changed, 60 insertions(+), 36 deletions(-)
> > 
> > diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
> > index c8201c070b4b..8f0082da5e70 100644
> > --- a/drivers/vhost/vsock.c
> > +++ b/drivers/vhost/vsock.c
> > @@ -410,7 +410,6 @@ static struct virtio_transport vhost_transport = {
> > 		.cancel_pkt               = vhost_transport_cancel_pkt,
> > 
> > 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
> > -		.dgram_bind               = virtio_transport_dgram_bind,
> > 		.dgram_allow              = virtio_transport_dgram_allow,
> > 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
> > 		.dgram_get_port		  = virtio_transport_dgram_get_port,
> > diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
> > index 23521a318cf0..73afa09f4585 100644
> > --- a/include/linux/virtio_vsock.h
> > +++ b/include/linux/virtio_vsock.h
> > @@ -216,8 +216,6 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val);
> > u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
> > bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
> > bool virtio_transport_stream_allow(u32 cid, u32 port);
> > -int virtio_transport_dgram_bind(struct vsock_sock *vsk,
> > -				struct sockaddr_vm *addr);
> > bool virtio_transport_dgram_allow(u32 cid, u32 port);
> > int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid);
> > int virtio_transport_dgram_get_port(struct sk_buff *skb, unsigned int *port);
> > diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
> > index 74358f0b47fa..ef86765f3765 100644
> > --- a/net/vmw_vsock/af_vsock.c
> > +++ b/net/vmw_vsock/af_vsock.c
> > @@ -438,6 +438,18 @@ vsock_connectible_lookup_transport(unsigned int cid, __u8 flags)
> > 	return transport;
> > }
> > 
> > +static const struct vsock_transport *
> > +vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
> > +{
> > +	const struct vsock_transport *transport;
> > +
> > +	transport = vsock_connectible_lookup_transport(cid, flags);
> > +	if (transport)
> > +		return transport;
> > +
> > +	return transport_dgram;
> > +}
> > +
> > /* Assign a transport to a socket and call the .init transport callback.
> >  *
> >  * Note: for connection oriented socket this must be called when vsk->remote_addr
> > @@ -474,7 +486,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
> > 
> > 	switch (sk->sk_type) {
> > 	case SOCK_DGRAM:
> > -		new_transport = transport_dgram;
> > +		new_transport = vsock_dgram_lookup_transport(remote_cid,
> > +							     remote_flags);
> > 		break;
> > 	case SOCK_STREAM:
> > 	case SOCK_SEQPACKET:
> > @@ -691,6 +704,9 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
> > static int __vsock_bind_dgram(struct vsock_sock *vsk,
> > 			      struct sockaddr_vm *addr)
> > {
> > +	if (!vsk->transport || !vsk->transport->dgram_bind)
> > +		return -EINVAL;
> > +
> > 	return vsk->transport->dgram_bind(vsk, addr);
> > }
> > 
> > @@ -1172,19 +1188,24 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> > 
> > 	lock_sock(sk);
> > 
> > -	transport = vsk->transport;
> > -
> > -	err = vsock_auto_bind(vsk);
> > -	if (err)
> > -		goto out;
> > -
> > -
> > 	/* If the provided message contains an address, use that.  Otherwise
> > 	 * fall back on the socket's remote handle (if it has been connected).
> > 	 */
> > 	if (msg->msg_name &&
> > 	    vsock_addr_cast(msg->msg_name, msg->msg_namelen,
> > 			    &remote_addr) == 0) {
> > +		transport = vsock_dgram_lookup_transport(remote_addr->svm_cid,
> > +							 remote_addr->svm_flags);
> > +		if (!transport) {
> > +			err = -EINVAL;
> > +			goto out;
> > +		}
> > +
> > +		if (!try_module_get(transport->module)) {
> > +			err = -ENODEV;
> > +			goto out;
> > +		}
> > +
> > 		/* Ensure this address is of the right type and is a valid
> > 		 * destination.
> > 		 */
> > @@ -1193,11 +1214,27 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> > 			remote_addr->svm_cid = transport->get_local_cid();
> > 
> 
> From here ...
> 
> > 		if (!vsock_addr_bound(remote_addr)) {
> > +			module_put(transport->module);
> > +			err = -EINVAL;
> > +			goto out;
> > +		}
> > +
> > +		if (!transport->dgram_allow(remote_addr->svm_cid,
> > +					    remote_addr->svm_port)) {
> > +			module_put(transport->module);
> > 			err = -EINVAL;
> > 			goto out;
> > 		}
> > +
> > +		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
> 
> ... to here, looks like duplicate code, can we get it out of the if block?
> 

Yes, I think using something like this:

[...]
	bool module_got = false;

[...]
		if (!try_module_get(transport->module)) {
			err = -ENODEV;
			goto out;
		}
		module_got = true;

[...]

out:
	if (likely(transport && !err && module_got))
		module_put(transport->module)

> > +		module_put(transport->module);
> > 	} else if (sock->state == SS_CONNECTED) {
> > 		remote_addr = &vsk->remote_addr;
> > +		transport = vsk->transport;
> > +
> > +		err = vsock_auto_bind(vsk);
> > +		if (err)
> > +			goto out;
> > 
> > 		if (remote_addr->svm_cid == VMADDR_CID_ANY)
> > 			remote_addr->svm_cid = transport->get_local_cid();
> > @@ -1205,23 +1242,23 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> > 		/* XXX Should connect() or this function ensure remote_addr is
> > 		 * bound?
> > 		 */
> > -		if (!vsock_addr_bound(&vsk->remote_addr)) {
> > +		if (!vsock_addr_bound(remote_addr)) {
> > 			err = -EINVAL;
> > 			goto out;
> > 		}
> > -	} else {
> > -		err = -EINVAL;
> > -		goto out;
> > -	}
> > 
> > -	if (!transport->dgram_allow(remote_addr->svm_cid,
> > -				    remote_addr->svm_port)) {
> > +		if (!transport->dgram_allow(remote_addr->svm_cid,
> > +					    remote_addr->svm_port)) {
> > +			err = -EINVAL;
> > +			goto out;
> > +		}
> > +
> > +		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
> > +	} else {
> > 		err = -EINVAL;
> > 		goto out;
> > 	}
> > 
> > -	err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
> > -
> > out:
> > 	release_sock(sk);
> > 	return err;
> > @@ -1255,13 +1292,18 @@ static int vsock_dgram_connect(struct socket *sock,
> > 	if (err)
> > 		goto out;
> > 
> > +	memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
> > +
> > +	err = vsock_assign_transport(vsk, NULL);
> > +	if (err)
> > +		goto out;
> > +
> > 	if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
> > 					 remote_addr->svm_port)) {
> > 		err = -EINVAL;
> > 		goto out;
> > 	}
> > 
> > -	memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
> > 	sock->state = SS_CONNECTED;
> > 
> > 	/* sock map disallows redirection of non-TCP sockets with sk_state !=
> > diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
> > index ff6e87e25fa0..c00bc5da769a 100644
> > --- a/net/vmw_vsock/hyperv_transport.c
> > +++ b/net/vmw_vsock/hyperv_transport.c
> > @@ -551,11 +551,6 @@ static void hvs_destruct(struct vsock_sock *vsk)
> > 	kfree(hvs);
> > }
> > 
> > -static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
> > -{
> > -	return -EOPNOTSUPP;
> > -}
> > -
> > static int hvs_dgram_get_cid(struct sk_buff *skb, unsigned int *cid)
> > {
> > 	return -EOPNOTSUPP;
> > @@ -841,7 +836,6 @@ static struct vsock_transport hvs_transport = {
> > 	.connect                  = hvs_connect,
> > 	.shutdown                 = hvs_shutdown,
> > 
> > -	.dgram_bind               = hvs_dgram_bind,
> > 	.dgram_get_cid		  = hvs_dgram_get_cid,
> > 	.dgram_get_port		  = hvs_dgram_get_port,
> > 	.dgram_get_length	  = hvs_dgram_get_length,
> > diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
> > index 5763cdf13804..1b7843a7779a 100644
> > --- a/net/vmw_vsock/virtio_transport.c
> > +++ b/net/vmw_vsock/virtio_transport.c
> > @@ -428,7 +428,6 @@ static struct virtio_transport virtio_transport = {
> > 		.shutdown                 = virtio_transport_shutdown,
> > 		.cancel_pkt               = virtio_transport_cancel_pkt,
> > 
> > -		.dgram_bind               = virtio_transport_dgram_bind,
> > 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
> > 		.dgram_allow              = virtio_transport_dgram_allow,
> > 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
> > diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
> > index e6903c719964..d5a3c8efe84b 100644
> > --- a/net/vmw_vsock/virtio_transport_common.c
> > +++ b/net/vmw_vsock/virtio_transport_common.c
> > @@ -790,13 +790,6 @@ bool virtio_transport_stream_allow(u32 cid, u32 port)
> > }
> > EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
> > 
> > -int virtio_transport_dgram_bind(struct vsock_sock *vsk,
> > -				struct sockaddr_vm *addr)
> > -{
> > -	return -EOPNOTSUPP;
> > -}
> > -EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
> > -
> > int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid)
> > {
> > 	return -EOPNOTSUPP;
> > diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c
> > index 2f3cabc79ee5..e9de45a26fbd 100644
> > --- a/net/vmw_vsock/vsock_loopback.c
> > +++ b/net/vmw_vsock/vsock_loopback.c
> > @@ -61,7 +61,6 @@ static struct virtio_transport loopback_transport = {
> > 		.shutdown                 = virtio_transport_shutdown,
> > 		.cancel_pkt               = vsock_loopback_cancel_pkt,
> > 
> > -		.dgram_bind               = virtio_transport_dgram_bind,
> > 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
> > 		.dgram_allow              = virtio_transport_dgram_allow,
> > 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
> > 
> > -- 
> > 2.30.2
> > 
> 
> The rest LGTM!
> 
> Stefano

Thanks,
Bobby
  
Bobby Eshleman June 23, 2023, 2:59 a.m. UTC | #3
On Fri, Jun 23, 2023 at 02:50:01AM +0000, Bobby Eshleman wrote:
> On Thu, Jun 22, 2023 at 05:19:08PM +0200, Stefano Garzarella wrote:
> > On Sat, Jun 10, 2023 at 12:58:30AM +0000, Bobby Eshleman wrote:
> > > This patch adds support for multi-transport datagrams.
> > > 
> > > This includes:
> > > - Per-packet lookup of transports when using sendto(sockaddr_vm)
> > > - Selecting H2G or G2H transport using VMADDR_FLAG_TO_HOST and CID in
> > >  sockaddr_vm
> > > 
> > > To preserve backwards compatibility with VMCI, some important changes
> > > were made. The "transport_dgram" / VSOCK_TRANSPORT_F_DGRAM is changed to
> > > be used for dgrams iff there is not yet a g2h or h2g transport that has
> > 
> > s/iff/if
> > 
> > > been registered that can transmit the packet. If there is a g2h/h2g
> > > transport for that remote address, then that transport will be used and
> > > not "transport_dgram". This essentially makes "transport_dgram" a
> > > fallback transport for when h2g/g2h has not yet gone online, which
> > > appears to be the exact use case for VMCI.
> > > 
> > > This design makes sense, because there is no reason that the
> > > transport_{g2h,h2g} cannot also service datagrams, which makes the role
> > > of transport_dgram difficult to understand outside of the VMCI context.
> > > 
> > > The logic around "transport_dgram" had to be retained to prevent
> > > breaking VMCI:
> > > 
> > > 1) VMCI datagrams appear to function outside of the h2g/g2h
> > >   paradigm. When the vmci transport becomes online, it registers itself
> > >   with the DGRAM feature, but not H2G/G2H. Only later when the
> > >   transport has more information about its environment does it register
> > >   H2G or G2H. In the case that a datagram socket becomes active
> > >   after DGRAM registration but before G2H/H2G registration, the
> > >   "transport_dgram" transport needs to be used.
> > 
> > IIRC we did this, because at that time only VMCI supported DGRAM. Now that
> > there are more transports, maybe DGRAM can follow the h2g/g2h paradigm.
> > 
> 
> Totally makes sense. I'll add the detail above that the prior design was
> a result of chronology.
> 
> > > 
> > > 2) VMCI seems to require special message be sent by the transport when a
> > >   datagram socket calls bind(). Under the h2g/g2h model, the transport
> > >   is selected using the remote_addr which is set by connect(). At
> > >   bind time there is no remote_addr because often no connect() has been
> > >   called yet: the transport is null. Therefore, with a null transport
> > >   there doesn't seem to be any good way for a datagram socket a tell the
> > >   VMCI transport that it has just had bind() called upon it.
> > 
> > @Vishnu, @Bryan do you think we can avoid this in some way?
> > 
> > > 
> > > Only transports with a special datagram fallback use-case such as VMCI
> > > need to register VSOCK_TRANSPORT_F_DGRAM.
> > 
> > Maybe we should rename it in VSOCK_TRANSPORT_F_DGRAM_FALLBACK or
> > something like that.
> > 
> > In any case, we definitely need to update the comment in
> > include/net/af_vsock.h on top of VSOCK_TRANSPORT_F_DGRAM mentioning
> > this.
> > 
> 
> Agreed. I'll rename to VSOCK_TRANSPORT_F_DGRAM_FALLBACK, unless we find
> there is a better way altogether.
> 
> > > 
> > > Signed-off-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
> > > ---
> > > drivers/vhost/vsock.c                   |  1 -
> > > include/linux/virtio_vsock.h            |  2 -
> > > net/vmw_vsock/af_vsock.c                | 78 +++++++++++++++++++++++++--------
> > > net/vmw_vsock/hyperv_transport.c        |  6 ---
> > > net/vmw_vsock/virtio_transport.c        |  1 -
> > > net/vmw_vsock/virtio_transport_common.c |  7 ---
> > > net/vmw_vsock/vsock_loopback.c          |  1 -
> > > 7 files changed, 60 insertions(+), 36 deletions(-)
> > > 
> > > diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
> > > index c8201c070b4b..8f0082da5e70 100644
> > > --- a/drivers/vhost/vsock.c
> > > +++ b/drivers/vhost/vsock.c
> > > @@ -410,7 +410,6 @@ static struct virtio_transport vhost_transport = {
> > > 		.cancel_pkt               = vhost_transport_cancel_pkt,
> > > 
> > > 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
> > > -		.dgram_bind               = virtio_transport_dgram_bind,
> > > 		.dgram_allow              = virtio_transport_dgram_allow,
> > > 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
> > > 		.dgram_get_port		  = virtio_transport_dgram_get_port,
> > > diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
> > > index 23521a318cf0..73afa09f4585 100644
> > > --- a/include/linux/virtio_vsock.h
> > > +++ b/include/linux/virtio_vsock.h
> > > @@ -216,8 +216,6 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val);
> > > u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
> > > bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
> > > bool virtio_transport_stream_allow(u32 cid, u32 port);
> > > -int virtio_transport_dgram_bind(struct vsock_sock *vsk,
> > > -				struct sockaddr_vm *addr);
> > > bool virtio_transport_dgram_allow(u32 cid, u32 port);
> > > int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid);
> > > int virtio_transport_dgram_get_port(struct sk_buff *skb, unsigned int *port);
> > > diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
> > > index 74358f0b47fa..ef86765f3765 100644
> > > --- a/net/vmw_vsock/af_vsock.c
> > > +++ b/net/vmw_vsock/af_vsock.c
> > > @@ -438,6 +438,18 @@ vsock_connectible_lookup_transport(unsigned int cid, __u8 flags)
> > > 	return transport;
> > > }
> > > 
> > > +static const struct vsock_transport *
> > > +vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
> > > +{
> > > +	const struct vsock_transport *transport;
> > > +
> > > +	transport = vsock_connectible_lookup_transport(cid, flags);
> > > +	if (transport)
> > > +		return transport;
> > > +
> > > +	return transport_dgram;
> > > +}
> > > +
> > > /* Assign a transport to a socket and call the .init transport callback.
> > >  *
> > >  * Note: for connection oriented socket this must be called when vsk->remote_addr
> > > @@ -474,7 +486,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
> > > 
> > > 	switch (sk->sk_type) {
> > > 	case SOCK_DGRAM:
> > > -		new_transport = transport_dgram;
> > > +		new_transport = vsock_dgram_lookup_transport(remote_cid,
> > > +							     remote_flags);
> > > 		break;
> > > 	case SOCK_STREAM:
> > > 	case SOCK_SEQPACKET:
> > > @@ -691,6 +704,9 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
> > > static int __vsock_bind_dgram(struct vsock_sock *vsk,
> > > 			      struct sockaddr_vm *addr)
> > > {
> > > +	if (!vsk->transport || !vsk->transport->dgram_bind)
> > > +		return -EINVAL;
> > > +
> > > 	return vsk->transport->dgram_bind(vsk, addr);
> > > }
> > > 
> > > @@ -1172,19 +1188,24 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> > > 
> > > 	lock_sock(sk);
> > > 
> > > -	transport = vsk->transport;
> > > -
> > > -	err = vsock_auto_bind(vsk);
> > > -	if (err)
> > > -		goto out;
> > > -
> > > -
> > > 	/* If the provided message contains an address, use that.  Otherwise
> > > 	 * fall back on the socket's remote handle (if it has been connected).
> > > 	 */
> > > 	if (msg->msg_name &&
> > > 	    vsock_addr_cast(msg->msg_name, msg->msg_namelen,
> > > 			    &remote_addr) == 0) {
> > > +		transport = vsock_dgram_lookup_transport(remote_addr->svm_cid,
> > > +							 remote_addr->svm_flags);
> > > +		if (!transport) {
> > > +			err = -EINVAL;
> > > +			goto out;
> > > +		}
> > > +
> > > +		if (!try_module_get(transport->module)) {
> > > +			err = -ENODEV;
> > > +			goto out;
> > > +		}
> > > +
> > > 		/* Ensure this address is of the right type and is a valid
> > > 		 * destination.
> > > 		 */
> > > @@ -1193,11 +1214,27 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> > > 			remote_addr->svm_cid = transport->get_local_cid();
> > > 
> > 
> > From here ...
> > 
> > > 		if (!vsock_addr_bound(remote_addr)) {
> > > +			module_put(transport->module);
> > > +			err = -EINVAL;
> > > +			goto out;
> > > +		}
> > > +
> > > +		if (!transport->dgram_allow(remote_addr->svm_cid,
> > > +					    remote_addr->svm_port)) {
> > > +			module_put(transport->module);
> > > 			err = -EINVAL;
> > > 			goto out;
> > > 		}
> > > +
> > > +		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
> > 
> > ... to here, looks like duplicate code, can we get it out of the if block?
> > 
> 
> Yes, I think using something like this:
> 
> [...]
> 	bool module_got = false;
> 
> [...]
> 		if (!try_module_get(transport->module)) {
> 			err = -ENODEV;
> 			goto out;
> 		}
> 		module_got = true;
> 
> [...]
> 
> out:
> 	if (likely(transport && !err && module_got))

Actually, just...

	if (module_got)

> 		module_put(transport->module)
> 
> > > +		module_put(transport->module);
> > > 	} else if (sock->state == SS_CONNECTED) {
> > > 		remote_addr = &vsk->remote_addr;
> > > +		transport = vsk->transport;
> > > +
> > > +		err = vsock_auto_bind(vsk);
> > > +		if (err)
> > > +			goto out;
> > > 
> > > 		if (remote_addr->svm_cid == VMADDR_CID_ANY)
> > > 			remote_addr->svm_cid = transport->get_local_cid();
> > > @@ -1205,23 +1242,23 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> > > 		/* XXX Should connect() or this function ensure remote_addr is
> > > 		 * bound?
> > > 		 */
> > > -		if (!vsock_addr_bound(&vsk->remote_addr)) {
> > > +		if (!vsock_addr_bound(remote_addr)) {
> > > 			err = -EINVAL;
> > > 			goto out;
> > > 		}
> > > -	} else {
> > > -		err = -EINVAL;
> > > -		goto out;
> > > -	}
> > > 
> > > -	if (!transport->dgram_allow(remote_addr->svm_cid,
> > > -				    remote_addr->svm_port)) {
> > > +		if (!transport->dgram_allow(remote_addr->svm_cid,
> > > +					    remote_addr->svm_port)) {
> > > +			err = -EINVAL;
> > > +			goto out;
> > > +		}
> > > +
> > > +		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
> > > +	} else {
> > > 		err = -EINVAL;
> > > 		goto out;
> > > 	}
> > > 
> > > -	err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
> > > -
> > > out:
> > > 	release_sock(sk);
> > > 	return err;
> > > @@ -1255,13 +1292,18 @@ static int vsock_dgram_connect(struct socket *sock,
> > > 	if (err)
> > > 		goto out;
> > > 
> > > +	memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
> > > +
> > > +	err = vsock_assign_transport(vsk, NULL);
> > > +	if (err)
> > > +		goto out;
> > > +
> > > 	if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
> > > 					 remote_addr->svm_port)) {
> > > 		err = -EINVAL;
> > > 		goto out;
> > > 	}
> > > 
> > > -	memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
> > > 	sock->state = SS_CONNECTED;
> > > 
> > > 	/* sock map disallows redirection of non-TCP sockets with sk_state !=
> > > diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
> > > index ff6e87e25fa0..c00bc5da769a 100644
> > > --- a/net/vmw_vsock/hyperv_transport.c
> > > +++ b/net/vmw_vsock/hyperv_transport.c
> > > @@ -551,11 +551,6 @@ static void hvs_destruct(struct vsock_sock *vsk)
> > > 	kfree(hvs);
> > > }
> > > 
> > > -static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
> > > -{
> > > -	return -EOPNOTSUPP;
> > > -}
> > > -
> > > static int hvs_dgram_get_cid(struct sk_buff *skb, unsigned int *cid)
> > > {
> > > 	return -EOPNOTSUPP;
> > > @@ -841,7 +836,6 @@ static struct vsock_transport hvs_transport = {
> > > 	.connect                  = hvs_connect,
> > > 	.shutdown                 = hvs_shutdown,
> > > 
> > > -	.dgram_bind               = hvs_dgram_bind,
> > > 	.dgram_get_cid		  = hvs_dgram_get_cid,
> > > 	.dgram_get_port		  = hvs_dgram_get_port,
> > > 	.dgram_get_length	  = hvs_dgram_get_length,
> > > diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
> > > index 5763cdf13804..1b7843a7779a 100644
> > > --- a/net/vmw_vsock/virtio_transport.c
> > > +++ b/net/vmw_vsock/virtio_transport.c
> > > @@ -428,7 +428,6 @@ static struct virtio_transport virtio_transport = {
> > > 		.shutdown                 = virtio_transport_shutdown,
> > > 		.cancel_pkt               = virtio_transport_cancel_pkt,
> > > 
> > > -		.dgram_bind               = virtio_transport_dgram_bind,
> > > 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
> > > 		.dgram_allow              = virtio_transport_dgram_allow,
> > > 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
> > > diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
> > > index e6903c719964..d5a3c8efe84b 100644
> > > --- a/net/vmw_vsock/virtio_transport_common.c
> > > +++ b/net/vmw_vsock/virtio_transport_common.c
> > > @@ -790,13 +790,6 @@ bool virtio_transport_stream_allow(u32 cid, u32 port)
> > > }
> > > EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
> > > 
> > > -int virtio_transport_dgram_bind(struct vsock_sock *vsk,
> > > -				struct sockaddr_vm *addr)
> > > -{
> > > -	return -EOPNOTSUPP;
> > > -}
> > > -EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
> > > -
> > > int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid)
> > > {
> > > 	return -EOPNOTSUPP;
> > > diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c
> > > index 2f3cabc79ee5..e9de45a26fbd 100644
> > > --- a/net/vmw_vsock/vsock_loopback.c
> > > +++ b/net/vmw_vsock/vsock_loopback.c
> > > @@ -61,7 +61,6 @@ static struct virtio_transport loopback_transport = {
> > > 		.shutdown                 = virtio_transport_shutdown,
> > > 		.cancel_pkt               = vsock_loopback_cancel_pkt,
> > > 
> > > -		.dgram_bind               = virtio_transport_dgram_bind,
> > > 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
> > > 		.dgram_allow              = virtio_transport_dgram_allow,
> > > 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
> > > 
> > > -- 
> > > 2.30.2
> > > 
> > 
> > The rest LGTM!
> > 
> > Stefano
> 
> Thanks,
> Bobby
  
Stefano Garzarella June 26, 2023, 2:50 p.m. UTC | #4
On Fri, Jun 23, 2023 at 02:59:23AM +0000, Bobby Eshleman wrote:
>On Fri, Jun 23, 2023 at 02:50:01AM +0000, Bobby Eshleman wrote:
>> On Thu, Jun 22, 2023 at 05:19:08PM +0200, Stefano Garzarella wrote:
>> > On Sat, Jun 10, 2023 at 12:58:30AM +0000, Bobby Eshleman wrote:
>> > > This patch adds support for multi-transport datagrams.
>> > >
>> > > This includes:
>> > > - Per-packet lookup of transports when using sendto(sockaddr_vm)
>> > > - Selecting H2G or G2H transport using VMADDR_FLAG_TO_HOST and CID in
>> > >  sockaddr_vm
>> > >
>> > > To preserve backwards compatibility with VMCI, some important changes
>> > > were made. The "transport_dgram" / VSOCK_TRANSPORT_F_DGRAM is changed to
>> > > be used for dgrams iff there is not yet a g2h or h2g transport that has
>> >
>> > s/iff/if
>> >
>> > > been registered that can transmit the packet. If there is a g2h/h2g
>> > > transport for that remote address, then that transport will be used and
>> > > not "transport_dgram". This essentially makes "transport_dgram" a
>> > > fallback transport for when h2g/g2h has not yet gone online, which
>> > > appears to be the exact use case for VMCI.
>> > >
>> > > This design makes sense, because there is no reason that the
>> > > transport_{g2h,h2g} cannot also service datagrams, which makes the role
>> > > of transport_dgram difficult to understand outside of the VMCI context.
>> > >
>> > > The logic around "transport_dgram" had to be retained to prevent
>> > > breaking VMCI:
>> > >
>> > > 1) VMCI datagrams appear to function outside of the h2g/g2h
>> > >   paradigm. When the vmci transport becomes online, it registers itself
>> > >   with the DGRAM feature, but not H2G/G2H. Only later when the
>> > >   transport has more information about its environment does it register
>> > >   H2G or G2H. In the case that a datagram socket becomes active
>> > >   after DGRAM registration but before G2H/H2G registration, the
>> > >   "transport_dgram" transport needs to be used.
>> >
>> > IIRC we did this, because at that time only VMCI supported DGRAM. Now that
>> > there are more transports, maybe DGRAM can follow the h2g/g2h paradigm.
>> >
>>
>> Totally makes sense. I'll add the detail above that the prior design was
>> a result of chronology.
>>
>> > >
>> > > 2) VMCI seems to require special message be sent by the transport when a
>> > >   datagram socket calls bind(). Under the h2g/g2h model, the transport
>> > >   is selected using the remote_addr which is set by connect(). At
>> > >   bind time there is no remote_addr because often no connect() has been
>> > >   called yet: the transport is null. Therefore, with a null transport
>> > >   there doesn't seem to be any good way for a datagram socket a tell the
>> > >   VMCI transport that it has just had bind() called upon it.
>> >
>> > @Vishnu, @Bryan do you think we can avoid this in some way?
>> >
>> > >
>> > > Only transports with a special datagram fallback use-case such as VMCI
>> > > need to register VSOCK_TRANSPORT_F_DGRAM.
>> >
>> > Maybe we should rename it in VSOCK_TRANSPORT_F_DGRAM_FALLBACK or
>> > something like that.
>> >
>> > In any case, we definitely need to update the comment in
>> > include/net/af_vsock.h on top of VSOCK_TRANSPORT_F_DGRAM mentioning
>> > this.
>> >
>>
>> Agreed. I'll rename to VSOCK_TRANSPORT_F_DGRAM_FALLBACK, unless we find
>> there is a better way altogether.
>>
>> > >
>> > > Signed-off-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
>> > > ---
>> > > drivers/vhost/vsock.c                   |  1 -
>> > > include/linux/virtio_vsock.h            |  2 -
>> > > net/vmw_vsock/af_vsock.c                | 78 +++++++++++++++++++++++++--------
>> > > net/vmw_vsock/hyperv_transport.c        |  6 ---
>> > > net/vmw_vsock/virtio_transport.c        |  1 -
>> > > net/vmw_vsock/virtio_transport_common.c |  7 ---
>> > > net/vmw_vsock/vsock_loopback.c          |  1 -
>> > > 7 files changed, 60 insertions(+), 36 deletions(-)
>> > >
>> > > diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
>> > > index c8201c070b4b..8f0082da5e70 100644
>> > > --- a/drivers/vhost/vsock.c
>> > > +++ b/drivers/vhost/vsock.c
>> > > @@ -410,7 +410,6 @@ static struct virtio_transport vhost_transport = {
>> > > 		.cancel_pkt               = vhost_transport_cancel_pkt,
>> > >
>> > > 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
>> > > -		.dgram_bind               = virtio_transport_dgram_bind,
>> > > 		.dgram_allow              = virtio_transport_dgram_allow,
>> > > 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
>> > > 		.dgram_get_port		  = virtio_transport_dgram_get_port,
>> > > diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
>> > > index 23521a318cf0..73afa09f4585 100644
>> > > --- a/include/linux/virtio_vsock.h
>> > > +++ b/include/linux/virtio_vsock.h
>> > > @@ -216,8 +216,6 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val);
>> > > u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
>> > > bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
>> > > bool virtio_transport_stream_allow(u32 cid, u32 port);
>> > > -int virtio_transport_dgram_bind(struct vsock_sock *vsk,
>> > > -				struct sockaddr_vm *addr);
>> > > bool virtio_transport_dgram_allow(u32 cid, u32 port);
>> > > int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid);
>> > > int virtio_transport_dgram_get_port(struct sk_buff *skb, unsigned int *port);
>> > > diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
>> > > index 74358f0b47fa..ef86765f3765 100644
>> > > --- a/net/vmw_vsock/af_vsock.c
>> > > +++ b/net/vmw_vsock/af_vsock.c
>> > > @@ -438,6 +438,18 @@ vsock_connectible_lookup_transport(unsigned int cid, __u8 flags)
>> > > 	return transport;
>> > > }
>> > >
>> > > +static const struct vsock_transport *
>> > > +vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
>> > > +{
>> > > +	const struct vsock_transport *transport;
>> > > +
>> > > +	transport = vsock_connectible_lookup_transport(cid, flags);
>> > > +	if (transport)
>> > > +		return transport;
>> > > +
>> > > +	return transport_dgram;
>> > > +}
>> > > +
>> > > /* Assign a transport to a socket and call the .init transport callback.
>> > >  *
>> > >  * Note: for connection oriented socket this must be called when vsk->remote_addr
>> > > @@ -474,7 +486,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
>> > >
>> > > 	switch (sk->sk_type) {
>> > > 	case SOCK_DGRAM:
>> > > -		new_transport = transport_dgram;
>> > > +		new_transport = vsock_dgram_lookup_transport(remote_cid,
>> > > +							     remote_flags);
>> > > 		break;
>> > > 	case SOCK_STREAM:
>> > > 	case SOCK_SEQPACKET:
>> > > @@ -691,6 +704,9 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
>> > > static int __vsock_bind_dgram(struct vsock_sock *vsk,
>> > > 			      struct sockaddr_vm *addr)
>> > > {
>> > > +	if (!vsk->transport || !vsk->transport->dgram_bind)
>> > > +		return -EINVAL;
>> > > +
>> > > 	return vsk->transport->dgram_bind(vsk, addr);
>> > > }
>> > >
>> > > @@ -1172,19 +1188,24 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
>> > >
>> > > 	lock_sock(sk);
>> > >
>> > > -	transport = vsk->transport;
>> > > -
>> > > -	err = vsock_auto_bind(vsk);
>> > > -	if (err)
>> > > -		goto out;
>> > > -
>> > > -
>> > > 	/* If the provided message contains an address, use that.  Otherwise
>> > > 	 * fall back on the socket's remote handle (if it has been connected).
>> > > 	 */
>> > > 	if (msg->msg_name &&
>> > > 	    vsock_addr_cast(msg->msg_name, msg->msg_namelen,
>> > > 			    &remote_addr) == 0) {
>> > > +		transport = vsock_dgram_lookup_transport(remote_addr->svm_cid,
>> > > +							 remote_addr->svm_flags);
>> > > +		if (!transport) {
>> > > +			err = -EINVAL;
>> > > +			goto out;
>> > > +		}
>> > > +
>> > > +		if (!try_module_get(transport->module)) {
>> > > +			err = -ENODEV;
>> > > +			goto out;
>> > > +		}
>> > > +
>> > > 		/* Ensure this address is of the right type and is a valid
>> > > 		 * destination.
>> > > 		 */
>> > > @@ -1193,11 +1214,27 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
>> > > 			remote_addr->svm_cid = transport->get_local_cid();
>> > >
>> >
>> > From here ...
>> >
>> > > 		if (!vsock_addr_bound(remote_addr)) {
>> > > +			module_put(transport->module);
>> > > +			err = -EINVAL;
>> > > +			goto out;
>> > > +		}
>> > > +
>> > > +		if (!transport->dgram_allow(remote_addr->svm_cid,
>> > > +					    remote_addr->svm_port)) {
>> > > +			module_put(transport->module);
>> > > 			err = -EINVAL;
>> > > 			goto out;
>> > > 		}
>> > > +
>> > > +		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
>> >
>> > ... to here, looks like duplicate code, can we get it out of the if block?
>> >
>>
>> Yes, I think using something like this:
>>
>> [...]
>> 	bool module_got = false;
>>
>> [...]
>> 		if (!try_module_get(transport->module)) {
>> 			err = -ENODEV;
>> 			goto out;
>> 		}
>> 		module_got = true;
>>
>> [...]
>>
>> out:
>> 	if (likely(transport && !err && module_got))
>
>Actually, just...
>
>	if (module_got)
>

Yep, I think it should work ;-)

Thanks,
Stefano
  

Patch

diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index c8201c070b4b..8f0082da5e70 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -410,7 +410,6 @@  static struct virtio_transport vhost_transport = {
 		.cancel_pkt               = vhost_transport_cancel_pkt,
 
 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
-		.dgram_bind               = virtio_transport_dgram_bind,
 		.dgram_allow              = virtio_transport_dgram_allow,
 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
 		.dgram_get_port		  = virtio_transport_dgram_get_port,
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index 23521a318cf0..73afa09f4585 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -216,8 +216,6 @@  void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val);
 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
 bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
 bool virtio_transport_stream_allow(u32 cid, u32 port);
-int virtio_transport_dgram_bind(struct vsock_sock *vsk,
-				struct sockaddr_vm *addr);
 bool virtio_transport_dgram_allow(u32 cid, u32 port);
 int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid);
 int virtio_transport_dgram_get_port(struct sk_buff *skb, unsigned int *port);
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 74358f0b47fa..ef86765f3765 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -438,6 +438,18 @@  vsock_connectible_lookup_transport(unsigned int cid, __u8 flags)
 	return transport;
 }
 
+static const struct vsock_transport *
+vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
+{
+	const struct vsock_transport *transport;
+
+	transport = vsock_connectible_lookup_transport(cid, flags);
+	if (transport)
+		return transport;
+
+	return transport_dgram;
+}
+
 /* Assign a transport to a socket and call the .init transport callback.
  *
  * Note: for connection oriented socket this must be called when vsk->remote_addr
@@ -474,7 +486,8 @@  int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
 
 	switch (sk->sk_type) {
 	case SOCK_DGRAM:
-		new_transport = transport_dgram;
+		new_transport = vsock_dgram_lookup_transport(remote_cid,
+							     remote_flags);
 		break;
 	case SOCK_STREAM:
 	case SOCK_SEQPACKET:
@@ -691,6 +704,9 @@  static int __vsock_bind_connectible(struct vsock_sock *vsk,
 static int __vsock_bind_dgram(struct vsock_sock *vsk,
 			      struct sockaddr_vm *addr)
 {
+	if (!vsk->transport || !vsk->transport->dgram_bind)
+		return -EINVAL;
+
 	return vsk->transport->dgram_bind(vsk, addr);
 }
 
@@ -1172,19 +1188,24 @@  static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
 
 	lock_sock(sk);
 
-	transport = vsk->transport;
-
-	err = vsock_auto_bind(vsk);
-	if (err)
-		goto out;
-
-
 	/* If the provided message contains an address, use that.  Otherwise
 	 * fall back on the socket's remote handle (if it has been connected).
 	 */
 	if (msg->msg_name &&
 	    vsock_addr_cast(msg->msg_name, msg->msg_namelen,
 			    &remote_addr) == 0) {
+		transport = vsock_dgram_lookup_transport(remote_addr->svm_cid,
+							 remote_addr->svm_flags);
+		if (!transport) {
+			err = -EINVAL;
+			goto out;
+		}
+
+		if (!try_module_get(transport->module)) {
+			err = -ENODEV;
+			goto out;
+		}
+
 		/* Ensure this address is of the right type and is a valid
 		 * destination.
 		 */
@@ -1193,11 +1214,27 @@  static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
 			remote_addr->svm_cid = transport->get_local_cid();
 
 		if (!vsock_addr_bound(remote_addr)) {
+			module_put(transport->module);
+			err = -EINVAL;
+			goto out;
+		}
+
+		if (!transport->dgram_allow(remote_addr->svm_cid,
+					    remote_addr->svm_port)) {
+			module_put(transport->module);
 			err = -EINVAL;
 			goto out;
 		}
+
+		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
+		module_put(transport->module);
 	} else if (sock->state == SS_CONNECTED) {
 		remote_addr = &vsk->remote_addr;
+		transport = vsk->transport;
+
+		err = vsock_auto_bind(vsk);
+		if (err)
+			goto out;
 
 		if (remote_addr->svm_cid == VMADDR_CID_ANY)
 			remote_addr->svm_cid = transport->get_local_cid();
@@ -1205,23 +1242,23 @@  static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
 		/* XXX Should connect() or this function ensure remote_addr is
 		 * bound?
 		 */
-		if (!vsock_addr_bound(&vsk->remote_addr)) {
+		if (!vsock_addr_bound(remote_addr)) {
 			err = -EINVAL;
 			goto out;
 		}
-	} else {
-		err = -EINVAL;
-		goto out;
-	}
 
-	if (!transport->dgram_allow(remote_addr->svm_cid,
-				    remote_addr->svm_port)) {
+		if (!transport->dgram_allow(remote_addr->svm_cid,
+					    remote_addr->svm_port)) {
+			err = -EINVAL;
+			goto out;
+		}
+
+		err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
+	} else {
 		err = -EINVAL;
 		goto out;
 	}
 
-	err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
-
 out:
 	release_sock(sk);
 	return err;
@@ -1255,13 +1292,18 @@  static int vsock_dgram_connect(struct socket *sock,
 	if (err)
 		goto out;
 
+	memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
+
+	err = vsock_assign_transport(vsk, NULL);
+	if (err)
+		goto out;
+
 	if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
 					 remote_addr->svm_port)) {
 		err = -EINVAL;
 		goto out;
 	}
 
-	memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
 	sock->state = SS_CONNECTED;
 
 	/* sock map disallows redirection of non-TCP sockets with sk_state !=
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index ff6e87e25fa0..c00bc5da769a 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -551,11 +551,6 @@  static void hvs_destruct(struct vsock_sock *vsk)
 	kfree(hvs);
 }
 
-static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
-{
-	return -EOPNOTSUPP;
-}
-
 static int hvs_dgram_get_cid(struct sk_buff *skb, unsigned int *cid)
 {
 	return -EOPNOTSUPP;
@@ -841,7 +836,6 @@  static struct vsock_transport hvs_transport = {
 	.connect                  = hvs_connect,
 	.shutdown                 = hvs_shutdown,
 
-	.dgram_bind               = hvs_dgram_bind,
 	.dgram_get_cid		  = hvs_dgram_get_cid,
 	.dgram_get_port		  = hvs_dgram_get_port,
 	.dgram_get_length	  = hvs_dgram_get_length,
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 5763cdf13804..1b7843a7779a 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -428,7 +428,6 @@  static struct virtio_transport virtio_transport = {
 		.shutdown                 = virtio_transport_shutdown,
 		.cancel_pkt               = virtio_transport_cancel_pkt,
 
-		.dgram_bind               = virtio_transport_dgram_bind,
 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
 		.dgram_allow              = virtio_transport_dgram_allow,
 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index e6903c719964..d5a3c8efe84b 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -790,13 +790,6 @@  bool virtio_transport_stream_allow(u32 cid, u32 port)
 }
 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
 
-int virtio_transport_dgram_bind(struct vsock_sock *vsk,
-				struct sockaddr_vm *addr)
-{
-	return -EOPNOTSUPP;
-}
-EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
-
 int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid)
 {
 	return -EOPNOTSUPP;
diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c
index 2f3cabc79ee5..e9de45a26fbd 100644
--- a/net/vmw_vsock/vsock_loopback.c
+++ b/net/vmw_vsock/vsock_loopback.c
@@ -61,7 +61,6 @@  static struct virtio_transport loopback_transport = {
 		.shutdown                 = virtio_transport_shutdown,
 		.cancel_pkt               = vsock_loopback_cancel_pkt,
 
-		.dgram_bind               = virtio_transport_dgram_bind,
 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
 		.dgram_allow              = virtio_transport_dgram_allow,
 		.dgram_get_cid		  = virtio_transport_dgram_get_cid,