Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

src|python-bindings: Use std::vector<uint8_t> for id and memo #302

Merged
merged 1 commit into from Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 4 additions & 10 deletions python-bindings/chiapos.cpp
Expand Up @@ -80,20 +80,14 @@ PYBIND11_MODULE(chiapos, m)
.def(
"get_memo",
[](DiskProver &dp) {
uint8_t *memo = new uint8_t[dp.GetMemoSize()];
dp.GetMemo(memo);
py::bytes ret = py::bytes(reinterpret_cast<char *>(memo), dp.GetMemoSize());
delete[] memo;
return ret;
const std::vector<uint8_t>& memo = dp.GetMemo();
return py::bytes(reinterpret_cast<const char*>(memo.data()), memo.size());
})
.def(
"get_id",
[](DiskProver &dp) {
uint8_t *id = new uint8_t[kIdLen];
dp.GetId(id);
py::bytes ret = py::bytes(reinterpret_cast<char *>(id), kIdLen);
delete[] id;
return ret;
const std::vector<uint8_t>& id = dp.GetId();
return py::bytes(reinterpret_cast<const char*>(id.data()), id.size());
})
.def("get_size", [](DiskProver &dp) { return dp.GetSize(); })
.def("get_filename", [](DiskProver &dp) { return dp.GetFilename(); })
Expand Down
7 changes: 3 additions & 4 deletions src/cli.cpp
Expand Up @@ -247,13 +247,12 @@ int main(int argc, char *argv[]) try {
Verifier verifier = Verifier();

uint32_t success = 0;
uint8_t id_bytes[32];
prover.GetId(id_bytes);
std::vector<uint8_t> id_bytes = prover.GetId();
k = prover.GetSize();

for (uint32_t num = 0; num < iterations; num++) {
vector<unsigned char> hash_input = intToBytes(num, 4);
hash_input.insert(hash_input.end(), &id_bytes[0], &id_bytes[32]);
hash_input.insert(hash_input.end(), id_bytes.begin(), id_bytes.end());

vector<unsigned char> hash(picosha2::k_digest_size);
picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
Expand All @@ -269,7 +268,7 @@ int main(int argc, char *argv[]) try {
cout << "challenge: 0x" << Util::HexStr(hash.data(), 256 / 8) << endl;
cout << "proof: 0x" << Util::HexStr(proof_data, k * 8) << endl;
LargeBits quality =
verifier.ValidateProof(id_bytes, k, hash.data(), proof_data, k * 8);
verifier.ValidateProof(id_bytes.data(), k, hash.data(), proof_data, k * 8);
if (quality.GetSize() == 256 && quality == qualities[i]) {
cout << "quality: " << quality << endl;
cout << "Proof verification succeeded. k = " << static_cast<int>(k) << endl;
Expand Down
24 changes: 9 additions & 15 deletions src/prover_disk.hpp
Expand Up @@ -50,7 +50,7 @@ class DiskProver {
public:
// The constructor opens the file, and reads the contents of the file header. The table pointers
// will be used to find and seek to all seven tables, at the time of proving.
explicit DiskProver(const std::string& filename)
explicit DiskProver(const std::string& filename) : id(kIdLen)
{
struct plot_header header{};
this->filename = filename;
Expand Down Expand Up @@ -80,16 +80,14 @@ class DiskProver {
} else {
throw std::invalid_argument("Invalid plot file format");
}

memcpy(this->id, header.id, sizeof(header.id));
memcpy(id.data(), header.id, sizeof(header.id));
this->k = header.k;
SafeSeek(disk_file, offsetof(struct plot_header, fmt_desc) + fmt_desc_len);

uint8_t size_buf[2];
SafeRead(disk_file, size_buf, 2);
this->memo_size = Util::TwoBytesToInt(size_buf);
this->memo = new uint8_t[this->memo_size];
SafeRead(disk_file, this->memo, this->memo_size);
memo.resize(Util::TwoBytesToInt(size_buf));
xdustinface marked this conversation as resolved.
Show resolved Hide resolved
SafeRead(disk_file, memo.data(), memo.size());

this->table_begin_pointers = std::vector<uint64_t>(11, 0);
this->C2 = std::vector<uint64_t>();
Expand Down Expand Up @@ -122,18 +120,15 @@ class DiskProver {
~DiskProver()
{
std::lock_guard<std::mutex> l(_mtx);
delete[] this->memo;
for (int i = 0; i < 6; i++) {
Encoding::ANSFree(kRValues[i]);
}
Encoding::ANSFree(kC3R);
}

void GetMemo(uint8_t* buffer) { memcpy(buffer, memo, this->memo_size); }

uint32_t GetMemoSize() const noexcept { return this->memo_size; }
const std::vector<uint8_t>& GetMemo() { return memo; }

void GetId(uint8_t* buffer) { memcpy(buffer, id, kIdLen); }
const std::vector<uint8_t>& GetId() { return id; }

std::string GetFilename() const noexcept { return filename; }

Expand Down Expand Up @@ -241,9 +236,8 @@ class DiskProver {
private:
mutable std::mutex _mtx;
std::string filename;
uint32_t memo_size;
uint8_t* memo;
uint8_t id[kIdLen]{}; // Unique plot id
std::vector<uint8_t> memo;
std::vector<uint8_t> id; // Unique plot id
uint8_t k;
std::vector<uint64_t> table_begin_pointers;
std::vector<uint64_t> C2;
Expand Down Expand Up @@ -577,7 +571,7 @@ class DiskProver {
// Where a < b is defined as: max(b) > max(a) where a and b are lists of k bit elements
std::vector<LargeBits> ReorderProof(const std::vector<Bits>& xs_input) const
{
F1Calculator f1(k, id);
F1Calculator f1(k, id.data());
xdustinface marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::pair<Bits, Bits> > results;
LargeBits xs;

Expand Down