diff --git a/docshell/base/nsDocShell.cpp b/docshell/base/nsDocShell.cpp index d8243cc86997..409080ebb637 100644 --- a/docshell/base/nsDocShell.cpp +++ b/docshell/base/nsDocShell.cpp @@ -6891,10 +6891,11 @@ nsresult nsDocShell::DoChannelLoad(nsIChannel * aChannel, nsIURILoader * aURILoader) { nsresult rv; - // Mark the channel as being a document URI... + // Mark the channel as being a document URI and allow content sniffing... nsLoadFlags loadFlags = 0; (void) aChannel->GetLoadFlags(&loadFlags); - loadFlags |= nsIChannel::LOAD_DOCUMENT_URI; + loadFlags |= nsIChannel::LOAD_DOCUMENT_URI | + nsIChannel::LOAD_CALL_CONTENT_SNIFFERS; // Load attributes depend on load type... switch (mLoadType) { diff --git a/netwerk/base/public/nsIChannel.idl b/netwerk/base/public/nsIChannel.idl index 3eaa39ae1324..353a80256f7a 100644 --- a/netwerk/base/public/nsIChannel.idl +++ b/netwerk/base/public/nsIChannel.idl @@ -182,7 +182,7 @@ interface nsIChannel : nsIRequest /************************************************************************** * Channel specific load flags: * - * Bits 21-31 are reserved for future use by this interface or one of its + * Bits 22-31 are reserved for future use by this interface or one of its * derivatives (e.g., see nsICachingChannel). */ @@ -216,4 +216,13 @@ interface nsIChannel : nsIRequest * for this load has been determined. */ const unsigned long LOAD_TARGETED = 1 << 20; + + /** + * If this flag is set, the channel should call the content sniffers as + * described in nsNetCID.h about NS_CONTENT_SNIFFER_CATEGORY. + * + * Note: Channels may ignore this flag; however, new channel implementations + * should only do so with good reason. + */ + const unsigned long LOAD_CALL_CONTENT_SNIFFERS = 1 << 21; }; diff --git a/netwerk/base/src/nsBaseChannel.cpp b/netwerk/base/src/nsBaseChannel.cpp index 91458ed23929..fb18a91a9831 100644 --- a/netwerk/base/src/nsBaseChannel.cpp +++ b/netwerk/base/src/nsBaseChannel.cpp @@ -233,11 +233,8 @@ nsBaseChannel::BeginPumpingData() // and especially when we call into the loadgroup. Our caller takes care to // release mPump if we return an error. - mPump = new nsInputStreamPump(); - if (!mPump) - return NS_ERROR_OUT_OF_MEMORY; - - rv = mPump->Init(stream, -1, -1, 0, 0, PR_TRUE); + rv = nsInputStreamPump::Create(getter_AddRefs(mPump), stream, -1, -1, 0, 0, + PR_TRUE); if (NS_SUCCEEDED(rv)) rv = mPump->AsyncRead(this, nsnull); @@ -555,6 +552,25 @@ CallTypeSniffers(void *aClosure, const PRUint8 *aData, PRUint32 aCount) { nsIChannel *chan = NS_STATIC_CAST(nsIChannel*, aClosure); + const nsCOMArray& sniffers = + gIOService->GetContentSniffers(); + PRUint32 length = sniffers.Count(); + for (PRUint32 i = 0; i < length; ++i) { + nsCAutoString newType; + nsresult rv = + sniffers[i]->GetMIMETypeFromContent(chan, aData, aCount, newType); + if (NS_SUCCEEDED(rv) && !newType.IsEmpty()) { + chan->SetContentType(newType); + break; + } + } +} + +static void +CallUnknownTypeSniffer(void *aClosure, const PRUint8 *aData, PRUint32 aCount) +{ + nsIChannel *chan = NS_STATIC_CAST(nsIChannel*, aClosure); + nsCOMPtr sniffer = do_CreateInstance(NS_GENERIC_CONTENT_SNIFFER); if (!sniffer) @@ -572,9 +588,14 @@ nsBaseChannel::OnStartRequest(nsIRequest *request, nsISupports *ctxt) // If our content type is unknown, then use the content type sniffer. If the // sniffer is not available for some reason, then we just keep going as-is. if (NS_SUCCEEDED(mStatus) && mContentType.EqualsLiteral(UNKNOWN_CONTENT_TYPE)) { - mPump->PeekStream(CallTypeSniffers, NS_STATIC_CAST(nsIChannel*, this)); + mPump->PeekStream(CallUnknownTypeSniffer, NS_STATIC_CAST(nsIChannel*, this)); } + // Now, the general type sniffers. Skip this if we have none. + if ((mLoadFlags & LOAD_CALL_CONTENT_SNIFFERS) && + gIOService->GetContentSniffers().Count() != 0) + mPump->PeekStream(CallTypeSniffers, NS_STATIC_CAST(nsIChannel*, this)); + SUSPEND_PUMP_FOR_SCOPE(); return mListener->OnStartRequest(this, mListenerContext); diff --git a/netwerk/base/src/nsIOService.cpp b/netwerk/base/src/nsIOService.cpp index 2ed08dc0d3a5..36f60c866074 100644 --- a/netwerk/base/src/nsIOService.cpp +++ b/netwerk/base/src/nsIOService.cpp @@ -155,6 +155,7 @@ nsIOService::nsIOService() : mOffline(PR_FALSE) , mOfflineForProfileChange(PR_FALSE) , mChannelEventSinks(NS_CHANNEL_EVENT_SINK_CATEGORY) + , mContentSniffers(NS_CONTENT_SNIFFER_CATEGORY) { // Get the allocator ready if (!gBufferCache) diff --git a/netwerk/base/src/nsIOService.h b/netwerk/base/src/nsIOService.h index de76524b541e..a55b8793abf4 100644 --- a/netwerk/base/src/nsIOService.h +++ b/netwerk/base/src/nsIOService.h @@ -56,6 +56,7 @@ #include "nsWeakReference.h" #include "nsINetUtil.h" #include "nsIChannelEventSink.h" +#include "nsIContentSniffer.h" #include "nsCategoryCache.h" #define NS_N(x) (sizeof(x)/sizeof(*x)) @@ -99,6 +100,11 @@ public: nsresult OnChannelRedirect(nsIChannel* oldChan, nsIChannel* newChan, PRUint32 flags); + // Gets the array of registered content sniffers + const nsCOMArray& GetContentSniffers() const { + return mContentSniffers.GetEntries(); + } + private: // These shouldn't be called directly: // - construct using GetInstance @@ -131,6 +137,7 @@ private: // cached categories nsCategoryCache mChannelEventSinks; + nsCategoryCache mContentSniffers; nsVoidArray mRestrictedPortList; diff --git a/netwerk/base/src/nsInputStreamPump.cpp b/netwerk/base/src/nsInputStreamPump.cpp index c018ecc1d53d..b083fa52f17f 100644 --- a/netwerk/base/src/nsInputStreamPump.cpp +++ b/netwerk/base/src/nsInputStreamPump.cpp @@ -81,6 +81,30 @@ nsInputStreamPump::~nsInputStreamPump() { } +nsresult +nsInputStreamPump::Create(nsInputStreamPump **result, + nsIInputStream *stream, + PRInt64 streamPos, + PRInt64 streamLen, + PRUint32 segsize, + PRUint32 segcount, + PRBool closeWhenDone) +{ + nsresult rv = NS_ERROR_OUT_OF_MEMORY; + nsRefPtr pump = new nsInputStreamPump(); + if (pump) { + rv = pump->Init(stream, streamPos, streamLen, + segsize, segcount, closeWhenDone); + if (NS_SUCCEEDED(rv)) { + *result = nsnull; + pump.swap(*result); + } + } + return rv; +} + + + struct PeekData { PeekData(nsInputStreamPump::PeekSegmentFun fun, void* closure) : mFunc(fun), mClosure(closure) {} diff --git a/netwerk/base/src/nsInputStreamPump.h b/netwerk/base/src/nsInputStreamPump.h index 6b7f517a2258..4cce57730dc3 100644 --- a/netwerk/base/src/nsInputStreamPump.h +++ b/netwerk/base/src/nsInputStreamPump.h @@ -61,6 +61,15 @@ public: nsInputStreamPump(); ~nsInputStreamPump(); + static NS_HIDDEN_(nsresult) + Create(nsInputStreamPump **result, + nsIInputStream *stream, + PRInt64 streamPos = -1, + PRInt64 streamLen = -1, + PRUint32 segsize = 0, + PRUint32 segcount = 0, + PRBool closeWhenDone = PR_FALSE); + typedef void (*PeekSegmentFun)(void *closure, const PRUint8 *buf, PRUint32 bufLen); /** diff --git a/netwerk/build/nsNetCID.h b/netwerk/build/nsNetCID.h index 99647300fe69..a8d10c0b6f89 100644 --- a/netwerk/build/nsNetCID.h +++ b/netwerk/build/nsNetCID.h @@ -778,5 +778,22 @@ */ #define NS_CHANNEL_EVENT_SINK_CATEGORY "net-channel-event-sinks" +/** + * Services in this category will get told about each load that happens and get + * the opportunity to override the detected MIME type via nsIContentSniffer. + * Services should not set the MIME type on the channel directly, but return the + * new type. If getMIMETypeFromContent throws an exception, the type will remain + * unchanged. + * + * Note that only channels with the LOAD_CALL_CONTENT_SNIFFERS flag will call + * content sniffers. Also note that there can be security implications about + * changing the MIME type -- proxies filtering responses based on their MIME + * type might consider certain types to be safe, which these sniffers can + * override. + * + * Not all channels may implement content sniffing. See also + * nsIChannel::LOAD_CALL_CONTENT_SNIFFERS. + */ +#define NS_CONTENT_SNIFFER_CATEGORY "net-content-sniffers" #endif // nsNetCID_h__ diff --git a/netwerk/protocol/http/src/nsHttpChannel.cpp b/netwerk/protocol/http/src/nsHttpChannel.cpp index 59e78174a94c..1d4c41173fe0 100644 --- a/netwerk/protocol/http/src/nsHttpChannel.cpp +++ b/netwerk/protocol/http/src/nsHttpChannel.cpp @@ -72,6 +72,7 @@ #include "nsIVariant.h" #include "nsChannelProperties.h" #include "nsStreamUtils.h" +#include "nsIOService.h" // True if the local cache should be bypassed when processing a request. #define BYPASS_LOCAL_CACHE(loadFlags) \ @@ -645,8 +646,8 @@ nsHttpChannel::SetupTransaction() getter_AddRefs(responseStream)); if (NS_FAILED(rv)) return rv; - rv = NS_NewInputStreamPump(getter_AddRefs(mTransactionPump), - responseStream); + rv = nsInputStreamPump::Create(getter_AddRefs(mTransactionPump), + responseStream); return rv; } @@ -713,6 +714,27 @@ nsHttpChannel::ApplyContentConversions() return NS_OK; } +// NOTE: This function duplicates code from nsBaseChannel. This will go away +// once HTTP uses nsBaseChannel (part of bug 312760) +static void +CallTypeSniffers(void *aClosure, const PRUint8 *aData, PRUint32 aCount) +{ + nsIChannel *chan = NS_STATIC_CAST(nsIChannel*, aClosure); + + const nsCOMArray& sniffers = + gIOService->GetContentSniffers(); + PRUint32 length = sniffers.Count(); + for (PRUint32 i = 0; i < length; ++i) { + nsCAutoString newType; + nsresult rv = + sniffers[i]->GetMIMETypeFromContent(chan, aData, aCount, newType); + if (NS_SUCCEEDED(rv) && !newType.IsEmpty()) { + chan->SetContentType(newType); + break; + } + } +} + nsresult nsHttpChannel::CallOnStartRequest() { @@ -750,6 +772,17 @@ nsHttpChannel::CallOnStartRequest() SetPropertyAsInt64(NS_CHANNEL_PROP_CONTENT_LENGTH, mResponseHead->ContentLength()); + // Allow consumers to override our content type + if ((mLoadFlags & LOAD_CALL_CONTENT_SNIFFERS) && + gIOService->GetContentSniffers().Count() != 0) { + if (mTransactionPump) + mTransactionPump->PeekStream(CallTypeSniffers, + NS_STATIC_CAST(nsIChannel*, this)); + else + mCachePump->PeekStream(CallTypeSniffers, + NS_STATIC_CAST(nsIChannel*, this)); + } + LOG((" calling mListener->OnStartRequest\n")); nsresult rv = mListener->OnStartRequest(this, mListenerContext); if (NS_FAILED(rv)) return rv; @@ -1682,9 +1715,9 @@ nsHttpChannel::ReadFromCache() rv = mCacheEntry->OpenInputStream(0, getter_AddRefs(stream)); if (NS_FAILED(rv)) return rv; - rv = NS_NewInputStreamPump(getter_AddRefs(mCachePump), - stream, nsInt64(-1), nsInt64(-1), 0, 0, - PR_TRUE); + rv = nsInputStreamPump::Create(getter_AddRefs(mCachePump), + stream, nsInt64(-1), nsInt64(-1), 0, 0, + PR_TRUE); if (NS_FAILED(rv)) return rv; return mCachePump->AsyncRead(this, mListenerContext); @@ -3976,6 +4009,12 @@ nsHttpChannel::OnStartRequest(nsIRequest *request, nsISupports *ctxt) LOG(("nsHttpChannel::OnStartRequest [this=%x request=%x status=%x]\n", this, request, mStatus)); + // Make sure things are what we expect them to be... + NS_ASSERTION(request == mCachePump || request == mTransactionPump, + "Unexpected request"); + NS_ASSERTION(!mTransactionPump || request == mTransactionPump, + "If we have a txn pump, request must be it"); + // don't enter this block if we're reading from the cache... if (NS_SUCCEEDED(mStatus) && !mCachePump && mTransaction) { // grab the security info from the connection object; the transaction diff --git a/netwerk/protocol/http/src/nsHttpChannel.h b/netwerk/protocol/http/src/nsHttpChannel.h index 1b1123b97a35..4adbe12ba022 100644 --- a/netwerk/protocol/http/src/nsHttpChannel.h +++ b/netwerk/protocol/http/src/nsHttpChannel.h @@ -22,6 +22,7 @@ * * Contributor(s): * Darin Fisher (original author) + * Christian Biesinger * * Alternatively, the contents of this file may be used under the terms of * either the GNU General Public License Version 2 or later (the "GPL"), or @@ -43,7 +44,9 @@ #include "nsHttpTransaction.h" #include "nsHttpRequestHead.h" #include "nsHttpAuthCache.h" +#include "nsInputStreamPump.h" #include "nsXPIDLString.h" +#include "nsAutoPtr.h" #include "nsCOMPtr.h" #include "nsInt64.h" @@ -71,7 +74,6 @@ #include "nsIStringEnumerator.h" #include "nsIOutputStream.h" #include "nsIAsyncInputStream.h" -#include "nsIInputStreamPump.h" #include "nsIPrompt.h" #include "nsIResumableChannel.h" #include "nsISupportsPriority.h" @@ -228,7 +230,7 @@ private: nsHttpRequestHead mRequestHead; nsHttpResponseHead *mResponseHead; - nsCOMPtr mTransactionPump; + nsRefPtr mTransactionPump; nsHttpTransaction *mTransaction; // hard ref nsHttpConnectionInfo *mConnectionInfo; // hard ref @@ -246,7 +248,7 @@ private: // cache specific data nsCOMPtr mCacheEntry; - nsCOMPtr mCachePump; + nsRefPtr mCachePump; nsHttpResponseHead *mCachedResponseHead; nsCacheAccessMode mCacheAccess; PRUint32 mPostID; diff --git a/netwerk/test/unit/test_content_sniffer.js b/netwerk/test/unit/test_content_sniffer.js new file mode 100644 index 000000000000..3582e93f0d94 --- /dev/null +++ b/netwerk/test/unit/test_content_sniffer.js @@ -0,0 +1,124 @@ +// This file tests nsIContentSniffer, introduced in bug 324985 + +const unknownType = "application/x-unknown-content-type"; +const sniffedType = "application/x-sniffed"; + +const snifferCID = Components.ID("{4c93d2db-8a56-48d7-b261-9cf2a8d998eb}"); +const snifferContract = "@mozilla.org/network/unittest/contentsniffer;1"; +const categoryName = "net-content-sniffers"; + +var sniffing_enabled = true; + +/** + * This object is both a factory and an nsIContentSniffer implementation (so, it + * is de-facto a service) + */ +var sniffer = { + QueryInterface: function sniffer_qi(iid) { + if (iid.equals(Components.interfaces.nsISupports) || + iid.equals(Components.interfaces.nsIFactory) || + iid.equals(Components.interfaces.nsIContentSniffer)) + return this; + throw Components.results.NS_ERROR_NO_INTERFACE; + }, + createInstance: function sniffer_ci(outer, iid) { + if (outer) + throw Components.results.NS_ERROR_NO_AGGREGATION; + return this.QueryInterface(iid); + }, + lockFactory: function sniffer_lockf(lock) { + throw Components.results.NS_ERROR_NOT_IMPLEMENTED; + }, + + getMIMETypeFromContent: function (request, data, length) { + return sniffedType; + } +}; + +var listener = { + onStartRequest: function test_onStartR(request, ctx) { + try { + var chan = request.QueryInterface(Components.interfaces.nsIChannel); + if (chan.contentType == unknownType) + do_throw("Type should not be unknown!"); + if (sniffing_enabled && this._iteration > 2 && + chan.contentType != sniffedType) { + do_throw("Expecting <" + sniffedType +"> but got <" + + chan.contentType + "> for " + chan.URI.spec); + } else if (!sniffing_enabled && chan.contentType == sniffedType) { + do_throw("Sniffing not enabled but sniffer called for " + chan.URI.spec); + } + } catch (e) { + do_throw("Unexpected exception: " + e); + } + + throw Components.results.NS_ERROR_ABORT; + }, + + onDataAvailable: function test_ODA() { + throw Components.results.NS_ERROR_UNEXPECTED; + }, + + onStopRequest: function test_onStopR(request, ctx, status) { + run_test_iteration(this._iteration); + do_test_finished(); + }, + + _iteration: 1 +}; + +function makeChan(url) { + var ios = Components.classes["@mozilla.org/network/io-service;1"] + .getService(Components.interfaces.nsIIOService); + var chan = ios.newChannel(url, null, null); + if (sniffing_enabled) + chan.loadFlags |= Components.interfaces.nsIChannel.LOAD_CALL_CONTENT_SNIFFERS; + + return chan; +} + +var urls = [ + // NOTE: First URL here runs without our content sniffer + "data:" + unknownType + ", Some text", + "data:" + unknownType + ", Text", // Make sure sniffing works even if we + // used the unknown content sniffer too + "data:text/plain, Some more text", + "http://localhost:4444" +]; + +function run_test() { + start_server(4444); + + Components.manager.nsIComponentRegistrar.registerFactory(snifferCID, + "Unit test content sniffer", snifferContract, sniffer); + + run_test_iteration(1); +} + +function run_test_iteration(index) { + if (index > urls.length) { + if (sniffing_enabled) { + sniffing_enabled = false; + index = listener._iteration = 1; + } else { + return; // we're done + } + } + + if (sniffing_enabled && index == 2) { + // Register our sniffer only here + // This also makes sure that dynamic registration is working + var catMan = Components.classes["@mozilla.org/categorymanager;1"] + .getService(Components.interfaces.nsICategoryManager); + catMan.nsICategoryManager.addCategoryEntry(categoryName, "unit test", + snifferContract, false, true); + } + + var chan = makeChan(urls[index - 1]); + + listener._iteration++; + chan.asyncOpen(listener, null); + + do_test_pending(); +} +