diff --git a/python-bindings/chiapos.cpp b/python-bindings/chiapos.cpp index b887d9902..8c0824dd4 100644 --- a/python-bindings/chiapos.cpp +++ b/python-bindings/chiapos.cpp @@ -77,6 +77,29 @@ PYBIND11_MODULE(chiapos, m) py::class_(m, "DiskProver") .def(py::init()) + .def("is_valid", [](const DiskProver &dp) { return dp.IsValid(); }) + .def(py::pickle( + [](const DiskProver& dp) { // __getstate__ + return py::make_tuple(dp.GetFilename(), + dp.GetSize(), + dp.GetMemo(), + dp.GetId(), + dp.GetTableBeginPointers(), + dp.GetC2()); + }, + [](const py::tuple& t) { // __setstate__ + if (t.size() != 6) + throw std::runtime_error("Invalid state!"); + + auto filename = t[0].cast(); + auto k = t[1].cast(); + auto memo = t[2].cast>(); + auto id = t[3].cast>(); + auto table_begin_pointers = t[4].cast>(); + auto C2 = t[5].cast>(); + return DiskProver(filename, memo, id, k, + table_begin_pointers, C2); + })) .def( "get_memo", [](DiskProver &dp) { diff --git a/src/prover_disk.hpp b/src/prover_disk.hpp index 4d1191645..4ae638b9f 100644 --- a/src/prover_disk.hpp +++ b/src/prover_disk.hpp @@ -47,74 +47,52 @@ struct plot_header { // The DiskProver, given a correctly formatted plot file, can efficiently generate valid proofs // of space, for a given challenge. class DiskProver { + static const size_t table_begin_pointers_size{11}; public: // The constructor opens the file, and reads the contents of the file header. The table pointers // will be used to find and seek to all seven tables, at the time of proving. explicit DiskProver(const std::string& filename) : id(kIdLen) { - struct plot_header header{}; this->filename = filename; + Read(id, memo, k, table_begin_pointers, C2); + } - std::ifstream disk_file(filename, std::ios::in | std::ios::binary); - - if (!disk_file.is_open()) { - throw std::invalid_argument("Invalid file " + filename); + // Note: This constructor presumes the input parameter are valid for the provided file and does + // not validate the file itself. It's recommended to use `IsValid` after construction to make + // sure the given file matches expectations. + DiskProver(std::string filename, std::vector memo, std::vector id, uint8_t k, + std::vector table_begin_pointers, std::vector C2) : + filename(std::move(filename)), + memo(std::move(memo)), + id(std::move(id)), + k(k), + table_begin_pointers(std::move(table_begin_pointers)), + C2(std::move(C2)) + { + if (this->id.size() != kIdLen) { + throw std::invalid_argument("Invalid id size: " + std::to_string(this->id.size())); } - // 19 bytes - "Proof of Space Plot" (utf-8) - // 32 bytes - unique plot id - // 1 byte - k - // 2 bytes - format description length - // x bytes - format description - // 2 bytes - memo length - // x bytes - memo - - SafeRead(disk_file, (uint8_t*)&header, sizeof(header)); - if (memcmp(header.magic, "Proof of Space Plot", sizeof(header.magic)) != 0) - throw std::invalid_argument("Invalid plot header magic"); - - uint16_t fmt_desc_len = Util::TwoBytesToInt(header.fmt_desc_len); - - if (fmt_desc_len == kFormatDescription.size() && - !memcmp(header.fmt_desc, kFormatDescription.c_str(), fmt_desc_len)) { - // OK - } else { - throw std::invalid_argument("Invalid plot file format"); + if (k < kMinPlotSize || k > kMaxPlotSize) { + throw std::invalid_argument("Invalid k: " + std::to_string(k)); } - memcpy(id.data(), header.id, sizeof(header.id)); - this->k = header.k; - SafeSeek(disk_file, offsetof(struct plot_header, fmt_desc) + fmt_desc_len); - - uint8_t size_buf[2]; - SafeRead(disk_file, size_buf, 2); - memo.resize(Util::TwoBytesToInt(size_buf)); - SafeRead(disk_file, memo.data(), memo.size()); - - this->table_begin_pointers = std::vector(11, 0); - this->C2 = std::vector(); - - uint8_t pointer_buf[8]; - for (uint8_t i = 1; i < 11; i++) { - SafeRead(disk_file, pointer_buf, 8); - this->table_begin_pointers[i] = Util::EightBytesToInt(pointer_buf); - } - - SafeSeek(disk_file, table_begin_pointers[9]); - - uint8_t c2_size = (Util::ByteAlign(k) / 8); - uint32_t c2_entries = (table_begin_pointers[10] - table_begin_pointers[9]) / c2_size; - if (c2_entries == 0 || c2_entries == 1) { - throw std::invalid_argument("Invalid C2 table size"); + if (this->table_begin_pointers.size() != table_begin_pointers_size) { + throw std::invalid_argument("Invalid table_begin_pointers size: " + std::to_string(this->table_begin_pointers.size())); } - - // The list of C2 entries is small enough to keep in memory. When proving, we can - // read from disk the C1 and C3 entries. - auto* c2_buf = new uint8_t[c2_size]; - for (uint32_t i = 0; i < c2_entries - 1; i++) { - SafeRead(disk_file, c2_buf, c2_size); - this->C2.push_back(Bits(c2_buf, c2_size, c2_size * 8).Slice(0, k).GetValue()); + const uint32_t expected_size = ((this->table_begin_pointers[10] - this->table_begin_pointers[9]) / (Util::ByteAlign(k) / 8)) - 1; + if (this->C2.size() != expected_size) { + throw std::invalid_argument("Invalid C2 size: " + std::to_string(this->C2.size())); } + } - delete[] c2_buf; + DiskProver(DiskProver&& other) noexcept + { + std::lock_guard lock(other._mtx); + filename = std::move(other.filename); + memo = std::move(other.memo); + id = std::move(other.id); + k = other.k; + table_begin_pointers = std::move(other.table_begin_pointers); + C2 = std::move(other.C2); } ~DiskProver() @@ -126,14 +104,51 @@ class DiskProver { Encoding::ANSFree(kC3R); } - const std::vector& GetMemo() { return memo; } + const std::vector& GetMemo() const { return memo; } - const std::vector& GetId() { return id; } + const std::vector& GetId() const { return id; } + + const std::vector& GetTableBeginPointers() const noexcept { return table_begin_pointers; } - std::string GetFilename() const noexcept { return filename; } + const std::vector& GetC2() const noexcept { return C2; } + + const std::string& GetFilename() const noexcept { return filename; } uint8_t GetSize() const noexcept { return k; } + void Validate() const + { + uint8_t k_out; + std::vector id_out, memo_out; + std::vector table_begin_pointers_out, C2_out; + Read(id_out, memo_out, k_out, table_begin_pointers_out, C2_out); + if (id != id_out) { + throw std::invalid_argument("EnsureValid: Invalid id"); + } + if (memo != memo_out) { + throw std::invalid_argument("EnsureValid: Invalid memo"); + } + if (k != k_out) { + throw std::invalid_argument("EnsureValid: Invalid k"); + } + if (table_begin_pointers != table_begin_pointers_out) { + throw std::invalid_argument("EnsureValid: Invalid table_begin_pointers"); + } + if (C2 != C2_out) { + throw std::invalid_argument("EnsureValid: Invalid C2"); + } + } + + bool IsValid() const + { + try { + Validate(); + } catch (...) { + return false; + } + return true; + } + // Given a challenge, returns a quality string, which is sha256(challenge + 2 adjecent x // values), from the 64 value proof. Note that this is more efficient than fetching all 64 x // values, which are in different parts of the disk. @@ -238,10 +253,82 @@ class DiskProver { std::string filename; std::vector memo; std::vector id; // Unique plot id - uint8_t k; + uint8_t k{0}; std::vector table_begin_pointers; std::vector C2; + void Read(std::vector& id_out, + std::vector& memo_out, + uint8_t& k_out, + std::vector& table_begin_pointers_out, + std::vector& C2_out) const + { + std::lock_guard lock(_mtx); + + std::ifstream disk_file(filename, std::ios::in | std::ios::binary); + + if (!disk_file.is_open()) { + throw std::invalid_argument("Invalid file " + filename); + } + struct plot_header header{}; + // 19 bytes - "Proof of Space Plot" (utf-8) + // 32 bytes - unique plot id + // 1 byte - k + // 2 bytes - format description length + // x bytes - format description + // 2 bytes - memo length + // x bytes - memo + + SafeRead(disk_file, (uint8_t*)&header, sizeof(header)); + if (memcmp(header.magic, "Proof of Space Plot", sizeof(header.magic)) != 0) + throw std::invalid_argument("Invalid plot header magic"); + + uint16_t fmt_desc_len = Util::TwoBytesToInt(header.fmt_desc_len); + + if (fmt_desc_len == kFormatDescription.size() && + !memcmp(header.fmt_desc, kFormatDescription.c_str(), fmt_desc_len)) { + // OK + } else { + throw std::invalid_argument("Invalid plot file format"); + } + id_out = std::vector(kIdLen); + memcpy(id_out.data(), header.id, sizeof(header.id)); + k_out = header.k; + SafeSeek(disk_file, offsetof(struct plot_header, fmt_desc) + fmt_desc_len); + + uint8_t size_buf[2]; + SafeRead(disk_file, size_buf, 2); + memo_out.resize(Util::TwoBytesToInt(size_buf)); + SafeRead(disk_file, memo_out.data(), memo_out.size()); + + table_begin_pointers_out = std::vector(table_begin_pointers_size, 0); + C2_out = std::vector(); + + uint8_t pointer_buf[8]; + for (size_t i = 1; i < table_begin_pointers_size; i++) { + SafeRead(disk_file, pointer_buf, 8); + table_begin_pointers_out[i] = Util::EightBytesToInt(pointer_buf); + } + + SafeSeek(disk_file, table_begin_pointers_out[9]); + + uint8_t c2_size = (Util::ByteAlign(k_out) / 8); + uint32_t c2_entries = (table_begin_pointers_out[10] - table_begin_pointers_out[9]) / c2_size; + if (c2_entries == 0 || c2_entries == 1) { + throw std::invalid_argument("Invalid C2 table size"); + } + + // The list of C2 entries is small enough to keep in memory. When proving, we can + // read from disk the C1 and C3 entries. + auto* c2_buf = new uint8_t[c2_size]; + for (uint32_t i = 0; i < c2_entries - 1; i++) { + SafeRead(disk_file, c2_buf, c2_size); + C2_out.push_back(Bits(c2_buf, c2_size, c2_size * 8).Slice(0, k).GetValue()); + } + + delete[] c2_buf; + } + // Using this method instead of simply seeking will prevent segfaults that would arise when // continuing the process of looking up qualities. static void SafeSeek(std::ifstream& disk_file, uint64_t seek_location) { diff --git a/tests/test.cpp b/tests/test.cpp index 50d6f6b01..0475913a4 100644 --- a/tests/test.cpp +++ b/tests/test.cpp @@ -1026,3 +1026,105 @@ TEST_CASE("FilteredDisk") */ remove("test_file.bin"); } + +TEST_CASE("DiskProver") +{ + SECTION("Construction") + { + std::string filename = "prover_test.plot"; + DiskPlotter plotter = DiskPlotter(); + std::vector memo{1, 2, 3}; + plotter.CreatePlotDisk( + ".", ".", ".", filename, 18, memo.data(), + memo.size(), plot_id_1, 32, 11, 0, + 4000, 2); + DiskProver prover1(filename); + std::string p1_filename = prover1.GetFilename(); + std::vector p1_id = prover1.GetId(); + std::vector p1_memo = prover1.GetMemo(); + uint8_t p1_k = prover1.GetSize(); + std::vector p1_pointers = prover1.GetTableBeginPointers(); + std::vector p1_C2 = prover1.GetC2(); + DiskProver prover2(p1_filename, + p1_memo, + p1_id, + p1_k, + p1_pointers, + p1_C2); + REQUIRE(prover1.GetFilename() == prover2.GetFilename()); + REQUIRE(prover1.GetSize() == prover2.GetSize()); + REQUIRE(prover1.GetId() == prover2.GetId()); + REQUIRE(prover1.GetMemo() == prover2.GetMemo()); + vector hash_input = intToBytes(0, 4); + vector hash(picosha2::k_digest_size); + picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end()); + vector qualities1 = prover1.GetQualitiesForChallenge(hash.data()); + LargeBits proof1 = prover1.GetFullProof(hash.data(), 0); + vector qualities2 = prover2.GetQualitiesForChallenge(hash.data()); + LargeBits proof2 = prover2.GetFullProof(hash.data(), 0); + REQUIRE(qualities1 == qualities2); + REQUIRE(proof1 == proof2); + + auto test_invalidity = [](std::string filename, std::vector memo, std::vector id, uint8_t k, + std::vector table_begin_pointers, std::vector C2) { + DiskProver prover(filename, memo, id, k, table_begin_pointers, C2); + REQUIRE_THROWS(prover.Validate()); + REQUIRE(!prover.IsValid()); + }; + + // Test "Invalid file" + test_invalidity("invalid", p1_memo, p1_id, p1_k, p1_pointers, p1_C2); + // Test "Invalid id" + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, {}, p1_k, p1_pointers, p1_C2)); + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, std::vector(kIdLen - 1, 0), p1_k, p1_pointers, p1_C2)); + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, std::vector(kIdLen + 1, 0), p1_k, p1_pointers, p1_C2)); + auto invalid_id = p1_id; + invalid_id.back()++; + test_invalidity(p1_filename, p1_memo, invalid_id, p1_k, p1_pointers, p1_C2); + // Test "Invalid k" + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, 0, p1_pointers, p1_C2)); + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, kMinPlotSize - 1, p1_pointers, p1_C2)); + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, kMaxPlotSize + 1, p1_pointers, p1_C2)); + test_invalidity(p1_filename, p1_memo, p1_id, p1_k + 1, p1_pointers, p1_C2); + // Test "Invalid table_begin_pointers size" + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, {}, p1_C2)); + auto invalid_pointers = std::vector(p1_pointers.begin(), p1_pointers.end() - 1); + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, invalid_pointers, p1_C2)); + invalid_pointers.push_back(p1_pointers.back()); + invalid_pointers.push_back(p1_pointers.back()); + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, invalid_pointers, p1_C2)); + invalid_pointers.pop_back(); + invalid_pointers.back()++; + test_invalidity(p1_filename, p1_memo, p1_id, p1_k, invalid_pointers, p1_C2); + // Test "Invalid C2 size" + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, {}, p1_C2)); + auto invalid_c2 = std::vector(p1_C2.begin(), p1_C2.end() - 1); + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, p1_pointers, invalid_c2)); + invalid_c2.push_back(p1_C2.back()); + invalid_c2.push_back(p1_C2.back()); + REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, p1_pointers, invalid_c2)); + invalid_c2.pop_back(); + invalid_c2.back()++; + test_invalidity(p1_filename, p1_memo, p1_id, p1_k, p1_pointers, invalid_c2); + REQUIRE(remove(filename.c_str()) == 0); + // Test move constructor + auto* p1_filename_ptr = prover1.GetFilename().data(); + auto* p1_memo_ptr = prover1.GetMemo().data(); + auto* p1_id_ptr = prover1.GetId().data(); + auto* p1_table_begin_pointers_ptr = prover1.GetTableBeginPointers().data(); + auto* p1_C2_ptr = prover1.GetC2().data(); + DiskProver prover3(std::move(prover1)); + REQUIRE(prover3.GetFilename().data() == p1_filename_ptr); + REQUIRE(prover3.GetMemo().data() == p1_memo_ptr); + REQUIRE(prover3.GetId().data() == p1_id_ptr); + REQUIRE(prover3.GetTableBeginPointers().data() == p1_table_begin_pointers_ptr); + REQUIRE(prover3.GetC2().data() == p1_C2_ptr); + REQUIRE(prover1.GetFilename().empty()); + REQUIRE(prover1.GetMemo().empty()); + REQUIRE(prover1.GetId().empty()); + REQUIRE(prover1.GetSize() == prover1.GetSize()); + REQUIRE(prover1.GetTableBeginPointers().empty()); + REQUIRE(prover1.GetC2().empty()); + REQUIRE(!prover1.IsValid()); + } +} \ No newline at end of file diff --git a/tests/test_python_bindings.py b/tests/test_python_bindings.py index a04c696be..2b52cf197 100644 --- a/tests/test_python_bindings.py +++ b/tests/test_python_bindings.py @@ -1,4 +1,6 @@ import unittest +import pickle + from chiapos import DiskProver, DiskPlotter, Verifier from hashlib import sha256 from pathlib import Path @@ -62,6 +64,7 @@ def test_k_21(self): pl = None pr = DiskProver(str(Path("myplot.dat"))) + assert pr.is_valid() total_proofs: int = 0 total_proofs2: int = 0 @@ -176,6 +179,36 @@ def test_faulty_plot_doesnt_crash(self): print(f"Successes: {successes}") print(f"Failures: {failures}") + def test_pickle_support(self): + if not Path("test_plot.dat").exists(): + plot_id: bytes = bytes([i for i in range(0, 32)]) + pl = DiskPlotter() + pl.create_plot_disk( + ".", + ".", + ".", + "test_plot.dat", + 21, + bytes([1, 2, 3, 4, 5]), + plot_id, + 300, + 32, + 8192, + 8, + False, + ) + + prover1: DiskProver = DiskProver(str(Path("test_plot.dat"))) + data = pickle.dumps(prover1) + prover_recovered: DiskProver = pickle.loads(data) + assert prover1.get_size() == prover_recovered.get_size() + assert prover1.get_filename() == prover_recovered.get_filename() + assert prover1.get_id() == prover_recovered.get_id() + assert prover1.get_memo() == prover_recovered.get_memo() + Path("test_plot.dat").unlink() + prover_recovered_invalid: DiskProver = pickle.loads(data) + assert not prover_recovered_invalid.is_valid() + if __name__ == "__main__": unittest.main()