Commit 0c35a851 authored by Paul Gofman's avatar Paul Gofman Committed by Alexandre Julliard

crypt32: Support CERT_NAME_SEARCH_ALL_NAMES_FLAG in CertGetNameStringW().

parent b1b9a754
...@@ -905,17 +905,7 @@ DWORD WINAPI CertGetNameStringA(PCCERT_CONTEXT cert, DWORD type, ...@@ -905,17 +905,7 @@ DWORD WINAPI CertGetNameStringA(PCCERT_CONTEXT cert, DWORD type,
return ret; return ret;
} }
/* Searches cert's extensions for the alternate name extension with OID static BOOL cert_get_alt_name_info(PCCERT_CONTEXT cert, BOOL alt_name_issuer, PCERT_ALT_NAME_INFO *info)
* altNameOID, and if found, searches it for the alternate name type entryType.
* If found, returns a pointer to the entry, otherwise returns NULL.
* Regardless of whether an entry of the desired type is found, if the
* alternate name extension is present, sets *info to the decoded alternate
* name extension, which you must free using LocalFree.
* The return value is a pointer within *info, so don't free *info before
* you're done with the return value.
*/
static PCERT_ALT_NAME_ENTRY cert_find_alt_name_entry(PCCERT_CONTEXT cert, BOOL alt_name_issuer,
DWORD entryType, PCERT_ALT_NAME_INFO *info)
{ {
static const char *oids[][2] = static const char *oids[][2] =
{ {
...@@ -924,24 +914,48 @@ static PCERT_ALT_NAME_ENTRY cert_find_alt_name_entry(PCCERT_CONTEXT cert, BOOL a ...@@ -924,24 +914,48 @@ static PCERT_ALT_NAME_ENTRY cert_find_alt_name_entry(PCCERT_CONTEXT cert, BOOL a
}; };
PCERT_EXTENSION ext; PCERT_EXTENSION ext;
DWORD bytes = 0; DWORD bytes = 0;
unsigned int i;
ext = CertFindExtension(oids[!!alt_name_issuer][0], cert->pCertInfo->cExtension, cert->pCertInfo->rgExtension); ext = CertFindExtension(oids[!!alt_name_issuer][0], cert->pCertInfo->cExtension, cert->pCertInfo->rgExtension);
if (!ext) if (!ext)
ext = CertFindExtension(oids[!!alt_name_issuer][1], cert->pCertInfo->cExtension, cert->pCertInfo->rgExtension); ext = CertFindExtension(oids[!!alt_name_issuer][1], cert->pCertInfo->cExtension, cert->pCertInfo->rgExtension);
if (!ext) return NULL; if (!ext) return FALSE;
if (!CryptDecodeObjectEx(cert->dwCertEncodingType, X509_ALTERNATE_NAME, ext->Value.pbData, ext->Value.cbData, return CryptDecodeObjectEx(cert->dwCertEncodingType, X509_ALTERNATE_NAME, ext->Value.pbData, ext->Value.cbData,
CRYPT_DECODE_ALLOC_FLAG, NULL, info, &bytes)) CRYPT_DECODE_ALLOC_FLAG, NULL, info, &bytes);
return NULL; }
for (i = 0; i < (*info)->cAltEntry; ++i) static PCERT_ALT_NAME_ENTRY cert_find_next_alt_name_entry(PCERT_ALT_NAME_INFO info, DWORD entry_type,
if ((*info)->rgAltEntry[i].dwAltNameChoice == entryType) unsigned int *index)
return &(*info)->rgAltEntry[i]; {
unsigned int i;
for (i = *index; i < info->cAltEntry; ++i)
if (info->rgAltEntry[i].dwAltNameChoice == entry_type)
{
*index = i + 1;
return &info->rgAltEntry[i];
}
return NULL; return NULL;
} }
/* Searches cert's extensions for the alternate name extension with OID
* altNameOID, and if found, searches it for the alternate name type entryType.
* If found, returns a pointer to the entry, otherwise returns NULL.
* Regardless of whether an entry of the desired type is found, if the
* alternate name extension is present, sets *info to the decoded alternate
* name extension, which you must free using LocalFree.
* The return value is a pointer within *info, so don't free *info before
* you're done with the return value.
*/
static PCERT_ALT_NAME_ENTRY cert_find_alt_name_entry(PCCERT_CONTEXT cert, BOOL alt_name_issuer,
DWORD entry_type, PCERT_ALT_NAME_INFO *info)
{
unsigned int index = 0;
if (!cert_get_alt_name_info(cert, alt_name_issuer, info)) return NULL;
return cert_find_next_alt_name_entry(*info, entry_type, &index);
}
static DWORD cert_get_name_from_rdn_attr(DWORD encodingType, static DWORD cert_get_name_from_rdn_attr(DWORD encodingType,
const CERT_NAME_BLOB *name, LPCSTR oid, LPWSTR pszNameString, DWORD cchNameString) const CERT_NAME_BLOB *name, LPCSTR oid, LPWSTR pszNameString, DWORD cchNameString)
{ {
...@@ -978,9 +992,10 @@ static DWORD copy_output_str(WCHAR *dst, const WCHAR *src, DWORD dst_size) ...@@ -978,9 +992,10 @@ static DWORD copy_output_str(WCHAR *dst, const WCHAR *src, DWORD dst_size)
DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, void *type_para, DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, void *type_para,
LPWSTR name_string, DWORD name_len) LPWSTR name_string, DWORD name_len)
{ {
static const DWORD supported_flags = CERT_NAME_ISSUER_FLAG | CERT_NAME_SEARCH_ALL_NAMES_FLAG;
BOOL alt_name_issuer, search_all_names;
CERT_ALT_NAME_INFO *info = NULL; CERT_ALT_NAME_INFO *info = NULL;
PCERT_ALT_NAME_ENTRY entry; PCERT_ALT_NAME_ENTRY entry;
BOOL alt_name_issuer;
PCERT_NAME_BLOB name; PCERT_NAME_BLOB name;
DWORD ret = 0; DWORD ret = 0;
...@@ -989,6 +1004,16 @@ DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, vo ...@@ -989,6 +1004,16 @@ DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, vo
if (!cert) if (!cert)
goto done; goto done;
if (flags & ~supported_flags)
FIXME("Unsupported flags %#lx.\n", flags);
search_all_names = flags & CERT_NAME_SEARCH_ALL_NAMES_FLAG;
if (search_all_names && type != CERT_NAME_DNS_TYPE)
{
WARN("CERT_NAME_SEARCH_ALL_NAMES_FLAG used with type %lu.\n", type);
goto done;
}
alt_name_issuer = flags & CERT_NAME_ISSUER_FLAG; alt_name_issuer = flags & CERT_NAME_ISSUER_FLAG;
name = alt_name_issuer ? &cert->pCertInfo->Issuer : &cert->pCertInfo->Subject; name = alt_name_issuer ? &cert->pCertInfo->Issuer : &cert->pCertInfo->Subject;
...@@ -1077,15 +1102,43 @@ DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, vo ...@@ -1077,15 +1102,43 @@ DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, vo
} }
case CERT_NAME_DNS_TYPE: case CERT_NAME_DNS_TYPE:
{ {
entry = cert_find_alt_name_entry(cert, alt_name_issuer, CERT_ALT_NAME_DNS_NAME, &info); unsigned int index = 0, len;
if (entry) if (cert_get_alt_name_info(cert, alt_name_issuer, &info)
&& (entry = cert_find_next_alt_name_entry(info, CERT_ALT_NAME_DNS_NAME, &index)))
{ {
ret = copy_output_str(name_string, entry->u.pwszDNSName, name_len); if (search_all_names)
break; {
do
{
if (name_string && name_len == 1) break;
ret += len = copy_output_str(name_string, entry->u.pwszDNSName, name_len ? name_len - 1 : 0);
if (name_string && name_len)
{
name_string += len;
name_len -= len;
}
}
while ((entry = cert_find_next_alt_name_entry(info, CERT_ALT_NAME_DNS_NAME, &index)));
}
else ret = copy_output_str(name_string, entry->u.pwszDNSName, name_len);
} }
else
{
if (!search_all_names || name_len != 1)
{
len = search_all_names && name_len ? name_len - 1 : name_len;
ret = cert_get_name_from_rdn_attr(cert->dwCertEncodingType, name, szOID_COMMON_NAME, ret = cert_get_name_from_rdn_attr(cert->dwCertEncodingType, name, szOID_COMMON_NAME,
name_string, name_len); name_string, len);
if (name_string) name_string += ret;
}
}
if (search_all_names)
{
if (name_string && name_len) *name_string = 0;
++ret;
}
break; break;
} }
case CERT_NAME_URL_TYPE: case CERT_NAME_URL_TYPE:
......
...@@ -847,39 +847,63 @@ static void test_CertStrToNameW(void) ...@@ -847,39 +847,63 @@ static void test_CertStrToNameW(void)
static void test_CertGetNameString_value_(unsigned int line, PCCERT_CONTEXT context, DWORD type, DWORD flags, static void test_CertGetNameString_value_(unsigned int line, PCCERT_CONTEXT context, DWORD type, DWORD flags,
void *type_para, const char *expected) void *type_para, const char *expected)
{ {
DWORD len, retlen, expected_len;
WCHAR expectedW[512]; WCHAR expectedW[512];
DWORD len, retlen;
WCHAR strW[512]; WCHAR strW[512];
unsigned int i;
char str[512]; char str[512];
for (i = 0; expected[i]; ++i) expected_len = 0;
expectedW[i] = expected[i]; while(expected[expected_len])
expectedW[i] = 0; {
while((expectedW[expected_len] = expected[expected_len]))
++expected_len;
if (!(flags & CERT_NAME_SEARCH_ALL_NAMES_FLAG))
break;
expectedW[expected_len++] = 0;
}
expectedW[expected_len++] = 0;
len = CertGetNameStringA(context, type, flags, type_para, NULL, 0); len = CertGetNameStringA(context, type, flags, type_para, NULL, 0);
ok(len == strlen(expected) + 1, "line %u: unexpected length %ld.\n", line, len); ok(len == expected_len, "line %u: unexpected length %ld, expected %ld.\n", line, len, expected_len);
memset(str, 0xcc, len);
retlen = CertGetNameStringA(context, type, flags, type_para, str, len); retlen = CertGetNameStringA(context, type, flags, type_para, str, len);
ok(retlen == len, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len); ok(retlen == len, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len);
ok(!strcmp(str, expected), "line %u: unexpected value %s.\n", line, str); ok(!memcmp(str, expected, expected_len), "line %u: unexpected value %s.\n", line, debugstr_an(str, expected_len));
str[0] = str[1] = 0xcc; str[0] = str[1] = 0xcc;
retlen = CertGetNameStringA(context, type, flags, type_para, str, len - 1); retlen = CertGetNameStringA(context, type, flags, type_para, str, len - 1);
ok(retlen == 1, "line %u: Unexpected len %lu, expected 1.\n", line, retlen); ok(retlen == 1, "line %u: Unexpected len %lu, expected 1.\n", line, retlen);
if (len == 1) return; if (len == 1) return;
ok(!str[0], "line %u: unexpected str[0] %#x.\n", line, str[0]); ok(!str[0], "line %u: unexpected str[0] %#x.\n", line, str[0]);
ok(str[1] == expected[1], "line %u: unexpected str[1] %#x.\n", line, str[1]); ok(str[1] == expected[1], "line %u: unexpected str[1] %#x.\n", line, str[1]);
ok(!memcmp(str + 1, expected + 1, len - 2),
"line %u: str %s, string data mismatch.\n", line, debugstr_a(str + 1));
retlen = CertGetNameStringA(context, type, flags, type_para, str, 0); retlen = CertGetNameStringA(context, type, flags, type_para, str, 0);
ok(retlen == len, "line %u: Unexpected len %lu, expected 1.\n", line, retlen); ok(retlen == len, "line %u: Unexpected len %lu, expected 1.\n", line, retlen);
memset(strW, 0xcc, len * sizeof(*strW));
retlen = CertGetNameStringW(context, type, flags, type_para, strW, len); retlen = CertGetNameStringW(context, type, flags, type_para, strW, len);
ok(retlen == len, "line %u: unexpected len %lu, expected 1.\n", line, retlen); ok(retlen == expected_len, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, expected_len);
ok(!wcscmp(strW, expectedW), "line %u: unexpected value %s.\n", line, debugstr_w(strW)); ok(!memcmp(strW, expectedW, len * sizeof(*strW)), "line %u: unexpected value %s.\n", line, debugstr_wn(strW, len));
strW[0] = strW[1] = 0xcccc; strW[0] = strW[1] = 0xcccc;
retlen = CertGetNameStringW(context, type, flags, type_para, strW, len - 1); retlen = CertGetNameStringW(context, type, flags, type_para, strW, len - 1);
ok(retlen == len - 1, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len - 1); ok(retlen == len - 1, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len - 1);
ok(!wcsncmp(strW, expectedW, retlen - 1), "line %u: string data mismatch.\n", line); if (flags & CERT_NAME_SEARCH_ALL_NAMES_FLAG)
{
ok(!memcmp(strW, expectedW, (retlen - 2) * sizeof(*strW)),
"line %u: str %s, string data mismatch.\n", line, debugstr_wn(strW, retlen - 2));
ok(!strW[retlen - 2], "line %u: string is not zero terminated.\n", line);
ok(!strW[retlen - 1], "line %u: string sequence is not zero terminated.\n", line);
retlen = CertGetNameStringW(context, type, flags, type_para, strW, 1);
ok(retlen == 1, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len - 1);
ok(!strW[retlen - 1], "line %u: string sequence is not zero terminated.\n", line);
}
else
{
ok(!memcmp(strW, expectedW, (retlen - 1) * sizeof(*strW)),
"line %u: str %s, string data mismatch.\n", line, debugstr_wn(strW, retlen - 1));
ok(!strW[retlen - 1], "line %u: string is not zero terminated.\n", line); ok(!strW[retlen - 1], "line %u: string is not zero terminated.\n", line);
}
retlen = CertGetNameStringA(context, type, flags, type_para, NULL, len - 1); retlen = CertGetNameStringA(context, type, flags, type_para, NULL, len - 1);
ok(retlen == len, "line %u: unexpected len %lu, expected %lu\n", line, retlen, len); ok(retlen == len, "line %u: unexpected len %lu, expected %lu\n", line, retlen, len);
retlen = CertGetNameStringW(context, type, flags, type_para, NULL, len - 1); retlen = CertGetNameStringW(context, type, flags, type_para, NULL, len - 1);
...@@ -924,6 +948,9 @@ static void test_CertGetNameString(void) ...@@ -924,6 +948,9 @@ static void test_CertGetNameString(void)
test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, NULL, localhost); test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, NULL, localhost);
test_CertGetNameString_value(context, CERT_NAME_FRIENDLY_DISPLAY_TYPE, 0, NULL, localhost); test_CertGetNameString_value(context, CERT_NAME_FRIENDLY_DISPLAY_TYPE, 0, NULL, localhost);
test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, 0, NULL, localhost); test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, 0, NULL, localhost);
test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG, NULL, "localhost\0");
test_CertGetNameString_value(context, CERT_NAME_EMAIL_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG, NULL, "");
test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG, NULL, "");
CertFreeCertificateContext(context); CertFreeCertificateContext(context);
...@@ -945,6 +972,10 @@ static void test_CertGetNameString(void) ...@@ -945,6 +972,10 @@ static void test_CertGetNameString(void)
test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_ISSUER_FLAG, NULL, "ex3.org"); test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_ISSUER_FLAG, NULL, "ex3.org");
test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, NULL, "server_cn.org"); test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, NULL, "server_cn.org");
test_CertGetNameString_value(context, CERT_NAME_ATTR_TYPE, 0, (void *)szOID_SUR_NAME, ""); test_CertGetNameString_value(context, CERT_NAME_ATTR_TYPE, 0, (void *)szOID_SUR_NAME, "");
test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG,
NULL, "ex1.org\0*.ex2.org\0");
test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG | CERT_NAME_ISSUER_FLAG,
NULL, "ex3.org\0*.ex4.org\0");
CertFreeCertificateContext(context); CertFreeCertificateContext(context);
} }
......
...@@ -3352,7 +3352,9 @@ typedef struct _CTL_FIND_SUBJECT_PARA ...@@ -3352,7 +3352,9 @@ typedef struct _CTL_FIND_SUBJECT_PARA
#define CERT_NAME_UPN_TYPE 8 #define CERT_NAME_UPN_TYPE 8
#define CERT_NAME_ISSUER_FLAG 0x00000001 #define CERT_NAME_ISSUER_FLAG 0x00000001
#define CERT_NAME_SEARCH_ALL_NAMES_FLAG 0x00000002
#define CERT_NAME_DISABLE_IE4_UTF8_FLAG 0x00010000 #define CERT_NAME_DISABLE_IE4_UTF8_FLAG 0x00010000
#define CERT_NAME_STR_ENABLE_PUNYCODE_FLAG 0x00200000
/* CryptFormatObject flags */ /* CryptFormatObject flags */
#define CRYPT_FORMAT_STR_MULTI_LINE 0x0001 #define CRYPT_FORMAT_STR_MULTI_LINE 0x0001
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment