Commit 3eda59af authored by Paul Gofman's avatar Paul Gofman Committed by Alexandre Julliard

winhttp: Replace pending read cancel in WinHttpWebSocketClose() with a generic cancel_queue().

parent 66ff3594
...@@ -182,7 +182,7 @@ static void CALLBACK task_callback( TP_CALLBACK_INSTANCE *instance, void *ctx ) ...@@ -182,7 +182,7 @@ static void CALLBACK task_callback( TP_CALLBACK_INSTANCE *instance, void *ctx )
task = get_next_task( queue, NULL ); task = get_next_task( queue, NULL );
while (task) while (task)
{ {
task->callback( task ); task->callback( task, FALSE );
/* Queue object may be freed by release_object() unless there is another task referencing it. */ /* Queue object may be freed by release_object() unless there is another task referencing it. */
next_task = get_next_task( queue, task ); next_task = get_next_task( queue, task );
release_object( task->obj ); release_object( task->obj );
...@@ -199,6 +199,7 @@ static DWORD queue_task( struct queue *queue, TASK_CALLBACK task, struct task_he ...@@ -199,6 +199,7 @@ static DWORD queue_task( struct queue *queue, TASK_CALLBACK task, struct task_he
TRACE("queueing %p in %p\n", task_hdr, queue); TRACE("queueing %p in %p\n", task_hdr, queue);
task_hdr->callback = task; task_hdr->callback = task;
task_hdr->completion_sent = 0;
task_hdr->refs = 1; task_hdr->refs = 1;
task_hdr->obj = obj; task_hdr->obj = obj;
addref_object( obj ); addref_object( obj );
...@@ -225,6 +226,35 @@ static DWORD queue_task( struct queue *queue, TASK_CALLBACK task, struct task_he ...@@ -225,6 +226,35 @@ static DWORD queue_task( struct queue *queue, TASK_CALLBACK task, struct task_he
return ERROR_SUCCESS; return ERROR_SUCCESS;
} }
static BOOL task_needs_completion( struct task_header *task_hdr )
{
return !InterlockedExchange( &task_hdr->completion_sent, 1 );
}
static void cancel_queue( struct queue *queue )
{
struct task_header *task_hdr, *found;
while (1)
{
AcquireSRWLockExclusive( &queue->lock );
found = NULL;
LIST_FOR_EACH_ENTRY( task_hdr, &queue->queued_tasks, struct task_header, entry )
{
if (task_needs_completion( task_hdr ))
{
found = task_hdr;
addref_task( found );
break;
}
}
ReleaseSRWLockExclusive( &queue->lock );
if (!found) break;
found->callback( found, TRUE );
release_task( found );
}
}
static void free_header( struct header *header ) static void free_header( struct header *header )
{ {
free( header->field ); free( header->field );
...@@ -2231,11 +2261,13 @@ end: ...@@ -2231,11 +2261,13 @@ end:
return ret; return ret;
} }
static void task_send_request( void *ctx ) static void task_send_request( void *ctx, BOOL abort )
{ {
struct send_request *s = ctx; struct send_request *s = ctx;
struct request *request = (struct request *)s->task_hdr.obj; struct request *request = (struct request *)s->task_hdr.obj;
if (abort) return;
TRACE( "running %p\n", ctx ); TRACE( "running %p\n", ctx );
send_request( request, s->headers, s->headers_len, s->optional, s->optional_len, s->total_len, s->context, TRUE ); send_request( request, s->headers, s->headers_len, s->optional, s->optional_len, s->total_len, s->context, TRUE );
...@@ -2813,11 +2845,13 @@ static DWORD receive_response( struct request *request, BOOL async ) ...@@ -2813,11 +2845,13 @@ static DWORD receive_response( struct request *request, BOOL async )
return ret; return ret;
} }
static void task_receive_response( void *ctx ) static void task_receive_response( void *ctx, BOOL abort )
{ {
struct receive_response *r = ctx; struct receive_response *r = ctx;
struct request *request = (struct request *)r->task_hdr.obj; struct request *request = (struct request *)r->task_hdr.obj;
if (abort) return;
TRACE("running %p\n", ctx); TRACE("running %p\n", ctx);
receive_response( request, TRUE ); receive_response( request, TRUE );
} }
...@@ -2905,11 +2939,13 @@ done: ...@@ -2905,11 +2939,13 @@ done:
return ret; return ret;
} }
static void task_query_data_available( void *ctx ) static void task_query_data_available( void *ctx, BOOL abort )
{ {
struct query_data *q = ctx; struct query_data *q = ctx;
struct request *request = (struct request *)q->task_hdr.obj; struct request *request = (struct request *)q->task_hdr.obj;
if (abort) return;
TRACE("running %p\n", ctx); TRACE("running %p\n", ctx);
query_data_available( request, q->available, TRUE ); query_data_available( request, q->available, TRUE );
} }
...@@ -2956,11 +2992,13 @@ BOOL WINAPI WinHttpQueryDataAvailable( HINTERNET hrequest, LPDWORD available ) ...@@ -2956,11 +2992,13 @@ BOOL WINAPI WinHttpQueryDataAvailable( HINTERNET hrequest, LPDWORD available )
return !ret || ret == ERROR_IO_PENDING; return !ret || ret == ERROR_IO_PENDING;
} }
static void task_read_data( void *ctx ) static void task_read_data( void *ctx, BOOL abort )
{ {
struct read_data *r = ctx; struct read_data *r = ctx;
struct request *request = (struct request *)r->task_hdr.obj; struct request *request = (struct request *)r->task_hdr.obj;
if (abort) return;
TRACE("running %p\n", ctx); TRACE("running %p\n", ctx);
read_data( request, r->buffer, r->to_read, r->read, TRUE ); read_data( request, r->buffer, r->to_read, r->read, TRUE );
} }
...@@ -3031,11 +3069,13 @@ static DWORD write_data( struct request *request, const void *buffer, DWORD to_w ...@@ -3031,11 +3069,13 @@ static DWORD write_data( struct request *request, const void *buffer, DWORD to_w
return ret; return ret;
} }
static void task_write_data( void *ctx ) static void task_write_data( void *ctx, BOOL abort )
{ {
struct write_data *w = ctx; struct write_data *w = ctx;
struct request *request = (struct request *)w->task_hdr.obj; struct request *request = (struct request *)w->task_hdr.obj;
if (abort) return;
TRACE("running %p\n", ctx); TRACE("running %p\n", ctx);
write_data( request, w->buffer, w->to_write, w->written, TRUE ); write_data( request, w->buffer, w->to_write, w->written, TRUE );
} }
...@@ -3315,13 +3355,10 @@ static void send_io_complete( struct object_header *hdr ) ...@@ -3315,13 +3355,10 @@ static void send_io_complete( struct object_header *hdr )
} }
/* returns FALSE if sending callback should be omitted. */ /* returns FALSE if sending callback should be omitted. */
static BOOL receive_io_complete( struct socket *socket ) static void receive_io_complete( struct socket *socket )
{ {
LONG count = InterlockedDecrement( &socket->hdr.pending_receives ); LONG count = InterlockedDecrement( &socket->hdr.pending_receives );
assert( count >= 0 || socket->state == SOCKET_STATE_CLOSED); assert( count >= 0 );
/* count is reset to zero during websocket close so if count went negative
* then WinHttpWebSocketClose() is to send the callback. */
return count >= 0;
} }
static BOOL socket_can_send( struct socket *socket ) static BOOL socket_can_send( struct socket *socket )
...@@ -3428,12 +3465,14 @@ static DWORD socket_send( struct socket *socket, WINHTTP_WEB_SOCKET_BUFFER_TYPE ...@@ -3428,12 +3465,14 @@ static DWORD socket_send( struct socket *socket, WINHTTP_WEB_SOCKET_BUFFER_TYPE
return send_frame( socket, opcode, 0, buf, len, final, ovr ); return send_frame( socket, opcode, 0, buf, len, final, ovr );
} }
static void task_socket_send( void *ctx ) static void task_socket_send( void *ctx, BOOL abort )
{ {
struct socket_send *s = ctx; struct socket_send *s = ctx;
struct socket *socket = (struct socket *)s->task_hdr.obj; struct socket *socket = (struct socket *)s->task_hdr.obj;
DWORD ret; DWORD ret;
if (abort) return;
TRACE("running %p\n", ctx); TRACE("running %p\n", ctx);
if (s->complete_async) ret = complete_send_frame( socket, &s->ovr, s->buf ); if (s->complete_async) ret = complete_send_frame( socket, &s->ovr, s->buf );
...@@ -3623,11 +3662,13 @@ static DWORD receive_frame( struct socket *socket, DWORD *ret_len, enum socket_o ...@@ -3623,11 +3662,13 @@ static DWORD receive_frame( struct socket *socket, DWORD *ret_len, enum socket_o
return ERROR_SUCCESS; return ERROR_SUCCESS;
} }
static void task_socket_send_pong( void *ctx ) static void task_socket_send_pong( void *ctx, BOOL abort )
{ {
struct socket_send *s = ctx; struct socket_send *s = ctx;
struct socket *socket = (struct socket *)s->task_hdr.obj; struct socket *socket = (struct socket *)s->task_hdr.obj;
if (abort) return;
TRACE("running %p\n", ctx); TRACE("running %p\n", ctx);
if (s->complete_async) complete_send_frame( socket, &s->ovr, NULL ); if (s->complete_async) complete_send_frame( socket, &s->ovr, NULL );
...@@ -3858,17 +3899,23 @@ static void socket_receive_complete( struct socket *socket, DWORD ret, WINHTTP_W ...@@ -3858,17 +3899,23 @@ static void socket_receive_complete( struct socket *socket, DWORD ret, WINHTTP_W
} }
} }
static void task_socket_receive( void *ctx ) static void task_socket_receive( void *ctx, BOOL abort )
{ {
struct socket_receive *r = ctx; struct socket_receive *r = ctx;
struct socket *socket = (struct socket *)r->task_hdr.obj; struct socket *socket = (struct socket *)r->task_hdr.obj;
DWORD ret, count; DWORD ret, count;
WINHTTP_WEB_SOCKET_BUFFER_TYPE type; WINHTTP_WEB_SOCKET_BUFFER_TYPE type;
if (abort)
{
socket_receive_complete( socket, ERROR_WINHTTP_OPERATION_CANCELLED, 0, 0 );
return;
}
TRACE("running %p\n", ctx); TRACE("running %p\n", ctx);
ret = socket_receive( socket, r->buf, r->len, &count, &type ); ret = socket_receive( socket, r->buf, r->len, &count, &type );
receive_io_complete( socket );
if (receive_io_complete( socket )) if (task_needs_completion( &r->task_hdr ))
socket_receive_complete( socket, ret, type, count ); socket_receive_complete( socket, ret, type, count );
} }
...@@ -3939,12 +3986,14 @@ static void socket_shutdown_complete( struct socket *socket, DWORD ret ) ...@@ -3939,12 +3986,14 @@ static void socket_shutdown_complete( struct socket *socket, DWORD ret )
} }
} }
static void task_socket_shutdown( void *ctx ) static void task_socket_shutdown( void *ctx, BOOL abort )
{ {
struct socket_shutdown *s = ctx; struct socket_shutdown *s = ctx;
struct socket *socket = (struct socket *)s->task_hdr.obj; struct socket *socket = (struct socket *)s->task_hdr.obj;
DWORD ret; DWORD ret;
if (abort) return;
TRACE("running %p\n", ctx); TRACE("running %p\n", ctx);
if (s->complete_async) ret = complete_send_frame( socket, &s->ovr, s->reason ); if (s->complete_async) ret = complete_send_frame( socket, &s->ovr, s->reason );
...@@ -4073,15 +4122,18 @@ static void socket_close_complete( struct socket *socket, DWORD ret ) ...@@ -4073,15 +4122,18 @@ static void socket_close_complete( struct socket *socket, DWORD ret )
} }
} }
static void task_socket_close( void *ctx ) static void task_socket_close( void *ctx, BOOL abort )
{ {
struct socket_shutdown *s = ctx; struct socket_shutdown *s = ctx;
struct socket *socket = (struct socket *)s->task_hdr.obj; struct socket *socket = (struct socket *)s->task_hdr.obj;
DWORD ret; DWORD ret;
if (abort) return;
TRACE("running %p\n", ctx); TRACE("running %p\n", ctx);
ret = socket_close( socket ); ret = socket_close( socket );
receive_io_complete( socket );
socket_close_complete( socket, ret ); socket_close_complete( socket, ret );
} }
...@@ -4113,25 +4165,14 @@ DWORD WINAPI WinHttpWebSocketClose( HINTERNET hsocket, USHORT status, void *reas ...@@ -4113,25 +4165,14 @@ DWORD WINAPI WinHttpWebSocketClose( HINTERNET hsocket, USHORT status, void *reas
if (socket->request->connect->hdr.flags & WINHTTP_FLAG_ASYNC) if (socket->request->connect->hdr.flags & WINHTTP_FLAG_ASYNC)
{ {
/* When closing the socket pending receives are cancelled. Setting socket->hdr.pending_receives to zero pending_receives = InterlockedIncrement( &socket->hdr.pending_receives );
* will prevent pending receives from sending callbacks. */ cancel_queue( &socket->recv_q );
pending_receives = InterlockedExchange( &socket->hdr.pending_receives, 0 );
assert( pending_receives >= 0 );
if (pending_receives)
{
WINHTTP_WEB_SOCKET_ASYNC_RESULT result;
result.AsyncResult.dwResult = 0;
result.AsyncResult.dwError = ERROR_WINHTTP_OPERATION_CANCELLED;
result.Operation = WINHTTP_WEB_SOCKET_RECEIVE_OPERATION;
send_callback( &socket->hdr, WINHTTP_CALLBACK_STATUS_REQUEST_ERROR, &result, sizeof(result) );
}
} }
if (prev_state < SOCKET_STATE_SHUTDOWN if (prev_state < SOCKET_STATE_SHUTDOWN
&& (ret = send_socket_shutdown( socket, status, reason, len, FALSE ))) goto done; && (ret = send_socket_shutdown( socket, status, reason, len, FALSE ))) goto done;
if (!pending_receives && socket->close_frame_received) if (pending_receives == 1 && socket->close_frame_received)
{ {
if (socket->request->connect->hdr.flags & WINHTTP_FLAG_ASYNC) if (socket->request->connect->hdr.flags & WINHTTP_FLAG_ASYNC)
socket_close_complete( socket, socket->close_frame_receive_err ); socket_close_complete( socket, socket->close_frame_receive_err );
...@@ -4144,7 +4185,10 @@ DWORD WINAPI WinHttpWebSocketClose( HINTERNET hsocket, USHORT status, void *reas ...@@ -4144,7 +4185,10 @@ DWORD WINAPI WinHttpWebSocketClose( HINTERNET hsocket, USHORT status, void *reas
if (!(s = calloc( 1, sizeof(*s) ))) return FALSE; if (!(s = calloc( 1, sizeof(*s) ))) return FALSE;
if ((ret = queue_task( &socket->recv_q, task_socket_close, &s->task_hdr, &socket->hdr ))) if ((ret = queue_task( &socket->recv_q, task_socket_close, &s->task_hdr, &socket->hdr )))
{
InterlockedDecrement( &socket->hdr.pending_receives );
free( s ); free( s );
}
} }
else ret = socket_close( socket ); else ret = socket_close( socket );
......
...@@ -274,7 +274,7 @@ struct socket ...@@ -274,7 +274,7 @@ struct socket
BOOL last_receive_final; BOOL last_receive_final;
}; };
typedef void (*TASK_CALLBACK)( void *ctx ); typedef void (*TASK_CALLBACK)( void *ctx, BOOL abort );
struct task_header struct task_header
{ {
...@@ -282,6 +282,7 @@ struct task_header ...@@ -282,6 +282,7 @@ struct task_header
TASK_CALLBACK callback; TASK_CALLBACK callback;
struct object_header *obj; struct object_header *obj;
volatile LONG refs; volatile LONG refs;
volatile LONG completion_sent;
}; };
struct send_request struct send_request
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment