Commit f5242405 authored by Alexandre Julliard's avatar Alexandre Julliard

New mechanism to transfer file descriptors from client to server.

parent 0ba59090
......@@ -184,10 +184,14 @@ void FILE_SetDosError(void)
HANDLE FILE_DupUnixHandle( int fd, DWORD access )
{
HANDLE ret;
wine_server_send_fd( fd );
SERVER_START_REQ( alloc_file_handle )
{
req->access = access;
SERVER_CALL_FD( fd );
req->fd = fd;
SERVER_CALL();
ret = req->handle;
}
SERVER_END_REQ;
......
......@@ -118,6 +118,12 @@ typedef struct
union debug_event_data info; /* event information */
} debug_event_t;
/* structure used in sending an fd from client to server */
struct send_fd
{
void *tid; /* thread id */
int fd; /* file descriptor on client-side */
};
/* Create a new process from the context of the parent */
struct new_process_request
......@@ -539,6 +545,7 @@ struct alloc_file_handle_request
{
REQUEST_HEADER; /* request header */
IN unsigned int access; /* wanted access rights */
IN int fd; /* file descriptor on the client side */
OUT handle_t handle; /* handle to the file */
};
......@@ -1592,7 +1599,7 @@ union generic_request
struct async_result_request async_result;
};
#define SERVER_PROTOCOL_VERSION 39
#define SERVER_PROTOCOL_VERSION 40
/* ### make_requests end ### */
/* Everything above this line is generated automatically by tools/make_requests */
......@@ -1611,10 +1618,10 @@ union generic_request
/* client communication functions */
extern unsigned int wine_server_call( union generic_request *req, size_t size );
extern unsigned int wine_server_call_fd( union generic_request *req, size_t size, int fd );
extern void server_protocol_error( const char *err, ... ) WINE_NORETURN;
extern void server_protocol_perror( const char *err ) WINE_NORETURN;
extern void wine_server_alloc_req( union generic_request *req, size_t size );
extern void wine_server_send_fd( int fd );
extern int wine_server_recv_fd( int handle, int cache );
extern const char *get_config_dir(void);
......@@ -1685,7 +1692,6 @@ struct __server_exception_frame
#define SERVER_CALL() (wine_server_call( &__req, sizeof(*req) ))
#define SERVER_CALL_ERR() (__server_call_err( &__req, sizeof(*req) ))
#define SERVER_CALL_FD(fd) (wine_server_call_fd( &__req, sizeof(*req), (fd) ))
extern int CLIENT_InitServer(void);
......
......@@ -55,7 +55,8 @@ struct cmsg_fd
};
static void *boot_thread_id;
static sigset_t block_set; /* signals to block during server calls */
static int fd_socket; /* socket to exchange file descriptors with the server */
/* die on a fatal error; use only during initialization */
static void fatal_error( const char *err, ... ) WINE_NORETURN;
......@@ -159,47 +160,6 @@ static void send_request( union generic_request *request )
}
/***********************************************************************
* send_request_fd
*
* Send a request to the server, passing a file descriptor.
*/
static void send_request_fd( union generic_request *request, int fd )
{
#ifndef HAVE_MSGHDR_ACCRIGHTS
struct cmsg_fd cmsg;
#endif
struct msghdr msghdr;
struct iovec vec;
int ret;
vec.iov_base = (void *)request;
vec.iov_len = sizeof(*request);
msghdr.msg_name = NULL;
msghdr.msg_namelen = 0;
msghdr.msg_iov = &vec;
msghdr.msg_iovlen = 1;
#ifdef HAVE_MSGHDR_ACCRIGHTS
msghdr.msg_accrights = (void *)&fd;
msghdr.msg_accrightslen = sizeof(fd);
#else /* HAVE_MSGHDR_ACCRIGHTS */
cmsg.len = sizeof(cmsg);
cmsg.level = SOL_SOCKET;
cmsg.type = SCM_RIGHTS;
cmsg.fd = fd;
msghdr.msg_control = &cmsg;
msghdr.msg_controllen = sizeof(cmsg);
msghdr.msg_flags = 0;
#endif /* HAVE_MSGHDR_ACCRIGHTS */
if ((ret = sendmsg( NtCurrentTeb()->socket, &msghdr, 0 )) == sizeof(*request)) return;
if (ret >= 0) server_protocol_error( "partial write %d\n", ret );
if (errno == EPIPE) SYSDEPS_ExitThread(0);
server_protocol_perror( "sendmsg" );
}
/***********************************************************************
* wait_reply
*
* Wait for a reply from the server.
......@@ -230,24 +190,64 @@ static void wait_reply( union generic_request *req )
*/
unsigned int wine_server_call( union generic_request *req, size_t size )
{
sigset_t old_set;
memset( (char *)req + size, 0, sizeof(*req) - size );
sigprocmask( SIG_BLOCK, &block_set, &old_set );
send_request( req );
wait_reply( req );
sigprocmask( SIG_SETMASK, &old_set, NULL );
return req->header.error;
}
/***********************************************************************
* wine_server_call_fd
* wine_server_send_fd
*
* Perform a server call, passing a file descriptor.
* Send a file descriptor to the server.
*/
unsigned int wine_server_call_fd( union generic_request *req, size_t size, int fd )
void wine_server_send_fd( int fd )
{
memset( (char *)req + size, 0, sizeof(*req) - size );
send_request_fd( req, fd );
wait_reply( req );
return req->header.error;
#ifndef HAVE_MSGHDR_ACCRIGHTS
struct cmsg_fd cmsg;
#endif
struct send_fd data;
struct msghdr msghdr;
struct iovec vec;
int ret;
vec.iov_base = (void *)&data;
vec.iov_len = sizeof(data);
msghdr.msg_name = NULL;
msghdr.msg_namelen = 0;
msghdr.msg_iov = &vec;
msghdr.msg_iovlen = 1;
#ifdef HAVE_MSGHDR_ACCRIGHTS
msghdr.msg_accrights = (void *)&fd;
msghdr.msg_accrightslen = sizeof(fd);
#else /* HAVE_MSGHDR_ACCRIGHTS */
cmsg.len = sizeof(cmsg);
cmsg.level = SOL_SOCKET;
cmsg.type = SCM_RIGHTS;
cmsg.fd = fd;
msghdr.msg_control = &cmsg;
msghdr.msg_controllen = sizeof(cmsg);
msghdr.msg_flags = 0;
#endif /* HAVE_MSGHDR_ACCRIGHTS */
data.tid = (void *)GetCurrentThreadId();
data.fd = fd;
for (;;)
{
if ((ret = sendmsg( fd_socket, &msghdr, 0 )) == sizeof(data)) return;
if (ret >= 0) server_protocol_error( "partial write %d\n", ret );
if (errno == EINTR) continue;
if (errno == EPIPE) SYSDEPS_ExitThread(0);
server_protocol_perror( "sendmsg" );
}
}
......@@ -563,6 +563,7 @@ int CLIENT_InitServer(void)
/* connect to the server */
fd = server_connect( oldcwd, serverdir );
fd_socket = dup(fd);
/* switch back to the starting directory */
if (oldcwd)
......@@ -570,6 +571,14 @@ int CLIENT_InitServer(void)
chdir( oldcwd );
free( oldcwd );
}
/* setup the signal mask */
sigemptyset( &block_set );
sigaddset( &block_set, SIGALRM );
sigaddset( &block_set, SIGIO );
sigaddset( &block_set, SIGINT );
sigaddset( &block_set, SIGHUP );
return fd;
}
......
......@@ -403,6 +403,12 @@ static int set_file_time( handle_t handle, time_t access_time, time_t write_time
if (!(file = get_file_obj( current->process, handle, GENERIC_WRITE )))
return 0;
if (!file->name)
{
set_error( STATUS_INVALID_HANDLE );
release_object( file );
return 0;
}
if (!access_time || !write_time)
{
struct stat st;
......@@ -453,19 +459,19 @@ DECL_HANDLER(create_file)
DECL_HANDLER(alloc_file_handle)
{
struct file *file;
int fd;
req->handle = 0;
if (current->pass_fd != -1)
if ((fd = thread_get_inflight_fd( current, req->fd )) == -1)
{
if ((file = create_file_for_fd( current->pass_fd, req->access,
FILE_SHARE_READ | FILE_SHARE_WRITE, 0 )))
set_error( STATUS_INVALID_HANDLE );
return;
}
if ((file = create_file_for_fd( fd, req->access, FILE_SHARE_READ | FILE_SHARE_WRITE, 0 )))
{
req->handle = alloc_handle( current->process, file, req->access, 0 );
release_object( file );
}
current->pass_fd = -1;
}
else set_error( STATUS_INVALID_PARAMETER );
}
/* get a Unix fd to access a file */
......
......@@ -36,6 +36,7 @@ static int running_processes;
static void process_dump( struct object *obj, int verbose );
static int process_signaled( struct object *obj, struct thread *thread );
static void process_poll_event( struct object *obj, int event );
static void process_destroy( struct object *obj );
static const struct object_ops process_ops =
......@@ -47,7 +48,7 @@ static const struct object_ops process_ops =
process_signaled, /* signaled */
no_satisfied, /* satisfied */
NULL, /* get_poll_events */
NULL, /* poll_event */
process_poll_event, /* poll_event */
no_get_fd, /* get_fd */
no_flush, /* flush */
no_get_file_info, /* get_file_info */
......@@ -147,11 +148,7 @@ struct thread *create_process( int fd )
struct process *process;
struct thread *thread = NULL;
if (!(process = alloc_object( &process_ops, -1 )))
{
close( fd );
return NULL;
}
if (!(process = alloc_object( &process_ops, fd ))) return NULL;
process->next = NULL;
process->prev = NULL;
process->thread_list = NULL;
......@@ -183,8 +180,9 @@ struct thread *create_process( int fd )
if (!(process->init_event = create_event( NULL, 0, 1, 0 ))) goto error;
/* create the main thread */
if (!(thread = create_thread( fd, process ))) goto error;
if (!(thread = create_thread( dup(fd), process ))) goto error;
set_select_events( &process->obj, POLLIN ); /* start listening to events */
release_object( process );
return thread;
......@@ -311,6 +309,15 @@ static int process_signaled( struct object *obj, struct thread *thread )
}
static void process_poll_event( struct object *obj, int event )
{
struct process *process = (struct process *)obj;
assert( obj->ops == &process_ops );
if (event & (POLLERR | POLLHUP)) set_select_events( obj, -1 );
else if (event & POLLIN) receive_fd( process );
}
static void startup_info_destroy( struct object *obj )
{
struct startup_info *info = (struct startup_info *)obj;
......@@ -491,7 +498,7 @@ void resume_process( struct process *process )
}
/* kill a process on the spot */
static void kill_process( struct process *process, struct thread *skip, int exit_code )
void kill_process( struct process *process, struct thread *skip, int exit_code )
{
struct thread *thread = process->thread_list;
while (thread)
......
......@@ -82,6 +82,7 @@ extern void remove_process_thread( struct process *process,
struct thread *thread );
extern void suspend_process( struct process *process );
extern void resume_process( struct process *process );
extern void kill_process( struct process *process, struct thread *skip, int exit_code );
extern void kill_debugged_processes( struct thread *debugger, int exit_code );
extern struct process_snapshot *process_snap( int *count );
extern struct module_snapshot *module_snap( struct process *process, int *count );
......
......@@ -204,8 +204,6 @@ void send_reply( struct thread *thread, union generic_request *request )
{
int ret;
assert (thread->pass_fd == -1);
if (debug_level) trace_reply( thread, request );
request->header.error = thread->error;
......@@ -221,43 +219,68 @@ void send_reply( struct thread *thread, union generic_request *request )
}
}
/* read a message from a client that has something to say */
void read_request( struct thread *thread )
/* receive a file descriptor on the process socket */
int receive_fd( struct process *process )
{
union generic_request req;
int ret;
struct send_fd data;
int fd, ret;
#ifdef HAVE_MSGHDR_ACCRIGHTS
msghdr.msg_accrightslen = sizeof(int);
msghdr.msg_accrights = (void *)&thread->pass_fd;
msghdr.msg_accrights = (void *)&fd;
#else /* HAVE_MSGHDR_ACCRIGHTS */
msghdr.msg_control = &cmsg;
msghdr.msg_controllen = sizeof(cmsg);
cmsg.fd = -1;
#endif /* HAVE_MSGHDR_ACCRIGHTS */
assert( thread->pass_fd == -1 );
myiovec.iov_base = &data;
myiovec.iov_len = sizeof(data);
myiovec.iov_base = &req;
myiovec.iov_len = sizeof(req);
ret = recvmsg( thread->obj.fd, &msghdr, 0 );
ret = recvmsg( process->obj.fd, &msghdr, 0 );
#ifndef HAVE_MSGHDR_ACCRIGHTS
thread->pass_fd = cmsg.fd;
fd = cmsg.fd;
#endif
if (ret == sizeof(req))
if (ret == sizeof(data))
{
call_req_handler( thread, &req );
thread->pass_fd = -1;
return;
struct thread *thread = get_thread_from_id( data.tid );
if (!thread || thread->process != process)
{
if (debug_level)
fprintf( stderr, "%08x: *fd* %d <- %d bad thread id\n",
(unsigned int)data.tid, data.fd, fd );
close( fd );
}
if (!ret) /* closed pipe */
kill_thread( thread, 0 );
else if (ret > 0)
fatal_protocol_error( thread, "partial recvmsg %d\n", ret );
else
fatal_protocol_perror( thread, "recvmsg" );
{
if (debug_level)
fprintf( stderr, "%08x: *fd* %d <- %d\n",
(unsigned int)thread, data.fd, fd );
thread_add_inflight_fd( thread, data.fd, fd );
}
return 0;
}
if (!ret)
{
set_select_events( &process->obj, -1 ); /* stop waiting on it */
}
else if (ret > 0)
{
fprintf( stderr, "Protocol error: process %p: partial recvmsg %d for fd\n", process, ret );
kill_process( process, NULL, 1 );
}
else if (ret < 0)
{
if (errno != EWOULDBLOCK && errno != EAGAIN)
{
fprintf( stderr, "Protocol error: process %p: ", process );
perror( "recvmsg" );
kill_process( process, NULL, 1 );
}
}
return -1;
}
/* send the wakeup signal to a thread */
......@@ -280,7 +303,7 @@ int send_client_fd( struct thread *thread, int fd, handle_t handle )
int ret;
if (debug_level)
fprintf( stderr, "%08x: *fd* %d = %d\n", (unsigned int)thread, handle, fd );
fprintf( stderr, "%08x: *fd* %d -> %d\n", (unsigned int)thread, handle, fd );
#ifdef HAVE_MSGHDR_ACCRIGHTS
msghdr.msg_accrightslen = sizeof(fd);
......
......@@ -31,7 +31,7 @@ extern void fatal_protocol_error( struct thread *thread, const char *err, ... );
extern void fatal_error( const char *err, ... ) WINE_NORETURN;
extern void fatal_perror( const char *err, ... ) WINE_NORETURN;
extern const char *get_config_dir(void);
extern void read_request( struct thread *thread );
extern int receive_fd( struct process *process );
extern int send_thread_wakeup( struct thread *thread, int signaled );
extern int send_client_fd( struct thread *thread, int fd, handle_t handle );
extern void send_reply( struct thread *thread, union generic_request *request );
......
......@@ -61,7 +61,7 @@ struct thread_apc
static void dump_thread( struct object *obj, int verbose );
static int thread_signaled( struct object *obj, struct thread *thread );
extern void thread_poll_event( struct object *obj, int event );
static void thread_poll_event( struct object *obj, int event );
static void destroy_thread( struct object *obj );
static struct thread_apc *thread_dequeue_apc( struct thread *thread, int system_only );
......@@ -139,15 +139,10 @@ static int alloc_client_buffer( struct thread *thread )
return 0;
}
/* create a new thread */
struct thread *create_thread( int fd, struct process *process )
/* initialize the structure for a newly allocated thread */
inline static void init_thread_structure( struct thread *thread )
{
struct thread *thread;
int flags = fcntl( fd, F_GETFL, 0 );
fcntl( fd, F_SETFL, flags | O_NONBLOCK );
if (!(thread = alloc_object( &thread_ops, fd ))) return NULL;
int i;
thread->unix_pid = 0; /* not known yet */
thread->context = NULL;
......@@ -163,7 +158,6 @@ struct thread *create_thread( int fd, struct process *process )
thread->user_apc.head = NULL;
thread->user_apc.tail = NULL;
thread->error = 0;
thread->pass_fd = -1;
thread->request_fd = NULL;
thread->reply_fd = -1;
thread->wait_fd = -1;
......@@ -177,8 +171,24 @@ struct thread *create_thread( int fd, struct process *process )
thread->suspend = 0;
thread->buffer = (void *)-1;
thread->last_req = REQ_get_thread_buffer;
thread->process = (struct process *)grab_object( process );
for (i = 0; i < MAX_INFLIGHT_FDS; i++)
thread->inflight[i].server = thread->inflight[i].client = -1;
}
/* create a new thread */
struct thread *create_thread( int fd, struct process *process )
{
struct thread *thread;
int flags = fcntl( fd, F_GETFL, 0 );
fcntl( fd, F_SETFL, flags | O_NONBLOCK );
if (!(thread = alloc_object( &thread_ops, fd ))) return NULL;
init_thread_structure( thread );
thread->process = (struct process *)grab_object( process );
if (!current) current = thread;
if (!booting_thread) /* first thread ever */
......@@ -190,7 +200,9 @@ struct thread *create_thread( int fd, struct process *process )
if ((thread->next = first_thread) != NULL) thread->next->prev = thread;
first_thread = thread;
#if 0
set_select_events( &thread->obj, POLLIN ); /* start listening to events */
#endif
if (!alloc_client_buffer( thread )) goto error;
return thread;
......@@ -200,13 +212,41 @@ struct thread *create_thread( int fd, struct process *process )
}
/* handle a client event */
void thread_poll_event( struct object *obj, int event )
static void thread_poll_event( struct object *obj, int event )
{
struct thread *thread = (struct thread *)obj;
assert( obj->ops == &thread_ops );
if (event & (POLLERR | POLLHUP)) kill_thread( thread, 0 );
#if 0
else if (event & POLLIN) read_request( thread );
#endif
}
/* cleanup everything that is no longer needed by a dead thread */
/* used by destroy_thread and kill_thread */
static void cleanup_thread( struct thread *thread )
{
int i;
struct thread_apc *apc;
while ((apc = thread_dequeue_apc( thread, 0 ))) free( apc );
if (thread->buffer != (void *)-1) munmap( thread->buffer, MAX_REQUEST_LENGTH );
if (thread->reply_fd != -1) close( thread->reply_fd );
if (thread->wait_fd != -1) close( thread->wait_fd );
if (thread->request_fd) release_object( thread->request_fd );
for (i = 0; i < MAX_INFLIGHT_FDS; i++)
{
if (thread->inflight[i].client != -1)
{
close( thread->inflight[i].server );
thread->inflight[i].client = thread->inflight[i].server = -1;
}
}
thread->buffer = (void *)-1;
thread->reply_fd = -1;
thread->wait_fd = -1;
thread->request_fd = NULL;
}
/* destroy a thread when its refcount is 0 */
......@@ -224,11 +264,7 @@ static void destroy_thread( struct object *obj )
while ((apc = thread_dequeue_apc( thread, 0 ))) free( apc );
if (thread->info) release_object( thread->info );
if (thread->queue) release_object( thread->queue );
if (thread->buffer != (void *)-1) munmap( thread->buffer, MAX_REQUEST_LENGTH );
if (thread->reply_fd != -1) close( thread->reply_fd );
if (thread->wait_fd != -1) close( thread->wait_fd );
if (thread->pass_fd != -1) close( thread->pass_fd );
if (thread->request_fd) release_object( thread->request_fd );
cleanup_thread( thread );
}
/* dump a thread on stdout for debugging purposes */
......@@ -599,6 +635,62 @@ static struct thread_apc *thread_dequeue_apc( struct thread *thread, int system_
return apc;
}
/* add an fd to the inflight list */
/* return list index, or -1 on error */
int thread_add_inflight_fd( struct thread *thread, int client, int server )
{
int i;
if (server == -1) return -1;
if (client == -1)
{
close( server );
return -1;
}
/* first check if we already have an entry for this fd */
for (i = 0; i < MAX_INFLIGHT_FDS; i++)
if (thread->inflight[i].client == client)
{
close( thread->inflight[i].server );
thread->inflight[i].server = server;
return i;
}
/* now find a free spot to store it */
for (i = 0; i < MAX_INFLIGHT_FDS; i++)
if (thread->inflight[i].client == -1)
{
thread->inflight[i].client = client;
thread->inflight[i].server = server;
return i;
}
return -1;
}
/* get an inflight fd and purge it from the list */
/* the fd must be closed when no longer used */
int thread_get_inflight_fd( struct thread *thread, int client )
{
int i, ret;
if (client == -1) return -1;
do
{
for (i = 0; i < MAX_INFLIGHT_FDS; i++)
{
if (thread->inflight[i].client == client)
{
ret = thread->inflight[i].server;
thread->inflight[i].server = thread->inflight[i].client = -1;
return ret;
}
}
} while (!receive_fd( thread->process )); /* in case it is still in the socket buffer */
return -1;
}
/* retrieve an LDT selector entry */
static void get_selector_entry( struct thread *thread, int entry,
unsigned int *base, unsigned int *limit,
......@@ -649,14 +741,7 @@ void kill_thread( struct thread *thread, int violent_death )
wake_up( &thread->obj, 0 );
detach_thread( thread, violent_death ? SIGTERM : 0 );
remove_select_user( &thread->obj );
release_object( thread->request_fd );
close( thread->reply_fd );
close( thread->wait_fd );
munmap( thread->buffer, MAX_REQUEST_LENGTH );
thread->request_fd = NULL;
thread->reply_fd = -1;
thread->wait_fd = -1;
thread->buffer = (void *)-1;
cleanup_thread( thread );
release_object( thread );
}
......
......@@ -36,6 +36,14 @@ struct apc_queue
struct thread_apc *tail;
};
/* descriptor for fds currently in flight from client to server */
struct inflight_fd
{
int client; /* fd on the client side (or -1 if entry is free) */
int server; /* fd on the server side */
};
#define MAX_INFLIGHT_FDS 16 /* max number of fds in flight per thread */
struct thread
{
struct object obj; /* object header */
......@@ -52,9 +60,9 @@ struct thread
struct thread_wait *wait; /* current wait condition if sleeping */
struct apc_queue system_apc; /* queue of system async procedure calls */
struct apc_queue user_apc; /* queue of user async procedure calls */
struct inflight_fd inflight[MAX_INFLIGHT_FDS]; /* fds currently in flight */
unsigned int error; /* current error code */
struct object *request_fd; /* fd for receiving client requests */
int pass_fd; /* fd to pass to the client */
int reply_fd; /* fd to send a reply to a client */
int wait_fd; /* fd to use to wake a sleeping client */
enum run_state state; /* running state */
......@@ -96,6 +104,8 @@ extern void wake_up( struct object *obj, int max );
extern int thread_queue_apc( struct thread *thread, struct object *owner, void *func,
enum apc_type type, int system, int nb_args, ... );
extern void thread_cancel_apc( struct thread *thread, struct object *owner, int system );
extern int thread_add_inflight_fd( struct thread *thread, int client, int server );
extern int thread_get_inflight_fd( struct thread *thread, int client );
extern struct thread_snapshot *thread_snap( int *count );
/* ptrace functions */
......
......@@ -670,7 +670,8 @@ static void dump_create_file_reply( const struct create_file_request *req )
static void dump_alloc_file_handle_request( const struct alloc_file_handle_request *req )
{
fprintf( stderr, " access=%08x", req->access );
fprintf( stderr, " access=%08x,", req->access );
fprintf( stderr, " fd=%d", req->fd );
}
static void dump_alloc_file_handle_reply( const struct alloc_file_handle_request *req )
......@@ -1889,11 +1890,9 @@ void trace_request( struct thread *thread, const union generic_request *request
fprintf( stderr, "%08x: %s(", (unsigned int)thread, req_names[req] );
cur_pos = 0;
req_dumpers[req]( request );
fprintf( stderr, " )\n" );
}
else
fprintf( stderr, "%08x: %d(", (unsigned int)thread, req );
if (thread->pass_fd != -1) fprintf( stderr, " ) fd=%d\n", thread->pass_fd );
else fprintf( stderr, " )\n" );
else fprintf( stderr, "%08x: %d(???)\n", (unsigned int)thread, req );
}
void trace_reply( struct thread *thread, const union generic_request *request )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment