Rewrite function to have (unneeded) socket descriptors automatically close()d when leaving the scope. Make sure the "ownership" of fds is correctly passed via take_fd(); i.e. descriptor returned to caller will remain valid.
Suggested-by: Jakub Sitnicki jakub@cloudflare.com Signed-off-by: Michal Luczaj mhal@rbox.co --- .../selftests/bpf/prog_tests/sockmap_helpers.h | 61 +++++++++++++--------- 1 file changed, 36 insertions(+), 25 deletions(-)
diff --git a/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h b/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h index ead8ea4fd0da..38e35c72bdaa 100644 --- a/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h +++ b/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h @@ -17,6 +17,17 @@
#define __always_unused __attribute__((__unused__))
+/* include/linux/cleanup.h */ +#define __get_and_null(p, nullvalue) \ + ({ \ + __auto_type __ptr = &(p); \ + __auto_type __val = *__ptr; \ + *__ptr = nullvalue; \ + __val; \ + }) + +#define take_fd(fd) __get_and_null(fd, -EBADF) + #define _FAIL(errnum, fmt...) \ ({ \ error_at_line(0, (errnum), __func__, __LINE__, fmt); \ @@ -182,6 +193,14 @@ __ret; \ })
+static inline void close_fd(int *fd) +{ + if (*fd >= 0) + xclose(*fd); +} + +#define __close_fd __attribute__((cleanup(close_fd))) + static inline int poll_connect(int fd, unsigned int timeout_sec) { struct timeval timeout = { .tv_sec = timeout_sec }; @@ -369,9 +388,10 @@ static inline int socket_loopback(int family, int sotype)
static inline int create_pair(int family, int sotype, int *p0, int *p1) { + __close_fd int s, c = -1, p = -1; struct sockaddr_storage addr; socklen_t len = sizeof(addr); - int s, c, p, err; + int err;
s = socket_loopback(family, sotype); if (s < 0) @@ -379,25 +399,23 @@ static inline int create_pair(int family, int sotype, int *p0, int *p1)
err = xgetsockname(s, sockaddr(&addr), &len); if (err) - goto close_s; + return err;
c = xsocket(family, sotype, 0); - if (c < 0) { - err = c; - goto close_s; - } + if (c < 0) + return c;
err = connect(c, sockaddr(&addr), len); if (err) { if (errno != EINPROGRESS) { FAIL_ERRNO("connect"); - goto close_c; + return err; }
err = poll_connect(c, IO_TIMEOUT_SEC); if (err) { FAIL_ERRNO("poll_connect"); - goto close_c; + return err; } }
@@ -405,36 +423,29 @@ static inline int create_pair(int family, int sotype, int *p0, int *p1) case SOCK_DGRAM: err = xgetsockname(c, sockaddr(&addr), &len); if (err) - goto close_c; + return err;
err = xconnect(s, sockaddr(&addr), len); - if (!err) { - *p0 = s; - *p1 = c; + if (err) return err; - } + + *p0 = take_fd(s); break; case SOCK_STREAM: case SOCK_SEQPACKET: p = xaccept_nonblock(s, NULL, NULL); - if (p >= 0) { - *p0 = p; - *p1 = c; - goto close_s; - } + if (p < 0) + return p;
- err = p; + *p0 = take_fd(p); break; default: FAIL("Unsupported socket type %#x", sotype); - err = -EOPNOTSUPP; + return -EOPNOTSUPP; }
-close_c: - close(c); -close_s: - close(s); - return err; + *p1 = take_fd(c); + return 0; }
static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,