/*
 * Associations
 *
 * Copyright 2007 Robert Shearman (for CodeWeavers)
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
 *
 */

#include <stdarg.h>
#include <assert.h>

#include "rpc.h"
#include "rpcndr.h"
#include "winternl.h"

#include "wine/unicode.h"
#include "wine/debug.h"

#include "rpc_binding.h"
#include "rpc_assoc.h"
#include "rpc_message.h"

WINE_DEFAULT_DEBUG_CHANNEL(rpc);

static CRITICAL_SECTION assoc_list_cs;
static CRITICAL_SECTION_DEBUG assoc_list_cs_debug =
{
    0, 0, &assoc_list_cs,
    { &assoc_list_cs_debug.ProcessLocksList, &assoc_list_cs_debug.ProcessLocksList },
      0, 0, { (DWORD_PTR)(__FILE__ ": assoc_list_cs") }
};
static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 };

static struct list client_assoc_list = LIST_INIT(client_assoc_list);
static struct list server_assoc_list = LIST_INIT(server_assoc_list);

static LONG last_assoc_group_id;

typedef struct _RpcContextHandle
{
    struct list entry;
    void *user_context;
    NDR_RUNDOWN rundown_routine;
    void *ctx_guard;
    UUID uuid;
    RTL_RWLOCK rw_lock;
    unsigned int refs;
} RpcContextHandle;

static void RpcContextHandle_Destroy(RpcContextHandle *context_handle);

static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr,
                                 LPCSTR Endpoint, LPCWSTR NetworkOptions,
                                 RpcAssoc **assoc_out)
{
    RpcAssoc *assoc;
    assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc));
    if (!assoc)
        return RPC_S_OUT_OF_RESOURCES;
    assoc->refs = 1;
    list_init(&assoc->free_connection_pool);
    list_init(&assoc->context_handle_list);
    InitializeCriticalSection(&assoc->cs);
    assoc->cs.DebugInfo->Spare[0] = (DWORD_PTR)(__FILE__ ": RpcAssoc.cs");
    assoc->Protseq = RPCRT4_strdupA(Protseq);
    assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
    assoc->Endpoint = RPCRT4_strdupA(Endpoint);
    assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
    assoc->assoc_group_id = 0;
    UuidCreate(&assoc->http_uuid);
    list_init(&assoc->entry);
    *assoc_out = assoc;
    return RPC_S_OK;
}

static BOOL compare_networkoptions(LPCWSTR opts1, LPCWSTR opts2)
{
    if ((opts1 == NULL) && (opts2 == NULL))
        return TRUE;
    if ((opts1 == NULL) || (opts2 == NULL))
        return FALSE;
    return !strcmpW(opts1, opts2);
}

RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
                                 LPCSTR Endpoint, LPCWSTR NetworkOptions,
                                 RpcAssoc **assoc_out)
{
    RpcAssoc *assoc;
    RPC_STATUS status;

    EnterCriticalSection(&assoc_list_cs);
    LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry)
    {
        if (!strcmp(Protseq, assoc->Protseq) &&
            !strcmp(NetworkAddr, assoc->NetworkAddr) &&
            !strcmp(Endpoint, assoc->Endpoint) &&
            compare_networkoptions(NetworkOptions, assoc->NetworkOptions))
        {
            assoc->refs++;
            *assoc_out = assoc;
            LeaveCriticalSection(&assoc_list_cs);
            TRACE("using existing assoc %p\n", assoc);
            return RPC_S_OK;
        }
    }

    status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
    if (status != RPC_S_OK)
    {
        LeaveCriticalSection(&assoc_list_cs);
        return status;
    }
    list_add_head(&client_assoc_list, &assoc->entry);
    *assoc_out = assoc;

    LeaveCriticalSection(&assoc_list_cs);

    TRACE("new assoc %p\n", assoc);

    return RPC_S_OK;
}

RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
                                         LPCSTR Endpoint, LPCWSTR NetworkOptions,
                                         ULONG assoc_gid,
                                         RpcAssoc **assoc_out)
{
    RpcAssoc *assoc;
    RPC_STATUS status;

    EnterCriticalSection(&assoc_list_cs);
    if (assoc_gid)
    {
        LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry)
        {
            /* FIXME: NetworkAddr shouldn't be NULL */
            if (assoc->assoc_group_id == assoc_gid &&
                !strcmp(Protseq, assoc->Protseq) &&
                (!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) &&
                !strcmp(Endpoint, assoc->Endpoint) &&
                ((!assoc->NetworkOptions == !NetworkOptions) &&
                 (!NetworkOptions || !strcmpW(NetworkOptions, assoc->NetworkOptions))))
            {
                assoc->refs++;
                *assoc_out = assoc;
                LeaveCriticalSection(&assoc_list_cs);
                TRACE("using existing assoc %p\n", assoc);
                return RPC_S_OK;
            }
        }
        *assoc_out = NULL;
        LeaveCriticalSection(&assoc_list_cs);
        return RPC_S_NO_CONTEXT_AVAILABLE;
    }

    status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
    if (status != RPC_S_OK)
    {
        LeaveCriticalSection(&assoc_list_cs);
        return status;
    }
    assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id);
    list_add_head(&server_assoc_list, &assoc->entry);
    *assoc_out = assoc;

    LeaveCriticalSection(&assoc_list_cs);

    TRACE("new assoc %p\n", assoc);

    return RPC_S_OK;
}

ULONG RpcAssoc_Release(RpcAssoc *assoc)
{
    ULONG refs;

    EnterCriticalSection(&assoc_list_cs);
    refs = --assoc->refs;
    if (!refs)
        list_remove(&assoc->entry);
    LeaveCriticalSection(&assoc_list_cs);

    if (!refs)
    {
        RpcConnection *Connection, *cursor2;
        RpcContextHandle *context_handle, *context_handle_cursor;

        TRACE("destroying assoc %p\n", assoc);

        LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
        {
            list_remove(&Connection->conn_pool_entry);
            RPCRT4_ReleaseConnection(Connection);
        }

        LIST_FOR_EACH_ENTRY_SAFE(context_handle, context_handle_cursor, &assoc->context_handle_list, RpcContextHandle, entry)
            RpcContextHandle_Destroy(context_handle);

        HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions);
        HeapFree(GetProcessHeap(), 0, assoc->Endpoint);
        HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr);
        HeapFree(GetProcessHeap(), 0, assoc->Protseq);

        assoc->cs.DebugInfo->Spare[0] = 0;
        DeleteCriticalSection(&assoc->cs);

        HeapFree(GetProcessHeap(), 0, assoc);
    }

    return refs;
}

#define ROUND_UP(value, alignment) (((value) + ((alignment) - 1)) & ~((alignment)-1))

static RPC_STATUS RpcAssoc_BindConnection(const RpcAssoc *assoc, RpcConnection *conn,
                                          const RPC_SYNTAX_IDENTIFIER *InterfaceId,
                                          const RPC_SYNTAX_IDENTIFIER *TransferSyntax)
{
    RpcPktHdr *hdr;
    RpcPktHdr *response_hdr;
    RPC_MESSAGE msg;
    RPC_STATUS status;
    unsigned char *auth_data = NULL;
    ULONG auth_length;

    TRACE("sending bind request to server\n");

    hdr = RPCRT4_BuildBindHeader(NDR_LOCAL_DATA_REPRESENTATION,
                                 RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
                                 assoc->assoc_group_id,
                                 InterfaceId, TransferSyntax);

    status = RPCRT4_Send(conn, hdr, NULL, 0);
    RPCRT4_FreeHeader(hdr);
    if (status != RPC_S_OK)
        return status;

    status = RPCRT4_ReceiveWithAuth(conn, &response_hdr, &msg, &auth_data, &auth_length);
    if (status != RPC_S_OK)
    {
        ERR("receive failed with error %d\n", status);
        return status;
    }

    switch (response_hdr->common.ptype)
    {
    case PKT_BIND_ACK:
    {
        RpcAddressString *server_address = msg.Buffer;
        if ((msg.BufferLength >= FIELD_OFFSET(RpcAddressString, string[0])) ||
            (msg.BufferLength >= ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4)))
        {
            unsigned short remaining = msg.BufferLength -
            ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4);
            RpcResultList *results = (RpcResultList*)((ULONG_PTR)server_address +
                ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4));
            if ((results->num_results == 1) &&
                (remaining >= FIELD_OFFSET(RpcResultList, results[results->num_results])))
            {
                switch (results->results[0].result)
                {
                case RESULT_ACCEPT:
                    /* respond to authorization request */
                    if (auth_length > sizeof(RpcAuthVerifier))
                        status = RPCRT4_ClientConnectionAuth(conn,
                                                             auth_data + sizeof(RpcAuthVerifier),
                                                             auth_length);
                    if (status == RPC_S_OK)
                    {
                        conn->assoc_group_id = response_hdr->bind_ack.assoc_gid;
                        conn->MaxTransmissionSize = response_hdr->bind_ack.max_tsize;
                        conn->ActiveInterface = *InterfaceId;
                    }
                    break;
                case RESULT_PROVIDER_REJECTION:
                    switch (results->results[0].reason)
                    {
                    case REASON_ABSTRACT_SYNTAX_NOT_SUPPORTED:
                        ERR("syntax %s, %d.%d not supported\n",
                            debugstr_guid(&InterfaceId->SyntaxGUID),
                            InterfaceId->SyntaxVersion.MajorVersion,
                            InterfaceId->SyntaxVersion.MinorVersion);
                        status = RPC_S_UNKNOWN_IF;
                        break;
                    case REASON_TRANSFER_SYNTAXES_NOT_SUPPORTED:
                        ERR("transfer syntax not supported\n");
                        status = RPC_S_SERVER_UNAVAILABLE;
                        break;
                    case REASON_NONE:
                    default:
                        status = RPC_S_CALL_FAILED_DNE;
                    }
                    break;
                case RESULT_USER_REJECTION:
                default:
                    ERR("rejection result %d\n", results->results[0].result);
                    status = RPC_S_CALL_FAILED_DNE;
                }
            }
            else
            {
                ERR("incorrect results size\n");
                status = RPC_S_CALL_FAILED_DNE;
            }
        }
        else
        {
            ERR("bind ack packet too small (%d)\n", msg.BufferLength);
            status = RPC_S_PROTOCOL_ERROR;
        }
        break;
    }
    case PKT_BIND_NACK:
        switch (response_hdr->bind_nack.reject_reason)
        {
        case REJECT_LOCAL_LIMIT_EXCEEDED:
        case REJECT_TEMPORARY_CONGESTION:
            ERR("server too busy\n");
            status = RPC_S_SERVER_TOO_BUSY;
            break;
        case REJECT_PROTOCOL_VERSION_NOT_SUPPORTED:
            ERR("protocol version not supported\n");
            status = RPC_S_PROTOCOL_ERROR;
            break;
        case REJECT_UNKNOWN_AUTHN_SERVICE:
            ERR("unknown authentication service\n");
            status = RPC_S_UNKNOWN_AUTHN_SERVICE;
            break;
        case REJECT_INVALID_CHECKSUM:
            ERR("invalid checksum\n");
            status = RPC_S_ACCESS_DENIED;
            break;
        default:
            ERR("rejected bind for reason %d\n", response_hdr->bind_nack.reject_reason);
            status = RPC_S_CALL_FAILED_DNE;
        }
        break;
    default:
        ERR("wrong packet type received %d\n", response_hdr->common.ptype);
        status = RPC_S_PROTOCOL_ERROR;
        break;
    }

    I_RpcFree(msg.Buffer);
    RPCRT4_FreeHeader(response_hdr);
    HeapFree(GetProcessHeap(), 0, auth_data);
    return status;
}

static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
                                                 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
                                                 const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo,
                                                 const RpcQualityOfService *QOS)
{
    RpcConnection *Connection;
    EnterCriticalSection(&assoc->cs);
    /* try to find a compatible connection from the connection pool */
    LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
    {
        if (!memcmp(&Connection->ActiveInterface, InterfaceId,
                    sizeof(RPC_SYNTAX_IDENTIFIER)) &&
            RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
            RpcQualityOfService_IsEqual(Connection->QOS, QOS))
        {
            list_remove(&Connection->conn_pool_entry);
            LeaveCriticalSection(&assoc->cs);
            TRACE("got connection from pool %p\n", Connection);
            return Connection;
        }
    }

    LeaveCriticalSection(&assoc->cs);
    return NULL;
}

RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc,
                                        const RPC_SYNTAX_IDENTIFIER *InterfaceId,
                                        const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo,
                                        RpcQualityOfService *QOS, LPCWSTR CookieAuth, RpcConnection **Connection)
{
    RpcConnection *NewConnection;
    RPC_STATUS status;

    *Connection = RpcAssoc_GetIdleConnection(assoc, InterfaceId, TransferSyntax, AuthInfo, QOS);
    if (*Connection)
        return RPC_S_OK;

    /* create a new connection */
    status = RPCRT4_CreateConnection(&NewConnection, FALSE /* is this a server connection? */,
        assoc->Protseq, assoc->NetworkAddr,
        assoc->Endpoint, assoc->NetworkOptions,
        AuthInfo, QOS, CookieAuth);
    if (status != RPC_S_OK)
        return status;

    NewConnection->assoc = assoc;
    status = RPCRT4_OpenClientConnection(NewConnection);
    if (status != RPC_S_OK)
    {
        RPCRT4_ReleaseConnection(NewConnection);
        return status;
    }

    status = RpcAssoc_BindConnection(assoc, NewConnection, InterfaceId, TransferSyntax);
    if (status != RPC_S_OK)
    {
        RPCRT4_ReleaseConnection(NewConnection);
        return status;
    }

    *Connection = NewConnection;

    return RPC_S_OK;
}

void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection)
{
    assert(!Connection->server);
    Connection->async_state = NULL;
    EnterCriticalSection(&assoc->cs);
    if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id;
    list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry);
    LeaveCriticalSection(&assoc->cs);
}

RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard,
                                                NDR_SCONTEXT *SContext)
{
    RpcContextHandle *context_handle;

    context_handle = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*context_handle));
    if (!context_handle)
        return RPC_S_OUT_OF_MEMORY;

    context_handle->ctx_guard = CtxGuard;
    RtlInitializeResource(&context_handle->rw_lock);
    context_handle->refs = 1;

    /* lock here to mirror unmarshall, so we don't need to special-case the
     * freeing of a non-marshalled context handle */
    RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);

    EnterCriticalSection(&assoc->cs);
    list_add_tail(&assoc->context_handle_list, &context_handle->entry);
    LeaveCriticalSection(&assoc->cs);

    *SContext = (NDR_SCONTEXT)context_handle;
    return RPC_S_OK;
}

BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard)
{
    RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
    return context_handle->ctx_guard == CtxGuard;
}

RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid,
                                            void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext)
{
    RpcContextHandle *context_handle;

    EnterCriticalSection(&assoc->cs);
    LIST_FOR_EACH_ENTRY(context_handle, &assoc->context_handle_list, RpcContextHandle, entry)
    {
        if (RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard) &&
            !memcmp(&context_handle->uuid, uuid, sizeof(*uuid)))
        {
            *SContext = (NDR_SCONTEXT)context_handle;
            if (context_handle->refs++)
            {
                LeaveCriticalSection(&assoc->cs);
                TRACE("found %p\n", context_handle);
                RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE);
                return RPC_S_OK;
            }
        }
    }
    LeaveCriticalSection(&assoc->cs);

    ERR("no context handle found for uuid %s, guard %p\n",
        debugstr_guid(uuid), CtxGuard);
    return ERROR_INVALID_HANDLE;
}

RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc,
                                              NDR_SCONTEXT SContext,
                                              void *CtxGuard,
                                              NDR_RUNDOWN rundown_routine)
{
    RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
    RPC_STATUS status;

    if (!RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard))
        return ERROR_INVALID_HANDLE;

    EnterCriticalSection(&assoc->cs);
    if (UuidIsNil(&context_handle->uuid, &status))
    {
        /* add a ref for the data being valid */
        context_handle->refs++;
        UuidCreate(&context_handle->uuid);
        context_handle->rundown_routine = rundown_routine;
        TRACE("allocated uuid %s for context handle %p\n",
              debugstr_guid(&context_handle->uuid), context_handle);
    }
    LeaveCriticalSection(&assoc->cs);

    return RPC_S_OK;
}

void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid)
{
    RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
    *uuid = context_handle->uuid;
}

static void RpcContextHandle_Destroy(RpcContextHandle *context_handle)
{
    TRACE("freeing %p\n", context_handle);

    if (context_handle->user_context && context_handle->rundown_routine)
    {
        TRACE("calling rundown routine %p with user context %p\n",
              context_handle->rundown_routine, context_handle->user_context);
        context_handle->rundown_routine(context_handle->user_context);
    }

    RtlDeleteResource(&context_handle->rw_lock);

    HeapFree(GetProcessHeap(), 0, context_handle);
}

unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock)
{
    RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
    unsigned int refs;

    if (release_lock)
        RtlReleaseResource(&context_handle->rw_lock);

    EnterCriticalSection(&assoc->cs);
    refs = --context_handle->refs;
    if (!refs)
        list_remove(&context_handle->entry);
    LeaveCriticalSection(&assoc->cs);

    if (!refs)
        RpcContextHandle_Destroy(context_handle);

    return refs;
}