Commit 45b9dccb authored by Zebediah Figura's avatar Zebediah Figura Committed by Alexandre Julliard

ws2_32: Do not assume that an fd_set is bounded by FD_SETSIZE.

parent 40494359
......@@ -2364,13 +2364,8 @@ static int add_fd_to_set( SOCKET fd, struct fd_set *set )
return 0;
}
if (set->fd_count < FD_SETSIZE)
{
set->fd_array[set->fd_count++] = fd;
return 1;
}
return 0;
set->fd_array[set->fd_count++] = fd;
return 1;
}
......@@ -2380,9 +2375,9 @@ static int add_fd_to_set( SOCKET fd, struct fd_set *set )
int WINAPI select( int count, fd_set *read_ptr, fd_set *write_ptr,
fd_set *except_ptr, const struct timeval *timeout)
{
char buffer[offsetof( struct afd_poll_params, sockets[FD_SETSIZE * 3] )] = {0};
struct afd_poll_params *params = (struct afd_poll_params *)buffer;
struct fd_set read_input;
struct fd_set *read_input = NULL;
struct afd_poll_params *params;
unsigned int poll_count = 0;
ULONG params_size, i, j;
SOCKET poll_socket = 0;
IO_STATUS_BLOCK io;
......@@ -2392,22 +2387,49 @@ int WINAPI select( int count, fd_set *read_ptr, fd_set *write_ptr,
TRACE( "read %p, write %p, except %p, timeout %p\n", read_ptr, write_ptr, except_ptr, timeout );
FD_ZERO( &read_input );
if (read_ptr) read_input.fd_count = read_ptr->fd_count;
if (!(sync_event = get_sync_event())) return -1;
if (read_ptr) poll_count += read_ptr->fd_count;
if (write_ptr) poll_count += write_ptr->fd_count;
if (except_ptr) poll_count += except_ptr->fd_count;
if (!poll_count)
{
SetLastError( WSAEINVAL );
return -1;
}
params_size = offsetof( struct afd_poll_params, sockets[poll_count] );
if (!(params = calloc( params_size, 1 )))
{
SetLastError( WSAENOBUFS );
return -1;
}
if (timeout)
params->timeout = timeout->tv_sec * -10000000 + timeout->tv_usec * -10;
else
params->timeout = TIMEOUT_INFINITE;
for (i = 0; i < read_input.fd_count; ++i)
if (read_ptr)
{
params->sockets[params->count].socket = read_input.fd_array[i] = read_ptr->fd_array[i];
params->sockets[params->count].flags = AFD_POLL_READ | AFD_POLL_ACCEPT | AFD_POLL_HUP;
++params->count;
poll_socket = read_input.fd_array[i];
unsigned int read_size = offsetof( struct fd_set, fd_array[read_ptr->fd_count] );
if (!(read_input = malloc( read_size )))
{
free( params );
SetLastError( WSAENOBUFS );
return -1;
}
memcpy( read_input, read_ptr, read_size );
for (i = 0; i < read_ptr->fd_count; ++i)
{
params->sockets[params->count].socket = read_ptr->fd_array[i];
params->sockets[params->count].flags = AFD_POLL_READ | AFD_POLL_ACCEPT | AFD_POLL_HUP;
++params->count;
poll_socket = read_ptr->fd_array[i];
}
}
if (write_ptr)
......@@ -2432,42 +2454,43 @@ int WINAPI select( int count, fd_set *read_ptr, fd_set *write_ptr,
}
}
if (!params->count)
{
SetLastError( WSAEINVAL );
return -1;
}
params_size = offsetof( struct afd_poll_params, sockets[params->count] );
assert( params->count == poll_count );
status = NtDeviceIoControlFile( (HANDLE)poll_socket, sync_event, NULL, NULL, &io,
IOCTL_AFD_POLL, params, params_size, params, params_size );
if (status == STATUS_PENDING)
{
if (WaitForSingleObject( sync_event, INFINITE ) == WAIT_FAILED)
{
free( read_input );
free( params );
return -1;
}
status = io.u.Status;
}
if (status == STATUS_TIMEOUT) status = STATUS_SUCCESS;
if (!status)
{
/* pointers may alias, so clear them all first */
if (read_ptr) FD_ZERO( read_ptr );
if (write_ptr) FD_ZERO( write_ptr );
if (except_ptr) FD_ZERO( except_ptr );
if (read_ptr) read_ptr->fd_count = 0;
if (write_ptr) write_ptr->fd_count = 0;
if (except_ptr) except_ptr->fd_count = 0;
for (i = 0; i < params->count; ++i)
{
unsigned int flags = params->sockets[i].flags;
SOCKET s = params->sockets[i].socket;
for (j = 0; j < read_input.fd_count; ++j)
if (read_input)
{
if (read_input.fd_array[j] == s
&& (flags & (AFD_POLL_READ | AFD_POLL_ACCEPT | AFD_POLL_HUP | AFD_POLL_CLOSE)))
for (j = 0; j < read_input->fd_count; ++j)
{
ret_count += add_fd_to_set( s, read_ptr );
flags &= ~AFD_POLL_CLOSE;
if (read_input->fd_array[j] == s
&& (flags & (AFD_POLL_READ | AFD_POLL_ACCEPT | AFD_POLL_HUP | AFD_POLL_CLOSE)))
{
ret_count += add_fd_to_set( s, read_ptr );
flags &= ~AFD_POLL_CLOSE;
}
}
}
......@@ -2482,6 +2505,9 @@ int WINAPI select( int count, fd_set *read_ptr, fd_set *write_ptr,
}
}
free( read_input );
free( params );
SetLastError( NtStatusToWSAError( status ) );
return status ? -1 : ret_count;
}
......
......@@ -3151,9 +3151,8 @@ static void test_select(void)
{
static char tmp_buf[1024];
SOCKET fdListen, fdRead, fdWrite;
fd_set readfds, writefds, exceptfds, *alloc_readfds;
unsigned int maxfd;
fd_set readfds, writefds, exceptfds, *alloc_fds;
SOCKET fdListen, fdRead, fdWrite, sockets[200];
int ret, len;
char buffer;
struct timeval select_timeout;
......@@ -3161,6 +3160,7 @@ static void test_select(void)
select_thread_params thread_params;
HANDLE thread_handle;
DWORD ticks, id, old_protect;
unsigned int maxfd, i;
char *page_pair;
fdRead = socket(AF_INET, SOCK_STREAM, 0);
......@@ -3339,16 +3339,33 @@ static void test_select(void)
page_pair = VirtualAlloc(NULL, 0x1000 * 2, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
VirtualProtect(page_pair + 0x1000, 0x1000, PAGE_NOACCESS, &old_protect);
alloc_readfds = (fd_set *)((page_pair + 0x1000) - offsetof(fd_set, fd_array[1]));
alloc_readfds->fd_count = 1;
alloc_readfds->fd_array[0] = fdRead;
ret = select(fdRead+1, alloc_readfds, NULL, NULL, &select_timeout);
alloc_fds = (fd_set *)((page_pair + 0x1000) - offsetof(fd_set, fd_array[1]));
alloc_fds->fd_count = 1;
alloc_fds->fd_array[0] = fdRead;
ret = select(fdRead+1, alloc_fds, NULL, NULL, &select_timeout);
ok(ret == 1, "select returned %d\n", ret);
VirtualFree(page_pair, 0, MEM_RELEASE);
closesocket(fdRead);
closesocket(fdWrite);
alloc_fds = malloc(offsetof(fd_set, fd_array[ARRAY_SIZE(sockets)]));
alloc_fds->fd_count = ARRAY_SIZE(sockets);
for (i = 0; i < ARRAY_SIZE(sockets); i += 2)
{
tcp_socketpair(&sockets[i], &sockets[i + 1]);
alloc_fds->fd_array[i] = sockets[i];
alloc_fds->fd_array[i + 1] = sockets[i + 1];
}
ret = select(0, NULL, alloc_fds, NULL, &select_timeout);
ok(ret == ARRAY_SIZE(sockets), "got %d\n", ret);
for (i = 0; i < ARRAY_SIZE(sockets); ++i)
{
ok(alloc_fds->fd_array[i] == sockets[i], "got socket %#Ix at index %u\n", alloc_fds->fd_array[i], i);
closesocket(sockets[i]);
}
free(alloc_fds);
/* select() works in 3 distinct states:
* - to check if a connection attempt ended with success or error;
* - to check if a pending connection is waiting for acceptance;
......
......@@ -19,6 +19,7 @@
#ifndef __WINE_WS2_32_PRIVATE_H
#define __WINE_WS2_32_PRIVATE_H
#include <assert.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment