Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python-bindings|prover_disk|tests: Add pickle support for DiskProver #303

Closed
23 changes: 23 additions & 0 deletions python-bindings/chiapos.cpp
Expand Up @@ -77,6 +77,29 @@ PYBIND11_MODULE(chiapos, m)

py::class_<DiskProver>(m, "DiskProver")
.def(py::init<const std::string &>())
.def("is_valid", [](const DiskProver &dp) { return dp.IsValid(); })
.def(py::pickle(
[](const DiskProver& dp) { // __getstate__
return py::make_tuple(dp.GetFilename(),
dp.GetSize(),
dp.GetMemo(),
dp.GetId(),
dp.GetTableBeginPointers(),
dp.GetC2());
},
[](const py::tuple& t) { // __setstate__
if (t.size() != 6)
throw std::runtime_error("Invalid state!");

auto filename = t[0].cast<std::string>();
auto k = t[1].cast<uint8_t>();
auto memo = t[2].cast<std::vector<uint8_t>>();
auto id = t[3].cast<std::vector<uint8_t>>();
xdustinface marked this conversation as resolved.
Show resolved Hide resolved
auto table_begin_pointers = t[4].cast<std::vector<uint64_t>>();
xdustinface marked this conversation as resolved.
Show resolved Hide resolved
auto C2 = t[5].cast<std::vector<uint64_t>>();
return DiskProver(filename, memo, id, k,
table_begin_pointers, C2);
}))
.def(
"get_memo",
[](DiskProver &dp) {
Expand Down
207 changes: 147 additions & 60 deletions src/prover_disk.hpp
Expand Up @@ -47,74 +47,52 @@ struct plot_header {
// The DiskProver, given a correctly formatted plot file, can efficiently generate valid proofs
// of space, for a given challenge.
class DiskProver {
static const size_t table_begin_pointers_size{11};
public:
// The constructor opens the file, and reads the contents of the file header. The table pointers
// will be used to find and seek to all seven tables, at the time of proving.
explicit DiskProver(const std::string& filename) : id(kIdLen)
{
struct plot_header header{};
this->filename = filename;
Read(id, memo, k, table_begin_pointers, C2);
}

std::ifstream disk_file(filename, std::ios::in | std::ios::binary);

if (!disk_file.is_open()) {
throw std::invalid_argument("Invalid file " + filename);
// Note: This constructor presumes the input parameter are valid for the provided file and does
// not validate the file itself. It's recommended to use `IsValid` after construction to make
// sure the given file matches expectations.
DiskProver(std::string filename, std::vector<uint8_t> memo, std::vector<uint8_t> id, uint8_t k,
std::vector<uint64_t> table_begin_pointers, std::vector<uint64_t> C2) :
filename(std::move(filename)),
memo(std::move(memo)),
id(std::move(id)),
k(k),
table_begin_pointers(std::move(table_begin_pointers)),
C2(std::move(C2))
{
xdustinface marked this conversation as resolved.
Show resolved Hide resolved
if (this->id.size() != kIdLen) {
throw std::invalid_argument("Invalid id size: " + std::to_string(this->id.size()));
}
// 19 bytes - "Proof of Space Plot" (utf-8)
// 32 bytes - unique plot id
// 1 byte - k
// 2 bytes - format description length
// x bytes - format description
// 2 bytes - memo length
// x bytes - memo

SafeRead(disk_file, (uint8_t*)&header, sizeof(header));
if (memcmp(header.magic, "Proof of Space Plot", sizeof(header.magic)) != 0)
throw std::invalid_argument("Invalid plot header magic");

uint16_t fmt_desc_len = Util::TwoBytesToInt(header.fmt_desc_len);

if (fmt_desc_len == kFormatDescription.size() &&
!memcmp(header.fmt_desc, kFormatDescription.c_str(), fmt_desc_len)) {
// OK
} else {
throw std::invalid_argument("Invalid plot file format");
if (k < kMinPlotSize || k > kMaxPlotSize) {
throw std::invalid_argument("Invalid k: " + std::to_string(k));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment above this constructor suggests the inputs are not validated, but the seem to be, at least to some extent. Are there some aspects left that aren't validated, or is the comment outdated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not say inputs are not validated. It says the inputs are not validated against the file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

}
memcpy(id.data(), header.id, sizeof(header.id));
this->k = header.k;
SafeSeek(disk_file, offsetof(struct plot_header, fmt_desc) + fmt_desc_len);

uint8_t size_buf[2];
SafeRead(disk_file, size_buf, 2);
memo.resize(Util::TwoBytesToInt(size_buf));
SafeRead(disk_file, memo.data(), memo.size());

this->table_begin_pointers = std::vector<uint64_t>(11, 0);
this->C2 = std::vector<uint64_t>();

uint8_t pointer_buf[8];
for (uint8_t i = 1; i < 11; i++) {
SafeRead(disk_file, pointer_buf, 8);
this->table_begin_pointers[i] = Util::EightBytesToInt(pointer_buf);
}

SafeSeek(disk_file, table_begin_pointers[9]);

uint8_t c2_size = (Util::ByteAlign(k) / 8);
uint32_t c2_entries = (table_begin_pointers[10] - table_begin_pointers[9]) / c2_size;
if (c2_entries == 0 || c2_entries == 1) {
throw std::invalid_argument("Invalid C2 table size");
if (this->table_begin_pointers.size() != table_begin_pointers_size) {
throw std::invalid_argument("Invalid table_begin_pointers size: " + std::to_string(this->table_begin_pointers.size()));
}

// The list of C2 entries is small enough to keep in memory. When proving, we can
// read from disk the C1 and C3 entries.
auto* c2_buf = new uint8_t[c2_size];
for (uint32_t i = 0; i < c2_entries - 1; i++) {
SafeRead(disk_file, c2_buf, c2_size);
this->C2.push_back(Bits(c2_buf, c2_size, c2_size * 8).Slice(0, k).GetValue());
const uint32_t expected_size = ((this->table_begin_pointers[10] - this->table_begin_pointers[9]) / (Util::ByteAlign(k) / 8)) - 1;
if (this->C2.size() != expected_size) {
throw std::invalid_argument("Invalid C2 size: " + std::to_string(this->C2.size()));
}
}

delete[] c2_buf;
DiskProver(DiskProver&& other) noexcept
{
std::lock_guard<std::mutex> lock(other._mtx);
filename = std::move(other.filename);
memo = std::move(other.memo);
id = std::move(other.id);
k = other.k;
table_begin_pointers = std::move(other.table_begin_pointers);
C2 = std::move(other.C2);
}
xdustinface marked this conversation as resolved.
Show resolved Hide resolved

~DiskProver()
Expand All @@ -126,14 +104,51 @@ class DiskProver {
Encoding::ANSFree(kC3R);
}

const std::vector<uint8_t>& GetMemo() { return memo; }
const std::vector<uint8_t>& GetMemo() const { return memo; }

const std::vector<uint8_t>& GetId() { return id; }
const std::vector<uint8_t>& GetId() const { return id; }

const std::vector<uint64_t>& GetTableBeginPointers() const noexcept { return table_begin_pointers; }

std::string GetFilename() const noexcept { return filename; }
const std::vector<uint64_t>& GetC2() const noexcept { return C2; }

const std::string& GetFilename() const noexcept { return filename; }

uint8_t GetSize() const noexcept { return k; }

void Validate() const
{
uint8_t k_out;
std::vector<uint8_t> id_out, memo_out;
std::vector<uint64_t> table_begin_pointers_out, C2_out;
Read(id_out, memo_out, k_out, table_begin_pointers_out, C2_out);
if (id != id_out) {
throw std::invalid_argument("EnsureValid: Invalid id");
}
if (memo != memo_out) {
throw std::invalid_argument("EnsureValid: Invalid memo");
}
if (k != k_out) {
throw std::invalid_argument("EnsureValid: Invalid k");
}
if (table_begin_pointers != table_begin_pointers_out) {
throw std::invalid_argument("EnsureValid: Invalid table_begin_pointers");
}
if (C2 != C2_out) {
throw std::invalid_argument("EnsureValid: Invalid C2");
}
}

bool IsValid() const
{
try {
Validate();
} catch (...) {
return false;
}
return true;
}

// Given a challenge, returns a quality string, which is sha256(challenge + 2 adjecent x
// values), from the 64 value proof. Note that this is more efficient than fetching all 64 x
// values, which are in different parts of the disk.
Expand Down Expand Up @@ -238,10 +253,82 @@ class DiskProver {
std::string filename;
std::vector<uint8_t> memo;
std::vector<uint8_t> id; // Unique plot id
uint8_t k;
uint8_t k{0};
std::vector<uint64_t> table_begin_pointers;
std::vector<uint64_t> C2;

void Read(std::vector<uint8_t>& id_out,
std::vector<uint8_t>& memo_out,
uint8_t& k_out,
std::vector<uint64_t>& table_begin_pointers_out,
std::vector<uint64_t>& C2_out) const
{
std::lock_guard<std::mutex> lock(_mtx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the mutex didn't used to be locked when reading this, as far as I can tell. This is only called from the filename constructor, right? I see it's also called from Validate(). It looks like this mutex is unneccessary though, but we can leave that for a future patch.


std::ifstream disk_file(filename, std::ios::in | std::ios::binary);

if (!disk_file.is_open()) {
throw std::invalid_argument("Invalid file " + filename);
}
struct plot_header header{};
// 19 bytes - "Proof of Space Plot" (utf-8)
// 32 bytes - unique plot id
// 1 byte - k
// 2 bytes - format description length
// x bytes - format description
// 2 bytes - memo length
// x bytes - memo

SafeRead(disk_file, (uint8_t*)&header, sizeof(header));
if (memcmp(header.magic, "Proof of Space Plot", sizeof(header.magic)) != 0)
throw std::invalid_argument("Invalid plot header magic");

uint16_t fmt_desc_len = Util::TwoBytesToInt(header.fmt_desc_len);

if (fmt_desc_len == kFormatDescription.size() &&
!memcmp(header.fmt_desc, kFormatDescription.c_str(), fmt_desc_len)) {
// OK
} else {
throw std::invalid_argument("Invalid plot file format");
}
id_out = std::vector<uint8_t>(kIdLen);
memcpy(id_out.data(), header.id, sizeof(header.id));
k_out = header.k;
SafeSeek(disk_file, offsetof(struct plot_header, fmt_desc) + fmt_desc_len);

uint8_t size_buf[2];
SafeRead(disk_file, size_buf, 2);
memo_out.resize(Util::TwoBytesToInt(size_buf));
SafeRead(disk_file, memo_out.data(), memo_out.size());

table_begin_pointers_out = std::vector<uint64_t>(table_begin_pointers_size, 0);
C2_out = std::vector<uint64_t>();

uint8_t pointer_buf[8];
for (size_t i = 1; i < table_begin_pointers_size; i++) {
SafeRead(disk_file, pointer_buf, 8);
table_begin_pointers_out[i] = Util::EightBytesToInt(pointer_buf);
}

SafeSeek(disk_file, table_begin_pointers_out[9]);

uint8_t c2_size = (Util::ByteAlign(k_out) / 8);
uint32_t c2_entries = (table_begin_pointers_out[10] - table_begin_pointers_out[9]) / c2_size;
if (c2_entries == 0 || c2_entries == 1) {
throw std::invalid_argument("Invalid C2 table size");
}

// The list of C2 entries is small enough to keep in memory. When proving, we can
// read from disk the C1 and C3 entries.
auto* c2_buf = new uint8_t[c2_size];
for (uint32_t i = 0; i < c2_entries - 1; i++) {
SafeRead(disk_file, c2_buf, c2_size);
C2_out.push_back(Bits(c2_buf, c2_size, c2_size * 8).Slice(0, k).GetValue());
}

delete[] c2_buf;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this leaks if SafeRead() throws an exception. This was an issue before your patch, I just want to make a note to remember to fix it later

}

// Using this method instead of simply seeking will prevent segfaults that would arise when
// continuing the process of looking up qualities.
static void SafeSeek(std::ifstream& disk_file, uint64_t seek_location) {
Expand Down
102 changes: 102 additions & 0 deletions tests/test.cpp
Expand Up @@ -1026,3 +1026,105 @@ TEST_CASE("FilteredDisk")
*/
remove("test_file.bin");
}

TEST_CASE("DiskProver")
{
SECTION("Construction")
{
std::string filename = "prover_test.plot";
DiskPlotter plotter = DiskPlotter();
std::vector<uint8_t> memo{1, 2, 3};
plotter.CreatePlotDisk(
".", ".", ".", filename, 18, memo.data(),
memo.size(), plot_id_1, 32, 11, 0,
4000, 2);
DiskProver prover1(filename);
std::string p1_filename = prover1.GetFilename();
std::vector<uint8_t> p1_id = prover1.GetId();
std::vector<uint8_t> p1_memo = prover1.GetMemo();
uint8_t p1_k = prover1.GetSize();
std::vector<uint64_t> p1_pointers = prover1.GetTableBeginPointers();
std::vector<uint64_t> p1_C2 = prover1.GetC2();
DiskProver prover2(p1_filename,
p1_memo,
p1_id,
p1_k,
p1_pointers,
p1_C2);
REQUIRE(prover1.GetFilename() == prover2.GetFilename());
REQUIRE(prover1.GetSize() == prover2.GetSize());
REQUIRE(prover1.GetId() == prover2.GetId());
REQUIRE(prover1.GetMemo() == prover2.GetMemo());
vector<unsigned char> hash_input = intToBytes(0, 4);
vector<unsigned char> hash(picosha2::k_digest_size);
picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
vector<LargeBits> qualities1 = prover1.GetQualitiesForChallenge(hash.data());
LargeBits proof1 = prover1.GetFullProof(hash.data(), 0);
vector<LargeBits> qualities2 = prover2.GetQualitiesForChallenge(hash.data());
LargeBits proof2 = prover2.GetFullProof(hash.data(), 0);
REQUIRE(qualities1 == qualities2);
REQUIRE(proof1 == proof2);

auto test_invalidity = [](std::string filename, std::vector<uint8_t> memo, std::vector<uint8_t> id, uint8_t k,
std::vector<uint64_t> table_begin_pointers, std::vector<uint64_t> C2) {
DiskProver prover(filename, memo, id, k, table_begin_pointers, C2);
REQUIRE_THROWS(prover.Validate());
REQUIRE(!prover.IsValid());
};

// Test "Invalid file"
test_invalidity("invalid", p1_memo, p1_id, p1_k, p1_pointers, p1_C2);
// Test "Invalid id"
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, {}, p1_k, p1_pointers, p1_C2));
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, std::vector<uint8_t>(kIdLen - 1, 0), p1_k, p1_pointers, p1_C2));
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, std::vector<uint8_t>(kIdLen + 1, 0), p1_k, p1_pointers, p1_C2));
auto invalid_id = p1_id;
invalid_id.back()++;
test_invalidity(p1_filename, p1_memo, invalid_id, p1_k, p1_pointers, p1_C2);
// Test "Invalid k"
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, 0, p1_pointers, p1_C2));
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, kMinPlotSize - 1, p1_pointers, p1_C2));
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, kMaxPlotSize + 1, p1_pointers, p1_C2));
test_invalidity(p1_filename, p1_memo, p1_id, p1_k + 1, p1_pointers, p1_C2);
// Test "Invalid table_begin_pointers size"
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, {}, p1_C2));
auto invalid_pointers = std::vector<uint64_t>(p1_pointers.begin(), p1_pointers.end() - 1);
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, invalid_pointers, p1_C2));
invalid_pointers.push_back(p1_pointers.back());
invalid_pointers.push_back(p1_pointers.back());
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, invalid_pointers, p1_C2));
invalid_pointers.pop_back();
invalid_pointers.back()++;
test_invalidity(p1_filename, p1_memo, p1_id, p1_k, invalid_pointers, p1_C2);
// Test "Invalid C2 size"
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, {}, p1_C2));
auto invalid_c2 = std::vector<uint64_t>(p1_C2.begin(), p1_C2.end() - 1);
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, p1_pointers, invalid_c2));
invalid_c2.push_back(p1_C2.back());
invalid_c2.push_back(p1_C2.back());
REQUIRE_THROWS(DiskProver(p1_filename, p1_memo, p1_id, p1_k, p1_pointers, invalid_c2));
invalid_c2.pop_back();
invalid_c2.back()++;
test_invalidity(p1_filename, p1_memo, p1_id, p1_k, p1_pointers, invalid_c2);
REQUIRE(remove(filename.c_str()) == 0);
// Test move constructor
auto* p1_filename_ptr = prover1.GetFilename().data();
auto* p1_memo_ptr = prover1.GetMemo().data();
auto* p1_id_ptr = prover1.GetId().data();
auto* p1_table_begin_pointers_ptr = prover1.GetTableBeginPointers().data();
auto* p1_C2_ptr = prover1.GetC2().data();
DiskProver prover3(std::move(prover1));
REQUIRE(prover3.GetFilename().data() == p1_filename_ptr);
REQUIRE(prover3.GetMemo().data() == p1_memo_ptr);
REQUIRE(prover3.GetId().data() == p1_id_ptr);
REQUIRE(prover3.GetTableBeginPointers().data() == p1_table_begin_pointers_ptr);
REQUIRE(prover3.GetC2().data() == p1_C2_ptr);
REQUIRE(prover1.GetFilename().empty());
REQUIRE(prover1.GetMemo().empty());
REQUIRE(prover1.GetId().empty());
REQUIRE(prover1.GetSize() == prover1.GetSize());
REQUIRE(prover1.GetTableBeginPointers().empty());
REQUIRE(prover1.GetC2().empty());
REQUIRE(!prover1.IsValid());
}
}