For device memory TCP, we let the user provide the kernel with a cmsg container 2 items:
1. the dmabuf pages fd that the user would like to send data from. 2. the offset into this dmabuf that the user would like to start sending from.
In tcp_sendmsg_locked(), if this cmsg is provided, we send the data using the dmabuf NET_TX pages bio_vec.
Also provide drivers with a new skb_devmem_frag_dma_map() helper. This helper is similar to skb_frag_dma_map(), but it first checks whether the frag being mapped is backed by dmabuf NET_TX pages, and provides the correct dma_addr if so.
Signed-off-by: Mina Almasry almasrymina@google.com --- include/linux/skbuff.h | 19 +++++++++-- include/net/sock.h | 2 ++ net/core/skbuff.c | 8 ++--- net/core/sock.c | 6 ++++ net/ipv4/tcp.c | 73 +++++++++++++++++++++++++++++++++++++++++- 5 files changed, 101 insertions(+), 7 deletions(-)
diff --git a/include/linux/skbuff.h b/include/linux/skbuff.h index f5e03aa84160..ad4e7bfcab07 100644 --- a/include/linux/skbuff.h +++ b/include/linux/skbuff.h @@ -1660,8 +1660,8 @@ static inline int skb_zerocopy_iter_dgram(struct sk_buff *skb, }
int skb_zerocopy_iter_stream(struct sock *sk, struct sk_buff *skb, - struct msghdr *msg, int len, - struct ubuf_info *uarg); + struct msghdr *msg, struct iov_iter *iov_iter, + int len, struct ubuf_info *uarg);
/* Internal */ #define skb_shinfo(SKB) ((struct skb_shared_info *)(skb_end_pointer(SKB))) @@ -3557,6 +3557,21 @@ static inline dma_addr_t skb_frag_dma_map(struct device *dev, skb_frag_off(frag) + offset, size, dir); }
+/* Similar to skb_frag_dma_map, but handles devmem skbs correctly. */ +static inline dma_addr_t skb_devmem_frag_dma_map(struct device *dev, + const struct sk_buff *skb, + const skb_frag_t *frag, + size_t offset, size_t size, + enum dma_data_direction dir) +{ + if (unlikely(skb->devmem && is_dma_buf_page(skb_frag_page(frag)))) { + dma_addr_t dma_addr = + dma_buf_page_to_dma_addr(skb_frag_page(frag)); + return dma_addr + skb_frag_off(frag) + offset; + } + return skb_frag_dma_map(dev, frag, offset, size, dir); +} + static inline struct sk_buff *pskb_copy(struct sk_buff *skb, gfp_t gfp_mask) { diff --git a/include/net/sock.h b/include/net/sock.h index c615666ff19a..733865f89635 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1890,6 +1890,8 @@ struct sockcm_cookie { u64 transmit_time; u32 mark; u32 tsflags; + u32 devmem_fd; + u32 devmem_offset; };
static inline void sockcm_init(struct sockcm_cookie *sockc, diff --git a/net/core/skbuff.c b/net/core/skbuff.c index 9b83da794641..b1e28e7ad6a8 100644 --- a/net/core/skbuff.c +++ b/net/core/skbuff.c @@ -1685,8 +1685,8 @@ void msg_zerocopy_put_abort(struct ubuf_info *uarg, bool have_uref) EXPORT_SYMBOL_GPL(msg_zerocopy_put_abort);
int skb_zerocopy_iter_stream(struct sock *sk, struct sk_buff *skb, - struct msghdr *msg, int len, - struct ubuf_info *uarg) + struct msghdr *msg, struct iov_iter *iov_iter, + int len, struct ubuf_info *uarg) { struct ubuf_info *orig_uarg = skb_zcopy(skb); int err, orig_len = skb->len; @@ -1697,12 +1697,12 @@ int skb_zerocopy_iter_stream(struct sock *sk, struct sk_buff *skb, if (orig_uarg && uarg != orig_uarg) return -EEXIST;
- err = __zerocopy_sg_from_iter(msg, sk, skb, &msg->msg_iter, len); + err = __zerocopy_sg_from_iter(msg, sk, skb, iov_iter, len); if (err == -EFAULT || (err == -EMSGSIZE && skb->len == orig_len)) { struct sock *save_sk = skb->sk;
/* Streams do not free skb on error. Reset to prev state. */ - iov_iter_revert(&msg->msg_iter, skb->len - orig_len); + iov_iter_revert(iov_iter, skb->len - orig_len); skb->sk = sk; ___pskb_trim(skb, orig_len); skb->sk = save_sk; diff --git a/net/core/sock.c b/net/core/sock.c index f9b9d9ec7322..854624bee5d0 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -2813,6 +2813,12 @@ int __sock_cmsg_send(struct sock *sk, struct cmsghdr *cmsg, return -EINVAL; sockc->transmit_time = get_unaligned((u64 *)CMSG_DATA(cmsg)); break; + case SCM_DEVMEM_OFFSET: + if (cmsg->cmsg_len != CMSG_LEN(2 * sizeof(u32))) + return -EINVAL; + sockc->devmem_fd = ((u32 *)CMSG_DATA(cmsg))[0]; + sockc->devmem_offset = ((u32 *)CMSG_DATA(cmsg))[1]; + break; /* SCM_RIGHTS and SCM_CREDENTIALS are semantically in SOL_UNIX. */ case SCM_RIGHTS: case SCM_CREDENTIALS: diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index a894b8a9dbb0..85d6cdc832ef 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -280,6 +280,7 @@ #include <asm/ioctls.h> #include <net/busy_poll.h> #include <linux/dma-buf.h> +#include <uapi/linux/dma-buf.h>
/* Track pending CMSGs. */ enum { @@ -1216,6 +1217,52 @@ int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, int *copied, return err; }
+static int tcp_prepare_devmem_data(struct msghdr *msg, int devmem_fd, + unsigned int devmem_offset, + struct file **devmem_file, + struct iov_iter *devmem_tx_iter, size_t size) +{ + struct dma_buf_pages *priv; + int err = 0; + + *devmem_file = fget_raw(devmem_fd); + if (!*devmem_file) { + err = -EINVAL; + goto err; + } + + if (!is_dma_buf_pages_file(*devmem_file)) { + err = -EBADF; + goto err_fput; + } + + priv = (*devmem_file)->private_data; + if (!priv) { + WARN_ONCE(!priv, "dma_buf_pages_file has no private_data"); + err = -EINTR; + goto err_fput; + } + + if (!(priv->type & DMA_BUF_PAGES_NET_TX)) + return -EINVAL; + + if (devmem_offset + size > priv->dmabuf->size) { + err = -ENOSPC; + goto err_fput; + } + + *devmem_tx_iter = priv->net_tx.iter; + iov_iter_advance(devmem_tx_iter, devmem_offset); + + return 0; + +err_fput: + fput(*devmem_file); + *devmem_file = NULL; +err: + return err; +} + int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) { struct tcp_sock *tp = tcp_sk(sk); @@ -1227,6 +1274,8 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) int process_backlog = 0; bool zc = false; long timeo; + struct file *devmem_file = NULL; + struct iov_iter devmem_tx_iter;
flags = msg->msg_flags;
@@ -1295,6 +1344,14 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) } }
+ if (sockc.devmem_fd) { + err = tcp_prepare_devmem_data(msg, sockc.devmem_fd, + sockc.devmem_offset, &devmem_file, + &devmem_tx_iter, size); + if (err) + goto out_err; + } + /* This should be in poll */ sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
@@ -1408,7 +1465,17 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) goto wait_for_space; }
- err = skb_zerocopy_iter_stream(sk, skb, msg, copy, uarg); + if (devmem_file) { + err = skb_zerocopy_iter_stream(sk, skb, msg, + &devmem_tx_iter, + copy, uarg); + if (err > 0) + iov_iter_advance(&msg->msg_iter, err); + } else { + err = skb_zerocopy_iter_stream(sk, skb, msg, + &msg->msg_iter, + copy, uarg); + } if (err == -EMSGSIZE || err == -EEXIST) { tcp_mark_push(tp, skb); goto new_segment; @@ -1462,6 +1529,8 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) } out_nopush: net_zcopy_put(uarg); + if (devmem_file) + fput(devmem_file); return copied + copied_syn;
do_error: @@ -1470,6 +1539,8 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) if (copied + copied_syn) goto out; out_err: + if (devmem_file) + fput(devmem_file); net_zcopy_put_abort(uarg, true); err = sk_stream_error(sk, flags, err); /* make sure we wake any epoll edge trigger waiter */