gecko-dev/testing/mochitest/ssltunnel/ssltunnel.cpp

1380 lines
42 KiB
C++
Raw Normal View History

/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* ***** BEGIN LICENSE BLOCK *****
* Version: MPL 1.1/GPL 2.0/LGPL 2.1
*
* The contents of this file are subject to the Mozilla Public License Version
* 1.1 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* http://www.mozilla.org/MPL/
*
* Software distributed under the License is distributed on an "AS IS" basis,
* WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
* for the specific language governing rights and limitations under the
* License.
*
* The Original Code is Mozilla test code
*
* The Initial Developer of the Original Code is
* Mozilla Foundation
* Portions created by the Initial Developer are Copyright (C) 2008
* the Initial Developer. All Rights Reserved.
*
* Contributor(s):
* Ted Mielczarek <ted.mielczarek@gmail.com>
* Honza Bambas <honzab@firemni.cz>
*
* Alternatively, the contents of this file may be used under the terms of
* either the GNU General Public License Version 2 or later (the "GPL"), or
* the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
* in which case the provisions of the GPL or the LGPL are applicable instead
* of those above. If you wish to allow use of your version of this file only
* under the terms of either the GPL or the LGPL, and not to allow others to
* use your version of this file under the terms of the MPL, indicate your
* decision by deleting the provisions above and replace them with the notice
* and other provisions required by the GPL or the LGPL. If you do not delete
* the provisions above, a recipient may use your version of this file under
* the terms of any one of the MPL, the GPL or the LGPL.
*
* ***** END LICENSE BLOCK ***** */
/*
* WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS. It is highly likely to
* be plagued with the usual problems endemic to C (buffer overflows
* and the like). We don't especially care here (but would accept
* patches!) because this is only intended for use in our test
* harnesses in controlled situations where input is guaranteed not to
* be malicious.
*/
#include <assert.h>
#include <stdio.h>
#include <string>
#include <vector>
#include <algorithm>
#include <stdarg.h>
#include "prinit.h"
#include "prerror.h"
#include "prenv.h"
#include "prio.h"
#include "prnetdb.h"
#include "prtpool.h"
#include "prtypes.h"
#include "nsAlgorithm.h"
#include "nss.h"
#include "pk11func.h"
#include "key.h"
#include "keyt.h"
#include "ssl.h"
#include "plhash.h"
using std::string;
using std::vector;
#define IS_DELIM(m, c) ((m)[(c) >> 3] & (1 << ((c) & 7)))
#define SET_DELIM(m, c) ((m)[(c) >> 3] |= (1 << ((c) & 7)))
#define DELIM_TABLE_SIZE 32
// You can set the level of logging by env var SSLTUNNEL_LOG_LEVEL=n, where n
// is 0 through 3. The default is 1, INFO level logging.
enum LogLevel {
LEVEL_DEBUG = 0,
LEVEL_INFO = 1,
LEVEL_ERROR = 2,
LEVEL_SILENT = 3
} gLogLevel, gLastLogLevel;
#define _LOG_OUTPUT(level, func, params) \
PR_BEGIN_MACRO \
if (level >= gLogLevel) { \
gLastLogLevel = level; \
func params;\
} \
PR_END_MACRO
// The most verbose output
#define LOG_DEBUG(params) \
_LOG_OUTPUT(LEVEL_DEBUG, printf, params)
// Top level informative messages
#define LOG_INFO(params) \
_LOG_OUTPUT(LEVEL_INFO, printf, params)
// Serious errors that must be logged always until completely gag
#define LOG_ERROR(params) \
_LOG_OUTPUT(LEVEL_ERROR, eprintf, params)
// Same as LOG_ERROR, but when logging is set to LEVEL_DEBUG, the message
// will be put to the stdout instead of stderr to keep continuity with other
// LOG_DEBUG message output
#define LOG_ERRORD(params) \
PR_BEGIN_MACRO \
if (gLogLevel == LEVEL_DEBUG) \
_LOG_OUTPUT(LEVEL_ERROR, printf, params); \
else \
_LOG_OUTPUT(LEVEL_ERROR, eprintf, params); \
PR_END_MACRO
// If there is any output written between LOG_BEGIN_BLOCK() and
// LOG_END_BLOCK() then a new line will be put to the proper output (out/err)
#define LOG_BEGIN_BLOCK() \
gLastLogLevel = LEVEL_SILENT;
#define LOG_END_BLOCK() \
PR_BEGIN_MACRO \
if (gLastLogLevel == LEVEL_ERROR) \
LOG_ERROR(("\n")); \
if (gLastLogLevel < LEVEL_ERROR) \
_LOG_OUTPUT(gLastLogLevel, printf, ("\n")); \
PR_END_MACRO
int eprintf(const char* str, ...)
{
va_list ap;
va_start(ap, str);
int result = vfprintf(stderr, str, ap);
va_end(ap);
return result;
}
// Copied from nsCRT
char* strtok2(char* string, const char* delims, char* *newStr)
{
PR_ASSERT(string);
char delimTable[DELIM_TABLE_SIZE];
PRUint32 i;
char* result;
char* str = string;
for (i = 0; i < DELIM_TABLE_SIZE; i++)
delimTable[i] = '\0';
for (i = 0; delims[i]; i++) {
SET_DELIM(delimTable, static_cast<PRUint8>(delims[i]));
}
// skip to beginning
while (*str && IS_DELIM(delimTable, static_cast<PRUint8>(*str))) {
str++;
}
result = str;
// fix up the end of the token
while (*str) {
if (IS_DELIM(delimTable, static_cast<PRUint8>(*str))) {
*str++ = '\0';
break;
}
str++;
}
*newStr = str;
return str == result ? NULL : result;
}
enum client_auth_option {
caNone = 0,
caRequire = 1,
caRequest = 2
};
// Structs for passing data into jobs on the thread pool
typedef struct {
PRInt32 listen_port;
string cert_nickname;
PLHashTable* host_cert_table;
PLHashTable* host_clientauth_table;
} server_info_t;
typedef struct {
PRFileDesc* client_sock;
PRNetAddr client_addr;
server_info_t* server_info;
// the original host in the Host: header for this connection is
// stored here, for proxied connections
string original_host;
// true if no SSL should be used for this connection
bool http_proxy_only;
// true if this connection is for a WebSocket
bool iswebsocket;
} connection_info_t;
typedef struct {
string fullHost;
bool matched;
} server_match_t;
const PRInt32 BUF_SIZE = 16384;
const PRInt32 BUF_MARGIN = 1024;
const PRInt32 BUF_TOTAL = BUF_SIZE + BUF_MARGIN;
struct relayBuffer
{
char *buffer, *bufferhead, *buffertail, *bufferend;
relayBuffer()
{
// Leave 1024 bytes more for request line manipulations
bufferhead = buffertail = buffer = new char[BUF_TOTAL];
bufferend = buffer + BUF_SIZE;
}
~relayBuffer()
{
delete [] buffer;
}
void compact() {
if (buffertail == bufferhead)
buffertail = bufferhead = buffer;
}
bool empty() { return bufferhead == buffertail; }
size_t areafree() { return bufferend - buffertail; }
size_t margin() { return areafree() + BUF_MARGIN; }
size_t present() { return buffertail - bufferhead; }
};
// A couple of stack classes for managing NSS/NSPR resources
class AutoCert {
public:
AutoCert(CERTCertificate* cert) { cert_ = cert; }
~AutoCert() { if (cert_) CERT_DestroyCertificate(cert_); }
operator CERTCertificate*() { return cert_; }
private:
CERTCertificate* cert_;
};
class AutoKey {
public:
AutoKey(SECKEYPrivateKey* key) { key_ = key; }
~AutoKey() { if (key_) SECKEY_DestroyPrivateKey(key_); }
operator SECKEYPrivateKey*() { return key_; }
private:
SECKEYPrivateKey* key_;
};
class AutoFD {
public:
AutoFD(PRFileDesc* fd) { fd_ = fd; }
~AutoFD() {
if (fd_) {
PR_Shutdown(fd_, PR_SHUTDOWN_BOTH);
PR_Close(fd_);
}
}
operator PRFileDesc*() { return fd_; }
PRFileDesc* reset(PRFileDesc* newfd) {
PRFileDesc* oldfd = fd_;
fd_ = newfd;
return oldfd;
}
private:
PRFileDesc* fd_;
};
// These numbers are multiplied by the number of listening ports (actual
// servers running). According the thread pool implementation there is no
// need to limit the number of threads initially, threads are allocated
// dynamically and stored in a linked list. Initial number of 2 is chosen
// to allocate a thread for socket accept and preallocate one for the first
// connection that is with high probability expected to come.
const PRUint32 INITIAL_THREADS = 2;
const PRUint32 MAX_THREADS = 100;
const PRUint32 DEFAULT_STACKSIZE = (512 * 1024);
// global data
string nssconfigdir;
vector<server_info_t> servers;
PRNetAddr remote_addr;
PRNetAddr websocket_server;
PRThreadPool* threads = NULL;
PRLock* shutdown_lock = NULL;
PRCondVar* shutdown_condvar = NULL;
// Not really used, unless something fails to start
bool shutdown_server = false;
bool do_http_proxy = false;
bool any_host_spec_config = false;
PR_CALLBACK PRIntn ClientAuthValueComparator(const void *v1, const void *v2)
{
int a = *static_cast<const client_auth_option*>(v1) -
*static_cast<const client_auth_option*>(v2);
if (a == 0)
return 0;
if (a > 0)
return 1;
else // (a < 0)
return -1;
}
static PRIntn match_hostname(PLHashEntry *he, PRIntn index, void* arg)
{
server_match_t *match = (server_match_t*)arg;
if (match->fullHost.find((char*)he->key) != string::npos)
match->matched = true;
return HT_ENUMERATE_NEXT;
}
/*
* Signal the main thread that the application should shut down.
*/
void SignalShutdown()
{
PR_Lock(shutdown_lock);
PR_NotifyCondVar(shutdown_condvar);
PR_Unlock(shutdown_lock);
}
bool ReadConnectRequest(server_info_t* server_info,
relayBuffer& buffer, PRInt32* result, string& certificate,
client_auth_option* clientauth, string& host)
{
if (buffer.present() < 4) {
LOG_DEBUG((" !! only %d bytes present in the buffer", (int)buffer.present()));
return false;
}
if (strncmp(buffer.buffertail-4, "\r\n\r\n", 4)) {
LOG_ERRORD((" !! request is not tailed with CRLFCRLF but with %x %x %x %x",
*(buffer.buffertail-4),
*(buffer.buffertail-3),
*(buffer.buffertail-2),
*(buffer.buffertail-1)));
return false;
}
LOG_DEBUG((" parsing initial connect request, dump:\n%.*s\n", (int)buffer.present(), buffer.bufferhead));
*result = 400;
char* token;
char* _caret;
token = strtok2(buffer.bufferhead, " ", &_caret);
if (!token) {
LOG_ERRORD((" no space found"));
return true;
}
if (strcmp(token, "CONNECT")) {
LOG_ERRORD((" not CONNECT request but %s", token));
return true;
}
token = strtok2(_caret, " ", &_caret);
void* c = PL_HashTableLookup(server_info->host_cert_table, token);
if (c)
certificate = static_cast<char*>(c);
host = "https://";
host += token;
c = PL_HashTableLookup(server_info->host_clientauth_table, token);
if (c)
*clientauth = *static_cast<client_auth_option*>(c);
else
*clientauth = caNone;
token = strtok2(_caret, "/", &_caret);
if (strcmp(token, "HTTP")) {
LOG_ERRORD((" not tailed with HTTP but with %s", token));
return true;
}
*result = 200;
return true;
}
bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si, string &certificate, client_auth_option clientAuth)
{
const char* certnick = certificate.empty() ?
si->cert_nickname.c_str() : certificate.c_str();
AutoCert cert(PK11_FindCertFromNickname(
certnick, NULL));
if (!cert) {
LOG_ERROR(("Failed to find cert %s\n", certnick));
return false;
}
AutoKey privKey(PK11_FindKeyByAnyCert(cert, NULL));
if (!privKey) {
LOG_ERROR(("Failed to find private key\n"));
return false;
}
PRFileDesc* ssl_socket = SSL_ImportFD(NULL, socket);
if (!ssl_socket) {
LOG_ERROR(("Error importing SSL socket\n"));
return false;
}
SSLKEAType certKEA = NSS_FindCertKEAType(cert);
if (SSL_ConfigSecureServer(ssl_socket, cert, privKey, certKEA)
!= SECSuccess) {
LOG_ERROR(("Error configuring SSL server socket\n"));
return false;
}
SSL_OptionSet(ssl_socket, SSL_SECURITY, PR_TRUE);
SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, PR_FALSE);
SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, PR_TRUE);
if (clientAuth != caNone)
{
SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, PR_TRUE);
SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire);
}
SSL_ResetHandshake(ssl_socket, PR_TRUE);
return true;
}
/**
* This function examines the buffer for a Sec-WebSocket-Location: field,
* and if it's present, it replaces the hostname in that field with the
* value in the server's original_host field. This function works
* in the reverse direction as AdjustWebSocketHost(), replacing the real
* hostname of a response with the potentially fake hostname that is expected
* by the browser (e.g., mochi.test).
*
* @return true if the header was adjusted successfully, or not found, false
* if the header is present but the url is not, which should indicate
* that more data needs to be read from the socket
*/
bool AdjustWebSocketLocation(relayBuffer& buffer, connection_info_t *ci)
{
assert(buffer.margin());
buffer.buffertail[1] = '\0';
char* wsloc = strstr(buffer.bufferhead, "Sec-WebSocket-Location:");
if (!wsloc)
return true;
// advance pointer to the start of the hostname
wsloc = strstr(wsloc, "ws://");
if (!wsloc)
return false;
wsloc += 5;
// find the end of the hostname
char* wslocend = strchr(wsloc + 1, '/');
if (!wslocend)
return false;
char *crlf = strstr(wsloc, "\r\n");
if (!crlf)
return false;
if (ci->original_host.empty())
return true;
int diff = ci->original_host.length() - (wslocend-wsloc);
if (diff > 0)
assert(size_t(diff) <= buffer.margin());
memmove(wslocend + diff, wslocend, buffer.buffertail - wsloc - diff);
buffer.buffertail += diff;
memcpy(wsloc, ci->original_host.c_str(), ci->original_host.length());
return true;
}
/**
* This function examines the buffer for a Host: field, and if it's present,
* it replaces the hostname in that field with the hostname in the server's
* remote_addr field. This is needed because proxy requests may be coming
* from mochitest with fake hosts, like mochi.test, and these need to be
* replaced with the host that the destination server is actually running
* on.
*/
bool AdjustWebSocketHost(relayBuffer& buffer, connection_info_t *ci)
{
const char HEADER_UPGRADE[] = "Upgrade:";
const char HEADER_HOST[] = "Host:";
PRNetAddr inet_addr = (websocket_server.inet.port ? websocket_server :
remote_addr);
assert(buffer.margin());
// Cannot use strnchr so add a null char at the end. There is always some
// space left because we preserve a margin.
buffer.buffertail[1] = '\0';
// Verify this is a WebSocket header.
char* h1 = strstr(buffer.bufferhead, HEADER_UPGRADE);
if (!h1)
return false;
h1 += strlen(HEADER_UPGRADE);
h1 += strspn(h1, " \t");
char* h2 = strstr(h1, "WebSocket\r\n");
if (!h2) h2 = strstr(h1, "websocket\r\n");
if (!h2) h2 = strstr(h1, "Websocket\r\n");
if (!h2)
return false;
char* host = strstr(buffer.bufferhead, HEADER_HOST);
if (!host)
return false;
// advance pointer to beginning of hostname
host += strlen(HEADER_HOST);
host += strspn(host, " \t");
char* endhost = strstr(host, "\r\n");
if (!endhost)
return false;
// Save the original host, so we can use it later on responses from the
// server.
ci->original_host.assign(host, endhost-host);
char newhost[40];
PR_NetAddrToString(&inet_addr, newhost, sizeof(newhost));
assert(strlen(newhost) < sizeof(newhost) - 7);
sprintf(newhost, "%s:%d", newhost, PR_ntohs(inet_addr.inet.port));
int diff = strlen(newhost) - (endhost-host);
if (diff > 0)
assert(size_t(diff) <= buffer.margin());
memmove(endhost + diff, endhost, buffer.buffertail - host - diff);
buffer.buffertail += diff;
memcpy(host, newhost, strlen(newhost));
return true;
}
/**
* This function prefixes Request-URI path with a full scheme-host-port
* string.
*/
bool AdjustRequestURI(relayBuffer& buffer, string *host)
{
assert(buffer.margin());
// Cannot use strnchr so add a null char at the end. There is always some space left
// because we preserve a margin.
buffer.buffertail[1] = '\0';
LOG_DEBUG((" incoming request to adjust:\n%s\n", buffer.bufferhead));
char *token, *path;
path = strchr(buffer.bufferhead, ' ') + 1;
if (!path)
return false;
// If the path doesn't start with a slash don't change it, it is probably '*' or a full
// path already. Return true, we are done with this request adjustment.
if (*path != '/')
return true;
token = strchr(path, ' ') + 1;
if (!token)
return false;
if (strncmp(token, "HTTP/", 5))
return false;
size_t hostlength = host->length();
assert(hostlength <= buffer.margin());
memmove(path + hostlength, path, buffer.buffertail - path);
memcpy(path, host->c_str(), hostlength);
buffer.buffertail += hostlength;
return true;
}
bool ConnectSocket(PRFileDesc *fd, const PRNetAddr *addr, PRIntervalTime timeout)
{
PRStatus stat = PR_Connect(fd, addr, timeout);
if (stat != PR_SUCCESS)
return false;
PRSocketOptionData option;
option.option = PR_SockOpt_Nonblocking;
option.value.non_blocking = PR_TRUE;
PR_SetSocketOption(fd, &option);
return true;
}
/*
* Handle an incoming client connection. The server thread has already
* accepted the connection, so we just need to connect to the remote
* port and then proxy data back and forth.
* The data parameter is a connection_info_t*, and must be deleted
* by this function.
*/
void HandleConnection(void* data)
{
connection_info_t* ci = static_cast<connection_info_t*>(data);
PRIntervalTime connect_timeout = PR_SecondsToInterval(30);
AutoFD other_sock(PR_NewTCPSocket());
bool client_done = false;
bool client_error = false;
bool connect_accepted = !do_http_proxy;
bool ssl_updated = !do_http_proxy;
bool expect_request_start = do_http_proxy;
string certificateToUse;
client_auth_option clientAuth;
string fullHost;
LOG_DEBUG(("SSLTUNNEL(%p)): incoming connection csock(0)=%p, ssock(1)=%p\n",
static_cast<void*>(data),
static_cast<void*>(ci->client_sock),
static_cast<void*>(other_sock)));
if (other_sock)
{
PRInt32 numberOfSockets = 1;
relayBuffer buffers[2];
if (!do_http_proxy)
{
if (!ci->http_proxy_only &&
!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, certificateToUse, caNone))
client_error = true;
else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout))
client_error = true;
else
numberOfSockets = 2;
}
PRPollDesc sockets[2] =
{
{ci->client_sock, PR_POLL_READ, 0},
{other_sock, PR_POLL_READ, 0}
};
bool socketErrorState[2] = {false, false};
while (!((client_error||client_done) && buffers[0].empty() && buffers[1].empty()))
{
sockets[0].in_flags |= PR_POLL_EXCEPT;
sockets[1].in_flags |= PR_POLL_EXCEPT;
LOG_DEBUG(("SSLTUNNEL(%p)): polling flags csock(0)=%c%c, ssock(1)=%c%c\n",
static_cast<void*>(data),
sockets[0].in_flags & PR_POLL_READ ? 'R' : '-',
sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-',
sockets[1].in_flags & PR_POLL_READ ? 'R' : '-',
sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-'));
PRInt32 pollStatus = PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000));
if (pollStatus < 0)
{
LOG_DEBUG(("SSLTUNNEL(%p)): pollStatus=%d, exiting\n",
static_cast<void*>(data), pollStatus));
client_error = true;
break;
}
if (pollStatus == 0)
{
// timeout
LOG_DEBUG(("SSLTUNNEL(%p)): poll timeout, looping\n",
static_cast<void*>(data)));
continue;
}
for (PRInt32 s = 0; s < numberOfSockets; ++s)
{
PRInt32 s2 = s == 1 ? 0 : 1;
PRInt16 out_flags = sockets[s].out_flags;
PRInt16 &in_flags = sockets[s].in_flags;
PRInt16 &in_flags2 = sockets[s2].in_flags;
sockets[s].out_flags = 0;
LOG_BEGIN_BLOCK();
LOG_DEBUG(("SSLTUNNEL(%p)): %csock(%d)=%p out_flags=%d",
static_cast<void*>(data),
s == 0 ? 'c' : 's',
s,
static_cast<void*>(sockets[s].fd),
out_flags));
if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP))
{
LOG_DEBUG((" :exception\n"));
client_error = true;
socketErrorState[s] = PR_TRUE;
// We got a fatal error state on the socket. Clear the output buffer
// for this socket to break the main loop, we will never more be able
// to send those data anyway.
buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
continue;
2009-08-17 17:25:35 +00:00
} // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling
if (out_flags & PR_POLL_READ && !buffers[s].areafree())
{
LOG_DEBUG((" no place in read buffer but got read flag, dropping it now!"));
in_flags &= ~PR_POLL_READ;
}
if (out_flags & PR_POLL_READ && buffers[s].areafree())
{
LOG_DEBUG((" :reading"));
PRInt32 bytesRead = PR_Recv(sockets[s].fd, buffers[s].buffertail,
buffers[s].areafree(), 0, PR_INTERVAL_NO_TIMEOUT);
if (bytesRead == 0)
{
LOG_DEBUG((" socket gracefully closed"));
client_done = true;
in_flags &= ~PR_POLL_READ;
}
else if (bytesRead < 0)
{
if (PR_GetError() != PR_WOULD_BLOCK_ERROR)
{
LOG_DEBUG((" error=%d", PR_GetError()));
// We are in error state, indicate that the connection was
// not closed gracefully
client_error = true;
socketErrorState[s] = PR_TRUE;
// Wipe out our send buffer, we cannot send it anyway.
buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
}
else
LOG_DEBUG((" would block"));
}
else
{
// If the other socket is in error state (unable to send/receive)
// throw this data away and continue loop
if (socketErrorState[s2])
{
LOG_DEBUG((" have read but other socket is in error state\n"));
continue;
}
buffers[s].buffertail += bytesRead;
LOG_DEBUG((", read %d bytes", bytesRead));
// We have to accept and handle the initial CONNECT request here
PRInt32 response;
if (!connect_accepted && ReadConnectRequest(ci->server_info, buffers[s],
&response, certificateToUse, &clientAuth, fullHost))
{
// Mark this as a proxy-only connection (no SSL) if the CONNECT
// request didn't come for port 443 or from any of the server's
// cert or clientauth hostnames.
if (fullHost.find(":443") == string::npos)
{
server_match_t match;
match.fullHost = fullHost;
match.matched = false;
PL_HashTableEnumerateEntries(ci->server_info->host_cert_table,
match_hostname,
&match);
PL_HashTableEnumerateEntries(ci->server_info->host_clientauth_table,
match_hostname,
&match);
ci->http_proxy_only = !match.matched;
}
else
{
ci->http_proxy_only = false;
}
// Clean the request as it would be read
buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer;
in_flags |= PR_POLL_WRITE;
connect_accepted = true;
// Store response to the oposite buffer
if (response != 200)
{
LOG_ERRORD((" could not read the connect request, closing connection with %d", response));
client_done = true;
sprintf(buffers[s2].buffer, "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n", response);
buffers[s2].buffertail = buffers[s2].buffer + strlen(buffers[s2].buffer);
break;
}
strcpy(buffers[s2].buffer, "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n");
buffers[s2].buffertail = buffers[s2].buffer + strlen(buffers[s2].buffer);
LOG_DEBUG((" accepted CONNECT request, connected to the server, sending OK to the client\n"));
// Send the response to the client socket
break;
} // end of CONNECT handling
if (!buffers[s].areafree())
{
// Do not poll for read when the buffer is full
LOG_DEBUG((" no place in our read buffer, stop reading"));
in_flags &= ~PR_POLL_READ;
}
if (ssl_updated)
{
if (s == 0 && expect_request_start)
{
if (!strstr(buffers[s].bufferhead, "\r\n\r\n"))
{
// We haven't received the complete header yet, so wait.
continue;
}
else
{
ci->iswebsocket = AdjustWebSocketHost(buffers[s], ci);
expect_request_start = !(ci->iswebsocket ||
AdjustRequestURI(buffers[s], &fullHost));
PRNetAddr* addr = &remote_addr;
if (ci->iswebsocket && websocket_server.inet.port)
addr = &websocket_server;
if (!ConnectSocket(other_sock, addr, connect_timeout))
{
LOG_ERRORD((" could not open connection to the real server\n"));
client_error = true;
break;
}
LOG_DEBUG(("\n connected to remote server\n"));
numberOfSockets = 2;
}
}
else if (s == 1 && ci->iswebsocket)
{
if (!AdjustWebSocketLocation(buffers[s], ci))
continue;
}
in_flags2 |= PR_POLL_WRITE;
LOG_DEBUG((" telling the other socket to write"));
}
else
LOG_DEBUG((" we have something for the other socket to write, but ssl has not been administered on it"));
}
} // PR_POLL_READ handling
if (out_flags & PR_POLL_WRITE)
{
LOG_DEBUG((" :writing"));
PRInt32 bytesWrite = PR_Send(sockets[s].fd, buffers[s2].bufferhead,
buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT);
if (bytesWrite < 0)
{
if (PR_GetError() != PR_WOULD_BLOCK_ERROR) {
LOG_DEBUG((" error=%d", PR_GetError()));
client_error = true;
socketErrorState[s] = PR_TRUE;
// We got a fatal error while writting the buffer. Clear it to break
// the main loop, we will never more be able to send it.
buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
}
else
LOG_DEBUG((" would block"));
}
else
{
LOG_DEBUG((", written %d bytes", bytesWrite));
buffers[s2].buffertail[1] = '\0';
LOG_DEBUG((" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead));
buffers[s2].bufferhead += bytesWrite;
if (buffers[s2].present())
{
LOG_DEBUG((" still have to write %d bytes", (int)buffers[s2].present()));
in_flags |= PR_POLL_WRITE;
}
else
{
if (!ssl_updated)
{
LOG_DEBUG((" proxy response sent to the client"));
// Proxy response has just been writen, update to ssl
ssl_updated = true;
if (!ci->http_proxy_only &&
!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, certificateToUse, clientAuth))
{
LOG_ERRORD((" failed to config server socket\n"));
client_error = true;
break;
}
LOG_DEBUG((" client socket updated to SSL"));
} // sslUpdate
LOG_DEBUG((" dropping our write flag and setting other socket read flag"));
in_flags &= ~PR_POLL_WRITE;
in_flags2 |= PR_POLL_READ;
buffers[s2].compact();
}
}
} // PR_POLL_WRITE handling
LOG_END_BLOCK(); // end the log
} // for...
} // while, poll
}
else
client_error = true;
LOG_DEBUG(("SSLTUNNEL(%p)): exiting root function for csock=%p, ssock=%p\n",
static_cast<void*>(data),
static_cast<void*>(ci->client_sock),
static_cast<void*>(other_sock)));
if (!client_error)
PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND);
PR_Close(ci->client_sock);
delete ci;
}
/*
* Start listening for SSL connections on a specified port, handing
* them off to client threads after accepting the connection.
* The data parameter is a server_info_t*, owned by the calling
* function.
*/
void StartServer(void* data)
{
server_info_t* si = static_cast<server_info_t*>(data);
//TODO: select ciphers?
AutoFD listen_socket(PR_NewTCPSocket());
if (!listen_socket) {
LOG_ERROR(("failed to create socket\n"));
SignalShutdown();
return;
}
// In case the socket is still open in the TIME_WAIT state from a previous
// instance of ssltunnel we ask to reuse the port.
PRSocketOptionData socket_option;
socket_option.option = PR_SockOpt_Reuseaddr;
socket_option.value.reuse_addr = PR_TRUE;
PR_SetSocketOption(listen_socket, &socket_option);
PRNetAddr server_addr;
PR_InitializeNetAddr(PR_IpAddrAny, si->listen_port, &server_addr);
if (PR_Bind(listen_socket, &server_addr) != PR_SUCCESS) {
LOG_ERROR(("failed to bind socket\n"));
SignalShutdown();
return;
}
if (PR_Listen(listen_socket, 1) != PR_SUCCESS) {
LOG_ERROR(("failed to listen on socket\n"));
SignalShutdown();
return;
}
LOG_INFO(("Server listening on port %d with cert %s\n", si->listen_port,
si->cert_nickname.c_str()));
while (!shutdown_server) {
connection_info_t* ci = new connection_info_t();
ci->server_info = si;
// block waiting for connections
ci->client_sock = PR_Accept(listen_socket, &ci->client_addr,
PR_INTERVAL_NO_TIMEOUT);
PRSocketOptionData option;
option.option = PR_SockOpt_Nonblocking;
option.value.non_blocking = PR_TRUE;
PR_SetSocketOption(ci->client_sock, &option);
if (ci->client_sock)
// Not actually using this PRJob*...
//PRJob* job =
PR_QueueJob(threads, HandleConnection, ci, PR_TRUE);
else
delete ci;
}
}
// bogus password func, just don't use passwords. :-P
char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg)
{
if (retry)
return NULL;
return PL_strdup("");
}
server_info_t* findServerInfo(int portnumber)
{
for (vector<server_info_t>::iterator it = servers.begin();
it != servers.end(); it++)
{
if (it->listen_port == portnumber)
return &(*it);
}
return NULL;
}
int processConfigLine(char* configLine)
{
if (*configLine == 0 || *configLine == '#')
return 0;
char* _caret;
char* keyword = strtok2(configLine, ":", &_caret);
// Configure usage of http/ssl tunneling proxy behavior
if (!strcmp(keyword, "httpproxy"))
{
char* value = strtok2(_caret, ":", &_caret);
if (!strcmp(value, "1"))
do_http_proxy = true;
return 0;
}
if (!strcmp(keyword, "websocketserver"))
{
char* ipstring = strtok2(_caret, ":", &_caret);
if (PR_StringToNetAddr(ipstring, &websocket_server) != PR_SUCCESS) {
LOG_ERROR(("Invalid IP address in proxy config: %s\n", ipstring));
return 1;
}
char* remoteport = strtok2(_caret, ":", &_caret);
int port = atoi(remoteport);
if (port <= 0) {
LOG_ERROR(("Invalid remote port in proxy config: %s\n", remoteport));
return 1;
}
websocket_server.inet.port = PR_htons(port);
return 0;
}
// Configure the forward address of the target server
if (!strcmp(keyword, "forward"))
{
char* ipstring = strtok2(_caret, ":", &_caret);
if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) {
LOG_ERROR(("Invalid remote IP address: %s\n", ipstring));
return 1;
}
char* serverportstring = strtok2(_caret, ":", &_caret);
int port = atoi(serverportstring);
if (port <= 0) {
LOG_ERROR(("Invalid remote port: %s\n", serverportstring));
return 1;
}
remote_addr.inet.port = PR_htons(port);
return 0;
}
// Configure all listen sockets and port+certificate bindings
if (!strcmp(keyword, "listen"))
{
char* hostname = strtok2(_caret, ":", &_caret);
char* hostportstring = NULL;
if (strcmp(hostname, "*"))
{
any_host_spec_config = true;
hostportstring = strtok2(_caret, ":", &_caret);
}
char* serverportstring = strtok2(_caret, ":", &_caret);
char* certnick = strtok2(_caret, ":", &_caret);
int port = atoi(serverportstring);
if (port <= 0) {
LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
return 1;
}
if (server_info_t* existingServer = findServerInfo(port))
{
char *certnick_copy = new char[strlen(certnick)+1];
char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
strcpy(hostname_copy, hostname);
strcat(hostname_copy, ":");
strcat(hostname_copy, hostportstring);
strcpy(certnick_copy, certnick);
PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table, hostname_copy, certnick_copy);
if (!entry) {
LOG_ERROR(("Out of memory"));
return 1;
}
}
else
{
server_info_t server;
server.cert_nickname = certnick;
server.listen_port = port;
server.host_cert_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, PL_CompareStrings, NULL, NULL);
if (!server.host_cert_table)
{
LOG_ERROR(("Internal, could not create hash table\n"));
return 1;
}
server.host_clientauth_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, ClientAuthValueComparator, NULL, NULL);
if (!server.host_clientauth_table)
{
LOG_ERROR(("Internal, could not create hash table\n"));
return 1;
}
servers.push_back(server);
}
return 0;
}
if (!strcmp(keyword, "clientauth"))
{
char* hostname = strtok2(_caret, ":", &_caret);
char* hostportstring = strtok2(_caret, ":", &_caret);
char* serverportstring = strtok2(_caret, ":", &_caret);
int port = atoi(serverportstring);
if (port <= 0) {
LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
return 1;
}
if (server_info_t* existingServer = findServerInfo(port))
{
char* authoptionstring = strtok2(_caret, ":", &_caret);
client_auth_option* authoption = new client_auth_option;
if (!authoption) {
LOG_ERROR(("Out of memory"));
return 1;
}
if (!strcmp(authoptionstring, "require"))
*authoption = caRequire;
else if (!strcmp(authoptionstring, "request"))
*authoption = caRequest;
else if (!strcmp(authoptionstring, "none"))
*authoption = caNone;
else
{
LOG_ERROR(("Incorrect client auth option modifier for host '%s'", hostname));
return 1;
}
any_host_spec_config = true;
char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
if (!hostname_copy) {
LOG_ERROR(("Out of memory"));
return 1;
}
strcpy(hostname_copy, hostname);
strcat(hostname_copy, ":");
strcat(hostname_copy, hostportstring);
PLHashEntry* entry = PL_HashTableAdd(existingServer->host_clientauth_table, hostname_copy, authoption);
if (!entry) {
LOG_ERROR(("Out of memory"));
return 1;
}
}
else
{
LOG_ERROR(("Server on port %d for client authentication option is not defined, use 'listen' option first", port));
return 1;
}
return 0;
}
// Configure the NSS certificate database directory
if (!strcmp(keyword, "certdbdir"))
{
nssconfigdir = strtok2(_caret, "\n", &_caret);
return 0;
}
LOG_ERROR(("Error: keyword \"%s\" unexpected\n", keyword));
return 1;
}
int parseConfigFile(const char* filePath)
{
FILE* f = fopen(filePath, "r");
if (!f)
return 1;
char buffer[1024], *b = buffer;
while (!feof(f))
{
char c;
fscanf(f, "%c", &c);
switch (c)
{
case '\n':
*b++ = 0;
if (processConfigLine(buffer))
return 1;
b = buffer;
case '\r':
continue;
default:
*b++ = c;
}
}
fclose(f);
// Check mandatory items
if (nssconfigdir.empty())
{
LOG_ERROR(("Error: missing path to NSS certification database\n,use certdbdir:<path> in the config file\n"));
return 1;
}
if (any_host_spec_config && !do_http_proxy)
{
LOG_ERROR(("Warning: any host-specific configurations are ignored, add httpproxy:1 to allow them\n"));
}
return 0;
}
PRIntn freeHostCertHashItems(PLHashEntry *he, PRIntn i, void *arg)
{
delete [] (char*)he->key;
delete [] (char*)he->value;
return HT_ENUMERATE_REMOVE;
}
PRIntn freeClientAuthHashItems(PLHashEntry *he, PRIntn i, void *arg)
{
delete [] (char*)he->key;
delete (client_auth_option*)he->value;
return HT_ENUMERATE_REMOVE;
}
int main(int argc, char** argv)
{
const char* configFilePath;
const char* logLevelEnv = PR_GetEnv("SSLTUNNEL_LOG_LEVEL");
gLogLevel = logLevelEnv ? (LogLevel)atoi(logLevelEnv) : LEVEL_INFO;
if (argc == 1)
configFilePath = "ssltunnel.cfg";
else
configFilePath = argv[1];
memset(&websocket_server, 0, sizeof(PRNetAddr));
if (parseConfigFile(configFilePath)) {
LOG_ERROR(("Error: config file \"%s\" missing or formating incorrect\n"
"Specify path to the config file as parameter to ssltunnel or \n"
"create ssltunnel.cfg in the working directory.\n\n"
"Example format of the config file:\n\n"
" # Enable http/ssl tunneling proxy-like behavior.\n"
" # If not specified ssltunnel simply does direct forward.\n"
" httpproxy:1\n\n"
" # Specify path to the certification database used.\n"
" certdbdir:/path/to/certdb\n\n"
" # Forward/proxy all requests in raw to 127.0.0.1:8888.\n"
" forward:127.0.0.1:8888\n\n"
" # Accept connections on port 4443 or 5678 resp. and authenticate\n"
" # to any host ('*') using the 'server cert' or 'server cert 2' resp.\n"
" listen:*:4443:server cert\n"
" listen:*:5678:server cert 2\n\n"
" # Accept connections on port 4443 and authenticate using\n"
" # 'a different cert' when target host is 'my.host.name:443'.\n"
" # This only works in httpproxy mode and has higher priority\n"
" # than the previous option.\n"
" listen:my.host.name:443:4443:a different cert\n\n"
" # To make a specific host require or just request a client certificate\n"
" # to authenticate use the following options. This can only be used\n"
" # in httpproxy mode and only after the 'listen' option has been\n"
" # specified. You also have to specify the tunnel listen port.\n"
" clientauth:requesting-client-cert.host.com:443:4443:request\n"
" clientauth:requiring-client-cert.host.com:443:4443:require\n"
" # Proxy WebSocket traffic to the server at 127.0.0.1:9999,\n"
" # instead of the server specified in the 'forward' option.\n"
" websocketserver:127.0.0.1:9999\n",
configFilePath));
return 1;
}
// create a thread pool to handle connections
threads = PR_CreateThreadPool(INITIAL_THREADS * servers.size(),
MAX_THREADS * servers.size(),
DEFAULT_STACKSIZE);
if (!threads) {
LOG_ERROR(("Failed to create thread pool\n"));
return 1;
}
shutdown_lock = PR_NewLock();
if (!shutdown_lock) {
LOG_ERROR(("Failed to create lock\n"));
PR_ShutdownThreadPool(threads);
return 1;
}
shutdown_condvar = PR_NewCondVar(shutdown_lock);
if (!shutdown_condvar) {
LOG_ERROR(("Failed to create condvar\n"));
PR_ShutdownThreadPool(threads);
PR_DestroyLock(shutdown_lock);
return 1;
}
PK11_SetPasswordFunc(password_func);
// Initialize NSS
if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) {
PRInt32 errorlen = PR_GetErrorTextLength();
char* err = new char[errorlen+1];
PR_GetErrorText(err);
LOG_ERROR(("Failed to init NSS: %s", err));
delete[] err;
PR_ShutdownThreadPool(threads);
PR_DestroyCondVar(shutdown_condvar);
PR_DestroyLock(shutdown_lock);
return 1;
}
if (NSS_SetDomesticPolicy() != SECSuccess) {
LOG_ERROR(("NSS_SetDomesticPolicy failed\n"));
PR_ShutdownThreadPool(threads);
PR_DestroyCondVar(shutdown_condvar);
PR_DestroyLock(shutdown_lock);
NSS_Shutdown();
return 1;
}
// these values should make NSS use the defaults
if (SSL_ConfigServerSessionIDCache(0, 0, 0, NULL) != SECSuccess) {
LOG_ERROR(("SSL_ConfigServerSessionIDCache failed\n"));
PR_ShutdownThreadPool(threads);
PR_DestroyCondVar(shutdown_condvar);
PR_DestroyLock(shutdown_lock);
NSS_Shutdown();
return 1;
}
for (vector<server_info_t>::iterator it = servers.begin();
it != servers.end(); it++) {
// Not actually using this PRJob*...
// PRJob* server_job =
PR_QueueJob(threads, StartServer, &(*it), PR_TRUE);
}
// now wait for someone to tell us to quit
PR_Lock(shutdown_lock);
PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT);
PR_Unlock(shutdown_lock);
shutdown_server = true;
LOG_INFO(("Shutting down...\n"));
// cleanup
PR_ShutdownThreadPool(threads);
PR_JoinThreadPool(threads);
PR_DestroyCondVar(shutdown_condvar);
PR_DestroyLock(shutdown_lock);
if (NSS_Shutdown() == SECFailure) {
LOG_DEBUG(("Leaked NSS objects!\n"));
}
for (vector<server_info_t>::iterator it = servers.begin();
it != servers.end(); it++)
{
PL_HashTableEnumerateEntries(it->host_cert_table, freeHostCertHashItems, NULL);
PL_HashTableEnumerateEntries(it->host_clientauth_table, freeClientAuthHashItems, NULL);
PL_HashTableDestroy(it->host_cert_table);
PL_HashTableDestroy(it->host_clientauth_table);
}
PR_Cleanup();
return 0;
}