diff --git a/libpathod/pathod.py b/libpathod/pathod.py index 9730834a3..e4dac32bf 100644 --- a/libpathod/pathod.py +++ b/libpathod/pathod.py @@ -6,6 +6,7 @@ import version, app, language, utils DEFAULT_CERT_DOMAIN = "pathod.net" CONFDIR = "~/.mitmproxy" +CERTSTORE_BASENAME = "mitmproxy" CA_CERT_NAME = "mitmproxy-ca.pem" logger = logging.getLogger('pathod') @@ -14,28 +15,23 @@ class PathodError(Exception): pass class SSLOptions: - def __init__(self, confdir=CONFDIR, cn=None, certfile=None, keyfile=None, - not_after_connect=None, request_client_cert=False, - sslversion=tcp.SSLv23_METHOD, ciphers=None): + def __init__(self, confdir=CONFDIR, cn=None, not_after_connect=None, + request_client_cert=False, sslversion=tcp.SSLv23_METHOD, + ciphers=None, certs=None): self.confdir = confdir self.cn = cn - if keyfile: - self.keyfile = os.path.expanduser(keyfile) - else: - keyfile = os.path.join(confdir, CA_CERT_NAME) - self.keyfile = os.path.expanduser(keyfile) - if not os.path.exists(self.keyfile): - certutils.dummy_ca(self.keyfile) - self.certstore = certutils.CertStore(self.keyfile) - self.certfile = certfile + self.certstore = certutils.CertStore.from_store( + os.path.expanduser(confdir), + CERTSTORE_BASENAME + ) + for i in certs or []: + self.certstore.add_cert_file(*i) self.not_after_connect = not_after_connect self.request_client_cert = request_client_cert self.ciphers = ciphers self.sslversion = sslversion def get_cert(self, name): - if self.certfile: - return certutils.SSLCert.from_pem(file(self.certfile, "rb").read()) if self.cn: name = self.cn elif not name: @@ -97,9 +93,9 @@ class PathodHandler(tcp.BaseHandler): self.wfile.flush() if not self.server.ssloptions.not_after_connect: try: + cert, key = self.server.ssloptions.get_cert(m.v[0]) self.convert_to_ssl( - self.server.ssloptions.get_cert(None), - self.server.ssloptions.keyfile, + cert, key, handle_sni = self.handle_sni, request_client_cert = self.server.ssloptions.request_client_cert, cipher_list = self.server.ssloptions.ciphers, @@ -213,9 +209,9 @@ class PathodHandler(tcp.BaseHandler): def handle(self): if self.server.ssl: try: + cert, key = self.server.ssloptions.get_cert(None) self.convert_to_ssl( - self.server.ssloptions.get_cert(None), - self.server.ssloptions.keyfile, + cert, key, handle_sni = self.handle_sni, request_client_cert = self.server.ssloptions.request_client_cert, cipher_list = self.server.ssloptions.ciphers, diff --git a/pathod b/pathod index d150eac0d..2e9fafc4f 100755 --- a/pathod +++ b/pathod @@ -31,14 +31,23 @@ def daemonize (stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'): def main(parser, args): + certs = [] + for i in args.ssl_certs: + parts = i.split("=", 1) + if len(parts) == 1: + parts = ["*", parts[0]] + parts[1] = os.path.expanduser(parts[1]) + if not os.path.exists(parts[1]): + parser.error("Certificate file does not exist: %s"%parts[1]) + certs.append(parts) + ssloptions = pathod.SSLOptions( cn = args.cn, confdir = args.confdir, - certfile = args.ssl_certfile, - keyfile = args.ssl_keyfile or args.ssl_certfile, not_after_connect = args.ssl_not_after_connect, ciphers = args.ciphers, - sslversion = utils.SSLVERSIONS[args.sslversion] + sslversion = utils.SSLVERSIONS[args.sslversion], + certs = certs ) alst = [] @@ -174,12 +183,12 @@ if __name__ == "__main__": help="Don't expect SSL after a CONNECT request." ) group.add_argument( - "--certfile", dest='ssl_certfile', default=None, type=str, - help='SSL certificate in PEM format, optionally with the key in the same file.' - ) - group.add_argument( - "--keyfile", dest='ssl_keyfile', default=None, type=str, - help='Key matching certfile.' + "--cert", dest='ssl_certs', default=[], type=str, + metavar = "SPEC", action="append", + help='Add an SSL certificate. SPEC is of the form "[domain=]path". '\ + 'The domain may include a wildcard, and is equal to "*" if not specified. '\ + 'The file at path is a certificate in PEM format. If a private key is included in the PEM, '\ + 'it is used, else the default key in the conf dir is used. Can be passed multiple times.' ) group.add_argument( "--ciphers", dest="ciphers", type=str, default=False, @@ -218,7 +227,6 @@ if __name__ == "__main__": "-x", dest="hexdump", action="store_true", default=False, help="Log request/response in hexdump format" ) - args = parser.parse_args() if args.daemonize: daemonize() diff --git a/test/test_pathod.py b/test/test_pathod.py index 1ab330954..b1529f772 100644 --- a/test/test_pathod.py +++ b/test/test_pathod.py @@ -1,6 +1,6 @@ import pprint from libpathod import pathod, version -from netlib import tcp, http +from netlib import tcp, http, certutils import requests import tutils @@ -66,14 +66,13 @@ class TestNotAfterConnect(tutils.DaemonTests): class TestCustomCert(tutils.DaemonTests): ssl = True ssloptions = dict( - certfile = tutils.test_data.path("data/testkey.pem"), - keyfile = tutils.test_data.path("data/testkey.pem"), + certs = [("*", tutils.test_data.path("data/testkey.pem"))], ) def test_connect(self): r = self.pathoc(r"get:/p/202") assert r.status_code == 202 assert r.sslinfo - + assert "Widgits" in str(r.sslinfo.certchain[0].get_subject()) class TestSSLCN(tutils.DaemonTests): @@ -224,5 +223,3 @@ class TestDaemonSSL(CommonTests): assert r.status_code == 202 assert self.d.last_log()["cipher"][1] > 0 - -