Bug 1790851 - Add libprio-rs dependency. r=emilio,supply-chain-reviewers

Differential Revision: https://phabricator.services.mozilla.com/D157315
This commit is contained in:
Simon Friedberger 2022-09-15 14:39:07 +00:00
parent 2260491bf2
commit e7b5980201
39 changed files with 12829 additions and 0 deletions

17
Cargo.lock generated
View File

@ -1196,6 +1196,9 @@ dependencies = [
[[package]]
name = "dap_ffi"
version = "0.1.0"
dependencies = [
"prio",
]
[[package]]
name = "darling"
@ -4132,6 +4135,20 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
name = "prefs_parser"
version = "0.0.1"
[[package]]
name = "prio"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d747fc8ff20e95b1ff2c5c6162143b89a1b08cee1bc8941fada53bfe1b04d3"
dependencies = [
"base64",
"byteorder",
"getrandom",
"serde",
"static_assertions",
"thiserror",
]
[[package]]
name = "proc-macro-error"
version = "1.0.4"

View File

@ -757,6 +757,12 @@ criteria = "safe-to-deploy"
version = "0.1.1"
notes = "This is a trivial crate."
[[audits.prio]]
who = "Simon Friedberger <simon@mozilla.com>"
criteria = "safe-to-deploy"
version = "0.8.4"
notes = "The crate does not use any unsafe code or ambient capabilities and thus meets the criteria for safe-to-deploy. The cryptography itself should be considered experimental at this phase and is currently undergoing a thorough audit organized by Cloudflare."
[[audits.proc-macro2]]
who = "Nika Layzell <nika@thelayzells.com>"
criteria = "safe-to-deploy"

View File

@ -0,0 +1 @@
{"files":{"Cargo.lock":"dfc8012d6eccfd9a048855151a21e4c72a4cb721340e16adaee812f352c849e5","Cargo.toml":"8c592c502411530076d915389d63db9c53c664218f64f94ed0941ba1062e5ae3","LICENSE":"5f5a5db8d4baa0eea0ff2d32a5a86c7a899a3343f1496f4477f42e2d651cc6dc","README.md":"9770a3d54717932143402e4234cd0095bf1c9b1271fd8e691d16b6bd6371a998","benches/speed_tests.rs":"5c564c01be0f66c9a2034247175b965342a999d4c58e8c8072438622eb0ba9c9","documentation/releases.md":"14cfe917c88b69d557badc683b887c734254810402c7e19c9a45d815637480a9","examples/sum.rs":"b94c4701c02e0bcf4ca854a8e6edef5a0c92454984a99e5128388c6685c4f276","src/benchmarked.rs":"d570042472a0ab939f7505d8823b7eb1fe71d764f7903dee6958799d04542314","src/client.rs":"52e8a784cae6564aa54891b9e72bc42b61796955960198252fbce688f8dfc584","src/codec.rs":"fa2c87cae856337f4240e682dbf6e0807986ffb10b8fd6c0309e4bd296e046ec","src/encrypt.rs":"a12dd543719ae158cf1dd56a29e08f1d0ab5c15b14ea3ddbb5128c9ee8ea3f81","src/fft.rs":"391094b246d1ee97ade8c8af810b183a3b70d49b2ea358461fbd2a009e90640d","src/field.rs":"78153e47c3025e0b17c4c810362cb4efcf95b6213a7d53a4551c0878f669ae27","src/flp.rs":"04f29b19b0fa6c5e4e7c25cd2961469a343672343e30b262c3224ec44807684a","src/flp/gadgets.rs":"e4e423081d983bb65ce2d332e3c5383cef076a22e345c75ce1b1221f24243c73","src/flp/types.rs":"2d555561c1f65f61f98c9925ecbc68a76dfc776315bcb5732858904e99f07958","src/fp.rs":"de7f76be35463b6afe780bc4177bf12b43d78b65ac3fef6f3ea26faafd715b64","src/lib.rs":"dfa3c8dc32d643e3f7f9a8a747902b17ed65aa5d810db0c367db21f49410fd69","src/polynomial.rs":"6aa1d8474e073687d1296f4d8172f71c2a2b5236a922d7ffe0c6507a8a1403f1","src/prng.rs":"a35073e197ebb8d6ae8efaa4b7281306226b654617348eb5625a361f56327fd2","src/server.rs":"8f410d43541ae2867f0385d60020f2a1a3005f7c99f52e61f2631134c9bcc396","src/test_vector.rs":"4aaacee3994f92066a11b6bdae8a2d1a622cd19f7acd47267d9236942e384185","src/util.rs":"860497817b9d64a1dbf414ab40ec514e683649b9ea77c018462e96cde31521d0","src/vdaf.rs":"7e39b697ccf907f3e79ceea71fc8985cf095634558bee66d68732e1b0e0f604d","src/vdaf/poplar1.rs":"495b97ee2379d1890df04b1ff42f89cfbb2327fa4e609bf266897539d792da5c","src/vdaf/prg.rs":"46e56c7932f8480bec445eccd3fd69502471f206d34d48f8062363a2e1c0f0eb","src/vdaf/prio2.rs":"c68ae84bb92a080375693216dca26d493c8b277e9e4f5432f67cf3420ad75a9f","src/vdaf/prio3.rs":"b0366a0a65e57efbedc7f1dc152d830c11261d563795fd0eeb2ca6089b8b71b5","src/vdaf/prio3_test.rs":"d67d1e02356e90a3732d7a298fa8de2bfebef3a987bfbe847f21df740254dd74","src/vdaf/test_vec/01/PrgAes128.json":"b2a88bff8f3d63966ae29d286a6388782018e8d2203402a5dc3aee092509afb9","src/vdaf/test_vec/01/Prio3Aes128Count.json":"d74d0eb2fe1530e32dee35ba7bd84c6ee151409406d0066aacbc80086d861f30","src/vdaf/test_vec/01/Prio3Aes128Histogram.json":"e1a29138dd2f4e6121806ade305f988e8f9c03f928870a874a0bd9d5d95571e5","src/vdaf/test_vec/01/Prio3Aes128Sum.json":"1303daafcbc24d2f9d2a1ec6c2565ae046143dcf91dbfaad9220abc6d2d03495","tests/backward_compatibility.rs":"56faee44ff85cad8b0f5b93a47d44fbb5ea5c1275cfa7b4309257aee73497234","tests/test_vectors/fieldpriov2.json":"4955819c163d86c1e417789981444499fc02133a9a51742af25b189ce910e4c4"},"package":"41d747fc8ff20e95b1ff2c5c6162143b89a1b08cee1bc8941fada53bfe1b04d3"}

1024
third_party/rust/prio/Cargo.lock generated vendored Normal file

File diff suppressed because it is too large Load Diff

138
third_party/rust/prio/Cargo.toml vendored Normal file
View File

@ -0,0 +1,138 @@
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
#
# When uploading crates to the registry Cargo will automatically
# "normalize" Cargo.toml files for maximal compatibility
# with all versions of Cargo and also rewrite `path` dependencies
# to registry (e.g., crates.io) dependencies.
#
# If you are reading this file be aware that the original Cargo.toml
# will likely look very different (and much more reasonable).
# See Cargo.toml.orig for the original contents.
[package]
edition = "2018"
rust-version = "1.58"
name = "prio"
version = "0.8.4"
authors = [
"Josh Aas <jaas@kflag.net>",
"Tim Geoghegan <timg@letsencrypt.org>",
"Christopher Patton <cpatton@cloudflare.com",
"Karl Tarbe <tarbe@apple.com>",
]
description = "Implementation of the Prio aggregation system core: https://crypto.stanford.edu/prio/"
readme = "README.md"
license = "MPL-2.0"
repository = "https://github.com/divviup/libprio-rs"
resolver = "2"
[package.metadata.docs.rs]
all-features = true
rustdoc-args = [
"--cfg",
"docsrs",
]
[[example]]
name = "sum"
required-features = ["prio2"]
[[test]]
name = "backward_compatibility"
path = "tests/backward_compatibility.rs"
required-features = ["prio2"]
[[bench]]
name = "speed_tests"
harness = false
[dependencies.aes]
version = "0.8.1"
optional = true
[dependencies.aes-gcm]
version = "^0.9"
optional = true
[dependencies.base64]
version = "0.13.0"
[dependencies.byteorder]
version = "1.4.3"
[dependencies.cmac]
version = "0.7.1"
optional = true
[dependencies.ctr]
version = "0.9.1"
optional = true
[dependencies.getrandom]
version = "0.2.7"
features = ["std"]
[dependencies.rand]
version = "0.8"
optional = true
[dependencies.rayon]
version = "1.5.3"
optional = true
[dependencies.ring]
version = "0.16.20"
optional = true
[dependencies.serde]
version = "1.0"
features = ["derive"]
[dependencies.serde_json]
version = "1.0"
optional = true
[dependencies.static_assertions]
version = "1.1.0"
[dependencies.thiserror]
version = "1.0"
[dev-dependencies.assert_matches]
version = "1.5.0"
[dev-dependencies.criterion]
version = "0.3"
[dev-dependencies.hex]
version = "0.4.3"
features = ["serde"]
[dev-dependencies.itertools]
version = "0.10.3"
[dev-dependencies.modinverse]
version = "0.1.0"
[dev-dependencies.num-bigint]
version = "0.4.3"
[dev-dependencies.serde_json]
version = "1.0"
[features]
crypto-dependencies = [
"aes",
"ctr",
"cmac",
]
default = ["crypto-dependencies"]
multithreaded = ["rayon"]
prio2 = [
"aes-gcm",
"ring",
]
test-util = [
"rand",
"serde_json",
]

375
third_party/rust/prio/LICENSE vendored Normal file
View File

@ -0,0 +1,375 @@
Copyright 2021 ISRG, except where otherwise noted. All rights reserved.
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

34
third_party/rust/prio/README.md vendored Normal file
View File

@ -0,0 +1,34 @@
# libprio-rs
[![Build Status]][actions] [![Latest Version]][crates.io] [![Docs badge]][docs.rs]
[Build Status]: https://github.com/divviup/libprio-rs/workflows/ci-build/badge.svg
[actions]: https://github.com/divviup/libprio-rs/actions?query=branch%3Amain
[Latest Version]: https://img.shields.io/crates/v/prio.svg
[crates.io]: https://crates.io/crates/prio
[Docs badge]: https://img.shields.io/badge/docs.rs-rustdoc-green
[docs.rs]: https://docs.rs/prio/
Pure Rust implementation of [Prio](https://crypto.stanford.edu/prio/), a system for Private, Robust,
and Scalable Computation of Aggregate Statistics.
## Exposure Notifications Private Analytics
This crate is used in the [Exposure Notifications Private Analytics][enpa] system. This is supported
by the interfaces in modules `server` and `client` and is referred to in various places as Prio v2.
See [`prio-server`][prio-server] or the [ENPA whitepaper][enpa-whitepaper] for more details.
## Verifiable Distributed Aggregation Function (EXPERIMENTAL)
Crate `prio` also implements a [Verifiable Distributed Aggregation Function
(VDAF)][vdaf] called "Prio3", implemented in the `vdaf` module, allowing Prio to
be used in the [Distributed Aggregation Protocol][dap] protocol being developed
in the PPM working group at the IETF. This support is still experimental, and is
evolving along with the DAP and VDAF specifications. Formal security analysis is
also forthcoming. Prio3 should not yet be used in production applications.
[enpa]: https://www.abetterinternet.org/post/prio-services-for-covid-en/
[enpa-whitepaper]: https://covid19-static.cdn-apple.com/applications/covid19/current/static/contact-tracing/pdf/ENPA_White_Paper.pdf
[prio-server]: https://github.com/divviup/prio-server
[vdaf]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/
[dap]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/

View File

@ -0,0 +1,276 @@
// SPDX-License-Identifier: MPL-2.0
use criterion::{criterion_group, criterion_main, Criterion};
use prio::benchmarked::*;
#[cfg(feature = "prio2")]
use prio::client::Client as Prio2Client;
use prio::codec::Encode;
#[cfg(feature = "prio2")]
use prio::encrypt::PublicKey;
use prio::field::{random_vector, Field128 as F, FieldElement};
#[cfg(feature = "multithreaded")]
use prio::flp::gadgets::ParallelSumMultithreaded;
use prio::flp::{
gadgets::{BlindPolyEval, Mul, ParallelSum},
types::CountVec,
Type,
};
#[cfg(feature = "prio2")]
use prio::server::{generate_verification_message, ValidationMemory};
use prio::vdaf::prio3::Prio3;
use prio::vdaf::{prio3::Prio3InputShare, Client as Prio3Client};
/// This benchmark compares the performance of recursive and iterative FFT.
pub fn fft(c: &mut Criterion) {
let test_sizes = [16, 256, 1024, 4096];
for size in test_sizes.iter() {
let inp = random_vector(*size).unwrap();
let mut outp = vec![F::zero(); *size];
c.bench_function(&format!("iterative FFT, size={}", *size), |b| {
b.iter(|| {
benchmarked_iterative_fft(&mut outp, &inp);
})
});
c.bench_function(&format!("recursive FFT, size={}", *size), |b| {
b.iter(|| {
benchmarked_recursive_fft(&mut outp, &inp);
})
});
}
}
/// Speed test for generating a seed and deriving a pseudorandom sequence of field elements.
pub fn prng(c: &mut Criterion) {
let test_sizes = [16, 256, 1024, 4096];
for size in test_sizes.iter() {
c.bench_function(&format!("rand, size={}", *size), |b| {
b.iter(|| random_vector::<F>(*size))
});
}
}
/// The asymptotic cost of polynomial multiplication is `O(n log n)` using FFT and `O(n^2)` using
/// the naive method. This benchmark demonstrates that the latter has better concrete performance
/// for small polynomials. The result is used to pick the `FFT_THRESHOLD` constant in
/// `src/flp/gadgets.rs`.
pub fn poly_mul(c: &mut Criterion) {
let test_sizes = [1_usize, 30, 60, 90, 120, 150];
for size in test_sizes.iter() {
let m = (*size + 1).next_power_of_two();
let mut g: Mul<F> = Mul::new(*size);
let mut outp = vec![F::zero(); 2 * m];
let inp = vec![random_vector(m).unwrap(); 2];
c.bench_function(&format!("poly mul FFT, size={}", *size), |b| {
b.iter(|| {
benchmarked_gadget_mul_call_poly_fft(&mut g, &mut outp, &inp).unwrap();
})
});
c.bench_function(&format!("poly mul direct, size={}", *size), |b| {
b.iter(|| {
benchmarked_gadget_mul_call_poly_direct(&mut g, &mut outp, &inp).unwrap();
})
});
}
}
/// Benchmark generation and verification of boolean vectors.
pub fn count_vec(c: &mut Criterion) {
let test_sizes = [10, 100, 1_000];
for size in test_sizes.iter() {
let input = vec![F::zero(); *size];
#[cfg(feature = "prio2")]
{
// Public keys used to instantiate the v2 client.
const PUBKEY1: &str = "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=";
const PUBKEY2: &str = "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LE=";
// Prio2
let pk1 = PublicKey::from_base64(PUBKEY1).unwrap();
let pk2 = PublicKey::from_base64(PUBKEY2).unwrap();
let mut client: Prio2Client<F> =
Prio2Client::new(input.len(), pk1.clone(), pk2.clone()).unwrap();
println!(
"prio2 proof size={}\n",
benchmarked_v2_prove(&input, &mut client).len()
);
c.bench_function(&format!("prio2 prove, size={}", *size), |b| {
b.iter(|| {
benchmarked_v2_prove(&input, &mut client);
})
});
let input_and_proof = benchmarked_v2_prove(&input, &mut client);
let mut validator: ValidationMemory<F> = ValidationMemory::new(input.len());
let eval_at = random_vector(1).unwrap()[0];
c.bench_function(&format!("prio2 query, size={}", *size), |b| {
b.iter(|| {
generate_verification_message(
input.len(),
eval_at,
&input_and_proof,
true,
&mut validator,
)
.unwrap();
})
});
}
// Prio3
let count_vec: CountVec<F, ParallelSum<F, BlindPolyEval<F>>> = CountVec::new(*size);
let joint_rand = random_vector(count_vec.joint_rand_len()).unwrap();
let prove_rand = random_vector(count_vec.prove_rand_len()).unwrap();
let proof = count_vec.prove(&input, &prove_rand, &joint_rand).unwrap();
println!("prio3 countvec proof size={}\n", proof.len());
c.bench_function(&format!("prio3 countvec prove, size={}", *size), |b| {
b.iter(|| {
let prove_rand = random_vector(count_vec.prove_rand_len()).unwrap();
count_vec.prove(&input, &prove_rand, &joint_rand).unwrap();
})
});
c.bench_function(&format!("prio3 countvec query, size={}", *size), |b| {
b.iter(|| {
let query_rand = random_vector(count_vec.query_rand_len()).unwrap();
count_vec
.query(&input, &proof, &query_rand, &joint_rand, 1)
.unwrap();
})
});
#[cfg(feature = "multithreaded")]
{
let count_vec: CountVec<F, ParallelSumMultithreaded<F, BlindPolyEval<F>>> =
CountVec::new(*size);
c.bench_function(
&format!("prio3 countvec multithreaded prove, size={}", *size),
|b| {
b.iter(|| {
let prove_rand = random_vector(count_vec.prove_rand_len()).unwrap();
count_vec.prove(&input, &prove_rand, &joint_rand).unwrap();
})
},
);
c.bench_function(
&format!("prio3 countvec multithreaded query, size={}", *size),
|b| {
b.iter(|| {
let query_rand = random_vector(count_vec.query_rand_len()).unwrap();
count_vec
.query(&input, &proof, &query_rand, &joint_rand, 1)
.unwrap();
})
},
);
}
}
}
/// Benchmark prio3 client performance.
pub fn prio3_client(c: &mut Criterion) {
let num_shares = 2;
let prio3 = Prio3::new_aes128_count(num_shares).unwrap();
let measurement = 1;
println!(
"prio3 count size = {}",
prio3_input_share_size(&prio3.shard(&measurement).unwrap())
);
c.bench_function("prio3 count", |b| {
b.iter(|| {
prio3.shard(&1).unwrap();
})
});
let buckets: Vec<u64> = (1..10).collect();
let prio3 = Prio3::new_aes128_histogram(num_shares, &buckets).unwrap();
let measurement = 17;
println!(
"prio3 histogram ({} buckets) size = {}",
buckets.len() + 1,
prio3_input_share_size(&prio3.shard(&measurement).unwrap())
);
c.bench_function(
&format!("prio3 histogram ({} buckets)", buckets.len() + 1),
|b| {
b.iter(|| {
prio3.shard(&measurement).unwrap();
})
},
);
let bits = 32;
let prio3 = Prio3::new_aes128_sum(num_shares, bits).unwrap();
let measurement = 1337;
println!(
"prio3 sum ({} bits) size = {}",
bits,
prio3_input_share_size(&prio3.shard(&measurement).unwrap())
);
c.bench_function(&format!("prio3 sum ({} bits)", bits), |b| {
b.iter(|| {
prio3.shard(&measurement).unwrap();
})
});
let len = 1000;
let prio3 = Prio3::new_aes128_count_vec(num_shares, len).unwrap();
let measurement = vec![0; len];
println!(
"prio3 countvec ({} len) size = {}",
len,
prio3_input_share_size(&prio3.shard(&measurement).unwrap())
);
c.bench_function(&format!("prio3 countvec ({} len)", len), |b| {
b.iter(|| {
prio3.shard(&measurement).unwrap();
})
});
#[cfg(feature = "multithreaded")]
{
let prio3 = Prio3::new_aes128_count_vec_multithreaded(num_shares, len).unwrap();
let measurement = vec![0; len];
println!(
"prio3 countvec multithreaded ({} len) size = {}",
len,
prio3_input_share_size(&prio3.shard(&measurement).unwrap())
);
c.bench_function(&format!("prio3 parallel countvec ({} len)", len), |b| {
b.iter(|| {
prio3.shard(&measurement).unwrap();
})
});
}
}
fn prio3_input_share_size<F: FieldElement, const L: usize>(
input_shares: &[Prio3InputShare<F, L>],
) -> usize {
let mut size = 0;
for input_share in input_shares {
size += input_share.get_encoded().len();
}
size
}
#[cfg(feature = "prio2")]
criterion_group!(benches, count_vec, prio3_client, poly_mul, prng, fft);
#[cfg(not(feature = "prio2"))]
criterion_group!(benches, prio3_client, poly_mul, prng, fft);
criterion_main!(benches);

View File

@ -0,0 +1,13 @@
# Releases
We use a GitHub Action to publish a crate named `prio` to [crates.io](https://crates.io). To cut a
release and publish:
- Bump the version number in `Cargo.toml` to e.g. `1.2.3` and merge that change to `main`
- Tag that commit on main as `v1.2.3`, either in `git` or in [GitHub's releases UI][releases].
- Publish a release in [GitHub's releases UI][releases].
Publishing the release will automatically publish the updated [`prio` crate][crate].
[releases]: https://github.com/divviup/libprio-rs/releases/new
[crate]: https://crates.io/crates/prio

73
third_party/rust/prio/examples/sum.rs vendored Normal file
View File

@ -0,0 +1,73 @@
// SPDX-License-Identifier: MPL-2.0
use prio::client::*;
use prio::encrypt::*;
use prio::field::*;
use prio::server::*;
fn main() {
let priv_key1 = PrivateKey::from_base64(
"BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgN\
t9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==",
)
.unwrap();
let priv_key2 = PrivateKey::from_base64(
"BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhF\
LMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==",
)
.unwrap();
let pub_key1 = PublicKey::from(&priv_key1);
let pub_key2 = PublicKey::from(&priv_key2);
let dim = 8;
let mut client1 = Client::new(dim, pub_key1.clone(), pub_key2.clone()).unwrap();
let mut client2 = Client::new(dim, pub_key1, pub_key2).unwrap();
let data1_u32 = [0, 0, 1, 0, 0, 0, 0, 0];
println!("Client 1 Input: {:?}", data1_u32);
let data1 = data1_u32
.iter()
.map(|x| Field32::from(*x))
.collect::<Vec<Field32>>();
let data2_u32 = [0, 0, 1, 0, 0, 0, 0, 0];
println!("Client 2 Input: {:?}", data2_u32);
let data2 = data2_u32
.iter()
.map(|x| Field32::from(*x))
.collect::<Vec<Field32>>();
let (share1_1, share1_2) = client1.encode_simple(&data1).unwrap();
let (share2_1, share2_2) = client2.encode_simple(&data2).unwrap();
let eval_at = Field32::from(12313);
let mut server1 = Server::new(dim, true, priv_key1).unwrap();
let mut server2 = Server::new(dim, false, priv_key2).unwrap();
let v1_1 = server1
.generate_verification_message(eval_at, &share1_1)
.unwrap();
let v1_2 = server2
.generate_verification_message(eval_at, &share1_2)
.unwrap();
let v2_1 = server1
.generate_verification_message(eval_at, &share2_1)
.unwrap();
let v2_2 = server2
.generate_verification_message(eval_at, &share2_2)
.unwrap();
let _ = server1.aggregate(&share1_1, &v1_1, &v1_2).unwrap();
let _ = server2.aggregate(&share1_2, &v1_1, &v1_2).unwrap();
let _ = server1.aggregate(&share2_1, &v2_1, &v2_2).unwrap();
let _ = server2.aggregate(&share2_2, &v2_1, &v2_2).unwrap();
server1.merge_total_shares(server2.total_shares()).unwrap();
println!("Final Publication: {:?}", server1.total_shares());
}

View File

@ -0,0 +1,56 @@
// SPDX-License-Identifier: MPL-2.0
//! This module provides wrappers around internal components of this crate that we want to
//! benchmark, but which we don't want to expose in the public API.
#[cfg(feature = "prio2")]
use crate::client::Client;
use crate::fft::discrete_fourier_transform;
use crate::field::FieldElement;
use crate::flp::gadgets::Mul;
use crate::flp::FlpError;
use crate::polynomial::{poly_fft, PolyAuxMemory};
/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm.
pub fn benchmarked_iterative_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) {
discrete_fourier_transform(outp, inp, inp.len()).unwrap();
}
/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm.
pub fn benchmarked_recursive_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) {
let mut mem = PolyAuxMemory::new(inp.len() / 2);
poly_fft(
outp,
inp,
&mem.roots_2n,
inp.len(),
false,
&mut mem.fft_memory,
)
}
/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function
/// uses FFT for multiplication.
pub fn benchmarked_gadget_mul_call_poly_fft<F: FieldElement>(
g: &mut Mul<F>,
outp: &mut [F],
inp: &[Vec<F>],
) -> Result<(), FlpError> {
g.call_poly_fft(outp, inp)
}
/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function
/// does the multiplication directly.
pub fn benchmarked_gadget_mul_call_poly_direct<F: FieldElement>(
g: &mut Mul<F>,
outp: &mut [F],
inp: &[Vec<F>],
) -> Result<(), FlpError> {
g.call_poly_direct(outp, inp)
}
/// Returns a Prio v2 proof that `data` is a valid boolean vector.
#[cfg(feature = "prio2")]
pub fn benchmarked_v2_prove<F: FieldElement>(data: &[F], client: &mut Client<F>) -> Vec<F> {
client.gen_proof(data)
}

264
third_party/rust/prio/src/client.rs vendored Normal file
View File

@ -0,0 +1,264 @@
// Copyright (c) 2020 Apple Inc.
// SPDX-License-Identifier: MPL-2.0
//! The Prio v2 client. Only 0 / 1 vectors are supported for now.
use crate::{
encrypt::{encrypt_share, EncryptError, PublicKey},
field::FieldElement,
polynomial::{poly_fft, PolyAuxMemory},
prng::{Prng, PrngError},
util::{proof_length, unpack_proof_mut},
vdaf::{prg::SeedStreamAes128, VdafError},
};
use std::convert::TryFrom;
/// The main object that can be used to create Prio shares
///
/// Client is used to create Prio shares.
#[derive(Debug)]
pub struct Client<F: FieldElement> {
dimension: usize,
mem: ClientMemory<F>,
public_key1: PublicKey,
public_key2: PublicKey,
}
/// Errors that might be emitted by the client.
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
/// Encryption/decryption error
#[error("encryption/decryption error")]
Encrypt(#[from] EncryptError),
/// PRNG error
#[error("prng error: {0}")]
Prng(#[from] PrngError),
/// VDAF error
#[error("vdaf error: {0}")]
Vdaf(#[from] VdafError),
}
impl<F: FieldElement> Client<F> {
/// Construct a new Prio client
pub fn new(
dimension: usize,
public_key1: PublicKey,
public_key2: PublicKey,
) -> Result<Self, ClientError> {
Ok(Client {
dimension,
mem: ClientMemory::new(dimension)?,
public_key1,
public_key2,
})
}
/// Construct a pair of encrypted shares based on the input data.
pub fn encode_simple(&mut self, data: &[F]) -> Result<(Vec<u8>, Vec<u8>), ClientError> {
let copy_data = |share_data: &mut [F]| {
share_data[..].clone_from_slice(data);
};
Ok(self.encode_with(copy_data)?)
}
/// Construct a pair of encrypted shares using a initilization function.
///
/// This might be slightly more efficient on large vectors, because one can
/// avoid copying the input data.
pub fn encode_with<G>(&mut self, init_function: G) -> Result<(Vec<u8>, Vec<u8>), EncryptError>
where
G: FnOnce(&mut [F]),
{
let mut proof = self.mem.prove_with(self.dimension, init_function);
// use prng to share the proof: share2 is the PRNG seed, and proof is mutated
// in-place
let mut share2 = [0; 32];
getrandom::getrandom(&mut share2)?;
let share2_prng = Prng::from_prio2_seed(&share2);
for (s1, d) in proof.iter_mut().zip(share2_prng.into_iter()) {
*s1 -= d;
}
let share1 = F::slice_into_byte_vec(&proof);
// encrypt shares with respective keys
let encrypted_share1 = encrypt_share(&share1, &self.public_key1)?;
let encrypted_share2 = encrypt_share(&share2, &self.public_key2)?;
Ok((encrypted_share1, encrypted_share2))
}
/// Generate a proof of the input's validity. The output is the encoded input and proof.
pub(crate) fn gen_proof(&mut self, input: &[F]) -> Vec<F> {
let copy_data = |share_data: &mut [F]| {
share_data[..].clone_from_slice(input);
};
self.mem.prove_with(self.dimension, copy_data)
}
}
#[derive(Debug)]
pub(crate) struct ClientMemory<F> {
prng: Prng<F, SeedStreamAes128>,
points_f: Vec<F>,
points_g: Vec<F>,
evals_f: Vec<F>,
evals_g: Vec<F>,
poly_mem: PolyAuxMemory<F>,
}
impl<F: FieldElement> ClientMemory<F> {
pub(crate) fn new(dimension: usize) -> Result<Self, VdafError> {
let n = (dimension + 1).next_power_of_two();
if let Ok(size) = F::Integer::try_from(2 * n) {
if size > F::generator_order() {
return Err(VdafError::Uncategorized(
"input size exceeds field capacity".into(),
));
}
} else {
return Err(VdafError::Uncategorized(
"input size exceeds field capacity".into(),
));
}
Ok(Self {
prng: Prng::new()?,
points_f: vec![F::zero(); n],
points_g: vec![F::zero(); n],
evals_f: vec![F::zero(); 2 * n],
evals_g: vec![F::zero(); 2 * n],
poly_mem: PolyAuxMemory::new(n),
})
}
}
impl<F: FieldElement> ClientMemory<F> {
pub(crate) fn prove_with<G>(&mut self, dimension: usize, init_function: G) -> Vec<F>
where
G: FnOnce(&mut [F]),
{
let mut proof = vec![F::zero(); proof_length(dimension)];
// unpack one long vector to different subparts
let unpacked = unpack_proof_mut(&mut proof, dimension).unwrap();
// initialize the data part
init_function(unpacked.data);
// fill in the rest
construct_proof(
unpacked.data,
dimension,
unpacked.f0,
unpacked.g0,
unpacked.h0,
unpacked.points_h_packed,
self,
);
proof
}
}
/// Convenience function if one does not want to reuse
/// [`Client`](struct.Client.html).
pub fn encode_simple<F: FieldElement>(
data: &[F],
public_key1: PublicKey,
public_key2: PublicKey,
) -> Result<(Vec<u8>, Vec<u8>), ClientError> {
let dimension = data.len();
let mut client_memory = Client::new(dimension, public_key1, public_key2)?;
client_memory.encode_simple(data)
}
fn interpolate_and_evaluate_at_2n<F: FieldElement>(
n: usize,
points_in: &[F],
evals_out: &mut [F],
mem: &mut PolyAuxMemory<F>,
) {
// interpolate through roots of unity
poly_fft(
&mut mem.coeffs,
points_in,
&mem.roots_n_inverted,
n,
true,
&mut mem.fft_memory,
);
// evaluate at 2N roots of unity
poly_fft(
evals_out,
&mem.coeffs,
&mem.roots_2n,
2 * n,
false,
&mut mem.fft_memory,
);
}
/// Proof construction
///
/// Based on Theorem 2.3.3 from Henry Corrigan-Gibbs' dissertation
/// This constructs the output \pi by doing the necessesary calculations
fn construct_proof<F: FieldElement>(
data: &[F],
dimension: usize,
f0: &mut F,
g0: &mut F,
h0: &mut F,
points_h_packed: &mut [F],
mem: &mut ClientMemory<F>,
) {
let n = (dimension + 1).next_power_of_two();
// set zero terms to random
*f0 = mem.prng.get();
*g0 = mem.prng.get();
mem.points_f[0] = *f0;
mem.points_g[0] = *g0;
// set zero term for the proof polynomial
*h0 = *f0 * *g0;
// set f_i = data_(i - 1)
// set g_i = f_i - 1
#[allow(clippy::needless_range_loop)]
for i in 0..dimension {
mem.points_f[i + 1] = data[i];
mem.points_g[i + 1] = data[i] - F::one();
}
// interpolate and evaluate at roots of unity
interpolate_and_evaluate_at_2n(n, &mem.points_f, &mut mem.evals_f, &mut mem.poly_mem);
interpolate_and_evaluate_at_2n(n, &mem.points_g, &mut mem.evals_g, &mut mem.poly_mem);
// calculate the proof polynomial as evals_f(r) * evals_g(r)
// only add non-zero points
let mut j: usize = 0;
let mut i: usize = 1;
while i < 2 * n {
points_h_packed[j] = mem.evals_f[i] * mem.evals_g[i];
j += 1;
i += 2;
}
}
#[test]
fn test_encode() {
use crate::field::Field32;
let pub_key1 = PublicKey::from_base64(
"BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=",
)
.unwrap();
let pub_key2 = PublicKey::from_base64(
"BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LE=",
)
.unwrap();
let data_u32 = [0u32, 1, 0, 1, 1, 0, 0, 0, 1];
let data = data_u32
.iter()
.map(|x| Field32::from(*x))
.collect::<Vec<Field32>>();
let encoded_shares = encode_simple(&data, pub_key1, pub_key2);
assert!(encoded_shares.is_ok());
}

598
third_party/rust/prio/src/codec.rs vendored Normal file
View File

@ -0,0 +1,598 @@
// SPDX-License-Identifier: MPL-2.0
//! Module `codec` provides support for encoding and decoding messages to or from the TLS wire
//! encoding, as specified in [RFC 8446, Section 3][1]. It provides traits that can be implemented
//! on values that need to be encoded or decoded, as well as utility functions for encoding
//! sequences of values.
//!
//! [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3
use byteorder::{BigEndian, ReadBytesExt};
use std::{
error::Error,
io::{Cursor, Read},
mem::size_of,
};
#[allow(missing_docs)]
#[derive(Debug, thiserror::Error)]
pub enum CodecError {
#[error("I/O error")]
Io(#[from] std::io::Error),
#[error("{0} bytes left in buffer after decoding value")]
BytesLeftOver(usize),
#[error("length prefix of encoded vector overflows buffer: {0}")]
LengthPrefixTooBig(usize),
#[error("other error: {0}")]
Other(#[source] Box<dyn Error + 'static + Send + Sync>),
#[error("unexpected value")]
UnexpectedValue,
}
/// Describes how to decode an object from a byte sequence.
pub trait Decode: Sized {
/// Read and decode an encoded object from `bytes`. On success, the decoded value is returned
/// and `bytes` is advanced by the encoded size of the value. On failure, an error is returned
/// and no further attempt to read from `bytes` should be made.
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError>;
/// Convenience method to get decoded value. Returns an error if [`Self::decode`] fails, or if
/// there are any bytes left in `bytes` after decoding a value.
fn get_decoded(bytes: &[u8]) -> Result<Self, CodecError> {
Self::get_decoded_with_param(&(), bytes)
}
}
/// Describes how to decode an object from a byte sequence, with a decoding parameter provided to
/// provide additional data used in decoding.
pub trait ParameterizedDecode<P>: Sized {
/// Read and decode an encoded object from `bytes`. `decoding_parameter` provides details of the
/// wire encoding such as lengths of different portions of the message. On success, the decoded
/// value is returned and `bytes` is advanced by the encoded size of the value. On failure, an
/// error is returned and no further attempt to read from `bytes` should be made.
fn decode_with_param(
decoding_parameter: &P,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError>;
/// Convenience method to get decoded value. Returns an error if [`Self::decode_with_param`]
/// fails, or if there are any bytes left in `bytes` after decoding a value.
fn get_decoded_with_param(decoding_parameter: &P, bytes: &[u8]) -> Result<Self, CodecError> {
let mut cursor = Cursor::new(bytes);
let decoded = Self::decode_with_param(decoding_parameter, &mut cursor)?;
if cursor.position() as usize != bytes.len() {
return Err(CodecError::BytesLeftOver(
bytes.len() - cursor.position() as usize,
));
}
Ok(decoded)
}
}
// Provide a blanket implementation so that any Decode can be used as a ParameterizedDecode<T> for
// any T.
impl<D: Decode + ?Sized, T> ParameterizedDecode<T> for D {
fn decode_with_param(
_decoding_parameter: &T,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
Self::decode(bytes)
}
}
/// Describes how to encode objects into a byte sequence.
pub trait Encode {
/// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
fn encode(&self, bytes: &mut Vec<u8>);
/// Convenience method to get encoded value.
fn get_encoded(&self) -> Vec<u8> {
self.get_encoded_with_param(&())
}
}
/// Describes how to encode objects into a byte sequence.
pub trait ParameterizedEncode<P> {
/// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.
/// `encoding_parameter` provides details of the wire encoding, used to control how the value
/// is encoded.
fn encode_with_param(&self, encoding_parameter: &P, bytes: &mut Vec<u8>);
/// Convenience method to get encoded value.
fn get_encoded_with_param(&self, encoding_parameter: &P) -> Vec<u8> {
let mut ret = Vec::new();
self.encode_with_param(encoding_parameter, &mut ret);
ret
}
}
// Provide a blanket implementation so that any Encode can be used as a ParameterizedEncode<T> for
// any T.
impl<E: Encode + ?Sized, T> ParameterizedEncode<T> for E {
fn encode_with_param(&self, _encoding_parameter: &T, bytes: &mut Vec<u8>) {
self.encode(bytes)
}
}
impl Decode for () {
fn decode(_bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Ok(())
}
}
impl Encode for () {
fn encode(&self, _bytes: &mut Vec<u8>) {}
}
impl Decode for u8 {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let mut value = [0u8; size_of::<u8>()];
bytes.read_exact(&mut value)?;
Ok(value[0])
}
}
impl Encode for u8 {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.push(*self);
}
}
impl Decode for u16 {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Ok(bytes.read_u16::<BigEndian>()?)
}
}
impl Encode for u16 {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.extend_from_slice(&u16::to_be_bytes(*self));
}
}
/// 24 bit integer, per
/// [RFC 8443, section 3.3](https://datatracker.ietf.org/doc/html/rfc8446#section-3.3)
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct U24(pub u32);
impl Decode for U24 {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Ok(U24(bytes.read_u24::<BigEndian>()?))
}
}
impl Encode for U24 {
fn encode(&self, bytes: &mut Vec<u8>) {
// Encode lower three bytes of the u32 as u24
bytes.extend_from_slice(&u32::to_be_bytes(self.0)[1..]);
}
}
impl Decode for u32 {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Ok(bytes.read_u32::<BigEndian>()?)
}
}
impl Encode for u32 {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.extend_from_slice(&u32::to_be_bytes(*self));
}
}
impl Decode for u64 {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Ok(bytes.read_u64::<BigEndian>()?)
}
}
impl Encode for u64 {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.extend_from_slice(&u64::to_be_bytes(*self));
}
}
/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xff`.
///
/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
pub fn encode_u8_items<P, E: ParameterizedEncode<P>>(
bytes: &mut Vec<u8>,
encoding_parameter: &P,
items: &[E],
) {
// Reserve space to later write length
let len_offset = bytes.len();
bytes.push(0);
for item in items {
item.encode_with_param(encoding_parameter, bytes);
}
let len = bytes.len() - len_offset - 1;
assert!(len <= u8::MAX.into());
bytes[len_offset] = len as u8;
}
/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
/// maximum length `0xff`.
///
/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
pub fn decode_u8_items<P, D: ParameterizedDecode<P>>(
decoding_parameter: &P,
bytes: &mut Cursor<&[u8]>,
) -> Result<Vec<D>, CodecError> {
// Read one byte to get length of opaque byte vector
let length = usize::from(u8::decode(bytes)?);
decode_items(length, decoding_parameter, bytes)
}
/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xffff`.
///
/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
pub fn encode_u16_items<P, E: ParameterizedEncode<P>>(
bytes: &mut Vec<u8>,
encoding_parameter: &P,
items: &[E],
) {
// Reserve space to later write length
let len_offset = bytes.len();
0u16.encode(bytes);
for item in items {
item.encode_with_param(encoding_parameter, bytes);
}
let len = bytes.len() - len_offset - 2;
assert!(len <= u16::MAX.into());
for (offset, byte) in u16::to_be_bytes(len as u16).iter().enumerate() {
bytes[len_offset + offset] = *byte;
}
}
/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
/// maximum length `0xffff`.
///
/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
pub fn decode_u16_items<P, D: ParameterizedDecode<P>>(
decoding_parameter: &P,
bytes: &mut Cursor<&[u8]>,
) -> Result<Vec<D>, CodecError> {
// Read two bytes to get length of opaque byte vector
let length = usize::from(u16::decode(bytes)?);
decode_items(length, decoding_parameter, bytes)
}
/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of
/// `0xffffff`.
///
/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4.
pub fn encode_u24_items<P, E: ParameterizedEncode<P>>(
bytes: &mut Vec<u8>,
encoding_parameter: &P,
items: &[E],
) {
// Reserve space to later write length
let len_offset = bytes.len();
U24(0).encode(bytes);
for item in items {
item.encode_with_param(encoding_parameter, bytes);
}
let len = bytes.len() - len_offset - 3;
assert!(len <= 0xffffff);
for (offset, byte) in u32::to_be_bytes(len as u32)[1..].iter().enumerate() {
bytes[len_offset + offset] = *byte;
}
}
/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of
/// maximum length `0xffffff`.
///
/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4
pub fn decode_u24_items<P, D: ParameterizedDecode<P>>(
decoding_parameter: &P,
bytes: &mut Cursor<&[u8]>,
) -> Result<Vec<D>, CodecError> {
// Read three bytes to get length of opaque byte vector
let length = U24::decode(bytes)?.0 as usize;
decode_items(length, decoding_parameter, bytes)
}
/// Decode the next `length` bytes from `bytes` into as many instances of `D` as possible.
fn decode_items<P, D: ParameterizedDecode<P>>(
length: usize,
decoding_parameter: &P,
bytes: &mut Cursor<&[u8]>,
) -> Result<Vec<D>, CodecError> {
let mut decoded = Vec::new();
let initial_position = bytes.position() as usize;
// Create cursor over specified portion of provided cursor to ensure we can't read past length.
let inner = bytes.get_ref();
// Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer.
let (items_end, overflowed) = initial_position.overflowing_add(length);
if overflowed || items_end > inner.len() {
return Err(CodecError::LengthPrefixTooBig(length));
}
let mut sub = Cursor::new(&bytes.get_ref()[initial_position..items_end]);
while sub.position() < length as u64 {
decoded.push(D::decode_with_param(decoding_parameter, &mut sub)?);
}
// Advance outer cursor by the amount read in the inner cursor
bytes.set_position(initial_position as u64 + sub.position());
Ok(decoded)
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
#[test]
fn encode_nothing() {
let mut bytes = vec![];
().encode(&mut bytes);
assert_eq!(bytes.len(), 0);
}
#[test]
fn roundtrip_u8() {
let value = 100u8;
let mut bytes = vec![];
value.encode(&mut bytes);
assert_eq!(bytes.len(), 1);
let decoded = u8::decode(&mut Cursor::new(&bytes)).unwrap();
assert_eq!(value, decoded);
}
#[test]
fn roundtrip_u16() {
let value = 1000u16;
let mut bytes = vec![];
value.encode(&mut bytes);
assert_eq!(bytes.len(), 2);
// Check endianness of encoding
assert_eq!(bytes, vec![3, 232]);
let decoded = u16::decode(&mut Cursor::new(&bytes)).unwrap();
assert_eq!(value, decoded);
}
#[test]
fn roundtrip_u24() {
let value = U24(1_000_000u32);
let mut bytes = vec![];
value.encode(&mut bytes);
assert_eq!(bytes.len(), 3);
// Check endianness of encoding
assert_eq!(bytes, vec![15, 66, 64]);
let decoded = U24::decode(&mut Cursor::new(&bytes)).unwrap();
assert_eq!(value, decoded);
}
#[test]
fn roundtrip_u32() {
let value = 134_217_728u32;
let mut bytes = vec![];
value.encode(&mut bytes);
assert_eq!(bytes.len(), 4);
// Check endianness of encoding
assert_eq!(bytes, vec![8, 0, 0, 0]);
let decoded = u32::decode(&mut Cursor::new(&bytes)).unwrap();
assert_eq!(value, decoded);
}
#[test]
fn roundtrip_u64() {
let value = 137_438_953_472u64;
let mut bytes = vec![];
value.encode(&mut bytes);
assert_eq!(bytes.len(), 8);
// Check endianness of encoding
assert_eq!(bytes, vec![0, 0, 0, 32, 0, 0, 0, 0]);
let decoded = u64::decode(&mut Cursor::new(&bytes)).unwrap();
assert_eq!(value, decoded);
}
#[derive(Debug, Eq, PartialEq)]
struct TestMessage {
field_u8: u8,
field_u16: u16,
field_u24: U24,
field_u32: u32,
field_u64: u64,
}
impl Encode for TestMessage {
fn encode(&self, bytes: &mut Vec<u8>) {
self.field_u8.encode(bytes);
self.field_u16.encode(bytes);
self.field_u24.encode(bytes);
self.field_u32.encode(bytes);
self.field_u64.encode(bytes);
}
}
impl Decode for TestMessage {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let field_u8 = u8::decode(bytes)?;
let field_u16 = u16::decode(bytes)?;
let field_u24 = U24::decode(bytes)?;
let field_u32 = u32::decode(bytes)?;
let field_u64 = u64::decode(bytes)?;
Ok(TestMessage {
field_u8,
field_u16,
field_u24,
field_u32,
field_u64,
})
}
}
impl TestMessage {
fn encoded_length() -> usize {
// u8 field
1 +
// u16 field
2 +
// u24 field
3 +
// u32 field
4 +
// u64 field
8
}
}
#[test]
fn roundtrip_message() {
let value = TestMessage {
field_u8: 0,
field_u16: 300,
field_u24: U24(1_000_000),
field_u32: 134_217_728,
field_u64: 137_438_953_472,
};
let mut bytes = vec![];
value.encode(&mut bytes);
assert_eq!(bytes.len(), TestMessage::encoded_length());
let decoded = TestMessage::decode(&mut Cursor::new(&bytes)).unwrap();
assert_eq!(value, decoded);
}
fn messages_vec() -> Vec<TestMessage> {
vec![
TestMessage {
field_u8: 0,
field_u16: 300,
field_u24: U24(1_000_000),
field_u32: 134_217_728,
field_u64: 137_438_953_472,
},
TestMessage {
field_u8: 0,
field_u16: 300,
field_u24: U24(1_000_000),
field_u32: 134_217_728,
field_u64: 137_438_953_472,
},
TestMessage {
field_u8: 0,
field_u16: 300,
field_u24: U24(1_000_000),
field_u32: 134_217_728,
field_u64: 137_438_953_472,
},
]
}
#[test]
fn roundtrip_variable_length_u8() {
let values = messages_vec();
let mut bytes = vec![];
encode_u8_items(&mut bytes, &(), &values);
assert_eq!(
bytes.len(),
// Length of opaque vector
1 +
// 3 TestMessage values
3 * TestMessage::encoded_length()
);
let decoded = decode_u8_items(&(), &mut Cursor::new(&bytes)).unwrap();
assert_eq!(values, decoded);
}
#[test]
fn roundtrip_variable_length_u16() {
let values = messages_vec();
let mut bytes = vec![];
encode_u16_items(&mut bytes, &(), &values);
assert_eq!(
bytes.len(),
// Length of opaque vector
2 +
// 3 TestMessage values
3 * TestMessage::encoded_length()
);
// Check endianness of encoded length
assert_eq!(bytes[0..2], [0, 3 * TestMessage::encoded_length() as u8]);
let decoded = decode_u16_items(&(), &mut Cursor::new(&bytes)).unwrap();
assert_eq!(values, decoded);
}
#[test]
fn roundtrip_variable_length_u24() {
let values = messages_vec();
let mut bytes = vec![];
encode_u24_items(&mut bytes, &(), &values);
assert_eq!(
bytes.len(),
// Length of opaque vector
3 +
// 3 TestMessage values
3 * TestMessage::encoded_length()
);
// Check endianness of encoded length
assert_eq!(bytes[0..3], [0, 0, 3 * TestMessage::encoded_length() as u8]);
let decoded = decode_u24_items(&(), &mut Cursor::new(&bytes)).unwrap();
assert_eq!(values, decoded);
}
#[test]
fn decode_items_overflow() {
let encoded = vec![1u8];
let mut cursor = Cursor::new(encoded.as_slice());
cursor.set_position(1);
assert_matches!(
decode_items::<(), u8>(usize::MAX, &(), &mut cursor).unwrap_err(),
CodecError::LengthPrefixTooBig(usize::MAX)
);
}
#[test]
fn decode_items_too_big() {
let encoded = vec![1u8];
let mut cursor = Cursor::new(encoded.as_slice());
cursor.set_position(1);
assert_matches!(
decode_items::<(), u8>(2, &(), &mut cursor).unwrap_err(),
CodecError::LengthPrefixTooBig(2)
);
}
}

232
third_party/rust/prio/src/encrypt.rs vendored Normal file
View File

@ -0,0 +1,232 @@
// Copyright (c) 2020 Apple Inc.
// SPDX-License-Identifier: MPL-2.0
//! Utilities for ECIES encryption / decryption used by the Prio client and server.
use crate::prng::PrngError;
use aes_gcm::aead::generic_array::typenum::U16;
use aes_gcm::aead::generic_array::GenericArray;
use aes_gcm::{AeadInPlace, NewAead};
use ring::agreement;
type Aes128 = aes_gcm::AesGcm<aes_gcm::aes::Aes128, U16>;
/// Length of the EC public key (X9.62 format)
pub const PUBLICKEY_LENGTH: usize = 65;
/// Length of the AES-GCM tag
pub const TAG_LENGTH: usize = 16;
/// Length of the symmetric AES-GCM key
const KEY_LENGTH: usize = 16;
/// Possible errors from encryption / decryption.
#[derive(Debug, thiserror::Error)]
pub enum EncryptError {
/// Base64 decoding error
#[error("base64 decoding error")]
DecodeBase64(#[from] base64::DecodeError),
/// Error in ECDH
#[error("error in ECDH")]
KeyAgreement,
/// Buffer for ciphertext was not large enough
#[error("buffer for ciphertext was not large enough")]
Encryption,
/// Authentication tags did not match.
#[error("authentication tags did not match")]
Decryption,
/// Input ciphertext was too small
#[error("input ciphertext was too small")]
DecryptionLength,
/// PRNG error
#[error("prng error: {0}")]
Prng(#[from] PrngError),
/// failure when calling getrandom().
#[error("getrandom: {0}")]
GetRandom(#[from] getrandom::Error),
}
/// NIST P-256, public key in X9.62 uncompressed format
#[derive(Debug, Clone)]
pub struct PublicKey(Vec<u8>);
/// NIST P-256, private key
///
/// X9.62 uncompressed public key concatenated with the secret scalar.
#[derive(Debug, Clone)]
pub struct PrivateKey(Vec<u8>);
impl PublicKey {
/// Load public key from a base64 encoded X9.62 uncompressed representation.
pub fn from_base64(key: &str) -> Result<Self, EncryptError> {
let keydata = base64::decode(key)?;
Ok(PublicKey(keydata))
}
}
/// Copy public key from a private key
impl std::convert::From<&PrivateKey> for PublicKey {
fn from(pk: &PrivateKey) -> Self {
PublicKey(pk.0[..PUBLICKEY_LENGTH].to_owned())
}
}
impl PrivateKey {
/// Load private key from a base64 encoded string.
pub fn from_base64(key: &str) -> Result<Self, EncryptError> {
let keydata = base64::decode(key)?;
Ok(PrivateKey(keydata))
}
}
/// Encrypt a bytestring using the public key
///
/// This uses ECIES with X9.63 key derivation function and AES-GCM for the
/// symmetic encryption and MAC.
pub fn encrypt_share(share: &[u8], key: &PublicKey) -> Result<Vec<u8>, EncryptError> {
let rng = ring::rand::SystemRandom::new();
let ephemeral_priv = agreement::EphemeralPrivateKey::generate(&agreement::ECDH_P256, &rng)
.map_err(|_| EncryptError::KeyAgreement)?;
let peer_public = agreement::UnparsedPublicKey::new(&agreement::ECDH_P256, &key.0);
let ephemeral_pub = ephemeral_priv
.compute_public_key()
.map_err(|_| EncryptError::KeyAgreement)?;
let symmetric_key_bytes = agreement::agree_ephemeral(
ephemeral_priv,
&peer_public,
EncryptError::KeyAgreement,
|material| Ok(x963_kdf(material, ephemeral_pub.as_ref())),
)?;
let in_out = share.to_owned();
let encrypted = encrypt_aes_gcm(
&symmetric_key_bytes[..16],
&symmetric_key_bytes[16..],
in_out,
)?;
let mut output = Vec::with_capacity(encrypted.len() + ephemeral_pub.as_ref().len());
output.extend_from_slice(ephemeral_pub.as_ref());
output.extend_from_slice(&encrypted);
Ok(output)
}
/// Decrypt a bytestring using the private key
///
/// This uses ECIES with X9.63 key derivation function and AES-GCM for the
/// symmetic encryption and MAC.
pub fn decrypt_share(share: &[u8], key: &PrivateKey) -> Result<Vec<u8>, EncryptError> {
if share.len() < PUBLICKEY_LENGTH + TAG_LENGTH {
return Err(EncryptError::DecryptionLength);
}
let empheral_pub_bytes: &[u8] = &share[0..PUBLICKEY_LENGTH];
let ephemeral_pub =
agreement::UnparsedPublicKey::new(&agreement::ECDH_P256, empheral_pub_bytes);
let fake_rng = ring::test::rand::FixedSliceRandom {
// private key consists of the public key + private scalar
bytes: &key.0[PUBLICKEY_LENGTH..],
};
let private_key = agreement::EphemeralPrivateKey::generate(&agreement::ECDH_P256, &fake_rng)
.map_err(|_| EncryptError::KeyAgreement)?;
let symmetric_key_bytes = agreement::agree_ephemeral(
private_key,
&ephemeral_pub,
EncryptError::KeyAgreement,
|material| Ok(x963_kdf(material, empheral_pub_bytes)),
)?;
// in_out is the AES-GCM ciphertext+tag, wihtout the ephemeral EC pubkey
let in_out = share[PUBLICKEY_LENGTH..].to_owned();
decrypt_aes_gcm(
&symmetric_key_bytes[..KEY_LENGTH],
&symmetric_key_bytes[KEY_LENGTH..],
in_out,
)
}
fn x963_kdf(z: &[u8], shared_info: &[u8]) -> [u8; 32] {
let mut hasher = ring::digest::Context::new(&ring::digest::SHA256);
hasher.update(z);
hasher.update(&1u32.to_be_bytes());
hasher.update(shared_info);
let digest = hasher.finish();
use std::convert::TryInto;
// unwrap never fails because SHA256 output len is 32
digest.as_ref().try_into().unwrap()
}
fn decrypt_aes_gcm(key: &[u8], nonce: &[u8], mut data: Vec<u8>) -> Result<Vec<u8>, EncryptError> {
let cipher = Aes128::new(GenericArray::from_slice(key));
cipher
.decrypt_in_place(GenericArray::from_slice(nonce), &[], &mut data)
.map_err(|_| EncryptError::Decryption)?;
Ok(data)
}
fn encrypt_aes_gcm(key: &[u8], nonce: &[u8], mut data: Vec<u8>) -> Result<Vec<u8>, EncryptError> {
let cipher = Aes128::new(GenericArray::from_slice(key));
cipher
.encrypt_in_place(GenericArray::from_slice(nonce), &[], &mut data)
.map_err(|_| EncryptError::Encryption)?;
Ok(data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt() -> Result<(), EncryptError> {
let pub_key = PublicKey::from_base64(
"BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9\
HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=",
)?;
let priv_key = PrivateKey::from_base64(
"BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgN\
t9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==",
)?;
let data = (0..100).map(|_| rand::random::<u8>()).collect::<Vec<u8>>();
let encrypted = encrypt_share(&data, &pub_key)?;
let decrypted = decrypt_share(&encrypted, &priv_key)?;
assert_eq!(decrypted, data);
Ok(())
}
#[test]
fn test_interop() {
let share1 = base64::decode("Kbnd2ZWrsfLfcpuxHffMrJ1b7sCrAsNqlb6Y1eAMfwCVUNXt").unwrap();
let share2 = base64::decode("hu+vT3+8/taHP7B/dWXh/g==").unwrap();
let encrypted_share1 = base64::decode(
"BEWObg41JiMJglSEA6Ebk37xOeflD2a1t2eiLmX0OPccJhAER5NmOI+4r4Cfm7aJn141sGKnTbCuIB9+AeVuw\
MAQnzjsGPu5aNgkdpp+6VowAcVAV1DlzZvtwlQkCFlX4f3xmafTPFTPOokYi2a+H1n8GKwd",
)
.unwrap();
let encrypted_share2 = base64::decode(
"BNRzQ6TbqSc4pk0S8aziVRNjWm4DXQR5yCYTK2w22iSw4XAPW4OB9RxBpWVa1C/3ywVBT/3yLArOMXEsCEMOG\
1+d2CiEvtuU1zADH2MVaCnXL/dVXkDchYZsvPWPkDcjQA==",
)
.unwrap();
let priv_key1 = PrivateKey::from_base64(
"BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOg\
Nt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==",
)
.unwrap();
let priv_key2 = PrivateKey::from_base64(
"BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhF\
LMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==",
)
.unwrap();
let decrypted1 = decrypt_share(&encrypted_share1, &priv_key1).unwrap();
let decrypted2 = decrypt_share(&encrypted_share2, &priv_key2).unwrap();
assert_eq!(decrypted1, share1);
assert_eq!(decrypted2, share2);
}
}

226
third_party/rust/prio/src/fft.rs vendored Normal file
View File

@ -0,0 +1,226 @@
// SPDX-License-Identifier: MPL-2.0
//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier
//! Transform (DFT) over a slice of field elements.
use crate::field::FieldElement;
use crate::fp::{log2, MAX_ROOTS};
use std::convert::TryFrom;
/// An error returned by an FFT operation.
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
pub enum FftError {
/// The output is too small.
#[error("output slice is smaller than specified size")]
OutputTooSmall,
/// The specified size is too large.
#[error("size is larger than than maximum permitted")]
SizeTooLarge,
/// The specified size is not a power of 2.
#[error("size is not a power of 2")]
SizeInvalid,
}
/// Sets `outp` to the DFT of `inp`.
///
/// Interpreting the input as the coefficients of a polynomial, the output is equal to the input
/// evaluated at points `p^0, p^1, ... p^(size-1)`, where `p` is the `2^size`-th principal root of
/// unity.
#[allow(clippy::many_single_char_names)]
pub fn discrete_fourier_transform<F: FieldElement>(
outp: &mut [F],
inp: &[F],
size: usize,
) -> Result<(), FftError> {
let d = usize::try_from(log2(size as u128)).map_err(|_| FftError::SizeTooLarge)?;
if size > outp.len() {
return Err(FftError::OutputTooSmall);
}
if size > 1 << MAX_ROOTS {
return Err(FftError::SizeTooLarge);
}
if size != 1 << d {
return Err(FftError::SizeInvalid);
}
#[allow(clippy::needless_range_loop)]
for i in 0..size {
let j = bitrev(d, i);
outp[i] = if j < inp.len() { inp[j] } else { F::zero() }
}
let mut w: F;
for l in 1..d + 1 {
w = F::one();
let r = F::root(l).unwrap();
let y = 1 << (l - 1);
for i in 0..y {
for j in 0..(size / y) >> 1 {
let x = (1 << l) * j + i;
let u = outp[x];
let v = w * outp[x + y];
outp[x] = u + v;
outp[x + y] = u - v;
}
w *= r;
}
}
Ok(())
}
/// Sets `outp` to the inverse of the DFT of `inp`.
#[cfg(test)]
pub(crate) fn discrete_fourier_transform_inv<F: FieldElement>(
outp: &mut [F],
inp: &[F],
size: usize,
) -> Result<(), FftError> {
let size_inv = F::from(F::Integer::try_from(size).unwrap()).inv();
discrete_fourier_transform(outp, inp, size)?;
discrete_fourier_transform_inv_finish(outp, size, size_inv);
Ok(())
}
/// An intermediate step in the computation of the inverse DFT. Exposing this function allows us to
/// amortize the cost the modular inverse across multiple inverse DFT operations.
pub(crate) fn discrete_fourier_transform_inv_finish<F: FieldElement>(
outp: &mut [F],
size: usize,
size_inv: F,
) {
let mut tmp: F;
outp[0] *= size_inv;
outp[size >> 1] *= size_inv;
for i in 1..size >> 1 {
tmp = outp[i] * size_inv;
outp[i] = outp[size - i] * size_inv;
outp[size - i] = tmp;
}
}
// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109)
fn bitrev(d: usize, x: usize) -> usize {
let mut y = 0;
for i in 0..d {
y += ((x >> i) & 1) << (d - i);
}
y >> 1
}
#[cfg(test)]
mod tests {
use super::*;
use crate::field::{
random_vector, split_vector, Field128, Field32, Field64, Field96, FieldPrio2,
};
use crate::polynomial::{poly_fft, PolyAuxMemory};
fn discrete_fourier_transform_then_inv_test<F: FieldElement>() -> Result<(), FftError> {
let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048];
for size in test_sizes.iter() {
let mut tmp = vec![F::zero(); *size];
let mut got = vec![F::zero(); *size];
let want = random_vector(*size).unwrap();
discrete_fourier_transform(&mut tmp, &want, want.len())?;
discrete_fourier_transform_inv(&mut got, &tmp, tmp.len())?;
assert_eq!(got, want);
}
Ok(())
}
#[test]
fn test_field32() {
discrete_fourier_transform_then_inv_test::<Field32>().expect("unexpected error");
}
#[test]
fn test_priov2_field32() {
discrete_fourier_transform_then_inv_test::<FieldPrio2>().expect("unexpected error");
}
#[test]
fn test_field64() {
discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error");
}
#[test]
fn test_field96() {
discrete_fourier_transform_then_inv_test::<Field96>().expect("unexpected error");
}
#[test]
fn test_field128() {
discrete_fourier_transform_then_inv_test::<Field128>().expect("unexpected error");
}
#[test]
fn test_recursive_fft() {
let size = 128;
let mut mem = PolyAuxMemory::new(size / 2);
let inp = random_vector(size).unwrap();
let mut want = vec![Field32::zero(); size];
let mut got = vec![Field32::zero(); size];
discrete_fourier_transform::<Field32>(&mut want, &inp, inp.len()).unwrap();
poly_fft(
&mut got,
&inp,
&mem.roots_2n,
size,
false,
&mut mem.fft_memory,
);
assert_eq!(got, want);
}
// This test demonstrates a consequence of \[BBG+19, Fact 4.4\]: interpolating a polynomial
// over secret shares and summing up the coefficients is equivalent to interpolating a
// polynomial over the plaintext data.
#[test]
fn test_fft_linearity() {
let len = 16;
let num_shares = 3;
let x: Vec<Field64> = random_vector(len).unwrap();
let mut x_shares = split_vector(&x, num_shares).unwrap();
// Just for fun, let's do something different with a subset of the inputs. For the first
// share, every odd element is set to the plaintext value. For all shares but the first,
// every odd element is set to 0.
#[allow(clippy::needless_range_loop)]
for i in 0..len {
if i % 2 != 0 {
x_shares[0][i] = x[i];
}
for j in 1..num_shares {
if i % 2 != 0 {
x_shares[j][i] = Field64::zero();
}
}
}
let mut got = vec![Field64::zero(); len];
let mut buf = vec![Field64::zero(); len];
for share in x_shares {
discrete_fourier_transform_inv(&mut buf, &share, len).unwrap();
for i in 0..len {
got[i] += buf[i];
}
}
let mut want = vec![Field64::zero(); len];
discrete_fourier_transform_inv(&mut want, &x, len).unwrap();
assert_eq!(got, want);
}
}

960
third_party/rust/prio/src/field.rs vendored Normal file
View File

@ -0,0 +1,960 @@
// Copyright (c) 2020 Apple Inc.
// SPDX-License-Identifier: MPL-2.0
//! Finite field arithmetic.
//!
//! Each field has an associated parameter called the "generator" that generates a multiplicative
//! subgroup of order `2^n` for some `n`.
#[cfg(feature = "crypto-dependencies")]
use crate::prng::{Prng, PrngError};
use crate::{
codec::{CodecError, Decode, Encode},
fp::{FP128, FP32, FP64, FP96},
};
use serde::{
de::{DeserializeOwned, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use std::{
cmp::min,
convert::{TryFrom, TryInto},
fmt::{self, Debug, Display, Formatter},
hash::{Hash, Hasher},
io::{Cursor, Read},
marker::PhantomData,
ops::{Add, AddAssign, BitAnd, Div, DivAssign, Mul, MulAssign, Neg, Shl, Shr, Sub, SubAssign},
};
/// Possible errors from finite field operations.
#[derive(Debug, thiserror::Error)]
pub enum FieldError {
/// Input sizes do not match.
#[error("input sizes do not match")]
InputSizeMismatch,
/// Returned when decoding a `FieldElement` from a short byte string.
#[error("short read from bytes")]
ShortRead,
/// Returned when decoding a `FieldElement` from a byte string encoding an integer larger than
/// or equal to the field modulus.
#[error("read from byte slice exceeds modulus")]
ModulusOverflow,
/// Error while performing I/O.
#[error("I/O error")]
Io(#[from] std::io::Error),
/// Error encoding or decoding a field.
#[error("Codec error")]
Codec(#[from] CodecError),
/// Error converting to `FieldElement::Integer`.
#[error("Integer TryFrom error")]
IntegerTryFrom,
/// Error converting `FieldElement::Integer` into something else.
#[error("Integer TryInto error")]
IntegerTryInto,
}
/// Byte order for encoding FieldElement values into byte sequences.
#[derive(Clone, Copy, Debug)]
enum ByteOrder {
/// Big endian byte order.
BigEndian,
/// Little endian byte order.
LittleEndian,
}
/// Objects with this trait represent an element of `GF(p)` for some prime `p`.
pub trait FieldElement:
Sized
+ Debug
+ Copy
+ PartialEq
+ Eq
+ Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Mul<Output = Self>
+ MulAssign
+ Div<Output = Self>
+ DivAssign
+ Neg<Output = Self>
+ Display
+ From<<Self as FieldElement>::Integer>
+ for<'a> TryFrom<&'a [u8], Error = FieldError>
// NOTE Ideally we would require `Into<[u8; Self::ENCODED_SIZE]>` instead of `Into<Vec<u8>>`,
// since the former avoids a heap allocation and can easily be converted into Vec<u8>, but that
// isn't possible yet[1]. However we can provide the impl on FieldElement implementations.
// [1]: https://github.com/rust-lang/rust/issues/60551
+ Into<Vec<u8>>
+ Serialize
+ DeserializeOwned
+ Encode
+ Decode
+ 'static // NOTE This bound is needed for downcasting a `dyn Gadget<F>>` to a concrete type.
{
/// Size in bytes of the encoding of a value.
const ENCODED_SIZE: usize;
/// The error returned if converting `usize` to an `Integer` fails.
type IntegerTryFromError: std::error::Error;
/// The error returend if converting an `Integer` to a `u64` fails.
type TryIntoU64Error: std::error::Error;
/// The integer representation of the field element.
type Integer: Copy
+ Debug
+ Eq
+ Ord
+ BitAnd<Output = <Self as FieldElement>::Integer>
+ Div<Output = <Self as FieldElement>::Integer>
+ Shl<Output = <Self as FieldElement>::Integer>
+ Shr<Output = <Self as FieldElement>::Integer>
+ Add<Output = <Self as FieldElement>::Integer>
+ Sub<Output = <Self as FieldElement>::Integer>
+ From<Self>
+ TryFrom<usize, Error = Self::IntegerTryFromError>
+ TryInto<u64, Error = Self::TryIntoU64Error>;
/// Modular exponentation, i.e., `self^exp (mod p)`.
fn pow(&self, exp: Self::Integer) -> Self;
/// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined.
fn inv(&self) -> Self;
/// Returns the prime modulus `p`.
fn modulus() -> Self::Integer;
/// Interprets the next [`Self::ENCODED_SIZE`] bytes from the input slice as an element of the
/// field. The `m` most significant bits are cleared, where `m` is equal to the length of
/// [`Self::Integer`] in bits minus the length of the modulus in bits.
///
/// # Errors
///
/// An error is returned if the provided slice is too small to encode a field element or if the
/// result encodes an integer larger than or equal to the field modulus.
///
/// # Warnings
///
/// This function should only be used within [`prng::Prng`] to convert a random byte string into
/// a field element. Use [`Self::decode`] to deserialize field elements. Use
/// [`field::rand`] or [`prng::Prng`] to randomly generate field elements.
#[doc(hidden)]
fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError>;
/// Returns the size of the multiplicative subgroup generated by `generator()`.
fn generator_order() -> Self::Integer;
/// Returns the generator of the multiplicative subgroup of size `generator_order()`.
fn generator() -> Self;
/// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th
/// prinicpal root of unity is 1 by definition.
fn root(l: usize) -> Option<Self>;
/// Returns the additive identity.
fn zero() -> Self;
/// Returns the multiplicative identity.
fn one() -> Self;
/// Convert a slice of field elements into a vector of bytes.
///
/// # Notes
///
/// Ideally we would implement `From<&[F: FieldElement]> for Vec<u8>` or the corresponding
/// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this
/// impossible.
fn slice_into_byte_vec(values: &[Self]) -> Vec<u8> {
let mut vec = Vec::with_capacity(values.len() * Self::ENCODED_SIZE);
for elem in values {
vec.append(&mut (*elem).into());
}
vec
}
/// Convert a slice of bytes into a vector of field elements. The slice is interpreted as a
/// sequence of [`Self::ENCODED_SIZE`]-byte sequences.
///
/// # Errors
///
/// Returns an error if the length of the provided byte slice is not a multiple of the size of a
/// field element, or if any of the values in the byte slice are invalid encodings of a field
/// element, because the encoded integer is larger than or equal to the field modulus.
///
/// # Notes
///
/// Ideally we would implement `From<&[u8]> for Vec<F: FieldElement>` or the corresponding
/// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this
/// impossible.
fn byte_slice_into_vec(bytes: &[u8]) -> Result<Vec<Self>, FieldError> {
if bytes.len() % Self::ENCODED_SIZE != 0 {
return Err(FieldError::ShortRead);
}
let mut vec = Vec::with_capacity(bytes.len() / Self::ENCODED_SIZE);
for chunk in bytes.chunks_exact(Self::ENCODED_SIZE) {
vec.push(Self::get_decoded(chunk)?);
}
Ok(vec)
}
}
/// Methods common to all `FieldElement` implementations that are private to the crate.
pub(crate) trait FieldElementExt: FieldElement {
/// Encode `input` as `bits`-bit vector of elements of `Self` if it's small enough
/// to be represented with that many bits.
///
/// # Arguments
///
/// * `input` - The field element to encode
/// * `bits` - The number of bits to use for the encoding
fn encode_into_bitvector_representation(
input: &Self::Integer,
bits: usize,
) -> Result<Vec<Self>, FieldError> {
// Create a mutable copy of `input`. In each iteration of the following loop we take the
// least significant bit, and shift input to the right by one bit.
let mut i = *input;
let one = Self::Integer::from(Self::one());
let mut encoded = Vec::with_capacity(bits);
for _ in 0..bits {
let w = Self::from(i & one);
encoded.push(w);
i = i >> one;
}
// If `i` is still not zero, this means that it cannot be encoded by `bits` bits.
if i != Self::Integer::from(Self::zero()) {
return Err(FieldError::InputSizeMismatch);
}
Ok(encoded)
}
/// Decode the bitvector-represented value `input` into a simple representation as a single
/// field element.
///
/// # Errors
///
/// This function errors if `2^input.len() - 1` does not fit into the field `Self`.
fn decode_from_bitvector_representation(input: &[Self]) -> Result<Self, FieldError> {
if !Self::valid_integer_bitlength(input.len()) {
return Err(FieldError::ModulusOverflow);
}
let mut decoded = Self::zero();
for (l, bit) in input.iter().enumerate() {
let w = Self::Integer::try_from(1 << l).map_err(|_| FieldError::IntegerTryFrom)?;
decoded += Self::from(w) * *bit;
}
Ok(decoded)
}
/// Interpret `i` as [`Self::Integer`] if it's representable in that type and smaller than the
/// field modulus.
fn valid_integer_try_from<N>(i: N) -> Result<Self::Integer, FieldError>
where
Self::Integer: TryFrom<N>,
{
let i_int = Self::Integer::try_from(i).map_err(|_| FieldError::IntegerTryFrom)?;
if Self::modulus() <= i_int {
return Err(FieldError::ModulusOverflow);
}
Ok(i_int)
}
/// Check if the largest number representable with `bits` bits (i.e. 2^bits - 1) is
/// representable in this field.
fn valid_integer_bitlength(bits: usize) -> bool {
if let Ok(bits_int) = Self::Integer::try_from(bits) {
if Self::modulus() >> bits_int != Self::Integer::from(Self::zero()) {
return true;
}
}
false
}
}
impl<F: FieldElement> FieldElementExt for F {}
/// serde Visitor implementation used to generically deserialize `FieldElement`
/// values from byte arrays.
struct FieldElementVisitor<F: FieldElement> {
phantom: PhantomData<F>,
}
impl<'de, F: FieldElement> Visitor<'de> for FieldElementVisitor<F> {
type Value = F;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
formatter.write_fmt(format_args!("an array of {} bytes", F::ENCODED_SIZE))
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Self::Value::try_from(v).map_err(E::custom)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut bytes = vec![];
while let Some(byte) = seq.next_element()? {
bytes.push(byte);
}
self.visit_bytes(&bytes)
}
}
macro_rules! make_field {
(
$(#[$meta:meta])*
$elem:ident, $int:ident, $fp:ident, $encoding_size:literal, $encoding_order:expr,
) => {
$(#[$meta])*
///
/// This structure represents a field element in a prime order field. The concrete
/// representation of the element is via the Montgomery domain. For an element n in GF(p),
/// we store n * R^-1 mod p (where R is a given power of two). This representation enables
/// using a more efficient (and branchless) multiplication algorithm, at the expense of
/// having to convert elements between their Montgomery domain representation and natural
/// representation. For calculations with many multiplications or exponentiations, this is
/// worthwhile.
///
/// As an invariant, this integer representing the field element in the Montgomery domain
/// must be less than the prime p.
#[derive(Clone, Copy, PartialOrd, Ord, Default)]
pub struct $elem(u128);
impl $elem {
/// Attempts to instantiate an `$elem` from the first `Self::ENCODED_SIZE` bytes in the
/// provided slice. The decoded value will be bitwise-ANDed with `mask` before reducing
/// it using the field modulus.
///
/// # Errors
///
/// An error is returned if the provided slice is not long enough to encode a field
/// element or if the decoded value is greater than the field prime.
///
/// # Notes
///
/// We cannot use `u128::from_le_bytes` or `u128::from_be_bytes` because those functions
/// expect inputs to be exactly 16 bytes long. Our encoding of most field elements is
/// more compact, and does not have to correspond to the size of an integer type. For
/// instance,`Field96`'s encoding is 12 bytes, even though it is a 16 byte `u128` in
/// memory.
fn try_from_bytes(bytes: &[u8], mask: u128) -> Result<Self, FieldError> {
if Self::ENCODED_SIZE > bytes.len() {
return Err(FieldError::ShortRead);
}
let mut int = 0;
for i in 0..Self::ENCODED_SIZE {
let j = match $encoding_order {
ByteOrder::LittleEndian => i,
ByteOrder::BigEndian => Self::ENCODED_SIZE - i - 1,
};
int |= (bytes[j] as u128) << (i << 3);
}
int &= mask;
if int >= $fp.p {
return Err(FieldError::ModulusOverflow);
}
// FieldParameters::montgomery() will return a value that has been fully reduced
// mod p, satisfying the invariant on Self.
Ok(Self($fp.montgomery(int)))
}
}
impl PartialEq for $elem {
fn eq(&self, rhs: &Self) -> bool {
// The fields included in this comparison MUST match the fields
// used in Hash::hash
// https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq
// Check the invariant that the integer representation is fully reduced.
debug_assert!(self.0 < $fp.p);
debug_assert!(rhs.0 < $fp.p);
self.0 == rhs.0
}
}
impl Hash for $elem {
fn hash<H: Hasher>(&self, state: &mut H) {
// The fields included in this hash MUST match the fields used
// in PartialEq::eq
// https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq
// Check the invariant that the integer representation is fully reduced.
debug_assert!(self.0 < $fp.p);
self.0.hash(state);
}
}
impl Eq for $elem {}
impl Add for $elem {
type Output = $elem;
fn add(self, rhs: Self) -> Self {
// FieldParameters::add() returns a value that has been fully reduced
// mod p, satisfying the invariant on Self.
Self($fp.add(self.0, rhs.0))
}
}
impl Add for &$elem {
type Output = $elem;
fn add(self, rhs: Self) -> $elem {
*self + *rhs
}
}
impl AddAssign for $elem {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl Sub for $elem {
type Output = $elem;
fn sub(self, rhs: Self) -> Self {
// We know that self.0 and rhs.0 are both less than p, thus FieldParameters::sub()
// returns a value less than p, satisfying the invariant on Self.
Self($fp.sub(self.0, rhs.0))
}
}
impl Sub for &$elem {
type Output = $elem;
fn sub(self, rhs: Self) -> $elem {
*self - *rhs
}
}
impl SubAssign for $elem {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl Mul for $elem {
type Output = $elem;
fn mul(self, rhs: Self) -> Self {
// FieldParameters::mul() always returns a value less than p, so the invariant on
// Self is satisfied.
Self($fp.mul(self.0, rhs.0))
}
}
impl Mul for &$elem {
type Output = $elem;
fn mul(self, rhs: Self) -> $elem {
*self * *rhs
}
}
impl MulAssign for $elem {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl Div for $elem {
type Output = $elem;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self {
self * rhs.inv()
}
}
impl Div for &$elem {
type Output = $elem;
fn div(self, rhs: Self) -> $elem {
*self / *rhs
}
}
impl DivAssign for $elem {
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs;
}
}
impl Neg for $elem {
type Output = $elem;
fn neg(self) -> Self {
// FieldParameters::neg() will return a value less than p because self.0 is less
// than p, and neg() dispatches to sub().
Self($fp.neg(self.0))
}
}
impl Neg for &$elem {
type Output = $elem;
fn neg(self) -> $elem {
-(*self)
}
}
impl From<$int> for $elem {
fn from(x: $int) -> Self {
// FieldParameters::montgomery() will return a value that has been fully reduced
// mod p, satisfying the invariant on Self.
Self($fp.montgomery(u128::try_from(x).unwrap()))
}
}
impl From<$elem> for $int {
fn from(x: $elem) -> Self {
$int::try_from($fp.residue(x.0)).unwrap()
}
}
impl PartialEq<$int> for $elem {
fn eq(&self, rhs: &$int) -> bool {
$fp.residue(self.0) == u128::try_from(*rhs).unwrap()
}
}
impl<'a> TryFrom<&'a [u8]> for $elem {
type Error = FieldError;
fn try_from(bytes: &[u8]) -> Result<Self, FieldError> {
Self::try_from_bytes(bytes, u128::MAX)
}
}
impl From<$elem> for [u8; $elem::ENCODED_SIZE] {
fn from(elem: $elem) -> Self {
let int = $fp.residue(elem.0);
let mut slice = [0; $elem::ENCODED_SIZE];
for i in 0..$elem::ENCODED_SIZE {
let j = match $encoding_order {
ByteOrder::LittleEndian => i,
ByteOrder::BigEndian => $elem::ENCODED_SIZE - i - 1,
};
slice[j] = ((int >> (i << 3)) & 0xff) as u8;
}
slice
}
}
impl From<$elem> for Vec<u8> {
fn from(elem: $elem) -> Self {
<[u8; $elem::ENCODED_SIZE]>::from(elem).to_vec()
}
}
impl Display for $elem {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "{}", $fp.residue(self.0))
}
}
impl Debug for $elem {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", $fp.residue(self.0))
}
}
// We provide custom [`serde::Serialize`] and [`serde::Deserialize`] implementations because
// the derived implementations would represent `FieldElement` values as the backing `u128`,
// which is not what we want because (1) we can be more efficient in all cases and (2) in
// some circumstances, [some serializers don't support `u128`](https://github.com/serde-rs/json/issues/625).
impl Serialize for $elem {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes: [u8; $elem::ENCODED_SIZE] = (*self).into();
serializer.serialize_bytes(&bytes)
}
}
impl<'de> Deserialize<'de> for $elem {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<$elem, D::Error> {
deserializer.deserialize_bytes(FieldElementVisitor { phantom: PhantomData })
}
}
impl Encode for $elem {
fn encode(&self, bytes: &mut Vec<u8>) {
let slice = <[u8; $elem::ENCODED_SIZE]>::from(*self);
bytes.extend_from_slice(&slice);
}
}
impl Decode for $elem {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let mut value = [0u8; $elem::ENCODED_SIZE];
bytes.read_exact(&mut value)?;
$elem::try_from_bytes(&value, u128::MAX).map_err(|e| {
CodecError::Other(Box::new(e) as Box<dyn std::error::Error + 'static + Send + Sync>)
})
}
}
impl FieldElement for $elem {
const ENCODED_SIZE: usize = $encoding_size;
type Integer = $int;
type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error;
type TryIntoU64Error = <Self::Integer as TryInto<u64>>::Error;
fn pow(&self, exp: Self::Integer) -> Self {
// FieldParameters::pow() relies on mul(), and will always return a value less
// than p.
Self($fp.pow(self.0, u128::try_from(exp).unwrap()))
}
fn inv(&self) -> Self {
// FieldParameters::inv() ultimately relies on mul(), and will always return a
// value less than p.
Self($fp.inv(self.0))
}
fn modulus() -> Self::Integer {
$fp.p as $int
}
fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError> {
$elem::try_from_bytes(bytes, $fp.bit_mask)
}
fn generator() -> Self {
Self($fp.g)
}
fn generator_order() -> Self::Integer {
1 << (Self::Integer::try_from($fp.num_roots).unwrap())
}
fn root(l: usize) -> Option<Self> {
if l < min($fp.roots.len(), $fp.num_roots+1) {
Some(Self($fp.roots[l]))
} else {
None
}
}
fn zero() -> Self {
Self(0)
}
fn one() -> Self {
Self($fp.roots[0])
}
}
};
}
make_field!(
/// `GF(4293918721)`, a 32-bit field.
Field32,
u32,
FP32,
4,
ByteOrder::BigEndian,
);
make_field!(
/// Same as Field32, but encoded in little endian for compatibility with Prio v2.
FieldPrio2,
u32,
FP32,
4,
ByteOrder::LittleEndian,
);
make_field!(
/// `GF(18446744069414584321)`, a 64-bit field.
Field64,
u64,
FP64,
8,
ByteOrder::BigEndian,
);
make_field!(
/// `GF(79228148845226978974766202881)`, a 96-bit field.
Field96,
u128,
FP96,
12,
ByteOrder::BigEndian,
);
make_field!(
/// `GF(340282366920938462946865773367900766209)`, a 128-bit field.
Field128,
u128,
FP128,
16,
ByteOrder::BigEndian,
);
/// Merge two vectors of fields by summing other_vector into accumulator.
///
/// # Errors
///
/// Fails if the two vectors do not have the same length.
#[cfg(any(test, feature = "prio2"))]
pub(crate) fn merge_vector<F: FieldElement>(
accumulator: &mut [F],
other_vector: &[F],
) -> Result<(), FieldError> {
if accumulator.len() != other_vector.len() {
return Err(FieldError::InputSizeMismatch);
}
for (a, o) in accumulator.iter_mut().zip(other_vector.iter()) {
*a += *o;
}
Ok(())
}
/// Outputs an additive secret sharing of the input.
#[cfg(feature = "crypto-dependencies")]
pub(crate) fn split_vector<F: FieldElement>(
inp: &[F],
num_shares: usize,
) -> Result<Vec<Vec<F>>, PrngError> {
if num_shares == 0 {
return Ok(vec![]);
}
let mut outp = Vec::with_capacity(num_shares);
outp.push(inp.to_vec());
for _ in 1..num_shares {
let share: Vec<F> = random_vector(inp.len())?;
for (x, y) in outp[0].iter_mut().zip(&share) {
*x -= *y;
}
outp.push(share);
}
Ok(outp)
}
/// Generate a vector of uniform random field elements.
#[cfg(feature = "crypto-dependencies")]
pub fn random_vector<F: FieldElement>(len: usize) -> Result<Vec<F>, PrngError> {
Ok(Prng::new()?.take(len).collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fp::MAX_ROOTS;
use crate::prng::Prng;
use assert_matches::assert_matches;
use std::collections::hash_map::DefaultHasher;
#[test]
fn test_endianness() {
let little_endian_encoded: [u8; FieldPrio2::ENCODED_SIZE] =
FieldPrio2(0x12_34_56_78).into();
let mut big_endian_encoded: [u8; Field32::ENCODED_SIZE] = Field32(0x12_34_56_78).into();
big_endian_encoded.reverse();
assert_eq!(little_endian_encoded, big_endian_encoded);
}
#[test]
fn test_accumulate() {
let mut lhs = vec![Field32(1); 10];
let rhs = vec![Field32(2); 10];
merge_vector(&mut lhs, &rhs).unwrap();
lhs.iter().for_each(|f| assert_eq!(*f, Field32(3)));
rhs.iter().for_each(|f| assert_eq!(*f, Field32(2)));
let wrong_len = vec![Field32::zero(); 9];
let result = merge_vector(&mut lhs, &wrong_len);
assert_matches!(result, Err(FieldError::InputSizeMismatch));
}
fn hash_helper<H: Hash>(input: H) -> u64 {
let mut hasher = DefaultHasher::new();
input.hash(&mut hasher);
hasher.finish()
}
// Some of the checks in this function, like `assert_eq!(one - one, zero)`
// or `assert_eq!(two / two, one)` trip this clippy lint for tautological
// comparisons, but we have a legitimate need to verify these basics. We put
// the #[allow] on the whole function since "attributes on expressions are
// experimental" https://github.com/rust-lang/rust/issues/15701
#[allow(clippy::eq_op)]
fn field_element_test<F: FieldElement + Hash>() {
let mut prng: Prng<F, _> = Prng::new().unwrap();
let int_modulus = F::modulus();
let int_one = F::Integer::try_from(1).unwrap();
let zero = F::zero();
let one = F::one();
let two = F::from(F::Integer::try_from(2).unwrap());
let four = F::from(F::Integer::try_from(4).unwrap());
// add
assert_eq!(F::from(int_modulus - int_one) + one, zero);
assert_eq!(one + one, two);
assert_eq!(two + F::from(int_modulus), two);
// sub
assert_eq!(zero - one, F::from(int_modulus - int_one));
assert_eq!(one - one, zero);
assert_eq!(two - F::from(int_modulus), two);
assert_eq!(one - F::from(int_modulus - int_one), two);
// add + sub
for _ in 0..100 {
let f = prng.get();
let g = prng.get();
assert_eq!(f + g - f - g, zero);
assert_eq!(f + g - g, f);
assert_eq!(f + g - f, g);
}
// mul
assert_eq!(two * two, four);
assert_eq!(two * one, two);
assert_eq!(two * zero, zero);
assert_eq!(one * F::from(int_modulus), zero);
// div
assert_eq!(four / two, two);
assert_eq!(two / two, one);
assert_eq!(zero / two, zero);
assert_eq!(two / zero, zero); // Undefined behavior
assert_eq!(zero.inv(), zero); // Undefined behavior
// mul + div
for _ in 0..100 {
let f = prng.get();
if f == zero {
continue;
}
assert_eq!(f * f.inv(), one);
assert_eq!(f.inv() * f, one);
}
// pow
assert_eq!(two.pow(F::Integer::try_from(0).unwrap()), one);
assert_eq!(two.pow(int_one), two);
assert_eq!(two.pow(F::Integer::try_from(2).unwrap()), four);
assert_eq!(two.pow(int_modulus - int_one), one);
assert_eq!(two.pow(int_modulus), two);
// roots
let mut int_order = F::generator_order();
for l in 0..MAX_ROOTS + 1 {
assert_eq!(
F::generator().pow(int_order),
F::root(l).unwrap(),
"failure for F::root({})",
l
);
int_order = int_order >> int_one;
}
// serialization
let test_inputs = vec![zero, one, prng.get(), F::from(int_modulus - int_one)];
for want in test_inputs.iter() {
let mut bytes = vec![];
want.encode(&mut bytes);
assert_eq!(bytes.len(), F::ENCODED_SIZE);
let got = F::get_decoded(&bytes).unwrap();
assert_eq!(got, *want);
}
let serialized_vec = F::slice_into_byte_vec(&test_inputs);
let deserialized = F::byte_slice_into_vec(&serialized_vec).unwrap();
assert_eq!(deserialized, test_inputs);
// equality and hash: Generate many elements, confirm they are not equal, and confirm
// various products that should be equal have the same hash. Three is chosen as a generator
// here because it happens to generate fairly large subgroups of (Z/pZ)* for all four
// primes.
let three = F::from(F::Integer::try_from(3).unwrap());
let mut powers_of_three = Vec::with_capacity(500);
let mut power = one;
for _ in 0..500 {
powers_of_three.push(power);
power *= three;
}
// Check all these elements are mutually not equal.
for i in 0..powers_of_three.len() {
let first = &powers_of_three[i];
for second in &powers_of_three[0..i] {
assert_ne!(first, second);
}
}
// Check that 3^i is the same whether it's calculated with pow() or repeated
// multiplication, with both equality and hash equality.
for (i, power) in powers_of_three.iter().enumerate() {
let result = three.pow(F::Integer::try_from(i).unwrap());
assert_eq!(result, *power);
let hash1 = hash_helper(power);
let hash2 = hash_helper(result);
assert_eq!(hash1, hash2);
}
// Check that 3^n = (3^i)*(3^(n-i)), via both equality and hash equality.
let expected_product = powers_of_three[powers_of_three.len() - 1];
let expected_hash = hash_helper(expected_product);
for i in 0..powers_of_three.len() {
let a = powers_of_three[i];
let b = powers_of_three[powers_of_three.len() - 1 - i];
let product = a * b;
assert_eq!(product, expected_product);
assert_eq!(hash_helper(product), expected_hash);
}
// Construct an element from a number that needs to be reduced, and test comparisons on it,
// confirming that FieldParameters::montgomery() reduced it correctly.
let p = F::from(int_modulus);
assert_eq!(p, zero);
assert_eq!(hash_helper(p), hash_helper(zero));
let p_plus_one = F::from(int_modulus + F::Integer::try_from(1).unwrap());
assert_eq!(p_plus_one, one);
assert_eq!(hash_helper(p_plus_one), hash_helper(one));
}
#[test]
fn test_field32() {
field_element_test::<Field32>();
}
#[test]
fn test_field_priov2() {
field_element_test::<FieldPrio2>();
}
#[test]
fn test_field64() {
field_element_test::<Field64>();
}
#[test]
fn test_field96() {
field_element_test::<Field96>();
}
#[test]
fn test_field128() {
field_element_test::<Field128>();
}
}

1029
third_party/rust/prio/src/flp.rs vendored Normal file

File diff suppressed because it is too large Load Diff

669
third_party/rust/prio/src/flp/gadgets.rs vendored Normal file
View File

@ -0,0 +1,669 @@
// SPDX-License-Identifier: MPL-2.0
//! A collection of gadgets.
use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish};
use crate::field::FieldElement;
use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget};
use crate::polynomial::{poly_deg, poly_eval, poly_mul};
#[cfg(feature = "multithreaded")]
use rayon::prelude::*;
use std::any::Any;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::marker::PhantomData;
/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for
/// polynomial multiplication. Otherwise, the gadget uses direct multiplication.
const FFT_THRESHOLD: usize = 60;
/// An arity-2 gadget that multiples its inputs.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Mul<F: FieldElement> {
/// Size of buffer for FFT operations.
n: usize,
/// Inverse of `n` in `F`.
n_inv: F,
/// The number of times this gadget will be called.
num_calls: usize,
}
impl<F: FieldElement> Mul<F> {
/// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be
/// called by the validity circuit.
pub fn new(num_calls: usize) -> Self {
let n = gadget_poly_fft_mem_len(2, num_calls);
let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
Self {
n,
n_inv,
num_calls,
}
}
// Multiply input polynomials directly.
pub(crate) fn call_poly_direct(
&mut self,
outp: &mut [F],
inp: &[Vec<F>],
) -> Result<(), FlpError> {
let v = poly_mul(&inp[0], &inp[1]);
outp[..v.len()].clone_from_slice(&v);
Ok(())
}
// Multiply input polynomials using FFT.
pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
let n = self.n;
let mut buf = vec![F::zero(); n];
discrete_fourier_transform(&mut buf, &inp[0], n)?;
discrete_fourier_transform(outp, &inp[1], n)?;
for i in 0..n {
buf[i] *= outp[i];
}
discrete_fourier_transform(outp, &buf, n)?;
discrete_fourier_transform_inv_finish(outp, n, self.n_inv);
Ok(())
}
}
impl<F: FieldElement> Gadget<F> for Mul<F> {
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
gadget_call_check(self, inp.len())?;
Ok(inp[0] * inp[1])
}
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
gadget_call_poly_check(self, outp, inp)?;
if inp[0].len() >= FFT_THRESHOLD {
self.call_poly_fft(outp, inp)
} else {
self.call_poly_direct(outp, inp)
}
}
fn arity(&self) -> usize {
2
}
fn degree(&self) -> usize {
2
}
fn calls(&self) -> usize {
self.num_calls
}
fn as_any(&mut self) -> &mut dyn Any {
self
}
}
/// An arity-1 gadget that evaluates its input on some polynomial.
//
// TODO Make `poly` an array of length determined by a const generic.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PolyEval<F: FieldElement> {
poly: Vec<F>,
/// Size of buffer for FFT operations.
n: usize,
/// Inverse of `n` in `F`.
n_inv: F,
/// The number of times this gadget will be called.
num_calls: usize,
}
impl<F: FieldElement> PolyEval<F> {
/// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times
/// this gadget is called by the validity circuit.
pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls);
let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
Self {
poly,
n,
n_inv,
num_calls,
}
}
}
impl<F: FieldElement> PolyEval<F> {
// Multiply input polynomials directly.
fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
outp[0] = self.poly[0];
let mut x = inp[0].to_vec();
for i in 1..self.poly.len() {
for j in 0..x.len() {
outp[j] += self.poly[i] * x[j];
}
if i < self.poly.len() - 1 {
x = poly_mul(&x, &inp[0]);
}
}
Ok(())
}
// Multiply input polynomials using FFT.
fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
let n = self.n;
let inp = &inp[0];
let mut inp_vals = vec![F::zero(); n];
discrete_fourier_transform(&mut inp_vals, inp, n)?;
let mut x_vals = inp_vals.clone();
let mut x = vec![F::zero(); n];
x[..inp.len()].clone_from_slice(inp);
outp[0] = self.poly[0];
for i in 1..self.poly.len() {
for j in 0..n {
outp[j] += self.poly[i] * x[j];
}
if i < self.poly.len() - 1 {
for j in 0..n {
x_vals[j] *= inp_vals[j];
}
discrete_fourier_transform(&mut x, &x_vals, n)?;
discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv);
}
}
Ok(())
}
}
impl<F: FieldElement> Gadget<F> for PolyEval<F> {
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
gadget_call_check(self, inp.len())?;
Ok(poly_eval(&self.poly, inp[0]))
}
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
gadget_call_poly_check(self, outp, inp)?;
for item in outp.iter_mut() {
*item = F::zero();
}
if inp[0].len() >= FFT_THRESHOLD {
self.call_poly_fft(outp, inp)
} else {
self.call_poly_direct(outp, inp)
}
}
fn arity(&self) -> usize {
1
}
fn degree(&self) -> usize {
poly_deg(&self.poly)
}
fn calls(&self) -> usize {
self.num_calls
}
fn as_any(&mut self) -> &mut dyn Any {
self
}
}
/// An arity-2 gadget that returns `poly(in[0]) * in[1]` for some polynomial `poly`.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct BlindPolyEval<F: FieldElement> {
poly: Vec<F>,
/// Size of buffer for the outer FFT multiplication.
n: usize,
/// Inverse of `n` in `F`.
n_inv: F,
/// The number of times this gadget will be called.
num_calls: usize,
}
impl<F: FieldElement> BlindPolyEval<F> {
/// Returns a `BlindPolyEval` gadget for polynomial `poly`.
pub fn new(poly: Vec<F>, num_calls: usize) -> Self {
let n = gadget_poly_fft_mem_len(poly_deg(&poly) + 1, num_calls);
let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv();
Self {
poly,
n,
n_inv,
num_calls,
}
}
fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
let x = &inp[0];
let y = &inp[1];
let mut z = y.to_vec();
for i in 0..self.poly.len() {
for j in 0..z.len() {
outp[j] += self.poly[i] * z[j];
}
if i < self.poly.len() - 1 {
z = poly_mul(&z, x);
}
}
Ok(())
}
fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
let n = self.n;
let x = &inp[0];
let y = &inp[1];
let mut x_vals = vec![F::zero(); n];
discrete_fourier_transform(&mut x_vals, x, n)?;
let mut z_vals = vec![F::zero(); n];
discrete_fourier_transform(&mut z_vals, y, n)?;
let mut z = vec![F::zero(); n];
let mut z_len = y.len();
z[..y.len()].clone_from_slice(y);
for i in 0..self.poly.len() {
for j in 0..z_len {
outp[j] += self.poly[i] * z[j];
}
if i < self.poly.len() - 1 {
for j in 0..n {
z_vals[j] *= x_vals[j];
}
discrete_fourier_transform(&mut z, &z_vals, n)?;
discrete_fourier_transform_inv_finish(&mut z, n, self.n_inv);
z_len += x.len();
}
}
Ok(())
}
}
impl<F: FieldElement> Gadget<F> for BlindPolyEval<F> {
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
gadget_call_check(self, inp.len())?;
Ok(inp[1] * poly_eval(&self.poly, inp[0]))
}
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
gadget_call_poly_check(self, outp, inp)?;
for x in outp.iter_mut() {
*x = F::zero();
}
if inp[0].len() >= FFT_THRESHOLD {
self.call_poly_fft(outp, inp)
} else {
self.call_poly_direct(outp, inp)
}
}
fn arity(&self) -> usize {
2
}
fn degree(&self) -> usize {
poly_deg(&self.poly) + 1
}
fn calls(&self) -> usize {
self.num_calls
}
fn as_any(&mut self) -> &mut dyn Any {
self
}
}
/// Marker trait for abstracting over [`ParallelSum`].
pub trait ParallelSumGadget<F: FieldElement, G>: Gadget<F> + Debug {
/// Wraps `inner` into a sum gadget with `chunks` chunks
fn new(inner: G, chunks: usize) -> Self;
}
/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ParallelSum<F: FieldElement, G: Gadget<F>> {
inner: G,
chunks: usize,
phantom: PhantomData<F>,
}
impl<F: FieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G> for ParallelSum<F, G> {
fn new(inner: G, chunks: usize) -> Self {
Self {
inner,
chunks,
phantom: PhantomData,
}
}
}
impl<F: FieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> {
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
gadget_call_check(self, inp.len())?;
let mut outp = F::zero();
for chunk in inp.chunks(self.inner.arity()) {
outp += self.inner.call(chunk)?;
}
Ok(outp)
}
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
gadget_call_poly_check(self, outp, inp)?;
for x in outp.iter_mut() {
*x = F::zero();
}
let mut partial_outp = vec![F::zero(); outp.len()];
for chunk in inp.chunks(self.inner.arity()) {
self.inner.call_poly(&mut partial_outp, chunk)?;
for i in 0..outp.len() {
outp[i] += partial_outp[i]
}
}
Ok(())
}
fn arity(&self) -> usize {
self.chunks * self.inner.arity()
}
fn degree(&self) -> usize {
self.inner.degree()
}
fn calls(&self) -> usize {
self.inner.calls()
}
fn as_any(&mut self) -> &mut dyn Any {
self
}
}
/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the
/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum
/// evaluation is multithreaded.
#[cfg(feature = "multithreaded")]
#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ParallelSumMultithreaded<F: FieldElement, G: Gadget<F>> {
serial_sum: ParallelSum<F, G>,
}
#[cfg(feature = "multithreaded")]
impl<F, G> ParallelSumGadget<F, G> for ParallelSumMultithreaded<F, G>
where
F: FieldElement + Sync + Send,
G: 'static + Gadget<F> + Clone + Sync,
{
fn new(inner: G, chunks: usize) -> Self {
Self {
serial_sum: ParallelSum::new(inner, chunks),
}
}
}
#[cfg(feature = "multithreaded")]
impl<F, G> Gadget<F> for ParallelSumMultithreaded<F, G>
where
F: FieldElement + Sync + Send,
G: 'static + Gadget<F> + Clone + Sync,
{
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
self.serial_sum.call(inp)
}
fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> {
gadget_call_poly_check(self, outp, inp)?;
let res = inp
.par_chunks(self.serial_sum.inner.arity())
.map(|chunk| {
let mut inner = self.serial_sum.inner.clone();
let mut partial_outp = vec![F::zero(); outp.len()];
inner.call_poly(&mut partial_outp, chunk).unwrap();
partial_outp
})
.reduce(
|| vec![F::zero(); outp.len()],
|mut x, y| {
for i in 0..x.len() {
x[i] += y[i];
}
x
},
);
outp.clone_from_slice(&res[..outp.len()]);
Ok(())
}
fn arity(&self) -> usize {
self.serial_sum.arity()
}
fn degree(&self) -> usize {
self.serial_sum.degree()
}
fn calls(&self) -> usize {
self.serial_sum.calls()
}
fn as_any(&mut self) -> &mut dyn Any {
self
}
}
// Check that the input parameters of g.call() are well-formed.
fn gadget_call_check<F: FieldElement, G: Gadget<F>>(
gadget: &G,
in_len: usize,
) -> Result<(), FlpError> {
if in_len != gadget.arity() {
return Err(FlpError::Gadget(format!(
"unexpected number of inputs: got {}; want {}",
in_len,
gadget.arity()
)));
}
if in_len == 0 {
return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string()));
}
Ok(())
}
// Check that the input parameters of g.call_poly() are well-formed.
fn gadget_call_poly_check<F: FieldElement, G: Gadget<F>>(
gadget: &G,
outp: &[F],
inp: &[Vec<F>],
) -> Result<(), FlpError>
where
G: Gadget<F>,
{
gadget_call_check(gadget, inp.len())?;
for i in 1..inp.len() {
if inp[i].len() != inp[0].len() {
return Err(FlpError::Gadget(
"gadget called on wire polynomials with different lengths".to_string(),
));
}
}
let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two();
if outp.len() != expected {
return Err(FlpError::Gadget(format!(
"incorrect output length: got {}; want {}",
outp.len(),
expected
)));
}
Ok(())
}
#[inline]
fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize {
gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::field::{random_vector, Field96 as TestField};
use crate::prng::Prng;
#[test]
fn test_mul() {
// Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the
// naive multiplication code path.
let num_calls = FFT_THRESHOLD / 2;
let mut g: Mul<TestField> = Mul::new(num_calls);
gadget_test(&mut g, num_calls);
// Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises
// FFT-based polynomial multiplication.
let num_calls = FFT_THRESHOLD;
let mut g: Mul<TestField> = Mul::new(num_calls);
gadget_test(&mut g, num_calls);
}
#[test]
fn test_poly_eval() {
let poly: Vec<TestField> = random_vector(10).unwrap();
let num_calls = FFT_THRESHOLD / 2;
let mut g: PolyEval<TestField> = PolyEval::new(poly.clone(), num_calls);
gadget_test(&mut g, num_calls);
let num_calls = FFT_THRESHOLD;
let mut g: PolyEval<TestField> = PolyEval::new(poly, num_calls);
gadget_test(&mut g, num_calls);
}
#[test]
fn test_blind_poly_eval() {
let poly: Vec<TestField> = random_vector(10).unwrap();
let num_calls = FFT_THRESHOLD / 2;
let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly.clone(), num_calls);
gadget_test(&mut g, num_calls);
let num_calls = FFT_THRESHOLD;
let mut g: BlindPolyEval<TestField> = BlindPolyEval::new(poly, num_calls);
gadget_test(&mut g, num_calls);
}
#[test]
fn test_parallel_sum() {
let poly: Vec<TestField> = random_vector(10).unwrap();
let num_calls = 10;
let chunks = 23;
let mut g = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks);
gadget_test(&mut g, num_calls);
}
#[test]
#[cfg(feature = "multithreaded")]
fn test_parallel_sum_multithreaded() {
use std::iter;
let poly: Vec<TestField> = random_vector(10).unwrap();
let num_calls = 10;
let chunks = 23;
let mut g =
ParallelSumMultithreaded::new(BlindPolyEval::new(poly.clone(), num_calls), chunks);
gadget_test(&mut g, num_calls);
// Test that the multithreaded version has the same output as the normal version.
let mut g_serial = ParallelSum::new(BlindPolyEval::new(poly, num_calls), chunks);
assert_eq!(g.arity(), g_serial.arity());
assert_eq!(g.degree(), g_serial.degree());
assert_eq!(g.calls(), g_serial.calls());
let arity = g.arity();
let degree = g.degree();
// Test that both gadgets evaluate to the same value when run on scalar inputs.
let inp: Vec<TestField> = random_vector(arity).unwrap();
let result = g.call(&inp).unwrap();
let result_serial = g_serial.call(&inp).unwrap();
assert_eq!(result, result_serial);
// Test that both gadgets evaluate to the same value when run on polynomial inputs.
let mut poly_outp = vec![TestField::zero(); (degree * (1 + num_calls)).next_power_of_two()];
let mut poly_outp_serial =
vec![TestField::zero(); (degree * (1 + num_calls)).next_power_of_two()];
let mut prng: Prng<TestField, _> = Prng::new().unwrap();
let poly_inp: Vec<_> = iter::repeat_with(|| {
iter::repeat_with(|| prng.get())
.take(1 + num_calls)
.collect::<Vec<_>>()
})
.take(arity)
.collect();
g.call_poly(&mut poly_outp, &poly_inp).unwrap();
g_serial
.call_poly(&mut poly_outp_serial, &poly_inp)
.unwrap();
assert_eq!(poly_outp, poly_outp_serial);
}
// Test that calling g.call_poly() and evaluating the output at a given point is equivalent
// to evaluating each of the inputs at the same point and applying g.call() on the results.
fn gadget_test<F: FieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) {
let wire_poly_len = (1 + num_calls).next_power_of_two();
let mut prng = Prng::new().unwrap();
let mut inp = vec![F::zero(); g.arity()];
let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)];
let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()];
let r = prng.get();
for i in 0..g.arity() {
for j in 0..wire_poly_len {
wire_polys[i][j] = prng.get();
}
inp[i] = poly_eval(&wire_polys[i], r);
}
g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
let got = poly_eval(&gadget_poly, r);
let want = g.call(&inp).unwrap();
assert_eq!(got, want);
// Repeat the call to make sure that the gadget's memory is reset properly between calls.
g.call_poly(&mut gadget_poly, &wire_polys).unwrap();
let got = poly_eval(&gadget_poly, r);
assert_eq!(got, want);
}
}

1174
third_party/rust/prio/src/flp/types.rs vendored Normal file

File diff suppressed because it is too large Load Diff

561
third_party/rust/prio/src/fp.rs vendored Normal file
View File

@ -0,0 +1,561 @@
// SPDX-License-Identifier: MPL-2.0
//! Finite field arithmetic for any field GF(p) for which p < 2^128.
#[cfg(test)]
use rand::{prelude::*, Rng};
/// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots
/// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This
/// is the largest input size we would ever need for the cryptographic applications in this crate.
pub(crate) const MAX_ROOTS: usize = 20;
/// This structure represents the parameters of a finite field GF(p) for which p < 2^128.
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct FieldParameters {
/// The prime modulus `p`.
pub p: u128,
/// `mu = -p^(-1) mod 2^64`.
pub mu: u64,
/// `r2 = (2^128)^2 mod p`.
pub r2: u128,
/// The `2^num_roots`-th -principal root of unity. This element is used to generate the
/// elements of `roots`.
pub g: u128,
/// The number of principal roots of unity in `roots`.
pub num_roots: usize,
/// Equal to `2^b - 1`, where `b` is the length of `p` in bits.
pub bit_mask: u128,
/// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the
/// multiplicative group. `roots[0]` is equal to one by definition.
pub roots: [u128; MAX_ROOTS + 1],
}
impl FieldParameters {
/// Addition. The result will be in [0, p), so long as both x and y are as well.
pub fn add(&self, x: u128, y: u128) -> u128 {
// 0,x
// + 0,y
// =====
// c,z
let (z, carry) = x.overflowing_add(y);
// c, z
// - 0, p
// ========
// b1,s1,s0
let (s0, b0) = z.overflowing_sub(self.p);
let (_s1, b1) = (carry as u128).overflowing_sub(b0 as u128);
// if b1 == 1: return z
// else: return s0
let m = 0u128.wrapping_sub(b1 as u128);
(z & m) | (s0 & !m)
}
/// Subtraction. The result will be in [0, p), so long as both x and y are as well.
pub fn sub(&self, x: u128, y: u128) -> u128 {
// 0, x
// - 0, y
// ========
// b1,z1,z0
let (z0, b0) = x.overflowing_sub(y);
let (_z1, b1) = 0u128.overflowing_sub(b0 as u128);
let m = 0u128.wrapping_sub(b1 as u128);
// z1,z0
// + 0, p
// ========
// s1,s0
z0.wrapping_add(m & self.p)
// if b1 == 1: return s0
// else: return z0
}
/// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm
/// described
/// [here](https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf).
/// The result will be in [0, p).
///
/// # Example usage
/// ```text
/// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46);
/// ```
pub fn mul(&self, x: u128, y: u128) -> u128 {
let x = [lo64(x), hi64(x)];
let y = [lo64(y), hi64(y)];
let p = [lo64(self.p), hi64(self.p)];
let mut zz = [0; 4];
// Integer multiplication
// z = x * y
// x1,x0
// * y1,y0
// ===========
// z3,z2,z1,z0
let mut result = x[0] * y[0];
let mut carry = hi64(result);
zz[0] = lo64(result);
result = x[0] * y[1];
let mut hi = hi64(result);
let mut lo = lo64(result);
result = lo + carry;
zz[1] = lo64(result);
let mut cc = hi64(result);
result = hi + cc;
zz[2] = lo64(result);
result = x[1] * y[0];
hi = hi64(result);
lo = lo64(result);
result = zz[1] + lo;
zz[1] = lo64(result);
cc = hi64(result);
result = hi + cc;
carry = lo64(result);
result = x[1] * y[1];
hi = hi64(result);
lo = lo64(result);
result = lo + carry;
lo = lo64(result);
cc = hi64(result);
result = hi + cc;
hi = lo64(result);
result = zz[2] + lo;
zz[2] = lo64(result);
cc = hi64(result);
result = hi + cc;
zz[3] = lo64(result);
// Montgomery Reduction
// z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64.
// z3,z2,z1,z0
// + p1,p0
// * w = mu*z0
// ===========
// z3,z2,z1, 0
let w = self.mu.wrapping_mul(zz[0] as u64);
result = p[0] * (w as u128);
hi = hi64(result);
lo = lo64(result);
result = zz[0] + lo;
zz[0] = lo64(result);
cc = hi64(result);
result = hi + cc;
carry = lo64(result);
result = p[1] * (w as u128);
hi = hi64(result);
lo = lo64(result);
result = lo + carry;
lo = lo64(result);
cc = hi64(result);
result = hi + cc;
hi = lo64(result);
result = zz[1] + lo;
zz[1] = lo64(result);
cc = hi64(result);
result = zz[2] + hi + cc;
zz[2] = lo64(result);
cc = hi64(result);
result = zz[3] + cc;
zz[3] = lo64(result);
// z3,z2,z1
// + p1,p0
// * w = mu*z1
// ===========
// z3,z2, 0
let w = self.mu.wrapping_mul(zz[1] as u64);
result = p[0] * (w as u128);
hi = hi64(result);
lo = lo64(result);
result = zz[1] + lo;
zz[1] = lo64(result);
cc = hi64(result);
result = hi + cc;
carry = lo64(result);
result = p[1] * (w as u128);
hi = hi64(result);
lo = lo64(result);
result = lo + carry;
lo = lo64(result);
cc = hi64(result);
result = hi + cc;
hi = lo64(result);
result = zz[2] + lo;
zz[2] = lo64(result);
cc = hi64(result);
result = zz[3] + hi + cc;
zz[3] = lo64(result);
cc = hi64(result);
// z = (z3,z2)
let prod = zz[2] | (zz[3] << 64);
// Final subtraction
// If z >= p, then z = z - p
// 0, z
// - 0, p
// ========
// b1,s1,s0
let (s0, b0) = prod.overflowing_sub(self.p);
let (_s1, b1) = (cc as u128).overflowing_sub(b0 as u128);
// if b1 == 1: return z
// else: return s0
let mask = 0u128.wrapping_sub(b1 as u128);
(prod & mask) | (s0 & !mask)
}
/// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the
/// runtime of this algorithm is linear in the bit length of `exp`.
pub fn pow(&self, x: u128, exp: u128) -> u128 {
let mut t = self.montgomery(1);
for i in (0..128 - exp.leading_zeros()).rev() {
t = self.mul(t, t);
if (exp >> i) & 1 != 0 {
t = self.mul(t, x);
}
}
t
}
/// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of
/// this algorithm is linear in the bit length of `p`.
pub fn inv(&self, x: u128) -> u128 {
self.pow(x, self.p - 2)
}
/// Negation, i.e., `-x (mod p)` where `p` is the modulus.
pub fn neg(&self, x: u128) -> u128 {
self.sub(0, x)
}
/// Maps an integer to its internal representation. Field elements are mapped to the Montgomery
/// domain in order to carry out field arithmetic. The result will be in [0, p).
///
/// # Example usage
/// ```text
/// let integer = 1; // Standard integer representation
/// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain
/// assert_eq!(elem, 2564090464);
/// ```
pub fn montgomery(&self, x: u128) -> u128 {
modp(self.mul(x, self.r2), self.p)
}
/// Returns a random field element mapped.
#[cfg(test)]
pub fn rand_elem<R: Rng + ?Sized>(&self, rng: &mut R) -> u128 {
let uniform = rand::distributions::Uniform::from(0..self.p);
self.montgomery(uniform.sample(rng))
}
/// Maps a field element to its representation as an integer. The result will be in [0, p).
///
/// #Example usage
/// ```text
/// let elem = 2564090464; // Internal representation in the Montgomery domain
/// let integer = fp.residue(elem); // Standard integer representation
/// assert_eq!(integer, 1);
/// ```
pub fn residue(&self, x: u128) -> u128 {
modp(self.mul(x, 1), self.p)
}
#[cfg(test)]
pub fn check(&self, p: u128, g: u128, order: u128) {
use modinverse::modinverse;
use num_bigint::{BigInt, ToBigInt};
use std::cmp::max;
assert_eq!(self.p, p, "p mismatch");
let mu = match modinverse((-(p as i128)).rem_euclid(1 << 64), 1 << 64) {
Some(mu) => mu as u64,
None => panic!("inverse of -p (mod 2^64) is undefined"),
};
assert_eq!(self.mu, mu, "mu mismatch");
let big_p = &p.to_bigint().unwrap();
let big_r: &BigInt = &(&(BigInt::from(1) << 128) % big_p);
let big_r2: &BigInt = &(&(big_r * big_r) % big_p);
let mut it = big_r2.iter_u64_digits();
let mut r2 = 0;
r2 |= it.next().unwrap() as u128;
if let Some(x) = it.next() {
r2 |= (x as u128) << 64;
}
assert_eq!(self.r2, r2, "r2 mismatch");
assert_eq!(self.g, self.montgomery(g), "g mismatch");
assert_eq!(
self.residue(self.pow(self.g, order)),
1,
"g order incorrect"
);
let num_roots = log2(order) as usize;
assert_eq!(order, 1 << num_roots, "order not a power of 2");
assert_eq!(self.num_roots, num_roots, "num_roots mismatch");
let mut roots = vec![0; max(num_roots, MAX_ROOTS) + 1];
roots[num_roots] = self.montgomery(g);
for i in (0..num_roots).rev() {
roots[i] = self.mul(roots[i + 1], roots[i + 1]);
}
assert_eq!(&self.roots, &roots[..MAX_ROOTS + 1], "roots mismatch");
assert_eq!(self.residue(self.roots[0]), 1, "first root is not one");
let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1);
assert_eq!(
self.bit_mask.to_bigint().unwrap(),
bit_mask,
"bit_mask mismatch"
);
}
}
fn lo64(x: u128) -> u128 {
x & ((1 << 64) - 1)
}
fn hi64(x: u128) -> u128 {
x >> 64
}
fn modp(x: u128, p: u128) -> u128 {
let (z, carry) = x.overflowing_sub(p);
let m = 0u128.wrapping_sub(carry as u128);
z.wrapping_add(m & p)
}
pub(crate) const FP32: FieldParameters = FieldParameters {
p: 4293918721, // 32-bit prime
mu: 17302828673139736575,
r2: 1676699750,
g: 1074114499,
num_roots: 20,
bit_mask: 4294967295,
roots: [
2564090464, 1729828257, 306605458, 2294308040, 1648889905, 57098624, 2788941825,
2779858277, 368200145, 2760217336, 594450960, 4255832533, 1372848488, 721329415,
3873251478, 1134002069, 7138597, 2004587313, 2989350643, 725214187, 1074114499,
],
};
pub(crate) const FP64: FieldParameters = FieldParameters {
p: 18446744069414584321, // 64-bit prime
mu: 18446744069414584319,
r2: 4294967295,
g: 959634606461954525,
num_roots: 32,
bit_mask: 18446744073709551615,
roots: [
18446744065119617025,
4294967296,
18446462594437939201,
72057594037927936,
1152921504338411520,
16384,
18446743519658770561,
18446735273187346433,
6519596376689022014,
9996039020351967275,
15452408553935940313,
15855629130643256449,
8619522106083987867,
13036116919365988132,
1033106119984023956,
16593078884869787648,
16980581328500004402,
12245796497946355434,
8709441440702798460,
8611358103550827629,
8120528636261052110,
],
};
pub(crate) const FP96: FieldParameters = FieldParameters {
p: 79228148845226978974766202881, // 96-bit prime
mu: 18446744073709551615,
r2: 69162923446439011319006025217,
g: 11329412859948499305522312170,
num_roots: 64,
bit_mask: 79228162514264337593543950335,
roots: [
10128756682736510015896859,
79218020088544242464750306022,
9188608122889034248261485869,
10170869429050723924726258983,
36379376833245035199462139324,
20898601228930800484072244511,
2845758484723985721473442509,
71302585629145191158180162028,
76552499132904394167108068662,
48651998692455360626769616967,
36570983454832589044179852640,
72716740645782532591407744342,
73296872548531908678227377531,
14831293153408122430659535205,
61540280632476003580389854060,
42256269782069635955059793151,
51673352890110285959979141934,
43102967204983216507957944322,
3990455111079735553382399289,
68042997008257313116433801954,
44344622755749285146379045633,
],
};
pub(crate) const FP128: FieldParameters = FieldParameters {
p: 340282366920938462946865773367900766209, // 128-bit prime
mu: 18446744073709551615,
r2: 403909908237944342183153,
g: 107630958476043550189608038630704257141,
num_roots: 66,
bit_mask: 340282366920938463463374607431768211455,
roots: [
516508834063867445247,
340282366920938462430356939304033320962,
129526470195413442198896969089616959958,
169031622068548287099117778531474117974,
81612939378432101163303892927894236156,
122401220764524715189382260548353967708,
199453575871863981432000940507837456190,
272368408887745135168960576051472383806,
24863773656265022616993900367764287617,
257882853788779266319541142124730662203,
323732363244658673145040701829006542956,
57532865270871759635014308631881743007,
149571414409418047452773959687184934208,
177018931070866797456844925926211239962,
268896136799800963964749917185333891349,
244556960591856046954834420512544511831,
118945432085812380213390062516065622346,
202007153998709986841225284843501908420,
332677126194796691532164818746739771387,
258279638927684931537542082169183965856,
148221243758794364405224645520862378432,
],
};
// Compute the ceiling of the base-2 logarithm of `x`.
pub(crate) fn log2(x: u128) -> u128 {
let y = (127 - x.leading_zeros()) as u128;
y + ((x > 1 << y) as u128)
}
#[cfg(test)]
mod tests {
use super::*;
use num_bigint::ToBigInt;
#[test]
fn test_log2() {
assert_eq!(log2(1), 0);
assert_eq!(log2(2), 1);
assert_eq!(log2(3), 2);
assert_eq!(log2(4), 2);
assert_eq!(log2(15), 4);
assert_eq!(log2(16), 4);
assert_eq!(log2(30), 5);
assert_eq!(log2(32), 5);
assert_eq!(log2(1 << 127), 127);
assert_eq!(log2((1 << 127) + 13), 128);
}
struct TestFieldParametersData {
fp: FieldParameters, // The paramters being tested
expected_p: u128, // Expected fp.p
expected_g: u128, // Expected fp.residue(fp.g)
expected_order: u128, // Expect fp.residue(fp.pow(fp.g, expected_order)) == 1
}
#[test]
fn test_fp() {
let test_fps = vec![
TestFieldParametersData {
fp: FP32,
expected_p: 4293918721,
expected_g: 3925978153,
expected_order: 1 << 20,
},
TestFieldParametersData {
fp: FP64,
expected_p: 18446744069414584321,
expected_g: 1753635133440165772,
expected_order: 1 << 32,
},
TestFieldParametersData {
fp: FP96,
expected_p: 79228148845226978974766202881,
expected_g: 34233996298771126927060021012,
expected_order: 1 << 64,
},
TestFieldParametersData {
fp: FP128,
expected_p: 340282366920938462946865773367900766209,
expected_g: 145091266659756586618791329697897684742,
expected_order: 1 << 66,
},
];
for t in test_fps.into_iter() {
// Check that the field parameters have been constructed properly.
t.fp.check(t.expected_p, t.expected_g, t.expected_order);
// Check that the generator has the correct order.
assert_eq!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order)), 1);
assert_ne!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order / 2)), 1);
// Test arithmetic using the field parameters.
arithmetic_test(&t.fp);
}
}
fn arithmetic_test(fp: &FieldParameters) {
let mut rng = rand::thread_rng();
let big_p = &fp.p.to_bigint().unwrap();
for _ in 0..100 {
let x = fp.rand_elem(&mut rng);
let y = fp.rand_elem(&mut rng);
let big_x = &fp.residue(x).to_bigint().unwrap();
let big_y = &fp.residue(y).to_bigint().unwrap();
// Test addition.
let got = fp.add(x, y);
let want = (big_x + big_y) % big_p;
assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
// Test subtraction.
let got = fp.sub(x, y);
let want = if big_x >= big_y {
big_x - big_y
} else {
big_p - big_y + big_x
};
assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
// Test multiplication.
let got = fp.mul(x, y);
let want = (big_x * big_y) % big_p;
assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
// Test inversion.
let got = fp.inv(x);
let want = big_x.modpow(&(big_p - 2u128), big_p);
assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
assert_eq!(fp.residue(fp.mul(got, x)), 1);
// Test negation.
let got = fp.neg(x);
let want = (big_p - big_x) % big_p;
assert_eq!(fp.residue(got).to_bigint().unwrap(), want);
assert_eq!(fp.residue(fp.add(got, x)), 0);
}
}
}

33
third_party/rust/prio/src/lib.rs vendored Normal file
View File

@ -0,0 +1,33 @@
// Copyright (c) 2020 Apple Inc.
// SPDX-License-Identifier: MPL-2.0
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! Libprio-rs
//!
//! Implementation of the [Prio](https://crypto.stanford.edu/prio/) private data aggregation
//! protocol. For now we only support 0 / 1 vectors.
pub mod benchmarked;
#[cfg(feature = "prio2")]
pub mod client;
#[cfg(feature = "prio2")]
pub mod encrypt;
#[cfg(feature = "prio2")]
pub mod server;
pub mod codec;
mod fft;
pub mod field;
pub mod flp;
mod fp;
mod polynomial;
mod prng;
// Module test_vector depends on crate `rand` so we make it an optional feature
// to spare most clients the extra dependency.
#[cfg(all(any(feature = "test-util", test), feature = "prio2"))]
pub mod test_vector;
#[cfg(feature = "prio2")]
pub mod util;
pub mod vdaf;

384
third_party/rust/prio/src/polynomial.rs vendored Normal file
View File

@ -0,0 +1,384 @@
// Copyright (c) 2020 Apple Inc.
// SPDX-License-Identifier: MPL-2.0
//! Functions for polynomial interpolation and evaluation
use crate::field::FieldElement;
use std::convert::TryFrom;
/// Temporary memory used for FFT
#[derive(Clone, Debug)]
pub struct PolyFFTTempMemory<F> {
fft_tmp: Vec<F>,
fft_y_sub: Vec<F>,
fft_roots_sub: Vec<F>,
}
impl<F: FieldElement> PolyFFTTempMemory<F> {
fn new(length: usize) -> Self {
PolyFFTTempMemory {
fft_tmp: vec![F::zero(); length],
fft_y_sub: vec![F::zero(); length],
fft_roots_sub: vec![F::zero(); length],
}
}
}
/// Auxiliary memory for polynomial interpolation and evaluation
#[derive(Clone, Debug)]
pub struct PolyAuxMemory<F> {
pub roots_2n: Vec<F>,
pub roots_2n_inverted: Vec<F>,
pub roots_n: Vec<F>,
pub roots_n_inverted: Vec<F>,
pub coeffs: Vec<F>,
pub fft_memory: PolyFFTTempMemory<F>,
}
impl<F: FieldElement> PolyAuxMemory<F> {
pub fn new(n: usize) -> Self {
PolyAuxMemory {
roots_2n: fft_get_roots(2 * n, false),
roots_2n_inverted: fft_get_roots(2 * n, true),
roots_n: fft_get_roots(n, false),
roots_n_inverted: fft_get_roots(n, true),
coeffs: vec![F::zero(); 2 * n],
fft_memory: PolyFFTTempMemory::new(2 * n),
}
}
}
fn fft_recurse<F: FieldElement>(
out: &mut [F],
n: usize,
roots: &[F],
ys: &[F],
tmp: &mut [F],
y_sub: &mut [F],
roots_sub: &mut [F],
) {
if n == 1 {
out[0] = ys[0];
return;
}
let half_n = n / 2;
let (tmp_first, tmp_second) = tmp.split_at_mut(half_n);
let (y_sub_first, y_sub_second) = y_sub.split_at_mut(half_n);
let (roots_sub_first, roots_sub_second) = roots_sub.split_at_mut(half_n);
// Recurse on the first half
for i in 0..half_n {
y_sub_first[i] = ys[i] + ys[i + half_n];
roots_sub_first[i] = roots[2 * i];
}
fft_recurse(
tmp_first,
half_n,
roots_sub_first,
y_sub_first,
tmp_second,
y_sub_second,
roots_sub_second,
);
for i in 0..half_n {
out[2 * i] = tmp_first[i];
}
// Recurse on the second half
for i in 0..half_n {
y_sub_first[i] = ys[i] - ys[i + half_n];
y_sub_first[i] *= roots[i];
}
fft_recurse(
tmp_first,
half_n,
roots_sub_first,
y_sub_first,
tmp_second,
y_sub_second,
roots_sub_second,
);
for i in 0..half_n {
out[2 * i + 1] = tmp[i];
}
}
/// Calculate `count` number of roots of unity of order `count`
fn fft_get_roots<F: FieldElement>(count: usize, invert: bool) -> Vec<F> {
let mut roots = vec![F::zero(); count];
let mut gen = F::generator();
if invert {
gen = gen.inv();
}
roots[0] = F::one();
let step_size = F::generator_order() / F::Integer::try_from(count).unwrap();
// generator for subgroup of order count
gen = gen.pow(step_size);
roots[1] = gen;
for i in 2..count {
roots[i] = gen * roots[i - 1];
}
roots
}
fn fft_interpolate_raw<F: FieldElement>(
out: &mut [F],
ys: &[F],
n_points: usize,
roots: &[F],
invert: bool,
mem: &mut PolyFFTTempMemory<F>,
) {
fft_recurse(
out,
n_points,
roots,
ys,
&mut mem.fft_tmp,
&mut mem.fft_y_sub,
&mut mem.fft_roots_sub,
);
if invert {
let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv();
#[allow(clippy::needless_range_loop)]
for i in 0..n_points {
out[i] *= n_inverse;
}
}
}
pub fn poly_fft<F: FieldElement>(
points_out: &mut [F],
points_in: &[F],
scaled_roots: &[F],
n_points: usize,
invert: bool,
mem: &mut PolyFFTTempMemory<F>,
) {
fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem)
}
// Evaluate a polynomial using Horner's method.
pub fn poly_eval<F: FieldElement>(poly: &[F], eval_at: F) -> F {
if poly.is_empty() {
return F::zero();
}
let mut result = poly[poly.len() - 1];
for i in (0..poly.len() - 1).rev() {
result *= eval_at;
result += poly[i];
}
result
}
// Returns the degree of polynomial `p`.
pub fn poly_deg<F: FieldElement>(p: &[F]) -> usize {
let mut d = p.len();
while d > 0 && p[d - 1] == F::zero() {
d -= 1;
}
d.saturating_sub(1)
}
// Multiplies polynomials `p` and `q` and returns the result.
pub fn poly_mul<F: FieldElement>(p: &[F], q: &[F]) -> Vec<F> {
let p_size = poly_deg(p) + 1;
let q_size = poly_deg(q) + 1;
let mut out = vec![F::zero(); p_size + q_size];
for i in 0..p_size {
for j in 0..q_size {
out[i + j] += p[i] * q[j];
}
}
out.truncate(poly_deg(&out) + 1);
out
}
#[cfg(feature = "prio2")]
pub fn poly_interpret_eval<F: FieldElement>(
points: &[F],
roots: &[F],
eval_at: F,
tmp_coeffs: &mut [F],
fft_memory: &mut PolyFFTTempMemory<F>,
) -> F {
poly_fft(tmp_coeffs, points, roots, points.len(), true, fft_memory);
poly_eval(&tmp_coeffs[..points.len()], eval_at)
}
// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise,
// the output is not `0`.
pub(crate) fn poly_range_check<F: FieldElement>(start: usize, end: usize) -> Vec<F> {
let mut p = vec![F::one()];
let mut q = [F::zero(), F::one()];
for i in start..end {
q[0] = -F::from(F::Integer::try_from(i).unwrap());
p = poly_mul(&p, &q);
}
p
}
#[test]
fn test_roots() {
use crate::field::Field32;
let count = 128;
let roots = fft_get_roots::<Field32>(count, false);
let roots_inv = fft_get_roots::<Field32>(count, true);
for i in 0..count {
assert_eq!(roots[i] * roots_inv[i], 1);
assert_eq!(roots[i].pow(u32::try_from(count).unwrap()), 1);
assert_eq!(roots_inv[i].pow(u32::try_from(count).unwrap()), 1);
}
}
#[test]
fn test_eval() {
use crate::field::Field32;
let mut poly = vec![Field32::from(0); 4];
poly[0] = 2.into();
poly[1] = 1.into();
poly[2] = 5.into();
// 5*3^2 + 3 + 2 = 50
assert_eq!(poly_eval(&poly[..3], 3.into()), 50);
poly[3] = 4.into();
// 4*3^3 + 5*3^2 + 3 + 2 = 158
assert_eq!(poly_eval(&poly[..4], 3.into()), 158);
}
#[test]
fn test_poly_deg() {
use crate::field::Field32;
let zero = Field32::zero();
let one = Field32::root(0).unwrap();
assert_eq!(poly_deg(&[zero]), 0);
assert_eq!(poly_deg(&[one]), 0);
assert_eq!(poly_deg(&[zero, one]), 1);
assert_eq!(poly_deg(&[zero, zero, one]), 2);
assert_eq!(poly_deg(&[zero, one, one]), 2);
assert_eq!(poly_deg(&[zero, one, one, one]), 3);
assert_eq!(poly_deg(&[zero, one, one, one, zero]), 3);
assert_eq!(poly_deg(&[zero, one, one, one, zero, zero]), 3);
}
#[test]
fn test_poly_mul() {
use crate::field::Field64;
let p = [
Field64::from(u64::try_from(2).unwrap()),
Field64::from(u64::try_from(3).unwrap()),
];
let q = [
Field64::one(),
Field64::zero(),
Field64::from(u64::try_from(5).unwrap()),
];
let want = [
Field64::from(u64::try_from(2).unwrap()),
Field64::from(u64::try_from(3).unwrap()),
Field64::from(u64::try_from(10).unwrap()),
Field64::from(u64::try_from(15).unwrap()),
];
let got = poly_mul(&p, &q);
assert_eq!(&got, &want);
}
#[test]
fn test_poly_range_check() {
use crate::field::Field64;
let start = 74;
let end = 112;
let p = poly_range_check(start, end);
// Check each number in the range.
for i in start..end {
let x = Field64::from(i as u64);
let y = poly_eval(&p, x);
assert_eq!(y, Field64::zero(), "range check failed for {}", i);
}
// Check the number below the range.
let x = Field64::from((start - 1) as u64);
let y = poly_eval(&p, x);
assert_ne!(y, Field64::zero());
// Check a number above the range.
let x = Field64::from(end as u64);
let y = poly_eval(&p, x);
assert_ne!(y, Field64::zero());
}
#[test]
fn test_fft() {
use crate::field::Field32;
use rand::prelude::*;
use std::convert::TryFrom;
let count = 128;
let mut mem = PolyAuxMemory::new(count / 2);
let mut poly = vec![Field32::from(0); count];
let mut points2 = vec![Field32::from(0); count];
let points = (0..count)
.into_iter()
.map(|_| Field32::from(random::<u32>()))
.collect::<Vec<Field32>>();
// From points to coeffs and back
poly_fft(
&mut poly,
&points,
&mem.roots_2n,
count,
false,
&mut mem.fft_memory,
);
poly_fft(
&mut points2,
&poly,
&mem.roots_2n_inverted,
count,
true,
&mut mem.fft_memory,
);
assert_eq!(points, points2);
// interpolation
poly_fft(
&mut poly,
&points,
&mem.roots_2n,
count,
false,
&mut mem.fft_memory,
);
#[allow(clippy::needless_range_loop)]
for i in 0..count {
let mut should_be = Field32::from(0);
for j in 0..count {
should_be = mem.roots_2n[i].pow(u32::try_from(j).unwrap()) * points[j] + should_be;
}
assert_eq!(should_be, poly[i]);
}
}

208
third_party/rust/prio/src/prng.rs vendored Normal file
View File

@ -0,0 +1,208 @@
// Copyright (c) 2020 Apple Inc.
// SPDX-License-Identifier: MPL-2.0
//! Tool for generating pseudorandom field elements.
//!
//! NOTE: The public API for this module is a work in progress.
use crate::field::{FieldElement, FieldError};
use crate::vdaf::prg::SeedStream;
#[cfg(feature = "crypto-dependencies")]
use crate::vdaf::prg::SeedStreamAes128;
#[cfg(feature = "crypto-dependencies")]
use getrandom::getrandom;
use std::marker::PhantomData;
const BUFFER_SIZE_IN_ELEMENTS: usize = 128;
/// Errors propagated by methods in this module.
#[derive(Debug, thiserror::Error)]
pub enum PrngError {
/// Failure when calling getrandom().
#[error("getrandom: {0}")]
GetRandom(#[from] getrandom::Error),
}
/// This type implements an iterator that generates a pseudorandom sequence of field elements. The
/// sequence is derived from the key stream of AES-128 in CTR mode with a random IV.
#[derive(Debug)]
pub(crate) struct Prng<F, S> {
phantom: PhantomData<F>,
seed_stream: S,
buffer: Vec<u8>,
buffer_index: usize,
output_written: usize,
}
#[cfg(feature = "crypto-dependencies")]
impl<F: FieldElement> Prng<F, SeedStreamAes128> {
/// Create a [`Prng`] from a seed for Prio 2. The first 16 bytes of the seed and the last 16
/// bytes of the seed are used, respectively, for the key and initialization vector for AES128
/// in CTR mode.
pub(crate) fn from_prio2_seed(seed: &[u8; 32]) -> Self {
let seed_stream = SeedStreamAes128::new(&seed[..16], &seed[16..]);
Self::from_seed_stream(seed_stream)
}
/// Create a [`Prng`] from a randomly generated seed.
pub(crate) fn new() -> Result<Self, PrngError> {
let mut seed = [0; 32];
getrandom(&mut seed)?;
Ok(Self::from_prio2_seed(&seed))
}
}
impl<F, S> Prng<F, S>
where
F: FieldElement,
S: SeedStream,
{
pub(crate) fn from_seed_stream(mut seed_stream: S) -> Self {
let mut buffer = vec![0; BUFFER_SIZE_IN_ELEMENTS * F::ENCODED_SIZE];
seed_stream.fill(&mut buffer);
Self {
phantom: PhantomData::<F>,
seed_stream,
buffer,
buffer_index: 0,
output_written: 0,
}
}
pub(crate) fn get(&mut self) -> F {
loop {
// Seek to the next chunk of the buffer that encodes an element of F.
for i in (self.buffer_index..self.buffer.len()).step_by(F::ENCODED_SIZE) {
let j = i + F::ENCODED_SIZE;
if let Some(x) = match F::try_from_random(&self.buffer[i..j]) {
Ok(x) => Some(x),
Err(FieldError::ModulusOverflow) => None, // reject this sample
Err(err) => panic!("unexpected error: {}", err),
} {
// Set the buffer index to the next chunk.
self.buffer_index = j;
self.output_written += 1;
return x;
}
}
// Refresh buffer with the next chunk of PRG output.
self.seed_stream.fill(&mut self.buffer);
self.buffer_index = 0;
}
}
}
impl<F, S> Iterator for Prng<F, S>
where
F: FieldElement,
S: SeedStream,
{
type Item = F;
fn next(&mut self) -> Option<F> {
Some(self.get())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
codec::Decode,
field::{Field96, FieldPrio2},
vdaf::prg::{Prg, PrgAes128, Seed},
};
use std::convert::TryInto;
#[test]
fn secret_sharing_interop() {
let seed = [
0xcd, 0x85, 0x5b, 0xd4, 0x86, 0x48, 0xa4, 0xce, 0x52, 0x5c, 0x36, 0xee, 0x5a, 0x71,
0xf3, 0x0f, 0x66, 0x80, 0xd3, 0x67, 0x53, 0x9a, 0x39, 0x6f, 0x12, 0x2f, 0xad, 0x94,
0x4d, 0x34, 0xcb, 0x58,
];
let reference = [
0xd0056ec5, 0xe23f9c52, 0x47e4ddb4, 0xbe5dacf6, 0x4b130aba, 0x530c7a90, 0xe8fc4ee5,
0xb0569cb7, 0x7774cd3c, 0x7f24e6a5, 0xcc82355d, 0xc41f4f13, 0x67fe193c, 0xc94d63a4,
0x5d7b474c, 0xcc5c9f5f, 0xe368e1d5, 0x020fa0cf, 0x9e96aa2a, 0xe924137d, 0xfa026ab9,
0x8ebca0cc, 0x26fc58a5, 0x10a7b173, 0xb9c97291, 0x53ef0e28, 0x069cfb8e, 0xe9383cae,
0xacb8b748, 0x6f5b9d49, 0x887d061b, 0x86db0c58,
];
let share2 = extract_share_from_seed::<FieldPrio2>(reference.len(), &seed);
assert_eq!(share2, reference);
}
/// takes a seed and hash as base64 encoded strings
#[cfg(feature = "prio2")]
fn random_data_interop(seed_base64: &str, hash_base64: &str, len: usize) {
let seed = base64::decode(seed_base64).unwrap();
let random_data = extract_share_from_seed::<FieldPrio2>(len, &seed);
let random_bytes = FieldPrio2::slice_into_byte_vec(&random_data);
let digest = ring::digest::digest(&ring::digest::SHA256, &random_bytes);
assert_eq!(base64::encode(digest), hash_base64);
}
#[test]
#[cfg(feature = "prio2")]
fn test_hash_interop() {
random_data_interop(
"AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=",
"RtzeQuuiWdD6bW2ZTobRELDmClz1wLy3HUiKsYsITOI=",
100_000,
);
// zero seed
random_data_interop(
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
"3wHQbSwAn9GPfoNkKe1qSzWdKnu/R+hPPyRwwz6Di+w=",
100_000,
);
// 0, 1, 2 ... seed
random_data_interop(
"AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=",
"RtzeQuuiWdD6bW2ZTobRELDmClz1wLy3HUiKsYsITOI=",
100_000,
);
// one arbirtary fixed seed
random_data_interop(
"rkLrnVcU8ULaiuXTvR3OKrfpMX0kQidqVzta1pleKKg=",
"b1fMXYrGUNR3wOZ/7vmUMmY51QHoPDBzwok0fz6xC0I=",
100_000,
);
// all bits set seed
random_data_interop(
"//////////////////////////////////////////8=",
"iBiDaqLrv7/rX/+vs6akPiprGgYfULdh/XhoD61HQXA=",
100_000,
);
}
fn extract_share_from_seed<F: FieldElement>(length: usize, seed: &[u8]) -> Vec<F> {
assert_eq!(seed.len(), 32);
Prng::from_prio2_seed(seed.try_into().unwrap())
.take(length)
.collect()
}
#[test]
fn rejection_sampling_test_vector() {
// These constants were found in a brute-force search, and they test that the PRG performs
// rejection sampling correctly when raw AES-CTR output exceeds the prime modulus.
let seed_stream = PrgAes128::seed_stream(
&Seed::get_decoded(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 95]).unwrap(),
b"",
);
let mut prng = Prng::<Field96, _>::from_seed_stream(seed_stream);
let expected = Field96::from(39729620190871453347343769187);
let actual = prng.nth(145).unwrap();
assert_eq!(actual, expected);
}
}

469
third_party/rust/prio/src/server.rs vendored Normal file
View File

@ -0,0 +1,469 @@
// Copyright (c) 2020 Apple Inc.
// SPDX-License-Identifier: MPL-2.0
//! The Prio v2 server. Only 0 / 1 vectors are supported for now.
use crate::{
encrypt::{decrypt_share, EncryptError, PrivateKey},
field::{merge_vector, FieldElement, FieldError},
polynomial::{poly_interpret_eval, PolyAuxMemory},
prng::{Prng, PrngError},
util::{proof_length, unpack_proof, SerializeError},
vdaf::prg::SeedStreamAes128,
};
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
/// Possible errors from server operations
#[derive(Debug, thiserror::Error)]
pub enum ServerError {
/// Unexpected Share Length
#[error("unexpected share length")]
ShareLength,
/// Encryption/decryption error
#[error("encryption/decryption error")]
Encrypt(#[from] EncryptError),
/// Finite field operation error
#[error("finite field operation error")]
Field(#[from] FieldError),
/// Serialization/deserialization error
#[error("serialization/deserialization error")]
Serialize(#[from] SerializeError),
/// Failure when calling getrandom().
#[error("getrandom: {0}")]
GetRandom(#[from] getrandom::Error),
/// PRNG error.
#[error("prng error: {0}")]
Prng(#[from] PrngError),
}
/// Auxiliary memory for constructing a
/// [`VerificationMessage`](struct.VerificationMessage.html)
#[derive(Debug)]
pub struct ValidationMemory<F> {
points_f: Vec<F>,
points_g: Vec<F>,
points_h: Vec<F>,
poly_mem: PolyAuxMemory<F>,
}
impl<F: FieldElement> ValidationMemory<F> {
/// Construct a new ValidationMemory object for validating proof shares of
/// length `dimension`.
pub fn new(dimension: usize) -> Self {
let n: usize = (dimension + 1).next_power_of_two();
ValidationMemory {
points_f: vec![F::zero(); n],
points_g: vec![F::zero(); n],
points_h: vec![F::zero(); 2 * n],
poly_mem: PolyAuxMemory::new(n),
}
}
}
/// Main workhorse of the server.
#[derive(Debug)]
pub struct Server<F> {
prng: Prng<F, SeedStreamAes128>,
dimension: usize,
is_first_server: bool,
accumulator: Vec<F>,
validation_mem: ValidationMemory<F>,
private_key: PrivateKey,
}
impl<F: FieldElement> Server<F> {
/// Construct a new server instance
///
/// Params:
/// * `dimension`: the number of elements in the aggregation vector.
/// * `is_first_server`: only one of the servers should have this true.
/// * `private_key`: the private key for decrypting the share of the proof.
pub fn new(
dimension: usize,
is_first_server: bool,
private_key: PrivateKey,
) -> Result<Server<F>, ServerError> {
Ok(Server {
prng: Prng::new()?,
dimension,
is_first_server,
accumulator: vec![F::zero(); dimension],
validation_mem: ValidationMemory::new(dimension),
private_key,
})
}
/// Decrypt and deserialize
fn deserialize_share(&self, encrypted_share: &[u8]) -> Result<Vec<F>, ServerError> {
let len = proof_length(self.dimension);
let share = decrypt_share(encrypted_share, &self.private_key)?;
Ok(if self.is_first_server {
F::byte_slice_into_vec(&share)?
} else {
if share.len() != 32 {
return Err(ServerError::ShareLength);
}
Prng::from_prio2_seed(&share.try_into().unwrap())
.take(len)
.collect()
})
}
/// Generate verification message from an encrypted share
///
/// This decrypts the share of the proof and constructs the
/// [`VerificationMessage`](struct.VerificationMessage.html).
/// The `eval_at` field should be generate by
/// [choose_eval_at](#method.choose_eval_at).
pub fn generate_verification_message(
&mut self,
eval_at: F,
share: &[u8],
) -> Result<VerificationMessage<F>, ServerError> {
let share_field = self.deserialize_share(share)?;
generate_verification_message(
self.dimension,
eval_at,
&share_field,
self.is_first_server,
&mut self.validation_mem,
)
}
/// Add the content of the encrypted share into the accumulator
///
/// This only changes the accumulator if the verification messages `v1` and
/// `v2` indicate that the share passed validation.
pub fn aggregate(
&mut self,
share: &[u8],
v1: &VerificationMessage<F>,
v2: &VerificationMessage<F>,
) -> Result<bool, ServerError> {
let share_field = self.deserialize_share(share)?;
let is_valid = is_valid_share(v1, v2);
if is_valid {
// Add to the accumulator. share_field also includes the proof
// encoding, so we slice off the first dimension fields, which are
// the actual data share.
merge_vector(&mut self.accumulator, &share_field[..self.dimension])?;
}
Ok(is_valid)
}
/// Return the current accumulated shares.
///
/// These can be merged together using
/// [`reconstruct_shares`](../util/fn.reconstruct_shares.html).
pub fn total_shares(&self) -> &[F] {
&self.accumulator
}
/// Merge shares from another server.
///
/// This modifies the current accumulator.
///
/// # Errors
///
/// Returns an error if `other_total_shares.len()` is not equal to this
//// server's `dimension`.
pub fn merge_total_shares(&mut self, other_total_shares: &[F]) -> Result<(), ServerError> {
Ok(merge_vector(&mut self.accumulator, other_total_shares)?)
}
/// Choose a random point for polynomial evaluation
///
/// The point returned is not one of the roots used for polynomial
/// evaluation.
pub fn choose_eval_at(&mut self) -> F {
loop {
let eval_at = self.prng.get();
if !self.validation_mem.poly_mem.roots_2n.contains(&eval_at) {
break eval_at;
}
}
}
}
/// Verification message for proof validation
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VerificationMessage<F> {
/// f evaluated at random point
pub f_r: F,
/// g evaluated at random point
pub g_r: F,
/// h evaluated at random point
pub h_r: F,
}
/// Given a proof and evaluation point, this constructs the verification
/// message.
pub fn generate_verification_message<F: FieldElement>(
dimension: usize,
eval_at: F,
proof: &[F],
is_first_server: bool,
mem: &mut ValidationMemory<F>,
) -> Result<VerificationMessage<F>, ServerError> {
let unpacked = unpack_proof(proof, dimension)?;
let proof_length = 2 * (dimension + 1).next_power_of_two();
// set zero terms
mem.points_f[0] = *unpacked.f0;
mem.points_g[0] = *unpacked.g0;
mem.points_h[0] = *unpacked.h0;
// set points_f and points_g
for (i, x) in unpacked.data.iter().enumerate() {
mem.points_f[i + 1] = *x;
if is_first_server {
// only one server needs to subtract one for point_g
mem.points_g[i + 1] = *x - F::one();
} else {
mem.points_g[i + 1] = *x;
}
}
// set points_h, skipping over elements that should be zero
let mut i = 1;
let mut j = 0;
while i < proof_length {
mem.points_h[i] = unpacked.points_h_packed[j];
j += 1;
i += 2;
}
// evaluate polynomials at random point
let f_r = poly_interpret_eval(
&mem.points_f,
&mem.poly_mem.roots_n_inverted,
eval_at,
&mut mem.poly_mem.coeffs,
&mut mem.poly_mem.fft_memory,
);
let g_r = poly_interpret_eval(
&mem.points_g,
&mem.poly_mem.roots_n_inverted,
eval_at,
&mut mem.poly_mem.coeffs,
&mut mem.poly_mem.fft_memory,
);
let h_r = poly_interpret_eval(
&mem.points_h,
&mem.poly_mem.roots_2n_inverted,
eval_at,
&mut mem.poly_mem.coeffs,
&mut mem.poly_mem.fft_memory,
);
Ok(VerificationMessage { f_r, g_r, h_r })
}
/// Decides if the distributed proof is valid
pub fn is_valid_share<F: FieldElement>(
v1: &VerificationMessage<F>,
v2: &VerificationMessage<F>,
) -> bool {
// reconstruct f_r, g_r, h_r
let f_r = v1.f_r + v2.f_r;
let g_r = v1.g_r + v2.g_r;
let h_r = v1.h_r + v2.h_r;
// validity check
f_r * g_r == h_r
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
encrypt::{encrypt_share, PublicKey},
field::{Field32, FieldPrio2},
test_vector::Priov2TestVector,
util::{self, unpack_proof_mut},
};
use serde_json;
#[test]
fn test_validation() {
let dim = 8;
let proof_u32: Vec<u32> = vec![
1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722,
3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680,
2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149,
];
let mut proof: Vec<Field32> = proof_u32.iter().map(|x| Field32::from(*x)).collect();
let share2 = util::tests::secret_share(&mut proof);
let eval_at = Field32::from(12313);
let mut validation_mem = ValidationMemory::new(dim);
let v1 =
generate_verification_message(dim, eval_at, &proof, true, &mut validation_mem).unwrap();
let v2 = generate_verification_message(dim, eval_at, &share2, false, &mut validation_mem)
.unwrap();
assert!(is_valid_share(&v1, &v2));
}
#[test]
fn test_verification_message_serde() {
let dim = 8;
let proof_u32: Vec<u32> = vec![
1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722,
3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680,
2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149,
];
let mut proof: Vec<Field32> = proof_u32.iter().map(|x| Field32::from(*x)).collect();
let share2 = util::tests::secret_share(&mut proof);
let eval_at = Field32::from(12313);
let mut validation_mem = ValidationMemory::new(dim);
let v1 =
generate_verification_message(dim, eval_at, &proof, true, &mut validation_mem).unwrap();
let v2 = generate_verification_message(dim, eval_at, &share2, false, &mut validation_mem)
.unwrap();
// serialize and deserialize the first verification message
let serialized = serde_json::to_string(&v1).unwrap();
let deserialized: VerificationMessage<Field32> = serde_json::from_str(&serialized).unwrap();
assert!(is_valid_share(&deserialized, &v2));
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum Tweak {
None,
WrongInput,
DataPartOfShare,
ZeroTermF,
ZeroTermG,
ZeroTermH,
PointsH,
VerificationF,
VerificationG,
VerificationH,
}
fn tweaks(tweak: Tweak) {
let dim = 123;
// We generate a test vector just to get a `Client` and `Server`s with
// encryption keys but construct and tweak inputs below.
let test_vector = Priov2TestVector::new(dim, 0).unwrap();
let mut server1 = test_vector.server_1().unwrap();
let mut server2 = test_vector.server_2().unwrap();
let mut client = test_vector.client().unwrap();
// all zero data
let mut data = vec![FieldPrio2::zero(); dim];
if let Tweak::WrongInput = tweak {
data[0] = FieldPrio2::from(2);
}
let (share1_original, share2) = client.encode_simple(&data).unwrap();
let decrypted_share1 = decrypt_share(&share1_original, &server1.private_key).unwrap();
let mut share1_field = FieldPrio2::byte_slice_into_vec(&decrypted_share1).unwrap();
let unpacked_share1 = unpack_proof_mut(&mut share1_field, dim).unwrap();
let one = FieldPrio2::from(1);
match tweak {
Tweak::DataPartOfShare => unpacked_share1.data[0] += one,
Tweak::ZeroTermF => *unpacked_share1.f0 += one,
Tweak::ZeroTermG => *unpacked_share1.g0 += one,
Tweak::ZeroTermH => *unpacked_share1.h0 += one,
Tweak::PointsH => unpacked_share1.points_h_packed[0] += one,
_ => (),
};
// reserialize altered share1
let share1_modified = encrypt_share(
&FieldPrio2::slice_into_byte_vec(&share1_field),
&PublicKey::from(&server1.private_key),
)
.unwrap();
let eval_at = server1.choose_eval_at();
let mut v1 = server1
.generate_verification_message(eval_at, &share1_modified)
.unwrap();
let v2 = server2
.generate_verification_message(eval_at, &share2)
.unwrap();
match tweak {
Tweak::VerificationF => v1.f_r += one,
Tweak::VerificationG => v1.g_r += one,
Tweak::VerificationH => v1.h_r += one,
_ => (),
}
let should_be_valid = matches!(tweak, Tweak::None);
assert_eq!(
server1.aggregate(&share1_modified, &v1, &v2).unwrap(),
should_be_valid
);
assert_eq!(
server2.aggregate(&share2, &v1, &v2).unwrap(),
should_be_valid
);
}
#[test]
fn tweak_none() {
tweaks(Tweak::None);
}
#[test]
fn tweak_input() {
tweaks(Tweak::WrongInput);
}
#[test]
fn tweak_data() {
tweaks(Tweak::DataPartOfShare);
}
#[test]
fn tweak_f_zero() {
tweaks(Tweak::ZeroTermF);
}
#[test]
fn tweak_g_zero() {
tweaks(Tweak::ZeroTermG);
}
#[test]
fn tweak_h_zero() {
tweaks(Tweak::ZeroTermH);
}
#[test]
fn tweak_h_points() {
tweaks(Tweak::PointsH);
}
#[test]
fn tweak_f_verif() {
tweaks(Tweak::VerificationF);
}
#[test]
fn tweak_g_verif() {
tweaks(Tweak::VerificationG);
}
#[test]
fn tweak_h_verif() {
tweaks(Tweak::VerificationH);
}
}

244
third_party/rust/prio/src/test_vector.rs vendored Normal file
View File

@ -0,0 +1,244 @@
// SPDX-License-Identifier: MPL-2.0
//! Module `test_vector` generates test vectors of serialized Prio inputs and
//! support for working with test vectors, enabling backward compatibility
//! testing.
use crate::{
client::{Client, ClientError},
encrypt::{PrivateKey, PublicKey},
field::{FieldElement, FieldPrio2},
server::{Server, ServerError},
};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
/// Errors propagated by functions in this module.
#[derive(Debug, thiserror::Error)]
pub enum TestVectorError {
/// Error from Prio client
#[error("Prio client error {0}")]
Client(#[from] ClientError),
/// Error from Prio server
#[error("Prio server error {0}")]
Server(#[from] ServerError),
/// Error while converting primitive to FieldElement associated integer type
#[error("Integer conversion error {0}")]
IntegerConversion(String),
}
const SERVER_1_PRIVATE_KEY: &str =
"BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBH\
fNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==";
const SERVER_2_PRIVATE_KEY: &str =
"BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rD\
ULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==";
/// An ECDSA P-256 private key suitable for decrypting inputs, used to generate
/// test vectors and later to decrypt them.
fn server_1_private_key() -> PrivateKey {
PrivateKey::from_base64(SERVER_1_PRIVATE_KEY).unwrap()
}
/// The public portion of [`server_1_private_key`].
fn server_1_public_key() -> PublicKey {
PublicKey::from(&server_1_private_key())
}
/// An ECDSA P-256 private key suitable for decrypting inputs, used to generate
/// test vectors and later to decrypt them.
fn server_2_private_key() -> PrivateKey {
PrivateKey::from_base64(SERVER_2_PRIVATE_KEY).unwrap()
}
/// The public portion of [`server_2_private_key`].
fn server_2_public_key() -> PublicKey {
PublicKey::from(&server_2_private_key())
}
/// A test vector of Prio inputs, serialized and encrypted in the Priov2 format,
/// along with a reference sum. The field is always [`FieldPrio2`].
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct Priov2TestVector {
/// Base64 encoded private key for the "first" a.k.a. "PHA" server, which
/// may be used to decrypt `server_1_shares`.
pub server_1_private_key: String,
/// Base64 encoded private key for the non-"first" a.k.a. "facilitator"
/// server, which may be used to decrypt `server_2_shares`.
pub server_2_private_key: String,
/// Dimension (number of buckets) of the inputs
pub dimension: usize,
/// Encrypted shares of Priov2 format inputs for the "first" a.k.a. "PHA"
/// server. The inner `Vec`s are encrypted bytes.
#[serde(
serialize_with = "base64::serialize_bytes",
deserialize_with = "base64::deserialize_bytes"
)]
pub server_1_shares: Vec<Vec<u8>>,
/// Encrypted share of Priov2 format inputs for the non-"first" a.k.a.
/// "facilitator" server.
#[serde(
serialize_with = "base64::serialize_bytes",
deserialize_with = "base64::deserialize_bytes"
)]
pub server_2_shares: Vec<Vec<u8>>,
/// The sum over the inputs.
#[serde(
serialize_with = "base64::serialize_field",
deserialize_with = "base64::deserialize_field"
)]
pub reference_sum: Vec<FieldPrio2>,
/// The version of the crate that generated this test vector
pub prio_crate_version: String,
}
impl Priov2TestVector {
/// Construct a test vector of `number_of_clients` inputs, each of which is a
/// `dimension`-dimension vector of random Boolean values encoded as
/// [`FieldPrio2`].
pub fn new(dimension: usize, number_of_clients: usize) -> Result<Self, TestVectorError> {
let mut client: Client<FieldPrio2> =
Client::new(dimension, server_1_public_key(), server_2_public_key())?;
let mut reference_sum = vec![FieldPrio2::zero(); dimension];
let mut server_1_shares = Vec::with_capacity(number_of_clients);
let mut server_2_shares = Vec::with_capacity(number_of_clients);
let mut rng = rand::thread_rng();
for _ in 0..number_of_clients {
// Generate a random vector of booleans
let data: Vec<FieldPrio2> = (0..dimension)
.map(|_| FieldPrio2::from(rng.gen_range(0..2)))
.collect();
// Update reference sum
for (r, d) in reference_sum.iter_mut().zip(&data) {
*r += *d;
}
let (server_1_share, server_2_share) = client.encode_simple(&data)?;
server_1_shares.push(server_1_share);
server_2_shares.push(server_2_share);
}
Ok(Self {
server_1_private_key: SERVER_1_PRIVATE_KEY.to_owned(),
server_2_private_key: SERVER_2_PRIVATE_KEY.to_owned(),
dimension,
server_1_shares,
server_2_shares,
reference_sum,
prio_crate_version: env!("CARGO_PKG_VERSION").to_owned(),
})
}
/// Construct a [`Client`] that can encrypt input shares to this test
/// vector's servers.
pub fn client(&self) -> Result<Client<FieldPrio2>, TestVectorError> {
Ok(Client::new(
self.dimension,
PublicKey::from(&PrivateKey::from_base64(&self.server_1_private_key).unwrap()),
PublicKey::from(&PrivateKey::from_base64(&self.server_2_private_key).unwrap()),
)?)
}
/// Construct a [`Server`] that can decrypt `server_1_shares`.
pub fn server_1(&self) -> Result<Server<FieldPrio2>, TestVectorError> {
Ok(Server::new(
self.dimension,
true,
PrivateKey::from_base64(&self.server_1_private_key).unwrap(),
)?)
}
/// Construct a [`Server`] that can decrypt `server_2_shares`.
pub fn server_2(&self) -> Result<Server<FieldPrio2>, TestVectorError> {
Ok(Server::new(
self.dimension,
false,
PrivateKey::from_base64(&self.server_2_private_key).unwrap(),
)?)
}
}
mod base64 {
//! Custom serialization module used for some members of struct
//! `Priov2TestVector` so that byte slices are serialized as base64 strings
//! instead of an array of an array of integers when serializing to JSON.
//
// Thank you, Alice! https://users.rust-lang.org/t/serialize-a-vec-u8-to-json-as-base64/57781/2
use crate::field::{FieldElement, FieldPrio2};
use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize_bytes<S: Serializer>(v: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error> {
let base64_vec = v.iter().map(base64::encode).collect();
<Vec<String>>::serialize(&base64_vec, s)
}
pub fn deserialize_bytes<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<Vec<u8>>, D::Error> {
<Vec<String>>::deserialize(d)?
.iter()
.map(|s| base64::decode(s.as_bytes()).map_err(Error::custom))
.collect()
}
pub fn serialize_field<S: Serializer>(v: &[FieldPrio2], s: S) -> Result<S::Ok, S::Error> {
String::serialize(&base64::encode(FieldPrio2::slice_into_byte_vec(v)), s)
}
pub fn deserialize_field<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<FieldPrio2>, D::Error> {
let bytes = base64::decode(String::deserialize(d)?.as_bytes()).map_err(Error::custom)?;
FieldPrio2::byte_slice_into_vec(&bytes).map_err(Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::reconstruct_shares;
#[test]
fn roundtrip_test_vector_serialization() {
let test_vector = Priov2TestVector::new(123, 100).unwrap();
let serialized = serde_json::to_vec(&test_vector).unwrap();
let test_vector_again: Priov2TestVector = serde_json::from_slice(&serialized).unwrap();
assert_eq!(test_vector, test_vector_again);
}
#[test]
fn accumulation_field_priov2() {
let dimension = 123;
let test_vector = Priov2TestVector::new(dimension, 100).unwrap();
let mut server1 = test_vector.server_1().unwrap();
let mut server2 = test_vector.server_2().unwrap();
for (server_1_share, server_2_share) in test_vector
.server_1_shares
.iter()
.zip(&test_vector.server_2_shares)
{
let eval_at = server1.choose_eval_at();
let v1 = server1
.generate_verification_message(eval_at, server_1_share)
.unwrap();
let v2 = server2
.generate_verification_message(eval_at, server_2_share)
.unwrap();
assert!(server1.aggregate(server_1_share, &v1, &v2).unwrap());
assert!(server2.aggregate(server_2_share, &v1, &v2).unwrap());
}
let total1 = server1.total_shares();
let total2 = server2.total_shares();
let reconstructed = reconstruct_shares(total1, total2).unwrap();
assert_eq!(reconstructed, test_vector.reference_sum);
}
}

201
third_party/rust/prio/src/util.rs vendored Normal file
View File

@ -0,0 +1,201 @@
// Copyright (c) 2020 Apple Inc.
// SPDX-License-Identifier: MPL-2.0
//! Utility functions for handling Prio stuff.
use crate::field::{FieldElement, FieldError};
/// Serialization errors
#[derive(Debug, thiserror::Error)]
pub enum SerializeError {
/// Emitted by `unpack_proof[_mut]` if the serialized share+proof has the wrong length
#[error("serialized input has wrong length")]
UnpackInputSizeMismatch,
/// Finite field operation error.
#[error("finite field operation error")]
Field(#[from] FieldError),
}
/// Returns the number of field elements in the proof for given dimension of
/// data elements
///
/// Proof is a vector, where the first `dimension` elements are the data
/// elements, the next 3 elements are the zero terms for polynomials f, g and h
/// and the remaining elements are non-zero points of h(x).
pub fn proof_length(dimension: usize) -> usize {
// number of data items + number of zero terms + N
dimension + 3 + (dimension + 1).next_power_of_two()
}
/// Unpacked proof with subcomponents
#[derive(Debug)]
pub struct UnpackedProof<'a, F: FieldElement> {
/// Data
pub data: &'a [F],
/// Zeroth coefficient of polynomial f
pub f0: &'a F,
/// Zeroth coefficient of polynomial g
pub g0: &'a F,
/// Zeroth coefficient of polynomial h
pub h0: &'a F,
/// Non-zero points of polynomial h
pub points_h_packed: &'a [F],
}
/// Unpacked proof with mutable subcomponents
#[derive(Debug)]
pub struct UnpackedProofMut<'a, F: FieldElement> {
/// Data
pub data: &'a mut [F],
/// Zeroth coefficient of polynomial f
pub f0: &'a mut F,
/// Zeroth coefficient of polynomial g
pub g0: &'a mut F,
/// Zeroth coefficient of polynomial h
pub h0: &'a mut F,
/// Non-zero points of polynomial h
pub points_h_packed: &'a mut [F],
}
/// Unpacks the proof vector into subcomponents
pub(crate) fn unpack_proof<F: FieldElement>(
proof: &[F],
dimension: usize,
) -> Result<UnpackedProof<F>, SerializeError> {
// check the proof length
if proof.len() != proof_length(dimension) {
return Err(SerializeError::UnpackInputSizeMismatch);
}
// split share into components
let (data, rest) = proof.split_at(dimension);
if let ([f0, g0, h0], points_h_packed) = rest.split_at(3) {
Ok(UnpackedProof {
data,
f0,
g0,
h0,
points_h_packed,
})
} else {
Err(SerializeError::UnpackInputSizeMismatch)
}
}
/// Unpacks a mutable proof vector into mutable subcomponents
// TODO(timg): This is public because it is used by tests/tweaks.rs. We should
// refactor that test so it doesn't require the crate to expose this function or
// UnpackedProofMut.
pub fn unpack_proof_mut<F: FieldElement>(
proof: &mut [F],
dimension: usize,
) -> Result<UnpackedProofMut<F>, SerializeError> {
// check the share length
if proof.len() != proof_length(dimension) {
return Err(SerializeError::UnpackInputSizeMismatch);
}
// split share into components
let (data, rest) = proof.split_at_mut(dimension);
if let ([f0, g0, h0], points_h_packed) = rest.split_at_mut(3) {
Ok(UnpackedProofMut {
data,
f0,
g0,
h0,
points_h_packed,
})
} else {
Err(SerializeError::UnpackInputSizeMismatch)
}
}
/// Add two field element arrays together elementwise.
///
/// Returns None, when array lengths are not equal.
pub fn reconstruct_shares<F: FieldElement>(share1: &[F], share2: &[F]) -> Option<Vec<F>> {
if share1.len() != share2.len() {
return None;
}
let mut reconstructed: Vec<F> = vec![F::zero(); share1.len()];
for (r, (s1, s2)) in reconstructed
.iter_mut()
.zip(share1.iter().zip(share2.iter()))
{
*r = *s1 + *s2;
}
Some(reconstructed)
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::field::{Field32, Field64};
use assert_matches::assert_matches;
pub fn secret_share(share: &mut [Field32]) -> Vec<Field32> {
use rand::Rng;
let mut rng = rand::thread_rng();
let mut random = vec![0u32; share.len()];
let mut share2 = vec![Field32::zero(); share.len()];
rng.fill(&mut random[..]);
for (r, f) in random.iter().zip(share2.iter_mut()) {
*f = Field32::from(*r);
}
for (f1, f2) in share.iter_mut().zip(share2.iter()) {
*f1 -= *f2;
}
share2
}
#[test]
fn test_unpack_share_mut() {
let dim = 15;
let len = proof_length(dim);
let mut share = vec![Field32::from(0); len];
let unpacked = unpack_proof_mut(&mut share, dim).unwrap();
*unpacked.f0 = Field32::from(12);
assert_eq!(share[dim], 12);
let mut short_share = vec![Field32::from(0); len - 1];
assert_matches!(
unpack_proof_mut(&mut short_share, dim),
Err(SerializeError::UnpackInputSizeMismatch)
);
}
#[test]
fn test_unpack_share() {
let dim = 15;
let len = proof_length(dim);
let share = vec![Field64::from(0); len];
unpack_proof(&share, dim).unwrap();
let short_share = vec![Field64::from(0); len - 1];
assert_matches!(
unpack_proof(&short_share, dim),
Err(SerializeError::UnpackInputSizeMismatch)
);
}
#[test]
fn secret_sharing() {
let mut share1 = vec![Field32::zero(); 10];
share1[3] = 21.into();
share1[8] = 123.into();
let original_data = share1.clone();
let share2 = secret_share(&mut share1);
let reconstructed = reconstruct_shares(&share1, &share2).unwrap();
assert_eq!(reconstructed, original_data);
}
}

536
third_party/rust/prio/src/vdaf.rs vendored Normal file
View File

@ -0,0 +1,536 @@
// SPDX-License-Identifier: MPL-2.0
//! Verifiable Distributed Aggregation Functions (VDAFs) as described in
//! [[draft-irtf-cfrg-vdaf-01]].
//!
//! [draft-irtf-cfrg-vdaf-01]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/01/
use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode};
use crate::field::{FieldElement, FieldError};
use crate::flp::FlpError;
use crate::prng::PrngError;
use crate::vdaf::prg::Seed;
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::fmt::Debug;
use std::io::Cursor;
/// Errors emitted by this module.
#[derive(Debug, thiserror::Error)]
pub enum VdafError {
/// An error occurred.
#[error("vdaf error: {0}")]
Uncategorized(String),
/// Field error.
#[error("field error: {0}")]
Field(#[from] FieldError),
/// An error occured while parsing a message.
#[error("io error: {0}")]
IoError(#[from] std::io::Error),
/// FLP error.
#[error("flp error: {0}")]
Flp(#[from] FlpError),
/// PRNG error.
#[error("prng error: {0}")]
Prng(#[from] PrngError),
/// failure when calling getrandom().
#[error("getrandom: {0}")]
GetRandom(#[from] getrandom::Error),
}
/// An additive share of a vector of field elements.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Share<F, const L: usize> {
/// An uncompressed share, typically sent to the leader.
Leader(Vec<F>),
/// A compressed share, typically sent to the helper.
Helper(Seed<L>),
}
impl<F: Clone, const L: usize> Share<F, L> {
/// Truncate the Leader's share to the given length. If this is the Helper's share, then this
/// method clones the input without modifying it.
#[cfg(feature = "prio2")]
pub(crate) fn truncated(&self, len: usize) -> Self {
match self {
Self::Leader(ref data) => Self::Leader(data[..len].to_vec()),
Self::Helper(ref seed) => Self::Helper(seed.clone()),
}
}
}
/// Parameters needed to decode a [`Share`]
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum ShareDecodingParameter<const L: usize> {
Leader(usize),
Helper,
}
impl<F: FieldElement, const L: usize> ParameterizedDecode<ShareDecodingParameter<L>>
for Share<F, L>
{
fn decode_with_param(
decoding_parameter: &ShareDecodingParameter<L>,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
match decoding_parameter {
ShareDecodingParameter::Leader(share_length) => {
let mut data = Vec::with_capacity(*share_length);
for _ in 0..*share_length {
data.push(F::decode(bytes)?)
}
Ok(Self::Leader(data))
}
ShareDecodingParameter::Helper => {
let seed = Seed::decode(bytes)?;
Ok(Self::Helper(seed))
}
}
}
}
impl<F: FieldElement, const L: usize> Encode for Share<F, L> {
fn encode(&self, bytes: &mut Vec<u8>) {
match self {
Share::Leader(share_data) => {
for x in share_data {
x.encode(bytes);
}
}
Share::Helper(share_seed) => {
share_seed.encode(bytes);
}
}
}
}
/// The base trait for VDAF schemes. This trait is inherited by traits [`Client`], [`Aggregator`],
/// and [`Collector`], which define the roles of the various parties involved in the execution of
/// the VDAF.
// TODO(brandon): once GATs are stabilized [https://github.com/rust-lang/rust/issues/44265],
// state the "&AggregateShare must implement Into<Vec<u8>>" constraint in terms of a where clause
// on the associated type instead of a where clause on the trait.
pub trait Vdaf: Clone + Debug
where
for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
{
/// The type of Client measurement to be aggregated.
type Measurement: Clone + Debug;
/// The aggregate result of the VDAF execution.
type AggregateResult: Clone + Debug;
/// The aggregation parameter, used by the Aggregators to map their input shares to output
/// shares.
type AggregationParam: Clone + Debug + Decode + Encode;
/// An input share sent by a Client.
type InputShare: Clone + Debug + for<'a> ParameterizedDecode<(&'a Self, usize)> + Encode;
/// An output share recovered from an input share by an Aggregator.
type OutputShare: Clone + Debug;
/// An Aggregator's share of the aggregate result.
type AggregateShare: Aggregatable<OutputShare = Self::OutputShare> + for<'a> TryFrom<&'a [u8]>;
/// The number of Aggregators. The Client generates as many input shares as there are
/// Aggregators.
fn num_aggregators(&self) -> usize;
}
/// The Client's role in the execution of a VDAF.
pub trait Client: Vdaf
where
for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
{
/// Shards a measurement into a sequence of input shares, one for each Aggregator.
fn shard(&self, measurement: &Self::Measurement) -> Result<Vec<Self::InputShare>, VdafError>;
}
/// The Aggregator's role in the execution of a VDAF.
pub trait Aggregator<const L: usize>: Vdaf
where
for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
{
/// State of the Aggregator during the Prepare process.
type PrepareState: Clone + Debug;
/// The type of messages broadcast by each aggregator at each round of the Prepare Process.
///
/// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be
/// associated with any aggregator involved in the execution of the VDAF.
type PrepareShare: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode;
/// Result of preprocessing a round of preparation shares.
///
/// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be
/// associated with any aggregator involved in the execution of the VDAF.
type PrepareMessage: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode;
/// Begins the Prepare process with the other Aggregators. The [`Self::PrepareState`] returned
/// is passed to [`Aggregator::prepare_step`] to get this aggregator's first-round prepare
/// message.
fn prepare_init(
&self,
verify_key: &[u8; L],
agg_id: usize,
agg_param: &Self::AggregationParam,
nonce: &[u8],
input_share: &Self::InputShare,
) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError>;
/// Preprocess a round of preparation shares into a single input to [`Aggregator::prepare_step`].
fn prepare_preprocess<M: IntoIterator<Item = Self::PrepareShare>>(
&self,
inputs: M,
) -> Result<Self::PrepareMessage, VdafError>;
/// Compute the next state transition from the current state and the previous round of input
/// messages. If this returns [`PrepareTransition::Continue`], then the returned
/// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from
/// this round and passed into another call to this method. This continues until this method
/// returns [`PrepareTransition::Finish`], at which point the returned output share may be
/// aggregated. If the method returns an error, the aggregator should consider its input share
/// invalid and not attempt to process it any further.
fn prepare_step(
&self,
state: Self::PrepareState,
input: Self::PrepareMessage,
) -> Result<PrepareTransition<Self, L>, VdafError>;
/// Aggregates a sequence of output shares into an aggregate share.
fn aggregate<M: IntoIterator<Item = Self::OutputShare>>(
&self,
agg_param: &Self::AggregationParam,
output_shares: M,
) -> Result<Self::AggregateShare, VdafError>;
}
/// The Collector's role in the execution of a VDAF.
pub trait Collector: Vdaf
where
for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
{
/// Combines aggregate shares into the aggregate result.
fn unshard<M: IntoIterator<Item = Self::AggregateShare>>(
&self,
agg_param: &Self::AggregationParam,
agg_shares: M,
num_measurements: usize,
) -> Result<Self::AggregateResult, VdafError>;
}
/// A state transition of an Aggregator during the Prepare process.
#[derive(Debug)]
pub enum PrepareTransition<V: Aggregator<L>, const L: usize>
where
for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
{
/// Continue processing.
Continue(V::PrepareState, V::PrepareShare),
/// Finish processing and return the output share.
Finish(V::OutputShare),
}
/// An aggregate share resulting from aggregating output shares together that
/// can merged with aggregate shares of the same type.
pub trait Aggregatable: Clone + Debug + From<Self::OutputShare> {
/// Type of output shares that can be accumulated into an aggregate share.
type OutputShare;
/// Update an aggregate share by merging it with another (`agg_share`).
fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError>;
/// Update an aggregate share by adding `output share`
fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError>;
}
/// An output share comprised of a vector of `F` elements.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct OutputShare<F>(Vec<F>);
impl<F> AsRef<[F]> for OutputShare<F> {
fn as_ref(&self) -> &[F] {
&self.0
}
}
impl<F> From<Vec<F>> for OutputShare<F> {
fn from(other: Vec<F>) -> Self {
Self(other)
}
}
impl<F: FieldElement> TryFrom<&[u8]> for OutputShare<F> {
type Error = FieldError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
fieldvec_try_from_bytes(bytes)
}
}
impl<F: FieldElement> From<&OutputShare<F>> for Vec<u8> {
fn from(output_share: &OutputShare<F>) -> Self {
fieldvec_to_vec(&output_share.0)
}
}
/// An aggregate share suitable for VDAFs whose output shares and aggregate
/// shares are vectors of `F` elements, and an output share needs no special
/// transformation to be merged into an aggregate share.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AggregateShare<F>(Vec<F>);
impl<F> AsRef<[F]> for AggregateShare<F> {
fn as_ref(&self) -> &[F] {
&self.0
}
}
impl<F> From<OutputShare<F>> for AggregateShare<F> {
fn from(other: OutputShare<F>) -> Self {
Self(other.0)
}
}
impl<F> From<Vec<F>> for AggregateShare<F> {
fn from(other: Vec<F>) -> Self {
Self(other)
}
}
impl<F: FieldElement> Aggregatable for AggregateShare<F> {
type OutputShare = OutputShare<F>;
fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError> {
self.sum(agg_share.as_ref())
}
fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError> {
// For prio3 and poplar1, no conversion is needed between output shares and aggregation
// shares.
self.sum(output_share.as_ref())
}
}
impl<F: FieldElement> AggregateShare<F> {
fn sum(&mut self, other: &[F]) -> Result<(), VdafError> {
if self.0.len() != other.len() {
return Err(VdafError::Uncategorized(format!(
"cannot sum shares of different lengths (left = {}, right = {}",
self.0.len(),
other.len()
)));
}
for (x, y) in self.0.iter_mut().zip(other) {
*x += *y;
}
Ok(())
}
}
impl<F: FieldElement> TryFrom<&[u8]> for AggregateShare<F> {
type Error = FieldError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
fieldvec_try_from_bytes(bytes)
}
}
impl<F: FieldElement> From<&AggregateShare<F>> for Vec<u8> {
fn from(aggregate_share: &AggregateShare<F>) -> Self {
fieldvec_to_vec(&aggregate_share.0)
}
}
/// fieldvec_try_from_bytes converts a slice of bytes to a type that is equivalent to a vector of
/// field elements.
#[inline(always)]
fn fieldvec_try_from_bytes<F: FieldElement, T: From<Vec<F>>>(
bytes: &[u8],
) -> Result<T, FieldError> {
F::byte_slice_into_vec(bytes).map(T::from)
}
/// fieldvec_to_vec converts a type that is equivalent to a vector of field elements into a vector
/// of bytes.
#[inline(always)]
fn fieldvec_to_vec<F: FieldElement, T: AsRef<[F]>>(val: T) -> Vec<u8> {
F::slice_into_byte_vec(val.as_ref())
}
#[cfg(test)]
pub(crate) fn run_vdaf<V, M, const L: usize>(
vdaf: &V,
agg_param: &V::AggregationParam,
measurements: M,
) -> Result<V::AggregateResult, VdafError>
where
V: Client + Aggregator<L> + Collector,
for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
M: IntoIterator<Item = V::Measurement>,
{
use rand::prelude::*;
let mut verify_key = [0; L];
thread_rng().fill(&mut verify_key[..]);
// NOTE Here we use the same nonce for each measurement for testing purposes. However, this is
// not secure. In use, the Aggregators MUST ensure that nonces are unique for each measurement.
let nonce = b"this is a nonce";
let mut agg_shares: Vec<Option<V::AggregateShare>> = vec![None; vdaf.num_aggregators()];
let mut num_measurements: usize = 0;
for measurement in measurements.into_iter() {
num_measurements += 1;
let input_shares = vdaf.shard(&measurement)?;
let out_shares = run_vdaf_prepare(vdaf, &verify_key, agg_param, nonce, input_shares)?;
for (out_share, agg_share) in out_shares.into_iter().zip(agg_shares.iter_mut()) {
if let Some(ref mut inner) = agg_share {
inner.merge(&out_share.into())?;
} else {
*agg_share = Some(out_share.into());
}
}
}
let res = vdaf.unshard(
agg_param,
agg_shares.into_iter().map(|option| option.unwrap()),
num_measurements,
)?;
Ok(res)
}
#[cfg(test)]
pub(crate) fn run_vdaf_prepare<V, M, const L: usize>(
vdaf: &V,
verify_key: &[u8; L],
agg_param: &V::AggregationParam,
nonce: &[u8],
input_shares: M,
) -> Result<Vec<V::OutputShare>, VdafError>
where
V: Client + Aggregator<L> + Collector,
for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
M: IntoIterator<Item = V::InputShare>,
{
let input_shares = input_shares
.into_iter()
.map(|input_share| input_share.get_encoded());
let mut states = Vec::new();
let mut outbound = Vec::new();
for (agg_id, input_share) in input_shares.enumerate() {
let (state, msg) = vdaf.prepare_init(
verify_key,
agg_id,
agg_param,
nonce,
&V::InputShare::get_decoded_with_param(&(vdaf, agg_id), &input_share)
.expect("failed to decode input share"),
)?;
states.push(state);
outbound.push(msg.get_encoded());
}
let mut inbound = vdaf
.prepare_preprocess(outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}))?
.get_encoded();
let mut out_shares = Vec::new();
loop {
let mut outbound = Vec::new();
for state in states.iter_mut() {
match vdaf.prepare_step(
state.clone(),
V::PrepareMessage::get_decoded_with_param(state, &inbound)
.expect("failed to decode prep message"),
)? {
PrepareTransition::Continue(new_state, msg) => {
outbound.push(msg.get_encoded());
*state = new_state
}
PrepareTransition::Finish(out_share) => {
out_shares.push(out_share);
}
}
}
if outbound.len() == vdaf.num_aggregators() {
// Another round is required before output shares are computed.
inbound = vdaf
.prepare_preprocess(outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}))?
.get_encoded();
} else if outbound.is_empty() {
// Each Aggregator recovered an output share.
break;
} else {
panic!("Aggregators did not finish the prepare phase at the same time");
}
}
Ok(out_shares)
}
#[cfg(test)]
mod tests {
use super::{AggregateShare, OutputShare};
use crate::field::{Field128, Field64, FieldElement};
use itertools::iterate;
use std::convert::TryFrom;
use std::fmt::Debug;
fn fieldvec_roundtrip_test<F, T>()
where
F: FieldElement,
for<'a> T: Debug + PartialEq + From<Vec<F>> + TryFrom<&'a [u8]>,
for<'a> <T as TryFrom<&'a [u8]>>::Error: Debug,
for<'a> Vec<u8>: From<&'a T>,
{
// Generate a value based on an arbitrary vector of field elements.
let g = F::generator();
let want_value = T::from(iterate(F::one(), |&v| g * v).take(10).collect());
// Round-trip the value through a byte-vector.
let buf: Vec<u8> = (&want_value).into();
let got_value = T::try_from(&buf).unwrap();
assert_eq!(want_value, got_value);
}
#[test]
fn roundtrip_output_share() {
fieldvec_roundtrip_test::<Field64, OutputShare<Field64>>();
fieldvec_roundtrip_test::<Field128, OutputShare<Field128>>();
}
#[test]
fn roundtrip_aggregate_share() {
fieldvec_roundtrip_test::<Field64, AggregateShare<Field64>>();
fieldvec_roundtrip_test::<Field128, AggregateShare<Field128>>();
}
}
#[cfg(feature = "crypto-dependencies")]
pub mod poplar1;
pub mod prg;
#[cfg(feature = "prio2")]
pub mod prio2;
pub mod prio3;
#[cfg(test)]
mod prio3_test;

View File

@ -0,0 +1,901 @@
// SPDX-License-Identifier: MPL-2.0
//! **(NOTE: This module is experimental. Applications should not use it yet.)** This module
//! partially implements the core component of the Poplar protocol [[BBCG+21]]. Named for the
//! Poplar1 [[draft-irtf-cfrg-vdaf-01]], the specification of this VDAF is under active
//! development. Thus this code should be regarded as experimental and not compliant with any
//! existing speciication.
//!
//! TODO Make the input shares stateful so that applications can efficiently evaluate the IDPF over
//! multiple rounds. Question: Will this require API changes to [`crate::vdaf::Vdaf`]?
//!
//! TODO Update trait [`Idpf`] so that the IDPF can have a different field type at the leaves than
//! at the inner nodes.
//!
//! TODO Implement the efficient IDPF of [[BBCG+21]]. [`ToyIdpf`] is not space efficient and is
//! merely intended as a proof-of-concept.
//!
//! [BBCG+21]: https://eprint.iacr.org/2021/017
//! [draft-irtf-cfrg-vdaf-01]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/01/
use std::cmp::Ordering;
use std::collections::{BTreeMap, BTreeSet};
use std::convert::{TryFrom, TryInto};
use std::fmt::Debug;
use std::io::Cursor;
use std::iter::FromIterator;
use std::marker::PhantomData;
use crate::codec::{
decode_u16_items, decode_u24_items, encode_u16_items, encode_u24_items, CodecError, Decode,
Encode, ParameterizedDecode,
};
use crate::field::{split_vector, FieldElement};
use crate::fp::log2;
use crate::prng::Prng;
use crate::vdaf::prg::{Prg, Seed};
use crate::vdaf::{
Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition,
Share, ShareDecodingParameter, Vdaf, VdafError,
};
/// An input for an IDPF ([`Idpf`]).
///
/// TODO Make this an associated type of `Idpf`.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct IdpfInput {
index: usize,
level: usize,
}
impl IdpfInput {
/// Constructs an IDPF input using the first `level` bits of `data`.
pub fn new(data: &[u8], level: usize) -> Result<Self, VdafError> {
if level > data.len() << 3 {
return Err(VdafError::Uncategorized(format!(
"desired bit length ({} bits) exceeds data length ({} bytes)",
level,
data.len()
)));
}
let mut index = 0;
let mut i = 0;
for byte in data {
for j in 0..8 {
let bit = (byte >> j) & 1;
if i < level {
index |= (bit as usize) << i;
}
i += 1;
}
}
Ok(Self { index, level })
}
/// Construct a new input that is a prefix of `self`. Bounds checking is performed by the
/// caller.
fn prefix(&self, level: usize) -> Self {
let index = self.index & ((1 << level) - 1);
Self { index, level }
}
/// Return the position of `self` in the look-up table of `ToyIdpf`.
fn data_index(&self) -> usize {
self.index | (1 << self.level)
}
}
impl Ord for IdpfInput {
fn cmp(&self, other: &Self) -> Ordering {
match self.level.cmp(&other.level) {
Ordering::Equal => self.index.cmp(&other.index),
ord => ord,
}
}
}
impl PartialOrd for IdpfInput {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Encode for IdpfInput {
fn encode(&self, bytes: &mut Vec<u8>) {
(self.index as u64).encode(bytes);
(self.level as u64).encode(bytes);
}
}
impl Decode for IdpfInput {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let index = u64::decode(bytes)? as usize;
let level = u64::decode(bytes)? as usize;
Ok(Self { index, level })
}
}
/// An Incremental Distributed Point Function (IDPF), as defined by [[BBCG+21]].
///
/// [BBCG+21]: https://eprint.iacr.org/2021/017
//
// NOTE(cjpatton) The real IDPF API probably needs to be stateful.
pub trait Idpf<const KEY_LEN: usize, const OUT_LEN: usize>:
Sized + Clone + Debug + Encode + Decode
{
/// The finite field over which the IDPF is defined.
//
// NOTE(cjpatton) The IDPF of [BBCG+21] might use different fields for different levels of the
// prefix tree.
type Field: FieldElement;
/// Generate and return a sequence of IDPF shares for `input`. Parameter `output` is an
/// iterator that is invoked to get the output value for each successive level of the prefix
/// tree.
fn gen<M: IntoIterator<Item = [Self::Field; OUT_LEN]>>(
input: &IdpfInput,
values: M,
) -> Result<[Self; KEY_LEN], VdafError>;
/// Evaluate an IDPF share on `prefix`.
fn eval(&self, prefix: &IdpfInput) -> Result<[Self::Field; OUT_LEN], VdafError>;
}
/// A "toy" IDPF used for demonstration purposes. The space consumed by each share is `O(2^n)`,
/// where `n` is the length of the input. The size of each share is restricted to 1MB, so this IDPF
/// is only suitable for very short inputs.
//
// NOTE(cjpatton) It would be straight-forward to generalize this construction to any `KEY_LEN` and
// `OUT_LEN`.
#[derive(Debug, Clone)]
pub struct ToyIdpf<F> {
data0: Vec<F>,
data1: Vec<F>,
level: usize,
}
impl<F: FieldElement> Idpf<2, 2> for ToyIdpf<F> {
type Field = F;
fn gen<M: IntoIterator<Item = [Self::Field; 2]>>(
input: &IdpfInput,
values: M,
) -> Result<[Self; 2], VdafError> {
const MAX_DATA_BYTES: usize = 1024 * 1024; // 1MB
let max_input_len =
usize::try_from(log2((MAX_DATA_BYTES / F::ENCODED_SIZE) as u128)).unwrap();
if input.level > max_input_len {
return Err(VdafError::Uncategorized(format!(
"input length ({}) exceeds maximum of ({})",
input.level, max_input_len
)));
}
let data_len = 1 << (input.level + 1);
let mut data0 = vec![F::zero(); data_len];
let mut data1 = vec![F::zero(); data_len];
let mut values = values.into_iter();
for level in 0..input.level + 1 {
let value = values.next().unwrap();
let index = input.prefix(level).data_index();
data0[index] = value[0];
data1[index] = value[1];
}
let mut data0 = split_vector(&data0, 2)?.into_iter();
let mut data1 = split_vector(&data1, 2)?.into_iter();
Ok([
ToyIdpf {
data0: data0.next().unwrap(),
data1: data1.next().unwrap(),
level: input.level,
},
ToyIdpf {
data0: data0.next().unwrap(),
data1: data1.next().unwrap(),
level: input.level,
},
])
}
fn eval(&self, prefix: &IdpfInput) -> Result<[F; 2], VdafError> {
if prefix.level > self.level {
return Err(VdafError::Uncategorized(format!(
"prefix length ({}) exceeds input length ({})",
prefix.level, self.level
)));
}
let index = prefix.data_index();
Ok([self.data0[index], self.data1[index]])
}
}
impl<F: FieldElement> Encode for ToyIdpf<F> {
fn encode(&self, bytes: &mut Vec<u8>) {
encode_u24_items(bytes, &(), &self.data0);
encode_u24_items(bytes, &(), &self.data1);
(self.level as u64).encode(bytes);
}
}
impl<F: FieldElement> Decode for ToyIdpf<F> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let data0 = decode_u24_items(&(), bytes)?;
let data1 = decode_u24_items(&(), bytes)?;
let level = u64::decode(bytes)? as usize;
Ok(Self {
data0,
data1,
level,
})
}
}
impl Encode for BTreeSet<IdpfInput> {
fn encode(&self, bytes: &mut Vec<u8>) {
// Encodes the aggregation parameter as a variable length vector of
// [`IdpfInput`], because the size of the aggregation parameter is not
// determined by the VDAF.
let items: Vec<IdpfInput> = self.iter().map(IdpfInput::clone).collect();
encode_u24_items(bytes, &(), &items);
}
}
impl Decode for BTreeSet<IdpfInput> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let inputs = decode_u24_items(&(), bytes)?;
Ok(Self::from_iter(inputs.into_iter()))
}
}
/// An input share for the `poplar1` VDAF.
#[derive(Debug, Clone)]
pub struct Poplar1InputShare<I: Idpf<2, 2>, const L: usize> {
/// IDPF share of input
idpf: I,
/// PRNG seed used to generate the aggregator's share of the randomness used in the first part
/// of the sketching protocol.
sketch_start_seed: Seed<L>,
/// Aggregator's share of the randomness used in the second part of the sketching protocol.
sketch_next: Share<I::Field, L>,
}
impl<I: Idpf<2, 2>, const L: usize> Encode for Poplar1InputShare<I, L> {
fn encode(&self, bytes: &mut Vec<u8>) {
self.idpf.encode(bytes);
self.sketch_start_seed.encode(bytes);
self.sketch_next.encode(bytes);
}
}
impl<'a, I, P, const L: usize> ParameterizedDecode<(&'a Poplar1<I, P, L>, usize)>
for Poplar1InputShare<I, L>
where
I: Idpf<2, 2>,
{
fn decode_with_param(
(poplar1, agg_id): &(&'a Poplar1<I, P, L>, usize),
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
let idpf = I::decode(bytes)?;
let sketch_start_seed = Seed::decode(bytes)?;
let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?;
let share_decoding_parameter = if is_leader {
// The sketch is two field elements for every bit of input, plus two more, corresponding
// to construction of shares in `Poplar1::shard`.
ShareDecodingParameter::Leader((poplar1.input_length + 1) * 2)
} else {
ShareDecodingParameter::Helper
};
let sketch_next =
<Share<I::Field, L>>::decode_with_param(&share_decoding_parameter, bytes)?;
Ok(Self {
idpf,
sketch_start_seed,
sketch_next,
})
}
}
/// The poplar1 VDAF.
#[derive(Debug)]
pub struct Poplar1<I, P, const L: usize> {
input_length: usize,
phantom: PhantomData<(I, P)>,
}
impl<I, P, const L: usize> Poplar1<I, P, L> {
/// Create an instance of the poplar1 VDAF. The caller provides a cipher suite `suite` used for
/// deriving pseudorandom sequences of field elements, and a input length in bits, corresponding
/// to `BITS` as defined in the [VDAF specification][1].
///
/// [1]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/01/
pub fn new(bits: usize) -> Self {
Self {
input_length: bits,
phantom: PhantomData,
}
}
}
impl<I, P, const L: usize> Clone for Poplar1<I, P, L> {
fn clone(&self) -> Self {
Self::new(self.input_length)
}
}
impl<I, P, const L: usize> Vdaf for Poplar1<I, P, L>
where
I: Idpf<2, 2>,
P: Prg<L>,
{
type Measurement = IdpfInput;
type AggregateResult = BTreeMap<IdpfInput, u64>;
type AggregationParam = BTreeSet<IdpfInput>;
type InputShare = Poplar1InputShare<I, L>;
type OutputShare = OutputShare<I::Field>;
type AggregateShare = AggregateShare<I::Field>;
fn num_aggregators(&self) -> usize {
2
}
}
impl<I, P, const L: usize> Client for Poplar1<I, P, L>
where
I: Idpf<2, 2>,
P: Prg<L>,
{
#[allow(clippy::many_single_char_names)]
fn shard(&self, input: &IdpfInput) -> Result<Vec<Poplar1InputShare<I, L>>, VdafError> {
let idpf_values: Vec<[I::Field; 2]> = Prng::new()?
.take(input.level + 1)
.map(|k| [I::Field::one(), k])
.collect();
// For each level of the prefix tree, generate correlated randomness that the aggregators use
// to validate the output. See [BBCG+21, Appendix C.4].
let leader_sketch_start_seed = Seed::generate()?;
let helper_sketch_start_seed = Seed::generate()?;
let helper_sketch_next_seed = Seed::generate()?;
let mut leader_sketch_start_prng: Prng<I::Field, _> =
Prng::from_seed_stream(P::seed_stream(&leader_sketch_start_seed, b""));
let mut helper_sketch_start_prng: Prng<I::Field, _> =
Prng::from_seed_stream(P::seed_stream(&helper_sketch_start_seed, b""));
let mut helper_sketch_next_prng: Prng<I::Field, _> =
Prng::from_seed_stream(P::seed_stream(&helper_sketch_next_seed, b""));
let mut leader_sketch_next: Vec<I::Field> = Vec::with_capacity(2 * idpf_values.len());
for value in idpf_values.iter() {
let k = value[1];
// [BBCG+21, Appendix C.4]
//
// $(a, b, c)$
let a = leader_sketch_start_prng.get() + helper_sketch_start_prng.get();
let b = leader_sketch_start_prng.get() + helper_sketch_start_prng.get();
let c = leader_sketch_start_prng.get() + helper_sketch_start_prng.get();
// $A = -2a + k$
// $B = a^2 + b + -ak + c$
let d = k - (a + a);
let e = (a * a) + b - (a * k) + c;
leader_sketch_next.push(d - helper_sketch_next_prng.get());
leader_sketch_next.push(e - helper_sketch_next_prng.get());
}
// Generate IDPF shares of the data and authentication vectors.
let idpf_shares = I::gen(input, idpf_values)?;
Ok(vec![
Poplar1InputShare {
idpf: idpf_shares[0].clone(),
sketch_start_seed: leader_sketch_start_seed,
sketch_next: Share::Leader(leader_sketch_next),
},
Poplar1InputShare {
idpf: idpf_shares[1].clone(),
sketch_start_seed: helper_sketch_start_seed,
sketch_next: Share::Helper(helper_sketch_next_seed),
},
])
}
}
fn get_level(agg_param: &BTreeSet<IdpfInput>) -> Result<usize, VdafError> {
let mut level = None;
for prefix in agg_param {
if let Some(l) = level {
if prefix.level != l {
return Err(VdafError::Uncategorized(
"prefixes must all have the same length".to_string(),
));
}
} else {
level = Some(prefix.level);
}
}
match level {
Some(level) => Ok(level),
None => Err(VdafError::Uncategorized("prefix set is empty".to_string())),
}
}
impl<I, P, const L: usize> Aggregator<L> for Poplar1<I, P, L>
where
I: Idpf<2, 2>,
P: Prg<L>,
{
type PrepareState = Poplar1PrepareState<I::Field>;
type PrepareShare = Poplar1PrepareMessage<I::Field>;
type PrepareMessage = Poplar1PrepareMessage<I::Field>;
#[allow(clippy::type_complexity)]
fn prepare_init(
&self,
verify_key: &[u8; L],
agg_id: usize,
agg_param: &BTreeSet<IdpfInput>,
nonce: &[u8],
input_share: &Self::InputShare,
) -> Result<
(
Poplar1PrepareState<I::Field>,
Poplar1PrepareMessage<I::Field>,
),
VdafError,
> {
let level = get_level(agg_param)?;
let is_leader = role_try_from(agg_id)?;
// Derive the verification randomness.
let mut p = P::init(verify_key);
p.update(nonce);
let mut verify_rand_prng: Prng<I::Field, _> = Prng::from_seed_stream(p.into_seed_stream());
// Evaluate the IDPF shares and compute the polynomial coefficients.
let mut z = [I::Field::zero(); 3];
let mut output_share = Vec::with_capacity(agg_param.len());
for prefix in agg_param.iter() {
let value = input_share.idpf.eval(prefix)?;
let (v, k) = (value[0], value[1]);
let r = verify_rand_prng.get();
// [BBCG+21, Appendix C.4]
//
// $(z_\sigma, z^*_\sigma, z^{**}_\sigma)$
let tmp = r * v;
z[0] += tmp;
z[1] += r * tmp;
z[2] += r * k;
output_share.push(v);
}
// [BBCG+21, Appendix C.4]
//
// Add blind shares $(a_\sigma b_\sigma, c_\sigma)$
//
// NOTE(cjpatton) We can make this faster by a factor of 3 by using three seed shares instead
// of one. On the other hand, if the input shares are made stateful, then we could store
// the PRNG state theire and avoid fast-forwarding.
let mut prng = Prng::<I::Field, _>::from_seed_stream(P::seed_stream(
&input_share.sketch_start_seed,
b"",
))
.skip(3 * level);
z[0] += prng.next().unwrap();
z[1] += prng.next().unwrap();
z[2] += prng.next().unwrap();
let (d, e) = match &input_share.sketch_next {
Share::Leader(data) => (data[2 * level], data[2 * level + 1]),
Share::Helper(seed) => {
let mut prng = Prng::<I::Field, _>::from_seed_stream(P::seed_stream(seed, b""))
.skip(2 * level);
(prng.next().unwrap(), prng.next().unwrap())
}
};
let x = if is_leader {
I::Field::one()
} else {
I::Field::zero()
};
Ok((
Poplar1PrepareState {
sketch: SketchState::RoundOne,
output_share: OutputShare(output_share),
d,
e,
x,
},
Poplar1PrepareMessage(z.to_vec()),
))
}
fn prepare_preprocess<M: IntoIterator<Item = Poplar1PrepareMessage<I::Field>>>(
&self,
inputs: M,
) -> Result<Poplar1PrepareMessage<I::Field>, VdafError> {
let mut output: Option<Vec<I::Field>> = None;
let mut count = 0;
for data_share in inputs.into_iter() {
count += 1;
if let Some(ref mut data) = output {
if data_share.0.len() != data.len() {
return Err(VdafError::Uncategorized(format!(
"unexpected message length: got {}; want {}",
data_share.0.len(),
data.len(),
)));
}
for (x, y) in data.iter_mut().zip(data_share.0.iter()) {
*x += *y;
}
} else {
output = Some(data_share.0);
}
}
if count != 2 {
return Err(VdafError::Uncategorized(format!(
"unexpected message count: got {}; want 2",
count,
)));
}
Ok(Poplar1PrepareMessage(output.unwrap()))
}
fn prepare_step(
&self,
mut state: Poplar1PrepareState<I::Field>,
msg: Poplar1PrepareMessage<I::Field>,
) -> Result<PrepareTransition<Self, L>, VdafError> {
match &state.sketch {
SketchState::RoundOne => {
if msg.0.len() != 3 {
return Err(VdafError::Uncategorized(format!(
"unexpected message length ({:?}): got {}; want 3",
state.sketch,
msg.0.len(),
)));
}
// Compute polynomial coefficients.
let z: [I::Field; 3] = msg.0.try_into().unwrap();
let y_share =
vec![(state.d * z[0]) + state.e + state.x * ((z[0] * z[0]) - z[1] - z[2])];
state.sketch = SketchState::RoundTwo;
Ok(PrepareTransition::Continue(
state,
Poplar1PrepareMessage(y_share),
))
}
SketchState::RoundTwo => {
if msg.0.len() != 1 {
return Err(VdafError::Uncategorized(format!(
"unexpected message length ({:?}): got {}; want 1",
state.sketch,
msg.0.len(),
)));
}
let y = msg.0[0];
if y != I::Field::zero() {
return Err(VdafError::Uncategorized(format!(
"output is invalid: polynomial evaluated to {}; want {}",
y,
I::Field::zero(),
)));
}
Ok(PrepareTransition::Finish(state.output_share))
}
}
}
fn aggregate<M: IntoIterator<Item = OutputShare<I::Field>>>(
&self,
agg_param: &BTreeSet<IdpfInput>,
output_shares: M,
) -> Result<AggregateShare<I::Field>, VdafError> {
let mut agg_share = AggregateShare(vec![I::Field::zero(); agg_param.len()]);
for output_share in output_shares.into_iter() {
agg_share.accumulate(&output_share)?;
}
Ok(agg_share)
}
}
/// A prepare message sent exchanged between Poplar1 aggregators
#[derive(Clone, Debug)]
pub struct Poplar1PrepareMessage<F>(Vec<F>);
impl<F> AsRef<[F]> for Poplar1PrepareMessage<F> {
fn as_ref(&self) -> &[F] {
&self.0
}
}
impl<F: FieldElement> Encode for Poplar1PrepareMessage<F> {
fn encode(&self, bytes: &mut Vec<u8>) {
// TODO: This is encoded as a variable length vector of F, but we may
// be able to make this a fixed-length vector for specific Poplar1
// instantations
encode_u16_items(bytes, &(), &self.0);
}
}
impl<F: FieldElement> ParameterizedDecode<Poplar1PrepareState<F>> for Poplar1PrepareMessage<F> {
fn decode_with_param(
_decoding_parameter: &Poplar1PrepareState<F>,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
// TODO: This is decoded as a variable length vector of F, but we may be
// able to make this a fixed-length vector for specific Poplar1
// instantiations.
let items = decode_u16_items(&(), bytes)?;
Ok(Self(items))
}
}
/// The state of each Aggregator during the Prepare process.
#[derive(Clone, Debug)]
pub struct Poplar1PrepareState<F> {
/// State of the secure sketching protocol.
sketch: SketchState,
/// The output share.
output_share: OutputShare<F>,
/// Aggregator's share of $A = -2a + k$.
d: F,
/// Aggregator's share of $B = a^2 + b -ak + c$.
e: F,
/// Equal to 1 if this Aggregator is the "leader" and 0 otherwise.
x: F,
}
#[derive(Clone, Debug)]
enum SketchState {
RoundOne,
RoundTwo,
}
impl<I, P, const L: usize> Collector for Poplar1<I, P, L>
where
I: Idpf<2, 2>,
P: Prg<L>,
{
fn unshard<M: IntoIterator<Item = AggregateShare<I::Field>>>(
&self,
agg_param: &BTreeSet<IdpfInput>,
agg_shares: M,
_num_measurements: usize,
) -> Result<BTreeMap<IdpfInput, u64>, VdafError> {
let mut agg_data = AggregateShare(vec![I::Field::zero(); agg_param.len()]);
for agg_share in agg_shares.into_iter() {
agg_data.merge(&agg_share)?;
}
let mut agg = BTreeMap::new();
for (prefix, count) in agg_param.iter().zip(agg_data.as_ref()) {
let count = <I::Field as FieldElement>::Integer::from(*count);
let count: u64 = count
.try_into()
.map_err(|_| VdafError::Uncategorized("aggregate overflow".to_string()))?;
agg.insert(*prefix, count);
}
Ok(agg)
}
}
fn role_try_from(agg_id: usize) -> Result<bool, VdafError> {
match agg_id {
0 => Ok(true),
1 => Ok(false),
_ => Err(VdafError::Uncategorized("unexpected aggregator id".into())),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::field::Field128;
use crate::vdaf::prg::PrgAes128;
use crate::vdaf::{run_vdaf, run_vdaf_prepare};
use rand::prelude::*;
#[test]
fn test_idpf() {
// IDPF input equality tests.
assert_eq!(
IdpfInput::new(b"hello", 40).unwrap(),
IdpfInput::new(b"hello", 40).unwrap()
);
assert_eq!(
IdpfInput::new(b"hi", 9).unwrap(),
IdpfInput::new(b"ha", 9).unwrap(),
);
assert_eq!(
IdpfInput::new(b"hello", 25).unwrap(),
IdpfInput::new(b"help", 25).unwrap()
);
assert_ne!(
IdpfInput::new(b"hello", 40).unwrap(),
IdpfInput::new(b"hello", 39).unwrap()
);
assert_ne!(
IdpfInput::new(b"hello", 40).unwrap(),
IdpfInput::new(b"hell-", 40).unwrap()
);
// IDPF uniqueness tests
let mut unique = BTreeSet::new();
assert!(unique.insert(IdpfInput::new(b"hello", 40).unwrap()));
assert!(!unique.insert(IdpfInput::new(b"hello", 40).unwrap()));
assert!(unique.insert(IdpfInput::new(b"hello", 39).unwrap()));
assert!(unique.insert(IdpfInput::new(b"bye", 20).unwrap()));
// Generate IDPF keys.
let input = IdpfInput::new(b"hi", 16).unwrap();
let keys = ToyIdpf::<Field128>::gen(
&input,
std::iter::repeat([Field128::one(), Field128::one()]),
)
.unwrap();
// Try evaluating the IDPF keys on all prefixes.
for prefix_len in 0..input.level + 1 {
let res = eval_idpf(
&keys,
&input.prefix(prefix_len),
&[Field128::one(), Field128::one()],
);
assert!(res.is_ok(), "prefix_len={} error: {:?}", prefix_len, res);
}
// Try evaluating the IDPF keys on incorrect prefixes.
eval_idpf(
&keys,
&IdpfInput::new(&[2], 2).unwrap(),
&[Field128::zero(), Field128::zero()],
)
.unwrap();
eval_idpf(
&keys,
&IdpfInput::new(&[23, 1], 12).unwrap(),
&[Field128::zero(), Field128::zero()],
)
.unwrap();
}
fn eval_idpf<I, const KEY_LEN: usize, const OUT_LEN: usize>(
keys: &[I; KEY_LEN],
input: &IdpfInput,
expected_output: &[I::Field; OUT_LEN],
) -> Result<(), VdafError>
where
I: Idpf<KEY_LEN, OUT_LEN>,
{
let mut output = [I::Field::zero(); OUT_LEN];
for key in keys {
let output_share = key.eval(input)?;
for (x, y) in output.iter_mut().zip(output_share) {
*x += y;
}
}
if expected_output != &output {
return Err(VdafError::Uncategorized(format!(
"eval_idpf(): unexpected output: got {:?}; want {:?}",
output, expected_output
)));
}
Ok(())
}
#[test]
fn test_poplar1() {
const INPUT_LEN: usize = 8;
let vdaf: Poplar1<ToyIdpf<Field128>, PrgAes128, 16> = Poplar1::new(INPUT_LEN);
assert_eq!(vdaf.num_aggregators(), 2);
// Run the VDAF input-distribution algorithm.
let input = vec![IdpfInput::new(&[0b0110_1000], INPUT_LEN).unwrap()];
let mut agg_param = BTreeSet::new();
agg_param.insert(input[0]);
check_btree(&run_vdaf(&vdaf, &agg_param, input.clone()).unwrap(), &[1]);
// Try evaluating the VDAF on each prefix of the input.
for prefix_len in 0..input[0].level + 1 {
let mut agg_param = BTreeSet::new();
agg_param.insert(input[0].prefix(prefix_len));
check_btree(&run_vdaf(&vdaf, &agg_param, input.clone()).unwrap(), &[1]);
}
// Try various prefixes.
let prefix_len = 4;
let mut agg_param = BTreeSet::new();
// At length 4, the next two prefixes are equal. Neither one matches the input.
agg_param.insert(IdpfInput::new(&[0b0000_0000], prefix_len).unwrap());
agg_param.insert(IdpfInput::new(&[0b0001_0000], prefix_len).unwrap());
agg_param.insert(IdpfInput::new(&[0b0000_0001], prefix_len).unwrap());
agg_param.insert(IdpfInput::new(&[0b0000_1101], prefix_len).unwrap());
// At length 4, the next two prefixes are equal. Both match the input.
agg_param.insert(IdpfInput::new(&[0b0111_1101], prefix_len).unwrap());
agg_param.insert(IdpfInput::new(&[0b0000_1101], prefix_len).unwrap());
let aggregate = run_vdaf(&vdaf, &agg_param, input.clone()).unwrap();
assert_eq!(aggregate.len(), agg_param.len());
check_btree(
&aggregate,
// We put six prefixes in the aggregation parameter, but the vector we get back is only
// 4 elements because at the given prefix length, some of the prefixes are equal.
&[0, 0, 0, 1],
);
let mut verify_key = [0; 16];
thread_rng().fill(&mut verify_key[..]);
let nonce = b"this is a nonce";
// Try evaluating the VDAF with an invalid aggregation parameter. (It's an error to have a
// mixture of prefix lengths.)
let mut agg_param = BTreeSet::new();
agg_param.insert(IdpfInput::new(&[0b0000_0111], 6).unwrap());
agg_param.insert(IdpfInput::new(&[0b0000_1000], 7).unwrap());
let input_shares = vdaf.shard(&input[0]).unwrap();
run_vdaf_prepare(&vdaf, &verify_key, &agg_param, nonce, input_shares).unwrap_err();
// Try evaluating the VDAF with malformed inputs.
//
// This IDPF key pair evaluates to 1 everywhere, which is illegal.
let mut input_shares = vdaf.shard(&input[0]).unwrap();
for (i, x) in input_shares[0].idpf.data0.iter_mut().enumerate() {
if i != input[0].index {
*x += Field128::one();
}
}
let mut agg_param = BTreeSet::new();
agg_param.insert(IdpfInput::new(&[0b0000_0111], 8).unwrap());
run_vdaf_prepare(&vdaf, &verify_key, &agg_param, nonce, input_shares).unwrap_err();
// This IDPF key pair has a garbled authentication vector.
let mut input_shares = vdaf.shard(&input[0]).unwrap();
for x in input_shares[0].idpf.data1.iter_mut() {
*x = Field128::zero();
}
let mut agg_param = BTreeSet::new();
agg_param.insert(IdpfInput::new(&[0b0000_0111], 8).unwrap());
run_vdaf_prepare(&vdaf, &verify_key, &agg_param, nonce, input_shares).unwrap_err();
}
fn check_btree(btree: &BTreeMap<IdpfInput, u64>, counts: &[u64]) {
for (got, want) in btree.values().zip(counts.iter()) {
assert_eq!(got, want, "got {:?} want {:?}", btree.values(), counts);
}
}
}

255
third_party/rust/prio/src/vdaf/prg.rs vendored Normal file
View File

@ -0,0 +1,255 @@
// SPDX-License-Identifier: MPL-2.0
//! Implementations of PRGs specified in [[draft-irtf-cfrg-vdaf-01]].
//!
//! [draft-irtf-cfrg-vdaf-01]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/01/
use crate::vdaf::{CodecError, Decode, Encode};
#[cfg(feature = "crypto-dependencies")]
use aes::{
cipher::{KeyIvInit, StreamCipher},
Aes128,
};
#[cfg(feature = "crypto-dependencies")]
use cmac::{Cmac, Mac};
#[cfg(feature = "crypto-dependencies")]
use ctr::Ctr64BE;
#[cfg(feature = "crypto-dependencies")]
use std::fmt::Formatter;
use std::{
fmt::Debug,
io::{Cursor, Read},
};
/// Function pointer to fill a buffer with random bytes. Under normal operation,
/// `getrandom::getrandom()` will be used, but other implementations can be used to control
/// randomness when generating or verifying test vectors.
pub(crate) type RandSource = fn(&mut [u8]) -> Result<(), getrandom::Error>;
/// Input of [`Prg`].
#[derive(Clone, Debug, Eq)]
pub struct Seed<const L: usize>(pub(crate) [u8; L]);
impl<const L: usize> Seed<L> {
/// Generate a uniform random seed.
pub fn generate() -> Result<Self, getrandom::Error> {
Self::from_rand_source(getrandom::getrandom)
}
pub(crate) fn from_rand_source(rand_source: RandSource) -> Result<Self, getrandom::Error> {
let mut seed = [0; L];
rand_source(&mut seed)?;
Ok(Self(seed))
}
pub(crate) fn uninitialized() -> Self {
Self([0; L])
}
pub(crate) fn xor_accumulate(&mut self, other: &Self) {
for i in 0..L {
self.0[i] ^= other.0[i]
}
}
pub(crate) fn xor(&mut self, left: &Self, right: &Self) {
for i in 0..L {
self.0[i] = left.0[i] ^ right.0[i]
}
}
}
impl<const L: usize> AsRef<[u8; L]> for Seed<L> {
fn as_ref(&self) -> &[u8; L] {
&self.0
}
}
impl<const L: usize> PartialEq for Seed<L> {
fn eq(&self, other: &Self) -> bool {
// Do constant-time compare.
let mut r = 0;
for (x, y) in (&self.0[..]).iter().zip(&other.0[..]) {
r |= x ^ y;
}
r == 0
}
}
impl<const L: usize> Encode for Seed<L> {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.extend_from_slice(&self.0[..]);
}
}
impl<const L: usize> Decode for Seed<L> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let mut seed = [0; L];
bytes.read_exact(&mut seed)?;
Ok(Seed(seed))
}
}
/// A stream of pseudorandom bytes derived from a seed.
pub trait SeedStream {
/// Fill `buf` with the next `buf.len()` bytes of output.
fn fill(&mut self, buf: &mut [u8]);
}
/// A pseudorandom generator (PRG) with the interface specified in [[draft-irtf-cfrg-vdaf-01]].
///
/// [draft-irtf-cfrg-vdaf-01]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/01/
pub trait Prg<const L: usize>: Clone + Debug {
/// The type of stream produced by this PRG.
type SeedStream: SeedStream;
/// Construct an instance of [`Prg`] with the given seed.
fn init(seed_bytes: &[u8; L]) -> Self;
/// Update the PRG state by passing in the next fragment of the info string. The final info
/// string is assembled from the concatenation of sequence of fragments passed to this method.
fn update(&mut self, data: &[u8]);
/// Finalize the PRG state, producing a seed stream.
fn into_seed_stream(self) -> Self::SeedStream;
/// Finalize the PRG state, producing a seed.
fn into_seed(self) -> Seed<L> {
let mut new_seed = [0; L];
let mut seed_stream = self.into_seed_stream();
seed_stream.fill(&mut new_seed);
Seed(new_seed)
}
/// Construct a seed stream from the given seed and info string.
fn seed_stream(seed: &Seed<L>, info: &[u8]) -> Self::SeedStream {
let mut prg = Self::init(seed.as_ref());
prg.update(info);
prg.into_seed_stream()
}
}
/// The PRG based on AES128 as specified in [[draft-irtf-cfrg-vdaf-01]].
///
/// [draft-irtf-cfrg-vdaf-01]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/01/
#[derive(Clone, Debug)]
#[cfg(feature = "crypto-dependencies")]
pub struct PrgAes128(Cmac<Aes128>);
#[cfg(feature = "crypto-dependencies")]
impl Prg<16> for PrgAes128 {
type SeedStream = SeedStreamAes128;
fn init(seed_bytes: &[u8; 16]) -> Self {
Self(Cmac::new_from_slice(seed_bytes).unwrap())
}
fn update(&mut self, data: &[u8]) {
self.0.update(data);
}
fn into_seed_stream(self) -> SeedStreamAes128 {
let key = self.0.finalize().into_bytes();
SeedStreamAes128::new(&key, &[0; 16])
}
}
/// The key stream produced by AES128 in CTR-mode.
#[cfg(feature = "crypto-dependencies")]
pub struct SeedStreamAes128(Ctr64BE<Aes128>);
#[cfg(feature = "crypto-dependencies")]
impl SeedStreamAes128 {
pub(crate) fn new(key: &[u8], iv: &[u8]) -> Self {
SeedStreamAes128(Ctr64BE::<Aes128>::new(key.into(), iv.into()))
}
}
#[cfg(feature = "crypto-dependencies")]
impl SeedStream for SeedStreamAes128 {
fn fill(&mut self, buf: &mut [u8]) {
buf.fill(0);
self.0.apply_keystream(buf);
}
}
#[cfg(feature = "crypto-dependencies")]
impl Debug for SeedStreamAes128 {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
// Ctr64BE<Aes128> does not implement Debug, but [`ctr::CtrCore`][1] does, and we get that
// with [`cipher::StreamCipherCoreWrapper::get_core`][2].
//
// [1]: https://docs.rs/ctr/latest/ctr/struct.CtrCore.html
// [2]: https://docs.rs/cipher/latest/cipher/struct.StreamCipherCoreWrapper.html
self.0.get_core().fmt(f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{field::Field128, prng::Prng};
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
#[derive(Deserialize, Serialize)]
struct PrgTestVector {
#[serde(with = "hex")]
seed: Vec<u8>,
#[serde(with = "hex")]
info: Vec<u8>,
length: usize,
#[serde(with = "hex")]
derived_seed: Vec<u8>,
#[serde(with = "hex")]
expanded_vec_field128: Vec<u8>,
}
// Test correctness of dervied methods.
fn test_prg<P, const L: usize>()
where
P: Prg<L>,
{
let seed = Seed::generate().unwrap();
let info = b"info string";
let mut prg = P::init(seed.as_ref());
prg.update(info);
let mut want: Seed<L> = Seed::uninitialized();
prg.clone().into_seed_stream().fill(&mut want.0[..]);
let got = prg.clone().into_seed();
assert_eq!(got, want);
let mut want = [0; 45];
prg.clone().into_seed_stream().fill(&mut want);
let mut got = [0; 45];
P::seed_stream(&seed, info).fill(&mut got);
assert_eq!(got, want);
}
#[test]
fn prg_aes128() {
let t: PrgTestVector =
serde_json::from_str(include_str!("test_vec/01/PrgAes128.json")).unwrap();
let mut prg = PrgAes128::init(&t.seed.try_into().unwrap());
prg.update(&t.info);
assert_eq!(
prg.clone().into_seed(),
Seed(t.derived_seed.try_into().unwrap())
);
let mut bytes = std::io::Cursor::new(t.expanded_vec_field128.as_slice());
let mut want = Vec::with_capacity(t.length);
while (bytes.position() as usize) < t.expanded_vec_field128.len() {
want.push(Field128::decode(&mut bytes).unwrap())
}
let got: Vec<Field128> = Prng::from_seed_stream(prg.clone().into_seed_stream())
.take(t.length)
.collect();
assert_eq!(got, want);
test_prg::<PrgAes128, 16>();
}
}

419
third_party/rust/prio/src/vdaf/prio2.rs vendored Normal file
View File

@ -0,0 +1,419 @@
// SPDX-License-Identifier: MPL-2.0
//! Port of the ENPA Prio system to a VDAF. It is backwards compatible with
//! [`Client`](crate::client::Client) and [`Server`](crate::server::Server).
use crate::{
client as v2_client,
codec::{CodecError, Decode, Encode, ParameterizedDecode},
field::{FieldElement, FieldPrio2},
prng::Prng,
server as v2_server,
util::proof_length,
vdaf::{
prg::Seed, Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare,
PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError,
},
};
use ring::hmac;
use std::{
convert::{TryFrom, TryInto},
io::Cursor,
};
/// The Prio2 VDAF. It supports the same measurement type as
/// [`Prio3Aes128CountVec`](crate::vdaf::prio3::Prio3Aes128CountVec) but uses the proof system
/// and finite field deployed in ENPA.
#[derive(Clone, Debug)]
pub struct Prio2 {
input_len: usize,
}
impl Prio2 {
/// Returns an instance of the VDAF for the given input length.
pub fn new(input_len: usize) -> Result<Self, VdafError> {
let n = (input_len + 1).next_power_of_two();
if let Ok(size) = u32::try_from(2 * n) {
if size > FieldPrio2::generator_order() {
return Err(VdafError::Uncategorized(
"input size exceeds field capacity".into(),
));
}
} else {
return Err(VdafError::Uncategorized(
"input size exceeds memory capacity".into(),
));
}
Ok(Prio2 { input_len })
}
/// Prepare an input share for aggregation using the given field element `query_rand` to
/// compute the verifier share.
///
/// In the [`Aggregator`](crate::vdaf::Aggregator) trait implementation for [`Prio2`], the
/// query randomness is computed jointly by the Aggregators. This method is designed to be used
/// in applications, like ENPA, in which the query randomness is instead chosen by a
/// third-party.
pub fn prepare_init_with_query_rand(
&self,
query_rand: FieldPrio2,
input_share: &Share<FieldPrio2, 32>,
is_leader: bool,
) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
let expanded_data: Option<Vec<FieldPrio2>> = match input_share {
Share::Leader(_) => None,
Share::Helper(ref seed) => {
let prng = Prng::from_prio2_seed(seed.as_ref());
Some(prng.take(proof_length(self.input_len)).collect())
}
};
let data = match input_share {
Share::Leader(ref data) => data,
Share::Helper(_) => expanded_data.as_ref().unwrap(),
};
let mut mem = v2_server::ValidationMemory::new(self.input_len);
let verifier_share = v2_server::generate_verification_message(
self.input_len,
query_rand,
data, // Combined input and proof shares
is_leader,
&mut mem,
)
.map_err(|e| VdafError::Uncategorized(e.to_string()))?;
Ok((
Prio2PrepareState(input_share.truncated(self.input_len)),
Prio2PrepareShare(verifier_share),
))
}
}
impl Vdaf for Prio2 {
type Measurement = Vec<u32>;
type AggregateResult = Vec<u32>;
type AggregationParam = ();
type InputShare = Share<FieldPrio2, 32>;
type OutputShare = OutputShare<FieldPrio2>;
type AggregateShare = AggregateShare<FieldPrio2>;
fn num_aggregators(&self) -> usize {
// Prio2 can easily be extended to support more than two Aggregators.
2
}
}
impl Client for Prio2 {
fn shard(&self, measurement: &Vec<u32>) -> Result<Vec<Share<FieldPrio2, 32>>, VdafError> {
if measurement.len() != self.input_len {
return Err(VdafError::Uncategorized("incorrect input length".into()));
}
let mut input: Vec<FieldPrio2> = Vec::with_capacity(measurement.len());
for int in measurement {
input.push((*int).into());
}
let mut mem = v2_client::ClientMemory::new(self.input_len)?;
let copy_data = |share_data: &mut [FieldPrio2]| {
share_data[..].clone_from_slice(&input);
};
let mut leader_data = mem.prove_with(self.input_len, copy_data);
let helper_seed = Seed::generate()?;
let helper_prng = Prng::from_prio2_seed(helper_seed.as_ref());
for (s1, d) in leader_data.iter_mut().zip(helper_prng.into_iter()) {
*s1 -= d;
}
Ok(vec![Share::Leader(leader_data), Share::Helper(helper_seed)])
}
}
/// State of each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Prio2PrepareState(Share<FieldPrio2, 32>);
impl Encode for Prio2PrepareState {
fn encode(&self, bytes: &mut Vec<u8>) {
self.0.encode(bytes);
}
}
impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState {
fn decode_with_param(
(prio2, agg_id): &(&'a Prio2, usize),
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
let share_decoder = if *agg_id == 0 {
ShareDecodingParameter::Leader(prio2.input_len)
} else {
ShareDecodingParameter::Helper
};
let out_share = Share::decode_with_param(&share_decoder, bytes)?;
Ok(Self(out_share))
}
}
/// Message emitted by each [`Aggregator`](crate::vdaf::Aggregator) during the Preparation phase.
#[derive(Clone, Debug)]
pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>);
impl Encode for Prio2PrepareShare {
fn encode(&self, bytes: &mut Vec<u8>) {
self.0.f_r.encode(bytes);
self.0.g_r.encode(bytes);
self.0.h_r.encode(bytes);
}
}
impl ParameterizedDecode<Prio2PrepareState> for Prio2PrepareShare {
fn decode_with_param(
_state: &Prio2PrepareState,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
Ok(Self(v2_server::VerificationMessage {
f_r: FieldPrio2::decode(bytes)?,
g_r: FieldPrio2::decode(bytes)?,
h_r: FieldPrio2::decode(bytes)?,
}))
}
}
impl Aggregator<32> for Prio2 {
type PrepareState = Prio2PrepareState;
type PrepareShare = Prio2PrepareShare;
type PrepareMessage = ();
fn prepare_init(
&self,
agg_key: &[u8; 32],
agg_id: usize,
_agg_param: &(),
nonce: &[u8],
input_share: &Share<FieldPrio2, 32>,
) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> {
let is_leader = role_try_from(agg_id)?;
// In the ENPA Prio system, the query randomness is generated by a third party and
// distributed to the Aggregators after they receive their input shares. In a VDAF, shared
// randomness is derived from a nonce selected by the client. For Prio2 we compute the
// query using HMAC-SHA256 evaluated over the nonce.
let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, agg_key);
let hmac_tag = hmac::sign(&hmac_key, nonce);
let query_rand = Prng::from_prio2_seed(hmac_tag.as_ref().try_into().unwrap())
.next()
.unwrap();
self.prepare_init_with_query_rand(query_rand, input_share, is_leader)
}
fn prepare_preprocess<M: IntoIterator<Item = Prio2PrepareShare>>(
&self,
inputs: M,
) -> Result<(), VdafError> {
let verifier_shares: Vec<v2_server::VerificationMessage<FieldPrio2>> =
inputs.into_iter().map(|msg| msg.0).collect();
if verifier_shares.len() != 2 {
return Err(VdafError::Uncategorized(
"wrong number of verifier shares".into(),
));
}
if !v2_server::is_valid_share(&verifier_shares[0], &verifier_shares[1]) {
return Err(VdafError::Uncategorized(
"proof verifier check failed".into(),
));
}
Ok(())
}
fn prepare_step(
&self,
state: Prio2PrepareState,
_input: (),
) -> Result<PrepareTransition<Self, 32>, VdafError> {
let data = match state.0 {
Share::Leader(data) => data,
Share::Helper(seed) => {
let prng = Prng::from_prio2_seed(seed.as_ref());
prng.take(self.input_len).collect()
}
};
Ok(PrepareTransition::Finish(OutputShare::from(data)))
}
fn aggregate<M: IntoIterator<Item = OutputShare<FieldPrio2>>>(
&self,
_agg_param: &(),
out_shares: M,
) -> Result<AggregateShare<FieldPrio2>, VdafError> {
let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
for out_share in out_shares.into_iter() {
agg_share.accumulate(&out_share)?;
}
Ok(agg_share)
}
}
impl Collector for Prio2 {
fn unshard<M: IntoIterator<Item = AggregateShare<FieldPrio2>>>(
&self,
_agg_param: &(),
agg_shares: M,
_num_measurements: usize,
) -> Result<Vec<u32>, VdafError> {
let mut agg = AggregateShare(vec![FieldPrio2::zero(); self.input_len]);
for agg_share in agg_shares.into_iter() {
agg.merge(&agg_share)?;
}
Ok(agg.0.into_iter().map(u32::from).collect())
}
}
impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Share<FieldPrio2, 32> {
fn decode_with_param(
(prio2, agg_id): &(&'a Prio2, usize),
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?;
let decoder = if is_leader {
ShareDecodingParameter::Leader(proof_length(prio2.input_len))
} else {
ShareDecodingParameter::Helper
};
Share::decode_with_param(&decoder, bytes)
}
}
fn role_try_from(agg_id: usize) -> Result<bool, VdafError> {
match agg_id {
0 => Ok(true),
1 => Ok(false),
_ => Err(VdafError::Uncategorized("unexpected aggregator id".into())),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
client::encode_simple,
encrypt::{decrypt_share, encrypt_share, PrivateKey, PublicKey},
field::random_vector,
server::Server,
vdaf::{run_vdaf, run_vdaf_prepare},
};
use rand::prelude::*;
const PRIV_KEY1: &str = "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==";
const PRIV_KEY2: &str = "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==";
#[test]
fn run_prio2() {
let prio2 = Prio2::new(6).unwrap();
assert_eq!(
run_vdaf(
&prio2,
&(),
[
vec![0, 0, 0, 0, 1, 0],
vec![0, 1, 0, 0, 0, 0],
vec![0, 1, 1, 0, 0, 0],
vec![1, 1, 1, 0, 0, 0],
vec![0, 0, 0, 0, 1, 1],
]
)
.unwrap(),
vec![1, 3, 2, 0, 2, 1],
);
}
#[test]
fn enpa_client_interop() {
let mut rng = thread_rng();
let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap();
let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap();
let pub_key1 = PublicKey::from(&priv_key1);
let pub_key2 = PublicKey::from(&priv_key2);
let data: Vec<FieldPrio2> = [0, 0, 1, 1, 0]
.iter()
.map(|x| FieldPrio2::from(*x))
.collect();
let (encrypted_input_share1, encrypted_input_share2) =
encode_simple(&data, pub_key1, pub_key2).unwrap();
let input_share1 = decrypt_share(&encrypted_input_share1, &priv_key1).unwrap();
let input_share2 = decrypt_share(&encrypted_input_share2, &priv_key2).unwrap();
let prio2 = Prio2::new(data.len()).unwrap();
let input_shares = vec![
Share::get_decoded_with_param(&(&prio2, 0), &input_share1).unwrap(),
Share::get_decoded_with_param(&(&prio2, 1), &input_share2).unwrap(),
];
let verify_key = rng.gen();
let mut nonce = [0; 16];
rng.fill(&mut nonce);
run_vdaf_prepare(&prio2, &verify_key, &(), &nonce, input_shares).unwrap();
}
#[test]
fn enpa_server_interop() {
let priv_key1 = PrivateKey::from_base64(PRIV_KEY1).unwrap();
let priv_key2 = PrivateKey::from_base64(PRIV_KEY2).unwrap();
let pub_key1 = PublicKey::from(&priv_key1);
let pub_key2 = PublicKey::from(&priv_key2);
let data = vec![0, 0, 1, 1, 0];
let prio2 = Prio2::new(data.len()).unwrap();
let input_shares = prio2.shard(&data).unwrap();
let encrypted_input_share1 =
encrypt_share(&input_shares[0].get_encoded(), &pub_key1).unwrap();
let encrypted_input_share2 =
encrypt_share(&input_shares[1].get_encoded(), &pub_key2).unwrap();
let mut server1 = Server::new(data.len(), true, priv_key1).unwrap();
let mut server2 = Server::new(data.len(), false, priv_key2).unwrap();
let eval_at: FieldPrio2 = random_vector(1).unwrap()[0];
let verifier1 = server1
.generate_verification_message(eval_at, &encrypted_input_share1)
.unwrap();
let verifier2 = server2
.generate_verification_message(eval_at, &encrypted_input_share2)
.unwrap();
server1
.aggregate(&encrypted_input_share1, &verifier1, &verifier2)
.unwrap();
server2
.aggregate(&encrypted_input_share2, &verifier1, &verifier2)
.unwrap();
}
#[test]
fn prepare_state_serialization() {
let mut verify_key = [0; 32];
thread_rng().fill(&mut verify_key[..]);
let data = vec![0, 0, 1, 1, 0];
let prio2 = Prio2::new(data.len()).unwrap();
let input_shares = prio2.shard(&data).unwrap();
for (agg_id, input_share) in input_shares.iter().enumerate() {
let (want, _msg) = prio2
.prepare_init(&verify_key, agg_id, &(), &[], input_share)
.unwrap();
let got =
Prio2PrepareState::get_decoded_with_param(&(&prio2, agg_id), &want.get_encoded())
.expect("failed to decode prepare step");
assert_eq!(got, want);
}
}
}

1092
third_party/rust/prio/src/vdaf/prio3.rs vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,162 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{
codec::{Encode, ParameterizedDecode},
flp::Type,
vdaf::{
prg::Prg,
prio3::{Prio3, Prio3InputShare, Prio3PrepareShare},
Aggregator, PrepareTransition,
},
};
use serde::{Deserialize, Serialize};
use std::{convert::TryInto, fmt::Debug};
#[derive(Debug, Deserialize, Serialize)]
struct TEncoded(#[serde(with = "hex")] Vec<u8>);
impl AsRef<[u8]> for TEncoded {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
#[derive(Deserialize, Serialize)]
struct TPrio3Prep<M> {
measurement: M,
#[serde(with = "hex")]
nonce: Vec<u8>,
input_shares: Vec<TEncoded>,
prep_shares: Vec<Vec<TEncoded>>,
prep_messages: Vec<TEncoded>,
out_shares: Vec<Vec<M>>,
}
#[derive(Deserialize, Serialize)]
struct TPrio3<M> {
verify_key: TEncoded,
prep: Vec<TPrio3Prep<M>>,
}
macro_rules! err {
(
$test_num:ident,
$error:expr,
$msg:expr
) => {
panic!("test #{} failed: {} err: {}", $test_num, $msg, $error)
};
}
// TODO Generalize this method to work with any VDAF. To do so we would need to add
// `test_vec_setup()` and `test_vec_shard()` to traits. (There may be a less invasive alternative.)
fn check_prep_test_vec<M, T, P, const L: usize>(
prio3: &Prio3<T, P, L>,
verify_key: &[u8; L],
test_num: usize,
t: &TPrio3Prep<M>,
) where
T: Type<Measurement = M>,
P: Prg<L>,
M: From<<T as Type>::Field> + Debug + PartialEq,
{
let input_shares = prio3
.test_vec_shard(&t.measurement)
.expect("failed to generate input shares");
assert_eq!(2, t.input_shares.len(), "#{}", test_num);
for (agg_id, want) in t.input_shares.iter().enumerate() {
assert_eq!(
input_shares[agg_id],
Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), want.as_ref())
.unwrap_or_else(|e| err!(test_num, e, "decode test vector (input share)")),
"#{}",
test_num
);
assert_eq!(
input_shares[agg_id].get_encoded(),
want.as_ref(),
"#{}",
test_num
)
}
let mut states = Vec::new();
let mut prep_shares = Vec::new();
for (agg_id, input_share) in input_shares.iter().enumerate() {
let (state, prep_share) = prio3
.prepare_init(verify_key, agg_id, &(), &t.nonce, input_share)
.unwrap_or_else(|e| err!(test_num, e, "prep state init"));
states.push(state);
prep_shares.push(prep_share);
}
assert_eq!(1, t.prep_shares.len(), "#{}", test_num);
for (i, want) in t.prep_shares[0].iter().enumerate() {
assert_eq!(
prep_shares[i],
Prio3PrepareShare::get_decoded_with_param(&states[i], want.as_ref())
.unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")),
"#{}",
test_num
);
assert_eq!(prep_shares[i].get_encoded(), want.as_ref(), "#{}", test_num);
}
let inbound = prio3
.prepare_preprocess(prep_shares)
.unwrap_or_else(|e| err!(test_num, e, "prep preprocess"));
assert_eq!(t.prep_messages.len(), 1);
assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref());
let mut out_shares = Vec::new();
for state in states.iter_mut() {
match prio3.prepare_step(state.clone(), inbound.clone()).unwrap() {
PrepareTransition::Finish(out_share) => {
out_shares.push(out_share);
}
_ => panic!("unexpected transition"),
}
}
for (got, want) in out_shares.iter().zip(t.out_shares.iter()) {
let got: Vec<M> = got.as_ref().iter().map(|x| M::from(*x)).collect();
assert_eq!(&got, want);
}
}
#[test]
fn test_vec_prio3_count() {
let t: TPrio3<u64> =
serde_json::from_str(include_str!("test_vec/01/Prio3Aes128Count.json")).unwrap();
let prio3 = Prio3::new_aes128_count(2).unwrap();
let verify_key = t.verify_key.as_ref().try_into().unwrap();
for (test_num, p) in t.prep.iter().enumerate() {
check_prep_test_vec(&prio3, &verify_key, test_num, p);
}
}
#[test]
fn test_vec_prio3_sum() {
let t: TPrio3<u128> =
serde_json::from_str(include_str!("test_vec/01/Prio3Aes128Sum.json")).unwrap();
let prio3 = Prio3::new_aes128_sum(2, 8).unwrap();
let verify_key = t.verify_key.as_ref().try_into().unwrap();
for (test_num, p) in t.prep.iter().enumerate() {
check_prep_test_vec(&prio3, &verify_key, test_num, p);
}
}
#[test]
fn test_vec_prio3_histogram() {
let t: TPrio3<u128> =
serde_json::from_str(include_str!("test_vec/01/Prio3Aes128Histogram.json")).unwrap();
let prio3 = Prio3::new_aes128_histogram(2, &[1, 10, 100]).unwrap();
let verify_key = t.verify_key.as_ref().try_into().unwrap();
for (test_num, p) in t.prep.iter().enumerate() {
check_prep_test_vec(&prio3, &verify_key, test_num, p);
}
}

View File

@ -0,0 +1,7 @@
{
"derived_seed": "ccf3be704c982182ad2961e9795a88aa",
"expanded_vec_field128": "ccf3be704c982182ad2961e9795a88aa8df71c0b5ea5c13bcf3173c3f3626505e1bf4738874d5405805082cc38c55d1f04f85fbb88b8cf8592ffed8a4ac7f76991c58d850a15e8deb34fb289ab6fab584554ffef16c683228db2b76e792ca4f3c15760044d0703b438c2aefd7975c5dd4b9992ee6f87f20e570572dea18fa580ee17204903c1234f1332d47a442ea636580518ce7aa5943c415117460a049bc19cc81edbb0114d71890cbdbe4ea2664cd038e57b88fb7fd3557830ad363c20b9840d35fd6bee6c3c8424f026ee7fbca3daf3c396a4d6736d7bd3b65b2c228d22a40f4404e47c61b26ac3c88bebf2f268fa972f8831f18bee374a22af0f8bb94d9331a1584bdf8cf3e8a5318b546efee8acd28f6cba8b21b9d52acbae8e726500340da98d643d0a5f1270ecb94c574130cea61224b0bc6d438b2f4f74152e72b37e6a9541c9dc5515f8f98fd0d1bce8743f033ab3e8574180ffc3363f3a0490f6f9583bf73a87b9bb4b51bfd0ef260637a4288c37a491c6cbdc46b6a86cd26edf611793236e912e7227bfb85b560308b06238bbd978f72ed4a58583cf0c6e134066eb6b399ad2f26fa01d69a62d8a2d04b4b8acf82299b07a834d4c2f48fee23a24c20307f9cabcd34b6d69f1969588ebde777e46e9522e866e6dd1e14119a1cb4c0709fa9ea347d9f872e76a39313e7d49bfbf3e5ce807183f43271ba2b5c6aaeaef22da301327c1fd9fedde7c5a68d9b97fa6eb687ec8ca692cb0f631f46e6699a211a1254026c9a0a43eceb450dc97cfa923321baf1f4b6f233260d46182b844dccec153aaddd20f920e9e13ff11434bcd2aa632bf4f544f41b5ddced962939676476f70e0b8640c3471fc7af62d80053781295b070388f7b7f1fa66220cb3",
"info": "696e666f20737472696e67",
"length": 40,
"seed": "01010101010101010101010101010101"
}

View File

@ -0,0 +1,38 @@
{
"agg_param": null,
"agg_result": [
1
],
"agg_shares": [
"ae5483343eb35a52",
"51ab7ccac14ca5b0"
],
"prep": [
{
"input_shares": [
"ae5483343eb35a52fcb36a62271a7ddb47f09d0ea2c6613807f84ac2e16814c82bcabdc9db5080fdf4f4f778734644fc",
"0101010101010101010101010101010101010101010101010101010101010101"
],
"measurement": 1,
"nonce": "01010101010101010101010101010101",
"out_shares": [
[
12561809521056635474
],
[
5884934548357948848
]
],
"prep_messages": [
""
],
"prep_shares": [
[
"22ce013d3aaa7e7574ed01fe1d074cd845dfbbbc5901cabd487d4e2e228274cc",
"dd31fec1c555818c51ab7ccac14ca5b00aae1c33d835c76dfa9406011a92a8e9"
]
]
}
],
"verify_key": "01010101010101010101010101010101"
}

View File

@ -0,0 +1,52 @@
{
"agg_param": null,
"agg_result": [
0,
0,
1,
0
],
"agg_shares": [
"ae5483353eb35a3371beec8f796e9afd086cb72d05a83a3dbefbe273acb0410787b1afba2065df5389011fd8963091e3004fa07fc91018af378da47c89abf1bd",
"51ab7ccac14ca5b08e41137086916504f79348d2fa57c5a641041d8c534fbefa784e5045df9a209076fee02769cf6e1fffb05f8036efe734c8725b8376540e44"
],
"buckets": [
1,
10,
100
],
"prep": [
{
"input_shares": [
"ae5483353eb35a3371beec8f796e9afd086cb72d05a83a3dbefbe273acb0410787b1afba2065df5389011fd8963091e3004fa07fc91018af378da47c89abf1bdfcb36a63271a7dbe47f09d0ea2c6613956dfe44e1302160dd2ade0205aa0409225caf0f966df97691568169000ef0af27c0985636e34889bc3fef4df192d7ead56e0dd51187bdc6662505cbd2962843cf2a1929642367f32058c531a6c611d76441e4ba82d136ba4aab16f2a612df63678d42d527e59d8a0b4cb2f07ed8aaf04199819a25fad1b8cad62fb2ec5a9bd78b2e013a50250c8bd44a15ad7d5edac35a58bed81a4088c72430afbd6fe34635a737cb7c4d29ffc9947b6b0fb8f3fdede9d8bd495b4d47e8400bded8aa53e4a5a2d063c6091c29613e044082b0555ce74c45b823aa8c5804aacdd3dc92a6ac00587557770972dcdc37eefb42eef43a1b401010101010101010101010101010101d5bf864de68bac19204e29697bf9504d",
"0101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101018e2e553b5e45c62e0ced57ec947c8627"
],
"measurement": 50,
"nonce": "01010101010101010101010101010101",
"out_shares": [
[
231724485416847873323492487111470127869,
11198307274976669387765744195748249863,
180368380143069850478496598824148046307,
413446761563421317675646300023681469
],
[
108557881504090589623373286256430638340,
329084059645961793559100029172152516346,
159913986777868612468369174543752719903,
339868920159375041629190127067877084740
]
],
"prep_messages": [
"5b91d376b8ce6a372ca37e85ef85d66a"
],
"prep_shares": [
[
"93b9dd4a3b46d8941fe7524a5cf1cd47ff8ee9c0c2e9b8230b1b940b665263b7b1c4b370652a333ee774ec9cd379b6e78e2e553b5e45c62e0ced57ec947c8627",
"6c4622b5c4b9274fe018adb5a30e32baa310ddd3e5ed87892dca520d1bed7d02998e38190652ba60a225e19211d77e22d5bf864de68bac19204e29697bf9504d"
]
]
}
],
"verify_key": "01010101010101010101010101010101"
}

View File

@ -0,0 +1,39 @@
{
"agg_param": null,
"agg_result": [
100
],
"agg_shares": [
"b6a735c5636efee29c0c1455e0c0f7d8",
"4958ca3a9c91010163f3ebaa1f3f088d"
],
"bits": 8,
"prep": [
{
"input_shares": [
"ae5483353eb35a3371beec8f796e9afd086cb72d05a83a3dbefbe273acb0410787b1afba2065df5389011fd8963091e3004fa07fc91018af378da47c89abf1bd85047e40874e2cdc5f3bc48f363b89f746770a402a777bed31b5a10c7319b3908d72de0c651215ba78d3cf681e07c564c0b4a9a4508df645bad8fef61e3ddf37fcb36a63271a7dbe47f09d0ea2c661396ad006d8915d149ad88f9b1cdb86e1d13d683c359b7ac899a2454316051e4e235dfd566f3459c336826555ed7f1baabf241e9a697d458912f3bd3778225e832b78cd4f17e57c9b9678cf6043894aff0d0f2e06828982ac3493ae5ded0c9886ea13d52bc0f209dc2f4e676c42b95b548a413f67b03ff18e9e6b699338400e9dffa800563abb495364acffc17126bf0bf8ff3c5caba82333e91352e03c637d44dc4db159a1b19d8db4d5a3fce356f6f2fca4adc9bcf65bec8d4d962b2b40f7ea413aead09979d4958707bf4098bb28829b79e381aaad8f69b7f2e6c159bbcb342ef7df2d9c56a906b171ab61b025b7c19aad8de495a8a97af2baab6d1240d30df417d1cc0fe7a90adaad8115924c0987fe1d16abe0c8a3c297d58a3112b818df72a10a41b34aa6b4ae370b1340a6085c8dcd597eead5d2584fdb160f0a086a56ea6a7736666ae34d3012fdb2c24af3d4b2a6ae735edfe837eaab1309eaa2d8273e7dbfe0fd4166d545ce8354e1237b48456715d12e38d02cd64c96b9daa01a2281d8a930817088c648b7c115e1550ada14b6072ada49be3c7e3f184db2461160d29937caa97db6020a5598063f1dc05653d1d380b34e923bd7170eeeb811bfc3ce12c1df55cf552e986a823743fac4723a48bda6c18ffec653c1f182890197e9fc74631dcbd0283c4258933c03aee9404f01010101010101010101010101010101fb0c701f7c07b9407a4a7b77d1ea017e",
"0101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101012d7667bffd0f81b078896503385f6f13"
],
"measurement": 100,
"nonce": "01010101010101010101010101010101",
"out_shares": [
[
242787699414660215404830418280405596120
],
[
97494667506278247542035355087495170189
]
],
"prep_messages": [
"d67a17a0810838f002c31e74e9b56e6d"
],
"prep_shares": [
[
"9f7aca77f790b930b46e8cd786ff1a239aa00e7aaaa734cc2bbcb121eb7c5bc00aef22fd95a24cd1a0054bde0dba06062d7667bffd0f81b078896503385f6f13",
"60853588086f46b34b9173287900e5de2becb1bdb8a8009d2cdc258674f08e8157e3a202a38282c20e220e733ab61e4cfb0c701f7c07b9407a4a7b77d1ea017e"
]
]
}
],
"verify_key": "01010101010101010101010101010101"
}

View File

@ -0,0 +1,31 @@
use prio::{test_vector::Priov2TestVector, util::reconstruct_shares};
#[test]
fn priov2_backward_compatibility() {
let test_vector: Priov2TestVector =
serde_json::from_str(include_str!("test_vectors/fieldpriov2.json")).unwrap();
let mut server1 = test_vector.server_1().unwrap();
let mut server2 = test_vector.server_2().unwrap();
for (server_1_share, server_2_share) in test_vector
.server_1_shares
.iter()
.zip(&test_vector.server_2_shares)
{
let eval_at = server1.choose_eval_at();
let v1 = server1
.generate_verification_message(eval_at, server_1_share)
.unwrap();
let v2 = server2
.generate_verification_message(eval_at, server_2_share)
.unwrap();
assert!(server1.aggregate(server_1_share, &v1, &v2).unwrap());
assert!(server2.aggregate(server_2_share, &v1, &v2).unwrap());
}
let reconstructed = reconstruct_shares(server1.total_shares(), server2.total_shares()).unwrap();
assert_eq!(reconstructed, test_vector.reference_sum);
}

View File

@ -0,0 +1,31 @@
{
"server_1_private_key": "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVSq3TfyKwn0NrftKisKKVSaTOt5seJ67P5QL4hxgPWvxw==",
"server_2_private_key": "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LER6hPEHg/ObBnRPV1rwS3nj9Bj0tbjVPPyL9p8QW8B+w==",
"dimension": 10,
"server_1_shares": [
"BOLHuvkvD7N3jfTxwrvHq+O7zEJJ++Ar7439Cr1Z7NfEUO1Z6w0LLs8Koro3Ma1ej4EOyrguwrbANRQMeWMIoOsbYCdy2w1GKIZLpLiaW0OsshAJ2LSUOoZZt2bet9mMe4BH+86hBLpncZZGeW335UIz4PCkOXOgK30FrP1pMEqtMoOuxWBkpwuaWqmCKMmdWe/I+Bb9TUL5mwIM+M8jnI0N9Y7XBmRWEeW/m1ji0zL0AndgkPIYyvMfOokGLdxXhAq5HyQ=",
"BFjYcYjJ3HRGHZfIMNNyd5+qpBpoSEIUWOH6PwXrQyT2PpEwVQuX+wA2tW8/KR4cgmg/OmT0YjEPdFsvy7Eemzdsx6jOZMkiC9eCpkQDre796eyl8aZD1LM5IOsk7StKYZHBXKZrIPfRI375WQfCi7LNaUGpa6qBmphLwu+5HSVt+PVqnWIKsl30jZFpt/53DJUbknVOYBj0pMmQnuNPPbpFoq5GZxGQQOEf7GLpaowl8tUgGYOagACSUvbL77h9zvlAlNM=",
"BI7WE02BRvKAV/n/sJ28Nm5vlGokeeDLqEyTRPRi8Ud/4Dy1qfkxbPm7dOWd1hI5zgowYBLjZ2gkQlqLGWUMkZqpBNBjV3+gMTdTrxDtfipcYNTSZPTMD/bDLOZ6yWVG0VlL0+kp1f3TXllHh+efv7Z7fRPp8DHgw7NUifaySlxMw0fQpZUeu4Ey1fR9bS3HYj02ANxu6ssDR2xv9D4CRo3UFG1hX96Plffc+tmI8cnnTYYxi2hmFwZSnlQaa/mCINxzjys=",
"BBXU7/Pr2/U4VflTuDJEAwZUMY0utzhLXr3ihHK6tK/Yfa1vgxqzFTpByy5vX6AzLlczWd2Wl88qMdCvBR/+VFYFAITOrHpcqPd2fCs1zgDPkRUIgInrgdGqRAIzv+7Qa4OW90TyBXs7VknQCoqSmINTt6t0OnY/zk5IGBS6zwjog5uwoz2vD0MjliVZml+laUhjoaLFMxTD6fSSA9aGYT7seIIl8JSngceQynthkIgiDP6/Rdjxs7xb5e/8oy9RVk2LM+0=",
"BKYxAKwiced4bzyxneyxm03z1xIAyYmW/BKB1XgE+dxTO8z8d24nD8s8SW2x2YetzOyPcstxpApN0p53+NwtUEQivnNsNzXQaZ9VXzy93nu3k+WXUm1F0+DRo03Z9XmVb7LItGwXgBIZzjPXpKbXfrCUJqpYHZzDAZrrYdiSFVpnAu+C1uAnhw0qEI/qUhjNC1Woc5ogmY0hqH9cTdcJ2hAd2Q4zRdNCknw0cCFayOsLMx9YYqzs4iQFeBsBPmOoz7Fl6Oc=",
"BNMLUYo2fCeTBi+YFud7E6irH3bxbhVANM46gBQSku871dWjevs/vZ8mafE47SrulTI2f6J9Kfov7Lqt2VxQqtUecJvYeMB9upVT/IF1+D9LmHcXUp1zeF8aWEV+L1GcCX3J+0GSKNoWsMv9nXk3yUZ7sRNdLAyEPdJfbXXQPgHTpA1xLYyRIHkDT8SVnfYVaY1GuQYxr7ubF5wpy4gqmKRnbNF1gHpzO8SbyufiJsFNWdKf7FsEH8Cxi1LNrNUgoNDbDAU=",
"BIvr6t/PfMl6UTm02mwfAYs4qLKkJsQ6H8cLlnQY/uQYr/JbqfhxURNcMX5w3sh1ndojyy1wKCWSoNsOwPeROfe1fuMFED/vUvKKmMl3iJY9dLijaB7ZMFibMvv6VzmULwOh4d44FqAl1ca9fLsVEwHl6Jr6tyrOD2D4nbU+oY/WfNa8yTdxESMEqNao8vlcGj1oYcPveWUMUsXyk3v5ULq34e0VEoFTMk17ku0ZVg1Z4b1E6hT00pCTgX/yHAIiM1+hF1Y=",
"BMt/IxZcFzyGSQiNDtpXgYT8LQv29RTo0ITUEUpnk+3HfrVUm+JFvEQMqo4OQiRMVzG/uVBAEoS3Pz4McT6aISoN4Xo/TXTAgSgrtNlteCJR1IxAwZQOPCJZAwNirUpSyydazaurRKlqUac6C6zTKbeB9XeLOLOeyCIKJ0HtSQQ5/iJIR1dBkDiASRnyVdbQyfnBmQJ6UM3N7JOSqFYi290WquV/afIF5puOrrFJWyewpdUNJ5tYnT7xLIIjtwCvbUCftU8=",
"BKQ+k5o9W1B8PFpFoKKKpmOqjRu79rhvtXvgAFq1AwP2kW9qlmb3DpTxy+JTXCcWleTSOUn8720BaWIV51Pp5ygmqXPNUetBwYiGozoPKq1yxTJnhd4FWdCXjjYF8FHp/yL+VbAN6BMZqcIwqv+bEhwo6985irKDVHdKC58Y7HQF9vgifcGVrWeLwRe1r+CMNYPMSF4qG8IGxyddVu+2KaOI5n1IwIxiDAWlpwtKJfXRDqJFz65WyTMkF6OmMmcJX9IrAwE=",
"BPzpKXqufT5fAelFMXmUXt6ViHQVMuIpLJnVsuFQozqWdMKsC8MPtT3tcOCxmfeV6fqsNtooIu1XZ3u6SDxfr4iTnpv53lJkINnVvZZmSxvP5vBa4Lzb//eeiI+QoWK4AP3e2ylNJwZ3tv3L2e60/92nGX91AhaXoCPj8cfvpPVyNiQdudRqW7YKaV27UqhuY0ulDoTgSLJ+p04ahWTamllcf1LmWxyIZEhZkl8yZbaZm4Ickw9q30wEQWRy+VpHwRFfR4w="
],
"server_2_shares": [
"BLTSiXCakcw7H0V2F27Px30DPhIM4L/OHlaSU3UN8cpGXMUjXS9sL4iJVAjVSijokWZMRHJAcWfj0bJoynhE752/EowzDYo0B26oIrsyWYiMb7+BCRLQZ64uWOYJEjGCTHw82/mPrOFdE/nZ2UPkUSU=",
"BJVHFQQ/8JAs+X5DDbcv6G861NXEoWiV65Y/b1qcy7djZgDChwhPxYLtuXxJm6P7haPf5RHz4519t8fwkcnM/zV2SV5Fo1VMR1tmnIaNAfl1jP/XWsiDc7RtIMzvb5bEQQkqvVKL5SZsT2QW0RvRG6I=",
"BCWqs+VtNcfHs/BvTxA1BfezliSqILy47wqYoc2CLJF/VZkzh/Hl4gucDN0xXRsVxL+eoY9u/OpObEw0sSdMitIDolLzs92xc6qn9ch90q0+XkjP9VrtGOjqcKezPeGXLVrGhiAAfxQWpv5AEArBmgQ=",
"BIpENXCZMfvLpehnTIN4gyu4i9tJAuATOSddNzVSTYXt1S/LGkVlhy9nqF/mW6kkwK9t85+HQbuno45JuXimwK4onuCc1PUp6N+oLM8tShopAvN6xLUuGeVokSB9YaRV5DahwRzOy2cTx1CSCzzaERk=",
"BKrtC8DSg66NcWVnQVNw7fs15jcFv8bqEuHaROa939rvZoQbGJ9aYdNLi+N6NyVaWvk+y4CamwGri9UUDmdKxDQJ55xxWUAAmcw71K5dMV8wQoZVen97K9//Ti3CSo/12Rx3hFM1fj3nkfb/+LTTsfE=",
"BNEzVoHGjJA7hKskUqMJQ8rD4L7xUwNg1zGvs49pJ9OR/xRgQFMdWWKY9yZ1Z4ZMysf6RZRtCGmLefmP1/6sS/hP/jXIP2lVjwdqKesGsHj20FF4m/5KPRZrTPDCkRxdWddvbVsqFVMF15bJs9Say1Q=",
"BP615a1p0EoVSEXIXAOj9xwScaMkhM+h2774s6dknI/tFebXFnWou2wPhAvMeiZZHwkc8bnnCfSc7Ah3S00sS4VUAqSGpijl0Ghy1dIo3b7mEE7BKTcBso8xwmlp1SgY85F/i49nIE5zWZx7e0ZpzaM=",
"BIS7+DaCGddzr+576NfDYV6TWmkGZUidxUdHgVLs7wsGdAaSO8D4eAhQKg5o1bJy+p2NRVMrfc6BjgBEejTz6FixWQuYcBygbHturW8EnIm/1AUU8YLbz1cHZddjxqBUmIapQg5xMrsDggZH3xIoZzQ=",
"BDWpotQorKrwzmMMwIJQcVqHT+cPkGa1QPZM9W5HFgqkURgIuJhbaAkOk7oix84fVCxOxoO+/spY2vUNn8r+0CQu3g21EcL9AzZwtEjZHjXVLZ6xLWz3v+eFB8yLT635CXisXx7wXeWUd6jp44MunqA=",
"BM0/op40IsjhByCx8+1oXTEe2DglHUNatUHDzR/SDIlkkA0yfpK1D345YE+83yOJmuQO1wKa7TrMeqPfTdaoYAwUkibWmisTdKQd4mljghwf7ECWh5Vy6kwfu0KMBtHbvJj9ti55+kWWOqwNtLzDL1I="
],
"reference_sum": "BgAAAAcAAAAEAAAABQAAAAkAAAAGAAAAAgAAAAYAAAAEAAAABQAAAA==",
"prio_crate_version": "0.5.0"
}

View File

@ -8,3 +8,4 @@ authors = [
license = "MPL-2.0"
[dependencies]
prio = {version = "0.8.4", default-features = false }