Adjust for upstream cert store changes, improve cert handling significantly

This commit is contained in:
Aldo Cortesi 2014-03-05 15:03:31 +13:00
parent 944f213ebc
commit e54bf1a804
3 changed files with 35 additions and 34 deletions

View File

@ -6,6 +6,7 @@ import version, app, language, utils
DEFAULT_CERT_DOMAIN = "pathod.net" DEFAULT_CERT_DOMAIN = "pathod.net"
CONFDIR = "~/.mitmproxy" CONFDIR = "~/.mitmproxy"
CERTSTORE_BASENAME = "mitmproxy"
CA_CERT_NAME = "mitmproxy-ca.pem" CA_CERT_NAME = "mitmproxy-ca.pem"
logger = logging.getLogger('pathod') logger = logging.getLogger('pathod')
@ -14,28 +15,23 @@ class PathodError(Exception): pass
class SSLOptions: class SSLOptions:
def __init__(self, confdir=CONFDIR, cn=None, certfile=None, keyfile=None, def __init__(self, confdir=CONFDIR, cn=None, not_after_connect=None,
not_after_connect=None, request_client_cert=False, request_client_cert=False, sslversion=tcp.SSLv23_METHOD,
sslversion=tcp.SSLv23_METHOD, ciphers=None): ciphers=None, certs=None):
self.confdir = confdir self.confdir = confdir
self.cn = cn self.cn = cn
if keyfile: self.certstore = certutils.CertStore.from_store(
self.keyfile = os.path.expanduser(keyfile) os.path.expanduser(confdir),
else: CERTSTORE_BASENAME
keyfile = os.path.join(confdir, CA_CERT_NAME) )
self.keyfile = os.path.expanduser(keyfile) for i in certs or []:
if not os.path.exists(self.keyfile): self.certstore.add_cert_file(*i)
certutils.dummy_ca(self.keyfile)
self.certstore = certutils.CertStore(self.keyfile)
self.certfile = certfile
self.not_after_connect = not_after_connect self.not_after_connect = not_after_connect
self.request_client_cert = request_client_cert self.request_client_cert = request_client_cert
self.ciphers = ciphers self.ciphers = ciphers
self.sslversion = sslversion self.sslversion = sslversion
def get_cert(self, name): def get_cert(self, name):
if self.certfile:
return certutils.SSLCert.from_pem(file(self.certfile, "rb").read())
if self.cn: if self.cn:
name = self.cn name = self.cn
elif not name: elif not name:
@ -97,9 +93,9 @@ class PathodHandler(tcp.BaseHandler):
self.wfile.flush() self.wfile.flush()
if not self.server.ssloptions.not_after_connect: if not self.server.ssloptions.not_after_connect:
try: try:
cert, key = self.server.ssloptions.get_cert(m.v[0])
self.convert_to_ssl( self.convert_to_ssl(
self.server.ssloptions.get_cert(None), cert, key,
self.server.ssloptions.keyfile,
handle_sni = self.handle_sni, handle_sni = self.handle_sni,
request_client_cert = self.server.ssloptions.request_client_cert, request_client_cert = self.server.ssloptions.request_client_cert,
cipher_list = self.server.ssloptions.ciphers, cipher_list = self.server.ssloptions.ciphers,
@ -213,9 +209,9 @@ class PathodHandler(tcp.BaseHandler):
def handle(self): def handle(self):
if self.server.ssl: if self.server.ssl:
try: try:
cert, key = self.server.ssloptions.get_cert(None)
self.convert_to_ssl( self.convert_to_ssl(
self.server.ssloptions.get_cert(None), cert, key,
self.server.ssloptions.keyfile,
handle_sni = self.handle_sni, handle_sni = self.handle_sni,
request_client_cert = self.server.ssloptions.request_client_cert, request_client_cert = self.server.ssloptions.request_client_cert,
cipher_list = self.server.ssloptions.ciphers, cipher_list = self.server.ssloptions.ciphers,

28
pathod
View File

@ -31,14 +31,23 @@ def daemonize (stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'):
def main(parser, args): def main(parser, args):
certs = []
for i in args.ssl_certs:
parts = i.split("=", 1)
if len(parts) == 1:
parts = ["*", parts[0]]
parts[1] = os.path.expanduser(parts[1])
if not os.path.exists(parts[1]):
parser.error("Certificate file does not exist: %s"%parts[1])
certs.append(parts)
ssloptions = pathod.SSLOptions( ssloptions = pathod.SSLOptions(
cn = args.cn, cn = args.cn,
confdir = args.confdir, confdir = args.confdir,
certfile = args.ssl_certfile,
keyfile = args.ssl_keyfile or args.ssl_certfile,
not_after_connect = args.ssl_not_after_connect, not_after_connect = args.ssl_not_after_connect,
ciphers = args.ciphers, ciphers = args.ciphers,
sslversion = utils.SSLVERSIONS[args.sslversion] sslversion = utils.SSLVERSIONS[args.sslversion],
certs = certs
) )
alst = [] alst = []
@ -174,12 +183,12 @@ if __name__ == "__main__":
help="Don't expect SSL after a CONNECT request." help="Don't expect SSL after a CONNECT request."
) )
group.add_argument( group.add_argument(
"--certfile", dest='ssl_certfile', default=None, type=str, "--cert", dest='ssl_certs', default=[], type=str,
help='SSL certificate in PEM format, optionally with the key in the same file.' metavar = "SPEC", action="append",
) help='Add an SSL certificate. SPEC is of the form "[domain=]path". '\
group.add_argument( 'The domain may include a wildcard, and is equal to "*" if not specified. '\
"--keyfile", dest='ssl_keyfile', default=None, type=str, 'The file at path is a certificate in PEM format. If a private key is included in the PEM, '\
help='Key matching certfile.' 'it is used, else the default key in the conf dir is used. Can be passed multiple times.'
) )
group.add_argument( group.add_argument(
"--ciphers", dest="ciphers", type=str, default=False, "--ciphers", dest="ciphers", type=str, default=False,
@ -218,7 +227,6 @@ if __name__ == "__main__":
"-x", dest="hexdump", action="store_true", default=False, "-x", dest="hexdump", action="store_true", default=False,
help="Log request/response in hexdump format" help="Log request/response in hexdump format"
) )
args = parser.parse_args() args = parser.parse_args()
if args.daemonize: if args.daemonize:
daemonize() daemonize()

View File

@ -1,6 +1,6 @@
import pprint import pprint
from libpathod import pathod, version from libpathod import pathod, version
from netlib import tcp, http from netlib import tcp, http, certutils
import requests import requests
import tutils import tutils
@ -66,14 +66,13 @@ class TestNotAfterConnect(tutils.DaemonTests):
class TestCustomCert(tutils.DaemonTests): class TestCustomCert(tutils.DaemonTests):
ssl = True ssl = True
ssloptions = dict( ssloptions = dict(
certfile = tutils.test_data.path("data/testkey.pem"), certs = [("*", tutils.test_data.path("data/testkey.pem"))],
keyfile = tutils.test_data.path("data/testkey.pem"),
) )
def test_connect(self): def test_connect(self):
r = self.pathoc(r"get:/p/202") r = self.pathoc(r"get:/p/202")
assert r.status_code == 202 assert r.status_code == 202
assert r.sslinfo assert r.sslinfo
assert "Widgits" in str(r.sslinfo.certchain[0].get_subject())
class TestSSLCN(tutils.DaemonTests): class TestSSLCN(tutils.DaemonTests):
@ -224,5 +223,3 @@ class TestDaemonSSL(CommonTests):
assert r.status_code == 202 assert r.status_code == 202
assert self.d.last_log()["cipher"][1] > 0 assert self.d.last_log()["cipher"][1] > 0