Skip to content

Commit

Permalink
wire: cache the non-witness serialization of MsgTx to memoize part of…
Browse files Browse the repository at this point in the history
… TxHash

In this commit, we add a new field to the `MsgTx` struct:
`cachedSeralizedNoWitness`. As we decode the main transaction, we use an
`io.TeeReader` to copy over the non-witness bytes into this new field.
As a result, we can fully cache all tx serialization when computing the
TxHash. This has been shown to show up on profiles during IBD. Caching
this value allows us to optimize TxHash calculation across the entire
daemon as a whole.
  • Loading branch information
Roasbeef committed Nov 16, 2023
1 parent f7e9fba commit 0924825
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 18 deletions.
24 changes: 24 additions & 0 deletions wire/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ func TestMessage(t *testing.T) {
spew.Sdump(msg))
continue
}

// Blank out the cached encoding for transactions to ensure the
// deep equality check doesn't fail.
if tx, ok := msg.(*MsgTx); ok {
tx.cachedSeralizedNoWitness = nil
}
if block, ok := msg.(*MsgBlock); ok {
for _, tx := range block.Transactions {
tx.cachedSeralizedNoWitness = nil
}
}

if !reflect.DeepEqual(msg, test.out) {
t.Errorf("ReadMessage #%d\n got: %v want: %v", i,
spew.Sdump(msg), spew.Sdump(test.out))
Expand Down Expand Up @@ -170,6 +182,18 @@ func TestMessage(t *testing.T) {
spew.Sdump(msg))
continue
}

// Blank out the cached encoding for transactions to ensure the
// deep equality check doesn't fail.
if tx, ok := msg.(*MsgTx); ok {
tx.cachedSeralizedNoWitness = nil
}
if block, ok := msg.(*MsgBlock); ok {
for _, tx := range block.Transactions {
tx.cachedSeralizedNoWitness = nil
}
}

if !reflect.DeepEqual(msg, test.out) {
t.Errorf("ReadMessage #%d\n got: %v want: %v", i,
spew.Sdump(msg), spew.Sdump(test.out))
Expand Down
74 changes: 56 additions & 18 deletions wire/msgtx.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ type MsgTx struct {
TxIn []*TxIn
TxOut []*TxOut
LockTime uint32

// cachedSeralizedNoWitness is a cached version of the serialization of
// this transaction without witness data. When we decode a transaction,
// we'll write out the non-witness bytes to this so we can quickly
// calculate the TxHash later if needed.
cachedSeralizedNoWitness []byte
}

// AddTxIn adds a transaction input to the message.
Expand All @@ -357,13 +363,19 @@ func (msg *MsgTx) AddTxOut(to *TxOut) {

// TxHash generates the Hash for the transaction.
func (msg *MsgTx) TxHash() chainhash.Hash {
// Encode the transaction and calculate double sha256 on the result.
// Ignore the error returns since the only way the encode could fail
// is being out of memory or due to nil pointers, both of which would
// cause a run-time panic.
buf := bytes.NewBuffer(make([]byte, 0, msg.SerializeSizeStripped()))
_ = msg.SerializeNoWitness(buf)
return chainhash.DoubleHashH(buf.Bytes())
if msg.cachedSeralizedNoWitness == nil {
// Encode the transaction and calculate double sha256 on the
// result. Ignore the error returns since the only way the
// encode could fail is being out of memory or due to nil
// pointers, both of which would cause a run-time panic.
strippedSize := msg.SerializeSizeStripped()
buf := bytes.NewBuffer(make([]byte, 0, strippedSize))
_ = msg.SerializeNoWitness(buf)

msg.cachedSeralizedNoWitness = buf.Bytes()
}

return chainhash.DoubleHashH(msg.cachedSeralizedNoWitness)
}

// WitnessHash generates the hash of the transaction serialized according to
Expand Down Expand Up @@ -461,7 +473,14 @@ func (msg *MsgTx) Copy() *MsgTx {
// See Deserialize for decoding transactions stored to disk, such as in a
// database, as opposed to decoding transactions from the wire.
func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error {
version, err := binarySerializer.Uint32(r, littleEndian)
// We'll use a tee reader in order to incrementally cache the raw
// non-witness serialization of this transaction. We'll then later
// cache this value as it allow to compute the TxHash more quickly, as
// we don't need to re-serialize the entire transaction.
var rawTxBuf bytes.Buffer
rawTxTeeReader := io.TeeReader(r, &rawTxBuf)

version, err := binarySerializer.Uint32(rawTxTeeReader, littleEndian)
if err != nil {
return err
}
Expand All @@ -472,12 +491,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
return err
}

// A count of zero (meaning no TxIn's to the uninitiated) means that the
// value is a TxFlagMarker, and hence indicates the presence of a flag.
var flag [1]TxFlag
// A count of zero (meaning no TxIn's to the uninitiated) indicates
// this is a transaction with witness data. Notice that we don't use
// the rawTxTeeReader here, as these are segwit specific bytes.
var (
flag [1]byte
hasWitneess bool
)
if count == TxFlagMarker && enc == WitnessEncoding {
// The count varint was in fact the flag marker byte. Next, we need to
// read the flag value, which is a single byte.
// Next, we need to read the flag, which is a single byte.
if _, err = io.ReadFull(r, flag[:]); err != nil {
return err
}
Expand All @@ -495,6 +517,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
if err != nil {
return err
}

hasWitneess = true
}

// Write out the actual number of inputs as this won't be the very byte
// series after the versino of segwit transactions.
if WriteVarInt(&rawTxBuf, pver, count); err != nil {
str := fmt.Sprintf("unable to write txin count: %v", err)
return messageError("MsgTx.BtcDecode", str)
}

// Prevent more input transactions than could possibly fit into a
Expand Down Expand Up @@ -545,15 +576,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
// and needs to be returned to the pool on error.
ti := &txIns[i]
msg.TxIn[i] = ti
err = readTxIn(r, pver, msg.Version, ti)
err = readTxIn(rawTxTeeReader, pver, msg.Version, ti)
if err != nil {
returnScriptBuffers()
return err
}
totalScriptSize += uint64(len(ti.SignatureScript))
}

count, err = ReadVarInt(r, pver)
count, err = ReadVarInt(rawTxTeeReader, pver)
if err != nil {
returnScriptBuffers()
return err
Expand All @@ -578,7 +609,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
// and needs to be returned to the pool on error.
to := &txOuts[i]
msg.TxOut[i] = to
err = ReadTxOut(r, pver, msg.Version, to)
err = ReadTxOut(rawTxTeeReader, pver, msg.Version, to)
if err != nil {
returnScriptBuffers()
return err
Expand All @@ -588,7 +619,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error

// If the transaction's flag byte isn't 0x00 at this point, then one or
// more of its inputs has accompanying witness data.
if flag[0] != 0 && enc == WitnessEncoding {
if hasWitneess && enc == WitnessEncoding {
for _, txin := range msg.TxIn {
// For each input, the witness is encoded as a stack
// with one or more items. Therefore, we first read a
Expand Down Expand Up @@ -626,7 +657,9 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
}
}

msg.LockTime, err = binarySerializer.Uint32(r, littleEndian)
msg.LockTime, err = binarySerializer.Uint32(
rawTxTeeReader, littleEndian,
)
if err != nil {
returnScriptBuffers()
return err
Expand Down Expand Up @@ -700,6 +733,11 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
scriptPool.Return(pkScript)
}

// Now that we've decoded the entire transaction without any issues,
// we'll cache the non-witness serialization so we can more quickly
// calculate the TxHash in the future.
msg.cachedSeralizedNoWitness = rawTxBuf.Bytes()

return nil
}

Expand Down
53 changes: 53 additions & 0 deletions wire/msgtx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ func TestTxHash(t *testing.T) {
t.Errorf("TxHash: wrong hash - got %v, want %v",
spew.Sprint(txHash), spew.Sprint(wantHash))
}

// Compute it again to ensure any cached elements, are valid.
txHash = msgTx.TxHash()
if !txHash.IsEqual(wantHash) {
t.Errorf("TxHash: wrong hash - got %v, want %v",
spew.Sprint(txHash), spew.Sprint(wantHash))
}
}

// TestTxSha tests the ability to generate the wtxid, and txid of a transaction
Expand Down Expand Up @@ -258,6 +265,18 @@ func TestWTxSha(t *testing.T) {
t.Errorf("WTxSha: wrong hash - got %v, want %v",
spew.Sprint(wtxid), spew.Sprint(wantHashWTxid))
}

// Compute the values again to ensure any cached elements are valid.
txid = msgTx.TxHash()
if !txid.IsEqual(wantHashTxid) {
t.Errorf("TxSha: wrong hash - got %v, want %v",
spew.Sprint(txid), spew.Sprint(wantHashTxid))
}
wtxid = msgTx.WitnessHash()
if !wtxid.IsEqual(wantHashWTxid) {
t.Errorf("WTxSha: wrong hash - got %v, want %v",
spew.Sprint(wtxid), spew.Sprint(wantHashWTxid))
}
}

// TestTxWire tests the MsgTx wire encode and decode for various numbers
Expand Down Expand Up @@ -393,6 +412,23 @@ func TestTxWire(t *testing.T) {
t.Errorf("BtcDecode #%d error %v", i, err)
continue
}

// If this is the base encoding, then ensure that the cached
// serialization properly matches the raw encoding.
if test.enc == BaseEncoding {
if !bytes.Equal(
test.buf, msg.cachedSeralizedNoWitness,
) {
t.Errorf("BtcdDecode #%d: cached encoding "+
"is wrong, expected %x got %x", i,
test.buf,
msg.cachedSeralizedNoWitness)
continue
}
}

msg.cachedSeralizedNoWitness = nil

if !reflect.DeepEqual(&msg, test.out) {
t.Errorf("BtcDecode #%d\n got: %s want: %s", i,
spew.Sdump(&msg), spew.Sdump(test.out))
Expand Down Expand Up @@ -539,6 +575,23 @@ func TestTxSerialize(t *testing.T) {
t.Errorf("Deserialize #%d error %v", i, err)
continue
}

// Ensure that the raw non-witness encoding matches the cached
// non-witness encoding bytes.
var b bytes.Buffer
if err := tx.SerializeNoWitness(&b); err != nil {
t.Errorf("Deserialize #%d: unable to encode: %v", i, err)
}
if !bytes.Equal(b.Bytes(), tx.cachedSeralizedNoWitness) {
t.Errorf("Deserialize #%d: cached encoding "+
"is wrong, expected %x got %x", i,
b.Bytes(),
tx.cachedSeralizedNoWitness)
continue
}

tx.cachedSeralizedNoWitness = nil

if !reflect.DeepEqual(&tx, test.out) {
t.Errorf("Deserialize #%d\n got: %s want: %s", i,
spew.Sdump(&tx), spew.Sdump(test.out))
Expand Down

0 comments on commit 0924825

Please sign in to comment.