[v3,29/55] tls/sw: Support MSG_SPLICE_PAGES

Message ID 20230331160914.1608208-30-dhowells@redhat.com
State New
Headers
Series splice, net: Replace sendpage with sendmsg(MSG_SPLICE_PAGES) |

Commit Message

David Howells March 31, 2023, 4:08 p.m. UTC
  Make TLS's sendmsg() support MSG_SPLICE_PAGES.  This causes pages to be
spliced from the source iterator if possible and copied the data if not.

This allows ->sendpage() to be replaced by something that can handle
multiple multipage folios in a single transaction.

Signed-off-by: David Howells <dhowells@redhat.com>
cc: Chuck Lever <chuck.lever@oracle.com>
cc: Boris Pismenny <borisp@nvidia.com>
cc: John Fastabend <john.fastabend@gmail.com>
cc: Jakub Kicinski <kuba@kernel.org>
cc: Eric Dumazet <edumazet@google.com>
cc: "David S. Miller" <davem@davemloft.net>
cc: Paolo Abeni <pabeni@redhat.com>
cc: Jens Axboe <axboe@kernel.dk>
cc: Matthew Wilcox <willy@infradead.org>
cc: netdev@vger.kernel.org
---
 net/tls/tls_sw.c | 57 +++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 56 insertions(+), 1 deletion(-)
  

Comments

David Howells March 31, 2023, 4:27 p.m. UTC | #1
Here's a trivial TLS server that can be used to test this.

David
---
/*
 * TLS-over-TCP sink server
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <fcntl.h>
#include <unistd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <linux/tls.h>

#define OSERROR(X, Y) do { if ((long)(X) == -1) { perror(Y); exit(1); } } while(0)

static unsigned char buffer[512 * 1024] __attribute__((aligned(4096)));

static void set_tls(int sock)
{
	struct tls12_crypto_info_aes_gcm_128 crypto_info;

	crypto_info.info.version = TLS_1_2_VERSION;
	crypto_info.info.cipher_type = TLS_CIPHER_AES_GCM_128;
	memset(crypto_info.iv,		0, TLS_CIPHER_AES_GCM_128_IV_SIZE);
	memset(crypto_info.rec_seq,	0, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
	memset(crypto_info.key,		0, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
	memset(crypto_info.salt,	0, TLS_CIPHER_AES_GCM_128_SALT_SIZE);

	OSERROR(setsockopt(sock, SOL_TCP, TCP_ULP, "tls", sizeof("tls")),
		"TCP_ULP");
	OSERROR(setsockopt(sock, SOL_TLS, TLS_TX, &crypto_info, sizeof(crypto_info)),
		"TLS_TX");
	OSERROR(setsockopt(sock, SOL_TLS, TLS_RX, &crypto_info, sizeof(crypto_info)),
		"TLS_RX");
}

int main(int argc, char *argv[])
{
	struct sockaddr_in sin = { .sin_family = AF_INET, .sin_port = htons(5556) };
	int sfd, afd;

	sfd = socket(AF_INET, SOCK_STREAM, 0);
	OSERROR(sfd, "socket");
	OSERROR(bind(sfd, (struct sockaddr *)&sin, sizeof(sin)), "bind");
	OSERROR(listen(sfd, 1), "listen");

	for (;;) {
		afd = accept(sfd, NULL, NULL);
		if (afd != -1) {
			set_tls(afd);
			while (read(afd, buffer, sizeof(buffer)) > 0) {}
			close(afd);
		}
	}
}
  
David Howells March 31, 2023, 4:28 p.m. UTC | #2
Here's a trivial TLS client program for testing this.

David
---
/*
 * TLS-over-TCP send client
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <fcntl.h>
#include <unistd.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/stat.h>
#include <sys/sendfile.h>
#include <linux/tls.h>

#define OSERROR(X, Y) do { if ((long)(X) == -1) { perror(Y); exit(1); } } while(0)

static unsigned char buffer[4096] __attribute__((aligned(4096)));

static void set_tls(int sock)
{
	struct tls12_crypto_info_aes_gcm_128 crypto_info;

	crypto_info.info.version = TLS_1_2_VERSION;
	crypto_info.info.cipher_type = TLS_CIPHER_AES_GCM_128;
	memset(crypto_info.iv,		0, TLS_CIPHER_AES_GCM_128_IV_SIZE);
	memset(crypto_info.rec_seq,	0, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
	memset(crypto_info.key,		0, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
	memset(crypto_info.salt,	0, TLS_CIPHER_AES_GCM_128_SALT_SIZE);

	OSERROR(setsockopt(sock, SOL_TCP, TCP_ULP, "tls", sizeof("tls")),
		"TCP_ULP");
	OSERROR(setsockopt(sock, SOL_TLS, TLS_TX, &crypto_info, sizeof(crypto_info)),
		"TLS_TX");
	OSERROR(setsockopt(sock, SOL_TLS, TLS_RX, &crypto_info, sizeof(crypto_info)),
		"TLS_RX");
}

int main(int argc, char *argv[])
{
	struct sockaddr_in sin = { .sin_family = AF_INET, .sin_port = htons(5556) };
	struct hostent *h;
	struct stat st;
	ssize_t r, o;
	int sf = 0;
	int cfd, fd;

	if (argc > 1 && strcmp(argv[1], "-s") == 0) {
		sf = 1;
		argc--;
		argv++;
	}
	
	if (argc != 3) {
		fprintf(stderr, "tcp-send [-s] <server> <file>\n");
		exit(2);
	}

	h = gethostbyname(argv[1]);
	if (!h) {
		fprintf(stderr, "%s: %s\n", argv[1], hstrerror(h_errno));
		exit(3);
	}

	if (!h->h_addr_list[0]) {
		fprintf(stderr, "%s: No addresses\n", argv[1]);
		exit(3);
	}

	memcpy(&sin.sin_addr, h->h_addr_list[0], h->h_length);
	
	cfd = socket(AF_INET, SOCK_STREAM, 0);
	OSERROR(cfd, "socket");
	OSERROR(connect(cfd, (struct sockaddr *)&sin, sizeof(sin)), "connect");
	set_tls(cfd);

	fd = open(argv[2], O_RDONLY);
	OSERROR(fd, argv[2]);
	OSERROR(fstat(fd, &st), argv[2]);

	if (!sf) {
		for (;;) {
			r = read(fd, buffer, sizeof(buffer));
			OSERROR(r, argv[2]);
			if (r == 0)
				break;

			o = 0;
			do {
				ssize_t w = write(cfd, buffer + o, r - o);
				OSERROR(w, "write");
				o += w;
			} while (o < r);
		}
	} else {
		off_t off = 0;
		r = sendfile(cfd, fd, &off, st.st_size);
		OSERROR(r, "sendfile");
		if (r != st.st_size) {
			fprintf(stderr, "Short sendfile\n");
			exit(1);
		}
	}

	OSERROR(close(cfd), "close/c");
	OSERROR(close(fd), "close/f");
	return 0;
}
  

Patch

diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 782d3701b86f..ce0c289e68ca 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -929,6 +929,49 @@  static int tls_sw_push_pending_record(struct sock *sk, int flags)
 				   &copied, flags);
 }
 
+static int rls_sw_sendmsg_splice(struct sock *sk, struct msghdr *msg,
+				 struct sk_msg *msg_pl, size_t try_to_copy,
+				 ssize_t *copied)
+{
+	struct page *page, **pages = &page;
+
+	do {
+		ssize_t part;
+		size_t off;
+		bool put = false;
+
+		part = iov_iter_extract_pages(&msg->msg_iter, &pages,
+					      try_to_copy, 1, 0, &off);
+		if (part <= 0)
+			return part ?: -EIO;
+
+		if (!sendpage_ok(page)) {
+			const void *p = kmap_local_page(page);
+			void *q;
+
+			q = page_frag_memdup(NULL, p + off, part,
+					     sk->sk_allocation, ULONG_MAX);
+			kunmap_local(p);
+			if (!q) {
+				iov_iter_revert(&msg->msg_iter, part);
+				return -ENOMEM;
+			}
+			page = virt_to_page(q);
+			off = offset_in_page(q);
+			put = true;
+		}
+
+		sk_msg_page_add(msg_pl, page, part, off);
+		sk_mem_charge(sk, part);
+		if (put)
+			put_page(page);
+		*copied += part;
+		try_to_copy -= part;
+	} while (try_to_copy && !sk_msg_full(msg_pl));
+
+	return 0;
+}
+
 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 {
 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
@@ -1016,6 +1059,17 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 			full_record = true;
 		}
 
+		if (try_to_copy && (msg->msg_flags & MSG_SPLICE_PAGES)) {
+			ret = rls_sw_sendmsg_splice(sk, msg, msg_pl,
+						    try_to_copy, &copied);
+			if (ret < 0)
+				goto send_end;
+			tls_ctx->pending_open_record_frags = true;
+			if (full_record || eor || sk_msg_full(msg_pl))
+				goto copied;
+			continue;
+		}
+
 		if (!is_kvec && (full_record || eor) && !async_capable) {
 			u32 first = msg_pl->sg.end;
 
@@ -1078,8 +1132,9 @@  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 		/* Open records defined only if successfully copied, otherwise
 		 * we would trim the sg but not reset the open record frags.
 		 */
-		tls_ctx->pending_open_record_frags = true;
 		copied += try_to_copy;
+copied:
+		tls_ctx->pending_open_record_frags = true;
 		if (full_record || eor) {
 			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
 						  record_type, &copied,