Thread (22 messages) 22 messages, 3 authors, 3d ago

[RFC net-next 03/17] tls: add protocol dimension to tls operation cache

From: Geliang Tang <geliang@kernel.org>
Date: 2026-06-22 10:44:26
Also in: mptcp
Subsystem: networking [general], networking [tls], the rest · Maintainers: "David S. Miller", Eric Dumazet, Jakub Kicinski, Paolo Abeni, John Fastabend, Sabrina Dubroca, Linus Torvalds

From: Geliang Tang <redacted>

The current TLS operation cache is indexed solely by IP version
(IPv4/IPv6). This was sufficient when only TCP was supported.
Rename TLS_NUM_PROTS to TLS_NUM_FAMILY to accurately reflect that it
represents the number of address families.

With the introduction of MPTCP, both TCP and MPTCP sockets within the
same IP version now share the same cache entries. When an MPTCP socket
enables TLS, it overwrites the cache with MPTCP-specific operations,
causing existing TCP TLS sockets to use the wrong ops, leading to type
confusion and kernel panics.

Fix by extending the cache arrays with a protocol dimension to separate
TCP and MPTCP. Introduce TLSTCP and TLSMPTCP enum values, along with
separate saved protocol pointers and mutexes for MPTCP. update_sk_prot()
and __tls_build_proto() now select the appropriate cache based on
sk->sk_protocol.

Co-developed-by: Gang Yan <redacted>
Signed-off-by: Gang Yan <redacted>
Co-developed-by: Zqiang <qiang.zhang@linux.dev>
Signed-off-by: Zqiang <qiang.zhang@linux.dev>
Signed-off-by: Geliang Tang <redacted>
---
 net/tls/tls_main.c | 40 +++++++++++++++++++++++++++++-----------
 1 file changed, 29 insertions(+), 11 deletions(-)
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index be824affd1b1..94133d62f73e 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -53,7 +53,13 @@ MODULE_ALIAS_TCP_ULP("tls");
 enum {
 	TLSV4,
 	TLSV6,
-	TLS_NUM_PROTS,
+	TLS_NUM_FAMILY,
+};
+
+enum {
+	TLSTCP,
+	TLSMPTCP,
+	TLS_NUM_PROTO,
 };
 
 #define CHECK_CIPHER_DESC(cipher,ci)				\
@@ -117,23 +123,30 @@ CHECK_CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm);
 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128);
 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256);
 
+static const struct proto *saved_mptcpv6_prot;
+static DEFINE_MUTEX(mptcpv6_prot_mutex);
 static const struct proto *saved_tcpv6_prot;
 static DEFINE_MUTEX(tcpv6_prot_mutex);
+static const struct proto *saved_mptcpv4_prot;
+static DEFINE_MUTEX(mptcpv4_prot_mutex);
 static const struct proto *saved_tcpv4_prot;
 static DEFINE_MUTEX(tcpv4_prot_mutex);
-static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
-static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
+static struct proto
+tls_prots[TLS_NUM_FAMILY][TLS_NUM_PROTO][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
+static struct proto_ops
+tls_proto_ops[TLS_NUM_FAMILY][TLS_NUM_PROTO][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
 			 const struct proto *base);
 
 static void update_sk_prot(struct sock *sk, struct tls_context *ctx)
 {
+	int proto = sk->sk_protocol == IPPROTO_MPTCP ? TLSMPTCP : TLSTCP;
 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
 
 	WRITE_ONCE(sk->sk_prot,
-		   &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
+		   &tls_prots[ip_ver][proto][ctx->tx_conf][ctx->rx_conf]);
 	WRITE_ONCE(sk->sk_socket->ops,
-		   &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
+		   &tls_proto_ops[ip_ver][proto][ctx->tx_conf][ctx->rx_conf]);
 }
 
 int wait_on_pending_writer(struct sock *sk, long *timeo)
@@ -971,18 +984,19 @@ static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG]
 static void __tls_build_proto(struct sock *sk,
 			      const struct proto *saved_prot,
 			      struct mutex *prot_mutex,
-			      int family)
+			      int family, int protocol)
 {
+	int proto = sk->sk_protocol == IPPROTO_MPTCP ? TLSMPTCP : TLSTCP;
 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
 	struct proto *prot = READ_ONCE(sk->sk_prot);
 
-	if (ip_ver == family) {
+	if (ip_ver == family && proto == protocol) {
 		/* smp_load_acquire pairs with smp_store_release below */
 		if (unlikely(prot != smp_load_acquire(&saved_prot))) {
 			mutex_lock(prot_mutex);
 			if (likely(prot != saved_prot)) {
-				build_protos(tls_prots[family], prot);
-				build_proto_ops(tls_proto_ops[family],
+				build_protos(tls_prots[family][protocol], prot);
+				build_proto_ops(tls_proto_ops[family][protocol],
 						sk->sk_socket->ops);
 				/* pairs with smp_load_acquire above */
 				smp_store_release(&saved_prot, prot);
@@ -995,10 +1009,14 @@ static void __tls_build_proto(struct sock *sk,
 static void tls_build_proto(struct sock *sk)
 {
 	/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
+	__tls_build_proto(sk, saved_mptcpv6_prot, &mptcpv6_prot_mutex,
+			  TLSV6, TLSMPTCP);
 	__tls_build_proto(sk, saved_tcpv6_prot, &tcpv6_prot_mutex,
-			  TLSV6);
+			  TLSV6, TLSTCP);
+	__tls_build_proto(sk, saved_mptcpv4_prot, &mptcpv4_prot_mutex,
+			  TLSV4, TLSMPTCP);
 	__tls_build_proto(sk, saved_tcpv4_prot, &tcpv4_prot_mutex,
-			  TLSV4);
+			  TLSV4, TLSTCP);
 }
 
 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
-- 
2.53.0
Keyboard shortcuts
hback out one level
jnext message in thread
kprevious message in thread
ldrill in
Escclose help / fold thread tree
?toggle this help