Commit 8b6c30ab authored by Rob Shearman's avatar Rob Shearman Committed by Alexandre Julliard

rpcrt4: Open the endpoint from the caller of RpcServerUseProtseq* instead of the…

rpcrt4: Open the endpoint from the caller of RpcServerUseProtseq* instead of the protseq server thread. This allows errors to be returned to the caller and to create more than one connection for an endpoint.
parent 433993ee
......@@ -258,7 +258,7 @@ RPC_STATUS RPCRT4_OpenBinding(RpcBinding* Binding, RpcConnection** Connection,
RPCRT4_CreateConnection(&NewConnection, Binding->server, Binding->Protseq,
Binding->NetworkAddr, Binding->Endpoint, NULL,
Binding->AuthInfo, Binding);
status = RPCRT4_OpenConnection(NewConnection);
status = RPCRT4_OpenClientConnection(NewConnection);
if (status != RPC_S_OK)
{
RPCRT4_DestroyConnection(NewConnection);
......
......@@ -65,7 +65,7 @@ struct connection_ops {
const char *name;
unsigned char epm_protocols[2]; /* only floors 3 and 4. see http://www.opengroup.org/onlinepubs/9629399/apdxl.htm */
RpcConnection *(*alloc)(void);
RPC_STATUS (*open_connection)(RpcConnection *conn);
RPC_STATUS (*open_connection_client)(RpcConnection *conn);
RPC_STATUS (*handoff)(RpcConnection *old_conn, RpcConnection *new_conn);
int (*read)(RpcConnection *conn, void *buffer, unsigned int len);
int (*write)(RpcConnection *conn, const void *buffer, unsigned int len);
......@@ -108,7 +108,7 @@ RpcConnection *RPCRT4_GetIdleConnection(const RPC_SYNTAX_IDENTIFIER *InterfaceId
void RPCRT4_ReleaseIdleConnection(RpcConnection *Connection);
RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server, LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCSTR NetworkOptions, RpcAuthInfo* AuthInfo, RpcBinding* Binding);
RPC_STATUS RPCRT4_DestroyConnection(RpcConnection* Connection);
RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection);
RPC_STATUS RPCRT4_OpenClientConnection(RpcConnection* Connection);
RPC_STATUS RPCRT4_CloseConnection(RpcConnection* Connection);
RPC_STATUS RPCRT4_SpawnConnection(RpcConnection** Connection, RpcConnection* OldConnection);
......
......@@ -496,8 +496,7 @@ static RPC_STATUS RPCRT4_use_protseq(RpcServerProtseq* ps)
{
RPC_STATUS status;
status = RPCRT4_CreateConnection(&ps->conn, TRUE, ps->Protseq, NULL,
ps->Endpoint, NULL, NULL, NULL);
status = ps->ops->open_endpoint(ps, ps->Endpoint);
if (status != RPC_S_OK)
return status;
......
......@@ -56,6 +56,8 @@ struct protseq_ops
/* returns -1 for failure, 0 for server state changed and 1 to indicate a
* new connection was established */
int (*wait_for_new_connection)(RpcServerProtseq *protseq, unsigned int count, void *wait_array);
/* opens the endpoint and optionally begins listening */
RPC_STATUS (*open_endpoint)(RpcServerProtseq *protseq, LPSTR endpoint);
};
typedef struct _RpcServerInterface
......
......@@ -89,8 +89,9 @@ static struct list connection_pool = LIST_INIT(connection_pool);
typedef struct _RpcConnection_np
{
RpcConnection common;
HANDLE pipe, thread;
HANDLE pipe;
OVERLAPPED ovl;
BOOL listening;
} RpcConnection_np;
static RpcConnection *rpcrt4_conn_np_alloc(void)
......@@ -99,13 +100,35 @@ static RpcConnection *rpcrt4_conn_np_alloc(void)
if (npc)
{
npc->pipe = NULL;
npc->thread = NULL;
memset(&npc->ovl, 0, sizeof(npc->ovl));
npc->listening = FALSE;
}
return &npc->common;
}
static RPC_STATUS rpcrt4_connect_pipe(RpcConnection *Connection, LPCSTR pname)
static RPC_STATUS rpcrt4_conn_listen_pipe(RpcConnection_np *npc)
{
if (npc->listening)
return RPC_S_OK;
npc->listening = TRUE;
if (ConnectNamedPipe(npc->pipe, &npc->ovl))
return RPC_S_OK;
WARN("Couldn't ConnectNamedPipe (error was %ld)\n", GetLastError());
if (GetLastError() == ERROR_PIPE_CONNECTED) {
SetEvent(npc->ovl.hEvent);
return RPC_S_OK;
}
if (GetLastError() == ERROR_IO_PENDING) {
/* FIXME: looks like we need to GetOverlappedResult here? */
return RPC_S_OK;
}
npc->listening = FALSE;
return RPC_S_SERVER_UNAVAILABLE;
}
static RPC_STATUS rpcrt4_conn_create_pipe(RpcConnection *Connection, LPCSTR pname)
{
RpcConnection_np *npc = (RpcConnection_np *) Connection;
TRACE("listening on %s\n", pname);
......@@ -121,22 +144,13 @@ static RPC_STATUS rpcrt4_connect_pipe(RpcConnection *Connection, LPCSTR pname)
memset(&npc->ovl, 0, sizeof(npc->ovl));
npc->ovl.hEvent = CreateEventW(NULL, TRUE, FALSE, NULL);
if (ConnectNamedPipe(npc->pipe, &npc->ovl))
return RPC_S_OK;
WARN("Couldn't ConnectNamedPipe (error was %ld)\n", GetLastError());
if (GetLastError() == ERROR_PIPE_CONNECTED) {
SetEvent(npc->ovl.hEvent);
return RPC_S_OK;
}
if (GetLastError() == ERROR_IO_PENDING) {
/* FIXME: looks like we need to GetOverlappedResult here? */
/* Note: we don't call ConnectNamedPipe here because it must be done in the
* server thread as the thread must be alertable */
return RPC_S_OK;
}
return RPC_S_SERVER_UNAVAILABLE;
}
static RPC_STATUS rpcrt4_open_pipe(RpcConnection *Connection, LPCSTR pname, BOOL wait)
static RPC_STATUS rpcrt4_conn_open_pipe(RpcConnection *Connection, LPCSTR pname, BOOL wait)
{
RpcConnection_np *npc = (RpcConnection_np *) Connection;
HANDLE pipe;
......@@ -188,13 +202,36 @@ static RPC_STATUS rpcrt4_ncalrpc_open(RpcConnection* Connection)
* but we'll implement it with named pipes for now */
pname = I_RpcAllocate(strlen(prefix) + strlen(Connection->Endpoint) + 1);
strcat(strcpy(pname, prefix), Connection->Endpoint);
r = rpcrt4_conn_open_pipe(Connection, pname, TRUE);
I_RpcFree(pname);
if (Connection->server)
r = rpcrt4_connect_pipe(Connection, pname);
else
r = rpcrt4_open_pipe(Connection, pname, TRUE);
return r;
}
static RPC_STATUS rpcrt4_protseq_ncalrpc_open_endpoint(RpcServerProtseq* protseq, LPSTR endpoint)
{
static LPCSTR prefix = "\\\\.\\pipe\\lrpc\\";
RPC_STATUS r;
LPSTR pname;
RpcConnection *Connection;
r = RPCRT4_CreateConnection(&Connection, TRUE, protseq->Protseq, NULL,
endpoint, NULL, NULL, NULL);
if (r != RPC_S_OK)
return r;
/* protseq=ncalrpc: supposed to use NT LPC ports,
* but we'll implement it with named pipes for now */
pname = I_RpcAllocate(strlen(prefix) + strlen(Connection->Endpoint) + 1);
strcat(strcpy(pname, prefix), Connection->Endpoint);
r = rpcrt4_conn_create_pipe(Connection, pname);
I_RpcFree(pname);
EnterCriticalSection(&protseq->cs);
Connection->Next = protseq->conn;
protseq->conn = Connection;
LeaveCriticalSection(&protseq->cs);
return r;
}
......@@ -212,19 +249,35 @@ static RPC_STATUS rpcrt4_ncacn_np_open(RpcConnection* Connection)
/* protseq=ncacn_np: named pipes */
pname = I_RpcAllocate(strlen(prefix) + strlen(Connection->Endpoint) + 1);
strcat(strcpy(pname, prefix), Connection->Endpoint);
if (Connection->server)
r = rpcrt4_connect_pipe(Connection, pname);
else
r = rpcrt4_open_pipe(Connection, pname, FALSE);
r = rpcrt4_conn_open_pipe(Connection, pname, FALSE);
I_RpcFree(pname);
return r;
}
static RPC_STATUS rpcrt4_conn_np_handoff(RpcConnection *old_conn, RpcConnection *new_conn)
static RPC_STATUS rpcrt4_protseq_ncacn_np_open_endpoint(RpcServerProtseq *protseq, LPSTR endpoint)
{
static LPCSTR prefix = "\\\\.";
RPC_STATUS r;
LPSTR pname;
RpcConnection *Connection;
r = RPCRT4_CreateConnection(&Connection, TRUE, protseq->Protseq, NULL,
endpoint, NULL, NULL, NULL);
if (r != RPC_S_OK)
return r;
/* protseq=ncacn_np: named pipes */
pname = I_RpcAllocate(strlen(prefix) + strlen(Connection->Endpoint) + 1);
strcat(strcpy(pname, prefix), Connection->Endpoint);
r = rpcrt4_conn_create_pipe(Connection, pname);
I_RpcFree(pname);
return r;
}
static void rpcrt4_conn_np_handoff(RpcConnection_np *old_npc, RpcConnection_np *new_npc)
{
RpcConnection_np *old_npc = (RpcConnection_np *) old_conn;
RpcConnection_np *new_npc = (RpcConnection_np *) new_conn;
/* because of the way named pipes work, we'll transfer the connected pipe
* to the child, then reopen the server binding to continue listening */
......@@ -232,7 +285,41 @@ static RPC_STATUS rpcrt4_conn_np_handoff(RpcConnection *old_conn, RpcConnection
new_npc->ovl = old_npc->ovl;
old_npc->pipe = 0;
memset(&old_npc->ovl, 0, sizeof(old_npc->ovl));
return RPCRT4_OpenConnection(old_conn);
old_npc->listening = FALSE;
}
static RPC_STATUS rpcrt4_ncacn_np_handoff(RpcConnection *old_conn, RpcConnection *new_conn)
{
RPC_STATUS status;
LPSTR pname;
static LPCSTR prefix = "\\\\.";
rpcrt4_conn_np_handoff((RpcConnection_np *)old_conn, (RpcConnection_np *)new_conn);
pname = I_RpcAllocate(strlen(prefix) + strlen(old_conn->Endpoint) + 1);
strcat(strcpy(pname, prefix), old_conn->Endpoint);
status = rpcrt4_conn_create_pipe(old_conn, pname);
I_RpcFree(pname);
return status;
}
static RPC_STATUS rpcrt4_ncalrpc_handoff(RpcConnection *old_conn, RpcConnection *new_conn)
{
RPC_STATUS status;
LPSTR pname;
static LPCSTR prefix = "\\\\.\\pipe\\lrpc\\";
TRACE("%s\n", old_conn->Endpoint);
rpcrt4_conn_np_handoff((RpcConnection_np *)old_conn, (RpcConnection_np *)new_conn);
pname = I_RpcAllocate(strlen(prefix) + strlen(old_conn->Endpoint) + 1);
strcat(strcpy(pname, prefix), old_conn->Endpoint);
status = rpcrt4_conn_create_pipe(old_conn, pname);
I_RpcFree(pname);
return status;
}
static int rpcrt4_conn_np_read(RpcConnection *Connection,
......@@ -409,7 +496,7 @@ static void *rpcrt4_protseq_np_get_wait_array(RpcServerProtseq *protseq, void *p
*count = 1;
conn = CONTAINING_RECORD(protseq->conn, RpcConnection_np, common);
while (conn) {
RPCRT4_OpenConnection(&conn->common);
rpcrt4_conn_listen_pipe(conn);
if (conn->ovl.hEvent)
(*count)++;
conn = CONTAINING_RECORD(conn->common.Next, RpcConnection_np, common);
......@@ -584,7 +671,7 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection)
if (tcpc->sock != -1)
return RPC_S_OK;
hints.ai_flags = Connection->server ? AI_PASSIVE : 0;
hints.ai_flags = 0;
hints.ai_family = PF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
......@@ -620,8 +707,72 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection)
continue;
}
if (Connection->server)
if (0>connect(sock, ai_cur->ai_addr, ai_cur->ai_addrlen))
{
WARN("connect() failed: %s\n", strerror(errno));
close(sock);
continue;
}
tcpc->sock = sock;
freeaddrinfo(ai);
TRACE("connected\n");
return RPC_S_OK;
}
freeaddrinfo(ai);
ERR("couldn't connect to %s:%s\n", Connection->NetworkAddr, Connection->Endpoint);
return RPC_S_SERVER_UNAVAILABLE;
}
static RPC_STATUS rpcrt4_protseq_ncacn_ip_tcp_open_endpoint(RpcServerProtseq *protseq, LPSTR endpoint)
{
RPC_STATUS status;
int sock;
int ret;
struct addrinfo *ai;
struct addrinfo *ai_cur;
struct addrinfo hints;
TRACE("(%p, %s)\n", protseq, endpoint);
hints.ai_flags = AI_PASSIVE /* for non-localhost addresses */;
hints.ai_family = PF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
hints.ai_addrlen = 0;
hints.ai_addr = NULL;
hints.ai_canonname = NULL;
hints.ai_next = NULL;
ret = getaddrinfo(NULL, endpoint, &hints, &ai);
if (ret)
{
ERR("getaddrinfo for port %s failed: %s\n", endpoint,
gai_strerror(ret));
return RPC_S_SERVER_UNAVAILABLE;
}
for (ai_cur = ai; ai_cur; ai_cur = ai_cur->ai_next)
{
RpcConnection_tcp *tcpc;
if (TRACE_ON(rpc))
{
char host[256];
char service[256];
getnameinfo(ai_cur->ai_addr, ai_cur->ai_addrlen,
host, sizeof(host), service, sizeof(service),
NI_NUMERICHOST | NI_NUMERICSERV);
TRACE("trying %s:%s\n", host, service);
}
sock = socket(ai_cur->ai_family, ai_cur->ai_socktype, ai_cur->ai_protocol);
if (sock < 0)
{
WARN("socket() failed: %s\n", strerror(errno));
continue;
}
ret = bind(sock, ai_cur->ai_addr, ai_cur->ai_addrlen);
if (ret < 0)
{
......@@ -629,6 +780,15 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection)
close(sock);
continue;
}
status = RPCRT4_CreateConnection((RpcConnection **)&tcpc, TRUE,
protseq->Protseq, NULL, endpoint, NULL,
NULL, NULL);
if (status != RPC_S_OK)
{
close(sock);
continue;
}
ret = listen(sock, 10);
if (ret < 0)
{
......@@ -648,25 +808,20 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection)
continue;
}
tcpc->sock = sock;
}
else /* it's a client */
{
if (0>connect(sock, ai_cur->ai_addr, ai_cur->ai_addrlen))
{
WARN("connect() failed: %s\n", strerror(errno));
close(sock);
continue;
}
tcpc->sock = sock;
}
freeaddrinfo(ai);
TRACE("connected\n");
EnterCriticalSection(&protseq->cs);
tcpc->common.Next = protseq->conn;
protseq->conn = &tcpc->common;
LeaveCriticalSection(&protseq->cs);
TRACE("listening on %s\n", endpoint);
return RPC_S_OK;
}
freeaddrinfo(ai);
ERR("couldn't connect to %s:%s\n", Connection->NetworkAddr, Connection->Endpoint);
ERR("couldn't listen on port %s\n", endpoint);
return RPC_S_SERVER_UNAVAILABLE;
}
......@@ -906,7 +1061,6 @@ static void *rpcrt4_protseq_sock_get_wait_array(RpcServerProtseq *protseq, void
*count = 1;
conn = (RpcConnection_tcp *)protseq->conn;
while (conn) {
RPCRT4_OpenConnection(&conn->common);
if (conn->sock != -1)
(*count)++;
conn = (RpcConnection_tcp *)conn->common.Next;
......@@ -1001,7 +1155,7 @@ static const struct connection_ops conn_protseq_list[] = {
{ EPM_PROTOCOL_NCACN, EPM_PROTOCOL_SMB },
rpcrt4_conn_np_alloc,
rpcrt4_ncacn_np_open,
rpcrt4_conn_np_handoff,
rpcrt4_ncacn_np_handoff,
rpcrt4_conn_np_read,
rpcrt4_conn_np_write,
rpcrt4_conn_np_close,
......@@ -1012,7 +1166,7 @@ static const struct connection_ops conn_protseq_list[] = {
{ EPM_PROTOCOL_NCALRPC, EPM_PROTOCOL_PIPE },
rpcrt4_conn_np_alloc,
rpcrt4_ncalrpc_open,
rpcrt4_conn_np_handoff,
rpcrt4_ncalrpc_handoff,
rpcrt4_conn_np_read,
rpcrt4_conn_np_write,
rpcrt4_conn_np_close,
......@@ -1042,6 +1196,7 @@ static const struct protseq_ops protseq_list[] =
rpcrt4_protseq_np_get_wait_array,
rpcrt4_protseq_np_free_wait_array,
rpcrt4_protseq_np_wait_for_new_connection,
rpcrt4_protseq_ncacn_np_open_endpoint,
},
{
"ncalrpc",
......@@ -1050,6 +1205,7 @@ static const struct protseq_ops protseq_list[] =
rpcrt4_protseq_np_get_wait_array,
rpcrt4_protseq_np_free_wait_array,
rpcrt4_protseq_np_wait_for_new_connection,
rpcrt4_protseq_ncalrpc_open_endpoint,
},
{
"ncacn_ip_tcp",
......@@ -1058,6 +1214,7 @@ static const struct protseq_ops protseq_list[] =
rpcrt4_protseq_sock_get_wait_array,
rpcrt4_protseq_sock_free_wait_array,
rpcrt4_protseq_sock_wait_for_new_connection,
rpcrt4_protseq_ncacn_ip_tcp_open_endpoint,
},
};
......@@ -1083,11 +1240,12 @@ static const struct connection_ops *rpcrt4_get_conn_protseq_ops(const char *prot
/**** interface to rest of code ****/
RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection)
RPC_STATUS RPCRT4_OpenClientConnection(RpcConnection* Connection)
{
TRACE("(Connection == ^%p)\n", Connection);
return Connection->ops->open_connection(Connection);
assert(!Connection->server);
return Connection->ops->open_connection_client(Connection);
}
RPC_STATUS RPCRT4_CloseConnection(RpcConnection* Connection)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment