Commit e6499942 authored by Juan Lang's avatar Juan Lang Committed by Alexandre Julliard

rsaenh: Test and fix CryptEncrypt with NULL buffer.

parent ff421fde
...@@ -1940,12 +1940,20 @@ BOOL WINAPI RSAENH_CPEncrypt(HCRYPTPROV hProv, HCRYPTKEY hKey, HCRYPTHASH hHash, ...@@ -1940,12 +1940,20 @@ BOOL WINAPI RSAENH_CPEncrypt(HCRYPTPROV hProv, HCRYPTKEY hKey, HCRYPTHASH hHash,
memcpy(in, out, pCryptKey->dwBlockLen); memcpy(in, out, pCryptKey->dwBlockLen);
} }
} else if (GET_ALG_TYPE(pCryptKey->aiAlgid) == ALG_TYPE_STREAM) { } else if (GET_ALG_TYPE(pCryptKey->aiAlgid) == ALG_TYPE_STREAM) {
if (pbData == NULL) {
*pdwDataLen = dwBufLen;
return TRUE;
}
encrypt_stream_impl(pCryptKey->aiAlgid, &pCryptKey->context, pbData, *pdwDataLen); encrypt_stream_impl(pCryptKey->aiAlgid, &pCryptKey->context, pbData, *pdwDataLen);
} else if (GET_ALG_TYPE(pCryptKey->aiAlgid) == ALG_TYPE_RSA) { } else if (GET_ALG_TYPE(pCryptKey->aiAlgid) == ALG_TYPE_RSA) {
if (pCryptKey->aiAlgid == CALG_RSA_SIGN) { if (pCryptKey->aiAlgid == CALG_RSA_SIGN) {
SetLastError(NTE_BAD_KEY); SetLastError(NTE_BAD_KEY);
return FALSE; return FALSE;
} }
if (!pbData) {
*pdwDataLen = pCryptKey->dwBlockLen;
return TRUE;
}
if (dwBufLen < pCryptKey->dwBlockLen) { if (dwBufLen < pCryptKey->dwBlockLen) {
SetLastError(ERROR_MORE_DATA); SetLastError(ERROR_MORE_DATA);
return FALSE; return FALSE;
......
...@@ -303,6 +303,11 @@ static void test_block_cipher_modes(void) ...@@ -303,6 +303,11 @@ static void test_block_cipher_modes(void)
result = CryptSetKeyParam(hKey, KP_MODE, (BYTE*)&dwMode, 0); result = CryptSetKeyParam(hKey, KP_MODE, (BYTE*)&dwMode, 0);
ok(result, "%08lx\n", GetLastError()); ok(result, "%08lx\n", GetLastError());
dwLen = 23;
result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, NULL, &dwLen, 24);
ok(result, "CryptEncrypt failed: %08lx\n", GetLastError());
ok(dwLen == 24, "Unexpected length %ld\n", dwLen);
SetLastError(ERROR_SUCCESS); SetLastError(ERROR_SUCCESS);
dwLen = 23; dwLen = 23;
result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, abData, &dwLen, 24); result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, abData, &dwLen, 24);
...@@ -318,6 +323,11 @@ static void test_block_cipher_modes(void) ...@@ -318,6 +323,11 @@ static void test_block_cipher_modes(void)
ok(result, "%08lx\n", GetLastError()); ok(result, "%08lx\n", GetLastError());
dwLen = 23; dwLen = 23;
result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, NULL, &dwLen, 24);
ok(result, "CryptEncrypt failed: %08lx\n", GetLastError());
ok(dwLen == 24, "Unexpected length %ld\n", dwLen);
dwLen = 23;
result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, abData, &dwLen, 24); result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, abData, &dwLen, 24);
ok(result && dwLen == 24 && !memcmp(cbc, abData, sizeof(cbc)), ok(result && dwLen == 24 && !memcmp(cbc, abData, sizeof(cbc)),
"%08lx, dwLen: %ld\n", GetLastError(), dwLen); "%08lx, dwLen: %ld\n", GetLastError(), dwLen);
...@@ -596,6 +606,9 @@ static void test_rc4(void) ...@@ -596,6 +606,9 @@ static void test_rc4(void)
ok(result, "%08lx\n", GetLastError()); ok(result, "%08lx\n", GetLastError());
dwDataLen = 16; dwDataLen = 16;
result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, NULL, &dwDataLen, 24);
ok(result, "%08lx\n", GetLastError());
dwDataLen = 16;
result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, pbData, &dwDataLen, 24); result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, pbData, &dwDataLen, 24);
ok(result, "%08lx\n", GetLastError()); ok(result, "%08lx\n", GetLastError());
...@@ -1105,6 +1118,10 @@ static void test_rsa_encrypt(void) ...@@ -1105,6 +1118,10 @@ static void test_rsa_encrypt(void)
if (!result) return; if (!result) return;
dwLen = 12; dwLen = 12;
result = CryptEncrypt(hRSAKey, 0, TRUE, 0, NULL, &dwLen, (DWORD)sizeof(abData));
ok(result, "CryptEncrypt failed: %08lx\n", GetLastError());
ok(dwLen == 128, "Unexpected length %ld\n", dwLen);
dwLen = 12;
result = CryptEncrypt(hRSAKey, 0, TRUE, 0, abData, &dwLen, (DWORD)sizeof(abData)); result = CryptEncrypt(hRSAKey, 0, TRUE, 0, abData, &dwLen, (DWORD)sizeof(abData));
ok (result, "%08lx\n", GetLastError()); ok (result, "%08lx\n", GetLastError());
if (!result) return; if (!result) return;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment