On Sun, 2023-11-05 at 18:44 -0800, Mina Almasry wrote: [...]
+/* On error, returns the -errno. On success, returns number of bytes sent to the
- user. May not consume all of @remaining_len.
- */
+static int tcp_recvmsg_devmem(const struct sock *sk, const struct sk_buff *skb,
unsigned int offset, struct msghdr *msg,
int remaining_len)
+{
- struct cmsg_devmem cmsg_devmem = { 0 };
- unsigned int start;
- int i, copy, n;
- int sent = 0;
- int err = 0;
- do {
start = skb_headlen(skb);
if (!skb_frags_not_readable(skb)) {
As 'skb_frags_not_readable()' is intended to be a possibly wider scope test then skb->devmem, should the above test explicitly skb->devmem?
err = -ENODEV;
goto out;
}
/* Copy header. */
copy = start - offset;
if (copy > 0) {
copy = min(copy, remaining_len);
n = copy_to_iter(skb->data + offset, copy,
&msg->msg_iter);
if (n != copy) {
err = -EFAULT;
goto out;
}
offset += copy;
remaining_len -= copy;
/* First a cmsg_devmem for # bytes copied to user
* buffer.
*/
memset(&cmsg_devmem, 0, sizeof(cmsg_devmem));
cmsg_devmem.frag_size = copy;
err = put_cmsg(msg, SOL_SOCKET, SO_DEVMEM_HEADER,
sizeof(cmsg_devmem), &cmsg_devmem);
if (err || msg->msg_flags & MSG_CTRUNC) {
msg->msg_flags &= ~MSG_CTRUNC;
if (!err)
err = -ETOOSMALL;
goto out;
}
sent += copy;
if (remaining_len == 0)
goto out;
}
/* after that, send information of devmem pages through a
* sequence of cmsg
*/
for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
const skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
struct page_pool_iov *ppiov;
u64 frag_offset;
u32 user_token;
int end;
/* skb_frags_not_readable() should indicate that ALL the
* frags in this skb are unreadable page_pool_iovs.
* We're checking for that flag above, but also check
* individual pages here. If the tcp stack is not
* setting skb->devmem correctly, we still don't want to
* crash here when accessing pgmap or priv below.
*/
if (!skb_frag_page_pool_iov(frag)) {
net_err_ratelimited("Found non-devmem skb with page_pool_iov");
err = -ENODEV;
goto out;
}
ppiov = skb_frag_page_pool_iov(frag);
end = start + skb_frag_size(frag);
copy = end - offset;
if (copy > 0) {
copy = min(copy, remaining_len);
frag_offset = page_pool_iov_virtual_addr(ppiov) +
skb_frag_off(frag) + offset -
start;
cmsg_devmem.frag_offset = frag_offset;
cmsg_devmem.frag_size = copy;
err = xa_alloc((struct xarray *)&sk->sk_user_pages,
&user_token, frag->bv_page,
xa_limit_31b, GFP_KERNEL);
if (err)
goto out;
cmsg_devmem.frag_token = user_token;
offset += copy;
remaining_len -= copy;
err = put_cmsg(msg, SOL_SOCKET,
SO_DEVMEM_OFFSET,
sizeof(cmsg_devmem),
&cmsg_devmem);
if (err || msg->msg_flags & MSG_CTRUNC) {
msg->msg_flags &= ~MSG_CTRUNC;
xa_erase((struct xarray *)&sk->sk_user_pages,
user_token);
if (!err)
err = -ETOOSMALL;
goto out;
}
page_pool_iov_get_many(ppiov, 1);
sent += copy;
if (remaining_len == 0)
goto out;
}
start = end;
}
if (!remaining_len)
goto out;
/* if remaining_len is not satisfied yet, we need to go to the
* next frag in the frag_list to satisfy remaining_len.
*/
skb = skb_shinfo(skb)->frag_list ?: skb->next;
I think at this point the 'skb' is still on the sk receive queue. The above will possibly walk the queue.
Later on, only the current queue tail could be possibly consumed by tcp_recvmsg_locked(). This feel confusing to me?!? Why don't limit the loop only the 'current' skb and it's frags?
offset = offset - start;
- } while (skb);
- if (remaining_len) {
err = -EFAULT;
goto out;
- }
+out:
- if (!sent)
sent = err;
- return sent;
+}
/*
- This routine copies from a sock struct into the user buffer.
@@ -2314,6 +2463,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, int *cmsg_flags) { struct tcp_sock *tp = tcp_sk(sk);
- int last_copied_devmem = -1; /* uninitialized */ int copied = 0; u32 peek_seq; u32 *seq;
@@ -2491,15 +2641,44 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, } if (!(flags & MSG_TRUNC)) {
err = skb_copy_datagram_msg(skb, offset, msg, used);
if (err) {
/* Exception. Bailout! */
if (!copied)
copied = -EFAULT;
if (last_copied_devmem != -1 &&
last_copied_devmem != skb->devmem) break;
if (!skb->devmem) {
err = skb_copy_datagram_msg(skb, offset, msg,
used);
if (err) {
/* Exception. Bailout! */
if (!copied)
copied = -EFAULT;
break;
}
} else {
if (!(flags & MSG_SOCK_DEVMEM)) {
/* skb->devmem skbs can only be received
* with the MSG_SOCK_DEVMEM flag.
*/
if (!copied)
copied = -EFAULT;
break;
}
err = tcp_recvmsg_devmem(sk, skb, offset, msg,
used);
if (err <= 0) {
if (!copied)
copied = -EFAULT;
break;
}
used = err;
Minor nit: I personally would find the above more readable, placing this whole chunk in a single helper (e.g. the current tcp_recvmsg_devmem(), renamed to something more appropriate).
Cheers,
Paolo