Skip to content

Commit

Permalink
python-bindings|prover_disk|tests: Add pickle support for DiskProver
Browse files Browse the repository at this point in the history
  • Loading branch information
xdustinface committed Aug 10, 2021
1 parent c93d559 commit 44494f8
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 2 deletions.
22 changes: 22 additions & 0 deletions python-bindings/chiapos.cpp
Expand Up @@ -77,6 +77,28 @@ PYBIND11_MODULE(chiapos, m)

py::class_<DiskProver>(m, "DiskProver")
.def(py::init<const std::string &>())
.def(py::pickle(
[](const DiskProver &dp) { // __getstate__
return py::make_tuple(dp.GetFilename(),
dp.GetSize(),
dp.GetMemo(),
dp.GetId(),
dp.GetTableBeginPointers(),
dp.GetC2());
},
[](const py::tuple& t) { // __setstate__
if (t.size() != 6)
throw std::runtime_error("Invalid state!");

auto filename = t[0].cast<std::string>();
auto k = t[1].cast<uint8_t>();
auto memo = t[2].cast<std::vector<uint8_t>>();
auto id = t[3].cast<std::vector<uint8_t>>();
auto table_begin_pointers = t[4].cast<std::vector<uint64_t>>();
auto C2 = t[5].cast<std::vector<uint64_t>>();
return DiskProver(filename, memo, id, k,
table_begin_pointers, C2);
}))
.def(
"get_memo",
[](DiskProver &dp) {
Expand Down
45 changes: 43 additions & 2 deletions src/prover_disk.hpp
Expand Up @@ -118,6 +118,43 @@ class DiskProver {
delete[] c2_buf;
}

// Note: This constructor presumes the input parameter are valid for the provided file and does
// not validate the file itself.
DiskProver(std::string filename, std::vector<uint8_t> memo, std::vector<uint8_t> id, uint8_t k,
std::vector<uint64_t> table_begin_pointers, std::vector<uint64_t> C2) :
filename(std::move(filename)),
memo(std::move(memo)),
id(std::move(id)),
k(k),
table_begin_pointers(std::move(table_begin_pointers)),
C2(std::move(C2))
{
if (this->id.size() != kIdLen) {
throw std::invalid_argument("Invalid id size: " + std::to_string(this->id.size()));
}
if (k < kMinPlotSize || k > kMaxPlotSize) {
throw std::invalid_argument("Invalid k: " + std::to_string(k));
}
if (this->table_begin_pointers.size() != nTableBeginPointerCount) {
throw std::invalid_argument("Invalid table_begin_pointers size: " + std::to_string(this->table_begin_pointers.size()));
}
uint32_t nExpectedSize = ((this->table_begin_pointers[10] - this->table_begin_pointers[9]) / (Util::ByteAlign(k) / 8)) - 1;
if (this->C2.size() != nExpectedSize) {
throw std::invalid_argument("Invalid C2 size: " + std::to_string(this->C2.size()));
}
}

DiskProver(DiskProver&& other) noexcept
{
std::lock_guard<std::mutex> lock(other._mtx);
filename = std::move(other.filename);
memo = std::move(other.memo);
id = std::move(other.id);
k = other.k;
table_begin_pointers = std::move(other.table_begin_pointers);
C2 = std::move(other.C2);
}

~DiskProver()
{
std::lock_guard<std::mutex> l(_mtx);
Expand All @@ -127,9 +164,13 @@ class DiskProver {
Encoding::ANSFree(kC3R);
}

const std::vector<uint8_t>& GetMemo() { return memo; }
const std::vector<uint8_t>& GetMemo() const { return memo; }

const std::vector<uint8_t>& GetId() const { return id; }

const std::vector<uint64_t>& GetTableBeginPointers() const { return table_begin_pointers; }

const std::vector<uint8_t>& GetId() { return id; }
const std::vector<uint64_t>& GetC2() const { return C2; }

std::string GetFilename() const noexcept { return filename; }

Expand Down
63 changes: 63 additions & 0 deletions tests/test.cpp
Expand Up @@ -1026,3 +1026,66 @@ TEST_CASE("FilteredDisk")
*/
remove("test_file.bin");
}

TEST_CASE("DiskProver")
{
SECTION("Construction")
{
std::string filename = "prover_test.plot";
DiskPlotter plotter = DiskPlotter();
std::vector<uint8_t> memo{1, 2, 3};
plotter.CreatePlotDisk(
".", ".", ".", filename, 18, memo.data(),
memo.size(), plot_id_1, 32, 11, 0,
4000, 2);
DiskProver prover1(filename);
std::string p1_filename = prover1.GetFilename();
std::vector<uint8_t> p1_id = prover1.GetId();
std::vector<uint8_t> p1_memo = prover1.GetMemo();
uint8_t p1_k = prover1.GetSize();
std::vector<uint64_t> p1_pointers = prover1.GetTableBeginPointers();
std::vector<uint64_t> p1_C2 = prover1.GetC2();
DiskProver prover2(p1_filename,
p1_memo,
p1_id,
p1_k,
p1_pointers,
p1_C2);
REQUIRE(prover1.GetFilename() == prover2.GetFilename());
REQUIRE(prover1.GetSize() == prover2.GetSize());
REQUIRE(prover1.GetId() == prover2.GetId());
REQUIRE(prover1.GetMemo() == prover2.GetMemo());
vector<unsigned char> hash_input = intToBytes(0, 4);
vector<unsigned char> hash(picosha2::k_digest_size);
picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
vector<LargeBits> qualities1 = prover1.GetQualitiesForChallenge(hash.data());
LargeBits proof1 = prover1.GetFullProof(hash.data(), 0);
vector<LargeBits> qualities2 = prover2.GetQualitiesForChallenge(hash.data());
LargeBits proof2 = prover2.GetFullProof(hash.data(), 0);
REQUIRE(qualities1 == qualities2);
REQUIRE(proof1 == proof2);
// Test "Invalid id"
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, {}, p1_k, p1_pointers, p1_C2));
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, std::vector<uint8_t>(kIdLen - 1, 0), p1_k, p1_pointers, p1_C2));
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, std::vector<uint8_t>(kIdLen + 1, 0), p1_k, p1_pointers, p1_C2));
// Test "Invalid k"
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, 0, p1_pointers, p1_C2));
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, kMinPlotSize - 1, p1_pointers, p1_C2));
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, kMaxPlotSize + 1, p1_pointers, p1_C2));
// Test "Invalid table_begin_pointers size"
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, {}, p1_C2));
auto invalid_pointers = std::vector<uint64_t>(p1_pointers.begin(), p1_pointers.end() - 1);
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, invalid_pointers, p1_C2));
invalid_pointers.push_back(p1_pointers.back());
invalid_pointers.push_back(p1_pointers.back());
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, invalid_pointers, p1_C2));
// Test "Invalid C2 size"
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, {}, p1_C2));
auto invalid_c2 = std::vector<uint64_t>(p1_C2.begin(), p1_C2.end() - 1);
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, p1_pointers, invalid_c2));
invalid_c2.push_back(p1_C2.back());
invalid_c2.push_back(p1_C2.back());
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, p1_pointers, invalid_c2));
REQUIRE(remove(filename.c_str()) == 0);
}
}
30 changes: 30 additions & 0 deletions tests/test_python_bindings.py
@@ -1,4 +1,6 @@
import unittest
import pickle

from chiapos import DiskProver, DiskPlotter, Verifier
from hashlib import sha256
from pathlib import Path
Expand Down Expand Up @@ -176,6 +178,34 @@ def test_faulty_plot_doesnt_crash(self):
print(f"Successes: {successes}")
print(f"Failures: {failures}")

def test_pickle_support(self):
if not Path("test_plot.dat").exists():
plot_id: bytes = bytes([i for i in range(0, 32)])
pl = DiskPlotter()
pl.create_plot_disk(
".",
".",
".",
"test_plot.dat",
21,
bytes([1, 2, 3, 4, 5]),
plot_id,
300,
32,
8192,
8,
False,
)

prover1: DiskProver = DiskProver(str(Path("test_plot.dat")))
data = pickle.dumps(prover1)
prover_recovered: DiskProver = pickle.loads(data)
assert prover1.get_size() == prover_recovered.get_size()
assert prover1.get_filename() == prover_recovered.get_filename()
assert prover1.get_id() == prover_recovered.get_id()
assert prover1.get_memo() == prover_recovered.get_memo()
Path("test_plot.dat").unlink()


if __name__ == "__main__":
unittest.main()

0 comments on commit 44494f8

Please sign in to comment.