Skip to content

Commit

Permalink
python-bindings|prover_disk|tests: Add pickle support for DiskProver
Browse files Browse the repository at this point in the history
  • Loading branch information
xdustinface committed Aug 9, 2021
1 parent 5ba8447 commit fc8b780
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 2 deletions.
22 changes: 22 additions & 0 deletions python-bindings/chiapos.cpp
Expand Up @@ -77,6 +77,28 @@ PYBIND11_MODULE(chiapos, m)

py::class_<DiskProver>(m, "DiskProver")
.def(py::init<const std::string &>())
.def(py::pickle(
[](const DiskProver &dp) { // __getstate__
return py::make_tuple(dp.GetFilename(),
dp.GetSize(),
dp.GetMemo(),
dp.GetId(),
dp.GetTableBeginPointers(),
dp.GetC2());
},
[](const py::tuple& t) { // __setstate__
if (t.size() != 6)
throw std::runtime_error("Invalid state!");

auto filename = t[0].cast<std::string>();
auto k = t[1].cast<uint8_t>();
auto memo = t[2].cast<std::vector<uint8_t>>();
auto id = t[3].cast<std::vector<uint8_t>>();
auto table_begin_pointers = t[4].cast<std::vector<uint64_t>>();
auto C2 = t[5].cast<std::vector<uint64_t>>();
return DiskProver(filename, memo, id, k,
table_begin_pointers, C2);
}))
.def(
"get_memo",
[](DiskProver &dp) {
Expand Down
30 changes: 28 additions & 2 deletions src/prover_disk.hpp
Expand Up @@ -117,6 +117,28 @@ class DiskProver {
delete[] c2_buf;
}

DiskProver(std::string filename, std::vector<uint8_t> memo, std::vector<uint8_t> id, uint8_t k,
std::vector<uint64_t> table_begin_pointers, std::vector<uint64_t> C2) :
filename(std::move(filename)),
memo(std::move(memo)),
id(std::move(id)),
k(k),
table_begin_pointers(std::move(table_begin_pointers)),
C2(std::move(C2))
{
}

DiskProver(DiskProver&& other) noexcept
{
std::lock_guard<std::mutex> lock(other._mtx);
filename = std::move(other.filename);
memo = std::move(other.memo);
id = std::move(other.id);
k = other.k;
table_begin_pointers = std::move(other.table_begin_pointers);
C2 = std::move(other.C2);
}

~DiskProver()
{
std::lock_guard<std::mutex> l(_mtx);
Expand All @@ -126,9 +148,13 @@ class DiskProver {
Encoding::ANSFree(kC3R);
}

std::vector<uint8_t> GetMemo() { return memo; }
std::vector<uint8_t> GetMemo() const { return memo; }

std::vector<uint8_t> GetId() const { return id; }

std::vector<uint64_t> GetTableBeginPointers() const { return table_begin_pointers; }

std::vector<uint8_t> GetId() { return id; }
std::vector<uint64_t> GetC2() const { return C2; }

std::string GetFilename() const noexcept { return filename; }

Expand Down
30 changes: 30 additions & 0 deletions tests/test_python_bindings.py
@@ -1,4 +1,6 @@
import unittest
import pickle

from chiapos import DiskProver, DiskPlotter, Verifier
from hashlib import sha256
from pathlib import Path
Expand Down Expand Up @@ -176,6 +178,34 @@ def test_faulty_plot_doesnt_crash(self):
print(f"Successes: {successes}")
print(f"Failures: {failures}")

def test_pickle_support(self):
if not Path("test_plot.dat").exists():
plot_id: bytes = bytes([i for i in range(0, 32)])
pl = DiskPlotter()
pl.create_plot_disk(
".",
".",
".",
"test_plot.dat",
21,
bytes([1, 2, 3, 4, 5]),
plot_id,
300,
32,
8192,
8,
False,
)

prover1: DiskProver = DiskProver(str(Path("test_plot.dat")))
data = pickle.dumps(prover1)
prover_recovered: DiskProver = pickle.loads(data)
assert prover1.get_size() == prover_recovered.get_size()
assert prover1.get_filename() == prover_recovered.get_filename()
assert prover1.get_id() == prover_recovered.get_id()
assert prover1.get_memo() == prover_recovered.get_memo()
Path("test_plot.dat").unlink()


if __name__ == "__main__":
unittest.main()

0 comments on commit fc8b780

Please sign in to comment.