Skip to content

Commit

Permalink
wire/msgblock+msgtx: user block-level script slab
Browse files Browse the repository at this point in the history
  • Loading branch information
cfromknecht committed Jan 25, 2020
1 parent 7458cea commit dacbb65
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
13 changes: 9 additions & 4 deletions wire/msgblock.go
Expand Up @@ -86,17 +86,19 @@ func (msg *MsgBlock) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) er
return messageError("MsgBlock.BtcDecode", str)
}

scriptBuf := scriptPool.Borrow()
msg.Transactions = make([]*MsgTx, 0, txCount)
for i := uint64(0); i < txCount; i++ {
tx := MsgTx{}
err := tx.btcDecode(r, pver, enc, buf)
err := tx.btcDecode(r, pver, enc, buf, scriptBuf[:])
if err != nil {
scriptPool.Return(scriptBuf)
binarySerializer.Return(buf)
return err
}
msg.Transactions = append(msg.Transactions, &tx)
}

scriptPool.Return(scriptBuf)
binarySerializer.Return(buf)

return nil
Expand Down Expand Up @@ -164,22 +166,25 @@ func (msg *MsgBlock) DeserializeTxLoc(r *bytes.Buffer) ([]TxLoc, error) {
return nil, messageError("MsgBlock.DeserializeTxLoc", str)
}

scriptBuf := scriptPool.Borrow()

// Deserialize each transaction while keeping track of its location
// within the byte stream.
msg.Transactions = make([]*MsgTx, 0, txCount)
txLocs := make([]TxLoc, txCount)
for i := uint64(0); i < txCount; i++ {
txLocs[i].TxStart = fullLen - r.Len()
tx := MsgTx{}
err := tx.btcDecode(r, 0, WitnessEncoding, buf)
err := tx.btcDecode(r, 0, WitnessEncoding, buf, scriptBuf[:])
if err != nil {
scriptPool.Return(scriptBuf)
binarySerializer.Return(buf)
return nil, err
}
msg.Transactions = append(msg.Transactions, &tx)
txLocs[i].TxLen = (fullLen - r.Len()) - txLocs[i].TxStart
}

scriptPool.Return(scriptBuf)
binarySerializer.Return(buf)

return txLocs, nil
Expand Down
19 changes: 4 additions & 15 deletions wire/msgtx.go
Expand Up @@ -404,13 +404,15 @@ func (msg *MsgTx) Copy() *MsgTx {
// database, as opposed to decoding transactions from the wire.
func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error {
buf := binarySerializer.Borrow()
err := msg.btcDecode(r, pver, enc, buf)
sbuf := scriptPool.Borrow()
err := msg.btcDecode(r, pver, enc, buf, sbuf[:])
scriptPool.Return(sbuf)
binarySerializer.Return(buf)
return err
}

func (msg *MsgTx) btcDecode(r io.Reader, pver uint32, enc MessageEncoding,
buf []byte) error {
buf, sbuf []byte) error {

if _, err := io.ReadFull(r, buf[:4]); err != nil {
return err
Expand Down Expand Up @@ -456,9 +458,6 @@ func (msg *MsgTx) btcDecode(r io.Reader, pver uint32, enc MessageEncoding,
return messageError("MsgTx.BtcDecode", str)
}

scriptBuf := scriptPool.Borrow()
sbuf := scriptBuf[:]

// Deserialize the inputs.
var totalScriptSize uint64
txIns := make([]TxIn, count)
Expand All @@ -470,7 +469,6 @@ func (msg *MsgTx) btcDecode(r io.Reader, pver uint32, enc MessageEncoding,
msg.TxIn[i] = ti
err = readTxInBuf(r, pver, msg.Version, ti, buf, sbuf)
if err != nil {
scriptPool.Return(scriptBuf)
return err
}
totalScriptSize += uint64(len(ti.SignatureScript))
Expand All @@ -479,15 +477,13 @@ func (msg *MsgTx) btcDecode(r io.Reader, pver uint32, enc MessageEncoding,

count, err = ReadVarIntBuf(r, pver, buf)
if err != nil {
scriptPool.Return(scriptBuf)
return err
}

// Prevent more output transactions than could possibly fit into a
// message. It would be possible to cause memory exhaustion and panics
// without a sane upper bound on this count.
if count > uint64(maxTxOutPerMessage) {
scriptPool.Return(scriptBuf)
str := fmt.Sprintf("too many output transactions to fit into "+
"max message size [count %d, max %d]", count,
maxTxOutPerMessage)
Expand All @@ -504,7 +500,6 @@ func (msg *MsgTx) btcDecode(r io.Reader, pver uint32, enc MessageEncoding,
msg.TxOut[i] = to
err = readTxOutBuf(r, pver, msg.Version, to, buf, sbuf)
if err != nil {
scriptPool.Return(scriptBuf)
return err
}
totalScriptSize += uint64(len(to.PkScript))
Expand All @@ -520,14 +515,12 @@ func (msg *MsgTx) btcDecode(r io.Reader, pver uint32, enc MessageEncoding,
// varint which encodes the number of stack items.
witCount, err := ReadVarIntBuf(r, pver, buf)
if err != nil {
scriptPool.Return(scriptBuf)
return err
}

// Prevent a possible memory exhaustion attack by
// limiting the witCount value to a sane upper bound.
if witCount > maxWitnessItemsPerInput {
scriptPool.Return(scriptBuf)
str := fmt.Sprintf("too many witness items to fit "+
"into max message size [count %d, max %d]",
witCount, maxWitnessItemsPerInput)
Expand All @@ -542,7 +535,6 @@ func (msg *MsgTx) btcDecode(r io.Reader, pver uint32, enc MessageEncoding,
txin.Witness[j], err = readScriptBuf(r, pver, buf, sbuf,
maxWitnessItemSize, "script witness item")
if err != nil {
scriptPool.Return(scriptBuf)
return err
}
totalScriptSize += uint64(len(txin.Witness[j]))
Expand All @@ -552,7 +544,6 @@ func (msg *MsgTx) btcDecode(r io.Reader, pver uint32, enc MessageEncoding,
}

if _, err := io.ReadFull(r, buf[:4]); err != nil {
scriptPool.Return(scriptBuf)
return err
}
msg.LockTime = littleEndian.Uint32(buf[:4])
Expand Down Expand Up @@ -615,8 +606,6 @@ func (msg *MsgTx) btcDecode(r io.Reader, pver uint32, enc MessageEncoding,
offset += scriptSize
}

scriptPool.Return(scriptBuf)

return nil
}

Expand Down

0 comments on commit dacbb65

Please sign in to comment.