mirror of
https://github.com/mozilla/gecko-dev.git
synced 2024-11-08 20:47:44 +00:00
1097 lines
33 KiB
C++
1097 lines
33 KiB
C++
/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
|
|
/* ***** BEGIN LICENSE BLOCK *****
|
|
* Version: MPL 1.1/GPL 2.0/LGPL 2.1
|
|
*
|
|
* The contents of this file are subject to the Mozilla Public License Version
|
|
* 1.1 (the "License"); you may not use this file except in compliance with
|
|
* the License. You may obtain a copy of the License at
|
|
* http://www.mozilla.org/MPL/
|
|
*
|
|
* Software distributed under the License is distributed on an "AS IS" basis,
|
|
* WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
|
|
* for the specific language governing rights and limitations under the
|
|
* License.
|
|
*
|
|
* The Original Code is Mozilla test code
|
|
*
|
|
* The Initial Developer of the Original Code is
|
|
* Mozilla Foundation
|
|
* Portions created by the Initial Developer are Copyright (C) 2008
|
|
* the Initial Developer. All Rights Reserved.
|
|
*
|
|
* Contributor(s):
|
|
* Ted Mielczarek <ted.mielczarek@gmail.com>
|
|
* Honza Bambas <honzab@firemni.cz>
|
|
*
|
|
* Alternatively, the contents of this file may be used under the terms of
|
|
* either the GNU General Public License Version 2 or later (the "GPL"), or
|
|
* the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
|
|
* in which case the provisions of the GPL or the LGPL are applicable instead
|
|
* of those above. If you wish to allow use of your version of this file only
|
|
* under the terms of either the GPL or the LGPL, and not to allow others to
|
|
* use your version of this file under the terms of the MPL, indicate your
|
|
* decision by deleting the provisions above and replace them with the notice
|
|
* and other provisions required by the GPL or the LGPL. If you do not delete
|
|
* the provisions above, a recipient may use your version of this file under
|
|
* the terms of any one of the MPL, the GPL or the LGPL.
|
|
*
|
|
* ***** END LICENSE BLOCK ***** */
|
|
|
|
/*
|
|
* WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS. It is highly likely to
|
|
* be plagued with the usual problems endemic to C (buffer overflows
|
|
* and the like). We don't especially care here (but would accept
|
|
* patches!) because this is only intended for use in our test
|
|
* harnesses in controlled situations where input is guaranteed not to
|
|
* be malicious.
|
|
*/
|
|
|
|
#include <assert.h>
|
|
#include <stdio.h>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <algorithm>
|
|
#include "prinit.h"
|
|
#include "prerror.h"
|
|
#include "prio.h"
|
|
#include "prnetdb.h"
|
|
#include "prtpool.h"
|
|
#include "prtypes.h"
|
|
#include "nss.h"
|
|
#include "pk11func.h"
|
|
#include "key.h"
|
|
#include "keyt.h"
|
|
#include "ssl.h"
|
|
#include "plhash.h"
|
|
|
|
using std::string;
|
|
using std::vector;
|
|
|
|
#define IS_DELIM(m, c) ((m)[(c) >> 3] & (1 << ((c) & 7)))
|
|
#define SET_DELIM(m, c) ((m)[(c) >> 3] |= (1 << ((c) & 7)))
|
|
#define DELIM_TABLE_SIZE 32
|
|
|
|
// Copied from nsCRT
|
|
char* strtok2(char* string, const char* delims, char* *newStr)
|
|
{
|
|
PR_ASSERT(string);
|
|
|
|
char delimTable[DELIM_TABLE_SIZE];
|
|
PRUint32 i;
|
|
char* result;
|
|
char* str = string;
|
|
|
|
for (i = 0; i < DELIM_TABLE_SIZE; i++)
|
|
delimTable[i] = '\0';
|
|
|
|
for (i = 0; delims[i]; i++) {
|
|
SET_DELIM(delimTable, static_cast<PRUint8>(delims[i]));
|
|
}
|
|
|
|
// skip to beginning
|
|
while (*str && IS_DELIM(delimTable, static_cast<PRUint8>(*str))) {
|
|
str++;
|
|
}
|
|
result = str;
|
|
|
|
// fix up the end of the token
|
|
while (*str) {
|
|
if (IS_DELIM(delimTable, static_cast<PRUint8>(*str))) {
|
|
*str++ = '\0';
|
|
break;
|
|
}
|
|
str++;
|
|
}
|
|
*newStr = str;
|
|
|
|
return str == result ? NULL : result;
|
|
}
|
|
|
|
|
|
|
|
enum client_auth_option {
|
|
caNone = 0,
|
|
caRequire = 1,
|
|
caRequest = 2
|
|
};
|
|
|
|
// Structs for passing data into jobs on the thread pool
|
|
typedef struct {
|
|
PRInt32 listen_port;
|
|
string cert_nickname;
|
|
PLHashTable* host_cert_table;
|
|
PLHashTable* host_clientauth_table;
|
|
} server_info_t;
|
|
|
|
typedef struct {
|
|
PRFileDesc* client_sock;
|
|
PRNetAddr client_addr;
|
|
server_info_t* server_info;
|
|
} connection_info_t;
|
|
|
|
const PRInt32 BUF_SIZE = 16384;
|
|
const PRInt32 BUF_MARGIN = 1024;
|
|
const PRInt32 BUF_TOTAL = BUF_SIZE + BUF_MARGIN;
|
|
|
|
struct relayBuffer
|
|
{
|
|
char *buffer, *bufferhead, *buffertail, *bufferend;
|
|
|
|
relayBuffer()
|
|
{
|
|
// Leave 1024 bytes more for request line manipulations
|
|
bufferhead = buffertail = buffer = new char[BUF_TOTAL];
|
|
bufferend = buffer + BUF_SIZE;
|
|
}
|
|
|
|
~relayBuffer()
|
|
{
|
|
delete [] buffer;
|
|
}
|
|
|
|
void compact() {
|
|
if (buffertail == bufferhead)
|
|
buffertail = bufferhead = buffer;
|
|
}
|
|
|
|
bool empty() { return bufferhead == buffertail; }
|
|
size_t free() { return bufferend - buffertail; }
|
|
size_t margin() { return free() + BUF_MARGIN; }
|
|
size_t present() { return buffertail - bufferhead; }
|
|
};
|
|
|
|
// A couple of stack classes for managing NSS/NSPR resources
|
|
class AutoCert {
|
|
public:
|
|
AutoCert(CERTCertificate* cert) { cert_ = cert; }
|
|
~AutoCert() { if (cert_) CERT_DestroyCertificate(cert_); }
|
|
operator CERTCertificate*() { return cert_; }
|
|
private:
|
|
CERTCertificate* cert_;
|
|
};
|
|
|
|
class AutoKey {
|
|
public:
|
|
AutoKey(SECKEYPrivateKey* key) { key_ = key; }
|
|
~AutoKey() { if (key_) SECKEY_DestroyPrivateKey(key_); }
|
|
operator SECKEYPrivateKey*() { return key_; }
|
|
private:
|
|
SECKEYPrivateKey* key_;
|
|
};
|
|
|
|
class AutoFD {
|
|
public:
|
|
AutoFD(PRFileDesc* fd) { fd_ = fd; }
|
|
~AutoFD() {
|
|
if (fd_) {
|
|
PR_Shutdown(fd_, PR_SHUTDOWN_BOTH);
|
|
PR_Close(fd_);
|
|
}
|
|
}
|
|
operator PRFileDesc*() { return fd_; }
|
|
PRFileDesc* reset(PRFileDesc* newfd) {
|
|
PRFileDesc* oldfd = fd_;
|
|
fd_ = newfd;
|
|
return oldfd;
|
|
}
|
|
private:
|
|
PRFileDesc* fd_;
|
|
};
|
|
|
|
// These are suggestions. If the number of ports to proxy on * 2
|
|
// is greater than either of these, then we'll use that value instead.
|
|
const PRUint32 INITIAL_THREADS = 1;
|
|
const PRUint32 MAX_THREADS = 5;
|
|
const PRUint32 DEFAULT_STACKSIZE = (512 * 1024);
|
|
|
|
// global data
|
|
string nssconfigdir;
|
|
vector<server_info_t> servers;
|
|
PRNetAddr remote_addr;
|
|
PRThreadPool* threads = NULL;
|
|
PRLock* shutdown_lock = NULL;
|
|
PRCondVar* shutdown_condvar = NULL;
|
|
// Not really used, unless something fails to start
|
|
bool shutdown_server = false;
|
|
bool do_http_proxy = false;
|
|
bool any_host_spec_config = false;
|
|
|
|
PR_CALLBACK PRIntn ClientAuthValueComparator(const void *v1, const void *v2)
|
|
{
|
|
int a = *static_cast<const client_auth_option*>(v1) -
|
|
*static_cast<const client_auth_option*>(v2);
|
|
if (a == 0)
|
|
return 0;
|
|
if (a > 0)
|
|
return 1;
|
|
else // (a < 0)
|
|
return -1;
|
|
}
|
|
|
|
/*
|
|
* Signal the main thread that the application should shut down.
|
|
*/
|
|
void SignalShutdown()
|
|
{
|
|
PR_Lock(shutdown_lock);
|
|
PR_NotifyCondVar(shutdown_condvar);
|
|
PR_Unlock(shutdown_lock);
|
|
}
|
|
|
|
bool ReadConnectRequest(server_info_t* server_info,
|
|
relayBuffer& buffer, PRInt32* result, string& certificate,
|
|
client_auth_option* clientauth, string& host)
|
|
{
|
|
if (buffer.present() < 4) {
|
|
printf(" !! only %d bytes present in the buffer", (int)buffer.present());
|
|
return false;
|
|
}
|
|
if (strncmp(buffer.buffertail-4, "\r\n\r\n", 4)) {
|
|
printf(" !! request is not tailed with CRLFCRLF but with %x %x %x %x",
|
|
*(buffer.buffertail-4),
|
|
*(buffer.buffertail-3),
|
|
*(buffer.buffertail-2),
|
|
*(buffer.buffertail-1));
|
|
return false;
|
|
}
|
|
|
|
printf(" parsing initial connect request, dump:\n%.*s\n", (int)buffer.present(), buffer.bufferhead);
|
|
|
|
*result = 400;
|
|
|
|
char* token;
|
|
char* _caret;
|
|
token = strtok2(buffer.bufferhead, " ", &_caret);
|
|
if (!token) {
|
|
printf(" no space found");
|
|
return true;
|
|
}
|
|
if (strcmp(token, "CONNECT")) {
|
|
printf(" not CONNECT request but %s", token);
|
|
return true;
|
|
}
|
|
|
|
token = strtok2(_caret, " ", &_caret);
|
|
void* c = PL_HashTableLookup(server_info->host_cert_table, token);
|
|
if (c)
|
|
certificate = static_cast<char*>(c);
|
|
|
|
host = "https://";
|
|
host += token;
|
|
|
|
c = PL_HashTableLookup(server_info->host_clientauth_table, token);
|
|
if (c)
|
|
*clientauth = *static_cast<client_auth_option*>(c);
|
|
else
|
|
*clientauth = caNone;
|
|
|
|
token = strtok2(_caret, "/", &_caret);
|
|
if (strcmp(token, "HTTP")) {
|
|
printf(" not tailed with HTTP but with %s", token);
|
|
return true;
|
|
}
|
|
|
|
*result = 200;
|
|
return true;
|
|
}
|
|
|
|
bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si, string &certificate, client_auth_option clientAuth)
|
|
{
|
|
const char* certnick = certificate.empty() ?
|
|
si->cert_nickname.c_str() : certificate.c_str();
|
|
|
|
AutoCert cert(PK11_FindCertFromNickname(
|
|
certnick, NULL));
|
|
if (!cert) {
|
|
fprintf(stderr, "Failed to find cert %s\n", certnick);
|
|
return false;
|
|
}
|
|
|
|
AutoKey privKey(PK11_FindKeyByAnyCert(cert, NULL));
|
|
if (!privKey) {
|
|
fprintf(stderr, "Failed to find private key\n");
|
|
return false;
|
|
}
|
|
|
|
PRFileDesc* ssl_socket = SSL_ImportFD(NULL, socket);
|
|
if (!ssl_socket) {
|
|
fprintf(stderr, "Error importing SSL socket\n");
|
|
return false;
|
|
}
|
|
|
|
SSLKEAType certKEA = NSS_FindCertKEAType(cert);
|
|
if (SSL_ConfigSecureServer(ssl_socket, cert, privKey, certKEA)
|
|
!= SECSuccess) {
|
|
fprintf(stderr, "Error configuring SSL server socket\n");
|
|
return false;
|
|
}
|
|
|
|
SSL_OptionSet(ssl_socket, SSL_SECURITY, PR_TRUE);
|
|
SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, PR_FALSE);
|
|
SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, PR_TRUE);
|
|
|
|
if (clientAuth != caNone)
|
|
{
|
|
SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, PR_TRUE);
|
|
SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire);
|
|
}
|
|
|
|
SSL_ResetHandshake(ssl_socket, PR_TRUE);
|
|
|
|
return true;
|
|
}
|
|
|
|
/**
|
|
* This function prefixes Request-URI path with a full scheme-host-port
|
|
* string.
|
|
*/
|
|
bool AdjustRequestURI(relayBuffer& buffer, string *host)
|
|
{
|
|
assert(buffer.margin());
|
|
|
|
// Cannot use strnchr so add a null char at the end. There is always some space left
|
|
// because we preserve a margin.
|
|
buffer.buffertail[1] = '\0';
|
|
printf(" incoming request to adjust:\n%s\n", buffer.bufferhead);
|
|
|
|
char *token, *path;
|
|
path = strchr(buffer.bufferhead, ' ') + 1;
|
|
if (!path)
|
|
return false;
|
|
|
|
// If the path doesn't start with a slash don't change it, it is probably '*' or a full
|
|
// path already. Return true, we are done with this request adjustment.
|
|
if (*path != '/')
|
|
return true;
|
|
|
|
token = strchr(path, ' ') + 1;
|
|
if (!token)
|
|
return false;
|
|
|
|
if (strncmp(token, "HTTP/", 5))
|
|
return false;
|
|
|
|
size_t hostlength = host->length();
|
|
assert(hostlength <= buffer.margin());
|
|
|
|
memmove(path + hostlength, path, buffer.buffertail - path);
|
|
memcpy(path, host->c_str(), hostlength);
|
|
buffer.buffertail += hostlength;
|
|
|
|
return true;
|
|
}
|
|
|
|
bool ConnectSocket(PRFileDesc *fd, const PRNetAddr *addr, PRIntervalTime timeout)
|
|
{
|
|
PRStatus stat = PR_Connect(fd, addr, timeout);
|
|
if (stat != PR_SUCCESS)
|
|
return false;
|
|
|
|
PRSocketOptionData option;
|
|
option.option = PR_SockOpt_Nonblocking;
|
|
option.value.non_blocking = PR_TRUE;
|
|
PR_SetSocketOption(fd, &option);
|
|
|
|
return true;
|
|
}
|
|
|
|
/*
|
|
* Handle an incoming client connection. The server thread has already
|
|
* accepted the connection, so we just need to connect to the remote
|
|
* port and then proxy data back and forth.
|
|
* The data parameter is a connection_info_t*, and must be deleted
|
|
* by this function.
|
|
*/
|
|
void HandleConnection(void* data)
|
|
{
|
|
connection_info_t* ci = static_cast<connection_info_t*>(data);
|
|
PRIntervalTime connect_timeout = PR_SecondsToInterval(2);
|
|
|
|
AutoFD other_sock(PR_NewTCPSocket());
|
|
bool client_done = false;
|
|
bool client_error = false;
|
|
bool connect_accepted = !do_http_proxy;
|
|
bool ssl_updated = !do_http_proxy;
|
|
bool expect_request_start = do_http_proxy;
|
|
string certificateToUse;
|
|
client_auth_option clientAuth;
|
|
string fullHost;
|
|
|
|
printf("SSLTUNNEL(%p): incoming connection csock(0)=%p, ssock(1)=%p\n", data, ci->client_sock, (PRFileDesc*)other_sock);
|
|
if (other_sock)
|
|
{
|
|
PRInt32 numberOfSockets = 1;
|
|
|
|
relayBuffer buffers[2];
|
|
|
|
if (!do_http_proxy)
|
|
{
|
|
if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, certificateToUse, caNone))
|
|
client_error = true;
|
|
else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout))
|
|
client_error = true;
|
|
else
|
|
numberOfSockets = 2;
|
|
}
|
|
|
|
PRPollDesc sockets[2] =
|
|
{
|
|
{ci->client_sock, PR_POLL_READ, 0},
|
|
{other_sock, PR_POLL_READ, 0}
|
|
};
|
|
PRBool socketErrorState[2] = {PR_FALSE, PR_FALSE};
|
|
|
|
while (!((client_error||client_done) && buffers[0].empty() && buffers[1].empty()))
|
|
{
|
|
sockets[0].in_flags |= PR_POLL_EXCEPT;
|
|
sockets[1].in_flags |= PR_POLL_EXCEPT;
|
|
printf("SSLTUNNEL(%p): polling flags csock(0)=%c%c, ssock(1)=%c%c\n", data,
|
|
sockets[0].in_flags & PR_POLL_READ ? 'R' : '-',
|
|
sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-',
|
|
sockets[1].in_flags & PR_POLL_READ ? 'R' : '-',
|
|
sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-');
|
|
PRInt32 pollStatus = PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000));
|
|
if (pollStatus < 0)
|
|
{
|
|
printf("SSLTUNNEL(%p): pollStatus=%d, exiting\n", data, pollStatus);
|
|
client_error = true;
|
|
break;
|
|
}
|
|
|
|
if (pollStatus == 0)
|
|
{
|
|
// timeout
|
|
printf("SSLTUNNEL(%p): poll timeout, looping\n", data);
|
|
continue;
|
|
}
|
|
|
|
for (PRInt32 s = 0; s < numberOfSockets; ++s)
|
|
{
|
|
PRInt32 s2 = s == 1 ? 0 : 1;
|
|
PRInt16 out_flags = sockets[s].out_flags;
|
|
PRInt16 &in_flags = sockets[s].in_flags;
|
|
PRInt16 &in_flags2 = sockets[s2].in_flags;
|
|
sockets[s].out_flags = 0;
|
|
|
|
printf("SSLTUNNEL(%p): %csock(%d)=%p out_flags=%d", data, s==0?'c':'s', s, sockets[s].fd, out_flags);
|
|
if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP))
|
|
{
|
|
printf(" :exception\n");
|
|
client_error = true;
|
|
socketErrorState[s] = PR_TRUE;
|
|
// We got a fatal error state on the socket. Clear the output buffer
|
|
// for this socket to break the main loop, we will never more be able
|
|
// to send those data anyway.
|
|
buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
|
|
continue;
|
|
} // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling
|
|
|
|
if (out_flags & PR_POLL_READ && !buffers[s].free())
|
|
{
|
|
printf(" no place in read buffer but got read flag, dropping it now!");
|
|
in_flags &= ~PR_POLL_READ;
|
|
}
|
|
|
|
if (out_flags & PR_POLL_READ && buffers[s].free())
|
|
{
|
|
printf(" :reading");
|
|
PRInt32 bytesRead = PR_Recv(sockets[s].fd, buffers[s].buffertail,
|
|
buffers[s].free(), 0, PR_INTERVAL_NO_TIMEOUT);
|
|
|
|
if (bytesRead == 0)
|
|
{
|
|
printf(" socket gracefully closed");
|
|
client_done = true;
|
|
in_flags &= ~PR_POLL_READ;
|
|
}
|
|
else if (bytesRead < 0)
|
|
{
|
|
if (PR_GetError() != PR_WOULD_BLOCK_ERROR)
|
|
{
|
|
printf(" error=%d", PR_GetError());
|
|
// We are in error state, indicate that the connection was
|
|
// not closed gracefully
|
|
client_error = true;
|
|
socketErrorState[s] = PR_TRUE;
|
|
// Wipe out our send buffer, we cannot send it anyway.
|
|
buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
|
|
}
|
|
else
|
|
printf(" would block");
|
|
}
|
|
else
|
|
{
|
|
// If the other socket is in error state (unable to send/receive)
|
|
// throw this data away and continue loop
|
|
if (socketErrorState[s2])
|
|
{
|
|
printf(" have read but other socket is in error state\n");
|
|
continue;
|
|
}
|
|
|
|
buffers[s].buffertail += bytesRead;
|
|
printf(", read %d bytes", bytesRead);
|
|
|
|
// We have to accept and handle the initial CONNECT request here
|
|
PRInt32 response;
|
|
if (!connect_accepted && ReadConnectRequest(ci->server_info, buffers[s],
|
|
&response, certificateToUse, &clientAuth, fullHost))
|
|
{
|
|
// Clean the request as it would be read
|
|
buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer;
|
|
in_flags |= PR_POLL_WRITE;
|
|
connect_accepted = true;
|
|
|
|
// Store response to the oposite buffer
|
|
if (response != 200)
|
|
{
|
|
printf(" could not read the connect request, closing connection with %d", response);
|
|
client_done = true;
|
|
sprintf(buffers[s2].buffer, "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n", response);
|
|
buffers[s2].buffertail = buffers[s2].buffer + strlen(buffers[s2].buffer);
|
|
break;
|
|
}
|
|
|
|
strcpy(buffers[s2].buffer, "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n");
|
|
buffers[s2].buffertail = buffers[s2].buffer + strlen(buffers[s2].buffer);
|
|
|
|
if (!ConnectSocket(other_sock, &remote_addr, connect_timeout))
|
|
{
|
|
printf(" could not open connection to the real server\n");
|
|
client_error = true;
|
|
break;
|
|
}
|
|
|
|
printf(" accepted CONNECT request, connected to the server, sending OK to the client\n");
|
|
// Send the response to the client socket
|
|
break;
|
|
} // end of CONNECT handling
|
|
|
|
if (!buffers[s].free())
|
|
{
|
|
// Do not poll for read when the buffer is full
|
|
printf(" no place in our read buffer, stop reading");
|
|
in_flags &= ~PR_POLL_READ;
|
|
}
|
|
|
|
if (ssl_updated)
|
|
{
|
|
if (s == 0 && expect_request_start)
|
|
expect_request_start = !AdjustRequestURI(buffers[s], &fullHost);
|
|
|
|
in_flags2 |= PR_POLL_WRITE;
|
|
printf(" telling the other socket to write");
|
|
}
|
|
else
|
|
printf(" we have something for the other socket to write, but ssl has not been administered on it");
|
|
}
|
|
} // PR_POLL_READ handling
|
|
|
|
if (out_flags & PR_POLL_WRITE)
|
|
{
|
|
printf(" :writting");
|
|
PRInt32 bytesWrite = PR_Send(sockets[s].fd, buffers[s2].bufferhead,
|
|
buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT);
|
|
|
|
if (bytesWrite < 0)
|
|
{
|
|
if (PR_GetError() != PR_WOULD_BLOCK_ERROR) {
|
|
printf(" error=%d", PR_GetError());
|
|
client_error = true;
|
|
socketErrorState[s] = PR_TRUE;
|
|
// We got a fatal error while writting the buffer. Clear it to break
|
|
// the main loop, we will never more be able to send it.
|
|
buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
|
|
}
|
|
else
|
|
printf(" would block");
|
|
}
|
|
else
|
|
{
|
|
printf(", writen %d bytes", bytesWrite);
|
|
buffers[s2].buffertail[1] = '\0';
|
|
printf(" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead);
|
|
|
|
buffers[s2].bufferhead += bytesWrite;
|
|
if (buffers[s2].present())
|
|
{
|
|
printf(" still have to write %d bytes", (int)buffers[s2].present());
|
|
in_flags |= PR_POLL_WRITE;
|
|
}
|
|
else
|
|
{
|
|
if (!ssl_updated)
|
|
{
|
|
printf(" proxy response sent to the client");
|
|
// Proxy response has just been writen, update to ssl
|
|
ssl_updated = true;
|
|
if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, certificateToUse, clientAuth))
|
|
{
|
|
printf(" but failed to config server socket\n");
|
|
client_error = true;
|
|
break;
|
|
}
|
|
|
|
printf(" client socket updated to SSL");
|
|
numberOfSockets = 2;
|
|
} // sslUpdate
|
|
|
|
printf(" dropping our write flag and setting other socket read flag");
|
|
in_flags &= ~PR_POLL_WRITE;
|
|
in_flags2 |= PR_POLL_READ;
|
|
buffers[s2].compact();
|
|
}
|
|
}
|
|
} // PR_POLL_WRITE handling
|
|
printf("\n"); // end the log
|
|
} // for...
|
|
} // while, poll
|
|
}
|
|
else
|
|
client_error = true;
|
|
|
|
printf("SSLTUNNEL(%p): exiting root function for csock=%p, ssock=%p\n", data, ci->client_sock, (PRFileDesc*)other_sock);
|
|
if (!client_error)
|
|
PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND);
|
|
PR_Close(ci->client_sock);
|
|
|
|
delete ci;
|
|
}
|
|
|
|
/*
|
|
* Start listening for SSL connections on a specified port, handing
|
|
* them off to client threads after accepting the connection.
|
|
* The data parameter is a server_info_t*, owned by the calling
|
|
* function.
|
|
*/
|
|
void StartServer(void* data)
|
|
{
|
|
server_info_t* si = static_cast<server_info_t*>(data);
|
|
|
|
//TODO: select ciphers?
|
|
AutoFD listen_socket(PR_NewTCPSocket());
|
|
if (!listen_socket) {
|
|
fprintf(stderr, "failed to create socket\n");
|
|
SignalShutdown();
|
|
return;
|
|
}
|
|
|
|
// In case the socket is still open in the TIME_WAIT state from a previous
|
|
// instance of ssltunnel we ask to reuse the port.
|
|
PRSocketOptionData socket_option;
|
|
socket_option.option = PR_SockOpt_Reuseaddr;
|
|
socket_option.value.reuse_addr = PR_TRUE;
|
|
PR_SetSocketOption(listen_socket, &socket_option);
|
|
|
|
PRNetAddr server_addr;
|
|
PR_InitializeNetAddr(PR_IpAddrAny, si->listen_port, &server_addr);
|
|
if (PR_Bind(listen_socket, &server_addr) != PR_SUCCESS) {
|
|
fprintf(stderr, "failed to bind socket\n");
|
|
SignalShutdown();
|
|
return;
|
|
}
|
|
|
|
if (PR_Listen(listen_socket, 1) != PR_SUCCESS) {
|
|
fprintf(stderr, "failed to listen on socket\n");
|
|
SignalShutdown();
|
|
return;
|
|
}
|
|
|
|
printf("Server listening on port %d with cert %s\n", si->listen_port,
|
|
si->cert_nickname.c_str());
|
|
|
|
while (!shutdown_server) {
|
|
connection_info_t* ci = new connection_info_t();
|
|
ci->server_info = si;
|
|
// block waiting for connections
|
|
ci->client_sock = PR_Accept(listen_socket, &ci->client_addr,
|
|
PR_INTERVAL_NO_TIMEOUT);
|
|
|
|
PRSocketOptionData option;
|
|
option.option = PR_SockOpt_Nonblocking;
|
|
option.value.non_blocking = PR_TRUE;
|
|
PR_SetSocketOption(ci->client_sock, &option);
|
|
|
|
if (ci->client_sock)
|
|
// Not actually using this PRJob*...
|
|
//PRJob* job =
|
|
PR_QueueJob(threads, HandleConnection, ci, PR_TRUE);
|
|
else
|
|
delete ci;
|
|
}
|
|
}
|
|
|
|
// bogus password func, just don't use passwords. :-P
|
|
char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg)
|
|
{
|
|
if (retry)
|
|
return NULL;
|
|
|
|
return PL_strdup("");
|
|
}
|
|
|
|
server_info_t* findServerInfo(int portnumber)
|
|
{
|
|
for (vector<server_info_t>::iterator it = servers.begin();
|
|
it != servers.end(); it++)
|
|
{
|
|
if (it->listen_port == portnumber)
|
|
return &(*it);
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
int processConfigLine(char* configLine)
|
|
{
|
|
if (*configLine == 0 || *configLine == '#')
|
|
return 0;
|
|
|
|
char* _caret;
|
|
char* keyword = strtok2(configLine, ":", &_caret);
|
|
|
|
// Configure usage of http/ssl tunneling proxy behavior
|
|
if (!strcmp(keyword, "httpproxy"))
|
|
{
|
|
char* value = strtok2(_caret, ":", &_caret);
|
|
if (!strcmp(value, "1"))
|
|
do_http_proxy = true;
|
|
|
|
return 0;
|
|
}
|
|
|
|
// Configure the forward address of the target server
|
|
if (!strcmp(keyword, "forward"))
|
|
{
|
|
char* ipstring = strtok2(_caret, ":", &_caret);
|
|
if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) {
|
|
fprintf(stderr, "Invalid remote IP address: %s\n", ipstring);
|
|
return 1;
|
|
}
|
|
char* serverportstring = strtok2(_caret, ":", &_caret);
|
|
int port = atoi(serverportstring);
|
|
if (port <= 0) {
|
|
fprintf(stderr, "Invalid remote port: %s\n", serverportstring);
|
|
return 1;
|
|
}
|
|
remote_addr.inet.port = PR_htons(port);
|
|
|
|
return 0;
|
|
}
|
|
|
|
// Configure all listen sockets and port+certificate bindings
|
|
if (!strcmp(keyword, "listen"))
|
|
{
|
|
char* hostname = strtok2(_caret, ":", &_caret);
|
|
char* hostportstring = NULL;
|
|
if (strcmp(hostname, "*"))
|
|
{
|
|
any_host_spec_config = true;
|
|
hostportstring = strtok2(_caret, ":", &_caret);
|
|
}
|
|
|
|
char* serverportstring = strtok2(_caret, ":", &_caret);
|
|
char* certnick = strtok2(_caret, ":", &_caret);
|
|
|
|
int port = atoi(serverportstring);
|
|
if (port <= 0) {
|
|
fprintf(stderr, "Invalid port specified: %s\n", serverportstring);
|
|
return 1;
|
|
}
|
|
|
|
if (server_info_t* existingServer = findServerInfo(port))
|
|
{
|
|
char *certnick_copy = new char[strlen(certnick)+1];
|
|
char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
|
|
|
|
strcpy(hostname_copy, hostname);
|
|
strcat(hostname_copy, ":");
|
|
strcat(hostname_copy, hostportstring);
|
|
strcpy(certnick_copy, certnick);
|
|
|
|
PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table, hostname_copy, certnick_copy);
|
|
if (!entry) {
|
|
fprintf(stderr, "Out of memory");
|
|
return 1;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
server_info_t server;
|
|
server.cert_nickname = certnick;
|
|
server.listen_port = port;
|
|
server.host_cert_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, PL_CompareStrings, NULL, NULL);
|
|
if (!server.host_cert_table)
|
|
{
|
|
fprintf(stderr, "Internal, could not create hash table\n");
|
|
return 1;
|
|
}
|
|
server.host_clientauth_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, ClientAuthValueComparator, NULL, NULL);
|
|
if (!server.host_clientauth_table)
|
|
{
|
|
fprintf(stderr, "Internal, could not create hash table\n");
|
|
return 1;
|
|
}
|
|
servers.push_back(server);
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
if (!strcmp(keyword, "clientauth"))
|
|
{
|
|
char* hostname = strtok2(_caret, ":", &_caret);
|
|
char* hostportstring = strtok2(_caret, ":", &_caret);
|
|
char* serverportstring = strtok2(_caret, ":", &_caret);
|
|
|
|
int port = atoi(serverportstring);
|
|
if (port <= 0) {
|
|
fprintf(stderr, "Invalid port specified: %s\n", serverportstring);
|
|
return 1;
|
|
}
|
|
|
|
if (server_info_t* existingServer = findServerInfo(port))
|
|
{
|
|
char* authoptionstring = strtok2(_caret, ":", &_caret);
|
|
client_auth_option* authoption = new client_auth_option;
|
|
if (!authoption) {
|
|
fprintf(stderr, "Out of memory");
|
|
return 1;
|
|
}
|
|
|
|
if (!strcmp(authoptionstring, "require"))
|
|
*authoption = caRequire;
|
|
else if (!strcmp(authoptionstring, "request"))
|
|
*authoption = caRequest;
|
|
else if (!strcmp(authoptionstring, "none"))
|
|
*authoption = caNone;
|
|
else
|
|
{
|
|
fprintf(stderr, "Incorrect client auth option modifier for host '%s'", hostname);
|
|
return 1;
|
|
}
|
|
|
|
any_host_spec_config = true;
|
|
|
|
char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
|
|
if (!hostname_copy) {
|
|
fprintf(stderr, "Out of memory");
|
|
return 1;
|
|
}
|
|
|
|
strcpy(hostname_copy, hostname);
|
|
strcat(hostname_copy, ":");
|
|
strcat(hostname_copy, hostportstring);
|
|
|
|
PLHashEntry* entry = PL_HashTableAdd(existingServer->host_clientauth_table, hostname_copy, authoption);
|
|
if (!entry) {
|
|
fprintf(stderr, "Out of memory");
|
|
return 1;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
fprintf(stderr, "Server on port %d for client authentication option is not defined, use 'listen' option first", port);
|
|
return 1;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
// Configure the NSS certificate database directory
|
|
if (!strcmp(keyword, "certdbdir"))
|
|
{
|
|
nssconfigdir = strtok2(_caret, "\n", &_caret);
|
|
return 0;
|
|
}
|
|
|
|
printf("Error: keyword \"%s\" unexpected\n", keyword);
|
|
return 1;
|
|
}
|
|
|
|
int parseConfigFile(const char* filePath)
|
|
{
|
|
FILE* f = fopen(filePath, "r");
|
|
if (!f)
|
|
return 1;
|
|
|
|
char buffer[1024], *b = buffer;
|
|
while (!feof(f))
|
|
{
|
|
char c;
|
|
fscanf(f, "%c", &c);
|
|
switch (c)
|
|
{
|
|
case '\n':
|
|
*b++ = 0;
|
|
if (processConfigLine(buffer))
|
|
return 1;
|
|
b = buffer;
|
|
case '\r':
|
|
continue;
|
|
default:
|
|
*b++ = c;
|
|
}
|
|
}
|
|
|
|
fclose(f);
|
|
|
|
// Check mandatory items
|
|
if (nssconfigdir.empty())
|
|
{
|
|
printf("Error: missing path to NSS certification database\n,use certdbdir:<path> in the config file\n");
|
|
return 1;
|
|
}
|
|
|
|
if (any_host_spec_config && !do_http_proxy)
|
|
{
|
|
printf("Warning: any host-specific configurations are ignored, add httpproxy:1 to allow them\n");
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
PRIntn freeHostCertHashItems(PLHashEntry *he, PRIntn i, void *arg)
|
|
{
|
|
delete [] (char*)he->key;
|
|
delete [] (char*)he->value;
|
|
return HT_ENUMERATE_REMOVE;
|
|
}
|
|
|
|
PRIntn freeClientAuthHashItems(PLHashEntry *he, PRIntn i, void *arg)
|
|
{
|
|
delete [] (char*)he->key;
|
|
delete (client_auth_option*)he->value;
|
|
return HT_ENUMERATE_REMOVE;
|
|
}
|
|
|
|
int main(int argc, char** argv)
|
|
{
|
|
const char* configFilePath;
|
|
if (argc == 1)
|
|
configFilePath = "ssltunnel.cfg";
|
|
else
|
|
configFilePath = argv[1];
|
|
|
|
if (parseConfigFile(configFilePath)) {
|
|
fprintf(stderr, "Error: config file \"%s\" missing or formating incorrect\n"
|
|
"Specify path to the config file as parameter to ssltunnel or \n"
|
|
"create ssltunnel.cfg in the working directory.\n\n"
|
|
"Example format of the config file:\n\n"
|
|
" # Enable http/ssl tunneling proxy-like behavior.\n"
|
|
" # If not specified ssltunnel simply does direct forward.\n"
|
|
" httpproxy:1\n\n"
|
|
" # Specify path to the certification database used.\n"
|
|
" certdbdir:/path/to/certdb\n\n"
|
|
" # Forward/proxy all requests in raw to 127.0.0.1:8888.\n"
|
|
" forward:127.0.0.1:8888\n\n"
|
|
" # Accept connections on port 4443 or 5678 resp. and authenticate\n"
|
|
" # to any host ('*') using the 'server cert' or 'server cert 2' resp.\n"
|
|
" listen:*:4443:server cert\n"
|
|
" listen:*:5678:server cert 2\n\n"
|
|
" # Accept connections on port 4443 and authenticate using\n"
|
|
" # 'a different cert' when target host is 'my.host.name:443'.\n"
|
|
" # This only works in httpproxy mode and has higher priority\n"
|
|
" # than the previous option.\n"
|
|
" listen:my.host.name:443:4443:a different cert\n\n"
|
|
" # To make a specific host require or just request a client certificate\n"
|
|
" # to authenticate use the following options. This can only be used\n"
|
|
" # in httpproxy mode and only after the 'listen' option has been\n"
|
|
" # specified. You also have to specify the tunnel listen port.\n"
|
|
" clientauth:requesting-client-cert.host.com:443:4443:request\n"
|
|
" clientauth:requiring-client-cert.host.com:443:4443:require\n",
|
|
configFilePath);
|
|
return 1;
|
|
}
|
|
|
|
// create a thread pool to handle connections
|
|
threads = PR_CreateThreadPool(PR_MAX(INITIAL_THREADS, servers.size()*2),
|
|
PR_MAX(MAX_THREADS, servers.size()*2),
|
|
DEFAULT_STACKSIZE);
|
|
if (!threads) {
|
|
fprintf(stderr, "Failed to create thread pool\n");
|
|
return 1;
|
|
}
|
|
|
|
shutdown_lock = PR_NewLock();
|
|
if (!shutdown_lock) {
|
|
fprintf(stderr, "Failed to create lock\n");
|
|
PR_ShutdownThreadPool(threads);
|
|
return 1;
|
|
}
|
|
shutdown_condvar = PR_NewCondVar(shutdown_lock);
|
|
if (!shutdown_condvar) {
|
|
fprintf(stderr, "Failed to create condvar\n");
|
|
PR_ShutdownThreadPool(threads);
|
|
PR_DestroyLock(shutdown_lock);
|
|
return 1;
|
|
}
|
|
|
|
PK11_SetPasswordFunc(password_func);
|
|
|
|
// Initialize NSS
|
|
if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) {
|
|
PRInt32 errorlen = PR_GetErrorTextLength();
|
|
char* err = new char[errorlen+1];
|
|
PR_GetErrorText(err);
|
|
fprintf(stderr, "Failed to init NSS: %s", err);
|
|
delete[] err;
|
|
PR_ShutdownThreadPool(threads);
|
|
PR_DestroyCondVar(shutdown_condvar);
|
|
PR_DestroyLock(shutdown_lock);
|
|
return 1;
|
|
}
|
|
|
|
if (NSS_SetDomesticPolicy() != SECSuccess) {
|
|
fprintf(stderr, "NSS_SetDomesticPolicy failed\n");
|
|
PR_ShutdownThreadPool(threads);
|
|
PR_DestroyCondVar(shutdown_condvar);
|
|
PR_DestroyLock(shutdown_lock);
|
|
NSS_Shutdown();
|
|
return 1;
|
|
}
|
|
|
|
// these values should make NSS use the defaults
|
|
if (SSL_ConfigServerSessionIDCache(0, 0, 0, NULL) != SECSuccess) {
|
|
fprintf(stderr, "SSL_ConfigServerSessionIDCache failed\n");
|
|
PR_ShutdownThreadPool(threads);
|
|
PR_DestroyCondVar(shutdown_condvar);
|
|
PR_DestroyLock(shutdown_lock);
|
|
NSS_Shutdown();
|
|
return 1;
|
|
}
|
|
|
|
for (vector<server_info_t>::iterator it = servers.begin();
|
|
it != servers.end(); it++) {
|
|
// Not actually using this PRJob*...
|
|
// PRJob* server_job =
|
|
PR_QueueJob(threads, StartServer, &(*it), PR_TRUE);
|
|
}
|
|
// now wait for someone to tell us to quit
|
|
PR_Lock(shutdown_lock);
|
|
PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT);
|
|
PR_Unlock(shutdown_lock);
|
|
shutdown_server = true;
|
|
printf("Shutting down...\n");
|
|
// cleanup
|
|
PR_ShutdownThreadPool(threads);
|
|
PR_JoinThreadPool(threads);
|
|
PR_DestroyCondVar(shutdown_condvar);
|
|
PR_DestroyLock(shutdown_lock);
|
|
if (NSS_Shutdown() == SECFailure) {
|
|
fprintf(stderr, "Leaked NSS objects!\n");
|
|
}
|
|
|
|
for (vector<server_info_t>::iterator it = servers.begin();
|
|
it != servers.end(); it++)
|
|
{
|
|
PL_HashTableEnumerateEntries(it->host_cert_table, freeHostCertHashItems, NULL);
|
|
PL_HashTableEnumerateEntries(it->host_clientauth_table, freeClientAuthHashItems, NULL);
|
|
PL_HashTableDestroy(it->host_cert_table);
|
|
PL_HashTableDestroy(it->host_clientauth_table);
|
|
}
|
|
|
|
PR_Cleanup();
|
|
return 0;
|
|
}
|