Commit a51486c6 authored by Rob Shearman's avatar Rob Shearman Committed by Alexandre Julliard

rpcrt4: Implement RpcCancelThread for the ncacn_ip_tcp protocol sequence.

parent 4dda7c63
......@@ -104,6 +104,7 @@ struct connection_ops {
int (*read)(RpcConnection *conn, void *buffer, unsigned int len);
int (*write)(RpcConnection *conn, const void *buffer, unsigned int len);
int (*close)(RpcConnection *conn);
void (*cancel_call)(RpcConnection *conn);
size_t (*get_top_of_tower)(unsigned char *tower_data, const char *networkaddr, const char *endpoint);
RPC_STATUS (*parse_top_of_tower)(const unsigned char *tower_data, size_t tower_size, char **networkaddr, char **endpoint);
};
......@@ -190,6 +191,11 @@ static inline int rpcrt4_conn_close(RpcConnection *Connection)
return Connection->ops->close(Connection);
}
static inline void rpcrt4_conn_cancel_call(RpcConnection *Connection)
{
Connection->ops->cancel_call(Connection);
}
static inline RPC_STATUS rpcrt4_conn_handoff(RpcConnection *old_conn, RpcConnection *new_conn)
{
return old_conn->ops->handoff(old_conn, new_conn);
......@@ -199,4 +205,6 @@ static inline RPC_STATUS rpcrt4_conn_handoff(RpcConnection *old_conn, RpcConnect
RPC_STATUS RpcTransport_GetTopOfTower(unsigned char *tower_data, size_t *tower_size, const char *protseq, const char *networkaddr, const char *endpoint);
RPC_STATUS RpcTransport_ParseTopOfTower(const unsigned char *tower_data, size_t tower_size, char **protseq, char **networkaddr, char **endpoint);
void RPCRT4_SetThreadCurrentConnection(RpcConnection *Connection);
#endif
......@@ -449,6 +449,8 @@ static RPC_STATUS RPCRT4_SendAuth(RpcConnection *Connection, RpcPktHdr *Header,
LONG alen;
RPC_STATUS status;
RPCRT4_SetThreadCurrentConnection(Connection);
buffer_pos = Buffer;
/* The packet building functions save the packet header size, so we can use it. */
hdr_size = Header->common.frag_len;
......@@ -518,6 +520,7 @@ static RPC_STATUS RPCRT4_SendAuth(RpcConnection *Connection, RpcPktHdr *Header,
if (status != RPC_S_OK)
{
HeapFree(GetProcessHeap(), 0, pkt);
RPCRT4_SetThreadCurrentConnection(NULL);
return status;
}
}
......@@ -528,6 +531,7 @@ write:
HeapFree(GetProcessHeap(), 0, pkt);
if (count<0) {
WARN("rpcrt4_conn_write failed (auth)\n");
RPCRT4_SetThreadCurrentConnection(NULL);
return RPC_S_PROTOCOL_ERROR;
}
......@@ -536,6 +540,7 @@ write:
Header->common.flags &= ~RPC_FLG_FIRST;
}
RPCRT4_SetThreadCurrentConnection(NULL);
return RPC_S_OK;
}
......@@ -697,6 +702,8 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
TRACE("(%p, %p, %p)\n", Connection, Header, pMsg);
RPCRT4_SetThreadCurrentConnection(Connection);
/* read packet common header */
dwRead = rpcrt4_conn_read(Connection, &common_hdr, sizeof(common_hdr));
if (dwRead != sizeof(common_hdr)) {
......@@ -872,6 +879,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
status = RPC_S_OK;
fail:
RPCRT4_SetThreadCurrentConnection(NULL);
if (status != RPC_S_OK) {
RPCRT4_FreeHeader(*Header);
*Header = NULL;
......
......@@ -416,6 +416,11 @@ static int rpcrt4_conn_np_close(RpcConnection *Connection)
return 0;
}
static void rpcrt4_conn_np_cancel_call(RpcConnection *Connection)
{
/* FIXME: implement when named pipe writes use overlapped I/O */
}
static size_t rpcrt4_ncacn_np_get_top_of_tower(unsigned char *tower_data,
const char *networkaddr,
const char *endpoint)
......@@ -703,6 +708,7 @@ typedef struct _RpcConnection_tcp
{
RpcConnection common;
int sock;
int cancel_fds[2];
} RpcConnection_tcp;
static RpcConnection *rpcrt4_conn_tcp_alloc(void)
......@@ -712,6 +718,12 @@ static RpcConnection *rpcrt4_conn_tcp_alloc(void)
if (tcpc == NULL)
return NULL;
tcpc->sock = -1;
if (socketpair(PF_UNIX, SOCK_STREAM, 0, tcpc->cancel_fds) < 0)
{
ERR("socketpair() failed: %s\n", strerror(errno));
HeapFree(GetProcessHeap(), 0, tcpc);
return NULL;
}
return &tcpc->common;
}
......@@ -777,6 +789,7 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection)
/* RPC depends on having minimal latency so disable the Nagle algorithm */
val = 1;
setsockopt(sock, SOL_TCP, TCP_NODELAY, &val, sizeof(val));
fcntl(sock, F_SETFL, O_NONBLOCK); /* make socket nonblocking */
tcpc->sock = sock;
......@@ -942,18 +955,64 @@ static int rpcrt4_conn_tcp_read(RpcConnection *Connection,
void *buffer, unsigned int count)
{
RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection;
int r = recv(tcpc->sock, buffer, count, MSG_WAITALL);
TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, r);
return r;
int bytes_read = 0;
do
{
int r = recv(tcpc->sock, (char *)buffer + bytes_read, count - bytes_read, 0);
if (r >= 0)
bytes_read += r;
else if (errno != EAGAIN)
return -1;
else
{
struct pollfd pfds[2];
pfds[0].fd = tcpc->sock;
pfds[0].events = POLLIN;
pfds[1].fd = tcpc->cancel_fds[0];
pfds[1].fd = POLLIN;
if (poll(pfds, 2, -1 /* infinite */) == -1 && errno != EINTR)
{
ERR("poll() failed: %s\n", strerror(errno));
return -1;
}
if (pfds[1].revents & POLLIN) /* canceled */
{
char dummy;
read(pfds[1].fd, &dummy, sizeof(dummy));
return -1;
}
}
} while (bytes_read != count);
TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, bytes_read);
return bytes_read;
}
static int rpcrt4_conn_tcp_write(RpcConnection *Connection,
const void *buffer, unsigned int count)
{
RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection;
int r = write(tcpc->sock, buffer, count);
TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, r);
return r;
int bytes_written = 0;
do
{
int r = write(tcpc->sock, (const char *)buffer + bytes_written, count - bytes_written);
if (r >= 0)
bytes_written += r;
else if (errno != EAGAIN)
return -1;
else
{
struct pollfd pfd;
pfd.fd = tcpc->sock;
pfd.events = POLLOUT;
if (poll(&pfd, 1, -1 /* infinite */) == -1 && errno != EINTR)
{
ERR("poll() failed: %s\n", strerror(errno));
return -1;
}
}
} while (bytes_written != count);
TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, bytes_written);
return bytes_written;
}
static int rpcrt4_conn_tcp_close(RpcConnection *Connection)
......@@ -965,9 +1024,21 @@ static int rpcrt4_conn_tcp_close(RpcConnection *Connection)
if (tcpc->sock != -1)
close(tcpc->sock);
tcpc->sock = -1;
close(tcpc->cancel_fds[0]);
close(tcpc->cancel_fds[1]);
return 0;
}
static void rpcrt4_conn_tcp_cancel_call(RpcConnection *Connection)
{
RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection;
char dummy = 1;
TRACE("%p\n", Connection);
write(tcpc->cancel_fds[1], &dummy, 1);
}
static size_t rpcrt4_ncacn_ip_tcp_get_top_of_tower(unsigned char *tower_data,
const char *networkaddr,
const char *endpoint)
......@@ -1250,6 +1321,7 @@ static const struct connection_ops conn_protseq_list[] = {
rpcrt4_conn_np_read,
rpcrt4_conn_np_write,
rpcrt4_conn_np_close,
rpcrt4_conn_np_cancel_call,
rpcrt4_ncacn_np_get_top_of_tower,
rpcrt4_ncacn_np_parse_top_of_tower,
},
......@@ -1261,6 +1333,7 @@ static const struct connection_ops conn_protseq_list[] = {
rpcrt4_conn_np_read,
rpcrt4_conn_np_write,
rpcrt4_conn_np_close,
rpcrt4_conn_np_cancel_call,
rpcrt4_ncalrpc_get_top_of_tower,
rpcrt4_ncalrpc_parse_top_of_tower,
},
......@@ -1272,6 +1345,7 @@ static const struct connection_ops conn_protseq_list[] = {
rpcrt4_conn_tcp_read,
rpcrt4_conn_tcp_write,
rpcrt4_conn_tcp_close,
rpcrt4_conn_tcp_cancel_call,
rpcrt4_ncacn_ip_tcp_get_top_of_tower,
rpcrt4_ncacn_ip_tcp_parse_top_of_tower,
}
......
......@@ -100,6 +100,8 @@
#include "winerror.h"
#include "winbase.h"
#include "winuser.h"
#include "winnt.h"
#include "winternl.h"
#include "iptypes.h"
#include "iphlpapi.h"
#include "wine/unicode.h"
......@@ -133,6 +135,25 @@ static CRITICAL_SECTION_DEBUG critsect_debug =
};
static CRITICAL_SECTION uuid_cs = { &critsect_debug, -1, 0, 0, 0, 0 };
static CRITICAL_SECTION threaddata_cs;
static CRITICAL_SECTION_DEBUG threaddata_cs_debug =
{
0, 0, &uuid_cs,
{ &threaddata_cs_debug.ProcessLocksList, &threaddata_cs_debug.ProcessLocksList },
0, 0, { (DWORD_PTR)(__FILE__ ": threaddata_cs") }
};
static CRITICAL_SECTION threaddata_cs = { &threaddata_cs_debug, -1, 0, 0, 0, 0 };
struct list threaddata_list = LIST_INIT(threaddata_list);
struct threaddata
{
struct list entry;
CRITICAL_SECTION cs;
DWORD thread_id;
RpcConnection *connection;
};
/***********************************************************************
* DllMain
*
......@@ -148,14 +169,29 @@ static CRITICAL_SECTION uuid_cs = { &critsect_debug, -1, 0, 0, 0, 0 };
BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved)
{
struct threaddata *tdata;
switch (fdwReason) {
case DLL_PROCESS_ATTACH:
DisableThreadLibraryCalls(hinstDLL);
master_mutex = CreateMutexA( NULL, FALSE, RPCSS_MASTER_MUTEX_NAME);
if (!master_mutex)
ERR("Failed to create master mutex\n");
break;
case DLL_THREAD_DETACH:
tdata = NtCurrentTeb()->ReservedForNtRpc;
if (tdata)
{
EnterCriticalSection(&threaddata_cs);
list_remove(&tdata->entry);
LeaveCriticalSection(&threaddata_cs);
DeleteCriticalSection(&tdata->cs);
if (tdata->connection)
ERR("tdata->connection should be NULL but is still set to %p\n", tdata);
HeapFree(GetProcessHeap(), 0, tdata);
}
case DLL_PROCESS_DETACH:
CloseHandle(master_mutex);
master_mutex = NULL;
......@@ -847,11 +883,53 @@ RPC_STATUS RPC_ENTRY RpcMgmtSetCancelTimeout(LONG Timeout)
return RPC_S_OK;
}
void RPCRT4_SetThreadCurrentConnection(RpcConnection *Connection)
{
struct threaddata *tdata = NtCurrentTeb()->ReservedForNtRpc;
if (!tdata)
{
tdata = HeapAlloc(GetProcessHeap(), 0, sizeof(*tdata));
if (!tdata) return;
InitializeCriticalSection(&tdata->cs);
tdata->thread_id = GetCurrentThreadId();
tdata->connection = Connection;
EnterCriticalSection(&threaddata_cs);
list_add_tail(&threaddata_list, &tdata->entry);
LeaveCriticalSection(&threaddata_cs);
NtCurrentTeb()->ReservedForNtRpc = tdata;
return;
}
EnterCriticalSection(&tdata->cs);
tdata->connection = Connection;
LeaveCriticalSection(&tdata->cs);
}
/******************************************************************************
* RpcCancelThread (rpcrt4.@)
*/
RPC_STATUS RPC_ENTRY RpcCancelThread(HANDLE ThreadHandle)
{
FIXME("(%p): stub\n", ThreadHandle);
DWORD target_tid;
struct threaddata *tdata;
TRACE("(%p)\n", ThreadHandle);
target_tid = GetThreadId(ThreadHandle);
if (!target_tid)
return RPC_S_INVALID_ARG;
EnterCriticalSection(&threaddata_cs);
LIST_FOR_EACH_ENTRY(tdata, &threaddata_list, struct threaddata, entry)
if (tdata->thread_id == target_tid)
{
rpcrt4_conn_cancel_call(tdata->connection);
break;
}
LeaveCriticalSection(&threaddata_cs);
return RPC_S_OK;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment