kernel/libnetwork: add accept4() from POSIX.1/2024

* add SOCK_NONBLOCK and SOCK_CLOEXEC
* also extends the type parameter on socketpair() and socket()

Change-Id: I73570d5bfb57c2da00c1086149c9f07547ba61ce
Reviewed-on: https://review.haiku-os.org/c/haiku/+/8515
Tested-by: Commit checker robot <no-reply+buildbot@haiku-os.org>
Reviewed-by: waddlesplash <waddlesplash@gmail.com>
This commit is contained in:
Jérôme Duval 2024-10-30 16:25:28 +01:00 committed by waddlesplash
parent 966076b273
commit 6beff0d163
10 changed files with 153 additions and 21 deletions

View File

@ -47,6 +47,9 @@ typedef uint8_t sa_family_t;
#define SOCK_SEQPACKET 5
#define SOCK_MISC 255
#define SOCK_NONBLOCK 0x00040000
#define SOCK_CLOEXEC 0x00080000
/* Socket options for SOL_SOCKET level */
#define SOL_SOCKET -1
@ -163,6 +166,7 @@ extern "C" {
#endif
int accept(int socket, struct sockaddr *address, socklen_t *_addressLength);
int accept4(int socket, struct sockaddr *address, socklen_t *_addressLength, int flags);
int bind(int socket, const struct sockaddr *address,
socklen_t addressLength);
int connect(int socket, const struct sockaddr *address,

View File

@ -264,7 +264,7 @@ status_t _user_connect(int socket, const struct sockaddr *address,
socklen_t addressLength);
status_t _user_listen(int socket, int backlog);
int _user_accept(int socket, struct sockaddr *address,
socklen_t *_addressLength);
socklen_t *_addressLength, int flags);
ssize_t _user_recv(int socket, void *data, size_t length, int flags);
ssize_t _user_recvfrom(int socket, void *data, size_t length, int flags,
struct sockaddr *address, socklen_t *_addressLength);

View File

@ -365,7 +365,7 @@ extern status_t _kern_connect(int socket, const struct sockaddr *address,
socklen_t addressLength);
extern status_t _kern_listen(int socket, int backlog);
extern int _kern_accept(int socket, struct sockaddr *address,
socklen_t *_addressLength);
socklen_t *_addressLength, int flags);
extern ssize_t _kern_recv(int socket, void *data, size_t length,
int flags);
extern ssize_t _kern_recvfrom(int socket, void *data, size_t length,

View File

@ -179,7 +179,7 @@ TypeHandlerImpl<const char*>::GetReturnValue(Context &context, uint64 value)
return read_string(context, (void *)value);
}
// #pragma mark - enums, flags
// #pragma mark - enums, flags, enum_flags
EnumTypeHandler::EnumTypeHandler(const EnumMap &m) : fMap(m) {}
@ -258,6 +258,43 @@ FlagsTypeHandler::RenderValue(Context &context, unsigned int value) const
return context.FormatUnsigned(value);
}
EnumFlagsTypeHandler::EnumFlagsTypeHandler(const EnumMap &m, const FlagsTypeHandler::FlagsList &l)
:
EnumTypeHandler(m), fList(l) {}
string
EnumFlagsTypeHandler::RenderValue(Context &context, unsigned int value) const
{
if (context.GetContents(Context::ENUMERATIONS)) {
string rendered;
FlagsTypeHandler::FlagsList::const_reverse_iterator i = fList.rbegin();
for (; i != fList.rend(); i++) {
if (value == 0)
break;
if ((value & i->value) != i->value)
continue;
if (!rendered.empty())
rendered.insert(0, "|");
rendered.insert(0, i->name);
value &= ~(i->value);
}
EnumMap::const_iterator j = fMap.find(value);
if (j != fMap.end() && j->second != NULL) {
if (!rendered.empty())
rendered.insert(0, "|");
rendered.insert(0, j->second);
}
return rendered;
}
return context.FormatUnsigned(value);
}
TypeHandlerSelector::TypeHandlerSelector(const SelectMap &m, int sibling,
TypeHandler *def)
: fMap(m), fSibling(sibling), fDefault(def) {}

View File

@ -35,7 +35,7 @@ public:
virtual string GetReturnValue(Context &, uint64 value) = 0;
};
class EnumTypeHandler : public TypeHandler {
class EnumTypeHandler : virtual public TypeHandler {
public:
typedef std::map<int, const char *> EnumMap;
@ -44,13 +44,13 @@ public:
string GetParameterValue(Context &c, Parameter *, const void *);
string GetReturnValue(Context &, uint64 value);
string RenderValue(Context &, unsigned int value) const;
virtual string RenderValue(Context &, unsigned int value) const;
private:
protected:
const EnumMap &fMap;
};
class FlagsTypeHandler : public TypeHandler {
class FlagsTypeHandler : virtual public TypeHandler {
public:
struct FlagInfo {
unsigned int value;
@ -63,12 +63,22 @@ public:
string GetParameterValue(Context &c, Parameter *, const void *);
string GetReturnValue(Context &, uint64 value);
string RenderValue(Context &, unsigned int value) const;
virtual string RenderValue(Context &, unsigned int value) const;
private:
const FlagsList &fList;
};
class EnumFlagsTypeHandler : public EnumTypeHandler {
public:
EnumFlagsTypeHandler(const EnumMap &, const FlagsTypeHandler::FlagsList &);
string RenderValue(Context &, unsigned int value) const;
private:
const FlagsTypeHandler::FlagsList &fList;
};
// currently limited to select ints
class TypeHandlerSelector : public TypeHandler {
public:

View File

@ -79,10 +79,19 @@ static const enum_info kShutdownHow[] = {
};
static const FlagsTypeHandler::FlagInfo kSocketFlagInfos[] = {
FLAG_INFO_ENTRY(SOCK_NONBLOCK),
FLAG_INFO_ENTRY(SOCK_CLOEXEC),
{ 0, NULL }
};
static FlagsTypeHandler::FlagsList kRecvFlags;
static EnumTypeHandler::EnumMap kSocketFamilyMap;
static EnumTypeHandler::EnumMap kSocketTypeMap;
static EnumTypeHandler::EnumMap kShutdownHowMap;
static FlagsTypeHandler::FlagsList kSocketFlags;
void
@ -100,6 +109,9 @@ patch_network()
for (int i = 0; kShutdownHow[i].name != NULL; i++) {
kShutdownHowMap[kShutdownHow[i].index] = kShutdownHow[i].name;
}
for (int i = 0; kSocketFlagInfos[i].name != NULL; i++) {
kSocketFlags.push_back(kSocketFlagInfos[i]);
}
Syscall *recv = get_syscall("_kern_recv");
recv->GetParameter("flags")->SetHandler(new FlagsTypeHandler(kRecvFlags));
@ -118,7 +130,7 @@ patch_network()
socket->GetParameter("family")->SetHandler(
new EnumTypeHandler(kSocketFamilyMap));
socket->GetParameter("type")->SetHandler(
new EnumTypeHandler(kSocketTypeMap));
new EnumFlagsTypeHandler(kSocketTypeMap, kSocketFlags));
Syscall *shutdown = get_syscall("_kern_shutdown_socket");
shutdown->GetParameter("how")->SetHandler(
@ -130,5 +142,8 @@ patch_network()
socketPair->GetParameter("family")->SetHandler(
new EnumTypeHandler(kSocketFamilyMap));
socketPair->GetParameter("type")->SetHandler(
new EnumTypeHandler(kSocketTypeMap));
new EnumFlagsTypeHandler(kSocketTypeMap, kSocketFlags));
Syscall *accept = get_syscall("_kern_accept");
accept->GetParameter("flags")->SetHandler(new FlagsTypeHandler(kSocketFlags));
}

View File

@ -358,7 +358,7 @@ get_socket_descriptor(int fd, bool kernel, file_descriptor*& descriptor)
static int
create_socket_fd(net_socket* socket, bool kernel)
create_socket_fd(net_socket* socket, int flags, bool kernel)
{
// Get the socket's non-blocking flag, so we can set the respective
// open mode flag.
@ -368,6 +368,11 @@ create_socket_fd(net_socket* socket, bool kernel)
SO_NONBLOCK, &nonBlock, &nonBlockLen);
if (error != B_OK)
return error;
int oflags = 0;
if ((flags & SOCK_CLOEXEC) != 0)
oflags |= O_CLOEXEC;
if ((flags & SOCK_NONBLOCK) != 0 || nonBlock)
oflags |= O_NONBLOCK;
// allocate a file descriptor
file_descriptor* descriptor = alloc_fd();
@ -377,15 +382,20 @@ create_socket_fd(net_socket* socket, bool kernel)
// init it
descriptor->ops = &sSocketFDOps;
descriptor->cookie = socket;
descriptor->open_mode = O_RDWR | (nonBlock ? O_NONBLOCK : 0);
descriptor->open_mode = O_RDWR | oflags;
// publish it
int fd = new_fd(get_current_io_context(kernel), descriptor);
io_context* context = get_current_io_context(kernel);
int fd = new_fd(context, descriptor);
if (fd < 0) {
descriptor->ops = NULL;
put_fd(descriptor);
}
mutex_lock(&context->io_mutex);
fd_set_close_on_exec(context, fd, (oflags & O_CLOEXEC) != 0);
mutex_unlock(&context->io_mutex);
return fd;
}
@ -399,6 +409,9 @@ common_socket(int family, int type, int protocol, bool kernel)
if (!get_stack_interface_module())
return B_UNSUPPORTED;
int sflags = type & (SOCK_CLOEXEC | SOCK_NONBLOCK);
type &= ~(SOCK_CLOEXEC | SOCK_NONBLOCK);
// create the socket
net_socket* socket;
status_t error = sStackInterface->open(family, type, protocol, &socket);
@ -408,7 +421,7 @@ common_socket(int family, int type, int protocol, bool kernel)
}
// allocate the FD
int fd = create_socket_fd(socket, kernel);
int fd = create_socket_fd(socket, sflags, kernel);
if (fd < 0) {
sStackInterface->close(socket);
sStackInterface->free(socket);
@ -467,13 +480,16 @@ common_listen(int fd, int backlog, bool kernel)
static int
common_accept(int fd, struct sockaddr *address, socklen_t *_addressLength,
common_accept(int fd, struct sockaddr *address, socklen_t *_addressLength, int flags,
bool kernel)
{
file_descriptor* descriptor;
GET_SOCKET_FD_OR_RETURN(fd, kernel, descriptor);
FileDescriptorPutter _(descriptor);
if ((flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK)) != 0)
RETURN_AND_SET_ERRNO(B_BAD_VALUE);
net_socket* acceptedSocket;
status_t error = sStackInterface->accept(FD_SOCKET(descriptor), address,
_addressLength, &acceptedSocket);
@ -481,7 +497,7 @@ common_accept(int fd, struct sockaddr *address, socklen_t *_addressLength,
return error;
// allocate the FD
int acceptedFD = create_socket_fd(acceptedSocket, kernel);
int acceptedFD = create_socket_fd(acceptedSocket, flags, kernel);
if (acceptedFD < 0) {
sStackInterface->close(acceptedSocket);
sStackInterface->free(acceptedSocket);
@ -633,6 +649,9 @@ common_socketpair(int family, int type, int protocol, int fds[2], bool kernel)
if (!get_stack_interface_module())
return B_UNSUPPORTED;
int sflags = type & (SOCK_CLOEXEC | SOCK_NONBLOCK);
type &= ~(SOCK_CLOEXEC | SOCK_NONBLOCK);
net_socket* sockets[2];
status_t error = sStackInterface->socketpair(family, type, protocol,
sockets);
@ -643,7 +662,7 @@ common_socketpair(int family, int type, int protocol, int fds[2], bool kernel)
// allocate the FDs
for (int i = 0; i < 2; i++) {
fds[i] = create_socket_fd(sockets[i], kernel);
fds[i] = create_socket_fd(sockets[i], sflags, kernel);
if (fds[i] < 0) {
sStackInterface->close(sockets[i]);
sStackInterface->free(sockets[i]);
@ -719,7 +738,15 @@ int
accept(int socket, struct sockaddr *address, socklen_t *_addressLength)
{
SyscallFlagUnsetter _;
RETURN_AND_SET_ERRNO(common_accept(socket, address, _addressLength, true));
RETURN_AND_SET_ERRNO(common_accept(socket, address, _addressLength, 0, true));
}
int
accept4(int socket, struct sockaddr *address, socklen_t *_addressLength, int flags)
{
SyscallFlagUnsetter _;
RETURN_AND_SET_ERRNO(common_accept(socket, address, _addressLength, flags, true));
}
@ -907,7 +934,7 @@ _user_listen(int socket, int backlog)
int
_user_accept(int socket, struct sockaddr *userAddress,
socklen_t *_addressLength)
socklen_t *_addressLength, int flags)
{
// check parameters
socklen_t addressLength = 0;
@ -922,7 +949,7 @@ _user_accept(int socket, struct sockaddr *userAddress,
char address[MAX_SOCKET_ADDRESS_LENGTH];
socklen_t userAddressBufferSize = addressLength;
result = common_accept(socket,
userAddress != NULL ? (sockaddr*)address : NULL, &addressLength, false);
userAddress != NULL ? (sockaddr*)address : NULL, &addressLength, flags, false);
// copy address size and address back to userland
if (copy_address_to_userland(address, addressLength, userAddress,

View File

@ -191,6 +191,13 @@ listen(int socket, int backlog)
extern "C" int
accept(int socket, struct sockaddr *_address, socklen_t *_addressLength)
{
return accept4(socket, _address, _addressLength, 0);
}
extern "C" int
accept4(int socket, struct sockaddr *_address, socklen_t *_addressLength, int flags)
{
bool r5compatible = check_r5_compatibility();
struct sockaddr haikuAddr;
@ -206,7 +213,7 @@ accept(int socket, struct sockaddr *_address, socklen_t *_addressLength)
addressLength = _addressLength ? *_addressLength : 0;
}
int acceptSocket = _kern_accept(socket, address, &addressLength);
int acceptSocket = _kern_accept(socket, address, &addressLength, flags);
pthread_testcancel();

View File

@ -25,6 +25,9 @@ SimpleTest unix_send_test : unix_send_test.c : $(TARGET_NETWORK_LIBS) ;
SimpleTest tcp_connection_test : tcp_connection_test.cpp
: $(TARGET_NETWORK_LIBS) ;
SimpleTest test4 : test4.c
: $(TARGET_NETWORK_LIBS) ;
SubInclude HAIKU_TOP src tests system network icmp ;
SubInclude HAIKU_TOP src tests system network ipv6 ;
SubInclude HAIKU_TOP src tests system network multicast ;

View File

@ -0,0 +1,29 @@
#include <stdio.h>
#include <string.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <fcntl.h>
int main(int argc, char **argv)
{
int sock = socket(AF_INET, SOCK_DGRAM | SOCK_CLOEXEC, 0);
if (sock < 0) {
printf("Failed! Socket could not be created.\n");
return -1;
}
int flags = fcntl(sock, F_GETFD);
int ret = 0;
if ((flags & FD_CLOEXEC) == 0) {
printf("Failed! Descriptor flag not found.\n");
ret = -1;
}
close(sock);
printf("Test complete.\n");
return ret;
}