Thread (14 messages) 14 messages, 2 authors, 2019-11-04

Re: [oss-drivers] Re: [PATCH net] net/tls: fix sk_msg trim on fallback to copy mode

From: John Fastabend <john.fastabend@gmail.com>
Date: 2019-11-04 18:57:05
Also in: linux-crypto

Jakub Kicinski wrote:
quoted hunk ↗ jump to hunk
On Fri, 1 Nov 2019 10:22:38 -0700, Jakub Kicinski wrote:
quoted
quoted
quoted
+		msg->sg.copybreak = 0;
+	} else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) >
+		   sk_msg_iter_dist(msg->sg.end, msg->sg.curr)) {
+		sk_msg_iter_var_prev(i);    
I suspect with small update to dist logic the special case could also
be dropped here. But I have a preference for my example above at the
moment. Just getting coffee now so will think on it though.  
Oka, I like the dist thing, I thought that's where you were going in
your first email :)

I need to do some more admin, and then I'll probably write a unit test
for this code (use space version).. So we can test either patch with it.
Attaching my "unit test", you should be able to just replace
sk_msg_trim() with yours and re-run. That said my understanding of the
expected geometry of the buffer may not be correct :)

The patch I posted yesterday, with the small adjustment to set curr to
start on empty message passes that test, here it is again:

----->8-----

From 953df5bc0992e31a2c7863ea8b8e490ba7a07356 Mon Sep 17 00:00:00 2001
From: Jakub Kicinski <redacted>
Date: Tue, 29 Oct 2019 20:20:49 -0700
Subject: [PATCH net] net/tls: fix sk_msg trim on fallback to copy mode

sk_msg_trim() tries to only update curr pointer if it falls into
the trimmed region. The logic, however, does not take into the
account pointer wrapping that sk_msg_iter_var_prev() does nor
(as John points out) the fact that msg->sg is a ring buffer.

This means that when the message was trimmed completely, the new
curr pointer would have the value of MAX_MSG_FRAGS - 1, which is
neither smaller than any other value, nor would it actually be
correct.

Special case the trimming to 0 length a little bit and rework
the comparison between curr and end to take into account wrapping.

This bug caused the TLS code to not copy all of the message, if
zero copy filled in fewer sg entries than memcopy would need.

Big thanks to Alexander Potapenko for the non-KMSAN reproducer.

v2:
 - take into account that msg->sg is a ring buffer (John).

Fixes: d829e9c4112b ("tls: convert to generic sk_msg interface")
Suggested-by: John Fastabend <john.fastabend@gmail.com>
Reported-by: syzbot+f8495bff23a879a6d0bd@syzkaller.appspotmail.com
Reported-by: syzbot+6f50c99e8f6194bf363f@syzkaller.appspotmail.com
Signed-off-by: Jakub Kicinski <redacted>
---
 include/linux/skmsg.h |  9 ++++++---
 net/core/skmsg.c      | 20 +++++++++++++++-----
 2 files changed, 21 insertions(+), 8 deletions(-)
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index e4b3fb4bb77c..ce7055259877 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -139,6 +139,11 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
 	}
 }
 
+static inline u32 sk_msg_iter_dist(u32 start, u32 end)
+{
+	return end >= start ? end - start : end + (MAX_MSG_FRAGS - start);
+}
+
 #define sk_msg_iter_var_prev(var)			\
 	do {						\
 		if (var == 0)				\
@@ -198,9 +203,7 @@ static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
 	if (sk_msg_full(msg))
 		return MAX_MSG_FRAGS;
 
-	return msg->sg.end >= msg->sg.start ?
-		msg->sg.end - msg->sg.start :
-		msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start);
+	return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
 }
I think its nice to pull this into a helper so I'm ok with also using the
dist below, except for one comment below.
quoted hunk ↗ jump to hunk
 
 static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index cf390e0aa73d..f87fde3a846c 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -270,18 +270,28 @@ void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
 
 	msg->sg.data[i].length -= trim;
 	sk_mem_uncharge(sk, trim);
+	/* Adjust copybreak if it falls into the trimmed part of last buf */
+	if (msg->sg.curr == i && msg->sg.copybreak > msg->sg.data[i].length)
+		msg->sg.copybreak = msg->sg.data[i].length;
 out:
-	/* If we trim data before curr pointer update copybreak and current
-	 * so that any future copy operations start at new copy location.
+	sk_msg_iter_var_next(i);
+	msg->sg.end = i;
+
+	/* If we trim data a full sg elem before curr pointer update
+	 * copybreak and current so that any future copy operations
+	 * start at new copy location.
 	 * However trimed data that has not yet been used in a copy op
 	 * does not require an update.
 	 */
-	if (msg->sg.curr >= i) {
+	if (!msg->sg.size) {
+		msg->sg.curr = msg->sg.start;
+		msg->sg.copybreak = 0;
+	} else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) >
+		   sk_msg_iter_dist(msg->sg.end, msg->sg.curr)) {
I'm not seeing how this can work. Taking simple case with start < end
so normal geometry without wrapping. Let,

 start = 1
 curr  = 3
 end   = 4

We could trim an index to get,

 start = 1
  curr = 3
     i = 3
   end = 4

Then after out: label this would push end up one,

 start = 1
  curr = 3
     i = 3
   end = 4

But dist(start,curr) = 2 and dist(end, curr) = 1 and we would set curr
to '3' but clear the copybreak? I think a better comparison would be,

  if (sk_msg_iter_dist(msg->sg.start, i) <
      sk_msg_iter_dist(msg->sg.start, msg->sg.curr)

To check if 'i' walked past curr so we can reset curr/copybreak?
+		sk_msg_iter_var_prev(i);
 		msg->sg.curr = i;
 		msg->sg.copybreak = msg->sg.data[i].length;
 	}
-	sk_msg_iter_var_next(i);
-	msg->sg.end = i;
 }
 EXPORT_SYMBOL_GPL(sk_msg_trim);
 
-- 
2.23.0
  
Keyboard shortcuts
hback out one level
jnext message in thread
kprevious message in thread
ldrill in
Escclose help / fold thread tree
?toggle this help