Commit 24e276a9 authored by Evan Tang's avatar Evan Tang Committed by Alexandre Julliard

secur32: Schannel AcceptSecurityContext support.

parent 5267fcca
......@@ -769,14 +769,11 @@ static BOOL validate_input_buffers(SecBufferDesc *desc)
return TRUE;
}
/***********************************************************************
* InitializeSecurityContextW
*/
static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
static SECURITY_STATUS establish_context(
PCredHandle phCredential, PCtxtHandle phContext, SEC_WCHAR *pszTargetName,
ULONG fContextReq, ULONG Reserved1, ULONG TargetDataRep,
PSecBufferDesc pInput, ULONG Reserved2, PCtxtHandle phNewContext,
PSecBufferDesc pOutput, ULONG *pfContextAttr, PTimeStamp ptsExpiry)
PSecBufferDesc pInput, ULONG fContextReq, ULONG TargetDataRep,
PCtxtHandle phNewContext, PSecBufferDesc pOutput, ULONG *pfContextAttr,
PTimeStamp ptsTimeStamp, BOOL bIsServer)
{
const ULONG extra_size = 0x10000;
struct schan_context *ctx;
......@@ -791,26 +788,20 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
ULONG input_offset = 0, output_offset = 0;
SecBufferDesc input_desc, output_desc;
TRACE("%p %p %s 0x%08lx %ld %ld %p %ld %p %p %p %p\n", phCredential, phContext,
debugstr_w(pszTargetName), fContextReq, Reserved1, TargetDataRep, pInput,
Reserved1, phNewContext, pOutput, pfContextAttr, ptsExpiry);
dump_buffer_desc(pInput);
dump_buffer_desc(pOutput);
if (ptsExpiry)
if (ptsTimeStamp)
{
ptsExpiry->LowPart = 0;
ptsExpiry->HighPart = 0;
ptsTimeStamp->LowPart = 0;
ptsTimeStamp->HighPart = 0;
}
if (!pOutput || !pOutput->cBuffers) return SEC_E_INVALID_TOKEN;
for (i = 0; i < pOutput->cBuffers; i++)
{
ULONG type = pOutput->pBuffers[i].BufferType;
ULONG allocate_memory_flag = bIsServer ? ASC_REQ_ALLOCATE_MEMORY : ISC_REQ_ALLOCATE_MEMORY;
if (type != SECBUFFER_TOKEN && type != SECBUFFER_ALERT) continue;
if (!pOutput->pBuffers[i].cbBuffer && !(fContextReq & ISC_REQ_ALLOCATE_MEMORY))
if (!pOutput->pBuffers[i].cbBuffer && !(fContextReq & allocate_memory_flag))
return SEC_E_INSUFFICIENT_MEMORY;
}
......@@ -818,15 +809,16 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
{
ULONG_PTR handle;
struct create_session_params create_params;
ULONG credential_use = bIsServer ? SECPKG_CRED_INBOUND : SECPKG_CRED_OUTBOUND;
if (!phCredential) return SEC_E_INVALID_HANDLE;
cred = schan_get_object(phCredential->dwLower, SCHAN_HANDLE_CRED);
if (!cred) return SEC_E_INVALID_HANDLE;
if (!(cred->credential_use & SECPKG_CRED_OUTBOUND))
if (!(cred->credential_use & credential_use))
{
WARN("Invalid credential use %#lx\n", cred->credential_use);
WARN("Invalid credential use %#lx, expected %#lx\n", cred->credential_use, credential_use);
return SEC_E_INVALID_HANDLE;
}
......@@ -848,7 +840,7 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
return SEC_E_INTERNAL_ERROR;
}
if (cred->enabled_protocols & (SP_PROT_DTLS1_0_CLIENT | SP_PROT_DTLS1_2_CLIENT))
if (cred->enabled_protocols & SP_PROT_DTLS1_X)
ctx->header_size = HEADER_SIZE_DTLS;
else
ctx->header_size = HEADER_SIZE_TLS;
......@@ -894,12 +886,13 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
phNewContext->dwLower = handle;
phNewContext->dwUpper = 0;
}
else
if (bIsServer || phContext)
{
SIZE_T record_size = 0;
unsigned char *ptr;
if (!(ctx = schan_get_object(phContext->dwLower, SCHAN_HANDLE_CTX))) return SEC_E_INVALID_HANDLE;
if (phContext && !(ctx = schan_get_object(phContext->dwLower, SCHAN_HANDLE_CTX))) return SEC_E_INVALID_HANDLE;
if (!pInput && !ctx->shutdown_requested && !is_dtls_context(ctx)) return SEC_E_INCOMPLETE_MESSAGE;
if (!ctx->shutdown_requested && pInput)
......@@ -938,7 +931,7 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
TRACE("Using expected_size %Iu.\n", expected_size);
}
if (phNewContext) *phNewContext = *phContext;
if (phNewContext && phContext) *phNewContext = *phContext;
}
ctx->req_ctx_attr = fContextReq;
......@@ -1014,17 +1007,46 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
if (buffer->BufferType == SECBUFFER_ALERT) buffer->cbBuffer = 0;
}
*pfContextAttr = ISC_RET_REPLAY_DETECT | ISC_RET_SEQUENCE_DETECT | ISC_RET_CONFIDENTIALITY | ISC_RET_STREAM;
if (ctx->req_ctx_attr & ISC_REQ_EXTENDED_ERROR) *pfContextAttr |= ISC_RET_EXTENDED_ERROR;
if (ctx->req_ctx_attr & ISC_REQ_DATAGRAM) *pfContextAttr |= ISC_RET_DATAGRAM;
if (ctx->req_ctx_attr & ISC_REQ_ALLOCATE_MEMORY) *pfContextAttr |= ISC_RET_ALLOCATED_MEMORY;
if (ctx->req_ctx_attr & ISC_REQ_USE_SUPPLIED_CREDS) *pfContextAttr |= ISC_RET_USED_SUPPLIED_CREDS;
if (ctx->req_ctx_attr & ISC_REQ_MANUAL_CRED_VALIDATION) *pfContextAttr |= ISC_RET_MANUAL_CRED_VALIDATION;
if (bIsServer)
{
*pfContextAttr = ASC_RET_REPLAY_DETECT | ASC_RET_SEQUENCE_DETECT | ASC_RET_CONFIDENTIALITY | ASC_RET_STREAM;
if (ctx->req_ctx_attr & ASC_REQ_EXTENDED_ERROR) *pfContextAttr |= ASC_RET_EXTENDED_ERROR;
if (ctx->req_ctx_attr & ASC_REQ_DATAGRAM) *pfContextAttr |= ASC_RET_DATAGRAM;
if (ctx->req_ctx_attr & ASC_REQ_ALLOCATE_MEMORY) *pfContextAttr |= ASC_RET_ALLOCATED_MEMORY;
}
else
{
*pfContextAttr = ISC_RET_REPLAY_DETECT | ISC_RET_SEQUENCE_DETECT | ISC_RET_CONFIDENTIALITY | ISC_RET_STREAM;
if (ctx->req_ctx_attr & ISC_REQ_EXTENDED_ERROR) *pfContextAttr |= ISC_RET_EXTENDED_ERROR;
if (ctx->req_ctx_attr & ISC_REQ_DATAGRAM) *pfContextAttr |= ISC_RET_DATAGRAM;
if (ctx->req_ctx_attr & ISC_REQ_ALLOCATE_MEMORY) *pfContextAttr |= ISC_RET_ALLOCATED_MEMORY;
if (ctx->req_ctx_attr & ISC_REQ_USE_SUPPLIED_CREDS) *pfContextAttr |= ISC_RET_USED_SUPPLIED_CREDS;
if (ctx->req_ctx_attr & ISC_REQ_MANUAL_CRED_VALIDATION) *pfContextAttr |= ISC_RET_MANUAL_CRED_VALIDATION;
}
return ret;
}
/***********************************************************************
* InitializeSecurityContextW
*/
static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
PCredHandle phCredential, PCtxtHandle phContext, SEC_WCHAR *pszTargetName,
ULONG fContextReq, ULONG Reserved1, ULONG TargetDataRep,
PSecBufferDesc pInput, ULONG Reserved2, PCtxtHandle phNewContext,
PSecBufferDesc pOutput, ULONG *pfContextAttr, PTimeStamp ptsExpiry)
{
TRACE("%p %p %s 0x%08lx %ld %ld %p %ld %p %p %p %p\n", phCredential, phContext,
debugstr_w(pszTargetName), fContextReq, Reserved1, TargetDataRep, pInput,
Reserved1, phNewContext, pOutput, pfContextAttr, ptsExpiry);
dump_buffer_desc(pInput);
dump_buffer_desc(pOutput);
return establish_context(phCredential, phContext, pszTargetName, pInput, fContextReq, TargetDataRep, phNewContext, pOutput, pfContextAttr, ptsExpiry, FALSE);
}
/***********************************************************************
* InitializeSecurityContextA
*/
static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextA(
......@@ -1055,6 +1077,23 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextA(
return ret;
}
/***********************************************************************
* AcceptSecurityContext
*/
static SECURITY_STATUS SEC_ENTRY schan_AcceptSecurityContext(
PCredHandle phCredential, PCtxtHandle phContext, PSecBufferDesc pInput,
ULONG fContextReq, ULONG TargetDataRep, PCtxtHandle phNewContext,
PSecBufferDesc pOutput, ULONG *pfContextAttr, PTimeStamp ptsTimeStamp)
{
TRACE("%p %p %p 0x%08lx %ld %p %p %p %p\n", phCredential, phContext, pInput,
fContextReq, TargetDataRep, phNewContext, pOutput, pfContextAttr, ptsTimeStamp);
dump_buffer_desc(pInput);
dump_buffer_desc(pOutput);
return establish_context(phCredential, phContext, NULL, pInput, fContextReq, TargetDataRep, phNewContext, pOutput, pfContextAttr, ptsTimeStamp, TRUE);
}
static void *get_alg_name(ALG_ID id, BOOL wide)
{
static const struct {
......@@ -1604,7 +1643,7 @@ static const SecurityFunctionTableA schanTableA = {
schan_FreeCredentialsHandle,
NULL, /* Reserved2 */
schan_InitializeSecurityContextA,
NULL, /* AcceptSecurityContext */
schan_AcceptSecurityContext,
NULL, /* CompleteAuthToken */
schan_DeleteSecurityContext,
schan_ApplyControlToken, /* ApplyControlToken */
......@@ -1635,7 +1674,7 @@ static const SecurityFunctionTableW schanTableW = {
schan_FreeCredentialsHandle,
NULL, /* Reserved2 */
schan_InitializeSecurityContextW,
NULL, /* AcceptSecurityContext */
schan_AcceptSecurityContext,
NULL, /* CompleteAuthToken */
schan_DeleteSecurityContext,
schan_ApplyControlToken, /* ApplyControlToken */
......
......@@ -354,10 +354,12 @@ static ssize_t push_adapter(gnutls_transport_ptr_t transport, const void *buff,
return len;
}
static const struct {
struct protocol_priority_flag {
DWORD enable_flag;
const char *gnutls_flag;
} protocol_priority_flags[] = {
};
static const struct protocol_priority_flag client_protocol_priority_flags[] = {
{SP_PROT_DTLS1_2_CLIENT, "VERS-DTLS1.2"},
{SP_PROT_DTLS1_0_CLIENT, "VERS-DTLS1.0"},
{SP_PROT_TLS1_3_CLIENT, "VERS-TLS1.3"},
......@@ -368,33 +370,46 @@ static const struct {
/* {SP_PROT_SSL2_CLIENT} is not supported by GnuTLS */
};
static const struct protocol_priority_flag server_protocol_priority_flags[] = {
{SP_PROT_DTLS1_2_SERVER, "VERS-DTLS1.2"},
{SP_PROT_DTLS1_0_SERVER, "VERS-DTLS1.0"},
{SP_PROT_TLS1_3_SERVER, "VERS-TLS1.3"},
{SP_PROT_TLS1_2_SERVER, "VERS-TLS1.2"},
{SP_PROT_TLS1_1_SERVER, "VERS-TLS1.1"},
{SP_PROT_TLS1_0_SERVER, "VERS-TLS1.0"},
{SP_PROT_SSL3_SERVER, "VERS-SSL3.0"}
/* {SP_PROT_SSL2_SERVER} is not supported by GnuTLS */
};
static DWORD supported_protocols;
static void check_supported_protocols(void)
static void check_supported_protocols(
const struct protocol_priority_flag *flags, int num_flags, BOOLEAN server)
{
const char *type_desc = server ? "server" : "client";
gnutls_session_t session;
char priority[64];
unsigned i;
int err;
err = pgnutls_init(&session, GNUTLS_CLIENT);
err = pgnutls_init(&session, server ? GNUTLS_SERVER : GNUTLS_CLIENT);
if (err != GNUTLS_E_SUCCESS)
{
pgnutls_perror(err);
return;
}
for(i = 0; i < ARRAY_SIZE(protocol_priority_flags); i++)
for(i = 0; i < num_flags; i++)
{
sprintf(priority, "NORMAL:-%s", protocol_priority_flags[i].gnutls_flag);
sprintf(priority, "NORMAL:-%s", flags[i].gnutls_flag);
err = pgnutls_priority_set_direct(session, priority, NULL);
if (err == GNUTLS_E_SUCCESS)
{
TRACE("%s is supported\n", protocol_priority_flags[i].gnutls_flag);
supported_protocols |= protocol_priority_flags[i].enable_flag;
TRACE("%s %s is supported\n", type_desc, flags[i].gnutls_flag);
supported_protocols |= flags[i].enable_flag;
}
else
TRACE("%s is not supported\n", protocol_priority_flags[i].gnutls_flag);
TRACE("%s %s is not supported\n", type_desc, flags[i].gnutls_flag);
}
pgnutls_deinit(session);
......@@ -420,6 +435,11 @@ static int pull_timeout(gnutls_transport_ptr_t transport, unsigned int timeout)
static NTSTATUS set_priority(schan_credentials *cred, gnutls_session_t session)
{
char priority[128] = "NORMAL:%LATEST_RECORD_VERSION", *p;
BOOL server = !!(cred->credential_use & SECPKG_CRED_INBOUND);
const struct protocol_priority_flag *protocols =
server ? server_protocol_priority_flags : client_protocol_priority_flags;
int num_protocols = server ? ARRAYSIZE(server_protocol_priority_flags)
: ARRAYSIZE(client_protocol_priority_flags);
BOOL using_vers_all = FALSE, disabled;
int i, err;
......@@ -447,16 +467,16 @@ static NTSTATUS set_priority(schan_credentials *cred, gnutls_session_t session)
using_vers_all = TRUE;
}
for (i = 0; i < ARRAY_SIZE(protocol_priority_flags); i++)
for (i = 0; i < num_protocols; i++)
{
if (!(supported_protocols & protocol_priority_flags[i].enable_flag)) continue;
if (!(supported_protocols & protocols[i].enable_flag)) continue;
disabled = !(cred->enabled_protocols & protocol_priority_flags[i].enable_flag);
disabled = !(cred->enabled_protocols & protocols[i].enable_flag);
if (using_vers_all && disabled) continue;
*p++ = ':';
*p++ = disabled ? '-' : '+';
strcpy(p, protocol_priority_flags[i].gnutls_flag);
strcpy(p, protocols[i].gnutls_flag);
p += strlen(p);
}
......@@ -483,7 +503,7 @@ static NTSTATUS schan_create_session( void *args )
*params->session = 0;
if (cred->enabled_protocols & (SP_PROT_DTLS1_0_CLIENT | SP_PROT_DTLS1_2_CLIENT))
if (cred->enabled_protocols & SP_PROT_DTLS1_X)
{
flags |= GNUTLS_DATAGRAM | GNUTLS_NONBLOCK;
}
......@@ -1505,7 +1525,8 @@ static NTSTATUS process_attach( void *args )
pgnutls_global_set_log_function(gnutls_log);
}
check_supported_protocols();
check_supported_protocols(client_protocol_priority_flags, ARRAYSIZE(client_protocol_priority_flags), FALSE);
check_supported_protocols(server_protocol_priority_flags, ARRAYSIZE(server_protocol_priority_flags), TRUE);
return STATUS_SUCCESS;
fail:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment