From 532a82cb78575df08a99fb9ac01cc3c57a9d13e5 Mon Sep 17 00:00:00 2001 From: "dcamp@mozilla.com" Date: Wed, 25 Jul 2007 23:38:43 -0700 Subject: [PATCH] try landing new safebrowsing protocol again. b=387196, r=tony, r=vlad (for new fixes) --- browser/app/profile/firefox.js | 2 +- .../safebrowsing/content/list-warden.js | 97 +- .../safebrowsing/content/sb-loader.js | 6 +- toolkit/components/build/Makefile.in | 1 + .../url-classifier/content/listmanager.js | 336 +-- .../url-classifier/content/moz/alarm.js | 4 + .../url-classifier/public/Makefile.in | 1 - .../public/nsIUrlClassifierDBService.idl | 48 +- .../public/nsIUrlClassifierStreamUpdater.idl | 17 +- .../public/nsIUrlClassifierUtils.idl | 25 +- .../public/nsIUrlListManager.idl | 19 +- .../components/url-classifier/src/Makefile.in | 9 +- .../src/nsUrlClassifierDBService.cpp | 2406 +++++++++++++---- .../src/nsUrlClassifierListManager.js | 1 - .../src/nsUrlClassifierStreamUpdater.cpp | 76 +- .../src/nsUrlClassifierStreamUpdater.h | 2 + .../src/nsUrlClassifierUtils.cpp | 321 ++- .../url-classifier/src/nsUrlClassifierUtils.h | 19 +- .../url-classifier/tests/Makefile.in | 11 +- .../tests/TestUrlClassifierUtils.cpp | 149 +- 20 files changed, 2559 insertions(+), 991 deletions(-) diff --git a/browser/app/profile/firefox.js b/browser/app/profile/firefox.js index cb296d045c48..f4204004bbd0 100644 --- a/browser/app/profile/firefox.js +++ b/browser/app/profile/firefox.js @@ -493,7 +493,7 @@ pref("browser.safebrowsing.enabled", true); pref("browser.safebrowsing.remoteLookups", false); // Non-enhanced mode (local url lists) URL list to check for updates -pref("browser.safebrowsing.provider.0.updateURL", "http://sb.google.com/safebrowsing/update?client={moz:client}&appver={moz:version}&"); +pref("browser.safebrowsing.provider.0.updateURL", "http://sb.google.com/safebrowsing/downloads?client={moz:client}&appver={moz:version}&pver=2.0"); pref("browser.safebrowsing.dataProvider", 0); diff --git a/browser/components/safebrowsing/content/list-warden.js b/browser/components/safebrowsing/content/list-warden.js index 677e8ba7fd73..c63e75fa1003 100644 --- a/browser/components/safebrowsing/content/list-warden.js +++ b/browser/components/safebrowsing/content/list-warden.js @@ -188,82 +188,51 @@ function MultiTableQuerier(url, whiteTables, blackTables, callback) { this.debugZone = "multitablequerier"; this.url_ = url; - this.whiteTables_ = whiteTables; - this.blackTables_ = blackTables; - this.whiteIdx_ = 0; - this.blackIdx_ = 0; + this.whiteTables_ = {}; + for (var i = 0; i < whiteTables.length; i++) { + this.whiteTables_[whiteTables[i]] = true; + } + + this.blackTables_ = {}; + for (var i = 0; i < blackTables.length; i++) { + this.blackTables_[blackTables[i]] = true; + } this.callback_ = callback; this.listManager_ = Cc["@mozilla.org/url-classifier/listmanager;1"] .getService(Ci.nsIUrlListManager); } -/** - * We first query the white tables in succession. If any contain - * the url, we stop. If none contain the url, we query the black tables - * in succession. If any contain the url, we call callback and - * stop. If none of the black tables contain the url, then we just stop - * (i.e., it's not black url). - */ MultiTableQuerier.prototype.run = function() { - var whiteTable = this.whiteTables_[this.whiteIdx_]; - var blackTable = this.blackTables_[this.blackIdx_]; - if (whiteTable) { - //G_Debug(this, "Looking in whitetable: " + whiteTable); - ++this.whiteIdx_; - this.listManager_.safeExists(whiteTable, this.url_, - BindToObject(this.whiteTableCallback_, - this)); - } else if (blackTable) { - //G_Debug(this, "Looking in blacktable: " + blackTable); - ++this.blackIdx_; - this.listManager_.safeExists(blackTable, this.url_, - BindToObject(this.blackTableCallback_, - this)); - } else { - // No tables left to check, so we quit. - G_Debug(this, "Not found in any tables: " + this.url_); + /* ask the dbservice for all the tables to which this URL belongs */ + this.listManager_.safeLookup(this.url_, + BindToObject(this.lookupCallback_, this)); +} + +MultiTableQuerier.prototype.lookupCallback_ = function(result) { + if (result == "") { this.callback_(PROT_ListWarden.NOT_FOUND); - - // Break circular ref to callback. - this.callback_ = null; - this.listManager_ = null; + return; } -} -/** - * After checking a white table, we return here. If the url is found, - * we can stop. Otherwise, we call run again. - */ -MultiTableQuerier.prototype.whiteTableCallback_ = function(isFound) { - //G_Debug(this, "whiteTableCallback_: " + isFound); - if (!isFound) - this.run(); - else { - G_Debug(this, "Found in whitelist: " + this.url_) - this.callback_(PROT_ListWarden.IN_WHITELIST); + var tableNames = result.split(","); - // Break circular ref to callback. - this.callback_ = null; - this.listManager_ = null; + /* Check the whitelists */ + for (var i = 0; i < tableNames.length; i++) { + if (tableNames[i] && this.whiteTables_[tableNames[i]]) { + this.callback_(PROT_ListWarden.IN_WHITELIST); + return; + } } -} -/** - * After checking a black table, we return here. If the url is found, - * we can call the callback and stop. Otherwise, we call run again. - */ -MultiTableQuerier.prototype.blackTableCallback_ = function(isFound) { - //G_Debug(this, "blackTableCallback_: " + isFound); - if (!isFound) { - this.run(); - } else { - // In the blacklist, must be an evil url. - G_Debug(this, "Found in blacklist: " + this.url_) - this.callback_(PROT_ListWarden.IN_BLACKLIST); - - // Break circular ref to callback. - this.callback_ = null; - this.listManager_ = null; + /* Check the blacklists */ + for (var i = 0; i < tableNames.length; i++) { + if (tableNames[i] && this.blackTables_[tableNames[i]]) { + this.callback_(PROT_ListWarden.IN_BLACKLIST); + return; + } } + + /* Not in any lists we know about */ + this.callback_(PROT_ListWarden.NOT_FOUND); } diff --git a/browser/components/safebrowsing/content/sb-loader.js b/browser/components/safebrowsing/content/sb-loader.js index bf1b407e21c9..3a3a16563df3 100644 --- a/browser/components/safebrowsing/content/sb-loader.js +++ b/browser/components/safebrowsing/content/sb-loader.js @@ -79,10 +79,8 @@ var safebrowsing = { // Register tables // XXX: move table names to a pref that we originally will download // from the provider (need to workout protocol details) - phishWarden.registerWhiteTable("goog-white-domain"); - phishWarden.registerWhiteTable("goog-white-url"); - phishWarden.registerBlackTable("goog-black-url"); - phishWarden.registerBlackTable("goog-black-enchash"); + phishWarden.registerWhiteTable("goog-white-exp"); + phishWarden.registerBlackTable("goog-phish-sha128"); // Download/update lists if we're in non-enhanced mode phishWarden.maybeToggleUpdateChecking(); diff --git a/toolkit/components/build/Makefile.in b/toolkit/components/build/Makefile.in index 970bf3c9303f..dc8e297a4442 100644 --- a/toolkit/components/build/Makefile.in +++ b/toolkit/components/build/Makefile.in @@ -129,6 +129,7 @@ endif ifdef MOZ_URL_CLASSIFIER SHARED_LIBRARY_LIBS += ../url-classifier/src/$(LIB_PREFIX)urlclassifier_s.$(LIB_SUFFIX) +EXTRA_DSO_LDOPTS += $(ZLIB_LIBS) endif ifdef MOZ_FEEDS diff --git a/toolkit/components/url-classifier/content/listmanager.js b/toolkit/components/url-classifier/content/listmanager.js index c5564cb5bcc1..8ec60262d960 100644 --- a/toolkit/components/url-classifier/content/listmanager.js +++ b/toolkit/components/url-classifier/content/listmanager.js @@ -38,49 +38,24 @@ // A class that manages lists, namely white and black lists for // phishing or malware protection. The ListManager knows how to fetch, -// update, and store lists, and knows the "kind" of list each is (is -// it a whitelist? a blacklist? etc). However it doesn't know how the -// lists are serialized or deserialized (the wireformat classes know -// this) nor the specific format of each list. For example, the list -// could be a map of domains to "1" if the domain is phishy. Or it -// could be a map of hosts to regular expressions to match, who knows? -// Answer: the trtable knows. List are serialized/deserialized by the -// wireformat reader from/to trtables, and queried by the listmanager. +// update, and store lists. // // There is a single listmanager for the whole application. // -// The listmanager is used only in privacy mode; in advanced protection -// mode a remote server is queried. -// -// How to add a new table: -// 1) get it up on the server -// 2) add it to tablesKnown -// 3) if it is not a known table type (trtable.js), add an implementation -// for it in trtable.js -// 4) add a check for it in the phishwarden's isXY() method, for example -// isBlackURL() -// -// TODO: obviously the way this works could use a lot of improvement. In -// particular adding a list should just be a matter of adding -// its name to the listmanager and an implementation to trtable -// (or not if a talbe of that type exists). The format and semantics -// of the list comprise its name, so the listmanager should easily -// be able to figure out what to do with what list (i.e., no -// need for step 4). // TODO more comprehensive update tests, for example add unittest check // that the listmanagers tables are properly written on updates -/** - * The base pref name for where we keep table version numbers. - * We add append the table name to this and set the value to - * the version. E.g., tableversion.goog-black-enchash may have - * a value of 1.1234. - */ -const kTableVersionPrefPrefix = "urlclassifier.tableversion."; - // How frequently we check for updates (30 minutes) const kUpdateInterval = 30 * 60 * 1000; +function QueryAdapter(callback) { + this.callback_ = callback; +}; + +QueryAdapter.prototype.handleResponse = function(value) { + this.callback_.handleEvent(value); +} + /** * A ListManager keeps track of black and white lists and knows * how to update them. @@ -96,28 +71,7 @@ function PROT_ListManager() { this.updateserverURL_ = null; - // The lists we know about and the parses we can use to read - // them. Default all to the earlies possible version (1.-1); this - // version will get updated when successfully read from disk or - // fetch updates. - this.tablesKnown_ = {}; this.isTesting_ = false; - - if (this.isTesting_) { - // populate with some tables for unittesting - this.tablesKnown_ = { - // A major version of zero means local, so don't ask for updates - "test1-foo-domain" : new PROT_VersionParser("test1-foo-domain", 0, -1), - "test2-foo-domain" : new PROT_VersionParser("test2-foo-domain", 0, -1), - "test-white-domain" : - new PROT_VersionParser("test-white-domain", 0, -1, true /* require mac*/), - "test-mac-domain" : - new PROT_VersionParser("test-mac-domain", 0, -1, true /* require mac */) - }; - - // expose the object for unittesting - this.wrappedJSObject = this; - } this.tablesData = {}; @@ -133,6 +87,9 @@ function PROT_ListManager() { 10*60*1000 /* error time, 10min */, 60*60*1000 /* backoff interval, 60min */, 6*60*60*1000 /* max backoff, 6hr */); + + this.dbService_ = Cc["@mozilla.org/url-classifier/dbservice;1"] + .getService(Ci.nsIUrlClassifierDBService); } /** @@ -163,7 +120,6 @@ PROT_ListManager.prototype.setUpdateUrl = function(url) { // Remove old tables which probably aren't valid for the new provider. for (var name in this.tablesData) { delete this.tablesData[name]; - delete this.tablesKnown_[name]; } } } @@ -188,11 +144,8 @@ PROT_ListManager.prototype.setKeyUrl = function(url) { */ PROT_ListManager.prototype.registerTable = function(tableName, opt_requireMac) { - var table = new PROT_VersionParser(tableName, 1, -1, opt_requireMac); - if (!table) - return false; - this.tablesKnown_[tableName] = table; - this.tablesData[tableName] = newUrlClassifierTable(tableName); + this.tablesData[tableName] = {}; + this.tablesData[tableName].needsUpdate = false; return true; } @@ -203,7 +156,7 @@ PROT_ListManager.prototype.registerTable = function(tableName, */ PROT_ListManager.prototype.enableUpdate = function(tableName) { var changed = false; - var table = this.tablesKnown_[tableName]; + var table = this.tablesData[tableName]; if (table) { G_Debug(this, "Enabling table updates for " + tableName); table.needsUpdate = true; @@ -220,7 +173,7 @@ PROT_ListManager.prototype.enableUpdate = function(tableName) { */ PROT_ListManager.prototype.disableUpdate = function(tableName) { var changed = false; - var table = this.tablesKnown_[tableName]; + var table = this.tablesData[tableName]; if (table) { G_Debug(this, "Disabling table updates for " + tableName); table.needsUpdate = false; @@ -235,14 +188,9 @@ PROT_ListManager.prototype.disableUpdate = function(tableName) { * Determine if we have some tables that need updating. */ PROT_ListManager.prototype.requireTableUpdates = function() { - for (var type in this.tablesKnown_) { - // All tables with a major of 0 are internal tables that we never - // update remotely. - if (this.tablesKnown_[type].major == 0) - continue; - + for (var type in this.tablesData) { // Tables that need updating even if other tables dont require it - if (this.tablesKnown_[type].needsUpdate) + if (this.tablesData[type].needsUpdate) return true; } @@ -263,6 +211,22 @@ PROT_ListManager.prototype.maybeStartManagingUpdates = function() { this.maybeToggleUpdateChecking(); } +PROT_ListManager.prototype.kickoffUpdate_ = function (tableData) +{ + this.startingUpdate_ = false; + // If the user has never downloaded tables, do the check now. + // If the user has tables, add a fuzz of a few minutes. + var initialUpdateDelay = 3000; + if (tableData != "") { + // Add a fuzz of 0-5 minutes. + initialUpdateDelay += Math.floor(Math.random() * (5 * 60 * 1000)); + } + + this.currentUpdateChecker_ = + new G_Alarm(BindToObject(this.checkForUpdates, this), + initialUpdateDelay); +} + /** * Determine if we have any tables that require updating. Different * Wardens may call us with new tables that need to be updated. @@ -281,26 +245,10 @@ PROT_ListManager.prototype.maybeToggleUpdateChecking = function() { // Multiple warden can ask us to reenable updates at the same time, but we // really just need to schedule a single update. - if (!this.currentUpdateChecker_) { - // If the user has never downloaded tables, do the check now. - // If the user has tables, add a fuzz of a few minutes. - this.loadTableVersions_(); - var hasTables = false; - for (var table in this.tablesKnown_) { - if (this.tablesKnown_[table].minor != -1) { - hasTables = true; - break; - } - } - - var initialUpdateDelay = 3000; - if (hasTables) { - // Add a fuzz of 0-5 minutes. - initialUpdateDelay += Math.floor(Math.random() * (5 * 60 * 1000)); - } - this.currentUpdateChecker_ = - new G_Alarm(BindToObject(this.checkForUpdates, this), - initialUpdateDelay); + if (!this.currentUpdateChecker && !this.startingUpdate_) { + this.startingUpdate_ = true; + // check the current state of tables in the database + this.dbService_.getTables(BindToObject(this.kickoffUpdate_, this)); } } else { G_Debug(this, "Stopping managing lists (if currently active)"); @@ -363,116 +311,19 @@ PROT_ListManager.prototype.stopUpdateChecker = function() { * value in the table corresponding to key. If the table name does not * exist, we return false, too. */ -PROT_ListManager.prototype.safeExists = function(table, key, callback) { +PROT_ListManager.prototype.safeLookup = function(key, callback) { try { - G_Debug(this, "safeExists: " + table + ", " + key); - var map = this.tablesData[table]; - map.exists(key, callback); + G_Debug(this, "safeLookup: " + key); + var cb = new QueryAdapter(callback); + this.dbService_.lookup(key, + BindToObject(cb.handleResponse, cb), + true); } catch(e) { - G_Debug(this, "safeExists masked failure for " + table + ", key " + key + ": " + e); - callback.handleEvent(false); + G_Debug(this, "safeLookup masked failure for key " + key + ": " + e); + callback.handleEvent(""); } } -/** - * We store table versions in user prefs. This method pulls the values out of - * the user prefs and into the tablesKnown objects. - */ -PROT_ListManager.prototype.loadTableVersions_ = function() { - // Pull values out of prefs. - var prefBase = kTableVersionPrefPrefix; - for (var table in this.tablesKnown_) { - var version = this.prefs_.getPref(prefBase + table, "1.-1"); - G_Debug(this, "loadTableVersion " + table + ": " + version); - var tokens = version.split("."); - G_Assert(this, tokens.length == 2, "invalid version number"); - - this.tablesKnown_[table].major = tokens[0]; - this.tablesKnown_[table].minor = tokens[1]; - } -} - -/** - * Callback from db update service. As new tables are added to the db, - * this callback is fired so we can update the version number. - * @param versionString String containing the table update response from the - * server - */ -PROT_ListManager.prototype.setTableVersion_ = function(versionString) { - G_Debug(this, "Got version string: " + versionString); - var versionParser = new PROT_VersionParser(""); - if (versionParser.fromString(versionString)) { - var tableName = versionParser.type; - var versionNumber = versionParser.versionString(); - var prefBase = kTableVersionPrefPrefix; - - this.prefs_.setPref(prefBase + tableName, versionNumber); - - if (!this.tablesKnown_[tableName]) { - this.tablesKnown_[tableName] = versionParser; - } else { - this.tablesKnown_[tableName].ImportVersion(versionParser); - } - - if (!this.tablesData[tableName]) - this.tablesData[tableName] = newUrlClassifierTable(tableName); - } - - // Since this is called from the update server, it means there was - // a successful http request. Make sure to notify the request backoff - // object. - this.requestBackoff_.noteServerResponse(200 /* ok */); -} - -/** - * Prepares a URL to fetch upates from. Format is a squence of - * type:major:minor, fields - * - * @param url The base URL to which query parameters are appended; assumes - * already has a trailing ? - * @returns the URL that we should request the table update from. - */ -PROT_ListManager.prototype.getRequestURL_ = function(url) { - url += "version="; - var firstElement = true; - var requestMac = false; - - for (var type in this.tablesKnown_) { - // All tables with a major of 0 are internal tables that we never - // update remotely. - if (this.tablesKnown_[type].major == 0) - continue; - - // Check if the table needs updating - if (this.tablesKnown_[type].needsUpdate == false) - continue; - - if (!firstElement) { - url += "," - } else { - firstElement = false; - } - url += type + ":" + this.tablesKnown_[type].toUrl(); - - if (this.tablesKnown_[type].requireMac) - requestMac = true; - } - - // Request a mac only if at least one of the tables to be updated requires - // it - if (requestMac) { - // Add the wrapped key for requesting macs - if (!this.urlCrypto_) - this.urlCrypto_ = new PROT_UrlCrypto(); - - url += "&wrkey=" + - encodeURIComponent(this.urlCrypto_.getManager().getWrappedKey()); - } - - G_Debug(this, "getRequestURL returning: " + url); - return url; -} - /** * Updates our internal tables from the update server * @@ -492,56 +343,87 @@ PROT_ListManager.prototype.checkForUpdates = function() { if (!this.requestBackoff_.canMakeRequest()) return false; - // Check to make sure our tables still exist (maybe the db got corrupted or - // the user deleted the file). If not, we need to reset the table version - // before sending the update check. - var tableNames = []; - for (var tableName in this.tablesKnown_) { - tableNames.push(tableName); - } - var dbService = Cc["@mozilla.org/url-classifier/dbservice;1"] - .getService(Ci.nsIUrlClassifierDBService); - dbService.checkTables(tableNames.join(","), - BindToObject(this.makeUpdateRequest_, this)); + // Grab the current state of the tables from the database + this.dbService_.getTables(BindToObject(this.makeUpdateRequest_, this)); return true; } /** * Method that fires the actual HTTP update request. * First we reset any tables that have disappeared. - * @param tableNames String comma separated list of tables that - * don't exist + * @param tableData List of table data already in the database, in the form + * tablename;\n */ -PROT_ListManager.prototype.makeUpdateRequest_ = function(tableNames) { - // Clear prefs that track table version if they no longer exist in the db. - var tables = tableNames.split(","); - for (var i = 0; i < tables.length; ++i) { - G_Debug(this, "Table |" + tables[i] + "| no longer exists, clearing pref."); - this.prefs_.clearPref(kTableVersionPrefPrefix + tables[i]); +PROT_ListManager.prototype.makeUpdateRequest_ = function(tableData) { + var tableNames = {}; + for (var tableName in this.tablesData) { + tableNames[tableName] = true; } - // Ok, now reload the table version. - this.loadTableVersions_(); + var request = ""; + + // For each table already in the database, include the chunk data from + // the database + var lines = tableData.split("\n"); + for (var i = 0; i < lines.length; i++) { + var fields = lines[i].split(";"); + if (tableNames[fields[0]]) { + request += lines[i] + "\n"; + delete tableNames[fields[0]]; + } + } + + // For each requested table that didn't have chunk data in the database, + // request it fresh + for (var tableName in tableNames) { + request += tableName + ";\n"; + } G_Debug(this, 'checkForUpdates: scheduling request..'); - var url = this.getRequestURL_(this.updateserverURL_); var streamer = Cc["@mozilla.org/url-classifier/streamupdater;1"] .getService(Ci.nsIUrlClassifierStreamUpdater); try { - streamer.updateUrl = url; + streamer.updateUrl = this.updateserverURL_; } catch (e) { G_Debug(this, 'invalid url'); return; } - if (!streamer.downloadUpdates(BindToObject(this.setTableVersion_, this), + if (!streamer.downloadUpdates(request, + BindToObject(this.updateSuccess_, this), + BindToObject(this.updateError_, this), BindToObject(this.downloadError_, this))) { G_Debug(this, "pending update, wait until later"); } } /** - * Callback function if there's a download error. + * Callback function if the update request succeeded. + * @param waitForUpdate String The number of seconds that the client should + * wait before requesting again. + */ +PROT_ListManager.prototype.updateSuccess_ = function(waitForUpdate) { + G_Debug(this, "update success: " + waitForUpdate); + if (waitForUpdate) { + var delay = parseInt(waitForUpdate, 10); + // As long as the delay is something sane (5 minutes or more), update + // our delay time for requesting updates + if (delay >= (5 * 60) && this.updateChecker_) + this.updateChecker_.setDelay(delay * 1000); + } +} + +/** + * Callback function if the update request succeeded. + * @param result String The error code of the failure + */ +PROT_ListManager.prototype.updateError_ = function(result) { + G_Debug(this, "update error: " + result); + // XXX: there was some trouble applying the updates. +} + +/** + * Callback function when the download failed * @param status String http status or an empty string if connection refused. */ PROT_ListManager.prototype.downloadError_ = function(status) { @@ -568,17 +450,3 @@ PROT_ListManager.prototype.QueryInterface = function(iid) { Components.returnCode = Components.results.NS_ERROR_NO_INTERFACE; return null; } - -// A simple factory function that creates nsIUrlClassifierTable instances based -// on a name. The name is a string of the format -// provider_name-semantic_type-table_type. For example, goog-white-enchash -// or goog-black-url. -function newUrlClassifierTable(name) { - G_Debug("protfactory", "Creating a new nsIUrlClassifierTable: " + name); - var tokens = name.split('-'); - var type = tokens[2]; - var table = Cc['@mozilla.org/url-classifier/table;1?type=' + type] - .createInstance(Ci.nsIUrlClassifierTable); - table.name = name; - return table; -} diff --git a/toolkit/components/url-classifier/content/moz/alarm.js b/toolkit/components/url-classifier/content/moz/alarm.js index e6438e2b4446..4b428e5199ce 100644 --- a/toolkit/components/url-classifier/content/moz/alarm.js +++ b/toolkit/components/url-classifier/content/moz/alarm.js @@ -134,6 +134,10 @@ G_Alarm.prototype.notify = function(timer) { return ret; } +G_Alarm.prototype.setDelay = function(delay) { + this.timer_.delay = delay; +} + /** * XPCOM cruft */ diff --git a/toolkit/components/url-classifier/public/Makefile.in b/toolkit/components/url-classifier/public/Makefile.in index 65193d2a5eb1..f12d14423400 100644 --- a/toolkit/components/url-classifier/public/Makefile.in +++ b/toolkit/components/url-classifier/public/Makefile.in @@ -10,7 +10,6 @@ XPIDL_MODULE = url-classifier XPIDLSRCS = nsIUrlClassifierDBService.idl \ nsIUrlClassifierStreamUpdater.idl \ - nsIUrlClassifierTable.idl \ nsIUrlClassifierUtils.idl \ nsIUrlListManager.idl \ $(NULL) diff --git a/toolkit/components/url-classifier/public/nsIUrlClassifierDBService.idl b/toolkit/components/url-classifier/public/nsIUrlClassifierDBService.idl index c6102c027d8e..bd99c0ff2bbd 100644 --- a/toolkit/components/url-classifier/public/nsIUrlClassifierDBService.idl +++ b/toolkit/components/url-classifier/public/nsIUrlClassifierDBService.idl @@ -49,32 +49,34 @@ interface nsIUrlClassifierCallback : nsISupports { * It provides async methods for querying and updating the database. As the * methods complete, they call the callback function. */ -[scriptable, uuid(211d5360-4af6-4a1d-99c1-926d35861eaf)] +[scriptable, uuid(10928bf5-e18d-4086-854b-6c4006f2b009)] interface nsIUrlClassifierDBService : nsISupports { /** - * Looks up a key in the database. After it finds a value, it calls - * callback with the value as the first param. If the key is not in - * the db or the table does not exist, the callback is called with - * an empty string parameter. + * Looks up a key in the database. + * + * @param key: The URL to search for. This URL will be canonicalized + * by the service. + * @param c: The callback will be called with a comma-separated list + * of tables to which the key belongs. + * @param needsProxy: Should be true if the callback needs to be called + * in the main thread, false if the callback is threadsafe. */ - void exists(in ACString tableName, in ACString key, - in nsIUrlClassifierCallback c); + void lookup(in ACString spec, + in nsIUrlClassifierCallback c, + in boolean needsProxy); /** - * Checks to see if the tables exist. tableNames is a comma separated list - * of table names to check. The callback is called with a comma separated - * list of tables that no longer exist (either the db is corrupted or the - * user deleted the file). + * Lists the tables along with which chunks are available in each table. + * This list is in the format of the request body: + * tablename;chunkdata\n + * tablename2;chunkdata2\n + * + * For example: + * goog-phish-regexp;a:10,14,30-40s:56,67 + * goog-white-regexp;a:1-3,5 */ - void checkTables(in ACString tableNames, in nsIUrlClassifierCallback c); - - /** - * Updates the table in the background. Calls callback after each table - * completes processing with the new table line as the parameter. This - * allows us to keep track of the table version in our main thread. - */ - void updateTables(in ACString updateString, in nsIUrlClassifierCallback c); + void getTables(in nsIUrlClassifierCallback c); //////////////////////////////////////////////////////////////////////////// // Incremental update methods. These are named to match similar methods @@ -89,10 +91,12 @@ interface nsIUrlClassifierDBService : nsISupports // interface, but it's tricky because of XPCOM proxies. /** - * Finish an incremental update. This commits any pending tables and - * calls the callback for each completed table. + * Finish an incremental update. Calls successCallback with the + * requested delay before the next update, or failureCallback with a + * result code. */ - void finish(in nsIUrlClassifierCallback c); + void finish(in nsIUrlClassifierCallback successCallback, + in nsIUrlClassifierCallback failureCallback); /** * Cancel an incremental update. This rolls back and pending changes. diff --git a/toolkit/components/url-classifier/public/nsIUrlClassifierStreamUpdater.idl b/toolkit/components/url-classifier/public/nsIUrlClassifierStreamUpdater.idl index 53fc5517ecdf..1e58e58d2a88 100644 --- a/toolkit/components/url-classifier/public/nsIUrlClassifierStreamUpdater.idl +++ b/toolkit/components/url-classifier/public/nsIUrlClassifierStreamUpdater.idl @@ -44,7 +44,7 @@ * downloading the whole update and then updating the sqlite database, we * update tables as the data is streaming in. */ -[scriptable, uuid(d9277fa4-7d51-4175-bd4e-546c080a83bf)] +[scriptable, uuid(adf0dfaa-ce91-4cf2-ab15-f5810408e2ec)] interface nsIUrlClassifierStreamUpdater : nsISupports { /** @@ -56,11 +56,14 @@ interface nsIUrlClassifierStreamUpdater : nsISupports * Try to download updates from updateUrl. Only one instance of this * runs at a time, so we return false if another instance is already * running. - * @param aTableCallback Called once for each table that we successfully - * download with the table header as the parameter. - * @param aErrorCallback Called if we get an http error or a connection - * refused. + * @param aRequestBody The body for the request. + * @param aSuccessCallback Called after a successful update. + * @param aUpdateErrorCallback Called for problems applying the update + * @param aDownloadErrorCallback Called if we get an http error or a + * connection refused error. */ - boolean downloadUpdates(in nsIUrlClassifierCallback aTableCallback, - in nsIUrlClassifierCallback aErrorCallback); + boolean downloadUpdates(in ACString aRequestBody, + in nsIUrlClassifierCallback aSuccessCallback, + in nsIUrlClassifierCallback aUpdateErrorCallback, + in nsIUrlClassifierCallback aDownloadErrorCallback); }; diff --git a/toolkit/components/url-classifier/public/nsIUrlClassifierUtils.idl b/toolkit/components/url-classifier/public/nsIUrlClassifierUtils.idl index 2e68548c0250..0e329642423b 100644 --- a/toolkit/components/url-classifier/public/nsIUrlClassifierUtils.idl +++ b/toolkit/components/url-classifier/public/nsIUrlClassifierUtils.idl @@ -39,27 +39,18 @@ * Some utility methods used by the url classifier. */ -[scriptable, uuid(89ea43b0-a23f-4db2-8d23-6d90dc55f67a)] +interface nsIURI; + +[scriptable, uuid(e4f0e59c-b922-48b0-a7b6-1735c1f96fed)] interface nsIUrlClassifierUtils : nsISupports { /** - * Canonicalize a URL. DON'T USE THIS DIRECTLY. Use - * PROT_EnchashDecrypter.prototype.getCanonicalUrl instead. This method - * url-decodes a string, but it doesn't normalize the hostname. The method - * in EnchashDecrypter first calls this method, then normalizes the hostname. + * Get the lookup string for a given URI. This normalizes the hostname, + * url-decodes the string, and strips off the protocol. * - * @param url String to canonicalize + * @param uri URI to get the lookup key for. * - * @returns String containing the canonicalized url (maximally url-decoded, - * then specially url-encoded) + * @returns String containing the canonicalized URI. */ - ACString canonicalizeURL(in ACString url); - - /** - * When canonicalizing hostnames, the final step is to url escape everything that - * is not alphanumeric or hyphen or dot. The existing methods (escape, - * encodeURIComponent and encodeURI are close, but not exactly what we want - * so we write our own function to do this. - */ - ACString escapeHostname(in ACString hostname); + ACString getKeyForURI(in nsIURI uri); }; diff --git a/toolkit/components/url-classifier/public/nsIUrlListManager.idl b/toolkit/components/url-classifier/public/nsIUrlListManager.idl index 45661a731753..8da4912e64e5 100644 --- a/toolkit/components/url-classifier/public/nsIUrlListManager.idl +++ b/toolkit/components/url-classifier/public/nsIUrlListManager.idl @@ -39,16 +39,17 @@ #include "nsISupports.idl" /** - * Interface for a class that manages updates of multiple nsIUrlClassifierTables. + * Interface for a class that manages updates of the url classifier database. */ // Interface for JS function callbacks -[scriptable, function, uuid(ba913c5c-13d6-41eb-83c1-de2f4165a516)] +[scriptable, function, uuid(fa4caf12-d057-4e7e-81e9-ce066ceee90b)] interface nsIUrlListManagerCallback : nsISupports { - void handleEvent(in boolean value); + void handleEvent(in ACString value); }; -[scriptable, uuid(d39982d6-da4f-4a27-8d91-f9c7b179aa33)] + +[scriptable, uuid(874d6c95-fb8b-4f89-b36d-85fe267ab356)] interface nsIUrlListManager : nsISupports { /** @@ -82,10 +83,12 @@ interface nsIUrlListManager : nsISupports void disableUpdate(in ACString tableName); /** - * Lookup a key in a table. Should not raise exceptions. Calls - * the callback function with a single parameter: true if the key - * is in the table, false if it isn't. + * Lookup a key. Should not raise exceptions. Calls the callback + * function with a comma-separated list of tables to which the key + * belongs. */ - void safeExists(in ACString tableName, in ACString key, + void safeLookup(in ACString key, in nsIUrlListManagerCallback cb); + + void checkForUpdates(); }; diff --git a/toolkit/components/url-classifier/src/Makefile.in b/toolkit/components/url-classifier/src/Makefile.in index 3e2ba368f514..8a8bb1f024d1 100644 --- a/toolkit/components/url-classifier/src/Makefile.in +++ b/toolkit/components/url-classifier/src/Makefile.in @@ -16,6 +16,7 @@ REQUIRES = necko \ storage \ string \ xpcom \ + $(ZLIB_REQUIRES) \ $(NULL) CPPSRCS = \ @@ -25,14 +26,16 @@ CPPSRCS = \ $(NULL) LOCAL_INCLUDES = \ - -I$(srcdir)/../../build + -I$(srcdir)/../../build \ $(NULL) # Same as JS components that are run through the pre-processor. -EXTRA_PP_COMPONENTS = nsUrlClassifierTable.js \ - nsUrlClassifierLib.js \ +EXTRA_PP_COMPONENTS = nsUrlClassifierLib.js \ nsUrlClassifierListManager.js \ $(NULL) include $(topsrcdir)/config/rules.mk + +export:: $(topsrcdir)/security/nss/lib/freebl/sha512.c + $(INSTALL) $^ . diff --git a/toolkit/components/url-classifier/src/nsUrlClassifierDBService.cpp b/toolkit/components/url-classifier/src/nsUrlClassifierDBService.cpp index d8e07ea0abfe..fa408505b016 100644 --- a/toolkit/components/url-classifier/src/nsUrlClassifierDBService.cpp +++ b/toolkit/components/url-classifier/src/nsUrlClassifierDBService.cpp @@ -22,6 +22,7 @@ * Contributor(s): * Tony Chang (original author) * Brett Wilson + * Dave Camp * * Alternatively, the contents of this file may be used under the terms of * either the GNU General Public License Version 2 or later (the "GPL"), or @@ -37,37 +38,86 @@ * * ***** END LICENSE BLOCK ***** */ +#include "nsCOMPtr.h" #include "mozIStorageService.h" #include "mozIStorageConnection.h" #include "mozIStorageStatement.h" +#include "mozStorageHelper.h" #include "mozStorageCID.h" #include "nsAppDirectoryServiceDefs.h" #include "nsAutoLock.h" -#include "nsCOMPtr.h" #include "nsCRT.h" +#include "nsICryptoHash.h" #include "nsIDirectoryService.h" #include "nsIObserverService.h" #include "nsIProperties.h" #include "nsIProxyObjectManager.h" #include "nsToolkitCompsCID.h" +#include "nsIUrlClassifierUtils.h" #include "nsUrlClassifierDBService.h" #include "nsString.h" #include "nsTArray.h" +#include "nsVoidArray.h" +#include "nsNetUtil.h" +#include "nsNetCID.h" #include "nsThreadUtils.h" #include "nsXPCOMStrings.h" #include "prlog.h" +#include "prlock.h" #include "prprf.h" +#include "zlib.h" + +/** + * The DBServices stores a set of Fragments. A fragment is one URL + * fragment containing two or more domain components and some number + * of path components. + * + * Fragment examples: + * example.com/ + * www.example.com/foo/bar + * www.mail.example.com/mail + * + * Fragments are described in "Simplified Regular Expression Lookup" + * section of the protocol document at + * http://code.google.com/p/google-safe-browsing/wiki/Protocolv2Spec + * + * A set of fragments is associated with a domain. The domain for a given + * fragment is the three-host-component domain of the fragment (two host + * components for URLs with only two components) with a trailing slash. + * So for the fragments listed above, the domains are example.com/, + * www.example.com/ and mail.example.com/. A collection of fragments for + * a given domain is referred to in this code as an Entry. + * + * Entries are associated with the table from which its fragments came. + * + * Fragments are added to the database in chunks. Each fragment in an entry + * keeps track of which chunk it came from, and as a chunk is added it keeps + * track of which entries contain its fragments. + * + * Fragments and domains are hashed in the database. The hash is described + * in the protocol document, but it's basically a truncated SHA256 hash. + */ // NSPR_LOG_MODULES=UrlClassifierDbService:5 #if defined(PR_LOGGING) static const PRLogModuleInfo *gUrlClassifierDbServiceLog = nsnull; #define LOG(args) PR_LOG(gUrlClassifierDbServiceLog, PR_LOG_DEBUG, args) +#define LOG_ENABLED() PR_LOG_TEST(gUrlClassifierDbServiceLog, 4) #else #define LOG(args) +#define LOG_ENABLED() (PR_FALSE) #endif // Change filename each time we change the db schema. -#define DATABASE_FILENAME "urlclassifier2.sqlite" +#define DATABASE_FILENAME "urlclassifier3.sqlite" + +#define MAX_HOST_COMPONENTS 5 +#define MAX_PATH_COMPONENTS 4 + +// Updates will fail if fed chunks larger than this +#define MAX_CHUNK_SIZE (1024 * 1024) + +#define KEY_LENGTH 16 // Singleton instance. static nsUrlClassifierDBService* sUrlClassifierDBService; @@ -79,41 +129,252 @@ static nsIThread* gDbBackgroundThread = nsnull; // thread. static PRBool gShuttingDownThread = PR_FALSE; -static const char* kNEW_TABLE_SUFFIX = "_new"; +// ------------------------------------------------------------------------- +// Hash class implementation -// This maps A-M to N-Z and N-Z to A-M. All other characters are left alone. -// Copied from mailnews/mime/src/mimetext.cpp -static const unsigned char kRot13Table[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, - 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, - 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, - 59, 60, 61, 62, 63, 64, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, - 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 91, 92, 93, 94, 95, 96, - 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 97, 98, - 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 123, 124, 125, 126, - 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, - 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, - 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, - 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, - 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, - 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, - 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, - 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, - 247, 248, 249, 250, 251, 252, 253, 254, 255 }; +// A convenience wrapper around the 16-byte hash for a domain or fragment. -// Does an in place rotation of the line -static void -Rot13Line(nsCString &line) +struct nsUrlClassifierHash { - nsCString::iterator start, end; - line.BeginWriting(start); - line.EndWriting(end); - while (start != end) { - *start = kRot13Table[static_cast(*start)]; - ++start; + PRUint8 buf[KEY_LENGTH]; + + nsresult FromPlaintext(const nsACString& plainText, nsICryptoHash *hash); + void Assign(const nsACString& str); + + const PRBool operator==(const nsUrlClassifierHash& hash) const { + return (memcmp(buf, hash.buf, sizeof(buf)) == 0); } +}; + +nsresult +nsUrlClassifierHash::FromPlaintext(const nsACString& plainText, + nsICryptoHash *hash) +{ + // From the protocol doc: + // Each entry in the chunk is composed of the 128 most significant bits + // of the SHA 256 hash of a suffix/prefix expression. + + nsresult rv = hash->Init(nsICryptoHash::SHA256); + NS_ENSURE_SUCCESS(rv, rv); + + rv = hash->Update + (reinterpret_cast(plainText.BeginReading()), + plainText.Length()); + NS_ENSURE_SUCCESS(rv, rv); + + nsCAutoString hashed; + rv = hash->Finish(PR_FALSE, hashed); + NS_ENSURE_SUCCESS(rv, rv); + + NS_ASSERTION(hashed.Length() >= KEY_LENGTH, + "not enough characters in the hash"); + + memcpy(buf, hashed.BeginReading(), KEY_LENGTH); + + return NS_OK; } +void +nsUrlClassifierHash::Assign(const nsACString& str) +{ + NS_ASSERTION(str.Length() >= KEY_LENGTH, + "string must be at least KEY_LENGTH characters long"); + memcpy(buf, str.BeginReading(), KEY_LENGTH); +} + +// ------------------------------------------------------------------------- +// Entry class implementation + +// This class represents one entry in the classifier database. It is a list +// of fragments and their associated chunks for a given key/table pair. +class nsUrlClassifierEntry +{ +public: + nsUrlClassifierEntry() : mId(0) {} + ~nsUrlClassifierEntry() {} + + // Read an entry from a database statement + PRBool ReadStatement(mozIStorageStatement* statement); + + // Prepare a statement to write this entry to the database + nsresult BindStatement(mozIStorageStatement* statement); + + // Add a single fragment associated with a given chunk + PRBool AddFragment(const nsUrlClassifierHash& hash, PRUint32 chunkNum); + + // Add all the fragments in a given entry to this entry + PRBool Merge(const nsUrlClassifierEntry& entry); + + // Remove all fragments in a given entry from this entry + PRBool SubtractFragments(const nsUrlClassifierEntry& entry); + + // Remove all fragments associated with a given chunk + PRBool SubtractChunk(PRUint32 chunkNum); + + // Check if there is a fragment with this hash in the entry + PRBool HasFragment(const nsUrlClassifierHash& hash); + + // Clear out the entry structure + void Clear(); + + PRBool IsEmpty() { return mFragments.Length() == 0; } + + nsUrlClassifierHash mKey; + PRUint32 mId; + PRUint32 mTableId; + +private: + // Add all the fragments from a database blob + PRBool AddFragments(const PRUint8* blob, PRUint32 blobLength); + + // One hash/chunkID pair in the fragment + struct Fragment { + nsUrlClassifierHash hash; + PRUint32 chunkNum; + + PRInt32 Diff(const Fragment& fragment) const { + PRInt32 cmp = memcmp(hash.buf, fragment.hash.buf, sizeof(hash.buf)); + if (cmp != 0) return cmp; + return chunkNum - fragment.chunkNum; + } + + PRBool operator==(const Fragment& fragment) const { + return (Diff(fragment) == 0); + } + + PRBool operator<(const Fragment& fragment) const { + return (Diff(fragment) < 0); + } + }; + + nsTArray mFragments; +}; + +PRBool +nsUrlClassifierEntry::ReadStatement(mozIStorageStatement* statement) +{ + mId = statement->AsInt32(0); + + PRUint32 size; + const PRUint8* blob = statement->AsSharedBlob(1, &size); + if (!blob || (size != KEY_LENGTH)) + return PR_FALSE; + memcpy(mKey.buf, blob, KEY_LENGTH); + + blob = statement->AsSharedBlob(2, &size); + if (!AddFragments(blob, size)) + return PR_FALSE; + + mTableId = statement->AsInt32(3); + + return PR_TRUE; +} + +nsresult +nsUrlClassifierEntry::BindStatement(mozIStorageStatement* statement) +{ + nsresult rv; + + if (mId == 0) + rv = statement->BindNullParameter(0); + else + rv = statement->BindInt32Parameter(0, mId); + NS_ENSURE_SUCCESS(rv, rv); + + rv = statement->BindBlobParameter(1, mKey.buf, KEY_LENGTH); + NS_ENSURE_SUCCESS(rv, rv); + + // Store the entries as one big blob. + // This results in a database that isn't portable between machines. + rv = statement->BindBlobParameter + (2, reinterpret_cast(mFragments.Elements()), + mFragments.Length() * sizeof(Fragment)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = statement->BindInt32Parameter(3, mTableId); + NS_ENSURE_SUCCESS(rv, rv); + + return PR_TRUE; +} + +PRBool +nsUrlClassifierEntry::AddFragment(const nsUrlClassifierHash& hash, + PRUint32 chunkNum) +{ + Fragment* fragment = mFragments.AppendElement(); + if (!fragment) + return PR_FALSE; + + fragment->hash = hash; + fragment->chunkNum = chunkNum; + + return PR_TRUE; +} + +PRBool +nsUrlClassifierEntry::AddFragments(const PRUint8* blob, PRUint32 blobLength) +{ + NS_ASSERTION(blobLength % sizeof(Fragment) == 0, + "Fragment blob not the right length"); + Fragment* fragment = mFragments.AppendElements + (reinterpret_cast(blob), blobLength / sizeof(Fragment)); + return (fragment != nsnull); +} + +PRBool +nsUrlClassifierEntry::Merge(const nsUrlClassifierEntry& entry) +{ + Fragment* fragment = mFragments.AppendElements(entry.mFragments); + return (fragment != nsnull); +} + +PRBool +nsUrlClassifierEntry::SubtractFragments(const nsUrlClassifierEntry& entry) +{ + for (PRUint32 i = 0; i < entry.mFragments.Length(); i++) { + for (PRUint32 j = 0; j < mFragments.Length(); j++) { + if (mFragments[j].hash == entry.mFragments[i].hash) { + mFragments.RemoveElementAt(j); + break; + } + } + } + + return PR_TRUE; +} + +PRBool +nsUrlClassifierEntry::SubtractChunk(PRUint32 chunkNum) +{ + PRUint32 i = 0; + while (i < mFragments.Length()) { + if (mFragments[i].chunkNum == chunkNum) + mFragments.RemoveElementAt(i); + else + i++; + } + + return PR_TRUE; +} + +PRBool +nsUrlClassifierEntry::HasFragment(const nsUrlClassifierHash& hash) +{ + for (PRUint32 i = 0; i < mFragments.Length(); i++) { + const Fragment& fragment = mFragments[i]; + if (fragment.hash == hash) + return PR_TRUE; + } + + return PR_FALSE; +} + +void +nsUrlClassifierEntry::Clear() +{ + mId = 0; + mFragments.Clear(); +} // ------------------------------------------------------------------------- // Actual worker implemenatation @@ -126,6 +387,13 @@ public: NS_DECL_NSIURLCLASSIFIERDBSERVICE NS_DECL_NSIURLCLASSIFIERDBSERVICEWORKER + // Initialize, called in the main thread + nsresult Init(); + + // Queue a lookup for the worker to perform, called in the main thread. + nsresult QueueLookup(const nsACString& lookupKey, + nsIUrlClassifierCallback* callback); + private: // No subclassing ~nsUrlClassifierDBServiceWorker(); @@ -133,83 +401,498 @@ private: // Disallow copy constructor nsUrlClassifierDBServiceWorker(nsUrlClassifierDBServiceWorker&); - // Table names have hyphens in them, which SQL doesn't allow, - // so we convert them to underscores. - void GetDbTableName(const nsACString& aTableName, nsCString* aDbTableName); - // Try to open the db, DATABASE_FILENAME. nsresult OpenDb(); - // Create a table in the db if it doesn't exist. - nsresult MaybeCreateTable(const nsCString& aTableName); + // Create table in the db if they don't exist. + nsresult MaybeCreateTables(mozIStorageConnection* connection); - // Drop a table if it exists. - nsresult MaybeDropTable(const nsCString& aTableName); + nsresult GetTableName(PRUint32 tableId, nsACString& table); + nsresult GetTableId(const nsACString& table, PRUint32* tableId); - // If this is not an update request, swap the new table - // in for the old table. - nsresult MaybeSwapTables(const nsCString& aVersionLine); + // Read the entry for a given key/table from the database + nsresult ReadEntry(const nsUrlClassifierHash& key, + PRUint32 tableId, + nsUrlClassifierEntry& entry); - // Parse a version string of the form [table-name #.###] or - // [table-name #.### update] and return the table name and - // whether or not it's an update. - nsresult ParseVersionString(const nsCSubstring& aLine, - nsCString* aTableName, - PRBool* aIsUpdate); + // Read the entry with a given ID from the database + nsresult ReadEntry(PRUint32 id, nsUrlClassifierEntry& entry); - // Handle a new table line of the form [table-name #.####]. We create the - // table if it doesn't exist and set the aTableName, aUpdateStatement, - // and aDeleteStatement. - nsresult ProcessNewTable(const nsCSubstring& aLine, - nsCString* aTableName, - mozIStorageStatement** aUpdateStatement, - mozIStorageStatement** aDeleteStatement); + // Remove an entry from the database + nsresult DeleteEntry(nsUrlClassifierEntry& entry); - // Handle an add or remove line. We execute additional update or delete - // statements. - nsresult ProcessUpdateTable(const nsCSubstring& aLine, - const nsCString& aTableName, - mozIStorageStatement* aUpdateStatement, - mozIStorageStatement* aDeleteStatement); + // Write an entry to the database + nsresult WriteEntry(nsUrlClassifierEntry& entry); + + // Decompress a zlib'ed chunk (used for -exp tables) + nsresult InflateChunk(nsACString& chunk); + + // Expand a chunk into its individual entries + nsresult GetChunkEntries(const nsACString& table, + PRUint32 tableId, + PRUint32 chunkNum, + nsACString& chunk, + nsTArray& entries); + + // Expand a stringified chunk list into an array of ints. + nsresult ParseChunkList(const nsACString& chunkStr, + nsTArray& chunks); + + // Join an array of ints into a stringified chunk list. + nsresult JoinChunkList(nsTArray& chunks, nsCString& chunkStr); + + // List the add/subtract chunks that have been applied to a table + nsresult GetChunkLists(PRUint32 tableId, + nsACString& addChunks, + nsACString& subChunks); + + // Set the list of add/subtract chunks that have been applied to a table + nsresult SetChunkLists(PRUint32 tableId, + const nsACString& addChunks, + const nsACString& subChunks); + + // Add a list of entries to the database, merging with + // existing entries as necessary + nsresult AddChunk(PRUint32 tableId, PRUint32 chunkNum, + nsTArray& entries); + + // Expire an add chunk + nsresult ExpireAdd(PRUint32 tableId, PRUint32 chunkNum); + + // Subtract a list of entries from the database + nsresult SubChunk(PRUint32 tableId, PRUint32 chunkNum, + nsTArray& entries); + + // Expire a subtract chunk + nsresult ExpireSub(PRUint32 tableId, PRUint32 chunkNum); + + // Handle line-oriented control information from a stream update + nsresult ProcessResponseLines(PRBool* done); + // Handle chunk data from a stream update + nsresult ProcessChunk(PRBool* done); + + // Reset an in-progress update + void ResetUpdate(); + + // take a lookup string (www.hostname.com/path/to/resource.html) and + // expand it into the set of fragments that should be searched for in an + // entry + nsresult GetLookupFragments(const nsCSubstring& spec, + nsTArray& fragments); + + // Get the database key for a given URI. This is the top three + // domain components if they exist, otherwise the top two. + // hostname.com/foo/bar -> hostname + // mail.hostname.com/foo/bar -> mail.hostname.com + // www.mail.hostname.com/foo/bar -> mail.hostname.com + nsresult GetKey(const nsACString& spec, nsUrlClassifierHash& hash); + + // Look for a given lookup string (www.hostname.com/path/to/resource.html) + // in the entries at the given key. Return the tableids found. + nsresult CheckKey(const nsCSubstring& spec, + const nsUrlClassifierHash& key, + nsTArray& tables); + + // Perform a classifier lookup for a given url. + nsresult DoLookup(const nsACString& spec, nsIUrlClassifierCallback* c); + + // Handle any queued-up lookups. We call this function during long-running + // update operations to prevent lookups from blocking for too long. + nsresult HandlePendingLookups(); + + nsCOMPtr mDBFile; + + nsCOMPtr mCryptoHash; // Holds a connection to the Db. We lazily initialize this because it has // to be created in the background thread (currently mozStorageConnection // isn't thread safe). - mozIStorageConnection* mConnection; + nsCOMPtr mConnection; - // True if we're in the middle of a streaming update. - PRBool mHasPendingUpdate; + nsCOMPtr mLookupStatement; + nsCOMPtr mLookupWithTableStatement; + nsCOMPtr mLookupWithIDStatement; - // For incremental updates, keep track of tables that have been updated. - // When finish() is called, we go ahead and pass these update lines to - // the callback. - nsTArray mTableUpdateLines; + nsCOMPtr mUpdateStatement; + nsCOMPtr mDeleteStatement; + + nsCOMPtr mAddChunkEntriesStatement; + nsCOMPtr mGetChunkEntriesStatement; + nsCOMPtr mDeleteChunkEntriesStatement; + + nsCOMPtr mGetChunkListsStatement; + nsCOMPtr mSetChunkListsStatement; + + nsCOMPtr mGetTablesStatement; + nsCOMPtr mGetTableIdStatement; + nsCOMPtr mGetTableNameStatement; + nsCOMPtr mInsertTableIdStatement; // We receive data in small chunks that may be broken in the middle of // a line. So we save the last partial line here. nsCString mPendingStreamUpdate; + + PRInt32 mUpdateWait; + + enum { + STATE_LINE, + STATE_CHUNK + } mState; + + enum { + CHUNK_ADD, + CHUNK_SUB + } mChunkType; + + PRUint32 mChunkNum; + PRUint32 mChunkLen; + + nsCString mUpdateTable; + PRUint32 mUpdateTableId; + + nsresult mUpdateStatus; + + // Pending lookups are stored in a queue for processing. The queue + // is protected by mPendingLookupLock. + PRLock* mPendingLookupLock; + + class PendingLookup { + public: + nsCString mKey; + nsCOMPtr mCallback; + }; + + // list of pending lookups + nsTArray mPendingLookups; }; NS_IMPL_THREADSAFE_ISUPPORTS1(nsUrlClassifierDBServiceWorker, nsIUrlClassifierDBServiceWorker) nsUrlClassifierDBServiceWorker::nsUrlClassifierDBServiceWorker() - : mConnection(nsnull), mHasPendingUpdate(PR_FALSE), mTableUpdateLines() + : mUpdateStatus(NS_OK) + , mPendingLookupLock(nsnull) { } -nsUrlClassifierDBServiceWorker::~nsUrlClassifierDBServiceWorker() -{ - NS_ASSERTION(mConnection == nsnull, - "Db connection not closed, leaking memory! Call CloseDb " - "to close the connection."); -} +nsUrlClassifierDBServiceWorker::~nsUrlClassifierDBServiceWorker() +{ + NS_ASSERTION(!mConnection, + "Db connection not closed, leaking memory! Call CloseDb " + "to close the connection."); + if (mPendingLookupLock) + PR_DestroyLock(mPendingLookupLock); +} + +nsresult +nsUrlClassifierDBServiceWorker::Init() +{ + // Compute database filename + + // Because we dump raw integers into the database, this database isn't + // portable between machine types, so store it in the local profile dir. + nsresult rv = NS_GetSpecialDirectory(NS_APP_USER_PROFILE_LOCAL_50_DIR, + getter_AddRefs(mDBFile)); + if (NS_FAILED(rv)) return rv; + + rv = mDBFile->Append(NS_LITERAL_STRING(DATABASE_FILENAME)); + NS_ENSURE_SUCCESS(rv, rv); + + mPendingLookupLock = PR_NewLock(); + if (!mPendingLookupLock) + return NS_ERROR_OUT_OF_MEMORY; + + ResetUpdate(); + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::QueueLookup(const nsACString& spec, + nsIUrlClassifierCallback* callback) +{ + nsAutoLock lock(mPendingLookupLock); + + PendingLookup* lookup = mPendingLookups.AppendElement(); + if (!lookup) return NS_ERROR_OUT_OF_MEMORY; + + lookup->mKey = spec; + lookup->mCallback = callback; + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::GetLookupFragments(const nsACString& spec, + nsTArray& fragments) +{ + fragments.Clear(); + + nsACString::const_iterator begin, end, iter; + spec.BeginReading(begin); + spec.EndReading(end); + + iter = begin; + if (!FindCharInReadable('/', iter, end)) { + return NS_OK; + } + + const nsCSubstring& host = Substring(begin, iter++); + const nsCSubstring& path = Substring(iter, end); + + /** + * From the protocol doc: + * For the hostname, the client will try at most 5 different strings. They + * are: + * a) The exact hostname of the url + * b) The 4 hostnames formed by starting with the last 5 components and + * successivly removing the leading component. The top-level component + * can be skipped. + */ + nsCStringArray hosts; + hosts.AppendCString(host); + + host.BeginReading(begin); + host.EndReading(end); + int numComponents = 0; + while (RFindInReadable(NS_LITERAL_CSTRING("."), begin, end) && + numComponents < MAX_HOST_COMPONENTS) { + // don't bother checking toplevel domains + if (++numComponents >= 2) { + host.EndReading(iter); + hosts.AppendCString(Substring(end, iter)); + } + end = begin; + host.BeginReading(begin); + } + + /** + * From the protocol doc: + * For the path, the client will also try at most 5 different strings. + * They are: + * a) the exact path of the url + * b) the 4 paths formed by starting at the root (/) and + * successively appending path components, including a trailing + * slash. This behavior should only extend up to the next-to-last + * path component, that is, a trailing slash should never be + * appended that was not present in the original url. + */ + nsCStringArray paths; + paths.AppendCString(path); + + numComponents = 0; + path.BeginReading(begin); + path.EndReading(end); + iter = begin; + while (FindCharInReadable('/', iter, end) && + numComponents < MAX_PATH_COMPONENTS) { + iter++; + paths.AppendCString(Substring(begin, iter)); + numComponents++; + } + + /** + * "In addition to these, the client should look up the exact host + * and exact path, with a trailing '$' appended." */ + nsCAutoString key; + key.Assign(spec); + key.Append('$'); + LOG(("Chking %s", key.get())); + + nsUrlClassifierHash* hash = fragments.AppendElement(); + if (!hash) return NS_ERROR_OUT_OF_MEMORY; + hash->FromPlaintext(key, mCryptoHash); + + for (int hostIndex = 0; hostIndex < hosts.Count(); hostIndex++) { + for (int pathIndex = 0; pathIndex < paths.Count(); pathIndex++) { + key.Assign(*hosts[hostIndex]); + key.Append('/'); + key.Append(*paths[pathIndex]); + LOG(("Chking %s", key.get())); + + hash = fragments.AppendElement(); + if (!hash) return NS_ERROR_OUT_OF_MEMORY; + hash->FromPlaintext(key, mCryptoHash); + } + } + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::CheckKey(const nsACString& spec, + const nsUrlClassifierHash& hash, + nsTArray& tables) +{ + mozStorageStatementScoper lookupScoper(mLookupStatement); + + nsresult rv = mLookupStatement->BindBlobParameter + (0, hash.buf, KEY_LENGTH); + NS_ENSURE_SUCCESS(rv, rv); + + nsTArray fragments; + PRBool haveFragments = PR_FALSE; + + PRBool exists; + rv = mLookupStatement->ExecuteStep(&exists); + NS_ENSURE_SUCCESS(rv, rv); + while (exists) { + if (!haveFragments) { + rv = GetLookupFragments(spec, fragments); + NS_ENSURE_SUCCESS(rv, rv); + haveFragments = PR_TRUE; + } + + nsUrlClassifierEntry entry; + if (!entry.ReadStatement(mLookupStatement)) + return NS_ERROR_FAILURE; + + for (PRUint32 i = 0; i < fragments.Length(); i++) { + if (entry.HasFragment(fragments[i])) { + tables.AppendElement(entry.mTableId); + break; + } + } + + rv = mLookupStatement->ExecuteStep(&exists); + NS_ENSURE_SUCCESS(rv, rv); + } + + return NS_OK; +} + +/** + * Lookup up a key in the database is a two step process: + * + * a) First we look for any Entries in the database that might apply to this + * url. For each URL there are one or two possible domain names to check: + * the two-part domain name (example.com) and the three-part name + * (www.example.com). We check the database for both of these. + * b) If we find any entries, we check the list of fragments for that entry + * against the possible subfragments of the URL as described in the + * "Simplified Regular Expression Lookup" section of the protocol doc. + */ +nsresult +nsUrlClassifierDBServiceWorker::DoLookup(const nsACString& spec, + nsIUrlClassifierCallback* c) +{ + if (gShuttingDownThread) { + c->HandleEvent(EmptyCString()); + return NS_ERROR_NOT_INITIALIZED; + } + + nsresult rv = OpenDb(); + if (NS_FAILED(rv)) { + c->HandleEvent(EmptyCString()); + return NS_ERROR_FAILURE; + } + +#if defined(PR_LOGGING) + PRIntervalTime clockStart = 0; + if (LOG_ENABLED()) { + clockStart = PR_IntervalNow(); + } +#endif + + nsACString::const_iterator begin, end, iter; + spec.BeginReading(begin); + spec.EndReading(end); + + iter = begin; + if (!FindCharInReadable('/', iter, end)) { + return NS_OK; + } + + const nsCSubstring& host = Substring(begin, iter++); + nsCStringArray hostComponents; + hostComponents.ParseString(PromiseFlatCString(host).get(), "."); + + if (hostComponents.Count() < 2) { + // no host or toplevel host, this won't match anything in the db + c->HandleEvent(EmptyCString()); + return NS_OK; + } + + // First check with two domain components + PRInt32 last = hostComponents.Count() - 1; + nsCAutoString lookupHost; + lookupHost.Assign(*hostComponents[last - 1]); + lookupHost.Append("."); + lookupHost.Append(*hostComponents[last]); + lookupHost.Append("/"); + nsUrlClassifierHash hash; + hash.FromPlaintext(lookupHost, mCryptoHash); + + // we ignore failures from CheckKey because we'd rather try to find + // more results than fail. + nsTArray resultTables; + CheckKey(spec, hash, resultTables); + + // Now check with three domain components + if (hostComponents.Count() > 2) { + nsCAutoString lookupHost2; + lookupHost2.Assign(*hostComponents[last - 2]); + lookupHost2.Append("."); + lookupHost2.Append(lookupHost); + hash.FromPlaintext(lookupHost2, mCryptoHash); + + CheckKey(spec, hash, resultTables); + } + + nsCAutoString result; + for (PRUint32 i = 0; i < resultTables.Length(); i++) { + nsCAutoString tableName; + GetTableName(resultTables[i], tableName); + + // ignore GetTableName failures - we want to try to get as many of the + // matched tables as possible + if (!result.IsEmpty()) { + result.Append(','); + } + result.Append(tableName); + } + +#if defined(PR_LOGGING) + if (LOG_ENABLED()) { + PRIntervalTime clockEnd = PR_IntervalNow(); + LOG(("query took %dms\n", + PR_IntervalToMilliseconds(clockEnd - clockStart))); + } +#endif + + c->HandleEvent(result); + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::HandlePendingLookups() +{ + nsAutoLock lock(mPendingLookupLock); + while (mPendingLookups.Length() > 0) { + PendingLookup lookup = mPendingLookups[0]; + mPendingLookups.RemoveElementAt(0); + lock.unlock(); + + DoLookup(lookup.mKey, lookup.mCallback); + + lock.lock(); + } + + return NS_OK; +} // Lookup a key in the db. NS_IMETHODIMP -nsUrlClassifierDBServiceWorker::Exists(const nsACString& tableName, - const nsACString& key, - nsIUrlClassifierCallback *c) +nsUrlClassifierDBServiceWorker::Lookup(const nsACString& spec, + nsIUrlClassifierCallback* c, + PRBool needsProxy) +{ + return HandlePendingLookups(); +} + +NS_IMETHODIMP +nsUrlClassifierDBServiceWorker::GetTables(nsIUrlClassifierCallback* c) { if (gShuttingDownThread) return NS_ERROR_NOT_INITIALIZED; @@ -220,169 +903,864 @@ nsUrlClassifierDBServiceWorker::Exists(const nsACString& tableName, return NS_ERROR_FAILURE; } - nsCAutoString dbTableName; - GetDbTableName(tableName, &dbTableName); + mozStorageStatementScoper scoper(mGetTablesStatement); - nsCOMPtr selectStatement; - nsCAutoString statement; - statement.AssignLiteral("SELECT value FROM "); - statement.Append(dbTableName); - statement.AppendLiteral(" WHERE key = ?1"); + nsCAutoString response; + PRBool hasMore; + while (NS_SUCCEEDED(rv = mGetTablesStatement->ExecuteStep(&hasMore)) && + hasMore) { + nsCAutoString val; + mGetTablesStatement->GetUTF8String(0, val); - rv = mConnection->CreateStatement(statement, - getter_AddRefs(selectStatement)); - - nsAutoString value; - // If CreateStatment failed, this probably means the table doesn't exist. - // That's ok, we just return an emptry string. - if (NS_SUCCEEDED(rv)) { - nsCString keyROT13(key); - Rot13Line(keyROT13); - rv = selectStatement->BindUTF8StringParameter(0, keyROT13); - NS_ENSURE_SUCCESS(rv, rv); - - PRBool hasMore = PR_FALSE; - rv = selectStatement->ExecuteStep(&hasMore); - // If the table has any columns, take the first value. - if (NS_SUCCEEDED(rv) && hasMore) { - selectStatement->GetString(0, value); + if (val.IsEmpty()) { + continue; } + + response.Append(val); + response.Append(';'); + + mGetTablesStatement->GetUTF8String(1, val); + + if (!val.IsEmpty()) { + response.Append("a:"); + response.Append(val); + } + + mGetTablesStatement->GetUTF8String(2, val); + if (!val.IsEmpty()) { + response.Append("s:"); + response.Append(val); + } + + response.Append('\n'); } - c->HandleEvent(NS_ConvertUTF16toUTF8(value)); + if (NS_FAILED(rv)) { + response.Truncate(); + } + + c->HandleEvent(response); + + return rv; +} + +nsresult +nsUrlClassifierDBServiceWorker::GetTableId(const nsACString& table, + PRUint32* tableId) +{ + mozStorageStatementScoper findScoper(mGetTableIdStatement); + + nsresult rv = mGetTableIdStatement->BindUTF8StringParameter(0, table); + NS_ENSURE_SUCCESS(rv, rv); + + PRBool exists; + rv = mGetTableIdStatement->ExecuteStep(&exists); + NS_ENSURE_SUCCESS(rv, rv); + if (exists) { + *tableId = mGetTableIdStatement->AsInt32(0); + return NS_OK; + } + + mozStorageStatementScoper insertScoper(mInsertTableIdStatement); + rv = mInsertTableIdStatement->BindUTF8StringParameter(0, table); + NS_ENSURE_SUCCESS(rv, rv); + + rv = mInsertTableIdStatement->Execute(); + NS_ENSURE_SUCCESS(rv, rv); + + PRInt64 rowId; + rv = mConnection->GetLastInsertRowID(&rowId); + NS_ENSURE_SUCCESS(rv, rv); + + if (rowId > PR_UINT32_MAX) + return NS_ERROR_FAILURE; + + *tableId = rowId; + return NS_OK; } -// We get a comma separated list of table names. For each table that doesn't -// exist, we return it in a comma separated list via the callback. -NS_IMETHODIMP -nsUrlClassifierDBServiceWorker::CheckTables(const nsACString & tableNames, - nsIUrlClassifierCallback *c) +nsresult +nsUrlClassifierDBServiceWorker::GetTableName(PRUint32 tableId, + nsACString& tableName) { - if (gShuttingDownThread) - return NS_ERROR_NOT_INITIALIZED; + mozStorageStatementScoper findScoper(mGetTableNameStatement); + nsresult rv = mGetTableNameStatement->BindInt32Parameter(0, tableId); + NS_ENSURE_SUCCESS(rv, rv); - nsresult rv = OpenDb(); - if (NS_FAILED(rv)) { - NS_ERROR("Unable to open database"); + PRBool exists; + rv = mGetTableNameStatement->ExecuteStep(&exists); + NS_ENSURE_SUCCESS(rv, rv); + if (!exists) return NS_ERROR_FAILURE; + + return mGetTableNameStatement->GetUTF8String(0, tableName); +} + +nsresult +nsUrlClassifierDBServiceWorker::InflateChunk(nsACString& chunk) +{ + nsCAutoString inflated; + char buf[4096]; + + const nsPromiseFlatCString& flat = PromiseFlatCString(chunk); + + z_stream stream; + memset(&stream, 0, sizeof(stream)); + stream.next_in = (Bytef*)flat.get(); + stream.avail_in = flat.Length(); + + if (inflateInit(&stream) != Z_OK) { return NS_ERROR_FAILURE; } - nsCAutoString changedTables; + int code; + do { + stream.next_out = (Bytef*)buf; + stream.avail_out = sizeof(buf); - // tablesNames is a comma separated list, so get each table name out for - // checking. - PRUint32 cur = 0; - PRInt32 next; - while (cur < tableNames.Length()) { - next = tableNames.FindChar(',', cur); - if (kNotFound == next) { - next = tableNames.Length(); - } - const nsCSubstring &tableName = Substring(tableNames, cur, next - cur); - cur = next + 1; + code = inflate(&stream, Z_NO_FLUSH); + PRUint32 numRead = sizeof(buf) - stream.avail_out; - nsCString dbTableName; - GetDbTableName(tableName, &dbTableName); - PRBool exists; - nsresult rv = mConnection->TableExists(dbTableName, &exists); - NS_ENSURE_SUCCESS(rv, rv); - if (!exists) { - if (changedTables.Length() > 0) - changedTables.Append(","); - changedTables.Append(tableName); + if (code == Z_OK || code == Z_STREAM_END) { + inflated.Append(buf, numRead); } + } while (code == Z_OK); + + inflateEnd(&stream); + + if (code != Z_STREAM_END) { + return NS_ERROR_FAILURE; } - c->HandleEvent(changedTables); + chunk = inflated; + return NS_OK; } -// Do a batch update of the database. After we complete processing a table, -// we call the callback with the table line. -NS_IMETHODIMP -nsUrlClassifierDBServiceWorker::UpdateTables(const nsACString& updateString, - nsIUrlClassifierCallback *c) +nsresult +nsUrlClassifierDBServiceWorker::ReadEntry(const nsUrlClassifierHash& hash, + PRUint32 tableId, + nsUrlClassifierEntry& entry) { - if (gShuttingDownThread) - return NS_ERROR_NOT_INITIALIZED; + entry.Clear(); - LOG(("Updating tables\n")); + mozStorageStatementScoper scoper(mLookupWithTableStatement); - nsresult rv = OpenDb(); - if (NS_FAILED(rv)) { - NS_ERROR("Unable to open database"); + nsresult rv = mLookupWithTableStatement->BindBlobParameter + (0, hash.buf, KEY_LENGTH); + NS_ENSURE_SUCCESS(rv, rv); + rv = mLookupWithTableStatement->BindInt32Parameter(1, tableId); + NS_ENSURE_SUCCESS(rv, rv); + + PRBool exists; + rv = mLookupWithTableStatement->ExecuteStep(&exists); + NS_ENSURE_SUCCESS(rv, rv); + + if (exists) { + if (!entry.ReadStatement(mLookupWithTableStatement)) + return NS_ERROR_FAILURE; + } else { + // New entry, initialize it + entry.mKey = hash; + entry.mTableId = tableId; + } + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::ReadEntry(PRUint32 id, + nsUrlClassifierEntry& entry) +{ + entry.Clear(); + entry.mId = id; + + mozStorageStatementScoper scoper(mLookupWithIDStatement); + + nsresult rv = mLookupWithIDStatement->BindInt32Parameter(0, id); + NS_ENSURE_SUCCESS(rv, rv); + rv = mLookupWithIDStatement->BindInt32Parameter(0, id); + NS_ENSURE_SUCCESS(rv, rv); + + PRBool exists; + rv = mLookupWithIDStatement->ExecuteStep(&exists); + NS_ENSURE_SUCCESS(rv, rv); + + if (exists) { + if (!entry.ReadStatement(mLookupWithIDStatement)) + return NS_ERROR_FAILURE; + } + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::DeleteEntry(nsUrlClassifierEntry& entry) +{ + if (entry.mId == 0) { + return NS_OK; + } + + mozStorageStatementScoper scoper(mDeleteStatement); + mDeleteStatement->BindInt32Parameter(0, entry.mId); + nsresult rv = mDeleteStatement->Execute(); + NS_ENSURE_SUCCESS(rv, rv); + + entry.mId = 0; + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::WriteEntry(nsUrlClassifierEntry& entry) +{ + mozStorageStatementScoper scoper(mUpdateStatement); + + if (entry.IsEmpty()) { + return DeleteEntry(entry); + } + + PRBool newEntry = (entry.mId == 0); + + nsresult rv = entry.BindStatement(mUpdateStatement); + NS_ENSURE_SUCCESS(rv, rv); + + rv = mUpdateStatement->Execute(); + NS_ENSURE_SUCCESS(rv, rv); + + if (newEntry) { + PRInt64 rowId; + rv = mConnection->GetLastInsertRowID(&rowId); + NS_ENSURE_SUCCESS(rv, rv); + + if (rowId > PR_UINT32_MAX) { + return NS_ERROR_FAILURE; + } + + entry.mId = rowId; + } + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::GetKey(const nsACString& spec, + nsUrlClassifierHash& hash) +{ + nsACString::const_iterator begin, end, iter; + spec.BeginReading(begin); + spec.EndReading(end); + + iter = begin; + if (!FindCharInReadable('/', iter, end)) { + return NS_OK; + } + + const nsCSubstring& host = Substring(begin, iter++); + nsCStringArray hostComponents; + hostComponents.ParseString(PromiseFlatCString(host).get(), "."); + + if (hostComponents.Count() < 2) + return NS_ERROR_FAILURE; + + PRInt32 last = hostComponents.Count() - 1; + nsCAutoString lookupHost; + + if (hostComponents.Count() > 2) { + lookupHost.Append(*hostComponents[last - 2]); + lookupHost.Append("."); + } + + lookupHost.Append(*hostComponents[last - 1]); + lookupHost.Append("."); + lookupHost.Append(*hostComponents[last]); + lookupHost.Append("/"); + + return hash.FromPlaintext(lookupHost, mCryptoHash); +} + +nsresult +nsUrlClassifierDBServiceWorker::GetChunkEntries(const nsACString& table, + PRUint32 tableId, + PRUint32 chunkNum, + nsACString& chunk, + nsTArray& entries) +{ + nsresult rv; + if (StringEndsWith(table, NS_LITERAL_CSTRING("-exp"))) { + // regexp tables need to be ungzipped + rv = InflateChunk(chunk); + NS_ENSURE_SUCCESS(rv, rv); + } + + if (StringEndsWith(table, NS_LITERAL_CSTRING("-sha128"))) { + PRUint32 start = 0; + while (start + KEY_LENGTH + 1 <= chunk.Length()) { + nsUrlClassifierEntry* entry = entries.AppendElement(); + if (!entry) return NS_ERROR_OUT_OF_MEMORY; + + // first 16 bytes are the domain/key + entry->mKey.Assign(Substring(chunk, start, KEY_LENGTH)); + + start += KEY_LENGTH; + // then there is a one-byte count of fragments + PRUint8 numEntries = static_cast(chunk[start]); + start++; + + if (numEntries == 0) { + // if there are no fragments, the domain itself is treated as a + // fragment + entry->AddFragment(entry->mKey, chunkNum); + } else { + if (start + (numEntries * KEY_LENGTH) >= chunk.Length()) { + // there isn't as much data as they said there would be. + return NS_ERROR_FAILURE; + } + + for (PRUint8 i = 0; i < numEntries; i++) { + nsUrlClassifierHash hash; + hash.Assign(Substring(chunk, start, KEY_LENGTH)); + entry->AddFragment(hash, chunkNum); + start += KEY_LENGTH; + } + } + } + } else { + nsCStringArray lines; + lines.ParseString(PromiseFlatCString(chunk).get(), "\n"); + + // non-hashed tables need to be hashed + for (PRInt32 i = 0; i < lines.Count(); i++) { + nsUrlClassifierEntry* entry = entries.AppendElement(); + if (!entry) return NS_ERROR_OUT_OF_MEMORY; + + rv = GetKey(*lines[i], entry->mKey); + NS_ENSURE_SUCCESS(rv, rv); + + entry->mTableId = tableId; + nsUrlClassifierHash hash; + hash.FromPlaintext(*lines[i], mCryptoHash); + entry->AddFragment(hash, mChunkNum); + } + } + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::ParseChunkList(const nsACString& chunkStr, + nsTArray& chunks) +{ + LOG(("Parsing %s", PromiseFlatCString(chunkStr).get())); + + nsCStringArray elements; + elements.ParseString(PromiseFlatCString(chunkStr).get() , ","); + + for (PRInt32 i = 0; i < elements.Count(); i++) { + nsCString& element = *elements[i]; + + PRUint32 first; + PRUint32 last; + if (PR_sscanf(element.get(), "%u-%u", &first, &last) == 2) { + if (first > last) { + PRUint32 tmp = first; + first = last; + last = tmp; + } + for (PRUint32 num = first; num <= last; num++) { + chunks.AppendElement(num); + } + } else if (PR_sscanf(element.get(), "%u", &first) == 1) { + chunks.AppendElement(first); + } + } + + LOG(("Got %d elements.", chunks.Length())); + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::JoinChunkList(nsTArray& chunks, + nsCString& chunkStr) +{ + chunkStr.Truncate(); + chunks.Sort(); + + PRUint32 i = 0; + while (i < chunks.Length()) { + if (i != 0) { + chunkStr.Append(','); + } + chunkStr.AppendInt(chunks[i]); + + PRUint32 first = i; + PRUint32 last = first; + i++; + while (i < chunks.Length() && chunks[i] == chunks[i - 1] + 1) { + last = chunks[i++]; + } + + if (last != first) { + chunkStr.Append('-'); + chunkStr.AppendInt(last); + } + } + + return NS_OK; +} + + +nsresult +nsUrlClassifierDBServiceWorker::GetChunkLists(PRUint32 tableId, + nsACString& addChunks, + nsACString& subChunks) +{ + addChunks.Truncate(); + subChunks.Truncate(); + + mozStorageStatementScoper scoper(mGetChunkListsStatement); + + nsresult rv = mGetChunkListsStatement->BindInt32Parameter(0, tableId); + NS_ENSURE_SUCCESS(rv, rv); + + PRBool hasMore = PR_FALSE; + rv = mGetChunkListsStatement->ExecuteStep(&hasMore); + NS_ENSURE_SUCCESS(rv, rv); + + if (!hasMore) { + LOG(("Getting chunks for %d, found nothing", tableId)); + return NS_OK; + } + + rv = mGetChunkListsStatement->GetUTF8String(0, addChunks); + NS_ENSURE_SUCCESS(rv, rv); + + rv = mGetChunkListsStatement->GetUTF8String(1, subChunks); + NS_ENSURE_SUCCESS(rv, rv); + + LOG(("Getting chunks for %d, got %s %s", + tableId, + PromiseFlatCString(addChunks).get(), + PromiseFlatCString(subChunks).get())); + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::SetChunkLists(PRUint32 tableId, + const nsACString& addChunks, + const nsACString& subChunks) +{ + mozStorageStatementScoper scoper(mSetChunkListsStatement); + + mSetChunkListsStatement->BindUTF8StringParameter(0, addChunks); + mSetChunkListsStatement->BindUTF8StringParameter(1, subChunks); + mSetChunkListsStatement->BindInt32Parameter(2, tableId); + nsresult rv = mSetChunkListsStatement->Execute(); + NS_ENSURE_SUCCESS(rv, rv); + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::AddChunk(PRUint32 tableId, + PRUint32 chunkNum, + nsTArray& entries) +{ +#if defined(PR_LOGGING) + PRIntervalTime clockStart = 0; + if (LOG_ENABLED()) { + clockStart = PR_IntervalNow(); + } +#endif + + LOG(("Adding %d entries to chunk %d", entries.Length(), chunkNum)); + + mozStorageTransaction transaction(mConnection, PR_FALSE); + + nsCAutoString addChunks; + nsCAutoString subChunks; + + HandlePendingLookups(); + + nsresult rv = GetChunkLists(tableId, addChunks, subChunks); + NS_ENSURE_SUCCESS(rv, rv); + + nsTArray adds; + ParseChunkList(addChunks, adds); + adds.AppendElement(chunkNum); + JoinChunkList(adds, addChunks); + rv = SetChunkLists(tableId, addChunks, subChunks); + NS_ENSURE_SUCCESS(rv, rv); + + nsTArray entryIDs; + + for (PRUint32 i = 0; i < entries.Length(); i++) { + nsUrlClassifierEntry& thisEntry = entries[i]; + + HandlePendingLookups(); + + nsUrlClassifierEntry existingEntry; + rv = ReadEntry(thisEntry.mKey, tableId, existingEntry); + NS_ENSURE_SUCCESS(rv, rv); + + if (!existingEntry.Merge(thisEntry)) + return NS_ERROR_FAILURE; + + HandlePendingLookups(); + + rv = WriteEntry(existingEntry); + NS_ENSURE_SUCCESS(rv, rv); + + entryIDs.AppendElement(existingEntry.mId); + } + + mozStorageStatementScoper scoper(mAddChunkEntriesStatement); + rv = mAddChunkEntriesStatement->BindInt32Parameter(0, chunkNum); + NS_ENSURE_SUCCESS(rv, rv); + + mAddChunkEntriesStatement->BindInt32Parameter(1, tableId); + NS_ENSURE_SUCCESS(rv, rv); + + mAddChunkEntriesStatement->BindBlobParameter + (2, + reinterpret_cast(entryIDs.Elements()), + entryIDs.Length() * sizeof(PRUint32)); + + HandlePendingLookups(); + + rv = mAddChunkEntriesStatement->Execute(); + NS_ENSURE_SUCCESS(rv, rv); + + HandlePendingLookups(); + + rv = transaction.Commit(); + NS_ENSURE_SUCCESS(rv, rv); + +#if defined(PR_LOGGING) + if (LOG_ENABLED()) { + PRIntervalTime clockEnd = PR_IntervalNow(); + printf("adding chunk %d took %dms\n", chunkNum, + PR_IntervalToMilliseconds(clockEnd - clockStart)); + } +#endif + + return rv; +} + +nsresult +nsUrlClassifierDBServiceWorker::ExpireAdd(PRUint32 tableId, + PRUint32 chunkNum) +{ + mozStorageTransaction transaction(mConnection, PR_FALSE); + + LOG(("Expiring chunk %d\n", chunkNum)); + + nsCAutoString addChunks; + nsCAutoString subChunks; + + HandlePendingLookups(); + + nsresult rv = GetChunkLists(tableId, addChunks, subChunks); + NS_ENSURE_SUCCESS(rv, rv); + + nsTArray adds; + ParseChunkList(addChunks, adds); + adds.RemoveElement(chunkNum); + JoinChunkList(adds, addChunks); + rv = SetChunkLists(tableId, addChunks, subChunks); + NS_ENSURE_SUCCESS(rv, rv); + + mozStorageStatementScoper getChunkEntriesScoper(mGetChunkEntriesStatement); + + rv = mGetChunkEntriesStatement->BindInt32Parameter(0, chunkNum); + NS_ENSURE_SUCCESS(rv, rv); + rv = mGetChunkEntriesStatement->BindInt32Parameter(1, tableId); + NS_ENSURE_SUCCESS(rv, rv); + + HandlePendingLookups(); + + PRBool exists; + rv = mGetChunkEntriesStatement->ExecuteStep(&exists); + NS_ENSURE_SUCCESS(rv, rv); + while (exists) { + PRUint32 size; + const PRUint8* blob = mGetChunkEntriesStatement->AsSharedBlob(0, &size); + if (blob) { + const PRUint32* entries = reinterpret_cast(blob); + for (PRUint32 i = 0; i < (size / sizeof(PRUint32)); i++) { + HandlePendingLookups(); + + nsUrlClassifierEntry entry; + rv = ReadEntry(entries[i], entry); + NS_ENSURE_SUCCESS(rv, rv); + + entry.SubtractChunk(chunkNum); + + HandlePendingLookups(); + + rv = WriteEntry(entry); + NS_ENSURE_SUCCESS(rv, rv); + } + } + + HandlePendingLookups(); + rv = mGetChunkEntriesStatement->ExecuteStep(&exists); + NS_ENSURE_SUCCESS(rv, rv); + } + + HandlePendingLookups(); + + mozStorageStatementScoper removeScoper(mDeleteChunkEntriesStatement); + mDeleteChunkEntriesStatement->BindInt32Parameter(0, tableId); + mDeleteChunkEntriesStatement->BindInt32Parameter(1, chunkNum); + rv = mDeleteChunkEntriesStatement->Execute(); + NS_ENSURE_SUCCESS(rv, rv); + + HandlePendingLookups(); + + return transaction.Commit(); +} + +nsresult +nsUrlClassifierDBServiceWorker::SubChunk(PRUint32 tableId, + PRUint32 chunkNum, + nsTArray& entries) +{ + mozStorageTransaction transaction(mConnection, PR_FALSE); + + nsCAutoString addChunks; + nsCAutoString subChunks; + + HandlePendingLookups(); + + nsresult rv = GetChunkLists(tableId, addChunks, subChunks); + NS_ENSURE_SUCCESS(rv, rv); + + nsTArray subs; + ParseChunkList(subChunks, subs); + subs.AppendElement(chunkNum); + JoinChunkList(subs, subChunks); + rv = SetChunkLists(tableId, addChunks, subChunks); + NS_ENSURE_SUCCESS(rv, rv); + + for (PRUint32 i = 0; i < entries.Length(); i++) { + nsUrlClassifierEntry& thisEntry = entries[i]; + + HandlePendingLookups(); + + nsUrlClassifierEntry existingEntry; + rv = ReadEntry(thisEntry.mKey, tableId, existingEntry); + NS_ENSURE_SUCCESS(rv, rv); + + if (!existingEntry.SubtractFragments(thisEntry)) + return NS_ERROR_FAILURE; + + HandlePendingLookups(); + + rv = WriteEntry(existingEntry); + NS_ENSURE_SUCCESS(rv, rv); + } + + HandlePendingLookups(); + + return transaction.Commit(); +} + +nsresult +nsUrlClassifierDBServiceWorker::ExpireSub(PRUint32 tableId, PRUint32 chunkNum) +{ + mozStorageTransaction transaction(mConnection, PR_FALSE); + + nsCAutoString addChunks; + nsCAutoString subChunks; + + HandlePendingLookups(); + + nsresult rv = GetChunkLists(tableId, addChunks, subChunks); + NS_ENSURE_SUCCESS(rv, rv); + + nsTArray subs; + ParseChunkList(subChunks, subs); + subs.RemoveElement(chunkNum); + JoinChunkList(subs, subChunks); + rv = SetChunkLists(tableId, addChunks, subChunks); + NS_ENSURE_SUCCESS(rv, rv); + + HandlePendingLookups(); + + return transaction.Commit(); +} + +nsresult +nsUrlClassifierDBServiceWorker::ProcessChunk(PRBool* done) +{ + // wait until the chunk plus terminating \n has been read + if (mPendingStreamUpdate.Length() <= static_cast(mChunkLen)) { + *done = PR_TRUE; + return NS_OK; + } + + if (mPendingStreamUpdate[mChunkLen] != '\n') { + LOG(("Didn't get a terminating newline after the chunk, failing the update")); return NS_ERROR_FAILURE; } - rv = mConnection->BeginTransaction(); - NS_ASSERTION(NS_SUCCEEDED(rv), "Unable to begin transaction"); + nsCAutoString chunk; + chunk.Assign(Substring(mPendingStreamUpdate, 0, mChunkLen)); + mPendingStreamUpdate = Substring(mPendingStreamUpdate, mChunkLen); - // Split the update string into lines + LOG(("Handling a chunk sized %d", chunk.Length())); + + nsTArray entries; + GetChunkEntries(mUpdateTable, mUpdateTableId, mChunkNum, chunk, entries); + + nsresult rv; + + if (mChunkType == CHUNK_ADD) { + rv = AddChunk(mUpdateTableId, mChunkNum, entries); + } else { + rv = SubChunk(mUpdateTableId, mChunkNum, entries); + } + + // pop off the chunk and the trailing \n + mPendingStreamUpdate = Substring(mPendingStreamUpdate, 1); + + mState = STATE_LINE; + *done = PR_FALSE; + + return rv; +} + +nsresult +nsUrlClassifierDBServiceWorker::ProcessResponseLines(PRBool* done) +{ PRUint32 cur = 0; PRInt32 next; - PRInt32 count = 0; - nsCAutoString dbTableName; - nsCAutoString lastTableLine; - nsCOMPtr updateStatement; - nsCOMPtr deleteStatement; + + nsresult rv; + // We will run to completion unless we find a chunk line + *done = PR_TRUE; + + nsACString& updateString = mPendingStreamUpdate; + while(cur < updateString.Length() && (next = updateString.FindChar('\n', cur)) != kNotFound) { - const nsCSubstring &line = Substring(updateString, cur, next - cur); - cur = next + 1; // prepare for next run + const nsCSubstring& line = Substring(updateString, cur, next - cur); + cur = next + 1; - // Skip blank lines - if (line.Length() == 0) - continue; + LOG(("Processing %s\n", PromiseFlatCString(line).get())); - count++; - - if ('[' == line[0]) { - rv = ProcessNewTable(line, &dbTableName, - getter_AddRefs(updateStatement), - getter_AddRefs(deleteStatement)); - NS_WARN_IF_FALSE(NS_SUCCEEDED(rv), "malformed table line"); - if (NS_SUCCEEDED(rv)) { - // If it's a new table, we may have completed a table. - // Go ahead and post the completion to the UI thread and db. - if (lastTableLine.Length() > 0) { - // If it was a new table, we need to swap in the new table. - rv = MaybeSwapTables(lastTableLine); - if (NS_SUCCEEDED(rv)) { - mConnection->CommitTransaction(); - c->HandleEvent(lastTableLine); - } else { - // failed to swap, rollback - mConnection->RollbackTransaction(); - } - mConnection->BeginTransaction(); - } - lastTableLine.Assign(line); + if (StringBeginsWith(line, NS_LITERAL_CSTRING("n:"))) { + if (PR_sscanf(PromiseFlatCString(line).get(), "n:%d", + &mUpdateWait) != 1) { + LOG(("Error parsing n: field: %s", PromiseFlatCString(line).get())); + mUpdateWait = 0; } + } else if (StringBeginsWith(line, NS_LITERAL_CSTRING("k:"))) { + // XXX: pleaserekey + } else if (StringBeginsWith(line, NS_LITERAL_CSTRING("i:"))) { + const nsCSubstring& data = Substring(line, 2); + PRInt32 comma; + if ((comma = data.FindChar(',')) == kNotFound) { + mUpdateTable = data; + } else { + mUpdateTable = Substring(data, 0, comma); + // The rest is the mac, which we don't support for now + } + GetTableId(mUpdateTable, &mUpdateTableId); + LOG(("update table: '%s' (%d)", mUpdateTable.get(), mUpdateTableId)); + } else if (StringBeginsWith(line, NS_LITERAL_CSTRING("a:")) || + StringBeginsWith(line, NS_LITERAL_CSTRING("s:"))) { + mState = STATE_CHUNK; + char command; + if (PR_sscanf(PromiseFlatCString(line).get(), + "%c:%d:%d", &command, &mChunkNum, &mChunkLen) != 3 || + mChunkLen > MAX_CHUNK_SIZE) { + return NS_ERROR_FAILURE; + } + mChunkType = (command == 'a') ? CHUNK_ADD : CHUNK_SUB; + + // Done parsing lines, move to chunk state now + *done = PR_FALSE; + break; + } else if (StringBeginsWith(line, NS_LITERAL_CSTRING("ad:"))) { + PRUint32 chunkNum; + if (PR_sscanf(PromiseFlatCString(line).get(), "ad:%u", &chunkNum) != 1) { + return NS_ERROR_FAILURE; + } + rv = ExpireAdd(mUpdateTableId, chunkNum); + NS_ENSURE_SUCCESS(rv, rv); + } else if (StringBeginsWith(line, NS_LITERAL_CSTRING("sd:"))) { + PRUint32 chunkNum; + if (PR_sscanf(PromiseFlatCString(line).get(), "ad:%u", &chunkNum) != 1) { + return NS_ERROR_FAILURE; + } + rv = ExpireSub(mUpdateTableId, chunkNum); + NS_ENSURE_SUCCESS(rv, rv); } else { - rv = ProcessUpdateTable(line, dbTableName, updateStatement, - deleteStatement); - NS_WARN_IF_FALSE(NS_SUCCEEDED(rv), "malformed update line"); + LOG(("ignoring unknown line: '%s'", PromiseFlatCString(line).get())); } } - LOG(("Num update lines: %d\n", count)); - rv = MaybeSwapTables(lastTableLine); - if (NS_SUCCEEDED(rv)) { - mConnection->CommitTransaction(); - c->HandleEvent(lastTableLine); - } else { - // failed to swap, rollback - mConnection->RollbackTransaction(); - } + mPendingStreamUpdate = Substring(updateString, cur); - LOG(("Finishing table update\n")); return NS_OK; } +void +nsUrlClassifierDBServiceWorker::ResetUpdate() +{ + mUpdateWait = 0; + mState = STATE_LINE; + mChunkNum = 0; + mChunkLen = 0; + mUpdateStatus = NS_OK; + + mUpdateTable.Truncate(); + mPendingStreamUpdate.Truncate(); +} + +/** + * Updating the database: + * + * The Update() method takes a series of chunks seperated with control data, + * as described in + * http://code.google.com/p/google-safe-browsing/wiki/Protocolv2Spec + * + * It will iterate through the control data until it reaches a chunk. By + * the time it reaches a chunk, it should have received + * a) the table to which this chunk applies + * b) the type of chunk (add, delete, expire add, expire delete). + * c) the chunk ID + * d) the length of the chunk. + * + * For add and subtract chunks, it needs to read the chunk data (expires + * don't have any data). Chunk data is a list of URI fragments whose + * encoding depends on the type of table (which is indicated by the end + * of the table name): + * a) tables ending with -exp are a zlib-compressed list of URI fragments + * separated by newlines. + * b) tables ending with -sha128 have the form + * [domain][N][frag0]...[fragN] + * 16 1 16 16 + * If N is 0, the domain is reused as a fragment. + * c) any other tables are assumed to be a plaintext list of URI fragments + * separated by newlines. + * + * Update() can be fed partial data; It will accumulate data until there is + * enough to act on. Finish() should be called when there will be no more + * data. + */ NS_IMETHODIMP nsUrlClassifierDBServiceWorker::Update(const nsACString& chunk) { + if (gShuttingDownThread) + return NS_ERROR_NOT_INITIALIZED; + + HandlePendingLookups(); + LOG(("Update from Stream.")); nsresult rv = OpenDb(); if (NS_FAILED(rv)) { @@ -390,115 +1768,56 @@ nsUrlClassifierDBServiceWorker::Update(const nsACString& chunk) return NS_ERROR_FAILURE; } - nsCAutoString updateString(mPendingStreamUpdate); - updateString.Append(chunk); - - nsCOMPtr updateStatement; - nsCOMPtr deleteStatement; - nsCAutoString dbTableName; - - // If we're not in the middle of an update, we start a new transaction. - // Otherwise, we need to pick up where we left off. - if (!mHasPendingUpdate) { - mConnection->BeginTransaction(); - mHasPendingUpdate = PR_TRUE; - } else { - PRUint32 numTables = mTableUpdateLines.Length(); - if (numTables > 0) { - const nsCSubstring &line = Substring( - mTableUpdateLines[numTables - 1], 0); - - rv = ProcessNewTable(line, &dbTableName, - getter_AddRefs(updateStatement), - getter_AddRefs(deleteStatement)); - NS_WARN_IF_FALSE(NS_SUCCEEDED(rv), "malformed table line"); - } + // if something has gone wrong during this update, just throw it away + if (NS_FAILED(mUpdateStatus)) { + return mUpdateStatus; } - PRUint32 cur = 0; - PRInt32 next; - while(cur < updateString.Length() && - (next = updateString.FindChar('\n', cur)) != kNotFound) { - const nsCSubstring &line = Substring(updateString, cur, next - cur); - cur = next + 1; // prepare for next run + LOG(("Got %s\n", PromiseFlatCString(chunk).get())); - // Skip blank lines - if (line.Length() == 0) - continue; + mPendingStreamUpdate.Append(chunk); - if ('[' == line[0]) { - rv = ProcessNewTable(line, &dbTableName, - getter_AddRefs(updateStatement), - getter_AddRefs(deleteStatement)); - NS_WARN_IF_FALSE(NS_SUCCEEDED(rv), "malformed table line"); - if (NS_SUCCEEDED(rv)) { - // Add the line to our array of table lines. - mTableUpdateLines.AppendElement(line); - } + PRBool done = PR_FALSE; + while (!done) { + if (mState == STATE_CHUNK) { + rv = ProcessChunk(&done); } else { - rv = ProcessUpdateTable(line, dbTableName, updateStatement, - deleteStatement); - NS_WARN_IF_FALSE(NS_SUCCEEDED(rv), "malformed update line"); + rv = ProcessResponseLines(&done); + } + if (NS_FAILED(rv)) { + mUpdateStatus = rv; + return rv; } } - // Save the remaining string fragment. - mPendingStreamUpdate = Substring(updateString, cur); - LOG(("pending stream update: %s", mPendingStreamUpdate.get())); return NS_OK; } NS_IMETHODIMP -nsUrlClassifierDBServiceWorker::Finish(nsIUrlClassifierCallback *c) +nsUrlClassifierDBServiceWorker::Finish(nsIUrlClassifierCallback* aSuccessCallback, + nsIUrlClassifierCallback* aErrorCallback) { - if (!mHasPendingUpdate) - return NS_OK; - - if (gShuttingDownThread) { - mConnection->RollbackTransaction(); - return NS_ERROR_NOT_INITIALIZED; - } - - nsresult rv = NS_OK; - for (PRUint32 i = 0; i < mTableUpdateLines.Length(); ++i) { - rv = MaybeSwapTables(mTableUpdateLines[i]); - if (NS_FAILED(rv)) { - break; - } - } - - if (NS_SUCCEEDED(rv)) { - LOG(("Finish, committing transaction")); - mConnection->CommitTransaction(); - - // Send update information to main thread. - for (PRUint32 i = 0; i < mTableUpdateLines.Length(); ++i) { - c->HandleEvent(mTableUpdateLines[i]); - } + nsCAutoString arg; + if (NS_SUCCEEDED(mUpdateStatus)) { + arg.AppendInt(mUpdateWait); + aSuccessCallback->HandleEvent(arg); } else { - LOG(("Finish failed (swap table error?), rolling back transaction")); - mConnection->RollbackTransaction(); + arg.AppendInt(mUpdateStatus); + aErrorCallback->HandleEvent(arg); } - mTableUpdateLines.Clear(); - mPendingStreamUpdate.Truncate(); - mHasPendingUpdate = PR_FALSE; + ResetUpdate(); + return NS_OK; } NS_IMETHODIMP nsUrlClassifierDBServiceWorker::CancelStream() { - if (!mHasPendingUpdate) - return NS_OK; + LOG(("CancelStream")); - LOG(("CancelStream, rolling back transaction")); - mConnection->RollbackTransaction(); + ResetUpdate(); - mTableUpdateLines.Clear(); - mPendingStreamUpdate.Truncate(); - mHasPendingUpdate = PR_FALSE; - return NS_OK; } @@ -509,261 +1828,206 @@ nsUrlClassifierDBServiceWorker::CancelStream() NS_IMETHODIMP nsUrlClassifierDBServiceWorker::CloseDb() { - if (mConnection != nsnull) { - NS_RELEASE(mConnection); + if (mConnection) { + mLookupStatement = nsnull; + mLookupWithTableStatement = nsnull; + mLookupWithIDStatement = nsnull; + + mUpdateStatement = nsnull; + mDeleteStatement = nsnull; + + mAddChunkEntriesStatement = nsnull; + mGetChunkEntriesStatement = nsnull; + mDeleteChunkEntriesStatement = nsnull; + + mGetChunkListsStatement = nsnull; + mSetChunkListsStatement = nsnull; + + mGetTablesStatement = nsnull; + mGetTableIdStatement = nsnull; + mGetTableNameStatement = nsnull; + mInsertTableIdStatement = nsnull; + + mConnection = nsnull; LOG(("urlclassifier db closed\n")); } - return NS_OK; -} -nsresult -nsUrlClassifierDBServiceWorker::ProcessNewTable( - const nsCSubstring& aLine, - nsCString* aDbTableName, - mozIStorageStatement** aUpdateStatement, - mozIStorageStatement** aDeleteStatement) -{ - // The line format is "[table-name #.####]" or "[table-name #.#### update]" - // The additional "update" in the header means that this is a diff. - // Otherwise, we should blow away the old table and start afresh. - PRBool isUpdate = PR_FALSE; - - // If the version string is bad, give up. - nsresult rv = ParseVersionString(aLine, aDbTableName, &isUpdate); - NS_ENSURE_SUCCESS(rv, rv); - - // If it's not an update, we dump the values into a new table. Once we're - // done with the table, we drop the original table and copy over the values - // from the old table into the new table. - if (!isUpdate) - aDbTableName->Append(kNEW_TABLE_SUFFIX); - - // Create the table - rv = MaybeCreateTable(*aDbTableName); - if (NS_FAILED(rv)) - return rv; - - // insert statement - nsCAutoString statement; - statement.AssignLiteral("INSERT OR REPLACE INTO "); - statement.Append(*aDbTableName); - statement.AppendLiteral(" VALUES (?1, ?2)"); - rv = mConnection->CreateStatement(statement, aUpdateStatement); - NS_ENSURE_SUCCESS(rv, rv); - - // delete statement - statement.AssignLiteral("DELETE FROM "); - statement.Append(*aDbTableName); - statement.AppendLiteral(" WHERE key = ?1"); - rv = mConnection->CreateStatement(statement, aDeleteStatement); - NS_ENSURE_SUCCESS(rv, rv); + mCryptoHash = nsnull; return NS_OK; } -nsresult -nsUrlClassifierDBServiceWorker::ProcessUpdateTable( - const nsCSubstring& aLine, - const nsCString& aTableName, - mozIStorageStatement* aUpdateStatement, - mozIStorageStatement* aDeleteStatement) -{ - // We should have seen a table name line by now. - if (aTableName.Length() == 0) - return NS_ERROR_FAILURE; - - if (!aUpdateStatement || !aDeleteStatement) { - NS_NOTREACHED("Statements NULL but table is not"); - return NS_ERROR_FAILURE; - } - // There should at least be an op char and a key - if (aLine.Length() < 2) - return NS_ERROR_FAILURE; - - char op = aLine[0]; - PRInt32 spacePos = aLine.FindChar('\t'); - nsresult rv = NS_ERROR_FAILURE; - - if ('+' == op && spacePos != kNotFound) { - // Insert operation of the form "+KEY\tVALUE" - const nsCSubstring &key = Substring(aLine, 1, spacePos - 1); - const nsCSubstring &value = Substring(aLine, spacePos + 1); - - // We use ROT13 versions of keys to avoid antivirus utilities from - // flagging us as a virus. - nsCString keyROT13(key); - Rot13Line(keyROT13); - - aUpdateStatement->BindUTF8StringParameter(0, keyROT13); - aUpdateStatement->BindUTF8StringParameter(1, value); - - rv = aUpdateStatement->Execute(); - NS_WARN_IF_FALSE(NS_SUCCEEDED(rv), "Failed to update"); - } else if ('-' == op) { - // Remove operation of the form "-KEY" - nsCString keyROT13; - if (spacePos == kNotFound) { - // No trailing tab - const nsCSubstring &key = Substring(aLine, 1); - keyROT13.Assign(key); - } else { - // With trailing tab - const nsCSubstring &key = Substring(aLine, 1, spacePos - 1); - keyROT13.Assign(key); - } - Rot13Line(keyROT13); - aDeleteStatement->BindUTF8StringParameter(0, keyROT13); - - rv = aDeleteStatement->Execute(); - NS_WARN_IF_FALSE(NS_SUCCEEDED(rv), "Failed to delete"); - } - - return rv; -} - nsresult nsUrlClassifierDBServiceWorker::OpenDb() { // Connection already open, don't do anything. - if (mConnection != nsnull) + if (mConnection) return NS_OK; LOG(("Opening db\n")); - // Compute database filename - nsCOMPtr dbFile; - - nsresult rv = NS_GetSpecialDirectory(NS_APP_USER_PROFILE_50_DIR, - getter_AddRefs(dbFile)); - NS_ENSURE_SUCCESS(rv, rv); - rv = dbFile->Append(NS_LITERAL_STRING(DATABASE_FILENAME)); - NS_ENSURE_SUCCESS(rv, rv); + nsresult rv; // open the connection nsCOMPtr storageService = do_GetService(MOZ_STORAGE_SERVICE_CONTRACTID, &rv); NS_ENSURE_SUCCESS(rv, rv); - rv = storageService->OpenDatabase(dbFile, &mConnection); + + nsCOMPtr connection; + rv = storageService->OpenDatabase(mDBFile, getter_AddRefs(connection)); if (rv == NS_ERROR_FILE_CORRUPTED) { // delete the db and try opening again - rv = dbFile->Remove(PR_FALSE); + rv = mDBFile->Remove(PR_FALSE); NS_ENSURE_SUCCESS(rv, rv); - rv = storageService->OpenDatabase(dbFile, &mConnection); + rv = storageService->OpenDatabase(mDBFile, getter_AddRefs(connection)); } + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->ExecuteSimpleSQL(NS_LITERAL_CSTRING("PRAGMA synchronous=OFF")); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->ExecuteSimpleSQL(NS_LITERAL_CSTRING("PRAGMA page_size=4096")); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->ExecuteSimpleSQL(NS_LITERAL_CSTRING("PRAGMA default_page_size=4096")); + NS_ENSURE_SUCCESS(rv, rv); + + // Create the table + rv = MaybeCreateTables(connection); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("SELECT * FROM moz_classifier" + " WHERE domain=?1"), + getter_AddRefs(mLookupStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("SELECT * FROM moz_classifier" + " WHERE domain=?1 AND table_id=?2"), + getter_AddRefs(mLookupWithTableStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("SELECT * FROM moz_classifier" + " WHERE id=?1"), + getter_AddRefs(mLookupWithIDStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("INSERT OR REPLACE INTO moz_classifier" + " VALUES (?1, ?2, ?3, ?4)"), + getter_AddRefs(mUpdateStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("DELETE FROM moz_classifier" + " WHERE id=?1"), + getter_AddRefs(mDeleteStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("INSERT OR REPLACE INTO moz_chunks VALUES (?1, ?2, ?3)"), + getter_AddRefs(mAddChunkEntriesStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("SELECT entries FROM moz_chunks" + " WHERE chunk_id = ?1 AND table_id = ?2"), + getter_AddRefs(mGetChunkEntriesStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("DELETE FROM moz_chunks WHERE table_id=?1 AND chunk_id=?2"), + getter_AddRefs(mDeleteChunkEntriesStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("SELECT add_chunks, sub_chunks FROM moz_tables" + " WHERE id=?1"), + getter_AddRefs(mGetChunkListsStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("UPDATE moz_tables" + " SET add_chunks=?1, sub_chunks=?2" + " WHERE id=?3"), + getter_AddRefs(mSetChunkListsStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("SELECT name, add_chunks, sub_chunks" + " FROM moz_tables"), + getter_AddRefs(mGetTablesStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("SELECT id FROM moz_tables" + " WHERE name = ?1"), + getter_AddRefs(mGetTableIdStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("SELECT name FROM moz_tables" + " WHERE id = ?1"), + getter_AddRefs(mGetTableNameStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->CreateStatement + (NS_LITERAL_CSTRING("INSERT INTO moz_tables(id, name, add_chunks, sub_chunks)" + " VALUES (null, ?1, null, null)"), + getter_AddRefs(mInsertTableIdStatement)); + NS_ENSURE_SUCCESS(rv, rv); + + mConnection = connection; + + mCryptoHash = do_CreateInstance(NS_CRYPTO_HASH_CONTRACTID, &rv); + NS_ENSURE_SUCCESS(rv, rv); + + return NS_OK; +} + +nsresult +nsUrlClassifierDBServiceWorker::MaybeCreateTables(mozIStorageConnection* connection) +{ + LOG(("MaybeCreateTables\n")); + + nsresult rv = connection->ExecuteSimpleSQL( + NS_LITERAL_CSTRING("CREATE TABLE IF NOT EXISTS moz_classifier" + " (id INTEGER PRIMARY KEY," + " domain BLOB," + " data BLOB," + " table_id INTEGER)")); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->ExecuteSimpleSQL( + NS_LITERAL_CSTRING("CREATE UNIQUE INDEX IF NOT EXISTS" + " moz_classifier_domain_index" + " ON moz_classifier(domain, table_id)")); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->ExecuteSimpleSQL( + NS_LITERAL_CSTRING("CREATE TABLE IF NOT EXISTS moz_tables" + " (id INTEGER PRIMARY KEY," + " name TEXT," + " add_chunks TEXT," + " sub_chunks TEXT);")); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->ExecuteSimpleSQL( + NS_LITERAL_CSTRING("CREATE TABLE IF NOT EXISTS moz_chunks" + " (chunk_id INTEGER," + " table_id INTEGER," + " entries BLOB)")); + NS_ENSURE_SUCCESS(rv, rv); + + rv = connection->ExecuteSimpleSQL( + NS_LITERAL_CSTRING("CREATE INDEX IF NOT EXISTS moz_chunks_id" + " ON moz_chunks(chunk_id)")); + NS_ENSURE_SUCCESS(rv, rv); + return rv; } -nsresult -nsUrlClassifierDBServiceWorker::MaybeCreateTable(const nsCString& aTableName) -{ - LOG(("MaybeCreateTable %s\n", aTableName.get())); - - nsCOMPtr createStatement; - nsCString statement; - statement.Assign("CREATE TABLE IF NOT EXISTS "); - statement.Append(aTableName); - statement.Append(" (key TEXT PRIMARY KEY, value TEXT)"); - nsresult rv = mConnection->CreateStatement(statement, - getter_AddRefs(createStatement)); - NS_ENSURE_SUCCESS(rv, rv); - - return createStatement->Execute(); -} - -nsresult -nsUrlClassifierDBServiceWorker::MaybeDropTable(const nsCString& aTableName) -{ - LOG(("MaybeDropTable %s\n", aTableName.get())); - nsCAutoString statement("DROP TABLE IF EXISTS "); - statement.Append(aTableName); - return mConnection->ExecuteSimpleSQL(statement); -} - -nsresult -nsUrlClassifierDBServiceWorker::MaybeSwapTables(const nsCString& aVersionLine) -{ - if (aVersionLine.Length() == 0) - return NS_ERROR_FAILURE; - - // Check to see if this was a full table update or not. - nsCAutoString tableName; - PRBool isUpdate; - nsresult rv = ParseVersionString(aVersionLine, &tableName, &isUpdate); - NS_ENSURE_SUCCESS(rv, rv); - - // Updates don't require any fancy logic. - if (isUpdate) - return NS_OK; - - // Not an update, so we need to swap tables by dropping the original table - // and copying in the values from the new table. - rv = MaybeDropTable(tableName); - NS_ENSURE_SUCCESS(rv, rv); - - nsCAutoString newTableName(tableName); - newTableName.Append(kNEW_TABLE_SUFFIX); - - // Bring over new table - nsCAutoString sql("ALTER TABLE "); - sql.Append(newTableName); - sql.Append(" RENAME TO "); - sql.Append(tableName); - rv = mConnection->ExecuteSimpleSQL(sql); - NS_ENSURE_SUCCESS(rv, rv); - - LOG(("tables swapped (%s)\n", tableName.get())); - - return NS_OK; -} - -// The line format is "[table-name #.####]" or "[table-name #.#### update]". -nsresult -nsUrlClassifierDBServiceWorker::ParseVersionString(const nsCSubstring& aLine, - nsCString* aTableName, - PRBool* aIsUpdate) -{ - // Blank lines are not valid - if (aLine.Length() == 0) - return NS_ERROR_FAILURE; - - // Max size for an update line (so we don't buffer overflow when sscanf'ing). - const PRUint32 MAX_LENGTH = 2048; - if (aLine.Length() > MAX_LENGTH) - return NS_ERROR_FAILURE; - - nsCAutoString lineData(aLine); - char tableNameBuf[MAX_LENGTH], endChar = ' '; - PRInt32 majorVersion, minorVersion, numConverted; - // Use trailing endChar to make sure the update token gets parsed. - numConverted = PR_sscanf(lineData.get(), "[%s %d.%d update%c", tableNameBuf, - &majorVersion, &minorVersion, &endChar); - if (numConverted != 4 || endChar != ']') { - // Check to see if it's not an update request - numConverted = PR_sscanf(lineData.get(), "[%s %d.%d%c", tableNameBuf, - &majorVersion, &minorVersion, &endChar); - if (numConverted != 4 || endChar != ']') - return NS_ERROR_FAILURE; - *aIsUpdate = PR_FALSE; - } else { - // First sscanf worked, so it's an update string. - *aIsUpdate = PR_TRUE; - } - - LOG(("Is update? %d\n", *aIsUpdate)); - - // Table header looks valid, go ahead and copy over the table name into the - // return variable. - GetDbTableName(nsCAutoString(tableNameBuf), aTableName); - return NS_OK; -} - -void -nsUrlClassifierDBServiceWorker::GetDbTableName(const nsACString& aTableName, - nsCString* aDbTableName) -{ - aDbTableName->Assign(aTableName); - aDbTableName->ReplaceChar('-', '_'); -} - // ------------------------------------------------------------------------- // Proxy class implementation @@ -805,6 +2069,9 @@ nsUrlClassifierDBService::~nsUrlClassifierDBService() nsresult nsUrlClassifierDBService::Init() { + NS_ASSERTION(sizeof(nsUrlClassifierHash) == KEY_LENGTH, + "nsUrlClassifierHash must be KEY_LENGTH bytes long!"); + #if defined(PR_LOGGING) if (!gUrlClassifierDbServiceLog) gUrlClassifierDbServiceLog = PR_NewLogModule("UrlClassifierDbService"); @@ -825,6 +2092,12 @@ nsUrlClassifierDBService::Init() if (!mWorker) return NS_ERROR_OUT_OF_MEMORY; + rv = mWorker->Init(); + if (NS_FAILED(rv)) { + mWorker = nsnull; + return rv; + } + // Add an observer for shutdown nsCOMPtr observerService = do_GetService("@mozilla.org/observer-service;1"); @@ -837,23 +2110,43 @@ nsUrlClassifierDBService::Init() return NS_OK; } -NS_IMETHODIMP -nsUrlClassifierDBService::Exists(const nsACString& tableName, - const nsACString& key, - nsIUrlClassifierCallback *c) +nsresult +nsUrlClassifierDBService::Lookup(const nsACString& spec, + nsIUrlClassifierCallback* c, + PRBool needsProxy) { NS_ENSURE_TRUE(gDbBackgroundThread, NS_ERROR_NOT_INITIALIZED); - nsresult rv; - // The proxy callback uses the current thread. - nsCOMPtr proxyCallback; - rv = NS_GetProxyForObject(NS_PROXY_TO_CURRENT_THREAD, - NS_GET_IID(nsIUrlClassifierCallback), - c, - NS_PROXY_ASYNC, - getter_AddRefs(proxyCallback)); + nsCOMPtr uri; + + nsresult rv = NS_NewURI(getter_AddRefs(uri), spec); NS_ENSURE_SUCCESS(rv, rv); + uri = NS_GetInnermostURI(uri); + if (!uri) { + return NS_ERROR_FAILURE; + } + + nsCAutoString key; + // Canonicalize the url + nsCOMPtr utilsService = + do_GetService(NS_URLCLASSIFIERUTILS_CONTRACTID); + rv = utilsService->GetKeyForURI(uri, key); + NS_ENSURE_SUCCESS(rv, rv); + + nsCOMPtr proxyCallback; + if (needsProxy) { + // The proxy callback uses the current thread. + rv = NS_GetProxyForObject(NS_PROXY_TO_CURRENT_THREAD, + NS_GET_IID(nsIUrlClassifierCallback), + c, + NS_PROXY_ASYNC, + getter_AddRefs(proxyCallback)); + NS_ENSURE_SUCCESS(rv, rv); + } else { + proxyCallback = c; + } + // The actual worker uses the background thread. nsCOMPtr proxy; rv = NS_GetProxyForObject(gDbBackgroundThread, @@ -863,12 +2156,16 @@ nsUrlClassifierDBService::Exists(const nsACString& tableName, getter_AddRefs(proxy)); NS_ENSURE_SUCCESS(rv, rv); - return proxy->Exists(tableName, key, proxyCallback); + // Queue this lookup and call the lookup function to flush the queue if + // necessary. + rv = mWorker->QueueLookup(key, proxyCallback); + NS_ENSURE_SUCCESS(rv, rv); + + return proxy->Lookup(EmptyCString(), nsnull, PR_FALSE); } NS_IMETHODIMP -nsUrlClassifierDBService::CheckTables(const nsACString & tableNames, - nsIUrlClassifierCallback *c) +nsUrlClassifierDBService::GetTables(nsIUrlClassifierCallback* c) { NS_ENSURE_TRUE(gDbBackgroundThread, NS_ERROR_NOT_INITIALIZED); @@ -891,35 +2188,7 @@ nsUrlClassifierDBService::CheckTables(const nsACString & tableNames, getter_AddRefs(proxy)); NS_ENSURE_SUCCESS(rv, rv); - return proxy->CheckTables(tableNames, proxyCallback); -} - -NS_IMETHODIMP -nsUrlClassifierDBService::UpdateTables(const nsACString& updateString, - nsIUrlClassifierCallback *c) -{ - NS_ENSURE_TRUE(gDbBackgroundThread, NS_ERROR_NOT_INITIALIZED); - - nsresult rv; - // The proxy callback uses the current thread. - nsCOMPtr proxyCallback; - rv = NS_GetProxyForObject(NS_PROXY_TO_CURRENT_THREAD, - NS_GET_IID(nsIUrlClassifierCallback), - c, - NS_PROXY_ASYNC, - getter_AddRefs(proxyCallback)); - NS_ENSURE_SUCCESS(rv, rv); - - // The actual worker uses the background thread. - nsCOMPtr proxy; - rv = NS_GetProxyForObject(gDbBackgroundThread, - NS_GET_IID(nsIUrlClassifierDBServiceWorker), - mWorker, - NS_PROXY_ASYNC, - getter_AddRefs(proxy)); - NS_ENSURE_SUCCESS(rv, rv); - - return proxy->UpdateTables(updateString, proxyCallback); + return proxy->GetTables(proxyCallback); } NS_IMETHODIMP @@ -942,19 +2211,32 @@ nsUrlClassifierDBService::Update(const nsACString& aUpdateChunk) } NS_IMETHODIMP -nsUrlClassifierDBService::Finish(nsIUrlClassifierCallback *c) +nsUrlClassifierDBService::Finish(nsIUrlClassifierCallback* aSuccessCallback, + nsIUrlClassifierCallback* aErrorCallback) { NS_ENSURE_TRUE(gDbBackgroundThread, NS_ERROR_NOT_INITIALIZED); nsresult rv; // The proxy callback uses the current thread. - nsCOMPtr proxyCallback; - rv = NS_GetProxyForObject(NS_PROXY_TO_CURRENT_THREAD, - NS_GET_IID(nsIUrlClassifierCallback), - c, - NS_PROXY_ASYNC, - getter_AddRefs(proxyCallback)); - NS_ENSURE_SUCCESS(rv, rv); + nsCOMPtr proxySuccessCallback; + if (aSuccessCallback) { + rv = NS_GetProxyForObject(NS_PROXY_TO_CURRENT_THREAD, + NS_GET_IID(nsIUrlClassifierCallback), + aSuccessCallback, + NS_PROXY_ASYNC, + getter_AddRefs(proxySuccessCallback)); + NS_ENSURE_SUCCESS(rv, rv); + } + + nsCOMPtr proxyErrorCallback; + if (aErrorCallback) { + rv = NS_GetProxyForObject(NS_PROXY_TO_CURRENT_THREAD, + NS_GET_IID(nsIUrlClassifierCallback), + aErrorCallback, + NS_PROXY_ASYNC, + getter_AddRefs(proxyErrorCallback)); + NS_ENSURE_SUCCESS(rv, rv); + } // The actual worker uses the background thread. nsCOMPtr proxy; @@ -965,7 +2247,7 @@ nsUrlClassifierDBService::Finish(nsIUrlClassifierCallback *c) getter_AddRefs(proxy)); NS_ENSURE_SUCCESS(rv, rv); - return proxy->Finish(proxyCallback); + return proxy->Finish(proxySuccessCallback, proxyErrorCallback); } NS_IMETHODIMP diff --git a/toolkit/components/url-classifier/src/nsUrlClassifierListManager.js b/toolkit/components/url-classifier/src/nsUrlClassifierListManager.js index 6791fbe25bdd..fff83d22ec2f 100644 --- a/toolkit/components/url-classifier/src/nsUrlClassifierListManager.js +++ b/toolkit/components/url-classifier/src/nsUrlClassifierListManager.js @@ -39,7 +39,6 @@ const Cc = Components.classes; const Ci = Components.interfaces; #include ../content/listmanager.js -#include ../content/wireformat.js var modScope = this; function Init() { diff --git a/toolkit/components/url-classifier/src/nsUrlClassifierStreamUpdater.cpp b/toolkit/components/url-classifier/src/nsUrlClassifierStreamUpdater.cpp index 5bded3283e8f..fe081b179c85 100644 --- a/toolkit/components/url-classifier/src/nsUrlClassifierStreamUpdater.cpp +++ b/toolkit/components/url-classifier/src/nsUrlClassifierStreamUpdater.cpp @@ -37,10 +37,14 @@ * ***** END LICENSE BLOCK ***** */ #include "nsCRT.h" +#include "nsIHttpChannel.h" #include "nsIObserverService.h" +#include "nsIStringStream.h" +#include "nsIUploadChannel.h" #include "nsIURI.h" #include "nsIUrlClassifierDBService.h" #include "nsStreamUtils.h" +#include "nsStringStream.h" #include "nsToolkitCompsCID.h" #include "nsUrlClassifierStreamUpdater.h" #include "prlog.h" @@ -63,8 +67,9 @@ class nsUrlClassifierStreamUpdater; class TableUpdateListener : public nsIStreamListener { public: - TableUpdateListener(nsIUrlClassifierCallback *aTableCallback, - nsIUrlClassifierCallback *aErrorCallback, + TableUpdateListener(nsIUrlClassifierCallback *aSuccessCallback, + nsIUrlClassifierCallback *aUpdateErrorCallback, + nsIUrlClassifierCallback *aDownloadErrorCallback, nsUrlClassifierStreamUpdater* aStreamUpdater); nsCOMPtr mDBService; @@ -76,20 +81,23 @@ private: ~TableUpdateListener() {} // Callback when table updates complete. - nsCOMPtr mTableCallback; - nsCOMPtr mErrorCallback; + nsCOMPtr mSuccessCallback; + nsCOMPtr mUpdateErrorCallback; + nsCOMPtr mDownloadErrorCallback; // Reference to the stream updater that created this. nsUrlClassifierStreamUpdater *mStreamUpdater; }; TableUpdateListener::TableUpdateListener( - nsIUrlClassifierCallback *aTableCallback, - nsIUrlClassifierCallback *aErrorCallback, + nsIUrlClassifierCallback *aSuccessCallback, + nsIUrlClassifierCallback *aUpdateErrorCallback, + nsIUrlClassifierCallback *aDownloadErrorCallback, nsUrlClassifierStreamUpdater* aStreamUpdater) { - mTableCallback = aTableCallback; - mErrorCallback = aErrorCallback; + mSuccessCallback = aSuccessCallback; + mDownloadErrorCallback = aDownloadErrorCallback; + mUpdateErrorCallback = aUpdateErrorCallback; mStreamUpdater = aStreamUpdater; } @@ -110,10 +118,13 @@ TableUpdateListener::OnStartRequest(nsIRequest *request, nsISupports* context) nsresult status; rv = httpChannel->GetStatus(&status); NS_ENSURE_SUCCESS(rv, rv); + + LOG(("OnStartRequest (status %d)", status)); + if (NS_ERROR_CONNECTION_REFUSED == status || NS_ERROR_NET_TIMEOUT == status) { // Assume that we're overloading the server and trigger backoff. - mErrorCallback->HandleEvent(nsCString()); + mDownloadErrorCallback->HandleEvent(nsCString()); return NS_ERROR_ABORT; } @@ -151,7 +162,7 @@ TableUpdateListener::OnDataAvailable(nsIRequest *request, nsCAutoString strStatus; strStatus.AppendInt(status); - mErrorCallback->HandleEvent(strStatus); + mDownloadErrorCallback->HandleEvent(strStatus); return NS_ERROR_ABORT; } @@ -180,7 +191,7 @@ TableUpdateListener::OnStopRequest(nsIRequest *request, nsISupports* context, // If we got the whole stream, call Finish to commit the changes. // Otherwise, call Cancel to rollback the changes. if (NS_SUCCEEDED(aStatus)) - mDBService->Finish(mTableCallback); + mDBService->Finish(mSuccessCallback, mUpdateErrorCallback); else mDBService->CancelStream(); @@ -235,6 +246,8 @@ nsUrlClassifierStreamUpdater::GetUpdateUrl(nsACString & aUpdateUrl) NS_IMETHODIMP nsUrlClassifierStreamUpdater::SetUpdateUrl(const nsACString & aUpdateUrl) { + LOG(("Update URL is %s\n", PromiseFlatCString(aUpdateUrl).get())); + nsresult rv = NS_NewURI(getter_AddRefs(mUpdateUrl), aUpdateUrl); NS_ENSURE_SUCCESS(rv, rv); @@ -243,8 +256,10 @@ nsUrlClassifierStreamUpdater::SetUpdateUrl(const nsACString & aUpdateUrl) NS_IMETHODIMP nsUrlClassifierStreamUpdater::DownloadUpdates( - nsIUrlClassifierCallback *aTableCallback, - nsIUrlClassifierCallback *aErrorCallback, + const nsACString &aRequestBody, + nsIUrlClassifierCallback *aSuccessCallback, + nsIUrlClassifierCallback *aUpdateErrorCallback, + nsIUrlClassifierCallback *aDownloadErrorCallback, PRBool *_retval) { if (mIsUpdating) { @@ -276,8 +291,12 @@ nsUrlClassifierStreamUpdater::DownloadUpdates( rv = NS_NewChannel(getter_AddRefs(mChannel), mUpdateUrl); NS_ENSURE_SUCCESS(rv, rv); + rv = AddRequestBody(aRequestBody); + NS_ENSURE_SUCCESS(rv, rv); + // Bind to a different callback each time we invoke this method. - mListener = new TableUpdateListener(aTableCallback, aErrorCallback, this); + mListener = new TableUpdateListener(aSuccessCallback, aUpdateErrorCallback, + aDownloadErrorCallback, this); // Make the request rv = mChannel->AsyncOpen(mListener.get(), nsnull); @@ -289,6 +308,35 @@ nsUrlClassifierStreamUpdater::DownloadUpdates( return NS_OK; } +nsresult +nsUrlClassifierStreamUpdater::AddRequestBody(const nsACString &aRequestBody) +{ + nsresult rv; + nsCOMPtr strStream = + do_CreateInstance(NS_STRINGINPUTSTREAM_CONTRACTID, &rv); + NS_ENSURE_SUCCESS(rv, rv); + + rv = strStream->SetData(aRequestBody.BeginReading(), + aRequestBody.Length()); + NS_ENSURE_SUCCESS(rv, rv); + + nsCOMPtr uploadChannel = do_QueryInterface(mChannel, &rv); + NS_ENSURE_SUCCESS(rv, rv); + + rv = uploadChannel->SetUploadStream(strStream, + NS_LITERAL_CSTRING("text/plain"), + -1); + NS_ENSURE_SUCCESS(rv, rv); + + nsCOMPtr httpChannel = do_QueryInterface(mChannel, &rv); + NS_ENSURE_SUCCESS(rv, rv); + + rv = httpChannel->SetRequestMethod(NS_LITERAL_CSTRING("POST")); + NS_ENSURE_SUCCESS(rv, rv); + + return NS_OK; +} + /////////////////////////////////////////////////////////////////////////////// // nsIObserver implementation diff --git a/toolkit/components/url-classifier/src/nsUrlClassifierStreamUpdater.h b/toolkit/components/url-classifier/src/nsUrlClassifierStreamUpdater.h index 1e453854974c..82071708d7f0 100644 --- a/toolkit/components/url-classifier/src/nsUrlClassifierStreamUpdater.h +++ b/toolkit/components/url-classifier/src/nsUrlClassifierStreamUpdater.h @@ -71,6 +71,8 @@ private: // Disallow copy constructor nsUrlClassifierStreamUpdater(nsUrlClassifierStreamUpdater&); + nsresult AddRequestBody(const nsACString &aRequestBody); + PRBool mIsUpdating; PRBool mInitialized; nsCOMPtr mUpdateUrl; diff --git a/toolkit/components/url-classifier/src/nsUrlClassifierUtils.cpp b/toolkit/components/url-classifier/src/nsUrlClassifierUtils.cpp index 01e62271365b..50ce3de69517 100644 --- a/toolkit/components/url-classifier/src/nsUrlClassifierUtils.cpp +++ b/toolkit/components/url-classifier/src/nsUrlClassifierUtils.cpp @@ -36,7 +36,11 @@ #include "nsEscape.h" #include "nsString.h" +#include "nsIURI.h" +#include "nsNetUtil.h" #include "nsUrlClassifierUtils.h" +#include "nsVoidArray.h" +#include "prprf.h" static char int_to_hex_digit(PRInt32 i) { @@ -44,6 +48,58 @@ static char int_to_hex_digit(PRInt32 i) return static_cast(((i < 10) ? (i + '0') : ((i - 10) + 'A'))); } +static PRBool +IsDecimal(const nsACString & num) +{ + for (PRUint32 i = 0; i < num.Length(); i++) { + if (!isdigit(num[i])) { + return PR_FALSE; + } + } + + return PR_TRUE; +} + +static PRBool +IsHex(const nsACString & num) +{ + if (num.Length() < 3) { + return PR_FALSE; + } + + if (num[0] != '0' || !(num[1] == 'x' || num[1] == 'X')) { + return PR_FALSE; + } + + for (PRUint32 i = 2; i < num.Length(); i++) { + if (!isxdigit(num[i])) { + return PR_FALSE; + } + } + + return PR_TRUE; +} + +static PRBool +IsOctal(const nsACString & num) +{ + if (num.Length() < 2) { + return PR_FALSE; + } + + if (num[0] != '0') { + return PR_FALSE; + } + + for (PRUint32 i = 1; i < num.Length(); i++) { + if (!isdigit(num[i]) || num[i] == '8' || num[i] == '9') { + return PR_FALSE; + } + } + + return PR_TRUE; +} + nsUrlClassifierUtils::nsUrlClassifierUtils() : mEscapeCharmap(nsnull) { } @@ -64,54 +120,252 @@ NS_IMPL_ISUPPORTS1(nsUrlClassifierUtils, nsIUrlClassifierUtils) ///////////////////////////////////////////////////////////////////////////// // nsIUrlClassifierUtils -/* ACString canonicalizeURL (in ACString url); */ NS_IMETHODIMP -nsUrlClassifierUtils::CanonicalizeURL(const nsACString & url, nsACString & _retval) +nsUrlClassifierUtils::GetKeyForURI(nsIURI * uri, nsACString & _retval) { - nsCAutoString decodedUrl(url); - nsCAutoString temp; - while (NS_UnescapeURL(decodedUrl.get(), decodedUrl.Length(), 0, temp)) { - decodedUrl.Assign(temp); - temp.Truncate(); - } - SpecialEncode(decodedUrl, _retval); - return NS_OK; -} + nsCOMPtr innerURI = NS_GetInnermostURI(uri); + if (!innerURI) + innerURI = uri; + + nsCAutoString host; + innerURI->GetAsciiHost(host); + + nsresult rv = CanonicalizeHostname(host, _retval); + NS_ENSURE_SUCCESS(rv, rv); + + nsCAutoString path; + rv = innerURI->GetPath(path); + NS_ENSURE_SUCCESS(rv, rv); + + // strip out anchors and query parameters + PRInt32 ref = path.FindChar('#'); + if (ref != kNotFound) + path.SetLength(ref); + + ref = path.FindChar('?'); + if (ref != kNotFound) + path.SetLength(ref); + + nsCAutoString temp; + rv = CanonicalizePath(path, temp); + NS_ENSURE_SUCCESS(rv, rv); + + _retval.Append(temp); -NS_IMETHODIMP -nsUrlClassifierUtils::EscapeHostname(const nsACString & hostname, - nsACString & _retval) -{ - const char* curChar = hostname.BeginReading(); - const char* end = hostname.EndReading(); - while (curChar != end) { - unsigned char c = static_cast(*curChar); - if (mEscapeCharmap->Contains(c)) { - _retval.Append('%'); - _retval.Append(int_to_hex_digit(c / 16)); - _retval.Append(int_to_hex_digit(c % 16)); - } else { - _retval.Append(*curChar); - } - ++curChar; - } - return NS_OK; } ///////////////////////////////////////////////////////////////////////////// // non-interface methods +nsresult +nsUrlClassifierUtils::CanonicalizeHostname(const nsACString & hostname, + nsACString & _retval) +{ + nsCAutoString unescaped; + if (!NS_UnescapeURL(PromiseFlatCString(hostname).get(), + PromiseFlatCString(hostname).Length(), + 0, unescaped)) { + unescaped.Assign(hostname); + } + + nsCAutoString cleaned; + CleanupHostname(unescaped, cleaned); + + nsCAutoString temp; + ParseIPAddress(cleaned, temp); + if (!temp.IsEmpty()) { + cleaned.Assign(temp); + } + + ToLowerCase(cleaned); + SpecialEncode(cleaned, PR_FALSE, _retval); + + return NS_OK; +} + + +nsresult +nsUrlClassifierUtils::CanonicalizePath(const nsACString & path, + nsACString & _retval) +{ + _retval.Truncate(); + + nsCAutoString decodedPath(path); + nsCAutoString temp; + while (NS_UnescapeURL(decodedPath.get(), decodedPath.Length(), 0, temp)) { + decodedPath.Assign(temp); + temp.Truncate(); + } + + SpecialEncode(decodedPath, PR_TRUE, _retval); + // XXX: lowercase the path? + + return NS_OK; +} + +void +nsUrlClassifierUtils::CleanupHostname(const nsACString & hostname, + nsACString & _retval) +{ + _retval.Truncate(); + + const char* curChar = hostname.BeginReading(); + const char* end = hostname.EndReading(); + char lastChar = '\0'; + while (curChar != end) { + unsigned char c = static_cast(*curChar); + if (c == '.' && (lastChar == '\0' || lastChar == '.')) { + // skip + } else { + _retval.Append(*curChar); + } + lastChar = c; + ++curChar; + } + + // cut off trailing dots + while (_retval[_retval.Length() - 1] == '.') { + _retval.SetLength(_retval.Length() - 1); + } +} + +void +nsUrlClassifierUtils::ParseIPAddress(const nsACString & host, + nsACString & _retval) +{ + _retval.Truncate(); + nsACString::const_iterator iter, end; + host.BeginReading(iter); + host.EndReading(end); + + if (host.Length() <= 15) { + // The Windows resolver allows a 4-part dotted decimal IP address to + // have a space followed by any old rubbish, so long as the total length + // of the string doesn't get above 15 characters. So, "10.192.95.89 xy" + // is resolved to 10.192.95.89. + // If the string length is greater than 15 characters, e.g. + // "10.192.95.89 xy.wildcard.example.com", it will be resolved through + // DNS. + + if (FindCharInReadable(' ', iter, end)) { + end = iter; + } + } + + for (host.BeginReading(iter); iter != end; iter++) { + if (!(isxdigit(*iter) || *iter == 'x' || *iter == 'X' || *iter == '.')) { + // not an IP + return; + } + } + + host.BeginReading(iter); + nsCStringArray parts; + parts.ParseString(PromiseFlatCString(Substring(iter, end)).get(), "."); + if (parts.Count() > 4) { + return; + } + + // If any potentially-octal numbers (start with 0 but not hex) have + // non-octal digits, no part of the ip can be in octal + // XXX: this came from the old javascript implementation, is it really + // supposed to be like this? + PRBool allowOctal = PR_TRUE; + for (PRInt32 i = 0; i < parts.Count(); i++) { + const nsCString& part = *parts[i]; + if (part[0] == '0') { + for (PRUint32 j = 1; j < part.Length(); j++) { + if (part[j] == 'x') { + break; + } + if (part[j] == '8' || part[j] == '9') { + allowOctal = PR_FALSE; + break; + } + } + } + } + + for (PRInt32 i = 0; i < parts.Count(); i++) { + nsCAutoString canonical; + + if (i == parts.Count() - 1) { + CanonicalNum(*parts[i], 5 - parts.Count(), allowOctal, canonical); + } else { + CanonicalNum(*parts[i], 1, allowOctal, canonical); + } + + if (canonical.IsEmpty()) { + _retval.Truncate(); + return; + } + + if (_retval.IsEmpty()) { + _retval.Assign(canonical); + } else { + _retval.Append('.'); + _retval.Append(canonical); + } + } + return; +} + +void +nsUrlClassifierUtils::CanonicalNum(const nsACString& num, + PRUint32 bytes, + PRBool allowOctal, + nsACString& _retval) +{ + _retval.Truncate(); + + if (num.Length() < 1) { + return; + } + + PRUint32 val; + if (allowOctal && IsOctal(num)) { + if (PR_sscanf(PromiseFlatCString(num).get(), "%o", &val) != 1) { + return; + } + } else if (IsDecimal(num)) { + if (PR_sscanf(PromiseFlatCString(num).get(), "%u", &val) != 1) { + return; + } + } else if (IsHex(num)) { + if (PR_sscanf(PromiseFlatCString(num).get(), num[1] == 'X' ? "0X%x" : "0x%x", + &val) != 1) { + return; + } + } else { + return; + } + + while (bytes--) { + char buf[20]; + PR_snprintf(buf, sizeof(buf), "%u", val & 0xff); + if (_retval.IsEmpty()) { + _retval.Assign(buf); + } else { + _retval = nsDependentCString(buf) + NS_LITERAL_CSTRING(".") + _retval; + } + val >>= 8; + } +} + // This function will encode all "special" characters in typical url -// encoding, that is %hh where h is a valid hex digit. See the comment in -// the header file for details. +// encoding, that is %hh where h is a valid hex digit. It will also fold +// any duplicated slashes. PRBool -nsUrlClassifierUtils::SpecialEncode(const nsACString & url, nsACString & _retval) +nsUrlClassifierUtils::SpecialEncode(const nsACString & url, + PRBool foldSlashes, + nsACString & _retval) { PRBool changed = PR_FALSE; const char* curChar = url.BeginReading(); const char* end = url.EndReading(); + unsigned char lastChar = '\0'; while (curChar != end) { unsigned char c = static_cast(*curChar); if (ShouldURLEscape(c)) { @@ -125,9 +379,12 @@ nsUrlClassifierUtils::SpecialEncode(const nsACString & url, nsACString & _retval _retval.Append(int_to_hex_digit(c % 16)); changed = PR_TRUE; + } else if (foldSlashes && (c == '/' && lastChar == '/')) { + // skip } else { _retval.Append(*curChar); } + lastChar = c; curChar++; } return changed; diff --git a/toolkit/components/url-classifier/src/nsUrlClassifierUtils.h b/toolkit/components/url-classifier/src/nsUrlClassifierUtils.h index d61f3dfbe125..d4897637af1e 100644 --- a/toolkit/components/url-classifier/src/nsUrlClassifierUtils.h +++ b/toolkit/components/url-classifier/src/nsUrlClassifierUtils.h @@ -77,19 +77,30 @@ public: nsUrlClassifierUtils(); ~nsUrlClassifierUtils() {} - nsresult Init(); - NS_DECL_ISUPPORTS NS_DECL_NSIURLCLASSIFIERUTILS + nsresult Init(); + + nsresult CanonicalizeHostname(const nsACString & hostname, + nsACString & _retval); + nsresult CanonicalizePath(const nsACString & url, nsACString & _retval); + // This function will encode all "special" characters in typical url encoding, // that is %hh where h is a valid hex digit. The characters which are encoded // by this function are any ascii characters under 32(control characters and // space), 37(%), and anything 127 or above (special characters). Url is the // string to encode, ret is the encoded string. Function returns true if // ret != url. - PRBool SpecialEncode(const nsACString & url, nsACString & _retval); + PRBool SpecialEncode(const nsACString & url, + PRBool foldSlashes, + nsACString & _retval); + void ParseIPAddress(const nsACString & host, nsACString & _retval); + void CanonicalNum(const nsACString & num, + PRUint32 bytes, + PRBool allowOctal, + nsACString & _retval); private: // Disallow copy constructor nsUrlClassifierUtils(const nsUrlClassifierUtils&); @@ -97,6 +108,8 @@ private: // Function to tell if we should encode a character. PRBool ShouldURLEscape(const unsigned char c) const; + void CleanupHostname(const nsACString & host, nsACString & _retval); + nsAutoPtr mEscapeCharmap; }; diff --git a/toolkit/components/url-classifier/tests/Makefile.in b/toolkit/components/url-classifier/tests/Makefile.in index f30f9ca237a2..1d9e6f6e406b 100644 --- a/toolkit/components/url-classifier/tests/Makefile.in +++ b/toolkit/components/url-classifier/tests/Makefile.in @@ -52,15 +52,11 @@ REQUIRES = \ string \ url-classifier \ xpcom \ + necko \ $(NULL) -# mochitests -_TEST_FILES = \ - test_enchash-decrypter.xhtml \ - $(NULL) - -libs:: $(_TEST_FILES) - $(INSTALL) $^ $(DEPTH)/_tests/testing/mochitest/tests/$(relativesrcdir) +# xpcshell tests +XPCSHELL_TESTS=unit # simple c++ tests (no xpcom) CPPSRCS = \ @@ -76,6 +72,7 @@ LOCAL_INCLUDES = \ LIBS = \ ../src/$(LIB_PREFIX)urlclassifier_s.$(LIB_SUFFIX) \ + $(MOZ_COMPONENT_LIBS) \ $(XPCOM_LIBS) \ $(NSPR_LIBS) \ $(NULL) diff --git a/toolkit/components/url-classifier/tests/TestUrlClassifierUtils.cpp b/toolkit/components/url-classifier/tests/TestUrlClassifierUtils.cpp index ff6e8c072733..f2bd618d6017 100644 --- a/toolkit/components/url-classifier/tests/TestUrlClassifierUtils.cpp +++ b/toolkit/components/url-classifier/tests/TestUrlClassifierUtils.cpp @@ -39,6 +39,8 @@ #include "nsEscape.h" #include "nsString.h" #include "nsUrlClassifierUtils.h" +#include "nsNetUtil.h" +#include "stdlib.h" static int gTotalTests = 0; static int gPassedTests = 0; @@ -117,8 +119,8 @@ void TestEncodeHelper(const char* in, const char* expected) { nsCString out, strIn(in), strExp(expected); nsUrlClassifierUtils utils; - - utils.SpecialEncode(strIn, out); + + utils.SpecialEncode(strIn, PR_TRUE, out); CheckEquals(strExp, out); } @@ -136,7 +138,7 @@ void TestEnc() } nsUrlClassifierUtils utils; nsCString out; - utils.SpecialEncode(noenc, out); + utils.SpecialEncode(noenc, PR_FALSE, out); CheckEquals(noenc, out); // Test that all the chars that we should encode [0,32],37,[127,255] are @@ -151,16 +153,18 @@ void TestEnc() } out.Truncate(); - utils.SpecialEncode(yesAsString, out); + utils.SpecialEncode(yesAsString, PR_FALSE, out); CheckEquals(yesExpectedString, out); + + TestEncodeHelper("blah//blah", "blah/blah"); } void TestCanonicalizeHelper(const char* in, const char* expected) { nsCString out, strIn(in), strExp(expected); nsUrlClassifierUtils utils; - - utils.CanonicalizeURL(strIn, out); + + utils.CanonicalizePath(strIn, out); CheckEquals(strExp, out); } @@ -177,19 +181,142 @@ void TestCanonicalize() "~a!b@c#d$e%25f^00&11*22(33)44_55+"); TestCanonicalizeHelper("", ""); - TestCanonicalizeHelper("http://www.google.com", "http://www.google.com"); - TestCanonicalizeHelper("http://%31%36%38%2e%31%38%38%2e%39%39%2e%32%36/%2E%73%65%63%75%72%65/%77%77%77%2E%65%62%61%79%2E%63%6F%6D/", - "http://168.188.99.26/.secure/www.ebay.com/"); - TestCanonicalizeHelper("http://195.127.0.11/uploads/%20%20%20%20/.verify/.eBaysecure=updateuserdataxplimnbqmn-xplmvalidateinfoswqpcmlx=hgplmcx/", - "http://195.127.0.11/uploads/%20%20%20%20/.verify/.eBaysecure=updateuserdataxplimnbqmn-xplmvalidateinfoswqpcmlx=hgplmcx/"); + TestCanonicalizeHelper("%31%36%38%2e%31%38%38%2e%39%39%2e%32%36/%2E%73%65%63%75%72%65/%77%77%77%2E%65%62%61%79%2E%63%6F%6D/", + "168.188.99.26/.secure/www.ebay.com/"); + TestCanonicalizeHelper("195.127.0.11/uploads/%20%20%20%20/.verify/.eBaysecure=updateuserdataxplimnbqmn-xplmvalidateinfoswqpcmlx=hgplmcx/", + "195.127.0.11/uploads/%20%20%20%20/.verify/.eBaysecure=updateuserdataxplimnbqmn-xplmvalidateinfoswqpcmlx=hgplmcx/"); +} + +void TestParseIPAddressHelper(const char *in, const char *expected) +{ + nsCString out, strIn(in), strExp(expected); + nsUrlClassifierUtils utils; + utils.Init(); + + utils.ParseIPAddress(strIn, out); + CheckEquals(strExp, out); +} + +void TestParseIPAddress() +{ + TestParseIPAddressHelper("123.123.0.0.1", ""); + TestParseIPAddressHelper("255.0.0.1", "255.0.0.1"); + TestParseIPAddressHelper("12.0x12.01234", "12.18.2.156"); + TestParseIPAddressHelper("276.2.3", "20.2.0.3"); + TestParseIPAddressHelper("012.034.01.055", "10.28.1.45"); + TestParseIPAddressHelper("0x12.0x43.0x44.0x01", "18.67.68.1"); + TestParseIPAddressHelper("167838211", "10.1.2.3"); + TestParseIPAddressHelper("3279880203", "195.127.0.11"); + TestParseIPAddressHelper("0x12434401", "18.67.68.1"); + TestParseIPAddressHelper("413960661", "24.172.137.213"); + TestParseIPAddressHelper("03053104725", "24.172.137.213"); + TestParseIPAddressHelper("030.0254.0x89d5", "24.172.137.213"); + TestParseIPAddressHelper("1.234.4.0377", "1.234.4.255"); + TestParseIPAddressHelper("1.2.3.00x0", ""); + TestParseIPAddressHelper("10.192.95.89 xy", "10.192.95.89"); + TestParseIPAddressHelper("10.192.95.89 xyz", ""); + TestParseIPAddressHelper("1.2.3.0x0", "1.2.3.0"); + TestParseIPAddressHelper("1.2.3.4", "1.2.3.4"); +} + +void TestCanonicalNumHelper(const char *in, PRUint32 bytes, + bool allowOctal, const char *expected) +{ + nsCString out, strIn(in), strExp(expected); + nsUrlClassifierUtils utils; + utils.Init(); + + utils.CanonicalNum(strIn, bytes, allowOctal, out); + CheckEquals(strExp, out); +} + +void TestCanonicalNum() +{ + TestCanonicalNumHelper("", 1, true, ""); + TestCanonicalNumHelper("10", 0, true, ""); + TestCanonicalNumHelper("45", 1, true, "45"); + TestCanonicalNumHelper("0x10", 1, true, "16"); + TestCanonicalNumHelper("367", 2, true, "1.111"); + TestCanonicalNumHelper("012345", 3, true, "0.20.229"); + TestCanonicalNumHelper("0173", 1, true, "123"); + TestCanonicalNumHelper("09", 1, false, "9"); + TestCanonicalNumHelper("0x120x34", 2, true, ""); + TestCanonicalNumHelper("0x12fc", 2, true, "18.252"); + TestCanonicalNumHelper("3279880203", 4, true, "195.127.0.11"); + TestCanonicalNumHelper("0x0000059", 1, true, "89"); + TestCanonicalNumHelper("0x00000059", 1, true, "89"); + TestCanonicalNumHelper("0x0000067", 1, true, "103"); +} + +void TestHostnameHelper(const char *in, const char *expected) +{ + nsCString out, strIn(in), strExp(expected); + nsUrlClassifierUtils utils; + utils.Init(); + + utils.CanonicalizeHostname(strIn, out); + CheckEquals(strExp, out); +} + +void TestHostname() +{ + TestHostnameHelper("abcd123;[]", "abcd123;[]"); + TestHostnameHelper("abc.123", "abc.123"); + TestHostnameHelper("abc..123", "abc.123"); + TestHostnameHelper("trailing.", "trailing"); + TestHostnameHelper("i love trailing dots....", "i%20love%20trailing%20dots"); + TestHostnameHelper(".leading", "leading"); + TestHostnameHelper("..leading", "leading"); + TestHostnameHelper(".dots.", "dots"); + TestHostnameHelper(".both.", "both"); + TestHostnameHelper(".both..", "both"); + TestHostnameHelper("..both.", "both"); + TestHostnameHelper("..both..", "both"); + TestHostnameHelper("..a.b.c.d..", "a.b.c.d"); + TestHostnameHelper("..127.0.0.1..", "127.0.0.1"); + TestHostnameHelper("asdf!@#$a", "asdf!@#$a"); + TestHostnameHelper("AB CD 12354", "ab%20cd%2012354"); + TestHostnameHelper("\1\2\3\4\112\177", "%01%02%03%04j%7F"); + TestHostnameHelper("<>.AS/-+", "<>.as/-+"); + +} + +void TestLongHostname() +{ + static const int kTestSize = 1024 * 150; + char *str = static_cast(malloc(kTestSize + 1)); + memset(str, 'x', kTestSize); + str[kTestSize] = '\0'; + + nsUrlClassifierUtils utils; + utils.Init(); + + nsCAutoString out; + nsDependentCString in(str); + PRIntervalTime clockStart = PR_IntervalNow(); + utils.CanonicalizeHostname(in, out); + PRIntervalTime clockEnd = PR_IntervalNow(); + + CheckEquals(in, out); + + printf("CanonicalizeHostname on long string (%dms)\n", + PR_IntervalToMilliseconds(clockEnd - clockStart)); } int main(int argc, char **argv) { + NS_LogInit(); + TestUnescape(); TestEnc(); TestCanonicalize(); + TestCanonicalNum(); + TestParseIPAddress(); + TestHostname(); + TestLongHostname(); + printf("%d of %d tests passed\n", gPassedTests, gTotalTests); // Non-zero return status signals test failure to build system. + return (gPassedTests != gTotalTests); }