diff --git a/modules/libpref/init/StaticPrefList.yaml b/modules/libpref/init/StaticPrefList.yaml index e5d2cc0841cb..a904a6246bcb 100644 --- a/modules/libpref/init/StaticPrefList.yaml +++ b/modules/libpref/init/StaticPrefList.yaml @@ -11018,6 +11018,18 @@ value: 2048 mirror: always +# How many records we store per entry +- name: network.ssl_tokens_cache_records_per_entry + type: RelaxedAtomicUint32 + value: 2 + mirror: always + +# If true, only use the token once +- name: network.ssl_tokens_cache_use_only_once + type: RelaxedAtomicBool + value: false + mirror: always + # The maximum allowed length for a URL - 1MB default. - name: network.standard-url.max-length type: RelaxedAtomicUint32 diff --git a/netwerk/base/SSLTokensCache.cpp b/netwerk/base/SSLTokensCache.cpp index b3565d24a0c9..7420785bf7e5 100644 --- a/netwerk/base/SSLTokensCache.cpp +++ b/netwerk/base/SSLTokensCache.cpp @@ -19,6 +19,9 @@ namespace net { static LazyLogModule gSSLTokensCacheLog("SSLTokensCache"); #undef LOG #define LOG(args) MOZ_LOG(gSSLTokensCacheLog, mozilla::LogLevel::Debug, args) +#undef LOG5_ENABLED +#define LOG5_ENABLED() \ + MOZ_LOG_TEST(mozilla::net::gSSLTokensCacheLog, mozilla::LogLevel::Verbose) class ExpirationComparator { public: @@ -48,6 +51,15 @@ SessionCacheInfo SessionCacheInfo::Clone() const { StaticRefPtr SSLTokensCache::gInstance; StaticMutex SSLTokensCache::sLock; +uint64_t SSLTokensCache::sRecordId = 0; + +SSLTokensCache::TokenCacheRecord::~TokenCacheRecord() { + if (!gInstance) { + return; + } + + gInstance->OnRecordDestroyed(this); +} uint32_t SSLTokensCache::TokenCacheRecord::Size() const { uint32_t size = mToken.Length() + sizeof(mSessionCacheInfo.mEVStatus) + @@ -71,6 +83,50 @@ void SSLTokensCache::TokenCacheRecord::Reset() { mSessionCacheInfo.mSucceededCertChainBytes.reset(); } +uint32_t SSLTokensCache::TokenCacheEntry::Size() const { + uint32_t size = 0; + for (const auto& rec : mRecords) { + size += rec->Size(); + } + return size; +} + +void SSLTokensCache::TokenCacheEntry::AddRecord( + UniquePtr&& aRecord, + nsTArray& aExpirationArray) { + if (mRecords.Length() == + StaticPrefs::network_ssl_tokens_cache_records_per_entry()) { + aExpirationArray.RemoveElement(mRecords[0].get()); + mRecords.RemoveElementAt(0); + } + + aExpirationArray.AppendElement(aRecord.get()); + for (int32_t i = mRecords.Length() - 1; i >= 0; --i) { + if (aRecord->mExpirationTime > mRecords[i]->mExpirationTime) { + mRecords.InsertElementAt(i + 1, std::move(aRecord)); + return; + } + } + mRecords.InsertElementAt(0, std::move(aRecord)); +} + +UniquePtr +SSLTokensCache::TokenCacheEntry::RemoveWithId(uint64_t aId) { + for (int32_t i = mRecords.Length() - 1; i >= 0; --i) { + if (mRecords[i]->mId == aId) { + UniquePtr record = std::move(mRecords[i]); + mRecords.RemoveElementAt(i); + return record; + } + } + return nullptr; +} + +const UniquePtr& +SSLTokensCache::TokenCacheEntry::Get() { + return mRecords[0]; +} + NS_IMPL_ISUPPORTS(SSLTokensCache, nsIMemoryReporter) // static @@ -204,45 +260,53 @@ nsresult SSLTokensCache::Put(const nsACString& aKey, const uint8_t* aToken, return rv; } - TokenCacheRecord* const rec = + auto makeUniqueRecord = [&]() { + auto rec = MakeUnique(); + rec->mKey = aKey; + rec->mExpirationTime = aExpirationTime; + MOZ_ASSERT(rec->mToken.IsEmpty()); + rec->mToken.AppendElements(aToken, aTokenLen); + rec->mId = ++sRecordId; + rec->mSessionCacheInfo.mServerCertBytes = std::move(certBytes); + + rec->mSessionCacheInfo.mSucceededCertChainBytes = + succeededCertChainBytes + ? Some(TransformIntoNewArray( + *succeededCertChainBytes, + [](auto& element) { return nsTArray(std::move(element)); })) + : Nothing(); + + if (isEV) { + rec->mSessionCacheInfo.mEVStatus = psm::EVStatus::EV; + } + + rec->mSessionCacheInfo.mCertificateTransparencyStatus = + certificateTransparencyStatus; + + rec->mSessionCacheInfo.mIsBuiltCertChainRootBuiltInRoot = + std::move(isBuiltCertChainRootBuiltInRoot); + return rec; + }; + + TokenCacheEntry* const cacheEntry = gInstance->mTokenCacheRecords.WithEntryHandle(aKey, [&](auto&& entry) { if (!entry) { - auto rec = MakeUnique(); - rec->mKey = aKey; - gInstance->mExpirationArray.AppendElement(rec.get()); - entry.Insert(std::move(rec)); + auto rec = makeUniqueRecord(); + auto cacheEntry = MakeUnique(); + cacheEntry->AddRecord(std::move(rec), gInstance->mExpirationArray); + entry.Insert(std::move(cacheEntry)); } else { + // To make sure the cache size is synced, we take away the size of + // whole entry and add it back later. gInstance->mCacheSize -= entry.Data()->Size(); - entry.Data()->Reset(); + entry.Data()->AddRecord(makeUniqueRecord(), + gInstance->mExpirationArray); } return entry->get(); }); - rec->mExpirationTime = aExpirationTime; - MOZ_ASSERT(rec->mToken.IsEmpty()); - rec->mToken.AppendElements(aToken, aTokenLen); - - rec->mSessionCacheInfo.mServerCertBytes = std::move(certBytes); - - rec->mSessionCacheInfo.mSucceededCertChainBytes = - succeededCertChainBytes - ? Some(TransformIntoNewArray( - *succeededCertChainBytes, - [](auto& element) { return nsTArray(std::move(element)); })) - : Nothing(); - - if (isEV) { - rec->mSessionCacheInfo.mEVStatus = psm::EVStatus::EV; - } - - rec->mSessionCacheInfo.mCertificateTransparencyStatus = - certificateTransparencyStatus; - - rec->mSessionCacheInfo.mIsBuiltCertChainRootBuiltInRoot = - std::move(isBuiltCertChainRootBuiltInRoot); - - gInstance->mCacheSize += rec->Size(); + gInstance->mCacheSize += cacheEntry->Size(); gInstance->LogStats(); @@ -252,8 +316,8 @@ nsresult SSLTokensCache::Put(const nsACString& aKey, const uint8_t* aToken, } // static -nsresult SSLTokensCache::Get(const nsACString& aKey, - nsTArray& aToken) { +nsresult SSLTokensCache::Get(const nsACString& aKey, nsTArray& aToken, + SessionCacheInfo& aResult, uint64_t* aTokenId) { StaticMutexAutoLock lock(sLock); LOG(("SSLTokensCache::Get [key=%s]", PromiseFlatCString(aKey).get())); @@ -263,13 +327,38 @@ nsresult SSLTokensCache::Get(const nsACString& aKey, return NS_ERROR_NOT_INITIALIZED; } - TokenCacheRecord* rec = nullptr; + return gInstance->GetLocked(aKey, aToken, aResult, aTokenId); +} - if (gInstance->mTokenCacheRecords.Get(aKey, &rec)) { - if (rec->mToken.Length()) { - aToken = rec->mToken.Clone(); - return NS_OK; +nsresult SSLTokensCache::GetLocked(const nsACString& aKey, + nsTArray& aToken, + SessionCacheInfo& aResult, + uint64_t* aTokenId) { + sLock.AssertCurrentThreadOwns(); + + TokenCacheEntry* cacheEntry = nullptr; + + if (mTokenCacheRecords.Get(aKey, &cacheEntry)) { + if (cacheEntry->RecordCount() == 0) { + MOZ_ASSERT(false, "Found a cacheEntry with no records"); + mTokenCacheRecords.Remove(aKey); + return NS_ERROR_NOT_AVAILABLE; } + + const UniquePtr& rec = cacheEntry->Get(); + aToken = rec->mToken.Clone(); + aResult = rec->mSessionCacheInfo.Clone(); + if (aTokenId) { + *aTokenId = rec->mId; + } + if (StaticPrefs::network_ssl_tokens_cache_use_only_once()) { + mCacheSize -= rec->Size(); + cacheEntry->RemoveWithId(rec->mId); + if (cacheEntry->RecordCount() == 0) { + mTokenCacheRecords.Remove(aKey); + } + } + return NS_OK; } LOG((" token not found")); @@ -277,31 +366,7 @@ nsresult SSLTokensCache::Get(const nsACString& aKey, } // static -bool SSLTokensCache::GetSessionCacheInfo(const nsACString& aKey, - SessionCacheInfo& aResult) { - StaticMutexAutoLock lock(sLock); - - LOG(("SSLTokensCache::GetSessionCacheInfo [key=%s]", - PromiseFlatCString(aKey).get())); - - if (!gInstance) { - LOG((" service not initialized")); - return false; - } - - TokenCacheRecord* rec = nullptr; - - if (gInstance->mTokenCacheRecords.Get(aKey, &rec)) { - aResult = rec->mSessionCacheInfo.Clone(); - return true; - } - - LOG((" token not found")); - return false; -} - -// static -nsresult SSLTokensCache::Remove(const nsACString& aKey) { +nsresult SSLTokensCache::Remove(const nsACString& aKey, uint64_t aId) { StaticMutexAutoLock lock(sLock); LOG(("SSLTokensCache::Remove [key=%s]", PromiseFlatCString(aKey).get())); @@ -311,33 +376,75 @@ nsresult SSLTokensCache::Remove(const nsACString& aKey) { return NS_ERROR_NOT_INITIALIZED; } - return gInstance->RemoveLocked(aKey); + return gInstance->RemoveLocked(aKey, aId); } -nsresult SSLTokensCache::RemoveLocked(const nsACString& aKey) { +nsresult SSLTokensCache::RemoveLocked(const nsACString& aKey, uint64_t aId) { sLock.AssertCurrentThreadOwns(); - LOG(("SSLTokensCache::RemoveLocked [key=%s]", - PromiseFlatCString(aKey).get())); + LOG(("SSLTokensCache::RemoveLocked [key=%s, id=%" PRIu64 "]", + PromiseFlatCString(aKey).get(), aId)); - UniquePtr rec; + TokenCacheEntry* cacheEntry; + if (!mTokenCacheRecords.Get(aKey, &cacheEntry)) { + return NS_ERROR_NOT_AVAILABLE; + } - if (!mTokenCacheRecords.Remove(aKey, &rec)) { - LOG((" token not found")); + UniquePtr rec = cacheEntry->RemoveWithId(aId); + if (!rec) { return NS_ERROR_NOT_AVAILABLE; } mCacheSize -= rec->Size(); - - if (!mExpirationArray.RemoveElement(rec.get())) { - MOZ_ASSERT(false, "token not found in mExpirationArray"); + if (cacheEntry->RecordCount() == 0) { + mTokenCacheRecords.Remove(aKey); } + // Release the record immediately, so mExpirationArray can be also updated. + rec = nullptr; + LogStats(); return NS_OK; } +// static +nsresult SSLTokensCache::RemoveAll(const nsACString& aKey) { + StaticMutexAutoLock lock(sLock); + + LOG(("SSLTokensCache::RemoveAll [key=%s]", PromiseFlatCString(aKey).get())); + + if (!gInstance) { + LOG((" service not initialized")); + return NS_ERROR_NOT_INITIALIZED; + } + + return gInstance->RemovAllLocked(aKey); +} + +nsresult SSLTokensCache::RemovAllLocked(const nsACString& aKey) { + sLock.AssertCurrentThreadOwns(); + + LOG(("SSLTokensCache::RemovAllLocked [key=%s]", + PromiseFlatCString(aKey).get())); + + UniquePtr cacheEntry; + if (!mTokenCacheRecords.Remove(aKey, &cacheEntry)) { + return NS_ERROR_NOT_AVAILABLE; + } + + mCacheSize -= cacheEntry->Size(); + cacheEntry = nullptr; + + LogStats(); + + return NS_OK; +} + +void SSLTokensCache::OnRecordDestroyed(TokenCacheRecord* aRec) { + mExpirationArray.RemoveElement(aRec); +} + void SSLTokensCache::EvictIfNecessary() { // kilobytes to bytes uint32_t capacity = StaticPrefs::network_ssl_tokens_cache_capacity() << 10; @@ -350,17 +457,23 @@ void SSLTokensCache::EvictIfNecessary() { mExpirationArray.Sort(ExpirationComparator()); while (mCacheSize > capacity && mExpirationArray.Length() > 0) { - if (NS_FAILED(RemoveLocked(mExpirationArray[0]->mKey))) { - MOZ_ASSERT(false, - "mExpirationArray and mTokenCacheRecords are out of sync!"); - mExpirationArray.RemoveElementAt(0); - } + DebugOnly rv = + RemoveLocked(mExpirationArray[0]->mKey, mExpirationArray[0]->mId); + MOZ_ASSERT(NS_SUCCEEDED(rv), + "mExpirationArray and mTokenCacheRecords are out of sync!"); } } void SSLTokensCache::LogStats() { + if (!LOG5_ENABLED()) { + return; + } LOG(("SSLTokensCache::LogStats [count=%zu, cacheSize=%u]", mExpirationArray.Length(), mCacheSize)); + for (const auto& ent : mTokenCacheRecords.Values()) { + const UniquePtr& rec = ent->Get(); + LOG(("key=%s count=%d", rec->mKey.get(), ent->RecordCount())); + } } size_t SSLTokensCache::SizeOfIncludingThis( diff --git a/netwerk/base/SSLTokensCache.h b/netwerk/base/SSLTokensCache.h index 22f622504d60..10d47d840f2a 100644 --- a/netwerk/base/SSLTokensCache.h +++ b/netwerk/base/SSLTokensCache.h @@ -45,17 +45,20 @@ class SSLTokensCache : public nsIMemoryReporter { static nsresult Put(const nsACString& aKey, const uint8_t* aToken, uint32_t aTokenLen, nsITransportSecurityInfo* aSecInfo, PRUint32 aExpirationTime); - static nsresult Get(const nsACString& aKey, nsTArray& aToken); - static bool GetSessionCacheInfo(const nsACString& aKey, - SessionCacheInfo& aResult); - static nsresult Remove(const nsACString& aKey); + static nsresult Get(const nsACString& aKey, nsTArray& aToken, + SessionCacheInfo& aResult, uint64_t* aTokenId = nullptr); + static nsresult Remove(const nsACString& aKey, uint64_t aId); + static nsresult RemoveAll(const nsACString& aKey); static void Clear(); private: SSLTokensCache(); virtual ~SSLTokensCache(); - nsresult RemoveLocked(const nsACString& aKey); + nsresult RemoveLocked(const nsACString& aKey, uint64_t aId); + nsresult RemovAllLocked(const nsACString& aKey); + nsresult GetLocked(const nsACString& aKey, nsTArray& aToken, + SessionCacheInfo& aResult, uint64_t* aTokenId); void EvictIfNecessary(); void LogStats(); @@ -64,11 +67,14 @@ class SSLTokensCache : public nsIMemoryReporter { static mozilla::StaticRefPtr gInstance; static StaticMutex sLock MOZ_UNANNOTATED; + static uint64_t sRecordId; uint32_t mCacheSize{0}; // Actual cache size in bytes class TokenCacheRecord { public: + ~TokenCacheRecord(); + uint32_t Size() const; void Reset(); @@ -76,9 +82,33 @@ class SSLTokensCache : public nsIMemoryReporter { PRUint32 mExpirationTime = 0; nsTArray mToken; SessionCacheInfo mSessionCacheInfo; + // An unique id to identify the record. Mostly used when we want to remove a + // record from TokenCacheEntry. + uint64_t mId = 0; }; - nsClassHashtable mTokenCacheRecords; + class TokenCacheEntry { + public: + uint32_t Size() const; + // Add a record into |mRecords|. To make sure |mRecords| is sorted, we + // iterate |mRecords| everytime to find a right place to insert the new + // record. + void AddRecord(UniquePtr&& aRecord, + nsTArray& aExpirationArray); + // This function returns the first record in |mRecords|. + const UniquePtr& Get(); + UniquePtr RemoveWithId(uint64_t aId); + uint32_t RecordCount() const { return mRecords.Length(); } + const nsTArray>& Records() { return mRecords; } + + private: + // The records in this array are ordered by the expiration time. + nsTArray> mRecords; + }; + + void OnRecordDestroyed(TokenCacheRecord* aRec); + + nsClassHashtable mTokenCacheRecords; nsTArray mExpirationArray; }; diff --git a/netwerk/protocol/http/Http3Session.cpp b/netwerk/protocol/http/Http3Session.cpp index ea2d0693c815..c7cde217780c 100644 --- a/netwerk/protocol/http/Http3Session.cpp +++ b/netwerk/protocol/http/Http3Session.cpp @@ -131,10 +131,12 @@ nsresult Http3Session::Init(const nsHttpConnectionInfo* aConnInfo, nsAutoCString peerId; mSocketControl->GetPeerId(peerId); nsTArray token; + SessionCacheInfo info; if (StaticPrefs::network_http_http3_enable_0rtt() && - NS_SUCCEEDED(SSLTokensCache::Get(peerId, token))) { + NS_SUCCEEDED(SSLTokensCache::Get(peerId, token, info))) { LOG(("Found a resumption token in the cache.")); mHttp3Connection->SetResumptionToken(token); + mSocketControl->SetSessionCacheInfo(std::move(info)); if (mHttp3Connection->IsZeroRtt()) { LOG(("Can send ZeroRtt data")); RefPtr self(this); diff --git a/netwerk/test/gtest/TestSSLTokensCache.cpp b/netwerk/test/gtest/TestSSLTokensCache.cpp new file mode 100644 index 000000000000..1b8770be93db --- /dev/null +++ b/netwerk/test/gtest/TestSSLTokensCache.cpp @@ -0,0 +1,165 @@ +#include "gtest/gtest.h" + +#include +#include "mozilla/Preferences.h" +#include "nsITransportSecurityInfo.h" +#include "nsSerializationHelper.h" +#include "SSLTokensCache.h" + +static already_AddRefed createDummySecInfo() { + // clang-format off + nsCString base64Serialization( + "FnhllAKWRHGAlo+ESXykKAAAAAAAAAAAwAAAAAAAAEaphjojH6pBabDSgSnsfLHeAAQAAgAAAAAAAAAAAAAAAAAAAAA" + "B4vFIJp5wRkeyPxAQ9RJGKPqbqVvKO0mKuIl8ec8o/uhmCjImkVxP+7sgiYWmMt8F+O2DZM7ZTG6GukivU8OT5gAAAAIAAAWpMII" + "FpTCCBI2gAwIBAgIQD4svsaKEC+QtqtsU2TF8ITANBgkqhkiG9w0BAQsFADBwMQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUN" + "lcnQgSW5jMRkwFwYDVQQLExB3d3cuZGlnaWNlcnQuY29tMS8wLQYDVQQDEyZEaWdpQ2VydCBTSEEyIEhpZ2ggQXNzdXJhbmNlIFN" + "lcnZlciBDQTAeFw0xNTAyMjMwMDAwMDBaFw0xNjAzMDIxMjAwMDBaMGoxCzAJBgNVBAYTAlVTMRYwFAYDVQQHEw1TYW4gRnJhbmN" + "pc2NvMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRUwEwYDVQQKEwxGYXN0bHksIEluYy4xFzAVBgNVBAMTDnd3dy5naXRodWIuY29tMII" + "BIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA+9WUCgrgUNwP/JC3cUefLAXeDpq8Ko/U8p8IRvny0Ri0I6Uq0t+RP/nF0LJ" + "Avda8QHYujdgeDTePepBX7+OiwBFhA0YO+rM3C2Z8IRaN/i9eLln+Yyc68+1z+E10s1EXdZrtDGvN6MHqygGsdfkXKfBLUJ1BZEh" + "s9sBnfcjq3kh5gZdBArdG9l5NpdmQhtceaFGsPiWuJxGxRzS4i95veUHWkhMpEYDEEBdcDGxqArvQCvzSlngdttQCfx8OUkBTb3B" + "A2okpTwwJfqPsxVetA6qR7UNc+fVb6KHwvm0bzi2rQ3xw3D/syRHwdMkpoVDQPCk43H9WufgfBKRen87dFwIDAQABo4ICPzCCAjs" + "wHwYDVR0jBBgwFoAUUWj/kK8CB3U8zNllZGKiErhZcjswHQYDVR0OBBYEFGS/RLNGCZvPWh1xSaIEcouINIQjMHsGA1UdEQR0MHK" + "CDnd3dy5naXRodWIuY29tggpnaXRodWIuY29tggwqLmdpdGh1Yi5jb22CCyouZ2l0aHViLmlvgglnaXRodWIuaW+CFyouZ2l0aHV" + "idXNlcmNvbnRlbnQuY29tghVnaXRodWJ1c2VyY29udGVudC5jb20wDgYDVR0PAQH/BAQDAgWgMB0GA1UdJQQWMBQGCCsGAQUFBwM" + "BBggrBgEFBQcDAjB1BgNVHR8EbjBsMDSgMqAwhi5odHRwOi8vY3JsMy5kaWdpY2VydC5jb20vc2hhMi1oYS1zZXJ2ZXItZzMuY3J" + "sMDSgMqAwhi5odHRwOi8vY3JsNC5kaWdpY2VydC5jb20vc2hhMi1oYS1zZXJ2ZXItZzMuY3JsMEIGA1UdIAQ7MDkwNwYJYIZIAYb" + "9bAEBMCowKAYIKwYBBQUHAgEWHGh0dHBzOi8vd3d3LmRpZ2ljZXJ0LmNvbS9DUFMwgYMGCCsGAQUFBwEBBHcwdTAkBggrBgEFBQc" + "wAYYYaHR0cDovL29jc3AuZGlnaWNlcnQuY29tME0GCCsGAQUFBzAChkFodHRwOi8vY2FjZXJ0cy5kaWdpY2VydC5jb20vRGlnaUN" + "lcnRTSEEySGlnaEFzc3VyYW5jZVNlcnZlckNBLmNydDAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQAc4dbVmuKvyI7" + "KZ4Txk+ZqcAYToJGKUIVaPL94e5SZGweUisjaCbplAOihnf6Mxt8n6vnuH2IsCaz2NRHqhdcosjT3CwAiJpJNkXPKWVL/txgdSTV" + "2cqB1GG4esFOalvI52dzn+J4fTIYZvNF+AtGyHSLm2XRXYZCw455laUKf6Sk9RDShDgUvzhOKL4GXfTwKXv12MyMknJybH8UCpjC" + "HZmFBVHMcUN/87HsQo20PdOekeEvkjrrMIxW+gxw22Yb67yF/qKgwrWr+43bLN709iyw+LWiU7sQcHL2xk9SYiWQDj2tYz2soObV" + "QYTJm0VUZMEVFhtALq46cx92Zu4vFwC8AAwAAAAABAQAA"); + // clang-format on + nsCOMPtr secInfo; + NS_DeserializeObject(base64Serialization, getter_AddRefs(secInfo)); + + nsCOMPtr securityInfo = do_QueryInterface(secInfo); + return securityInfo.forget(); +} + +static auto MakeTestData(const size_t aDataSize) { + auto data = nsTArray(); + data.SetLength(aDataSize); + std::iota(data.begin(), data.end(), 0); + return data; +} + +static void putToken(const nsACString& aKey, uint32_t aSize) { + nsCOMPtr secInfo = createDummySecInfo(); + nsTArray token = MakeTestData(aSize); + nsresult rv = mozilla::net::SSLTokensCache::Put(aKey, token.Elements(), aSize, + secInfo, aSize); + ASSERT_EQ(rv, NS_OK); +} + +static void getAndCheckResult(const nsACString& aKey, uint32_t aExpectedSize) { + nsTArray result; + mozilla::net::SessionCacheInfo unused; + nsresult rv = mozilla::net::SSLTokensCache::Get(aKey, result, unused); + ASSERT_EQ(rv, NS_OK); + ASSERT_EQ(result.Length(), (size_t)aExpectedSize); +} + +TEST(TestTokensCache, SinglePut) +{ + mozilla::net::SSLTokensCache::Clear(); + mozilla::Preferences::SetInt("network.ssl_tokens_cache_records_per_entry", 1); + mozilla::Preferences::SetBool("network.ssl_tokens_cache_use_only_once", + false); + + putToken("anon:www.example.com:443"_ns, 100); + nsTArray result; + mozilla::net::SessionCacheInfo unused; + uint64_t id = 0; + nsresult rv = mozilla::net::SSLTokensCache::Get("anon:www.example.com:443"_ns, + result, unused, &id); + ASSERT_EQ(rv, NS_OK); + ASSERT_EQ(result.Length(), (size_t)100); + ASSERT_EQ(id, (uint64_t)1); + rv = mozilla::net::SSLTokensCache::Get("anon:www.example.com:443"_ns, result, + unused, &id); + ASSERT_EQ(rv, NS_OK); + + mozilla::Preferences::SetBool("network.ssl_tokens_cache_use_only_once", true); + // network.ssl_tokens_cache_use_only_once is true, so the record will be + // removed after SSLTokensCache::Get below. + rv = mozilla::net::SSLTokensCache::Get("anon:www.example.com:443"_ns, result, + unused); + ASSERT_EQ(rv, NS_OK); + rv = mozilla::net::SSLTokensCache::Get("anon:www.example.com:443"_ns, result, + unused); + ASSERT_EQ(rv, NS_ERROR_NOT_AVAILABLE); +} + +TEST(TestTokensCache, MultiplePut) +{ + mozilla::net::SSLTokensCache::Clear(); + mozilla::Preferences::SetInt("network.ssl_tokens_cache_records_per_entry", 3); + + putToken("anon:www.example1.com:443"_ns, 300); + // This record will be removed because + // "network.ssl_tokens_cache_records_per_entry" is 3. + putToken("anon:www.example1.com:443"_ns, 100); + putToken("anon:www.example1.com:443"_ns, 200); + putToken("anon:www.example1.com:443"_ns, 400); + + // Test if records are ordered by the expiration time + getAndCheckResult("anon:www.example1.com:443"_ns, 200); + getAndCheckResult("anon:www.example1.com:443"_ns, 300); + getAndCheckResult("anon:www.example1.com:443"_ns, 400); +} + +TEST(TestTokensCache, RemoveAll) +{ + mozilla::net::SSLTokensCache::Clear(); + mozilla::Preferences::SetInt("network.ssl_tokens_cache_records_per_entry", 3); + + putToken("anon:www.example1.com:443"_ns, 100); + putToken("anon:www.example1.com:443"_ns, 200); + putToken("anon:www.example1.com:443"_ns, 300); + + putToken("anon:www.example2.com:443"_ns, 100); + putToken("anon:www.example2.com:443"_ns, 200); + putToken("anon:www.example2.com:443"_ns, 300); + + nsTArray result; + mozilla::net::SessionCacheInfo unused; + nsresult rv = mozilla::net::SSLTokensCache::Get( + "anon:www.example1.com:443"_ns, result, unused); + ASSERT_EQ(rv, NS_OK); + ASSERT_EQ(result.Length(), (size_t)100); + + rv = mozilla::net::SSLTokensCache::RemoveAll("anon:www.example1.com:443"_ns); + ASSERT_EQ(rv, NS_OK); + + rv = mozilla::net::SSLTokensCache::Get("anon:www.example1.com:443"_ns, result, + unused); + ASSERT_EQ(rv, NS_ERROR_NOT_AVAILABLE); + + rv = mozilla::net::SSLTokensCache::Get("anon:www.example2.com:443"_ns, result, + unused); + ASSERT_EQ(rv, NS_OK); + ASSERT_EQ(result.Length(), (size_t)100); +} + +TEST(TestTokensCache, Eviction) +{ + mozilla::net::SSLTokensCache::Clear(); + + mozilla::Preferences::SetInt("network.ssl_tokens_cache_records_per_entry", 3); + mozilla::Preferences::SetInt("network.ssl_tokens_cache_capacity", 8); + + putToken("anon:www.example2.com:443"_ns, 300); + putToken("anon:www.example2.com:443"_ns, 400); + putToken("anon:www.example2.com:443"_ns, 500); + // The one has expiration time "300" will be removed because we only allow 3 + // records per entry. + putToken("anon:www.example2.com:443"_ns, 600); + + putToken("anon:www.example3.com:443"_ns, 600); + putToken("anon:www.example3.com:443"_ns, 500); + // The one has expiration time "400" was evicted, so we get "500". + getAndCheckResult("anon:www.example2.com:443"_ns, 500); +} diff --git a/netwerk/test/gtest/moz.build b/netwerk/test/gtest/moz.build index 002f5009eae4..79ca936efbad 100644 --- a/netwerk/test/gtest/moz.build +++ b/netwerk/test/gtest/moz.build @@ -24,6 +24,7 @@ UNIFIED_SOURCES += [ "TestReadStreamToString.cpp", "TestServerTimingHeader.cpp", "TestSocketTransportService.cpp", + "TestSSLTokensCache.cpp", "TestStandardURL.cpp", "TestUDPSocket.cpp", ] diff --git a/security/manager/ssl/CommonSocketControl.cpp b/security/manager/ssl/CommonSocketControl.cpp index bf3a40ed4015..d94576a483ea 100644 --- a/security/manager/ssl/CommonSocketControl.cpp +++ b/security/manager/ssl/CommonSocketControl.cpp @@ -12,7 +12,7 @@ #include "SharedSSLState.h" #include "sslt.h" #include "ssl.h" -#include "mozilla/net/SSLTokensCache.h" +#include "mozilla/StaticPrefs_network.h" #include "nsICertOverrideService.h" #include "nsITlsHandshakeListener.h" @@ -224,10 +224,7 @@ CommonSocketControl::IsAcceptableForHost(const nsACString& hostname, } void CommonSocketControl::RebuildCertificateInfoFromSSLTokenCache() { - nsAutoCString key; - GetPeerId(key); - mozilla::net::SessionCacheInfo info; - if (!mozilla::net::SSLTokensCache::GetSessionCacheInfo(key, info)) { + if (!mSessionCacheInfo) { MOZ_LOG( gPIPNSSLog, LogLevel::Debug, ("CommonSocketControl::RebuildCertificateInfoFromSSLTokenCache cannot " @@ -235,6 +232,7 @@ void CommonSocketControl::RebuildCertificateInfoFromSSLTokenCache() { return; } + mozilla::net::SessionCacheInfo& info = *mSessionCacheInfo; nsCOMPtr cert( new nsNSSCertificate(std::move(info.mServerCertBytes))); SetServerCert(cert, info.mEVStatus); diff --git a/security/manager/ssl/CommonSocketControl.h b/security/manager/ssl/CommonSocketControl.h index 4afd79d2d04d..3f093d384f12 100644 --- a/security/manager/ssl/CommonSocketControl.h +++ b/security/manager/ssl/CommonSocketControl.h @@ -7,6 +7,8 @@ #ifndef CommonSocketControl_h #define CommonSocketControl_h +#include "mozilla/Maybe.h" +#include "mozilla/net/SSLTokensCache.h" #include "nsISSLSocketControl.h" #include "TransportSecurityInfo.h" @@ -20,10 +22,16 @@ class CommonSocketControl : public mozilla::psm::TransportSecurityInfo, uint32_t GetProviderFlags() const { return mProviderFlags; } void SetSSLVersionUsed(int16_t version) { mSSLVersionUsed = version; } + void SetSessionCacheInfo(mozilla::net::SessionCacheInfo&& aInfo) { + mSessionCacheInfo.reset(); + mSessionCacheInfo.emplace(std::move(aInfo)); + } void RebuildCertificateInfoFromSSLTokenCache(); protected: ~CommonSocketControl() = default; + + mozilla::Maybe mSessionCacheInfo; bool mHandshakeCompleted; bool mJoined; bool mSentClientCert; diff --git a/security/manager/ssl/nsNSSIOLayer.cpp b/security/manager/ssl/nsNSSIOLayer.cpp index 51d305f0f785..bca734563774 100644 --- a/security/manager/ssl/nsNSSIOLayer.cpp +++ b/security/manager/ssl/nsNSSIOLayer.cpp @@ -926,7 +926,9 @@ nsresult nsNSSSocketInfo::SetResumptionTokenFromExternalCache() { return rv; } - rv = mozilla::net::SSLTokensCache::Get(peerId, token); + uint64_t tokenId = 0; + mozilla::net::SessionCacheInfo info; + rv = mozilla::net::SSLTokensCache::Get(peerId, token, info, &tokenId); if (NS_FAILED(rv)) { if (rv == NS_ERROR_NOT_AVAILABLE) { // It's ok if we can't find the token. @@ -939,7 +941,7 @@ nsresult nsNSSSocketInfo::SetResumptionTokenFromExternalCache() { SECStatus srv = SSL_SetResumptionToken(mFd, token.Elements(), token.Length()); if (srv == SECFailure) { PRErrorCode error = PR_GetError(); - mozilla::net::SSLTokensCache::Remove(peerId); + mozilla::net::SSLTokensCache::Remove(peerId, tokenId); MOZ_LOG(gPIPNSSLog, LogLevel::Debug, ("Setting token failed with NSS error %d [id=%s]", error, PromiseFlatCString(peerId).get())); @@ -953,6 +955,8 @@ nsresult nsNSSSocketInfo::SetResumptionTokenFromExternalCache() { return NS_ERROR_FAILURE; } + SetSessionCacheInfo(std::move(info)); + return NS_OK; }