From fc8b78002dd571913838dc395570a95719baa04f Mon Sep 17 00:00:00 2001 From: xdustinface Date: Fri, 6 Aug 2021 16:14:38 +0200 Subject: [PATCH] python-bindings|prover_disk|tests: Add `pickle` support for `DiskProver` --- python-bindings/chiapos.cpp | 22 ++++++++++++++++++++++ src/prover_disk.hpp | 30 ++++++++++++++++++++++++++++-- tests/test_python_bindings.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/python-bindings/chiapos.cpp b/python-bindings/chiapos.cpp index aeb7dd44a..4b72937d2 100644 --- a/python-bindings/chiapos.cpp +++ b/python-bindings/chiapos.cpp @@ -77,6 +77,28 @@ PYBIND11_MODULE(chiapos, m) py::class_(m, "DiskProver") .def(py::init()) + .def(py::pickle( + [](const DiskProver &dp) { // __getstate__ + return py::make_tuple(dp.GetFilename(), + dp.GetSize(), + dp.GetMemo(), + dp.GetId(), + dp.GetTableBeginPointers(), + dp.GetC2()); + }, + [](const py::tuple& t) { // __setstate__ + if (t.size() != 6) + throw std::runtime_error("Invalid state!"); + + auto filename = t[0].cast(); + auto k = t[1].cast(); + auto memo = t[2].cast>(); + auto id = t[3].cast>(); + auto table_begin_pointers = t[4].cast>(); + auto C2 = t[5].cast>(); + return DiskProver(filename, memo, id, k, + table_begin_pointers, C2); + })) .def( "get_memo", [](DiskProver &dp) { diff --git a/src/prover_disk.hpp b/src/prover_disk.hpp index a509a7ede..a158e3455 100644 --- a/src/prover_disk.hpp +++ b/src/prover_disk.hpp @@ -117,6 +117,28 @@ class DiskProver { delete[] c2_buf; } + DiskProver(std::string filename, std::vector memo, std::vector id, uint8_t k, + std::vector table_begin_pointers, std::vector C2) : + filename(std::move(filename)), + memo(std::move(memo)), + id(std::move(id)), + k(k), + table_begin_pointers(std::move(table_begin_pointers)), + C2(std::move(C2)) + { + } + + DiskProver(DiskProver&& other) noexcept + { + std::lock_guard lock(other._mtx); + filename = std::move(other.filename); + memo = std::move(other.memo); + id = std::move(other.id); + k = other.k; + table_begin_pointers = std::move(other.table_begin_pointers); + C2 = std::move(other.C2); + } + ~DiskProver() { std::lock_guard l(_mtx); @@ -126,9 +148,13 @@ class DiskProver { Encoding::ANSFree(kC3R); } - std::vector GetMemo() { return memo; } + std::vector GetMemo() const { return memo; } + + std::vector GetId() const { return id; } + + std::vector GetTableBeginPointers() const { return table_begin_pointers; } - std::vector GetId() { return id; } + std::vector GetC2() const { return C2; } std::string GetFilename() const noexcept { return filename; } diff --git a/tests/test_python_bindings.py b/tests/test_python_bindings.py index a04c696be..790e44d75 100644 --- a/tests/test_python_bindings.py +++ b/tests/test_python_bindings.py @@ -1,4 +1,6 @@ import unittest +import pickle + from chiapos import DiskProver, DiskPlotter, Verifier from hashlib import sha256 from pathlib import Path @@ -176,6 +178,34 @@ def test_faulty_plot_doesnt_crash(self): print(f"Successes: {successes}") print(f"Failures: {failures}") + def test_pickle_support(self): + if not Path("test_plot.dat").exists(): + plot_id: bytes = bytes([i for i in range(0, 32)]) + pl = DiskPlotter() + pl.create_plot_disk( + ".", + ".", + ".", + "test_plot.dat", + 21, + bytes([1, 2, 3, 4, 5]), + plot_id, + 300, + 32, + 8192, + 8, + False, + ) + + prover1: DiskProver = DiskProver(str(Path("test_plot.dat"))) + data = pickle.dumps(prover1) + prover_recovered: DiskProver = pickle.loads(data) + assert prover1.get_size() == prover_recovered.get_size() + assert prover1.get_filename() == prover_recovered.get_filename() + assert prover1.get_id() == prover_recovered.get_id() + assert prover1.get_memo() == prover_recovered.get_memo() + Path("test_plot.dat").unlink() + if __name__ == "__main__": unittest.main()