diff --git a/dlls/rsaenh/rsaenh.c b/dlls/rsaenh/rsaenh.c index 02e3e6e32b..eafc92c960 100644 --- a/dlls/rsaenh/rsaenh.c +++ b/dlls/rsaenh/rsaenh.c @@ -1940,12 +1940,20 @@ BOOL WINAPI RSAENH_CPEncrypt(HCRYPTPROV hProv, HCRYPTKEY hKey, HCRYPTHASH hHash, memcpy(in, out, pCryptKey->dwBlockLen); } } else if (GET_ALG_TYPE(pCryptKey->aiAlgid) == ALG_TYPE_STREAM) { + if (pbData == NULL) { + *pdwDataLen = dwBufLen; + return TRUE; + } encrypt_stream_impl(pCryptKey->aiAlgid, &pCryptKey->context, pbData, *pdwDataLen); } else if (GET_ALG_TYPE(pCryptKey->aiAlgid) == ALG_TYPE_RSA) { if (pCryptKey->aiAlgid == CALG_RSA_SIGN) { SetLastError(NTE_BAD_KEY); return FALSE; } + if (!pbData) { + *pdwDataLen = pCryptKey->dwBlockLen; + return TRUE; + } if (dwBufLen < pCryptKey->dwBlockLen) { SetLastError(ERROR_MORE_DATA); return FALSE; diff --git a/dlls/rsaenh/tests/rsaenh.c b/dlls/rsaenh/tests/rsaenh.c index 01cbb0c941..90caeb4923 100644 --- a/dlls/rsaenh/tests/rsaenh.c +++ b/dlls/rsaenh/tests/rsaenh.c @@ -303,6 +303,11 @@ static void test_block_cipher_modes(void) result = CryptSetKeyParam(hKey, KP_MODE, (BYTE*)&dwMode, 0); ok(result, "%08lx\n", GetLastError()); + dwLen = 23; + result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, NULL, &dwLen, 24); + ok(result, "CryptEncrypt failed: %08lx\n", GetLastError()); + ok(dwLen == 24, "Unexpected length %ld\n", dwLen); + SetLastError(ERROR_SUCCESS); dwLen = 23; result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, abData, &dwLen, 24); @@ -317,6 +322,11 @@ static void test_block_cipher_modes(void) result = CryptSetKeyParam(hKey, KP_MODE, (BYTE*)&dwMode, 0); ok(result, "%08lx\n", GetLastError()); + dwLen = 23; + result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, NULL, &dwLen, 24); + ok(result, "CryptEncrypt failed: %08lx\n", GetLastError()); + ok(dwLen == 24, "Unexpected length %ld\n", dwLen); + dwLen = 23; result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, abData, &dwLen, 24); ok(result && dwLen == 24 && !memcmp(cbc, abData, sizeof(cbc)), @@ -595,6 +605,9 @@ static void test_rc4(void) result = CryptDestroyHash(hHash); ok(result, "%08lx\n", GetLastError()); + dwDataLen = 16; + result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, NULL, &dwDataLen, 24); + ok(result, "%08lx\n", GetLastError()); dwDataLen = 16; result = CryptEncrypt(hKey, (HCRYPTHASH)NULL, TRUE, 0, pbData, &dwDataLen, 24); ok(result, "%08lx\n", GetLastError()); @@ -1104,6 +1117,10 @@ static void test_rsa_encrypt(void) ok (result, "%08lx\n", GetLastError()); if (!result) return; + dwLen = 12; + result = CryptEncrypt(hRSAKey, 0, TRUE, 0, NULL, &dwLen, (DWORD)sizeof(abData)); + ok(result, "CryptEncrypt failed: %08lx\n", GetLastError()); + ok(dwLen == 128, "Unexpected length %ld\n", dwLen); dwLen = 12; result = CryptEncrypt(hRSAKey, 0, TRUE, 0, abData, &dwLen, (DWORD)sizeof(abData)); ok (result, "%08lx\n", GetLastError());