Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wire: cache the non-witness serialization of MsgTx to memoize part of… #1376

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions blockchain/fullblocktests/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ func additionalCoinbase(amount btcutil.Amount) func(*wire.MsgBlock) {
// Increase the first proof-of-work coinbase subsidy by the
// provided amount.
b.Transactions[0].TxOut[0].Value += int64(amount)

b.Transactions[0].WipeCache()
}
}

Expand All @@ -402,6 +404,8 @@ func additionalSpendFee(fee btcutil.Amount) func(*wire.MsgBlock) {
fee))
}
b.Transactions[1].TxOut[0].Value -= int64(fee)

b.Transactions[1].WipeCache()
}
}

Expand All @@ -410,6 +414,8 @@ func additionalSpendFee(fee btcutil.Amount) func(*wire.MsgBlock) {
func replaceSpendScript(pkScript []byte) func(*wire.MsgBlock) {
return func(b *wire.MsgBlock) {
b.Transactions[1].TxOut[0].PkScript = pkScript

b.Transactions[1].WipeCache()
}
}

Expand All @@ -418,6 +424,8 @@ func replaceSpendScript(pkScript []byte) func(*wire.MsgBlock) {
func replaceCoinbaseSigScript(script []byte) func(*wire.MsgBlock) {
return func(b *wire.MsgBlock) {
b.Transactions[0].TxIn[0].SignatureScript = script

b.Transactions[0].WipeCache()
}
}

Expand Down
3 changes: 3 additions & 0 deletions btcutil/txsort/txsort_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ func TestSort(t *testing.T) {
// Now sort the transaction using the mutable version and ensure
// the resulting hash is the expected value.
txsort.InPlaceSort(&tx)

tx.WipeCache()

if got := tx.TxHash().String(); got != test.sortedHash {
t.Errorf("SortMutate (%s): sorted hash does not match "+
"expected - got %v, want %v", test.name, got,
Expand Down
24 changes: 24 additions & 0 deletions wire/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ func TestMessage(t *testing.T) {
spew.Sdump(msg))
continue
}

// Blank out the cached encoding for transactions to ensure the
// deep equality check doesn't fail.
if tx, ok := msg.(*MsgTx); ok {
tx.cachedSeralizedNoWitness = nil
}
if block, ok := msg.(*MsgBlock); ok {
for _, tx := range block.Transactions {
tx.cachedSeralizedNoWitness = nil
}
}

if !reflect.DeepEqual(msg, test.out) {
t.Errorf("ReadMessage #%d\n got: %v want: %v", i,
spew.Sdump(msg), spew.Sdump(test.out))
Expand Down Expand Up @@ -170,6 +182,18 @@ func TestMessage(t *testing.T) {
spew.Sdump(msg))
continue
}

// Blank out the cached encoding for transactions to ensure the
// deep equality check doesn't fail.
if tx, ok := msg.(*MsgTx); ok {
tx.cachedSeralizedNoWitness = nil
}
if block, ok := msg.(*MsgBlock); ok {
for _, tx := range block.Transactions {
tx.cachedSeralizedNoWitness = nil
}
}

if !reflect.DeepEqual(msg, test.out) {
t.Errorf("ReadMessage #%d\n got: %v want: %v", i,
spew.Sdump(msg), spew.Sdump(test.out))
Expand Down
85 changes: 67 additions & 18 deletions wire/msgtx.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ type MsgTx struct {
TxIn []*TxIn
TxOut []*TxOut
LockTime uint32

// cachedSeralizedNoWitness is a cached version of the serialization of
// this transaction without witness data. When we decode a transaction,
// we'll write out the non-witness bytes to this so we can quickly
// calculate the TxHash later if needed.
cachedSeralizedNoWitness []byte
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: s/cachedSeralizedNoWitness/cachedSerializedNoWitness/

}

// AddTxIn adds a transaction input to the message.
Expand All @@ -357,13 +363,26 @@ func (msg *MsgTx) AddTxOut(to *TxOut) {

// TxHash generates the Hash for the transaction.
func (msg *MsgTx) TxHash() chainhash.Hash {
// Encode the transaction and calculate double sha256 on the result.
// Ignore the error returns since the only way the encode could fail
// is being out of memory or due to nil pointers, both of which would
// cause a run-time panic.
buf := bytes.NewBuffer(make([]byte, 0, msg.SerializeSizeStripped()))
_ = msg.SerializeNoWitness(buf)
return chainhash.DoubleHashH(buf.Bytes())
if msg.cachedSeralizedNoWitness == nil {
// Encode the transaction and calculate double sha256 on the
// result. Ignore the error returns since the only way the
// encode could fail is being out of memory or due to nil
// pointers, both of which would cause a run-time panic.
strippedSize := msg.SerializeSizeStripped()
buf := bytes.NewBuffer(make([]byte, 0, strippedSize))
_ = msg.SerializeNoWitness(buf)

msg.cachedSeralizedNoWitness = buf.Bytes()
}

return chainhash.DoubleHashH(msg.cachedSeralizedNoWitness)
}

// WipeCache removes the cached serialized bytes of the transaction. This is
// useful to be able to get the correct txid after mutating a transaction's
// state.
func (msg *MsgTx) WipeCache() {
msg.cachedSeralizedNoWitness = nil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we also need to call this in the helper methods like AddTxIn() and AddTxOut()? Or is the assumption that TxHash() would in practice only be called once the transaction is fully built, so the cache doesn't need to be invalidated?

I'm mostly worrying about uses of MsgTx outside of btcd, where I'm not sure we can 100% guarantee that we're always using this pattern...

}

// WitnessHash generates the hash of the transaction serialized according to
Expand Down Expand Up @@ -461,7 +480,14 @@ func (msg *MsgTx) Copy() *MsgTx {
// See Deserialize for decoding transactions stored to disk, such as in a
// database, as opposed to decoding transactions from the wire.
func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error {
version, err := binarySerializer.Uint32(r, littleEndian)
// We'll use a tee reader in order to incrementally cache the raw
// non-witness serialization of this transaction. We'll then later
// cache this value as it allow to compute the TxHash more quickly, as
// we don't need to re-serialize the entire transaction.
var rawTxBuf bytes.Buffer
rawTxTeeReader := io.TeeReader(r, &rawTxBuf)

version, err := binarySerializer.Uint32(rawTxTeeReader, littleEndian)
if err != nil {
return err
}
Expand All @@ -472,12 +498,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
return err
}

// A count of zero (meaning no TxIn's to the uninitiated) means that the
// value is a TxFlagMarker, and hence indicates the presence of a flag.
var flag [1]TxFlag
// A count of zero (meaning no TxIn's to the uninitiated) indicates
// this is a transaction with witness data. Notice that we don't use
// the rawTxTeeReader here, as these are segwit specific bytes.
var (
flag [1]byte
hasWitneess bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/hasWitneess/hasWitness

)
if count == TxFlagMarker && enc == WitnessEncoding {
// The count varint was in fact the flag marker byte. Next, we need to
// read the flag value, which is a single byte.
// Next, we need to read the flag, which is a single byte.
if _, err = io.ReadFull(r, flag[:]); err != nil {
return err
}
Expand All @@ -495,6 +524,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
if err != nil {
return err
}

hasWitneess = true
}

// Write out the actual number of inputs as this won't be the very byte
// series after the versino of segwit transactions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: s/versino/version

if WriteVarInt(&rawTxBuf, pver, count); err != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add err := WriteVarInt().

str := fmt.Sprintf("unable to write txin count: %v", err)
return messageError("MsgTx.BtcDecode", str)
}

// Prevent more input transactions than could possibly fit into a
Expand Down Expand Up @@ -545,15 +583,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
// and needs to be returned to the pool on error.
ti := &txIns[i]
msg.TxIn[i] = ti
err = readTxIn(r, pver, msg.Version, ti)
err = readTxIn(rawTxTeeReader, pver, msg.Version, ti)
if err != nil {
returnScriptBuffers()
return err
}
totalScriptSize += uint64(len(ti.SignatureScript))
}

count, err = ReadVarInt(r, pver)
count, err = ReadVarInt(rawTxTeeReader, pver)
if err != nil {
returnScriptBuffers()
return err
Expand All @@ -578,7 +616,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
// and needs to be returned to the pool on error.
to := &txOuts[i]
msg.TxOut[i] = to
err = ReadTxOut(r, pver, msg.Version, to)
err = ReadTxOut(rawTxTeeReader, pver, msg.Version, to)
if err != nil {
returnScriptBuffers()
return err
Expand All @@ -588,7 +626,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error

// If the transaction's flag byte isn't 0x00 at this point, then one or
// more of its inputs has accompanying witness data.
if flag[0] != 0 && enc == WitnessEncoding {
if hasWitneess && enc == WitnessEncoding {
for _, txin := range msg.TxIn {
// For each input, the witness is encoded as a stack
// with one or more items. Therefore, we first read a
Expand Down Expand Up @@ -626,7 +664,9 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
}
}

msg.LockTime, err = binarySerializer.Uint32(r, littleEndian)
msg.LockTime, err = binarySerializer.Uint32(
rawTxTeeReader, littleEndian,
)
if err != nil {
returnScriptBuffers()
return err
Expand Down Expand Up @@ -700,6 +740,11 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
scriptPool.Return(pkScript)
}

// Now that we've decoded the entire transaction without any issues,
// we'll cache the non-witness serialization so we can more quickly
// calculate the TxHash in the future.
msg.cachedSeralizedNoWitness = rawTxBuf.Bytes()

return nil
}

Expand Down Expand Up @@ -832,6 +877,10 @@ func (msg *MsgTx) Serialize(w io.Writer) error {
// Serialize, however even if the source transaction has inputs with witness
// data, the old serialization format will still be used.
func (msg *MsgTx) SerializeNoWitness(w io.Writer) error {
if msg.cachedSeralizedNoWitness != nil {
w.Write(msg.cachedSeralizedNoWitness)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing return?


return msg.BtcEncode(w, 0, BaseEncoding)
}

Expand Down
53 changes: 53 additions & 0 deletions wire/msgtx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ func TestTxHash(t *testing.T) {
t.Errorf("TxHash: wrong hash - got %v, want %v",
spew.Sprint(txHash), spew.Sprint(wantHash))
}

// Compute it again to ensure any cached elements, are valid.
txHash = msgTx.TxHash()
if !txHash.IsEqual(wantHash) {
t.Errorf("TxHash: wrong hash - got %v, want %v",
spew.Sprint(txHash), spew.Sprint(wantHash))
}
}

// TestTxSha tests the ability to generate the wtxid, and txid of a transaction
Expand Down Expand Up @@ -258,6 +265,18 @@ func TestWTxSha(t *testing.T) {
t.Errorf("WTxSha: wrong hash - got %v, want %v",
spew.Sprint(wtxid), spew.Sprint(wantHashWTxid))
}

// Compute the values again to ensure any cached elements are valid.
txid = msgTx.TxHash()
if !txid.IsEqual(wantHashTxid) {
t.Errorf("TxSha: wrong hash - got %v, want %v",
spew.Sprint(txid), spew.Sprint(wantHashTxid))
}
wtxid = msgTx.WitnessHash()
if !wtxid.IsEqual(wantHashWTxid) {
t.Errorf("WTxSha: wrong hash - got %v, want %v",
spew.Sprint(wtxid), spew.Sprint(wantHashWTxid))
}
}

// TestTxWire tests the MsgTx wire encode and decode for various numbers
Expand Down Expand Up @@ -393,6 +412,23 @@ func TestTxWire(t *testing.T) {
t.Errorf("BtcDecode #%d error %v", i, err)
continue
}

// If this is the base encoding, then ensure that the cached
// serialization properly matches the raw encoding.
if test.enc == BaseEncoding {
if !bytes.Equal(
test.buf, msg.cachedSeralizedNoWitness,
) {
t.Errorf("BtcdDecode #%d: cached encoding "+
"is wrong, expected %x got %x", i,
test.buf,
msg.cachedSeralizedNoWitness)
continue
}
}

msg.cachedSeralizedNoWitness = nil

if !reflect.DeepEqual(&msg, test.out) {
t.Errorf("BtcDecode #%d\n got: %s want: %s", i,
spew.Sdump(&msg), spew.Sdump(test.out))
Expand Down Expand Up @@ -539,6 +575,23 @@ func TestTxSerialize(t *testing.T) {
t.Errorf("Deserialize #%d error %v", i, err)
continue
}

// Ensure that the raw non-witness encoding matches the cached
// non-witness encoding bytes.
var b bytes.Buffer
if err := tx.SerializeNoWitness(&b); err != nil {
t.Errorf("Deserialize #%d: unable to encode: %v", i, err)
}
if !bytes.Equal(b.Bytes(), tx.cachedSeralizedNoWitness) {
t.Errorf("Deserialize #%d: cached encoding "+
"is wrong, expected %x got %x", i,
b.Bytes(),
tx.cachedSeralizedNoWitness)
continue
}

tx.cachedSeralizedNoWitness = nil

if !reflect.DeepEqual(&tx, test.out) {
t.Errorf("Deserialize #%d\n got: %s want: %s", i,
spew.Sdump(&tx), spew.Sdump(test.out))
Expand Down