Commit d8948da1 authored by Jacek Caban's avatar Jacek Caban Committed by Alexandre Julliard

wininet: Improved non-blocking mode in secure NETCON_recv.

parent 0767e060
...@@ -86,6 +86,11 @@ ...@@ -86,6 +86,11 @@
#define RESPONSE_TIMEOUT 30 /* FROM internet.c */ #define RESPONSE_TIMEOUT 30 /* FROM internet.c */
#ifdef MSG_DONTWAIT
#define WINE_MSG_DONTWAIT MSG_DONTWAIT
#else
#define WINE_MSG_DONTWAIT 0
#endif
WINE_DEFAULT_DEBUG_CHANNEL(wininet); WINE_DEFAULT_DEBUG_CHANNEL(wininet);
...@@ -755,37 +760,53 @@ DWORD NETCON_send(netconn_t *connection, const void *msg, size_t len, int flags, ...@@ -755,37 +760,53 @@ DWORD NETCON_send(netconn_t *connection, const void *msg, size_t len, int flags,
} }
} }
static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *ret_size, BOOL *eof) static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, blocking_mode_t mode, SIZE_T *ret_size, BOOL *eof)
{ {
const SIZE_T ssl_buf_size = conn->ssl_sizes.cbHeader+conn->ssl_sizes.cbMaximumMessage+conn->ssl_sizes.cbTrailer; const SIZE_T ssl_buf_size = conn->ssl_sizes.cbHeader+conn->ssl_sizes.cbMaximumMessage+conn->ssl_sizes.cbTrailer;
SecBuffer bufs[4]; SecBuffer bufs[4];
SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs}; SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
SSIZE_T size, buf_len; SSIZE_T size, buf_len = 0;
blocking_mode_t tmp_mode;
int i; int i;
SECURITY_STATUS res; SECURITY_STATUS res;
assert(conn->extra_len < ssl_buf_size); assert(conn->extra_len < ssl_buf_size);
/* BLOCKING_WAITALL is handled by caller */
if(mode == BLOCKING_WAITALL)
mode = BLOCKING_ALLOW;
if(conn->extra_len) { if(conn->extra_len) {
memcpy(conn->ssl_buf, conn->extra_buf, conn->extra_len); memcpy(conn->ssl_buf, conn->extra_buf, conn->extra_len);
buf_len = conn->extra_len; buf_len = conn->extra_len;
conn->extra_len = 0; conn->extra_len = 0;
heap_free(conn->extra_buf); heap_free(conn->extra_buf);
conn->extra_buf = NULL; conn->extra_buf = NULL;
}else { }
buf_len = recv(conn->socket, conn->ssl_buf+conn->extra_len, ssl_buf_size-conn->extra_len, 0);
if(buf_len < 0) {
WARN("recv failed\n");
return FALSE;
}
tmp_mode = buf_len ? BLOCKING_DISALLOW : mode;
set_socket_blocking(conn->socket, tmp_mode);
size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, tmp_mode == BLOCKING_ALLOW ? 0 : WINE_MSG_DONTWAIT);
if(size < 0) {
if(!buf_len) { if(!buf_len) {
*eof = TRUE; if(errno == EAGAIN || errno == EWOULDBLOCK) {
return TRUE; TRACE("would block\n");
return WSAEWOULDBLOCK;
}
WARN("recv failed\n");
return ERROR_INTERNET_CONNECTION_ABORTED;
} }
}else {
buf_len += size;
}
*ret_size = buf_len;
if(!buf_len) {
*eof = TRUE;
return ERROR_SUCCESS;
} }
*ret_size = 0;
*eof = FALSE; *eof = FALSE;
do { do {
...@@ -801,19 +822,34 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T * ...@@ -801,19 +822,34 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *
case SEC_I_CONTEXT_EXPIRED: case SEC_I_CONTEXT_EXPIRED:
TRACE("context expired\n"); TRACE("context expired\n");
*eof = TRUE; *eof = TRUE;
return TRUE; return ERROR_SUCCESS;
case SEC_E_INCOMPLETE_MESSAGE: case SEC_E_INCOMPLETE_MESSAGE:
assert(buf_len < ssl_buf_size); assert(buf_len < ssl_buf_size);
size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, 0); set_socket_blocking(conn->socket, mode);
if(size < 1) size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, mode == BLOCKING_ALLOW ? 0 : WINE_MSG_DONTWAIT);
return FALSE; if(size < 1) {
if(size < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
TRACE("would block\n");
/* FIXME: Optimize extra_buf usage. */
conn->extra_buf = heap_alloc(buf_len);
if(!conn->extra_buf)
return ERROR_NOT_ENOUGH_MEMORY;
conn->extra_len = buf_len;
memcpy(conn->extra_buf, conn->ssl_buf, conn->extra_len);
return WSAEWOULDBLOCK;
}
return ERROR_INTERNET_CONNECTION_ABORTED;
}
buf_len += size; buf_len += size;
continue; continue;
default: default:
WARN("failed: %08x\n", res); WARN("failed: %08x\n", res);
return FALSE; return ERROR_INTERNET_CONNECTION_ABORTED;
} }
} while(res != SEC_E_OK); } while(res != SEC_E_OK);
...@@ -825,7 +861,7 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T * ...@@ -825,7 +861,7 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *
assert(!conn->peek_len); assert(!conn->peek_len);
conn->peek_msg_mem = conn->peek_msg = heap_alloc(bufs[i].cbBuffer - size); conn->peek_msg_mem = conn->peek_msg = heap_alloc(bufs[i].cbBuffer - size);
if(!conn->peek_msg) if(!conn->peek_msg)
return FALSE; return ERROR_NOT_ENOUGH_MEMORY;
conn->peek_len = bufs[i].cbBuffer-size; conn->peek_len = bufs[i].cbBuffer-size;
memcpy(conn->peek_msg, (char*)bufs[i].pvBuffer+size, conn->peek_len); memcpy(conn->peek_msg, (char*)bufs[i].pvBuffer+size, conn->peek_len);
} }
...@@ -838,14 +874,14 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T * ...@@ -838,14 +874,14 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *
if(bufs[i].BufferType == SECBUFFER_EXTRA) { if(bufs[i].BufferType == SECBUFFER_EXTRA) {
conn->extra_buf = heap_alloc(bufs[i].cbBuffer); conn->extra_buf = heap_alloc(bufs[i].cbBuffer);
if(!conn->extra_buf) if(!conn->extra_buf)
return FALSE; return ERROR_NOT_ENOUGH_MEMORY;
conn->extra_len = bufs[i].cbBuffer; conn->extra_len = bufs[i].cbBuffer;
memcpy(conn->extra_buf, bufs[i].pvBuffer, conn->extra_len); memcpy(conn->extra_buf, bufs[i].pvBuffer, conn->extra_len);
} }
} }
return TRUE; return ERROR_SUCCESS;
} }
/****************************************************************************** /******************************************************************************
...@@ -867,9 +903,7 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t ...@@ -867,9 +903,7 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t
case BLOCKING_ALLOW: case BLOCKING_ALLOW:
break; break;
case BLOCKING_DISALLOW: case BLOCKING_DISALLOW:
#ifdef MSG_DONTWAIT flags = WINE_MSG_DONTWAIT;
flags = MSG_DONTWAIT;
#endif
break; break;
case BLOCKING_WAITALL: case BLOCKING_WAITALL:
flags = MSG_WAITALL; flags = MSG_WAITALL;
...@@ -883,7 +917,8 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t ...@@ -883,7 +917,8 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t
else else
{ {
SIZE_T size = 0, cread; SIZE_T size = 0, cread;
BOOL res, eof; BOOL eof;
DWORD res;
if(connection->peek_msg) { if(connection->peek_msg) {
size = min(len, connection->peek_len); size = min(len, connection->peek_len);
...@@ -900,18 +935,19 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t ...@@ -900,18 +935,19 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t
*recvd = size; *recvd = size;
return ERROR_SUCCESS; return ERROR_SUCCESS;
} }
}
if(mode == BLOCKING_DISALLOW) mode = BLOCKING_DISALLOW;
return WSAEWOULDBLOCK; /* FIXME: We can do better */ }
set_socket_blocking(connection->socket, BLOCKING_ALLOW);
do { do {
res = read_ssl_chunk(connection, (BYTE*)buf+size, len-size, &cread, &eof); res = read_ssl_chunk(connection, (BYTE*)buf+size, len-size, mode, &cread, &eof);
if(!res) { if(res != ERROR_SUCCESS) {
WARN("read_ssl_chunk failed\n"); if(res == WSAEWOULDBLOCK) {
if(!size) if(size)
return ERROR_INTERNET_CONNECTION_ABORTED; res = ERROR_SUCCESS;
}else {
WARN("read_ssl_chunk failed\n");
}
break; break;
} }
...@@ -925,7 +961,7 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t ...@@ -925,7 +961,7 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t
TRACE("received %ld bytes\n", size); TRACE("received %ld bytes\n", size);
*recvd = size; *recvd = size;
return ERROR_SUCCESS; return res;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment