On Wed, Jul 24, 2024 at 01:32 PM +02, Michal Luczaj wrote:
Rewrite function to have (unneeded) socket descriptors automatically close()d when leaving the scope. Make sure the "ownership" of fds is correctly passed via take_fd(); i.e. descriptor returned to caller will remain valid.
Suggested-by: Jakub Sitnicki jakub@cloudflare.com Signed-off-by: Michal Luczaj mhal@rbox.co
.../selftests/bpf/prog_tests/sockmap_helpers.h | 57 ++++++++++++---------- 1 file changed, 32 insertions(+), 25 deletions(-)
diff --git a/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h b/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h index ead8ea4fd0da..2e0f9fe459be 100644 --- a/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h +++ b/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h @@ -182,6 +182,21 @@ __ret; \ }) +#define take_fd(fd) \
- ({ \
__auto_type __val = (fd); \
fd = -EBADF; \
__val; \
- })
Probably should operate on a pointer to fd to avoid side effects, like __get_and_null macro in include/linux/cleanup.h. take_fd is effectively __get_and_null(fd, -EBADFD).
+static inline void close_fd(int *fd) +{
- if (*fd >= 0)
xclose(*fd);
+}
+#define __close_fd __attribute__((cleanup(close_fd)))
static inline int poll_connect(int fd, unsigned int timeout_sec) { struct timeval timeout = { .tv_sec = timeout_sec }; @@ -369,9 +384,10 @@ static inline int socket_loopback(int family, int sotype) static inline int create_pair(int family, int sotype, int *p0, int *p1) {
- __close_fd int s, c = -1, p = -1; struct sockaddr_storage addr; socklen_t len = sizeof(addr);
- int s, c, p, err;
- int err;
s = socket_loopback(family, sotype); if (s < 0) @@ -379,25 +395,23 @@ static inline int create_pair(int family, int sotype, int *p0, int *p1) err = xgetsockname(s, sockaddr(&addr), &len); if (err)
goto close_s;
return err;
c = xsocket(family, sotype, 0);
- if (c < 0) {
err = c;
goto close_s;
- }
- if (c < 0)
return c;
err = connect(c, sockaddr(&addr), len); if (err) { if (errno != EINPROGRESS) { FAIL_ERRNO("connect");
goto close_c;
}return err;
err = poll_connect(c, IO_TIMEOUT_SEC); if (err) { FAIL_ERRNO("poll_connect");
goto close_c;
} }return err;
@@ -405,36 +419,29 @@ static inline int create_pair(int family, int sotype, int *p0, int *p1) case SOCK_DGRAM: err = xgetsockname(c, sockaddr(&addr), &len); if (err)
goto close_c;
return err;
err = xconnect(s, sockaddr(&addr), len);
if (!err) {
*p0 = s;
*p1 = c;
if (err) return err;
}
break; case SOCK_STREAM: case SOCK_SEQPACKET: p = xaccept_nonblock(s, NULL, NULL);*p0 = take_fd(s);
if (p >= 0) {
*p0 = p;
*p1 = c;
goto close_s;
}
if (p < 0)
return p;
err = p;
break; default: FAIL("Unsupported socket type %#x", sotype);*p0 = take_fd(p);
err = -EOPNOTSUPP;
}return -EOPNOTSUPP;
-close_c:
- close(c);
-close_s:
- close(s);
- return err;
- *p1 = take_fd(c);
- return 0;
} static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
This turned out nice and readable, IMHO.