mirror of
https://github.com/mozilla/gecko-dev.git
synced 2024-10-20 00:35:44 +00:00
416 lines
13 KiB
C++
416 lines
13 KiB
C++
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
|
|
/* vim: set ts=2 et sw=2 tw=80: */
|
|
/* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
* You can obtain one at http://mozilla.org/MPL/2.0/. */
|
|
|
|
#ifndef tls_agent_h_
|
|
#define tls_agent_h_
|
|
|
|
#include "prio.h"
|
|
#include "ssl.h"
|
|
|
|
#include <iostream>
|
|
#include <functional>
|
|
|
|
#include "test_io.h"
|
|
|
|
#define GTEST_HAS_RTTI 0
|
|
#include "gtest/gtest.h"
|
|
|
|
namespace nss_test {
|
|
|
|
#define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl
|
|
|
|
enum SessionResumptionMode {
|
|
RESUME_NONE = 0,
|
|
RESUME_SESSIONID = 1,
|
|
RESUME_TICKET = 2,
|
|
RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
|
|
};
|
|
|
|
class TlsAgent;
|
|
|
|
typedef
|
|
std::function<SECStatus(TlsAgent& agent, PRBool checksig, PRBool isServer)>
|
|
AuthCertificateCallbackFunction;
|
|
|
|
typedef
|
|
std::function<void(TlsAgent& agent)>
|
|
HandshakeCallbackFunction;
|
|
|
|
typedef
|
|
std::function<int32_t(TlsAgent& agent, const SECItem *srvNameArr,
|
|
PRUint32 srvNameArrSize)>
|
|
SniCallbackFunction;
|
|
|
|
class TlsAgent : public PollTarget {
|
|
public:
|
|
enum Role { CLIENT, SERVER };
|
|
enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR };
|
|
|
|
static const std::string kClient; // the client key is sign only
|
|
static const std::string kServerRsa; // both sign and encrypt
|
|
static const std::string kServerRsaSign;
|
|
static const std::string kServerRsaDecrypt;
|
|
static const std::string kServerEcdsa;
|
|
static const std::string kServerEcdhEcdsa;
|
|
// TODO: static const std::string kServerEcdhRsa;
|
|
static const std::string kServerDsa;
|
|
|
|
TlsAgent(const std::string& name, Role role, Mode mode);
|
|
virtual ~TlsAgent();
|
|
|
|
bool Init() {
|
|
pr_fd_ = DummyPrSocket::CreateFD(role_str(), mode_);
|
|
if (!pr_fd_) return false;
|
|
|
|
adapter_ = DummyPrSocket::GetAdapter(pr_fd_);
|
|
if (!adapter_) return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); }
|
|
|
|
void SetPacketFilter(PacketFilter* filter) {
|
|
adapter_->SetPacketFilter(filter);
|
|
}
|
|
|
|
|
|
void StartConnect();
|
|
void CheckKEAType(SSLKEAType type) const;
|
|
void CheckAuthType(SSLAuthType type) const;
|
|
|
|
void DisableAllCiphers();
|
|
void EnableCiphersByAuthType(SSLAuthType authType);
|
|
void EnableCiphersByKeyExchange(SSLKEAType kea);
|
|
void EnableSingleCipher(uint16_t cipher);
|
|
|
|
void Handshake();
|
|
// Marks the internal state as CONNECTING in anticipation of renegotiation.
|
|
void PrepareForRenegotiate();
|
|
// Prepares for renegotiation, then actually triggers it.
|
|
void StartRenegotiate();
|
|
bool ConfigServerCert(const std::string& name, bool updateKeyBits = false);
|
|
bool EnsureTlsSetup();
|
|
|
|
void SetupClientAuth();
|
|
void RequestClientAuth(bool requireAuth);
|
|
bool GetClientAuthCredentials(CERTCertificate** cert,
|
|
SECKEYPrivateKey** priv) const;
|
|
|
|
void ConfigureSessionCache(SessionResumptionMode mode);
|
|
void SetSessionTicketsEnabled(bool en);
|
|
void SetSessionCacheEnabled(bool en);
|
|
void SetVersionRange(uint16_t minver, uint16_t maxver);
|
|
void GetVersionRange(uint16_t* minver, uint16_t* maxver);
|
|
void CheckPreliminaryInfo();
|
|
void SetExpectedVersion(uint16_t version);
|
|
void SetServerKeyBits(uint16_t bits);
|
|
void SetExpectedReadError(bool err);
|
|
void EnableFalseStart();
|
|
void ExpectResumption();
|
|
void SetSignatureAlgorithms(const SSLSignatureAndHashAlg* algorithms,
|
|
size_t count);
|
|
void EnableAlpn(const uint8_t* val, size_t len);
|
|
void CheckAlpn(SSLNextProtoState expected_state,
|
|
const std::string& expected) const;
|
|
void EnableSrtp();
|
|
void CheckSrtp() const;
|
|
void CheckErrorCode(int32_t expected) const;
|
|
// Send data on the socket, encrypting it.
|
|
void SendData(size_t bytes, size_t blocksize = 1024);
|
|
// Send data directly to the underlying socket, skipping the TLS layer.
|
|
void SendDirect(const DataBuffer& buf);
|
|
void ReadBytes();
|
|
void ResetSentBytes(); // Hack to test drops.
|
|
void EnableExtendedMasterSecret();
|
|
void CheckExtendedMasterSecret(bool expected);
|
|
void DisableRollbackDetection();
|
|
void EnableCompression();
|
|
void SetDowngradeCheckVersion(uint16_t version);
|
|
|
|
const std::string& name() const { return name_; }
|
|
|
|
Role role() const { return role_; }
|
|
std::string role_str() const { return role_ == SERVER ? "server" : "client"; }
|
|
|
|
State state() const { return state_; }
|
|
|
|
const CERTCertificate* peer_cert() const {
|
|
return SSL_PeerCertificate(ssl_fd_);
|
|
}
|
|
|
|
const char* state_str() const { return state_str(state()); }
|
|
|
|
const char* state_str(State state) const { return states[state]; }
|
|
|
|
PRFileDesc* ssl_fd() { return ssl_fd_; }
|
|
DummyPrSocket* adapter() { return adapter_; }
|
|
|
|
bool is_compressed() const {
|
|
return info_.compressionMethod != ssl_compression_null;
|
|
}
|
|
uint16_t server_key_bits() const { return server_key_bits_; }
|
|
uint16_t min_version() const { return vrange_.min; }
|
|
uint16_t max_version() const { return vrange_.max; }
|
|
uint16_t version() const {
|
|
EXPECT_EQ(STATE_CONNECTED, state_);
|
|
return info_.protocolVersion;
|
|
}
|
|
|
|
bool cipher_suite(uint16_t* cipher_suite) const {
|
|
if (state_ != STATE_CONNECTED) return false;
|
|
|
|
*cipher_suite = info_.cipherSuite;
|
|
return true;
|
|
}
|
|
|
|
std::string cipher_suite_name() const {
|
|
if (state_ != STATE_CONNECTED) return "UNKNOWN";
|
|
|
|
return csinfo_.cipherSuiteName;
|
|
}
|
|
|
|
std::vector<uint8_t> session_id() const {
|
|
return std::vector<uint8_t>(info_.sessionID,
|
|
info_.sessionID + info_.sessionIDLength);
|
|
}
|
|
|
|
size_t received_bytes() const { return recv_ctr_; }
|
|
int32_t error_code() const { return error_code_; }
|
|
|
|
bool can_falsestart_hook_called() const { return can_falsestart_hook_called_; }
|
|
|
|
void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) {
|
|
handshake_callback_ = handshake_callback;
|
|
}
|
|
|
|
void SetAuthCertificateCallback(
|
|
AuthCertificateCallbackFunction auth_certificate_callback) {
|
|
auth_certificate_callback_ = auth_certificate_callback;
|
|
}
|
|
|
|
void SetSniCallback(SniCallbackFunction sni_callback) {
|
|
sni_callback_ = sni_callback;
|
|
}
|
|
|
|
private:
|
|
const static char* states[];
|
|
|
|
void SetState(State state) {
|
|
if (state_ == state) return;
|
|
|
|
LOG("Changing state from " << state_str(state_) << " to "
|
|
<< state_str(state));
|
|
state_ = state;
|
|
}
|
|
|
|
// Dummy auth certificate hook.
|
|
static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
|
|
PRBool checksig, PRBool isServer) {
|
|
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
|
|
agent->CheckPreliminaryInfo();
|
|
agent->auth_certificate_hook_called_ = true;
|
|
if (agent->auth_certificate_callback_) {
|
|
return agent->auth_certificate_callback_(*agent, checksig, isServer);
|
|
}
|
|
return SECSuccess;
|
|
}
|
|
|
|
// Client auth certificate hook.
|
|
static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd,
|
|
PRBool checksig, PRBool isServer) {
|
|
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
|
|
EXPECT_TRUE(agent->expect_client_auth_);
|
|
EXPECT_TRUE(isServer);
|
|
if (agent->auth_certificate_callback_) {
|
|
return agent->auth_certificate_callback_(*agent, checksig, isServer);
|
|
}
|
|
return SECSuccess;
|
|
}
|
|
|
|
static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd,
|
|
CERTDistNames* caNames,
|
|
CERTCertificate** cert,
|
|
SECKEYPrivateKey** privKey);
|
|
|
|
static void ReadableCallback(PollTarget* self, Event event) {
|
|
TlsAgent* agent = static_cast<TlsAgent*>(self);
|
|
if (event == TIMER_EVENT) {
|
|
agent->timer_handle_ = nullptr;
|
|
}
|
|
agent->ReadableCallback_int();
|
|
}
|
|
|
|
void ReadableCallback_int() {
|
|
LOG("Readable");
|
|
switch (state_) {
|
|
case STATE_CONNECTING:
|
|
Handshake();
|
|
break;
|
|
case STATE_CONNECTED:
|
|
ReadBytes();
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
static PRInt32 SniHook(PRFileDesc *fd, const SECItem *srvNameArr,
|
|
PRUint32 srvNameArrSize,
|
|
void *arg) {
|
|
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
|
|
agent->CheckPreliminaryInfo();
|
|
agent->sni_hook_called_ = true;
|
|
EXPECT_EQ(1UL, srvNameArrSize);
|
|
if (agent->sni_callback_) {
|
|
return agent->sni_callback_(*agent, srvNameArr, srvNameArrSize);
|
|
}
|
|
return 0; // First configuration.
|
|
}
|
|
|
|
static SECStatus CanFalseStartCallback(PRFileDesc *fd, void *arg,
|
|
PRBool *canFalseStart) {
|
|
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
|
|
agent->CheckPreliminaryInfo();
|
|
EXPECT_TRUE(agent->falsestart_enabled_);
|
|
EXPECT_FALSE(agent->can_falsestart_hook_called_);
|
|
agent->can_falsestart_hook_called_ = true;
|
|
*canFalseStart = true;
|
|
return SECSuccess;
|
|
}
|
|
|
|
static void HandshakeCallback(PRFileDesc *fd, void *arg) {
|
|
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
|
|
agent->handshake_callback_called_ = true;
|
|
agent->Connected();
|
|
if (agent->handshake_callback_) {
|
|
agent->handshake_callback_(*agent);
|
|
}
|
|
}
|
|
|
|
void CheckCallbacks() const;
|
|
void Connected();
|
|
|
|
const std::string name_;
|
|
Mode mode_;
|
|
uint16_t server_key_bits_;
|
|
PRFileDesc* pr_fd_;
|
|
DummyPrSocket* adapter_;
|
|
PRFileDesc* ssl_fd_;
|
|
Role role_;
|
|
State state_;
|
|
Poller::Timer *timer_handle_;
|
|
bool falsestart_enabled_;
|
|
uint16_t expected_version_;
|
|
uint16_t expected_cipher_suite_;
|
|
bool expect_resumption_;
|
|
bool expect_client_auth_;
|
|
bool can_falsestart_hook_called_;
|
|
bool sni_hook_called_;
|
|
bool auth_certificate_hook_called_;
|
|
bool handshake_callback_called_;
|
|
SSLChannelInfo info_;
|
|
SSLCipherSuiteInfo csinfo_;
|
|
SSLVersionRange vrange_;
|
|
int32_t error_code_;
|
|
size_t send_ctr_;
|
|
size_t recv_ctr_;
|
|
bool expected_read_error_;
|
|
HandshakeCallbackFunction handshake_callback_;
|
|
AuthCertificateCallbackFunction auth_certificate_callback_;
|
|
SniCallbackFunction sni_callback_;
|
|
};
|
|
|
|
class TlsAgentTestBase : public ::testing::Test {
|
|
public:
|
|
static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;
|
|
|
|
TlsAgentTestBase(TlsAgent::Role role,
|
|
Mode mode) : agent_(nullptr),
|
|
fd_(nullptr),
|
|
role_(role),
|
|
mode_(mode) {}
|
|
~TlsAgentTestBase() {
|
|
if (fd_) {
|
|
PR_Close(fd_);
|
|
}
|
|
}
|
|
|
|
void SetUp();
|
|
void TearDown();
|
|
|
|
void MakeRecord(uint8_t type, uint16_t version,
|
|
const uint8_t *buf, size_t len,
|
|
DataBuffer* out, uint32_t seq_num = 0);
|
|
void MakeHandshakeMessage(uint8_t hs_type,
|
|
const uint8_t *data,
|
|
size_t hs_len,
|
|
DataBuffer* out,
|
|
uint32_t seq_num = 0);
|
|
void MakeHandshakeMessageFragment(uint8_t hs_type,
|
|
const uint8_t *data,
|
|
size_t hs_len,
|
|
DataBuffer* out,
|
|
uint32_t seq_num,
|
|
uint32_t fragment_offset,
|
|
uint32_t fragment_length);
|
|
void MakeTrivialHandshakeRecord(uint8_t hs_type,
|
|
size_t hs_len,
|
|
DataBuffer* out);
|
|
static inline TlsAgent::Role ToRole(const std::string& str) {
|
|
return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
|
|
}
|
|
|
|
static inline Mode ToMode(const std::string& str) {
|
|
return str == "TLS" ? STREAM : DGRAM;
|
|
}
|
|
|
|
void Init();
|
|
|
|
protected:
|
|
void EnsureInit();
|
|
void ProcessMessage(const DataBuffer& buffer,
|
|
TlsAgent::State expected_state,
|
|
int32_t error_code = 0);
|
|
|
|
|
|
TlsAgent* agent_;
|
|
PRFileDesc* fd_;
|
|
TlsAgent::Role role_;
|
|
Mode mode_;
|
|
};
|
|
|
|
class TlsAgentTest :
|
|
public TlsAgentTestBase,
|
|
public ::testing::WithParamInterface
|
|
<std::tuple<std::string,std::string>> {
|
|
public:
|
|
TlsAgentTest() :
|
|
TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
|
|
ToMode(std::get<1>(GetParam()))) {}
|
|
};
|
|
|
|
class TlsAgentTestClient : public TlsAgentTestBase,
|
|
public ::testing::WithParamInterface<std::string> {
|
|
public:
|
|
TlsAgentTestClient() : TlsAgentTestBase(TlsAgent::CLIENT,
|
|
ToMode(GetParam())) {}
|
|
};
|
|
|
|
class TlsAgentStreamTestClient : public TlsAgentTestBase {
|
|
public:
|
|
TlsAgentStreamTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, STREAM) {}
|
|
};
|
|
|
|
class TlsAgentDgramTestClient : public TlsAgentTestBase {
|
|
public:
|
|
TlsAgentDgramTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, DGRAM) {}
|
|
};
|
|
|
|
} // namespace nss_test
|
|
|
|
#endif
|