From 82532113143a80285d9053090ed80bb2bf566c4a Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 16 Jun 2023 14:30:27 +0700 Subject: [PATCH 001/119] implements RLP encoding/decoding processing for structs --- rlp/internal/rlpstruct/rlpstruct.go | 213 ++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 rlp/internal/rlpstruct/rlpstruct.go diff --git a/rlp/internal/rlpstruct/rlpstruct.go b/rlp/internal/rlpstruct/rlpstruct.go new file mode 100644 index 0000000000..2e3eeb6881 --- /dev/null +++ b/rlp/internal/rlpstruct/rlpstruct.go @@ -0,0 +1,213 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package rlpstruct implements struct processing for RLP encoding/decoding. +// +// In particular, this package handles all rules around field filtering, +// struct tags and nil value determination. +package rlpstruct + +import ( + "fmt" + "reflect" + "strings" +) + +// Field represents a struct field. +type Field struct { + Name string + Index int + Exported bool + Type Type + Tag string +} + +// Type represents the attributes of a Go type. +type Type struct { + Name string + Kind reflect.Kind + IsEncoder bool // whether type implements rlp.Encoder + IsDecoder bool // whether type implements rlp.Decoder + Elem *Type // non-nil for Kind values of Ptr, Slice, Array +} + +// DefaultNilValue determines whether a nil pointer to t encodes/decodes +// as an empty string or empty list. +func (t Type) DefaultNilValue() NilKind { + k := t.Kind + if isUint(k) || k == reflect.String || k == reflect.Bool || isByteArray(t) { + return NilKindString + } + return NilKindList +} + +// NilKind is the RLP value encoded in place of nil pointers. +type NilKind uint8 + +const ( + NilKindString NilKind = 0x80 + NilKindList NilKind = 0xC0 +) + +// Tags represents struct tags. +type Tags struct { + // rlp:"nil" controls whether empty input results in a nil pointer. + // nilKind is the kind of empty value allowed for the field. + NilKind NilKind + NilOK bool + + // rlp:"optional" allows for a field to be missing in the input list. + // If this is set, all subsequent fields must also be optional. + Optional bool + + // rlp:"tail" controls whether this field swallows additional list elements. It can + // only be set for the last field, which must be of slice type. + Tail bool + + // rlp:"-" ignores fields. + Ignored bool +} + +// TagError is raised for invalid struct tags. +type TagError struct { + StructType string + + // These are set by this package. + Field string + Tag string + Err string +} + +func (e TagError) Error() string { + field := "field " + e.Field + if e.StructType != "" { + field = e.StructType + "." + e.Field + } + return fmt.Sprintf("rlp: invalid struct tag %q for %s (%s)", e.Tag, field, e.Err) +} + +// ProcessFields filters the given struct fields, returning only fields +// that should be considered for encoding/decoding. +func ProcessFields(allFields []Field) ([]Field, []Tags, error) { + lastPublic := lastPublicField(allFields) + + // Gather all exported fields and their tags. + var fields []Field + var tags []Tags + for _, field := range allFields { + if !field.Exported { + continue + } + ts, err := parseTag(field, lastPublic) + if err != nil { + return nil, nil, err + } + if ts.Ignored { + continue + } + fields = append(fields, field) + tags = append(tags, ts) + } + + // Verify optional field consistency. If any optional field exists, + // all fields after it must also be optional. Note: optional + tail + // is supported. + var anyOptional bool + var firstOptionalName string + for i, ts := range tags { + name := fields[i].Name + if ts.Optional || ts.Tail { + if !anyOptional { + firstOptionalName = name + } + anyOptional = true + } else { + if anyOptional { + msg := fmt.Sprintf("must be optional because preceding field %q is optional", firstOptionalName) + return nil, nil, TagError{Field: name, Err: msg} + } + } + } + return fields, tags, nil +} + +func parseTag(field Field, lastPublic int) (Tags, error) { + name := field.Name + tag := reflect.StructTag(field.Tag) + var ts Tags + for _, t := range strings.Split(tag.Get("rlp"), ",") { + switch t = strings.TrimSpace(t); t { + case "": + // empty tag is allowed for some reason + case "-": + ts.Ignored = true + case "nil", "nilString", "nilList": + ts.NilOK = true + if field.Type.Kind != reflect.Ptr { + return ts, TagError{Field: name, Tag: t, Err: "field is not a pointer"} + } + switch t { + case "nil": + ts.NilKind = field.Type.Elem.DefaultNilValue() + case "nilString": + ts.NilKind = NilKindString + case "nilList": + ts.NilKind = NilKindList + } + case "optional": + ts.Optional = true + if ts.Tail { + return ts, TagError{Field: name, Tag: t, Err: `also has "tail" tag`} + } + case "tail": + ts.Tail = true + if field.Index != lastPublic { + return ts, TagError{Field: name, Tag: t, Err: "must be on last field"} + } + if ts.Optional { + return ts, TagError{Field: name, Tag: t, Err: `also has "optional" tag`} + } + if field.Type.Kind != reflect.Slice { + return ts, TagError{Field: name, Tag: t, Err: "field type is not slice"} + } + default: + return ts, TagError{Field: name, Tag: t, Err: "unknown tag"} + } + } + return ts, nil +} + +func lastPublicField(fields []Field) int { + last := 0 + for _, f := range fields { + if f.Exported { + last = f.Index + } + } + return last +} + +func isUint(k reflect.Kind) bool { + return k >= reflect.Uint && k <= reflect.Uintptr +} + +func isByte(typ Type) bool { + return typ.Kind == reflect.Uint8 && !typ.IsEncoder +} + +func isByteArray(typ Type) bool { + return (typ.Kind == reflect.Slice || typ.Kind == reflect.Array) && isByte(*typ.Elem) +} From fa67be8815ab694e88f71a4a31ba11a63282079c Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 16 Jun 2023 16:01:10 +0700 Subject: [PATCH 002/119] Update RLP lib --- go.mod | 64 ++--- go.sum | 171 ++++++++---- rlp/decode.go | 668 ++++++++++++++++++++++++++++------------------- rlp/doc.go | 149 ++++++++++- rlp/encbuffer.go | 423 ++++++++++++++++++++++++++++++ rlp/encode.go | 596 ++++++++++++++++-------------------------- rlp/iterator.go | 60 +++++ rlp/raw.go | 138 ++++++++++ rlp/safe.go | 27 ++ rlp/typecache.go | 268 ++++++++++++------- rlp/unsafe.go | 35 +++ 11 files changed, 1775 insertions(+), 824 deletions(-) create mode 100644 rlp/encbuffer.go create mode 100644 rlp/iterator.go create mode 100644 rlp/safe.go create mode 100644 rlp/unsafe.go diff --git a/go.mod b/go.mod index 15d820f802..e5db078ba3 100644 --- a/go.mod +++ b/go.mod @@ -4,44 +4,45 @@ go 1.19 require ( bazil.org/fuse v0.0.0-20180421153158-65cc252bf669 - github.com/VictoriaMetrics/fastcache v1.5.7 + github.com/VictoriaMetrics/fastcache v1.6.0 github.com/aristanetworks/goarista v0.0.0-20191023202215-f096da5361bb github.com/btcsuite/btcd v0.0.0-20171128150713-2e60448ffcc6 github.com/cespare/cp v1.1.1 github.com/davecgh/go-spew v1.1.1 github.com/deckarep/golang-set v0.0.0-20180603214616-504e848d77ea - github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf + github.com/docker/docker v1.6.2 github.com/dop251/goja v0.0.0-20230531210528-d7324b2d74f7 github.com/edsrzf/mmap-go v1.0.0 - github.com/fatih/color v1.6.0 + github.com/fatih/color v1.7.0 github.com/gizak/termui v2.2.0+incompatible github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 - github.com/go-stack/stack v1.8.0 - github.com/golang/protobuf v1.3.2 - github.com/golang/snappy v0.0.1 + github.com/go-stack/stack v1.8.1 + github.com/golang/protobuf v1.5.2 + github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb github.com/hashicorp/golang-lru v0.5.3 - github.com/huin/goupnp v1.0.0 + github.com/holiman/uint256 v1.2.2 + github.com/huin/goupnp v1.0.3 github.com/influxdata/influxdb v1.7.9 - github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458 + github.com/jackpal/go-nat-pmp v1.0.2 github.com/julienschmidt/httprouter v1.3.0 github.com/karalabe/hid v1.0.0 - github.com/mattn/go-colorable v0.1.0 + github.com/mattn/go-colorable v0.1.13 github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 - github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c + github.com/olekukonko/tablewriter v0.0.5 github.com/pborman/uuid v1.2.0 github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7 - github.com/pkg/errors v0.8.1 + github.com/pkg/errors v0.9.1 github.com/prometheus/prometheus v1.7.2-0.20170814170113-3101606756c5 github.com/rjeczalik/notify v0.9.2 - github.com/rs/cors v1.6.0 + github.com/rs/cors v1.7.0 github.com/steakknife/bloomfilter v0.0.0-20180922174646-6819c0d2a570 - github.com/stretchr/testify v1.4.0 - github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d - golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/net v0.0.0-20220722155237-a158d28d115b - golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 - golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f - golang.org/x/tools v0.1.12 + github.com/stretchr/testify v1.8.1 + github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 + golang.org/x/crypto v0.1.0 + golang.org/x/net v0.8.0 + golang.org/x/sync v0.1.0 + golang.org/x/sys v0.7.0 + golang.org/x/tools v0.7.0 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c gopkg.in/karalabe/cookiejar.v2 v2.0.0-20150724131613-8dcd6a7f4951 gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce @@ -50,27 +51,28 @@ require ( ) require ( - github.com/cespare/xxhash/v2 v2.1.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/dlclark/regexp2 v1.7.0 // indirect + github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect - github.com/google/go-cmp v0.3.1 // indirect github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect - github.com/google/uuid v1.0.0 // indirect - github.com/kr/pretty v0.3.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/maruel/panicparse v0.0.0-20160720141634-ad661195ed0e // indirect github.com/maruel/ut v1.0.2 // indirect - github.com/mattn/go-isatty v0.0.5-0.20180830101745-3fb116b82035 // indirect - github.com/mattn/go-runewidth v0.0.4 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect + github.com/mattn/go-runewidth v0.0.9 // indirect github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect github.com/naoina/go-stringutil v0.1.0 // indirect github.com/nsf/termbox-go v0.0.0-20170211012700-3540b76b9c77 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.6.1 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3 // indirect - golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect - golang.org/x/text v0.3.8 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect - gotest.tools v2.2.0+incompatible // indirect + golang.org/x/mod v0.9.0 // indirect + golang.org/x/term v0.6.0 // indirect + golang.org/x/text v0.8.0 // indirect + golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect + google.golang.org/protobuf v1.28.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 2fff78c90b..c913b65f5a 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/DataDog/zstd v1.3.6-0.20190409195224-796139022798/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= github.com/Shopify/sarama v1.23.1/go.mod h1:XLH1GYJnLVE0XCr6KdJGVJRTwY30moWNJ4sERjXX6fs= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= -github.com/VictoriaMetrics/fastcache v1.5.7 h1:4y6y0G8PRzszQUYIQHHssv/jgPHAb5qQuuDNdCbyAgw= -github.com/VictoriaMetrics/fastcache v1.5.7/go.mod h1:ptDBkNMQI4RtmVo8VS/XwRY6RoTu1dAWCbrk+6WsEM8= +github.com/VictoriaMetrics/fastcache v1.6.0 h1:C/3Oi3EiBCqufydp1neRZkqcwmEiuRT9c3fqvvgKm5o= +github.com/VictoriaMetrics/fastcache v1.6.0/go.mod h1:0qHz5QP0GMX4pfmMA/zt5RgfNuXJrTP0zS7DqpHGGTw= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8= @@ -23,8 +23,9 @@ github.com/btcsuite/btcd v0.0.0-20171128150713-2e60448ffcc6 h1:Eey/GGQ/E5Xp1P2Ly github.com/btcsuite/btcd v0.0.0-20171128150713-2e60448ffcc6/go.mod h1:Dmm/EzmjnCiweXmzRIAiUWCInVmPgjkzgv5k4tVyXiQ= github.com/cespare/cp v1.1.1 h1:nCb6ZLdB7NRaqsm91JtQTAme2SKJzXVsdPIPkyJr1MU= github.com/cespare/cp v1.1.1/go.mod h1:SOGHArjBr4JWaSDEVpWpo/hNg6RoKrls6Oh40hiwW+s= -github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -38,8 +39,8 @@ github.com/deckarep/golang-set v0.0.0-20180603214616-504e848d77ea/go.mod h1:93vs github.com/dlclark/regexp2 v1.4.1-0.20201116162257-a2a8dda75c91/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= github.com/dlclark/regexp2 v1.7.0 h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo= github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf h1:sh8rkQZavChcmakYiSlqu2425CHyFXLZZnvm7PDpU8M= -github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v1.6.2 h1:HlFGsy+9/xrgMmhmN+NGhCc5SHGJ7I+kHosRR1xc/aI= +github.com/docker/docker v1.6.2/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/dop251/goja v0.0.0-20211022113120-dc8c55024d06/go.mod h1:R9ET47fwRVRPZnOGvHxxhuZcbrMCuiqOz3Rlrh4KSnk= github.com/dop251/goja v0.0.0-20230531210528-d7324b2d74f7 h1:cVGkvrdHgyBkYeB6kMCaF5j2d9Bg4trgbIpcUrKrvk4= github.com/dop251/goja v0.0.0-20230531210528-d7324b2d74f7/go.mod h1:QMWlm50DNe14hD7t24KEqZuUdC9sOTy8W6XbCU1mlw4= @@ -50,9 +51,12 @@ github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1 github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/edsrzf/mmap-go v1.0.0 h1:CEBF7HpRnUCSJgGUb5h1Gm7e3VkmVDrR8lvWVLtrOFw= github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= -github.com/fatih/color v1.6.0 h1:66qjqZk8kalYAvDRtM1AdAJQI0tj4Wrue3Eq3B3pmFU= -github.com/fatih/color v1.6.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= github.com/gizak/termui v2.2.0+incompatible h1:qvZU9Xll/Xd/Xr/YO+HfBKXhy8a8/94ao6vV9DSXzUE= github.com/gizak/termui v2.2.0+incompatible/go.mod h1:PkJoWUt/zacQKysNfQtcw1RW+eK2SxkieVBtl+4ovLA= @@ -63,40 +67,56 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= -github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk= +github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= -github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.3 h1:YPkqC67at8FYaadspW/6uE0COsBxS2656RLEr8Bppgk= github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= -github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/holiman/uint256 v1.2.2 h1:TXKcSGc2WaxPD2+bmzAsVthL4+pEN0YwXcL5qED83vk= +github.com/holiman/uint256 v1.2.2/go.mod h1:SC8Ryt4n+UBbPbIBKaG9zbbDlp4jOru9xFZmPzLUTxw= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/huin/goupnp v1.0.0 h1:wg75sLpL6DZqwHQN6E1Cfk6mtfzS45z8OV+ic+DtHRo= -github.com/huin/goupnp v1.0.0/go.mod h1:n9v9KO1tAxYH82qOn+UTIFQDmx5n1Zxd/ClZDMX7Bnc= +github.com/huin/goupnp v1.0.3 h1:N8No57ls+MnjlB+JPiCVSOyy/ot7MJTqlo7rn+NYSqQ= +github.com/huin/goupnp v1.0.3/go.mod h1:ZxNlw5WqJj6wSsRK5+YfflQGXYfccj5VgQsMNixHM7Y= github.com/huin/goutil v0.0.0-20170803182201-1ca381bf3150/go.mod h1:PpLOETDnJ0o3iZrZfqZzyLl6l7F3c6L1oWn7OICBi6o= github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/influxdata/influxdb v1.7.9 h1:uSeBTNO4rBkbp1Be5FKRsAmglM9nlx25TzVQRQt1An4= github.com/influxdata/influxdb v1.7.9/go.mod h1:qZna6X/4elxqT3yI9iZYdZrWWdeFOOprn86kgg4+IzY= github.com/influxdata/influxdb1-client v0.0.0-20190809212627-fc22c7df067e/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= -github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458 h1:6OvNmYgJyexcZ3pYbTI9jWx5tHo1Dee/tWbLMfPe2TA= -github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= +github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= +github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -111,8 +131,9 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -123,13 +144,13 @@ github.com/maruel/panicparse v0.0.0-20160720141634-ad661195ed0e h1:e2z/lz9pvtRrE github.com/maruel/panicparse v0.0.0-20160720141634-ad661195ed0e/go.mod h1:nty42YY5QByNC5MM7q/nj938VbgPU7avs45z6NClpxI= github.com/maruel/ut v1.0.2 h1:mQTlQk3jubTbdTcza+hwoZQWhzcvE4L6K6RTtAFlA1k= github.com/maruel/ut v1.0.2/go.mod h1:RV8PwPD9dd2KFlnlCc/DB2JVvkXmyaalfc5xvmSrRSs= -github.com/mattn/go-colorable v0.1.0 h1:v2XXALHHh6zHfYTJ+cSkwtyffnaOyR1MXaA91mTrb8o= -github.com/mattn/go-colorable v0.1.0/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= -github.com/mattn/go-isatty v0.0.5-0.20180830101745-3fb116b82035 h1:USWjF42jDCSEeikX/G1g40ZWnsPXN5WkZ4jMHZWyBK4= -github.com/mattn/go-isatty v0.0.5-0.20180830101745-3fb116b82035/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= -github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y= -github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 h1:DpOJ2HYzCv8LZP15IdmG+YdwD2luVPHITV96TkirNBM= github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= @@ -144,15 +165,19 @@ github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 h1:shk/vn9oCoOTmwcou github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416/go.mod h1:NBIhNtsFMo3G2szEBne+bO4gS192HuIYRqfvOWb4i1E= github.com/nsf/termbox-go v0.0.0-20170211012700-3540b76b9c77 h1:gKl78uP/I7JZ56OFtRf7nc4m1icV38hwV0In5pEGzeA= github.com/nsf/termbox-go v0.0.0-20170211012700-3540b76b9c77/go.mod h1:IuKpRQcYE1Tfu+oAQqaLisqDeXgjyyltCfsaoYN18NQ= -github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c h1:1RHs3tNxjXGHeul8z2t6H2N2TlAqpKe5yryJztRx4Jk= -github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= +github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.10.1 h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo= github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/onsi/gomega v1.7.0 h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.0 h1:2mOpI4JVVPBN+WQRa0WKH2eXR+Ey+uK4n7Zj0aYpIQA= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1 h1:o0+MgICZLuZ7xjH7Vx6zS/zcu93/BEp1VwkIW1mEXCE= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/openconfig/gnmi v0.0.0-20190823184014-89b2bf29312c/go.mod h1:t+O9It+LKzfOAhKTT5O0ehDix+MTqbtT0T9t+7zzOvc= github.com/openconfig/reference v0.0.0-20190727015836-8dfd928c9696/go.mod h1:ym2A+zigScwkSEb/cVQB0/ZMpU3rqiH6X7WRRsxgOGw= github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= @@ -160,9 +185,10 @@ github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtP github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7 h1:oYW+YCJ1pachXTQmzR3rNLYGGz4g/UgFcjb28p/viDM= github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7/go.mod h1:CRroGNssyjTd/qIG2FyxByd2S8JEAZXBl4qUrZf8GS0= github.com/pierrec/lz4 v0.0.0-20190327172049-315a67e90e41/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -181,10 +207,11 @@ github.com/prometheus/prometheus v1.7.2-0.20170814170113-3101606756c5/go.mod h1: github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rjeczalik/notify v0.9.2 h1:MiTWrPj55mNDHEiIX5YUSKefw/+lCQVoAFmD6oQm5w8= github.com/rjeczalik/notify v0.9.2/go.mod h1:aErll2f0sUX9PXZnVNyeiObbmTlk5jnMoCa4QEjJeqM= -github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rs/cors v1.6.0 h1:G9tHG9lebljV9mfp9SNPDL36nCDxmo3zTlAf1YgvzmI= -github.com/rs/cors v1.6.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= +github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/steakknife/bloomfilter v0.0.0-20180922174646-6819c0d2a570 h1:gIlAHnH1vJb5vwEjIp5kBj/eu99p/bl0Ay2goiPe5xE= @@ -193,12 +220,16 @@ github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3 h1:njlZPzLwU639 github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3/go.mod h1:hpGUWaI9xL8pRQCTXQgocU38Qw1g0Us7n5PxxTwTCYU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d h1:gZZadD8H+fF+n9CmNhYL1Y0dJB+kLOmKd7FbPJLeGHs= -github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d/go.mod h1:9OrXJhf154huy1nPWmuSrkgjPUtUNhA+Zmy+6AESzuA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= +github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/templexxx/cpufeat v0.0.0-20180724012125-cef66df7f161/go.mod h1:wM7WEvslTq+iOEAMDLSzhVuOt5BRZ05WirO+b09GHQU= github.com/templexxx/xor v0.0.0-20181023030647-4e92f724b73b/go.mod h1:5XA7W9S6mni3h5uvOC75dA3m9CCCaS83lltmc0ukdi4= github.com/tjfoc/gmsm v1.0.1/go.mod h1:XxO4hdhhrzAd+G4CjDqaOkd0hUzmtPR/d3EiBBMn/wc= @@ -210,64 +241,100 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190404164418-38d8ce5564a5/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= +golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= +golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190912160710-24e19bdeb0f2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180926160741-c2ed4eda69e7/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190801041406-cbf593c0f2f3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190912141932-bc967efca4b8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190912185636-87d9f09c5d89/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= +golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df h1:5Pf6pFKu98ODmgnpvkJ3kFUOQGGLIzLIkbzUHp47618= +golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/bsm/ratelimit.v1 v1.0.0-20160220154919-db14e161995a/go.mod h1:KF9sEfUPAXdG8Oev9e99iLGnl2uJMjc5B+4y3O7x610= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -275,7 +342,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo= gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q= @@ -295,8 +361,11 @@ gopkg.in/urfave/cli.v1 v1.20.0 h1:NdAVW6RYxDif9DhDHaAortIu956m2c0v+09AZBPTbE0= gopkg.in/urfave/cli.v1 v1.20.0/go.mod h1:vuBzUtMdQeixQj8LVd+/98pzhxNGQoyuPBlsXHOQNO0= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= -gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/rlp/decode.go b/rlp/decode.go index 60d9dab2b5..20c454ca9c 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -26,100 +26,78 @@ import ( "math/big" "reflect" "strings" + "sync" + + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" + + "github.com/holiman/uint256" ) +//lint:ignore ST1012 EOL is not an error. + +// EOL is returned when the end of the current list +// has been reached during streaming. +var EOL = errors.New("rlp: end of list") + var ( + ErrExpectedString = errors.New("rlp: expected String or Byte") + ErrExpectedList = errors.New("rlp: expected List") + ErrCanonInt = errors.New("rlp: non-canonical integer format") + ErrCanonSize = errors.New("rlp: non-canonical size information") + ErrElemTooLarge = errors.New("rlp: element is larger than containing list") + ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") + ErrMoreThanOneValue = errors.New("rlp: input contains more than one value") + + // internal errors + errNotInList = errors.New("rlp: call of ListEnd outside of any list") + errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") + errUintOverflow = errors.New("rlp: uint overflow") errNoPointer = errors.New("rlp: interface given to Decode must be a pointer") errDecodeIntoNil = errors.New("rlp: pointer given to Decode must not be nil") + errUint256Large = errors.New("rlp: value too large for uint256") + + streamPool = sync.Pool{ + New: func() interface{} { return new(Stream) }, + } ) -// Decoder is implemented by types that require custom RLP -// decoding rules or need to decode into private fields. +// Decoder is implemented by types that require custom RLP decoding rules or need to decode +// into private fields. // -// The DecodeRLP method should read one value from the given -// Stream. It is not forbidden to read less or more, but it might -// be confusing. +// The DecodeRLP method should read one value from the given Stream. It is not forbidden to +// read less or more, but it might be confusing. type Decoder interface { DecodeRLP(*Stream) error } -// Decode parses RLP-encoded data from r and stores the result in the -// value pointed to by val. Val must be a non-nil pointer. If r does -// not implement ByteReader, Decode will do its own buffering. -// -// Decode uses the following type-dependent decoding rules: -// -// If the type implements the Decoder interface, decode calls -// DecodeRLP. -// -// To decode into a pointer, Decode will decode into the value pointed -// to. If the pointer is nil, a new value of the pointer's element -// type is allocated. If the pointer is non-nil, the existing value -// will be reused. -// -// To decode into a struct, Decode expects the input to be an RLP -// list. The decoded elements of the list are assigned to each public -// field in the order given by the struct's definition. The input list -// must contain an element for each decoded field. Decode returns an -// error if there are too few or too many elements. -// -// The decoding of struct fields honours certain struct tags, "tail", -// "nil" and "-". -// -// The "-" tag ignores fields. -// -// For an explanation of "tail", see the example. -// -// The "nil" tag applies to pointer-typed fields and changes the decoding -// rules for the field such that input values of size zero decode as a nil -// pointer. This tag can be useful when decoding recursive types. -// -// type StructWithEmptyOK struct { -// Foo *[20]byte `rlp:"nil"` -// } -// -// To decode into a slice, the input must be a list and the resulting -// slice will contain the input elements in order. For byte slices, -// the input must be an RLP string. Array types decode similarly, with -// the additional restriction that the number of input elements (or -// bytes) must match the array's length. -// -// To decode into a Go string, the input must be an RLP string. The -// input bytes are taken as-is and will not necessarily be valid UTF-8. -// -// To decode into an unsigned integer type, the input must also be an RLP -// string. The bytes are interpreted as a big endian representation of -// the integer. If the RLP string is larger than the bit size of the -// type, Decode will return an error. Decode also supports *big.Int. -// There is no size limit for big integers. -// -// To decode into an interface value, Decode stores one of these -// in the value: -// -// []interface{}, for RLP lists -// []byte, for RLP strings +// Decode parses RLP-encoded data from r and stores the result in the value pointed to by +// val. Please see package-level documentation for the decoding rules. Val must be a +// non-nil pointer. // -// Non-empty interface types are not supported, nor are booleans, -// signed integers, floating point numbers, maps, channels and -// functions. +// If r does not implement ByteReader, Decode will do its own buffering. // -// Note that Decode does not set an input limit for all readers -// and may be vulnerable to panics cause by huge value sizes. If -// you need an input limit, use +// Note that Decode does not set an input limit for all readers and may be vulnerable to +// panics cause by huge value sizes. If you need an input limit, use // -// NewStream(r, limit).Decode(val) +// NewStream(r, limit).Decode(val) func Decode(r io.Reader, val interface{}) error { - // TODO: this could use a Stream from a pool. - return NewStream(r, 0).Decode(val) + stream := streamPool.Get().(*Stream) + defer streamPool.Put(stream) + + stream.Reset(r, 0) + return stream.Decode(val) } -// DecodeBytes parses RLP data from b into val. -// Please see the documentation of Decode for the decoding rules. -// The input must contain exactly one value and no trailing data. +// DecodeBytes parses RLP data from b into val. Please see package-level documentation for +// the decoding rules. The input must contain exactly one value and no trailing data. func DecodeBytes(b []byte, val interface{}) error { - // TODO: this could use a Stream from a pool. r := bytes.NewReader(b) - if err := NewStream(r, uint64(len(b))).Decode(val); err != nil { + + stream := streamPool.Get().(*Stream) + defer streamPool.Put(stream) + + stream.Reset(r, uint64(len(b))) + if err := stream.Decode(val); err != nil { return err } if r.Len() > 0 { @@ -173,21 +151,26 @@ func addErrorContext(err error, ctx string) error { var ( decoderInterface = reflect.TypeOf(new(Decoder)).Elem() bigInt = reflect.TypeOf(big.Int{}) + u256Int = reflect.TypeOf(uint256.Int{}) ) -func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { +func makeDecoder(typ reflect.Type, tags rlpstruct.Tags) (dec decoder, err error) { kind := typ.Kind() switch { case typ == rawValueType: return decodeRawValue, nil - case typ.Implements(decoderInterface): - return decodeDecoder, nil - case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(decoderInterface): - return decodeDecoderNoPtr, nil case typ.AssignableTo(reflect.PtrTo(bigInt)): return decodeBigInt, nil case typ.AssignableTo(bigInt): return decodeBigIntNoPtr, nil + case typ == reflect.PtrTo(u256Int): + return decodeU256, nil + case typ == u256Int: + return decodeU256NoPtr, nil + case kind == reflect.Ptr: + return makePtrDecoder(typ, tags) + case reflect.PtrTo(typ).Implements(decoderInterface): + return decodeDecoder, nil case isUint(kind): return decodeUint, nil case kind == reflect.Bool: @@ -198,11 +181,6 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { return makeListDecoder(typ, tags) case kind == reflect.Struct: return makeStructDecoder(typ) - case kind == reflect.Ptr: - if tags.nilOK { - return makeOptionalPtrDecoder(typ) - } - return makePtrDecoder(typ) case kind == reflect.Interface: return decodeInterface, nil default: @@ -252,35 +230,48 @@ func decodeBigIntNoPtr(s *Stream, val reflect.Value) error { } func decodeBigInt(s *Stream, val reflect.Value) error { - b, err := s.Bytes() + i := val.Interface().(*big.Int) + if i == nil { + i = new(big.Int) + val.Set(reflect.ValueOf(i)) + } + + err := s.decodeBigInt(i) if err != nil { return wrapStreamError(err, val.Type()) } - i := val.Interface().(*big.Int) + return nil +} + +func decodeU256NoPtr(s *Stream, val reflect.Value) error { + return decodeU256(s, val.Addr()) +} + +func decodeU256(s *Stream, val reflect.Value) error { + i := val.Interface().(*uint256.Int) if i == nil { - i = new(big.Int) + i = new(uint256.Int) val.Set(reflect.ValueOf(i)) } - // Reject leading zero bytes - if len(b) > 0 && b[0] == 0 { - return wrapStreamError(ErrCanonInt, val.Type()) + + err := s.ReadUint256(i) + if err != nil { + return wrapStreamError(err, val.Type()) } - i.SetBytes(b) return nil } -func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) { +func makeListDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) { etype := typ.Elem() if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) { if typ.Kind() == reflect.Array { return decodeByteArray, nil - } else { - return decodeByteSlice, nil } + return decodeByteSlice, nil } - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{}) + if etypeinfo.decoderErr != nil { + return nil, etypeinfo.decoderErr } var dec decoder switch { @@ -288,7 +279,7 @@ func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) { dec = func(s *Stream, val reflect.Value) error { return decodeListArray(s, val, etypeinfo.decoder) } - case tag.tail: + case tag.Tail: // A slice with "tail" tag can occur as the last field // of a struct and is supposed to swallow all remaining // list elements. The struct decoder already called s.List, @@ -381,25 +372,23 @@ func decodeByteArray(s *Stream, val reflect.Value) error { if err != nil { return err } - vlen := val.Len() + slice := byteArrayBytes(val, val.Len()) switch kind { case Byte: - if vlen == 0 { + if len(slice) == 0 { return &decodeError{msg: "input string too long", typ: val.Type()} - } - if vlen > 1 { + } else if len(slice) > 1 { return &decodeError{msg: "input string too short", typ: val.Type()} } - bv, _ := s.Uint() - val.Index(0).SetUint(bv) + slice[0] = s.byteval + s.kind = -1 case String: - if uint64(vlen) < size { + if uint64(len(slice)) < size { return &decodeError{msg: "input string too long", typ: val.Type()} } - if uint64(vlen) > size { + if uint64(len(slice)) > size { return &decodeError{msg: "input string too short", typ: val.Type()} } - slice := val.Slice(0, vlen).Interface().([]byte) if err := s.readFull(slice); err != nil { return err } @@ -418,13 +407,25 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { if err != nil { return nil, err } + for _, f := range fields { + if f.info.decoderErr != nil { + return nil, structFieldError{typ, f.index, f.info.decoderErr} + } + } dec := func(s *Stream, val reflect.Value) (err error) { if _, err := s.List(); err != nil { return wrapStreamError(err, typ) } - for _, f := range fields { + for i, f := range fields { err := f.info.decoder(s, val.Field(f.index)) if err == EOL { + if f.optional { + // The field is optional, so reaching the end of the list before + // reaching the last field is acceptable. All remaining undecoded + // fields are zeroed. + zeroFields(val, fields[i:]) + break + } return &decodeError{msg: "too few elements", typ: typ} } else if err != nil { return addErrorContext(err, "."+typ.Field(f.index).Name) @@ -435,15 +436,29 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { return dec, nil } -// makePtrDecoder creates a decoder that decodes into -// the pointer's element type. -func makePtrDecoder(typ reflect.Type) (decoder, error) { +func zeroFields(structval reflect.Value, fields []field) { + for _, f := range fields { + fv := structval.Field(f.index) + fv.Set(reflect.Zero(fv.Type())) + } +} + +// makePtrDecoder creates a decoder that decodes into the pointer's element type. +func makePtrDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) { etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{}) + switch { + case etypeinfo.decoderErr != nil: + return nil, etypeinfo.decoderErr + case !tag.NilOK: + return makeSimplePtrDecoder(etype, etypeinfo), nil + default: + return makeNilPtrDecoder(etype, etypeinfo, tag), nil } - dec := func(s *Stream, val reflect.Value) (err error) { +} + +func makeSimplePtrDecoder(etype reflect.Type, etypeinfo *typeinfo) decoder { + return func(s *Stream, val reflect.Value) (err error) { newval := val if val.IsNil() { newval = reflect.New(etype) @@ -453,30 +468,39 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) { } return err } - return dec, nil } -// makeOptionalPtrDecoder creates a decoder that decodes empty values -// as nil. Non-empty values are decoded into a value of the element type, -// just like makePtrDecoder does. +// makeNilPtrDecoder creates a decoder that decodes empty values as nil. Non-empty +// values are decoded into a value of the element type, just like makePtrDecoder does. // // This decoder is used for pointer-typed struct fields with struct tag "nil". -func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { - etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err - } - dec := func(s *Stream, val reflect.Value) (err error) { +func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, ts rlpstruct.Tags) decoder { + typ := reflect.PtrTo(etype) + nilPtr := reflect.Zero(typ) + + // Determine the value kind that results in nil pointer. + nilKind := typeNilKind(etype, ts) + + return func(s *Stream, val reflect.Value) (err error) { kind, size, err := s.Kind() - if err != nil || size == 0 && kind != Byte { + if err != nil { + val.Set(nilPtr) + return wrapStreamError(err, typ) + } + // Handle empty values as a nil pointer. + if kind != Byte && size == 0 { + if kind != nilKind { + return &decodeError{ + msg: fmt.Sprintf("wrong kind of empty value (got %v, want %v)", kind, nilKind), + typ: typ, + } + } // rearm s.Kind. This is important because the input // position must advance to the next value even though // we don't read anything. s.kind = -1 - // set the pointer to nil. - val.Set(reflect.Zero(typ)) - return err + val.Set(nilPtr) + return nil } newval := val if val.IsNil() { @@ -487,7 +511,6 @@ func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { } return err } - return dec, nil } var ifsliceType = reflect.TypeOf([]interface{}{}) @@ -516,25 +539,12 @@ func decodeInterface(s *Stream, val reflect.Value) error { return nil } -// This decoder is used for non-pointer values of types -// that implement the Decoder interface using a pointer receiver. -func decodeDecoderNoPtr(s *Stream, val reflect.Value) error { - return val.Addr().Interface().(Decoder).DecodeRLP(s) -} - func decodeDecoder(s *Stream, val reflect.Value) error { - // Decoder instances are not handled using the pointer rule if the type - // implements Decoder with pointer receiver (i.e. always) - // because it might handle empty values specially. - // We need to allocate one here in this case, like makePtrDecoder does. - if val.Kind() == reflect.Ptr && val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) - } - return val.Interface().(Decoder).DecodeRLP(s) + return val.Addr().Interface().(Decoder).DecodeRLP(s) } // Kind represents the kind of value contained in an RLP stream. -type Kind int +type Kind int8 const ( Byte Kind = iota @@ -555,29 +565,6 @@ func (k Kind) String() string { } } -var ( - // EOL is returned when the end of the current list - // has been reached during streaming. - EOL = errors.New("rlp: end of list") - - // Actual Errors - ErrExpectedString = errors.New("rlp: expected String or Byte") - ErrExpectedList = errors.New("rlp: expected List") - ErrCanonInt = errors.New("rlp: non-canonical integer format") - ErrCanonSize = errors.New("rlp: non-canonical size information") - ErrElemTooLarge = errors.New("rlp: element is larger than containing list") - ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") - - // This error is reported by DecodeBytes if the slice contains - // additional data after the first RLP value. - ErrMoreThanOneValue = errors.New("rlp: input contains more than one value") - - // internal errors - errNotInList = errors.New("rlp: call of ListEnd outside of any list") - errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") - errUintOverflow = errors.New("rlp: uint overflow") -) - // ByteReader must be implemented by any input reader for a Stream. It // is implemented by e.g. bufio.Reader and bytes.Reader. type ByteReader interface { @@ -600,22 +587,16 @@ type ByteReader interface { type Stream struct { r ByteReader - // number of bytes remaining to be read from r. - remaining uint64 - limited bool - - // auxiliary buffer for integer decoding - uintbuf []byte - - kind Kind // kind of value ahead - size uint64 // size of value ahead - byteval byte // value of single byte in type tag - kinderr error // error from last readKind - stack []listpos + remaining uint64 // number of bytes remaining to be read from r + size uint64 // size of value ahead + kinderr error // error from last readKind + stack []uint64 // list sizes + uintbuf [32]byte // auxiliary buffer for integer decoding + kind Kind // kind of value ahead + byteval byte // value of single byte in type tag + limited bool // true if input limit is in effect } -type listpos struct{ pos, size uint64 } - // NewStream creates a new decoding stream reading from r. // // If r implements the ByteReader interface, Stream will @@ -675,6 +656,37 @@ func (s *Stream) Bytes() ([]byte, error) { } } +// ReadBytes decodes the next RLP value and stores the result in b. +// The value size must match len(b) exactly. +func (s *Stream) ReadBytes(b []byte) error { + kind, size, err := s.Kind() + if err != nil { + return err + } + switch kind { + case Byte: + if len(b) != 1 { + return fmt.Errorf("input value has wrong size 1, want %d", len(b)) + } + b[0] = s.byteval + s.kind = -1 // rearm Kind + return nil + case String: + if uint64(len(b)) != size { + return fmt.Errorf("input value has wrong size %d, want %d", size, len(b)) + } + if err = s.readFull(b); err != nil { + return err + } + if size == 1 && b[0] < 128 { + return ErrCanonSize + } + return nil + default: + return ErrExpectedString + } +} + // Raw reads a raw encoded value including RLP type information. func (s *Stream) Raw() ([]byte, error) { kind, size, err := s.Kind() @@ -685,8 +697,8 @@ func (s *Stream) Raw() ([]byte, error) { s.kind = -1 // rearm Kind return []byte{s.byteval}, nil } - // the original header has already been read and is no longer - // available. read content and put a new header in front of it. + // The original header has already been read and is no longer + // available. Read content and put a new header in front of it. start := headsize(size) buf := make([]byte, uint64(start)+size) if err := s.readFull(buf[start:]); err != nil { @@ -703,10 +715,31 @@ func (s *Stream) Raw() ([]byte, error) { // Uint reads an RLP string of up to 8 bytes and returns its contents // as an unsigned integer. If the input does not contain an RLP string, the // returned error will be ErrExpectedString. +// +// Deprecated: use s.Uint64 instead. func (s *Stream) Uint() (uint64, error) { return s.uint(64) } +func (s *Stream) Uint64() (uint64, error) { + return s.uint(64) +} + +func (s *Stream) Uint32() (uint32, error) { + i, err := s.uint(32) + return uint32(i), err +} + +func (s *Stream) Uint16() (uint16, error) { + i, err := s.uint(16) + return uint16(i), err +} + +func (s *Stream) Uint8() (uint8, error) { + i, err := s.uint(8) + return uint8(i), err +} + func (s *Stream) uint(maxbits int) (uint64, error) { kind, size, err := s.Kind() if err != nil { @@ -769,7 +802,14 @@ func (s *Stream) List() (size uint64, err error) { if kind != List { return 0, ErrExpectedList } - s.stack = append(s.stack, listpos{0, size}) + + // Remove size of inner list from outer list before pushing the new size + // onto the stack. This ensures that the remaining outer list size will + // be correct after the matching call to ListEnd. + if inList, limit := s.listLimit(); inList { + s.stack[len(s.stack)-1] = limit - size + } + s.stack = append(s.stack, size) s.kind = -1 s.size = 0 return size, nil @@ -778,22 +818,116 @@ func (s *Stream) List() (size uint64, err error) { // ListEnd returns to the enclosing list. // The input reader must be positioned at the end of a list. func (s *Stream) ListEnd() error { - if len(s.stack) == 0 { + // Ensure that no more data is remaining in the current list. + if inList, listLimit := s.listLimit(); !inList { return errNotInList - } - tos := s.stack[len(s.stack)-1] - if tos.pos != tos.size { + } else if listLimit > 0 { return errNotAtEOL } s.stack = s.stack[:len(s.stack)-1] // pop - if len(s.stack) > 0 { - s.stack[len(s.stack)-1].pos += tos.size - } s.kind = -1 s.size = 0 return nil } +// MoreDataInList reports whether the current list context contains +// more data to be read. +func (s *Stream) MoreDataInList() bool { + _, listLimit := s.listLimit() + return listLimit > 0 +} + +// BigInt decodes an arbitrary-size integer value. +func (s *Stream) BigInt() (*big.Int, error) { + i := new(big.Int) + if err := s.decodeBigInt(i); err != nil { + return nil, err + } + return i, nil +} + +func (s *Stream) decodeBigInt(dst *big.Int) error { + var buffer []byte + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == List: + return ErrExpectedString + case kind == Byte: + buffer = s.uintbuf[:1] + buffer[0] = s.byteval + s.kind = -1 // re-arm Kind + case size == 0: + // Avoid zero-length read. + s.kind = -1 + case size <= uint64(len(s.uintbuf)): + // For integers smaller than s.uintbuf, allocating a buffer + // can be avoided. + buffer = s.uintbuf[:size] + if err := s.readFull(buffer); err != nil { + return err + } + // Reject inputs where single byte encoding should have been used. + if size == 1 && buffer[0] < 128 { + return ErrCanonSize + } + default: + // For large integers, a temporary buffer is needed. + buffer = make([]byte, size) + if err := s.readFull(buffer); err != nil { + return err + } + } + + // Reject leading zero bytes. + if len(buffer) > 0 && buffer[0] == 0 { + return ErrCanonInt + } + // Set the integer bytes. + dst.SetBytes(buffer) + return nil +} + +// ReadUint256 decodes the next value as a uint256. +func (s *Stream) ReadUint256(dst *uint256.Int) error { + var buffer []byte + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == List: + return ErrExpectedString + case kind == Byte: + buffer = s.uintbuf[:1] + buffer[0] = s.byteval + s.kind = -1 // re-arm Kind + case size == 0: + // Avoid zero-length read. + s.kind = -1 + case size <= uint64(len(s.uintbuf)): + // All possible uint256 values fit into s.uintbuf. + buffer = s.uintbuf[:size] + if err := s.readFull(buffer); err != nil { + return err + } + // Reject inputs where single byte encoding should have been used. + if size == 1 && buffer[0] < 128 { + return ErrCanonSize + } + default: + return errUint256Large + } + + // Reject leading zero bytes. + if len(buffer) > 0 && buffer[0] == 0 { + return ErrCanonInt + } + // Set the integer bytes. + dst.SetBytes(buffer) + return nil +} + // Decode decodes a value and stores the result in the value pointed // to by val. Please see the documentation for the Decode function // to learn about the decoding rules. @@ -809,14 +943,14 @@ func (s *Stream) Decode(val interface{}) error { if rval.IsNil() { return errDecodeIntoNil } - info, err := cachedTypeInfo(rtyp.Elem(), tags{}) + decoder, err := cachedDecoder(rtyp.Elem()) if err != nil { return err } - err = info.decoder(s, rval.Elem()) + err = decoder(s, rval.Elem()) if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 { - // add decode target type to error so context has more meaning + // Add decode target type to error so context has more meaning. decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")")) } return err @@ -839,6 +973,9 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { case *bytes.Reader: s.remaining = uint64(br.Len()) s.limited = true + case *bytes.Buffer: + s.remaining = uint64(br.Len()) + s.limited = true case *strings.Reader: s.remaining = uint64(br.Len()) s.limited = true @@ -857,9 +994,8 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { s.size = 0 s.kind = -1 s.kinderr = nil - if s.uintbuf == nil { - s.uintbuf = make([]byte, 8) - } + s.byteval = 0 + s.uintbuf = [32]byte{} } // Kind returns the kind and size of the next value in the @@ -874,35 +1010,29 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { // the value. Subsequent calls to Kind (until the value is decoded) // will not advance the input reader and return cached information. func (s *Stream) Kind() (kind Kind, size uint64, err error) { - var tos *listpos - if len(s.stack) > 0 { - tos = &s.stack[len(s.stack)-1] - } - if s.kind < 0 { - s.kinderr = nil - // Don't read further if we're at the end of the - // innermost list. - if tos != nil && tos.pos == tos.size { - return 0, 0, EOL - } - s.kind, s.size, s.kinderr = s.readKind() - if s.kinderr == nil { - if tos == nil { - // At toplevel, check that the value is smaller - // than the remaining input length. - if s.limited && s.size > s.remaining { - s.kinderr = ErrValueTooLarge - } - } else { - // Inside a list, check that the value doesn't overflow the list. - if s.size > tos.size-tos.pos { - s.kinderr = ErrElemTooLarge - } - } + if s.kind >= 0 { + return s.kind, s.size, s.kinderr + } + + // Check for end of list. This needs to be done here because readKind + // checks against the list size, and would return the wrong error. + inList, listLimit := s.listLimit() + if inList && listLimit == 0 { + return 0, 0, EOL + } + // Read the actual size tag. + s.kind, s.size, s.kinderr = s.readKind() + if s.kinderr == nil { + // Check the data size of the value ahead against input limits. This + // is done here because many decoders require allocating an input + // buffer matching the value size. Checking it here protects those + // decoders from inputs declaring very large value size. + if inList && s.size > listLimit { + s.kinderr = ErrElemTooLarge + } else if s.limited && s.size > s.remaining { + s.kinderr = ErrValueTooLarge } } - // Note: this might return a sticky error generated - // by an earlier call to readKind. return s.kind, s.size, s.kinderr } @@ -929,37 +1059,35 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) { s.byteval = b return Byte, 0, nil case b < 0xB8: - // Otherwise, if a string is 0-55 bytes long, - // the RLP encoding consists of a single byte with value 0x80 plus the - // length of the string followed by the string. The range of the first - // byte is thus [0x80, 0xB7]. + // Otherwise, if a string is 0-55 bytes long, the RLP encoding consists + // of a single byte with value 0x80 plus the length of the string + // followed by the string. The range of the first byte is thus [0x80, 0xB7]. return String, uint64(b - 0x80), nil case b < 0xC0: - // If a string is more than 55 bytes long, the - // RLP encoding consists of a single byte with value 0xB7 plus the length - // of the length of the string in binary form, followed by the length of - // the string, followed by the string. For example, a length-1024 string - // would be encoded as 0xB90400 followed by the string. The range of - // the first byte is thus [0xB8, 0xBF]. + // If a string is more than 55 bytes long, the RLP encoding consists of a + // single byte with value 0xB7 plus the length of the length of the + // string in binary form, followed by the length of the string, followed + // by the string. For example, a length-1024 string would be encoded as + // 0xB90400 followed by the string. The range of the first byte is thus + // [0xB8, 0xBF]. size, err = s.readUint(b - 0xB7) if err == nil && size < 56 { err = ErrCanonSize } return String, size, err case b < 0xF8: - // If the total payload of a list - // (i.e. the combined length of all its items) is 0-55 bytes long, the - // RLP encoding consists of a single byte with value 0xC0 plus the length - // of the list followed by the concatenation of the RLP encodings of the - // items. The range of the first byte is thus [0xC0, 0xF7]. + // If the total payload of a list (i.e. the combined length of all its + // items) is 0-55 bytes long, the RLP encoding consists of a single byte + // with value 0xC0 plus the length of the list followed by the + // concatenation of the RLP encodings of the items. The range of the + // first byte is thus [0xC0, 0xF7]. return List, uint64(b - 0xC0), nil default: - // If the total payload of a list is more than 55 bytes long, - // the RLP encoding consists of a single byte with value 0xF7 - // plus the length of the length of the payload in binary - // form, followed by the length of the payload, followed by - // the concatenation of the RLP encodings of the items. The - // range of the first byte is thus [0xF8, 0xFF]. + // If the total payload of a list is more than 55 bytes long, the RLP + // encoding consists of a single byte with value 0xF7 plus the length of + // the length of the payload in binary form, followed by the length of + // the payload, followed by the concatenation of the RLP encodings of + // the items. The range of the first byte is thus [0xF8, 0xFF]. size, err = s.readUint(b - 0xF7) if err == nil && size < 56 { err = ErrCanonSize @@ -977,23 +1105,24 @@ func (s *Stream) readUint(size byte) (uint64, error) { b, err := s.readByte() return uint64(b), err default: - start := int(8 - size) - for i := 0; i < start; i++ { - s.uintbuf[i] = 0 + buffer := s.uintbuf[:8] + for i := range buffer { + buffer[i] = 0 } - if err := s.readFull(s.uintbuf[start:]); err != nil { + start := int(8 - size) + if err := s.readFull(buffer[start:]); err != nil { return 0, err } - if s.uintbuf[start] == 0 { - // Note: readUint is also used to decode integer - // values. The error needs to be adjusted to become - // ErrCanonInt in this case. + if buffer[start] == 0 { + // Note: readUint is also used to decode integer values. + // The error needs to be adjusted to become ErrCanonInt in this case. return 0, ErrCanonSize } - return binary.BigEndian.Uint64(s.uintbuf), nil + return binary.BigEndian.Uint64(buffer[:]), nil } } +// readFull reads into buf from the underlying stream. func (s *Stream) readFull(buf []byte) (err error) { if err := s.willRead(uint64(len(buf))); err != nil { return err @@ -1004,11 +1133,18 @@ func (s *Stream) readFull(buf []byte) (err error) { n += nn } if err == io.EOF { - err = io.ErrUnexpectedEOF + if n < len(buf) { + err = io.ErrUnexpectedEOF + } else { + // Readers are allowed to give EOF even though the read succeeded. + // In such cases, we discard the EOF, like io.ReadFull() does. + err = nil + } } return err } +// readByte reads a single byte from the underlying stream. func (s *Stream) readByte() (byte, error) { if err := s.willRead(1); err != nil { return 0, err @@ -1020,16 +1156,16 @@ func (s *Stream) readByte() (byte, error) { return b, err } +// willRead is called before any read from the underlying stream. It checks +// n against size limits, and updates the limits if n doesn't overflow them. func (s *Stream) willRead(n uint64) error { s.kind = -1 // rearm Kind - if len(s.stack) > 0 { - // check list overflow - tos := s.stack[len(s.stack)-1] - if n > tos.size-tos.pos { + if inList, limit := s.listLimit(); inList { + if n > limit { return ErrElemTooLarge } - s.stack[len(s.stack)-1].pos += n + s.stack[len(s.stack)-1] = limit - n } if s.limited { if n > s.remaining { @@ -1039,3 +1175,11 @@ func (s *Stream) willRead(n uint64) error { } return nil } + +// listLimit returns the amount of data remaining in the innermost list. +func (s *Stream) listLimit() (inList bool, limit uint64) { + if len(s.stack) == 0 { + return false, 0 + } + return true, s.stack[len(s.stack)-1] +} diff --git a/rlp/doc.go b/rlp/doc.go index b3a81fe232..eeeee9a43a 100644 --- a/rlp/doc.go +++ b/rlp/doc.go @@ -17,17 +17,142 @@ /* Package rlp implements the RLP serialization format. -The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily -nested arrays of binary data, and RLP is the main encoding method used -to serialize objects in Ethereum. The only purpose of RLP is to encode -structure; encoding specific atomic data types (eg. strings, ints, -floats) is left up to higher-order protocols; in Ethereum integers -must be represented in big endian binary form with no leading zeroes -(thus making the integer value zero equivalent to the empty byte -array). - -RLP values are distinguished by a type tag. The type tag precedes the -value in the input stream and defines the size and kind of the bytes -that follow. +The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily nested arrays of +binary data, and RLP is the main encoding method used to serialize objects in Ethereum. +The only purpose of RLP is to encode structure; encoding specific atomic data types (eg. +strings, ints, floats) is left up to higher-order protocols. In Ethereum integers must be +represented in big endian binary form with no leading zeroes (thus making the integer +value zero equivalent to the empty string). + +RLP values are distinguished by a type tag. The type tag precedes the value in the input +stream and defines the size and kind of the bytes that follow. + +# Encoding Rules + +Package rlp uses reflection and encodes RLP based on the Go type of the value. + +If the type implements the Encoder interface, Encode calls EncodeRLP. It does not +call EncodeRLP on nil pointer values. + +To encode a pointer, the value being pointed to is encoded. A nil pointer to a struct +type, slice or array always encodes as an empty RLP list unless the slice or array has +element type byte. A nil pointer to any other value encodes as the empty string. + +Struct values are encoded as an RLP list of all their encoded public fields. Recursive +struct types are supported. + +To encode slices and arrays, the elements are encoded as an RLP list of the value's +elements. Note that arrays and slices with element type uint8 or byte are always encoded +as an RLP string. + +A Go string is encoded as an RLP string. + +An unsigned integer value is encoded as an RLP string. Zero always encodes as an empty RLP +string. big.Int values are treated as integers. Signed integers (int, int8, int16, ...) +are not supported and will return an error when encoding. + +Boolean values are encoded as the unsigned integers zero (false) and one (true). + +An interface value encodes as the value contained in the interface. + +Floating point numbers, maps, channels and functions are not supported. + +# Decoding Rules + +Decoding uses the following type-dependent rules: + +If the type implements the Decoder interface, DecodeRLP is called. + +To decode into a pointer, the value will be decoded as the element type of the pointer. If +the pointer is nil, a new value of the pointer's element type is allocated. If the pointer +is non-nil, the existing value will be reused. Note that package rlp never leaves a +pointer-type struct field as nil unless one of the "nil" struct tags is present. + +To decode into a struct, decoding expects the input to be an RLP list. The decoded +elements of the list are assigned to each public field in the order given by the struct's +definition. The input list must contain an element for each decoded field. Decoding +returns an error if there are too few or too many elements for the struct. + +To decode into a slice, the input must be a list and the resulting slice will contain the +input elements in order. For byte slices, the input must be an RLP string. Array types +decode similarly, with the additional restriction that the number of input elements (or +bytes) must match the array's defined length. + +To decode into a Go string, the input must be an RLP string. The input bytes are taken +as-is and will not necessarily be valid UTF-8. + +To decode into an unsigned integer type, the input must also be an RLP string. The bytes +are interpreted as a big endian representation of the integer. If the RLP string is larger +than the bit size of the type, decoding will return an error. Decode also supports +*big.Int. There is no size limit for big integers. + +To decode into a boolean, the input must contain an unsigned integer of value zero (false) +or one (true). + +To decode into an interface value, one of these types is stored in the value: + + []interface{}, for RLP lists + []byte, for RLP strings + +Non-empty interface types are not supported when decoding. +Signed integers, floating point numbers, maps, channels and functions cannot be decoded into. + +# Struct Tags + +As with other encoding packages, the "-" tag ignores fields. + + type StructWithIgnoredField struct{ + Ignored uint `rlp:"-"` + Field uint + } + +Go struct values encode/decode as RLP lists. There are two ways of influencing the mapping +of fields to list elements. The "tail" tag, which may only be used on the last exported +struct field, allows slurping up any excess list elements into a slice. + + type StructWithTail struct{ + Field uint + Tail []string `rlp:"tail"` + } + +The "optional" tag says that the field may be omitted if it is zero-valued. If this tag is +used on a struct field, all subsequent public fields must also be declared optional. + +When encoding a struct with optional fields, the output RLP list contains all values up to +the last non-zero optional field. + +When decoding into a struct, optional fields may be omitted from the end of the input +list. For the example below, this means input lists of one, two, or three elements are +accepted. + + type StructWithOptionalFields struct{ + Required uint + Optional1 uint `rlp:"optional"` + Optional2 uint `rlp:"optional"` + } + +The "nil", "nilList" and "nilString" tags apply to pointer-typed fields only, and change +the decoding rules for the field type. For regular pointer fields without the "nil" tag, +input values must always match the required input length exactly and the decoder does not +produce nil values. When the "nil" tag is set, input values of size zero decode as a nil +pointer. This is especially useful for recursive types. + + type StructWithNilField struct { + Field *[3]byte `rlp:"nil"` + } + +In the example above, Field allows two possible input sizes. For input 0xC180 (a list +containing an empty string) Field is set to nil after decoding. For input 0xC483000000 (a +list containing a 3-byte string), Field is set to a non-nil array pointer. + +RLP supports two kinds of empty values: empty lists and empty strings. When using the +"nil" tag, the kind of empty value allowed for a type is chosen automatically. A field +whose Go type is a pointer to an unsigned integer, string, boolean or byte array/slice +expects an empty RLP string. Any other pointer field type encodes/decodes as an empty RLP +list. + +The choice of null value can be made explicit with the "nilList" and "nilString" struct +tags. Using these tags encodes/decodes a Go nil pointer value as the empty RLP value kind +defined by the tag. */ package rlp diff --git a/rlp/encbuffer.go b/rlp/encbuffer.go new file mode 100644 index 0000000000..8d3a3b2293 --- /dev/null +++ b/rlp/encbuffer.go @@ -0,0 +1,423 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +import ( + "encoding/binary" + "io" + "math/big" + "reflect" + "sync" + + "github.com/holiman/uint256" +) + +type encBuffer struct { + str []byte // string data, contains everything except list headers + lheads []listhead // all list headers + lhsize int // sum of sizes of all encoded list headers + sizebuf [9]byte // auxiliary buffer for uint encoding +} + +// The global encBuffer pool. +var encBufferPool = sync.Pool{ + New: func() interface{} { return new(encBuffer) }, +} + +func getEncBuffer() *encBuffer { + buf := encBufferPool.Get().(*encBuffer) + buf.reset() + return buf +} + +func (buf *encBuffer) reset() { + buf.lhsize = 0 + buf.str = buf.str[:0] + buf.lheads = buf.lheads[:0] +} + +// size returns the length of the encoded data. +func (buf *encBuffer) size() int { + return len(buf.str) + buf.lhsize +} + +// makeBytes creates the encoder output. +func (buf *encBuffer) makeBytes() []byte { + out := make([]byte, buf.size()) + buf.copyTo(out) + return out +} + +func (buf *encBuffer) copyTo(dst []byte) { + strpos := 0 + pos := 0 + for _, head := range buf.lheads { + // write string data before header + n := copy(dst[pos:], buf.str[strpos:head.offset]) + pos += n + strpos += n + // write the header + enc := head.encode(dst[pos:]) + pos += len(enc) + } + // copy string data after the last list header + copy(dst[pos:], buf.str[strpos:]) +} + +// writeTo writes the encoder output to w. +func (buf *encBuffer) writeTo(w io.Writer) (err error) { + strpos := 0 + for _, head := range buf.lheads { + // write string data before header + if head.offset-strpos > 0 { + n, err := w.Write(buf.str[strpos:head.offset]) + strpos += n + if err != nil { + return err + } + } + // write the header + enc := head.encode(buf.sizebuf[:]) + if _, err = w.Write(enc); err != nil { + return err + } + } + if strpos < len(buf.str) { + // write string data after the last list header + _, err = w.Write(buf.str[strpos:]) + } + return err +} + +// Write implements io.Writer and appends b directly to the output. +func (buf *encBuffer) Write(b []byte) (int, error) { + buf.str = append(buf.str, b...) + return len(b), nil +} + +// writeBool writes b as the integer 0 (false) or 1 (true). +func (buf *encBuffer) writeBool(b bool) { + if b { + buf.str = append(buf.str, 0x01) + } else { + buf.str = append(buf.str, 0x80) + } +} + +func (buf *encBuffer) writeUint64(i uint64) { + if i == 0 { + buf.str = append(buf.str, 0x80) + } else if i < 128 { + // fits single byte + buf.str = append(buf.str, byte(i)) + } else { + s := putint(buf.sizebuf[1:], i) + buf.sizebuf[0] = 0x80 + byte(s) + buf.str = append(buf.str, buf.sizebuf[:s+1]...) + } +} + +func (buf *encBuffer) writeBytes(b []byte) { + if len(b) == 1 && b[0] <= 0x7F { + // fits single byte, no string header + buf.str = append(buf.str, b[0]) + } else { + buf.encodeStringHeader(len(b)) + buf.str = append(buf.str, b...) + } +} + +func (buf *encBuffer) writeString(s string) { + buf.writeBytes([]byte(s)) +} + +// wordBytes is the number of bytes in a big.Word +const wordBytes = (32 << (uint64(^big.Word(0)) >> 63)) / 8 + +// writeBigInt writes i as an integer. +func (buf *encBuffer) writeBigInt(i *big.Int) { + bitlen := i.BitLen() + if bitlen <= 64 { + buf.writeUint64(i.Uint64()) + return + } + // Integer is larger than 64 bits, encode from i.Bits(). + // The minimal byte length is bitlen rounded up to the next + // multiple of 8, divided by 8. + length := ((bitlen + 7) & -8) >> 3 + buf.encodeStringHeader(length) + buf.str = append(buf.str, make([]byte, length)...) + index := length + bytesBuf := buf.str[len(buf.str)-length:] + for _, d := range i.Bits() { + for j := 0; j < wordBytes && index > 0; j++ { + index-- + bytesBuf[index] = byte(d) + d >>= 8 + } + } +} + +// writeUint256 writes z as an integer. +func (buf *encBuffer) writeUint256(z *uint256.Int) { + bitlen := z.BitLen() + if bitlen <= 64 { + buf.writeUint64(z.Uint64()) + return + } + nBytes := byte((bitlen + 7) / 8) + var b [33]byte + binary.BigEndian.PutUint64(b[1:9], z[3]) + binary.BigEndian.PutUint64(b[9:17], z[2]) + binary.BigEndian.PutUint64(b[17:25], z[1]) + binary.BigEndian.PutUint64(b[25:33], z[0]) + b[32-nBytes] = 0x80 + nBytes + buf.str = append(buf.str, b[32-nBytes:]...) +} + +// list adds a new list header to the header stack. It returns the index of the header. +// Call listEnd with this index after encoding the content of the list. +func (buf *encBuffer) list() int { + buf.lheads = append(buf.lheads, listhead{offset: len(buf.str), size: buf.lhsize}) + return len(buf.lheads) - 1 +} + +func (buf *encBuffer) listEnd(index int) { + lh := &buf.lheads[index] + lh.size = buf.size() - lh.offset - lh.size + if lh.size < 56 { + buf.lhsize++ // length encoded into kind tag + } else { + buf.lhsize += 1 + intsize(uint64(lh.size)) + } +} + +func (buf *encBuffer) encode(val interface{}) error { + rval := reflect.ValueOf(val) + writer, err := cachedWriter(rval.Type()) + if err != nil { + return err + } + return writer(rval, buf) +} + +func (buf *encBuffer) encodeStringHeader(size int) { + if size < 56 { + buf.str = append(buf.str, 0x80+byte(size)) + } else { + sizesize := putint(buf.sizebuf[1:], uint64(size)) + buf.sizebuf[0] = 0xB7 + byte(sizesize) + buf.str = append(buf.str, buf.sizebuf[:sizesize+1]...) + } +} + +// encReader is the io.Reader returned by EncodeToReader. +// It releases its encbuf at EOF. +type encReader struct { + buf *encBuffer // the buffer we're reading from. this is nil when we're at EOF. + lhpos int // index of list header that we're reading + strpos int // current position in string buffer + piece []byte // next piece to be read +} + +func (r *encReader) Read(b []byte) (n int, err error) { + for { + if r.piece = r.next(); r.piece == nil { + // Put the encode buffer back into the pool at EOF when it + // is first encountered. Subsequent calls still return EOF + // as the error but the buffer is no longer valid. + if r.buf != nil { + encBufferPool.Put(r.buf) + r.buf = nil + } + return n, io.EOF + } + nn := copy(b[n:], r.piece) + n += nn + if nn < len(r.piece) { + // piece didn't fit, see you next time. + r.piece = r.piece[nn:] + return n, nil + } + r.piece = nil + } +} + +// next returns the next piece of data to be read. +// it returns nil at EOF. +func (r *encReader) next() []byte { + switch { + case r.buf == nil: + return nil + + case r.piece != nil: + // There is still data available for reading. + return r.piece + + case r.lhpos < len(r.buf.lheads): + // We're before the last list header. + head := r.buf.lheads[r.lhpos] + sizebefore := head.offset - r.strpos + if sizebefore > 0 { + // String data before header. + p := r.buf.str[r.strpos:head.offset] + r.strpos += sizebefore + return p + } + r.lhpos++ + return head.encode(r.buf.sizebuf[:]) + + case r.strpos < len(r.buf.str): + // String data at the end, after all list headers. + p := r.buf.str[r.strpos:] + r.strpos = len(r.buf.str) + return p + + default: + return nil + } +} + +func encBufferFromWriter(w io.Writer) *encBuffer { + switch w := w.(type) { + case EncoderBuffer: + return w.buf + case *EncoderBuffer: + return w.buf + case *encBuffer: + return w + default: + return nil + } +} + +// EncoderBuffer is a buffer for incremental encoding. +// +// The zero value is NOT ready for use. To get a usable buffer, +// create it using NewEncoderBuffer or call Reset. +type EncoderBuffer struct { + buf *encBuffer + dst io.Writer + + ownBuffer bool +} + +// NewEncoderBuffer creates an encoder buffer. +func NewEncoderBuffer(dst io.Writer) EncoderBuffer { + var w EncoderBuffer + w.Reset(dst) + return w +} + +// Reset truncates the buffer and sets the output destination. +func (w *EncoderBuffer) Reset(dst io.Writer) { + if w.buf != nil && !w.ownBuffer { + panic("can't Reset derived EncoderBuffer") + } + + // If the destination writer has an *encBuffer, use it. + // Note that w.ownBuffer is left false here. + if dst != nil { + if outer := encBufferFromWriter(dst); outer != nil { + *w = EncoderBuffer{outer, nil, false} + return + } + } + + // Get a fresh buffer. + if w.buf == nil { + w.buf = encBufferPool.Get().(*encBuffer) + w.ownBuffer = true + } + w.buf.reset() + w.dst = dst +} + +// Flush writes encoded RLP data to the output writer. This can only be called once. +// If you want to re-use the buffer after Flush, you must call Reset. +func (w *EncoderBuffer) Flush() error { + var err error + if w.dst != nil { + err = w.buf.writeTo(w.dst) + } + // Release the internal buffer. + if w.ownBuffer { + encBufferPool.Put(w.buf) + } + *w = EncoderBuffer{} + return err +} + +// ToBytes returns the encoded bytes. +func (w *EncoderBuffer) ToBytes() []byte { + return w.buf.makeBytes() +} + +// AppendToBytes appends the encoded bytes to dst. +func (w *EncoderBuffer) AppendToBytes(dst []byte) []byte { + size := w.buf.size() + out := append(dst, make([]byte, size)...) + w.buf.copyTo(out[len(dst):]) + return out +} + +// Write appends b directly to the encoder output. +func (w EncoderBuffer) Write(b []byte) (int, error) { + return w.buf.Write(b) +} + +// WriteBool writes b as the integer 0 (false) or 1 (true). +func (w EncoderBuffer) WriteBool(b bool) { + w.buf.writeBool(b) +} + +// WriteUint64 encodes an unsigned integer. +func (w EncoderBuffer) WriteUint64(i uint64) { + w.buf.writeUint64(i) +} + +// WriteBigInt encodes a big.Int as an RLP string. +// Note: Unlike with Encode, the sign of i is ignored. +func (w EncoderBuffer) WriteBigInt(i *big.Int) { + w.buf.writeBigInt(i) +} + +// WriteUint256 encodes uint256.Int as an RLP string. +func (w EncoderBuffer) WriteUint256(i *uint256.Int) { + w.buf.writeUint256(i) +} + +// WriteBytes encodes b as an RLP string. +func (w EncoderBuffer) WriteBytes(b []byte) { + w.buf.writeBytes(b) +} + +// WriteString encodes s as an RLP string. +func (w EncoderBuffer) WriteString(s string) { + w.buf.writeString(s) +} + +// List starts a list. It returns an internal index. Call EndList with +// this index after encoding the content to finish the list. +func (w EncoderBuffer) List() int { + return w.buf.list() +} + +// ListEnd finishes the given list. +func (w EncoderBuffer) ListEnd(index int) { + w.buf.listEnd(index) +} diff --git a/rlp/encode.go b/rlp/encode.go index 44592c2f53..f34be7f3df 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -17,20 +17,29 @@ package rlp import ( + "errors" "fmt" "io" "math/big" "reflect" - "sync" + + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" + + "github.com/holiman/uint256" ) var ( // Common encoded values. // These are useful when implementing EncodeRLP. + + // EmptyString is the encoding of an empty string. EmptyString = []byte{0x80} - EmptyList = []byte{0xC0} + // EmptyList is the encoding of an empty list. + EmptyList = []byte{0xC0} ) +var ErrNegativeBigInt = errors.New("rlp: cannot encode negative big.Int") + // Encoder is implemented by types that require custom // encoding rules or want to encode private fields. type Encoder interface { @@ -49,80 +58,48 @@ type Encoder interface { // perform many small writes in some cases. Consider making w // buffered. // -// Encode uses the following type-dependent encoding rules: -// -// If the type implements the Encoder interface, Encode calls -// EncodeRLP. This is true even for nil pointers, please see the -// documentation for Encoder. -// -// To encode a pointer, the value being pointed to is encoded. For nil -// pointers, Encode will encode the zero value of the type. A nil -// pointer to a struct type always encodes as an empty RLP list. -// A nil pointer to an array encodes as an empty list (or empty string -// if the array has element type byte). -// -// Struct values are encoded as an RLP list of all their encoded -// public fields. Recursive struct types are supported. -// -// To encode slices and arrays, the elements are encoded as an RLP -// list of the value's elements. Note that arrays and slices with -// element type uint8 or byte are always encoded as an RLP string. -// -// A Go string is encoded as an RLP string. -// -// An unsigned integer value is encoded as an RLP string. Zero always -// encodes as an empty RLP string. Encode also supports *big.Int. -// -// An interface value encodes as the value contained in the interface. -// -// Boolean values are not supported, nor are signed integers, floating -// point numbers, maps, channels and functions. +// Please see package-level documentation of encoding rules. func Encode(w io.Writer, val interface{}) error { - if outer, ok := w.(*encbuf); ok { - // Encode was called by some type's EncodeRLP. - // Avoid copying by writing to the outer encbuf directly. - return outer.encode(val) + // Optimization: reuse *encBuffer when called by EncodeRLP. + if buf := encBufferFromWriter(w); buf != nil { + return buf.encode(val) } - eb := encbufPool.Get().(*encbuf) - defer encbufPool.Put(eb) - eb.reset() - if err := eb.encode(val); err != nil { + + buf := getEncBuffer() + defer encBufferPool.Put(buf) + if err := buf.encode(val); err != nil { return err } - return eb.toWriter(w) + return buf.writeTo(w) } -// EncodeBytes returns the RLP encoding of val. -// Please see the documentation of Encode for the encoding rules. +// EncodeToBytes returns the RLP encoding of val. +// Please see package-level documentation for the encoding rules. func EncodeToBytes(val interface{}) ([]byte, error) { - eb := encbufPool.Get().(*encbuf) - defer encbufPool.Put(eb) - eb.reset() - if err := eb.encode(val); err != nil { + buf := getEncBuffer() + defer encBufferPool.Put(buf) + + if err := buf.encode(val); err != nil { return nil, err } - return eb.toBytes(), nil + return buf.makeBytes(), nil } -// EncodeReader returns a reader from which the RLP encoding of val +// EncodeToReader returns a reader from which the RLP encoding of val // can be read. The returned size is the total size of the encoded // data. // // Please see the documentation of Encode for the encoding rules. func EncodeToReader(val interface{}) (size int, r io.Reader, err error) { - eb := encbufPool.Get().(*encbuf) - eb.reset() - if err := eb.encode(val); err != nil { + buf := getEncBuffer() + if err := buf.encode(val); err != nil { + encBufferPool.Put(buf) return 0, nil, err } - return eb.size(), &encReader{buf: eb}, nil -} - -type encbuf struct { - str []byte // string data, contains everything except list headers - lheads []*listhead // all list headers - lhsize int // sum of sizes of all encoded list headers - sizebuf []byte // 9-byte auxiliary buffer for uint encoding + // Note: can't put the reader back into the pool here + // because it is held by encReader. The reader puts it + // back when it has been fully consumed. + return buf.size(), &encReader{buf: buf}, nil } type listhead struct { @@ -151,214 +128,32 @@ func puthead(buf []byte, smalltag, largetag byte, size uint64) int { if size < 56 { buf[0] = smalltag + byte(size) return 1 - } else { - sizesize := putint(buf[1:], size) - buf[0] = largetag + byte(sizesize) - return sizesize + 1 - } -} - -// encbufs are pooled. -var encbufPool = sync.Pool{ - New: func() interface{} { return &encbuf{sizebuf: make([]byte, 9)} }, -} - -func (w *encbuf) reset() { - w.lhsize = 0 - if w.str != nil { - w.str = w.str[:0] - } - if w.lheads != nil { - w.lheads = w.lheads[:0] - } -} - -// encbuf implements io.Writer so it can be passed it into EncodeRLP. -func (w *encbuf) Write(b []byte) (int, error) { - w.str = append(w.str, b...) - return len(b), nil -} - -func (w *encbuf) encode(val interface{}) error { - rval := reflect.ValueOf(val) - ti, err := cachedTypeInfo(rval.Type(), tags{}) - if err != nil { - return err - } - return ti.writer(rval, w) -} - -func (w *encbuf) encodeStringHeader(size int) { - if size < 56 { - w.str = append(w.str, 0x80+byte(size)) - } else { - // TODO: encode to w.str directly - sizesize := putint(w.sizebuf[1:], uint64(size)) - w.sizebuf[0] = 0xB7 + byte(sizesize) - w.str = append(w.str, w.sizebuf[:sizesize+1]...) - } -} - -func (w *encbuf) encodeString(b []byte) { - if len(b) == 1 && b[0] <= 0x7F { - // fits single byte, no string header - w.str = append(w.str, b[0]) - } else { - w.encodeStringHeader(len(b)) - w.str = append(w.str, b...) - } -} - -func (w *encbuf) list() *listhead { - lh := &listhead{offset: len(w.str), size: w.lhsize} - w.lheads = append(w.lheads, lh) - return lh -} - -func (w *encbuf) listEnd(lh *listhead) { - lh.size = w.size() - lh.offset - lh.size - if lh.size < 56 { - w.lhsize += 1 // length encoded into kind tag - } else { - w.lhsize += 1 + intsize(uint64(lh.size)) - } -} - -func (w *encbuf) size() int { - return len(w.str) + w.lhsize -} - -func (w *encbuf) toBytes() []byte { - out := make([]byte, w.size()) - strpos := 0 - pos := 0 - for _, head := range w.lheads { - // write string data before header - n := copy(out[pos:], w.str[strpos:head.offset]) - pos += n - strpos += n - // write the header - enc := head.encode(out[pos:]) - pos += len(enc) } - // copy string data after the last list header - copy(out[pos:], w.str[strpos:]) - return out + sizesize := putint(buf[1:], size) + buf[0] = largetag + byte(sizesize) + return sizesize + 1 } -func (w *encbuf) toWriter(out io.Writer) (err error) { - strpos := 0 - for _, head := range w.lheads { - // write string data before header - if head.offset-strpos > 0 { - n, err := out.Write(w.str[strpos:head.offset]) - strpos += n - if err != nil { - return err - } - } - // write the header - enc := head.encode(w.sizebuf) - if _, err = out.Write(enc); err != nil { - return err - } - } - if strpos < len(w.str) { - // write string data after the last list header - _, err = out.Write(w.str[strpos:]) - } - return err -} - -// encReader is the io.Reader returned by EncodeToReader. -// It releases its encbuf at EOF. -type encReader struct { - buf *encbuf // the buffer we're reading from. this is nil when we're at EOF. - lhpos int // index of list header that we're reading - strpos int // current position in string buffer - piece []byte // next piece to be read -} - -func (r *encReader) Read(b []byte) (n int, err error) { - for { - if r.piece = r.next(); r.piece == nil { - // Put the encode buffer back into the pool at EOF when it - // is first encountered. Subsequent calls still return EOF - // as the error but the buffer is no longer valid. - if r.buf != nil { - encbufPool.Put(r.buf) - r.buf = nil - } - return n, io.EOF - } - nn := copy(b[n:], r.piece) - n += nn - if nn < len(r.piece) { - // piece didn't fit, see you next time. - r.piece = r.piece[nn:] - return n, nil - } - r.piece = nil - } -} - -// next returns the next piece of data to be read. -// it returns nil at EOF. -func (r *encReader) next() []byte { - switch { - case r.buf == nil: - return nil - - case r.piece != nil: - // There is still data available for reading. - return r.piece - - case r.lhpos < len(r.buf.lheads): - // We're before the last list header. - head := r.buf.lheads[r.lhpos] - sizebefore := head.offset - r.strpos - if sizebefore > 0 { - // String data before header. - p := r.buf.str[r.strpos:head.offset] - r.strpos += sizebefore - return p - } else { - r.lhpos++ - return head.encode(r.buf.sizebuf) - } - - case r.strpos < len(r.buf.str): - // String data at the end, after all list headers. - p := r.buf.str[r.strpos:] - r.strpos = len(r.buf.str) - return p - - default: - return nil - } -} - -var ( - encoderInterface = reflect.TypeOf(new(Encoder)).Elem() - big0 = big.NewInt(0) -) +var encoderInterface = reflect.TypeOf(new(Encoder)).Elem() // makeWriter creates a writer function for the given type. -func makeWriter(typ reflect.Type, ts tags) (writer, error) { +func makeWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { kind := typ.Kind() switch { case typ == rawValueType: return writeRawValue, nil - case typ.Implements(encoderInterface): - return writeEncoder, nil - case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(encoderInterface): - return writeEncoderNoPtr, nil - case kind == reflect.Interface: - return writeInterface, nil case typ.AssignableTo(reflect.PtrTo(bigInt)): return writeBigIntPtr, nil case typ.AssignableTo(bigInt): return writeBigIntNoPtr, nil + case typ == reflect.PtrTo(u256Int): + return writeU256IntPtr, nil + case typ == u256Int: + return writeU256IntNoPtr, nil + case kind == reflect.Ptr: + return makePtrWriter(typ, ts) + case reflect.PtrTo(typ).Implements(encoderInterface): + return makeEncoderWriter(typ), nil case isUint(kind): return writeUint, nil case kind == reflect.Bool: @@ -368,97 +163,116 @@ func makeWriter(typ reflect.Type, ts tags) (writer, error) { case kind == reflect.Slice && isByte(typ.Elem()): return writeBytes, nil case kind == reflect.Array && isByte(typ.Elem()): - return writeByteArray, nil + return makeByteArrayWriter(typ), nil case kind == reflect.Slice || kind == reflect.Array: return makeSliceWriter(typ, ts) case kind == reflect.Struct: return makeStructWriter(typ) - case kind == reflect.Ptr: - return makePtrWriter(typ) + case kind == reflect.Interface: + return writeInterface, nil default: return nil, fmt.Errorf("rlp: type %v is not RLP-serializable", typ) } } -func isByte(typ reflect.Type) bool { - return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface) -} - -func writeRawValue(val reflect.Value, w *encbuf) error { +func writeRawValue(val reflect.Value, w *encBuffer) error { w.str = append(w.str, val.Bytes()...) return nil } -func writeUint(val reflect.Value, w *encbuf) error { - i := val.Uint() - if i == 0 { - w.str = append(w.str, 0x80) - } else if i < 128 { - // fits single byte - w.str = append(w.str, byte(i)) - } else { - // TODO: encode int to w.str directly - s := putint(w.sizebuf[1:], i) - w.sizebuf[0] = 0x80 + byte(s) - w.str = append(w.str, w.sizebuf[:s+1]...) - } +func writeUint(val reflect.Value, w *encBuffer) error { + w.writeUint64(val.Uint()) return nil } -func writeBool(val reflect.Value, w *encbuf) error { - if val.Bool() { - w.str = append(w.str, 0x01) - } else { - w.str = append(w.str, 0x80) - } +func writeBool(val reflect.Value, w *encBuffer) error { + w.writeBool(val.Bool()) return nil } -func writeBigIntPtr(val reflect.Value, w *encbuf) error { +func writeBigIntPtr(val reflect.Value, w *encBuffer) error { ptr := val.Interface().(*big.Int) if ptr == nil { w.str = append(w.str, 0x80) return nil } - return writeBigInt(ptr, w) + if ptr.Sign() == -1 { + return ErrNegativeBigInt + } + w.writeBigInt(ptr) + return nil } -func writeBigIntNoPtr(val reflect.Value, w *encbuf) error { +func writeBigIntNoPtr(val reflect.Value, w *encBuffer) error { i := val.Interface().(big.Int) - return writeBigInt(&i, w) + if i.Sign() == -1 { + return ErrNegativeBigInt + } + w.writeBigInt(&i) + return nil } -func writeBigInt(i *big.Int, w *encbuf) error { - if cmp := i.Cmp(big0); cmp == -1 { - return fmt.Errorf("rlp: cannot encode negative *big.Int") - } else if cmp == 0 { +func writeU256IntPtr(val reflect.Value, w *encBuffer) error { + ptr := val.Interface().(*uint256.Int) + if ptr == nil { w.str = append(w.str, 0x80) - } else { - w.encodeString(i.Bytes()) + return nil } + w.writeUint256(ptr) + return nil +} + +func writeU256IntNoPtr(val reflect.Value, w *encBuffer) error { + i := val.Interface().(uint256.Int) + w.writeUint256(&i) return nil } -func writeBytes(val reflect.Value, w *encbuf) error { - w.encodeString(val.Bytes()) +func writeBytes(val reflect.Value, w *encBuffer) error { + w.writeBytes(val.Bytes()) return nil } -func writeByteArray(val reflect.Value, w *encbuf) error { - if !val.CanAddr() { - // Slice requires the value to be addressable. - // Make it addressable by copying. - copy := reflect.New(val.Type()).Elem() - copy.Set(val) - val = copy +func makeByteArrayWriter(typ reflect.Type) writer { + switch typ.Len() { + case 0: + return writeLengthZeroByteArray + case 1: + return writeLengthOneByteArray + default: + length := typ.Len() + return func(val reflect.Value, w *encBuffer) error { + if !val.CanAddr() { + // Getting the byte slice of val requires it to be addressable. Make it + // addressable by copying. + copy := reflect.New(val.Type()).Elem() + copy.Set(val) + val = copy + } + slice := byteArrayBytes(val, length) + w.encodeStringHeader(len(slice)) + w.str = append(w.str, slice...) + return nil + } } - size := val.Len() - slice := val.Slice(0, size).Bytes() - w.encodeString(slice) +} + +func writeLengthZeroByteArray(val reflect.Value, w *encBuffer) error { + w.str = append(w.str, 0x80) return nil } -func writeString(val reflect.Value, w *encbuf) error { +func writeLengthOneByteArray(val reflect.Value, w *encBuffer) error { + b := byte(val.Index(0).Uint()) + if b <= 0x7f { + w.str = append(w.str, b) + } else { + w.str = append(w.str, 0x81, b) + } + return nil +} + +func writeString(val reflect.Value, w *encBuffer) error { s := val.String() if len(s) == 1 && s[0] <= 0x7f { // fits single byte, no string header @@ -470,27 +284,7 @@ func writeString(val reflect.Value, w *encbuf) error { return nil } -func writeEncoder(val reflect.Value, w *encbuf) error { - return val.Interface().(Encoder).EncodeRLP(w) -} - -// writeEncoderNoPtr handles non-pointer values that implement Encoder -// with a pointer receiver. -func writeEncoderNoPtr(val reflect.Value, w *encbuf) error { - if !val.CanAddr() { - // We can't get the address. It would be possible to make the - // value addressable by creating a shallow copy, but this - // creates other problems so we're not doing it (yet). - // - // package json simply doesn't call MarshalJSON for cases like - // this, but encodes the value as if it didn't implement the - // interface. We don't want to handle it that way. - return fmt.Errorf("rlp: game over: unadressable value of type %v, EncodeRLP is pointer method", val.Type()) - } - return val.Addr().Interface().(Encoder).EncodeRLP(w) -} - -func writeInterface(val reflect.Value, w *encbuf) error { +func writeInterface(val reflect.Value, w *encBuffer) error { if val.IsNil() { // Write empty list. This is consistent with the previous RLP // encoder that we had and should therefore avoid any @@ -499,31 +293,51 @@ func writeInterface(val reflect.Value, w *encbuf) error { return nil } eval := val.Elem() - ti, err := cachedTypeInfo(eval.Type(), tags{}) + writer, err := cachedWriter(eval.Type()) if err != nil { return err } - return ti.writer(eval, w) + return writer(eval, w) } -func makeSliceWriter(typ reflect.Type, ts tags) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err +func makeSliceWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { + etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr } - writer := func(val reflect.Value, w *encbuf) error { - if !ts.tail { - defer w.listEnd(w.list()) + + var wfn writer + if ts.Tail { + // This is for struct tail slices. + // w.list is not called for them. + wfn = func(val reflect.Value, w *encBuffer) error { + vlen := val.Len() + for i := 0; i < vlen; i++ { + if err := etypeinfo.writer(val.Index(i), w); err != nil { + return err + } + } + return nil } - vlen := val.Len() - for i := 0; i < vlen; i++ { - if err := etypeinfo.writer(val.Index(i), w); err != nil { - return err + } else { + // This is for regular slices and arrays. + wfn = func(val reflect.Value, w *encBuffer) error { + vlen := val.Len() + if vlen == 0 { + w.str = append(w.str, 0xC0) + return nil + } + listOffset := w.list() + for i := 0; i < vlen; i++ { + if err := etypeinfo.writer(val.Index(i), w); err != nil { + return err + } } + w.listEnd(listOffset) + return nil } - return nil } - return writer, nil + return wfn, nil } func makeStructWriter(typ reflect.Type) (writer, error) { @@ -531,56 +345,86 @@ func makeStructWriter(typ reflect.Type) (writer, error) { if err != nil { return nil, err } - writer := func(val reflect.Value, w *encbuf) error { - lh := w.list() - for _, f := range fields { - if err := f.info.writer(val.Field(f.index), w); err != nil { - return err + for _, f := range fields { + if f.info.writerErr != nil { + return nil, structFieldError{typ, f.index, f.info.writerErr} + } + } + + var writer writer + firstOptionalField := firstOptionalField(fields) + if firstOptionalField == len(fields) { + // This is the writer function for structs without any optional fields. + writer = func(val reflect.Value, w *encBuffer) error { + lh := w.list() + for _, f := range fields { + if err := f.info.writer(val.Field(f.index), w); err != nil { + return err + } } + w.listEnd(lh) + return nil + } + } else { + // If there are any "optional" fields, the writer needs to perform additional + // checks to determine the output list length. + writer = func(val reflect.Value, w *encBuffer) error { + lastField := len(fields) - 1 + for ; lastField >= firstOptionalField; lastField-- { + if !val.Field(fields[lastField].index).IsZero() { + break + } + } + lh := w.list() + for i := 0; i <= lastField; i++ { + if err := fields[i].info.writer(val.Field(fields[i].index), w); err != nil { + return err + } + } + w.listEnd(lh) + return nil } - w.listEnd(lh) - return nil } return writer, nil } -func makePtrWriter(typ reflect.Type) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err +func makePtrWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { + nilEncoding := byte(0xC0) + if typeNilKind(typ.Elem(), ts) == String { + nilEncoding = 0x80 } - // determine nil pointer handler - var nilfunc func(*encbuf) error - kind := typ.Elem().Kind() - switch { - case kind == reflect.Array && isByte(typ.Elem().Elem()): - nilfunc = func(w *encbuf) error { - w.str = append(w.str, 0x80) - return nil - } - case kind == reflect.Struct || kind == reflect.Array: - nilfunc = func(w *encbuf) error { - // encoding the zero value of a struct/array could trigger - // infinite recursion, avoid that. - w.listEnd(w.list()) - return nil - } - default: - zero := reflect.Zero(typ.Elem()) - nilfunc = func(w *encbuf) error { - return etypeinfo.writer(zero, w) + etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr + } + + writer := func(val reflect.Value, w *encBuffer) error { + if ev := val.Elem(); ev.IsValid() { + return etypeinfo.writer(ev, w) } + w.str = append(w.str, nilEncoding) + return nil } + return writer, nil +} - writer := func(val reflect.Value, w *encbuf) error { - if val.IsNil() { - return nilfunc(w) - } else { - return etypeinfo.writer(val.Elem(), w) +func makeEncoderWriter(typ reflect.Type) writer { + if typ.Implements(encoderInterface) { + return func(val reflect.Value, w *encBuffer) error { + return val.Interface().(Encoder).EncodeRLP(w) + } + } + w := func(val reflect.Value, w *encBuffer) error { + if !val.CanAddr() { + // package json simply doesn't call MarshalJSON for this case, but encodes the + // value as if it didn't implement the interface. We don't want to handle it that + // way. + return fmt.Errorf("rlp: unadressable value of type %v, EncodeRLP is pointer method", val.Type()) } + return val.Addr().Interface().(Encoder).EncodeRLP(w) } - return writer, err + return w } // putint writes i to the beginning of b in big endian byte diff --git a/rlp/iterator.go b/rlp/iterator.go new file mode 100644 index 0000000000..6be574572e --- /dev/null +++ b/rlp/iterator.go @@ -0,0 +1,60 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +type listIterator struct { + data []byte + next []byte + err error +} + +// NewListIterator creates an iterator for the (list) represented by data +// TODO: Consider removing this implementation, as it is no longer used. +func NewListIterator(data RawValue) (*listIterator, error) { + k, t, c, err := readKind(data) + if err != nil { + return nil, err + } + if k != List { + return nil, ErrExpectedList + } + it := &listIterator{ + data: data[t : t+c], + } + return it, nil +} + +// Next forwards the iterator one step, returns true if it was not at end yet +func (it *listIterator) Next() bool { + if len(it.data) == 0 { + return false + } + _, t, c, err := readKind(it.data) + it.next = it.data[:t+c] + it.data = it.data[t+c:] + it.err = err + return true +} + +// Value returns the current value +func (it *listIterator) Value() []byte { + return it.next +} + +func (it *listIterator) Err() error { + return it.err +} diff --git a/rlp/raw.go b/rlp/raw.go index 2b3f328f66..773aa7e614 100644 --- a/rlp/raw.go +++ b/rlp/raw.go @@ -28,12 +28,53 @@ type RawValue []byte var rawValueType = reflect.TypeOf(RawValue{}) +// StringSize returns the encoded size of a string. +func StringSize(s string) uint64 { + switch { + case len(s) == 0: + return 1 + case len(s) == 1: + if s[0] <= 0x7f { + return 1 + } else { + return 2 + } + default: + return uint64(headsize(uint64(len(s))) + len(s)) + } +} + +// BytesSize returns the encoded size of a byte slice. +func BytesSize(b []byte) uint64 { + switch { + case len(b) == 0: + return 1 + case len(b) == 1: + if b[0] <= 0x7f { + return 1 + } else { + return 2 + } + default: + return uint64(headsize(uint64(len(b))) + len(b)) + } +} + // ListSize returns the encoded size of an RLP list with the given // content size. func ListSize(contentSize uint64) uint64 { return uint64(headsize(contentSize)) + contentSize } +// IntSize returns the encoded size of the integer x. Note: The return type of this +// function is 'int' for backwards-compatibility reasons. The result is always positive. +func IntSize(x uint64) int { + if x < 0x80 { + return 1 + } + return 1 + intsize(x) +} + // Split returns the content of first RLP value and any // bytes after the value as subslices of b. func Split(b []byte) (k Kind, content, rest []byte, err error) { @@ -57,6 +98,32 @@ func SplitString(b []byte) (content, rest []byte, err error) { return content, rest, nil } +// SplitUint64 decodes an integer at the beginning of b. +// It also returns the remaining data after the integer in 'rest'. +func SplitUint64(b []byte) (x uint64, rest []byte, err error) { + content, rest, err := SplitString(b) + if err != nil { + return 0, b, err + } + switch { + case len(content) == 0: + return 0, rest, nil + case len(content) == 1: + if content[0] == 0 { + return 0, b, ErrCanonInt + } + return uint64(content[0]), rest, nil + case len(content) > 8: + return 0, b, errUintOverflow + default: + x, err = readSize(content, byte(len(content))) + if err != nil { + return 0, b, ErrCanonInt + } + return x, rest, nil + } +} + // SplitList splits b into the content of a list and any remaining // bytes after the list. func SplitList(b []byte) (content, rest []byte, err error) { @@ -154,3 +221,74 @@ func readSize(b []byte, slen byte) (uint64, error) { } return s, nil } + +// AppendUint64 appends the RLP encoding of i to b, and returns the resulting slice. +func AppendUint64(b []byte, i uint64) []byte { + if i == 0 { + return append(b, 0x80) + } else if i < 128 { + return append(b, byte(i)) + } + switch { + case i < (1 << 8): + return append(b, 0x81, byte(i)) + case i < (1 << 16): + return append(b, 0x82, + byte(i>>8), + byte(i), + ) + case i < (1 << 24): + return append(b, 0x83, + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 32): + return append(b, 0x84, + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 40): + return append(b, 0x85, + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + case i < (1 << 48): + return append(b, 0x86, + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 56): + return append(b, 0x87, + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + default: + return append(b, 0x88, + byte(i>>56), + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + } +} diff --git a/rlp/safe.go b/rlp/safe.go new file mode 100644 index 0000000000..3c910337b6 --- /dev/null +++ b/rlp/safe.go @@ -0,0 +1,27 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +//go:build nacl || js || !cgo +// +build nacl js !cgo + +package rlp + +import "reflect" + +// byteArrayBytes returns a slice of the byte array v. +func byteArrayBytes(v reflect.Value, length int) []byte { + return v.Slice(0, length).Bytes() +} diff --git a/rlp/typecache.go b/rlp/typecache.go index 3df799e1ec..c3244050bf 100644 --- a/rlp/typecache.go +++ b/rlp/typecache.go @@ -19,138 +19,222 @@ package rlp import ( "fmt" "reflect" - "strings" "sync" -) + "sync/atomic" -var ( - typeCacheMutex sync.RWMutex - typeCache = make(map[typekey]*typeinfo) + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" ) +// typeinfo is an entry in the type cache. type typeinfo struct { - decoder - writer -} - -// represents struct tags -type tags struct { - // rlp:"nil" controls whether empty input results in a nil pointer. - nilOK bool - // rlp:"tail" controls whether this field swallows additional list - // elements. It can only be set for the last field, which must be - // of slice type. - tail bool - // rlp:"-" ignores fields. - ignored bool + decoder decoder + decoderErr error // error from makeDecoder + writer writer + writerErr error // error from makeWriter } +// typekey is the key of a type in typeCache. It includes the struct tags because +// they might generate a different decoder. type typekey struct { reflect.Type - // the key must include the struct tags because they - // might generate a different decoder. - tags + rlpstruct.Tags } type decoder func(*Stream, reflect.Value) error -type writer func(reflect.Value, *encbuf) error +type writer func(reflect.Value, *encBuffer) error + +var theTC = newTypeCache() + +type typeCache struct { + cur atomic.Value + + // This lock synchronizes writers. + mu sync.Mutex + next map[typekey]*typeinfo +} + +func newTypeCache() *typeCache { + c := new(typeCache) + c.cur.Store(make(map[typekey]*typeinfo)) + return c +} + +func cachedDecoder(typ reflect.Type) (decoder, error) { + info := theTC.info(typ) + return info.decoder, info.decoderErr +} + +func cachedWriter(typ reflect.Type) (writer, error) { + info := theTC.info(typ) + return info.writer, info.writerErr +} + +func (c *typeCache) info(typ reflect.Type) *typeinfo { + key := typekey{Type: typ} + if info := c.cur.Load().(map[typekey]*typeinfo)[key]; info != nil { + return info + } + + // Not in the cache, need to generate info for this type. + return c.generate(typ, rlpstruct.Tags{}) +} + +func (c *typeCache) generate(typ reflect.Type, tags rlpstruct.Tags) *typeinfo { + c.mu.Lock() + defer c.mu.Unlock() + + cur := c.cur.Load().(map[typekey]*typeinfo) + if info := cur[typekey{typ, tags}]; info != nil { + return info + } -func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) { - typeCacheMutex.RLock() - info := typeCache[typekey{typ, tags}] - typeCacheMutex.RUnlock() - if info != nil { - return info, nil + // Copy cur to next. + c.next = make(map[typekey]*typeinfo, len(cur)+1) + for k, v := range cur { + c.next[k] = v } - // not in the cache, need to generate info for this type. - typeCacheMutex.Lock() - defer typeCacheMutex.Unlock() - return cachedTypeInfo1(typ, tags) + + // Generate. + info := c.infoWhileGenerating(typ, tags) + + // next -> cur + c.cur.Store(c.next) + c.next = nil + return info } -func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) { +func (c *typeCache) infoWhileGenerating(typ reflect.Type, tags rlpstruct.Tags) *typeinfo { key := typekey{typ, tags} - info := typeCache[key] - if info != nil { - // another goroutine got the write lock first - return info, nil + if info := c.next[key]; info != nil { + return info } - // put a dummmy value into the cache before generating. - // if the generator tries to lookup itself, it will get + // Put a dummy value into the cache before generating. + // If the generator tries to lookup itself, it will get // the dummy value and won't call itself recursively. - typeCache[key] = new(typeinfo) - info, err := genTypeInfo(typ, tags) - if err != nil { - // remove the dummy value if the generator fails - delete(typeCache, key) - return nil, err - } - *typeCache[key] = *info - return typeCache[key], err + info := new(typeinfo) + c.next[key] = info + info.generate(typ, tags) + return info } type field struct { - index int - info *typeinfo + index int + info *typeinfo + optional bool } +// structFields resolves the typeinfo of all public fields in a struct type. func structFields(typ reflect.Type) (fields []field, err error) { + // Convert fields to rlpstruct.Field. + var allStructFields []rlpstruct.Field for i := 0; i < typ.NumField(); i++ { - if f := typ.Field(i); f.PkgPath == "" { // exported - tags, err := parseStructTag(typ, i) - if err != nil { - return nil, err - } - if tags.ignored { - continue - } - info, err := cachedTypeInfo1(f.Type, tags) - if err != nil { - return nil, err - } - fields = append(fields, field{i, info}) + rf := typ.Field(i) + allStructFields = append(allStructFields, rlpstruct.Field{ + Name: rf.Name, + Index: i, + Exported: rf.PkgPath == "", + Tag: string(rf.Tag), + Type: *rtypeToStructType(rf.Type, nil), + }) + } + + // Filter/validate fields. + structFields, structTags, err := rlpstruct.ProcessFields(allStructFields) + if err != nil { + if tagErr, ok := err.(rlpstruct.TagError); ok { + tagErr.StructType = typ.String() + return nil, tagErr } + return nil, err + } + + // Resolve typeinfo. + for i, sf := range structFields { + typ := typ.Field(sf.Index).Type + tags := structTags[i] + info := theTC.infoWhileGenerating(typ, tags) + fields = append(fields, field{sf.Index, info, tags.Optional}) } return fields, nil } -func parseStructTag(typ reflect.Type, fi int) (tags, error) { - f := typ.Field(fi) - var ts tags - for _, t := range strings.Split(f.Tag.Get("rlp"), ",") { - switch t = strings.TrimSpace(t); t { - case "": - case "-": - ts.ignored = true - case "nil": - ts.nilOK = true - case "tail": - ts.tail = true - if fi != typ.NumField()-1 { - return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name) - } - if f.Type.Kind() != reflect.Slice { - return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (field type is not slice)`, typ, f.Name) - } - default: - return ts, fmt.Errorf("rlp: unknown struct tag %q on %v.%s", t, typ, f.Name) +// firstOptionalField returns the index of the first field with "optional" tag. +func firstOptionalField(fields []field) int { + for i, f := range fields { + if f.optional { + return i } } - return ts, nil + return len(fields) } -func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) { - info = new(typeinfo) - if info.decoder, err = makeDecoder(typ, tags); err != nil { - return nil, err +type structFieldError struct { + typ reflect.Type + field int + err error +} + +func (e structFieldError) Error() string { + return fmt.Sprintf("%v (struct field %v.%s)", e.err, e.typ, e.typ.Field(e.field).Name) +} + +func (i *typeinfo) generate(typ reflect.Type, tags rlpstruct.Tags) { + i.decoder, i.decoderErr = makeDecoder(typ, tags) + i.writer, i.writerErr = makeWriter(typ, tags) +} + +// rtypeToStructType converts typ to rlpstruct.Type. +func rtypeToStructType(typ reflect.Type, rec map[reflect.Type]*rlpstruct.Type) *rlpstruct.Type { + k := typ.Kind() + if k == reflect.Invalid { + panic("invalid kind") } - if info.writer, err = makeWriter(typ, tags); err != nil { - return nil, err + + if prev := rec[typ]; prev != nil { + return prev // short-circuit for recursive types + } + if rec == nil { + rec = make(map[reflect.Type]*rlpstruct.Type) + } + + t := &rlpstruct.Type{ + Name: typ.String(), + Kind: k, + IsEncoder: typ.Implements(encoderInterface), + IsDecoder: typ.Implements(decoderInterface), + } + rec[typ] = t + if k == reflect.Array || k == reflect.Slice || k == reflect.Ptr { + t.Elem = rtypeToStructType(typ.Elem(), rec) + } + return t +} + +// typeNilKind gives the RLP value kind for nil pointers to 'typ'. +func typeNilKind(typ reflect.Type, tags rlpstruct.Tags) Kind { + styp := rtypeToStructType(typ, nil) + + var nk rlpstruct.NilKind + if tags.NilOK { + nk = tags.NilKind + } else { + nk = styp.DefaultNilValue() + } + switch nk { + case rlpstruct.NilKindString: + return String + case rlpstruct.NilKindList: + return List + default: + panic("invalid nil kind value") } - return info, nil } func isUint(k reflect.Kind) bool { return k >= reflect.Uint && k <= reflect.Uintptr } + +func isByte(typ reflect.Type) bool { + return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface) +} diff --git a/rlp/unsafe.go b/rlp/unsafe.go new file mode 100644 index 0000000000..2152ba35fc --- /dev/null +++ b/rlp/unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +//go:build !nacl && !js && cgo +// +build !nacl,!js,cgo + +package rlp + +import ( + "reflect" + "unsafe" +) + +// byteArrayBytes returns a slice of the byte array v. +func byteArrayBytes(v reflect.Value, length int) []byte { + var s []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = v.UnsafeAddr() + hdr.Cap = length + hdr.Len = length + return s +} From 9b5dadb40ac0eb52e8dbe2975864866faec9cd08 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 16 Jun 2023 16:01:50 +0700 Subject: [PATCH 003/119] Implement RLP encoder code generation tool --- rlp/rlpgen/gen.go | 800 +++++++++++++++++++++++++++ rlp/rlpgen/gen_test.go | 107 ++++ rlp/rlpgen/main.go | 147 +++++ rlp/rlpgen/testdata/bigint.in.txt | 10 + rlp/rlpgen/testdata/bigint.out.txt | 49 ++ rlp/rlpgen/testdata/nil.in.txt | 30 + rlp/rlpgen/testdata/nil.out.txt | 289 ++++++++++ rlp/rlpgen/testdata/optional.in.txt | 17 + rlp/rlpgen/testdata/optional.out.txt | 153 +++++ rlp/rlpgen/testdata/rawvalue.in.txt | 11 + rlp/rlpgen/testdata/rawvalue.out.txt | 64 +++ rlp/rlpgen/testdata/uint256.in.txt | 10 + rlp/rlpgen/testdata/uint256.out.txt | 44 ++ rlp/rlpgen/testdata/uints.in.txt | 10 + rlp/rlpgen/testdata/uints.out.txt | 53 ++ rlp/rlpgen/types.go | 124 +++++ 16 files changed, 1918 insertions(+) create mode 100644 rlp/rlpgen/gen.go create mode 100644 rlp/rlpgen/gen_test.go create mode 100644 rlp/rlpgen/main.go create mode 100644 rlp/rlpgen/testdata/bigint.in.txt create mode 100644 rlp/rlpgen/testdata/bigint.out.txt create mode 100644 rlp/rlpgen/testdata/nil.in.txt create mode 100644 rlp/rlpgen/testdata/nil.out.txt create mode 100644 rlp/rlpgen/testdata/optional.in.txt create mode 100644 rlp/rlpgen/testdata/optional.out.txt create mode 100644 rlp/rlpgen/testdata/rawvalue.in.txt create mode 100644 rlp/rlpgen/testdata/rawvalue.out.txt create mode 100644 rlp/rlpgen/testdata/uint256.in.txt create mode 100644 rlp/rlpgen/testdata/uint256.out.txt create mode 100644 rlp/rlpgen/testdata/uints.in.txt create mode 100644 rlp/rlpgen/testdata/uints.out.txt create mode 100644 rlp/rlpgen/types.go diff --git a/rlp/rlpgen/gen.go b/rlp/rlpgen/gen.go new file mode 100644 index 0000000000..26ccdc574e --- /dev/null +++ b/rlp/rlpgen/gen.go @@ -0,0 +1,800 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "fmt" + "go/format" + "go/types" + "sort" + + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" +) + +// buildContext keeps the data needed for make*Op. +type buildContext struct { + topType *types.Named // the type we're creating methods for + + encoderIface *types.Interface + decoderIface *types.Interface + rawValueType *types.Named + + typeToStructCache map[types.Type]*rlpstruct.Type +} + +func newBuildContext(packageRLP *types.Package) *buildContext { + enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying() + dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying() + rawv := packageRLP.Scope().Lookup("RawValue").Type() + return &buildContext{ + typeToStructCache: make(map[types.Type]*rlpstruct.Type), + encoderIface: enc.(*types.Interface), + decoderIface: dec.(*types.Interface), + rawValueType: rawv.(*types.Named), + } +} + +func (bctx *buildContext) isEncoder(typ types.Type) bool { + return types.Implements(typ, bctx.encoderIface) +} + +func (bctx *buildContext) isDecoder(typ types.Type) bool { + return types.Implements(typ, bctx.decoderIface) +} + +// typeToStructType converts typ to rlpstruct.Type. +func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type { + if prev := bctx.typeToStructCache[typ]; prev != nil { + return prev // short-circuit for recursive types. + } + + // Resolve named types to their underlying type, but keep the name. + name := types.TypeString(typ, nil) + for { + utype := typ.Underlying() + if utype == typ { + break + } + typ = utype + } + + // Create the type and store it in cache. + t := &rlpstruct.Type{ + Name: name, + Kind: typeReflectKind(typ), + IsEncoder: bctx.isEncoder(typ), + IsDecoder: bctx.isDecoder(typ), + } + bctx.typeToStructCache[typ] = t + + // Assign element type. + switch typ.(type) { + case *types.Array, *types.Slice, *types.Pointer: + etype := typ.(interface{ Elem() types.Type }).Elem() + t.Elem = bctx.typeToStructType(etype) + } + return t +} + +// genContext is passed to the gen* methods of op when generating +// the output code. It tracks packages to be imported by the output +// file and assigns unique names of temporary variables. +type genContext struct { + inPackage *types.Package + imports map[string]struct{} + tempCounter int +} + +func newGenContext(inPackage *types.Package) *genContext { + return &genContext{ + inPackage: inPackage, + imports: make(map[string]struct{}), + } +} + +func (ctx *genContext) temp() string { + v := fmt.Sprintf("_tmp%d", ctx.tempCounter) + ctx.tempCounter++ + return v +} + +func (ctx *genContext) resetTemp() { + ctx.tempCounter = 0 +} + +func (ctx *genContext) addImport(path string) { + if path == ctx.inPackage.Path() { + return // avoid importing the package that we're generating in. + } + // TODO: renaming? + ctx.imports[path] = struct{}{} +} + +// importsList returns all packages that need to be imported. +func (ctx *genContext) importsList() []string { + imp := make([]string, 0, len(ctx.imports)) + for k := range ctx.imports { + imp = append(imp, k) + } + sort.Strings(imp) + return imp +} + +// qualify is the types.Qualifier used for printing types. +func (ctx *genContext) qualify(pkg *types.Package) string { + if pkg.Path() == ctx.inPackage.Path() { + return "" + } + ctx.addImport(pkg.Path()) + // TODO: renaming? + return pkg.Name() +} + +type op interface { + // genWrite creates the encoder. The generated code should write v, + // which is any Go expression, to the rlp.EncoderBuffer 'w'. + genWrite(ctx *genContext, v string) string + + // genDecode creates the decoder. The generated code should read + // a value from the rlp.Stream 'dec' and store it to dst. + genDecode(ctx *genContext) (string, string) +} + +// basicOp handles basic types bool, uint*, string. +type basicOp struct { + typ types.Type + writeMethod string // calle write the value + writeArgType types.Type // parameter type of writeMethod + decMethod string + decResultType types.Type // return type of decMethod + decUseBitSize bool // if true, result bit size is appended to decMethod +} + +func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) { + op := basicOp{typ: typ} + kind := typ.Kind() + switch { + case kind == types.Bool: + op.writeMethod = "WriteBool" + op.writeArgType = types.Typ[types.Bool] + op.decMethod = "Bool" + op.decResultType = types.Typ[types.Bool] + case kind >= types.Uint8 && kind <= types.Uint64: + op.writeMethod = "WriteUint64" + op.writeArgType = types.Typ[types.Uint64] + op.decMethod = "Uint" + op.decResultType = typ + op.decUseBitSize = true + case kind == types.String: + op.writeMethod = "WriteString" + op.writeArgType = types.Typ[types.String] + op.decMethod = "String" + op.decResultType = types.Typ[types.String] + default: + return nil, fmt.Errorf("unhandled basic type: %v", typ) + } + return op, nil +} + +func (*buildContext) makeByteSliceOp(typ *types.Slice) op { + if !isByte(typ.Elem()) { + panic("non-byte slice type in makeByteSliceOp") + } + bslice := types.NewSlice(types.Typ[types.Uint8]) + return basicOp{ + typ: typ, + writeMethod: "WriteBytes", + writeArgType: bslice, + decMethod: "Bytes", + decResultType: bslice, + } +} + +func (bctx *buildContext) makeRawValueOp() op { + bslice := types.NewSlice(types.Typ[types.Uint8]) + return basicOp{ + typ: bctx.rawValueType, + writeMethod: "Write", + writeArgType: bslice, + decMethod: "Raw", + decResultType: bslice, + } +} + +func (op basicOp) writeNeedsConversion() bool { + return !types.AssignableTo(op.typ, op.writeArgType) +} + +func (op basicOp) decodeNeedsConversion() bool { + return !types.AssignableTo(op.decResultType, op.typ) +} + +func (op basicOp) genWrite(ctx *genContext, v string) string { + if op.writeNeedsConversion() { + v = fmt.Sprintf("%s(%s)", op.writeArgType, v) + } + return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v) +} + +func (op basicOp) genDecode(ctx *genContext) (string, string) { + var ( + resultV = ctx.temp() + result = resultV + method = op.decMethod + ) + if op.decUseBitSize { + // Note: For now, this only works for platform-independent integer + // sizes. makeBasicOp forbids the platform-dependent types. + var sizes types.StdSizes + method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8) + } + + // Call the decoder method. + var b bytes.Buffer + fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method) + fmt.Fprintf(&b, "if err != nil { return err }\n") + if op.decodeNeedsConversion() { + conv := ctx.temp() + fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV) + result = conv + } + return result, b.String() +} + +// byteArrayOp handles [...]byte. +type byteArrayOp struct { + typ types.Type + name types.Type // name != typ for named byte array types (e.g. common.Address) +} + +func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp { + nt := types.Type(name) + if name == nil { + nt = typ + } + return byteArrayOp{typ, nt} +} + +func (op byteArrayOp) genWrite(ctx *genContext, v string) string { + return fmt.Sprintf("w.WriteBytes(%s[:])\n", v) +} + +func (op byteArrayOp) genDecode(ctx *genContext) (string, string) { + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify)) + fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV) + return resultV, b.String() +} + +// bigIntOp handles big.Int. +// This exists because big.Int has it's own decoder operation on rlp.Stream, +// but the decode method returns *big.Int, so it needs to be dereferenced. +type bigIntOp struct { + pointer bool +} + +func (op bigIntOp) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + + fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v) + fmt.Fprintf(&b, " return rlp.ErrNegativeBigInt\n") + fmt.Fprintf(&b, "}\n") + dst := v + if !op.pointer { + dst = "&" + v + } + fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst) + + // Wrap with nil check. + if op.pointer { + code := b.String() + b.Reset() + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") + fmt.Fprintf(&b, "} else {\n") + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "}\n") + } + + return b.String() +} + +func (op bigIntOp) genDecode(ctx *genContext) (string, string) { + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV) + fmt.Fprintf(&b, "if err != nil { return err }\n") + + result := resultV + if !op.pointer { + result = "(*" + resultV + ")" + } + return result, b.String() +} + +// uint256Op handles "github.com/holiman/uint256".Int +type uint256Op struct { + pointer bool +} + +func (op uint256Op) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + + dst := v + if !op.pointer { + dst = "&" + v + } + fmt.Fprintf(&b, "w.WriteUint256(%s)\n", dst) + + // Wrap with nil check. + if op.pointer { + code := b.String() + b.Reset() + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") + fmt.Fprintf(&b, "} else {\n") + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "}\n") + } + + return b.String() +} + +func (op uint256Op) genDecode(ctx *genContext) (string, string) { + ctx.addImport("github.com/holiman/uint256") + + var b bytes.Buffer + resultV := ctx.temp() + fmt.Fprintf(&b, "var %s uint256.Int\n", resultV) + fmt.Fprintf(&b, "if err := dec.ReadUint256(&%s); err != nil { return err }\n", resultV) + + result := resultV + if op.pointer { + result = "&" + resultV + } + return result, b.String() +} + +// encoderDecoderOp handles rlp.Encoder and rlp.Decoder. +// In order to be used with this, the type must implement both interfaces. +// This restriction may be lifted in the future by creating separate ops for +// encoding and decoding. +type encoderDecoderOp struct { + typ types.Type +} + +func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string { + return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v) +} + +func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) { + // DecodeRLP must have pointer receiver, and this is verified in makeOp. + etyp := op.typ.(*types.Pointer).Elem() + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify)) + fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV) + return resultV, b.String() +} + +// ptrOp handles pointer types. +type ptrOp struct { + elemTyp types.Type + elem op + nilOK bool + nilValue rlpstruct.NilKind +} + +func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) { + elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{}) + if err != nil { + return nil, err + } + op := ptrOp{elemTyp: elemTyp, elem: elemOp} + + // Determine nil value. + if tags.NilOK { + op.nilOK = true + op.nilValue = tags.NilKind + } else { + styp := bctx.typeToStructType(elemTyp) + op.nilValue = styp.DefaultNilValue() + } + return op, nil +} + +func (op ptrOp) genWrite(ctx *genContext, v string) string { + // Note: in writer functions, accesses to v are read-only, i.e. v is any Go + // expression. To make all accesses work through the pointer, we substitute + // v with (*v). This is required for most accesses including `v`, `call(v)`, + // and `v[index]` on slices. + // + // For `v.field` and `v[:]` on arrays, the dereference operation is not required. + var vv string + _, isStruct := op.elem.(structOp) + _, isByteArray := op.elem.(byteArrayOp) + if isStruct || isByteArray { + vv = v + } else { + vv = fmt.Sprintf("(*%s)", v) + } + + var b bytes.Buffer + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write([]byte{0x%X})\n", op.nilValue) + fmt.Fprintf(&b, "} else {\n") + fmt.Fprintf(&b, " %s", op.elem.genWrite(ctx, vv)) + fmt.Fprintf(&b, "}\n") + return b.String() +} + +func (op ptrOp) genDecode(ctx *genContext) (string, string) { + result, code := op.elem.genDecode(ctx) + if !op.nilOK { + // If nil pointers are not allowed, we can just decode the element. + return "&" + result, code + } + + // nil is allowed, so check the kind and size first. + // If size is zero and kind matches the nilKind of the type, + // the value decodes as a nil pointer. + var ( + resultV = ctx.temp() + kindV = ctx.temp() + sizeV = ctx.temp() + wantKind string + ) + if op.nilValue == rlpstruct.NilKindList { + wantKind = "rlp.List" + } else { + wantKind = "rlp.String" + } + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify)) + fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV) + fmt.Fprintf(&b, " return err\n") + fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, " %s = &%s\n", resultV, result) + fmt.Fprintf(&b, "}\n") + return resultV, b.String() +} + +// structOp handles struct types. +type structOp struct { + named *types.Named + typ *types.Struct + fields []*structField + optionalFields []*structField +} + +type structField struct { + name string + typ types.Type + elem op +} + +func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) { + // Convert fields to []rlpstruct.Field. + var allStructFields []rlpstruct.Field + for i := 0; i < typ.NumFields(); i++ { + f := typ.Field(i) + allStructFields = append(allStructFields, rlpstruct.Field{ + Name: f.Name(), + Exported: f.Exported(), + Index: i, + Tag: typ.Tag(i), + Type: *bctx.typeToStructType(f.Type()), + }) + } + + // Filter/validate fields. + fields, tags, err := rlpstruct.ProcessFields(allStructFields) + if err != nil { + return nil, err + } + + // Create field ops. + var op = structOp{named: named, typ: typ} + for i, field := range fields { + // Advanced struct tags are not supported yet. + tag := tags[i] + if err := checkUnsupportedTags(field.Name, tag); err != nil { + return nil, err + } + typ := typ.Field(field.Index).Type() + elem, err := bctx.makeOp(nil, typ, tags[i]) + if err != nil { + return nil, fmt.Errorf("field %s: %v", field.Name, err) + } + f := &structField{name: field.Name, typ: typ, elem: elem} + if tag.Optional { + op.optionalFields = append(op.optionalFields, f) + } else { + op.fields = append(op.fields, f) + } + } + return op, nil +} + +func checkUnsupportedTags(field string, tag rlpstruct.Tags) error { + if tag.Tail { + return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field) + } + return nil +} + +func (op structOp) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + var listMarker = ctx.temp() + fmt.Fprintf(&b, "%s := w.List()\n", listMarker) + for _, field := range op.fields { + selector := v + "." + field.name + fmt.Fprint(&b, field.elem.genWrite(ctx, selector)) + } + op.writeOptionalFields(&b, ctx, v) + fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) + return b.String() +} + +func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) { + if len(op.optionalFields) == 0 { + return + } + // First check zero-ness of all optional fields. + var zeroV = make([]string, len(op.optionalFields)) + for i, field := range op.optionalFields { + selector := v + "." + field.name + zeroV[i] = ctx.temp() + fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify)) + } + // Now write the fields. + for i, field := range op.optionalFields { + selector := v + "." + field.name + cond := "" + for j := i; j < len(op.optionalFields); j++ { + if j > i { + cond += " || " + } + cond += zeroV[j] + } + fmt.Fprintf(b, "if %s {\n", cond) + fmt.Fprint(b, field.elem.genWrite(ctx, selector)) + fmt.Fprintf(b, "}\n") + } +} + +func (op structOp) genDecode(ctx *genContext) (string, string) { + // Get the string representation of the type. + // Here, named types are handled separately because the output + // would contain a copy of the struct definition otherwise. + var typeName string + if op.named != nil { + typeName = types.TypeString(op.named, ctx.qualify) + } else { + typeName = types.TypeString(op.typ, ctx.qualify) + } + + // Create struct object. + var resultV = ctx.temp() + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, typeName) + + // Decode fields. + fmt.Fprintf(&b, "{\n") + fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") + for _, field := range op.fields { + result, code := field.elem.genDecode(ctx) + fmt.Fprintf(&b, "// %s:\n", field.name) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result) + } + op.decodeOptionalFields(&b, ctx, resultV) + fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") + fmt.Fprintf(&b, "}\n") + return resultV, b.String() +} + +func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) { + var suffix bytes.Buffer + for _, field := range op.optionalFields { + result, code := field.elem.genDecode(ctx) + fmt.Fprintf(b, "// %s:\n", field.name) + fmt.Fprintf(b, "if dec.MoreDataInList() {\n") + fmt.Fprint(b, code) + fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result) + fmt.Fprintf(&suffix, "}\n") + } + suffix.WriteTo(b) +} + +// sliceOp handles slice types. +type sliceOp struct { + typ *types.Slice + elemOp op +} + +func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) { + elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{}) + if err != nil { + return nil, err + } + return sliceOp{typ: typ, elemOp: elemOp}, nil +} + +func (op sliceOp) genWrite(ctx *genContext, v string) string { + var ( + listMarker = ctx.temp() // holds return value of w.List() + iterElemV = ctx.temp() // iteration variable + elemCode = op.elemOp.genWrite(ctx, iterElemV) + ) + + var b bytes.Buffer + fmt.Fprintf(&b, "%s := w.List()\n", listMarker) + fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v) + fmt.Fprint(&b, elemCode) + fmt.Fprintf(&b, "}\n") + fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) + return b.String() +} + +func (op sliceOp) genDecode(ctx *genContext) (string, string) { + var sliceV = ctx.temp() // holds the output slice + elemResult, elemCode := op.elemOp.genDecode(ctx) + + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify)) + fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") + fmt.Fprintf(&b, "for dec.MoreDataInList() {\n") + fmt.Fprintf(&b, " %s", elemCode) + fmt.Fprintf(&b, " %s = append(%s, %s)\n", sliceV, sliceV, elemResult) + fmt.Fprintf(&b, "}\n") + fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") + return sliceV, b.String() +} + +func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) { + switch typ := typ.(type) { + case *types.Named: + if isBigInt(typ) { + return bigIntOp{}, nil + } + if isUint256(typ) { + return uint256Op{}, nil + } + if typ == bctx.rawValueType { + return bctx.makeRawValueOp(), nil + } + if bctx.isDecoder(typ) { + return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ) + } + // TODO: same check for encoder? + return bctx.makeOp(typ, typ.Underlying(), tags) + case *types.Pointer: + if isBigInt(typ.Elem()) { + return bigIntOp{pointer: true}, nil + } + if isUint256(typ.Elem()) { + return uint256Op{pointer: true}, nil + } + // Encoder/Decoder interfaces. + if bctx.isEncoder(typ) { + if bctx.isDecoder(typ) { + return encoderDecoderOp{typ}, nil + } + return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ) + } + if bctx.isDecoder(typ) { + return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ) + } + // Default pointer handling. + return bctx.makePtrOp(typ.Elem(), tags) + case *types.Basic: + return bctx.makeBasicOp(typ) + case *types.Struct: + return bctx.makeStructOp(name, typ) + case *types.Slice: + etyp := typ.Elem() + if isByte(etyp) && !bctx.isEncoder(etyp) { + return bctx.makeByteSliceOp(typ), nil + } + return bctx.makeSliceOp(typ) + case *types.Array: + etyp := typ.Elem() + if isByte(etyp) && !bctx.isEncoder(etyp) { + return bctx.makeByteArrayOp(name, typ), nil + } + return nil, fmt.Errorf("unhandled array type: %v", typ) + default: + return nil, fmt.Errorf("unhandled type: %v", typ) + } +} + +// generateDecoder generates the DecodeRLP method on 'typ'. +func generateDecoder(ctx *genContext, typ string, op op) []byte { + ctx.resetTemp() + ctx.addImport(pathOfPackageRLP) + + result, code := op.genDecode(ctx) + var b bytes.Buffer + fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, " *obj = %s\n", result) + fmt.Fprintf(&b, " return nil\n") + fmt.Fprintf(&b, "}\n") + return b.Bytes() +} + +// generateEncoder generates the EncodeRLP method on 'typ'. +func generateEncoder(ctx *genContext, typ string, op op) []byte { + ctx.resetTemp() + ctx.addImport("io") + ctx.addImport(pathOfPackageRLP) + + var b bytes.Buffer + fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ) + fmt.Fprintf(&b, " w := rlp.NewEncoderBuffer(_w)\n") + fmt.Fprint(&b, op.genWrite(ctx, "obj")) + fmt.Fprintf(&b, " return w.Flush()\n") + fmt.Fprintf(&b, "}\n") + return b.Bytes() +} + +func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) { + bctx.topType = typ + + pkg := typ.Obj().Pkg() + op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{}) + if err != nil { + return nil, err + } + + var ( + ctx = newGenContext(pkg) + encSource []byte + decSource []byte + ) + if encoder { + encSource = generateEncoder(ctx, typ.Obj().Name(), op) + } + if decoder { + decSource = generateDecoder(ctx, typ.Obj().Name(), op) + } + + var b bytes.Buffer + fmt.Fprintf(&b, "package %s\n\n", pkg.Name()) + for _, imp := range ctx.importsList() { + fmt.Fprintf(&b, "import %q\n", imp) + } + if encoder { + fmt.Fprintln(&b) + b.Write(encSource) + } + if decoder { + fmt.Fprintln(&b) + b.Write(decSource) + } + + source := b.Bytes() + // fmt.Println(string(source)) + return format.Source(source) +} diff --git a/rlp/rlpgen/gen_test.go b/rlp/rlpgen/gen_test.go new file mode 100644 index 0000000000..3b4f5df287 --- /dev/null +++ b/rlp/rlpgen/gen_test.go @@ -0,0 +1,107 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" + "os" + "path/filepath" + "testing" +) + +// Package RLP is loaded only once and reused for all tests. +var ( + testFset = token.NewFileSet() + testImporter = importer.ForCompiler(testFset, "source", nil).(types.ImporterFrom) + testPackageRLP *types.Package +) + +func init() { + cwd, err := os.Getwd() + if err != nil { + panic(err) + } + testPackageRLP, err = testImporter.ImportFrom(pathOfPackageRLP, cwd, 0) + if err != nil { + panic(fmt.Errorf("can't load package RLP: %v", err)) + } +} + +var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"} + +func TestOutput(t *testing.T) { + for _, test := range tests { + test := test + t.Run(test, func(t *testing.T) { + inputFile := filepath.Join("testdata", test+".in.txt") + outputFile := filepath.Join("testdata", test+".out.txt") + bctx, typ, err := loadTestSource(inputFile, "Test") + if err != nil { + t.Fatal("error loading test source:", err) + } + output, err := bctx.generate(typ, true, true) + if err != nil { + t.Fatal("error in generate:", err) + } + + // Set this environment variable to regenerate the test outputs. + if os.Getenv("WRITE_TEST_FILES") != "" { + os.WriteFile(outputFile, output, 0644) + } + + // Check if output matches. + wantOutput, err := os.ReadFile(outputFile) + if err != nil { + t.Fatal("error loading expected test output:", err) + } + if !bytes.Equal(output, wantOutput) { + t.Fatalf("output mismatch, want: %v got %v", string(wantOutput), string(output)) + } + }) + } +} + +func loadTestSource(file string, typeName string) (*buildContext, *types.Named, error) { + // Load the test input. + content, err := os.ReadFile(file) + if err != nil { + return nil, nil, err + } + f, err := parser.ParseFile(testFset, file, content, 0) + if err != nil { + return nil, nil, err + } + conf := types.Config{Importer: testImporter} + pkg, err := conf.Check("test", testFset, []*ast.File{f}, nil) + if err != nil { + return nil, nil, err + } + + // Find the test struct. + bctx := newBuildContext(testPackageRLP) + typ, err := lookupStructType(pkg.Scope(), typeName) + if err != nil { + return nil, nil, fmt.Errorf("can't find type %s: %v", typeName, err) + } + return bctx, typ, nil +} diff --git a/rlp/rlpgen/main.go b/rlp/rlpgen/main.go new file mode 100644 index 0000000000..87aebbc47a --- /dev/null +++ b/rlp/rlpgen/main.go @@ -0,0 +1,147 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "go/types" + "os" + + "golang.org/x/tools/go/packages" +) + +const pathOfPackageRLP = "github.com/tomochain/tomochain/rlp" + +func main() { + var ( + pkgdir = flag.String("dir", ".", "input package") + output = flag.String("out", "-", "output file (default is stdout)") + genEncoder = flag.Bool("encoder", true, "generate EncodeRLP?") + genDecoder = flag.Bool("decoder", false, "generate DecodeRLP?") + typename = flag.String("type", "", "type to generate methods for") + ) + flag.Parse() + + cfg := Config{ + Dir: *pkgdir, + Type: *typename, + GenerateEncoder: *genEncoder, + GenerateDecoder: *genDecoder, + } + code, err := cfg.process() + if err != nil { + fatal(err) + } + if *output == "-" { + os.Stdout.Write(code) + } else if err := os.WriteFile(*output, code, 0600); err != nil { + fatal(err) + } +} + +func fatal(args ...interface{}) { + fmt.Fprintln(os.Stderr, args...) + os.Exit(1) +} + +type Config struct { + Dir string // input package directory + Type string + + GenerateEncoder bool + GenerateDecoder bool +} + +// process generates the Go code. +func (cfg *Config) process() (code []byte, err error) { + // Load packages. + pcfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedImports | packages.NeedDeps, + Dir: cfg.Dir, + BuildFlags: []string{"-tags", "norlpgen"}, + } + ps, err := packages.Load(pcfg, pathOfPackageRLP, ".") + if err != nil { + return nil, err + } + if len(ps) == 0 { + return nil, fmt.Errorf("no Go package found in %s", cfg.Dir) + } + packages.PrintErrors(ps) + + // Find the packages that were loaded. + var ( + pkg *types.Package + packageRLP *types.Package + ) + for _, p := range ps { + if len(p.Errors) > 0 { + return nil, fmt.Errorf("package %s has errors", p.PkgPath) + } + if p.PkgPath == pathOfPackageRLP { + packageRLP = p.Types + } else { + pkg = p.Types + } + } + bctx := newBuildContext(packageRLP) + + // Find the type and generate. + typ, err := lookupStructType(pkg.Scope(), cfg.Type) + if err != nil { + return nil, fmt.Errorf("can't find %s in %s: %v", cfg.Type, pkg, err) + } + code, err = bctx.generate(typ, cfg.GenerateEncoder, cfg.GenerateDecoder) + if err != nil { + return nil, err + } + + // Add build comments. + // This is done here to avoid processing these lines with gofmt. + var header bytes.Buffer + fmt.Fprint(&header, "// Code generated by rlpgen. DO NOT EDIT.\n\n") + fmt.Fprint(&header, "//go:build !norlpgen\n") + fmt.Fprint(&header, "// +build !norlpgen\n\n") + return append(header.Bytes(), code...), nil +} + +func lookupStructType(scope *types.Scope, name string) (*types.Named, error) { + typ, err := lookupType(scope, name) + if err != nil { + return nil, err + } + _, ok := typ.Underlying().(*types.Struct) + if !ok { + return nil, errors.New("not a struct type") + } + return typ, nil +} + +func lookupType(scope *types.Scope, name string) (*types.Named, error) { + obj := scope.Lookup(name) + if obj == nil { + return nil, errors.New("no such identifier") + } + typ, ok := obj.(*types.TypeName) + if !ok { + return nil, errors.New("not a type") + } + return typ.Type().(*types.Named), nil +} diff --git a/rlp/rlpgen/testdata/bigint.in.txt b/rlp/rlpgen/testdata/bigint.in.txt new file mode 100644 index 0000000000..d23d84a287 --- /dev/null +++ b/rlp/rlpgen/testdata/bigint.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +import "math/big" + +type Test struct { + Int *big.Int + IntNoPtr big.Int +} diff --git a/rlp/rlpgen/testdata/bigint.out.txt b/rlp/rlpgen/testdata/bigint.out.txt new file mode 100644 index 0000000000..6dc7bea3bf --- /dev/null +++ b/rlp/rlpgen/testdata/bigint.out.txt @@ -0,0 +1,49 @@ +package test + +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Int == nil { + w.Write(rlp.EmptyString) + } else { + if obj.Int.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(obj.Int) + } + if obj.IntNoPtr.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(&obj.IntNoPtr) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Int: + _tmp1, err := dec.BigInt() + if err != nil { + return err + } + _tmp0.Int = _tmp1 + // IntNoPtr: + _tmp2, err := dec.BigInt() + if err != nil { + return err + } + _tmp0.IntNoPtr = (*_tmp2) + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/nil.in.txt b/rlp/rlpgen/testdata/nil.in.txt new file mode 100644 index 0000000000..a28ff34487 --- /dev/null +++ b/rlp/rlpgen/testdata/nil.in.txt @@ -0,0 +1,30 @@ +// -*- mode: go -*- + +package test + +type Aux struct{ + A uint32 +} + +type Test struct{ + Uint8 *byte `rlp:"nil"` + Uint8List *byte `rlp:"nilList"` + + Uint32 *uint32 `rlp:"nil"` + Uint32List *uint32 `rlp:"nilList"` + + Uint64 *uint64 `rlp:"nil"` + Uint64List *uint64 `rlp:"nilList"` + + String *string `rlp:"nil"` + StringList *string `rlp:"nilList"` + + ByteArray *[3]byte `rlp:"nil"` + ByteArrayList *[3]byte `rlp:"nilList"` + + ByteSlice *[]byte `rlp:"nil"` + ByteSliceList *[]byte `rlp:"nilList"` + + Struct *Aux `rlp:"nil"` + StructString *Aux `rlp:"nilString"` +} diff --git a/rlp/rlpgen/testdata/nil.out.txt b/rlp/rlpgen/testdata/nil.out.txt new file mode 100644 index 0000000000..b3bdd0b86f --- /dev/null +++ b/rlp/rlpgen/testdata/nil.out.txt @@ -0,0 +1,289 @@ +package test + +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Uint8 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64(uint64((*obj.Uint8))) + } + if obj.Uint8List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64(uint64((*obj.Uint8List))) + } + if obj.Uint32 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64(uint64((*obj.Uint32))) + } + if obj.Uint32List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64(uint64((*obj.Uint32List))) + } + if obj.Uint64 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64((*obj.Uint64)) + } + if obj.Uint64List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64((*obj.Uint64List)) + } + if obj.String == nil { + w.Write([]byte{0x80}) + } else { + w.WriteString((*obj.String)) + } + if obj.StringList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteString((*obj.StringList)) + } + if obj.ByteArray == nil { + w.Write([]byte{0x80}) + } else { + w.WriteBytes(obj.ByteArray[:]) + } + if obj.ByteArrayList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteBytes(obj.ByteArrayList[:]) + } + if obj.ByteSlice == nil { + w.Write([]byte{0x80}) + } else { + w.WriteBytes((*obj.ByteSlice)) + } + if obj.ByteSliceList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteBytes((*obj.ByteSliceList)) + } + if obj.Struct == nil { + w.Write([]byte{0xC0}) + } else { + _tmp1 := w.List() + w.WriteUint64(uint64(obj.Struct.A)) + w.ListEnd(_tmp1) + } + if obj.StructString == nil { + w.Write([]byte{0x80}) + } else { + _tmp2 := w.List() + w.WriteUint64(uint64(obj.StructString.A)) + w.ListEnd(_tmp2) + } + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Uint8: + var _tmp2 *byte + if _tmp3, _tmp4, err := dec.Kind(); err != nil { + return err + } else if _tmp4 != 0 || _tmp3 != rlp.String { + _tmp1, err := dec.Uint8() + if err != nil { + return err + } + _tmp2 = &_tmp1 + } + _tmp0.Uint8 = _tmp2 + // Uint8List: + var _tmp6 *byte + if _tmp7, _tmp8, err := dec.Kind(); err != nil { + return err + } else if _tmp8 != 0 || _tmp7 != rlp.List { + _tmp5, err := dec.Uint8() + if err != nil { + return err + } + _tmp6 = &_tmp5 + } + _tmp0.Uint8List = _tmp6 + // Uint32: + var _tmp10 *uint32 + if _tmp11, _tmp12, err := dec.Kind(); err != nil { + return err + } else if _tmp12 != 0 || _tmp11 != rlp.String { + _tmp9, err := dec.Uint32() + if err != nil { + return err + } + _tmp10 = &_tmp9 + } + _tmp0.Uint32 = _tmp10 + // Uint32List: + var _tmp14 *uint32 + if _tmp15, _tmp16, err := dec.Kind(); err != nil { + return err + } else if _tmp16 != 0 || _tmp15 != rlp.List { + _tmp13, err := dec.Uint32() + if err != nil { + return err + } + _tmp14 = &_tmp13 + } + _tmp0.Uint32List = _tmp14 + // Uint64: + var _tmp18 *uint64 + if _tmp19, _tmp20, err := dec.Kind(); err != nil { + return err + } else if _tmp20 != 0 || _tmp19 != rlp.String { + _tmp17, err := dec.Uint64() + if err != nil { + return err + } + _tmp18 = &_tmp17 + } + _tmp0.Uint64 = _tmp18 + // Uint64List: + var _tmp22 *uint64 + if _tmp23, _tmp24, err := dec.Kind(); err != nil { + return err + } else if _tmp24 != 0 || _tmp23 != rlp.List { + _tmp21, err := dec.Uint64() + if err != nil { + return err + } + _tmp22 = &_tmp21 + } + _tmp0.Uint64List = _tmp22 + // String: + var _tmp26 *string + if _tmp27, _tmp28, err := dec.Kind(); err != nil { + return err + } else if _tmp28 != 0 || _tmp27 != rlp.String { + _tmp25, err := dec.String() + if err != nil { + return err + } + _tmp26 = &_tmp25 + } + _tmp0.String = _tmp26 + // StringList: + var _tmp30 *string + if _tmp31, _tmp32, err := dec.Kind(); err != nil { + return err + } else if _tmp32 != 0 || _tmp31 != rlp.List { + _tmp29, err := dec.String() + if err != nil { + return err + } + _tmp30 = &_tmp29 + } + _tmp0.StringList = _tmp30 + // ByteArray: + var _tmp34 *[3]byte + if _tmp35, _tmp36, err := dec.Kind(); err != nil { + return err + } else if _tmp36 != 0 || _tmp35 != rlp.String { + var _tmp33 [3]byte + if err := dec.ReadBytes(_tmp33[:]); err != nil { + return err + } + _tmp34 = &_tmp33 + } + _tmp0.ByteArray = _tmp34 + // ByteArrayList: + var _tmp38 *[3]byte + if _tmp39, _tmp40, err := dec.Kind(); err != nil { + return err + } else if _tmp40 != 0 || _tmp39 != rlp.List { + var _tmp37 [3]byte + if err := dec.ReadBytes(_tmp37[:]); err != nil { + return err + } + _tmp38 = &_tmp37 + } + _tmp0.ByteArrayList = _tmp38 + // ByteSlice: + var _tmp42 *[]byte + if _tmp43, _tmp44, err := dec.Kind(); err != nil { + return err + } else if _tmp44 != 0 || _tmp43 != rlp.String { + _tmp41, err := dec.Bytes() + if err != nil { + return err + } + _tmp42 = &_tmp41 + } + _tmp0.ByteSlice = _tmp42 + // ByteSliceList: + var _tmp46 *[]byte + if _tmp47, _tmp48, err := dec.Kind(); err != nil { + return err + } else if _tmp48 != 0 || _tmp47 != rlp.List { + _tmp45, err := dec.Bytes() + if err != nil { + return err + } + _tmp46 = &_tmp45 + } + _tmp0.ByteSliceList = _tmp46 + // Struct: + var _tmp51 *Aux + if _tmp52, _tmp53, err := dec.Kind(); err != nil { + return err + } else if _tmp53 != 0 || _tmp52 != rlp.List { + var _tmp49 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp50, err := dec.Uint32() + if err != nil { + return err + } + _tmp49.A = _tmp50 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp51 = &_tmp49 + } + _tmp0.Struct = _tmp51 + // StructString: + var _tmp56 *Aux + if _tmp57, _tmp58, err := dec.Kind(); err != nil { + return err + } else if _tmp58 != 0 || _tmp57 != rlp.String { + var _tmp54 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp55, err := dec.Uint32() + if err != nil { + return err + } + _tmp54.A = _tmp55 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp56 = &_tmp54 + } + _tmp0.StructString = _tmp56 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/optional.in.txt b/rlp/rlpgen/testdata/optional.in.txt new file mode 100644 index 0000000000..f1ac9f7899 --- /dev/null +++ b/rlp/rlpgen/testdata/optional.in.txt @@ -0,0 +1,17 @@ +// -*- mode: go -*- + +package test + +type Aux struct { + A uint64 +} + +type Test struct { + Uint64 uint64 `rlp:"optional"` + Pointer *uint64 `rlp:"optional"` + String string `rlp:"optional"` + Slice []uint64 `rlp:"optional"` + Array [3]byte `rlp:"optional"` + NamedStruct Aux `rlp:"optional"` + AnonStruct struct{ A string } `rlp:"optional"` +} diff --git a/rlp/rlpgen/testdata/optional.out.txt b/rlp/rlpgen/testdata/optional.out.txt new file mode 100644 index 0000000000..fb9b95d44d --- /dev/null +++ b/rlp/rlpgen/testdata/optional.out.txt @@ -0,0 +1,153 @@ +package test + +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + _tmp1 := obj.Uint64 != 0 + _tmp2 := obj.Pointer != nil + _tmp3 := obj.String != "" + _tmp4 := len(obj.Slice) > 0 + _tmp5 := obj.Array != ([3]byte{}) + _tmp6 := obj.NamedStruct != (Aux{}) + _tmp7 := obj.AnonStruct != (struct{ A string }{}) + if _tmp1 || _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + w.WriteUint64(obj.Uint64) + } + if _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + if obj.Pointer == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64((*obj.Pointer)) + } + } + if _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + w.WriteString(obj.String) + } + if _tmp4 || _tmp5 || _tmp6 || _tmp7 { + _tmp8 := w.List() + for _, _tmp9 := range obj.Slice { + w.WriteUint64(_tmp9) + } + w.ListEnd(_tmp8) + } + if _tmp5 || _tmp6 || _tmp7 { + w.WriteBytes(obj.Array[:]) + } + if _tmp6 || _tmp7 { + _tmp10 := w.List() + w.WriteUint64(obj.NamedStruct.A) + w.ListEnd(_tmp10) + } + if _tmp7 { + _tmp11 := w.List() + w.WriteString(obj.AnonStruct.A) + w.ListEnd(_tmp11) + } + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Uint64: + if dec.MoreDataInList() { + _tmp1, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.Uint64 = _tmp1 + // Pointer: + if dec.MoreDataInList() { + _tmp2, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.Pointer = &_tmp2 + // String: + if dec.MoreDataInList() { + _tmp3, err := dec.String() + if err != nil { + return err + } + _tmp0.String = _tmp3 + // Slice: + if dec.MoreDataInList() { + var _tmp4 []uint64 + if _, err := dec.List(); err != nil { + return err + } + for dec.MoreDataInList() { + _tmp5, err := dec.Uint64() + if err != nil { + return err + } + _tmp4 = append(_tmp4, _tmp5) + } + if err := dec.ListEnd(); err != nil { + return err + } + _tmp0.Slice = _tmp4 + // Array: + if dec.MoreDataInList() { + var _tmp6 [3]byte + if err := dec.ReadBytes(_tmp6[:]); err != nil { + return err + } + _tmp0.Array = _tmp6 + // NamedStruct: + if dec.MoreDataInList() { + var _tmp7 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp8, err := dec.Uint64() + if err != nil { + return err + } + _tmp7.A = _tmp8 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp0.NamedStruct = _tmp7 + // AnonStruct: + if dec.MoreDataInList() { + var _tmp9 struct{ A string } + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp10, err := dec.String() + if err != nil { + return err + } + _tmp9.A = _tmp10 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp0.AnonStruct = _tmp9 + } + } + } + } + } + } + } + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/rawvalue.in.txt b/rlp/rlpgen/testdata/rawvalue.in.txt new file mode 100644 index 0000000000..6c17849954 --- /dev/null +++ b/rlp/rlpgen/testdata/rawvalue.in.txt @@ -0,0 +1,11 @@ +// -*- mode: go -*- + +package test + +import "github.com/tomochain/tomochain/rlp" + +type Test struct { + RawValue rlp.RawValue + PointerToRawValue *rlp.RawValue + SliceOfRawValue []rlp.RawValue +} diff --git a/rlp/rlpgen/testdata/rawvalue.out.txt b/rlp/rlpgen/testdata/rawvalue.out.txt new file mode 100644 index 0000000000..4b6eb385d6 --- /dev/null +++ b/rlp/rlpgen/testdata/rawvalue.out.txt @@ -0,0 +1,64 @@ +package test + +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.Write(obj.RawValue) + if obj.PointerToRawValue == nil { + w.Write([]byte{0x80}) + } else { + w.Write((*obj.PointerToRawValue)) + } + _tmp1 := w.List() + for _, _tmp2 := range obj.SliceOfRawValue { + w.Write(_tmp2) + } + w.ListEnd(_tmp1) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // RawValue: + _tmp1, err := dec.Raw() + if err != nil { + return err + } + _tmp0.RawValue = _tmp1 + // PointerToRawValue: + _tmp2, err := dec.Raw() + if err != nil { + return err + } + _tmp0.PointerToRawValue = &_tmp2 + // SliceOfRawValue: + var _tmp3 []rlp.RawValue + if _, err := dec.List(); err != nil { + return err + } + for dec.MoreDataInList() { + _tmp4, err := dec.Raw() + if err != nil { + return err + } + _tmp3 = append(_tmp3, _tmp4) + } + if err := dec.ListEnd(); err != nil { + return err + } + _tmp0.SliceOfRawValue = _tmp3 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/uint256.in.txt b/rlp/rlpgen/testdata/uint256.in.txt new file mode 100644 index 0000000000..ed16e0a788 --- /dev/null +++ b/rlp/rlpgen/testdata/uint256.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +import "github.com/holiman/uint256" + +type Test struct { + Int *uint256.Int + IntNoPtr uint256.Int +} diff --git a/rlp/rlpgen/testdata/uint256.out.txt b/rlp/rlpgen/testdata/uint256.out.txt new file mode 100644 index 0000000000..5d99ca2e6d --- /dev/null +++ b/rlp/rlpgen/testdata/uint256.out.txt @@ -0,0 +1,44 @@ +package test + +import "github.com/holiman/uint256" +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Int == nil { + w.Write(rlp.EmptyString) + } else { + w.WriteUint256(obj.Int) + } + w.WriteUint256(&obj.IntNoPtr) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Int: + var _tmp1 uint256.Int + if err := dec.ReadUint256(&_tmp1); err != nil { + return err + } + _tmp0.Int = &_tmp1 + // IntNoPtr: + var _tmp2 uint256.Int + if err := dec.ReadUint256(&_tmp2); err != nil { + return err + } + _tmp0.IntNoPtr = _tmp2 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/uints.in.txt b/rlp/rlpgen/testdata/uints.in.txt new file mode 100644 index 0000000000..8095da997d --- /dev/null +++ b/rlp/rlpgen/testdata/uints.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +type Test struct{ + A uint8 + B uint16 + C uint32 + D uint64 +} diff --git a/rlp/rlpgen/testdata/uints.out.txt b/rlp/rlpgen/testdata/uints.out.txt new file mode 100644 index 0000000000..17896dd305 --- /dev/null +++ b/rlp/rlpgen/testdata/uints.out.txt @@ -0,0 +1,53 @@ +package test + +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.WriteUint64(uint64(obj.A)) + w.WriteUint64(uint64(obj.B)) + w.WriteUint64(uint64(obj.C)) + w.WriteUint64(obj.D) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp1, err := dec.Uint8() + if err != nil { + return err + } + _tmp0.A = _tmp1 + // B: + _tmp2, err := dec.Uint16() + if err != nil { + return err + } + _tmp0.B = _tmp2 + // C: + _tmp3, err := dec.Uint32() + if err != nil { + return err + } + _tmp0.C = _tmp3 + // D: + _tmp4, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.D = _tmp4 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/types.go b/rlp/rlpgen/types.go new file mode 100644 index 0000000000..ea7dc96d88 --- /dev/null +++ b/rlp/rlpgen/types.go @@ -0,0 +1,124 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "fmt" + "go/types" + "reflect" +) + +// typeReflectKind gives the reflect.Kind that represents typ. +func typeReflectKind(typ types.Type) reflect.Kind { + switch typ := typ.(type) { + case *types.Basic: + k := typ.Kind() + if k >= types.Bool && k <= types.Complex128 { + // value order matches for Bool..Complex128 + return reflect.Bool + reflect.Kind(k-types.Bool) + } + if k == types.String { + return reflect.String + } + if k == types.UnsafePointer { + return reflect.UnsafePointer + } + panic(fmt.Errorf("unhandled BasicKind %v", k)) + case *types.Array: + return reflect.Array + case *types.Chan: + return reflect.Chan + case *types.Interface: + return reflect.Interface + case *types.Map: + return reflect.Map + case *types.Pointer: + return reflect.Ptr + case *types.Signature: + return reflect.Func + case *types.Slice: + return reflect.Slice + case *types.Struct: + return reflect.Struct + default: + panic(fmt.Errorf("unhandled type %T", typ)) + } +} + +// nonZeroCheck returns the expression that checks whether 'v' is a non-zero value of type 'vtyp'. +func nonZeroCheck(v string, vtyp types.Type, qualify types.Qualifier) string { + // Resolve type name. + typ := resolveUnderlying(vtyp) + switch typ := typ.(type) { + case *types.Basic: + k := typ.Kind() + switch { + case k == types.Bool: + return v + case k >= types.Uint && k <= types.Complex128: + return fmt.Sprintf("%s != 0", v) + case k == types.String: + return fmt.Sprintf(`%s != ""`, v) + default: + panic(fmt.Errorf("unhandled BasicKind %v", k)) + } + case *types.Array, *types.Struct: + return fmt.Sprintf("%s != (%s{})", v, types.TypeString(vtyp, qualify)) + case *types.Interface, *types.Pointer, *types.Signature: + return fmt.Sprintf("%s != nil", v) + case *types.Slice, *types.Map: + return fmt.Sprintf("len(%s) > 0", v) + default: + panic(fmt.Errorf("unhandled type %T", typ)) + } +} + +// isBigInt checks whether 'typ' is "math/big".Int. +func isBigInt(typ types.Type) bool { + named, ok := typ.(*types.Named) + if !ok { + return false + } + name := named.Obj() + return name.Pkg().Path() == "math/big" && name.Name() == "Int" +} + +// isUint256 checks whether 'typ' is "github.com/holiman/uint256".Int. +func isUint256(typ types.Type) bool { + named, ok := typ.(*types.Named) + if !ok { + return false + } + name := named.Obj() + return name.Pkg().Path() == "github.com/holiman/uint256" && name.Name() == "Int" +} + +// isByte checks whether the underlying type of 'typ' is uint8. +func isByte(typ types.Type) bool { + basic, ok := resolveUnderlying(typ).(*types.Basic) + return ok && basic.Kind() == types.Uint8 +} + +func resolveUnderlying(typ types.Type) types.Type { + for { + t := typ.Underlying() + if t == typ { + return t + } + typ = t + } +} From ddf35f8c4f74ab1b2a801f4960428bb61f0b0526 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 16 Jun 2023 16:02:10 +0700 Subject: [PATCH 004/119] Update unit tests and benchmark --- rlp/decode_test.go | 520 ++++++++++++++++++++++++++++++++-- rlp/encbuffer_example_test.go | 45 +++ rlp/encode_test.go | 338 ++++++++++++++++++++-- rlp/encoder_example_test.go | 20 +- rlp/iterator_test.go | 59 ++++ rlp/raw_test.go | 160 ++++++++++- 6 files changed, 1074 insertions(+), 68 deletions(-) create mode 100644 rlp/encbuffer_example_test.go create mode 100644 rlp/iterator_test.go diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 4d8abd0012..3ee237fb09 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -26,6 +26,10 @@ import ( "reflect" "strings" "testing" + + "github.com/tomochain/tomochain/common/math" + + "github.com/holiman/uint256" ) func TestStreamKind(t *testing.T) { @@ -284,6 +288,47 @@ func TestStreamRaw(t *testing.T) { } } +func TestStreamReadBytes(t *testing.T) { + tests := []struct { + input string + size int + err string + }{ + // kind List + {input: "C0", size: 1, err: "rlp: expected String or Byte"}, + // kind Byte + {input: "04", size: 0, err: "input value has wrong size 1, want 0"}, + {input: "04", size: 1}, + {input: "04", size: 2, err: "input value has wrong size 1, want 2"}, + // kind String + {input: "820102", size: 0, err: "input value has wrong size 2, want 0"}, + {input: "820102", size: 1, err: "input value has wrong size 2, want 1"}, + {input: "820102", size: 2}, + {input: "820102", size: 3, err: "input value has wrong size 2, want 3"}, + } + + for _, test := range tests { + test := test + name := fmt.Sprintf("input_%s/size_%d", test.input, test.size) + t.Run(name, func(t *testing.T) { + s := NewStream(bytes.NewReader(unhex(test.input)), 0) + b := make([]byte, test.size) + err := s.ReadBytes(b) + if test.err == "" { + if err != nil { + t.Errorf("unexpected error %q", err) + } + } else { + if err == nil { + t.Errorf("expected error, got nil") + } else if err.Error() != test.err { + t.Errorf("wrong error %q", err) + } + } + }) + } +} + func TestDecodeErrors(t *testing.T) { r := bytes.NewReader(nil) @@ -327,6 +372,15 @@ type recstruct struct { Child *recstruct `rlp:"nil"` } +type bigIntStruct struct { + I *big.Int + B string +} + +type invalidNilTag struct { + X []byte `rlp:"nil"` +} + type invalidTail1 struct { A uint `rlp:"tail"` B string @@ -347,19 +401,79 @@ type tailUint struct { Tail []uint `rlp:"tail"` } -var ( - veryBigInt = big.NewInt(0).Add( - big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), - big.NewInt(0xFFFF), - ) -) +type tailPrivateFields struct { + A uint + Tail []uint `rlp:"tail"` + x, y bool //lint:ignore U1000 unused fields required for testing purposes. +} + +type nilListUint struct { + X *uint `rlp:"nilList"` +} + +type nilStringSlice struct { + X *[]uint `rlp:"nilString"` +} + +type intField struct { + X int +} + +type optionalFields struct { + A uint + B uint `rlp:"optional"` + C uint `rlp:"optional"` +} + +type optionalAndTailField struct { + A uint + B uint `rlp:"optional"` + Tail []uint `rlp:"tail"` +} + +type optionalBigIntField struct { + A uint + B *big.Int `rlp:"optional"` +} + +type optionalPtrField struct { + A uint + B *[3]byte `rlp:"optional"` +} + +type nonOptionalPtrField struct { + A uint + B *[3]byte +} -type hasIgnoredField struct { +type multipleOptionalFields struct { + A *[3]byte `rlp:"optional"` + B *[3]byte `rlp:"optional"` +} + +type optionalPtrFieldNil struct { + A uint + B *[3]byte `rlp:"optional,nil"` +} + +type ignoredField struct { A uint B uint `rlp:"-"` C uint } +var ( + veryBigInt = new(big.Int).Add( + new(big.Int).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), + big.NewInt(0xFFFF), + ) + veryVeryBigInt = new(big.Int).Exp(veryBigInt, big.NewInt(8), nil) +) + +var ( + veryBigInt256, _ = uint256.FromBig(veryBigInt) +) + var decodeTests = []decodeTest{ // booleans {input: "01", ptr: new(bool), value: true}, @@ -428,12 +542,31 @@ var decodeTests = []decodeTest{ {input: "C0", ptr: new(string), error: "rlp: expected input string or byte for string"}, // big ints + {input: "80", ptr: new(*big.Int), value: big.NewInt(0)}, {input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, + {input: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", ptr: new(*big.Int), value: veryVeryBigInt}, {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works + + // big int errors {input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"}, - {input: "820001", ptr: new(big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, - {input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"}, + {input: "00", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, + {input: "820001", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, + {input: "8105", ptr: new(*big.Int), error: "rlp: non-canonical size information for *big.Int"}, + + // uint256 + {input: "80", ptr: new(*uint256.Int), value: uint256.NewInt(0)}, + {input: "01", ptr: new(*uint256.Int), value: uint256.NewInt(1)}, + {input: "88FFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: uint256.NewInt(math.MaxUint64)}, + {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: veryBigInt256}, + {input: "10", ptr: new(uint256.Int), value: *uint256.NewInt(16)}, // non-pointer also works + + // uint256 errors + {input: "C0", ptr: new(*uint256.Int), error: "rlp: expected input string or byte for *uint256.Int"}, + {input: "00", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"}, + {input: "820001", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"}, + {input: "8105", ptr: new(*uint256.Int), error: "rlp: non-canonical size information for *uint256.Int"}, + {input: "A1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00", ptr: new(*uint256.Int), error: "rlp: value too large for uint256"}, // structs { @@ -446,6 +579,13 @@ var decodeTests = []decodeTest{ ptr: new(recstruct), value: recstruct{1, &recstruct{2, &recstruct{3, nil}}}, }, + { + // This checks that empty big.Int works correctly in struct context. It's easy to + // miss the update of s.kind for this case, so it needs its own test. + input: "C58083343434", + ptr: new(bigIntStruct), + value: bigIntStruct{new(big.Int), "444"}, + }, // struct errors { @@ -479,20 +619,20 @@ var decodeTests = []decodeTest{ error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I", }, { - input: "C0", - ptr: new(invalidTail1), - error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail1.A (must be on last field)", - }, - { - input: "C0", - ptr: new(invalidTail2), - error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail2.B (field type is not slice)", + input: "C103", + ptr: new(intField), + error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)", }, { input: "C50102C20102", ptr: new(tailUint), error: "rlp: expected input string or byte for uint, decoding into (rlp.tailUint).Tail[1]", }, + { + input: "C0", + ptr: new(invalidNilTag), + error: `rlp: invalid struct tag "nil" for rlp.invalidNilTag.X (field is not a pointer)`, + }, // struct tag "tail" { @@ -510,12 +650,192 @@ var decodeTests = []decodeTest{ ptr: new(tailRaw), value: tailRaw{A: 1, Tail: []RawValue{}}, }, + { + input: "C3010203", + ptr: new(tailPrivateFields), + value: tailPrivateFields{A: 1, Tail: []uint{2, 3}}, + }, + { + input: "C0", + ptr: new(invalidTail1), + error: `rlp: invalid struct tag "tail" for rlp.invalidTail1.A (must be on last field)`, + }, + { + input: "C0", + ptr: new(invalidTail2), + error: `rlp: invalid struct tag "tail" for rlp.invalidTail2.B (field type is not slice)`, + }, // struct tag "-" { input: "C20102", - ptr: new(hasIgnoredField), - value: hasIgnoredField{A: 1, C: 2}, + ptr: new(ignoredField), + value: ignoredField{A: 1, C: 2}, + }, + + // struct tag "nilList" + { + input: "C180", + ptr: new(nilListUint), + error: "rlp: wrong kind of empty value (got String, want List) for *uint, decoding into (rlp.nilListUint).X", + }, + { + input: "C1C0", + ptr: new(nilListUint), + value: nilListUint{}, + }, + { + input: "C103", + ptr: new(nilListUint), + value: func() interface{} { + v := uint(3) + return nilListUint{X: &v} + }(), + }, + + // struct tag "nilString" + { + input: "C1C0", + ptr: new(nilStringSlice), + error: "rlp: wrong kind of empty value (got List, want String) for *[]uint, decoding into (rlp.nilStringSlice).X", + }, + { + input: "C180", + ptr: new(nilStringSlice), + value: nilStringSlice{}, + }, + { + input: "C2C103", + ptr: new(nilStringSlice), + value: nilStringSlice{X: &[]uint{3}}, + }, + + // struct tag "optional" + { + input: "C101", + ptr: new(optionalFields), + value: optionalFields{1, 0, 0}, + }, + { + input: "C20102", + ptr: new(optionalFields), + value: optionalFields{1, 2, 0}, + }, + { + input: "C3010203", + ptr: new(optionalFields), + value: optionalFields{1, 2, 3}, + }, + { + input: "C401020304", + ptr: new(optionalFields), + error: "rlp: input list has too many elements for rlp.optionalFields", + }, + { + input: "C101", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1}, + }, + { + input: "C20102", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}}, + }, + { + input: "C401020304", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{3, 4}}, + }, + { + input: "C101", + ptr: new(optionalBigIntField), + value: optionalBigIntField{A: 1, B: nil}, + }, + { + input: "C20102", + ptr: new(optionalBigIntField), + value: optionalBigIntField{A: 1, B: big.NewInt(2)}, + }, + { + input: "C101", + ptr: new(optionalPtrField), + value: optionalPtrField{A: 1}, + }, + { + input: "C20180", // not accepted because "optional" doesn't enable "nil" + ptr: new(optionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B", + }, + { + input: "C20102", + ptr: new(optionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B", + }, + { + input: "C50183010203", + ptr: new(optionalPtrField), + value: optionalPtrField{A: 1, B: &[3]byte{1, 2, 3}}, + }, + { + // all optional fields nil + input: "C0", + ptr: new(multipleOptionalFields), + value: multipleOptionalFields{A: nil, B: nil}, + }, + { + // all optional fields set + input: "C88301020383010203", + ptr: new(multipleOptionalFields), + value: multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}}, + }, + { + // nil optional field appears before a non-nil one + input: "C58083010203", + ptr: new(multipleOptionalFields), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.multipleOptionalFields).A", + }, + { + // decode a nil ptr into a ptr that is not nil or not optional + input: "C20180", + ptr: new(nonOptionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.nonOptionalPtrField).B", + }, + { + input: "C101", + ptr: new(optionalPtrFieldNil), + value: optionalPtrFieldNil{A: 1}, + }, + { + input: "C20180", // accepted because "nil" tag allows empty input + ptr: new(optionalPtrFieldNil), + value: optionalPtrFieldNil{A: 1}, + }, + { + input: "C20102", + ptr: new(optionalPtrFieldNil), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrFieldNil).B", + }, + + // struct tag "optional" field clearing + { + input: "C101", + ptr: &optionalFields{A: 9, B: 8, C: 7}, + value: optionalFields{A: 1, B: 0, C: 0}, + }, + { + input: "C20102", + ptr: &optionalFields{A: 9, B: 8, C: 7}, + value: optionalFields{A: 1, B: 2, C: 0}, + }, + { + input: "C20102", + ptr: &optionalAndTailField{A: 9, B: 8, Tail: []uint{7, 6, 5}}, + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}}, + }, + { + input: "C101", + ptr: &optionalPtrField{A: 9, B: &[3]byte{8, 7, 6}}, + value: optionalPtrField{A: 1}, }, // RawValue @@ -591,6 +911,26 @@ func TestDecodeWithByteReader(t *testing.T) { }) } +func testDecodeWithEncReader(t *testing.T, n int) { + s := strings.Repeat("0", n) + _, r, _ := EncodeToReader(s) + var decoded string + err := Decode(r, &decoded) + if err != nil { + t.Errorf("Unexpected decode error with n=%v: %v", n, err) + } + if decoded != s { + t.Errorf("Decode mismatch with n=%v", n) + } +} + +// This is a regression test checking that decoding from encReader +// works for RLP values of size 8192 bytes or more. +func TestDecodeWithEncReader(t *testing.T) { + testDecodeWithEncReader(t, 8188) // length with header is 8191 + testDecodeWithEncReader(t, 8189) // length with header is 8192 +} + // plainReader reads from a byte slice but does not // implement ReadByte. It is also not recognized by the // size validation. This is useful to test how the decoder @@ -661,6 +1001,22 @@ func TestDecodeDecoder(t *testing.T) { } } +func TestDecodeDecoderNilPointer(t *testing.T) { + var s struct { + T1 *testDecoder `rlp:"nil"` + T2 *testDecoder + } + if err := Decode(bytes.NewReader(unhex("C2C002")), &s); err != nil { + t.Fatalf("Decode error: %v", err) + } + if s.T1 != nil { + t.Errorf("decoder T1 allocated for empty input (called: %v)", s.T1.called) + } + if s.T2 == nil || !s.T2.called { + t.Errorf("decoder T2 not allocated/called") + } +} + type byteDecoder byte func (bd *byteDecoder) DecodeRLP(s *Stream) error { @@ -691,13 +1047,66 @@ func TestDecoderInByteSlice(t *testing.T) { } } +type unencodableDecoder func() + +func (f *unencodableDecoder) DecodeRLP(s *Stream) error { + if _, err := s.List(); err != nil { + return err + } + if err := s.ListEnd(); err != nil { + return err + } + *f = func() {} + return nil +} + +func TestDecoderFunc(t *testing.T) { + var x func() + if err := DecodeBytes([]byte{0xC0}, (*unencodableDecoder)(&x)); err != nil { + t.Fatal(err) + } + x() +} + +// This tests the validity checks for fields with struct tag "optional". +func TestInvalidOptionalField(t *testing.T) { + type ( + invalid1 struct { + A uint `rlp:"optional"` + B uint + } + invalid2 struct { + T []uint `rlp:"tail,optional"` + } + invalid3 struct { + T []uint `rlp:"optional,tail"` + } + ) + + tests := []struct { + v interface{} + err string + }{ + {v: new(invalid1), err: `rlp: invalid struct tag "" for rlp.invalid1.B (must be optional because preceding field "A" is optional)`}, + {v: new(invalid2), err: `rlp: invalid struct tag "optional" for rlp.invalid2.T (also has "tail" tag)`}, + {v: new(invalid3), err: `rlp: invalid struct tag "tail" for rlp.invalid3.T (also has "optional" tag)`}, + } + for _, test := range tests { + err := DecodeBytes(unhex("C20102"), test.v) + if err == nil { + t.Errorf("no error for %T", test.v) + } else if err.Error() != test.err { + t.Errorf("wrong error for %T: %v", test.v, err.Error()) + } + } +} + func ExampleDecode() { input, _ := hex.DecodeString("C90A1486666F6F626172") type example struct { - A, B uint - private uint // private fields are ignored - String string + A, B uint + String string } var s example @@ -708,7 +1117,7 @@ func ExampleDecode() { fmt.Printf("Decoded value: %#v\n", s) } // Output: - // Decoded value: rlp.example{A:0xa, B:0x14, private:0x0, String:"foobar"} + // Decoded value: rlp.example{A:0xa, B:0x14, String:"foobar"} } func ExampleDecode_structTagNil() { @@ -768,7 +1177,7 @@ func ExampleStream() { // [102 111 111 98 97 114] } -func BenchmarkDecode(b *testing.B) { +func BenchmarkDecodeUints(b *testing.B) { enc := encodeTestSlice(90000) b.SetBytes(int64(len(enc))) b.ReportAllocs() @@ -783,7 +1192,7 @@ func BenchmarkDecode(b *testing.B) { } } -func BenchmarkDecodeIntSliceReuse(b *testing.B) { +func BenchmarkDecodeUintsReused(b *testing.B) { enc := encodeTestSlice(100000) b.SetBytes(int64(len(enc))) b.ReportAllocs() @@ -798,6 +1207,65 @@ func BenchmarkDecodeIntSliceReuse(b *testing.B) { } } +func BenchmarkDecodeByteArrayStruct(b *testing.B) { + enc, err := EncodeToBytes(&byteArrayStruct{}) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out byteArrayStruct + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecodeBigInts(b *testing.B) { + ints := make([]*big.Int, 200) + for i := range ints { + ints[i] = math.BigPow(2, int64(i)) + } + enc, err := EncodeToBytes(ints) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out []*big.Int + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecodeU256Ints(b *testing.B) { + ints := make([]*uint256.Int, 200) + for i := range ints { + ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i))) + } + enc, err := EncodeToBytes(ints) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out []*uint256.Int + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + func encodeTestSlice(n uint) []byte { s := make([]uint, n) for i := uint(0); i < n; i++ { @@ -811,7 +1279,7 @@ func encodeTestSlice(n uint) []byte { } func unhex(str string) []byte { - b, err := hex.DecodeString(strings.Replace(str, " ", "", -1)) + b, err := hex.DecodeString(strings.ReplaceAll(str, " ", "")) if err != nil { panic(fmt.Sprintf("invalid hex string: %q", str)) } diff --git a/rlp/encbuffer_example_test.go b/rlp/encbuffer_example_test.go new file mode 100644 index 0000000000..c41de60f02 --- /dev/null +++ b/rlp/encbuffer_example_test.go @@ -0,0 +1,45 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp_test + +import ( + "bytes" + "fmt" + + "github.com/tomochain/tomochain/rlp" +) + +func ExampleEncoderBuffer() { + var w bytes.Buffer + + // Encode [4, [5, 6]] to w. + buf := rlp.NewEncoderBuffer(&w) + l1 := buf.List() + buf.WriteUint64(4) + l2 := buf.List() + buf.WriteUint64(5) + buf.WriteUint64(6) + buf.ListEnd(l2) + buf.ListEnd(l1) + + if err := buf.Flush(); err != nil { + panic(err) + } + fmt.Printf("%X\n", w.Bytes()) + // Output: + // C404C20506 +} diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 827960f7c1..7b8775c12b 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -21,10 +21,14 @@ import ( "errors" "fmt" "io" - "io/ioutil" "math/big" + "runtime" "sync" "testing" + + "github.com/tomochain/tomochain/common/math" + + "github.com/holiman/uint256" ) type testEncoder struct { @@ -33,12 +37,19 @@ type testEncoder struct { func (e *testEncoder) EncodeRLP(w io.Writer) error { if e == nil { - w.Write([]byte{0, 0, 0, 0}) - } else if e.err != nil { + panic("EncodeRLP called on nil value") + } + if e.err != nil { return e.err - } else { - w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) } + w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) + return nil +} + +type testEncoderValueMethod struct{} + +func (e testEncoderValueMethod) EncodeRLP(w io.Writer) error { + w.Write([]byte{0xFA, 0xFE, 0xF0}) return nil } @@ -49,6 +60,13 @@ func (e byteEncoder) EncodeRLP(w io.Writer) error { return nil } +type undecodableEncoder func() + +func (f undecodableEncoder) EncodeRLP(w io.Writer) error { + w.Write([]byte{0xF5, 0xF5, 0xF5}) + return nil +} + type encodableReader struct { A, B uint } @@ -103,35 +121,95 @@ var encTests = []encTest{ {val: big.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"}, {val: big.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"}, { - val: big.NewInt(0).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), + val: new(big.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), output: "8F102030405060708090A0B0C0D0E0F2", }, { - val: big.NewInt(0).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), + val: new(big.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), output: "9C0100020003000400050006000700080009000A000B000C000D000E01", }, { - val: big.NewInt(0).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")), + val: new(big.Int).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")), output: "A1010000000000000000000000000000000000000000000000000000000000000000", }, + { + val: veryBigInt, + output: "89FFFFFFFFFFFFFFFFFF", + }, + { + val: veryVeryBigInt, + output: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", + }, // non-pointer big.Int {val: *big.NewInt(0), output: "80"}, {val: *big.NewInt(0xFFFFFF), output: "83FFFFFF"}, // negative ints are not supported - {val: big.NewInt(-1), error: "rlp: cannot encode negative *big.Int"}, - - // byte slices, strings + {val: big.NewInt(-1), error: "rlp: cannot encode negative big.Int"}, + {val: *big.NewInt(-1), error: "rlp: cannot encode negative big.Int"}, + + // uint256 + {val: uint256.NewInt(0), output: "80"}, + {val: uint256.NewInt(1), output: "01"}, + {val: uint256.NewInt(127), output: "7F"}, + {val: uint256.NewInt(128), output: "8180"}, + {val: uint256.NewInt(256), output: "820100"}, + {val: uint256.NewInt(1024), output: "820400"}, + {val: uint256.NewInt(0xFFFFFF), output: "83FFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFF), output: "84FFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFF), output: "85FFFFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"}, + { + val: new(uint256.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), + output: "8F102030405060708090A0B0C0D0E0F2", + }, + { + val: new(uint256.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), + output: "9C0100020003000400050006000700080009000A000B000C000D000E01", + }, + // non-pointer uint256.Int + {val: *uint256.NewInt(0), output: "80"}, + {val: *uint256.NewInt(0xFFFFFF), output: "83FFFFFF"}, + + // byte arrays + {val: [0]byte{}, output: "80"}, + {val: [1]byte{0}, output: "00"}, + {val: [1]byte{1}, output: "01"}, + {val: [1]byte{0x7F}, output: "7F"}, + {val: [1]byte{0x80}, output: "8180"}, + {val: [1]byte{0xFF}, output: "81FF"}, + {val: [3]byte{1, 2, 3}, output: "83010203"}, + {val: [57]byte{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + + // named byte type arrays + {val: [0]namedByteType{}, output: "80"}, + {val: [1]namedByteType{0}, output: "00"}, + {val: [1]namedByteType{1}, output: "01"}, + {val: [1]namedByteType{0x7F}, output: "7F"}, + {val: [1]namedByteType{0x80}, output: "8180"}, + {val: [1]namedByteType{0xFF}, output: "81FF"}, + {val: [3]namedByteType{1, 2, 3}, output: "83010203"}, + {val: [57]namedByteType{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + + // byte slices {val: []byte{}, output: "80"}, + {val: []byte{0}, output: "00"}, {val: []byte{0x7E}, output: "7E"}, {val: []byte{0x7F}, output: "7F"}, {val: []byte{0x80}, output: "8180"}, {val: []byte{1, 2, 3}, output: "83010203"}, + // named byte type slices + {val: []namedByteType{}, output: "80"}, + {val: []namedByteType{0}, output: "00"}, + {val: []namedByteType{0x7E}, output: "7E"}, + {val: []namedByteType{0x7F}, output: "7F"}, + {val: []namedByteType{0x80}, output: "8180"}, {val: []namedByteType{1, 2, 3}, output: "83010203"}, - {val: [...]namedByteType{1, 2, 3}, output: "83010203"}, + // strings {val: "", output: "80"}, {val: "\x7E", output: "7E"}, {val: "\x7F", output: "7F"}, @@ -204,6 +282,12 @@ var encTests = []encTest{ output: "F90200CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376", }, + // Non-byte arrays are encoded as lists. + // Note that it is important to test [4]uint64 specifically, + // because that's the underlying type of uint256.Int. + {val: [4]uint32{1, 2, 3, 4}, output: "C401020304"}, + {val: [4]uint64{1, 2, 3, 4}, output: "C401020304"}, + // RawValue {val: RawValue(unhex("01")), output: "01"}, {val: RawValue(unhex("82FFFF")), output: "82FFFF"}, @@ -214,11 +298,34 @@ var encTests = []encTest{ {val: simplestruct{A: 3, B: "foo"}, output: "C50383666F6F"}, {val: &recstruct{5, nil}, output: "C205C0"}, {val: &recstruct{5, &recstruct{4, &recstruct{3, nil}}}, output: "C605C404C203C0"}, + {val: &intField{X: 3}, error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)"}, + + // struct tag "-" + {val: &ignoredField{A: 1, B: 2, C: 3}, output: "C20103"}, + + // struct tag "tail" {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02"), unhex("03")}}, output: "C3010203"}, {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02")}}, output: "C20102"}, {val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"}, {val: &tailRaw{A: 1, Tail: nil}, output: "C101"}, - {val: &hasIgnoredField{A: 1, B: 2, C: 3}, output: "C20103"}, + + // struct tag "optional" + {val: &optionalFields{}, output: "C180"}, + {val: &optionalFields{A: 1}, output: "C101"}, + {val: &optionalFields{A: 1, B: 2}, output: "C20102"}, + {val: &optionalFields{A: 1, B: 2, C: 3}, output: "C3010203"}, + {val: &optionalFields{A: 1, B: 0, C: 3}, output: "C3018003"}, + {val: &optionalAndTailField{A: 1}, output: "C101"}, + {val: &optionalAndTailField{A: 1, B: 2}, output: "C20102"}, + {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"}, + {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"}, + {val: &optionalBigIntField{A: 1}, output: "C101"}, + {val: &optionalPtrField{A: 1}, output: "C101"}, + {val: &optionalPtrFieldNil{A: 1}, output: "C101"}, + {val: &multipleOptionalFields{A: nil, B: nil}, output: "C0"}, + {val: &multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}}, output: "C88301020383010203"}, + {val: &multipleOptionalFields{A: nil, B: &[3]byte{1, 2, 3}}, output: "C58083010203"}, // encodes without error but decode will fail + {val: &nonOptionalPtrField{A: 1}, output: "C20180"}, // encodes without error but decode will fail // nil {val: (*uint)(nil), output: "80"}, @@ -226,26 +333,73 @@ var encTests = []encTest{ {val: (*[]byte)(nil), output: "80"}, {val: (*[10]byte)(nil), output: "80"}, {val: (*big.Int)(nil), output: "80"}, + {val: (*uint256.Int)(nil), output: "80"}, {val: (*[]string)(nil), output: "C0"}, {val: (*[10]string)(nil), output: "C0"}, {val: (*[]interface{})(nil), output: "C0"}, {val: (*[]struct{ uint })(nil), output: "C0"}, {val: (*interface{})(nil), output: "C0"}, + // nil struct fields + { + val: struct { + X *[]byte + }{}, + output: "C180", + }, + { + val: struct { + X *[2]byte + }{}, + output: "C180", + }, + { + val: struct { + X *uint64 + }{}, + output: "C180", + }, + { + val: struct { + X *uint64 `rlp:"nilList"` + }{}, + output: "C1C0", + }, + { + val: struct { + X *[]uint64 + }{}, + output: "C1C0", + }, + { + val: struct { + X *[]uint64 `rlp:"nilString"` + }{}, + output: "C180", + }, + // interfaces {val: []io.Reader{reader}, output: "C3C20102"}, // the contained value is a struct // Encoder - {val: (*testEncoder)(nil), output: "00000000"}, + {val: (*testEncoder)(nil), output: "C0"}, {val: &testEncoder{}, output: "00010001000100010001"}, {val: &testEncoder{errors.New("test error")}, error: "test error"}, - // verify that pointer method testEncoder.EncodeRLP is called for + {val: struct{ E testEncoderValueMethod }{}, output: "C3FAFEF0"}, + {val: struct{ E *testEncoderValueMethod }{}, output: "C1C0"}, + + // Verify that the Encoder interface works for unsupported types like func(). + {val: undecodableEncoder(func() {}), output: "F5F5F5"}, + + // Verify that pointer method testEncoder.EncodeRLP is called for // addressable non-pointer values. {val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"}, {val: &struct{ TE testEncoder }{testEncoder{errors.New("test error")}}, error: "test error"}, - // verify the error for non-addressable non-pointer Encoder - {val: testEncoder{}, error: "rlp: game over: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"}, - // verify the special case for []byte + + // Verify the error for non-addressable non-pointer Encoder. + {val: testEncoder{}, error: "rlp: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"}, + + // Verify Encoder takes precedence over []byte. {val: []byteEncoder{0, 1, 2, 3, 4}, output: "C5C0C0C0C0C0"}, } @@ -281,13 +435,28 @@ func TestEncodeToBytes(t *testing.T) { runEncTests(t, EncodeToBytes) } +func TestEncodeAppendToBytes(t *testing.T) { + buffer := make([]byte, 20) + runEncTests(t, func(val interface{}) ([]byte, error) { + w := NewEncoderBuffer(nil) + defer w.Flush() + + err := Encode(w, val) + if err != nil { + return nil, err + } + output := w.AppendToBytes(buffer[:0]) + return output, nil + }) +} + func TestEncodeToReader(t *testing.T) { runEncTests(t, func(val interface{}) ([]byte, error) { _, r, err := EncodeToReader(val) if err != nil { return nil, err } - return ioutil.ReadAll(r) + return io.ReadAll(r) }) } @@ -328,7 +497,7 @@ func TestEncodeToReaderReturnToPool(t *testing.T) { go func() { for i := 0; i < 1000; i++ { _, r, _ := EncodeToReader("foo") - ioutil.ReadAll(r) + io.ReadAll(r) r.Read(buf) r.Read(buf) r.Read(buf) @@ -339,3 +508,132 @@ func TestEncodeToReaderReturnToPool(t *testing.T) { } wg.Wait() } + +var sink interface{} + +func BenchmarkIntsize(b *testing.B) { + for i := 0; i < b.N; i++ { + sink = intsize(0x12345678) + } +} + +func BenchmarkPutint(b *testing.B) { + buf := make([]byte, 8) + for i := 0; i < b.N; i++ { + putint(buf, 0x12345678) + sink = buf + } +} + +func BenchmarkEncodeBigInts(b *testing.B) { + ints := make([]*big.Int, 200) + for i := range ints { + ints[i] = math.BigPow(2, int64(i)) + } + out := bytes.NewBuffer(make([]byte, 0, 4096)) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(out, ints); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeU256Ints(b *testing.B) { + ints := make([]*uint256.Int, 200) + for i := range ints { + ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i))) + } + out := bytes.NewBuffer(make([]byte, 0, 4096)) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(out, ints); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeConcurrentInterface(b *testing.B) { + type struct1 struct { + A string + B *big.Int + C [20]byte + } + value := []interface{}{ + uint(999), + &struct1{A: "hello", B: big.NewInt(0xFFFFFFFF)}, + [10]byte{1, 2, 3, 4, 5, 6}, + []string{"yeah", "yeah", "yeah"}, + } + + var wg sync.WaitGroup + for cpu := 0; cpu < runtime.NumCPU(); cpu++ { + wg.Add(1) + go func() { + defer wg.Done() + + var buffer bytes.Buffer + for i := 0; i < b.N; i++ { + buffer.Reset() + err := Encode(&buffer, value) + if err != nil { + panic(err) + } + } + }() + } + wg.Wait() +} + +type byteArrayStruct struct { + A [20]byte + B [32]byte + C [32]byte +} + +func BenchmarkEncodeByteArrayStruct(b *testing.B) { + var out bytes.Buffer + var value byteArrayStruct + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(&out, &value); err != nil { + b.Fatal(err) + } + } +} + +type structSliceElem struct { + X uint64 + Y uint64 + Z uint64 +} + +type structPtrSlice []*structSliceElem + +func BenchmarkEncodeStructPtrSlice(b *testing.B) { + var out bytes.Buffer + var value = structPtrSlice{ + &structSliceElem{1, 1, 1}, + &structSliceElem{2, 2, 2}, + &structSliceElem{3, 3, 3}, + &structSliceElem{5, 5, 5}, + &structSliceElem{6, 6, 6}, + &structSliceElem{7, 7, 7}, + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(&out, &value); err != nil { + b.Fatal(err) + } + } +} diff --git a/rlp/encoder_example_test.go b/rlp/encoder_example_test.go index 1cffa241c2..6291bfafe5 100644 --- a/rlp/encoder_example_test.go +++ b/rlp/encoder_example_test.go @@ -14,11 +14,13 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package rlp +package rlp_test import ( "fmt" "io" + + "github.com/tomochain/tomochain/rlp" ) type MyCoolType struct { @@ -28,27 +30,19 @@ type MyCoolType struct { // EncodeRLP writes x as RLP list [a, b] that omits the Name field. func (x *MyCoolType) EncodeRLP(w io.Writer) (err error) { - // Note: the receiver can be a nil pointer. This allows you to - // control the encoding of nil, but it also means that you have to - // check for a nil receiver. - if x == nil { - err = Encode(w, []uint{0, 0}) - } else { - err = Encode(w, []uint{x.a, x.b}) - } - return err + return rlp.Encode(w, []uint{x.a, x.b}) } func ExampleEncoder() { var t *MyCoolType // t is nil pointer to MyCoolType - bytes, _ := EncodeToBytes(t) + bytes, _ := rlp.EncodeToBytes(t) fmt.Printf("%v → %X\n", t, bytes) t = &MyCoolType{Name: "foobar", a: 5, b: 6} - bytes, _ = EncodeToBytes(t) + bytes, _ = rlp.EncodeToBytes(t) fmt.Printf("%v → %X\n", t, bytes) // Output: - // → C28080 + // → C0 // &{foobar 5 6} → C20506 } diff --git a/rlp/iterator_test.go b/rlp/iterator_test.go new file mode 100644 index 0000000000..87c11bdbae --- /dev/null +++ b/rlp/iterator_test.go @@ -0,0 +1,59 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +import ( + "testing" + + "github.com/tomochain/tomochain/common/hexutil" +) + +// TestIterator tests some basic things about the ListIterator. A more +// comprehensive test can be found in core/rlp_test.go, where we can +// use both types and rlp without dependency cycles +func TestIterator(t *testing.T) { + bodyRlpHex := "0xf902cbf8d6f869800182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ba01025c66fad28b4ce3370222624d952c35529e602af7cbe04f667371f61b0e3b3a00ab8813514d1217059748fd903288ace1b4001a4bc5fbde2790debdc8167de2ff869010182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ca05ac4cf1d19be06f3742c21df6c49a7e929ceb3dbaf6a09f3cfb56ff6828bd9a7a06875970133a35e63ac06d360aa166d228cc013e9b96e0a2cae7f55b22e1ee2e8f901f0f901eda0c75448377c0e426b8017b23c5f77379ecf69abc1d5c224284ad3ba1c46c59adaa00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000808080808080a00000000000000000000000000000000000000000000000000000000000000000880000000000000000" + bodyRlp := hexutil.MustDecode(bodyRlpHex) + + it, err := NewListIterator(bodyRlp) + if err != nil { + t.Fatal(err) + } + // Check that txs exist + if !it.Next() { + t.Fatal("expected two elems, got zero") + } + txs := it.Value() + // Check that uncles exist + if !it.Next() { + t.Fatal("expected two elems, got one") + } + txit, err := NewListIterator(txs) + if err != nil { + t.Fatal(err) + } + var i = 0 + for txit.Next() { + if txit.err != nil { + t.Fatal(txit.err) + } + i++ + } + if exp := 2; i != exp { + t.Errorf("count wrong, expected %d got %d", i, exp) + } +} diff --git a/rlp/raw_test.go b/rlp/raw_test.go index 2aad042100..7b3255eca3 100644 --- a/rlp/raw_test.go +++ b/rlp/raw_test.go @@ -18,9 +18,10 @@ package rlp import ( "bytes" + "errors" "io" - "reflect" "testing" + "testing/quick" ) func TestCountValues(t *testing.T) { @@ -53,21 +54,84 @@ func TestCountValues(t *testing.T) { if count != test.count { t.Errorf("test %d: count mismatch, got %d want %d\ninput: %s", i, count, test.count, test.input) } - if !reflect.DeepEqual(err, test.err) { + if !errors.Is(err, test.err) { t.Errorf("test %d: err mismatch, got %q want %q\ninput: %s", i, err, test.err, test.input) } } } -func TestSplitTypes(t *testing.T) { - if _, _, err := SplitString(unhex("C100")); err != ErrExpectedString { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedString) +func TestSplitString(t *testing.T) { + for i, test := range []string{ + "C0", + "C100", + "C3010203", + "C88363617483646F67", + "F8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974", + } { + if _, _, err := SplitString(unhex(test)); !errors.Is(err, ErrExpectedString) { + t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedString) + } + } +} + +func TestSplitList(t *testing.T) { + for i, test := range []string{ + "80", + "00", + "01", + "8180", + "81FF", + "820400", + "83636174", + "83646F67", + "B8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974", + } { + if _, _, err := SplitList(unhex(test)); !errors.Is(err, ErrExpectedList) { + t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedList) + } } - if _, _, err := SplitList(unhex("01")); err != ErrExpectedList { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedList) +} + +func TestSplitUint64(t *testing.T) { + tests := []struct { + input string + val uint64 + rest string + err error + }{ + {"01", 1, "", nil}, + {"7FFF", 0x7F, "FF", nil}, + {"80FF", 0, "FF", nil}, + {"81FAFF", 0xFA, "FF", nil}, + {"82FAFAFF", 0xFAFA, "FF", nil}, + {"83FAFAFAFF", 0xFAFAFA, "FF", nil}, + {"84FAFAFAFAFF", 0xFAFAFAFA, "FF", nil}, + {"85FAFAFAFAFAFF", 0xFAFAFAFAFA, "FF", nil}, + {"86FAFAFAFAFAFAFF", 0xFAFAFAFAFAFA, "FF", nil}, + {"87FAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFA, "FF", nil}, + {"88FAFAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFAFA, "FF", nil}, + + // errors + {"", 0, "", io.ErrUnexpectedEOF}, + {"00", 0, "00", ErrCanonInt}, + {"81", 0, "81", ErrValueTooLarge}, + {"8100", 0, "8100", ErrCanonSize}, + {"8200FF", 0, "8200FF", ErrCanonInt}, + {"8103FF", 0, "8103FF", ErrCanonSize}, + {"89FAFAFAFAFAFAFAFAFAFF", 0, "89FAFAFAFAFAFAFAFAFAFF", errUintOverflow}, } - if _, _, err := SplitList(unhex("81FF")); err != ErrExpectedList { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedList) + + for i, test := range tests { + val, rest, err := SplitUint64(unhex(test.input)) + if val != test.val { + t.Errorf("test %d: val mismatch: got %x, want %x (input %q)", i, val, test.val, test.input) + } + if !bytes.Equal(rest, unhex(test.rest)) { + t.Errorf("test %d: rest mismatch: got %x, want %s (input %q)", i, rest, test.rest, test.input) + } + if err != test.err { + t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) + } } } @@ -78,7 +142,9 @@ func TestSplit(t *testing.T) { val, rest string err error }{ + {input: "00FFFF", kind: Byte, val: "00", rest: "FFFF"}, {input: "01FFFF", kind: Byte, val: "01", rest: "FFFF"}, + {input: "7FFFFF", kind: Byte, val: "7F", rest: "FFFF"}, {input: "80FFFF", kind: String, val: "", rest: "FFFF"}, {input: "C3010203", kind: List, val: "010203"}, @@ -194,3 +260,79 @@ func TestReadSize(t *testing.T) { } } } + +func TestAppendUint64(t *testing.T) { + tests := []struct { + input uint64 + slice []byte + output string + }{ + {0, nil, "80"}, + {1, nil, "01"}, + {2, nil, "02"}, + {127, nil, "7F"}, + {128, nil, "8180"}, + {129, nil, "8181"}, + {0xFFFFFF, nil, "83FFFFFF"}, + {127, []byte{1, 2, 3}, "0102037F"}, + {0xFFFFFF, []byte{1, 2, 3}, "01020383FFFFFF"}, + } + + for _, test := range tests { + x := AppendUint64(test.slice, test.input) + if !bytes.Equal(x, unhex(test.output)) { + t.Errorf("AppendUint64(%v, %d): got %x, want %s", test.slice, test.input, x, test.output) + } + + // Check that IntSize returns the appended size. + length := len(x) - len(test.slice) + if s := IntSize(test.input); s != length { + t.Errorf("IntSize(%d): got %d, want %d", test.input, s, length) + } + } +} + +func TestAppendUint64Random(t *testing.T) { + fn := func(i uint64) bool { + enc, _ := EncodeToBytes(i) + encAppend := AppendUint64(nil, i) + return bytes.Equal(enc, encAppend) + } + config := quick.Config{MaxCountScale: 50} + if err := quick.Check(fn, &config); err != nil { + t.Fatal(err) + } +} + +func TestBytesSize(t *testing.T) { + tests := []struct { + v []byte + size uint64 + }{ + {v: []byte{}, size: 1}, + {v: []byte{0x1}, size: 1}, + {v: []byte{0x7E}, size: 1}, + {v: []byte{0x7F}, size: 1}, + {v: []byte{0x80}, size: 2}, + {v: []byte{0xFF}, size: 2}, + {v: []byte{0xFF, 0xF0}, size: 3}, + {v: make([]byte, 55), size: 56}, + {v: make([]byte, 56), size: 58}, + } + + for _, test := range tests { + s := BytesSize(test.v) + if s != test.size { + t.Errorf("BytesSize(%#x) -> %d, want %d", test.v, s, test.size) + } + s = StringSize(string(test.v)) + if s != test.size { + t.Errorf("StringSize(%#x) -> %d, want %d", test.v, s, test.size) + } + // Sanity check: + enc, _ := EncodeToBytes(test.v) + if uint64(len(enc)) != test.size { + t.Errorf("len(EncodeToBytes(%#x)) -> %d, test says %d", test.v, len(enc), test.size) + } + } +} From 17e4876ed655c4adaf50a37339f858d711a13509 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 20 Jun 2023 18:13:56 +0700 Subject: [PATCH 005/119] Include RLPgen tool to CI --- build/ci.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/build/ci.go b/build/ci.go index ea44817049..6af2b18afe 100644 --- a/build/ci.go +++ b/build/ci.go @@ -14,6 +14,7 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . +//go:build none // +build none /* @@ -23,14 +24,13 @@ Usage: go run build/ci.go Available commands are: - install [ -arch architecture ] [ -cc compiler ] [ packages... ] -- builds packages and executables - test [ -coverage ] [ packages... ] -- runs the tests - lint -- runs certain pre-selected linters - importkeys -- imports signing keys from env - xgo [ -alltools ] [ options ] -- cross builds according to options + install [ -arch architecture ] [ -cc compiler ] [ packages... ] -- builds packages and executables + test [ -coverage ] [ packages... ] -- runs the tests + lint -- runs certain pre-selected linters + importkeys -- imports signing keys from env + xgo [ -alltools ] [ options ] -- cross builds according to options For all commands, -n prevents execution of external programs (dry run mode). - */ package main @@ -62,6 +62,7 @@ var ( executablePath("rlpdump"), executablePath("swarm"), executablePath("wnode"), + executablePath("rlp/rlpgen"), } ) From aa7c4b4078b1b672298785a0009f60041d4b6cca Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Wed, 21 Jun 2023 10:59:49 +0700 Subject: [PATCH 006/119] Add benchmarks for types RLP encoding/decoding --- core/types/types_test.go | 111 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 core/types/types_test.go diff --git a/core/types/types_test.go b/core/types/types_test.go new file mode 100644 index 0000000000..03c29a159b --- /dev/null +++ b/core/types/types_test.go @@ -0,0 +1,111 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "math/big" + "testing" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/rlp" +) + +type devnull struct{ len int } + +func (d *devnull) Write(p []byte) (int, error) { + d.len += len(p) + return len(p), nil +} + +func BenchmarkEncodeRLP(b *testing.B) { + benchRLP(b, true) +} + +func BenchmarkDecodeRLP(b *testing.B) { + benchRLP(b, false) +} + +func benchRLP(b *testing.B, encode bool) { + key, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + to := common.HexToAddress("0x00000000000000000000000000000000deadbeef") + signer := NewEIP155Signer(big.NewInt(1337)) + tx := NewTransaction(1, to, big.NewInt(1), 1000000, big.NewInt(500), nil) + signedTx, err := SignTx(tx, signer, key) + if err != nil { + b.Fatal("cannot sign transaction for benchmarking") + } + for _, tc := range []struct { + name string + obj interface{} + }{ + { + "header", + &Header{ + Difficulty: big.NewInt(10000000000), + Number: big.NewInt(1000), + GasLimit: 8_000_000, + GasUsed: 8_000_000, + Time: big.NewInt(555), + Extra: make([]byte, 32), + }, + }, + { + "receipt-for-storage", + &ReceiptForStorage{ + Status: ReceiptStatusSuccessful, + CumulativeGasUsed: 0x888888888, + Logs: make([]*Log, 0), + }, + }, + { + "receipt-full", + &Receipt{ + Status: ReceiptStatusSuccessful, + CumulativeGasUsed: 0x888888888, + Logs: make([]*Log, 0), + }, + }, + { + "transaction", + signedTx, + }, + } { + if encode { + b.Run(tc.name, func(b *testing.B) { + b.ReportAllocs() + var null = &devnull{} + for i := 0; i < b.N; i++ { + rlp.Encode(null, tc.obj) + } + b.SetBytes(int64(null.len / b.N)) + }) + } else { + data, _ := rlp.EncodeToBytes(tc.obj) + // Test decoding + b.Run(tc.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if err := rlp.DecodeBytes(data, tc.obj); err != nil { + b.Fatal(err) + } + } + b.SetBytes(int64(len(data))) + }) + } + } +} From ac89cd3357d9620cbaa6aab4752c0282a2b82ccf Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Wed, 21 Jun 2023 11:40:16 +0700 Subject: [PATCH 007/119] Generate RLP encoder for some structs --- core/types/gen_header_rlp.go | 58 ++++++++++++++++++++++++++++++++++++ core/types/gen_log_rlp.go | 26 ++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 core/types/gen_header_rlp.go create mode 100644 core/types/gen_log_rlp.go diff --git a/core/types/gen_header_rlp.go b/core/types/gen_header_rlp.go new file mode 100644 index 0000000000..1422cf6b16 --- /dev/null +++ b/core/types/gen_header_rlp.go @@ -0,0 +1,58 @@ +// Code generated by rlpgen. DO NOT EDIT. + +//go:build !norlpgen +// +build !norlpgen + +package types + +import ( + "io" + + "github.com/tomochain/tomochain/rlp" +) + +func (obj *Header) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.WriteBytes(obj.ParentHash[:]) + w.WriteBytes(obj.UncleHash[:]) + w.WriteBytes(obj.Coinbase[:]) + w.WriteBytes(obj.Root[:]) + w.WriteBytes(obj.TxHash[:]) + w.WriteBytes(obj.ReceiptHash[:]) + w.WriteBytes(obj.Bloom[:]) + if obj.Difficulty == nil { + w.Write(rlp.EmptyString) + } else { + if obj.Difficulty.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(obj.Difficulty) + } + if obj.Number == nil { + w.Write(rlp.EmptyString) + } else { + if obj.Number.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(obj.Number) + } + w.WriteUint64(obj.GasLimit) + w.WriteUint64(obj.GasUsed) + if obj.Time == nil { + w.Write(rlp.EmptyString) + } else { + if obj.Time.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(obj.Time) + } + w.WriteBytes(obj.Extra) + w.WriteBytes(obj.MixDigest[:]) + w.WriteBytes(obj.Nonce[:]) + w.WriteBytes(obj.Validators) + w.WriteBytes(obj.Validator) + w.WriteBytes(obj.Penalties) + w.ListEnd(_tmp0) + return w.Flush() +} diff --git a/core/types/gen_log_rlp.go b/core/types/gen_log_rlp.go new file mode 100644 index 0000000000..3f2c3ddc06 --- /dev/null +++ b/core/types/gen_log_rlp.go @@ -0,0 +1,26 @@ +// Code generated by rlpgen. DO NOT EDIT. + +//go:build !norlpgen +// +build !norlpgen + +package types + +import ( + "io" + + "github.com/tomochain/tomochain/rlp" +) + +func (obj *rlpLog) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.WriteBytes(obj.Address[:]) + _tmp1 := w.List() + for _, _tmp2 := range obj.Topics { + w.WriteBytes(_tmp2[:]) + } + w.ListEnd(_tmp1) + w.WriteBytes(obj.Data) + w.ListEnd(_tmp0) + return w.Flush() +} From 060f4ce0c6df48e50c42acef8f91448fb019d0c2 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Wed, 21 Jun 2023 11:41:17 +0700 Subject: [PATCH 008/119] Convert status of receipts from uint to uint64 Reference: https://github.com/ethereum/go-ethereum/pull/16784 --- core/types/gen_receipt_json.go | 8 ++++---- core/types/log.go | 3 +++ core/types/receipt.go | 6 +++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/core/types/gen_receipt_json.go b/core/types/gen_receipt_json.go index ffc851f2db..c698b9e36d 100644 --- a/core/types/gen_receipt_json.go +++ b/core/types/gen_receipt_json.go @@ -15,7 +15,7 @@ var _ = (*receiptMarshaling)(nil) func (r Receipt) MarshalJSON() ([]byte, error) { type Receipt struct { PostState hexutil.Bytes `json:"root"` - Status hexutil.Uint `json:"status"` + Status hexutil.Uint64 `json:"status"` CumulativeGasUsed hexutil.Uint64 `json:"cumulativeGasUsed" gencodec:"required"` Bloom Bloom `json:"logsBloom" gencodec:"required"` Logs []*Log `json:"logs" gencodec:"required"` @@ -25,7 +25,7 @@ func (r Receipt) MarshalJSON() ([]byte, error) { } var enc Receipt enc.PostState = r.PostState - enc.Status = hexutil.Uint(r.Status) + enc.Status = hexutil.Uint64(r.Status) enc.CumulativeGasUsed = hexutil.Uint64(r.CumulativeGasUsed) enc.Bloom = r.Bloom enc.Logs = r.Logs @@ -38,7 +38,7 @@ func (r Receipt) MarshalJSON() ([]byte, error) { func (r *Receipt) UnmarshalJSON(input []byte) error { type Receipt struct { PostState *hexutil.Bytes `json:"root"` - Status *hexutil.Uint `json:"status"` + Status *hexutil.Uint64 `json:"status"` CumulativeGasUsed *hexutil.Uint64 `json:"cumulativeGasUsed" gencodec:"required"` Bloom *Bloom `json:"logsBloom" gencodec:"required"` Logs []*Log `json:"logs" gencodec:"required"` @@ -54,7 +54,7 @@ func (r *Receipt) UnmarshalJSON(input []byte) error { r.PostState = *dec.PostState } if dec.Status != nil { - r.Status = uint(*dec.Status) + r.Status = uint64(*dec.Status) } if dec.CumulativeGasUsed == nil { return errors.New("missing required field 'cumulativeGasUsed' for Receipt") diff --git a/core/types/log.go b/core/types/log.go index af8e515eac..bee50763a8 100644 --- a/core/types/log.go +++ b/core/types/log.go @@ -63,6 +63,9 @@ type logMarshaling struct { Index hexutil.Uint } +//go:generate go run ../../rlp/rlpgen -type rlpLog -out gen_log_rlp.go + +// rlpLog is used to RLP-encode both the consensus and storage formats. type rlpLog struct { Address common.Address Topics []common.Hash diff --git a/core/types/receipt.go b/core/types/receipt.go index 3c55c12247..879aaf29c9 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -36,17 +36,17 @@ var ( const ( // ReceiptStatusFailed is the status code of a transaction if execution failed. - ReceiptStatusFailed = uint(0) + ReceiptStatusFailed = uint64(0) // ReceiptStatusSuccessful is the status code of a transaction if execution succeeded. - ReceiptStatusSuccessful = uint(1) + ReceiptStatusSuccessful = uint64(1) ) // Receipt represents the results of a transaction. type Receipt struct { // Consensus fields PostState []byte `json:"root"` - Status uint `json:"status"` + Status uint64 `json:"status"` CumulativeGasUsed uint64 `json:"cumulativeGasUsed" gencodec:"required"` Bloom Bloom `json:"logsBloom" gencodec:"required"` Logs []*Log `json:"logs" gencodec:"required"` From 788eef59c35f3438ba4a8a3f1488ccdec1a29444 Mon Sep 17 00:00:00 2001 From: Enda Dinh <90235926+endadinh@users.noreply.github.com> Date: Thu, 13 Jul 2023 11:21:06 +0700 Subject: [PATCH 009/119] cmd/tomo: added counters to the geth inspect report --- core/rawdb/database.go | 186 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) diff --git a/core/rawdb/database.go b/core/rawdb/database.go index 1183a74f51..f46f7ec8e7 100644 --- a/core/rawdb/database.go +++ b/core/rawdb/database.go @@ -17,10 +17,17 @@ package rawdb import ( + "bytes" "fmt" + "os" + "time" + + "github.com/olekukonko/tablewriter" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/ethdb/leveldb" "github.com/tomochain/tomochain/ethdb/memorydb" + "github.com/tomochain/tomochain/log" ) // freezerdb is a database wrapper that enabled freezer data retrievals. @@ -108,3 +115,182 @@ func NewLevelDBDatabase(file string, cache int, handles int, namespace string) ( } return NewDatabase(db), nil } + +type counter uint64 + +func (c counter) String() string { + return fmt.Sprintf("%d", c) +} + +func (c counter) Percentage(current uint64) string { + return fmt.Sprintf("%d", current*100/uint64(c)) +} + +// stat stores sizes and count for a parameter +type stat struct { + size common.StorageSize + count counter +} + +// Add size to the stat and increase the counter by 1 +func (s *stat) Add(size common.StorageSize) { + s.size += size + s.count++ +} + +func (s *stat) Size() string { + return s.size.String() +} + +func (s *stat) Count() string { + return s.count.String() +} + +// InspectDatabase traverses the entire database and checks the size +// of all different categories of data. +func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { + it := db.NewIterator(keyPrefix, keyStart) + defer it.Release() + + var ( + count int64 + start = time.Now() + logged = time.Now() + + // Key-value store statistics + headers stat + bodies stat + receipts stat + tds stat + numHashPairings stat + hashNumPairings stat + tries stat + codes stat + txLookups stat + accountSnaps stat + storageSnaps stat + preimages stat + bloomBits stat + cliqueSnaps stat + + // Ancient store statistics + ancientHeadersSize common.StorageSize + ancientBodiesSize common.StorageSize + ancientReceiptsSize common.StorageSize + ancientTdsSize common.StorageSize + ancientHashesSize common.StorageSize + + // Les statistic + chtTrieNodes stat + bloomTrieNodes stat + + // Meta- and unaccounted data + metadata stat + unaccounted stat + + // Totals + total common.StorageSize + ) + // Inspect key-value database first. + for it.Next() { + var ( + key = it.Key() + size = common.StorageSize(len(key) + len(it.Value())) + ) + total += size + switch { + case bytes.HasPrefix(key, headerPrefix) && len(key) == (len(headerPrefix)+8+common.HashLength): + headers.Add(size) + case bytes.HasPrefix(key, blockBodyPrefix) && len(key) == (len(blockBodyPrefix)+8+common.HashLength): + bodies.Add(size) + case bytes.HasPrefix(key, blockReceiptsPrefix) && len(key) == (len(blockReceiptsPrefix)+8+common.HashLength): + receipts.Add(size) + case bytes.HasPrefix(key, headerPrefix) && bytes.HasSuffix(key, headerTDSuffix): + tds.Add(size) + case bytes.HasPrefix(key, headerPrefix) && bytes.HasSuffix(key, headerHashSuffix): + numHashPairings.Add(size) + case bytes.HasPrefix(key, headerNumberPrefix) && len(key) == (len(headerNumberPrefix)+common.HashLength): + hashNumPairings.Add(size) + case len(key) == common.HashLength: + tries.Add(size) + // case bytes.HasPrefix(key, codePrefix) && len(key) == len(codePrefix)+common.HashLength: + // codes.Add(size) + case bytes.HasPrefix(key, txLookupPrefix) && len(key) == (len(txLookupPrefix)+common.HashLength): + txLookups.Add(size) + case bytes.HasPrefix(key, SnapshotAccountPrefix) && len(key) == (len(SnapshotAccountPrefix)+common.HashLength): + accountSnaps.Add(size) + case bytes.HasPrefix(key, SnapshotStoragePrefix) && len(key) == (len(SnapshotStoragePrefix)+2*common.HashLength): + storageSnaps.Add(size) + case bytes.HasPrefix(key, preimagePrefix) && len(key) == (len(preimagePrefix)+common.HashLength): + preimages.Add(size) + case bytes.HasPrefix(key, bloomBitsPrefix) && len(key) == (len(bloomBitsPrefix)+10+common.HashLength): + bloomBits.Add(size) + case bytes.HasPrefix(key, []byte("clique-")) && len(key) == 7+common.HashLength: + cliqueSnaps.Add(size) + case bytes.HasPrefix(key, []byte("cht-")) && len(key) == 4+common.HashLength: + chtTrieNodes.Add(size) + case bytes.HasPrefix(key, []byte("blt-")) && len(key) == 4+common.HashLength: + bloomTrieNodes.Add(size) + default: + var accounted bool + for _, meta := range [][]byte{databaseVerisionKey, headHeaderKey, headBlockKey, headFastBlockKey, fastTrieProgressKey} { + if bytes.Equal(key, meta) { + metadata.Add(size) + accounted = true + break + } + } + if !accounted { + unaccounted.Add(size) + } + } + count += 1 + if count%1000 == 0 && time.Since(logged) > 8*time.Second { + log.Info("Inspecting database", "count", count, "elapsed", common.PrettyDuration(time.Since(start))) + logged = time.Now() + } + } + // Inspect append-only file store then. + ancientSizes := []*common.StorageSize{&ancientHeadersSize, &ancientBodiesSize, &ancientReceiptsSize, &ancientHashesSize, &ancientTdsSize} + for i, category := range []string{freezerHeaderTable, freezerBodiesTable, freezerReceiptTable, freezerHashTable, freezerDifficultyTable} { + if size, err := db.AncientSize(category); err == nil { + *ancientSizes[i] += common.StorageSize(size) + total += common.StorageSize(size) + } + } + // Display the database statistic. + stats := [][]string{ + {"Key-Value store", "Headers", headers.Size(), headers.Count()}, + {"Key-Value store", "Bodies", bodies.Size(), bodies.Count()}, + {"Key-Value store", "Receipt lists", receipts.Size(), receipts.Count()}, + {"Key-Value store", "Difficulties", tds.Size(), tds.Count()}, + {"Key-Value store", "Block number->hash", numHashPairings.Size(), numHashPairings.Count()}, + {"Key-Value store", "Block hash->number", hashNumPairings.Size(), hashNumPairings.Count()}, + {"Key-Value store", "Transaction index", txLookups.Size(), txLookups.Count()}, + {"Key-Value store", "Bloombit index", bloomBits.Size(), bloomBits.Count()}, + {"Key-Value store", "Contract codes", codes.Size(), codes.Count()}, + {"Key-Value store", "Trie nodes", tries.Size(), tries.Count()}, + {"Key-Value store", "Trie preimages", preimages.Size(), preimages.Count()}, + {"Key-Value store", "Account snapshot", accountSnaps.Size(), accountSnaps.Count()}, + {"Key-Value store", "Storage snapshot", storageSnaps.Size(), storageSnaps.Count()}, + {"Key-Value store", "Clique snapshots", cliqueSnaps.Size(), cliqueSnaps.Count()}, + {"Key-Value store", "Singleton metadata", metadata.Size(), metadata.Count()}, + // {"Ancient store", "Headers", ancientHeadersSize.String(), ancients.String()}, + // {"Ancient store", "Bodies", ancientBodiesSize.String(), ancients.String()}, + // {"Ancient store", "Receipt lists", ancientReceiptsSize.String(), ancients.String()}, + // {"Ancient store", "Difficulties", ancientTdsSize.String(), ancients.String()}, + // {"Ancient store", "Block number->hash", ancientHashesSize.String(), ancients.String()}, + {"Light client", "CHT trie nodes", chtTrieNodes.Size(), chtTrieNodes.Count()}, + {"Light client", "Bloom trie nodes", bloomTrieNodes.Size(), bloomTrieNodes.Count()}, + } + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"Database", "Category", "Size", "Items"}) + table.SetFooter([]string{"", "Total", total.String(), " "}) + table.AppendBulk(stats) + table.Render() + + if unaccounted.size > 0 { + log.Error("Database contains unaccounted data", "size", unaccounted) + } + return nil +} From 81be74946a6430846c30a86f989cf912c1748c79 Mon Sep 17 00:00:00 2001 From: Enda Dinh <90235926+endadinh@users.noreply.github.com> Date: Thu, 13 Jul 2023 16:07:58 +0700 Subject: [PATCH 010/119] Change handling of dirty objects in state --- core/state/dump.go | 2 +- core/state/journal.go | 132 ++++++++++++++++++++++---- core/state/state_object.go | 52 +++-------- core/state/statedb.go | 187 +++++++++++++++++++++++++++++-------- core/state/statedb_test.go | 15 +-- 5 files changed, 281 insertions(+), 107 deletions(-) diff --git a/core/state/dump.go b/core/state/dump.go index f08c6e7df3..3fb154bbf0 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -53,7 +53,7 @@ func (self *StateDB) RawDump() Dump { panic(err) } - obj := newObject(nil, common.BytesToAddress(addr), data, nil) + obj := newObject(nil, common.BytesToAddress(addr), data) account := DumpAccount{ Balance: data.Balance.String(), Nonce: data.Nonce, diff --git a/core/state/journal.go b/core/state/journal.go index 1ac5cdbf25..cbb443706d 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -22,11 +22,67 @@ import ( "github.com/tomochain/tomochain/common" ) +// journalEntry is a modification entry in the state change journal that can be +// reverted on demand. type journalEntry interface { - undo(*StateDB) + // revert undoes the changes introduced by this journal entry. + revert(*StateDB) + + // dirtied returns the Ethereum address modified by this journal entry. + dirtied() *common.Address +} + +// journal contains the list of state modifications applied since the last state +// commit. These are tracked to be able to be reverted in case of an execution +// exception or revertal request. +type journal struct { + entries []journalEntry // Current changes tracked by the journal + dirties map[common.Address]int // Dirty accounts and the number of changes +} + +// newJournal create a new initialized journal. +func newJournal() *journal { + return &journal{ + dirties: make(map[common.Address]int), + } +} + +// append inserts a new modification entry to the end of the change journal. +func (j *journal) append(entry journalEntry) { + j.entries = append(j.entries, entry) + if addr := entry.dirtied(); addr != nil { + j.dirties[*addr]++ + } +} + +// revert undoes a batch of journalled modifications along with any reverted +// dirty handling too. +func (j *journal) revert(statedb *StateDB, snapshot int) { + for i := len(j.entries) - 1; i >= snapshot; i-- { + // Undo the changes made by the operation + j.entries[i].revert(statedb) + + // Drop any dirty tracking induced by the change + if addr := j.entries[i].dirtied(); addr != nil { + if j.dirties[*addr]--; j.dirties[*addr] == 0 { + delete(j.dirties, *addr) + } + } + } + j.entries = j.entries[:snapshot] +} + +// dirty explicitly sets an address to dirty, even if the change entries would +// otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD +// precompile consensus exception. +func (j *journal) dirty(addr common.Address) { + j.dirties[addr]++ } -type journal []journalEntry +// length returns the current number of entries in the journal. +func (j *journal) length() int { + return len(j.entries) +} type ( // Changes to the account trie. @@ -77,16 +133,24 @@ type ( } ) -func (ch createObjectChange) undo(s *StateDB) { +func (ch createObjectChange) revert(s *StateDB) { delete(s.stateObjects, *ch.account) delete(s.stateObjectsDirty, *ch.account) } -func (ch resetObjectChange) undo(s *StateDB) { +func (ch createObjectChange) dirtied() *common.Address { + return ch.account +} + +func (ch resetObjectChange) revert(s *StateDB) { s.setStateObject(ch.prev) } -func (ch suicideChange) undo(s *StateDB) { +func (ch resetObjectChange) dirtied() *common.Address { + return nil +} + +func (ch suicideChange) revert(s *StateDB) { obj := s.getStateObject(*ch.account) if obj != nil { obj.suicided = ch.prev @@ -94,38 +158,60 @@ func (ch suicideChange) undo(s *StateDB) { } } +func (ch suicideChange) dirtied() *common.Address { + return ch.account +} + var ripemd = common.HexToAddress("0000000000000000000000000000000000000003") -func (ch touchChange) undo(s *StateDB) { - if !ch.prev && *ch.account != ripemd { - s.getStateObject(*ch.account).touched = ch.prev - if !ch.prevDirty { - delete(s.stateObjectsDirty, *ch.account) - } - } +func (ch touchChange) revert(s *StateDB) { +} + +func (ch touchChange) dirtied() *common.Address { + return ch.account } -func (ch balanceChange) undo(s *StateDB) { +func (ch balanceChange) revert(s *StateDB) { s.getStateObject(*ch.account).setBalance(ch.prev) } -func (ch nonceChange) undo(s *StateDB) { +func (ch balanceChange) dirtied() *common.Address { + return ch.account +} + +func (ch nonceChange) revert(s *StateDB) { s.getStateObject(*ch.account).setNonce(ch.prev) } -func (ch codeChange) undo(s *StateDB) { +func (ch nonceChange) dirtied() *common.Address { + return ch.account +} + +func (ch codeChange) revert(s *StateDB) { s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) } -func (ch storageChange) undo(s *StateDB) { +func (ch codeChange) dirtied() *common.Address { + return ch.account +} + +func (ch storageChange) revert(s *StateDB) { s.getStateObject(*ch.account).setState(ch.key, ch.prevalue) } -func (ch refundChange) undo(s *StateDB) { +func (ch storageChange) dirtied() *common.Address { + return ch.account +} + +func (ch refundChange) revert(s *StateDB) { s.refund = ch.prev } -func (ch addLogChange) undo(s *StateDB) { +func (ch refundChange) dirtied() *common.Address { + return nil +} + +func (ch addLogChange) revert(s *StateDB) { logs := s.logs[ch.txhash] if len(logs) == 1 { delete(s.logs, ch.txhash) @@ -135,6 +221,14 @@ func (ch addLogChange) undo(s *StateDB) { s.logSize-- } -func (ch addPreimageChange) undo(s *StateDB) { +func (ch addLogChange) dirtied() *common.Address { + return nil +} + +func (ch addPreimageChange) revert(s *StateDB) { delete(s.preimages, ch.hash) } + +func (ch addPreimageChange) dirtied() *common.Address { + return nil +} diff --git a/core/state/state_object.go b/core/state/state_object.go index b03231e23b..aca54472f5 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -85,9 +85,7 @@ type stateObject struct { // during the "update" phase of the state transition. dirtyCode bool // true if the code was updated suicided bool - touched bool deleted bool - onDirty func(addr common.Address) // Callback method to mark a state object newly dirty } // empty returns whether the account is considered empty. @@ -105,7 +103,7 @@ type Account struct { } // newObject creates a state object. -func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *stateObject { +func newObject(db *StateDB, address common.Address, data Account) *stateObject { if data.Balance == nil { data.Balance = new(big.Int) } @@ -119,7 +117,6 @@ func newObject(db *StateDB, address common.Address, data Account, onDirty func(a data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), - onDirty: onDirty, } } @@ -137,23 +134,17 @@ func (self *stateObject) setError(err error) { func (self *stateObject) markSuicided() { self.suicided = true - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil - } } func (c *stateObject) touch() { - c.db.journal = append(c.db.journal, touchChange{ - account: &c.address, - prev: c.touched, - prevDirty: c.onDirty == nil, + c.db.journal.append(touchChange{ + account: &c.address, }) - if c.onDirty != nil { - c.onDirty(c.Address()) - c.onDirty = nil + if c.address == ripemd { + // Explicitly put it in the dirty-cache, which is otherwise generated from + // flattened journals. + c.db.journal.dirty(c.address) } - c.touched = true } func (c *stateObject) getTrie(db Database) Trie { @@ -212,7 +203,7 @@ func (self *stateObject) GetState(db Database, key common.Hash) common.Hash { // SetState updates a value in account storage. func (self *stateObject) SetState(db Database, key, value common.Hash) { - self.db.journal = append(self.db.journal, storageChange{ + self.db.journal.append(storageChange{ account: &self.address, key: key, prevalue: self.GetState(db, key), @@ -223,11 +214,6 @@ func (self *stateObject) SetState(db Database, key, value common.Hash) { func (self *stateObject) setState(key, value common.Hash) { self.cachedStorage[key] = value self.dirtyStorage[key] = value - - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil - } } // updateTrie writes cached storage modifications into the object's storage trie. @@ -291,7 +277,7 @@ func (c *stateObject) SubBalance(amount *big.Int) { } func (self *stateObject) SetBalance(amount *big.Int) { - self.db.journal = append(self.db.journal, balanceChange{ + self.db.journal.append(balanceChange{ account: &self.address, prev: new(big.Int).Set(self.data.Balance), }) @@ -300,17 +286,13 @@ func (self *stateObject) SetBalance(amount *big.Int) { func (self *stateObject) setBalance(amount *big.Int) { self.data.Balance = amount - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil - } } // Return the gas back to the origin. Used by the Virtual machine or Closures func (c *stateObject) ReturnGas(gas *big.Int) {} -func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { - stateObject := newObject(db, self.address, self.data, onDirty) +func (self *stateObject) deepCopy(db *StateDB) *stateObject { + stateObject := newObject(db, self.address, self.data) if self.trie != nil { stateObject.trie = db.db.CopyTrie(self.trie) } @@ -350,7 +332,7 @@ func (self *stateObject) Code(db Database) []byte { func (self *stateObject) SetCode(codeHash common.Hash, code []byte) { prevcode := self.Code(self.db.db) - self.db.journal = append(self.db.journal, codeChange{ + self.db.journal.append(codeChange{ account: &self.address, prevhash: self.CodeHash(), prevcode: prevcode, @@ -362,14 +344,10 @@ func (self *stateObject) setCode(codeHash common.Hash, code []byte) { self.code = code self.data.CodeHash = codeHash[:] self.dirtyCode = true - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil - } } func (self *stateObject) SetNonce(nonce uint64) { - self.db.journal = append(self.db.journal, nonceChange{ + self.db.journal.append(nonceChange{ account: &self.address, prev: self.data.Nonce, }) @@ -378,10 +356,6 @@ func (self *stateObject) SetNonce(nonce uint64) { func (self *stateObject) setNonce(nonce uint64) { self.data.Nonce = nonce - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil - } } func (self *stateObject) CodeHash() []byte { diff --git a/core/state/statedb.go b/core/state/statedb.go index 7a3357b3e8..822a417e37 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -22,11 +22,14 @@ import ( "math/big" "sort" "sync" + "time" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -37,6 +40,9 @@ type revision struct { } var ( + // emptyRoot is the known root hash of an empty trie. + emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + // emptyState is the known hash of an empty state trie entry. emptyState = crypto.Keccak256Hash(nil) @@ -53,6 +59,12 @@ type StateDB struct { db Database trie Trie + snaps *snapshot.Tree + snap snapshot.Snapshot + snapDestructs map[common.Hash]struct{} + snapAccounts map[common.Hash][]byte + snapStorage map[common.Hash]map[common.Hash][]byte + // This map holds 'live' objects, which will get modified while processing a state transition. stateObjects map[common.Address]*stateObject stateObjectsDirty map[common.Address]struct{} @@ -76,15 +88,28 @@ type StateDB struct { // Journal of state modifications. This is the backbone of // Snapshot and RevertToSnapshot. - journal journal + journal *journal validRevisions []revision nextRevisionId int + // Measurements gathered during execution for debugging purposes + AccountReads time.Duration + AccountHashes time.Duration + AccountUpdates time.Duration + AccountCommits time.Duration + StorageReads time.Duration + StorageHashes time.Duration + StorageUpdates time.Duration + StorageCommits time.Duration + SnapshotAccountReads time.Duration + SnapshotStorageReads time.Duration + SnapshotCommits time.Duration + lock sync.Mutex } func (self *StateDB) SubRefund(gas uint64) { - self.journal = append(self.journal, refundChange{ + self.journal.append(refundChange{ prev: self.refund}) if gas > self.refund { panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, self.refund)) @@ -101,19 +126,29 @@ func (self *StateDB) GetCommittedState(addr common.Address, hash common.Hash) co } // Create a new state from a given trie. -func New(root common.Hash, db Database) (*StateDB, error) { +func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { tr, err := db.OpenTrie(root) if err != nil { return nil, err } - return &StateDB{ + sdb := &StateDB{ db: db, trie: tr, + snaps: snaps, stateObjects: make(map[common.Address]*stateObject), stateObjectsDirty: make(map[common.Address]struct{}), logs: make(map[common.Hash][]*types.Log), preimages: make(map[common.Hash][]byte), - }, nil + journal: newJournal(), + } + if sdb.snaps != nil { + if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil { + sdb.snapDestructs = make(map[common.Hash]struct{}) + sdb.snapAccounts = make(map[common.Hash][]byte) + sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte) + } + } + return sdb, nil } // setError remembers the first non-nil error it is called with. @@ -144,11 +179,20 @@ func (self *StateDB) Reset(root common.Hash) error { self.logSize = 0 self.preimages = make(map[common.Hash][]byte) self.clearJournalAndRefund() + + if self.snaps != nil { + self.snapAccounts, self.snapDestructs, self.snapStorage = nil, nil, nil + if self.snap = self.snaps.Snapshot(root); self.snap != nil { + self.snapDestructs = make(map[common.Hash]struct{}) + self.snapAccounts = make(map[common.Hash][]byte) + self.snapStorage = make(map[common.Hash]map[common.Hash][]byte) + } + } return nil } func (self *StateDB) AddLog(log *types.Log) { - self.journal = append(self.journal, addLogChange{txhash: self.thash}) + self.journal.append(addLogChange{txhash: self.thash}) log.TxHash = self.thash log.BlockHash = self.bhash @@ -173,7 +217,7 @@ func (self *StateDB) Logs() []*types.Log { // AddPreimage records a SHA3 preimage seen by the VM. func (self *StateDB) AddPreimage(hash common.Hash, preimage []byte) { if _, ok := self.preimages[hash]; !ok { - self.journal = append(self.journal, addPreimageChange{hash: hash}) + self.journal.append(addPreimageChange{hash: hash}) pi := make([]byte, len(preimage)) copy(pi, preimage) self.preimages[hash] = pi @@ -186,7 +230,7 @@ func (self *StateDB) Preimages() map[common.Hash][]byte { } func (self *StateDB) AddRefund(gas uint64) { - self.journal = append(self.journal, refundChange{prev: self.refund}) + self.journal.append(refundChange{prev: self.refund}) self.refund += gas } @@ -272,7 +316,7 @@ func (self *StateDB) StorageTrie(addr common.Address) Trie { if stateObject == nil { return nil } - cpy := stateObject.deepCopy(self, nil) + cpy := stateObject.deepCopy(self) return cpy.updateTrie(self.db) } @@ -342,7 +386,7 @@ func (self *StateDB) Suicide(addr common.Address) bool { if stateObject == nil { return false } - self.journal = append(self.journal, suicideChange{ + self.journal.append(suicideChange{ account: &addr, prev: stateObject.suicided, prevbalance: new(big.Int).Set(stateObject.Balance()), @@ -358,13 +402,22 @@ func (self *StateDB) Suicide(addr common.Address) bool { // // updateStateObject writes the given object to the trie. -func (self *StateDB) updateStateObject(stateObject *stateObject) { - addr := stateObject.Address() - data, err := rlp.EncodeToBytes(stateObject) +func (s *StateDB) updateStateObject(obj *stateObject) { + addr := obj.Address() + data, err := rlp.EncodeToBytes(obj) if err != nil { panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err)) } - self.setError(self.trie.TryUpdate(addr[:], data)) + s.setError(s.trie.TryUpdate(addr[:], data)) + + // If state snapshotting is active, cache the data til commit. Note, this + // update mechanism is not symmetric to the deletion, because whereas it is + // enough to track account updates at commit time, deletions need tracking + // at transaction boundary level to ensure we capture state clearing. + if s.snap != nil { + s.snapAccounts[obj.addrHash] = snapshot.AccountRLP(obj.data.Nonce, obj.data.Balance, obj.data.Root, obj.data.CodeHash) + } + } // deleteStateObject removes the given object from the state trie. @@ -404,11 +457,65 @@ func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObje return nil } // Insert into the live set. - obj := newObject(self, addr, data, self.MarkStateObjectDirty) + obj := newObject(self, addr, data) self.setStateObject(obj) return obj } +// getDeletedStateObject is similar to getStateObject, but instead of returning +// nil for a deleted state object, it returns the actual object with the deleted +// flag set. This is needed by the state journal to revert to the correct s- +// destructed object instead of wiping all knowledge about the state object. +func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { + // Prefer live objects if any is available + if obj := s.stateObjects[addr]; obj != nil { + return obj + } + // If no live objects are available, attempt to use snapshots + var ( + data Account + err error + ) + if s.snap != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.SnapshotAccountReads += time.Since(start) }(time.Now()) + } + var acc *snapshot.Account + if acc, err = s.snap.Account(crypto.Keccak256Hash(addr[:])); err == nil { + if acc == nil { + return nil + } + data.Nonce, data.Balance, data.CodeHash = acc.Nonce, acc.Balance, acc.CodeHash + if len(data.CodeHash) == 0 { + data.CodeHash = emptyCodeHash + } + data.Root = common.BytesToHash(acc.Root) + if data.Root == (common.Hash{}) { + data.Root = emptyRoot + } + } + } + // If snapshot unavailable or reading from it failed, load from the database + if s.snap == nil || err != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now()) + } + enc, err := s.trie.TryGet(addr[:]) + if len(enc) == 0 { + s.setError(err) + return nil + } + if err := rlp.DecodeBytes(enc, &data); err != nil { + log.Error("Failed to decode state object", "addr", addr, "err", err) + return nil + } + } + // Insert into the live set + obj := newObject(s, addr, data) + s.setStateObject(obj) + return obj +} + func (self *StateDB) setStateObject(object *stateObject) { self.stateObjects[object.Address()] = object } @@ -422,22 +529,16 @@ func (self *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { return stateObject } -// MarkStateObjectDirty adds the specified object to the dirty map to avoid costly -// state object cache iteration to find a handful of modified ones. -func (self *StateDB) MarkStateObjectDirty(addr common.Address) { - self.stateObjectsDirty[addr] = struct{}{} -} - // createObject creates a new state object. If there is an existing account with // the given address, it is overwritten and returned as the second return value. func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { prev = self.getStateObject(addr) - newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty) + newobj = newObject(self, addr, Account{}) newobj.setNonce(0) // sets the object to dirty if prev == nil { - self.journal = append(self.journal, createObjectChange{account: &addr}) + self.journal.append(createObjectChange{account: &addr}) } else { - self.journal = append(self.journal, resetObjectChange{prev: prev}) + self.journal.append(resetObjectChange{prev: prev}) } self.setStateObject(newobj) return newobj, prev @@ -449,8 +550,8 @@ func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObjec // CreateAccount is called during the EVM CREATE operation. The situation might arise that // a contract does the following: // -// 1. sends funds to sha(account ++ (nonce + 1)) -// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) +// 1. sends funds to sha(account ++ (nonce + 1)) +// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) // // Carrying over the balance ensures that Ether doesn't disappear. func (self *StateDB) CreateAccount(addr common.Address) { @@ -492,16 +593,16 @@ func (self *StateDB) Copy() *StateDB { state := &StateDB{ db: self.db, trie: self.db.CopyTrie(self.trie), - stateObjects: make(map[common.Address]*stateObject, len(self.stateObjectsDirty)), - stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)), + stateObjects: make(map[common.Address]*stateObject, len(self.journal.dirties)), + stateObjectsDirty: make(map[common.Address]struct{}, len(self.journal.dirties)), refund: self.refund, logs: make(map[common.Hash][]*types.Log, len(self.logs)), logSize: self.logSize, preimages: make(map[common.Hash][]byte), } // Copy the dirty states, logs, and preimages - for addr := range self.stateObjectsDirty { - state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty) + for addr := range self.journal.dirties { + state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state) state.stateObjectsDirty[addr] = struct{}{} } for hash, logs := range self.logs { @@ -518,7 +619,7 @@ func (self *StateDB) Copy() *StateDB { func (self *StateDB) Snapshot() int { id := self.nextRevisionId self.nextRevisionId++ - self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)}) + self.validRevisions = append(self.validRevisions, revision{id, self.journal.length()}) return id } @@ -533,13 +634,8 @@ func (self *StateDB) RevertToSnapshot(revid int) { } snapshot := self.validRevisions[idx].journalIndex - // Replay the journal to undo changes. - for i := len(self.journal) - 1; i >= snapshot; i-- { - self.journal[i].undo(self) - } - self.journal = self.journal[:snapshot] - - // Remove invalidated snapshots from the stack. + // Replay the journal to undo changes and remove invalidated snapshots + self.journal.revert(self, snapshot) self.validRevisions = self.validRevisions[:idx] } @@ -551,14 +647,19 @@ func (self *StateDB) GetRefund() uint64 { // Finalise finalises the state by removing the self destructed objects // and clears the journal as well as the refunds. func (s *StateDB) Finalise(deleteEmptyObjects bool) { - for addr := range s.stateObjectsDirty { - stateObject := s.stateObjects[addr] + for addr := range s.journal.dirties { + stateObject, exist := s.stateObjects[addr] + if !exist { + continue + } + if stateObject.suicided || (deleteEmptyObjects && stateObject.empty()) { s.deleteStateObject(stateObject) } else { stateObject.updateRoot(s.db) s.updateStateObject(stateObject) } + s.stateObjectsDirty[addr] = struct{}{} } // Invalidate journal because reverting across transactions is not allowed. s.clearJournalAndRefund() @@ -602,7 +703,7 @@ func (s *StateDB) DeleteSuicides() { } func (s *StateDB) clearJournalAndRefund() { - s.journal = nil + s.journal = newJournal() s.validRevisions = s.validRevisions[:0] s.refund = 0 } @@ -611,6 +712,10 @@ func (s *StateDB) clearJournalAndRefund() { func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) { defer s.clearJournalAndRefund() + for addr := range s.journal.dirties { + s.stateObjectsDirty[addr] = struct{}{} + } + // Commit objects to the trie. for addr, stateObject := range s.stateObjects { _, isDirty := s.stateObjectsDirty[addr] diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 865ee073b8..ce44d00f80 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -20,7 +20,6 @@ import ( "bytes" "encoding/binary" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math" "math/big" "math/rand" @@ -29,6 +28,8 @@ import ( "testing" "testing/quick" + "github.com/tomochain/tomochain/core/rawdb" + check "gopkg.in/check.v1" "github.com/tomochain/tomochain/common" @@ -40,7 +41,7 @@ import ( func TestUpdateLeaks(t *testing.T) { // Create an empty state database db := rawdb.NewMemoryDatabase() - state, _ := New(common.Hash{}, NewDatabase(db)) + state, _ := New(common.Hash{}, NewDatabase(db), nil) // Update it with some accounts for i := byte(0); i < 255; i++ { @@ -70,8 +71,8 @@ func TestIntermediateLeaks(t *testing.T) { // Create two state databases, one transitioning to the final state, the other final from the beginning transDb := rawdb.NewMemoryDatabase() finalDb := rawdb.NewMemoryDatabase() - transState, _ := New(common.Hash{}, NewDatabase(transDb)) - finalState, _ := New(common.Hash{}, NewDatabase(finalDb)) + transState, _ := New(common.Hash{}, NewDatabase(transDb), nil) + finalState, _ := New(common.Hash{}, NewDatabase(finalDb), nil) modify := func(state *StateDB, addr common.Address, i, tweak byte) { state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) @@ -129,7 +130,7 @@ func TestIntermediateLeaks(t *testing.T) { func TestCopy(t *testing.T) { // Create a random state test to copy and modify "independently" db := rawdb.NewMemoryDatabase() - orig, _ := New(common.Hash{}, NewDatabase(db)) + orig, _ := New(common.Hash{}, NewDatabase(db), nil) for i := byte(0); i < 255; i++ { obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) @@ -341,7 +342,7 @@ func (test *snapshotTest) run() bool { // Run all actions and create snapshots. var ( db = rawdb.NewMemoryDatabase() - state, _ = New(common.Hash{}, NewDatabase(db)) + state, _ = New(common.Hash{}, NewDatabase(db), nil) snapshotRevs = make([]int, len(test.snapshots)) sindex = 0 ) @@ -355,7 +356,7 @@ func (test *snapshotTest) run() bool { // Revert all snapshots in reverse order. Each revert must yield a state // that is equivalent to fresh state with all actions up the snapshot applied. for sindex--; sindex >= 0; sindex-- { - checkstate, _ := New(common.Hash{}, state.Database()) + checkstate, _ := New(common.Hash{}, state.Database(), nil) for _, action := range test.actions[:test.snapshots[sindex]] { action.fn(action, checkstate) } From 2eb2792431e27fc3da71322ca7e3665baf1594f4 Mon Sep 17 00:00:00 2001 From: terryyyz-coin98 Date: Fri, 14 Jul 2023 11:29:16 +0700 Subject: [PATCH 011/119] feat: define ExecutionResult --- core/state_transition.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/core/state_transition.go b/core/state_transition.go index 9a2b079249..e4bcc30055 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -76,6 +76,40 @@ type Message interface { BalanceTokenFee() *big.Int } +// message no matter the execution itself is successful or not. + type ExecutionResult struct { + UsedGas uint64 // Total used gas but include the refunded gas + Err error // Any error encountered during the execution(listed in core/vm/errors.go) + ReturnData []byte // Returned data from evm(function result or data supplied with revert opcode) + } + + // Unwrap returns the internal evm error which allows us for further + // analysis outside. + func (result *ExecutionResult) Unwrap() error { + return result.Err + } + + // Failed returns the indicator whether the execution is successful or not + func (result *ExecutionResult) Failed() bool { return result.Err != nil } + + // Return is a helper function to help caller distinguish between revert reason + // and function return. Return returns the data after execution if no error occurs. + func (result *ExecutionResult) Return() []byte { + if result.Err != nil { + return nil + } + return common.CopyBytes(result.ReturnData) + } + + // Revert returns the concrete revert reason if the execution is aborted by `REVERT` + // opcode. Note the reason can be nil if no data supplied with revert opcode. + func (result *ExecutionResult) Revert() []byte { + if result.Err != vm.ErrExecutionReverted { + return nil + } + return common.CopyBytes(result.ReturnData) + } + // IntrinsicGas computes the 'intrinsic gas' for a message with the given data. func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error) { // Set the starting gas for the raw transaction From a88ee1c3b071db59cf9d115b287a44a3e68fafa5 Mon Sep 17 00:00:00 2001 From: terryyyz-coin98 Date: Fri, 14 Jul 2023 15:02:44 +0700 Subject: [PATCH 012/119] apply new return type --- accounts/abi/abi.go | 26 +++++++ accounts/abi/abi_test.go | 32 +++++++++ accounts/abi/bind/backends/simulated.go | 71 ++++++++++++++----- core/error.go | 7 ++ core/state_processor.go | 12 ++-- core/state_transition.go | 92 +++++++++++++------------ core/token_validator.go | 4 +- eth/api_tracer.go | 10 +-- internal/ethapi/api.go | 90 ++++++++++++++++++------ tests/state_test_util.go | 2 +- 10 files changed, 251 insertions(+), 95 deletions(-) diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 254b1f7fb4..6acf0e2b66 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -21,6 +21,9 @@ import ( "encoding/json" "fmt" "io" + "errors" + + "github.com/tomochain/tomochain/crypto" ) // The ABI holds information about a contract's context and available @@ -144,3 +147,26 @@ func (abi *ABI) MethodById(sigdata []byte) (*Method, error) { } return nil, fmt.Errorf("no method with id: %#x", sigdata[:4]) } + +// revertSelector is a special function selector for revert reason unpacking. + var revertSelector = crypto.Keccak256([]byte("Error(string)"))[:4] + + // UnpackRevert resolves the abi-encoded revert reason. According to the solidity + // spec https://solidity.readthedocs.io/en/latest/control-structures.html#revert, + // the provided revert reason is abi-encoded as if it were a call to a function + // `Error(string)`. So it's a special tool for it. + func UnpackRevert(data []byte) (string, error) { + if len(data) < 4 { + return "", errors.New("invalid data for unpacking") + } + if !bytes.Equal(data[:4], revertSelector) { + return "", errors.New("invalid data for unpacking") + } + var reason string + // typ, _ := NewType("string", "", nil) + typ, _ := NewType("string") + if err := (Arguments{{Type: typ}}).Unpack(&reason, data[4:]); err != nil { + return "", err + } + return reason, nil + } diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 5a128bfe54..b7aad7eb4c 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -21,6 +21,7 @@ import ( "encoding/hex" "fmt" "log" + "errors" "math/big" "strings" "testing" @@ -713,3 +714,34 @@ func TestABI_MethodById(t *testing.T) { } } + +func TestUnpackRevert(t *testing.T) { + t.Parallel() + + var cases = []struct { + input string + expect string + expectErr error + }{ + {"", "", errors.New("invalid data for unpacking")}, + {"08c379a1", "", errors.New("invalid data for unpacking")}, + {"08c379a00000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000d72657665727420726561736f6e00000000000000000000000000000000000000", "revert reason", nil}, + } + for index, c := range cases { + t.Run(fmt.Sprintf("case %d", index), func(t *testing.T) { + got, err := UnpackRevert(common.Hex2Bytes(c.input)) + if c.expectErr != nil { + if err == nil { + t.Fatalf("Expected non-nil error") + } + if err.Error() != c.expectErr.Error() { + t.Fatalf("Expected error mismatch, want %v, got %v", c.expectErr, err) + } + return + } + if c.expect != got { + t.Fatalf("Output mismatch, want %v, got %v", c.expect, got) + } + }) + } + } diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7411f492a8..7cfb5e5546 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -27,6 +27,7 @@ import ( "time" "github.com/tomochain/tomochain" + "github.com/tomochain/tomochain/accounts/abi" "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" @@ -198,8 +199,11 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.Call if err != nil { return nil, err } - rval, _, _, err := b.callContract(ctx, call, b.blockchain.CurrentBlock(), state) - return rval, err + res, err := b.callContract(ctx, call, b.blockchain.CurrentBlock(), state) + if err != nil { + return nil, err + } + return res.Return(), nil } //FIXME: please use copyState for this function @@ -228,11 +232,11 @@ func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain c vmenv := vm.NewEVM(evmContext, statedb, nil, chain.Config(), vm.Config{}) gaspool := new(core.GasPool).AddGas(1000000) owner := common.Address{} - rval, _, _, err := core.NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner) + result, err := core.NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner) if err != nil { return nil, err } - return rval, err + return result.Return(), nil } // PendingCallContract executes a contract call on the pending state. @@ -241,8 +245,11 @@ func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call tomocha defer b.mu.Unlock() defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot()) - rval, _, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) - return rval, err + res, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) + if err != nil { + return nil, err + } + return res.Return(), nil } // PendingNonceAt implements PendingStateReader.PendingNonceAt, retrieving @@ -280,40 +287,68 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM cap = hi // Create a helper to check if a gas allowance results in an executable transaction - executable := func(gas uint64) bool { + executable := func(gas uint64) (bool, *core.ExecutionResult, error) { call.Gas = gas snapshot := b.pendingState.Snapshot() - _, _, failed, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) - fmt.Println("EstimateGas",err,failed) + res, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) b.pendingState.RevertToSnapshot(snapshot) - if err != nil || failed { - return false + if err != nil { + if err == core.ErrIntrinsicGas { + return true, nil, nil // Special case, raise gas limit + } + return true, nil, err } - return true + return res.Failed(), res, nil } // Execute the binary search and hone in on an executable gas limit for lo+1 < hi { mid := (hi + lo) / 2 - if !executable(mid) { - lo = mid + failed, _, err := executable(mid) + // If the error is not nil(consensus error), it means the provided message + // call or transaction will never be accepted no matter how much gas it is + // assigned. Return the error directly, don't struggle any more + if err != nil { + return 0, err + } + if failed { + lo = mid } else { hi = mid } } // Reject the transaction as invalid if it still fails at the highest allowance if hi == cap { - if !executable(hi) { - return 0, errGasEstimationFailed - } + failed, result, err := executable(hi) + if err != nil { + return 0, err + } + if failed { + if result != nil && result.Err != vm.ErrOutOfGas { + errMsg := fmt.Sprintf("always failing transaction (%v)", result.Err) + + if len(result.Revert()) > 0 { + ret, err := abi.UnpackRevert(result.Revert()) + if err != nil { + errMsg += fmt.Sprintf(" (%#x)", result.Revert()) + } else { + errMsg += fmt.Sprintf(" (%s)", ret) + } + } + return 0, errors.New(errMsg) + } + + // Otherwise, the specified gas cap is too low + return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap) + } } return hi, nil } // callContract implements common code between normal and pending contract calls. // state is modified during execution, make sure to copy it if necessary. -func (b *SimulatedBackend) callContract(ctx context.Context, call tomochain.CallMsg, block *types.Block, statedb *state.StateDB) ([]byte, uint64, bool, error) { +func (b *SimulatedBackend) callContract(ctx context.Context, call tomochain.CallMsg, block *types.Block, statedb *state.StateDB) (*core.ExecutionResult, error) { // Ensure message is initialized properly. if call.GasPrice == nil { call.GasPrice = big.NewInt(1) diff --git a/core/error.go b/core/error.go index 63be6ab83d..cf4599a439 100644 --- a/core/error.go +++ b/core/error.go @@ -38,4 +38,11 @@ var ( ErrNotFoundM1 = errors.New("list M1 not found ") ErrStopPreparingBlock = errors.New("stop calculating a block not verified by M2") + + // ErrGasUintOverflow is returned when calculating gas usage. + ErrGasUintOverflow = errors.New("gas uint64 overflow") + + // ErrInsufficientFundsForTransfer is returned if the transaction sender doesn't + // have enough funds for transfer(topmost call only). + ErrInsufficientFundsForTransfer = errors.New("insufficient funds for transfer") ) diff --git a/core/state_processor.go b/core/state_processor.go index 035c15f2b3..a75403b25f 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -408,7 +408,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* // End Bypass blacklist address // Apply the transaction to the current state (included in the env) - _, gas, failed, err := ApplyMessage(vmenv, msg, gp, coinbaseOwner) + result, err := ApplyMessage(vmenv, msg, gp, coinbaseOwner) if err != nil { return nil, 0, err, false @@ -420,13 +420,13 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* } else { root = statedb.IntermediateRoot(config.IsEIP158(header.Number)).Bytes() } - *usedGas += gas + *usedGas += result.UsedGas // Create a new receipt for the transaction, storing the intermediate root and gas used by the tx // based on the eip phase, we're passing wether the root touch-delete accounts. - receipt := types.NewReceipt(root, failed, *usedGas) + receipt := types.NewReceipt(root, result.Failed(), *usedGas) receipt.TxHash = tx.Hash() - receipt.GasUsed = gas + receipt.GasUsed = result.UsedGas // if the transaction created a contract, store the creation address in the receipt. if msg.To() == nil { receipt.ContractAddress = crypto.CreateAddress(vmenv.Context.Origin, tx.Nonce()) @@ -434,10 +434,10 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* // Set the receipt logs and create a bloom for filtering receipt.Logs = statedb.GetLogs(tx.Hash()) receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) - if balanceFee != nil && failed { + if balanceFee != nil && result.Failed() { state.PayFeeWithTRC21TxFail(statedb, msg.From(), *tx.To()) } - return receipt, gas, err, balanceFee != nil + return receipt, result.UsedGas, err, balanceFee != nil } func ApplySignTransaction(config *params.ChainConfig, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64) (*types.Receipt, uint64, error, bool) { diff --git a/core/state_transition.go b/core/state_transition.go index e4bcc30055..80e6c9516a 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -23,7 +23,6 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/vm" - "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" ) @@ -130,13 +129,13 @@ func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error) } // Make sure we don't exceed uint64 for all data combinations if (math.MaxUint64-gas)/params.TxDataNonZeroGas < nz { - return 0, vm.ErrOutOfGas + return 0, ErrGasUintOverflow } gas += nz * params.TxDataNonZeroGas z := uint64(len(data)) - nz if (math.MaxUint64-gas)/params.TxDataZeroGas < z { - return 0, vm.ErrOutOfGas + return 0, ErrGasUintOverflow } gas += z * params.TxDataZeroGas } @@ -163,7 +162,7 @@ func NewStateTransition(evm *vm.EVM, msg Message, gp *GasPool) *StateTransition // the gas used (which includes gas refunds) and an error if it failed. An error always // indicates a core error meaning that the message would always fail for that particular // state and would never be accepted within a block. -func ApplyMessage(evm *vm.EVM, msg Message, gp *GasPool, owner common.Address) ([]byte, uint64, bool, error) { +func ApplyMessage(evm *vm.EVM, msg Message, gp *GasPool, owner common.Address) (*ExecutionResult, error) { return NewStateTransition(evm, msg, gp).TransitionDb(owner) } @@ -195,15 +194,6 @@ func (st *StateTransition) to() vm.AccountRef { return reference } -func (st *StateTransition) useGas(amount uint64) error { - if st.gas < amount { - return vm.ErrOutOfGas - } - st.gas -= amount - - return nil -} - func (st *StateTransition) buyGas() error { var ( state = st.state @@ -247,11 +237,32 @@ func (st *StateTransition) preCheck() error { } // TransitionDb will transition the state by applying the current message and -// returning the result including the the used gas. It returns an error if it -// failed. An error indicates a consensus issue. -func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedGas uint64, failed bool, err error) { - if err = st.preCheck(); err != nil { - return +// returning the evm execution result with following fields. +// +// - used gas: +// total gas used (including gas being refunded) +// - returndata: +// the returned data from evm +// - concrete execution error: +// various **EVM** error which aborts the execution, +// e.g. ErrOutOfGas, ErrExecutionReverted +// +// However if any consensus issue encountered, return the error directly with +// nil evm execution result. +func (st *StateTransition) TransitionDb(owner common.Address) (*ExecutionResult, error) { + // First check this message satisfies all consensus rules before + // applying the message. The rules include these clauses + // + // 1. the nonce of the message caller is correct + // 2. caller has enough balance to cover transaction fee(gaslimit * gasprice) + // 3. the amount of gas required is available in the block + // 4. the purchased gas is enough to cover intrinsic usage + // 5. there is no overflow when calculating intrinsic gas + // 6. caller has enough balance to cover asset transfer for **topmost** call + + // Check clauses 1-3, buy gas if everything is correct + if err := st.preCheck(); err != nil { + return nil, err } msg := st.msg sender := st.from() // err checked in preCheck @@ -259,44 +270,35 @@ func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedG homestead := st.evm.ChainConfig().IsHomestead(st.evm.BlockNumber) contractCreation := msg.To() == nil - // Pay intrinsic gas + // Check clauses 4-5, substract intrinsic gas if everything is correct gas, err := IntrinsicGas(st.data, contractCreation, homestead) if err != nil { - return nil, 0, false, err - } - if err = st.useGas(gas); err != nil { - return nil, 0, false, err + return nil, err } + if st.gas < gas { + return nil, ErrIntrinsicGas + } + st.gas -= gas + + // check clause 6 + if msg.Value().Sign() > 0 && !st.evm.CanTransfer(st.state, msg.From(), msg.Value()) { + return nil, ErrInsufficientFundsForTransfer + } var ( - evm = st.evm - // vm errors do not effect consensus and are therefor - // not assigned to err, except for insufficient balance - // error. - vmerr error + ret []byte + vmerr error ) // for debugging purpose // TODO: clean it after fixing the issue https://github.com/tomochain/tomochain/issues/401 - var contractAction string nonce := uint64(1) if contractCreation { - ret, _, st.gas, vmerr = evm.Create(sender, st.data, st.gas, st.value) - contractAction = "contract creation" + ret, _, st.gas, vmerr = st.evm.Create(sender, st.data, st.gas, st.value) } else { // Increment the nonce for the next transaction nonce = st.state.GetNonce(sender.Address()) + 1 st.state.SetNonce(sender.Address(), nonce) - ret, st.gas, vmerr = evm.Call(sender, st.to().Address(), st.data, st.gas, st.value) - contractAction = "contract call" - } - if vmerr != nil { - log.Debug("VM returned with error", "action", contractAction, "contract address", st.to().Address(), "gas", st.gas, "gasPrice", st.gasPrice, "nonce", nonce, "err", vmerr) - // The only possible consensus-error would be if there wasn't - // sufficient balance to make the transfer happen. The first - // balance transfer may never fail. - if vmerr == vm.ErrInsufficientBalance { - return nil, 0, false, vmerr - } + ret, st.gas, vmerr = st.evm.Call(sender, st.to().Address(), st.data, st.gas, st.value) } st.refundGas() @@ -308,7 +310,11 @@ func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedG st.state.AddBalance(st.evm.Coinbase, new(big.Int).Mul(new(big.Int).SetUint64(st.gasUsed()), st.gasPrice)) } - return ret, st.gasUsed(), vmerr != nil, err + return &ExecutionResult{ + UsedGas: st.gasUsed(), + Err: vmerr, + ReturnData: ret, + }, err } func (st *StateTransition) refundGas() { diff --git a/core/token_validator.go b/core/token_validator.go index 485ff05c59..a85c156a4b 100644 --- a/core/token_validator.go +++ b/core/token_validator.go @@ -111,11 +111,11 @@ func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, vmenv := vm.NewEVM(evmContext, statedb, nil, chain.Config(), vm.Config{}) gaspool := new(GasPool).AddGas(1000000) owner := common.Address{} - rval, _, _, err := NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner) + result, err := NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner) if err != nil { return nil, err } - return rval, err + return result.Return(), err } // make sure that balance of token is at slot 0 diff --git a/eth/api_tracer.go b/eth/api_tracer.go index e1744dc2c1..3977607d70 100644 --- a/eth/api_tracer.go +++ b/eth/api_tracer.go @@ -474,7 +474,7 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, vmenv := vm.NewEVM(vmctx, statedb, tomoxState, api.config, vm.Config{}) owner := common.Address{} - if _, _, _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.Gas()), owner); err != nil { + if _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.Gas()), owner); err != nil { failed = err break } @@ -630,7 +630,7 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v vmenv := vm.NewEVM(vmctx, statedb, nil, api.config, vm.Config{Debug: true, Tracer: tracer}) owner := common.Address{} - ret, gas, failed, err := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.Gas()), owner) + result, err := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.Gas()), owner) if err != nil { return nil, fmt.Errorf("tracing failed: %v", err) } @@ -638,9 +638,9 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v switch tracer := tracer.(type) { case *vm.StructLogger: return ðapi.ExecutionResult{ - Gas: gas, - Failed: failed, - ReturnValue: fmt.Sprintf("%x", ret), + Gas: result.UsedGas, + Failed: result.Failed(), + ReturnValue: fmt.Sprintf("%x", result.Return()), StructLogs: ethapi.FormatLogs(tracer.StructLogs()), }, nil diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 33376a1071..810bf68923 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -32,6 +32,7 @@ import ( "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/util" "github.com/tomochain/tomochain/accounts" + "github.com/tomochain/tomochain/accounts/abi" "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/accounts/keystore" "github.com/tomochain/tomochain/common" @@ -1025,12 +1026,12 @@ type CallArgs struct { Data hexutil.Bytes `json:"data"` } -func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config, timeout time.Duration) ([]byte, uint64, bool, error) { +func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config, timeout time.Duration) (*core.ExecutionResult, error) { defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) statedb, header, err := s.b.StateAndHeaderByNumber(ctx, blockNr) if statedb == nil || err != nil { - return nil, 0, false, err + return nil, err } // Set sender address or use a default if none specified addr := args.From @@ -1068,20 +1069,20 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr block, err := s.b.BlockByNumber(ctx, blockNr) if err != nil { - return nil, 0, false, err + return nil, err } author, err := s.b.GetEngine().Author(block.Header()) if err != nil { - return nil, 0, false, err + return nil, err } tomoxState, err := s.b.TomoxService().GetTradingState(block, author) if err != nil { - return nil, 0, false, err + return nil, err } // Get a new instance of the EVM. evm, vmError, err := s.b.GetEVM(ctx, msg, statedb, tomoxState, header, vmCfg) if err != nil { - return nil, 0, false, err + return nil, err } // Wait for the context to be done and cancel the evm. Even if the // EVM has finished, cancelling may be done (repeatedly) @@ -1094,18 +1095,38 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr // and apply the message. gp := new(core.GasPool).AddGas(math.MaxUint64) owner := common.Address{} - res, gas, failed, err := core.ApplyMessage(evm, msg, gp, owner) + result, err := core.ApplyMessage(evm, msg, gp, owner) if err := vmError(); err != nil { - return nil, 0, false, err + return nil, err } - return res, gas, failed, err + return result, err } // Call executes the given transaction on the state for the given block number. // It doesn't make and changes in the state/blockchain and is useful to execute and retrieve values. func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { - result, _, _, err := s.doCall(ctx, args, blockNr, vm.Config{}, 5*time.Second) - return (hexutil.Bytes)(result), err + result, err := s.doCall(ctx, args, blockNr, vm.Config{}, 5*time.Second) + if err != nil { + return nil, err + } + return result.Return(), nil +} + +type EstimateGasError struct { + error string // Concrete error type if it's failed to estimate gas usage + vmerr error // Additional field, it's non-nil if the given transaction is invalid + revert string // Additional field, it's non-empty if the transaction is reverted and reason is provided +} + +func (e EstimateGasError) Error() string { + errMsg := e.error + if e.vmerr != nil { + errMsg += fmt.Sprintf(" (%v)", e.vmerr) + } + if e.revert != "" { + errMsg += fmt.Sprintf(" (%s)", e.revert) + } + return errMsg } // EstimateGas returns an estimate of the amount of gas needed to execute the @@ -1130,19 +1151,26 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h cap = hi // Create a helper to check if a gas allowance results in an executable transaction - executable := func(gas uint64) bool { + executable := func(gas uint64) (bool, *core.ExecutionResult, error) { args.Gas = hexutil.Uint64(gas) - _, _, failed, err := s.doCall(ctx, args, rpc.LatestBlockNumber, vm.Config{}, 0) - if err != nil || failed { - return false + result, err := s.doCall(ctx, args, rpc.LatestBlockNumber, vm.Config{}, 0) + if err != nil { + if err == core.ErrIntrinsicGas { + return true, nil, nil // Special case, raise gas limit + } + return true, nil, err } - return true + return result.Failed(), result, nil } // Execute the binary search and hone in on an executable gas limit for lo+1 < hi { mid := (hi + lo) / 2 - if !executable(mid) { + failed, _, err := executable(mid) + if err != nil { + return 0, err + } + if failed { lo = mid } else { hi = mid @@ -1150,9 +1178,31 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h } // Reject the transaction as invalid if it still fails at the highest allowance if hi == cap { - if !executable(hi) { - return 0, fmt.Errorf("gas required exceeds allowance or always failing transaction") - } + failed, result, err := executable(hi) + if err != nil { + return 0, nil + } + + if failed { + if result != nil && result.Err != vm.ErrOutOfGas { + var revert string + + if len(result.Revert()) > 0 { + ret, err := abi.UnpackRevert(result.Revert()) + if err != nil { + revert = hexutil.Encode(result.Revert()) + } else { + revert = ret + } + } + return 0, EstimateGasError { + error: "always failing transaction", + vmerr: result.Err, + revert: revert, + } + } + return 0, EstimateGasError{error: fmt.Sprintf("gas required exceeds allowance (%d)", cap)} + } } return hexutil.Uint64(hi), nil } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index e532aa8a46..6360457689 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -144,7 +144,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD snapshot := statedb.Snapshot() coinbase := &t.json.Env.Coinbase - if _, _, _, err := core.ApplyMessage(evm, msg, gaspool, *coinbase); err != nil { + if _, err := core.ApplyMessage(evm, msg, gaspool, *coinbase); err != nil { statedb.RevertToSnapshot(snapshot) } if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { From 25f4168df284950a50d339debc90061a96da68ca Mon Sep 17 00:00:00 2001 From: terryyyz-coin98 Date: Fri, 14 Jul 2023 15:25:51 +0700 Subject: [PATCH 013/119] handle for simulate --- accounts/abi/bind/backends/simulated.go | 59 ++++++++++++++++----- internal/ethapi/api.go | 68 ++++++++++++++----------- 2 files changed, 84 insertions(+), 43 deletions(-) diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7cfb5e5546..e180b832de 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -20,16 +20,18 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sync" "time" + "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain" "github.com/tomochain/tomochain/accounts/abi" "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" @@ -187,6 +189,36 @@ func (b *SimulatedBackend) PendingCodeAt(ctx context.Context, contract common.Ad return b.pendingState.GetCode(contract), nil } +func newRevertError(result *core.ExecutionResult) *revertError { + reason, errUnpack := abi.UnpackRevert(result.Revert()) + err := errors.New("execution reverted") + if errUnpack == nil { + err = fmt.Errorf("execution reverted: %v", reason) + } + return &revertError{ + error: err, + reason: hexutil.Encode(result.Revert()), + } + } + + // revertError is an API error that encompassas an EVM revertal with JSON error + // code and a binary data blob. + type revertError struct { + error + reason string // revert reason hex encoded + } + + // ErrorCode returns the JSON error code for a revertal. + // See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal + func (e *revertError) ErrorCode() int { + return 3 + } + + // ErrorData returns the hex encoded revert reason. + func (e *revertError) ErrorData() interface{} { + return e.reason + } + // CallContract executes a contract call. func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.CallMsg, blockNumber *big.Int) ([]byte, error) { b.mu.Lock() @@ -203,7 +235,12 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.Call if err != nil { return nil, err } - return res.Return(), nil + + if len(res.Revert()) > 0 { + return nil, newRevertError(res) + } + + return res.Return(), res.Err } //FIXME: please use copyState for this function @@ -249,7 +286,11 @@ func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call tomocha if err != nil { return nil, err } - return res.Return(), nil + if len(res.Revert()) > 0 { + return nil, newRevertError(res) + } + + return res.Return(), res.Err } // PendingNonceAt implements PendingStateReader.PendingNonceAt, retrieving @@ -326,17 +367,11 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM } if failed { if result != nil && result.Err != vm.ErrOutOfGas { - errMsg := fmt.Sprintf("always failing transaction (%v)", result.Err) if len(result.Revert()) > 0 { - ret, err := abi.UnpackRevert(result.Revert()) - if err != nil { - errMsg += fmt.Sprintf(" (%#x)", result.Revert()) - } else { - errMsg += fmt.Sprintf(" (%s)", ret) - } + return 0, newRevertError(result) } - return 0, errors.New(errMsg) + return 0, result.Err } // Otherwise, the specified gas cap is too low diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 810bf68923..37bfd1fb82 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -1102,6 +1102,36 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr return result, err } +func newRevertError(result *core.ExecutionResult) *revertError { + reason, errUnpack := abi.UnpackRevert(result.Revert()) + err := errors.New("execution reverted") + if errUnpack == nil { + err = fmt.Errorf("execution reverted: %v", reason) + } + return &revertError{ + error: err, + reason: hexutil.Encode(result.Revert()), + } + } + + // revertError is an API error that encompassas an EVM revertal with JSON error + // code and a binary data blob. + type revertError struct { + error + reason string // revert reason hex encoded + } + + // ErrorCode returns the JSON error code for a revertal. + // See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal + func (e *revertError) ErrorCode() int { + return 3 + } + + // ErrorData returns the hex encoded revert reason. + func (e *revertError) ErrorData() interface{} { + return e.reason + } + // Call executes the given transaction on the state for the given block number. // It doesn't make and changes in the state/blockchain and is useful to execute and retrieve values. func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { @@ -1109,24 +1139,11 @@ func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr r if err != nil { return nil, err } - return result.Return(), nil -} - -type EstimateGasError struct { - error string // Concrete error type if it's failed to estimate gas usage - vmerr error // Additional field, it's non-nil if the given transaction is invalid - revert string // Additional field, it's non-empty if the transaction is reverted and reason is provided -} -func (e EstimateGasError) Error() string { - errMsg := e.error - if e.vmerr != nil { - errMsg += fmt.Sprintf(" (%v)", e.vmerr) - } - if e.revert != "" { - errMsg += fmt.Sprintf(" (%s)", e.revert) - } - return errMsg + if len(result.Revert()) > 0 { + return nil, newRevertError(result) + } + return result.Return(), result.Err } // EstimateGas returns an estimate of the amount of gas needed to execute the @@ -1185,23 +1202,12 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h if failed { if result != nil && result.Err != vm.ErrOutOfGas { - var revert string - if len(result.Revert()) > 0 { - ret, err := abi.UnpackRevert(result.Revert()) - if err != nil { - revert = hexutil.Encode(result.Revert()) - } else { - revert = ret - } - } - return 0, EstimateGasError { - error: "always failing transaction", - vmerr: result.Err, - revert: revert, + return 0, newRevertError(result) } + return 0, result.Err } - return 0, EstimateGasError{error: fmt.Sprintf("gas required exceeds allowance (%d)", cap)} + return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap) } } return hexutil.Uint64(hi), nil From 5e52c84c6432bb3e28047504d0c974ae13cdf00c Mon Sep 17 00:00:00 2001 From: terryyyz-coin98 Date: Fri, 14 Jul 2023 15:37:53 +0700 Subject: [PATCH 014/119] return data in json --- rpc/json.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/rpc/json.go b/rpc/json.go index 715f33ee16..9d57a9cf70 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -56,6 +56,8 @@ type jsonError struct { Data interface{} `json:"data,omitempty"` } + + type jsonErrResponse struct { Version string `json:"jsonrpc"` Id interface{} `json:"id,omitempty"` @@ -96,6 +98,10 @@ func (err *jsonError) ErrorCode() int { return err.Code } +func (err *jsonError) ErrorData() interface{} { + return err.Data + } + // NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based // on explicitly given encoding and decoding methods. func NewCodec(rwc io.ReadWriteCloser, encode, decode func(v interface{}) error) ServerCodec { From 8843c93a889be1dad8d1fb63655b6f112c8798f8 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sat, 15 Jul 2023 00:47:30 +0700 Subject: [PATCH 015/119] Move Message struct from package types to core --- accounts/abi/bind/backends/simulated.go | 40 ++++++--- core/error.go | 7 ++ core/evm.go | 6 +- core/state_processor.go | 15 ++-- core/state_transition.go | 114 +++++++++++++++--------- core/token_validator.go | 35 +++++--- core/types/hashes.go | 26 ++++++ core/types/transaction.go | 70 --------------- eth/api_backend.go | 14 ++- eth/api_tracer.go | 32 +++---- internal/ethapi/api.go | 32 ++++--- internal/ethapi/backend.go | 9 +- les/api_backend.go | 11 ++- tests/state_test_util.go | 22 ++++- 14 files changed, 237 insertions(+), 196 deletions(-) create mode 100644 core/types/hashes.go diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7411f492a8..00013fe190 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -20,8 +20,6 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sync" "time" @@ -30,9 +28,11 @@ import ( "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -202,7 +202,7 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.Call return rval, err } -//FIXME: please use copyState for this function +// FIXME: please use copyState for this function // CallContractWithState executes a contract call at the given state. func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { // Ensure message is initialized properly. @@ -215,11 +215,19 @@ func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain c call.Value = new(big.Int) } // Execute the call. - msg := callmsg{call} + msg := &core.Message{ + To: call.To, + From: call.From, + Value: call.Value, + GasLimit: call.Gas, + GasPrice: call.GasPrice, + Data: call.Data, + SkipAccountChecks: false, + } feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) - if msg.To() != nil { - if value, ok := feeCapacity[*msg.To()]; ok { - msg.CallMsg.BalanceTokenFee = value + if msg.To != nil { + if value, ok := feeCapacity[*msg.To]; ok { + msg.BalanceTokenFee = value } } evmContext := core.NewEVMContext(msg, chain.CurrentHeader(), chain, nil) @@ -285,7 +293,7 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM snapshot := b.pendingState.Snapshot() _, _, failed, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) - fmt.Println("EstimateGas",err,failed) + fmt.Println("EstimateGas", err, failed) b.pendingState.RevertToSnapshot(snapshot) if err != nil || failed { @@ -328,11 +336,19 @@ func (b *SimulatedBackend) callContract(ctx context.Context, call tomochain.Call from := statedb.GetOrNewStateObject(call.From) from.SetBalance(math.MaxBig256) // Execute the call. - msg := callmsg{call} + msg := &core.Message{ + To: call.To, + From: call.From, + Value: call.Value, + GasLimit: call.Gas, + GasPrice: call.GasPrice, + Data: call.Data, + SkipAccountChecks: true, + } feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) - if msg.To() != nil { - if value, ok := feeCapacity[*msg.To()]; ok { - msg.CallMsg.BalanceTokenFee = value + if msg.To != nil { + if value, ok := feeCapacity[*msg.To]; ok { + msg.BalanceTokenFee = value } } evmContext := core.NewEVMContext(msg, block.Header(), b.blockchain, nil) diff --git a/core/error.go b/core/error.go index 63be6ab83d..ec0f3a8166 100644 --- a/core/error.go +++ b/core/error.go @@ -33,9 +33,16 @@ var ( // next one expected based on the local chain. ErrNonceTooHigh = errors.New("nonce too high") + // ErrNonceMax is returned if the nonce of a transaction sender account has + // maximum allowed value and would become invalid if incremented. + ErrNonceMax = errors.New("nonce has max value") + ErrNotPoSV = errors.New("Posv not found in config") ErrNotFoundM1 = errors.New("list M1 not found ") ErrStopPreparingBlock = errors.New("stop calculating a block not verified by M2") + + // ErrSenderNoEOA is returned if the sender of a transaction is a contract. + ErrSenderNoEOA = errors.New("sender not an eoa") ) diff --git a/core/evm.go b/core/evm.go index 04636999b3..f3ac62a73f 100644 --- a/core/evm.go +++ b/core/evm.go @@ -26,7 +26,7 @@ import ( ) // NewEVMContext creates a new context for use in the EVM. -func NewEVMContext(msg Message, header *types.Header, chain consensus.ChainContext, author *common.Address) vm.Context { +func NewEVMContext(msg *Message, header *types.Header, chain consensus.ChainContext, author *common.Address) vm.Context { // If we don't have an explicit author (i.e. not mining), extract from the header var beneficiary common.Address if author == nil { @@ -38,13 +38,13 @@ func NewEVMContext(msg Message, header *types.Header, chain consensus.ChainConte CanTransfer: CanTransfer, Transfer: Transfer, GetHash: GetHashFn(header, chain), - Origin: msg.From(), + Origin: msg.From, Coinbase: beneficiary, BlockNumber: new(big.Int).Set(header.Number), Time: new(big.Int).Set(header.Time), Difficulty: new(big.Int).Set(header.Difficulty), GasLimit: header.GasLimit, - GasPrice: new(big.Int).Set(msg.GasPrice()), + GasPrice: new(big.Int).Set(msg.GasPrice), } } diff --git a/core/state_processor.go b/core/state_processor.go index 035c15f2b3..f3da43717a 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -18,9 +18,6 @@ package core import ( "fmt" - - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/log" "math/big" "runtime" "strings" @@ -33,7 +30,9 @@ import ( "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/tomox/tradingstate" ) // StateProcessor is a basic Processor, which takes care of transitioning @@ -243,7 +242,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* balanceFee = value } } - msg, err := tx.AsMessage(types.MakeSigner(config, header.Number), balanceFee, header.Number) + msg, err := TransactionToMessage(tx, types.MakeSigner(config, header.Number), balanceFee) if err != nil { return nil, 0, err, false } @@ -391,7 +390,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* blockMap[9147453] = "0x3538a544021c07869c16b764424c5987409cba48" blockMap[9147459] = "0xe187cf86c2274b1f16e8225a7da9a75aba4f1f5f" - addrFrom := msg.From().Hex() + addrFrom := msg.From.Hex() currentBlockNumber := header.Number.Int64() if addr, ok := blockMap[currentBlockNumber]; ok { @@ -409,7 +408,6 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* // Apply the transaction to the current state (included in the env) _, gas, failed, err := ApplyMessage(vmenv, msg, gp, coinbaseOwner) - if err != nil { return nil, 0, err, false } @@ -428,14 +426,14 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* receipt.TxHash = tx.Hash() receipt.GasUsed = gas // if the transaction created a contract, store the creation address in the receipt. - if msg.To() == nil { + if msg.To == nil { receipt.ContractAddress = crypto.CreateAddress(vmenv.Context.Origin, tx.Nonce()) } // Set the receipt logs and create a bloom for filtering receipt.Logs = statedb.GetLogs(tx.Hash()) receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) if balanceFee != nil && failed { - state.PayFeeWithTRC21TxFail(statedb, msg.From(), *tx.To()) + state.PayFeeWithTRC21TxFail(statedb, msg.From, *tx.To()) } return receipt, gas, err, balanceFee != nil } @@ -517,7 +515,6 @@ func InitSignerInTransactions(config *params.ChainConfig, header *types.Header, go func(from int, to int) { for j := from; j < to; j++ { types.CacheSigner(signer, txs[j]) - txs[j].CacheHash() } wg.Done() }(from, to) diff --git a/core/state_transition.go b/core/state_transition.go index 9a2b079249..0f396cac92 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -18,10 +18,12 @@ package core import ( "errors" + "fmt" "math" "math/big" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" @@ -42,15 +44,17 @@ The state transitioning model does all all the necessary work to work out a vali 3) Create a new state object if the recipient is \0*32 4) Value transfer == If contract creation == - 4a) Attempt to run transaction data - 4b) If valid, use result as code for the new state object + + 4a) Attempt to run transaction data + 4b) If valid, use result as code for the new state object + == end == 5) Run Script section 6) Derive new state root */ type StateTransition struct { gp *GasPool - msg Message + msg *Message gas uint64 gasPrice *big.Int initialGas uint64 @@ -60,20 +64,22 @@ type StateTransition struct { evm *vm.EVM } -// Message represents a message sent to a contract. -type Message interface { - From() common.Address - //FromFrontier() (common.Address, error) - To() *common.Address - - GasPrice() *big.Int - Gas() uint64 - Value() *big.Int - - Nonce() uint64 - CheckNonce() bool - Data() []byte - BalanceTokenFee() *big.Int +// A Message contains the data derived from a single transaction that is relevant to state +// processing. +type Message struct { + To *common.Address + From common.Address + Nonce uint64 + Value *big.Int + GasLimit uint64 + GasPrice *big.Int + Data []byte + BalanceTokenFee *big.Int + + // When SkipAccountChecks is true, the message nonce is not checked against the + // account nonce in state. It also disables checking that the sender is an EOA. + // This field will be set to true for operations like RPC eth_call. + SkipAccountChecks bool } // IntrinsicGas computes the 'intrinsic gas' for a message with the given data. @@ -110,18 +116,35 @@ func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error) } // NewStateTransition initialises and returns a new state transition object. -func NewStateTransition(evm *vm.EVM, msg Message, gp *GasPool) *StateTransition { +func NewStateTransition(evm *vm.EVM, msg *Message, gp *GasPool) *StateTransition { return &StateTransition{ gp: gp, evm: evm, msg: msg, - gasPrice: msg.GasPrice(), - value: msg.Value(), - data: msg.Data(), + gasPrice: msg.GasPrice, + value: msg.Value, + data: msg.Data, state: evm.StateDB, } } +// TransactionToMessage converts a transaction into a Message. +func TransactionToMessage(tx *types.Transaction, s types.Signer, balanceFee *big.Int) (*Message, error) { + msg := &Message{ + Nonce: tx.Nonce(), + GasLimit: tx.Gas(), + GasPrice: new(big.Int).Set(tx.GasPrice()), + To: tx.To(), + Value: tx.Value(), + Data: tx.Data(), + SkipAccountChecks: false, + BalanceTokenFee: balanceFee, + } + var err error + msg.From, err = types.Sender(s, tx) + return msg, err +} + // ApplyMessage computes the new state by applying the given message // against the old state within the environment. // @@ -129,12 +152,12 @@ func NewStateTransition(evm *vm.EVM, msg Message, gp *GasPool) *StateTransition // the gas used (which includes gas refunds) and an error if it failed. An error always // indicates a core error meaning that the message would always fail for that particular // state and would never be accepted within a block. -func ApplyMessage(evm *vm.EVM, msg Message, gp *GasPool, owner common.Address) ([]byte, uint64, bool, error) { +func ApplyMessage(evm *vm.EVM, msg *Message, gp *GasPool, owner common.Address) ([]byte, uint64, bool, error) { return NewStateTransition(evm, msg, gp).TransitionDb(owner) } func (st *StateTransition) from() vm.AccountRef { - f := st.msg.From() + f := st.msg.From if !st.state.Exist(f) { st.state.CreateAccount(f) } @@ -142,14 +165,14 @@ func (st *StateTransition) from() vm.AccountRef { } func (st *StateTransition) balanceTokenFee() *big.Int { - return st.msg.BalanceTokenFee() + return st.msg.BalanceTokenFee } func (st *StateTransition) to() vm.AccountRef { if st.msg == nil { return vm.AccountRef{} } - to := st.msg.To() + to := st.msg.To if to == nil { return vm.AccountRef{} // contract creation } @@ -176,7 +199,7 @@ func (st *StateTransition) buyGas() error { balanceTokenFee = st.balanceTokenFee() from = st.from() ) - mgval := new(big.Int).Mul(new(big.Int).SetUint64(st.msg.Gas()), st.gasPrice) + mgval := new(big.Int).Mul(new(big.Int).SetUint64(st.msg.GasLimit), st.gasPrice) if balanceTokenFee == nil { if state.GetBalance(from.Address()).Cmp(mgval) < 0 { return errInsufficientBalanceForGas @@ -184,12 +207,12 @@ func (st *StateTransition) buyGas() error { } else if balanceTokenFee.Cmp(mgval) < 0 { return errInsufficientBalanceForGas } - if err := st.gp.SubGas(st.msg.Gas()); err != nil { + if err := st.gp.SubGas(st.msg.GasLimit); err != nil { return err } - st.gas += st.msg.Gas() + st.gas += st.msg.GasLimit - st.initialGas = st.msg.Gas() + st.initialGas = st.msg.GasLimit if balanceTokenFee == nil { state.SubBalance(from.Address(), mgval) } @@ -197,23 +220,34 @@ func (st *StateTransition) buyGas() error { } func (st *StateTransition) preCheck() error { + // Only check transactions that are not fake msg := st.msg - sender := st.from() - - // Make sure this transaction's nonce is correct - if msg.CheckNonce() { - nonce := st.state.GetNonce(sender.Address()) - if nonce < msg.Nonce() { - return ErrNonceTooHigh - } else if nonce > msg.Nonce() { - return ErrNonceTooLow + if !msg.SkipAccountChecks { + // Make sure this transaction's nonce is correct. + stNonce := st.state.GetNonce(msg.From) + if msgNonce := msg.Nonce; stNonce < msgNonce { + return fmt.Errorf("%w: address %v, tx: %d state: %d", ErrNonceTooHigh, + msg.From.Hex(), msgNonce, stNonce) + } else if stNonce > msgNonce { + return fmt.Errorf("%w: address %v, tx: %d state: %d", ErrNonceTooLow, + msg.From.Hex(), msgNonce, stNonce) + } else if stNonce+1 < stNonce { + return fmt.Errorf("%w: address %v, nonce: %d", ErrNonceMax, + msg.From.Hex(), stNonce) + } + // Make sure the sender is an EOA + codeHash := st.state.GetCodeHash(msg.From) + if codeHash != (common.Hash{}) && codeHash != types.EmptyCodeHash { + return fmt.Errorf("%w: address %v, codehash: %s", ErrSenderNoEOA, + msg.From.Hex(), codeHash) } } + return st.buyGas() } // TransitionDb will transition the state by applying the current message and -// returning the result including the the used gas. It returns an error if it +// returning the result including the used gas. It returns an error if it // failed. An error indicates a consensus issue. func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedGas uint64, failed bool, err error) { if err = st.preCheck(); err != nil { @@ -223,7 +257,7 @@ func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedG sender := st.from() // err checked in preCheck homestead := st.evm.ChainConfig().IsHomestead(st.evm.BlockNumber) - contractCreation := msg.To() == nil + contractCreation := msg.To == nil // Pay intrinsic gas gas, err := IntrinsicGas(st.data, contractCreation, homestead) diff --git a/core/token_validator.go b/core/token_validator.go index 485ff05c59..b446106293 100644 --- a/core/token_validator.go +++ b/core/token_validator.go @@ -17,7 +17,11 @@ package core import ( "fmt" - ethereum "github.com/tomochain/tomochain" + "math/big" + "math/rand" + "strings" + + tomochain "github.com/tomochain/tomochain" "github.com/tomochain/tomochain/accounts/abi" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" @@ -25,9 +29,6 @@ import ( "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/log" - "math/big" - "math/rand" - "strings" ) const ( @@ -38,7 +39,7 @@ const ( // callmsg implements core.Message to allow passing it as a transaction simulator. type callmsg struct { - ethereum.CallMsg + tomochain.CallMsg } func (m callmsg) From() common.Address { return m.CallMsg.From } @@ -52,7 +53,7 @@ func (m callmsg) Data() []byte { return m.CallMsg.Data } func (m callmsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } type SimulatedBackend interface { - CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) + CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) } // GetTokenAbi return token abi @@ -72,7 +73,7 @@ func RunContract(chain consensus.ChainContext, statedb *state.StateDB, contractA } fakeCaller := common.HexToAddress("0x0000000000000000000000000000000000000001") statedb.SetBalance(fakeCaller, common.BasePrice) - msg := ethereum.CallMsg{To: &contractAddr, Data: input, From: fakeCaller} + msg := tomochain.CallMsg{To: &contractAddr, Data: input, From: fakeCaller} result, err := CallContractWithState(msg, chain, statedb) if err != nil { return nil, err @@ -85,9 +86,9 @@ func RunContract(chain consensus.ChainContext, statedb *state.StateDB, contractA return unpackResult, nil } -//FIXME: please use copyState for this function +// FIXME: please use copyState for this function // CallContractWithState executes a contract call at the given state. -func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { +func CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { // Ensure message is initialized properly. call.GasPrice = big.NewInt(0) @@ -98,11 +99,19 @@ func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, call.Value = new(big.Int) } // Execute the call. - msg := callmsg{call} + msg := &Message{ + To: call.To, + From: call.From, + Value: call.Value, + GasLimit: call.Gas, + GasPrice: call.GasPrice, + Data: call.Data, + SkipAccountChecks: false, + } feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) - if msg.To() != nil { - if value, ok := feeCapacity[*msg.To()]; ok { - msg.CallMsg.BalanceTokenFee = value + if msg.To != nil { + if value, ok := feeCapacity[*msg.To]; ok { + msg.BalanceTokenFee = value } } evmContext := NewEVMContext(msg, chain.CurrentHeader(), chain, nil) diff --git a/core/types/hashes.go b/core/types/hashes.go new file mode 100644 index 0000000000..2c9d8b6900 --- /dev/null +++ b/core/types/hashes.go @@ -0,0 +1,26 @@ +// Copyright 2023 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "github.com/tomochain/tomochain/crypto" +) + +var ( + // EmptyCodeHash is the known hash of the empty EVM bytecode. + EmptyCodeHash = crypto.Keccak256Hash(nil) // c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470 +) diff --git a/core/types/transaction.go b/core/types/transaction.go index cf546c4420..dd983bcb27 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -242,34 +242,6 @@ func (tx *Transaction) Size() common.StorageSize { return common.StorageSize(c) } -// AsMessage returns the transaction as a core.Message. -// -// AsMessage requires a signer to derive the sender. -// -// XXX Rename message to something less arbitrary? -func (tx *Transaction) AsMessage(s Signer, balanceFee *big.Int, number *big.Int) (Message, error) { - msg := Message{ - nonce: tx.data.AccountNonce, - gasLimit: tx.data.GasLimit, - gasPrice: new(big.Int).Set(tx.data.Price), - to: tx.data.Recipient, - amount: tx.data.Amount, - data: tx.data.Payload, - checkNonce: true, - balanceTokenFee: balanceFee, - } - var err error - msg.from, err = Sender(s, tx) - if balanceFee != nil { - if number.Cmp(common.TIPTRC21Fee) > 0 { - msg.gasPrice = common.TRC21GasPrice - } else { - msg.gasPrice = common.TRC21GasPriceBefore - } - } - return msg, err -} - // WithSignature returns a new transaction with the given signature. // This signature needs to be formatted as described in the yellow paper (v+27). func (tx *Transaction) WithSignature(signer Signer, sig []byte) (*Transaction, error) { @@ -680,45 +652,3 @@ func (t *TransactionsByPriceAndNonce) Shift() { func (t *TransactionsByPriceAndNonce) Pop() { heap.Pop(&t.heads) } - -// Message is a fully derived transaction and implements core.Message -// -// NOTE: In a future PR this will be removed. -type Message struct { - to *common.Address - from common.Address - nonce uint64 - amount *big.Int - gasLimit uint64 - gasPrice *big.Int - data []byte - checkNonce bool - balanceTokenFee *big.Int -} - -func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte, checkNonce bool, balanceTokenFee *big.Int) Message { - if balanceTokenFee != nil { - gasPrice = common.TRC21GasPrice - } - return Message{ - from: from, - to: to, - nonce: nonce, - amount: amount, - gasLimit: gasLimit, - gasPrice: gasPrice, - data: data, - checkNonce: checkNonce, - balanceTokenFee: balanceTokenFee, - } -} - -func (m Message) From() common.Address { return m.from } -func (m Message) BalanceTokenFee() *big.Int { return m.balanceTokenFee } -func (m Message) To() *common.Address { return m.to } -func (m Message) GasPrice() *big.Int { return m.gasPrice } -func (m Message) Value() *big.Int { return m.amount } -func (m Message) Gas() uint64 { return m.gasLimit } -func (m Message) Nonce() uint64 { return m.nonce } -func (m Message) Data() []byte { return m.data } -func (m Message) CheckNonce() bool { return m.checkNonce } diff --git a/eth/api_backend.go b/eth/api_backend.go index 67554b4480..d2022670b3 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -21,20 +21,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending" "io/ioutil" "math/big" "path/filepath" - "github.com/tomochain/tomochain/tomox" - - "github.com/tomochain/tomochain/consensus/posv" - "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/consensus/posv" "github.com/tomochain/tomochain/contracts" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" @@ -50,6 +45,9 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending" ) // EthApiBackend implements ethapi.Backend for full nodes @@ -136,8 +134,8 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { - state.SetBalance(msg.From(), math.MaxBig256) +func (b *EthApiBackend) GetEVM(ctx context.Context, msg *core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { + state.SetBalance(msg.From, math.MaxBig256) vmError := func() error { return nil } context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil) diff --git a/eth/api_tracer.go b/eth/api_tracer.go index e1744dc2c1..4581ec8d08 100644 --- a/eth/api_tracer.go +++ b/eth/api_tracer.go @@ -21,7 +21,6 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" "io/ioutil" "math/big" "runtime" @@ -39,6 +38,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" "github.com/tomochain/tomochain/trie" ) @@ -198,13 +198,13 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl feeCapacity := state.GetTRC21FeeCapacityFromState(task.statedb) // Trace all the transactions contained within for i, tx := range task.block.Transactions() { - var balacne *big.Int + var balanceFee *big.Int if tx.To() != nil { if value, ok := feeCapacity[*tx.To()]; ok { - balacne = value + balanceFee = value } } - msg, _ := tx.AsMessage(signer, balacne, task.block.Number()) + msg, _ := core.TransactionToMessage(tx, signer, balanceFee) vmctx := core.NewEVMContext(msg, task.block.Header(), api.eth.blockchain, nil) res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) @@ -438,13 +438,13 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, // Fetch and execute the next transaction trace tasks for task := range jobs { feeCapacity := state.GetTRC21FeeCapacityFromState(task.statedb) - var balacne *big.Int + var balanceFee *big.Int if txs[task.index].To() != nil { if value, ok := feeCapacity[*txs[task.index].To()]; ok { - balacne = value + balanceFee = value } } - msg, _ := txs[task.index].AsMessage(signer, balacne, block.Number()) + msg, _ := core.TransactionToMessage(txs[task.index], signer, balanceFee) vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) @@ -462,19 +462,19 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, for i, tx := range txs { // Send the trace task over for execution jobs <- &txTraceTask{statedb: statedb.Copy(), index: i} - var balacne *big.Int + var balanceFee *big.Int if tx.To() != nil { if value, ok := feeCapacity[*tx.To()]; ok { - balacne = value + balanceFee = value } } // Generate the next state snapshot fast without tracing - msg, _ := tx.AsMessage(signer, balacne, block.Number()) + msg, _ := core.TransactionToMessage(tx, signer, balanceFee) vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) vmenv := vm.NewEVM(vmctx, statedb, tomoxState, api.config, vm.Config{}) owner := common.Address{} - if _, _, _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.Gas()), owner); err != nil { + if _, _, _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.GasLimit), owner); err != nil { failed = err break } @@ -567,7 +567,7 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (* } size, _ := database.TrieDB().Size() log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start), "size", size) - return statedb,tomoxState, nil + return statedb, tomoxState, nil } // TraceTransaction returns the structured logs created during the execution of EVM @@ -593,7 +593,7 @@ func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, hash common.Ha // traceTx configures a new tracer according to the provided configuration, and // executes the given message in the provided environment. The return value will // be tracer dependent. -func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) { +func (api *PrivateDebugAPI) traceTx(ctx context.Context, message *core.Message, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) { // Assemble the structured logger or the JavaScript tracer var ( tracer vm.Tracer @@ -630,7 +630,7 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v vmenv := vm.NewEVM(vmctx, statedb, nil, api.config, vm.Config{Debug: true, Tracer: tracer}) owner := common.Address{} - ret, gas, failed, err := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.Gas()), owner) + ret, gas, failed, err := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.GasLimit), owner) if err != nil { return nil, fmt.Errorf("tracing failed: %v", err) } @@ -653,7 +653,7 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v } // computeTxEnv returns the execution environment of a certain transaction. -func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, reexec uint64) (core.Message, vm.Context, *state.StateDB, error) { +func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, reexec uint64) (*core.Message, vm.Context, *state.StateDB, error) { // Create the parent state database block := api.eth.blockchain.GetBlockByHash(blockHash) if block == nil { @@ -687,7 +687,7 @@ func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, ree balanceFee = value } } - msg, err := tx.AsMessage(types.MakeSigner(api.config, block.Header().Number), balanceFee, block.Number()) + msg, err := core.TransactionToMessage(tx, types.MakeSigner(api.config, block.Header().Number), balanceFee) if err != nil { return nil, vm.Context{}, nil, fmt.Errorf("tx %x failed: %v", tx.Hash(), err) } diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 33376a1071..35e8c723b2 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -21,14 +21,11 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" "math/big" "sort" "strings" "time" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/util" "github.com/tomochain/tomochain/accounts" @@ -50,6 +47,8 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" ) const ( @@ -424,7 +423,8 @@ func (s *PrivateAccountAPI) SignTransaction(ctx context.Context, args SendTxArgs // safely used to calculate a signature from. // // The hash is calulcated as -// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). +// +// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). // // This gives context to the signed message and prevents signing of transactions. func signHash(data []byte) []byte { @@ -1052,7 +1052,17 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr balanceTokenFee := big.NewInt(0).SetUint64(gas) balanceTokenFee = balanceTokenFee.Mul(balanceTokenFee, gasPrice) // Create new call message - msg := types.NewMessage(addr, args.To, 0, args.Value.ToInt(), gas, gasPrice, args.Data, false, balanceTokenFee) + msg := &core.Message{ + To: args.To, + From: addr, + Nonce: 0, + Value: args.Value.ToInt(), + GasLimit: gas, + GasPrice: gasPrice, + Data: args.Data, + BalanceTokenFee: balanceTokenFee, + SkipAccountChecks: false, + } // Setup context so it may be cancelled the call has completed // or, in case of unmetered gas, setup a context with a timeout. @@ -1305,8 +1315,8 @@ func (s *PublicBlockChainAPI) findNearestSignedBlock(ctx context.Context, b *typ } /* - findFinalityOfBlock return finality of a block - Use blocksHashCache for to keep track - refer core/blockchain.go for more detail +findFinalityOfBlock return finality of a block +Use blocksHashCache for to keep track - refer core/blockchain.go for more detail */ func (s *PublicBlockChainAPI) findFinalityOfBlock(ctx context.Context, b *types.Block, masternodes []common.Address) (uint, error) { engine, _ := s.b.GetEngine().(*posv.Posv) @@ -1371,7 +1381,7 @@ func (s *PublicBlockChainAPI) findFinalityOfBlock(ctx context.Context, b *types. } /* - Extract signers from block +Extract signers from block */ func (s *PublicBlockChainAPI) getSigners(ctx context.Context, block *types.Block, engine *posv.Posv) ([]common.Address, error) { var err error @@ -2965,7 +2975,8 @@ func GetSignersFromBlocks(b Backend, blockNumber uint64, blockHash common.Hash, // GetStakerROI Estimate ROI for stakers using the last epoc reward // then multiple by epoch per year, if the address is not masternode of last epoch - return 0 // Formular: -// ROI = average_latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 +// +// ROI = average_latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 func (s *PublicBlockChainAPI) GetStakerROI() float64 { blockNumber := s.b.CurrentBlock().Number().Uint64() lastCheckpointNumber := blockNumber - (blockNumber % s.b.ChainConfig().Posv.Epoch) - s.b.ChainConfig().Posv.Epoch // calculate for 2 epochs ago @@ -2991,7 +3002,8 @@ func (s *PublicBlockChainAPI) GetStakerROI() float64 { // GetStakerROIMasternode Estimate ROI for stakers of a specific masternode using the last epoc reward // then multiple by epoch per year, if the address is not masternode of last epoch - return 0 // Formular: -// ROI = latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 +// +// ROI = latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 func (s *PublicBlockChainAPI) GetStakerROIMasternode(masternode common.Address) float64 { votersReward := s.b.GetVotersRewards(masternode) if votersReward == nil { diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index 16edc3a17f..9a197d4e2a 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -19,12 +19,8 @@ package ethapi import ( "context" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending" "math/big" - "github.com/tomochain/tomochain/tomox" - "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" @@ -38,6 +34,9 @@ import ( "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending" ) // Backend interface provides the common API services (that are provided by @@ -61,7 +60,7 @@ type Backend interface { GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) GetTd(blockHash common.Hash) *big.Int - GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) + GetEVM(ctx context.Context, msg *core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription SubscribeChainSideEvent(ch chan<- core.ChainSideEvent) event.Subscription diff --git a/les/api_backend.go b/les/api_backend.go index d8285da97d..d5ba3c2c91 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -20,14 +20,10 @@ import ( "context" "encoding/json" "errors" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending" "io/ioutil" "math/big" "path/filepath" - "github.com/tomochain/tomochain/tomox" - "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" @@ -45,6 +41,9 @@ import ( "github.com/tomochain/tomochain/light" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending" ) type LesApiBackend struct { @@ -105,8 +104,8 @@ func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { - state.SetBalance(msg.From(), math.MaxBig256) +func (b *LesApiBackend) GetEVM(ctx context.Context, msg *core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { + state.SetBalance(msg.From, math.MaxBig256) context := core.NewEVMContext(msg, header, b.eth.blockchain, nil) return vm.NewEVM(context, state, tomoxState, b.eth.chainConfig, vmCfg), state.Error, nil } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index e532aa8a46..217d519c89 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -19,8 +19,8 @@ package tests import ( "encoding/hex" "encoding/json" + "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "strings" @@ -28,8 +28,8 @@ import ( "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" - "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/sha3" @@ -190,7 +190,7 @@ func (t *StateTest) genesis(config *params.ChainConfig) *core.Genesis { } } -func (tx *stTransaction) toMessage(ps stPostState) (core.Message, error) { +func (tx *stTransaction) toMessage(ps stPostState) (*core.Message, error) { // Derive sender from private key if present. var from common.Address if len(tx.PrivateKey) > 0 { @@ -235,7 +235,21 @@ func (tx *stTransaction) toMessage(ps stPostState) (core.Message, error) { if err != nil { return nil, fmt.Errorf("invalid tx data %q", dataHex) } - msg := types.NewMessage(from, to, tx.Nonce, value, gasLimit, tx.GasPrice, data, true, nil) + // If baseFee provided, set gasPrice to effectiveGasPrice. + gasPrice := tx.GasPrice + if gasPrice == nil { + return nil, errors.New("no gas price provided") + } + + msg := &core.Message{ + From: from, + To: to, + Nonce: tx.Nonce, + Value: value, + GasLimit: gasLimit, + GasPrice: gasPrice, + Data: data, + } return msg, nil } From 4148d3f55cb47e6cacd0006d0433a6277efdadad Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sat, 15 Jul 2023 01:08:07 +0700 Subject: [PATCH 016/119] Calculate TRC21GasPrice when convert a transaction to a core.Message --- core/database_util.go | 8 ++++---- core/genesis.go | 12 ++++++------ core/headerchain.go | 11 ++++++----- core/state_processor.go | 2 +- core/state_transition.go | 9 ++++++++- eth/api_tracer.go | 8 ++++---- 6 files changed, 29 insertions(+), 21 deletions(-) diff --git a/core/database_util.go b/core/database_util.go index a5ab18687d..82f17df197 100644 --- a/core/database_util.go +++ b/core/database_util.go @@ -22,10 +22,10 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" @@ -100,16 +100,16 @@ func GetCanonicalHash(db DatabaseReader, number uint64) common.Hash { return common.BytesToHash(data) } -// missingNumber is returned by GetBlockNumber if no header with the +// MissingNumber is returned by GetBlockNumber if no header with the // given block hash has been stored in the database -const missingNumber = uint64(0xffffffffffffffff) +const MissingNumber = uint64(0xffffffffffffffff) // GetBlockNumber returns the block number assigned to a block hash // if the corresponding header is present in the database func GetBlockNumber(db DatabaseReader, hash common.Hash) uint64 { data, _ := db.Get(append(blockHashPrefix, hash.Bytes()...)) if len(data) != 8 { - return missingNumber + return MissingNumber } return binary.BigEndian.Uint64(data) } diff --git a/core/genesis.go b/core/genesis.go index e1b7185a41..fcd196bd47 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -22,13 +22,13 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "strings" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" @@ -140,10 +140,10 @@ func (e *GenesisMismatchError) Error() string { // SetupGenesisBlock writes or updates the genesis block in db. // The block that will be used is: // -// genesis == nil genesis != nil -// +------------------------------------------ -// db has no genesis | main-net default | genesis -// db has genesis | from DB | genesis (if compatible) +// genesis == nil genesis != nil +// +------------------------------------------ +// db has no genesis | main-net default | genesis +// db has genesis | from DB | genesis (if compatible) // // The stored chain configuration will be updated if it is compatible (i.e. does not // specify a fork block below the local head block). In case of a conflict, the @@ -197,7 +197,7 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig // Check config compatibility and write the config. Compatibility errors // are returned to the caller unless we're already at block zero. height := GetBlockNumber(db, GetHeadHeaderHash(db)) - if height == missingNumber { + if height == MissingNumber { return newcfg, stored, fmt.Errorf("missing block number for head header hash") } compatErr := storedcfg.CheckCompatible(newcfg, height) diff --git a/core/headerchain.go b/core/headerchain.go index 8365f2127d..f3cc8cf77b 100644 --- a/core/headerchain.go +++ b/core/headerchain.go @@ -26,7 +26,7 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/types" @@ -66,9 +66,10 @@ type HeaderChain struct { } // NewHeaderChain creates a new HeaderChain structure. -// getValidator should return the parent's validator -// procInterrupt points to the parent's interrupt semaphore -// wg points to the parent's shutdown wait group +// +// getValidator should return the parent's validator +// procInterrupt points to the parent's interrupt semaphore +// wg points to the parent's shutdown wait group func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine consensus.Engine, procInterrupt func() bool) (*HeaderChain, error) { headerCache, _ := lru.New(headerCacheLimit) tdCache, _ := lru.New(tdCacheLimit) @@ -114,7 +115,7 @@ func (hc *HeaderChain) GetBlockNumber(hash common.Hash) uint64 { return cached.(uint64) } number := GetBlockNumber(hc.chainDb, hash) - if number != missingNumber { + if number != MissingNumber { hc.numberCache.Add(hash, number) } return number diff --git a/core/state_processor.go b/core/state_processor.go index f3da43717a..de70802235 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -242,7 +242,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* balanceFee = value } } - msg, err := TransactionToMessage(tx, types.MakeSigner(config, header.Number), balanceFee) + msg, err := TransactionToMessage(tx, types.MakeSigner(config, header.Number), balanceFee, header.Number) if err != nil { return nil, 0, err, false } diff --git a/core/state_transition.go b/core/state_transition.go index 0f396cac92..402ea9fbf0 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -129,7 +129,7 @@ func NewStateTransition(evm *vm.EVM, msg *Message, gp *GasPool) *StateTransition } // TransactionToMessage converts a transaction into a Message. -func TransactionToMessage(tx *types.Transaction, s types.Signer, balanceFee *big.Int) (*Message, error) { +func TransactionToMessage(tx *types.Transaction, s types.Signer, balanceFee *big.Int, number *big.Int) (*Message, error) { msg := &Message{ Nonce: tx.Nonce(), GasLimit: tx.Gas(), @@ -142,6 +142,13 @@ func TransactionToMessage(tx *types.Transaction, s types.Signer, balanceFee *big } var err error msg.From, err = types.Sender(s, tx) + if balanceFee != nil { + if number.Cmp(common.TIPTRC21Fee) > 0 { + msg.GasPrice = common.TRC21GasPrice + } else { + msg.GasPrice = common.TRC21GasPriceBefore + } + } return msg, err } diff --git a/eth/api_tracer.go b/eth/api_tracer.go index 4581ec8d08..6d0da2466e 100644 --- a/eth/api_tracer.go +++ b/eth/api_tracer.go @@ -204,7 +204,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl balanceFee = value } } - msg, _ := core.TransactionToMessage(tx, signer, balanceFee) + msg, _ := core.TransactionToMessage(tx, signer, balanceFee, task.block.Number()) vmctx := core.NewEVMContext(msg, task.block.Header(), api.eth.blockchain, nil) res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) @@ -444,7 +444,7 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, balanceFee = value } } - msg, _ := core.TransactionToMessage(txs[task.index], signer, balanceFee) + msg, _ := core.TransactionToMessage(txs[task.index], signer, balanceFee, block.Number()) vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) @@ -469,7 +469,7 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, } } // Generate the next state snapshot fast without tracing - msg, _ := core.TransactionToMessage(tx, signer, balanceFee) + msg, _ := core.TransactionToMessage(tx, signer, balanceFee, block.Number()) vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) vmenv := vm.NewEVM(vmctx, statedb, tomoxState, api.config, vm.Config{}) @@ -687,7 +687,7 @@ func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, ree balanceFee = value } } - msg, err := core.TransactionToMessage(tx, types.MakeSigner(api.config, block.Header().Number), balanceFee) + msg, err := core.TransactionToMessage(tx, types.MakeSigner(api.config, block.Header().Number), balanceFee, block.Number()) if err != nil { return nil, vm.Context{}, nil, fmt.Errorf("tx %x failed: %v", tx.Hash(), err) } From 47b9f539636718381b8921fb1de6394b845d88ae Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sat, 15 Jul 2023 01:08:24 +0700 Subject: [PATCH 017/119] Update unit tests --- cmd/tomo/bugcmd.go | 3 +-- cmd/tomo/consolecmd_test.go | 4 +-- console/console_test.go | 15 ++++++----- eth/tracers/tracers_test.go | 7 ++--- les/odr_test.go | 34 +++++++++++++++++------- light/odr_test.go | 52 ++++++++++++++++++++++++++----------- p2p/discover/node_test.go | 2 +- p2p/discv5/node_test.go | 2 +- tests/state_test.go | 3 +++ tests/vm_test.go | 7 +++-- 10 files changed, 86 insertions(+), 43 deletions(-) diff --git a/cmd/tomo/bugcmd.go b/cmd/tomo/bugcmd.go index 3174f73881..5cec10ad49 100644 --- a/cmd/tomo/bugcmd.go +++ b/cmd/tomo/bugcmd.go @@ -105,5 +105,4 @@ const header = `Please answer these questions before submitting your issue. Than #### What did you see instead? -#### System details -` +#### System details` diff --git a/cmd/tomo/consolecmd_test.go b/cmd/tomo/consolecmd_test.go index 241373f521..894f55c698 100644 --- a/cmd/tomo/consolecmd_test.go +++ b/cmd/tomo/consolecmd_test.go @@ -52,7 +52,7 @@ func TestConsoleWelcome(t *testing.T) { tomo.SetTemplateFunc("goarch", func() string { return runtime.GOARCH }) tomo.SetTemplateFunc("gover", runtime.Version) tomo.SetTemplateFunc("tomover", func() string { return params.Version }) - tomo.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format(time.RFC1123) }) + tomo.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (MST)") }) tomo.SetTemplateFunc("apis", func() string { return ipcAPIs }) // Verify the actual welcome message to the required template @@ -137,7 +137,7 @@ func testAttachWelcome(t *testing.T, tomo *testtomo, endpoint, apis string) { attach.SetTemplateFunc("gover", runtime.Version) attach.SetTemplateFunc("tomover", func() string { return params.Version }) attach.SetTemplateFunc("etherbase", func() string { return tomo.Etherbase }) - attach.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format(time.RFC1123) }) + attach.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (MST)") }) attach.SetTemplateFunc("ipc", func() bool { return strings.HasPrefix(endpoint, "ipc") }) attach.SetTemplateFunc("datadir", func() string { return tomo.Datadir }) attach.SetTemplateFunc("apis", func() string { return apis }) diff --git a/console/console_test.go b/console/console_test.go index 22527f4ddc..98f85c4b43 100644 --- a/console/console_test.go +++ b/console/console_test.go @@ -19,8 +19,6 @@ package console import ( "bytes" "errors" - "github.com/tomochain/tomochain/tomox" - "github.com/tomochain/tomochain/tomoxlending" "io/ioutil" "os" "strings" @@ -29,10 +27,13 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/console/prompt" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/eth" "github.com/tomochain/tomochain/internal/jsre" "github.com/tomochain/tomochain/node" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomoxlending" ) const ( @@ -67,10 +68,10 @@ func (p *hookedPrompter) PromptPassword(prompt string) (string, error) { func (p *hookedPrompter) PromptConfirm(prompt string) (bool, error) { return false, errors.New("not implemented") } -func (p *hookedPrompter) SetHistory(history []string) {} -func (p *hookedPrompter) AppendHistory(command string) {} -func (p *hookedPrompter) ClearHistory() {} -func (p *hookedPrompter) SetWordCompleter(completer WordCompleter) {} +func (p *hookedPrompter) SetHistory(history []string) {} +func (p *hookedPrompter) AppendHistory(command string) {} +func (p *hookedPrompter) ClearHistory() {} +func (p *hookedPrompter) SetWordCompleter(completer prompt.WordCompleter) {} // tester is a console test environment for the console tests to operate on. type tester struct { @@ -262,7 +263,7 @@ func TestPrettyError(t *testing.T) { defer tester.Close(t) tester.console.Evaluate("throw 'hello'") - want := jsre.ErrorColor("hello") + "\n" + want := jsre.ErrorColor("hello") + "\n\tat :1:1(1)\n\n" if output := tester.output.String(); output != want { t.Fatalf("pretty error mismatch: have %s, want %s", output, want) } diff --git a/eth/tracers/tracers_test.go b/eth/tracers/tracers_test.go index 38d4075175..9f469aeb89 100644 --- a/eth/tracers/tracers_test.go +++ b/eth/tracers/tracers_test.go @@ -20,17 +20,18 @@ import ( "crypto/ecdsa" "crypto/rand" "encoding/json" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "path/filepath" "reflect" "strings" "testing" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -178,7 +179,7 @@ func TestPrestateTracerCreate2(t *testing.T) { } evm := vm.NewEVM(context, statedb, nil, params.MainnetChainConfig, vm.Config{Debug: true, Tracer: tracer}) - msg, err := tx.AsMessage(signer, nil, nil) + msg, err := core.TransactionToMessage(tx, signer, nil, nil) if err != nil { t.Fatalf("failed to prepare transaction for tracing: %v", err) } @@ -253,7 +254,7 @@ func TestCallTracer(t *testing.T) { } evm := vm.NewEVM(context, statedb, nil, test.Genesis.Config, vm.Config{Debug: true, Tracer: tracer}) - msg, err := tx.AsMessage(signer, nil, common.Big0) + msg, err := core.TransactionToMessage(tx, signer, nil, common.Big0) if err != nil { t.Fatalf("failed to prepare transaction for tracing: %v", err) } diff --git a/les/odr_test.go b/les/odr_test.go index 3858e34028..4e95ecdb5d 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -19,7 +19,6 @@ package les import ( "bytes" "context" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" "time" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -109,12 +109,6 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon // //func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, odrContractCall) } -type callmsg struct { - types.Message -} - -func (callmsg) CheckNonce() bool { return false } - func odrContractCall(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000") @@ -133,8 +127,18 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false, balanceTokenFee)} - + fromAddr := from.Address() + msg := &core.Message{ + To: &fromAddr, + From: testContractAddr, + Nonce: 0, + Value: new(big.Int), + GasLimit: 100000, + GasPrice: new(big.Int), + Data: data, + SkipAccountChecks: false, + BalanceTokenFee: balanceTokenFee, + } context := core.NewEVMContext(msg, header, bc, nil) vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{}) @@ -153,7 +157,17 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false, balanceTokenFee)} + msg := &core.Message{ + To: &testBankAddress, + From: testContractAddr, + Nonce: 0, + Value: new(big.Int), + GasLimit: 100000, + GasPrice: new(big.Int), + Data: data, + SkipAccountChecks: false, + BalanceTokenFee: balanceTokenFee, + } context := core.NewEVMContext(msg, header, lc, nil) vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) diff --git a/light/odr_test.go b/light/odr_test.go index 0c5fc78573..43e7eaaea3 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -20,16 +20,16 @@ import ( "bytes" "context" "errors" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -43,7 +43,7 @@ import ( var ( testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") testBankAddress = crypto.PubkeyToAddress(testBankKey.PublicKey) - testBankFunds = big.NewInt(100000000) + testBankFunds = big.NewInt(1_000_000_000_000_000_000) acc1Key, _ = crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") acc2Key, _ = crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") @@ -74,7 +74,10 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { case *BlockRequest: req.Rlp = core.GetBodyRLP(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) case *ReceiptsRequest: - req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) + number := core.GetBlockNumber(odr.sdb, req.Hash) + if number != core.MissingNumber { + req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, number) + } case *TrieRequest: t, _ := trie.New(req.Id.Root, trie.NewDatabase(odr.sdb)) nodes := NewNodeSet() @@ -110,9 +113,16 @@ func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, odrGetReceipts) } func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { var receipts types.Receipts if bc != nil { - receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) + if number := core.GetBlockNumber(db, bhash); number != core.MissingNumber { + if block := core.GetBlock(db, bhash, number); block != nil { + receipts = core.GetBlockReceipts(db, bhash, number) + } + } } else { - receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) + number := core.GetBlockNumber(db, bhash) + if number != core.MissingNumber { + receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, number) + } } if receipts == nil { return nil, nil @@ -148,7 +158,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, odrContractCall) } type callmsg struct { - types.Message + core.Message } func (callmsg) CheckNonce() bool { return false } @@ -183,7 +193,16 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, new(big.Int), data, false, balanceTokenFee)} + msg := &core.Message{ + From: testBankAddress, + To: &testContractAddr, + Value: new(big.Int), + GasLimit: 1000000, + GasPrice: new(big.Int), + Data: data, + SkipAccountChecks: true, + BalanceTokenFee: balanceTokenFee, + } context := core.NewEVMContext(msg, header, chain, nil) vmenv := vm.NewEVM(context, st, nil, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) @@ -202,17 +221,17 @@ func testChainGen(i int, block *core.BlockGen) { switch i { case 0: // In block 1, the test bank sends account #1 some ether. - tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testBankKey) + tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(10_000_000_000_000_000), params.TxGas, nil, nil), signer, testBankKey) block.AddTx(tx) case 1: // In block 2, the test bank sends some more ether to account #1. // acc1Addr passes it on to account #2. // acc1Addr creates a test contract. - tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testBankKey) + tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(1_000_000_000_000_000), params.TxGas, nil, nil), signer, testBankKey) nonce := block.TxNonce(acc1Addr) - tx2, _ := types.SignTx(types.NewTransaction(nonce, acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key) + tx2, _ := types.SignTx(types.NewTransaction(nonce, acc2Addr, big.NewInt(1_000_000_000_000_000), params.TxGas, nil, nil), signer, acc1Key) nonce++ - tx3, _ := types.SignTx(types.NewContractCreation(nonce, big.NewInt(0), 1000000, big.NewInt(0), testContractCode), signer, acc1Key) + tx3, _ := types.SignTx(types.NewContractCreation(nonce, big.NewInt(0), 1000000, nil, testContractCode), signer, acc1Key) testContractAddr = crypto.CreateAddress(acc1Addr, nonce) block.AddTx(tx1) block.AddTx(tx2) @@ -240,9 +259,12 @@ func testChainGen(i int, block *core.BlockGen) { func testChainOdr(t *testing.T, protocol int, fn odrTestFn) { var ( - sdb = rawdb.NewMemoryDatabase() - ldb = rawdb.NewMemoryDatabase() - gspec = core.Genesis{Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}} + sdb = rawdb.NewMemoryDatabase() + ldb = rawdb.NewMemoryDatabase() + gspec = core.Genesis{ + Config: params.TestChainConfig, + Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}, + } genesis = gspec.MustCommit(sdb) ) gspec.MustCommit(ldb) diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go index 8e3da2c2aa..ddf8a7bd98 100644 --- a/p2p/discover/node_test.go +++ b/p2p/discover/node_test.go @@ -142,7 +142,7 @@ var parseNodeTests = []struct { { // This test checks that errors from url.Parse are handled. rawurl: "://foo", - wantError: `parse ://foo: missing protocol scheme`, + wantError: `parse "://foo": missing protocol scheme`, }, } diff --git a/p2p/discv5/node_test.go b/p2p/discv5/node_test.go index a28f298252..d0fa6880a3 100644 --- a/p2p/discv5/node_test.go +++ b/p2p/discv5/node_test.go @@ -141,7 +141,7 @@ var parseNodeTests = []struct { { // This test checks that errors from url.Parse are handled. rawurl: "://foo", - wantError: `parse ://foo: missing protocol scheme`, + wantError: `parse "://foo": missing protocol scheme`, }, } diff --git a/tests/state_test.go b/tests/state_test.go index 7c8c5e9268..81a7370d60 100644 --- a/tests/state_test.go +++ b/tests/state_test.go @@ -26,6 +26,9 @@ import ( ) func TestState(t *testing.T) { + if testing.Short() { + t.Skip("skipping testing in short mode") + } t.Parallel() st := new(testMatcher) diff --git a/tests/vm_test.go b/tests/vm_test.go index 9e1f735436..7fda3cc6f5 100644 --- a/tests/vm_test.go +++ b/tests/vm_test.go @@ -17,15 +17,18 @@ package tests import ( - "github.com/tomochain/tomochain/common" "math/big" "testing" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/vm" ) func TestVM(t *testing.T) { - common.TIPTomoXCancellationFee=big.NewInt(100000000) + if testing.Short() { + t.Skip("skipping testing in short mode") + } + common.TIPTomoXCancellationFee = big.NewInt(100000000) t.Parallel() vmt := new(testMatcher) vmt.fails("^vmSystemOperationsTest.json/createNameRegistrator$", "fails without parallel execution") From 3b6e3f3dae249817ea3942298b388149e56c5384 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sat, 15 Jul 2023 02:11:29 +0700 Subject: [PATCH 018/119] Add more information fields to transaction receipts and regen codec/RLP --- consensus/posv/posv.go | 9 +- core/types/gen_log_json.go | 2 + core/types/gen_log_rlp.go | 7 +- core/types/gen_receipt_json.go | 21 +++++ core/types/log.go | 2 +- core/types/receipt.go | 153 +++++++++++++++++++++++++++------ 6 files changed, 156 insertions(+), 38 deletions(-) diff --git a/consensus/posv/posv.go b/consensus/posv/posv.go index 0027104970..bd2b818d1c 100644 --- a/consensus/posv/posv.go +++ b/consensus/posv/posv.go @@ -21,9 +21,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" "io/ioutil" "math/big" "math/rand" @@ -50,6 +47,10 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" + + "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) const ( @@ -1146,7 +1147,7 @@ func (c *Posv) CacheData(header *types.Header, txs []*types.Transaction, receipt signTxs := []*types.Transaction{} for _, tx := range txs { if tx.IsSigningTransaction() { - var b uint + var b uint64 for _, r := range receipts { if r.TxHash == tx.Hash() { if len(r.PostState) > 0 { diff --git a/core/types/gen_log_json.go b/core/types/gen_log_json.go index 759ff8814c..ae61caf6b9 100644 --- a/core/types/gen_log_json.go +++ b/core/types/gen_log_json.go @@ -12,6 +12,7 @@ import ( var _ = (*logMarshaling)(nil) +// MarshalJSON marshals as JSON. func (l Log) MarshalJSON() ([]byte, error) { type Log struct { Address common.Address `json:"address" gencodec:"required"` @@ -37,6 +38,7 @@ func (l Log) MarshalJSON() ([]byte, error) { return json.Marshal(&enc) } +// UnmarshalJSON unmarshals from JSON. func (l *Log) UnmarshalJSON(input []byte) error { type Log struct { Address *common.Address `json:"address" gencodec:"required"` diff --git a/core/types/gen_log_rlp.go b/core/types/gen_log_rlp.go index 3f2c3ddc06..9301635297 100644 --- a/core/types/gen_log_rlp.go +++ b/core/types/gen_log_rlp.go @@ -5,11 +5,8 @@ package types -import ( - "io" - - "github.com/tomochain/tomochain/rlp" -) +import "github.com/tomochain/tomochain/rlp" +import "io" func (obj *rlpLog) EncodeRLP(_w io.Writer) error { w := rlp.NewEncoderBuffer(_w) diff --git a/core/types/gen_receipt_json.go b/core/types/gen_receipt_json.go index c698b9e36d..03494c8a6f 100644 --- a/core/types/gen_receipt_json.go +++ b/core/types/gen_receipt_json.go @@ -5,6 +5,7 @@ package types import ( "encoding/json" "errors" + "math/big" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" @@ -12,6 +13,7 @@ import ( var _ = (*receiptMarshaling)(nil) +// MarshalJSON marshals as JSON. func (r Receipt) MarshalJSON() ([]byte, error) { type Receipt struct { PostState hexutil.Bytes `json:"root"` @@ -22,6 +24,9 @@ func (r Receipt) MarshalJSON() ([]byte, error) { TxHash common.Hash `json:"transactionHash" gencodec:"required"` ContractAddress common.Address `json:"contractAddress"` GasUsed hexutil.Uint64 `json:"gasUsed" gencodec:"required"` + BlockHash common.Hash `json:"blockHash,omitempty"` + BlockNumber *hexutil.Big `json:"blockNumber,omitempty"` + TransactionIndex hexutil.Uint `json:"transactionIndex"` } var enc Receipt enc.PostState = r.PostState @@ -32,9 +37,13 @@ func (r Receipt) MarshalJSON() ([]byte, error) { enc.TxHash = r.TxHash enc.ContractAddress = r.ContractAddress enc.GasUsed = hexutil.Uint64(r.GasUsed) + enc.BlockHash = r.BlockHash + enc.BlockNumber = (*hexutil.Big)(r.BlockNumber) + enc.TransactionIndex = hexutil.Uint(r.TransactionIndex) return json.Marshal(&enc) } +// UnmarshalJSON unmarshals from JSON. func (r *Receipt) UnmarshalJSON(input []byte) error { type Receipt struct { PostState *hexutil.Bytes `json:"root"` @@ -45,6 +54,9 @@ func (r *Receipt) UnmarshalJSON(input []byte) error { TxHash *common.Hash `json:"transactionHash" gencodec:"required"` ContractAddress *common.Address `json:"contractAddress"` GasUsed *hexutil.Uint64 `json:"gasUsed" gencodec:"required"` + BlockHash *common.Hash `json:"blockHash,omitempty"` + BlockNumber *hexutil.Big `json:"blockNumber,omitempty"` + TransactionIndex *hexutil.Uint `json:"transactionIndex"` } var dec Receipt if err := json.Unmarshal(input, &dec); err != nil { @@ -79,5 +91,14 @@ func (r *Receipt) UnmarshalJSON(input []byte) error { return errors.New("missing required field 'gasUsed' for Receipt") } r.GasUsed = uint64(*dec.GasUsed) + if dec.BlockHash != nil { + r.BlockHash = *dec.BlockHash + } + if dec.BlockNumber != nil { + r.BlockNumber = (*big.Int)(dec.BlockNumber) + } + if dec.TransactionIndex != nil { + r.TransactionIndex = uint(*dec.TransactionIndex) + } return nil } diff --git a/core/types/log.go b/core/types/log.go index bee50763a8..93567b1e61 100644 --- a/core/types/log.go +++ b/core/types/log.go @@ -25,7 +25,7 @@ import ( "github.com/tomochain/tomochain/rlp" ) -//go:generate gencodec -type Log -field-override logMarshaling -out gen_log_json.go +//go:generate go run github.com/fjl/gencodec -type Log -field-override logMarshaling -out gen_log_json.go // Log represents a contract log event. These events are generated by the LOG opcode and // stored/indexed by the node. diff --git a/core/types/receipt.go b/core/types/receipt.go index 879aaf29c9..121c647a31 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -18,16 +18,20 @@ package types import ( "bytes" + "errors" "fmt" "io" + "math/big" "unsafe" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" ) -//go:generate gencodec -type Receipt -field-override receiptMarshaling -out gen_receipt_json.go +//go:generate go run github.com/fjl/gencodec -type Receipt -field-override receiptMarshaling -out gen_receipt_json.go var ( receiptStatusFailedRLP = []byte{} @@ -55,13 +59,21 @@ type Receipt struct { TxHash common.Hash `json:"transactionHash" gencodec:"required"` ContractAddress common.Address `json:"contractAddress"` GasUsed uint64 `json:"gasUsed" gencodec:"required"` + + // Inclusion information: These fields provide information about the inclusion of the + // transaction corresponding to this receipt. + BlockHash common.Hash `json:"blockHash,omitempty"` + BlockNumber *big.Int `json:"blockNumber,omitempty"` + TransactionIndex uint `json:"transactionIndex"` } type receiptMarshaling struct { PostState hexutil.Bytes - Status hexutil.Uint + Status hexutil.Uint64 CumulativeGasUsed hexutil.Uint64 GasUsed hexutil.Uint64 + BlockNumber *hexutil.Big + TransactionIndex hexutil.Uint } // receiptRLP is the consensus encoding of a receipt. @@ -72,7 +84,14 @@ type receiptRLP struct { Logs []*Log } -type receiptStorageRLP struct { +// StoredReceiptRLP is the storage encoding of a receipt. +type StoredReceiptRLP struct { + PostStateOrStatus []byte + CumulativeGasUsed uint64 + Logs []*Log +} + +type legacyStoredReceiptRLP struct { PostStateOrStatus []byte CumulativeGasUsed uint64 Bloom Bloom @@ -141,7 +160,6 @@ func (r *Receipt) statusEncoding() []byte { // to approximate and limit the memory consumption of various caches. func (r *Receipt) Size() common.StorageSize { size := common.StorageSize(unsafe.Sizeof(*r)) + common.StorageSize(len(r.PostState)) - size += common.StorageSize(len(r.Logs)) * common.StorageSize(unsafe.Sizeof(Log{})) for _, log := range r.Logs { size += common.StorageSize(len(log.Topics)*common.HashLength + len(log.Data)) @@ -163,50 +181,129 @@ type ReceiptForStorage Receipt // EncodeRLP implements rlp.Encoder, and flattens all content fields of a receipt // into an RLP stream. -func (r *ReceiptForStorage) EncodeRLP(w io.Writer) error { - enc := &receiptStorageRLP{ - PostStateOrStatus: (*Receipt)(r).statusEncoding(), - CumulativeGasUsed: r.CumulativeGasUsed, - Bloom: r.Bloom, - TxHash: r.TxHash, - ContractAddress: r.ContractAddress, - Logs: make([]*LogForStorage, len(r.Logs)), - GasUsed: r.GasUsed, - } - for i, log := range r.Logs { - enc.Logs[i] = (*LogForStorage)(log) +func (r *ReceiptForStorage) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + outerList := w.List() + w.WriteBytes((*Receipt)(r).statusEncoding()) + w.WriteUint64(r.CumulativeGasUsed) + logList := w.List() + for _, log := range r.Logs { + if err := rlp.Encode(w, log); err != nil { + return err + } } - return rlp.Encode(w, enc) + w.ListEnd(logList) + w.ListEnd(outerList) + return w.Flush() } // DecodeRLP implements rlp.Decoder, and loads both consensus and implementation // fields of a receipt from an RLP stream. func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error { - var dec receiptStorageRLP - if err := s.Decode(&dec); err != nil { + // Retrieve the entire receipt blob as we need to try multiple decoders + blob, err := s.Raw() + if err != nil { + return err + } + // Try decoding from the newest format for future proofness, then the older one + // for old nodes that just upgraded. V4 was an intermediate unreleased format so + // we do need to decode it, but it's not common (try last). + if err := decodeStoredReceiptRLP(r, blob); err == nil { + return nil + } + return decodeLegacyStoredReceiptRLP(r, blob) +} + +func decodeStoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { + var stored StoredReceiptRLP + if err := rlp.DecodeBytes(blob, &stored); err != nil { return err } - if err := (*Receipt)(r).setStatus(dec.PostStateOrStatus); err != nil { + if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { return err } - // Assign the consensus fields - r.CumulativeGasUsed, r.Bloom = dec.CumulativeGasUsed, dec.Bloom - r.Logs = make([]*Log, len(dec.Logs)) - for i, log := range dec.Logs { + r.CumulativeGasUsed = stored.CumulativeGasUsed + r.Logs = stored.Logs + r.Bloom = CreateBloom(Receipts{(*Receipt)(r)}) + + return nil +} + +func decodeLegacyStoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { + var stored legacyStoredReceiptRLP + if err := rlp.DecodeBytes(blob, &stored); err != nil { + return err + } + if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { + return err + } + r.CumulativeGasUsed = stored.CumulativeGasUsed + r.TxHash = stored.TxHash + r.ContractAddress = stored.ContractAddress + r.GasUsed = stored.GasUsed + r.Logs = make([]*Log, len(stored.Logs)) + for i, log := range stored.Logs { r.Logs[i] = (*Log)(log) } - // Assign the implementation fields - r.TxHash, r.ContractAddress, r.GasUsed = dec.TxHash, dec.ContractAddress, dec.GasUsed + r.Bloom = CreateBloom(Receipts{(*Receipt)(r)}) + return nil } -// Receipts is a wrapper around a Receipt array to implement DerivableList. +// Receipts implements DerivableList for receipts. type Receipts []*Receipt // Len returns the number of receipts in this list. func (r Receipts) Len() int { return len(r) } -// GetRlp returns the RLP encoding of one receipt from the list. +// DeriveFields fills the receipts with their computed fields based on consensus +// data and contextual infos like containing block and transactions. +func (rs Receipts) DeriveFields(config *params.ChainConfig, hash common.Hash, number uint64, txs []*Transaction) error { + signer := MakeSigner(config, new(big.Int).SetUint64(number)) + + logIndex := uint(0) + if len(txs) != len(rs) { + return errors.New("transaction and receipt count mismatch") + } + for i := 0; i < len(rs); i++ { + // The transaction type and hash can be retrieved from the transaction itself + rs[i].TxHash = txs[i].Hash() + + // block location fields + rs[i].BlockHash = hash + rs[i].BlockNumber = new(big.Int).SetUint64(number) + rs[i].TransactionIndex = uint(i) + + // The contract address can be derived from the transaction itself + if txs[i].To() == nil { + // Deriving the signer is expensive, only do if it's actually needed + from, _ := Sender(signer, txs[i]) + rs[i].ContractAddress = crypto.CreateAddress(from, txs[i].Nonce()) + } else { + rs[i].ContractAddress = common.Address{} + } + + // The used gas can be calculated based on previous r + if i == 0 { + rs[i].GasUsed = rs[i].CumulativeGasUsed + } else { + rs[i].GasUsed = rs[i].CumulativeGasUsed - rs[i-1].CumulativeGasUsed + } + + // The derived log fields can simply be set from the block and transaction + for j := 0; j < len(rs[i].Logs); j++ { + rs[i].Logs[j].BlockNumber = number + rs[i].Logs[j].BlockHash = hash + rs[i].Logs[j].TxHash = rs[i].TxHash + rs[i].Logs[j].TxIndex = uint(i) + rs[i].Logs[j].Index = logIndex + logIndex++ + } + } + return nil +} + +// GetRlp returns the RLP encoding of one receipt from the list.. func (r Receipts) GetRlp(i int) []byte { bytes, err := rlp.EncodeToBytes(r[i]) if err != nil { From 0c7da38ccbc6ad3f4448f8d0b49acfd72e172742 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sat, 15 Jul 2023 02:18:00 +0700 Subject: [PATCH 019/119] Add database accessor methods for receipts/logs --- core/database_util.go | 108 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 101 insertions(+), 7 deletions(-) diff --git a/core/database_util.go b/core/database_util.go index a5ab18687d..297b264121 100644 --- a/core/database_util.go +++ b/core/database_util.go @@ -22,9 +22,10 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" @@ -224,6 +225,66 @@ func GetTd(db DatabaseReader, hash common.Hash, number uint64) *big.Int { return td } +// ReceiptLogs is a barebone version of ReceiptForStorage which only keeps +// the list of logs. When decoding a stored receipt into this object we +// avoid creating the bloom filter. +type receiptLogs struct { + Logs []*types.Log +} + +// DecodeRLP implements rlp.Decoder. +func (r *receiptLogs) DecodeRLP(s *rlp.Stream) error { + var stored types.StoredReceiptRLP + if err := s.Decode(&stored); err != nil { + return err + } + r.Logs = stored.Logs + return nil +} + +// DeriveLogFields fills the logs in receiptLogs with information such as block number, txhash, etc. +func deriveLogFields(receipts []*receiptLogs, hash common.Hash, number uint64, txs types.Transactions) error { + logIndex := uint(0) + if len(txs) != len(receipts) { + return errors.New("transaction and receipt count mismatch") + } + for i := 0; i < len(receipts); i++ { + txHash := txs[i].Hash() + // The derived log fields can simply be set from the block and transaction + for j := 0; j < len(receipts[i].Logs); j++ { + receipts[i].Logs[j].BlockNumber = number + receipts[i].Logs[j].BlockHash = hash + receipts[i].Logs[j].TxHash = txHash + receipts[i].Logs[j].TxIndex = uint(i) + receipts[i].Logs[j].Index = logIndex + logIndex++ + } + } + return nil +} + +// ReadLogs retrieves the logs for all transactions in a block. In case +// receipts is not found, a nil is returned. +// Note: ReadLogs does not derive unstored log fields. +func ReadLogs(db ethdb.Reader, hash common.Hash, number uint64, config *params.ChainConfig) [][]*types.Log { + // Retrieve the flattened receipt slice + data := ReadReceiptsRLP(db, hash, number) + if len(data) == 0 { + return nil + } + var receipts []*receiptLogs + if err := rlp.DecodeBytes(data, &receipts); err != nil { + log.Error("Invalid receipt array RLP", "hash", hash, "err", err) + return nil + } + + logs := make([][]*types.Log, len(receipts)) + for i, receipt := range receipts { + logs[i] = receipt.Logs + } + return logs +} + // GetBlock retrieves an entire block corresponding to the hash, assembling it // back from the stored header and body. If either the header or body could not // be retrieved nil is returned. @@ -244,14 +305,25 @@ func GetBlock(db DatabaseReader, hash common.Hash, number uint64) *types.Block { return types.NewBlockWithHeader(header).WithBody(body.Transactions, body.Uncles) } -// GetBlockReceipts retrieves the receipts generated by the transactions included -// in a block given by its hash. -func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types.Receipts { +// ReadReceiptsRLP retrieves all the transaction receipts belonging to a block in RLP encoding. +func ReadReceiptsRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue { data, _ := db.Get(append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash[:]...)) if len(data) == 0 { return nil } - storageReceipts := []*types.ReceiptForStorage{} + return data +} + +// ReadRawReceipts retrieves all the transaction receipts belonging to a block. +// The receipt metadata fields are not guaranteed to be populated, so they +// should not be used. Use ReadReceipts instead if the metadata is needed. +func ReadRawReceipts(db DatabaseReader, hash common.Hash, number uint64) types.Receipts { + // Retrieve the flattened receipt slice + data := ReadReceiptsRLP(db, hash, number) + if len(data) == 0 { + return nil + } + var storageReceipts []*types.ReceiptForStorage if err := rlp.DecodeBytes(data, &storageReceipts); err != nil { log.Error("Invalid receipt array RLP", "hash", hash, "err", err) return nil @@ -263,6 +335,28 @@ func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types. return receipts } +// GetBlockReceipts retrieves the receipts generated by the transactions included +// in a block given by its hash. +func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64, config *params.ChainConfig) types.Receipts { + // We're deriving many fields from the block body, retrieve beside the receipt + receipts := ReadRawReceipts(db, hash, number) + if receipts == nil { + return nil + } + + body := GetBody(db, hash, number) + if body == nil { + log.Error("Missing body but have receipt", "hash", hash, "number", number) + return nil + } + if err := receipts.DeriveFields(config, hash, number, body.Transactions); err != nil { + log.Error("Failed to derive block receipts fields", "hash", hash, "number", number, "err", err) + return nil + } + + return receipts +} + // GetTxLookupEntry retrieves the positional metadata associated with a transaction // hash to allow retrieving the transaction or receipt by hash. func GetTxLookupEntry(db DatabaseReader, hash common.Hash) (common.Hash, uint64, uint64) { @@ -317,12 +411,12 @@ func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, co // GetReceipt retrieves a specific transaction receipt from the database, along with // its added positional metadata. -func GetReceipt(db DatabaseReader, hash common.Hash) (*types.Receipt, common.Hash, uint64, uint64) { +func GetReceipt(db DatabaseReader, hash common.Hash, config *params.ChainConfig) (*types.Receipt, common.Hash, uint64, uint64) { // Retrieve the lookup metadata and resolve the receipt from the receipts blockHash, blockNumber, receiptIndex := GetTxLookupEntry(db, hash) if blockHash != (common.Hash{}) { - receipts := GetBlockReceipts(db, blockHash, blockNumber) + receipts := GetBlockReceipts(db, blockHash, blockNumber, config) if len(receipts) <= int(receiptIndex) { log.Error("Receipt refereced missing", "number", blockNumber, "hash", blockHash, "index", receiptIndex) return nil, common.Hash{}, 0, 0 From e1226ecd80e4d4f0f7c96fc14a3c97577c053751 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sat, 15 Jul 2023 02:26:28 +0700 Subject: [PATCH 020/119] Add chain configs as a requirement for receipts/logs retrieving --- accounts/abi/bind/backends/simulated.go | 14 +++++++------- contracts/utils.go | 2 +- core/blockchain.go | 4 ++-- eth/api_backend.go | 14 ++++++-------- eth/downloader/fakepeer.go | 2 +- les/api_backend.go | 11 +++++------ les/handler.go | 4 ++-- les/odr_test.go | 6 +++--- light/odr_util.go | 9 +++++---- light/txpool.go | 11 +++++++---- 10 files changed, 39 insertions(+), 38 deletions(-) diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7411f492a8..e16aa8f34a 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -20,8 +20,6 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sync" "time" @@ -30,9 +28,11 @@ import ( "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -174,7 +174,7 @@ func (b *SimulatedBackend) ForEachStorageAt(ctx context.Context, contract common // TransactionReceipt returns the receipt of a transaction. func (b *SimulatedBackend) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) { - receipt, _, _, _ := core.GetReceipt(b.database, txHash) + receipt, _, _, _ := core.GetReceipt(b.database, txHash, b.config) return receipt, nil } @@ -202,7 +202,7 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.Call return rval, err } -//FIXME: please use copyState for this function +// FIXME: please use copyState for this function // CallContractWithState executes a contract call at the given state. func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { // Ensure message is initialized properly. @@ -285,7 +285,7 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM snapshot := b.pendingState.Snapshot() _, _, failed, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) - fmt.Println("EstimateGas",err,failed) + fmt.Println("EstimateGas", err, failed) b.pendingState.RevertToSnapshot(snapshot) if err != nil || failed { @@ -485,11 +485,11 @@ func (fb *filterBackend) HeaderByNumber(ctx context.Context, block rpc.BlockNumb } func (fb *filterBackend) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) { - return core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash)), nil + return core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash), fb.bc.Config()), nil } func (fb *filterBackend) GetLogs(ctx context.Context, hash common.Hash) ([][]*types.Log, error) { - receipts := core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash)) + receipts := core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash), fb.bc.Config()) if receipts == nil { return nil, nil } diff --git a/contracts/utils.go b/contracts/utils.go index 4468b5de9a..f4e58b7711 100644 --- a/contracts/utils.go +++ b/contracts/utils.go @@ -336,7 +336,7 @@ func GetRewardForCheckpoint(c *posv.Posv, chain consensus.ChainReader, header *t block := chain.GetBlock(header.Hash(), i) txs := block.Transactions() if !chain.Config().IsTIPSigning(header.Number) { - receipts := core.GetBlockReceipts(c.GetDb(), header.Hash(), i) + receipts := core.GetBlockReceipts(c.GetDb(), header.Hash(), i, chain.Config()) signData = c.CacheData(header, txs, receipts) } else { signData = c.CacheSigner(header.Hash(), txs) diff --git a/core/blockchain.go b/core/blockchain.go index f763189be7..27c5e0554b 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -800,7 +800,7 @@ func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block { // GetReceiptsByHash retrieves the receipts for all transactions in a given block. func (bc *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts { - return GetBlockReceipts(bc.db, hash, GetBlockNumber(bc.db, hash)) + return GetBlockReceipts(bc.db, hash, GetBlockNumber(bc.db, hash), bc.chainConfig) } // GetBlocksFromHash returns the block corresponding to hash and up to n-1 ancestors. @@ -2120,7 +2120,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { // These logs are later announced as deleted. collectLogs = func(h common.Hash) { // Coalesce logs and set 'Removed'. - receipts := GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h)) + receipts := GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h), bc.chainConfig) for _, receipt := range receipts { for _, log := range receipt.Logs { del := *log diff --git a/eth/api_backend.go b/eth/api_backend.go index 67554b4480..09190d63ea 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -21,20 +21,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending" "io/ioutil" "math/big" "path/filepath" - "github.com/tomochain/tomochain/tomox" - - "github.com/tomochain/tomochain/consensus/posv" - "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/consensus/posv" "github.com/tomochain/tomochain/contracts" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" @@ -50,6 +45,9 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending" ) // EthApiBackend implements ethapi.Backend for full nodes @@ -117,11 +115,11 @@ func (b *EthApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*t } func (b *EthApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { - return core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)), nil + return core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig()), nil } func (b *EthApiBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { - receipts := core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) + receipts := core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig()) if receipts == nil { return nil, nil } diff --git a/eth/downloader/fakepeer.go b/eth/downloader/fakepeer.go index 4d7c5ac280..c2a5178342 100644 --- a/eth/downloader/fakepeer.go +++ b/eth/downloader/fakepeer.go @@ -140,7 +140,7 @@ func (p *FakePeer) RequestBodies(hashes []common.Hash) error { func (p *FakePeer) RequestReceipts(hashes []common.Hash) error { var receipts [][]*types.Receipt for _, hash := range hashes { - receipts = append(receipts, core.GetBlockReceipts(p.db, hash, p.hc.GetBlockNumber(hash))) + receipts = append(receipts, core.GetBlockReceipts(p.db, hash, p.hc.GetBlockNumber(hash), p.hc.Config())) } p.dl.DeliverReceipts(p.id, receipts) return nil diff --git a/les/api_backend.go b/les/api_backend.go index d8285da97d..46ebfacf9f 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -20,14 +20,10 @@ import ( "context" "encoding/json" "errors" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending" "io/ioutil" "math/big" "path/filepath" - "github.com/tomochain/tomochain/tomox" - "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" @@ -45,6 +41,9 @@ import ( "github.com/tomochain/tomochain/light" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending" ) type LesApiBackend struct { @@ -94,11 +93,11 @@ func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*t } func (b *LesApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { - return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) + return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig()) } func (b *LesApiBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { - return light.GetBlockLogs(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) + return light.GetBlockLogs(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig()) } func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int { diff --git a/les/handler.go b/les/handler.go index b426f7fdd1..c338ca62ae 100644 --- a/les/handler.go +++ b/les/handler.go @@ -21,7 +21,6 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "net" "sync" @@ -30,6 +29,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth/downloader" @@ -646,7 +646,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { break } // Retrieve the requested block's receipts, skipping if unknown to us - results := core.GetBlockReceipts(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash)) + results := core.GetBlockReceipts(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash), pm.chainConfig) if results == nil { if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { continue diff --git a/les/odr_test.go b/les/odr_test.go index 3858e34028..f6a5a9bd07 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -19,7 +19,6 @@ package les import ( "bytes" "context" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" "time" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -64,9 +64,9 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var receipts types.Receipts if bc != nil { - receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) + receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash), config) } else { - receipts, _ = light.GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) + receipts, _ = light.GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash), config) } if receipts == nil { return nil diff --git a/light/odr_util.go b/light/odr_util.go index 89a63eb2b9..6adbc9303f 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -24,6 +24,7 @@ import ( "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" ) @@ -125,9 +126,9 @@ func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint // GetBlockReceipts retrieves the receipts generated by the transactions included // in a block given by its hash. -func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (types.Receipts, error) { +func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64, config *params.ChainConfig) (types.Receipts, error) { // Retrieve the potentially incomplete receipts from disk or network - receipts := core.GetBlockReceipts(odr.Database(), hash, number) + receipts := core.GetBlockReceipts(odr.Database(), hash, number, config) if receipts == nil { r := &ReceiptsRequest{Hash: hash, Number: number} if err := odr.Retrieve(ctx, r); err != nil { @@ -154,9 +155,9 @@ func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, num // GetBlockLogs retrieves the logs generated by the transactions included in a // block given by its hash. -func GetBlockLogs(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) ([][]*types.Log, error) { +func GetBlockLogs(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64, config *params.ChainConfig) ([][]*types.Log, error) { // Retrieve the potentially incomplete receipts from disk or network - receipts := core.GetBlockReceipts(odr.Database(), hash, number) + receipts := core.GetBlockReceipts(odr.Database(), hash, number, config) if receipts == nil { r := &ReceiptsRequest{Hash: hash, Number: number} if err := odr.Retrieve(ctx, r); err != nil { diff --git a/light/txpool.go b/light/txpool.go index 7af86dbd6b..9d75448a67 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -74,10 +74,13 @@ type TxPool struct { // // Send instructs backend to forward new transactions // NewHead notifies backend about a new head after processed by the tx pool, -// including mined and rolled back transactions since the last event +// +// including mined and rolled back transactions since the last event +// // Discard notifies backend about transactions that should be discarded either -// because they have been replaced by a re-send or because they have been mined -// long ago and no rollback is expected +// +// because they have been replaced by a re-send or because they have been mined +// long ago and no rollback is expected type TxRelayBackend interface { Send(txs types.Transactions) NewHead(head common.Hash, mined []common.Hash, rollback []common.Hash) @@ -180,7 +183,7 @@ func (pool *TxPool) checkMinedTxs(ctx context.Context, hash common.Hash, number // If some transactions have been mined, write the needed data to disk and update if list != nil { // Retrieve all the receipts belonging to this block and write the loopup table - if _, err := GetBlockReceipts(ctx, pool.odr, hash, number); err != nil { // ODR caches, ignore results + if _, err := GetBlockReceipts(ctx, pool.odr, hash, number, pool.config); err != nil { // ODR caches, ignore results return err } if err := core.WriteTxLookupEntries(pool.chainDb, block); err != nil { From 0eb25b258e21c9868a2d680a1271563d4b126470 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sat, 15 Jul 2023 02:42:18 +0700 Subject: [PATCH 021/119] [WIP] Fix unit tests --- accounts/abi/bind/backends/simulated.go | 1 - core/blockchain_test.go | 11 ++++++----- core/database_util.go | 6 +++--- core/database_util_test.go | 10 ++++++---- core/genesis.go | 12 ++++++------ core/headerchain.go | 11 ++++++----- eth/filters/filter_system_test.go | 10 +++++++--- les/handler_test.go | 4 ++-- light/odr_test.go | 16 +++++++++++----- 9 files changed, 47 insertions(+), 34 deletions(-) diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index e16aa8f34a..de836c5033 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -285,7 +285,6 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM snapshot := b.pendingState.Snapshot() _, _, failed, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) - fmt.Println("EstimateGas", err, failed) b.pendingState.RevertToSnapshot(snapshot) if err != nil || failed { diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 6860924112..a76dc87814 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -18,13 +18,14 @@ package core import ( "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "sync" "testing" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core/state" @@ -622,7 +623,7 @@ func TestFastVsFullChains(t *testing.T) { } else if types.CalcUncleHash(fblock.Uncles()) != types.CalcUncleHash(ablock.Uncles()) { t.Errorf("block #%d [%x]: uncles mismatch: have %v, want %v", num, hash, fblock.Uncles(), ablock.Uncles()) } - if freceipts, areceipts := GetBlockReceipts(fastDb, hash, GetBlockNumber(fastDb, hash)), GetBlockReceipts(archiveDb, hash, GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts) != types.DeriveSha(areceipts) { + if freceipts, areceipts := GetBlockReceipts(fastDb, hash, GetBlockNumber(fastDb, hash), fast.Config()), GetBlockReceipts(archiveDb, hash, GetBlockNumber(archiveDb, hash), fast.Config()); types.DeriveSha(freceipts) != types.DeriveSha(areceipts) { t.Errorf("block #%d [%x]: receipts mismatch: have %v, want %v", num, hash, freceipts, areceipts) } } @@ -807,7 +808,7 @@ func TestChainTxReorgs(t *testing.T) { if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn != nil { t.Errorf("drop %d: tx %v found while shouldn't have been", i, txn) } - if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt != nil { + if rcpt, _, _, _ := GetReceipt(db, tx.Hash(), blockchain.Config()); rcpt != nil { t.Errorf("drop %d: receipt %v found while shouldn't have been", i, rcpt) } } @@ -816,7 +817,7 @@ func TestChainTxReorgs(t *testing.T) { if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn == nil { t.Errorf("add %d: expected tx to be found", i) } - if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt == nil { + if rcpt, _, _, _ := GetReceipt(db, tx.Hash(), blockchain.Config()); rcpt == nil { t.Errorf("add %d: expected receipt to be found", i) } } @@ -825,7 +826,7 @@ func TestChainTxReorgs(t *testing.T) { if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn == nil { t.Errorf("share %d: expected tx to be found", i) } - if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt == nil { + if rcpt, _, _, _ := GetReceipt(db, tx.Hash(), blockchain.Config()); rcpt == nil { t.Errorf("share %d: expected receipt to be found", i) } } diff --git a/core/database_util.go b/core/database_util.go index 297b264121..b1c2f3f3d2 100644 --- a/core/database_util.go +++ b/core/database_util.go @@ -101,16 +101,16 @@ func GetCanonicalHash(db DatabaseReader, number uint64) common.Hash { return common.BytesToHash(data) } -// missingNumber is returned by GetBlockNumber if no header with the +// MissingNumber is returned by GetBlockNumber if no header with the // given block hash has been stored in the database -const missingNumber = uint64(0xffffffffffffffff) +const MissingNumber = uint64(0xffffffffffffffff) // GetBlockNumber returns the block number assigned to a block hash // if the corresponding header is present in the database func GetBlockNumber(db DatabaseReader, hash common.Hash) uint64 { data, _ := db.Get(append(blockHashPrefix, hash.Bytes()...)) if len(data) != 8 { - return missingNumber + return MissingNumber } return binary.BigEndian.Uint64(data) } diff --git a/core/database_util_test.go b/core/database_util_test.go index f28ca160a5..19cbf790fb 100644 --- a/core/database_util_test.go +++ b/core/database_util_test.go @@ -18,10 +18,12 @@ package core import ( "bytes" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto/sha3" @@ -361,14 +363,14 @@ func TestBlockReceiptStorage(t *testing.T) { // Check that no receipt entries are in a pristine database hash := common.BytesToHash([]byte{0x03, 0x14}) - if rs := GetBlockReceipts(db, hash, 0); len(rs) != 0 { + if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) != 0 { t.Fatalf("non existent receipts returned: %v", rs) } // Insert the receipt slice into the database and check presence if err := WriteBlockReceipts(db, hash, 0, receipts); err != nil { t.Fatalf("failed to write block receipts: %v", err) } - if rs := GetBlockReceipts(db, hash, 0); len(rs) == 0 { + if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) == 0 { t.Fatalf("no receipts returned") } else { for i := 0; i < len(receipts); i++ { @@ -382,7 +384,7 @@ func TestBlockReceiptStorage(t *testing.T) { } // Delete the receipt slice and check purge DeleteBlockReceipts(db, hash, 0) - if rs := GetBlockReceipts(db, hash, 0); len(rs) != 0 { + if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) != 0 { t.Fatalf("deleted receipts returned: %v", rs) } } diff --git a/core/genesis.go b/core/genesis.go index e1b7185a41..fcd196bd47 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -22,13 +22,13 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "strings" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" @@ -140,10 +140,10 @@ func (e *GenesisMismatchError) Error() string { // SetupGenesisBlock writes or updates the genesis block in db. // The block that will be used is: // -// genesis == nil genesis != nil -// +------------------------------------------ -// db has no genesis | main-net default | genesis -// db has genesis | from DB | genesis (if compatible) +// genesis == nil genesis != nil +// +------------------------------------------ +// db has no genesis | main-net default | genesis +// db has genesis | from DB | genesis (if compatible) // // The stored chain configuration will be updated if it is compatible (i.e. does not // specify a fork block below the local head block). In case of a conflict, the @@ -197,7 +197,7 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig // Check config compatibility and write the config. Compatibility errors // are returned to the caller unless we're already at block zero. height := GetBlockNumber(db, GetHeadHeaderHash(db)) - if height == missingNumber { + if height == MissingNumber { return newcfg, stored, fmt.Errorf("missing block number for head header hash") } compatErr := storedcfg.CheckCompatible(newcfg, height) diff --git a/core/headerchain.go b/core/headerchain.go index 8365f2127d..f3cc8cf77b 100644 --- a/core/headerchain.go +++ b/core/headerchain.go @@ -26,7 +26,7 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/types" @@ -66,9 +66,10 @@ type HeaderChain struct { } // NewHeaderChain creates a new HeaderChain structure. -// getValidator should return the parent's validator -// procInterrupt points to the parent's interrupt semaphore -// wg points to the parent's shutdown wait group +// +// getValidator should return the parent's validator +// procInterrupt points to the parent's interrupt semaphore +// wg points to the parent's shutdown wait group func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine consensus.Engine, procInterrupt func() bool) (*HeaderChain, error) { headerCache, _ := lru.New(headerCacheLimit) tdCache, _ := lru.New(tdCacheLimit) @@ -114,7 +115,7 @@ func (hc *HeaderChain) GetBlockNumber(hash common.Hash) uint64 { return cached.(uint64) } number := GetBlockNumber(hc.chainDb, hash) - if number != missingNumber { + if number != MissingNumber { hc.numberCache.Add(hash, number) } return number diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go index d947a672ac..eb3e7cce4b 100644 --- a/eth/filters/filter_system_test.go +++ b/eth/filters/filter_system_test.go @@ -19,7 +19,6 @@ package filters import ( "context" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "reflect" @@ -31,6 +30,7 @@ import ( "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -48,6 +48,10 @@ type testBackend struct { chainFeed *event.Feed } +func (b *testBackend) ChainConfig() *params.ChainConfig { + return params.TestChainConfig +} + func (b *testBackend) ChainDb() ethdb.Database { return b.db } @@ -71,12 +75,12 @@ func (b *testBackend) HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumbe func (b *testBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { number := core.GetBlockNumber(b.db, blockHash) - return core.GetBlockReceipts(b.db, blockHash, number), nil + return core.GetBlockReceipts(b.db, blockHash, number, b.ChainConfig()), nil } func (b *testBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { number := core.GetBlockNumber(b.db, blockHash) - receipts := core.GetBlockReceipts(b.db, blockHash, number) + receipts := core.GetBlockReceipts(b.db, blockHash, number, b.ChainConfig()) logs := make([][]*types.Log, len(receipts)) for i, receipt := range receipts { diff --git a/les/handler_test.go b/les/handler_test.go index 225900dd52..3bb88ab9b9 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -18,7 +18,6 @@ package les import ( "encoding/binary" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "testing" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/eth/downloader" @@ -304,7 +304,7 @@ func testGetReceipt(t *testing.T, protocol int) { block := bc.GetBlockByNumber(i) hashes = append(hashes, block.Hash()) - receipts = append(receipts, core.GetBlockReceipts(db, block.Hash(), block.NumberU64())) + receipts = append(receipts, core.GetBlockReceipts(db, block.Hash(), block.NumberU64(), bc.Config())) } // Send the hash request and verify the response cost := peer.GetRequestCost(GetReceiptsMsg, len(hashes)) diff --git a/light/odr_test.go b/light/odr_test.go index 0c5fc78573..d497f0d228 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -20,16 +20,16 @@ import ( "bytes" "context" "errors" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -74,7 +74,10 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { case *BlockRequest: req.Rlp = core.GetBodyRLP(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) case *ReceiptsRequest: - req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) + number := core.GetBlockNumber(odr.sdb, req.Hash) + if number != core.MissingNumber { + req.Receipts = core.ReadRawReceipts(odr.sdb, req.Hash, number) + } case *TrieRequest: t, _ := trie.New(req.Id.Root, trie.NewDatabase(odr.sdb)) nodes := NewNodeSet() @@ -110,9 +113,12 @@ func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, odrGetReceipts) } func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { var receipts types.Receipts if bc != nil { - receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) + receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash), bc.Config()) } else { - receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) + number := core.GetBlockNumber(db, bhash) + if number != core.MissingNumber { + receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, number, lc.Config()) + } } if receipts == nil { return nil, nil From 8fdaf686beb465eda94af6393c09fa135501923c Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sun, 16 Jul 2023 19:36:02 +0700 Subject: [PATCH 022/119] Unify multiple keccak interface into one inside crypto package --- core/vm/instructions.go | 25 ++++++++++++++-------- core/vm/interpreter.go | 14 +++--------- crypto/crypto.go | 47 +++++++++++++++++++++++++++++++++++------ trie/committer.go | 6 +++--- trie/hasher.go | 15 +++---------- 5 files changed, 66 insertions(+), 41 deletions(-) diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 16f3685852..ab962bd65d 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -17,13 +17,13 @@ package vm import ( - "github.com/tomochain/tomochain/params" "math/big" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core/types" - "golang.org/x/crypto/sha3" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/params" ) var ( @@ -381,7 +381,7 @@ func opSha3(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]by data := callContext.memory.GetPtr(offset.Int64(), size.Int64()) if interpreter.hasher == nil { - interpreter.hasher = sha3.NewLegacyKeccak256().(keccakState) + interpreter.hasher = crypto.NewKeccakState() } else { interpreter.hasher.Reset() } @@ -513,16 +513,21 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx // opExtCodeHash returns the code hash of a specified account. // There are several cases when the function is called, while we can relay everything // to `state.GetCodeHash` function to ensure the correctness. -// (1) Caller tries to get the code hash of a normal contract account, state +// +// (1) Caller tries to get the code hash of a normal contract account, state +// // should return the relative code hash and set it as the result. // -// (2) Caller tries to get the code hash of a non-existent account, state should +// (2) Caller tries to get the code hash of a non-existent account, state should +// // return common.Hash{} and zero will be set as the result. // -// (3) Caller tries to get the code hash for an account without contract code, +// (3) Caller tries to get the code hash for an account without contract code, +// // state should return emptyCodeHash(0xc5d246...) as the result. // -// (4) Caller tries to get the code hash of a precompiled account, the result +// (4) Caller tries to get the code hash of a precompiled account, the result +// // should be zero or emptyCodeHash. // // It is worth noting that in order to avoid unnecessary create and clean, @@ -531,10 +536,12 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx // If the precompile account is not transferred any amount on a private or // customized chain, the return value will be zero. // -// (5) Caller tries to get the code hash for an account which is marked as suicided +// (5) Caller tries to get the code hash for an account which is marked as suicided +// // in the current transaction, the code hash of this account should be returned. // -// (6) Caller tries to get the code hash for an account which is marked as deleted, +// (6) Caller tries to get the code hash for an account which is marked as deleted, +// // this account should be regarded as a non-existent account and zero should be returned. func opExtCodeHash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { slot := callContext.stack.peek() diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index fc5b17a4f3..36027be797 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -17,11 +17,11 @@ package vm import ( - "hash" "sync/atomic" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/log" ) @@ -70,14 +70,6 @@ type callCtx struct { contract *Contract } -// keccakState wraps sha3.state. In addition to the usual hash methods, it also supports -// Read to get a variable amount of data from the hash state. Read is faster than Sum -// because it doesn't copy the internal state, but also modifies the internal state. -type keccakState interface { - hash.Hash - Read([]byte) (int, error) -} - // EVMInterpreter represents an EVM interpreter type EVMInterpreter struct { evm *EVM @@ -85,8 +77,8 @@ type EVMInterpreter struct { intPool *intPool - hasher keccakState // Keccak256 hasher instance shared across opcodes - hasherBuf common.Hash // Keccak256 hasher result array shared aross opcodes + hasher crypto.KeccakState // Keccak256 hasher instance shared across opcodes + hasherBuf common.Hash // Keccak256 hasher result array shared across opcodes readOnly bool // Whether to throw on stateful modifications returnData []byte // Last CALL's return data for subsequent reuse diff --git a/crypto/crypto.go b/crypto/crypto.go index 18386f85c0..6affee64ce 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -23,6 +23,7 @@ import ( "encoding/hex" "errors" "fmt" + "hash" "io" "io/ioutil" "math/big" @@ -30,38 +31,72 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" - "github.com/tomochain/tomochain/crypto/sha3" "github.com/tomochain/tomochain/rlp" + "golang.org/x/crypto/sha3" ) +// SignatureLength indicates the byte length required to carry a signature with recovery id. +const SignatureLength = 64 + 1 // 64 bytes ECDSA signature + 1 byte recovery id + +// RecoveryIDOffset points to the byte offset within the signature that contains the recovery id. +const RecoveryIDOffset = 64 + +// DigestLength sets the signature digest exact length +const DigestLength = 32 + var ( secp256k1_N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) secp256k1_halfN = new(big.Int).Div(secp256k1_N, big.NewInt(2)) ) +var errInvalidPubkey = errors.New("invalid secp256k1 public key") + +// KeccakState wraps sha3.state. In addition to the usual hash methods, it also supports +// Read to get a variable amount of data from the hash state. Read is faster than Sum +// because it doesn't copy the internal state, but also modifies the internal state. +type KeccakState interface { + hash.Hash + Read([]byte) (int, error) +} + +// NewKeccakState creates a new KeccakState +func NewKeccakState() KeccakState { + return sha3.NewLegacyKeccak256().(KeccakState) +} + +// HashData hashes the provided data using the KeccakState and returns a 32 byte hash +func HashData(kh KeccakState, data []byte) (h common.Hash) { + kh.Reset() + kh.Write(data) + kh.Read(h[:]) + return h +} + // Keccak256 calculates and returns the Keccak256 hash of the input data. func Keccak256(data ...[]byte) []byte { - d := sha3.NewKeccak256() + b := make([]byte, 32) + d := NewKeccakState() for _, b := range data { d.Write(b) } - return d.Sum(nil) + d.Read(b) + return b } // Keccak256Hash calculates and returns the Keccak256 hash of the input data, // converting it to an internal Hash data structure. func Keccak256Hash(data ...[]byte) (h common.Hash) { - d := sha3.NewKeccak256() + d := NewKeccakState() for _, b := range data { d.Write(b) } - d.Sum(h[:0]) + d.Read(h[:]) return h } // Keccak512 calculates and returns the Keccak512 hash of the input data. func Keccak512(data ...[]byte) []byte { - d := sha3.NewKeccak512() + d := sha3.NewLegacyKeccak512() for _, b := range data { d.Write(b) } diff --git a/trie/committer.go b/trie/committer.go index 78ed86bb4a..43a31381b9 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -22,8 +22,8 @@ import ( "sync" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/rlp" - "golang.org/x/crypto/sha3" ) // leafChanSize is the size of the leafCh. It's a pretty arbitrary number, to allow @@ -46,7 +46,7 @@ type leaf struct { // processed sequentially - onleaf will never be called in parallel or out of order. type committer struct { tmp sliceBuffer - sha keccakState + sha crypto.KeccakState onleaf LeafCallback leafCh chan *leaf @@ -57,7 +57,7 @@ var committerPool = sync.Pool{ New: func() interface{} { return &committer{ tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + sha: crypto.NewKeccakState(), } }, } diff --git a/trie/hasher.go b/trie/hasher.go index 8a2ea18068..e306b35e7b 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -17,21 +17,12 @@ package trie import ( - "hash" "sync" + "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/rlp" - "golang.org/x/crypto/sha3" ) -// keccakState wraps sha3.state. In addition to the usual hash methods, it also supports -// Read to get a variable amount of data from the hash state. Read is faster than Sum -// because it doesn't copy the internal state, but also modifies the internal state. -type keccakState interface { - hash.Hash - Read([]byte) (int, error) -} - type sliceBuffer []byte func (b *sliceBuffer) Write(data []byte) (n int, err error) { @@ -46,7 +37,7 @@ func (b *sliceBuffer) Reset() { // hasher is a type used for the trie Hash operation. A hasher has some // internal preallocated temp space type hasher struct { - sha keccakState + sha crypto.KeccakState tmp sliceBuffer parallel bool // Whether to use paralallel threads when hashing } @@ -56,7 +47,7 @@ var hasherPool = sync.Pool{ New: func() interface{} { return &hasher{ tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + sha: crypto.NewKeccakState(), } }, } From 340bf2baa69d9bc0c9009dc307d401bf237bf097 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sun, 16 Jul 2023 20:14:25 +0700 Subject: [PATCH 023/119] Add encode method to Node interface --- trie/database.go | 7 ++++- trie/hasher.go | 82 +++++++++++++++++++++++++----------------------- trie/node.go | 14 +++------ trie/node_enc.go | 64 +++++++++++++++++++++++++++++++++++++ 4 files changed, 116 insertions(+), 51 deletions(-) create mode 100644 trie/node_enc.go diff --git a/trie/database.go b/trie/database.go index bb2da07c29..bf3f2e89de 100644 --- a/trie/database.go +++ b/trie/database.go @@ -106,6 +106,11 @@ type rawNode []byte func (n rawNode) Cache() (HashNode, bool) { panic("this should never end up in a live trie") } func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawNode) EncodeRLP(w io.Writer) error { + _, err := w.Write(n) + return err +} + // rawFullNode represents only the useful data content of a full Node, with the // caches and flags stripped out to minimize its data storage. This type honors // the same RLP encoding as the original parent. @@ -184,7 +189,7 @@ func (n *cachedNode) obj(hash common.Hash) Node { // forChilds invokes the callback for all the tracked children of this Node, // both the implicit ones from inside the Node as well as the explicit ones -//from outside the Node. +// from outside the Node. func (n *cachedNode) forChilds(onChild func(hash common.Hash)) { for child := range n.children { onChild(child) diff --git a/trie/hasher.go b/trie/hasher.go index e306b35e7b..0cda29e3f5 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -1,4 +1,4 @@ -// Copyright 2019 The go-ethereum Authors +// Copyright 2016 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify @@ -23,31 +23,22 @@ import ( "github.com/tomochain/tomochain/rlp" ) -type sliceBuffer []byte - -func (b *sliceBuffer) Write(data []byte) (n int, err error) { - *b = append(*b, data...) - return len(data), nil -} - -func (b *sliceBuffer) Reset() { - *b = (*b)[:0] -} - // hasher is a type used for the trie Hash operation. A hasher has some // internal preallocated temp space type hasher struct { sha crypto.KeccakState - tmp sliceBuffer - parallel bool // Whether to use paralallel threads when hashing + tmp []byte + encbuf rlp.EncoderBuffer + parallel bool // Whether to use parallel threads when hashing } // hasherPool holds pureHashers var hasherPool = sync.Pool{ New: func() interface{} { return &hasher{ - tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: crypto.NewKeccakState(), + tmp: make([]byte, 0, 550), // cap is as large as a full fullNode. + sha: crypto.NewKeccakState(), + encbuf: rlp.NewEncoderBuffer(nil), } }, } @@ -62,14 +53,14 @@ func returnHasherToPool(h *hasher) { hasherPool.Put(h) } -// hash collapses a Node down into a hash Node, also returning a copy of the -// original Node initialized with the computed hash to replace the original one. +// hash collapses a node down into a hash node, also returning a copy of the +// original node initialized with the computed hash to replace the original one. func (h *hasher) hash(n Node, force bool) (hashed Node, cached Node) { - // We're not storing the Node, just hashing, use available cached data + // Return the cached hash if it's available if hash, _ := n.Cache(); hash != nil { return hash, n } - // Trie not processed yet or needs storage, walk the children + // Trie not processed yet, walk the children switch n := n.(type) { case *ShortNode: collapsed, cached := h.hashShortNodeChildren(n) @@ -97,11 +88,11 @@ func (h *hasher) hash(n Node, force bool) (hashed Node, cached Node) { } } -// hashShortNodeChildren collapses the short Node. The returned collapsed Node +// hashShortNodeChildren collapses the short node. The returned collapsed node // holds a live reference to the Key, and must not be modified. // The cached func (h *hasher) hashShortNodeChildren(n *ShortNode) (collapsed, cached *ShortNode) { - // Hash the short Node's child, caching the newly hashed subtree + // Hash the short node's child, caching the newly hashed subtree collapsed, cached = n.copy(), n.copy() // Previously, we did copy this one. We don't seem to need to actually // do that, since we don't overwrite/reuse keys @@ -116,7 +107,7 @@ func (h *hasher) hashShortNodeChildren(n *ShortNode) (collapsed, cached *ShortNo } func (h *hasher) hashFullNodeChildren(n *FullNode) (collapsed *FullNode, cached *FullNode) { - // Hash the full Node's children, caching the newly hashed subtrees + // Hash the full node's children, caching the newly hashed subtrees cached = n.copy() collapsed = n.copy() if h.parallel { @@ -147,35 +138,46 @@ func (h *hasher) hashFullNodeChildren(n *FullNode) (collapsed *FullNode, cached return collapsed, cached } -// shortnodeToHash creates a HashNode from a ShortNode. The supplied shortnode +// shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode // should have hex-type Key, which will be converted (without modification) // into compact form for RLP encoding. // If the rlp data is smaller than 32 bytes, `nil` is returned. func (h *hasher) shortnodeToHash(n *ShortNode, force bool) Node { - h.tmp.Reset() - if err := rlp.Encode(&h.tmp, n); err != nil { - panic("encode error: " + err.Error()) - } + n.encode(h.encbuf) + enc := h.encodedBytes() - if len(h.tmp) < 32 && !force { + if len(enc) < 32 && !force { return n // Nodes smaller than 32 bytes are stored inside their parent } - return h.hashData(h.tmp) + return h.hashData(enc) } -// shortnodeToHash is used to creates a HashNode from a set of hashNodes, (which +// shortnodeToHash is used to creates a hashNode from a set of hashNodes, (which // may contain nil values) func (h *hasher) fullnodeToHash(n *FullNode, force bool) Node { - h.tmp.Reset() - // Generate the RLP encoding of the Node - if err := n.EncodeRLP(&h.tmp); err != nil { - panic("encode error: " + err.Error()) - } + n.encode(h.encbuf) + enc := h.encodedBytes() - if len(h.tmp) < 32 && !force { + if len(enc) < 32 && !force { return n // Nodes smaller than 32 bytes are stored inside their parent } - return h.hashData(h.tmp) + return h.hashData(enc) +} + +// encodedBytes returns the result of the last encoding operation on h.encbuf. +// This also resets the encoder buffer. +// +// All node encoding must be done like this: +// +// node.encode(h.encbuf) +// enc := h.encodedBytes() +// +// This convention exists because node.encode can only be inlined/escape-analyzed when +// called on a concrete receiver type. +func (h *hasher) encodedBytes() []byte { + h.tmp = h.encbuf.AppendToBytes(h.tmp[:0]) + h.encbuf.Reset(nil) + return h.tmp } // hashData hashes the provided data @@ -188,8 +190,8 @@ func (h *hasher) hashData(data []byte) HashNode { } // proofHash is used to construct trie proofs, and returns the 'collapsed' -// Node (for later RLP encoding) aswell as the hashed Node -- unless the -// Node is smaller than 32 bytes, in which case it will be returned as is. +// node (for later RLP encoding) as well as the hashed node -- unless the +// node is smaller than 32 bytes, in which case it will be returned as is. // This method does not do anything on value- or hash-nodes. func (h *hasher) proofHash(original Node) (collapsed, hashed Node) { switch n := original.(type) { diff --git a/trie/node.go b/trie/node.go index ffb2f18116..fbbe293413 100644 --- a/trie/node.go +++ b/trie/node.go @@ -30,6 +30,7 @@ var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b type Node interface { fstring(string) string Cache() (HashNode, bool) + encode(w rlp.EncoderBuffer) } type ( @@ -52,16 +53,9 @@ var nilValueNode = ValueNode(nil) // EncodeRLP encodes a full Node into the consensus RLP format. func (n *FullNode) EncodeRLP(w io.Writer) error { - var nodes [17]Node - - for i, child := range &n.Children { - if child != nil { - nodes[i] = child - } else { - nodes[i] = nilValueNode - } - } - return rlp.Encode(w, nodes) + eb := rlp.NewEncoderBuffer(w) + n.encode(eb) + return eb.Flush() } func (n *FullNode) copy() *FullNode { copy := *n; return © } diff --git a/trie/node_enc.go b/trie/node_enc.go new file mode 100644 index 0000000000..b5b0660f2d --- /dev/null +++ b/trie/node_enc.go @@ -0,0 +1,64 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "github.com/tomochain/tomochain/rlp" +) + +func nodeToBytes(n Node) []byte { + w := rlp.NewEncoderBuffer(nil) + n.encode(w) + result := w.ToBytes() + w.Flush() + return result +} + +func (n *FullNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + for _, c := range n.Children { + if c != nil { + c.encode(w) + } else { + w.Write(rlp.EmptyString) + } + } + w.ListEnd(offset) +} + +func (n *ShortNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + w.WriteBytes(n.Key) + if n.Val != nil { + n.Val.encode(w) + } else { + w.Write(rlp.EmptyString) + } + w.ListEnd(offset) +} + +func (n HashNode) encode(w rlp.EncoderBuffer) { + w.WriteBytes(n) +} + +func (n ValueNode) encode(w rlp.EncoderBuffer) { + w.WriteBytes(n) +} + +func (n rawNode) encode(w rlp.EncoderBuffer) { + w.Write(n) +} From 33149c9b1c9e763152d388c4fd8a1ff08d72e597 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sun, 16 Jul 2023 20:14:39 +0700 Subject: [PATCH 024/119] Implement stacktrie --- trie/stacktrie.go | 534 +++++++++++++++++++++++++++++++++++++++++ trie/stacktrie_test.go | 392 ++++++++++++++++++++++++++++++ 2 files changed, 926 insertions(+) create mode 100644 trie/stacktrie.go create mode 100644 trie/stacktrie_test.go diff --git a/trie/stacktrie.go b/trie/stacktrie.go new file mode 100644 index 0000000000..78da41f7ea --- /dev/null +++ b/trie/stacktrie.go @@ -0,0 +1,534 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bufio" + "bytes" + "encoding/gob" + "errors" + "io" + "sync" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/log" +) + +var ErrCommitDisabled = errors.New("no database for committing") + +var stPool = sync.Pool{ + New: func() interface{} { + return NewStackTrie(nil) + }, +} + +// NodeWriteFunc is used to provide all information of a dirty node for committing +// so that callers can flush nodes into database with desired scheme. +type NodeWriteFunc = func(owner common.Hash, path []byte, hash common.Hash, blob []byte) + +func stackTrieFromPool(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { + st := stPool.Get().(*StackTrie) + st.owner = owner + st.writeFn = writeFn + return st +} + +func returnToPool(st *StackTrie) { + st.Reset() + stPool.Put(st) +} + +// StackTrie is a trie implementation that expects keys to be inserted +// in order. Once it determines that a subtree will no longer be inserted +// into, it will hash it and free up the memory it uses. +type StackTrie struct { + owner common.Hash // the owner of the trie + nodeType uint8 // node type (as in branch, ext, leaf) + val []byte // value contained by this node if it's a leaf + key []byte // key chunk covered by this (leaf|ext) node + children [16]*StackTrie // list of children (for branch and exts) + writeFn NodeWriteFunc // function for committing nodes, can be nil +} + +// NewStackTrie allocates and initializes an empty trie. +func NewStackTrie(writeFn NodeWriteFunc) *StackTrie { + return &StackTrie{ + nodeType: emptyNode, + writeFn: writeFn, + } +} + +// NewStackTrieWithOwner allocates and initializes an empty trie, but with +// the additional owner field. +func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { + return &StackTrie{ + owner: owner, + nodeType: emptyNode, + writeFn: writeFn, + } +} + +// NewFromBinary initialises a serialized stacktrie with the given db. +func NewFromBinary(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) { + var st StackTrie + if err := st.UnmarshalBinary(data); err != nil { + return nil, err + } + // If a database is used, we need to recursively add it to every child + if writeFn != nil { + st.setWriter(writeFn) + } + return &st, nil +} + +// MarshalBinary implements encoding.BinaryMarshaler +func (st *StackTrie) MarshalBinary() (data []byte, err error) { + var ( + b bytes.Buffer + w = bufio.NewWriter(&b) + ) + if err := gob.NewEncoder(w).Encode(struct { + Owner common.Hash + NodeType uint8 + Val []byte + Key []byte + }{ + st.owner, + st.nodeType, + st.val, + st.key, + }); err != nil { + return nil, err + } + for _, child := range st.children { + if child == nil { + w.WriteByte(0) + continue + } + w.WriteByte(1) + if childData, err := child.MarshalBinary(); err != nil { + return nil, err + } else { + w.Write(childData) + } + } + w.Flush() + return b.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (st *StackTrie) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + return st.unmarshalBinary(r) +} + +func (st *StackTrie) unmarshalBinary(r io.Reader) error { + var dec struct { + Owner common.Hash + NodeType uint8 + Val []byte + Key []byte + } + if err := gob.NewDecoder(r).Decode(&dec); err != nil { + return err + } + st.owner = dec.Owner + st.nodeType = dec.NodeType + st.val = dec.Val + st.key = dec.Key + + var hasChild = make([]byte, 1) + for i := range st.children { + if _, err := r.Read(hasChild); err != nil { + return err + } else if hasChild[0] == 0 { + continue + } + var child StackTrie + if err := child.unmarshalBinary(r); err != nil { + return err + } + st.children[i] = &child + } + return nil +} + +func (st *StackTrie) setWriter(writeFn NodeWriteFunc) { + st.writeFn = writeFn + for _, child := range st.children { + if child != nil { + child.setWriter(writeFn) + } + } +} + +func newLeaf(owner common.Hash, key, val []byte, writeFn NodeWriteFunc) *StackTrie { + st := stackTrieFromPool(writeFn, owner) + st.nodeType = leafNode + st.key = append(st.key, key...) + st.val = val + return st +} + +func newExt(owner common.Hash, key []byte, child *StackTrie, writeFn NodeWriteFunc) *StackTrie { + st := stackTrieFromPool(writeFn, owner) + st.nodeType = extNode + st.key = append(st.key, key...) + st.children[0] = child + return st +} + +// List all values that StackTrie#nodeType can hold +const ( + emptyNode = iota + branchNode + extNode + leafNode + hashedNode +) + +// Update inserts a (key, value) pair into the stack trie. +func (st *StackTrie) Update(key, value []byte) error { + k := keybytesToHex(key) + if len(value) == 0 { + panic("deletion not supported") + } + st.insert(k[:len(k)-1], value, nil) + return nil +} + +// MustUpdate is a wrapper of Update and will omit any encountered error but +// just print out an error message. +func (st *StackTrie) MustUpdate(key, value []byte) { + if err := st.Update(key, value); err != nil { + log.Error("Unhandled trie error in StackTrie.Update", "err", err) + } +} + +func (st *StackTrie) Reset() { + st.owner = common.Hash{} + st.writeFn = nil + st.key = st.key[:0] + st.val = nil + for i := range st.children { + st.children[i] = nil + } + st.nodeType = emptyNode +} + +// Helper function that, given a full key, determines the index +// at which the chunk pointed by st.keyOffset is different from +// the same chunk in the full key. +func (st *StackTrie) getDiffIndex(key []byte) int { + for idx, nibble := range st.key { + if nibble != key[idx] { + return idx + } + } + return len(st.key) +} + +// Helper function to that inserts a (key, value) pair into +// the trie. +func (st *StackTrie) insert(key, value []byte, prefix []byte) { + switch st.nodeType { + case branchNode: /* Branch */ + idx := int(key[0]) + + // Unresolve elder siblings + for i := idx - 1; i >= 0; i-- { + if st.children[i] != nil { + if st.children[i].nodeType != hashedNode { + st.children[i].hash(append(prefix, byte(i))) + } + break + } + } + + // Add new child + if st.children[idx] == nil { + st.children[idx] = newLeaf(st.owner, key[1:], value, st.writeFn) + } else { + st.children[idx].insert(key[1:], value, append(prefix, key[0])) + } + + case extNode: /* Ext */ + // Compare both key chunks and see where they differ + diffidx := st.getDiffIndex(key) + + // Check if chunks are identical. If so, recurse into + // the child node. Otherwise, the key has to be split + // into 1) an optional common prefix, 2) the fullnode + // representing the two differing path, and 3) a leaf + // for each of the differentiated subtrees. + if diffidx == len(st.key) { + // Ext key and key segment are identical, recurse into + // the child node. + st.children[0].insert(key[diffidx:], value, append(prefix, key[:diffidx]...)) + return + } + // Save the original part. Depending if the break is + // at the extension's last byte or not, create an + // intermediate extension or use the extension's child + // node directly. + var n *StackTrie + if diffidx < len(st.key)-1 { + // Break on the non-last byte, insert an intermediate + // extension. The path prefix of the newly-inserted + // extension should also contain the different byte. + n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.writeFn) + n.hash(append(prefix, st.key[:diffidx+1]...)) + } else { + // Break on the last byte, no need to insert + // an extension node: reuse the current node. + // The path prefix of the original part should + // still be same. + n = st.children[0] + n.hash(append(prefix, st.key...)) + } + var p *StackTrie + if diffidx == 0 { + // the break is on the first byte, so + // the current node is converted into + // a branch node. + st.children[0] = nil + p = st + st.nodeType = branchNode + } else { + // the common prefix is at least one byte + // long, insert a new intermediate branch + // node. + st.children[0] = stackTrieFromPool(st.writeFn, st.owner) + st.children[0].nodeType = branchNode + p = st.children[0] + } + // Create a leaf for the inserted part + o := newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) + + // Insert both child leaves where they belong: + origIdx := st.key[diffidx] + newIdx := key[diffidx] + p.children[origIdx] = n + p.children[newIdx] = o + st.key = st.key[:diffidx] + + case leafNode: /* Leaf */ + // Compare both key chunks and see where they differ + diffidx := st.getDiffIndex(key) + + // Overwriting a key isn't supported, which means that + // the current leaf is expected to be split into 1) an + // optional extension for the common prefix of these 2 + // keys, 2) a fullnode selecting the path on which the + // keys differ, and 3) one leaf for the differentiated + // component of each key. + if diffidx >= len(st.key) { + panic("Trying to insert into existing key") + } + + // Check if the split occurs at the first nibble of the + // chunk. In that case, no prefix extnode is necessary. + // Otherwise, create that + var p *StackTrie + if diffidx == 0 { + // Convert current leaf into a branch + st.nodeType = branchNode + p = st + st.children[0] = nil + } else { + // Convert current node into an ext, + // and insert a child branch node. + st.nodeType = extNode + st.children[0] = NewStackTrieWithOwner(st.writeFn, st.owner) + st.children[0].nodeType = branchNode + p = st.children[0] + } + + // Create the two child leaves: one containing the original + // value and another containing the new value. The child leaf + // is hashed directly in order to free up some memory. + origIdx := st.key[diffidx] + p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.writeFn) + p.children[origIdx].hash(append(prefix, st.key[:diffidx+1]...)) + + newIdx := key[diffidx] + p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) + + // Finally, cut off the key part that has been passed + // over to the children. + st.key = st.key[:diffidx] + st.val = nil + + case emptyNode: /* Empty */ + st.nodeType = leafNode + st.key = key + st.val = value + + case hashedNode: + panic("trying to insert into hash") + + default: + panic("invalid type") + } +} + +// hash converts st into a 'hashedNode', if possible. Possible outcomes: +// +// 1. The rlp-encoded value was >= 32 bytes: +// - Then the 32-byte `hash` will be accessible in `st.val`. +// - And the 'st.type' will be 'hashedNode' +// +// 2. The rlp-encoded value was < 32 bytes +// - Then the <32 byte rlp-encoded value will be accessible in 'st.val'. +// - And the 'st.type' will be 'hashedNode' AGAIN +// +// This method also sets 'st.type' to hashedNode, and clears 'st.key'. +func (st *StackTrie) hash(path []byte) { + h := newHasher(false) + defer returnHasherToPool(h) + + st.hashRec(h, path) +} + +func (st *StackTrie) hashRec(hasher *hasher, path []byte) { + // The switch below sets this to the RLP-encoding of this node. + var encodedNode []byte + + switch st.nodeType { + case hashedNode: + return + + case emptyNode: + st.val = types.EmptyRootHash.Bytes() + st.key = st.key[:0] + st.nodeType = hashedNode + return + + case branchNode: + var nodes FullNode + for i, child := range st.children { + if child == nil { + nodes.Children[i] = nilValueNode + continue + } + child.hashRec(hasher, append(path, byte(i))) + if len(child.val) < 32 { + nodes.Children[i] = rawNode(child.val) + } else { + nodes.Children[i] = HashNode(child.val) + } + + // Release child back to pool. + st.children[i] = nil + returnToPool(child) + } + + nodes.encode(hasher.encbuf) + encodedNode = hasher.encodedBytes() + + case extNode: + st.children[0].hashRec(hasher, append(path, st.key...)) + + n := ShortNode{Key: hexToCompact(st.key)} + if len(st.children[0].val) < 32 { + n.Val = rawNode(st.children[0].val) + } else { + n.Val = HashNode(st.children[0].val) + } + + n.encode(hasher.encbuf) + encodedNode = hasher.encodedBytes() + + // Release child back to pool. + returnToPool(st.children[0]) + st.children[0] = nil + + case leafNode: + st.key = append(st.key, byte(16)) + n := ShortNode{Key: hexToCompact(st.key), Val: ValueNode(st.val)} + + n.encode(hasher.encbuf) + encodedNode = hasher.encodedBytes() + + default: + panic("invalid node type") + } + + st.nodeType = hashedNode + st.key = st.key[:0] + if len(encodedNode) < 32 { + st.val = common.CopyBytes(encodedNode) + return + } + + // Write the hash to the 'val'. We allocate a new val here to not mutate + // input values + st.val = hasher.hashData(encodedNode) + if st.writeFn != nil { + st.writeFn(st.owner, path, common.BytesToHash(st.val), encodedNode) + } +} + +// Hash returns the hash of the current node. +func (st *StackTrie) Hash() (h common.Hash) { + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + st.hashRec(hasher, nil) + if len(st.val) == 32 { + copy(h[:], st.val) + return h + } + // If the node's RLP isn't 32 bytes long, the node will not + // be hashed, and instead contain the rlp-encoding of the + // node. For the top level node, we need to force the hashing. + hasher.sha.Reset() + hasher.sha.Write(st.val) + hasher.sha.Read(h[:]) + return h +} + +// Commit will firstly hash the entire trie if it's still not hashed +// and then commit all nodes to the associated database. Actually most +// of the trie nodes MAY have been committed already. The main purpose +// here is to commit the root node. +// +// The associated database is expected, otherwise the whole commit +// functionality should be disabled. +func (st *StackTrie) Commit() (h common.Hash, err error) { + if st.writeFn == nil { + return common.Hash{}, ErrCommitDisabled + } + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + st.hashRec(hasher, nil) + if len(st.val) == 32 { + copy(h[:], st.val) + return h, nil + } + // If the node's RLP isn't 32 bytes long, the node will not + // be hashed (and committed), and instead contain the rlp-encoding of the + // node. For the top level node, we need to force the hashing+commit. + hasher.sha.Reset() + hasher.sha.Write(st.val) + hasher.sha.Read(h[:]) + + st.writeFn(st.owner, nil, h, st.val) + return h, nil +} diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go new file mode 100644 index 0000000000..7908290128 --- /dev/null +++ b/trie/stacktrie_test.go @@ -0,0 +1,392 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bytes" + "math/big" + "testing" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/crypto" +) + +func TestStackTrieInsertAndHash(t *testing.T) { + type KeyValueHash struct { + K string // Hex string for key. + V string // Value, directly converted to bytes. + H string // Expected root hash after insert of (K, V) to an existing trie. + } + tests := [][]KeyValueHash{ + { // {0:0, 7:0, f:0} + {"00", "v_______________________0___0", "5cb26357b95bb9af08475be00243ceb68ade0b66b5cd816b0c18a18c612d2d21"}, + {"70", "v_______________________0___1", "8ff64309574f7a437a7ad1628e690eb7663cfde10676f8a904a8c8291dbc1603"}, + {"f0", "v_______________________0___2", "9e3a01bd8d43efb8e9d4b5506648150b8e3ed1caea596f84ee28e01a72635470"}, + }, + { // {1:0cc, e:{1:fc, e:fc}} + {"10cc", "v_______________________1___0", "233e9b257843f3dfdb1cce6676cdaf9e595ac96ee1b55031434d852bc7ac9185"}, + {"e1fc", "v_______________________1___1", "39c5e908ae83d0c78520c7c7bda0b3782daf594700e44546e93def8f049cca95"}, + {"eefc", "v_______________________1___2", "d789567559fd76fe5b7d9cc42f3750f942502ac1c7f2a466e2f690ec4b6c2a7c"}, + }, + { // {b:{a:ac, b:ac}, d:acc} + {"baac", "v_______________________2___0", "8be1c86ba7ec4c61e14c1a9b75055e0464c2633ae66a055a24e75450156a5d42"}, + {"bbac", "v_______________________2___1", "8495159b9895a7d88d973171d737c0aace6fe6ac02a4769fff1bc43bcccce4cc"}, + {"dacc", "v_______________________2___2", "9bcfc5b220a27328deb9dc6ee2e3d46c9ebc9c69e78acda1fa2c7040602c63ca"}, + }, + { // {0:0cccc, 2:456{0:0, 2:2} + {"00cccc", "v_______________________3___0", "e57dc2785b99ce9205080cb41b32ebea7ac3e158952b44c87d186e6d190a6530"}, + {"245600", "v_______________________3___1", "0335354adbd360a45c1871a842452287721b64b4234dfe08760b243523c998db"}, + {"245622", "v_______________________3___2", "9e6832db0dca2b5cf81c0e0727bfde6afc39d5de33e5720bccacc183c162104e"}, + }, + { // {1:4567{1:1c, 3:3c}, 3:0cccccc} + {"1456711c", "v_______________________4___0", "f2389e78d98fed99f3e63d6d1623c1d4d9e8c91cb1d585de81fbc7c0e60d3529"}, + {"1456733c", "v_______________________4___1", "101189b3fab852be97a0120c03d95eefcf984d3ed639f2328527de6def55a9c0"}, + {"30cccccc", "v_______________________4___2", "3780ce111f98d15751dfde1eb21080efc7d3914b429e5c84c64db637c55405b3"}, + }, + { // 8800{1:f, 2:e, 3:d} + {"88001f", "v_______________________5___0", "e817db50d84f341d443c6f6593cafda093fc85e773a762421d47daa6ac993bd5"}, + {"88002e", "v_______________________5___1", "d6e3e6047bdc110edd296a4d63c030aec451bee9d8075bc5a198eee8cda34f68"}, + {"88003d", "v_______________________5___2", "b6bdf8298c703342188e5f7f84921a402042d0e5fb059969dd53a6b6b1fb989e"}, + }, + { // 0{1:fc, 2:ec, 4:dc} + {"01fc", "v_______________________6___0", "693268f2ca80d32b015f61cd2c4dba5a47a6b52a14c34f8e6945fad684e7a0d5"}, + {"02ec", "v_______________________6___1", "e24ddd44469310c2b785a2044618874bf486d2f7822603a9b8dce58d6524d5de"}, + {"04dc", "v_______________________6___2", "33fc259629187bbe54b92f82f0cd8083b91a12e41a9456b84fc155321e334db7"}, + }, + { // f{0:fccc, f:ff{0:f, f:f}} + {"f0fccc", "v_______________________7___0", "b0966b5aa469a3e292bc5fcfa6c396ae7a657255eef552ea7e12f996de795b90"}, + {"ffff0f", "v_______________________7___1", "3b1ca154ec2a3d96d8d77bddef0abfe40a53a64eb03cecf78da9ec43799fa3d0"}, + {"ffffff", "v_______________________7___2", "e75463041f1be8252781be0ace579a44ea4387bf5b2739f4607af676f7719678"}, + }, + { // ff{0:f{0:f, f:f}, f:fcc} + {"ff0f0f", "v_______________________8___0", "0928af9b14718ec8262ab89df430f1e5fbf66fac0fed037aff2b6767ae8c8684"}, + {"ff0fff", "v_______________________8___1", "d870f4d3ce26b0bf86912810a1960693630c20a48ba56be0ad04bc3e9ddb01e6"}, + {"ffffcc", "v_______________________8___2", "4239f10dd9d9915ecf2e047d6a576bdc1733ed77a30830f1bf29deaf7d8e966f"}, + }, + { + {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"}, + {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"}, + {"123f", "x___________________________2", "1164d7299964e74ac40d761f9189b2a3987fae959800d0f7e29d3aaf3eae9e15"}, + }, + { + {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"}, + {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"}, + {"124a", "x___________________________2", "661a96a669869d76b7231380da0649d013301425fbea9d5c5fae6405aa31cfce"}, + }, + { + {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"}, + {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"}, + {"13aa", "x___________________________2", "6590120e1fd3ffd1a90e8de5bb10750b61079bb0776cca4414dd79a24e4d4356"}, + }, + { + {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"}, + {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"}, + {"2aaa", "x___________________________2", "f869b40e0c55eace1918332ef91563616fbf0755e2b946119679f7ef8e44b514"}, + }, + { + {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"}, + {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"}, + {"1234fa", "x___________________________2", "4f4e368ab367090d5bc3dbf25f7729f8bd60df84de309b4633a6b69ab66142c0"}, + }, + { + {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"}, + {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"}, + {"1235aa", "x___________________________2", "21840121d11a91ac8bbad9a5d06af902a5c8d56a47b85600ba813814b7bfcb9b"}, + }, + { + {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"}, + {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"}, + {"124aaa", "x___________________________2", "ea4040ddf6ae3fbd1524bdec19c0ab1581015996262006632027fa5cf21e441e"}, + }, + { + {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"}, + {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"}, + {"13aaaa", "x___________________________2", "e4beb66c67e44f2dd8ba36036e45a44ff68f8d52942472b1911a45f886a34507"}, + }, + { + {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"}, + {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"}, + {"2aaaaa", "x___________________________2", "5f5989b820ff5d76b7d49e77bb64f26602294f6c42a1a3becc669cd9e0dc8ec9"}, + }, + { + {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"}, + {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"}, + {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"}, + {"1234fa", "x___________________________3", "65bb3aafea8121111d693ffe34881c14d27b128fd113fa120961f251fe28428d"}, + }, + { + {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"}, + {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"}, + {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"}, + {"1235aa", "x___________________________3", "f670e4d2547c533c5f21e0045442e2ecb733f347ad6d29ef36e0f5ba31bb11a8"}, + }, + { + {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"}, + {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"}, + {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"}, + {"124aaa", "x___________________________3", "c17464123050a9a6f29b5574bb2f92f6d305c1794976b475b7fb0316b6335598"}, + }, + { + {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"}, + {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"}, + {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"}, + {"13aaaa", "x___________________________3", "aa8301be8cb52ea5cd249f5feb79fb4315ee8de2140c604033f4b3fff78f0105"}, + }, + { + {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"}, + {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"}, + {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"}, + {"123f", "x___________________________3", "80f7bad1893ca57e3443bb3305a517723a74d3ba831bcaca22a170645eb7aafb"}, + }, + { + {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"}, + {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"}, + {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"}, + {"124a", "x___________________________3", "383bc1bb4f019e6bc4da3751509ea709b58dd1ac46081670834bae072f3e9557"}, + }, + { + {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"}, + {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"}, + {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"}, + {"13aa", "x___________________________3", "ff0dc70ce2e5db90ee42a4c2ad12139596b890e90eb4e16526ab38fa465b35cf"}, + }, + } + st := NewStackTrie(nil) + for i, test := range tests { + // The StackTrie does not allow Insert(), Hash(), Insert(), ... + // so we will create new trie for every sequence length of inserts. + for l := 1; l <= len(test); l++ { + st.Reset() + for j := 0; j < l; j++ { + kv := &test[j] + if err := st.Update(common.FromHex(kv.K), []byte(kv.V)); err != nil { + t.Fatal(err) + } + } + expected := common.HexToHash(test[l-1].H) + if h := st.Hash(); h != expected { + t.Errorf("%d(%d): root hash mismatch: %x, expected %x", i, l, h, expected) + } + } + } +} + +func TestSizeBug(t *testing.T) { + st := NewStackTrie(nil) + nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + + leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") + value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") + + nt.Update(leaf, value) + st.Update(leaf, value) + + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +func TestEmptyBug(t *testing.T) { + st := NewStackTrie(nil) + nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + + //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") + //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") + kvs := []struct { + K string + V string + }{ + {K: "405787fa12a823e0f2b7631cc41b3ba8828b3321ca811111fa75cd3aa3bb5ace", V: "9496f4ec2bf9dab484cac6be589e8417d84781be08"}, + {K: "40edb63a35fcf86c08022722aa3287cdd36440d671b4918131b2514795fefa9c", V: "01"}, + {K: "b10e2d527612073b26eecdfd717e6a320cf44b4afac2b0732d9fcbe2b7fa0cf6", V: "947a30f7736e48d6599356464ba4c150d8da0302ff"}, + {K: "c2575a0e9e593c00f959f8c92f12db2869c3395a3b0502d05e2516446f71f85b", V: "02"}, + } + + for _, kv := range kvs { + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + } + + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +func TestValLength56(t *testing.T) { + st := NewStackTrie(nil) + nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + + //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") + //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") + kvs := []struct { + K string + V string + }{ + {K: "405787fa12a823e0f2b7631cc41b3ba8828b3321ca811111fa75cd3aa3bb5ace", V: "1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"}, + } + + for _, kv := range kvs { + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + } + + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +// TestUpdateSmallNodes tests a case where the leaves are small (both key and value), +// which causes a lot of node-within-node. This case was found via fuzzing. +func TestUpdateSmallNodes(t *testing.T) { + st := NewStackTrie(nil) + nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + kvs := []struct { + K string + V string + }{ + {"63303030", "3041"}, // stacktrie.Update + {"65", "3000"}, // stacktrie.Update + } + for _, kv := range kvs { + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + } + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +// TestUpdateVariableKeys contains a case which stacktrie fails: when keys of different +// sizes are used, and the second one has the same prefix as the first, then the +// stacktrie fails, since it's unable to 'expand' on an already added leaf. +// For all practical purposes, this is fine, since keys are fixed-size length +// in account and storage tries. +// +// The test is marked as 'skipped', and exists just to have the behaviour documented. +// This case was found via fuzzing. +func TestUpdateVariableKeys(t *testing.T) { + t.SkipNow() + st := NewStackTrie(nil) + nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + kvs := []struct { + K string + V string + }{ + {"0x33303534636532393561313031676174", "303030"}, + {"0x3330353463653239356131303167617430", "313131"}, + } + for _, kv := range kvs { + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + } + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +// TestStacktrieNotModifyValues checks that inserting blobs of data into the +// stacktrie does not mutate the blobs +func TestStacktrieNotModifyValues(t *testing.T) { + st := NewStackTrie(nil) + { // Test a very small trie + // Give it the value as a slice with large backing alloc, + // so if the stacktrie tries to append, it won't have to realloc + value := make([]byte, 1, 100) + value[0] = 0x2 + want := common.CopyBytes(value) + st.Update([]byte{0x01}, value) + st.Hash() + if have := value; !bytes.Equal(have, want) { + t.Fatalf("tiny trie: have %#x want %#x", have, want) + } + st = NewStackTrie(nil) + } + // Test with a larger trie + keyB := big.NewInt(1) + keyDelta := big.NewInt(1) + var vals [][]byte + getValue := func(i int) []byte { + if i%2 == 0 { // large + return crypto.Keccak256(big.NewInt(int64(i)).Bytes()) + } else { //small + return big.NewInt(int64(i)).Bytes() + } + } + for i := 0; i < 1000; i++ { + key := common.BigToHash(keyB) + value := getValue(i) + st.Update(key.Bytes(), value) + vals = append(vals, value) + keyB = keyB.Add(keyB, keyDelta) + keyDelta.Add(keyDelta, common.Big1) + } + st.Hash() + for i := 0; i < 1000; i++ { + want := getValue(i) + + have := vals[i] + if !bytes.Equal(have, want) { + t.Fatalf("item %d, have %#x want %#x", i, have, want) + } + } +} + +// TestStacktrieSerialization tests that the stacktrie works well if we +// serialize/unserialize it a lot +func TestStacktrieSerialization(t *testing.T) { + var ( + st = NewStackTrie(nil) + nt = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + keyB = big.NewInt(1) + keyDelta = big.NewInt(1) + vals [][]byte + keys [][]byte + ) + getValue := func(i int) []byte { + if i%2 == 0 { // large + return crypto.Keccak256(big.NewInt(int64(i)).Bytes()) + } else { //small + return big.NewInt(int64(i)).Bytes() + } + } + for i := 0; i < 10; i++ { + vals = append(vals, getValue(i)) + keys = append(keys, common.BigToHash(keyB).Bytes()) + keyB = keyB.Add(keyB, keyDelta) + keyDelta.Add(keyDelta, common.Big1) + } + for i, k := range keys { + nt.Update(k, common.CopyBytes(vals[i])) + } + + for i, k := range keys { + blob, err := st.MarshalBinary() + if err != nil { + t.Fatal(err) + } + newSt, err := NewFromBinary(blob, nil) + if err != nil { + t.Fatal(err) + } + st = newSt + st.Update(k, common.CopyBytes(vals[i])) + } + if have, want := st.Hash(), nt.Hash(); have != want { + t.Fatalf("have %#x want %#x", have, want) + } +} From 2e9abc43d5c689eec15eac752c9db9f8c440fd94 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sun, 16 Jul 2023 20:22:40 +0700 Subject: [PATCH 025/119] Add encode method for rawShortNode and rawFullNode --- trie/hasher.go | 11 +++++++++++ trie/node_enc.go | 8 ++++++++ 2 files changed, 19 insertions(+) diff --git a/trie/hasher.go b/trie/hasher.go index 0cda29e3f5..d4a36dd5ed 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -23,6 +23,17 @@ import ( "github.com/tomochain/tomochain/rlp" ) +type sliceBuffer []byte + +func (b *sliceBuffer) Write(data []byte) (n int, err error) { + *b = append(*b, data...) + return len(data), nil +} + +func (b *sliceBuffer) Reset() { + *b = (*b)[:0] +} + // hasher is a type used for the trie Hash operation. A hasher has some // internal preallocated temp space type hasher struct { diff --git a/trie/node_enc.go b/trie/node_enc.go index b5b0660f2d..b987abfbf5 100644 --- a/trie/node_enc.go +++ b/trie/node_enc.go @@ -62,3 +62,11 @@ func (n ValueNode) encode(w rlp.EncoderBuffer) { func (n rawNode) encode(w rlp.EncoderBuffer) { w.Write(n) } + +func (n rawShortNode) encode(w rlp.EncoderBuffer) { + panic("this should never end up in a live trie") +} + +func (n rawFullNode) encode(w rlp.EncoderBuffer) { + panic("this should never end up in a live trie") +} From 008274fc410b37af86e7a9d7164e83c701459fe8 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sun, 16 Jul 2023 20:25:46 +0700 Subject: [PATCH 026/119] Fix import cycle --- consensus/posv/posv.go | 8 ++++---- trie/stacktrie.go | 9 ++++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/consensus/posv/posv.go b/consensus/posv/posv.go index 0027104970..f2b48fde93 100644 --- a/consensus/posv/posv.go +++ b/consensus/posv/posv.go @@ -21,9 +21,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" "io/ioutil" "math/big" "math/rand" @@ -50,6 +47,9 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) const ( @@ -1146,7 +1146,7 @@ func (c *Posv) CacheData(header *types.Header, txs []*types.Transaction, receipt signTxs := []*types.Transaction{} for _, tx := range txs { if tx.IsSigningTransaction() { - var b uint + var b uint64 for _, r := range receipts { if r.TxHash == tx.Hash() { if len(r.PostState) > 0 { diff --git a/trie/stacktrie.go b/trie/stacktrie.go index 78da41f7ea..f640a61085 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -25,11 +25,14 @@ import ( "sync" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/log" ) -var ErrCommitDisabled = errors.New("no database for committing") +var ( + emptyRootHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + + ErrCommitDisabled = errors.New("no database for committing") +) var stPool = sync.Pool{ New: func() interface{} { @@ -414,7 +417,7 @@ func (st *StackTrie) hashRec(hasher *hasher, path []byte) { return case emptyNode: - st.val = types.EmptyRootHash.Bytes() + st.val = emptyRootHash.Bytes() st.key = st.key[:0] st.nodeType = hashedNode return From f88dba48b80e6186ead7956ec6d1c0c2da962604 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sun, 16 Jul 2023 20:34:48 +0700 Subject: [PATCH 027/119] Fix stacktrie unit tests --- accounts/abi/bind/backends/simulated.go | 7 +++--- trie/stacktrie.go | 8 ++---- trie/stacktrie_test.go | 33 ++++++++++++++++++++----- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7411f492a8..6f4039654e 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -20,8 +20,6 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sync" "time" @@ -30,9 +28,11 @@ import ( "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -202,7 +202,7 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.Call return rval, err } -//FIXME: please use copyState for this function +// FIXME: please use copyState for this function // CallContractWithState executes a contract call at the given state. func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { // Ensure message is initialized properly. @@ -285,7 +285,6 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM snapshot := b.pendingState.Snapshot() _, _, failed, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) - fmt.Println("EstimateGas",err,failed) b.pendingState.RevertToSnapshot(snapshot) if err != nil || failed { diff --git a/trie/stacktrie.go b/trie/stacktrie.go index f640a61085..48417e556c 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -28,11 +28,7 @@ import ( "github.com/tomochain/tomochain/log" ) -var ( - emptyRootHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - - ErrCommitDisabled = errors.New("no database for committing") -) +var ErrCommitDisabled = errors.New("no database for committing") var stPool = sync.Pool{ New: func() interface{} { @@ -417,7 +413,7 @@ func (st *StackTrie) hashRec(hasher *hasher, path []byte) { return case emptyNode: - st.val = emptyRootHash.Bytes() + st.val = emptyRoot.Bytes() st.key = st.key[:0] st.nodeType = hashedNode return diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go index 7908290128..dd5206c87c 100644 --- a/trie/stacktrie_test.go +++ b/trie/stacktrie_test.go @@ -188,7 +188,10 @@ func TestStackTrieInsertAndHash(t *testing.T) { func TestSizeBug(t *testing.T) { st := NewStackTrie(nil) - nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") @@ -203,7 +206,10 @@ func TestSizeBug(t *testing.T) { func TestEmptyBug(t *testing.T) { st := NewStackTrie(nil) - nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") @@ -229,7 +235,10 @@ func TestEmptyBug(t *testing.T) { func TestValLength56(t *testing.T) { st := NewStackTrie(nil) - nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") @@ -254,7 +263,11 @@ func TestValLength56(t *testing.T) { // which causes a lot of node-within-node. This case was found via fuzzing. func TestUpdateSmallNodes(t *testing.T) { st := NewStackTrie(nil) - nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + kvs := []struct { K string V string @@ -282,7 +295,11 @@ func TestUpdateSmallNodes(t *testing.T) { func TestUpdateVariableKeys(t *testing.T) { t.SkipNow() st := NewStackTrie(nil) - nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + kvs := []struct { K string V string @@ -351,12 +368,16 @@ func TestStacktrieNotModifyValues(t *testing.T) { func TestStacktrieSerialization(t *testing.T) { var ( st = NewStackTrie(nil) - nt = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) keyB = big.NewInt(1) keyDelta = big.NewInt(1) vals [][]byte keys [][]byte ) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + getValue := func(i int) []byte { if i%2 == 0 { // large return crypto.Keccak256(big.NewInt(int64(i)).Bytes()) From af788c6184bbaa16efb2392f07321659fb3101eb Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sun, 16 Jul 2023 20:48:00 +0700 Subject: [PATCH 028/119] Convert status of receipts from uint to uint64 --- accounts/abi/bind/backends/simulated.go | 7 +++---- consensus/posv/posv.go | 8 ++++---- core/types/gen_log_rlp.go | 26 ------------------------- 3 files changed, 7 insertions(+), 34 deletions(-) delete mode 100644 core/types/gen_log_rlp.go diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7411f492a8..6f4039654e 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -20,8 +20,6 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sync" "time" @@ -30,9 +28,11 @@ import ( "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -202,7 +202,7 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.Call return rval, err } -//FIXME: please use copyState for this function +// FIXME: please use copyState for this function // CallContractWithState executes a contract call at the given state. func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { // Ensure message is initialized properly. @@ -285,7 +285,6 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM snapshot := b.pendingState.Snapshot() _, _, failed, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) - fmt.Println("EstimateGas",err,failed) b.pendingState.RevertToSnapshot(snapshot) if err != nil || failed { diff --git a/consensus/posv/posv.go b/consensus/posv/posv.go index 0027104970..f2b48fde93 100644 --- a/consensus/posv/posv.go +++ b/consensus/posv/posv.go @@ -21,9 +21,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" "io/ioutil" "math/big" "math/rand" @@ -50,6 +47,9 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) const ( @@ -1146,7 +1146,7 @@ func (c *Posv) CacheData(header *types.Header, txs []*types.Transaction, receipt signTxs := []*types.Transaction{} for _, tx := range txs { if tx.IsSigningTransaction() { - var b uint + var b uint64 for _, r := range receipts { if r.TxHash == tx.Hash() { if len(r.PostState) > 0 { diff --git a/core/types/gen_log_rlp.go b/core/types/gen_log_rlp.go deleted file mode 100644 index 3f2c3ddc06..0000000000 --- a/core/types/gen_log_rlp.go +++ /dev/null @@ -1,26 +0,0 @@ -// Code generated by rlpgen. DO NOT EDIT. - -//go:build !norlpgen -// +build !norlpgen - -package types - -import ( - "io" - - "github.com/tomochain/tomochain/rlp" -) - -func (obj *rlpLog) EncodeRLP(_w io.Writer) error { - w := rlp.NewEncoderBuffer(_w) - _tmp0 := w.List() - w.WriteBytes(obj.Address[:]) - _tmp1 := w.List() - for _, _tmp2 := range obj.Topics { - w.WriteBytes(_tmp2[:]) - } - w.ListEnd(_tmp1) - w.WriteBytes(obj.Data) - w.ListEnd(_tmp0) - return w.Flush() -} From b440640a18e0a8dd29d3850de1fd25d2ef0deb9c Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sun, 16 Jul 2023 20:52:33 +0700 Subject: [PATCH 029/119] Fix unit tests --- core/bench_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/bench_test.go b/core/bench_test.go index 137b57f031..c217b2d946 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -18,7 +18,6 @@ package core import ( "crypto/ecdsa" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "os" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -294,7 +294,7 @@ func benchReadChain(b *testing.B, full bool, count uint64) { if full { hash := header.Hash() GetBody(db, hash, n) - GetBlockReceipts(db, hash, n) + GetBlockReceipts(db, hash, n, chain.Config()) } } From e0f71f51864b776d31e6fe0055fc8a2e333835b2 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Sun, 16 Jul 2023 21:20:29 +0700 Subject: [PATCH 030/119] Fix unit tests --- core/database_util_test.go | 65 ++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 13 deletions(-) diff --git a/core/database_util_test.go b/core/database_util_test.go index 19cbf790fb..5e01e274f8 100644 --- a/core/database_util_test.go +++ b/core/database_util_test.go @@ -18,6 +18,8 @@ package core import ( "bytes" + "encoding/hex" + "fmt" "math/big" "testing" @@ -337,6 +339,13 @@ func TestLookupStorage(t *testing.T) { func TestBlockReceiptStorage(t *testing.T) { db := rawdb.NewMemoryDatabase() + // Create a live block since we need metadata to reconstruct the receipt + tx1 := types.NewTransaction(1, common.HexToAddress("0x1"), big.NewInt(1), 1, big.NewInt(1), nil) + tx2 := types.NewTransaction(2, common.HexToAddress("0x2"), big.NewInt(2), 2, big.NewInt(2), nil) + + body := &types.Body{Transactions: types.Transactions{tx1, tx2}} + + // Create the two receipts to manage afterwards receipt1 := &types.Receipt{ Status: types.ReceiptStatusFailed, CumulativeGasUsed: 1, @@ -344,10 +353,12 @@ func TestBlockReceiptStorage(t *testing.T) { {Address: common.BytesToAddress([]byte{0x11})}, {Address: common.BytesToAddress([]byte{0x01, 0x11})}, }, - TxHash: common.BytesToHash([]byte{0x11, 0x11}), + TxHash: tx1.Hash(), ContractAddress: common.BytesToAddress([]byte{0x01, 0x11, 0x11}), GasUsed: 111111, } + receipt1.Bloom = types.CreateBloom(types.Receipts{receipt1}) + receipt2 := &types.Receipt{ PostState: common.Hash{2}.Bytes(), CumulativeGasUsed: 2, @@ -355,10 +366,11 @@ func TestBlockReceiptStorage(t *testing.T) { {Address: common.BytesToAddress([]byte{0x22})}, {Address: common.BytesToAddress([]byte{0x02, 0x22})}, }, - TxHash: common.BytesToHash([]byte{0x22, 0x22}), + TxHash: tx2.Hash(), ContractAddress: common.BytesToAddress([]byte{0x02, 0x22, 0x22}), GasUsed: 222222, } + receipt2.Bloom = types.CreateBloom(types.Receipts{receipt2}) receipts := []*types.Receipt{receipt1, receipt2} // Check that no receipt entries are in a pristine database @@ -366,25 +378,52 @@ func TestBlockReceiptStorage(t *testing.T) { if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) != 0 { t.Fatalf("non existent receipts returned: %v", rs) } + // Insert the body that corresponds to the receipts + WriteBody(db, hash, 0, body) + // Insert the receipt slice into the database and check presence - if err := WriteBlockReceipts(db, hash, 0, receipts); err != nil { - t.Fatalf("failed to write block receipts: %v", err) - } + WriteBlockReceipts(db, hash, 0, receipts) if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) == 0 { t.Fatalf("no receipts returned") } else { - for i := 0; i < len(receipts); i++ { - rlpHave, _ := rlp.EncodeToBytes(rs[i]) - rlpWant, _ := rlp.EncodeToBytes(receipts[i]) - - if !bytes.Equal(rlpHave, rlpWant) { - t.Fatalf("receipt #%d: receipt mismatch: have %v, want %v", i, rs[i], receipts[i]) - } + if err := checkReceiptsRLP(rs, receipts); err != nil { + t.Fatalf(err.Error()) } } - // Delete the receipt slice and check purge + // Delete the body and ensure that the receipts are no longer returned (metadata can't be recomputed) + DeleteBody(db, hash, 0) + if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); rs != nil { + t.Fatalf("receipts returned when body was deleted: %v", rs) + } + // Ensure that receipts without metadata can be returned without the block body too + if err := checkReceiptsRLP(ReadRawReceipts(db, hash, 0), receipts); err != nil { + t.Fatalf(err.Error()) + } + // Sanity check that body alone without the receipt is a full purge + WriteBody(db, hash, 0, body) + DeleteBlockReceipts(db, hash, 0) if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) != 0 { t.Fatalf("deleted receipts returned: %v", rs) } } + +func checkReceiptsRLP(have, want types.Receipts) error { + if len(have) != len(want) { + return fmt.Errorf("receipts sizes mismatch: have %d, want %d", len(have), len(want)) + } + for i := 0; i < len(want); i++ { + rlpHave, err := rlp.EncodeToBytes(have[i]) + if err != nil { + return err + } + rlpWant, err := rlp.EncodeToBytes(want[i]) + if err != nil { + return err + } + if !bytes.Equal(rlpHave, rlpWant) { + return fmt.Errorf("receipt #%d: receipt mismatch: have %s, want %s", i, hex.EncodeToString(rlpHave), hex.EncodeToString(rlpWant)) + } + } + return nil +} From 986a3c49b90513d4b05bb9ccf31cc64423d491a7 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 17 Jul 2023 10:35:55 +0700 Subject: [PATCH 031/119] Remove irrelevant code and fix unit tests --- core/rawdb/database.go | 154 ------------------------------- core/state/managed_state_test.go | 2 +- core/state/statedb.go | 82 +++------------- core/state/statedb_test.go | 24 +++-- metrics/metrics.go | 7 +- 5 files changed, 36 insertions(+), 233 deletions(-) diff --git a/core/rawdb/database.go b/core/rawdb/database.go index f46f7ec8e7..cf80d12d0d 100644 --- a/core/rawdb/database.go +++ b/core/rawdb/database.go @@ -17,17 +17,12 @@ package rawdb import ( - "bytes" "fmt" - "os" - "time" - "github.com/olekukonko/tablewriter" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/ethdb/leveldb" "github.com/tomochain/tomochain/ethdb/memorydb" - "github.com/tomochain/tomochain/log" ) // freezerdb is a database wrapper that enabled freezer data retrievals. @@ -145,152 +140,3 @@ func (s *stat) Size() string { func (s *stat) Count() string { return s.count.String() } - -// InspectDatabase traverses the entire database and checks the size -// of all different categories of data. -func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { - it := db.NewIterator(keyPrefix, keyStart) - defer it.Release() - - var ( - count int64 - start = time.Now() - logged = time.Now() - - // Key-value store statistics - headers stat - bodies stat - receipts stat - tds stat - numHashPairings stat - hashNumPairings stat - tries stat - codes stat - txLookups stat - accountSnaps stat - storageSnaps stat - preimages stat - bloomBits stat - cliqueSnaps stat - - // Ancient store statistics - ancientHeadersSize common.StorageSize - ancientBodiesSize common.StorageSize - ancientReceiptsSize common.StorageSize - ancientTdsSize common.StorageSize - ancientHashesSize common.StorageSize - - // Les statistic - chtTrieNodes stat - bloomTrieNodes stat - - // Meta- and unaccounted data - metadata stat - unaccounted stat - - // Totals - total common.StorageSize - ) - // Inspect key-value database first. - for it.Next() { - var ( - key = it.Key() - size = common.StorageSize(len(key) + len(it.Value())) - ) - total += size - switch { - case bytes.HasPrefix(key, headerPrefix) && len(key) == (len(headerPrefix)+8+common.HashLength): - headers.Add(size) - case bytes.HasPrefix(key, blockBodyPrefix) && len(key) == (len(blockBodyPrefix)+8+common.HashLength): - bodies.Add(size) - case bytes.HasPrefix(key, blockReceiptsPrefix) && len(key) == (len(blockReceiptsPrefix)+8+common.HashLength): - receipts.Add(size) - case bytes.HasPrefix(key, headerPrefix) && bytes.HasSuffix(key, headerTDSuffix): - tds.Add(size) - case bytes.HasPrefix(key, headerPrefix) && bytes.HasSuffix(key, headerHashSuffix): - numHashPairings.Add(size) - case bytes.HasPrefix(key, headerNumberPrefix) && len(key) == (len(headerNumberPrefix)+common.HashLength): - hashNumPairings.Add(size) - case len(key) == common.HashLength: - tries.Add(size) - // case bytes.HasPrefix(key, codePrefix) && len(key) == len(codePrefix)+common.HashLength: - // codes.Add(size) - case bytes.HasPrefix(key, txLookupPrefix) && len(key) == (len(txLookupPrefix)+common.HashLength): - txLookups.Add(size) - case bytes.HasPrefix(key, SnapshotAccountPrefix) && len(key) == (len(SnapshotAccountPrefix)+common.HashLength): - accountSnaps.Add(size) - case bytes.HasPrefix(key, SnapshotStoragePrefix) && len(key) == (len(SnapshotStoragePrefix)+2*common.HashLength): - storageSnaps.Add(size) - case bytes.HasPrefix(key, preimagePrefix) && len(key) == (len(preimagePrefix)+common.HashLength): - preimages.Add(size) - case bytes.HasPrefix(key, bloomBitsPrefix) && len(key) == (len(bloomBitsPrefix)+10+common.HashLength): - bloomBits.Add(size) - case bytes.HasPrefix(key, []byte("clique-")) && len(key) == 7+common.HashLength: - cliqueSnaps.Add(size) - case bytes.HasPrefix(key, []byte("cht-")) && len(key) == 4+common.HashLength: - chtTrieNodes.Add(size) - case bytes.HasPrefix(key, []byte("blt-")) && len(key) == 4+common.HashLength: - bloomTrieNodes.Add(size) - default: - var accounted bool - for _, meta := range [][]byte{databaseVerisionKey, headHeaderKey, headBlockKey, headFastBlockKey, fastTrieProgressKey} { - if bytes.Equal(key, meta) { - metadata.Add(size) - accounted = true - break - } - } - if !accounted { - unaccounted.Add(size) - } - } - count += 1 - if count%1000 == 0 && time.Since(logged) > 8*time.Second { - log.Info("Inspecting database", "count", count, "elapsed", common.PrettyDuration(time.Since(start))) - logged = time.Now() - } - } - // Inspect append-only file store then. - ancientSizes := []*common.StorageSize{&ancientHeadersSize, &ancientBodiesSize, &ancientReceiptsSize, &ancientHashesSize, &ancientTdsSize} - for i, category := range []string{freezerHeaderTable, freezerBodiesTable, freezerReceiptTable, freezerHashTable, freezerDifficultyTable} { - if size, err := db.AncientSize(category); err == nil { - *ancientSizes[i] += common.StorageSize(size) - total += common.StorageSize(size) - } - } - // Display the database statistic. - stats := [][]string{ - {"Key-Value store", "Headers", headers.Size(), headers.Count()}, - {"Key-Value store", "Bodies", bodies.Size(), bodies.Count()}, - {"Key-Value store", "Receipt lists", receipts.Size(), receipts.Count()}, - {"Key-Value store", "Difficulties", tds.Size(), tds.Count()}, - {"Key-Value store", "Block number->hash", numHashPairings.Size(), numHashPairings.Count()}, - {"Key-Value store", "Block hash->number", hashNumPairings.Size(), hashNumPairings.Count()}, - {"Key-Value store", "Transaction index", txLookups.Size(), txLookups.Count()}, - {"Key-Value store", "Bloombit index", bloomBits.Size(), bloomBits.Count()}, - {"Key-Value store", "Contract codes", codes.Size(), codes.Count()}, - {"Key-Value store", "Trie nodes", tries.Size(), tries.Count()}, - {"Key-Value store", "Trie preimages", preimages.Size(), preimages.Count()}, - {"Key-Value store", "Account snapshot", accountSnaps.Size(), accountSnaps.Count()}, - {"Key-Value store", "Storage snapshot", storageSnaps.Size(), storageSnaps.Count()}, - {"Key-Value store", "Clique snapshots", cliqueSnaps.Size(), cliqueSnaps.Count()}, - {"Key-Value store", "Singleton metadata", metadata.Size(), metadata.Count()}, - // {"Ancient store", "Headers", ancientHeadersSize.String(), ancients.String()}, - // {"Ancient store", "Bodies", ancientBodiesSize.String(), ancients.String()}, - // {"Ancient store", "Receipt lists", ancientReceiptsSize.String(), ancients.String()}, - // {"Ancient store", "Difficulties", ancientTdsSize.String(), ancients.String()}, - // {"Ancient store", "Block number->hash", ancientHashesSize.String(), ancients.String()}, - {"Light client", "CHT trie nodes", chtTrieNodes.Size(), chtTrieNodes.Count()}, - {"Light client", "Bloom trie nodes", bloomTrieNodes.Size(), bloomTrieNodes.Count()}, - } - table := tablewriter.NewWriter(os.Stdout) - table.SetHeader([]string{"Database", "Category", "Size", "Items"}) - table.SetFooter([]string{"", "Total", total.String(), " "}) - table.AppendBulk(stats) - table.Render() - - if unaccounted.size > 0 { - log.Error("Database contains unaccounted data", "size", unaccounted) - } - return nil -} diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index 79220dc077..46deebdfc2 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -17,10 +17,10 @@ package state import ( - "github.com/tomochain/tomochain/core/rawdb" "testing" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" ) var addr = common.BytesToAddress([]byte("test")) diff --git a/core/state/statedb.go b/core/state/statedb.go index 822a417e37..644d80c703 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -25,7 +25,6 @@ import ( "time" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/log" @@ -59,8 +58,6 @@ type StateDB struct { db Database trie Trie - snaps *snapshot.Tree - snap snapshot.Snapshot snapDestructs map[common.Hash]struct{} snapAccounts map[common.Hash][]byte snapStorage map[common.Hash]map[common.Hash][]byte @@ -126,7 +123,7 @@ func (self *StateDB) GetCommittedState(addr common.Address, hash common.Hash) co } // Create a new state from a given trie. -func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { +func New(root common.Hash, db Database) (*StateDB, error) { tr, err := db.OpenTrie(root) if err != nil { return nil, err @@ -134,20 +131,12 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) sdb := &StateDB{ db: db, trie: tr, - snaps: snaps, stateObjects: make(map[common.Address]*stateObject), stateObjectsDirty: make(map[common.Address]struct{}), logs: make(map[common.Hash][]*types.Log), preimages: make(map[common.Hash][]byte), journal: newJournal(), } - if sdb.snaps != nil { - if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil { - sdb.snapDestructs = make(map[common.Hash]struct{}) - sdb.snapAccounts = make(map[common.Hash][]byte) - sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte) - } - } return sdb, nil } @@ -179,15 +168,6 @@ func (self *StateDB) Reset(root common.Hash) error { self.logSize = 0 self.preimages = make(map[common.Hash][]byte) self.clearJournalAndRefund() - - if self.snaps != nil { - self.snapAccounts, self.snapDestructs, self.snapStorage = nil, nil, nil - if self.snap = self.snaps.Snapshot(root); self.snap != nil { - self.snapDestructs = make(map[common.Hash]struct{}) - self.snapAccounts = make(map[common.Hash][]byte) - self.snapStorage = make(map[common.Hash]map[common.Hash][]byte) - } - } return nil } @@ -409,15 +389,6 @@ func (s *StateDB) updateStateObject(obj *stateObject) { panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err)) } s.setError(s.trie.TryUpdate(addr[:], data)) - - // If state snapshotting is active, cache the data til commit. Note, this - // update mechanism is not symmetric to the deletion, because whereas it is - // enough to track account updates at commit time, deletions need tracking - // at transaction boundary level to ensure we capture state clearing. - if s.snap != nil { - s.snapAccounts[obj.addrHash] = snapshot.AccountRLP(obj.data.Nonce, obj.data.Balance, obj.data.Root, obj.data.CodeHash) - } - } // deleteStateObject removes the given object from the state trie. @@ -471,44 +442,18 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { if obj := s.stateObjects[addr]; obj != nil { return obj } - // If no live objects are available, attempt to use snapshots - var ( - data Account - err error - ) - if s.snap != nil { - if metrics.EnabledExpensive { - defer func(start time.Time) { s.SnapshotAccountReads += time.Since(start) }(time.Now()) - } - var acc *snapshot.Account - if acc, err = s.snap.Account(crypto.Keccak256Hash(addr[:])); err == nil { - if acc == nil { - return nil - } - data.Nonce, data.Balance, data.CodeHash = acc.Nonce, acc.Balance, acc.CodeHash - if len(data.CodeHash) == 0 { - data.CodeHash = emptyCodeHash - } - data.Root = common.BytesToHash(acc.Root) - if data.Root == (common.Hash{}) { - data.Root = emptyRoot - } - } + var data Account + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now()) } - // If snapshot unavailable or reading from it failed, load from the database - if s.snap == nil || err != nil { - if metrics.EnabledExpensive { - defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now()) - } - enc, err := s.trie.TryGet(addr[:]) - if len(enc) == 0 { - s.setError(err) - return nil - } - if err := rlp.DecodeBytes(enc, &data); err != nil { - log.Error("Failed to decode state object", "addr", addr, "err", err) - return nil - } + enc, err := s.trie.TryGet(addr[:]) + if len(enc) == 0 { + s.setError(err) + return nil + } + if err := rlp.DecodeBytes(enc, &data); err != nil { + log.Error("Failed to decode state object", "addr", addr, "err", err) + return nil } // Insert into the live set obj := newObject(s, addr, data) @@ -599,6 +544,7 @@ func (self *StateDB) Copy() *StateDB { logs: make(map[common.Hash][]*types.Log, len(self.logs)), logSize: self.logSize, preimages: make(map[common.Hash][]byte), + journal: newJournal(), } // Copy the dirty states, logs, and preimages for addr := range self.journal.dirties { @@ -712,7 +658,7 @@ func (s *StateDB) clearJournalAndRefund() { func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) { defer s.clearJournalAndRefund() - for addr := range s.journal.dirties { + for addr := range s.journal.dirties { s.stateObjectsDirty[addr] = struct{}{} } diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index ce44d00f80..b87bc6685a 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -41,7 +41,7 @@ import ( func TestUpdateLeaks(t *testing.T) { // Create an empty state database db := rawdb.NewMemoryDatabase() - state, _ := New(common.Hash{}, NewDatabase(db), nil) + state, _ := New(common.Hash{}, NewDatabase(db)) // Update it with some accounts for i := byte(0); i < 255; i++ { @@ -71,8 +71,8 @@ func TestIntermediateLeaks(t *testing.T) { // Create two state databases, one transitioning to the final state, the other final from the beginning transDb := rawdb.NewMemoryDatabase() finalDb := rawdb.NewMemoryDatabase() - transState, _ := New(common.Hash{}, NewDatabase(transDb), nil) - finalState, _ := New(common.Hash{}, NewDatabase(finalDb), nil) + transState, _ := New(common.Hash{}, NewDatabase(transDb)) + finalState, _ := New(common.Hash{}, NewDatabase(finalDb)) modify := func(state *StateDB, addr common.Address, i, tweak byte) { state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) @@ -130,7 +130,7 @@ func TestIntermediateLeaks(t *testing.T) { func TestCopy(t *testing.T) { // Create a random state test to copy and modify "independently" db := rawdb.NewMemoryDatabase() - orig, _ := New(common.Hash{}, NewDatabase(db), nil) + orig, _ := New(common.Hash{}, NewDatabase(db)) for i := byte(0); i < 255; i++ { obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) @@ -342,7 +342,7 @@ func (test *snapshotTest) run() bool { // Run all actions and create snapshots. var ( db = rawdb.NewMemoryDatabase() - state, _ = New(common.Hash{}, NewDatabase(db), nil) + state, _ = New(common.Hash{}, NewDatabase(db)) snapshotRevs = make([]int, len(test.snapshots)) sindex = 0 ) @@ -356,7 +356,7 @@ func (test *snapshotTest) run() bool { // Revert all snapshots in reverse order. Each revert must yield a state // that is equivalent to fresh state with all actions up the snapshot applied. for sindex--; sindex >= 0; sindex-- { - checkstate, _ := New(common.Hash{}, state.Database(), nil) + checkstate, _ := New(common.Hash{}, state.Database()) for _, action := range test.actions[:test.snapshots[sindex]] { action.fn(action, checkstate) } @@ -416,15 +416,21 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { func (s *StateSuite) TestTouchDelete(c *check.C) { s.state.GetOrNewStateObject(common.Address{}) root, _ := s.state.Commit(false) - s.state.Reset(root) + s.state, _ = New(root, s.state.db) snapshot := s.state.Snapshot() s.state.AddBalance(common.Address{}, new(big.Int)) - if len(s.state.stateObjectsDirty) != 1 { + if len(s.state.journal.dirties) != 1 { + c.Fatal("expected one dirty state object") + } + if s.state.journal.dirties[common.Address{}] != 1 { c.Fatal("expected one dirty state object") } s.state.RevertToSnapshot(snapshot) - if len(s.state.stateObjectsDirty) != 0 { + if len(s.state.journal.dirties) != 0 { + c.Fatal("expected no dirty state object") + } + if s.state.journal.dirties[common.Address{}] != 0 { c.Fatal("expected no dirty state object") } } diff --git a/metrics/metrics.go b/metrics/metrics.go index dbb2727ec0..3e315b19e1 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -19,7 +19,12 @@ import ( // // This global kill-switch helps quantify the observer effect and makes // for less cluttered pprof profiles. -var Enabled bool = false +var Enabled = false + +// EnabledExpensive is a soft-flag meant for external packages to check if costly +// metrics gathering is allowed or not. The goal is to separate standard metrics +// for health monitoring and debug metrics that might impact runtime performance. +var EnabledExpensive = false // MetricsEnabledFlag is the CLI flag name to use to enable metrics collections. const MetricsEnabledFlag = "metrics" From 327c90dbaa55488fb0e515d3d6b045117dfa2f9b Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 17 Jul 2023 14:29:49 +0700 Subject: [PATCH 032/119] Remove rlpLog RLP encoder --- core/types/gen_log_rlp.go | 26 -------------------------- 1 file changed, 26 deletions(-) delete mode 100644 core/types/gen_log_rlp.go diff --git a/core/types/gen_log_rlp.go b/core/types/gen_log_rlp.go deleted file mode 100644 index 3f2c3ddc06..0000000000 --- a/core/types/gen_log_rlp.go +++ /dev/null @@ -1,26 +0,0 @@ -// Code generated by rlpgen. DO NOT EDIT. - -//go:build !norlpgen -// +build !norlpgen - -package types - -import ( - "io" - - "github.com/tomochain/tomochain/rlp" -) - -func (obj *rlpLog) EncodeRLP(_w io.Writer) error { - w := rlp.NewEncoderBuffer(_w) - _tmp0 := w.List() - w.WriteBytes(obj.Address[:]) - _tmp1 := w.List() - for _, _tmp2 := range obj.Topics { - w.WriteBytes(_tmp2[:]) - } - w.ListEnd(_tmp1) - w.WriteBytes(obj.Data) - w.ListEnd(_tmp0) - return w.Flush() -} From 4d22f8dc0dab634ad30163513258f731716bd1f5 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 18 Jul 2023 16:04:29 +0700 Subject: [PATCH 033/119] Fix unit tests --- core/bench_test.go | 12 ++++++++++-- core/blockchain.go | 8 ++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/core/bench_test.go b/core/bench_test.go index 137b57f031..cef95625c6 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -18,7 +18,6 @@ package core import ( "crypto/ecdsa" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "os" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -238,10 +238,16 @@ func makeChainForBench(db ethdb.Database, full bool, count uint64) { WriteHeader(db, header) WriteCanonicalHash(db, hash, n) WriteTd(db, hash, n, big.NewInt(int64(n+1))) + if n == 0 { + WriteChainConfig(db, hash, params.AllEthashProtocolChanges) + } + WriteHeadHeaderHash(db, hash) + if full || n == 0 { block := types.NewBlockWithHeader(header) WriteBody(db, hash, n, block.Body()) WriteBlockReceipts(db, hash, n, nil) + WriteHeadBlockHash(db, hash) } } } @@ -275,6 +281,8 @@ func benchReadChain(b *testing.B, full bool, count uint64) { } makeChainForBench(db, full, count) db.Close() + cacheConfig := defaultCacheConfig + cacheConfig.Disabled = true b.ReportAllocs() b.ResetTimer() @@ -284,7 +292,7 @@ func benchReadChain(b *testing.B, full bool, count uint64) { if err != nil { b.Fatalf("error opening database at %v: %v", dir, err) } - chain, err := NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) + chain, err := NewBlockChain(db, cacheConfig, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) if err != nil { b.Fatalf("error creating chain: %v", err) } diff --git a/core/blockchain.go b/core/blockchain.go index f763189be7..18b1521ff5 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -82,6 +82,14 @@ type CacheConfig struct { TrieNodeLimit int // Memory limit (MB) at which to flush the current in-memory trie to disk TrieTimeLimit time.Duration // Time limit after which to flush the current in-memory trie to disk } + +// defaultCacheConfig are the default caching values if none are specified by the +// user (also used during testing). +var defaultCacheConfig = &CacheConfig{ + TrieNodeLimit: 256, + TrieTimeLimit: 5 * time.Minute, +} + type ResultProcessBlock struct { logs []*types.Log receipts []*types.Receipt From 5f1e2b415473a77180ae78a9395165346559517a Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 18 Jul 2023 16:11:56 +0700 Subject: [PATCH 034/119] Minor fix after reviewing --- core/headerchain.go | 7 +++---- core/state_processor.go | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/headerchain.go b/core/headerchain.go index f3cc8cf77b..af2b392645 100644 --- a/core/headerchain.go +++ b/core/headerchain.go @@ -66,10 +66,9 @@ type HeaderChain struct { } // NewHeaderChain creates a new HeaderChain structure. -// -// getValidator should return the parent's validator -// procInterrupt points to the parent's interrupt semaphore -// wg points to the parent's shutdown wait group +// getValidator should return the parent's validator +// procInterrupt points to the parent's interrupt semaphore +// wg points to the parent's shutdown wait group func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine consensus.Engine, procInterrupt func() bool) (*HeaderChain, error) { headerCache, _ := lru.New(headerCacheLimit) tdCache, _ := lru.New(tdCacheLimit) diff --git a/core/state_processor.go b/core/state_processor.go index de70802235..d77697dea7 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -515,6 +515,7 @@ func InitSignerInTransactions(config *params.ChainConfig, header *types.Header, go func(from int, to int) { for j := from; j < to; j++ { types.CacheSigner(signer, txs[j]) + txs[j].CacheHash() } wg.Done() }(from, to) From c37affb1536c8c9cd3cd64bd647bbe64de62ec58 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 20 Jul 2023 16:07:54 +0700 Subject: [PATCH 035/119] Move statedb.Account struct to types.StateAccount --- accounts/keystore/keystore_wallet.go | 8 +-- accounts/usbwallet/wallet.go | 2 +- cmd/faucet/faucet.go | 2 +- cmd/gc/main.go | 12 ++-- consensus/clique/snapshot.go | 2 +- consensus/posv/posv.go | 8 +-- consensus/posv/snapshot.go | 2 +- core/state/dump.go | 3 +- core/state/iterator.go | 3 +- core/state/state_object.go | 18 ++--- core/state/statedb.go | 10 +-- core/state/sync.go | 3 +- core/types/block.go | 5 -- core/types/hashes.go | 39 ++++++++++ core/types/state_account.go | 103 +++++++++++++++++++++++++++ les/handler.go | 12 ++-- 16 files changed, 182 insertions(+), 50 deletions(-) create mode 100644 core/types/hashes.go create mode 100644 core/types/state_account.go diff --git a/accounts/keystore/keystore_wallet.go b/accounts/keystore/keystore_wallet.go index 01ffd75a8e..91ac138786 100644 --- a/accounts/keystore/keystore_wallet.go +++ b/accounts/keystore/keystore_wallet.go @@ -90,7 +90,7 @@ func (w *keystoreWallet) SignHash(account accounts.Account, hash []byte) ([]byte if account.URL != (accounts.URL{}) && account.URL != w.account.URL { return nil, accounts.ErrUnknownAccount } - // Account seems valid, request the keystore to sign + // StateAccount seems valid, request the keystore to sign return w.keystore.SignHash(account, hash) } @@ -106,7 +106,7 @@ func (w *keystoreWallet) SignTx(account accounts.Account, tx *types.Transaction, if account.URL != (accounts.URL{}) && account.URL != w.account.URL { return nil, accounts.ErrUnknownAccount } - // Account seems valid, request the keystore to sign + // StateAccount seems valid, request the keystore to sign return w.keystore.SignTx(account, tx, chainID) } @@ -120,7 +120,7 @@ func (w *keystoreWallet) SignHashWithPassphrase(account accounts.Account, passph if account.URL != (accounts.URL{}) && account.URL != w.account.URL { return nil, accounts.ErrUnknownAccount } - // Account seems valid, request the keystore to sign + // StateAccount seems valid, request the keystore to sign return w.keystore.SignHashWithPassphrase(account, passphrase, hash) } @@ -134,6 +134,6 @@ func (w *keystoreWallet) SignTxWithPassphrase(account accounts.Account, passphra if account.URL != (accounts.URL{}) && account.URL != w.account.URL { return nil, accounts.ErrUnknownAccount } - // Account seems valid, request the keystore to sign + // StateAccount seems valid, request the keystore to sign return w.keystore.SignTxWithPassphrase(account, passphrase, tx, chainID) } diff --git a/accounts/usbwallet/wallet.go b/accounts/usbwallet/wallet.go index d3cda1f21e..2cb2ca2ae7 100644 --- a/accounts/usbwallet/wallet.go +++ b/accounts/usbwallet/wallet.go @@ -319,7 +319,7 @@ func (w *wallet) selfDerive() { // Termination requested continue case reqc = <-w.deriveReq: - // Account discovery requested + // StateAccount discovery requested } // Derivation needs a chain and device access, skip if either unavailable w.stateLock.RLock() diff --git a/cmd/faucet/faucet.go b/cmd/faucet/faucet.go index 6014f3c5a2..45a5e6cb4f 100644 --- a/cmd/faucet/faucet.go +++ b/cmd/faucet/faucet.go @@ -200,7 +200,7 @@ type faucet struct { index []byte // Index page to serve up on the web keystore *keystore.KeyStore // Keystore containing the single signer - account accounts.Account // Account funding user faucet requests + account accounts.Account // StateAccount funding user faucet requests nonce uint64 // Current pending nonce of the faucet price *big.Int // Current gas price to issue funds with diff --git a/cmd/gc/main.go b/cmd/gc/main.go index 567349ee42..7e1fc4e6d9 100644 --- a/cmd/gc/main.go +++ b/cmd/gc/main.go @@ -3,9 +3,6 @@ package main import ( "flag" "fmt" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/ethdb" - "github.com/tomochain/tomochain/ethdb/leveldb" "os" "os/signal" "runtime" @@ -13,12 +10,15 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" "github.com/tomochain/tomochain/cmd/utils" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" - "github.com/tomochain/tomochain/core/state" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/ethdb/leveldb" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -82,7 +82,7 @@ func main() { if running { for _, address := range cleanAddress { enc := trieRoot.trie.Get(address.Bytes()) - var data state.Account + var data types.StateAccount rlp.DecodeBytes(enc, &data) fmt.Println(time.Now().Format(time.RFC3339), "Start clean state address ", address.Hex(), " at block ", trieRoot.number) signerRoot, err := resolveHash(data.Root[:], db) diff --git a/consensus/clique/snapshot.go b/consensus/clique/snapshot.go index 3c2bf703d8..9a1e9e8846 100644 --- a/consensus/clique/snapshot.go +++ b/consensus/clique/snapshot.go @@ -32,7 +32,7 @@ import ( type Vote struct { Signer common.Address `json:"signer"` // Authorized signer that cast this vote Block uint64 `json:"block"` // Block number the vote was cast in (expire old votes) - Address common.Address `json:"address"` // Account being voted on to change its authorization + Address common.Address `json:"address"` // StateAccount being voted on to change its authorization Authorize bool `json:"authorize"` // Whether to authorize or deauthorize the voted account } diff --git a/consensus/posv/posv.go b/consensus/posv/posv.go index 0027104970..a25bdd764d 100644 --- a/consensus/posv/posv.go +++ b/consensus/posv/posv.go @@ -21,9 +21,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" "io/ioutil" "math/big" "math/rand" @@ -50,6 +47,9 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) const ( @@ -181,7 +181,7 @@ var ( // SignerFn is a signer callback function to request a hash to be signed by a // backing account. -//type SignerFn func(accounts.Account, []byte) ([]byte, error) +//type SignerFn func(accounts.StateAccount, []byte) ([]byte, error) // sigHash returns the hash which is used as input for the proof-of-stake-voting // signing. It is the hash of the entire header apart from the 65 byte signature diff --git a/consensus/posv/snapshot.go b/consensus/posv/snapshot.go index aef9e2a39f..01f9d50e42 100644 --- a/consensus/posv/snapshot.go +++ b/consensus/posv/snapshot.go @@ -32,7 +32,7 @@ import ( //type Vote struct { // Signer common.Address `json:"signer"` // Authorized signer that cast this vote // Block uint64 `json:"block"` // Block number the vote was cast in (expire old votes) -// Address common.Address `json:"address"` // Account being voted on to change its authorization +// Address common.Address `json:"address"` // StateAccount being voted on to change its authorization // Authorize bool `json:"authorize"` // Whether to authorize or deauthorize the voted account //} diff --git a/core/state/dump.go b/core/state/dump.go index f08c6e7df3..7368146ca7 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -48,7 +49,7 @@ func (self *StateDB) RawDump() Dump { it := trie.NewIterator(self.trie.NodeIterator(nil)) for it.Next() { addr := self.trie.GetKey(it.Key) - var data Account + var data types.StateAccount if err := rlp.DecodeBytes(it.Value, &data); err != nil { panic(err) } diff --git a/core/state/iterator.go b/core/state/iterator.go index 3cfc592ecb..d69321f36a 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -104,7 +105,7 @@ func (it *NodeIterator) step() error { return nil } // Otherwise we've reached an account node, initiate data iteration - var account Account + var account types.StateAccount if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil { return err } diff --git a/core/state/state_object.go b/core/state/state_object.go index b03231e23b..478823be58 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -23,6 +23,7 @@ import ( "math/big" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/rlp" ) @@ -58,12 +59,12 @@ func (self Storage) Copy() Storage { // // The usage pattern is as follows: // First you need to obtain a state object. -// Account values can be accessed and modified through the object. +// StateAccount values can be accessed and modified through the object. // Finally, call CommitTrie to write the modified storage trie into a database. type stateObject struct { address common.Address addrHash common.Hash // hash of ethereum address of the account - data Account + data types.StateAccount db *StateDB // DB error. @@ -95,17 +96,8 @@ func (s *stateObject) empty() bool { return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash) } -// Account is the Ethereum consensus representation of accounts. -// These objects are stored in the main account trie. -type Account struct { - Nonce uint64 - Balance *big.Int - Root common.Hash // merkle root of the storage trie - CodeHash []byte -} - // newObject creates a state object. -func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *stateObject { +func newObject(db *StateDB, address common.Address, data types.StateAccount, onDirty func(addr common.Address)) *stateObject { if data.Balance == nil { data.Balance = new(big.Int) } @@ -397,7 +389,7 @@ func (self *stateObject) Nonce() uint64 { } // Never called, but must be present to allow stateObject to be used -// as a vm.Account interface that also satisfies the vm.ContractRef +// as a vm.StateAccount interface that also satisfies the vm.ContractRef // interface. Interfaces are awesome. func (self *stateObject) Value() *big.Int { panic("Value on stateObject should never be called") diff --git a/core/state/statedb.go b/core/state/statedb.go index 7a3357b3e8..818d3d0aaa 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -398,7 +398,7 @@ func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObje self.setError(err) return nil } - var data Account + var data types.StateAccount if err := rlp.DecodeBytes(enc, &data); err != nil { log.Error("Failed to decode state object", "addr", addr, "err", err) return nil @@ -432,7 +432,7 @@ func (self *StateDB) MarkStateObjectDirty(addr common.Address) { // the given address, it is overwritten and returned as the second return value. func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { prev = self.getStateObject(addr) - newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty) + newobj = newObject(self, addr, types.StateAccount{}, self.MarkStateObjectDirty) newobj.setNonce(0) // sets the object to dirty if prev == nil { self.journal = append(self.journal, createObjectChange{account: &addr}) @@ -449,8 +449,8 @@ func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObjec // CreateAccount is called during the EVM CREATE operation. The situation might arise that // a contract does the following: // -// 1. sends funds to sha(account ++ (nonce + 1)) -// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) +// 1. sends funds to sha(account ++ (nonce + 1)) +// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) // // Carrying over the balance ensures that Ether doesn't disappear. func (self *StateDB) CreateAccount(addr common.Address) { @@ -636,7 +636,7 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) } // Write trie changes. root, err = s.trie.Commit(func(leaf []byte, parent common.Hash) error { - var account Account + var account types.StateAccount if err := rlp.DecodeBytes(leaf, &account); err != nil { return nil } diff --git a/core/state/sync.go b/core/state/sync.go index 95f29b2879..e26281c7db 100644 --- a/core/state/sync.go +++ b/core/state/sync.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" @@ -29,7 +30,7 @@ import ( func NewStateSync(root common.Hash, database ethdb.KeyValueReader, bloom *trie.SyncBloom) *trie.Sync { var syncer *trie.Sync callback := func(leaf []byte, parent common.Hash) error { - var obj Account + var obj types.StateAccount if err := rlp.Decode(bytes.NewReader(leaf), &obj); err != nil { return err } diff --git a/core/types/block.go b/core/types/block.go index a055ced147..9e95a1d82c 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -33,11 +33,6 @@ import ( "github.com/tomochain/tomochain/rlp" ) -var ( - EmptyRootHash = DeriveSha(Transactions{}) - EmptyUncleHash = CalcUncleHash(nil) -) - // A BlockNonce is a 64-bit hash which proves (combined with the // mix-hash) that a sufficient amount of computation has been carried // out on a block. diff --git a/core/types/hashes.go b/core/types/hashes.go new file mode 100644 index 0000000000..35fc6dc9f9 --- /dev/null +++ b/core/types/hashes.go @@ -0,0 +1,39 @@ +// Copyright 2023 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/crypto" +) + +var ( + // EmptyRootHash is the known root hash of an empty trie. + EmptyRootHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + + // EmptyUncleHash is the known hash of the empty uncle set. + EmptyUncleHash = rlpHash([]*Header(nil)) // 1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347 + + // EmptyCodeHash is the known hash of the empty EVM bytecode. + EmptyCodeHash = crypto.Keccak256Hash(nil) // c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470 + + // EmptyTxsHash is the known hash of the empty transaction set. + EmptyTxsHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + + // EmptyReceiptsHash is the known hash of the empty receipt set. + EmptyReceiptsHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") +) diff --git a/core/types/state_account.go b/core/types/state_account.go new file mode 100644 index 0000000000..01c552a04c --- /dev/null +++ b/core/types/state_account.go @@ -0,0 +1,103 @@ +package types + +import ( + "bytes" + "math/big" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/rlp" +) + +// StateAccount is the Ethereum consensus representation of accounts. +// These objects are stored in the main account trie. +type StateAccount struct { + Nonce uint64 + Balance *big.Int + Root common.Hash // merkle root of the storage trie + CodeHash []byte +} + +// NewEmptyStateAccount constructs an empty state account. +func NewEmptyStateAccount() *StateAccount { + return &StateAccount{ + Balance: new(big.Int), + Root: EmptyRootHash, + CodeHash: EmptyCodeHash.Bytes(), + } +} + +// Copy returns a deep-copied state account object. +func (acct *StateAccount) Copy() *StateAccount { + var balance *big.Int + if acct.Balance != nil { + balance = new(big.Int).Set(acct.Balance) + } + return &StateAccount{ + Nonce: acct.Nonce, + Balance: balance, + Root: acct.Root, + CodeHash: common.CopyBytes(acct.CodeHash), + } +} + +// SlimAccount is a modified version of an Account, where the root is replaced +// with a byte slice. This format can be used to represent full-consensus format +// or slim format which replaces the empty root and code hash as nil byte slice. +type SlimAccount struct { + Nonce uint64 + Balance *big.Int + Root []byte // Nil if root equals to types.EmptyRootHash + CodeHash []byte // Nil if hash equals to types.EmptyCodeHash +} + +// SlimAccountRLP encodes the state account in 'slim RLP' format. +func SlimAccountRLP(account StateAccount) []byte { + slim := SlimAccount{ + Nonce: account.Nonce, + Balance: account.Balance, + } + if account.Root != EmptyRootHash { + slim.Root = account.Root[:] + } + if !bytes.Equal(account.CodeHash, EmptyCodeHash[:]) { + slim.CodeHash = account.CodeHash + } + data, err := rlp.EncodeToBytes(slim) + if err != nil { + panic(err) + } + return data +} + +// FullAccount decodes the data on the 'slim RLP' format and return +// the consensus format account. +func FullAccount(data []byte) (*StateAccount, error) { + var slim SlimAccount + if err := rlp.DecodeBytes(data, &slim); err != nil { + return nil, err + } + var account StateAccount + account.Nonce, account.Balance = slim.Nonce, slim.Balance + + // Interpret the storage root and code hash in slim format. + if len(slim.Root) == 0 { + account.Root = EmptyRootHash + } else { + account.Root = common.BytesToHash(slim.Root) + } + if len(slim.CodeHash) == 0 { + account.CodeHash = EmptyCodeHash[:] + } else { + account.CodeHash = slim.CodeHash + } + return &account, nil +} + +// FullAccountRLP converts data on the 'slim RLP' format into the full RLP-format. +func FullAccountRLP(data []byte) ([]byte, error) { + account, err := FullAccount(data) + if err != nil { + return nil, err + } + return rlp.EncodeToBytes(account) +} diff --git a/les/handler.go b/les/handler.go index b426f7fdd1..d5bd36f470 100644 --- a/les/handler.go +++ b/les/handler.go @@ -21,7 +21,6 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "net" "sync" @@ -30,6 +29,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth/downloader" @@ -1095,18 +1095,18 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } // getAccount retrieves an account from the state based at root. -func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (state.Account, error) { +func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (types.StateAccount, error) { trie, err := trie.New(root, statedb.Database().TrieDB()) if err != nil { - return state.Account{}, err + return types.StateAccount{}, err } blob, err := trie.TryGet(hash[:]) if err != nil { - return state.Account{}, err + return types.StateAccount{}, err } - var account state.Account + var account types.StateAccount if err = rlp.DecodeBytes(blob, &account); err != nil { - return state.Account{}, err + return types.StateAccount{}, err } return account, nil } From e0776fb55f3d6302da4ece2762d196f87768f5a7 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 20 Jul 2023 16:24:04 +0700 Subject: [PATCH 036/119] Move database_util.go to rawdb package --- accounts/abi/bind/backends/simulated.go | 14 ++--- cmd/gc/main.go | 17 +++--- cmd/tomo/dao_test.go | 5 +- cmd/utils/cmd.go | 5 +- contracts/utils.go | 3 +- core/bench_test.go | 16 +++--- core/blockchain.go | 72 ++++++++++++------------- core/blockchain_test.go | 32 +++++------ core/chain_indexer.go | 7 +-- core/chain_indexer_test.go | 8 +-- core/genesis.go | 38 ++++++------- core/genesis_test.go | 4 +- core/headerchain.go | 51 +++++++++--------- core/{ => rawdb}/database_util.go | 19 ++++--- core/{ => rawdb}/database_util_test.go | 21 ++++---- eth/api.go | 6 +-- eth/api_backend.go | 15 +++--- eth/api_tracer.go | 7 +-- eth/backend.go | 26 ++++----- eth/bloombits.go | 10 ++-- eth/downloader/downloader.go | 36 ++++++------- eth/downloader/fakepeer.go | 5 +- eth/downloader/statesync.go | 6 +-- eth/filters/bench_test.go | 27 +++++----- eth/filters/filter_system.go | 5 +- eth/filters/filter_system_test.go | 22 ++++---- eth/filters/filter_test.go | 22 ++++---- internal/ethapi/api.go | 33 ++++++------ les/api_backend.go | 12 ++--- les/backend.go | 3 +- les/fetcher.go | 6 +-- les/handler.go | 26 ++++----- les/handler_test.go | 10 ++-- les/odr_requests.go | 8 +-- les/odr_test.go | 8 +-- les/protocol.go | 3 +- les/request_test.go | 9 ++-- les/server.go | 5 +- les/sync.go | 4 +- light/lightchain.go | 10 ++-- light/lightchain_test.go | 6 +-- light/odr.go | 15 +++--- light/odr_test.go | 14 ++--- light/odr_util.go | 33 ++++++------ light/postprocess.go | 6 +-- light/txpool.go | 16 +++--- 46 files changed, 369 insertions(+), 357 deletions(-) rename core/{ => rawdb}/database_util.go (98%) rename core/{ => rawdb}/database_util_test.go (97%) diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7411f492a8..716c86857d 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -20,8 +20,6 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sync" "time" @@ -30,9 +28,11 @@ import ( "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -174,7 +174,7 @@ func (b *SimulatedBackend) ForEachStorageAt(ctx context.Context, contract common // TransactionReceipt returns the receipt of a transaction. func (b *SimulatedBackend) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) { - receipt, _, _, _ := core.GetReceipt(b.database, txHash) + receipt, _, _, _ := rawdb.GetReceipt(b.database, txHash) return receipt, nil } @@ -202,7 +202,7 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.Call return rval, err } -//FIXME: please use copyState for this function +// FIXME: please use copyState for this function // CallContractWithState executes a contract call at the given state. func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { // Ensure message is initialized properly. @@ -285,7 +285,7 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM snapshot := b.pendingState.Snapshot() _, _, failed, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) - fmt.Println("EstimateGas",err,failed) + fmt.Println("EstimateGas", err, failed) b.pendingState.RevertToSnapshot(snapshot) if err != nil || failed { @@ -485,11 +485,11 @@ func (fb *filterBackend) HeaderByNumber(ctx context.Context, block rpc.BlockNumb } func (fb *filterBackend) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) { - return core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash)), nil + return rawdb.GetBlockReceipts(fb.db, hash, rawdb.GetBlockNumber(fb.db, hash)), nil } func (fb *filterBackend) GetLogs(ctx context.Context, hash common.Hash) ([][]*types.Log, error) { - receipts := core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash)) + receipts := rawdb.GetBlockReceipts(fb.db, hash, rawdb.GetBlockNumber(fb.db, hash)) if receipts == nil { return nil, nil } diff --git a/cmd/gc/main.go b/cmd/gc/main.go index 567349ee42..2413d34869 100644 --- a/cmd/gc/main.go +++ b/cmd/gc/main.go @@ -3,9 +3,6 @@ package main import ( "flag" "fmt" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/ethdb" - "github.com/tomochain/tomochain/ethdb/leveldb" "os" "os/signal" "runtime" @@ -13,12 +10,14 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" "github.com/tomochain/tomochain/cmd/utils" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/eth" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/ethdb/leveldb" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -54,15 +53,15 @@ func main() { flag.Parse() db, _ := leveldb.New(*dir, eth.DefaultConfig.DatabaseCache, utils.MakeDatabaseHandles(), "") lddb := rawdb.NewDatabase(db) - head := core.GetHeadBlockHash(lddb) - currentHeader := core.GetHeader(lddb, head, core.GetBlockNumber(lddb, head)) + head := rawdb.GetHeadBlockHash(lddb) + currentHeader := rawdb.GetHeader(lddb, head, rawdb.GetBlockNumber(lddb, head)) tridb := trie.NewDatabase(lddb) catchEventInterupt(db) cache, _ = lru.New(*cacheSize) go func() { for i := uint64(1); i <= currentHeader.Number.Uint64(); i++ { - hash := core.GetCanonicalHash(lddb, i) - root := core.GetHeader(lddb, hash, i).Root + hash := rawdb.GetCanonicalHash(lddb, i) + root := rawdb.GetHeader(lddb, hash, i).Root trieRoot, err := trie.NewSecure(root, tridb) if err != nil { continue diff --git a/cmd/tomo/dao_test.go b/cmd/tomo/dao_test.go index 773f1ed152..768a7bb762 100644 --- a/cmd/tomo/dao_test.go +++ b/cmd/tomo/dao_test.go @@ -17,7 +17,6 @@ package main import ( - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "os" @@ -25,7 +24,7 @@ import ( "testing" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" ) // Genesis block for nodes which don't care about the DAO fork (i.e. not configured) @@ -130,7 +129,7 @@ func testDAOForkBlockNewChain(t *testing.T, test int, genesis string, expectBloc if genesis != "" { genesisHash = daoGenesisHash } - config, err := core.GetChainConfig(db, genesisHash) + config, err := rawdb.GetChainConfig(db, genesisHash) if err != nil { t.Errorf("test %d: failed to retrieve chain config: %v", test, err) return // we want to return here, the other checks can't make it past this point (nil panic). diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index a3787f7311..667098e90e 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -29,6 +29,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb" @@ -271,7 +272,7 @@ func ImportPreimages(db ethdb.Database, fn string) error { // Accumulate the preimages and flush when enough ws gathered preimages[crypto.Keccak256Hash(blob)] = common.CopyBytes(blob) if len(preimages) > 1024 { - if err := core.WritePreimages(db, 0, preimages); err != nil { + if err := rawdb.WritePreimages(db, 0, preimages); err != nil { return err } preimages = make(map[common.Hash][]byte) @@ -279,7 +280,7 @@ func ImportPreimages(db ethdb.Database, fn string) error { } // Flush the last batch preimage data if len(preimages) > 0 { - return core.WritePreimages(db, 0, preimages) + return rawdb.WritePreimages(db, 0, preimages) } return nil } diff --git a/contracts/utils.go b/contracts/utils.go index 4468b5de9a..26d36ddd69 100644 --- a/contracts/utils.go +++ b/contracts/utils.go @@ -39,6 +39,7 @@ import ( "github.com/tomochain/tomochain/contracts/blocksigner/contract" randomizeContract "github.com/tomochain/tomochain/contracts/randomize/contract" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" stateDatabase "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" @@ -336,7 +337,7 @@ func GetRewardForCheckpoint(c *posv.Posv, chain consensus.ChainReader, header *t block := chain.GetBlock(header.Hash(), i) txs := block.Transactions() if !chain.Config().IsTIPSigning(header.Number) { - receipts := core.GetBlockReceipts(c.GetDb(), header.Hash(), i) + receipts := rawdb.GetBlockReceipts(c.GetDb(), header.Hash(), i) signData = c.CacheData(header, txs, receipts) } else { signData = c.CacheSigner(header.Hash(), txs) diff --git a/core/bench_test.go b/core/bench_test.go index 137b57f031..c8fa85ede5 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -18,7 +18,6 @@ package core import ( "crypto/ecdsa" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "os" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -235,13 +235,13 @@ func makeChainForBench(db ethdb.Database, full bool, count uint64) { ReceiptHash: types.EmptyRootHash, } hash = header.Hash() - WriteHeader(db, header) - WriteCanonicalHash(db, hash, n) - WriteTd(db, hash, n, big.NewInt(int64(n+1))) + rawdb.WriteHeader(db, header) + rawdb.WriteCanonicalHash(db, hash, n) + rawdb.WriteTd(db, hash, n, big.NewInt(int64(n+1))) if full || n == 0 { block := types.NewBlockWithHeader(header) - WriteBody(db, hash, n, block.Body()) - WriteBlockReceipts(db, hash, n, nil) + rawdb.WriteBody(db, hash, n, block.Body()) + rawdb.WriteBlockReceipts(db, hash, n, nil) } } } @@ -293,8 +293,8 @@ func benchReadChain(b *testing.B, full bool, count uint64) { header := chain.GetHeaderByNumber(n) if full { hash := header.Hash() - GetBody(db, hash, n) - GetBlockReceipts(db, hash, n) + rawdb.GetBody(db, hash, n) + rawdb.GetBlockReceipts(db, hash, n) } } diff --git a/core/blockchain.go b/core/blockchain.go index f763189be7..eb8a816f55 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -28,17 +28,16 @@ import ( "sync/atomic" "time" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" - - "github.com/tomochain/tomochain/accounts/abi/bind" - "github.com/tomochain/tomochain/tomox/tradingstate" + "gopkg.in/karalabe/cookiejar.v2/collections/prque" lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/mclock" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/posv" contractValidator "github.com/tomochain/tomochain/contracts/validator/contract" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -50,8 +49,9 @@ import ( "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" "github.com/tomochain/tomochain/trie" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) var ( @@ -276,7 +276,7 @@ func (bc *BlockChain) addTomoxDb(tomoxDb ethdb.TomoxDatabase) { // assumes that the chain manager mutex is held. func (bc *BlockChain) loadLastState() error { // Restore the last known head block - head := GetHeadBlockHash(bc.db) + head := rawdb.GetHeadBlockHash(bc.db) if head == (common.Hash{}) { // Corrupt or empty database, init from scratch log.Warn("Empty database, resetting chain") @@ -344,7 +344,7 @@ func (bc *BlockChain) loadLastState() error { // Restore the last known head header currentHeader := currentBlock.Header() - if head := GetHeadHeaderHash(bc.db); head != (common.Hash{}) { + if head := rawdb.GetHeadHeaderHash(bc.db); head != (common.Hash{}) { if header := bc.GetHeaderByHash(head); header != nil { currentHeader = header } @@ -353,7 +353,7 @@ func (bc *BlockChain) loadLastState() error { // Restore the last known head fast block bc.currentFastBlock.Store(currentBlock) - if head := GetHeadFastBlockHash(bc.db); head != (common.Hash{}) { + if head := rawdb.GetHeadFastBlockHash(bc.db); head != (common.Hash{}) { if block := bc.GetBlockByHash(head); block != nil { bc.currentFastBlock.Store(block) } @@ -385,7 +385,7 @@ func (bc *BlockChain) SetHead(head uint64) error { // Rewind the header chain, deleting all block bodies until then delFn := func(hash common.Hash, num uint64) { - DeleteBody(bc.db, hash, num) + rawdb.DeleteBody(bc.db, hash, num) } bc.hc.SetHead(head, delFn) currentHeader := bc.hc.CurrentHeader() @@ -420,10 +420,10 @@ func (bc *BlockChain) SetHead(head uint64) error { } currentBlock := bc.CurrentBlock() currentFastBlock := bc.CurrentFastBlock() - if err := WriteHeadBlockHash(bc.db, currentBlock.Hash()); err != nil { + if err := rawdb.WriteHeadBlockHash(bc.db, currentBlock.Hash()); err != nil { log.Crit("Failed to reset head full block", "err", err) } - if err := WriteHeadFastBlockHash(bc.db, currentFastBlock.Hash()); err != nil { + if err := rawdb.WriteHeadFastBlockHash(bc.db, currentFastBlock.Hash()); err != nil { log.Crit("Failed to reset head fast block", "err", err) } return bc.loadLastState() @@ -562,7 +562,7 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error { if err := bc.hc.WriteTd(genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil { log.Crit("Failed to write genesis block TD", "err", err) } - if err := WriteBlock(bc.db, genesis); err != nil { + if err := rawdb.WriteBlock(bc.db, genesis); err != nil { log.Crit("Failed to write genesis block", "err", err) } bc.genesisBlock = genesis @@ -658,13 +658,13 @@ func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error { // Note, this function assumes that the `mu` mutex is held! func (bc *BlockChain) insert(block *types.Block) { // If the block is on a side chain or an unknown one, force other heads onto it too - updateHeads := GetCanonicalHash(bc.db, block.NumberU64()) != block.Hash() + updateHeads := rawdb.GetCanonicalHash(bc.db, block.NumberU64()) != block.Hash() // Add the block to the canonical chain number scheme and mark as the head - if err := WriteCanonicalHash(bc.db, block.Hash(), block.NumberU64()); err != nil { + if err := rawdb.WriteCanonicalHash(bc.db, block.Hash(), block.NumberU64()); err != nil { log.Crit("Failed to insert block number", "err", err) } - if err := WriteHeadBlockHash(bc.db, block.Hash()); err != nil { + if err := rawdb.WriteHeadBlockHash(bc.db, block.Hash()); err != nil { log.Crit("Failed to insert head block hash", "err", err) } bc.currentBlock.Store(block) @@ -681,7 +681,7 @@ func (bc *BlockChain) insert(block *types.Block) { if updateHeads { bc.hc.SetCurrentHeader(block.Header()) - if err := WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil { + if err := rawdb.WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil { log.Crit("Failed to insert head fast block hash", "err", err) } bc.currentFastBlock.Store(block) @@ -701,7 +701,7 @@ func (bc *BlockChain) GetBody(hash common.Hash) *types.Body { body := cached.(*types.Body) return body } - body := GetBody(bc.db, hash, bc.hc.GetBlockNumber(hash)) + body := rawdb.GetBody(bc.db, hash, bc.hc.GetBlockNumber(hash)) if body == nil { return nil } @@ -717,7 +717,7 @@ func (bc *BlockChain) GetBodyRLP(hash common.Hash) rlp.RawValue { if cached, ok := bc.bodyRLPCache.Get(hash); ok { return cached.(rlp.RawValue) } - body := GetBodyRLP(bc.db, hash, bc.hc.GetBlockNumber(hash)) + body := rawdb.GetBodyRLP(bc.db, hash, bc.hc.GetBlockNumber(hash)) if len(body) == 0 { return nil } @@ -731,7 +731,7 @@ func (bc *BlockChain) HasBlock(hash common.Hash, number uint64) bool { if bc.blockCache.Contains(hash) { return true } - ok, _ := bc.db.Has(blockBodyKey(hash, number)) + ok, _ := bc.db.Has(rawdb.BlockBodyKey(hash, number)) return ok } @@ -774,7 +774,7 @@ func (bc *BlockChain) GetBlock(hash common.Hash, number uint64) *types.Block { if block, ok := bc.blockCache.Get(hash); ok { return block.(*types.Block) } - block := GetBlock(bc.db, hash, number) + block := rawdb.GetBlock(bc.db, hash, number) if block == nil { return nil } @@ -791,7 +791,7 @@ func (bc *BlockChain) GetBlockByHash(hash common.Hash) *types.Block { // GetBlockByNumber retrieves a block from the database by number, caching it // (associated with its hash) if found. func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block { - hash := GetCanonicalHash(bc.db, number) + hash := rawdb.GetCanonicalHash(bc.db, number) if hash == (common.Hash{}) { return nil } @@ -800,7 +800,7 @@ func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block { // GetReceiptsByHash retrieves the receipts for all transactions in a given block. func (bc *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts { - return GetBlockReceipts(bc.db, hash, GetBlockNumber(bc.db, hash)) + return rawdb.GetBlockReceipts(bc.db, hash, rawdb.GetBlockNumber(bc.db, hash)) } // GetBlocksFromHash returns the block corresponding to hash and up to n-1 ancestors. @@ -996,12 +996,12 @@ func (bc *BlockChain) Rollback(chain []common.Hash) { if currentFastBlock := bc.CurrentFastBlock(); currentFastBlock.Hash() == hash { newFastBlock := bc.GetBlock(currentFastBlock.ParentHash(), currentFastBlock.NumberU64()-1) bc.currentFastBlock.Store(newFastBlock) - WriteHeadFastBlockHash(bc.db, newFastBlock.Hash()) + rawdb.WriteHeadFastBlockHash(bc.db, newFastBlock.Hash()) } if currentBlock := bc.CurrentBlock(); currentBlock.Hash() == hash { newBlock := bc.GetBlock(currentBlock.ParentHash(), currentBlock.NumberU64()-1) bc.currentBlock.Store(newBlock) - WriteHeadBlockHash(bc.db, newBlock.Hash()) + rawdb.WriteHeadBlockHash(bc.db, newBlock.Hash()) } } } @@ -1086,13 +1086,13 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [ return i, fmt.Errorf("failed to set receipts data: %v", err) } // Write all the data out into the database - if err := WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil { + if err := rawdb.WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil { return i, fmt.Errorf("failed to write block body: %v", err) } - if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil { + if err := rawdb.WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil { return i, fmt.Errorf("failed to write block receipts: %v", err) } - if err := WriteTxLookupEntries(batch, block); err != nil { + if err := rawdb.WriteTxLookupEntries(batch, block); err != nil { return i, fmt.Errorf("failed to write lookup metadata: %v", err) } stats.processed++ @@ -1118,7 +1118,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [ if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case currentFastBlock := bc.CurrentFastBlock() if bc.GetTd(currentFastBlock.Hash(), currentFastBlock.NumberU64()).Cmp(td) < 0 { - if err := WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil { + if err := rawdb.WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil { log.Crit("Failed to update head fast block hash", "err", err) } bc.currentFastBlock.Store(head) @@ -1148,7 +1148,7 @@ func (bc *BlockChain) WriteBlockWithoutState(block *types.Block, td *big.Int) (e if err := bc.hc.WriteTd(block.Hash(), block.NumberU64(), td); err != nil { return err } - if err := WriteBlock(bc.db, block); err != nil { + if err := rawdb.WriteBlock(bc.db, block); err != nil { return err } return nil @@ -1178,7 +1178,7 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. } // Write other block data using a batch. batch := bc.db.NewBatch() - if err := WriteBlock(batch, block); err != nil { + if err := rawdb.WriteBlock(batch, block); err != nil { return NonStatTy, err } root, err := state.Commit(bc.chainConfig.IsEIP158(block.Number())) @@ -1324,7 +1324,7 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. } } } - if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil { + if err := rawdb.WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil { return NonStatTy, err } // If the total difficulty is higher than our known, add it to the canonical chain @@ -1344,11 +1344,11 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. } } // Write the positional metadata for transaction and receipt lookups - if err := WriteTxLookupEntries(batch, block); err != nil { + if err := rawdb.WriteTxLookupEntries(batch, block); err != nil { return NonStatTy, err } // Write hash preimages - if err := WritePreimages(bc.db, block.NumberU64(), state.Preimages()); err != nil { + if err := rawdb.WritePreimages(bc.db, block.NumberU64(), state.Preimages()); err != nil { return NonStatTy, err } status = CanonStatTy @@ -2120,7 +2120,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { // These logs are later announced as deleted. collectLogs = func(h common.Hash) { // Coalesce logs and set 'Removed'. - receipts := GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h)) + receipts := rawdb.GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h)) for _, receipt := range receipts { for _, log := range receipt.Logs { del := *log @@ -2189,7 +2189,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { // insert the block in the canonical way, re-writing history bc.insert(newChain[i]) // write lookup entries for hash based transaction/receipt searches - if err := WriteTxLookupEntries(bc.db, newChain[i]); err != nil { + if err := rawdb.WriteTxLookupEntries(bc.db, newChain[i]); err != nil { return err } addedTxs = append(addedTxs, newChain[i].Transactions()...) @@ -2199,7 +2199,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { // When transactions get deleted from the database that means the // receipts that were created in the fork must also be deleted for _, tx := range diff { - DeleteTxLookupEntry(bc.db, tx.Hash()) + rawdb.DeleteTxLookupEntry(bc.db, tx.Hash()) } if len(deletedLogs) > 0 { go bc.rmLogsFeed.Send(RemovedLogsEvent{deletedLogs}) diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 6860924112..7c61a61433 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -18,7 +18,6 @@ package core import ( "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "sync" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -128,8 +128,8 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error { return err } blockchain.mu.Lock() - WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) - WriteBlock(blockchain.db, block) + rawdb.WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) + rawdb.WriteBlock(blockchain.db, block) statedb.Commit(true) blockchain.mu.Unlock() } @@ -146,8 +146,8 @@ func testHeaderChainImport(chain []*types.Header, blockchain *BlockChain) error } // Manually insert the header into the database, but don't reorganise (allows subsequent testing) blockchain.mu.Lock() - WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash))) - WriteHeader(blockchain.db, header) + rawdb.WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash))) + rawdb.WriteHeader(blockchain.db, header) blockchain.mu.Unlock() } return nil @@ -173,7 +173,7 @@ func TestLastBlock(t *testing.T) { if _, err := blockchain.InsertChain(blocks); err != nil { t.Fatalf("Failed to insert block: %v", err) } - if blocks[len(blocks)-1].Hash() != GetHeadBlockHash(blockchain.db) { + if blocks[len(blocks)-1].Hash() != rawdb.GetHeadBlockHash(blockchain.db) { t.Fatalf("Write/Get HeadBlockHash failed") } } @@ -622,13 +622,13 @@ func TestFastVsFullChains(t *testing.T) { } else if types.CalcUncleHash(fblock.Uncles()) != types.CalcUncleHash(ablock.Uncles()) { t.Errorf("block #%d [%x]: uncles mismatch: have %v, want %v", num, hash, fblock.Uncles(), ablock.Uncles()) } - if freceipts, areceipts := GetBlockReceipts(fastDb, hash, GetBlockNumber(fastDb, hash)), GetBlockReceipts(archiveDb, hash, GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts) != types.DeriveSha(areceipts) { + if freceipts, areceipts := rawdb.GetBlockReceipts(fastDb, hash, rawdb.GetBlockNumber(fastDb, hash)), rawdb.GetBlockReceipts(archiveDb, hash, rawdb.GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts) != types.DeriveSha(areceipts) { t.Errorf("block #%d [%x]: receipts mismatch: have %v, want %v", num, hash, freceipts, areceipts) } } // Check that the canonical chains are the same between the databases for i := 0; i < len(blocks)+1; i++ { - if fhash, ahash := GetCanonicalHash(fastDb, uint64(i)), GetCanonicalHash(archiveDb, uint64(i)); fhash != ahash { + if fhash, ahash := rawdb.GetCanonicalHash(fastDb, uint64(i)), rawdb.GetCanonicalHash(archiveDb, uint64(i)); fhash != ahash { t.Errorf("block #%d: canonical hash mismatch: have %v, want %v", i, fhash, ahash) } } @@ -804,28 +804,28 @@ func TestChainTxReorgs(t *testing.T) { // removed tx for i, tx := range (types.Transactions{pastDrop, freshDrop}) { - if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn != nil { + if txn, _, _, _ := rawdb.GetTransaction(db, tx.Hash()); txn != nil { t.Errorf("drop %d: tx %v found while shouldn't have been", i, txn) } - if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt != nil { + if rcpt, _, _, _ := rawdb.GetReceipt(db, tx.Hash()); rcpt != nil { t.Errorf("drop %d: receipt %v found while shouldn't have been", i, rcpt) } } // added tx for i, tx := range (types.Transactions{pastAdd, freshAdd, futureAdd}) { - if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn == nil { + if txn, _, _, _ := rawdb.GetTransaction(db, tx.Hash()); txn == nil { t.Errorf("add %d: expected tx to be found", i) } - if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt == nil { + if rcpt, _, _, _ := rawdb.GetReceipt(db, tx.Hash()); rcpt == nil { t.Errorf("add %d: expected receipt to be found", i) } } // shared tx for i, tx := range (types.Transactions{postponed, swapped}) { - if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn == nil { + if txn, _, _, _ := rawdb.GetTransaction(db, tx.Hash()); txn == nil { t.Errorf("share %d: expected tx to be found", i) } - if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt == nil { + if rcpt, _, _, _ := rawdb.GetReceipt(db, tx.Hash()); rcpt == nil { t.Errorf("share %d: expected receipt to be found", i) } } @@ -980,14 +980,14 @@ func TestCanonicalBlockRetrieval(t *testing.T) { // try to retrieve a block by its canonical hash and see if the block data can be retrieved. for { - ch := GetCanonicalHash(blockchain.db, block.NumberU64()) + ch := rawdb.GetCanonicalHash(blockchain.db, block.NumberU64()) if ch == (common.Hash{}) { continue // busy wait for canonical hash to be written } if ch != block.Hash() { t.Fatalf("unknown canonical hash, want %s, got %s", block.Hash().Hex(), ch.Hex()) } - fb := GetBlock(blockchain.db, ch, block.NumberU64()) + fb := rawdb.GetBlock(blockchain.db, ch, block.NumberU64()) if fb == nil { t.Fatalf("unable to retrieve block %d for canonical hash: %s", block.NumberU64(), ch.Hex()) } diff --git a/core/chain_indexer.go b/core/chain_indexer.go index 95190eea93..41f3919904 100644 --- a/core/chain_indexer.go +++ b/core/chain_indexer.go @@ -24,6 +24,7 @@ import ( "time" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -206,7 +207,7 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, events chan ChainE // TODO(karalabe): This operation is expensive and might block, causing the event system to // potentially also lock up. We need to do with on a different thread somehow. - if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil { + if h := rawdb.FindCommonAncestor(c.chainDb, prevHeader, header); h != nil { c.newHead(h.Number.Uint64(), true) } } @@ -349,11 +350,11 @@ func (c *ChainIndexer) processSection(section uint64, lastHead common.Hash) (com } for number := section * c.sectionSize; number < (section+1)*c.sectionSize; number++ { - hash := GetCanonicalHash(c.chainDb, number) + hash := rawdb.GetCanonicalHash(c.chainDb, number) if hash == (common.Hash{}) { return common.Hash{}, fmt.Errorf("canonical block #%d unknown", number) } - header := GetHeader(c.chainDb, hash, number) + header := rawdb.GetHeader(c.chainDb, hash, number) if header == nil { return common.Hash{}, fmt.Errorf("block #%d [%x…] not found", number, hash[:4]) } else if header.ParentHash != lastHead { diff --git a/core/chain_indexer_test.go b/core/chain_indexer_test.go index a954c062d9..3a50819b9d 100644 --- a/core/chain_indexer_test.go +++ b/core/chain_indexer_test.go @@ -18,13 +18,13 @@ package core import ( "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "testing" "time" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" ) @@ -92,10 +92,10 @@ func testChainIndexer(t *testing.T, count int) { inject := func(number uint64) { header := &types.Header{Number: big.NewInt(int64(number)), Extra: big.NewInt(rand.Int63()).Bytes()} if number > 0 { - header.ParentHash = GetCanonicalHash(db, number-1) + header.ParentHash = rawdb.GetCanonicalHash(db, number-1) } - WriteHeader(db, header) - WriteCanonicalHash(db, header.Hash(), number) + rawdb.WriteHeader(db, header) + rawdb.WriteCanonicalHash(db, header.Hash(), number) } // Start indexer with an already existing chain for i := uint64(0); i <= 100; i++ { diff --git a/core/genesis.go b/core/genesis.go index e1b7185a41..b646c3b4c2 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -22,13 +22,13 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "strings" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" @@ -140,10 +140,10 @@ func (e *GenesisMismatchError) Error() string { // SetupGenesisBlock writes or updates the genesis block in db. // The block that will be used is: // -// genesis == nil genesis != nil -// +------------------------------------------ -// db has no genesis | main-net default | genesis -// db has genesis | from DB | genesis (if compatible) +// genesis == nil genesis != nil +// +------------------------------------------ +// db has no genesis | main-net default | genesis +// db has genesis | from DB | genesis (if compatible) // // The stored chain configuration will be updated if it is compatible (i.e. does not // specify a fork block below the local head block). In case of a conflict, the @@ -156,7 +156,7 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig } // Just commit the new block if there is no stored genesis block. - stored := GetCanonicalHash(db, 0) + stored := rawdb.GetCanonicalHash(db, 0) if (stored == common.Hash{}) { if genesis == nil { log.Info("Writing default main-net genesis block") @@ -178,12 +178,12 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig // Get the existing chain configuration. newcfg := genesis.configOrDefault(stored) - storedcfg, err := GetChainConfig(db, stored) + storedcfg, err := rawdb.GetChainConfig(db, stored) if err != nil { - if err == ErrChainConfigNotFound { + if err == rawdb.ErrChainConfigNotFound { // This case happens if a genesis write was interrupted. log.Warn("Found genesis block without chain config") - err = WriteChainConfig(db, stored, newcfg) + err = rawdb.WriteChainConfig(db, stored, newcfg) } return newcfg, stored, err } @@ -196,15 +196,15 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig // Check config compatibility and write the config. Compatibility errors // are returned to the caller unless we're already at block zero. - height := GetBlockNumber(db, GetHeadHeaderHash(db)) - if height == missingNumber { + height := rawdb.GetBlockNumber(db, rawdb.GetHeadHeaderHash(db)) + if height == rawdb.MissingNumber { return newcfg, stored, fmt.Errorf("missing block number for head header hash") } compatErr := storedcfg.CheckCompatible(newcfg, height) if compatErr != nil && height != 0 && compatErr.RewindTo != 0 { return newcfg, stored, compatErr } - return newcfg, stored, WriteChainConfig(db, stored, newcfg) + return newcfg, stored, rawdb.WriteChainConfig(db, stored, newcfg) } func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig { @@ -268,29 +268,29 @@ func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { if block.Number().Sign() != 0 { return nil, fmt.Errorf("can't commit genesis block with number > 0") } - if err := WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil { + if err := rawdb.WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil { return nil, err } - if err := WriteBlock(db, block); err != nil { + if err := rawdb.WriteBlock(db, block); err != nil { return nil, err } - if err := WriteBlockReceipts(db, block.Hash(), block.NumberU64(), nil); err != nil { + if err := rawdb.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), nil); err != nil { return nil, err } - if err := WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { + if err := rawdb.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { return nil, err } - if err := WriteHeadBlockHash(db, block.Hash()); err != nil { + if err := rawdb.WriteHeadBlockHash(db, block.Hash()); err != nil { return nil, err } - if err := WriteHeadHeaderHash(db, block.Hash()); err != nil { + if err := rawdb.WriteHeadHeaderHash(db, block.Hash()); err != nil { return nil, err } config := g.Config if config == nil { config = params.AllEthashProtocolChanges } - return block, WriteChainConfig(db, block.Hash(), config) + return block, rawdb.WriteChainConfig(db, block.Hash(), config) } // MustCommit writes the genesis block and state to db, panicking on error. diff --git a/core/genesis_test.go b/core/genesis_test.go index 177798a5d2..ee32b6705d 100644 --- a/core/genesis_test.go +++ b/core/genesis_test.go @@ -17,7 +17,6 @@ package core import ( - "github.com/tomochain/tomochain/core/rawdb" "math/big" "reflect" "testing" @@ -25,6 +24,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/params" @@ -155,7 +155,7 @@ func TestSetupGenesis(t *testing.T) { t.Errorf("%s: returned hash %s, want %s", test.name, hash.Hex(), test.wantHash.Hex()) } else if err == nil { // Check database content. - stored := GetBlock(db, test.wantHash, 0) + stored := rawdb.GetBlock(db, test.wantHash, 0) if stored.Hash() != test.wantHash { t.Errorf("%s: block in DB has hash %s, want %s", test.name, stored.Hash(), test.wantHash) } diff --git a/core/headerchain.go b/core/headerchain.go index 8365f2127d..3519e2cb92 100644 --- a/core/headerchain.go +++ b/core/headerchain.go @@ -26,9 +26,11 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" @@ -66,9 +68,10 @@ type HeaderChain struct { } // NewHeaderChain creates a new HeaderChain structure. -// getValidator should return the parent's validator -// procInterrupt points to the parent's interrupt semaphore -// wg points to the parent's shutdown wait group +// +// getValidator should return the parent's validator +// procInterrupt points to the parent's interrupt semaphore +// wg points to the parent's shutdown wait group func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine consensus.Engine, procInterrupt func() bool) (*HeaderChain, error) { headerCache, _ := lru.New(headerCacheLimit) tdCache, _ := lru.New(tdCacheLimit) @@ -97,7 +100,7 @@ func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine c } hc.currentHeader.Store(hc.genesisHeader) - if head := GetHeadBlockHash(chainDb); head != (common.Hash{}) { + if head := rawdb.GetHeadBlockHash(chainDb); head != (common.Hash{}) { if chead := hc.GetHeaderByHash(head); chead != nil { hc.currentHeader.Store(chead) } @@ -113,8 +116,8 @@ func (hc *HeaderChain) GetBlockNumber(hash common.Hash) uint64 { if cached, ok := hc.numberCache.Get(hash); ok { return cached.(uint64) } - number := GetBlockNumber(hc.chainDb, hash) - if number != missingNumber { + number := rawdb.GetBlockNumber(hc.chainDb, hash) + if number != rawdb.MissingNumber { hc.numberCache.Add(hash, number) } return number @@ -147,7 +150,7 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er if err := hc.WriteTd(hash, number, externTd); err != nil { log.Crit("Failed to write header total difficulty", "err", err) } - if err := WriteHeader(hc.chainDb, header); err != nil { + if err := rawdb.WriteHeader(hc.chainDb, header); err != nil { log.Crit("Failed to write header content", "err", err) } // If the total difficulty is higher than our known, add it to the canonical chain @@ -156,11 +159,11 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er if externTd.Cmp(localTd) > 0 || (externTd.Cmp(localTd) == 0 && mrand.Float64() < 0.5) { // Delete any canonical number assignments above the new head for i := number + 1; ; i++ { - hash := GetCanonicalHash(hc.chainDb, i) + hash := rawdb.GetCanonicalHash(hc.chainDb, i) if hash == (common.Hash{}) { break } - DeleteCanonicalHash(hc.chainDb, i) + rawdb.DeleteCanonicalHash(hc.chainDb, i) } // Overwrite any stale canonical number assignments var ( @@ -168,18 +171,18 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er headNumber = header.Number.Uint64() - 1 headHeader = hc.GetHeader(headHash, headNumber) ) - for GetCanonicalHash(hc.chainDb, headNumber) != headHash { - WriteCanonicalHash(hc.chainDb, headHash, headNumber) + for rawdb.GetCanonicalHash(hc.chainDb, headNumber) != headHash { + rawdb.WriteCanonicalHash(hc.chainDb, headHash, headNumber) headHash = headHeader.ParentHash headNumber = headHeader.Number.Uint64() - 1 headHeader = hc.GetHeader(headHash, headNumber) } // Extend the canonical chain with the new header - if err := WriteCanonicalHash(hc.chainDb, hash, number); err != nil { + if err := rawdb.WriteCanonicalHash(hc.chainDb, hash, number); err != nil { log.Crit("Failed to insert header number", "err", err) } - if err := WriteHeadHeaderHash(hc.chainDb, hash); err != nil { + if err := rawdb.WriteHeadHeaderHash(hc.chainDb, hash); err != nil { log.Crit("Failed to insert head header hash", "err", err) } hc.currentHeaderHash = hash @@ -316,7 +319,7 @@ func (hc *HeaderChain) GetTd(hash common.Hash, number uint64) *big.Int { if cached, ok := hc.tdCache.Get(hash); ok { return cached.(*big.Int) } - td := GetTd(hc.chainDb, hash, number) + td := rawdb.GetTd(hc.chainDb, hash, number) if td == nil { return nil } @@ -334,7 +337,7 @@ func (hc *HeaderChain) GetTdByHash(hash common.Hash) *big.Int { // WriteTd stores a block's total difficulty into the database, also caching it // along the way. func (hc *HeaderChain) WriteTd(hash common.Hash, number uint64, td *big.Int) error { - if err := WriteTd(hc.chainDb, hash, number, td); err != nil { + if err := rawdb.WriteTd(hc.chainDb, hash, number, td); err != nil { return err } hc.tdCache.Add(hash, new(big.Int).Set(td)) @@ -348,7 +351,7 @@ func (hc *HeaderChain) GetHeader(hash common.Hash, number uint64) *types.Header if header, ok := hc.headerCache.Get(hash); ok { return header.(*types.Header) } - header := GetHeader(hc.chainDb, hash, number) + header := rawdb.GetHeader(hc.chainDb, hash, number) if header == nil { return nil } @@ -368,14 +371,14 @@ func (hc *HeaderChain) HasHeader(hash common.Hash, number uint64) bool { if hc.numberCache.Contains(hash) || hc.headerCache.Contains(hash) { return true } - ok, _ := hc.chainDb.Has(headerKey(hash, number)) + ok, _ := hc.chainDb.Has(rawdb.HeaderKey(hash, number)) return ok } // GetHeaderByNumber retrieves a block header from the database by number, // caching it (associated with its hash) if found. func (hc *HeaderChain) GetHeaderByNumber(number uint64) *types.Header { - hash := GetCanonicalHash(hc.chainDb, number) + hash := rawdb.GetCanonicalHash(hc.chainDb, number) if hash == (common.Hash{}) { return nil } @@ -390,7 +393,7 @@ func (hc *HeaderChain) CurrentHeader() *types.Header { // SetCurrentHeader sets the current head header of the canonical chain. func (hc *HeaderChain) SetCurrentHeader(head *types.Header) { - if err := WriteHeadHeaderHash(hc.chainDb, head.Hash()); err != nil { + if err := rawdb.WriteHeadHeaderHash(hc.chainDb, head.Hash()); err != nil { log.Crit("Failed to insert head header hash", "err", err) } hc.currentHeader.Store(head) @@ -416,13 +419,13 @@ func (hc *HeaderChain) SetHead(head uint64, delFn DeleteCallback) { if delFn != nil { delFn(hash, num) } - DeleteHeader(hc.chainDb, hash, num) - DeleteTd(hc.chainDb, hash, num) + rawdb.DeleteHeader(hc.chainDb, hash, num) + rawdb.DeleteTd(hc.chainDb, hash, num) hc.currentHeader.Store(hc.GetHeader(hdr.ParentHash, hdr.Number.Uint64()-1)) } // Roll back the canonical chain numbering for i := height; i > head; i-- { - DeleteCanonicalHash(hc.chainDb, i) + rawdb.DeleteCanonicalHash(hc.chainDb, i) } // Clear out any stale content from the caches hc.headerCache.Purge() @@ -434,7 +437,7 @@ func (hc *HeaderChain) SetHead(head uint64, delFn DeleteCallback) { } hc.currentHeaderHash = hc.CurrentHeader().Hash() - if err := WriteHeadHeaderHash(hc.chainDb, hc.currentHeaderHash); err != nil { + if err := rawdb.WriteHeadHeaderHash(hc.chainDb, hc.currentHeaderHash); err != nil { log.Crit("Failed to reset head header hash", "err", err) } } diff --git a/core/database_util.go b/core/rawdb/database_util.go similarity index 98% rename from core/database_util.go rename to core/rawdb/database_util.go index a5ab18687d..3b3eb01727 100644 --- a/core/database_util.go +++ b/core/rawdb/database_util.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package core +package rawdb import ( "bytes" @@ -22,7 +22,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "github.com/tomochain/tomochain/common" @@ -100,16 +99,16 @@ func GetCanonicalHash(db DatabaseReader, number uint64) common.Hash { return common.BytesToHash(data) } -// missingNumber is returned by GetBlockNumber if no header with the +// MissingNumber is returned by GetBlockNumber if no header with the // given block hash has been stored in the database -const missingNumber = uint64(0xffffffffffffffff) +const MissingNumber = uint64(0xffffffffffffffff) // GetBlockNumber returns the block number assigned to a block hash // if the corresponding header is present in the database func GetBlockNumber(db DatabaseReader, hash common.Hash) uint64 { data, _ := db.Get(append(blockHashPrefix, hash.Bytes()...)) if len(data) != 8 { - return missingNumber + return MissingNumber } return binary.BigEndian.Uint64(data) } @@ -161,7 +160,7 @@ func GetTrieSyncProgress(db DatabaseReader) uint64 { // GetHeaderRLP retrieves a block header in its raw RLP database encoding, or nil // if the header's not found. func GetHeaderRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue { - data, _ := db.Get(headerKey(hash, number)) + data, _ := db.Get(HeaderKey(hash, number)) return data } @@ -182,15 +181,15 @@ func GetHeader(db DatabaseReader, hash common.Hash, number uint64) *types.Header // GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding. func GetBodyRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue { - data, _ := db.Get(blockBodyKey(hash, number)) + data, _ := db.Get(BlockBodyKey(hash, number)) return data } -func headerKey(hash common.Hash, number uint64) []byte { +func HeaderKey(hash common.Hash, number uint64) []byte { return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...) } -func blockBodyKey(hash common.Hash, number uint64) []byte { +func BlockBodyKey(hash common.Hash, number uint64) []byte { return append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...) } @@ -555,7 +554,7 @@ func DeleteTxLookupEntry(db DatabaseDeleter, hash common.Hash) { // PreimageTable returns a Database instance with the key prefix for preimage entries. func PreimageTable(db ethdb.Database) ethdb.Database { - return rawdb.NewTable(db, preimagePrefix) + return NewTable(db, preimagePrefix) } // WritePreimages writes the provided set of preimages to the database. `number` is the diff --git a/core/database_util_test.go b/core/rawdb/database_util_test.go similarity index 97% rename from core/database_util_test.go rename to core/rawdb/database_util_test.go index f28ca160a5..8f4385c580 100644 --- a/core/database_util_test.go +++ b/core/rawdb/database_util_test.go @@ -14,11 +14,10 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package core +package rawdb import ( "bytes" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" @@ -30,7 +29,7 @@ import ( // Tests block header storage and retrieval operations. func TestHeaderStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() // Create a test header to move around the database and make sure it's really new header := &types.Header{Number: big.NewInt(42), Extra: []byte("test header")} @@ -65,7 +64,7 @@ func TestHeaderStorage(t *testing.T) { // Tests block body storage and retrieval operations. func TestBodyStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() // Create a test body to move around the database and make sure it's really new body := &types.Body{Uncles: []*types.Header{{Extra: []byte("test header")}}} @@ -105,7 +104,7 @@ func TestBodyStorage(t *testing.T) { // Tests block storage and retrieval operations. func TestBlockStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() // Create a test block to move around the database and make sure it's really new block := types.NewBlockWithHeader(&types.Header{ @@ -157,7 +156,7 @@ func TestBlockStorage(t *testing.T) { // Tests that partial block contents don't get reassembled into full blocks. func TestPartialBlockStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() block := types.NewBlockWithHeader(&types.Header{ Extra: []byte("test block"), UncleHash: types.EmptyUncleHash, @@ -198,7 +197,7 @@ func TestPartialBlockStorage(t *testing.T) { // Tests block total difficulty storage and retrieval operations. func TestTdStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() // Create a test TD to move around the database and make sure it's really new hash, td := common.Hash{}, big.NewInt(314) @@ -223,7 +222,7 @@ func TestTdStorage(t *testing.T) { // Tests that canonical numbers can be mapped to hashes and retrieved. func TestCanonicalMappingStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() // Create a test canonical number and assinged hash to move around hash, number := common.Hash{0: 0xff}, uint64(314) @@ -248,7 +247,7 @@ func TestCanonicalMappingStorage(t *testing.T) { // Tests that head headers and head blocks can be assigned, individually. func TestHeadStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() blockHead := types.NewBlockWithHeader(&types.Header{Extra: []byte("test block header")}) blockFull := types.NewBlockWithHeader(&types.Header{Extra: []byte("test block full")}) @@ -288,7 +287,7 @@ func TestHeadStorage(t *testing.T) { // Tests that positional lookup metadata can be stored and retrieved. func TestLookupStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() tx1 := types.NewTransaction(1, common.BytesToAddress([]byte{0x11}), big.NewInt(111), 1111, big.NewInt(11111), []byte{0x11, 0x11, 0x11}) tx2 := types.NewTransaction(2, common.BytesToAddress([]byte{0x22}), big.NewInt(222), 2222, big.NewInt(22222), []byte{0x22, 0x22, 0x22}) @@ -333,7 +332,7 @@ func TestLookupStorage(t *testing.T) { // Tests that receipts associated with a single block can be stored and retrieved. func TestBlockReceiptStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() receipt1 := &types.Receipt{ Status: types.ReceiptStatusFailed, diff --git a/eth/api.go b/eth/api.go index 76a466a49f..e885f6d600 100644 --- a/eth/api.go +++ b/eth/api.go @@ -28,6 +28,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/log" @@ -343,7 +344,7 @@ func NewPrivateDebugAPI(config *params.ChainConfig, eth *Ethereum) *PrivateDebug // Preimage is a debug API function that returns the preimage for a sha3 hash, if known. func (api *PrivateDebugAPI) Preimage(ctx context.Context, hash common.Hash) (hexutil.Bytes, error) { - db := core.PreimageTable(api.eth.ChainDb()) + db := rawdb.PreimageTable(api.eth.ChainDb()) return db.Get(hash.Bytes()) } @@ -494,11 +495,10 @@ func (api *PublicEthereumAPI) ChainId() hexutil.Uint64 { } // GetOwner return masternode owner of the given coinbase address -func (api *PublicEthereumAPI) GetOwnerByCoinbase(ctx context.Context, coinbase common.Address, blockNr rpc.BlockNumber) (common.Address, error) { +func (api *PublicEthereumAPI) GetOwnerByCoinbase(ctx context.Context, coinbase common.Address, blockNr rpc.BlockNumber) (common.Address, error) { statedb, _, err := api.e.ApiBackend.StateAndHeaderByNumber(ctx, blockNr) if err != nil { return common.Address{}, err } return statedb.GetOwner(coinbase), nil } - diff --git a/eth/api_backend.go b/eth/api_backend.go index 67554b4480..13020714b1 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -21,23 +21,19 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending" "io/ioutil" "math/big" "path/filepath" - "github.com/tomochain/tomochain/tomox" - - "github.com/tomochain/tomochain/consensus/posv" - "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/consensus/posv" "github.com/tomochain/tomochain/contracts" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" stateDatabase "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" @@ -50,6 +46,9 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending" ) // EthApiBackend implements ethapi.Backend for full nodes @@ -117,11 +116,11 @@ func (b *EthApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*t } func (b *EthApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { - return core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)), nil + return rawdb.GetBlockReceipts(b.eth.chainDb, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash)), nil } func (b *EthApiBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { - receipts := core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) + receipts := rawdb.GetBlockReceipts(b.eth.chainDb, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash)) if receipts == nil { return nil, nil } diff --git a/eth/api_tracer.go b/eth/api_tracer.go index e1744dc2c1..6aa3ba3171 100644 --- a/eth/api_tracer.go +++ b/eth/api_tracer.go @@ -21,7 +21,6 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" "io/ioutil" "math/big" "runtime" @@ -31,6 +30,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -39,6 +39,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" "github.com/tomochain/tomochain/trie" ) @@ -567,14 +568,14 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (* } size, _ := database.TrieDB().Size() log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start), "size", size) - return statedb,tomoxState, nil + return statedb, tomoxState, nil } // TraceTransaction returns the structured logs created during the execution of EVM // and returns them as a JSON object. func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, hash common.Hash, config *TraceConfig) (interface{}, error) { // Retrieve the transaction and assemble its EVM context - tx, blockHash, _, index := core.GetTransaction(api.eth.ChainDb(), hash) + tx, blockHash, _, index := rawdb.GetTransaction(api.eth.ChainDb(), hash) if tx == nil { return nil, fmt.Errorf("transaction %x not found", hash) } diff --git a/eth/backend.go b/eth/backend.go index 412c67d230..8bd7806bfc 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -18,6 +18,7 @@ package eth import ( + "bytes" "errors" "fmt" "math/big" @@ -27,18 +28,10 @@ import ( "sync/atomic" "time" - "github.com/tomochain/tomochain/tomoxlending" - - "github.com/tomochain/tomochain/accounts/abi/bind" - "github.com/tomochain/tomochain/common/hexutil" - "github.com/tomochain/tomochain/core/state" - "github.com/tomochain/tomochain/eth/filters" - "github.com/tomochain/tomochain/rlp" - - "bytes" - "github.com/tomochain/tomochain/accounts" + "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/consensus/posv" @@ -46,11 +39,12 @@ import ( contractValidator "github.com/tomochain/tomochain/contracts/validator/contract" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" - - //"github.com/tomochain/tomochain/core/state" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/eth/downloader" + "github.com/tomochain/tomochain/eth/filters" "github.com/tomochain/tomochain/eth/gasprice" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -60,8 +54,10 @@ import ( "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomoxlending" ) type LesServer interface { @@ -160,11 +156,11 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi log.Info("Initialising Ethereum protocol", "versions", ProtocolVersions, "network", config.NetworkId) if !config.SkipBcVersionCheck { - bcVersion := core.GetBlockChainVersion(chainDb) + bcVersion := rawdb.GetBlockChainVersion(chainDb) if bcVersion != core.BlockChainVersion && bcVersion != 0 { return nil, fmt.Errorf("Blockchain DB version mismatch (%d / %d). Run geth upgradedb.\n", bcVersion, core.BlockChainVersion) } - core.WriteBlockChainVersion(chainDb, core.BlockChainVersion) + rawdb.WriteBlockChainVersion(chainDb, core.BlockChainVersion) } var ( vmConfig = vm.Config{EnablePreimageRecording: config.EnablePreimageRecording} @@ -187,7 +183,7 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi if compat, ok := genesisErr.(*params.ConfigCompatError); ok { log.Warn("Rewinding chain to upgrade configuration", "err", compat) eth.blockchain.SetHead(compat.RewindTo) - core.WriteChainConfig(chainDb, genesisHash, chainConfig) + rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig) } eth.bloomIndexer.Start(eth.blockchain) diff --git a/eth/bloombits.go b/eth/bloombits.go index abe8c5d671..39695f43e8 100644 --- a/eth/bloombits.go +++ b/eth/bloombits.go @@ -17,13 +17,13 @@ package eth import ( - "github.com/tomochain/tomochain/core/rawdb" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/bitutil" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/params" @@ -61,8 +61,8 @@ func (eth *Ethereum) startBloomHandlers() { task := <-request task.Bitsets = make([][]byte, len(task.Sections)) for i, section := range task.Sections { - head := core.GetCanonicalHash(eth.chainDb, (section+1)*params.BloomBitsBlocks-1) - if compVector, err := core.GetBloomBits(eth.chainDb, task.Bit, section, head); err == nil { + head := rawdb.GetCanonicalHash(eth.chainDb, (section+1)*params.BloomBitsBlocks-1) + if compVector, err := rawdb.GetBloomBits(eth.chainDb, task.Bit, section, head); err == nil { if blob, err := bitutil.DecompressBytes(compVector, int(params.BloomBitsBlocks)/8); err == nil { task.Bitsets[i] = blob } else { @@ -108,7 +108,7 @@ func NewBloomIndexer(db ethdb.Database, size uint64) *core.ChainIndexer { db: db, size: size, } - table := rawdb.NewTable(db, string(core.BloomBitsIndexPrefix)) + table := rawdb.NewTable(db, string(rawdb.BloomBitsIndexPrefix)) return core.NewChainIndexer(db, table, backend, size, bloomConfirms, bloomThrottling, "bloombits") } @@ -138,7 +138,7 @@ func (b *BloomIndexer) Commit() error { if err != nil { return err } - core.WriteBloomBits(batch, uint(i), b.section, b.head, bitutil.CompressBytes(bits)) + rawdb.WriteBloomBits(batch, uint(i), b.section, b.head, bitutil.CompressBytes(bits)) } return batch.Write() } diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index eba7fad779..f9faf2ff6d 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -27,7 +27,7 @@ import ( "github.com/tomochain/tomochain" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -225,7 +225,7 @@ func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, chain BlockC stateCh: make(chan dataPack), stateSyncStart: make(chan *stateSync), syncStatsState: stateSyncStats{ - processed: core.GetTrieSyncProgress(stateDb), + processed: rawdb.GetTrieSyncProgress(stateDb), }, trackStateReq: make(chan *stateReq), } @@ -975,22 +975,22 @@ func (d *Downloader) fetchReceipts(from uint64) error { // various callbacks to handle the slight differences between processing them. // // The instrumentation parameters: -// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer) -// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers) -// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`) -// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed) -// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping) -// - pending: task callback for the number of requests still needing download (detect completion/non-completability) -// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish) -// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use) -// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions) -// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic) -// - fetch: network callback to actually send a particular download request to a physical remote peer -// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer) -// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping) -// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks -// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) -// - kind: textual label of the type being downloaded to display in log mesages +// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer) +// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers) +// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`) +// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed) +// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping) +// - pending: task callback for the number of requests still needing download (detect completion/non-completability) +// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish) +// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use) +// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions) +// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic) +// - fetch: network callback to actually send a particular download request to a physical remote peer +// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer) +// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping) +// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks +// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) +// - kind: textual label of the type being downloaded to display in log mesages func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool, expire func() map[string]int, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peerConnection, int) (*fetchRequest, bool, error), fetchHook func([]*types.Header), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int, diff --git a/eth/downloader/fakepeer.go b/eth/downloader/fakepeer.go index 4d7c5ac280..26e12307d1 100644 --- a/eth/downloader/fakepeer.go +++ b/eth/downloader/fakepeer.go @@ -21,6 +21,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" ) @@ -126,7 +127,7 @@ func (p *FakePeer) RequestBodies(hashes []common.Hash) error { uncles [][]*types.Header ) for _, hash := range hashes { - block := core.GetBlock(p.db, hash, p.hc.GetBlockNumber(hash)) + block := rawdb.GetBlock(p.db, hash, p.hc.GetBlockNumber(hash)) txs = append(txs, block.Transactions()) uncles = append(uncles, block.Uncles()) @@ -140,7 +141,7 @@ func (p *FakePeer) RequestBodies(hashes []common.Hash) error { func (p *FakePeer) RequestReceipts(hashes []common.Hash) error { var receipts [][]*types.Receipt for _, hash := range hashes { - receipts = append(receipts, core.GetBlockReceipts(p.db, hash, p.hc.GetBlockNumber(hash))) + receipts = append(receipts, rawdb.GetBlockReceipts(p.db, hash, p.hc.GetBlockNumber(hash))) } p.dl.DeliverReceipts(p.id, receipts) return nil diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index 3809a0c579..747c9f9cff 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -18,16 +18,16 @@ package downloader import ( "fmt" - "github.com/tomochain/tomochain/ethdb/memorydb" "hash" "sync" "time" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/crypto/sha3" "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/ethdb/memorydb" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/trie" ) @@ -470,6 +470,6 @@ func (s *stateSync) updateStats(written, duplicate, unexpected int, duration tim log.Info("Imported new state entries", "count", written, "elapsed", common.PrettyDuration(duration), "processed", s.d.syncStatsState.processed, "pending", s.d.syncStatsState.pending, "retry", len(s.tasks), "duplicate", s.d.syncStatsState.duplicate, "unexpected", s.d.syncStatsState.unexpected) } if written > 0 { - core.WriteTrieSyncProgress(s.d.stateDB, s.d.syncStatsState.processed) + rawdb.WriteTrieSyncProgress(s.d.stateDB, s.d.syncStatsState.processed) } } diff --git a/eth/filters/bench_test.go b/eth/filters/bench_test.go index 3648a3db2f..9822a85e43 100644 --- a/eth/filters/bench_test.go +++ b/eth/filters/bench_test.go @@ -20,14 +20,13 @@ import ( "bytes" "context" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "testing" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/bitutil" - "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -68,18 +67,18 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { benchDataDir := node.DefaultDataDir() + "/geth/chaindata" fmt.Println("Running bloombits benchmark section size:", sectionSize) - db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"") + db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "") if err != nil { b.Fatalf("error opening database at %v: %v", benchDataDir, err) } - head := core.GetHeadBlockHash(db) + head := rawdb.GetHeadBlockHash(db) if head == (common.Hash{}) { b.Fatalf("chain data not found at %v", benchDataDir) } clearBloomBits(db) fmt.Println("Generating bloombits data...") - headNum := core.GetBlockNumber(db, head) + headNum := rawdb.GetBlockNumber(db, head) if headNum < sectionSize+512 { b.Fatalf("not enough blocks for running a benchmark") } @@ -94,14 +93,14 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { } var header *types.Header for i := sectionIdx * sectionSize; i < (sectionIdx+1)*sectionSize; i++ { - hash := core.GetCanonicalHash(db, i) - header = core.GetHeader(db, hash, i) + hash := rawdb.GetCanonicalHash(db, i) + header = rawdb.GetHeader(db, hash, i) if header == nil { b.Fatalf("Error creating bloomBits data") } bc.AddBloom(uint(i-sectionIdx*sectionSize), header.Bloom) } - sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*sectionSize-1) + sectionHead := rawdb.GetCanonicalHash(db, (sectionIdx+1)*sectionSize-1) for i := 0; i < types.BloomBitLength; i++ { data, err := bc.Bitset(uint(i)) if err != nil { @@ -110,7 +109,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { comp := bitutil.CompressBytes(data) dataSize += uint64(len(data)) compSize += uint64(len(comp)) - core.WriteBloomBits(db, uint(i), sectionIdx, sectionHead, comp) + rawdb.WriteBloomBits(db, uint(i), sectionIdx, sectionHead, comp) } //if sectionIdx%50 == 0 { // fmt.Println(" section", sectionIdx, "/", cnt) @@ -130,7 +129,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { for i := 0; i < benchFilterCnt; i++ { if i%20 == 0 { db.Close() - db, _ = rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"") + db, _ = rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "") backend = &testBackend{mux, db, cnt, new(event.Feed), new(event.Feed), new(event.Feed), new(event.Feed)} } var addr common.Address @@ -148,7 +147,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { } func forEachKey(db ethdb.Database, startPrefix, endPrefix []byte, fn func(key []byte)) { - it := db.NewIterator(startPrefix,nil) + it := db.NewIterator(startPrefix, nil) for it.Next() { key := it.Key() cmpLen := len(key) @@ -176,15 +175,15 @@ func clearBloomBits(db ethdb.Database) { func BenchmarkNoBloomBits(b *testing.B) { benchDataDir := node.DefaultDataDir() + "/geth/chaindata" fmt.Println("Running benchmark without bloombits") - db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"") + db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "") if err != nil { b.Fatalf("error opening database at %v: %v", benchDataDir, err) } - head := core.GetHeadBlockHash(db) + head := rawdb.GetHeadBlockHash(db) if head == (common.Hash{}) { b.Fatalf("chain data not found at %v", benchDataDir) } - headNum := core.GetBlockNumber(db, head) + headNum := rawdb.GetBlockNumber(db, head) clearBloomBits(db) diff --git a/eth/filters/filter_system.go b/eth/filters/filter_system.go index 3d92fc1ac7..75c3c5e417 100644 --- a/eth/filters/filter_system.go +++ b/eth/filters/filter_system.go @@ -28,6 +28,7 @@ import ( ethereum "github.com/tomochain/tomochain" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/rpc" @@ -348,11 +349,11 @@ func (es *EventSystem) lightFilterNewHead(newHeader *types.Header, callBack func for oldh.Hash() != newh.Hash() { if oldh.Number.Uint64() >= newh.Number.Uint64() { oldHeaders = append(oldHeaders, oldh) - oldh = core.GetHeader(es.backend.ChainDb(), oldh.ParentHash, oldh.Number.Uint64()-1) + oldh = rawdb.GetHeader(es.backend.ChainDb(), oldh.ParentHash, oldh.Number.Uint64()-1) } if oldh.Number.Uint64() < newh.Number.Uint64() { newHeaders = append(newHeaders, newh) - newh = core.GetHeader(es.backend.ChainDb(), newh.ParentHash, newh.Number.Uint64()-1) + newh = rawdb.GetHeader(es.backend.ChainDb(), newh.ParentHash, newh.Number.Uint64()-1) if newh == nil { // happens when CHT syncing, nothing to do newh = oldh diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go index d947a672ac..dbd0195262 100644 --- a/eth/filters/filter_system_test.go +++ b/eth/filters/filter_system_test.go @@ -19,7 +19,6 @@ package filters import ( "context" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "reflect" @@ -31,6 +30,7 @@ import ( "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -60,23 +60,23 @@ func (b *testBackend) HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumbe var hash common.Hash var num uint64 if blockNr == rpc.LatestBlockNumber { - hash = core.GetHeadBlockHash(b.db) - num = core.GetBlockNumber(b.db, hash) + hash = rawdb.GetHeadBlockHash(b.db) + num = rawdb.GetBlockNumber(b.db, hash) } else { num = uint64(blockNr) - hash = core.GetCanonicalHash(b.db, num) + hash = rawdb.GetCanonicalHash(b.db, num) } - return core.GetHeader(b.db, hash, num), nil + return rawdb.GetHeader(b.db, hash, num), nil } func (b *testBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { - number := core.GetBlockNumber(b.db, blockHash) - return core.GetBlockReceipts(b.db, blockHash, number), nil + number := rawdb.GetBlockNumber(b.db, blockHash) + return rawdb.GetBlockReceipts(b.db, blockHash, number), nil } func (b *testBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { - number := core.GetBlockNumber(b.db, blockHash) - receipts := core.GetBlockReceipts(b.db, blockHash, number) + number := rawdb.GetBlockNumber(b.db, blockHash) + receipts := rawdb.GetBlockReceipts(b.db, blockHash, number) logs := make([][]*types.Log, len(receipts)) for i, receipt := range receipts { @@ -122,8 +122,8 @@ func (b *testBackend) ServiceFilter(ctx context.Context, session *bloombits.Matc task.Bitsets = make([][]byte, len(task.Sections)) for i, section := range task.Sections { if rand.Int()%4 != 0 { // Handle occasional missing deliveries - head := core.GetCanonicalHash(b.db, (section+1)*params.BloomBitsBlocks-1) - task.Bitsets[i], _ = core.GetBloomBits(b.db, task.Bit, section, head) + head := rawdb.GetCanonicalHash(b.db, (section+1)*params.BloomBitsBlocks-1) + task.Bitsets[i], _ = rawdb.GetBloomBits(b.db, task.Bit, section, head) } } request <- task diff --git a/eth/filters/filter_test.go b/eth/filters/filter_test.go index bdfb6e37f8..4db307b056 100644 --- a/eth/filters/filter_test.go +++ b/eth/filters/filter_test.go @@ -18,7 +18,6 @@ package filters import ( "context" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "os" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/event" @@ -50,7 +50,7 @@ func BenchmarkFilters(b *testing.B) { defer os.RemoveAll(dir) var ( - db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0,"") + db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0, "") mux = new(event.TypeMux) txFeed = new(event.Feed) rmLogsFeed = new(event.Feed) @@ -84,14 +84,14 @@ func BenchmarkFilters(b *testing.B) { } }) for i, block := range chain { - core.WriteBlock(db, block) - if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { + rawdb.WriteBlock(db, block) + if err := rawdb.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { b.Fatalf("failed to insert block number: %v", err) } - if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil { + if err := rawdb.WriteHeadBlockHash(db, block.Hash()); err != nil { b.Fatalf("failed to insert block number: %v", err) } - if err := core.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil { + if err := rawdb.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil { b.Fatal("error writing block receipts:", err) } } @@ -115,7 +115,7 @@ func TestFilters(t *testing.T) { defer os.RemoveAll(dir) var ( - db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0,"") + db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0, "") mux = new(event.TypeMux) txFeed = new(event.Feed) rmLogsFeed = new(event.Feed) @@ -174,14 +174,14 @@ func TestFilters(t *testing.T) { } }) for i, block := range chain { - core.WriteBlock(db, block) - if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { + rawdb.WriteBlock(db, block) + if err := rawdb.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { t.Fatalf("failed to insert block number: %v", err) } - if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil { + if err := rawdb.WriteHeadBlockHash(db, block.Hash()); err != nil { t.Fatalf("failed to insert block number: %v", err) } - if err := core.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil { + if err := rawdb.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil { t.Fatal("error writing block receipts:", err) } } diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 33376a1071..e16d3557bc 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -21,14 +21,11 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" "math/big" "sort" "strings" "time" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/util" "github.com/tomochain/tomochain/accounts" @@ -41,6 +38,7 @@ import ( "github.com/tomochain/tomochain/consensus/posv" contractValidator "github.com/tomochain/tomochain/contracts/validator/contract" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -50,6 +48,8 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" ) const ( @@ -424,7 +424,8 @@ func (s *PrivateAccountAPI) SignTransaction(ctx context.Context, args SendTxArgs // safely used to calculate a signature from. // // The hash is calulcated as -// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). +// +// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). // // This gives context to the signed message and prevents signing of transactions. func signHash(data []byte) []byte { @@ -1305,8 +1306,8 @@ func (s *PublicBlockChainAPI) findNearestSignedBlock(ctx context.Context, b *typ } /* - findFinalityOfBlock return finality of a block - Use blocksHashCache for to keep track - refer core/blockchain.go for more detail +findFinalityOfBlock return finality of a block +Use blocksHashCache for to keep track - refer core/blockchain.go for more detail */ func (s *PublicBlockChainAPI) findFinalityOfBlock(ctx context.Context, b *types.Block, masternodes []common.Address) (uint, error) { engine, _ := s.b.GetEngine().(*posv.Posv) @@ -1371,7 +1372,7 @@ func (s *PublicBlockChainAPI) findFinalityOfBlock(ctx context.Context, b *types. } /* - Extract signers from block +Extract signers from block */ func (s *PublicBlockChainAPI) getSigners(ctx context.Context, block *types.Block, engine *posv.Posv) ([]common.Address, error) { var err error @@ -1594,7 +1595,7 @@ func (s *PublicTransactionPoolAPI) GetTransactionCount(ctx context.Context, addr // GetTransactionByHash returns the transaction for the given hash func (s *PublicTransactionPoolAPI) GetTransactionByHash(ctx context.Context, hash common.Hash) *RPCTransaction { // Try to return an already finalized transaction - if tx, blockHash, blockNumber, index := core.GetTransaction(s.b.ChainDb(), hash); tx != nil { + if tx, blockHash, blockNumber, index := rawdb.GetTransaction(s.b.ChainDb(), hash); tx != nil { return newRPCTransaction(tx, blockHash, blockNumber, index) } // No finalized transaction, try to retrieve it from the pool @@ -1610,7 +1611,7 @@ func (s *PublicTransactionPoolAPI) GetRawTransactionByHash(ctx context.Context, var tx *types.Transaction // Retrieve a finalized transaction, or a pooled otherwise - if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil { + if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil { if tx = s.b.GetPoolTransaction(hash); tx == nil { // Transaction not found anywhere, abort return nil, nil @@ -1622,7 +1623,7 @@ func (s *PublicTransactionPoolAPI) GetRawTransactionByHash(ctx context.Context, // GetTransactionReceipt returns the transaction receipt for the given transaction hash. func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, hash common.Hash) (map[string]interface{}, error) { - tx, blockHash, blockNumber, index := core.GetTransaction(s.b.ChainDb(), hash) + tx, blockHash, blockNumber, index := rawdb.GetTransaction(s.b.ChainDb(), hash) if tx == nil { return nil, nil } @@ -1867,7 +1868,7 @@ func (s *PublicTomoXTransactionPoolAPI) SendLendingRawTransaction(ctx context.Co func (s *PublicTomoXTransactionPoolAPI) GetOrderTxMatchByHash(ctx context.Context, hash common.Hash) ([]*tradingstate.OrderItem, error) { var tx *types.Transaction orders := []*tradingstate.OrderItem{} - if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil { + if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil { if tx = s.b.GetPoolTransaction(hash); tx == nil { return []*tradingstate.OrderItem{}, nil } @@ -2598,7 +2599,7 @@ func (s *PublicTomoXTransactionPoolAPI) GetBorrows(ctx context.Context, lendingT // GetLendingTxMatchByHash returns lendingItems which have been processed at tx of the given txhash func (s *PublicTomoXTransactionPoolAPI) GetLendingTxMatchByHash(ctx context.Context, hash common.Hash) ([]*lendingstate.LendingItem, error) { var tx *types.Transaction - if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil { + if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil { if tx = s.b.GetPoolTransaction(hash); tx == nil { return []*lendingstate.LendingItem{}, nil } @@ -2614,7 +2615,7 @@ func (s *PublicTomoXTransactionPoolAPI) GetLendingTxMatchByHash(ctx context.Cont // GetLiquidatedTradesByTxHash returns trades which closed by TomoX protocol at the tx of the give hash func (s *PublicTomoXTransactionPoolAPI) GetLiquidatedTradesByTxHash(ctx context.Context, hash common.Hash) (lendingstate.FinalizedResult, error) { var tx *types.Transaction - if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil { + if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil { if tx = s.b.GetPoolTransaction(hash); tx == nil { return lendingstate.FinalizedResult{}, nil } @@ -2965,7 +2966,8 @@ func GetSignersFromBlocks(b Backend, blockNumber uint64, blockHash common.Hash, // GetStakerROI Estimate ROI for stakers using the last epoc reward // then multiple by epoch per year, if the address is not masternode of last epoch - return 0 // Formular: -// ROI = average_latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 +// +// ROI = average_latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 func (s *PublicBlockChainAPI) GetStakerROI() float64 { blockNumber := s.b.CurrentBlock().Number().Uint64() lastCheckpointNumber := blockNumber - (blockNumber % s.b.ChainConfig().Posv.Epoch) - s.b.ChainConfig().Posv.Epoch // calculate for 2 epochs ago @@ -2991,7 +2993,8 @@ func (s *PublicBlockChainAPI) GetStakerROI() float64 { // GetStakerROIMasternode Estimate ROI for stakers of a specific masternode using the last epoc reward // then multiple by epoch per year, if the address is not masternode of last epoch - return 0 // Formular: -// ROI = latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 +// +// ROI = latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 func (s *PublicBlockChainAPI) GetStakerROIMasternode(masternode common.Address) float64 { votersReward := s.b.GetVotersRewards(masternode) if votersReward == nil { diff --git a/les/api_backend.go b/les/api_backend.go index d8285da97d..31eacf2ae9 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -20,20 +20,17 @@ import ( "context" "encoding/json" "errors" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending" "io/ioutil" "math/big" "path/filepath" - "github.com/tomochain/tomochain/tomox" - "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -45,6 +42,9 @@ import ( "github.com/tomochain/tomochain/light" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending" ) type LesApiBackend struct { @@ -94,11 +94,11 @@ func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*t } func (b *LesApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { - return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) + return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash)) } func (b *LesApiBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { - return light.GetBlockLogs(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) + return light.GetBlockLogs(ctx, b.eth.odr, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash)) } func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int { diff --git a/les/backend.go b/les/backend.go index 1a5cae11b8..9cebbd40e4 100644 --- a/les/backend.go +++ b/les/backend.go @@ -28,6 +28,7 @@ import ( "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth" "github.com/tomochain/tomochain/eth/downloader" @@ -122,7 +123,7 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { if compat, ok := genesisErr.(*params.ConfigCompatError); ok { log.Warn("Rewinding chain to upgrade configuration", "err", compat) leth.blockchain.SetHead(compat.RewindTo) - core.WriteChainConfig(chainDb, genesisHash, chainConfig) + rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig) } leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay) diff --git a/les/fetcher.go b/les/fetcher.go index 7edfe808bb..80568bc322 100644 --- a/les/fetcher.go +++ b/les/fetcher.go @@ -25,7 +25,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/mclock" "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/light" "github.com/tomochain/tomochain/log" @@ -280,7 +280,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) { // if one of root's children is canonical, keep it, delete other branches and root itself var newRoot *fetcherTreeNode for i, nn := range fp.root.children { - if core.GetCanonicalHash(f.pm.chainDb, nn.number) == nn.hash { + if rawdb.GetCanonicalHash(f.pm.chainDb, nn.number) == nn.hash { fp.root.children = append(fp.root.children[:i], fp.root.children[i+1:]...) nn.parent = nil newRoot = nn @@ -363,7 +363,7 @@ func (f *lightFetcher) peerHasBlock(p *peer, hash common.Hash, number uint64) bo // // when syncing, just check if it is part of the known chain, there is nothing better we // can do since we do not know the most recent block hash yet - return core.GetCanonicalHash(f.pm.chainDb, fp.root.number) == fp.root.hash && core.GetCanonicalHash(f.pm.chainDb, number) == hash + return rawdb.GetCanonicalHash(f.pm.chainDb, fp.root.number) == fp.root.hash && rawdb.GetCanonicalHash(f.pm.chainDb, number) == hash } // requestAmount calculates the amount of headers to be downloaded starting diff --git a/les/handler.go b/les/handler.go index b426f7fdd1..6c57247788 100644 --- a/les/handler.go +++ b/les/handler.go @@ -21,7 +21,6 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "net" "sync" @@ -30,6 +29,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth/downloader" @@ -529,7 +529,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { break } // Retrieve the requested block body, stopping if enough was found - if data := core.GetBodyRLP(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash)); len(data) != 0 { + if data := rawdb.GetBodyRLP(pm.chainDb, hash, rawdb.GetBlockNumber(pm.chainDb, hash)); len(data) != 0 { bodies = append(bodies, data) bytes += len(data) } @@ -580,7 +580,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } for _, req := range req.Reqs { // Retrieve the requested state entry, stopping if enough was found - if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + if header := rawdb.GetHeader(pm.chainDb, req.BHash, rawdb.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { statedb, err := pm.blockchain.State() if err != nil { continue @@ -646,7 +646,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { break } // Retrieve the requested block's receipts, skipping if unknown to us - results := core.GetBlockReceipts(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash)) + results := rawdb.GetBlockReceipts(pm.chainDb, hash, rawdb.GetBlockNumber(pm.chainDb, hash)) if results == nil { if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { continue @@ -706,7 +706,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } for _, req := range req.Reqs { // Retrieve the requested state entry, stopping if enough was found - if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + if header := rawdb.GetHeader(pm.chainDb, req.BHash, rawdb.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { statedb, err := pm.blockchain.State() if err != nil { continue @@ -764,7 +764,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { if statedb == nil || req.BHash != lastBHash { statedb, root, lastBHash = nil, common.Hash{}, req.BHash - if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + if header := rawdb.GetHeader(pm.chainDb, req.BHash, rawdb.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { statedb, _ = pm.blockchain.State() root = header.Root } @@ -860,7 +860,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { trieDb := trie.NewDatabase(rawdb.NewTable(pm.chainDb, light.ChtTablePrefix)) for _, req := range req.Reqs { if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil { - sectionHead := core.GetCanonicalHash(pm.chainDb, req.ChtNum*light.CHTFrequencyServer-1) + sectionHead := rawdb.GetCanonicalHash(pm.chainDb, req.ChtNum*light.CHTFrequencyServer-1) if root := light.GetChtRoot(pm.chainDb, req.ChtNum-1, sectionHead); root != (common.Hash{}) { trie, err := trie.New(root, trieDb) if err != nil { @@ -1115,10 +1115,10 @@ func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common. func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) { switch id { case htCanonical: - sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.CHTFrequencyClient-1) + sectionHead := rawdb.GetCanonicalHash(pm.chainDb, (idx+1)*light.CHTFrequencyClient-1) return light.GetChtV2Root(pm.chainDb, idx, sectionHead), light.ChtTablePrefix case htBloomBits: - sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.BloomTrieFrequency-1) + sectionHead := rawdb.GetCanonicalHash(pm.chainDb, (idx+1)*light.BloomTrieFrequency-1) return light.GetBloomTrieRoot(pm.chainDb, idx, sectionHead), light.BloomTrieTablePrefix } return common.Hash{}, "" @@ -1129,8 +1129,8 @@ func (pm *ProtocolManager) getHelperTrieAuxData(req HelperTrieReq) []byte { switch { case req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8: blockNum := binary.BigEndian.Uint64(req.Key) - hash := core.GetCanonicalHash(pm.chainDb, blockNum) - return core.GetHeaderRLP(pm.chainDb, hash, blockNum) + hash := rawdb.GetCanonicalHash(pm.chainDb, blockNum) + return rawdb.GetHeaderRLP(pm.chainDb, hash, blockNum) } return nil } @@ -1143,9 +1143,9 @@ func (pm *ProtocolManager) txStatus(hashes []common.Hash) []txStatus { // If the transaction is unknown to the pool, try looking it up locally if stat == core.TxStatusUnknown { - if block, number, index := core.GetTxLookupEntry(pm.chainDb, hashes[i]); block != (common.Hash{}) { + if block, number, index := rawdb.GetTxLookupEntry(pm.chainDb, hashes[i]); block != (common.Hash{}) { stats[i].Status = core.TxStatusIncluded - stats[i].Lookup = &core.TxLookupEntry{BlockHash: block, BlockIndex: number, Index: index} + stats[i].Lookup = &rawdb.TxLookupEntry{BlockHash: block, BlockIndex: number, Index: index} } } } diff --git a/les/handler_test.go b/les/handler_test.go index 225900dd52..e3c526cb13 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -18,7 +18,6 @@ package les import ( "encoding/binary" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "testing" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/eth/downloader" @@ -304,7 +304,7 @@ func testGetReceipt(t *testing.T, protocol int) { block := bc.GetBlockByNumber(i) hashes = append(hashes, block.Hash()) - receipts = append(receipts, core.GetBlockReceipts(db, block.Hash(), block.NumberU64())) + receipts = append(receipts, rawdb.GetBlockReceipts(db, block.Hash(), block.NumberU64())) } // Send the hash request and verify the response cost := peer.GetRequestCost(GetReceiptsMsg, len(hashes)) @@ -555,9 +555,9 @@ func TestTransactionStatusLes2(t *testing.T) { } // check if their status is included now - block1hash := core.GetCanonicalHash(db, 1) - test(tx1, false, txStatus{Status: core.TxStatusIncluded, Lookup: &core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}}) - test(tx2, false, txStatus{Status: core.TxStatusIncluded, Lookup: &core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}}) + block1hash := rawdb.GetCanonicalHash(db, 1) + test(tx1, false, txStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}}) + test(tx2, false, txStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}}) // create a reorg that rolls them back gchain, _ = core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), ethash.NewFaker(), db, 2, func(i int, block *core.BlockGen) {}) diff --git a/les/odr_requests.go b/les/odr_requests.go index e6e68e7621..cca89b1e7e 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -// Package light implements on-demand retrieval capable state and chain objects +// Package les implements on-demand retrieval capable state and chain objects // for the Ethereum Light Client. package les @@ -24,7 +24,7 @@ import ( "fmt" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb" @@ -110,7 +110,7 @@ func (r *BlockRequest) Validate(db ethdb.Database, msg *Msg) error { body := bodies[0] // Retrieve our stored header and validate block content against it - header := core.GetHeader(db, r.Hash, r.Number) + header := rawdb.GetHeader(db, r.Hash, r.Number) if header == nil { return errHeaderUnavailable } @@ -166,7 +166,7 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error { receipt := receipts[0] // Retrieve our stored header and validate receipt content against it - header := core.GetHeader(db, r.Hash, r.Number) + header := rawdb.GetHeader(db, r.Hash, r.Number) if header == nil { return errHeaderUnavailable } diff --git a/les/odr_test.go b/les/odr_test.go index 3858e34028..29a7163afd 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -19,7 +19,6 @@ package les import ( "bytes" "context" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" "time" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -64,9 +64,9 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var receipts types.Receipts if bc != nil { - receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) + receipts = rawdb.GetBlockReceipts(db, bhash, rawdb.GetBlockNumber(db, bhash)) } else { - receipts, _ = light.GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) + receipts, _ = light.GetBlockReceipts(ctx, lc.Odr(), bhash, rawdb.GetBlockNumber(db, bhash)) } if receipts == nil { return nil @@ -190,7 +190,7 @@ func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { test := func(expFail uint64) { for i := uint64(0); i <= pm.blockchain.CurrentHeader().Number.Uint64(); i++ { - bhash := core.GetCanonicalHash(db, i) + bhash := rawdb.GetCanonicalHash(db, i) b1 := fn(light.NoOdr, db, pm.chainConfig, pm.blockchain.(*core.BlockChain), nil, bhash) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) diff --git a/les/protocol.go b/les/protocol.go index 9ca62e73e3..26e4573369 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -28,6 +28,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/secp256k1" "github.com/tomochain/tomochain/rlp" @@ -224,6 +225,6 @@ type proofsData [][]rlp.RawValue type txStatus struct { Status core.TxStatus - Lookup *core.TxLookupEntry `rlp:"nil"` + Lookup *rawdb.TxLookupEntry `rlp:"nil"` Error string } diff --git a/les/request_test.go b/les/request_test.go index 183128d839..2313e738a5 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -18,12 +18,11 @@ package les import ( "context" - "github.com/tomochain/tomochain/core/rawdb" "testing" "time" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/eth" "github.com/tomochain/tomochain/ethdb" @@ -59,7 +58,7 @@ func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light //func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) } func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { - return &light.TrieRequest{Id: light.StateTrieID(core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey} + return &light.TrieRequest{Id: light.StateTrieID(rawdb.GetHeader(db, bhash, rawdb.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey} } //func TestCodeAccessLes1(t *testing.T) { testAccess(t, 1, tfCodeAccess) } @@ -67,7 +66,7 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh //func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) } func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { - header := core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash)) + header := rawdb.GetHeader(db, bhash, rawdb.GetBlockNumber(db, bhash)) if header.Number.Uint64() < testContractDeployed { return nil } @@ -100,7 +99,7 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { test := func(expFail uint64) { for i := uint64(0); i <= pm.blockchain.CurrentHeader().Number.Uint64(); i++ { - bhash := core.GetCanonicalHash(db, i) + bhash := rawdb.GetCanonicalHash(db, i) if req := fn(ldb, bhash, i); req != nil { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() diff --git a/les/server.go b/les/server.go index b56d2cad4b..4705f599da 100644 --- a/les/server.go +++ b/les/server.go @@ -25,6 +25,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth" "github.com/tomochain/tomochain/ethdb" @@ -329,11 +330,11 @@ func (pm *ProtocolManager) blockLoop() { header := ev.Block.Header() hash := header.Hash() number := header.Number.Uint64() - td := core.GetTd(pm.chainDb, hash, number) + td := rawdb.GetTd(pm.chainDb, hash, number) if td != nil && td.Cmp(lastBroadcastTd) > 0 { var reorg uint64 if lastHead != nil { - reorg = lastHead.Number.Uint64() - core.FindCommonAncestor(pm.chainDb, header, lastHead).Number.Uint64() + reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(pm.chainDb, header, lastHead).Number.Uint64() } lastHead = header lastBroadcastTd = td diff --git a/les/sync.go b/les/sync.go index 8e3cd47ca3..993e96a581 100644 --- a/les/sync.go +++ b/les/sync.go @@ -20,7 +20,7 @@ import ( "context" "time" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/eth/downloader" "github.com/tomochain/tomochain/light" ) @@ -61,7 +61,7 @@ func (pm *ProtocolManager) syncer() { func (pm *ProtocolManager) needToSync(peerHead blockInfo) bool { head := pm.blockchain.CurrentHeader() - currentTd := core.GetTd(pm.chainDb, head.Hash(), head.Number.Uint64()) + currentTd := rawdb.GetTd(pm.chainDb, head.Hash(), head.Number.Uint64()) return currentTd != nil && peerHead.Td.Cmp(currentTd) > 0 } diff --git a/light/lightchain.go b/light/lightchain.go index 6c91389777..42717f1ced 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -24,10 +24,12 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" @@ -142,7 +144,7 @@ func (self *LightChain) Odr() OdrBackend { // loadLastState loads the last known chain state from the database. This method // assumes that the chain manager mutex is held. func (self *LightChain) loadLastState() error { - if head := core.GetHeadHeaderHash(self.chainDb); head == (common.Hash{}) { + if head := rawdb.GetHeadHeaderHash(self.chainDb); head == (common.Hash{}) { // Corrupt or empty database, init from scratch self.Reset() } else { @@ -189,10 +191,10 @@ func (bc *LightChain) ResetWithGenesisBlock(genesis *types.Block) { defer bc.mu.Unlock() // Prepare the genesis block and reinitialise the chain - if err := core.WriteTd(bc.chainDb, genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil { + if err := rawdb.WriteTd(bc.chainDb, genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil { log.Crit("Failed to write genesis block TD", "err", err) } - if err := core.WriteBlock(bc.chainDb, genesis); err != nil { + if err := rawdb.WriteBlock(bc.chainDb, genesis); err != nil { log.Crit("Failed to write genesis block", "err", err) } bc.genesisBlock = genesis diff --git a/light/lightchain_test.go b/light/lightchain_test.go index 21836cc88c..073efecb00 100644 --- a/light/lightchain_test.go +++ b/light/lightchain_test.go @@ -18,13 +18,13 @@ package light import ( "context" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/params" @@ -123,8 +123,8 @@ func testHeaderChainImport(chain []*types.Header, lightchain *LightChain) error } // Manually insert the header into the database, but don't reorganize (allows subsequent testing) lightchain.mu.Lock() - core.WriteTd(lightchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, lightchain.GetTdByHash(header.ParentHash))) - core.WriteHeader(lightchain.chainDb, header) + rawdb.WriteTd(lightchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, lightchain.GetTdByHash(header.ParentHash))) + rawdb.WriteHeader(lightchain.chainDb, header) lightchain.mu.Unlock() } return nil diff --git a/light/odr.go b/light/odr.go index b5591fdd93..9fe919cb39 100644 --- a/light/odr.go +++ b/light/odr.go @@ -24,6 +24,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" ) @@ -112,7 +113,7 @@ type BlockRequest struct { // StoreResult stores the retrieved data in local database func (req *BlockRequest) StoreResult(db ethdb.Database) { - core.WriteBodyRLP(db, req.Hash, req.Number, req.Rlp) + rawdb.WriteBodyRLP(db, req.Hash, req.Number, req.Rlp) } // ReceiptsRequest is the ODR request type for retrieving block bodies @@ -125,7 +126,7 @@ type ReceiptsRequest struct { // StoreResult stores the retrieved data in local database func (req *ReceiptsRequest) StoreResult(db ethdb.Database) { - core.WriteBlockReceipts(db, req.Hash, req.Number, req.Receipts) + rawdb.WriteBlockReceipts(db, req.Hash, req.Number, req.Receipts) } // ChtRequest is the ODR request type for state/storage trie entries @@ -141,10 +142,10 @@ type ChtRequest struct { // StoreResult stores the retrieved data in local database func (req *ChtRequest) StoreResult(db ethdb.Database) { // if there is a canonical hash, there is a header too - core.WriteHeader(db, req.Header) + rawdb.WriteHeader(db, req.Header) hash, num := req.Header.Hash(), req.Header.Number.Uint64() - core.WriteTd(db, hash, num, req.Td) - core.WriteCanonicalHash(db, hash, num) + rawdb.WriteTd(db, hash, num, req.Td) + rawdb.WriteCanonicalHash(db, hash, num) } // BloomRequest is the ODR request type for retrieving bloom filters from a CHT structure @@ -161,11 +162,11 @@ type BloomRequest struct { // StoreResult stores the retrieved data in local database func (req *BloomRequest) StoreResult(db ethdb.Database) { for i, sectionIdx := range req.SectionIdxList { - sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) + sectionHead := rawdb.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) // if we don't have the canonical hash stored for this section head number, we'll still store it under // a key with a zero sectionHead. GetBloomBits will look there too if we still don't have the canonical // hash. In the unlikely case we've retrieved the section head hash since then, we'll just retrieve the // bit vector again from the network. - core.WriteBloomBits(db, req.BitIdx, sectionIdx, sectionHead, req.BloomBits[i]) + rawdb.WriteBloomBits(db, req.BitIdx, sectionIdx, sectionHead, req.BloomBits[i]) } } diff --git a/light/odr_test.go b/light/odr_test.go index 0c5fc78573..f1dbc13407 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -20,16 +20,16 @@ import ( "bytes" "context" "errors" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -72,9 +72,9 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { } switch req := req.(type) { case *BlockRequest: - req.Rlp = core.GetBodyRLP(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) + req.Rlp = rawdb.GetBodyRLP(odr.sdb, req.Hash, rawdb.GetBlockNumber(odr.sdb, req.Hash)) case *ReceiptsRequest: - req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) + req.Receipts = rawdb.GetBlockReceipts(odr.sdb, req.Hash, rawdb.GetBlockNumber(odr.sdb, req.Hash)) case *TrieRequest: t, _ := trie.New(req.Id.Root, trie.NewDatabase(odr.sdb)) nodes := NewNodeSet() @@ -110,9 +110,9 @@ func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, odrGetReceipts) } func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { var receipts types.Receipts if bc != nil { - receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) + receipts = rawdb.GetBlockReceipts(db, bhash, rawdb.GetBlockNumber(db, bhash)) } else { - receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) + receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, rawdb.GetBlockNumber(db, bhash)) } if receipts == nil { return nil, nil @@ -268,7 +268,7 @@ func testChainOdr(t *testing.T, protocol int, fn odrTestFn) { test := func(expFail int) { for i := uint64(0); i <= blockchain.CurrentHeader().Number.Uint64(); i++ { - bhash := core.GetCanonicalHash(sdb, i) + bhash := rawdb.GetCanonicalHash(sdb, i) b1, err := fn(NoOdr, sdb, blockchain, nil, bhash) if err != nil { t.Fatalf("error in full-node test for block %d: %v", i, err) diff --git a/light/odr_util.go b/light/odr_util.go index 89a63eb2b9..371375d81d 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -22,6 +22,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/rlp" @@ -31,10 +32,10 @@ var sha3_nil = crypto.Keccak256Hash(nil) func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*types.Header, error) { db := odr.Database() - hash := core.GetCanonicalHash(db, number) + hash := rawdb.GetCanonicalHash(db, number) if (hash != common.Hash{}) { // if there is a canonical hash, there is a header too - header := core.GetHeader(db, hash, number) + header := rawdb.GetHeader(db, hash, number) if header == nil { panic("Canonical hash present but header not found") } @@ -47,14 +48,14 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ ) if odr.ChtIndexer() != nil { chtCount, sectionHeadNum, sectionHead = odr.ChtIndexer().Sections() - canonicalHash := core.GetCanonicalHash(db, sectionHeadNum) + canonicalHash := rawdb.GetCanonicalHash(db, sectionHeadNum) // if the CHT was injected as a trusted checkpoint, we have no canonical hash yet so we accept zero hash too for chtCount > 0 && canonicalHash != sectionHead && canonicalHash != (common.Hash{}) { chtCount-- if chtCount > 0 { sectionHeadNum = chtCount*CHTFrequencyClient - 1 sectionHead = odr.ChtIndexer().SectionHead(chtCount - 1) - canonicalHash = core.GetCanonicalHash(db, sectionHeadNum) + canonicalHash = rawdb.GetCanonicalHash(db, sectionHeadNum) } } } @@ -69,7 +70,7 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ } func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (common.Hash, error) { - hash := core.GetCanonicalHash(odr.Database(), number) + hash := rawdb.GetCanonicalHash(odr.Database(), number) if (hash != common.Hash{}) { return hash, nil } @@ -82,7 +83,7 @@ func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (commo // GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding. func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (rlp.RawValue, error) { - if data := core.GetBodyRLP(odr.Database(), hash, number); data != nil { + if data := rawdb.GetBodyRLP(odr.Database(), hash, number); data != nil { return data, nil } r := &BlockRequest{Hash: hash, Number: number} @@ -111,7 +112,7 @@ func GetBody(ctx context.Context, odr OdrBackend, hash common.Hash, number uint6 // back from the stored header and body. func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (*types.Block, error) { // Retrieve the block header and body contents - header := core.GetHeader(odr.Database(), hash, number) + header := rawdb.GetHeader(odr.Database(), hash, number) if header == nil { return nil, ErrNoHeader } @@ -127,7 +128,7 @@ func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint // in a block given by its hash. func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (types.Receipts, error) { // Retrieve the potentially incomplete receipts from disk or network - receipts := core.GetBlockReceipts(odr.Database(), hash, number) + receipts := rawdb.GetBlockReceipts(odr.Database(), hash, number) if receipts == nil { r := &ReceiptsRequest{Hash: hash, Number: number} if err := odr.Retrieve(ctx, r); err != nil { @@ -141,13 +142,13 @@ func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, num if err != nil { return nil, err } - genesis := core.GetCanonicalHash(odr.Database(), 0) - config, _ := core.GetChainConfig(odr.Database(), genesis) + genesis := rawdb.GetCanonicalHash(odr.Database(), 0) + config, _ := rawdb.GetChainConfig(odr.Database(), genesis) if err := core.SetReceiptsData(config, block, receipts); err != nil { return nil, err } - core.WriteBlockReceipts(odr.Database(), hash, number, receipts) + rawdb.WriteBlockReceipts(odr.Database(), hash, number, receipts) } return receipts, nil } @@ -156,7 +157,7 @@ func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, num // block given by its hash. func GetBlockLogs(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) ([][]*types.Log, error) { // Retrieve the potentially incomplete receipts from disk or network - receipts := core.GetBlockReceipts(odr.Database(), hash, number) + receipts := rawdb.GetBlockReceipts(odr.Database(), hash, number) if receipts == nil { r := &ReceiptsRequest{Hash: hash, Number: number} if err := odr.Retrieve(ctx, r); err != nil { @@ -187,24 +188,24 @@ func GetBloomBits(ctx context.Context, odr OdrBackend, bitIdx uint, sectionIdxLi ) if odr.BloomTrieIndexer() != nil { bloomTrieCount, sectionHeadNum, sectionHead = odr.BloomTrieIndexer().Sections() - canonicalHash := core.GetCanonicalHash(db, sectionHeadNum) + canonicalHash := rawdb.GetCanonicalHash(db, sectionHeadNum) // if the BloomTrie was injected as a trusted checkpoint, we have no canonical hash yet so we accept zero hash too for bloomTrieCount > 0 && canonicalHash != sectionHead && canonicalHash != (common.Hash{}) { bloomTrieCount-- if bloomTrieCount > 0 { sectionHeadNum = bloomTrieCount*BloomTrieFrequency - 1 sectionHead = odr.BloomTrieIndexer().SectionHead(bloomTrieCount - 1) - canonicalHash = core.GetCanonicalHash(db, sectionHeadNum) + canonicalHash = rawdb.GetCanonicalHash(db, sectionHeadNum) } } } for i, sectionIdx := range sectionIdxList { - sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) + sectionHead := rawdb.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) // if we don't have the canonical hash stored for this section head number, we'll still look for // an entry with a zero sectionHead (we store it with zero section head too if we don't know it // at the time of the retrieval) - bloomBits, err := core.GetBloomBits(db, bitIdx, sectionIdx, sectionHead) + bloomBits, err := rawdb.GetBloomBits(db, bitIdx, sectionIdx, sectionHead) if err == nil { result[i] = bloomBits } else { diff --git a/light/postprocess.go b/light/postprocess.go index 1e83a3cd7a..22526d943f 100644 --- a/light/postprocess.go +++ b/light/postprocess.go @@ -19,13 +19,13 @@ package light import ( "encoding/binary" "errors" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/bitutil" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" @@ -162,7 +162,7 @@ func (c *ChtIndexerBackend) Process(header *types.Header) { hash, num := header.Hash(), header.Number.Uint64() c.lastHash = hash - td := core.GetTd(c.diskdb, hash, num) + td := rawdb.GetTd(c.diskdb, hash, num) if td == nil { panic(nil) } @@ -273,7 +273,7 @@ func (b *BloomTrieIndexerBackend) Commit() error { binary.BigEndian.PutUint64(encKey[2:10], b.section) var decomp []byte for j := uint64(0); j < b.bloomTrieRatio; j++ { - data, err := core.GetBloomBits(b.diskdb, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j]) + data, err := rawdb.GetBloomBits(b.diskdb, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j]) if err != nil { return err } diff --git a/light/txpool.go b/light/txpool.go index 7af86dbd6b..11f7e019d9 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -24,6 +24,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" @@ -74,10 +75,13 @@ type TxPool struct { // // Send instructs backend to forward new transactions // NewHead notifies backend about a new head after processed by the tx pool, -// including mined and rolled back transactions since the last event +// +// including mined and rolled back transactions since the last event +// // Discard notifies backend about transactions that should be discarded either -// because they have been replaced by a re-send or because they have been mined -// long ago and no rollback is expected +// +// because they have been replaced by a re-send or because they have been mined +// long ago and no rollback is expected type TxRelayBackend interface { Send(txs types.Transactions) NewHead(head common.Hash, mined []common.Hash, rollback []common.Hash) @@ -183,7 +187,7 @@ func (pool *TxPool) checkMinedTxs(ctx context.Context, hash common.Hash, number if _, err := GetBlockReceipts(ctx, pool.odr, hash, number); err != nil { // ODR caches, ignore results return err } - if err := core.WriteTxLookupEntries(pool.chainDb, block); err != nil { + if err := rawdb.WriteTxLookupEntries(pool.chainDb, block); err != nil { return err } // Update the transaction pool's state @@ -202,7 +206,7 @@ func (pool *TxPool) rollbackTxs(hash common.Hash, txc txStateChanges) { if list, ok := pool.mined[hash]; ok { for _, tx := range list { txHash := tx.Hash() - core.DeleteTxLookupEntry(pool.chainDb, txHash) + rawdb.DeleteTxLookupEntry(pool.chainDb, txHash) pool.pending[txHash] = tx txc.setState(txHash, false) } @@ -258,7 +262,7 @@ func (pool *TxPool) reorgOnNewHead(ctx context.Context, newHeader *types.Header) idx2 := idx - txPermanent if len(pool.mined) > 0 { for i := pool.clearIdx; i < idx2; i++ { - hash := core.GetCanonicalHash(pool.chainDb, i) + hash := rawdb.GetCanonicalHash(pool.chainDb, i) if list, ok := pool.mined[hash]; ok { hashes := make([]common.Hash, len(list)) for i, tx := range list { From 5ce901b99b793b89951e2b24e317e65424644a8b Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 20 Jul 2023 16:52:37 +0700 Subject: [PATCH 037/119] Move db keys to schema.go --- core/blockchain.go | 2 +- core/headerchain.go | 2 +- core/rawdb/database_util.go | 82 ++++++------------------------- core/rawdb/schema.go | 98 +++++++++++++++++++++++++++++++++++++ 4 files changed, 116 insertions(+), 68 deletions(-) create mode 100644 core/rawdb/schema.go diff --git a/core/blockchain.go b/core/blockchain.go index eb8a816f55..1d8fe69fa9 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -731,7 +731,7 @@ func (bc *BlockChain) HasBlock(hash common.Hash, number uint64) bool { if bc.blockCache.Contains(hash) { return true } - ok, _ := bc.db.Has(rawdb.BlockBodyKey(hash, number)) + ok, _ := bc.db.Has(rawdb.BlockBodyKey(number, hash)) return ok } diff --git a/core/headerchain.go b/core/headerchain.go index 3519e2cb92..4fe236824d 100644 --- a/core/headerchain.go +++ b/core/headerchain.go @@ -371,7 +371,7 @@ func (hc *HeaderChain) HasHeader(hash common.Hash, number uint64) bool { if hc.numberCache.Contains(hash) || hc.headerCache.Contains(hash) { return true } - ok, _ := hc.chainDb.Has(rawdb.HeaderKey(hash, number)) + ok, _ := hc.chainDb.Has(rawdb.HeaderKey(number, hash)) return ok } diff --git a/core/rawdb/database_util.go b/core/rawdb/database_util.go index 3b3eb01727..f32c3af465 100644 --- a/core/rawdb/database_util.go +++ b/core/rawdb/database_util.go @@ -20,7 +20,6 @@ import ( "bytes" "encoding/binary" "encoding/json" - "errors" "fmt" "math/big" @@ -28,7 +27,6 @@ import ( "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" ) @@ -43,38 +41,6 @@ type DatabaseDeleter interface { Delete(key []byte) error } -var ( - headHeaderKey = []byte("LastHeader") - headBlockKey = []byte("LastBlock") - headFastKey = []byte("LastFast") - trieSyncKey = []byte("TrieSync") - - // Data item prefixes (use single byte to avoid mixing data types, avoid `i`). - headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header - tdSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + tdSuffix -> td - numSuffix = []byte("n") // headerPrefix + num (uint64 big endian) + numSuffix -> hash - blockHashPrefix = []byte("H") // blockHashPrefix + hash -> num (uint64 big endian) - bodyPrefix = []byte("b") // bodyPrefix + num (uint64 big endian) + hash -> block body - blockReceiptsPrefix = []byte("r") // blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts - lookupPrefix = []byte("l") // lookupPrefix + hash -> transaction/receipt lookup metadata - bloomBitsPrefix = []byte("B") // bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash -> bloom bits - - preimagePrefix = "secure-key-" // preimagePrefix + hash -> preimage - configPrefix = []byte("ethereum-config-") // config prefix for the db - - // Chain index prefixes (use `i` + single byte to avoid mixing data types). - BloomBitsIndexPrefix = []byte("iB") // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress - - // used by old db, now only used for conversion - oldReceiptsPrefix = []byte("receipts-") - oldTxMetaSuffix = []byte{0x01} - - ErrChainConfigNotFound = errors.New("ChainConfig not found") // general config not found error - - preimageCounter = metrics.NewRegisteredCounter("db/preimage/total", nil) - preimageHitCounter = metrics.NewRegisteredCounter("db/preimage/hits", nil) -) - // TxLookupEntry is a positional metadata to help looking up the data content of // a transaction or receipt given only its hash. type TxLookupEntry struct { @@ -160,7 +126,7 @@ func GetTrieSyncProgress(db DatabaseReader) uint64 { // GetHeaderRLP retrieves a block header in its raw RLP database encoding, or nil // if the header's not found. func GetHeaderRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue { - data, _ := db.Get(HeaderKey(hash, number)) + data, _ := db.Get(HeaderKey(number, hash)) return data } @@ -181,18 +147,10 @@ func GetHeader(db DatabaseReader, hash common.Hash, number uint64) *types.Header // GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding. func GetBodyRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue { - data, _ := db.Get(BlockBodyKey(hash, number)) + data, _ := db.Get(BlockBodyKey(number, hash)) return data } -func HeaderKey(hash common.Hash, number uint64) []byte { - return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...) -} - -func BlockBodyKey(hash common.Hash, number uint64) []byte { - return append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...) -} - // GetBody retrieves the block body (transactons, uncles) corresponding to the // hash, nil if none found. func GetBody(db DatabaseReader, hash common.Hash, number uint64) *types.Body { @@ -211,7 +169,7 @@ func GetBody(db DatabaseReader, hash common.Hash, number uint64) *types.Body { // GetTd retrieves a block's total difficulty corresponding to the hash, nil if // none found. func GetTd(db DatabaseReader, hash common.Hash, number uint64) *big.Int { - data, _ := db.Get(append(append(append(headerPrefix, encodeBlockNumber(number)...), hash[:]...), tdSuffix...)) + data, _ := db.Get(headerTDKey(number, hash)) if len(data) == 0 { return nil } @@ -246,7 +204,7 @@ func GetBlock(db DatabaseReader, hash common.Hash, number uint64) *types.Block { // GetBlockReceipts retrieves the receipts generated by the transactions included // in a block given by its hash. func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types.Receipts { - data, _ := db.Get(append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash[:]...)) + data, _ := db.Get(blockReceiptsKey(number, hash)) if len(data) == 0 { return nil } @@ -266,7 +224,7 @@ func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types. // hash to allow retrieving the transaction or receipt by hash. func GetTxLookupEntry(db DatabaseReader, hash common.Hash) (common.Hash, uint64, uint64) { // Load the positional metadata from disk and bail if it fails - data, _ := db.Get(append(lookupPrefix, hash.Bytes()...)) + data, _ := db.Get(txLookupKey(hash)) if len(data) == 0 { return common.Hash{}, 0, 0 } @@ -344,12 +302,7 @@ func GetReceipt(db DatabaseReader, hash common.Hash) (*types.Receipt, common.Has // GetBloomBits retrieves the compressed bloom bit vector belonging to the given // section and bit index from the. func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) ([]byte, error) { - key := append(append(bloomBitsPrefix, make([]byte, 10)...), head.Bytes()...) - - binary.BigEndian.PutUint16(key[1:], uint16(bit)) - binary.BigEndian.PutUint64(key[3:], section) - - return db.Get(key) + return db.Get(bloomBitsKey(bit, section, head)) } // WriteCanonicalHash stores the canonical hash for the given block number. @@ -425,8 +378,7 @@ func WriteBody(db ethdb.KeyValueWriter, hash common.Hash, number uint64, body *t // WriteBodyRLP writes a serialized body of a block into the database. func WriteBodyRLP(db ethdb.KeyValueWriter, hash common.Hash, number uint64, rlp rlp.RawValue) error { - key := append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...) - if err := db.Put(key, rlp); err != nil { + if err := db.Put(BlockBodyKey(number, hash), rlp); err != nil { log.Crit("Failed to store block body", "err", err) } return nil @@ -438,8 +390,7 @@ func WriteTd(db ethdb.KeyValueWriter, hash common.Hash, number uint64, td *big.I if err != nil { return err } - key := append(append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...), tdSuffix...) - if err := db.Put(key, data); err != nil { + if err := db.Put(headerTDKey(number, hash), data); err != nil { log.Crit("Failed to store block total difficulty", "err", err) } return nil @@ -472,8 +423,7 @@ func WriteBlockReceipts(db ethdb.KeyValueWriter, hash common.Hash, number uint64 return err } // Store the flattened receipt slice - key := append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...) - if err := db.Put(key, bytes); err != nil { + if err := db.Put(blockReceiptsKey(number, hash), bytes); err != nil { log.Crit("Failed to store block receipts", "err", err) } return nil @@ -493,7 +443,7 @@ func WriteTxLookupEntries(db ethdb.KeyValueWriter, block *types.Block) error { if err != nil { return err } - if err := db.Put(append(lookupPrefix, tx.Hash().Bytes()...), data); err != nil { + if err := db.Put(txLookupKey(tx.Hash()), data); err != nil { return err } } @@ -526,12 +476,12 @@ func DeleteHeader(db DatabaseDeleter, hash common.Hash, number uint64) { // DeleteBody removes all block body data associated with a hash. func DeleteBody(db DatabaseDeleter, hash common.Hash, number uint64) { - db.Delete(append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...)) + db.Delete(BlockBodyKey(number, hash)) } // DeleteTd removes all block total difficulty data associated with a hash. func DeleteTd(db DatabaseDeleter, hash common.Hash, number uint64) { - db.Delete(append(append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...), tdSuffix...)) + db.Delete(headerTDKey(number, hash)) } // DeleteBlock removes all block data associated with a hash. @@ -544,12 +494,12 @@ func DeleteBlock(db DatabaseDeleter, hash common.Hash, number uint64) { // DeleteBlockReceipts removes all receipt data associated with a block hash. func DeleteBlockReceipts(db DatabaseDeleter, hash common.Hash, number uint64) { - db.Delete(append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...)) + db.Delete(blockReceiptsKey(number, hash)) } // DeleteTxLookupEntry removes all transaction data associated with a hash. func DeleteTxLookupEntry(db DatabaseDeleter, hash common.Hash) { - db.Delete(append(lookupPrefix, hash.Bytes()...)) + db.Delete(txLookupKey(hash)) } // PreimageTable returns a Database instance with the key prefix for preimage entries. @@ -606,12 +556,12 @@ func WriteChainConfig(db ethdb.KeyValueWriter, hash common.Hash, cfg *params.Cha return err } - return db.Put(append(configPrefix, hash[:]...), jsonChainConfig) + return db.Put(configKey(hash), jsonChainConfig) } // GetChainConfig will fetch the network settings based on the given hash. func GetChainConfig(db DatabaseReader, hash common.Hash) (*params.ChainConfig, error) { - jsonChainConfig, _ := db.Get(append(configPrefix, hash[:]...)) + jsonChainConfig, _ := db.Get(configKey(hash)) if len(jsonChainConfig) == 0 { return nil, ErrChainConfigNotFound } diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go new file mode 100644 index 0000000000..491c3efa1c --- /dev/null +++ b/core/rawdb/schema.go @@ -0,0 +1,98 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package rawdb contains a collection of low level database accessors. +package rawdb + +import ( + "encoding/binary" + "errors" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/metrics" +) + +var ( + headHeaderKey = []byte("LastHeader") + headBlockKey = []byte("LastBlock") + headFastKey = []byte("LastFast") + trieSyncKey = []byte("TrieSync") + + // Data item prefixes (use single byte to avoid mixing data types, avoid `i`). + headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header + tdSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + tdSuffix -> td + numSuffix = []byte("n") // headerPrefix + num (uint64 big endian) + numSuffix -> hash + blockHashPrefix = []byte("H") // blockHashPrefix + hash -> num (uint64 big endian) + bodyPrefix = []byte("b") // bodyPrefix + num (uint64 big endian) + hash -> block body + blockReceiptsPrefix = []byte("r") // blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts + lookupPrefix = []byte("l") // lookupPrefix + hash -> transaction/receipt lookup metadata + bloomBitsPrefix = []byte("B") // bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash -> bloom bits + + preimagePrefix = "secure-key-" // preimagePrefix + hash -> preimage + configPrefix = []byte("ethereum-config-") // config prefix for the db + + // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress + BloomBitsIndexPrefix = []byte("iB") // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress + + // used by old db, now only used for conversion + oldReceiptsPrefix = []byte("receipts-") + oldTxMetaSuffix = []byte{0x01} + + ErrChainConfigNotFound = errors.New("ChainConfig not found") // general config not found error + + preimageCounter = metrics.NewRegisteredCounter("db/preimage/total", nil) + preimageHitCounter = metrics.NewRegisteredCounter("db/preimage/hits", nil) +) + +// configKey = configPrefix + hash +func configKey(hash common.Hash) []byte { + return append(configPrefix, hash.Bytes()...) +} + +// headerTDKey = headerPrefix + num (uint64 big endian) + hash + tdSuffix +func headerTDKey(number uint64, hash common.Hash) []byte { + return append(HeaderKey(number, hash), tdSuffix...) +} + +// HeaderKey = headerPrefix + num (uint64 big endian) + hash +func HeaderKey(number uint64, hash common.Hash) []byte { + return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...) +} + +// BlockBodyKey = bodyPrefix + num (uint64 big endian) + hash +func BlockBodyKey(number uint64, hash common.Hash) []byte { + return append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...) +} + +// blockReceiptsKey = blockReceiptsPrefix + num (uint64 big endian) + hash +func blockReceiptsKey(number uint64, hash common.Hash) []byte { + return append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...) +} + +// txLookupKey = lookupPrefix + hash +func txLookupKey(hash common.Hash) []byte { + return append(lookupPrefix, hash.Bytes()...) +} + +// bloomBitsKey = bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash +func bloomBitsKey(bit uint, section uint64, hash common.Hash) []byte { + key := append(append(bloomBitsPrefix, make([]byte, 10)...), hash.Bytes()...) + + binary.BigEndian.PutUint16(key[1:], uint16(bit)) + binary.BigEndian.PutUint64(key[3:], section) + + return key +} From 36373e7df7d02b50c534152125060937f02b89ce Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 20 Jul 2023 21:39:23 +0700 Subject: [PATCH 038/119] Refine db keys --- core/rawdb/database_util.go | 44 ++++++++++++++-------------------- core/rawdb/schema.go | 47 ++++++++++++++++++++++++++++--------- 2 files changed, 54 insertions(+), 37 deletions(-) diff --git a/core/rawdb/database_util.go b/core/rawdb/database_util.go index f32c3af465..c27345b4e1 100644 --- a/core/rawdb/database_util.go +++ b/core/rawdb/database_util.go @@ -58,7 +58,7 @@ func encodeBlockNumber(number uint64) []byte { // GetCanonicalHash retrieves a hash assigned to a canonical block number. func GetCanonicalHash(db DatabaseReader, number uint64) common.Hash { - data, _ := db.Get(append(append(headerPrefix, encodeBlockNumber(number)...), numSuffix...)) + data, _ := db.Get(headerHashKey(number)) if len(data) == 0 { return common.Hash{} } @@ -72,7 +72,7 @@ const MissingNumber = uint64(0xffffffffffffffff) // GetBlockNumber returns the block number assigned to a block hash // if the corresponding header is present in the database func GetBlockNumber(db DatabaseReader, hash common.Hash) uint64 { - data, _ := db.Get(append(blockHashPrefix, hash.Bytes()...)) + data, _ := db.Get(headerNumberKey(hash)) if len(data) != 8 { return MissingNumber } @@ -114,7 +114,7 @@ func GetHeadFastBlockHash(db DatabaseReader) common.Hash { } // GetTrieSyncProgress retrieves the number of tries nodes fast synced to allow -// reportinc correct numbers across restarts. +// reporting correct numbers across restarts. func GetTrieSyncProgress(db DatabaseReader) uint64 { data, _ := db.Get(trieSyncKey) if len(data) == 0 { @@ -151,7 +151,7 @@ func GetBodyRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue return data } -// GetBody retrieves the block body (transactons, uncles) corresponding to the +// GetBody retrieves the block body (transactions, uncles) corresponding to the // hash, nil if none found. func GetBody(db DatabaseReader, hash common.Hash, number uint64) *types.Body { data := GetBodyRLP(db, hash, number) @@ -208,7 +208,7 @@ func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types. if len(data) == 0 { return nil } - storageReceipts := []*types.ReceiptForStorage{} + var storageReceipts []*types.ReceiptForStorage if err := rlp.DecodeBytes(data, &storageReceipts); err != nil { log.Error("Invalid receipt array RLP", "hash", hash, "err", err) return nil @@ -251,7 +251,7 @@ func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, co } return body.Transactions[txIndex], blockHash, blockNumber, txIndex } - // Old transaction representation, load the transaction and it's metadata separately + // Old transaction representation, load the transaction and its metadata separately data, _ := db.Get(hash.Bytes()) if len(data) == 0 { return nil, common.Hash{}, 0, 0 @@ -261,7 +261,7 @@ func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, co return nil, common.Hash{}, 0, 0 } // Retrieve the blockchain positional metadata - data, _ = db.Get(append(hash.Bytes(), oldTxMetaSuffix...)) + data, _ = db.Get(oldTxMetaKey(hash)) if len(data) == 0 { return nil, common.Hash{}, 0, 0 } @@ -281,13 +281,13 @@ func GetReceipt(db DatabaseReader, hash common.Hash) (*types.Receipt, common.Has if blockHash != (common.Hash{}) { receipts := GetBlockReceipts(db, blockHash, blockNumber) if len(receipts) <= int(receiptIndex) { - log.Error("Receipt refereced missing", "number", blockNumber, "hash", blockHash, "index", receiptIndex) + log.Error("Receipt referenced missing", "number", blockNumber, "hash", blockHash, "index", receiptIndex) return nil, common.Hash{}, 0, 0 } return receipts[receiptIndex], blockHash, blockNumber, receiptIndex } // Old receipt representation, load the receipt and set an unknown metadata - data, _ := db.Get(append(oldReceiptsPrefix, hash[:]...)) + data, _ := db.Get(oldReceiptsKey(hash)) if len(data) == 0 { return nil, common.Hash{}, 0, 0 } @@ -300,15 +300,14 @@ func GetReceipt(db DatabaseReader, hash common.Hash) (*types.Receipt, common.Has } // GetBloomBits retrieves the compressed bloom bit vector belonging to the given -// section and bit index from the. +// bit index and section indexes. func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) ([]byte, error) { return db.Get(bloomBitsKey(bit, section, head)) } // WriteCanonicalHash stores the canonical hash for the given block number. func WriteCanonicalHash(db ethdb.KeyValueWriter, hash common.Hash, number uint64) error { - key := append(append(headerPrefix, encodeBlockNumber(number)...), numSuffix...) - if err := db.Put(key, hash.Bytes()); err != nil { + if err := db.Put(headerHashKey(number), hash.Bytes()); err != nil { log.Crit("Failed to store number to hash mapping", "err", err) } return nil @@ -353,15 +352,13 @@ func WriteHeader(db ethdb.KeyValueWriter, header *types.Header) error { if err != nil { return err } - hash := header.Hash().Bytes() + hash := header.Hash() num := header.Number.Uint64() encNum := encodeBlockNumber(num) - key := append(blockHashPrefix, hash...) - if err := db.Put(key, encNum); err != nil { + if err := db.Put(headerNumberKey(hash), encNum); err != nil { log.Crit("Failed to store hash to number mapping", "err", err) } - key = append(append(headerPrefix, encNum...), hash...) - if err := db.Put(key, data); err != nil { + if err := db.Put(headerKey(num, hash), data); err != nil { log.Crit("Failed to store header", "err", err) } return nil @@ -453,25 +450,20 @@ func WriteTxLookupEntries(db ethdb.KeyValueWriter, block *types.Block) error { // WriteBloomBits writes the compressed bloom bits vector belonging to the given // section and bit index. func WriteBloomBits(db ethdb.KeyValueWriter, bit uint, section uint64, head common.Hash, bits []byte) { - key := append(append(bloomBitsPrefix, make([]byte, 10)...), head.Bytes()...) - - binary.BigEndian.PutUint16(key[1:], uint16(bit)) - binary.BigEndian.PutUint64(key[3:], section) - - if err := db.Put(key, bits); err != nil { + if err := db.Put(bloomBitsKey(bit, section, head), bits); err != nil { log.Crit("Failed to store bloom bits", "err", err) } } // DeleteCanonicalHash removes the number to hash canonical mapping. func DeleteCanonicalHash(db DatabaseDeleter, number uint64) { - db.Delete(append(append(headerPrefix, encodeBlockNumber(number)...), numSuffix...)) + db.Delete(headerHashKey(number)) } // DeleteHeader removes all block header data associated with a hash. func DeleteHeader(db DatabaseDeleter, hash common.Hash, number uint64) { - db.Delete(append(blockHashPrefix, hash.Bytes()...)) - db.Delete(append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...)) + db.Delete(headerNumberKey(hash)) + db.Delete(headerHashKey(number)) } // DeleteBody removes all block body data associated with a hash. diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index 491c3efa1c..50c343cc24 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -33,12 +33,12 @@ var ( // Data item prefixes (use single byte to avoid mixing data types, avoid `i`). headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header - tdSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + tdSuffix -> td - numSuffix = []byte("n") // headerPrefix + num (uint64 big endian) + numSuffix -> hash - blockHashPrefix = []byte("H") // blockHashPrefix + hash -> num (uint64 big endian) - bodyPrefix = []byte("b") // bodyPrefix + num (uint64 big endian) + hash -> block body + headerTDSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + headerTDSuffix -> td + headerHashSuffix = []byte("n") // headerPrefix + num (uint64 big endian) + headerHashSuffix -> hash + headerNumberPrefix = []byte("H") // headerNumberPrefix + hash -> num (uint64 big endian) + blockBodyPrefix = []byte("b") // blockBodyPrefix + num (uint64 big endian) + hash -> block body blockReceiptsPrefix = []byte("r") // blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts - lookupPrefix = []byte("l") // lookupPrefix + hash -> transaction/receipt lookup metadata + txLookupPrefix = []byte("l") // txLookupPrefix + hash -> transaction/receipt lookup metadata bloomBitsPrefix = []byte("B") // bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash -> bloom bits preimagePrefix = "secure-key-" // preimagePrefix + hash -> preimage @@ -62,9 +62,19 @@ func configKey(hash common.Hash) []byte { return append(configPrefix, hash.Bytes()...) } -// headerTDKey = headerPrefix + num (uint64 big endian) + hash + tdSuffix +// headerKey = headerPrefix + num (uint64 big endian) + hash +func headerKey(number uint64, hash common.Hash) []byte { + return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...) +} + +// headerTDKey = headerPrefix + num (uint64 big endian) + hash + headerTDSuffix func headerTDKey(number uint64, hash common.Hash) []byte { - return append(HeaderKey(number, hash), tdSuffix...) + return append(HeaderKey(number, hash), headerTDSuffix...) +} + +// headerHashKey = headerPrefix + num (uint64 big endian) + headerHashSuffix +func headerHashKey(number uint64) []byte { + return append(append(headerPrefix, encodeBlockNumber(number)...), headerHashSuffix...) } // HeaderKey = headerPrefix + num (uint64 big endian) + hash @@ -72,9 +82,14 @@ func HeaderKey(number uint64, hash common.Hash) []byte { return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...) } -// BlockBodyKey = bodyPrefix + num (uint64 big endian) + hash +// headerNumberKey = headerNumberPrefix + hash +func headerNumberKey(hash common.Hash) []byte { + return append(headerNumberPrefix, hash.Bytes()...) +} + +// BlockBodyKey = blockBodyPrefix + num (uint64 big endian) + hash func BlockBodyKey(number uint64, hash common.Hash) []byte { - return append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...) + return append(append(blockBodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...) } // blockReceiptsKey = blockReceiptsPrefix + num (uint64 big endian) + hash @@ -82,9 +97,9 @@ func blockReceiptsKey(number uint64, hash common.Hash) []byte { return append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...) } -// txLookupKey = lookupPrefix + hash +// txLookupKey = txLookupPrefix + hash func txLookupKey(hash common.Hash) []byte { - return append(lookupPrefix, hash.Bytes()...) + return append(txLookupPrefix, hash.Bytes()...) } // bloomBitsKey = bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash @@ -96,3 +111,13 @@ func bloomBitsKey(bit uint, section uint64, hash common.Hash) []byte { return key } + +// oldTxMetaKey = hash + oldTxMetaSuffix +func oldTxMetaKey(hash common.Hash) []byte { + return append(hash.Bytes(), oldTxMetaSuffix...) +} + +// oldReceiptsKey = oldReceiptsPrefix + hash +func oldReceiptsKey(hash common.Hash) []byte { + return append(oldReceiptsPrefix, hash[:]...) +} From 6392afdcbaeadc77d531be2630b26417985fd27e Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 20 Jul 2023 21:55:35 +0700 Subject: [PATCH 039/119] Rename --- core/rawdb/{database_util.go => accessors_chain.go} | 0 core/rawdb/{database_util_test.go => accessors_chain_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename core/rawdb/{database_util.go => accessors_chain.go} (100%) rename core/rawdb/{database_util_test.go => accessors_chain_test.go} (100%) diff --git a/core/rawdb/database_util.go b/core/rawdb/accessors_chain.go similarity index 100% rename from core/rawdb/database_util.go rename to core/rawdb/accessors_chain.go diff --git a/core/rawdb/database_util_test.go b/core/rawdb/accessors_chain_test.go similarity index 100% rename from core/rawdb/database_util_test.go rename to core/rawdb/accessors_chain_test.go From cf62b980401450498fdf1fcf1ed9cb5ac33609cc Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 20 Jul 2023 22:08:27 +0700 Subject: [PATCH 040/119] Split into sub accessor files --- core/rawdb/accessors_chain.go | 176 +------------------------------ core/rawdb/accessors_indexes.go | 144 +++++++++++++++++++++++++ core/rawdb/accessors_metadata.go | 71 +++++++++++++ core/rawdb/schema.go | 8 ++ 4 files changed, 224 insertions(+), 175 deletions(-) create mode 100644 core/rawdb/accessors_indexes.go create mode 100644 core/rawdb/accessors_metadata.go diff --git a/core/rawdb/accessors_chain.go b/core/rawdb/accessors_chain.go index c27345b4e1..e80153e530 100644 --- a/core/rawdb/accessors_chain.go +++ b/core/rawdb/accessors_chain.go @@ -19,7 +19,6 @@ package rawdb import ( "bytes" "encoding/binary" - "encoding/json" "fmt" "math/big" @@ -27,7 +26,6 @@ import ( "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" ) @@ -41,14 +39,6 @@ type DatabaseDeleter interface { Delete(key []byte) error } -// TxLookupEntry is a positional metadata to help looking up the data content of -// a transaction or receipt given only its hash. -type TxLookupEntry struct { - BlockHash common.Hash - BlockIndex uint64 - Index uint64 -} - // encodeBlockNumber encodes a block number as big endian uint64 func encodeBlockNumber(number uint64) []byte { enc := make([]byte, 8) @@ -220,91 +210,6 @@ func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types. return receipts } -// GetTxLookupEntry retrieves the positional metadata associated with a transaction -// hash to allow retrieving the transaction or receipt by hash. -func GetTxLookupEntry(db DatabaseReader, hash common.Hash) (common.Hash, uint64, uint64) { - // Load the positional metadata from disk and bail if it fails - data, _ := db.Get(txLookupKey(hash)) - if len(data) == 0 { - return common.Hash{}, 0, 0 - } - // Parse and return the contents of the lookup entry - var entry TxLookupEntry - if err := rlp.DecodeBytes(data, &entry); err != nil { - log.Error("Invalid lookup entry RLP", "hash", hash, "err", err) - return common.Hash{}, 0, 0 - } - return entry.BlockHash, entry.BlockIndex, entry.Index -} - -// GetTransaction retrieves a specific transaction from the database, along with -// its added positional metadata. -func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, common.Hash, uint64, uint64) { - // Retrieve the lookup metadata and resolve the transaction from the body - blockHash, blockNumber, txIndex := GetTxLookupEntry(db, hash) - - if blockHash != (common.Hash{}) { - body := GetBody(db, blockHash, blockNumber) - if body == nil || len(body.Transactions) <= int(txIndex) { - log.Error("Transaction referenced missing", "number", blockNumber, "hash", blockHash, "index", txIndex) - return nil, common.Hash{}, 0, 0 - } - return body.Transactions[txIndex], blockHash, blockNumber, txIndex - } - // Old transaction representation, load the transaction and its metadata separately - data, _ := db.Get(hash.Bytes()) - if len(data) == 0 { - return nil, common.Hash{}, 0, 0 - } - var tx types.Transaction - if err := rlp.DecodeBytes(data, &tx); err != nil { - return nil, common.Hash{}, 0, 0 - } - // Retrieve the blockchain positional metadata - data, _ = db.Get(oldTxMetaKey(hash)) - if len(data) == 0 { - return nil, common.Hash{}, 0, 0 - } - var entry TxLookupEntry - if err := rlp.DecodeBytes(data, &entry); err != nil { - return nil, common.Hash{}, 0, 0 - } - return &tx, entry.BlockHash, entry.BlockIndex, entry.Index -} - -// GetReceipt retrieves a specific transaction receipt from the database, along with -// its added positional metadata. -func GetReceipt(db DatabaseReader, hash common.Hash) (*types.Receipt, common.Hash, uint64, uint64) { - // Retrieve the lookup metadata and resolve the receipt from the receipts - blockHash, blockNumber, receiptIndex := GetTxLookupEntry(db, hash) - - if blockHash != (common.Hash{}) { - receipts := GetBlockReceipts(db, blockHash, blockNumber) - if len(receipts) <= int(receiptIndex) { - log.Error("Receipt referenced missing", "number", blockNumber, "hash", blockHash, "index", receiptIndex) - return nil, common.Hash{}, 0, 0 - } - return receipts[receiptIndex], blockHash, blockNumber, receiptIndex - } - // Old receipt representation, load the receipt and set an unknown metadata - data, _ := db.Get(oldReceiptsKey(hash)) - if len(data) == 0 { - return nil, common.Hash{}, 0, 0 - } - var receipt types.ReceiptForStorage - err := rlp.DecodeBytes(data, &receipt) - if err != nil { - log.Error("Invalid receipt RLP", "hash", hash, "err", err) - } - return (*types.Receipt)(&receipt), common.Hash{}, 0, 0 -} - -// GetBloomBits retrieves the compressed bloom bit vector belonging to the given -// bit index and section indexes. -func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) ([]byte, error) { - return db.Get(bloomBitsKey(bit, section, head)) -} - // WriteCanonicalHash stores the canonical hash for the given block number. func WriteCanonicalHash(db ethdb.KeyValueWriter, hash common.Hash, number uint64) error { if err := db.Put(headerHashKey(number), hash.Bytes()); err != nil { @@ -426,35 +331,6 @@ func WriteBlockReceipts(db ethdb.KeyValueWriter, hash common.Hash, number uint64 return nil } -// WriteTxLookupEntries stores a positional metadata for every transaction from -// a block, enabling hash based transaction and receipt lookups. -func WriteTxLookupEntries(db ethdb.KeyValueWriter, block *types.Block) error { - // Iterate over each transaction and encode its metadata - for i, tx := range block.Transactions() { - entry := TxLookupEntry{ - BlockHash: block.Hash(), - BlockIndex: block.NumberU64(), - Index: uint64(i), - } - data, err := rlp.EncodeToBytes(entry) - if err != nil { - return err - } - if err := db.Put(txLookupKey(tx.Hash()), data); err != nil { - return err - } - } - return nil -} - -// WriteBloomBits writes the compressed bloom bits vector belonging to the given -// section and bit index. -func WriteBloomBits(db ethdb.KeyValueWriter, bit uint, section uint64, head common.Hash, bits []byte) { - if err := db.Put(bloomBitsKey(bit, section, head), bits); err != nil { - log.Crit("Failed to store bloom bits", "err", err) - } -} - // DeleteCanonicalHash removes the number to hash canonical mapping. func DeleteCanonicalHash(db DatabaseDeleter, number uint64) { db.Delete(headerHashKey(number)) @@ -463,7 +339,7 @@ func DeleteCanonicalHash(db DatabaseDeleter, number uint64) { // DeleteHeader removes all block header data associated with a hash. func DeleteHeader(db DatabaseDeleter, hash common.Hash, number uint64) { db.Delete(headerNumberKey(hash)) - db.Delete(headerHashKey(number)) + db.Delete(headerKey(number, hash)) } // DeleteBody removes all block body data associated with a hash. @@ -489,11 +365,6 @@ func DeleteBlockReceipts(db DatabaseDeleter, hash common.Hash, number uint64) { db.Delete(blockReceiptsKey(number, hash)) } -// DeleteTxLookupEntry removes all transaction data associated with a hash. -func DeleteTxLookupEntry(db DatabaseDeleter, hash common.Hash) { - db.Delete(txLookupKey(hash)) -} - // PreimageTable returns a Database instance with the key prefix for preimage entries. func PreimageTable(db ethdb.Database) ethdb.Database { return NewTable(db, preimagePrefix) @@ -521,51 +392,6 @@ func WritePreimages(db ethdb.Database, number uint64, preimages map[common.Hash] return nil } -// GetBlockChainVersion reads the version number from db. -func GetBlockChainVersion(db DatabaseReader) int { - var vsn uint - enc, _ := db.Get([]byte("BlockchainVersion")) - rlp.DecodeBytes(enc, &vsn) - return int(vsn) -} - -// WriteBlockChainVersion writes vsn as the version number to db. -func WriteBlockChainVersion(db ethdb.KeyValueWriter, vsn int) { - enc, _ := rlp.EncodeToBytes(uint(vsn)) - db.Put([]byte("BlockchainVersion"), enc) -} - -// WriteChainConfig writes the chain config settings to the database. -func WriteChainConfig(db ethdb.KeyValueWriter, hash common.Hash, cfg *params.ChainConfig) error { - // short circuit and ignore if nil config. GetChainConfig - // will return a default. - if cfg == nil { - return nil - } - - jsonChainConfig, err := json.Marshal(cfg) - if err != nil { - return err - } - - return db.Put(configKey(hash), jsonChainConfig) -} - -// GetChainConfig will fetch the network settings based on the given hash. -func GetChainConfig(db DatabaseReader, hash common.Hash) (*params.ChainConfig, error) { - jsonChainConfig, _ := db.Get(configKey(hash)) - if len(jsonChainConfig) == 0 { - return nil, ErrChainConfigNotFound - } - - var config params.ChainConfig - if err := json.Unmarshal(jsonChainConfig, &config); err != nil { - return nil, err - } - - return &config, nil -} - // FindCommonAncestor returns the last common ancestor of two block headers func FindCommonAncestor(db DatabaseReader, a, b *types.Header) *types.Header { for bn := b.Number.Uint64(); a.Number.Uint64() > bn; { diff --git a/core/rawdb/accessors_indexes.go b/core/rawdb/accessors_indexes.go new file mode 100644 index 0000000000..ae14c990bb --- /dev/null +++ b/core/rawdb/accessors_indexes.go @@ -0,0 +1,144 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import ( + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/rlp" +) + +// GetTxLookupEntry retrieves the positional metadata associated with a transaction +// hash to allow retrieving the transaction or receipt by hash. +func GetTxLookupEntry(db DatabaseReader, hash common.Hash) (common.Hash, uint64, uint64) { + // Load the positional metadata from disk and bail if it fails + data, _ := db.Get(txLookupKey(hash)) + if len(data) == 0 { + return common.Hash{}, 0, 0 + } + // Parse and return the contents of the lookup entry + var entry TxLookupEntry + if err := rlp.DecodeBytes(data, &entry); err != nil { + log.Error("Invalid lookup entry RLP", "hash", hash, "err", err) + return common.Hash{}, 0, 0 + } + return entry.BlockHash, entry.BlockIndex, entry.Index +} + +// WriteTxLookupEntries stores a positional metadata for every transaction from +// a block, enabling hash based transaction and receipt lookups. +func WriteTxLookupEntries(db ethdb.KeyValueWriter, block *types.Block) error { + // Iterate over each transaction and encode its metadata + for i, tx := range block.Transactions() { + entry := TxLookupEntry{ + BlockHash: block.Hash(), + BlockIndex: block.NumberU64(), + Index: uint64(i), + } + data, err := rlp.EncodeToBytes(entry) + if err != nil { + return err + } + if err := db.Put(txLookupKey(tx.Hash()), data); err != nil { + return err + } + } + return nil +} + +// DeleteTxLookupEntry removes all transaction data associated with a hash. +func DeleteTxLookupEntry(db DatabaseDeleter, hash common.Hash) { + db.Delete(txLookupKey(hash)) +} + +// GetTransaction retrieves a specific transaction from the database, along with +// its added positional metadata. +func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, common.Hash, uint64, uint64) { + // Retrieve the lookup metadata and resolve the transaction from the body + blockHash, blockNumber, txIndex := GetTxLookupEntry(db, hash) + + if blockHash != (common.Hash{}) { + body := GetBody(db, blockHash, blockNumber) + if body == nil || len(body.Transactions) <= int(txIndex) { + log.Error("Transaction referenced missing", "number", blockNumber, "hash", blockHash, "index", txIndex) + return nil, common.Hash{}, 0, 0 + } + return body.Transactions[txIndex], blockHash, blockNumber, txIndex + } + // Old transaction representation, load the transaction and its metadata separately + data, _ := db.Get(hash.Bytes()) + if len(data) == 0 { + return nil, common.Hash{}, 0, 0 + } + var tx types.Transaction + if err := rlp.DecodeBytes(data, &tx); err != nil { + return nil, common.Hash{}, 0, 0 + } + // Retrieve the blockchain positional metadata + data, _ = db.Get(oldTxMetaKey(hash)) + if len(data) == 0 { + return nil, common.Hash{}, 0, 0 + } + var entry TxLookupEntry + if err := rlp.DecodeBytes(data, &entry); err != nil { + return nil, common.Hash{}, 0, 0 + } + return &tx, entry.BlockHash, entry.BlockIndex, entry.Index +} + +// GetReceipt retrieves a specific transaction receipt from the database, along with +// its added positional metadata. +func GetReceipt(db DatabaseReader, hash common.Hash) (*types.Receipt, common.Hash, uint64, uint64) { + // Retrieve the lookup metadata and resolve the receipt from the receipts + blockHash, blockNumber, receiptIndex := GetTxLookupEntry(db, hash) + + if blockHash != (common.Hash{}) { + receipts := GetBlockReceipts(db, blockHash, blockNumber) + if len(receipts) <= int(receiptIndex) { + log.Error("Receipt referenced missing", "number", blockNumber, "hash", blockHash, "index", receiptIndex) + return nil, common.Hash{}, 0, 0 + } + return receipts[receiptIndex], blockHash, blockNumber, receiptIndex + } + // Old receipt representation, load the receipt and set an unknown metadata + data, _ := db.Get(oldReceiptsKey(hash)) + if len(data) == 0 { + return nil, common.Hash{}, 0, 0 + } + var receipt types.ReceiptForStorage + err := rlp.DecodeBytes(data, &receipt) + if err != nil { + log.Error("Invalid receipt RLP", "hash", hash, "err", err) + } + return (*types.Receipt)(&receipt), common.Hash{}, 0, 0 +} + +// GetBloomBits retrieves the compressed bloom bit vector belonging to the given +// bit index and section indexes. +func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) ([]byte, error) { + return db.Get(bloomBitsKey(bit, section, head)) +} + +// WriteBloomBits writes the compressed bloom bits vector belonging to the given +// section and bit index. +func WriteBloomBits(db ethdb.KeyValueWriter, bit uint, section uint64, head common.Hash, bits []byte) { + if err := db.Put(bloomBitsKey(bit, section, head), bits); err != nil { + log.Crit("Failed to store bloom bits", "err", err) + } +} diff --git a/core/rawdb/accessors_metadata.go b/core/rawdb/accessors_metadata.go new file mode 100644 index 0000000000..16fbbd77b0 --- /dev/null +++ b/core/rawdb/accessors_metadata.go @@ -0,0 +1,71 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import ( + "encoding/json" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/rlp" +) + +// GetBlockChainVersion reads the version number from db. +func GetBlockChainVersion(db DatabaseReader) int { + var vsn uint + enc, _ := db.Get([]byte("BlockchainVersion")) + rlp.DecodeBytes(enc, &vsn) + return int(vsn) +} + +// WriteBlockChainVersion writes vsn as the version number to db. +func WriteBlockChainVersion(db ethdb.KeyValueWriter, vsn int) { + enc, _ := rlp.EncodeToBytes(uint(vsn)) + db.Put([]byte("BlockchainVersion"), enc) +} + +// WriteChainConfig writes the chain config settings to the database. +func WriteChainConfig(db ethdb.KeyValueWriter, hash common.Hash, cfg *params.ChainConfig) error { + // short circuit and ignore if nil config. GetChainConfig + // will return a default. + if cfg == nil { + return nil + } + + jsonChainConfig, err := json.Marshal(cfg) + if err != nil { + return err + } + + return db.Put(configKey(hash), jsonChainConfig) +} + +// GetChainConfig will fetch the network settings based on the given hash. +func GetChainConfig(db DatabaseReader, hash common.Hash) (*params.ChainConfig, error) { + jsonChainConfig, _ := db.Get(configKey(hash)) + if len(jsonChainConfig) == 0 { + return nil, ErrChainConfigNotFound + } + + var config params.ChainConfig + if err := json.Unmarshal(jsonChainConfig, &config); err != nil { + return nil, err + } + + return &config, nil +} diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index 50c343cc24..528f0e15ee 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -57,6 +57,14 @@ var ( preimageHitCounter = metrics.NewRegisteredCounter("db/preimage/hits", nil) ) +// TxLookupEntry is a positional metadata to help looking up the data content of +// a transaction or receipt given only its hash. +type TxLookupEntry struct { + BlockHash common.Hash + BlockIndex uint64 + Index uint64 +} + // configKey = configPrefix + hash func configKey(hash common.Hash) []byte { return append(configPrefix, hash.Bytes()...) From 97a62cab91a05f82a871981b4b20fb06497cc964 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 21 Jul 2023 22:10:41 +0700 Subject: [PATCH 041/119] Re-enable VM tests --- tests/vm_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/vm_test.go b/tests/vm_test.go index 7fda3cc6f5..8377d4bc32 100644 --- a/tests/vm_test.go +++ b/tests/vm_test.go @@ -25,9 +25,6 @@ import ( ) func TestVM(t *testing.T) { - if testing.Short() { - t.Skip("skipping testing in short mode") - } common.TIPTomoXCancellationFee = big.NewInt(100000000) t.Parallel() vmt := new(testMatcher) From 024181b82b8888bfd758ca2e7d9e5dfcd4a354c9 Mon Sep 17 00:00:00 2001 From: c98tristan Date: Mon, 17 Jul 2023 14:46:01 +0700 Subject: [PATCH 042/119] chore: remove intPool --- core/vm/interpreter.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index 36027be797..1b86bf37ce 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -75,8 +75,6 @@ type EVMInterpreter struct { evm *EVM cfg Config - intPool *intPool - hasher crypto.KeccakState // Keccak256 hasher instance shared across opcodes hasherBuf common.Hash // Keccak256 hasher result array shared across opcodes From c3846466a57fd2aca567edee54c47f326bacc86b Mon Sep 17 00:00:00 2001 From: c98tristan Date: Wed, 19 Jul 2023 10:03:33 +0700 Subject: [PATCH 043/119] Feat: Update derive SHA for stacktrie --- consensus/clique/clique.go | 3 ++- consensus/ethash/consensus.go | 3 ++- consensus/posv/posv.go | 3 ++- core/bench_test.go | 4 ++-- core/block_validator.go | 6 ++++-- core/blockchain_test.go | 8 ++++--- core/database_util.go | 3 ++- core/database_util_test.go | 10 +++++---- core/genesis.go | 14 ++++++------ core/tx_pool_test.go | 20 ++++++++++++------ core/types/block.go | 8 +++---- core/types/block_test.go | 40 ++++++++++++++++++++++++++++++++++- core/types/derive_sha.go | 17 ++++++++++----- eth/downloader/queue.go | 5 +++-- eth/fetcher/fetcher.go | 8 ++++--- eth/fetcher/fetcher_test.go | 6 ++++-- les/odr_requests.go | 4 ++-- miner/worker.go | 3 +++ rlp/encode.go | 2 ++ trie/database.go | 6 ++++-- trie/trie.go | 6 ++++++ 21 files changed, 130 insertions(+), 49 deletions(-) diff --git a/consensus/clique/clique.go b/consensus/clique/clique.go index f63373e17e..28610581d2 100644 --- a/consensus/clique/clique.go +++ b/consensus/clique/clique.go @@ -40,6 +40,7 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/trie" ) const ( @@ -575,7 +576,7 @@ func (c *Clique) Finalize(chain consensus.ChainReader, header *types.Header, sta header.UncleHash = types.CalcUncleHash(nil) // Assemble and return the final block for sealing - return types.NewBlock(header, txs, nil, receipts), nil + return types.NewBlock(header, txs, nil, receipts, new(trie.Trie)), nil } // Authorize injects a private key into the consensus engine to mint new blocks diff --git a/consensus/ethash/consensus.go b/consensus/ethash/consensus.go index 12f63cfde7..ee90ca7bac 100644 --- a/consensus/ethash/consensus.go +++ b/consensus/ethash/consensus.go @@ -32,6 +32,7 @@ import ( "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) // Ethash proof-of-work protocol constants. @@ -519,7 +520,7 @@ func (ethash *Ethash) Finalize(chain consensus.ChainReader, header *types.Header header.Root = state.IntermediateRoot(chain.Config().IsEIP158(header.Number)) // Header seems complete, assemble into a block and return - return types.NewBlock(header, txs, uncles, receipts), nil + return types.NewBlock(header, txs, uncles, receipts, new(trie.Trie)), nil } // Some weird constants to avoid constant memory allocs for them. diff --git a/consensus/posv/posv.go b/consensus/posv/posv.go index f2b48fde93..fa31a2d605 100644 --- a/consensus/posv/posv.go +++ b/consensus/posv/posv.go @@ -49,6 +49,7 @@ import ( "github.com/tomochain/tomochain/rpc" "github.com/tomochain/tomochain/tomox/tradingstate" "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "github.com/tomochain/tomochain/trie" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -985,7 +986,7 @@ func (c *Posv) Finalize(chain consensus.ChainReader, header *types.Header, state header.UncleHash = types.CalcUncleHash(nil) // Assemble and return the final block for sealing - return types.NewBlock(header, txs, nil, receipts), nil + return types.NewBlock(header, txs, nil, receipts, new(trie.Trie)), nil } // Authorize injects a private key into the consensus engine to mint new blocks diff --git a/core/bench_test.go b/core/bench_test.go index cef95625c6..ac36d44a95 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -23,10 +23,11 @@ import ( "os" "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus/ethash" - "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -242,7 +243,6 @@ func makeChainForBench(db ethdb.Database, full bool, count uint64) { WriteChainConfig(db, hash, params.AllEthashProtocolChanges) } WriteHeadHeaderHash(db, hash) - if full || n == 0 { block := types.NewBlockWithHeader(header) WriteBody(db, hash, n, block.Body()) diff --git a/core/block_validator.go b/core/block_validator.go index 34fde4cedd..384fec62f0 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -18,6 +18,7 @@ package core import ( "fmt" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/posv" @@ -27,6 +28,7 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/tomox/tradingstate" "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "github.com/tomochain/tomochain/trie" ) // BlockValidator is responsible for validating block headers, uncles and @@ -71,7 +73,7 @@ func (v *BlockValidator) ValidateBody(block *types.Block) error { if hash := types.CalcUncleHash(block.Uncles()); hash != header.UncleHash { return fmt.Errorf("uncle root hash mismatch: have %x, want %x", hash, header.UncleHash) } - if hash := types.DeriveSha(block.Transactions()); hash != header.TxHash { + if hash := types.DeriveSha(block.Transactions(), new(trie.Trie)); hash != header.TxHash { return fmt.Errorf("transaction root hash mismatch: have %x, want %x", hash, header.TxHash) } return nil @@ -93,7 +95,7 @@ func (v *BlockValidator) ValidateState(block, parent *types.Block, statedb *stat return fmt.Errorf("invalid bloom (remote: %x local: %x)", header.Bloom, rbloom) } // Tre receipt Trie's root (R = (Tr [[H1, R1], ... [Hn, R1]])) - receiptSha := types.DeriveSha(receipts) + receiptSha := types.DeriveSha(receipts, new(trie.Trie)) if receiptSha != header.ReceiptHash { return fmt.Errorf("invalid receipt root hash (remote: %x local: %x)", header.ReceiptHash, receiptSha) } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 6860924112..124880ba87 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -18,13 +18,15 @@ package core import ( "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "sync" "testing" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/trie" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core/state" @@ -617,12 +619,12 @@ func TestFastVsFullChains(t *testing.T) { } if fblock, ablock := fast.GetBlockByHash(hash), archive.GetBlockByHash(hash); fblock.Hash() != ablock.Hash() { t.Errorf("block #%d [%x]: block mismatch: have %v, want %v", num, hash, fblock, ablock) - } else if types.DeriveSha(fblock.Transactions()) != types.DeriveSha(ablock.Transactions()) { + } else if types.DeriveSha(fblock.Transactions(), new(trie.Trie)) != types.DeriveSha(ablock.Transactions(), new(trie.Trie)) { t.Errorf("block #%d [%x]: transactions mismatch: have %v, want %v", num, hash, fblock.Transactions(), ablock.Transactions()) } else if types.CalcUncleHash(fblock.Uncles()) != types.CalcUncleHash(ablock.Uncles()) { t.Errorf("block #%d [%x]: uncles mismatch: have %v, want %v", num, hash, fblock.Uncles(), ablock.Uncles()) } - if freceipts, areceipts := GetBlockReceipts(fastDb, hash, GetBlockNumber(fastDb, hash)), GetBlockReceipts(archiveDb, hash, GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts) != types.DeriveSha(areceipts) { + if freceipts, areceipts := GetBlockReceipts(fastDb, hash, GetBlockNumber(fastDb, hash)), GetBlockReceipts(archiveDb, hash, GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts, new(trie.Trie)) != types.DeriveSha(areceipts, new(trie.Trie)) { t.Errorf("block #%d [%x]: receipts mismatch: have %v, want %v", num, hash, freceipts, areceipts) } } diff --git a/core/database_util.go b/core/database_util.go index a5ab18687d..a668434f16 100644 --- a/core/database_util.go +++ b/core/database_util.go @@ -22,9 +22,10 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" diff --git a/core/database_util_test.go b/core/database_util_test.go index f28ca160a5..d15d5d2e53 100644 --- a/core/database_util_test.go +++ b/core/database_util_test.go @@ -18,10 +18,12 @@ package core import ( "bytes" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/trie" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto/sha3" @@ -83,7 +85,7 @@ func TestBodyStorage(t *testing.T) { } if entry := GetBody(db, hash, 0); entry == nil { t.Fatalf("Stored body not found") - } else if types.DeriveSha(types.Transactions(entry.Transactions)) != types.DeriveSha(types.Transactions(body.Transactions)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) { + } else if types.DeriveSha(types.Transactions(entry.Transactions), new(trie.Trie)) != types.DeriveSha(types.Transactions(body.Transactions), new(trie.Trie)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) { t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, body) } if entry := GetBodyRLP(db, hash, 0); entry == nil { @@ -139,7 +141,7 @@ func TestBlockStorage(t *testing.T) { } if entry := GetBody(db, block.Hash(), block.NumberU64()); entry == nil { t.Fatalf("Stored body not found") - } else if types.DeriveSha(types.Transactions(entry.Transactions)) != types.DeriveSha(block.Transactions()) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) { + } else if types.DeriveSha(types.Transactions(entry.Transactions), new(trie.Trie)) != types.DeriveSha(block.Transactions(), new(trie.Trie)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) { t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, block.Body()) } // Delete the block and verify the execution @@ -295,7 +297,7 @@ func TestLookupStorage(t *testing.T) { tx3 := types.NewTransaction(3, common.BytesToAddress([]byte{0x33}), big.NewInt(333), 3333, big.NewInt(33333), []byte{0x33, 0x33, 0x33}) txs := []*types.Transaction{tx1, tx2, tx3} - block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil) + block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil, new(trie.Trie)) // Check that no transactions entries are in a pristine database for i, tx := range txs { diff --git a/core/genesis.go b/core/genesis.go index e1b7185a41..4f417d64ba 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -22,10 +22,12 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "strings" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/trie" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" @@ -140,10 +142,10 @@ func (e *GenesisMismatchError) Error() string { // SetupGenesisBlock writes or updates the genesis block in db. // The block that will be used is: // -// genesis == nil genesis != nil -// +------------------------------------------ -// db has no genesis | main-net default | genesis -// db has genesis | from DB | genesis (if compatible) +// genesis == nil genesis != nil +// +------------------------------------------ +// db has no genesis | main-net default | genesis +// db has genesis | from DB | genesis (if compatible) // // The stored chain configuration will be updated if it is compatible (i.e. does not // specify a fork block below the local head block). In case of a conflict, the @@ -258,7 +260,7 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block { statedb.Commit(false) statedb.Database().TrieDB().Commit(root, true) - return types.NewBlock(head, nil, nil, nil) + return types.NewBlock(head, nil, nil, nil, new(trie.Trie)) } // Commit writes the block and state of a genesis specification to the database. diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 4c6a311119..c458d0d35a 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -19,8 +19,6 @@ package core import ( "crypto/ecdsa" "fmt" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "math/rand" @@ -28,6 +26,10 @@ import ( "testing" "time" + "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/trie" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" @@ -70,7 +72,7 @@ func (bc *testBlockChain) Config() *params.ChainConfig { func (bc *testBlockChain) CurrentBlock() *types.Block { return types.NewBlock(&types.Header{ GasLimit: bc.gasLimit, - }, nil, nil, nil) + }, nil, nil, nil, new(trie.Trie)) } func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block { @@ -872,8 +874,10 @@ func testTransactionQueueGlobalLimiting(t *testing.T, nolocals bool) { // // This logic should not hold for local transactions, unless the local tracking // mechanism is disabled. -func TestTransactionQueueTimeLimiting(t *testing.T) { testTransactionQueueTimeLimiting(t, false) } -func TestTransactionQueueTimeLimitingNoLocals(t *testing.T) { testTransactionQueueTimeLimiting(t, true) } +func TestTransactionQueueTimeLimiting(t *testing.T) { testTransactionQueueTimeLimiting(t, false) } +func TestTransactionQueueTimeLimitingNoLocals(t *testing.T) { + testTransactionQueueTimeLimiting(t, true) +} func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { common.MinGasPrice = big.NewInt(0) @@ -981,8 +985,10 @@ func TestTransactionPendingLimiting(t *testing.T) { // Tests that the transaction limits are enforced the same way irrelevant whether // the transactions are added one by one or in batches. -func TestTransactionQueueLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 1) } -func TestTransactionPendingLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 0) } +func TestTransactionQueueLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 1) } +func TestTransactionPendingLimitingEquivalency(t *testing.T) { + testTransactionLimitingEquivalency(t, 0) +} func testTransactionLimitingEquivalency(t *testing.T, origin uint64) { t.Parallel() diff --git a/core/types/block.go b/core/types/block.go index a055ced147..0f06e935e6 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -34,7 +34,7 @@ import ( ) var ( - EmptyRootHash = DeriveSha(Transactions{}) + EmptyRootHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") EmptyUncleHash = CalcUncleHash(nil) ) @@ -225,14 +225,14 @@ type storageblock struct { // The values of TxHash, UncleHash, ReceiptHash and Bloom in header // are ignored and set to values derived from the given txs, uncles // and receipts. -func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []*Receipt) *Block { +func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []*Receipt, hasher Hasher) *Block { b := &Block{header: CopyHeader(header), td: new(big.Int)} // TODO: panic if len(txs) != len(receipts) if len(txs) == 0 { b.header.TxHash = EmptyRootHash } else { - b.header.TxHash = DeriveSha(Transactions(txs)) + b.header.TxHash = DeriveSha(Transactions(txs), hasher) b.transactions = make(Transactions, len(txs)) copy(b.transactions, txs) } @@ -240,7 +240,7 @@ func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []* if len(receipts) == 0 { b.header.ReceiptHash = EmptyRootHash } else { - b.header.ReceiptHash = DeriveSha(Receipts(receipts)) + b.header.ReceiptHash = DeriveSha(Receipts(receipts), hasher) b.header.Bloom = CreateBloom(receipts) } diff --git a/core/types/block_test.go b/core/types/block_test.go index 9b78b653c7..460dc35ba6 100644 --- a/core/types/block_test.go +++ b/core/types/block_test.go @@ -17,13 +17,16 @@ package types import ( + "hash" "math/big" "testing" "bytes" + "reflect" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/rlp" - "reflect" + "golang.org/x/crypto/sha3" ) // from bcValidBlockTest.json, "SimpleTx" @@ -59,3 +62,38 @@ func TestBlockEncoding(t *testing.T) { t.Errorf("encoded block mismatch:\ngot: %x\nwant: %x", ourBlockEnc, blockEnc) } } + +func TestUncleHash(t *testing.T) { + uncles := make([]*Header, 0) + h := CalcUncleHash(uncles) + exp := common.HexToHash("1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347") + if h != exp { + t.Fatalf("empty uncle hash is wrong, got %x != %x", h, exp) + } +} + +var benchBuffer = bytes.NewBuffer(make([]byte, 0, 32000)) + +// testHasher is the helper tool for transaction/receipt list hashing. +// The original hasher is trie, in order to get rid of import cycle, +// use the testing hasher instead. +type testHasher struct { + hasher hash.Hash +} + +func newHasher() *testHasher { + return &testHasher{hasher: sha3.NewLegacyKeccak256()} +} + +func (h *testHasher) Reset() { + h.hasher.Reset() +} + +func (h *testHasher) Update(key, val []byte) { + h.hasher.Write(key) + h.hasher.Write(val) +} + +func (h *testHasher) Hash() common.Hash { + return common.BytesToHash(h.hasher.Sum(nil)) +} diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go index 2731c39cbb..78eab91597 100644 --- a/core/types/derive_sha.go +++ b/core/types/derive_sha.go @@ -21,21 +21,28 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/rlp" - "github.com/tomochain/tomochain/trie" ) +// DerivableList is the interface which can derive the hash. type DerivableList interface { Len() int GetRlp(i int) []byte } -func DeriveSha(list DerivableList) common.Hash { +// Hasher is the tool used to calculate the hash of derivable list. +type Hasher interface { + Reset() + Update([]byte, []byte) + Hash() common.Hash +} + +func DeriveSha(list DerivableList, hasher Hasher) common.Hash { + hasher.Reset() keybuf := new(bytes.Buffer) - trie := new(trie.Trie) for i := 0; i < list.Len(); i++ { keybuf.Reset() rlp.Encode(keybuf, uint(i)) - trie.Update(keybuf.Bytes(), list.GetRlp(i)) + hasher.Update(keybuf.Bytes(), list.GetRlp(i)) } - return trie.Hash() + return hasher.Hash() } diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index 0ed4e75faa..38a3f839f5 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -29,6 +29,7 @@ import ( "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/metrics" + "github.com/tomochain/tomochain/trie" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -767,7 +768,7 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLi defer q.lock.Unlock() reconstruct := func(header *types.Header, index int, result *fetchResult) error { - if types.DeriveSha(types.Transactions(txLists[index])) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash { + if types.DeriveSha(types.Transactions(txLists[index]), new(trie.Trie)) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash { return errInvalidBody } result.Transactions = txLists[index] @@ -785,7 +786,7 @@ func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) (int, defer q.lock.Unlock() reconstruct := func(header *types.Header, index int, result *fetchResult) error { - if types.DeriveSha(types.Receipts(receiptList[index])) != header.ReceiptHash { + if types.DeriveSha(types.Receipts(receiptList[index]), new(trie.Trie)) != header.ReceiptHash { return errInvalidReceipt } result.Receipts = receiptList[index] diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go index 65b15094d2..6b3080ce13 100644 --- a/eth/fetcher/fetcher.go +++ b/eth/fetcher/fetcher.go @@ -19,14 +19,16 @@ package fetcher import ( "errors" - "github.com/hashicorp/golang-lru" "math/rand" "time" + lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/trie" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -468,7 +470,7 @@ func (f *Fetcher) loop() { announce.time = task.time // If the block is empty (header only), short circuit into the final import queue - if header.TxHash == types.DeriveSha(types.Transactions{}) && header.UncleHash == types.CalcUncleHash([]*types.Header{}) { + if header.TxHash == types.EmptyRootHash && header.UncleHash == types.CalcUncleHash([]*types.Header{}) { log.Trace("Block empty, skipping body retrieval", "peer", announce.origin, "number", header.Number, "hash", header.Hash()) block := types.NewBlockWithHeader(header) @@ -530,7 +532,7 @@ func (f *Fetcher) loop() { for hash, announce := range f.completing { if f.queued[hash] == nil { - txnHash := types.DeriveSha(types.Transactions(task.transactions[i])) + txnHash := types.DeriveSha(types.Transactions(task.transactions[i]), new(trie.Trie)) uncleHash := types.CalcUncleHash(task.uncles[i]) if txnHash == announce.header.TxHash && uncleHash == announce.header.UncleHash && announce.origin == task.peer { diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go index ab7e03aaa1..10acef3e6b 100644 --- a/eth/fetcher/fetcher_test.go +++ b/eth/fetcher/fetcher_test.go @@ -18,13 +18,15 @@ package fetcher import ( "errors" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sync" "sync/atomic" "testing" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/trie" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" @@ -38,7 +40,7 @@ var ( testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") testAddress = crypto.PubkeyToAddress(testKey.PublicKey) genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000)) - unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil) + unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil, new(trie.Trie)) ) // makeChain creates a chain of n blocks starting at and including parent. diff --git a/les/odr_requests.go b/les/odr_requests.go index e6e68e7621..f061331dbc 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -114,7 +114,7 @@ func (r *BlockRequest) Validate(db ethdb.Database, msg *Msg) error { if header == nil { return errHeaderUnavailable } - if header.TxHash != types.DeriveSha(types.Transactions(body.Transactions)) { + if header.TxHash != types.DeriveSha(types.Transactions(body.Transactions), new(trie.Trie)) { return errTxHashMismatch } if header.UncleHash != types.CalcUncleHash(body.Uncles) { @@ -170,7 +170,7 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error { if header == nil { return errHeaderUnavailable } - if header.ReceiptHash != types.DeriveSha(receipt) { + if header.ReceiptHash != types.DeriveSha(receipt, new(trie.Trie)) { return errReceiptHashMismatch } // Validations passed, store and return diff --git a/miner/worker.go b/miner/worker.go index 995c401690..a8985a2a81 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -23,6 +23,7 @@ import ( "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "github.com/tomochain/tomochain/trie" "math/big" "os" @@ -204,6 +205,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) { self.current.txs, nil, self.current.receipts, + new(trie.Trie), ), self.current.state.Copy() } return self.current.Block, self.current.state.Copy() @@ -219,6 +221,7 @@ func (self *worker) pendingBlock() *types.Block { self.current.txs, nil, self.current.receipts, + new(trie.Trie), ) } return self.current.Block diff --git a/rlp/encode.go b/rlp/encode.go index f34be7f3df..977ff2088d 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -79,6 +79,8 @@ func EncodeToBytes(val interface{}) ([]byte, error) { buf := getEncBuffer() defer encBufferPool.Put(buf) + fmt.Println("val: ", val) + if err := buf.encode(val); err != nil { return nil, err } diff --git a/trie/database.go b/trie/database.go index bf3f2e89de..879c8be3bc 100644 --- a/trie/database.go +++ b/trie/database.go @@ -107,11 +107,11 @@ func (n rawNode) Cache() (HashNode, bool) { panic("this should never end up in func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } func (n rawNode) EncodeRLP(w io.Writer) error { - _, err := w.Write(n) + _, err := w.Write([]byte(n)) return err } -// rawFullNode represents only the useful data content of a full Node, with the +// rawFullNode represents only the useful data content of a full node, with the // caches and flags stripped out to minimize its data storage. This type honors // the same RLP encoding as the original parent. type rawFullNode [17]Node @@ -171,6 +171,7 @@ func (n *cachedNode) rlp() []byte { if node, ok := n.node.(rawNode); ok { return node } + fmt.Println("cachedNode rlp", n.node) blob, err := rlp.EncodeToBytes(n.node) if err != nil { panic(err) @@ -790,6 +791,7 @@ func (db *Database) commit(hash common.Hash, batch ethdb.Batch, uncacher *cleane if err != nil { return err } + fmt.Println("commit", node) if err := batch.Put(hash[:], node.rlp()); err != nil { return err } diff --git a/trie/trie.go b/trie/trie.go index 589a96186d..e630a6a546 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -637,3 +637,9 @@ func (t *Trie) hashRoot(db *Database) (Node, Node, error) { t.unhashed = 0 return hashed, cached, nil } + +// Reset drops the referenced root node and cleans all internal state. +func (t *Trie) Reset() { + t.root = nil + t.unhashed = 0 +} From b8421e5dec4ca7c9288e6498f96efaad7cdd63ca Mon Sep 17 00:00:00 2001 From: c98tristan Date: Wed, 19 Jul 2023 10:55:24 +0700 Subject: [PATCH 044/119] Chore: Add benchmark for Stacktrie --- core/bench_test.go | 1 + core/types/hashing_test.go | 79 ++++++++++++++++++++++++++++++++++++++ rlp/encode.go | 2 - 3 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 core/types/hashing_test.go diff --git a/core/bench_test.go b/core/bench_test.go index ac36d44a95..af2132ea20 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -243,6 +243,7 @@ func makeChainForBench(db ethdb.Database, full bool, count uint64) { WriteChainConfig(db, hash, params.AllEthashProtocolChanges) } WriteHeadHeaderHash(db, hash) + if full || n == 0 { block := types.NewBlockWithHeader(header) WriteBody(db, hash, n, block.Body()) diff --git a/core/types/hashing_test.go b/core/types/hashing_test.go new file mode 100644 index 0000000000..d2f2781a6b --- /dev/null +++ b/core/types/hashing_test.go @@ -0,0 +1,79 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types_test + +import ( + "math/big" + "testing" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/trie" +) + +func BenchmarkDeriveSha200(b *testing.B) { + txs, err := genTxs(200) + if err != nil { + b.Fatal(err) + } + var exp common.Hash + var got common.Hash + b.Run("std_trie", func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) + exp = types.DeriveSha(txs, tr) + } + }) + + b.Run("stack_trie", func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + got = types.DeriveSha(txs, trie.NewStackTrie(nil)) + } + }) + if got != exp { + b.Errorf("got %x exp %x", got, exp) + } +} + +func genTxs(num uint64) (types.Transactions, error) { + key, err := crypto.HexToECDSA("deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + if err != nil { + return nil, err + } + var addr = crypto.PubkeyToAddress(key.PublicKey) + newTx := func(i uint64) (*types.Transaction, error) { + signer := types.NewEIP155Signer(big.NewInt(18)) + utx := types.NewTransaction(i, addr, new(big.Int), 0, new(big.Int).SetUint64(10000000), nil) + tx, err := types.SignTx(utx, signer, key) + return tx, err + } + var txs types.Transactions + for i := uint64(0); i < num; i++ { + tx, err := newTx(i) + if err != nil { + return nil, err + } + txs = append(txs, tx) + } + return txs, nil +} diff --git a/rlp/encode.go b/rlp/encode.go index 977ff2088d..f34be7f3df 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -79,8 +79,6 @@ func EncodeToBytes(val interface{}) ([]byte, error) { buf := getEncBuffer() defer encBufferPool.Put(buf) - fmt.Println("val: ", val) - if err := buf.encode(val); err != nil { return nil, err } From 6099609aed68b4d3ff5469d055f04e42ac40a656 Mon Sep 17 00:00:00 2001 From: c98tristan Date: Wed, 19 Jul 2023 14:18:52 +0700 Subject: [PATCH 045/119] Chore: Fix Update function missing return type --- core/types/derive_sha.go | 2 +- trie/trie.go | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go index 78eab91597..fecd0775b3 100644 --- a/core/types/derive_sha.go +++ b/core/types/derive_sha.go @@ -32,7 +32,7 @@ type DerivableList interface { // Hasher is the tool used to calculate the hash of derivable list. type Hasher interface { Reset() - Update([]byte, []byte) + Update([]byte, []byte) error Hash() common.Hash } diff --git a/trie/trie.go b/trie/trie.go index e630a6a546..680e8c4e93 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -316,10 +316,13 @@ func (t *Trie) tryGetBestRightKeyAndValue(origNode Node, prefix []byte) (key []b // // The value bytes must not be modified by the caller while they are // stored in the trie. -func (t *Trie) Update(key, value []byte) { +func (t *Trie) Update(key, value []byte) error { if err := t.TryUpdate(key, value); err != nil { log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + return err } + + return nil } // TryUpdate associates key with value in the trie. Subsequent calls to From ae6ca826cc46e122dedacb29db0f1435cc207239 Mon Sep 17 00:00:00 2001 From: c98tristan Date: Mon, 24 Jul 2023 11:03:27 +0700 Subject: [PATCH 046/119] Chore: Change GetRlp to EncodeIndex in LendingTransaction and OrderTransaction --- core/types/lending_transaction.go | 13 ++++++++----- core/types/order_transaction.go | 11 +++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/core/types/lending_transaction.go b/core/types/lending_transaction.go index e33826829c..7152461117 100644 --- a/core/types/lending_transaction.go +++ b/core/types/lending_transaction.go @@ -17,6 +17,7 @@ package types import ( + "bytes" "container/heap" "errors" "io" @@ -319,10 +320,12 @@ func (s LendingTransactions) Len() int { return len(s) } // Swap swaps the i'th and the j'th element in s. func (s LendingTransactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// GetRlp implements Rlpable and returns the i'th element of s in rlp. -func (s LendingTransactions) GetRlp(i int) []byte { - enc, _ := rlp.EncodeToBytes(s[i]) - return enc +// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors +// because we assume that *Transaction will only ever contain valid txs that were either +// constructed by decoding or via public API in this package. +func (s LendingTransactions) EncodeIndex(i int, w *bytes.Buffer) { + tx := s[i] + rlp.Encode(w, tx.data) } // LendingTxDifference returns a new set t which is the difference between a to b. @@ -363,7 +366,7 @@ func (s *LendingTxByNonce) Pop() interface{} { return x } -//LendingTransactionByNonce sort transaction by nonce +// LendingTransactionByNonce sort transaction by nonce type LendingTransactionByNonce struct { txs map[common.Address]LendingTransactions heads LendingTxByNonce diff --git a/core/types/order_transaction.go b/core/types/order_transaction.go index d51884e3f5..e7150b991e 100644 --- a/core/types/order_transaction.go +++ b/core/types/order_transaction.go @@ -17,6 +17,7 @@ package types import ( + "bytes" "container/heap" "errors" "io" @@ -250,10 +251,12 @@ func (s OrderTransactions) Len() int { return len(s) } // Swap swaps the i'th and the j'th element in s. func (s OrderTransactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// GetRlp implements Rlpable and returns the i'th element of s in rlp. -func (s OrderTransactions) GetRlp(i int) []byte { - enc, _ := rlp.EncodeToBytes(s[i]) - return enc +// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors +// because we assume that *Transaction will only ever contain valid txs that were either +// constructed by decoding or via public API in this package. +func (s OrderTransactions) EncodeIndex(i int, w *bytes.Buffer) { + tx := s[i] + rlp.Encode(w, tx.data) } // OrderTxDifference returns a new set t which is the difference between a to b. From 515a614188b4509f623cb8dca35d5eced97515fe Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Wed, 19 Jul 2023 14:47:39 +0700 Subject: [PATCH 047/119] Update DeriveSha with new Hasher and DerivableList interface --- core/types/derive_sha.go | 42 +++++++++++++++++++++++++++++++++------ core/types/hashing.go | 11 ++++++++++ core/types/receipt.go | 12 +++++------ core/types/transaction.go | 15 ++++++++------ trie/database.go | 1 - trie/trie.go | 1 - 6 files changed, 61 insertions(+), 21 deletions(-) create mode 100644 core/types/hashing.go diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go index fecd0775b3..210ee26e06 100644 --- a/core/types/derive_sha.go +++ b/core/types/derive_sha.go @@ -26,7 +26,7 @@ import ( // DerivableList is the interface which can derive the hash. type DerivableList interface { Len() int - GetRlp(i int) []byte + EncodeIndex(int, *bytes.Buffer) } // Hasher is the tool used to calculate the hash of derivable list. @@ -36,13 +36,43 @@ type Hasher interface { Hash() common.Hash } +func encodeForDerive(list DerivableList, i int, buf *bytes.Buffer) []byte { + buf.Reset() + list.EncodeIndex(i, buf) + // It's really unfortunate that we need to do perform this copy. + // StackTrie holds onto the values until Hash is called, so the values + // written to it must not alias. + return common.CopyBytes(buf.Bytes()) +} + +// DeriveSha creates the tree hashes of transactions, receipts, and withdrawals in a block header. func DeriveSha(list DerivableList, hasher Hasher) common.Hash { hasher.Reset() - keybuf := new(bytes.Buffer) - for i := 0; i < list.Len(); i++ { - keybuf.Reset() - rlp.Encode(keybuf, uint(i)) - hasher.Update(keybuf.Bytes(), list.GetRlp(i)) + + valueBuf := encodeBufferPool.Get().(*bytes.Buffer) + defer encodeBufferPool.Put(valueBuf) + + // StackTrie requires values to be inserted in increasing hash order, which is not the + // order that `list` provides hashes in. This insertion sequence ensures that the + // order is correct. + // + // The error returned by hasher is omitted because hasher will produce an incorrect + // hash in case any error occurs. + var indexBuf []byte + for i := 1; i < list.Len() && i <= 0x7f; i++ { + indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i)) + value := encodeForDerive(list, i, valueBuf) + hasher.Update(indexBuf, value) + } + if list.Len() > 0 { + indexBuf = rlp.AppendUint64(indexBuf[:0], 0) + value := encodeForDerive(list, 0, valueBuf) + hasher.Update(indexBuf, value) + } + for i := 0x80; i < list.Len(); i++ { + indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i)) + value := encodeForDerive(list, i, valueBuf) + hasher.Update(indexBuf, value) } return hasher.Hash() } diff --git a/core/types/hashing.go b/core/types/hashing.go new file mode 100644 index 0000000000..8b9cb92b94 --- /dev/null +++ b/core/types/hashing.go @@ -0,0 +1,11 @@ +package types + +import ( + "bytes" + "sync" +) + +// encodeBufferPool holds temporary encoder buffers for DeriveSha and TX encoding. +var encodeBufferPool = sync.Pool{ + New: func() interface{} { return new(bytes.Buffer) }, +} diff --git a/core/types/receipt.go b/core/types/receipt.go index 879aaf29c9..1fe49d7aef 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -206,11 +206,9 @@ type Receipts []*Receipt // Len returns the number of receipts in this list. func (r Receipts) Len() int { return len(r) } -// GetRlp returns the RLP encoding of one receipt from the list. -func (r Receipts) GetRlp(i int) []byte { - bytes, err := rlp.EncodeToBytes(r[i]) - if err != nil { - panic(err) - } - return bytes +// EncodeIndex encodes the i'th receipt to w. +func (rs Receipts) EncodeIndex(i int, w *bytes.Buffer) { + r := rs[i] + data := &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs} + rlp.Encode(w, data) } diff --git a/core/types/transaction.go b/core/types/transaction.go index cf546c4420..5b155ac6af 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -17,6 +17,7 @@ package types import ( + "bytes" "container/heap" "errors" "fmt" @@ -523,15 +524,17 @@ type Transactions []*Transaction // Len returns the length of s. func (s Transactions) Len() int { return len(s) } +// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors +// because we assume that *Transaction will only ever contain valid txs that were either +// constructed by decoding or via public API in this package. +func (s Transactions) EncodeIndex(i int, w *bytes.Buffer) { + tx := s[i] + rlp.Encode(w, tx.data) +} + // Swap swaps the i'th and the j'th element in s. func (s Transactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// GetRlp implements Rlpable and returns the i'th element of s in rlp. -func (s Transactions) GetRlp(i int) []byte { - enc, _ := rlp.EncodeToBytes(s[i]) - return enc -} - // TxDifference returns a new set t which is the difference between a to b. func TxDifference(a, b Transactions) (keep Transactions) { keep = make(Transactions, 0, len(a)) diff --git a/trie/database.go b/trie/database.go index 879c8be3bc..43ce266371 100644 --- a/trie/database.go +++ b/trie/database.go @@ -171,7 +171,6 @@ func (n *cachedNode) rlp() []byte { if node, ok := n.node.(rawNode); ok { return node } - fmt.Println("cachedNode rlp", n.node) blob, err := rlp.EncodeToBytes(n.node) if err != nil { panic(err) diff --git a/trie/trie.go b/trie/trie.go index 680e8c4e93..a0c627d232 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -321,7 +321,6 @@ func (t *Trie) Update(key, value []byte) error { log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) return err } - return nil } From 44d2c5469db5268ea2e1372ee3b2ff2913888403 Mon Sep 17 00:00:00 2001 From: c98tristan Date: Mon, 24 Jul 2023 11:52:44 +0700 Subject: [PATCH 048/119] Chore: Add intPool in interpreter.go --- core/vm/interpreter.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index 1b86bf37ce..36027be797 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -75,6 +75,8 @@ type EVMInterpreter struct { evm *EVM cfg Config + intPool *intPool + hasher crypto.KeccakState // Keccak256 hasher instance shared across opcodes hasherBuf common.Hash // Keccak256 hasher result array shared across opcodes From 72ca227ae786a81c82aac57866b975d45c88b96b Mon Sep 17 00:00:00 2001 From: c98tristan Date: Tue, 25 Jul 2023 14:20:38 +0700 Subject: [PATCH 049/119] Chore: Change parameter of NewBlock function from Trie to Stacktrie --- consensus/clique/clique.go | 2 +- consensus/ethash/consensus.go | 2 +- consensus/posv/posv.go | 2 +- core/block_validator.go | 4 ++-- core/blockchain_test.go | 4 ++-- core/database_util_test.go | 6 +++--- core/genesis.go | 2 +- core/tx_pool_test.go | 2 +- core/types/block_test.go | 5 ++--- eth/downloader/queue.go | 4 ++-- eth/fetcher/fetcher.go | 2 +- eth/fetcher/fetcher_test.go | 2 +- les/odr_requests.go | 4 ++-- trie/database.go | 2 +- 14 files changed, 21 insertions(+), 22 deletions(-) diff --git a/consensus/clique/clique.go b/consensus/clique/clique.go index 28610581d2..5c03e332cc 100644 --- a/consensus/clique/clique.go +++ b/consensus/clique/clique.go @@ -576,7 +576,7 @@ func (c *Clique) Finalize(chain consensus.ChainReader, header *types.Header, sta header.UncleHash = types.CalcUncleHash(nil) // Assemble and return the final block for sealing - return types.NewBlock(header, txs, nil, receipts, new(trie.Trie)), nil + return types.NewBlock(header, txs, nil, receipts, new(trie.StackTrie)), nil } // Authorize injects a private key into the consensus engine to mint new blocks diff --git a/consensus/ethash/consensus.go b/consensus/ethash/consensus.go index ee90ca7bac..7064569927 100644 --- a/consensus/ethash/consensus.go +++ b/consensus/ethash/consensus.go @@ -520,7 +520,7 @@ func (ethash *Ethash) Finalize(chain consensus.ChainReader, header *types.Header header.Root = state.IntermediateRoot(chain.Config().IsEIP158(header.Number)) // Header seems complete, assemble into a block and return - return types.NewBlock(header, txs, uncles, receipts, new(trie.Trie)), nil + return types.NewBlock(header, txs, uncles, receipts, new(trie.StackTrie)), nil } // Some weird constants to avoid constant memory allocs for them. diff --git a/consensus/posv/posv.go b/consensus/posv/posv.go index fa31a2d605..180a2ab956 100644 --- a/consensus/posv/posv.go +++ b/consensus/posv/posv.go @@ -986,7 +986,7 @@ func (c *Posv) Finalize(chain consensus.ChainReader, header *types.Header, state header.UncleHash = types.CalcUncleHash(nil) // Assemble and return the final block for sealing - return types.NewBlock(header, txs, nil, receipts, new(trie.Trie)), nil + return types.NewBlock(header, txs, nil, receipts, new(trie.StackTrie)), nil } // Authorize injects a private key into the consensus engine to mint new blocks diff --git a/core/block_validator.go b/core/block_validator.go index 384fec62f0..63e3f54383 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -73,7 +73,7 @@ func (v *BlockValidator) ValidateBody(block *types.Block) error { if hash := types.CalcUncleHash(block.Uncles()); hash != header.UncleHash { return fmt.Errorf("uncle root hash mismatch: have %x, want %x", hash, header.UncleHash) } - if hash := types.DeriveSha(block.Transactions(), new(trie.Trie)); hash != header.TxHash { + if hash := types.DeriveSha(block.Transactions(), new(trie.StackTrie)); hash != header.TxHash { return fmt.Errorf("transaction root hash mismatch: have %x, want %x", hash, header.TxHash) } return nil @@ -95,7 +95,7 @@ func (v *BlockValidator) ValidateState(block, parent *types.Block, statedb *stat return fmt.Errorf("invalid bloom (remote: %x local: %x)", header.Bloom, rbloom) } // Tre receipt Trie's root (R = (Tr [[H1, R1], ... [Hn, R1]])) - receiptSha := types.DeriveSha(receipts, new(trie.Trie)) + receiptSha := types.DeriveSha(receipts, new(trie.StackTrie)) if receiptSha != header.ReceiptHash { return fmt.Errorf("invalid receipt root hash (remote: %x local: %x)", header.ReceiptHash, receiptSha) } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 124880ba87..f9132863b4 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -619,12 +619,12 @@ func TestFastVsFullChains(t *testing.T) { } if fblock, ablock := fast.GetBlockByHash(hash), archive.GetBlockByHash(hash); fblock.Hash() != ablock.Hash() { t.Errorf("block #%d [%x]: block mismatch: have %v, want %v", num, hash, fblock, ablock) - } else if types.DeriveSha(fblock.Transactions(), new(trie.Trie)) != types.DeriveSha(ablock.Transactions(), new(trie.Trie)) { + } else if types.DeriveSha(fblock.Transactions(), new(trie.StackTrie)) != types.DeriveSha(ablock.Transactions(), new(trie.StackTrie)) { t.Errorf("block #%d [%x]: transactions mismatch: have %v, want %v", num, hash, fblock.Transactions(), ablock.Transactions()) } else if types.CalcUncleHash(fblock.Uncles()) != types.CalcUncleHash(ablock.Uncles()) { t.Errorf("block #%d [%x]: uncles mismatch: have %v, want %v", num, hash, fblock.Uncles(), ablock.Uncles()) } - if freceipts, areceipts := GetBlockReceipts(fastDb, hash, GetBlockNumber(fastDb, hash)), GetBlockReceipts(archiveDb, hash, GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts, new(trie.Trie)) != types.DeriveSha(areceipts, new(trie.Trie)) { + if freceipts, areceipts := GetBlockReceipts(fastDb, hash, GetBlockNumber(fastDb, hash)), GetBlockReceipts(archiveDb, hash, GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts, new(trie.StackTrie)) != types.DeriveSha(areceipts, new(trie.StackTrie)) { t.Errorf("block #%d [%x]: receipts mismatch: have %v, want %v", num, hash, freceipts, areceipts) } } diff --git a/core/database_util_test.go b/core/database_util_test.go index d15d5d2e53..0d61403f05 100644 --- a/core/database_util_test.go +++ b/core/database_util_test.go @@ -85,7 +85,7 @@ func TestBodyStorage(t *testing.T) { } if entry := GetBody(db, hash, 0); entry == nil { t.Fatalf("Stored body not found") - } else if types.DeriveSha(types.Transactions(entry.Transactions), new(trie.Trie)) != types.DeriveSha(types.Transactions(body.Transactions), new(trie.Trie)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) { + } else if types.DeriveSha(types.Transactions(entry.Transactions), new(trie.StackTrie)) != types.DeriveSha(types.Transactions(body.Transactions), new(trie.StackTrie)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) { t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, body) } if entry := GetBodyRLP(db, hash, 0); entry == nil { @@ -141,7 +141,7 @@ func TestBlockStorage(t *testing.T) { } if entry := GetBody(db, block.Hash(), block.NumberU64()); entry == nil { t.Fatalf("Stored body not found") - } else if types.DeriveSha(types.Transactions(entry.Transactions), new(trie.Trie)) != types.DeriveSha(block.Transactions(), new(trie.Trie)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) { + } else if types.DeriveSha(types.Transactions(entry.Transactions), new(trie.StackTrie)) != types.DeriveSha(block.Transactions(), new(trie.StackTrie)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) { t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, block.Body()) } // Delete the block and verify the execution @@ -297,7 +297,7 @@ func TestLookupStorage(t *testing.T) { tx3 := types.NewTransaction(3, common.BytesToAddress([]byte{0x33}), big.NewInt(333), 3333, big.NewInt(33333), []byte{0x33, 0x33, 0x33}) txs := []*types.Transaction{tx1, tx2, tx3} - block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil, new(trie.Trie)) + block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil, new(trie.StackTrie)) // Check that no transactions entries are in a pristine database for i, tx := range txs { diff --git a/core/genesis.go b/core/genesis.go index 4f417d64ba..6ee2550d9c 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -260,7 +260,7 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block { statedb.Commit(false) statedb.Database().TrieDB().Commit(root, true) - return types.NewBlock(head, nil, nil, nil, new(trie.Trie)) + return types.NewBlock(head, nil, nil, nil, new(trie.StackTrie)) } // Commit writes the block and state of a genesis specification to the database. diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index c458d0d35a..7a5e9f213a 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -72,7 +72,7 @@ func (bc *testBlockChain) Config() *params.ChainConfig { func (bc *testBlockChain) CurrentBlock() *types.Block { return types.NewBlock(&types.Header{ GasLimit: bc.gasLimit, - }, nil, nil, nil, new(trie.Trie)) + }, nil, nil, nil, new(trie.StackTrie)) } func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block { diff --git a/core/types/block_test.go b/core/types/block_test.go index 460dc35ba6..e93ae02de8 100644 --- a/core/types/block_test.go +++ b/core/types/block_test.go @@ -17,12 +17,11 @@ package types import ( + "bytes" "hash" "math/big" - "testing" - - "bytes" "reflect" + "testing" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/rlp" diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index 38a3f839f5..43569da2df 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -768,7 +768,7 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLi defer q.lock.Unlock() reconstruct := func(header *types.Header, index int, result *fetchResult) error { - if types.DeriveSha(types.Transactions(txLists[index]), new(trie.Trie)) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash { + if types.DeriveSha(types.Transactions(txLists[index]), new(trie.StackTrie)) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash { return errInvalidBody } result.Transactions = txLists[index] @@ -786,7 +786,7 @@ func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) (int, defer q.lock.Unlock() reconstruct := func(header *types.Header, index int, result *fetchResult) error { - if types.DeriveSha(types.Receipts(receiptList[index]), new(trie.Trie)) != header.ReceiptHash { + if types.DeriveSha(types.Receipts(receiptList[index]), new(trie.StackTrie)) != header.ReceiptHash { return errInvalidReceipt } result.Receipts = receiptList[index] diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go index 6b3080ce13..d1bc108fd2 100644 --- a/eth/fetcher/fetcher.go +++ b/eth/fetcher/fetcher.go @@ -532,7 +532,7 @@ func (f *Fetcher) loop() { for hash, announce := range f.completing { if f.queued[hash] == nil { - txnHash := types.DeriveSha(types.Transactions(task.transactions[i]), new(trie.Trie)) + txnHash := types.DeriveSha(types.Transactions(task.transactions[i]), new(trie.StackTrie)) uncleHash := types.CalcUncleHash(task.uncles[i]) if txnHash == announce.header.TxHash && uncleHash == announce.header.UncleHash && announce.origin == task.peer { diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go index 10acef3e6b..3a79ce1cb3 100644 --- a/eth/fetcher/fetcher_test.go +++ b/eth/fetcher/fetcher_test.go @@ -40,7 +40,7 @@ var ( testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") testAddress = crypto.PubkeyToAddress(testKey.PublicKey) genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000)) - unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil, new(trie.Trie)) + unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil, new(trie.StackTrie)) ) // makeChain creates a chain of n blocks starting at and including parent. diff --git a/les/odr_requests.go b/les/odr_requests.go index f061331dbc..a4d24d6dde 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -114,7 +114,7 @@ func (r *BlockRequest) Validate(db ethdb.Database, msg *Msg) error { if header == nil { return errHeaderUnavailable } - if header.TxHash != types.DeriveSha(types.Transactions(body.Transactions), new(trie.Trie)) { + if header.TxHash != types.DeriveSha(types.Transactions(body.Transactions), new(trie.StackTrie)) { return errTxHashMismatch } if header.UncleHash != types.CalcUncleHash(body.Uncles) { @@ -170,7 +170,7 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error { if header == nil { return errHeaderUnavailable } - if header.ReceiptHash != types.DeriveSha(receipt, new(trie.Trie)) { + if header.ReceiptHash != types.DeriveSha(receipt, new(trie.StackTrie)) { return errReceiptHashMismatch } // Validations passed, store and return diff --git a/trie/database.go b/trie/database.go index 43ce266371..a1422dcb0a 100644 --- a/trie/database.go +++ b/trie/database.go @@ -790,7 +790,7 @@ func (db *Database) commit(hash common.Hash, batch ethdb.Batch, uncacher *cleane if err != nil { return err } - fmt.Println("commit", node) + if err := batch.Put(hash[:], node.rlp()); err != nil { return err } From 530cbb1e14912ffa7ca7c6627ffb09d54bf58332 Mon Sep 17 00:00:00 2001 From: c98tristan Date: Tue, 25 Jul 2023 16:55:29 +0700 Subject: [PATCH 050/119] Chore: Sorting imported library --- core/blockchain_test.go | 5 ++--- core/database_util_test.go | 5 ++--- core/genesis.go | 5 ++--- core/tx_pool_test.go | 5 ++--- eth/fetcher/fetcher.go | 1 - eth/fetcher/fetcher_test.go | 5 ++--- 6 files changed, 10 insertions(+), 16 deletions(-) diff --git a/core/blockchain_test.go b/core/blockchain_test.go index f9132863b4..0035058801 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -24,16 +24,15 @@ import ( "testing" "time" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/trie" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) // Test fork of length N starting from block i diff --git a/core/database_util_test.go b/core/database_util_test.go index 0d61403f05..f0ae8d520b 100644 --- a/core/database_util_test.go +++ b/core/database_util_test.go @@ -21,13 +21,12 @@ import ( "math/big" "testing" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/trie" - "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto/sha3" "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/trie" ) // Tests block header storage and retrieval operations. diff --git a/core/genesis.go b/core/genesis.go index 6ee2550d9c..7cc74e64ee 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -25,18 +25,17 @@ import ( "math/big" "strings" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/trie" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/trie" ) //go:generate gencodec -type Genesis -field-override genesisSpecMarshaling -out gen_genesis.go diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 7a5e9f213a..8ddb0650ea 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -26,16 +26,15 @@ import ( "testing" "time" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/trie" - - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) // testTxPoolConfig is a transaction pool configuration without stateful disk diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go index d1bc108fd2..142089586c 100644 --- a/eth/fetcher/fetcher.go +++ b/eth/fetcher/fetcher.go @@ -23,7 +23,6 @@ import ( "time" lru "github.com/hashicorp/golang-lru" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/types" diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go index 3a79ce1cb3..951b2fcd6c 100644 --- a/eth/fetcher/fetcher_test.go +++ b/eth/fetcher/fetcher_test.go @@ -24,15 +24,14 @@ import ( "testing" "time" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/trie" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) var ( From 260f47e834e421e1cbf7b85a4418b543341d0e4a Mon Sep 17 00:00:00 2001 From: c98tristan Date: Tue, 25 Jul 2023 17:02:00 +0700 Subject: [PATCH 051/119] Chore: Sorting imported library --- core/database_util.go | 3 +-- eth/filters/bench_test.go | 10 +++++----- rlp/decode.go | 3 +-- rlp/encode.go | 3 +-- rlp/encode_test.go | 3 +-- 5 files changed, 9 insertions(+), 13 deletions(-) diff --git a/core/database_util.go b/core/database_util.go index a668434f16..f501134392 100644 --- a/core/database_util.go +++ b/core/database_util.go @@ -24,9 +24,8 @@ import ( "fmt" "math/big" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" diff --git a/eth/filters/bench_test.go b/eth/filters/bench_test.go index 3648a3db2f..d1b4820343 100644 --- a/eth/filters/bench_test.go +++ b/eth/filters/bench_test.go @@ -20,7 +20,6 @@ import ( "bytes" "context" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "testing" "time" @@ -28,6 +27,7 @@ import ( "github.com/tomochain/tomochain/common/bitutil" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -68,7 +68,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { benchDataDir := node.DefaultDataDir() + "/geth/chaindata" fmt.Println("Running bloombits benchmark section size:", sectionSize) - db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"") + db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "") if err != nil { b.Fatalf("error opening database at %v: %v", benchDataDir, err) } @@ -130,7 +130,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { for i := 0; i < benchFilterCnt; i++ { if i%20 == 0 { db.Close() - db, _ = rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"") + db, _ = rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "") backend = &testBackend{mux, db, cnt, new(event.Feed), new(event.Feed), new(event.Feed), new(event.Feed)} } var addr common.Address @@ -148,7 +148,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { } func forEachKey(db ethdb.Database, startPrefix, endPrefix []byte, fn func(key []byte)) { - it := db.NewIterator(startPrefix,nil) + it := db.NewIterator(startPrefix, nil) for it.Next() { key := it.Key() cmpLen := len(key) @@ -176,7 +176,7 @@ func clearBloomBits(db ethdb.Database) { func BenchmarkNoBloomBits(b *testing.B) { benchDataDir := node.DefaultDataDir() + "/geth/chaindata" fmt.Println("Running benchmark without bloombits") - db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"") + db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "") if err != nil { b.Fatalf("error opening database at %v: %v", benchDataDir, err) } diff --git a/rlp/decode.go b/rlp/decode.go index 20c454ca9c..ac93c139a9 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -28,9 +28,8 @@ import ( "strings" "sync" - "github.com/tomochain/tomochain/rlp/internal/rlpstruct" - "github.com/holiman/uint256" + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" ) //lint:ignore ST1012 EOL is not an error. diff --git a/rlp/encode.go b/rlp/encode.go index f34be7f3df..2ca283c0a3 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -23,9 +23,8 @@ import ( "math/big" "reflect" - "github.com/tomochain/tomochain/rlp/internal/rlpstruct" - "github.com/holiman/uint256" + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" ) var ( diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 7b8775c12b..9f2e6c38f9 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -26,9 +26,8 @@ import ( "sync" "testing" - "github.com/tomochain/tomochain/common/math" - "github.com/holiman/uint256" + "github.com/tomochain/tomochain/common/math" ) type testEncoder struct { From b1cb0fc8c15176c021804a5f418d3a50b1abdd32 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 25 Jul 2023 17:36:25 +0700 Subject: [PATCH 052/119] Resolve conflicts after merged --- core/bench_test.go | 2 -- core/rawdb/accessors_chain.go | 1 - core/types/block.go | 5 ----- 3 files changed, 8 deletions(-) diff --git a/core/bench_test.go b/core/bench_test.go index e6698510da..9af7791f5d 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -23,8 +23,6 @@ import ( "os" "testing" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus/ethash" diff --git a/core/rawdb/accessors_chain.go b/core/rawdb/accessors_chain.go index 5817395c06..e80153e530 100644 --- a/core/rawdb/accessors_chain.go +++ b/core/rawdb/accessors_chain.go @@ -23,7 +23,6 @@ import ( "math/big" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" diff --git a/core/types/block.go b/core/types/block.go index 0f06e935e6..66baecf2cf 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -33,11 +33,6 @@ import ( "github.com/tomochain/tomochain/rlp" ) -var ( - EmptyRootHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - EmptyUncleHash = CalcUncleHash(nil) -) - // A BlockNonce is a 64-bit hash which proves (combined with the // mix-hash) that a sufficient amount of computation has been carried // out on a block. From 5e8bde341f11d2a3901d3a337738e6edb631629e Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 31 Jul 2023 14:28:00 +0700 Subject: [PATCH 053/119] Refactor some rawdb methods --- consensus/posv/posv.go | 2 ++ core/blockchain_test.go | 2 +- core/rawdb/accessors_chain.go | 28 --------------- core/rawdb/accessors_state.go | 58 +++++++++++++++++++++++++++++++ core/rawdb/accessors_trie.go | 64 +++++++++++++++++++++++++++++++++++ 5 files changed, 125 insertions(+), 29 deletions(-) create mode 100644 core/rawdb/accessors_state.go create mode 100644 core/rawdb/accessors_trie.go diff --git a/consensus/posv/posv.go b/consensus/posv/posv.go index f7531d8151..30a4d2a51c 100644 --- a/consensus/posv/posv.go +++ b/consensus/posv/posv.go @@ -32,6 +32,7 @@ import ( "time" lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" @@ -49,6 +50,7 @@ import ( "github.com/tomochain/tomochain/rpc" "github.com/tomochain/tomochain/tomox/tradingstate" "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "github.com/tomochain/tomochain/trie" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 2ec70a80ef..d8d845ac10 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -623,7 +623,7 @@ func TestFastVsFullChains(t *testing.T) { } else if types.CalcUncleHash(fblock.Uncles()) != types.CalcUncleHash(ablock.Uncles()) { t.Errorf("block #%d [%x]: uncles mismatch: have %v, want %v", num, hash, fblock.Uncles(), ablock.Uncles()) } - if freceipts, areceipts := GetBlockReceipts(fastDb, hash, GetBlockNumber(fastDb, hash)), GetBlockReceipts(archiveDb, hash, GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts, new(trie.StackTrie)) != types.DeriveSha(areceipts, new(trie.StackTrie)) { + if freceipts, areceipts := rawdb.GetBlockReceipts(fastDb, hash, rawdb.GetBlockNumber(fastDb, hash)), rawdb.GetBlockReceipts(archiveDb, hash, rawdb.GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts, new(trie.StackTrie)) != types.DeriveSha(areceipts, new(trie.StackTrie)) { t.Errorf("block #%d [%x]: receipts mismatch: have %v, want %v", num, hash, freceipts, areceipts) } } diff --git a/core/rawdb/accessors_chain.go b/core/rawdb/accessors_chain.go index e80153e530..aea37e500c 100644 --- a/core/rawdb/accessors_chain.go +++ b/core/rawdb/accessors_chain.go @@ -19,7 +19,6 @@ package rawdb import ( "bytes" "encoding/binary" - "fmt" "math/big" "github.com/tomochain/tomochain/common" @@ -365,33 +364,6 @@ func DeleteBlockReceipts(db DatabaseDeleter, hash common.Hash, number uint64) { db.Delete(blockReceiptsKey(number, hash)) } -// PreimageTable returns a Database instance with the key prefix for preimage entries. -func PreimageTable(db ethdb.Database) ethdb.Database { - return NewTable(db, preimagePrefix) -} - -// WritePreimages writes the provided set of preimages to the database. `number` is the -// current block number, and is used for debug messages only. -func WritePreimages(db ethdb.Database, number uint64, preimages map[common.Hash][]byte) error { - table := PreimageTable(db) - batch := table.NewBatch() - hitCount := 0 - for hash, preimage := range preimages { - if _, err := table.Get(hash.Bytes()); err != nil { - batch.Put(hash.Bytes(), preimage) - hitCount++ - } - } - preimageCounter.Inc(int64(len(preimages))) - preimageHitCounter.Inc(int64(hitCount)) - if hitCount > 0 { - if err := batch.Write(); err != nil { - return fmt.Errorf("preimage write fail for block %d: %v", number, err) - } - } - return nil -} - // FindCommonAncestor returns the last common ancestor of two block headers func FindCommonAncestor(db DatabaseReader, a, b *types.Header) *types.Header { for bn := b.Number.Uint64(); a.Number.Uint64() > bn; { diff --git a/core/rawdb/accessors_state.go b/core/rawdb/accessors_state.go new file mode 100644 index 0000000000..23048f0f8c --- /dev/null +++ b/core/rawdb/accessors_state.go @@ -0,0 +1,58 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import ( + "fmt" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" +) + +// PreimageTable returns a Database instance with the key prefix for preimage entries. +func PreimageTable(db ethdb.Database) ethdb.Database { + return NewTable(db, preimagePrefix) +} + +// ReadPreimage retrieves a single preimage of the provided hash. +func ReadPreimage(db ethdb.Database, hash common.Hash) []byte { + table := PreimageTable(db) + data, _ := table.Get(hash.Bytes()) + return data +} + +// WritePreimages writes the provided set of preimages to the database. `number` is the +// current block number, and is used for debug messages only. +func WritePreimages(db ethdb.Database, number uint64, preimages map[common.Hash][]byte) error { + table := PreimageTable(db) + batch := table.NewBatch() + hitCount := 0 + for hash, preimage := range preimages { + if _, err := table.Get(hash.Bytes()); err != nil { + batch.Put(hash.Bytes(), preimage) + hitCount++ + } + } + preimageCounter.Inc(int64(len(preimages))) + preimageHitCounter.Inc(int64(hitCount)) + if hitCount > 0 { + if err := batch.Write(); err != nil { + return fmt.Errorf("preimage write fail for block %d: %v", number, err) + } + } + return nil +} diff --git a/core/rawdb/accessors_trie.go b/core/rawdb/accessors_trie.go new file mode 100644 index 0000000000..7e1bbcaa2f --- /dev/null +++ b/core/rawdb/accessors_trie.go @@ -0,0 +1,64 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see + +package rawdb + +import ( + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" +) + +// HashScheme is the legacy hash-based state scheme with which trie nodes are +// stored in the disk with node hash as the database key. The advantage of this +// scheme is that different versions of trie nodes can be stored in disk, which +// is very beneficial for constructing archive nodes. The drawback is it will +// store different trie nodes on the same path to different locations on the disk +// with no data locality, and it's unfriendly for designing state pruning. +// +// Now this scheme is still kept for backward compatibility, and it will be used +// for archive node and some other tries(e.g. light trie). +const HashScheme = "hashScheme" + +// ReadLegacyTrieNode retrieves the legacy trie node with the given +// associated node hash. +func ReadLegacyTrieNode(db ethdb.KeyValueReader, hash common.Hash) []byte { + data, err := db.Get(hash.Bytes()) + if err != nil { + return nil + } + return data +} + +// HasLegacyTrieNode checks if the trie node with the provided hash is present in db. +func HasLegacyTrieNode(db ethdb.KeyValueReader, hash common.Hash) bool { + ok, _ := db.Has(hash.Bytes()) + return ok +} + +// WriteLegacyTrieNode writes the provided legacy trie node to database. +func WriteLegacyTrieNode(db ethdb.KeyValueWriter, hash common.Hash, node []byte) { + if err := db.Put(hash.Bytes(), node); err != nil { + log.Crit("Failed to store legacy trie node", "err", err) + } +} + +// DeleteLegacyTrieNode deletes the specified legacy trie node from database. +func DeleteLegacyTrieNode(db ethdb.KeyValueWriter, hash common.Hash) { + if err := db.Delete(hash.Bytes()); err != nil { + log.Crit("Failed to delete legacy trie node", "err", err) + } +} From cb81493257c6e989e375ab74216db44119ff5cd4 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 31 Jul 2023 15:28:36 +0700 Subject: [PATCH 054/119] [WIP] Implement new trie interface and separate preimageStore --- cmd/gc/main.go | 6 +- core/state/database.go | 45 +++++-- core/state/dump.go | 2 +- core/state/state_object.go | 14 +-- core/state/statedb.go | 23 ++-- go.mod | 4 +- go.sum | 9 +- les/handler.go | 2 +- light/trie.go | 64 ++++++++-- tomox/tradingstate/tomox_trie.go | 14 +-- tomoxlending/lendingstate/tomox_trie.go | 14 +-- trie/database.go | 45 ++++--- trie/preimages.go | 94 ++++++++++++++ trie/proof.go | 26 ++-- trie/secure_trie.go | 155 +++++++++++++++++------- trie/sync_test.go | 21 ++-- trie/trie.go | 112 ++++++++++------- 17 files changed, 458 insertions(+), 192 deletions(-) create mode 100644 trie/preimages.go diff --git a/cmd/gc/main.go b/cmd/gc/main.go index 4f2e0a1291..af68df50ee 100644 --- a/cmd/gc/main.go +++ b/cmd/gc/main.go @@ -14,11 +14,9 @@ import ( "github.com/tomochain/tomochain/cmd/utils" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/ethdb/leveldb" - "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -80,9 +78,7 @@ func main() { atomic.StoreInt32(&finish, 1) if running { for _, address := range cleanAddress { - enc := trieRoot.trie.Get(address.Bytes()) - var data types.StateAccount - rlp.DecodeBytes(enc, &data) + data, _ := trieRoot.trie.GetAccount(address) fmt.Println(time.Now().Format(time.RFC3339), "Start clean state address ", address.Hex(), " at block ", trieRoot.number) signerRoot, err := resolveHash(data.Root[:], db) if err != nil { diff --git a/core/state/database.go b/core/state/database.go index b57f134db8..8a678ccdc4 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -21,6 +21,7 @@ import ( lru "github.com/hashicorp/golang-lru" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/trie" ) @@ -59,20 +60,40 @@ type Trie interface { // TODO(fjl): remove this when SecureTrie is removed GetKey([]byte) []byte - // TryGet returns the value for key stored in the trie. The value bytes must - // not be modified by the caller. If a node was not found in the database, a - // trie.MissingNodeError is returned. - TryGet(key []byte) ([]byte, error) - - // TryUpdate associates key with value in the trie. If value has length zero, any - // existing value is deleted from the trie. The value bytes must not be modified + // GetStorage returns the value for key stored in the trie. The value bytes + // must not be modified by the caller. If a node was not found in the database, + // a trie.MissingNodeError is returned. + GetStorage(addr common.Address, key []byte) ([]byte, error) + + // GetAccount abstracts an account read from the trie. It retrieves the + // account blob from the trie with provided account address and decodes it + // with associated decoding algorithm. If the specified account is not in + // the trie, nil will be returned. If the trie is corrupted(e.g. some nodes + // are missing or the account blob is incorrect for decoding), an error will + // be returned. + GetAccount(address common.Address) (*types.StateAccount, error) + + // UpdateStorage associates key with value in the trie. If value has length zero, + // any existing value is deleted from the trie. The value bytes must not be modified // by the caller while they are stored in the trie. If a node was not found in the // database, a trie.MissingNodeError is returned. - TryUpdate(key, value []byte) error + UpdateStorage(addr common.Address, key, value []byte) error + + // UpdateAccount abstracts an account write to the trie. It encodes the + // provided account object with associated algorithm and then updates it + // in the trie with provided address. + UpdateAccount(address common.Address, account *types.StateAccount) error + + // UpdateContractCode abstracts code write to the trie. It is expected + // to be moved to the stateWriter interface when the latter is ready. + UpdateContractCode(address common.Address, codeHash common.Hash, code []byte) error + + // DeleteStorage removes any existing value for key from the trie. If a node + // was not found in the database, a trie.MissingNodeError is returned. + DeleteStorage(addr common.Address, key []byte) error - // TryDelete removes any existing value for key from the trie. If a node was not - // found in the database, a trie.MissingNodeError is returned. - TryDelete(key []byte) error + // DeleteAccount abstracts an account deletion from the trie. + DeleteAccount(address common.Address) error // Hash returns the root hash of the trie. It does not write to the database and // can be used even if the trie doesn't have one. @@ -109,7 +130,7 @@ func NewDatabase(db ethdb.Database) Database { func NewDatabaseWithCache(db ethdb.Database, cache int) Database { csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabaseWithCache(db, cache), + db: trie.NewDatabaseWithCache(db, &trie.Config{Cache: cache, Preimages: true}), codeSizeCache: csc, } } diff --git a/core/state/dump.go b/core/state/dump.go index 7368146ca7..4e21044bbc 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -54,7 +54,7 @@ func (self *StateDB) RawDump() Dump { panic(err) } - obj := newObject(nil, common.BytesToAddress(addr), data, nil) + obj := newObject(nil, common.BytesToAddress(addr), &data, nil) account := DumpAccount{ Balance: data.Balance.String(), Nonce: data.Nonce, diff --git a/core/state/state_object.go b/core/state/state_object.go index 478823be58..bb40953b69 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -97,7 +97,7 @@ func (s *stateObject) empty() bool { } // newObject creates a state object. -func newObject(db *StateDB, address common.Address, data types.StateAccount, onDirty func(addr common.Address)) *stateObject { +func newObject(db *StateDB, address common.Address, data *types.StateAccount, onDirty func(addr common.Address)) *stateObject { if data.Balance == nil { data.Balance = new(big.Int) } @@ -108,7 +108,7 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount, onD db: db, address: address, addrHash: crypto.Keccak256Hash(address[:]), - data: data, + data: *data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty, @@ -163,7 +163,7 @@ func (c *stateObject) getTrie(db Database) Trie { func (self *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { value := common.Hash{} // Load from DB in case it is missing. - enc, err := self.getTrie(db).TryGet(key[:]) + enc, err := self.getTrie(db).GetStorage(self.address, key.Bytes()) if err != nil { self.setError(err) return common.Hash{} @@ -184,7 +184,7 @@ func (self *stateObject) GetState(db Database, key common.Hash) common.Hash { return value } // Load from DB in case it is missing. - enc, err := self.getTrie(db).TryGet(key[:]) + enc, err := self.getTrie(db).GetStorage(self.address, key.Bytes()) if err != nil { self.setError(err) return common.Hash{} @@ -228,12 +228,12 @@ func (self *stateObject) updateTrie(db Database) Trie { for key, value := range self.dirtyStorage { delete(self.dirtyStorage, key) if (value == common.Hash{}) { - self.setError(tr.TryDelete(key[:])) + self.setError(tr.DeleteStorage(self.address, key[:])) continue } // Encoding []byte cannot fail, ok to ignore the error. v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) - self.setError(tr.TryUpdate(key[:], v)) + self.setError(tr.UpdateStorage(self.address, key[:], v)) } return tr } @@ -302,7 +302,7 @@ func (self *stateObject) setBalance(amount *big.Int) { func (c *stateObject) ReturnGas(gas *big.Int) {} func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { - stateObject := newObject(db, self.address, self.data, onDirty) + stateObject := newObject(db, self.address, &self.data, onDirty) if self.trie != nil { stateObject.trie = db.db.CopyTrie(self.trie) } diff --git a/core/state/statedb.go b/core/state/statedb.go index 818d3d0aaa..6e17f493d7 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -26,7 +26,6 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -360,18 +359,18 @@ func (self *StateDB) Suicide(addr common.Address) bool { // updateStateObject writes the given object to the trie. func (self *StateDB) updateStateObject(stateObject *stateObject) { addr := stateObject.Address() - data, err := rlp.EncodeToBytes(stateObject) - if err != nil { - panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err)) + if err := self.trie.UpdateAccount(addr, &stateObject.data); err != nil { + self.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err)) } - self.setError(self.trie.TryUpdate(addr[:], data)) } // deleteStateObject removes the given object from the state trie. func (self *StateDB) deleteStateObject(stateObject *stateObject) { stateObject.deleted = true addr := stateObject.Address() - self.setError(self.trie.TryDelete(addr[:])) + if err := self.trie.DeleteAccount(addr); err != nil { + self.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err)) + } } // DeleteAddress removes the address from the state trie. @@ -393,14 +392,12 @@ func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObje } // Load the object from the database. - enc, err := self.trie.TryGet(addr[:]) - if len(enc) == 0 { - self.setError(err) + data, err := self.trie.GetAccount(addr) + if err != nil { + self.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err)) return nil } - var data types.StateAccount - if err := rlp.DecodeBytes(enc, &data); err != nil { - log.Error("Failed to decode state object", "addr", addr, "err", err) + if data == nil { return nil } // Insert into the live set. @@ -432,7 +429,7 @@ func (self *StateDB) MarkStateObjectDirty(addr common.Address) { // the given address, it is overwritten and returned as the second return value. func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { prev = self.getStateObject(addr) - newobj = newObject(self, addr, types.StateAccount{}, self.MarkStateObjectDirty) + newobj = newObject(self, addr, &types.StateAccount{}, self.MarkStateObjectDirty) newobj.setNonce(0) // sets the object to dirty if prev == nil { self.journal = append(self.journal, createObjectChange{account: &addr}) diff --git a/go.mod b/go.mod index e5db078ba3..ae326fcb35 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/stretchr/testify v1.8.1 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 golang.org/x/crypto v0.1.0 + golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691 golang.org/x/net v0.8.0 golang.org/x/sync v0.1.0 golang.org/x/sys v0.7.0 @@ -55,6 +56,7 @@ require ( github.com/dlclark/regexp2 v1.7.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect github.com/google/uuid v1.3.0 // indirect github.com/kr/pretty v0.3.1 // indirect @@ -69,7 +71,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3 // indirect - golang.org/x/mod v0.9.0 // indirect + golang.org/x/mod v0.11.0 // indirect golang.org/x/term v0.6.0 // indirect golang.org/x/text v0.8.0 // indirect golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect diff --git a/go.sum b/go.sum index c913b65f5a..6699d53904 100644 --- a/go.sum +++ b/go.sum @@ -94,8 +94,9 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= @@ -245,10 +246,12 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691 h1:/yRP+0AN7mf5DkD3BAI6TOFnd51gEoDEb8o35jIFtgw= +golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= -golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= +golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= diff --git a/les/handler.go b/les/handler.go index f354584e5b..a186190a63 100644 --- a/les/handler.go +++ b/les/handler.go @@ -1100,7 +1100,7 @@ func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common. if err != nil { return types.StateAccount{}, err } - blob, err := trie.TryGet(hash[:]) + blob, err := trie.Get(hash[:]) if err != nil { return types.StateAccount{}, err } diff --git a/light/trie.go b/light/trie.go index d247f145ea..c469491ef8 100644 --- a/light/trie.go +++ b/light/trie.go @@ -26,6 +26,7 @@ import ( "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -95,27 +96,74 @@ type odrTrie struct { trie *trie.Trie } -func (t *odrTrie) TryGet(key []byte) ([]byte, error) { +func (t *odrTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { key = crypto.Keccak256(key) - var res []byte + var enc []byte err := t.do(key, func() (err error) { - res, err = t.trie.TryGet(key) + enc, err = t.trie.Get(key) return err }) - return res, err + if err != nil || len(enc) == 0 { + return nil, err + } + _, content, _, err := rlp.Split(enc) + return content, err +} + +func (t *odrTrie) GetAccount(address common.Address) (*types.StateAccount, error) { + var ( + enc []byte + key = crypto.Keccak256(address.Bytes()) + ) + err := t.do(key, func() (err error) { + enc, err = t.trie.Get(key) + return err + }) + if err != nil || len(enc) == 0 { + return nil, err + } + acct := new(types.StateAccount) + if err := rlp.DecodeBytes(enc, acct); err != nil { + return nil, err + } + return acct, nil +} + +func (t *odrTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error { + key := crypto.Keccak256(address.Bytes()) + value, err := rlp.EncodeToBytes(acc) + if err != nil { + return fmt.Errorf("decoding error in account update: %w", err) + } + return t.do(key, func() error { + return t.trie.Update(key, value) + }) +} + +func (t *odrTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error { + return nil } -func (t *odrTrie) TryUpdate(key, value []byte) error { +func (t *odrTrie) UpdateStorage(_ common.Address, key, value []byte) error { key = crypto.Keccak256(key) + v, _ := rlp.EncodeToBytes(value) return t.do(key, func() error { - return t.trie.TryDelete(key) + return t.trie.Update(key, v) }) } -func (t *odrTrie) TryDelete(key []byte) error { +func (t *odrTrie) DeleteStorage(_ common.Address, key []byte) error { key = crypto.Keccak256(key) return t.do(key, func() error { - return t.trie.TryDelete(key) + return t.trie.Delete(key) + }) +} + +// DeleteAccount abstracts an account deletion from the trie. +func (t *odrTrie) DeleteAccount(address common.Address) error { + key := crypto.Keccak256(address.Bytes()) + return t.do(key, func() error { + return t.trie.Delete(key) }) } diff --git a/tomox/tradingstate/tomox_trie.go b/tomox/tradingstate/tomox_trie.go index 908648def9..197e50b4c0 100644 --- a/tomox/tradingstate/tomox_trie.go +++ b/tomox/tradingstate/tomox_trie.go @@ -18,11 +18,11 @@ package tradingstate import ( "fmt" - "github.com/tomochain/tomochain/ethdb" - "github.com/tomochain/tomochain/trie" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/trie" ) // TomoXTrie wraps a trie with key hashing. In a secure trie, all @@ -78,10 +78,10 @@ func (t *TomoXTrie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGet(key []byte) ([]byte, error) { - return t.trie.TryGet(key) + return t.trie.Get(key) } -// TryGetBestLeftKey returns the value of max left leaf +// TryGetBestLeftKeyAndValue returns the value of max left leaf // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGetBestLeftKeyAndValue() ([]byte, []byte, error) { return t.trie.TryGetBestLeftKeyAndValue() @@ -91,7 +91,7 @@ func (t *TomoXTrie) TryGetAllLeftKeyAndValue(limit []byte) ([][]byte, [][]byte, return t.trie.TryGetAllLeftKeyAndValue(limit) } -// TryGetBestRightKey returns the value of max left leaf +// TryGetBestRightKeyAndValue returns the value of max left leaf // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGetBestRightKeyAndValue() ([]byte, []byte, error) { return t.trie.TryGetBestRightKeyAndValue() @@ -118,7 +118,7 @@ func (t *TomoXTrie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryUpdate(key, value []byte) error { - err := t.trie.TryUpdate(key, value) + err := t.trie.Update(key, value) if err != nil { return err } @@ -137,7 +137,7 @@ func (t *TomoXTrie) Delete(key []byte) { // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryDelete(key []byte) error { delete(t.getSecKeyCache(), string(key)) - return t.trie.TryDelete(key) + return t.trie.Delete(key) } // GetKey returns the sha3 preimage of a hashed key that was diff --git a/tomoxlending/lendingstate/tomox_trie.go b/tomoxlending/lendingstate/tomox_trie.go index 8ff0a5633a..2852139ae0 100644 --- a/tomoxlending/lendingstate/tomox_trie.go +++ b/tomoxlending/lendingstate/tomox_trie.go @@ -18,11 +18,11 @@ package lendingstate import ( "fmt" - "github.com/tomochain/tomochain/ethdb" - "github.com/tomochain/tomochain/trie" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/trie" ) // TomoXTrie wraps a trie with key hashing. In a secure trie, all @@ -78,16 +78,16 @@ func (t *TomoXTrie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGet(key []byte) ([]byte, error) { - return t.trie.TryGet(key) + return t.trie.Get(key) } -// TryGetBestLeftKey returns the value of max left leaf +// TryGetBestLeftKeyAndValue returns the value of max left leaf // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGetBestLeftKeyAndValue() ([]byte, []byte, error) { return t.trie.TryGetBestLeftKeyAndValue() } -// TryGetBestRightKey returns the value of max left leaf +// TryGetBestRightKeyAndValue returns the value of max left leaf // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGetBestRightKeyAndValue() ([]byte, []byte, error) { return t.trie.TryGetBestRightKeyAndValue() @@ -114,7 +114,7 @@ func (t *TomoXTrie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryUpdate(key, value []byte) error { - err := t.trie.TryUpdate(key, value) + err := t.trie.Update(key, value) if err != nil { return err } @@ -133,7 +133,7 @@ func (t *TomoXTrie) Delete(key []byte) { // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryDelete(key []byte) error { delete(t.getSecKeyCache(), string(key)) - return t.trie.TryDelete(key) + return t.trie.Delete(key) } // GetKey returns the sha3 preimage of a hashed key that was diff --git a/trie/database.go b/trie/database.go index a1422dcb0a..6d9f3868f9 100644 --- a/trie/database.go +++ b/trie/database.go @@ -25,6 +25,7 @@ import ( "time" "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" @@ -65,6 +66,12 @@ const secureKeyPrefixLength = 11 // secureKeyLength is the length of the above prefix + 32byte hash. const secureKeyLength = secureKeyPrefixLength + 32 +// Config defines all necessary options for database. +type Config struct { + Cache int // Memory allowance (MB) to use for caching trie nodes in memory + Preimages bool // Flag whether the preimage of trie key is recorded +} + // Database is an intermediate write layer between the trie data structures and // the disk database. The aim is to accumulate trie writes in-memory and only // periodically flush a couple tries to disk, garbage collecting the remainder. @@ -74,6 +81,7 @@ const secureKeyLength = secureKeyPrefixLength + 32 // behind this split design is to provide read access to RPC handlers and sync // servers even while the trie is executing expensive garbage collection. type Database struct { + config *Config // Configuration for trie database diskdb ethdb.KeyValueStore // Persistent storage for matured trie nodes cleans *fastcache.Cache // GC friendly memory Cache of clean Node RLPs @@ -81,7 +89,7 @@ type Database struct { oldest common.Hash // Oldest tracked Node, flush-list head newest common.Hash // Newest tracked Node, flush-list tail - preimages map[common.Hash][]byte // Preimages of nodes from the secure trie + preimages *preimageStore // The store for caching preimages gctime time.Duration // Time spent on garbage collection since last commit gcnodes uint64 // Nodes garbage collected since last commit @@ -282,26 +290,32 @@ func expandNode(hash HashNode, n Node) Node { // NewDatabase creates a new trie database to store ephemeral trie content before // its written out to disk or garbage collected. No read Cache is created, so all // data retrievals will hit the underlying disk database. -func NewDatabase(diskdb ethdb.KeyValueStore) *Database { - return NewDatabaseWithCache(diskdb, 0) +func NewDatabase(diskdb ethdb.Database) *Database { + return NewDatabaseWithCache(diskdb, &Config{Cache: 0}) } // NewDatabaseWithCache creates a new trie database to store ephemeral trie content // before its written out to disk or garbage collected. It also acts as a read Cache // for nodes loaded from disk. -func NewDatabaseWithCache(diskdb ethdb.KeyValueStore, cache int) *Database { +func NewDatabaseWithCache(diskdb ethdb.Database, config *Config) *Database { var cleans *fastcache.Cache - if cache > 0 { - cleans = fastcache.New(cache * 1024 * 1024) + if config.Cache > 0 { + cleans = fastcache.New(config.Cache * 1024 * 1024) } - return &Database{ + var preimages *preimageStore + if config != nil && config.Preimages { + preimages = newPreimageStore(diskdb) + } + db := &Database{ diskdb: diskdb, cleans: cleans, dirties: map[common.Hash]*cachedNode{{}: { children: make(map[common.Hash]uint16), }}, - preimages: make(map[common.Hash][]byte), + preimages: preimages, } + + return db } // DiskDB retrieves the persistent storage backing the trie database. @@ -357,11 +371,12 @@ func (db *Database) insert(hash common.Hash, size int, node Node) { // yet unknown. The method will make a copy of the slice. // // Note, this method assumes that the database's Lock is held! +// This function's still be kept because of TomoX tries func (db *Database) InsertPreimage(hash common.Hash, preimage []byte) { - if _, ok := db.preimages[hash]; ok { + if _, ok := db.preimages.preimages[hash]; ok { return } - db.preimages[hash] = common.CopyBytes(preimage) + db.preimages.preimages[hash] = common.CopyBytes(preimage) db.preimagesSize += common.StorageSize(common.HashLength + len(preimage)) } @@ -445,7 +460,7 @@ func (db *Database) Node(hash common.Hash) ([]byte, error) { func (db *Database) Preimage(hash common.Hash) ([]byte, error) { // Retrieve the Node from Cache if available db.Lock.RLock() - preimage := db.preimages[hash] + preimage := db.preimages.preimages[hash] db.Lock.RUnlock() if preimage != nil { @@ -612,7 +627,7 @@ func (db *Database) Cap(limit common.StorageSize) error { // leave for later to deduplicate writes. flushPreimages := db.preimagesSize > 4*1024*1024 if flushPreimages { - for hash, preimage := range db.preimages { + for hash, preimage := range db.preimages.preimages { copy(keyBuf[secureKeyPrefixLength:], hash[:]) if err := batch.Put(keyBuf[:], preimage); err != nil { log.Error("Failed to commit Preimage from trie database", "err", err) @@ -661,7 +676,7 @@ func (db *Database) Cap(limit common.StorageSize) error { defer db.Lock.Unlock() if flushPreimages { - db.preimages = make(map[common.Hash][]byte) + db.preimages.preimages = make(map[common.Hash][]byte) db.preimagesSize = 0 } for db.oldest != oldest { @@ -711,7 +726,7 @@ func (db *Database) Commit(node common.Hash, report bool) error { copy(keyBuf[:], secureKeyPrefix) // Move all of the accumulated preimages into a write batch - for hash, preimage := range db.preimages { + for hash, preimage := range db.preimages.preimages { copy(keyBuf[secureKeyPrefixLength:], hash[:]) if err := batch.Put(keyBuf[:], preimage); err != nil { log.Error("Failed to commit Preimage from trie database", "err", err) @@ -753,7 +768,7 @@ func (db *Database) Commit(node common.Hash, report bool) error { batch.Reset() // Reset the storage counters and bumpd metrics - db.preimages = make(map[common.Hash][]byte) + db.preimages.preimages = make(map[common.Hash][]byte) db.preimagesSize = 0 memcacheCommitTimeTimer.Update(time.Since(start)) diff --git a/trie/preimages.go b/trie/preimages.go new file mode 100644 index 0000000000..760f2290f4 --- /dev/null +++ b/trie/preimages.go @@ -0,0 +1,94 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "sync" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/ethdb" +) + +// preimageStore is the store for caching preimages of node key. +type preimageStore struct { + lock sync.RWMutex + disk ethdb.Database + preimages map[common.Hash][]byte // Preimages of nodes from the secure trie + preimagesSize common.StorageSize // Storage size of the preimages cache +} + +// newPreimageStore initializes the store for caching preimages. +func newPreimageStore(disk ethdb.Database) *preimageStore { + return &preimageStore{ + disk: disk, + preimages: make(map[common.Hash][]byte), + } +} + +// insertPreimage writes a new trie node pre-image to the memory database if it's +// yet unknown. The method will NOT make a copy of the slice, only use if the +// preimage will NOT be changed later on. +func (store *preimageStore) insertPreimage(preimages map[common.Hash][]byte) { + store.lock.Lock() + defer store.lock.Unlock() + + for hash, preimage := range preimages { + if _, ok := store.preimages[hash]; ok { + continue + } + store.preimages[hash] = preimage + store.preimagesSize += common.StorageSize(common.HashLength + len(preimage)) + } +} + +// preimage retrieves a cached trie node pre-image from memory. If it cannot be +// found cached, the method queries the persistent database for the content. +func (store *preimageStore) preimage(hash common.Hash) []byte { + store.lock.RLock() + preimage := store.preimages[hash] + store.lock.RUnlock() + + if preimage != nil { + return preimage + } + return rawdb.ReadPreimage(store.disk, hash) +} + +// commit flushes the cached preimages into the disk. +func (store *preimageStore) commit(force bool) error { + store.lock.Lock() + defer store.lock.Unlock() + + if store.preimagesSize <= 4*1024*1024 && !force { + return nil + } + if err := rawdb.WritePreimages(store.disk, 0, store.preimages); err != nil { + return err + } + + store.preimages, store.preimagesSize = make(map[common.Hash][]byte), 0 + return nil +} + +// size returns the current storage size of accumulated preimages. +func (store *preimageStore) size() common.StorageSize { + store.lock.RLock() + defer store.lock.RUnlock() + + return store.preimagesSize +} diff --git a/trie/proof.go b/trie/proof.go index 9e4082a27e..28320e8a06 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -22,8 +22,8 @@ import ( "fmt" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/ethdb" - "github.com/tomochain/tomochain/ethdb/memorydb" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/rlp" ) @@ -395,11 +395,11 @@ func hasRightElement(node Node, key []byte) bool { // Expect the normal case, this function can also be used to verify the following // range proofs(note this function doesn't accept zero element proof): // -// - All elements proof. In this case the left and right proof can be nil, but the -// range should be all the leaves in the trie. +// - All elements proof. In this case the left and right proof can be nil, but the +// range should be all the leaves in the trie. // -// - One element proof. In this case no matter the left edge proof is a non-existent -// proof or not, we can always verify the correctness of the proof. +// - One element proof. In this case no matter the left edge proof is a non-existent +// proof or not, we can always verify the correctness of the proof. // // Except returning the error to indicate the proof is valid or not, the function will // also return a flag to indicate whether there exists more accounts/slots in the trie. @@ -419,15 +419,12 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, valu // Special case, there is no edge proof at all. The given range is expected // to be the whole leaf-set in the trie. if firstProof == nil && lastProof == nil { - emptytrie, err := New(common.Hash{}, NewDatabase(memorydb.New())) - if err != nil { - return err, false - } + tr := NewStackTrie(nil) for index, key := range keys { - emptytrie.TryUpdate(key, values[index]) + tr.Update(key, values[index]) } - if emptytrie.Hash() != rootHash { - return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, emptytrie.Hash()), false + if have, want := tr.Hash(), rootHash; have != want { + return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()), false } return nil, false // no more element. } @@ -464,9 +461,10 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, valu } // Rebuild the trie with the leave stream, the shape of trie // should be same with the original one. - newtrie := &Trie{root: root, Db: NewDatabase(memorydb.New())} + + newtrie := &Trie{root: root, Db: NewDatabase(rawdb.NewMemoryDatabase())} for index, key := range keys { - newtrie.TryUpdate(key, values[index]) + newtrie.Update(key, values[index]) } if newtrie.Hash() != rootHash { return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, newtrie.Hash()), false diff --git a/trie/secure_trie.go b/trie/secure_trie.go index f62d3d06de..db95c7cc41 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -17,10 +17,9 @@ package trie import ( - "fmt" - "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/rlp" ) // SecureTrie wraps a trie with key hashing. In a secure trie, all @@ -35,6 +34,7 @@ import ( // SecureTrie is not safe for concurrent use. type SecureTrie struct { trie Trie + preimages *preimageStore hashKeyBuf [common.HashLength]byte secKeyCache map[string][]byte secKeyCacheOwner *SecureTrie // Pointer to self, replace the key Cache on mismatch @@ -50,7 +50,7 @@ type SecureTrie struct { // Accessing the trie loads nodes from the database or Node pool on demand. // Loaded nodes are kept around until their 'Cache generation' expires. // A new Cache generation is created by each call to Commit. -// cachelimit sets the number of past Cache generations to keep. +// cache limit sets the number of past Cache generations to keep. func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) { if db == nil { panic("trie.NewSecure called without a database") @@ -59,49 +59,84 @@ func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) { if err != nil { return nil, err } - return &SecureTrie{trie: *trie}, nil + return &SecureTrie{trie: *trie, preimages: db.preimages}, nil } -// Get returns the value for key stored in the trie. +// MustGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. -func (t *SecureTrie) Get(key []byte) []byte { - res, err := t.TryGet(key) - if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) +// +// This function will omit any encountered error but just +// print out an error message. +func (t *SecureTrie) MustGet(key []byte) []byte { + return t.trie.MustGet(t.hashKey(key)) +} + +// GetStorage attempts to retrieve a storage slot with provided account address +// and slot key. The value bytes must not be modified by the caller. +// If the specified storage slot is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { + enc, err := t.trie.Get(t.hashKey(key)) + if err != nil || len(enc) == 0 { + return nil, err } - return res + _, content, _, err := rlp.Split(enc) + return content, err } -// TryGet returns the value for key stored in the trie. -// The value bytes must not be modified by the caller. -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *SecureTrie) TryGet(key []byte) ([]byte, error) { - return t.trie.TryGet(t.hashKey(key)) +// GetAccount attempts to retrieve an account with provided account address. +// If the specified account is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) GetAccount(address common.Address) (*types.StateAccount, error) { + res, err := t.trie.Get(t.hashKey(address.Bytes())) + if res == nil || err != nil { + return nil, err + } + ret := new(types.StateAccount) + err = rlp.DecodeBytes(res, ret) + return ret, err } -// Update associates key with value in the trie. Subsequent calls to +// GetAccountByHash does the same thing as GetAccount, however it expects an +// account hash that is the hash of address. This constitutes an abstraction +// leak, since the client code needs to know the key format. +func (t *SecureTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { + res, err := t.trie.Get(addrHash.Bytes()) + if res == nil || err != nil { + return nil, err + } + ret := new(types.StateAccount) + err = rlp.DecodeBytes(res, ret) + return ret, err +} + +// MustUpdate associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil. // // The value bytes must not be modified by the caller while they are // stored in the trie. -func (t *SecureTrie) Update(key, value []byte) { - if err := t.TryUpdate(key, value); err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) - } +// +// This function will omit any encountered error but just print out an +// error message. +func (t *SecureTrie) MustUpdate(key, value []byte) { + hk := t.hashKey(key) + t.trie.MustUpdate(hk, value) + t.getSecKeyCache()[string(hk)] = common.CopyBytes(key) } -// TryUpdate associates key with value in the trie. Subsequent calls to +// UpdateStorage associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil. // // The value bytes must not be modified by the caller while they are // stored in the trie. // -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *SecureTrie) TryUpdate(key, value []byte) error { +// If a node is not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) UpdateStorage(_ common.Address, key, value []byte) error { hk := t.hashKey(key) - err := t.trie.TryUpdate(hk, value) + v, _ := rlp.EncodeToBytes(value) + err := t.trie.Update(hk, v) if err != nil { return err } @@ -109,19 +144,47 @@ func (t *SecureTrie) TryUpdate(key, value []byte) error { return nil } -// Delete removes any existing value for key from the trie. -func (t *SecureTrie) Delete(key []byte) { - if err := t.TryDelete(key); err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) +// UpdateAccount will abstract the write of an account to the secure trie. + +func (t *SecureTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error { + hk := t.hashKey(address.Bytes()) + data, err := rlp.EncodeToBytes(acc) + if err != nil { + return err } + if err := t.trie.Update(hk, data); err != nil { + return err + } + t.getSecKeyCache()[string(hk)] = address.Bytes() + return nil +} + +func (t *SecureTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error { + return nil +} + +// MustDelete removes any existing value for key from the trie. This function +// will omit any encountered error but just print out an error message. +func (t *SecureTrie) MustDelete(key []byte) { + hk := t.hashKey(key) + delete(t.getSecKeyCache(), string(hk)) + t.trie.MustDelete(hk) } -// TryDelete removes any existing value for key from the trie. -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *SecureTrie) TryDelete(key []byte) error { +// DeleteStorage removes any existing storage slot from the trie. +// If the specified trie node is not in the trie, nothing will be changed. +// If a node is not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) DeleteStorage(_ common.Address, key []byte) error { hk := t.hashKey(key) delete(t.getSecKeyCache(), string(hk)) - return t.trie.TryDelete(hk) + return t.trie.Delete(hk) +} + +// DeleteAccount abstracts an account deletion from the trie. +func (t *SecureTrie) DeleteAccount(address common.Address) error { + hk := t.hashKey(address.Bytes()) + delete(t.getSecKeyCache(), string(hk)) + return t.trie.Delete(hk) } // GetKey returns the sha3 Preimage of a hashed key that was @@ -130,8 +193,10 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { return key } - key, _ := t.trie.Db.Preimage(common.BytesToHash(shaKey)) - return key + if t.preimages == nil { + return nil + } + return t.preimages.preimage(common.BytesToHash(shaKey)) } // Commit writes all nodes and the secure hash pre-images to the trie's database. @@ -142,12 +207,15 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { func (t *SecureTrie) Commit(onleaf LeafCallback) (root common.Hash, err error) { // Write all the pre-images to the actual disk database if len(t.getSecKeyCache()) > 0 { - t.trie.Db.Lock.Lock() - for hk, key := range t.secKeyCache { - t.trie.Db.InsertPreimage(common.BytesToHash([]byte(hk)), key) + if t.preimages != nil { + t.trie.Db.Lock.Lock() + preimages := make(map[common.Hash][]byte) + for hk, key := range t.secKeyCache { + preimages[common.BytesToHash([]byte(hk))] = key + } + t.preimages.insertPreimage(preimages) + t.trie.Db.Lock.Unlock() } - t.trie.Db.Lock.Unlock() - t.secKeyCache = make(map[string][]byte) } // Commit the trie to its intermediate Node database @@ -162,8 +230,11 @@ func (t *SecureTrie) Hash() common.Hash { // Copy returns a copy of SecureTrie. func (t *SecureTrie) Copy() *SecureTrie { - cpy := *t - return &cpy + return &SecureTrie{ + trie: *t.trie.Copy(), + preimages: t.preimages, + secKeyCache: t.secKeyCache, + } } // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration diff --git a/trie/sync_test.go b/trie/sync_test.go index b7627054ae..25baa5c67c 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -21,13 +21,14 @@ import ( "testing" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/ethdb/memorydb" ) // makeTestTrie create a sample test trie to test Node-wise reconstruction. func makeTestTrie() (*Database, *Trie, map[string][]byte) { // Create an empty trie - triedb := NewDatabase(memorydb.New()) + triedb := NewDatabase(rawdb.NewMemoryDatabase()) trie, _ := New(common.Hash{}, triedb) // Fill it with some arbitrary data @@ -67,7 +68,7 @@ func checkTrieContents(t *testing.T, db *Database, root []byte, content map[stri t.Fatalf("inconsistent trie at %x: %v", root, err) } for key, val := range content { - if have := trie.Get([]byte(key)); !bytes.Equal(have, val) { + if have, _ := trie.Get([]byte(key)); !bytes.Equal(have, val) { t.Errorf("entry %x: content mismatch: have %x, want %x", key, have, val) } } @@ -88,8 +89,8 @@ func checkTrieConsistency(db *Database, root common.Hash) error { // Tests that an empty trie is not scheduled for syncing. func TestEmptySync(t *testing.T) { - dbA := NewDatabase(memorydb.New()) - dbB := NewDatabase(memorydb.New()) + dbA := NewDatabase(rawdb.NewMemoryDatabase()) + dbB := NewDatabase(rawdb.NewMemoryDatabase()) emptyA, _ := New(common.Hash{}, dbA) emptyB, _ := New(emptyRoot, dbB) @@ -110,7 +111,7 @@ func testIterativeSync(t *testing.T, count int) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) @@ -145,7 +146,7 @@ func TestIterativeDelayedSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) @@ -185,7 +186,7 @@ func testIterativeRandomSync(t *testing.T, count int) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) @@ -228,7 +229,7 @@ func TestIterativeRandomDelayedSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) @@ -277,7 +278,7 @@ func TestDuplicateAvoidanceSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) @@ -319,7 +320,7 @@ func TestIncompleteSync(t *testing.T) { srcDb, srcTrie, _ := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) diff --git a/trie/trie.go b/trie/trie.go index a0c627d232..8d63ee9620 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -82,35 +82,45 @@ func New(root common.Hash, db *Database) (*Trie, error) { return trie, nil } +// Copy returns a copy of Trie. +func (t *Trie) Copy() *Trie { + return &Trie{ + Db: t.Db, + root: t.root, + unhashed: t.unhashed, + } +} + // NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at // the key after the given start key. func (t *Trie) NodeIterator(start []byte) NodeIterator { return newNodeIterator(t, start) } -// Get returns the value for key stored in the trie. -// The value bytes must not be modified by the caller. -func (t *Trie) Get(key []byte) []byte { - res, err := t.TryGet(key) +// MustGet is a wrapper of Get and will omit any encountered error but just +// print out an error message. +func (t *Trie) MustGet(key []byte) []byte { + res, err := t.Get(key) if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + log.Error("Unhandled trie error in Trie.Get", "err", err) } return res } -// TryGet returns the value for key stored in the trie. +// Get returns the value for key stored in the trie. // The value bytes must not be modified by the caller. -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryGet(key []byte) ([]byte, error) { - key = keybytesToHex(key) - value, newroot, didResolve, err := t.tryGet(t.root, key, 0) +// +// If the requested node is not present in trie, no error will be returned. +// If the trie is corrupted, a MissingNodeError is returned. +func (t *Trie) Get(key []byte) ([]byte, error) { + value, newroot, didResolve, err := t.get(t.root, keybytesToHex(key), 0) if err == nil && didResolve { t.root = newroot } return value, err } -func (t *Trie) tryGet(origNode Node, key []byte, pos int) (value []byte, newnode Node, didResolve bool, err error) { +func (t *Trie) get(origNode Node, key []byte, pos int) (value []byte, newnode Node, didResolve bool, err error) { switch n := (origNode).(type) { case nil: return nil, nil, false, nil @@ -121,14 +131,14 @@ func (t *Trie) tryGet(origNode Node, key []byte, pos int) (value []byte, newnode // key not found in trie return nil, n, false, nil } - value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) + value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key)) if err == nil && didResolve { n = n.copy() n.Val = newnode } return value, n, didResolve, err case *FullNode: - value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) + value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1) if err == nil && didResolve { n = n.copy() n.Children[key[pos]] = newnode @@ -139,10 +149,10 @@ func (t *Trie) tryGet(origNode Node, key []byte, pos int) (value []byte, newnode if err != nil { return nil, n, true, err } - value, newnode, _, err := t.tryGet(child, key, pos) + value, newnode, _, err := t.get(child, key, pos) return value, newnode, true, err default: - panic(fmt.Sprintf("%T: invalid Node: %v", origNode, origNode)) + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) } } @@ -310,29 +320,28 @@ func (t *Trie) tryGetBestRightKeyAndValue(origNode Node, prefix []byte) (key []b return nil, nil, nil, false, fmt.Errorf("%T: invalid Node: %v", origNode, origNode) } -// Update associates key with value in the trie. Subsequent calls to -// Get will return value. If value has length zero, any existing value -// is deleted from the trie and calls to Get will return nil. -// -// The value bytes must not be modified by the caller while they are -// stored in the trie. -func (t *Trie) Update(key, value []byte) error { - if err := t.TryUpdate(key, value); err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) - return err +// MustUpdate is a wrapper of Update and will omit any encountered error but +// just print out an error message. +func (t *Trie) MustUpdate(key, value []byte) { + if err := t.Update(key, value); err != nil { + log.Error("Unhandled trie error in Trie.Update", "err", err) } - return nil } -// TryUpdate associates key with value in the trie. Subsequent calls to +// Update associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil. // // The value bytes must not be modified by the caller while they are // stored in the trie. // -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryUpdate(key, value []byte) error { +// If the requested node is not present in trie, no error will be returned. +// If the trie is corrupted, a MissingNodeError is returned. +func (t *Trie) Update(key, value []byte) error { + return t.update(key, value) +} + +func (t *Trie) update(key, value []byte) error { t.unhashed++ k := keybytesToHex(key) if len(value) != 0 { @@ -420,16 +429,19 @@ func (t *Trie) insert(n Node, prefix, key []byte, value Node) (bool, Node, error } } -// Delete removes any existing value for key from the trie. -func (t *Trie) Delete(key []byte) { - if err := t.TryDelete(key); err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) +// MustDelete is a wrapper of Delete and will omit any encountered error but +// just print out an error message. +func (t *Trie) MustDelete(key []byte) { + if err := t.Delete(key); err != nil { + log.Error("Unhandled trie error in Trie.Delete", "err", err) } } -// TryDelete removes any existing value for key from the trie. -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryDelete(key []byte) error { +// Delete removes any existing value for key from the trie. +// +// If the requested node is not present in trie, no error will be returned. +// If the trie is corrupted, a MissingNodeError is returned. +func (t *Trie) Delete(key []byte) error { t.unhashed++ k := keybytesToHex(key) _, n, err := t.delete(t.root, nil, k) @@ -464,8 +476,8 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { switch child := child.(type) { case *ShortNode: // Deleting from the subtrie reduced it to another - // short Node. Merge the nodes to avoid creating a - // ShortNode{..., ShortNode{...}}. Use concat (which + // short node. Merge the nodes to avoid creating a + // shortNode{..., shortNode{...}}. Use concat (which // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. @@ -483,10 +495,18 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { n.flags = t.newFlag() n.Children[key[0]] = nn + // Because n is a full node, it must've contained at least two children + // before the delete operation. If the new child value is non-nil, n still + // has at least two children after the deletion, and cannot be reduced to + // a short node. + if nn != nil { + return true, n, nil + } + // Reduction: // Check how many non-nil entries are left after deleting and - // reduce the full Node to a short Node if only one entry is + // reduce the full node to a short node if only one entry is // left. Since n must've contained at least two children - // before deletion (otherwise it would not be a full Node) n + // before deletion (otherwise it would not be a full node) n // can never be reduced to nil. // // When the loop is done, pos contains the index of the single @@ -505,13 +525,13 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { } if pos >= 0 { if pos != 16 { - // If the remaining entry is a short Node, it replaces + // If the remaining entry is a short node, it replaces // n and its key gets the missing nibble tacked to the // front. This avoids creating an invalid - // ShortNode{..., ShortNode{...}}. Since the entry + // shortNode{..., shortNode{...}}. Since the entry // might not be loaded yet, resolve it just for this // check. - cnode, err := t.resolve(n.Children[pos], prefix) + cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos))) if err != nil { return false, nil, err } @@ -520,7 +540,7 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { return true, &ShortNode{k, cnode.Val, t.newFlag()}, nil } } - // Otherwise, n is replaced by a one-nibble short Node + // Otherwise, n is replaced by a one-nibble short node // containing the child. return true, &ShortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil } @@ -535,7 +555,7 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { case HashNode: // We've hit a part of the trie that isn't loaded yet. Load - // the Node and delete from it. This leaves all child nodes on + // the node and delete from it. This leaves all child nodes on // the path to the value in the trie. rn, err := t.resolveHash(n, prefix) if err != nil { @@ -548,7 +568,7 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { return true, nn, nil default: - panic(fmt.Sprintf("%T: invalid Node: %v (%v)", n, n, key)) + panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) } } From a4096363bbdfc8fe3f8102624183c334850aacb5 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 31 Jul 2023 16:04:13 +0700 Subject: [PATCH 055/119] Refactor NewDatabaseWithConfig --- core/state/database.go | 2 +- tomox/tradingstate/database.go | 2 +- tomoxlending/lendingstate/database.go | 2 +- trie/database.go | 8 ++++---- trie/database_test.go | 4 ++-- trie/iterator_test.go | 6 +++--- trie/secure_trie_test.go | 6 +++--- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/core/state/database.go b/core/state/database.go index 8a678ccdc4..7396ef7018 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -130,7 +130,7 @@ func NewDatabase(db ethdb.Database) Database { func NewDatabaseWithCache(db ethdb.Database, cache int) Database { csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabaseWithCache(db, &trie.Config{Cache: cache, Preimages: true}), + db: trie.NewDatabaseWithConfig(db, &trie.Config{Cache: cache, Preimages: true}), codeSizeCache: csc, } } diff --git a/tomox/tradingstate/database.go b/tomox/tradingstate/database.go index 56acf61ec6..e77b6be1a6 100644 --- a/tomox/tradingstate/database.go +++ b/tomox/tradingstate/database.go @@ -81,7 +81,7 @@ type Trie interface { func NewDatabase(db ethdb.Database) Database { csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabase(db), + db: trie.NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), codeSizeCache: csc, } } diff --git a/tomoxlending/lendingstate/database.go b/tomoxlending/lendingstate/database.go index d823602599..c27c41dcfe 100644 --- a/tomoxlending/lendingstate/database.go +++ b/tomoxlending/lendingstate/database.go @@ -80,7 +80,7 @@ type Trie interface { func NewDatabase(db ethdb.Database) Database { csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabase(db), + db: trie.NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), codeSizeCache: csc, } } diff --git a/trie/database.go b/trie/database.go index 6d9f3868f9..4d9c7368fa 100644 --- a/trie/database.go +++ b/trie/database.go @@ -291,15 +291,15 @@ func expandNode(hash HashNode, n Node) Node { // its written out to disk or garbage collected. No read Cache is created, so all // data retrievals will hit the underlying disk database. func NewDatabase(diskdb ethdb.Database) *Database { - return NewDatabaseWithCache(diskdb, &Config{Cache: 0}) + return NewDatabaseWithConfig(diskdb, nil) } -// NewDatabaseWithCache creates a new trie database to store ephemeral trie content +// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content // before its written out to disk or garbage collected. It also acts as a read Cache // for nodes loaded from disk. -func NewDatabaseWithCache(diskdb ethdb.Database, config *Config) *Database { +func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database { var cleans *fastcache.Cache - if config.Cache > 0 { + if config != nil && config.Cache > 0 { cleans = fastcache.New(config.Cache * 1024 * 1024) } var preimages *preimageStore diff --git a/trie/database_test.go b/trie/database_test.go index ed6b58fdc5..126923b12c 100644 --- a/trie/database_test.go +++ b/trie/database_test.go @@ -20,13 +20,13 @@ import ( "testing" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/ethdb/memorydb" + "github.com/tomochain/tomochain/core/rawdb" ) // Tests that the trie database returns a missing trie Node error if attempting // to retrieve the meta root. func TestDatabaseMetarootFetch(t *testing.T) { - db := NewDatabase(memorydb.New()) + db := NewDatabase(rawdb.NewMemoryDatabase()) if _, err := db.Node(common.Hash{}); err == nil { t.Fatalf("metaroot retrieval succeeded") } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 26d48c95cd..b93d664220 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -23,7 +23,7 @@ import ( "testing" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/ethdb/memorydb" + "github.com/tomochain/tomochain/core/rawdb" ) func TestIterator(t *testing.T) { @@ -292,7 +292,7 @@ func TestIteratorContinueAfterErrorDisk(t *testing.T) { testIteratorContinueA func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) } func testIteratorContinueAfterError(t *testing.T, memonly bool) { - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) tr, _ := New(common.Hash{}, triedb) @@ -383,7 +383,7 @@ func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) { func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) { // Commit test trie to Db, then remove the Node containing "bars". - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) ctr, _ := New(common.Hash{}, triedb) diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index a015ffcff6..7dcb5680c5 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -23,19 +23,19 @@ import ( "testing" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/ethdb/memorydb" ) func newEmptySecure() *SecureTrie { - trie, _ := NewSecure(common.Hash{}, NewDatabase(memorydb.New())) + trie, _ := NewSecure(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) return trie } // makeTestSecureTrie creates a large enough secure trie for testing. func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) { // Create an empty trie - triedb := NewDatabase(memorydb.New()) + triedb := NewDatabase(rawdb.NewMemoryDatabase()) trie, _ := NewSecure(common.Hash{}, triedb) // Fill it with some arbitrary data From 4e70b3e8bb2d5d00731e4856fbb58451d977b597 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 31 Jul 2023 16:28:46 +0700 Subject: [PATCH 056/119] Include configs when init trie databases --- cmd/evm/runner.go | 9 ++++++--- core/state/database.go | 10 +++++----- trie/database.go | 38 ++++++++++++++++++-------------------- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index 5d3b242898..c6def708c3 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -20,12 +20,14 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "os" "runtime/pprof" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/trie" + goruntime "runtime" "github.com/tomochain/tomochain/cmd/evm/internal/compiler" @@ -83,6 +85,7 @@ func runCmd(ctx *cli.Context) error { debugLogger *vm.StructLogger statedb *state.StateDB chainConfig *params.ChainConfig + preimages = ctx.Bool(DumpFlag.Name) sender = common.StringToAddress("sender") receiver = common.StringToAddress("receiver") ) @@ -98,11 +101,11 @@ func runCmd(ctx *cli.Context) error { gen := readGenesis(ctx.GlobalString(GenesisFlag.Name)) db := rawdb.NewMemoryDatabase() genesis := gen.ToBlock(db) - statedb, _ = state.New(genesis.Root(), state.NewDatabase(db)) + statedb, _ = state.New(genesis.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Preimages: preimages})) chainConfig = gen.Config } else { db := rawdb.NewMemoryDatabase() - statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ = state.New(common.Hash{}, state.NewDatabaseWithConfig(db, &trie.Config{Preimages: preimages})) } if ctx.GlobalString(SenderFlag.Name) != "" { sender = common.HexToAddress(ctx.GlobalString(SenderFlag.Name)) diff --git a/core/state/database.go b/core/state/database.go index 7396ef7018..8f47c33965 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -119,18 +119,18 @@ type Trie interface { // NewDatabase creates a backing store for state. The returned database is safe for // concurrent use, but does not retain any recent trie nodes in memory. To keep some -// historical state in memory, use the NewDatabaseWithCache constructor. +// historical state in memory, use the NewDatabaseWithConfig constructor. func NewDatabase(db ethdb.Database) Database { - return NewDatabaseWithCache(db, 0) + return NewDatabaseWithConfig(db, nil) } -// NewDatabaseWithCache creates a backing store for state. The returned database +// NewDatabaseWithConfig creates a backing store for state. The returned database // is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a // large memory cache. -func NewDatabaseWithCache(db ethdb.Database, cache int) Database { +func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabaseWithConfig(db, &trie.Config{Cache: cache, Preimages: true}), + db: trie.NewDatabaseWithConfig(db, config), codeSizeCache: csc, } } diff --git a/trie/database.go b/trie/database.go index 4d9c7368fa..8aa296fda8 100644 --- a/trie/database.go +++ b/trie/database.go @@ -726,26 +726,28 @@ func (db *Database) Commit(node common.Hash, report bool) error { copy(keyBuf[:], secureKeyPrefix) // Move all of the accumulated preimages into a write batch - for hash, preimage := range db.preimages.preimages { - copy(keyBuf[secureKeyPrefixLength:], hash[:]) - if err := batch.Put(keyBuf[:], preimage); err != nil { - log.Error("Failed to commit Preimage from trie database", "err", err) - return err - } - // If the batch is too large, flush to disk - if batch.ValueSize() > ethdb.IdealBatchSize { - if err := batch.Write(); err != nil { + if db.preimages != nil { + for hash, preimage := range db.preimages.preimages { + copy(keyBuf[secureKeyPrefixLength:], hash[:]) + if err := batch.Put(keyBuf[:], preimage); err != nil { + log.Error("Failed to commit Preimage from trie database", "err", err) return err } - batch.Reset() + // If the batch is too large, flush to disk + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + return err + } + batch.Reset() + } } + // Since we're going to replay trie Node writes into the clean Cache, flush out + // any batched pre-images before continuing. + if err := batch.Write(); err != nil { + return err + } + batch.Reset() } - // Since we're going to replay trie Node writes into the clean Cache, flush out - // any batched pre-images before continuing. - if err := batch.Write(); err != nil { - return err - } - batch.Reset() // Move the trie itself into the batch, flushing if enough data is accumulated nodes, storage := len(db.dirties), db.dirtiesSize @@ -767,10 +769,6 @@ func (db *Database) Commit(node common.Hash, report bool) error { batch.Replay(uncacher) batch.Reset() - // Reset the storage counters and bumpd metrics - db.preimages.preimages = make(map[common.Hash][]byte) - db.preimagesSize = 0 - memcacheCommitTimeTimer.Update(time.Since(start)) memcacheCommitSizeMeter.Mark(int64(storage - db.dirtiesSize)) memcacheCommitNodesMeter.Mark(int64(nodes - len(db.dirties))) From 1b9882dd8d915f811ed7a5e859c2e1ce4f6176e9 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 31 Jul 2023 17:38:41 +0700 Subject: [PATCH 057/119] Nitpick --- trie/database.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/trie/database.go b/trie/database.go index 8aa296fda8..446fda34c2 100644 --- a/trie/database.go +++ b/trie/database.go @@ -813,9 +813,12 @@ func (db *Database) commit(hash common.Hash, batch ethdb.Batch, uncacher *cleane return err } db.Lock.Lock() - batch.Replay(uncacher) + err := batch.Replay(uncacher) batch.Reset() db.Lock.Unlock() + if err != nil { + return err + } } return nil } @@ -829,7 +832,7 @@ type cleaner struct { // Put reacts to database writes and implements dirty data uncaching. This is the // post-processing step of a commit operation where the already persisted trie is // removed from the dirty Cache and moved into the clean Cache. The reason behind -// the two-phase commit is to ensure ensure data availability while moving from +// the two-phase commit is to ensure data availability while moving from // memory to disk. func (c *cleaner) Put(key []byte, rlp []byte) error { hash := common.BytesToHash(key) From 475b70acfc46cfa47d4d709cace17e135c34eff2 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 1 Aug 2023 11:12:12 +0700 Subject: [PATCH 058/119] Merge pull request #9 from c98tristan:feat/update-trie-to-stacktrie --- consensus/clique/clique.go | 3 +- consensus/ethash/consensus.go | 3 +- consensus/posv/posv.go | 3 +- core/bench_test.go | 3 +- core/block_validator.go | 6 +- core/blockchain_test.go | 5 +- core/genesis.go | 3 +- core/rawdb/accessors_chain_test.go | 427 ----------------------------- core/tx_pool_test.go | 19 +- core/types/block.go | 6 +- core/types/block_test.go | 41 ++- core/types/derive_sha.go | 57 +++- core/types/hashing.go | 11 + core/types/hashing_test.go | 79 ++++++ core/types/lending_transaction.go | 13 +- core/types/order_transaction.go | 11 +- core/types/receipt.go | 7 + core/types/transaction.go | 15 +- eth/downloader/queue.go | 5 +- eth/fetcher/fetcher.go | 7 +- eth/fetcher/fetcher_test.go | 5 +- les/odr_requests.go | 4 +- miner/worker.go | 3 + rlp/decode.go | 3 +- rlp/encode.go | 3 +- rlp/encode_test.go | 3 +- trie/database.go | 5 +- trie/trie.go | 10 +- 28 files changed, 269 insertions(+), 491 deletions(-) delete mode 100644 core/rawdb/accessors_chain_test.go create mode 100644 core/types/hashing.go create mode 100644 core/types/hashing_test.go diff --git a/consensus/clique/clique.go b/consensus/clique/clique.go index f63373e17e..5c03e332cc 100644 --- a/consensus/clique/clique.go +++ b/consensus/clique/clique.go @@ -40,6 +40,7 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/trie" ) const ( @@ -575,7 +576,7 @@ func (c *Clique) Finalize(chain consensus.ChainReader, header *types.Header, sta header.UncleHash = types.CalcUncleHash(nil) // Assemble and return the final block for sealing - return types.NewBlock(header, txs, nil, receipts), nil + return types.NewBlock(header, txs, nil, receipts, new(trie.StackTrie)), nil } // Authorize injects a private key into the consensus engine to mint new blocks diff --git a/consensus/ethash/consensus.go b/consensus/ethash/consensus.go index 12f63cfde7..7064569927 100644 --- a/consensus/ethash/consensus.go +++ b/consensus/ethash/consensus.go @@ -32,6 +32,7 @@ import ( "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) // Ethash proof-of-work protocol constants. @@ -519,7 +520,7 @@ func (ethash *Ethash) Finalize(chain consensus.ChainReader, header *types.Header header.Root = state.IntermediateRoot(chain.Config().IsEIP158(header.Number)) // Header seems complete, assemble into a block and return - return types.NewBlock(header, txs, uncles, receipts), nil + return types.NewBlock(header, txs, uncles, receipts, new(trie.StackTrie)), nil } // Some weird constants to avoid constant memory allocs for them. diff --git a/consensus/posv/posv.go b/consensus/posv/posv.go index 6a632bb357..f9570deb78 100644 --- a/consensus/posv/posv.go +++ b/consensus/posv/posv.go @@ -49,6 +49,7 @@ import ( "github.com/tomochain/tomochain/rpc" "github.com/tomochain/tomochain/tomox/tradingstate" "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "github.com/tomochain/tomochain/trie" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -985,7 +986,7 @@ func (c *Posv) Finalize(chain consensus.ChainReader, header *types.Header, state header.UncleHash = types.CalcUncleHash(nil) // Assemble and return the final block for sealing - return types.NewBlock(header, txs, nil, receipts), nil + return types.NewBlock(header, txs, nil, receipts, new(trie.StackTrie)), nil } // Authorize injects a private key into the consensus engine to mint new blocks diff --git a/core/bench_test.go b/core/bench_test.go index 5380398f91..e157b17080 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -23,10 +23,11 @@ import ( "os" "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus/ethash" - "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" diff --git a/core/block_validator.go b/core/block_validator.go index 34fde4cedd..63e3f54383 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -18,6 +18,7 @@ package core import ( "fmt" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/posv" @@ -27,6 +28,7 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/tomox/tradingstate" "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "github.com/tomochain/tomochain/trie" ) // BlockValidator is responsible for validating block headers, uncles and @@ -71,7 +73,7 @@ func (v *BlockValidator) ValidateBody(block *types.Block) error { if hash := types.CalcUncleHash(block.Uncles()); hash != header.UncleHash { return fmt.Errorf("uncle root hash mismatch: have %x, want %x", hash, header.UncleHash) } - if hash := types.DeriveSha(block.Transactions()); hash != header.TxHash { + if hash := types.DeriveSha(block.Transactions(), new(trie.StackTrie)); hash != header.TxHash { return fmt.Errorf("transaction root hash mismatch: have %x, want %x", hash, header.TxHash) } return nil @@ -93,7 +95,7 @@ func (v *BlockValidator) ValidateState(block, parent *types.Block, statedb *stat return fmt.Errorf("invalid bloom (remote: %x local: %x)", header.Bloom, rbloom) } // Tre receipt Trie's root (R = (Tr [[H1, R1], ... [Hn, R1]])) - receiptSha := types.DeriveSha(receipts) + receiptSha := types.DeriveSha(receipts, new(trie.StackTrie)) if receiptSha != header.ReceiptHash { return fmt.Errorf("invalid receipt root hash (remote: %x local: %x)", header.ReceiptHash, receiptSha) } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index c4d171de03..dd8ea644e3 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -32,6 +32,7 @@ import ( "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) // Test fork of length N starting from block i @@ -617,12 +618,12 @@ func TestFastVsFullChains(t *testing.T) { } if fblock, ablock := fast.GetBlockByHash(hash), archive.GetBlockByHash(hash); fblock.Hash() != ablock.Hash() { t.Errorf("block #%d [%x]: block mismatch: have %v, want %v", num, hash, fblock, ablock) - } else if types.DeriveSha(fblock.Transactions()) != types.DeriveSha(ablock.Transactions()) { + } else if types.DeriveSha(fblock.Transactions(), new(trie.StackTrie)) != types.DeriveSha(ablock.Transactions(), new(trie.StackTrie)) { t.Errorf("block #%d [%x]: transactions mismatch: have %v, want %v", num, hash, fblock.Transactions(), ablock.Transactions()) } else if types.CalcUncleHash(fblock.Uncles()) != types.CalcUncleHash(ablock.Uncles()) { t.Errorf("block #%d [%x]: uncles mismatch: have %v, want %v", num, hash, fblock.Uncles(), ablock.Uncles()) } - if freceipts, areceipts := rawdb.GetBlockReceipts(fastDb, hash, rawdb.GetBlockNumber(fastDb, hash), fast.Config()), rawdb.GetBlockReceipts(archiveDb, hash, rawdb.GetBlockNumber(archiveDb, hash), fast.Config()); types.DeriveSha(freceipts) != types.DeriveSha(areceipts) { + if freceipts, areceipts := rawdb.GetBlockReceipts(fastDb, hash, rawdb.GetBlockNumber(fastDb, hash), fast.Config()), rawdb.GetBlockReceipts(archiveDb, hash, rawdb.GetBlockNumber(archiveDb, hash), fast.Config()); types.DeriveSha(freceipts, trie.NewStackTrie(nil)) != types.DeriveSha(areceipts, trie.NewStackTrie(nil)) { t.Errorf("block #%d [%x]: receipts mismatch: have %v, want %v", num, hash, freceipts, areceipts) } } diff --git a/core/genesis.go b/core/genesis.go index b646c3b4c2..d068960cb1 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -35,6 +35,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/trie" ) //go:generate gencodec -type Genesis -field-override genesisSpecMarshaling -out gen_genesis.go @@ -258,7 +259,7 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block { statedb.Commit(false) statedb.Database().TrieDB().Commit(root, true) - return types.NewBlock(head, nil, nil, nil) + return types.NewBlock(head, nil, nil, nil, new(trie.StackTrie)) } // Commit writes the block and state of a genesis specification to the database. diff --git a/core/rawdb/accessors_chain_test.go b/core/rawdb/accessors_chain_test.go deleted file mode 100644 index 2e81df6279..0000000000 --- a/core/rawdb/accessors_chain_test.go +++ /dev/null @@ -1,427 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package rawdb - -import ( - "bytes" - "encoding/hex" - "fmt" - "math/big" - "testing" - - "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core/types" - "github.com/tomochain/tomochain/crypto/sha3" - "github.com/tomochain/tomochain/params" - "github.com/tomochain/tomochain/rlp" -) - -// Tests block header storage and retrieval operations. -func TestHeaderStorage(t *testing.T) { - db := NewMemoryDatabase() - - // Create a test header to move around the database and make sure it's really new - header := &types.Header{Number: big.NewInt(42), Extra: []byte("test header")} - if entry := GetHeader(db, header.Hash(), header.Number.Uint64()); entry != nil { - t.Fatalf("Non existent header returned: %v", entry) - } - // Write and verify the header in the database - if err := WriteHeader(db, header); err != nil { - t.Fatalf("Failed to write header into database: %v", err) - } - if entry := GetHeader(db, header.Hash(), header.Number.Uint64()); entry == nil { - t.Fatalf("Stored header not found") - } else if entry.Hash() != header.Hash() { - t.Fatalf("Retrieved header mismatch: have %v, want %v", entry, header) - } - if entry := GetHeaderRLP(db, header.Hash(), header.Number.Uint64()); entry == nil { - t.Fatalf("Stored header RLP not found") - } else { - hasher := sha3.NewKeccak256() - hasher.Write(entry) - - if hash := common.BytesToHash(hasher.Sum(nil)); hash != header.Hash() { - t.Fatalf("Retrieved RLP header mismatch: have %v, want %v", entry, header) - } - } - // Delete the header and verify the execution - DeleteHeader(db, header.Hash(), header.Number.Uint64()) - if entry := GetHeader(db, header.Hash(), header.Number.Uint64()); entry != nil { - t.Fatalf("Deleted header returned: %v", entry) - } -} - -// Tests block body storage and retrieval operations. -func TestBodyStorage(t *testing.T) { - db := NewMemoryDatabase() - - // Create a test body to move around the database and make sure it's really new - body := &types.Body{Uncles: []*types.Header{{Extra: []byte("test header")}}} - - hasher := sha3.NewKeccak256() - rlp.Encode(hasher, body) - hash := common.BytesToHash(hasher.Sum(nil)) - - if entry := GetBody(db, hash, 0); entry != nil { - t.Fatalf("Non existent body returned: %v", entry) - } - // Write and verify the body in the database - if err := WriteBody(db, hash, 0, body); err != nil { - t.Fatalf("Failed to write body into database: %v", err) - } - if entry := GetBody(db, hash, 0); entry == nil { - t.Fatalf("Stored body not found") - } else if types.DeriveSha(types.Transactions(entry.Transactions)) != types.DeriveSha(types.Transactions(body.Transactions)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) { - t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, body) - } - if entry := GetBodyRLP(db, hash, 0); entry == nil { - t.Fatalf("Stored body RLP not found") - } else { - hasher := sha3.NewKeccak256() - hasher.Write(entry) - - if calc := common.BytesToHash(hasher.Sum(nil)); calc != hash { - t.Fatalf("Retrieved RLP body mismatch: have %v, want %v", entry, body) - } - } - // Delete the body and verify the execution - DeleteBody(db, hash, 0) - if entry := GetBody(db, hash, 0); entry != nil { - t.Fatalf("Deleted body returned: %v", entry) - } -} - -// Tests block storage and retrieval operations. -func TestBlockStorage(t *testing.T) { - db := NewMemoryDatabase() - - // Create a test block to move around the database and make sure it's really new - block := types.NewBlockWithHeader(&types.Header{ - Extra: []byte("test block"), - UncleHash: types.EmptyUncleHash, - TxHash: types.EmptyRootHash, - ReceiptHash: types.EmptyRootHash, - }) - if entry := GetBlock(db, block.Hash(), block.NumberU64()); entry != nil { - t.Fatalf("Non existent block returned: %v", entry) - } - if entry := GetHeader(db, block.Hash(), block.NumberU64()); entry != nil { - t.Fatalf("Non existent header returned: %v", entry) - } - if entry := GetBody(db, block.Hash(), block.NumberU64()); entry != nil { - t.Fatalf("Non existent body returned: %v", entry) - } - // Write and verify the block in the database - if err := WriteBlock(db, block); err != nil { - t.Fatalf("Failed to write block into database: %v", err) - } - if entry := GetBlock(db, block.Hash(), block.NumberU64()); entry == nil { - t.Fatalf("Stored block not found") - } else if entry.Hash() != block.Hash() { - t.Fatalf("Retrieved block mismatch: have %v, want %v", entry, block) - } - if entry := GetHeader(db, block.Hash(), block.NumberU64()); entry == nil { - t.Fatalf("Stored header not found") - } else if entry.Hash() != block.Header().Hash() { - t.Fatalf("Retrieved header mismatch: have %v, want %v", entry, block.Header()) - } - if entry := GetBody(db, block.Hash(), block.NumberU64()); entry == nil { - t.Fatalf("Stored body not found") - } else if types.DeriveSha(types.Transactions(entry.Transactions)) != types.DeriveSha(block.Transactions()) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) { - t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, block.Body()) - } - // Delete the block and verify the execution - DeleteBlock(db, block.Hash(), block.NumberU64()) - if entry := GetBlock(db, block.Hash(), block.NumberU64()); entry != nil { - t.Fatalf("Deleted block returned: %v", entry) - } - if entry := GetHeader(db, block.Hash(), block.NumberU64()); entry != nil { - t.Fatalf("Deleted header returned: %v", entry) - } - if entry := GetBody(db, block.Hash(), block.NumberU64()); entry != nil { - t.Fatalf("Deleted body returned: %v", entry) - } -} - -// Tests that partial block contents don't get reassembled into full blocks. -func TestPartialBlockStorage(t *testing.T) { - db := NewMemoryDatabase() - block := types.NewBlockWithHeader(&types.Header{ - Extra: []byte("test block"), - UncleHash: types.EmptyUncleHash, - TxHash: types.EmptyRootHash, - ReceiptHash: types.EmptyRootHash, - }) - // Store a header and check that it's not recognized as a block - if err := WriteHeader(db, block.Header()); err != nil { - t.Fatalf("Failed to write header into database: %v", err) - } - if entry := GetBlock(db, block.Hash(), block.NumberU64()); entry != nil { - t.Fatalf("Non existent block returned: %v", entry) - } - DeleteHeader(db, block.Hash(), block.NumberU64()) - - // Store a body and check that it's not recognized as a block - if err := WriteBody(db, block.Hash(), block.NumberU64(), block.Body()); err != nil { - t.Fatalf("Failed to write body into database: %v", err) - } - if entry := GetBlock(db, block.Hash(), block.NumberU64()); entry != nil { - t.Fatalf("Non existent block returned: %v", entry) - } - DeleteBody(db, block.Hash(), block.NumberU64()) - - // Store a header and a body separately and check reassembly - if err := WriteHeader(db, block.Header()); err != nil { - t.Fatalf("Failed to write header into database: %v", err) - } - if err := WriteBody(db, block.Hash(), block.NumberU64(), block.Body()); err != nil { - t.Fatalf("Failed to write body into database: %v", err) - } - if entry := GetBlock(db, block.Hash(), block.NumberU64()); entry == nil { - t.Fatalf("Stored block not found") - } else if entry.Hash() != block.Hash() { - t.Fatalf("Retrieved block mismatch: have %v, want %v", entry, block) - } -} - -// Tests block total difficulty storage and retrieval operations. -func TestTdStorage(t *testing.T) { - db := NewMemoryDatabase() - - // Create a test TD to move around the database and make sure it's really new - hash, td := common.Hash{}, big.NewInt(314) - if entry := GetTd(db, hash, 0); entry != nil { - t.Fatalf("Non existent TD returned: %v", entry) - } - // Write and verify the TD in the database - if err := WriteTd(db, hash, 0, td); err != nil { - t.Fatalf("Failed to write TD into database: %v", err) - } - if entry := GetTd(db, hash, 0); entry == nil { - t.Fatalf("Stored TD not found") - } else if entry.Cmp(td) != 0 { - t.Fatalf("Retrieved TD mismatch: have %v, want %v", entry, td) - } - // Delete the TD and verify the execution - DeleteTd(db, hash, 0) - if entry := GetTd(db, hash, 0); entry != nil { - t.Fatalf("Deleted TD returned: %v", entry) - } -} - -// Tests that canonical numbers can be mapped to hashes and retrieved. -func TestCanonicalMappingStorage(t *testing.T) { - db := NewMemoryDatabase() - - // Create a test canonical number and assinged hash to move around - hash, number := common.Hash{0: 0xff}, uint64(314) - if entry := GetCanonicalHash(db, number); entry != (common.Hash{}) { - t.Fatalf("Non existent canonical mapping returned: %v", entry) - } - // Write and verify the TD in the database - if err := WriteCanonicalHash(db, hash, number); err != nil { - t.Fatalf("Failed to write canonical mapping into database: %v", err) - } - if entry := GetCanonicalHash(db, number); entry == (common.Hash{}) { - t.Fatalf("Stored canonical mapping not found") - } else if entry != hash { - t.Fatalf("Retrieved canonical mapping mismatch: have %v, want %v", entry, hash) - } - // Delete the TD and verify the execution - DeleteCanonicalHash(db, number) - if entry := GetCanonicalHash(db, number); entry != (common.Hash{}) { - t.Fatalf("Deleted canonical mapping returned: %v", entry) - } -} - -// Tests that head headers and head blocks can be assigned, individually. -func TestHeadStorage(t *testing.T) { - db := NewMemoryDatabase() - - blockHead := types.NewBlockWithHeader(&types.Header{Extra: []byte("test block header")}) - blockFull := types.NewBlockWithHeader(&types.Header{Extra: []byte("test block full")}) - blockFast := types.NewBlockWithHeader(&types.Header{Extra: []byte("test block fast")}) - - // Check that no head entries are in a pristine database - if entry := GetHeadHeaderHash(db); entry != (common.Hash{}) { - t.Fatalf("Non head header entry returned: %v", entry) - } - if entry := GetHeadBlockHash(db); entry != (common.Hash{}) { - t.Fatalf("Non head block entry returned: %v", entry) - } - if entry := GetHeadFastBlockHash(db); entry != (common.Hash{}) { - t.Fatalf("Non fast head block entry returned: %v", entry) - } - // Assign separate entries for the head header and block - if err := WriteHeadHeaderHash(db, blockHead.Hash()); err != nil { - t.Fatalf("Failed to write head header hash: %v", err) - } - if err := WriteHeadBlockHash(db, blockFull.Hash()); err != nil { - t.Fatalf("Failed to write head block hash: %v", err) - } - if err := WriteHeadFastBlockHash(db, blockFast.Hash()); err != nil { - t.Fatalf("Failed to write fast head block hash: %v", err) - } - // Check that both heads are present, and different (i.e. two heads maintained) - if entry := GetHeadHeaderHash(db); entry != blockHead.Hash() { - t.Fatalf("Head header hash mismatch: have %v, want %v", entry, blockHead.Hash()) - } - if entry := GetHeadBlockHash(db); entry != blockFull.Hash() { - t.Fatalf("Head block hash mismatch: have %v, want %v", entry, blockFull.Hash()) - } - if entry := GetHeadFastBlockHash(db); entry != blockFast.Hash() { - t.Fatalf("Fast head block hash mismatch: have %v, want %v", entry, blockFast.Hash()) - } -} - -// Tests that positional lookup metadata can be stored and retrieved. -func TestLookupStorage(t *testing.T) { - db := NewMemoryDatabase() - - tx1 := types.NewTransaction(1, common.BytesToAddress([]byte{0x11}), big.NewInt(111), 1111, big.NewInt(11111), []byte{0x11, 0x11, 0x11}) - tx2 := types.NewTransaction(2, common.BytesToAddress([]byte{0x22}), big.NewInt(222), 2222, big.NewInt(22222), []byte{0x22, 0x22, 0x22}) - tx3 := types.NewTransaction(3, common.BytesToAddress([]byte{0x33}), big.NewInt(333), 3333, big.NewInt(33333), []byte{0x33, 0x33, 0x33}) - txs := []*types.Transaction{tx1, tx2, tx3} - - block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil) - - // Check that no transactions entries are in a pristine database - for i, tx := range txs { - if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn != nil { - t.Fatalf("tx #%d [%x]: non existent transaction returned: %v", i, tx.Hash(), txn) - } - } - // Insert all the transactions into the database, and verify contents - if err := WriteBlock(db, block); err != nil { - t.Fatalf("failed to write block contents: %v", err) - } - if err := WriteTxLookupEntries(db, block); err != nil { - t.Fatalf("failed to write transactions: %v", err) - } - for i, tx := range txs { - if txn, hash, number, index := GetTransaction(db, tx.Hash()); txn == nil { - t.Fatalf("tx #%d [%x]: transaction not found", i, tx.Hash()) - } else { - if hash != block.Hash() || number != block.NumberU64() || index != uint64(i) { - t.Fatalf("tx #%d [%x]: positional metadata mismatch: have %x/%d/%d, want %x/%v/%v", i, tx.Hash(), hash, number, index, block.Hash(), block.NumberU64(), i) - } - if tx.String() != txn.String() { - t.Fatalf("tx #%d [%x]: transaction mismatch: have %v, want %v", i, tx.Hash(), txn, tx) - } - } - } - // Delete the transactions and check purge - for i, tx := range txs { - DeleteTxLookupEntry(db, tx.Hash()) - if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn != nil { - t.Fatalf("tx #%d [%x]: deleted transaction returned: %v", i, tx.Hash(), txn) - } - } -} - -// Tests that receipts associated with a single block can be stored and retrieved. -func TestBlockReceiptStorage(t *testing.T) { - db := NewMemoryDatabase() - - // Create a live block since we need metadata to reconstruct the receipt - tx1 := types.NewTransaction(1, common.HexToAddress("0x1"), big.NewInt(1), 1, big.NewInt(1), nil) - tx2 := types.NewTransaction(2, common.HexToAddress("0x2"), big.NewInt(2), 2, big.NewInt(2), nil) - - body := &types.Body{Transactions: types.Transactions{tx1, tx2}} - - // Create the two receipts to manage afterwards - receipt1 := &types.Receipt{ - Status: types.ReceiptStatusFailed, - CumulativeGasUsed: 1, - Logs: []*types.Log{ - {Address: common.BytesToAddress([]byte{0x11})}, - {Address: common.BytesToAddress([]byte{0x01, 0x11})}, - }, - TxHash: tx1.Hash(), - ContractAddress: common.BytesToAddress([]byte{0x01, 0x11, 0x11}), - GasUsed: 111111, - } - receipt1.Bloom = types.CreateBloom(types.Receipts{receipt1}) - - receipt2 := &types.Receipt{ - PostState: common.Hash{2}.Bytes(), - CumulativeGasUsed: 2, - Logs: []*types.Log{ - {Address: common.BytesToAddress([]byte{0x22})}, - {Address: common.BytesToAddress([]byte{0x02, 0x22})}, - }, - TxHash: tx2.Hash(), - ContractAddress: common.BytesToAddress([]byte{0x02, 0x22, 0x22}), - GasUsed: 222222, - } - receipt2.Bloom = types.CreateBloom(types.Receipts{receipt2}) - receipts := []*types.Receipt{receipt1, receipt2} - - // Check that no receipt entries are in a pristine database - hash := common.BytesToHash([]byte{0x03, 0x14}) - if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) != 0 { - t.Fatalf("non existent receipts returned: %v", rs) - } - // Insert the body that corresponds to the receipts - WriteBody(db, hash, 0, body) - - // Insert the receipt slice into the database and check presence - WriteBlockReceipts(db, hash, 0, receipts) - if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) == 0 { - t.Fatalf("no receipts returned") - } else { - if err := checkReceiptsRLP(rs, receipts); err != nil { - t.Fatalf(err.Error()) - } - } - // Delete the body and ensure that the receipts are no longer returned (metadata can't be recomputed) - DeleteBody(db, hash, 0) - if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); rs != nil { - t.Fatalf("receipts returned when body was deleted: %v", rs) - } - // Ensure that receipts without metadata can be returned without the block body too - if err := checkReceiptsRLP(ReadRawReceipts(db, hash, 0), receipts); err != nil { - t.Fatalf(err.Error()) - } - // Sanity check that body alone without the receipt is a full purge - WriteBody(db, hash, 0, body) - - DeleteBlockReceipts(db, hash, 0) - if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) != 0 { - t.Fatalf("deleted receipts returned: %v", rs) - } -} - -func checkReceiptsRLP(have, want types.Receipts) error { - if len(have) != len(want) { - return fmt.Errorf("receipts sizes mismatch: have %d, want %d", len(have), len(want)) - } - for i := 0; i < len(want); i++ { - rlpHave, err := rlp.EncodeToBytes(have[i]) - if err != nil { - return err - } - rlpWant, err := rlp.EncodeToBytes(want[i]) - if err != nil { - return err - } - if !bytes.Equal(rlpHave, rlpWant) { - return fmt.Errorf("receipt #%d: receipt mismatch: have %s, want %s", i, hex.EncodeToString(rlpHave), hex.EncodeToString(rlpWant)) - } - } - return nil -} diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 4c6a311119..8ddb0650ea 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -19,8 +19,6 @@ package core import ( "crypto/ecdsa" "fmt" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "math/rand" @@ -29,11 +27,14 @@ import ( "time" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) // testTxPoolConfig is a transaction pool configuration without stateful disk @@ -70,7 +71,7 @@ func (bc *testBlockChain) Config() *params.ChainConfig { func (bc *testBlockChain) CurrentBlock() *types.Block { return types.NewBlock(&types.Header{ GasLimit: bc.gasLimit, - }, nil, nil, nil) + }, nil, nil, nil, new(trie.StackTrie)) } func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block { @@ -872,8 +873,10 @@ func testTransactionQueueGlobalLimiting(t *testing.T, nolocals bool) { // // This logic should not hold for local transactions, unless the local tracking // mechanism is disabled. -func TestTransactionQueueTimeLimiting(t *testing.T) { testTransactionQueueTimeLimiting(t, false) } -func TestTransactionQueueTimeLimitingNoLocals(t *testing.T) { testTransactionQueueTimeLimiting(t, true) } +func TestTransactionQueueTimeLimiting(t *testing.T) { testTransactionQueueTimeLimiting(t, false) } +func TestTransactionQueueTimeLimitingNoLocals(t *testing.T) { + testTransactionQueueTimeLimiting(t, true) +} func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { common.MinGasPrice = big.NewInt(0) @@ -981,8 +984,10 @@ func TestTransactionPendingLimiting(t *testing.T) { // Tests that the transaction limits are enforced the same way irrelevant whether // the transactions are added one by one or in batches. -func TestTransactionQueueLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 1) } -func TestTransactionPendingLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 0) } +func TestTransactionQueueLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 1) } +func TestTransactionPendingLimitingEquivalency(t *testing.T) { + testTransactionLimitingEquivalency(t, 0) +} func testTransactionLimitingEquivalency(t *testing.T, origin uint64) { t.Parallel() diff --git a/core/types/block.go b/core/types/block.go index 9e95a1d82c..66baecf2cf 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -220,14 +220,14 @@ type storageblock struct { // The values of TxHash, UncleHash, ReceiptHash and Bloom in header // are ignored and set to values derived from the given txs, uncles // and receipts. -func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []*Receipt) *Block { +func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []*Receipt, hasher Hasher) *Block { b := &Block{header: CopyHeader(header), td: new(big.Int)} // TODO: panic if len(txs) != len(receipts) if len(txs) == 0 { b.header.TxHash = EmptyRootHash } else { - b.header.TxHash = DeriveSha(Transactions(txs)) + b.header.TxHash = DeriveSha(Transactions(txs), hasher) b.transactions = make(Transactions, len(txs)) copy(b.transactions, txs) } @@ -235,7 +235,7 @@ func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []* if len(receipts) == 0 { b.header.ReceiptHash = EmptyRootHash } else { - b.header.ReceiptHash = DeriveSha(Receipts(receipts)) + b.header.ReceiptHash = DeriveSha(Receipts(receipts), hasher) b.header.Bloom = CreateBloom(receipts) } diff --git a/core/types/block_test.go b/core/types/block_test.go index 9b78b653c7..e93ae02de8 100644 --- a/core/types/block_test.go +++ b/core/types/block_test.go @@ -17,13 +17,15 @@ package types import ( + "bytes" + "hash" "math/big" + "reflect" "testing" - "bytes" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/rlp" - "reflect" + "golang.org/x/crypto/sha3" ) // from bcValidBlockTest.json, "SimpleTx" @@ -59,3 +61,38 @@ func TestBlockEncoding(t *testing.T) { t.Errorf("encoded block mismatch:\ngot: %x\nwant: %x", ourBlockEnc, blockEnc) } } + +func TestUncleHash(t *testing.T) { + uncles := make([]*Header, 0) + h := CalcUncleHash(uncles) + exp := common.HexToHash("1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347") + if h != exp { + t.Fatalf("empty uncle hash is wrong, got %x != %x", h, exp) + } +} + +var benchBuffer = bytes.NewBuffer(make([]byte, 0, 32000)) + +// testHasher is the helper tool for transaction/receipt list hashing. +// The original hasher is trie, in order to get rid of import cycle, +// use the testing hasher instead. +type testHasher struct { + hasher hash.Hash +} + +func newHasher() *testHasher { + return &testHasher{hasher: sha3.NewLegacyKeccak256()} +} + +func (h *testHasher) Reset() { + h.hasher.Reset() +} + +func (h *testHasher) Update(key, val []byte) { + h.hasher.Write(key) + h.hasher.Write(val) +} + +func (h *testHasher) Hash() common.Hash { + return common.BytesToHash(h.hasher.Sum(nil)) +} diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go index 2731c39cbb..210ee26e06 100644 --- a/core/types/derive_sha.go +++ b/core/types/derive_sha.go @@ -21,21 +21,58 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/rlp" - "github.com/tomochain/tomochain/trie" ) +// DerivableList is the interface which can derive the hash. type DerivableList interface { Len() int - GetRlp(i int) []byte + EncodeIndex(int, *bytes.Buffer) } -func DeriveSha(list DerivableList) common.Hash { - keybuf := new(bytes.Buffer) - trie := new(trie.Trie) - for i := 0; i < list.Len(); i++ { - keybuf.Reset() - rlp.Encode(keybuf, uint(i)) - trie.Update(keybuf.Bytes(), list.GetRlp(i)) +// Hasher is the tool used to calculate the hash of derivable list. +type Hasher interface { + Reset() + Update([]byte, []byte) error + Hash() common.Hash +} + +func encodeForDerive(list DerivableList, i int, buf *bytes.Buffer) []byte { + buf.Reset() + list.EncodeIndex(i, buf) + // It's really unfortunate that we need to do perform this copy. + // StackTrie holds onto the values until Hash is called, so the values + // written to it must not alias. + return common.CopyBytes(buf.Bytes()) +} + +// DeriveSha creates the tree hashes of transactions, receipts, and withdrawals in a block header. +func DeriveSha(list DerivableList, hasher Hasher) common.Hash { + hasher.Reset() + + valueBuf := encodeBufferPool.Get().(*bytes.Buffer) + defer encodeBufferPool.Put(valueBuf) + + // StackTrie requires values to be inserted in increasing hash order, which is not the + // order that `list` provides hashes in. This insertion sequence ensures that the + // order is correct. + // + // The error returned by hasher is omitted because hasher will produce an incorrect + // hash in case any error occurs. + var indexBuf []byte + for i := 1; i < list.Len() && i <= 0x7f; i++ { + indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i)) + value := encodeForDerive(list, i, valueBuf) + hasher.Update(indexBuf, value) + } + if list.Len() > 0 { + indexBuf = rlp.AppendUint64(indexBuf[:0], 0) + value := encodeForDerive(list, 0, valueBuf) + hasher.Update(indexBuf, value) + } + for i := 0x80; i < list.Len(); i++ { + indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i)) + value := encodeForDerive(list, i, valueBuf) + hasher.Update(indexBuf, value) } - return trie.Hash() + return hasher.Hash() } diff --git a/core/types/hashing.go b/core/types/hashing.go new file mode 100644 index 0000000000..8b9cb92b94 --- /dev/null +++ b/core/types/hashing.go @@ -0,0 +1,11 @@ +package types + +import ( + "bytes" + "sync" +) + +// encodeBufferPool holds temporary encoder buffers for DeriveSha and TX encoding. +var encodeBufferPool = sync.Pool{ + New: func() interface{} { return new(bytes.Buffer) }, +} diff --git a/core/types/hashing_test.go b/core/types/hashing_test.go new file mode 100644 index 0000000000..d2f2781a6b --- /dev/null +++ b/core/types/hashing_test.go @@ -0,0 +1,79 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types_test + +import ( + "math/big" + "testing" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/trie" +) + +func BenchmarkDeriveSha200(b *testing.B) { + txs, err := genTxs(200) + if err != nil { + b.Fatal(err) + } + var exp common.Hash + var got common.Hash + b.Run("std_trie", func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) + exp = types.DeriveSha(txs, tr) + } + }) + + b.Run("stack_trie", func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + got = types.DeriveSha(txs, trie.NewStackTrie(nil)) + } + }) + if got != exp { + b.Errorf("got %x exp %x", got, exp) + } +} + +func genTxs(num uint64) (types.Transactions, error) { + key, err := crypto.HexToECDSA("deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + if err != nil { + return nil, err + } + var addr = crypto.PubkeyToAddress(key.PublicKey) + newTx := func(i uint64) (*types.Transaction, error) { + signer := types.NewEIP155Signer(big.NewInt(18)) + utx := types.NewTransaction(i, addr, new(big.Int), 0, new(big.Int).SetUint64(10000000), nil) + tx, err := types.SignTx(utx, signer, key) + return tx, err + } + var txs types.Transactions + for i := uint64(0); i < num; i++ { + tx, err := newTx(i) + if err != nil { + return nil, err + } + txs = append(txs, tx) + } + return txs, nil +} diff --git a/core/types/lending_transaction.go b/core/types/lending_transaction.go index e33826829c..7152461117 100644 --- a/core/types/lending_transaction.go +++ b/core/types/lending_transaction.go @@ -17,6 +17,7 @@ package types import ( + "bytes" "container/heap" "errors" "io" @@ -319,10 +320,12 @@ func (s LendingTransactions) Len() int { return len(s) } // Swap swaps the i'th and the j'th element in s. func (s LendingTransactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// GetRlp implements Rlpable and returns the i'th element of s in rlp. -func (s LendingTransactions) GetRlp(i int) []byte { - enc, _ := rlp.EncodeToBytes(s[i]) - return enc +// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors +// because we assume that *Transaction will only ever contain valid txs that were either +// constructed by decoding or via public API in this package. +func (s LendingTransactions) EncodeIndex(i int, w *bytes.Buffer) { + tx := s[i] + rlp.Encode(w, tx.data) } // LendingTxDifference returns a new set t which is the difference between a to b. @@ -363,7 +366,7 @@ func (s *LendingTxByNonce) Pop() interface{} { return x } -//LendingTransactionByNonce sort transaction by nonce +// LendingTransactionByNonce sort transaction by nonce type LendingTransactionByNonce struct { txs map[common.Address]LendingTransactions heads LendingTxByNonce diff --git a/core/types/order_transaction.go b/core/types/order_transaction.go index d51884e3f5..e7150b991e 100644 --- a/core/types/order_transaction.go +++ b/core/types/order_transaction.go @@ -17,6 +17,7 @@ package types import ( + "bytes" "container/heap" "errors" "io" @@ -250,10 +251,12 @@ func (s OrderTransactions) Len() int { return len(s) } // Swap swaps the i'th and the j'th element in s. func (s OrderTransactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// GetRlp implements Rlpable and returns the i'th element of s in rlp. -func (s OrderTransactions) GetRlp(i int) []byte { - enc, _ := rlp.EncodeToBytes(s[i]) - return enc +// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors +// because we assume that *Transaction will only ever contain valid txs that were either +// constructed by decoding or via public API in this package. +func (s OrderTransactions) EncodeIndex(i int, w *bytes.Buffer) { + tx := s[i] + rlp.Encode(w, tx.data) } // OrderTxDifference returns a new set t which is the difference between a to b. diff --git a/core/types/receipt.go b/core/types/receipt.go index 121c647a31..3235af79cf 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -256,6 +256,13 @@ type Receipts []*Receipt // Len returns the number of receipts in this list. func (r Receipts) Len() int { return len(r) } +// EncodeIndex encodes the i'th receipt to w. +func (rs Receipts) EncodeIndex(i int, w *bytes.Buffer) { + r := rs[i] + data := &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs} + rlp.Encode(w, data) +} + // DeriveFields fills the receipts with their computed fields based on consensus // data and contextual infos like containing block and transactions. func (rs Receipts) DeriveFields(config *params.ChainConfig, hash common.Hash, number uint64, txs []*Transaction) error { diff --git a/core/types/transaction.go b/core/types/transaction.go index dd983bcb27..d0d2ce215d 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -17,6 +17,7 @@ package types import ( + "bytes" "container/heap" "errors" "fmt" @@ -495,15 +496,17 @@ type Transactions []*Transaction // Len returns the length of s. func (s Transactions) Len() int { return len(s) } +// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors +// because we assume that *Transaction will only ever contain valid txs that were either +// constructed by decoding or via public API in this package. +func (s Transactions) EncodeIndex(i int, w *bytes.Buffer) { + tx := s[i] + rlp.Encode(w, tx.data) +} + // Swap swaps the i'th and the j'th element in s. func (s Transactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// GetRlp implements Rlpable and returns the i'th element of s in rlp. -func (s Transactions) GetRlp(i int) []byte { - enc, _ := rlp.EncodeToBytes(s[i]) - return enc -} - // TxDifference returns a new set t which is the difference between a to b. func TxDifference(a, b Transactions) (keep Transactions) { keep = make(Transactions, 0, len(a)) diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index 0ed4e75faa..43569da2df 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -29,6 +29,7 @@ import ( "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/metrics" + "github.com/tomochain/tomochain/trie" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -767,7 +768,7 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLi defer q.lock.Unlock() reconstruct := func(header *types.Header, index int, result *fetchResult) error { - if types.DeriveSha(types.Transactions(txLists[index])) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash { + if types.DeriveSha(types.Transactions(txLists[index]), new(trie.StackTrie)) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash { return errInvalidBody } result.Transactions = txLists[index] @@ -785,7 +786,7 @@ func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) (int, defer q.lock.Unlock() reconstruct := func(header *types.Header, index int, result *fetchResult) error { - if types.DeriveSha(types.Receipts(receiptList[index])) != header.ReceiptHash { + if types.DeriveSha(types.Receipts(receiptList[index]), new(trie.StackTrie)) != header.ReceiptHash { return errInvalidReceipt } result.Receipts = receiptList[index] diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go index 65b15094d2..142089586c 100644 --- a/eth/fetcher/fetcher.go +++ b/eth/fetcher/fetcher.go @@ -19,14 +19,15 @@ package fetcher import ( "errors" - "github.com/hashicorp/golang-lru" "math/rand" "time" + lru "github.com/hashicorp/golang-lru" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/trie" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -468,7 +469,7 @@ func (f *Fetcher) loop() { announce.time = task.time // If the block is empty (header only), short circuit into the final import queue - if header.TxHash == types.DeriveSha(types.Transactions{}) && header.UncleHash == types.CalcUncleHash([]*types.Header{}) { + if header.TxHash == types.EmptyRootHash && header.UncleHash == types.CalcUncleHash([]*types.Header{}) { log.Trace("Block empty, skipping body retrieval", "peer", announce.origin, "number", header.Number, "hash", header.Hash()) block := types.NewBlockWithHeader(header) @@ -530,7 +531,7 @@ func (f *Fetcher) loop() { for hash, announce := range f.completing { if f.queued[hash] == nil { - txnHash := types.DeriveSha(types.Transactions(task.transactions[i])) + txnHash := types.DeriveSha(types.Transactions(task.transactions[i]), new(trie.StackTrie)) uncleHash := types.CalcUncleHash(task.uncles[i]) if txnHash == announce.header.TxHash && uncleHash == announce.header.UncleHash && announce.origin == task.peer { diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go index ab7e03aaa1..951b2fcd6c 100644 --- a/eth/fetcher/fetcher_test.go +++ b/eth/fetcher/fetcher_test.go @@ -18,7 +18,6 @@ package fetcher import ( "errors" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sync" "sync/atomic" @@ -28,9 +27,11 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) var ( @@ -38,7 +39,7 @@ var ( testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") testAddress = crypto.PubkeyToAddress(testKey.PublicKey) genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000)) - unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil) + unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil, new(trie.StackTrie)) ) // makeChain creates a chain of n blocks starting at and including parent. diff --git a/les/odr_requests.go b/les/odr_requests.go index cca89b1e7e..8bf12f6e8b 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -114,7 +114,7 @@ func (r *BlockRequest) Validate(db ethdb.Database, msg *Msg) error { if header == nil { return errHeaderUnavailable } - if header.TxHash != types.DeriveSha(types.Transactions(body.Transactions)) { + if header.TxHash != types.DeriveSha(types.Transactions(body.Transactions), new(trie.StackTrie)) { return errTxHashMismatch } if header.UncleHash != types.CalcUncleHash(body.Uncles) { @@ -170,7 +170,7 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error { if header == nil { return errHeaderUnavailable } - if header.ReceiptHash != types.DeriveSha(receipt) { + if header.ReceiptHash != types.DeriveSha(receipt, new(trie.StackTrie)) { return errReceiptHashMismatch } // Validations passed, store and return diff --git a/miner/worker.go b/miner/worker.go index 995c401690..a8985a2a81 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -23,6 +23,7 @@ import ( "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "github.com/tomochain/tomochain/trie" "math/big" "os" @@ -204,6 +205,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) { self.current.txs, nil, self.current.receipts, + new(trie.Trie), ), self.current.state.Copy() } return self.current.Block, self.current.state.Copy() @@ -219,6 +221,7 @@ func (self *worker) pendingBlock() *types.Block { self.current.txs, nil, self.current.receipts, + new(trie.Trie), ) } return self.current.Block diff --git a/rlp/decode.go b/rlp/decode.go index 20c454ca9c..ac93c139a9 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -28,9 +28,8 @@ import ( "strings" "sync" - "github.com/tomochain/tomochain/rlp/internal/rlpstruct" - "github.com/holiman/uint256" + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" ) //lint:ignore ST1012 EOL is not an error. diff --git a/rlp/encode.go b/rlp/encode.go index f34be7f3df..2ca283c0a3 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -23,9 +23,8 @@ import ( "math/big" "reflect" - "github.com/tomochain/tomochain/rlp/internal/rlpstruct" - "github.com/holiman/uint256" + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" ) var ( diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 7b8775c12b..9f2e6c38f9 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -26,9 +26,8 @@ import ( "sync" "testing" - "github.com/tomochain/tomochain/common/math" - "github.com/holiman/uint256" + "github.com/tomochain/tomochain/common/math" ) type testEncoder struct { diff --git a/trie/database.go b/trie/database.go index bf3f2e89de..a1422dcb0a 100644 --- a/trie/database.go +++ b/trie/database.go @@ -107,11 +107,11 @@ func (n rawNode) Cache() (HashNode, bool) { panic("this should never end up in func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } func (n rawNode) EncodeRLP(w io.Writer) error { - _, err := w.Write(n) + _, err := w.Write([]byte(n)) return err } -// rawFullNode represents only the useful data content of a full Node, with the +// rawFullNode represents only the useful data content of a full node, with the // caches and flags stripped out to minimize its data storage. This type honors // the same RLP encoding as the original parent. type rawFullNode [17]Node @@ -790,6 +790,7 @@ func (db *Database) commit(hash common.Hash, batch ethdb.Batch, uncacher *cleane if err != nil { return err } + if err := batch.Put(hash[:], node.rlp()); err != nil { return err } diff --git a/trie/trie.go b/trie/trie.go index 589a96186d..a0c627d232 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -316,10 +316,12 @@ func (t *Trie) tryGetBestRightKeyAndValue(origNode Node, prefix []byte) (key []b // // The value bytes must not be modified by the caller while they are // stored in the trie. -func (t *Trie) Update(key, value []byte) { +func (t *Trie) Update(key, value []byte) error { if err := t.TryUpdate(key, value); err != nil { log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + return err } + return nil } // TryUpdate associates key with value in the trie. Subsequent calls to @@ -637,3 +639,9 @@ func (t *Trie) hashRoot(db *Database) (Node, Node, error) { t.unhashed = 0 return hashed, cached, nil } + +// Reset drops the referenced root node and cleans all internal state. +func (t *Trie) Reset() { + t.root = nil + t.unhashed = 0 +} From 62f26cf7cf6b284ed429441ab3dc64ce2257b3a3 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 1 Aug 2023 11:39:56 +0700 Subject: [PATCH 059/119] Minor fix --- internal/ethapi/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 9b8bab7352..dda53d2c09 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -1062,7 +1062,7 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr GasPrice: gasPrice, Data: args.Data, BalanceTokenFee: balanceTokenFee, - SkipAccountChecks: false, + SkipAccountChecks: true, } // Setup context so it may be cancelled the call has completed From 097326cadd2e2daeac7b218579e932d4f49fc152 Mon Sep 17 00:00:00 2001 From: trinhdn2 <136422909+trinhdn2@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:48:15 +0700 Subject: [PATCH 060/119] Bump client version to 2.4.0 --- params/version.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/params/version.go b/params/version.go index af4d16e53c..c220c12734 100644 --- a/params/version.go +++ b/params/version.go @@ -21,10 +21,10 @@ import ( ) const ( - VersionMajor = 2 // Major version component of the current release - VersionMinor = 3 // Minor version component of the current release - VersionPatch = 2 // Patch version component of the current release - VersionMeta = "stable" // Version metadata to append to the version string + VersionMajor = 2 // Major version component of the current release + VersionMinor = 4 // Minor version component of the current release + VersionPatch = 0 // Patch version component of the current release + VersionMeta = "dev" // Version metadata to append to the version string ) // Version holds the textual version string. From 943f5000d279927aff49c2747f226fa40aca5203 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 3 Aug 2023 14:46:22 +0700 Subject: [PATCH 061/119] Fix UpdateStorage in trie --- cmd/evm/runner.go | 9 ++++----- trie/secure_trie.go | 3 +-- trie/trie.go | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index c6def708c3..f3fcef9c7e 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -22,24 +22,23 @@ import ( "fmt" "io/ioutil" "os" + goruntime "runtime" "runtime/pprof" "time" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/trie" - - goruntime "runtime" + cli "gopkg.in/urfave/cli.v1" "github.com/tomochain/tomochain/cmd/evm/internal/compiler" "github.com/tomochain/tomochain/cmd/utils" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/core/vm/runtime" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" - cli "gopkg.in/urfave/cli.v1" + "github.com/tomochain/tomochain/trie" ) var runCommand = cli.Command{ diff --git a/trie/secure_trie.go b/trie/secure_trie.go index db95c7cc41..cbffd559e3 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -135,8 +135,7 @@ func (t *SecureTrie) MustUpdate(key, value []byte) { // If a node is not found in the database, a MissingNodeError is returned. func (t *SecureTrie) UpdateStorage(_ common.Address, key, value []byte) error { hk := t.hashKey(key) - v, _ := rlp.EncodeToBytes(value) - err := t.trie.Update(hk, v) + err := t.trie.Update(hk, value) if err != nil { return err } diff --git a/trie/trie.go b/trie/trie.go index 8d63ee9620..9df6e56559 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -531,7 +531,7 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { // shortNode{..., shortNode{...}}. Since the entry // might not be loaded yet, resolve it just for this // check. - cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos))) + cnode, err := t.resolve(n.Children[pos], prefix) if err != nil { return false, nil, err } From be70268ce5ff86dac130a31d058bf16201dbc3ea Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 3 Aug 2023 15:27:53 +0700 Subject: [PATCH 062/119] Fix trie unit tests --- core/rawdb/accessors_chain_test.go | 9 ++--- internal/blocktest/test_hash.go | 60 ++++++++++++++++++++++++++++++ tests/state_test.go | 3 ++ trie/secure_trie_test.go | 20 +++++----- trie/trie_test.go | 43 ++++++++++----------- 5 files changed, 97 insertions(+), 38 deletions(-) create mode 100644 internal/blocktest/test_hash.go diff --git a/core/rawdb/accessors_chain_test.go b/core/rawdb/accessors_chain_test.go index d978aa79ed..4f66f8d0f7 100644 --- a/core/rawdb/accessors_chain_test.go +++ b/core/rawdb/accessors_chain_test.go @@ -22,11 +22,10 @@ import ( "testing" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto/sha3" + "github.com/tomochain/tomochain/internal/blocktest" "github.com/tomochain/tomochain/rlp" - "github.com/tomochain/tomochain/trie" ) // Tests block header storage and retrieval operations. @@ -84,7 +83,7 @@ func TestBodyStorage(t *testing.T) { } if entry := GetBody(db, hash, 0); entry == nil { t.Fatalf("Stored body not found") - } else if types.DeriveSha(types.Transactions(entry.Transactions), new(trie.StackTrie)) != types.DeriveSha(types.Transactions(body.Transactions), new(trie.StackTrie)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) { + } else if types.DeriveSha(types.Transactions(entry.Transactions), blocktest.NewHasher()) != types.DeriveSha(types.Transactions(body.Transactions), blocktest.NewHasher()) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) { t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, body) } if entry := GetBodyRLP(db, hash, 0); entry == nil { @@ -140,7 +139,7 @@ func TestBlockStorage(t *testing.T) { } if entry := GetBody(db, block.Hash(), block.NumberU64()); entry == nil { t.Fatalf("Stored body not found") - } else if types.DeriveSha(types.Transactions(entry.Transactions), new(trie.StackTrie)) != types.DeriveSha(block.Transactions(), new(trie.StackTrie)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) { + } else if types.DeriveSha(types.Transactions(entry.Transactions), blocktest.NewHasher()) != types.DeriveSha(block.Transactions(), blocktest.NewHasher()) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) { t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, block.Body()) } // Delete the block and verify the execution @@ -296,7 +295,7 @@ func TestLookupStorage(t *testing.T) { tx3 := types.NewTransaction(3, common.BytesToAddress([]byte{0x33}), big.NewInt(333), 3333, big.NewInt(33333), []byte{0x33, 0x33, 0x33}) txs := []*types.Transaction{tx1, tx2, tx3} - block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil, new(trie.StackTrie)) + block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil, blocktest.NewHasher()) // Check that no transactions entries are in a pristine database for i, tx := range txs { diff --git a/internal/blocktest/test_hash.go b/internal/blocktest/test_hash.go new file mode 100644 index 0000000000..37d979e319 --- /dev/null +++ b/internal/blocktest/test_hash.go @@ -0,0 +1,60 @@ +// Copyright 2023 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package utesting provides a standalone replacement for package testing. +// +// This package exists because package testing cannot easily be embedded into a +// standalone go program. It provides an API that mirrors the standard library +// testing API. + +package blocktest + +import ( + "hash" + + "golang.org/x/crypto/sha3" + + "github.com/tomochain/tomochain/common" +) + +// testHasher is the helper tool for transaction/receipt list hashing. +// The original hasher is trie, in order to get rid of import cycle, +// use the testing hasher instead. +type testHasher struct { + hasher hash.Hash +} + +// NewHasher returns a new testHasher instance. +func NewHasher() *testHasher { + return &testHasher{hasher: sha3.NewLegacyKeccak256()} +} + +// Reset resets the hash state. +func (h *testHasher) Reset() { + h.hasher.Reset() +} + +// Update updates the hash state with the given key and value. +func (h *testHasher) Update(key, val []byte) error { + h.hasher.Write(key) + h.hasher.Write(val) + return nil +} + +// Hash returns the hash value. +func (h *testHasher) Hash() common.Hash { + return common.BytesToHash(h.hasher.Sum(nil)) +} diff --git a/tests/state_test.go b/tests/state_test.go index 7c8c5e9268..81a7370d60 100644 --- a/tests/state_test.go +++ b/tests/state_test.go @@ -26,6 +26,9 @@ import ( ) func TestState(t *testing.T) { + if testing.Short() { + t.Skip("skipping testing in short mode") + } t.Parallel() st := new(testMatcher) diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index 7dcb5680c5..bc17b2ca40 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -44,17 +44,17 @@ func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) { // Map the same data under multiple keys key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) // Add some other data to inflate the trie for j := byte(3); j < 13; j++ { key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) } } trie.Commit(nil) @@ -77,9 +77,9 @@ func TestSecureDelete(t *testing.T) { } for _, val := range vals { if val.v != "" { - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } else { - trie.Delete([]byte(val.k)) + trie.MustDelete([]byte(val.k)) } } hash := trie.Hash() @@ -91,13 +91,13 @@ func TestSecureDelete(t *testing.T) { func TestSecureGetKey(t *testing.T) { trie := newEmptySecure() - trie.Update([]byte("foo"), []byte("bar")) + trie.MustUpdate([]byte("foo"), []byte("bar")) key := []byte("foo") value := []byte("bar") seckey := crypto.Keccak256(key) - if !bytes.Equal(trie.Get(key), value) { + if !bytes.Equal(trie.MustGet(key), value) { t.Errorf("Get did not return bar") } if k := trie.GetKey(seckey); !bytes.Equal(k, key) { @@ -125,15 +125,15 @@ func TestSecureTrieConcurrency(t *testing.T) { for j := byte(0); j < 255; j++ { // Map the same data under multiple keys key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j} - tries[index].Update(key, val) + tries[index].MustUpdate(key, val) key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j} - tries[index].Update(key, val) + tries[index].MustUpdate(key, val) // Add some other data to inflate the trie for k := byte(3); k < 13; k++ { key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j} - tries[index].Update(key, val) + tries[index].MustUpdate(key, val) } } tries[index].Commit(nil) diff --git a/trie/trie_test.go b/trie/trie_test.go index 8087a4a8a9..fdfcf4858b 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -29,10 +29,11 @@ import ( "testing/quick" "github.com/davecgh/go-spew/spew" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb/leveldb" - "github.com/tomochain/tomochain/ethdb/memorydb" "github.com/tomochain/tomochain/rlp" ) @@ -43,7 +44,7 @@ func init() { // Used for testing func newEmpty() *Trie { - trie, _ := New(common.Hash{}, NewDatabase(memorydb.New())) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) return trie } @@ -61,13 +62,13 @@ func TestNull(t *testing.T) { key := make([]byte, 32) value := []byte("test") trie.Update(key, value) - if !bytes.Equal(trie.Get(key), value) { + if !bytes.Equal(trie.MustGet(key), value) { t.Fatal("wrong value") } } func TestMissingRoot(t *testing.T) { - trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(memorydb.New())) + trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(rawdb.NewMemoryDatabase())) if trie != nil { t.Error("New returned non-nil trie for invalid root") } @@ -80,7 +81,7 @@ func TestMissingNodeDisk(t *testing.T) { testMissingNode(t, false) } func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) } func testMissingNode(t *testing.T, memonly bool) { - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) trie, _ := New(common.Hash{}, triedb) @@ -92,27 +93,27 @@ func testMissingNode(t *testing.T, memonly bool) { } trie, _ = New(root, triedb) - _, err := trie.TryGet([]byte("120000")) + _, err := trie.Get([]byte("120000")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(root, triedb) - _, err = trie.TryGet([]byte("120099")) + _, err = trie.Get([]byte("120099")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(root, triedb) - _, err = trie.TryGet([]byte("123456")) + _, err = trie.Get([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(root, triedb) - err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) + err = trie.Update([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(root, triedb) - err = trie.TryDelete([]byte("123456")) + err = trie.Delete([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -125,27 +126,27 @@ func testMissingNode(t *testing.T, memonly bool) { } trie, _ = New(root, triedb) - _, err = trie.TryGet([]byte("120000")) + _, err = trie.Get([]byte("120000")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } trie, _ = New(root, triedb) - _, err = trie.TryGet([]byte("120099")) + _, err = trie.Get([]byte("120099")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } trie, _ = New(root, triedb) - _, err = trie.TryGet([]byte("123456")) + _, err = trie.Get([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(root, triedb) - err = trie.TryUpdate([]byte("120099"), []byte("zxcv")) + err = trie.Update([]byte("120099"), []byte("zxcv")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } trie, _ = New(root, triedb) - err = trie.TryDelete([]byte("123456")) + err = trie.Delete([]byte("123456")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } @@ -403,7 +404,7 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value { } func runRandTest(rt randTest) bool { - triedb := NewDatabase(memorydb.New()) + triedb := NewDatabase(rawdb.NewMemoryDatabase()) tr, _ := New(common.Hash{}, triedb) values := make(map[string]string) // tracks content of the trie @@ -419,7 +420,7 @@ func runRandTest(rt randTest) bool { tr.Delete(step.key) delete(values, string(step.key)) case opGet: - v := tr.Get(step.key) + v := tr.MustGet(step.key) want := values[string(step.key)] if string(v) != want { rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want) @@ -823,15 +824,11 @@ func tempDB() (string, *Database) { if err != nil { panic(fmt.Sprintf("can't create temporary directory: %v", err)) } - diskdb, err := leveldb.New(dir, 256, 0, "") - if err != nil { - panic(fmt.Sprintf("can't create temporary database: %v", err)) - } - return dir, NewDatabase(diskdb) + return dir, NewDatabase(rawdb.NewMemoryDatabase()) } func getString(trie *Trie, k string) []byte { - return trie.Get([]byte(k)) + return trie.MustGet([]byte(k)) } func updateString(trie *Trie, k, v string) { From bf418512602daa7c999aefd52eeb989d09faa8ae Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 3 Aug 2023 17:43:22 +0700 Subject: [PATCH 063/119] Fix GetCommittedState and GetState --- core/state/state_object.go | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/core/state/state_object.go b/core/state/state_object.go index bb40953b69..b4d5c915a4 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -163,18 +163,12 @@ func (c *stateObject) getTrie(db Database) Trie { func (self *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { value := common.Hash{} // Load from DB in case it is missing. - enc, err := self.getTrie(db).GetStorage(self.address, key.Bytes()) + val, err := self.getTrie(db).GetStorage(self.address, key.Bytes()) if err != nil { self.setError(err) return common.Hash{} } - if len(enc) > 0 { - _, content, _, err := rlp.Split(enc) - if err != nil { - self.setError(err) - } - value.SetBytes(content) - } + value.SetBytes(val) return value } @@ -184,18 +178,13 @@ func (self *stateObject) GetState(db Database, key common.Hash) common.Hash { return value } // Load from DB in case it is missing. - enc, err := self.getTrie(db).GetStorage(self.address, key.Bytes()) + val, err := self.getTrie(db).GetStorage(self.address, key.Bytes()) if err != nil { self.setError(err) return common.Hash{} } - if len(enc) > 0 { - _, content, _, err := rlp.Split(enc) - if err != nil { - self.setError(err) - } - value.SetBytes(content) - } + + value.SetBytes(val) if (value != common.Hash{}) { self.cachedStorage[key] = value } From 12d08d02832abea6b30b1e7961d81d0674458236 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 4 Aug 2023 11:52:06 +0700 Subject: [PATCH 064/119] Fix state test --- core/state/state_test.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/core/state/state_test.go b/core/state/state_test.go index 30cca6c361..85cb7ee5b7 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -18,14 +18,16 @@ package state import ( "bytes" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" + checker "gopkg.in/check.v1" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb" - checker "gopkg.in/check.v1" + "github.com/tomochain/tomochain/trie" ) type StateSuite struct { @@ -88,8 +90,9 @@ func (s *StateSuite) TestDump(c *checker.C) { } func (s *StateSuite) SetUpTest(c *checker.C) { - s.db= rawdb.NewMemoryDatabase() - s.state, _ = New(common.Hash{}, NewDatabase(s.db)) + s.db = rawdb.NewMemoryDatabase() + tdb := NewDatabaseWithConfig(s.db, &trie.Config{Preimages: true}) + s.state, _ = New(common.Hash{}, tdb) } func (s *StateSuite) TestNull(c *checker.C) { From 47e5d54ee25c9d7e0034e85518cd11ca2801b775 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 4 Aug 2023 13:59:14 +0700 Subject: [PATCH 065/119] Implement AddUncheckedTx in chain maker --- core/chain_makers.go | 11 ++++++++++- core/database_util.go | 3 --- eth/filters/filter_test.go | 11 ++++++++--- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/core/chain_makers.go b/core/chain_makers.go index ac7c311fd2..986556ca31 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -18,12 +18,12 @@ package core import ( "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/misc" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -115,6 +115,15 @@ func (b *BlockGen) AddTxWithChain(bc *BlockChain, tx *types.Transaction) { } } +// AddUncheckedTx forcefully adds a transaction to the block without any +// validation. +// +// AddUncheckedTx will cause consensus failures when used during real +// chain processing. This is best used in conjunction with raw block insertion. +func (b *BlockGen) AddUncheckedTx(tx *types.Transaction) { + b.txs = append(b.txs, tx) +} + // Number returns the block number of the block being generated. func (b *BlockGen) Number() *big.Int { return new(big.Int).Set(b.header.Number) diff --git a/core/database_util.go b/core/database_util.go index 0ee4eb0e98..e0fcd7f489 100644 --- a/core/database_util.go +++ b/core/database_util.go @@ -24,8 +24,6 @@ import ( "fmt" "math/big" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" @@ -344,7 +342,6 @@ func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64, config if receipts == nil { return nil } - body := GetBody(db, hash, number) if body == nil { log.Error("Missing body but have receipt", "hash", hash, "number", number) diff --git a/eth/filters/filter_test.go b/eth/filters/filter_test.go index bdfb6e37f8..fb619f18fb 100644 --- a/eth/filters/filter_test.go +++ b/eth/filters/filter_test.go @@ -18,7 +18,6 @@ package filters import ( "context" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "os" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/event" @@ -50,7 +50,7 @@ func BenchmarkFilters(b *testing.B) { defer os.RemoveAll(dir) var ( - db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0,"") + db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0, "") mux = new(event.TypeMux) txFeed = new(event.Feed) rmLogsFeed = new(event.Feed) @@ -115,7 +115,7 @@ func TestFilters(t *testing.T) { defer os.RemoveAll(dir) var ( - db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0,"") + db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0, "") mux = new(event.TypeMux) txFeed = new(event.Feed) rmLogsFeed = new(event.Feed) @@ -144,6 +144,7 @@ func TestFilters(t *testing.T) { }, } gen.AddUncheckedReceipt(receipt) + gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil)) case 2: receipt := types.NewReceipt(nil, false, 0) receipt.Logs = []*types.Log{ @@ -153,6 +154,7 @@ func TestFilters(t *testing.T) { }, } gen.AddUncheckedReceipt(receipt) + gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil)) case 998: receipt := types.NewReceipt(nil, false, 0) receipt.Logs = []*types.Log{ @@ -162,6 +164,7 @@ func TestFilters(t *testing.T) { }, } gen.AddUncheckedReceipt(receipt) + gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil)) case 999: receipt := types.NewReceipt(nil, false, 0) receipt.Logs = []*types.Log{ @@ -171,8 +174,10 @@ func TestFilters(t *testing.T) { }, } gen.AddUncheckedReceipt(receipt) + gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil)) } }) + for i, block := range chain { core.WriteBlock(db, block) if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { From 0d0b12417aa720d04016f529ed6726cb8d0f1feb Mon Sep 17 00:00:00 2001 From: Enda Dinh <90235926+endadinh@users.noreply.github.com> Date: Fri, 14 Jul 2023 17:35:13 +0700 Subject: [PATCH 066/119] Dynamic state snapshots --- accounts/abi/bind/backends/simulated.go | 6 +- cmd/evm/runner.go | 4 +- cmd/evm/staterunner.go | 2 +- cmd/tomo/chaincmd.go | 3 +- cmd/tomo/main.go | 2 + cmd/tomo/usage.go | 18 +- cmd/utils/flags.go | 9 + common/bytes.go | 22 + core/blockchain.go | 116 +++- core/blockchain_test.go | 4 +- core/chain_makers.go | 2 +- core/genesis.go | 5 +- core/rawdb/accessors_snapshot.go | 135 ++++ core/rawdb/accessors_state.go | 2 +- core/rawdb/database.go | 150 +++++ core/rawdb/key_length_iterator.go | 47 ++ core/rawdb/schema.go | 55 +- core/state/iterator_test.go | 2 +- core/state/managed_state_test.go | 4 +- core/state/snapshot/account.go | 54 ++ core/state/snapshot/difflayer.go | 533 ++++++++++++++++ core/state/snapshot/difflayer_test.go | 399 ++++++++++++ core/state/snapshot/disklayer.go | 166 +++++ core/state/snapshot/generate.go | 284 +++++++++ core/state/snapshot/iterator.go | 221 +++++++ core/state/snapshot/iterator_binary | 115 ++++ core/state/snapshot/iterator_fast.go | 302 +++++++++ core/state/snapshot/journal.go | 243 +++++++ core/state/snapshot/snapshot.go | 597 ++++++++++++++++++ core/state/snapshot/snapshot_test.go | 348 ++++++++++ core/state/snapshot/sort.go | 36 ++ core/state/state_object.go | 2 +- core/state/state_test.go | 2 +- core/state/statedb.go | 13 +- core/state/sync_test.go | 6 +- core/tx_pool_test.go | 37 +- core/vm/gas_table_test.go | 7 +- core/vm/runtime/runtime.go | 7 +- core/vm/runtime/runtime_test.go | 7 +- eth/api_test.go | 5 +- eth/api_tracer.go | 8 +- eth/backend.go | 8 +- eth/config.go | 2 + eth/fetcher/fetcher.go | 1 + eth/handler_test.go | 11 +- eth/tracers/tracers_test.go | 4 +- go.mod | 2 +- go.sum | 4 +- les/odr_test.go | 6 +- light/odr_test.go | 7 +- light/trie.go | 2 +- tests/state_test.go | 16 +- tests/state_test_util.go | 16 +- tests/vm_test.go | 5 +- tests/vm_test_util.go | 7 +- tomoxlending/lendingstate/lendingitem_test.go | 13 +- 56 files changed, 3960 insertions(+), 124 deletions(-) create mode 100644 core/rawdb/accessors_snapshot.go create mode 100644 core/rawdb/key_length_iterator.go create mode 100644 core/state/snapshot/account.go create mode 100644 core/state/snapshot/difflayer.go create mode 100644 core/state/snapshot/difflayer_test.go create mode 100644 core/state/snapshot/disklayer.go create mode 100644 core/state/snapshot/generate.go create mode 100644 core/state/snapshot/iterator.go create mode 100644 core/state/snapshot/iterator_binary create mode 100644 core/state/snapshot/iterator_fast.go create mode 100644 core/state/snapshot/journal.go create mode 100644 core/state/snapshot/snapshot.go create mode 100644 core/state/snapshot/snapshot_test.go create mode 100644 core/state/snapshot/sort.go diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 6c7ff609de..d960036fba 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -107,7 +107,7 @@ func (b *SimulatedBackend) rollback() { statedb, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database()) + b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database(), nil) } // CodeAt returns the code associated with a certain account in the blockchain. @@ -383,7 +383,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa statedb, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database()) + b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database(), nil) return nil } @@ -462,7 +462,7 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error { statedb, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database()) + b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database(), nil) return nil } diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index f3fcef9c7e..75abb768ce 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -100,11 +100,11 @@ func runCmd(ctx *cli.Context) error { gen := readGenesis(ctx.GlobalString(GenesisFlag.Name)) db := rawdb.NewMemoryDatabase() genesis := gen.ToBlock(db) - statedb, _ = state.New(genesis.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Preimages: preimages})) + statedb, _ = state.New(genesis.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Preimages: preimages}), nil) chainConfig = gen.Config } else { db := rawdb.NewMemoryDatabase() - statedb, _ = state.New(common.Hash{}, state.NewDatabaseWithConfig(db, &trie.Config{Preimages: preimages})) + statedb, _ = state.New(common.Hash{}, state.NewDatabaseWithConfig(db, &trie.Config{Preimages: preimages}), nil) } if ctx.GlobalString(SenderFlag.Name) != "" { sender = common.HexToAddress(ctx.GlobalString(SenderFlag.Name)) diff --git a/cmd/evm/staterunner.go b/cmd/evm/staterunner.go index 5499be6962..018a7c5262 100644 --- a/cmd/evm/staterunner.go +++ b/cmd/evm/staterunner.go @@ -94,7 +94,7 @@ func stateTestCmd(ctx *cli.Context) error { for _, st := range test.Subtests() { // Run the test and aggregate the result result := &StatetestResult{Name: key, Fork: st.Fork, Pass: true} - state, err := test.Run(st, cfg) + state, err := test.Run(st, cfg, false) if err != nil { // Test failed, mark as so and dump any state to aid debugging result.Pass, result.Error = false, err.Error() diff --git a/cmd/tomo/chaincmd.go b/cmd/tomo/chaincmd.go index dc0a274ba5..e1c23c0cdf 100644 --- a/cmd/tomo/chaincmd.go +++ b/cmd/tomo/chaincmd.go @@ -66,6 +66,7 @@ It expects the genesis file as argument.`, utils.CacheFlag, utils.LightModeFlag, utils.GCModeFlag, + utils.SnapshotFlag, utils.CacheDatabaseFlag, utils.CacheGCFlag, }, @@ -450,7 +451,7 @@ func dump(ctx *cli.Context) error { fmt.Println("{}") utils.Fatalf("block not found") } else { - state, err := state.New(block.Root(), state.NewDatabase(chainDb)) + state, err := state.New(block.Root(), state.NewDatabase(chainDb), nil) if err != nil { utils.Fatalf("could not create new state: %v", err) } diff --git a/cmd/tomo/main.go b/cmd/tomo/main.go index 2a606fbb78..1b08a78ced 100644 --- a/cmd/tomo/main.go +++ b/cmd/tomo/main.go @@ -86,6 +86,7 @@ var ( utils.LightModeFlag, utils.SyncModeFlag, utils.GCModeFlag, + utils.SnapshotFlag, //utils.LightServFlag, //utils.LightPeersFlag, //utils.LightKDFFlag, @@ -93,6 +94,7 @@ var ( //utils.CacheDatabaseFlag, //utils.CacheGCFlag, //utils.TrieCacheGenFlag, + utils.CacheSnapshotFlag, utils.ListenPortFlag, utils.MaxPeersFlag, utils.MaxPendingPeersFlag, diff --git a/cmd/tomo/usage.go b/cmd/tomo/usage.go index f166d9aae9..8840af8394 100644 --- a/cmd/tomo/usage.go +++ b/cmd/tomo/usage.go @@ -123,15 +123,15 @@ var AppHelpFlagGroups = []flagGroup{ // utils.TxPoolLifetimeFlag, // }, //}, - //{ - // Name: "PERFORMANCE TUNING", - // Flags: []cli.Flag{ - // utils.CacheFlag, - // utils.CacheDatabaseFlag, - // utils.CacheGCFlag, - // utils.TrieCacheGenFlag, - // }, - //}, + { + Name: "PERFORMANCE TUNING", + Flags: []cli.Flag{ + utils.CacheFlag, + utils.CacheDatabaseFlag, + utils.CacheGCFlag, + utils.CacheSnapshotFlag, + }, + }, { Name: "ACCOUNT", Flags: []cli.Flag{ diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 59a0cdeaf0..90e85528b2 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -190,6 +190,10 @@ var ( Usage: `Blockchain garbage collection mode ("full", "archive")`, Value: "full", } + SnapshotFlag = cli.BoolFlag{ + Name: "snapshot", + Usage: `Enables snapshot-database mode -- experimental work in progress feature`, + } LightServFlag = cli.IntFlag{ Name: "lightserv", Usage: "Maximum percentage of time allowed for serving LES requests (0-90)", @@ -305,6 +309,11 @@ var ( Usage: "Percentage of cache memory allowance to use for trie pruning", Value: 25, } + CacheSnapshotFlag = cli.IntFlag{ + Name: "cache.snapshot", + Usage: "Percentage of cache memory allowance to use for snapshot caching (default = 10% full mode, 20% archive mode)", + Value: 10, + } // Miner settings StakingEnabledFlag = cli.BoolFlag{ Name: "mine", diff --git a/common/bytes.go b/common/bytes.go index ba00e8a4b2..1801cb1cae 100644 --- a/common/bytes.go +++ b/common/bytes.go @@ -119,3 +119,25 @@ func LeftPadBytes(slice []byte, l int) []byte { return padded } + +// TrimLeftZeroes returns a subslice of s without leading zeroes +func TrimLeftZeroes(s []byte) []byte { + idx := 0 + for ; idx < len(s); idx++ { + if s[idx] != 0 { + break + } + } + return s[idx:] +} + +// TrimRightZeroes returns a subslice of s without trailing zeroes +func TrimRightZeroes(s []byte) []byte { + idx := len(s) + for ; idx > 0; idx-- { + if s[idx-1] != 0 { + break + } + } + return s[:idx] +} diff --git a/core/blockchain.go b/core/blockchain.go index 2d77d154d1..e6afc1f78f 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -28,6 +28,8 @@ import ( "sync/atomic" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" "gopkg.in/karalabe/cookiejar.v2/collections/prque" lru "github.com/hashicorp/golang-lru" @@ -37,8 +39,8 @@ import ( "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/posv" contractValidator "github.com/tomochain/tomochain/contracts/validator/contract" - "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" + "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -50,14 +52,38 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" "github.com/tomochain/tomochain/trie" ) var ( - blockInsertTimer = metrics.NewRegisteredTimer("chain/inserts", nil) - CheckpointCh = make(chan int) - ErrNoGenesis = errors.New("Genesis not found in chain") + accountReadTimer = metrics.NewRegisteredTimer("chain/account/reads", nil) + accountHashTimer = metrics.NewRegisteredTimer("chain/account/hashes", nil) + accountUpdateTimer = metrics.NewRegisteredTimer("chain/account/updates", nil) + accountCommitTimer = metrics.NewRegisteredTimer("chain/account/commits", nil) + + storageReadTimer = metrics.NewRegisteredTimer("chain/storage/reads", nil) + storageHashTimer = metrics.NewRegisteredTimer("chain/storage/hashes", nil) + storageUpdateTimer = metrics.NewRegisteredTimer("chain/storage/updates", nil) + storageCommitTimer = metrics.NewRegisteredTimer("chain/storage/commits", nil) + + snapshotAccountReadTimer = metrics.NewRegisteredTimer("chain/snapshot/account/reads", nil) + snapshotStorageReadTimer = metrics.NewRegisteredTimer("chain/snapshot/storage/reads", nil) + snapshotCommitTimer = metrics.NewRegisteredTimer("chain/snapshot/commits", nil) + + blockInsertTimer = metrics.NewRegisteredTimer("chain/inserts", nil) + blockValidationTimer = metrics.NewRegisteredTimer("chain/validation", nil) + blockExecutionTimer = metrics.NewRegisteredTimer("chain/execution", nil) + blockWriteTimer = metrics.NewRegisteredTimer("chain/write", nil) + blockReorgAddMeter = metrics.NewRegisteredMeter("chain/reorg/drop", nil) + blockReorgDropMeter = metrics.NewRegisteredMeter("chain/reorg/add", nil) + + blockPrefetchExecuteTimer = metrics.NewRegisteredTimer("chain/prefetch/executes", nil) + blockPrefetchInterruptMeter = metrics.NewRegisteredMeter("chain/prefetch/interrupts", nil) + + errInsertionInterrupted = errors.New("insertion is interrupted") + + CheckpointCh = make(chan int) + ErrNoGenesis = errors.New("Genesis not found in chain") ) const ( @@ -81,6 +107,9 @@ type CacheConfig struct { Disabled bool // Whether to disable trie write caching (archive node) TrieNodeLimit int // Memory limit (MB) at which to flush the current in-memory trie to disk TrieTimeLimit time.Duration // Time limit after which to flush the current in-memory trie to disk + SnapshotLimit int // Memory allowance (MB) to use for caching snapshot entries in memory + + SnapshotWait bool // Wait for snapshot construction on startup. TODO(karalabe): This is a dirty hack for testing, nuke it } // defaultCacheConfig are the default caching values if none are specified by the @@ -120,8 +149,9 @@ type BlockChain struct { db ethdb.Database // Low level persistent database to store final content in tomoxDb ethdb.TomoxDatabase - triegc *prque.Prque // Priority queue mapping block numbers to tries to gc - gcproc time.Duration // Accumulates canonical block processing for trie dumping + snaps *snapshot.Tree // Snapshot tree for fast trie leaf access + triegc *prque.Prque // Priority queue mapping block numbers to tries to gc + gcproc time.Duration // Accumulates canonical block processing for trie dumping hc *HeaderChain rmLogsFeed event.Feed @@ -183,6 +213,8 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par cacheConfig = &CacheConfig{ TrieNodeLimit: 256 * 1024 * 1024, TrieTimeLimit: 5 * time.Minute, + SnapshotLimit: 256, + SnapshotWait: true, } } bodyCache, _ := lru.New(bodyCacheLimit) @@ -255,6 +287,10 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par } } } + // Load any existing snapshot, regenerating it if loading failed + if bc.cacheConfig.SnapshotLimit > 0 { + bc.snaps = snapshot.New(bc.db, bc.stateCache.TrieDB(), bc.cacheConfig.SnapshotLimit, bc.CurrentBlock().Root(), !bc.cacheConfig.SnapshotWait) + } // Take ownership of this particular state go bc.update() return bc, nil @@ -297,12 +333,26 @@ func (bc *BlockChain) loadLastState() error { log.Warn("Head block missing, resetting chain", "hash", head) return bc.Reset() } + // Make sure the state associated with the block is available + if _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps); err != nil { + // Dangling block without a state associated, init from scratch + log.Warn("Head state missing, repairing chain", "number", currentBlock.Number(), "hash", currentBlock.Hash()) + if err := bc.repair(¤tBlock); err != nil { + return err + } + rawdb.WriteHeadBlockHash(bc.db, currentBlock.Hash()) + } + + // Everything seems to be fine, set as the head block + bc.currentBlock.Store(currentBlock) + repair := false if common.Rewound != uint64(0) { repair = true } // Make sure the state associated with the block is available - _, err := state.New(currentBlock.Root(), bc.stateCache) + _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps) + // err != nil{} if err != nil { repair = true } else { @@ -410,7 +460,7 @@ func (bc *BlockChain) SetHead(head uint64) error { bc.currentBlock.Store(bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64())) } if currentBlock := bc.CurrentBlock(); currentBlock != nil { - if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil { + if _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps); err != nil { // Rewound state missing, rolled back to before pivot, reset to genesis bc.currentBlock.Store(bc.genesisBlock) } @@ -453,6 +503,11 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error { bc.currentBlock.Store(block) bc.mu.Unlock() + // Destroy any existing state snapshot and regenerate it in the background + if bc.snaps != nil { + log.Info("Destroy any existing state snapshot and regenerate it in the background", "Snapshot", bc.snaps) + bc.snaps.Rebuild(block.Root()) + } log.Info("Committed new head block", "number", block.Number(), "hash", hash) return nil } @@ -509,7 +564,7 @@ func (bc *BlockChain) State() (*state.StateDB, error) { // StateAt returns a new mutable state based on a particular point in time. func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) { - return state.New(root, bc.stateCache) + return state.New(root, bc.stateCache, bc.snaps) } // OrderStateAt returns a new mutable state based on a particular point in time. @@ -600,7 +655,7 @@ func (bc *BlockChain) repair(head **types.Block) error { for { // Abort if we've rewound to a head block that does have associated state if (common.Rewound == uint64(0)) || ((*head).Number().Uint64() < common.Rewound) { - if _, err := state.New((*head).Root(), bc.stateCache); err == nil { + if _, err := state.New((*head).Root(), bc.stateCache, bc.snaps); err == nil { log.Info("Rewound blockchain to past state", "number", (*head).Number(), "hash", (*head).Hash()) engine, ok := bc.Engine().(*posv.Posv) if ok { @@ -882,6 +937,14 @@ func (bc *BlockChain) SaveData() { // Make sure no inconsistent state is leaked during insertion bc.mu.Lock() defer bc.mu.Unlock() + // Ensure that the entirety of the state snapshot is journalled to disk. + var snapBase common.Hash + if bc.snaps != nil { + var err error + if snapBase, err = bc.snaps.Journal(bc.CurrentBlock().Root()); err != nil { + log.Error("Failed to journal state snapshot", "err", err) + } + } // Ensure the state of a recent block is also stored to disk before exiting. // We're writing three different states to catch different restart scenarios: // - HEAD: So we don't need to reprocess any blocks in the general case @@ -933,6 +996,12 @@ func (bc *BlockChain) SaveData() { } } } + if snapBase != (common.Hash{}) { + log.Info("Writing snapshot state to disk", "root", snapBase) + if err := triedb.Commit(snapBase, true); err != nil { + log.Error("Failed to commit recent state trie", "err", err) + } + } for !bc.triegc.Empty() { triedb.Dereference(bc.triegc.PopItem().(common.Hash)) } @@ -1536,7 +1605,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty } else { parent = chain[i-1] } - statedb, err := state.New(parent.Root(), bc.stateCache) + statedb, err := state.New(parent.Root(), bc.stateCache, bc.snaps) if err != nil { return i, events, coalescedLogs, err } @@ -1642,6 +1711,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty } feeCapacity := state.GetTRC21FeeCapacityFromStateWithCache(parent.Root(), statedb) // Process block using the parent state as reference point. + substart := time.Now() receipts, logs, usedGas, err := bc.processor.Process(block, statedb, tradingState, bc.vmConfig, feeCapacity) if err != nil { bc.reportBlock(block, receipts, err) @@ -1653,12 +1723,32 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty bc.reportBlock(block, receipts, err) return i, events, coalescedLogs, err } + // Update the metrics touched during block processing + accountReadTimer.Update(statedb.AccountReads) // Account reads are complete, we can mark them + storageReadTimer.Update(statedb.StorageReads) // Storage reads are complete, we can mark them + accountUpdateTimer.Update(statedb.AccountUpdates) // Account updates are complete, we can mark them + storageUpdateTimer.Update(statedb.StorageUpdates) // Storage updates are complete, we can mark them + snapshotAccountReadTimer.Update(statedb.SnapshotAccountReads) // Account reads are complete, we can mark them + snapshotStorageReadTimer.Update(statedb.SnapshotStorageReads) // Storage reads are complete, we can mark them + + triehash := statedb.AccountHashes + statedb.StorageHashes // Save to not double count in validation + trieproc := statedb.SnapshotAccountReads + statedb.AccountReads + statedb.AccountUpdates + trieproc += statedb.SnapshotStorageReads + statedb.StorageReads + statedb.StorageUpdates + + blockExecutionTimer.Update(time.Since(substart) - trieproc - triehash) + proctime := time.Since(bstart) // Write the block to the chain and get the status. status, err := bc.WriteBlockWithState(block, receipts, statedb, tradingState, lendingState) if err != nil { return i, events, coalescedLogs, err } + + // Update the metrics touched during block commit + accountCommitTimer.Update(statedb.AccountCommits) // Account commits are complete, we can mark them + storageCommitTimer.Update(statedb.StorageCommits) // Storage commits are complete, we can mark them + snapshotCommitTimer.Update(statedb.SnapshotCommits) // Snapshot commits are complete, we can mark them + if bc.chainConfig.Posv != nil { c := bc.engine.(*posv.Posv) coinbase := c.Signer() @@ -1828,7 +1918,7 @@ func (bc *BlockChain) getResultBlock(block *types.Block, verifiedM2 bool) (*Resu // Create a new statedb using the parent block and report an // error if it fails. var parent = bc.GetBlock(block.ParentHash(), block.NumberU64()-1) - statedb, err := state.New(parent.Root(), bc.stateCache) + statedb, err := state.New(parent.Root(), bc.stateCache, bc.snaps) if err != nil { return nil, err } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index dd8ea644e3..08ed2f3ba9 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -24,6 +24,8 @@ import ( "testing" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core/rawdb" @@ -114,7 +116,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error { } return err } - statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.stateCache) + statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.stateCache, nil) if err != nil { return err } diff --git a/core/chain_makers.go b/core/chain_makers.go index 986556ca31..1e2aeb88b0 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -234,7 +234,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse return nil, nil } for i := 0; i < n; i++ { - statedb, err := state.New(parent.Root(), state.NewDatabase(db)) + statedb, err := state.New(parent.Root(), state.NewDatabase(db), nil) if err != nil { panic(err) } diff --git a/core/genesis.go b/core/genesis.go index d068960cb1..77970085fd 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -25,10 +25,11 @@ import ( "math/big" "strings" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" - "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" @@ -227,7 +228,7 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block { if db == nil { db = rawdb.NewMemoryDatabase() } - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) for addr, account := range g.Alloc { statedb.AddBalance(addr, account.Balance) statedb.SetCode(addr, account.Code) diff --git a/core/rawdb/accessors_snapshot.go b/core/rawdb/accessors_snapshot.go new file mode 100644 index 0000000000..6ef285019b --- /dev/null +++ b/core/rawdb/accessors_snapshot.go @@ -0,0 +1,135 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import ( + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" +) + +// ReadSnapshotRoot retrieves the root of the block whose state is contained in +// the persisted snapshot. +func ReadSnapshotRoot(db ethdb.KeyValueReader) common.Hash { + data, _ := db.Get(snapshotRootKey) + if len(data) != common.HashLength { + return common.Hash{} + } + return common.BytesToHash(data) +} + +// WriteSnapshotRoot stores the root of the block whose state is contained in +// the persisted snapshot. +func WriteSnapshotRoot(db ethdb.KeyValueWriter, root common.Hash) { + if err := db.Put(snapshotRootKey, root[:]); err != nil { + log.Crit("Failed to store snapshot root", "err", err) + } +} + +// DeleteSnapshotRoot deletes the hash of the block whose state is contained in +// the persisted snapshot. Since snapshots are not immutable, this method can +// be used during updates, so a crash or failure will mark the entire snapshot +// invalid. +func DeleteSnapshotRoot(db ethdb.KeyValueWriter) { + if err := db.Delete(snapshotRootKey); err != nil { + log.Crit("Failed to remove snapshot root", "err", err) + } +} + +// ReadAccountSnapshot retrieves the snapshot entry of an account trie leaf. +func ReadAccountSnapshot(db ethdb.KeyValueReader, hash common.Hash) []byte { + data, _ := db.Get(accountSnapshotKey(hash)) + return data +} + +// WriteAccountSnapshot stores the snapshot entry of an account trie leaf. +func WriteAccountSnapshot(db ethdb.KeyValueWriter, hash common.Hash, entry []byte) { + if err := db.Put(accountSnapshotKey(hash), entry); err != nil { + log.Crit("Failed to store account snapshot", "err", err) + } +} + +// DeleteAccountSnapshot removes the snapshot entry of an account trie leaf. +func DeleteAccountSnapshot(db ethdb.KeyValueWriter, hash common.Hash) { + if err := db.Delete(accountSnapshotKey(hash)); err != nil { + log.Crit("Failed to delete account snapshot", "err", err) + } +} + +// ReadStorageSnapshot retrieves the snapshot entry of an storage trie leaf. +func ReadStorageSnapshot(db ethdb.KeyValueReader, accountHash, storageHash common.Hash) []byte { + data, _ := db.Get(storageSnapshotKey(accountHash, storageHash)) + return data +} + +// WriteStorageSnapshot stores the snapshot entry of an storage trie leaf. +func WriteStorageSnapshot(db ethdb.KeyValueWriter, accountHash, storageHash common.Hash, entry []byte) { + if err := db.Put(storageSnapshotKey(accountHash, storageHash), entry); err != nil { + log.Crit("Failed to store storage snapshot", "err", err) + } +} + +// DeleteStorageSnapshot removes the snapshot entry of an storage trie leaf. +func DeleteStorageSnapshot(db ethdb.KeyValueWriter, accountHash, storageHash common.Hash) { + if err := db.Delete(storageSnapshotKey(accountHash, storageHash)); err != nil { + log.Crit("Failed to delete storage snapshot", "err", err) + } +} + +// IterateStorageSnapshots returns an iterator for walking the entire storage +// space of a specific account. +func IterateStorageSnapshots(db ethdb.Iteratee, accountHash common.Hash) ethdb.Iterator { + return NewKeyLengthIterator(db.NewIterator(storageSnapshotsKey(accountHash), nil), len(SnapshotStoragePrefix)+2*common.HashLength) +} + +// ReadSnapshotJournal retrieves the serialized in-memory diff layers saved at +// the last shutdown. The blob is expected to be max a few 10s of megabytes. +func ReadSnapshotJournal(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(snapshotJournalKey) + return data +} + +// WriteSnapshotJournal stores the serialized in-memory diff layers to save at +// shutdown. The blob is expected to be max a few 10s of megabytes. +func WriteSnapshotJournal(db ethdb.KeyValueWriter, journal []byte) { + if err := db.Put(snapshotJournalKey, journal); err != nil { + log.Crit("Failed to store snapshot journal", "err", err) + } +} + +// DeleteSnapshotJournal deletes the serialized in-memory diff layers saved at +// the last shutdown +func DeleteSnapshotJournal(db ethdb.KeyValueWriter) { + if err := db.Delete(snapshotJournalKey); err != nil { + log.Crit("Failed to remove snapshot journal", "err", err) + } +} + +// ReadSnapshotGenerator retrieves the serialized snapshot generator saved at +// the last shutdown. +func ReadSnapshotGenerator(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(snapshotGeneratorKey) + return data +} + +// WriteSnapshotGenerator stores the serialized snapshot generator to save at +// shutdown. +func WriteSnapshotGenerator(db ethdb.KeyValueWriter, generator []byte) { + if err := db.Put(snapshotGeneratorKey, generator); err != nil { + log.Crit("Failed to store snapshot generator", "err", err) + } +} diff --git a/core/rawdb/accessors_state.go b/core/rawdb/accessors_state.go index 23048f0f8c..28bba40f3c 100644 --- a/core/rawdb/accessors_state.go +++ b/core/rawdb/accessors_state.go @@ -25,7 +25,7 @@ import ( // PreimageTable returns a Database instance with the key prefix for preimage entries. func PreimageTable(db ethdb.Database) ethdb.Database { - return NewTable(db, preimagePrefix) + return NewTable(db, PreimagePrefix) } // ReadPreimage retrieves a single preimage of the provided hash. diff --git a/core/rawdb/database.go b/core/rawdb/database.go index cf80d12d0d..0ebedd7f7f 100644 --- a/core/rawdb/database.go +++ b/core/rawdb/database.go @@ -17,7 +17,13 @@ package rawdb import ( + "bytes" "fmt" + "os" + "time" + + "github.com/olekukonko/tablewriter" + "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/ethdb" @@ -140,3 +146,147 @@ func (s *stat) Size() string { func (s *stat) Count() string { return s.count.String() } + +// InspectDatabase traverses the entire database and checks the size +// of all different categories of data. +func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { + it := db.NewIterator(keyPrefix, keyStart) + defer it.Release() + + var ( + count int64 + start = time.Now() + logged = time.Now() + + // Key-value store statistics + headers stat + bodies stat + receipts stat + tds stat + numHashPairings stat + hashNumPairings stat + tries stat + codes stat + txLookups stat + accountSnaps stat + storageSnaps stat + preimages stat + bloomBits stat + cliqueSnaps stat + + // Ancient store statistics + ancientHeadersSize common.StorageSize + ancientBodiesSize common.StorageSize + ancientReceiptsSize common.StorageSize + ancientTdsSize common.StorageSize + ancientHashesSize common.StorageSize + + // Les statistic + chtTrieNodes stat + bloomTrieNodes stat + + // Meta- and unaccounted data + metadata stat + unaccounted stat + + // Totals + total common.StorageSize + ) + // Inspect key-value database first. + for it.Next() { + var ( + key = it.Key() + size = common.StorageSize(len(key) + len(it.Value())) + ) + total += size + switch { + case bytes.HasPrefix(key, headerPrefix) && len(key) == (len(headerPrefix)+8+common.HashLength): + headers.Add(size) + case bytes.HasPrefix(key, blockBodyPrefix) && len(key) == (len(blockBodyPrefix)+8+common.HashLength): + bodies.Add(size) + case bytes.HasPrefix(key, blockReceiptsPrefix) && len(key) == (len(blockReceiptsPrefix)+8+common.HashLength): + receipts.Add(size) + case bytes.HasPrefix(key, headerPrefix) && bytes.HasSuffix(key, headerTDSuffix): + tds.Add(size) + case bytes.HasPrefix(key, headerPrefix) && bytes.HasSuffix(key, headerHashSuffix): + numHashPairings.Add(size) + case bytes.HasPrefix(key, headerNumberPrefix) && len(key) == (len(headerNumberPrefix)+common.HashLength): + hashNumPairings.Add(size) + case len(key) == common.HashLength: + tries.Add(size) + case bytes.HasPrefix(key, txLookupPrefix) && len(key) == (len(txLookupPrefix)+common.HashLength): + txLookups.Add(size) + case bytes.HasPrefix(key, SnapshotAccountPrefix) && len(key) == (len(SnapshotAccountPrefix)+common.HashLength): + accountSnaps.Add(size) + case bytes.HasPrefix(key, SnapshotStoragePrefix) && len(key) == (len(SnapshotStoragePrefix)+2*common.HashLength): + storageSnaps.Add(size) + case bytes.HasPrefix(key, []byte(PreimagePrefix)) && len(key) == (len(PreimagePrefix)+common.HashLength): + preimages.Add(size) + case bytes.HasPrefix(key, bloomBitsPrefix) && len(key) == (len(bloomBitsPrefix)+10+common.HashLength): + bloomBits.Add(size) + case bytes.HasPrefix(key, []byte("clique-")) && len(key) == 7+common.HashLength: + cliqueSnaps.Add(size) + case bytes.HasPrefix(key, []byte("cht-")) && len(key) == 4+common.HashLength: + chtTrieNodes.Add(size) + case bytes.HasPrefix(key, []byte("blt-")) && len(key) == 4+common.HashLength: + bloomTrieNodes.Add(size) + default: + var accounted bool + for _, meta := range [][]byte{databaseVersionKey, headHeaderKey, headBlockKey, headFastBlockKey, fastTrieProgressKey} { + if bytes.Equal(key, meta) { + metadata.Add(size) + accounted = true + break + } + } + if !accounted { + unaccounted.Add(size) + } + } + count += 1 + if count%1000 == 0 && time.Since(logged) > 8*time.Second { + log.Info("Inspecting database", "count", count, "elapsed", common.PrettyDuration(time.Since(start))) + logged = time.Now() + } + } + // Get number of ancient rows inside the freezer + ancients := counter(0) + if count, err := db.Ancients(); err == nil { + ancients = counter(count) + } + // Display the database statistic. + stats := [][]string{ + {"Key-Value store", "Headers", headers.Size(), headers.Count()}, + {"Key-Value store", "Bodies", bodies.Size(), bodies.Count()}, + {"Key-Value store", "Receipt lists", receipts.Size(), receipts.Count()}, + {"Key-Value store", "Difficulties", tds.Size(), tds.Count()}, + {"Key-Value store", "Block number->hash", numHashPairings.Size(), numHashPairings.Count()}, + {"Key-Value store", "Block hash->number", hashNumPairings.Size(), hashNumPairings.Count()}, + {"Key-Value store", "Transaction index", txLookups.Size(), txLookups.Count()}, + {"Key-Value store", "Bloombit index", bloomBits.Size(), bloomBits.Count()}, + {"Key-Value store", "Contract codes", codes.Size(), codes.Count()}, + {"Key-Value store", "Trie nodes", tries.Size(), tries.Count()}, + {"Key-Value store", "Trie preimages", preimages.Size(), preimages.Count()}, + {"Key-Value store", "Account snapshot", accountSnaps.Size(), accountSnaps.Count()}, + {"Key-Value store", "Storage snapshot", storageSnaps.Size(), storageSnaps.Count()}, + {"Key-Value store", "Clique snapshots", cliqueSnaps.Size(), cliqueSnaps.Count()}, + {"Key-Value store", "Singleton metadata", metadata.Size(), metadata.Count()}, + {"Ancient store", "Headers", ancientHeadersSize.String(), ancients.String()}, + {"Ancient store", "Bodies", ancientBodiesSize.String(), ancients.String()}, + {"Ancient store", "Receipt lists", ancientReceiptsSize.String(), ancients.String()}, + {"Ancient store", "Difficulties", ancientTdsSize.String(), ancients.String()}, + {"Ancient store", "Block number->hash", ancientHashesSize.String(), ancients.String()}, + {"Light client", "CHT trie nodes", chtTrieNodes.Size(), chtTrieNodes.Count()}, + {"Light client", "Bloom trie nodes", bloomTrieNodes.Size(), bloomTrieNodes.Count()}, + } + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"Database", "Category", "Size", "Items"}) + table.SetFooter([]string{"", "Total", total.String(), " "}) + table.AppendBulk(stats) + table.Render() + + if unaccounted.size > 0 { + log.Error("Database contains unaccounted data", "size", unaccounted.size, "count", unaccounted.count) + } + return nil +} diff --git a/core/rawdb/key_length_iterator.go b/core/rawdb/key_length_iterator.go new file mode 100644 index 0000000000..9e24f0ec32 --- /dev/null +++ b/core/rawdb/key_length_iterator.go @@ -0,0 +1,47 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import "github.com/tomochain/tomochain/ethdb" + +// KeyLengthIterator is a wrapper for a database iterator that ensures only key-value pairs +// with a specific key length will be returned. +type KeyLengthIterator struct { + requiredKeyLength int + ethdb.Iterator +} + +// NewKeyLengthIterator returns a wrapped version of the iterator that will only return key-value +// pairs where keys with a specific key length will be returned. +func NewKeyLengthIterator(it ethdb.Iterator, keyLen int) ethdb.Iterator { + return &KeyLengthIterator{ + Iterator: it, + requiredKeyLength: keyLen, + } +} + +func (it *KeyLengthIterator) Next() bool { + // Return true as soon as a key with the required key length is discovered + for it.Iterator.Next() { + if len(it.Iterator.Key()) == it.requiredKeyLength { + return true + } + } + + // Return false when we exhaust the keys in the underlying iterator. + return false +} diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index 528f0e15ee..b49e238ab6 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -26,22 +26,42 @@ import ( ) var ( + // databaseVersionKey tracks the current database version. + databaseVersionKey = []byte("DatabaseVersion") + + // headFastBlockKey tracks the latest known incomplete block's hash during fast sync. + headFastBlockKey = []byte("LastFast") + + // fastTrieProgressKey tracks the number of trie entries imported during fast sync. + fastTrieProgressKey = []byte("TrieSync") + + // snapshotRootKey tracks the hash of the last snapshot. + snapshotRootKey = []byte("SnapshotRoot") + + // snapshotJournalKey tracks the in-memory diff layers across restarts. + snapshotJournalKey = []byte("SnapshotJournal") + + // snapshotGeneratorKey tracks the snapshot generation marker across restarts. + snapshotGeneratorKey = []byte("SnapshotGenerator") + headHeaderKey = []byte("LastHeader") headBlockKey = []byte("LastBlock") headFastKey = []byte("LastFast") trieSyncKey = []byte("TrieSync") // Data item prefixes (use single byte to avoid mixing data types, avoid `i`). - headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header - headerTDSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + headerTDSuffix -> td - headerHashSuffix = []byte("n") // headerPrefix + num (uint64 big endian) + headerHashSuffix -> hash - headerNumberPrefix = []byte("H") // headerNumberPrefix + hash -> num (uint64 big endian) - blockBodyPrefix = []byte("b") // blockBodyPrefix + num (uint64 big endian) + hash -> block body - blockReceiptsPrefix = []byte("r") // blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts - txLookupPrefix = []byte("l") // txLookupPrefix + hash -> transaction/receipt lookup metadata - bloomBitsPrefix = []byte("B") // bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash -> bloom bits - - preimagePrefix = "secure-key-" // preimagePrefix + hash -> preimage + headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header + headerTDSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + headerTDSuffix -> td + headerHashSuffix = []byte("n") // headerPrefix + num (uint64 big endian) + headerHashSuffix -> hash + headerNumberPrefix = []byte("H") // headerNumberPrefix + hash -> num (uint64 big endian) + blockBodyPrefix = []byte("b") // blockBodyPrefix + num (uint64 big endian) + hash -> block body + blockReceiptsPrefix = []byte("r") // blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts + txLookupPrefix = []byte("l") // txLookupPrefix + hash -> transaction/receipt lookup metadata + bloomBitsPrefix = []byte("B") // bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash -> bloom bits + SnapshotAccountPrefix = []byte("a") // SnapshotAccountPrefix + account hash -> account trie value + SnapshotStoragePrefix = []byte("o") // SnapshotStoragePrefix + account hash + storage hash -> storage trie value + + PreimagePrefix = "secure-key-" // PreimagePrefix + hash -> preimage configPrefix = []byte("ethereum-config-") // config prefix for the db // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress @@ -129,3 +149,18 @@ func oldTxMetaKey(hash common.Hash) []byte { func oldReceiptsKey(hash common.Hash) []byte { return append(oldReceiptsPrefix, hash[:]...) } + +// accountSnapshotKey = SnapshotAccountPrefix + hash +func accountSnapshotKey(hash common.Hash) []byte { + return append(SnapshotAccountPrefix, hash.Bytes()...) +} + +// storageSnapshotKey = SnapshotStoragePrefix + account hash + storage hash +func storageSnapshotKey(accountHash, storageHash common.Hash) []byte { + return append(append(SnapshotStoragePrefix, accountHash.Bytes()...), storageHash.Bytes()...) +} + +// storageSnapshotsKey = SnapshotStoragePrefix + account hash + storage hash +func storageSnapshotsKey(accountHash common.Hash) []byte { + return append(SnapshotStoragePrefix, accountHash.Bytes()...) +} diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go index 20864e0768..b5b7287025 100644 --- a/core/state/iterator_test.go +++ b/core/state/iterator_test.go @@ -29,7 +29,7 @@ func TestNodeIteratorCoverage(t *testing.T) { // Create some arbitrary test state to iterate db, root, _ := makeTestState() - state, err := New(root, db) + state, err := New(root, db, nil) if err != nil { t.Fatalf("failed to create state trie at %x: %v", root, err) } diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index 46deebdfc2..1d19b087da 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -19,6 +19,8 @@ package state import ( "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/rawdb" ) @@ -27,7 +29,7 @@ var addr = common.BytesToAddress([]byte("test")) func create() (*ManagedState, *account) { db := rawdb.NewMemoryDatabase() - statedb, _ := New(common.Hash{}, NewDatabase(db)) + statedb, _ := New(common.Hash{}, NewDatabase(db), nil) ms := ManageState(statedb) ms.StateDB.SetNonce(addr, 100) ms.accounts[addr] = newAccount(ms.StateDB.getStateObject(addr)) diff --git a/core/state/snapshot/account.go b/core/state/snapshot/account.go new file mode 100644 index 0000000000..3177f25d92 --- /dev/null +++ b/core/state/snapshot/account.go @@ -0,0 +1,54 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "math/big" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/rlp" +) + +// Account is a slim version of a state.Account, where the root and code hash +// are replaced with a nil byte slice for empty accounts. +type Account struct { + Nonce uint64 + Balance *big.Int + Root []byte + CodeHash []byte +} + +// AccountRLP converts a state.Account content into a slim snapshot version RLP +// encoded. +func AccountRLP(nonce uint64, balance *big.Int, root common.Hash, codehash []byte) []byte { + slim := Account{ + Nonce: nonce, + Balance: balance, + } + if root != emptyRoot { + slim.Root = root[:] + } + if !bytes.Equal(codehash, emptyCode[:]) { + slim.CodeHash = codehash + } + data, err := rlp.EncodeToBytes(slim) + if err != nil { + panic(err) + } + return data +} diff --git a/core/state/snapshot/difflayer.go b/core/state/snapshot/difflayer.go new file mode 100644 index 0000000000..1a75761bf8 --- /dev/null +++ b/core/state/snapshot/difflayer.go @@ -0,0 +1,533 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "encoding/binary" + "fmt" + "math" + "math/rand" + "sort" + "sync" + "sync/atomic" + "time" + + bloomfilter "github.com/holiman/bloomfilter/v2" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/rlp" +) + +var ( + // aggregatorMemoryLimit is the maximum size of the bottom-most diff layer + // that aggregates the writes from above until it's flushed into the disk + // layer. + // + // Note, bumping this up might drastically increase the size of the bloom + // filters that's stored in every diff layer. Don't do that without fully + // understanding all the implications. + aggregatorMemoryLimit = uint64(4 * 1024 * 1024) + + // aggregatorItemLimit is an approximate number of items that will end up + // in the agregator layer before it's flushed out to disk. A plain account + // weighs around 14B (+hash), a storage slot 32B (+hash), a deleted slot + // 0B (+hash). Slots are mostly set/unset in lockstep, so thet average at + // 16B (+hash). All in all, the average entry seems to be 15+32=47B. Use a + // smaller number to be on the safe side. + aggregatorItemLimit = aggregatorMemoryLimit / 42 + + // bloomTargetError is the target false positive rate when the aggregator + // layer is at its fullest. The actual value will probably move around up + // and down from this number, it's mostly a ballpark figure. + // + // Note, dropping this down might drastically increase the size of the bloom + // filters that's stored in every diff layer. Don't do that without fully + // understanding all the implications. + bloomTargetError = 0.02 + + // bloomSize is the ideal bloom filter size given the maximum number of items + // it's expected to hold and the target false positive error rate. + bloomSize = math.Ceil(float64(aggregatorItemLimit) * math.Log(bloomTargetError) / math.Log(1/math.Pow(2, math.Log(2)))) + + // bloomFuncs is the ideal number of bits a single entry should set in the + // bloom filter to keep its size to a minimum (given it's size and maximum + // entry count). + bloomFuncs = math.Round((bloomSize / float64(aggregatorItemLimit)) * math.Log(2)) + + // the bloom offsets are runtime constants which determines which part of the + // the account/storage hash the hasher functions looks at, to determine the + // bloom key for an account/slot. This is randomized at init(), so that the + // global population of nodes do not all display the exact same behaviour with + // regards to bloom content + bloomDestructHasherOffset = 0 + bloomAccountHasherOffset = 0 + bloomStorageHasherOffset = 0 +) + +func init() { + // Init the bloom offsets in the range [0:24] (requires 8 bytes) + bloomDestructHasherOffset = rand.Intn(25) + bloomAccountHasherOffset = rand.Intn(25) + bloomStorageHasherOffset = rand.Intn(25) + + // The destruct and account blooms must be different, as the storage slots + // will check for destruction too for every bloom miss. It should not collide + // with modified accounts. + for bloomAccountHasherOffset == bloomDestructHasherOffset { + bloomAccountHasherOffset = rand.Intn(25) + } +} + +// diffLayer represents a collection of modifications made to a state snapshot +// after running a block on top. It contains one sorted list for the account trie +// and one-one list for each storage tries. +// +// The goal of a diff layer is to act as a journal, tracking recent modifications +// made to the state, that have not yet graduated into a semi-immutable state. +type diffLayer struct { + origin *diskLayer // Base disk layer to directly use on bloom misses + parent snapshot // Parent snapshot modified by this one, never nil + memory uint64 // Approximate guess as to how much memory we use + + root common.Hash // Root hash to which this snapshot diff belongs to + stale uint32 // Signals that the layer became stale (state progressed) + + destructSet map[common.Hash]struct{} // Keyed markers for deleted (and potentially) recreated accounts + accountList []common.Hash // List of account for iteration. If it exists, it's sorted, otherwise it's nil + accountData map[common.Hash][]byte // Keyed accounts for direct retrival (nil means deleted) + storageList map[common.Hash][]common.Hash // List of storage slots for iterated retrievals, one per account. Any existing lists are sorted if non-nil + storageData map[common.Hash]map[common.Hash][]byte // Keyed storage slots for direct retrival. one per account (nil means deleted) + + diffed *bloomfilter.Filter // Bloom filter tracking all the diffed items up to the disk layer + + lock sync.RWMutex +} + +// destructBloomHasher is a wrapper around a common.Hash to satisfy the interface +// API requirements of the bloom library used. It's used to convert a destruct +// event into a 64 bit mini hash. +type destructBloomHasher common.Hash + +func (h destructBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") } +func (h destructBloomHasher) Sum(b []byte) []byte { panic("not implemented") } +func (h destructBloomHasher) Reset() { panic("not implemented") } +func (h destructBloomHasher) BlockSize() int { panic("not implemented") } +func (h destructBloomHasher) Size() int { return 8 } +func (h destructBloomHasher) Sum64() uint64 { + return binary.BigEndian.Uint64(h[bloomDestructHasherOffset : bloomDestructHasherOffset+8]) +} + +// accountBloomHasher is a wrapper around a common.Hash to satisfy the interface +// API requirements of the bloom library used. It's used to convert an account +// hash into a 64 bit mini hash. +type accountBloomHasher common.Hash + +func (h accountBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") } +func (h accountBloomHasher) Sum(b []byte) []byte { panic("not implemented") } +func (h accountBloomHasher) Reset() { panic("not implemented") } +func (h accountBloomHasher) BlockSize() int { panic("not implemented") } +func (h accountBloomHasher) Size() int { return 8 } +func (h accountBloomHasher) Sum64() uint64 { + return binary.BigEndian.Uint64(h[bloomAccountHasherOffset : bloomAccountHasherOffset+8]) +} + +// storageBloomHasher is a wrapper around a [2]common.Hash to satisfy the interface +// API requirements of the bloom library used. It's used to convert an account +// hash into a 64 bit mini hash. +type storageBloomHasher [2]common.Hash + +func (h storageBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") } +func (h storageBloomHasher) Sum(b []byte) []byte { panic("not implemented") } +func (h storageBloomHasher) Reset() { panic("not implemented") } +func (h storageBloomHasher) BlockSize() int { panic("not implemented") } +func (h storageBloomHasher) Size() int { return 8 } +func (h storageBloomHasher) Sum64() uint64 { + return binary.BigEndian.Uint64(h[0][bloomStorageHasherOffset:bloomStorageHasherOffset+8]) ^ + binary.BigEndian.Uint64(h[1][bloomStorageHasherOffset:bloomStorageHasherOffset+8]) +} + +// newDiffLayer creates a new diff on top of an existing snapshot, whether that's a low +// level persistent database or a hierarchical diff already. +func newDiffLayer(parent snapshot, root common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer { + // Create the new layer with some pre-allocated data segments + dl := &diffLayer{ + parent: parent, + root: root, + destructSet: destructs, + accountData: accounts, + storageData: storage, + } + switch parent := parent.(type) { + case *diskLayer: + dl.rebloom(parent) + case *diffLayer: + dl.rebloom(parent.origin) + default: + panic("unknown parent type") + } + // Sanity check that accounts or storage slots are never nil + for accountHash, blob := range accounts { + if blob == nil { + panic(fmt.Sprintf("account %#x nil", accountHash)) + } + } + for accountHash, slots := range storage { + if slots == nil { + panic(fmt.Sprintf("storage %#x nil", accountHash)) + } + } + // Determine memory size and track the dirty writes + for _, data := range accounts { + dl.memory += uint64(common.HashLength + len(data)) + snapshotDirtyAccountWriteMeter.Mark(int64(len(data))) + } + // Fill the storage hashes and sort them for the iterator + dl.storageList = make(map[common.Hash][]common.Hash) + for accountHash := range destructs { + dl.storageList[accountHash] = nil + } + // Determine memory size and track the dirty writes + for _, slots := range storage { + for _, data := range slots { + dl.memory += uint64(common.HashLength + len(data)) + snapshotDirtyStorageWriteMeter.Mark(int64(len(data))) + } + } + dl.memory += uint64(len(dl.storageList) * common.HashLength) + return dl +} + +// rebloom discards the layer's current bloom and rebuilds it from scratch based +// on the parent's and the local diffs. +func (dl *diffLayer) rebloom(origin *diskLayer) { + dl.lock.Lock() + defer dl.lock.Unlock() + + defer func(start time.Time) { + snapshotBloomIndexTimer.Update(time.Since(start)) + }(time.Now()) + + // Inject the new origin that triggered the rebloom + dl.origin = origin + + // Retrieve the parent bloom or create a fresh empty one + if parent, ok := dl.parent.(*diffLayer); ok { + parent.lock.RLock() + dl.diffed, _ = parent.diffed.Copy() + parent.lock.RUnlock() + } else { + dl.diffed, _ = bloomfilter.New(uint64(bloomSize), uint64(bloomFuncs)) + } + // Iterate over all the accounts and storage slots and index them + for hash := range dl.destructSet { + dl.diffed.Add(destructBloomHasher(hash)) + } + for hash := range dl.accountData { + dl.diffed.Add(accountBloomHasher(hash)) + } + for accountHash, slots := range dl.storageData { + for storageHash := range slots { + dl.diffed.Add(storageBloomHasher{accountHash, storageHash}) + } + } + // Calculate the current false positive rate and update the error rate meter. + // This is a bit cheating because subsequent layers will overwrite it, but it + // should be fine, we're only interested in ballpark figures. + k := float64(dl.diffed.K()) + n := float64(dl.diffed.N()) + m := float64(dl.diffed.M()) + snapshotBloomErrorGauge.Update(math.Pow(1.0-math.Exp((-k)*(n+0.5)/(m-1)), k)) +} + +// Root returns the root hash for which this snapshot was made. +func (dl *diffLayer) Root() common.Hash { + return dl.root +} + +// Parent returns the subsequent layer of a diff layer. +func (dl *diffLayer) Parent() snapshot { + return dl.parent +} + +// Stale return whether this layer has become stale (was flattened across) or if +// it's still live. +func (dl *diffLayer) Stale() bool { + return atomic.LoadUint32(&dl.stale) != 0 +} + +// Account directly retrieves the account associated with a particular hash in +// the snapshot slim data format. +func (dl *diffLayer) Account(hash common.Hash) (*Account, error) { + data, err := dl.AccountRLP(hash) + if err != nil { + return nil, err + } + if len(data) == 0 { // can be both nil and []byte{} + return nil, nil + } + account := new(Account) + if err := rlp.DecodeBytes(data, account); err != nil { + panic(err) + } + return account, nil +} + +// AccountRLP directly retrieves the account RLP associated with a particular +// hash in the snapshot slim data format. +func (dl *diffLayer) AccountRLP(hash common.Hash) ([]byte, error) { + // Check the bloom filter first whether there's even a point in reaching into + // all the maps in all the layers below + dl.lock.RLock() + hit := dl.diffed.Contains(accountBloomHasher(hash)) + if !hit { + hit = dl.diffed.Contains(destructBloomHasher(hash)) + } + dl.lock.RUnlock() + + // If the bloom filter misses, don't even bother with traversing the memory + // diff layers, reach straight into the bottom persistent disk layer + if !hit { + snapshotBloomAccountMissMeter.Mark(1) + return dl.origin.AccountRLP(hash) + } + // The bloom filter hit, start poking in the internal maps + return dl.accountRLP(hash, 0) +} + +// accountRLP is an internal version of AccountRLP that skips the bloom filter +// checks and uses the internal maps to try and retrieve the data. It's meant +// to be used if a higher layer's bloom filter hit already. +func (dl *diffLayer) accountRLP(hash common.Hash, depth int) ([]byte, error) { + dl.lock.RLock() + defer dl.lock.RUnlock() + + // If the layer was flattened into, consider it invalid (any live reference to + // the original should be marked as unusable). + if dl.Stale() { + return nil, ErrSnapshotStale + } + // If the account is known locally, return it + if data, ok := dl.accountData[hash]; ok { + snapshotDirtyAccountHitMeter.Mark(1) + snapshotDirtyAccountHitDepthHist.Update(int64(depth)) + snapshotDirtyAccountReadMeter.Mark(int64(len(data))) + snapshotBloomAccountTrueHitMeter.Mark(1) + return data, nil + } + // If the account is known locally, but deleted, return it + if _, ok := dl.destructSet[hash]; ok { + snapshotDirtyAccountHitMeter.Mark(1) + snapshotDirtyAccountHitDepthHist.Update(int64(depth)) + snapshotDirtyAccountInexMeter.Mark(1) + snapshotBloomAccountTrueHitMeter.Mark(1) + return nil, nil + } + // Account unknown to this diff, resolve from parent + if diff, ok := dl.parent.(*diffLayer); ok { + return diff.accountRLP(hash, depth+1) + } + // Failed to resolve through diff layers, mark a bloom error and use the disk + snapshotBloomAccountFalseHitMeter.Mark(1) + return dl.parent.AccountRLP(hash) +} + +// Storage directly retrieves the storage data associated with a particular hash, +// within a particular account. If the slot is unknown to this diff, it's parent +// is consulted. +func (dl *diffLayer) Storage(accountHash, storageHash common.Hash) ([]byte, error) { + // Check the bloom filter first whether there's even a point in reaching into + // all the maps in all the layers below + dl.lock.RLock() + hit := dl.diffed.Contains(storageBloomHasher{accountHash, storageHash}) + if !hit { + hit = dl.diffed.Contains(destructBloomHasher(accountHash)) + } + dl.lock.RUnlock() + + // If the bloom filter misses, don't even bother with traversing the memory + // diff layers, reach straight into the bottom persistent disk layer + if !hit { + snapshotBloomStorageMissMeter.Mark(1) + return dl.origin.Storage(accountHash, storageHash) + } + // The bloom filter hit, start poking in the internal maps + return dl.storage(accountHash, storageHash, 0) +} + +// storage is an internal version of Storage that skips the bloom filter checks +// and uses the internal maps to try and retrieve the data. It's meant to be +// used if a higher layer's bloom filter hit already. +func (dl *diffLayer) storage(accountHash, storageHash common.Hash, depth int) ([]byte, error) { + dl.lock.RLock() + defer dl.lock.RUnlock() + + // If the layer was flattened into, consider it invalid (any live reference to + // the original should be marked as unusable). + if dl.Stale() { + return nil, ErrSnapshotStale + } + // If the account is known locally, try to resolve the slot locally + if storage, ok := dl.storageData[accountHash]; ok { + if data, ok := storage[storageHash]; ok { + snapshotDirtyStorageHitMeter.Mark(1) + snapshotDirtyStorageHitDepthHist.Update(int64(depth)) + if n := len(data); n > 0 { + snapshotDirtyStorageReadMeter.Mark(int64(n)) + } else { + snapshotDirtyStorageInexMeter.Mark(1) + } + snapshotBloomStorageTrueHitMeter.Mark(1) + return data, nil + } + } + // If the account is known locally, but deleted, return an empty slot + if _, ok := dl.destructSet[accountHash]; ok { + snapshotDirtyStorageHitMeter.Mark(1) + snapshotDirtyStorageHitDepthHist.Update(int64(depth)) + snapshotDirtyStorageInexMeter.Mark(1) + snapshotBloomStorageTrueHitMeter.Mark(1) + return nil, nil + } + // Storage slot unknown to this diff, resolve from parent + if diff, ok := dl.parent.(*diffLayer); ok { + return diff.storage(accountHash, storageHash, depth+1) + } + // Failed to resolve through diff layers, mark a bloom error and use the disk + snapshotBloomStorageFalseHitMeter.Mark(1) + return dl.parent.Storage(accountHash, storageHash) +} + +// Update creates a new layer on top of the existing snapshot diff tree with +// the specified data items. +func (dl *diffLayer) Update(blockRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer { + return newDiffLayer(dl, blockRoot, destructs, accounts, storage) +} + +// flatten pushes all data from this point downwards, flattening everything into +// a single diff at the bottom. Since usually the lowermost diff is the largest, +// the flattening bulds up from there in reverse. +func (dl *diffLayer) flatten() snapshot { + // If the parent is not diff, we're the first in line, return unmodified + parent, ok := dl.parent.(*diffLayer) + if !ok { + return dl + } + // Parent is a diff, flatten it first (note, apart from weird corned cases, + // flatten will realistically only ever merge 1 layer, so there's no need to + // be smarter about grouping flattens together). + parent = parent.flatten().(*diffLayer) + + parent.lock.Lock() + defer parent.lock.Unlock() + + // Before actually writing all our data to the parent, first ensure that the + // parent hasn't been 'corrupted' by someone else already flattening into it + if atomic.SwapUint32(&parent.stale, 1) != 0 { + panic("parent diff layer is stale") // we've flattened into the same parent from two children, boo + } + // Overwrite all the updated accounts blindly, merge the sorted list + for hash := range dl.destructSet { + parent.destructSet[hash] = struct{}{} + delete(parent.accountData, hash) + delete(parent.storageData, hash) + } + for hash, data := range dl.accountData { + parent.accountData[hash] = data + } + // Overwrite all the updated storage slots (individually) + for accountHash, storage := range dl.storageData { + // If storage didn't exist (or was deleted) in the parent, overwrite blindly + if _, ok := parent.storageData[accountHash]; !ok { + parent.storageData[accountHash] = storage + continue + } + // Storage exists in both parent and child, merge the slots + comboData := parent.storageData[accountHash] + for storageHash, data := range storage { + comboData[storageHash] = data + } + parent.storageData[accountHash] = comboData + } + // Return the combo parent + return &diffLayer{ + parent: parent.parent, + origin: parent.origin, + root: dl.root, + destructSet: parent.destructSet, + accountData: parent.accountData, + storageData: parent.storageData, + storageList: make(map[common.Hash][]common.Hash), + diffed: dl.diffed, + memory: parent.memory + dl.memory, + } +} + +// AccountList returns a sorted list of all accounts in this difflayer, including +// the deleted ones. +// +// Note, the returned slice is not a copy, so do not modify it. +func (dl *diffLayer) AccountList() []common.Hash { + // If an old list already exists, return it + dl.lock.RLock() + list := dl.accountList + dl.lock.RUnlock() + + if list != nil { + return list + } + // No old sorted account list exists, generate a new one + dl.lock.Lock() + defer dl.lock.Unlock() + + dl.accountList = make([]common.Hash, 0, len(dl.destructSet)+len(dl.accountData)) + for hash := range dl.accountData { + dl.accountList = append(dl.accountList, hash) + } + for hash := range dl.destructSet { + if _, ok := dl.accountData[hash]; !ok { + dl.accountList = append(dl.accountList, hash) + } + } + sort.Sort(hashes(dl.accountList)) + return dl.accountList +} + +// StorageList returns a sorted list of all storage slot hashes in this difflayer +// for the given account. +// +// Note, the returned slice is not a copy, so do not modify it. +func (dl *diffLayer) StorageList(accountHash common.Hash) []common.Hash { + // If an old list already exists, return it + dl.lock.RLock() + list := dl.storageList[accountHash] + dl.lock.RUnlock() + + if list != nil { + return list + } + // No old sorted account list exists, generate a new one + dl.lock.Lock() + defer dl.lock.Unlock() + + storageMap := dl.storageData[accountHash] + storageList := make([]common.Hash, 0, len(storageMap)) + for k := range storageMap { + storageList = append(storageList, k) + } + sort.Sort(hashes(storageList)) + dl.storageList[accountHash] = storageList + return storageList +} diff --git a/core/state/snapshot/difflayer_test.go b/core/state/snapshot/difflayer_test.go new file mode 100644 index 0000000000..89814432bb --- /dev/null +++ b/core/state/snapshot/difflayer_test.go @@ -0,0 +1,399 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "math/rand" + "testing" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/ethdb/memorydb" +) + +func copyDestructs(destructs map[common.Hash]struct{}) map[common.Hash]struct{} { + copy := make(map[common.Hash]struct{}) + for hash := range destructs { + copy[hash] = struct{}{} + } + return copy +} + +func copyAccounts(accounts map[common.Hash][]byte) map[common.Hash][]byte { + copy := make(map[common.Hash][]byte) + for hash, blob := range accounts { + copy[hash] = blob + } + return copy +} + +func copyStorage(storage map[common.Hash]map[common.Hash][]byte) map[common.Hash]map[common.Hash][]byte { + copy := make(map[common.Hash]map[common.Hash][]byte) + for accHash, slots := range storage { + copy[accHash] = make(map[common.Hash][]byte) + for slotHash, blob := range slots { + copy[accHash][slotHash] = blob + } + } + return copy +} + +// TestMergeBasics tests some simple merges +func TestMergeBasics(t *testing.T) { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + // Fill up a parent + for i := 0; i < 100; i++ { + h := randomHash() + data := randomAccount() + + accounts[h] = data + if rand.Intn(4) == 0 { + destructs[h] = struct{}{} + } + if rand.Intn(2) == 0 { + accStorage := make(map[common.Hash][]byte) + value := make([]byte, 32) + rand.Read(value) + accStorage[randomHash()] = value + storage[h] = accStorage + } + } + // Add some (identical) layers on top + parent := newDiffLayer(emptyLayer(), common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + child := newDiffLayer(parent, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + child = newDiffLayer(child, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + child = newDiffLayer(child, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + child = newDiffLayer(child, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + // And flatten + merged := (child.flatten()).(*diffLayer) + + { // Check account lists + if have, want := len(merged.accountList), 0; have != want { + t.Errorf("accountList wrong: have %v, want %v", have, want) + } + if have, want := len(merged.AccountList()), len(accounts); have != want { + t.Errorf("AccountList() wrong: have %v, want %v", have, want) + } + if have, want := len(merged.accountList), len(accounts); have != want { + t.Errorf("accountList [2] wrong: have %v, want %v", have, want) + } + } + { // Check account drops + if have, want := len(merged.destructSet), len(destructs); have != want { + t.Errorf("accountDrop wrong: have %v, want %v", have, want) + } + } + { // Check storage lists + i := 0 + for aHash, sMap := range storage { + if have, want := len(merged.storageList), i; have != want { + t.Errorf("[1] storageList wrong: have %v, want %v", have, want) + } + if have, want := len(merged.StorageList(aHash)), len(sMap); have != want { + t.Errorf("[2] StorageList() wrong: have %v, want %v", have, want) + } + if have, want := len(merged.storageList[aHash]), len(sMap); have != want { + t.Errorf("storageList wrong: have %v, want %v", have, want) + } + i++ + } + } +} + +// TestMergeDelete tests some deletion +func TestMergeDelete(t *testing.T) { + var ( + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + // Fill up a parent + h1 := common.HexToHash("0x01") + h2 := common.HexToHash("0x02") + + flipDrops := func() map[common.Hash]struct{} { + return map[common.Hash]struct{}{ + h2: struct{}{}, + } + } + flipAccs := func() map[common.Hash][]byte { + return map[common.Hash][]byte{ + h1: randomAccount(), + } + } + flopDrops := func() map[common.Hash]struct{} { + return map[common.Hash]struct{}{ + h1: struct{}{}, + } + } + flopAccs := func() map[common.Hash][]byte { + return map[common.Hash][]byte{ + h2: randomAccount(), + } + } + // Add some flipAccs-flopping layers on top + parent := newDiffLayer(emptyLayer(), common.Hash{}, flipDrops(), flipAccs(), storage) + child := parent.Update(common.Hash{}, flopDrops(), flopAccs(), storage) + child = child.Update(common.Hash{}, flipDrops(), flipAccs(), storage) + child = child.Update(common.Hash{}, flopDrops(), flopAccs(), storage) + child = child.Update(common.Hash{}, flipDrops(), flipAccs(), storage) + child = child.Update(common.Hash{}, flopDrops(), flopAccs(), storage) + child = child.Update(common.Hash{}, flipDrops(), flipAccs(), storage) + + if data, _ := child.Account(h1); data == nil { + t.Errorf("last diff layer: expected %x account to be non-nil", h1) + } + if data, _ := child.Account(h2); data != nil { + t.Errorf("last diff layer: expected %x account to be nil", h2) + } + if _, ok := child.destructSet[h1]; ok { + t.Errorf("last diff layer: expected %x drop to be missing", h1) + } + if _, ok := child.destructSet[h2]; !ok { + t.Errorf("last diff layer: expected %x drop to be present", h1) + } + // And flatten + merged := (child.flatten()).(*diffLayer) + + if data, _ := merged.Account(h1); data == nil { + t.Errorf("merged layer: expected %x account to be non-nil", h1) + } + if data, _ := merged.Account(h2); data != nil { + t.Errorf("merged layer: expected %x account to be nil", h2) + } + if _, ok := merged.destructSet[h1]; !ok { // Note, drops stay alive until persisted to disk! + t.Errorf("merged diff layer: expected %x drop to be present", h1) + } + if _, ok := merged.destructSet[h2]; !ok { // Note, drops stay alive until persisted to disk! + t.Errorf("merged diff layer: expected %x drop to be present", h1) + } + // If we add more granular metering of memory, we can enable this again, + // but it's not implemented for now + //if have, want := merged.memory, child.memory; have != want { + // t.Errorf("mem wrong: have %d, want %d", have, want) + //} +} + +// This tests that if we create a new account, and set a slot, and then merge +// it, the lists will be correct. +func TestInsertAndMerge(t *testing.T) { + // Fill up a parent + var ( + acc = common.HexToHash("0x01") + slot = common.HexToHash("0x02") + parent *diffLayer + child *diffLayer + ) + { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + parent = newDiffLayer(emptyLayer(), common.Hash{}, destructs, accounts, storage) + } + { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + accounts[acc] = randomAccount() + storage[acc] = make(map[common.Hash][]byte) + storage[acc][slot] = []byte{0x01} + child = newDiffLayer(parent, common.Hash{}, destructs, accounts, storage) + } + // And flatten + merged := (child.flatten()).(*diffLayer) + { // Check that slot value is present + have, _ := merged.Storage(acc, slot) + if want := []byte{0x01}; !bytes.Equal(have, want) { + t.Errorf("merged slot value wrong: have %x, want %x", have, want) + } + } +} + +func emptyLayer() *diskLayer { + return &diskLayer{ + diskdb: memorydb.New(), + cache: fastcache.New(500 * 1024), + } +} + +// BenchmarkSearch checks how long it takes to find a non-existing key +// BenchmarkSearch-6 200000 10481 ns/op (1K per layer) +// BenchmarkSearch-6 200000 10760 ns/op (10K per layer) +// BenchmarkSearch-6 100000 17866 ns/op +// +// BenchmarkSearch-6 500000 3723 ns/op (10k per layer, only top-level RLock() +func BenchmarkSearch(b *testing.B) { + // First, we set up 128 diff layers, with 1K items each + fill := func(parent snapshot) *diffLayer { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + for i := 0; i < 10000; i++ { + accounts[randomHash()] = randomAccount() + } + return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage) + } + var layer snapshot + layer = emptyLayer() + for i := 0; i < 128; i++ { + layer = fill(layer) + } + key := crypto.Keccak256Hash([]byte{0x13, 0x38}) + b.ResetTimer() + for i := 0; i < b.N; i++ { + layer.AccountRLP(key) + } +} + +// BenchmarkSearchSlot checks how long it takes to find a non-existing key +// - Number of layers: 128 +// - Each layers contains the account, with a couple of storage slots +// BenchmarkSearchSlot-6 100000 14554 ns/op +// BenchmarkSearchSlot-6 100000 22254 ns/op (when checking parent root using mutex) +// BenchmarkSearchSlot-6 100000 14551 ns/op (when checking parent number using atomic) +// With bloom filter: +// BenchmarkSearchSlot-6 3467835 351 ns/op +func BenchmarkSearchSlot(b *testing.B) { + // First, we set up 128 diff layers, with 1K items each + accountKey := crypto.Keccak256Hash([]byte{0x13, 0x37}) + storageKey := crypto.Keccak256Hash([]byte{0x13, 0x37}) + accountRLP := randomAccount() + fill := func(parent snapshot) *diffLayer { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + accounts[accountKey] = accountRLP + + accStorage := make(map[common.Hash][]byte) + for i := 0; i < 5; i++ { + value := make([]byte, 32) + rand.Read(value) + accStorage[randomHash()] = value + storage[accountKey] = accStorage + } + return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage) + } + var layer snapshot + layer = emptyLayer() + for i := 0; i < 128; i++ { + layer = fill(layer) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + layer.Storage(accountKey, storageKey) + } +} + +// With accountList and sorting +// BenchmarkFlatten-6 50 29890856 ns/op +// +// Without sorting and tracking accountlist +// BenchmarkFlatten-6 300 5511511 ns/op +func BenchmarkFlatten(b *testing.B) { + fill := func(parent snapshot) *diffLayer { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + for i := 0; i < 100; i++ { + accountKey := randomHash() + accounts[accountKey] = randomAccount() + + accStorage := make(map[common.Hash][]byte) + for i := 0; i < 20; i++ { + value := make([]byte, 32) + rand.Read(value) + accStorage[randomHash()] = value + + } + storage[accountKey] = accStorage + } + return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + var layer snapshot + layer = emptyLayer() + for i := 1; i < 128; i++ { + layer = fill(layer) + } + b.StartTimer() + + for i := 1; i < 128; i++ { + dl, ok := layer.(*diffLayer) + if !ok { + break + } + layer = dl.flatten() + } + b.StopTimer() + } +} + +// This test writes ~324M of diff layers to disk, spread over +// - 128 individual layers, +// - each with 200 accounts +// - containing 200 slots +// +// BenchmarkJournal-6 1 1471373923 ns/ops +// BenchmarkJournal-6 1 1208083335 ns/op // bufio writer +func BenchmarkJournal(b *testing.B) { + fill := func(parent snapshot) *diffLayer { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + for i := 0; i < 200; i++ { + accountKey := randomHash() + accounts[accountKey] = randomAccount() + + accStorage := make(map[common.Hash][]byte) + for i := 0; i < 200; i++ { + value := make([]byte, 32) + rand.Read(value) + accStorage[randomHash()] = value + + } + storage[accountKey] = accStorage + } + return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage) + } + layer := snapshot(new(diskLayer)) + for i := 1; i < 128; i++ { + layer = fill(layer) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + layer.Journal(new(bytes.Buffer)) + } +} diff --git a/core/state/snapshot/disklayer.go b/core/state/snapshot/disklayer.go new file mode 100644 index 0000000000..4fa43660c7 --- /dev/null +++ b/core/state/snapshot/disklayer.go @@ -0,0 +1,166 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "sync" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/trie" +) + +// diskLayer is a low level persistent snapshot built on top of a key-value store. +type diskLayer struct { + diskdb ethdb.KeyValueStore // Key-value store containing the base snapshot + triedb *trie.Database // Trie node cache for reconstuction purposes + cache *fastcache.Cache // Cache to avoid hitting the disk for direct access + + root common.Hash // Root hash of the base snapshot + stale bool // Signals that the layer became stale (state progressed) + + genMarker []byte // Marker for the state that's indexed during initial layer generation + genPending chan struct{} // Notification channel when generation is done (test synchronicity) + genAbort chan chan *generatorStats // Notification channel to abort generating the snapshot in this layer + + lock sync.RWMutex +} + +// Root returns root hash for which this snapshot was made. +func (dl *diskLayer) Root() common.Hash { + return dl.root +} + +// Parent always returns nil as there's no layer below the disk. +func (dl *diskLayer) Parent() snapshot { + return nil +} + +// Stale return whether this layer has become stale (was flattened across) or if +// it's still live. +func (dl *diskLayer) Stale() bool { + dl.lock.RLock() + defer dl.lock.RUnlock() + + return dl.stale +} + +// Account directly retrieves the account associated with a particular hash in +// the snapshot slim data format. +func (dl *diskLayer) Account(hash common.Hash) (*Account, error) { + data, err := dl.AccountRLP(hash) + if err != nil { + return nil, err + } + if len(data) == 0 { // can be both nil and []byte{} + return nil, nil + } + account := new(Account) + if err := rlp.DecodeBytes(data, account); err != nil { + panic(err) + } + return account, nil +} + +// AccountRLP directly retrieves the account RLP associated with a particular +// hash in the snapshot slim data format. +func (dl *diskLayer) AccountRLP(hash common.Hash) ([]byte, error) { + dl.lock.RLock() + defer dl.lock.RUnlock() + + // If the layer was flattened into, consider it invalid (any live reference to + // the original should be marked as unusable). + if dl.stale { + return nil, ErrSnapshotStale + } + // If the layer is being generated, ensure the requested hash has already been + // covered by the generator. + if dl.genMarker != nil && bytes.Compare(hash[:], dl.genMarker) > 0 { + return nil, ErrNotCoveredYet + } + // If we're in the disk layer, all diff layers missed + snapshotDirtyAccountMissMeter.Mark(1) + + // Try to retrieve the account from the memory cache + if blob, found := dl.cache.HasGet(nil, hash[:]); found { + snapshotCleanAccountHitMeter.Mark(1) + snapshotCleanAccountReadMeter.Mark(int64(len(blob))) + return blob, nil + } + // Cache doesn't contain account, pull from disk and cache for later + blob := rawdb.ReadAccountSnapshot(dl.diskdb, hash) + dl.cache.Set(hash[:], blob) + + snapshotCleanAccountMissMeter.Mark(1) + if n := len(blob); n > 0 { + snapshotCleanAccountWriteMeter.Mark(int64(n)) + } else { + snapshotCleanAccountInexMeter.Mark(1) + } + return blob, nil +} + +// Storage directly retrieves the storage data associated with a particular hash, +// within a particular account. +func (dl *diskLayer) Storage(accountHash, storageHash common.Hash) ([]byte, error) { + dl.lock.RLock() + defer dl.lock.RUnlock() + + // If the layer was flattened into, consider it invalid (any live reference to + // the original should be marked as unusable). + if dl.stale { + return nil, ErrSnapshotStale + } + key := append(accountHash[:], storageHash[:]...) + + // If the layer is being generated, ensure the requested hash has already been + // covered by the generator. + if dl.genMarker != nil && bytes.Compare(key, dl.genMarker) > 0 { + return nil, ErrNotCoveredYet + } + // If we're in the disk layer, all diff layers missed + snapshotDirtyStorageMissMeter.Mark(1) + + // Try to retrieve the storage slot from the memory cache + if blob, found := dl.cache.HasGet(nil, key); found { + snapshotCleanStorageHitMeter.Mark(1) + snapshotCleanStorageReadMeter.Mark(int64(len(blob))) + return blob, nil + } + // Cache doesn't contain storage slot, pull from disk and cache for later + blob := rawdb.ReadStorageSnapshot(dl.diskdb, accountHash, storageHash) + dl.cache.Set(key, blob) + + snapshotCleanStorageMissMeter.Mark(1) + if n := len(blob); n > 0 { + snapshotCleanStorageWriteMeter.Mark(int64(n)) + } else { + snapshotCleanStorageInexMeter.Mark(1) + } + return blob, nil +} + +// Update creates a new layer on top of the existing snapshot diff tree with +// the specified data items. Note, the maps are retained by the method to avoid +// copying everything. +func (dl *diskLayer) Update(blockHash common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer { + return newDiffLayer(dl, blockHash, destructs, accounts, storage) +} diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go new file mode 100644 index 0000000000..a38eeab75d --- /dev/null +++ b/core/state/snapshot/generate.go @@ -0,0 +1,284 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "encoding/binary" + "fmt" + "math/big" + "time" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/trie" +) + +var ( + // emptyRoot is the known root hash of an empty trie. + emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + + // emptyCode is the known hash of the empty EVM bytecode. + emptyCode = crypto.Keccak256Hash(nil) +) + +// generatorStats is a collection of statistics gathered by the snapshot generator +// for logging purposes. +type generatorStats struct { + origin uint64 // Origin prefix where generation started + start time.Time // Timestamp when generation started + accounts uint64 // Number of accounts indexed + slots uint64 // Number of storage slots indexed + storage common.StorageSize // Account and storage slot size +} + +// Log creates an contextual log with the given message and the context pulled +// from the internally maintained statistics. +func (gs *generatorStats) Log(msg string, marker []byte) { + var ctx []interface{} + + // Figure out whether we're after or within an account + switch len(marker) { + case common.HashLength: + ctx = append(ctx, []interface{}{"at", common.BytesToHash(marker)}...) + case 2 * common.HashLength: + ctx = append(ctx, []interface{}{ + "in", common.BytesToHash(marker[:common.HashLength]), + "at", common.BytesToHash(marker[common.HashLength:]), + }...) + } + // Add the usual measurements + ctx = append(ctx, []interface{}{ + "accounts", gs.accounts, + "slots", gs.slots, + "storage", gs.storage, + "elapsed", common.PrettyDuration(time.Since(gs.start)), + }...) + // Calculate the estimated indexing time based on current stats + if len(marker) > 0 { + if done := binary.BigEndian.Uint64(marker[:8]) - gs.origin; done > 0 { + left := math.MaxUint64 - binary.BigEndian.Uint64(marker[:8]) + + speed := done/uint64(time.Since(gs.start)/time.Millisecond+1) + 1 // +1s to avoid division by zero + ctx = append(ctx, []interface{}{ + "eta", common.PrettyDuration(time.Duration(left/speed) * time.Millisecond), + }...) + } + } + log.Info(msg, ctx...) +} + +// generateSnapshot regenerates a brand new snapshot based on an existing state +// database and head block asynchronously. The snapshot is returned immediately +// and generation is continued in the background until done. +func generateSnapshot(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash) *diskLayer { + // Create a new disk layer with an initialized state marker at zero + var ( + stats = &generatorStats{start: time.Now()} + batch = diskdb.NewBatch() + genMarker = []byte{} // Initialized but empty! + ) + // Create a new disk layer with an initialized state marker at zero + rawdb.WriteSnapshotRoot(diskdb, root) + if err := batch.Write(); err != nil { + log.Crit("Failed to write initialized state marker", "err", err) + } + base := &diskLayer{ + diskdb: diskdb, + triedb: triedb, + root: root, + cache: fastcache.New(cache * 1024 * 1024), + genMarker: genMarker, // Initialized but empty! + genPending: make(chan struct{}), + genAbort: make(chan chan *generatorStats), + } + go base.generate(stats) + log.Debug("Start snapshot generation", "root", root) + return base +} + +// journalProgress persists the generator stats into the database to resume later. +func journalProgress(db ethdb.KeyValueWriter, marker []byte, stats *generatorStats) { + // Write out the generator marker. Note it's a standalone disk layer generator + // which is not mixed with journal. It's ok if the generator is persisted while + // journal is not. + entry := journalGenerator{ + Done: marker == nil, + Marker: marker, + } + if stats != nil { + entry.Accounts = stats.accounts + entry.Slots = stats.slots + entry.Storage = uint64(stats.storage) + } + blob, err := rlp.EncodeToBytes(entry) + if err != nil { + panic(err) // Cannot happen, here to catch dev errors + } + var logstr string + switch { + case marker == nil: + logstr = "done" + case bytes.Equal(marker, []byte{}): + logstr = "empty" + case len(marker) == common.HashLength: + logstr = fmt.Sprintf("%#x", marker) + default: + logstr = fmt.Sprintf("%#x:%#x", marker[:common.HashLength], marker[common.HashLength:]) + } + log.Debug("Journalled generator progress", "progress", logstr) + rawdb.WriteSnapshotGenerator(db, blob) +} + +// generate is a background thread that iterates over the state and storage tries, +// constructing the state snapshot. All the arguments are purely for statistics +// gethering and logging, since the method surfs the blocks as they arrive, often +// being restarted. +func (dl *diskLayer) generate(stats *generatorStats) { + // Create an account and state iterator pointing to the current generator marker + accTrie, err := trie.NewSecure(dl.root, dl.triedb) + if err != nil { + // The account trie is missing (GC), surf the chain until one becomes available + stats.Log("Trie missing, state snapshotting paused", dl.genMarker) + + abort := <-dl.genAbort + abort <- stats + return + } + stats.Log("Resuming state snapshot generation", dl.genMarker) + + var accMarker []byte + if len(dl.genMarker) > 0 { // []byte{} is the start, use nil for that + accMarker = dl.genMarker[:common.HashLength] + } + accIt := trie.NewIterator(accTrie.NodeIterator(accMarker)) + batch := dl.diskdb.NewBatch() + + // Iterate from the previous marker and continue generating the state snapshot + logged := time.Now() + for accIt.Next() { + // Retrieve the current account and flatten it into the internal format + accountHash := common.BytesToHash(accIt.Key) + + var acc struct { + Nonce uint64 + Balance *big.Int + Root common.Hash + CodeHash []byte + } + if err := rlp.DecodeBytes(accIt.Value, &acc); err != nil { + log.Crit("Invalid account encountered during snapshot creation", "err", err) + } + data := AccountRLP(acc.Nonce, acc.Balance, acc.Root, acc.CodeHash) + + // If the account is not yet in-progress, write it out + if accMarker == nil || !bytes.Equal(accountHash[:], accMarker) { + rawdb.WriteAccountSnapshot(batch, accountHash, data) + stats.storage += common.StorageSize(1 + common.HashLength + len(data)) + stats.accounts++ + } + // If we've exceeded our batch allowance or termination was requested, flush to disk + var abort chan *generatorStats + select { + case abort = <-dl.genAbort: + default: + } + if batch.ValueSize() > ethdb.IdealBatchSize || abort != nil { + // Only write and set the marker if we actually did something useful + if batch.ValueSize() > 0 { + batch.Write() + batch.Reset() + + dl.lock.Lock() + dl.genMarker = accountHash[:] + dl.lock.Unlock() + } + if abort != nil { + stats.Log("Aborting state snapshot generation", accountHash[:]) + abort <- stats + return + } + } + // If the account is in-progress, continue where we left off (otherwise iterate all) + if acc.Root != emptyRoot { + storeTrie, err := trie.NewSecure(acc.Root, dl.triedb) + if err != nil { + log.Crit("Storage trie inaccessible for snapshot generation", "err", err) + } + var storeMarker []byte + if accMarker != nil && bytes.Equal(accountHash[:], accMarker) && len(dl.genMarker) > common.HashLength { + storeMarker = dl.genMarker[common.HashLength:] + } + storeIt := trie.NewIterator(storeTrie.NodeIterator(storeMarker)) + for storeIt.Next() { + rawdb.WriteStorageSnapshot(batch, accountHash, common.BytesToHash(storeIt.Key), storeIt.Value) + stats.storage += common.StorageSize(1 + 2*common.HashLength + len(storeIt.Value)) + stats.slots++ + + // If we've exceeded our batch allowance or termination was requested, flush to disk + var abort chan *generatorStats + select { + case abort = <-dl.genAbort: + default: + } + if batch.ValueSize() > ethdb.IdealBatchSize || abort != nil { + // Only write and set the marker if we actually did something useful + if batch.ValueSize() > 0 { + batch.Write() + batch.Reset() + + dl.lock.Lock() + dl.genMarker = append(accountHash[:], storeIt.Key...) + dl.lock.Unlock() + } + if abort != nil { + stats.Log("Aborting state snapshot generation", append(accountHash[:], storeIt.Key...)) + abort <- stats + return + } + } + } + } + if time.Since(logged) > 8*time.Second { + stats.Log("Generating state snapshot", accIt.Key) + logged = time.Now() + } + // Some account processed, unmark the marker + accMarker = nil + } + // Snapshot fully generated, set the marker to nil + if batch.ValueSize() > 0 { + batch.Write() + } + log.Info("Generated state snapshot", "accounts", stats.accounts, "slots", stats.slots, + "storage", stats.storage, "elapsed", common.PrettyDuration(time.Since(stats.start))) + + dl.lock.Lock() + dl.genMarker = nil + close(dl.genPending) + dl.lock.Unlock() + + // Someone will be looking for us, wait it out + abort := <-dl.genAbort + abort <- nil +} diff --git a/core/state/snapshot/iterator.go b/core/state/snapshot/iterator.go new file mode 100644 index 0000000000..b62fb30e34 --- /dev/null +++ b/core/state/snapshot/iterator.go @@ -0,0 +1,221 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "fmt" + "sort" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/ethdb" +) + +// Iterator is an iterator to step over all the accounts or the specific +// storage in a snapshot which may or may not be composed of multiple layers. +type Iterator interface { + // Next steps the iterator forward one element, returning false if exhausted, + // or an error if iteration failed for some reason (e.g. root being iterated + // becomes stale and garbage collected). + Next() bool + + // Error returns any failure that occurred during iteration, which might have + // caused a premature iteration exit (e.g. snapshot stack becoming stale). + Error() error + + // Hash returns the hash of the account or storage slot the iterator is + // currently at. + Hash() common.Hash + + // Release releases associated resources. Release should always succeed and + // can be called multiple times without causing error. + Release() +} + +// AccountIterator is an iterator to step over all the accounts in a snapshot, +// which may or may not be composed of multiple layers. +type AccountIterator interface { + Iterator + + // Account returns the RLP encoded slim account the iterator is currently at. + // An error will be returned if the iterator becomes invalid + Account() []byte +} + +// diffAccountIterator is an account iterator that steps over the accounts (both +// live and deleted) contained within a single diff layer. Higher order iterators +// will use the deleted accounts to skip deeper iterators. +type diffAccountIterator struct { + // curHash is the current hash the iterator is positioned on. The field is + // explicitly tracked since the referenced diff layer might go stale after + // the iterator was positioned and we don't want to fail accessing the old + // hash as long as the iterator is not touched any more. + curHash common.Hash + + layer *diffLayer // Live layer to retrieve values from + keys []common.Hash // Keys left in the layer to iterate + fail error // Any failures encountered (stale) +} + +// StorageIterator is an iterator to step over the specific storage in a snapshot, +// which may or may not be composed of multiple layers. +type StorageIterator interface { + Iterator + + // Slot returns the storage slot the iterator is currently at. An error will + // be returned if the iterator becomes invalid + Slot() []byte +} + +// AccountIterator creates an account iterator over a single diff layer. +func (dl *diffLayer) AccountIterator(seek common.Hash) AccountIterator { + // Seek out the requested starting account + hashes := dl.AccountList() + index := sort.Search(len(hashes), func(i int) bool { + return bytes.Compare(seek[:], hashes[i][:]) < 0 + }) + // Assemble and returned the already seeked iterator + return &diffAccountIterator{ + layer: dl, + keys: hashes[index:], + } +} + +// Next steps the iterator forward one element, returning false if exhausted. +func (it *diffAccountIterator) Next() bool { + // If the iterator was already stale, consider it a programmer error. Although + // we could just return false here, triggering this path would probably mean + // somebody forgot to check for Error, so lets blow up instead of undefined + // behavior that's hard to debug. + if it.fail != nil { + panic(fmt.Sprintf("called Next of failed iterator: %v", it.fail)) + } + // Stop iterating if all keys were exhausted + if len(it.keys) == 0 { + return false + } + if it.layer.Stale() { + it.fail, it.keys = ErrSnapshotStale, nil + return false + } + // Iterator seems to be still alive, retrieve and cache the live hash + it.curHash = it.keys[0] + // key cached, shift the iterator and notify the user of success + it.keys = it.keys[1:] + return true +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit (e.g. snapshot stack becoming stale). +func (it *diffAccountIterator) Error() error { + return it.fail +} + +// Hash returns the hash of the account the iterator is currently at. +func (it *diffAccountIterator) Hash() common.Hash { + return it.curHash +} + +// Account returns the RLP encoded slim account the iterator is currently at. +// This method may _fail_, if the underlying layer has been flattened between +// the call to Next and Acccount. That type of error will set it.Err. +// This method assumes that flattening does not delete elements from +// the accountdata mapping (writing nil into it is fine though), and will panic +// if elements have been deleted. +func (it *diffAccountIterator) Account() []byte { + it.layer.lock.RLock() + blob, ok := it.layer.accountData[it.curHash] + if !ok { + if _, ok := it.layer.destructSet[it.curHash]; ok { + return nil + } + panic(fmt.Sprintf("iterator referenced non-existent account: %x", it.curHash)) + } + it.layer.lock.RUnlock() + if it.layer.Stale() { + it.fail, it.keys = ErrSnapshotStale, nil + } + return blob +} + +// Release is a noop for diff account iterators as there are no held resources. +func (it *diffAccountIterator) Release() {} + +// diskAccountIterator is an account iterator that steps over the live accounts +// contained within a disk layer. +type diskAccountIterator struct { + layer *diskLayer + it ethdb.Iterator +} + +// AccountIterator creates an account iterator over a disk layer. +func (dl *diskLayer) AccountIterator(seek common.Hash) AccountIterator { + pos := common.TrimRightZeroes(seek[:]) + return &diskAccountIterator{ + layer: dl, + it: dl.diskdb.NewIterator(rawdb.SnapshotAccountPrefix, pos), + } +} + +// Next steps the iterator forward one element, returning false if exhausted. +func (it *diskAccountIterator) Next() bool { + // If the iterator was already exhausted, don't bother + if it.it == nil { + return false + } + // Try to advance the iterator and release it if we reached the end + for { + if !it.it.Next() || !bytes.HasPrefix(it.it.Key(), rawdb.SnapshotAccountPrefix) { + it.it.Release() + it.it = nil + return false + } + if len(it.it.Key()) == len(rawdb.SnapshotAccountPrefix)+common.HashLength { + break + } + } + return true +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit (e.g. snapshot stack becoming stale). +// +// A diff layer is immutable after creation content wise and can always be fully +// iterated without error, so this method always returns nil. +func (it *diskAccountIterator) Error() error { + return it.it.Error() +} + +// Hash returns the hash of the account the iterator is currently at. +func (it *diskAccountIterator) Hash() common.Hash { + return common.BytesToHash(it.it.Key()) +} + +// Account returns the RLP encoded slim account the iterator is currently at. +func (it *diskAccountIterator) Account() []byte { + return it.it.Value() +} + +// Release releases the database snapshot held during iteration. +func (it *diskAccountIterator) Release() { + // The iterator is auto-released on exhaustion, so make sure it's still alive + if it.it != nil { + it.it.Release() + it.it = nil + } +} diff --git a/core/state/snapshot/iterator_binary b/core/state/snapshot/iterator_binary new file mode 100644 index 0000000000..d8df968ea5 --- /dev/null +++ b/core/state/snapshot/iterator_binary @@ -0,0 +1,115 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + + "github.com/tomochain/tomochain/common" +) + +// binaryAccountIterator is a simplistic iterator to step over the accounts in +// a snapshot, which may or may npt be composed of multiple layers. Performance +// wise this iterator is slow, it's meant for cross validating the fast one, +type binaryAccountIterator struct { + a *diffAccountIterator + b AccountIterator + aDone bool + bDone bool + k common.Hash + fail error +} + +// newBinaryAccountIterator creates a simplistic account iterator to step over +// all the accounts in a slow, but eaily verifiable way. +func (dl *diffLayer) newBinaryAccountIterator() AccountIterator { + parent, ok := dl.parent.(*diffLayer) + if !ok { + // parent is the disk layer + return dl.AccountIterator(common.Hash{}) + } + l := &binaryAccountIterator{ + a: dl.AccountIterator(common.Hash{}).(*diffAccountIterator), + b: parent.newBinaryAccountIterator(), + } + l.aDone = !l.a.Next() + l.bDone = !l.b.Next() + return l +} + +// Next steps the iterator forward one element, returning false if exhausted, +// or an error if iteration failed for some reason (e.g. root being iterated +// becomes stale and garbage collected). +func (it *binaryAccountIterator) Next() bool { + if it.aDone && it.bDone { + return false + } + nextB := it.b.Hash() +first: + nextA := it.a.Hash() + if it.aDone { + it.bDone = !it.b.Next() + it.k = nextB + return true + } + if it.bDone { + it.aDone = !it.a.Next() + it.k = nextA + return true + } + if diff := bytes.Compare(nextA[:], nextB[:]); diff < 0 { + it.aDone = !it.a.Next() + it.k = nextA + return true + } else if diff == 0 { + // Now we need to advance one of them + it.aDone = !it.a.Next() + goto first + } + it.bDone = !it.b.Next() + it.k = nextB + return true +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit (e.g. snapshot stack becoming stale). +func (it *binaryAccountIterator) Error() error { + return it.fail +} + +// Hash returns the hash of the account the iterator is currently at. +func (it *binaryAccountIterator) Hash() common.Hash { + return it.k +} + +// Account returns the RLP encoded slim account the iterator is currently at, or +// nil if the iterated snapshot stack became stale (you can check Error after +// to see if it failed or not). +func (it *binaryAccountIterator) Account() []byte { + blob, err := it.a.layer.AccountRLP(it.k) + if err != nil { + it.fail = err + return nil + } + return blob +} + +// Release recursively releases all the iterators in the stack. +func (it *binaryAccountIterator) Release() { + it.a.Release() + it.b.Release() +} diff --git a/core/state/snapshot/iterator_fast.go b/core/state/snapshot/iterator_fast.go new file mode 100644 index 0000000000..afbe70c2bc --- /dev/null +++ b/core/state/snapshot/iterator_fast.go @@ -0,0 +1,302 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "fmt" + "sort" + + "github.com/tomochain/tomochain/common" +) + +// weightedAccountIterator is an account iterator with an assigned weight. It is +// used to prioritise which account is the correct one if multiple iterators find +// the same one (modified in multiple consecutive blocks). +type weightedAccountIterator struct { + it AccountIterator + priority int +} + +// weightedAccountIterators is a set of iterators implementing the sort.Interface. +type weightedAccountIterators []*weightedAccountIterator + +// Len implements sort.Interface, returning the number of active iterators. +func (its weightedAccountIterators) Len() int { return len(its) } + +// Less implements sort.Interface, returning which of two iterators in the stack +// is before the other. +func (its weightedAccountIterators) Less(i, j int) bool { + // Order the iterators primarily by the account hashes + hashI := its[i].it.Hash() + hashJ := its[j].it.Hash() + + switch bytes.Compare(hashI[:], hashJ[:]) { + case -1: + return true + case 1: + return false + } + // Same account in multiple layers, split by priority + return its[i].priority < its[j].priority +} + +// Swap implements sort.Interface, swapping two entries in the iterator stack. +func (its weightedAccountIterators) Swap(i, j int) { + its[i], its[j] = its[j], its[i] +} + +// fastAccountIterator is a more optimized multi-layer iterator which maintains a +// direct mapping of all iterators leading down to the bottom layer. +type fastAccountIterator struct { + tree *Tree // Snapshot tree to reinitialize stale sub-iterators with + root common.Hash // Root hash to reinitialize stale sub-iterators through + curAccount []byte + + iterators weightedAccountIterators + initiated bool + fail error +} + +// newFastAccountIterator creates a new hierarhical account iterator with one +// element per diff layer. The returned combo iterator can be used to walk over +// the entire snapshot diff stack simultaneously. +func newFastAccountIterator(tree *Tree, root common.Hash, seek common.Hash) (AccountIterator, error) { + snap := tree.Snapshot(root) + if snap == nil { + return nil, fmt.Errorf("unknown snapshot: %x", root) + } + fi := &fastAccountIterator{ + tree: tree, + root: root, + } + current := snap.(snapshot) + for depth := 0; current != nil; depth++ { + fi.iterators = append(fi.iterators, &weightedAccountIterator{ + it: current.AccountIterator(seek), + priority: depth, + }) + current = current.Parent() + } + fi.init() + return fi, nil +} + +// init walks over all the iterators and resolves any clashes between them, after +// which it prepares the stack for step-by-step iteration. +func (fi *fastAccountIterator) init() { + // Track which account hashes are iterators positioned on + var positioned = make(map[common.Hash]int) + + // Position all iterators and track how many remain live + for i := 0; i < len(fi.iterators); i++ { + // Retrieve the first element and if it clashes with a previous iterator, + // advance either the current one or the old one. Repeat until nothing is + // clashing any more. + it := fi.iterators[i] + for { + // If the iterator is exhausted, drop it off the end + if !it.it.Next() { + it.it.Release() + last := len(fi.iterators) - 1 + + fi.iterators[i] = fi.iterators[last] + fi.iterators[last] = nil + fi.iterators = fi.iterators[:last] + + i-- + break + } + // The iterator is still alive, check for collisions with previous ones + hash := it.it.Hash() + if other, exist := positioned[hash]; !exist { + positioned[hash] = i + break + } else { + // Iterators collide, one needs to be progressed, use priority to + // determine which. + // + // This whole else-block can be avoided, if we instead + // do an initial priority-sort of the iterators. If we do that, + // then we'll only wind up here if a lower-priority (preferred) iterator + // has the same value, and then we will always just continue. + // However, it costs an extra sort, so it's probably not better + if fi.iterators[other].priority < it.priority { + // The 'it' should be progressed + continue + } else { + // The 'other' should be progressed, swap them + it = fi.iterators[other] + fi.iterators[other], fi.iterators[i] = fi.iterators[i], fi.iterators[other] + continue + } + } + } + } + // Re-sort the entire list + sort.Sort(fi.iterators) + fi.initiated = false +} + +// Next steps the iterator forward one element, returning false if exhausted. +func (fi *fastAccountIterator) Next() bool { + if len(fi.iterators) == 0 { + return false + } + if !fi.initiated { + // Don't forward first time -- we had to 'Next' once in order to + // do the sorting already + fi.initiated = true + fi.curAccount = fi.iterators[0].it.Account() + if innerErr := fi.iterators[0].it.Error(); innerErr != nil { + fi.fail = innerErr + return false + } + if fi.curAccount != nil { + return true + } + // Implicit else: we've hit a nil-account, and need to fall through to the + // loop below to land on something non-nil + } + // If an account is deleted in one of the layers, the key will still be there, + // but the actual value will be nil. However, the iterator should not + // export nil-values (but instead simply omit the key), so we need to loop + // here until we either + // - get a non-nil value, + // - hit an error, + // - or exhaust the iterator + for { + if !fi.next(0) { + return false // exhausted + } + fi.curAccount = fi.iterators[0].it.Account() + if innerErr := fi.iterators[0].it.Error(); innerErr != nil { + fi.fail = innerErr + return false // error + } + if fi.curAccount != nil { + break // non-nil value found + } + } + return true +} + +// next handles the next operation internally and should be invoked when we know +// that two elements in the list may have the same value. +// +// For example, if the iterated hashes become [2,3,5,5,8,9,10], then we should +// invoke next(3), which will call Next on elem 3 (the second '5') and will +// cascade along the list, applying the same operation if needed. +func (fi *fastAccountIterator) next(idx int) bool { + // If this particular iterator got exhausted, remove it and return true (the + // next one is surely not exhausted yet, otherwise it would have been removed + // already). + if it := fi.iterators[idx].it; !it.Next() { + it.Release() + + fi.iterators = append(fi.iterators[:idx], fi.iterators[idx+1:]...) + return len(fi.iterators) > 0 + } + // If there's noone left to cascade into, return + if idx == len(fi.iterators)-1 { + return true + } + // We next-ed the iterator at 'idx', now we may have to re-sort that element + var ( + cur, next = fi.iterators[idx], fi.iterators[idx+1] + curHash, nextHash = cur.it.Hash(), next.it.Hash() + ) + if diff := bytes.Compare(curHash[:], nextHash[:]); diff < 0 { + // It is still in correct place + return true + } else if diff == 0 && cur.priority < next.priority { + // So still in correct place, but we need to iterate on the next + fi.next(idx + 1) + return true + } + // At this point, the iterator is in the wrong location, but the remaining + // list is sorted. Find out where to move the item. + clash := -1 + index := sort.Search(len(fi.iterators), func(n int) bool { + // The iterator always advances forward, so anything before the old slot + // is known to be behind us, so just skip them altogether. This actually + // is an important clause since the sort order got invalidated. + if n < idx { + return false + } + if n == len(fi.iterators)-1 { + // Can always place an elem last + return true + } + nextHash := fi.iterators[n+1].it.Hash() + if diff := bytes.Compare(curHash[:], nextHash[:]); diff < 0 { + return true + } else if diff > 0 { + return false + } + // The elem we're placing it next to has the same value, + // so whichever winds up on n+1 will need further iteraton + clash = n + 1 + + return cur.priority < fi.iterators[n+1].priority + }) + fi.move(idx, index) + if clash != -1 { + fi.next(clash) + } + return true +} + +// move advances an iterator to another position in the list. +func (fi *fastAccountIterator) move(index, newpos int) { + elem := fi.iterators[index] + copy(fi.iterators[index:], fi.iterators[index+1:newpos+1]) + fi.iterators[newpos] = elem +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit (e.g. snapshot stack becoming stale). +func (fi *fastAccountIterator) Error() error { + return fi.fail +} + +// Hash returns the current key +func (fi *fastAccountIterator) Hash() common.Hash { + return fi.iterators[0].it.Hash() +} + +// Account returns the current key +func (fi *fastAccountIterator) Account() []byte { + return fi.curAccount +} + +// Release iterates over all the remaining live layer iterators and releases each +// of thme individually. +func (fi *fastAccountIterator) Release() { + for _, it := range fi.iterators { + it.it.Release() + } + fi.iterators = nil +} + +// Debug is a convencience helper during testing +func (fi *fastAccountIterator) Debug() { + for _, it := range fi.iterators { + fmt.Printf("[p=%v v=%v] ", it.priority, it.it.Hash()[0]) + } + fmt.Println() +} diff --git a/core/state/snapshot/journal.go b/core/state/snapshot/journal.go new file mode 100644 index 0000000000..0c0e3a960c --- /dev/null +++ b/core/state/snapshot/journal.go @@ -0,0 +1,243 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "time" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/trie" +) + +// journalGenerator is a disk layer entry containing the generator progress marker. +type journalGenerator struct { + Wiping bool // Whether the database was in progress of being wiped + Done bool // Whether the generator finished creating the snapshot + Marker []byte + Accounts uint64 + Slots uint64 + Storage uint64 +} + +// journalDestruct is an account deletion entry in a diffLayer's disk journal. +type journalDestruct struct { + Hash common.Hash +} + +// journalAccount is an account entry in a diffLayer's disk journal. +type journalAccount struct { + Hash common.Hash + Blob []byte +} + +// journalStorage is an account's storage map in a diffLayer's disk journal. +type journalStorage struct { + Hash common.Hash + Keys []common.Hash + Vals [][]byte +} + +// loadSnapshot loads a pre-existing state snapshot backed by a key-value store. +func loadSnapshot(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash) (snapshot, error) { + // Retrieve the block number and hash of the snapshot, failing if no snapshot + // is present in the database (or crashed mid-update). + baseRoot := rawdb.ReadSnapshotRoot(diskdb) + if baseRoot == (common.Hash{}) { + return nil, errors.New("missing or corrupted snapshot") + } + base := &diskLayer{ + diskdb: diskdb, + triedb: triedb, + cache: fastcache.New(cache * 1024 * 1024), + root: baseRoot, + } + // Retrieve the journal, it must exist since even for 0 layer it stores whether + // we've already generated the snapshot or are in progress only + journal := rawdb.ReadSnapshotJournal(diskdb) + if len(journal) == 0 { + return nil, errors.New("missing or corrupted snapshot journal") + } + r := rlp.NewStream(bytes.NewReader(journal), 0) + + // Read the snapshot generation progress for the disk layer + var generator journalGenerator + if err := r.Decode(&generator); err != nil { + return nil, fmt.Errorf("failed to load snapshot progress marker: %v", err) + } + // Load all the snapshot diffs from the journal + snapshot, err := loadDiffLayer(base, r) + if err != nil { + return nil, err + } + // Entire snapshot journal loaded, sanity check the head and return + // Journal doesn't exist, don't worry if it's not supposed to + if head := snapshot.Root(); head != root { + return nil, fmt.Errorf("head doesn't match snapshot: have %#x, want %#x", head, root) + } + // Everything loaded correctly, resume any suspended operations + if !generator.Done { + // Whether or not wiping was in progress, load any generator progress too + base.genMarker = generator.Marker + if base.genMarker == nil { + base.genMarker = []byte{} + } + base.genPending = make(chan struct{}) + base.genAbort = make(chan chan *generatorStats) + + var origin uint64 + if len(generator.Marker) >= 8 { + origin = binary.BigEndian.Uint64(generator.Marker) + } + go base.generate(&generatorStats{ + origin: origin, + start: time.Now(), + accounts: generator.Accounts, + slots: generator.Slots, + storage: common.StorageSize(generator.Storage), + }) + } + return snapshot, nil +} + +// loadDiffLayer reads the next sections of a snapshot journal, reconstructing a new +// diff and verifying that it can be linked to the requested parent. +func loadDiffLayer(parent snapshot, r *rlp.Stream) (snapshot, error) { + // Read the next diff journal entry + var root common.Hash + if err := r.Decode(&root); err != nil { + // The first read may fail with EOF, marking the end of the journal + if err == io.EOF { + return parent, nil + } + return nil, fmt.Errorf("load diff root: %v", err) + } + var destructs []journalDestruct + if err := r.Decode(&destructs); err != nil { + return nil, fmt.Errorf("load diff destructs: %v", err) + } + destructSet := make(map[common.Hash]struct{}) + for _, entry := range destructs { + destructSet[entry.Hash] = struct{}{} + } + var accounts []journalAccount + if err := r.Decode(&accounts); err != nil { + return nil, fmt.Errorf("load diff accounts: %v", err) + } + accountData := make(map[common.Hash][]byte) + for _, entry := range accounts { + accountData[entry.Hash] = entry.Blob + } + var storage []journalStorage + if err := r.Decode(&storage); err != nil { + return nil, fmt.Errorf("load diff storage: %v", err) + } + storageData := make(map[common.Hash]map[common.Hash][]byte) + for _, entry := range storage { + slots := make(map[common.Hash][]byte) + for i, key := range entry.Keys { + slots[key] = entry.Vals[i] + } + storageData[entry.Hash] = slots + } + return loadDiffLayer(newDiffLayer(parent, root, destructSet, accountData, storageData), r) +} + +// Journal writes the persistent layer generator stats into a buffer to be stored +// in the database as the snapshot journal. +func (dl *diskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + // If the snapshot is currently being generated, abort it + var stats *generatorStats + if dl.genAbort != nil { + abort := make(chan *generatorStats) + dl.genAbort <- abort + + if stats = <-abort; stats != nil { + stats.Log("Journalling in-progress snapshot", dl.genMarker) + } + } + // Ensure the layer didn't get stale + dl.lock.RLock() + defer dl.lock.RUnlock() + + if dl.stale { + return common.Hash{}, ErrSnapshotStale + } + // Ensure the generator stats is written even if none was ran this cycle + journalProgress(dl.diskdb, dl.genMarker, stats) + + log.Debug("Journalled disk layer", "root", dl.root) + return dl.root, nil +} + +// Journal writes the memory layer contents into a buffer to be stored in the +// database as the snapshot journal. +func (dl *diffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + // Journal the parent first + base, err := dl.parent.Journal(buffer) + if err != nil { + return common.Hash{}, err + } + // Ensure the layer didn't get stale + dl.lock.RLock() + defer dl.lock.RUnlock() + + if dl.Stale() { + return common.Hash{}, ErrSnapshotStale + } + // Everything below was journalled, persist this layer too + if err := rlp.Encode(buffer, dl.root); err != nil { + return common.Hash{}, err + } + destructs := make([]journalDestruct, 0, len(dl.destructSet)) + for hash := range dl.destructSet { + destructs = append(destructs, journalDestruct{Hash: hash}) + } + if err := rlp.Encode(buffer, destructs); err != nil { + return common.Hash{}, err + } + accounts := make([]journalAccount, 0, len(dl.accountData)) + for hash, blob := range dl.accountData { + accounts = append(accounts, journalAccount{Hash: hash, Blob: blob}) + } + if err := rlp.Encode(buffer, accounts); err != nil { + return common.Hash{}, err + } + storage := make([]journalStorage, 0, len(dl.storageData)) + for hash, slots := range dl.storageData { + keys := make([]common.Hash, 0, len(slots)) + vals := make([][]byte, 0, len(slots)) + for key, val := range slots { + keys = append(keys, key) + vals = append(vals, val) + } + storage = append(storage, journalStorage{Hash: hash, Keys: keys, Vals: vals}) + } + if err := rlp.Encode(buffer, storage); err != nil { + return common.Hash{}, err + } + return base, nil +} diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go new file mode 100644 index 0000000000..34d7c77177 --- /dev/null +++ b/core/state/snapshot/snapshot.go @@ -0,0 +1,597 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package snapshot implements a journalled, dynamic state dump. +package snapshot + +import ( + "bytes" + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/metrics" + "github.com/tomochain/tomochain/trie" +) + +var ( + snapshotCleanAccountHitMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/hit", nil) + snapshotCleanAccountMissMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/miss", nil) + snapshotCleanAccountInexMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/inex", nil) + snapshotCleanAccountReadMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/read", nil) + snapshotCleanAccountWriteMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/write", nil) + + snapshotCleanStorageHitMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/hit", nil) + snapshotCleanStorageMissMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/miss", nil) + snapshotCleanStorageInexMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/inex", nil) + snapshotCleanStorageReadMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/read", nil) + snapshotCleanStorageWriteMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/write", nil) + + snapshotDirtyAccountHitMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/hit", nil) + snapshotDirtyAccountMissMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/miss", nil) + snapshotDirtyAccountInexMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/inex", nil) + snapshotDirtyAccountReadMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/read", nil) + snapshotDirtyAccountWriteMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/write", nil) + + snapshotDirtyStorageHitMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/hit", nil) + snapshotDirtyStorageMissMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/miss", nil) + snapshotDirtyStorageInexMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/inex", nil) + snapshotDirtyStorageReadMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/read", nil) + snapshotDirtyStorageWriteMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/write", nil) + + snapshotDirtyAccountHitDepthHist = metrics.NewRegisteredHistogram("state/snapshot/dirty/account/hit/depth", nil, metrics.NewExpDecaySample(1028, 0.015)) + snapshotDirtyStorageHitDepthHist = metrics.NewRegisteredHistogram("state/snapshot/dirty/storage/hit/depth", nil, metrics.NewExpDecaySample(1028, 0.015)) + + snapshotFlushAccountItemMeter = metrics.NewRegisteredMeter("state/snapshot/flush/account/item", nil) + snapshotFlushAccountSizeMeter = metrics.NewRegisteredMeter("state/snapshot/flush/account/size", nil) + snapshotFlushStorageItemMeter = metrics.NewRegisteredMeter("state/snapshot/flush/storage/item", nil) + snapshotFlushStorageSizeMeter = metrics.NewRegisteredMeter("state/snapshot/flush/storage/size", nil) + + snapshotBloomIndexTimer = metrics.NewRegisteredResettingTimer("state/snapshot/bloom/index", nil) + snapshotBloomErrorGauge = metrics.NewRegisteredGaugeFloat64("state/snapshot/bloom/error", nil) + + snapshotBloomAccountTrueHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/account/truehit", nil) + snapshotBloomAccountFalseHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/account/falsehit", nil) + snapshotBloomAccountMissMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/account/miss", nil) + + snapshotBloomStorageTrueHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/storage/truehit", nil) + snapshotBloomStorageFalseHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/storage/falsehit", nil) + snapshotBloomStorageMissMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/storage/miss", nil) + + // ErrSnapshotStale is returned from data accessors if the underlying snapshot + // layer had been invalidated due to the chain progressing forward far enough + // to not maintain the layer's original state. + ErrSnapshotStale = errors.New("snapshot stale") + + // ErrNotCoveredYet is returned from data accessors if the underlying snapshot + // is being generated currently and the requested data item is not yet in the + // range of accounts covered. + ErrNotCoveredYet = errors.New("not covered yet") + + // errSnapshotCycle is returned if a snapshot is attempted to be inserted + // that forms a cycle in the snapshot tree. + errSnapshotCycle = errors.New("snapshot cycle") +) + +// Snapshot represents the functionality supported by a snapshot storage layer. +type Snapshot interface { + // Root returns the root hash for which this snapshot was made. + Root() common.Hash + + // Account directly retrieves the account associated with a particular hash in + // the snapshot slim data format. + Account(hash common.Hash) (*Account, error) + + // AccountRLP directly retrieves the account RLP associated with a particular + // hash in the snapshot slim data format. + AccountRLP(hash common.Hash) ([]byte, error) + + // Storage directly retrieves the storage data associated with a particular hash, + // within a particular account. + Storage(accountHash, storageHash common.Hash) ([]byte, error) +} + +// snapshot is the internal version of the snapshot data layer that supports some +// additional methods compared to the public API. +type snapshot interface { + Snapshot + + // Parent returns the subsequent layer of a snapshot, or nil if the base was + // reached. + // + // Note, the method is an internal helper to avoid type switching between the + // disk and diff layers. There is no locking involved. + Parent() snapshot + + // Update creates a new layer on top of the existing snapshot diff tree with + // the specified data items. + // + // Note, the maps are retained by the method to avoid copying everything. + Update(blockRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer + + // Journal commits an entire diff hierarchy to disk into a single journal entry. + // This is meant to be used during shutdown to persist the snapshot without + // flattening everything down (bad for reorgs). + Journal(buffer *bytes.Buffer) (common.Hash, error) + + // Stale return whether this layer has become stale (was flattened across) or + // if it's still live. + Stale() bool + + // AccountIterator creates an account iterator over an arbitrary layer. + AccountIterator(seek common.Hash) AccountIterator +} + +// SnapshotTree is an Ethereum state snapshot tree. It consists of one persistent +// base layer backed by a key-value store, on top of which arbitrarily many in- +// memory diff layers are topped. The memory diffs can form a tree with branching, +// but the disk layer is singleton and common to all. If a reorg goes deeper than +// the disk layer, everything needs to be deleted. +// +// The goal of a state snapshot is twofold: to allow direct access to account and +// storage data to avoid expensive multi-level trie lookups; and to allow sorted, +// cheap iteration of the account/storage tries for sync aid. +type Tree struct { + diskdb ethdb.KeyValueStore // Persistent database to store the snapshot + triedb *trie.Database // In-memory cache to access the trie through + cache int // Megabytes permitted to use for read caches + layers map[common.Hash]snapshot // Collection of all known layers + lock sync.RWMutex +} + +// New attempts to load an already existing snapshot from a persistent key-value +// store (with a number of memory layers from a journal), ensuring that the head +// of the snapshot matches the expected one. +// +// If the snapshot is missing or inconsistent, the entirety is deleted and will +// be reconstructed from scratch based on the tries in the key-value store, on a +// background thread. +func New(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash, async bool) *Tree { + // Create a new, empty snapshot tree + snap := &Tree{ + diskdb: diskdb, + triedb: triedb, + cache: cache, + layers: make(map[common.Hash]snapshot), + } + if !async { + defer snap.waitBuild() + } + // Attempt to load a previously persisted snapshot and rebuild one if failed + head, err := loadSnapshot(diskdb, triedb, cache, root) + if err != nil { + log.Warn("Failed to load snapshot, regenerating", "err", err) + snap.Rebuild(root) + return snap + } + // Existing snapshot loaded, seed all the layers + for head != nil { + snap.layers[head.Root()] = head + head = head.Parent() + } + return snap +} + +// waitBuild blocks until the snapshot finishes rebuilding. This method is meant +// to be used by tests to ensure we're testing what we believe we are. +func (t *Tree) waitBuild() { + // Find the rebuild termination channel + var done chan struct{} + + t.lock.RLock() + for _, layer := range t.layers { + if layer, ok := layer.(*diskLayer); ok { + done = layer.genPending + break + } + } + t.lock.RUnlock() + + // Wait until the snapshot is generated + if done != nil { + <-done + } +} + +// Snapshot retrieves a snapshot belonging to the given block root, or nil if no +// snapshot is maintained for that block. +func (t *Tree) Snapshot(blockRoot common.Hash) Snapshot { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.layers[blockRoot] +} + +// Update adds a new snapshot into the tree, if that can be linked to an existing +// old parent. It is disallowed to insert a disk layer (the origin of all). +func (t *Tree) Update(blockRoot common.Hash, parentRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) error { + // Reject noop updates to avoid self-loops in the snapshot tree. This is a + // special case that can only happen for Clique networks where empty blocks + // don't modify the state (0 block subsidy). + // + // Although we could silently ignore this internally, it should be the caller's + // responsibility to avoid even attempting to insert such a snapshot. + if blockRoot == parentRoot { + return errSnapshotCycle + } + // Generate a new snapshot on top of the parent + parent := t.Snapshot(parentRoot).(snapshot) + if parent == nil { + return fmt.Errorf("parent [%#x] snapshot missing", parentRoot) + } + snap := parent.Update(blockRoot, destructs, accounts, storage) + + // Save the new snapshot for later + t.lock.Lock() + defer t.lock.Unlock() + + t.layers[snap.root] = snap + return nil +} + +// Cap traverses downwards the snapshot tree from a head block hash until the +// number of allowed layers are crossed. All layers beyond the permitted number +// are flattened downwards. +func (t *Tree) Cap(root common.Hash, layers int) error { + // Retrieve the head snapshot to cap from + snap := t.Snapshot(root) + if snap == nil { + return fmt.Errorf("snapshot [%#x] missing", root) + } + diff, ok := snap.(*diffLayer) + if !ok { + return fmt.Errorf("snapshot [%#x] is disk layer", root) + } + // Run the internal capping and discard all stale layers + t.lock.Lock() + defer t.lock.Unlock() + + // Flattening the bottom-most diff layer requires special casing since there's + // no child to rewire to the grandparent. In that case we can fake a temporary + // child for the capping and then remove it. + var persisted *diskLayer + + switch layers { + case 0: + // If full commit was requested, flatten the diffs and merge onto disk + diff.lock.RLock() + base := diffToDisk(diff.flatten().(*diffLayer)) + diff.lock.RUnlock() + + // Replace the entire snapshot tree with the flat base + t.layers = map[common.Hash]snapshot{base.root: base} + return nil + + case 1: + // If full flattening was requested, flatten the diffs but only merge if the + // memory limit was reached + var ( + bottom *diffLayer + base *diskLayer + ) + diff.lock.RLock() + bottom = diff.flatten().(*diffLayer) + if bottom.memory >= aggregatorMemoryLimit { + base = diffToDisk(bottom) + } + diff.lock.RUnlock() + + // If all diff layers were removed, replace the entire snapshot tree + if base != nil { + t.layers = map[common.Hash]snapshot{base.root: base} + return nil + } + // Merge the new aggregated layer into the snapshot tree, clean stales below + t.layers[bottom.root] = bottom + + default: + // Many layers requested to be retained, cap normally + persisted = t.cap(diff, layers) + } + // Remove any layer that is stale or links into a stale layer + children := make(map[common.Hash][]common.Hash) + for root, snap := range t.layers { + if diff, ok := snap.(*diffLayer); ok { + parent := diff.parent.Root() + children[parent] = append(children[parent], root) + } + } + var remove func(root common.Hash) + remove = func(root common.Hash) { + delete(t.layers, root) + for _, child := range children[root] { + remove(child) + } + delete(children, root) + } + for root, snap := range t.layers { + if snap.Stale() { + remove(root) + } + } + // If the disk layer was modified, regenerate all the cummulative blooms + if persisted != nil { + var rebloom func(root common.Hash) + rebloom = func(root common.Hash) { + if diff, ok := t.layers[root].(*diffLayer); ok { + diff.rebloom(persisted) + } + for _, child := range children[root] { + rebloom(child) + } + } + rebloom(persisted.root) + } + return nil +} + +// cap traverses downwards the diff tree until the number of allowed layers are +// crossed. All diffs beyond the permitted number are flattened downwards. If the +// layer limit is reached, memory cap is also enforced (but not before). +// +// The method returns the new disk layer if diffs were persistend into it. +func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer { + // Dive until we run out of layers or reach the persistent database + for ; layers > 2; layers-- { + // If we still have diff layers below, continue down + if parent, ok := diff.parent.(*diffLayer); ok { + diff = parent + } else { + // Diff stack too shallow, return without modifications + return nil + } + } + // We're out of layers, flatten anything below, stopping if it's the disk or if + // the memory limit is not yet exceeded. + switch parent := diff.parent.(type) { + case *diskLayer: + return nil + + case *diffLayer: + // Flatten the parent into the grandparent. The flattening internally obtains a + // write lock on grandparent. + flattened := parent.flatten().(*diffLayer) + t.layers[flattened.root] = flattened + + diff.lock.Lock() + defer diff.lock.Unlock() + + diff.parent = flattened + if flattened.memory < aggregatorMemoryLimit { + // Accumulator layer is smaller than the limit, so we can abort, unless + // there's a snapshot being generated currently. In that case, the trie + // will move fron underneath the generator so we **must** merge all the + // partial data down into the snapshot and restart the generation. + if flattened.parent.(*diskLayer).genAbort == nil { + return nil + } + } + default: + panic(fmt.Sprintf("unknown data layer: %T", parent)) + } + // If the bottom-most layer is larger than our memory cap, persist to disk + bottom := diff.parent.(*diffLayer) + + bottom.lock.RLock() + base := diffToDisk(bottom) + bottom.lock.RUnlock() + + t.layers[base.root] = base + diff.parent = base + return base +} + +// diffToDisk merges a bottom-most diff into the persistent disk layer underneath +// it. The method will panic if called onto a non-bottom-most diff layer. +func diffToDisk(bottom *diffLayer) *diskLayer { + var ( + base = bottom.parent.(*diskLayer) + batch = base.diskdb.NewBatch() + stats *generatorStats + ) + // If the disk layer is running a snapshot generator, abort it + if base.genAbort != nil { + abort := make(chan *generatorStats) + base.genAbort <- abort + stats = <-abort + } + // Start by temporarily deleting the current snapshot block marker. This + // ensures that in the case of a crash, the entire snapshot is invalidated. + rawdb.DeleteSnapshotRoot(batch) + + // Mark the original base as stale as we're going to create a new wrapper + base.lock.Lock() + if base.stale { + panic("parent disk layer is stale") // we've committed into the same base from two children, boo + } + base.stale = true + base.lock.Unlock() + + // Destroy all the destructed accounts from the database + for hash := range bottom.destructSet { + // Skip any account not covered yet by the snapshot + if base.genMarker != nil && bytes.Compare(hash[:], base.genMarker) > 0 { + continue + } + // Remove all storage slots + rawdb.DeleteAccountSnapshot(batch, hash) + base.cache.Set(hash[:], nil) + + it := rawdb.IterateStorageSnapshots(base.diskdb, hash) + for it.Next() { + if key := it.Key(); len(key) == 65 { // TODO(karalabe): Yuck, we should move this into the iterator + batch.Delete(key) + base.cache.Del(key[1:]) + + snapshotFlushStorageItemMeter.Mark(1) + } + } + it.Release() + } + // Push all updated accounts into the database + for hash, data := range bottom.accountData { + // Skip any account not covered yet by the snapshot + if base.genMarker != nil && bytes.Compare(hash[:], base.genMarker) > 0 { + continue + } + // Push the account to disk + rawdb.WriteAccountSnapshot(batch, hash, data) + base.cache.Set(hash[:], data) + snapshotCleanAccountWriteMeter.Mark(int64(len(data))) + + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + log.Crit("Failed to write account snapshot", "err", err) + } + batch.Reset() + } + snapshotFlushAccountItemMeter.Mark(1) + snapshotFlushAccountSizeMeter.Mark(int64(len(data))) + } + // Push all the storage slots into the database + for accountHash, storage := range bottom.storageData { + // Skip any account not covered yet by the snapshot + if base.genMarker != nil && bytes.Compare(accountHash[:], base.genMarker) > 0 { + continue + } + // Generation might be mid-account, track that case too + midAccount := base.genMarker != nil && bytes.Equal(accountHash[:], base.genMarker[:common.HashLength]) + + for storageHash, data := range storage { + // Skip any slot not covered yet by the snapshot + if midAccount && bytes.Compare(storageHash[:], base.genMarker[common.HashLength:]) > 0 { + continue + } + if len(data) > 0 { + rawdb.WriteStorageSnapshot(batch, accountHash, storageHash, data) + base.cache.Set(append(accountHash[:], storageHash[:]...), data) + snapshotCleanStorageWriteMeter.Mark(int64(len(data))) + } else { + rawdb.DeleteStorageSnapshot(batch, accountHash, storageHash) + base.cache.Set(append(accountHash[:], storageHash[:]...), nil) + } + snapshotFlushStorageItemMeter.Mark(1) + snapshotFlushStorageSizeMeter.Mark(int64(len(data))) + } + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + log.Crit("Failed to write storage snapshot", "err", err) + } + batch.Reset() + } + } + // Update the snapshot block marker and write any remainder data + rawdb.WriteSnapshotRoot(batch, bottom.root) + if err := batch.Write(); err != nil { + log.Crit("Failed to write leftover snapshot", "err", err) + } + res := &diskLayer{ + root: bottom.root, + cache: base.cache, + diskdb: base.diskdb, + triedb: base.triedb, + genMarker: base.genMarker, + genPending: base.genPending, + } + // If snapshot generation hasn't finished yet, port over all the starts and + // continue where the previous round left off. + // + // Note, the `base.genAbort` comparison is not used normally, it's checked + // to allow the tests to play with the marker without triggering this path. + if base.genMarker != nil && base.genAbort != nil { + res.genMarker = base.genMarker + res.genAbort = make(chan chan *generatorStats) + go res.generate(stats) + } + return res +} + +// Journal commits an entire diff hierarchy to disk into a single journal entry. +// This is meant to be used during shutdown to persist the snapshot without +// flattening everything down (bad for reorgs). +// +// The method returns the root hash of the base layer that needs to be persisted +// to disk as a trie too to allow continuing any pending generation op. +func (t *Tree) Journal(root common.Hash) (common.Hash, error) { + // Retrieve the head snapshot to journal from var snap snapshot + snap := t.Snapshot(root) + if snap == nil { + return common.Hash{}, fmt.Errorf("snapshot [%#x] missing", root) + } + // Run the journaling + t.lock.Lock() + defer t.lock.Unlock() + + journal := new(bytes.Buffer) + base, err := snap.(snapshot).Journal(journal) + if err != nil { + return common.Hash{}, err + } + // Store the journal into the database and return + rawdb.WriteSnapshotJournal(t.diskdb, journal.Bytes()) + return base, nil +} + +// Rebuild wipes all available snapshot data from the persistent database and +// discard all caches and diff layers. Afterwards, it starts a new snapshot +// generator with the given root hash. +func (t *Tree) Rebuild(root common.Hash) { + t.lock.Lock() + defer t.lock.Unlock() + + // Iterate over and mark all layers stale + for _, layer := range t.layers { + switch layer := layer.(type) { + case *diskLayer: + // If the base layer is generating, abort it and save + if layer.genAbort != nil { + abort := make(chan *generatorStats) + layer.genAbort <- abort + <-abort + } + // Layer should be inactive now, mark it as stale + layer.lock.Lock() + layer.stale = true + layer.lock.Unlock() + + case *diffLayer: + // If the layer is a simple diff, simply mark as stale + layer.lock.Lock() + atomic.StoreUint32(&layer.stale, 1) + layer.lock.Unlock() + + default: + panic(fmt.Sprintf("unknown layer type: %T", layer)) + } + } + // Start generating a new snapshot from scratch on a backgroung thread. The + // generator will run a wiper first if there's not one running right now. + log.Info("Rebuilding state snapshot") + t.layers = map[common.Hash]snapshot{ + root: generateSnapshot(t.diskdb, t.triedb, t.cache, root), + } +} + +// AccountIterator creates a new account iterator for the specified root hash and +// seeks to a starting account hash. +func (t *Tree) AccountIterator(root common.Hash, seek common.Hash) (AccountIterator, error) { + return newFastAccountIterator(t, root, seek) +} diff --git a/core/state/snapshot/snapshot_test.go b/core/state/snapshot/snapshot_test.go new file mode 100644 index 0000000000..75e53186b8 --- /dev/null +++ b/core/state/snapshot/snapshot_test.go @@ -0,0 +1,348 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "fmt" + "math/big" + "math/rand" + "testing" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/rlp" +) + +// randomHash generates a random blob of data and returns it as a hash. +func randomHash() common.Hash { + var hash common.Hash + if n, err := rand.Read(hash[:]); n != common.HashLength || err != nil { + panic(err) + } + return hash +} + +// randomAccount generates a random account and returns it RLP encoded. +func randomAccount() []byte { + root := randomHash() + a := Account{ + Balance: big.NewInt(rand.Int63()), + Nonce: rand.Uint64(), + Root: root[:], + CodeHash: emptyCode[:], + } + data, _ := rlp.EncodeToBytes(a) + return data +} + +// randomAccountSet generates a set of random accounts with the given strings as +// the account address hashes. +func randomAccountSet(hashes ...string) map[common.Hash][]byte { + accounts := make(map[common.Hash][]byte) + for _, hash := range hashes { + accounts[common.HexToHash(hash)] = randomAccount() + } + return accounts +} + +// Tests that if a disk layer becomes stale, no active external references will +// be returned with junk data. This version of the test flattens every diff layer +// to check internal corner case around the bottom-most memory accumulator. +func TestDiskLayerExternalInvalidationFullFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Retrieve a reference to the base and commit a diff on top + ref := snaps.Snapshot(base.root) + + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 2 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 2) + } + // Commit the diff layer onto the disk and ensure it's persisted + if err := snaps.Cap(common.HexToHash("0x02"), 0); err != nil { + t.Fatalf("failed to merge diff layer onto disk: %v", err) + } + // Since the base layer was modified, ensure that data retrieval on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 1 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 1) + fmt.Println(snaps.layers) + } +} + +// Tests that if a disk layer becomes stale, no active external references will +// be returned with junk data. This version of the test retains the bottom diff +// layer to check the usual mode of operation where the accumulator is retained. +func TestDiskLayerExternalInvalidationPartialFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Retrieve a reference to the base and commit two diffs on top + ref := snaps.Snapshot(base.root) + + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 3 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3) + } + // Commit the diff layer onto the disk and ensure it's persisted + defer func(memcap uint64) { aggregatorMemoryLimit = memcap }(aggregatorMemoryLimit) + aggregatorMemoryLimit = 0 + + if err := snaps.Cap(common.HexToHash("0x03"), 2); err != nil { + t.Fatalf("failed to merge diff layer onto disk: %v", err) + } + // Since the base layer was modified, ensure that data retrievald on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 2 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2) + fmt.Println(snaps.layers) + } +} + +// Tests that if a diff layer becomes stale, no active external references will +// be returned with junk data. This version of the test flattens every diff layer +// to check internal corner case around the bottom-most memory accumulator. +func TestDiffLayerExternalInvalidationFullFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Commit two diffs on top and retrieve a reference to the bottommost + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 3 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3) + } + ref := snaps.Snapshot(common.HexToHash("0x02")) + + // Flatten the diff layer into the bottom accumulator + if err := snaps.Cap(common.HexToHash("0x03"), 1); err != nil { + t.Fatalf("failed to flatten diff layer into accumulator: %v", err) + } + // Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 2 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2) + fmt.Println(snaps.layers) + } +} + +// Tests that if a diff layer becomes stale, no active external references will +// be returned with junk data. This version of the test retains the bottom diff +// layer to check the usual mode of operation where the accumulator is retained. +func TestDiffLayerExternalInvalidationPartialFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Commit three diffs on top and retrieve a reference to the bottommost + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 4 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 4) + } + ref := snaps.Snapshot(common.HexToHash("0x02")) + + // Doing a Cap operation with many allowed layers should be a no-op + exp := len(snaps.layers) + if err := snaps.Cap(common.HexToHash("0x04"), 2000); err != nil { + t.Fatalf("failed to flatten diff layer into accumulator: %v", err) + } + if got := len(snaps.layers); got != exp { + t.Errorf("layers modified, got %d exp %d", got, exp) + } + // Flatten the diff layer into the bottom accumulator + if err := snaps.Cap(common.HexToHash("0x04"), 2); err != nil { + t.Fatalf("failed to flatten diff layer into accumulator: %v", err) + } + // Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 3 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 3) + fmt.Println(snaps.layers) + } +} + +// TestPostCapBasicDataAccess tests some functionality regarding capping/flattening. +func TestPostCapBasicDataAccess(t *testing.T) { + // setAccount is a helper to construct a random account entry and assign it to + // an account slot in a snapshot + setAccount := func(accKey string) map[common.Hash][]byte { + return map[common.Hash][]byte{ + common.HexToHash(accKey): randomAccount(), + } + } + // Create a starting base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // The lowest difflayer + snaps.Update(common.HexToHash("0xa1"), common.HexToHash("0x01"), nil, setAccount("0xa1"), nil) + snaps.Update(common.HexToHash("0xa2"), common.HexToHash("0xa1"), nil, setAccount("0xa2"), nil) + snaps.Update(common.HexToHash("0xb2"), common.HexToHash("0xa1"), nil, setAccount("0xb2"), nil) + + snaps.Update(common.HexToHash("0xa3"), common.HexToHash("0xa2"), nil, setAccount("0xa3"), nil) + snaps.Update(common.HexToHash("0xb3"), common.HexToHash("0xb2"), nil, setAccount("0xb3"), nil) + + // checkExist verifies if an account exiss in a snapshot + checkExist := func(layer *diffLayer, key string) error { + if data, _ := layer.Account(common.HexToHash(key)); data == nil { + return fmt.Errorf("expected %x to exist, got nil", common.HexToHash(key)) + } + return nil + } + // shouldErr checks that an account access errors as expected + shouldErr := func(layer *diffLayer, key string) error { + if data, err := layer.Account(common.HexToHash(key)); err == nil { + return fmt.Errorf("expected error, got data %x", data) + } + return nil + } + // check basics + snap := snaps.Snapshot(common.HexToHash("0xb3")).(*diffLayer) + + if err := checkExist(snap, "0xa1"); err != nil { + t.Error(err) + } + if err := checkExist(snap, "0xb2"); err != nil { + t.Error(err) + } + if err := checkExist(snap, "0xb3"); err != nil { + t.Error(err) + } + // Cap to a bad root should fail + if err := snaps.Cap(common.HexToHash("0x1337"), 0); err == nil { + t.Errorf("expected error, got none") + } + // Now, merge the a-chain + snaps.Cap(common.HexToHash("0xa3"), 0) + + // At this point, a2 got merged into a1. Thus, a1 is now modified, and as a1 is + // the parent of b2, b2 should no longer be able to iterate into parent. + + // These should still be accessible + if err := checkExist(snap, "0xb2"); err != nil { + t.Error(err) + } + if err := checkExist(snap, "0xb3"); err != nil { + t.Error(err) + } + // But these would need iteration into the modified parent + if err := shouldErr(snap, "0xa1"); err != nil { + t.Error(err) + } + if err := shouldErr(snap, "0xa2"); err != nil { + t.Error(err) + } + if err := shouldErr(snap, "0xa3"); err != nil { + t.Error(err) + } + // Now, merge it again, just for fun. It should now error, since a3 + // is a disk layer + if err := snaps.Cap(common.HexToHash("0xa3"), 0); err == nil { + t.Error("expected error capping the disk layer, got none") + } +} diff --git a/core/state/snapshot/sort.go b/core/state/snapshot/sort.go new file mode 100644 index 0000000000..dc877911a1 --- /dev/null +++ b/core/state/snapshot/sort.go @@ -0,0 +1,36 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + + "github.com/tomochain/tomochain/common" +) + +// hashes is a helper to implement sort.Interface. +type hashes []common.Hash + +// Len is the number of elements in the collection. +func (hs hashes) Len() int { return len(hs) } + +// Less reports whether the element with index i should sort before the element +// with index j. +func (hs hashes) Less(i, j int) bool { return bytes.Compare(hs[i][:], hs[j][:]) < 0 } + +// Swap swaps the elements with indexes i and j. +func (hs hashes) Swap(i, j int) { hs[i], hs[j] = hs[j], hs[i] } diff --git a/core/state/state_object.go b/core/state/state_object.go index 0aa936fc9c..477bc02da2 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -329,7 +329,7 @@ func (self *stateObject) setCode(codeHash common.Hash, code []byte) { func (self *stateObject) SetNonce(nonce uint64) { self.db.journal.append(nonceChange{ - account: &self.address, + account: &self.address, prev: self.data.Nonce, }) self.setNonce(nonce) diff --git a/core/state/state_test.go b/core/state/state_test.go index 85cb7ee5b7..2c40de412b 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -138,7 +138,7 @@ func (s *StateSuite) TestSnapshotEmpty(c *checker.C) { // printing/logging in tests (-check.vv does not work) func TestSnapshot2(t *testing.T) { db := rawdb.NewMemoryDatabase() - state, _ := New(common.Hash{}, NewDatabase(db)) + state, _ := New(common.Hash{}, NewDatabase(db), nil) stateobjaddr0 := toAddr([]byte("so0")) stateobjaddr1 := toAddr([]byte("so1")) diff --git a/core/state/statedb.go b/core/state/statedb.go index a6dfb7f0a6..9552e68b51 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -24,6 +24,8 @@ import ( "sync" "time" + "github.com/tomochain/tomochain/core/state/snapshot" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" @@ -57,9 +59,8 @@ type StateDB struct { db Database trie Trie - snapDestructs map[common.Hash]struct{} - snapAccounts map[common.Hash][]byte - snapStorage map[common.Hash]map[common.Hash][]byte + snaps *snapshot.Tree // Nil if snapshot is not available + snap snapshot.Snapshot // Nil if snapshot is not available // This map holds 'live' objects, which will get modified while processing a state transition. stateObjects map[common.Address]*stateObject @@ -122,7 +123,7 @@ func (self *StateDB) GetCommittedState(addr common.Address, hash common.Hash) co } // Create a new state from a given trie. -func New(root common.Hash, db Database) (*StateDB, error) { +func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { tr, err := db.OpenTrie(root) if err != nil { return nil, err @@ -130,12 +131,16 @@ func New(root common.Hash, db Database) (*StateDB, error) { sdb := &StateDB{ db: db, trie: tr, + snaps: snaps, stateObjects: make(map[common.Address]*stateObject), stateObjectsDirty: make(map[common.Address]struct{}), logs: make(map[common.Hash][]*types.Log), preimages: make(map[common.Hash][]byte), journal: newJournal(), } + if sdb.snaps != nil { + sdb.snap = sdb.snaps.Snapshot(root) + } return sdb, nil } diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 19fefb6548..69c6491f01 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -41,7 +41,7 @@ type testAccount struct { func makeTestState() (Database, common.Hash, []*testAccount) { // Create an empty state db := NewDatabase(rawdb.NewMemoryDatabase()) - state, _ := New(common.Hash{}, db) + state, _ := New(common.Hash{}, db, nil) // Fill it with some arbitrary data accounts := []*testAccount{} @@ -72,7 +72,7 @@ func makeTestState() (Database, common.Hash, []*testAccount) { // account array. func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) { // Check root availability and state contents - state, err := New(root, NewDatabase(db)) + state, err := New(root, NewDatabase(db), nil) if err != nil { t.Fatalf("failed to create state trie at %x: %v", root, err) } @@ -113,7 +113,7 @@ func checkStateConsistency(db ethdb.Database, root common.Hash) error { if _, err := db.Get(root.Bytes()); err != nil { return nil // Consider a non existent state consistent. } - state, err := New(root, NewDatabase(db)) + state, err := New(root, NewDatabase(db), nil) if err != nil { return err } diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 8ddb0650ea..058968d24e 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -26,6 +26,9 @@ import ( "testing" "time" + "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/rawdb" @@ -97,7 +100,7 @@ func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ec func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { diskdb := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} key, _ := crypto.GenerateKey() @@ -177,7 +180,7 @@ func (c *testChain) State() (*state.StateDB, error) { stdb := c.statedb if *c.trigger { db := rawdb.NewMemoryDatabase() - c.statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) + c.statedb, _ = state.New(common.Hash{}, state.NewDatabase(db), nil) // simulate that the new head block included tx0 and tx1 c.statedb.SetNonce(c.address, 2) c.statedb.SetBalance(c.address, new(big.Int).SetUint64(params.Ether)) @@ -196,7 +199,7 @@ func TestStateChangeDuringTransactionPoolReset(t *testing.T) { db = rawdb.NewMemoryDatabase() key, _ = crypto.GenerateKey() address = crypto.PubkeyToAddress(key.PublicKey) - statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ = state.New(common.Hash{}, state.NewDatabase(db), nil) trigger = false ) @@ -356,7 +359,7 @@ func TestTransactionChainFork(t *testing.T) { addr := crypto.PubkeyToAddress(key.PublicKey) resetState := func() { db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) statedb.AddBalance(addr, big.NewInt(100000000000000)) pool.chain = &testBlockChain{statedb, 1000000, new(event.Feed)} @@ -386,7 +389,7 @@ func TestTransactionDoubleNonce(t *testing.T) { addr := crypto.PubkeyToAddress(key.PublicKey) resetState := func() { db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) statedb.AddBalance(addr, big.NewInt(100000000000000)) pool.chain = &testBlockChain{statedb, 1000000, new(event.Feed)} @@ -577,7 +580,7 @@ func TestTransactionPostponing(t *testing.T) { // Create the pool to test the postponing with db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) @@ -793,7 +796,7 @@ func testTransactionQueueGlobalLimiting(t *testing.T, nolocals bool) { // Create the pool to test the limit enforcement with db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} config := testTxPoolConfig @@ -886,7 +889,7 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { // Create the pool to test the non-expiration enforcement db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} config := testTxPoolConfig @@ -1043,7 +1046,7 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) { // Create the pool to test the limit enforcement with db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} config := testTxPoolConfig @@ -1090,7 +1093,7 @@ func TestTransactionCapClearsFromAll(t *testing.T) { // Create the pool to test the limit enforcement with db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} config := testTxPoolConfig @@ -1125,7 +1128,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) { // Create the pool to test the limit enforcement with db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} config := testTxPoolConfig @@ -1175,7 +1178,7 @@ func TestTransactionPoolRepricing(t *testing.T) { // Create the pool to test the pricing enforcement with db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) @@ -1297,7 +1300,7 @@ func TestTransactionPoolRepricingKeepsLocals(t *testing.T) { // Create the pool to test the pricing enforcement with db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) @@ -1360,7 +1363,7 @@ func TestTransactionPoolUnderpricing(t *testing.T) { // Create the pool to test the pricing enforcement with db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} config := testTxPoolConfig @@ -1462,7 +1465,7 @@ func TestTransactionReplacement(t *testing.T) { // Create the pool to test the pricing enforcement with db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) @@ -1557,7 +1560,7 @@ func testTransactionJournaling(t *testing.T, nolocals bool) { // Create the original pool to inject transaction into the journal db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} config := testTxPoolConfig @@ -1656,7 +1659,7 @@ func TestTransactionStatusCheck(t *testing.T) { // Create the pool to test the status retrievals with db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) diff --git a/core/vm/gas_table_test.go b/core/vm/gas_table_test.go index ba31cf4945..7e7df4f891 100644 --- a/core/vm/gas_table_test.go +++ b/core/vm/gas_table_test.go @@ -17,11 +17,12 @@ package vm import ( - "github.com/tomochain/tomochain/core/rawdb" "math" "math/big" "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/core/state" @@ -81,7 +82,7 @@ func TestEIP2200(t *testing.T) { for i, tt := range eip2200Tests { address := common.BytesToAddress([]byte("contract")) db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) statedb.CreateAccount(address) statedb.SetCode(address, hexutil.MustDecode(tt.input)) statedb.SetState(address, common.Hash{}, common.BytesToHash([]byte{tt.original})) @@ -91,7 +92,7 @@ func TestEIP2200(t *testing.T) { CanTransfer: func(StateDB, common.Address, *big.Int) bool { return true }, Transfer: func(StateDB, common.Address, common.Address, *big.Int) {}, } - vmenv := NewEVM(vmctx, statedb, nil,params.AllEthashProtocolChanges, Config{ExtraEips: []int{2200}}) + vmenv := NewEVM(vmctx, statedb, nil, params.AllEthashProtocolChanges, Config{ExtraEips: []int{2200}}) _, gas, err := vmenv.Call(AccountRef(common.Address{}), address, nil, tt.gaspool, new(big.Int)) if err != tt.failure { diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go index 683cad1d1c..9a13d3d6f6 100644 --- a/core/vm/runtime/runtime.go +++ b/core/vm/runtime/runtime.go @@ -17,11 +17,12 @@ package runtime import ( - "github.com/tomochain/tomochain/core/rawdb" "math" "math/big" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/vm" @@ -100,7 +101,7 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { if cfg.State == nil { db := rawdb.NewMemoryDatabase() - cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db)) + cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db), nil) } var ( address = common.BytesToAddress([]byte("contract")) @@ -131,7 +132,7 @@ func Create(input []byte, cfg *Config) ([]byte, common.Address, uint64, error) { if cfg.State == nil { db := rawdb.NewMemoryDatabase() - cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db)) + cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db), nil) } var ( vmenv = NewEnv(cfg) diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go index e430c2b2ae..0b95751d34 100644 --- a/core/vm/runtime/runtime_test.go +++ b/core/vm/runtime/runtime_test.go @@ -17,11 +17,12 @@ package runtime import ( - "github.com/tomochain/tomochain/core/rawdb" "math/big" "strings" "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/accounts/abi" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" @@ -99,7 +100,7 @@ func TestExecute(t *testing.T) { func TestCall(t *testing.T) { db := rawdb.NewMemoryDatabase() - state, _ := state.New(common.Hash{}, state.NewDatabase(db)) + state, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) address := common.HexToAddress("0x0a") state.SetCode(address, []byte{ byte(vm.PUSH1), 10, @@ -156,7 +157,7 @@ func BenchmarkCall(b *testing.B) { func benchmarkEVM_Create(bench *testing.B, code string) { var ( db = rawdb.NewMemoryDatabase() - statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ = state.New(common.Hash{}, state.NewDatabase(db), nil) sender = common.BytesToAddress([]byte("sender")) receiver = common.BytesToAddress([]byte("receiver")) ) diff --git a/eth/api_test.go b/eth/api_test.go index f9f2fc43da..f0a48df523 100644 --- a/eth/api_test.go +++ b/eth/api_test.go @@ -17,10 +17,11 @@ package eth import ( - "github.com/tomochain/tomochain/core/rawdb" "reflect" "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/davecgh/go-spew/spew" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/state" @@ -32,7 +33,7 @@ func TestStorageRangeAt(t *testing.T) { // Create a state where account 0x010000... has a few storage entries. var ( db = rawdb.NewMemoryDatabase() - state, _ = state.New(common.Hash{}, state.NewDatabase(db)) + state, _ = state.New(common.Hash{}, state.NewDatabase(db), nil) addr = common.Address{0x01} keys = []common.Hash{ // hashes of Keys of storage common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"), diff --git a/eth/api_tracer.go b/eth/api_tracer.go index 9429dee29f..f5abeab34d 100644 --- a/eth/api_tracer.go +++ b/eth/api_tracer.go @@ -145,7 +145,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl return nil, fmt.Errorf("parent block #%d not found", number-1) } } - statedb, err := state.New(start.Root(), database) + statedb, err := state.New(start.Root(), database, nil) var tomoxState *tradingstate.TradingStateDB if err != nil { // If the starting state is missing, allow some number of blocks to be reexecuted @@ -159,7 +159,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl if start == nil { break } - if statedb, err = state.New(start.Root(), database); err == nil { + if statedb, err = state.New(start.Root(), database, nil); err == nil { tomoxState, err = tradingstate.New(start.Root(), tradingstate.NewDatabase(api.eth.TomoX.GetLevelDB())) if err == nil { break @@ -514,8 +514,8 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (* if block == nil { break } - if statedb, err = state.New(block.Root(), database); err == nil { - tomoxState, err = api.eth.blockchain.OrderStateAt(block) + if statedb, err = state.New(block.Root(), database, nil); err == nil { + tomoxState, err = tradingstate.New(block.Root(), tradingstate.NewDatabase(api.eth.TomoX.GetLevelDB())) if err == nil { break } diff --git a/eth/backend.go b/eth/backend.go index 8bd7806bfc..1fe58ca2da 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -121,6 +121,7 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi if !config.SyncMode.IsValid() { return nil, fmt.Errorf("invalid sync mode %d", config.SyncMode) } + chainDb, err := CreateDB(ctx, config, "chaindata") if err != nil { return nil, err @@ -164,7 +165,12 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi } var ( vmConfig = vm.Config{EnablePreimageRecording: config.EnablePreimageRecording} - cacheConfig = &core.CacheConfig{Disabled: config.NoPruning, TrieNodeLimit: config.TrieCache, TrieTimeLimit: config.TrieTimeout} + cacheConfig = &core.CacheConfig{ + Disabled: config.NoPruning, + TrieNodeLimit: config.TrieCache, + TrieTimeLimit: config.TrieTimeout, + SnapshotLimit: config.SnapshotCache, + } ) if eth.chainConfig.Posv != nil { c := eth.engine.(*posv.Posv) diff --git a/eth/config.go b/eth/config.go index a86f084561..8b62ab7e48 100644 --- a/eth/config.go +++ b/eth/config.go @@ -48,6 +48,7 @@ var DefaultConfig = Config{ DatabaseCache: 768, TrieCache: 256, TrieTimeout: 5 * time.Minute, + SnapshotCache: 256, GasPrice: big.NewInt(0.25 * params.Shannon), TxPool: core.DefaultTxPoolConfig, @@ -93,6 +94,7 @@ type Config struct { DatabaseCache int TrieCache int TrieTimeout time.Duration + SnapshotCache int // Mining-related options Etherbase common.Address `toml:",omitempty"` diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go index 142089586c..d1bc108fd2 100644 --- a/eth/fetcher/fetcher.go +++ b/eth/fetcher/fetcher.go @@ -23,6 +23,7 @@ import ( "time" lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/types" diff --git a/eth/handler_test.go b/eth/handler_test.go index d8d2f00979..bee29ea90e 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -17,13 +17,14 @@ package eth import ( - "github.com/tomochain/tomochain/core/rawdb" "math" "math/big" "math/rand" "testing" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" @@ -343,9 +344,9 @@ func testGetNodeData(t *testing.T, protocol int) { // Fetch for now the entire chain db hashes := []common.Hash{} - it:=db.NewIterator(nil,nil) + it := db.NewIterator(nil, nil) for it.Next() { - key:=it.Key() + key := it.Key() if len(key) == len(common.Hash{}) { hashes = append(hashes, common.BytesToHash(key)) } @@ -374,7 +375,7 @@ func testGetNodeData(t *testing.T, protocol int) { } accounts := []common.Address{testBank, acc1Addr, acc2Addr} for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ { - trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb)) + trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb), nil) for j, acc := range accounts { state, _ := pm.blockchain.State() @@ -470,7 +471,7 @@ func testDAOChallenge(t *testing.T, localForked, remoteForked bool, timeout bool var ( evmux = new(event.TypeMux) pow = ethash.NewFaker() - db = rawdb.NewMemoryDatabase() + db = rawdb.NewMemoryDatabase() config = ¶ms.ChainConfig{DAOForkBlock: big.NewInt(1), DAOForkSupport: localForked} gspec = &core.Genesis{Config: config} genesis = gspec.MustCommit(db) diff --git a/eth/tracers/tracers_test.go b/eth/tracers/tracers_test.go index 9f469aeb89..b0f96d4a68 100644 --- a/eth/tracers/tracers_test.go +++ b/eth/tracers/tracers_test.go @@ -170,7 +170,7 @@ func TestPrestateTracerCreate2(t *testing.T) { Balance: big.NewInt(500000000000000), } db := rawdb.NewMemoryDatabase() - statedb := tests.MakePreState(db, alloc) + statedb := tests.MakePreState(db, alloc, false) // Create the tracer, the EVM environment and run it tracer, err := New("prestateTracer") @@ -245,7 +245,7 @@ func TestCallTracer(t *testing.T) { GasPrice: tx.GasPrice(), } db := rawdb.NewMemoryDatabase() - statedb := tests.MakePreState(db, test.Genesis.Alloc) + statedb := tests.MakePreState(db, test.Genesis.Alloc, false) // Create the tracer, the EVM environment and run it tracer, err := New("callTracer") diff --git a/go.mod b/go.mod index ae326fcb35..6c16905987 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,6 @@ require ( github.com/stretchr/testify v1.8.1 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 golang.org/x/crypto v0.1.0 - golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691 golang.org/x/net v0.8.0 golang.org/x/sync v0.1.0 golang.org/x/sys v0.7.0 @@ -59,6 +58,7 @@ require ( github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect github.com/google/uuid v1.3.0 // indirect + github.com/holiman/bloomfilter/v2 v2.0.3 github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/maruel/panicparse v0.0.0-20160720141634-ad661195ed0e // indirect diff --git a/go.sum b/go.sum index 6699d53904..693c5630c6 100644 --- a/go.sum +++ b/go.sum @@ -106,6 +106,8 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.3 h1:YPkqC67at8FYaadspW/6uE0COsBxS2656RLEr8Bppgk= github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/holiman/bloomfilter/v2 v2.0.3 h1:73e0e/V0tCydx14a0SCYS/EWCxgwLZ18CZcZKVu0fao= +github.com/holiman/bloomfilter/v2 v2.0.3/go.mod h1:zpoh+gs7qcpqrHr3dB55AMiJwo0iURXE7ZOP9L9hSkA= github.com/holiman/uint256 v1.2.2 h1:TXKcSGc2WaxPD2+bmzAsVthL4+pEN0YwXcL5qED83vk= github.com/holiman/uint256 v1.2.2/go.mod h1:SC8Ryt4n+UBbPbIBKaG9zbbDlp4jOru9xFZmPzLUTxw= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= @@ -246,8 +248,6 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= -golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691 h1:/yRP+0AN7mf5DkD3BAI6TOFnd51gEoDEb8o35jIFtgw= -golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= diff --git a/les/odr_test.go b/les/odr_test.go index 556f6467c1..302a294259 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -23,6 +23,8 @@ import ( "testing" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" @@ -91,7 +93,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon for _, addr := range acc { if bc != nil { header := bc.GetHeaderByHash(bhash) - st, err = state.New(header.Root, state.NewDatabase(db)) + st, err = state.New(header.Root, state.NewDatabase(db), nil) } else { header := lc.GetHeaderByHash(bhash) st = light.NewState(ctx, header, lc.Odr()) @@ -117,7 +119,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai data[35] = byte(i) if bc != nil { header := bc.GetHeaderByHash(bhash) - statedb, err := state.New(header.Root, state.NewDatabase(db)) + statedb, err := state.New(header.Root, state.NewDatabase(db), nil) if err == nil { from := statedb.GetOrNewStateObject(testBankAddress) diff --git a/light/odr_test.go b/light/odr_test.go index 997f7049ac..1debcd4c1e 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -24,6 +24,9 @@ import ( "testing" "time" + "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus" @@ -140,7 +143,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc st = NewState(ctx, header, lc.Odr()) } else { header := bc.GetHeaderByHash(bhash) - st, _ = state.New(header.Root, state.NewDatabase(db)) + st, _ = state.New(header.Root, state.NewDatabase(db), nil) } var res []byte @@ -180,7 +183,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain } else { chain = bc header = bc.GetHeaderByHash(bhash) - st, _ = state.New(header.Root, state.NewDatabase(db)) + st, _ = state.New(header.Root, state.NewDatabase(db), nil) } // Perform read-only call. diff --git a/light/trie.go b/light/trie.go index c469491ef8..8d32392f46 100644 --- a/light/trie.go +++ b/light/trie.go @@ -31,7 +31,7 @@ import ( ) func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB { - state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr)) + state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr), nil) return state } diff --git a/tests/state_test.go b/tests/state_test.go index 81a7370d60..a6d23edacf 100644 --- a/tests/state_test.go +++ b/tests/state_test.go @@ -53,13 +53,17 @@ func TestState(t *testing.T) { subtest := subtest key := fmt.Sprintf("%s/%d", subtest.Fork, subtest.Index) name := name + "/" + key - t.Run(key, func(t *testing.T) { - if subtest.Fork == "Constantinople" { - t.Skip("constantinople not supported yet") - } + + t.Run(key+"/trie", func(t *testing.T) { + withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { + _, err := test.Run(subtest, vmconfig, false) + return st.checkFailure(t, name+"/trie", err) + }) + }) + t.Run(key+"/snap", func(t *testing.T) { withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { - _, err := test.Run(subtest, vmconfig) - return st.checkFailure(t, name, err) + _, err := test.Run(subtest, vmconfig, true) + return st.checkFailure(t, name+"/snap", err) }) }) } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 217d519c89..39b8250648 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -30,6 +30,7 @@ import ( "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" + "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/sha3" @@ -121,14 +122,14 @@ func (t *StateTest) Subtests() []StateSubtest { } // Run executes a specific subtest. -func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateDB, error) { +func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config, snapshotter bool) (*state.StateDB, error) { config, ok := Forks[subtest.Fork] if !ok { return nil, UnsupportedForkError{subtest.Fork} } block := t.genesis(config).ToBlock(nil) db := rawdb.NewMemoryDatabase() - statedb := MakePreState(db, t.json.Pre) + statedb := MakePreState(db, t.json.Pre, snapshotter) post := t.json.Post[subtest.Fork][subtest.Index] msg, err := t.json.Tx.toMessage(post) @@ -161,9 +162,9 @@ func (t *StateTest) gasLimit(subtest StateSubtest) uint64 { return t.json.Tx.GasLimit[t.json.Post[subtest.Fork][subtest.Index].Indexes.Gas] } -func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB { +func MakePreState(db ethdb.Database, accounts core.GenesisAlloc, snapshotter bool) *state.StateDB { sdb := state.NewDatabase(db) - statedb, _ := state.New(common.Hash{}, sdb) + statedb, _ := state.New(common.Hash{}, sdb, nil) for addr, a := range accounts { statedb.SetCode(addr, a.Code) statedb.SetNonce(addr, a.Nonce) @@ -174,7 +175,12 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB } // Commit and re-open to start with a clean state. root, _ := statedb.Commit(false) - statedb, _ = state.New(root, sdb) + + var snaps *snapshot.Tree + if snapshotter { + snaps = snapshot.New(db, sdb.TrieDB(), 1, root, false) + } + statedb, _ = state.New(root, sdb, snaps) return statedb } diff --git a/tests/vm_test.go b/tests/vm_test.go index 8377d4bc32..234d73620c 100644 --- a/tests/vm_test.go +++ b/tests/vm_test.go @@ -37,7 +37,10 @@ func TestVM(t *testing.T) { vmt.walk(t, vmTestDir, func(t *testing.T, name string, test *VMTest) { withTrace(t, test.json.Exec.GasLimit, func(vmconfig vm.Config) error { - return vmt.checkFailure(t, name, test.Run(vmconfig)) + return vmt.checkFailure(t, name+"/trie", test.Run(vmconfig, false)) + }) + withTrace(t, test.json.Exec.GasLimit, func(vmconfig vm.Config) error { + return vmt.checkFailure(t, name+"/snap", test.Run(vmconfig, true)) }) }) } diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go index 01c471af27..c2a56d7796 100644 --- a/tests/vm_test_util.go +++ b/tests/vm_test_util.go @@ -20,9 +20,10 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" @@ -78,9 +79,9 @@ type vmExecMarshaling struct { GasPrice *math.HexOrDecimal256 } -func (t *VMTest) Run(vmconfig vm.Config) error { +func (t *VMTest) Run(vmconfig vm.Config, snapshotter bool) error { db := rawdb.NewMemoryDatabase() - statedb := MakePreState(db, t.json.Pre) + statedb := MakePreState(db, t.json.Pre, snapshotter) ret, gasRemaining, err := t.exec(statedb, vmconfig) if t.json.GasRemaining == nil { diff --git a/tomoxlending/lendingstate/lendingitem_test.go b/tomoxlending/lendingstate/lendingitem_test.go index b83c59ebee..564dffddf6 100644 --- a/tomoxlending/lendingstate/lendingitem_test.go +++ b/tomoxlending/lendingstate/lendingitem_test.go @@ -2,17 +2,18 @@ package lendingstate import ( "fmt" + "math/big" + "math/rand" + "os" + "testing" + "time" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/sha3" "github.com/tomochain/tomochain/rpc" - "math/big" - "math/rand" - "os" - "testing" - "time" ) func TestLendingItem_VerifyLendingSide(t *testing.T) { @@ -152,7 +153,7 @@ func SetCollateralDetail(statedb *state.StateDB, token common.Address, depositRa func TestVerifyBalance(t *testing.T) { db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) relayer := common.HexToAddress("0x0D3ab14BBaD3D99F4203bd7a11aCB94882050E7e") uAddr := common.HexToAddress("0xDeE6238780f98c0ca2c2C28453149bEA49a3Abc9") lendingToken := common.HexToAddress("0xd9bb01454c85247B2ef35BB5BE57384cC275a8cf") // USD From a3c9a579f6937a011bfdd792fefa93bc4bc798e2 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 17 Jul 2023 13:48:30 +0700 Subject: [PATCH 067/119] Fix after rebase --- core/blockchain_test.go | 2 - core/rawdb/database.go | 3 +- core/state/dump.go | 18 +- core/state/managed_state_test.go | 2 - core/state/state_test.go | 2 +- core/state/statedb.go | 382 ++++++++++++++++++------------- core/state/statedb_test.go | 14 +- core/tx_pool_test.go | 3 - les/odr_test.go | 2 - light/odr_test.go | 3 - 10 files changed, 236 insertions(+), 195 deletions(-) diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 08ed2f3ba9..ac911ca3fd 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -24,8 +24,6 @@ import ( "testing" "time" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core/rawdb" diff --git a/core/rawdb/database.go b/core/rawdb/database.go index 0ebedd7f7f..ea1dfe2347 100644 --- a/core/rawdb/database.go +++ b/core/rawdb/database.go @@ -23,12 +23,11 @@ import ( "time" "github.com/olekukonko/tablewriter" - "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/ethdb/leveldb" "github.com/tomochain/tomochain/ethdb/memorydb" + "github.com/tomochain/tomochain/log" ) // freezerdb is a database wrapper that enabled freezer data retrievals. diff --git a/core/state/dump.go b/core/state/dump.go index 3ac545a9c7..6d8994462f 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -40,15 +40,15 @@ type Dump struct { Accounts map[string]DumpAccount `json:"accounts"` } -func (self *StateDB) RawDump() Dump { +func (s *StateDB) RawDump() Dump { dump := Dump{ - Root: fmt.Sprintf("%x", self.trie.Hash()), + Root: fmt.Sprintf("%x", s.trie.Hash()), Accounts: make(map[string]DumpAccount), } - it := trie.NewIterator(self.trie.NodeIterator(nil)) + it := trie.NewIterator(s.trie.NodeIterator(nil)) for it.Next() { - addr := self.trie.GetKey(it.Key) + addr := s.trie.GetKey(it.Key) var data types.StateAccount if err := rlp.DecodeBytes(it.Value, &data); err != nil { panic(err) @@ -60,20 +60,20 @@ func (self *StateDB) RawDump() Dump { Nonce: data.Nonce, Root: common.Bytes2Hex(data.Root[:]), CodeHash: common.Bytes2Hex(data.CodeHash), - Code: common.Bytes2Hex(obj.Code(self.db)), + Code: common.Bytes2Hex(obj.Code(s.db)), Storage: make(map[string]string), } - storageIt := trie.NewIterator(obj.getTrie(self.db).NodeIterator(nil)) + storageIt := trie.NewIterator(obj.getTrie(s.db).NodeIterator(nil)) for storageIt.Next() { - account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) + account.Storage[common.Bytes2Hex(s.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) } dump.Accounts[common.Bytes2Hex(addr)] = account } return dump } -func (self *StateDB) Dump() []byte { - json, err := json.MarshalIndent(self.RawDump(), "", " ") +func (s *StateDB) Dump() []byte { + json, err := json.MarshalIndent(s.RawDump(), "", " ") if err != nil { fmt.Println("dump err", err) } diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index 1d19b087da..c4fa4937aa 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -19,8 +19,6 @@ package state import ( "testing" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/rawdb" ) diff --git a/core/state/state_test.go b/core/state/state_test.go index 2c40de412b..8ed63c3e34 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -92,7 +92,7 @@ func (s *StateSuite) TestDump(c *checker.C) { func (s *StateSuite) SetUpTest(c *checker.C) { s.db = rawdb.NewMemoryDatabase() tdb := NewDatabaseWithConfig(s.db, &trie.Config{Preimages: true}) - s.state, _ = New(common.Hash{}, tdb) + s.state, _ = New(common.Hash{}, tdb, nil) } func (s *StateSuite) TestNull(c *checker.C) { diff --git a/core/state/statedb.go b/core/state/statedb.go index 9552e68b51..58c4494cba 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -24,9 +24,8 @@ import ( "sync" "time" - "github.com/tomochain/tomochain/core/state/snapshot" - "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/metrics" @@ -59,8 +58,11 @@ type StateDB struct { db Database trie Trie - snaps *snapshot.Tree // Nil if snapshot is not available - snap snapshot.Snapshot // Nil if snapshot is not available + snaps *snapshot.Tree + snap snapshot.Snapshot + snapDestructs map[common.Hash]struct{} + snapAccounts map[common.Hash][]byte + snapStorage map[common.Hash]map[common.Hash][]byte // This map holds 'live' objects, which will get modified while processing a state transition. stateObjects map[common.Address]*stateObject @@ -105,19 +107,19 @@ type StateDB struct { lock sync.Mutex } -func (self *StateDB) SubRefund(gas uint64) { - self.journal.append(refundChange{ - prev: self.refund}) - if gas > self.refund { - panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, self.refund)) +func (s *StateDB) SubRefund(gas uint64) { + s.journal.append(refundChange{ + prev: s.refund}) + if gas > s.refund { + panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, s.refund)) } - self.refund -= gas + s.refund -= gas } -func (self *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { + stateObject := s.getStateObject(addr) if stateObject != nil { - return stateObject.GetCommittedState(self.db, hash) + return stateObject.GetCommittedState(s.db, hash) } return common.Hash{} } @@ -141,107 +143,123 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) if sdb.snaps != nil { sdb.snap = sdb.snaps.Snapshot(root) } + if sdb.snaps != nil { + if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil { + sdb.snapDestructs = make(map[common.Hash]struct{}) + sdb.snapAccounts = make(map[common.Hash][]byte) + sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte) + } + } return sdb, nil } // setError remembers the first non-nil error it is called with. -func (self *StateDB) setError(err error) { - if self.dbErr == nil { - self.dbErr = err +func (s *StateDB) setError(err error) { + if s.dbErr == nil { + s.dbErr = err } } -func (self *StateDB) Error() error { - return self.dbErr +func (s *StateDB) Error() error { + return s.dbErr } // Reset clears out all ephemeral state objects from the state db, but keeps // the underlying state trie to avoid reloading data for the next operations. -func (self *StateDB) Reset(root common.Hash) error { - tr, err := self.db.OpenTrie(root) +func (s *StateDB) Reset(root common.Hash) error { + tr, err := s.db.OpenTrie(root) if err != nil { return err } - self.trie = tr - self.stateObjects = make(map[common.Address]*stateObject) - self.stateObjectsDirty = make(map[common.Address]struct{}) - self.thash = common.Hash{} - self.bhash = common.Hash{} - self.txIndex = 0 - self.logs = make(map[common.Hash][]*types.Log) - self.logSize = 0 - self.preimages = make(map[common.Hash][]byte) - self.clearJournalAndRefund() + s.trie = tr + s.stateObjects = make(map[common.Address]*stateObject) + s.stateObjectsDirty = make(map[common.Address]struct{}) + s.thash = common.Hash{} + s.bhash = common.Hash{} + s.txIndex = 0 + s.logs = make(map[common.Hash][]*types.Log) + s.logSize = 0 + s.preimages = make(map[common.Hash][]byte) + s.clearJournalAndRefund() + + if s.snaps != nil { + s.snapAccounts, s.snapDestructs, s.snapStorage = nil, nil, nil + if s.snap = s.snaps.Snapshot(root); s.snap != nil { + s.snapDestructs = make(map[common.Hash]struct{}) + s.snapAccounts = make(map[common.Hash][]byte) + s.snapStorage = make(map[common.Hash]map[common.Hash][]byte) + } + } return nil } -func (self *StateDB) AddLog(log *types.Log) { - self.journal.append(addLogChange{txhash: self.thash}) +func (s *StateDB) AddLog(log *types.Log) { + s.journal.append(addLogChange{txhash: s.thash}) - log.TxHash = self.thash - log.BlockHash = self.bhash - log.TxIndex = uint(self.txIndex) - log.Index = self.logSize - self.logs[self.thash] = append(self.logs[self.thash], log) - self.logSize++ + log.TxHash = s.thash + log.BlockHash = s.bhash + log.TxIndex = uint(s.txIndex) + log.Index = s.logSize + s.logs[s.thash] = append(s.logs[s.thash], log) + s.logSize++ } -func (self *StateDB) GetLogs(hash common.Hash) []*types.Log { - return self.logs[hash] +func (s *StateDB) GetLogs(hash common.Hash) []*types.Log { + return s.logs[hash] } -func (self *StateDB) Logs() []*types.Log { +func (s *StateDB) Logs() []*types.Log { var logs []*types.Log - for _, lgs := range self.logs { + for _, lgs := range s.logs { logs = append(logs, lgs...) } return logs } // AddPreimage records a SHA3 preimage seen by the VM. -func (self *StateDB) AddPreimage(hash common.Hash, preimage []byte) { - if _, ok := self.preimages[hash]; !ok { - self.journal.append(addPreimageChange{hash: hash}) +func (s *StateDB) AddPreimage(hash common.Hash, preimage []byte) { + if _, ok := s.preimages[hash]; !ok { + s.journal.append(addPreimageChange{hash: hash}) pi := make([]byte, len(preimage)) copy(pi, preimage) - self.preimages[hash] = pi + s.preimages[hash] = pi } } // Preimages returns a list of SHA3 preimages that have been submitted. -func (self *StateDB) Preimages() map[common.Hash][]byte { - return self.preimages +func (s *StateDB) Preimages() map[common.Hash][]byte { + return s.preimages } -func (self *StateDB) AddRefund(gas uint64) { - self.journal.append(refundChange{prev: self.refund}) - self.refund += gas +func (s *StateDB) AddRefund(gas uint64) { + s.journal.append(refundChange{prev: s.refund}) + s.refund += gas } // Exist reports whether the given account address exists in the state. // Notably this also returns true for suicided accounts. -func (self *StateDB) Exist(addr common.Address) bool { - return self.getStateObject(addr) != nil +func (s *StateDB) Exist(addr common.Address) bool { + return s.getStateObject(addr) != nil } // Empty returns whether the state object is either non-existent // or empty according to the EIP161 specification (balance = nonce = code = 0) -func (self *StateDB) Empty(addr common.Address) bool { - so := self.getStateObject(addr) +func (s *StateDB) Empty(addr common.Address) bool { + so := s.getStateObject(addr) return so == nil || so.empty() } // Retrieve the balance from the given address or 0 if object not found -func (self *StateDB) GetBalance(addr common.Address) *big.Int { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetBalance(addr common.Address) *big.Int { + stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.Balance() } return common.Big0 } -func (self *StateDB) GetNonce(addr common.Address) uint64 { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetNonce(addr common.Address) uint64 { + stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.Nonce() } @@ -249,63 +267,63 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 { return 0 } -func (self *StateDB) GetCode(addr common.Address) []byte { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetCode(addr common.Address) []byte { + stateObject := s.getStateObject(addr) if stateObject != nil { - return stateObject.Code(self.db) + return stateObject.Code(s.db) } return nil } -func (self *StateDB) GetCodeSize(addr common.Address) int { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetCodeSize(addr common.Address) int { + stateObject := s.getStateObject(addr) if stateObject == nil { return 0 } if stateObject.code != nil { return len(stateObject.code) } - size, err := self.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash())) + size, err := s.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash())) if err != nil { - self.setError(err) + s.setError(err) } return size } -func (self *StateDB) GetCodeHash(addr common.Address) common.Hash { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetCodeHash(addr common.Address) common.Hash { + stateObject := s.getStateObject(addr) if stateObject == nil { return common.Hash{} } return common.BytesToHash(stateObject.CodeHash()) } -func (self *StateDB) GetState(addr common.Address, bhash common.Hash) common.Hash { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetState(addr common.Address, bhash common.Hash) common.Hash { + stateObject := s.getStateObject(addr) if stateObject != nil { - return stateObject.GetState(self.db, bhash) + return stateObject.GetState(s.db, bhash) } return common.Hash{} } // Database retrieves the low level database supporting the lower level trie ops. -func (self *StateDB) Database() Database { - return self.db +func (s *StateDB) Database() Database { + return s.db } // StorageTrie returns the storage trie of an account. // The return value is a copy and is nil for non-existent accounts. -func (self *StateDB) StorageTrie(addr common.Address) Trie { - stateObject := self.getStateObject(addr) +func (s *StateDB) StorageTrie(addr common.Address) Trie { + stateObject := s.getStateObject(addr) if stateObject == nil { return nil } - cpy := stateObject.deepCopy(self) - return cpy.updateTrie(self.db) + cpy := stateObject.deepCopy(s) + return cpy.updateTrie(s.db) } -func (self *StateDB) HasSuicided(addr common.Address) bool { - stateObject := self.getStateObject(addr) +func (s *StateDB) HasSuicided(addr common.Address) bool { + stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.suicided } @@ -317,46 +335,46 @@ func (self *StateDB) HasSuicided(addr common.Address) bool { */ // AddBalance adds amount to the account associated with addr. -func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { stateObject.AddBalance(amount) } } // SubBalance subtracts amount from the account associated with addr. -func (self *StateDB) SubBalance(addr common.Address, amount *big.Int) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { stateObject.SubBalance(amount) } } -func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { stateObject.SetBalance(amount) } } -func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) SetNonce(addr common.Address, nonce uint64) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { stateObject.SetNonce(nonce) } } -func (self *StateDB) SetCode(addr common.Address, code []byte) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) SetCode(addr common.Address, code []byte) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { stateObject.SetCode(crypto.Keccak256Hash(code), code) } } -func (self *StateDB) SetState(addr common.Address, key, value common.Hash) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetState(self.db, key, value) + stateObject.SetState(s.db, key, value) } } @@ -365,12 +383,12 @@ func (self *StateDB) SetState(addr common.Address, key, value common.Hash) { // // The account's state object is still available until the state is committed, // getStateObject will return a non-nil account after Suicide. -func (self *StateDB) Suicide(addr common.Address) bool { - stateObject := self.getStateObject(addr) +func (s *StateDB) Suicide(addr common.Address) bool { + stateObject := s.getStateObject(addr) if stateObject == nil { return false } - self.journal.append(suicideChange{ + s.journal.append(suicideChange{ account: &addr, prev: stateObject.suicided, prevbalance: new(big.Int).Set(stateObject.Balance()), @@ -386,34 +404,43 @@ func (self *StateDB) Suicide(addr common.Address) bool { // // updateStateObject writes the given object to the trie. -func (self *StateDB) updateStateObject(stateObject *stateObject) { +func (s *StateDB) updateStateObject(stateObject *stateObject) { addr := stateObject.Address() - if err := self.trie.UpdateAccount(addr, &stateObject.data); err != nil { - self.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err)) + if err := s.trie.UpdateAccount(addr, &stateObject.data); err != nil { + s.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err)) + } + + // If state snapshotting is active, cache the data til commit. Note, this + // update mechanism is not symmetric to the deletion, because whereas it is + // enough to track account updates at commit time, deletions need tracking + // at transaction boundary level to ensure we capture state clearing. + if s.snap != nil { + s.snapAccounts[stateObject.addrHash] = snapshot.AccountRLP(stateObject.data.Nonce, stateObject.data.Balance, stateObject.data.Root, stateObject.data.CodeHash) } + } // deleteStateObject removes the given object from the state trie. -func (self *StateDB) deleteStateObject(stateObject *stateObject) { +func (s *StateDB) deleteStateObject(stateObject *stateObject) { stateObject.deleted = true addr := stateObject.Address() - if err := self.trie.DeleteAccount(addr); err != nil { - self.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err)) + if err := s.trie.DeleteAccount(addr); err != nil { + s.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err)) } } // DeleteAddress removes the address from the state trie. -func (self *StateDB) DeleteAddress(addr common.Address) { - stateObject := self.getStateObject(addr) +func (s *StateDB) DeleteAddress(addr common.Address) { + stateObject := s.getStateObject(addr) if stateObject != nil && !stateObject.deleted { - self.deleteStateObject(stateObject) + s.deleteStateObject(stateObject) } } // Retrieve a state object given my the address. Returns nil if not found. -func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObject) { +func (s *StateDB) getStateObject(addr common.Address) (stateObject *stateObject) { // Prefer 'live' objects. - if obj := self.stateObjects[addr]; obj != nil { + if obj := s.stateObjects[addr]; obj != nil { if obj.deleted { return nil } @@ -421,17 +448,17 @@ func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObje } // Load the object from the database. - data, err := self.trie.GetAccount(addr) + data, err := s.trie.GetAccount(addr) if err != nil { - self.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err)) + s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err)) return nil } if data == nil { return nil } // Insert into the live set. - obj := newObject(self, addr, data) - self.setStateObject(obj) + obj := newObject(s, addr, data) + s.setStateObject(obj) return obj } @@ -444,13 +471,40 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { if obj := s.stateObjects[addr]; obj != nil { return obj } - if metrics.EnabledExpensive { - defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now()) + // If no live objects are available, attempt to use snapshots + var ( + data *types.StateAccount + err error + ) + if s.snap != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.SnapshotAccountReads += time.Since(start) }(time.Now()) + } + var acc *snapshot.Account + if acc, err = s.snap.Account(crypto.Keccak256Hash(addr[:])); err == nil { + if acc == nil { + return nil + } + data.Nonce, data.Balance, data.CodeHash = acc.Nonce, acc.Balance, acc.CodeHash + if len(data.CodeHash) == 0 { + data.CodeHash = emptyCodeHash + } + data.Root = common.BytesToHash(acc.Root) + if data.Root == (common.Hash{}) { + data.Root = emptyRoot + } + } } - data, err := s.trie.GetAccount(addr) - if err != nil { - s.setError(err) - return nil + // If snapshot unavailable or reading from it failed, load from the database + if s.snap == nil || err != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now()) + } + data, err = s.trie.GetAccount(addr) + if err != nil { + s.setError(err) + return nil + } } // Insert into the live set obj := newObject(s, addr, data) @@ -458,31 +512,31 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { return obj } -func (self *StateDB) setStateObject(object *stateObject) { - self.stateObjects[object.Address()] = object +func (s *StateDB) setStateObject(object *stateObject) { + s.stateObjects[object.Address()] = object } // Retrieve a state object or create a new state object if nil. -func (self *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { + stateObject := s.getStateObject(addr) if stateObject == nil || stateObject.deleted { - stateObject, _ = self.createObject(addr) + stateObject, _ = s.createObject(addr) } return stateObject } // createObject creates a new state object. If there is an existing account with // the given address, it is overwritten and returned as the second return value. -func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { - prev = self.getStateObject(addr) - newobj = newObject(self, addr, &types.StateAccount{}) +func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { + prev = s.getStateObject(addr) + newobj = newObject(s, addr, &types.StateAccount{}) newobj.setNonce(0) // sets the object to dirty if prev == nil { - self.journal.append(createObjectChange{account: &addr}) + s.journal.append(createObjectChange{account: &addr}) } else { - self.journal.append(resetObjectChange{prev: prev}) + s.journal.append(resetObjectChange{prev: prev}) } - self.setStateObject(newobj) + s.setStateObject(newobj) return newobj, prev } @@ -496,15 +550,15 @@ func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObjec // 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) // // Carrying over the balance ensures that Ether doesn't disappear. -func (self *StateDB) CreateAccount(addr common.Address) { - new, prev := self.createObject(addr) +func (s *StateDB) CreateAccount(addr common.Address) { + new, prev := s.createObject(addr) if prev != nil { new.setBalance(prev.data.Balance) } } -func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { - so := db.getStateObject(addr) +func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { + so := s.getStateObject(addr) if so == nil { return nil } @@ -514,10 +568,10 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common cb(h, value) } - it := trie.NewIterator(so.getTrie(db.db).NodeIterator(nil)) + it := trie.NewIterator(so.getTrie(s.db).NodeIterator(nil)) for it.Next() { // ignore cached values - key := common.BytesToHash(db.trie.GetKey(it.Key)) + key := common.BytesToHash(s.trie.GetKey(it.Key)) if _, ok := so.cachedStorage[key]; !ok { cb(key, common.BytesToHash(it.Value)) } @@ -527,64 +581,64 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common // Copy creates a deep, independent copy of the state. // Snapshots of the copied state cannot be applied to the copy. -func (self *StateDB) Copy() *StateDB { - self.lock.Lock() - defer self.lock.Unlock() +func (s *StateDB) Copy() *StateDB { + s.lock.Lock() + defer s.lock.Unlock() // Copy all the basic fields, initialize the memory ones state := &StateDB{ - db: self.db, - trie: self.db.CopyTrie(self.trie), - stateObjects: make(map[common.Address]*stateObject, len(self.journal.dirties)), - stateObjectsDirty: make(map[common.Address]struct{}, len(self.journal.dirties)), - refund: self.refund, - logs: make(map[common.Hash][]*types.Log, len(self.logs)), - logSize: self.logSize, + db: s.db, + trie: s.db.CopyTrie(s.trie), + stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), + stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)), + refund: s.refund, + logs: make(map[common.Hash][]*types.Log, len(s.logs)), + logSize: s.logSize, preimages: make(map[common.Hash][]byte), journal: newJournal(), } // Copy the dirty states, logs, and preimages - for addr := range self.journal.dirties { - state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state) + for addr := range s.journal.dirties { + state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) state.stateObjectsDirty[addr] = struct{}{} } - for hash, logs := range self.logs { + for hash, logs := range s.logs { state.logs[hash] = make([]*types.Log, len(logs)) copy(state.logs[hash], logs) } - for hash, preimage := range self.preimages { + for hash, preimage := range s.preimages { state.preimages[hash] = preimage } return state } // Snapshot returns an identifier for the current revision of the state. -func (self *StateDB) Snapshot() int { - id := self.nextRevisionId - self.nextRevisionId++ - self.validRevisions = append(self.validRevisions, revision{id, self.journal.length()}) +func (s *StateDB) Snapshot() int { + id := s.nextRevisionId + s.nextRevisionId++ + s.validRevisions = append(s.validRevisions, revision{id, s.journal.length()}) return id } // RevertToSnapshot reverts all state changes made since the given revision. -func (self *StateDB) RevertToSnapshot(revid int) { +func (s *StateDB) RevertToSnapshot(revid int) { // Find the snapshot in the stack of valid snapshots. - idx := sort.Search(len(self.validRevisions), func(i int) bool { - return self.validRevisions[i].id >= revid + idx := sort.Search(len(s.validRevisions), func(i int) bool { + return s.validRevisions[i].id >= revid }) - if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid { + if idx == len(s.validRevisions) || s.validRevisions[idx].id != revid { panic(fmt.Errorf("revision id %v cannot be reverted", revid)) } - snapshot := self.validRevisions[idx].journalIndex + snapshot := s.validRevisions[idx].journalIndex // Replay the journal to undo changes and remove invalidated snapshots - self.journal.revert(self, snapshot) - self.validRevisions = self.validRevisions[:idx] + s.journal.revert(s, snapshot) + s.validRevisions = s.validRevisions[:idx] } // GetRefund returns the current value of the refund counter. -func (self *StateDB) GetRefund() uint64 { - return self.refund +func (s *StateDB) GetRefund() uint64 { + return s.refund } // Finalise finalises the state by removing the self destructed objects @@ -618,10 +672,10 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { // Prepare sets the current transaction hash and index and block hash which is // used when the EVM emits new state logs. -func (self *StateDB) Prepare(thash, bhash common.Hash, ti int) { - self.thash = thash - self.bhash = bhash - self.txIndex = ti +func (s *StateDB) Prepare(thash, bhash common.Hash, ti int) { + s.thash = thash + s.bhash = bhash + s.txIndex = ti } // DeleteSuicides flags the suicided objects for deletion so that it diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index b87bc6685a..d69f5cb7c3 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -41,7 +41,7 @@ import ( func TestUpdateLeaks(t *testing.T) { // Create an empty state database db := rawdb.NewMemoryDatabase() - state, _ := New(common.Hash{}, NewDatabase(db)) + state, _ := New(common.Hash{}, NewDatabase(db), nil) // Update it with some accounts for i := byte(0); i < 255; i++ { @@ -71,8 +71,8 @@ func TestIntermediateLeaks(t *testing.T) { // Create two state databases, one transitioning to the final state, the other final from the beginning transDb := rawdb.NewMemoryDatabase() finalDb := rawdb.NewMemoryDatabase() - transState, _ := New(common.Hash{}, NewDatabase(transDb)) - finalState, _ := New(common.Hash{}, NewDatabase(finalDb)) + transState, _ := New(common.Hash{}, NewDatabase(transDb), nil) + finalState, _ := New(common.Hash{}, NewDatabase(finalDb), nil) modify := func(state *StateDB, addr common.Address, i, tweak byte) { state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) @@ -130,7 +130,7 @@ func TestIntermediateLeaks(t *testing.T) { func TestCopy(t *testing.T) { // Create a random state test to copy and modify "independently" db := rawdb.NewMemoryDatabase() - orig, _ := New(common.Hash{}, NewDatabase(db)) + orig, _ := New(common.Hash{}, NewDatabase(db), nil) for i := byte(0); i < 255; i++ { obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) @@ -342,7 +342,7 @@ func (test *snapshotTest) run() bool { // Run all actions and create snapshots. var ( db = rawdb.NewMemoryDatabase() - state, _ = New(common.Hash{}, NewDatabase(db)) + state, _ = New(common.Hash{}, NewDatabase(db), nil) snapshotRevs = make([]int, len(test.snapshots)) sindex = 0 ) @@ -356,7 +356,7 @@ func (test *snapshotTest) run() bool { // Revert all snapshots in reverse order. Each revert must yield a state // that is equivalent to fresh state with all actions up the snapshot applied. for sindex--; sindex >= 0; sindex-- { - checkstate, _ := New(common.Hash{}, state.Database()) + checkstate, _ := New(common.Hash{}, state.Database(), nil) for _, action := range test.actions[:test.snapshots[sindex]] { action.fn(action, checkstate) } @@ -416,7 +416,7 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { func (s *StateSuite) TestTouchDelete(c *check.C) { s.state.GetOrNewStateObject(common.Address{}) root, _ := s.state.Commit(false) - s.state, _ = New(root, s.state.db) + s.state, _ = New(root, s.state.db, nil) snapshot := s.state.Snapshot() s.state.AddBalance(common.Address{}, new(big.Int)) diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 058968d24e..adfd813147 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -26,9 +26,6 @@ import ( "testing" "time" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/rawdb" diff --git a/les/odr_test.go b/les/odr_test.go index 302a294259..a1f004e619 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -23,8 +23,6 @@ import ( "testing" "time" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" diff --git a/light/odr_test.go b/light/odr_test.go index 1debcd4c1e..81ae37b95b 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -24,9 +24,6 @@ import ( "testing" "time" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus" From 2ad92f6c5ae6a5ecb1bb49bb571358e7ac57af07 Mon Sep 17 00:00:00 2001 From: Enda Dinh <90235926+endadinh@users.noreply.github.com> Date: Mon, 17 Jul 2023 14:36:31 +0700 Subject: [PATCH 068/119] fix unit tests --- core/state/managed_state_test.go | 1 - core/state/snapshot/disklayer_test.go | 435 ++++++++++++ .../{iterator_binary => iterator_binary.go} | 0 core/state/snapshot/iterator_test.go | 658 ++++++++++++++++++ core/state/state_object.go | 2 +- core/state/statedb_test.go | 2 +- 6 files changed, 1095 insertions(+), 3 deletions(-) create mode 100644 core/state/snapshot/disklayer_test.go rename core/state/snapshot/{iterator_binary => iterator_binary.go} (100%) create mode 100644 core/state/snapshot/iterator_test.go diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index c4fa4937aa..9df24323f5 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -20,7 +20,6 @@ import ( "testing" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core/rawdb" ) var addr = common.BytesToAddress([]byte("test")) diff --git a/core/state/snapshot/disklayer_test.go b/core/state/snapshot/disklayer_test.go new file mode 100644 index 0000000000..652e531b25 --- /dev/null +++ b/core/state/snapshot/disklayer_test.go @@ -0,0 +1,435 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "testing" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/ethdb/memorydb" +) + +// reverse reverses the contents of a byte slice. It's used to update random accs +// with deterministic changes. +func reverse(blob []byte) []byte { + res := make([]byte, len(blob)) + for i, b := range blob { + res[len(blob)-1-i] = b + } + return res +} + +// Tests that merging something into a disk layer persists it into the database +// and invalidates any previously written and cached values. +func TestDiskMerge(t *testing.T) { + // Create some accounts in the disk layer + db := memorydb.New() + + var ( + accNoModNoCache = common.Hash{0x1} + accNoModCache = common.Hash{0x2} + accModNoCache = common.Hash{0x3} + accModCache = common.Hash{0x4} + accDelNoCache = common.Hash{0x5} + accDelCache = common.Hash{0x6} + conNoModNoCache = common.Hash{0x7} + conNoModNoCacheSlot = common.Hash{0x70} + conNoModCache = common.Hash{0x8} + conNoModCacheSlot = common.Hash{0x80} + conModNoCache = common.Hash{0x9} + conModNoCacheSlot = common.Hash{0x90} + conModCache = common.Hash{0xa} + conModCacheSlot = common.Hash{0xa0} + conDelNoCache = common.Hash{0xb} + conDelNoCacheSlot = common.Hash{0xb0} + conDelCache = common.Hash{0xc} + conDelCacheSlot = common.Hash{0xc0} + conNukeNoCache = common.Hash{0xd} + conNukeNoCacheSlot = common.Hash{0xd0} + conNukeCache = common.Hash{0xe} + conNukeCacheSlot = common.Hash{0xe0} + baseRoot = randomHash() + diffRoot = randomHash() + ) + + rawdb.WriteAccountSnapshot(db, accNoModNoCache, accNoModNoCache[:]) + rawdb.WriteAccountSnapshot(db, accNoModCache, accNoModCache[:]) + rawdb.WriteAccountSnapshot(db, accModNoCache, accModNoCache[:]) + rawdb.WriteAccountSnapshot(db, accModCache, accModCache[:]) + rawdb.WriteAccountSnapshot(db, accDelNoCache, accDelNoCache[:]) + rawdb.WriteAccountSnapshot(db, accDelCache, accDelCache[:]) + + rawdb.WriteAccountSnapshot(db, conNoModNoCache, conNoModNoCache[:]) + rawdb.WriteStorageSnapshot(db, conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conNoModCache, conNoModCache[:]) + rawdb.WriteStorageSnapshot(db, conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conModNoCache, conModNoCache[:]) + rawdb.WriteStorageSnapshot(db, conModNoCache, conModNoCacheSlot, conModNoCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conModCache, conModCache[:]) + rawdb.WriteStorageSnapshot(db, conModCache, conModCacheSlot, conModCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conDelNoCache, conDelNoCache[:]) + rawdb.WriteStorageSnapshot(db, conDelNoCache, conDelNoCacheSlot, conDelNoCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conDelCache, conDelCache[:]) + rawdb.WriteStorageSnapshot(db, conDelCache, conDelCacheSlot, conDelCacheSlot[:]) + + rawdb.WriteAccountSnapshot(db, conNukeNoCache, conNukeNoCache[:]) + rawdb.WriteStorageSnapshot(db, conNukeNoCache, conNukeNoCacheSlot, conNukeNoCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conNukeCache, conNukeCache[:]) + rawdb.WriteStorageSnapshot(db, conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:]) + + rawdb.WriteSnapshotRoot(db, baseRoot) + + // Create a disk layer based on the above and cache in some data + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + baseRoot: &diskLayer{ + diskdb: db, + cache: fastcache.New(500 * 1024), + root: baseRoot, + }, + }, + } + base := snaps.Snapshot(baseRoot) + base.AccountRLP(accNoModCache) + base.AccountRLP(accModCache) + base.AccountRLP(accDelCache) + base.Storage(conNoModCache, conNoModCacheSlot) + base.Storage(conModCache, conModCacheSlot) + base.Storage(conDelCache, conDelCacheSlot) + base.Storage(conNukeCache, conNukeCacheSlot) + + // Modify or delete some accounts, flatten everything onto disk + if err := snaps.Update(diffRoot, baseRoot, map[common.Hash]struct{}{ + accDelNoCache: struct{}{}, + accDelCache: struct{}{}, + conNukeNoCache: struct{}{}, + conNukeCache: struct{}{}, + }, map[common.Hash][]byte{ + accModNoCache: reverse(accModNoCache[:]), + accModCache: reverse(accModCache[:]), + }, map[common.Hash]map[common.Hash][]byte{ + conModNoCache: {conModNoCacheSlot: reverse(conModNoCacheSlot[:])}, + conModCache: {conModCacheSlot: reverse(conModCacheSlot[:])}, + conDelNoCache: {conDelNoCacheSlot: nil}, + conDelCache: {conDelCacheSlot: nil}, + }); err != nil { + t.Fatalf("failed to update snapshot tree: %v", err) + } + if err := snaps.Cap(diffRoot, 0); err != nil { + t.Fatalf("failed to flatten snapshot tree: %v", err) + } + // Retrieve all the data through the disk layer and validate it + base = snaps.Snapshot(diffRoot) + if _, ok := base.(*diskLayer); !ok { + t.Fatalf("update not flattend into the disk layer") + } + + // assertAccount ensures that an account matches the given blob. + assertAccount := func(account common.Hash, data []byte) { + t.Helper() + blob, err := base.AccountRLP(account) + if err != nil { + t.Errorf("account access (%x) failed: %v", account, err) + } else if !bytes.Equal(blob, data) { + t.Errorf("account access (%x) mismatch: have %x, want %x", account, blob, data) + } + } + assertAccount(accNoModNoCache, accNoModNoCache[:]) + assertAccount(accNoModCache, accNoModCache[:]) + assertAccount(accModNoCache, reverse(accModNoCache[:])) + assertAccount(accModCache, reverse(accModCache[:])) + assertAccount(accDelNoCache, nil) + assertAccount(accDelCache, nil) + + // assertStorage ensures that a storage slot matches the given blob. + assertStorage := func(account common.Hash, slot common.Hash, data []byte) { + t.Helper() + blob, err := base.Storage(account, slot) + if err != nil { + t.Errorf("storage access (%x:%x) failed: %v", account, slot, err) + } else if !bytes.Equal(blob, data) { + t.Errorf("storage access (%x:%x) mismatch: have %x, want %x", account, slot, blob, data) + } + } + assertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + assertStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:])) + assertStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:])) + assertStorage(conDelNoCache, conDelNoCacheSlot, nil) + assertStorage(conDelCache, conDelCacheSlot, nil) + assertStorage(conNukeNoCache, conNukeNoCacheSlot, nil) + assertStorage(conNukeCache, conNukeCacheSlot, nil) + + // Retrieve all the data directly from the database and validate it + + // assertDatabaseAccount ensures that an account from the database matches the given blob. + assertDatabaseAccount := func(account common.Hash, data []byte) { + t.Helper() + if blob := rawdb.ReadAccountSnapshot(db, account); !bytes.Equal(blob, data) { + t.Errorf("account database access (%x) mismatch: have %x, want %x", account, blob, data) + } + } + assertDatabaseAccount(accNoModNoCache, accNoModNoCache[:]) + assertDatabaseAccount(accNoModCache, accNoModCache[:]) + assertDatabaseAccount(accModNoCache, reverse(accModNoCache[:])) + assertDatabaseAccount(accModCache, reverse(accModCache[:])) + assertDatabaseAccount(accDelNoCache, nil) + assertDatabaseAccount(accDelCache, nil) + + // assertDatabaseStorage ensures that a storage slot from the database matches the given blob. + assertDatabaseStorage := func(account common.Hash, slot common.Hash, data []byte) { + t.Helper() + if blob := rawdb.ReadStorageSnapshot(db, account, slot); !bytes.Equal(blob, data) { + t.Errorf("storage database access (%x:%x) mismatch: have %x, want %x", account, slot, blob, data) + } + } + assertDatabaseStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + assertDatabaseStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + assertDatabaseStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:])) + assertDatabaseStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:])) + assertDatabaseStorage(conDelNoCache, conDelNoCacheSlot, nil) + assertDatabaseStorage(conDelCache, conDelCacheSlot, nil) + assertDatabaseStorage(conNukeNoCache, conNukeNoCacheSlot, nil) + assertDatabaseStorage(conNukeCache, conNukeCacheSlot, nil) +} + +// Tests that merging something into a disk layer persists it into the database +// and invalidates any previously written and cached values, discarding anything +// after the in-progress generation marker. +func TestDiskPartialMerge(t *testing.T) { + // Iterate the test a few times to ensure we pick various internal orderings + // for the data slots as well as the progress marker. + for i := 0; i < 1024; i++ { + // Create some accounts in the disk layer + db := memorydb.New() + + var ( + accNoModNoCache = randomHash() + accNoModCache = randomHash() + accModNoCache = randomHash() + accModCache = randomHash() + accDelNoCache = randomHash() + accDelCache = randomHash() + conNoModNoCache = randomHash() + conNoModNoCacheSlot = randomHash() + conNoModCache = randomHash() + conNoModCacheSlot = randomHash() + conModNoCache = randomHash() + conModNoCacheSlot = randomHash() + conModCache = randomHash() + conModCacheSlot = randomHash() + conDelNoCache = randomHash() + conDelNoCacheSlot = randomHash() + conDelCache = randomHash() + conDelCacheSlot = randomHash() + conNukeNoCache = randomHash() + conNukeNoCacheSlot = randomHash() + conNukeCache = randomHash() + conNukeCacheSlot = randomHash() + baseRoot = randomHash() + diffRoot = randomHash() + genMarker = append(randomHash().Bytes(), randomHash().Bytes()...) + ) + + // insertAccount injects an account into the database if it's after the + // generator marker, drops the op otherwise. This is needed to seed the + // database with a valid starting snapshot. + insertAccount := func(account common.Hash, data []byte) { + if bytes.Compare(account[:], genMarker) <= 0 { + rawdb.WriteAccountSnapshot(db, account, data[:]) + } + } + insertAccount(accNoModNoCache, accNoModNoCache[:]) + insertAccount(accNoModCache, accNoModCache[:]) + insertAccount(accModNoCache, accModNoCache[:]) + insertAccount(accModCache, accModCache[:]) + insertAccount(accDelNoCache, accDelNoCache[:]) + insertAccount(accDelCache, accDelCache[:]) + + // insertStorage injects a storage slot into the database if it's after + // the generator marker, drops the op otherwise. This is needed to seed + // the database with a valid starting snapshot. + insertStorage := func(account common.Hash, slot common.Hash, data []byte) { + if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 { + rawdb.WriteStorageSnapshot(db, account, slot, data[:]) + } + } + insertAccount(conNoModNoCache, conNoModNoCache[:]) + insertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + insertAccount(conNoModCache, conNoModCache[:]) + insertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + insertAccount(conModNoCache, conModNoCache[:]) + insertStorage(conModNoCache, conModNoCacheSlot, conModNoCacheSlot[:]) + insertAccount(conModCache, conModCache[:]) + insertStorage(conModCache, conModCacheSlot, conModCacheSlot[:]) + insertAccount(conDelNoCache, conDelNoCache[:]) + insertStorage(conDelNoCache, conDelNoCacheSlot, conDelNoCacheSlot[:]) + insertAccount(conDelCache, conDelCache[:]) + insertStorage(conDelCache, conDelCacheSlot, conDelCacheSlot[:]) + + insertAccount(conNukeNoCache, conNukeNoCache[:]) + insertStorage(conNukeNoCache, conNukeNoCacheSlot, conNukeNoCacheSlot[:]) + insertAccount(conNukeCache, conNukeCache[:]) + insertStorage(conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:]) + + rawdb.WriteSnapshotRoot(db, baseRoot) + + // Create a disk layer based on the above using a random progress marker + // and cache in some data. + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + baseRoot: &diskLayer{ + diskdb: db, + cache: fastcache.New(500 * 1024), + root: baseRoot, + }, + }, + } + snaps.layers[baseRoot].(*diskLayer).genMarker = genMarker + base := snaps.Snapshot(baseRoot) + + // assertAccount ensures that an account matches the given blob if it's + // already covered by the disk snapshot, and errors out otherwise. + assertAccount := func(account common.Hash, data []byte) { + t.Helper() + blob, err := base.AccountRLP(account) + if bytes.Compare(account[:], genMarker) > 0 && err != ErrNotCoveredYet { + t.Fatalf("test %d: post-marker (%x) account access (%x) succeeded: %x", i, genMarker, account, blob) + } + if bytes.Compare(account[:], genMarker) <= 0 && !bytes.Equal(blob, data) { + t.Fatalf("test %d: pre-marker (%x) account access (%x) mismatch: have %x, want %x", i, genMarker, account, blob, data) + } + } + assertAccount(accNoModCache, accNoModCache[:]) + assertAccount(accModCache, accModCache[:]) + assertAccount(accDelCache, accDelCache[:]) + + // assertStorage ensures that a storage slot matches the given blob if + // it's already covered by the disk snapshot, and errors out otherwise. + assertStorage := func(account common.Hash, slot common.Hash, data []byte) { + t.Helper() + blob, err := base.Storage(account, slot) + if bytes.Compare(append(account[:], slot[:]...), genMarker) > 0 && err != ErrNotCoveredYet { + t.Fatalf("test %d: post-marker (%x) storage access (%x:%x) succeeded: %x", i, genMarker, account, slot, blob) + } + if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 && !bytes.Equal(blob, data) { + t.Fatalf("test %d: pre-marker (%x) storage access (%x:%x) mismatch: have %x, want %x", i, genMarker, account, slot, blob, data) + } + } + assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + assertStorage(conModCache, conModCacheSlot, conModCacheSlot[:]) + assertStorage(conDelCache, conDelCacheSlot, conDelCacheSlot[:]) + assertStorage(conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:]) + + // Modify or delete some accounts, flatten everything onto disk + if err := snaps.Update(diffRoot, baseRoot, map[common.Hash]struct{}{ + accDelNoCache: struct{}{}, + accDelCache: struct{}{}, + conNukeNoCache: struct{}{}, + conNukeCache: struct{}{}, + }, map[common.Hash][]byte{ + accModNoCache: reverse(accModNoCache[:]), + accModCache: reverse(accModCache[:]), + }, map[common.Hash]map[common.Hash][]byte{ + conModNoCache: {conModNoCacheSlot: reverse(conModNoCacheSlot[:])}, + conModCache: {conModCacheSlot: reverse(conModCacheSlot[:])}, + conDelNoCache: {conDelNoCacheSlot: nil}, + conDelCache: {conDelCacheSlot: nil}, + }); err != nil { + t.Fatalf("test %d: failed to update snapshot tree: %v", i, err) + } + if err := snaps.Cap(diffRoot, 0); err != nil { + t.Fatalf("test %d: failed to flatten snapshot tree: %v", i, err) + } + // Retrieve all the data through the disk layer and validate it + base = snaps.Snapshot(diffRoot) + if _, ok := base.(*diskLayer); !ok { + t.Fatalf("test %d: update not flattend into the disk layer", i) + } + assertAccount(accNoModNoCache, accNoModNoCache[:]) + assertAccount(accNoModCache, accNoModCache[:]) + assertAccount(accModNoCache, reverse(accModNoCache[:])) + assertAccount(accModCache, reverse(accModCache[:])) + assertAccount(accDelNoCache, nil) + assertAccount(accDelCache, nil) + + assertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + assertStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:])) + assertStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:])) + assertStorage(conDelNoCache, conDelNoCacheSlot, nil) + assertStorage(conDelCache, conDelCacheSlot, nil) + assertStorage(conNukeNoCache, conNukeNoCacheSlot, nil) + assertStorage(conNukeCache, conNukeCacheSlot, nil) + + // Retrieve all the data directly from the database and validate it + + // assertDatabaseAccount ensures that an account inside the database matches + // the given blob if it's already covered by the disk snapshot, and does not + // exist otherwise. + assertDatabaseAccount := func(account common.Hash, data []byte) { + t.Helper() + blob := rawdb.ReadAccountSnapshot(db, account) + if bytes.Compare(account[:], genMarker) > 0 && blob != nil { + t.Fatalf("test %d: post-marker (%x) account database access (%x) succeeded: %x", i, genMarker, account, blob) + } + if bytes.Compare(account[:], genMarker) <= 0 && !bytes.Equal(blob, data) { + t.Fatalf("test %d: pre-marker (%x) account database access (%x) mismatch: have %x, want %x", i, genMarker, account, blob, data) + } + } + assertDatabaseAccount(accNoModNoCache, accNoModNoCache[:]) + assertDatabaseAccount(accNoModCache, accNoModCache[:]) + assertDatabaseAccount(accModNoCache, reverse(accModNoCache[:])) + assertDatabaseAccount(accModCache, reverse(accModCache[:])) + assertDatabaseAccount(accDelNoCache, nil) + assertDatabaseAccount(accDelCache, nil) + + // assertDatabaseStorage ensures that a storage slot inside the database + // matches the given blob if it's already covered by the disk snapshot, + // and does not exist otherwise. + assertDatabaseStorage := func(account common.Hash, slot common.Hash, data []byte) { + t.Helper() + blob := rawdb.ReadStorageSnapshot(db, account, slot) + if bytes.Compare(append(account[:], slot[:]...), genMarker) > 0 && blob != nil { + t.Fatalf("test %d: post-marker (%x) storage database access (%x:%x) succeeded: %x", i, genMarker, account, slot, blob) + } + if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 && !bytes.Equal(blob, data) { + t.Fatalf("test %d: pre-marker (%x) storage database access (%x:%x) mismatch: have %x, want %x", i, genMarker, account, slot, blob, data) + } + } + assertDatabaseStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + assertDatabaseStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + assertDatabaseStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:])) + assertDatabaseStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:])) + assertDatabaseStorage(conDelNoCache, conDelNoCacheSlot, nil) + assertDatabaseStorage(conDelCache, conDelCacheSlot, nil) + assertDatabaseStorage(conNukeNoCache, conNukeNoCacheSlot, nil) + assertDatabaseStorage(conNukeCache, conNukeCacheSlot, nil) + } +} + +// Tests that merging something into a disk layer persists it into the database +// and invalidates any previously written and cached values, discarding anything +// after the in-progress generation marker. +// +// This test case is a tiny specialized case of TestDiskPartialMerge, which tests +// some very specific cornercases that random tests won't ever trigger. +func TestDiskMidAccountPartialMerge(t *testing.T) { +} diff --git a/core/state/snapshot/iterator_binary b/core/state/snapshot/iterator_binary.go similarity index 100% rename from core/state/snapshot/iterator_binary rename to core/state/snapshot/iterator_binary.go diff --git a/core/state/snapshot/iterator_test.go b/core/state/snapshot/iterator_test.go new file mode 100644 index 0000000000..613bd9955d --- /dev/null +++ b/core/state/snapshot/iterator_test.go @@ -0,0 +1,658 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "encoding/binary" + "fmt" + "math/rand" + "testing" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" +) + +// TestAccountIteratorBasics tests some simple single-layer iteration +func TestAccountIteratorBasics(t *testing.T) { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + // Fill up a parent + for i := 0; i < 100; i++ { + h := randomHash() + data := randomAccount() + + accounts[h] = data + if rand.Intn(4) == 0 { + destructs[h] = struct{}{} + } + if rand.Intn(2) == 0 { + accStorage := make(map[common.Hash][]byte) + value := make([]byte, 32) + rand.Read(value) + accStorage[randomHash()] = value + storage[h] = accStorage + } + } + // Add some (identical) layers on top + parent := newDiffLayer(emptyLayer(), common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + it := parent.AccountIterator(common.Hash{}) + verifyIterator(t, 100, it) +} + +type testIterator struct { + values []byte +} + +func newTestIterator(values ...byte) *testIterator { + return &testIterator{values} +} + +func (ti *testIterator) Seek(common.Hash) { + panic("implement me") +} + +func (ti *testIterator) Next() bool { + ti.values = ti.values[1:] + return len(ti.values) > 0 +} + +func (ti *testIterator) Error() error { + return nil +} + +func (ti *testIterator) Hash() common.Hash { + return common.BytesToHash([]byte{ti.values[0]}) +} + +func (ti *testIterator) Account() []byte { + return nil +} + +func (ti *testIterator) Release() {} + +func TestFastIteratorBasics(t *testing.T) { + type testCase struct { + lists [][]byte + expKeys []byte + } + for i, tc := range []testCase{ + {lists: [][]byte{{0, 1, 8}, {1, 2, 8}, {2, 9}, {4}, + {7, 14, 15}, {9, 13, 15, 16}}, + expKeys: []byte{0, 1, 2, 4, 7, 8, 9, 13, 14, 15, 16}}, + {lists: [][]byte{{0, 8}, {1, 2, 8}, {7, 14, 15}, {8, 9}, + {9, 10}, {10, 13, 15, 16}}, + expKeys: []byte{0, 1, 2, 7, 8, 9, 10, 13, 14, 15, 16}}, + } { + var iterators []*weightedAccountIterator + for i, data := range tc.lists { + it := newTestIterator(data...) + iterators = append(iterators, &weightedAccountIterator{it, i}) + + } + fi := &fastAccountIterator{ + iterators: iterators, + initiated: false, + } + count := 0 + for fi.Next() { + if got, exp := fi.Hash()[31], tc.expKeys[count]; exp != got { + t.Errorf("tc %d, [%d]: got %d exp %d", i, count, got, exp) + } + count++ + } + } +} + +func verifyIterator(t *testing.T, expCount int, it AccountIterator) { + t.Helper() + + var ( + count = 0 + last = common.Hash{} + ) + for it.Next() { + hash := it.Hash() + if bytes.Compare(last[:], hash[:]) >= 0 { + t.Errorf("wrong order: %x >= %x", last, hash) + } + if it.Account() == nil { + t.Errorf("iterator returned nil-value for hash %x", hash) + } + count++ + } + if count != expCount { + t.Errorf("iterator count mismatch: have %d, want %d", count, expCount) + } + if err := it.Error(); err != nil { + t.Errorf("iterator failed: %v", err) + } +} + +// TestAccountIteratorTraversal tests some simple multi-layer iteration. +func TestAccountIteratorTraversal(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Stack three diff layers on top with various overlaps + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, + randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, + randomAccountSet("0xbb", "0xdd", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, + randomAccountSet("0xcc", "0xf0", "0xff"), nil) + + // Verify the single and multi-layer iterators + head := snaps.Snapshot(common.HexToHash("0x04")) + + verifyIterator(t, 3, head.(snapshot).AccountIterator(common.Hash{})) + verifyIterator(t, 7, head.(*diffLayer).newBinaryAccountIterator()) + + it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{}) + defer it.Release() + + verifyIterator(t, 7, it) +} + +// TestAccountIteratorTraversalValues tests some multi-layer iteration, where we +// also expect the correct values to show up. +func TestAccountIteratorTraversalValues(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Create a batch of account sets to seed subsequent layers with + var ( + a = make(map[common.Hash][]byte) + b = make(map[common.Hash][]byte) + c = make(map[common.Hash][]byte) + d = make(map[common.Hash][]byte) + e = make(map[common.Hash][]byte) + f = make(map[common.Hash][]byte) + g = make(map[common.Hash][]byte) + h = make(map[common.Hash][]byte) + ) + for i := byte(2); i < 0xff; i++ { + a[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 0, i)) + if i > 20 && i%2 == 0 { + b[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 1, i)) + } + if i%4 == 0 { + c[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 2, i)) + } + if i%7 == 0 { + d[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 3, i)) + } + if i%8 == 0 { + e[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 4, i)) + } + if i > 50 || i < 85 { + f[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 5, i)) + } + if i%64 == 0 { + g[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 6, i)) + } + if i%128 == 0 { + h[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 7, i)) + } + } + // Assemble a stack of snapshots from the account layers + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, a, nil) + snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, b, nil) + snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, c, nil) + snaps.Update(common.HexToHash("0x05"), common.HexToHash("0x04"), nil, d, nil) + snaps.Update(common.HexToHash("0x06"), common.HexToHash("0x05"), nil, e, nil) + snaps.Update(common.HexToHash("0x07"), common.HexToHash("0x06"), nil, f, nil) + snaps.Update(common.HexToHash("0x08"), common.HexToHash("0x07"), nil, g, nil) + snaps.Update(common.HexToHash("0x09"), common.HexToHash("0x08"), nil, h, nil) + + it, _ := snaps.AccountIterator(common.HexToHash("0x09"), common.Hash{}) + defer it.Release() + + head := snaps.Snapshot(common.HexToHash("0x09")) + for it.Next() { + hash := it.Hash() + want, err := head.AccountRLP(hash) + if err != nil { + t.Fatalf("failed to retrieve expected account: %v", err) + } + if have := it.Account(); !bytes.Equal(want, have) { + t.Fatalf("hash %x: account mismatch: have %x, want %x", hash, have, want) + } + } +} + +// This testcase is notorious, all layers contain the exact same 200 accounts. +func TestAccountIteratorLargeTraversal(t *testing.T) { + // Create a custom account factory to recreate the same addresses + makeAccounts := func(num int) map[common.Hash][]byte { + accounts := make(map[common.Hash][]byte) + for i := 0; i < num; i++ { + h := common.Hash{} + binary.BigEndian.PutUint64(h[:], uint64(i+1)) + accounts[h] = randomAccount() + } + return accounts + } + // Build up a large stack of snapshots + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + for i := 1; i < 128; i++ { + snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil) + } + // Iterate the entire stack and ensure everything is hit only once + head := snaps.Snapshot(common.HexToHash("0x80")) + verifyIterator(t, 200, head.(snapshot).AccountIterator(common.Hash{})) + verifyIterator(t, 200, head.(*diffLayer).newBinaryAccountIterator()) + + it, _ := snaps.AccountIterator(common.HexToHash("0x80"), common.Hash{}) + defer it.Release() + + verifyIterator(t, 200, it) +} + +// TestAccountIteratorFlattening tests what happens when we +// - have a live iterator on child C (parent C1 -> C2 .. CN) +// - flattens C2 all the way into CN +// - continues iterating +func TestAccountIteratorFlattening(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Create a stack of diffs on top + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, + randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, + randomAccountSet("0xbb", "0xdd", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, + randomAccountSet("0xcc", "0xf0", "0xff"), nil) + + // Create an iterator and flatten the data from underneath it + it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{}) + defer it.Release() + + if err := snaps.Cap(common.HexToHash("0x04"), 1); err != nil { + t.Fatalf("failed to flatten snapshot stack: %v", err) + } + //verifyIterator(t, 7, it) +} + +func TestAccountIteratorSeek(t *testing.T) { + // Create a snapshot stack with some initial data + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, + randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, + randomAccountSet("0xbb", "0xdd", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, + randomAccountSet("0xcc", "0xf0", "0xff"), nil) + + // Construct various iterators and ensure their tranversal is correct + it, _ := snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xdd")) + defer it.Release() + verifyIterator(t, 3, it) // expected: ee, f0, ff + + it, _ = snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xaa")) + defer it.Release() + verifyIterator(t, 3, it) // expected: ee, f0, ff + + it, _ = snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xff")) + defer it.Release() + verifyIterator(t, 0, it) // expected: nothing + + it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xbb")) + defer it.Release() + verifyIterator(t, 5, it) // expected: cc, dd, ee, f0, ff + + it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xef")) + defer it.Release() + verifyIterator(t, 2, it) // expected: f0, ff + + it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xf0")) + defer it.Release() + verifyIterator(t, 1, it) // expected: ff + + it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xff")) + defer it.Release() + verifyIterator(t, 0, it) // expected: nothing +} + +// TestIteratorDeletions tests that the iterator behaves correct when there are +// deleted accounts (where the Account() value is nil). The iterator +// should not output any accounts or nil-values for those cases. +func TestIteratorDeletions(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Stack three diff layers on top with various overlaps + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), + nil, randomAccountSet("0x11", "0x22", "0x33"), nil) + + deleted := common.HexToHash("0x22") + destructed := map[common.Hash]struct{}{ + deleted: struct{}{}, + } + snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), + destructed, randomAccountSet("0x11", "0x33"), nil) + + snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), + nil, randomAccountSet("0x33", "0x44", "0x55"), nil) + + // The output should be 11,33,44,55 + it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{}) + // Do a quick check + verifyIterator(t, 4, it) + it.Release() + + // And a more detailed verification that we indeed do not see '0x22' + it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{}) + defer it.Release() + for it.Next() { + hash := it.Hash() + if it.Account() == nil { + t.Errorf("iterator returned nil-value for hash %x", hash) + } + if hash == deleted { + t.Errorf("expected deleted elem %x to not be returned by iterator", deleted) + } + } +} + +// BenchmarkAccountIteratorTraversal is a bit a bit notorious -- all layers contain the +// exact same 200 accounts. That means that we need to process 2000 items, but +// only spit out 200 values eventually. +// +// The value-fetching benchmark is easy on the binary iterator, since it never has to reach +// down at any depth for retrieving the values -- all are on the toppmost layer +// +// BenchmarkAccountIteratorTraversal/binary_iterator_keys-6 2239 483674 ns/op +// BenchmarkAccountIteratorTraversal/binary_iterator_values-6 2403 501810 ns/op +// BenchmarkAccountIteratorTraversal/fast_iterator_keys-6 1923 677966 ns/op +// BenchmarkAccountIteratorTraversal/fast_iterator_values-6 1741 649967 ns/op +func BenchmarkAccountIteratorTraversal(b *testing.B) { + // Create a custom account factory to recreate the same addresses + makeAccounts := func(num int) map[common.Hash][]byte { + accounts := make(map[common.Hash][]byte) + for i := 0; i < num; i++ { + h := common.Hash{} + binary.BigEndian.PutUint64(h[:], uint64(i+1)) + accounts[h] = randomAccount() + } + return accounts + } + // Build up a large stack of snapshots + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + for i := 1; i <= 100; i++ { + snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil) + } + // We call this once before the benchmark, so the creation of + // sorted accountlists are not included in the results. + head := snaps.Snapshot(common.HexToHash("0x65")) + head.(*diffLayer).newBinaryAccountIterator() + + b.Run("binary iterator keys", func(b *testing.B) { + for i := 0; i < b.N; i++ { + got := 0 + it := head.(*diffLayer).newBinaryAccountIterator() + for it.Next() { + got++ + } + if exp := 200; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("binary iterator values", func(b *testing.B) { + for i := 0; i < b.N; i++ { + got := 0 + it := head.(*diffLayer).newBinaryAccountIterator() + for it.Next() { + got++ + head.(*diffLayer).accountRLP(it.Hash(), 0) + } + if exp := 200; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("fast iterator keys", func(b *testing.B) { + for i := 0; i < b.N; i++ { + it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{}) + defer it.Release() + + got := 0 + for it.Next() { + got++ + } + if exp := 200; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("fast iterator values", func(b *testing.B) { + for i := 0; i < b.N; i++ { + it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{}) + defer it.Release() + + got := 0 + for it.Next() { + got++ + it.Account() + } + if exp := 200; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) +} + +// BenchmarkAccountIteratorLargeBaselayer is a pretty realistic benchmark, where +// the baselayer is a lot larger than the upper layer. +// +// This is heavy on the binary iterator, which in most cases will have to +// call recursively 100 times for the majority of the values +// +// BenchmarkAccountIteratorLargeBaselayer/binary_iterator_(keys)-6 514 1971999 ns/op +// BenchmarkAccountIteratorLargeBaselayer/binary_iterator_(values)-6 61 18997492 ns/op +// BenchmarkAccountIteratorLargeBaselayer/fast_iterator_(keys)-6 10000 114385 ns/op +// BenchmarkAccountIteratorLargeBaselayer/fast_iterator_(values)-6 4047 296823 ns/op +func BenchmarkAccountIteratorLargeBaselayer(b *testing.B) { + // Create a custom account factory to recreate the same addresses + makeAccounts := func(num int) map[common.Hash][]byte { + accounts := make(map[common.Hash][]byte) + for i := 0; i < num; i++ { + h := common.Hash{} + binary.BigEndian.PutUint64(h[:], uint64(i+1)) + accounts[h] = randomAccount() + } + return accounts + } + // Build up a large stack of snapshots + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, makeAccounts(2000), nil) + for i := 2; i <= 100; i++ { + snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(20), nil) + } + // We call this once before the benchmark, so the creation of + // sorted accountlists are not included in the results. + head := snaps.Snapshot(common.HexToHash("0x65")) + head.(*diffLayer).newBinaryAccountIterator() + + b.Run("binary iterator (keys)", func(b *testing.B) { + for i := 0; i < b.N; i++ { + got := 0 + it := head.(*diffLayer).newBinaryAccountIterator() + for it.Next() { + got++ + } + if exp := 2000; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("binary iterator (values)", func(b *testing.B) { + for i := 0; i < b.N; i++ { + got := 0 + it := head.(*diffLayer).newBinaryAccountIterator() + for it.Next() { + got++ + v := it.Hash() + head.(*diffLayer).accountRLP(v, 0) + } + if exp := 2000; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("fast iterator (keys)", func(b *testing.B) { + for i := 0; i < b.N; i++ { + it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{}) + defer it.Release() + + got := 0 + for it.Next() { + got++ + } + if exp := 2000; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("fast iterator (values)", func(b *testing.B) { + for i := 0; i < b.N; i++ { + it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{}) + defer it.Release() + + got := 0 + for it.Next() { + it.Account() + got++ + } + if exp := 2000; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) +} + +/* +func BenchmarkBinaryAccountIteration(b *testing.B) { + benchmarkAccountIteration(b, func(snap snapshot) AccountIterator { + return snap.(*diffLayer).newBinaryAccountIterator() + }) +} +func BenchmarkFastAccountIteration(b *testing.B) { + benchmarkAccountIteration(b, newFastAccountIterator) +} +func benchmarkAccountIteration(b *testing.B, iterator func(snap snapshot) AccountIterator) { + // Create a diff stack and randomize the accounts across them + layers := make([]map[common.Hash][]byte, 128) + for i := 0; i < len(layers); i++ { + layers[i] = make(map[common.Hash][]byte) + } + for i := 0; i < b.N; i++ { + depth := rand.Intn(len(layers)) + layers[depth][randomHash()] = randomAccount() + } + stack := snapshot(emptyLayer()) + for _, layer := range layers { + stack = stack.Update(common.Hash{}, layer, nil, nil) + } + // Reset the timers and report all the stats + it := iterator(stack) + b.ResetTimer() + b.ReportAllocs() + for it.Next() { + } +} +*/ diff --git a/core/state/state_object.go b/core/state/state_object.go index 477bc02da2..0aa936fc9c 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -329,7 +329,7 @@ func (self *stateObject) setCode(codeHash common.Hash, code []byte) { func (self *stateObject) SetNonce(nonce uint64) { self.db.journal.append(nonceChange{ - account: &self.address, + account: &self.address, prev: self.data.Nonce, }) self.setNonce(nonce) diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index d69f5cb7c3..085d8f7c27 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -416,7 +416,7 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { func (s *StateSuite) TestTouchDelete(c *check.C) { s.state.GetOrNewStateObject(common.Address{}) root, _ := s.state.Commit(false) - s.state, _ = New(root, s.state.db, nil) + s.state, _ = New(root, s.state.db, s.state.snaps) snapshot := s.state.Snapshot() s.state.AddBalance(common.Address{}, new(big.Int)) From a0d94748e555f0143943743d0c052259087a7b1d Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 17 Jul 2023 15:04:31 +0700 Subject: [PATCH 069/119] Remove duplicate accessor methods --- core/blockchain.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/blockchain.go b/core/blockchain.go index e6afc1f78f..49ee8d6f9b 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -28,17 +28,16 @@ import ( "sync/atomic" "time" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" + lru "github.com/hashicorp/golang-lru" "gopkg.in/karalabe/cookiejar.v2/collections/prque" - lru "github.com/hashicorp/golang-lru" "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/mclock" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/posv" contractValidator "github.com/tomochain/tomochain/contracts/validator/contract" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/types" @@ -52,6 +51,7 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" "github.com/tomochain/tomochain/trie" ) From 995f139e050ead0d06def782755b04be74ad8ff0 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 7 Aug 2023 15:08:28 +0700 Subject: [PATCH 070/119] Record prevdestruct in journal --- core/state/journal.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/state/journal.go b/core/state/journal.go index cbb443706d..2c75c9dbef 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -90,7 +90,8 @@ type ( account *common.Address } resetObjectChange struct { - prev *stateObject + prev *stateObject + prevdestruct bool } suicideChange struct { account *common.Address @@ -144,6 +145,9 @@ func (ch createObjectChange) dirtied() *common.Address { func (ch resetObjectChange) revert(s *StateDB) { s.setStateObject(ch.prev) + if !ch.prevdestruct && s.snap != nil { + delete(s.snapDestructs, ch.prev.addrHash) + } } func (ch resetObjectChange) dirtied() *common.Address { From 8d083251c018d29fdce2045e49d0271a244b7507 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 7 Aug 2023 15:09:42 +0700 Subject: [PATCH 071/119] Use types.SlimAccount instead of snapshot.Account --- core/state/snapshot/account.go | 54 ---------------------------- core/state/snapshot/difflayer.go | 6 ++-- core/state/snapshot/disklayer.go | 6 ++-- core/state/snapshot/generate.go | 4 ++- core/state/snapshot/snapshot.go | 3 +- core/state/snapshot/snapshot_test.go | 4 ++- 6 files changed, 16 insertions(+), 61 deletions(-) delete mode 100644 core/state/snapshot/account.go diff --git a/core/state/snapshot/account.go b/core/state/snapshot/account.go deleted file mode 100644 index 3177f25d92..0000000000 --- a/core/state/snapshot/account.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2019 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package snapshot - -import ( - "bytes" - "math/big" - - "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/rlp" -) - -// Account is a slim version of a state.Account, where the root and code hash -// are replaced with a nil byte slice for empty accounts. -type Account struct { - Nonce uint64 - Balance *big.Int - Root []byte - CodeHash []byte -} - -// AccountRLP converts a state.Account content into a slim snapshot version RLP -// encoded. -func AccountRLP(nonce uint64, balance *big.Int, root common.Hash, codehash []byte) []byte { - slim := Account{ - Nonce: nonce, - Balance: balance, - } - if root != emptyRoot { - slim.Root = root[:] - } - if !bytes.Equal(codehash, emptyCode[:]) { - slim.CodeHash = codehash - } - data, err := rlp.EncodeToBytes(slim) - if err != nil { - panic(err) - } - return data -} diff --git a/core/state/snapshot/difflayer.go b/core/state/snapshot/difflayer.go index 1a75761bf8..98214497c8 100644 --- a/core/state/snapshot/difflayer.go +++ b/core/state/snapshot/difflayer.go @@ -27,7 +27,9 @@ import ( "time" bloomfilter "github.com/holiman/bloomfilter/v2" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/rlp" ) @@ -270,7 +272,7 @@ func (dl *diffLayer) Stale() bool { // Account directly retrieves the account associated with a particular hash in // the snapshot slim data format. -func (dl *diffLayer) Account(hash common.Hash) (*Account, error) { +func (dl *diffLayer) Account(hash common.Hash) (*types.SlimAccount, error) { data, err := dl.AccountRLP(hash) if err != nil { return nil, err @@ -278,7 +280,7 @@ func (dl *diffLayer) Account(hash common.Hash) (*Account, error) { if len(data) == 0 { // can be both nil and []byte{} return nil, nil } - account := new(Account) + account := new(types.SlimAccount) if err := rlp.DecodeBytes(data, account); err != nil { panic(err) } diff --git a/core/state/snapshot/disklayer.go b/core/state/snapshot/disklayer.go index 4fa43660c7..febb3e6753 100644 --- a/core/state/snapshot/disklayer.go +++ b/core/state/snapshot/disklayer.go @@ -21,8 +21,10 @@ import ( "sync" "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" @@ -65,7 +67,7 @@ func (dl *diskLayer) Stale() bool { // Account directly retrieves the account associated with a particular hash in // the snapshot slim data format. -func (dl *diskLayer) Account(hash common.Hash) (*Account, error) { +func (dl *diskLayer) Account(hash common.Hash) (*types.SlimAccount, error) { data, err := dl.AccountRLP(hash) if err != nil { return nil, err @@ -73,7 +75,7 @@ func (dl *diskLayer) Account(hash common.Hash) (*Account, error) { if len(data) == 0 { // can be both nil and []byte{} return nil, nil } - account := new(Account) + account := new(types.SlimAccount) if err := rlp.DecodeBytes(data, account); err != nil { panic(err) } diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go index a38eeab75d..ea5b59a72f 100644 --- a/core/state/snapshot/generate.go +++ b/core/state/snapshot/generate.go @@ -24,9 +24,11 @@ import ( "time" "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" @@ -189,7 +191,7 @@ func (dl *diskLayer) generate(stats *generatorStats) { if err := rlp.DecodeBytes(accIt.Value, &acc); err != nil { log.Crit("Invalid account encountered during snapshot creation", "err", err) } - data := AccountRLP(acc.Nonce, acc.Balance, acc.Root, acc.CodeHash) + data := types.SlimAccountRLP(acc) // If the account is not yet in-progress, write it out if accMarker == nil || !bytes.Equal(accountHash[:], accMarker) { diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index 34d7c77177..82cc1addce 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -26,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/metrics" @@ -98,7 +99,7 @@ type Snapshot interface { // Account directly retrieves the account associated with a particular hash in // the snapshot slim data format. - Account(hash common.Hash) (*Account, error) + Account(hash common.Hash) (*types.SlimAccount, error) // AccountRLP directly retrieves the account RLP associated with a particular // hash in the snapshot slim data format. diff --git a/core/state/snapshot/snapshot_test.go b/core/state/snapshot/snapshot_test.go index 75e53186b8..35fe62c839 100644 --- a/core/state/snapshot/snapshot_test.go +++ b/core/state/snapshot/snapshot_test.go @@ -23,8 +23,10 @@ import ( "testing" "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/rlp" ) @@ -40,7 +42,7 @@ func randomHash() common.Hash { // randomAccount generates a random account and returns it RLP encoded. func randomAccount() []byte { root := randomHash() - a := Account{ + a := types.SlimAccount{ Balance: big.NewInt(rand.Int63()), Nonce: rand.Uint64(), Root: root[:], From 781ac07d676a4d93bd897536976fe06467c61d9f Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 7 Aug 2023 15:12:06 +0700 Subject: [PATCH 072/119] Rename receivers and implement pendingStorage --- core/state/state_object.go | 362 +++++++++++++++++++++++-------------- 1 file changed, 228 insertions(+), 134 deletions(-) diff --git a/core/state/state_object.go b/core/state/state_object.go index 0aa936fc9c..05e11073c7 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -21,10 +21,12 @@ import ( "fmt" "io" "math/big" + "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/rlp" ) @@ -78,11 +80,12 @@ type stateObject struct { trie Trie // storage trie, which becomes non-nil on first access code Code // contract bytecode, which gets set when code is loaded - cachedStorage Storage // Storage entry cache to avoid duplicate reads - dirtyStorage Storage // Storage entries that need to be flushed to disk + originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction + pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block + dirtyStorage Storage // Storage entries that have been modified in the current transaction execution // Cache flags. - // When an object is marked suicided it will be delete from the trie + // When an object is marked suicided it will be deleted from the trie // during the "update" phase of the state transition. dirtyCode bool // true if the code was updated suicided bool @@ -103,186 +106,277 @@ func newObject(db *StateDB, address common.Address, data *types.StateAccount) *s data.CodeHash = emptyCodeHash } return &stateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - data: *data, - cachedStorage: make(Storage), - dirtyStorage: make(Storage), + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: *data, + originStorage: make(Storage), + pendingStorage: make(Storage), + dirtyStorage: make(Storage), } } // EncodeRLP implements rlp.Encoder. -func (c *stateObject) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, c.data) +func (s *stateObject) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, s.data) } // setError remembers the first non-nil error it is called with. -func (self *stateObject) setError(err error) { - if self.dbErr == nil { - self.dbErr = err +func (s *stateObject) setError(err error) { + if s.dbErr == nil { + s.dbErr = err } } -func (self *stateObject) markSuicided() { - self.suicided = true +func (s *stateObject) markSuicided() { + s.suicided = true } -func (c *stateObject) touch() { - c.db.journal.append(touchChange{ - account: &c.address, +func (s *stateObject) touch() { + s.db.journal.append(touchChange{ + account: &s.address, }) - if c.address == ripemd { + if s.address == ripemd { // Explicitly put it in the dirty-cache, which is otherwise generated from // flattened journals. - c.db.journal.dirty(c.address) + s.db.journal.dirty(s.address) } } -func (c *stateObject) getTrie(db Database) Trie { - if c.trie == nil { +func (s *stateObject) getTrie(db Database) Trie { + if s.trie == nil { var err error - c.trie, err = db.OpenStorageTrie(c.addrHash, c.data.Root) + s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) if err != nil { - c.trie, _ = db.OpenStorageTrie(c.addrHash, common.Hash{}) - c.setError(fmt.Errorf("can't create storage trie: %v", err)) + s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) + s.setError(fmt.Errorf("can't create storage trie: %v", err)) } } - return c.trie + return s.trie } -func (self *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { - value := common.Hash{} - // Load from DB in case it is missing. - val, err := self.getTrie(db).GetStorage(self.address, key.Bytes()) - if err != nil { - self.setError(err) - return common.Hash{} +func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { + // If we have a dirty value for this state entry, return it + value, dirty := s.dirtyStorage[key] + if dirty { + return value } - value.SetBytes(val) - return value + // Otherwise return the entry's original value + return s.GetCommittedState(db, key) } -func (self *stateObject) GetState(db Database, key common.Hash) common.Hash { - value, exists := self.cachedStorage[key] - if exists { +func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { + // If we have a pending write or clean cached, return that + if value, pending := s.pendingStorage[key]; pending { return value } - // Load from DB in case it is missing. - val, err := self.getTrie(db).GetStorage(self.address, key.Bytes()) - if err != nil { - self.setError(err) - return common.Hash{} + if value, cached := s.originStorage[key]; cached { + return value } - - value.SetBytes(val) - if (value != common.Hash{}) { - self.cachedStorage[key] = value + // If no live objects are available, attempt to use snapshots + var ( + enc []byte + err error + ) + if s.db.snap != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.SnapshotStorageReads += time.Since(start) }(time.Now()) + } + // If the object was destructed in *this* block (and potentially resurrected), + // the storage has been cleared out, and we should *not* consult the previous + // snapshot about any storage values. The only possible alternatives are: + // 1) resurrect happened, and new slot values were set -- those should + // have been handles via pendingStorage above. + // 2) we don't have new values, and can deliver empty response back + if _, destructed := s.db.snapDestructs[s.addrHash]; destructed { + return common.Hash{} + } + enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key[:])) + } + // If snapshot unavailable or reading from it failed, load from the database + if s.db.snap == nil || err != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageReads += time.Since(start) }(time.Now()) + } + if enc, err = s.getTrie(db).GetStorage(s.address, key.Bytes()); err != nil { + s.setError(err) + return common.Hash{} + } + } + var value common.Hash + if len(enc) > 0 { + _, content, _, err := rlp.Split(enc) + if err != nil { + s.setError(err) + } + value.SetBytes(content) } + s.originStorage[key] = value return value } // SetState updates a value in account storage. -func (self *stateObject) SetState(db Database, key, value common.Hash) { - self.db.journal.append(storageChange{ - account: &self.address, +func (s *stateObject) SetState(db Database, key, value common.Hash) { + // If the new value is the same as old, don't set + prev := s.GetState(db, key) + if prev == value { + return + } + s.db.journal.append(storageChange{ + account: &s.address, key: key, - prevalue: self.GetState(db, key), + prevalue: prev, }) - self.setState(key, value) + s.setState(key, value) +} + +func (s *stateObject) setState(key, value common.Hash) { + s.dirtyStorage[key] = value } -func (self *stateObject) setState(key, value common.Hash) { - self.cachedStorage[key] = value - self.dirtyStorage[key] = value +// finalise moves all dirty storage slots into the pending area to be hashed or +// committed later. It is invoked at the end of every transaction. +func (s *stateObject) finalise() { + for key, value := range s.dirtyStorage { + s.pendingStorage[key] = value + } + if len(s.dirtyStorage) > 0 { + s.dirtyStorage = make(Storage) + } } // updateTrie writes cached storage modifications into the object's storage trie. -func (self *stateObject) updateTrie(db Database) Trie { - tr := self.getTrie(db) - for key, value := range self.dirtyStorage { - delete(self.dirtyStorage, key) - if (value == common.Hash{}) { - self.setError(tr.DeleteStorage(self.address, key[:])) +func (s *stateObject) updateTrie(db Database) Trie { + // Make sure all dirty slots are finalized into the pending storage area + s.finalise() + if len(s.pendingStorage) == 0 { + return s.trie + } + // Track the amount of time wasted on updating the storge trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now()) + } + // Retrieve the snapshot storage map for the object + var storage map[common.Hash][]byte + if s.db.snap != nil { + // Retrieve the old storage map, if available, create a new one otherwise + storage = s.db.snapStorage[s.addrHash] + if storage == nil { + storage = make(map[common.Hash][]byte) + s.db.snapStorage[s.addrHash] = storage + } + } + // Insert all the pending updates into the trie + tr := s.getTrie(db) + for key, value := range s.pendingStorage { + // Skip noop changes, persist actual changes + if value == s.originStorage[key] { continue } - // Encoding []byte cannot fail, ok to ignore the error. - v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) - self.setError(tr.UpdateStorage(self.address, key[:], v)) + s.originStorage[key] = value + + var v []byte + if (value == common.Hash{}) { + s.setError(tr.DeleteStorage(s.address, key[:])) + } else { + // Encoding []byte cannot fail, ok to ignore the error. + v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) + s.setError(tr.UpdateStorage(s.address, key[:], v)) + } + // If state snapshotting is active, cache the data til commit + if storage != nil { + storage[crypto.Keccak256Hash(key[:])] = v // v will be nil if value is 0x00 + } + } + if len(s.pendingStorage) > 0 { + s.pendingStorage = make(Storage) } return tr } // UpdateRoot sets the trie root to the current root hash of -func (self *stateObject) updateRoot(db Database) { - self.updateTrie(db) - self.data.Root = self.trie.Hash() +func (s *stateObject) updateRoot(db Database) { + // If nothing changed, don't bother with hashing anything + if s.updateTrie(db) == nil { + return + } + // Track the amount of time wasted on hashing the storge trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now()) + } + s.data.Root = s.trie.Hash() } // CommitTrie the storage trie of the object to dwb. // This updates the trie root. -func (self *stateObject) CommitTrie(db Database) error { - self.updateTrie(db) - if self.dbErr != nil { - return self.dbErr +func (s *stateObject) CommitTrie(db Database) error { + // If nothing changed, don't bother with hashing anything + if s.updateTrie(db) == nil { + return nil + } + if s.dbErr != nil { + return s.dbErr + } + // Track the amount of time wasted on committing the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) } - root, err := self.trie.Commit(nil) + root, err := s.trie.Commit(nil) if err == nil { - self.data.Root = root + s.data.Root = root } return err } // AddBalance removes amount from c's balance. // It is used to add funds to the destination account of a transfer. -func (c *stateObject) AddBalance(amount *big.Int) { +func (s *stateObject) AddBalance(amount *big.Int) { // EIP158: We must check emptiness for the objects such that the account // clearing (0,0,0 objects) can take effect. if amount.Sign() == 0 { - if c.empty() { - c.touch() + if s.empty() { + s.touch() } return } - c.SetBalance(new(big.Int).Add(c.Balance(), amount)) + s.SetBalance(new(big.Int).Add(s.Balance(), amount)) } // SubBalance removes amount from c's balance. // It is used to remove funds from the origin account of a transfer. -func (c *stateObject) SubBalance(amount *big.Int) { +func (s *stateObject) SubBalance(amount *big.Int) { if amount.Sign() == 0 { return } - c.SetBalance(new(big.Int).Sub(c.Balance(), amount)) + s.SetBalance(new(big.Int).Sub(s.Balance(), amount)) } -func (self *stateObject) SetBalance(amount *big.Int) { - self.db.journal.append(balanceChange{ - account: &self.address, - prev: new(big.Int).Set(self.data.Balance), +func (s *stateObject) SetBalance(amount *big.Int) { + s.db.journal.append(balanceChange{ + account: &s.address, + prev: new(big.Int).Set(s.data.Balance), }) - self.setBalance(amount) + s.setBalance(amount) } -func (self *stateObject) setBalance(amount *big.Int) { - self.data.Balance = amount +func (s *stateObject) setBalance(amount *big.Int) { + s.data.Balance = amount } -// Return the gas back to the origin. Used by the Virtual machine or Closures -func (c *stateObject) ReturnGas(gas *big.Int) {} +// ReturnGas returns the gas back to the origin. Used by the Virtual machine or Closures +func (s *stateObject) ReturnGas(gas *big.Int) {} -func (self *stateObject) deepCopy(db *StateDB) *stateObject { - stateObject := newObject(db, self.address, &self.data) - if self.trie != nil { - stateObject.trie = db.db.CopyTrie(self.trie) +func (s *stateObject) deepCopy(db *StateDB) *stateObject { + stateObject := newObject(db, s.address, &s.data) + if s.trie != nil { + stateObject.trie = db.db.CopyTrie(s.trie) } - stateObject.code = self.code - stateObject.dirtyStorage = self.dirtyStorage.Copy() - stateObject.cachedStorage = self.dirtyStorage.Copy() - stateObject.suicided = self.suicided - stateObject.dirtyCode = self.dirtyCode - stateObject.deleted = self.deleted + stateObject.code = s.code + stateObject.dirtyStorage = s.dirtyStorage.Copy() + stateObject.suicided = s.suicided + stateObject.dirtyCode = s.dirtyCode + stateObject.deleted = s.deleted return stateObject } @@ -290,70 +384,70 @@ func (self *stateObject) deepCopy(db *StateDB) *stateObject { // Attribute accessors // -// Returns the address of the contract/account -func (c *stateObject) Address() common.Address { - return c.address +// Address returns the address of the contract/account +func (s *stateObject) Address() common.Address { + return s.address } // Code returns the contract code associated with this object, if any. -func (self *stateObject) Code(db Database) []byte { - if self.code != nil { - return self.code +func (s *stateObject) Code(db Database) []byte { + if s.code != nil { + return s.code } - if bytes.Equal(self.CodeHash(), emptyCodeHash) { + if bytes.Equal(s.CodeHash(), emptyCodeHash) { return nil } - code, err := db.ContractCode(self.addrHash, common.BytesToHash(self.CodeHash())) + code, err := db.ContractCode(s.addrHash, common.BytesToHash(s.CodeHash())) if err != nil { - self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err)) + s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) } - self.code = code + s.code = code return code } -func (self *stateObject) SetCode(codeHash common.Hash, code []byte) { - prevcode := self.Code(self.db.db) - self.db.journal.append(codeChange{ - account: &self.address, - prevhash: self.CodeHash(), +func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { + prevcode := s.Code(s.db.db) + s.db.journal.append(codeChange{ + account: &s.address, + prevhash: s.CodeHash(), prevcode: prevcode, }) - self.setCode(codeHash, code) + s.setCode(codeHash, code) } -func (self *stateObject) setCode(codeHash common.Hash, code []byte) { - self.code = code - self.data.CodeHash = codeHash[:] - self.dirtyCode = true +func (s *stateObject) setCode(codeHash common.Hash, code []byte) { + s.code = code + s.data.CodeHash = codeHash[:] + s.dirtyCode = true } -func (self *stateObject) SetNonce(nonce uint64) { - self.db.journal.append(nonceChange{ - account: &self.address, - prev: self.data.Nonce, +func (s *stateObject) SetNonce(nonce uint64) { + s.db.journal.append(nonceChange{ + account: &s.address, + prev: s.data.Nonce, }) - self.setNonce(nonce) + s.setNonce(nonce) } -func (self *stateObject) setNonce(nonce uint64) { - self.data.Nonce = nonce +func (s *stateObject) setNonce(nonce uint64) { + s.data.Nonce = nonce } -func (self *stateObject) CodeHash() []byte { - return self.data.CodeHash +func (s *stateObject) CodeHash() []byte { + return s.data.CodeHash } -func (self *stateObject) Balance() *big.Int { - return self.data.Balance +func (s *stateObject) Balance() *big.Int { + return s.data.Balance } -func (self *stateObject) Nonce() uint64 { - return self.data.Nonce +func (s *stateObject) Nonce() uint64 { + return s.data.Nonce } -// Never called, but must be present to allow stateObject to be used +// Value is never called, but must be present to allow stateObject to be used // as a vm.StateAccount interface that also satisfies the vm.ContractRef // interface. Interfaces are awesome. -func (self *stateObject) Value() *big.Int { +func (s *stateObject) Value() *big.Int { panic("Value on stateObject should never be called") } From d3a156f2d920d48aee9903860788956426bdcbb4 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 7 Aug 2023 15:14:12 +0700 Subject: [PATCH 073/119] Update statedb.go --- core/state/managed_state_test.go | 1 + core/state/state_test.go | 32 ++-- core/state/statedb.go | 244 +++++++++++++++++++++---------- 3 files changed, 187 insertions(+), 90 deletions(-) diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index 9df24323f5..c4fa4937aa 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" ) var addr = common.BytesToAddress([]byte("test")) diff --git a/core/state/state_test.go b/core/state/state_test.go index 8ed63c3e34..17ecf3b192 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -213,24 +213,30 @@ func compareStateObjects(so0, so1 *stateObject, t *testing.T) { t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code) } - if len(so1.cachedStorage) != len(so0.cachedStorage) { - t.Errorf("Storage size mismatch: have %d, want %d", len(so1.cachedStorage), len(so0.cachedStorage)) + if len(so1.dirtyStorage) != len(so0.dirtyStorage) { + t.Errorf("Dirty storage size mismatch: have %d, want %d", len(so1.dirtyStorage), len(so0.dirtyStorage)) } - for k, v := range so1.cachedStorage { - if so0.cachedStorage[k] != v { - t.Errorf("Storage key %x mismatch: have %v, want %v", k, so0.cachedStorage[k], v) + for k, v := range so1.dirtyStorage { + if so0.dirtyStorage[k] != v { + t.Errorf("Dirty storage key %x mismatch: have %v, want %v", k, so0.dirtyStorage[k], v) } } - for k, v := range so0.cachedStorage { - if so1.cachedStorage[k] != v { - t.Errorf("Storage key %x mismatch: have %v, want none.", k, v) + for k, v := range so0.dirtyStorage { + if so1.dirtyStorage[k] != v { + t.Errorf("Dirty storage key %x mismatch: have %v, want none.", k, v) } } - - if so0.suicided != so1.suicided { - t.Fatalf("suicided mismatch: have %v, want %v", so0.suicided, so1.suicided) + if len(so1.originStorage) != len(so0.originStorage) { + t.Errorf("Origin storage size mismatch: have %d, want %d", len(so1.originStorage), len(so0.originStorage)) + } + for k, v := range so1.originStorage { + if so0.originStorage[k] != v { + t.Errorf("Origin storage key %x mismatch: have %v, want %v", k, so0.originStorage[k], v) + } } - if so0.deleted != so1.deleted { - t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted) + for k, v := range so0.originStorage { + if so1.originStorage[k] != v { + t.Errorf("Origin storage key %x mismatch: have %v, want none.", k, v) + } } } diff --git a/core/state/statedb.go b/core/state/statedb.go index 58c4494cba..2c6253be7d 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -28,6 +28,7 @@ import ( "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" @@ -39,9 +40,6 @@ type revision struct { } var ( - // emptyRoot is the known root hash of an empty trie. - emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - // emptyState is the known hash of an empty state trie entry. emptyState = crypto.Keccak256Hash(nil) @@ -49,7 +47,7 @@ var ( emptyCode = crypto.Keccak256Hash(nil) ) -// StateDBs within the ethereum protocol are used to store anything +// StateDB within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: // * Contracts @@ -65,8 +63,9 @@ type StateDB struct { snapStorage map[common.Hash]map[common.Hash][]byte // This map holds 'live' objects, which will get modified while processing a state transition. - stateObjects map[common.Address]*stateObject - stateObjectsDirty map[common.Address]struct{} + stateObjects map[common.Address]*stateObject + stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie + stateObjectsDirty map[common.Address]struct{} // DB error. // State objects are used by the consensus core and VM which are @@ -124,24 +123,22 @@ func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) commo return common.Hash{} } -// Create a new state from a given trie. +// New creates a new state from a given trie. func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { tr, err := db.OpenTrie(root) if err != nil { return nil, err } sdb := &StateDB{ - db: db, - trie: tr, - snaps: snaps, - stateObjects: make(map[common.Address]*stateObject), - stateObjectsDirty: make(map[common.Address]struct{}), - logs: make(map[common.Hash][]*types.Log), - preimages: make(map[common.Hash][]byte), - journal: newJournal(), - } - if sdb.snaps != nil { - sdb.snap = sdb.snaps.Snapshot(root) + db: db, + trie: tr, + snaps: snaps, + stateObjects: make(map[common.Address]*stateObject), + stateObjectsPending: make(map[common.Address]struct{}), + stateObjectsDirty: make(map[common.Address]struct{}), + logs: make(map[common.Hash][]*types.Log), + preimages: make(map[common.Hash][]byte), + journal: newJournal(), } if sdb.snaps != nil { if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil { @@ -173,6 +170,7 @@ func (s *StateDB) Reset(root common.Hash) error { } s.trie = tr s.stateObjects = make(map[common.Address]*stateObject) + s.stateObjectsPending = make(map[common.Address]struct{}) s.stateObjectsDirty = make(map[common.Address]struct{}) s.thash = common.Hash{} s.bhash = common.Hash{} @@ -415,7 +413,7 @@ func (s *StateDB) updateStateObject(stateObject *stateObject) { // enough to track account updates at commit time, deletions need tracking // at transaction boundary level to ensure we capture state clearing. if s.snap != nil { - s.snapAccounts[stateObject.addrHash] = snapshot.AccountRLP(stateObject.data.Nonce, stateObject.data.Balance, stateObject.data.Root, stateObject.data.CodeHash) + s.snapAccounts[stateObject.addrHash] = types.SlimAccountRLP(stateObject.data) } } @@ -439,27 +437,10 @@ func (s *StateDB) DeleteAddress(addr common.Address) { // Retrieve a state object given my the address. Returns nil if not found. func (s *StateDB) getStateObject(addr common.Address) (stateObject *stateObject) { - // Prefer 'live' objects. - if obj := s.stateObjects[addr]; obj != nil { - if obj.deleted { - return nil - } + if obj := s.getDeletedStateObject(addr); obj != nil && !obj.deleted { return obj } - - // Load the object from the database. - data, err := s.trie.GetAccount(addr) - if err != nil { - s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err)) - return nil - } - if data == nil { - return nil - } - // Insert into the live set. - obj := newObject(s, addr, data) - s.setStateObject(obj) - return obj + return nil } // getDeletedStateObject is similar to getStateObject, but instead of returning @@ -480,29 +461,36 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { if metrics.EnabledExpensive { defer func(start time.Time) { s.SnapshotAccountReads += time.Since(start) }(time.Now()) } - var acc *snapshot.Account - if acc, err = s.snap.Account(crypto.Keccak256Hash(addr[:])); err == nil { + if acc, err := s.snap.Account(crypto.Keccak256Hash(addr[:])); err == nil { if acc == nil { return nil } - data.Nonce, data.Balance, data.CodeHash = acc.Nonce, acc.Balance, acc.CodeHash + data = &types.StateAccount{ + Nonce: acc.Nonce, + Balance: acc.Balance, + CodeHash: acc.CodeHash, + Root: common.BytesToHash(acc.Root), + } if len(data.CodeHash) == 0 { - data.CodeHash = emptyCodeHash + data.CodeHash = types.EmptyCodeHash.Bytes() } - data.Root = common.BytesToHash(acc.Root) if data.Root == (common.Hash{}) { - data.Root = emptyRoot + data.Root = types.EmptyRootHash } } } // If snapshot unavailable or reading from it failed, load from the database - if s.snap == nil || err != nil { + if data == nil { + start := time.Now() + data, err = s.trie.GetAccount(addr) if metrics.EnabledExpensive { - defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now()) + s.AccountReads += time.Since(start) } - data, err = s.trie.GetAccount(addr) if err != nil { - s.setError(err) + s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err)) + return nil + } + if data == nil { return nil } } @@ -516,7 +504,7 @@ func (s *StateDB) setStateObject(object *stateObject) { s.stateObjects[object.Address()] = object } -// Retrieve a state object or create a new state object if nil. +// GetOrNewStateObject retrieves a state object or create a new state object if nil. func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { stateObject := s.getStateObject(addr) if stateObject == nil || stateObject.deleted { @@ -529,12 +517,19 @@ func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { // the given address, it is overwritten and returned as the second return value. func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { prev = s.getStateObject(addr) + var prevdestruct bool + if s.snap != nil && prev != nil { + _, prevdestruct = s.snapDestructs[prev.addrHash] + if !prevdestruct { + s.snapDestructs[prev.addrHash] = struct{}{} + } + } newobj = newObject(s, addr, &types.StateAccount{}) newobj.setNonce(0) // sets the object to dirty if prev == nil { s.journal.append(createObjectChange{account: &addr}) } else { - s.journal.append(resetObjectChange{prev: prev}) + s.journal.append(resetObjectChange{prev: prev, prevdestruct: prevdestruct}) } s.setStateObject(newobj) return newobj, prev @@ -562,18 +557,25 @@ func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common. if so == nil { return nil } - - // When iterating over the storage check the cache first - for h, value := range so.cachedStorage { - cb(h, value) - } - it := trie.NewIterator(so.getTrie(s.db).NodeIterator(nil)) + for it.Next() { - // ignore cached values key := common.BytesToHash(s.trie.GetKey(it.Key)) - if _, ok := so.cachedStorage[key]; !ok { - cb(key, common.BytesToHash(it.Value)) + if value, dirty := so.dirtyStorage[key]; dirty { + if !cb(key, value) { + return nil + } + continue + } + + if len(it.Value) > 0 { + _, content, _, err := rlp.Split(it.Value) + if err != nil { + return err + } + if !cb(key, common.BytesToHash(content)) { + return nil + } } } return nil @@ -587,24 +589,62 @@ func (s *StateDB) Copy() *StateDB { // Copy all the basic fields, initialize the memory ones state := &StateDB{ - db: s.db, - trie: s.db.CopyTrie(s.trie), - stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), - stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)), - refund: s.refund, - logs: make(map[common.Hash][]*types.Log, len(s.logs)), - logSize: s.logSize, - preimages: make(map[common.Hash][]byte), - journal: newJournal(), + db: s.db, + trie: s.db.CopyTrie(s.trie), + stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), + stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)), + stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)), + refund: s.refund, + logs: make(map[common.Hash][]*types.Log, len(s.logs)), + logSize: s.logSize, + preimages: make(map[common.Hash][]byte), + journal: newJournal(), + + // In order for the block producer to be able to use and make additions + // to the snapshot tree, we need to copy that as well. Otherwise, any + // block mined by ourselves will cause gaps in the tree, and force the + // miner to operate trie-backed only. + snaps: s.snaps, + snap: s.snap, } // Copy the dirty states, logs, and preimages for addr := range s.journal.dirties { - state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) + // As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527), + // and in the Finalise-method, there is a case where an object is in the journal but not + // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for + // nil + if object, exist := s.stateObjects[addr]; exist { + // Even though the original object is dirty, we are not copying the journal, + // so we need to make sure that any side effect the journal would have caused + // during a commit (or similar op) is already applied to the copy. + state.stateObjects[addr] = object.deepCopy(state) + + state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits + state.stateObjectsPending[addr] = struct{}{} // Mark the copy pending to force external (account) commits + } + } + // Above, we don't copy the actual journal. This means that if the copy is copied, the + // loop above will be a no-op, since the copy's journal is empty. + // Thus, here we iterate over stateObjects, to enable copies of copies + for addr := range s.stateObjectsPending { + if _, exist := state.stateObjects[addr]; !exist { + state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) + } + state.stateObjectsPending[addr] = struct{}{} + } + for addr := range s.stateObjectsDirty { + if _, exist := state.stateObjects[addr]; !exist { + state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) + } state.stateObjectsDirty[addr] = struct{}{} } for hash, logs := range s.logs { - state.logs[hash] = make([]*types.Log, len(logs)) - copy(state.logs[hash], logs) + cpy := make([]*types.Log, len(logs)) + for i, l := range logs { + cpy[i] = new(types.Log) + *cpy[i] = *l + } + state.logs[hash] = cpy } for hash, preimage := range s.preimages { state.preimages[hash] = preimage @@ -645,17 +685,26 @@ func (s *StateDB) GetRefund() uint64 { // and clears the journal as well as the refunds. func (s *StateDB) Finalise(deleteEmptyObjects bool) { for addr := range s.journal.dirties { - stateObject, exist := s.stateObjects[addr] + obj, exist := s.stateObjects[addr] if !exist { continue } - - if stateObject.suicided || (deleteEmptyObjects && stateObject.empty()) { - s.deleteStateObject(stateObject) + if obj.suicided || (deleteEmptyObjects && obj.empty()) { + obj.deleted = true + + // If state snapshotting is active, also mark the destruction there. + // Note, we can't do this only at the end of a block because multiple + // transactions within the same block might self-destruct and then + // resurrect an account; but the snapshotter needs both events. + if s.snap != nil { + s.snapDestructs[obj.addrHash] = struct{}{} // We need to maintain account deletions explicitly (will remain set indefinitely) + delete(s.snapAccounts, obj.addrHash) // Clear out any previously updated account data (may be recreated via a resurrect) + delete(s.snapStorage, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a resurrect) + } } else { - stateObject.updateRoot(s.db) - s.updateStateObject(stateObject) + obj.finalise() } + s.stateObjectsPending[addr] = struct{}{} s.stateObjectsDirty[addr] = struct{}{} } // Invalidate journal because reverting across transactions is not allowed. @@ -666,7 +715,25 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) { // It is called in between transactions to get the root hash that // goes into transaction receipts. func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { + // Finalise all the dirty storage states and write them into the tries s.Finalise(deleteEmptyObjects) + + for addr := range s.stateObjectsPending { + obj := s.stateObjects[addr] + if obj.deleted { + s.deleteStateObject(obj) + } else { + obj.updateRoot(s.db) + s.updateStateObject(obj) + } + } + if len(s.stateObjectsPending) > 0 { + s.stateObjectsPending = make(map[common.Address]struct{}) + } + // Track the amount of time wasted on hashing the account trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountHashes += time.Since(start) }(time.Now()) + } return s.trie.Hash() } @@ -737,6 +804,10 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) delete(s.stateObjectsDirty, addr) } // Write trie changes. + var start time.Time + if metrics.EnabledExpensive { + start = time.Now() + } root, err = s.trie.Commit(func(leaf []byte, parent common.Hash) error { var account types.StateAccount if err := rlp.DecodeBytes(leaf, &account); err != nil { @@ -751,6 +822,25 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) } return nil }) + if metrics.EnabledExpensive { + s.AccountCommits += time.Since(start) + } + // If snapshotting is enabled, update the snapshot tree with this new version + if s.snap != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.SnapshotCommits += time.Since(start) }(time.Now()) + } + // Only update if there's a state transition (skip empty Clique blocks) + if parent := s.snap.Root(); parent != root { + if err := s.snaps.Update(root, parent, s.snapDestructs, s.snapAccounts, s.snapStorage); err != nil { + log.Warn("Failed to update snapshot tree", "from", parent, "to", root, "err", err) + } + if err := s.snaps.Cap(root, 127); err != nil { // Persistent layer is 128th, the last available trie + log.Warn("Failed to cap snapshot tree", "root", root, "layers", 127, "err", err) + } + } + s.snap, s.snapDestructs, s.snapAccounts, s.snapStorage = nil, nil, nil, nil + } return root, err } From 3659a0d9e7ade99fa8067c01933b2d222e1242a5 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 10 Aug 2023 17:33:12 +0700 Subject: [PATCH 074/119] Revert --- core/state/state_object.go | 178 +++++++-------------------- core/state/statedb.go | 242 ++++++++++++------------------------- 2 files changed, 118 insertions(+), 302 deletions(-) diff --git a/core/state/state_object.go b/core/state/state_object.go index 05e11073c7..eecaa346fb 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -21,12 +21,10 @@ import ( "fmt" "io" "math/big" - "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/rlp" ) @@ -80,12 +78,11 @@ type stateObject struct { trie Trie // storage trie, which becomes non-nil on first access code Code // contract bytecode, which gets set when code is loaded - originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction - pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block - dirtyStorage Storage // Storage entries that have been modified in the current transaction execution + cachedStorage Storage // Storage entry cache to avoid duplicate reads + dirtyStorage Storage // Storage entries that need to be flushed to disk // Cache flags. - // When an object is marked suicided it will be deleted from the trie + // When an object is marked suicided it will be delete from the trie // during the "update" phase of the state transition. dirtyCode bool // true if the code was updated suicided bool @@ -106,13 +103,12 @@ func newObject(db *StateDB, address common.Address, data *types.StateAccount) *s data.CodeHash = emptyCodeHash } return &stateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - data: *data, - originStorage: make(Storage), - pendingStorage: make(Storage), - dirtyStorage: make(Storage), + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: *data, + cachedStorage: make(Storage), + dirtyStorage: make(Storage), } } @@ -155,172 +151,81 @@ func (s *stateObject) getTrie(db Database) Trie { return s.trie } -func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { - // If we have a dirty value for this state entry, return it - value, dirty := s.dirtyStorage[key] - if dirty { - return value +func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { + value := common.Hash{} + // Load from DB in case it is missing. + val, err := s.getTrie(db).GetStorage(s.address, key.Bytes()) + if err != nil { + s.setError(err) + return common.Hash{} } - // Otherwise return the entry's original value - return s.GetCommittedState(db, key) + value.SetBytes(val) + return value } -func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { - // If we have a pending write or clean cached, return that - if value, pending := s.pendingStorage[key]; pending { - return value - } - if value, cached := s.originStorage[key]; cached { +func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { + value, exists := s.cachedStorage[key] + if exists { return value } - // If no live objects are available, attempt to use snapshots - var ( - enc []byte - err error - ) - if s.db.snap != nil { - if metrics.EnabledExpensive { - defer func(start time.Time) { s.db.SnapshotStorageReads += time.Since(start) }(time.Now()) - } - // If the object was destructed in *this* block (and potentially resurrected), - // the storage has been cleared out, and we should *not* consult the previous - // snapshot about any storage values. The only possible alternatives are: - // 1) resurrect happened, and new slot values were set -- those should - // have been handles via pendingStorage above. - // 2) we don't have new values, and can deliver empty response back - if _, destructed := s.db.snapDestructs[s.addrHash]; destructed { - return common.Hash{} - } - enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key[:])) - } - // If snapshot unavailable or reading from it failed, load from the database - if s.db.snap == nil || err != nil { - if metrics.EnabledExpensive { - defer func(start time.Time) { s.db.StorageReads += time.Since(start) }(time.Now()) - } - if enc, err = s.getTrie(db).GetStorage(s.address, key.Bytes()); err != nil { - s.setError(err) - return common.Hash{} - } + // Load from DB in case it is missing. + val, err := s.getTrie(db).GetStorage(s.address, key.Bytes()) + if err != nil { + s.setError(err) + return common.Hash{} } - var value common.Hash - if len(enc) > 0 { - _, content, _, err := rlp.Split(enc) - if err != nil { - s.setError(err) - } - value.SetBytes(content) + + value.SetBytes(val) + if (value != common.Hash{}) { + s.cachedStorage[key] = value } - s.originStorage[key] = value return value } // SetState updates a value in account storage. func (s *stateObject) SetState(db Database, key, value common.Hash) { - // If the new value is the same as old, don't set - prev := s.GetState(db, key) - if prev == value { - return - } s.db.journal.append(storageChange{ account: &s.address, key: key, - prevalue: prev, + prevalue: s.GetState(db, key), }) s.setState(key, value) } func (s *stateObject) setState(key, value common.Hash) { + s.cachedStorage[key] = value s.dirtyStorage[key] = value } -// finalise moves all dirty storage slots into the pending area to be hashed or -// committed later. It is invoked at the end of every transaction. -func (s *stateObject) finalise() { - for key, value := range s.dirtyStorage { - s.pendingStorage[key] = value - } - if len(s.dirtyStorage) > 0 { - s.dirtyStorage = make(Storage) - } -} - // updateTrie writes cached storage modifications into the object's storage trie. func (s *stateObject) updateTrie(db Database) Trie { - // Make sure all dirty slots are finalized into the pending storage area - s.finalise() - if len(s.pendingStorage) == 0 { - return s.trie - } - // Track the amount of time wasted on updating the storge trie - if metrics.EnabledExpensive { - defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now()) - } - // Retrieve the snapshot storage map for the object - var storage map[common.Hash][]byte - if s.db.snap != nil { - // Retrieve the old storage map, if available, create a new one otherwise - storage = s.db.snapStorage[s.addrHash] - if storage == nil { - storage = make(map[common.Hash][]byte) - s.db.snapStorage[s.addrHash] = storage - } - } - // Insert all the pending updates into the trie tr := s.getTrie(db) - for key, value := range s.pendingStorage { - // Skip noop changes, persist actual changes - if value == s.originStorage[key] { - continue - } - s.originStorage[key] = value - - var v []byte + for key, value := range s.dirtyStorage { + delete(s.dirtyStorage, key) if (value == common.Hash{}) { s.setError(tr.DeleteStorage(s.address, key[:])) - } else { - // Encoding []byte cannot fail, ok to ignore the error. - v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) - s.setError(tr.UpdateStorage(s.address, key[:], v)) - } - // If state snapshotting is active, cache the data til commit - if storage != nil { - storage[crypto.Keccak256Hash(key[:])] = v // v will be nil if value is 0x00 + continue } - } - if len(s.pendingStorage) > 0 { - s.pendingStorage = make(Storage) + // Encoding []byte cannot fail, ok to ignore the error. + v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) + s.setError(tr.UpdateStorage(s.address, key[:], v)) } return tr } // UpdateRoot sets the trie root to the current root hash of func (s *stateObject) updateRoot(db Database) { - // If nothing changed, don't bother with hashing anything - if s.updateTrie(db) == nil { - return - } - // Track the amount of time wasted on hashing the storge trie - if metrics.EnabledExpensive { - defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now()) - } + s.updateTrie(db) s.data.Root = s.trie.Hash() } // CommitTrie the storage trie of the object to dwb. // This updates the trie root. func (s *stateObject) CommitTrie(db Database) error { - // If nothing changed, don't bother with hashing anything - if s.updateTrie(db) == nil { - return nil - } + s.updateTrie(db) if s.dbErr != nil { return s.dbErr } - // Track the amount of time wasted on committing the storage trie - if metrics.EnabledExpensive { - defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) - } root, err := s.trie.Commit(nil) if err == nil { s.data.Root = root @@ -364,7 +269,7 @@ func (s *stateObject) setBalance(amount *big.Int) { s.data.Balance = amount } -// ReturnGas returns the gas back to the origin. Used by the Virtual machine or Closures +// ReturnGas return the gas back to the origin. Used by the Virtual machine or Closures func (s *stateObject) ReturnGas(gas *big.Int) {} func (s *stateObject) deepCopy(db *StateDB) *stateObject { @@ -374,6 +279,7 @@ func (s *stateObject) deepCopy(db *StateDB) *stateObject { } stateObject.code = s.code stateObject.dirtyStorage = s.dirtyStorage.Copy() + stateObject.cachedStorage = s.dirtyStorage.Copy() stateObject.suicided = s.suicided stateObject.dirtyCode = s.dirtyCode stateObject.deleted = s.deleted diff --git a/core/state/statedb.go b/core/state/statedb.go index 2c6253be7d..75d69cb061 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -28,7 +28,6 @@ import ( "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" @@ -40,6 +39,9 @@ type revision struct { } var ( + // emptyRoot is the known root hash of an empty trie. + emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + // emptyState is the known hash of an empty state trie entry. emptyState = crypto.Keccak256Hash(nil) @@ -47,7 +49,7 @@ var ( emptyCode = crypto.Keccak256Hash(nil) ) -// StateDB within the ethereum protocol are used to store anything +// StateDBs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: // * Contracts @@ -63,9 +65,8 @@ type StateDB struct { snapStorage map[common.Hash]map[common.Hash][]byte // This map holds 'live' objects, which will get modified while processing a state transition. - stateObjects map[common.Address]*stateObject - stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie - stateObjectsDirty map[common.Address]struct{} + stateObjects map[common.Address]*stateObject + stateObjectsDirty map[common.Address]struct{} // DB error. // State objects are used by the consensus core and VM which are @@ -123,22 +124,24 @@ func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) commo return common.Hash{} } -// New creates a new state from a given trie. +// Create a new state from a given trie. func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { tr, err := db.OpenTrie(root) if err != nil { return nil, err } sdb := &StateDB{ - db: db, - trie: tr, - snaps: snaps, - stateObjects: make(map[common.Address]*stateObject), - stateObjectsPending: make(map[common.Address]struct{}), - stateObjectsDirty: make(map[common.Address]struct{}), - logs: make(map[common.Hash][]*types.Log), - preimages: make(map[common.Hash][]byte), - journal: newJournal(), + db: db, + trie: tr, + snaps: snaps, + stateObjects: make(map[common.Address]*stateObject), + stateObjectsDirty: make(map[common.Address]struct{}), + logs: make(map[common.Hash][]*types.Log), + preimages: make(map[common.Hash][]byte), + journal: newJournal(), + } + if sdb.snaps != nil { + sdb.snap = sdb.snaps.Snapshot(root) } if sdb.snaps != nil { if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil { @@ -170,7 +173,6 @@ func (s *StateDB) Reset(root common.Hash) error { } s.trie = tr s.stateObjects = make(map[common.Address]*stateObject) - s.stateObjectsPending = make(map[common.Address]struct{}) s.stateObjectsDirty = make(map[common.Address]struct{}) s.thash = common.Hash{} s.bhash = common.Hash{} @@ -437,10 +439,27 @@ func (s *StateDB) DeleteAddress(addr common.Address) { // Retrieve a state object given my the address. Returns nil if not found. func (s *StateDB) getStateObject(addr common.Address) (stateObject *stateObject) { - if obj := s.getDeletedStateObject(addr); obj != nil && !obj.deleted { + // Prefer 'live' objects. + if obj := s.stateObjects[addr]; obj != nil { + if obj.deleted { + return nil + } return obj } - return nil + + // Load the object from the database. + data, err := s.trie.GetAccount(addr) + if err != nil { + s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err)) + return nil + } + if data == nil { + return nil + } + // Insert into the live set. + obj := newObject(s, addr, data) + s.setStateObject(obj) + return obj } // getDeletedStateObject is similar to getStateObject, but instead of returning @@ -461,36 +480,29 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { if metrics.EnabledExpensive { defer func(start time.Time) { s.SnapshotAccountReads += time.Since(start) }(time.Now()) } - if acc, err := s.snap.Account(crypto.Keccak256Hash(addr[:])); err == nil { + var acc *types.SlimAccount + if acc, err = s.snap.Account(crypto.Keccak256Hash(addr[:])); err == nil { if acc == nil { return nil } - data = &types.StateAccount{ - Nonce: acc.Nonce, - Balance: acc.Balance, - CodeHash: acc.CodeHash, - Root: common.BytesToHash(acc.Root), - } + data.Nonce, data.Balance, data.CodeHash = acc.Nonce, acc.Balance, acc.CodeHash if len(data.CodeHash) == 0 { - data.CodeHash = types.EmptyCodeHash.Bytes() + data.CodeHash = emptyCodeHash } + data.Root = common.BytesToHash(acc.Root) if data.Root == (common.Hash{}) { - data.Root = types.EmptyRootHash + data.Root = emptyRoot } } } // If snapshot unavailable or reading from it failed, load from the database - if data == nil { - start := time.Now() - data, err = s.trie.GetAccount(addr) + if s.snap == nil || err != nil { if metrics.EnabledExpensive { - s.AccountReads += time.Since(start) + defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now()) } + data, err = s.trie.GetAccount(addr) if err != nil { - s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err)) - return nil - } - if data == nil { + s.setError(err) return nil } } @@ -504,7 +516,7 @@ func (s *StateDB) setStateObject(object *stateObject) { s.stateObjects[object.Address()] = object } -// GetOrNewStateObject retrieves a state object or create a new state object if nil. +// Retrieve a state object or create a new state object if nil. func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { stateObject := s.getStateObject(addr) if stateObject == nil || stateObject.deleted { @@ -517,19 +529,12 @@ func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { // the given address, it is overwritten and returned as the second return value. func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { prev = s.getStateObject(addr) - var prevdestruct bool - if s.snap != nil && prev != nil { - _, prevdestruct = s.snapDestructs[prev.addrHash] - if !prevdestruct { - s.snapDestructs[prev.addrHash] = struct{}{} - } - } newobj = newObject(s, addr, &types.StateAccount{}) newobj.setNonce(0) // sets the object to dirty if prev == nil { s.journal.append(createObjectChange{account: &addr}) } else { - s.journal.append(resetObjectChange{prev: prev, prevdestruct: prevdestruct}) + s.journal.append(resetObjectChange{prev: prev}) } s.setStateObject(newobj) return newobj, prev @@ -557,25 +562,18 @@ func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common. if so == nil { return nil } - it := trie.NewIterator(so.getTrie(s.db).NodeIterator(nil)) + // When iterating over the storage check the cache first + for h, value := range so.cachedStorage { + cb(h, value) + } + + it := trie.NewIterator(so.getTrie(s.db).NodeIterator(nil)) for it.Next() { + // ignore cached values key := common.BytesToHash(s.trie.GetKey(it.Key)) - if value, dirty := so.dirtyStorage[key]; dirty { - if !cb(key, value) { - return nil - } - continue - } - - if len(it.Value) > 0 { - _, content, _, err := rlp.Split(it.Value) - if err != nil { - return err - } - if !cb(key, common.BytesToHash(content)) { - return nil - } + if _, ok := so.cachedStorage[key]; !ok { + cb(key, common.BytesToHash(it.Value)) } } return nil @@ -589,62 +587,24 @@ func (s *StateDB) Copy() *StateDB { // Copy all the basic fields, initialize the memory ones state := &StateDB{ - db: s.db, - trie: s.db.CopyTrie(s.trie), - stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), - stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)), - stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)), - refund: s.refund, - logs: make(map[common.Hash][]*types.Log, len(s.logs)), - logSize: s.logSize, - preimages: make(map[common.Hash][]byte), - journal: newJournal(), - - // In order for the block producer to be able to use and make additions - // to the snapshot tree, we need to copy that as well. Otherwise, any - // block mined by ourselves will cause gaps in the tree, and force the - // miner to operate trie-backed only. - snaps: s.snaps, - snap: s.snap, + db: s.db, + trie: s.db.CopyTrie(s.trie), + stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), + stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)), + refund: s.refund, + logs: make(map[common.Hash][]*types.Log, len(s.logs)), + logSize: s.logSize, + preimages: make(map[common.Hash][]byte), + journal: newJournal(), } // Copy the dirty states, logs, and preimages for addr := range s.journal.dirties { - // As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527), - // and in the Finalise-method, there is a case where an object is in the journal but not - // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for - // nil - if object, exist := s.stateObjects[addr]; exist { - // Even though the original object is dirty, we are not copying the journal, - // so we need to make sure that any side effect the journal would have caused - // during a commit (or similar op) is already applied to the copy. - state.stateObjects[addr] = object.deepCopy(state) - - state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits - state.stateObjectsPending[addr] = struct{}{} // Mark the copy pending to force external (account) commits - } - } - // Above, we don't copy the actual journal. This means that if the copy is copied, the - // loop above will be a no-op, since the copy's journal is empty. - // Thus, here we iterate over stateObjects, to enable copies of copies - for addr := range s.stateObjectsPending { - if _, exist := state.stateObjects[addr]; !exist { - state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) - } - state.stateObjectsPending[addr] = struct{}{} - } - for addr := range s.stateObjectsDirty { - if _, exist := state.stateObjects[addr]; !exist { - state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) - } + state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) state.stateObjectsDirty[addr] = struct{}{} } for hash, logs := range s.logs { - cpy := make([]*types.Log, len(logs)) - for i, l := range logs { - cpy[i] = new(types.Log) - *cpy[i] = *l - } - state.logs[hash] = cpy + state.logs[hash] = make([]*types.Log, len(logs)) + copy(state.logs[hash], logs) } for hash, preimage := range s.preimages { state.preimages[hash] = preimage @@ -685,26 +645,17 @@ func (s *StateDB) GetRefund() uint64 { // and clears the journal as well as the refunds. func (s *StateDB) Finalise(deleteEmptyObjects bool) { for addr := range s.journal.dirties { - obj, exist := s.stateObjects[addr] + stateObject, exist := s.stateObjects[addr] if !exist { continue } - if obj.suicided || (deleteEmptyObjects && obj.empty()) { - obj.deleted = true - - // If state snapshotting is active, also mark the destruction there. - // Note, we can't do this only at the end of a block because multiple - // transactions within the same block might self-destruct and then - // resurrect an account; but the snapshotter needs both events. - if s.snap != nil { - s.snapDestructs[obj.addrHash] = struct{}{} // We need to maintain account deletions explicitly (will remain set indefinitely) - delete(s.snapAccounts, obj.addrHash) // Clear out any previously updated account data (may be recreated via a resurrect) - delete(s.snapStorage, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a resurrect) - } + + if stateObject.suicided || (deleteEmptyObjects && stateObject.empty()) { + s.deleteStateObject(stateObject) } else { - obj.finalise() + stateObject.updateRoot(s.db) + s.updateStateObject(stateObject) } - s.stateObjectsPending[addr] = struct{}{} s.stateObjectsDirty[addr] = struct{}{} } // Invalidate journal because reverting across transactions is not allowed. @@ -715,25 +666,7 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) { // It is called in between transactions to get the root hash that // goes into transaction receipts. func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { - // Finalise all the dirty storage states and write them into the tries s.Finalise(deleteEmptyObjects) - - for addr := range s.stateObjectsPending { - obj := s.stateObjects[addr] - if obj.deleted { - s.deleteStateObject(obj) - } else { - obj.updateRoot(s.db) - s.updateStateObject(obj) - } - } - if len(s.stateObjectsPending) > 0 { - s.stateObjectsPending = make(map[common.Address]struct{}) - } - // Track the amount of time wasted on hashing the account trie - if metrics.EnabledExpensive { - defer func(start time.Time) { s.AccountHashes += time.Since(start) }(time.Now()) - } return s.trie.Hash() } @@ -804,10 +737,6 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) delete(s.stateObjectsDirty, addr) } // Write trie changes. - var start time.Time - if metrics.EnabledExpensive { - start = time.Now() - } root, err = s.trie.Commit(func(leaf []byte, parent common.Hash) error { var account types.StateAccount if err := rlp.DecodeBytes(leaf, &account); err != nil { @@ -822,25 +751,6 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) } return nil }) - if metrics.EnabledExpensive { - s.AccountCommits += time.Since(start) - } - // If snapshotting is enabled, update the snapshot tree with this new version - if s.snap != nil { - if metrics.EnabledExpensive { - defer func(start time.Time) { s.SnapshotCommits += time.Since(start) }(time.Now()) - } - // Only update if there's a state transition (skip empty Clique blocks) - if parent := s.snap.Root(); parent != root { - if err := s.snaps.Update(root, parent, s.snapDestructs, s.snapAccounts, s.snapStorage); err != nil { - log.Warn("Failed to update snapshot tree", "from", parent, "to", root, "err", err) - } - if err := s.snaps.Cap(root, 127); err != nil { // Persistent layer is 128th, the last available trie - log.Warn("Failed to cap snapshot tree", "root", root, "layers", 127, "err", err) - } - } - s.snap, s.snapDestructs, s.snapAccounts, s.snapStorage = nil, nil, nil, nil - } return root, err } From 420d3432c3107ef4f9c1d16270aca75090017aea Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 14 Aug 2023 11:30:44 +0700 Subject: [PATCH 075/119] Fix commit and copy state --- contracts/validator/validator_test.go | 3 - core/state/state_object.go | 242 ++++++++++++++++++++------ core/state/statedb.go | 48 +++-- eth/tracers/tracers_test.go | 2 +- 4 files changed, 224 insertions(+), 71 deletions(-) diff --git a/contracts/validator/validator_test.go b/contracts/validator/validator_test.go index c7a452d751..9cdb8bec87 100644 --- a/contracts/validator/validator_test.go +++ b/contracts/validator/validator_test.go @@ -60,10 +60,7 @@ func TestValidator(t *testing.T) { d := time.Now().Add(1000 * time.Millisecond) ctx, cancel := context.WithDeadline(context.Background(), d) defer cancel() - code, _ := contractBackend.CodeAt(ctx, validatorAddress, nil) - t.Log("contract code", common.ToHex(code)) f := func(key, val common.Hash) bool { - t.Log(key.Hex(), val.Hex()) return true } contractBackend.ForEachStorageAt(ctx, validatorAddress, nil, f) diff --git a/core/state/state_object.go b/core/state/state_object.go index eecaa346fb..b0dc8da14d 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -21,10 +21,12 @@ import ( "fmt" "io" "math/big" + "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/rlp" ) @@ -32,23 +34,23 @@ var emptyCodeHash = crypto.Keccak256(nil) type Code []byte -func (self Code) String() string { - return string(self) //strings.Join(Disassemble(self), " ") +func (c Code) String() string { + return string(c) //strings.Join(Disassemble(c), " ") } type Storage map[common.Hash]common.Hash -func (self Storage) String() (str string) { - for key, value := range self { +func (s Storage) String() (str string) { + for key, value := range s { str += fmt.Sprintf("%X : %X\n", key, value) } return } -func (self Storage) Copy() Storage { +func (s Storage) Copy() Storage { cpy := make(Storage) - for key, value := range self { + for key, value := range s { cpy[key] = value } @@ -59,7 +61,7 @@ func (self Storage) Copy() Storage { // // The usage pattern is as follows: // First you need to obtain a state object. -// StateAccount values can be accessed and modified through the object. +// Account values can be accessed and modified through the object. // Finally, call CommitTrie to write the modified storage trie into a database. type stateObject struct { address common.Address @@ -78,11 +80,13 @@ type stateObject struct { trie Trie // storage trie, which becomes non-nil on first access code Code // contract bytecode, which gets set when code is loaded - cachedStorage Storage // Storage entry cache to avoid duplicate reads - dirtyStorage Storage // Storage entries that need to be flushed to disk + originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction + pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block + dirtyStorage Storage // Storage entries that have been modified in the current transaction execution + fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. // Cache flags. - // When an object is marked suicided it will be delete from the trie + // When an object is marked suicided it will be deleted from the trie // during the "update" phase of the state transition. dirtyCode bool // true if the code was updated suicided bool @@ -102,13 +106,17 @@ func newObject(db *StateDB, address common.Address, data *types.StateAccount) *s if data.CodeHash == nil { data.CodeHash = emptyCodeHash } + if data.Root == (common.Hash{}) { + data.Root = emptyRoot + } return &stateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - data: *data, - cachedStorage: make(Storage), - dirtyStorage: make(Storage), + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: *data, + originStorage: make(Storage), + pendingStorage: make(Storage), + dirtyStorage: make(Storage), } } @@ -151,81 +159,210 @@ func (s *stateObject) getTrie(db Database) Trie { return s.trie } -func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { - value := common.Hash{} - // Load from DB in case it is missing. - val, err := s.getTrie(db).GetStorage(s.address, key.Bytes()) - if err != nil { - s.setError(err) - return common.Hash{} +// GetState retrieves a value from the account storage trie. +func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { + // If the fake storage is set, only lookup the state here(in the debugging mode) + if s.fakeStorage != nil { + return s.fakeStorage[key] } - value.SetBytes(val) - return value + // If we have a dirty value for this state entry, return it + value, dirty := s.dirtyStorage[key] + if dirty { + return value + } + // Otherwise return the entry's original value + return s.GetCommittedState(db, key) } -func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { - value, exists := s.cachedStorage[key] - if exists { +// GetCommittedState retrieves a value from the committed account storage trie. +func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { + // If the fake storage is set, only lookup the state here(in the debugging mode) + if s.fakeStorage != nil { + return s.fakeStorage[key] + } + // If we have a pending write or clean cached, return that + if value, pending := s.pendingStorage[key]; pending { return value } - // Load from DB in case it is missing. - val, err := s.getTrie(db).GetStorage(s.address, key.Bytes()) - if err != nil { - s.setError(err) - return common.Hash{} + if value, cached := s.originStorage[key]; cached { + return value } - - value.SetBytes(val) - if (value != common.Hash{}) { - s.cachedStorage[key] = value + // If no live objects are available, attempt to use snapshots + var ( + enc []byte + err error + value common.Hash + ) + if s.db.snap != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.SnapshotStorageReads += time.Since(start) }(time.Now()) + } + // If the object was destructed in *this* block (and potentially resurrected), + // the storage has been cleared out, and we should *not* consult the previous + // snapshot about any storage values. The only possible alternatives are: + // 1) resurrect happened, and new slot values were set -- those should + // have been handles via pendingStorage above. + // 2) we don't have new values, and can deliver empty response back + if _, destructed := s.db.snapDestructs[s.addrHash]; destructed { + return common.Hash{} + } + enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key[:])) + if len(enc) > 0 { + _, content, _, err := rlp.Split(enc) + if err != nil { + s.setError(err) + } + value.SetBytes(content) + } + } + // If snapshot unavailable or reading from it failed, load from the database + if s.db.snap == nil || err != nil { + start := time.Now() + val, err := s.getTrie(db).GetStorage(s.address, key.Bytes()) + if metrics.EnabledExpensive { + s.db.StorageReads += time.Since(start) + } + if err != nil { + s.setError(err) + return common.Hash{} + } + value.SetBytes(val) } + s.originStorage[key] = value return value } // SetState updates a value in account storage. func (s *stateObject) SetState(db Database, key, value common.Hash) { + // If the fake storage is set, put the temporary state update here. + if s.fakeStorage != nil { + s.fakeStorage[key] = value + return + } + // If the new value is the same as old, don't set + prev := s.GetState(db, key) + if prev == value { + return + } + // New value is different, update and journal the change s.db.journal.append(storageChange{ account: &s.address, key: key, - prevalue: s.GetState(db, key), + prevalue: prev, }) s.setState(key, value) } +// SetStorage replaces the entire state storage with the given one. +// +// After this function is called, all original state will be ignored and state +// lookup only happens in the fake state storage. +// +// Note this function should only be used for debugging purpose. +func (s *stateObject) SetStorage(storage map[common.Hash]common.Hash) { + // Allocate fake storage if it's nil. + if s.fakeStorage == nil { + s.fakeStorage = make(Storage) + } + for key, value := range storage { + s.fakeStorage[key] = value + } + // Don't bother journal since this function should only be used for + // debugging and the `fake` storage won't be committed to database. +} + func (s *stateObject) setState(key, value common.Hash) { - s.cachedStorage[key] = value s.dirtyStorage[key] = value } +// finalise moves all dirty storage slots into the pending area to be hashed or +// committed later. It is invoked at the end of every transaction. +func (s *stateObject) finalise() { + for key, value := range s.dirtyStorage { + s.pendingStorage[key] = value + } + if len(s.dirtyStorage) > 0 { + s.dirtyStorage = make(Storage) + } +} + // updateTrie writes cached storage modifications into the object's storage trie. +// It will return nil if the trie has not been loaded and no changes have been made func (s *stateObject) updateTrie(db Database) Trie { + // Make sure all dirty slots are finalized into the pending storage area + s.finalise() + if len(s.pendingStorage) == 0 { + return s.trie + } + // Track the amount of time wasted on updating the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now()) + } + // Retrieve the snapshot storage map for the object + var storage map[common.Hash][]byte + if s.db.snap != nil { + // Retrieve the old storage map, if available, create a new one otherwise + storage = s.db.snapStorage[s.addrHash] + if storage == nil { + storage = make(map[common.Hash][]byte) + s.db.snapStorage[s.addrHash] = storage + } + } + // Insert all the pending updates into the trie tr := s.getTrie(db) - for key, value := range s.dirtyStorage { - delete(s.dirtyStorage, key) - if (value == common.Hash{}) { - s.setError(tr.DeleteStorage(s.address, key[:])) + for key, value := range s.pendingStorage { + // Skip noop changes, persist actual changes + if value == s.originStorage[key] { continue } - // Encoding []byte cannot fail, ok to ignore the error. - v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) - s.setError(tr.UpdateStorage(s.address, key[:], v)) + s.originStorage[key] = value + + var v []byte + if (value == common.Hash{}) { + s.setError(tr.DeleteStorage(s.address, key.Bytes())) + } else { + // Encoding []byte cannot fail, ok to ignore the error. + v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) + s.setError(tr.UpdateStorage(s.address, key.Bytes(), v)) + } + // If state snapshotting is active, cache the data til commit + if storage != nil { + storage[crypto.Keccak256Hash(key[:])] = v // v will be nil if value is 0x00 + } + } + if len(s.pendingStorage) > 0 { + s.pendingStorage = make(Storage) } return tr } // UpdateRoot sets the trie root to the current root hash of func (s *stateObject) updateRoot(db Database) { - s.updateTrie(db) + // If nothing changed, don't bother with hashing anything + if s.updateTrie(db) == nil { + return + } + // Track the amount of time wasted on hashing the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now()) + } s.data.Root = s.trie.Hash() } -// CommitTrie the storage trie of the object to dwb. +// CommitTrie the storage trie of the object to db. // This updates the trie root. func (s *stateObject) CommitTrie(db Database) error { - s.updateTrie(db) + // If nothing changed, don't bother with hashing anything + if s.updateTrie(db) == nil { + return nil + } if s.dbErr != nil { return s.dbErr } + // Track the amount of time wasted on committing the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) + } root, err := s.trie.Commit(nil) if err == nil { s.data.Root = root @@ -269,7 +406,7 @@ func (s *stateObject) setBalance(amount *big.Int) { s.data.Balance = amount } -// ReturnGas return the gas back to the origin. Used by the Virtual machine or Closures +// ReturnGas returns the gas back to the origin. Used by the Virtual machine or Closures func (s *stateObject) ReturnGas(gas *big.Int) {} func (s *stateObject) deepCopy(db *StateDB) *stateObject { @@ -279,7 +416,8 @@ func (s *stateObject) deepCopy(db *StateDB) *stateObject { } stateObject.code = s.code stateObject.dirtyStorage = s.dirtyStorage.Copy() - stateObject.cachedStorage = s.dirtyStorage.Copy() + stateObject.originStorage = s.originStorage.Copy() + stateObject.pendingStorage = s.pendingStorage.Copy() stateObject.suicided = s.suicided stateObject.dirtyCode = s.dirtyCode stateObject.deleted = s.deleted @@ -352,7 +490,7 @@ func (s *stateObject) Nonce() uint64 { } // Value is never called, but must be present to allow stateObject to be used -// as a vm.StateAccount interface that also satisfies the vm.ContractRef +// as a vm.Account interface that also satisfies the vm.ContractRef // interface. Interfaces are awesome. func (s *stateObject) Value() *big.Int { panic("Value on stateObject should never be called") diff --git a/core/state/statedb.go b/core/state/statedb.go index 75d69cb061..f264ede6da 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -49,7 +49,7 @@ var ( emptyCode = crypto.Keccak256Hash(nil) ) -// StateDBs within the ethereum protocol are used to store anything +// StateDB within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: // * Contracts @@ -124,7 +124,7 @@ func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) commo return common.Hash{} } -// Create a new state from a given trie. +// New creates a new state from a given trie. func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { tr, err := db.OpenTrie(root) if err != nil { @@ -249,7 +249,7 @@ func (s *StateDB) Empty(addr common.Address) bool { return so == nil || so.empty() } -// Retrieve the balance from the given address or 0 if object not found +// GetBalance retrieves the balance from the given address or 0 if object not found func (s *StateDB) GetBalance(addr common.Address) *big.Int { stateObject := s.getStateObject(addr) if stateObject != nil { @@ -516,7 +516,7 @@ func (s *StateDB) setStateObject(object *stateObject) { s.stateObjects[object.Address()] = object } -// Retrieve a state object or create a new state object if nil. +// GetOrNewStateObject retrieves a state object or create a new state object if nil. func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { stateObject := s.getStateObject(addr) if stateObject == nil || stateObject.deleted { @@ -562,18 +562,27 @@ func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common. if so == nil { return nil } + tr := so.getTrie(s.db) + trieIt := tr.NodeIterator(nil) + it := trie.NewIterator(trieIt) - // When iterating over the storage check the cache first - for h, value := range so.cachedStorage { - cb(h, value) - } - - it := trie.NewIterator(so.getTrie(s.db).NodeIterator(nil)) for it.Next() { - // ignore cached values key := common.BytesToHash(s.trie.GetKey(it.Key)) - if _, ok := so.cachedStorage[key]; !ok { - cb(key, common.BytesToHash(it.Value)) + if value, dirty := so.dirtyStorage[key]; dirty { + if !cb(key, value) { + return nil + } + continue + } + + if len(it.Value) > 0 { + _, content, _, err := rlp.Split(it.Value) + if err != nil { + return err + } + if !cb(key, common.BytesToHash(content)) { + return nil + } } } return nil @@ -599,8 +608,17 @@ func (s *StateDB) Copy() *StateDB { } // Copy the dirty states, logs, and preimages for addr := range s.journal.dirties { - state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) - state.stateObjectsDirty[addr] = struct{}{} + // As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527), + // and in the Finalise-method, there is a case where an object is in the journal but not + // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for + // nil + if object, exist := s.stateObjects[addr]; exist { + // Even though the original object is dirty, we are not copying the journal, + // so we need to make sure that any side effect the journal would have caused + // during a commit (or similar op) is already applied to the copy. + state.stateObjects[addr] = object.deepCopy(state) + state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits + } } for hash, logs := range s.logs { state.logs[hash] = make([]*types.Log, len(logs)) diff --git a/eth/tracers/tracers_test.go b/eth/tracers/tracers_test.go index b0f96d4a68..05e91617ed 100644 --- a/eth/tracers/tracers_test.go +++ b/eth/tracers/tracers_test.go @@ -177,7 +177,7 @@ func TestPrestateTracerCreate2(t *testing.T) { if err != nil { t.Fatalf("failed to create call tracer: %v", err) } - evm := vm.NewEVM(context, statedb, nil, params.MainnetChainConfig, vm.Config{Debug: true, Tracer: tracer}) + evm := vm.NewEVM(context, statedb, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) msg, err := core.TransactionToMessage(tx, signer, nil, nil) if err != nil { From bcae7d2aff0f915bfffbec6e67d9875ca51b3352 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 18 Aug 2023 11:15:55 +0700 Subject: [PATCH 076/119] Record blockWriteTimer metric --- core/blockchain.go | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/core/blockchain.go b/core/blockchain.go index 49ee8d6f9b..2a912a4ce4 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -351,9 +351,7 @@ func (bc *BlockChain) loadLastState() error { repair = true } // Make sure the state associated with the block is available - _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps) - // err != nil{} - if err != nil { + if _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps); err != nil { repair = true } else { engine, ok := bc.Engine().(*posv.Posv) @@ -951,12 +949,14 @@ func (bc *BlockChain) SaveData() { // - HEAD-1: So we don't do large reorgs if our HEAD becomes an uncle // - HEAD-127: So we have a hard limit on the number of blocks reexecuted if !bc.cacheConfig.Disabled { - var tradingTriedb *trie.Database - var lendingTriedb *trie.Database + var ( + tradingTriedb *trie.Database + lendingTriedb *trie.Database + tradingService posv.TradingService + lendingService posv.LendingService + ) engine, _ := bc.Engine().(*posv.Posv) triedb := bc.stateCache.TrieDB() - var tradingService posv.TradingService - var lendingService posv.LendingService if bc.Config().IsTIPTomoX(bc.CurrentBlock().Number()) && bc.chainConfig.Posv != nil && bc.CurrentBlock().NumberU64() > bc.chainConfig.Posv.Epoch && engine != nil { tradingService = engine.GetTomoXService() if tradingService != nil && tradingService.GetStateCache() != nil { @@ -1616,11 +1616,13 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty } parentAuthor, _ := bc.Engine().Author(parent.Header()) // clear the previous dry-run cache - var tradingState *tradingstate.TradingStateDB - var lendingState *lendingstate.LendingStateDB - var tradingService posv.TradingService - var lendingService posv.LendingService - isSDKNode := false + var ( + tradingState *tradingstate.TradingStateDB + lendingState *lendingstate.LendingStateDB + tradingService posv.TradingService + lendingService posv.LendingService + isSDKNode = false + ) if bc.Config().IsTIPTomoX(block.Number()) && bc.chainConfig.Posv != nil && engine != nil && block.NumberU64() > bc.chainConfig.Posv.Epoch { tradingService = engine.GetTomoXService() lendingService = engine.GetLendingService() @@ -1749,6 +1751,8 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty storageCommitTimer.Update(statedb.StorageCommits) // Storage commits are complete, we can mark them snapshotCommitTimer.Update(statedb.SnapshotCommits) // Snapshot commits are complete, we can mark them + blockWriteTimer.Update(time.Since(substart) - statedb.AccountCommits - statedb.StorageCommits - statedb.SnapshotCommits) + if bc.chainConfig.Posv != nil { c := bc.engine.(*posv.Posv) coinbase := c.Signer() From 25b8733a120ea73a1c96378893c48d4844ce57fd Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 8 Sep 2023 10:50:53 +0700 Subject: [PATCH 077/119] Format code --- accounts/abi/abi.go | 46 ++++++++-------- accounts/abi/abi_test.go | 81 ++++++++++++++-------------- internal/ethapi/api.go | 110 +++++++++++++++++++-------------------- rpc/json.go | 6 +-- 4 files changed, 122 insertions(+), 121 deletions(-) diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 6acf0e2b66..ddf68bff4c 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -19,11 +19,11 @@ package abi import ( "bytes" "encoding/json" + "errors" "fmt" "io" - "errors" - "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/crypto" ) // The ABI holds information about a contract's context and available @@ -149,24 +149,24 @@ func (abi *ABI) MethodById(sigdata []byte) (*Method, error) { } // revertSelector is a special function selector for revert reason unpacking. - var revertSelector = crypto.Keccak256([]byte("Error(string)"))[:4] - - // UnpackRevert resolves the abi-encoded revert reason. According to the solidity - // spec https://solidity.readthedocs.io/en/latest/control-structures.html#revert, - // the provided revert reason is abi-encoded as if it were a call to a function - // `Error(string)`. So it's a special tool for it. - func UnpackRevert(data []byte) (string, error) { - if len(data) < 4 { - return "", errors.New("invalid data for unpacking") - } - if !bytes.Equal(data[:4], revertSelector) { - return "", errors.New("invalid data for unpacking") - } - var reason string - // typ, _ := NewType("string", "", nil) - typ, _ := NewType("string") - if err := (Arguments{{Type: typ}}).Unpack(&reason, data[4:]); err != nil { - return "", err - } - return reason, nil - } +var revertSelector = crypto.Keccak256([]byte("Error(string)"))[:4] + +// UnpackRevert resolves the abi-encoded revert reason. According to the solidity +// spec https://solidity.readthedocs.io/en/latest/control-structures.html#revert, +// the provided revert reason is abi-encoded as if it were a call to a function +// `Error(string)`. So it's a special tool for it. +func UnpackRevert(data []byte) (string, error) { + if len(data) < 4 { + return "", errors.New("invalid data for unpacking") + } + if !bytes.Equal(data[:4], revertSelector) { + return "", errors.New("invalid data for unpacking") + } + var reason string + // typ, _ := NewType("string", "", nil) + typ, _ := NewType("string") + if err := (Arguments{{Type: typ}}).Unpack(&reason, data[4:]); err != nil { + return "", err + } + return reason, nil +} diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index b7aad7eb4c..9092b6cd81 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -19,9 +19,9 @@ package abi import ( "bytes" "encoding/hex" + "errors" "fmt" "log" - "errors" "math/big" "strings" "testing" @@ -620,16 +620,19 @@ func TestBareEvents(t *testing.T) { } // TestUnpackEvent is based on this contract: -// contract T { -// event received(address sender, uint amount, bytes memo); -// event receivedAddr(address sender); -// function receive(bytes memo) external payable { -// received(msg.sender, msg.value, memo); -// receivedAddr(msg.sender); -// } -// } +// +// contract T { +// event received(address sender, uint amount, bytes memo); +// event receivedAddr(address sender); +// function receive(bytes memo) external payable { +// received(msg.sender, msg.value, memo); +// receivedAddr(msg.sender); +// } +// } +// // When receive("X") is called with sender 0x00... and value 1, it produces this tx receipt: -// receipt{status=1 cgas=23949 bloom=00000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000040200000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 logs=[log: b6818c8064f645cd82d99b59a1a267d6d61117ef [75fd880d39c1daf53b6547ab6cb59451fc6452d27caa90e5b6649dd8293b9eed] 000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158 9ae378b6d4409eada347a5dc0c180f186cb62dc68fcc0f043425eb917335aa28 0 95d429d309bb9d753954195fe2d69bd140b4ae731b9b5b605c34323de162cf00 0]} +// +// receipt{status=1 cgas=23949 bloom=00000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000040200000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 logs=[log: b6818c8064f645cd82d99b59a1a267d6d61117ef [75fd880d39c1daf53b6547ab6cb59451fc6452d27caa90e5b6649dd8293b9eed] 000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158 9ae378b6d4409eada347a5dc0c180f186cb62dc68fcc0f043425eb917335aa28 0 95d429d309bb9d753954195fe2d69bd140b4ae731b9b5b605c34323de162cf00 0]} func TestUnpackEvent(t *testing.T) { const abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]` abi, err := JSON(strings.NewReader(abiJSON)) @@ -716,32 +719,32 @@ func TestABI_MethodById(t *testing.T) { } func TestUnpackRevert(t *testing.T) { - t.Parallel() - - var cases = []struct { - input string - expect string - expectErr error - }{ - {"", "", errors.New("invalid data for unpacking")}, - {"08c379a1", "", errors.New("invalid data for unpacking")}, - {"08c379a00000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000d72657665727420726561736f6e00000000000000000000000000000000000000", "revert reason", nil}, - } - for index, c := range cases { - t.Run(fmt.Sprintf("case %d", index), func(t *testing.T) { - got, err := UnpackRevert(common.Hex2Bytes(c.input)) - if c.expectErr != nil { - if err == nil { - t.Fatalf("Expected non-nil error") - } - if err.Error() != c.expectErr.Error() { - t.Fatalf("Expected error mismatch, want %v, got %v", c.expectErr, err) - } - return - } - if c.expect != got { - t.Fatalf("Output mismatch, want %v, got %v", c.expect, got) - } - }) - } - } + t.Parallel() + + var cases = []struct { + input string + expect string + expectErr error + }{ + {"", "", errors.New("invalid data for unpacking")}, + {"08c379a1", "", errors.New("invalid data for unpacking")}, + {"08c379a00000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000d72657665727420726561736f6e00000000000000000000000000000000000000", "revert reason", nil}, + } + for index, c := range cases { + t.Run(fmt.Sprintf("case %d", index), func(t *testing.T) { + got, err := UnpackRevert(common.Hex2Bytes(c.input)) + if c.expectErr != nil { + if err == nil { + t.Fatalf("Expected non-nil error") + } + if err.Error() != c.expectErr.Error() { + t.Fatalf("Expected error mismatch, want %v, got %v", c.expectErr, err) + } + return + } + if c.expect != got { + t.Fatalf("Output mismatch, want %v, got %v", c.expect, got) + } + }) + } +} diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index ae627a3732..154aea7eb6 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -1114,46 +1114,46 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr } func newRevertError(result *core.ExecutionResult) *revertError { - reason, errUnpack := abi.UnpackRevert(result.Revert()) - err := errors.New("execution reverted") - if errUnpack == nil { - err = fmt.Errorf("execution reverted: %v", reason) - } - return &revertError{ - error: err, - reason: hexutil.Encode(result.Revert()), - } - } - - // revertError is an API error that encompassas an EVM revertal with JSON error - // code and a binary data blob. - type revertError struct { - error - reason string // revert reason hex encoded - } - - // ErrorCode returns the JSON error code for a revertal. - // See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal - func (e *revertError) ErrorCode() int { - return 3 - } - - // ErrorData returns the hex encoded revert reason. - func (e *revertError) ErrorData() interface{} { - return e.reason - } + reason, errUnpack := abi.UnpackRevert(result.Revert()) + err := errors.New("execution reverted") + if errUnpack == nil { + err = fmt.Errorf("execution reverted: %v", reason) + } + return &revertError{ + error: err, + reason: hexutil.Encode(result.Revert()), + } +} + +// revertError is an API error that encompassas an EVM revertal with JSON error +// code and a binary data blob. +type revertError struct { + error + reason string // revert reason hex encoded +} + +// ErrorCode returns the JSON error code for a revertal. +// See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal +func (e *revertError) ErrorCode() int { + return 3 +} + +// ErrorData returns the hex encoded revert reason. +func (e *revertError) ErrorData() interface{} { + return e.reason +} // Call executes the given transaction on the state for the given block number. // It doesn't make and changes in the state/blockchain and is useful to execute and retrieve values. func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { result, err := s.doCall(ctx, args, blockNr, vm.Config{}, 5*time.Second) - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - if len(result.Revert()) > 0 { - return nil, newRevertError(result) - } + if len(result.Revert()) > 0 { + return nil, newRevertError(result) + } return result.Return(), result.Err } @@ -1184,9 +1184,9 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h result, err := s.doCall(ctx, args, rpc.LatestBlockNumber, vm.Config{}, 0) if err != nil { - if err == core.ErrIntrinsicGas { - return true, nil, nil // Special case, raise gas limit - } + if err == core.ErrIntrinsicGas { + return true, nil, nil // Special case, raise gas limit + } return true, nil, err } return result.Failed(), result, nil @@ -1194,10 +1194,10 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h // Execute the binary search and hone in on an executable gas limit for lo+1 < hi { mid := (hi + lo) / 2 - failed, _, err := executable(mid) - if err != nil { - return 0, err - } + failed, _, err := executable(mid) + if err != nil { + return 0, err + } if failed { lo = mid } else { @@ -1206,20 +1206,20 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h } // Reject the transaction as invalid if it still fails at the highest allowance if hi == cap { - failed, result, err := executable(hi) - if err != nil { - return 0, nil - } - - if failed { - if result != nil && result.Err != vm.ErrOutOfGas { - if len(result.Revert()) > 0 { - return 0, newRevertError(result) - } - return 0, result.Err - } - return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap) - } + failed, result, err := executable(hi) + if err != nil { + return 0, nil + } + + if failed { + if result != nil && result.Err != vm.ErrOutOfGas { + if len(result.Revert()) > 0 { + return 0, newRevertError(result) + } + return 0, result.Err + } + return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap) + } } return hexutil.Uint64(hi), nil } diff --git a/rpc/json.go b/rpc/json.go index 9d57a9cf70..e35a74118a 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -56,8 +56,6 @@ type jsonError struct { Data interface{} `json:"data,omitempty"` } - - type jsonErrResponse struct { Version string `json:"jsonrpc"` Id interface{} `json:"id,omitempty"` @@ -99,8 +97,8 @@ func (err *jsonError) ErrorCode() int { } func (err *jsonError) ErrorData() interface{} { - return err.Data - } + return err.Data +} // NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based // on explicitly given encoding and decoding methods. From 35e30f600f75160c6f955db780289d77996a273b Mon Sep 17 00:00:00 2001 From: trinhdn97 Date: Thu, 21 Sep 2023 01:43:32 +0700 Subject: [PATCH 078/119] Run gofmt and fix unit tests --- accounts/abi/abi.go | 46 +++++++-------- accounts/abi/abi_test.go | 81 +++++++++++++------------- eth/tracers/tracers_test.go | 7 ++- internal/ethapi/api.go | 110 ++++++++++++++++++------------------ les/odr_test.go | 8 +-- light/odr_test.go | 4 +- rpc/json.go | 6 +- 7 files changed, 132 insertions(+), 130 deletions(-) diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 6acf0e2b66..ddf68bff4c 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -19,11 +19,11 @@ package abi import ( "bytes" "encoding/json" + "errors" "fmt" "io" - "errors" - "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/crypto" ) // The ABI holds information about a contract's context and available @@ -149,24 +149,24 @@ func (abi *ABI) MethodById(sigdata []byte) (*Method, error) { } // revertSelector is a special function selector for revert reason unpacking. - var revertSelector = crypto.Keccak256([]byte("Error(string)"))[:4] - - // UnpackRevert resolves the abi-encoded revert reason. According to the solidity - // spec https://solidity.readthedocs.io/en/latest/control-structures.html#revert, - // the provided revert reason is abi-encoded as if it were a call to a function - // `Error(string)`. So it's a special tool for it. - func UnpackRevert(data []byte) (string, error) { - if len(data) < 4 { - return "", errors.New("invalid data for unpacking") - } - if !bytes.Equal(data[:4], revertSelector) { - return "", errors.New("invalid data for unpacking") - } - var reason string - // typ, _ := NewType("string", "", nil) - typ, _ := NewType("string") - if err := (Arguments{{Type: typ}}).Unpack(&reason, data[4:]); err != nil { - return "", err - } - return reason, nil - } +var revertSelector = crypto.Keccak256([]byte("Error(string)"))[:4] + +// UnpackRevert resolves the abi-encoded revert reason. According to the solidity +// spec https://solidity.readthedocs.io/en/latest/control-structures.html#revert, +// the provided revert reason is abi-encoded as if it were a call to a function +// `Error(string)`. So it's a special tool for it. +func UnpackRevert(data []byte) (string, error) { + if len(data) < 4 { + return "", errors.New("invalid data for unpacking") + } + if !bytes.Equal(data[:4], revertSelector) { + return "", errors.New("invalid data for unpacking") + } + var reason string + // typ, _ := NewType("string", "", nil) + typ, _ := NewType("string") + if err := (Arguments{{Type: typ}}).Unpack(&reason, data[4:]); err != nil { + return "", err + } + return reason, nil +} diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index b7aad7eb4c..9092b6cd81 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -19,9 +19,9 @@ package abi import ( "bytes" "encoding/hex" + "errors" "fmt" "log" - "errors" "math/big" "strings" "testing" @@ -620,16 +620,19 @@ func TestBareEvents(t *testing.T) { } // TestUnpackEvent is based on this contract: -// contract T { -// event received(address sender, uint amount, bytes memo); -// event receivedAddr(address sender); -// function receive(bytes memo) external payable { -// received(msg.sender, msg.value, memo); -// receivedAddr(msg.sender); -// } -// } +// +// contract T { +// event received(address sender, uint amount, bytes memo); +// event receivedAddr(address sender); +// function receive(bytes memo) external payable { +// received(msg.sender, msg.value, memo); +// receivedAddr(msg.sender); +// } +// } +// // When receive("X") is called with sender 0x00... and value 1, it produces this tx receipt: -// receipt{status=1 cgas=23949 bloom=00000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000040200000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 logs=[log: b6818c8064f645cd82d99b59a1a267d6d61117ef [75fd880d39c1daf53b6547ab6cb59451fc6452d27caa90e5b6649dd8293b9eed] 000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158 9ae378b6d4409eada347a5dc0c180f186cb62dc68fcc0f043425eb917335aa28 0 95d429d309bb9d753954195fe2d69bd140b4ae731b9b5b605c34323de162cf00 0]} +// +// receipt{status=1 cgas=23949 bloom=00000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000040200000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 logs=[log: b6818c8064f645cd82d99b59a1a267d6d61117ef [75fd880d39c1daf53b6547ab6cb59451fc6452d27caa90e5b6649dd8293b9eed] 000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158 9ae378b6d4409eada347a5dc0c180f186cb62dc68fcc0f043425eb917335aa28 0 95d429d309bb9d753954195fe2d69bd140b4ae731b9b5b605c34323de162cf00 0]} func TestUnpackEvent(t *testing.T) { const abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]` abi, err := JSON(strings.NewReader(abiJSON)) @@ -716,32 +719,32 @@ func TestABI_MethodById(t *testing.T) { } func TestUnpackRevert(t *testing.T) { - t.Parallel() - - var cases = []struct { - input string - expect string - expectErr error - }{ - {"", "", errors.New("invalid data for unpacking")}, - {"08c379a1", "", errors.New("invalid data for unpacking")}, - {"08c379a00000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000d72657665727420726561736f6e00000000000000000000000000000000000000", "revert reason", nil}, - } - for index, c := range cases { - t.Run(fmt.Sprintf("case %d", index), func(t *testing.T) { - got, err := UnpackRevert(common.Hex2Bytes(c.input)) - if c.expectErr != nil { - if err == nil { - t.Fatalf("Expected non-nil error") - } - if err.Error() != c.expectErr.Error() { - t.Fatalf("Expected error mismatch, want %v, got %v", c.expectErr, err) - } - return - } - if c.expect != got { - t.Fatalf("Output mismatch, want %v, got %v", c.expect, got) - } - }) - } - } + t.Parallel() + + var cases = []struct { + input string + expect string + expectErr error + }{ + {"", "", errors.New("invalid data for unpacking")}, + {"08c379a1", "", errors.New("invalid data for unpacking")}, + {"08c379a00000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000d72657665727420726561736f6e00000000000000000000000000000000000000", "revert reason", nil}, + } + for index, c := range cases { + t.Run(fmt.Sprintf("case %d", index), func(t *testing.T) { + got, err := UnpackRevert(common.Hex2Bytes(c.input)) + if c.expectErr != nil { + if err == nil { + t.Fatalf("Expected non-nil error") + } + if err.Error() != c.expectErr.Error() { + t.Fatalf("Expected error mismatch, want %v, got %v", c.expectErr, err) + } + return + } + if c.expect != got { + t.Fatalf("Output mismatch, want %v, got %v", c.expect, got) + } + }) + } +} diff --git a/eth/tracers/tracers_test.go b/eth/tracers/tracers_test.go index 38d4075175..623b8279e5 100644 --- a/eth/tracers/tracers_test.go +++ b/eth/tracers/tracers_test.go @@ -20,17 +20,18 @@ import ( "crypto/ecdsa" "crypto/rand" "encoding/json" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "path/filepath" "reflect" "strings" "testing" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -183,7 +184,7 @@ func TestPrestateTracerCreate2(t *testing.T) { t.Fatalf("failed to prepare transaction for tracing: %v", err) } st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) - if _, _, _, err = st.TransitionDb(common.Address{}); err != nil { + if _, err = st.TransitionDb(common.Address{}); err != nil { t.Fatalf("failed to execute transaction: %v", err) } // Retrieve the trace result and compare against the etalon @@ -258,7 +259,7 @@ func TestCallTracer(t *testing.T) { t.Fatalf("failed to prepare transaction for tracing: %v", err) } st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) - if _, _, _, err = st.TransitionDb(common.Address{}); err != nil { + if _, err = st.TransitionDb(common.Address{}); err != nil { t.Fatalf("failed to execute transaction: %v", err) } // Retrieve the trace result and compare against the etalon diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 37bfd1fb82..9ed9df2dbf 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -1103,46 +1103,46 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr } func newRevertError(result *core.ExecutionResult) *revertError { - reason, errUnpack := abi.UnpackRevert(result.Revert()) - err := errors.New("execution reverted") - if errUnpack == nil { - err = fmt.Errorf("execution reverted: %v", reason) - } - return &revertError{ - error: err, - reason: hexutil.Encode(result.Revert()), - } - } - - // revertError is an API error that encompassas an EVM revertal with JSON error - // code and a binary data blob. - type revertError struct { - error - reason string // revert reason hex encoded - } - - // ErrorCode returns the JSON error code for a revertal. - // See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal - func (e *revertError) ErrorCode() int { - return 3 - } - - // ErrorData returns the hex encoded revert reason. - func (e *revertError) ErrorData() interface{} { - return e.reason - } + reason, errUnpack := abi.UnpackRevert(result.Revert()) + err := errors.New("execution reverted") + if errUnpack == nil { + err = fmt.Errorf("execution reverted: %v", reason) + } + return &revertError{ + error: err, + reason: hexutil.Encode(result.Revert()), + } +} + +// revertError is an API error that encompassas an EVM revertal with JSON error +// code and a binary data blob. +type revertError struct { + error + reason string // revert reason hex encoded +} + +// ErrorCode returns the JSON error code for a revertal. +// See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal +func (e *revertError) ErrorCode() int { + return 3 +} + +// ErrorData returns the hex encoded revert reason. +func (e *revertError) ErrorData() interface{} { + return e.reason +} // Call executes the given transaction on the state for the given block number. // It doesn't make and changes in the state/blockchain and is useful to execute and retrieve values. func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { result, err := s.doCall(ctx, args, blockNr, vm.Config{}, 5*time.Second) - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - if len(result.Revert()) > 0 { - return nil, newRevertError(result) - } + if len(result.Revert()) > 0 { + return nil, newRevertError(result) + } return result.Return(), result.Err } @@ -1173,9 +1173,9 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h result, err := s.doCall(ctx, args, rpc.LatestBlockNumber, vm.Config{}, 0) if err != nil { - if err == core.ErrIntrinsicGas { - return true, nil, nil // Special case, raise gas limit - } + if err == core.ErrIntrinsicGas { + return true, nil, nil // Special case, raise gas limit + } return true, nil, err } return result.Failed(), result, nil @@ -1183,10 +1183,10 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h // Execute the binary search and hone in on an executable gas limit for lo+1 < hi { mid := (hi + lo) / 2 - failed, _, err := executable(mid) - if err != nil { - return 0, err - } + failed, _, err := executable(mid) + if err != nil { + return 0, err + } if failed { lo = mid } else { @@ -1195,20 +1195,20 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h } // Reject the transaction as invalid if it still fails at the highest allowance if hi == cap { - failed, result, err := executable(hi) - if err != nil { - return 0, nil - } - - if failed { - if result != nil && result.Err != vm.ErrOutOfGas { - if len(result.Revert()) > 0 { - return 0, newRevertError(result) - } - return 0, result.Err - } - return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap) - } + failed, result, err := executable(hi) + if err != nil { + return 0, nil + } + + if failed { + if result != nil && result.Err != vm.ErrOutOfGas { + if len(result.Revert()) > 0 { + return 0, newRevertError(result) + } + return 0, result.Err + } + return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap) + } } return hexutil.Uint64(hi), nil } diff --git a/les/odr_test.go b/les/odr_test.go index 3858e34028..b7e12d454b 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -141,8 +141,8 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai //vmenv := core.NewEnv(statedb, config, bc, msg, header, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) owner := common.Address{} - ret, _, _, _ := core.ApplyMessage(vmenv, msg, gp, owner) - res = append(res, ret...) + ret, _ := core.ApplyMessage(vmenv, msg, gp, owner) + res = append(res, ret.Return()...) } } else { header := lc.GetHeaderByHash(bhash) @@ -158,9 +158,9 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) owner := common.Address{} - ret, _, _, _ := core.ApplyMessage(vmenv, msg, gp, owner) + ret, _:= core.ApplyMessage(vmenv, msg, gp, owner) if statedb.Error() == nil { - res = append(res, ret...) + res = append(res, ret.Return()...) } } } diff --git a/light/odr_test.go b/light/odr_test.go index 0c5fc78573..1d9df31d4c 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -188,8 +188,8 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain vmenv := vm.NewEVM(context, st, nil, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) owner := common.Address{} - ret, _, _, _ := core.ApplyMessage(vmenv, msg, gp, owner) - res = append(res, ret...) + ret, _ := core.ApplyMessage(vmenv, msg, gp, owner) + res = append(res, ret.Return()...) if st.Error() != nil { return res, st.Error() } diff --git a/rpc/json.go b/rpc/json.go index 9d57a9cf70..e35a74118a 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -56,8 +56,6 @@ type jsonError struct { Data interface{} `json:"data,omitempty"` } - - type jsonErrResponse struct { Version string `json:"jsonrpc"` Id interface{} `json:"id,omitempty"` @@ -99,8 +97,8 @@ func (err *jsonError) ErrorCode() int { } func (err *jsonError) ErrorData() interface{} { - return err.Data - } + return err.Data +} // NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based // on explicitly given encoding and decoding methods. From 653e72ee0f7b7d6b76ad505cf8d54a09dfa75dcd Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 29 Sep 2023 11:52:06 +0700 Subject: [PATCH 079/119] Add packing for dynamic array and slice types --- accounts/abi/argument.go | 21 +++++-------- accounts/abi/pack_test.go | 62 ++++++++++++++++++++++++++++++++++++++- accounts/abi/type.go | 58 ++++++++++++++++++++++++++++++------ 3 files changed, 118 insertions(+), 23 deletions(-) diff --git a/accounts/abi/argument.go b/accounts/abi/argument.go index 512d8fdfa7..cf140698d2 100644 --- a/accounts/abi/argument.go +++ b/accounts/abi/argument.go @@ -232,11 +232,7 @@ func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) { // input offset is the bytes offset for packed output inputOffset := 0 for _, abiArg := range abiArgs { - if abiArg.Type.T == ArrayTy { - inputOffset += 32 * abiArg.Type.Size - } else { - inputOffset += 32 - } + inputOffset += getDynamicTypeOffset(abiArg.Type) } var ret []byte for i, a := range args { @@ -246,14 +242,13 @@ func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) { if err != nil { return nil, err } - // check for a slice type (string, bytes, slice) - if input.Type.requiresLengthPrefix() { - // calculate the offset - offset := inputOffset + len(variableInput) + // check for dynamic types + if isDynamicType(input.Type) { // set the offset - ret = append(ret, packNum(reflect.ValueOf(offset))...) - // Append the packed output to the variable input. The variable input - // will be appended at the end of the input. + ret = append(ret, packNum(reflect.ValueOf(inputOffset))...) + // calculate next offset + inputOffset += len(packed) + // append to variable input variableInput = append(variableInput, packed...) } else { // append the packed value to the input @@ -278,7 +273,7 @@ func capitalise(input string) string { return strings.ToUpper(input[:1]) + input[1:] } -//unpackStruct extracts each argument into its corresponding struct field +// unpackStruct extracts each argument into its corresponding struct field func unpackStruct(value, reflectValue reflect.Value, arg Argument) error { name := capitalise(arg.Name) typ := value.Type() diff --git a/accounts/abi/pack_test.go b/accounts/abi/pack_test.go index be48cb5b15..7956157fee 100644 --- a/accounts/abi/pack_test.go +++ b/accounts/abi/pack_test.go @@ -324,6 +324,66 @@ func TestPack(t *testing.T) { "foobar", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000006666f6f6261720000000000000000000000000000000000000000000000000000"), }, + { + "string[]", + []string{"hello", "foobar"}, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002" + // len(array) = 2 + "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 + "0000000000000000000000000000000000000000000000000000000000000080" + // offset 128 to i = 1 + "0000000000000000000000000000000000000000000000000000000000000005" + // len(str[0]) = 5 + "68656c6c6f000000000000000000000000000000000000000000000000000000" + // str[0] + "0000000000000000000000000000000000000000000000000000000000000006" + // len(str[1]) = 6 + "666f6f6261720000000000000000000000000000000000000000000000000000"), // str[1] + }, + { + "string[2]", + []string{"hello", "foobar"}, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040" + // offset to i = 0 + "0000000000000000000000000000000000000000000000000000000000000080" + // offset to i = 1 + "0000000000000000000000000000000000000000000000000000000000000005" + // len(str[0]) = 5 + "68656c6c6f000000000000000000000000000000000000000000000000000000" + // str[0] + "0000000000000000000000000000000000000000000000000000000000000006" + // len(str[1]) = 6 + "666f6f6261720000000000000000000000000000000000000000000000000000"), // str[1] + }, + { + "bytes32[][]", + [][]common.Hash{{{1}, {2}}, {{3}, {4}, {5}}}, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002" + // len(array) = 2 + "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset 160 to i = 1 + "0000000000000000000000000000000000000000000000000000000000000002" + // len(array[0]) = 2 + "0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0] + "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1] + "0000000000000000000000000000000000000000000000000000000000000003" + // len(array[1]) = 3 + "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0] + "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1] + "0500000000000000000000000000000000000000000000000000000000000000"), // array[1][2] + }, + + { + "bytes32[][2]", + [][]common.Hash{{{1}, {2}}, {{3}, {4}, {5}}}, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset 160 to i = 1 + "0000000000000000000000000000000000000000000000000000000000000002" + // len(array[0]) = 2 + "0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0] + "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1] + "0000000000000000000000000000000000000000000000000000000000000003" + // len(array[1]) = 3 + "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0] + "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1] + "0500000000000000000000000000000000000000000000000000000000000000"), // array[1][2] + }, + + { + "bytes32[3][2]", + [][]common.Hash{{{1}, {2}, {3}}, {{3}, {4}, {5}}}, + common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0] + "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1] + "0300000000000000000000000000000000000000000000000000000000000000" + // array[0][2] + "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0] + "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1] + "0500000000000000000000000000000000000000000000000000000000000000"), // array[1][2] + }, } { typ, err := NewType(test.typ) if err != nil { @@ -336,7 +396,7 @@ func TestPack(t *testing.T) { } if !bytes.Equal(output, test.output) { - t.Errorf("%d failed. Expected bytes: '%x' Got: '%x'", i, test.output, output) + t.Errorf("input %d for typ: %v failed. Expected bytes: '%x' Got: '%x'", i, typ.String(), test.output, output) } } } diff --git a/accounts/abi/type.go b/accounts/abi/type.go index a1f13ffa29..e216cbe071 100644 --- a/accounts/abi/type.go +++ b/accounts/abi/type.go @@ -178,23 +178,39 @@ func (t Type) pack(v reflect.Value) ([]byte, error) { return nil, err } - if t.T == SliceTy || t.T == ArrayTy { - var packed []byte + switch t.T { + case SliceTy, ArrayTy: + var ret []byte + if t.requiresLengthPrefix() { + // append length + ret = append(ret, packNum(reflect.ValueOf(v.Len()))...) + } + + // calculate offset if any + offset := 0 + offsetReq := isDynamicType(*t.Elem) + if offsetReq { + offset = getDynamicTypeOffset(*t.Elem) * v.Len() + } + var tail []byte for i := 0; i < v.Len(); i++ { val, err := t.Elem.pack(v.Index(i)) if err != nil { return nil, err } - packed = append(packed, val...) - } - if t.T == SliceTy { - return packBytesSlice(packed, v.Len()), nil - } else if t.T == ArrayTy { - return packed, nil + if !offsetReq { + ret = append(ret, val...) + continue + } + ret = append(ret, packNum(reflect.ValueOf(offset))...) + offset += len(val) + tail = append(tail, val...) } + return append(ret, tail...), nil + default: + return packElement(t, v), nil } - return packElement(t, v), nil } // requireLengthPrefix returns whether the type requires any sort of length @@ -202,3 +218,27 @@ func (t Type) pack(v reflect.Value) ([]byte, error) { func (t Type) requiresLengthPrefix() bool { return t.T == StringTy || t.T == BytesTy || t.T == SliceTy } + +// isDynamicType returns true if the type is dynamic. +// StringTy, BytesTy, and SliceTy (irrespective of slice element type) are dynamic types +// ArrayTy is considered dynamic if and only if the Array element is a dynamic type. +// This function recursively checks the type for slice and array elements. +func isDynamicType(t Type) bool { + // dynamic types + // array is also a dynamic type if the array type is dynamic + return t.T == StringTy || t.T == BytesTy || t.T == SliceTy || (t.T == ArrayTy && isDynamicType(*t.Elem)) +} + +// getDynamicTypeOffset returns the offset for the type. +// See `isDynamicType` to know which types are considered dynamic. +// If the type t is an array and element type is not a dynamic type, then we consider it a static type and +// return 32 * size of array since length prefix is not required. +// If t is a dynamic type or element type(for slices and arrays) is dynamic, then we simply return 32 as offset. +func getDynamicTypeOffset(t Type) int { + // if it is an array and there are no dynamic types + // then the array is static type + if t.T == ArrayTy && !isDynamicType(*t.Elem) { + return 32 * t.Size + } + return 32 +} From 4e8214cbe92fe0f1f66b6b8c008523cd449f3995 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 29 Sep 2023 16:46:07 +0700 Subject: [PATCH 080/119] Tuple support --- accounts/abi/abi.go | 7 +- accounts/abi/argument.go | 212 +++++++++++++++++++++++---------------- accounts/abi/numbers.go | 37 ++----- accounts/abi/reflect.go | 150 +++++++++++++++++++++------ accounts/abi/type.go | 150 +++++++++++++++++++++++---- accounts/abi/unpack.go | 119 +++++++++++++++++----- 6 files changed, 485 insertions(+), 190 deletions(-) diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 254b1f7fb4..08d5db9798 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -58,13 +58,11 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { return nil, err } return arguments, nil - } method, exist := abi.Methods[name] if !exist { return nil, fmt.Errorf("method '%s' not found", name) } - arguments, err := method.Inputs.Pack(args...) if err != nil { return nil, err @@ -82,7 +80,7 @@ func (abi ABI) Unpack(v interface{}, name string, output []byte) (err error) { // we need to decide whether we're calling a method or an event if method, ok := abi.Methods[name]; ok { if len(output)%32 != 0 { - return fmt.Errorf("abi: improperly formatted output") + return fmt.Errorf("abi: improperly formatted output: %s - Bytes: [%+v]", string(output), output) } return method.Outputs.Unpack(v, output) } else if event, ok := abi.Events[name]; ok { @@ -137,6 +135,9 @@ func (abi *ABI) UnmarshalJSON(data []byte) error { // MethodById looks up a method by the 4-byte id // returns nil if none found func (abi *ABI) MethodById(sigdata []byte) (*Method, error) { + if len(sigdata) < 4 { + return nil, fmt.Errorf("data too short (% bytes) for abi method lookup", len(sigdata)) + } for _, method := range abi.Methods { if bytes.Equal(method.Id(), sigdata[:4]) { return &method, nil diff --git a/accounts/abi/argument.go b/accounts/abi/argument.go index cf140698d2..d0a6b035c6 100644 --- a/accounts/abi/argument.go +++ b/accounts/abi/argument.go @@ -33,24 +33,27 @@ type Argument struct { type Arguments []Argument +type ArgumentMarshaling struct { + Name string + Type string + Components []ArgumentMarshaling + Indexed bool +} + // UnmarshalJSON implements json.Unmarshaler interface func (argument *Argument) UnmarshalJSON(data []byte) error { - var extarg struct { - Name string - Type string - Indexed bool - } - err := json.Unmarshal(data, &extarg) + var arg ArgumentMarshaling + err := json.Unmarshal(data, &arg) if err != nil { return fmt.Errorf("argument json err: %v", err) } - argument.Type, err = NewType(extarg.Type) + argument.Type, err = NewType(arg.Type, arg.Components) if err != nil { return err } - argument.Name = extarg.Name - argument.Indexed = extarg.Indexed + argument.Name = arg.Name + argument.Indexed = arg.Indexed return nil } @@ -85,7 +88,6 @@ func (arguments Arguments) isTuple() bool { // Unpack performs the operation hexdata -> Go format func (arguments Arguments) Unpack(v interface{}, data []byte) error { - // make sure the passed value is arguments pointer if reflect.Ptr != reflect.ValueOf(v).Kind() { return fmt.Errorf("abi: Unpack(non-pointer %T)", v) @@ -97,34 +99,123 @@ func (arguments Arguments) Unpack(v interface{}, data []byte) error { if arguments.isTuple() { return arguments.unpackTuple(v, marshalledValues) } - return arguments.unpackAtomic(v, marshalledValues) + return arguments.unpackAtomic(v, marshalledValues[0]) } -func (arguments Arguments) unpackTuple(v interface{}, marshalledValues []interface{}) error { +// unpack sets the unmarshalled value to go format. +// Note the dst here must be settable. +func unpack(t *Type, dst interface{}, src interface{}) error { + var ( + dstVal = reflect.ValueOf(dst).Elem() + srcVal = reflect.ValueOf(src) + ) + + if t.T != TupleTy && !((t.T == SliceTy || t.T == ArrayTy) && t.Elem.T == TupleTy) { + return set(dstVal, srcVal) + } + + switch t.T { + case TupleTy: + if dstVal.Kind() != reflect.Struct { + return fmt.Errorf("abi: invalid dst value for unpack, want struct, got %s", dstVal.Kind()) + } + fieldmap, err := mapArgNamesToStructFields(t.TupleRawNames, dstVal) + if err != nil { + return err + } + for i, elem := range t.TupleElems { + fname := fieldmap[t.TupleRawNames[i]] + field := dstVal.FieldByName(fname) + if !field.IsValid() { + return fmt.Errorf("abi: field %s can't found in the given value", t.TupleRawNames[i]) + } + if err := unpack(elem, field.Addr().Interface(), srcVal.Field(i).Interface()); err != nil { + return err + } + } + return nil + case SliceTy: + if dstVal.Kind() != reflect.Slice { + return fmt.Errorf("abi: invalid dst value for unpack, want slice, got %s", dstVal.Kind()) + } + slice := reflect.MakeSlice(dstVal.Type(), srcVal.Len(), srcVal.Len()) + for i := 0; i < slice.Len(); i++ { + if err := unpack(t.Elem, slice.Index(i).Addr().Interface(), srcVal.Index(i).Interface()); err != nil { + return err + } + } + dstVal.Set(slice) + case ArrayTy: + if dstVal.Kind() != reflect.Array { + return fmt.Errorf("abi: invalid dst value for unpack, want array, got %s", dstVal.Kind()) + } + array := reflect.New(dstVal.Type()).Elem() + for i := 0; i < array.Len(); i++ { + if err := unpack(t.Elem, array.Index(i).Addr().Interface(), srcVal.Index(i).Interface()); err != nil { + return err + } + } + dstVal.Set(array) + } + return nil +} +// unpackAtomic unpacks ( hexdata -> go ) a single value +func (arguments Arguments) unpackAtomic(v interface{}, marshalledValues interface{}) error { + if arguments.LengthNonIndexed() == 0 { + return nil + } + argument := arguments.NonIndexed()[0] + elem := reflect.ValueOf(v).Elem() + + if elem.Kind() == reflect.Struct { + fieldmap, err := mapArgNamesToStructFields([]string{argument.Name}, elem) + if err != nil { + return err + } + field := elem.FieldByName(fieldmap[argument.Name]) + if !field.IsValid() { + return fmt.Errorf("abi: field %s can't be found in the given value", argument.Name) + } + return unpack(&argument.Type, field.Addr().Interface(), marshalledValues) + } + return unpack(&argument.Type, elem.Addr().Interface(), marshalledValues) +} + +// unpackTuple unpacks ( hexdata -> go ) a batch of values. +func (arguments Arguments) unpackTuple(v interface{}, marshalledValues []interface{}) error { var ( value = reflect.ValueOf(v).Elem() typ = value.Type() kind = value.Kind() ) - if err := requireUnpackKind(value, typ, kind, arguments); err != nil { return err } - // If the output interface is a struct, make sure names don't collide + + // If the interface is a struct, get of abi->struct_field mapping + var abi2struct map[string]string if kind == reflect.Struct { - if err := requireUniqueStructFieldNames(arguments); err != nil { + var ( + argNames []string + err error + ) + for _, arg := range arguments.NonIndexed() { + argNames = append(argNames, arg.Name) + } + abi2struct, err = mapArgNamesToStructFields(argNames, value) + if err != nil { return err } } for i, arg := range arguments.NonIndexed() { - - reflectValue := reflect.ValueOf(marshalledValues[i]) - switch kind { case reflect.Struct: - err := unpackStruct(value, reflectValue, arg) - if err != nil { + field := value.FieldByName(abi2struct[arg.Name]) + if !field.IsValid() { + return fmt.Errorf("abi: field %s can't be found in the given value", arg.Name) + } + if err := unpack(&arg.Type, field.Addr().Interface(), marshalledValues[i]); err != nil { return err } case reflect.Slice, reflect.Array: @@ -132,11 +223,10 @@ func (arguments Arguments) unpackTuple(v interface{}, marshalledValues []interfa return fmt.Errorf("abi: insufficient number of arguments for unpack, want %d, got %d", len(arguments), value.Len()) } v := value.Index(i) - if err := requireAssignable(v, reflectValue); err != nil { + if err := requireAssignable(v, reflect.ValueOf(marshalledValues[i])); err != nil { return err } - - if err := set(v.Elem(), reflectValue, arg); err != nil { + if err := unpack(&arg.Type, v.Addr().Interface(), marshalledValues[i]); err != nil { return err } default: @@ -144,45 +234,9 @@ func (arguments Arguments) unpackTuple(v interface{}, marshalledValues []interfa } } return nil -} - -// unpackAtomic unpacks ( hexdata -> go ) a single value -func (arguments Arguments) unpackAtomic(v interface{}, marshalledValues []interface{}) error { - if len(marshalledValues) != 1 { - return fmt.Errorf("abi: wrong length, expected single value, got %d", len(marshalledValues)) - } - elem := reflect.ValueOf(v).Elem() - kind := elem.Kind() - reflectValue := reflect.ValueOf(marshalledValues[0]) - - if kind == reflect.Struct { - //make sure names don't collide - if err := requireUniqueStructFieldNames(arguments); err != nil { - return err - } - - return unpackStruct(elem, reflectValue, arguments[0]) - } - - return set(elem, reflectValue, arguments.NonIndexed()[0]) } -// Computes the full size of an array; -// i.e. counting nested arrays, which count towards size for unpacking. -func getArraySize(arr *Type) int { - size := arr.Size - // Arrays can be nested, with each element being the same size - arr = arr.Elem - for arr.T == ArrayTy { - // Keep multiplying by elem.Size while the elem is an array. - size *= arr.Size - arr = arr.Elem - } - // Now we have the full array size, including its children. - return size -} - // UnpackValues can be used to unpack ABI-encoded hexdata according to the ABI-specification, // without supplying a struct to unpack into. Instead, this method returns a list containing the // values. An atomic argument will be a list with one element. @@ -191,7 +245,7 @@ func (arguments Arguments) UnpackValues(data []byte) ([]interface{}, error) { virtualArgs := 0 for index, arg := range arguments.NonIndexed() { marshalledValue, err := toGoType((index+virtualArgs)*32, arg.Type, data) - if arg.Type.T == ArrayTy { + if arg.Type.T == ArrayTy && !isDynamicType(arg.Type) { // If we have a static array, like [3]uint256, these are coded as // just like uint256,uint256,uint256. // This means that we need to add two 'virtual' arguments when @@ -202,7 +256,11 @@ func (arguments Arguments) UnpackValues(data []byte) ([]interface{}, error) { // // Calculate the full array size to get the correct offset for the next argument. // Decrement it by 1, as the normal index increment is still applied. - virtualArgs += getArraySize(&arg.Type) - 1 + virtualArgs += getTypeSize(arg.Type)/32 - 1 + } else if arg.Type.T == TupleTy && !isDynamicType(arg.Type) { + // If we have a static tuple, like (uint256, bool, uint256), these are + // coded as just like uint256,bool,uint256 + virtualArgs += getTypeSize(arg.Type)/32 - 1 } if err != nil { return nil, err @@ -232,7 +290,7 @@ func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) { // input offset is the bytes offset for packed output inputOffset := 0 for _, abiArg := range abiArgs { - inputOffset += getDynamicTypeOffset(abiArg.Type) + inputOffset += getTypeSize(abiArg.Type) } var ret []byte for i, a := range args { @@ -261,29 +319,13 @@ func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) { return ret, nil } -// capitalise makes the first character of a string upper case, also removing any -// prefixing underscores from the variable names. -func capitalise(input string) string { - for len(input) > 0 && input[0] == '_' { - input = input[1:] - } - if len(input) == 0 { - return "" - } - return strings.ToUpper(input[:1]) + input[1:] -} - -// unpackStruct extracts each argument into its corresponding struct field -func unpackStruct(value, reflectValue reflect.Value, arg Argument) error { - name := capitalise(arg.Name) - typ := value.Type() - for j := 0; j < typ.NumField(); j++ { - // TODO read tags: `abi:"fieldName"` - if typ.Field(j).Name == name { - if err := set(value.Field(j), reflectValue, arg); err != nil { - return err - } +// ToCamelCase converts an under-score string to a camel-case string +func ToCamelCase(input string) string { + parts := strings.Split(input, "_") + for i, s := range parts { + if len(s) > 0 { + parts[i] = strings.ToUpper(s[:1]) + s[1:] } } - return nil + return strings.Join(parts, "") } diff --git a/accounts/abi/numbers.go b/accounts/abi/numbers.go index 3d541ee9a2..491b94d341 100644 --- a/accounts/abi/numbers.go +++ b/accounts/abi/numbers.go @@ -25,35 +25,20 @@ import ( ) var ( - big_t = reflect.TypeOf(&big.Int{}) - derefbig_t = reflect.TypeOf(big.Int{}) - uint8_t = reflect.TypeOf(uint8(0)) - uint16_t = reflect.TypeOf(uint16(0)) - uint32_t = reflect.TypeOf(uint32(0)) - uint64_t = reflect.TypeOf(uint64(0)) - int_t = reflect.TypeOf(int(0)) - int8_t = reflect.TypeOf(int8(0)) - int16_t = reflect.TypeOf(int16(0)) - int32_t = reflect.TypeOf(int32(0)) - int64_t = reflect.TypeOf(int64(0)) - address_t = reflect.TypeOf(common.Address{}) - int_ts = reflect.TypeOf([]int(nil)) - int8_ts = reflect.TypeOf([]int8(nil)) - int16_ts = reflect.TypeOf([]int16(nil)) - int32_ts = reflect.TypeOf([]int32(nil)) - int64_ts = reflect.TypeOf([]int64(nil)) + bigT = reflect.TypeOf(&big.Int{}) + derefbigT = reflect.TypeOf(big.Int{}) + uint8T = reflect.TypeOf(uint8(0)) + uint16T = reflect.TypeOf(uint16(0)) + uint32T = reflect.TypeOf(uint32(0)) + uint64T = reflect.TypeOf(uint64(0)) + int8T = reflect.TypeOf(int8(0)) + int16T = reflect.TypeOf(int16(0)) + int32T = reflect.TypeOf(int32(0)) + int64T = reflect.TypeOf(int64(0)) + addressT = reflect.TypeOf(common.Address{}) ) // U256 converts a big Int into a 256bit EVM number. func U256(n *big.Int) []byte { return math.PaddedBigBytes(math.U256(n), 32) } - -// checks whether the given reflect value is signed. This also works for slices with a number type -func isSigned(v reflect.Value) bool { - switch v.Type() { - case int_ts, int8_ts, int16_ts, int32_ts, int64_ts, int_t, int8_t, int16_t, int32_t, int64_t: - return true - } - return false -} diff --git a/accounts/abi/reflect.go b/accounts/abi/reflect.go index 2e6bf7098f..c39b3d0a6b 100644 --- a/accounts/abi/reflect.go +++ b/accounts/abi/reflect.go @@ -19,12 +19,13 @@ package abi import ( "fmt" "reflect" + "strings" ) // indirect recursively dereferences the value until it either gets the value // or finds a big.Int func indirect(v reflect.Value) reflect.Value { - if v.Kind() == reflect.Ptr && v.Elem().Type() != derefbig_t { + if v.Kind() == reflect.Ptr && v.Elem().Type() != derefbigT { return indirect(v.Elem()) } return v @@ -36,26 +37,26 @@ func reflectIntKindAndType(unsigned bool, size int) (reflect.Kind, reflect.Type) switch size { case 8: if unsigned { - return reflect.Uint8, uint8_t + return reflect.Uint8, uint8T } - return reflect.Int8, int8_t + return reflect.Int8, int8T case 16: if unsigned { - return reflect.Uint16, uint16_t + return reflect.Uint16, uint16T } - return reflect.Int16, int16_t + return reflect.Int16, int16T case 32: if unsigned { - return reflect.Uint32, uint32_t + return reflect.Uint32, uint32T } - return reflect.Int32, int32_t + return reflect.Int32, int32T case 64: if unsigned { - return reflect.Uint64, uint64_t + return reflect.Uint64, uint64T } - return reflect.Int64, int64_t + return reflect.Int64, int64T } - return reflect.Ptr, big_t + return reflect.Ptr, bigT } // mustArrayToBytesSlice creates a new byte slice with the exact same size as value @@ -70,22 +71,36 @@ func mustArrayToByteSlice(value reflect.Value) reflect.Value { // // set is a bit more lenient when it comes to assignment and doesn't force an as // strict ruleset as bare `reflect` does. -func set(dst, src reflect.Value, output Argument) error { - dstType := dst.Type() - srcType := src.Type() +func set(dst, src reflect.Value) error { + dstType, srcType := dst.Type(), src.Type() switch { - case dstType.AssignableTo(srcType): - dst.Set(src) case dstType.Kind() == reflect.Interface: + return set(dst.Elem(), src) + case dstType.Kind() == reflect.Ptr && dstType.Elem() != derefbigT: + return set(dst.Elem(), src) + case srcType.AssignableTo(dstType) && dst.CanSet(): dst.Set(src) - case dstType.Kind() == reflect.Ptr: - return set(dst.Elem(), src, output) + case dstType.Kind() == reflect.Slice && srcType.Kind() == reflect.Slice: + return setSlice(dst, src) default: return fmt.Errorf("abi: cannot unmarshal %v in to %v", src.Type(), dst.Type()) } return nil } +// setSlice attempts to assign src to dst when slices are not assignable by default +// e.g. src: [][]byte -> dst: [][15]byte +func setSlice(dst, src reflect.Value) error { + slice := reflect.MakeSlice(dst.Type(), src.Len(), src.Len()) + for i := 0; i < src.Len(); i++ { + v := src.Index(i) + reflect.Copy(slice.Index(i), v) + } + + dst.Set(slice) + return nil +} + // requireAssignable assures that `dest` is a pointer and it's not an interface. func requireAssignable(dst, src reflect.Value) error { if dst.Kind() != reflect.Ptr && dst.Kind() != reflect.Interface { @@ -111,18 +126,97 @@ func requireUnpackKind(v reflect.Value, t reflect.Type, k reflect.Kind, return nil } -// requireUniqueStructFieldNames makes sure field names don't collide -func requireUniqueStructFieldNames(args Arguments) error { - exists := make(map[string]bool) - for _, arg := range args { - field := capitalise(arg.Name) - if field == "" { - return fmt.Errorf("abi: purely underscored output cannot unpack to struct") +// mapArgNamesToStructFields maps a slice of argument names to struct fields. +// first round: for each Exportable field that contains a `abi:""` tag +// +// and this field name exists in the given argument name list, pair them together. +// +// second round: for each argument name that has not been already linked, +// +// find what variable is expected to be mapped into, if it exists and has not been +// used, pair them. +// +// Note this function assumes the given value is a struct value. +func mapArgNamesToStructFields(argNames []string, value reflect.Value) (map[string]string, error) { + typ := value.Type() + + abi2struct := make(map[string]string) + struct2abi := make(map[string]string) + + // first round ~~~ + for i := 0; i < typ.NumField(); i++ { + structFieldName := typ.Field(i).Name + + // skip private struct fields. + if structFieldName[:1] != strings.ToUpper(structFieldName[:1]) { + continue + } + // skip fields that have no abi:"" tag. + var ok bool + var tagName string + if tagName, ok = typ.Field(i).Tag.Lookup("abi"); !ok { + continue + } + // check if tag is empty. + if tagName == "" { + return nil, fmt.Errorf("struct: abi tag in '%s' is empty", structFieldName) + } + // check which argument field matches with the abi tag. + found := false + for _, arg := range argNames { + if arg == tagName { + if abi2struct[arg] != "" { + return nil, fmt.Errorf("struct: abi tag in '%s' already mapped", structFieldName) + } + // pair them + abi2struct[arg] = structFieldName + struct2abi[structFieldName] = arg + found = true + } + } + // check if this tag has been mapped. + if !found { + return nil, fmt.Errorf("struct: abi tag '%s' defined but not found in abi", tagName) + } + } + + // second round ~~~ + for _, argName := range argNames { + + structFieldName := ToCamelCase(argName) + + if structFieldName == "" { + return nil, fmt.Errorf("abi: purely underscored output cannot unpack to struct") + } + + // this abi has already been paired, skip it... unless there exists another, yet unassigned + // struct field with the same field name. If so, raise an error: + // abi: [ { "name": "value" } ] + // struct { Value *big.Int , Value1 *big.Int `abi:"value"`} + if abi2struct[argName] != "" { + if abi2struct[argName] != structFieldName && + struct2abi[structFieldName] == "" && + value.FieldByName(structFieldName).IsValid() { + return nil, fmt.Errorf("abi: multiple variables maps to the same abi field '%s'", argName) + } + continue } - if exists[field] { - return fmt.Errorf("abi: multiple outputs mapping to the same struct field '%s'", field) + + // return an error if this struct field has already been paired. + if struct2abi[structFieldName] != "" { + return nil, fmt.Errorf("abi: multiple outputs mapping to the same struct field '%s'", structFieldName) + } + + if value.FieldByName(structFieldName).IsValid() { + // pair them + abi2struct[argName] = structFieldName + struct2abi[structFieldName] = argName + } else { + // not paired, but annotate as used, to detect cases like + // abi : [ { "name": "value" }, { "name": "_value" } ] + // struct { Value *big.Int } + struct2abi[structFieldName] = argName } - exists[field] = true } - return nil + return abi2struct, nil } diff --git a/accounts/abi/type.go b/accounts/abi/type.go index e216cbe071..26151dbd3e 100644 --- a/accounts/abi/type.go +++ b/accounts/abi/type.go @@ -17,6 +17,7 @@ package abi import ( + "errors" "fmt" "reflect" "regexp" @@ -32,6 +33,7 @@ const ( StringTy SliceTy ArrayTy + TupleTy AddressTy FixedBytesTy BytesTy @@ -43,13 +45,16 @@ const ( // Type is the reflection of the supported argument type type Type struct { Elem *Type - Kind reflect.Kind Type reflect.Type Size int T byte // Our own type checking stringKind string // holds the unparsed string for deriving signatures + + // Tuple relative fields + TupleElems []*Type // Type information of all tuple fields + TupleRawNames []string // Raw field name of all tuple fields } var ( @@ -58,7 +63,7 @@ var ( ) // NewType creates a new reflection type of abi type given in t. -func NewType(t string) (typ Type, err error) { +func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { // check that array brackets are equal if they exist if strings.Count(t, "[") != strings.Count(t, "]") { return Type{}, fmt.Errorf("invalid arg type in abi") @@ -71,7 +76,7 @@ func NewType(t string) (typ Type, err error) { if strings.Count(t, "[") != 0 { i := strings.LastIndex(t, "[") // recursively embed the type - embeddedType, err := NewType(t[:i]) + embeddedType, err := NewType(t[:i], components) if err != nil { return Type{}, err } @@ -87,6 +92,9 @@ func NewType(t string) (typ Type, err error) { typ.Kind = reflect.Slice typ.Elem = &embeddedType typ.Type = reflect.SliceOf(embeddedType.Type) + if embeddedType.T == TupleTy { + typ.stringKind = embeddedType.stringKind + sliced + } } else if len(intz) == 1 { // is a array typ.T = ArrayTy @@ -97,13 +105,21 @@ func NewType(t string) (typ Type, err error) { return Type{}, fmt.Errorf("abi: error parsing variable size: %v", err) } typ.Type = reflect.ArrayOf(typ.Size, embeddedType.Type) + if embeddedType.T == TupleTy { + typ.stringKind = embeddedType.stringKind + sliced + } } else { return Type{}, fmt.Errorf("invalid formatting of array type") } return typ, err } // parse the type and size of the abi-type. - parsedType := typeRegex.FindAllStringSubmatch(t, -1)[0] + matches := typeRegex.FindAllStringSubmatch(t, -1) + if len(matches) == 0 { + return Type{}, fmt.Errorf("invalid type '%v'", t) + } + parsedType := matches[0] + // varSize is the size of the variable var varSize int if len(parsedType[3]) > 0 { @@ -135,7 +151,7 @@ func NewType(t string) (typ Type, err error) { typ.Type = reflect.TypeOf(bool(false)) case "address": typ.Kind = reflect.Array - typ.Type = address_t + typ.Type = addressT typ.Size = 20 typ.T = AddressTy case "string": @@ -153,6 +169,40 @@ func NewType(t string) (typ Type, err error) { typ.Size = varSize typ.Type = reflect.ArrayOf(varSize, reflect.TypeOf(byte(0))) } + case "tuple": + var ( + fields []reflect.StructField + elems []*Type + names []string + expression string // canonical parameter expression + ) + expression += "(" + for idx, c := range components { + cType, err := NewType(c.Type, c.Components) + if err != nil { + return Type{}, err + } + if ToCamelCase(c.Name) == "" { + return Type{}, errors.New("abi: purely anonymous or underscored field is not supported") + } + fields = append(fields, reflect.StructField{ + Name: ToCamelCase(c.Name), // reflect.StructOf will panic for any exported field. + Type: cType.Type, + }) + elems = append(elems, &cType) + names = append(names, c.Name) + expression += cType.stringKind + if idx != len(components)-1 { + expression += "," + } + } + expression += ")" + typ.Kind = reflect.Struct + typ.Type = reflect.StructOf(fields) + typ.TupleElems = elems + typ.TupleRawNames = names + typ.T = TupleTy + typ.stringKind = expression case "function": typ.Kind = reflect.Array typ.T = FunctionTy @@ -173,7 +223,6 @@ func (t Type) String() (out string) { func (t Type) pack(v reflect.Value) ([]byte, error) { // dereference pointer first if it's a pointer v = indirect(v) - if err := typeCheck(t, v); err != nil { return nil, err } @@ -191,7 +240,7 @@ func (t Type) pack(v reflect.Value) ([]byte, error) { offset := 0 offsetReq := isDynamicType(*t.Elem) if offsetReq { - offset = getDynamicTypeOffset(*t.Elem) * v.Len() + offset = getTypeSize(*t.Elem) * v.Len() } var tail []byte for i := 0; i < v.Len(); i++ { @@ -208,6 +257,45 @@ func (t Type) pack(v reflect.Value) ([]byte, error) { tail = append(tail, val...) } return append(ret, tail...), nil + case TupleTy: + // (T1,...,Tk) for k >= 0 and any types T1, …, Tk + // enc(X) = head(X(1)) ... head(X(k)) tail(X(1)) ... tail(X(k)) + // where X = (X(1), ..., X(k)) and head and tail are defined for Ti being a static + // type as + // head(X(i)) = enc(X(i)) and tail(X(i)) = "" (the empty string) + // and as + // head(X(i)) = enc(len(head(X(1)) ... head(X(k)) tail(X(1)) ... tail(X(i-1)))) + // tail(X(i)) = enc(X(i)) + // otherwise, i.e. if Ti is a dynamic type. + fieldmap, err := mapArgNamesToStructFields(t.TupleRawNames, v) + if err != nil { + return nil, err + } + // Calculate prefix occupied size. + offset := 0 + for _, elem := range t.TupleElems { + offset += getTypeSize(*elem) + } + var ret, tail []byte + for i, elem := range t.TupleElems { + field := v.FieldByName(fieldmap[t.TupleRawNames[i]]) + if !field.IsValid() { + return nil, fmt.Errorf("field %s for tuple not found in the given struct", t.TupleRawNames[i]) + } + val, err := elem.pack(field) + if err != nil { + return nil, err + } + if isDynamicType(*elem) { + ret = append(ret, packNum(reflect.ValueOf(offset))...) + tail = append(tail, val...) + offset += len(val) + } else { + ret = append(ret, val...) + } + } + return append(ret, tail...), nil + default: return packElement(t, v), nil } @@ -220,25 +308,45 @@ func (t Type) requiresLengthPrefix() bool { } // isDynamicType returns true if the type is dynamic. -// StringTy, BytesTy, and SliceTy (irrespective of slice element type) are dynamic types -// ArrayTy is considered dynamic if and only if the Array element is a dynamic type. -// This function recursively checks the type for slice and array elements. +// The following types are called “dynamic”: +// * bytes +// * string +// * T[] for any T +// * T[k] for any dynamic T and any k >= 0 +// * (T1,...,Tk) if Ti is dynamic for some 1 <= i <= k func isDynamicType(t Type) bool { - // dynamic types - // array is also a dynamic type if the array type is dynamic + if t.T == TupleTy { + for _, elem := range t.TupleElems { + if isDynamicType(*elem) { + return true + } + } + return false + } return t.T == StringTy || t.T == BytesTy || t.T == SliceTy || (t.T == ArrayTy && isDynamicType(*t.Elem)) } -// getDynamicTypeOffset returns the offset for the type. -// See `isDynamicType` to know which types are considered dynamic. -// If the type t is an array and element type is not a dynamic type, then we consider it a static type and -// return 32 * size of array since length prefix is not required. -// If t is a dynamic type or element type(for slices and arrays) is dynamic, then we simply return 32 as offset. -func getDynamicTypeOffset(t Type) int { - // if it is an array and there are no dynamic types - // then the array is static type +// getTypeSize returns the size that this type needs to occupy. +// We distinguish static and dynamic types. Static types are encoded in-place +// and dynamic types are encoded at a separately allocated location after the +// current block. +// So for a static variable, the size returned represents the size that the +// variable actually occupies. +// For a dynamic variable, the returned size is fixed 32 bytes, which is used +// to store the location reference for actual value storage. +func getTypeSize(t Type) int { if t.T == ArrayTy && !isDynamicType(*t.Elem) { - return 32 * t.Size + // Recursively calculate type size if it is a nested array + if t.Elem.T == ArrayTy { + return t.Size * getTypeSize(*t.Elem) + } + return t.Size * 32 + } else if t.T == TupleTy && !isDynamicType(t) { + total := 0 + for _, elem := range t.TupleElems { + total += getTypeSize(*elem) + } + return total } return 32 } diff --git a/accounts/abi/unpack.go b/accounts/abi/unpack.go index 208486349a..86af6f97b9 100644 --- a/accounts/abi/unpack.go +++ b/accounts/abi/unpack.go @@ -25,8 +25,17 @@ import ( "github.com/tomochain/tomochain/common" ) +var ( + maxUint256 = big.NewInt(0).Add( + big.NewInt(0).Exp(big.NewInt(2), big.NewInt(256), nil), + big.NewInt(-1)) + maxInt256 = big.NewInt(0).Add( + big.NewInt(0).Exp(big.NewInt(2), big.NewInt(255), nil), + big.NewInt(-1)) +) + // reads the integer based on its kind -func readInteger(kind reflect.Kind, b []byte) interface{} { +func readInteger(typ byte, kind reflect.Kind, b []byte) interface{} { switch kind { case reflect.Uint8: return b[len(b)-1] @@ -45,7 +54,20 @@ func readInteger(kind reflect.Kind, b []byte) interface{} { case reflect.Int64: return int64(binary.BigEndian.Uint64(b[len(b)-8:])) default: - return new(big.Int).SetBytes(b) + // the only case lefts for integer is int256/uint256. + // big.SetBytes can't tell if a number is negative, positive on itself. + // On EVM, if the returned number > max int256, it is negative. + ret := new(big.Int).SetBytes(b) + if typ == UintTy { + return ret + } + + if ret.Cmp(maxInt256) > 0 { + ret.Add(maxUint256, big.NewInt(0).Neg(ret)) + ret.Add(ret, big.NewInt(1)) + ret.Neg(ret) + } + return ret } } @@ -93,17 +115,6 @@ func readFixedBytes(t Type, word []byte) (interface{}, error) { } -func getFullElemSize(elem *Type) int { - //all other should be counted as 32 (slices have pointers to respective elements) - size := 32 - //arrays wrap it, each element being the same size - for elem.T == ArrayTy { - size *= elem.Size - elem = elem.Elem - } - return size -} - // iteratively unpack elements func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error) { if size < 0 { @@ -128,13 +139,9 @@ func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error) // Arrays have packed elements, resulting in longer unpack steps. // Slices have just 32 bytes per element (pointing to the contents). - elemSize := 32 - if t.T == ArrayTy { - elemSize = getFullElemSize(t.Elem) - } + elemSize := getTypeSize(*t.Elem) for i, j := start, 0; j < size; i, j = i+elemSize, j+1 { - inter, err := toGoType(i, *t.Elem, output) if err != nil { return nil, err @@ -148,6 +155,36 @@ func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error) return refSlice.Interface(), nil } +func forTupleUnpack(t Type, output []byte) (interface{}, error) { + retval := reflect.New(t.Type).Elem() + virtualArgs := 0 + for index, elem := range t.TupleElems { + marshalledValue, err := toGoType((index+virtualArgs)*32, *elem, output) + if elem.T == ArrayTy && !isDynamicType(*elem) { + // If we have a static array, like [3]uint256, these are coded as + // just like uint256,uint256,uint256. + // This means that we need to add two 'virtual' arguments when + // we count the index from now on. + // + // Array values nested multiple levels deep are also encoded inline: + // [2][3]uint256: uint256,uint256,uint256,uint256,uint256,uint256 + // + // Calculate the full array size to get the correct offset for the next argument. + // Decrement it by 1, as the normal index increment is still applied. + virtualArgs += getTypeSize(*elem)/32 - 1 + } else if elem.T == TupleTy && !isDynamicType(*elem) { + // If we have a static tuple, like (uint256, bool, uint256), these are + // coded as just like uint256,bool,uint256 + virtualArgs += getTypeSize(*elem)/32 - 1 + } + if err != nil { + return nil, err + } + retval.Field(index).Set(reflect.ValueOf(marshalledValue)) + } + return retval.Interface(), nil +} + // toGoType parses the output bytes and recursively assigns the value of these bytes // into a go type with accordance with the ABI spec. func toGoType(index int, t Type, output []byte) (interface{}, error) { @@ -156,14 +193,14 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) { } var ( - returnOutput []byte - begin, end int - err error + returnOutput []byte + begin, length int + err error ) // if we require a length prefix, find the beginning word and size returned. if t.requiresLengthPrefix() { - begin, end, err = lengthPrefixPointsTo(index, output) + begin, length, err = lengthPrefixPointsTo(index, output) if err != nil { return nil, err } @@ -172,14 +209,28 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) { } switch t.T { + case TupleTy: + if isDynamicType(t) { + begin, err := tuplePointsTo(index, output) + if err != nil { + return nil, err + } + return forTupleUnpack(t, output[begin:]) + } else { + return forTupleUnpack(t, output[index:]) + } case SliceTy: - return forEachUnpack(t, output, begin, end) + return forEachUnpack(t, output[begin:], 0, length) case ArrayTy: - return forEachUnpack(t, output, index, t.Size) + if isDynamicType(*t.Elem) { + offset := int64(binary.BigEndian.Uint64(returnOutput[len(returnOutput)-8:])) + return forEachUnpack(t, output[offset:], 0, t.Size) + } + return forEachUnpack(t, output[index:], 0, t.Size) case StringTy: // variable arrays are written at the end of the return bytes - return string(output[begin : begin+end]), nil + return string(output[begin : begin+length]), nil case IntTy, UintTy: - return readInteger(t.Kind, returnOutput), nil + return readInteger(t.T, t.Kind, returnOutput), nil case BoolTy: return readBool(returnOutput) case AddressTy: @@ -187,7 +238,7 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) { case HashTy: return common.BytesToHash(returnOutput), nil case BytesTy: - return output[begin : begin+end], nil + return output[begin : begin+length], nil case FixedBytesTy: return readFixedBytes(t, returnOutput) case FunctionTy: @@ -228,3 +279,17 @@ func lengthPrefixPointsTo(index int, output []byte) (start int, length int, err length = int(lengthBig.Uint64()) return } + +// tuplePointsTo resolves the location reference for dynamic tuple. +func tuplePointsTo(index int, output []byte) (start int, err error) { + offset := big.NewInt(0).SetBytes(output[index : index+32]) + outputLen := big.NewInt(int64(len(output))) + + if offset.Cmp(big.NewInt(int64(len(output)))) > 0 { + return 0, fmt.Errorf("abi: cannot marshal in to go slice: offset %v would go over slice boundary (len=%v)", offset, outputLen) + } + if offset.BitLen() > 63 { + return 0, fmt.Errorf("abi offset larger than int64: %v", offset) + } + return int(offset.Uint64()), nil +} From a9cbf32b03aa221a9ca185a320726688ecf8e650 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 29 Sep 2023 16:46:22 +0700 Subject: [PATCH 081/119] Add unit tests --- accounts/abi/abi_test.go | 89 +++++--- accounts/abi/numbers_test.go | 11 - accounts/abi/pack_test.go | 302 +++++++++++++++++++++++- accounts/abi/type_test.go | 430 ++++++++++++++++++----------------- accounts/abi/unpack_test.go | 301 +++++++++++++++++++++++- 5 files changed, 878 insertions(+), 255 deletions(-) diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 5a128bfe54..354668e206 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -22,11 +22,10 @@ import ( "fmt" "log" "math/big" + "reflect" "strings" "testing" - "reflect" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/crypto" ) @@ -52,11 +51,14 @@ const jsondata2 = ` { "type" : "function", "name" : "slice", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint32[2]" } ] }, { "type" : "function", "name" : "slice256", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint256[2]" } ] }, { "type" : "function", "name" : "sliceAddress", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "address[]" } ] }, - { "type" : "function", "name" : "sliceMultiAddress", "constant" : false, "inputs" : [ { "name" : "a", "type" : "address[]" }, { "name" : "b", "type" : "address[]" } ] } + { "type" : "function", "name" : "sliceMultiAddress", "constant" : false, "inputs" : [ { "name" : "a", "type" : "address[]" }, { "name" : "b", "type" : "address[]" } ] }, + { "type" : "function", "name" : "nestedArray", "constant" : false, "inputs" : [ { "name" : "a", "type" : "uint256[2][2]" }, { "name" : "b", "type" : "address[]" } ] }, + { "type" : "function", "name" : "nestedArray2", "constant" : false, "inputs" : [ { "name" : "a", "type" : "uint8[][2]" } ] }, + { "type" : "function", "name" : "nestedSlice", "constant" : false, "inputs" : [ { "name" : "a", "type" : "uint8[][]" } ] } ]` func TestReader(t *testing.T) { - Uint256, _ := NewType("uint256") + Uint256, _ := NewType("uint256", nil) exp := ABI{ Methods: map[string]Method{ "balance": { @@ -177,7 +179,7 @@ func TestTestSlice(t *testing.T) { } func TestMethodSignature(t *testing.T) { - String, _ := NewType("string") + String, _ := NewType("string", nil) m := Method{"foo", false, []Argument{{"bar", String, false}, {"baz", String, false}}, nil} exp := "foo(string,string)" if m.Sig() != exp { @@ -189,12 +191,31 @@ func TestMethodSignature(t *testing.T) { t.Errorf("expected ids to match %x != %x", m.Id(), idexp) } - uintt, _ := NewType("uint256") + uintt, _ := NewType("uint256", nil) m = Method{"foo", false, []Argument{{"bar", uintt, false}}, nil} exp = "foo(uint256)" if m.Sig() != exp { t.Error("signature mismatch", exp, "!=", m.Sig()) } + + // Method with tuple arguments + s, _ := NewType("tuple", []ArgumentMarshaling{ + {Name: "a", Type: "int256"}, + {Name: "b", Type: "int256[]"}, + {Name: "c", Type: "tuple[]", Components: []ArgumentMarshaling{ + {Name: "x", Type: "int256"}, + {Name: "y", Type: "int256"}, + }}, + {Name: "d", Type: "tuple[2]", Components: []ArgumentMarshaling{ + {Name: "x", Type: "int256"}, + {Name: "y", Type: "int256"}, + }}, + }) + m = Method{"foo", false, []Argument{{"s", s, false}, {"bar", String, false}}, nil} + exp = "foo((int256,int256[],(int256,int256)[],(int256,int256)[2]),string)" + if m.Sig() != exp { + t.Error("signature mismatch", exp, "!=", m.Sig()) + } } func TestMultiPack(t *testing.T) { @@ -564,11 +585,13 @@ func TestBareEvents(t *testing.T) { const definition = `[ { "type" : "event", "name" : "balance" }, { "type" : "event", "name" : "anon", "anonymous" : true}, - { "type" : "event", "name" : "args", "inputs" : [{ "indexed":false, "name":"arg0", "type":"uint256" }, { "indexed":true, "name":"arg1", "type":"address" }] } + { "type" : "event", "name" : "args", "inputs" : [{ "indexed":false, "name":"arg0", "type":"uint256" }, { "indexed":true, "name":"arg1", "type":"address" }] }, + { "type" : "event", "name" : "tuple", "inputs" : [{ "indexed":false, "name":"t", "type":"tuple", "components":[{"name":"a", "type":"uint256"}] }, { "indexed":true, "name":"arg1", "type":"address" }] } ]` - arg0, _ := NewType("uint256") - arg1, _ := NewType("address") + arg0, _ := NewType("uint256", nil) + arg1, _ := NewType("address", nil) + tuple, _ := NewType("tuple", []ArgumentMarshaling{{Name: "a", Type: "uint256"}}) expectedEvents := map[string]struct { Anonymous bool @@ -580,6 +603,10 @@ func TestBareEvents(t *testing.T) { {Name: "arg0", Type: arg0, Indexed: false}, {Name: "arg1", Type: arg1, Indexed: true}, }}, + "tuple": {false, []Argument{ + {Name: "t", Type: tuple, Indexed: false}, + {Name: "arg1", Type: arg1, Indexed: true}, + }}, } abi, err := JSON(strings.NewReader(definition)) @@ -619,16 +646,19 @@ func TestBareEvents(t *testing.T) { } // TestUnpackEvent is based on this contract: -// contract T { -// event received(address sender, uint amount, bytes memo); -// event receivedAddr(address sender); -// function receive(bytes memo) external payable { -// received(msg.sender, msg.value, memo); -// receivedAddr(msg.sender); -// } -// } +// +// contract T { +// event received(address sender, uint amount, bytes memo); +// event receivedAddr(address sender); +// function receive(bytes memo) external payable { +// received(msg.sender, msg.value, memo); +// receivedAddr(msg.sender); +// } +// } +// // When receive("X") is called with sender 0x00... and value 1, it produces this tx receipt: -// receipt{status=1 cgas=23949 bloom=00000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000040200000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 logs=[log: b6818c8064f645cd82d99b59a1a267d6d61117ef [75fd880d39c1daf53b6547ab6cb59451fc6452d27caa90e5b6649dd8293b9eed] 000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158 9ae378b6d4409eada347a5dc0c180f186cb62dc68fcc0f043425eb917335aa28 0 95d429d309bb9d753954195fe2d69bd140b4ae731b9b5b605c34323de162cf00 0]} +// +// receipt{status=1 cgas=23949 bloom=00000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000040200000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 logs=[log: b6818c8064f645cd82d99b59a1a267d6d61117ef [75fd880d39c1daf53b6547ab6cb59451fc6452d27caa90e5b6649dd8293b9eed] 000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158 9ae378b6d4409eada347a5dc0c180f186cb62dc68fcc0f043425eb917335aa28 0 95d429d309bb9d753954195fe2d69bd140b4ae731b9b5b605c34323de162cf00 0]} func TestUnpackEvent(t *testing.T) { const abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]` abi, err := JSON(strings.NewReader(abiJSON)) @@ -646,28 +676,24 @@ func TestUnpackEvent(t *testing.T) { } type ReceivedEvent struct { - Address common.Address - Amount *big.Int - Memo []byte + Sender common.Address + Amount *big.Int + Memo []byte } var ev ReceivedEvent err = abi.Unpack(&ev, "received", data) if err != nil { t.Error(err) - } else { - t.Logf("len(data): %d; received event: %+v", len(data), ev) } type ReceivedAddrEvent struct { - Address common.Address + Sender common.Address } var receivedAddrEv ReceivedAddrEvent err = abi.Unpack(&receivedAddrEv, "receivedAddr", data) if err != nil { t.Error(err) - } else { - t.Logf("len(data): %d; received event: %+v", len(data), receivedAddrEv) } } @@ -711,5 +737,14 @@ func TestABI_MethodById(t *testing.T) { t.Errorf("Method %v (id %v) not 'findable' by id in ABI", name, common.ToHex(m.Id())) } } - + // Also test empty + if _, err := abi.MethodById([]byte{0x00}); err == nil { + t.Errorf("Expected error, too short to decode data") + } + if _, err := abi.MethodById([]byte{}); err == nil { + t.Errorf("Expected error, too short to decode data") + } + if _, err := abi.MethodById(nil); err == nil { + t.Errorf("Expected error, nil is short to decode data") + } } diff --git a/accounts/abi/numbers_test.go b/accounts/abi/numbers_test.go index b9ff5aef17..d25a5abcb5 100644 --- a/accounts/abi/numbers_test.go +++ b/accounts/abi/numbers_test.go @@ -19,7 +19,6 @@ package abi import ( "bytes" "math/big" - "reflect" "testing" ) @@ -32,13 +31,3 @@ func TestNumberTypes(t *testing.T) { t.Errorf("expected %x got %x", ubytes, unsigned) } } - -func TestSigned(t *testing.T) { - if isSigned(reflect.ValueOf(uint(10))) { - t.Error("signed") - } - - if !isSigned(reflect.ValueOf(int(10))) { - t.Error("not signed") - } -} diff --git a/accounts/abi/pack_test.go b/accounts/abi/pack_test.go index 7956157fee..8578d03de2 100644 --- a/accounts/abi/pack_test.go +++ b/accounts/abi/pack_test.go @@ -29,303 +29,356 @@ import ( func TestPack(t *testing.T) { for i, test := range []struct { - typ string - - input interface{} - output []byte + typ string + components []ArgumentMarshaling + input interface{} + output []byte }{ { "uint8", + nil, uint8(2), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), }, { "uint8[]", + nil, []uint8{1, 2}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), }, { "uint16", + nil, uint16(2), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), }, { "uint16[]", + nil, []uint16{1, 2}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), }, { "uint32", + nil, uint32(2), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), }, { "uint32[]", + nil, []uint32{1, 2}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), }, { "uint64", + nil, uint64(2), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), }, { "uint64[]", + nil, []uint64{1, 2}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), }, { "uint256", + nil, big.NewInt(2), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), }, { "uint256[]", + nil, []*big.Int{big.NewInt(1), big.NewInt(2)}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), }, { "int8", + nil, int8(2), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), }, { "int8[]", + nil, []int8{1, 2}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), }, { "int16", + nil, int16(2), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), }, { "int16[]", + nil, []int16{1, 2}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), }, { "int32", + nil, int32(2), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), }, { "int32[]", + nil, []int32{1, 2}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), }, { "int64", + nil, int64(2), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), }, { "int64[]", + nil, []int64{1, 2}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), }, { "int256", + nil, big.NewInt(2), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), }, { "int256[]", + nil, []*big.Int{big.NewInt(1), big.NewInt(2)}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), }, { "bytes1", + nil, [1]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes2", + nil, [2]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes3", + nil, [3]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes4", + nil, [4]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes5", + nil, [5]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes6", + nil, [6]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes7", + nil, [7]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes8", + nil, [8]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes9", + nil, [9]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes10", + nil, [10]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes11", + nil, [11]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes12", + nil, [12]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes13", + nil, [13]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes14", + nil, [14]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes15", + nil, [15]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes16", + nil, [16]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes17", + nil, [17]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes18", + nil, [18]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes19", + nil, [19]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes20", + nil, [20]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes21", + nil, [21]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes22", + nil, [22]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes23", + nil, [23]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes24", - [24]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes24", + nil, [24]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes25", + nil, [25]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes26", + nil, [26]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes27", + nil, [27]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes28", + nil, [28]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes29", + nil, [29]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes30", + nil, [30]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes31", + nil, [31]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "bytes32", + nil, [32]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "uint32[2][3][4]", + nil, [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000050000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000700000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000d000000000000000000000000000000000000000000000000000000000000000e000000000000000000000000000000000000000000000000000000000000000f000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000110000000000000000000000000000000000000000000000000000000000000012000000000000000000000000000000000000000000000000000000000000001300000000000000000000000000000000000000000000000000000000000000140000000000000000000000000000000000000000000000000000000000000015000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000170000000000000000000000000000000000000000000000000000000000000018"), }, { "address[]", + nil, []common.Address{{1}, {2}}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000"), }, { "bytes32[]", + nil, []common.Hash{{1}, {2}}, common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000201000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000"), }, { "function", + nil, [24]byte{1}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), }, { "string", + nil, "foobar", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000006666f6f6261720000000000000000000000000000000000000000000000000000"), }, { "string[]", + nil, []string{"hello", "foobar"}, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002" + // len(array) = 2 "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 @@ -337,6 +390,7 @@ func TestPack(t *testing.T) { }, { "string[2]", + nil, []string{"hello", "foobar"}, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040" + // offset to i = 0 "0000000000000000000000000000000000000000000000000000000000000080" + // offset to i = 1 @@ -347,6 +401,7 @@ func TestPack(t *testing.T) { }, { "bytes32[][]", + nil, [][]common.Hash{{{1}, {2}}, {{3}, {4}, {5}}}, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002" + // len(array) = 2 "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 @@ -362,6 +417,7 @@ func TestPack(t *testing.T) { { "bytes32[][2]", + nil, [][]common.Hash{{{1}, {2}}, {{3}, {4}, {5}}}, common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 "00000000000000000000000000000000000000000000000000000000000000a0" + // offset 160 to i = 1 @@ -376,6 +432,7 @@ func TestPack(t *testing.T) { { "bytes32[3][2]", + nil, [][]common.Hash{{{1}, {2}, {3}}, {{3}, {4}, {5}}}, common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0] "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1] @@ -384,12 +441,182 @@ func TestPack(t *testing.T) { "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1] "0500000000000000000000000000000000000000000000000000000000000000"), // array[1][2] }, + { + // static tuple + "tuple", + []ArgumentMarshaling{ + {Name: "a", Type: "int64"}, + {Name: "b", Type: "int256"}, + {Name: "c", Type: "int256"}, + {Name: "d", Type: "bool"}, + {Name: "e", Type: "bytes32[3][2]"}, + }, + struct { + A int64 + B *big.Int + C *big.Int + D bool + E [][]common.Hash + }{1, big.NewInt(1), big.NewInt(-1), true, [][]common.Hash{{{1}, {2}, {3}}, {{3}, {4}, {5}}}}, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001" + // struct[a] + "0000000000000000000000000000000000000000000000000000000000000001" + // struct[b] + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // struct[c] + "0000000000000000000000000000000000000000000000000000000000000001" + // struct[d] + "0100000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][0] + "0200000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][1] + "0300000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][2] + "0300000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[1][0] + "0400000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[1][1] + "0500000000000000000000000000000000000000000000000000000000000000"), // struct[e] array[1][2] + }, + { + // dynamic tuple + "tuple", + []ArgumentMarshaling{ + {Name: "a", Type: "string"}, + {Name: "b", Type: "int64"}, + {Name: "c", Type: "bytes"}, + {Name: "d", Type: "string[]"}, + {Name: "e", Type: "int256[]"}, + {Name: "f", Type: "address[]"}, + }, + struct { + FieldA string `abi:"a"` // Test whether abi tag works + FieldB int64 `abi:"b"` + C []byte + D []string + E []*big.Int + F []common.Address + }{"foobar", 1, []byte{1}, []string{"foo", "bar"}, []*big.Int{big.NewInt(1), big.NewInt(-1)}, []common.Address{{1}, {2}}}, + common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000000000c0" + // struct[a] offset + "0000000000000000000000000000000000000000000000000000000000000001" + // struct[b] + "0000000000000000000000000000000000000000000000000000000000000100" + // struct[c] offset + "0000000000000000000000000000000000000000000000000000000000000140" + // struct[d] offset + "0000000000000000000000000000000000000000000000000000000000000220" + // struct[e] offset + "0000000000000000000000000000000000000000000000000000000000000280" + // struct[f] offset + "0000000000000000000000000000000000000000000000000000000000000006" + // struct[a] length + "666f6f6261720000000000000000000000000000000000000000000000000000" + // struct[a] "foobar" + "0000000000000000000000000000000000000000000000000000000000000001" + // struct[c] length + "0100000000000000000000000000000000000000000000000000000000000000" + // []byte{1} + "0000000000000000000000000000000000000000000000000000000000000002" + // struct[d] length + "0000000000000000000000000000000000000000000000000000000000000040" + // foo offset + "0000000000000000000000000000000000000000000000000000000000000080" + // bar offset + "0000000000000000000000000000000000000000000000000000000000000003" + // foo length + "666f6f0000000000000000000000000000000000000000000000000000000000" + // foo + "0000000000000000000000000000000000000000000000000000000000000003" + // bar offset + "6261720000000000000000000000000000000000000000000000000000000000" + // bar + "0000000000000000000000000000000000000000000000000000000000000002" + // struct[e] length + "0000000000000000000000000000000000000000000000000000000000000001" + // 1 + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // -1 + "0000000000000000000000000000000000000000000000000000000000000002" + // struct[f] length + "0000000000000000000000000100000000000000000000000000000000000000" + // common.Address{1} + "0000000000000000000000000200000000000000000000000000000000000000"), // common.Address{2} + }, + { + // nested tuple + "tuple", + []ArgumentMarshaling{ + {Name: "a", Type: "tuple", Components: []ArgumentMarshaling{{Name: "a", Type: "uint256"}, {Name: "b", Type: "uint256[]"}}}, + {Name: "b", Type: "int256[]"}, + }, + struct { + A struct { + FieldA *big.Int `abi:"a"` + B []*big.Int + } + B []*big.Int + }{ + A: struct { + FieldA *big.Int `abi:"a"` // Test whether abi tag works for nested tuple + B []*big.Int + }{big.NewInt(1), []*big.Int{big.NewInt(1), big.NewInt(0)}}, + B: []*big.Int{big.NewInt(1), big.NewInt(0)}}, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040" + // a offset + "00000000000000000000000000000000000000000000000000000000000000e0" + // b offset + "0000000000000000000000000000000000000000000000000000000000000001" + // a.a value + "0000000000000000000000000000000000000000000000000000000000000040" + // a.b offset + "0000000000000000000000000000000000000000000000000000000000000002" + // a.b length + "0000000000000000000000000000000000000000000000000000000000000001" + // a.b[0] value + "0000000000000000000000000000000000000000000000000000000000000000" + // a.b[1] value + "0000000000000000000000000000000000000000000000000000000000000002" + // b length + "0000000000000000000000000000000000000000000000000000000000000001" + // b[0] value + "0000000000000000000000000000000000000000000000000000000000000000"), // b[1] value + }, + { + // tuple slice + "tuple[]", + []ArgumentMarshaling{ + {Name: "a", Type: "int256"}, + {Name: "b", Type: "int256[]"}, + }, + []struct { + A *big.Int + B []*big.Int + }{ + {big.NewInt(-1), []*big.Int{big.NewInt(1), big.NewInt(0)}}, + {big.NewInt(1), []*big.Int{big.NewInt(2), big.NewInt(-1)}}, + }, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002" + // tuple length + "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0] offset + "00000000000000000000000000000000000000000000000000000000000000e0" + // tuple[1] offset + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].A + "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0].B offset + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[0].B length + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].B[0] value + "0000000000000000000000000000000000000000000000000000000000000000" + // tuple[0].B[1] value + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].A + "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[1].B offset + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].B length + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].B[0] value + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // tuple[1].B[1] value + }, + { + // static tuple array + "tuple[2]", + []ArgumentMarshaling{ + {Name: "a", Type: "int256"}, + {Name: "b", Type: "int256"}, + }, + [2]struct { + A *big.Int + B *big.Int + }{ + {big.NewInt(-1), big.NewInt(1)}, + {big.NewInt(1), big.NewInt(-1)}, + }, + common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].a + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].b + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].a + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // tuple[1].b + }, + { + // dynamic tuple array + "tuple[2]", + []ArgumentMarshaling{ + {Name: "a", Type: "int256[]"}, + }, + [2]struct { + A []*big.Int + }{ + {[]*big.Int{big.NewInt(-1), big.NewInt(1)}}, + {[]*big.Int{big.NewInt(1), big.NewInt(-1)}}, + }, + common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0] offset + "00000000000000000000000000000000000000000000000000000000000000c0" + // tuple[1] offset + "0000000000000000000000000000000000000000000000000000000000000020" + // tuple[0].A offset + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[0].A length + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].A[0] + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].A[1] + "0000000000000000000000000000000000000000000000000000000000000020" + // tuple[1].A offset + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].A length + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].A[0] + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // tuple[1].A[1] + }, } { - typ, err := NewType(test.typ) + typ, err := NewType(test.typ, test.components) if err != nil { t.Fatalf("%v failed. Unexpected parse error: %v", i, err) } - output, err := typ.pack(reflect.ValueOf(test.input)) if err != nil { t.Fatalf("%v failed. Unexpected pack error: %v", i, err) @@ -466,6 +693,59 @@ func TestMethodPack(t *testing.T) { if !bytes.Equal(packed, sig) { t.Errorf("expected %x got %x", sig, packed) } + + a := [2][2]*big.Int{{big.NewInt(1), big.NewInt(1)}, {big.NewInt(2), big.NewInt(0)}} + sig = abi.Methods["nestedArray"].Id() + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{0}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{0xa0}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes(addrC[:], 32)...) + sig = append(sig, common.LeftPadBytes(addrD[:], 32)...) + packed, err = abi.Pack("nestedArray", a, []common.Address{addrC, addrD}) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(packed, sig) { + t.Errorf("expected %x got %x", sig, packed) + } + + sig = abi.Methods["nestedArray2"].Id() + sig = append(sig, common.LeftPadBytes([]byte{0x20}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{0x40}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{0x80}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + packed, err = abi.Pack("nestedArray2", [2][]uint8{{1}, {1}}) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(packed, sig) { + t.Errorf("expected %x got %x", sig, packed) + } + + sig = abi.Methods["nestedSlice"].Id() + sig = append(sig, common.LeftPadBytes([]byte{0x20}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{0x02}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{0x40}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{0xa0}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) + sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) + packed, err = abi.Pack("nestedSlice", [][]uint8{{1, 2}, {1, 2}}) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(packed, sig) { + t.Errorf("expected %x got %x", sig, packed) + } } func TestPackNumber(t *testing.T) { diff --git a/accounts/abi/type_test.go b/accounts/abi/type_test.go index fc23b07520..48da0c4ef5 100644 --- a/accounts/abi/type_test.go +++ b/accounts/abi/type_test.go @@ -32,72 +32,75 @@ type typeWithoutStringer Type // Tests that all allowed types get recognized by the type parser. func TestTypeRegexp(t *testing.T) { tests := []struct { - blob string - kind Type + blob string + components []ArgumentMarshaling + kind Type }{ - {"bool", Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}}, - {"bool[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool(nil)), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}}, - {"bool[2]", Type{Size: 2, Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}}, - {"bool[2][]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}}, - {"bool[][]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}}, - {"bool[][2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}}, - {"bool[2][2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}}, - {"bool[2][][2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][][2]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}, stringKind: "bool[2][][2]"}}, - {"bool[2][2][2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}, stringKind: "bool[2][2][2]"}}, - {"bool[][][]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}, stringKind: "bool[][][]"}}, - {"bool[][2][]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][2][]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}, stringKind: "bool[][2][]"}}, - {"int8", Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}}, - {"int16", Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}}, - {"int32", Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}}, - {"int64", Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}}, - {"int256", Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}}, - {"int8[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int8{}), Elem: &Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[]"}}, - {"int8[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int8{}), Elem: &Type{Kind: reflect.Int8, Type: int8_t, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[2]"}}, - {"int16[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int16{}), Elem: &Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[]"}}, - {"int16[2]", Type{Size: 2, Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]int16{}), Elem: &Type{Kind: reflect.Int16, Type: int16_t, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[2]"}}, - {"int32[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int32{}), Elem: &Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[]"}}, - {"int32[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int32{}), Elem: &Type{Kind: reflect.Int32, Type: int32_t, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[2]"}}, - {"int64[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int64{}), Elem: &Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[]"}}, - {"int64[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int64{}), Elem: &Type{Kind: reflect.Int64, Type: int64_t, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[2]"}}, - {"int256[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[]"}}, - {"int256[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[2]"}}, - {"uint8", Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}}, - {"uint16", Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}}, - {"uint32", Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}}, - {"uint64", Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}}, - {"uint256", Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}}, - {"uint8[]", Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]uint8{}), Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[]"}}, - {"uint8[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint8{}), Elem: &Type{Kind: reflect.Uint8, Type: uint8_t, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[2]"}}, - {"uint16[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint16{}), Elem: &Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[]"}}, - {"uint16[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint16{}), Elem: &Type{Kind: reflect.Uint16, Type: uint16_t, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[2]"}}, - {"uint32[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint32{}), Elem: &Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[]"}}, - {"uint32[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint32{}), Elem: &Type{Kind: reflect.Uint32, Type: uint32_t, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[2]"}}, - {"uint64[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint64{}), Elem: &Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[]"}}, - {"uint64[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint64{}), Elem: &Type{Kind: reflect.Uint64, Type: uint64_t, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[2]"}}, - {"uint256[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[]"}}, - {"uint256[2]", Type{Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]*big.Int{}), Size: 2, Elem: &Type{Kind: reflect.Ptr, Type: big_t, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[2]"}}, - {"bytes32", Type{Kind: reflect.Array, T: FixedBytesTy, Size: 32, Type: reflect.TypeOf([32]byte{}), stringKind: "bytes32"}}, - {"bytes[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][]byte{}), Elem: &Type{Kind: reflect.Slice, Type: reflect.TypeOf([]byte{}), T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[]"}}, - {"bytes[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]byte{}), Elem: &Type{T: BytesTy, Type: reflect.TypeOf([]byte{}), Kind: reflect.Slice, stringKind: "bytes"}, stringKind: "bytes[2]"}}, - {"bytes32[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][32]byte{}), Elem: &Type{Kind: reflect.Array, Type: reflect.TypeOf([32]byte{}), T: FixedBytesTy, Size: 32, stringKind: "bytes32"}, stringKind: "bytes32[]"}}, - {"bytes32[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][32]byte{}), Elem: &Type{Kind: reflect.Array, T: FixedBytesTy, Size: 32, Type: reflect.TypeOf([32]byte{}), stringKind: "bytes32"}, stringKind: "bytes32[2]"}}, - {"string", Type{Kind: reflect.String, T: StringTy, Type: reflect.TypeOf(""), stringKind: "string"}}, - {"string[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]string{}), Elem: &Type{Kind: reflect.String, Type: reflect.TypeOf(""), T: StringTy, stringKind: "string"}, stringKind: "string[]"}}, - {"string[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]string{}), Elem: &Type{Kind: reflect.String, T: StringTy, Type: reflect.TypeOf(""), stringKind: "string"}, stringKind: "string[2]"}}, - {"address", Type{Kind: reflect.Array, Type: address_t, Size: 20, T: AddressTy, stringKind: "address"}}, - {"address[]", Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]common.Address{}), Elem: &Type{Kind: reflect.Array, Type: address_t, Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[]"}}, - {"address[2]", Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]common.Address{}), Elem: &Type{Kind: reflect.Array, Type: address_t, Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[2]"}}, + {"bool", nil, Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}}, + {"bool[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool(nil)), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}}, + {"bool[2]", nil, Type{Size: 2, Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}}, + {"bool[2][]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}}, + {"bool[][]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}}, + {"bool[][2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}}, + {"bool[2][2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}}, + {"bool[2][][2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][][2]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}, stringKind: "bool[2][][2]"}}, + {"bool[2][2][2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}, stringKind: "bool[2][2][2]"}}, + {"bool[][][]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}, stringKind: "bool[][][]"}}, + {"bool[][2][]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][2][]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}, stringKind: "bool[][2][]"}}, + {"int8", nil, Type{Kind: reflect.Int8, Type: int8T, Size: 8, T: IntTy, stringKind: "int8"}}, + {"int16", nil, Type{Kind: reflect.Int16, Type: int16T, Size: 16, T: IntTy, stringKind: "int16"}}, + {"int32", nil, Type{Kind: reflect.Int32, Type: int32T, Size: 32, T: IntTy, stringKind: "int32"}}, + {"int64", nil, Type{Kind: reflect.Int64, Type: int64T, Size: 64, T: IntTy, stringKind: "int64"}}, + {"int256", nil, Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: IntTy, stringKind: "int256"}}, + {"int8[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int8{}), Elem: &Type{Kind: reflect.Int8, Type: int8T, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[]"}}, + {"int8[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int8{}), Elem: &Type{Kind: reflect.Int8, Type: int8T, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[2]"}}, + {"int16[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int16{}), Elem: &Type{Kind: reflect.Int16, Type: int16T, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[]"}}, + {"int16[2]", nil, Type{Size: 2, Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]int16{}), Elem: &Type{Kind: reflect.Int16, Type: int16T, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[2]"}}, + {"int32[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int32{}), Elem: &Type{Kind: reflect.Int32, Type: int32T, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[]"}}, + {"int32[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int32{}), Elem: &Type{Kind: reflect.Int32, Type: int32T, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[2]"}}, + {"int64[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int64{}), Elem: &Type{Kind: reflect.Int64, Type: int64T, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[]"}}, + {"int64[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int64{}), Elem: &Type{Kind: reflect.Int64, Type: int64T, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[2]"}}, + {"int256[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[]"}}, + {"int256[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[2]"}}, + {"uint8", nil, Type{Kind: reflect.Uint8, Type: uint8T, Size: 8, T: UintTy, stringKind: "uint8"}}, + {"uint16", nil, Type{Kind: reflect.Uint16, Type: uint16T, Size: 16, T: UintTy, stringKind: "uint16"}}, + {"uint32", nil, Type{Kind: reflect.Uint32, Type: uint32T, Size: 32, T: UintTy, stringKind: "uint32"}}, + {"uint64", nil, Type{Kind: reflect.Uint64, Type: uint64T, Size: 64, T: UintTy, stringKind: "uint64"}}, + {"uint256", nil, Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: UintTy, stringKind: "uint256"}}, + {"uint8[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]uint8{}), Elem: &Type{Kind: reflect.Uint8, Type: uint8T, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[]"}}, + {"uint8[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint8{}), Elem: &Type{Kind: reflect.Uint8, Type: uint8T, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[2]"}}, + {"uint16[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint16{}), Elem: &Type{Kind: reflect.Uint16, Type: uint16T, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[]"}}, + {"uint16[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint16{}), Elem: &Type{Kind: reflect.Uint16, Type: uint16T, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[2]"}}, + {"uint32[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint32{}), Elem: &Type{Kind: reflect.Uint32, Type: uint32T, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[]"}}, + {"uint32[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint32{}), Elem: &Type{Kind: reflect.Uint32, Type: uint32T, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[2]"}}, + {"uint64[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint64{}), Elem: &Type{Kind: reflect.Uint64, Type: uint64T, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[]"}}, + {"uint64[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint64{}), Elem: &Type{Kind: reflect.Uint64, Type: uint64T, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[2]"}}, + {"uint256[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[]"}}, + {"uint256[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]*big.Int{}), Size: 2, Elem: &Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[2]"}}, + {"bytes32", nil, Type{Kind: reflect.Array, T: FixedBytesTy, Size: 32, Type: reflect.TypeOf([32]byte{}), stringKind: "bytes32"}}, + {"bytes[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][]byte{}), Elem: &Type{Kind: reflect.Slice, Type: reflect.TypeOf([]byte{}), T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[]"}}, + {"bytes[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]byte{}), Elem: &Type{T: BytesTy, Type: reflect.TypeOf([]byte{}), Kind: reflect.Slice, stringKind: "bytes"}, stringKind: "bytes[2]"}}, + {"bytes32[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][32]byte{}), Elem: &Type{Kind: reflect.Array, Type: reflect.TypeOf([32]byte{}), T: FixedBytesTy, Size: 32, stringKind: "bytes32"}, stringKind: "bytes32[]"}}, + {"bytes32[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][32]byte{}), Elem: &Type{Kind: reflect.Array, T: FixedBytesTy, Size: 32, Type: reflect.TypeOf([32]byte{}), stringKind: "bytes32"}, stringKind: "bytes32[2]"}}, + {"string", nil, Type{Kind: reflect.String, T: StringTy, Type: reflect.TypeOf(""), stringKind: "string"}}, + {"string[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]string{}), Elem: &Type{Kind: reflect.String, Type: reflect.TypeOf(""), T: StringTy, stringKind: "string"}, stringKind: "string[]"}}, + {"string[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]string{}), Elem: &Type{Kind: reflect.String, T: StringTy, Type: reflect.TypeOf(""), stringKind: "string"}, stringKind: "string[2]"}}, + {"address", nil, Type{Kind: reflect.Array, Type: addressT, Size: 20, T: AddressTy, stringKind: "address"}}, + {"address[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]common.Address{}), Elem: &Type{Kind: reflect.Array, Type: addressT, Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[]"}}, + {"address[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]common.Address{}), Elem: &Type{Kind: reflect.Array, Type: addressT, Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[2]"}}, // TODO when fixed types are implemented properly - // {"fixed", Type{}}, - // {"fixed128x128", Type{}}, - // {"fixed[]", Type{}}, - // {"fixed[2]", Type{}}, - // {"fixed128x128[]", Type{}}, - // {"fixed128x128[2]", Type{}}, + // {"fixed", nil, Type{}}, + // {"fixed128x128", nil, Type{}}, + // {"fixed[]", nil, Type{}}, + // {"fixed[2]", nil, Type{}}, + // {"fixed128x128[]", nil, Type{}}, + // {"fixed128x128[2]", nil, Type{}}, + {"tuple", []ArgumentMarshaling{{Name: "a", Type: "int64"}}, Type{Kind: reflect.Struct, T: TupleTy, Type: reflect.TypeOf(struct{ A int64 }{}), stringKind: "(int64)", + TupleElems: []*Type{{Kind: reflect.Int64, T: IntTy, Type: reflect.TypeOf(int64(0)), Size: 64, stringKind: "int64"}}, TupleRawNames: []string{"a"}}}, } for _, tt := range tests { - typ, err := NewType(tt.blob) + typ, err := NewType(tt.blob, tt.components) if err != nil { t.Errorf("type %q: failed to parse type string: %v", tt.blob, err) } @@ -109,151 +112,170 @@ func TestTypeRegexp(t *testing.T) { func TestTypeCheck(t *testing.T) { for i, test := range []struct { - typ string - input interface{} - err string + typ string + components []ArgumentMarshaling + input interface{} + err string }{ - {"uint", big.NewInt(1), "unsupported arg type: uint"}, - {"int", big.NewInt(1), "unsupported arg type: int"}, - {"uint256", big.NewInt(1), ""}, - {"uint256[][3][]", [][3][]*big.Int{{{}}}, ""}, - {"uint256[][][3]", [3][][]*big.Int{{{}}}, ""}, - {"uint256[3][][]", [][][3]*big.Int{{{}}}, ""}, - {"uint256[3][3][3]", [3][3][3]*big.Int{{{}}}, ""}, - {"uint8[][]", [][]uint8{}, ""}, - {"int256", big.NewInt(1), ""}, - {"uint8", uint8(1), ""}, - {"uint16", uint16(1), ""}, - {"uint32", uint32(1), ""}, - {"uint64", uint64(1), ""}, - {"int8", int8(1), ""}, - {"int16", int16(1), ""}, - {"int32", int32(1), ""}, - {"int64", int64(1), ""}, - {"uint24", big.NewInt(1), ""}, - {"uint40", big.NewInt(1), ""}, - {"uint48", big.NewInt(1), ""}, - {"uint56", big.NewInt(1), ""}, - {"uint72", big.NewInt(1), ""}, - {"uint80", big.NewInt(1), ""}, - {"uint88", big.NewInt(1), ""}, - {"uint96", big.NewInt(1), ""}, - {"uint104", big.NewInt(1), ""}, - {"uint112", big.NewInt(1), ""}, - {"uint120", big.NewInt(1), ""}, - {"uint128", big.NewInt(1), ""}, - {"uint136", big.NewInt(1), ""}, - {"uint144", big.NewInt(1), ""}, - {"uint152", big.NewInt(1), ""}, - {"uint160", big.NewInt(1), ""}, - {"uint168", big.NewInt(1), ""}, - {"uint176", big.NewInt(1), ""}, - {"uint184", big.NewInt(1), ""}, - {"uint192", big.NewInt(1), ""}, - {"uint200", big.NewInt(1), ""}, - {"uint208", big.NewInt(1), ""}, - {"uint216", big.NewInt(1), ""}, - {"uint224", big.NewInt(1), ""}, - {"uint232", big.NewInt(1), ""}, - {"uint240", big.NewInt(1), ""}, - {"uint248", big.NewInt(1), ""}, - {"int24", big.NewInt(1), ""}, - {"int40", big.NewInt(1), ""}, - {"int48", big.NewInt(1), ""}, - {"int56", big.NewInt(1), ""}, - {"int72", big.NewInt(1), ""}, - {"int80", big.NewInt(1), ""}, - {"int88", big.NewInt(1), ""}, - {"int96", big.NewInt(1), ""}, - {"int104", big.NewInt(1), ""}, - {"int112", big.NewInt(1), ""}, - {"int120", big.NewInt(1), ""}, - {"int128", big.NewInt(1), ""}, - {"int136", big.NewInt(1), ""}, - {"int144", big.NewInt(1), ""}, - {"int152", big.NewInt(1), ""}, - {"int160", big.NewInt(1), ""}, - {"int168", big.NewInt(1), ""}, - {"int176", big.NewInt(1), ""}, - {"int184", big.NewInt(1), ""}, - {"int192", big.NewInt(1), ""}, - {"int200", big.NewInt(1), ""}, - {"int208", big.NewInt(1), ""}, - {"int216", big.NewInt(1), ""}, - {"int224", big.NewInt(1), ""}, - {"int232", big.NewInt(1), ""}, - {"int240", big.NewInt(1), ""}, - {"int248", big.NewInt(1), ""}, - {"uint30", uint8(1), "abi: cannot use uint8 as type ptr as argument"}, - {"uint8", uint16(1), "abi: cannot use uint16 as type uint8 as argument"}, - {"uint8", uint32(1), "abi: cannot use uint32 as type uint8 as argument"}, - {"uint8", uint64(1), "abi: cannot use uint64 as type uint8 as argument"}, - {"uint8", int8(1), "abi: cannot use int8 as type uint8 as argument"}, - {"uint8", int16(1), "abi: cannot use int16 as type uint8 as argument"}, - {"uint8", int32(1), "abi: cannot use int32 as type uint8 as argument"}, - {"uint8", int64(1), "abi: cannot use int64 as type uint8 as argument"}, - {"uint16", uint16(1), ""}, - {"uint16", uint8(1), "abi: cannot use uint8 as type uint16 as argument"}, - {"uint16[]", []uint16{1, 2, 3}, ""}, - {"uint16[]", [3]uint16{1, 2, 3}, ""}, - {"uint16[]", []uint32{1, 2, 3}, "abi: cannot use []uint32 as type [0]uint16 as argument"}, - {"uint16[3]", [3]uint32{1, 2, 3}, "abi: cannot use [3]uint32 as type [3]uint16 as argument"}, - {"uint16[3]", [4]uint16{1, 2, 3}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"}, - {"uint16[3]", []uint16{1, 2, 3}, ""}, - {"uint16[3]", []uint16{1, 2, 3, 4}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"}, - {"address[]", []common.Address{{1}}, ""}, - {"address[1]", []common.Address{{1}}, ""}, - {"address[1]", [1]common.Address{{1}}, ""}, - {"address[2]", [1]common.Address{{1}}, "abi: cannot use [1]array as type [2]array as argument"}, - {"bytes32", [32]byte{}, ""}, - {"bytes31", [31]byte{}, ""}, - {"bytes30", [30]byte{}, ""}, - {"bytes29", [29]byte{}, ""}, - {"bytes28", [28]byte{}, ""}, - {"bytes27", [27]byte{}, ""}, - {"bytes26", [26]byte{}, ""}, - {"bytes25", [25]byte{}, ""}, - {"bytes24", [24]byte{}, ""}, - {"bytes23", [23]byte{}, ""}, - {"bytes22", [22]byte{}, ""}, - {"bytes21", [21]byte{}, ""}, - {"bytes20", [20]byte{}, ""}, - {"bytes19", [19]byte{}, ""}, - {"bytes18", [18]byte{}, ""}, - {"bytes17", [17]byte{}, ""}, - {"bytes16", [16]byte{}, ""}, - {"bytes15", [15]byte{}, ""}, - {"bytes14", [14]byte{}, ""}, - {"bytes13", [13]byte{}, ""}, - {"bytes12", [12]byte{}, ""}, - {"bytes11", [11]byte{}, ""}, - {"bytes10", [10]byte{}, ""}, - {"bytes9", [9]byte{}, ""}, - {"bytes8", [8]byte{}, ""}, - {"bytes7", [7]byte{}, ""}, - {"bytes6", [6]byte{}, ""}, - {"bytes5", [5]byte{}, ""}, - {"bytes4", [4]byte{}, ""}, - {"bytes3", [3]byte{}, ""}, - {"bytes2", [2]byte{}, ""}, - {"bytes1", [1]byte{}, ""}, - {"bytes32", [33]byte{}, "abi: cannot use [33]uint8 as type [32]uint8 as argument"}, - {"bytes32", common.Hash{1}, ""}, - {"bytes31", common.Hash{1}, "abi: cannot use common.Hash as type [31]uint8 as argument"}, - {"bytes31", [32]byte{}, "abi: cannot use [32]uint8 as type [31]uint8 as argument"}, - {"bytes", []byte{0, 1}, ""}, - {"bytes", [2]byte{0, 1}, "abi: cannot use array as type slice as argument"}, - {"bytes", common.Hash{1}, "abi: cannot use array as type slice as argument"}, - {"string", "hello world", ""}, - {"string", string(""), ""}, - {"string", []byte{}, "abi: cannot use slice as type string as argument"}, - {"bytes32[]", [][32]byte{{}}, ""}, - {"function", [24]byte{}, ""}, - {"bytes20", common.Address{}, ""}, - {"address", [20]byte{}, ""}, - {"address", common.Address{}, ""}, + {"uint", nil, big.NewInt(1), "unsupported arg type: uint"}, + {"int", nil, big.NewInt(1), "unsupported arg type: int"}, + {"uint256", nil, big.NewInt(1), ""}, + {"uint256[][3][]", nil, [][3][]*big.Int{{{}}}, ""}, + {"uint256[][][3]", nil, [3][][]*big.Int{{{}}}, ""}, + {"uint256[3][][]", nil, [][][3]*big.Int{{{}}}, ""}, + {"uint256[3][3][3]", nil, [3][3][3]*big.Int{{{}}}, ""}, + {"uint8[][]", nil, [][]uint8{}, ""}, + {"int256", nil, big.NewInt(1), ""}, + {"uint8", nil, uint8(1), ""}, + {"uint16", nil, uint16(1), ""}, + {"uint32", nil, uint32(1), ""}, + {"uint64", nil, uint64(1), ""}, + {"int8", nil, int8(1), ""}, + {"int16", nil, int16(1), ""}, + {"int32", nil, int32(1), ""}, + {"int64", nil, int64(1), ""}, + {"uint24", nil, big.NewInt(1), ""}, + {"uint40", nil, big.NewInt(1), ""}, + {"uint48", nil, big.NewInt(1), ""}, + {"uint56", nil, big.NewInt(1), ""}, + {"uint72", nil, big.NewInt(1), ""}, + {"uint80", nil, big.NewInt(1), ""}, + {"uint88", nil, big.NewInt(1), ""}, + {"uint96", nil, big.NewInt(1), ""}, + {"uint104", nil, big.NewInt(1), ""}, + {"uint112", nil, big.NewInt(1), ""}, + {"uint120", nil, big.NewInt(1), ""}, + {"uint128", nil, big.NewInt(1), ""}, + {"uint136", nil, big.NewInt(1), ""}, + {"uint144", nil, big.NewInt(1), ""}, + {"uint152", nil, big.NewInt(1), ""}, + {"uint160", nil, big.NewInt(1), ""}, + {"uint168", nil, big.NewInt(1), ""}, + {"uint176", nil, big.NewInt(1), ""}, + {"uint184", nil, big.NewInt(1), ""}, + {"uint192", nil, big.NewInt(1), ""}, + {"uint200", nil, big.NewInt(1), ""}, + {"uint208", nil, big.NewInt(1), ""}, + {"uint216", nil, big.NewInt(1), ""}, + {"uint224", nil, big.NewInt(1), ""}, + {"uint232", nil, big.NewInt(1), ""}, + {"uint240", nil, big.NewInt(1), ""}, + {"uint248", nil, big.NewInt(1), ""}, + {"int24", nil, big.NewInt(1), ""}, + {"int40", nil, big.NewInt(1), ""}, + {"int48", nil, big.NewInt(1), ""}, + {"int56", nil, big.NewInt(1), ""}, + {"int72", nil, big.NewInt(1), ""}, + {"int80", nil, big.NewInt(1), ""}, + {"int88", nil, big.NewInt(1), ""}, + {"int96", nil, big.NewInt(1), ""}, + {"int104", nil, big.NewInt(1), ""}, + {"int112", nil, big.NewInt(1), ""}, + {"int120", nil, big.NewInt(1), ""}, + {"int128", nil, big.NewInt(1), ""}, + {"int136", nil, big.NewInt(1), ""}, + {"int144", nil, big.NewInt(1), ""}, + {"int152", nil, big.NewInt(1), ""}, + {"int160", nil, big.NewInt(1), ""}, + {"int168", nil, big.NewInt(1), ""}, + {"int176", nil, big.NewInt(1), ""}, + {"int184", nil, big.NewInt(1), ""}, + {"int192", nil, big.NewInt(1), ""}, + {"int200", nil, big.NewInt(1), ""}, + {"int208", nil, big.NewInt(1), ""}, + {"int216", nil, big.NewInt(1), ""}, + {"int224", nil, big.NewInt(1), ""}, + {"int232", nil, big.NewInt(1), ""}, + {"int240", nil, big.NewInt(1), ""}, + {"int248", nil, big.NewInt(1), ""}, + {"uint30", nil, uint8(1), "abi: cannot use uint8 as type ptr as argument"}, + {"uint8", nil, uint16(1), "abi: cannot use uint16 as type uint8 as argument"}, + {"uint8", nil, uint32(1), "abi: cannot use uint32 as type uint8 as argument"}, + {"uint8", nil, uint64(1), "abi: cannot use uint64 as type uint8 as argument"}, + {"uint8", nil, int8(1), "abi: cannot use int8 as type uint8 as argument"}, + {"uint8", nil, int16(1), "abi: cannot use int16 as type uint8 as argument"}, + {"uint8", nil, int32(1), "abi: cannot use int32 as type uint8 as argument"}, + {"uint8", nil, int64(1), "abi: cannot use int64 as type uint8 as argument"}, + {"uint16", nil, uint16(1), ""}, + {"uint16", nil, uint8(1), "abi: cannot use uint8 as type uint16 as argument"}, + {"uint16[]", nil, []uint16{1, 2, 3}, ""}, + {"uint16[]", nil, [3]uint16{1, 2, 3}, ""}, + {"uint16[]", nil, []uint32{1, 2, 3}, "abi: cannot use []uint32 as type [0]uint16 as argument"}, + {"uint16[3]", nil, [3]uint32{1, 2, 3}, "abi: cannot use [3]uint32 as type [3]uint16 as argument"}, + {"uint16[3]", nil, [4]uint16{1, 2, 3}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"}, + {"uint16[3]", nil, []uint16{1, 2, 3}, ""}, + {"uint16[3]", nil, []uint16{1, 2, 3, 4}, "abi: cannot use [4]uint16 as type [3]uint16 as argument"}, + {"address[]", nil, []common.Address{{1}}, ""}, + {"address[1]", nil, []common.Address{{1}}, ""}, + {"address[1]", nil, [1]common.Address{{1}}, ""}, + {"address[2]", nil, [1]common.Address{{1}}, "abi: cannot use [1]array as type [2]array as argument"}, + {"bytes32", nil, [32]byte{}, ""}, + {"bytes31", nil, [31]byte{}, ""}, + {"bytes30", nil, [30]byte{}, ""}, + {"bytes29", nil, [29]byte{}, ""}, + {"bytes28", nil, [28]byte{}, ""}, + {"bytes27", nil, [27]byte{}, ""}, + {"bytes26", nil, [26]byte{}, ""}, + {"bytes25", nil, [25]byte{}, ""}, + {"bytes24", nil, [24]byte{}, ""}, + {"bytes23", nil, [23]byte{}, ""}, + {"bytes22", nil, [22]byte{}, ""}, + {"bytes21", nil, [21]byte{}, ""}, + {"bytes20", nil, [20]byte{}, ""}, + {"bytes19", nil, [19]byte{}, ""}, + {"bytes18", nil, [18]byte{}, ""}, + {"bytes17", nil, [17]byte{}, ""}, + {"bytes16", nil, [16]byte{}, ""}, + {"bytes15", nil, [15]byte{}, ""}, + {"bytes14", nil, [14]byte{}, ""}, + {"bytes13", nil, [13]byte{}, ""}, + {"bytes12", nil, [12]byte{}, ""}, + {"bytes11", nil, [11]byte{}, ""}, + {"bytes10", nil, [10]byte{}, ""}, + {"bytes9", nil, [9]byte{}, ""}, + {"bytes8", nil, [8]byte{}, ""}, + {"bytes7", nil, [7]byte{}, ""}, + {"bytes6", nil, [6]byte{}, ""}, + {"bytes5", nil, [5]byte{}, ""}, + {"bytes4", nil, [4]byte{}, ""}, + {"bytes3", nil, [3]byte{}, ""}, + {"bytes2", nil, [2]byte{}, ""}, + {"bytes1", nil, [1]byte{}, ""}, + {"bytes32", nil, [33]byte{}, "abi: cannot use [33]uint8 as type [32]uint8 as argument"}, + {"bytes32", nil, common.Hash{1}, ""}, + {"bytes31", nil, common.Hash{1}, "abi: cannot use common.Hash as type [31]uint8 as argument"}, + {"bytes31", nil, [32]byte{}, "abi: cannot use [32]uint8 as type [31]uint8 as argument"}, + {"bytes", nil, []byte{0, 1}, ""}, + {"bytes", nil, [2]byte{0, 1}, "abi: cannot use array as type slice as argument"}, + {"bytes", nil, common.Hash{1}, "abi: cannot use array as type slice as argument"}, + {"string", nil, "hello world", ""}, + {"string", nil, string(""), ""}, + {"string", nil, []byte{}, "abi: cannot use slice as type string as argument"}, + {"bytes32[]", nil, [][32]byte{{}}, ""}, + {"function", nil, [24]byte{}, ""}, + {"bytes20", nil, common.Address{}, ""}, + {"address", nil, [20]byte{}, ""}, + {"address", nil, common.Address{}, ""}, + {"bytes32[]]", nil, "", "invalid arg type in abi"}, + {"invalidType", nil, "", "unsupported arg type: invalidType"}, + {"invalidSlice[]", nil, "", "unsupported arg type: invalidSlice"}, + // simple tuple + {"tuple", []ArgumentMarshaling{{Name: "a", Type: "uint256"}, {Name: "b", Type: "uint256"}}, struct { + A *big.Int + B *big.Int + }{}, ""}, + // tuple slice + {"tuple[]", []ArgumentMarshaling{{Name: "a", Type: "uint256"}, {Name: "b", Type: "uint256"}}, []struct { + A *big.Int + B *big.Int + }{}, ""}, + // tuple array + {"tuple[2]", []ArgumentMarshaling{{Name: "a", Type: "uint256"}, {Name: "b", Type: "uint256"}}, []struct { + A *big.Int + B *big.Int + }{{big.NewInt(0), big.NewInt(0)}, {big.NewInt(0), big.NewInt(0)}}, ""}, } { - typ, err := NewType(test.typ) + typ, err := NewType(test.typ, test.components) if err != nil && len(test.err) == 0 { t.Fatal("unexpected parse error:", err) } else if err != nil && len(test.err) != 0 { diff --git a/accounts/abi/unpack_test.go b/accounts/abi/unpack_test.go index 24b5dc0f11..2906bec20a 100644 --- a/accounts/abi/unpack_test.go +++ b/accounts/abi/unpack_test.go @@ -56,6 +56,23 @@ var unpackTests = []unpackTest{ enc: "0000000000000000000000000000000000000000000000000000000000000001", want: true, }, + { + def: `[{ "type": "bool" }]`, + enc: "0000000000000000000000000000000000000000000000000000000000000000", + want: false, + }, + { + def: `[{ "type": "bool" }]`, + enc: "0000000000000000000000000000000000000000000000000001000000000001", + want: false, + err: "abi: improperly encoded boolean value", + }, + { + def: `[{ "type": "bool" }]`, + enc: "0000000000000000000000000000000000000000000000000000000000000003", + want: false, + err: "abi: improperly encoded boolean value", + }, { def: `[{"type": "uint32"}]`, enc: "0000000000000000000000000000000000000000000000000000000000000001", @@ -100,6 +117,11 @@ var unpackTests = []unpackTest{ enc: "0000000000000000000000000000000000000000000000000000000000000001", want: big.NewInt(1), }, + { + def: `[{"type": "int256"}]`, + enc: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + want: big.NewInt(-1), + }, { def: `[{"type": "address"}]`, enc: "0000000000000000000000000100000000000000000000000000000000000000", @@ -151,7 +173,7 @@ var unpackTests = []unpackTest{ // multi dimensional, if these pass, all types that don't require length prefix should pass { def: `[{"type": "uint8[][]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000000000000000E0000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", want: [][]uint8{{1, 2}, {1, 2}}, }, { @@ -161,7 +183,7 @@ var unpackTests = []unpackTest{ }, { def: `[{"type": "uint8[][2]"}]`, - enc: "000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000001", + enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000001", want: [2][]uint8{{1}, {1}}, }, { @@ -169,6 +191,11 @@ var unpackTests = []unpackTest{ enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", want: [][2]uint8{{1, 2}}, }, + { + def: `[{"type": "uint8[2][]"}]`, + enc: "000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + want: [][2]uint8{{1, 2}, {1, 2}}, + }, { def: `[{"type": "uint16[]"}]`, enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", @@ -214,6 +241,16 @@ var unpackTests = []unpackTest{ enc: "000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003", want: [3]*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)}, }, + { + def: `[{"type": "string[4]"}]`, + enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000140000000000000000000000000000000000000000000000000000000000000000548656c6c6f0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005576f726c64000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000b476f2d657468657265756d0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008457468657265756d000000000000000000000000000000000000000000000000", + want: [4]string{"Hello", "World", "Go-ethereum", "Ethereum"}, + }, + { + def: `[{"type": "string[]"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000008457468657265756d000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000b676f2d657468657265756d000000000000000000000000000000000000000000", + want: []string{"Ethereum", "go-ethereum"}, + }, { def: `[{"type": "int8[]"}]`, enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", @@ -273,6 +310,53 @@ var unpackTests = []unpackTest{ Int2 *big.Int }{big.NewInt(1), big.NewInt(2)}, }, + { + def: `[{"name":"int_one","type":"int256"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + want: struct { + IntOne *big.Int + }{big.NewInt(1)}, + }, + { + def: `[{"name":"int__one","type":"int256"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + want: struct { + IntOne *big.Int + }{big.NewInt(1)}, + }, + { + def: `[{"name":"int_one_","type":"int256"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + want: struct { + IntOne *big.Int + }{big.NewInt(1)}, + }, + { + def: `[{"name":"int_one","type":"int256"}, {"name":"intone","type":"int256"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + want: struct { + IntOne *big.Int + Intone *big.Int + }{big.NewInt(1), big.NewInt(2)}, + }, + { + def: `[{"name":"___","type":"int256"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + want: struct { + IntOne *big.Int + Intone *big.Int + }{}, + err: "abi: purely underscored output cannot unpack to struct", + }, + { + def: `[{"name":"int_one","type":"int256"},{"name":"IntOne","type":"int256"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + want: struct { + Int1 *big.Int + Int2 *big.Int + }{}, + err: "abi: multiple outputs mapping to the same struct field 'IntOne'", + }, { def: `[{"name":"int","type":"int256"},{"name":"Int","type":"int256"}]`, enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", @@ -337,6 +421,55 @@ func TestUnpack(t *testing.T) { } } +func TestUnpackSetDynamicArrayOutput(t *testing.T) { + abi, err := JSON(strings.NewReader(`[{"constant":true,"inputs":[],"name":"testDynamicFixedBytes15","outputs":[{"name":"","type":"bytes15[]"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":true,"inputs":[],"name":"testDynamicFixedBytes32","outputs":[{"name":"","type":"bytes32[]"}],"payable":false,"stateMutability":"view","type":"function"}]`)) + if err != nil { + t.Fatal(err) + } + + var ( + marshalledReturn32 = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000230783132333435363738393000000000000000000000000000000000000000003078303938373635343332310000000000000000000000000000000000000000") + marshalledReturn15 = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000230783031323334350000000000000000000000000000000000000000000000003078393837363534000000000000000000000000000000000000000000000000") + + out32 [][32]byte + out15 [][15]byte + ) + + // test 32 + err = abi.Unpack(&out32, "testDynamicFixedBytes32", marshalledReturn32) + if err != nil { + t.Fatal(err) + } + if len(out32) != 2 { + t.Fatalf("expected array with 2 values, got %d", len(out32)) + } + expected := common.Hex2Bytes("3078313233343536373839300000000000000000000000000000000000000000") + if !bytes.Equal(out32[0][:], expected) { + t.Errorf("expected %x, got %x\n", expected, out32[0]) + } + expected = common.Hex2Bytes("3078303938373635343332310000000000000000000000000000000000000000") + if !bytes.Equal(out32[1][:], expected) { + t.Errorf("expected %x, got %x\n", expected, out32[1]) + } + + // test 15 + err = abi.Unpack(&out15, "testDynamicFixedBytes32", marshalledReturn15) + if err != nil { + t.Fatal(err) + } + if len(out15) != 2 { + t.Fatalf("expected array with 2 values, got %d", len(out15)) + } + expected = common.Hex2Bytes("307830313233343500000000000000") + if !bytes.Equal(out15[0][:], expected) { + t.Errorf("expected %x, got %x\n", expected, out15[0]) + } + expected = common.Hex2Bytes("307839383736353400000000000000") + if !bytes.Equal(out15[1][:], expected) { + t.Errorf("expected %x, got %x\n", expected, out15[1]) + } +} + type methodMultiOutput struct { Int *big.Int String string @@ -440,6 +573,68 @@ func TestMultiReturnWithArray(t *testing.T) { } } +func TestMultiReturnWithStringArray(t *testing.T) { + const definition = `[{"name" : "multi", "outputs": [{"name": "","type": "uint256[3]"},{"name": "","type": "address"},{"name": "","type": "string[2]"},{"name": "","type": "bool"}]}]` + abi, err := JSON(strings.NewReader(definition)) + if err != nil { + t.Fatal(err) + } + buff := new(bytes.Buffer) + buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000005c1b78ea0000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000001a055690d9db80000000000000000000000000000ab1257528b3782fb40d7ed5f72e624b744dffb2f00000000000000000000000000000000000000000000000000000000000000c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000008457468657265756d000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001048656c6c6f2c20457468657265756d2100000000000000000000000000000000")) + temp, _ := big.NewInt(0).SetString("30000000000000000000", 10) + ret1, ret1Exp := new([3]*big.Int), [3]*big.Int{big.NewInt(1545304298), big.NewInt(6), temp} + ret2, ret2Exp := new(common.Address), common.HexToAddress("ab1257528b3782fb40d7ed5f72e624b744dffb2f") + ret3, ret3Exp := new([2]string), [2]string{"Ethereum", "Hello, Ethereum!"} + ret4, ret4Exp := new(bool), false + if err := abi.Unpack(&[]interface{}{ret1, ret2, ret3, ret4}, "multi", buff.Bytes()); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(*ret1, ret1Exp) { + t.Error("big.Int array result", *ret1, "!= Expected", ret1Exp) + } + if !reflect.DeepEqual(*ret2, ret2Exp) { + t.Error("address result", *ret2, "!= Expected", ret2Exp) + } + if !reflect.DeepEqual(*ret3, ret3Exp) { + t.Error("string array result", *ret3, "!= Expected", ret3Exp) + } + if !reflect.DeepEqual(*ret4, ret4Exp) { + t.Error("bool result", *ret4, "!= Expected", ret4Exp) + } +} + +func TestMultiReturnWithStringSlice(t *testing.T) { + const definition = `[{"name" : "multi", "outputs": [{"name": "","type": "string[]"},{"name": "","type": "uint256[]"}]}]` + abi, err := JSON(strings.NewReader(definition)) + if err != nil { + t.Fatal(err) + } + buff := new(bytes.Buffer) + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) // output[0] offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000120")) // output[1] offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // output[0] length + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040")) // output[0][0] offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000080")) // output[0][1] offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000008")) // output[0][0] length + buff.Write(common.Hex2Bytes("657468657265756d000000000000000000000000000000000000000000000000")) // output[0][0] value + buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000b")) // output[0][1] length + buff.Write(common.Hex2Bytes("676f2d657468657265756d000000000000000000000000000000000000000000")) // output[0][1] value + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // output[1] length + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000064")) // output[1][0] value + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000065")) // output[1][1] value + ret1, ret1Exp := new([]string), []string{"ethereum", "go-ethereum"} + ret2, ret2Exp := new([]*big.Int), []*big.Int{big.NewInt(100), big.NewInt(101)} + if err := abi.Unpack(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(*ret1, ret1Exp) { + t.Error("string slice result", *ret1, "!= Expected", ret1Exp) + } + if !reflect.DeepEqual(*ret2, ret2Exp) { + t.Error("uint256 slice result", *ret2, "!= Expected", ret2Exp) + } +} + func TestMultiReturnWithDeeplyNestedArray(t *testing.T) { // Similar to TestMultiReturnWithArray, but with a special case in mind: // values of nested static arrays count towards the size as well, and any element following @@ -729,6 +924,108 @@ func TestUnmarshal(t *testing.T) { } } +func TestUnpackTuple(t *testing.T) { + const simpleTuple = `[{"name":"tuple","constant":false,"outputs":[{"type":"tuple","name":"ret","components":[{"type":"int256","name":"a"},{"type":"int256","name":"b"}]}]}]` + abi, err := JSON(strings.NewReader(simpleTuple)) + if err != nil { + t.Fatal(err) + } + buff := new(bytes.Buffer) + + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // ret[a] = 1 + buff.Write(common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")) // ret[b] = -1 + + v := struct { + Ret struct { + A *big.Int + B *big.Int + } + }{Ret: struct { + A *big.Int + B *big.Int + }{new(big.Int), new(big.Int)}} + + err = abi.Unpack(&v, "tuple", buff.Bytes()) + if err != nil { + t.Error(err) + } else { + if v.Ret.A.Cmp(big.NewInt(1)) != 0 { + t.Errorf("unexpected value unpacked: want %x, got %x", 1, v.Ret.A) + } + if v.Ret.B.Cmp(big.NewInt(-1)) != 0 { + t.Errorf("unexpected value unpacked: want %x, got %x", v.Ret.B, -1) + } + } + + // Test nested tuple + const nestedTuple = `[{"name":"tuple","constant":false,"outputs":[ + {"type":"tuple","name":"s","components":[{"type":"uint256","name":"a"},{"type":"uint256[]","name":"b"},{"type":"tuple[]","name":"c","components":[{"name":"x", "type":"uint256"},{"name":"y","type":"uint256"}]}]}, + {"type":"tuple","name":"t","components":[{"name":"x", "type":"uint256"},{"name":"y","type":"uint256"}]}, + {"type":"uint256","name":"a"} + ]}]` + + abi, err = JSON(strings.NewReader(nestedTuple)) + if err != nil { + t.Fatal(err) + } + buff.Reset() + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000080")) // s offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")) // t.X = 0 + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // t.Y = 1 + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // a = 1 + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // s.A = 1 + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000060")) // s.B offset + buff.Write(common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000000000c0")) // s.C offset + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // s.B length + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // s.B[0] = 1 + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // s.B[0] = 2 + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // s.C length + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // s.C[0].X + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // s.C[0].Y + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002")) // s.C[1].X + buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // s.C[1].Y + + type T struct { + X *big.Int `abi:"x"` + Z *big.Int `abi:"y"` // Test whether the abi tag works. + } + + type S struct { + A *big.Int + B []*big.Int + C []T + } + + type Ret struct { + FieldS S `abi:"s"` + FieldT T `abi:"t"` + A *big.Int + } + var ret Ret + var expected = Ret{ + FieldS: S{ + A: big.NewInt(1), + B: []*big.Int{big.NewInt(1), big.NewInt(2)}, + C: []T{ + {big.NewInt(1), big.NewInt(2)}, + {big.NewInt(2), big.NewInt(1)}, + }, + }, + FieldT: T{ + big.NewInt(0), big.NewInt(1), + }, + A: big.NewInt(1), + } + + err = abi.Unpack(&ret, "tuple", buff.Bytes()) + if err != nil { + t.Error(err) + } + if reflect.DeepEqual(ret, expected) { + t.Error("unexpected unpack value") + } +} + func TestOOMMaliciousInput(t *testing.T) { oomTests := []unpackTest{ { From c746706db982db042f0f3da40f9f5823d7d626f5 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 6 Oct 2023 17:22:07 +0700 Subject: [PATCH 082/119] [WIP] Split Unpack and UnpackIntoInterface --- accounts/abi/abi.go | 189 +++++-- accounts/abi/abi_test.go | 669 +++++++++++++++++------ accounts/abi/argument.go | 230 +++----- accounts/abi/error.go | 18 +- accounts/abi/event.go | 80 ++- accounts/abi/event_test.go | 189 +++++-- accounts/abi/method.go | 155 ++++-- accounts/abi/pack.go | 34 +- accounts/abi/pack_test.go | 638 ++-------------------- accounts/abi/packing_test.go | 990 +++++++++++++++++++++++++++++++++++ accounts/abi/reflect.go | 149 ++++-- accounts/abi/type.go | 166 ++++-- accounts/abi/type_test.go | 181 +++++-- accounts/abi/unpack.go | 97 ++-- accounts/abi/unpack_test.go | 425 +++++---------- accounts/abi/utils.go | 39 ++ common/math/big.go | 14 +- 17 files changed, 2689 insertions(+), 1574 deletions(-) create mode 100644 accounts/abi/packing_test.go create mode 100644 accounts/abi/utils.go diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 08d5db9798..7d5d6291b0 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -19,8 +19,11 @@ package abi import ( "bytes" "encoding/json" + "errors" "fmt" "io" + + "github.com/tomochain/tomochain/common" ) // The ABI holds information about a contract's context and available @@ -30,6 +33,12 @@ type ABI struct { Constructor Method Methods map[string]Method Events map[string]Event + + // Additional "special" functions introduced in solidity v0.6.0. + // It's separated from the original default fallback. Each contract + // can only define one fallback and receive function. + Fallback Method // Note it's also used to represent legacy fallback before v0.6.0 + Receive Method } // JSON returns a parsed ABI interface and error if it failed. @@ -40,7 +49,6 @@ func JSON(reader io.Reader) (ABI, error) { if err := dec.Decode(&abi); err != nil { return ABI{}, err } - return abi, nil } @@ -68,80 +76,181 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { return nil, err } // Pack up the method ID too if not a constructor and return - return append(method.Id(), arguments...), nil + return append(method.ID, arguments...), nil } -// Unpack output in v according to the abi specification -func (abi ABI) Unpack(v interface{}, name string, output []byte) (err error) { - if len(output) == 0 { - return fmt.Errorf("abi: unmarshalling empty output") - } +func (abi ABI) getArguments(name string, data []byte) (Arguments, error) { // since there can't be naming collisions with contracts and events, // we need to decide whether we're calling a method or an event + var args Arguments if method, ok := abi.Methods[name]; ok { - if len(output)%32 != 0 { - return fmt.Errorf("abi: improperly formatted output: %s - Bytes: [%+v]", string(output), output) + if len(data)%32 != 0 { + return nil, fmt.Errorf("abi: improperly formatted output: %s - Bytes: [%+v]", string(data), data) } - return method.Outputs.Unpack(v, output) - } else if event, ok := abi.Events[name]; ok { - return event.Inputs.Unpack(v, output) + args = method.Outputs + } + if event, ok := abi.Events[name]; ok { + args = event.Inputs + } + if args == nil { + return nil, errors.New("abi: could not locate named method or event") + } + return args, nil +} + +// Unpack unpacks the output according to the abi specification. +func (abi ABI) Unpack(name string, data []byte) ([]interface{}, error) { + args, err := abi.getArguments(name, data) + if err != nil { + return nil, err + } + return args.Unpack(data) +} + +// UnpackIntoInterface unpacks the output in v according to the abi specification. +// It performs an additional copy. Please only use, if you want to unpack into a +// structure that does not strictly conform to the abi structure (e.g. has additional arguments) +func (abi ABI) UnpackIntoInterface(v interface{}, name string, data []byte) error { + args, err := abi.getArguments(name, data) + if err != nil { + return err + } + unpacked, err := args.Unpack(data) + if err != nil { + return err + } + return args.Copy(v, unpacked) +} + +// UnpackIntoMap unpacks a log into the provided map[string]interface{}. +func (abi ABI) UnpackIntoMap(v map[string]interface{}, name string, data []byte) (err error) { + args, err := abi.getArguments(name, data) + if err != nil { + return err } - return fmt.Errorf("abi: could not locate named method or event") + return args.UnpackIntoMap(v, data) } -// UnmarshalJSON implements json.Unmarshaler interface +// UnmarshalJSON implements json.Unmarshaler interface. func (abi *ABI) UnmarshalJSON(data []byte) error { var fields []struct { - Type string - Name string - Constant bool + Type string + Name string + Inputs []Argument + Outputs []Argument + + // Status indicator which can be: "pure", "view", + // "nonpayable" or "payable". + StateMutability string + + // Deprecated Status indicators, but removed in v0.6.0. + Constant bool // True if function is either pure or view + Payable bool // True if function is payable + + // Event relevant indicator represents the event is + // declared as anonymous. Anonymous bool - Inputs []Argument - Outputs []Argument } - if err := json.Unmarshal(data, &fields); err != nil { return err } - abi.Methods = make(map[string]Method) abi.Events = make(map[string]Event) for _, field := range fields { switch field.Type { case "constructor": - abi.Constructor = Method{ - Inputs: field.Inputs, + abi.Constructor = NewMethod("", "", Constructor, field.StateMutability, field.Constant, field.Payable, field.Inputs, nil) + case "function": + name := abi.overloadedMethodName(field.Name) + abi.Methods[name] = NewMethod(name, field.Name, Function, field.StateMutability, field.Constant, field.Payable, field.Inputs, field.Outputs) + case "fallback": + // New introduced function type in v0.6.0, check more detail + // here https://solidity.readthedocs.io/en/v0.6.0/contracts.html#fallback-function + if abi.HasFallback() { + return errors.New("only single fallback is allowed") } - // empty defaults to function according to the abi spec - case "function", "": - abi.Methods[field.Name] = Method{ - Name: field.Name, - Const: field.Constant, - Inputs: field.Inputs, - Outputs: field.Outputs, + abi.Fallback = NewMethod("", "", Fallback, field.StateMutability, field.Constant, field.Payable, nil, nil) + case "receive": + // New introduced function type in v0.6.0, check more detail + // here https://solidity.readthedocs.io/en/v0.6.0/contracts.html#fallback-function + if abi.HasReceive() { + return errors.New("only single receive is allowed") } - case "event": - abi.Events[field.Name] = Event{ - Name: field.Name, - Anonymous: field.Anonymous, - Inputs: field.Inputs, + if field.StateMutability != "payable" { + return errors.New("the statemutability of receive can only be payable") } + abi.Receive = NewMethod("", "", Receive, field.StateMutability, field.Constant, field.Payable, nil, nil) + case "event": + name := abi.overloadedEventName(field.Name) + abi.Events[name] = NewEvent(name, field.Name, field.Anonymous, field.Inputs) + default: + return fmt.Errorf("abi: could not recognize type %v of field %v", field.Type, field.Name) } } - return nil } -// MethodById looks up a method by the 4-byte id -// returns nil if none found +// overloadedMethodName returns the next available name for a given function. +// Needed since solidity allows for function overload. +// +// e.g. if the abi contains Methods send, send1 +// overloadedMethodName would return send2 for input send. +func (abi *ABI) overloadedMethodName(rawName string) string { + name := rawName + _, ok := abi.Methods[name] + for idx := 0; ok; idx++ { + name = fmt.Sprintf("%s%d", rawName, idx) + _, ok = abi.Methods[name] + } + return name +} + +// overloadedEventName returns the next available name for a given event. +// Needed since solidity allows for event overload. +// +// e.g. if the abi contains events received, received1 +// overloadedEventName would return received2 for input received. +func (abi *ABI) overloadedEventName(rawName string) string { + name := rawName + _, ok := abi.Events[name] + for idx := 0; ok; idx++ { + name = fmt.Sprintf("%s%d", rawName, idx) + _, ok = abi.Events[name] + } + return name +} + +// MethodById looks up a method by the 4-byte id, +// returns nil if none found. func (abi *ABI) MethodById(sigdata []byte) (*Method, error) { if len(sigdata) < 4 { - return nil, fmt.Errorf("data too short (% bytes) for abi method lookup", len(sigdata)) + return nil, fmt.Errorf("data too short (%d bytes) for abi method lookup", len(sigdata)) } for _, method := range abi.Methods { - if bytes.Equal(method.Id(), sigdata[:4]) { + if bytes.Equal(method.ID, sigdata[:4]) { return &method, nil } } return nil, fmt.Errorf("no method with id: %#x", sigdata[:4]) } + +// EventByID looks an event up by its topic hash in the +// ABI and returns nil if none found. +func (abi *ABI) EventByID(topic common.Hash) (*Event, error) { + for _, event := range abi.Events { + if bytes.Equal(event.ID.Bytes(), topic.Bytes()) { + return &event, nil + } + } + return nil, fmt.Errorf("no event with id: %#x", topic.Hex()) +} + +// HasFallback returns an indicator whether a fallback function is included. +func (abi *ABI) HasFallback() bool { + return abi.Fallback.Type == Fallback +} + +// HasReceive returns an indicator whether a receive function is included. +func (abi *ABI) HasReceive() bool { + return abi.Receive.Type == Receive +} diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 354668e206..67af037633 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -20,64 +20,111 @@ import ( "bytes" "encoding/hex" "fmt" - "log" "math/big" "reflect" "strings" "testing" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/crypto" ) const jsondata = ` [ - { "type" : "function", "name" : "balance", "constant" : true }, - { "type" : "function", "name" : "send", "constant" : false, "inputs" : [ { "name" : "amount", "type" : "uint256" } ] } + { "type" : "function", "name" : ""}, + { "type" : "function", "name" : "balance", "stateMutability" : "view" }, + { "type" : "function", "name" : "send", "inputs" : [ { "name" : "amount", "type" : "uint256" } ] }, + { "type" : "function", "name" : "test", "inputs" : [ { "name" : "number", "type" : "uint32" } ] }, + { "type" : "function", "name" : "string", "inputs" : [ { "name" : "inputs", "type" : "string" } ] }, + { "type" : "function", "name" : "bool", "inputs" : [ { "name" : "inputs", "type" : "bool" } ] }, + { "type" : "function", "name" : "address", "inputs" : [ { "name" : "inputs", "type" : "address" } ] }, + { "type" : "function", "name" : "uint64[2]", "inputs" : [ { "name" : "inputs", "type" : "uint64[2]" } ] }, + { "type" : "function", "name" : "uint64[]", "inputs" : [ { "name" : "inputs", "type" : "uint64[]" } ] }, + { "type" : "function", "name" : "int8", "inputs" : [ { "name" : "inputs", "type" : "int8" } ] }, + { "type" : "function", "name" : "foo", "inputs" : [ { "name" : "inputs", "type" : "uint32" } ] }, + { "type" : "function", "name" : "bar", "inputs" : [ { "name" : "inputs", "type" : "uint32" }, { "name" : "string", "type" : "uint16" } ] }, + { "type" : "function", "name" : "slice", "inputs" : [ { "name" : "inputs", "type" : "uint32[2]" } ] }, + { "type" : "function", "name" : "slice256", "inputs" : [ { "name" : "inputs", "type" : "uint256[2]" } ] }, + { "type" : "function", "name" : "sliceAddress", "inputs" : [ { "name" : "inputs", "type" : "address[]" } ] }, + { "type" : "function", "name" : "sliceMultiAddress", "inputs" : [ { "name" : "a", "type" : "address[]" }, { "name" : "b", "type" : "address[]" } ] }, + { "type" : "function", "name" : "nestedArray", "inputs" : [ { "name" : "a", "type" : "uint256[2][2]" }, { "name" : "b", "type" : "address[]" } ] }, + { "type" : "function", "name" : "nestedArray2", "inputs" : [ { "name" : "a", "type" : "uint8[][2]" } ] }, + { "type" : "function", "name" : "nestedSlice", "inputs" : [ { "name" : "a", "type" : "uint8[][]" } ] }, + { "type" : "function", "name" : "receive", "inputs" : [ { "name" : "memo", "type" : "bytes" }], "outputs" : [], "payable" : true, "stateMutability" : "payable" }, + { "type" : "function", "name" : "fixedArrStr", "stateMutability" : "view", "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr", "type" : "uint256[2]" } ] }, + { "type" : "function", "name" : "fixedArrBytes", "stateMutability" : "view", "inputs" : [ { "name" : "bytes", "type" : "bytes" }, { "name" : "fixedArr", "type" : "uint256[2]" } ] }, + { "type" : "function", "name" : "mixedArrStr", "stateMutability" : "view", "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr", "type" : "uint256[2]" }, { "name" : "dynArr", "type" : "uint256[]" } ] }, + { "type" : "function", "name" : "doubleFixedArrStr", "stateMutability" : "view", "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr1", "type" : "uint256[2]" }, { "name" : "fixedArr2", "type" : "uint256[3]" } ] }, + { "type" : "function", "name" : "multipleMixedArrStr", "stateMutability" : "view", "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr1", "type" : "uint256[2]" }, { "name" : "dynArr", "type" : "uint256[]" }, { "name" : "fixedArr2", "type" : "uint256[3]" } ] }, + { "type" : "function", "name" : "overloadedNames", "stateMutability" : "view", "inputs": [ { "components": [ { "internalType": "uint256", "name": "_f", "type": "uint256" }, { "internalType": "uint256", "name": "__f", "type": "uint256"}, { "internalType": "uint256", "name": "f", "type": "uint256"}],"internalType": "struct Overloader.F", "name": "f","type": "tuple"}]} ]` -const jsondata2 = ` -[ - { "type" : "function", "name" : "balance", "constant" : true }, - { "type" : "function", "name" : "send", "constant" : false, "inputs" : [ { "name" : "amount", "type" : "uint256" } ] }, - { "type" : "function", "name" : "test", "constant" : false, "inputs" : [ { "name" : "number", "type" : "uint32" } ] }, - { "type" : "function", "name" : "string", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "string" } ] }, - { "type" : "function", "name" : "bool", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "bool" } ] }, - { "type" : "function", "name" : "address", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "address" } ] }, - { "type" : "function", "name" : "uint64[2]", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint64[2]" } ] }, - { "type" : "function", "name" : "uint64[]", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint64[]" } ] }, - { "type" : "function", "name" : "foo", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint32" } ] }, - { "type" : "function", "name" : "bar", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint32" }, { "name" : "string", "type" : "uint16" } ] }, - { "type" : "function", "name" : "slice", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint32[2]" } ] }, - { "type" : "function", "name" : "slice256", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "uint256[2]" } ] }, - { "type" : "function", "name" : "sliceAddress", "constant" : false, "inputs" : [ { "name" : "inputs", "type" : "address[]" } ] }, - { "type" : "function", "name" : "sliceMultiAddress", "constant" : false, "inputs" : [ { "name" : "a", "type" : "address[]" }, { "name" : "b", "type" : "address[]" } ] }, - { "type" : "function", "name" : "nestedArray", "constant" : false, "inputs" : [ { "name" : "a", "type" : "uint256[2][2]" }, { "name" : "b", "type" : "address[]" } ] }, - { "type" : "function", "name" : "nestedArray2", "constant" : false, "inputs" : [ { "name" : "a", "type" : "uint8[][2]" } ] }, - { "type" : "function", "name" : "nestedSlice", "constant" : false, "inputs" : [ { "name" : "a", "type" : "uint8[][]" } ] } -]` +var ( + Uint256, _ = NewType("uint256", "", nil) + Uint32, _ = NewType("uint32", "", nil) + Uint16, _ = NewType("uint16", "", nil) + String, _ = NewType("string", "", nil) + Bool, _ = NewType("bool", "", nil) + Bytes, _ = NewType("bytes", "", nil) + Address, _ = NewType("address", "", nil) + Uint64Arr, _ = NewType("uint64[]", "", nil) + AddressArr, _ = NewType("address[]", "", nil) + Int8, _ = NewType("int8", "", nil) + // Special types for testing + Uint32Arr2, _ = NewType("uint32[2]", "", nil) + Uint64Arr2, _ = NewType("uint64[2]", "", nil) + Uint256Arr, _ = NewType("uint256[]", "", nil) + Uint256Arr2, _ = NewType("uint256[2]", "", nil) + Uint256Arr3, _ = NewType("uint256[3]", "", nil) + Uint256ArrNested, _ = NewType("uint256[2][2]", "", nil) + Uint8ArrNested, _ = NewType("uint8[][2]", "", nil) + Uint8SliceNested, _ = NewType("uint8[][]", "", nil) + TupleF, _ = NewType("tuple", "struct Overloader.F", []ArgumentMarshaling{ + {Name: "_f", Type: "uint256"}, + {Name: "__f", Type: "uint256"}, + {Name: "f", Type: "uint256"}}) +) + +var methods = map[string]Method{ + "": NewMethod("", "", Function, "", false, false, nil, nil), + "balance": NewMethod("balance", "balance", Function, "view", false, false, nil, nil), + "send": NewMethod("send", "send", Function, "", false, false, []Argument{{"amount", Uint256, false}}, nil), + "test": NewMethod("test", "test", Function, "", false, false, []Argument{{"number", Uint32, false}}, nil), + "string": NewMethod("string", "string", Function, "", false, false, []Argument{{"inputs", String, false}}, nil), + "bool": NewMethod("bool", "bool", Function, "", false, false, []Argument{{"inputs", Bool, false}}, nil), + "address": NewMethod("address", "address", Function, "", false, false, []Argument{{"inputs", Address, false}}, nil), + "uint64[]": NewMethod("uint64[]", "uint64[]", Function, "", false, false, []Argument{{"inputs", Uint64Arr, false}}, nil), + "uint64[2]": NewMethod("uint64[2]", "uint64[2]", Function, "", false, false, []Argument{{"inputs", Uint64Arr2, false}}, nil), + "int8": NewMethod("int8", "int8", Function, "", false, false, []Argument{{"inputs", Int8, false}}, nil), + "foo": NewMethod("foo", "foo", Function, "", false, false, []Argument{{"inputs", Uint32, false}}, nil), + "bar": NewMethod("bar", "bar", Function, "", false, false, []Argument{{"inputs", Uint32, false}, {"string", Uint16, false}}, nil), + "slice": NewMethod("slice", "slice", Function, "", false, false, []Argument{{"inputs", Uint32Arr2, false}}, nil), + "slice256": NewMethod("slice256", "slice256", Function, "", false, false, []Argument{{"inputs", Uint256Arr2, false}}, nil), + "sliceAddress": NewMethod("sliceAddress", "sliceAddress", Function, "", false, false, []Argument{{"inputs", AddressArr, false}}, nil), + "sliceMultiAddress": NewMethod("sliceMultiAddress", "sliceMultiAddress", Function, "", false, false, []Argument{{"a", AddressArr, false}, {"b", AddressArr, false}}, nil), + "nestedArray": NewMethod("nestedArray", "nestedArray", Function, "", false, false, []Argument{{"a", Uint256ArrNested, false}, {"b", AddressArr, false}}, nil), + "nestedArray2": NewMethod("nestedArray2", "nestedArray2", Function, "", false, false, []Argument{{"a", Uint8ArrNested, false}}, nil), + "nestedSlice": NewMethod("nestedSlice", "nestedSlice", Function, "", false, false, []Argument{{"a", Uint8SliceNested, false}}, nil), + "receive": NewMethod("receive", "receive", Function, "payable", false, true, []Argument{{"memo", Bytes, false}}, []Argument{}), + "fixedArrStr": NewMethod("fixedArrStr", "fixedArrStr", Function, "view", false, false, []Argument{{"str", String, false}, {"fixedArr", Uint256Arr2, false}}, nil), + "fixedArrBytes": NewMethod("fixedArrBytes", "fixedArrBytes", Function, "view", false, false, []Argument{{"bytes", Bytes, false}, {"fixedArr", Uint256Arr2, false}}, nil), + "mixedArrStr": NewMethod("mixedArrStr", "mixedArrStr", Function, "view", false, false, []Argument{{"str", String, false}, {"fixedArr", Uint256Arr2, false}, {"dynArr", Uint256Arr, false}}, nil), + "doubleFixedArrStr": NewMethod("doubleFixedArrStr", "doubleFixedArrStr", Function, "view", false, false, []Argument{{"str", String, false}, {"fixedArr1", Uint256Arr2, false}, {"fixedArr2", Uint256Arr3, false}}, nil), + "multipleMixedArrStr": NewMethod("multipleMixedArrStr", "multipleMixedArrStr", Function, "view", false, false, []Argument{{"str", String, false}, {"fixedArr1", Uint256Arr2, false}, {"dynArr", Uint256Arr, false}, {"fixedArr2", Uint256Arr3, false}}, nil), + "overloadedNames": NewMethod("overloadedNames", "overloadedNames", Function, "view", false, false, []Argument{{"f", TupleF, false}}, nil), +} func TestReader(t *testing.T) { - Uint256, _ := NewType("uint256", nil) - exp := ABI{ - Methods: map[string]Method{ - "balance": { - "balance", true, nil, nil, - }, - "send": { - "send", false, []Argument{ - {"amount", Uint256, false}, - }, nil, - }, - }, + abi := ABI{ + Methods: methods, } - abi, err := JSON(strings.NewReader(jsondata)) + exp, err := JSON(strings.NewReader(jsondata)) if err != nil { - t.Error(err) + t.Fatal(err) } - // deep equal fails for some reason for name, expM := range exp.Methods { gotM, exist := abi.Methods[name] if !exist { @@ -99,11 +146,58 @@ func TestReader(t *testing.T) { } } -func TestTestNumbers(t *testing.T) { - abi, err := JSON(strings.NewReader(jsondata2)) +func TestInvalidABI(t *testing.T) { + json := `[{ "type" : "function", "name" : "", "constant" : fals }]` + _, err := JSON(strings.NewReader(json)) + if err == nil { + t.Fatal("invalid json should produce error") + } + json2 := `[{ "type" : "function", "name" : "send", "constant" : false, "inputs" : [ { "name" : "amount", "typ" : "uint256" } ] }]` + _, err = JSON(strings.NewReader(json2)) + if err == nil { + t.Fatal("invalid json should produce error") + } +} + +// TestConstructor tests a constructor function. +// The test is based on the following contract: +// +// contract TestConstructor { +// constructor(uint256 a, uint256 b) public{} +// } +func TestConstructor(t *testing.T) { + json := `[{ "inputs": [{"internalType": "uint256","name": "a","type": "uint256" },{ "internalType": "uint256","name": "b","type": "uint256"}],"stateMutability": "nonpayable","type": "constructor"}]` + method := NewMethod("", "", Constructor, "nonpayable", false, false, []Argument{{"a", Uint256, false}, {"b", Uint256, false}}, nil) + // Test from JSON + abi, err := JSON(strings.NewReader(json)) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(abi.Constructor, method) { + t.Error("Missing expected constructor") + } + // Test pack/unpack + packed, err := abi.Pack("", big.NewInt(1), big.NewInt(2)) + if err != nil { + t.Error(err) + } + unpacked, err := abi.Constructor.Inputs.Unpack(packed) if err != nil { t.Error(err) - t.FailNow() + } + + if !reflect.DeepEqual(unpacked[0], big.NewInt(1)) { + t.Error("Unable to pack/unpack from constructor") + } + if !reflect.DeepEqual(unpacked[1], big.NewInt(2)) { + t.Error("Unable to pack/unpack from constructor") + } +} + +func TestTestNumbers(t *testing.T) { + abi, err := JSON(strings.NewReader(jsondata)) + if err != nil { + t.Fatal(err) } if _, err := abi.Pack("balance"); err != nil { @@ -137,69 +231,26 @@ func TestTestNumbers(t *testing.T) { } } -func TestTestString(t *testing.T) { - abi, err := JSON(strings.NewReader(jsondata2)) - if err != nil { - t.Error(err) - t.FailNow() - } - - if _, err := abi.Pack("string", "hello world"); err != nil { - t.Error(err) - } -} - -func TestTestBool(t *testing.T) { - abi, err := JSON(strings.NewReader(jsondata2)) - if err != nil { - t.Error(err) - t.FailNow() - } - - if _, err := abi.Pack("bool", true); err != nil { - t.Error(err) - } -} - -func TestTestSlice(t *testing.T) { - abi, err := JSON(strings.NewReader(jsondata2)) - if err != nil { - t.Error(err) - t.FailNow() - } - - slice := make([]uint64, 2) - if _, err := abi.Pack("uint64[2]", slice); err != nil { - t.Error(err) - } - - if _, err := abi.Pack("uint64[]", slice); err != nil { - t.Error(err) - } -} - func TestMethodSignature(t *testing.T) { - String, _ := NewType("string", nil) - m := Method{"foo", false, []Argument{{"bar", String, false}, {"baz", String, false}}, nil} + m := NewMethod("foo", "foo", Function, "", false, false, []Argument{{"bar", String, false}, {"baz", String, false}}, nil) exp := "foo(string,string)" - if m.Sig() != exp { - t.Error("signature mismatch", exp, "!=", m.Sig()) + if m.Sig != exp { + t.Error("signature mismatch", exp, "!=", m.Sig) } idexp := crypto.Keccak256([]byte(exp))[:4] - if !bytes.Equal(m.Id(), idexp) { - t.Errorf("expected ids to match %x != %x", m.Id(), idexp) + if !bytes.Equal(m.ID, idexp) { + t.Errorf("expected ids to match %x != %x", m.ID, idexp) } - uintt, _ := NewType("uint256", nil) - m = Method{"foo", false, []Argument{{"bar", uintt, false}}, nil} + m = NewMethod("foo", "foo", Function, "", false, false, []Argument{{"bar", Uint256, false}}, nil) exp = "foo(uint256)" - if m.Sig() != exp { - t.Error("signature mismatch", exp, "!=", m.Sig()) + if m.Sig != exp { + t.Error("signature mismatch", exp, "!=", m.Sig) } // Method with tuple arguments - s, _ := NewType("tuple", []ArgumentMarshaling{ + s, _ := NewType("tuple", "", []ArgumentMarshaling{ {Name: "a", Type: "int256"}, {Name: "b", Type: "int256[]"}, {Name: "c", Type: "tuple[]", Components: []ArgumentMarshaling{ @@ -211,18 +262,40 @@ func TestMethodSignature(t *testing.T) { {Name: "y", Type: "int256"}, }}, }) - m = Method{"foo", false, []Argument{{"s", s, false}, {"bar", String, false}}, nil} + m = NewMethod("foo", "foo", Function, "", false, false, []Argument{{"s", s, false}, {"bar", String, false}}, nil) exp = "foo((int256,int256[],(int256,int256)[],(int256,int256)[2]),string)" - if m.Sig() != exp { - t.Error("signature mismatch", exp, "!=", m.Sig()) + if m.Sig != exp { + t.Error("signature mismatch", exp, "!=", m.Sig) } } +func TestOverloadedMethodSignature(t *testing.T) { + json := `[{"constant":true,"inputs":[{"name":"i","type":"uint256"},{"name":"j","type":"uint256"}],"name":"foo","outputs":[],"payable":false,"stateMutability":"pure","type":"function"},{"constant":true,"inputs":[{"name":"i","type":"uint256"}],"name":"foo","outputs":[],"payable":false,"stateMutability":"pure","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"i","type":"uint256"}],"name":"bar","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"i","type":"uint256"},{"indexed":false,"name":"j","type":"uint256"}],"name":"bar","type":"event"}]` + abi, err := JSON(strings.NewReader(json)) + if err != nil { + t.Fatal(err) + } + check := func(name string, expect string, method bool) { + if method { + if abi.Methods[name].Sig != expect { + t.Fatalf("The signature of overloaded method mismatch, want %s, have %s", expect, abi.Methods[name].Sig) + } + } else { + if abi.Events[name].Sig != expect { + t.Fatalf("The signature of overloaded event mismatch, want %s, have %s", expect, abi.Events[name].Sig) + } + } + } + check("foo", "foo(uint256,uint256)", true) + check("foo0", "foo(uint256)", true) + check("bar", "bar(uint256)", false) + check("bar0", "bar(uint256,uint256)", false) +} + func TestMultiPack(t *testing.T) { - abi, err := JSON(strings.NewReader(jsondata2)) + abi, err := JSON(strings.NewReader(jsondata)) if err != nil { - t.Error(err) - t.FailNow() + t.Fatal(err) } sig := crypto.Keccak256([]byte("bar(uint32,uint16)"))[:4] @@ -232,10 +305,8 @@ func TestMultiPack(t *testing.T) { packed, err := abi.Pack("bar", uint32(10), uint16(11)) if err != nil { - t.Error(err) - t.FailNow() + t.Fatal(err) } - if !bytes.Equal(packed, sig) { t.Errorf("expected %x got %x", sig, packed) } @@ -246,11 +317,11 @@ func ExampleJSON() { abi, err := JSON(strings.NewReader(definition)) if err != nil { - log.Fatalln(err) + panic(err) } out, err := abi.Pack("isBar", common.HexToAddress("01")) if err != nil { - log.Fatalln(err) + panic(err) } fmt.Printf("%x\n", out) @@ -387,15 +458,7 @@ func TestInputVariableInputLength(t *testing.T) { } func TestInputFixedArrayAndVariableInputLength(t *testing.T) { - const definition = `[ - { "type" : "function", "name" : "fixedArrStr", "constant" : true, "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr", "type" : "uint256[2]" } ] }, - { "type" : "function", "name" : "fixedArrBytes", "constant" : true, "inputs" : [ { "name" : "str", "type" : "bytes" }, { "name" : "fixedArr", "type" : "uint256[2]" } ] }, - { "type" : "function", "name" : "mixedArrStr", "constant" : true, "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr", "type": "uint256[2]" }, { "name" : "dynArr", "type": "uint256[]" } ] }, - { "type" : "function", "name" : "doubleFixedArrStr", "constant" : true, "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr1", "type": "uint256[2]" }, { "name" : "fixedArr2", "type": "uint256[3]" } ] }, - { "type" : "function", "name" : "multipleMixedArrStr", "constant" : true, "inputs" : [ { "name" : "str", "type" : "string" }, { "name" : "fixedArr1", "type": "uint256[2]" }, { "name" : "dynArr", "type" : "uint256[]" }, { "name" : "fixedArr2", "type" : "uint256[3]" } ] } - ]` - - abi, err := JSON(strings.NewReader(definition)) + abi, err := JSON(strings.NewReader(jsondata)) if err != nil { t.Error(err) } @@ -542,7 +605,7 @@ func TestInputFixedArrayAndVariableInputLength(t *testing.T) { strvalue = common.RightPadBytes([]byte(strin), 32) fixedarrin1value1 = common.LeftPadBytes(fixedarrin1[0].Bytes(), 32) fixedarrin1value2 = common.LeftPadBytes(fixedarrin1[1].Bytes(), 32) - dynarroffset = U256(big.NewInt(int64(256 + ((len(strin)/32)+1)*32))) + dynarroffset = math.U256Bytes(big.NewInt(int64(256 + ((len(strin)/32)+1)*32))) dynarrlength = make([]byte, 32) dynarrlength[31] = byte(len(dynarrin)) dynarrinvalue1 = common.LeftPadBytes(dynarrin[0].Bytes(), 32) @@ -569,7 +632,7 @@ func TestInputFixedArrayAndVariableInputLength(t *testing.T) { } func TestDefaultFunctionParsing(t *testing.T) { - const definition = `[{ "name" : "balance" }]` + const definition = `[{ "name" : "balance", "type" : "function" }]` abi, err := JSON(strings.NewReader(definition)) if err != nil { @@ -589,9 +652,7 @@ func TestBareEvents(t *testing.T) { { "type" : "event", "name" : "tuple", "inputs" : [{ "indexed":false, "name":"t", "type":"tuple", "components":[{"name":"a", "type":"uint256"}] }, { "indexed":true, "name":"arg1", "type":"address" }] } ]` - arg0, _ := NewType("uint256", nil) - arg1, _ := NewType("address", nil) - tuple, _ := NewType("tuple", []ArgumentMarshaling{{Name: "a", Type: "uint256"}}) + tuple, _ := NewType("tuple", "", []ArgumentMarshaling{{Name: "a", Type: "uint256"}}) expectedEvents := map[string]struct { Anonymous bool @@ -600,12 +661,12 @@ func TestBareEvents(t *testing.T) { "balance": {false, nil}, "anon": {true, nil}, "args": {false, []Argument{ - {Name: "arg0", Type: arg0, Indexed: false}, - {Name: "arg1", Type: arg1, Indexed: true}, + {Name: "arg0", Type: Uint256, Indexed: false}, + {Name: "arg1", Type: Address, Indexed: true}, }}, "tuple": {false, []Argument{ {Name: "t", Type: tuple, Indexed: false}, - {Name: "arg1", Type: arg1, Indexed: true}, + {Name: "arg1", Type: Address, Indexed: true}, }}, } @@ -682,7 +743,7 @@ func TestUnpackEvent(t *testing.T) { } var ev ReceivedEvent - err = abi.Unpack(&ev, "received", data) + err = abi.UnpackIntoInterface(&ev, "received", data) if err != nil { t.Error(err) } @@ -691,52 +752,215 @@ func TestUnpackEvent(t *testing.T) { Sender common.Address } var receivedAddrEv ReceivedAddrEvent - err = abi.Unpack(&receivedAddrEv, "receivedAddr", data) + err = abi.UnpackIntoInterface(&receivedAddrEv, "receivedAddr", data) if err != nil { t.Error(err) } } -func TestABI_MethodById(t *testing.T) { - const abiJSON = `[ - {"type":"function","name":"receive","constant":false,"inputs":[{"name":"memo","type":"bytes"}],"outputs":[],"payable":true,"stateMutability":"payable"}, - {"type":"event","name":"received","anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}]}, - {"type":"function","name":"fixedArrStr","constant":true,"inputs":[{"name":"str","type":"string"},{"name":"fixedArr","type":"uint256[2]"}]}, - {"type":"function","name":"fixedArrBytes","constant":true,"inputs":[{"name":"str","type":"bytes"},{"name":"fixedArr","type":"uint256[2]"}]}, - {"type":"function","name":"mixedArrStr","constant":true,"inputs":[{"name":"str","type":"string"},{"name":"fixedArr","type":"uint256[2]"},{"name":"dynArr","type":"uint256[]"}]}, - {"type":"function","name":"doubleFixedArrStr","constant":true,"inputs":[{"name":"str","type":"string"},{"name":"fixedArr1","type":"uint256[2]"},{"name":"fixedArr2","type":"uint256[3]"}]}, - {"type":"function","name":"multipleMixedArrStr","constant":true,"inputs":[{"name":"str","type":"string"},{"name":"fixedArr1","type":"uint256[2]"},{"name":"dynArr","type":"uint256[]"},{"name":"fixedArr2","type":"uint256[3]"}]}, - {"type":"function","name":"balance","constant":true}, - {"type":"function","name":"send","constant":false,"inputs":[{"name":"amount","type":"uint256"}]}, - {"type":"function","name":"test","constant":false,"inputs":[{"name":"number","type":"uint32"}]}, - {"type":"function","name":"string","constant":false,"inputs":[{"name":"inputs","type":"string"}]}, - {"type":"function","name":"bool","constant":false,"inputs":[{"name":"inputs","type":"bool"}]}, - {"type":"function","name":"address","constant":false,"inputs":[{"name":"inputs","type":"address"}]}, - {"type":"function","name":"uint64[2]","constant":false,"inputs":[{"name":"inputs","type":"uint64[2]"}]}, - {"type":"function","name":"uint64[]","constant":false,"inputs":[{"name":"inputs","type":"uint64[]"}]}, - {"type":"function","name":"foo","constant":false,"inputs":[{"name":"inputs","type":"uint32"}]}, - {"type":"function","name":"bar","constant":false,"inputs":[{"name":"inputs","type":"uint32"},{"name":"string","type":"uint16"}]}, - {"type":"function","name":"_slice","constant":false,"inputs":[{"name":"inputs","type":"uint32[2]"}]}, - {"type":"function","name":"__slice256","constant":false,"inputs":[{"name":"inputs","type":"uint256[2]"}]}, - {"type":"function","name":"sliceAddress","constant":false,"inputs":[{"name":"inputs","type":"address[]"}]}, - {"type":"function","name":"sliceMultiAddress","constant":false,"inputs":[{"name":"a","type":"address[]"},{"name":"b","type":"address[]"}]} - ] -` +func TestUnpackEventIntoMap(t *testing.T) { + const abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]` + abi, err := JSON(strings.NewReader(abiJSON)) + if err != nil { + t.Fatal(err) + } + + const hexdata = `000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158` + data, err := hex.DecodeString(hexdata) + if err != nil { + t.Fatal(err) + } + if len(data)%32 == 0 { + t.Errorf("len(data) is %d, want a non-multiple of 32", len(data)) + } + + receivedMap := map[string]interface{}{} + expectedReceivedMap := map[string]interface{}{ + "sender": common.HexToAddress("0x376c47978271565f56DEB45495afa69E59c16Ab2"), + "amount": big.NewInt(1), + "memo": []byte{88}, + } + if err := abi.UnpackIntoMap(receivedMap, "received", data); err != nil { + t.Error(err) + } + if len(receivedMap) != 3 { + t.Error("unpacked `received` map expected to have length 3") + } + if receivedMap["sender"] != expectedReceivedMap["sender"] { + t.Error("unpacked `received` map does not match expected map") + } + if receivedMap["amount"].(*big.Int).Cmp(expectedReceivedMap["amount"].(*big.Int)) != 0 { + t.Error("unpacked `received` map does not match expected map") + } + if !bytes.Equal(receivedMap["memo"].([]byte), expectedReceivedMap["memo"].([]byte)) { + t.Error("unpacked `received` map does not match expected map") + } + + receivedAddrMap := map[string]interface{}{} + if err = abi.UnpackIntoMap(receivedAddrMap, "receivedAddr", data); err != nil { + t.Error(err) + } + if len(receivedAddrMap) != 1 { + t.Error("unpacked `receivedAddr` map expected to have length 1") + } + if receivedAddrMap["sender"] != expectedReceivedMap["sender"] { + t.Error("unpacked `receivedAddr` map does not match expected map") + } +} + +func TestUnpackMethodIntoMap(t *testing.T) { + const abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"constant":false,"inputs":[],"name":"send","outputs":[{"name":"amount","type":"uint256"}],"payable":true,"stateMutability":"payable","type":"function"},{"constant":false,"inputs":[{"name":"addr","type":"address"}],"name":"get","outputs":[{"name":"hash","type":"bytes"}],"payable":true,"stateMutability":"payable","type":"function"}]` + abi, err := JSON(strings.NewReader(abiJSON)) + if err != nil { + t.Fatal(err) + } + const hexdata = `00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000015800000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000158000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000001580000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000015800000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000158` + data, err := hex.DecodeString(hexdata) + if err != nil { + t.Fatal(err) + } + if len(data)%32 != 0 { + t.Errorf("len(data) is %d, want a multiple of 32", len(data)) + } + + // Tests a method with no outputs + receiveMap := map[string]interface{}{} + if err = abi.UnpackIntoMap(receiveMap, "receive", data); err != nil { + t.Error(err) + } + if len(receiveMap) > 0 { + t.Error("unpacked `receive` map expected to have length 0") + } + + // Tests a method with only outputs + sendMap := map[string]interface{}{} + if err = abi.UnpackIntoMap(sendMap, "send", data); err != nil { + t.Error(err) + } + if len(sendMap) != 1 { + t.Error("unpacked `send` map expected to have length 1") + } + if sendMap["amount"].(*big.Int).Cmp(big.NewInt(1)) != 0 { + t.Error("unpacked `send` map expected `amount` value of 1") + } + + // Tests a method with outputs and inputs + getMap := map[string]interface{}{} + if err = abi.UnpackIntoMap(getMap, "get", data); err != nil { + t.Error(err) + } + if len(getMap) != 1 { + t.Error("unpacked `get` map expected to have length 1") + } + expectedBytes := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 88, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 88, 0} + if !bytes.Equal(getMap["hash"].([]byte), expectedBytes) { + t.Errorf("unpacked `get` map expected `hash` value of %v", expectedBytes) + } +} + +func TestUnpackIntoMapNamingConflict(t *testing.T) { + // Two methods have the same name + var abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"get","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"constant":false,"inputs":[],"name":"send","outputs":[{"name":"amount","type":"uint256"}],"payable":true,"stateMutability":"payable","type":"function"},{"constant":false,"inputs":[{"name":"addr","type":"address"}],"name":"get","outputs":[{"name":"hash","type":"bytes"}],"payable":true,"stateMutability":"payable","type":"function"}]` abi, err := JSON(strings.NewReader(abiJSON)) if err != nil { t.Fatal(err) } + var hexdata = `00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158` + data, err := hex.DecodeString(hexdata) + if err != nil { + t.Fatal(err) + } + if len(data)%32 == 0 { + t.Errorf("len(data) is %d, want a non-multiple of 32", len(data)) + } + getMap := map[string]interface{}{} + if err = abi.UnpackIntoMap(getMap, "get", data); err == nil { + t.Error("naming conflict between two methods; error expected") + } + + // Two events have the same name + abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"received","type":"event"}]` + abi, err = JSON(strings.NewReader(abiJSON)) + if err != nil { + t.Fatal(err) + } + hexdata = `000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158` + data, err = hex.DecodeString(hexdata) + if err != nil { + t.Fatal(err) + } + if len(data)%32 == 0 { + t.Errorf("len(data) is %d, want a non-multiple of 32", len(data)) + } + receivedMap := map[string]interface{}{} + if err = abi.UnpackIntoMap(receivedMap, "received", data); err != nil { + t.Error("naming conflict between two events; no error expected") + } + + // Method and event have the same name + abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"received","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]` + abi, err = JSON(strings.NewReader(abiJSON)) + if err != nil { + t.Fatal(err) + } + if len(data)%32 == 0 { + t.Errorf("len(data) is %d, want a non-multiple of 32", len(data)) + } + if err = abi.UnpackIntoMap(receivedMap, "received", data); err == nil { + t.Error("naming conflict between an event and a method; error expected") + } + + // Conflict is case sensitive + abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"received","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"Received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]` + abi, err = JSON(strings.NewReader(abiJSON)) + if err != nil { + t.Fatal(err) + } + if len(data)%32 == 0 { + t.Errorf("len(data) is %d, want a non-multiple of 32", len(data)) + } + expectedReceivedMap := map[string]interface{}{ + "sender": common.HexToAddress("0x376c47978271565f56DEB45495afa69E59c16Ab2"), + "amount": big.NewInt(1), + "memo": []byte{88}, + } + if err = abi.UnpackIntoMap(receivedMap, "Received", data); err != nil { + t.Error(err) + } + if len(receivedMap) != 3 { + t.Error("unpacked `received` map expected to have length 3") + } + if receivedMap["sender"] != expectedReceivedMap["sender"] { + t.Error("unpacked `received` map does not match expected map") + } + if receivedMap["amount"].(*big.Int).Cmp(expectedReceivedMap["amount"].(*big.Int)) != 0 { + t.Error("unpacked `received` map does not match expected map") + } + if !bytes.Equal(receivedMap["memo"].([]byte), expectedReceivedMap["memo"].([]byte)) { + t.Error("unpacked `received` map does not match expected map") + } +} + +func TestABI_MethodById(t *testing.T) { + abi, err := JSON(strings.NewReader(jsondata)) + if err != nil { + t.Fatal(err) + } for name, m := range abi.Methods { a := fmt.Sprintf("%v", m) - m2, err := abi.MethodById(m.Id()) + m2, err := abi.MethodById(m.ID) if err != nil { t.Fatalf("Failed to look up ABI method: %v", err) } b := fmt.Sprintf("%v", m2) if a != b { - t.Errorf("Method %v (id %v) not 'findable' by id in ABI", name, common.ToHex(m.Id())) + t.Errorf("Method %v (id %x) not 'findable' by id in ABI", name, m.ID) } } + // test unsuccessful lookups + if _, err = abi.MethodById(crypto.Keccak256()); err == nil { + t.Error("Expected error: no method with this id") + } // Also test empty if _, err := abi.MethodById([]byte{0x00}); err == nil { t.Errorf("Expected error, too short to decode data") @@ -748,3 +972,148 @@ func TestABI_MethodById(t *testing.T) { t.Errorf("Expected error, nil is short to decode data") } } + +func TestABI_EventById(t *testing.T) { + tests := []struct { + name string + json string + event string + }{ + { + name: "", + json: `[ + {"type":"event","name":"received","anonymous":false,"inputs":[ + {"indexed":false,"name":"sender","type":"address"}, + {"indexed":false,"name":"amount","type":"uint256"}, + {"indexed":false,"name":"memo","type":"bytes"} + ] + }]`, + event: "received(address,uint256,bytes)", + }, { + name: "", + json: `[ + { "constant": true, "inputs": [], "name": "name", "outputs": [ { "name": "", "type": "string" } ], "payable": false, "stateMutability": "view", "type": "function" }, + { "constant": false, "inputs": [ { "name": "_spender", "type": "address" }, { "name": "_value", "type": "uint256" } ], "name": "approve", "outputs": [ { "name": "", "type": "bool" } ], "payable": false, "stateMutability": "nonpayable", "type": "function" }, + { "constant": true, "inputs": [], "name": "totalSupply", "outputs": [ { "name": "", "type": "uint256" } ], "payable": false, "stateMutability": "view", "type": "function" }, + { "constant": false, "inputs": [ { "name": "_from", "type": "address" }, { "name": "_to", "type": "address" }, { "name": "_value", "type": "uint256" } ], "name": "transferFrom", "outputs": [ { "name": "", "type": "bool" } ], "payable": false, "stateMutability": "nonpayable", "type": "function" }, + { "constant": true, "inputs": [], "name": "decimals", "outputs": [ { "name": "", "type": "uint8" } ], "payable": false, "stateMutability": "view", "type": "function" }, + { "constant": true, "inputs": [ { "name": "_owner", "type": "address" } ], "name": "balanceOf", "outputs": [ { "name": "balance", "type": "uint256" } ], "payable": false, "stateMutability": "view", "type": "function" }, + { "constant": true, "inputs": [], "name": "symbol", "outputs": [ { "name": "", "type": "string" } ], "payable": false, "stateMutability": "view", "type": "function" }, + { "constant": false, "inputs": [ { "name": "_to", "type": "address" }, { "name": "_value", "type": "uint256" } ], "name": "transfer", "outputs": [ { "name": "", "type": "bool" } ], "payable": false, "stateMutability": "nonpayable", "type": "function" }, + { "constant": true, "inputs": [ { "name": "_owner", "type": "address" }, { "name": "_spender", "type": "address" } ], "name": "allowance", "outputs": [ { "name": "", "type": "uint256" } ], "payable": false, "stateMutability": "view", "type": "function" }, + { "payable": true, "stateMutability": "payable", "type": "fallback" }, + { "anonymous": false, "inputs": [ { "indexed": true, "name": "owner", "type": "address" }, { "indexed": true, "name": "spender", "type": "address" }, { "indexed": false, "name": "value", "type": "uint256" } ], "name": "Approval", "type": "event" }, + { "anonymous": false, "inputs": [ { "indexed": true, "name": "from", "type": "address" }, { "indexed": true, "name": "to", "type": "address" }, { "indexed": false, "name": "value", "type": "uint256" } ], "name": "Transfer", "type": "event" } + ]`, + event: "Transfer(address,address,uint256)", + }, + } + + for testnum, test := range tests { + abi, err := JSON(strings.NewReader(test.json)) + if err != nil { + t.Error(err) + } + + topic := test.event + topicID := crypto.Keccak256Hash([]byte(topic)) + + event, err := abi.EventByID(topicID) + if err != nil { + t.Fatalf("Failed to look up ABI method: %v, test #%d", err, testnum) + } + if event == nil { + t.Errorf("We should find a event for topic %s, test #%d", topicID.Hex(), testnum) + } + + if event.ID != topicID { + t.Errorf("Event id %s does not match topic %s, test #%d", event.ID.Hex(), topicID.Hex(), testnum) + } + + unknowntopicID := crypto.Keccak256Hash([]byte("unknownEvent")) + unknownEvent, err := abi.EventByID(unknowntopicID) + if err == nil { + t.Errorf("EventByID should return an error if a topic is not found, test #%d", testnum) + } + if unknownEvent != nil { + t.Errorf("We should not find any event for topic %s, test #%d", unknowntopicID.Hex(), testnum) + } + } +} + +// TestDoubleDuplicateMethodNames checks that if transfer0 already exists, there won't be a name +// conflict and that the second transfer method will be renamed transfer1. +func TestDoubleDuplicateMethodNames(t *testing.T) { + abiJSON := `[{"constant":false,"inputs":[{"name":"to","type":"address"},{"name":"value","type":"uint256"}],"name":"transfer","outputs":[{"name":"ok","type":"bool"}],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":false,"inputs":[{"name":"to","type":"address"},{"name":"value","type":"uint256"},{"name":"data","type":"bytes"}],"name":"transfer0","outputs":[{"name":"ok","type":"bool"}],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":false,"inputs":[{"name":"to","type":"address"},{"name":"value","type":"uint256"},{"name":"data","type":"bytes"},{"name":"customFallback","type":"string"}],"name":"transfer","outputs":[{"name":"ok","type":"bool"}],"payable":false,"stateMutability":"nonpayable","type":"function"}]` + contractAbi, err := JSON(strings.NewReader(abiJSON)) + if err != nil { + t.Fatal(err) + } + if _, ok := contractAbi.Methods["transfer"]; !ok { + t.Fatalf("Could not find original method") + } + if _, ok := contractAbi.Methods["transfer0"]; !ok { + t.Fatalf("Could not find duplicate method") + } + if _, ok := contractAbi.Methods["transfer1"]; !ok { + t.Fatalf("Could not find duplicate method") + } + if _, ok := contractAbi.Methods["transfer2"]; ok { + t.Fatalf("Should not have found extra method") + } +} + +// TestDoubleDuplicateEventNames checks that if send0 already exists, there won't be a name +// conflict and that the second send event will be renamed send1. +// The test runs the abi of the following contract. +// +// contract DuplicateEvent { +// event send(uint256 a); +// event send0(); +// event send(); +// } +func TestDoubleDuplicateEventNames(t *testing.T) { + abiJSON := `[{"anonymous": false,"inputs": [{"indexed": false,"internalType": "uint256","name": "a","type": "uint256"}],"name": "send","type": "event"},{"anonymous": false,"inputs": [],"name": "send0","type": "event"},{ "anonymous": false, "inputs": [],"name": "send","type": "event"}]` + contractAbi, err := JSON(strings.NewReader(abiJSON)) + if err != nil { + t.Fatal(err) + } + if _, ok := contractAbi.Events["send"]; !ok { + t.Fatalf("Could not find original event") + } + if _, ok := contractAbi.Events["send0"]; !ok { + t.Fatalf("Could not find duplicate event") + } + if _, ok := contractAbi.Events["send1"]; !ok { + t.Fatalf("Could not find duplicate event") + } + if _, ok := contractAbi.Events["send2"]; ok { + t.Fatalf("Should not have found extra event") + } +} + +// TestUnnamedEventParam checks that an event with unnamed parameters is +// correctly handled. +// The test runs the abi of the following contract. +// +// contract TestEvent { +// event send(uint256, uint256); +// } +func TestUnnamedEventParam(t *testing.T) { + abiJSON := `[{ "anonymous": false, "inputs": [{ "indexed": false,"internalType": "uint256", "name": "","type": "uint256"},{"indexed": false,"internalType": "uint256","name": "","type": "uint256"}],"name": "send","type": "event"}]` + contractAbi, err := JSON(strings.NewReader(abiJSON)) + if err != nil { + t.Fatal(err) + } + + event, ok := contractAbi.Events["send"] + if !ok { + t.Fatalf("Could not find event") + } + if event.Inputs[0].Name != "arg0" { + t.Fatalf("Could not find input") + } + if event.Inputs[1].Name != "arg1" { + t.Fatalf("Could not find input") + } +} diff --git a/accounts/abi/argument.go b/accounts/abi/argument.go index d0a6b035c6..2e48d539e0 100644 --- a/accounts/abi/argument.go +++ b/accounts/abi/argument.go @@ -18,6 +18,7 @@ package abi import ( "encoding/json" + "errors" "fmt" "reflect" "strings" @@ -34,13 +35,14 @@ type Argument struct { type Arguments []Argument type ArgumentMarshaling struct { - Name string - Type string - Components []ArgumentMarshaling - Indexed bool + Name string + Type string + InternalType string + Components []ArgumentMarshaling + Indexed bool } -// UnmarshalJSON implements json.Unmarshaler interface +// UnmarshalJSON implements json.Unmarshaler interface. func (argument *Argument) UnmarshalJSON(data []byte) error { var arg ArgumentMarshaling err := json.Unmarshal(data, &arg) @@ -48,7 +50,7 @@ func (argument *Argument) UnmarshalJSON(data []byte) error { return fmt.Errorf("argument json err: %v", err) } - argument.Type, err = NewType(arg.Type, arg.Components) + argument.Type, err = NewType(arg.Type, arg.InternalType, arg.Components) if err != nil { return err } @@ -58,19 +60,7 @@ func (argument *Argument) UnmarshalJSON(data []byte) error { return nil } -// LengthNonIndexed returns the number of arguments when not counting 'indexed' ones. Only events -// can ever have 'indexed' arguments, it should always be false on arguments for method input/output -func (arguments Arguments) LengthNonIndexed() int { - out := 0 - for _, arg := range arguments { - if !arg.Indexed { - out++ - } - } - return out -} - -// NonIndexed returns the arguments with indexed arguments filtered out +// NonIndexed returns the arguments with indexed arguments filtered out. func (arguments Arguments) NonIndexed() Arguments { var ret []Argument for _, arg := range arguments { @@ -81,170 +71,125 @@ func (arguments Arguments) NonIndexed() Arguments { return ret } -// isTuple returns true for non-atomic constructs, like (uint,uint) or uint[] +// isTuple returns true for non-atomic constructs, like (uint,uint) or uint[]. func (arguments Arguments) isTuple() bool { return len(arguments) > 1 } -// Unpack performs the operation hexdata -> Go format -func (arguments Arguments) Unpack(v interface{}, data []byte) error { - // make sure the passed value is arguments pointer - if reflect.Ptr != reflect.ValueOf(v).Kind() { - return fmt.Errorf("abi: Unpack(non-pointer %T)", v) +// Unpack performs the operation hexdata -> Go format. +func (arguments Arguments) Unpack(data []byte) ([]interface{}, error) { + if len(data) == 0 { + if len(arguments.NonIndexed()) != 0 { + return nil, errors.New("abi: attempting to unmarshall an empty string while arguments are expected") + } + return make([]interface{}, 0), nil + } + return arguments.UnpackValues(data) +} + +// UnpackIntoMap performs the operation hexdata -> mapping of argument name to argument value. +func (arguments Arguments) UnpackIntoMap(v map[string]interface{}, data []byte) error { + // Make sure map is not nil + if v == nil { + return errors.New("abi: cannot unpack into a nil map") + } + if len(data) == 0 { + if len(arguments.NonIndexed()) != 0 { + return errors.New("abi: attempting to unmarshall an empty string while arguments are expected") + } + return nil // Nothing to unmarshal, return } marshalledValues, err := arguments.UnpackValues(data) if err != nil { return err } - if arguments.isTuple() { - return arguments.unpackTuple(v, marshalledValues) + for i, arg := range arguments.NonIndexed() { + v[arg.Name] = marshalledValues[i] } - return arguments.unpackAtomic(v, marshalledValues[0]) + return nil } -// unpack sets the unmarshalled value to go format. -// Note the dst here must be settable. -func unpack(t *Type, dst interface{}, src interface{}) error { - var ( - dstVal = reflect.ValueOf(dst).Elem() - srcVal = reflect.ValueOf(src) - ) - - if t.T != TupleTy && !((t.T == SliceTy || t.T == ArrayTy) && t.Elem.T == TupleTy) { - return set(dstVal, srcVal) +// Copy performs the operation go format -> provided struct. +func (arguments Arguments) Copy(v interface{}, values []interface{}) error { + // make sure the passed value is arguments pointer + if reflect.Ptr != reflect.ValueOf(v).Kind() { + return fmt.Errorf("abi: Unpack(non-pointer %T)", v) } - - switch t.T { - case TupleTy: - if dstVal.Kind() != reflect.Struct { - return fmt.Errorf("abi: invalid dst value for unpack, want struct, got %s", dstVal.Kind()) - } - fieldmap, err := mapArgNamesToStructFields(t.TupleRawNames, dstVal) - if err != nil { - return err - } - for i, elem := range t.TupleElems { - fname := fieldmap[t.TupleRawNames[i]] - field := dstVal.FieldByName(fname) - if !field.IsValid() { - return fmt.Errorf("abi: field %s can't found in the given value", t.TupleRawNames[i]) - } - if err := unpack(elem, field.Addr().Interface(), srcVal.Field(i).Interface()); err != nil { - return err - } - } - return nil - case SliceTy: - if dstVal.Kind() != reflect.Slice { - return fmt.Errorf("abi: invalid dst value for unpack, want slice, got %s", dstVal.Kind()) + if len(values) == 0 { + if len(arguments.NonIndexed()) != 0 { + return errors.New("abi: attempting to copy no values while arguments are expected") } - slice := reflect.MakeSlice(dstVal.Type(), srcVal.Len(), srcVal.Len()) - for i := 0; i < slice.Len(); i++ { - if err := unpack(t.Elem, slice.Index(i).Addr().Interface(), srcVal.Index(i).Interface()); err != nil { - return err - } - } - dstVal.Set(slice) - case ArrayTy: - if dstVal.Kind() != reflect.Array { - return fmt.Errorf("abi: invalid dst value for unpack, want array, got %s", dstVal.Kind()) - } - array := reflect.New(dstVal.Type()).Elem() - for i := 0; i < array.Len(); i++ { - if err := unpack(t.Elem, array.Index(i).Addr().Interface(), srcVal.Index(i).Interface()); err != nil { - return err - } - } - dstVal.Set(array) + return nil // Nothing to copy, return } - return nil + if arguments.isTuple() { + return arguments.copyTuple(v, values) + } + return arguments.copyAtomic(v, values[0]) } // unpackAtomic unpacks ( hexdata -> go ) a single value -func (arguments Arguments) unpackAtomic(v interface{}, marshalledValues interface{}) error { - if arguments.LengthNonIndexed() == 0 { - return nil - } - argument := arguments.NonIndexed()[0] - elem := reflect.ValueOf(v).Elem() +func (arguments Arguments) copyAtomic(v interface{}, marshalledValues interface{}) error { + dst := reflect.ValueOf(v).Elem() + src := reflect.ValueOf(marshalledValues) - if elem.Kind() == reflect.Struct { - fieldmap, err := mapArgNamesToStructFields([]string{argument.Name}, elem) - if err != nil { - return err - } - field := elem.FieldByName(fieldmap[argument.Name]) - if !field.IsValid() { - return fmt.Errorf("abi: field %s can't be found in the given value", argument.Name) - } - return unpack(&argument.Type, field.Addr().Interface(), marshalledValues) + if dst.Kind() == reflect.Struct { + return set(dst.Field(0), src) } - return unpack(&argument.Type, elem.Addr().Interface(), marshalledValues) + return set(dst, src) } -// unpackTuple unpacks ( hexdata -> go ) a batch of values. -func (arguments Arguments) unpackTuple(v interface{}, marshalledValues []interface{}) error { - var ( - value = reflect.ValueOf(v).Elem() - typ = value.Type() - kind = value.Kind() - ) - if err := requireUnpackKind(value, typ, kind, arguments); err != nil { - return err - } +// copyTuple copies a batch of values from marshalledValues to v. +func (arguments Arguments) copyTuple(v interface{}, marshalledValues []interface{}) error { + value := reflect.ValueOf(v).Elem() + nonIndexedArgs := arguments.NonIndexed() - // If the interface is a struct, get of abi->struct_field mapping - var abi2struct map[string]string - if kind == reflect.Struct { - var ( - argNames []string - err error - ) - for _, arg := range arguments.NonIndexed() { - argNames = append(argNames, arg.Name) + switch value.Kind() { + case reflect.Struct: + argNames := make([]string, len(nonIndexedArgs)) + for i, arg := range nonIndexedArgs { + argNames[i] = arg.Name } - abi2struct, err = mapArgNamesToStructFields(argNames, value) + var err error + abi2struct, err := mapArgNamesToStructFields(argNames, value) if err != nil { return err } - } - for i, arg := range arguments.NonIndexed() { - switch kind { - case reflect.Struct: + for i, arg := range nonIndexedArgs { field := value.FieldByName(abi2struct[arg.Name]) if !field.IsValid() { return fmt.Errorf("abi: field %s can't be found in the given value", arg.Name) } - if err := unpack(&arg.Type, field.Addr().Interface(), marshalledValues[i]); err != nil { - return err - } - case reflect.Slice, reflect.Array: - if value.Len() < i { - return fmt.Errorf("abi: insufficient number of arguments for unpack, want %d, got %d", len(arguments), value.Len()) - } - v := value.Index(i) - if err := requireAssignable(v, reflect.ValueOf(marshalledValues[i])); err != nil { + if err := set(field, reflect.ValueOf(marshalledValues[i])); err != nil { return err } - if err := unpack(&arg.Type, v.Addr().Interface(), marshalledValues[i]); err != nil { + } + case reflect.Slice, reflect.Array: + if value.Len() < len(marshalledValues) { + return fmt.Errorf("abi: insufficient number of arguments for unpack, want %d, got %d", len(arguments), value.Len()) + } + for i := range nonIndexedArgs { + if err := set(value.Index(i), reflect.ValueOf(marshalledValues[i])); err != nil { return err } - default: - return fmt.Errorf("abi:[2] cannot unmarshal tuple in to %v", typ) } + default: + return fmt.Errorf("abi:[2] cannot unmarshal tuple in to %v", value.Type()) } return nil - } // UnpackValues can be used to unpack ABI-encoded hexdata according to the ABI-specification, // without supplying a struct to unpack into. Instead, this method returns a list containing the // values. An atomic argument will be a list with one element. func (arguments Arguments) UnpackValues(data []byte) ([]interface{}, error) { - retval := make([]interface{}, 0, arguments.LengthNonIndexed()) + nonIndexedArgs := arguments.NonIndexed() + retval := make([]interface{}, 0, len(nonIndexedArgs)) virtualArgs := 0 - for index, arg := range arguments.NonIndexed() { + for index, arg := range nonIndexedArgs { marshalledValue, err := toGoType((index+virtualArgs)*32, arg.Type, data) + if err != nil { + return nil, err + } if arg.Type.T == ArrayTy && !isDynamicType(arg.Type) { // If we have a static array, like [3]uint256, these are coded as // just like uint256,uint256,uint256. @@ -262,26 +207,23 @@ func (arguments Arguments) UnpackValues(data []byte) ([]interface{}, error) { // coded as just like uint256,bool,uint256 virtualArgs += getTypeSize(arg.Type)/32 - 1 } - if err != nil { - return nil, err - } retval = append(retval, marshalledValue) } return retval, nil } -// PackValues performs the operation Go format -> Hexdata -// It is the semantic opposite of UnpackValues +// PackValues performs the operation Go format -> Hexdata. +// It is the semantic opposite of UnpackValues. func (arguments Arguments) PackValues(args []interface{}) ([]byte, error) { return arguments.Pack(args...) } -// Pack performs the operation Go format -> Hexdata +// Pack performs the operation Go format -> Hexdata. func (arguments Arguments) Pack(args ...interface{}) ([]byte, error) { // Make sure arguments match up and pack them abiArgs := arguments if len(args) != len(abiArgs) { - return nil, fmt.Errorf("argument count mismatch: %d for %d", len(args), len(abiArgs)) + return nil, fmt.Errorf("argument count mismatch: got %d for %d", len(args), len(abiArgs)) } // variable input is the output appended at the end of packed // output. This is used for strings and bytes types input. diff --git a/accounts/abi/error.go b/accounts/abi/error.go index 9d8674ad08..f0f71b6c91 100644 --- a/accounts/abi/error.go +++ b/accounts/abi/error.go @@ -39,23 +39,21 @@ func formatSliceString(kind reflect.Kind, sliceSize int) string { // type in t. func sliceTypeCheck(t Type, val reflect.Value) error { if val.Kind() != reflect.Slice && val.Kind() != reflect.Array { - return typeErr(formatSliceString(t.Kind, t.Size), val.Type()) + return typeErr(formatSliceString(t.GetType().Kind(), t.Size), val.Type()) } if t.T == ArrayTy && val.Len() != t.Size { - return typeErr(formatSliceString(t.Elem.Kind, t.Size), formatSliceString(val.Type().Elem().Kind(), val.Len())) + return typeErr(formatSliceString(t.Elem.GetType().Kind(), t.Size), formatSliceString(val.Type().Elem().Kind(), val.Len())) } - if t.Elem.T == SliceTy { + if t.Elem.T == SliceTy || t.Elem.T == ArrayTy { if val.Len() > 0 { return sliceTypeCheck(*t.Elem, val.Index(0)) } - } else if t.Elem.T == ArrayTy { - return sliceTypeCheck(*t.Elem, val.Index(0)) } - if elemKind := val.Type().Elem().Kind(); elemKind != t.Elem.Kind { - return typeErr(formatSliceString(t.Elem.Kind, t.Size), val.Type()) + if val.Type().Elem().Kind() != t.Elem.GetType().Kind() { + return typeErr(formatSliceString(t.Elem.GetType().Kind(), t.Size), val.Type()) } return nil } @@ -68,10 +66,10 @@ func typeCheck(t Type, value reflect.Value) error { } // Check base type validity. Element types will be checked later on. - if t.Kind != value.Kind() { - return typeErr(t.Kind, value.Kind()) + if t.GetType().Kind() != value.Kind() { + return typeErr(t.GetType().Kind(), value.Kind()) } else if t.T == FixedBytesTy && t.Size != value.Len() { - return typeErr(t.Type, value.Type()) + return typeErr(t.GetType(), value.Type()) } else { return nil } diff --git a/accounts/abi/event.go b/accounts/abi/event.go index 082fd71aea..d427aac793 100644 --- a/accounts/abi/event.go +++ b/accounts/abi/event.go @@ -28,30 +28,76 @@ import ( // holds type information (inputs) about the yielded output. Anonymous events // don't get the signature canonical representation as the first LOG topic. type Event struct { - Name string + // Name is the event name used for internal representation. It's derived from + // the raw name and a suffix will be added in the case of event overloading. + // + // e.g. + // These are two events that have the same name: + // * foo(int,int) + // * foo(uint,uint) + // The event name of the first one will be resolved as foo while the second one + // will be resolved as foo0. + Name string + + // RawName is the raw event name parsed from ABI. + RawName string Anonymous bool Inputs Arguments + str string + + // Sig contains the string signature according to the ABI spec. + // e.g. event foo(uint32 a, int b) = "foo(uint32,int256)" + // Please note that "int" is substitute for its canonical representation "int256" + Sig string + + // ID returns the canonical representation of the event's signature used by the + // abi definition to identify event names and types. + ID common.Hash } -func (event Event) String() string { - inputs := make([]string, len(event.Inputs)) - for i, input := range event.Inputs { - inputs[i] = fmt.Sprintf("%v %v", input.Name, input.Type) +// NewEvent creates a new Event. +// It sanitizes the input arguments to remove unnamed arguments. +// It also precomputes the id, signature and string representation +// of the event. +func NewEvent(name, rawName string, anonymous bool, inputs Arguments) Event { + // sanitize inputs to remove inputs without names + // and precompute string and sig representation. + names := make([]string, len(inputs)) + types := make([]string, len(inputs)) + for i, input := range inputs { + if input.Name == "" { + inputs[i] = Argument{ + Name: fmt.Sprintf("arg%d", i), + Indexed: input.Indexed, + Type: input.Type, + } + } else { + inputs[i] = input + } + // string representation + names[i] = fmt.Sprintf("%v %v", input.Type, inputs[i].Name) if input.Indexed { - inputs[i] = fmt.Sprintf("%v indexed %v", input.Name, input.Type) + names[i] = fmt.Sprintf("%v indexed %v", input.Type, inputs[i].Name) } + // sig representation + types[i] = input.Type.String() } - return fmt.Sprintf("event %v(%v)", event.Name, strings.Join(inputs, ", ")) -} -// Id returns the canonical representation of the event's signature used by the -// abi definition to identify event names and types. -func (e Event) Id() common.Hash { - types := make([]string, len(e.Inputs)) - i := 0 - for _, input := range e.Inputs { - types[i] = input.Type.String() - i++ + str := fmt.Sprintf("event %v(%v)", rawName, strings.Join(names, ", ")) + sig := fmt.Sprintf("%v(%v)", rawName, strings.Join(types, ",")) + id := common.BytesToHash(crypto.Keccak256([]byte(sig))) + + return Event{ + Name: name, + RawName: rawName, + Anonymous: anonymous, + Inputs: inputs, + str: str, + Sig: sig, + ID: id, } - return common.BytesToHash(crypto.Keccak256([]byte(fmt.Sprintf("%v(%v)", e.Name, strings.Join(types, ","))))) +} + +func (e Event) String() string { + return e.str } diff --git a/accounts/abi/event_test.go b/accounts/abi/event_test.go index c39411d8f6..3a39059a4e 100644 --- a/accounts/abi/event_test.go +++ b/accounts/abi/event_test.go @@ -27,6 +27,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/crypto" ) @@ -58,12 +59,28 @@ var jsonEventPledge = []byte(`{ "type": "event" }`) +var jsonEventMixedCase = []byte(`{ + "anonymous": false, + "inputs": [{ + "indexed": false, "name": "value", "type": "uint256" + }, { + "indexed": false, "name": "_value", "type": "uint256" + }, { + "indexed": false, "name": "Value", "type": "uint256" + }], + "name": "MixedCase", + "type": "event" + }`) + // 1000000 var transferData1 = "00000000000000000000000000000000000000000000000000000000000f4240" // "0x00Ce0d46d924CC8437c806721496599FC3FFA268", 2218516807680, "usd" var pledgeData1 = "00000000000000000000000000ce0d46d924cc8437c806721496599fc3ffa2680000000000000000000000000000000000000000000000000000020489e800007573640000000000000000000000000000000000000000000000000000000000" +// 1000000,2218516807680,1000001 +var mixedCaseData1 = "00000000000000000000000000000000000000000000000000000000000f42400000000000000000000000000000000000000000000000000000020489e8000000000000000000000000000000000000000000000000000000000000000f4241" + func TestEventId(t *testing.T) { var table = []struct { definition string @@ -71,12 +88,45 @@ func TestEventId(t *testing.T) { }{ { definition: `[ - { "type" : "event", "name" : "balance", "inputs": [{ "name" : "in", "type": "uint256" }] }, - { "type" : "event", "name" : "check", "inputs": [{ "name" : "t", "type": "address" }, { "name": "b", "type": "uint256" }] } + { "type" : "event", "name" : "Balance", "inputs": [{ "name" : "in", "type": "uint256" }] }, + { "type" : "event", "name" : "Check", "inputs": [{ "name" : "t", "type": "address" }, { "name": "b", "type": "uint256" }] } ]`, expectations: map[string]common.Hash{ - "balance": crypto.Keccak256Hash([]byte("balance(uint256)")), - "check": crypto.Keccak256Hash([]byte("check(address,uint256)")), + "Balance": crypto.Keccak256Hash([]byte("Balance(uint256)")), + "Check": crypto.Keccak256Hash([]byte("Check(address,uint256)")), + }, + }, + } + + for _, test := range table { + abi, err := JSON(strings.NewReader(test.definition)) + if err != nil { + t.Fatal(err) + } + + for name, event := range abi.Events { + if event.ID != test.expectations[name] { + t.Errorf("expected id to be %x, got %x", test.expectations[name], event.ID) + } + } + } +} + +func TestEventString(t *testing.T) { + var table = []struct { + definition string + expectations map[string]string + }{ + { + definition: `[ + { "type" : "event", "name" : "Balance", "inputs": [{ "name" : "in", "type": "uint256" }] }, + { "type" : "event", "name" : "Check", "inputs": [{ "name" : "t", "type": "address" }, { "name": "b", "type": "uint256" }] }, + { "type" : "event", "name" : "Transfer", "inputs": [{ "name": "from", "type": "address", "indexed": true }, { "name": "to", "type": "address", "indexed": true }, { "name": "value", "type": "uint256" }] } + ]`, + expectations: map[string]string{ + "Balance": "event Balance(uint256 in)", + "Check": "event Check(address t, uint256 b)", + "Transfer": "event Transfer(address indexed from, address indexed to, uint256 value)", }, }, } @@ -88,8 +138,8 @@ func TestEventId(t *testing.T) { } for name, event := range abi.Events { - if event.Id() != test.expectations[name] { - t.Errorf("expected id to be %x, got %x", test.expectations[name], event.Id()) + if event.String() != test.expectations[name] { + t.Errorf("expected string to be %s, got %s", test.expectations[name], event.String()) } } } @@ -98,10 +148,6 @@ func TestEventId(t *testing.T) { // TestEventMultiValueWithArrayUnpack verifies that array fields will be counted after parsing array. func TestEventMultiValueWithArrayUnpack(t *testing.T) { definition := `[{"name": "test", "type": "event", "inputs": [{"indexed": false, "name":"value1", "type":"uint8[2]"},{"indexed": false, "name":"value2", "type":"uint8"}]}]` - type testStruct struct { - Value1 [2]uint8 - Value2 uint8 - } abi, err := JSON(strings.NewReader(definition)) require.NoError(t, err) var b bytes.Buffer @@ -109,10 +155,10 @@ func TestEventMultiValueWithArrayUnpack(t *testing.T) { for ; i <= 3; i++ { b.Write(packNum(reflect.ValueOf(i))) } - var rst testStruct - require.NoError(t, abi.Unpack(&rst, "test", b.Bytes())) - require.Equal(t, [2]uint8{1, 2}, rst.Value1) - require.Equal(t, uint8(3), rst.Value2) + unpacked, err := abi.Unpack("test", b.Bytes()) + require.NoError(t, err) + require.Equal(t, [2]uint8{1, 2}, unpacked[0]) + require.Equal(t, uint8(3), unpacked[1]) } func TestEventTupleUnpack(t *testing.T) { @@ -121,6 +167,27 @@ func TestEventTupleUnpack(t *testing.T) { Value *big.Int } + type EventTransferWithTag struct { + // this is valid because `value` is not exportable, + // so value is only unmarshalled into `Value1`. + value *big.Int //lint:ignore U1000 unused field is part of test + Value1 *big.Int `abi:"value"` + } + + type BadEventTransferWithSameFieldAndTag struct { + Value *big.Int + Value1 *big.Int `abi:"value"` + } + + type BadEventTransferWithDuplicatedTag struct { + Value1 *big.Int `abi:"value"` + Value2 *big.Int `abi:"value"` + } + + type BadEventTransferWithEmptyTag struct { + Value *big.Int `abi:""` + } + type EventPledge struct { Who common.Address Wad *big.Int @@ -133,9 +200,16 @@ func TestEventTupleUnpack(t *testing.T) { Currency [3]byte } + type EventMixedCase struct { + Value1 *big.Int `abi:"value"` + Value2 *big.Int `abi:"_value"` + Value3 *big.Int `abi:"Value"` + } + bigint := new(big.Int) bigintExpected := big.NewInt(1000000) bigintExpected2 := big.NewInt(2218516807680) + bigintExpected3 := big.NewInt(1000001) addr := common.HexToAddress("0x00Ce0d46d924CC8437c806721496599FC3FFA268") var testCases = []struct { data string @@ -158,6 +232,34 @@ func TestEventTupleUnpack(t *testing.T) { jsonEventTransfer, "", "Can unpack ERC20 Transfer event into slice", + }, { + transferData1, + &EventTransferWithTag{}, + &EventTransferWithTag{Value1: bigintExpected}, + jsonEventTransfer, + "", + "Can unpack ERC20 Transfer event into structure with abi: tag", + }, { + transferData1, + &BadEventTransferWithDuplicatedTag{}, + &BadEventTransferWithDuplicatedTag{}, + jsonEventTransfer, + "struct: abi tag in 'Value2' already mapped", + "Can not unpack ERC20 Transfer event with duplicated abi tag", + }, { + transferData1, + &BadEventTransferWithSameFieldAndTag{}, + &BadEventTransferWithSameFieldAndTag{}, + jsonEventTransfer, + "abi: multiple variables maps to the same abi field 'value'", + "Can not unpack ERC20 Transfer event with a field and a tag mapping to the same abi variable", + }, { + transferData1, + &BadEventTransferWithEmptyTag{}, + &BadEventTransferWithEmptyTag{}, + jsonEventTransfer, + "struct: abi tag in 'Value' is empty", + "Can not unpack ERC20 Transfer event with an empty tag", }, { pledgeData1, &EventPledge{}, @@ -207,15 +309,22 @@ func TestEventTupleUnpack(t *testing.T) { &[]interface{}{common.Address{}, new(big.Int)}, &[]interface{}{}, jsonEventPledge, - "abi: insufficient number of elements in the list/array for unpack, want 3, got 2", + "abi: insufficient number of arguments for unpack, want 3, got 2", "Can not unpack Pledge event into too short slice", }, { pledgeData1, new(map[string]interface{}), &[]interface{}{}, jsonEventPledge, - "abi: cannot unmarshal tuple into map[string]interface {}", + "abi:[2] cannot unmarshal tuple in to map[string]interface {}", "Can not unpack Pledge event into map", + }, { + mixedCaseData1, + &EventMixedCase{}, + &EventMixedCase{Value1: bigintExpected, Value2: bigintExpected2, Value3: bigintExpected3}, + jsonEventMixedCase, + "", + "Can unpack abi variables with mixed case", }} for _, tc := range testCases { @@ -227,7 +336,7 @@ func TestEventTupleUnpack(t *testing.T) { assert.Nil(err, "Should be able to unpack event data.") assert.Equal(tc.expected, tc.dest, tc.name) } else { - assert.EqualError(err, tc.error) + assert.EqualError(err, tc.error, tc.name) } }) } @@ -239,48 +348,14 @@ func unpackTestEventData(dest interface{}, hexData string, jsonEvent []byte, ass var e Event assert.NoError(json.Unmarshal(jsonEvent, &e), "Should be able to unmarshal event ABI") a := ABI{Events: map[string]Event{"e": e}} - return a.Unpack(dest, "e", data) -} - -/* -Taken from -https://github.com/tomochain/tomochain/pull/15568 -*/ - -type testResult struct { - Values [2]*big.Int - Value1 *big.Int - Value2 *big.Int -} - -type testCase struct { - definition string - want testResult -} - -func (tc testCase) encoded(intType, arrayType Type) []byte { - var b bytes.Buffer - if tc.want.Value1 != nil { - val, _ := intType.pack(reflect.ValueOf(tc.want.Value1)) - b.Write(val) - } - - if !reflect.DeepEqual(tc.want.Values, [2]*big.Int{nil, nil}) { - val, _ := arrayType.pack(reflect.ValueOf(tc.want.Values)) - b.Write(val) - } - if tc.want.Value2 != nil { - val, _ := intType.pack(reflect.ValueOf(tc.want.Value2)) - b.Write(val) - } - return b.Bytes() + return a.UnpackIntoInterface(dest, "e", data) } // TestEventUnpackIndexed verifies that indexed field will be skipped by event decoder. func TestEventUnpackIndexed(t *testing.T) { definition := `[{"name": "test", "type": "event", "inputs": [{"indexed": true, "name":"value1", "type":"uint8"},{"indexed": false, "name":"value2", "type":"uint8"}]}]` type testStruct struct { - Value1 uint8 + Value1 uint8 // indexed Value2 uint8 } abi, err := JSON(strings.NewReader(definition)) @@ -288,16 +363,16 @@ func TestEventUnpackIndexed(t *testing.T) { var b bytes.Buffer b.Write(packNum(reflect.ValueOf(uint8(8)))) var rst testStruct - require.NoError(t, abi.Unpack(&rst, "test", b.Bytes())) + require.NoError(t, abi.UnpackIntoInterface(&rst, "test", b.Bytes())) require.Equal(t, uint8(0), rst.Value1) require.Equal(t, uint8(8), rst.Value2) } -// TestEventIndexedWithArrayUnpack verifies that decoder will not overlow when static array is indexed input. +// TestEventIndexedWithArrayUnpack verifies that decoder will not overflow when static array is indexed input. func TestEventIndexedWithArrayUnpack(t *testing.T) { definition := `[{"name": "test", "type": "event", "inputs": [{"indexed": true, "name":"value1", "type":"uint8[2]"},{"indexed": false, "name":"value2", "type":"string"}]}]` type testStruct struct { - Value1 [2]uint8 + Value1 [2]uint8 // indexed Value2 string } abi, err := JSON(strings.NewReader(definition)) @@ -310,7 +385,7 @@ func TestEventIndexedWithArrayUnpack(t *testing.T) { b.Write(common.RightPadBytes([]byte(stringOut), 32)) var rst testStruct - require.NoError(t, abi.Unpack(&rst, "test", b.Bytes())) + require.NoError(t, abi.UnpackIntoInterface(&rst, "test", b.Bytes())) require.Equal(t, [2]uint8{0, 0}, rst.Value1) require.Equal(t, stringOut, rst.Value2) } diff --git a/accounts/abi/method.go b/accounts/abi/method.go index 57a2f0e0a4..e2ca384203 100644 --- a/accounts/abi/method.go +++ b/accounts/abi/method.go @@ -23,57 +23,146 @@ import ( "github.com/tomochain/tomochain/crypto" ) +// FunctionType represents different types of functions a contract might have. +type FunctionType int + +const ( + // Constructor represents the constructor of the contract. + // The constructor function is called while deploying a contract. + Constructor FunctionType = iota + // Fallback represents the fallback function. + // This function is executed if no other function matches the given function + // signature and no receive function is specified. + Fallback + // Receive represents the receive function. + // This function is executed on plain Ether transfers. + Receive + // Function represents a normal function. + Function +) + // Method represents a callable given a `Name` and whether the method is a constant. // If the method is `Const` no transaction needs to be created for this // particular Method call. It can easily be simulated using a local VM. // For example a `Balance()` method only needs to retrieve something -// from the storage and therefor requires no Tx to be send to the +// from the storage and therefore requires no Tx to be sent to the // network. A method such as `Transact` does require a Tx and thus will -// be flagged `true`. +// be flagged `false`. // Input specifies the required input parameters for this gives method. type Method struct { + // Name is the method name used for internal representation. It's derived from + // the raw name and a suffix will be added in the case of a function overload. + // + // e.g. + // These are two functions that have the same name: + // * foo(int,int) + // * foo(uint,uint) + // The method name of the first one will be resolved as foo while the second one + // will be resolved as foo0. Name string - Const bool + RawName string // RawName is the raw method name parsed from ABI + + // Type indicates whether the method is a + // special fallback introduced in solidity v0.6.0 + Type FunctionType + + // StateMutability indicates the mutability state of method, + // the default value is nonpayable. It can be empty if the abi + // is generated by legacy compiler. + StateMutability string + + // Legacy indicators generated by compiler before v0.6.0 + Constant bool + Payable bool + Inputs Arguments Outputs Arguments + str string + // Sig returns the methods string signature according to the ABI spec. + // e.g. function foo(uint32 a, int b) = "foo(uint32,int256)" + // Please note that "int" is substitute for its canonical representation "int256" + Sig string + // ID returns the canonical representation of the method's signature used by the + // abi definition to identify method names and types. + ID []byte } -// Sig returns the methods string signature according to the ABI spec. -// -// Example -// -// function foo(uint32 a, int b) = "foo(uint32,int256)" -// -// Please note that "int" is substitute for its canonical representation "int256" -func (method Method) Sig() string { - types := make([]string, len(method.Inputs)) - i := 0 - for _, input := range method.Inputs { +// NewMethod creates a new Method. +// A method should always be created using NewMethod. +// It also precomputes the sig representation and the string representation +// of the method. +func NewMethod(name string, rawName string, funType FunctionType, mutability string, isConst, isPayable bool, inputs Arguments, outputs Arguments) Method { + var ( + types = make([]string, len(inputs)) + inputNames = make([]string, len(inputs)) + outputNames = make([]string, len(outputs)) + ) + for i, input := range inputs { + inputNames[i] = fmt.Sprintf("%v %v", input.Type, input.Name) types[i] = input.Type.String() - i++ - } - return fmt.Sprintf("%v(%v)", method.Name, strings.Join(types, ",")) -} - -func (method Method) String() string { - inputs := make([]string, len(method.Inputs)) - for i, input := range method.Inputs { - inputs[i] = fmt.Sprintf("%v %v", input.Name, input.Type) } - outputs := make([]string, len(method.Outputs)) - for i, output := range method.Outputs { + for i, output := range outputs { + outputNames[i] = output.Type.String() if len(output.Name) > 0 { - outputs[i] = fmt.Sprintf("%v ", output.Name) + outputNames[i] += fmt.Sprintf(" %v", output.Name) } - outputs[i] += output.Type.String() } - constant := "" - if method.Const { - constant = "constant " + // calculate the signature and method id. Note only function + // has meaningful signature and id. + var ( + sig string + id []byte + ) + if funType == Function { + sig = fmt.Sprintf("%v(%v)", rawName, strings.Join(types, ",")) + id = crypto.Keccak256([]byte(sig))[:4] + } + // Extract meaningful state mutability of solidity method. + // If it's default value, never print it. + state := mutability + if state == "nonpayable" { + state = "" + } + if state != "" { + state = state + " " + } + identity := fmt.Sprintf("function %v", rawName) + switch funType { + case Fallback: + identity = "fallback" + case Receive: + identity = "receive" + case Constructor: + identity = "constructor" + } + str := fmt.Sprintf("%v(%v) %sreturns(%v)", identity, strings.Join(inputNames, ", "), state, strings.Join(outputNames, ", ")) + + return Method{ + Name: name, + RawName: rawName, + Type: funType, + StateMutability: mutability, + Constant: isConst, + Payable: isPayable, + Inputs: inputs, + Outputs: outputs, + str: str, + Sig: sig, + ID: id, } - return fmt.Sprintf("function %v(%v) %sreturns(%v)", method.Name, strings.Join(inputs, ", "), constant, strings.Join(outputs, ", ")) } -func (method Method) Id() []byte { - return crypto.Keccak256([]byte(method.Sig()))[:4] +func (method Method) String() string { + return method.str +} + +// IsConstant returns the indicator whether the method is read-only. +func (method Method) IsConstant() bool { + return method.StateMutability == "view" || method.StateMutability == "pure" || method.Constant +} + +// IsPayable returns the indicator whether the method can process +// plain ether transfers. +func (method Method) IsPayable() bool { + return method.StateMutability == "payable" || method.Payable } diff --git a/accounts/abi/pack.go b/accounts/abi/pack.go index 7d422f579f..5d8b86edb5 100644 --- a/accounts/abi/pack.go +++ b/accounts/abi/pack.go @@ -17,6 +17,8 @@ package abi import ( + "errors" + "fmt" "math/big" "reflect" @@ -25,7 +27,7 @@ import ( ) // packBytesSlice packs the given bytes as [L, V] as the canonical representation -// bytes slice +// bytes slice. func packBytesSlice(bytes []byte, l int) []byte { len := packNum(reflect.ValueOf(l)) return append(len, common.RightPadBytes(bytes, (l+31)/32*32)...) @@ -33,49 +35,51 @@ func packBytesSlice(bytes []byte, l int) []byte { // packElement packs the given reflect value according to the abi specification in // t. -func packElement(t Type, reflectValue reflect.Value) []byte { +func packElement(t Type, reflectValue reflect.Value) ([]byte, error) { switch t.T { case IntTy, UintTy: - return packNum(reflectValue) + return packNum(reflectValue), nil case StringTy: - return packBytesSlice([]byte(reflectValue.String()), reflectValue.Len()) + return packBytesSlice([]byte(reflectValue.String()), reflectValue.Len()), nil case AddressTy: if reflectValue.Kind() == reflect.Array { reflectValue = mustArrayToByteSlice(reflectValue) } - return common.LeftPadBytes(reflectValue.Bytes(), 32) + return common.LeftPadBytes(reflectValue.Bytes(), 32), nil case BoolTy: if reflectValue.Bool() { - return math.PaddedBigBytes(common.Big1, 32) + return math.PaddedBigBytes(common.Big1, 32), nil } - return math.PaddedBigBytes(common.Big0, 32) + return math.PaddedBigBytes(common.Big0, 32), nil case BytesTy: if reflectValue.Kind() == reflect.Array { reflectValue = mustArrayToByteSlice(reflectValue) } - return packBytesSlice(reflectValue.Bytes(), reflectValue.Len()) + if reflectValue.Type() != reflect.TypeOf([]byte{}) { + return []byte{}, errors.New("Bytes type is neither slice nor array") + } + return packBytesSlice(reflectValue.Bytes(), reflectValue.Len()), nil case FixedBytesTy, FunctionTy: if reflectValue.Kind() == reflect.Array { reflectValue = mustArrayToByteSlice(reflectValue) } - return common.RightPadBytes(reflectValue.Bytes(), 32) + return common.RightPadBytes(reflectValue.Bytes(), 32), nil default: - panic("abi: fatal error") + return []byte{}, fmt.Errorf("Could not pack element, unknown type: %v", t.T) } } -// packNum packs the given number (using the reflect value) and will cast it to appropriate number representation +// packNum packs the given number (using the reflected value) and will cast it to appropriate number representation. func packNum(value reflect.Value) []byte { switch kind := value.Kind(); kind { case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return U256(new(big.Int).SetUint64(value.Uint())) + return math.U256Bytes(new(big.Int).SetUint64(value.Uint())) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return U256(big.NewInt(value.Int())) + return math.U256Bytes(big.NewInt(value.Int())) case reflect.Ptr: - return U256(value.Interface().(*big.Int)) + return math.U256Bytes(new(big.Int).Set(value.Interface().(*big.Int))) default: panic("abi: fatal error") } - } diff --git a/accounts/abi/pack_test.go b/accounts/abi/pack_test.go index 8578d03de2..ed5585b112 100644 --- a/accounts/abi/pack_test.go +++ b/accounts/abi/pack_test.go @@ -18,623 +18,51 @@ package abi import ( "bytes" + "encoding/hex" + "fmt" "math" "math/big" "reflect" + "strconv" "strings" "testing" "github.com/tomochain/tomochain/common" ) +// TestPack tests the general pack/unpack tests in packing_test.go func TestPack(t *testing.T) { - for i, test := range []struct { - typ string - components []ArgumentMarshaling - input interface{} - output []byte - }{ - { - "uint8", - nil, - uint8(2), - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "uint8[]", - nil, - []uint8{1, 2}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "uint16", - nil, - uint16(2), - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "uint16[]", - nil, - []uint16{1, 2}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "uint32", - nil, - uint32(2), - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "uint32[]", - nil, - []uint32{1, 2}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "uint64", - nil, - uint64(2), - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "uint64[]", - nil, - []uint64{1, 2}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "uint256", - nil, - big.NewInt(2), - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "uint256[]", - nil, - []*big.Int{big.NewInt(1), big.NewInt(2)}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "int8", - nil, - int8(2), - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "int8[]", - nil, - []int8{1, 2}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "int16", - nil, - int16(2), - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "int16[]", - nil, - []int16{1, 2}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "int32", - nil, - int32(2), - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "int32[]", - nil, - []int32{1, 2}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "int64", - nil, - int64(2), - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "int64[]", - nil, - []int64{1, 2}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "int256", - nil, - big.NewInt(2), - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "int256[]", - nil, - []*big.Int{big.NewInt(1), big.NewInt(2)}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002"), - }, - { - "bytes1", - nil, - [1]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes2", - nil, - [2]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes3", - nil, - [3]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes4", - nil, - [4]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes5", - nil, - [5]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes6", - nil, - [6]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes7", - nil, - [7]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes8", - nil, - [8]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes9", - nil, - [9]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes10", - nil, - [10]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes11", - nil, - [11]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes12", - nil, - [12]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes13", - nil, - [13]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes14", - nil, - [14]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes15", - nil, - [15]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes16", - nil, - [16]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes17", - nil, - [17]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes18", - nil, - [18]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes19", - nil, - [19]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes20", - nil, - [20]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes21", - nil, - [21]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes22", - nil, - [22]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes23", - nil, - [23]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes24", - nil, - [24]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes25", - nil, - [25]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes26", - nil, - [26]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes27", - nil, - [27]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes28", - nil, - [28]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes29", - nil, - [29]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes30", - nil, - [30]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes31", - nil, - [31]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "bytes32", - nil, - [32]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "uint32[2][3][4]", - nil, - [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000050000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000700000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000d000000000000000000000000000000000000000000000000000000000000000e000000000000000000000000000000000000000000000000000000000000000f000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000110000000000000000000000000000000000000000000000000000000000000012000000000000000000000000000000000000000000000000000000000000001300000000000000000000000000000000000000000000000000000000000000140000000000000000000000000000000000000000000000000000000000000015000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000170000000000000000000000000000000000000000000000000000000000000018"), - }, - { - "address[]", - nil, - []common.Address{{1}, {2}}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000200000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000"), - }, - { - "bytes32[]", - nil, - []common.Hash{{1}, {2}}, - common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000201000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000"), - }, - { - "function", - nil, - [24]byte{1}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - "string", - nil, - "foobar", - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000006666f6f6261720000000000000000000000000000000000000000000000000000"), - }, - { - "string[]", - nil, - []string{"hello", "foobar"}, - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002" + // len(array) = 2 - "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 - "0000000000000000000000000000000000000000000000000000000000000080" + // offset 128 to i = 1 - "0000000000000000000000000000000000000000000000000000000000000005" + // len(str[0]) = 5 - "68656c6c6f000000000000000000000000000000000000000000000000000000" + // str[0] - "0000000000000000000000000000000000000000000000000000000000000006" + // len(str[1]) = 6 - "666f6f6261720000000000000000000000000000000000000000000000000000"), // str[1] - }, - { - "string[2]", - nil, - []string{"hello", "foobar"}, - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040" + // offset to i = 0 - "0000000000000000000000000000000000000000000000000000000000000080" + // offset to i = 1 - "0000000000000000000000000000000000000000000000000000000000000005" + // len(str[0]) = 5 - "68656c6c6f000000000000000000000000000000000000000000000000000000" + // str[0] - "0000000000000000000000000000000000000000000000000000000000000006" + // len(str[1]) = 6 - "666f6f6261720000000000000000000000000000000000000000000000000000"), // str[1] - }, - { - "bytes32[][]", - nil, - [][]common.Hash{{{1}, {2}}, {{3}, {4}, {5}}}, - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002" + // len(array) = 2 - "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 - "00000000000000000000000000000000000000000000000000000000000000a0" + // offset 160 to i = 1 - "0000000000000000000000000000000000000000000000000000000000000002" + // len(array[0]) = 2 - "0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0] - "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1] - "0000000000000000000000000000000000000000000000000000000000000003" + // len(array[1]) = 3 - "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0] - "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1] - "0500000000000000000000000000000000000000000000000000000000000000"), // array[1][2] - }, - - { - "bytes32[][2]", - nil, - [][]common.Hash{{{1}, {2}}, {{3}, {4}, {5}}}, - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 - "00000000000000000000000000000000000000000000000000000000000000a0" + // offset 160 to i = 1 - "0000000000000000000000000000000000000000000000000000000000000002" + // len(array[0]) = 2 - "0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0] - "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1] - "0000000000000000000000000000000000000000000000000000000000000003" + // len(array[1]) = 3 - "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0] - "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1] - "0500000000000000000000000000000000000000000000000000000000000000"), // array[1][2] - }, - - { - "bytes32[3][2]", - nil, - [][]common.Hash{{{1}, {2}, {3}}, {{3}, {4}, {5}}}, - common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0] - "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1] - "0300000000000000000000000000000000000000000000000000000000000000" + // array[0][2] - "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0] - "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1] - "0500000000000000000000000000000000000000000000000000000000000000"), // array[1][2] - }, - { - // static tuple - "tuple", - []ArgumentMarshaling{ - {Name: "a", Type: "int64"}, - {Name: "b", Type: "int256"}, - {Name: "c", Type: "int256"}, - {Name: "d", Type: "bool"}, - {Name: "e", Type: "bytes32[3][2]"}, - }, - struct { - A int64 - B *big.Int - C *big.Int - D bool - E [][]common.Hash - }{1, big.NewInt(1), big.NewInt(-1), true, [][]common.Hash{{{1}, {2}, {3}}, {{3}, {4}, {5}}}}, - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001" + // struct[a] - "0000000000000000000000000000000000000000000000000000000000000001" + // struct[b] - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // struct[c] - "0000000000000000000000000000000000000000000000000000000000000001" + // struct[d] - "0100000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][0] - "0200000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][1] - "0300000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][2] - "0300000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[1][0] - "0400000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[1][1] - "0500000000000000000000000000000000000000000000000000000000000000"), // struct[e] array[1][2] - }, - { - // dynamic tuple - "tuple", - []ArgumentMarshaling{ - {Name: "a", Type: "string"}, - {Name: "b", Type: "int64"}, - {Name: "c", Type: "bytes"}, - {Name: "d", Type: "string[]"}, - {Name: "e", Type: "int256[]"}, - {Name: "f", Type: "address[]"}, - }, - struct { - FieldA string `abi:"a"` // Test whether abi tag works - FieldB int64 `abi:"b"` - C []byte - D []string - E []*big.Int - F []common.Address - }{"foobar", 1, []byte{1}, []string{"foo", "bar"}, []*big.Int{big.NewInt(1), big.NewInt(-1)}, []common.Address{{1}, {2}}}, - common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000000000c0" + // struct[a] offset - "0000000000000000000000000000000000000000000000000000000000000001" + // struct[b] - "0000000000000000000000000000000000000000000000000000000000000100" + // struct[c] offset - "0000000000000000000000000000000000000000000000000000000000000140" + // struct[d] offset - "0000000000000000000000000000000000000000000000000000000000000220" + // struct[e] offset - "0000000000000000000000000000000000000000000000000000000000000280" + // struct[f] offset - "0000000000000000000000000000000000000000000000000000000000000006" + // struct[a] length - "666f6f6261720000000000000000000000000000000000000000000000000000" + // struct[a] "foobar" - "0000000000000000000000000000000000000000000000000000000000000001" + // struct[c] length - "0100000000000000000000000000000000000000000000000000000000000000" + // []byte{1} - "0000000000000000000000000000000000000000000000000000000000000002" + // struct[d] length - "0000000000000000000000000000000000000000000000000000000000000040" + // foo offset - "0000000000000000000000000000000000000000000000000000000000000080" + // bar offset - "0000000000000000000000000000000000000000000000000000000000000003" + // foo length - "666f6f0000000000000000000000000000000000000000000000000000000000" + // foo - "0000000000000000000000000000000000000000000000000000000000000003" + // bar offset - "6261720000000000000000000000000000000000000000000000000000000000" + // bar - "0000000000000000000000000000000000000000000000000000000000000002" + // struct[e] length - "0000000000000000000000000000000000000000000000000000000000000001" + // 1 - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // -1 - "0000000000000000000000000000000000000000000000000000000000000002" + // struct[f] length - "0000000000000000000000000100000000000000000000000000000000000000" + // common.Address{1} - "0000000000000000000000000200000000000000000000000000000000000000"), // common.Address{2} - }, - { - // nested tuple - "tuple", - []ArgumentMarshaling{ - {Name: "a", Type: "tuple", Components: []ArgumentMarshaling{{Name: "a", Type: "uint256"}, {Name: "b", Type: "uint256[]"}}}, - {Name: "b", Type: "int256[]"}, - }, - struct { - A struct { - FieldA *big.Int `abi:"a"` - B []*big.Int - } - B []*big.Int - }{ - A: struct { - FieldA *big.Int `abi:"a"` // Test whether abi tag works for nested tuple - B []*big.Int - }{big.NewInt(1), []*big.Int{big.NewInt(1), big.NewInt(0)}}, - B: []*big.Int{big.NewInt(1), big.NewInt(0)}}, - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040" + // a offset - "00000000000000000000000000000000000000000000000000000000000000e0" + // b offset - "0000000000000000000000000000000000000000000000000000000000000001" + // a.a value - "0000000000000000000000000000000000000000000000000000000000000040" + // a.b offset - "0000000000000000000000000000000000000000000000000000000000000002" + // a.b length - "0000000000000000000000000000000000000000000000000000000000000001" + // a.b[0] value - "0000000000000000000000000000000000000000000000000000000000000000" + // a.b[1] value - "0000000000000000000000000000000000000000000000000000000000000002" + // b length - "0000000000000000000000000000000000000000000000000000000000000001" + // b[0] value - "0000000000000000000000000000000000000000000000000000000000000000"), // b[1] value - }, - { - // tuple slice - "tuple[]", - []ArgumentMarshaling{ - {Name: "a", Type: "int256"}, - {Name: "b", Type: "int256[]"}, - }, - []struct { - A *big.Int - B []*big.Int - }{ - {big.NewInt(-1), []*big.Int{big.NewInt(1), big.NewInt(0)}}, - {big.NewInt(1), []*big.Int{big.NewInt(2), big.NewInt(-1)}}, - }, - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002" + // tuple length - "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0] offset - "00000000000000000000000000000000000000000000000000000000000000e0" + // tuple[1] offset - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].A - "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0].B offset - "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[0].B length - "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].B[0] value - "0000000000000000000000000000000000000000000000000000000000000000" + // tuple[0].B[1] value - "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].A - "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[1].B offset - "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].B length - "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].B[0] value - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // tuple[1].B[1] value - }, - { - // static tuple array - "tuple[2]", - []ArgumentMarshaling{ - {Name: "a", Type: "int256"}, - {Name: "b", Type: "int256"}, - }, - [2]struct { - A *big.Int - B *big.Int - }{ - {big.NewInt(-1), big.NewInt(1)}, - {big.NewInt(1), big.NewInt(-1)}, - }, - common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].a - "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].b - "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].a - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // tuple[1].b - }, - { - // dynamic tuple array - "tuple[2]", - []ArgumentMarshaling{ - {Name: "a", Type: "int256[]"}, - }, - [2]struct { - A []*big.Int - }{ - {[]*big.Int{big.NewInt(-1), big.NewInt(1)}}, - {[]*big.Int{big.NewInt(1), big.NewInt(-1)}}, - }, - common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0] offset - "00000000000000000000000000000000000000000000000000000000000000c0" + // tuple[1] offset - "0000000000000000000000000000000000000000000000000000000000000020" + // tuple[0].A offset - "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[0].A length - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].A[0] - "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].A[1] - "0000000000000000000000000000000000000000000000000000000000000020" + // tuple[1].A offset - "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].A length - "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].A[0] - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // tuple[1].A[1] - }, - } { - typ, err := NewType(test.typ, test.components) - if err != nil { - t.Fatalf("%v failed. Unexpected parse error: %v", i, err) - } - output, err := typ.pack(reflect.ValueOf(test.input)) - if err != nil { - t.Fatalf("%v failed. Unexpected pack error: %v", i, err) - } - - if !bytes.Equal(output, test.output) { - t.Errorf("input %d for typ: %v failed. Expected bytes: '%x' Got: '%x'", i, typ.String(), test.output, output) - } + for i, test := range packUnpackTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + encb, err := hex.DecodeString(test.packed) + if err != nil { + t.Fatalf("invalid hex %s: %v", test.packed, err) + } + inDef := fmt.Sprintf(`[{ "name" : "method", "type": "function", "inputs": %s}]`, test.def) + inAbi, err := JSON(strings.NewReader(inDef)) + if err != nil { + t.Fatalf("invalid ABI definition %s, %v", inDef, err) + } + var packed []byte + packed, err = inAbi.Pack("method", test.unpacked) + + if err != nil { + t.Fatalf("test %d (%v) failed: %v", i, test.def, err) + } + if !reflect.DeepEqual(packed[4:], encb) { + t.Errorf("test %d (%v) failed: expected %v, got %v", i, test.def, encb, packed[4:]) + } + }) } } func TestMethodPack(t *testing.T) { - abi, err := JSON(strings.NewReader(jsondata2)) + abi, err := JSON(strings.NewReader(jsondata)) if err != nil { t.Fatal(err) } - sig := abi.Methods["slice"].Id() + sig := abi.Methods["slice"].ID sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) @@ -648,7 +76,7 @@ func TestMethodPack(t *testing.T) { } var addrA, addrB = common.Address{1}, common.Address{2} - sig = abi.Methods["sliceAddress"].Id() + sig = abi.Methods["sliceAddress"].ID sig = append(sig, common.LeftPadBytes([]byte{32}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) sig = append(sig, common.LeftPadBytes(addrA[:], 32)...) @@ -663,7 +91,7 @@ func TestMethodPack(t *testing.T) { } var addrC, addrD = common.Address{3}, common.Address{4} - sig = abi.Methods["sliceMultiAddress"].Id() + sig = abi.Methods["sliceMultiAddress"].ID sig = append(sig, common.LeftPadBytes([]byte{64}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{160}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) @@ -681,7 +109,7 @@ func TestMethodPack(t *testing.T) { t.Errorf("expected %x got %x", sig, packed) } - sig = abi.Methods["slice256"].Id() + sig = abi.Methods["slice256"].ID sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) @@ -695,7 +123,7 @@ func TestMethodPack(t *testing.T) { } a := [2][2]*big.Int{{big.NewInt(1), big.NewInt(1)}, {big.NewInt(2), big.NewInt(0)}} - sig = abi.Methods["nestedArray"].Id() + sig = abi.Methods["nestedArray"].ID sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{1}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{2}, 32)...) @@ -712,7 +140,7 @@ func TestMethodPack(t *testing.T) { t.Errorf("expected %x got %x", sig, packed) } - sig = abi.Methods["nestedArray2"].Id() + sig = abi.Methods["nestedArray2"].ID sig = append(sig, common.LeftPadBytes([]byte{0x20}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{0x40}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{0x80}, 32)...) @@ -728,7 +156,7 @@ func TestMethodPack(t *testing.T) { t.Errorf("expected %x got %x", sig, packed) } - sig = abi.Methods["nestedSlice"].Id() + sig = abi.Methods["nestedSlice"].ID sig = append(sig, common.LeftPadBytes([]byte{0x20}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{0x02}, 32)...) sig = append(sig, common.LeftPadBytes([]byte{0x40}, 32)...) diff --git a/accounts/abi/packing_test.go b/accounts/abi/packing_test.go new file mode 100644 index 0000000000..bdf00273aa --- /dev/null +++ b/accounts/abi/packing_test.go @@ -0,0 +1,990 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package abi + +import ( + "math/big" + + "github.com/tomochain/tomochain/common" +) + +type packUnpackTest struct { + def string + unpacked interface{} + packed string +} + +var packUnpackTests = []packUnpackTest{ + // Booleans + { + def: `[{ "type": "bool" }]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001", + unpacked: true, + }, + { + def: `[{ "type": "bool" }]`, + packed: "0000000000000000000000000000000000000000000000000000000000000000", + unpacked: false, + }, + // Integers + { + def: `[{ "type": "uint8" }]`, + unpacked: uint8(2), + packed: "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{ "type": "uint8[]" }]`, + unpacked: []uint8{1, 2}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{ "type": "uint16" }]`, + unpacked: uint16(2), + packed: "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{ "type": "uint16[]" }]`, + unpacked: []uint16{1, 2}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "uint17"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001", + unpacked: big.NewInt(1), + }, + { + def: `[{"type": "uint32"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001", + unpacked: uint32(1), + }, + { + def: `[{"type": "uint32[]"}]`, + unpacked: []uint32{1, 2}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "uint64"}]`, + unpacked: uint64(2), + packed: "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "uint64[]"}]`, + unpacked: []uint64{1, 2}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "uint256"}]`, + unpacked: big.NewInt(2), + packed: "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "uint256[]"}]`, + unpacked: []*big.Int{big.NewInt(1), big.NewInt(2)}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "int8"}]`, + unpacked: int8(2), + packed: "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "int8[]"}]`, + unpacked: []int8{1, 2}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "int16"}]`, + unpacked: int16(2), + packed: "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "int16[]"}]`, + unpacked: []int16{1, 2}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "int17"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001", + unpacked: big.NewInt(1), + }, + { + def: `[{"type": "int32"}]`, + unpacked: int32(2), + packed: "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "int32"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001", + unpacked: int32(1), + }, + { + def: `[{"type": "int32[]"}]`, + unpacked: []int32{1, 2}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "int64"}]`, + unpacked: int64(2), + packed: "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "int64[]"}]`, + unpacked: []int64{1, 2}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "int256"}]`, + unpacked: big.NewInt(2), + packed: "0000000000000000000000000000000000000000000000000000000000000002", + }, + { + def: `[{"type": "int256"}]`, + packed: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + unpacked: big.NewInt(-1), + }, + { + def: `[{"type": "int256[]"}]`, + unpacked: []*big.Int{big.NewInt(1), big.NewInt(2)}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + }, + // Address + { + def: `[{"type": "address"}]`, + packed: "0000000000000000000000000100000000000000000000000000000000000000", + unpacked: common.Address{1}, + }, + { + def: `[{"type": "address[]"}]`, + unpacked: []common.Address{{1}, {2}}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000100000000000000000000000000000000000000" + + "0000000000000000000000000200000000000000000000000000000000000000", + }, + // Bytes + { + def: `[{"type": "bytes1"}]`, + unpacked: [1]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes2"}]`, + unpacked: [2]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes3"}]`, + unpacked: [3]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes4"}]`, + unpacked: [4]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes5"}]`, + unpacked: [5]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes6"}]`, + unpacked: [6]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes7"}]`, + unpacked: [7]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes8"}]`, + unpacked: [8]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes9"}]`, + unpacked: [9]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes10"}]`, + unpacked: [10]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes11"}]`, + unpacked: [11]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes12"}]`, + unpacked: [12]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes13"}]`, + unpacked: [13]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes14"}]`, + unpacked: [14]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes15"}]`, + unpacked: [15]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes16"}]`, + unpacked: [16]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes17"}]`, + unpacked: [17]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes18"}]`, + unpacked: [18]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes19"}]`, + unpacked: [19]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes20"}]`, + unpacked: [20]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes21"}]`, + unpacked: [21]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes22"}]`, + unpacked: [22]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes23"}]`, + unpacked: [23]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes24"}]`, + unpacked: [24]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes25"}]`, + unpacked: [25]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes26"}]`, + unpacked: [26]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes27"}]`, + unpacked: [27]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes28"}]`, + unpacked: [28]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes29"}]`, + unpacked: [29]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes30"}]`, + unpacked: [30]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes31"}]`, + unpacked: [31]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes32"}]`, + unpacked: [32]byte{1}, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "bytes32"}]`, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + unpacked: [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + { + def: `[{"type": "bytes"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000020" + + "0100000000000000000000000000000000000000000000000000000000000000", + unpacked: common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), + }, + { + def: `[{"type": "bytes32"}]`, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + unpacked: [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + // Functions + { + def: `[{"type": "function"}]`, + packed: "0100000000000000000000000000000000000000000000000000000000000000", + unpacked: [24]byte{1}, + }, + // Slice and Array + { + def: `[{"type": "uint8[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: []uint8{1, 2}, + }, + { + def: `[{"type": "uint8[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000000", + unpacked: []uint8{}, + }, + { + def: `[{"type": "uint256[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000000", + unpacked: []*big.Int{}, + }, + { + def: `[{"type": "uint8[2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [2]uint8{1, 2}, + }, + { + def: `[{"type": "int8[2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [2]int8{1, 2}, + }, + { + def: `[{"type": "int16[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: []int16{1, 2}, + }, + { + def: `[{"type": "int16[2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [2]int16{1, 2}, + }, + { + def: `[{"type": "int32[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: []int32{1, 2}, + }, + { + def: `[{"type": "int32[2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [2]int32{1, 2}, + }, + { + def: `[{"type": "int64[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: []int64{1, 2}, + }, + { + def: `[{"type": "int64[2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [2]int64{1, 2}, + }, + { + def: `[{"type": "int256[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: []*big.Int{big.NewInt(1), big.NewInt(2)}, + }, + { + def: `[{"type": "int256[3]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000003", + unpacked: [3]*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)}, + }, + // multi dimensional, if these pass, all types that don't require length prefix should pass + { + def: `[{"type": "uint8[][]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000000", + unpacked: [][]uint8{}, + }, + { + def: `[{"type": "uint8[][]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000040" + + "00000000000000000000000000000000000000000000000000000000000000a0" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [][]uint8{{1, 2}, {1, 2}}, + }, + { + def: `[{"type": "uint8[][]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000040" + + "00000000000000000000000000000000000000000000000000000000000000a0" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000003" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000003", + unpacked: [][]uint8{{1, 2}, {1, 2, 3}}, + }, + { + def: `[{"type": "uint8[2][2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [2][2]uint8{{1, 2}, {1, 2}}, + }, + { + def: `[{"type": "uint8[][2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000040" + + "0000000000000000000000000000000000000000000000000000000000000060" + + "0000000000000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000000", + unpacked: [2][]uint8{{}, {}}, + }, + { + def: `[{"type": "uint8[][2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000040" + + "0000000000000000000000000000000000000000000000000000000000000080" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000001", + unpacked: [2][]uint8{{1}, {1}}, + }, + { + def: `[{"type": "uint8[2][]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000000", + unpacked: [][2]uint8{}, + }, + { + def: `[{"type": "uint8[2][]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [][2]uint8{{1, 2}}, + }, + { + def: `[{"type": "uint8[2][]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [][2]uint8{{1, 2}, {1, 2}}, + }, + { + def: `[{"type": "uint16[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: []uint16{1, 2}, + }, + { + def: `[{"type": "uint16[2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [2]uint16{1, 2}, + }, + { + def: `[{"type": "uint32[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: []uint32{1, 2}, + }, + { + def: `[{"type": "uint32[2][3][4]"}]`, + unpacked: [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}}, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000003" + + "0000000000000000000000000000000000000000000000000000000000000004" + + "0000000000000000000000000000000000000000000000000000000000000005" + + "0000000000000000000000000000000000000000000000000000000000000006" + + "0000000000000000000000000000000000000000000000000000000000000007" + + "0000000000000000000000000000000000000000000000000000000000000008" + + "0000000000000000000000000000000000000000000000000000000000000009" + + "000000000000000000000000000000000000000000000000000000000000000a" + + "000000000000000000000000000000000000000000000000000000000000000b" + + "000000000000000000000000000000000000000000000000000000000000000c" + + "000000000000000000000000000000000000000000000000000000000000000d" + + "000000000000000000000000000000000000000000000000000000000000000e" + + "000000000000000000000000000000000000000000000000000000000000000f" + + "0000000000000000000000000000000000000000000000000000000000000010" + + "0000000000000000000000000000000000000000000000000000000000000011" + + "0000000000000000000000000000000000000000000000000000000000000012" + + "0000000000000000000000000000000000000000000000000000000000000013" + + "0000000000000000000000000000000000000000000000000000000000000014" + + "0000000000000000000000000000000000000000000000000000000000000015" + + "0000000000000000000000000000000000000000000000000000000000000016" + + "0000000000000000000000000000000000000000000000000000000000000017" + + "0000000000000000000000000000000000000000000000000000000000000018", + }, + + { + def: `[{"type": "bytes32[]"}]`, + unpacked: [][32]byte{{1}, {2}}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0100000000000000000000000000000000000000000000000000000000000000" + + "0200000000000000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "uint32[2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [2]uint32{1, 2}, + }, + { + def: `[{"type": "uint64[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: []uint64{1, 2}, + }, + { + def: `[{"type": "uint64[2]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: [2]uint64{1, 2}, + }, + { + def: `[{"type": "uint256[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: []*big.Int{big.NewInt(1), big.NewInt(2)}, + }, + { + def: `[{"type": "uint256[3]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000003", + unpacked: [3]*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)}, + }, + { + def: `[{"type": "string[4]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000080" + + "00000000000000000000000000000000000000000000000000000000000000c0" + + "0000000000000000000000000000000000000000000000000000000000000100" + + "0000000000000000000000000000000000000000000000000000000000000140" + + "0000000000000000000000000000000000000000000000000000000000000005" + + "48656c6c6f000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000005" + + "576f726c64000000000000000000000000000000000000000000000000000000" + + "000000000000000000000000000000000000000000000000000000000000000b" + + "476f2d657468657265756d000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000008" + + "457468657265756d000000000000000000000000000000000000000000000000", + unpacked: [4]string{"Hello", "World", "Go-ethereum", "Ethereum"}, + }, + { + def: `[{"type": "string[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000040" + + "0000000000000000000000000000000000000000000000000000000000000080" + + "0000000000000000000000000000000000000000000000000000000000000008" + + "457468657265756d000000000000000000000000000000000000000000000000" + + "000000000000000000000000000000000000000000000000000000000000000b" + + "676f2d657468657265756d000000000000000000000000000000000000000000", + unpacked: []string{"Ethereum", "go-ethereum"}, + }, + { + def: `[{"type": "bytes[]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000040" + + "0000000000000000000000000000000000000000000000000000000000000080" + + "0000000000000000000000000000000000000000000000000000000000000003" + + "f0f0f00000000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000003" + + "f0f0f00000000000000000000000000000000000000000000000000000000000", + unpacked: [][]byte{{0xf0, 0xf0, 0xf0}, {0xf0, 0xf0, 0xf0}}, + }, + { + def: `[{"type": "uint256[2][][]"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000040" + + "00000000000000000000000000000000000000000000000000000000000000e0" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "00000000000000000000000000000000000000000000000000000000000000c8" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "00000000000000000000000000000000000000000000000000000000000003e8" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "00000000000000000000000000000000000000000000000000000000000000c8" + + "0000000000000000000000000000000000000000000000000000000000000001" + + "00000000000000000000000000000000000000000000000000000000000003e8", + unpacked: [][][2]*big.Int{{{big.NewInt(1), big.NewInt(200)}, {big.NewInt(1), big.NewInt(1000)}}, {{big.NewInt(1), big.NewInt(200)}, {big.NewInt(1), big.NewInt(1000)}}}, + }, + // struct outputs + { + def: `[{"components": [{"name":"int1","type":"int256"},{"name":"int2","type":"int256"}], "type":"tuple"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: struct { + Int1 *big.Int + Int2 *big.Int + }{big.NewInt(1), big.NewInt(2)}, + }, + { + def: `[{"components": [{"name":"int_one","type":"int256"}], "type":"tuple"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001", + unpacked: struct { + IntOne *big.Int + }{big.NewInt(1)}, + }, + { + def: `[{"components": [{"name":"int__one","type":"int256"}], "type":"tuple"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001", + unpacked: struct { + IntOne *big.Int + }{big.NewInt(1)}, + }, + { + def: `[{"components": [{"name":"int_one_","type":"int256"}], "type":"tuple"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001", + unpacked: struct { + IntOne *big.Int + }{big.NewInt(1)}, + }, + { + def: `[{"components": [{"name":"int_one","type":"int256"}, {"name":"intone","type":"int256"}], "type":"tuple"}]`, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000002", + unpacked: struct { + IntOne *big.Int + Intone *big.Int + }{big.NewInt(1), big.NewInt(2)}, + }, + { + def: `[{"type": "string"}]`, + unpacked: "foobar", + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000006" + + "666f6f6261720000000000000000000000000000000000000000000000000000", + }, + { + def: `[{"type": "string[]"}]`, + unpacked: []string{"hello", "foobar"}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + // len(array) = 2 + "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 + "0000000000000000000000000000000000000000000000000000000000000080" + // offset 128 to i = 1 + "0000000000000000000000000000000000000000000000000000000000000005" + // len(str[0]) = 5 + "68656c6c6f000000000000000000000000000000000000000000000000000000" + // str[0] + "0000000000000000000000000000000000000000000000000000000000000006" + // len(str[1]) = 6 + "666f6f6261720000000000000000000000000000000000000000000000000000", // str[1] + }, + { + def: `[{"type": "string[2]"}]`, + unpacked: [2]string{"hello", "foobar"}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000040" + // offset to i = 0 + "0000000000000000000000000000000000000000000000000000000000000080" + // offset to i = 1 + "0000000000000000000000000000000000000000000000000000000000000005" + // len(str[0]) = 5 + "68656c6c6f000000000000000000000000000000000000000000000000000000" + // str[0] + "0000000000000000000000000000000000000000000000000000000000000006" + // len(str[1]) = 6 + "666f6f6261720000000000000000000000000000000000000000000000000000", // str[1] + }, + { + def: `[{"type": "bytes32[][]"}]`, + unpacked: [][][32]byte{{{1}, {2}}, {{3}, {4}, {5}}}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + // len(array) = 2 + "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset 160 to i = 1 + "0000000000000000000000000000000000000000000000000000000000000002" + // len(array[0]) = 2 + "0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0] + "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1] + "0000000000000000000000000000000000000000000000000000000000000003" + // len(array[1]) = 3 + "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0] + "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1] + "0500000000000000000000000000000000000000000000000000000000000000", // array[1][2] + }, + { + def: `[{"type": "bytes32[][2]"}]`, + unpacked: [2][][32]byte{{{1}, {2}}, {{3}, {4}, {5}}}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000040" + // offset 64 to i = 0 + "00000000000000000000000000000000000000000000000000000000000000a0" + // offset 160 to i = 1 + "0000000000000000000000000000000000000000000000000000000000000002" + // len(array[0]) = 2 + "0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0] + "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1] + "0000000000000000000000000000000000000000000000000000000000000003" + // len(array[1]) = 3 + "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0] + "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1] + "0500000000000000000000000000000000000000000000000000000000000000", // array[1][2] + }, + { + def: `[{"type": "bytes32[3][2]"}]`, + unpacked: [2][3][32]byte{{{1}, {2}, {3}}, {{3}, {4}, {5}}}, + packed: "0100000000000000000000000000000000000000000000000000000000000000" + // array[0][0] + "0200000000000000000000000000000000000000000000000000000000000000" + // array[0][1] + "0300000000000000000000000000000000000000000000000000000000000000" + // array[0][2] + "0300000000000000000000000000000000000000000000000000000000000000" + // array[1][0] + "0400000000000000000000000000000000000000000000000000000000000000" + // array[1][1] + "0500000000000000000000000000000000000000000000000000000000000000", // array[1][2] + }, + { + // static tuple + def: `[{"components": [{"name":"a","type":"int64"}, + {"name":"b","type":"int256"}, + {"name":"c","type":"int256"}, + {"name":"d","type":"bool"}, + {"name":"e","type":"bytes32[3][2]"}], "type":"tuple"}]`, + unpacked: struct { + A int64 + B *big.Int + C *big.Int + D bool + E [2][3][32]byte + }{1, big.NewInt(1), big.NewInt(-1), true, [2][3][32]byte{{{1}, {2}, {3}}, {{3}, {4}, {5}}}}, + packed: "0000000000000000000000000000000000000000000000000000000000000001" + // struct[a] + "0000000000000000000000000000000000000000000000000000000000000001" + // struct[b] + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // struct[c] + "0000000000000000000000000000000000000000000000000000000000000001" + // struct[d] + "0100000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][0] + "0200000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][1] + "0300000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[0][2] + "0300000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[1][0] + "0400000000000000000000000000000000000000000000000000000000000000" + // struct[e] array[1][1] + "0500000000000000000000000000000000000000000000000000000000000000", // struct[e] array[1][2] + }, + { + def: `[{"components": [{"name":"a","type":"string"}, + {"name":"b","type":"int64"}, + {"name":"c","type":"bytes"}, + {"name":"d","type":"string[]"}, + {"name":"e","type":"int256[]"}, + {"name":"f","type":"address[]"}], "type":"tuple"}]`, + unpacked: struct { + A string + B int64 + C []byte + D []string + E []*big.Int + F []common.Address + }{"foobar", 1, []byte{1}, []string{"foo", "bar"}, []*big.Int{big.NewInt(1), big.NewInt(-1)}, []common.Address{{1}, {2}}}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + // struct a + "00000000000000000000000000000000000000000000000000000000000000c0" + // struct[a] offset + "0000000000000000000000000000000000000000000000000000000000000001" + // struct[b] + "0000000000000000000000000000000000000000000000000000000000000100" + // struct[c] offset + "0000000000000000000000000000000000000000000000000000000000000140" + // struct[d] offset + "0000000000000000000000000000000000000000000000000000000000000220" + // struct[e] offset + "0000000000000000000000000000000000000000000000000000000000000280" + // struct[f] offset + "0000000000000000000000000000000000000000000000000000000000000006" + // struct[a] length + "666f6f6261720000000000000000000000000000000000000000000000000000" + // struct[a] "foobar" + "0000000000000000000000000000000000000000000000000000000000000001" + // struct[c] length + "0100000000000000000000000000000000000000000000000000000000000000" + // []byte{1} + "0000000000000000000000000000000000000000000000000000000000000002" + // struct[d] length + "0000000000000000000000000000000000000000000000000000000000000040" + // foo offset + "0000000000000000000000000000000000000000000000000000000000000080" + // bar offset + "0000000000000000000000000000000000000000000000000000000000000003" + // foo length + "666f6f0000000000000000000000000000000000000000000000000000000000" + // foo + "0000000000000000000000000000000000000000000000000000000000000003" + // bar offset + "6261720000000000000000000000000000000000000000000000000000000000" + // bar + "0000000000000000000000000000000000000000000000000000000000000002" + // struct[e] length + "0000000000000000000000000000000000000000000000000000000000000001" + // 1 + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // -1 + "0000000000000000000000000000000000000000000000000000000000000002" + // struct[f] length + "0000000000000000000000000100000000000000000000000000000000000000" + // common.Address{1} + "0000000000000000000000000200000000000000000000000000000000000000", // common.Address{2} + }, + { + def: `[{"components": [{ "type": "tuple","components": [{"name": "a","type": "uint256"}, + {"name": "b","type": "uint256[]"}], + "name": "a","type": "tuple"}, + {"name": "b","type": "uint256[]"}], "type": "tuple"}]`, + unpacked: struct { + A struct { + A *big.Int + B []*big.Int + } + B []*big.Int + }{ + A: struct { + A *big.Int + B []*big.Int + }{big.NewInt(1), []*big.Int{big.NewInt(1), big.NewInt(2)}}, + B: []*big.Int{big.NewInt(1), big.NewInt(2)}}, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + // struct a + "0000000000000000000000000000000000000000000000000000000000000040" + // a offset + "00000000000000000000000000000000000000000000000000000000000000e0" + // b offset + "0000000000000000000000000000000000000000000000000000000000000001" + // a.a value + "0000000000000000000000000000000000000000000000000000000000000040" + // a.b offset + "0000000000000000000000000000000000000000000000000000000000000002" + // a.b length + "0000000000000000000000000000000000000000000000000000000000000001" + // a.b[0] value + "0000000000000000000000000000000000000000000000000000000000000002" + // a.b[1] value + "0000000000000000000000000000000000000000000000000000000000000002" + // b length + "0000000000000000000000000000000000000000000000000000000000000001" + // b[0] value + "0000000000000000000000000000000000000000000000000000000000000002", // b[1] value + }, + + { + def: `[{"components": [{"name": "a","type": "int256"}, + {"name": "b","type": "int256[]"}], + "name": "a","type": "tuple[]"}]`, + unpacked: []struct { + A *big.Int + B []*big.Int + }{ + {big.NewInt(-1), []*big.Int{big.NewInt(1), big.NewInt(3)}}, + {big.NewInt(1), []*big.Int{big.NewInt(2), big.NewInt(-1)}}, + }, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple length + "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0] offset + "00000000000000000000000000000000000000000000000000000000000000e0" + // tuple[1] offset + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].A + "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0].B offset + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[0].B length + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].B[0] value + "0000000000000000000000000000000000000000000000000000000000000003" + // tuple[0].B[1] value + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].A + "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[1].B offset + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].B length + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].B[0] value + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", // tuple[1].B[1] value + }, + { + def: `[{"components": [{"name": "a","type": "int256"}, + {"name": "b","type": "int256"}], + "name": "a","type": "tuple[2]"}]`, + unpacked: [2]struct { + A *big.Int + B *big.Int + }{ + {big.NewInt(-1), big.NewInt(1)}, + {big.NewInt(1), big.NewInt(-1)}, + }, + packed: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].a + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].b + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].a + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", // tuple[1].b + }, + { + def: `[{"components": [{"name": "a","type": "int256[]"}], + "name": "a","type": "tuple[2]"}]`, + unpacked: [2]struct { + A []*big.Int + }{ + {[]*big.Int{big.NewInt(-1), big.NewInt(1)}}, + {[]*big.Int{big.NewInt(1), big.NewInt(-1)}}, + }, + packed: "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000040" + // tuple[0] offset + "00000000000000000000000000000000000000000000000000000000000000c0" + // tuple[1] offset + "0000000000000000000000000000000000000000000000000000000000000020" + // tuple[0].A offset + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[0].A length + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + // tuple[0].A[0] + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[0].A[1] + "0000000000000000000000000000000000000000000000000000000000000020" + // tuple[1].A offset + "0000000000000000000000000000000000000000000000000000000000000002" + // tuple[1].A length + "0000000000000000000000000000000000000000000000000000000000000001" + // tuple[1].A[0] + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", // tuple[1].A[1] + }, +} diff --git a/accounts/abi/reflect.go b/accounts/abi/reflect.go index c39b3d0a6b..0f4948ac82 100644 --- a/accounts/abi/reflect.go +++ b/accounts/abi/reflect.go @@ -17,49 +17,76 @@ package abi import ( + "errors" "fmt" + "math/big" "reflect" "strings" ) +// ConvertType converts an interface of a runtime type into a interface of the +// given type +// e.g. turn +// var fields []reflect.StructField +// +// fields = append(fields, reflect.StructField{ +// Name: "X", +// Type: reflect.TypeOf(new(big.Int)), +// Tag: reflect.StructTag("json:\"" + "x" + "\""), +// } +// +// into +// type TupleT struct { X *big.Int } +func ConvertType(in interface{}, proto interface{}) interface{} { + protoType := reflect.TypeOf(proto) + if reflect.TypeOf(in).ConvertibleTo(protoType) { + return reflect.ValueOf(in).Convert(protoType).Interface() + } + // Use set as a last ditch effort + if err := set(reflect.ValueOf(proto), reflect.ValueOf(in)); err != nil { + panic(err) + } + return proto +} + // indirect recursively dereferences the value until it either gets the value // or finds a big.Int func indirect(v reflect.Value) reflect.Value { - if v.Kind() == reflect.Ptr && v.Elem().Type() != derefbigT { + if v.Kind() == reflect.Ptr && v.Elem().Type() != reflect.TypeOf(big.Int{}) { return indirect(v.Elem()) } return v } -// reflectIntKind returns the reflect using the given size and +// reflectIntType returns the reflect using the given size and // unsignedness. -func reflectIntKindAndType(unsigned bool, size int) (reflect.Kind, reflect.Type) { +func reflectIntType(unsigned bool, size int) reflect.Type { + if unsigned { + switch size { + case 8: + return reflect.TypeOf(uint8(0)) + case 16: + return reflect.TypeOf(uint16(0)) + case 32: + return reflect.TypeOf(uint32(0)) + case 64: + return reflect.TypeOf(uint64(0)) + } + } switch size { case 8: - if unsigned { - return reflect.Uint8, uint8T - } - return reflect.Int8, int8T + return reflect.TypeOf(int8(0)) case 16: - if unsigned { - return reflect.Uint16, uint16T - } - return reflect.Int16, int16T + return reflect.TypeOf(int16(0)) case 32: - if unsigned { - return reflect.Uint32, uint32T - } - return reflect.Int32, int32T + return reflect.TypeOf(int32(0)) case 64: - if unsigned { - return reflect.Uint64, uint64T - } - return reflect.Int64, int64T + return reflect.TypeOf(int64(0)) } - return reflect.Ptr, bigT + return reflect.TypeOf(&big.Int{}) } -// mustArrayToBytesSlice creates a new byte slice with the exact same size as value +// mustArrayToByteSlice creates a new byte slice with the exact same size as value // and copies the bytes in value to the new slice. func mustArrayToByteSlice(value reflect.Value) reflect.Value { slice := reflect.MakeSlice(reflect.TypeOf([]byte{}), value.Len(), value.Len()) @@ -74,14 +101,18 @@ func mustArrayToByteSlice(value reflect.Value) reflect.Value { func set(dst, src reflect.Value) error { dstType, srcType := dst.Type(), src.Type() switch { - case dstType.Kind() == reflect.Interface: + case dstType.Kind() == reflect.Interface && dst.Elem().IsValid(): return set(dst.Elem(), src) - case dstType.Kind() == reflect.Ptr && dstType.Elem() != derefbigT: + case dstType.Kind() == reflect.Ptr && dstType.Elem() != reflect.TypeOf(big.Int{}): return set(dst.Elem(), src) case srcType.AssignableTo(dstType) && dst.CanSet(): dst.Set(src) - case dstType.Kind() == reflect.Slice && srcType.Kind() == reflect.Slice: + case dstType.Kind() == reflect.Slice && srcType.Kind() == reflect.Slice && dst.CanSet(): return setSlice(dst, src) + case dstType.Kind() == reflect.Array: + return setArray(dst, src) + case dstType.Kind() == reflect.Struct: + return setStruct(dst, src) default: return fmt.Errorf("abi: cannot unmarshal %v in to %v", src.Type(), dst.Type()) } @@ -90,38 +121,59 @@ func set(dst, src reflect.Value) error { // setSlice attempts to assign src to dst when slices are not assignable by default // e.g. src: [][]byte -> dst: [][15]byte +// setSlice ignores if we cannot copy all of src' elements. func setSlice(dst, src reflect.Value) error { slice := reflect.MakeSlice(dst.Type(), src.Len(), src.Len()) for i := 0; i < src.Len(); i++ { - v := src.Index(i) - reflect.Copy(slice.Index(i), v) + if src.Index(i).Kind() == reflect.Struct { + if err := set(slice.Index(i), src.Index(i)); err != nil { + return err + } + } else { + // e.g. [][32]uint8 to []common.Hash + if err := set(slice.Index(i), src.Index(i)); err != nil { + return err + } + } } - - dst.Set(slice) - return nil + if dst.CanSet() { + dst.Set(slice) + return nil + } + return errors.New("Cannot set slice, destination not settable") } -// requireAssignable assures that `dest` is a pointer and it's not an interface. -func requireAssignable(dst, src reflect.Value) error { - if dst.Kind() != reflect.Ptr && dst.Kind() != reflect.Interface { - return fmt.Errorf("abi: cannot unmarshal %v into %v", src.Type(), dst.Type()) +func setArray(dst, src reflect.Value) error { + if src.Kind() == reflect.Ptr { + return set(dst, indirect(src)) } - return nil + array := reflect.New(dst.Type()).Elem() + min := src.Len() + if src.Len() > dst.Len() { + min = dst.Len() + } + for i := 0; i < min; i++ { + if err := set(array.Index(i), src.Index(i)); err != nil { + return err + } + } + if dst.CanSet() { + dst.Set(array) + return nil + } + return errors.New("Cannot set array, destination not settable") } -// requireUnpackKind verifies preconditions for unpacking `args` into `kind` -func requireUnpackKind(v reflect.Value, t reflect.Type, k reflect.Kind, - args Arguments) error { - - switch k { - case reflect.Struct: - case reflect.Slice, reflect.Array: - if minLen := args.LengthNonIndexed(); v.Len() < minLen { - return fmt.Errorf("abi: insufficient number of elements in the list/array for unpack, want %d, got %d", - minLen, v.Len()) +func setStruct(dst, src reflect.Value) error { + for i := 0; i < src.NumField(); i++ { + srcField := src.Field(i) + dstField := dst.Field(i) + if !dstField.IsValid() || !srcField.IsValid() { + return fmt.Errorf("Could not find src field: %v value: %v in destination", srcField.Type().Name(), srcField) + } + if err := set(dstField, srcField); err != nil { + return err } - default: - return fmt.Errorf("abi: cannot unmarshal tuple into %v", t) } return nil } @@ -152,9 +204,8 @@ func mapArgNamesToStructFields(argNames []string, value reflect.Value) (map[stri continue } // skip fields that have no abi:"" tag. - var ok bool - var tagName string - if tagName, ok = typ.Field(i).Tag.Lookup("abi"); !ok { + tagName, ok := typ.Field(i).Tag.Lookup("abi") + if !ok { continue } // check if tag is empty. diff --git a/accounts/abi/type.go b/accounts/abi/type.go index 26151dbd3e..d243877961 100644 --- a/accounts/abi/type.go +++ b/accounts/abi/type.go @@ -23,6 +23,10 @@ import ( "regexp" "strconv" "strings" + "unicode" + "unicode/utf8" + + "github.com/tomochain/tomochain/common" ) // Type enumerator @@ -42,19 +46,19 @@ const ( FunctionTy ) -// Type is the reflection of the supported argument type +// Type is the reflection of the supported argument type. type Type struct { Elem *Type - Kind reflect.Kind - Type reflect.Type Size int T byte // Our own type checking stringKind string // holds the unparsed string for deriving signatures // Tuple relative fields - TupleElems []*Type // Type information of all tuple fields - TupleRawNames []string // Raw field name of all tuple fields + TupleRawName string // Raw struct name defined in source code, may be empty. + TupleElems []*Type // Type information of all tuple fields + TupleRawNames []string // Raw field name of all tuple fields + TupleType reflect.Type // Underlying struct of the tuple } var ( @@ -63,20 +67,24 @@ var ( ) // NewType creates a new reflection type of abi type given in t. -func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { +func NewType(t string, internalType string, components []ArgumentMarshaling) (typ Type, err error) { // check that array brackets are equal if they exist if strings.Count(t, "[") != strings.Count(t, "]") { - return Type{}, fmt.Errorf("invalid arg type in abi") + return Type{}, errors.New("invalid arg type in abi") } - typ.stringKind = t // if there are brackets, get ready to go into slice/array mode and // recursively create the type if strings.Count(t, "[") != 0 { - i := strings.LastIndex(t, "[") + // Note internalType can be empty here. + subInternal := internalType + if i := strings.LastIndex(internalType, "["); i != -1 { + subInternal = subInternal[:i] + } // recursively embed the type - embeddedType, err := NewType(t[:i], components) + i := strings.LastIndex(t, "[") + embeddedType, err := NewType(t[:i], subInternal, components) if err != nil { return Type{}, err } @@ -89,27 +97,19 @@ func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { if len(intz) == 0 { // is a slice typ.T = SliceTy - typ.Kind = reflect.Slice typ.Elem = &embeddedType - typ.Type = reflect.SliceOf(embeddedType.Type) - if embeddedType.T == TupleTy { - typ.stringKind = embeddedType.stringKind + sliced - } + typ.stringKind = embeddedType.stringKind + sliced } else if len(intz) == 1 { - // is a array + // is an array typ.T = ArrayTy - typ.Kind = reflect.Array typ.Elem = &embeddedType typ.Size, err = strconv.Atoi(intz[0]) if err != nil { return Type{}, fmt.Errorf("abi: error parsing variable size: %v", err) } - typ.Type = reflect.ArrayOf(typ.Size, embeddedType.Type) - if embeddedType.T == TupleTy { - typ.stringKind = embeddedType.stringKind + sliced - } + typ.stringKind = embeddedType.stringKind + sliced } else { - return Type{}, fmt.Errorf("invalid formatting of array type") + return Type{}, errors.New("invalid formatting of array type") } return typ, err } @@ -138,36 +138,27 @@ func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { // varType is the parsed abi type switch varType := parsedType[1]; varType { case "int": - typ.Kind, typ.Type = reflectIntKindAndType(false, varSize) typ.Size = varSize typ.T = IntTy case "uint": - typ.Kind, typ.Type = reflectIntKindAndType(true, varSize) typ.Size = varSize typ.T = UintTy case "bool": - typ.Kind = reflect.Bool typ.T = BoolTy - typ.Type = reflect.TypeOf(bool(false)) case "address": - typ.Kind = reflect.Array - typ.Type = addressT typ.Size = 20 typ.T = AddressTy case "string": - typ.Kind = reflect.String - typ.Type = reflect.TypeOf("") typ.T = StringTy case "bytes": if varSize == 0 { typ.T = BytesTy - typ.Kind = reflect.Slice - typ.Type = reflect.SliceOf(reflect.TypeOf(byte(0))) } else { + if varSize > 32 { + return Type{}, fmt.Errorf("unsupported arg type: %s", t) + } typ.T = FixedBytesTy - typ.Kind = reflect.Array typ.Size = varSize - typ.Type = reflect.ArrayOf(varSize, reflect.TypeOf(byte(0))) } case "tuple": var ( @@ -175,19 +166,30 @@ func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { elems []*Type names []string expression string // canonical parameter expression + used = make(map[string]bool) ) expression += "(" for idx, c := range components { - cType, err := NewType(c.Type, c.Components) + cType, err := NewType(c.Type, c.InternalType, c.Components) if err != nil { return Type{}, err } - if ToCamelCase(c.Name) == "" { + name := ToCamelCase(c.Name) + if name == "" { return Type{}, errors.New("abi: purely anonymous or underscored field is not supported") } + fieldName := ResolveNameConflict(name, func(s string) bool { return used[s] }) + if err != nil { + return Type{}, err + } + used[fieldName] = true + if !isValidFieldName(fieldName) { + return Type{}, fmt.Errorf("field %d has invalid name", idx) + } fields = append(fields, reflect.StructField{ - Name: ToCamelCase(c.Name), // reflect.StructOf will panic for any exported field. - Type: cType.Type, + Name: fieldName, // reflect.StructOf will panic for any exported field. + Type: cType.GetType(), + Tag: reflect.StructTag("json:\"" + c.Name + "\""), }) elems = append(elems, &cType) names = append(names, c.Name) @@ -197,17 +199,26 @@ func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { } } expression += ")" - typ.Kind = reflect.Struct - typ.Type = reflect.StructOf(fields) + + typ.TupleType = reflect.StructOf(fields) typ.TupleElems = elems typ.TupleRawNames = names typ.T = TupleTy typ.stringKind = expression + + const structPrefix = "struct " + // After solidity 0.5.10, a new field of abi "internalType" + // is introduced. From that we can obtain the struct name + // user defined in the source code. + if internalType != "" && strings.HasPrefix(internalType, structPrefix) { + // Foo.Bar type definition is not allowed in golang, + // convert the format to FooBar + typ.TupleRawName = strings.ReplaceAll(internalType[len(structPrefix):], ".", "") + } + case "function": - typ.Kind = reflect.Array typ.T = FunctionTy typ.Size = 24 - typ.Type = reflect.ArrayOf(24, reflect.TypeOf(byte(0))) default: return Type{}, fmt.Errorf("unsupported arg type: %s", t) } @@ -215,7 +226,43 @@ func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { return } -// String implements Stringer +// GetType returns the reflection type of the ABI type. +func (t Type) GetType() reflect.Type { + switch t.T { + case IntTy: + return reflectIntType(false, t.Size) + case UintTy: + return reflectIntType(true, t.Size) + case BoolTy: + return reflect.TypeOf(false) + case StringTy: + return reflect.TypeOf("") + case SliceTy: + return reflect.SliceOf(t.Elem.GetType()) + case ArrayTy: + return reflect.ArrayOf(t.Size, t.Elem.GetType()) + case TupleTy: + return t.TupleType + case AddressTy: + return reflect.TypeOf(common.Address{}) + case FixedBytesTy: + return reflect.ArrayOf(t.Size, reflect.TypeOf(byte(0))) + case BytesTy: + return reflect.SliceOf(reflect.TypeOf(byte(0))) + case HashTy: + // hashtype currently not used + return reflect.ArrayOf(32, reflect.TypeOf(byte(0))) + case FixedPointTy: + // fixedpoint type currently not used + return reflect.ArrayOf(32, reflect.TypeOf(byte(0))) + case FunctionTy: + return reflect.ArrayOf(24, reflect.TypeOf(byte(0))) + default: + panic("Invalid type") + } +} + +// String implements Stringer. func (t Type) String() (out string) { return t.stringKind } @@ -297,11 +344,11 @@ func (t Type) pack(v reflect.Value) ([]byte, error) { return append(ret, tail...), nil default: - return packElement(t, v), nil + return packElement(t, v) } } -// requireLengthPrefix returns whether the type requires any sort of length +// requiresLengthPrefix returns whether the type requires any sort of length // prefixing. func (t Type) requiresLengthPrefix() bool { return t.T == StringTy || t.T == BytesTy || t.T == SliceTy @@ -337,7 +384,7 @@ func isDynamicType(t Type) bool { func getTypeSize(t Type) int { if t.T == ArrayTy && !isDynamicType(*t.Elem) { // Recursively calculate type size if it is a nested array - if t.Elem.T == ArrayTy { + if t.Elem.T == ArrayTy || t.Elem.T == TupleTy { return t.Size * getTypeSize(*t.Elem) } return t.Size * 32 @@ -350,3 +397,30 @@ func getTypeSize(t Type) int { } return 32 } + +// isLetter reports whether a given 'rune' is classified as a Letter. +// This method is copied from reflect/type.go +func isLetter(ch rune) bool { + return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch >= utf8.RuneSelf && unicode.IsLetter(ch) +} + +// isValidFieldName checks if a string is a valid (struct) field name or not. +// +// According to the language spec, a field name should be an identifier. +// +// identifier = letter { letter | unicode_digit } . +// letter = unicode_letter | "_" . +// This method is copied from reflect/type.go +func isValidFieldName(fieldName string) bool { + for i, c := range fieldName { + if i == 0 && !isLetter(c) { + return false + } + + if !(isLetter(c) || unicode.IsDigit(c)) { + return false + } + } + + return len(fieldName) > 0 +} diff --git a/accounts/abi/type_test.go b/accounts/abi/type_test.go index 48da0c4ef5..3b89029419 100644 --- a/accounts/abi/type_test.go +++ b/accounts/abi/type_test.go @@ -22,10 +22,11 @@ import ( "testing" "github.com/davecgh/go-spew/spew" + "github.com/tomochain/tomochain/common" ) -// typeWithoutStringer is a alias for the Type type which simply doesn't implement +// typeWithoutStringer is an alias for the Type type which simply doesn't implement // the stringer interface to allow printing type details in the tests below. type typeWithoutStringer Type @@ -36,58 +37,58 @@ func TestTypeRegexp(t *testing.T) { components []ArgumentMarshaling kind Type }{ - {"bool", nil, Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}}, - {"bool[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool(nil)), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}}, - {"bool[2]", nil, Type{Size: 2, Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}}, - {"bool[2][]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}}, - {"bool[][]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}}, - {"bool[][2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}}, - {"bool[2][2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}}, - {"bool[2][][2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][][2]bool{}), Elem: &Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}, stringKind: "bool[2][][2]"}}, - {"bool[2][2][2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][2]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}, stringKind: "bool[2][2][2]"}}, - {"bool[][][]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}, stringKind: "bool[][][]"}}, - {"bool[][2][]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][2][]bool{}), Elem: &Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]bool{}), Elem: &Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]bool{}), Elem: &Type{Kind: reflect.Bool, T: BoolTy, Type: reflect.TypeOf(bool(false)), stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}, stringKind: "bool[][2][]"}}, - {"int8", nil, Type{Kind: reflect.Int8, Type: int8T, Size: 8, T: IntTy, stringKind: "int8"}}, - {"int16", nil, Type{Kind: reflect.Int16, Type: int16T, Size: 16, T: IntTy, stringKind: "int16"}}, - {"int32", nil, Type{Kind: reflect.Int32, Type: int32T, Size: 32, T: IntTy, stringKind: "int32"}}, - {"int64", nil, Type{Kind: reflect.Int64, Type: int64T, Size: 64, T: IntTy, stringKind: "int64"}}, - {"int256", nil, Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: IntTy, stringKind: "int256"}}, - {"int8[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int8{}), Elem: &Type{Kind: reflect.Int8, Type: int8T, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[]"}}, - {"int8[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int8{}), Elem: &Type{Kind: reflect.Int8, Type: int8T, Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[2]"}}, - {"int16[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int16{}), Elem: &Type{Kind: reflect.Int16, Type: int16T, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[]"}}, - {"int16[2]", nil, Type{Size: 2, Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]int16{}), Elem: &Type{Kind: reflect.Int16, Type: int16T, Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[2]"}}, - {"int32[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int32{}), Elem: &Type{Kind: reflect.Int32, Type: int32T, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[]"}}, - {"int32[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int32{}), Elem: &Type{Kind: reflect.Int32, Type: int32T, Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[2]"}}, - {"int64[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]int64{}), Elem: &Type{Kind: reflect.Int64, Type: int64T, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[]"}}, - {"int64[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]int64{}), Elem: &Type{Kind: reflect.Int64, Type: int64T, Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[2]"}}, - {"int256[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[]"}}, - {"int256[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[2]"}}, - {"uint8", nil, Type{Kind: reflect.Uint8, Type: uint8T, Size: 8, T: UintTy, stringKind: "uint8"}}, - {"uint16", nil, Type{Kind: reflect.Uint16, Type: uint16T, Size: 16, T: UintTy, stringKind: "uint16"}}, - {"uint32", nil, Type{Kind: reflect.Uint32, Type: uint32T, Size: 32, T: UintTy, stringKind: "uint32"}}, - {"uint64", nil, Type{Kind: reflect.Uint64, Type: uint64T, Size: 64, T: UintTy, stringKind: "uint64"}}, - {"uint256", nil, Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: UintTy, stringKind: "uint256"}}, - {"uint8[]", nil, Type{Kind: reflect.Slice, T: SliceTy, Type: reflect.TypeOf([]uint8{}), Elem: &Type{Kind: reflect.Uint8, Type: uint8T, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[]"}}, - {"uint8[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint8{}), Elem: &Type{Kind: reflect.Uint8, Type: uint8T, Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[2]"}}, - {"uint16[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint16{}), Elem: &Type{Kind: reflect.Uint16, Type: uint16T, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[]"}}, - {"uint16[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint16{}), Elem: &Type{Kind: reflect.Uint16, Type: uint16T, Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[2]"}}, - {"uint32[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint32{}), Elem: &Type{Kind: reflect.Uint32, Type: uint32T, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[]"}}, - {"uint32[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint32{}), Elem: &Type{Kind: reflect.Uint32, Type: uint32T, Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[2]"}}, - {"uint64[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]uint64{}), Elem: &Type{Kind: reflect.Uint64, Type: uint64T, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[]"}}, - {"uint64[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]uint64{}), Elem: &Type{Kind: reflect.Uint64, Type: uint64T, Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[2]"}}, - {"uint256[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]*big.Int{}), Elem: &Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[]"}}, - {"uint256[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Type: reflect.TypeOf([2]*big.Int{}), Size: 2, Elem: &Type{Kind: reflect.Ptr, Type: bigT, Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[2]"}}, - {"bytes32", nil, Type{Kind: reflect.Array, T: FixedBytesTy, Size: 32, Type: reflect.TypeOf([32]byte{}), stringKind: "bytes32"}}, - {"bytes[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][]byte{}), Elem: &Type{Kind: reflect.Slice, Type: reflect.TypeOf([]byte{}), T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[]"}}, - {"bytes[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][]byte{}), Elem: &Type{T: BytesTy, Type: reflect.TypeOf([]byte{}), Kind: reflect.Slice, stringKind: "bytes"}, stringKind: "bytes[2]"}}, - {"bytes32[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([][32]byte{}), Elem: &Type{Kind: reflect.Array, Type: reflect.TypeOf([32]byte{}), T: FixedBytesTy, Size: 32, stringKind: "bytes32"}, stringKind: "bytes32[]"}}, - {"bytes32[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2][32]byte{}), Elem: &Type{Kind: reflect.Array, T: FixedBytesTy, Size: 32, Type: reflect.TypeOf([32]byte{}), stringKind: "bytes32"}, stringKind: "bytes32[2]"}}, - {"string", nil, Type{Kind: reflect.String, T: StringTy, Type: reflect.TypeOf(""), stringKind: "string"}}, - {"string[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]string{}), Elem: &Type{Kind: reflect.String, Type: reflect.TypeOf(""), T: StringTy, stringKind: "string"}, stringKind: "string[]"}}, - {"string[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]string{}), Elem: &Type{Kind: reflect.String, T: StringTy, Type: reflect.TypeOf(""), stringKind: "string"}, stringKind: "string[2]"}}, - {"address", nil, Type{Kind: reflect.Array, Type: addressT, Size: 20, T: AddressTy, stringKind: "address"}}, - {"address[]", nil, Type{T: SliceTy, Kind: reflect.Slice, Type: reflect.TypeOf([]common.Address{}), Elem: &Type{Kind: reflect.Array, Type: addressT, Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[]"}}, - {"address[2]", nil, Type{Kind: reflect.Array, T: ArrayTy, Size: 2, Type: reflect.TypeOf([2]common.Address{}), Elem: &Type{Kind: reflect.Array, Type: addressT, Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[2]"}}, + {"bool", nil, Type{T: BoolTy, stringKind: "bool"}}, + {"bool[]", nil, Type{T: SliceTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}}, + {"bool[2]", nil, Type{Size: 2, T: ArrayTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}}, + {"bool[2][]", nil, Type{T: SliceTy, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}}, + {"bool[][]", nil, Type{T: SliceTy, Elem: &Type{T: SliceTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}}, + {"bool[][2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: SliceTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}}, + {"bool[2][2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}}, + {"bool[2][][2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: SliceTy, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][]"}, stringKind: "bool[2][][2]"}}, + {"bool[2][2][2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[2]"}, stringKind: "bool[2][2]"}, stringKind: "bool[2][2][2]"}}, + {"bool[][][]", nil, Type{T: SliceTy, Elem: &Type{T: SliceTy, Elem: &Type{T: SliceTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][]"}, stringKind: "bool[][][]"}}, + {"bool[][2][]", nil, Type{T: SliceTy, Elem: &Type{T: ArrayTy, Size: 2, Elem: &Type{T: SliceTy, Elem: &Type{T: BoolTy, stringKind: "bool"}, stringKind: "bool[]"}, stringKind: "bool[][2]"}, stringKind: "bool[][2][]"}}, + {"int8", nil, Type{Size: 8, T: IntTy, stringKind: "int8"}}, + {"int16", nil, Type{Size: 16, T: IntTy, stringKind: "int16"}}, + {"int32", nil, Type{Size: 32, T: IntTy, stringKind: "int32"}}, + {"int64", nil, Type{Size: 64, T: IntTy, stringKind: "int64"}}, + {"int256", nil, Type{Size: 256, T: IntTy, stringKind: "int256"}}, + {"int8[]", nil, Type{T: SliceTy, Elem: &Type{Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[]"}}, + {"int8[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 8, T: IntTy, stringKind: "int8"}, stringKind: "int8[2]"}}, + {"int16[]", nil, Type{T: SliceTy, Elem: &Type{Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[]"}}, + {"int16[2]", nil, Type{Size: 2, T: ArrayTy, Elem: &Type{Size: 16, T: IntTy, stringKind: "int16"}, stringKind: "int16[2]"}}, + {"int32[]", nil, Type{T: SliceTy, Elem: &Type{Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[]"}}, + {"int32[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 32, T: IntTy, stringKind: "int32"}, stringKind: "int32[2]"}}, + {"int64[]", nil, Type{T: SliceTy, Elem: &Type{Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[]"}}, + {"int64[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 64, T: IntTy, stringKind: "int64"}, stringKind: "int64[2]"}}, + {"int256[]", nil, Type{T: SliceTy, Elem: &Type{Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[]"}}, + {"int256[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 256, T: IntTy, stringKind: "int256"}, stringKind: "int256[2]"}}, + {"uint8", nil, Type{Size: 8, T: UintTy, stringKind: "uint8"}}, + {"uint16", nil, Type{Size: 16, T: UintTy, stringKind: "uint16"}}, + {"uint32", nil, Type{Size: 32, T: UintTy, stringKind: "uint32"}}, + {"uint64", nil, Type{Size: 64, T: UintTy, stringKind: "uint64"}}, + {"uint256", nil, Type{Size: 256, T: UintTy, stringKind: "uint256"}}, + {"uint8[]", nil, Type{T: SliceTy, Elem: &Type{Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[]"}}, + {"uint8[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 8, T: UintTy, stringKind: "uint8"}, stringKind: "uint8[2]"}}, + {"uint16[]", nil, Type{T: SliceTy, Elem: &Type{Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[]"}}, + {"uint16[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 16, T: UintTy, stringKind: "uint16"}, stringKind: "uint16[2]"}}, + {"uint32[]", nil, Type{T: SliceTy, Elem: &Type{Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[]"}}, + {"uint32[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 32, T: UintTy, stringKind: "uint32"}, stringKind: "uint32[2]"}}, + {"uint64[]", nil, Type{T: SliceTy, Elem: &Type{Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[]"}}, + {"uint64[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 64, T: UintTy, stringKind: "uint64"}, stringKind: "uint64[2]"}}, + {"uint256[]", nil, Type{T: SliceTy, Elem: &Type{Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[]"}}, + {"uint256[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 256, T: UintTy, stringKind: "uint256"}, stringKind: "uint256[2]"}}, + {"bytes32", nil, Type{T: FixedBytesTy, Size: 32, stringKind: "bytes32"}}, + {"bytes[]", nil, Type{T: SliceTy, Elem: &Type{T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[]"}}, + {"bytes[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: BytesTy, stringKind: "bytes"}, stringKind: "bytes[2]"}}, + {"bytes32[]", nil, Type{T: SliceTy, Elem: &Type{T: FixedBytesTy, Size: 32, stringKind: "bytes32"}, stringKind: "bytes32[]"}}, + {"bytes32[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: FixedBytesTy, Size: 32, stringKind: "bytes32"}, stringKind: "bytes32[2]"}}, + {"string", nil, Type{T: StringTy, stringKind: "string"}}, + {"string[]", nil, Type{T: SliceTy, Elem: &Type{T: StringTy, stringKind: "string"}, stringKind: "string[]"}}, + {"string[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{T: StringTy, stringKind: "string"}, stringKind: "string[2]"}}, + {"address", nil, Type{Size: 20, T: AddressTy, stringKind: "address"}}, + {"address[]", nil, Type{T: SliceTy, Elem: &Type{Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[]"}}, + {"address[2]", nil, Type{T: ArrayTy, Size: 2, Elem: &Type{Size: 20, T: AddressTy, stringKind: "address"}, stringKind: "address[2]"}}, // TODO when fixed types are implemented properly // {"fixed", nil, Type{}}, // {"fixed128x128", nil, Type{}}, @@ -95,12 +96,18 @@ func TestTypeRegexp(t *testing.T) { // {"fixed[2]", nil, Type{}}, // {"fixed128x128[]", nil, Type{}}, // {"fixed128x128[2]", nil, Type{}}, - {"tuple", []ArgumentMarshaling{{Name: "a", Type: "int64"}}, Type{Kind: reflect.Struct, T: TupleTy, Type: reflect.TypeOf(struct{ A int64 }{}), stringKind: "(int64)", - TupleElems: []*Type{{Kind: reflect.Int64, T: IntTy, Type: reflect.TypeOf(int64(0)), Size: 64, stringKind: "int64"}}, TupleRawNames: []string{"a"}}}, + {"tuple", []ArgumentMarshaling{{Name: "a", Type: "int64"}}, Type{T: TupleTy, TupleType: reflect.TypeOf(struct { + A int64 `json:"a"` + }{}), stringKind: "(int64)", + TupleElems: []*Type{{T: IntTy, Size: 64, stringKind: "int64"}}, TupleRawNames: []string{"a"}}}, + {"tuple with long name", []ArgumentMarshaling{{Name: "aTypicalParamName", Type: "int64"}}, Type{T: TupleTy, TupleType: reflect.TypeOf(struct { + ATypicalParamName int64 `json:"aTypicalParamName"` + }{}), stringKind: "(int64)", + TupleElems: []*Type{{T: IntTy, Size: 64, stringKind: "int64"}}, TupleRawNames: []string{"aTypicalParamName"}}}, } for _, tt := range tests { - typ, err := NewType(tt.blob, tt.components) + typ, err := NewType(tt.blob, "", tt.components) if err != nil { t.Errorf("type %q: failed to parse type string: %v", tt.blob, err) } @@ -275,7 +282,7 @@ func TestTypeCheck(t *testing.T) { B *big.Int }{{big.NewInt(0), big.NewInt(0)}, {big.NewInt(0), big.NewInt(0)}}, ""}, } { - typ, err := NewType(test.typ, test.components) + typ, err := NewType(test.typ, "", test.components) if err != nil && len(test.err) == 0 { t.Fatal("unexpected parse error:", err) } else if err != nil && len(test.err) != 0 { @@ -300,3 +307,63 @@ func TestTypeCheck(t *testing.T) { } } } + +func TestInternalType(t *testing.T) { + components := []ArgumentMarshaling{{Name: "a", Type: "int64"}} + internalType := "struct a.b[]" + kind := Type{ + T: TupleTy, + TupleType: reflect.TypeOf(struct { + A int64 `json:"a"` + }{}), + stringKind: "(int64)", + TupleRawName: "ab[]", + TupleElems: []*Type{{T: IntTy, Size: 64, stringKind: "int64"}}, + TupleRawNames: []string{"a"}, + } + + blob := "tuple" + typ, err := NewType(blob, internalType, components) + if err != nil { + t.Errorf("type %q: failed to parse type string: %v", blob, err) + } + if !reflect.DeepEqual(typ, kind) { + t.Errorf("type %q: parsed type mismatch:\nGOT %s\nWANT %s ", blob, spew.Sdump(typeWithoutStringer(typ)), spew.Sdump(typeWithoutStringer(kind))) + } +} + +func TestGetTypeSize(t *testing.T) { + var testCases = []struct { + typ string + components []ArgumentMarshaling + typSize int + }{ + // simple array + {"uint256[2]", nil, 32 * 2}, + {"address[3]", nil, 32 * 3}, + {"bytes32[4]", nil, 32 * 4}, + // array array + {"uint256[2][3][4]", nil, 32 * (2 * 3 * 4)}, + // array tuple + {"tuple[2]", []ArgumentMarshaling{{Name: "x", Type: "bytes32"}, {Name: "y", Type: "bytes32"}}, (32 * 2) * 2}, + // simple tuple + {"tuple", []ArgumentMarshaling{{Name: "x", Type: "uint256"}, {Name: "y", Type: "uint256"}}, 32 * 2}, + // tuple array + {"tuple", []ArgumentMarshaling{{Name: "x", Type: "bytes32[2]"}}, 32 * 2}, + // tuple tuple + {"tuple", []ArgumentMarshaling{{Name: "x", Type: "tuple", Components: []ArgumentMarshaling{{Name: "x", Type: "bytes32"}}}}, 32}, + {"tuple", []ArgumentMarshaling{{Name: "x", Type: "tuple", Components: []ArgumentMarshaling{{Name: "x", Type: "bytes32[2]"}, {Name: "y", Type: "uint256"}}}}, 32 * (2 + 1)}, + } + + for i, data := range testCases { + typ, err := NewType(data.typ, "", data.components) + if err != nil { + t.Errorf("type %q: failed to parse type string: %v", data.typ, err) + } + + result := getTypeSize(typ) + if result != data.typSize { + t.Errorf("case %d type %q: get type size error: actual: %d expected: %d", i, data.typ, result, data.typSize) + } + } +} diff --git a/accounts/abi/unpack.go b/accounts/abi/unpack.go index 86af6f97b9..927f43afe6 100644 --- a/accounts/abi/unpack.go +++ b/accounts/abi/unpack.go @@ -26,52 +26,54 @@ import ( ) var ( - maxUint256 = big.NewInt(0).Add( - big.NewInt(0).Exp(big.NewInt(2), big.NewInt(256), nil), - big.NewInt(-1)) - maxInt256 = big.NewInt(0).Add( - big.NewInt(0).Exp(big.NewInt(2), big.NewInt(255), nil), - big.NewInt(-1)) + // MaxUint256 is the maximum value that can be represented by a uint256. + MaxUint256 = new(big.Int).Sub(new(big.Int).Lsh(common.Big1, 256), common.Big1) + // MaxInt256 is the maximum value that can be represented by a int256. + MaxInt256 = new(big.Int).Sub(new(big.Int).Lsh(common.Big1, 255), common.Big1) ) -// reads the integer based on its kind -func readInteger(typ byte, kind reflect.Kind, b []byte) interface{} { - switch kind { - case reflect.Uint8: - return b[len(b)-1] - case reflect.Uint16: - return binary.BigEndian.Uint16(b[len(b)-2:]) - case reflect.Uint32: - return binary.BigEndian.Uint32(b[len(b)-4:]) - case reflect.Uint64: - return binary.BigEndian.Uint64(b[len(b)-8:]) - case reflect.Int8: +// ReadInteger reads the integer based on its kind and returns the appropriate value. +func ReadInteger(typ Type, b []byte) interface{} { + if typ.T == UintTy { + switch typ.Size { + case 8: + return b[len(b)-1] + case 16: + return binary.BigEndian.Uint16(b[len(b)-2:]) + case 32: + return binary.BigEndian.Uint32(b[len(b)-4:]) + case 64: + return binary.BigEndian.Uint64(b[len(b)-8:]) + default: + // the only case left for unsigned integer is uint256. + return new(big.Int).SetBytes(b) + } + } + switch typ.Size { + case 8: return int8(b[len(b)-1]) - case reflect.Int16: + case 16: return int16(binary.BigEndian.Uint16(b[len(b)-2:])) - case reflect.Int32: + case 32: return int32(binary.BigEndian.Uint32(b[len(b)-4:])) - case reflect.Int64: + case 64: return int64(binary.BigEndian.Uint64(b[len(b)-8:])) default: - // the only case lefts for integer is int256/uint256. - // big.SetBytes can't tell if a number is negative, positive on itself. + // the only case left for integer is int256 + // big.SetBytes can't tell if a number is negative or positive in itself. // On EVM, if the returned number > max int256, it is negative. + // A number is > max int256 if the bit at position 255 is set. ret := new(big.Int).SetBytes(b) - if typ == UintTy { - return ret - } - - if ret.Cmp(maxInt256) > 0 { - ret.Add(maxUint256, big.NewInt(0).Neg(ret)) - ret.Add(ret, big.NewInt(1)) + if ret.Bit(255) == 1 { + ret.Add(MaxUint256, new(big.Int).Neg(ret)) + ret.Add(ret, common.Big1) ret.Neg(ret) } return ret } } -// reads a bool +// readBool reads a bool. func readBool(word []byte) (bool, error) { for _, b := range word[:31] { if b != 0 { @@ -89,7 +91,8 @@ func readBool(word []byte) (bool, error) { } // A function type is simply the address with the function selection signature at the end. -// This enforces that standard by always presenting it as a 24-array (address + sig = 24 bytes) +// +// readFunctionType enforces that standard by always presenting it as a 24-array (address + sig = 24 bytes) func readFunctionType(t Type, word []byte) (funcTy [24]byte, err error) { if t.T != FunctionTy { return [24]byte{}, fmt.Errorf("abi: invalid type in call to make function type byte array") @@ -102,20 +105,20 @@ func readFunctionType(t Type, word []byte) (funcTy [24]byte, err error) { return } -// through reflection, creates a fixed array to be read from -func readFixedBytes(t Type, word []byte) (interface{}, error) { +// ReadFixedBytes uses reflection to create a fixed array to be read from. +func ReadFixedBytes(t Type, word []byte) (interface{}, error) { if t.T != FixedBytesTy { return nil, fmt.Errorf("abi: invalid type in call to make fixed byte array") } // convert - array := reflect.New(t.Type).Elem() + array := reflect.New(t.GetType()).Elem() reflect.Copy(array, reflect.ValueOf(word[0:t.Size])) return array.Interface(), nil } -// iteratively unpack elements +// forEachUnpack iteratively unpack elements. func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error) { if size < 0 { return nil, fmt.Errorf("cannot marshal input to array, size is negative (%d)", size) @@ -129,10 +132,10 @@ func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error) if t.T == SliceTy { // declare our slice - refSlice = reflect.MakeSlice(t.Type, size, size) + refSlice = reflect.MakeSlice(t.GetType(), size, size) } else if t.T == ArrayTy { // declare our array - refSlice = reflect.New(t.Type).Elem() + refSlice = reflect.New(t.GetType()).Elem() } else { return nil, fmt.Errorf("abi: invalid type in array/slice unpacking stage") } @@ -156,7 +159,7 @@ func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error) } func forTupleUnpack(t Type, output []byte) (interface{}, error) { - retval := reflect.New(t.Type).Elem() + retval := reflect.New(t.GetType()).Elem() virtualArgs := 0 for index, elem := range t.TupleElems { marshalledValue, err := toGoType((index+virtualArgs)*32, *elem, output) @@ -216,21 +219,23 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) { return nil, err } return forTupleUnpack(t, output[begin:]) - } else { - return forTupleUnpack(t, output[index:]) } + return forTupleUnpack(t, output[index:]) case SliceTy: return forEachUnpack(t, output[begin:], 0, length) case ArrayTy: if isDynamicType(*t.Elem) { - offset := int64(binary.BigEndian.Uint64(returnOutput[len(returnOutput)-8:])) + offset := binary.BigEndian.Uint64(returnOutput[len(returnOutput)-8:]) + if offset > uint64(len(output)) { + return nil, fmt.Errorf("abi: toGoType offset greater than output length: offset: %d, len(output): %d", offset, len(output)) + } return forEachUnpack(t, output[offset:], 0, t.Size) } return forEachUnpack(t, output[index:], 0, t.Size) case StringTy: // variable arrays are written at the end of the return bytes return string(output[begin : begin+length]), nil case IntTy, UintTy: - return readInteger(t.T, t.Kind, returnOutput), nil + return ReadInteger(t, returnOutput), nil case BoolTy: return readBool(returnOutput) case AddressTy: @@ -240,7 +245,7 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) { case BytesTy: return output[begin : begin+length], nil case FixedBytesTy: - return readFixedBytes(t, returnOutput) + return ReadFixedBytes(t, returnOutput) case FunctionTy: return readFunctionType(t, returnOutput) default: @@ -248,7 +253,7 @@ func toGoType(index int, t Type, output []byte) (interface{}, error) { } } -// interprets a 32 byte slice as an offset and then determines which indice to look to decode the type. +// lengthPrefixPointsTo interprets a 32 byte slice as an offset and then determines which indices to look to decode the type. func lengthPrefixPointsTo(index int, output []byte) (start int, length int, err error) { bigOffsetEnd := big.NewInt(0).SetBytes(output[index : index+32]) bigOffsetEnd.Add(bigOffsetEnd, common.Big32) @@ -269,7 +274,7 @@ func lengthPrefixPointsTo(index int, output []byte) (start int, length int, err totalSize.Add(totalSize, bigOffsetEnd) totalSize.Add(totalSize, lengthBig) if totalSize.BitLen() > 63 { - return 0, 0, fmt.Errorf("abi length larger than int64: %v", totalSize) + return 0, 0, fmt.Errorf("abi: length larger than int64: %v", totalSize) } if totalSize.Cmp(outputLength) > 0 { diff --git a/accounts/abi/unpack_test.go b/accounts/abi/unpack_test.go index 2906bec20a..7fda0ccbf1 100644 --- a/accounts/abi/unpack_test.go +++ b/accounts/abi/unpack_test.go @@ -27,9 +27,36 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/tomochain/tomochain/common" ) +// TestUnpack tests the general pack/unpack tests in packing_test.go +func TestUnpack(t *testing.T) { + for i, test := range packUnpackTests { + t.Run(strconv.Itoa(i)+" "+test.def, func(t *testing.T) { + //Unpack + def := fmt.Sprintf(`[{ "name" : "method", "type": "function", "outputs": %s}]`, test.def) + abi, err := JSON(strings.NewReader(def)) + if err != nil { + t.Fatalf("invalid ABI definition %s: %v", def, err) + } + encb, err := hex.DecodeString(test.packed) + if err != nil { + t.Fatalf("invalid hex %s: %v", test.packed, err) + } + out, err := abi.Unpack("method", encb) + if err != nil { + t.Errorf("test %d (%v) failed: %v", i, test.def, err) + return + } + if !reflect.DeepEqual(test.unpacked, ConvertType(out[0], test.unpacked)) { + t.Errorf("test %d (%v) failed: expected %v, got %v", i, test.def, test.unpacked, out[0]) + } + }) + } +} + type unpackTest struct { def string // ABI definition JSON enc string // evm return data @@ -51,16 +78,7 @@ func (test unpackTest) checkError(err error) error { } var unpackTests = []unpackTest{ - { - def: `[{ "type": "bool" }]`, - enc: "0000000000000000000000000000000000000000000000000000000000000001", - want: true, - }, - { - def: `[{ "type": "bool" }]`, - enc: "0000000000000000000000000000000000000000000000000000000000000000", - want: false, - }, + // Bools { def: `[{ "type": "bool" }]`, enc: "0000000000000000000000000000000000000000000000000001000000000001", @@ -73,11 +91,7 @@ var unpackTests = []unpackTest{ want: false, err: "abi: improperly encoded boolean value", }, - { - def: `[{"type": "uint32"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000001", - want: uint32(1), - }, + // Integers { def: `[{"type": "uint32"}]`, enc: "0000000000000000000000000000000000000000000000000000000000000001", @@ -90,16 +104,6 @@ var unpackTests = []unpackTest{ want: uint16(0), err: "abi: cannot unmarshal *big.Int in to uint16", }, - { - def: `[{"type": "uint17"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000001", - want: big.NewInt(1), - }, - { - def: `[{"type": "int32"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000001", - want: int32(1), - }, { def: `[{"type": "int32"}]`, enc: "0000000000000000000000000000000000000000000000000000000000000001", @@ -112,36 +116,10 @@ var unpackTests = []unpackTest{ want: int16(0), err: "abi: cannot unmarshal *big.Int in to int16", }, - { - def: `[{"type": "int17"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000001", - want: big.NewInt(1), - }, - { - def: `[{"type": "int256"}]`, - enc: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", - want: big.NewInt(-1), - }, - { - def: `[{"type": "address"}]`, - enc: "0000000000000000000000000100000000000000000000000000000000000000", - want: common.Address{1}, - }, - { - def: `[{"type": "bytes32"}]`, - enc: "0100000000000000000000000000000000000000000000000000000000000000", - want: [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - }, { def: `[{"type": "bytes"}]`, enc: "000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000200100000000000000000000000000000000000000000000000000000000000000", - want: common.Hex2Bytes("0100000000000000000000000000000000000000000000000000000000000000"), - }, - { - def: `[{"type": "bytes"}]`, - enc: "000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000200100000000000000000000000000000000000000000000000000000000000000", - want: [32]byte{}, - err: "abi: cannot unmarshal []uint8 in to [32]uint8", + want: [32]byte{1}, }, { def: `[{"type": "bytes32"}]`, @@ -149,204 +127,13 @@ var unpackTests = []unpackTest{ want: []byte(nil), err: "abi: cannot unmarshal [32]uint8 in to []uint8", }, - { - def: `[{"type": "bytes32"}]`, - enc: "0100000000000000000000000000000000000000000000000000000000000000", - want: [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - }, - { - def: `[{"type": "function"}]`, - enc: "0100000000000000000000000000000000000000000000000000000000000000", - want: [24]byte{1}, - }, - // slices - { - def: `[{"type": "uint8[]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: []uint8{1, 2}, - }, - { - def: `[{"type": "uint8[2]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [2]uint8{1, 2}, - }, - // multi dimensional, if these pass, all types that don't require length prefix should pass - { - def: `[{"type": "uint8[][]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [][]uint8{{1, 2}, {1, 2}}, - }, - { - def: `[{"type": "uint8[2][2]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [2][2]uint8{{1, 2}, {1, 2}}, - }, - { - def: `[{"type": "uint8[][2]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000001", - want: [2][]uint8{{1}, {1}}, - }, - { - def: `[{"type": "uint8[2][]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [][2]uint8{{1, 2}}, - }, - { - def: `[{"type": "uint8[2][]"}]`, - enc: "000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [][2]uint8{{1, 2}, {1, 2}}, - }, - { - def: `[{"type": "uint16[]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: []uint16{1, 2}, - }, - { - def: `[{"type": "uint16[2]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [2]uint16{1, 2}, - }, - { - def: `[{"type": "uint32[]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: []uint32{1, 2}, - }, - { - def: `[{"type": "uint32[2]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [2]uint32{1, 2}, - }, - { - def: `[{"type": "uint32[2][3][4]"}]`, - enc: "000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000050000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000700000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000d000000000000000000000000000000000000000000000000000000000000000e000000000000000000000000000000000000000000000000000000000000000f000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000110000000000000000000000000000000000000000000000000000000000000012000000000000000000000000000000000000000000000000000000000000001300000000000000000000000000000000000000000000000000000000000000140000000000000000000000000000000000000000000000000000000000000015000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000170000000000000000000000000000000000000000000000000000000000000018", - want: [4][3][2]uint32{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}, {{13, 14}, {15, 16}, {17, 18}}, {{19, 20}, {21, 22}, {23, 24}}}, - }, - { - def: `[{"type": "uint64[]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: []uint64{1, 2}, - }, - { - def: `[{"type": "uint64[2]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [2]uint64{1, 2}, - }, - { - def: `[{"type": "uint256[]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: []*big.Int{big.NewInt(1), big.NewInt(2)}, - }, - { - def: `[{"type": "uint256[3]"}]`, - enc: "000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003", - want: [3]*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)}, - }, - { - def: `[{"type": "string[4]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000140000000000000000000000000000000000000000000000000000000000000000548656c6c6f0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005576f726c64000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000b476f2d657468657265756d0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008457468657265756d000000000000000000000000000000000000000000000000", - want: [4]string{"Hello", "World", "Go-ethereum", "Ethereum"}, - }, - { - def: `[{"type": "string[]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000008457468657265756d000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000b676f2d657468657265756d000000000000000000000000000000000000000000", - want: []string{"Ethereum", "go-ethereum"}, - }, - { - def: `[{"type": "int8[]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: []int8{1, 2}, - }, - { - def: `[{"type": "int8[2]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [2]int8{1, 2}, - }, - { - def: `[{"type": "int16[]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: []int16{1, 2}, - }, - { - def: `[{"type": "int16[2]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [2]int16{1, 2}, - }, - { - def: `[{"type": "int32[]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: []int32{1, 2}, - }, - { - def: `[{"type": "int32[2]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [2]int32{1, 2}, - }, - { - def: `[{"type": "int64[]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: []int64{1, 2}, - }, - { - def: `[{"type": "int64[2]"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: [2]int64{1, 2}, - }, - { - def: `[{"type": "int256[]"}]`, - enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: []*big.Int{big.NewInt(1), big.NewInt(2)}, - }, - { - def: `[{"type": "int256[3]"}]`, - enc: "000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000003", - want: [3]*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3)}, - }, - // struct outputs - { - def: `[{"name":"int1","type":"int256"},{"name":"int2","type":"int256"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: struct { - Int1 *big.Int - Int2 *big.Int - }{big.NewInt(1), big.NewInt(2)}, - }, - { - def: `[{"name":"int_one","type":"int256"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: struct { - IntOne *big.Int - }{big.NewInt(1)}, - }, - { - def: `[{"name":"int__one","type":"int256"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: struct { - IntOne *big.Int - }{big.NewInt(1)}, - }, - { - def: `[{"name":"int_one_","type":"int256"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: struct { - IntOne *big.Int - }{big.NewInt(1)}, - }, - { - def: `[{"name":"int_one","type":"int256"}, {"name":"intone","type":"int256"}]`, - enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", - want: struct { - IntOne *big.Int - Intone *big.Int - }{big.NewInt(1), big.NewInt(2)}, - }, { def: `[{"name":"___","type":"int256"}]`, enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", want: struct { IntOne *big.Int Intone *big.Int - }{}, - err: "abi: purely underscored output cannot unpack to struct", + }{IntOne: big.NewInt(1)}, }, { def: `[{"name":"int_one","type":"int256"},{"name":"IntOne","type":"int256"}]`, @@ -393,22 +180,47 @@ var unpackTests = []unpackTest{ }{}, err: "abi: purely underscored output cannot unpack to struct", }, + // Make sure only the first argument is consumed + { + def: `[{"name":"int_one","type":"int256"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + want: struct { + IntOne *big.Int + }{big.NewInt(1)}, + }, + { + def: `[{"name":"int__one","type":"int256"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + want: struct { + IntOne *big.Int + }{big.NewInt(1)}, + }, + { + def: `[{"name":"int_one_","type":"int256"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", + want: struct { + IntOne *big.Int + }{big.NewInt(1)}, + }, } -func TestUnpack(t *testing.T) { +// TestLocalUnpackTests runs test specially designed only for unpacking. +// All test cases that can be used to test packing and unpacking should move to packing_test.go +func TestLocalUnpackTests(t *testing.T) { for i, test := range unpackTests { t.Run(strconv.Itoa(i), func(t *testing.T) { - def := fmt.Sprintf(`[{ "name" : "method", "outputs": %s}]`, test.def) + //Unpack + def := fmt.Sprintf(`[{ "name" : "method", "type": "function", "outputs": %s}]`, test.def) abi, err := JSON(strings.NewReader(def)) if err != nil { t.Fatalf("invalid ABI definition %s: %v", def, err) } encb, err := hex.DecodeString(test.enc) if err != nil { - t.Fatalf("invalid hex: %s" + test.enc) + t.Fatalf("invalid hex %s: %v", test.enc, err) } outptr := reflect.New(reflect.TypeOf(test.want)) - err = abi.Unpack(outptr.Interface(), "method", encb) + err = abi.UnpackIntoInterface(outptr.Interface(), "method", encb) if err := test.checkError(err); err != nil { t.Errorf("test %d (%v) failed: %v", i, test.def, err) return @@ -421,7 +233,7 @@ func TestUnpack(t *testing.T) { } } -func TestUnpackSetDynamicArrayOutput(t *testing.T) { +func TestUnpackIntoInterfaceSetDynamicArrayOutput(t *testing.T) { abi, err := JSON(strings.NewReader(`[{"constant":true,"inputs":[],"name":"testDynamicFixedBytes15","outputs":[{"name":"","type":"bytes15[]"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":true,"inputs":[],"name":"testDynamicFixedBytes32","outputs":[{"name":"","type":"bytes32[]"}],"payable":false,"stateMutability":"view","type":"function"}]`)) if err != nil { t.Fatal(err) @@ -436,7 +248,7 @@ func TestUnpackSetDynamicArrayOutput(t *testing.T) { ) // test 32 - err = abi.Unpack(&out32, "testDynamicFixedBytes32", marshalledReturn32) + err = abi.UnpackIntoInterface(&out32, "testDynamicFixedBytes32", marshalledReturn32) if err != nil { t.Fatal(err) } @@ -453,7 +265,7 @@ func TestUnpackSetDynamicArrayOutput(t *testing.T) { } // test 15 - err = abi.Unpack(&out15, "testDynamicFixedBytes32", marshalledReturn15) + err = abi.UnpackIntoInterface(&out15, "testDynamicFixedBytes32", marshalledReturn15) if err != nil { t.Fatal(err) } @@ -477,7 +289,7 @@ type methodMultiOutput struct { func methodMultiReturn(require *require.Assertions) (ABI, []byte, methodMultiOutput) { const definition = `[ - { "name" : "multi", "constant" : false, "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]` + { "name" : "multi", "type": "function", "outputs": [ { "name": "Int", "type": "uint256" }, { "name": "String", "type": "string" } ] }]` var expected = methodMultiOutput{big.NewInt(1), "hello"} abi, err := JSON(strings.NewReader(definition)) @@ -497,6 +309,11 @@ func TestMethodMultiReturn(t *testing.T) { Int *big.Int } + newInterfaceSlice := func(len int) interface{} { + slice := make([]interface{}, len) + return &slice + } + abi, data, expected := methodMultiReturn(require.New(t)) bigint := new(big.Int) var testCases = []struct { @@ -524,6 +341,16 @@ func TestMethodMultiReturn(t *testing.T) { &[2]interface{}{&expected.Int, &expected.String}, "", "Can unpack into an array", + }, { + &[2]interface{}{}, + &[2]interface{}{expected.Int, expected.String}, + "", + "Can unpack into interface array", + }, { + newInterfaceSlice(2), + &[]interface{}{expected.Int, expected.String}, + "", + "Can unpack into interface slice", }, { &[]interface{}{new(int), new(int)}, &[]interface{}{&expected.Int, &expected.String}, @@ -532,14 +359,14 @@ func TestMethodMultiReturn(t *testing.T) { }, { &[]interface{}{new(int)}, &[]interface{}{}, - "abi: insufficient number of elements in the list/array for unpack, want 2, got 1", + "abi: insufficient number of arguments for unpack, want 2, got 1", "Can not unpack into a slice with wrong types", }} for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { require := require.New(t) - err := abi.Unpack(tc.dest, "multi", data) + err := abi.UnpackIntoInterface(tc.dest, "multi", data) if tc.error == "" { require.Nil(err, "Should be able to unpack method outputs.") require.Equal(tc.expected, tc.dest) @@ -551,7 +378,7 @@ func TestMethodMultiReturn(t *testing.T) { } func TestMultiReturnWithArray(t *testing.T) { - const definition = `[{"name" : "multi", "outputs": [{"type": "uint64[3]"}, {"type": "uint64"}]}]` + const definition = `[{"name" : "multi", "type": "function", "outputs": [{"type": "uint64[3]"}, {"type": "uint64"}]}]` abi, err := JSON(strings.NewReader(definition)) if err != nil { t.Fatal(err) @@ -562,7 +389,7 @@ func TestMultiReturnWithArray(t *testing.T) { ret1, ret1Exp := new([3]uint64), [3]uint64{9, 9, 9} ret2, ret2Exp := new(uint64), uint64(8) - if err := abi.Unpack(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { + if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { t.Fatal(err) } if !reflect.DeepEqual(*ret1, ret1Exp) { @@ -574,7 +401,7 @@ func TestMultiReturnWithArray(t *testing.T) { } func TestMultiReturnWithStringArray(t *testing.T) { - const definition = `[{"name" : "multi", "outputs": [{"name": "","type": "uint256[3]"},{"name": "","type": "address"},{"name": "","type": "string[2]"},{"name": "","type": "bool"}]}]` + const definition = `[{"name" : "multi", "type": "function", "outputs": [{"name": "","type": "uint256[3]"},{"name": "","type": "address"},{"name": "","type": "string[2]"},{"name": "","type": "bool"}]}]` abi, err := JSON(strings.NewReader(definition)) if err != nil { t.Fatal(err) @@ -586,7 +413,7 @@ func TestMultiReturnWithStringArray(t *testing.T) { ret2, ret2Exp := new(common.Address), common.HexToAddress("ab1257528b3782fb40d7ed5f72e624b744dffb2f") ret3, ret3Exp := new([2]string), [2]string{"Ethereum", "Hello, Ethereum!"} ret4, ret4Exp := new(bool), false - if err := abi.Unpack(&[]interface{}{ret1, ret2, ret3, ret4}, "multi", buff.Bytes()); err != nil { + if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2, ret3, ret4}, "multi", buff.Bytes()); err != nil { t.Fatal(err) } if !reflect.DeepEqual(*ret1, ret1Exp) { @@ -604,7 +431,7 @@ func TestMultiReturnWithStringArray(t *testing.T) { } func TestMultiReturnWithStringSlice(t *testing.T) { - const definition = `[{"name" : "multi", "outputs": [{"name": "","type": "string[]"},{"name": "","type": "uint256[]"}]}]` + const definition = `[{"name" : "multi", "type": "function", "outputs": [{"name": "","type": "string[]"},{"name": "","type": "uint256[]"}]}]` abi, err := JSON(strings.NewReader(definition)) if err != nil { t.Fatal(err) @@ -624,7 +451,7 @@ func TestMultiReturnWithStringSlice(t *testing.T) { buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000065")) // output[1][1] value ret1, ret1Exp := new([]string), []string{"ethereum", "go-ethereum"} ret2, ret2Exp := new([]*big.Int), []*big.Int{big.NewInt(100), big.NewInt(101)} - if err := abi.Unpack(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { + if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { t.Fatal(err) } if !reflect.DeepEqual(*ret1, ret1Exp) { @@ -640,7 +467,7 @@ func TestMultiReturnWithDeeplyNestedArray(t *testing.T) { // values of nested static arrays count towards the size as well, and any element following // after such nested array argument should be read with the correct offset, // so that it does not read content from the previous array argument. - const definition = `[{"name" : "multi", "outputs": [{"type": "uint64[3][2][4]"}, {"type": "uint64"}]}]` + const definition = `[{"name" : "multi", "type": "function", "outputs": [{"type": "uint64[3][2][4]"}, {"type": "uint64"}]}]` abi, err := JSON(strings.NewReader(definition)) if err != nil { t.Fatal(err) @@ -664,7 +491,7 @@ func TestMultiReturnWithDeeplyNestedArray(t *testing.T) { {{0x411, 0x412, 0x413}, {0x421, 0x422, 0x423}}, } ret2, ret2Exp := new(uint64), uint64(0x9876) - if err := abi.Unpack(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { + if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { t.Fatal(err) } if !reflect.DeepEqual(*ret1, ret1Exp) { @@ -677,15 +504,15 @@ func TestMultiReturnWithDeeplyNestedArray(t *testing.T) { func TestUnmarshal(t *testing.T) { const definition = `[ - { "name" : "int", "constant" : false, "outputs": [ { "type": "uint256" } ] }, - { "name" : "bool", "constant" : false, "outputs": [ { "type": "bool" } ] }, - { "name" : "bytes", "constant" : false, "outputs": [ { "type": "bytes" } ] }, - { "name" : "fixed", "constant" : false, "outputs": [ { "type": "bytes32" } ] }, - { "name" : "multi", "constant" : false, "outputs": [ { "type": "bytes" }, { "type": "bytes" } ] }, - { "name" : "intArraySingle", "constant" : false, "outputs": [ { "type": "uint256[3]" } ] }, - { "name" : "addressSliceSingle", "constant" : false, "outputs": [ { "type": "address[]" } ] }, - { "name" : "addressSliceDouble", "constant" : false, "outputs": [ { "name": "a", "type": "address[]" }, { "name": "b", "type": "address[]" } ] }, - { "name" : "mixedBytes", "constant" : true, "outputs": [ { "name": "a", "type": "bytes" }, { "name": "b", "type": "bytes32" } ] }]` + { "name" : "int", "type": "function", "outputs": [ { "type": "uint256" } ] }, + { "name" : "bool", "type": "function", "outputs": [ { "type": "bool" } ] }, + { "name" : "bytes", "type": "function", "outputs": [ { "type": "bytes" } ] }, + { "name" : "fixed", "type": "function", "outputs": [ { "type": "bytes32" } ] }, + { "name" : "multi", "type": "function", "outputs": [ { "type": "bytes" }, { "type": "bytes" } ] }, + { "name" : "intArraySingle", "type": "function", "outputs": [ { "type": "uint256[3]" } ] }, + { "name" : "addressSliceSingle", "type": "function", "outputs": [ { "type": "address[]" } ] }, + { "name" : "addressSliceDouble", "type": "function", "outputs": [ { "name": "a", "type": "address[]" }, { "name": "b", "type": "address[]" } ] }, + { "name" : "mixedBytes", "type": "function", "stateMutability" : "view", "outputs": [ { "name": "a", "type": "bytes" }, { "name": "b", "type": "bytes32" } ] }]` abi, err := JSON(strings.NewReader(definition)) if err != nil { @@ -703,7 +530,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000a")) buff.Write(common.Hex2Bytes("0102000000000000000000000000000000000000000000000000000000000000")) - err = abi.Unpack(&mixedBytes, "mixedBytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&mixedBytes, "mixedBytes", buff.Bytes()) if err != nil { t.Error(err) } else { @@ -718,7 +545,7 @@ func TestUnmarshal(t *testing.T) { // marshal int var Int *big.Int - err = abi.Unpack(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + err = abi.UnpackIntoInterface(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) if err != nil { t.Error(err) } @@ -729,7 +556,7 @@ func TestUnmarshal(t *testing.T) { // marshal bool var Bool bool - err = abi.Unpack(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + err = abi.UnpackIntoInterface(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) if err != nil { t.Error(err) } @@ -746,7 +573,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(bytesOut) var Bytes []byte - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes()) if err != nil { t.Error(err) } @@ -762,7 +589,7 @@ func TestUnmarshal(t *testing.T) { bytesOut = common.RightPadBytes([]byte("hello"), 64) buff.Write(bytesOut) - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes()) if err != nil { t.Error(err) } @@ -778,7 +605,7 @@ func TestUnmarshal(t *testing.T) { bytesOut = common.RightPadBytes([]byte("hello"), 64) buff.Write(bytesOut) - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes()) if err != nil { t.Error(err) } @@ -788,7 +615,7 @@ func TestUnmarshal(t *testing.T) { } // marshal dynamic bytes output empty - err = abi.Unpack(&Bytes, "bytes", nil) + err = abi.UnpackIntoInterface(&Bytes, "bytes", nil) if err == nil { t.Error("expected error") } @@ -799,7 +626,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) buff.Write(common.RightPadBytes([]byte("hello"), 32)) - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes()) if err != nil { t.Error(err) } @@ -813,7 +640,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.RightPadBytes([]byte("hello"), 32)) var hash common.Hash - err = abi.Unpack(&hash, "fixed", buff.Bytes()) + err = abi.UnpackIntoInterface(&hash, "fixed", buff.Bytes()) if err != nil { t.Error(err) } @@ -826,12 +653,12 @@ func TestUnmarshal(t *testing.T) { // marshal error buff.Reset() buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes()) if err == nil { t.Error("expected error") } - err = abi.Unpack(&Bytes, "multi", make([]byte, 64)) + err = abi.UnpackIntoInterface(&Bytes, "multi", make([]byte, 64)) if err == nil { t.Error("expected error") } @@ -842,7 +669,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000003")) // marshal int array var intArray [3]*big.Int - err = abi.Unpack(&intArray, "intArraySingle", buff.Bytes()) + err = abi.UnpackIntoInterface(&intArray, "intArraySingle", buff.Bytes()) if err != nil { t.Error(err) } @@ -863,7 +690,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000")) var outAddr []common.Address - err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes()) + err = abi.UnpackIntoInterface(&outAddr, "addressSliceSingle", buff.Bytes()) if err != nil { t.Fatal("didn't expect error:", err) } @@ -890,7 +717,7 @@ func TestUnmarshal(t *testing.T) { A []common.Address B []common.Address } - err = abi.Unpack(&outAddrStruct, "addressSliceDouble", buff.Bytes()) + err = abi.UnpackIntoInterface(&outAddrStruct, "addressSliceDouble", buff.Bytes()) if err != nil { t.Fatal("didn't expect error:", err) } @@ -918,14 +745,14 @@ func TestUnmarshal(t *testing.T) { buff.Reset() buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000100")) - err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes()) + err = abi.UnpackIntoInterface(&outAddr, "addressSliceSingle", buff.Bytes()) if err == nil { t.Fatal("expected error:", err) } } func TestUnpackTuple(t *testing.T) { - const simpleTuple = `[{"name":"tuple","constant":false,"outputs":[{"type":"tuple","name":"ret","components":[{"type":"int256","name":"a"},{"type":"int256","name":"b"}]}]}]` + const simpleTuple = `[{"name":"tuple","type":"function","outputs":[{"type":"tuple","name":"ret","components":[{"type":"int256","name":"a"},{"type":"int256","name":"b"}]}]}]` abi, err := JSON(strings.NewReader(simpleTuple)) if err != nil { t.Fatal(err) @@ -935,30 +762,26 @@ func TestUnpackTuple(t *testing.T) { buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) // ret[a] = 1 buff.Write(common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")) // ret[b] = -1 + // If the result is single tuple, use struct as return value container directly. v := struct { - Ret struct { - A *big.Int - B *big.Int - } - }{Ret: struct { A *big.Int B *big.Int - }{new(big.Int), new(big.Int)}} + }{new(big.Int), new(big.Int)} - err = abi.Unpack(&v, "tuple", buff.Bytes()) + err = abi.UnpackIntoInterface(&v, "tuple", buff.Bytes()) if err != nil { t.Error(err) } else { - if v.Ret.A.Cmp(big.NewInt(1)) != 0 { - t.Errorf("unexpected value unpacked: want %x, got %x", 1, v.Ret.A) + if v.A.Cmp(big.NewInt(1)) != 0 { + t.Errorf("unexpected value unpacked: want %x, got %x", 1, v.A) } - if v.Ret.B.Cmp(big.NewInt(-1)) != 0 { - t.Errorf("unexpected value unpacked: want %x, got %x", v.Ret.B, -1) + if v.B.Cmp(big.NewInt(-1)) != 0 { + t.Errorf("unexpected value unpacked: want %x, got %x", -1, v.B) } } // Test nested tuple - const nestedTuple = `[{"name":"tuple","constant":false,"outputs":[ + const nestedTuple = `[{"name":"tuple","type":"function","outputs":[ {"type":"tuple","name":"s","components":[{"type":"uint256","name":"a"},{"type":"uint256[]","name":"b"},{"type":"tuple[]","name":"c","components":[{"name":"x", "type":"uint256"},{"name":"y","type":"uint256"}]}]}, {"type":"tuple","name":"t","components":[{"name":"x", "type":"uint256"},{"name":"y","type":"uint256"}]}, {"type":"uint256","name":"a"} @@ -1017,7 +840,7 @@ func TestUnpackTuple(t *testing.T) { A: big.NewInt(1), } - err = abi.Unpack(&ret, "tuple", buff.Bytes()) + err = abi.UnpackIntoInterface(&ret, "tuple", buff.Bytes()) if err != nil { t.Error(err) } @@ -1080,7 +903,7 @@ func TestOOMMaliciousInput(t *testing.T) { }, } for i, test := range oomTests { - def := fmt.Sprintf(`[{ "name" : "method", "outputs": %s}]`, test.def) + def := fmt.Sprintf(`[{ "name" : "method", "type": "function", "outputs": %s}]`, test.def) abi, err := JSON(strings.NewReader(def)) if err != nil { t.Fatalf("invalid ABI definition %s: %v", def, err) diff --git a/accounts/abi/utils.go b/accounts/abi/utils.go new file mode 100644 index 0000000000..f88d2ee2d4 --- /dev/null +++ b/accounts/abi/utils.go @@ -0,0 +1,39 @@ +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package abi + +import "fmt" + +// ResolveNameConflict returns the next available name for a given thing. +// This helper can be used for lots of purposes: +// +// - In solidity function overloading is supported, this function can fix +// the name conflicts of overloaded functions. +// - In golang binding generation, the parameter(in function, event, error, +// and struct definition) name will be converted to camelcase style which +// may eventually lead to name conflicts. +// +// Name conflicts are mostly resolved by adding number suffix. e.g. if the abi contains +// Methods "send" and "send1", ResolveNameConflict would return "send2" for input "send". +func ResolveNameConflict(rawName string, used func(string) bool) string { + name := rawName + ok := used(name) + for idx := 0; ok; idx++ { + name = fmt.Sprintf("%s%d", rawName, idx) + ok = used(name) + } + return name +} diff --git a/common/math/big.go b/common/math/big.go index 7872786503..27068b2282 100644 --- a/common/math/big.go +++ b/common/math/big.go @@ -176,13 +176,19 @@ func U256(x *big.Int) *big.Int { return x.And(x, tt256m1) } +// U256Bytes converts a big Int into a 256bit EVM number. +// This operation is destructive. +func U256Bytes(n *big.Int) []byte { + return PaddedBigBytes(U256(n), 32) +} + // S256 interprets x as a two's complement number. // x must not exceed 256 bits (the result is undefined if it does) and is not modified. // -// S256(0) = 0 -// S256(1) = 1 -// S256(2**255) = -2**255 -// S256(2**256-1) = -1 +// S256(0) = 0 +// S256(1) = 1 +// S256(2**255) = -2**255 +// S256(2**256-1) = -1 func S256(x *big.Int) *big.Int { if x.Cmp(tt255) < 0 { return x From f545fafedbee0140a0e5c4d0d66681fd72cb17e4 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 9 Oct 2023 00:25:27 +0700 Subject: [PATCH 083/119] Fix unpack tuple unit test --- accounts/abi/unpack_test.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/accounts/abi/unpack_test.go b/accounts/abi/unpack_test.go index 7fda0ccbf1..9af0a666fb 100644 --- a/accounts/abi/unpack_test.go +++ b/accounts/abi/unpack_test.go @@ -763,20 +763,24 @@ func TestUnpackTuple(t *testing.T) { buff.Write(common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")) // ret[b] = -1 // If the result is single tuple, use struct as return value container directly. - v := struct { + type v struct { A *big.Int B *big.Int - }{new(big.Int), new(big.Int)} + } + type r struct { + Result v + } + var ret0 = new(r) + err = abi.UnpackIntoInterface(ret0, "tuple", buff.Bytes()) - err = abi.UnpackIntoInterface(&v, "tuple", buff.Bytes()) if err != nil { t.Error(err) } else { - if v.A.Cmp(big.NewInt(1)) != 0 { - t.Errorf("unexpected value unpacked: want %x, got %x", 1, v.A) + if ret0.Result.A.Cmp(big.NewInt(1)) != 0 { + t.Errorf("unexpected value unpacked: want %x, got %x", 1, ret0.Result.A) } - if v.B.Cmp(big.NewInt(-1)) != 0 { - t.Errorf("unexpected value unpacked: want %x, got %x", -1, v.B) + if ret0.Result.B.Cmp(big.NewInt(-1)) != 0 { + t.Errorf("unexpected value unpacked: want %x, got %x", -1, ret0.Result.B) } } From 9394238e4ca35576d3e688c19153ad07d986ea66 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 9 Oct 2023 00:45:40 +0700 Subject: [PATCH 084/119] Use UnpackIntoInterface instead of the general Unpack --- accounts/abi/bind/base.go | 20 ++++++++++++++++---- accounts/abi/bind/bind.go | 9 +++++---- core/token_validator.go | 4 ++-- tomox/token.go | 4 ++-- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/accounts/abi/bind/base.go b/accounts/abi/bind/base.go index caf1640496..70b827bfc7 100644 --- a/accounts/abi/bind/base.go +++ b/accounts/abi/bind/base.go @@ -30,6 +30,11 @@ import ( "github.com/tomochain/tomochain/event" ) +var ( + errNoEventSignature = errors.New("no event signature") + errEventSignatureMismatch = errors.New("event signature mismatch") +) + // SignerFn is a signer function callback when a contract requires a method to // sign the transaction before submission. type SignerFn func(types.Signer, common.Address, *types.Transaction) (*types.Transaction, error) @@ -161,7 +166,7 @@ func (c *BoundContract) Call(opts *CallOpts, result interface{}, method string, if err != nil { return err } - return c.abi.Unpack(result, method, output) + return c.abi.UnpackIntoInterface(result, method, output) } // Transact invokes the (paid) contract method with params as input values. @@ -252,7 +257,7 @@ func (c *BoundContract) FilterLogs(opts *FilterOpts, name string, query ...[]int opts = new(FilterOpts) } // Append the event selector to the query parameters and construct the topic set - query = append([][]interface{}{{c.abi.Events[name].Id()}}, query...) + query = append([][]interface{}{{c.abi.Events[name].ID}}, query...) topics, err := makeTopics(query...) if err != nil { @@ -301,7 +306,7 @@ func (c *BoundContract) WatchLogs(opts *WatchOpts, name string, query ...[]inter opts = new(WatchOpts) } // Append the event selector to the query parameters and construct the topic set - query = append([][]interface{}{{c.abi.Events[name].Id()}}, query...) + query = append([][]interface{}{{c.abi.Events[name].ID}}, query...) topics, err := makeTopics(query...) if err != nil { @@ -326,8 +331,15 @@ func (c *BoundContract) WatchLogs(opts *WatchOpts, name string, query ...[]inter // UnpackLog unpacks a retrieved log into the provided output structure. func (c *BoundContract) UnpackLog(out interface{}, event string, log types.Log) error { + // Anonymous events are not supported. + if len(log.Topics) == 0 { + return errNoEventSignature + } + if log.Topics[0] != c.abi.Events[event].ID { + return errEventSignatureMismatch + } if len(log.Data) > 0 { - if err := c.abi.Unpack(out, event, log.Data); err != nil { + if err := c.abi.UnpackIntoInterface(out, event, log.Data); err != nil { return err } } diff --git a/accounts/abi/bind/bind.go b/accounts/abi/bind/bind.go index efb24e4d88..9b73e7ef2f 100644 --- a/accounts/abi/bind/bind.go +++ b/accounts/abi/bind/bind.go @@ -89,7 +89,7 @@ func Bind(types []string, abis []string, bytecodes []string, pkg string, lang La } } // Append the methods to the call or transact lists - if original.Const { + if original.IsConstant() { calls[original.Name] = &tmplMethod{Original: original, Normalized: normalized, Structured: structured(original.Outputs)} } else { transacts[original.Name] = &tmplMethod{Original: original, Normalized: normalized, Structured: structured(original.Outputs)} @@ -166,9 +166,10 @@ var bindType = map[Lang]func(kind abi.Type) string{ // Helper function for the binding generators. // It reads the unmatched characters after the inner type-match, -// (since the inner type is a prefix of the total type declaration), -// looks for valid arrays (possibly a dynamic one) wrapping the inner type, -// and returns the sizes of these arrays. +// +// (since the inner type is a prefix of the total type declaration), +// looks for valid arrays (possibly a dynamic one) wrapping the inner type, +// and returns the sizes of these arrays. // // Returned array sizes are in the same order as solidity signatures; inner array size first. // Array sizes may also be "", indicating a dynamic array. diff --git a/core/token_validator.go b/core/token_validator.go index 485ff05c59..d9204abbf7 100644 --- a/core/token_validator.go +++ b/core/token_validator.go @@ -78,14 +78,14 @@ func RunContract(chain consensus.ChainContext, statedb *state.StateDB, contractA return nil, err } var unpackResult interface{} - err = abi.Unpack(&unpackResult, method, result) + err = abi.UnpackIntoInterface(&unpackResult, method, result) if err != nil { return nil, err } return unpackResult, nil } -//FIXME: please use copyState for this function +// FIXME: please use copyState for this function // CallContractWithState executes a contract call at the given state. func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { // Ensure message is initialized properly. diff --git a/tomox/token.go b/tomox/token.go index 24e6e138f0..a90a3dce90 100644 --- a/tomox/token.go +++ b/tomox/token.go @@ -37,7 +37,7 @@ func RunContract(chain consensus.ChainContext, statedb *state.StateDB, contractA return nil, err } var unpackResult interface{} - err = abi.Unpack(&unpackResult, method, result) + err = abi.UnpackIntoInterface(&unpackResult, method, result) if err != nil { return nil, err } @@ -75,4 +75,4 @@ func (tomox *TomoX) GetTokenDecimal(chain consensus.ChainContext, statedb *state // FIXME: using in unit tests only func (tomox *TomoX) SetTokenDecimal(token common.Address, decimal *big.Int) { tomox.tokenDecimalCache.Add(token, decimal) -} \ No newline at end of file +} From d2ac293c592ddcdd7ca61f182a6bc714a0aa0bd6 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 9 Oct 2023 00:33:44 +0700 Subject: [PATCH 085/119] Parse ABI only once, create metadata struct --- accounts/abi/bind/base.go | 25 +++ accounts/abi/bind/template.go | 363 +++++++++++++++++++--------------- 2 files changed, 231 insertions(+), 157 deletions(-) diff --git a/accounts/abi/bind/base.go b/accounts/abi/bind/base.go index 70b827bfc7..35c236abba 100644 --- a/accounts/abi/bind/base.go +++ b/accounts/abi/bind/base.go @@ -21,6 +21,8 @@ import ( "errors" "fmt" "math/big" + "strings" + "sync" "github.com/tomochain/tomochain" "github.com/tomochain/tomochain/accounts/abi" @@ -77,6 +79,29 @@ type WatchOpts struct { Context context.Context // Network context to support cancellation and timeouts (nil = no timeout) } +// MetaData collects all metadata for a bound contract. +type MetaData struct { + mu sync.Mutex + Sigs map[string]string + Bin string + ABI string + ab *abi.ABI +} + +func (m *MetaData) GetAbi() (*abi.ABI, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.ab != nil { + return m.ab, nil + } + if parsed, err := abi.JSON(strings.NewReader(m.ABI)); err != nil { + return nil, err + } else { + m.ab = &parsed + } + return m.ab, nil +} + // BoundContract is the base wrapper object that reflects a contract on the // Ethereum network. It contains a collection of methods that are used by the // higher level contract bindings to operate. diff --git a/accounts/abi/bind/template.go b/accounts/abi/bind/template.go index f49b0efd1c..43985bfe6b 100644 --- a/accounts/abi/bind/template.go +++ b/accounts/abi/bind/template.go @@ -22,17 +22,24 @@ import "github.com/tomochain/tomochain/accounts/abi" type tmplData struct { Package string // Name of the package to place the generated file in Contracts map[string]*tmplContract // List of contracts to generate into this file + Libraries map[string]string // Map the bytecode's link pattern to the library name + Structs map[string]*tmplStruct // Contract struct type definitions } // tmplContract contains the data needed to generate an individual contract binding. type tmplContract struct { Type string // Type name of the main contract binding InputABI string // JSON ABI used as the input to generate the binding from - InputBin string // Optional EVM bytecode used to denetare deploy code from + InputBin string // Optional EVM bytecode used to generate deploy code from + FuncSigs map[string]string // Optional map: string signature -> 4-byte signature Constructor abi.Method // Contract constructor for deploy parametrization Calls map[string]*tmplMethod // Contract calls that only read state data Transacts map[string]*tmplMethod // Contract calls that write state data + Fallback *tmplMethod // Additional special fallback function + Receive *tmplMethod // Additional special receive function Events map[string]*tmplEvent // Contract events accessors + Libraries map[string]string // Same as tmplData, but filtered to only keep what the contract needs + Library bool // Indicator whether the contract is a library } // tmplMethod is a wrapper around an abi.Method that contains a few preprocessed @@ -43,42 +50,120 @@ type tmplMethod struct { Structured bool // Whether the returns should be accumulated into a struct } -// tmplEvent is a wrapper around an a +// tmplEvent is a wrapper around an abi.Event that contains a few preprocessed +// and cached data fields. type tmplEvent struct { Original abi.Event // Original event as parsed by the abi package Normalized abi.Event // Normalized version of the parsed fields } +// tmplField is a wrapper around a struct field with binding language +// struct type definition and relative filed name. +type tmplField struct { + Type string // Field type representation depends on target binding language + Name string // Field name converted from the raw user-defined field name + SolKind abi.Type // Raw abi type information +} + +// tmplStruct is a wrapper around an abi.tuple and contains an auto-generated +// struct name. +type tmplStruct struct { + Name string // Auto-generated struct name(before solidity v0.5.11) or raw name. + Fields []*tmplField // Struct fields definition depends on the binding language. +} + // tmplSource is language to template mapping containing all the supported // programming languages the package can generate to. var tmplSource = map[Lang]string{ - LangGo: tmplSourceGo, - LangJava: tmplSourceJava, + LangGo: tmplSourceGo, } -// tmplSourceGo is the Go source template use to generate the contract binding -// based on. +// tmplSourceGo is the Go source template that the generated Go contract binding +// is based on. const tmplSourceGo = ` // Code generated - DO NOT EDIT. // This file is a generated binding and any manual changes will be lost. package {{.Package}} +import ( + "math/big" + "strings" + "errors" + + "github.com/tomochain/tomochain" + "github.com/tomochain/tomochain/accounts/abi" + "github.com/tomochain/tomochain/accounts/abi/bind" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/event" +) + +// Reference imports to suppress errors if they are not otherwise used. +var ( + _ = errors.New + _ = big.NewInt + _ = strings.NewReader + _ = tomochain.NotFound + _ = bind.Bind + _ = common.Big1 + _ = types.BloomLookup + _ = event.NewSubscription + _ = abi.ConvertType +) + +{{$structs := .Structs}} +{{range $structs}} + // {{.Name}} is an auto generated low-level Go binding around an user-defined struct. + type {{.Name}} struct { + {{range $field := .Fields}} + {{$field.Name}} {{$field.Type}}{{end}} + } +{{end}} + {{range $contract := .Contracts}} + // {{.Type}}MetaData contains all meta data concerning the {{.Type}} contract. + var {{.Type}}MetaData = &bind.MetaData{ + ABI: "{{.InputABI}}", + {{if $contract.FuncSigs -}} + Sigs: map[string]string{ + {{range $strsig, $binsig := .FuncSigs}}"{{$binsig}}": "{{$strsig}}", + {{end}} + }, + {{end -}} + {{if .InputBin -}} + Bin: "0x{{.InputBin}}", + {{end}} + } // {{.Type}}ABI is the input ABI used to generate the binding from. - const {{.Type}}ABI = "{{.InputABI}}" + // Deprecated: Use {{.Type}}MetaData.ABI instead. + var {{.Type}}ABI = {{.Type}}MetaData.ABI + + {{if $contract.FuncSigs}} + // Deprecated: Use {{.Type}}MetaData.Sigs instead. + // {{.Type}}FuncSigs maps the 4-byte function signature to its string representation. + var {{.Type}}FuncSigs = {{.Type}}MetaData.Sigs + {{end}} {{if .InputBin}} // {{.Type}}Bin is the compiled bytecode used for deploying new contracts. - const {{.Type}}Bin = ` + "`" + `{{.InputBin}}` + "`" + ` + // Deprecated: Use {{.Type}}MetaData.Bin instead. + var {{.Type}}Bin = {{.Type}}MetaData.Bin - // Deploy{{.Type}} deploys a new Ethereum contract, binding an instance of {{.Type}} to it. - func Deploy{{.Type}}(auth *bind.TransactOpts, backend bind.ContractBackend {{range .Constructor.Inputs}}, {{.Name}} {{bindtype .Type}}{{end}}) (common.Address, *types.Transaction, *{{.Type}}, error) { - parsed, err := abi.JSON(strings.NewReader({{.Type}}ABI)) + // Deploy{{.Type}} deploys a new Tomochain contract, binding an instance of {{.Type}} to it. + func Deploy{{.Type}}(auth *bind.TransactOpts, backend bind.ContractBackend {{range .Constructor.Inputs}}, {{.Name}} {{bindtype .Type $structs}}{{end}}) (common.Address, *types.Transaction, *{{.Type}}, error) { + parsed, err := {{.Type}}MetaData.GetAbi() if err != nil { return common.Address{}, nil, nil, err } - address, tx, contract, err := bind.DeployContract(auth, parsed, common.FromHex({{.Type}}Bin), backend {{range .Constructor.Inputs}}, {{.Name}}{{end}}) + if parsed == nil { + return common.Address{}, nil, nil, errors.New("GetABI returned nil") + } + {{range $pattern, $name := .Libraries}} + {{decapitalise $name}}Addr, _, _, _ := Deploy{{capitalise $name}}(auth, backend) + {{$contract.Type}}Bin = strings.ReplaceAll({{$contract.Type}}Bin, "__${{$pattern}}$__", {{decapitalise $name}}Addr.String()[2:]) + {{end}} + address, tx, contract, err := bind.DeployContract(auth, *parsed, common.FromHex({{.Type}}Bin), backend {{range .Constructor.Inputs}}, {{.Name}}{{end}}) if err != nil { return common.Address{}, nil, nil, err } @@ -86,29 +171,29 @@ package {{.Package}} } {{end}} - // {{.Type}} is an auto generated Go binding around an Ethereum contract. + // {{.Type}} is an auto generated Go binding around an Tomochain contract. type {{.Type}} struct { {{.Type}}Caller // Read-only binding to the contract {{.Type}}Transactor // Write-only binding to the contract - {{.Type}}Filterer // Log filterer for contract events + {{.Type}}Filterer // Log filterer for contract events } - // {{.Type}}Caller is an auto generated read-only Go binding around an Ethereum contract. + // {{.Type}}Caller is an auto generated read-only Go binding around an Tomochain contract. type {{.Type}}Caller struct { contract *bind.BoundContract // Generic contract wrapper for the low level calls } - // {{.Type}}Transactor is an auto generated write-only Go binding around an Ethereum contract. + // {{.Type}}Transactor is an auto generated write-only Go binding around an Tomochain contract. type {{.Type}}Transactor struct { contract *bind.BoundContract // Generic contract wrapper for the low level calls } - // {{.Type}}Filterer is an auto generated log filtering Go binding around an Ethereum contract events. + // {{.Type}}Filterer is an auto generated log filtering Go binding around an Tomochain contract events. type {{.Type}}Filterer struct { contract *bind.BoundContract // Generic contract wrapper for the low level calls } - // {{.Type}}Session is an auto generated Go binding around an Ethereum contract, + // {{.Type}}Session is an auto generated Go binding around an Tomochain contract, // with pre-set call and transact options. type {{.Type}}Session struct { Contract *{{.Type}} // Generic contract binding to set the session for @@ -116,31 +201,31 @@ package {{.Package}} TransactOpts bind.TransactOpts // Transaction auth options to use throughout this session } - // {{.Type}}CallerSession is an auto generated read-only Go binding around an Ethereum contract, + // {{.Type}}CallerSession is an auto generated read-only Go binding around an Tomochain contract, // with pre-set call options. type {{.Type}}CallerSession struct { Contract *{{.Type}}Caller // Generic contract caller binding to set the session for CallOpts bind.CallOpts // Call options to use throughout this session } - // {{.Type}}TransactorSession is an auto generated write-only Go binding around an Ethereum contract, + // {{.Type}}TransactorSession is an auto generated write-only Go binding around an Tomochain contract, // with pre-set transact options. type {{.Type}}TransactorSession struct { Contract *{{.Type}}Transactor // Generic contract transactor binding to set the session for TransactOpts bind.TransactOpts // Transaction auth options to use throughout this session } - // {{.Type}}Raw is an auto generated low-level Go binding around an Ethereum contract. + // {{.Type}}Raw is an auto generated low-level Go binding around an Tomochain contract. type {{.Type}}Raw struct { Contract *{{.Type}} // Generic contract binding to access the raw methods on } - // {{.Type}}CallerRaw is an auto generated low-level read-only Go binding around an Ethereum contract. + // {{.Type}}CallerRaw is an auto generated low-level read-only Go binding around an Tomochain contract. type {{.Type}}CallerRaw struct { Contract *{{.Type}}Caller // Generic read-only contract binding to access the raw methods on } - // {{.Type}}TransactorRaw is an auto generated low-level write-only Go binding around an Ethereum contract. + // {{.Type}}TransactorRaw is an auto generated low-level write-only Go binding around an Tomochain contract. type {{.Type}}TransactorRaw struct { Contract *{{.Type}}Transactor // Generic write-only contract binding to access the raw methods on } @@ -183,18 +268,18 @@ package {{.Package}} // bind{{.Type}} binds a generic wrapper to an already deployed contract. func bind{{.Type}}(address common.Address, caller bind.ContractCaller, transactor bind.ContractTransactor, filterer bind.ContractFilterer) (*bind.BoundContract, error) { - parsed, err := abi.JSON(strings.NewReader({{.Type}}ABI)) + parsed, err := {{.Type}}MetaData.GetAbi() if err != nil { return nil, err } - return bind.NewBoundContract(address, parsed, caller, transactor, filterer), nil + return bind.NewBoundContract(address, *parsed, caller, transactor, filterer), nil } // Call invokes the (constant) contract method with params as input values and // sets the output to result. The result type might be a single field for simple // returns, a slice of interfaces for anonymous returns and a struct for named // returns. - func (_{{$contract.Type}} *{{$contract.Type}}Raw) Call(opts *bind.CallOpts, result interface{}, method string, params ...interface{}) error { + func (_{{$contract.Type}} *{{$contract.Type}}Raw) Call(opts *bind.CallOpts, result *[]interface{}, method string, params ...interface{}) error { return _{{$contract.Type}}.Contract.{{$contract.Type}}Caller.contract.Call(opts, result, method, params...) } @@ -213,7 +298,7 @@ package {{.Package}} // sets the output to result. The result type might be a single field for simple // returns, a slice of interfaces for anonymous returns and a struct for named // returns. - func (_{{$contract.Type}} *{{$contract.Type}}CallerRaw) Call(opts *bind.CallOpts, result interface{}, method string, params ...interface{}) error { + func (_{{$contract.Type}} *{{$contract.Type}}CallerRaw) Call(opts *bind.CallOpts, result *[]interface{}, method string, params ...interface{}) error { return _{{$contract.Type}}.Contract.contract.Call(opts, result, method, params...) } @@ -229,63 +314,116 @@ package {{.Package}} } {{range .Calls}} - // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}. + // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.ID}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}Caller) {{.Normalized.Name}}(opts *bind.CallOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}};{{end}} },{{else}}{{range .Normalized.Outputs}}{{bindtype .Type}},{{end}}{{end}} error) { - {{if .Structured}}ret := new(struct{ - {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}} - {{end}} - }){{else}}var ( - {{range $i, $_ := .Normalized.Outputs}}ret{{$i}} = new({{bindtype .Type}}) - {{end}} - ){{end}} - out := {{if .Structured}}ret{{else}}{{if eq (len .Normalized.Outputs) 1}}ret0{{else}}&[]interface{}{ - {{range $i, $_ := .Normalized.Outputs}}ret{{$i}}, - {{end}} - }{{end}}{{end}} - err := _{{$contract.Type}}.contract.Call(opts, out, "{{.Original.Name}}" {{range .Normalized.Inputs}}, {{.Name}}{{end}}) - return {{if .Structured}}*ret,{{else}}{{range $i, $_ := .Normalized.Outputs}}*ret{{$i}},{{end}}{{end}} err + func (_{{$contract.Type}} *{{$contract.Type}}Caller) {{.Normalized.Name}}(opts *bind.CallOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type $structs}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type $structs}};{{end}} },{{else}}{{range .Normalized.Outputs}}{{bindtype .Type $structs}},{{end}}{{end}} error) { + var out []interface{} + err := _{{$contract.Type}}.contract.Call(opts, &out, "{{.Original.Name}}" {{range .Normalized.Inputs}}, {{.Name}}{{end}}) + {{if .Structured}} + outstruct := new(struct{ {{range .Normalized.Outputs}} {{.Name}} {{bindtype .Type $structs}}; {{end}} }) + if err != nil { + return *outstruct, err + } + {{range $i, $t := .Normalized.Outputs}} + outstruct.{{.Name}} = *abi.ConvertType(out[{{$i}}], new({{bindtype .Type $structs}})).(*{{bindtype .Type $structs}}){{end}} + + return *outstruct, err + {{else}} + if err != nil { + return {{range $i, $_ := .Normalized.Outputs}}*new({{bindtype .Type $structs}}), {{end}} err + } + {{range $i, $t := .Normalized.Outputs}} + out{{$i}} := *abi.ConvertType(out[{{$i}}], new({{bindtype .Type $structs}})).(*{{bindtype .Type $structs}}){{end}} + + return {{range $i, $t := .Normalized.Outputs}}out{{$i}}, {{end}} err + {{end}} } - // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}. + // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.ID}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}Session) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}};{{end}} }, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type}},{{end}} {{end}} error) { + func (_{{$contract.Type}} *{{$contract.Type}}Session) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type $structs}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type $structs}};{{end}} }, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type $structs}},{{end}} {{end}} error) { return _{{$contract.Type}}.Contract.{{.Normalized.Name}}(&_{{$contract.Type}}.CallOpts {{range .Normalized.Inputs}}, {{.Name}}{{end}}) } - // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}. + // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.ID}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}CallerSession) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type}};{{end}} }, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type}},{{end}} {{end}} error) { + func (_{{$contract.Type}} *{{$contract.Type}}CallerSession) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type $structs}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type $structs}};{{end}} }, {{else}} {{range .Normalized.Outputs}}{{bindtype .Type $structs}},{{end}} {{end}} error) { return _{{$contract.Type}}.Contract.{{.Normalized.Name}}(&_{{$contract.Type}}.CallOpts {{range .Normalized.Inputs}}, {{.Name}}{{end}}) } {{end}} {{range .Transacts}} - // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.Id}}. + // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.ID}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}Transactor) {{.Normalized.Name}}(opts *bind.TransactOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type}} {{end}}) (*types.Transaction, error) { + func (_{{$contract.Type}} *{{$contract.Type}}Transactor) {{.Normalized.Name}}(opts *bind.TransactOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type $structs}} {{end}}) (*types.Transaction, error) { return _{{$contract.Type}}.contract.Transact(opts, "{{.Original.Name}}" {{range .Normalized.Inputs}}, {{.Name}}{{end}}) } - // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.Id}}. + // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.ID}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}Session) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) (*types.Transaction, error) { + func (_{{$contract.Type}} *{{$contract.Type}}Session) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type $structs}} {{end}}) (*types.Transaction, error) { return _{{$contract.Type}}.Contract.{{.Normalized.Name}}(&_{{$contract.Type}}.TransactOpts {{range $i, $_ := .Normalized.Inputs}}, {{.Name}}{{end}}) } - // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.Id}}. + // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.ID}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}TransactorSession) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type}} {{end}}) (*types.Transaction, error) { + func (_{{$contract.Type}} *{{$contract.Type}}TransactorSession) {{.Normalized.Name}}({{range $i, $_ := .Normalized.Inputs}}{{if ne $i 0}},{{end}} {{.Name}} {{bindtype .Type $structs}} {{end}}) (*types.Transaction, error) { return _{{$contract.Type}}.Contract.{{.Normalized.Name}}(&_{{$contract.Type}}.TransactOpts {{range $i, $_ := .Normalized.Inputs}}, {{.Name}}{{end}}) } {{end}} + {{if .Fallback}} + // Fallback is a paid mutator transaction binding the contract fallback function. + // + // Solidity: {{.Fallback.Original.String}} + func (_{{$contract.Type}} *{{$contract.Type}}Transactor) Fallback(opts *bind.TransactOpts, calldata []byte) (*types.Transaction, error) { + return _{{$contract.Type}}.contract.RawTransact(opts, calldata) + } + + // Fallback is a paid mutator transaction binding the contract fallback function. + // + // Solidity: {{.Fallback.Original.String}} + func (_{{$contract.Type}} *{{$contract.Type}}Session) Fallback(calldata []byte) (*types.Transaction, error) { + return _{{$contract.Type}}.Contract.Fallback(&_{{$contract.Type}}.TransactOpts, calldata) + } + + // Fallback is a paid mutator transaction binding the contract fallback function. + // + // Solidity: {{.Fallback.Original.String}} + func (_{{$contract.Type}} *{{$contract.Type}}TransactorSession) Fallback(calldata []byte) (*types.Transaction, error) { + return _{{$contract.Type}}.Contract.Fallback(&_{{$contract.Type}}.TransactOpts, calldata) + } + {{end}} + + {{if .Receive}} + // Receive is a paid mutator transaction binding the contract receive function. + // + // Solidity: {{.Receive.Original.String}} + func (_{{$contract.Type}} *{{$contract.Type}}Transactor) Receive(opts *bind.TransactOpts) (*types.Transaction, error) { + return _{{$contract.Type}}.contract.RawTransact(opts, nil) // calldata is disallowed for receive function + } + + // Receive is a paid mutator transaction binding the contract receive function. + // + // Solidity: {{.Receive.Original.String}} + func (_{{$contract.Type}} *{{$contract.Type}}Session) Receive() (*types.Transaction, error) { + return _{{$contract.Type}}.Contract.Receive(&_{{$contract.Type}}.TransactOpts) + } + + // Receive is a paid mutator transaction binding the contract receive function. + // + // Solidity: {{.Receive.Original.String}} + func (_{{$contract.Type}} *{{$contract.Type}}TransactorSession) Receive() (*types.Transaction, error) { + return _{{$contract.Type}}.Contract.Receive(&_{{$contract.Type}}.TransactOpts) + } + {{end}} + {{range .Events}} // {{$contract.Type}}{{.Normalized.Name}}Iterator is returned from Filter{{.Normalized.Name}} and is used to iterate over the raw logs and unpacked data for {{.Normalized.Name}} events raised by the {{$contract.Type}} contract. type {{$contract.Type}}{{.Normalized.Name}}Iterator struct { @@ -295,7 +433,7 @@ package {{.Package}} event string // Event name to use for unpacking event data logs chan types.Log // Log channel receiving the found contract events - sub ethereum.Subscription // Subscription for errors, completion and termination + sub tomochain.Subscription // Subscription for errors, completion and termination done bool // Whether the subscription completed delivering logs fail error // Occurred error to stop iteration } @@ -353,14 +491,14 @@ package {{.Package}} // {{$contract.Type}}{{.Normalized.Name}} represents a {{.Normalized.Name}} event raised by the {{$contract.Type}} contract. type {{$contract.Type}}{{.Normalized.Name}} struct { {{range .Normalized.Inputs}} - {{capitalise .Name}} {{if .Indexed}}{{bindtopictype .Type}}{{else}}{{bindtype .Type}}{{end}}; {{end}} + {{capitalise .Name}} {{if .Indexed}}{{bindtopictype .Type $structs}}{{else}}{{bindtype .Type $structs}}{{end}}; {{end}} Raw types.Log // Blockchain specific contextual infos } - // Filter{{.Normalized.Name}} is a free log retrieval operation binding the contract event 0x{{printf "%x" .Original.Id}}. + // Filter{{.Normalized.Name}} is a free log retrieval operation binding the contract event 0x{{printf "%x" .Original.ID}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Filter{{.Normalized.Name}}(opts *bind.FilterOpts{{range .Normalized.Inputs}}{{if .Indexed}}, {{.Name}} []{{bindtype .Type}}{{end}}{{end}}) (*{{$contract.Type}}{{.Normalized.Name}}Iterator, error) { + func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Filter{{.Normalized.Name}}(opts *bind.FilterOpts{{range .Normalized.Inputs}}{{if .Indexed}}, {{.Name}} []{{bindtype .Type $structs}}{{end}}{{end}}) (*{{$contract.Type}}{{.Normalized.Name}}Iterator, error) { {{range .Normalized.Inputs}} {{if .Indexed}}var {{.Name}}Rule []interface{} for _, {{.Name}}Item := range {{.Name}} { @@ -374,10 +512,10 @@ package {{.Package}} return &{{$contract.Type}}{{.Normalized.Name}}Iterator{contract: _{{$contract.Type}}.contract, event: "{{.Original.Name}}", logs: logs, sub: sub}, nil } - // Watch{{.Normalized.Name}} is a free log subscription operation binding the contract event 0x{{printf "%x" .Original.Id}}. + // Watch{{.Normalized.Name}} is a free log subscription operation binding the contract event 0x{{printf "%x" .Original.ID}}. // // Solidity: {{.Original.String}} - func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Watch{{.Normalized.Name}}(opts *bind.WatchOpts, sink chan<- *{{$contract.Type}}{{.Normalized.Name}}{{range .Normalized.Inputs}}{{if .Indexed}}, {{.Name}} []{{bindtype .Type}}{{end}}{{end}}) (event.Subscription, error) { + func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Watch{{.Normalized.Name}}(opts *bind.WatchOpts, sink chan<- *{{$contract.Type}}{{.Normalized.Name}}{{range .Normalized.Inputs}}{{if .Indexed}}, {{.Name}} []{{bindtype .Type $structs}}{{end}}{{end}}) (event.Subscription, error) { {{range .Normalized.Inputs}} {{if .Indexed}}var {{.Name}}Rule []interface{} for _, {{.Name}}Item := range {{.Name}} { @@ -415,108 +553,19 @@ package {{.Package}} } }), nil } - {{end}} -{{end}} -` - -// tmplSourceJava is the Java source template use to generate the contract binding -// based on. -const tmplSourceJava = ` -// This file is an automatically generated Java binding. Do not modify as any -// change will likely be lost upon the next re-generation! -package {{.Package}}; - -import org.ethereum.geth.*; -import org.ethereum.geth.internal.*; - -{{range $contract := .Contracts}} - public class {{.Type}} { - // ABI is the input ABI used to generate the binding from. - public final static String ABI = "{{.InputABI}}"; - - {{if .InputBin}} - // BYTECODE is the compiled bytecode used for deploying new contracts. - public final static byte[] BYTECODE = "{{.InputBin}}".getBytes(); - - // deploy deploys a new Ethereum contract, binding an instance of {{.Type}} to it. - public static {{.Type}} deploy(TransactOpts auth, EthereumClient client{{range .Constructor.Inputs}}, {{bindtype .Type}} {{.Name}}{{end}}) throws Exception { - Interfaces args = Geth.newInterfaces({{(len .Constructor.Inputs)}}); - {{range $index, $element := .Constructor.Inputs}} - args.set({{$index}}, Geth.newInterface()); args.get({{$index}}).set{{namedtype (bindtype .Type) .Type}}({{.Name}}); - {{end}} - return new {{.Type}}(Geth.deployContract(auth, ABI, BYTECODE, client, args)); - } - - // Internal constructor used by contract deployment. - private {{.Type}}(BoundContract deployment) { - this.Address = deployment.getAddress(); - this.Deployer = deployment.getDeployer(); - this.Contract = deployment; + // Parse{{.Normalized.Name}} is a log parse operation binding the contract event 0x{{printf "%x" .Original.ID}}. + // + // Solidity: {{.Original.String}} + func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Parse{{.Normalized.Name}}(log types.Log) (*{{$contract.Type}}{{.Normalized.Name}}, error) { + event := new({{$contract.Type}}{{.Normalized.Name}}) + if err := _{{$contract.Type}}.contract.UnpackLog(event, "{{.Original.Name}}", log); err != nil { + return nil, err } - {{end}} - - // Ethereum address where this contract is located at. - public final Address Address; - - // Ethereum transaction in which this contract was deployed (if known!). - public final Transaction Deployer; - - // Contract instance bound to a blockchain address. - private final BoundContract Contract; - - // Creates a new instance of {{.Type}}, bound to a specific deployed contract. - public {{.Type}}(Address address, EthereumClient client) throws Exception { - this(Geth.bindContract(address, ABI, client)); + event.Raw = log + return event, nil } - {{range .Calls}} - {{if gt (len .Normalized.Outputs) 1}} - // {{capitalise .Normalized.Name}}Results is the output of a call to {{.Normalized.Name}}. - public class {{capitalise .Normalized.Name}}Results { - {{range $index, $item := .Normalized.Outputs}}public {{bindtype .Type}} {{if ne .Name ""}}{{.Name}}{{else}}Return{{$index}}{{end}}; - {{end}} - } - {{end}} - - // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.Id}}. - // - // Solidity: {{.Original.String}} - public {{if gt (len .Normalized.Outputs) 1}}{{capitalise .Normalized.Name}}Results{{else}}{{range .Normalized.Outputs}}{{bindtype .Type}}{{end}}{{end}} {{.Normalized.Name}}(CallOpts opts{{range .Normalized.Inputs}}, {{bindtype .Type}} {{.Name}}{{end}}) throws Exception { - Interfaces args = Geth.newInterfaces({{(len .Normalized.Inputs)}}); - {{range $index, $item := .Normalized.Inputs}}args.set({{$index}}, Geth.newInterface()); args.get({{$index}}).set{{namedtype (bindtype .Type) .Type}}({{.Name}}); - {{end}} - - Interfaces results = Geth.newInterfaces({{(len .Normalized.Outputs)}}); - {{range $index, $item := .Normalized.Outputs}}Interface result{{$index}} = Geth.newInterface(); result{{$index}}.setDefault{{namedtype (bindtype .Type) .Type}}(); results.set({{$index}}, result{{$index}}); - {{end}} - - if (opts == null) { - opts = Geth.newCallOpts(); - } - this.Contract.call(opts, results, "{{.Original.Name}}", args); - {{if gt (len .Normalized.Outputs) 1}} - {{capitalise .Normalized.Name}}Results result = new {{capitalise .Normalized.Name}}Results(); - {{range $index, $item := .Normalized.Outputs}}result.{{if ne .Name ""}}{{.Name}}{{else}}Return{{$index}}{{end}} = results.get({{$index}}).get{{namedtype (bindtype .Type) .Type}}(); - {{end}} - return result; - {{else}}{{range .Normalized.Outputs}}return results.get(0).get{{namedtype (bindtype .Type) .Type}}();{{end}} - {{end}} - } - {{end}} - - {{range .Transacts}} - // {{.Normalized.Name}} is a paid mutator transaction binding the contract method 0x{{printf "%x" .Original.Id}}. - // - // Solidity: {{.Original.String}} - public Transaction {{.Normalized.Name}}(TransactOpts opts{{range .Normalized.Inputs}}, {{bindtype .Type}} {{.Name}}{{end}}) throws Exception { - Interfaces args = Geth.newInterfaces({{(len .Normalized.Inputs)}}); - {{range $index, $item := .Normalized.Inputs}}args.set({{$index}}, Geth.newInterface()); args.get({{$index}}).set{{namedtype (bindtype .Type) .Type}}({{.Name}}); - {{end}} - - return this.Contract.transact(opts, "{{.Original.Name}}" , args); - } - {{end}} - } + {{end}} {{end}} ` From 80996eadc1b3a3960850e52a443a4e53ecb9f773 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 9 Oct 2023 14:43:00 +0700 Subject: [PATCH 086/119] [WIP] Optimize EstimateGas API --- internal/ethapi/api.go | 150 ++++++++++++++++++++++++++++------------- 1 file changed, 104 insertions(+), 46 deletions(-) diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 154aea7eb6..53457fac36 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -511,7 +511,7 @@ func (s *PublicBlockChainAPI) BlockNumber() *big.Int { return header.Number } -// BlockNumber returns the block number of the chain head. +// GetRewardByHash returns the block reward by block hash. func (s *PublicBlockChainAPI) GetRewardByHash(hash common.Hash) map[string]map[string]map[string]*big.Int { return s.b.GetRewardByHash(hash) } @@ -1027,17 +1027,17 @@ type CallArgs struct { Data hexutil.Bytes `json:"data"` } -func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config, timeout time.Duration) (*core.ExecutionResult, error) { +func DoCall(ctx context.Context, b Backend, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config, timeout time.Duration) (*core.ExecutionResult, error) { defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) - statedb, header, err := s.b.StateAndHeaderByNumber(ctx, blockNr) + statedb, header, err := b.StateAndHeaderByNumber(ctx, blockNr) if statedb == nil || err != nil { return nil, err } // Set sender address or use a default if none specified addr := args.From if addr == (common.Address{}) { - if wallets := s.b.AccountManager().Wallets(); len(wallets) > 0 { + if wallets := b.AccountManager().Wallets(); len(wallets) > 0 { if accounts := wallets[0].Accounts(); len(accounts) > 0 { addr = accounts[0].Address } @@ -1078,20 +1078,20 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr // this makes sure resources are cleaned up. defer cancel() - block, err := s.b.BlockByNumber(ctx, blockNr) + block, err := b.BlockByNumber(ctx, blockNr) if err != nil { return nil, err } - author, err := s.b.GetEngine().Author(block.Header()) + author, err := b.GetEngine().Author(block.Header()) if err != nil { return nil, err } - tomoxState, err := s.b.TomoxService().GetTradingState(block, author) + tomoxState, err := b.TomoxService().GetTradingState(block, author) if err != nil { return nil, err } // Get a new instance of the EVM. - evm, vmError, err := s.b.GetEVM(ctx, msg, statedb, tomoxState, header, vmCfg) + evm, vmError, err := b.GetEVM(ctx, msg, statedb, tomoxState, header, vmCfg) if err != nil { return nil, err } @@ -1146,7 +1146,7 @@ func (e *revertError) ErrorData() interface{} { // Call executes the given transaction on the state for the given block number. // It doesn't make and changes in the state/blockchain and is useful to execute and retrieve values. func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { - result, err := s.doCall(ctx, args, blockNr, vm.Config{}, 5*time.Second) + result, err := DoCall(ctx, s.b, args, blockNr, vm.Config{}, 5*time.Second) if err != nil { return nil, err } @@ -1157,45 +1157,107 @@ func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr r return result.Return(), result.Err } -// EstimateGas returns an estimate of the amount of gas needed to execute the -// given transaction against the current pending block. -func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (hexutil.Uint64, error) { - // Binary search the gas requirement, as it may be higher than the amount used +// executeEstimate is a helper that executes the transaction under a given gas limit and returns +// true if the transaction fails for a reason that might be related to not enough gas. A non-nil +// error means execution failed due to reasons unrelated to the gas limit. +func executeEstimate(ctx context.Context, b Backend, args CallArgs, state *state.StateDB, header *types.Header, gasLimit uint64) (bool, *core.ExecutionResult, error) { + args.Gas = (hexutil.Uint64)(gasLimit) + result, err := DoCall(ctx, b, args, rpc.BlockNumber(header.Number.Int64()), vm.Config{}, 0) + if err != nil { + if errors.Is(err, core.ErrIntrinsicGas) { + return true, nil, nil // Special case, raise gas limit + } + return true, nil, err // Bail out + } + return result.Failed(), result, nil +} + +// DoEstimateGas returns the lowest possible gas limit that allows the transaction to run +// successfully at block `blockNrOrHash`. It returns error if the transaction would revert, or if +// there are unexpected failures. The gas limit is capped by both `args.Gas` (if non-nil & +// non-zero) and `gasCap` (if non-zero). +func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumber) (hexutil.Uint64, error) { + // Binary search the gas limit, as it may need to be higher than the amount used var ( - lo uint64 = params.TxGas - 1 - hi uint64 - cap uint64 + lo uint64 // lowest-known gas limit where tx execution fails + hi uint64 // lowest-known gas limit where tx execution succeeds ) + // Determine the highest gas limit can be used during the estimation. if uint64(args.Gas) >= params.TxGas { hi = uint64(args.Gas) } else { - // Retrieve the current pending block to act as the gas ceiling - block, err := s.b.BlockByNumber(ctx, rpc.LatestBlockNumber) + // Retrieve the block to act as the gas ceiling + block, err := b.BlockByNumber(ctx, blockNrOrHash) if err != nil { return 0, err } + if block == nil { + return 0, errors.New("block not found") + } hi = block.GasLimit() } - cap = hi + // Normalize the max fee per gas the call is willing to spend. + feeCap := args.GasPrice.ToInt() - // Create a helper to check if a gas allowance results in an executable transaction - executable := func(gas uint64) (bool, *core.ExecutionResult, error) { - args.Gas = hexutil.Uint64(gas) + state, header, err := b.StateAndHeaderByNumber(ctx, blockNrOrHash) + if state == nil || err != nil { + return 0, err + } - result, err := s.doCall(ctx, args, rpc.LatestBlockNumber, vm.Config{}, 0) - if err != nil { - if err == core.ErrIntrinsicGas { - return true, nil, nil // Special case, raise gas limit + // Recap the highest gas limit with account's available balance. + if feeCap.BitLen() != 0 { + balance := state.GetBalance(args.From) // from can't be nil + available := new(big.Int).Set(balance) + if args.Value.ToInt().Cmp(available) >= 0 { + return 0, core.ErrInsufficientFundsForTransfer + } + available.Sub(available, args.Value.ToInt()) + allowance := new(big.Int).Div(available, feeCap) + + // If the allowance is larger than maximum uint64, skip checking + if allowance.IsUint64() && hi > allowance.Uint64() { + transfer := args.Value + log.Warn("Gas estimation capped by limited funds", "original", hi, "balance", balance, + "sent", transfer.ToInt(), "maxFeePerGas", feeCap, "fundable", allowance) + hi = allowance.Uint64() + } + } + + // We first execute the transaction at the highest allowable gas limit, since if this fails we + // can return error immediately. + failed, result, err := executeEstimate(ctx, b, args, state.Copy(), header, hi) + if err != nil { + return 0, err + } + if failed { + if result != nil && !errors.Is(result.Err, vm.ErrOutOfGas) { + if len(result.Revert()) > 0 { + return 0, newRevertError(result) } - return true, nil, err + return 0, result.Err } - return result.Failed(), result, nil + return 0, fmt.Errorf("gas required exceeds allowance (%d)", hi) } - // Execute the binary search and hone in on an executable gas limit + // For almost any transaction, the gas consumed by the unconstrained execution above + // lower-bounds the gas limit required for it to succeed. One exception is those txs that + // explicitly check gas remaining in order to successfully execute within a given limit, but we + // probably don't want to return a lowest possible gas limit for these cases anyway. + lo = result.UsedGas - 1 + + // Binary search for the smallest gas limit that allows the tx to execute successfully. for lo+1 < hi { mid := (hi + lo) / 2 - failed, _, err := executable(mid) + if mid > lo*2 { + // Most txs don't need much higher gas limit than their gas used, and most txs don't + // require near the full block limit of gas, so the selection of where to bisect the + // range here is skewed to favor the low side. + mid = lo * 2 + } + failed, _, err = executeEstimate(ctx, b, args, state.Copy(), header, mid) if err != nil { + // This should not happen under normal conditions since if we make it this far the + // transaction had run without error at least once before. + log.Error("execution error in estimate gas", "err", err) return 0, err } if failed { @@ -1204,24 +1266,20 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h hi = mid } } - // Reject the transaction as invalid if it still fails at the highest allowance - if hi == cap { - failed, result, err := executable(hi) - if err != nil { - return 0, nil - } + return hexutil.Uint64(hi), nil +} - if failed { - if result != nil && result.Err != vm.ErrOutOfGas { - if len(result.Revert()) > 0 { - return 0, newRevertError(result) - } - return 0, result.Err - } - return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap) - } +// EstimateGas returns the lowest possible gas limit that allows the transaction to run +// successfully at block `blockNrOrHash`, or the latest block if `blockNrOrHash` is unspecified. It +// returns error if the transaction would revert or if there are unexpected failures. The returned +// value is capped by both `args.Gas` (if non-nil & non-zero) and the backend's RPCGasCap +// configuration (if non-zero). +func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs, blockNrOrHash *rpc.BlockNumber) (hexutil.Uint64, error) { + bNrOrHash := rpc.LatestBlockNumber + if blockNrOrHash != nil { + bNrOrHash = *blockNrOrHash } - return hexutil.Uint64(hi), nil + return DoEstimateGas(ctx, s.b, args, bNrOrHash) } // ExecutionResult groups all structured logs emitted by the EVM From 97284ef22750d85a9c2940fc2085a8e91b61907d Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 10 Oct 2023 13:58:33 +0700 Subject: [PATCH 087/119] Use same state for each invocation within EstimateGas --- internal/ethapi/api.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 53457fac36..366de5e255 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -1030,10 +1030,14 @@ type CallArgs struct { func DoCall(ctx context.Context, b Backend, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config, timeout time.Duration) (*core.ExecutionResult, error) { defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) - statedb, header, err := b.StateAndHeaderByNumber(ctx, blockNr) - if statedb == nil || err != nil { + state, header, err := b.StateAndHeaderByNumber(ctx, blockNr) + if state == nil || err != nil { return nil, err } + + return doCall(ctx, b, args, state, header, timeout) +} +func doCall(ctx context.Context, b Backend, args CallArgs, state *state.StateDB, header *types.Header, timeout time.Duration) (*core.ExecutionResult, error) { // Set sender address or use a default if none specified addr := args.From if addr == (common.Address{}) { @@ -1078,7 +1082,7 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNr rpc.BlockNumb // this makes sure resources are cleaned up. defer cancel() - block, err := b.BlockByNumber(ctx, blockNr) + block, err := b.BlockByNumber(ctx, rpc.BlockNumber(header.Number.Int64())) if err != nil { return nil, err } @@ -1091,7 +1095,7 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNr rpc.BlockNumb return nil, err } // Get a new instance of the EVM. - evm, vmError, err := b.GetEVM(ctx, msg, statedb, tomoxState, header, vmCfg) + evm, vmError, err := b.GetEVM(ctx, msg, state, tomoxState, header, vm.Config{}) if err != nil { return nil, err } @@ -1162,7 +1166,7 @@ func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr r // error means execution failed due to reasons unrelated to the gas limit. func executeEstimate(ctx context.Context, b Backend, args CallArgs, state *state.StateDB, header *types.Header, gasLimit uint64) (bool, *core.ExecutionResult, error) { args.Gas = (hexutil.Uint64)(gasLimit) - result, err := DoCall(ctx, b, args, rpc.BlockNumber(header.Number.Int64()), vm.Config{}, 0) + result, err := doCall(ctx, b, args, state, header, 0) if err != nil { if errors.Is(err, core.ErrIntrinsicGas) { return true, nil, nil // Special case, raise gas limit From 98382229f9db8d9a295db839bba6db0fd0afb788 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 10 Oct 2023 14:20:13 +0700 Subject: [PATCH 088/119] Restore error functionality --- accounts/abi/bind/base.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/accounts/abi/bind/base.go b/accounts/abi/bind/base.go index 35c236abba..55c13e15c9 100644 --- a/accounts/abi/bind/base.go +++ b/accounts/abi/bind/base.go @@ -179,7 +179,10 @@ func (c *BoundContract) Call(opts *CallOpts, result interface{}, method string, } } else { output, err = c.caller.CallContract(ctx, msg, nil) - if err == nil && len(output) == 0 { + if err != nil { + return err + } + if len(output) == 0 { // Make sure we have a contract to operate on, and bail out otherwise. if code, err = c.caller.CodeAt(ctx, c.address, nil); err != nil { return err From 021856ff56718ccbfdd5876e895440399883a48a Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 10 Oct 2023 16:51:48 +0700 Subject: [PATCH 089/119] Fix merge conflicts --- accounts/abi/abi.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 0cd112495f..7724d7c479 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -24,6 +24,7 @@ import ( "io" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/crypto" ) // The ABI holds information about a contract's context and available @@ -253,6 +254,7 @@ func (abi *ABI) HasFallback() bool { // HasReceive returns an indicator whether a receive function is included. func (abi *ABI) HasReceive() bool { return abi.Receive.Type == Receive +} // revertSelector is a special function selector for revert reason unpacking. var revertSelector = crypto.Keccak256([]byte("Error(string)"))[:4] @@ -268,10 +270,13 @@ func UnpackRevert(data []byte) (string, error) { if !bytes.Equal(data[:4], revertSelector) { return "", errors.New("invalid data for unpacking") } - var reason string - typ, _ := NewType("string", "", nil) - if err := (Arguments{{Type: typ}}).Unpack(&reason, data[4:]); err != nil { + typ, err := NewType("string", "", nil) + if err != nil { + return "", err + } + unpacked, err := (Arguments{{Type: typ}}).Unpack(data[4:]) + if err != nil { return "", err } - return reason, nil + return unpacked[0].(string), nil } From 1d81da4f016e75e71575694cacdc0be02688e6b1 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Wed, 11 Oct 2023 14:49:37 +0700 Subject: [PATCH 090/119] Add devnet profile and reduce blocks per epoch --- cmd/tomo/config.go | 18 +++++++++++++++++- cmd/tomo/main.go | 4 ++++ common/constants.go | 9 ++++++--- eth/backend.go | 2 +- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/cmd/tomo/config.go b/cmd/tomo/config.go index d8dffb4f5b..55d44f8be6 100644 --- a/cmd/tomo/config.go +++ b/cmd/tomo/config.go @@ -20,7 +20,6 @@ import ( "bufio" "errors" "fmt" - "gopkg.in/urfave/cli.v1" "io" "math/big" "os" @@ -29,6 +28,8 @@ import ( "unicode" "github.com/naoina/toml" + "gopkg.in/urfave/cli.v1" + "github.com/tomochain/tomochain/cmd/utils" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/eth" @@ -166,6 +167,21 @@ func makeConfigNode(ctx *cli.Context) (*node.Node, tomoConfig) { common.BlackListHFNumber = uint64(0) } + // Check if devnet is enable + if ctx.GlobalUint64(utils.NetworkIdFlag.Name) == 989898 { + cfg.Eth.NetworkId = 989898 + common.IsDevnet = true + common.TIP2019Block = big.NewInt(300) + common.TIPSigning = big.NewInt(600) + common.TIPRandomize = big.NewInt(900) + common.TIPTomoX = big.NewInt(1200) + common.TIPTomoXLending = big.NewInt(1500) + common.TIPTomoXCancellationFee = big.NewInt(1800) + common.EpocBlockSecret = uint64(100) + common.EpocBlockOpening = uint64(125) + common.EpocBlockRandomize = uint64(150) + } + // Rewound if rewound := ctx.GlobalInt(utils.RewoundFlag.Name); rewound != 0 { common.Rewound = uint64(rewound) diff --git a/cmd/tomo/main.go b/cmd/tomo/main.go index 1b08a78ced..99851e86ad 100644 --- a/cmd/tomo/main.go +++ b/cmd/tomo/main.go @@ -226,6 +226,10 @@ func tomo(ctx *cli.Context) error { // it unlocks any requested accounts, and starts the RPC/IPC interfaces and the // miner. func startNode(ctx *cli.Context, stack *node.Node, cfg tomoConfig) { + if common.IsDevnet { + log.Info("DEVNET configuration applied") + } + // Start up the node itself utils.StartNode(stack) diff --git a/common/constants.go b/common/constants.go index af75a82e70..6278800529 100644 --- a/common/constants.go +++ b/common/constants.go @@ -11,9 +11,6 @@ const ( HexSignMethod = "e341eaa4" HexSetSecret = "34d38600" HexSetOpening = "e11f5ba2" - EpocBlockSecret = 800 - EpocBlockOpening = 850 - EpocBlockRandomize = 900 MaxMasternodes = 150 LimitPenaltyEpoch = 4 BlocksPerYear = uint64(15768000) @@ -29,6 +26,11 @@ const ( var Rewound = uint64(0) +// dynamic configs +var EpocBlockSecret = uint64(800) +var EpocBlockOpening = uint64(850) +var EpocBlockRandomize = uint64(900) + // hardforks var TIP2019Block = big.NewInt(1050000) var TIPSigning = big.NewInt(3000000) @@ -38,6 +40,7 @@ var TIPTomoX = big.NewInt(20581700) var TIPTomoXLending = big.NewInt(21430200) var TIPTomoXCancellationFee = big.NewInt(30915660) var TIPTomoXTestnet = big.NewInt(0) +var IsDevnet bool = false var IsTestnet bool = false var StoreRewardFolder string var RollbackHash Hash diff --git a/eth/backend.go b/eth/backend.go index 1fe58ca2da..1d3e51a845 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -559,7 +559,7 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi // Hook verifies masternodes set c.HookVerifyMNs = func(header *types.Header, signers []common.Address) error { number := header.Number.Int64() - if number > 0 && number%common.EpocBlockRandomize == 0 { + if number > 0 && uint64(number)%common.EpocBlockRandomize == 0 { start := time.Now() validators, err := GetValidators(eth.blockchain, signers) log.Debug("Time Calculated HookVerifyMNs ", "block", header.Number.Uint64(), "time", common.PrettyDuration(time.Since(start))) From a8106e27bc2f8e1ab2f35edd399b947682012861 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 26 Oct 2023 15:04:49 +0700 Subject: [PATCH 091/119] Handle solidity panic revert --- accounts/abi/abi.go | 69 ++++++++++++++++++++++++++++++++++++++++ accounts/abi/abi_test.go | 34 ++++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 7d5d6291b0..179db856e3 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -22,6 +22,9 @@ import ( "errors" "fmt" "io" + "math/big" + + "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/common" ) @@ -254,3 +257,69 @@ func (abi *ABI) HasFallback() bool { func (abi *ABI) HasReceive() bool { return abi.Receive.Type == Receive } + +// revertSelector is a special function selector for revert reason unpacking. +var revertSelector = crypto.Keccak256([]byte("Error(string)"))[:4] + +// panicSelector is a special function selector for panic reason unpacking. +var panicSelector = crypto.Keccak256([]byte("Panic(uint256)"))[:4] + +// panicReasons map is for readable panic codes +// see this linkage for the deails +// https://docs.soliditylang.org/en/v0.8.21/control-structures.html#panic-via-assert-and-error-via-require +// the reason string list is copied from ether.js +// https://github.com/ethers-io/ethers.js/blob/fa3a883ff7c88611ce766f58bdd4b8ac90814470/src.ts/abi/interface.ts#L207-L218 +var panicReasons = map[uint64]string{ + 0x00: "generic panic", + 0x01: "assert(false)", + 0x11: "arithmetic underflow or overflow", + 0x12: "division or modulo by zero", + 0x21: "enum overflow", + 0x22: "invalid encoded storage byte array accessed", + 0x31: "out-of-bounds array access; popping on an empty array", + 0x32: "out-of-bounds access of an array or bytesN", + 0x41: "out of memory", + 0x51: "uninitialized function", +} + +// UnpackRevert resolves the abi-encoded revert reason. According to the solidity +// spec https://solidity.readthedocs.io/en/latest/control-structures.html#revert, +// the provided revert reason is abi-encoded as if it were a call to function +// `Error(string)` or `Panic(uint256)`. So it's a special tool for it. +func UnpackRevert(data []byte) (string, error) { + if len(data) < 4 { + return "", errors.New("invalid data for unpacking") + } + switch { + case bytes.Equal(data[:4], revertSelector): + typ, err := NewType("string", "", nil) + if err != nil { + return "", err + } + unpacked, err := (Arguments{{Type: typ}}).Unpack(data[4:]) + if err != nil { + return "", err + } + return unpacked[0].(string), nil + case bytes.Equal(data[:4], panicSelector): + typ, err := NewType("uint256", "", nil) + if err != nil { + return "", err + } + unpacked, err := (Arguments{{Type: typ}}).Unpack(data[4:]) + if err != nil { + return "", err + } + pCode := unpacked[0].(*big.Int) + // uint64 safety check for future + // but the code is not bigger than MAX(uint64) now + if pCode.IsUint64() { + if reason, ok := panicReasons[pCode.Uint64()]; ok { + return reason, nil + } + } + return fmt.Sprintf("unknown panic code: %#x", pCode), nil + default: + return "", errors.New("invalid data for unpacking") + } +} diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 67af037633..87eceb8b1d 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -19,6 +19,7 @@ package abi import ( "bytes" "encoding/hex" + "errors" "fmt" "math/big" "reflect" @@ -1117,3 +1118,36 @@ func TestUnnamedEventParam(t *testing.T) { t.Fatalf("Could not find input") } } + +func TestUnpackRevert(t *testing.T) { + t.Parallel() + + var cases = []struct { + input string + expect string + expectErr error + }{ + {"", "", errors.New("invalid data for unpacking")}, + {"08c379a1", "", errors.New("invalid data for unpacking")}, + {"08c379a00000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000d72657665727420726561736f6e00000000000000000000000000000000000000", "revert reason", nil}, + {"4e487b710000000000000000000000000000000000000000000000000000000000000000", "generic panic", nil}, + {"4e487b7100000000000000000000000000000000000000000000000000000000000000ff", "unknown panic code: 0xff", nil}, + } + for index, c := range cases { + t.Run(fmt.Sprintf("case %d", index), func(t *testing.T) { + got, err := UnpackRevert(common.Hex2Bytes(c.input)) + if c.expectErr != nil { + if err == nil { + t.Fatalf("Expected non-nil error") + } + if err.Error() != c.expectErr.Error() { + t.Fatalf("Expected error mismatch, want %v, got %v", c.expectErr, err) + } + return + } + if c.expect != got { + t.Fatalf("Output mismatch, want %v, got %v", c.expect, got) + } + }) + } +} From 6b84f760d4cd70e91b2576caa0e6049daac86e24 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Wed, 1 Nov 2023 00:42:57 +0700 Subject: [PATCH 092/119] Fix imports --- accounts/abi/abi.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 61ba2d3927..c39c88befb 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -24,8 +24,6 @@ import ( "io" "math/big" - "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/crypto" ) From b0cf4285199084caf14d669324f80a50eec6455d Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 27 Oct 2023 16:50:24 +0700 Subject: [PATCH 093/119] Implement IP tracker --- p2p/netutil/iptrack.go | 130 +++++++++++++++++++++++++++++++++ p2p/netutil/iptrack_test.go | 138 ++++++++++++++++++++++++++++++++++++ 2 files changed, 268 insertions(+) create mode 100644 p2p/netutil/iptrack.go create mode 100644 p2p/netutil/iptrack_test.go diff --git a/p2p/netutil/iptrack.go b/p2p/netutil/iptrack.go new file mode 100644 index 0000000000..a8660a4d73 --- /dev/null +++ b/p2p/netutil/iptrack.go @@ -0,0 +1,130 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package netutil + +import ( + "time" + + "github.com/tomochain/tomochain/common/mclock" +) + +// IPTracker predicts the external endpoint, i.e. IP address and port, of the local host +// based on statements made by other hosts. +type IPTracker struct { + window time.Duration + contactWindow time.Duration + minStatements int + clock mclock.Clock + statements map[string]ipStatement + contact map[string]mclock.AbsTime + lastStatementGC mclock.AbsTime + lastContactGC mclock.AbsTime +} + +type ipStatement struct { + endpoint string + time mclock.AbsTime +} + +// NewIPTracker creates an IP tracker. +// +// The window parameters configure the amount of past network events which are kept. The +// minStatements parameter enforces a minimum number of statements which must be recorded +// before any prediction is made. Higher values for these parameters decrease 'flapping' of +// predictions as network conditions change. Window duration values should typically be in +// the range of minutes. +func NewIPTracker(window, contactWindow time.Duration, minStatements int) *IPTracker { + return &IPTracker{ + window: window, + contactWindow: contactWindow, + statements: make(map[string]ipStatement), + minStatements: minStatements, + contact: make(map[string]mclock.AbsTime), + clock: mclock.System{}, + } +} + +// PredictFullConeNAT checks whether the local host is behind full cone NAT. It predicts by +// checking whether any statement has been received from a node we didn't contact before +// the statement was made. +func (it *IPTracker) PredictFullConeNAT() bool { + now := it.clock.Now() + it.gcContact(now) + it.gcStatements(now) + for host, st := range it.statements { + if c, ok := it.contact[host]; !ok || c > st.time { + return true + } + } + return false +} + +// PredictEndpoint returns the current prediction of the external endpoint. +func (it *IPTracker) PredictEndpoint() string { + it.gcStatements(it.clock.Now()) + + // The current strategy is simple: find the endpoint with most statements. + counts := make(map[string]int, len(it.statements)) + maxcount, max := 0, "" + for _, s := range it.statements { + c := counts[s.endpoint] + 1 + counts[s.endpoint] = c + if c > maxcount && c >= it.minStatements { + maxcount, max = c, s.endpoint + } + } + return max +} + +// AddStatement records that a certain host thinks our external endpoint is the one given. +func (it *IPTracker) AddStatement(host, endpoint string) { + now := it.clock.Now() + it.statements[host] = ipStatement{endpoint, now} + if time.Duration(now-it.lastStatementGC) >= it.window { + it.gcStatements(now) + } +} + +// AddContact records that a packet containing our endpoint information has been sent to a +// certain host. +func (it *IPTracker) AddContact(host string) { + now := it.clock.Now() + it.contact[host] = now + if time.Duration(now-it.lastContactGC) >= it.contactWindow { + it.gcContact(now) + } +} + +func (it *IPTracker) gcStatements(now mclock.AbsTime) { + it.lastStatementGC = now + cutoff := now.Add(-it.window) + for host, s := range it.statements { + if s.time < cutoff { + delete(it.statements, host) + } + } +} + +func (it *IPTracker) gcContact(now mclock.AbsTime) { + it.lastContactGC = now + cutoff := now.Add(-it.contactWindow) + for host, ct := range it.contact { + if ct < cutoff { + delete(it.contact, host) + } + } +} diff --git a/p2p/netutil/iptrack_test.go b/p2p/netutil/iptrack_test.go new file mode 100644 index 0000000000..711e588d6d --- /dev/null +++ b/p2p/netutil/iptrack_test.go @@ -0,0 +1,138 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package netutil + +import ( + crand "crypto/rand" + "fmt" + "testing" + "time" + + "github.com/tomochain/tomochain/common/mclock" +) + +const ( + opStatement = iota + opContact + opPredict + opCheckFullCone +) + +type iptrackTestEvent struct { + op int + time int // absolute, in milliseconds + ip, from string +} + +func TestIPTracker(t *testing.T) { + tests := map[string][]iptrackTestEvent{ + "minStatements": { + {opPredict, 0, "", ""}, + {opStatement, 0, "127.0.0.1", "127.0.0.2"}, + {opPredict, 1000, "", ""}, + {opStatement, 1000, "127.0.0.1", "127.0.0.3"}, + {opPredict, 1000, "", ""}, + {opStatement, 1000, "127.0.0.1", "127.0.0.4"}, + {opPredict, 1000, "127.0.0.1", ""}, + }, + "window": { + {opStatement, 0, "127.0.0.1", "127.0.0.2"}, + {opStatement, 2000, "127.0.0.1", "127.0.0.3"}, + {opStatement, 3000, "127.0.0.1", "127.0.0.4"}, + {opPredict, 10000, "127.0.0.1", ""}, + {opPredict, 10001, "", ""}, // first statement expired + {opStatement, 10100, "127.0.0.1", "127.0.0.2"}, + {opPredict, 10200, "127.0.0.1", ""}, + }, + "fullcone": { + {opContact, 0, "", "127.0.0.2"}, + {opStatement, 10, "127.0.0.1", "127.0.0.2"}, + {opContact, 2000, "", "127.0.0.3"}, + {opStatement, 2010, "127.0.0.1", "127.0.0.3"}, + {opContact, 3000, "", "127.0.0.4"}, + {opStatement, 3010, "127.0.0.1", "127.0.0.4"}, + {opCheckFullCone, 3500, "false", ""}, + }, + "fullcone_2": { + {opContact, 0, "", "127.0.0.2"}, + {opStatement, 10, "127.0.0.1", "127.0.0.2"}, + {opContact, 2000, "", "127.0.0.3"}, + {opStatement, 2010, "127.0.0.1", "127.0.0.3"}, + {opStatement, 3000, "127.0.0.1", "127.0.0.4"}, + {opContact, 3010, "", "127.0.0.4"}, + {opCheckFullCone, 3500, "true", ""}, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { runIPTrackerTest(t, test) }) + } +} + +func runIPTrackerTest(t *testing.T, evs []iptrackTestEvent) { + var ( + clock mclock.Simulated + it = NewIPTracker(10*time.Second, 10*time.Second, 3) + ) + it.clock = &clock + for i, ev := range evs { + evtime := time.Duration(ev.time) * time.Millisecond + clock.Run(evtime - time.Duration(clock.Now())) + switch ev.op { + case opStatement: + it.AddStatement(ev.from, ev.ip) + case opContact: + it.AddContact(ev.from) + case opPredict: + if pred := it.PredictEndpoint(); pred != ev.ip { + t.Errorf("op %d: wrong prediction %q, want %q", i, pred, ev.ip) + } + case opCheckFullCone: + pred := fmt.Sprintf("%t", it.PredictFullConeNAT()) + if pred != ev.ip { + t.Errorf("op %d: wrong prediction %s, want %s", i, pred, ev.ip) + } + } + } +} + +// This checks that old statements and contacts are GCed even if Predict* isn't called. +func TestIPTrackerForceGC(t *testing.T) { + var ( + clock mclock.Simulated + window = 10 * time.Second + rate = 50 * time.Millisecond + max = int(window/rate) + 1 + it = NewIPTracker(window, window, 3) + ) + it.clock = &clock + + for i := 0; i < 5*max; i++ { + e1 := make([]byte, 4) + e2 := make([]byte, 4) + crand.Read(e1) + crand.Read(e2) + it.AddStatement(string(e1), string(e2)) + it.AddContact(string(e1)) + clock.Run(rate) + } + if len(it.contact) > 2*max { + t.Errorf("contacts not GCed, have %d", len(it.contact)) + } + if len(it.statements) > 2*max { + t.Errorf("statements not GCed, have %d", len(it.statements)) + } +} From 67a1980e93367b68eb76be87f4c105d4c1f0dd61 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 27 Oct 2023 16:52:44 +0700 Subject: [PATCH 094/119] Update Ethereum node record --- p2p/enr/enr.go | 279 +++++++++++++++++++++++++------------------- p2p/enr/enr_test.go | 233 ++++++++++++++++++++---------------- p2p/enr/entries.go | 83 +++++++++++-- 3 files changed, 362 insertions(+), 233 deletions(-) diff --git a/p2p/enr/enr.go b/p2p/enr/enr.go index fabd08ae26..81cdffdf1c 100644 --- a/p2p/enr/enr.go +++ b/p2p/enr/enr.go @@ -29,33 +29,53 @@ package enr import ( "bytes" - "crypto/ecdsa" "errors" "fmt" "io" "sort" - "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/crypto/sha3" "github.com/tomochain/tomochain/rlp" ) const SizeLimit = 300 // maximum encoded size of a node record in bytes -const ID_SECP256k1_KECCAK = ID("secp256k1-keccak") // the default identity scheme - var ( - errNoID = errors.New("unknown or unspecified identity scheme") - errInvalidSigsize = errors.New("invalid signature size") - errInvalidSig = errors.New("invalid signature") + ErrInvalidSig = errors.New("invalid signature on node record") errNotSorted = errors.New("record key/value pairs are not sorted by key") errDuplicateKey = errors.New("record contains duplicate key") errIncompletePair = errors.New("record contains incomplete k/v pair") + errIncompleteList = errors.New("record contains less than two list elements") errTooBig = fmt.Errorf("record bigger than %d bytes", SizeLimit) errEncodeUnsigned = errors.New("can't encode unsigned record") errNotFound = errors.New("no such key in record") ) +// An IdentityScheme is capable of verifying record signatures and +// deriving node addresses. +type IdentityScheme interface { + Verify(r *Record, sig []byte) error + NodeAddr(r *Record) []byte +} + +// SchemeMap is a registry of named identity schemes. +type SchemeMap map[string]IdentityScheme + +func (m SchemeMap) Verify(r *Record, sig []byte) error { + s := m[r.IdentityScheme()] + if s == nil { + return ErrInvalidSig + } + return s.Verify(r, sig) +} + +func (m SchemeMap) NodeAddr(r *Record) []byte { + s := m[r.IdentityScheme()] + if s == nil { + return nil + } + return s.NodeAddr(r) +} + // Record represents a node record. The zero value is an empty record. type Record struct { seq uint64 // sequence number @@ -70,9 +90,22 @@ type pair struct { v rlp.RawValue } -// Signed reports whether the record has a valid signature. -func (r *Record) Signed() bool { - return r.signature != nil +// Size returns the encoded size of the record. +func (r *Record) Size() uint64 { + if r.raw != nil { + return uint64(len(r.raw)) + } + return computeSize(r) +} + +func computeSize(r *Record) uint64 { + size := uint64(rlp.IntSize(r.seq)) + size += rlp.BytesSize(r.signature) + for _, p := range r.pairs { + size += rlp.StringSize(p.k) + size += uint64(len(p.v)) + } + return rlp.ListSize(size) } // Seq returns the sequence number. @@ -81,8 +114,8 @@ func (r *Record) Seq() uint64 { } // SetSeq updates the record sequence number. This invalidates any signature on the record. -// Calling SetSeq is usually not required because signing the redord increments the -// sequence number. +// Calling SetSeq is usually not required because setting any key in a signed record +// increments the sequence number. func (r *Record) SetSeq(s uint64) { r.signature = nil r.raw = nil @@ -105,66 +138,100 @@ func (r *Record) Load(e Entry) error { return &KeyError{Key: e.ENRKey(), Err: errNotFound} } -// Set adds or updates the given entry in the record. -// It panics if the value can't be encoded. +// Set adds or updates the given entry in the record. It panics if the value can't be +// encoded. If the record is signed, Set increments the sequence number and invalidates +// the sequence number. func (r *Record) Set(e Entry) { - r.signature = nil - r.raw = nil blob, err := rlp.EncodeToBytes(e) if err != nil { panic(fmt.Errorf("enr: can't encode %s: %v", e.ENRKey(), err)) } + r.invalidate() - i := sort.Search(len(r.pairs), func(i int) bool { return r.pairs[i].k >= e.ENRKey() }) - - if i < len(r.pairs) && r.pairs[i].k == e.ENRKey() { + pairs := make([]pair, len(r.pairs)) + copy(pairs, r.pairs) + i := sort.Search(len(pairs), func(i int) bool { return pairs[i].k >= e.ENRKey() }) + switch { + case i < len(pairs) && pairs[i].k == e.ENRKey(): // element is present at r.pairs[i] - r.pairs[i].v = blob - return - } else if i < len(r.pairs) { + pairs[i].v = blob + case i < len(r.pairs): // insert pair before i-th elem el := pair{e.ENRKey(), blob} - r.pairs = append(r.pairs, pair{}) - copy(r.pairs[i+1:], r.pairs[i:]) - r.pairs[i] = el - return + pairs = append(pairs, pair{}) + copy(pairs[i+1:], pairs[i:]) + pairs[i] = el + default: + // element should be placed at the end of r.pairs + pairs = append(pairs, pair{e.ENRKey(), blob}) } + r.pairs = pairs +} + +func (r *Record) invalidate() { + if r.signature != nil { + r.seq++ + } + r.signature = nil + r.raw = nil +} - // element should be placed at the end of r.pairs - r.pairs = append(r.pairs, pair{e.ENRKey(), blob}) +// Signature returns the signature of the record. +func (r *Record) Signature() []byte { + if r.signature == nil { + return nil + } + cpy := make([]byte, len(r.signature)) + copy(cpy, r.signature) + return cpy } // EncodeRLP implements rlp.Encoder. Encoding fails if // the record is unsigned. func (r Record) EncodeRLP(w io.Writer) error { - if !r.Signed() { + if r.signature == nil { return errEncodeUnsigned } _, err := w.Write(r.raw) return err } -// DecodeRLP implements rlp.Decoder. Decoding verifies the signature. +// DecodeRLP implements rlp.Decoder. Decoding doesn't verify the signature. func (r *Record) DecodeRLP(s *rlp.Stream) error { - raw, err := s.Raw() + dec, raw, err := decodeRecord(s) if err != nil { return err } + *r = dec + r.raw = raw + return nil +} + +func decodeRecord(s *rlp.Stream) (dec Record, raw []byte, err error) { + raw, err = s.Raw() + if err != nil { + return dec, raw, err + } if len(raw) > SizeLimit { - return errTooBig + return dec, raw, errTooBig } // Decode the RLP container. - dec := Record{raw: raw} s = rlp.NewStream(bytes.NewReader(raw), 0) if _, err := s.List(); err != nil { - return err + return dec, raw, err } if err = s.Decode(&dec.signature); err != nil { - return err + if err == rlp.EOL { + err = errIncompleteList + } + return dec, raw, err } if err = s.Decode(&dec.seq); err != nil { - return err + if err == rlp.EOL { + err = errIncompleteList + } + return dec, raw, err } // The rest of the record contains sorted k/v pairs. var prevkey string @@ -174,62 +241,73 @@ func (r *Record) DecodeRLP(s *rlp.Stream) error { if err == rlp.EOL { break } - return err + return dec, raw, err } if err := s.Decode(&kv.v); err != nil { if err == rlp.EOL { - return errIncompletePair + return dec, raw, errIncompletePair } - return err + return dec, raw, err } if i > 0 { if kv.k == prevkey { - return errDuplicateKey + return dec, raw, errDuplicateKey } if kv.k < prevkey { - return errNotSorted + return dec, raw, errNotSorted } } dec.pairs = append(dec.pairs, kv) prevkey = kv.k } - if err := s.ListEnd(); err != nil { - return err - } - - // Verify signature. - if err = dec.verifySignature(); err != nil { - return err - } - *r = dec - return nil + return dec, raw, s.ListEnd() } -type s256raw []byte - -func (s256raw) ENRKey() string { return "secp256k1" } +// IdentityScheme returns the name of the identity scheme in the record. +func (r *Record) IdentityScheme() string { + var id ID + r.Load(&id) + return string(id) +} -// NodeAddr returns the node address. The return value will be nil if the record is -// unsigned. -func (r *Record) NodeAddr() []byte { - var entry s256raw - if r.Load(&entry) != nil { - return nil - } - return crypto.Keccak256(entry) +// VerifySignature checks whether the record is signed using the given identity scheme. +func (r *Record) VerifySignature(s IdentityScheme) error { + return s.Verify(r, r.signature) } -// Sign signs the record with the given private key. It updates the record's identity -// scheme, public key and increments the sequence number. Sign returns an error if the -// encoded record is larger than the size limit. -func (r *Record) Sign(privkey *ecdsa.PrivateKey) error { - r.seq = r.seq + 1 - r.Set(ID_SECP256k1_KECCAK) - r.Set(Secp256k1(privkey.PublicKey)) - return r.signAndEncode(privkey) +// SetSig sets the record signature. It returns an error if the encoded record is larger +// than the size limit or if the signature is invalid according to the passed scheme. +// +// You can also use SetSig to remove the signature explicitly by passing a nil scheme +// and signature. +// +// SetSig panics when either the scheme or the signature (but not both) are nil. +func (r *Record) SetSig(s IdentityScheme, sig []byte) error { + switch { + // Prevent storing invalid data. + case s == nil && sig != nil: + panic("enr: invalid call to SetSig with non-nil signature but nil scheme") + case s != nil && sig == nil: + panic("enr: invalid call to SetSig with nil signature but non-nil scheme") + // Verify if we have a scheme. + case s != nil: + if err := s.Verify(r, sig); err != nil { + return err + } + raw, err := r.encode(sig) + if err != nil { + return err + } + r.signature, r.raw = sig, raw + // Reset otherwise. + default: + r.signature, r.raw = nil, nil + } + return nil } -func (r *Record) appendPairs(list []interface{}) []interface{} { +// AppendElements appends the sequence number and entries to the given slice. +func (r *Record) AppendElements(list []interface{}) []interface{} { list = append(list, r.seq) for _, p := range r.pairs { list = append(list, p.k, p.v) @@ -237,54 +315,15 @@ func (r *Record) appendPairs(list []interface{}) []interface{} { return list } -func (r *Record) signAndEncode(privkey *ecdsa.PrivateKey) error { - // Put record elements into a flat list. Leave room for the signature. - list := make([]interface{}, 1, len(r.pairs)*2+2) - list = r.appendPairs(list) - - // Sign the tail of the list. - h := sha3.NewKeccak256() - rlp.Encode(h, list[1:]) - sig, err := crypto.Sign(h.Sum(nil), privkey) - if err != nil { - return err - } - sig = sig[:len(sig)-1] // remove v - - // Put signature in front. - r.signature, list[0] = sig, sig - r.raw, err = rlp.EncodeToBytes(list) - if err != nil { - return err - } - if len(r.raw) > SizeLimit { - return errTooBig - } - return nil -} - -func (r *Record) verifySignature() error { - // Get identity scheme, public key, signature. - var id ID - var entry s256raw - if err := r.Load(&id); err != nil { - return err - } else if id != ID_SECP256k1_KECCAK { - return errNoID +func (r *Record) encode(sig []byte) (raw []byte, err error) { + list := make([]interface{}, 1, 2*len(r.pairs)+2) + list[0] = sig + list = r.AppendElements(list) + if raw, err = rlp.EncodeToBytes(list); err != nil { + return nil, err } - if err := r.Load(&entry); err != nil { - return err - } else if len(entry) != 33 { - return fmt.Errorf("invalid public key") - } - - // Verify the signature. - list := make([]interface{}, 0, len(r.pairs)*2+1) - list = r.appendPairs(list) - h := sha3.NewKeccak256() - rlp.Encode(h, list) - if !crypto.VerifySignature(entry, h.Sum(nil), r.signature) { - return errInvalidSig + if len(raw) > SizeLimit { + return nil, errTooBig } - return nil + return raw, nil } diff --git a/p2p/enr/enr_test.go b/p2p/enr/enr_test.go index bba1738bc2..8ea78fd9f4 100644 --- a/p2p/enr/enr_test.go +++ b/p2p/enr/enr_test.go @@ -18,7 +18,7 @@ package enr import ( "bytes" - "encoding/hex" + "encoding/binary" "fmt" "math/rand" "testing" @@ -26,13 +26,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/rlp" -) -var ( - privkey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") - pubkey = &privkey.PublicKey + "github.com/tomochain/tomochain/rlp" ) var rnd = rand.New(rand.NewSource(time.Now().UnixNano())) @@ -54,63 +49,51 @@ func TestGetSetID(t *testing.T) { assert.Equal(t, id, id2) } -// TestGetSetIP4 tests encoding/decoding and setting/getting of the IP4 key. -func TestGetSetIP4(t *testing.T) { - ip := IP4{192, 168, 0, 3} +// TestGetSetIP4 tests encoding/decoding and setting/getting of the IP key. +func TestGetSetIPv4(t *testing.T) { + ip := IPv4{192, 168, 0, 3} var r Record r.Set(ip) - var ip2 IP4 + var ip2 IPv4 require.NoError(t, r.Load(&ip2)) assert.Equal(t, ip, ip2) } // TestGetSetIP6 tests encoding/decoding and setting/getting of the IP6 key. -func TestGetSetIP6(t *testing.T) { - ip := IP6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68} +func TestGetSetIPv6(t *testing.T) { + ip := IPv6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68} var r Record r.Set(ip) - var ip2 IP6 + var ip2 IPv6 require.NoError(t, r.Load(&ip2)) assert.Equal(t, ip, ip2) } -// TestGetSetDiscPort tests encoding/decoding and setting/getting of the DiscPort key. -func TestGetSetDiscPort(t *testing.T) { - port := DiscPort(30309) +// TestGetSetUDP tests encoding/decoding and setting/getting of the UDP key. +func TestGetSetUDP(t *testing.T) { + port := UDP(30309) var r Record r.Set(port) - var port2 DiscPort + var port2 UDP require.NoError(t, r.Load(&port2)) assert.Equal(t, port, port2) } -// TestGetSetSecp256k1 tests encoding/decoding and setting/getting of the Secp256k1 key. -func TestGetSetSecp256k1(t *testing.T) { - var r Record - if err := r.Sign(privkey); err != nil { - t.Fatal(err) - } - - var pk Secp256k1 - require.NoError(t, r.Load(&pk)) - assert.EqualValues(t, pubkey, &pk) -} - func TestLoadErrors(t *testing.T) { var r Record - ip4 := IP4{127, 0, 0, 1} + ip4 := IPv4{127, 0, 0, 1} r.Set(ip4) // Check error for missing keys. - var ip6 IP6 - err := r.Load(&ip6) + var udp UDP + err := r.Load(&udp) if !IsNotFound(err) { t.Error("IsNotFound should return true for missing key") } - assert.Equal(t, &KeyError{Key: ip6.ENRKey(), Err: errNotFound}, err) + assert.Equal(t, &KeyError{Key: udp.ENRKey(), Err: errNotFound}, err) // Check error for invalid keys. var list []uint @@ -167,40 +150,75 @@ func TestSortedGetAndSet(t *testing.T) { func TestDirty(t *testing.T) { var r Record - if r.Signed() { - t.Error("Signed returned true for zero record") - } if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned { t.Errorf("expected errEncodeUnsigned, got %#v", err) } - require.NoError(t, r.Sign(privkey)) - if !r.Signed() { - t.Error("Signed return false for signed record") + require.NoError(t, signTest([]byte{5}, &r)) + if len(r.signature) == 0 { + t.Error("record is not signed") } _, err := rlp.EncodeToBytes(r) assert.NoError(t, err) r.SetSeq(3) - if r.Signed() { - t.Error("Signed returned true for modified record") + if len(r.signature) != 0 { + t.Error("signature still set after modification") } if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned { t.Errorf("expected errEncodeUnsigned, got %#v", err) } } +func TestSize(t *testing.T) { + var r Record + + // Empty record size is 3 bytes. + // Unsigned records cannot be encoded, but they could, the encoding + // would be [ 0, 0 ] -> 0xC28080. + assert.Equal(t, uint64(3), r.Size()) + + // Add one attribute. The size increases to 5, the encoding + // would be [ 0, 0, "k", "v" ] -> 0xC58080C26B76. + r.Set(WithEntry("k", "v")) + assert.Equal(t, uint64(5), r.Size()) + + // Now add a signature. + nodeid := []byte{1, 2, 3, 4, 5, 6, 7, 8} + signTest(nodeid, &r) + assert.Equal(t, uint64(45), r.Size()) + enc, _ := rlp.EncodeToBytes(&r) + if r.Size() != uint64(len(enc)) { + t.Error("Size() not equal encoded length", len(enc)) + } + if r.Size() != computeSize(&r) { + t.Error("Size() not equal computed size", computeSize(&r)) + } +} + +func TestSeq(t *testing.T) { + var r Record + + assert.Equal(t, uint64(0), r.Seq()) + r.Set(UDP(1)) + assert.Equal(t, uint64(0), r.Seq()) + signTest([]byte{5}, &r) + assert.Equal(t, uint64(0), r.Seq()) + r.Set(UDP(2)) + assert.Equal(t, uint64(1), r.Seq()) +} + // TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record. func TestGetSetOverwrite(t *testing.T) { var r Record - ip := IP4{192, 168, 0, 3} + ip := IPv4{192, 168, 0, 3} r.Set(ip) - ip2 := IP4{192, 168, 0, 4} + ip2 := IPv4{192, 168, 0, 4} r.Set(ip2) - var ip3 IP4 + var ip3 IPv4 require.NoError(t, r.Load(&ip3)) assert.Equal(t, ip2, ip3) } @@ -208,9 +226,9 @@ func TestGetSetOverwrite(t *testing.T) { // TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record. func TestSignEncodeAndDecode(t *testing.T) { var r Record - r.Set(DiscPort(30303)) - r.Set(IP4{127, 0, 0, 1}) - require.NoError(t, r.Sign(privkey)) + r.Set(UDP(30303)) + r.Set(IPv4{127, 0, 0, 1}) + require.NoError(t, signTest([]byte{5}, &r)) blob, err := rlp.EncodeToBytes(r) require.NoError(t, err) @@ -224,62 +242,43 @@ func TestSignEncodeAndDecode(t *testing.T) { assert.Equal(t, blob, blob2) } -func TestNodeAddr(t *testing.T) { - var r Record - if addr := r.NodeAddr(); addr != nil { - t.Errorf("wrong address on empty record: got %v, want %v", addr, nil) - } - - require.NoError(t, r.Sign(privkey)) - expected := "caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726" - assert.Equal(t, expected, hex.EncodeToString(r.NodeAddr())) -} - -var pyRecord, _ = hex.DecodeString("f896b840954dc36583c1f4b69ab59b1375f362f06ee99f3723cd77e64b6de6d211c27d7870642a79d4516997f94091325d2a7ca6215376971455fb221d34f35b277149a1018664697363763582765f82696490736563703235366b312d6b656363616b83697034847f00000189736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138") - -// TestPythonInterop checks that we can decode and verify a record produced by the Python -// implementation. -func TestPythonInterop(t *testing.T) { - var r Record - if err := rlp.DecodeBytes(pyRecord, &r); err != nil { - t.Fatalf("can't decode: %v", err) - } - - var ( - wantAddr, _ = hex.DecodeString("caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726") - wantSeq = uint64(1) - wantIP = IP4{127, 0, 0, 1} - wantDiscport = DiscPort(30303) - ) - if r.Seq() != wantSeq { - t.Errorf("wrong seq: got %d, want %d", r.Seq(), wantSeq) - } - if addr := r.NodeAddr(); !bytes.Equal(addr, wantAddr) { - t.Errorf("wrong addr: got %x, want %x", addr, wantAddr) - } - want := map[Entry]interface{}{new(IP4): &wantIP, new(DiscPort): &wantDiscport} - for k, v := range want { - desc := fmt.Sprintf("loading key %q", k.ENRKey()) - if assert.NoError(t, r.Load(k), desc) { - assert.Equal(t, k, v, desc) - } - } -} - // TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed. func TestRecordTooBig(t *testing.T) { var r Record key := randomString(10) // set a big value for random key, expect error - r.Set(WithEntry(key, randomString(300))) - if err := r.Sign(privkey); err != errTooBig { + r.Set(WithEntry(key, randomString(SizeLimit))) + if err := signTest([]byte{5}, &r); err != errTooBig { t.Fatalf("expected to get errTooBig, got %#v", err) } // set an acceptable value for random key, expect no error r.Set(WithEntry(key, randomString(100))) - require.NoError(t, r.Sign(privkey)) + require.NoError(t, signTest([]byte{5}, &r)) +} + +// This checks that incomplete RLP inputs are handled correctly. +func TestDecodeIncomplete(t *testing.T) { + type decTest struct { + input []byte + err error + } + tests := []decTest{ + {[]byte{0xC0}, errIncompleteList}, + {[]byte{0xC1, 0x1}, errIncompleteList}, + {[]byte{0xC2, 0x1, 0x2}, nil}, + {[]byte{0xC3, 0x1, 0x2, 0x3}, errIncompletePair}, + {[]byte{0xC4, 0x1, 0x2, 0x3, 0x4}, nil}, + {[]byte{0xC5, 0x1, 0x2, 0x3, 0x4, 0x5}, errIncompletePair}, + } + for _, test := range tests { + var r Record + err := rlp.DecodeBytes(test.input, &r) + if err != test.err { + t.Errorf("wrong error for %X: %v", test.input, err) + } + } } // TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs. @@ -295,9 +294,12 @@ func TestSignEncodeAndDecodeRandom(t *testing.T) { r.Set(WithEntry(key, &value)) } - require.NoError(t, r.Sign(privkey)) - _, err := rlp.EncodeToBytes(r) + require.NoError(t, signTest([]byte{5}, &r)) + + enc, err := rlp.EncodeToBytes(r) require.NoError(t, err) + require.Equal(t, uint64(len(enc)), r.Size()) + require.Equal(t, uint64(len(enc)), computeSize(&r)) for k, v := range pairs { desc := fmt.Sprintf("key %q", k) @@ -308,11 +310,40 @@ func TestSignEncodeAndDecodeRandom(t *testing.T) { } } -func BenchmarkDecode(b *testing.B) { - var r Record - for i := 0; i < b.N; i++ { - rlp.DecodeBytes(pyRecord, &r) +type testSig struct{} + +type testID []byte + +func (id testID) ENRKey() string { return "testid" } + +func signTest(id []byte, r *Record) error { + r.Set(ID("test")) + r.Set(testID(id)) + return r.SetSig(testSig{}, makeTestSig(id, r.Seq())) +} + +func makeTestSig(id []byte, seq uint64) []byte { + sig := make([]byte, 8, len(id)+8) + binary.BigEndian.PutUint64(sig[:8], seq) + sig = append(sig, id...) + return sig +} + +func (testSig) Verify(r *Record, sig []byte) error { + var id []byte + if err := r.Load((*testID)(&id)); err != nil { + return err + } + if !bytes.Equal(sig, makeTestSig(id, r.Seq())) { + return ErrInvalidSig + } + return nil +} + +func (testSig) NodeAddr(r *Record) []byte { + var id []byte + if err := r.Load((*testID)(&id)); err != nil { + return nil } - b.StopTimer() - r.NodeAddr() + return id } diff --git a/p2p/enr/entries.go b/p2p/enr/entries.go index e31a4901a3..f68e7725d2 100644 --- a/p2p/enr/entries.go +++ b/p2p/enr/entries.go @@ -62,27 +62,83 @@ type DiscPort uint16 func (v DiscPort) ENRKey() string { return "discv5" } +// TCP is the "tcp" key, which holds the TCP port of the node. +type TCP uint16 + +func (v TCP) ENRKey() string { return "tcp" } + +// TCP6 is the "tcp6" key, which holds the IPv6-specific tcp6 port of the node. +type TCP6 uint16 + +func (v TCP6) ENRKey() string { return "tcp6" } + +// UDP is the "udp" key, which holds the UDP port of the node. +type UDP uint16 + +func (v UDP) ENRKey() string { return "udp" } + +// UDP6 is the "udp6" key, which holds the IPv6-specific UDP port of the node. +type UDP6 uint16 + +func (v UDP6) ENRKey() string { return "udp6" } + // ID is the "id" key, which holds the name of the identity scheme. type ID string +const IDv4 = ID("v4") // the default identity scheme + func (v ID) ENRKey() string { return "id" } -// IP4 is the "ip4" key, which holds a 4-byte IPv4 address. -type IP4 net.IP +// IP is either the "ip" or "ip6" key, depending on the value. +// Use this value to encode IP addresses that can be either v4 or v6. +// To load an address from a record use the IPv4 or IPv6 types. +type IP net.IP -func (v IP4) ENRKey() string { return "ip4" } +func (v IP) ENRKey() string { + if net.IP(v).To4() == nil { + return "ip6" + } + return "ip" +} // EncodeRLP implements rlp.Encoder. -func (v IP4) EncodeRLP(w io.Writer) error { +func (v IP) EncodeRLP(w io.Writer) error { + if ip4 := net.IP(v).To4(); ip4 != nil { + return rlp.Encode(w, ip4) + } + if ip6 := net.IP(v).To16(); ip6 != nil { + return rlp.Encode(w, ip6) + } + return fmt.Errorf("invalid IP address: %v", net.IP(v)) +} + +// DecodeRLP implements rlp.Decoder. +func (v *IP) DecodeRLP(s *rlp.Stream) error { + if err := s.Decode((*net.IP)(v)); err != nil { + return err + } + if len(*v) != 4 && len(*v) != 16 { + return fmt.Errorf("invalid IP address, want 4 or 16 bytes: %v", *v) + } + return nil +} + +// IPv4 is the "ip" key, which holds the IP address of the node. +type IPv4 net.IP + +func (v IPv4) ENRKey() string { return "ip" } + +// EncodeRLP implements rlp.Encoder. +func (v IPv4) EncodeRLP(w io.Writer) error { ip4 := net.IP(v).To4() if ip4 == nil { - return fmt.Errorf("invalid IPv4 address: %v", v) + return fmt.Errorf("invalid IPv4 address: %v", net.IP(v)) } return rlp.Encode(w, ip4) } // DecodeRLP implements rlp.Decoder. -func (v *IP4) DecodeRLP(s *rlp.Stream) error { +func (v *IPv4) DecodeRLP(s *rlp.Stream) error { if err := s.Decode((*net.IP)(v)); err != nil { return err } @@ -92,19 +148,22 @@ func (v *IP4) DecodeRLP(s *rlp.Stream) error { return nil } -// IP6 is the "ip6" key, which holds a 16-byte IPv6 address. -type IP6 net.IP +// IPv6 is the "ip6" key, which holds the IP address of the node. +type IPv6 net.IP -func (v IP6) ENRKey() string { return "ip6" } +func (v IPv6) ENRKey() string { return "ip6" } // EncodeRLP implements rlp.Encoder. -func (v IP6) EncodeRLP(w io.Writer) error { - ip6 := net.IP(v) +func (v IPv6) EncodeRLP(w io.Writer) error { + ip6 := net.IP(v).To16() + if ip6 == nil { + return fmt.Errorf("invalid IPv6 address: %v", net.IP(v)) + } return rlp.Encode(w, ip6) } // DecodeRLP implements rlp.Decoder. -func (v *IP6) DecodeRLP(s *rlp.Stream) error { +func (v *IPv6) DecodeRLP(s *rlp.Stream) error { if err := s.Decode((*net.IP)(v)); err != nil { return err } From 7c5781ce62710503be4c6e9b585fe29ba830dcd6 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Fri, 27 Oct 2023 16:53:44 +0700 Subject: [PATCH 095/119] Implement new P2P node representation --- crypto/crypto.go | 9 + p2p/enode/idscheme.go | 161 ++++++++++++ p2p/enode/idscheme_test.go | 74 ++++++ p2p/enode/iter.go | 295 +++++++++++++++++++++ p2p/enode/iter_test.go | 291 +++++++++++++++++++++ p2p/enode/localnode.go | 332 ++++++++++++++++++++++++ p2p/enode/localnode_test.go | 129 ++++++++++ p2p/enode/node.go | 279 ++++++++++++++++++++ p2p/enode/node_test.go | 145 +++++++++++ p2p/enode/nodedb.go | 501 ++++++++++++++++++++++++++++++++++++ p2p/enode/nodedb_test.go | 469 +++++++++++++++++++++++++++++++++ p2p/enode/urlv4.go | 203 +++++++++++++++ p2p/enode/urlv4_test.go | 200 ++++++++++++++ 13 files changed, 3088 insertions(+) create mode 100644 p2p/enode/idscheme.go create mode 100644 p2p/enode/idscheme_test.go create mode 100644 p2p/enode/iter.go create mode 100644 p2p/enode/iter_test.go create mode 100644 p2p/enode/localnode.go create mode 100644 p2p/enode/localnode_test.go create mode 100644 p2p/enode/node.go create mode 100644 p2p/enode/node_test.go create mode 100644 p2p/enode/nodedb.go create mode 100644 p2p/enode/nodedb_test.go create mode 100644 p2p/enode/urlv4.go create mode 100644 p2p/enode/urlv4_test.go diff --git a/crypto/crypto.go b/crypto/crypto.go index 6affee64ce..9154a5e9a9 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -163,6 +163,15 @@ func FromECDSA(priv *ecdsa.PrivateKey) []byte { return math.PaddedBigBytes(priv.D, priv.Params().BitSize/8) } +// UnmarshalPubkey converts bytes to a secp256k1 public key. +func UnmarshalPubkey(pub []byte) (*ecdsa.PublicKey, error) { + x, y := elliptic.Unmarshal(S256(), pub) + if x == nil { + return nil, errInvalidPubkey + } + return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y}, nil +} + func ToECDSAPub(pub []byte) *ecdsa.PublicKey { if len(pub) == 0 { return nil diff --git a/p2p/enode/idscheme.go b/p2p/enode/idscheme.go new file mode 100644 index 0000000000..87981db5c2 --- /dev/null +++ b/p2p/enode/idscheme.go @@ -0,0 +1,161 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "crypto/ecdsa" + "fmt" + "io" + + "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/p2p/enr" + "github.com/tomochain/tomochain/rlp" + "golang.org/x/crypto/sha3" +) + +// ValidSchemes is a List of known secure identity schemes. +var ValidSchemes = enr.SchemeMap{ + "v4": V4ID{}, +} + +// ValidSchemesForTesting is a List of identity schemes for testing. +var ValidSchemesForTesting = enr.SchemeMap{ + "v4": V4ID{}, + "null": NullID{}, +} + +// V4ID is the "v4" identity scheme. +type V4ID struct{} + +// SignV4 signs a record using the v4 scheme. +func SignV4(r *enr.Record, privkey *ecdsa.PrivateKey) error { + // Copy r to avoid modifying it if signing fails. + cpy := *r + cpy.Set(enr.ID("v4")) + cpy.Set(Secp256k1(privkey.PublicKey)) + + h := sha3.NewLegacyKeccak256() + rlp.Encode(h, cpy.AppendElements(nil)) + sig, err := crypto.Sign(h.Sum(nil), privkey) + if err != nil { + return err + } + sig = sig[:len(sig)-1] // remove v + if err = cpy.SetSig(V4ID{}, sig); err == nil { + *r = cpy + } + return err +} + +func (V4ID) Verify(r *enr.Record, sig []byte) error { + var entry s256raw + if err := r.Load(&entry); err != nil { + return err + } else if len(entry) != 33 { + return fmt.Errorf("invalid public key") + } + + h := sha3.NewLegacyKeccak256() + rlp.Encode(h, r.AppendElements(nil)) + if !crypto.VerifySignature(entry, h.Sum(nil), sig) { + return enr.ErrInvalidSig + } + return nil +} + +func (V4ID) NodeAddr(r *enr.Record) []byte { + var pubkey Secp256k1 + err := r.Load(&pubkey) + if err != nil { + return nil + } + buf := make([]byte, 64) + math.ReadBits(pubkey.X, buf[:32]) + math.ReadBits(pubkey.Y, buf[32:]) + return crypto.Keccak256(buf) +} + +// Secp256k1 is the "secp256k1" key, which holds a public key. +type Secp256k1 ecdsa.PublicKey + +func (v Secp256k1) ENRKey() string { return "secp256k1" } + +// EncodeRLP implements rlp.Encoder. +func (v Secp256k1) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, crypto.CompressPubkey((*ecdsa.PublicKey)(&v))) +} + +// DecodeRLP implements rlp.Decoder. +func (v *Secp256k1) DecodeRLP(s *rlp.Stream) error { + buf, err := s.Bytes() + if err != nil { + return err + } + pk, err := crypto.DecompressPubkey(buf) + if err != nil { + return err + } + *v = (Secp256k1)(*pk) + return nil +} + +// s256raw is an unparsed secp256k1 public key entry. +type s256raw []byte + +func (s256raw) ENRKey() string { return "secp256k1" } + +// v4CompatID is a weaker and insecure version of the "v4" scheme which only checks for the +// presence of a secp256k1 public key, but doesn't verify the signature. +type v4CompatID struct { + V4ID +} + +func (v4CompatID) Verify(r *enr.Record, sig []byte) error { + var pubkey Secp256k1 + return r.Load(&pubkey) +} + +func signV4Compat(r *enr.Record, pubkey *ecdsa.PublicKey) { + r.Set((*Secp256k1)(pubkey)) + if err := r.SetSig(v4CompatID{}, []byte{}); err != nil { + panic(err) + } +} + +// NullID is the "null" ENR identity scheme. This scheme stores the node +// ID in the record without any signature. +type NullID struct{} + +func (NullID) Verify(r *enr.Record, sig []byte) error { + return nil +} + +func (NullID) NodeAddr(r *enr.Record) []byte { + var id ID + r.Load(enr.WithEntry("nulladdr", &id)) + return id[:] +} + +func SignNull(r *enr.Record, id ID) *Node { + r.Set(enr.ID("null")) + r.Set(enr.WithEntry("nulladdr", id)) + if err := r.SetSig(NullID{}, []byte{}); err != nil { + panic(err) + } + return &Node{r: *r, id: id} +} diff --git a/p2p/enode/idscheme_test.go b/p2p/enode/idscheme_test.go new file mode 100644 index 0000000000..8d7440f471 --- /dev/null +++ b/p2p/enode/idscheme_test.go @@ -0,0 +1,74 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "bytes" + "crypto/ecdsa" + "encoding/hex" + "math/big" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/p2p/enr" + "github.com/tomochain/tomochain/rlp" +) + +var ( + privkey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + pubkey = &privkey.PublicKey +) + +func TestEmptyNodeID(t *testing.T) { + var r enr.Record + if addr := ValidSchemes.NodeAddr(&r); addr != nil { + t.Errorf("wrong address on empty record: got %v, want %v", addr, nil) + } + + require.NoError(t, SignV4(&r, privkey)) + expected := "a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7" + assert.Equal(t, expected, hex.EncodeToString(ValidSchemes.NodeAddr(&r))) +} + +// Checks that failure to sign leaves the record unmodified. +func TestSignError(t *testing.T) { + invalidKey := &ecdsa.PrivateKey{D: new(big.Int), PublicKey: *pubkey} + + var r enr.Record + emptyEnc, _ := rlp.EncodeToBytes(&r) + if err := SignV4(&r, invalidKey); err == nil { + t.Fatal("expected error from SignV4") + } + newEnc, _ := rlp.EncodeToBytes(&r) + if !bytes.Equal(newEnc, emptyEnc) { + t.Fatal("record modified even though signing failed") + } +} + +// TestGetSetSecp256k1 tests encoding/decoding and setting/getting of the Secp256k1 key. +func TestGetSetSecp256k1(t *testing.T) { + var r enr.Record + if err := SignV4(&r, privkey); err != nil { + t.Fatal(err) + } + + var pk Secp256k1 + require.NoError(t, r.Load(&pk)) + assert.EqualValues(t, pubkey, &pk) +} diff --git a/p2p/enode/iter.go b/p2p/enode/iter.go new file mode 100644 index 0000000000..b8ab4a758a --- /dev/null +++ b/p2p/enode/iter.go @@ -0,0 +1,295 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "sync" + "time" +) + +// Iterator represents a sequence of nodes. The Next method moves to the next node in the +// sequence. It returns false when the sequence has ended or the iterator is closed. Close +// may be called concurrently with Next and Node, and interrupts Next if it is blocked. +type Iterator interface { + Next() bool // moves to next node + Node() *Node // returns current node + Close() // ends the iterator +} + +// ReadNodes reads at most n nodes from the given iterator. The return value contains no +// duplicates and no nil values. To prevent looping indefinitely for small repeating node +// sequences, this function calls Next at most n times. +func ReadNodes(it Iterator, n int) []*Node { + seen := make(map[ID]*Node, n) + for i := 0; i < n && it.Next(); i++ { + // Remove duplicates, keeping the node with higher seq. + node := it.Node() + prevNode, ok := seen[node.ID()] + if ok && prevNode.Seq() > node.Seq() { + continue + } + seen[node.ID()] = node + } + result := make([]*Node, 0, len(seen)) + for _, node := range seen { + result = append(result, node) + } + return result +} + +// IterNodes makes an iterator which runs through the given nodes once. +func IterNodes(nodes []*Node) Iterator { + return &sliceIter{nodes: nodes, index: -1} +} + +// CycleNodes makes an iterator which cycles through the given nodes indefinitely. +func CycleNodes(nodes []*Node) Iterator { + return &sliceIter{nodes: nodes, index: -1, cycle: true} +} + +type sliceIter struct { + mu sync.Mutex + nodes []*Node + index int + cycle bool +} + +func (it *sliceIter) Next() bool { + it.mu.Lock() + defer it.mu.Unlock() + + if len(it.nodes) == 0 { + return false + } + it.index++ + if it.index == len(it.nodes) { + if it.cycle { + it.index = 0 + } else { + it.nodes = nil + return false + } + } + return true +} + +func (it *sliceIter) Node() *Node { + it.mu.Lock() + defer it.mu.Unlock() + if len(it.nodes) == 0 { + return nil + } + return it.nodes[it.index] +} + +func (it *sliceIter) Close() { + it.mu.Lock() + defer it.mu.Unlock() + + it.nodes = nil +} + +// Filter wraps an iterator such that Next only returns nodes for which +// the 'check' function returns true. +func Filter(it Iterator, check func(*Node) bool) Iterator { + return &filterIter{it, check} +} + +type filterIter struct { + Iterator + check func(*Node) bool +} + +func (f *filterIter) Next() bool { + for f.Iterator.Next() { + if f.check(f.Node()) { + return true + } + } + return false +} + +// FairMix aggregates multiple node iterators. The mixer itself is an iterator which ends +// only when Close is called. Source iterators added via AddSource are removed from the +// mix when they end. +// +// The distribution of nodes returned by Next is approximately fair, i.e. FairMix +// attempts to draw from all sources equally often. However, if a certain source is slow +// and doesn't return a node within the configured timeout, a node from any other source +// will be returned. +// +// It's safe to call AddSource and Close concurrently with Next. +type FairMix struct { + wg sync.WaitGroup + fromAny chan *Node + timeout time.Duration + cur *Node + + mu sync.Mutex + closed chan struct{} + sources []*mixSource + last int +} + +type mixSource struct { + it Iterator + next chan *Node + timeout time.Duration +} + +// NewFairMix creates a mixer. +// +// The timeout specifies how long the mixer will wait for the next fairly-chosen source +// before giving up and taking a node from any other source. A good way to set the timeout +// is deciding how long you'd want to wait for a node on average. Passing a negative +// timeout makes the mixer completely fair. +func NewFairMix(timeout time.Duration) *FairMix { + m := &FairMix{ + fromAny: make(chan *Node), + closed: make(chan struct{}), + timeout: timeout, + } + return m +} + +// AddSource adds a source of nodes. +func (m *FairMix) AddSource(it Iterator) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed == nil { + return + } + m.wg.Add(1) + source := &mixSource{it, make(chan *Node), m.timeout} + m.sources = append(m.sources, source) + go m.runSource(m.closed, source) +} + +// Close shuts down the mixer and all current sources. +// Calling this is required to release resources associated with the mixer. +func (m *FairMix) Close() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed == nil { + return + } + for _, s := range m.sources { + s.it.Close() + } + close(m.closed) + m.wg.Wait() + close(m.fromAny) + m.sources = nil + m.closed = nil +} + +// Next returns a node from a random source. +func (m *FairMix) Next() bool { + m.cur = nil + + for { + source := m.pickSource() + if source == nil { + return m.nextFromAny() + } + + var timeout <-chan time.Time + if source.timeout >= 0 { + timer := time.NewTimer(source.timeout) + timeout = timer.C + defer timer.Stop() + } + + select { + case n, ok := <-source.next: + if ok { + // Here, the timeout is reset to the configured value + // because the source delivered a node. + source.timeout = m.timeout + m.cur = n + return true + } + // This source has ended. + m.deleteSource(source) + case <-timeout: + // The selected source did not deliver a node within the timeout, so the + // timeout duration is halved for next time. This is supposed to improve + // latency with stuck sources. + source.timeout /= 2 + return m.nextFromAny() + } + } +} + +// Node returns the current node. +func (m *FairMix) Node() *Node { + return m.cur +} + +// nextFromAny is used when there are no sources or when the 'fair' choice +// doesn't turn up a node quickly enough. +func (m *FairMix) nextFromAny() bool { + n, ok := <-m.fromAny + if ok { + m.cur = n + } + return ok +} + +// pickSource chooses the next source to read from, cycling through them in order. +func (m *FairMix) pickSource() *mixSource { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.sources) == 0 { + return nil + } + m.last = (m.last + 1) % len(m.sources) + return m.sources[m.last] +} + +// deleteSource deletes a source. +func (m *FairMix) deleteSource(s *mixSource) { + m.mu.Lock() + defer m.mu.Unlock() + + for i := range m.sources { + if m.sources[i] == s { + copy(m.sources[i:], m.sources[i+1:]) + m.sources[len(m.sources)-1] = nil + m.sources = m.sources[:len(m.sources)-1] + break + } + } +} + +// runSource reads a single source in a loop. +func (m *FairMix) runSource(closed chan struct{}, s *mixSource) { + defer m.wg.Done() + defer close(s.next) + for s.it.Next() { + n := s.it.Node() + select { + case s.next <- n: + case m.fromAny <- n: + case <-closed: + return + } + } +} diff --git a/p2p/enode/iter_test.go b/p2p/enode/iter_test.go new file mode 100644 index 0000000000..ae980345aa --- /dev/null +++ b/p2p/enode/iter_test.go @@ -0,0 +1,291 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "encoding/binary" + "runtime" + "sync/atomic" + "testing" + "time" + + "github.com/tomochain/tomochain/p2p/enr" +) + +func TestReadNodes(t *testing.T) { + nodes := ReadNodes(new(genIter), 10) + checkNodes(t, nodes, 10) +} + +// This test checks that ReadNodes terminates when reading N nodes from an iterator +// which returns less than N nodes in an endless cycle. +func TestReadNodesCycle(t *testing.T) { + iter := &callCountIter{ + Iterator: CycleNodes([]*Node{ + testNode(0, 0), + testNode(1, 0), + testNode(2, 0), + }), + } + nodes := ReadNodes(iter, 10) + checkNodes(t, nodes, 3) + if iter.count != 10 { + t.Fatalf("%d calls to Next, want %d", iter.count, 100) + } +} + +func TestFilterNodes(t *testing.T) { + nodes := make([]*Node, 100) + for i := range nodes { + nodes[i] = testNode(uint64(i), uint64(i)) + } + + it := Filter(IterNodes(nodes), func(n *Node) bool { + return n.Seq() >= 50 + }) + for i := 50; i < len(nodes); i++ { + if !it.Next() { + t.Fatal("Next returned false") + } + if it.Node() != nodes[i] { + t.Fatalf("iterator returned wrong node %v\nwant %v", it.Node(), nodes[i]) + } + } + if it.Next() { + t.Fatal("Next returned true after underlying iterator has ended") + } +} + +func checkNodes(t *testing.T, nodes []*Node, wantLen int) { + if len(nodes) != wantLen { + t.Errorf("slice has %d nodes, want %d", len(nodes), wantLen) + return + } + seen := make(map[ID]bool, len(nodes)) + for i, e := range nodes { + if e == nil { + t.Errorf("nil node at index %d", i) + return + } + if seen[e.ID()] { + t.Errorf("slice has duplicate node %v", e.ID()) + return + } + seen[e.ID()] = true + } +} + +// This test checks fairness of FairMix in the happy case where all sources return nodes +// within the context's deadline. +func TestFairMix(t *testing.T) { + for i := 0; i < 500; i++ { + testMixerFairness(t) + } +} + +func testMixerFairness(t *testing.T) { + mix := NewFairMix(1 * time.Second) + mix.AddSource(&genIter{index: 1}) + mix.AddSource(&genIter{index: 2}) + mix.AddSource(&genIter{index: 3}) + defer mix.Close() + + nodes := ReadNodes(mix, 500) + checkNodes(t, nodes, 500) + + // Verify that the nodes slice contains an approximately equal number of nodes + // from each source. + d := idPrefixDistribution(nodes) + for _, count := range d { + if approxEqual(count, len(nodes)/3, 30) { + t.Fatalf("ID distribution is unfair: %v", d) + } + } +} + +// This test checks that FairMix falls back to an alternative source when +// the 'fair' choice doesn't return a node within the timeout. +func TestFairMixNextFromAll(t *testing.T) { + mix := NewFairMix(1 * time.Millisecond) + mix.AddSource(&genIter{index: 1}) + mix.AddSource(CycleNodes(nil)) + defer mix.Close() + + nodes := ReadNodes(mix, 500) + checkNodes(t, nodes, 500) + + d := idPrefixDistribution(nodes) + if len(d) > 1 || d[1] != len(nodes) { + t.Fatalf("wrong ID distribution: %v", d) + } +} + +// This test ensures FairMix works for Next with no sources. +func TestFairMixEmpty(t *testing.T) { + var ( + mix = NewFairMix(1 * time.Second) + testN = testNode(1, 1) + ch = make(chan *Node) + ) + defer mix.Close() + + go func() { + mix.Next() + ch <- mix.Node() + }() + + mix.AddSource(CycleNodes([]*Node{testN})) + if n := <-ch; n != testN { + t.Errorf("got wrong node: %v", n) + } +} + +// This test checks closing a source while Next runs. +func TestFairMixRemoveSource(t *testing.T) { + mix := NewFairMix(1 * time.Second) + source := make(blockingIter) + mix.AddSource(source) + + sig := make(chan *Node) + go func() { + <-sig + mix.Next() + sig <- mix.Node() + }() + + sig <- nil + runtime.Gosched() + source.Close() + + wantNode := testNode(0, 0) + mix.AddSource(CycleNodes([]*Node{wantNode})) + n := <-sig + + if len(mix.sources) != 1 { + t.Fatalf("have %d sources, want one", len(mix.sources)) + } + if n != wantNode { + t.Fatalf("mixer returned wrong node") + } +} + +type blockingIter chan struct{} + +func (it blockingIter) Next() bool { + <-it + return false +} + +func (it blockingIter) Node() *Node { + return nil +} + +func (it blockingIter) Close() { + close(it) +} + +func TestFairMixClose(t *testing.T) { + for i := 0; i < 20 && !t.Failed(); i++ { + testMixerClose(t) + } +} + +func testMixerClose(t *testing.T) { + mix := NewFairMix(-1) + mix.AddSource(CycleNodes(nil)) + mix.AddSource(CycleNodes(nil)) + + done := make(chan struct{}) + go func() { + defer close(done) + if mix.Next() { + t.Error("Next returned true") + } + }() + // This call is supposed to make it more likely that NextNode is + // actually executing by the time we call Close. + runtime.Gosched() + + mix.Close() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("Next didn't unblock on Close") + } + + mix.Close() // shouldn't crash +} + +func idPrefixDistribution(nodes []*Node) map[uint32]int { + d := make(map[uint32]int, len(nodes)) + for _, node := range nodes { + id := node.ID() + d[binary.BigEndian.Uint32(id[:4])]++ + } + return d +} + +func approxEqual(x, y, ε int) bool { + if y > x { + x, y = y, x + } + return x-y > ε +} + +// genIter creates fake nodes with numbered IDs based on 'index' and 'gen' +type genIter struct { + node *Node + index, gen uint32 +} + +func (s *genIter) Next() bool { + index := atomic.LoadUint32(&s.index) + if index == ^uint32(0) { + s.node = nil + return false + } + s.node = testNode(uint64(index)<<32|uint64(s.gen), 0) + s.gen++ + return true +} + +func (s *genIter) Node() *Node { + return s.node +} + +func (s *genIter) Close() { + atomic.StoreUint32(&s.index, ^uint32(0)) +} + +func testNode(id, seq uint64) *Node { + var nodeID ID + binary.BigEndian.PutUint64(nodeID[:], id) + r := new(enr.Record) + r.SetSeq(seq) + return SignNull(r, nodeID) +} + +// callCountIter counts calls to NextNode. +type callCountIter struct { + Iterator + count int +} + +func (it *callCountIter) Next() bool { + it.count++ + return it.Iterator.Next() +} diff --git a/p2p/enode/localnode.go b/p2p/enode/localnode.go new file mode 100644 index 0000000000..06f274992f --- /dev/null +++ b/p2p/enode/localnode.go @@ -0,0 +1,332 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "crypto/ecdsa" + "fmt" + "net" + "reflect" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/p2p/enr" + "github.com/tomochain/tomochain/p2p/netutil" +) + +const ( + // IP tracker configuration + iptrackMinStatements = 10 + iptrackWindow = 5 * time.Minute + iptrackContactWindow = 10 * time.Minute + + // time needed to wait between two updates to the local ENR + recordUpdateThrottle = time.Millisecond +) + +// LocalNode produces the signed node record of a local node, i.e. a node run in the +// current process. Setting ENR entries via the Set method updates the record. A new version +// of the record is signed on demand when the Node method is called. +type LocalNode struct { + cur atomic.Value // holds a non-nil node pointer while the record is up-to-date + + id ID + key *ecdsa.PrivateKey + db *DB + + // everything below is protected by a lock + mu sync.RWMutex + seq uint64 + update time.Time // timestamp when the record was last updated + entries map[string]enr.Entry + endpoint4 lnEndpoint + endpoint6 lnEndpoint +} + +type lnEndpoint struct { + track *netutil.IPTracker + staticIP, fallbackIP net.IP + fallbackUDP uint16 // port +} + +// NewLocalNode creates a local node. +func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode { + ln := &LocalNode{ + id: PubkeyToIDV4(&key.PublicKey), + db: db, + key: key, + entries: make(map[string]enr.Entry), + endpoint4: lnEndpoint{ + track: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements), + }, + endpoint6: lnEndpoint{ + track: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements), + }, + } + ln.seq = db.localSeq(ln.id) + ln.update = time.Now() + ln.cur.Store((*Node)(nil)) + return ln +} + +// Database returns the node database associated with the local node. +func (ln *LocalNode) Database() *DB { + return ln.db +} + +// Node returns the current version of the local node record. +func (ln *LocalNode) Node() *Node { + // If we have a valid record, return that + n := ln.cur.Load().(*Node) + if n != nil { + return n + } + + // Record was invalidated, sign a new copy. + ln.mu.Lock() + defer ln.mu.Unlock() + + // Double check the current record, since multiple goroutines might be waiting + // on the write mutex. + if n = ln.cur.Load().(*Node); n != nil { + return n + } + + // The initial sequence number is the current timestamp in milliseconds. To ensure + // that the initial sequence number will always be higher than any previous sequence + // number (assuming the clock is correct), we want to avoid updating the record faster + // than once per ms. So we need to sleep here until the next possible update time has + // arrived. + lastChange := time.Since(ln.update) + if lastChange < recordUpdateThrottle { + time.Sleep(recordUpdateThrottle - lastChange) + } + + ln.sign() + ln.update = time.Now() + return ln.cur.Load().(*Node) +} + +// Seq returns the current sequence number of the local node record. +func (ln *LocalNode) Seq() uint64 { + ln.mu.Lock() + defer ln.mu.Unlock() + + return ln.seq +} + +// ID returns the local node ID. +func (ln *LocalNode) ID() ID { + return ln.id +} + +// Set puts the given entry into the local record, overwriting any existing value. +// Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll +// be overwritten by the endpoint predictor. +// +// Since node record updates are throttled to one per second, Set is asynchronous. +// Any update will be queued up and published when at least one second passes from +// the last change. +func (ln *LocalNode) Set(e enr.Entry) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.set(e) +} + +func (ln *LocalNode) set(e enr.Entry) { + val, exists := ln.entries[e.ENRKey()] + if !exists || !reflect.DeepEqual(val, e) { + ln.entries[e.ENRKey()] = e + ln.invalidate() + } +} + +// Delete removes the given entry from the local record. +func (ln *LocalNode) Delete(e enr.Entry) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.delete(e) +} + +func (ln *LocalNode) delete(e enr.Entry) { + _, exists := ln.entries[e.ENRKey()] + if exists { + delete(ln.entries, e.ENRKey()) + ln.invalidate() + } +} + +func (ln *LocalNode) endpointForIP(ip net.IP) *lnEndpoint { + if ip.To4() != nil { + return &ln.endpoint4 + } + return &ln.endpoint6 +} + +// SetStaticIP sets the local IP to the given one unconditionally. +// This disables endpoint prediction. +func (ln *LocalNode) SetStaticIP(ip net.IP) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.endpointForIP(ip).staticIP = ip + ln.updateEndpoints() +} + +// SetFallbackIP sets the last-resort IP address. This address is used +// if no endpoint prediction can be made and no static IP is set. +func (ln *LocalNode) SetFallbackIP(ip net.IP) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.endpointForIP(ip).fallbackIP = ip + ln.updateEndpoints() +} + +// SetFallbackUDP sets the last-resort UDP-on-IPv4 port. This port is used +// if no endpoint prediction can be made. +func (ln *LocalNode) SetFallbackUDP(port int) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.endpoint4.fallbackUDP = uint16(port) + ln.endpoint6.fallbackUDP = uint16(port) + ln.updateEndpoints() +} + +// UDPEndpointStatement should be called whenever a statement about the local node's +// UDP endpoint is received. It feeds the local endpoint predictor. +func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.endpointForIP(endpoint.IP).track.AddStatement(fromaddr.String(), endpoint.String()) + ln.updateEndpoints() +} + +// UDPContact should be called whenever the local node has announced itself to another node +// via UDP. It feeds the local endpoint predictor. +func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) { + ln.mu.Lock() + defer ln.mu.Unlock() + + ln.endpointForIP(toaddr.IP).track.AddContact(toaddr.String()) + ln.updateEndpoints() +} + +// updateEndpoints updates the record with predicted endpoints. +func (ln *LocalNode) updateEndpoints() { + ip4, udp4 := ln.endpoint4.get() + ip6, udp6 := ln.endpoint6.get() + + if ip4 != nil && !ip4.IsUnspecified() { + ln.set(enr.IPv4(ip4)) + } else { + ln.delete(enr.IPv4{}) + } + if ip6 != nil && !ip6.IsUnspecified() { + ln.set(enr.IPv6(ip6)) + } else { + ln.delete(enr.IPv6{}) + } + if udp4 != 0 { + ln.set(enr.UDP(udp4)) + } else { + ln.delete(enr.UDP(0)) + } + if udp6 != 0 && udp6 != udp4 { + ln.set(enr.UDP6(udp6)) + } else { + ln.delete(enr.UDP6(0)) + } +} + +// get returns the endpoint with highest precedence. +func (e *lnEndpoint) get() (newIP net.IP, newPort uint16) { + newPort = e.fallbackUDP + if e.fallbackIP != nil { + newIP = e.fallbackIP + } + if e.staticIP != nil { + newIP = e.staticIP + } else if ip, port := predictAddr(e.track); ip != nil { + newIP = ip + newPort = port + } + return newIP, newPort +} + +// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based +// endpoint representation to IP and port types. +func predictAddr(t *netutil.IPTracker) (net.IP, uint16) { + ep := t.PredictEndpoint() + if ep == "" { + return nil, 0 + } + ipString, portString, _ := net.SplitHostPort(ep) + ip := net.ParseIP(ipString) + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, 0 + } + return ip, uint16(port) +} + +func (ln *LocalNode) invalidate() { + ln.cur.Store((*Node)(nil)) +} + +func (ln *LocalNode) sign() { + if n := ln.cur.Load().(*Node); n != nil { + return // no changes + } + + var r enr.Record + for _, e := range ln.entries { + r.Set(e) + } + ln.bumpSeq() + r.SetSeq(ln.seq) + if err := SignV4(&r, ln.key); err != nil { + panic(fmt.Errorf("enode: can't sign record: %v", err)) + } + n, err := New(ValidSchemes, &r) + if err != nil { + panic(fmt.Errorf("enode: can't verify local record: %v", err)) + } + ln.cur.Store(n) + log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP()) +} + +func (ln *LocalNode) bumpSeq() { + ln.seq++ + ln.db.storeLocalSeq(ln.id, ln.seq) +} + +// nowMilliseconds gives the current timestamp at millisecond precision. +func nowMilliseconds() uint64 { + ns := time.Now().UnixNano() + if ns < 0 { + return 0 + } + return uint64(ns / 1000 / 1000) +} diff --git a/p2p/enode/localnode_test.go b/p2p/enode/localnode_test.go new file mode 100644 index 0000000000..c7fd79ef99 --- /dev/null +++ b/p2p/enode/localnode_test.go @@ -0,0 +1,129 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "crypto/rand" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/p2p/enr" +) + +func newLocalNodeForTesting() (*LocalNode, *DB) { + db, _ := OpenDB("") + key, _ := crypto.GenerateKey() + return NewLocalNode(db, key), db +} + +func TestLocalNode(t *testing.T) { + ln, db := newLocalNodeForTesting() + defer db.Close() + + if ln.Node().ID() != ln.ID() { + t.Fatal("inconsistent ID") + } + + ln.Set(enr.WithEntry("x", uint(3))) + var x uint + if err := ln.Node().Load(enr.WithEntry("x", &x)); err != nil { + t.Fatal("can't load entry 'x':", err) + } else if x != 3 { + t.Fatal("wrong value for entry 'x':", x) + } +} + +// This test checks that the sequence number is persisted between restarts. +func TestLocalNodeSeqPersist(t *testing.T) { + timestamp := nowMilliseconds() + + ln, db := newLocalNodeForTesting() + defer db.Close() + + initialSeq := ln.Node().Seq() + if initialSeq < timestamp { + t.Fatalf("wrong initial seq %d, want at least %d", initialSeq, timestamp) + } + + ln.Set(enr.WithEntry("x", uint(1))) + if s := ln.Node().Seq(); s != initialSeq+1 { + t.Fatalf("wrong seq %d after set, want %d", s, initialSeq+1) + } + + // Create a new instance, it should reload the sequence number. + // The number increases just after that because a new record is + // created without the "x" entry. + ln2 := NewLocalNode(db, ln.key) + if s := ln2.Node().Seq(); s != initialSeq+2 { + t.Fatalf("wrong seq %d on new instance, want %d", s, initialSeq+2) + } + + finalSeq := ln2.Node().Seq() + + // Create a new instance with a different node key on the same database. + // This should reset the sequence number. + key, _ := crypto.GenerateKey() + ln3 := NewLocalNode(db, key) + if s := ln3.Node().Seq(); s < finalSeq { + t.Fatalf("wrong seq %d on instance with changed key, want >= %d", s, finalSeq) + } +} + +// This test checks behavior of the endpoint predictor. +func TestLocalNodeEndpoint(t *testing.T) { + var ( + fallback = &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 80} + predicted = &net.UDPAddr{IP: net.IP{127, 0, 1, 2}, Port: 81} + staticIP = net.IP{127, 0, 1, 2} + ) + ln, db := newLocalNodeForTesting() + defer db.Close() + + // Nothing is set initially. + assert.Equal(t, net.IP(nil), ln.Node().IP()) + assert.Equal(t, 0, ln.Node().UDP()) + initialSeq := ln.Node().Seq() + + // Set up fallback address. + ln.SetFallbackIP(fallback.IP) + ln.SetFallbackUDP(fallback.Port) + assert.Equal(t, fallback.IP, ln.Node().IP()) + assert.Equal(t, fallback.Port, ln.Node().UDP()) + assert.Equal(t, initialSeq+1, ln.Node().Seq()) + + // Add endpoint statements from random hosts. + for i := 0; i < iptrackMinStatements; i++ { + assert.Equal(t, fallback.IP, ln.Node().IP()) + assert.Equal(t, fallback.Port, ln.Node().UDP()) + assert.Equal(t, initialSeq+1, ln.Node().Seq()) + + from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90} + rand.Read(from.IP) + ln.UDPEndpointStatement(from, predicted) + } + assert.Equal(t, predicted.IP, ln.Node().IP()) + assert.Equal(t, predicted.Port, ln.Node().UDP()) + assert.Equal(t, initialSeq+2, ln.Node().Seq()) + + // Static IP overrides prediction. + ln.SetStaticIP(staticIP) + assert.Equal(t, staticIP, ln.Node().IP()) + assert.Equal(t, fallback.Port, ln.Node().UDP()) + assert.Equal(t, initialSeq+3, ln.Node().Seq()) +} diff --git a/p2p/enode/node.go b/p2p/enode/node.go new file mode 100644 index 0000000000..7606ad40f1 --- /dev/null +++ b/p2p/enode/node.go @@ -0,0 +1,279 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "crypto/ecdsa" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "math/bits" + "net" + "strings" + + "github.com/tomochain/tomochain/p2p/enr" + "github.com/tomochain/tomochain/rlp" +) + +var errMissingPrefix = errors.New("missing 'enr:' prefix for base64-encoded record") + +// Node represents a host on the network. +type Node struct { + r enr.Record + id ID +} + +// New wraps a node record. The record must be valid according to the given +// identity scheme. +func New(validSchemes enr.IdentityScheme, r *enr.Record) (*Node, error) { + if err := r.VerifySignature(validSchemes); err != nil { + return nil, err + } + node := &Node{r: *r} + if n := copy(node.id[:], validSchemes.NodeAddr(&node.r)); n != len(ID{}) { + return nil, fmt.Errorf("invalid node ID length %d, need %d", n, len(ID{})) + } + return node, nil +} + +// MustParse parses a node record or enode:// URL. It panics if the input is invalid. +func MustParse(rawurl string) *Node { + n, err := Parse(ValidSchemes, rawurl) + if err != nil { + panic("invalid node: " + err.Error()) + } + return n +} + +// Parse decodes and verifies a base64-encoded node record. +func Parse(validSchemes enr.IdentityScheme, input string) (*Node, error) { + if strings.HasPrefix(input, "enode://") { + return ParseV4(input) + } + if !strings.HasPrefix(input, "enr:") { + return nil, errMissingPrefix + } + bin, err := base64.RawURLEncoding.DecodeString(input[4:]) + if err != nil { + return nil, err + } + var r enr.Record + if err := rlp.DecodeBytes(bin, &r); err != nil { + return nil, err + } + return New(validSchemes, &r) +} + +// ID returns the node identifier. +func (n *Node) ID() ID { + return n.id +} + +// Seq returns the sequence number of the underlying record. +func (n *Node) Seq() uint64 { + return n.r.Seq() +} + +// Incomplete returns true for nodes with no IP address. +func (n *Node) Incomplete() bool { + return n.IP() == nil +} + +// Load retrieves an entry from the underlying record. +func (n *Node) Load(k enr.Entry) error { + return n.r.Load(k) +} + +// IP returns the IP address of the node. This prefers IPv4 addresses. +func (n *Node) IP() net.IP { + var ( + ip4 enr.IPv4 + ip6 enr.IPv6 + ) + if n.Load(&ip4) == nil { + return net.IP(ip4) + } + if n.Load(&ip6) == nil { + return net.IP(ip6) + } + return nil +} + +// UDP returns the UDP port of the node. +func (n *Node) UDP() int { + var port enr.UDP + n.Load(&port) + return int(port) +} + +// TCP returns the TCP port of the node. +func (n *Node) TCP() int { + var port enr.TCP + n.Load(&port) + return int(port) +} + +// Pubkey returns the secp256k1 public key of the node, if present. +func (n *Node) Pubkey() *ecdsa.PublicKey { + var key ecdsa.PublicKey + if n.Load((*Secp256k1)(&key)) != nil { + return nil + } + return &key +} + +// Record returns the node's record. The return value is a copy and may +// be modified by the caller. +func (n *Node) Record() *enr.Record { + cpy := n.r + return &cpy +} + +// ValidateComplete checks whether n has a valid IP and UDP port. +// Deprecated: don't use this method. +func (n *Node) ValidateComplete() error { + if n.Incomplete() { + return errors.New("missing IP address") + } + if n.UDP() == 0 { + return errors.New("missing UDP port") + } + ip := n.IP() + if ip.IsMulticast() || ip.IsUnspecified() { + return errors.New("invalid IP (multicast/unspecified)") + } + // Validate the node key (on curve, etc.). + var key Secp256k1 + return n.Load(&key) +} + +// String returns the text representation of the record. +func (n *Node) String() string { + if isNewV4(n) { + return n.URLv4() // backwards-compatibility glue for NewV4 nodes + } + enc, _ := rlp.EncodeToBytes(&n.r) // always succeeds because record is valid + b64 := base64.RawURLEncoding.EncodeToString(enc) + return "enr:" + b64 +} + +// MarshalText implements encoding.TextMarshaler. +func (n *Node) MarshalText() ([]byte, error) { + return []byte(n.String()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (n *Node) UnmarshalText(text []byte) error { + dec, err := Parse(ValidSchemes, string(text)) + if err == nil { + *n = *dec + } + return err +} + +// ID is a unique identifier for each node. +type ID [32]byte + +// Bytes returns a byte slice representation of the ID +func (n ID) Bytes() []byte { + return n[:] +} + +// ID prints as a long hexadecimal number. +func (n ID) String() string { + return fmt.Sprintf("%x", n[:]) +} + +// GoString returns the Go syntax representation of a ID is a call to HexID. +func (n ID) GoString() string { + return fmt.Sprintf("enode.HexID(\"%x\")", n[:]) +} + +// TerminalString returns a shortened hex string for terminal logging. +func (n ID) TerminalString() string { + return hex.EncodeToString(n[:8]) +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (n ID) MarshalText() ([]byte, error) { + return []byte(hex.EncodeToString(n[:])), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +func (n *ID) UnmarshalText(text []byte) error { + id, err := ParseID(string(text)) + if err != nil { + return err + } + *n = id + return nil +} + +// HexID converts a hex string to an ID. +// The string may be prefixed with 0x. +// It panics if the string is not a valid ID. +func HexID(in string) ID { + id, err := ParseID(in) + if err != nil { + panic(err) + } + return id +} + +func ParseID(in string) (ID, error) { + var id ID + b, err := hex.DecodeString(strings.TrimPrefix(in, "0x")) + if err != nil { + return id, err + } else if len(b) != len(id) { + return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2) + } + copy(id[:], b) + return id, nil +} + +// DistCmp compares the distances a->target and b->target. +// Returns -1 if a is closer to target, 1 if b is closer to target +// and 0 if they are equal. +func DistCmp(target, a, b ID) int { + for i := range target { + da := a[i] ^ target[i] + db := b[i] ^ target[i] + if da > db { + return 1 + } else if da < db { + return -1 + } + } + return 0 +} + +// LogDist returns the logarithmic distance between a and b, log2(a ^ b). +func LogDist(a, b ID) int { + lz := 0 + for i := range a { + x := a[i] ^ b[i] + if x == 0 { + lz += 8 + } else { + lz += bits.LeadingZeros8(x) + break + } + } + return len(a)*8 - lz +} diff --git a/p2p/enode/node_test.go b/p2p/enode/node_test.go new file mode 100644 index 0000000000..a2ec526551 --- /dev/null +++ b/p2p/enode/node_test.go @@ -0,0 +1,145 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "bytes" + "encoding/hex" + "fmt" + "math/big" + "testing" + "testing/quick" + + "github.com/stretchr/testify/assert" + "github.com/tomochain/tomochain/p2p/enr" + "github.com/tomochain/tomochain/rlp" +) + +var pyRecord, _ = hex.DecodeString("f884b8407098ad865b00a582051940cb9cf36836572411a47278783077011599ed5cd16b76f2635f4e234738f30813a89eb9137e3e3df5266e3a1f11df72ecf1145ccb9c01826964827634826970847f00000189736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31388375647082765f") + +// TestPythonInterop checks that we can decode and verify a record produced by the Python +// implementation. +func TestPythonInterop(t *testing.T) { + var r enr.Record + if err := rlp.DecodeBytes(pyRecord, &r); err != nil { + t.Fatalf("can't decode: %v", err) + } + n, err := New(ValidSchemes, &r) + if err != nil { + t.Fatalf("can't verify record: %v", err) + } + + var ( + wantID = HexID("a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7") + wantSeq = uint64(1) + wantIP = enr.IPv4{127, 0, 0, 1} + wantUDP = enr.UDP(30303) + ) + if n.Seq() != wantSeq { + t.Errorf("wrong seq: got %d, want %d", n.Seq(), wantSeq) + } + if n.ID() != wantID { + t.Errorf("wrong id: got %x, want %x", n.ID(), wantID) + } + want := map[enr.Entry]interface{}{new(enr.IPv4): &wantIP, new(enr.UDP): &wantUDP} + for k, v := range want { + desc := fmt.Sprintf("loading key %q", k.ENRKey()) + if assert.NoError(t, n.Load(k), desc) { + assert.Equal(t, k, v, desc) + } + } +} + +func TestHexID(t *testing.T) { + ref := ID{0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188} + id1 := HexID("0x00000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc") + id2 := HexID("00000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc") + + if id1 != ref { + t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:]) + } + if id2 != ref { + t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:]) + } +} + +func TestID_textEncoding(t *testing.T) { + ref := ID{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x30, + 0x31, 0x32, + } + hex := "0102030405060708091011121314151617181920212223242526272829303132" + + text, err := ref.MarshalText() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(text, []byte(hex)) { + t.Fatalf("text encoding did not match\nexpected: %s\ngot: %s", hex, text) + } + + id := new(ID) + if err := id.UnmarshalText(text); err != nil { + t.Fatal(err) + } + if *id != ref { + t.Fatalf("text decoding did not match\nexpected: %s\ngot: %s", ref, id) + } +} + +func TestID_distcmp(t *testing.T) { + distcmpBig := func(target, a, b ID) int { + tbig := new(big.Int).SetBytes(target[:]) + abig := new(big.Int).SetBytes(a[:]) + bbig := new(big.Int).SetBytes(b[:]) + return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig)) + } + if err := quick.CheckEqual(DistCmp, distcmpBig, nil); err != nil { + t.Error(err) + } +} + +// The random tests is likely to miss the case where a and b are equal, +// this test checks it explicitly. +func TestID_distcmpEqual(t *testing.T) { + base := ID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + x := ID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0} + if DistCmp(base, x, x) != 0 { + t.Errorf("DistCmp(base, x, x) != 0") + } +} + +func TestID_logdist(t *testing.T) { + logdistBig := func(a, b ID) int { + abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:]) + return new(big.Int).Xor(abig, bbig).BitLen() + } + if err := quick.CheckEqual(LogDist, logdistBig, nil); err != nil { + t.Error(err) + } +} + +// The random tests is likely to miss the case where a and b are equal, +// this test checks it explicitly. +func TestID_logdistEqual(t *testing.T) { + x := ID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + if LogDist(x, x) != 0 { + t.Errorf("LogDist(x, x) != 0") + } +} diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go new file mode 100644 index 0000000000..466de5ce17 --- /dev/null +++ b/p2p/enode/nodedb.go @@ -0,0 +1,501 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "fmt" + "net" + "os" + "sync" + "time" + + "github.com/syndtr/goleveldb/leveldb" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" + "github.com/syndtr/goleveldb/leveldb/util" + "github.com/tomochain/tomochain/rlp" +) + +// Keys in the node database. +const ( + dbVersionKey = "version" // Version of the database to flush if changes + dbNodePrefix = "n:" // Identifier to prefix node entries with + dbLocalPrefix = "local:" + dbDiscoverRoot = "v4" + dbDiscv5Root = "v5" + + // These fields are stored per ID and IP, the full key is "n::v4::findfail". + // Use nodeItemKey to create those keys. + dbNodeFindFails = "findfail" + dbNodePing = "lastping" + dbNodePong = "lastpong" + dbNodeSeq = "seq" + + // Local information is keyed by ID only, the full key is "local::seq". + // Use localItemKey to create those keys. + dbLocalSeq = "seq" +) + +const ( + dbNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. + dbCleanupCycle = time.Hour // Time period for running the expiration task. + dbVersion = 9 +) + +var ( + errInvalidIP = errors.New("invalid IP") +) + +var zeroIP = make(net.IP, 16) + +// DB is the node database, storing previously seen nodes and any collected metadata about +// them for QoS purposes. +type DB struct { + lvl *leveldb.DB // Interface to the database itself + runner sync.Once // Ensures we can start at most one expirer + quit chan struct{} // Channel to signal the expiring thread to stop +} + +// OpenDB opens a node database for storing and retrieving infos about known peers in the +// network. If no path is given an in-memory, temporary database is constructed. +func OpenDB(path string) (*DB, error) { + if path == "" { + return newMemoryDB() + } + return newPersistentDB(path) +} + +// newMemoryNodeDB creates a new in-memory node database without a persistent backend. +func newMemoryDB() (*DB, error) { + db, err := leveldb.Open(storage.NewMemStorage(), nil) + if err != nil { + return nil, err + } + return &DB{lvl: db, quit: make(chan struct{})}, nil +} + +// newPersistentNodeDB creates/opens a leveldb backed persistent node database, +// also flushing its contents in case of a version mismatch. +func newPersistentDB(path string) (*DB, error) { + opts := &opt.Options{OpenFilesCacheCapacity: 5} + db, err := leveldb.OpenFile(path, opts) + if _, iscorrupted := err.(*errors.ErrCorrupted); iscorrupted { + db, err = leveldb.RecoverFile(path, nil) + } + if err != nil { + return nil, err + } + // The nodes contained in the cache correspond to a certain protocol version. + // Flush all nodes if the version doesn't match. + currentVer := make([]byte, binary.MaxVarintLen64) + currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))] + + blob, err := db.Get([]byte(dbVersionKey), nil) + switch err { + case leveldb.ErrNotFound: + // Version not found (i.e. empty cache), insert it + if err := db.Put([]byte(dbVersionKey), currentVer, nil); err != nil { + db.Close() + return nil, err + } + + case nil: + // Version present, flush if different + if !bytes.Equal(blob, currentVer) { + db.Close() + if err = os.RemoveAll(path); err != nil { + return nil, err + } + return newPersistentDB(path) + } + } + return &DB{lvl: db, quit: make(chan struct{})}, nil +} + +// nodeKey returns the database key for a node record. +func nodeKey(id ID) []byte { + key := append([]byte(dbNodePrefix), id[:]...) + key = append(key, ':') + key = append(key, dbDiscoverRoot...) + return key +} + +// splitNodeKey returns the node ID of a key created by nodeKey. +func splitNodeKey(key []byte) (id ID, rest []byte) { + if !bytes.HasPrefix(key, []byte(dbNodePrefix)) { + return ID{}, nil + } + item := key[len(dbNodePrefix):] + copy(id[:], item[:len(id)]) + return id, item[len(id)+1:] +} + +// nodeItemKey returns the database key for a node metadata field. +func nodeItemKey(id ID, ip net.IP, field string) []byte { + ip16 := ip.To16() + if ip16 == nil { + panic(fmt.Errorf("invalid IP (length %d)", len(ip))) + } + return bytes.Join([][]byte{nodeKey(id), ip16, []byte(field)}, []byte{':'}) +} + +// splitNodeItemKey returns the components of a key created by nodeItemKey. +func splitNodeItemKey(key []byte) (id ID, ip net.IP, field string) { + id, key = splitNodeKey(key) + // Skip discover root. + if string(key) == dbDiscoverRoot { + return id, nil, "" + } + key = key[len(dbDiscoverRoot)+1:] + // Split out the IP. + ip = key[:16] + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + key = key[16+1:] + // Field is the remainder of key. + field = string(key) + return id, ip, field +} + +func v5Key(id ID, ip net.IP, field string) []byte { + return bytes.Join([][]byte{ + []byte(dbNodePrefix), + id[:], + []byte(dbDiscv5Root), + ip.To16(), + []byte(field), + }, []byte{':'}) +} + +// localItemKey returns the key of a local node item. +func localItemKey(id ID, field string) []byte { + key := append([]byte(dbLocalPrefix), id[:]...) + key = append(key, ':') + key = append(key, field...) + return key +} + +// fetchInt64 retrieves an integer associated with a particular key. +func (db *DB) fetchInt64(key []byte) int64 { + blob, err := db.lvl.Get(key, nil) + if err != nil { + return 0 + } + val, read := binary.Varint(blob) + if read <= 0 { + return 0 + } + return val +} + +// storeInt64 stores an integer in the given key. +func (db *DB) storeInt64(key []byte, n int64) error { + blob := make([]byte, binary.MaxVarintLen64) + blob = blob[:binary.PutVarint(blob, n)] + return db.lvl.Put(key, blob, nil) +} + +// fetchUint64 retrieves an integer associated with a particular key. +func (db *DB) fetchUint64(key []byte) uint64 { + blob, err := db.lvl.Get(key, nil) + if err != nil { + return 0 + } + val, _ := binary.Uvarint(blob) + return val +} + +// storeUint64 stores an integer in the given key. +func (db *DB) storeUint64(key []byte, n uint64) error { + blob := make([]byte, binary.MaxVarintLen64) + blob = blob[:binary.PutUvarint(blob, n)] + return db.lvl.Put(key, blob, nil) +} + +// Node retrieves a node with a given id from the database. +func (db *DB) Node(id ID) *Node { + blob, err := db.lvl.Get(nodeKey(id), nil) + if err != nil { + return nil + } + return mustDecodeNode(id[:], blob) +} + +func mustDecodeNode(id, data []byte) *Node { + node := new(Node) + if err := rlp.DecodeBytes(data, &node.r); err != nil { + panic(fmt.Errorf("p2p/enode: can't decode node %x in DB: %v", id, err)) + } + // Restore node id cache. + copy(node.id[:], id) + return node +} + +// UpdateNode inserts - potentially overwriting - a node into the peer database. +func (db *DB) UpdateNode(node *Node) error { + if node.Seq() < db.NodeSeq(node.ID()) { + return nil + } + blob, err := rlp.EncodeToBytes(&node.r) + if err != nil { + return err + } + if err := db.lvl.Put(nodeKey(node.ID()), blob, nil); err != nil { + return err + } + return db.storeUint64(nodeItemKey(node.ID(), zeroIP, dbNodeSeq), node.Seq()) +} + +// NodeSeq returns the stored record sequence number of the given node. +func (db *DB) NodeSeq(id ID) uint64 { + return db.fetchUint64(nodeItemKey(id, zeroIP, dbNodeSeq)) +} + +// Resolve returns the stored record of the node if it has a larger sequence +// number than n. +func (db *DB) Resolve(n *Node) *Node { + if n.Seq() > db.NodeSeq(n.ID()) { + return n + } + return db.Node(n.ID()) +} + +// DeleteNode deletes all information associated with a node. +func (db *DB) DeleteNode(id ID) { + deleteRange(db.lvl, nodeKey(id)) +} + +func deleteRange(db *leveldb.DB, prefix []byte) { + it := db.NewIterator(util.BytesPrefix(prefix), nil) + defer it.Release() + for it.Next() { + db.Delete(it.Key(), nil) + } +} + +// ensureExpirer is a small helper method ensuring that the data expiration +// mechanism is running. If the expiration goroutine is already running, this +// method simply returns. +// +// The goal is to start the data evacuation only after the network successfully +// bootstrapped itself (to prevent dumping potentially useful seed nodes). Since +// it would require significant overhead to exactly trace the first successful +// convergence, it's simpler to "ensure" the correct state when an appropriate +// condition occurs (i.e. a successful bonding), and discard further events. +func (db *DB) ensureExpirer() { + db.runner.Do(func() { go db.expirer() }) +} + +// expirer should be started in a go routine, and is responsible for looping ad +// infinitum and dropping stale data from the database. +func (db *DB) expirer() { + tick := time.NewTicker(dbCleanupCycle) + defer tick.Stop() + for { + select { + case <-tick.C: + db.expireNodes() + case <-db.quit: + return + } + } +} + +// expireNodes iterates over the database and deletes all nodes that have not +// been seen (i.e. received a pong from) for some time. +func (db *DB) expireNodes() { + it := db.lvl.NewIterator(util.BytesPrefix([]byte(dbNodePrefix)), nil) + defer it.Release() + if !it.Next() { + return + } + + var ( + threshold = time.Now().Add(-dbNodeExpiration).Unix() + youngestPong int64 + atEnd = false + ) + for !atEnd { + id, ip, field := splitNodeItemKey(it.Key()) + if field == dbNodePong { + time, _ := binary.Varint(it.Value()) + if time > youngestPong { + youngestPong = time + } + if time < threshold { + // Last pong from this IP older than threshold, remove fields belonging to it. + deleteRange(db.lvl, nodeItemKey(id, ip, "")) + } + } + atEnd = !it.Next() + nextID, _ := splitNodeKey(it.Key()) + if atEnd || nextID != id { + // We've moved beyond the last entry of the current ID. + // Remove everything if there was no recent enough pong. + if youngestPong > 0 && youngestPong < threshold { + deleteRange(db.lvl, nodeKey(id)) + } + youngestPong = 0 + } + } +} + +// LastPingReceived retrieves the time of the last ping packet received from +// a remote node. +func (db *DB) LastPingReceived(id ID, ip net.IP) time.Time { + if ip = ip.To16(); ip == nil { + return time.Time{} + } + return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePing)), 0) +} + +// UpdateLastPingReceived updates the last time we tried contacting a remote node. +func (db *DB) UpdateLastPingReceived(id ID, ip net.IP, instance time.Time) error { + if ip = ip.To16(); ip == nil { + return errInvalidIP + } + return db.storeInt64(nodeItemKey(id, ip, dbNodePing), instance.Unix()) +} + +// LastPongReceived retrieves the time of the last successful pong from remote node. +func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time { + if ip = ip.To16(); ip == nil { + return time.Time{} + } + // Launch expirer + db.ensureExpirer() + return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePong)), 0) +} + +// UpdateLastPongReceived updates the last pong time of a node. +func (db *DB) UpdateLastPongReceived(id ID, ip net.IP, instance time.Time) error { + if ip = ip.To16(); ip == nil { + return errInvalidIP + } + return db.storeInt64(nodeItemKey(id, ip, dbNodePong), instance.Unix()) +} + +// FindFails retrieves the number of findnode failures since bonding. +func (db *DB) FindFails(id ID, ip net.IP) int { + if ip = ip.To16(); ip == nil { + return 0 + } + return int(db.fetchInt64(nodeItemKey(id, ip, dbNodeFindFails))) +} + +// UpdateFindFails updates the number of findnode failures since bonding. +func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error { + if ip = ip.To16(); ip == nil { + return errInvalidIP + } + return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails)) +} + +// FindFailsV5 retrieves the discv5 findnode failure counter. +func (db *DB) FindFailsV5(id ID, ip net.IP) int { + if ip = ip.To16(); ip == nil { + return 0 + } + return int(db.fetchInt64(v5Key(id, ip, dbNodeFindFails))) +} + +// UpdateFindFailsV5 stores the discv5 findnode failure counter. +func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error { + if ip = ip.To16(); ip == nil { + return errInvalidIP + } + return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails)) +} + +// localSeq retrieves the local record sequence counter, defaulting to the current +// timestamp if no previous exists. This ensures that wiping all data associated +// with a node (apart from its key) will not generate already used sequence nums. +func (db *DB) localSeq(id ID) uint64 { + if seq := db.fetchUint64(localItemKey(id, dbLocalSeq)); seq > 0 { + return seq + } + return nowMilliseconds() +} + +// storeLocalSeq stores the local record sequence counter. +func (db *DB) storeLocalSeq(id ID, n uint64) { + db.storeUint64(localItemKey(id, dbLocalSeq), n) +} + +// QuerySeeds retrieves random nodes to be used as potential seed nodes +// for bootstrapping. +func (db *DB) QuerySeeds(n int, maxAge time.Duration) []*Node { + var ( + now = time.Now() + nodes = make([]*Node, 0, n) + it = db.lvl.NewIterator(nil, nil) + id ID + ) + defer it.Release() + +seek: + for seeks := 0; len(nodes) < n && seeks < n*5; seeks++ { + // Seek to a random entry. The first byte is incremented by a + // random amount each time in order to increase the likelihood + // of hitting all existing nodes in very small databases. + ctr := id[0] + rand.Read(id[:]) + id[0] = ctr + id[0]%16 + it.Seek(nodeKey(id)) + + n := nextNode(it) + if n == nil { + id[0] = 0 + continue seek // iterator exhausted + } + if now.Sub(db.LastPongReceived(n.ID(), n.IP())) > maxAge { + continue seek + } + for i := range nodes { + if nodes[i].ID() == n.ID() { + continue seek // duplicate + } + } + nodes = append(nodes, n) + } + return nodes +} + +// reads the next node record from the iterator, skipping over other +// database entries. +func nextNode(it iterator.Iterator) *Node { + for end := false; !end; end = !it.Next() { + id, rest := splitNodeKey(it.Key()) + if string(rest) != dbDiscoverRoot { + continue + } + return mustDecodeNode(id[:], it.Value()) + } + return nil +} + +// Close flushes and closes the database files. +func (db *DB) Close() { + close(db.quit) + db.lvl.Close() +} diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go new file mode 100644 index 0000000000..38764f31b1 --- /dev/null +++ b/p2p/enode/nodedb_test.go @@ -0,0 +1,469 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "bytes" + "fmt" + "net" + "path/filepath" + "reflect" + "testing" + "time" +) + +var keytestID = HexID("51232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439") + +func TestDBNodeKey(t *testing.T) { + enc := nodeKey(keytestID) + want := []byte{ + 'n', ':', + 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // node id + 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, // + 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, // + 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, // + ':', 'v', '4', + } + if !bytes.Equal(enc, want) { + t.Errorf("wrong encoded key:\ngot %q\nwant %q", enc, want) + } + id, _ := splitNodeKey(enc) + if id != keytestID { + t.Errorf("wrong ID from splitNodeKey") + } +} + +func TestDBNodeItemKey(t *testing.T) { + wantIP := net.IP{127, 0, 0, 3} + wantField := "foobar" + enc := nodeItemKey(keytestID, wantIP, wantField) + want := []byte{ + 'n', ':', + 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // node id + 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, // + 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, // + 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, // + ':', 'v', '4', ':', + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // IP + 0x00, 0x00, 0xff, 0xff, 0x7f, 0x00, 0x00, 0x03, // + ':', 'f', 'o', 'o', 'b', 'a', 'r', + } + if !bytes.Equal(enc, want) { + t.Errorf("wrong encoded key:\ngot %q\nwant %q", enc, want) + } + id, ip, field := splitNodeItemKey(enc) + if id != keytestID { + t.Errorf("splitNodeItemKey returned wrong ID: %v", id) + } + if !ip.Equal(wantIP) { + t.Errorf("splitNodeItemKey returned wrong IP: %v", ip) + } + if field != wantField { + t.Errorf("splitNodeItemKey returned wrong field: %q", field) + } +} + +var nodeDBInt64Tests = []struct { + key []byte + value int64 +}{ + {key: []byte{0x01}, value: 1}, + {key: []byte{0x02}, value: 2}, + {key: []byte{0x03}, value: 3}, +} + +func TestDBInt64(t *testing.T) { + db, _ := OpenDB("") + defer db.Close() + + tests := nodeDBInt64Tests + for i := 0; i < len(tests); i++ { + // Insert the next value + if err := db.storeInt64(tests[i].key, tests[i].value); err != nil { + t.Errorf("test %d: failed to store value: %v", i, err) + } + // Check all existing and non existing values + for j := 0; j < len(tests); j++ { + num := db.fetchInt64(tests[j].key) + switch { + case j <= i && num != tests[j].value: + t.Errorf("test %d, item %d: value mismatch: have %v, want %v", i, j, num, tests[j].value) + case j > i && num != 0: + t.Errorf("test %d, item %d: value mismatch: have %v, want %v", i, j, num, 0) + } + } + } +} + +func TestDBFetchStore(t *testing.T) { + node := NewV4( + hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + net.IP{192, 168, 0, 1}, + 30303, + 30303, + ) + inst := time.Now() + num := 314 + + db, _ := OpenDB("") + defer db.Close() + + // Check fetch/store operations on a node ping object + if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != 0 { + t.Errorf("ping: non-existing object: %v", stored) + } + if err := db.UpdateLastPingReceived(node.ID(), node.IP(), inst); err != nil { + t.Errorf("ping: failed to update: %v", err) + } + if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() { + t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) + } + // Check fetch/store operations on a node pong object + if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != 0 { + t.Errorf("pong: non-existing object: %v", stored) + } + if err := db.UpdateLastPongReceived(node.ID(), node.IP(), inst); err != nil { + t.Errorf("pong: failed to update: %v", err) + } + if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() { + t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) + } + // Check fetch/store operations on a node findnode-failure object + if stored := db.FindFails(node.ID(), node.IP()); stored != 0 { + t.Errorf("find-node fails: non-existing object: %v", stored) + } + if err := db.UpdateFindFails(node.ID(), node.IP(), num); err != nil { + t.Errorf("find-node fails: failed to update: %v", err) + } + if stored := db.FindFails(node.ID(), node.IP()); stored != num { + t.Errorf("find-node fails: value mismatch: have %v, want %v", stored, num) + } + // Check fetch/store operations on an actual node object + if stored := db.Node(node.ID()); stored != nil { + t.Errorf("node: non-existing object: %v", stored) + } + if err := db.UpdateNode(node); err != nil { + t.Errorf("node: failed to update: %v", err) + } + if stored := db.Node(node.ID()); stored == nil { + t.Errorf("node: not found") + } else if !reflect.DeepEqual(stored, node) { + t.Errorf("node: data mismatch: have %v, want %v", stored, node) + } +} + +var nodeDBSeedQueryNodes = []struct { + node *Node + pong time.Time +}{ + // This one should not be in the result set because its last + // pong time is too far in the past. + { + node: NewV4( + hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + net.IP{127, 0, 0, 3}, + 30303, + 30303, + ), + pong: time.Now().Add(-3 * time.Hour), + }, + // This one shouldn't be in the result set because its + // nodeID is the local node's ID. + { + node: NewV4( + hexPubkey("ff93ff820abacd4351b0f14e47b324bc82ff014c226f3f66a53535734a3c150e7e38ca03ef0964ba55acddc768f5e99cd59dea95ddd4defbab1339c92fa319b2"), + net.IP{127, 0, 0, 3}, + 30303, + 30303, + ), + pong: time.Now().Add(-4 * time.Second), + }, + + // These should be in the result set. + { + node: NewV4( + hexPubkey("c2b5eb3f5dde05f815b63777809ee3e7e0cbb20035a6b00ce327191e6eaa8f26a8d461c9112b7ab94698e7361fa19fd647e603e73239002946d76085b6f928d6"), + net.IP{127, 0, 0, 1}, + 30303, + 30303, + ), + pong: time.Now().Add(-2 * time.Second), + }, + { + node: NewV4( + hexPubkey("6ca1d400c8ddf8acc94bcb0dd254911ad71a57bed5e0ae5aa205beed59b28c2339908e97990c493499613cff8ecf6c3dc7112a8ead220cdcd00d8847ca3db755"), + net.IP{127, 0, 0, 2}, + 30303, + 30303, + ), + pong: time.Now().Add(-3 * time.Second), + }, + { + node: NewV4( + hexPubkey("234dc63fe4d131212b38236c4c3411288d7bec61cbf7b120ff12c43dc60c96182882f4291d209db66f8a38e986c9c010ff59231a67f9515c7d1668b86b221a47"), + net.IP{127, 0, 0, 3}, + 30303, + 30303, + ), + pong: time.Now().Add(-1 * time.Second), + }, + { + node: NewV4( + hexPubkey("c013a50b4d1ebce5c377d8af8cb7114fd933ffc9627f96ad56d90fef5b7253ec736fd07ef9a81dc2955a997e54b7bf50afd0aa9f110595e2bec5bb7ce1657004"), + net.IP{127, 0, 0, 3}, + 30303, + 30303, + ), + pong: time.Now().Add(-2 * time.Second), + }, + { + node: NewV4( + hexPubkey("f141087e3e08af1aeec261ff75f48b5b1637f594ea9ad670e50051646b0416daa3b134c28788cbe98af26992a47652889cd8577ccc108ac02c6a664db2dc1283"), + net.IP{127, 0, 0, 3}, + 30303, + 30303, + ), + pong: time.Now().Add(-2 * time.Second), + }, +} + +func TestDBSeedQuery(t *testing.T) { + // Querying seeds uses seeks an might not find all nodes + // every time when the database is small. Run the test multiple + // times to avoid flakes. + const attempts = 15 + var err error + for i := 0; i < attempts; i++ { + if err = testSeedQuery(); err == nil { + return + } + } + if err != nil { + t.Errorf("no successful run in %d attempts: %v", attempts, err) + } +} + +func testSeedQuery() error { + db, _ := OpenDB("") + defer db.Close() + + // Insert a batch of nodes for querying + for i, seed := range nodeDBSeedQueryNodes { + if err := db.UpdateNode(seed.node); err != nil { + return fmt.Errorf("node %d: failed to insert: %v", i, err) + } + if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil { + return fmt.Errorf("node %d: failed to insert bondTime: %v", i, err) + } + } + + // Retrieve the entire batch and check for duplicates + seeds := db.QuerySeeds(len(nodeDBSeedQueryNodes)*2, time.Hour) + have := make(map[ID]struct{}, len(seeds)) + for _, seed := range seeds { + have[seed.ID()] = struct{}{} + } + want := make(map[ID]struct{}, len(nodeDBSeedQueryNodes[1:])) + for _, seed := range nodeDBSeedQueryNodes[1:] { + want[seed.node.ID()] = struct{}{} + } + if len(seeds) != len(want) { + return fmt.Errorf("seed count mismatch: have %v, want %v", len(seeds), len(want)) + } + for id := range have { + if _, ok := want[id]; !ok { + return fmt.Errorf("extra seed: %v", id) + } + } + for id := range want { + if _, ok := have[id]; !ok { + return fmt.Errorf("missing seed: %v", id) + } + } + return nil +} + +func TestDBPersistency(t *testing.T) { + root := t.TempDir() + + var ( + testKey = []byte("somekey") + testInt = int64(314) + ) + + // Create a persistent database and store some values + db, err := OpenDB(filepath.Join(root, "database")) + if err != nil { + t.Fatalf("failed to create persistent database: %v", err) + } + if err := db.storeInt64(testKey, testInt); err != nil { + t.Fatalf("failed to store value: %v.", err) + } + db.Close() + + // Reopen the database and check the value + db, err = OpenDB(filepath.Join(root, "database")) + if err != nil { + t.Fatalf("failed to open persistent database: %v", err) + } + if val := db.fetchInt64(testKey); val != testInt { + t.Fatalf("value mismatch: have %v, want %v", val, testInt) + } + db.Close() +} + +var nodeDBExpirationNodes = []struct { + node *Node + pong time.Time + storeNode bool + exp bool +}{ + // Node has new enough pong time and isn't expired: + { + node: NewV4( + hexPubkey("8d110e2ed4b446d9b5fb50f117e5f37fb7597af455e1dab0e6f045a6eeaa786a6781141659020d38bdc5e698ed3d4d2bafa8b5061810dfa63e8ac038db2e9b67"), + net.IP{127, 0, 0, 1}, + 30303, + 30303, + ), + storeNode: true, + pong: time.Now().Add(-dbNodeExpiration + time.Minute), + exp: false, + }, + // Node with pong time before expiration is removed: + { + node: NewV4( + hexPubkey("913a205579c32425b220dfba999d215066e5bdbf900226b11da1907eae5e93eb40616d47412cf819664e9eacbdfcca6b0c6e07e09847a38472d4be46ab0c3672"), + net.IP{127, 0, 0, 2}, + 30303, + 30303, + ), + storeNode: true, + pong: time.Now().Add(-dbNodeExpiration - time.Minute), + exp: true, + }, + // Just pong time, no node stored: + { + node: NewV4( + hexPubkey("b56670e0b6bad2c5dab9f9fe6f061a16cf78d68b6ae2cfda3144262d08d97ce5f46fd8799b6d1f709b1abe718f2863e224488bd7518e5e3b43809ac9bd1138ca"), + net.IP{127, 0, 0, 3}, + 30303, + 30303, + ), + storeNode: false, + pong: time.Now().Add(-dbNodeExpiration - time.Minute), + exp: true, + }, + // Node with multiple pong times, all older than expiration. + { + node: NewV4( + hexPubkey("29f619cebfd32c9eab34aec797ed5e3fe15b9b45be95b4df3f5fe6a9ae892f433eb08d7698b2ef3621568b0fb70d57b515ab30d4e72583b798298e0f0a66b9d1"), + net.IP{127, 0, 0, 4}, + 30303, + 30303, + ), + storeNode: true, + pong: time.Now().Add(-dbNodeExpiration - time.Minute), + exp: true, + }, + { + node: NewV4( + hexPubkey("29f619cebfd32c9eab34aec797ed5e3fe15b9b45be95b4df3f5fe6a9ae892f433eb08d7698b2ef3621568b0fb70d57b515ab30d4e72583b798298e0f0a66b9d1"), + net.IP{127, 0, 0, 5}, + 30303, + 30303, + ), + storeNode: false, + pong: time.Now().Add(-dbNodeExpiration - 2*time.Minute), + exp: true, + }, + // Node with multiple pong times, one newer, one older than expiration. + { + node: NewV4( + hexPubkey("3b73a9e5f4af6c4701c57c73cc8cfa0f4802840b24c11eba92aac3aef65644a3728b4b2aec8199f6d72bd66be2c65861c773129039bd47daa091ca90a6d4c857"), + net.IP{127, 0, 0, 6}, + 30303, + 30303, + ), + storeNode: true, + pong: time.Now().Add(-dbNodeExpiration + time.Minute), + exp: false, + }, + { + node: NewV4( + hexPubkey("3b73a9e5f4af6c4701c57c73cc8cfa0f4802840b24c11eba92aac3aef65644a3728b4b2aec8199f6d72bd66be2c65861c773129039bd47daa091ca90a6d4c857"), + net.IP{127, 0, 0, 7}, + 30303, + 30303, + ), + storeNode: false, + pong: time.Now().Add(-dbNodeExpiration - time.Minute), + exp: true, + }, +} + +func TestDBExpiration(t *testing.T) { + db, _ := OpenDB("") + defer db.Close() + + // Add all the test nodes and set their last pong time. + for i, seed := range nodeDBExpirationNodes { + if seed.storeNode { + if err := db.UpdateNode(seed.node); err != nil { + t.Fatalf("node %d: failed to insert: %v", i, err) + } + } + if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil { + t.Fatalf("node %d: failed to update bondTime: %v", i, err) + } + } + + db.expireNodes() + + // Check that expired entries have been removed. + unixZeroTime := time.Unix(0, 0) + for i, seed := range nodeDBExpirationNodes { + node := db.Node(seed.node.ID()) + pong := db.LastPongReceived(seed.node.ID(), seed.node.IP()) + if seed.exp { + if seed.storeNode && node != nil { + t.Errorf("node %d (%s) shouldn't be present after expiration", i, seed.node.ID().TerminalString()) + } + if !pong.Equal(unixZeroTime) { + t.Errorf("pong time %d (%s %v) shouldn't be present after expiration", i, seed.node.ID().TerminalString(), seed.node.IP()) + } + } else { + if seed.storeNode && node == nil { + t.Errorf("node %d (%s) should be present after expiration", i, seed.node.ID().TerminalString()) + } + if !pong.Equal(seed.pong.Truncate(1 * time.Second)) { + t.Errorf("pong time %d (%s) should be %v after expiration, but is %v", i, seed.node.ID().TerminalString(), seed.pong, pong) + } + } + } +} + +// This test checks that expiration works when discovery v5 data is present +// in the database. +func TestDBExpireV5(t *testing.T) { + db, _ := OpenDB("") + defer db.Close() + + ip := net.IP{127, 0, 0, 1} + db.UpdateFindFailsV5(ID{}, ip, 4) + db.expireNodes() +} diff --git a/p2p/enode/urlv4.go b/p2p/enode/urlv4.go new file mode 100644 index 0000000000..204839dccb --- /dev/null +++ b/p2p/enode/urlv4.go @@ -0,0 +1,203 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "crypto/ecdsa" + "encoding/hex" + "errors" + "fmt" + "net" + "net/url" + "regexp" + "strconv" + + "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/p2p/enr" +) + +var ( + incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$") + lookupIPFunc = net.LookupIP +) + +// MustParseV4 parses a node URL. It panics if the URL is not valid. +func MustParseV4(rawurl string) *Node { + n, err := ParseV4(rawurl) + if err != nil { + panic("invalid node URL: " + err.Error()) + } + return n +} + +// ParseV4 parses a node URL. +// +// There are two basic forms of node URLs: +// +// - incomplete nodes, which only have the public key (node ID) +// - complete nodes, which contain the public key and IP/Port information +// +// For incomplete nodes, the designator must look like one of these +// +// enode:// +// +// +// For complete nodes, the node ID is encoded in the username portion +// of the URL, separated from the host by an @ sign. The hostname can +// only be given as an IP address or using DNS domain name. +// The port in the host name section is the TCP listening port. If the +// TCP and UDP (discovery) ports differ, the UDP port is specified as +// query parameter "discport". +// +// In the following example, the node URL describes +// a node with IP address 10.3.58.6, TCP listening port 30303 +// and UDP discovery port 30301. +// +// enode://@10.3.58.6:30303?discport=30301 +func ParseV4(rawurl string) (*Node, error) { + if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil { + id, err := parsePubkey(m[1]) + if err != nil { + return nil, fmt.Errorf("invalid public key (%v)", err) + } + return NewV4(id, nil, 0, 0), nil + } + return parseComplete(rawurl) +} + +// NewV4 creates a node from discovery v4 node information. The record +// contained in the node has a zero-length signature. +func NewV4(pubkey *ecdsa.PublicKey, ip net.IP, tcp, udp int) *Node { + var r enr.Record + if len(ip) > 0 { + r.Set(enr.IP(ip)) + } + if udp != 0 { + r.Set(enr.UDP(udp)) + } + if tcp != 0 { + r.Set(enr.TCP(tcp)) + } + signV4Compat(&r, pubkey) + n, err := New(v4CompatID{}, &r) + if err != nil { + panic(err) + } + return n +} + +// isNewV4 returns true for nodes created by NewV4. +func isNewV4(n *Node) bool { + var k s256raw + return n.r.IdentityScheme() == "" && n.r.Load(&k) == nil && len(n.r.Signature()) == 0 +} + +func parseComplete(rawurl string) (*Node, error) { + var ( + id *ecdsa.PublicKey + tcpPort, udpPort uint64 + ) + u, err := url.Parse(rawurl) + if err != nil { + return nil, err + } + if u.Scheme != "enode" { + return nil, errors.New("invalid URL scheme, want \"enode\"") + } + // Parse the Node ID from the user portion. + if u.User == nil { + return nil, errors.New("does not contain node ID") + } + if id, err = parsePubkey(u.User.String()); err != nil { + return nil, fmt.Errorf("invalid public key (%v)", err) + } + // Parse the IP address. + ip := net.ParseIP(u.Hostname()) + if ip == nil { + ips, err := lookupIPFunc(u.Hostname()) + if err != nil { + return nil, err + } + ip = ips[0] + } + // Ensure the IP is 4 bytes long for IPv4 addresses. + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + // Parse the port numbers. + if tcpPort, err = strconv.ParseUint(u.Port(), 10, 16); err != nil { + return nil, errors.New("invalid port") + } + udpPort = tcpPort + qv := u.Query() + if qv.Get("discport") != "" { + udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16) + if err != nil { + return nil, errors.New("invalid discport in query") + } + } + return NewV4(id, ip, int(tcpPort), int(udpPort)), nil +} + +// parsePubkey parses a hex-encoded secp256k1 public key. +func parsePubkey(in string) (*ecdsa.PublicKey, error) { + b, err := hex.DecodeString(in) + if err != nil { + return nil, err + } else if len(b) != 64 { + return nil, fmt.Errorf("wrong length, want %d hex chars", 128) + } + b = append([]byte{0x4}, b...) + return crypto.UnmarshalPubkey(b) +} + +func (n *Node) URLv4() string { + var ( + scheme enr.ID + nodeid string + key ecdsa.PublicKey + ) + n.Load(&scheme) + n.Load((*Secp256k1)(&key)) + switch { + case scheme == "v4" || key != ecdsa.PublicKey{}: + nodeid = fmt.Sprintf("%x", crypto.FromECDSAPub(&key)[1:]) + default: + nodeid = fmt.Sprintf("%s.%x", scheme, n.id[:]) + } + u := url.URL{Scheme: "enode"} + if n.Incomplete() { + u.Host = nodeid + } else { + addr := net.TCPAddr{IP: n.IP(), Port: n.TCP()} + u.User = url.User(nodeid) + u.Host = addr.String() + if n.UDP() != n.TCP() { + u.RawQuery = "discport=" + strconv.Itoa(n.UDP()) + } + } + return u.String() +} + +// PubkeyToIDV4 derives the v4 node address from the given public key. +func PubkeyToIDV4(key *ecdsa.PublicKey) ID { + e := make([]byte, 64) + math.ReadBits(key.X, e[:len(e)/2]) + math.ReadBits(key.Y, e[len(e)/2:]) + return ID(crypto.Keccak256Hash(e)) +} diff --git a/p2p/enode/urlv4_test.go b/p2p/enode/urlv4_test.go new file mode 100644 index 0000000000..f56d28632b --- /dev/null +++ b/p2p/enode/urlv4_test.go @@ -0,0 +1,200 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "crypto/ecdsa" + "errors" + "net" + "reflect" + "strings" + "testing" + + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/p2p/enr" +) + +func init() { + lookupIPFunc = func(name string) ([]net.IP, error) { + if name == "node.example.org" { + return []net.IP{{33, 44, 55, 66}}, nil + } + return nil, errors.New("no such host") + } +} + +var parseNodeTests = []struct { + input string + wantError string + wantResult *Node +}{ + // Records + { + input: "enr:-IS4QGrdq0ugARp5T2BZ41TrZOqLc_oKvZoPuZP5--anqWE_J-Tucc1xgkOL7qXl0puJgT7qc2KSvcupc4NCb0nr4tdjgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQM6UUF2Rm-oFe1IH_rQkRCi00T2ybeMHRSvw1HDpRvjPYN1ZHCCdl8", + wantResult: func() *Node { + testKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8") + var r enr.Record + r.Set(enr.IP{127, 0, 0, 1}) + r.Set(enr.UDP(30303)) + r.SetSeq(99) + SignV4(&r, testKey) + n, _ := New(ValidSchemes, &r) + return n + }(), + }, + // Invalid Records + { + input: "enr:", + wantError: "EOF", // could be nicer + }, + { + input: "enr:x", + wantError: "illegal base64 data at input byte 0", + }, + { + input: "enr:-EmGZm9vYmFyY4JpZIJ2NIJpcIR_AAABiXNlY3AyNTZrMaEDOlFBdkZvqBXtSB_60JEQotNE9sm3jB0Ur8NRw6Ub4z2DdWRwgnZf", + wantError: enr.ErrInvalidSig.Error(), + }, + // Complete node URLs with IP address and ports + { + input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@invalid.:3", + wantError: `no such host`, + }, + { + input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo", + wantError: `invalid port`, + }, + { + input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo", + wantError: `invalid discport in query`, + }, + { + input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150", + wantResult: NewV4( + hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + net.IP{127, 0, 0, 1}, + 52150, + 52150, + ), + }, + { + input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150", + wantResult: NewV4( + hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + net.ParseIP("::"), + 52150, + 52150, + ), + }, + { + input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[2001:db8:3c4d:15::abcd:ef12]:52150", + wantResult: NewV4( + hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), + 52150, + 52150, + ), + }, + { + input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=22334", + wantResult: NewV4( + hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + net.IP{0x7f, 0x0, 0x0, 0x1}, + 52150, + 22334, + ), + }, + // Incomplete node URLs with no address + { + input: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439", + wantResult: NewV4( + hexPubkey("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), + nil, 0, 0, + ), + }, + // Invalid URLs + { + input: "", + wantError: errMissingPrefix.Error(), + }, + { + input: "1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439", + wantError: errMissingPrefix.Error(), + }, + { + input: "01010101", + wantError: errMissingPrefix.Error(), + }, + { + input: "enode://01010101@123.124.125.126:3", + wantError: `invalid public key (wrong length, want 128 hex chars)`, + }, + { + input: "enode://01010101", + wantError: `invalid public key (wrong length, want 128 hex chars)`, + }, + { + input: "http://foobar", + wantError: errMissingPrefix.Error(), + }, + { + input: "://foo", + wantError: errMissingPrefix.Error(), + }, +} + +func hexPubkey(h string) *ecdsa.PublicKey { + k, err := parsePubkey(h) + if err != nil { + panic(err) + } + return k +} + +func TestParseNode(t *testing.T) { + for _, test := range parseNodeTests { + n, err := Parse(ValidSchemes, test.input) + if test.wantError != "" { + if err == nil { + t.Errorf("test %q:\n got nil error, expected %#q", test.input, test.wantError) + continue + } else if !strings.Contains(err.Error(), test.wantError) { + t.Errorf("test %q:\n got error %#q, expected %#q", test.input, err.Error(), test.wantError) + continue + } + } else { + if err != nil { + t.Errorf("test %q:\n unexpected error: %v", test.input, err) + continue + } + if !reflect.DeepEqual(n, test.wantResult) { + t.Errorf("test %q:\n result mismatch:\ngot: %#v\nwant: %#v", test.input, n, test.wantResult) + } + } + } +} + +func TestNodeString(t *testing.T) { + for i, test := range parseNodeTests { + if test.wantError == "" && strings.HasPrefix(test.input, "enode://") { + str := test.wantResult.String() + if str != test.input { + t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.input) + } + } + } +} From 331c4d9c1fe10bfe68c904d60b553cfc0e9b8467 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 30 Oct 2023 16:38:52 +0700 Subject: [PATCH 096/119] Port p2p/discover to p2p/enode --- p2p/discover/database.go | 370 ------------------- p2p/discover/database_test.go | 380 ------------------- p2p/discover/node.go | 422 +++------------------ p2p/discover/table.go | 451 +++++++++-------------- p2p/discover/table_test.go | 631 ++++++++++++++------------------ p2p/discover/table_util_test.go | 167 +++++++++ p2p/discover/udp.go | 173 +++++---- p2p/discover/udp_test.go | 98 ++--- 8 files changed, 813 insertions(+), 1879 deletions(-) delete mode 100644 p2p/discover/database.go delete mode 100644 p2p/discover/database_test.go create mode 100644 p2p/discover/table_util_test.go diff --git a/p2p/discover/database.go b/p2p/discover/database.go deleted file mode 100644 index 43a4ca37fb..0000000000 --- a/p2p/discover/database.go +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -// Contains the node database, storing previously seen nodes and any collected -// metadata about them for QoS purposes. - -package discover - -import ( - "bytes" - "crypto/rand" - "encoding/binary" - "os" - "sync" - "time" - - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/errors" - "github.com/syndtr/goleveldb/leveldb/iterator" - "github.com/syndtr/goleveldb/leveldb/opt" - "github.com/syndtr/goleveldb/leveldb/storage" - "github.com/syndtr/goleveldb/leveldb/util" - "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/rlp" -) - -var ( - nodeDBNilNodeID = NodeID{} // Special node ID to use as a nil element. - nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. - nodeDBCleanupCycle = time.Hour // Time period for running the expiration task. -) - -// nodeDB stores all nodes we know about. -type nodeDB struct { - lvl *leveldb.DB // Interface to the database itself - self NodeID // Own node id to prevent adding it into the database - runner sync.Once // Ensures we can start at most one expirer - quit chan struct{} // Channel to signal the expiring thread to stop -} - -// Schema layout for the node database -var ( - nodeDBVersionKey = []byte("version") // Version of the database to flush if changes - nodeDBItemPrefix = []byte("n:") // Identifier to prefix node entries with - - nodeDBDiscoverRoot = ":discover" - nodeDBDiscoverPing = nodeDBDiscoverRoot + ":lastping" - nodeDBDiscoverPong = nodeDBDiscoverRoot + ":lastpong" - nodeDBDiscoverFindFails = nodeDBDiscoverRoot + ":findfail" -) - -// newNodeDB creates a new node database for storing and retrieving infos about -// known peers in the network. If no path is given, an in-memory, temporary -// database is constructed. -func newNodeDB(path string, version int, self NodeID) (*nodeDB, error) { - if path == "" { - return newMemoryNodeDB(self) - } - return newPersistentNodeDB(path, version, self) -} - -// newMemoryNodeDB creates a new in-memory node database without a persistent -// backend. -func newMemoryNodeDB(self NodeID) (*nodeDB, error) { - db, err := leveldb.Open(storage.NewMemStorage(), nil) - if err != nil { - return nil, err - } - return &nodeDB{ - lvl: db, - self: self, - quit: make(chan struct{}), - }, nil -} - -// newPersistentNodeDB creates/opens a leveldb backed persistent node database, -// also flushing its contents in case of a version mismatch. -func newPersistentNodeDB(path string, version int, self NodeID) (*nodeDB, error) { - opts := &opt.Options{OpenFilesCacheCapacity: 5} - db, err := leveldb.OpenFile(path, opts) - if _, iscorrupted := err.(*errors.ErrCorrupted); iscorrupted { - db, err = leveldb.RecoverFile(path, nil) - } - if err != nil { - return nil, err - } - // The nodes contained in the cache correspond to a certain protocol version. - // Flush all nodes if the version doesn't match. - currentVer := make([]byte, binary.MaxVarintLen64) - currentVer = currentVer[:binary.PutVarint(currentVer, int64(version))] - - blob, err := db.Get(nodeDBVersionKey, nil) - switch err { - case leveldb.ErrNotFound: - // Version not found (i.e. empty cache), insert it - if err := db.Put(nodeDBVersionKey, currentVer, nil); err != nil { - db.Close() - return nil, err - } - - case nil: - // Version present, flush if different - if !bytes.Equal(blob, currentVer) { - db.Close() - if err = os.RemoveAll(path); err != nil { - return nil, err - } - return newPersistentNodeDB(path, version, self) - } - } - return &nodeDB{ - lvl: db, - self: self, - quit: make(chan struct{}), - }, nil -} - -// makeKey generates the leveldb key-blob from a node id and its particular -// field of interest. -func makeKey(id NodeID, field string) []byte { - if bytes.Equal(id[:], nodeDBNilNodeID[:]) { - return []byte(field) - } - return append(nodeDBItemPrefix, append(id[:], field...)...) -} - -// splitKey tries to split a database key into a node id and a field part. -func splitKey(key []byte) (id NodeID, field string) { - // If the key is not of a node, return it plainly - if !bytes.HasPrefix(key, nodeDBItemPrefix) { - return NodeID{}, string(key) - } - // Otherwise split the id and field - item := key[len(nodeDBItemPrefix):] - copy(id[:], item[:len(id)]) - field = string(item[len(id):]) - - return id, field -} - -// fetchInt64 retrieves an integer instance associated with a particular -// database key. -func (db *nodeDB) fetchInt64(key []byte) int64 { - blob, err := db.lvl.Get(key, nil) - if err != nil { - return 0 - } - val, read := binary.Varint(blob) - if read <= 0 { - return 0 - } - return val -} - -// storeInt64 update a specific database entry to the current time instance as a -// unix timestamp. -func (db *nodeDB) storeInt64(key []byte, n int64) error { - blob := make([]byte, binary.MaxVarintLen64) - blob = blob[:binary.PutVarint(blob, n)] - - return db.lvl.Put(key, blob, nil) -} - -// node retrieves a node with a given id from the database. -func (db *nodeDB) node(id NodeID) *Node { - blob, err := db.lvl.Get(makeKey(id, nodeDBDiscoverRoot), nil) - if err != nil { - return nil - } - node := new(Node) - if err := rlp.DecodeBytes(blob, node); err != nil { - log.Error("Failed to decode node RLP", "err", err) - return nil - } - node.sha = crypto.Keccak256Hash(node.ID[:]) - return node -} - -// updateNode inserts - potentially overwriting - a node into the peer database. -func (db *nodeDB) updateNode(node *Node) error { - blob, err := rlp.EncodeToBytes(node) - if err != nil { - return err - } - return db.lvl.Put(makeKey(node.ID, nodeDBDiscoverRoot), blob, nil) -} - -// deleteNode deletes all information/keys associated with a node. -func (db *nodeDB) deleteNode(id NodeID) error { - deleter := db.lvl.NewIterator(util.BytesPrefix(makeKey(id, "")), nil) - for deleter.Next() { - if err := db.lvl.Delete(deleter.Key(), nil); err != nil { - return err - } - } - return nil -} - -// ensureExpirer is a small helper method ensuring that the data expiration -// mechanism is running. If the expiration goroutine is already running, this -// method simply returns. -// -// The goal is to start the data evacuation only after the network successfully -// bootstrapped itself (to prevent dumping potentially useful seed nodes). Since -// it would require significant overhead to exactly trace the first successful -// convergence, it's simpler to "ensure" the correct state when an appropriate -// condition occurs (i.e. a successful bonding), and discard further events. -func (db *nodeDB) ensureExpirer() { - db.runner.Do(func() { go db.expirer() }) -} - -// expirer should be started in a go routine, and is responsible for looping ad -// infinitum and dropping stale data from the database. -func (db *nodeDB) expirer() { - tick := time.NewTicker(nodeDBCleanupCycle) - defer tick.Stop() - for { - select { - case <-tick.C: - if err := db.expireNodes(); err != nil { - log.Error("Failed to expire nodedb items", "err", err) - } - case <-db.quit: - return - } - } -} - -// expireNodes iterates over the database and deletes all nodes that have not -// been seen (i.e. received a pong from) for some allotted time. -func (db *nodeDB) expireNodes() error { - threshold := time.Now().Add(-nodeDBNodeExpiration) - - // Find discovered nodes that are older than the allowance - it := db.lvl.NewIterator(nil, nil) - defer it.Release() - - for it.Next() { - // Skip the item if not a discovery node - id, field := splitKey(it.Key()) - if field != nodeDBDiscoverRoot { - continue - } - // Skip the node if not expired yet (and not self) - if !bytes.Equal(id[:], db.self[:]) { - if seen := db.bondTime(id); seen.After(threshold) { - continue - } - } - // Otherwise delete all associated information - db.deleteNode(id) - } - return nil -} - -// lastPing retrieves the time of the last ping packet send to a remote node, -// requesting binding. -func (db *nodeDB) lastPing(id NodeID) time.Time { - return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0) -} - -// updateLastPing updates the last time we tried contacting a remote node. -func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error { - return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) -} - -// bondTime retrieves the time of the last successful pong from remote node. -func (db *nodeDB) bondTime(id NodeID) time.Time { - return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) -} - -// hasBond reports whether the given node is considered bonded. -func (db *nodeDB) hasBond(id NodeID) bool { - return time.Since(db.bondTime(id)) < nodeDBNodeExpiration -} - -// updateBondTime updates the last pong time of a node. -func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error { - return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) -} - -// findFails retrieves the number of findnode failures since bonding. -func (db *nodeDB) findFails(id NodeID) int { - return int(db.fetchInt64(makeKey(id, nodeDBDiscoverFindFails))) -} - -// updateFindFails updates the number of findnode failures since bonding. -func (db *nodeDB) updateFindFails(id NodeID, fails int) error { - return db.storeInt64(makeKey(id, nodeDBDiscoverFindFails), int64(fails)) -} - -// querySeeds retrieves random nodes to be used as potential seed nodes -// for bootstrapping. -func (db *nodeDB) querySeeds(n int, maxAge time.Duration) []*Node { - var ( - now = time.Now() - nodes = make([]*Node, 0, n) - it = db.lvl.NewIterator(nil, nil) - id NodeID - ) - defer it.Release() - -seek: - for seeks := 0; len(nodes) < n && seeks < n*5; seeks++ { - // Seek to a random entry. The first byte is incremented by a - // random amount each time in order to increase the likelihood - // of hitting all existing nodes in very small databases. - ctr := id[0] - rand.Read(id[:]) - id[0] = ctr + id[0]%16 - it.Seek(makeKey(id, nodeDBDiscoverRoot)) - - n := nextNode(it) - if n == nil { - id[0] = 0 - continue seek // iterator exhausted - } - if n.ID == db.self { - continue seek - } - if now.Sub(db.bondTime(n.ID)) > maxAge { - continue seek - } - for i := range nodes { - if nodes[i].ID == n.ID { - continue seek // duplicate - } - } - nodes = append(nodes, n) - } - return nodes -} - -// reads the next node record from the iterator, skipping over other -// database entries. -func nextNode(it iterator.Iterator) *Node { - for end := false; !end; end = !it.Next() { - id, field := splitKey(it.Key()) - if field != nodeDBDiscoverRoot { - continue - } - var n Node - if err := rlp.DecodeBytes(it.Value(), &n); err != nil { - log.Warn("Failed to decode node RLP", "id", id, "err", err) - continue - } - return &n - } - return nil -} - -// close flushes and closes the database files. -func (db *nodeDB) close() { - close(db.quit) - db.lvl.Close() -} diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go deleted file mode 100644 index c4fa44d099..0000000000 --- a/p2p/discover/database_test.go +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package discover - -import ( - "bytes" - "io/ioutil" - "net" - "os" - "path/filepath" - "reflect" - "testing" - "time" -) - -var nodeDBKeyTests = []struct { - id NodeID - field string - key []byte -}{ - { - id: NodeID{}, - field: "version", - key: []byte{0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e}, // field - }, - { - id: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - field: ":discover", - key: []byte{0x6e, 0x3a, // prefix - 0x1d, 0xd9, 0xd6, 0x5c, 0x45, 0x52, 0xb5, 0xeb, // node id - 0x43, 0xd5, 0xad, 0x55, 0xa2, 0xee, 0x3f, 0x56, // - 0xc6, 0xcb, 0xc1, 0xc6, 0x4a, 0x5c, 0x8d, 0x65, // - 0x9f, 0x51, 0xfc, 0xd5, 0x1b, 0xac, 0xe2, 0x43, // - 0x51, 0x23, 0x2b, 0x8d, 0x78, 0x21, 0x61, 0x7d, // - 0x2b, 0x29, 0xb5, 0x4b, 0x81, 0xcd, 0xef, 0xb9, // - 0xb3, 0xe9, 0xc3, 0x7d, 0x7f, 0xd5, 0xf6, 0x32, // - 0x70, 0xbc, 0xc9, 0xe1, 0xa6, 0xf6, 0xa4, 0x39, // - 0x3a, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x76, 0x65, 0x72, // field - }, - }, -} - -func TestNodeDBKeys(t *testing.T) { - for i, tt := range nodeDBKeyTests { - if key := makeKey(tt.id, tt.field); !bytes.Equal(key, tt.key) { - t.Errorf("make test %d: key mismatch: have 0x%x, want 0x%x", i, key, tt.key) - } - id, field := splitKey(tt.key) - if !bytes.Equal(id[:], tt.id[:]) { - t.Errorf("split test %d: id mismatch: have 0x%x, want 0x%x", i, id, tt.id) - } - if field != tt.field { - t.Errorf("split test %d: field mismatch: have 0x%x, want 0x%x", i, field, tt.field) - } - } -} - -var nodeDBInt64Tests = []struct { - key []byte - value int64 -}{ - {key: []byte{0x01}, value: 1}, - {key: []byte{0x02}, value: 2}, - {key: []byte{0x03}, value: 3}, -} - -func TestNodeDBInt64(t *testing.T) { - db, _ := newNodeDB("", Version, NodeID{}) - defer db.close() - - tests := nodeDBInt64Tests - for i := 0; i < len(tests); i++ { - // Insert the next value - if err := db.storeInt64(tests[i].key, tests[i].value); err != nil { - t.Errorf("test %d: failed to store value: %v", i, err) - } - // Check all existing and non existing values - for j := 0; j < len(tests); j++ { - num := db.fetchInt64(tests[j].key) - switch { - case j <= i && num != tests[j].value: - t.Errorf("test %d, item %d: value mismatch: have %v, want %v", i, j, num, tests[j].value) - case j > i && num != 0: - t.Errorf("test %d, item %d: value mismatch: have %v, want %v", i, j, num, 0) - } - } - } -} - -func TestNodeDBFetchStore(t *testing.T) { - node := NewNode( - MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.IP{192, 168, 0, 1}, - 30303, - 30303, - ) - inst := time.Now() - num := 314 - - db, _ := newNodeDB("", Version, NodeID{}) - defer db.close() - - // Check fetch/store operations on a node ping object - if stored := db.lastPing(node.ID); stored.Unix() != 0 { - t.Errorf("ping: non-existing object: %v", stored) - } - if err := db.updateLastPing(node.ID, inst); err != nil { - t.Errorf("ping: failed to update: %v", err) - } - if stored := db.lastPing(node.ID); stored.Unix() != inst.Unix() { - t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) - } - // Check fetch/store operations on a node pong object - if stored := db.bondTime(node.ID); stored.Unix() != 0 { - t.Errorf("pong: non-existing object: %v", stored) - } - if err := db.updateBondTime(node.ID, inst); err != nil { - t.Errorf("pong: failed to update: %v", err) - } - if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() { - t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) - } - // Check fetch/store operations on a node findnode-failure object - if stored := db.findFails(node.ID); stored != 0 { - t.Errorf("find-node fails: non-existing object: %v", stored) - } - if err := db.updateFindFails(node.ID, num); err != nil { - t.Errorf("find-node fails: failed to update: %v", err) - } - if stored := db.findFails(node.ID); stored != num { - t.Errorf("find-node fails: value mismatch: have %v, want %v", stored, num) - } - // Check fetch/store operations on an actual node object - if stored := db.node(node.ID); stored != nil { - t.Errorf("node: non-existing object: %v", stored) - } - if err := db.updateNode(node); err != nil { - t.Errorf("node: failed to update: %v", err) - } - if stored := db.node(node.ID); stored == nil { - t.Errorf("node: not found") - } else if !reflect.DeepEqual(stored, node) { - t.Errorf("node: data mismatch: have %v, want %v", stored, node) - } -} - -var nodeDBSeedQueryNodes = []struct { - node *Node - pong time.Time -}{ - // This one should not be in the result set because its last - // pong time is too far in the past. - { - node: NewNode( - MustHexID("0x84d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.IP{127, 0, 0, 3}, - 30303, - 30303, - ), - pong: time.Now().Add(-3 * time.Hour), - }, - // This one shouldn't be in in the result set because its - // nodeID is the local node's ID. - { - node: NewNode( - MustHexID("0x57d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.IP{127, 0, 0, 3}, - 30303, - 30303, - ), - pong: time.Now().Add(-4 * time.Second), - }, - - // These should be in the result set. - { - node: NewNode( - MustHexID("0x22d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.IP{127, 0, 0, 1}, - 30303, - 30303, - ), - pong: time.Now().Add(-2 * time.Second), - }, - { - node: NewNode( - MustHexID("0x44d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.IP{127, 0, 0, 2}, - 30303, - 30303, - ), - pong: time.Now().Add(-3 * time.Second), - }, - { - node: NewNode( - MustHexID("0xe2d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.IP{127, 0, 0, 3}, - 30303, - 30303, - ), - pong: time.Now().Add(-1 * time.Second), - }, -} - -func TestNodeDBSeedQuery(t *testing.T) { - db, _ := newNodeDB("", Version, nodeDBSeedQueryNodes[1].node.ID) - defer db.close() - - // Insert a batch of nodes for querying - for i, seed := range nodeDBSeedQueryNodes { - if err := db.updateNode(seed.node); err != nil { - t.Fatalf("node %d: failed to insert: %v", i, err) - } - if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to insert bondTime: %v", i, err) - } - } - - // Retrieve the entire batch and check for duplicates - seeds := db.querySeeds(len(nodeDBSeedQueryNodes)*2, time.Hour) - have := make(map[NodeID]struct{}) - for _, seed := range seeds { - have[seed.ID] = struct{}{} - } - want := make(map[NodeID]struct{}) - for _, seed := range nodeDBSeedQueryNodes[2:] { - want[seed.node.ID] = struct{}{} - } - if len(seeds) != len(want) { - t.Errorf("seed count mismatch: have %v, want %v", len(seeds), len(want)) - } - for id := range have { - if _, ok := want[id]; !ok { - t.Errorf("extra seed: %v", id) - } - } - for id := range want { - if _, ok := have[id]; !ok { - t.Errorf("missing seed: %v", id) - } - } -} - -func TestNodeDBPersistency(t *testing.T) { - root, err := ioutil.TempDir("", "nodedb-") - if err != nil { - t.Fatalf("failed to create temporary data folder: %v", err) - } - defer os.RemoveAll(root) - - var ( - testKey = []byte("somekey") - testInt = int64(314) - ) - - // Create a persistent database and store some values - db, err := newNodeDB(filepath.Join(root, "database"), Version, NodeID{}) - if err != nil { - t.Fatalf("failed to create persistent database: %v", err) - } - if err := db.storeInt64(testKey, testInt); err != nil { - t.Fatalf("failed to store value: %v.", err) - } - db.close() - - // Reopen the database and check the value - db, err = newNodeDB(filepath.Join(root, "database"), Version, NodeID{}) - if err != nil { - t.Fatalf("failed to open persistent database: %v", err) - } - if val := db.fetchInt64(testKey); val != testInt { - t.Fatalf("value mismatch: have %v, want %v", val, testInt) - } - db.close() - - // Change the database version and check flush - db, err = newNodeDB(filepath.Join(root, "database"), Version+1, NodeID{}) - if err != nil { - t.Fatalf("failed to open persistent database: %v", err) - } - if val := db.fetchInt64(testKey); val != 0 { - t.Fatalf("value mismatch: have %v, want %v", val, 0) - } - db.close() -} - -var nodeDBExpirationNodes = []struct { - node *Node - pong time.Time - exp bool -}{ - { - node: NewNode( - MustHexID("0x01d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.IP{127, 0, 0, 1}, - 30303, - 30303, - ), - pong: time.Now().Add(-nodeDBNodeExpiration + time.Minute), - exp: false, - }, { - node: NewNode( - MustHexID("0x02d9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.IP{127, 0, 0, 2}, - 30303, - 30303, - ), - pong: time.Now().Add(-nodeDBNodeExpiration - time.Minute), - exp: true, - }, -} - -func TestNodeDBExpiration(t *testing.T) { - db, _ := newNodeDB("", Version, NodeID{}) - defer db.close() - - // Add all the test nodes and set their last pong time - for i, seed := range nodeDBExpirationNodes { - if err := db.updateNode(seed.node); err != nil { - t.Fatalf("node %d: failed to insert: %v", i, err) - } - if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to update bondTime: %v", i, err) - } - } - // Expire some of them, and check the rest - if err := db.expireNodes(); err != nil { - t.Fatalf("failed to expire nodes: %v", err) - } - for i, seed := range nodeDBExpirationNodes { - node := db.node(seed.node.ID) - if (node == nil && !seed.exp) || (node != nil && seed.exp) { - t.Errorf("node %d: expiration mismatch: have %v, want %v", i, node, seed.exp) - } - } -} - -func TestNodeDBSelfExpiration(t *testing.T) { - // Find a node in the tests that shouldn't expire, and assign it as self - var self NodeID - for _, node := range nodeDBExpirationNodes { - if !node.exp { - self = node.node.ID - break - } - } - db, _ := newNodeDB("", Version, self) - defer db.close() - - // Add all the test nodes and set their last pong time - for i, seed := range nodeDBExpirationNodes { - if err := db.updateNode(seed.node); err != nil { - t.Fatalf("node %d: failed to insert: %v", i, err) - } - if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { - t.Fatalf("node %d: failed to update bondTime: %v", i, err) - } - } - // Expire the nodes and make sure self has been evacuated too - if err := db.expireNodes(); err != nil { - t.Fatalf("failed to expire nodes: %v", err) - } - node := db.node(self) - if node != nil { - t.Errorf("self not evacuated") - } -} diff --git a/p2p/discover/node.go b/p2p/discover/node.go index 839f762278..9fe7bdb6d2 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -18,415 +18,87 @@ package discover import ( "crypto/ecdsa" - "crypto/elliptic" - "encoding/hex" "errors" - "fmt" "math/big" - "math/rand" "net" - "net/url" - "regexp" - "strconv" - "strings" "time" - "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/secp256k1" + "github.com/tomochain/tomochain/p2p/enode" ) -const NodeIDBits = 512 - -// Node represents a host on the network. +// node represents a host on the network. // The fields of Node may not be modified. -type Node struct { - IP net.IP // len 4 for IPv4 or 16 for IPv6 - UDP, TCP uint16 // port numbers - ID NodeID // the node's public key - - // This is a cached copy of sha3(ID) which is used for node - // distance calculations. This is part of Node in order to make it - // possible to write tests that need a node at a certain distance. - // In those tests, the content of sha will not actually correspond - // with ID. - sha common.Hash - - // Time when the node was added to the table. - addedAt time.Time -} - -// NewNode creates a new node. It is mostly meant to be used for -// testing purposes. -func NewNode(id NodeID, ip net.IP, udpPort, tcpPort uint16) *Node { - if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - } - return &Node{ - IP: ip, - UDP: udpPort, - TCP: tcpPort, - ID: id, - sha: crypto.Keccak256Hash(id[:]), - } -} - -func (n *Node) addr() *net.UDPAddr { - return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)} -} - -// Incomplete returns true for nodes with no IP address. -func (n *Node) Incomplete() bool { - return n.IP == nil -} - -// checks whether n is a valid complete node. -func (n *Node) validateComplete() error { - if n.Incomplete() { - return errors.New("incomplete node") - } - if n.UDP == 0 { - return errors.New("missing UDP port") - } - if n.TCP == 0 { - return errors.New("missing TCP port") - } - if n.IP.IsMulticast() || n.IP.IsUnspecified() { - return errors.New("invalid IP (multicast/unspecified)") - } - _, err := n.ID.Pubkey() // validate the key (on curve, etc.) - return err -} - -// The string representation of a Node is a URL. -// Please see ParseNode for a description of the format. -func (n *Node) String() string { - u := url.URL{Scheme: "enode"} - if n.Incomplete() { - u.Host = fmt.Sprintf("%x", n.ID[:]) - } else { - addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)} - u.User = url.User(fmt.Sprintf("%x", n.ID[:])) - u.Host = addr.String() - if n.UDP != n.TCP { - u.RawQuery = "discport=" + strconv.Itoa(int(n.UDP)) - } - } - return u.String() -} - -var incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$") - -// ParseNode parses a node designator. -// -// There are two basic forms of node designators -// - incomplete nodes, which only have the public key (node ID) -// - complete nodes, which contain the public key and IP/Port information -// -// For incomplete nodes, the designator must look like one of these -// -// enode:// -// -// -// For complete nodes, the node ID is encoded in the username portion -// of the URL, separated from the host by an @ sign. The hostname can -// only be given as an IP address, DNS domain names are not allowed. -// The port in the host name section is the TCP listening port. If the -// TCP and UDP (discovery) ports differ, the UDP port is specified as -// query parameter "discport". -// -// In the following example, the node URL describes -// a node with IP address 10.3.58.6, TCP listening port 30303 -// and UDP discovery port 30301. -// -// enode://@10.3.58.6:30303?discport=30301 -func ParseNode(rawurl string) (*Node, error) { - if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil { - id, err := HexID(m[1]) - if err != nil { - return nil, fmt.Errorf("invalid node ID (%v)", err) - } - return NewNode(id, nil, 0, 0), nil - } - return parseComplete(rawurl) +type node struct { + enode.Node + addedAt time.Time // time when the node was added to the table } -func parseComplete(rawurl string) (*Node, error) { - var ( - id NodeID - ip net.IP - tcpPort, udpPort uint64 - ) - u, err := url.Parse(rawurl) - if err != nil { - return nil, err - } - if u.Scheme != "enode" { - return nil, errors.New("invalid URL scheme, want \"enode\"") - } - // Parse the Node ID from the user portion. - if u.User == nil { - return nil, errors.New("does not contain node ID") - } - if id, err = HexID(u.User.String()); err != nil { - return nil, fmt.Errorf("invalid node ID (%v)", err) - } - // Parse the IP address. - host, port, err := net.SplitHostPort(u.Host) - if err != nil { - return nil, fmt.Errorf("invalid host: %v", err) - } - if ip = net.ParseIP(host); ip == nil { - return nil, errors.New("invalid IP address") - } - // Ensure the IP is 4 bytes long for IPv4 addresses. - if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - } - // Parse the port numbers. - if tcpPort, err = strconv.ParseUint(port, 10, 16); err != nil { - return nil, errors.New("invalid port") - } - udpPort = tcpPort - qv := u.Query() - if qv.Get("discport") != "" { - udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16) - if err != nil { - return nil, errors.New("invalid discport in query") - } - } - return NewNode(id, ip, uint16(udpPort), uint16(tcpPort)), nil -} - -// MustParseNode parses a node URL. It panics if the URL is not valid. -func MustParseNode(rawurl string) *Node { - n, err := ParseNode(rawurl) - if err != nil { - panic("invalid node URL: " + err.Error()) - } - return n -} +type encPubkey [64]byte -// MarshalText implements encoding.TextMarshaler. -func (n *Node) MarshalText() ([]byte, error) { - return []byte(n.String()), nil +func encodePubkey(key *ecdsa.PublicKey) encPubkey { + var e encPubkey + math.ReadBits(key.X, e[:len(e)/2]) + math.ReadBits(key.Y, e[len(e)/2:]) + return e } -// UnmarshalText implements encoding.TextUnmarshaler. -func (n *Node) UnmarshalText(text []byte) error { - dec, err := ParseNode(string(text)) - if err == nil { - *n = *dec - } - return err -} - -// NodeID is a unique identifier for each node. -// The node identifier is a marshaled elliptic curve public key. -type NodeID [NodeIDBits / 8]byte - -// Bytes returns a byte slice representation of the NodeID -func (n NodeID) Bytes() []byte { - return n[:] -} - -// NodeID prints as a long hexadecimal number. -func (n NodeID) String() string { - return fmt.Sprintf("%x", n[:]) -} - -// The Go syntax representation of a NodeID is a call to HexID. -func (n NodeID) GoString() string { - return fmt.Sprintf("discover.HexID(\"%x\")", n[:]) -} - -// TerminalString returns a shortened hex string for terminal logging. -func (n NodeID) TerminalString() string { - return hex.EncodeToString(n[:8]) -} - -// MarshalText implements the encoding.TextMarshaler interface. -func (n NodeID) MarshalText() ([]byte, error) { - return []byte(hex.EncodeToString(n[:])), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (n *NodeID) UnmarshalText(text []byte) error { - id, err := HexID(string(text)) - if err != nil { - return err - } - *n = id - return nil -} - -// BytesID converts a byte slice to a NodeID -func BytesID(b []byte) (NodeID, error) { - var id NodeID - if len(b) != len(id) { - return id, fmt.Errorf("wrong length, want %d bytes", len(id)) - } - copy(id[:], b) - return id, nil -} - -// MustBytesID converts a byte slice to a NodeID. -// It panics if the byte slice is not a valid NodeID. -func MustBytesID(b []byte) NodeID { - id, err := BytesID(b) - if err != nil { - panic(err) +func decodePubkey(e encPubkey) (*ecdsa.PublicKey, error) { + p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)} + half := len(e) / 2 + p.X.SetBytes(e[:half]) + p.Y.SetBytes(e[half:]) + if !p.Curve.IsOnCurve(p.X, p.Y) { + return nil, errors.New("invalid secp256k1 curve point") } - return id + return p, nil } -// HexID converts a hex string to a NodeID. -// The string may be prefixed with 0x. -func HexID(in string) (NodeID, error) { - var id NodeID - b, err := hex.DecodeString(strings.TrimPrefix(in, "0x")) - if err != nil { - return id, err - } else if len(b) != len(id) { - return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2) - } - copy(id[:], b) - return id, nil +func (e encPubkey) id() enode.ID { + return enode.ID(crypto.Keccak256Hash(e[:])) } -// MustHexID converts a hex string to a NodeID. -// It panics if the string is not a valid NodeID. -func MustHexID(in string) NodeID { - id, err := HexID(in) +// recoverNodeKey computes the public key used to sign the +// given hash from the signature. +func recoverNodeKey(hash, sig []byte) (key encPubkey, err error) { + pubkey, err := secp256k1.RecoverPubkey(hash, sig) if err != nil { - panic(err) + return key, err } - return id + copy(key[:], pubkey[1:]) + return key, nil } -// PubkeyID returns a marshaled representation of the given public key. -func PubkeyID(pub *ecdsa.PublicKey) NodeID { - var id NodeID - pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y) - if len(pbytes)-1 != len(id) { - panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes))) - } - copy(id[:], pbytes[1:]) - return id +func wrapNode(n *enode.Node) *node { + return &node{Node: *n} } -// Pubkey returns the public key represented by the node ID. -// It returns an error if the ID is not a point on the curve. -func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) { - p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)} - half := len(id) / 2 - p.X.SetBytes(id[:half]) - p.Y.SetBytes(id[half:]) - if !p.Curve.IsOnCurve(p.X, p.Y) { - return nil, errors.New("id is invalid secp256k1 curve point") +func wrapNodes(ns []*enode.Node) []*node { + result := make([]*node, len(ns)) + for i, n := range ns { + result[i] = wrapNode(n) } - return p, nil + return result } -// recoverNodeID computes the public key used to sign the -// given hash from the signature. -func recoverNodeID(hash, sig []byte) (id NodeID, err error) { - pubkey, err := secp256k1.RecoverPubkey(hash, sig) - if err != nil { - return id, err - } - if len(pubkey)-1 != len(id) { - return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8) - } - for i := range id { - id[i] = pubkey[i+1] - } - return id, nil +func unwrapNode(n *node) *enode.Node { + return &n.Node } -// distcmp compares the distances a->target and b->target. -// Returns -1 if a is closer to target, 1 if b is closer to target -// and 0 if they are equal. -func distcmp(target, a, b common.Hash) int { - for i := range target { - da := a[i] ^ target[i] - db := b[i] ^ target[i] - if da > db { - return 1 - } else if da < db { - return -1 - } +func unwrapNodes(ns []*node) []*enode.Node { + result := make([]*enode.Node, len(ns)) + for i, n := range ns { + result[i] = unwrapNode(n) } - return 0 + return result } -// table of leading zero counts for bytes [0..255] -var lzcount = [256]int{ - 8, 7, 6, 6, 5, 5, 5, 5, - 4, 4, 4, 4, 4, 4, 4, 4, - 3, 3, 3, 3, 3, 3, 3, 3, - 3, 3, 3, 3, 3, 3, 3, 3, - 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, +func (n *node) addr() *net.UDPAddr { + return &net.UDPAddr{IP: n.IP(), Port: n.UDP()} } -// logdist returns the logarithmic distance between a and b, log2(a ^ b). -func logdist(a, b common.Hash) int { - lz := 0 - for i := range a { - x := a[i] ^ b[i] - if x == 0 { - lz += 8 - } else { - lz += lzcount[x] - break - } - } - return len(a)*8 - lz -} - -// hashAtDistance returns a random hash such that logdist(a, b) == n -func hashAtDistance(a common.Hash, n int) (b common.Hash) { - if n == 0 { - return a - } - // flip bit at position n, fill the rest with random bits - b = a - pos := len(a) - n/8 - 1 - bit := byte(0x01) << (byte(n%8) - 1) - if bit == 0 { - pos++ - bit = 0x80 - } - b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits - for i := pos + 1; i < len(a); i++ { - b[i] = byte(rand.Intn(255)) - } - return b +func (n *node) String() string { + return n.Node.String() } diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 6fdd2cfd19..729ab1d0e0 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -23,9 +23,9 @@ package discover import ( + "crypto/ecdsa" crand "crypto/rand" "encoding/binary" - "errors" "fmt" mrand "math/rand" "net" @@ -36,13 +36,14 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/netutil" ) const ( - alpha = 3 // Kademlia concurrency factor - bucketSize = 200 // Kademlia bucket size - maxReplacements = 10 // Size of per-bucket replacement list + alpha = 3 // Kademlia concurrency factor + bucketSize = 16 // Kademlia bucket size + maxReplacements = 10 // Size of per-bucket replacement list // We keep buckets for the upper 1/15 of distances because // it's very unlikely we'll ever encounter a node that's closer. @@ -54,76 +55,56 @@ const ( bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24 tableIPLimit, tableSubnet = 10, 24 - maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions - maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped - - refreshInterval = 30 * time.Minute - revalidateInterval = 10 * time.Second - copyNodesInterval = 30 * time.Second - seedMinTableTime = 5 * time.Minute - seedCount = 30 - seedMaxAge = 5 * 24 * time.Hour + maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped + refreshInterval = 30 * time.Minute + revalidateInterval = 10 * time.Second + copyNodesInterval = 30 * time.Second + seedMinTableTime = 5 * time.Minute + seedCount = 30 + seedMaxAge = 5 * 24 * time.Hour ) type Table struct { mutex sync.Mutex // protects buckets, bucket content, nursery, rand buckets [nBuckets]*bucket // index of known nodes by distance - nursery []*Node // bootstrap nodes + nursery []*node // bootstrap nodes rand *mrand.Rand // source of randomness, periodically reseeded ips netutil.DistinctNetSet - db *nodeDB // database of known nodes + db *enode.DB // database of known nodes refreshReq chan chan struct{} initDone chan struct{} closeReq chan struct{} closed chan struct{} - bondmu sync.Mutex - bonding map[NodeID]*bondproc - bondslots chan struct{} // limits total number of active bonding processes - - nodeAddedHook func(*Node) // for testing + nodeAddedHook func(*node) // for testing net transport - self *Node // metadata of the local node -} - -type bondproc struct { - err error - n *Node - done chan struct{} + self *node // metadata of the local node } // transport is implemented by the UDP transport. // it is an interface so we can test without opening lots of UDP // sockets and without generating a private key. type transport interface { - ping(NodeID, *net.UDPAddr) error - waitping(NodeID) error - findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error) + ping(enode.ID, *net.UDPAddr) error + findnode(toid enode.ID, addr *net.UDPAddr, target encPubkey) ([]*node, error) close() } // bucket contains nodes, ordered by their last activity. the entry // that was most recently active is the first element in entries. type bucket struct { - entries []*Node // live entries, sorted by time of last contact - replacements []*Node // recently seen nodes to be used if revalidation fails + entries []*node // live entries, sorted by time of last contact + replacements []*node // recently seen nodes to be used if revalidation fails ips netutil.DistinctNetSet } -func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) { - // If no node database was given, use an in-memory one - db, err := newNodeDB(nodeDBPath, Version, ourID) - if err != nil { - return nil, err - } +func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.Node) (*Table, error) { tab := &Table{ net: t, db: db, - self: NewNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)), - bonding: make(map[NodeID]*bondproc), - bondslots: make(chan struct{}, maxBondingPingPongs), + self: wrapNode(self), refreshReq: make(chan chan struct{}), initDone: make(chan struct{}), closeReq: make(chan struct{}), @@ -134,20 +115,14 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string if err := tab.setFallbackNodes(bootnodes); err != nil { return nil, err } - for i := 0; i < cap(tab.bondslots); i++ { - tab.bondslots <- struct{}{} - } for i := range tab.buckets { tab.buckets[i] = &bucket{ ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit}, } } tab.seedRand() - tab.loadSeedNodes(false) - // Start the background expiration goroutine after loading seeds so that the search for - // seed nodes also considers older nodes that would otherwise be removed by the - // expiration. - tab.db.ensureExpirer() + tab.loadSeedNodes() + go tab.loop() return tab, nil } @@ -162,15 +137,13 @@ func (tab *Table) seedRand() { } // Self returns the local node. -// The returned node should not be modified by the caller. -func (tab *Table) Self() *Node { - return tab.self +func (tab *Table) Self() *enode.Node { + return unwrapNode(tab.self) } -// ReadRandomNodes fills the given slice with random nodes from the -// table. It will not write the same node more than once. The nodes in -// the slice are copies and can be modified by the caller. -func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { +// ReadRandomNodes fills the given slice with random nodes from the table. The results +// are guaranteed to be unique for a single invocation, no node will appear twice. +func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) { if !tab.isInitDone() { return 0 } @@ -178,10 +151,10 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { defer tab.mutex.Unlock() // Find all non-empty buckets and get a fresh slice of their entries. - var buckets [][]*Node - for _, b := range tab.buckets { + var buckets [][]*node + for _, b := range &tab.buckets { if len(b.entries) > 0 { - buckets = append(buckets, b.entries[:]) + buckets = append(buckets, b.entries) } } if len(buckets) == 0 { @@ -196,7 +169,7 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) { var i, j int for ; i < len(buf); i, j = i+1, (j+1)%len(buckets) { b := buckets[j] - buf[i] = &(*b[0]) + buf[i] = unwrapNode(b[0]) buckets[j] = b[1:] if len(b) == 1 { buckets = append(buckets[:j], buckets[j+1:]...) @@ -221,20 +194,13 @@ func (tab *Table) Close() { // setFallbackNodes sets the initial points of contact. These nodes // are used to connect to the network if the table is empty and there // are no known nodes in the database. -func (tab *Table) setFallbackNodes(nodes []*Node) error { +func (tab *Table) setFallbackNodes(nodes []*enode.Node) error { for _, n := range nodes { - if err := n.validateComplete(); err != nil { - return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err) + if err := n.ValidateComplete(); err != nil { + return fmt.Errorf("bad bootstrap node %q: %v", n, err) } } - tab.nursery = make([]*Node, 0, len(nodes)) - for _, n := range nodes { - cpy := *n - // Recompute cpy.sha because the node might not have been - // created by NewNode or ParseNode. - cpy.sha = crypto.Keccak256Hash(n.ID[:]) - tab.nursery = append(tab.nursery, &cpy) - } + tab.nursery = wrapNodes(nodes) return nil } @@ -250,47 +216,48 @@ func (tab *Table) isInitDone() bool { // Resolve searches for a specific node with the given ID. // It returns nil if the node could not be found. -func (tab *Table) Resolve(targetID NodeID) *Node { +func (tab *Table) Resolve(n *enode.Node) *enode.Node { // If the node is present in the local table, no // network interaction is required. - hash := crypto.Keccak256Hash(targetID[:]) + hash := n.ID() tab.mutex.Lock() cl := tab.closest(hash, 1) tab.mutex.Unlock() - if len(cl.entries) > 0 && cl.entries[0].ID == targetID { - return cl.entries[0] + if len(cl.entries) > 0 && cl.entries[0].ID() == hash { + return unwrapNode(cl.entries[0]) } // Otherwise, do a network lookup. - result := tab.Lookup(targetID) + result := tab.lookup(encodePubkey(n.Pubkey()), true) for _, n := range result { - if n.ID == targetID { - return n + if n.ID() == hash { + return unwrapNode(n) } } return nil } -// Lookup performs a network search for nodes close -// to the given target. It approaches the target by querying -// nodes that are closer to it on each iteration. -// The given target does not need to be an actual node -// identifier. -func (tab *Table) Lookup(targetID NodeID) []*Node { - return tab.lookup(targetID, true) +// LookupRandom finds random nodes in the network. +func (tab *Table) LookupRandom() []*enode.Node { + var target encPubkey + crand.Read(target[:]) + return unwrapNodes(tab.lookup(target, true)) } -func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node { +// lookup performs a network search for nodes close to the given target. It approaches the +// target by querying nodes that are closer to it on each iteration. The given target does +// not need to be an actual node identifier. +func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node { var ( - target = crypto.Keccak256Hash(targetID[:]) - asked = make(map[NodeID]bool) - seen = make(map[NodeID]bool) - reply = make(chan []*Node, alpha) + target = enode.ID(crypto.Keccak256Hash(targetKey[:])) + asked = make(map[enode.ID]bool) + seen = make(map[enode.ID]bool) + reply = make(chan []*node, alpha) pendingQueries = 0 result *nodesByDistance ) // don't query further if we hit ourself. // unlikely to happen often in practice. - asked[tab.self.ID] = true + asked[tab.self.ID()] = true for { tab.mutex.Lock() @@ -312,25 +279,10 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node { // ask the alpha closest nodes that we haven't asked yet for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ { n := result.entries[i] - if !asked[n.ID] { - asked[n.ID] = true + if !asked[n.ID()] { + asked[n.ID()] = true pendingQueries++ - go func() { - // Find potential neighbors to bond with - r, err := tab.net.findnode(n.ID, n.addr(), targetID) - if err != nil { - // Bump the failure counter to detect and evacuate non-bonded entries - fails := tab.db.findFails(n.ID) + 1 - tab.db.updateFindFails(n.ID, fails) - log.Trace("Bumping findnode failure counter", "id", n.ID, "failcount", fails) - - if fails >= maxFindnodeFailures { - log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails) - tab.delete(n) - } - } - reply <- tab.bondall(r) - }() + go tab.findnode(n, targetKey, reply) } } if pendingQueries == 0 { @@ -339,8 +291,8 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node { } // wait for the next reply for _, n := range <-reply { - if n != nil && !seen[n.ID] { - seen[n.ID] = true + if n != nil && !seen[n.ID()] { + seen[n.ID()] = true result.push(n, bucketSize) } } @@ -349,6 +301,29 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node { return result.entries } +func (tab *Table) findnode(n *node, targetKey encPubkey, reply chan<- []*node) { + fails := tab.db.FindFails(n.ID(), n.IP()) + r, err := tab.net.findnode(n.ID(), n.addr(), targetKey) + if err != nil || len(r) == 0 { + fails++ + tab.db.UpdateFindFails(n.ID(), n.IP(), fails) + log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err) + if fails >= maxFindnodeFailures { + log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails) + tab.delete(n) + } + } else if fails > 0 { + tab.db.UpdateFindFails(n.ID(), n.IP(), fails-1) + } + + // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll + // just remove those again during revalidation. + for _, n := range r { + tab.add(n) + } + reply <- r +} + func (tab *Table) refresh() <-chan struct{} { done := make(chan struct{}) select { @@ -401,7 +376,7 @@ loop: case <-revalidateDone: revalidate.Reset(tab.nextRevalidateTime()) case <-copyNodes.C: - go tab.copyBondedNodes() + go tab.copyLiveNodes() case <-tab.closeReq: break loop } @@ -416,7 +391,6 @@ loop: for _, ch := range waiting { close(ch) } - tab.db.close() close(tab.closed) } @@ -429,10 +403,14 @@ func (tab *Table) doRefresh(done chan struct{}) { // Load nodes from the database and insert // them. This should yield a few previously seen nodes that are // (hopefully) still alive. - tab.loadSeedNodes(true) + tab.loadSeedNodes() // Run self lookup to discover new neighbor nodes. - tab.lookup(tab.self.ID, false) + // We can only do this if we have a secp256k1 identity. + var key ecdsa.PublicKey + if err := tab.self.Load((*enode.Secp256k1)(&key)); err == nil { + tab.lookup(encodePubkey(&key), false) + } // The Kademlia paper specifies that the bucket refresh should // perform a lookup in the least recently used bucket. We cannot @@ -441,22 +419,19 @@ func (tab *Table) doRefresh(done chan struct{}) { // sha3 preimage that falls into a chosen bucket. // We perform a few lookups with a random target instead. for i := 0; i < 3; i++ { - var target NodeID + var target encPubkey crand.Read(target[:]) tab.lookup(target, false) } } -func (tab *Table) loadSeedNodes(bond bool) { - seeds := tab.db.querySeeds(seedCount, seedMaxAge) +func (tab *Table) loadSeedNodes() { + seeds := wrapNodes(tab.db.QuerySeeds(seedCount, seedMaxAge)) seeds = append(seeds, tab.nursery...) - if bond { - seeds = tab.bondall(seeds) - } for i := range seeds { seed := seeds[i] - age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }} - log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age) + age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) }} + log.Debug("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age) tab.add(seed) } } @@ -473,28 +448,28 @@ func (tab *Table) doRevalidate(done chan<- struct{}) { } // Ping the selected node and wait for a pong. - err := tab.ping(last.ID, last.addr()) + err := tab.net.ping(last.ID(), last.addr()) tab.mutex.Lock() defer tab.mutex.Unlock() b := tab.buckets[bi] if err == nil { // The node responded, move it to the front. - log.Debug("Revalidated node", "b", bi, "id", last.ID) + log.Debug("Revalidated node", "b", bi, "id", last.ID()) b.bump(last) return } // No reply received, pick a replacement or delete the node if there aren't // any replacements. if r := tab.replace(b, last); r != nil { - log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP) + log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "r", r.ID(), "rip", r.IP()) } else { - log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP) + log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP()) } } // nodeToRevalidate returns the last node in a random, non-empty bucket. -func (tab *Table) nodeToRevalidate() (n *Node, bi int) { +func (tab *Table) nodeToRevalidate() (n *node, bi int) { tab.mutex.Lock() defer tab.mutex.Unlock() @@ -515,17 +490,17 @@ func (tab *Table) nextRevalidateTime() time.Duration { return time.Duration(tab.rand.Int63n(int64(revalidateInterval))) } -// copyBondedNodes adds nodes from the table to the database if they have been in the table +// copyLiveNodes adds nodes from the table to the database if they have been in the table // longer then minTableTime. -func (tab *Table) copyBondedNodes() { +func (tab *Table) copyLiveNodes() { tab.mutex.Lock() defer tab.mutex.Unlock() now := time.Now() - for _, b := range tab.buckets { + for _, b := range &tab.buckets { for _, n := range b.entries { if now.Sub(n.addedAt) >= seedMinTableTime { - tab.db.updateNode(n) + tab.db.UpdateNode(unwrapNode(n)) } } } @@ -533,12 +508,12 @@ func (tab *Table) copyBondedNodes() { // closest returns the n nodes in the table that are closest to the // given id. The caller must hold tab.mutex. -func (tab *Table) closest(target common.Hash, nresults int) *nodesByDistance { +func (tab *Table) closest(target enode.ID, nresults int) *nodesByDistance { // This is a very wasteful way to find the closest nodes but // obviously correct. I believe that tree-based buckets would make // this easier to implement efficiently. close := &nodesByDistance{target: target} - for _, b := range tab.buckets { + for _, b := range &tab.buckets { for _, n := range b.entries { close.push(n, nresults) } @@ -547,176 +522,76 @@ func (tab *Table) closest(target common.Hash, nresults int) *nodesByDistance { } func (tab *Table) len() (n int) { - for _, b := range tab.buckets { + for _, b := range &tab.buckets { n += len(b.entries) } return n } -// bondall bonds with all given nodes concurrently and returns -// those nodes for which bonding has probably succeeded. -func (tab *Table) bondall(nodes []*Node) (result []*Node) { - rc := make(chan *Node, len(nodes)) - for i := range nodes { - go func(n *Node) { - nn, _ := tab.bond(false, n.ID, n.addr(), n.TCP) - rc <- nn - }(nodes[i]) - } - for range nodes { - if n := <-rc; n != nil { - result = append(result, n) - } - } - return result -} - -// bond ensures the local node has a bond with the given remote node. -// It also attempts to insert the node into the table if bonding succeeds. -// The caller must not hold tab.mutex. -// -// A bond is must be established before sending findnode requests. -// Both sides must have completed a ping/pong exchange for a bond to -// exist. The total number of active bonding processes is limited in -// order to restrain network use. -// -// bond is meant to operate idempotently in that bonding with a remote -// node which still remembers a previously established bond will work. -// The remote node will simply not send a ping back, causing waitping -// to time out. -// -// If pinged is true, the remote node has just pinged us and one half -// of the process can be skipped. -func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) { - if id == tab.self.ID { - return nil, errors.New("is self") - } - if pinged && !tab.isInitDone() { - return nil, errors.New("still initializing") - } - // Start bonding if we haven't seen this node for a while or if it failed findnode too often. - node, fails := tab.db.node(id), tab.db.findFails(id) - age := time.Since(tab.db.bondTime(id)) - var result error - if fails > 0 || age > nodeDBNodeExpiration { - log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age) - - tab.bondmu.Lock() - w := tab.bonding[id] - if w != nil { - // Wait for an existing bonding process to complete. - tab.bondmu.Unlock() - <-w.done - } else { - // Register a new bonding process. - w = &bondproc{done: make(chan struct{})} - tab.bonding[id] = w - tab.bondmu.Unlock() - // Do the ping/pong. The result goes into w. - tab.pingpong(w, pinged, id, addr, tcpPort) - // Unregister the process after it's done. - tab.bondmu.Lock() - delete(tab.bonding, id) - tab.bondmu.Unlock() - } - // Retrieve the bonding results - result = w.err - if result == nil { - node = w.n - } - } - // Add the node to the table even if the bonding ping/pong - // fails. It will be relaced quickly if it continues to be - // unresponsive. - if node != nil { - tab.add(node) - tab.db.updateFindFails(id, 0) - } - return node, result -} - -func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) { - // Request a bonding slot to limit network usage - <-tab.bondslots - defer func() { tab.bondslots <- struct{}{} }() - - // Ping the remote side and wait for a pong. - if w.err = tab.ping(id, addr); w.err != nil { - close(w.done) - return - } - if !pinged { - // Give the remote node a chance to ping us before we start - // sending findnode requests. If they still remember us, - // waitping will simply time out. - tab.net.waitping(id) - } - // Bonding succeeded, update the node database. - w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort) - close(w.done) -} - -// ping a remote endpoint and wait for a reply, also updating the node -// database accordingly. -func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error { - tab.db.updateLastPing(id, time.Now()) - if err := tab.net.ping(id, addr); err != nil { - return err - } - tab.db.updateBondTime(id, time.Now()) - return nil -} - // bucket returns the bucket for the given node ID hash. -func (tab *Table) bucket(sha common.Hash) *bucket { - d := logdist(tab.self.sha, sha) +func (tab *Table) bucket(id enode.ID) *bucket { + d := enode.LogDist(tab.self.ID(), id) if d <= bucketMinDistance { return tab.buckets[0] } return tab.buckets[d-bucketMinDistance-1] } -// add attempts to add the given node its corresponding bucket. If the -// bucket has space available, adding the node succeeds immediately. -// Otherwise, the node is added if the least recently active node in -// the bucket does not respond to a ping packet. +// add attempts to add the given node to its corresponding bucket. If the bucket has space +// available, adding the node succeeds immediately. Otherwise, the node is added if the +// least recently active node in the bucket does not respond to a ping packet. // // The caller must not hold tab.mutex. -func (tab *Table) add(new *Node) { +func (tab *Table) add(n *node) { + if n.ID() == tab.self.ID() { + return + } + tab.mutex.Lock() defer tab.mutex.Unlock() - - b := tab.bucket(new.sha) - if !tab.bumpOrAdd(b, new) { + b := tab.bucket(n.ID()) + if !tab.bumpOrAdd(b, n) { // Node is not in table. Add it to the replacement list. - tab.addReplacement(b, new) + tab.addReplacement(b, n) + } +} + +// addThroughPing adds the given node to the table. Compared to plain +// 'add' there is an additional safety measure: if the table is still +// initializing the node is not added. This prevents an attack where the +// table could be filled by just sending ping repeatedly. +// +// The caller must not hold tab.mutex. +func (tab *Table) addThroughPing(n *node) { + if !tab.isInitDone() { + return } + tab.add(n) } // stuff adds nodes the table to the end of their corresponding bucket // if the bucket is not full. The caller must not hold tab.mutex. -func (tab *Table) stuff(nodes []*Node) { +func (tab *Table) stuff(nodes []*node) { tab.mutex.Lock() defer tab.mutex.Unlock() for _, n := range nodes { - if n.ID == tab.self.ID { + if n.ID() == tab.self.ID() { continue // don't add self } - b := tab.bucket(n.sha) + b := tab.bucket(n.ID()) if len(b.entries) < bucketSize { tab.bumpOrAdd(b, n) } } } -// delete removes an entry from the node table (used to evacuate -// failed/non-bonded discovery peers). -func (tab *Table) delete(node *Node) { +// delete removes an entry from the node table. It is used to evacuate dead nodes. +func (tab *Table) delete(node *node) { tab.mutex.Lock() defer tab.mutex.Unlock() - tab.deleteInBucket(tab.bucket(node.sha), node) + tab.deleteInBucket(tab.bucket(node.ID()), node) } func (tab *Table) addIP(b *bucket, ip net.IP) bool { @@ -743,27 +618,27 @@ func (tab *Table) removeIP(b *bucket, ip net.IP) { b.ips.Remove(ip) } -func (tab *Table) addReplacement(b *bucket, n *Node) { +func (tab *Table) addReplacement(b *bucket, n *node) { for _, e := range b.replacements { - if e.ID == n.ID { + if e.ID() == n.ID() { return // already in list } } - if !tab.addIP(b, n.IP) { + if !tab.addIP(b, n.IP()) { return } - var removed *Node + var removed *node b.replacements, removed = pushNode(b.replacements, n, maxReplacements) if removed != nil { - tab.removeIP(b, removed.IP) + tab.removeIP(b, removed.IP()) } } // replace removes n from the replacement list and replaces 'last' with it if it is the // last entry in the bucket. If 'last' isn't the last entry, it has either been replaced // with someone else or became active. -func (tab *Table) replace(b *bucket, last *Node) *Node { - if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID != last.ID { +func (tab *Table) replace(b *bucket, last *node) *node { + if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID() != last.ID() { // Entry has moved, don't replace it. return nil } @@ -775,15 +650,15 @@ func (tab *Table) replace(b *bucket, last *Node) *Node { r := b.replacements[tab.rand.Intn(len(b.replacements))] b.replacements = deleteNode(b.replacements, r) b.entries[len(b.entries)-1] = r - tab.removeIP(b, last.IP) + tab.removeIP(b, last.IP()) return r } // bump moves the given node to the front of the bucket entry list // if it is contained in that list. -func (b *bucket) bump(n *Node) bool { +func (b *bucket) bump(n *node) bool { for i := range b.entries { - if b.entries[i].ID == n.ID { + if b.entries[i].ID() == n.ID() { // move it to the front copy(b.entries[1:], b.entries[:i]) b.entries[0] = n @@ -795,11 +670,11 @@ func (b *bucket) bump(n *Node) bool { // bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't // full. The return value is true if n is in the bucket. -func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool { +func (tab *Table) bumpOrAdd(b *bucket, n *node) bool { if b.bump(n) { return true } - if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) { + if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP()) { return false } b.entries, _ = pushNode(b.entries, n, bucketSize) @@ -811,13 +686,13 @@ func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool { return true } -func (tab *Table) deleteInBucket(b *bucket, n *Node) { +func (tab *Table) deleteInBucket(b *bucket, n *node) { b.entries = deleteNode(b.entries, n) - tab.removeIP(b, n.IP) + tab.removeIP(b, n.IP()) } // pushNode adds n to the front of list, keeping at most max items. -func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) { +func pushNode(list []*node, n *node, max int) ([]*node, *node) { if len(list) < max { list = append(list, nil) } @@ -828,9 +703,9 @@ func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) { } // deleteNode removes n from list. -func deleteNode(list []*Node, n *Node) []*Node { +func deleteNode(list []*node, n *node) []*node { for i := range list { - if list[i].ID == n.ID { + if list[i].ID() == n.ID() { return append(list[:i], list[i+1:]...) } } @@ -840,14 +715,14 @@ func deleteNode(list []*Node, n *Node) []*Node { // nodesByDistance is a list of nodes, ordered by // distance to target. type nodesByDistance struct { - entries []*Node - target common.Hash + entries []*node + target enode.ID } // push adds the given node to the list, keeping the total size below maxElems. -func (h *nodesByDistance) push(n *Node, maxElems int) { +func (h *nodesByDistance) push(n *node, maxElems int) { ix := sort.Search(len(h.entries), func(i int) bool { - return distcmp(h.target, h.entries[i].sha, n.sha) > 0 + return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0 }) if len(h.entries) < maxElems { h.entries = append(h.entries, n) diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index b81b0bfde9..388baf6dac 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -20,7 +20,6 @@ import ( "crypto/ecdsa" "fmt" "math/rand" - "sync" "net" "reflect" @@ -28,8 +27,9 @@ import ( "testing/quick" "time" - "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/p2p/enode" + "github.com/tomochain/tomochain/p2p/enr" ) func TestTable_pingReplace(t *testing.T) { @@ -49,30 +49,27 @@ func TestTable_pingReplace(t *testing.T) { func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) { transport := newPingRecorder() - tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) + tab, db := newTestTable(transport) defer tab.Close() + defer db.Close() - // Wait for init so bond is accepted. <-tab.initDone - // fill up the sender's bucket. - pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99) + // Fill up the sender's bucket. + pingKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8") + pingSender := wrapNode(enode.NewV4(&pingKey.PublicKey, net.IP{}, 99, 99)) last := fillBucket(tab, pingSender) - // this call to bond should replace the last node - // in its bucket if the node is not responding. - transport.dead[last.ID] = !lastInBucketIsResponding - transport.dead[pingSender.ID] = !newNodeIsResponding - tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0) + // Add the sender as if it just pinged us. Revalidate should replace the last node in + // its bucket if it is unresponsive. Revalidate again to ensure that + transport.dead[last.ID()] = !lastInBucketIsResponding + transport.dead[pingSender.ID()] = !newNodeIsResponding + tab.add(pingSender) + tab.doRevalidate(make(chan struct{}, 1)) tab.doRevalidate(make(chan struct{}, 1)) - // first ping goes to sender (bonding pingback) - if !transport.pinged[pingSender.ID] { - t.Error("table did not ping back sender") - } - if !transport.pinged[last.ID] { - // second ping goes to oldest node in bucket - // to see whether it is still alive. + if !transport.pinged[last.ID()] { + // Oldest node in bucket is pinged to see whether it is still alive. t.Error("table did not ping last node in bucket") } @@ -82,14 +79,14 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding if !lastInBucketIsResponding && !newNodeIsResponding { wantSize-- } - if l := len(tab.bucket(pingSender.sha).entries); l != wantSize { + if l := len(tab.bucket(pingSender.ID()).entries); l != wantSize { t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize) } - if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding { + if found := contains(tab.bucket(pingSender.ID()).entries, last.ID()); found != lastInBucketIsResponding { t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding) } wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding - if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry { + if found := contains(tab.bucket(pingSender.ID()).entries, pingSender.ID()); found != wantNewEntry { t.Errorf("new entry found: %t, want: %t", found, wantNewEntry) } } @@ -102,9 +99,9 @@ func TestBucket_bumpNoDuplicates(t *testing.T) { Values: func(args []reflect.Value, rand *rand.Rand) { // generate a random list of nodes. this will be the content of the bucket. n := rand.Intn(bucketSize-1) + 1 - nodes := make([]*Node, n) + nodes := make([]*node, n) for i := range nodes { - nodes[i] = nodeAtDistance(common.Hash{}, 200) + nodes[i] = nodeAtDistance(enode.ID{}, 200, intIP(200)) } args[0] = reflect.ValueOf(nodes) // generate random bump positions. @@ -116,8 +113,8 @@ func TestBucket_bumpNoDuplicates(t *testing.T) { }, } - prop := func(nodes []*Node, bumps []int) (ok bool) { - b := &bucket{entries: make([]*Node, len(nodes))} + prop := func(nodes []*node, bumps []int) (ok bool) { + b := &bucket{entries: make([]*node, len(nodes))} copy(b.entries, nodes) for i, pos := range bumps { b.bump(b.entries[pos]) @@ -139,12 +136,12 @@ func TestBucket_bumpNoDuplicates(t *testing.T) { // This checks that the table-wide IP limit is applied correctly. func TestTable_IPLimit(t *testing.T) { transport := newPingRecorder() - tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) + tab, db := newTestTable(transport) defer tab.Close() + defer db.Close() for i := 0; i < tableIPLimit+1; i++ { - n := nodeAtDistance(tab.self.sha, i) - n.IP = net.IP{172, 0, 1, byte(i)} + n := nodeAtDistance(tab.self.ID(), i, net.IP{172, 0, 1, byte(i)}) tab.add(n) } if tab.len() > tableIPLimit { @@ -152,16 +149,16 @@ func TestTable_IPLimit(t *testing.T) { } } -// This checks that the table-wide IP limit is applied correctly. +// This checks that the per-bucket IP limit is applied correctly. func TestTable_BucketIPLimit(t *testing.T) { transport := newPingRecorder() - tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) + tab, db := newTestTable(transport) defer tab.Close() + defer db.Close() d := 3 for i := 0; i < bucketIPLimit+1; i++ { - n := nodeAtDistance(tab.self.sha, d) - n.IP = net.IP{172, 0, 1, byte(i)} + n := nodeAtDistance(tab.self.ID(), d, net.IP{172, 0, 1, byte(i)}) tab.add(n) } if tab.len() > bucketIPLimit { @@ -169,70 +166,18 @@ func TestTable_BucketIPLimit(t *testing.T) { } } -// fillBucket inserts nodes into the given bucket until -// it is full. The node's IDs dont correspond to their -// hashes. -func fillBucket(tab *Table, n *Node) (last *Node) { - ld := logdist(tab.self.sha, n.sha) - b := tab.bucket(n.sha) - for len(b.entries) < bucketSize { - b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld)) - } - return b.entries[bucketSize-1] -} - -// nodeAtDistance creates a node for which logdist(base, n.sha) == ld. -// The node's ID does not correspond to n.sha. -func nodeAtDistance(base common.Hash, ld int) (n *Node) { - n = new(Node) - n.sha = hashAtDistance(base, ld) - n.IP = net.IP{byte(ld), 0, 2, byte(ld)} - copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID - return n -} - -type pingRecorder struct { - mu sync.Mutex - dead, pinged map[NodeID]bool -} - -func newPingRecorder() *pingRecorder { - return &pingRecorder{ - dead: make(map[NodeID]bool), - pinged: make(map[NodeID]bool), - } -} - -func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { - return nil, nil -} -func (t *pingRecorder) close() {} -func (t *pingRecorder) waitping(from NodeID) error { - return nil // remote always pings -} -func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { - t.mu.Lock() - defer t.mu.Unlock() - - t.pinged[toid] = true - if t.dead[toid] { - return errTimeout - } else { - return nil - } -} - func TestTable_closest(t *testing.T) { t.Parallel() test := func(test *closeTest) bool { // for any node table, Target and N transport := newPingRecorder() - tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil) + tab, db := newTestTable(transport) defer tab.Close() + defer db.Close() tab.stuff(test.All) - // check that doClosest(Target, N) returns nodes + // check that closest(Target, N) returns nodes result := tab.closest(test.Target, test.N).entries if hasDuplicates(result) { t.Errorf("result contains duplicates") @@ -258,15 +203,15 @@ func TestTable_closest(t *testing.T) { // check that the result nodes have minimum distance to target. for _, b := range tab.buckets { for _, n := range b.entries { - if contains(result, n.ID) { + if contains(result, n.ID()) { continue // don't run the check below for nodes in result } - farthestResult := result[len(result)-1].sha - if distcmp(test.Target, n.sha, farthestResult) < 0 { + farthestResult := result[len(result)-1].ID() + if enode.DistCmp(test.Target, n.ID(), farthestResult) < 0 { t.Errorf("table contains node that is closer to target but it's not in result") t.Logf(" Target: %v", test.Target) t.Logf(" Farthest Result: %v", farthestResult) - t.Logf(" ID: %v", n.ID) + t.Logf(" ID: %v", n.ID()) return false } } @@ -283,25 +228,26 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) { MaxCount: 200, Rand: rand.New(rand.NewSource(time.Now().Unix())), Values: func(args []reflect.Value, rand *rand.Rand) { - args[0] = reflect.ValueOf(make([]*Node, rand.Intn(1000))) + args[0] = reflect.ValueOf(make([]*enode.Node, rand.Intn(1000))) }, } - test := func(buf []*Node) bool { + test := func(buf []*enode.Node) bool { transport := newPingRecorder() - tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) + tab, db := newTestTable(transport) defer tab.Close() + defer db.Close() <-tab.initDone for i := 0; i < len(buf); i++ { ld := cfg.Rand.Intn(len(tab.buckets)) - tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)}) + tab.stuff([]*node{nodeAtDistance(tab.self.ID(), ld, intIP(ld))}) } gotN := tab.ReadRandomNodes(buf) if gotN != tab.len() { t.Errorf("wrong number of nodes, got %d, want %d", gotN, tab.len()) return false } - if hasDuplicates(buf[:gotN]) { + if hasDuplicates(wrapNodes(buf[:gotN])) { t.Errorf("result contains duplicates") return false } @@ -313,302 +259,304 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) { } type closeTest struct { - Self NodeID - Target common.Hash - All []*Node + Self enode.ID + Target enode.ID + All []*node N int } func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { t := &closeTest{ - Self: gen(NodeID{}, rand).(NodeID), - Target: gen(common.Hash{}, rand).(common.Hash), + Self: gen(enode.ID{}, rand).(enode.ID), + Target: gen(enode.ID{}, rand).(enode.ID), N: rand.Intn(bucketSize), } - for _, id := range gen([]NodeID{}, rand).([]NodeID) { - t.All = append(t.All, &Node{ID: id}) + for _, id := range gen([]enode.ID{}, rand).([]enode.ID) { + n := enode.SignNull(new(enr.Record), id) + t.All = append(t.All, wrapNode(n)) } return reflect.ValueOf(t) } -//func TestTable_Lookup(t *testing.T) { -// bucketSizeTest := 16 -// self := nodeAtDistance(common.Hash{}, 0) -// tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil) -// defer tab.Close() -// -// // lookup on empty table returns no nodes -// if results := tab.Lookup(lookupTestnet.target); len(results) > 0 { -// t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) -// } -// // seed table with initial node (otherwise lookup will terminate immediately) -// seed := NewNode(lookupTestnet.dists[256][0], net.IP{}, 256, 0) -// tab.stuff([]*Node{seed}) -// -// results := tab.Lookup(lookupTestnet.target) -// t.Logf("results:") -// for _, e := range results { -// t.Logf(" ld=%d, %x", logdist(lookupTestnet.targetSha, e.sha), e.sha[:]) -// } -// if len(results) != bucketSizeTest { -// t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSizeTest) -// } -// if hasDuplicates(results) { -// t.Errorf("result set contains duplicate entries") -// } -// if !sortedByDistanceTo(lookupTestnet.targetSha, results) { -// t.Errorf("result set not sorted by distance to target") -// } -// // TODO: check result nodes are actually closest -//} +func TestTable_Lookup(t *testing.T) { + tab, db := newTestTable(lookupTestnet) + defer tab.Close() + defer db.Close() + + // lookup on empty table returns no nodes + if results := tab.lookup(lookupTestnet.target, false); len(results) > 0 { + t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) + } + // seed table with initial node (otherwise lookup will terminate immediately) + seedKey, _ := decodePubkey(lookupTestnet.dists[256][0]) + seed := wrapNode(enode.NewV4(seedKey, net.IP{}, 0, 256)) + tab.stuff([]*node{seed}) + + results := tab.lookup(lookupTestnet.target, true) + t.Logf("results:") + for _, e := range results { + t.Logf(" ld=%d, %x", enode.LogDist(lookupTestnet.targetSha, e.ID()), e.ID().Bytes()) + } + if len(results) != bucketSize { + t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize) + } + if hasDuplicates(results) { + t.Errorf("result set contains duplicate entries") + } + if !sortedByDistanceTo(lookupTestnet.targetSha, results) { + t.Errorf("result set not sorted by distance to target") + } + // TODO: check result nodes are actually closest +} // This is the test network for the Lookup test. // The nodes were obtained by running testnet.mine with a random NodeID as target. var lookupTestnet = &preminedTestnet{ - target: MustHexID("166aea4f556532c6d34e8b740e5d314af7e9ac0ca79833bd751d6b665f12dfd38ec563c363b32f02aef4a80b44fd3def94612d497b99cb5f17fd24de454927ec"), - targetSha: common.Hash{0x5c, 0x94, 0x4e, 0xe5, 0x1c, 0x5a, 0xe9, 0xf7, 0x2a, 0x95, 0xec, 0xcb, 0x8a, 0xed, 0x3, 0x74, 0xee, 0xcb, 0x51, 0x19, 0xd7, 0x20, 0xcb, 0xea, 0x68, 0x13, 0xe8, 0xe0, 0xd6, 0xad, 0x92, 0x61}, - dists: [257][]NodeID{ + target: hexEncPubkey("166aea4f556532c6d34e8b740e5d314af7e9ac0ca79833bd751d6b665f12dfd38ec563c363b32f02aef4a80b44fd3def94612d497b99cb5f17fd24de454927ec"), + targetSha: enode.HexID("5c944ee51c5ae9f72a95eccb8aed0374eecb5119d720cbea6813e8e0d6ad9261"), + dists: [257][]encPubkey{ 240: { - MustHexID("2001ad5e3e80c71b952161bc0186731cf5ffe942d24a79230a0555802296238e57ea7a32f5b6f18564eadc1c65389448481f8c9338df0a3dbd18f708cbc2cbcb"), - MustHexID("6ba3f4f57d084b6bf94cc4555b8c657e4a8ac7b7baf23c6874efc21dd1e4f56b7eb2721e07f5242d2f1d8381fc8cae535e860197c69236798ba1ad231b105794"), + hexEncPubkey("2001ad5e3e80c71b952161bc0186731cf5ffe942d24a79230a0555802296238e57ea7a32f5b6f18564eadc1c65389448481f8c9338df0a3dbd18f708cbc2cbcb"), + hexEncPubkey("6ba3f4f57d084b6bf94cc4555b8c657e4a8ac7b7baf23c6874efc21dd1e4f56b7eb2721e07f5242d2f1d8381fc8cae535e860197c69236798ba1ad231b105794"), }, 244: { - MustHexID("696ba1f0a9d55c59246f776600542a9e6432490f0cd78f8bb55a196918df2081a9b521c3c3ba48e465a75c10768807717f8f689b0b4adce00e1c75737552a178"), + hexEncPubkey("696ba1f0a9d55c59246f776600542a9e6432490f0cd78f8bb55a196918df2081a9b521c3c3ba48e465a75c10768807717f8f689b0b4adce00e1c75737552a178"), }, 246: { - MustHexID("d6d32178bdc38416f46ffb8b3ec9e4cb2cfff8d04dd7e4311a70e403cb62b10be1b447311b60b4f9ee221a8131fc2cbd45b96dd80deba68a949d467241facfa8"), - MustHexID("3ea3d04a43a3dfb5ac11cffc2319248cf41b6279659393c2f55b8a0a5fc9d12581a9d97ef5d8ff9b5abf3321a290e8f63a4f785f450dc8a672aba3ba2ff4fdab"), - MustHexID("2fc897f05ae585553e5c014effd3078f84f37f9333afacffb109f00ca8e7a3373de810a3946be971cbccdfd40249f9fe7f322118ea459ac71acca85a1ef8b7f4"), + hexEncPubkey("d6d32178bdc38416f46ffb8b3ec9e4cb2cfff8d04dd7e4311a70e403cb62b10be1b447311b60b4f9ee221a8131fc2cbd45b96dd80deba68a949d467241facfa8"), + hexEncPubkey("3ea3d04a43a3dfb5ac11cffc2319248cf41b6279659393c2f55b8a0a5fc9d12581a9d97ef5d8ff9b5abf3321a290e8f63a4f785f450dc8a672aba3ba2ff4fdab"), + hexEncPubkey("2fc897f05ae585553e5c014effd3078f84f37f9333afacffb109f00ca8e7a3373de810a3946be971cbccdfd40249f9fe7f322118ea459ac71acca85a1ef8b7f4"), }, 247: { - MustHexID("3155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32"), - MustHexID("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069db"), - MustHexID("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aac"), - MustHexID("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df73"), - MustHexID("8b58c6073dd98bbad4e310b97186c8f822d3a5c7d57af40e2136e88e315afd115edb27d2d0685a908cfe5aa49d0debdda6e6e63972691d6bd8c5af2d771dd2a9"), - MustHexID("2cbb718b7dc682da19652e7d9eb4fefaf7b7147d82c1c2b6805edf77b85e29fde9f6da195741467ff2638dc62c8d3e014ea5686693c15ed0080b6de90354c137"), - MustHexID("e84027696d3f12f2de30a9311afea8fbd313c2360daff52bb5fc8c7094d5295758bec3134e4eef24e4cdf377b40da344993284628a7a346eba94f74160998feb"), - MustHexID("f1357a4f04f9d33753a57c0b65ba20a5d8777abbffd04e906014491c9103fb08590e45548d37aa4bd70965e2e81ddba94f31860348df01469eec8c1829200a68"), - MustHexID("4ab0a75941b12892369b4490a1928c8ca52a9ad6d3dffbd1d8c0b907bc200fe74c022d011ec39b64808a39c0ca41f1d3254386c3e7733e7044c44259486461b6"), - MustHexID("d45150a72dc74388773e68e03133a3b5f51447fe91837d566706b3c035ee4b56f160c878c6273394daee7f56cc398985269052f22f75a8057df2fe6172765354"), + hexEncPubkey("3155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32"), + hexEncPubkey("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069db"), + hexEncPubkey("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aac"), + hexEncPubkey("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df73"), + hexEncPubkey("8b58c6073dd98bbad4e310b97186c8f822d3a5c7d57af40e2136e88e315afd115edb27d2d0685a908cfe5aa49d0debdda6e6e63972691d6bd8c5af2d771dd2a9"), + hexEncPubkey("2cbb718b7dc682da19652e7d9eb4fefaf7b7147d82c1c2b6805edf77b85e29fde9f6da195741467ff2638dc62c8d3e014ea5686693c15ed0080b6de90354c137"), + hexEncPubkey("e84027696d3f12f2de30a9311afea8fbd313c2360daff52bb5fc8c7094d5295758bec3134e4eef24e4cdf377b40da344993284628a7a346eba94f74160998feb"), + hexEncPubkey("f1357a4f04f9d33753a57c0b65ba20a5d8777abbffd04e906014491c9103fb08590e45548d37aa4bd70965e2e81ddba94f31860348df01469eec8c1829200a68"), + hexEncPubkey("4ab0a75941b12892369b4490a1928c8ca52a9ad6d3dffbd1d8c0b907bc200fe74c022d011ec39b64808a39c0ca41f1d3254386c3e7733e7044c44259486461b6"), + hexEncPubkey("d45150a72dc74388773e68e03133a3b5f51447fe91837d566706b3c035ee4b56f160c878c6273394daee7f56cc398985269052f22f75a8057df2fe6172765354"), }, 248: { - MustHexID("6aadfce366a189bab08ac84721567483202c86590642ea6d6a14f37ca78d82bdb6509eb7b8b2f6f63c78ae3ae1d8837c89509e41497d719b23ad53dd81574afa"), - MustHexID("a605ecfd6069a4cf4cf7f5840e5bc0ce10d23a3ac59e2aaa70c6afd5637359d2519b4524f56fc2ca180cdbebe54262f720ccaae8c1b28fd553c485675831624d"), - MustHexID("29701451cb9448ca33fc33680b44b840d815be90146eb521641efbffed0859c154e8892d3906eae9934bfacee72cd1d2fa9dd050fd18888eea49da155ab0efd2"), - MustHexID("3ed426322dee7572b08592e1e079f8b6c6b30e10e6243edd144a6a48fdbdb83df73a6e41b1143722cb82604f2203a32758610b5d9544f44a1a7921ba001528c1"), - MustHexID("b2e2a2b7fdd363572a3256e75435fab1da3b16f7891a8bd2015f30995dae665d7eabfd194d87d99d5df628b4bbc7b04e5b492c596422dd8272746c7a1b0b8e4f"), - MustHexID("0c69c9756162c593e85615b814ce57a2a8ca2df6c690b9c4e4602731b61e1531a3bbe3f7114271554427ffabea80ad8f36fa95a49fa77b675ae182c6ccac1728"), - MustHexID("8d28be21d5a97b0876442fa4f5e5387f5bf3faad0b6f13b8607b64d6e448c0991ca28dd7fe2f64eb8eadd7150bff5d5666aa6ed868b84c71311f4ba9a38569dd"), - MustHexID("2c677e1c64b9c9df6359348a7f5f33dc79e22f0177042486d125f8b6ca7f0dc756b1f672aceee5f1746bcff80aaf6f92a8dc0c9fbeb259b3fa0da060de5ab7e8"), - MustHexID("3994880f94a8678f0cd247a43f474a8af375d2a072128da1ad6cae84a244105ff85e94fc7d8496f639468de7ee998908a91c7e33ef7585fff92e984b210941a1"), - MustHexID("b45a9153c08d002a48090d15d61a7c7dad8c2af85d4ff5bd36ce23a9a11e0709bf8d56614c7b193bc028c16cbf7f20dfbcc751328b64a924995d47b41e452422"), - MustHexID("057ab3a9e53c7a84b0f3fc586117a525cdd18e313f52a67bf31798d48078e325abe5cfee3f6c2533230cb37d0549289d692a29dd400e899b8552d4b928f6f907"), - MustHexID("0ddf663d308791eb92e6bd88a2f8cb45e4f4f35bb16708a0e6ff7f1362aa6a73fedd0a1b1557fb3365e38e1b79d6918e2fae2788728b70c9ab6b51a3b94a4338"), - MustHexID("f637e07ff50cc1e3731735841c4798411059f2023abcf3885674f3e8032531b0edca50fd715df6feb489b6177c345374d64f4b07d257a7745de393a107b013a5"), - MustHexID("e24ec7c6eec094f63c7b3239f56d311ec5a3e45bc4e622a1095a65b95eea6fe13e29f3b6b7a2cbfe40906e3989f17ac834c3102dd0cadaaa26e16ee06d782b72"), - MustHexID("b76ea1a6fd6506ef6e3506a4f1f60ed6287fff8114af6141b2ff13e61242331b54082b023cfea5b3083354a4fb3f9eb8be01fb4a518f579e731a5d0707291a6b"), - MustHexID("9b53a37950ca8890ee349b325032d7b672cab7eced178d3060137b24ef6b92a43977922d5bdfb4a3409a2d80128e02f795f9dae6d7d99973ad0e23a2afb8442f"), + hexEncPubkey("6aadfce366a189bab08ac84721567483202c86590642ea6d6a14f37ca78d82bdb6509eb7b8b2f6f63c78ae3ae1d8837c89509e41497d719b23ad53dd81574afa"), + hexEncPubkey("a605ecfd6069a4cf4cf7f5840e5bc0ce10d23a3ac59e2aaa70c6afd5637359d2519b4524f56fc2ca180cdbebe54262f720ccaae8c1b28fd553c485675831624d"), + hexEncPubkey("29701451cb9448ca33fc33680b44b840d815be90146eb521641efbffed0859c154e8892d3906eae9934bfacee72cd1d2fa9dd050fd18888eea49da155ab0efd2"), + hexEncPubkey("3ed426322dee7572b08592e1e079f8b6c6b30e10e6243edd144a6a48fdbdb83df73a6e41b1143722cb82604f2203a32758610b5d9544f44a1a7921ba001528c1"), + hexEncPubkey("b2e2a2b7fdd363572a3256e75435fab1da3b16f7891a8bd2015f30995dae665d7eabfd194d87d99d5df628b4bbc7b04e5b492c596422dd8272746c7a1b0b8e4f"), + hexEncPubkey("0c69c9756162c593e85615b814ce57a2a8ca2df6c690b9c4e4602731b61e1531a3bbe3f7114271554427ffabea80ad8f36fa95a49fa77b675ae182c6ccac1728"), + hexEncPubkey("8d28be21d5a97b0876442fa4f5e5387f5bf3faad0b6f13b8607b64d6e448c0991ca28dd7fe2f64eb8eadd7150bff5d5666aa6ed868b84c71311f4ba9a38569dd"), + hexEncPubkey("2c677e1c64b9c9df6359348a7f5f33dc79e22f0177042486d125f8b6ca7f0dc756b1f672aceee5f1746bcff80aaf6f92a8dc0c9fbeb259b3fa0da060de5ab7e8"), + hexEncPubkey("3994880f94a8678f0cd247a43f474a8af375d2a072128da1ad6cae84a244105ff85e94fc7d8496f639468de7ee998908a91c7e33ef7585fff92e984b210941a1"), + hexEncPubkey("b45a9153c08d002a48090d15d61a7c7dad8c2af85d4ff5bd36ce23a9a11e0709bf8d56614c7b193bc028c16cbf7f20dfbcc751328b64a924995d47b41e452422"), + hexEncPubkey("057ab3a9e53c7a84b0f3fc586117a525cdd18e313f52a67bf31798d48078e325abe5cfee3f6c2533230cb37d0549289d692a29dd400e899b8552d4b928f6f907"), + hexEncPubkey("0ddf663d308791eb92e6bd88a2f8cb45e4f4f35bb16708a0e6ff7f1362aa6a73fedd0a1b1557fb3365e38e1b79d6918e2fae2788728b70c9ab6b51a3b94a4338"), + hexEncPubkey("f637e07ff50cc1e3731735841c4798411059f2023abcf3885674f3e8032531b0edca50fd715df6feb489b6177c345374d64f4b07d257a7745de393a107b013a5"), + hexEncPubkey("e24ec7c6eec094f63c7b3239f56d311ec5a3e45bc4e622a1095a65b95eea6fe13e29f3b6b7a2cbfe40906e3989f17ac834c3102dd0cadaaa26e16ee06d782b72"), + hexEncPubkey("b76ea1a6fd6506ef6e3506a4f1f60ed6287fff8114af6141b2ff13e61242331b54082b023cfea5b3083354a4fb3f9eb8be01fb4a518f579e731a5d0707291a6b"), + hexEncPubkey("9b53a37950ca8890ee349b325032d7b672cab7eced178d3060137b24ef6b92a43977922d5bdfb4a3409a2d80128e02f795f9dae6d7d99973ad0e23a2afb8442f"), }, 249: { - MustHexID("675ae65567c3c72c50c73bc0fd4f61f202ea5f93346ca57b551de3411ccc614fad61cb9035493af47615311b9d44ee7a161972ee4d77c28fe1ec029d01434e6a"), - MustHexID("8eb81408389da88536ae5800392b16ef5109d7ea132c18e9a82928047ecdb502693f6e4a4cdd18b54296caf561db937185731456c456c98bfe7de0baf0eaa495"), - MustHexID("2adba8b1612a541771cb93a726a38a4b88e97b18eced2593eb7daf82f05a5321ca94a72cc780c306ff21e551a932fc2c6d791e4681907b5ceab7f084c3fa2944"), - MustHexID("b1b4bfbda514d9b8f35b1c28961da5d5216fe50548f4066f69af3b7666a3b2e06eac646735e963e5c8f8138a2fb95af15b13b23ff00c6986eccc0efaa8ee6fb4"), - MustHexID("d2139281b289ad0e4d7b4243c4364f5c51aac8b60f4806135de06b12b5b369c9e43a6eb494eab860d115c15c6fbb8c5a1b0e382972e0e460af395b8385363de7"), - MustHexID("4a693df4b8fc5bdc7cec342c3ed2e228d7c5b4ab7321ddaa6cccbeb45b05a9f1d95766b4002e6d4791c2deacb8a667aadea6a700da28a3eea810a30395701bbc"), - MustHexID("ab41611195ec3c62bb8cd762ee19fb182d194fd141f4a66780efbef4b07ce916246c022b841237a3a6b512a93431157edd221e854ed2a259b72e9c5351f44d0c"), - MustHexID("68e8e26099030d10c3c703ae7045c0a48061fb88058d853b3e67880014c449d4311014da99d617d3150a20f1a3da5e34bf0f14f1c51fe4dd9d58afd222823176"), - MustHexID("3fbcacf546fb129cd70fc48de3b593ba99d3c473798bc309292aca280320e0eacc04442c914cad5c4cf6950345ba79b0d51302df88285d4e83ee3fe41339eee7"), - MustHexID("1d4a623659f7c8f80b6c3939596afdf42e78f892f682c768ad36eb7bfba402dbf97aea3a268f3badd8fe7636be216edf3d67ee1e08789ebbc7be625056bd7109"), - MustHexID("a283c474ab09da02bbc96b16317241d0627646fcc427d1fe790b76a7bf1989ced90f92101a973047ae9940c92720dffbac8eff21df8cae468a50f72f9e159417"), - MustHexID("dbf7e5ad7f87c3dfecae65d87c3039e14ed0bdc56caf00ce81931073e2e16719d746295512ff7937a15c3b03603e7c41a4f9df94fcd37bb200dd8f332767e9cb"), - MustHexID("caaa070a26692f64fc77f30d7b5ae980d419b4393a0f442b1c821ef58c0862898b0d22f74a4f8c5d83069493e3ec0b92f17dc1fe6e4cd437c1ec25039e7ce839"), - MustHexID("874cc8d1213beb65c4e0e1de38ef5d8165235893ac74ab5ea937c885eaab25c8d79dad0456e9fd3e9450626cac7e107b004478fb59842f067857f39a47cee695"), - MustHexID("d94193f236105010972f5df1b7818b55846592a0445b9cdc4eaed811b8c4c0f7c27dc8cc9837a4774656d6b34682d6d329d42b6ebb55da1d475c2474dc3dfdf4"), - MustHexID("edd9af6aded4094e9785637c28fccbd3980cbe28e2eb9a411048a23c2ace4bd6b0b7088a7817997b49a3dd05fc6929ca6c7abbb69438dbdabe65e971d2a794b2"), + hexEncPubkey("675ae65567c3c72c50c73bc0fd4f61f202ea5f93346ca57b551de3411ccc614fad61cb9035493af47615311b9d44ee7a161972ee4d77c28fe1ec029d01434e6a"), + hexEncPubkey("8eb81408389da88536ae5800392b16ef5109d7ea132c18e9a82928047ecdb502693f6e4a4cdd18b54296caf561db937185731456c456c98bfe7de0baf0eaa495"), + hexEncPubkey("2adba8b1612a541771cb93a726a38a4b88e97b18eced2593eb7daf82f05a5321ca94a72cc780c306ff21e551a932fc2c6d791e4681907b5ceab7f084c3fa2944"), + hexEncPubkey("b1b4bfbda514d9b8f35b1c28961da5d5216fe50548f4066f69af3b7666a3b2e06eac646735e963e5c8f8138a2fb95af15b13b23ff00c6986eccc0efaa8ee6fb4"), + hexEncPubkey("d2139281b289ad0e4d7b4243c4364f5c51aac8b60f4806135de06b12b5b369c9e43a6eb494eab860d115c15c6fbb8c5a1b0e382972e0e460af395b8385363de7"), + hexEncPubkey("4a693df4b8fc5bdc7cec342c3ed2e228d7c5b4ab7321ddaa6cccbeb45b05a9f1d95766b4002e6d4791c2deacb8a667aadea6a700da28a3eea810a30395701bbc"), + hexEncPubkey("ab41611195ec3c62bb8cd762ee19fb182d194fd141f4a66780efbef4b07ce916246c022b841237a3a6b512a93431157edd221e854ed2a259b72e9c5351f44d0c"), + hexEncPubkey("68e8e26099030d10c3c703ae7045c0a48061fb88058d853b3e67880014c449d4311014da99d617d3150a20f1a3da5e34bf0f14f1c51fe4dd9d58afd222823176"), + hexEncPubkey("3fbcacf546fb129cd70fc48de3b593ba99d3c473798bc309292aca280320e0eacc04442c914cad5c4cf6950345ba79b0d51302df88285d4e83ee3fe41339eee7"), + hexEncPubkey("1d4a623659f7c8f80b6c3939596afdf42e78f892f682c768ad36eb7bfba402dbf97aea3a268f3badd8fe7636be216edf3d67ee1e08789ebbc7be625056bd7109"), + hexEncPubkey("a283c474ab09da02bbc96b16317241d0627646fcc427d1fe790b76a7bf1989ced90f92101a973047ae9940c92720dffbac8eff21df8cae468a50f72f9e159417"), + hexEncPubkey("dbf7e5ad7f87c3dfecae65d87c3039e14ed0bdc56caf00ce81931073e2e16719d746295512ff7937a15c3b03603e7c41a4f9df94fcd37bb200dd8f332767e9cb"), + hexEncPubkey("caaa070a26692f64fc77f30d7b5ae980d419b4393a0f442b1c821ef58c0862898b0d22f74a4f8c5d83069493e3ec0b92f17dc1fe6e4cd437c1ec25039e7ce839"), + hexEncPubkey("874cc8d1213beb65c4e0e1de38ef5d8165235893ac74ab5ea937c885eaab25c8d79dad0456e9fd3e9450626cac7e107b004478fb59842f067857f39a47cee695"), + hexEncPubkey("d94193f236105010972f5df1b7818b55846592a0445b9cdc4eaed811b8c4c0f7c27dc8cc9837a4774656d6b34682d6d329d42b6ebb55da1d475c2474dc3dfdf4"), + hexEncPubkey("edd9af6aded4094e9785637c28fccbd3980cbe28e2eb9a411048a23c2ace4bd6b0b7088a7817997b49a3dd05fc6929ca6c7abbb69438dbdabe65e971d2a794b2"), }, 250: { - MustHexID("53a5bd1215d4ab709ae8fdc2ced50bba320bced78bd9c5dc92947fb402250c914891786db0978c898c058493f86fc68b1c5de8a5cb36336150ac7a88655b6c39"), - MustHexID("b7f79e3ab59f79262623c9ccefc8f01d682323aee56ffbe295437487e9d5acaf556a9c92e1f1c6a9601f2b9eb6b027ae1aeaebac71d61b9b78e88676efd3e1a3"), - MustHexID("d374bf7e8d7ffff69cc00bebff38ef5bc1dcb0a8d51c1a3d70e61ac6b2e2d6617109254b0ac224354dfbf79009fe4239e09020c483cc60c071e00b9238684f30"), - MustHexID("1e1eac1c9add703eb252eb991594f8f5a173255d526a855fab24ae57dc277e055bc3c7a7ae0b45d437c4f47a72d97eb7b126f2ba344ba6c0e14b2c6f27d4b1e6"), - MustHexID("ae28953f63d4bc4e706712a59319c111f5ff8f312584f65d7436b4cd3d14b217b958f8486bad666b4481fe879019fb1f767cf15b3e3e2711efc33b56d460448a"), - MustHexID("934bb1edf9c7a318b82306aca67feb3d6b434421fa275d694f0b4927afd8b1d3935b727fd4ff6e3d012e0c82f1824385174e8c6450ade59c2a43281a4b3446b6"), - MustHexID("9eef3f28f70ce19637519a0916555bf76d26de31312ac656cf9d3e379899ea44e4dd7ffcce923b4f3563f8a00489a34bd6936db0cbb4c959d32c49f017e07d05"), - MustHexID("82200872e8f871c48f1fad13daec6478298099b591bb3dbc4ef6890aa28ebee5860d07d70be62f4c0af85085a90ae8179ee8f937cf37915c67ea73e704b03ee7"), - MustHexID("6c75a5834a08476b7fc37ff3dc2011dc3ea3b36524bad7a6d319b18878fad813c0ba76d1f4555cacd3890c865438c21f0e0aed1f80e0a157e642124c69f43a11"), - MustHexID("995b873742206cb02b736e73a88580c2aacb0bd4a3c97a647b647bcab3f5e03c0e0736520a8b3600da09edf4248991fb01091ec7ff3ec7cdc8a1beae011e7aae"), - MustHexID("c773a056594b5cdef2e850d30891ff0e927c3b1b9c35cd8e8d53a1017001e237468e1ece3ae33d612ca3e6abb0a9169aa352e9dcda358e5af2ad982b577447db"), - MustHexID("2b46a5f6923f475c6be99ec6d134437a6d11f6bb4b4ac6bcd94572fa1092639d1c08aeefcb51f0912f0a060f71d4f38ee4da70ecc16010b05dd4a674aab14c3a"), - MustHexID("af6ab501366debbaa0d22e20e9688f32ef6b3b644440580fd78de4fe0e99e2a16eb5636bbae0d1c259df8ddda77b35b9a35cbc36137473e9c68fbc9d203ba842"), - MustHexID("c9f6f2dd1a941926f03f770695bda289859e85fabaf94baaae20b93e5015dc014ba41150176a36a1884adb52f405194693e63b0c464a6891cc9cc1c80d450326"), - MustHexID("5b116f0751526868a909b61a30b0c5282c37df6925cc03ddea556ef0d0602a9595fd6c14d371f8ed7d45d89918a032dcd22be4342a8793d88fdbeb3ca3d75bd7"), - MustHexID("50f3222fb6b82481c7c813b2172e1daea43e2710a443b9c2a57a12bd160dd37e20f87aa968c82ad639af6972185609d47036c0d93b4b7269b74ebd7073221c10"), + hexEncPubkey("53a5bd1215d4ab709ae8fdc2ced50bba320bced78bd9c5dc92947fb402250c914891786db0978c898c058493f86fc68b1c5de8a5cb36336150ac7a88655b6c39"), + hexEncPubkey("b7f79e3ab59f79262623c9ccefc8f01d682323aee56ffbe295437487e9d5acaf556a9c92e1f1c6a9601f2b9eb6b027ae1aeaebac71d61b9b78e88676efd3e1a3"), + hexEncPubkey("d374bf7e8d7ffff69cc00bebff38ef5bc1dcb0a8d51c1a3d70e61ac6b2e2d6617109254b0ac224354dfbf79009fe4239e09020c483cc60c071e00b9238684f30"), + hexEncPubkey("1e1eac1c9add703eb252eb991594f8f5a173255d526a855fab24ae57dc277e055bc3c7a7ae0b45d437c4f47a72d97eb7b126f2ba344ba6c0e14b2c6f27d4b1e6"), + hexEncPubkey("ae28953f63d4bc4e706712a59319c111f5ff8f312584f65d7436b4cd3d14b217b958f8486bad666b4481fe879019fb1f767cf15b3e3e2711efc33b56d460448a"), + hexEncPubkey("934bb1edf9c7a318b82306aca67feb3d6b434421fa275d694f0b4927afd8b1d3935b727fd4ff6e3d012e0c82f1824385174e8c6450ade59c2a43281a4b3446b6"), + hexEncPubkey("9eef3f28f70ce19637519a0916555bf76d26de31312ac656cf9d3e379899ea44e4dd7ffcce923b4f3563f8a00489a34bd6936db0cbb4c959d32c49f017e07d05"), + hexEncPubkey("82200872e8f871c48f1fad13daec6478298099b591bb3dbc4ef6890aa28ebee5860d07d70be62f4c0af85085a90ae8179ee8f937cf37915c67ea73e704b03ee7"), + hexEncPubkey("6c75a5834a08476b7fc37ff3dc2011dc3ea3b36524bad7a6d319b18878fad813c0ba76d1f4555cacd3890c865438c21f0e0aed1f80e0a157e642124c69f43a11"), + hexEncPubkey("995b873742206cb02b736e73a88580c2aacb0bd4a3c97a647b647bcab3f5e03c0e0736520a8b3600da09edf4248991fb01091ec7ff3ec7cdc8a1beae011e7aae"), + hexEncPubkey("c773a056594b5cdef2e850d30891ff0e927c3b1b9c35cd8e8d53a1017001e237468e1ece3ae33d612ca3e6abb0a9169aa352e9dcda358e5af2ad982b577447db"), + hexEncPubkey("2b46a5f6923f475c6be99ec6d134437a6d11f6bb4b4ac6bcd94572fa1092639d1c08aeefcb51f0912f0a060f71d4f38ee4da70ecc16010b05dd4a674aab14c3a"), + hexEncPubkey("af6ab501366debbaa0d22e20e9688f32ef6b3b644440580fd78de4fe0e99e2a16eb5636bbae0d1c259df8ddda77b35b9a35cbc36137473e9c68fbc9d203ba842"), + hexEncPubkey("c9f6f2dd1a941926f03f770695bda289859e85fabaf94baaae20b93e5015dc014ba41150176a36a1884adb52f405194693e63b0c464a6891cc9cc1c80d450326"), + hexEncPubkey("5b116f0751526868a909b61a30b0c5282c37df6925cc03ddea556ef0d0602a9595fd6c14d371f8ed7d45d89918a032dcd22be4342a8793d88fdbeb3ca3d75bd7"), + hexEncPubkey("50f3222fb6b82481c7c813b2172e1daea43e2710a443b9c2a57a12bd160dd37e20f87aa968c82ad639af6972185609d47036c0d93b4b7269b74ebd7073221c10"), }, 251: { - MustHexID("9b8f702a62d1bee67bedfeb102eca7f37fa1713e310f0d6651cc0c33ea7c5477575289ccd463e5a2574a00a676a1fdce05658ba447bb9d2827f0ba47b947e894"), - MustHexID("b97532eb83054ed054b4abdf413bb30c00e4205545c93521554dbe77faa3cfaa5bd31ef466a107b0b34a71ec97214c0c83919720142cddac93aa7a3e928d4708"), - MustHexID("2f7a5e952bfb67f2f90b8441b5fadc9ee13b1dcde3afeeb3dd64bf937f86663cc5c55d1fa83952b5422763c7df1b7f2794b751c6be316ebc0beb4942e65ab8c1"), - MustHexID("42c7483781727051a0b3660f14faf39e0d33de5e643702ae933837d036508ab856ce7eec8ec89c4929a4901256e5233a3d847d5d4893f91bcf21835a9a880fee"), - MustHexID("873bae27bf1dc854408fba94046a53ab0c965cebe1e4e12290806fc62b88deb1f4a47f9e18f78fc0e7913a0c6e42ac4d0fc3a20cea6bc65f0c8a0ca90b67521e"), - MustHexID("a7e3a370bbd761d413f8d209e85886f68bf73d5c3089b2dc6fa42aab1ecb5162635497eed95dee2417f3c9c74a3e76319625c48ead2e963c7de877cd4551f347"), - MustHexID("528597534776a40df2addaaea15b6ff832ce36b9748a265768368f657e76d58569d9f30dbb91e91cf0ae7efe8f402f17aa0ae15f5c55051ba03ba830287f4c42"), - MustHexID("461d8bd4f13c3c09031fdb84f104ed737a52f630261463ce0bdb5704259bab4b737dda688285b8444dbecaecad7f50f835190b38684ced5e90c54219e5adf1bc"), - MustHexID("6ec50c0be3fd232737090fc0111caaf0bb6b18f72be453428087a11a97fd6b52db0344acbf789a689bd4f5f50f79017ea784f8fd6fe723ad6ae675b9e3b13e21"), - MustHexID("12fc5e2f77a83fdcc727b79d8ae7fe6a516881138d3011847ee136b400fed7cfba1f53fd7a9730253c7aa4f39abeacd04f138417ba7fcb0f36cccc3514e0dab6"), - MustHexID("4fdbe75914ccd0bce02101606a1ccf3657ec963e3b3c20239d5fec87673fe446d649b4f15f1fe1a40e6cfbd446dda2d31d40bb602b1093b8fcd5f139ba0eb46a"), - MustHexID("3753668a0f6281e425ea69b52cb2d17ab97afbe6eb84cf5d25425bc5e53009388857640668fadd7c110721e6047c9697803bd8a6487b43bb343bfa32ebf24039"), - MustHexID("2e81b16346637dec4410fd88e527346145b9c0a849dbf2628049ac7dae016c8f4305649d5659ec77f1e8a0fac0db457b6080547226f06283598e3740ad94849a"), - MustHexID("802c3cc27f91c89213223d758f8d2ecd41135b357b6d698f24d811cdf113033a81c38e0bdff574a5c005b00a8c193dc2531f8c1fa05fa60acf0ab6f2858af09f"), - MustHexID("fcc9a2e1ac3667026ff16192876d1813bb75abdbf39b929a92863012fe8b1d890badea7a0de36274d5c1eb1e8f975785532c50d80fd44b1a4b692f437303393f"), - MustHexID("6d8b3efb461151dd4f6de809b62726f5b89e9b38e9ba1391967f61cde844f7528fecf821b74049207cee5a527096b31f3ad623928cd3ce51d926fa345a6b2951"), + hexEncPubkey("9b8f702a62d1bee67bedfeb102eca7f37fa1713e310f0d6651cc0c33ea7c5477575289ccd463e5a2574a00a676a1fdce05658ba447bb9d2827f0ba47b947e894"), + hexEncPubkey("b97532eb83054ed054b4abdf413bb30c00e4205545c93521554dbe77faa3cfaa5bd31ef466a107b0b34a71ec97214c0c83919720142cddac93aa7a3e928d4708"), + hexEncPubkey("2f7a5e952bfb67f2f90b8441b5fadc9ee13b1dcde3afeeb3dd64bf937f86663cc5c55d1fa83952b5422763c7df1b7f2794b751c6be316ebc0beb4942e65ab8c1"), + hexEncPubkey("42c7483781727051a0b3660f14faf39e0d33de5e643702ae933837d036508ab856ce7eec8ec89c4929a4901256e5233a3d847d5d4893f91bcf21835a9a880fee"), + hexEncPubkey("873bae27bf1dc854408fba94046a53ab0c965cebe1e4e12290806fc62b88deb1f4a47f9e18f78fc0e7913a0c6e42ac4d0fc3a20cea6bc65f0c8a0ca90b67521e"), + hexEncPubkey("a7e3a370bbd761d413f8d209e85886f68bf73d5c3089b2dc6fa42aab1ecb5162635497eed95dee2417f3c9c74a3e76319625c48ead2e963c7de877cd4551f347"), + hexEncPubkey("528597534776a40df2addaaea15b6ff832ce36b9748a265768368f657e76d58569d9f30dbb91e91cf0ae7efe8f402f17aa0ae15f5c55051ba03ba830287f4c42"), + hexEncPubkey("461d8bd4f13c3c09031fdb84f104ed737a52f630261463ce0bdb5704259bab4b737dda688285b8444dbecaecad7f50f835190b38684ced5e90c54219e5adf1bc"), + hexEncPubkey("6ec50c0be3fd232737090fc0111caaf0bb6b18f72be453428087a11a97fd6b52db0344acbf789a689bd4f5f50f79017ea784f8fd6fe723ad6ae675b9e3b13e21"), + hexEncPubkey("12fc5e2f77a83fdcc727b79d8ae7fe6a516881138d3011847ee136b400fed7cfba1f53fd7a9730253c7aa4f39abeacd04f138417ba7fcb0f36cccc3514e0dab6"), + hexEncPubkey("4fdbe75914ccd0bce02101606a1ccf3657ec963e3b3c20239d5fec87673fe446d649b4f15f1fe1a40e6cfbd446dda2d31d40bb602b1093b8fcd5f139ba0eb46a"), + hexEncPubkey("3753668a0f6281e425ea69b52cb2d17ab97afbe6eb84cf5d25425bc5e53009388857640668fadd7c110721e6047c9697803bd8a6487b43bb343bfa32ebf24039"), + hexEncPubkey("2e81b16346637dec4410fd88e527346145b9c0a849dbf2628049ac7dae016c8f4305649d5659ec77f1e8a0fac0db457b6080547226f06283598e3740ad94849a"), + hexEncPubkey("802c3cc27f91c89213223d758f8d2ecd41135b357b6d698f24d811cdf113033a81c38e0bdff574a5c005b00a8c193dc2531f8c1fa05fa60acf0ab6f2858af09f"), + hexEncPubkey("fcc9a2e1ac3667026ff16192876d1813bb75abdbf39b929a92863012fe8b1d890badea7a0de36274d5c1eb1e8f975785532c50d80fd44b1a4b692f437303393f"), + hexEncPubkey("6d8b3efb461151dd4f6de809b62726f5b89e9b38e9ba1391967f61cde844f7528fecf821b74049207cee5a527096b31f3ad623928cd3ce51d926fa345a6b2951"), }, 252: { - MustHexID("f1ae93157cc48c2075dd5868fbf523e79e06caf4b8198f352f6e526680b78ff4227263de92612f7d63472bd09367bb92a636fff16fe46ccf41614f7a72495c2a"), - MustHexID("587f482d111b239c27c0cb89b51dd5d574db8efd8de14a2e6a1400c54d4567e77c65f89c1da52841212080b91604104768350276b6682f2f961cdaf4039581c7"), - MustHexID("e3f88274d35cefdaabdf205afe0e80e936cc982b8e3e47a84ce664c413b29016a4fb4f3a3ebae0a2f79671f8323661ed462bf4390af94c424dc8ace0c301b90f"), - MustHexID("0ddc736077da9a12ba410dc5ea63cbcbe7659dd08596485b2bff3435221f82c10d263efd9af938e128464be64a178b7cd22e19f400d5802f4c9df54bf89f2619"), - MustHexID("784aa34d833c6ce63fcc1279630113c3272e82c4ae8c126c5a52a88ac461b6baeed4244e607b05dc14e5b2f41c70a273c3804dea237f14f7a1e546f6d1309d14"), - MustHexID("f253a2c354ee0e27cfcae786d726753d4ad24be6516b279a936195a487de4a59dbc296accf20463749ff55293263ed8c1b6365eecb248d44e75e9741c0d18205"), - MustHexID("a1910b80357b3ad9b4593e0628922939614dc9056a5fbf477279c8b2c1d0b4b31d89a0c09d0d41f795271d14d3360ef08a3f821e65e7e1f56c07a36afe49c7c5"), - MustHexID("f1168552c2efe541160f0909b0b4a9d6aeedcf595cdf0e9b165c97e3e197471a1ee6320e93389edfba28af6eaf10de98597ad56e7ab1b504ed762451996c3b98"), - MustHexID("b0c8e5d2c8634a7930e1a6fd082e448c6cf9d2d8b7293558b59238815a4df926c286bf297d2049f14e8296a6eb3256af614ec1812c4f2bbe807673b58bf14c8c"), - MustHexID("0fb346076396a38badc342df3679b55bd7f40a609ab103411fe45082c01f12ea016729e95914b2b5540e987ff5c9b133e85862648e7f36abdfd23100d248d234"), - MustHexID("f736e0cc83417feaa280d9483f5d4d72d1b036cd0c6d9cbdeb8ac35ceb2604780de46dddaa32a378474e1d5ccdf79b373331c30c7911ade2ae32f98832e5de1f"), - MustHexID("8b02991457602f42b38b342d3f2259ae4100c354b3843885f7e4e07bd644f64dab94bb7f38a3915f8b7f11d8e3f81c28e07a0078cf79d7397e38a7b7e0c857e2"), - MustHexID("9221d9f04a8a184993d12baa91116692bb685f887671302999d69300ad103eb2d2c75a09d8979404c6dd28f12362f58a1a43619c493d9108fd47588a23ce5824"), - MustHexID("652797801744dada833fff207d67484742eea6835d695925f3e618d71b68ec3c65bdd85b4302b2cdcb835ad3f94fd00d8da07e570b41bc0d2bcf69a8de1b3284"), - MustHexID("d84f06fe64debc4cd0625e36d19b99014b6218375262cc2209202bdbafd7dffcc4e34ce6398e182e02fd8faeed622c3e175545864902dfd3d1ac57647cddf4c6"), - MustHexID("d0ed87b294f38f1d741eb601020eeec30ac16331d05880fe27868f1e454446de367d7457b41c79e202eaf9525b029e4f1d7e17d85a55f83a557c005c68d7328a"), + hexEncPubkey("f1ae93157cc48c2075dd5868fbf523e79e06caf4b8198f352f6e526680b78ff4227263de92612f7d63472bd09367bb92a636fff16fe46ccf41614f7a72495c2a"), + hexEncPubkey("587f482d111b239c27c0cb89b51dd5d574db8efd8de14a2e6a1400c54d4567e77c65f89c1da52841212080b91604104768350276b6682f2f961cdaf4039581c7"), + hexEncPubkey("e3f88274d35cefdaabdf205afe0e80e936cc982b8e3e47a84ce664c413b29016a4fb4f3a3ebae0a2f79671f8323661ed462bf4390af94c424dc8ace0c301b90f"), + hexEncPubkey("0ddc736077da9a12ba410dc5ea63cbcbe7659dd08596485b2bff3435221f82c10d263efd9af938e128464be64a178b7cd22e19f400d5802f4c9df54bf89f2619"), + hexEncPubkey("784aa34d833c6ce63fcc1279630113c3272e82c4ae8c126c5a52a88ac461b6baeed4244e607b05dc14e5b2f41c70a273c3804dea237f14f7a1e546f6d1309d14"), + hexEncPubkey("f253a2c354ee0e27cfcae786d726753d4ad24be6516b279a936195a487de4a59dbc296accf20463749ff55293263ed8c1b6365eecb248d44e75e9741c0d18205"), + hexEncPubkey("a1910b80357b3ad9b4593e0628922939614dc9056a5fbf477279c8b2c1d0b4b31d89a0c09d0d41f795271d14d3360ef08a3f821e65e7e1f56c07a36afe49c7c5"), + hexEncPubkey("f1168552c2efe541160f0909b0b4a9d6aeedcf595cdf0e9b165c97e3e197471a1ee6320e93389edfba28af6eaf10de98597ad56e7ab1b504ed762451996c3b98"), + hexEncPubkey("b0c8e5d2c8634a7930e1a6fd082e448c6cf9d2d8b7293558b59238815a4df926c286bf297d2049f14e8296a6eb3256af614ec1812c4f2bbe807673b58bf14c8c"), + hexEncPubkey("0fb346076396a38badc342df3679b55bd7f40a609ab103411fe45082c01f12ea016729e95914b2b5540e987ff5c9b133e85862648e7f36abdfd23100d248d234"), + hexEncPubkey("f736e0cc83417feaa280d9483f5d4d72d1b036cd0c6d9cbdeb8ac35ceb2604780de46dddaa32a378474e1d5ccdf79b373331c30c7911ade2ae32f98832e5de1f"), + hexEncPubkey("8b02991457602f42b38b342d3f2259ae4100c354b3843885f7e4e07bd644f64dab94bb7f38a3915f8b7f11d8e3f81c28e07a0078cf79d7397e38a7b7e0c857e2"), + hexEncPubkey("9221d9f04a8a184993d12baa91116692bb685f887671302999d69300ad103eb2d2c75a09d8979404c6dd28f12362f58a1a43619c493d9108fd47588a23ce5824"), + hexEncPubkey("652797801744dada833fff207d67484742eea6835d695925f3e618d71b68ec3c65bdd85b4302b2cdcb835ad3f94fd00d8da07e570b41bc0d2bcf69a8de1b3284"), + hexEncPubkey("d84f06fe64debc4cd0625e36d19b99014b6218375262cc2209202bdbafd7dffcc4e34ce6398e182e02fd8faeed622c3e175545864902dfd3d1ac57647cddf4c6"), + hexEncPubkey("d0ed87b294f38f1d741eb601020eeec30ac16331d05880fe27868f1e454446de367d7457b41c79e202eaf9525b029e4f1d7e17d85a55f83a557c005c68d7328a"), }, 253: { - MustHexID("ad4485e386e3cc7c7310366a7c38fb810b8896c0d52e55944bfd320ca294e7912d6c53c0a0cf85e7ce226e92491d60430e86f8f15cda0161ed71893fb4a9e3a1"), - MustHexID("36d0e7e5b7734f98c6183eeeb8ac5130a85e910a925311a19c4941b1290f945d4fc3996b12ef4966960b6fa0fb29b1604f83a0f81bd5fd6398d2e1a22e46af0c"), - MustHexID("7d307d8acb4a561afa23bdf0bd945d35c90245e26345ec3a1f9f7df354222a7cdcb81339c9ed6744526c27a1a0c8d10857e98df942fa433602facac71ac68a31"), - MustHexID("d97bf55f88c83fae36232661af115d66ca600fc4bd6d1fb35ff9bb4dad674c02cf8c8d05f317525b5522250db58bb1ecafb7157392bf5aa61b178c61f098d995"), - MustHexID("7045d678f1f9eb7a4613764d17bd5698796494d0bf977b16f2dbc272b8a0f7858a60805c022fc3d1fe4f31c37e63cdaca0416c0d053ef48a815f8b19121605e0"), - MustHexID("14e1f21418d445748de2a95cd9a8c3b15b506f86a0acabd8af44bb968ce39885b19c8822af61b3dd58a34d1f265baec30e3ae56149dc7d2aa4a538f7319f69c8"), - MustHexID("b9453d78281b66a4eac95a1546017111eaaa5f92a65d0de10b1122940e92b319728a24edf4dec6acc412321b1c95266d39c7b3a5d265c629c3e49a65fb022c09"), - MustHexID("e8a49248419e3824a00d86af422f22f7366e2d4922b304b7169937616a01d9d6fa5abf5cc01061a352dc866f48e1fa2240dbb453d872b1d7be62bdfc1d5e248c"), - MustHexID("bebcff24b52362f30e0589ee573ce2d86f073d58d18e6852a592fa86ceb1a6c9b96d7fb9ec7ed1ed98a51b6743039e780279f6bb49d0a04327ac7a182d9a56f6"), - MustHexID("d0835e5a4291db249b8d2fca9f503049988180c7d247bedaa2cf3a1bad0a76709360a85d4f9a1423b2cbc82bb4d94b47c0cde20afc430224834c49fe312a9ae3"), - MustHexID("6b087fe2a2da5e4f0b0f4777598a4a7fb66bf77dbd5bfc44e8a7eaa432ab585a6e226891f56a7d4f5ed11a7c57b90f1661bba1059590ca4267a35801c2802913"), - MustHexID("d901e5bde52d1a0f4ddf010a686a53974cdae4ebe5c6551b3c37d6b6d635d38d5b0e5f80bc0186a2c7809dbf3a42870dd09643e68d32db896c6da8ba734579e7"), - MustHexID("96419fb80efae4b674402bb969ebaab86c1274f29a83a311e24516d36cdf148fe21754d46c97688cdd7468f24c08b13e4727c29263393638a3b37b99ff60ebca"), - MustHexID("7b9c1889ae916a5d5abcdfb0aaedcc9c6f9eb1c1a4f68d0c2d034fe79ac610ce917c3abc670744150fa891bfcd8ab14fed6983fca964de920aa393fa7b326748"), - MustHexID("7a369b2b8962cc4c65900be046482fbf7c14f98a135bbbae25152c82ad168fb2097b3d1429197cf46d3ce9fdeb64808f908a489cc6019725db040060fdfe5405"), - MustHexID("47bcae48288da5ecc7f5058dfa07cf14d89d06d6e449cb946e237aa6652ea050d9f5a24a65efdc0013ccf232bf88670979eddef249b054f63f38da9d7796dbd8"), + hexEncPubkey("ad4485e386e3cc7c7310366a7c38fb810b8896c0d52e55944bfd320ca294e7912d6c53c0a0cf85e7ce226e92491d60430e86f8f15cda0161ed71893fb4a9e3a1"), + hexEncPubkey("36d0e7e5b7734f98c6183eeeb8ac5130a85e910a925311a19c4941b1290f945d4fc3996b12ef4966960b6fa0fb29b1604f83a0f81bd5fd6398d2e1a22e46af0c"), + hexEncPubkey("7d307d8acb4a561afa23bdf0bd945d35c90245e26345ec3a1f9f7df354222a7cdcb81339c9ed6744526c27a1a0c8d10857e98df942fa433602facac71ac68a31"), + hexEncPubkey("d97bf55f88c83fae36232661af115d66ca600fc4bd6d1fb35ff9bb4dad674c02cf8c8d05f317525b5522250db58bb1ecafb7157392bf5aa61b178c61f098d995"), + hexEncPubkey("7045d678f1f9eb7a4613764d17bd5698796494d0bf977b16f2dbc272b8a0f7858a60805c022fc3d1fe4f31c37e63cdaca0416c0d053ef48a815f8b19121605e0"), + hexEncPubkey("14e1f21418d445748de2a95cd9a8c3b15b506f86a0acabd8af44bb968ce39885b19c8822af61b3dd58a34d1f265baec30e3ae56149dc7d2aa4a538f7319f69c8"), + hexEncPubkey("b9453d78281b66a4eac95a1546017111eaaa5f92a65d0de10b1122940e92b319728a24edf4dec6acc412321b1c95266d39c7b3a5d265c629c3e49a65fb022c09"), + hexEncPubkey("e8a49248419e3824a00d86af422f22f7366e2d4922b304b7169937616a01d9d6fa5abf5cc01061a352dc866f48e1fa2240dbb453d872b1d7be62bdfc1d5e248c"), + hexEncPubkey("bebcff24b52362f30e0589ee573ce2d86f073d58d18e6852a592fa86ceb1a6c9b96d7fb9ec7ed1ed98a51b6743039e780279f6bb49d0a04327ac7a182d9a56f6"), + hexEncPubkey("d0835e5a4291db249b8d2fca9f503049988180c7d247bedaa2cf3a1bad0a76709360a85d4f9a1423b2cbc82bb4d94b47c0cde20afc430224834c49fe312a9ae3"), + hexEncPubkey("6b087fe2a2da5e4f0b0f4777598a4a7fb66bf77dbd5bfc44e8a7eaa432ab585a6e226891f56a7d4f5ed11a7c57b90f1661bba1059590ca4267a35801c2802913"), + hexEncPubkey("d901e5bde52d1a0f4ddf010a686a53974cdae4ebe5c6551b3c37d6b6d635d38d5b0e5f80bc0186a2c7809dbf3a42870dd09643e68d32db896c6da8ba734579e7"), + hexEncPubkey("96419fb80efae4b674402bb969ebaab86c1274f29a83a311e24516d36cdf148fe21754d46c97688cdd7468f24c08b13e4727c29263393638a3b37b99ff60ebca"), + hexEncPubkey("7b9c1889ae916a5d5abcdfb0aaedcc9c6f9eb1c1a4f68d0c2d034fe79ac610ce917c3abc670744150fa891bfcd8ab14fed6983fca964de920aa393fa7b326748"), + hexEncPubkey("7a369b2b8962cc4c65900be046482fbf7c14f98a135bbbae25152c82ad168fb2097b3d1429197cf46d3ce9fdeb64808f908a489cc6019725db040060fdfe5405"), + hexEncPubkey("47bcae48288da5ecc7f5058dfa07cf14d89d06d6e449cb946e237aa6652ea050d9f5a24a65efdc0013ccf232bf88670979eddef249b054f63f38da9d7796dbd8"), }, 254: { - MustHexID("099739d7abc8abd38ecc7a816c521a1168a4dbd359fa7212a5123ab583ffa1cf485a5fed219575d6475dbcdd541638b2d3631a6c7fce7474e7fe3cba1d4d5853"), - MustHexID("c2b01603b088a7182d0cf7ef29fb2b04c70acb320fccf78526bf9472e10c74ee70b3fcfa6f4b11d167bd7d3bc4d936b660f2c9bff934793d97cb21750e7c3d31"), - MustHexID("20e4d8f45f2f863e94b45548c1ef22a11f7d36f263e4f8623761e05a64c4572379b000a52211751e2561b0f14f4fc92dd4130410c8ccc71eb4f0e95a700d4ca9"), - MustHexID("27f4a16cc085e72d86e25c98bd2eca173eaaee7565c78ec5a52e9e12b2211f35de81b5b45e9195de2ebfe29106742c59112b951a04eb7ae48822911fc1f9389e"), - MustHexID("55db5ee7d98e7f0b1c3b9d5be6f2bc619a1b86c3cdd513160ad4dcf267037a5fffad527ac15d50aeb32c59c13d1d4c1e567ebbf4de0d25236130c8361f9aac63"), - MustHexID("883df308b0130fc928a8559fe50667a0fff80493bc09685d18213b2db241a3ad11310ed86b0ef662b3ce21fc3d9aa7f3fc24b8d9afe17c7407e9afd3345ae548"), - MustHexID("c7af968cc9bc8200c3ee1a387405f7563be1dce6710a3439f42ea40657d0eae9d2b3c16c42d779605351fcdece4da637b9804e60ca08cfb89aec32c197beffa6"), - MustHexID("3e66f2b788e3ff1d04106b80597915cd7afa06c405a7ae026556b6e583dca8e05cfbab5039bb9a1b5d06083ffe8de5780b1775550e7218f5e98624bf7af9a0a8"), - MustHexID("4fc7f53764de3337fdaec0a711d35d3a923e72fa65025444d12230b3552ed43d9b2d1ad08ccb11f2d50c58809e6dd74dde910e195294fca3b47ae5a3967cc479"), - MustHexID("bafdfdcf6ccaa989436752fa97c77477b6baa7deb374b16c095492c529eb133e8e2f99e1977012b64767b9d34b2cf6d2048ed489bd822b5139b523f6a423167b"), - MustHexID("7f5d78008a4312fe059104ce80202c82b8915c2eb4411c6b812b16f7642e57c00f2c9425121f5cbac4257fe0b3e81ef5dea97ea2dbaa98f6a8b6fd4d1e5980bb"), - MustHexID("598c37fe78f922751a052f463aeb0cb0bc7f52b7c2a4cf2da72ec0931c7c32175d4165d0f8998f7320e87324ac3311c03f9382a5385c55f0407b7a66b2acd864"), - MustHexID("f758c4136e1c148777a7f3275a76e2db0b2b04066fd738554ec398c1c6cc9fb47e14a3b4c87bd47deaeab3ffd2110514c3855685a374794daff87b605b27ee2e"), - MustHexID("0307bb9e4fd865a49dcf1fe4333d1b944547db650ab580af0b33e53c4fef6c789531110fac801bbcbce21fc4d6f61b6d5b24abdf5b22e3030646d579f6dca9c2"), - MustHexID("82504b6eb49bb2c0f91a7006ce9cefdbaf6df38706198502c2e06601091fc9dc91e4f15db3410d45c6af355bc270b0f268d3dff560f956985c7332d4b10bd1ed"), - MustHexID("b39b5b677b45944ceebe76e76d1f051de2f2a0ec7b0d650da52135743e66a9a5dba45f638258f9a7545d9a790c7fe6d3fdf82c25425c7887323e45d27d06c057"), + hexEncPubkey("099739d7abc8abd38ecc7a816c521a1168a4dbd359fa7212a5123ab583ffa1cf485a5fed219575d6475dbcdd541638b2d3631a6c7fce7474e7fe3cba1d4d5853"), + hexEncPubkey("c2b01603b088a7182d0cf7ef29fb2b04c70acb320fccf78526bf9472e10c74ee70b3fcfa6f4b11d167bd7d3bc4d936b660f2c9bff934793d97cb21750e7c3d31"), + hexEncPubkey("20e4d8f45f2f863e94b45548c1ef22a11f7d36f263e4f8623761e05a64c4572379b000a52211751e2561b0f14f4fc92dd4130410c8ccc71eb4f0e95a700d4ca9"), + hexEncPubkey("27f4a16cc085e72d86e25c98bd2eca173eaaee7565c78ec5a52e9e12b2211f35de81b5b45e9195de2ebfe29106742c59112b951a04eb7ae48822911fc1f9389e"), + hexEncPubkey("55db5ee7d98e7f0b1c3b9d5be6f2bc619a1b86c3cdd513160ad4dcf267037a5fffad527ac15d50aeb32c59c13d1d4c1e567ebbf4de0d25236130c8361f9aac63"), + hexEncPubkey("883df308b0130fc928a8559fe50667a0fff80493bc09685d18213b2db241a3ad11310ed86b0ef662b3ce21fc3d9aa7f3fc24b8d9afe17c7407e9afd3345ae548"), + hexEncPubkey("c7af968cc9bc8200c3ee1a387405f7563be1dce6710a3439f42ea40657d0eae9d2b3c16c42d779605351fcdece4da637b9804e60ca08cfb89aec32c197beffa6"), + hexEncPubkey("3e66f2b788e3ff1d04106b80597915cd7afa06c405a7ae026556b6e583dca8e05cfbab5039bb9a1b5d06083ffe8de5780b1775550e7218f5e98624bf7af9a0a8"), + hexEncPubkey("4fc7f53764de3337fdaec0a711d35d3a923e72fa65025444d12230b3552ed43d9b2d1ad08ccb11f2d50c58809e6dd74dde910e195294fca3b47ae5a3967cc479"), + hexEncPubkey("bafdfdcf6ccaa989436752fa97c77477b6baa7deb374b16c095492c529eb133e8e2f99e1977012b64767b9d34b2cf6d2048ed489bd822b5139b523f6a423167b"), + hexEncPubkey("7f5d78008a4312fe059104ce80202c82b8915c2eb4411c6b812b16f7642e57c00f2c9425121f5cbac4257fe0b3e81ef5dea97ea2dbaa98f6a8b6fd4d1e5980bb"), + hexEncPubkey("598c37fe78f922751a052f463aeb0cb0bc7f52b7c2a4cf2da72ec0931c7c32175d4165d0f8998f7320e87324ac3311c03f9382a5385c55f0407b7a66b2acd864"), + hexEncPubkey("f758c4136e1c148777a7f3275a76e2db0b2b04066fd738554ec398c1c6cc9fb47e14a3b4c87bd47deaeab3ffd2110514c3855685a374794daff87b605b27ee2e"), + hexEncPubkey("0307bb9e4fd865a49dcf1fe4333d1b944547db650ab580af0b33e53c4fef6c789531110fac801bbcbce21fc4d6f61b6d5b24abdf5b22e3030646d579f6dca9c2"), + hexEncPubkey("82504b6eb49bb2c0f91a7006ce9cefdbaf6df38706198502c2e06601091fc9dc91e4f15db3410d45c6af355bc270b0f268d3dff560f956985c7332d4b10bd1ed"), + hexEncPubkey("b39b5b677b45944ceebe76e76d1f051de2f2a0ec7b0d650da52135743e66a9a5dba45f638258f9a7545d9a790c7fe6d3fdf82c25425c7887323e45d27d06c057"), }, 255: { - MustHexID("5c4d58d46e055dd1f093f81ee60a675e1f02f54da6206720adee4dccef9b67a31efc5c2a2949c31a04ee31beadc79aba10da31440a1f9ff2a24093c63c36d784"), - MustHexID("ea72161ffdd4b1e124c7b93b0684805f4c4b58d617ed498b37a145c670dbc2e04976f8785583d9c805ffbf343c31d492d79f841652bbbd01b61ed85640b23495"), - MustHexID("51caa1d93352d47a8e531692a3612adac1e8ac68d0a200d086c1c57ae1e1a91aa285ab242e8c52ef9d7afe374c9485b122ae815f1707b875569d0433c1c3ce85"), - MustHexID("c08397d5751b47bd3da044b908be0fb0e510d3149574dff7aeab33749b023bb171b5769990fe17469dbebc100bc150e798aeda426a2dcc766699a225fddd75c6"), - MustHexID("0222c1c194b749736e593f937fad67ee348ac57287a15c7e42877aa38a9b87732a408bca370f812efd0eedbff13e6d5b854bf3ba1dec431a796ed47f32552b09"), - MustHexID("03d859cd46ef02d9bfad5268461a6955426845eef4126de6be0fa4e8d7e0727ba2385b78f1a883a8239e95ebb814f2af8379632c7d5b100688eebc5841209582"), - MustHexID("64d5004b7e043c39ff0bd10cb20094c287721d5251715884c280a612b494b3e9e1c64ba6f67614994c7d969a0d0c0295d107d53fc225d47c44c4b82852d6f960"), - MustHexID("b0a5eefb2dab6f786670f35bf9641eefe6dd87fd3f1362bcab4aaa792903500ab23d88fae68411372e0813b057535a601d46e454323745a948017f6063a47b1f"), - MustHexID("0cc6df0a3433d448b5684d2a3ffa9d1a825388177a18f44ad0008c7bd7702f1ec0fc38b83506f7de689c3b6ecb552599927e29699eed6bb867ff08f80068b287"), - MustHexID("50772f7b8c03a4e153355fbbf79c8a80cf32af656ff0c7873c99911099d04a0dae0674706c357e0145ad017a0ade65e6052cb1b0d574fcd6f67da3eee0ace66b"), - MustHexID("1ae37829c9ef41f8b508b82259ebac76b1ed900d7a45c08b7970f25d2d48ddd1829e2f11423a18749940b6dab8598c6e416cef0efd47e46e51f29a0bc65b37cd"), - MustHexID("ba973cab31c2af091fc1644a93527d62b2394999e2b6ccbf158dd5ab9796a43d408786f1803ef4e29debfeb62fce2b6caa5ab2b24d1549c822a11c40c2856665"), - MustHexID("bc413ad270dd6ea25bddba78f3298b03b8ba6f8608ac03d06007d4116fa78ef5a0cfe8c80155089382fc7a193243ee5500082660cb5d7793f60f2d7d18650964"), - MustHexID("5a6a9ef07634d9eec3baa87c997b529b92652afa11473dfee41ef7037d5c06e0ddb9fe842364462d79dd31cff8a59a1b8d5bc2b810dea1d4cbbd3beb80ecec83"), - MustHexID("f492c6ee2696d5f682f7f537757e52744c2ae560f1090a07024609e903d334e9e174fc01609c5a229ddbcac36c9d21adaf6457dab38a25bfd44f2f0ee4277998"), - MustHexID("459e4db99298cb0467a90acee6888b08bb857450deac11015cced5104853be5adce5b69c740968bc7f931495d671a70cad9f48546d7cd203357fe9af0e8d2164"), + hexEncPubkey("5c4d58d46e055dd1f093f81ee60a675e1f02f54da6206720adee4dccef9b67a31efc5c2a2949c31a04ee31beadc79aba10da31440a1f9ff2a24093c63c36d784"), + hexEncPubkey("ea72161ffdd4b1e124c7b93b0684805f4c4b58d617ed498b37a145c670dbc2e04976f8785583d9c805ffbf343c31d492d79f841652bbbd01b61ed85640b23495"), + hexEncPubkey("51caa1d93352d47a8e531692a3612adac1e8ac68d0a200d086c1c57ae1e1a91aa285ab242e8c52ef9d7afe374c9485b122ae815f1707b875569d0433c1c3ce85"), + hexEncPubkey("c08397d5751b47bd3da044b908be0fb0e510d3149574dff7aeab33749b023bb171b5769990fe17469dbebc100bc150e798aeda426a2dcc766699a225fddd75c6"), + hexEncPubkey("0222c1c194b749736e593f937fad67ee348ac57287a15c7e42877aa38a9b87732a408bca370f812efd0eedbff13e6d5b854bf3ba1dec431a796ed47f32552b09"), + hexEncPubkey("03d859cd46ef02d9bfad5268461a6955426845eef4126de6be0fa4e8d7e0727ba2385b78f1a883a8239e95ebb814f2af8379632c7d5b100688eebc5841209582"), + hexEncPubkey("64d5004b7e043c39ff0bd10cb20094c287721d5251715884c280a612b494b3e9e1c64ba6f67614994c7d969a0d0c0295d107d53fc225d47c44c4b82852d6f960"), + hexEncPubkey("b0a5eefb2dab6f786670f35bf9641eefe6dd87fd3f1362bcab4aaa792903500ab23d88fae68411372e0813b057535a601d46e454323745a948017f6063a47b1f"), + hexEncPubkey("0cc6df0a3433d448b5684d2a3ffa9d1a825388177a18f44ad0008c7bd7702f1ec0fc38b83506f7de689c3b6ecb552599927e29699eed6bb867ff08f80068b287"), + hexEncPubkey("50772f7b8c03a4e153355fbbf79c8a80cf32af656ff0c7873c99911099d04a0dae0674706c357e0145ad017a0ade65e6052cb1b0d574fcd6f67da3eee0ace66b"), + hexEncPubkey("1ae37829c9ef41f8b508b82259ebac76b1ed900d7a45c08b7970f25d2d48ddd1829e2f11423a18749940b6dab8598c6e416cef0efd47e46e51f29a0bc65b37cd"), + hexEncPubkey("ba973cab31c2af091fc1644a93527d62b2394999e2b6ccbf158dd5ab9796a43d408786f1803ef4e29debfeb62fce2b6caa5ab2b24d1549c822a11c40c2856665"), + hexEncPubkey("bc413ad270dd6ea25bddba78f3298b03b8ba6f8608ac03d06007d4116fa78ef5a0cfe8c80155089382fc7a193243ee5500082660cb5d7793f60f2d7d18650964"), + hexEncPubkey("5a6a9ef07634d9eec3baa87c997b529b92652afa11473dfee41ef7037d5c06e0ddb9fe842364462d79dd31cff8a59a1b8d5bc2b810dea1d4cbbd3beb80ecec83"), + hexEncPubkey("f492c6ee2696d5f682f7f537757e52744c2ae560f1090a07024609e903d334e9e174fc01609c5a229ddbcac36c9d21adaf6457dab38a25bfd44f2f0ee4277998"), + hexEncPubkey("459e4db99298cb0467a90acee6888b08bb857450deac11015cced5104853be5adce5b69c740968bc7f931495d671a70cad9f48546d7cd203357fe9af0e8d2164"), }, 256: { - MustHexID("a8593af8a4aef7b806b5197612017951bac8845a1917ca9a6a15dd6086d608505144990b245785c4cd2d67a295701c7aac2aa18823fb0033987284b019656268"), - MustHexID("d2eebef914928c3aad77fc1b2a495f52d2294acf5edaa7d8a530b540f094b861a68fe8348a46a7c302f08ab609d85912a4968eacfea0740847b29421b4795d9e"), - MustHexID("b14bfcb31495f32b650b63cf7d08492e3e29071fdc73cf2da0da48d4b191a70ba1a65f42ad8c343206101f00f8a48e8db4b08bf3f622c0853e7323b250835b91"), - MustHexID("7feaee0d818c03eb30e4e0bf03ade0f3c21ca38e938a761aa1781cf70bda8cc5cd631a6cc53dd44f1d4a6d3e2dae6513c6c66ee50cb2f0e9ad6f7e319b309fd9"), - MustHexID("4ca3b657b139311db8d583c25dd5963005e46689e1317620496cc64129c7f3e52870820e0ec7941d28809311df6db8a2867bbd4f235b4248af24d7a9c22d1232"), - MustHexID("1181defb1d16851d42dd951d84424d6bd1479137f587fa184d5a8152be6b6b16ed08bcdb2c2ed8539bcde98c80c432875f9f724737c316a2bd385a39d3cab1d8"), - MustHexID("d9dd818769fa0c3ec9f553c759b92476f082817252a04a47dc1777740b1731d280058c66f982812f173a294acf4944a85ba08346e2de153ba3ba41ce8a62cb64"), - MustHexID("bd7c4f8a9e770aa915c771b15e107ca123d838762da0d3ffc53aa6b53e9cd076cffc534ec4d2e4c334c683f1f5ea72e0e123f6c261915ed5b58ac1b59f003d88"), - MustHexID("3dd5739c73649d510456a70e9d6b46a855864a4a3f744e088fd8c8da11b18e4c9b5f2d7da50b1c147b2bae5ca9609ae01f7a3cdea9dce34f80a91d29cd82f918"), - MustHexID("f0d7df1efc439b4bcc0b762118c1cfa99b2a6143a9f4b10e3c9465125f4c9fca4ab88a2504169bbcad65492cf2f50da9dd5d077c39574a944f94d8246529066b"), - MustHexID("dd598b9ba441448e5fb1a6ec6c5f5aa9605bad6e223297c729b1705d11d05f6bfd3d41988b694681ae69bb03b9a08bff4beab5596503d12a39bffb5cd6e94c7c"), - MustHexID("3fce284ac97e567aebae681b15b7a2b6df9d873945536335883e4bbc26460c064370537f323fd1ada828ea43154992d14ac0cec0940a2bd2a3f42ec156d60c83"), - MustHexID("7c8dfa8c1311cb14fb29a8ac11bca23ecc115e56d9fcf7b7ac1db9066aa4eb39f8b1dabf46e192a65be95ebfb4e839b5ab4533fef414921825e996b210dd53bd"), - MustHexID("cafa6934f82120456620573d7f801390ed5e16ed619613a37e409e44ab355ef755e83565a913b48a9466db786f8d4fbd590bfec474c2524d4a2608d4eafd6abd"), - MustHexID("9d16600d0dd310d77045769fed2cb427f32db88cd57d86e49390c2ba8a9698cfa856f775be2013237226e7bf47b248871cf865d23015937d1edeb20db5e3e760"), - MustHexID("17be6b6ba54199b1d80eff866d348ea11d8a4b341d63ad9a6681d3ef8a43853ac564d153eb2a8737f0afc9ab320f6f95c55aa11aaa13bbb1ff422fd16bdf8188"), + hexEncPubkey("a8593af8a4aef7b806b5197612017951bac8845a1917ca9a6a15dd6086d608505144990b245785c4cd2d67a295701c7aac2aa18823fb0033987284b019656268"), + hexEncPubkey("d2eebef914928c3aad77fc1b2a495f52d2294acf5edaa7d8a530b540f094b861a68fe8348a46a7c302f08ab609d85912a4968eacfea0740847b29421b4795d9e"), + hexEncPubkey("b14bfcb31495f32b650b63cf7d08492e3e29071fdc73cf2da0da48d4b191a70ba1a65f42ad8c343206101f00f8a48e8db4b08bf3f622c0853e7323b250835b91"), + hexEncPubkey("7feaee0d818c03eb30e4e0bf03ade0f3c21ca38e938a761aa1781cf70bda8cc5cd631a6cc53dd44f1d4a6d3e2dae6513c6c66ee50cb2f0e9ad6f7e319b309fd9"), + hexEncPubkey("4ca3b657b139311db8d583c25dd5963005e46689e1317620496cc64129c7f3e52870820e0ec7941d28809311df6db8a2867bbd4f235b4248af24d7a9c22d1232"), + hexEncPubkey("1181defb1d16851d42dd951d84424d6bd1479137f587fa184d5a8152be6b6b16ed08bcdb2c2ed8539bcde98c80c432875f9f724737c316a2bd385a39d3cab1d8"), + hexEncPubkey("d9dd818769fa0c3ec9f553c759b92476f082817252a04a47dc1777740b1731d280058c66f982812f173a294acf4944a85ba08346e2de153ba3ba41ce8a62cb64"), + hexEncPubkey("bd7c4f8a9e770aa915c771b15e107ca123d838762da0d3ffc53aa6b53e9cd076cffc534ec4d2e4c334c683f1f5ea72e0e123f6c261915ed5b58ac1b59f003d88"), + hexEncPubkey("3dd5739c73649d510456a70e9d6b46a855864a4a3f744e088fd8c8da11b18e4c9b5f2d7da50b1c147b2bae5ca9609ae01f7a3cdea9dce34f80a91d29cd82f918"), + hexEncPubkey("f0d7df1efc439b4bcc0b762118c1cfa99b2a6143a9f4b10e3c9465125f4c9fca4ab88a2504169bbcad65492cf2f50da9dd5d077c39574a944f94d8246529066b"), + hexEncPubkey("dd598b9ba441448e5fb1a6ec6c5f5aa9605bad6e223297c729b1705d11d05f6bfd3d41988b694681ae69bb03b9a08bff4beab5596503d12a39bffb5cd6e94c7c"), + hexEncPubkey("3fce284ac97e567aebae681b15b7a2b6df9d873945536335883e4bbc26460c064370537f323fd1ada828ea43154992d14ac0cec0940a2bd2a3f42ec156d60c83"), + hexEncPubkey("7c8dfa8c1311cb14fb29a8ac11bca23ecc115e56d9fcf7b7ac1db9066aa4eb39f8b1dabf46e192a65be95ebfb4e839b5ab4533fef414921825e996b210dd53bd"), + hexEncPubkey("cafa6934f82120456620573d7f801390ed5e16ed619613a37e409e44ab355ef755e83565a913b48a9466db786f8d4fbd590bfec474c2524d4a2608d4eafd6abd"), + hexEncPubkey("9d16600d0dd310d77045769fed2cb427f32db88cd57d86e49390c2ba8a9698cfa856f775be2013237226e7bf47b248871cf865d23015937d1edeb20db5e3e760"), + hexEncPubkey("17be6b6ba54199b1d80eff866d348ea11d8a4b341d63ad9a6681d3ef8a43853ac564d153eb2a8737f0afc9ab320f6f95c55aa11aaa13bbb1ff422fd16bdf8188"), }, }, } type preminedTestnet struct { - target NodeID - targetSha common.Hash // sha3(target) - dists [hashBits + 1][]NodeID + target encPubkey + targetSha enode.ID // sha3(target) + dists [hashBits + 1][]encPubkey } -func (tn *preminedTestnet) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { +func (tn *preminedTestnet) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { // current log distance is encoded in port number // fmt.Println("findnode query at dist", toaddr.Port) if toaddr.Port == 0 { panic("query to node at distance 0") } - next := uint16(toaddr.Port) - 1 - var result []*Node - for i, id := range tn.dists[toaddr.Port] { - result = append(result, NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i))) + next := toaddr.Port - 1 + var result []*node + for i, ekey := range tn.dists[toaddr.Port] { + key, _ := decodePubkey(ekey) + node := wrapNode(enode.NewV4(key, net.ParseIP("127.0.0.1"), i, next)) + result = append(result, node) } return result, nil } -func (*preminedTestnet) close() {} -func (*preminedTestnet) waitping(from NodeID) error { return nil } -func (*preminedTestnet) ping(toid NodeID, toaddr *net.UDPAddr) error { return nil } +func (*preminedTestnet) close() {} +func (*preminedTestnet) waitping(from enode.ID) error { return nil } +func (*preminedTestnet) ping(toid enode.ID, toaddr *net.UDPAddr) error { return nil } // mine generates a testnet struct literal with nodes at // various distances to the given target. -func (n *preminedTestnet) mine(target NodeID) { - n.target = target - n.targetSha = crypto.Keccak256Hash(n.target[:]) +func (tn *preminedTestnet) mine(target encPubkey) { + tn.target = target + tn.targetSha = tn.target.id() found := 0 for found < bucketSize*10 { k := newkey() - id := PubkeyID(&k.PublicKey) - sha := crypto.Keccak256Hash(id[:]) - ld := logdist(n.targetSha, sha) - if len(n.dists[ld]) < bucketSize { - n.dists[ld] = append(n.dists[ld], id) + key := encodePubkey(&k.PublicKey) + ld := enode.LogDist(tn.targetSha, key.id()) + if len(tn.dists[ld]) < bucketSize { + tn.dists[ld] = append(tn.dists[ld], key) fmt.Println("found ID with ld", ld) found++ } } fmt.Println("&preminedTestnet{") - fmt.Printf(" target: %#v,\n", n.target) - fmt.Printf(" targetSha: %#v,\n", n.targetSha) - fmt.Printf(" dists: [%d][]NodeID{\n", len(n.dists)) - for ld, ns := range n.dists { + fmt.Printf(" target: %#v,\n", tn.target) + fmt.Printf(" targetSha: %#v,\n", tn.targetSha) + fmt.Printf(" dists: [%d][]encPubkey{\n", len(tn.dists)) + for ld, ns := range tn.dists { if len(ns) == 0 { continue } - fmt.Printf(" %d: []NodeID{\n", ld) + fmt.Printf(" %d: []encPubkey{\n", ld) for _, n := range ns { - fmt.Printf(" MustHexID(\"%x\"),\n", n[:]) + fmt.Printf(" hexEncPubkey(\"%x\"),\n", n[:]) } fmt.Println(" },") } @@ -616,40 +564,6 @@ func (n *preminedTestnet) mine(target NodeID) { fmt.Println("}") } -func hasDuplicates(slice []*Node) bool { - seen := make(map[NodeID]bool) - for i, e := range slice { - if e == nil { - panic(fmt.Sprintf("nil *Node at %d", i)) - } - if seen[e.ID] { - return true - } - seen[e.ID] = true - } - return false -} - -func sortedByDistanceTo(distbase common.Hash, slice []*Node) bool { - var last common.Hash - for i, e := range slice { - if i > 0 && distcmp(distbase, e.sha, last) < 0 { - return false - } - last = e.sha - } - return true -} - -func contains(ns []*Node, id NodeID) bool { - for _, n := range ns { - if n.ID == id { - return true - } - } - return false -} - // gen wraps quick.Value so it's easier to use. // it generates a random value of the given value's type. func gen(typ interface{}, rand *rand.Rand) interface{} { @@ -660,6 +574,13 @@ func gen(typ interface{}, rand *rand.Rand) interface{} { return v.Interface() } +func quickcfg() *quick.Config { + return &quick.Config{ + MaxCount: 5000, + Rand: rand.New(rand.NewSource(time.Now().Unix())), + } +} + func newkey() *ecdsa.PrivateKey { key, err := crypto.GenerateKey() if err != nil { diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go new file mode 100644 index 0000000000..fe55eb5627 --- /dev/null +++ b/p2p/discover/table_util_test.go @@ -0,0 +1,167 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "crypto/ecdsa" + "encoding/hex" + "fmt" + "math/rand" + "net" + "sync" + + "github.com/tomochain/tomochain/p2p/enode" + "github.com/tomochain/tomochain/p2p/enr" +) + +func newTestTable(t transport) (*Table, *enode.DB) { + var r enr.Record + r.Set(enr.IP{0, 0, 0, 0}) + n := enode.SignNull(&r, enode.ID{}) + db, _ := enode.OpenDB("") + tab, _ := newTable(t, n, db, nil) + return tab, db +} + +// nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld. +func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node { + var r enr.Record + r.Set(enr.IP(ip)) + return wrapNode(enode.SignNull(&r, idAtDistance(base, ld))) +} + +// idAtDistance returns a random hash such that enode.LogDist(a, b) == n +func idAtDistance(a enode.ID, n int) (b enode.ID) { + if n == 0 { + return a + } + // flip bit at position n, fill the rest with random bits + b = a + pos := len(a) - n/8 - 1 + bit := byte(0x01) << (byte(n%8) - 1) + if bit == 0 { + pos++ + bit = 0x80 + } + b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits + for i := pos + 1; i < len(a); i++ { + b[i] = byte(rand.Intn(255)) + } + return b +} + +func intIP(i int) net.IP { + return net.IP{byte(i), 0, 2, byte(i)} +} + +// fillBucket inserts nodes into the given bucket until it is full. +func fillBucket(tab *Table, n *node) (last *node) { + ld := enode.LogDist(tab.self.ID(), n.ID()) + b := tab.bucket(n.ID()) + for len(b.entries) < bucketSize { + b.entries = append(b.entries, nodeAtDistance(tab.self.ID(), ld, intIP(ld))) + } + return b.entries[bucketSize-1] +} + +type pingRecorder struct { + mu sync.Mutex + dead, pinged map[enode.ID]bool +} + +func newPingRecorder() *pingRecorder { + return &pingRecorder{ + dead: make(map[enode.ID]bool), + pinged: make(map[enode.ID]bool), + } +} + +func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { + return nil, nil +} + +func (t *pingRecorder) waitping(from enode.ID) error { + return nil // remote always pings +} + +func (t *pingRecorder) ping(toid enode.ID, toaddr *net.UDPAddr) error { + t.mu.Lock() + defer t.mu.Unlock() + + t.pinged[toid] = true + if t.dead[toid] { + return errTimeout + } else { + return nil + } +} + +func (t *pingRecorder) close() {} + +func hasDuplicates(slice []*node) bool { + seen := make(map[enode.ID]bool) + for i, e := range slice { + if e == nil { + panic(fmt.Sprintf("nil *Node at %d", i)) + } + if seen[e.ID()] { + return true + } + seen[e.ID()] = true + } + return false +} + +func contains(ns []*node, id enode.ID) bool { + for _, n := range ns { + if n.ID() == id { + return true + } + } + return false +} + +func sortedByDistanceTo(distbase enode.ID, slice []*node) bool { + var last enode.ID + for i, e := range slice { + if i > 0 && enode.DistCmp(distbase, e.ID(), last) < 0 { + return false + } + last = e.ID() + } + return true +} + +func hexEncPubkey(h string) (ret encPubkey) { + b, err := hex.DecodeString(h) + if err != nil { + panic(err) + } + if len(b) != len(ret) { + panic("invalid length") + } + copy(ret[:], b) + return ret +} + +func hexPubkey(h string) *ecdsa.PublicKey { + k, err := decodePubkey(hexEncPubkey(h)) + if err != nil { + panic(err) + } + return k +} diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 051477cb59..3b73e29393 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -27,13 +27,12 @@ import ( "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/nat" "github.com/tomochain/tomochain/p2p/netutil" "github.com/tomochain/tomochain/rlp" ) -const Version = 4 - // Errors var ( errPacketTooSmall = errors.New("too small") @@ -48,9 +47,9 @@ var ( // Timeouts const ( - respTimeout = 500 * time.Millisecond - sendTimeout = 500 * time.Millisecond - expiration = 20 * time.Second + respTimeout = 500 * time.Millisecond + expiration = 20 * time.Second + bondExpiration = 24 * time.Hour ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning @@ -63,7 +62,6 @@ const ( pongPacket findnodePacket neighborsPacket - pingTomo ) // RPC request structures @@ -91,7 +89,7 @@ type ( // findnode is a query for nodes close to the given target. findnode struct { - Target NodeID // doesn't need to be an actual public key + Target encPubkey Expiration uint64 // Ignore additional fields (for forward compatibility). Rest []rlp.RawValue `rlp:"tail"` @@ -109,7 +107,7 @@ type ( IP net.IP // len 4 for IPv4 or 16 for IPv6 UDP uint16 // for discovery protocol TCP uint16 // for RLPx protocol - ID NodeID + ID encPubkey } rpcEndpoint struct { @@ -127,7 +125,7 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} } -func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { +func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*node, error) { if rn.UDP <= 1024 { return nil, errors.New("low port") } @@ -137,17 +135,26 @@ func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { return nil, errors.New("not contained in netrestrict whitelist") } - n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) - err := n.validateComplete() + key, err := decodePubkey(rn.ID) + if err != nil { + return nil, err + } + n := wrapNode(enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP))) + err = n.ValidateComplete() return n, err } -func nodeToRPC(n *Node) rpcNode { - return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP} +func nodeToRPC(n *node) rpcNode { + var key ecdsa.PublicKey + var ekey encPubkey + if err := n.Load((*enode.Secp256k1)(&key)); err == nil { + ekey = encodePubkey(&key) + } + return rpcNode{ID: ekey, IP: n.IP(), UDP: uint16(n.UDP()), TCP: uint16(n.TCP())} } type packet interface { - handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error + handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error name() string } @@ -185,7 +192,7 @@ type udp struct { // to all the callback functions for that node. type pending struct { // these fields must match in the reply. - from NodeID + from enode.ID ptype byte // time when the request must complete @@ -203,7 +210,7 @@ type pending struct { } type reply struct { - from NodeID + from enode.ID ptype byte data interface{} // loop indicates whether there was @@ -226,7 +233,7 @@ type Config struct { AnnounceAddr *net.UDPAddr // local address announced in the DHT NodeDBPath string // if set, the node database is stored at this filesystem location NetRestrict *netutil.Netlist // network whitelist - Bootnodes []*Node // list of bootstrap nodes + Bootnodes []*enode.Node // list of bootstrap nodes Unhandled chan<- ReadPacket // unhandled packets are sent on this channel } @@ -241,6 +248,16 @@ func ListenUDP(c conn, cfg Config) (*Table, error) { } func newUDP(c conn, cfg Config) (*Table, *udp, error) { + realaddr := c.LocalAddr().(*net.UDPAddr) + if cfg.AnnounceAddr != nil { + realaddr = cfg.AnnounceAddr + } + self := enode.NewV4(&cfg.PrivateKey.PublicKey, realaddr.IP, realaddr.Port, realaddr.Port) + db, err := enode.OpenDB(cfg.NodeDBPath) + if err != nil { + return nil, nil, err + } + udp := &udp{ conn: c, priv: cfg.PrivateKey, @@ -249,13 +266,9 @@ func newUDP(c conn, cfg Config) (*Table, *udp, error) { gotreply: make(chan reply), addpending: make(chan *pending), } - realaddr := c.LocalAddr().(*net.UDPAddr) - if cfg.AnnounceAddr != nil { - realaddr = cfg.AnnounceAddr - } // TODO: separate TCP port udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) - tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes) + tab, err := newTable(udp, self, db, cfg.Bootnodes) if err != nil { return nil, nil, err } @@ -269,36 +282,56 @@ func newUDP(c conn, cfg Config) (*Table, *udp, error) { func (t *udp) close() { close(t.closing) t.conn.Close() + t.db.Close() // TODO: wait for the loops to end. } // ping sends a ping message to the given node and waits for a reply. -func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { +func (t *udp) ping(toid enode.ID, toaddr *net.UDPAddr) error { + return <-t.sendPing(toid, toaddr, nil) +} + +// sendPing sends a ping message to the given node and invokes the callback +// when the reply arrives. +func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-chan error { req := &ping{ - Version: Version, + Version: 4, From: t.ourEndpoint, To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB Expiration: uint64(time.Now().Add(expiration).Unix()), } - packet, hash, err := encodePacket(t.priv, pingTomo, req) + packet, hash, err := encodePacket(t.priv, pingPacket, req) if err != nil { - return err + errc := make(chan error, 1) + errc <- err + return errc } errc := t.pending(toid, pongPacket, func(p interface{}) bool { - return bytes.Equal(p.(*pong).ReplyTok, hash) + ok := bytes.Equal(p.(*pong).ReplyTok, hash) + if ok && callback != nil { + callback() + } + return ok }) t.write(toaddr, req.name(), packet) - return <-errc + return errc } -func (t *udp) waitping(from NodeID) error { - return <-t.pending(from, pingTomo, func(interface{}) bool { return true }) +func (t *udp) waitping(from enode.ID) error { + return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) } // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. -func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { - nodes := make([]*Node, 0, bucketSize) +func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { + // If we haven't seen a ping from the destination node for a while, it won't remember + // our endpoint proof and reject findnode. Solicit a ping first. + if time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration { + t.ping(toid, toaddr) + t.waitping(toid) + } + + nodes := make([]*node, 0, bucketSize) nreceived := 0 errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { reply := r.(*neighbors) @@ -317,13 +350,12 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node Target: target, Expiration: uint64(time.Now().Add(expiration).Unix()), }) - err := <-errc - return nodes, err + return nodes, <-errc } // pending adds a reply callback to the pending reply queue. // see the documentation of type pending for a detailed explanation. -func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error { +func (t *udp) pending(id enode.ID, ptype byte, callback func(interface{}) bool) <-chan error { ch := make(chan error, 1) p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} select { @@ -335,7 +367,7 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <- return ch } -func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { +func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool { matched := make(chan bool, 1) select { case t.gotreply <- reply{from, ptype, req, matched}: @@ -549,22 +581,23 @@ func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { return err } -func decodePacket(buf []byte) (packet, NodeID, []byte, error) { +func decodePacket(buf []byte) (packet, encPubkey, []byte, error) { if len(buf) < headSize+1 { - return nil, NodeID{}, nil, errPacketTooSmall + return nil, encPubkey{}, nil, errPacketTooSmall } hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] shouldhash := crypto.Keccak256(buf[macSize:]) if !bytes.Equal(hash, shouldhash) { - return nil, NodeID{}, nil, errBadHash + return nil, encPubkey{}, nil, errBadHash } - fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig) + fromKey, err := recoverNodeKey(crypto.Keccak256(buf[headSize:]), sig) if err != nil { - return nil, NodeID{}, hash, err + return nil, fromKey, hash, err } + var req packet switch ptype := sigdata[0]; ptype { - case pingTomo: + case pingPacket: req = new(ping) case pongPacket: req = new(pong) @@ -573,68 +606,78 @@ func decodePacket(buf []byte) (packet, NodeID, []byte, error) { case neighborsPacket: req = new(neighbors) default: - return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) + return nil, fromKey, hash, fmt.Errorf("unknown type: %d", ptype) } s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0) err = s.Decode(req) - return req, fromID, hash, err + return req, fromKey, hash, err } -func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { +func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { if expired(req.Expiration) { return errExpired } + key, err := decodePubkey(fromKey) + if err != nil { + return fmt.Errorf("invalid public key: %v", err) + } t.send(from, pongPacket, &pong{ To: makeEndpoint(from, req.From.TCP), ReplyTok: mac, Expiration: uint64(time.Now().Add(expiration).Unix()), }) - if !t.handleReply(fromID, pingTomo, req) { - // Note: we're ignoring the provided IP address right now - go t.bond(true, fromID, from, req.From.TCP) - } + n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port)) + t.handleReply(n.ID(), pingPacket, req) + if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration { + t.sendPing(n.ID(), from, func() { t.addThroughPing(n) }) + } else { + t.addThroughPing(n) + } + t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now()) return nil } -func (req *ping) name() string { return "PING TOMO/v4" } +func (req *ping) name() string { return "PING/v4" } -func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { +func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { if expired(req.Expiration) { return errExpired } + fromID := fromKey.id() if !t.handleReply(fromID, pongPacket, req) { return errUnsolicitedReply } + t.db.UpdateLastPongReceived(fromID, from.IP, time.Now()) return nil } func (req *pong) name() string { return "PONG/v4" } -func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { +func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { if expired(req.Expiration) { return errExpired } - if !t.db.hasBond(fromID) { - // No bond exists, we don't process the packet. This prevents - // an attack vector where the discovery protocol could be used - // to amplify traffic in a DDOS attack. A malicious actor - // would send a findnode request with the IP address and UDP - // port of the target as the source address. The recipient of - // the findnode packet would then send a neighbors packet - // (which is a much bigger packet than findnode) to the victim. + fromID := fromKey.id() + if time.Since(t.db.LastPongReceived(fromID, from.IP)) > bondExpiration { + // No endpoint proof pong exists, we don't process the packet. This prevents an + // attack vector where the discovery protocol could be used to amplify traffic in a + // DDOS attack. A malicious actor would send a findnode request with the IP address + // and UDP port of the target as the source address. The recipient of the findnode + // packet would then send a neighbors packet (which is a much bigger packet than + // findnode) to the victim. return errUnknownNode } - target := crypto.Keccak256Hash(req.Target[:]) + target := enode.ID(crypto.Keccak256Hash(req.Target[:])) t.mutex.Lock() closest := t.closest(target, bucketSize).entries t.mutex.Unlock() - log.Trace("find neighbors ", "from", from, "fromID", fromID, "closest", len(closest)) + p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} var sent bool // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the 1280 byte limit. for _, n := range closest { - if netutil.CheckRelayIP(from.IP, n.IP) == nil { + if netutil.CheckRelayIP(from.IP, n.IP()) == nil { p.Nodes = append(p.Nodes, nodeToRPC(n)) } if len(p.Nodes) == maxNeighbors { @@ -651,11 +694,11 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte func (req *findnode) name() string { return "FINDNODE/v4" } -func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { +func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte) error { if expired(req.Expiration) { return errExpired } - if !t.handleReply(fromID, neighborsPacket, req) { + if !t.handleReply(fromKey.id(), neighborsPacket, req) { return errUnsolicitedReply } return nil diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index b13a79658d..82ca1ef19f 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -36,6 +36,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rlp" ) @@ -46,7 +47,7 @@ func init() { // shared test variables var ( futureExp = uint64(time.Now().Add(10 * time.Hour).Unix()) - testTarget = NodeID{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1} + testTarget = encPubkey{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1} testRemote = rpcEndpoint{IP: net.ParseIP("1.1.1.1").To4(), UDP: 1, TCP: 2} testLocalAnnounced = rpcEndpoint{IP: net.ParseIP("2.2.2.2").To4(), UDP: 3, TCP: 4} testLocal = rpcEndpoint{IP: net.ParseIP("3.3.3.3").To4(), UDP: 5, TCP: 6} @@ -124,7 +125,7 @@ func TestUDP_packetErrors(t *testing.T) { test := newUDPTest(t) defer test.table.Close() - test.packetIn(errExpired, pingTomo, &ping{From: testRemote, To: testLocalAnnounced, Version: Version}) + test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4}) test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp}) test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp}) test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp}) @@ -136,7 +137,7 @@ func TestUDP_pingTimeout(t *testing.T) { defer test.table.Close() toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} - toid := NodeID{1, 2, 3, 4} + toid := enode.ID{1, 2, 3, 4} if err := test.udp.ping(toid, toaddr); err != errTimeout { t.Error("expected timeout error, got", err) } @@ -220,8 +221,8 @@ func TestUDP_findnodeTimeout(t *testing.T) { defer test.table.Close() toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} - toid := NodeID{1, 2, 3, 4} - target := NodeID{4, 5, 6, 7} + toid := enode.ID{1, 2, 3, 4} + target := encPubkey{4, 5, 6, 7} result, err := test.udp.findnode(toid, toaddr, target) if err != errTimeout { t.Error("expected timeout error, got", err) @@ -232,35 +233,36 @@ func TestUDP_findnodeTimeout(t *testing.T) { } func TestUDP_findnode(t *testing.T) { - bucketSizeTest := 16 test := newUDPTest(t) defer test.table.Close() // put a few nodes into the table. their exact // distribution shouldn't matter much, although we need to // take care not to overflow any bucket. - targetHash := crypto.Keccak256Hash(testTarget[:]) - nodes := &nodesByDistance{target: targetHash} - for i := 0; i < bucketSizeTest; i++ { - nodes.push(nodeAtDistance(test.table.self.sha, i+2), bucketSizeTest) + nodes := &nodesByDistance{target: testTarget.id()} + for i := 0; i < bucketSize; i++ { + key := newkey() + n := wrapNode(enode.NewV4(&key.PublicKey, net.IP{10, 13, 0, 1}, 0, i)) + nodes.push(n, bucketSize) } test.table.stuff(nodes.entries) // ensure there's a bond with the test node, // findnode won't be accepted otherwise. - test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now()) + remoteID := encodePubkey(&test.remotekey.PublicKey).id() + test.table.db.UpdateLastPongReceived(remoteID, time.Now()) // check that closest neighbors are returned. test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) - expected := test.table.closest(targetHash, bucketSizeTest) + expected := test.table.closest(testTarget.id(), bucketSize) - waitNeighbors := func(want []*Node) { + waitNeighbors := func(want []*node) { test.waitPacketOut(func(p *neighbors) { if len(p.Nodes) != len(want) { - t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSizeTest) + t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize) } for i := range p.Nodes { - if p.Nodes[i].ID != want[i].ID { + if p.Nodes[i].ID.id() != want[i].ID() { t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i]) } } @@ -274,10 +276,13 @@ func TestUDP_findnodeMultiReply(t *testing.T) { test := newUDPTest(t) defer test.table.Close() + rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey) + test.table.db.UpdateLastPingReceived(rid, time.Now()) + // queue a pending findnode request - resultc, errc := make(chan []*Node), make(chan error) + resultc, errc := make(chan []*node), make(chan error) go func() { - rid := PubkeyID(&test.remotekey.PublicKey) + rid := encodePubkey(&test.remotekey.PublicKey).id() ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget) if err != nil && len(ns) == 0 { errc <- err @@ -295,11 +300,11 @@ func TestUDP_findnodeMultiReply(t *testing.T) { }) // send the reply as two packets. - list := []*Node{ - MustParseNode("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304"), - MustParseNode("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"), - MustParseNode("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17"), - MustParseNode("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"), + list := []*node{ + wrapNode(enode.MustParseV4("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304")), + wrapNode(enode.MustParseV4("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303")), + wrapNode(enode.MustParseV4("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17")), + wrapNode(enode.MustParseV4("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303")), } rpclist := make([]rpcNode, len(list)) for i := range list { @@ -324,12 +329,12 @@ func TestUDP_findnodeMultiReply(t *testing.T) { func TestUDP_successfulPing(t *testing.T) { test := newUDPTest(t) - added := make(chan *Node, 1) - test.table.nodeAddedHook = func(n *Node) { added <- n } + added := make(chan *node, 1) + test.table.nodeAddedHook = func(n *node) { added <- n } defer test.table.Close() // The remote side sends a ping packet to initiate the exchange. - go test.packetIn(nil, pingTomo, &ping{From: testRemote, To: testLocalAnnounced, Version: Version, Expiration: futureExp}) + go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) // the ping is replied to. test.waitPacketOut(func(p *pong) { @@ -369,18 +374,18 @@ func TestUDP_successfulPing(t *testing.T) { // pong packet. select { case n := <-added: - rid := PubkeyID(&test.remotekey.PublicKey) - if n.ID != rid { - t.Errorf("node has wrong ID: got %v, want %v", n.ID, rid) + rid := encodePubkey(&test.remotekey.PublicKey).id() + if n.ID() != rid { + t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid) } - if !n.IP.Equal(test.remoteaddr.IP) { - t.Errorf("node has wrong IP: got %v, want: %v", n.IP, test.remoteaddr.IP) + if !n.IP().Equal(test.remoteaddr.IP) { + t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.IP) } - if int(n.UDP) != test.remoteaddr.Port { - t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP, test.remoteaddr.Port) + if int(n.UDP()) != test.remoteaddr.Port { + t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port) } - if n.TCP != testRemote.TCP { - t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP, testRemote.TCP) + if n.TCP() != int(testRemote.TCP) { + t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP(), testRemote.TCP) } case <-time.After(2 * time.Second): t.Errorf("node was not added within 2 seconds") @@ -392,7 +397,7 @@ var testPackets = []struct { wantPacket interface{} }{ { - input: "95a4d7d1909e6a58f115e9a451d47a8f016776a8874140366e702e33e85c7b4cd58a82ebece6acd0973342b66b9e716fece46b5c67a3560fc8624063dd15a310469de42ca599474b9d8cb6eb8dc41b0d5236539ea7ae10ef3c630cd94faefd800005ea04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a355", + input: "71dbda3a79554728d4f94411e42ee1f8b0d561c10e1e5f5893367948c6a7d70bb87b235fa28a77070271b6c164a2dce8c7e13a5739b53b5e96f2e5acb0e458a02902f5965d55ecbeb2ebb6cabb8b2b232896a36b737666c55265ad0a68412f250001ea04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a355", wantPacket: &ping{ Version: 4, From: rpcEndpoint{net.ParseIP("127.0.0.1").To4(), 3322, 5544}, @@ -402,7 +407,7 @@ var testPackets = []struct { }, }, { - input: "57b1c182cc24e21e9297baa70d57a67ade498439123c968ffc048541addf9d463d1d25d10cf473a7f90a3efd6a070818097ebeaef58cd53843cb3af28acaee354272cfe7801b7fa7dbd8aa13309b6059fce877ad376c8dad7524dc34de626bd80105ec04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a3550102", + input: "e9614ccfd9fc3e74360018522d30e1419a143407ffcce748de3e22116b7e8dc92ff74788c0b6663aaa3d67d641936511c8f8d6ad8698b820a7cf9e1be7155e9a241f556658c55428ec0563514365799a4be2be5a685a80971ddcfa80cb422cdd0101ec04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a3550102", wantPacket: &ping{ Version: 4, From: rpcEndpoint{net.ParseIP("127.0.0.1").To4(), 3322, 5544}, @@ -412,7 +417,7 @@ var testPackets = []struct { }, }, { - input: "e3e987421accd2c75967d4a7229c436c18760def054738d8d9669697ee4726cdc9949c51df3e90d795d33d3f57d508c4687913338f6eb9caa89873aaae9dd49a5473ade5ea452c4df9d1f842eadf03439dbc373c0de8b20b412b6760d7b479140105f83e82022bd79020010db83c4d001500000000abcdef12820cfa8215a8d79020010db885a308d313198a2e037073488208ae82823a8443b9a355c50102030405", + input: "577be4349c4dd26768081f58de4c6f375a7a22f3f7adda654d1428637412c3d7fe917cadc56d4e5e7ffae1dbe3efffb9849feb71b262de37977e7c7a44e677295680e9e38ab26bee2fcbae207fba3ff3d74069a50b902a82c9903ed37cc993c50001f83e82022bd79020010db83c4d001500000000abcdef12820cfa8215a8d79020010db885a308d313198a2e037073488208ae82823a8443b9a355c5010203040531b9019afde696e582a78fa8d95ea13ce3297d4afb8ba6433e4154caa5ac6431af1b80ba76023fa4090c408f6b4bc3701562c031041d4702971d102c9ab7fa5eed4cd6bab8f7af956f7d565ee1917084a95398b6a21eac920fe3dd1345ec0a7ef39367ee69ddf092cbfe5b93e5e568ebc491983c09c76d922dc3", wantPacket: &ping{ Version: 555, From: rpcEndpoint{net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), 3322, 5544}, @@ -433,7 +438,7 @@ var testPackets = []struct { { input: "c7c44041b9f7c7e41934417ebac9a8e1a4c6298f74553f2fcfdcae6ed6fe53163eb3d2b52e39fe91831b8a927bf4fc222c3902202027e5e9eb812195f95d20061ef5cd31d502e47ecb61183f74a504fe04c51e73df81f25c4d506b26db4517490103f84eb840ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31387574077f301b421bc84df7266c44e9e6d569fc56be00812904767bf5ccd1fc7f8443b9a35582999983999999280dc62cc8255c73471e0a61da0c89acdc0e035e260add7fc0c04ad9ebf3919644c91cb247affc82b69bd2ca235c71eab8e49737c937a2c396", wantPacket: &findnode{ - Target: MustHexID("ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31387574077f301b421bc84df7266c44e9e6d569fc56be00812904767bf5ccd1fc7f"), + Target: hexEncPubkey("ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31387574077f301b421bc84df7266c44e9e6d569fc56be00812904767bf5ccd1fc7f"), Expiration: 1136239445, Rest: []rlp.RawValue{{0x82, 0x99, 0x99}, {0x83, 0x99, 0x99, 0x99}}, }, @@ -443,25 +448,25 @@ var testPackets = []struct { wantPacket: &neighbors{ Nodes: []rpcNode{ { - ID: MustHexID("3155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32"), + ID: hexEncPubkey("3155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32"), IP: net.ParseIP("99.33.22.55").To4(), UDP: 4444, TCP: 4445, }, { - ID: MustHexID("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069db"), + ID: hexEncPubkey("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069db"), IP: net.ParseIP("1.2.3.4").To4(), UDP: 1, TCP: 1, }, { - ID: MustHexID("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aac"), + ID: hexEncPubkey("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aac"), IP: net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), UDP: 3333, TCP: 3333, }, { - ID: MustHexID("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df73"), + ID: hexEncPubkey("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df73"), IP: net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"), UDP: 999, TCP: 1000, @@ -475,13 +480,14 @@ var testPackets = []struct { func TestForwardCompatibility(t *testing.T) { testkey, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") - wantNodeID := PubkeyID(&testkey.PublicKey) + wantNodeKey := encodePubkey(&testkey.PublicKey) + for _, test := range testPackets { input, err := hex.DecodeString(test.input) if err != nil { t.Fatalf("invalid hex: %s", test.input) } - packet, nodeid, _, err := decodePacket(input) + packet, nodekey, _, err := decodePacket(input) if err != nil { t.Errorf("did not accept packet %s\n%v", test.input, err) continue @@ -489,8 +495,8 @@ func TestForwardCompatibility(t *testing.T) { if !reflect.DeepEqual(packet, test.wantPacket) { t.Errorf("got %s\nwant %s", spew.Sdump(packet), spew.Sdump(test.wantPacket)) } - if nodeid != wantNodeID { - t.Errorf("got id %v\nwant id %v", nodeid, wantNodeID) + if nodekey != wantNodeKey { + t.Errorf("got id %v\nwant id %v", nodekey, wantNodeKey) } } } From 47ad2d32ad7769d576202cf1c73ff1bfc94ce4ee Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 30 Oct 2023 16:39:33 +0700 Subject: [PATCH 097/119] Port p2p to p2p/enode --- p2p/dial.go | 127 +++++++--------- p2p/dial_test.go | 21 +-- p2p/message.go | 11 +- p2p/peer.go | 58 +++---- p2p/protocol.go | 4 +- p2p/protocols/protocol_test.go | 12 +- p2p/rlpx.go | 77 +++++----- p2p/rlpx_test.go | 78 +++------- p2p/server.go | 266 +++++++++++++++++++++------------ p2p/server_test.go | 230 +++++++++++++++++++--------- 10 files changed, 504 insertions(+), 380 deletions(-) diff --git a/p2p/dial.go b/p2p/dial.go index 454d2198c2..7f93a12bed 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -18,14 +18,13 @@ package p2p import ( "container/heap" - "crypto/rand" "errors" "fmt" "net" "time" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/netutil" ) @@ -50,7 +49,7 @@ const ( // NodeDialer is used to connect to nodes in the network, typically by using // an underlying net.Dialer but also using net.Pipe in tests type NodeDialer interface { - Dial(*discover.Node) (net.Conn, error) + Dial(*enode.Node) (net.Conn, error) } // TCPDialer implements the NodeDialer interface by using a net.Dialer to @@ -60,8 +59,8 @@ type TCPDialer struct { } // Dial creates a TCP connection to the node -func (t TCPDialer) Dial(dest *discover.Node) (net.Conn, error) { - addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)} +func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) { + addr := &net.TCPAddr{IP: dest.IP(), Port: dest.TCP()} return t.Dialer.Dial("tcp", addr.String()) } @@ -74,22 +73,22 @@ type dialstate struct { netrestrict *netutil.Netlist lookupRunning bool - dialing map[discover.NodeID]connFlag - lookupBuf []*discover.Node // current discovery lookup results - randomNodes []*discover.Node // filled from Table - static map[discover.NodeID]*dialTask + dialing map[enode.ID]connFlag + lookupBuf []*enode.Node // current discovery lookup results + randomNodes []*enode.Node // filled from Table + static map[enode.ID]*dialTask hist *dialHistory - start time.Time // time when the dialer was first used - bootnodes []*discover.Node // default dials when there are no peers + start time.Time // time when the dialer was first used + bootnodes []*enode.Node // default dials when there are no peers } type discoverTable interface { - Self() *discover.Node + Self() *enode.Node Close() - Resolve(target discover.NodeID) *discover.Node - Lookup(target discover.NodeID) []*discover.Node - ReadRandomNodes([]*discover.Node) int + Resolve(*enode.Node) *enode.Node + LookupRandom() []*enode.Node + ReadRandomNodes([]*enode.Node) int } // the dial history remembers recent dials. @@ -97,7 +96,7 @@ type dialHistory []pastDial // pastDial is an entry in the dial history. type pastDial struct { - id discover.NodeID + id enode.ID exp time.Time } @@ -109,7 +108,7 @@ type task interface { // fields cannot be accessed while the task is running. type dialTask struct { flags connFlag - dest *discover.Node + dest *enode.Node lastResolved time.Time resolveDelay time.Duration } @@ -118,7 +117,7 @@ type dialTask struct { // Only one discoverTask is active at any time. // discoverTask.Do performs a random lookup. type discoverTask struct { - results []*discover.Node + results []*enode.Node } // A waitExpireTask is generated if there are no other tasks @@ -127,15 +126,15 @@ type waitExpireTask struct { time.Duration } -func newDialState(static []*discover.Node, bootnodes []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { +func newDialState(static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { s := &dialstate{ maxDynDials: maxdyn, ntab: ntab, netrestrict: netrestrict, - static: make(map[discover.NodeID]*dialTask), - dialing: make(map[discover.NodeID]connFlag), - bootnodes: make([]*discover.Node, len(bootnodes)), - randomNodes: make([]*discover.Node, maxdyn/2), + static: make(map[enode.ID]*dialTask), + dialing: make(map[enode.ID]connFlag), + bootnodes: make([]*enode.Node, len(bootnodes)), + randomNodes: make([]*enode.Node, maxdyn/2), hist: new(dialHistory), } copy(s.bootnodes, bootnodes) @@ -145,32 +144,32 @@ func newDialState(static []*discover.Node, bootnodes []*discover.Node, ntab disc return s } -func (s *dialstate) addStatic(n *discover.Node) { - // This overwites the task instead of updating an existing +func (s *dialstate) addStatic(n *enode.Node) { + // This overwrites the task instead of updating an existing // entry, giving users the opportunity to force a resolve operation. - s.static[n.ID] = &dialTask{flags: staticDialedConn, dest: n} + s.static[n.ID()] = &dialTask{flags: staticDialedConn, dest: n} } -func (s *dialstate) removeStatic(n *discover.Node) { +func (s *dialstate) removeStatic(n *enode.Node) { // This removes a task so future attempts to connect will not be made. - delete(s.static, n.ID) + delete(s.static, n.ID()) // This removes a previous dial timestamp so that application // can force a server to reconnect with chosen peer immediately. - s.hist.remove(n.ID) + s.hist.remove(n.ID()) } -func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { +func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { if s.start.IsZero() { s.start = now } var newtasks []task - addDial := func(flag connFlag, n *discover.Node) bool { + addDial := func(flag connFlag, n *enode.Node) bool { if err := s.checkDial(n, peers); err != nil { - log.Trace("Skipping dial candidate", "id", n.ID, "addr", &net.TCPAddr{IP: n.IP, Port: int(n.TCP)}, "err", err) + log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err) return false } - s.dialing[n.ID] = flag + s.dialing[n.ID()] = flag newtasks = append(newtasks, &dialTask{flags: flag, dest: n}) return true } @@ -196,8 +195,8 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now err := s.checkDial(t.dest, peers) switch err { case errNotWhitelisted, errSelf: - log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}, "err", err) - delete(s.static, t.dest.ID) + log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err) + delete(s.static, t.dest.ID()) case nil: s.dialing[id] = t.flags newtasks = append(newtasks, t) @@ -260,21 +259,18 @@ var ( errNotWhitelisted = errors.New("not contained in netrestrict whitelist") ) -func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error { - _, dialing := s.dialing[n.ID] +func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error { + _, dialing := s.dialing[n.ID()] switch { case dialing: return errAlreadyDialing - case peers[n.ID] != nil: - exitsPeer := peers[n.ID] - if exitsPeer.PairPeer != nil { - return errAlreadyConnected - } - case s.ntab != nil && n.ID == s.ntab.Self().ID: + case peers[n.ID()] != nil: + return errAlreadyConnected + case s.ntab != nil && n.ID() == s.ntab.Self().ID(): return errSelf - case s.netrestrict != nil && !s.netrestrict.Contains(n.IP): + case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()): return errNotWhitelisted - case s.hist.contains(n.ID): + case s.hist.contains(n.ID()): return errRecentlyDialed } return nil @@ -283,8 +279,8 @@ func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) func (s *dialstate) taskDone(t task, now time.Time) { switch t := t.(type) { case *dialTask: - s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration)) - delete(s.dialing, t.dest.ID) + s.hist.add(t.dest.ID(), now.Add(dialHistoryExpiration)) + delete(s.dialing, t.dest.ID()) case *discoverTask: s.lookupRunning = false s.lookupBuf = append(s.lookupBuf, t.results...) @@ -303,26 +299,10 @@ func (t *dialTask) Do(srv *Server) { // Try resolving the ID of static nodes if dialing failed. if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { if t.resolve(srv) { - err = t.dial(srv, t.dest) + t.dial(srv, t.dest) } } } - if err == nil { - err = t.dial(srv, t.dest) - if err != nil { - // Try resolving the ID of static nodes if dialing failed. - if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { - if t.resolve(srv) { - err = t.dial(srv, t.dest) - } - } - } - if err == nil { - log.Trace("Dial pair connection success", "task", t.dest) - } else { - log.Trace("Dial pair connection error", "task", t.dest, "err", err) - } - } } // resolve attempts to find the current endpoint for the destination @@ -342,7 +322,7 @@ func (t *dialTask) resolve(srv *Server) bool { if time.Since(t.lastResolved) < t.resolveDelay { return false } - resolved := srv.ntab.Resolve(t.dest.ID) + resolved := srv.ntab.Resolve(t.dest) t.lastResolved = time.Now() if resolved == nil { t.resolveDelay *= 2 @@ -355,7 +335,7 @@ func (t *dialTask) resolve(srv *Server) bool { // The node was found. t.resolveDelay = initialResolveDelay t.dest = resolved - log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}) + log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) return true } @@ -364,7 +344,7 @@ type dialError struct { } // dial performs the actual connection attempt. -func (t *dialTask) dial(srv *Server, dest *discover.Node) error { +func (t *dialTask) dial(srv *Server, dest *enode.Node) error { fd, err := srv.Dialer.Dial(dest) if err != nil { return &dialError{err} @@ -374,7 +354,8 @@ func (t *dialTask) dial(srv *Server, dest *discover.Node) error { } func (t *dialTask) String() string { - return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP) + id := t.dest.ID() + return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP()) } func (t *discoverTask) Do(srv *Server) { @@ -386,9 +367,7 @@ func (t *discoverTask) Do(srv *Server) { time.Sleep(next.Sub(now)) } srv.lastLookup = time.Now() - var target discover.NodeID - rand.Read(target[:]) - t.results = srv.ntab.Lookup(target) + t.results = srv.ntab.LookupRandom() } func (t *discoverTask) String() string { @@ -410,11 +389,11 @@ func (t waitExpireTask) String() string { func (h dialHistory) min() pastDial { return h[0] } -func (h *dialHistory) add(id discover.NodeID, exp time.Time) { +func (h *dialHistory) add(id enode.ID, exp time.Time) { heap.Push(h, pastDial{id, exp}) } -func (h *dialHistory) remove(id discover.NodeID) bool { +func (h *dialHistory) remove(id enode.ID) bool { for i, v := range *h { if v.id == id { heap.Remove(h, i) @@ -423,7 +402,7 @@ func (h *dialHistory) remove(id discover.NodeID) bool { } return false } -func (h dialHistory) contains(id discover.NodeID) bool { +func (h dialHistory) contains(id enode.ID) bool { for _, v := range h { if v.id == id { return true diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 362f22c137..0b88b4cf80 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -25,6 +25,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/netutil" ) @@ -48,8 +49,8 @@ func runDialTest(t *testing.T, test dialtest) { vtime time.Time running int ) - pm := func(ps []*Peer) map[discover.NodeID]*Peer { - m := make(map[discover.NodeID]*Peer) + pm := func(ps []*Peer) map[enode.ID]*Peer { + m := make(map[enode.ID]*Peer) for _, p := range ps { m[p.rw.id] = p } @@ -80,8 +81,8 @@ type fakeTable []*discover.Node func (t fakeTable) Self() *discover.Node { return new(discover.Node) } func (t fakeTable) Close() {} -func (t fakeTable) Lookup(discover.NodeID) []*discover.Node { return nil } -func (t fakeTable) Resolve(discover.NodeID) *discover.Node { return nil } +func (t fakeTable) Lookup(enode.ID) []*discover.Node { return nil } +func (t fakeTable) Resolve(enode.ID) *discover.Node { return nil } func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf, t) } // This test checks that dynamic dials are launched from discovery results. @@ -656,7 +657,7 @@ func TestDialResolve(t *testing.T) { config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}} srv := &Server{ntab: table, Config: config} tasks[0].Do(srv) - if !reflect.DeepEqual(table.resolveCalls, []discover.NodeID{dest.ID}) { + if !reflect.DeepEqual(table.resolveCalls, []enode.ID{dest.ID}) { t.Fatalf("wrong resolve calls, got %v", table.resolveCalls) } @@ -684,19 +685,19 @@ next: return true } -func uintID(i uint32) discover.NodeID { - var id discover.NodeID +func uintID(i uint32) enode.ID { + var id enode.ID binary.BigEndian.PutUint32(id[:], i) return id } // implements discoverTable for TestDialResolve type resolveMock struct { - resolveCalls []discover.NodeID + resolveCalls []enode.ID answer *discover.Node } -func (t *resolveMock) Resolve(id discover.NodeID) *discover.Node { +func (t *resolveMock) Resolve(id enode.ID) *discover.Node { t.resolveCalls = append(t.resolveCalls, id) return t.answer } @@ -704,5 +705,5 @@ func (t *resolveMock) Resolve(id discover.NodeID) *discover.Node { func (t *resolveMock) Self() *discover.Node { return new(discover.Node) } func (t *resolveMock) Close() {} func (t *resolveMock) Bootstrap([]*discover.Node) {} -func (t *resolveMock) Lookup(discover.NodeID) []*discover.Node { return nil } +func (t *resolveMock) Lookup(enode.ID) []*discover.Node { return nil } func (t *resolveMock) ReadRandomNodes(buf []*discover.Node) int { return 0 } diff --git a/p2p/message.go b/p2p/message.go index c92b31581d..b29a82c7bf 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -26,7 +26,7 @@ import ( "time" "github.com/tomochain/tomochain/event" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rlp" ) @@ -100,12 +100,11 @@ func Send(w MsgWriter, msgcode uint64, data interface{}) error { // SendItems writes an RLP with the given code and data elements. // For a call such as: // -// SendItems(w, code, e1, e2, e3) +// SendItems(w, code, e1, e2, e3) // // the message payload will be an RLP list containing the items: // -// [e1, e2, e3] -// +// [e1, e2, e3] func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error { return Send(w, msgcode, elems) } @@ -254,13 +253,13 @@ type msgEventer struct { MsgReadWriter feed *event.Feed - peerID discover.NodeID + peerID enode.ID Protocol string } // newMsgEventer returns a msgEventer which sends message events to the given // feed -func newMsgEventer(rw MsgReadWriter, feed *event.Feed, peerID discover.NodeID, proto string) *msgEventer { +func newMsgEventer(rw MsgReadWriter, feed *event.Feed, peerID enode.ID, proto string) *msgEventer { return &msgEventer{ MsgReadWriter: rw, feed: feed, diff --git a/p2p/peer.go b/p2p/peer.go index 1852d59bcf..04b12d893f 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -17,6 +17,7 @@ package p2p import ( + "errors" "fmt" "io" "net" @@ -27,10 +28,15 @@ import ( "github.com/tomochain/tomochain/common/mclock" "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" + "github.com/tomochain/tomochain/p2p/enr" "github.com/tomochain/tomochain/rlp" ) +var ( + ErrShuttingDown = errors.New("shutting down") +) + const ( baseProtocolVersion = 5 baseProtocolLength = uint64(16) @@ -47,8 +53,6 @@ const ( discMsg = 0x01 pingMsg = 0x02 pongMsg = 0x03 - getPeersMsg = 0x04 - peersMsg = 0x05 ) // protoHandshake is the RLP structure of the protocol handshake. @@ -57,7 +61,7 @@ type protoHandshake struct { Name string Caps []Cap ListenPort uint64 - ID discover.NodeID + ID []byte // secp256k1 public key // Ignore additional fields (for forward compatibility). Rest []rlp.RawValue `rlp:"tail"` @@ -87,12 +91,12 @@ const ( // PeerEvent is an event emitted when peers are either added or dropped from // a p2p.Server or when a message is sent or received on a peer connection type PeerEvent struct { - Type PeerEventType `json:"type"` - Peer discover.NodeID `json:"peer"` - Error string `json:"error,omitempty"` - Protocol string `json:"protocol,omitempty"` - MsgCode *uint64 `json:"msg_code,omitempty"` - MsgSize *uint32 `json:"msg_size,omitempty"` + Type PeerEventType `json:"type"` + Peer enode.ID `json:"peer"` + Error string `json:"error,omitempty"` + Protocol string `json:"protocol,omitempty"` + MsgCode *uint64 `json:"msg_code,omitempty"` + MsgSize *uint32 `json:"msg_size,omitempty"` } // Peer represents a connected remote node. @@ -108,22 +112,27 @@ type Peer struct { disc chan DiscReason // events receives message send / receive events if set - events *event.Feed - PairPeer *Peer + events *event.Feed } // NewPeer returns a peer for testing purposes. -func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { +func NewPeer(id enode.ID, name string, caps []Cap) *Peer { pipe, _ := net.Pipe() - conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name} + node := enode.SignNull(new(enr.Record), id) + conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name} peer := newPeer(conn, nil) close(peer.closed) // ensures Disconnect doesn't block return peer } // ID returns the node's public key. -func (p *Peer) ID() discover.NodeID { - return p.rw.id +func (p *Peer) ID() enode.ID { + return p.rw.node.ID() +} + +// Node returns the peer's node descriptor. +func (p *Peer) Node() *enode.Node { + return p.rw.node } // Name returns the node name that the remote node advertised. @@ -158,12 +167,13 @@ func (p *Peer) Disconnect(reason DiscReason) { // String implements fmt.Stringer. func (p *Peer) String() string { - return fmt.Sprintf("Peer %x %v ", p.rw.id[:8], p.RemoteAddr()) + id := p.ID() + return fmt.Sprintf("Peer %x %v", id[:8], p.RemoteAddr()) } // Inbound returns true if the peer is an inbound connection func (p *Peer) Inbound() bool { - return p.rw.flags&inboundConn != 0 + return p.rw.is(inboundConn) } func newPeer(conn *conn, protocols []Protocol) *Peer { @@ -175,7 +185,7 @@ func newPeer(conn *conn, protocols []Protocol) *Peer { disc: make(chan DiscReason), protoErr: make(chan error, len(protomap)+1), // protocols + pingLoop closed: make(chan struct{}), - log: log.New("id", conn.id, "conn", conn.flags), + log: log.New("id", conn.node.ID(), "conn", conn.flags), } return p } @@ -223,15 +233,14 @@ loop: reason = discReasonForError(err) break loop case err = <-p.disc: + reason = discReasonForError(err) break loop } } + close(p.closed) p.rw.close(reason) p.wg.Wait() - if p.PairPeer != nil { - go func() { p.PairPeer.Disconnect(DiscPairPeerStop) }() - } return remoteRequested, err } @@ -348,7 +357,6 @@ func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) rw = newMsgEventer(rw, p.events, p.ID(), proto.Name) } p.log.Trace(fmt.Sprintf("Starting protocol %s/%d", proto.Name, proto.Version)) - go func() { err := proto.Run(p, rw) if err == nil { @@ -376,7 +384,7 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) { type protoRW struct { Protocol - in chan Msg // receices read messages + in chan Msg // receives read messages closed <-chan struct{} // receives when peer is shutting down wstart <-chan struct{} // receives when write may start werr chan<- error // for write results @@ -398,7 +406,7 @@ func (rw *protoRW) WriteMsg(msg Msg) (err error) { // as well but we don't want to rely on that. rw.werr <- err case <-rw.closed: - err = fmt.Errorf("shutting down") + err = ErrShuttingDown } return err } diff --git a/p2p/protocol.go b/p2p/protocol.go index dbdb197012..fafc044d81 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -19,7 +19,7 @@ package p2p import ( "fmt" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" ) // Protocol represents a P2P subprotocol implementation. @@ -51,7 +51,7 @@ type Protocol struct { // PeerInfo is an optional helper method to retrieve protocol specific metadata // about a certain peer in the network. If an info retrieval function is set, // but returns nil, it is assumed that the protocol handshake is still running. - PeerInfo func(id discover.NodeID) interface{} + PeerInfo func(id enode.ID) interface{} } func (p Protocol) cap() Cap { diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go index 286bbf97f5..0e4523e403 100644 --- a/p2p/protocols/protocol_test.go +++ b/p2p/protocols/protocol_test.go @@ -24,7 +24,7 @@ import ( "time" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/simulations/adapters" p2ptest "github.com/tomochain/tomochain/p2p/testing" ) @@ -36,14 +36,14 @@ type hs0 struct { // message to kill/drop the peer with nodeID type kill struct { - C discover.NodeID + C enode.ID } // message to drop connection type drop struct { } -/// protoHandshake represents module-independent aspects of the protocol and is +// / protoHandshake represents module-independent aspects of the protocol and is // the first message peers send and receive as part the initial exchange type protoHandshake struct { Version uint // local and remote peer should have identical version @@ -144,7 +144,7 @@ func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTes return p2ptest.NewProtocolTester(t, conf.ID, 2, newProtocol(pp)) } -func protoHandshakeExchange(id discover.NodeID, proto *protoHandshake) []p2ptest.Exchange { +func protoHandshakeExchange(id enode.ID, proto *protoHandshake) []p2ptest.Exchange { return []p2ptest.Exchange{ { @@ -197,7 +197,7 @@ func TestProtoHandshakeSuccess(t *testing.T) { runProtoHandshake(t, &protoHandshake{42, "420"}) } -func moduleHandshakeExchange(id discover.NodeID, resp uint) []p2ptest.Exchange { +func moduleHandshakeExchange(id enode.ID, resp uint) []p2ptest.Exchange { return []p2ptest.Exchange{ { @@ -249,7 +249,7 @@ func TestModuleHandshakeSuccess(t *testing.T) { } // testing complex interactions over multiple peers, relaying, dropping -func testMultiPeerSetup(a, b discover.NodeID) []p2ptest.Exchange { +func testMultiPeerSetup(a, b enode.ID) []p2ptest.Exchange { return []p2ptest.Exchange{ { diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 2cc4d42d35..2fcfc1c18f 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -36,12 +36,12 @@ import ( "time" "github.com/golang/snappy" - "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/common/bitutil" + crypto "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/ecies" "github.com/tomochain/tomochain/crypto/secp256k1" "github.com/tomochain/tomochain/crypto/sha3" - "github.com/tomochain/tomochain/p2p/discover" - "github.com/tomochain/tomochain/rlp" + rlp "github.com/tomochain/tomochain/rlp" ) const ( @@ -122,7 +122,6 @@ func (t *rlpx) close(err error) { } func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) { - // Writing our handshake happens concurrently, we prefer // returning the handshake read error. If the remote side // disconnects us early with a valid reason, we should return it @@ -166,7 +165,7 @@ func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, if err := msg.Decode(&hs); err != nil { return nil, err } - if (hs.ID == discover.NodeID{}) { + if len(hs.ID) != 64 || !bitutil.TestBytes(hs.ID) { return nil, DiscInvalidIdentity } return &hs, nil @@ -176,31 +175,29 @@ func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, // messages. the protocol handshake is the first authenticated message // and also verifies whether the encryption handshake 'worked' and the // remote side actually provided the right public key. -func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) { +func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *ecdsa.PublicKey) (*ecdsa.PublicKey, error) { var ( sec secrets err error ) if dial == nil { - sec, err = receiverEncHandshake(t.fd, prv, nil) + sec, err = receiverEncHandshake(t.fd, prv) } else { - sec, err = initiatorEncHandshake(t.fd, prv, dial.ID, nil) + sec, err = initiatorEncHandshake(t.fd, prv, dial) } if err != nil { - return discover.NodeID{}, err + return nil, err } t.wmu.Lock() t.rw = newRLPXFrameRW(t.fd, sec) t.wmu.Unlock() - return sec.RemoteID, nil + return sec.Remote.ExportECDSA(), nil } // encHandshake contains the state of the encryption handshake. type encHandshake struct { - initiator bool - remoteID discover.NodeID - - remotePub *ecies.PublicKey // remote-pubk + initiator bool + remote *ecies.PublicKey // remote-pubk initNonce, respNonce []byte // nonce randomPrivKey *ecies.PrivateKey // ecdhe-random remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk @@ -209,7 +206,7 @@ type encHandshake struct { // secrets represents the connection secrets // which are negotiated during the encryption handshake. type secrets struct { - RemoteID discover.NodeID + Remote *ecies.PublicKey AES, MAC []byte EgressMAC, IngressMAC hash.Hash Token []byte @@ -250,9 +247,9 @@ func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) { sharedSecret := crypto.Keccak256(ecdheSecret, crypto.Keccak256(h.respNonce, h.initNonce)) aesSecret := crypto.Keccak256(ecdheSecret, sharedSecret) s := secrets{ - RemoteID: h.remoteID, - AES: aesSecret, - MAC: crypto.Keccak256(ecdheSecret, aesSecret), + Remote: h.remote, + AES: aesSecret, + MAC: crypto.Keccak256(ecdheSecret, aesSecret), } // setup sha3 instances for the MACs @@ -274,16 +271,16 @@ func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) { // staticSharedSecret returns the static shared secret, the result // of key agreement between the local and remote static node key. func (h *encHandshake) staticSharedSecret(prv *ecdsa.PrivateKey) ([]byte, error) { - return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen) + return ecies.ImportECDSA(prv).GenerateShared(h.remote, sskLen, sskLen) } // initiatorEncHandshake negotiates a session token on conn. // it should be called on the dialing side of the connection. // // prv is the local client's private key. -func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) { - h := &encHandshake{initiator: true, remoteID: remoteID} - authMsg, err := h.makeAuthMsg(prv, token) +func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remote *ecdsa.PublicKey) (s secrets, err error) { + h := &encHandshake{initiator: true, remote: ecies.ImportECDSAPublic(remote)} + authMsg, err := h.makeAuthMsg(prv) if err != nil { return s, err } @@ -307,15 +304,11 @@ func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID d } // makeAuthMsg creates the initiator handshake message. -func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey, token []byte) (*authMsgV4, error) { - rpub, err := h.remoteID.Pubkey() - if err != nil { - return nil, fmt.Errorf("bad remoteID: %v", err) - } - h.remotePub = ecies.ImportECDSAPublic(rpub) +func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey) (*authMsgV4, error) { // Generate random initiator nonce. h.initNonce = make([]byte, shaLen) - if _, err := rand.Read(h.initNonce); err != nil { + _, err := rand.Read(h.initNonce) + if err != nil { return nil, err } // Generate random keypair to for ECDH. @@ -325,7 +318,7 @@ func (h *encHandshake) makeAuthMsg(prv *ecdsa.PrivateKey, token []byte) (*authMs } // Sign known message: static-shared-secret ^ nonce - token, err = h.staticSharedSecret(prv) + token, err := h.staticSharedSecret(prv) if err != nil { return nil, err } @@ -353,8 +346,7 @@ func (h *encHandshake) handleAuthResp(msg *authRespV4) (err error) { // it should be called on the listening side of the connection. // // prv is the local client's private key. -// token is the token from a previous session with this node. -func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) { +func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey) (s secrets, err error) { authMsg := new(authMsgV4) authPacket, err := readHandshakeMsg(authMsg, encAuthMsgLen, prv, conn) if err != nil { @@ -386,13 +378,12 @@ func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byt func (h *encHandshake) handleAuthMsg(msg *authMsgV4, prv *ecdsa.PrivateKey) error { // Import the remote identity. - h.initNonce = msg.Nonce[:] - h.remoteID = msg.InitiatorPubkey - rpub, err := h.remoteID.Pubkey() + rpub, err := importPublicKey(msg.InitiatorPubkey[:]) if err != nil { - return fmt.Errorf("bad remoteID: %#v", err) + return err } - h.remotePub = ecies.ImportECDSAPublic(rpub) + h.initNonce = msg.Nonce[:] + h.remote = rpub // Generate random keypair for ECDH. // If a private key is already set, use it instead of generating one (for testing). @@ -438,7 +429,7 @@ func (msg *authMsgV4) sealPlain(h *encHandshake) ([]byte, error) { n += copy(buf[n:], msg.InitiatorPubkey[:]) n += copy(buf[n:], msg.Nonce[:]) buf[n] = 0 // token-flag - return ecies.Encrypt(rand.Reader, h.remotePub, buf, nil, nil) + return ecies.Encrypt(rand.Reader, h.remote, buf, nil, nil) } func (msg *authMsgV4) decodePlain(input []byte) { @@ -454,7 +445,7 @@ func (msg *authRespV4) sealPlain(hs *encHandshake) ([]byte, error) { buf := make([]byte, authRespLen) n := copy(buf, msg.RandomPubkey[:]) copy(buf[n:], msg.Nonce[:]) - return ecies.Encrypt(rand.Reader, hs.remotePub, buf, nil, nil) + return ecies.Encrypt(rand.Reader, hs.remote, buf, nil, nil) } func (msg *authRespV4) decodePlain(input []byte) { @@ -477,7 +468,7 @@ func sealEIP8(msg interface{}, h *encHandshake) ([]byte, error) { prefix := make([]byte, 2) binary.BigEndian.PutUint16(prefix, uint16(buf.Len()+eciesOverhead)) - enc, err := ecies.Encrypt(rand.Reader, h.remotePub, buf.Bytes(), nil, prefix) + enc, err := ecies.Encrypt(rand.Reader, h.remote, buf.Bytes(), nil, prefix) return append(prefix, enc...), err } @@ -529,9 +520,9 @@ func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) { return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey)) } // TODO: fewer pointless conversions - pub := crypto.ToECDSAPub(pubKey65) - if pub.X == nil { - return nil, fmt.Errorf("invalid public key") + pub, err := crypto.UnmarshalPubkey(pubKey65) + if err != nil { + return nil, err } return ecies.ImportECDSAPublic(pub), nil } diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index e86a1fb171..207dbf057f 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -18,6 +18,7 @@ package p2p import ( "bytes" + "crypto/ecdsa" "crypto/rand" "errors" "fmt" @@ -31,11 +32,11 @@ import ( "time" "github.com/davecgh/go-spew/spew" - "github.com/tomochain/tomochain/crypto" + crypto "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/ecies" "github.com/tomochain/tomochain/crypto/sha3" - "github.com/tomochain/tomochain/p2p/discover" - "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/p2p/simulations/pipes" + rlp "github.com/tomochain/tomochain/rlp" ) func TestSharedSecret(t *testing.T) { @@ -79,9 +80,9 @@ func TestEncHandshake(t *testing.T) { func testEncHandshake(token []byte) error { type result struct { - side string - id discover.NodeID - err error + side string + pubkey *ecdsa.PublicKey + err error } var ( prv0, _ = crypto.GenerateKey() @@ -96,14 +97,12 @@ func testEncHandshake(token []byte) error { defer func() { output <- r }() defer fd0.Close() - dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)} - r.id, r.err = c0.doEncHandshake(prv0, dest) + r.pubkey, r.err = c0.doEncHandshake(prv0, &prv1.PublicKey) if r.err != nil { return } - id1 := discover.PubkeyID(&prv1.PublicKey) - if r.id != id1 { - r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1) + if !reflect.DeepEqual(r.pubkey, &prv1.PublicKey) { + r.err = fmt.Errorf("remote pubkey mismatch: got %v, want: %v", r.pubkey, &prv1.PublicKey) } }() go func() { @@ -111,13 +110,12 @@ func testEncHandshake(token []byte) error { defer func() { output <- r }() defer fd1.Close() - r.id, r.err = c1.doEncHandshake(prv1, nil) + r.pubkey, r.err = c1.doEncHandshake(prv1, nil) if r.err != nil { return } - id0 := discover.PubkeyID(&prv0.PublicKey) - if r.id != id0 { - r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0) + if !reflect.DeepEqual(r.pubkey, &prv0.PublicKey) { + r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.pubkey, &prv0.PublicKey) } }() @@ -149,17 +147,17 @@ func testEncHandshake(token []byte) error { func TestProtocolHandshake(t *testing.T) { var ( prv0, _ = crypto.GenerateKey() - node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33} - hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}} + pub0 = crypto.FromECDSAPub(&prv0.PublicKey)[1:] + hs0 = &protoHandshake{Version: 3, ID: pub0, Caps: []Cap{{"a", 0}, {"b", 2}}} prv1, _ = crypto.GenerateKey() - node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44} - hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}} + pub1 = crypto.FromECDSAPub(&prv1.PublicKey)[1:] + hs1 = &protoHandshake{Version: 3, ID: pub1, Caps: []Cap{{"c", 1}, {"d", 3}}} wg sync.WaitGroup ) - fd0, fd1, err := tcpPipe() + fd0, fd1, err := pipes.TCPPipe() if err != nil { t.Fatal(err) } @@ -169,13 +167,13 @@ func TestProtocolHandshake(t *testing.T) { defer wg.Done() defer fd0.Close() rlpx := newRLPX(fd0) - remid, err := rlpx.doEncHandshake(prv0, node1) + rpubkey, err := rlpx.doEncHandshake(prv0, &prv1.PublicKey) if err != nil { t.Errorf("dial side enc handshake failed: %v", err) return } - if remid != node1.ID { - t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID) + if !reflect.DeepEqual(rpubkey, &prv1.PublicKey) { + t.Errorf("dial side remote pubkey mismatch: got %v, want %v", rpubkey, &prv1.PublicKey) return } @@ -195,13 +193,13 @@ func TestProtocolHandshake(t *testing.T) { defer wg.Done() defer fd1.Close() rlpx := newRLPX(fd1) - remid, err := rlpx.doEncHandshake(prv1, nil) + rpubkey, err := rlpx.doEncHandshake(prv1, nil) if err != nil { t.Errorf("listen side enc handshake failed: %v", err) return } - if remid != node0.ID { - t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID) + if !reflect.DeepEqual(rpubkey, &prv0.PublicKey) { + t.Errorf("listen side remote pubkey mismatch: got %v, want %v", rpubkey, &prv0.PublicKey) return } @@ -601,31 +599,3 @@ func TestHandshakeForwardCompatibility(t *testing.T) { t.Errorf("ingress-mac('foo') mismatch:\ngot %x\nwant %x", fooIngressHash, wantFooIngressHash) } } - -// tcpPipe creates an in process full duplex pipe based on a localhost TCP socket -func tcpPipe() (net.Conn, net.Conn, error) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return nil, nil, err - } - defer l.Close() - - var aconn net.Conn - aerr := make(chan error, 1) - go func() { - var err error - aconn, err = l.Accept() - aerr <- err - }() - - dconn, err := net.Dial("tcp", l.Addr().String()) - if err != nil { - <-aerr - return nil, nil, err - } - if err := <-aerr; err != nil { - dconn.Close() - return nil, nil, err - } - return aconn, dconn, nil -} diff --git a/p2p/server.go b/p2p/server.go index 6a5ea9e613..7875426b65 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -18,19 +18,23 @@ package p2p import ( + "bytes" "crypto/ecdsa" "errors" "fmt" "net" "sync" + "sync/atomic" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/mclock" + "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p/discover" "github.com/tomochain/tomochain/p2p/discv5" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/nat" "github.com/tomochain/tomochain/p2p/netutil" ) @@ -76,7 +80,7 @@ type Config struct { // Disabling is useful for protocol debugging (manual topology). NoDiscovery bool - // DiscoveryV5 specifies whether the the new topic-discovery based V5 discovery + // DiscoveryV5 specifies whether the new topic-discovery based V5 discovery // protocol should be started or not. DiscoveryV5 bool `toml:",omitempty"` @@ -86,7 +90,7 @@ type Config struct { // BootstrapNodes are used to establish connectivity // with the rest of the network. - BootstrapNodes []*discover.Node + BootstrapNodes []*enode.Node // BootstrapNodesV5 are used to establish connectivity // with the rest of the network using the V5 discovery @@ -95,11 +99,11 @@ type Config struct { // Static nodes are used as pre-configured connections which are always // maintained and re-connected on disconnects. - StaticNodes []*discover.Node + StaticNodes []*enode.Node // Trusted nodes are used as pre-configured connections which are always // allowed to connect, even above the peer limit. - TrustedNodes []*discover.Node + TrustedNodes []*enode.Node // Connectivity can be restricted to certain IP networks. // If this option is set to a non-nil value, only hosts which match one of the @@ -167,8 +171,10 @@ type Server struct { peerOpDone chan struct{} quit chan struct{} - addstatic chan *discover.Node - removestatic chan *discover.Node + addstatic chan *enode.Node + removestatic chan *enode.Node + addtrusted chan *enode.Node + removetrusted chan *enode.Node posthandshake chan *conn addpeer chan *conn delpeer chan peerDrop @@ -177,7 +183,7 @@ type Server struct { log log.Logger } -type peerOpFunc func(map[discover.NodeID]*Peer) +type peerOpFunc func(map[enode.ID]*Peer) type peerDrop struct { *Peer @@ -185,7 +191,7 @@ type peerDrop struct { requested bool // true if signaled by the peer } -type connFlag int +type connFlag int32 const ( dynDialedConn connFlag = 1 << iota @@ -199,16 +205,16 @@ const ( type conn struct { fd net.Conn transport + node *enode.Node flags connFlag - cont chan error // The run loop uses cont to signal errors to SetupConn. - id discover.NodeID // valid after the encryption handshake - caps []Cap // valid after the protocol handshake - name string // valid after the protocol handshake + cont chan error // The run loop uses cont to signal errors to SetupConn. + caps []Cap // valid after the protocol handshake + name string // valid after the protocol handshake } type transport interface { // The two handshakes. - doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) + doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) // The MsgReadWriter can only be used after the encryption // handshake has completed. The code uses conn.id to track this @@ -222,8 +228,8 @@ type transport interface { func (c *conn) String() string { s := c.flags.String() - if (c.id != discover.NodeID{}) { - s += " " + c.id.String() + if (c.node.ID() != enode.ID{}) { + s += " " + c.node.ID().String() } s += " " + c.fd.RemoteAddr().String() return s @@ -250,7 +256,23 @@ func (f connFlag) String() string { } func (c *conn) is(f connFlag) bool { - return c.flags&f != 0 + flags := connFlag(atomic.LoadInt32((*int32)(&c.flags))) + return flags&f != 0 +} + +func (c *conn) set(f connFlag, val bool) { + for { + oldFlags := connFlag(atomic.LoadInt32((*int32)(&c.flags))) + flags := oldFlags + if val { + flags |= f + } else { + flags &= ^f + } + if atomic.CompareAndSwapInt32((*int32)(&c.flags), int32(oldFlags), int32(flags)) { + return + } + } } // Peers returns all connected peers. @@ -260,7 +282,7 @@ func (srv *Server) Peers() []*Peer { // Note: We'd love to put this function into a variable but // that seems to cause a weird compiler error in some // environments. - case srv.peerOp <- func(peers map[discover.NodeID]*Peer) { + case srv.peerOp <- func(peers map[enode.ID]*Peer) { for _, p := range peers { ps = append(ps, p) } @@ -275,7 +297,7 @@ func (srv *Server) Peers() []*Peer { func (srv *Server) PeerCount() int { var count int select { - case srv.peerOp <- func(ps map[discover.NodeID]*Peer) { count = len(ps) }: + case srv.peerOp <- func(ps map[enode.ID]*Peer) { count = len(ps) }: <-srv.peerOpDone case <-srv.quit: } @@ -285,8 +307,7 @@ func (srv *Server) PeerCount() int { // AddPeer connects to the given node and maintains the connection until the // server is shut down. If the connection fails for any reason, the server will // attempt to reconnect the peer. -func (srv *Server) AddPeer(node *discover.Node) { - +func (srv *Server) AddPeer(node *enode.Node) { select { case srv.addstatic <- node: case <-srv.quit: @@ -294,55 +315,83 @@ func (srv *Server) AddPeer(node *discover.Node) { } // RemovePeer disconnects from the given node -func (srv *Server) RemovePeer(node *discover.Node) { +func (srv *Server) RemovePeer(node *enode.Node) { select { case srv.removestatic <- node: case <-srv.quit: } } +// AddTrustedPeer adds the given node to a reserved whitelist which allows the +// node to always connect, even if the slot are full. +func (srv *Server) AddTrustedPeer(node *enode.Node) { + select { + case srv.addtrusted <- node: + case <-srv.quit: + } +} + +// RemoveTrustedPeer removes the given node from the trusted peer set. +func (srv *Server) RemoveTrustedPeer(node *enode.Node) { + select { + case srv.removetrusted <- node: + case <-srv.quit: + } +} + // SubscribePeers subscribes the given channel to peer events func (srv *Server) SubscribeEvents(ch chan *PeerEvent) event.Subscription { return srv.peerFeed.Subscribe(ch) } // Self returns the local node's endpoint information. -func (srv *Server) Self() *discover.Node { +func (srv *Server) Self() *enode.Node { srv.lock.Lock() - defer srv.lock.Unlock() + running, listener, ntab := srv.running, srv.listener, srv.ntab + srv.lock.Unlock() - if !srv.running { - return &discover.Node{IP: net.ParseIP("0.0.0.0")} + if !running { + return enode.NewV4(&srv.PrivateKey.PublicKey, net.ParseIP("0.0.0.0"), 0, 0) } - return srv.makeSelf(srv.listener, srv.ntab) + return srv.makeSelf(listener, ntab) } -func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *discover.Node { - // If the server's not running, return an empty node. +func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *enode.Node { // If the node is running but discovery is off, manually assemble the node infos. if ntab == nil { - // Inbound connections disabled, use zero address. - if listener == nil { - return &discover.Node{IP: net.ParseIP("0.0.0.0"), ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} - } - // Otherwise inject the listener address too - addr := listener.Addr().(*net.TCPAddr) - return &discover.Node{ - ID: discover.PubkeyID(&srv.PrivateKey.PublicKey), - IP: addr.IP, - TCP: uint16(addr.Port), - } + addr := srv.tcpAddr(listener) + return enode.NewV4(&srv.PrivateKey.PublicKey, addr.IP, addr.Port, 0) } // Otherwise return the discovery node. return ntab.Self() } +func (srv *Server) tcpAddr(listener net.Listener) net.TCPAddr { + addr := net.TCPAddr{IP: net.IP{0, 0, 0, 0}} + if listener == nil { + return addr // Inbound connections disabled, use zero address. + } + // Otherwise inject the listener address too. + if a, ok := listener.Addr().(*net.TCPAddr); ok { + addr = *a + } + if srv.NAT != nil { + if ip, err := srv.NAT.ExternalIP(); err == nil { + addr.IP = ip + } + } + if addr.IP.IsUnspecified() { + addr.IP = net.IP{127, 0, 0, 1} + } + return addr +} + // Stop terminates the server and all active peer connections. // It blocks until all active connections have been closed. func (srv *Server) Stop() { srv.lock.Lock() - defer srv.lock.Unlock() if !srv.running { + srv.lock.Unlock() return } srv.running = false @@ -351,6 +400,7 @@ func (srv *Server) Stop() { srv.listener.Close() } close(srv.quit) + srv.lock.Unlock() srv.loopWG.Wait() } @@ -409,8 +459,10 @@ func (srv *Server) Start() (err error) { srv.addpeer = make(chan *conn) srv.delpeer = make(chan peerDrop) srv.posthandshake = make(chan *conn) - srv.addstatic = make(chan *discover.Node) - srv.removestatic = make(chan *discover.Node) + srv.addstatic = make(chan *enode.Node) + srv.removestatic = make(chan *enode.Node) + srv.addtrusted = make(chan *enode.Node) + srv.removetrusted = make(chan *enode.Node) srv.peerOp = make(chan peerOpFunc) srv.peerOpDone = make(chan struct{}) @@ -487,7 +539,8 @@ func (srv *Server) Start() (err error) { dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) // handshake - srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} + pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey) + srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]} for _, p := range srv.Protocols { srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap()) } @@ -503,7 +556,6 @@ func (srv *Server) Start() (err error) { srv.loopWG.Add(1) go srv.run(dialer) - srv.running = true return nil } @@ -530,27 +582,26 @@ func (srv *Server) startListening() error { } type dialer interface { - newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task + newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task taskDone(task, time.Time) - addStatic(*discover.Node) - removeStatic(*discover.Node) + addStatic(*enode.Node) + removeStatic(*enode.Node) } func (srv *Server) run(dialstate dialer) { defer srv.loopWG.Done() var ( - peers = make(map[discover.NodeID]*Peer) + peers = make(map[enode.ID]*Peer) inboundCount = 0 - trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes)) + trusted = make(map[enode.ID]bool, len(srv.TrustedNodes)) taskdone = make(chan task, maxActiveDialTasks) runningTasks []task queuedTasks []task // tasks that can't run yet ) // Put trusted nodes into a map to speed up checks. - // Trusted peers are loaded on startup and cannot be - // modified while the server is running. + // Trusted peers are loaded on startup or added via AddTrustedPeer RPC. for _, n := range srv.TrustedNodes { - trusted[n.ID] = true + trusted[n.ID()] = true } // removes t from runningTasks @@ -595,17 +646,37 @@ running: // This channel is used by AddPeer to add to the // ephemeral static peer list. Add it to the dialer, // it will keep the node connected. - srv.log.Debug("Adding static node", "node", n) + srv.log.Trace("Adding static node", "node", n) dialstate.addStatic(n) case n := <-srv.removestatic: // This channel is used by RemovePeer to send a // disconnect request to a peer and begin the - // stop keeping the node connected - srv.log.Debug("Removing static node", "node", n) + // stop keeping the node connected. + srv.log.Trace("Removing static node", "node", n) dialstate.removeStatic(n) - if p, ok := peers[n.ID]; ok { + if p, ok := peers[n.ID()]; ok { p.Disconnect(DiscRequested) } + case n := <-srv.addtrusted: + // This channel is used by AddTrustedPeer to add an enode + // to the trusted node set. + srv.log.Trace("Adding trusted node", "node", n) + trusted[n.ID()] = true + // Mark any already-connected peer as trusted + if p, ok := peers[n.ID()]; ok { + p.rw.set(trustedConn, true) + } + case n := <-srv.removetrusted: + // This channel is used by RemoveTrustedPeer to remove an enode + // from the trusted node set. + srv.log.Trace("Removing trusted node", "node", n) + if _, ok := trusted[n.ID()]; ok { + delete(trusted, n.ID()) + } + // Unmark any already-connected peer as trusted + if p, ok := peers[n.ID()]; ok { + p.rw.set(trustedConn, false) + } case op := <-srv.peerOp: // This channel is used by Peers and PeerCount. op(peers) @@ -620,7 +691,7 @@ running: case c := <-srv.posthandshake: // A connection has passed the encryption handshake so // the remote identity is known (but hasn't been verified yet). - if trusted[c.id] { + if trusted[c.node.ID()] { // Ensure that the trusted flag is set before checking against MaxPeers. c.flags |= trustedConn } @@ -643,15 +714,9 @@ running: p.events = &srv.peerFeed } name := truncateName(c.name) - + srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) go srv.runPeer(p) - if peers[c.id] != nil { - peers[c.id].PairPeer = p - srv.log.Debug("Adding p2p pair peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) - } else { - peers[c.id] = p - srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) - } + peers[c.node.ID()] = p if p.Inbound() { inboundCount++ } @@ -698,7 +763,7 @@ running: } } -func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error { +func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { // Drop connections with no matching protocols. if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { return DiscUselessPeer @@ -708,19 +773,15 @@ func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inbound return srv.encHandshakeChecks(peers, inboundCount, c) } -func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error { +func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { switch { case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: return DiscTooManyPeers case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns(): return DiscTooManyPeers - case peers[c.id] != nil: - exitPeer := peers[c.id] - if exitPeer.PairPeer != nil { - return DiscAlreadyConnected - } - return nil - case c.id == srv.Self().ID: + case peers[c.node.ID()] != nil: + return DiscAlreadyConnected + case c.node.ID() == srv.Self().ID(): return DiscSelf default: return nil @@ -730,7 +791,6 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCo func (srv *Server) maxInboundConns() int { return srv.MaxPeers - srv.maxDialedConns() } - func (srv *Server) maxDialedConns() int { if srv.NoDiscovery || srv.NoDial { return 0 @@ -750,7 +810,7 @@ type tempError interface { // inbound connections. func (srv *Server) listenLoop() { defer srv.loopWG.Done() - srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab)) + srv.log.Info("RLPx listener up", "self", srv.Self()) tokens := defaultMaxPendingPeers if srv.MaxPendingPeers > 0 { @@ -803,7 +863,7 @@ func (srv *Server) listenLoop() { // SetupConn runs the handshakes and attempts to add the connection // as a peer. It returns when the connection has been added as a peer // or the handshakes have failed. -func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) error { +func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error { self := srv.Self() if self == nil { return errors.New("shutdown") @@ -812,12 +872,12 @@ func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *discover.Nod err := srv.setupConn(c, flags, dialDest) if err != nil { c.close(err) - srv.log.Trace("Setting up connection failed", "id", c.id, "err", err) + srv.log.Trace("Setting up connection failed", "addr", fd.RemoteAddr(), "err", err) } return err } -func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) error { +func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) error { // Prevent leftover pending conns from entering the handshake. srv.lock.Lock() running := srv.running @@ -825,18 +885,30 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) e if !running { return errServerStopped } + // If dialing, figure out the remote public key. + var dialPubkey *ecdsa.PublicKey + if dialDest != nil { + dialPubkey = new(ecdsa.PublicKey) + if err := dialDest.Load((*enode.Secp256k1)(dialPubkey)); err != nil { + return fmt.Errorf("dial destination doesn't have a secp256k1 public key") + } + } // Run the encryption handshake. - var err error - if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil { + remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey) + if err != nil { srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err) return err } - clog := srv.log.New("id", c.id, "addr", c.fd.RemoteAddr(), "conn", c.flags) - // For dialed connections, check that the remote public key matches. - if dialDest != nil && c.id != dialDest.ID { - clog.Trace("Dialed identity mismatch", "want", c, dialDest.ID) - return DiscUnexpectedIdentity + if dialDest != nil { + // For dialed connections, check that the remote public key matches. + if dialPubkey.X.Cmp(remotePubkey.X) != 0 || dialPubkey.Y.Cmp(remotePubkey.Y) != 0 { + return DiscUnexpectedIdentity + } + c.node = dialDest + } else { + c.node = nodeFromConn(remotePubkey, c.fd) } + clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags) err = srv.checkpoint(c, srv.posthandshake) if err != nil { clog.Trace("Rejected peer before protocol handshake", "err", err) @@ -848,8 +920,8 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) e clog.Trace("Failed proto handshake", "err", err) return err } - if phs.ID != c.id { - clog.Trace("Wrong devp2p handshake identity", "err", phs.ID) + if id := c.node.ID(); !bytes.Equal(crypto.Keccak256(phs.ID), id[:]) { + clog.Trace("Wrong devp2p handshake identity", "phsid", fmt.Sprintf("%x", phs.ID)) return DiscUnexpectedIdentity } c.caps, c.name = phs.Caps, phs.Name @@ -864,6 +936,16 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *discover.Node) e return nil } +func nodeFromConn(pubkey *ecdsa.PublicKey, conn net.Conn) *enode.Node { + var ip net.IP + var port int + if tcp, ok := conn.RemoteAddr().(*net.TCPAddr); ok { + ip = tcp.IP + port = tcp.Port + } + return enode.NewV4(pubkey, ip, port, port) +} + func truncateName(s string) string { if len(s) > 20 { return s[:20] + "..." @@ -938,13 +1020,13 @@ func (srv *Server) NodeInfo() *NodeInfo { info := &NodeInfo{ Name: srv.Name, Enode: node.String(), - ID: node.ID.String(), - IP: node.IP.String(), + ID: node.ID().String(), + IP: node.IP().String(), ListenAddr: srv.ListenAddr, Protocols: make(map[string]interface{}), } - info.Ports.Discovery = int(node.UDP) - info.Ports.Listener = int(node.TCP) + info.Ports.Discovery = node.UDP() + info.Ports.Listener = node.TCP() // Gather all the running protocol infos (only once per protocol type) for _, proto := range srv.Protocols { diff --git a/p2p/server_test.go b/p2p/server_test.go index b014bd9c3e..da86ef63dc 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -28,21 +28,22 @@ import ( "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/sha3" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" + "github.com/tomochain/tomochain/p2p/enr" ) -func init() { - // log.Root().SetHandler(log.LvlFilterHandler(log.LvlError, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) -} +// func init() { +// log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) +// } type testTransport struct { - id discover.NodeID + rpub *ecdsa.PublicKey *rlpx closeErr error } -func newTestTransport(id discover.NodeID, fd net.Conn) transport { +func newTestTransport(rpub *ecdsa.PublicKey, fd net.Conn) transport { wrapped := newRLPX(fd).(*rlpx) wrapped.rw = newRLPXFrameRW(fd, secrets{ MAC: zero16, @@ -50,15 +51,16 @@ func newTestTransport(id discover.NodeID, fd net.Conn) transport { IngressMAC: sha3.NewKeccak256(), EgressMAC: sha3.NewKeccak256(), }) - return &testTransport{id: id, rlpx: wrapped} + return &testTransport{rpub: rpub, rlpx: wrapped} } -func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { - return c.id, nil +func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) { + return c.rpub, nil } func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { - return &protoHandshake{ID: c.id, Name: "test"}, nil + pubkey := crypto.FromECDSAPub(c.rpub)[1:] + return &protoHandshake{ID: pubkey, Name: "test"}, nil } func (c *testTransport) close(err error) { @@ -66,7 +68,7 @@ func (c *testTransport) close(err error) { c.closeErr = err } -func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { +func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *Server { config := Config{ Name: "test", MaxPeers: 10, @@ -76,7 +78,7 @@ func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { server := &Server{ Config: config, newPeerHook: pf, - newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) }, + newTransport: func(fd net.Conn) transport { return newTestTransport(remoteKey, fd) }, } if err := server.Start(); err != nil { t.Fatalf("Could not start server: %v", err) @@ -87,14 +89,11 @@ func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { func TestServerListen(t *testing.T) { // start the test server connected := make(chan *Peer) - remid := randomID() + remid := &newkey().PublicKey srv := startTestServer(t, remid, func(p *Peer) { - if p.ID() != remid { + if p.ID() != enode.PubkeyToIDV4(remid) { t.Error("peer func called with wrong node id") } - if p == nil { - t.Error("peer func called with nil conn") - } connected <- p }) defer close(connected) @@ -141,21 +140,23 @@ func TestServerDial(t *testing.T) { // start the server connected := make(chan *Peer) - remid := randomID() + remid := &newkey().PublicKey srv := startTestServer(t, remid, func(p *Peer) { connected <- p }) defer close(connected) defer srv.Stop() // tell the server to connect tcpAddr := listener.Addr().(*net.TCPAddr) - srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}) + node := enode.NewV4(remid, tcpAddr.IP, tcpAddr.Port, 0) + srv.AddPeer(node) select { case conn := <-accepted: defer conn.Close() + select { case peer := <-connected: - if peer.ID() != remid { + if peer.ID() != enode.PubkeyToIDV4(remid) { t.Errorf("peer has wrong id") } if peer.Name() != "test" { @@ -169,25 +170,33 @@ func TestServerDial(t *testing.T) { if !reflect.DeepEqual(peers, []*Peer{peer}) { t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) } - case <-time.After(1 * time.Second): - t.Error("server did not launch peer within one second") - } - select { - case peer := <-connected: - if peer.ID() != remid { - t.Errorf("peer has wrong id") - } - if peer.Name() != "test" { - t.Errorf("peer has wrong name") - } - if peer.RemoteAddr().String() != conn.LocalAddr().String() { - t.Errorf("peer started with wrong conn: got %v, want %v", - peer.RemoteAddr(), conn.LocalAddr()) + // Test AddTrustedPeer/RemoveTrustedPeer and changing Trusted flags + // Particularly for race conditions on changing the flag state. + if peer := srv.Peers()[0]; peer.Info().Network.Trusted { + t.Errorf("peer is trusted prematurely: %v", peer) } + done := make(chan bool) + go func() { + srv.AddTrustedPeer(node) + if peer := srv.Peers()[0]; !peer.Info().Network.Trusted { + t.Errorf("peer is not trusted after AddTrustedPeer: %v", peer) + } + srv.RemoveTrustedPeer(node) + if peer := srv.Peers()[0]; peer.Info().Network.Trusted { + t.Errorf("peer is trusted after RemoveTrustedPeer: %v", peer) + } + done <- true + }() + // Trigger potential race conditions + peer = srv.Peers()[0] + _ = peer.Inbound() + _ = peer.Info() + <-done case <-time.After(1 * time.Second): t.Error("server did not launch peer within one second") } + case <-time.After(1 * time.Second): t.Error("server did not connect within one second") } @@ -201,7 +210,7 @@ func TestServerTaskScheduling(t *testing.T) { quit, returned = make(chan struct{}), make(chan struct{}) tc = 0 tg = taskgen{ - newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { + newFunc: func(running int, peers map[enode.ID]*Peer) []task { tc++ return []task{&testTask{index: tc - 1}} }, @@ -274,7 +283,7 @@ func TestServerManyTasks(t *testing.T) { defer srv.Stop() srv.loopWG.Add(1) go srv.run(taskgen{ - newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { + newFunc: func(running int, peers map[enode.ID]*Peer) []task { start, end = end, end+maxActiveDialTasks+10 if end > len(alltasks) { end = len(alltasks) @@ -309,19 +318,19 @@ func TestServerManyTasks(t *testing.T) { } type taskgen struct { - newFunc func(running int, peers map[discover.NodeID]*Peer) []task + newFunc func(running int, peers map[enode.ID]*Peer) []task doneFunc func(task) } -func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task { +func (tg taskgen) newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task { return tg.newFunc(running, peers) } func (tg taskgen) taskDone(t task, now time.Time) { tg.doneFunc(t) } -func (tg taskgen) addStatic(*discover.Node) { +func (tg taskgen) addStatic(*enode.Node) { } -func (tg taskgen) removeStatic(*discover.Node) { +func (tg taskgen) removeStatic(*enode.Node) { } type testTask struct { @@ -337,13 +346,14 @@ func (t *testTask) Do(srv *Server) { // just after the encryption handshake when the server is // at capacity. Trusted connections should still be accepted. func TestServerAtCap(t *testing.T) { - trustedID := randomID() + trustedNode := newkey() + trustedID := enode.PubkeyToIDV4(&trustedNode.PublicKey) srv := &Server{ Config: Config{ PrivateKey: newkey(), MaxPeers: 10, NoDial: true, - TrustedNodes: []*discover.Node{{ID: trustedID}}, + TrustedNodes: []*enode.Node{newNode(trustedID, nil)}, }, } if err := srv.Start(); err != nil { @@ -351,10 +361,11 @@ func TestServerAtCap(t *testing.T) { } defer srv.Stop() - newconn := func(id discover.NodeID) *conn { + newconn := func(id enode.ID) *conn { fd, _ := net.Pipe() - tx := newTestTransport(id, fd) - return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)} + tx := newTestTransport(&trustedNode.PublicKey, fd) + node := enode.SignNull(new(enr.Record), id) + return &conn{fd: fd, transport: tx, flags: inboundConn, node: node, cont: make(chan error)} } // Inject a few connections to fill up the peer set. @@ -365,7 +376,8 @@ func TestServerAtCap(t *testing.T) { } } // Try inserting a non-trusted connection. - c := newconn(randomID()) + anotherID := randomID() + c := newconn(anotherID) if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { t.Error("wrong error for insert:", err) } @@ -378,62 +390,144 @@ func TestServerAtCap(t *testing.T) { t.Error("Server did not set trusted flag") } + // Remove from trusted set and try again + srv.RemoveTrustedPeer(newNode(trustedID, nil)) + c = newconn(trustedID) + if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { + t.Error("wrong error for insert:", err) + } + + // Add anotherID to trusted set and try again + srv.AddTrustedPeer(newNode(anotherID, nil)) + c = newconn(anotherID) + if err := srv.checkpoint(c, srv.posthandshake); err != nil { + t.Error("unexpected error for trusted conn @posthandshake:", err) + } + if !c.is(trustedConn) { + t.Error("Server did not set trusted flag") + } +} + +func TestServerPeerLimits(t *testing.T) { + srvkey := newkey() + clientkey := newkey() + clientnode := enode.NewV4(&clientkey.PublicKey, nil, 0, 0) + + var tp = &setupTransport{ + pubkey: &clientkey.PublicKey, + phs: protoHandshake{ + ID: crypto.FromECDSAPub(&clientkey.PublicKey)[1:], + // Force "DiscUselessPeer" due to unmatching caps + // Caps: []Cap{discard.cap()}, + }, + } + + srv := &Server{ + Config: Config{ + PrivateKey: srvkey, + MaxPeers: 0, + NoDial: true, + Protocols: []Protocol{discard}, + }, + newTransport: func(fd net.Conn) transport { return tp }, + log: log.New(), + } + if err := srv.Start(); err != nil { + t.Fatalf("couldn't start server: %v", err) + } + defer srv.Stop() + + // Check that server is full (MaxPeers=0) + flags := dynDialedConn + dialDest := clientnode + conn, _ := net.Pipe() + srv.SetupConn(conn, flags, dialDest) + if tp.closeErr != DiscTooManyPeers { + t.Errorf("unexpected close error: %q", tp.closeErr) + } + conn.Close() + + srv.AddTrustedPeer(clientnode) + + // Check that server allows a trusted peer despite being full. + conn, _ = net.Pipe() + srv.SetupConn(conn, flags, dialDest) + if tp.closeErr == DiscTooManyPeers { + t.Errorf("failed to bypass MaxPeers with trusted node: %q", tp.closeErr) + } + + if tp.closeErr != DiscUselessPeer { + t.Errorf("unexpected close error: %q", tp.closeErr) + } + conn.Close() + + srv.RemoveTrustedPeer(clientnode) + + // Check that server is full again. + conn, _ = net.Pipe() + srv.SetupConn(conn, flags, dialDest) + if tp.closeErr != DiscTooManyPeers { + t.Errorf("unexpected close error: %q", tp.closeErr) + } + conn.Close() } func TestServerSetupConn(t *testing.T) { - id := randomID() - srvkey := newkey() - srvid := discover.PubkeyID(&srvkey.PublicKey) + var ( + clientkey, srvkey = newkey(), newkey() + clientpub = &clientkey.PublicKey + srvpub = &srvkey.PublicKey + ) tests := []struct { dontstart bool tt *setupTransport flags connFlag - dialDest *discover.Node + dialDest *enode.Node wantCloseErr error wantCalls string }{ { dontstart: true, - tt: &setupTransport{id: id}, + tt: &setupTransport{pubkey: clientpub}, wantCalls: "close,", wantCloseErr: errServerStopped, }, { - tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, + tt: &setupTransport{pubkey: clientpub, encHandshakeErr: errors.New("read error")}, flags: inboundConn, wantCalls: "doEncHandshake,close,", wantCloseErr: errors.New("read error"), }, { - tt: &setupTransport{id: id}, - dialDest: &discover.Node{ID: randomID()}, + tt: &setupTransport{pubkey: clientpub}, + dialDest: enode.NewV4(&newkey().PublicKey, nil, 0, 0), flags: dynDialedConn, wantCalls: "doEncHandshake,close,", wantCloseErr: DiscUnexpectedIdentity, }, { - tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, - dialDest: &discover.Node{ID: id}, + tt: &setupTransport{pubkey: clientpub, phs: protoHandshake{ID: randomID().Bytes()}}, + dialDest: enode.NewV4(clientpub, nil, 0, 0), flags: dynDialedConn, wantCalls: "doEncHandshake,doProtoHandshake,close,", wantCloseErr: DiscUnexpectedIdentity, }, { - tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, - dialDest: &discover.Node{ID: id}, + tt: &setupTransport{pubkey: clientpub, protoHandshakeErr: errors.New("foo")}, + dialDest: enode.NewV4(clientpub, nil, 0, 0), flags: dynDialedConn, wantCalls: "doEncHandshake,doProtoHandshake,close,", wantCloseErr: errors.New("foo"), }, { - tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, + tt: &setupTransport{pubkey: srvpub, phs: protoHandshake{ID: crypto.FromECDSAPub(srvpub)[1:]}}, flags: inboundConn, wantCalls: "doEncHandshake,close,", wantCloseErr: DiscSelf, }, { - tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}}, + tt: &setupTransport{pubkey: clientpub, phs: protoHandshake{ID: crypto.FromECDSAPub(clientpub)[1:]}}, flags: inboundConn, wantCalls: "doEncHandshake,doProtoHandshake,close,", wantCloseErr: DiscUselessPeer, @@ -468,26 +562,26 @@ func TestServerSetupConn(t *testing.T) { } type setupTransport struct { - id discover.NodeID - encHandshakeErr error - - phs *protoHandshake + pubkey *ecdsa.PublicKey + encHandshakeErr error + phs protoHandshake protoHandshakeErr error calls string closeErr error } -func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { +func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) { c.calls += "doEncHandshake," - return c.id, c.encHandshakeErr + return c.pubkey, c.encHandshakeErr } + func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { c.calls += "doProtoHandshake," if c.protoHandshakeErr != nil { return nil, c.protoHandshakeErr } - return c.phs, nil + return &c.phs, nil } func (c *setupTransport) close(err error) { c.calls += "close," @@ -510,7 +604,7 @@ func newkey() *ecdsa.PrivateKey { return key } -func randomID() (id discover.NodeID) { +func randomID() (id enode.ID) { for i := range id { id[i] = byte(rand.Intn(255)) } From 1fbef4c1bffaea0b54e23eb13e9ed9921cd9afa0 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 30 Oct 2023 16:40:24 +0700 Subject: [PATCH 098/119] Port p2p/simulations and testing to p2p/enode --- p2p/simulations/adapters/docker.go | 4 +- p2p/simulations/adapters/exec.go | 10 ++-- p2p/simulations/adapters/inproc.go | 30 +++++------ p2p/simulations/adapters/types.go | 75 ++++++++++++++++++++------- p2p/simulations/examples/ping-pong.go | 6 +-- p2p/simulations/http.go | 13 +++-- p2p/simulations/http_test.go | 9 ++-- p2p/simulations/mocker.go | 30 +++++------ p2p/simulations/mocker_test.go | 4 +- p2p/simulations/network.go | 70 ++++++++++++------------- p2p/simulations/network_test.go | 10 ++-- p2p/simulations/pipes/pipes.go | 55 ++++++++++++++++++++ p2p/simulations/simulation.go | 14 ++--- p2p/testing/peerpool.go | 12 ++--- p2p/testing/protocolsession.go | 30 +++++------ p2p/testing/protocoltester.go | 8 +-- 16 files changed, 237 insertions(+), 143 deletions(-) create mode 100644 p2p/simulations/pipes/pipes.go diff --git a/p2p/simulations/adapters/docker.go b/p2p/simulations/adapters/docker.go index 51469a4a55..10fd732321 100644 --- a/p2p/simulations/adapters/docker.go +++ b/p2p/simulations/adapters/docker.go @@ -30,7 +30,7 @@ import ( "github.com/docker/docker/pkg/reexec" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" ) // DockerAdapter is a NodeAdapter which runs simulation nodes inside Docker @@ -61,7 +61,7 @@ func NewDockerAdapter() (*DockerAdapter, error) { return &DockerAdapter{ ExecAdapter{ - nodes: make(map[discover.NodeID]*ExecNode), + nodes: make(map[enode.ID]*ExecNode), }, }, nil } diff --git a/p2p/simulations/adapters/exec.go b/p2p/simulations/adapters/exec.go index 31a7dbe3fd..58e2613123 100644 --- a/p2p/simulations/adapters/exec.go +++ b/p2p/simulations/adapters/exec.go @@ -39,7 +39,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rpc" "golang.org/x/net/websocket" ) @@ -55,7 +55,7 @@ type ExecAdapter struct { // simulation node are created. BaseDir string - nodes map[discover.NodeID]*ExecNode + nodes map[enode.ID]*ExecNode } // NewExecAdapter returns an ExecAdapter which stores node data in @@ -63,7 +63,7 @@ type ExecAdapter struct { func NewExecAdapter(baseDir string) *ExecAdapter { return &ExecAdapter{ BaseDir: baseDir, - nodes: make(map[discover.NodeID]*ExecNode), + nodes: make(map[enode.ID]*ExecNode), } } @@ -123,7 +123,7 @@ func (e *ExecAdapter) NewNode(config *NodeConfig) (Node, error) { // ExecNode starts a simulation node by exec'ing the current binary and // running the configured services type ExecNode struct { - ID discover.NodeID + ID enode.ID Dir string Config *execNodeConfig Cmd *exec.Cmd @@ -498,7 +498,7 @@ type wsRPCDialer struct { // DialRPC implements the RPCDialer interface by creating a WebSocket RPC // client of the given node -func (w *wsRPCDialer) DialRPC(id discover.NodeID) (*rpc.Client, error) { +func (w *wsRPCDialer) DialRPC(id enode.ID) (*rpc.Client, error) { addr, ok := w.addrs[id.String()] if !ok { return nil, fmt.Errorf("unknown node: %s", id) diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index 5ebfb91094..0ca16072cb 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -27,7 +27,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rpc" ) @@ -35,7 +35,7 @@ import ( // connects them using in-memory net.Pipe connections type SimAdapter struct { mtx sync.RWMutex - nodes map[discover.NodeID]*SimNode + nodes map[enode.ID]*SimNode services map[string]ServiceFunc } @@ -44,7 +44,7 @@ type SimAdapter struct { // particular node are passed to the NewNode function in the NodeConfig) func NewSimAdapter(services map[string]ServiceFunc) *SimAdapter { return &SimAdapter{ - nodes: make(map[discover.NodeID]*SimNode), + nodes: make(map[enode.ID]*SimNode), services: services, } } @@ -96,7 +96,7 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { node: n, adapter: s, running: make(map[string]node.Service), - connected: make(map[discover.NodeID]bool), + connected: make(map[enode.ID]bool), } s.nodes[id] = simNode return simNode, nil @@ -104,12 +104,12 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { // Dial implements the p2p.NodeDialer interface by connecting to the node using // an in-memory net.Pipe connection -func (s *SimAdapter) Dial(dest *discover.Node) (conn net.Conn, err error) { - node, ok := s.GetNode(dest.ID) +func (s *SimAdapter) Dial(dest *enode.Node) (conn net.Conn, err error) { + node, ok := s.GetNode(dest.ID()) if !ok { return nil, fmt.Errorf("unknown node: %s", dest.ID) } - if node.connected[dest.ID] { + if node.connected[dest.ID()] { return nil, fmt.Errorf("dialed node: %s", dest.ID) } srv := node.Server() @@ -118,13 +118,13 @@ func (s *SimAdapter) Dial(dest *discover.Node) (conn net.Conn, err error) { } pipe1, pipe2 := net.Pipe() go srv.SetupConn(pipe1, 0, nil) - node.connected[dest.ID] = true + node.connected[dest.ID()] = true return pipe2, nil } // DialRPC implements the RPCDialer interface by creating an in-memory RPC // client of the given node -func (s *SimAdapter) DialRPC(id discover.NodeID) (*rpc.Client, error) { +func (s *SimAdapter) DialRPC(id enode.ID) (*rpc.Client, error) { node, ok := s.GetNode(id) if !ok { return nil, fmt.Errorf("unknown node: %s", id) @@ -137,7 +137,7 @@ func (s *SimAdapter) DialRPC(id discover.NodeID) (*rpc.Client, error) { } // GetNode returns the node with the given ID if it exists -func (s *SimAdapter) GetNode(id discover.NodeID) (*SimNode, bool) { +func (s *SimAdapter) GetNode(id enode.ID) (*SimNode, bool) { s.mtx.RLock() defer s.mtx.RUnlock() node, ok := s.nodes[id] @@ -149,14 +149,14 @@ func (s *SimAdapter) GetNode(id discover.NodeID) (*SimNode, bool) { // protocols directly over that pipe type SimNode struct { lock sync.RWMutex - ID discover.NodeID + ID enode.ID config *NodeConfig adapter *SimAdapter node *node.Node running map[string]node.Service client *rpc.Client registerOnce sync.Once - connected map[discover.NodeID]bool + connected map[enode.ID]bool } // Addr returns the node's discovery address @@ -164,9 +164,9 @@ func (self *SimNode) Addr() []byte { return []byte(self.Node().String()) } -// Node returns a discover.Node representing the SimNode -func (self *SimNode) Node() *discover.Node { - return discover.NewNode(self.ID, net.IP{127, 0, 0, 1}, 30303, 30303) +// Node returns a node descriptor representing the SimNode +func (sn *SimNode) Node() *enode.Node { + return sn.config.Node() } // Client returns an rpc.Client which can be used to communicate with the diff --git a/p2p/simulations/adapters/types.go b/p2p/simulations/adapters/types.go index f03bbe75b1..089f50ea20 100644 --- a/p2p/simulations/adapters/types.go +++ b/p2p/simulations/adapters/types.go @@ -23,12 +23,14 @@ import ( "fmt" "net" "os" + "strconv" "github.com/docker/docker/pkg/reexec" + "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rpc" ) @@ -38,7 +40,6 @@ import ( // * SimNode - An in-memory node // * ExecNode - A child process node // * DockerNode - A Docker container node -// type Node interface { // Addr returns the node's address (e.g. an Enode URL) Addr() []byte @@ -77,7 +78,7 @@ type NodeAdapter interface { type NodeConfig struct { // ID is the node's ID which is used to identify the node in the // simulation network - ID discover.NodeID + ID enode.ID // PrivateKey is the node's private key which is used by the devp2p // stack to encrypt communications @@ -96,25 +97,31 @@ type NodeConfig struct { Services []string // function to sanction or prevent suggesting a peer - Reachable func(id discover.NodeID) bool + Reachable func(id enode.ID) bool + + Port uint16 } // nodeConfigJSON is used to encode and decode NodeConfig as JSON by encoding // all fields as strings type nodeConfigJSON struct { - ID string `json:"id"` - PrivateKey string `json:"private_key"` - Name string `json:"name"` - Services []string `json:"services"` + ID string `json:"id"` + PrivateKey string `json:"private_key"` + Name string `json:"name"` + Services []string `json:"services"` + EnableMsgEvents bool `json:"enable_msg_events"` + Port uint16 `json:"port"` } // MarshalJSON implements the json.Marshaler interface by encoding the config // fields as strings func (n *NodeConfig) MarshalJSON() ([]byte, error) { confJSON := nodeConfigJSON{ - ID: n.ID.String(), - Name: n.Name, - Services: n.Services, + ID: n.ID.String(), + Name: n.Name, + Services: n.Services, + Port: n.Port, + EnableMsgEvents: n.EnableMsgEvents, } if n.PrivateKey != nil { confJSON.PrivateKey = hex.EncodeToString(crypto.FromECDSA(n.PrivateKey)) @@ -131,11 +138,9 @@ func (n *NodeConfig) UnmarshalJSON(data []byte) error { } if confJSON.ID != "" { - nodeID, err := discover.HexID(confJSON.ID) - if err != nil { + if err := n.ID.UnmarshalText([]byte(confJSON.ID)); err != nil { return err } - n.ID = nodeID } if confJSON.PrivateKey != "" { @@ -152,10 +157,17 @@ func (n *NodeConfig) UnmarshalJSON(data []byte) error { n.Name = confJSON.Name n.Services = confJSON.Services + n.Port = confJSON.Port + n.EnableMsgEvents = confJSON.EnableMsgEvents return nil } +// Node returns the node descriptor represented by the config. +func (n *NodeConfig) Node() *enode.Node { + return enode.NewV4(&n.PrivateKey.PublicKey, net.IP{127, 0, 0, 1}, int(n.Port), int(n.Port)) +} + // RandomNodeConfig returns node configuration with a randomly generated ID and // PrivateKey func RandomNodeConfig() *NodeConfig { @@ -163,13 +175,36 @@ func RandomNodeConfig() *NodeConfig { if err != nil { panic("unable to generate key") } - var id discover.NodeID - pubkey := crypto.FromECDSAPub(&key.PublicKey) - copy(id[:], pubkey[1:]) + + id := enode.PubkeyToIDV4(&key.PublicKey) + port, err := assignTCPPort() + if err != nil { + panic("unable to assign tcp port") + } return &NodeConfig{ - ID: id, - PrivateKey: key, + ID: id, + Name: fmt.Sprintf("node_%s", id.String()), + PrivateKey: key, + Port: port, + EnableMsgEvents: true, + } +} + +func assignTCPPort() (uint16, error) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return 0, err + } + l.Close() + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + return 0, err + } + p, err := strconv.ParseInt(port, 10, 32) + if err != nil { + return 0, err } + return uint16(p), nil } // ServiceContext is a collection of options and methods which can be utilised @@ -186,7 +221,7 @@ type ServiceContext struct { // other nodes in the network (for example a simulated Swarm node which needs // to connect to a Geth node to resolve ENS names) type RPCDialer interface { - DialRPC(id discover.NodeID) (*rpc.Client, error) + DialRPC(id enode.ID) (*rpc.Client, error) } // Services is a collection of services which can be run in a simulation diff --git a/p2p/simulations/examples/ping-pong.go b/p2p/simulations/examples/ping-pong.go index dae524d05b..de7a9e6b5c 100644 --- a/p2p/simulations/examples/ping-pong.go +++ b/p2p/simulations/examples/ping-pong.go @@ -28,7 +28,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/simulations" "github.com/tomochain/tomochain/p2p/simulations/adapters" "github.com/tomochain/tomochain/rpc" @@ -96,12 +96,12 @@ func main() { // sends a ping to all its connected peers every 10s and receives a pong in // return type pingPongService struct { - id discover.NodeID + id enode.ID log log.Logger received int64 } -func newPingPongService(id discover.NodeID) *pingPongService { +func newPingPongService(id enode.ID) *pingPongService { return &pingPongService{ id: id, log: log.New("node.id", id), diff --git a/p2p/simulations/http.go b/p2p/simulations/http.go index 29159b6fc1..d7ed380a4e 100644 --- a/p2p/simulations/http.go +++ b/p2p/simulations/http.go @@ -29,10 +29,11 @@ import ( "strings" "sync" + "github.com/tomochain/tomochain/p2p/enode" + "github.com/julienschmidt/httprouter" "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" "github.com/tomochain/tomochain/p2p/simulations/adapters" "github.com/tomochain/tomochain/rpc" "golang.org/x/net/websocket" @@ -698,18 +699,19 @@ func (s *Server) JSON(w http.ResponseWriter, status int, data interface{}) { json.NewEncoder(w).Encode(data) } -// wrapHandler returns a httprouter.Handle which wraps a http.HandlerFunc by +// wrapHandler returns an httprouter.Handle which wraps an http.HandlerFunc by // populating request.Context with any objects from the URL params func (s *Server) wrapHandler(handler http.HandlerFunc) httprouter.Handle { return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - ctx := context.Background() + ctx := req.Context() if id := params.ByName("nodeid"); id != "" { + var nodeID enode.ID var node *Node - if nodeID, err := discover.HexID(id); err == nil { + if nodeID.UnmarshalText([]byte(id)) == nil { node = s.network.GetNode(nodeID) } else { node = s.network.GetNodeByName(id) @@ -722,8 +724,9 @@ func (s *Server) wrapHandler(handler http.HandlerFunc) httprouter.Handle { } if id := params.ByName("peerid"); id != "" { + var peerID enode.ID var peer *Node - if peerID, err := discover.HexID(id); err == nil { + if peerID.UnmarshalText([]byte(id)) == nil { peer = s.network.GetNode(peerID) } else { peer = s.network.GetNodeByName(id) diff --git a/p2p/simulations/http_test.go b/p2p/simulations/http_test.go index e00b8057c7..a89301895f 100644 --- a/p2p/simulations/http_test.go +++ b/p2p/simulations/http_test.go @@ -31,6 +31,7 @@ import ( "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/simulations/adapters" "github.com/tomochain/tomochain/rpc" ) @@ -38,12 +39,12 @@ import ( // testService implements the node.Service interface and provides protocols // and APIs which are useful for testing nodes in a simulation network type testService struct { - id discover.NodeID + id enode.ID // peerCount is incremented once a peer handshake has been performed peerCount int64 - peers map[discover.NodeID]*testPeer + peers map[enode.ID]*testPeer peersMtx sync.Mutex // state stores []byte which is used to test creating and loading @@ -54,7 +55,7 @@ type testService struct { func newTestService(ctx *adapters.ServiceContext) (node.Service, error) { svc := &testService{ id: ctx.Config.ID, - peers: make(map[discover.NodeID]*testPeer), + peers: make(map[enode.ID]*testPeer), } svc.state.Store(ctx.Snapshot) return svc, nil @@ -65,7 +66,7 @@ type testPeer struct { dumReady chan struct{} } -func (t *testService) peer(id discover.NodeID) *testPeer { +func (t *testService) peer(id enode.ID) *testPeer { t.peersMtx.Lock() defer t.peersMtx.Unlock() if peer, ok := t.peers[id]; ok { diff --git a/p2p/simulations/mocker.go b/p2p/simulations/mocker.go index daff17e29d..d052d9e26d 100644 --- a/p2p/simulations/mocker.go +++ b/p2p/simulations/mocker.go @@ -25,23 +25,23 @@ import ( "time" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" ) -//a map of mocker names to its function +// a map of mocker names to its function var mockerList = map[string]func(net *Network, quit chan struct{}, nodeCount int){ "startStop": startStop, "probabilistic": probabilistic, "boot": boot, } -//Lookup a mocker by its name, returns the mockerFn +// Lookup a mocker by its name, returns the mockerFn func LookupMocker(mockerType string) func(net *Network, quit chan struct{}, nodeCount int) { return mockerList[mockerType] } -//Get a list of mockers (keys of the map) -//Useful for frontend to build available mocker selection +// Get a list of mockers (keys of the map) +// Useful for frontend to build available mocker selection func GetMockerList() []string { list := make([]string, 0, len(mockerList)) for k := range mockerList { @@ -50,7 +50,7 @@ func GetMockerList() []string { return list } -//The boot mockerFn only connects the node in a ring and doesn't do anything else +// The boot mockerFn only connects the node in a ring and doesn't do anything else func boot(net *Network, quit chan struct{}, nodeCount int) { _, err := connectNodesInRing(net, nodeCount) if err != nil { @@ -58,7 +58,7 @@ func boot(net *Network, quit chan struct{}, nodeCount int) { } } -//The startStop mockerFn stops and starts nodes in a defined period (ticker) +// The startStop mockerFn stops and starts nodes in a defined period (ticker) func startStop(net *Network, quit chan struct{}, nodeCount int) { nodes, err := connectNodesInRing(net, nodeCount) if err != nil { @@ -95,10 +95,10 @@ func startStop(net *Network, quit chan struct{}, nodeCount int) { } } -//The probabilistic mocker func has a more probabilistic pattern -//(the implementation could probably be improved): -//nodes are connected in a ring, then a varying number of random nodes is selected, -//mocker then stops and starts them in random intervals, and continues the loop +// The probabilistic mocker func has a more probabilistic pattern +// (the implementation could probably be improved): +// nodes are connected in a ring, then a varying number of random nodes is selected, +// mocker then stops and starts them in random intervals, and continues the loop func probabilistic(net *Network, quit chan struct{}, nodeCount int) { nodes, err := connectNodesInRing(net, nodeCount) if err != nil { @@ -147,7 +147,7 @@ func probabilistic(net *Network, quit chan struct{}, nodeCount int) { wg.Done() continue } - go func(id discover.NodeID) { + go func(id enode.ID) { time.Sleep(randWait) err := net.Start(id) if err != nil { @@ -161,9 +161,9 @@ func probabilistic(net *Network, quit chan struct{}, nodeCount int) { } -//connect nodeCount number of nodes in a ring -func connectNodesInRing(net *Network, nodeCount int) ([]discover.NodeID, error) { - ids := make([]discover.NodeID, nodeCount) +// connect nodeCount number of nodes in a ring +func connectNodesInRing(net *Network, nodeCount int) ([]enode.ID, error) { + ids := make([]enode.ID, nodeCount) for i := 0; i < nodeCount; i++ { node, err := net.NewNode() if err != nil { diff --git a/p2p/simulations/mocker_test.go b/p2p/simulations/mocker_test.go index f9a23bfe12..397c4c5dcf 100644 --- a/p2p/simulations/mocker_test.go +++ b/p2p/simulations/mocker_test.go @@ -27,7 +27,7 @@ import ( "testing" "time" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" ) func TestMocker(t *testing.T) { @@ -82,7 +82,7 @@ func TestMocker(t *testing.T) { defer sub.Unsubscribe() //wait until all nodes are started and connected //store every node up event in a map (value is irrelevant, mimic Set datatype) - nodemap := make(map[discover.NodeID]bool) + nodemap := make(map[enode.ID]bool) wg.Add(1) nodesComplete := false connCount := 0 diff --git a/p2p/simulations/network.go b/p2p/simulations/network.go index 08643f7d89..1c50b8f359 100644 --- a/p2p/simulations/network.go +++ b/p2p/simulations/network.go @@ -27,7 +27,7 @@ import ( "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/simulations/adapters" ) @@ -51,7 +51,7 @@ type Network struct { NetworkConfig Nodes []*Node `json:"nodes"` - nodeMap map[discover.NodeID]int + nodeMap map[enode.ID]int Conns []*Conn `json:"conns"` connMap map[string]int @@ -67,7 +67,7 @@ func NewNetwork(nodeAdapter adapters.NodeAdapter, conf *NetworkConfig) *Network return &Network{ NetworkConfig: *conf, nodeAdapter: nodeAdapter, - nodeMap: make(map[discover.NodeID]int), + nodeMap: make(map[enode.ID]int), connMap: make(map[string]int), quitc: make(chan struct{}), } @@ -92,14 +92,14 @@ func (self *Network) NewNodeWithConfig(conf *adapters.NodeConfig) (*Node, error) defer self.lock.Unlock() // create a random ID and PrivateKey if not set - if conf.ID == (discover.NodeID{}) { + if conf.ID == (enode.ID{}) { c := adapters.RandomNodeConfig() conf.ID = c.ID conf.PrivateKey = c.PrivateKey } id := conf.ID if conf.Reachable == nil { - conf.Reachable = func(otherID discover.NodeID) bool { + conf.Reachable = func(otherID enode.ID) bool { _, err := self.InitConn(conf.ID, otherID) return err == nil } @@ -174,13 +174,13 @@ func (self *Network) StopAll() error { } // Start starts the node with the given ID -func (self *Network) Start(id discover.NodeID) error { +func (self *Network) Start(id enode.ID) error { return self.startWithSnapshots(id, nil) } // startWithSnapshots starts the node with the given ID using the give // snapshots -func (self *Network) startWithSnapshots(id discover.NodeID, snapshots map[string][]byte) error { +func (self *Network) startWithSnapshots(id enode.ID, snapshots map[string][]byte) error { node := self.GetNode(id) if node == nil { return fmt.Errorf("node %v does not exist", id) @@ -214,7 +214,7 @@ func (self *Network) startWithSnapshots(id discover.NodeID, snapshots map[string // watchPeerEvents reads peer events from the given channel and emits // corresponding network events -func (self *Network) watchPeerEvents(id discover.NodeID, events chan *p2p.PeerEvent, sub event.Subscription) { +func (self *Network) watchPeerEvents(id enode.ID, events chan *p2p.PeerEvent, sub event.Subscription) { defer func() { sub.Unsubscribe() @@ -258,7 +258,7 @@ func (self *Network) watchPeerEvents(id discover.NodeID, events chan *p2p.PeerEv } // Stop stops the node with the given ID -func (self *Network) Stop(id discover.NodeID) error { +func (self *Network) Stop(id enode.ID) error { node := self.GetNode(id) if node == nil { return fmt.Errorf("node %v does not exist", id) @@ -278,7 +278,7 @@ func (self *Network) Stop(id discover.NodeID) error { // Connect connects two nodes together by calling the "admin_addPeer" RPC // method on the "one" node so that it connects to the "other" node -func (self *Network) Connect(oneID, otherID discover.NodeID) error { +func (self *Network) Connect(oneID, otherID enode.ID) error { log.Debug(fmt.Sprintf("connecting %s to %s", oneID, otherID)) conn, err := self.InitConn(oneID, otherID) if err != nil { @@ -294,7 +294,7 @@ func (self *Network) Connect(oneID, otherID discover.NodeID) error { // Disconnect disconnects two nodes by calling the "admin_removePeer" RPC // method on the "one" node so that it disconnects from the "other" node -func (self *Network) Disconnect(oneID, otherID discover.NodeID) error { +func (self *Network) Disconnect(oneID, otherID enode.ID) error { conn := self.GetConn(oneID, otherID) if conn == nil { return fmt.Errorf("connection between %v and %v does not exist", oneID, otherID) @@ -311,7 +311,7 @@ func (self *Network) Disconnect(oneID, otherID discover.NodeID) error { } // DidConnect tracks the fact that the "one" node connected to the "other" node -func (self *Network) DidConnect(one, other discover.NodeID) error { +func (self *Network) DidConnect(one, other enode.ID) error { conn, err := self.GetOrCreateConn(one, other) if err != nil { return fmt.Errorf("connection between %v and %v does not exist", one, other) @@ -326,7 +326,7 @@ func (self *Network) DidConnect(one, other discover.NodeID) error { // DidDisconnect tracks the fact that the "one" node disconnected from the // "other" node -func (self *Network) DidDisconnect(one, other discover.NodeID) error { +func (self *Network) DidDisconnect(one, other enode.ID) error { conn := self.GetConn(one, other) if conn == nil { return fmt.Errorf("connection between %v and %v does not exist", one, other) @@ -341,7 +341,7 @@ func (self *Network) DidDisconnect(one, other discover.NodeID) error { } // DidSend tracks the fact that "sender" sent a message to "receiver" -func (self *Network) DidSend(sender, receiver discover.NodeID, proto string, code uint64) error { +func (self *Network) DidSend(sender, receiver enode.ID, proto string, code uint64) error { msg := &Msg{ One: sender, Other: receiver, @@ -354,7 +354,7 @@ func (self *Network) DidSend(sender, receiver discover.NodeID, proto string, cod } // DidReceive tracks the fact that "receiver" received a message from "sender" -func (self *Network) DidReceive(sender, receiver discover.NodeID, proto string, code uint64) error { +func (self *Network) DidReceive(sender, receiver enode.ID, proto string, code uint64) error { msg := &Msg{ One: sender, Other: receiver, @@ -368,7 +368,7 @@ func (self *Network) DidReceive(sender, receiver discover.NodeID, proto string, // GetNode gets the node with the given ID, returning nil if the node does not // exist -func (self *Network) GetNode(id discover.NodeID) *Node { +func (self *Network) GetNode(id enode.ID) *Node { self.lock.Lock() defer self.lock.Unlock() return self.getNode(id) @@ -382,7 +382,7 @@ func (self *Network) GetNodeByName(name string) *Node { return self.getNodeByName(name) } -func (self *Network) getNode(id discover.NodeID) *Node { +func (self *Network) getNode(id enode.ID) *Node { i, found := self.nodeMap[id] if !found { return nil @@ -410,7 +410,7 @@ func (self *Network) GetNodes() (nodes []*Node) { // GetConn returns the connection which exists between "one" and "other" // regardless of which node initiated the connection -func (self *Network) GetConn(oneID, otherID discover.NodeID) *Conn { +func (self *Network) GetConn(oneID, otherID enode.ID) *Conn { self.lock.Lock() defer self.lock.Unlock() return self.getConn(oneID, otherID) @@ -418,13 +418,13 @@ func (self *Network) GetConn(oneID, otherID discover.NodeID) *Conn { // GetOrCreateConn is like GetConn but creates the connection if it doesn't // already exist -func (self *Network) GetOrCreateConn(oneID, otherID discover.NodeID) (*Conn, error) { +func (self *Network) GetOrCreateConn(oneID, otherID enode.ID) (*Conn, error) { self.lock.Lock() defer self.lock.Unlock() return self.getOrCreateConn(oneID, otherID) } -func (self *Network) getOrCreateConn(oneID, otherID discover.NodeID) (*Conn, error) { +func (self *Network) getOrCreateConn(oneID, otherID enode.ID) (*Conn, error) { if conn := self.getConn(oneID, otherID); conn != nil { return conn, nil } @@ -449,7 +449,7 @@ func (self *Network) getOrCreateConn(oneID, otherID discover.NodeID) (*Conn, err return conn, nil } -func (self *Network) getConn(oneID, otherID discover.NodeID) *Conn { +func (self *Network) getConn(oneID, otherID enode.ID) *Conn { label := ConnLabel(oneID, otherID) i, found := self.connMap[label] if !found { @@ -466,7 +466,7 @@ func (self *Network) getConn(oneID, otherID discover.NodeID) *Conn { // it also checks whether there has been recent attempt to connect the peers // this is cheating as the simulation is used as an oracle and know about // remote peers attempt to connect to a node which will then not initiate the connection -func (self *Network) InitConn(oneID, otherID discover.NodeID) (*Conn, error) { +func (self *Network) InitConn(oneID, otherID enode.ID) (*Conn, error) { self.lock.Lock() defer self.lock.Unlock() if oneID == otherID { @@ -501,15 +501,15 @@ func (self *Network) Shutdown() { close(self.quitc) } -//Reset resets all network properties: -//emtpies the nodes and the connection list +// Reset resets all network properties: +// emtpies the nodes and the connection list func (self *Network) Reset() { self.lock.Lock() defer self.lock.Unlock() //re-initialize the maps self.connMap = make(map[string]int) - self.nodeMap = make(map[discover.NodeID]int) + self.nodeMap = make(map[enode.ID]int) self.Nodes = nil self.Conns = nil @@ -528,7 +528,7 @@ type Node struct { } // ID returns the ID of the node -func (self *Node) ID() discover.NodeID { +func (self *Node) ID() enode.ID { return self.Config.ID } @@ -565,10 +565,10 @@ func (self *Node) MarshalJSON() ([]byte, error) { // Conn represents a connection between two nodes in the network type Conn struct { // One is the node which initiated the connection - One discover.NodeID `json:"one"` + One enode.ID `json:"one"` // Other is the node which the connection was made to - Other discover.NodeID `json:"other"` + Other enode.ID `json:"other"` // Up tracks whether or not the connection is active Up bool `json:"up"` @@ -597,11 +597,11 @@ func (self *Conn) String() string { // Msg represents a p2p message sent between two nodes in the network type Msg struct { - One discover.NodeID `json:"one"` - Other discover.NodeID `json:"other"` - Protocol string `json:"protocol"` - Code uint64 `json:"code"` - Received bool `json:"received"` + One enode.ID `json:"one"` + Other enode.ID `json:"other"` + Protocol string `json:"protocol"` + Code uint64 `json:"code"` + Received bool `json:"received"` } // String returns a log-friendly string @@ -612,8 +612,8 @@ func (self *Msg) String() string { // ConnLabel generates a deterministic string which represents a connection // between two nodes, used to compare if two connections are between the same // nodes -func ConnLabel(source, target discover.NodeID) string { - var first, second discover.NodeID +func ConnLabel(source, target enode.ID) string { + var first, second enode.ID if bytes.Compare(source.Bytes(), target.Bytes()) > 0 { first = target second = source diff --git a/p2p/simulations/network_test.go b/p2p/simulations/network_test.go index da97428cf6..a89083b629 100644 --- a/p2p/simulations/network_test.go +++ b/p2p/simulations/network_test.go @@ -22,7 +22,7 @@ import ( "testing" "time" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/simulations/adapters" ) @@ -39,7 +39,7 @@ func TestNetworkSimulation(t *testing.T) { }) defer network.Shutdown() nodeCount := 20 - ids := make([]discover.NodeID, nodeCount) + ids := make([]enode.ID, nodeCount) for i := 0; i < nodeCount; i++ { node, err := network.NewNode() if err != nil { @@ -63,7 +63,7 @@ func TestNetworkSimulation(t *testing.T) { } return nil } - check := func(ctx context.Context, id discover.NodeID) (bool, error) { + check := func(ctx context.Context, id enode.ID) (bool, error) { // check we haven't run out of time select { case <-ctx.Done(): @@ -101,7 +101,7 @@ func TestNetworkSimulation(t *testing.T) { defer cancel() // trigger a check every 100ms - trigger := make(chan discover.NodeID) + trigger := make(chan enode.ID) go triggerChecks(ctx, ids, trigger, 100*time.Millisecond) result := NewSimulation(network).Run(ctx, &Step{ @@ -139,7 +139,7 @@ func TestNetworkSimulation(t *testing.T) { } } -func triggerChecks(ctx context.Context, ids []discover.NodeID, trigger chan discover.NodeID, interval time.Duration) { +func triggerChecks(ctx context.Context, ids []enode.ID, trigger chan enode.ID, interval time.Duration) { tick := time.NewTicker(interval) defer tick.Stop() for { diff --git a/p2p/simulations/pipes/pipes.go b/p2p/simulations/pipes/pipes.go new file mode 100644 index 0000000000..ec277c0d14 --- /dev/null +++ b/p2p/simulations/pipes/pipes.go @@ -0,0 +1,55 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package pipes + +import ( + "net" +) + +// NetPipe wraps net.Pipe in a signature returning an error +func NetPipe() (net.Conn, net.Conn, error) { + p1, p2 := net.Pipe() + return p1, p2, nil +} + +// TCPPipe creates an in process full duplex pipe based on a localhost TCP socket +func TCPPipe() (net.Conn, net.Conn, error) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, err + } + defer l.Close() + + var aconn net.Conn + aerr := make(chan error, 1) + go func() { + var err error + aconn, err = l.Accept() + aerr <- err + }() + + dconn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + <-aerr + return nil, nil, err + } + if err := <-aerr; err != nil { + dconn.Close() + return nil, nil, err + } + return aconn, dconn, nil +} diff --git a/p2p/simulations/simulation.go b/p2p/simulations/simulation.go index 6fc879ed10..0879cd9123 100644 --- a/p2p/simulations/simulation.go +++ b/p2p/simulations/simulation.go @@ -20,7 +20,7 @@ import ( "context" "time" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" ) // Simulation provides a framework for running actions in a simulated network @@ -55,7 +55,7 @@ func (s *Simulation) Run(ctx context.Context, step *Step) (result *StepResult) { } // wait for all node expectations to either pass, error or timeout - nodes := make(map[discover.NodeID]struct{}, len(step.Expect.Nodes)) + nodes := make(map[enode.ID]struct{}, len(step.Expect.Nodes)) for _, id := range step.Expect.Nodes { nodes[id] = struct{}{} } @@ -119,7 +119,7 @@ type Step struct { // Trigger is a channel which receives node ids and triggers an // expectation check for that node - Trigger chan discover.NodeID + Trigger chan enode.ID // Expect is the expectation to wait for when performing this step Expect *Expectation @@ -127,15 +127,15 @@ type Step struct { type Expectation struct { // Nodes is a list of nodes to check - Nodes []discover.NodeID + Nodes []enode.ID // Check checks whether a given node meets the expectation - Check func(context.Context, discover.NodeID) (bool, error) + Check func(context.Context, enode.ID) (bool, error) } func newStepResult() *StepResult { return &StepResult{ - Passes: make(map[discover.NodeID]time.Time), + Passes: make(map[enode.ID]time.Time), } } @@ -150,7 +150,7 @@ type StepResult struct { FinishedAt time.Time // Passes are the timestamps of the successful node expectations - Passes map[discover.NodeID]time.Time + Passes map[enode.ID]time.Time // NetworkEvents are the network events which occurred during the step NetworkEvents []*Event diff --git a/p2p/testing/peerpool.go b/p2p/testing/peerpool.go index 0934cfbdb8..9b80e8e051 100644 --- a/p2p/testing/peerpool.go +++ b/p2p/testing/peerpool.go @@ -21,22 +21,22 @@ import ( "sync" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" ) type TestPeer interface { - ID() discover.NodeID + ID() enode.ID Drop(error) } // TestPeerPool is an example peerPool to demonstrate registration of peer connections type TestPeerPool struct { lock sync.Mutex - peers map[discover.NodeID]TestPeer + peers map[enode.ID]TestPeer } func NewTestPeerPool() *TestPeerPool { - return &TestPeerPool{peers: make(map[discover.NodeID]TestPeer)} + return &TestPeerPool{peers: make(map[enode.ID]TestPeer)} } func (self *TestPeerPool) Add(p TestPeer) { @@ -53,14 +53,14 @@ func (self *TestPeerPool) Remove(p TestPeer) { delete(self.peers, p.ID()) } -func (self *TestPeerPool) Has(id discover.NodeID) bool { +func (self *TestPeerPool) Has(id enode.ID) bool { self.lock.Lock() defer self.lock.Unlock() _, ok := self.peers[id] return ok } -func (self *TestPeerPool) Get(id discover.NodeID) TestPeer { +func (self *TestPeerPool) Get(id enode.ID) TestPeer { self.lock.Lock() defer self.lock.Unlock() return self.peers[id] diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go index 6f4d4c4994..783af99bf2 100644 --- a/p2p/testing/protocolsession.go +++ b/p2p/testing/protocolsession.go @@ -24,7 +24,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/simulations/adapters" ) @@ -35,7 +35,7 @@ var errTimedOut = errors.New("timed out") // receive (expect) messages type ProtocolSession struct { Server *p2p.Server - IDs []discover.NodeID + IDs []enode.ID adapter *adapters.SimAdapter events chan *p2p.PeerEvent } @@ -56,25 +56,25 @@ type Exchange struct { // Trigger is part of the exchange, incoming message for the pivot node // sent by a peer type Trigger struct { - Msg interface{} // type of message to be sent - Code uint64 // code of message is given - Peer discover.NodeID // the peer to send the message to - Timeout time.Duration // timeout duration for the sending + Msg interface{} // type of message to be sent + Code uint64 // code of message is given + Peer enode.ID // the peer to send the message to + Timeout time.Duration // timeout duration for the sending } // Expect is part of an exchange, outgoing message from the pivot node // received by a peer type Expect struct { - Msg interface{} // type of message to expect - Code uint64 // code of message is now given - Peer discover.NodeID // the peer that expects the message - Timeout time.Duration // timeout duration for receiving + Msg interface{} // type of message to expect + Code uint64 // code of message is now given + Peer enode.ID // the peer that expects the message + Timeout time.Duration // timeout duration for receiving } // Disconnect represents a disconnect event, used and checked by TestDisconnected type Disconnect struct { - Peer discover.NodeID // discconnected peer - Error error // disconnect reason + Peer enode.ID // discconnected peer + Error error // disconnect reason } // trigger sends messages from peers @@ -109,7 +109,7 @@ func (self *ProtocolSession) trigger(trig Trigger) error { // expect checks an expectation of a message sent out by the pivot node func (self *ProtocolSession) expect(exps []Expect) error { // construct a map of expectations for each node - peerExpects := make(map[discover.NodeID][]Expect) + peerExpects := make(map[enode.ID][]Expect) for _, exp := range exps { if exp.Msg == nil { return errors.New("no message to expect") @@ -118,7 +118,7 @@ func (self *ProtocolSession) expect(exps []Expect) error { } // construct a map of mockNodes for each node - mockNodes := make(map[discover.NodeID]*mockNode) + mockNodes := make(map[enode.ID]*mockNode) for nodeID := range peerExpects { simNode, ok := self.adapter.GetNode(nodeID) if !ok { @@ -251,7 +251,7 @@ func (self *ProtocolSession) testExchange(e Exchange) error { // TestDisconnected tests the disconnections given as arguments // the disconnect structs describe what disconnect error is expected on which peer func (self *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error { - expects := make(map[discover.NodeID]error) + expects := make(map[enode.ID]error) for _, disconnect := range disconnects { expects[disconnect.Peer] = disconnect.Error } diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go index 0ac8b05f34..f16acbac95 100644 --- a/p2p/testing/protocoltester.go +++ b/p2p/testing/protocoltester.go @@ -35,7 +35,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/simulations" "github.com/tomochain/tomochain/p2p/simulations/adapters" "github.com/tomochain/tomochain/rlp" @@ -52,7 +52,7 @@ type ProtocolTester struct { // NewProtocolTester constructs a new ProtocolTester // it takes as argument the pivot node id, the number of dummy peers and the // protocol run function called on a peer connection by the p2p server -func NewProtocolTester(t *testing.T, id discover.NodeID, n int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester { +func NewProtocolTester(t *testing.T, id enode.ID, n int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester { services := adapters.Services{ "test": func(ctx *adapters.ServiceContext) (node.Service, error) { return &testNode{run}, nil @@ -76,7 +76,7 @@ func NewProtocolTester(t *testing.T, id discover.NodeID, n int, run func(*p2p.Pe node := net.GetNode(id).Node.(*adapters.SimNode) peers := make([]*adapters.NodeConfig, n) - peerIDs := make([]discover.NodeID, n) + peerIDs := make([]enode.ID, n) for i := 0; i < n; i++ { peers[i] = adapters.RandomNodeConfig() peers[i].Services = []string{"mock"} @@ -108,7 +108,7 @@ func (self *ProtocolTester) Stop() error { // Connect brings up the remote peer node and connects it using the // p2p/simulations network connection with the in memory network adapter -func (self *ProtocolTester) Connect(selfID discover.NodeID, peers ...*adapters.NodeConfig) { +func (self *ProtocolTester) Connect(selfID enode.ID, peers ...*adapters.NodeConfig) { for _, peer := range peers { log.Trace(fmt.Sprintf("start node %v", peer.ID)) if _, err := self.network.NewNodeWithConfig(peer); err != nil { From 70f4fafcf00461bef4c9413ca57d0dd4c6129e93 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 30 Oct 2023 16:41:34 +0700 Subject: [PATCH 099/119] cmd, swarm: port to p2p/enode --- cmd/bootnode/main.go | 3 +- cmd/faucet/faucet.go | 4 +- cmd/p2psim/main.go | 25 ++- cmd/swarm/main.go | 6 +- cmd/utils/flags.go | 6 +- cmd/wnode/main.go | 10 +- eth/handler.go | 4 +- eth/peer.go | 225 +++++++++++++++----------- eth/sync.go | 4 +- eth/sync_test.go | 6 +- les/handler.go | 10 +- les/peer.go | 36 ++--- les/protocol.go | 14 +- les/serverpool.go | 323 ++++++++++++++++++++++++-------------- node/api.go | 6 +- node/config.go | 12 +- swarm/network/hive.go | 12 +- swarm/network/messages.go | 22 +-- swarm/swarm.go | 6 +- whisper/whisperv5/api.go | 14 +- whisper/whisperv6/api.go | 12 +- 21 files changed, 448 insertions(+), 312 deletions(-) diff --git a/cmd/bootnode/main.go b/cmd/bootnode/main.go index c582847c54..f3a8278873 100644 --- a/cmd/bootnode/main.go +++ b/cmd/bootnode/main.go @@ -29,6 +29,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p/discover" "github.com/tomochain/tomochain/p2p/discv5" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/nat" "github.com/tomochain/tomochain/p2p/netutil" ) @@ -85,7 +86,7 @@ func main() { } if *writeAddr { - fmt.Printf("%v\n", discover.PubkeyID(&nodeKey.PublicKey)) + fmt.Printf("%v\n", enode.PubkeyToIDV4(&nodeKey.PublicKey)) os.Exit(0) } diff --git a/cmd/faucet/faucet.go b/cmd/faucet/faucet.go index 45a5e6cb4f..33a17f35b7 100644 --- a/cmd/faucet/faucet.go +++ b/cmd/faucet/faucet.go @@ -54,8 +54,8 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" "github.com/tomochain/tomochain/p2p/discv5" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/nat" "github.com/tomochain/tomochain/params" "golang.org/x/net/websocket" @@ -255,7 +255,7 @@ func newFaucet(genesis *core.Genesis, port int, enodes []*discv5.Node, network u return nil, err } for _, boot := range enodes { - old, _ := discover.ParseNode(boot.String()) + old, _ := enode.ParseV4(boot.String()) stack.Server().AddPeer(old) } // Attach to the client and retrieve and interesting metadatas diff --git a/cmd/p2psim/main.go b/cmd/p2psim/main.go index 7ae0b8b56e..a39c5da3ae 100644 --- a/cmd/p2psim/main.go +++ b/cmd/p2psim/main.go @@ -19,21 +19,20 @@ // Here is an example of creating a 2 node network with the first node // connected to the second: // -// $ p2psim node create -// Created node01 +// $ p2psim node create +// Created node01 // -// $ p2psim node start node01 -// Started node01 +// $ p2psim node start node01 +// Started node01 // -// $ p2psim node create -// Created node02 +// $ p2psim node create +// Created node02 // -// $ p2psim node start node02 -// Started node02 -// -// $ p2psim node connect node01 node02 -// Connected node01 to node02 +// $ p2psim node start node02 +// Started node02 // +// $ p2psim node connect node01 node02 +// Connected node01 to node02 package main import ( @@ -47,7 +46,7 @@ import ( "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/simulations" "github.com/tomochain/tomochain/p2p/simulations/adapters" "github.com/tomochain/tomochain/rpc" @@ -283,7 +282,7 @@ func createNode(ctx *cli.Context) error { if err != nil { return err } - config.ID = discover.PubkeyID(&privKey.PublicKey) + config.ID = enode.PubkeyToIDV4(&privKey.PublicKey) config.PrivateKey = privKey } if services := ctx.String("services"); services != "" { diff --git a/cmd/swarm/main.go b/cmd/swarm/main.go index ecd6aae792..221ccbbcb7 100644 --- a/cmd/swarm/main.go +++ b/cmd/swarm/main.go @@ -39,7 +39,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/swarm" bzzapi "github.com/tomochain/tomochain/swarm/api" @@ -153,7 +153,7 @@ var ( } ) -//declare a few constant error messages, useful for later error check comparisons in test +// declare a few constant error messages, useful for later error check comparisons in test var ( SWARM_ERR_NO_BZZACCOUNT = "bzzaccount option is required but not set; check your config file, command line or environment variables" SWARM_ERR_SWAP_SET_NO_API = "SWAP is enabled but --swap-api is not set" @@ -543,7 +543,7 @@ func getPassPhrase(prompt string, i int, passwords []string) string { func injectBootnodes(srv *p2p.Server, nodes []string) { for _, url := range nodes { - n, err := discover.ParseNode(url) + n, err := enode.ParseV4(url) if err != nil { log.Error("Invalid swarm bootnode", "err", err) continue diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 90e85528b2..5f8542b290 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -46,8 +46,8 @@ import ( "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" "github.com/tomochain/tomochain/p2p/discv5" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/nat" "github.com/tomochain/tomochain/p2p/netutil" "github.com/tomochain/tomochain/params" @@ -643,9 +643,9 @@ func setBootstrapNodes(ctx *cli.Context, cfg *p2p.Config) { case ctx.GlobalBool(TomoTestnetFlag.Name): urls = params.TestnetBootnodes } - cfg.BootstrapNodes = make([]*discover.Node, 0, len(urls)) + cfg.BootstrapNodes = make([]*enode.Node, 0, len(urls)) for _, url := range urls { - node, err := discover.ParseNode(url) + node, err := enode.ParseV4(url) if err != nil { log.Error("Bootstrap URL invalid", "enode", url, "err", err) continue diff --git a/cmd/wnode/main.go b/cmd/wnode/main.go index 78c558bc12..b5795e7158 100644 --- a/cmd/wnode/main.go +++ b/cmd/wnode/main.go @@ -41,7 +41,7 @@ import ( "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/nat" "github.com/tomochain/tomochain/whisper/mailserver" whisper "github.com/tomochain/tomochain/whisper/whisperv6" @@ -175,7 +175,7 @@ func initialize() { log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*argVerbosity), log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) done = make(chan struct{}) - var peers []*discover.Node + var peers []*enode.Node var err error if *generateKey { @@ -203,7 +203,7 @@ func initialize() { if len(*argEnode) == 0 { argEnode = scanLineA("Please enter the peer's enode: ") } - peer := discover.MustParseNode(*argEnode) + peer := enode.MustParseV4(*argEnode) peers = append(peers, peer) } @@ -750,11 +750,11 @@ func requestExpiredMessagesLoop() { } func extractIDFromEnode(s string) []byte { - n, err := discover.ParseNode(s) + n, err := enode.ParseV4(s) if err != nil { utils.Fatalf("Failed to parse enode: %s", err) } - return n.ID[:] + return n.ID().Bytes() } // obfuscateBloom adds 16 random bits to the the bloom diff --git a/eth/handler.go b/eth/handler.go index eae95a9b1b..2f377fe8e8 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -38,7 +38,7 @@ import ( "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" ) @@ -178,7 +178,7 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne NodeInfo: func() interface{} { return manager.NodeInfo() }, - PeerInfo: func(id discover.NodeID) interface{} { + PeerInfo: func(id enode.ID) interface{} { if p := manager.peers.Peer(fmt.Sprintf("%x", id[:8])); p != nil { return p.Info() } diff --git a/eth/peer.go b/eth/peer.go index 6942678852..314c384588 100644 --- a/eth/peer.go +++ b/eth/peer.go @@ -24,6 +24,7 @@ import ( "time" mapset "github.com/deckarep/golang-set" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/p2p" @@ -38,10 +39,26 @@ var ( const ( maxKnownTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS) + maxKnownBlocks = 1024 // Maximum block hashes to keep in the known list (prevent DOS) maxKnownOrderTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS) maxKnownLendingTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS) - maxKnownBlocks = 1024 // Maximum block hashes to keep in the known list (prevent DOS) - handshakeTimeout = 5 * time.Second + + // maxQueuedTxs is the maximum number of transaction lists to queue up before + // dropping broadcasts. This is a sensitive number as a transaction list might + // contain a single transaction, or thousands. + maxQueuedTxs = 128 + + // maxQueuedProps is the maximum number of block propagations to queue up before + // dropping broadcasts. There's not much point in queueing stale blocks, so a few + // that might cover uncles should be enough. + maxQueuedProps = 4 + + // maxQueuedAnns is the maximum number of block announcements to queue up before + // dropping broadcasts. Similarly to block propagations, there's no point to queue + // above some healthy uncle limit, so use that. + maxQueuedAnns = 4 + + handshakeTimeout = 5 * time.Second ) // PeerInfo represents a short summary of the Ethereum sub-protocol metadata known @@ -52,12 +69,17 @@ type PeerInfo struct { Head string `json:"head"` // SHA3 hash of the peer's best owned block } +// propEvent is a block propagation, waiting for its turn in the broadcast queue. +type propEvent struct { + block *types.Block + td *big.Int +} + type peer struct { id string *p2p.Peer - rw p2p.MsgReadWriter - pairRw p2p.MsgReadWriter + rw p2p.MsgReadWriter version int // Protocol version negotiated forkDrop *time.Timer // Timed connection dropper if forks aren't validated in time @@ -66,27 +88,66 @@ type peer struct { td *big.Int lock sync.RWMutex - knownTxs mapset.Set // Set of transaction hashes known to be known by this peer - knownBlocks mapset.Set // Set of block hashes known to be known by this peer - knownOrderTxs mapset.Set // Set of order transaction hashes known to be known by this peer - knownLendingTxs mapset.Set // Set of lending transaction hashes known to be known by this peer + knownTxs mapset.Set // Set of transaction hashes known to be known by this peer + knownBlocks mapset.Set // Set of block hashes known to be known by this peer + knownOrderTxs mapset.Set // Set of order transaction hashes known to be known by this peer + knownLendingTxs mapset.Set // Set of lending transaction hashes known to be known by this peer + queuedTxs chan []*types.Transaction // Queue of transactions to broadcast to the peer + queuedProps chan *propEvent // Queue of blocks to broadcast to the peer + queuedAnns chan *types.Block // Queue of blocks to announce to the peer + term chan struct{} // Termination channel to stop the broadcaster } func newPeer(version int, p *p2p.Peer, rw p2p.MsgReadWriter) *peer { - id := p.ID() - return &peer{ - Peer: p, - rw: rw, - version: version, - id: fmt.Sprintf("%x", id[:8]), - knownTxs: mapset.NewSet(), - knownBlocks: mapset.NewSet(), - knownOrderTxs: mapset.NewSet(), - knownLendingTxs: mapset.NewSet(), + Peer: p, + rw: rw, + version: version, + id: fmt.Sprintf("%x", p.ID().Bytes()[:8]), + knownTxs: mapset.NewSet(), + knownBlocks: mapset.NewSet(), + queuedTxs: make(chan []*types.Transaction, maxQueuedTxs), + queuedProps: make(chan *propEvent, maxQueuedProps), + queuedAnns: make(chan *types.Block, maxQueuedAnns), + term: make(chan struct{}), + } +} + +// broadcast is a write loop that multiplexes block propagations, announcements +// and transaction broadcasts into the remote peer. The goal is to have an async +// writer that does not lock up node internals. +func (p *peer) broadcast() { + for { + select { + case txs := <-p.queuedTxs: + if err := p.SendTransactions(txs); err != nil { + return + } + p.Log().Trace("Broadcast transactions", "count", len(txs)) + + case prop := <-p.queuedProps: + if err := p.SendNewBlock(prop.block, prop.td); err != nil { + return + } + p.Log().Trace("Propagated block", "number", prop.block.Number(), "hash", prop.block.Hash(), "td", prop.td) + + case block := <-p.queuedAnns: + if err := p.SendNewBlockHashes([]common.Hash{block.Hash()}, []uint64{block.NumberU64()}); err != nil { + return + } + p.Log().Trace("Announced block", "number", block.Number(), "hash", block.Hash()) + + case <-p.term: + return + } } } +// close signals the broadcast goroutine to terminate. +func (p *peer) close() { + close(p.term) +} + // Info gathers and returns a collection of metadata known about a peer. func (p *peer) Info() *PeerInfo { hash, td := p.Head() @@ -184,6 +245,19 @@ func (p *peer) SendLendingTransactions(txs types.LendingTransactions) error { return p2p.Send(p.rw, LendingTxMsg, txs) } +// AsyncSendTransactions queues list of transactions propagation to a remote +// peer. If the peer's broadcast queue is full, the event is silently dropped. +func (p *peer) AsyncSendTransactions(txs []*types.Transaction) { + select { + case p.queuedTxs <- txs: + for _, tx := range txs { + p.knownTxs.Add(tx.Hash()) + } + default: + p.Log().Debug("Dropping transaction propagation", "count", len(txs)) + } +} + // SendNewBlockHashes announces the availability of a number of blocks through // a hash notification. func (p *peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error { @@ -198,127 +272,102 @@ func (p *peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error return p2p.Send(p.rw, NewBlockHashesMsg, request) } +// AsyncSendNewBlockHash queues the availability of a block for propagation to a +// remote peer. If the peer's broadcast queue is full, the event is silently +// dropped. +func (p *peer) AsyncSendNewBlockHash(block *types.Block) { + select { + case p.queuedAnns <- block: + p.knownBlocks.Add(block.Hash()) + default: + p.Log().Debug("Dropping block announcement", "number", block.NumberU64(), "hash", block.Hash()) + } +} + // SendNewBlock propagates an entire block to a remote peer. func (p *peer) SendNewBlock(block *types.Block, td *big.Int) error { p.knownBlocks.Add(block.Hash()) - if p.pairRw != nil { - return p2p.Send(p.pairRw, NewBlockMsg, []interface{}{block, td}) - } else { - return p2p.Send(p.rw, NewBlockMsg, []interface{}{block, td}) + return p2p.Send(p.rw, NewBlockMsg, []interface{}{block, td}) +} + +// AsyncSendNewBlock queues an entire block for propagation to a remote peer. If +// the peer's broadcast queue is full, the event is silently dropped. +func (p *peer) AsyncSendNewBlock(block *types.Block, td *big.Int) { + select { + case p.queuedProps <- &propEvent{block: block, td: td}: + p.knownBlocks.Add(block.Hash()) + default: + p.Log().Debug("Dropping block propagation", "number", block.NumberU64(), "hash", block.Hash()) } } // SendBlockHeaders sends a batch of block headers to the remote peer. func (p *peer) SendBlockHeaders(headers []*types.Header) error { - if p.pairRw != nil { - return p2p.Send(p.pairRw, BlockHeadersMsg, headers) - } else { - return p2p.Send(p.rw, BlockHeadersMsg, headers) - } + return p2p.Send(p.rw, BlockHeadersMsg, headers) } // SendBlockBodies sends a batch of block contents to the remote peer. func (p *peer) SendBlockBodies(bodies []*blockBody) error { - if p.pairRw != nil { - return p2p.Send(p.pairRw, BlockBodiesMsg, blockBodiesData(bodies)) - } else { - return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesData(bodies)) - } + return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesData(bodies)) } // SendBlockBodiesRLP sends a batch of block contents to the remote peer from // an already RLP encoded format. func (p *peer) SendBlockBodiesRLP(bodies []rlp.RawValue) error { - if p.pairRw != nil { - return p2p.Send(p.pairRw, BlockBodiesMsg, bodies) - } else { - return p2p.Send(p.rw, BlockBodiesMsg, bodies) - } + return p2p.Send(p.rw, BlockBodiesMsg, bodies) } // SendNodeDataRLP sends a batch of arbitrary internal data, corresponding to the // hashes requested. func (p *peer) SendNodeData(data [][]byte) error { - if p.pairRw != nil { - return p2p.Send(p.pairRw, NodeDataMsg, data) - } else { - return p2p.Send(p.rw, NodeDataMsg, data) - } + return p2p.Send(p.rw, NodeDataMsg, data) } // SendReceiptsRLP sends a batch of transaction receipts, corresponding to the // ones requested from an already RLP encoded format. func (p *peer) SendReceiptsRLP(receipts []rlp.RawValue) error { - if p.pairRw != nil { - return p2p.Send(p.pairRw, ReceiptsMsg, receipts) - } else { - return p2p.Send(p.rw, ReceiptsMsg, receipts) - } + return p2p.Send(p.rw, ReceiptsMsg, receipts) } // RequestOneHeader is a wrapper around the header query functions to fetch a // single header. It is used solely by the fetcher. func (p *peer) RequestOneHeader(hash common.Hash) error { p.Log().Debug("Fetching single header", "hash", hash) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false}) - } else { - return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false}) - } + return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false}) } // RequestHeadersByHash fetches a batch of blocks' headers corresponding to the // specified header query, based on the hash of an origin block. func (p *peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error { p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) - } else { - return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) - } + return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) } // RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the // specified header query, based on the number of an origin block. func (p *peer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error { p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) - } else { - return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) - } + return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) } // RequestBodies fetches a batch of blocks' bodies corresponding to the hashes // specified. func (p *peer) RequestBodies(hashes []common.Hash) error { p.Log().Debug("Fetching batch of block bodies", "count", len(hashes)) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetBlockBodiesMsg, hashes) - } else { - return p2p.Send(p.rw, GetBlockBodiesMsg, hashes) - } + return p2p.Send(p.rw, GetBlockBodiesMsg, hashes) } // RequestNodeData fetches a batch of arbitrary data from a node's known state // data, corresponding to the specified hashes. func (p *peer) RequestNodeData(hashes []common.Hash) error { p.Log().Debug("Fetching batch of state data", "count", len(hashes)) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetNodeDataMsg, hashes) - } else { - return p2p.Send(p.rw, GetNodeDataMsg, hashes) - } + return p2p.Send(p.rw, GetNodeDataMsg, hashes) } // RequestReceipts fetches a batch of transaction receipts from a remote node. func (p *peer) RequestReceipts(hashes []common.Hash) error { p.Log().Debug("Fetching batch of receipts", "count", len(hashes)) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetReceiptsMsg, hashes) - } else { - return p2p.Send(p.rw, GetReceiptsMsg, hashes) - } + return p2p.Send(p.rw, GetReceiptsMsg, hashes) } // Handshake executes the eth protocol handshake, negotiating version number, @@ -406,7 +455,8 @@ func newPeerSet() *peerSet { } // Register injects a new peer into the working set, or returns an error if the -// peer is already known. +// peer is already known. If a new peer it registered, its broadcast loop is also +// started. func (ps *peerSet) Register(p *peer) error { ps.lock.Lock() defer ps.lock.Unlock() @@ -414,16 +464,12 @@ func (ps *peerSet) Register(p *peer) error { if ps.closed { return errClosed } - if existPeer, ok := ps.peers[p.id]; ok { - if existPeer.pairRw != nil { - return errAlreadyRegistered - } - existPeer.PairPeer = p.Peer - existPeer.pairRw = p.rw - p.PairPeer = existPeer.Peer - return p2p.ErrAddPairPeer + if _, ok := ps.peers[p.id]; ok { + return errAlreadyRegistered } ps.peers[p.id] = p + go p.broadcast() + return nil } @@ -433,10 +479,13 @@ func (ps *peerSet) Unregister(id string) error { ps.lock.Lock() defer ps.lock.Unlock() - if _, ok := ps.peers[id]; !ok { + p, ok := ps.peers[id] + if !ok { return errNotRegistered } delete(ps.peers, id) + p.close() + return nil } @@ -486,7 +535,7 @@ func (ps *peerSet) PeersWithoutTx(hash common.Hash) []*peer { return list } -// PeersWithoutTx retrieves a list of peers that do not have a given transaction +// OrderPeersWithoutTx retrieves a list of peers that do not have a given transaction // in their set of known hashes. func (ps *peerSet) OrderPeersWithoutTx(hash common.Hash) []*peer { ps.lock.RLock() diff --git a/eth/sync.go b/eth/sync.go index a1224b3caf..ae95ad8d5e 100644 --- a/eth/sync.go +++ b/eth/sync.go @@ -25,7 +25,7 @@ import ( "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth/downloader" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" ) const ( @@ -64,7 +64,7 @@ func (pm *ProtocolManager) syncTransactions(p *peer) { // the transactions in small packs to one peer at a time. func (pm *ProtocolManager) txsyncLoop() { var ( - pending = make(map[discover.NodeID]*txsync) + pending = make(map[enode.ID]*txsync) sending = false // whether a send is active pack = new(txsync) // the pack that is being sent done = make(chan error, 1) // result of the send diff --git a/eth/sync_test.go b/eth/sync_test.go index 9b447f2a1c..491a7513c7 100644 --- a/eth/sync_test.go +++ b/eth/sync_test.go @@ -23,7 +23,7 @@ import ( "github.com/tomochain/tomochain/eth/downloader" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" ) // Tests that fast sync gets disabled as soon as a real block is successfully @@ -42,8 +42,8 @@ func TestFastSyncDisabling(t *testing.T) { // Sync up the two peers io1, io2 := p2p.MsgPipe() - go pmFull.handle(pmFull.newPeer(63, p2p.NewPeer(discover.NodeID{}, "empty", nil), io2)) - go pmEmpty.handle(pmEmpty.newPeer(63, p2p.NewPeer(discover.NodeID{}, "full", nil), io1)) + go pmFull.handle(pmFull.newPeer(63, p2p.NewPeer(enode.ID{}, "empty", nil), io2)) + go pmEmpty.handle(pmEmpty.newPeer(63, p2p.NewPeer(enode.ID{}, "full", nil), io1)) time.Sleep(250 * time.Millisecond) pmEmpty.synchronise(pmEmpty.peers.BestPeer()) diff --git a/les/handler.go b/les/handler.go index b6d7e601ac..eb6d1886d9 100644 --- a/les/handler.go +++ b/les/handler.go @@ -22,7 +22,6 @@ import ( "errors" "fmt" "math/big" - "net" "sync" "time" @@ -38,8 +37,8 @@ import ( "github.com/tomochain/tomochain/light" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" "github.com/tomochain/tomochain/p2p/discv5" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" @@ -164,8 +163,7 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, protoco var entry *poolEntry peer := manager.newPeer(int(version), networkId, p, rw) if manager.serverPool != nil { - addr := p.RemoteAddr().(*net.TCPAddr) - entry = manager.serverPool.connect(peer, addr.IP, uint16(addr.Port)) + entry = manager.serverPool.connect(peer, peer.Node()) } peer.poolEntry = entry select { @@ -187,7 +185,7 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, protoco NodeInfo: func() interface{} { return manager.NodeInfo() }, - PeerInfo: func(id discover.NodeID) interface{} { + PeerInfo: func(id enode.ID) interface{} { if p := manager.peers.Peer(fmt.Sprintf("%x", id[:8])); p != nil { return p.Info() } @@ -388,7 +386,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } if p.requestAnnounceType == announceTypeSigned { - if err := req.checkSignature(p.pubKey); err != nil { + if err := req.checkSignature(p.ID()); err != nil { p.Log().Trace("Invalid announcement signature", "err", err) return err } diff --git a/les/peer.go b/les/peer.go index 2723003ec2..ca91562f49 100644 --- a/les/peer.go +++ b/les/peer.go @@ -18,8 +18,6 @@ package les import ( - "crypto/ecdsa" - "encoding/binary" "errors" "fmt" "math/big" @@ -36,9 +34,10 @@ import ( ) var ( - errClosed = errors.New("peer set is closed") - errAlreadyRegistered = errors.New("peer is already registered") - errNotRegistered = errors.New("peer is not registered") + errClosed = errors.New("peer set is closed") + errAlreadyRegistered = errors.New("peer is already registered") + errNotRegistered = errors.New("peer is not registered") + errInvalidHelpTrieReq = errors.New("invalid help trie request") ) const maxResponseErrors = 50 // number of invalid responses tolerated (makes the protocol less brittle but still avoids spam) @@ -51,7 +50,6 @@ const ( type peer struct { *p2p.Peer - pubKey *ecdsa.PublicKey rw p2p.MsgReadWriter @@ -80,11 +78,9 @@ type peer struct { func newPeer(version int, network uint64, p *p2p.Peer, rw p2p.MsgReadWriter) *peer { id := p.ID() - pubKey, _ := id.Pubkey() return &peer{ Peer: p, - pubKey: pubKey, rw: rw, version: version, network: network, @@ -284,21 +280,21 @@ func (p *peer) RequestProofs(reqID, cost uint64, reqs []ProofReq) error { } // RequestHelperTrieProofs fetches a batch of HelperTrie merkle proofs from a remote node. -func (p *peer) RequestHelperTrieProofs(reqID, cost uint64, reqs []HelperTrieReq) error { - p.Log().Debug("Fetching batch of HelperTrie proofs", "count", len(reqs)) +func (p *peer) RequestHelperTrieProofs(reqID, cost uint64, data interface{}) error { switch p.version { case lpv1: - reqsV1 := make([]ChtReq, len(reqs)) - for i, req := range reqs { - if req.Type != htCanonical || req.AuxReq != auxHeader || len(req.Key) != 8 { - return fmt.Errorf("Request invalid in LES/1 mode") - } - blockNum := binary.BigEndian.Uint64(req.Key) - // convert HelperTrie request to old CHT request - reqsV1[i] = ChtReq{ChtNum: (req.TrieIdx + 1) * (light.CHTFrequencyClient / light.CHTFrequencyServer), BlockNum: blockNum, FromLevel: req.FromLevel} + reqs, ok := data.([]ChtReq) + if !ok { + return errInvalidHelpTrieReq } - return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqsV1) + p.Log().Debug("Fetching batch of header proofs", "count", len(reqs)) + return sendRequest(p.rw, GetHeaderProofsMsg, reqID, cost, reqs) case lpv2: + reqs, ok := data.([]HelperTrieReq) + if !ok { + return errInvalidHelpTrieReq + } + p.Log().Debug("Fetching batch of HelperTrie proofs", "count", len(reqs)) return sendRequest(p.rw, GetHelperTrieProofsMsg, reqID, cost, reqs) default: panic(nil) @@ -545,9 +541,11 @@ func (ps *peerSet) notify(n peerSetNotify) { func (ps *peerSet) Register(p *peer) error { ps.lock.Lock() if ps.closed { + ps.lock.Unlock() return errClosed } if _, ok := ps.peers[p.id]; ok { + ps.lock.Unlock() return errAlreadyRegistered } ps.peers[p.id] = p diff --git a/les/protocol.go b/les/protocol.go index 26e4573369..eabac6dab8 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -18,9 +18,7 @@ package les import ( - "bytes" "crypto/ecdsa" - "crypto/elliptic" "errors" "fmt" "io" @@ -30,7 +28,7 @@ import ( "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/crypto/secp256k1" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rlp" ) @@ -148,22 +146,20 @@ func (a *announceData) sign(privKey *ecdsa.PrivateKey) { } // checkSignature verifies if the block announcement has a valid signature by the given pubKey -func (a *announceData) checkSignature(pubKey *ecdsa.PublicKey) error { +func (a *announceData) checkSignature(id enode.ID) error { var sig []byte if err := a.Update.decode().get("sign", &sig); err != nil { return err } rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) - recPubkey, err := secp256k1.RecoverPubkey(crypto.Keccak256(rlp), sig) + recPubkey, err := crypto.SigToPub(crypto.Keccak256(rlp), sig) if err != nil { return err } - pbytes := elliptic.Marshal(pubKey.Curve, pubKey.X, pubKey.Y) - if bytes.Equal(pbytes, recPubkey) { + if id == enode.PubkeyToIDV4(recPubkey) { return nil - } else { - return errors.New("Wrong signature") } + return errors.New("wrong signature") } type blockInfo struct { diff --git a/les/serverpool.go b/les/serverpool.go index 313de65e90..93a37fc279 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -18,6 +18,7 @@ package les import ( + "crypto/ecdsa" "fmt" "io" "math" @@ -28,11 +29,12 @@ import ( "time" "github.com/tomochain/tomochain/common/mclock" + "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" "github.com/tomochain/tomochain/p2p/discv5" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rlp" ) @@ -73,7 +75,6 @@ const ( // and a short term value which is adjusted exponentially with a factor of // pstatRecentAdjust with each dial/connection and also returned exponentially // to the average with the time constant pstatReturnToMeanTC - pstatRecentAdjust = 0.1 pstatReturnToMeanTC = time.Hour // node address selection weight is dropped by a factor of exp(-addrFailDropLn) after // each unsuccessful connection (restored after a successful one) @@ -83,14 +84,31 @@ const ( responseScoreTC = time.Millisecond * 100 delayScoreTC = time.Second * 5 timeoutPow = 10 - // peerSelectMinWeight is added to calculated weights at request peer selection - // to give poorly performing peers a little chance of coming back - peerSelectMinWeight = 0.005 // initStatsWeight is used to initialize previously unknown peers with good // statistics to give a chance to prove themselves initStatsWeight = 1 ) +// connReq represents a request for peer connection. +type connReq struct { + p *peer + node *enode.Node + result chan *poolEntry +} + +// disconnReq represents a request for peer disconnection. +type disconnReq struct { + entry *poolEntry + stopped bool + done chan struct{} +} + +// registerReq represents a request for peer registration. +type registerReq struct { + entry *poolEntry + done chan struct{} +} + // serverPool implements a pool for storing and selecting newly discovered and already // known light server nodes. It received discovered nodes, stores statistics about // known nodes and takes care of always having enough good quality servers connected. @@ -105,14 +123,17 @@ type serverPool struct { topic discv5.Topic discSetPeriod chan time.Duration - discNodes chan *discv5.Node + discNodes chan *enode.Node discLookups chan bool - entries map[discover.NodeID]*poolEntry - lock sync.Mutex + entries map[enode.ID]*poolEntry timeout, enableRetry chan *poolEntry adjustStats chan poolStatAdjust + connCh chan *connReq + disconnCh chan *disconnReq + registerCh chan *registerReq + knownQueue, newQueue poolEntryQueue knownSelect, newSelect *weightedRandomSelect knownSelected, newSelected int @@ -125,10 +146,13 @@ func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup) *s db: db, quit: quit, wg: wg, - entries: make(map[discover.NodeID]*poolEntry), + entries: make(map[enode.ID]*poolEntry), timeout: make(chan *poolEntry, 1), adjustStats: make(chan poolStatAdjust, 100), enableRetry: make(chan *poolEntry, 1), + connCh: make(chan *connReq), + disconnCh: make(chan *disconnReq), + registerCh: make(chan *registerReq), knownSelect: newWeightedRandomSelect(), newSelect: newWeightedRandomSelect(), fastDiscover: true, @@ -147,13 +171,28 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { if pool.server.DiscV5 != nil { pool.discSetPeriod = make(chan time.Duration, 1) - pool.discNodes = make(chan *discv5.Node, 100) + pool.discNodes = make(chan *enode.Node, 100) pool.discLookups = make(chan bool, 100) - go pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, pool.discNodes, pool.discLookups) + go pool.discoverNodes() } - - go pool.eventLoop() pool.checkDial() + go pool.eventLoop() +} + +// discoverNodes wraps SearchTopic, converting result nodes to enode.Node. +func (pool *serverPool) discoverNodes() { + ch := make(chan *discv5.Node) + go func() { + pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, ch, pool.discLookups) + close(ch) + }() + for n := range ch { + pubkey, err := decodePubkey64(n.ID[:]) + if err != nil { + continue + } + pool.discNodes <- enode.NewV4(pubkey, n.IP, int(n.TCP), int(n.UDP)) + } } // connect should be called upon any incoming connection. If the connection has been @@ -161,84 +200,45 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { // Otherwise, the connection should be rejected. // Note that whenever a connection has been accepted and a pool entry has been returned, // disconnect should also always be called. -func (pool *serverPool) connect(p *peer, ip net.IP, port uint16) *poolEntry { - pool.lock.Lock() - defer pool.lock.Unlock() - entry := pool.entries[p.ID()] - if entry == nil { - entry = pool.findOrNewNode(p.ID(), ip, port) - } - p.Log().Debug("Connecting to new peer", "state", entry.state) - if entry.state == psConnected || entry.state == psRegistered { +func (pool *serverPool) connect(p *peer, node *enode.Node) *poolEntry { + log.Debug("Connect new entry", "enode", p.id) + req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)} + select { + case pool.connCh <- req: + case <-pool.quit: return nil } - pool.connWg.Add(1) - entry.peer = p - entry.state = psConnected - addr := &poolEntryAddress{ - ip: ip, - port: port, - lastSeen: mclock.Now(), - } - entry.lastConnected = addr - entry.addr = make(map[string]*poolEntryAddress) - entry.addr[addr.strKey()] = addr - entry.addrSelect = *newWeightedRandomSelect() - entry.addrSelect.update(addr) - return entry + return <-req.result } // registered should be called after a successful handshake func (pool *serverPool) registered(entry *poolEntry) { - log.Debug("Registered new entry", "enode", entry.id) - pool.lock.Lock() - defer pool.lock.Unlock() - - entry.state = psRegistered - entry.regTime = mclock.Now() - if !entry.known { - pool.newQueue.remove(entry) - entry.known = true + log.Debug("Registered new entry", "enode", entry.node.ID()) + req := ®isterReq{entry: entry, done: make(chan struct{})} + select { + case pool.registerCh <- req: + case <-pool.quit: + return } - pool.knownQueue.setLatest(entry) - entry.shortRetry = shortRetryCnt + <-req.done } // disconnect should be called when ending a connection. Service quality statistics // can be updated optionally (not updated if no registration happened, in this case // only connection statistics are updated, just like in case of timeout) func (pool *serverPool) disconnect(entry *poolEntry) { - log.Debug("Disconnected old entry", "enode", entry.id) - pool.lock.Lock() - defer pool.lock.Unlock() - - if entry.state == psRegistered { - connTime := mclock.Now() - entry.regTime - connAdjust := float64(connTime) / float64(targetConnTime) - if connAdjust > 1 { - connAdjust = 1 - } - stopped := false - select { - case <-pool.quit: - stopped = true - default: - } - if stopped { - entry.connectStats.add(1, connAdjust) - } else { - entry.connectStats.add(connAdjust, 1) - } + stopped := false + select { + case <-pool.quit: + stopped = true + default: } + log.Debug("Disconnected old entry", "enode", entry.node.ID()) + req := &disconnReq{entry: entry, stopped: stopped, done: make(chan struct{})} - entry.state = psNotConnected - if entry.knownSelected { - pool.knownSelected-- - } else { - pool.newSelected-- - } - pool.setRetryDial(entry) - pool.connWg.Done() + // Block until disconnection request is served. + pool.disconnCh <- req + <-req.done } const ( @@ -281,25 +281,51 @@ func (pool *serverPool) eventLoop() { if pool.discSetPeriod != nil { pool.discSetPeriod <- time.Millisecond * 100 } + + // disconnect updates service quality statistics depending on the connection time + // and disconnection initiator. + disconnect := func(req *disconnReq, stopped bool) { + // Handle peer disconnection requests. + entry := req.entry + if entry.state == psRegistered { + connAdjust := float64(mclock.Now()-entry.regTime) / float64(targetConnTime) + if connAdjust > 1 { + connAdjust = 1 + } + if stopped { + // disconnect requested by ourselves. + entry.connectStats.add(1, connAdjust) + } else { + // disconnect requested by server side. + entry.connectStats.add(connAdjust, 1) + } + } + entry.state = psNotConnected + + if entry.knownSelected { + pool.knownSelected-- + } else { + pool.newSelected-- + } + pool.setRetryDial(entry) + pool.connWg.Done() + close(req.done) + } + for { select { case entry := <-pool.timeout: - pool.lock.Lock() if !entry.removed { pool.checkDialTimeout(entry) } - pool.lock.Unlock() case entry := <-pool.enableRetry: - pool.lock.Lock() if !entry.removed { entry.delayedRetry = false pool.updateCheckDial(entry) } - pool.lock.Unlock() case adj := <-pool.adjustStats: - pool.lock.Lock() switch adj.adjustType { case pseBlockDelay: adj.entry.delayStats.add(float64(adj.time), 1) @@ -309,13 +335,10 @@ func (pool *serverPool) eventLoop() { case pseResponseTimeout: adj.entry.timeoutStats.add(1, 1) } - pool.lock.Unlock() case node := <-pool.discNodes: - pool.lock.Lock() - entry := pool.findOrNewNode(discover.NodeID(node.ID), node.IP, node.TCP) + entry := pool.findOrNewNode(node) pool.updateCheckDial(entry) - pool.lock.Unlock() case conv := <-pool.discLookups: if conv { @@ -331,31 +354,82 @@ func (pool *serverPool) eventLoop() { } } + case req := <-pool.connCh: + // Handle peer connection requests. + entry := pool.entries[req.p.ID()] + if entry == nil { + entry = pool.findOrNewNode(req.node) + } + if entry.state == psConnected || entry.state == psRegistered { + req.result <- nil + continue + } + pool.connWg.Add(1) + entry.peer = req.p + entry.state = psConnected + addr := &poolEntryAddress{ + ip: req.node.IP(), + port: uint16(req.node.TCP()), + lastSeen: mclock.Now(), + } + entry.lastConnected = addr + entry.addr = make(map[string]*poolEntryAddress) + entry.addr[addr.strKey()] = addr + entry.addrSelect = *newWeightedRandomSelect() + entry.addrSelect.update(addr) + req.result <- entry + + case req := <-pool.registerCh: + // Handle peer registration requests. + entry := req.entry + entry.state = psRegistered + entry.regTime = mclock.Now() + if !entry.known { + pool.newQueue.remove(entry) + entry.known = true + } + pool.knownQueue.setLatest(entry) + entry.shortRetry = shortRetryCnt + close(req.done) + + case req := <-pool.disconnCh: + // Handle peer disconnection requests. + disconnect(req, req.stopped) + case <-pool.quit: if pool.discSetPeriod != nil { close(pool.discSetPeriod) } - pool.connWg.Wait() + + // Spawn a goroutine to close the disconnCh after all connections are disconnected. + go func() { + pool.connWg.Wait() + close(pool.disconnCh) + }() + + // Handle all remaining disconnection requests before exit. + for req := range pool.disconnCh { + disconnect(req, true) + } pool.saveNodes() pool.wg.Done() return - } } } -func (pool *serverPool) findOrNewNode(id discover.NodeID, ip net.IP, port uint16) *poolEntry { +func (pool *serverPool) findOrNewNode(node *enode.Node) *poolEntry { now := mclock.Now() - entry := pool.entries[id] + entry := pool.entries[node.ID()] if entry == nil { - log.Debug("Discovered new entry", "id", id) + log.Debug("Discovered new entry", "id", node.ID()) entry = &poolEntry{ - id: id, + node: node, addr: make(map[string]*poolEntryAddress), addrSelect: *newWeightedRandomSelect(), shortRetry: shortRetryCnt, } - pool.entries[id] = entry + pool.entries[node.ID()] = entry // initialize previously unknown peers with good statistics to give a chance to prove themselves entry.connectStats.add(1, initStatsWeight) entry.delayStats.add(0, initStatsWeight) @@ -363,10 +437,7 @@ func (pool *serverPool) findOrNewNode(id discover.NodeID, ip net.IP, port uint16 entry.timeoutStats.add(0, initStatsWeight) } entry.lastDiscovered = now - addr := &poolEntryAddress{ - ip: ip, - port: port, - } + addr := &poolEntryAddress{ip: node.IP(), port: uint16(node.TCP())} if a, ok := entry.addr[addr.strKey()]; ok { addr = a } else { @@ -393,12 +464,12 @@ func (pool *serverPool) loadNodes() { return } for _, e := range list { - log.Debug("Loaded server stats", "id", e.id, "fails", e.lastConnected.fails, + log.Debug("Loaded server stats", "id", e.node.ID(), "fails", e.lastConnected.fails, "conn", fmt.Sprintf("%v/%v", e.connectStats.avg, e.connectStats.weight), "delay", fmt.Sprintf("%v/%v", time.Duration(e.delayStats.avg), e.delayStats.weight), "response", fmt.Sprintf("%v/%v", time.Duration(e.responseStats.avg), e.responseStats.weight), "timeout", fmt.Sprintf("%v/%v", e.timeoutStats.avg, e.timeoutStats.weight)) - pool.entries[e.id] = e + pool.entries[e.node.ID()] = e pool.knownQueue.setLatest(e) pool.knownSelect.update((*knownEntry)(e)) } @@ -424,7 +495,7 @@ func (pool *serverPool) removeEntry(entry *poolEntry) { pool.newSelect.remove((*discoveredEntry)(entry)) pool.knownSelect.remove((*knownEntry)(entry)) entry.removed = true - delete(pool.entries, entry.id) + delete(pool.entries, entry.node.ID()) } // setRetryDial starts the timer which will enable dialing a certain node again @@ -502,10 +573,10 @@ func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) { pool.newSelected++ } addr := entry.addrSelect.choose().(*poolEntryAddress) - log.Debug("Dialing new peer", "lesaddr", entry.id.String()+"@"+addr.strKey(), "set", len(entry.addr), "known", knownSelected) + log.Debug("Dialing new peer", "lesaddr", entry.node.ID().String()+"@"+addr.strKey(), "set", len(entry.addr), "known", knownSelected) entry.dialed = addr go func() { - pool.server.AddPeer(discover.NewNode(entry.id, addr.ip, addr.port, addr.port)) + pool.server.AddPeer(entry.node) select { case <-pool.quit: case <-time.After(dialTimeout): @@ -523,7 +594,7 @@ func (pool *serverPool) checkDialTimeout(entry *poolEntry) { if entry.state != psDialed { return } - log.Debug("Dial timeout", "lesaddr", entry.id.String()+"@"+entry.dialed.strKey()) + log.Debug("Dial timeout", "lesaddr", entry.node.ID().String()+"@"+entry.dialed.strKey()) entry.state = psNotConnected if entry.knownSelected { pool.knownSelected-- @@ -545,8 +616,9 @@ const ( // poolEntry represents a server node and stores its current state and statistics. type poolEntry struct { peer *peer - id discover.NodeID + pubkey [64]byte // secp256k1 key of the node addr map[string]*poolEntryAddress + node *enode.Node lastConnected, dialed *poolEntryAddress addrSelect weightedRandomSelect @@ -563,23 +635,39 @@ type poolEntry struct { shortRetry int } +// poolEntryEnc is the RLP encoding of poolEntry. +type poolEntryEnc struct { + Pubkey []byte + IP net.IP + Port uint16 + Fails uint + CStat, DStat, RStat, TStat poolStats +} + func (e *poolEntry) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, []interface{}{e.id, e.lastConnected.ip, e.lastConnected.port, e.lastConnected.fails, &e.connectStats, &e.delayStats, &e.responseStats, &e.timeoutStats}) + return rlp.Encode(w, &poolEntryEnc{ + Pubkey: encodePubkey64(e.node.Pubkey()), + IP: e.lastConnected.ip, + Port: e.lastConnected.port, + Fails: e.lastConnected.fails, + CStat: e.connectStats, + DStat: e.delayStats, + RStat: e.responseStats, + TStat: e.timeoutStats, + }) } func (e *poolEntry) DecodeRLP(s *rlp.Stream) error { - var entry struct { - ID discover.NodeID - IP net.IP - Port uint16 - Fails uint - CStat, DStat, RStat, TStat poolStats - } + var entry poolEntryEnc if err := s.Decode(&entry); err != nil { return err } + pubkey, err := decodePubkey64(entry.Pubkey) + if err != nil { + return err + } addr := &poolEntryAddress{ip: entry.IP, port: entry.Port, fails: entry.Fails, lastSeen: mclock.Now()} - e.id = entry.ID + e.node = enode.NewV4(pubkey, entry.IP, int(entry.Port), int(entry.Port)) e.addr = make(map[string]*poolEntryAddress) e.addr[addr.strKey()] = addr e.addrSelect = *newWeightedRandomSelect() @@ -594,6 +682,14 @@ func (e *poolEntry) DecodeRLP(s *rlp.Stream) error { return nil } +func encodePubkey64(pub *ecdsa.PublicKey) []byte { + return crypto.FromECDSAPub(pub)[:1] +} + +func decodePubkey64(b []byte) (*ecdsa.PublicKey, error) { + return crypto.UnmarshalPubkey(append([]byte{0x04}, b...)) +} + // discoveredEntry implements wrsItem type discoveredEntry poolEntry @@ -605,9 +701,8 @@ func (e *discoveredEntry) Weight() int64 { t := time.Duration(mclock.Now() - e.lastDiscovered) if t <= discoverExpireStart { return 1000000000 - } else { - return int64(1000000000 * math.Exp(-float64(t-discoverExpireStart)/float64(discoverExpireConst))) } + return int64(1000000000 * math.Exp(-float64(t-discoverExpireStart)/float64(discoverExpireConst))) } // knownEntry implements wrsItem diff --git a/node/api.go b/node/api.go index 23edbe2b3f..a6b92c6b80 100644 --- a/node/api.go +++ b/node/api.go @@ -26,7 +26,7 @@ import ( "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rpc" ) @@ -51,7 +51,7 @@ func (api *PrivateAdminAPI) AddPeer(url string) (bool, error) { return false, ErrNodeStopped } // Try to add the url as a static peer and return - node, err := discover.ParseNode(url) + node, err := enode.ParseV4(url) if err != nil { return false, fmt.Errorf("invalid enode: %v", err) } @@ -67,7 +67,7 @@ func (api *PrivateAdminAPI) RemovePeer(url string) (bool, error) { return false, ErrNodeStopped } // Try to remove the url as a static peer and return - node, err := discover.ParseNode(url) + node, err := enode.ParseV4(url) if err != nil { return false, fmt.Errorf("invalid enode: %v", err) } diff --git a/node/config.go b/node/config.go index b8ad712fc8..1eb4e528d7 100644 --- a/node/config.go +++ b/node/config.go @@ -32,7 +32,7 @@ import ( "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" ) const ( @@ -333,18 +333,18 @@ func (c *Config) NodeKey() *ecdsa.PrivateKey { } // StaticNodes returns a list of node enode URLs configured as static nodes. -func (c *Config) StaticNodes() []*discover.Node { +func (c *Config) StaticNodes() []*enode.Node { return c.parsePersistentNodes(c.resolvePath(datadirStaticNodes)) } // TrustedNodes returns a list of node enode URLs configured as trusted nodes. -func (c *Config) TrustedNodes() []*discover.Node { +func (c *Config) TrustedNodes() []*enode.Node { return c.parsePersistentNodes(c.resolvePath(datadirTrustedNodes)) } // parsePersistentNodes parses a list of discovery node URLs loaded from a .json // file from within the data directory. -func (c *Config) parsePersistentNodes(path string) []*discover.Node { +func (c *Config) parsePersistentNodes(path string) []*enode.Node { // Short circuit if no node config is present if c.DataDir == "" { return nil @@ -359,12 +359,12 @@ func (c *Config) parsePersistentNodes(path string) []*discover.Node { return nil } // Interpret the list as a discovery node array - var nodes []*discover.Node + var nodes []*enode.Node for _, url := range nodelist { if url == "" { continue } - node, err := discover.ParseNode(url) + node, err := enode.ParseV4(url) if err != nil { log.Error(fmt.Sprintf("Node URL %s: %v\n", url, err)) continue diff --git a/swarm/network/hive.go b/swarm/network/hive.go index 413074c474..0b8824c906 100644 --- a/swarm/network/hive.go +++ b/swarm/network/hive.go @@ -25,7 +25,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/metrics" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/netutil" "github.com/tomochain/tomochain/swarm/network/kademlia" "github.com/tomochain/tomochain/swarm/storage" @@ -49,7 +49,7 @@ var ( type Hive struct { listenAddr func() string callInterval uint64 - id discover.NodeID + id enode.ID addr kademlia.Address kad *kademlia.Kademlia path string @@ -77,7 +77,7 @@ type HiveParams struct { *kademlia.KadParams } -//create default params +// create default params func NewDefaultHiveParams() *HiveParams { kad := kademlia.NewDefaultKadParams() // kad.BucketSize = bucketSize @@ -90,8 +90,8 @@ func NewDefaultHiveParams() *HiveParams { } } -//this can only finally be set after all config options (file, cmd line, env vars) -//have been evaluated +// this can only finally be set after all config options (file, cmd line, env vars) +// have been evaluated func (self *HiveParams) Init(path string) { self.KadDbPath = filepath.Join(path, "bzz-peers.json") } @@ -133,7 +133,7 @@ func (self *Hive) Addr() kademlia.Address { // listedAddr is a function to retrieve listening address to advertise to peers // connectPeer is a function to connect to a peer based on its NodeID or enode URL // there are called on the p2p.Server which runs on the node -func (self *Hive) Start(id discover.NodeID, listenAddr func() string, connectPeer func(string) error) (err error) { +func (self *Hive) Start(id enode.ID, listenAddr func() string, connectPeer func(string) error) (err error) { self.toggle = make(chan bool) self.more = make(chan bool) self.quit = make(chan bool) diff --git a/swarm/network/messages.go b/swarm/network/messages.go index 18ab633535..b434c2ff1b 100644 --- a/swarm/network/messages.go +++ b/swarm/network/messages.go @@ -21,8 +21,9 @@ import ( "net" "time" + "github.com/tomochain/tomochain/p2p/enode" + "github.com/tomochain/tomochain/contracts/chequebook" - "github.com/tomochain/tomochain/p2p/discover" "github.com/tomochain/tomochain/swarm/network/kademlia" "github.com/tomochain/tomochain/swarm/services/swap" "github.com/tomochain/tomochain/swarm/storage" @@ -45,7 +46,7 @@ const ( ) /* - Handshake + Handshake * Version: 8 byte integer version of the protocol * ID: arbitrary byte sequence client identifier human readable @@ -54,7 +55,6 @@ const ( * NetworkID: 8 byte integer network identifier * Caps: swarm-specific capabilities, format identical to devp2p * SyncState: syncronisation state (db iterator key and address space etc) persisted about the peer - */ type statusMsgData struct { Version uint64 @@ -69,12 +69,12 @@ func (self *statusMsgData) String() string { } /* - store requests are forwarded to the peers in their kademlia proximity bin - if they are distant - if they are within our storage radius or have any incentive to store it - then attach your nodeID to the metadata - if the storage request is sufficiently close (within our proxLimit, i. e., the - last row of the routing table) +store requests are forwarded to the peers in their kademlia proximity bin +if they are distant +if they are within our storage radius or have any incentive to store it +then attach your nodeID to the metadata +if the storage request is sufficiently close (within our proxLimit, i. e., the +last row of the routing table) */ type storeRequestMsgData struct { Key storage.Key // hash of datasize | data @@ -181,9 +181,9 @@ type peerAddr struct { // peerAddr pretty prints as enode func (self *peerAddr) String() string { - var nodeid discover.NodeID + var nodeid enode.ID copy(nodeid[:], self.ID) - return discover.NewNode(nodeid, self.IP, 0, self.Port).String() + return nodeid.GoString() } /* diff --git a/swarm/swarm.go b/swarm/swarm.go index 34a790eca6..e970fc55c6 100644 --- a/swarm/swarm.go +++ b/swarm/swarm.go @@ -37,7 +37,7 @@ import ( "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" "github.com/tomochain/tomochain/swarm/api" @@ -275,7 +275,7 @@ Start is called when the stack is started func (self *Swarm) Start(srv *p2p.Server) error { startTime = time.Now() connectPeer := func(url string) error { - node, err := discover.ParseNode(url) + node, err := enode.ParseV4(url) if err != nil { return fmt.Errorf("invalid node URL: %v", err) } @@ -296,7 +296,7 @@ func (self *Swarm) Start(srv *p2p.Server) error { log.Warn(fmt.Sprintf("Starting Swarm service")) self.hive.Start( - discover.PubkeyID(&srv.PrivateKey.PublicKey), + enode.PubkeyToIDV4(&srv.PrivateKey.PublicKey), func() string { return srv.ListenAddr }, connectPeer, ) diff --git a/whisper/whisperv5/api.go b/whisper/whisperv5/api.go index 0ab821b4f4..6c6cb16449 100644 --- a/whisper/whisperv5/api.go +++ b/whisper/whisperv5/api.go @@ -28,7 +28,7 @@ import ( "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rpc" ) @@ -93,19 +93,19 @@ func (api *PublicWhisperAPI) SetMaxMessageSize(ctx context.Context, size uint32) return true, api.w.SetMaxMessageSize(size) } -// SetMinPow sets the minimum PoW for a message before it is accepted. +// SetMinPoW sets the minimum PoW for a message before it is accepted. func (api *PublicWhisperAPI) SetMinPoW(ctx context.Context, pow float64) (bool, error) { return true, api.w.SetMinimumPoW(pow) } // MarkTrustedPeer marks a peer trusted. , which will allow it to send historic (expired) messages. // Note: This function is not adding new nodes, the node needs to exists as a peer. -func (api *PublicWhisperAPI) MarkTrustedPeer(ctx context.Context, enode string) (bool, error) { - n, err := discover.ParseNode(enode) +func (api *PublicWhisperAPI) MarkTrustedPeer(ctx context.Context, url string) (bool, error) { + n, err := enode.ParseV4(url) if err != nil { return false, err } - return true, api.w.AllowP2PMessagesFromPeer(n.ID[:]) + return true, api.w.AllowP2PMessagesFromPeer(n.ID().Bytes()) } // NewKeyPair generates a new public and private key pair for message decryption and encryption. @@ -275,11 +275,11 @@ func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, er // send to specific node (skip PoW check) if len(req.TargetPeer) > 0 { - n, err := discover.ParseNode(req.TargetPeer) + n, err := enode.ParseV4(req.TargetPeer) if err != nil { return false, fmt.Errorf("failed to parse target peer: %s", err) } - return true, api.w.SendP2PMessage(n.ID[:], env) + return true, api.w.SendP2PMessage(n.ID().Bytes(), env) } // ensure that the message PoW meets the node's minimum accepted PoW diff --git a/whisper/whisperv6/api.go b/whisper/whisperv6/api.go index 32831d1ec1..dca7c8b3f7 100644 --- a/whisper/whisperv6/api.go +++ b/whisper/whisperv6/api.go @@ -28,7 +28,7 @@ import ( "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rpc" ) @@ -106,12 +106,12 @@ func (api *PublicWhisperAPI) SetBloomFilter(ctx context.Context, bloom hexutil.B // MarkTrustedPeer marks a peer trusted, which will allow it to send historic (expired) messages. // Note: This function is not adding new nodes, the node needs to exists as a peer. -func (api *PublicWhisperAPI) MarkTrustedPeer(ctx context.Context, enode string) (bool, error) { - n, err := discover.ParseNode(enode) +func (api *PublicWhisperAPI) MarkTrustedPeer(ctx context.Context, url string) (bool, error) { + n, err := enode.ParseV4(url) if err != nil { return false, err } - return true, api.w.AllowP2PMessagesFromPeer(n.ID[:]) + return true, api.w.AllowP2PMessagesFromPeer(n.ID().Bytes()) } // NewKeyPair generates a new public and private key pair for message decryption and encryption. @@ -294,11 +294,11 @@ func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, er // send to specific node (skip PoW check) if len(req.TargetPeer) > 0 { - n, err := discover.ParseNode(req.TargetPeer) + n, err := enode.ParseV4(req.TargetPeer) if err != nil { return false, fmt.Errorf("failed to parse target peer: %s", err) } - return true, api.w.SendP2PMessage(n.ID[:], env) + return true, api.w.SendP2PMessage(n.ID().Bytes(), env) } // ensure that the message PoW meets the node's minimum accepted PoW From d2a975d81360d1ff6a7cecae0edbfc0f2d49166a Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 30 Oct 2023 18:10:01 +0700 Subject: [PATCH 100/119] [WIP] Fix p2e/enode porting unit tests --- les/helper_test.go | 10 +- p2p/dial_test.go | 471 ++++++++++++++--------------- p2p/simulations/adapters/inproc.go | 6 +- 3 files changed, 242 insertions(+), 245 deletions(-) diff --git a/les/helper_test.go b/les/helper_test.go index 67a932b4ec..4841efb6c6 100644 --- a/les/helper_test.go +++ b/les/helper_test.go @@ -37,7 +37,7 @@ import ( "github.com/tomochain/tomochain/les/flowcontrol" "github.com/tomochain/tomochain/light" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/params" ) @@ -223,8 +223,8 @@ func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, sh app, net := p2p.MsgPipe() // Generate a random id and create the peer - var id discover.NodeID - rand.Read(id[:]) + var id enode.ID + rand.Read(id.Bytes()) peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net) @@ -260,8 +260,8 @@ func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer, app, net := p2p.MsgPipe() // Generate a random id and create the peer - var id discover.NodeID - rand.Read(id[:]) + var id enode.ID + rand.Read(id.Bytes()) peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net) peer2 := pm2.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), app) diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 0b88b4cf80..411f49a7c3 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -24,8 +24,9 @@ import ( "time" "github.com/davecgh/go-spew/spew" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" + "github.com/tomochain/tomochain/p2p/enr" "github.com/tomochain/tomochain/p2p/netutil" ) @@ -52,7 +53,7 @@ func runDialTest(t *testing.T, test dialtest) { pm := func(ps []*Peer) map[enode.ID]*Peer { m := make(map[enode.ID]*Peer) for _, p := range ps { - m[p.rw.id] = p + m[p.ID()] = p } return m } @@ -70,6 +71,7 @@ func runDialTest(t *testing.T, test dialtest) { t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n", i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) } + t.Log("tasks:", spew.Sdump(new)) // Time advances by 16 seconds on every round. vtime = vtime.Add(16 * time.Second) @@ -77,13 +79,13 @@ func runDialTest(t *testing.T, test dialtest) { } } -type fakeTable []*discover.Node +type fakeTable []*enode.Node -func (t fakeTable) Self() *discover.Node { return new(discover.Node) } -func (t fakeTable) Close() {} -func (t fakeTable) Lookup(enode.ID) []*discover.Node { return nil } -func (t fakeTable) Resolve(enode.ID) *discover.Node { return nil } -func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf, t) } +func (t fakeTable) Self() *enode.Node { return new(enode.Node) } +func (t fakeTable) Close() {} +func (t fakeTable) LookupRandom() []*enode.Node { return nil } +func (t fakeTable) Resolve(*enode.Node) *enode.Node { return nil } +func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) } // This test checks that dynamic dials are launched from discovery results. func TestDialStateDynDial(t *testing.T) { @@ -93,63 +95,63 @@ func TestDialStateDynDial(t *testing.T) { // A discovery query is launched. { peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, }, new: []task{&discoverTask{}}, }, // Dynamic dials are launched when it completes. { peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, }, done: []task{ - &discoverTask{results: []*discover.Node{ - {ID: uintID(2)}, // this one is already connected and not dialed. - {ID: uintID(3)}, - {ID: uintID(4)}, - {ID: uintID(5)}, - {ID: uintID(6)}, // these are not tried because max dyn dials is 5 - {ID: uintID(7)}, // ... + &discoverTask{results: []*enode.Node{ + newNode(uintID(2), nil), // this one is already connected and not dialed. + newNode(uintID(3), nil), + newNode(uintID(4), nil), + newNode(uintID(5), nil), + newNode(uintID(6), nil), // these are not tried because max dyn dials is 5 + newNode(uintID(7), nil), // ... }}, }, new: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // Some of the dials complete but no new ones are launched yet because // the sum of active dial count and dynamic peer count is == maxDynDials. { peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(3)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(4)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, }, done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, }, }, // No new dial tasks are launched in the this round because // maxDynDials has been reached. { peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(3)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(4)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, }, done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ &waitExpireTask{Duration: 14 * time.Second}, @@ -159,29 +161,31 @@ func TestDialStateDynDial(t *testing.T) { // results from last discovery lookup are reused. { peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(3)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(4)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, + }, + new: []task{ + &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, }, - new: []task{}, }, // More peers (3,4) drop off and dial for ID 6 completes. // The last query result from the discovery lookup is reused // and a new one is spawned because more candidates are needed. { peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, }, done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(6)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, }, new: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(7)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)}, + &discoverTask{}, }, }, // Peer 7 is connected, but there still aren't enough dynamic peers @@ -189,29 +193,29 @@ func TestDialStateDynDial(t *testing.T) { // no new is started. { peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(7)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(7), nil)}}, }, done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(7)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)}, }, }, // Finish the running node discovery with an empty set. A new lookup // should be immediately requested. { peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(7)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(7), nil)}}, }, done: []task{ &discoverTask{}, }, new: []task{ - &waitExpireTask{Duration: 14 * time.Second}, + &discoverTask{}, }, }, }, @@ -220,17 +224,17 @@ func TestDialStateDynDial(t *testing.T) { // Tests that bootnodes are dialed if no peers are connectd, but not otherwise. func TestDialStateDynDialBootnode(t *testing.T) { - bootnodes := []*discover.Node{ - {ID: uintID(1)}, - {ID: uintID(2)}, - {ID: uintID(3)}, + bootnodes := []*enode.Node{ + newNode(uintID(1), nil), + newNode(uintID(2), nil), + newNode(uintID(3), nil), } table := fakeTable{ - {ID: uintID(4)}, - {ID: uintID(5)}, - {ID: uintID(6)}, - {ID: uintID(7)}, - {ID: uintID(8)}, + newNode(uintID(4), nil), + newNode(uintID(5), nil), + newNode(uintID(6), nil), + newNode(uintID(7), nil), + newNode(uintID(8), nil), } runDialTest(t, dialtest{ init: newDialState(nil, bootnodes, table, 5, nil), @@ -238,16 +242,16 @@ func TestDialStateDynDialBootnode(t *testing.T) { // 2 dynamic dials attempted, bootnodes pending fallback interval { new: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, &discoverTask{}, }, }, // No dials succeed, bootnodes still pending fallback interval { done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // No dials succeed, bootnodes still pending fallback interval @@ -255,54 +259,51 @@ func TestDialStateDynDialBootnode(t *testing.T) { // No dials succeed, 2 dynamic dials attempted and 1 bootnode too as fallback interval was reached { new: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // No dials succeed, 2nd bootnode is attempted { done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, }, }, // No dials succeed, 3rd bootnode is attempted { done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, }, new: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, }, }, // No dials succeed, 1st bootnode is attempted again, expired random nodes retried { done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, }, new: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // Random dial succeeds, no more bootnodes are attempted { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(4)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, }, done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, }, @@ -313,14 +314,14 @@ func TestDialStateDynDialFromTable(t *testing.T) { // This table always returns the same random nodes // in the order given below. table := fakeTable{ - {ID: uintID(1)}, - {ID: uintID(2)}, - {ID: uintID(3)}, - {ID: uintID(4)}, - {ID: uintID(5)}, - {ID: uintID(6)}, - {ID: uintID(7)}, - {ID: uintID(8)}, + newNode(uintID(1), nil), + newNode(uintID(2), nil), + newNode(uintID(3), nil), + newNode(uintID(4), nil), + newNode(uintID(5), nil), + newNode(uintID(6), nil), + newNode(uintID(7), nil), + newNode(uintID(8), nil), } runDialTest(t, dialtest{ @@ -329,67 +330,63 @@ func TestDialStateDynDialFromTable(t *testing.T) { // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. { new: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, &discoverTask{}, }, }, // Dialing nodes 1,2 succeeds. Dials from the lookup are launched. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, }, done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}}, - &discoverTask{results: []*discover.Node{ - {ID: uintID(10)}, - {ID: uintID(11)}, - {ID: uintID(12)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, + &discoverTask{results: []*enode.Node{ + newNode(uintID(10), nil), + newNode(uintID(11), nil), + newNode(uintID(12), nil), }}, }, new: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(2)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(10)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(11)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(12)}}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)}, + &discoverTask{}, }, }, // Dialing nodes 3,4,5 fails. The dials from the lookup succeed. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(10)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(11)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(12)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, }, done: []task{ - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(3)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(5)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(10)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(11)}}, - &dialTask{flags: dynDialedConn, dest: &discover.Node{ID: uintID(12)}}, - }, - new: []task{ - &discoverTask{}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)}, }, }, // Waiting for expiry. No waitExpireTask is launched because the // discovery query is still running. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(10)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(11)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(12)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, }, }, // Nodes 3,4 are not tried again because only the first two @@ -397,30 +394,38 @@ func TestDialStateDynDialFromTable(t *testing.T) { // already connected. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(10)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(11)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(12)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, }, }, }, }) } +func newNode(id enode.ID, ip net.IP) *enode.Node { + var r enr.Record + if ip != nil { + r.Set(enr.IP(ip)) + } + return enode.SignNull(&r, id) +} + // This test checks that candidates that do not match the netrestrict list are not dialed. func TestDialStateNetRestrict(t *testing.T) { // This table always returns the same random nodes // in the order given below. table := fakeTable{ - {ID: uintID(1), IP: net.ParseIP("127.0.0.1")}, - {ID: uintID(2), IP: net.ParseIP("127.0.0.2")}, - {ID: uintID(3), IP: net.ParseIP("127.0.0.3")}, - {ID: uintID(4), IP: net.ParseIP("127.0.0.4")}, - {ID: uintID(5), IP: net.ParseIP("127.0.2.5")}, - {ID: uintID(6), IP: net.ParseIP("127.0.2.6")}, - {ID: uintID(7), IP: net.ParseIP("127.0.2.7")}, - {ID: uintID(8), IP: net.ParseIP("127.0.2.8")}, + newNode(uintID(1), net.ParseIP("127.0.0.1")), + newNode(uintID(2), net.ParseIP("127.0.0.2")), + newNode(uintID(3), net.ParseIP("127.0.0.3")), + newNode(uintID(4), net.ParseIP("127.0.0.4")), + newNode(uintID(5), net.ParseIP("127.0.2.5")), + newNode(uintID(6), net.ParseIP("127.0.2.6")), + newNode(uintID(7), net.ParseIP("127.0.2.7")), + newNode(uintID(8), net.ParseIP("127.0.2.8")), } restrict := new(netutil.Netlist) restrict.Add("127.0.2.0/24") @@ -440,12 +445,12 @@ func TestDialStateNetRestrict(t *testing.T) { // This test checks that static dials are launched. func TestDialStateStaticDial(t *testing.T) { - wantStatic := []*discover.Node{ - {ID: uintID(1)}, - {ID: uintID(2)}, - {ID: uintID(3)}, - {ID: uintID(4)}, - {ID: uintID(5)}, + wantStatic := []*enode.Node{ + newNode(uintID(1), nil), + newNode(uintID(2), nil), + newNode(uintID(3), nil), + newNode(uintID(4), nil), + newNode(uintID(5), nil), } runDialTest(t, dialtest{ @@ -455,70 +460,67 @@ func TestDialStateStaticDial(t *testing.T) { // aren't yet connected. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, }, new: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(5)}}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)}, }, }, // No new tasks are launched in this round because all static // nodes are either connected or still being dialed. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, - }, - new: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, }, done: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, }, }, // No new dial tasks are launched because all static // nodes are now connected. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(4)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(5)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}}, }, done: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(5)}}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(4)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(5)}}, + &waitExpireTask{Duration: 14 * time.Second}, }, }, // Wait a round for dial history to expire, no new tasks should spawn. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(4)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(5)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}}, }, }, // If a static node is dropped, it should be immediately redialed, // irrespective whether it was originally static or dynamic. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(5)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}}, + }, + new: []task{ + &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)}, }, - new: []task{}, }, }, }) @@ -526,9 +528,9 @@ func TestDialStateStaticDial(t *testing.T) { // This test checks that static peers will be redialed immediately if they were re-added to a static list. func TestDialStaticAfterReset(t *testing.T) { - wantStatic := []*discover.Node{ - {ID: uintID(1)}, - {ID: uintID(2)}, + wantStatic := []*enode.Node{ + newNode(uintID(1), nil), + newNode(uintID(2), nil), } rounds := []round{ @@ -536,23 +538,22 @@ func TestDialStaticAfterReset(t *testing.T) { { peers: nil, new: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, }, }, // No new dial tasks, all peers are connected. { peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, id: uintID(1)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(2)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, }, done: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, }, new: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + &waitExpireTask{Duration: 30 * time.Second}, }, }, } @@ -564,19 +565,17 @@ func TestDialStaticAfterReset(t *testing.T) { for _, n := range wantStatic { dTest.init.removeStatic(n) dTest.init.addStatic(n) - delete(dTest.init.dialing, n.ID) } - // without removing peers they will be considered recently dialed runDialTest(t, dTest) } // This test checks that past dials are not retried for some time. func TestDialStateCache(t *testing.T) { - wantStatic := []*discover.Node{ - {ID: uintID(1)}, - {ID: uintID(2)}, - {ID: uintID(3)}, + wantStatic := []*enode.Node{ + newNode(uintID(1), nil), + newNode(uintID(2), nil), + newNode(uintID(3), nil), } runDialTest(t, dialtest{ @@ -587,53 +586,52 @@ func TestDialStateCache(t *testing.T) { { peers: nil, new: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, }, }, // No new tasks are launched in this round because all static // nodes are either connected or still being dialed. { peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, id: uintID(1)}}, - {rw: &conn{flags: staticDialedConn, id: uintID(2)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, }, done: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, - }, - new: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(1)}}, - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(2)}}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, }, }, // A salvage task is launched to wait for node 3's history // entry to expire. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, }, done: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, + }, + new: []task{ + &waitExpireTask{Duration: 14 * time.Second}, }, }, // Still waiting for node 3's entry to expire in the cache. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, }, }, // The cache entry for node 3 has expired and is retried. { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, - {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, }, new: []task{ - &dialTask{flags: staticDialedConn, dest: &discover.Node{ID: uintID(3)}}, + &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, }, }, }, @@ -641,12 +639,12 @@ func TestDialStateCache(t *testing.T) { } func TestDialResolve(t *testing.T) { - resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444) + resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) table := &resolveMock{answer: resolved} state := newDialState(nil, nil, table, 0, nil) // Check that the task is generated with an incomplete ID. - dest := discover.NewNode(uintID(1), nil, 0, 0) + dest := newNode(uintID(1), nil) state.addStatic(dest) tasks := state.newTasks(0, nil, time.Time{}) if !reflect.DeepEqual(tasks, []task{&dialTask{flags: staticDialedConn, dest: dest}}) { @@ -657,7 +655,7 @@ func TestDialResolve(t *testing.T) { config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}} srv := &Server{ntab: table, Config: config} tasks[0].Do(srv) - if !reflect.DeepEqual(table.resolveCalls, []enode.ID{dest.ID}) { + if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) { t.Fatalf("wrong resolve calls, got %v", table.resolveCalls) } @@ -693,17 +691,16 @@ func uintID(i uint32) enode.ID { // implements discoverTable for TestDialResolve type resolveMock struct { - resolveCalls []enode.ID - answer *discover.Node + resolveCalls []*enode.Node + answer *enode.Node } -func (t *resolveMock) Resolve(id enode.ID) *discover.Node { - t.resolveCalls = append(t.resolveCalls, id) +func (t *resolveMock) Resolve(n *enode.Node) *enode.Node { + t.resolveCalls = append(t.resolveCalls, n) return t.answer } -func (t *resolveMock) Self() *discover.Node { return new(discover.Node) } -func (t *resolveMock) Close() {} -func (t *resolveMock) Bootstrap([]*discover.Node) {} -func (t *resolveMock) Lookup(enode.ID) []*discover.Node { return nil } -func (t *resolveMock) ReadRandomNodes(buf []*discover.Node) int { return 0 } +func (t *resolveMock) Self() *enode.Node { return new(enode.Node) } +func (t *resolveMock) Close() {} +func (t *resolveMock) LookupRandom() []*enode.Node { return nil } +func (t *resolveMock) ReadRandomNodes(buf []*enode.Node) int { return 0 } diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index 0ca16072cb..3a21dd2792 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -107,14 +107,14 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { func (s *SimAdapter) Dial(dest *enode.Node) (conn net.Conn, err error) { node, ok := s.GetNode(dest.ID()) if !ok { - return nil, fmt.Errorf("unknown node: %s", dest.ID) + return nil, fmt.Errorf("unknown node: %s", dest.ID()) } if node.connected[dest.ID()] { - return nil, fmt.Errorf("dialed node: %s", dest.ID) + return nil, fmt.Errorf("dialed node: %s", dest.ID()) } srv := node.Server() if srv == nil { - return nil, fmt.Errorf("node not running: %s", dest.ID) + return nil, fmt.Errorf("node not running: %s", dest.ID()) } pipe1, pipe2 := net.Pipe() go srv.SetupConn(pipe1, 0, nil) From 3c18913ae30c8448e30a772bdf1fc374be7f234a Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 31 Oct 2023 22:51:00 +0700 Subject: [PATCH 101/119] [WIP] Fix enode unit tests --- p2p/discover/node_test.go | 335 --------------------------------- p2p/discover/udp_test.go | 4 +- p2p/peer_test.go | 11 +- p2p/simulations/http_test.go | 7 +- whisper/whisperv5/peer_test.go | 9 +- whisper/whisperv6/peer_test.go | 13 +- 6 files changed, 23 insertions(+), 356 deletions(-) delete mode 100644 p2p/discover/node_test.go diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go deleted file mode 100644 index ddf8a7bd98..0000000000 --- a/p2p/discover/node_test.go +++ /dev/null @@ -1,335 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package discover - -import ( - "bytes" - "fmt" - "math/big" - "math/rand" - "net" - "reflect" - "strings" - "testing" - "testing/quick" - "time" - - "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/crypto" -) - -func ExampleNewNode() { - id := MustHexID("1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439") - - // Complete nodes contain UDP and TCP endpoints: - n1 := NewNode(id, net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), 52150, 30303) - fmt.Println("n1:", n1) - fmt.Println("n1.Incomplete() ->", n1.Incomplete()) - - // An incomplete node can be created by passing zero values - // for all parameters except id. - n2 := NewNode(id, nil, 0, 0) - fmt.Println("n2:", n2) - fmt.Println("n2.Incomplete() ->", n2.Incomplete()) - - // Output: - // n1: enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[2001:db8:3c4d:15::abcd:ef12]:30303?discport=52150 - // n1.Incomplete() -> false - // n2: enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439 - // n2.Incomplete() -> true -} - -var parseNodeTests = []struct { - rawurl string - wantError string - wantResult *Node -}{ - { - rawurl: "http://foobar", - wantError: `invalid URL scheme, want "enode"`, - }, - { - rawurl: "enode://01010101@123.124.125.126:3", - wantError: `invalid node ID (wrong length, want 128 hex chars)`, - }, - // Complete nodes with IP address. - { - rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3", - wantError: `invalid IP address`, - }, - //{ - // rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo", - // wantError: `parse enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo: invalid port ":foo" after host`, - //}, - { - rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo", - wantError: `invalid discport in query`, - }, - { - rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150", - wantResult: NewNode( - MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.IP{0x7f, 0x0, 0x0, 0x1}, - 52150, - 52150, - ), - }, - { - rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150", - wantResult: NewNode( - MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.ParseIP("::"), - 52150, - 52150, - ), - }, - { - rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[2001:db8:3c4d:15::abcd:ef12]:52150", - wantResult: NewNode( - MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), - 52150, - 52150, - ), - }, - { - rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=22334", - wantResult: NewNode( - MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - net.IP{0x7f, 0x0, 0x0, 0x1}, - 22334, - 52150, - ), - }, - // Incomplete nodes with no address. - { - rawurl: "1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439", - wantResult: NewNode( - MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - nil, 0, 0, - ), - }, - { - rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439", - wantResult: NewNode( - MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"), - nil, 0, 0, - ), - }, - // Invalid URLs - { - rawurl: "01010101", - wantError: `invalid node ID (wrong length, want 128 hex chars)`, - }, - { - rawurl: "enode://01010101", - wantError: `invalid node ID (wrong length, want 128 hex chars)`, - }, - { - // This test checks that errors from url.Parse are handled. - rawurl: "://foo", - wantError: `parse "://foo": missing protocol scheme`, - }, -} - -func TestParseNode(t *testing.T) { - for _, test := range parseNodeTests { - n, err := ParseNode(test.rawurl) - if test.wantError != "" { - if err == nil { - t.Errorf("test %q:\n got nil error, expected %#q", test.rawurl, test.wantError) - continue - } else if err.Error() != test.wantError { - t.Errorf("test %q:\n got error %#q, expected %#q", test.rawurl, err.Error(), test.wantError) - continue - } - } else { - if err != nil { - t.Errorf("test %q:\n unexpected error: %v", test.rawurl, err) - continue - } - if !reflect.DeepEqual(n, test.wantResult) { - t.Errorf("test %q:\n result mismatch:\ngot: %#v, want: %#v", test.rawurl, n, test.wantResult) - } - } - } -} - -func TestNodeString(t *testing.T) { - for i, test := range parseNodeTests { - if test.wantError == "" && strings.HasPrefix(test.rawurl, "enode://") { - str := test.wantResult.String() - if str != test.rawurl { - t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl) - } - } - } -} - -func TestHexID(t *testing.T) { - ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188} - id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc") - id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc") - - if id1 != ref { - t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:]) - } - if id2 != ref { - t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:]) - } -} - -func TestNodeID_textEncoding(t *testing.T) { - ref := NodeID{ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10, - 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x20, - 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x30, - 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x40, - 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x50, - 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x60, - 0x61, 0x62, 0x63, 0x64, - } - hex := "01020304050607080910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364" - - text, err := ref.MarshalText() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(text, []byte(hex)) { - t.Fatalf("text encoding did not match\nexpected: %s\ngot: %s", hex, text) - } - - id := new(NodeID) - if err := id.UnmarshalText(text); err != nil { - t.Fatal(err) - } - if *id != ref { - t.Fatalf("text decoding did not match\nexpected: %s\ngot: %s", ref, id) - } -} - -func TestNodeID_recover(t *testing.T) { - prv := newkey() - hash := make([]byte, 32) - sig, err := crypto.Sign(hash, prv) - if err != nil { - t.Fatalf("signing error: %v", err) - } - - pub := PubkeyID(&prv.PublicKey) - recpub, err := recoverNodeID(hash, sig) - if err != nil { - t.Fatalf("recovery error: %v", err) - } - if pub != recpub { - t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub) - } - - ecdsa, err := pub.Pubkey() - if err != nil { - t.Errorf("Pubkey error: %v", err) - } - if !reflect.DeepEqual(ecdsa, &prv.PublicKey) { - t.Errorf("Pubkey mismatch:\n got: %#v\n want: %#v", ecdsa, &prv.PublicKey) - } -} - -func TestNodeID_pubkeyBad(t *testing.T) { - ecdsa, err := NodeID{}.Pubkey() - if err == nil { - t.Error("expected error for zero ID") - } - if ecdsa != nil { - t.Error("expected nil result") - } -} - -func TestNodeID_distcmp(t *testing.T) { - distcmpBig := func(target, a, b common.Hash) int { - tbig := new(big.Int).SetBytes(target[:]) - abig := new(big.Int).SetBytes(a[:]) - bbig := new(big.Int).SetBytes(b[:]) - return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig)) - } - if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg()); err != nil { - t.Error(err) - } -} - -// the random tests is likely to miss the case where they're equal. -func TestNodeID_distcmpEqual(t *testing.T) { - base := common.Hash{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - x := common.Hash{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0} - if distcmp(base, x, x) != 0 { - t.Errorf("distcmp(base, x, x) != 0") - } -} - -func TestNodeID_logdist(t *testing.T) { - logdistBig := func(a, b common.Hash) int { - abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:]) - return new(big.Int).Xor(abig, bbig).BitLen() - } - if err := quick.CheckEqual(logdist, logdistBig, quickcfg()); err != nil { - t.Error(err) - } -} - -// the random tests is likely to miss the case where they're equal. -func TestNodeID_logdistEqual(t *testing.T) { - x := common.Hash{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - if logdist(x, x) != 0 { - t.Errorf("logdist(x, x) != 0") - } -} - -func TestNodeID_hashAtDistance(t *testing.T) { - // we don't use quick.Check here because its output isn't - // very helpful when the test fails. - cfg := quickcfg() - for i := 0; i < cfg.MaxCount; i++ { - a := gen(common.Hash{}, cfg.Rand).(common.Hash) - dist := cfg.Rand.Intn(len(common.Hash{}) * 8) - result := hashAtDistance(a, dist) - actualdist := logdist(result, a) - - if dist != actualdist { - t.Log("a: ", a) - t.Log("result:", result) - t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist) - } - } -} - -func quickcfg() *quick.Config { - return &quick.Config{ - MaxCount: 5000, - Rand: rand.New(rand.NewSource(time.Now().Unix())), - } -} - -// TODO: The Generate method can be dropped when we require Go >= 1.5 -// because testing/quick learned to generate arrays in 1.5. - -func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value { - var id NodeID - m := rand.Intn(len(id)) - for i := len(id) - 1; i > m; i-- { - id[i] = byte(rand.Uint32()) - } - return reflect.ValueOf(id) -} diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index 82ca1ef19f..e5fb32d082 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -250,7 +250,7 @@ func TestUDP_findnode(t *testing.T) { // ensure there's a bond with the test node, // findnode won't be accepted otherwise. remoteID := encodePubkey(&test.remotekey.PublicKey).id() - test.table.db.UpdateLastPongReceived(remoteID, time.Now()) + test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now()) // check that closest neighbors are returned. test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) @@ -277,7 +277,7 @@ func TestUDP_findnodeMultiReply(t *testing.T) { defer test.table.Close() rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey) - test.table.db.UpdateLastPingReceived(rid, time.Now()) + test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now()) // queue a pending findnode request resultc, errc := make(chan []*node), make(chan error) diff --git a/p2p/peer_test.go b/p2p/peer_test.go index a3e1c74fd8..d9e4b6c333 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -44,9 +44,14 @@ var discard = Protocol{ } func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) { - fd1, fd2 := net.Pipe() - c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1)} - c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2)} + var ( + fd1, fd2 = net.Pipe() + key1, key2 = newkey(), newkey() + t1 = newTestTransport(&key2.PublicKey, fd1) + t2 = newTestTransport(&key1.PublicKey, fd2) + ) + c1 := &conn{fd: fd1, transport: t1} + c2 := &conn{fd: fd2, transport: t2} for _, p := range protos { c1.caps = append(c1.caps, p.cap()) c2.caps = append(c2.caps, p.cap()) diff --git a/p2p/simulations/http_test.go b/p2p/simulations/http_test.go index a89301895f..d557e5f996 100644 --- a/p2p/simulations/http_test.go +++ b/p2p/simulations/http_test.go @@ -30,7 +30,6 @@ import ( "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/simulations/adapters" "github.com/tomochain/tomochain/rpc" @@ -412,7 +411,7 @@ func (t *expectEvents) nodeEvent(id string, up bool) *Event { Type: EventTypeNode, Node: &Node{ Config: &adapters.NodeConfig{ - ID: discover.MustHexID(id), + ID: enode.HexID(id), }, Up: up, }, @@ -423,8 +422,8 @@ func (t *expectEvents) connEvent(one, other string, up bool) *Event { return &Event{ Type: EventTypeConn, Conn: &Conn{ - One: discover.MustHexID(one), - Other: discover.MustHexID(other), + One: enode.HexID(one), + Other: enode.HexID(other), Up: up, }, } diff --git a/whisper/whisperv5/peer_test.go b/whisper/whisperv5/peer_test.go index 2805aa3d52..0594c22d58 100644 --- a/whisper/whisperv5/peer_test.go +++ b/whisper/whisperv5/peer_test.go @@ -28,7 +28,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/nat" ) @@ -131,12 +131,11 @@ func initialize(t *testing.T) { port := port0 + i addr := fmt.Sprintf(":%d", port) // e.g. ":30303" name := common.MakeName("whisper-go", "2.0") - var peers []*discover.Node + var peers []*enode.Node if i > 0 { peerNodeId := nodes[i-1].id - peerPort := uint16(port - 1) - peerNode := discover.PubkeyID(&peerNodeId.PublicKey) - peer := discover.NewNode(peerNode, ip, peerPort, peerPort) + peerPort := port - 1 + peer := enode.NewV4(&peerNodeId.PublicKey, ip, peerPort, peerPort) peers = append(peers, peer) } diff --git a/whisper/whisperv6/peer_test.go b/whisper/whisperv6/peer_test.go index 1f0365eacf..7a3b53265b 100644 --- a/whisper/whisperv6/peer_test.go +++ b/whisper/whisperv6/peer_test.go @@ -31,7 +31,7 @@ import ( "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/nat" ) @@ -202,12 +202,11 @@ func initialize(t *testing.T) { port := port0 + i addr := fmt.Sprintf(":%d", port) // e.g. ":30303" name := common.MakeName("whisper-go", "2.0") - var peers []*discover.Node + var peers []*enode.Node if i > 0 { peerNodeID := nodes[i-1].id - peerPort := uint16(port - 1) - peerNode := discover.PubkeyID(&peerNodeID.PublicKey) - peer := discover.NewNode(peerNode, ip, peerPort, peerPort) + peerPort := port - 1 + peer := enode.NewV4(&peerNodeID.PublicKey, ip, peerPort, peerPort) peers = append(peers, peer) } @@ -437,7 +436,7 @@ func checkPowExchangeForNodeZeroOnce(t *testing.T, mustPass bool) bool { cnt := 0 for i, node := range nodes { for peer := range node.shh.peers { - if peer.peer.ID() == discover.PubkeyID(&nodes[0].id.PublicKey) { + if peer.peer.ID() == enode.PubkeyToIDV4(&nodes[0].id.PublicKey) { cnt++ if peer.powRequirement != masterPow { if mustPass { @@ -458,7 +457,7 @@ func checkPowExchangeForNodeZeroOnce(t *testing.T, mustPass bool) bool { func checkPowExchange(t *testing.T) { for i, node := range nodes { for peer := range node.shh.peers { - if peer.peer.ID() != discover.PubkeyID(&nodes[0].id.PublicKey) { + if peer.peer.ID() != enode.PubkeyToIDV4(&nodes[0].id.PublicKey) { if peer.powRequirement != masterPow { t.Fatalf("node %d: failed to exchange pow requirement in round %d; expected %f, got %f", i, round, masterPow, peer.powRequirement) From 7c3dbe0562f009c8803b8adfcc54b8bd9acf1ede Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 31 Oct 2023 23:26:24 +0700 Subject: [PATCH 102/119] [WIP] Fix swarm unit tests --- cmd/swarm/config_test.go | 14 ++ cmd/swarm/run_test.go | 273 +++++++++++++++++++++++++++++++++------ eth/helper_test.go | 8 +- node/api.go | 4 +- 4 files changed, 257 insertions(+), 42 deletions(-) diff --git a/cmd/swarm/config_test.go b/cmd/swarm/config_test.go index 05b5eeb90c..4a1e30db18 100644 --- a/cmd/swarm/config_test.go +++ b/cmd/swarm/config_test.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "os" "os/exec" "testing" @@ -552,3 +553,16 @@ func TestValidateConfig(t *testing.T) { } } } + +func assignTCPPort() (string, error) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", err + } + l.Close() + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + return "", err + } + return port, nil +} diff --git a/cmd/swarm/run_test.go b/cmd/swarm/run_test.go index 6c6d3d66ed..94b4397b2a 100644 --- a/cmd/swarm/run_test.go +++ b/cmd/swarm/run_test.go @@ -17,16 +17,22 @@ package main import ( + "context" + "crypto/ecdsa" "fmt" "io/ioutil" "net" "os" + "path" "path/filepath" "runtime" + "sync" + "syscall" "testing" "time" "github.com/docker/docker/pkg/reexec" + "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/accounts/keystore" "github.com/tomochain/tomochain/internal/cmdtest" @@ -81,6 +87,7 @@ type testCluster struct { // // When starting more than one node, they are connected together using the // admin SetPeer RPC method. + func newTestCluster(t *testing.T, size int) *testCluster { cluster := &testCluster{} defer func() { @@ -96,18 +103,7 @@ func newTestCluster(t *testing.T, size int) *testCluster { cluster.TmpDir = tmpdir // start the nodes - cluster.Nodes = make([]*testNode, 0, size) - for i := 0; i < size; i++ { - dir := filepath.Join(cluster.TmpDir, fmt.Sprintf("swarm%02d", i)) - if err := os.Mkdir(dir, 0700); err != nil { - t.Fatal(err) - } - - node := newTestNode(t, dir) - node.Name = fmt.Sprintf("swarm%02d", i) - - cluster.Nodes = append(cluster.Nodes, node) - } + cluster.StartNewNodes(t, size) if size == 1 { return cluster @@ -145,14 +141,52 @@ func (c *testCluster) Shutdown() { os.RemoveAll(c.TmpDir) } +func (c *testCluster) Stop() { + for _, node := range c.Nodes { + node.Shutdown() + } +} + +func (c *testCluster) StartNewNodes(t *testing.T, size int) { + c.Nodes = make([]*testNode, 0, size) + for i := 0; i < size; i++ { + dir := filepath.Join(c.TmpDir, fmt.Sprintf("swarm%02d", i)) + if err := os.Mkdir(dir, 0700); err != nil { + t.Fatal(err) + } + + node := newTestNode(t, dir) + node.Name = fmt.Sprintf("swarm%02d", i) + + c.Nodes = append(c.Nodes, node) + } +} + +func (c *testCluster) StartExistingNodes(t *testing.T, size int, bzzaccount string) { + c.Nodes = make([]*testNode, 0, size) + for i := 0; i < size; i++ { + dir := filepath.Join(c.TmpDir, fmt.Sprintf("swarm%02d", i)) + node := existingTestNode(t, dir, bzzaccount) + node.Name = fmt.Sprintf("swarm%02d", i) + + c.Nodes = append(c.Nodes, node) + } +} + +func (c *testCluster) Cleanup() { + os.RemoveAll(c.TmpDir) +} + type testNode struct { - Name string - Addr string - URL string - Enode string - Dir string - Client *rpc.Client - Cmd *cmdtest.TestCmd + Name string + Addr string + URL string + Enode string + Dir string + IpcPath string + PrivateKey *ecdsa.PrivateKey + Client *rpc.Client + Cmd *cmdtest.TestCmd } const testPassphrase = "swarm-test-passphrase" @@ -181,24 +215,103 @@ func getTestAccount(t *testing.T, dir string) (conf *node.Config, account accoun return conf, account } -func newTestNode(t *testing.T, dir string) *testNode { - - conf, account := getTestAccount(t, dir) +func existingTestNode(t *testing.T, dir string, bzzaccount string) *testNode { + conf, _ := getTestAccount(t, dir) node := &testNode{Dir: dir} + // use a unique IPCPath when running tests on Windows + if runtime.GOOS == "windows" { + conf.IPCPath = fmt.Sprintf("bzzd-%s.ipc", bzzaccount) + } + // assign ports - httpPort, err := assignTCPPort() + ports, err := getAvailableTCPPorts(2) + if err != nil { + t.Fatal(err) + } + p2pPort := ports[0] + httpPort := ports[1] + + // start the node + node.Cmd = runSwarm(t, + "--port", p2pPort, + "--nat", "extip:127.0.0.1", + "--nodiscover", + "--datadir", dir, + "--ipcpath", conf.IPCPath, + "--ens-api", "", + "--bzzaccount", bzzaccount, + "--bzznetworkid", "321", + "--bzzport", httpPort, + "--verbosity", "3", + ) + node.Cmd.InputLine(testPassphrase) + defer func() { + if t.Failed() { + node.Shutdown() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // ensure that all ports have active listeners + // so that the next node will not get the same + // when calling getAvailableTCPPorts + err = waitTCPPorts(ctx, ports...) if err != nil { t.Fatal(err) } - p2pPort, err := assignTCPPort() + + // wait for the node to start + for start := time.Now(); time.Since(start) < 10*time.Second; time.Sleep(50 * time.Millisecond) { + node.Client, err = rpc.Dial(conf.IPCEndpoint()) + if err == nil { + break + } + } + if node.Client == nil { + t.Fatal(err) + } + + // load info + var info swarm.Info + if err := node.Client.Call(&info, "bzz_info"); err != nil { + t.Fatal(err) + } + node.Addr = net.JoinHostPort("127.0.0.1", info.Port) + node.URL = "http://" + node.Addr + + var nodeInfo p2p.NodeInfo + if err := node.Client.Call(&nodeInfo, "admin_nodeInfo"); err != nil { + t.Fatal(err) + } + node.Enode = nodeInfo.Enode + node.IpcPath = conf.IPCPath + return node +} + +func newTestNode(t *testing.T, dir string) *testNode { + + conf, account := getTestAccount(t, dir) + ks := keystore.NewKeyStore(path.Join(dir, "keystore"), 1<<18, 1) + + pk := decryptStoreAccount(ks, account.Address.Hex(), []string{testPassphrase}) + + node := &testNode{Dir: dir, PrivateKey: pk} + + // assign ports + ports, err := getAvailableTCPPorts(2) if err != nil { t.Fatal(err) } + p2pPort := ports[0] + httpPort := ports[1] // start the node node.Cmd = runSwarm(t, "--port", p2pPort, + "--nat", "extip:127.0.0.1", "--nodiscover", "--datadir", dir, "--ipcpath", conf.IPCPath, @@ -206,7 +319,7 @@ func newTestNode(t *testing.T, dir string) *testNode { "--bzzaccount", account.Address.String(), "--bzznetworkid", "321", "--bzzport", httpPort, - "--verbosity", "6", + "--verbosity", "3", ) node.Cmd.InputLine(testPassphrase) defer func() { @@ -215,6 +328,17 @@ func newTestNode(t *testing.T, dir string) *testNode { } }() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // ensure that all ports have active listeners + // so that the next node will not get the same + // when calling getAvailableTCPPorts + err = waitTCPPorts(ctx, ports...) + if err != nil { + t.Fatal(err) + } + // wait for the node to start for start := time.Now(); time.Since(start) < 10*time.Second; time.Sleep(50 * time.Millisecond) { node.Client, err = rpc.Dial(conf.IPCEndpoint()) @@ -238,8 +362,8 @@ func newTestNode(t *testing.T, dir string) *testNode { if err := node.Client.Call(&nodeInfo, "admin_nodeInfo"); err != nil { t.Fatal(err) } - node.Enode = fmt.Sprintf("enode://%s@127.0.0.1:%s", nodeInfo.ID, p2pPort) - + node.Enode = nodeInfo.Enode + node.IpcPath = conf.IPCPath return node } @@ -249,15 +373,92 @@ func (n *testNode) Shutdown() { } } -func assignTCPPort() (string, error) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return "", err +// getAvailableTCPPorts returns a set of ports that +// nothing is listening on at the time. +// +// Function assignTCPPort cannot be called in sequence +// and guardantee that the same port will be returned in +// different calls as the listener is closed within the function, +// not after all listeners are started and selected unique +// available ports. +func getAvailableTCPPorts(count int) (ports []string, err error) { + for i := 0; i < count; i++ { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + // defer close in the loop to be sure the same port will not + // be selected in the next iteration + defer l.Close() + + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + return nil, err + } + ports = append(ports, port) } - l.Close() - _, port, err := net.SplitHostPort(l.Addr().String()) - if err != nil { - return "", err + return ports, nil +} + +// waitTCPPorts blocks until tcp connections can be +// established on all provided ports. It runs all +// ports dialers in parallel, and returns the first +// encountered error. +// See waitTCPPort also. +func waitTCPPorts(ctx context.Context, ports ...string) error { + var err error + // mu locks err variable that is assigned in + // other goroutines + var mu sync.Mutex + + // cancel is canceling all goroutines + // when the firs error is returned + // to prevent unnecessary waiting + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + for _, port := range ports { + wg.Add(1) + go func(port string) { + defer wg.Done() + + e := waitTCPPort(ctx, port) + + mu.Lock() + defer mu.Unlock() + if e != nil && err == nil { + err = e + cancel() + } + }(port) + } + wg.Wait() + + return err +} + +// waitTCPPort blocks until tcp connection can be established +// ona provided port. It has a 3 minute timeout as maximum, +// to prevent long waiting, but it can be shortened with +// a provided context instance. Dialer has a 10 second timeout +// in every iteration, and connection refused error will be +// retried in 100 milliseconds periods. +func waitTCPPort(ctx context.Context, port string) error { + ctx, cancel := context.WithTimeout(ctx, 3*time.Minute) + defer cancel() + + for { + c, err := (&net.Dialer{Timeout: 10 * time.Second}).DialContext(ctx, "tcp", "127.0.0.1:"+port) + if err != nil { + if operr, ok := err.(*net.OpError); ok { + if syserr, ok := operr.Err.(*os.SyscallError); ok && syserr.Err == syscall.ECONNREFUSED { + time.Sleep(100 * time.Millisecond) + continue + } + } + return err + } + return c.Close() } - return port, nil } diff --git a/eth/helper_test.go b/eth/helper_test.go index 6ea65856f4..85f75f01d1 100644 --- a/eth/helper_test.go +++ b/eth/helper_test.go @@ -22,7 +22,6 @@ package eth import ( "crypto/ecdsa" "crypto/rand" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sort" "sync" @@ -31,6 +30,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -38,7 +38,7 @@ import ( "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/params" ) @@ -149,8 +149,8 @@ func newTestPeer(name string, version int, pm *ProtocolManager, shake bool) (*te app, net := p2p.MsgPipe() // Generate a random id and create the peer - var id discover.NodeID - rand.Read(id[:]) + var id enode.ID + rand.Read(id.Bytes()) peer := pm.newPeer(version, p2p.NewPeer(id, name, nil), net) diff --git a/node/api.go b/node/api.go index a6b92c6b80..25b67ac1b7 100644 --- a/node/api.go +++ b/node/api.go @@ -51,9 +51,9 @@ func (api *PrivateAdminAPI) AddPeer(url string) (bool, error) { return false, ErrNodeStopped } // Try to add the url as a static peer and return - node, err := enode.ParseV4(url) + node, err := enode.Parse(enode.ValidSchemes, url) if err != nil { - return false, fmt.Errorf("invalid enode: %v", err) + return false, fmt.Errorf("invalid enode url: %v, err %v", url, err) } server.AddPeer(node) return true, nil From f562dea2d80c6846eccbf87f1c29d08bf1bc91d0 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 31 Oct 2023 23:52:33 +0700 Subject: [PATCH 103/119] Fix downloader and p2p/peer unit tests --- cmd/swarm/main.go | 2 +- eth/downloader/downloader_test.go | 22 ++++++++++++---------- p2p/peer_test.go | 4 ++-- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/cmd/swarm/main.go b/cmd/swarm/main.go index 221ccbbcb7..83a2609df2 100644 --- a/cmd/swarm/main.go +++ b/cmd/swarm/main.go @@ -543,7 +543,7 @@ func getPassPhrase(prompt string, i int, passwords []string) string { func injectBootnodes(srv *p2p.Server, nodes []string) { for _, url := range nodes { - n, err := enode.ParseV4(url) + n, err := enode.Parse(enode.ValidSchemes, url) if err != nil { log.Error("Invalid swarm bootnode", "err", err) continue diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index af39f9856b..470819224e 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -94,7 +94,7 @@ func newTester() *downloadTester { peerChainTds: make(map[string]map[common.Hash]*big.Int), peerMissingStates: make(map[string]map[common.Hash]bool), } - tester.stateDb= rawdb.NewMemoryDatabase() + tester.stateDb = rawdb.NewMemoryDatabase() tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00}) tester.downloader = New(FullSync, tester.stateDb, new(event.TypeMux), tester, nil, tester.dropPeer) @@ -160,7 +160,7 @@ func (dl *downloadTester) makeChainFork(n, f int, parent *types.Block, parentRec // Create the common suffix hashes, headers, blocks, receipts := dl.makeChain(n-f, 0, parent, parentReceipts, false) - // Create the forks, making the second heavyer if non balanced forks were requested + // Create the forks, making the second heavier if non balanced forks were requested hashes1, headers1, blocks1, receipts1 := dl.makeChain(f, 1, blocks[hashes[0]], receipts[hashes[0]], false) hashes1 = append(hashes1, hashes[1:]...) @@ -663,12 +663,14 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng // Tests that simple synchronization against a canonical chain works correctly. // In this test common ancestor lookup should be short circuited and not require // binary searching. -func TestCanonicalSynchronisation62(t *testing.T) { testCanonicalSynchronisation(t, 62, FullSync) } -func TestCanonicalSynchronisation63Full(t *testing.T) { testCanonicalSynchronisation(t, 63, FullSync) } -func TestCanonicalSynchronisation63Fast(t *testing.T) { testCanonicalSynchronisation(t, 63, FastSync) } -func TestCanonicalSynchronisation64Full(t *testing.T) { testCanonicalSynchronisation(t, 64, FullSync) } -func TestCanonicalSynchronisation64Fast(t *testing.T) { testCanonicalSynchronisation(t, 64, FastSync) } -func TestCanonicalSynchronisation64Light(t *testing.T) { testCanonicalSynchronisation(t, 64, LightSync) } +func TestCanonicalSynchronisation62(t *testing.T) { testCanonicalSynchronisation(t, 62, FullSync) } +func TestCanonicalSynchronisation63Full(t *testing.T) { testCanonicalSynchronisation(t, 63, FullSync) } +func TestCanonicalSynchronisation63Fast(t *testing.T) { testCanonicalSynchronisation(t, 63, FastSync) } +func TestCanonicalSynchronisation64Full(t *testing.T) { testCanonicalSynchronisation(t, 64, FullSync) } +func TestCanonicalSynchronisation64Fast(t *testing.T) { testCanonicalSynchronisation(t, 64, FastSync) } +func TestCanonicalSynchronisation64Light(t *testing.T) { + testCanonicalSynchronisation(t, 64, LightSync) +} func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) { t.Parallel() @@ -1357,8 +1359,8 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) { } } -//Tests that synchronisation progress (origin block number, current block number -//and highest block number) is tracked and updated correctly. +// Tests that synchronisation progress (origin block number, current block number +// and highest block number) is tracked and updated correctly. func TestSyncProgress62(t *testing.T) { testSyncProgress(t, 62, FullSync) } func TestSyncProgress63Full(t *testing.T) { testSyncProgress(t, 63, FullSync) } func TestSyncProgress63Fast(t *testing.T) { testSyncProgress(t, 63, FastSync) } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index d9e4b6c333..1c795f9f80 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -50,8 +50,8 @@ func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) { t1 = newTestTransport(&key2.PublicKey, fd1) t2 = newTestTransport(&key1.PublicKey, fd2) ) - c1 := &conn{fd: fd1, transport: t1} - c2 := &conn{fd: fd2, transport: t2} + c1 := &conn{fd: fd1, node: newNode(randomID(), nil), transport: t1} + c2 := &conn{fd: fd2, node: newNode(randomID(), nil), transport: t2} for _, p := range protos { c1.caps = append(c1.caps, p.cap()) c2.caps = append(c2.caps, p.cap()) From a73dfa8f1f38ebb922d961dff9b7a2f825cd23f7 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Wed, 1 Nov 2023 00:22:08 +0700 Subject: [PATCH 104/119] Fix swarm upload unit test --- cmd/swarm/run_test.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cmd/swarm/run_test.go b/cmd/swarm/run_test.go index 94b4397b2a..3a53ef0091 100644 --- a/cmd/swarm/run_test.go +++ b/cmd/swarm/run_test.go @@ -110,9 +110,15 @@ func newTestCluster(t *testing.T, size int) *testCluster { } // connect the nodes together - for _, node := range cluster.Nodes { - if err := node.Client.Call(nil, "admin_addPeer", cluster.Nodes[0].Enode); err != nil { - t.Fatal(err) + for i, node := range cluster.Nodes { + // TODO(trinhdn2): only need to peer with cluster.Nodes[0], fix this later + for j := 0; j < size; j++ { + if i == j { + continue + } + if err := node.Client.Call(nil, "admin_addPeer", cluster.Nodes[j].Enode); err != nil { + t.Fatal(err) + } } } From 301300cef59afd7b22dcbf6f764580828a45c27b Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 30 Nov 2023 14:50:29 +0700 Subject: [PATCH 105/119] Implement bi-directional communication --- rpc/client.go | 599 +++++++++++++++----------------------------- rpc/endpoints.go | 101 ++++++++ rpc/errors.go | 27 +- rpc/handler.go | 405 ++++++++++++++++++++++++++++++ rpc/http.go | 171 ++++++++++--- rpc/inproc.go | 6 +- rpc/ipc.go | 27 +- rpc/json.go | 477 ++++++++++++++++------------------- rpc/json_test.go | 178 ------------- rpc/server.go | 447 ++++++--------------------------- rpc/service.go | 285 +++++++++++++++++++++ rpc/stdio.go | 54 ++++ rpc/subscription.go | 330 +++++++++++++++++++----- rpc/types.go | 82 +----- rpc/utils.go | 241 ------------------ rpc/utils_test.go | 43 ---- rpc/websocket.go | 83 ++++-- 17 files changed, 1841 insertions(+), 1715 deletions(-) create mode 100644 rpc/endpoints.go create mode 100644 rpc/handler.go delete mode 100644 rpc/json_test.go create mode 100644 rpc/service.go create mode 100644 rpc/stdio.go delete mode 100644 rpc/utils.go delete mode 100644 rpc/utils_test.go diff --git a/rpc/client.go b/rpc/client.go index 9d665d91e9..93ca384715 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -18,17 +18,13 @@ package rpc import ( "bytes" - "container/list" "context" "encoding/json" "errors" "fmt" - "net" "net/url" "reflect" "strconv" - "strings" - "sync" "sync/atomic" "time" @@ -39,13 +35,14 @@ var ( ErrClientQuit = errors.New("client is closed") ErrNoResult = errors.New("no result in JSON-RPC response") ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") + errClientReconnected = errors.New("client reconnected") + errDead = errors.New("connection lost") ) const ( // Timeouts tcpKeepAliveInterval = 30 * time.Second - defaultDialTimeout = 10 * time.Second // used when dialing if the context has no deadline - defaultWriteTimeout = 10 * time.Second // used for calls if the context has no deadline + defaultDialTimeout = 10 * time.Second // used if context has no deadline subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls ) @@ -60,7 +57,7 @@ const ( // The approach taken here is to maintain a per-subscription linked list buffer // shrinks on demand. If the buffer reaches the size below, the subscription is // dropped. - maxClientSubscriptionBuffer = 8000 + maxClientSubscriptionBuffer = 20000 ) // BatchElem is an element in a batch request. @@ -76,55 +73,57 @@ type BatchElem struct { Error error } -// A value of this type can a JSON-RPC request, notification, successful response or -// error response. Which one it is depends on the fields. -type jsonrpcMessage struct { - Version string `json:"jsonrpc"` - ID json.RawMessage `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Params json.RawMessage `json:"params,omitempty"` - Error *jsonError `json:"error,omitempty"` - Result json.RawMessage `json:"result,omitempty"` -} +// Client represents a connection to an RPC server. +type Client struct { + idgen func() ID // for subscriptions + isHTTP bool + services *serviceRegistry -func (msg *jsonrpcMessage) isNotification() bool { - return msg.ID == nil && msg.Method != "" -} + idCounter uint32 -func (msg *jsonrpcMessage) isResponse() bool { - return msg.hasValidID() && msg.Method == "" && len(msg.Params) == 0 -} + // This function, if non-nil, is called when the connection is lost. + reconnectFunc reconnectFunc -func (msg *jsonrpcMessage) hasValidID() bool { - return len(msg.ID) > 0 && msg.ID[0] != '{' && msg.ID[0] != '[' + // writeConn is used for writing to the connection on the caller's goroutine. It should + // only be accessed outside of dispatch, with the write lock held. The write lock is + // taken by sending on requestOp and released by sending on sendDone. + writeConn jsonWriter + + // for dispatch + close chan struct{} + closing chan struct{} // closed when client is quitting + didClose chan struct{} // closed when client quits + reconnected chan ServerCodec // where write/reconnect sends the new connection + readOp chan readOp // read messages + readErr chan error // errors from read + reqInit chan *requestOp // register response IDs, takes write lock + reqSent chan error // signals write completion, releases write lock + reqTimeout chan *requestOp // removes response IDs when call timeout expires } -func (msg *jsonrpcMessage) String() string { - b, _ := json.Marshal(msg) - return string(b) +type reconnectFunc func(ctx context.Context) (ServerCodec, error) + +type clientContextKey struct{} + +type clientConn struct { + codec ServerCodec + handler *handler } -// Client represents a connection to an RPC server. -type Client struct { - idCounter uint32 - connectFunc func(ctx context.Context) (net.Conn, error) - isHTTP bool +func (c *Client) newClientConn(conn ServerCodec) *clientConn { + ctx := context.WithValue(context.Background(), clientContextKey{}, c) + handler := newHandler(ctx, conn, c.idgen, c.services) + return &clientConn{conn, handler} +} - // writeConn is only safe to access outside dispatch, with the - // write lock held. The write lock is taken by sending on - // requestOp and released by sending on sendDone. - writeConn net.Conn +func (cc *clientConn) close(err error, inflightReq *requestOp) { + cc.handler.close(err, inflightReq) + cc.codec.Close() +} - // for dispatch - close chan struct{} - didQuit chan struct{} // closed when client quits - reconnected chan net.Conn // where write/reconnect sends the new connection - readErr chan error // errors from read - readResp chan []*jsonrpcMessage // valid messages from read - requestOp chan *requestOp // for registering response IDs - sendDone chan error // signals write completion, releases write lock - respWait map[string]*requestOp // active requests - subs map[string]*ClientSubscription // active subscriptions +type readOp struct { + msgs []*jsonrpcMessage + batch bool } type requestOp struct { @@ -134,9 +133,14 @@ type requestOp struct { sub *ClientSubscription // only set for EthSubscribe requests } -func (op *requestOp) wait(ctx context.Context) (*jsonrpcMessage, error) { +func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) { select { case <-ctx.Done(): + // Send the timeout to dispatch so it can remove the request IDs. + select { + case c.reqTimeout <- op: + case <-c.closing: + } return nil, ctx.Err() case resp := <-op.resp: return resp, op.err @@ -171,6 +175,8 @@ func DialContext(ctx context.Context, rawurl string) (*Client, error) { return DialHTTP(rawurl) case "ws", "wss": return DialWebsocket(ctx, rawurl, "") + case "stdio": + return DialStdIO(ctx) case "": return DialIPC(ctx, rawurl) default: @@ -178,36 +184,57 @@ func DialContext(ctx context.Context, rawurl string) (*Client, error) { } } -func newClient(initctx context.Context, connectFunc func(context.Context) (net.Conn, error)) (*Client, error) { - conn, err := connectFunc(initctx) +// ClientFromContext retrieves the client from the context, if any. This can be used to perform +// 'reverse calls' in a handler method. +func ClientFromContext(ctx context.Context) (*Client, bool) { + client, ok := ctx.Value(clientContextKey{}).(*Client) + return client, ok +} + +func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) { + conn, err := connect(initctx) if err != nil { return nil, err } - _, isHTTP := conn.(*httpConn) + c := initClient(conn, randomIDGenerator(), new(serviceRegistry)) + c.reconnectFunc = connect + return c, nil +} +func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client { + _, isHTTP := conn.(*httpConn) c := &Client{ - writeConn: conn, + idgen: idgen, isHTTP: isHTTP, - connectFunc: connectFunc, + services: services, + writeConn: conn, close: make(chan struct{}), - didQuit: make(chan struct{}), - reconnected: make(chan net.Conn), + closing: make(chan struct{}), + didClose: make(chan struct{}), + reconnected: make(chan ServerCodec), + readOp: make(chan readOp), readErr: make(chan error), - readResp: make(chan []*jsonrpcMessage), - requestOp: make(chan *requestOp), - sendDone: make(chan error, 1), - respWait: make(map[string]*requestOp), - subs: make(map[string]*ClientSubscription), + reqInit: make(chan *requestOp), + reqSent: make(chan error, 1), + reqTimeout: make(chan *requestOp), } if !isHTTP { go c.dispatch(conn) } - return c, nil + return c +} + +// RegisterName creates a service for the given receiver type under the given name. When no +// methods on the given receiver match the criteria to be either a RPC method or a +// subscription an error is returned. Otherwise a new service is created and added to the +// service collection this client provides to the server. +func (c *Client) RegisterName(name string, receiver interface{}) error { + return c.services.registerName(name, receiver) } func (c *Client) nextID() json.RawMessage { id := atomic.AddUint32(&c.idCounter, 1) - return []byte(strconv.FormatUint(uint64(id), 10)) + return strconv.AppendUint(nil, uint64(id), 10) } // SupportedModules calls the rpc_modules method, retrieving the list of @@ -227,8 +254,8 @@ func (c *Client) Close() { } select { case c.close <- struct{}{}: - <-c.didQuit - case <-c.didQuit: + <-c.didClose + case <-c.didClose: } } @@ -263,8 +290,8 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str return err } - // dispatch has accepted the request and will close the channel it when it quits. - switch resp, err := op.wait(ctx); { + // dispatch has accepted the request and will close the channel when it quits. + switch resp, err := op.wait(ctx, c); { case err != nil: return err case resp.Error != nil: @@ -276,40 +303,6 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str } } -// CallContext performs a JSON-RPC call with the given arguments. If the context is -// canceled before the call has successfully returned, CallContext returns immediately. -// -// The result must be a pointer so that package json can unmarshal into it. You -// can also pass nil, in which case the result is ignored. -func (c *Client) GetResultCallContext(ctx context.Context, result interface{}, method string, args ...interface{}) (json.RawMessage, error) { - msg, err := c.newMessage(method, args...) - if err != nil { - return nil, err - } - op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)} - - if c.isHTTP { - err = c.sendHTTP(ctx, op, msg) - } else { - err = c.send(ctx, op, msg) - } - if err != nil { - return nil, err - } - - // dispatch has accepted the request and will close the channel it when it quits. - switch resp, err := op.wait(ctx); { - case err != nil: - return nil, err - case resp.Error != nil: - return nil, resp.Error - case len(resp.Result) == 0: - return nil, ErrNoResult - default: - return resp.Result, json.Unmarshal(resp.Result, &result) - } -} - // BatchCall sends all given requests as a single batch and waits for the server // to return a response for all of them. // @@ -322,7 +315,7 @@ func (c *Client) BatchCall(b []BatchElem) error { return c.BatchCallContext(ctx, b) } -// BatchCall sends all given requests as a single batch and waits for the server +// BatchCallContext sends all given requests as a single batch and waits for the server // to return a response for all of them. The wait duration is bounded by the // context's deadline. // @@ -356,7 +349,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { // Wait for all responses to come back. for n := 0; n < len(b) && err == nil; n++ { var resp *jsonrpcMessage - resp, err = op.wait(ctx) + resp, err = op.wait(ctx, c) if err != nil { break } @@ -383,6 +376,22 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { return err } +// Notify sends a notification, i.e. a method call that doesn't expect a response. +func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) error { + op := new(requestOp) + msg, err := c.newMessage(method, args...) + if err != nil { + return err + } + msg.ID = nil + + if c.isHTTP { + return c.sendHTTP(ctx, op, msg) + } else { + return c.send(ctx, op, msg) + } +} + // EthSubscribe registers a subscripion under the "eth" namespace. func (c *Client) EthSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { return c.Subscribe(ctx, "eth", channel, args...) @@ -433,53 +442,48 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf if err := c.send(ctx, op, msg); err != nil { return nil, err } - if _, err := op.wait(ctx); err != nil { + if _, err := op.wait(ctx, c); err != nil { return nil, err } return op.sub, nil } func (c *Client) newMessage(method string, paramsIn ...interface{}) (*jsonrpcMessage, error) { - params, err := json.Marshal(paramsIn) - if err != nil { - return nil, err + msg := &jsonrpcMessage{Version: vsn, ID: c.nextID(), Method: method} + if paramsIn != nil { // prevent sending "params":null + var err error + if msg.Params, err = json.Marshal(paramsIn); err != nil { + return nil, err + } } - return &jsonrpcMessage{Version: "2.0", ID: c.nextID(), Method: method, Params: params}, nil + return msg, nil } // send registers op with the dispatch loop, then sends msg on the connection. // if sending fails, op is deregistered. func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error { select { - case c.requestOp <- op: - log.Trace("", "msg", log.Lazy{Fn: func() string { - return fmt.Sprint("sending ", msg) - }}) + case c.reqInit <- op: err := c.write(ctx, msg) - c.sendDone <- err + c.reqSent <- err return err case <-ctx.Done(): // This can happen if the client is overloaded or unable to keep up with // subscription notifications. return ctx.Err() - case <-c.didQuit: + case <-c.closing: return ErrClientQuit } } func (c *Client) write(ctx context.Context, msg interface{}) error { - deadline, ok := ctx.Deadline() - if !ok { - deadline = time.Now().Add(defaultWriteTimeout) - } // The previous write failed. Try to establish a new connection. if c.writeConn == nil { if err := c.reconnect(ctx); err != nil { return err } } - c.writeConn.SetWriteDeadline(deadline) - err := json.NewEncoder(c.writeConn).Encode(msg) + err := c.writeConn.Write(ctx, msg) if err != nil { c.writeConn = nil } @@ -487,16 +491,25 @@ func (c *Client) write(ctx context.Context, msg interface{}) error { } func (c *Client) reconnect(ctx context.Context) error { - newconn, err := c.connectFunc(ctx) + if c.reconnectFunc == nil { + return errDead + } + + if _, ok := ctx.Deadline(); !ok { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, defaultDialTimeout) + defer cancel() + } + newconn, err := c.reconnectFunc(ctx) if err != nil { - log.Trace(fmt.Sprintf("reconnect failed: %v", err)) + log.Trace("RPC client reconnect failed", "err", err) return err } select { case c.reconnected <- newconn: c.writeConn = newconn return nil - case <-c.didQuit: + case <-c.didClose: newconn.Close() return ErrClientQuit } @@ -505,321 +518,107 @@ func (c *Client) reconnect(ctx context.Context) error { // dispatch is the main loop of the client. // It sends read messages to waiting calls to Call and BatchCall // and subscription notifications to registered subscriptions. -func (c *Client) dispatch(conn net.Conn) { - // Spawn the initial read loop. - go c.read(conn) - +func (c *Client) dispatch(codec ServerCodec) { var ( - lastOp *requestOp // tracks last send operation - requestOpLock = c.requestOp // nil while the send lock is held - reading = true // if true, a read loop is running + lastOp *requestOp // tracks last send operation + reqInitLock = c.reqInit // nil while the send lock is held + conn = c.newClientConn(codec) + reading = true ) - defer close(c.didQuit) defer func() { - c.closeRequestOps(ErrClientQuit) - conn.Close() + close(c.closing) if reading { - // Empty read channels until read is dead. - for { - select { - case <-c.readResp: - case <-c.readErr: - return - } - } + conn.close(ErrClientQuit, nil) + c.drainRead() } + close(c.didClose) }() + // Spawn the initial read loop. + go c.read(codec) + for { select { case <-c.close: return - // Read path. - case batch := <-c.readResp: - for _, msg := range batch { - switch { - case msg.isNotification(): - log.Trace("", "msg", log.Lazy{Fn: func() string { - return fmt.Sprint("<-readResp: notification ", msg) - }}) - c.handleNotification(msg) - case msg.isResponse(): - log.Trace("", "msg", log.Lazy{Fn: func() string { - return fmt.Sprint("<-readResp: response ", msg) - }}) - c.handleResponse(msg) - default: - log.Debug("", "msg", log.Lazy{Fn: func() string { - return fmt.Sprint("<-readResp: dropping weird message", msg) - }}) - // TODO: maybe close - } + // Read path: + case op := <-c.readOp: + if op.batch { + conn.handler.handleBatch(op.msgs) + } else { + conn.handler.handleMsg(op.msgs[0]) } case err := <-c.readErr: - log.Debug(fmt.Sprintf("<-readErr: %v", err)) - c.closeRequestOps(err) - conn.Close() + conn.handler.log.Debug("RPC connection read error", "err", err) + conn.close(err, lastOp) reading = false - case newconn := <-c.reconnected: - log.Debug(fmt.Sprintf("<-reconnected: (reading=%t) %v", reading, conn.RemoteAddr())) + // Reconnect: + case newcodec := <-c.reconnected: + log.Debug("RPC client reconnected", "reading", reading, "conn", newcodec.RemoteAddr()) if reading { - // Wait for the previous read loop to exit. This is a rare case. - conn.Close() - <-c.readErr + // Wait for the previous read loop to exit. This is a rare case which + // happens if this loop isn't notified in time after the connection breaks. + // In those cases the caller will notice first and reconnect. Closing the + // handler terminates all waiting requests (closing op.resp) except for + // lastOp, which will be transferred to the new handler. + conn.close(errClientReconnected, lastOp) + c.drainRead() } - go c.read(newconn) + go c.read(newcodec) reading = true - conn = newconn - - // Send path. - case op := <-requestOpLock: - // Stop listening for further send ops until the current one is done. - requestOpLock = nil + conn = c.newClientConn(newcodec) + // Re-register the in-flight request on the new handler + // because that's where it will be sent. + conn.handler.addRequestOp(lastOp) + + // Send path: + case op := <-reqInitLock: + // Stop listening for further requests until the current one has been sent. + reqInitLock = nil lastOp = op - for _, id := range op.ids { - c.respWait[string(id)] = op - } + conn.handler.addRequestOp(op) - case err := <-c.sendDone: + case err := <-c.reqSent: if err != nil { - // Remove response handlers for the last send. We remove those here - // because the error is already handled in Call or BatchCall. When the - // read loop goes down, it will signal all other current operations. - for _, id := range lastOp.ids { - delete(c.respWait, string(id)) - } + // Remove response handlers for the last send. When the read loop + // goes down, it will signal all other current operations. + conn.handler.removeRequestOp(lastOp) } - // Listen for send ops again. - requestOpLock = c.requestOp + // Let the next request in. + reqInitLock = c.reqInit lastOp = nil - } - } -} -// closeRequestOps unblocks pending send ops and active subscriptions. -func (c *Client) closeRequestOps(err error) { - didClose := make(map[*requestOp]bool) - - for id, op := range c.respWait { - // Remove the op so that later calls will not close op.resp again. - delete(c.respWait, id) - - if !didClose[op] { - op.err = err - close(op.resp) - didClose[op] = true + case op := <-c.reqTimeout: + conn.handler.removeRequestOp(op) } } - for id, sub := range c.subs { - delete(c.subs, id) - sub.quitWithError(err, false) - } -} - -func (c *Client) handleNotification(msg *jsonrpcMessage) { - if !strings.HasSuffix(msg.Method, notificationMethodSuffix) { - log.Debug(fmt.Sprint("dropping non-subscription message: ", msg)) - return - } - var subResult struct { - ID string `json:"subscription"` - Result json.RawMessage `json:"result"` - } - if err := json.Unmarshal(msg.Params, &subResult); err != nil { - log.Debug(fmt.Sprint("dropping invalid subscription message: ", msg)) - return - } - if c.subs[subResult.ID] != nil { - c.subs[subResult.ID].deliver(subResult.Result) - } } -func (c *Client) handleResponse(msg *jsonrpcMessage) { - op := c.respWait[string(msg.ID)] - if op == nil { - log.Debug(fmt.Sprintf("unsolicited response %v", msg)) - return - } - delete(c.respWait, string(msg.ID)) - // For normal responses, just forward the reply to Call/BatchCall. - if op.sub == nil { - op.resp <- msg - return - } - // For subscription responses, start the subscription if the server - // indicates success. EthSubscribe gets unblocked in either case through - // the op.resp channel. - defer close(op.resp) - if msg.Error != nil { - op.err = msg.Error - return - } - if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil { - go op.sub.start() - c.subs[op.sub.subid] = op.sub - } -} - -// Reading happens on a dedicated goroutine. - -func (c *Client) read(conn net.Conn) error { - var ( - buf json.RawMessage - dec = json.NewDecoder(conn) - ) - readMessage := func() (rs []*jsonrpcMessage, err error) { - buf = buf[:0] - if err = dec.Decode(&buf); err != nil { - return nil, err - } - if isBatch(buf) { - err = json.Unmarshal(buf, &rs) - } else { - rs = make([]*jsonrpcMessage, 1) - err = json.Unmarshal(buf, &rs[0]) - } - return rs, err - } - +// drainRead drops read messages until an error occurs. +func (c *Client) drainRead() { for { - resp, err := readMessage() - if err != nil { - c.readErr <- err - return err - } - c.readResp <- resp - } -} - -// Subscriptions. - -// A ClientSubscription represents a subscription established through EthSubscribe. -type ClientSubscription struct { - client *Client - etype reflect.Type - channel reflect.Value - namespace string - subid string - in chan json.RawMessage - - quitOnce sync.Once // ensures quit is closed once - quit chan struct{} // quit is closed when the subscription exits - errOnce sync.Once // ensures err is closed once - err chan error -} - -func newClientSubscription(c *Client, namespace string, channel reflect.Value) *ClientSubscription { - sub := &ClientSubscription{ - client: c, - namespace: namespace, - etype: channel.Type().Elem(), - channel: channel, - quit: make(chan struct{}), - err: make(chan error, 1), - in: make(chan json.RawMessage), - } - return sub -} - -// Err returns the subscription error channel. The intended use of Err is to schedule -// resubscription when the client connection is closed unexpectedly. -// -// The error channel receives a value when the subscription has ended due -// to an error. The received error is nil if Close has been called -// on the underlying client and no other error has occurred. -// -// The error channel is closed when Unsubscribe is called on the subscription. -func (sub *ClientSubscription) Err() <-chan error { - return sub.err -} - -// Unsubscribe unsubscribes the notification and closes the error channel. -// It can safely be called more than once. -func (sub *ClientSubscription) Unsubscribe() { - sub.quitWithError(nil, true) - sub.errOnce.Do(func() { close(sub.err) }) -} - -func (sub *ClientSubscription) quitWithError(err error, unsubscribeServer bool) { - sub.quitOnce.Do(func() { - // The dispatch loop won't be able to execute the unsubscribe call - // if it is blocked on deliver. Close sub.quit first because it - // unblocks deliver. - close(sub.quit) - if unsubscribeServer { - sub.requestUnsubscribe() - } - if err != nil { - if err == ErrClientQuit { - err = nil // Adhere to subscription semantics. - } - sub.err <- err + select { + case <-c.readOp: + case <-c.readErr: + return } - }) -} - -func (sub *ClientSubscription) deliver(result json.RawMessage) (ok bool) { - select { - case sub.in <- result: - return true - case <-sub.quit: - return false } } -func (sub *ClientSubscription) start() { - sub.quitWithError(sub.forward()) -} - -func (sub *ClientSubscription) forward() (err error, unsubscribeServer bool) { - cases := []reflect.SelectCase{ - {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.quit)}, - {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.in)}, - {Dir: reflect.SelectSend, Chan: sub.channel}, - } - buffer := list.New() - defer buffer.Init() +// read decodes RPC messages from a codec, feeding them into dispatch. +func (c *Client) read(codec ServerCodec) { for { - var chosen int - var recv reflect.Value - if buffer.Len() == 0 { - // Idle, omit send case. - chosen, recv, _ = reflect.Select(cases[:2]) - } else { - // Non-empty buffer, send the first queued item. - cases[2].Send = reflect.ValueOf(buffer.Front().Value) - chosen, recv, _ = reflect.Select(cases) + msgs, batch, err := codec.Read() + if _, ok := err.(*json.SyntaxError); ok { + codec.Write(context.Background(), errorMessage(&parseError{err.Error()})) } - - switch chosen { - case 0: // <-sub.quit - return nil, false - case 1: // <-sub.in - val, err := sub.unmarshal(recv.Interface().(json.RawMessage)) - if err != nil { - return err, true - } - if buffer.Len() == maxClientSubscriptionBuffer { - return ErrSubscriptionQueueOverflow, true - } - buffer.PushBack(val) - case 2: // sub.channel<- - cases[2].Send = reflect.Value{} // Don't hold onto the value. - buffer.Remove(buffer.Front()) + if err != nil { + c.readErr <- err + return } + c.readOp <- readOp{msgs, batch} } } - -func (sub *ClientSubscription) unmarshal(result json.RawMessage) (interface{}, error) { - val := reflect.New(sub.etype) - err := json.Unmarshal(result, val.Interface()) - return val.Elem().Interface(), err -} - -func (sub *ClientSubscription) requestUnsubscribe() error { - var result interface{} - return sub.client.Call(&result, sub.namespace+unsubscribeMethodSuffix, sub.subid) -} diff --git a/rpc/endpoints.go b/rpc/endpoints.go new file mode 100644 index 0000000000..d91b00fea1 --- /dev/null +++ b/rpc/endpoints.go @@ -0,0 +1,101 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rpc + +import ( + "net" + + "github.com/tomochain/tomochain/log" +) + +// StartHTTPEndpoint starts the HTTP RPC endpoint, configured with cors/vhosts/modules +func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []string, vhosts []string, timeouts HTTPTimeouts) (net.Listener, *Server, error) { + // Generate the whitelist based on the allowed modules + whitelist := make(map[string]bool) + for _, module := range modules { + whitelist[module] = true + } + // Register all the APIs exposed by the services + handler := NewServer() + for _, api := range apis { + if whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { + if err := handler.RegisterName(api.Namespace, api.Service); err != nil { + return nil, nil, err + } + log.Debug("HTTP registered", "namespace", api.Namespace) + } + } + // All APIs registered, start the HTTP listener + var ( + listener net.Listener + err error + ) + if listener, err = net.Listen("tcp", endpoint); err != nil { + return nil, nil, err + } + go NewHTTPServer(cors, vhosts, timeouts, handler).Serve(listener) + return listener, handler, err +} + +// StartWSEndpoint starts a websocket endpoint +func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []string, exposeAll bool) (net.Listener, *Server, error) { + // Generate the whitelist based on the allowed modules + whitelist := make(map[string]bool) + for _, module := range modules { + whitelist[module] = true + } + // Register all the APIs exposed by the services + handler := NewServer() + for _, api := range apis { + if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { + if err := handler.RegisterName(api.Namespace, api.Service); err != nil { + return nil, nil, err + } + log.Debug("WebSocket registered", "service", api.Service, "namespace", api.Namespace) + } + } + // All APIs registered, start the HTTP listener + var ( + listener net.Listener + err error + ) + if listener, err = net.Listen("tcp", endpoint); err != nil { + return nil, nil, err + } + go NewWSServer(wsOrigins, handler).Serve(listener) + return listener, handler, err + +} + +// StartIPCEndpoint starts an IPC endpoint. +func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, error) { + // Register all the APIs exposed by the services. + handler := NewServer() + for _, api := range apis { + if err := handler.RegisterName(api.Namespace, api.Service); err != nil { + return nil, nil, err + } + log.Debug("IPC registered", "namespace", api.Namespace) + } + // All APIs registered, start the IPC listener. + listener, err := ipcListen(ipcEndpoint) + if err != nil { + return nil, nil, err + } + go handler.ServeListener(listener) + return listener, handler, nil +} diff --git a/rpc/errors.go b/rpc/errors.go index 10509a533b..6c8b8e5899 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -18,10 +18,14 @@ package rpc import "fmt" -// request is for an unknown service -type methodNotFoundError struct { - service string - method string +const defaultErrorCode = -32000 + +type methodNotFoundError struct{ method string } + +func (e *methodNotFoundError) ErrorCode() int { return -32601 } + +func (e *methodNotFoundError) Error() string { + return fmt.Sprintf("the method %s does not exist/is not available", e.method) } // A DataError contains some data in addition to the error message. @@ -30,12 +34,21 @@ type DataError interface { ErrorData() interface{} // returns the error data } -func (e *methodNotFoundError) ErrorCode() int { return -32601 } +type subscriptionNotFoundError struct{ namespace, subscription string } -func (e *methodNotFoundError) Error() string { - return fmt.Sprintf("The method %s%s%s does not exist/is not available", e.service, serviceMethodSeparator, e.method) +func (e *subscriptionNotFoundError) ErrorCode() int { return -32601 } + +func (e *subscriptionNotFoundError) Error() string { + return fmt.Sprintf("no %q subscription in %s namespace", e.subscription, e.namespace) } +// Invalid JSON was received by the server. +type parseError struct{ message string } + +func (e *parseError) ErrorCode() int { return -32700 } + +func (e *parseError) Error() string { return e.message } + // received message isn't a valid request type invalidRequestError struct{ message string } diff --git a/rpc/handler.go b/rpc/handler.go new file mode 100644 index 0000000000..3507dd8411 --- /dev/null +++ b/rpc/handler.go @@ -0,0 +1,405 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rpc + +import ( + "context" + "encoding/json" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/tomochain/tomochain/log" +) + +// handler handles JSON-RPC messages. There is one handler per connection. Note that +// handler is not safe for concurrent use. Message handling never blocks indefinitely +// because RPCs are processed on background goroutines launched by handler. +// +// The entry points for incoming messages are: +// +// h.handleMsg(message) +// h.handleBatch(message) +// +// Outgoing calls use the requestOp struct. Register the request before sending it +// on the connection: +// +// op := &requestOp{ids: ...} +// h.addRequestOp(op) +// +// Now send the request, then wait for the reply to be delivered through handleMsg: +// +// if err := op.wait(...); err != nil { +// h.removeRequestOp(op) // timeout, etc. +// } +type handler struct { + reg *serviceRegistry + unsubscribeCb *callback + idgen func() ID // subscription ID generator + respWait map[string]*requestOp // active client requests + clientSubs map[string]*ClientSubscription // active client subscriptions + callWG sync.WaitGroup // pending call goroutines + rootCtx context.Context // canceled by close() + cancelRoot func() // cancel function for rootCtx + conn jsonWriter // where responses will be sent + log log.Logger + allowSubscribe bool + + subLock sync.Mutex + serverSubs map[ID]*Subscription +} + +type callProc struct { + ctx context.Context + notifiers []*Notifier +} + +func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler { + rootCtx, cancelRoot := context.WithCancel(connCtx) + h := &handler{ + reg: reg, + idgen: idgen, + conn: conn, + respWait: make(map[string]*requestOp), + clientSubs: make(map[string]*ClientSubscription), + rootCtx: rootCtx, + cancelRoot: cancelRoot, + allowSubscribe: true, + serverSubs: make(map[ID]*Subscription), + log: log.Root(), + } + if conn.RemoteAddr() != "" { + h.log = h.log.New("conn", conn.RemoteAddr()) + } + h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe)) + return h +} + +// handleBatch executes all messages in a batch and returns the responses. +func (h *handler) handleBatch(msgs []*jsonrpcMessage) { + // Emit error response for empty batches: + if len(msgs) == 0 { + h.startCallProc(func(cp *callProc) { + h.conn.Write(cp.ctx, errorMessage(&invalidRequestError{"empty batch"})) + }) + return + } + + // Handle non-call messages first: + calls := make([]*jsonrpcMessage, 0, len(msgs)) + for _, msg := range msgs { + if handled := h.handleImmediate(msg); !handled { + calls = append(calls, msg) + } + } + if len(calls) == 0 { + return + } + // Process calls on a goroutine because they may block indefinitely: + h.startCallProc(func(cp *callProc) { + answers := make([]*jsonrpcMessage, 0, len(msgs)) + for _, msg := range calls { + if answer := h.handleCallMsg(cp, msg); answer != nil { + answers = append(answers, answer) + } + } + h.addSubscriptions(cp.notifiers) + if len(answers) > 0 { + h.conn.Write(cp.ctx, answers) + } + for _, n := range cp.notifiers { + n.activate() + } + }) +} + +// handleMsg handles a single message. +func (h *handler) handleMsg(msg *jsonrpcMessage) { + if ok := h.handleImmediate(msg); ok { + return + } + h.startCallProc(func(cp *callProc) { + answer := h.handleCallMsg(cp, msg) + h.addSubscriptions(cp.notifiers) + if answer != nil { + h.conn.Write(cp.ctx, answer) + } + for _, n := range cp.notifiers { + n.activate() + } + }) +} + +// close cancels all requests except for inflightReq and waits for +// call goroutines to shut down. +func (h *handler) close(err error, inflightReq *requestOp) { + h.cancelAllRequests(err, inflightReq) + h.cancelRoot() + h.callWG.Wait() + h.cancelServerSubscriptions(err) +} + +// addRequestOp registers a request operation. +func (h *handler) addRequestOp(op *requestOp) { + for _, id := range op.ids { + h.respWait[string(id)] = op + } +} + +// removeRequestOps stops waiting for the given request IDs. +func (h *handler) removeRequestOp(op *requestOp) { + for _, id := range op.ids { + delete(h.respWait, string(id)) + } +} + +// cancelAllRequests unblocks and removes pending requests and active subscriptions. +func (h *handler) cancelAllRequests(err error, inflightReq *requestOp) { + didClose := make(map[*requestOp]bool) + if inflightReq != nil { + didClose[inflightReq] = true + } + + for id, op := range h.respWait { + // Remove the op so that later calls will not close op.resp again. + delete(h.respWait, id) + + if !didClose[op] { + op.err = err + close(op.resp) + didClose[op] = true + } + } + for id, sub := range h.clientSubs { + delete(h.clientSubs, id) + sub.quitWithError(err, false) + } +} + +func (h *handler) addSubscriptions(nn []*Notifier) { + h.subLock.Lock() + defer h.subLock.Unlock() + + for _, n := range nn { + if sub := n.takeSubscription(); sub != nil { + h.serverSubs[sub.ID] = sub + } + } +} + +// cancelServerSubscriptions removes all subscriptions and closes their error channels. +func (h *handler) cancelServerSubscriptions(err error) { + h.subLock.Lock() + defer h.subLock.Unlock() + + for id, s := range h.serverSubs { + s.err <- err + close(s.err) + delete(h.serverSubs, id) + } +} + +// startCallProc runs fn in a new goroutine and starts tracking it in the h.calls wait group. +func (h *handler) startCallProc(fn func(*callProc)) { + h.callWG.Add(1) + go func() { + ctx, cancel := context.WithCancel(h.rootCtx) + defer h.callWG.Done() + defer cancel() + fn(&callProc{ctx: ctx}) + }() +} + +// handleImmediate executes non-call messages. It returns false if the message is a +// call or requires a reply. +func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { + start := time.Now() + switch { + case msg.isNotification(): + if strings.HasSuffix(msg.Method, notificationMethodSuffix) { + h.handleSubscriptionResult(msg) + return true + } + return false + case msg.isResponse(): + h.handleResponse(msg) + h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "t", time.Since(start)) + return true + default: + return false + } +} + +// handleSubscriptionResult processes subscription notifications. +func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { + var result subscriptionResult + if err := json.Unmarshal(msg.Params, &result); err != nil { + h.log.Debug("Dropping invalid subscription message") + return + } + if h.clientSubs[result.ID] != nil { + h.clientSubs[result.ID].deliver(result.Result) + } +} + +// handleResponse processes method call responses. +func (h *handler) handleResponse(msg *jsonrpcMessage) { + op := h.respWait[string(msg.ID)] + if op == nil { + h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID}) + return + } + delete(h.respWait, string(msg.ID)) + // For normal responses, just forward the reply to Call/BatchCall. + if op.sub == nil { + op.resp <- msg + return + } + // For subscription responses, start the subscription if the server + // indicates success. EthSubscribe gets unblocked in either case through + // the op.resp channel. + defer close(op.resp) + if msg.Error != nil { + op.err = msg.Error + return + } + if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil { + go op.sub.start() + h.clientSubs[op.sub.subid] = op.sub + } +} + +// handleCallMsg executes a call message and returns the answer. +func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage { + start := time.Now() + switch { + case msg.isNotification(): + h.handleCall(ctx, msg) + h.log.Debug("Served "+msg.Method, "duration", time.Since(start)) + return nil + + case msg.isCall(): + resp := h.handleCall(ctx, msg) + var ctx []interface{} + ctx = append(ctx, "reqid", idForLog{msg.ID}, "duration", time.Since(start)) + if resp.Error != nil { + ctx = append(ctx, "err", resp.Error.Message) + if resp.Error.Data != nil { + ctx = append(ctx, "errdata", resp.Error.Data) + } + h.log.Warn("Served "+msg.Method, ctx...) + } else { + h.log.Debug("Served "+msg.Method, ctx...) + } + return resp + + case msg.hasValidID(): + return msg.errorResponse(&invalidRequestError{"invalid request"}) + + default: + return errorMessage(&invalidRequestError{"invalid request"}) + } +} + +// handleCall processes method calls. +func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { + if msg.isSubscribe() { + return h.handleSubscribe(cp, msg) + } + var callb *callback + if msg.isUnsubscribe() { + callb = h.unsubscribeCb + } else { + callb = h.reg.callback(msg.Method) + } + if callb == nil { + return msg.errorResponse(&methodNotFoundError{method: msg.Method}) + } + args, err := parsePositionalArguments(msg.Params, callb.argTypes) + if err != nil { + return msg.errorResponse(&invalidParamsError{err.Error()}) + } + + return h.runMethod(cp.ctx, msg, callb, args) +} + +// handleSubscribe processes *_subscribe method calls. +func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { + if !h.allowSubscribe { + return msg.errorResponse(ErrNotificationsUnsupported) + } + + // Subscription method name is first argument. + name, err := parseSubscriptionName(msg.Params) + if err != nil { + return msg.errorResponse(&invalidParamsError{err.Error()}) + } + namespace := msg.namespace() + callb := h.reg.subscription(namespace, name) + if callb == nil { + return msg.errorResponse(&subscriptionNotFoundError{namespace, name}) + } + + // Parse subscription name arg too, but remove it before calling the callback. + argTypes := append([]reflect.Type{stringType}, callb.argTypes...) + args, err := parsePositionalArguments(msg.Params, argTypes) + if err != nil { + return msg.errorResponse(&invalidParamsError{err.Error()}) + } + args = args[1:] + + // Install notifier in context so the subscription handler can find it. + n := &Notifier{h: h, namespace: namespace} + cp.notifiers = append(cp.notifiers, n) + ctx := context.WithValue(cp.ctx, notifierKey{}, n) + + return h.runMethod(ctx, msg, callb, args) +} + +// runMethod runs the Go callback for an RPC method. +func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *callback, args []reflect.Value) *jsonrpcMessage { + result, err := callb.call(ctx, msg.Method, args) + if err != nil { + return msg.errorResponse(err) + } + return msg.response(result) +} + +// unsubscribe is the callback function for all *_unsubscribe calls. +func (h *handler) unsubscribe(ctx context.Context, id ID) (bool, error) { + h.subLock.Lock() + defer h.subLock.Unlock() + + s := h.serverSubs[id] + if s == nil { + return false, ErrSubscriptionNotFound + } + close(s.err) + delete(h.serverSubs, id) + return true, nil +} + +type idForLog struct{ json.RawMessage } + +func (id idForLog) String() string { + if s, err := strconv.Unquote(string(id.RawMessage)); err == nil { + return s + } + return string(id.RawMessage) +} diff --git a/rpc/http.go b/rpc/http.go index 32badac29c..640a8460c1 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -32,38 +32,77 @@ import ( "time" "github.com/rs/cors" + + "github.com/tomochain/tomochain/log" ) const ( + maxRequestContentLength = 1024 * 512 contentType = "application/json" - maxRequestContentLength = 1024 * 128 ) -var nullAddr, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:0") +// https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13 +var acceptedContentTypes = []string{contentType, "application/json-rpc", "application/jsonrequest"} type httpConn struct { client *http.Client req *http.Request closeOnce sync.Once - closed chan struct{} + closed chan interface{} } // httpConn is treated specially by Client. -func (hc *httpConn) LocalAddr() net.Addr { return nullAddr } -func (hc *httpConn) RemoteAddr() net.Addr { return nullAddr } -func (hc *httpConn) SetReadDeadline(time.Time) error { return nil } -func (hc *httpConn) SetWriteDeadline(time.Time) error { return nil } -func (hc *httpConn) SetDeadline(time.Time) error { return nil } -func (hc *httpConn) Write([]byte) (int, error) { panic("Write called") } - -func (hc *httpConn) Read(b []byte) (int, error) { +func (hc *httpConn) Write(context.Context, interface{}) error { + panic("Write called on httpConn") +} + +func (hc *httpConn) RemoteAddr() string { + return hc.req.URL.String() +} + +func (hc *httpConn) Read() ([]*jsonrpcMessage, bool, error) { <-hc.closed - return 0, io.EOF + return nil, false, io.EOF } -func (hc *httpConn) Close() error { +func (hc *httpConn) Close() { hc.closeOnce.Do(func() { close(hc.closed) }) - return nil +} + +func (hc *httpConn) Closed() <-chan interface{} { + return hc.closed +} + +// HTTPTimeouts represents the configuration params for the HTTP RPC server. +type HTTPTimeouts struct { + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. + // + // Because ReadTimeout does not let Handlers make per-request + // decisions on each request body's acceptable deadline or + // upload rate, most users will prefer to use + // ReadHeaderTimeout. It is valid to use them both. + ReadTimeout time.Duration + + // WriteTimeout is the maximum duration before timing out + // writes of the response. It is reset whenever a new + // request's header is read. Like ReadTimeout, it does not + // let Handlers make decisions on a per-request basis. + WriteTimeout time.Duration + + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. If IdleTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, ReadHeaderTimeout is used. + IdleTimeout time.Duration +} + +// DefaultHTTPTimeouts represents the default timeout values used if further +// configuration is not provided. +var DefaultHTTPTimeouts = HTTPTimeouts{ + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, } // DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP @@ -77,8 +116,8 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { req.Header.Set("Accept", contentType) initctx := context.Background() - return newClient(initctx, func(context.Context) (net.Conn, error) { - return &httpConn{client: client, req: req, closed: make(chan struct{})}, nil + return newClient(initctx, func(context.Context) (ServerCodec, error) { + return &httpConn{client: client, req: req, closed: make(chan interface{})}, nil }) } @@ -90,10 +129,19 @@ func DialHTTP(endpoint string) (*Client, error) { func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) error { hc := c.writeConn.(*httpConn) respBody, err := hc.doRequest(ctx, msg) + if respBody != nil { + defer respBody.Close() + } + if err != nil { + if respBody != nil { + buf := new(bytes.Buffer) + if _, err2 := buf.ReadFrom(respBody); err2 == nil { + return fmt.Errorf("%v %v", err, buf.String()) + } + } return err } - defer respBody.Close() var respmsg jsonrpcMessage if err := json.NewDecoder(respBody).Decode(&respmsg); err != nil { return err @@ -132,37 +180,68 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos if err != nil { return nil, err } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return resp.Body, errors.New(resp.Status) + } return resp.Body, nil } -// httpReadWriteNopCloser wraps a io.Reader and io.Writer with a NOP Close method. -type httpReadWriteNopCloser struct { +// httpServerConn turns a HTTP connection into a Conn. +type httpServerConn struct { io.Reader io.Writer + r *http.Request } -// Close does nothing and returns always nil -func (t *httpReadWriteNopCloser) Close() error { - return nil +func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec { + body := io.LimitReader(r.Body, maxRequestContentLength) + conn := &httpServerConn{Reader: body, Writer: w, r: r} + return NewJSONCodec(conn) } +// Close does nothing and always returns nil. +func (t *httpServerConn) Close() error { return nil } + +// RemoteAddr returns the peer address of the underlying connection. +func (t *httpServerConn) RemoteAddr() string { + return t.r.RemoteAddr +} + +// SetWriteDeadline does nothing and always returns nil. +func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil } + // NewHTTPServer creates a new HTTP RPC server around an API provider. // // Deprecated: Server implements http.Handler -func NewHTTPServer(cors []string, vhosts []string, srv *Server) *http.Server { +func NewHTTPServer(cors []string, vhosts []string, timeouts HTTPTimeouts, srv http.Handler) *http.Server { // Wrap the CORS-handler within a host-handler handler := newCorsHandler(srv, cors) handler = newVHostHandler(vhosts, handler) + + // Make sure timeout values are meaningful + if timeouts.ReadTimeout < time.Second { + log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", DefaultHTTPTimeouts.ReadTimeout) + timeouts.ReadTimeout = DefaultHTTPTimeouts.ReadTimeout + } + if timeouts.WriteTimeout < time.Second { + log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", DefaultHTTPTimeouts.WriteTimeout) + timeouts.WriteTimeout = DefaultHTTPTimeouts.WriteTimeout + } + if timeouts.IdleTimeout < time.Second { + log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", DefaultHTTPTimeouts.IdleTimeout) + timeouts.IdleTimeout = DefaultHTTPTimeouts.IdleTimeout + } + // Bundle and start the HTTP server return &http.Server{ Handler: handler, - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, + ReadTimeout: timeouts.ReadTimeout, + WriteTimeout: timeouts.WriteTimeout, + IdleTimeout: timeouts.IdleTimeout, } } // ServeHTTP serves JSON-RPC requests over HTTP. -func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Permit dumb empty requests for remote health-checks (AWS) if r.Method == http.MethodGet && r.ContentLength == 0 && r.URL.RawQuery == "" { return @@ -174,12 +253,21 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // All checks passed, create a codec that reads direct from the request body // untilEOF and writes the response to w and order the server to process a // single request. - body := io.LimitReader(r.Body, maxRequestContentLength) - codec := NewJSONCodec(&httpReadWriteNopCloser{body, w}) - defer codec.Close() + ctx := r.Context() + ctx = context.WithValue(ctx, "remote", r.RemoteAddr) + ctx = context.WithValue(ctx, "scheme", r.Proto) + ctx = context.WithValue(ctx, "local", r.Host) + if ua := r.Header.Get("User-Agent"); ua != "" { + ctx = context.WithValue(ctx, "User-Agent", ua) + } + if origin := r.Header.Get("Origin"); origin != "" { + ctx = context.WithValue(ctx, "Origin", origin) + } w.Header().Set("content-type", contentType) - srv.ServeSingleRequest(codec, OptionMethodInvocation) + codec := newHTTPServerConn(r, w) + defer codec.Close() + s.serveSingleRequest(ctx, codec) } // validateRequest returns a non-zero response code and error message if the @@ -192,15 +280,24 @@ func validateRequest(r *http.Request) (int, error) { err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxRequestContentLength) return http.StatusRequestEntityTooLarge, err } - mt, _, err := mime.ParseMediaType(r.Header.Get("content-type")) - if r.Method != http.MethodOptions && (err != nil || mt != contentType) { - err := fmt.Errorf("invalid content type, only %s is supported", contentType) - return http.StatusUnsupportedMediaType, err + // Allow OPTIONS (regardless of content-type) + if r.Method == http.MethodOptions { + return 0, nil + } + // Check content-type + if mt, _, err := mime.ParseMediaType(r.Header.Get("content-type")); err == nil { + for _, accepted := range acceptedContentTypes { + if accepted == mt { + return 0, nil + } + } } - return 0, nil + // Invalid content-type + err := fmt.Errorf("invalid content type, only %s is supported", contentType) + return http.StatusUnsupportedMediaType, err } -func newCorsHandler(srv *Server, allowedOrigins []string) http.Handler { +func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler { // disable CORS support if user has not specified a custom CORS configuration if len(allowedOrigins) == 0 { return srv diff --git a/rpc/inproc.go b/rpc/inproc.go index 595a7ca651..c4456cfc4b 100644 --- a/rpc/inproc.go +++ b/rpc/inproc.go @@ -21,13 +21,13 @@ import ( "net" ) -// NewInProcClient attaches an in-process connection to the given RPC server. +// DialInProc attaches an in-process connection to the given RPC server. func DialInProc(handler *Server) *Client { initctx := context.Background() - c, _ := newClient(initctx, func(context.Context) (net.Conn, error) { + c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { p1, p2 := net.Pipe() go handler.ServeCodec(NewJSONCodec(p1), OptionMethodInvocation|OptionSubscriptions) - return p2, nil + return NewJSONCodec(p2), nil }) return c } diff --git a/rpc/ipc.go b/rpc/ipc.go index 89f02d6bcd..b17e021cf4 100644 --- a/rpc/ipc.go +++ b/rpc/ipc.go @@ -18,27 +18,24 @@ package rpc import ( "context" - "fmt" "net" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/p2p/netutil" ) -// CreateIPCListener creates an listener, on Unix platforms this is a unix socket, on -// Windows this is a named pipe -func CreateIPCListener(endpoint string) (net.Listener, error) { - return ipcListen(endpoint) -} - // ServeListener accepts connections on l, serving JSON-RPC on them. -func (srv *Server) ServeListener(l net.Listener) error { +func (s *Server) ServeListener(l net.Listener) error { for { conn, err := l.Accept() - if err != nil { + if netutil.IsTemporaryError(err) { + log.Warn("RPC accept error", "err", err) + continue + } else if err != nil { return err } - log.Trace(fmt.Sprint("accepted conn", conn.RemoteAddr())) - go srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions) + log.Trace("Accepted RPC connection", "conn", conn.RemoteAddr()) + go s.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions) } } @@ -49,7 +46,11 @@ func (srv *Server) ServeListener(l net.Listener) error { // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialIPC(ctx context.Context, endpoint string) (*Client, error) { - return newClient(ctx, func(ctx context.Context) (net.Conn, error) { - return newIPCConnection(ctx, endpoint) + return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { + conn, err := newIPCConnection(ctx, endpoint) + if err != nil { + return nil, err + } + return NewJSONCodec(conn), err }) } diff --git a/rpc/json.go b/rpc/json.go index e35a74118a..34c825c025 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -18,71 +18,114 @@ package rpc import ( "bytes" + "context" "encoding/json" + "errors" "fmt" "io" "reflect" - "strconv" "strings" "sync" - - "github.com/tomochain/tomochain/log" + "time" ) const ( - jsonrpcVersion = "2.0" + vsn = "2.0" serviceMethodSeparator = "_" subscribeMethodSuffix = "_subscribe" unsubscribeMethodSuffix = "_unsubscribe" notificationMethodSuffix = "_subscription" + + defaultWriteTimeout = 10 * time.Second // used if context has no deadline ) -type jsonRequest struct { - Method string `json:"method"` - Version string `json:"jsonrpc"` - Id json.RawMessage `json:"id,omitempty"` - Payload json.RawMessage `json:"params,omitempty"` +var null = json.RawMessage("null") + +type subscriptionResult struct { + ID string `json:"subscription"` + Result json.RawMessage `json:"result,omitempty"` } -type jsonSuccessResponse struct { - Version string `json:"jsonrpc"` - Id interface{} `json:"id,omitempty"` - Result interface{} `json:"result"` +// A value of this type can a JSON-RPC request, notification, successful response or +// error response. Which one it is depends on the fields. +type jsonrpcMessage struct { + Version string `json:"jsonrpc,omitempty"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Error *jsonError `json:"error,omitempty"` + Result json.RawMessage `json:"result,omitempty"` } -type jsonError struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data,omitempty"` +func (msg *jsonrpcMessage) isNotification() bool { + return msg.hasValidVersion() && msg.ID == nil && msg.Method != "" } -type jsonErrResponse struct { - Version string `json:"jsonrpc"` - Id interface{} `json:"id,omitempty"` - Error jsonError `json:"error"` +func (msg *jsonrpcMessage) isCall() bool { + return msg.hasValidVersion() && msg.hasValidID() && msg.Method != "" } -type jsonSubscription struct { - Subscription string `json:"subscription"` - Result interface{} `json:"result,omitempty"` +func (msg *jsonrpcMessage) isResponse() bool { + return msg.hasValidVersion() && msg.hasValidID() && msg.Method == "" && msg.Params == nil && (msg.Result != nil || msg.Error != nil) } -type jsonNotification struct { - Version string `json:"jsonrpc"` - Method string `json:"method"` - Params jsonSubscription `json:"params"` +func (msg *jsonrpcMessage) hasValidID() bool { + return len(msg.ID) > 0 && msg.ID[0] != '{' && msg.ID[0] != '[' } -// jsonCodec reads and writes JSON-RPC messages to the underlying connection. It -// also has support for parsing arguments and serializing (result) objects. -type jsonCodec struct { - closer sync.Once // close closed channel once - closed chan interface{} // closed on Close - decMu sync.Mutex // guards the decoder - decode func(v interface{}) error // decoder to allow multiple transports - encMu sync.Mutex // guards the encoder - encode func(v interface{}) error // encoder to allow multiple transports - rw io.ReadWriteCloser // connection +func (msg *jsonrpcMessage) hasValidVersion() bool { + return msg.Version == vsn +} + +func (msg *jsonrpcMessage) isSubscribe() bool { + return strings.HasSuffix(msg.Method, subscribeMethodSuffix) +} + +func (msg *jsonrpcMessage) isUnsubscribe() bool { + return strings.HasSuffix(msg.Method, unsubscribeMethodSuffix) +} + +func (msg *jsonrpcMessage) namespace() string { + elem := strings.SplitN(msg.Method, serviceMethodSeparator, 2) + return elem[0] +} + +func (msg *jsonrpcMessage) String() string { + b, _ := json.Marshal(msg) + return string(b) +} + +func (msg *jsonrpcMessage) errorResponse(err error) *jsonrpcMessage { + resp := errorMessage(err) + resp.ID = msg.ID + return resp +} + +func (msg *jsonrpcMessage) response(result interface{}) *jsonrpcMessage { + enc, err := json.Marshal(result) + if err != nil { + // TODO: wrap with 'internal server error' + return msg.errorResponse(err) + } + return &jsonrpcMessage{Version: vsn, ID: msg.ID, Result: enc} +} + +func errorMessage(err error) *jsonrpcMessage { + msg := &jsonrpcMessage{Version: vsn, ID: null, Error: &jsonError{ + Code: defaultErrorCode, + Message: err.Error(), + }} + ec, ok := err.(Error) + if ok { + msg.Error.Code = ec.ErrorCode() + } + return msg +} + +type jsonError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` } func (err *jsonError) Error() string { @@ -96,280 +139,196 @@ func (err *jsonError) ErrorCode() int { return err.Code } -func (err *jsonError) ErrorData() interface{} { - return err.Data +// Conn is a subset of the methods of net.Conn which are sufficient for ServerCodec. +type Conn interface { + io.ReadWriteCloser + SetWriteDeadline(time.Time) error +} + +// ConnRemoteAddr wraps the RemoteAddr operation, which returns a description +// of the peer address of a connection. If a Conn also implements ConnRemoteAddr, this +// description is used in log messages. +type ConnRemoteAddr interface { + RemoteAddr() string +} + +// connWithRemoteAddr overrides the remote address of a connection. +type connWithRemoteAddr struct { + Conn + addr string +} + +func (c connWithRemoteAddr) RemoteAddr() string { return c.addr } + +// jsonCodec reads and writes JSON-RPC messages to the underlying connection. It also has +// support for parsing arguments and serializing (result) objects. +type jsonCodec struct { + remoteAddr string + closer sync.Once // close closed channel once + closed chan interface{} // closed on Close + decode func(v interface{}) error // decoder to allow multiple transports + encMu sync.Mutex // guards the encoder + encode func(v interface{}) error // encoder to allow multiple transports + conn Conn } // NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based // on explicitly given encoding and decoding methods. -func NewCodec(rwc io.ReadWriteCloser, encode, decode func(v interface{}) error) ServerCodec { - return &jsonCodec{ +func NewCodec(conn Conn, encode, decode func(v interface{}) error) ServerCodec { + codec := &jsonCodec{ closed: make(chan interface{}), encode: encode, decode: decode, - rw: rwc, + conn: conn, } + if ra, ok := conn.(ConnRemoteAddr); ok { + codec.remoteAddr = ra.RemoteAddr() + } + return codec } // NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0. -func NewJSONCodec(rwc io.ReadWriteCloser) ServerCodec { - enc := json.NewEncoder(rwc) - dec := json.NewDecoder(rwc) +func NewJSONCodec(conn Conn) ServerCodec { + enc := json.NewEncoder(conn) + dec := json.NewDecoder(conn) dec.UseNumber() + return NewCodec(conn, enc.Encode, dec.Decode) +} - return &jsonCodec{ - closed: make(chan interface{}), - encode: enc.Encode, - decode: dec.Decode, - rw: rwc, - } +func (c *jsonCodec) RemoteAddr() string { + return c.remoteAddr } -// isBatch returns true when the first non-whitespace characters is '[' -func isBatch(msg json.RawMessage) bool { - for _, c := range msg { - // skip insignificant whitespace (http://www.ietf.org/rfc/rfc4627.txt) - if c == 0x20 || c == 0x09 || c == 0x0a || c == 0x0d { - continue - } - return c == '[' +func (c *jsonCodec) Read() (msg []*jsonrpcMessage, batch bool, err error) { + // Decode the next JSON object in the input stream. + // This verifies basic syntax, etc. + var rawmsg json.RawMessage + if err := c.decode(&rawmsg); err != nil { + return nil, false, err } - return false + msg, batch = parseMessage(rawmsg) + return msg, batch, nil } -// ReadRequestHeaders will read new requests without parsing the arguments. It will -// return a collection of requests, an indication if these requests are in batch -// form or an error when the incoming message could not be read/parsed. -func (c *jsonCodec) ReadRequestHeaders() ([]rpcRequest, bool, Error) { - c.decMu.Lock() - defer c.decMu.Unlock() +// Write sends a message to client. +func (c *jsonCodec) Write(ctx context.Context, v interface{}) error { + c.encMu.Lock() + defer c.encMu.Unlock() - var incomingMsg json.RawMessage - if err := c.decode(&incomingMsg); err != nil { - return nil, false, &invalidRequestError{err.Error()} - } - if isBatch(incomingMsg) { - return parseBatchRequest(incomingMsg) + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(defaultWriteTimeout) } - return parseRequest(incomingMsg) + c.conn.SetWriteDeadline(deadline) + return c.encode(v) } -// checkReqId returns an error when the given reqId isn't valid for RPC method calls. -// valid id's are strings, numbers or null -func checkReqId(reqId json.RawMessage) error { - if len(reqId) == 0 { - return fmt.Errorf("missing request id") - } - if _, err := strconv.ParseFloat(string(reqId), 64); err == nil { - return nil - } - var str string - if err := json.Unmarshal(reqId, &str); err == nil { - return nil - } - return fmt.Errorf("invalid request id") +// Close the underlying connection +func (c *jsonCodec) Close() { + c.closer.Do(func() { + close(c.closed) + c.conn.Close() + }) } -// parseRequest will parse a single request from the given RawMessage. It will return -// the parsed request, an indication if the request was a batch or an error when -// the request could not be parsed. -func parseRequest(incomingMsg json.RawMessage) ([]rpcRequest, bool, Error) { - var in jsonRequest - if err := json.Unmarshal(incomingMsg, &in); err != nil { - return nil, false, &invalidMessageError{err.Error()} - } - - if err := checkReqId(in.Id); err != nil { - return nil, false, &invalidMessageError{err.Error()} - } - - // subscribe are special, they will always use `subscribeMethod` as first param in the payload - if strings.HasSuffix(in.Method, subscribeMethodSuffix) { - reqs := []rpcRequest{{id: &in.Id, isPubSub: true}} - if len(in.Payload) > 0 { - // first param must be subscription name - var subscribeMethod [1]string - if err := json.Unmarshal(in.Payload, &subscribeMethod); err != nil { - log.Debug(fmt.Sprintf("Unable to parse subscription method: %v\n", err)) - return nil, false, &invalidRequestError{"Unable to parse subscription request"} - } - - reqs[0].service, reqs[0].method = strings.TrimSuffix(in.Method, subscribeMethodSuffix), subscribeMethod[0] - reqs[0].params = in.Payload - return reqs, false, nil - } - return nil, false, &invalidRequestError{"Unable to parse subscription request"} - } - - if strings.HasSuffix(in.Method, unsubscribeMethodSuffix) { - return []rpcRequest{{id: &in.Id, isPubSub: true, - method: in.Method, params: in.Payload}}, false, nil - } +// Closed returns a channel which will be closed when Close is called +func (c *jsonCodec) Closed() <-chan interface{} { + return c.closed +} - elems := strings.Split(in.Method, serviceMethodSeparator) - if len(elems) != 2 { - return nil, false, &methodNotFoundError{in.Method, ""} +// parseMessage parses raw bytes as a (batch of) JSON-RPC message(s). There are no error +// checks in this function because the raw message has already been syntax-checked when it +// is called. Any non-JSON-RPC messages in the input return the zero value of +// jsonrpcMessage. +func parseMessage(raw json.RawMessage) ([]*jsonrpcMessage, bool) { + if !isBatch(raw) { + msgs := []*jsonrpcMessage{{}} + json.Unmarshal(raw, &msgs[0]) + return msgs, false } - - // regular RPC call - if len(in.Payload) == 0 { - return []rpcRequest{{service: elems[0], method: elems[1], id: &in.Id}}, false, nil + dec := json.NewDecoder(bytes.NewReader(raw)) + dec.Token() // skip '[' + var msgs []*jsonrpcMessage + for dec.More() { + msgs = append(msgs, new(jsonrpcMessage)) + dec.Decode(&msgs[len(msgs)-1]) } - - return []rpcRequest{{service: elems[0], method: elems[1], id: &in.Id, params: in.Payload}}, false, nil + return msgs, true } -// parseBatchRequest will parse a batch request into a collection of requests from the given RawMessage, an indication -// if the request was a batch or an error when the request could not be read. -func parseBatchRequest(incomingMsg json.RawMessage) ([]rpcRequest, bool, Error) { - var in []jsonRequest - if err := json.Unmarshal(incomingMsg, &in); err != nil { - return nil, false, &invalidMessageError{err.Error()} - } - - requests := make([]rpcRequest, len(in)) - for i, r := range in { - if err := checkReqId(r.Id); err != nil { - return nil, false, &invalidMessageError{err.Error()} - } - - id := &in[i].Id - - // subscribe are special, they will always use `subscriptionMethod` as first param in the payload - if strings.HasSuffix(r.Method, subscribeMethodSuffix) { - requests[i] = rpcRequest{id: id, isPubSub: true} - if len(r.Payload) > 0 { - // first param must be subscription name - var subscribeMethod [1]string - if err := json.Unmarshal(r.Payload, &subscribeMethod); err != nil { - log.Debug(fmt.Sprintf("Unable to parse subscription method: %v\n", err)) - return nil, false, &invalidRequestError{"Unable to parse subscription request"} - } - - requests[i].service, requests[i].method = strings.TrimSuffix(r.Method, subscribeMethodSuffix), subscribeMethod[0] - requests[i].params = r.Payload - continue - } - - return nil, true, &invalidRequestError{"Unable to parse (un)subscribe request arguments"} - } - - if strings.HasSuffix(r.Method, unsubscribeMethodSuffix) { - requests[i] = rpcRequest{id: id, isPubSub: true, method: r.Method, params: r.Payload} +// isBatch returns true when the first non-whitespace characters is '[' +func isBatch(raw json.RawMessage) bool { + for _, c := range raw { + // skip insignificant whitespace (http://www.ietf.org/rfc/rfc4627.txt) + if c == 0x20 || c == 0x09 || c == 0x0a || c == 0x0d { continue } - - if len(r.Payload) == 0 { - requests[i] = rpcRequest{id: id, params: nil} - } else { - requests[i] = rpcRequest{id: id, params: r.Payload} - } - if elem := strings.Split(r.Method, serviceMethodSeparator); len(elem) == 2 { - requests[i].service, requests[i].method = elem[0], elem[1] - } else { - requests[i].err = &methodNotFoundError{r.Method, ""} - } - } - - return requests, true, nil -} - -// ParseRequestArguments tries to parse the given params (json.RawMessage) with the given -// types. It returns the parsed values or an error when the parsing failed. -func (c *jsonCodec) ParseRequestArguments(argTypes []reflect.Type, params interface{}) ([]reflect.Value, Error) { - if args, ok := params.(json.RawMessage); !ok { - return nil, &invalidParamsError{"Invalid params supplied"} - } else { - return parsePositionalArguments(args, argTypes) + return c == '[' } + return false } // parsePositionalArguments tries to parse the given args to an array of values with the // given types. It returns the parsed values or an error when the args could not be // parsed. Missing optional arguments are returned as reflect.Zero values. -func parsePositionalArguments(rawArgs json.RawMessage, types []reflect.Type) ([]reflect.Value, Error) { - // Read beginning of the args array. +func parsePositionalArguments(rawArgs json.RawMessage, types []reflect.Type) ([]reflect.Value, error) { dec := json.NewDecoder(bytes.NewReader(rawArgs)) - if tok, _ := dec.Token(); tok != json.Delim('[') { - return nil, &invalidParamsError{"non-array args"} + var args []reflect.Value + tok, err := dec.Token() + switch { + case err == io.EOF || tok == nil && err == nil: + // "params" is optional and may be empty. Also allow "params":null even though it's + // not in the spec because our own client used to send it. + case err != nil: + return nil, err + case tok == json.Delim('['): + // Read argument array. + if args, err = parseArgumentArray(dec, types); err != nil { + return nil, err + } + default: + return nil, errors.New("non-array args") } - // Read args. + // Set any missing args to nil. + for i := len(args); i < len(types); i++ { + if types[i].Kind() != reflect.Ptr { + return nil, fmt.Errorf("missing value for required argument %d", i) + } + args = append(args, reflect.Zero(types[i])) + } + return args, nil +} + +func parseArgumentArray(dec *json.Decoder, types []reflect.Type) ([]reflect.Value, error) { args := make([]reflect.Value, 0, len(types)) for i := 0; dec.More(); i++ { if i >= len(types) { - return nil, &invalidParamsError{fmt.Sprintf("too many arguments, want at most %d", len(types))} + return args, fmt.Errorf("too many arguments, want at most %d", len(types)) } argval := reflect.New(types[i]) if err := dec.Decode(argval.Interface()); err != nil { - return nil, &invalidParamsError{fmt.Sprintf("invalid argument %d: %v", i, err)} + return args, fmt.Errorf("invalid argument %d: %v", i, err) } if argval.IsNil() && types[i].Kind() != reflect.Ptr { - return nil, &invalidParamsError{fmt.Sprintf("missing value for required argument %d", i)} + return args, fmt.Errorf("missing value for required argument %d", i) } args = append(args, argval.Elem()) } // Read end of args array. - if _, err := dec.Token(); err != nil { - return nil, &invalidParamsError{err.Error()} - } - // Set any missing args to nil. - for i := len(args); i < len(types); i++ { - if types[i].Kind() != reflect.Ptr { - return nil, &invalidParamsError{fmt.Sprintf("missing value for required argument %d", i)} - } - args = append(args, reflect.Zero(types[i])) - } - return args, nil + _, err := dec.Token() + return args, err } -// CreateResponse will create a JSON-RPC success response with the given id and reply as result. -func (c *jsonCodec) CreateResponse(id interface{}, reply interface{}) interface{} { - if isHexNum(reflect.TypeOf(reply)) { - return &jsonSuccessResponse{Version: jsonrpcVersion, Id: id, Result: fmt.Sprintf(`%#x`, reply)} +// parseSubscriptionName extracts the subscription name from an encoded argument array. +func parseSubscriptionName(rawArgs json.RawMessage) (string, error) { + dec := json.NewDecoder(bytes.NewReader(rawArgs)) + if tok, _ := dec.Token(); tok != json.Delim('[') { + return "", errors.New("non-array args") } - return &jsonSuccessResponse{Version: jsonrpcVersion, Id: id, Result: reply} -} - -// CreateErrorResponse will create a JSON-RPC error response with the given id and error. -func (c *jsonCodec) CreateErrorResponse(id interface{}, err Error) interface{} { - return &jsonErrResponse{Version: jsonrpcVersion, Id: id, Error: jsonError{Code: err.ErrorCode(), Message: err.Error()}} -} - -// CreateErrorResponseWithInfo will create a JSON-RPC error response with the given id and error. -// info is optional and contains additional information about the error. When an empty string is passed it is ignored. -func (c *jsonCodec) CreateErrorResponseWithInfo(id interface{}, err Error, info interface{}) interface{} { - return &jsonErrResponse{Version: jsonrpcVersion, Id: id, - Error: jsonError{Code: err.ErrorCode(), Message: err.Error(), Data: info}} -} - -// CreateNotification will create a JSON-RPC notification with the given subscription id and event as params. -func (c *jsonCodec) CreateNotification(subid, namespace string, event interface{}) interface{} { - if isHexNum(reflect.TypeOf(event)) { - return &jsonNotification{Version: jsonrpcVersion, Method: namespace + notificationMethodSuffix, - Params: jsonSubscription{Subscription: subid, Result: fmt.Sprintf(`%#x`, event)}} + v, _ := dec.Token() + method, ok := v.(string) + if !ok { + return "", errors.New("expected subscription name as first argument") } - - return &jsonNotification{Version: jsonrpcVersion, Method: namespace + notificationMethodSuffix, - Params: jsonSubscription{Subscription: subid, Result: event}} -} - -// Write message to client -func (c *jsonCodec) Write(res interface{}) error { - c.encMu.Lock() - defer c.encMu.Unlock() - - return c.encode(res) -} - -// Close the underlying connection -func (c *jsonCodec) Close() { - c.closer.Do(func() { - close(c.closed) - c.rw.Close() - }) -} - -// Closed returns a channel which will be closed when Close is called -func (c *jsonCodec) Closed() <-chan interface{} { - return c.closed + return method, nil } diff --git a/rpc/json_test.go b/rpc/json_test.go deleted file mode 100644 index 5048d2f7a0..0000000000 --- a/rpc/json_test.go +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package rpc - -import ( - "bufio" - "bytes" - "encoding/json" - "reflect" - "strconv" - "testing" -) - -type RWC struct { - *bufio.ReadWriter -} - -func (rwc *RWC) Close() error { - return nil -} - -func TestJSONRequestParsing(t *testing.T) { - server := NewServer() - service := new(Service) - - if err := server.RegisterName("calc", service); err != nil { - t.Fatalf("%v", err) - } - - req := bytes.NewBufferString(`{"id": 1234, "jsonrpc": "2.0", "method": "calc_add", "params": [11, 22]}`) - var str string - reply := bytes.NewBufferString(str) - rw := &RWC{bufio.NewReadWriter(bufio.NewReader(req), bufio.NewWriter(reply))} - - codec := NewJSONCodec(rw) - - requests, batch, err := codec.ReadRequestHeaders() - if err != nil { - t.Fatalf("%v", err) - } - - if batch { - t.Fatalf("Request isn't a batch") - } - - if len(requests) != 1 { - t.Fatalf("Expected 1 request but got %d requests - %v", len(requests), requests) - } - - if requests[0].service != "calc" { - t.Fatalf("Expected service 'calc' but got '%s'", requests[0].service) - } - - if requests[0].method != "add" { - t.Fatalf("Expected method 'Add' but got '%s'", requests[0].method) - } - - if rawId, ok := requests[0].id.(*json.RawMessage); ok { - id, e := strconv.ParseInt(string(*rawId), 0, 64) - if e != nil { - t.Fatalf("%v", e) - } - if id != 1234 { - t.Fatalf("Expected id 1234 but got %d", id) - } - } else { - t.Fatalf("invalid request, expected *json.RawMesage got %T", requests[0].id) - } - - var arg int - args := []reflect.Type{reflect.TypeOf(arg), reflect.TypeOf(arg)} - - v, err := codec.ParseRequestArguments(args, requests[0].params) - if err != nil { - t.Fatalf("%v", err) - } - - if len(v) != 2 { - t.Fatalf("Expected 2 argument values, got %d", len(v)) - } - - if v[0].Int() != 11 || v[1].Int() != 22 { - t.Fatalf("expected %d == 11 && %d == 22", v[0].Int(), v[1].Int()) - } -} - -func TestJSONRequestParamsParsing(t *testing.T) { - - var ( - stringT = reflect.TypeOf("") - intT = reflect.TypeOf(0) - intPtrT = reflect.TypeOf(new(int)) - - stringV = reflect.ValueOf("abc") - i = 1 - intV = reflect.ValueOf(i) - intPtrV = reflect.ValueOf(&i) - ) - - var validTests = []struct { - input string - argTypes []reflect.Type - expected []reflect.Value - }{ - {`[]`, []reflect.Type{}, []reflect.Value{}}, - {`[]`, []reflect.Type{intPtrT}, []reflect.Value{intPtrV}}, - {`[1]`, []reflect.Type{intT}, []reflect.Value{intV}}, - {`[1,"abc"]`, []reflect.Type{intT, stringT}, []reflect.Value{intV, stringV}}, - {`[null]`, []reflect.Type{intPtrT}, []reflect.Value{intPtrV}}, - {`[null,"abc"]`, []reflect.Type{intPtrT, stringT, intPtrT}, []reflect.Value{intPtrV, stringV, intPtrV}}, - {`[null,"abc",null]`, []reflect.Type{intPtrT, stringT, intPtrT}, []reflect.Value{intPtrV, stringV, intPtrV}}, - } - - codec := jsonCodec{} - - for _, test := range validTests { - params := (json.RawMessage)([]byte(test.input)) - args, err := codec.ParseRequestArguments(test.argTypes, params) - - if err != nil { - t.Fatal(err) - } - - var match []interface{} - json.Unmarshal([]byte(test.input), &match) - - if len(args) != len(test.argTypes) { - t.Fatalf("expected %d parsed args, got %d", len(test.argTypes), len(args)) - } - - for i, arg := range args { - expected := test.expected[i] - - if arg.Kind() != expected.Kind() { - t.Errorf("expected type for param %d in %s", i, test.input) - } - - if arg.Kind() == reflect.Int && arg.Int() != expected.Int() { - t.Errorf("expected int(%d), got int(%d) in %s", expected.Int(), arg.Int(), test.input) - } - - if arg.Kind() == reflect.String && arg.String() != expected.String() { - t.Errorf("expected string(%s), got string(%s) in %s", expected.String(), arg.String(), test.input) - } - } - } - - var invalidTests = []struct { - input string - argTypes []reflect.Type - }{ - {`[]`, []reflect.Type{intT}}, - {`[null]`, []reflect.Type{intT}}, - {`[1]`, []reflect.Type{stringT}}, - {`[1,2]`, []reflect.Type{stringT}}, - {`["abc", null]`, []reflect.Type{stringT, intT}}, - } - - for i, test := range invalidTests { - if _, err := codec.ParseRequestArguments(test.argTypes, test.input); err == nil { - t.Errorf("expected test %d - %s to fail", i, test.input) - } - } -} diff --git a/rpc/server.go b/rpc/server.go index acaca96d55..e8eca78564 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -18,20 +18,19 @@ package rpc import ( "context" - "fmt" - "reflect" - "runtime" - "strings" - "sync" + "io" "sync/atomic" mapset "github.com/deckarep/golang-set" + "github.com/tomochain/tomochain/log" ) const MetadataApi = "rpc" -// CodecOption specifies which type of messages this codec supports +// CodecOption specifies which type of messages a codec supports. +// +// Deprecated: this option is no longer honored by Server. type CodecOption int const ( @@ -42,196 +41,87 @@ const ( OptionSubscriptions = 1 << iota // support pub sub ) -// NewServer will create a new server instance with no registered handlers. -func NewServer() *Server { - server := &Server{ - services: make(serviceRegistry), - codecs: mapset.NewSet(), - run: 1, - } +// Server is an RPC server. +type Server struct { + services serviceRegistry + idgen func() ID + run int32 + codecs mapset.Set +} - // register a default service which will provide meta information about the RPC service such as the services and - // methods it offers. +// NewServer creates a new server instance with no registered handlers. +func NewServer() *Server { + server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet(), run: 1} + // Register the default service providing meta information about the RPC service such + // as the services and methods it offers. rpcService := &RPCService{server} server.RegisterName(MetadataApi, rpcService) - return server } -// RPCService gives meta information about the server. -// e.g. gives information about the loaded modules. -type RPCService struct { - server *Server -} - -// Modules returns the list of RPC services with their version number -func (s *RPCService) Modules() map[string]string { - modules := make(map[string]string) - for name := range s.server.services { - modules[name] = "1.0" - } - return modules +// RegisterName creates a service for the given receiver type under the given name. When no +// methods on the given receiver match the criteria to be either a RPC method or a +// subscription an error is returned. Otherwise a new service is created and added to the +// service collection this server provides to clients. +func (s *Server) RegisterName(name string, receiver interface{}) error { + return s.services.registerName(name, receiver) } -// RegisterName will create a service for the given rcvr type under the given name. When no methods on the given rcvr -// match the criteria to be either a RPC method or a subscription an error is returned. Otherwise a new service is -// created and added to the service collection this server instance serves. -func (s *Server) RegisterName(name string, rcvr interface{}) error { - if s.services == nil { - s.services = make(serviceRegistry) - } - - svc := new(service) - svc.typ = reflect.TypeOf(rcvr) - rcvrVal := reflect.ValueOf(rcvr) - - if name == "" { - return fmt.Errorf("no service name for type %s", svc.typ.String()) - } - if !isExported(reflect.Indirect(rcvrVal).Type().Name()) { - return fmt.Errorf("%s is not exported", reflect.Indirect(rcvrVal).Type().Name()) - } - - methods, subscriptions := suitableCallbacks(rcvrVal, svc.typ) +// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes +// the response back using the given codec. It will block until the codec is closed or the +// server is stopped. In either case the codec is closed. +// +// Note that codec options are no longer supported. +func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { + defer codec.Close() - // already a previous service register under given sname, merge methods/subscriptions - if regsvc, present := s.services[name]; present { - if len(methods) == 0 && len(subscriptions) == 0 { - return fmt.Errorf("Service %T doesn't have any suitable methods/subscriptions to expose", rcvr) - } - for _, m := range methods { - regsvc.callbacks[formatName(m.method.Name)] = m - } - for _, s := range subscriptions { - regsvc.subscriptions[formatName(s.method.Name)] = s - } - return nil + // Don't serve if server is stopped. + if atomic.LoadInt32(&s.run) == 0 { + return } - svc.name = name - svc.callbacks, svc.subscriptions = methods, subscriptions - - if len(svc.callbacks) == 0 && len(svc.subscriptions) == 0 { - return fmt.Errorf("Service %T doesn't have any suitable methods/subscriptions to expose", rcvr) - } + // Add the codec to the set so it can be closed by Stop. + s.codecs.Add(codec) + defer s.codecs.Remove(codec) - s.services[svc.name] = svc - return nil + c := initClient(codec, s.idgen, &s.services) + <-codec.Closed() + c.Close() } -// serveRequest will reads requests from the codec, calls the RPC callback and -// writes the response to the given codec. -// -// If singleShot is true it will process a single request, otherwise it will handle -// requests until the codec returns an error when reading a request (in most cases -// an EOF). It executes requests in parallel when singleShot is false. -func (s *Server) serveRequest(codec ServerCodec, singleShot bool, options CodecOption) error { - var pend sync.WaitGroup - - defer func() { - if err := recover(); err != nil { - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - log.Error(fmt.Sprintf("RPC serveRequest %s\n", string(buf))) - } - s.codecsMu.Lock() - s.codecs.Remove(codec) - s.codecsMu.Unlock() - }() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // if the codec supports notification include a notifier that callbacks can use - // to send notification to clients. It is thight to the codec/connection. If the - // connection is closed the notifier will stop and cancels all active subscriptions. - if options&OptionSubscriptions == OptionSubscriptions { - ctx = context.WithValue(ctx, notifierKey{}, newNotifier(codec)) - } - s.codecsMu.Lock() - if atomic.LoadInt32(&s.run) != 1 { // server stopped - s.codecsMu.Unlock() - return &shutdownError{} +// serveSingleRequest reads and processes a single RPC request from the given codec. This +// is used to serve HTTP connections. Subscriptions and reverse calls are not allowed in +// this mode. +func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { + // Don't serve if server is stopped. + if atomic.LoadInt32(&s.run) == 0 { + return } - s.codecs.Add(codec) - s.codecsMu.Unlock() - // test if the server is ordered to stop - for atomic.LoadInt32(&s.run) == 1 { - reqs, batch, err := s.readRequest(codec) - if err != nil { - // If a parsing error occurred, send an error - if err.Error() != "EOF" { - log.Debug(fmt.Sprintf("read error %v\n", err)) - codec.Write(codec.CreateErrorResponse(nil, err)) - } - // Error or end of stream, wait for requests and tear down - pend.Wait() - return nil - } + h := newHandler(ctx, codec, s.idgen, &s.services) + h.allowSubscribe = false + defer h.close(io.EOF, nil) - // check if server is ordered to shutdown and return an error - // telling the client that his request failed. - if atomic.LoadInt32(&s.run) != 1 { - err = &shutdownError{} - if batch { - resps := make([]interface{}, len(reqs)) - for i, r := range reqs { - resps[i] = codec.CreateErrorResponse(&r.id, err) - } - codec.Write(resps) - } else { - codec.Write(codec.CreateErrorResponse(&reqs[0].id, err)) - } - return nil - } - // If a single shot request is executing, run and return immediately - if singleShot { - if batch { - s.execBatch(ctx, codec, reqs) - } else { - s.exec(ctx, codec, reqs[0]) - } - return nil + reqs, batch, err := codec.Read() + if err != nil { + if err != io.EOF { + codec.Write(ctx, errorMessage(&invalidMessageError{"parse error"})) } - // For multi-shot connections, start a goroutine to serve and loop back - pend.Add(1) - - go func(reqs []*serverRequest, batch bool) { - defer pend.Done() - if batch { - s.execBatch(ctx, codec, reqs) - } else { - s.exec(ctx, codec, reqs[0]) - } - }(reqs, batch) + return + } + if batch { + h.handleBatch(reqs) + } else { + h.handleMsg(reqs[0]) } - return nil -} - -// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes the -// response back using the given codec. It will block until the codec is closed or the server is -// stopped. In either case the codec is closed. -func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { - defer codec.Close() - s.serveRequest(codec, false, options) -} - -// ServeSingleRequest reads and processes a single RPC request from the given codec. It will not -// close the codec unless a non-recoverable error has occurred. Note, this method will return after -// a single request has been processed! -func (s *Server) ServeSingleRequest(codec ServerCodec, options CodecOption) { - s.serveRequest(codec, true, options) } -// Stop will stop reading new requests, wait for stopPendingRequestTimeout to allow pending requests to finish, -// close all codecs which will cancel pending requests/subscriptions. +// Stop stops reading new requests, waits for stopPendingRequestTimeout to allow pending +// requests to finish, then closes all codecs which will cancel pending requests and +// subscriptions. func (s *Server) Stop() { if atomic.CompareAndSwapInt32(&s.run, 1, 0) { - log.Debug("RPC Server shutdown initiatied") - s.codecsMu.Lock() - defer s.codecsMu.Unlock() + log.Debug("RPC server shutting down") s.codecs.Each(func(c interface{}) bool { c.(ServerCodec).Close() return true @@ -239,207 +129,20 @@ func (s *Server) Stop() { } } -// createSubscription will call the subscription callback and returns the subscription id or error. -func (s *Server) createSubscription(ctx context.Context, c ServerCodec, req *serverRequest) (ID, error) { - // subscription have as first argument the context following optional arguments - args := []reflect.Value{req.callb.rcvr, reflect.ValueOf(ctx)} - args = append(args, req.args...) - reply := req.callb.method.Func.Call(args) - - if !reply[1].IsNil() { // subscription creation failed - return "", reply[1].Interface().(error) - } - - return reply[0].Interface().(*Subscription).ID, nil -} - -// handle executes a request and returns the response from the callback. -func (s *Server) handle(ctx context.Context, codec ServerCodec, req *serverRequest) (interface{}, func()) { - if req.err != nil { - return codec.CreateErrorResponse(&req.id, req.err), nil - } - - if req.isUnsubscribe { // cancel subscription, first param must be the subscription id - if len(req.args) >= 1 && req.args[0].Kind() == reflect.String { - notifier, supported := NotifierFromContext(ctx) - if !supported { // interface doesn't support subscriptions (e.g. http) - return codec.CreateErrorResponse(&req.id, &callbackError{ErrNotificationsUnsupported.Error()}), nil - } - - subid := ID(req.args[0].String()) - if err := notifier.unsubscribe(subid); err != nil { - return codec.CreateErrorResponse(&req.id, &callbackError{err.Error()}), nil - } - - return codec.CreateResponse(req.id, true), nil - } - return codec.CreateErrorResponse(&req.id, &invalidParamsError{"Expected subscription id as first argument"}), nil - } - - if req.callb.isSubscribe { - subid, err := s.createSubscription(ctx, codec, req) - if err != nil { - return codec.CreateErrorResponse(&req.id, &callbackError{err.Error()}), nil - } - - // active the subscription after the sub id was successfully sent to the client - activateSub := func() { - notifier, _ := NotifierFromContext(ctx) - notifier.activate(subid, req.svcname) - } - - return codec.CreateResponse(req.id, subid), activateSub - } - - // regular RPC call, prepare arguments - if len(req.args) != len(req.callb.argTypes) { - rpcErr := &invalidParamsError{fmt.Sprintf("%s%s%s expects %d parameters, got %d", - req.svcname, serviceMethodSeparator, req.callb.method.Name, - len(req.callb.argTypes), len(req.args))} - return codec.CreateErrorResponse(&req.id, rpcErr), nil - } - - arguments := []reflect.Value{req.callb.rcvr} - if req.callb.hasCtx { - arguments = append(arguments, reflect.ValueOf(ctx)) - } - if len(req.args) > 0 { - arguments = append(arguments, req.args...) - } - - // execute RPC method and return result - reply := req.callb.method.Func.Call(arguments) - if len(reply) == 0 { - return codec.CreateResponse(req.id, nil), nil - } - - if req.callb.errPos >= 0 { // test if method returned an error - if !reply[req.callb.errPos].IsNil() { - e := reply[req.callb.errPos].Interface().(error) - res := codec.CreateErrorResponse(&req.id, &callbackError{e.Error()}) - return res, nil - } - } - return codec.CreateResponse(req.id, reply[0].Interface()), nil -} - -// exec executes the given request and writes the result back using the codec. -func (s *Server) exec(ctx context.Context, codec ServerCodec, req *serverRequest) { - var response interface{} - var callback func() - if req.err != nil { - response = codec.CreateErrorResponse(&req.id, req.err) - } else { - response, callback = s.handle(ctx, codec, req) - } - - if err := codec.Write(response); err != nil { - log.Error(fmt.Sprintf("RPC exec %v\n", err)) - codec.Close() - } - - // when request was a subscribe request this allows these subscriptions to be actived - if callback != nil { - callback() - } -} - -// execBatch executes the given requests and writes the result back using the codec. -// It will only write the response back when the last request is processed. -func (s *Server) execBatch(ctx context.Context, codec ServerCodec, requests []*serverRequest) { - responses := make([]interface{}, len(requests)) - var callbacks []func() - for i, req := range requests { - if req.err != nil { - responses[i] = codec.CreateErrorResponse(&req.id, req.err) - } else { - var callback func() - if responses[i], callback = s.handle(ctx, codec, req); callback != nil { - callbacks = append(callbacks, callback) - } - } - } - - if err := codec.Write(responses); err != nil { - log.Error(fmt.Sprintf("RPC execBacth %v\n", err)) - codec.Close() - } - - // when request holds one of more subscribe requests this allows these subscriptions to be activated - for _, c := range callbacks { - c() - } +// RPCService gives meta information about the server. +// e.g. gives information about the loaded modules. +type RPCService struct { + server *Server } -// readRequest requests the next (batch) request from the codec. It will return the collection -// of requests, an indication if the request was a batch, the invalid request identifier and an -// error when the request could not be read/parsed. -func (s *Server) readRequest(codec ServerCodec) ([]*serverRequest, bool, Error) { - reqs, batch, err := codec.ReadRequestHeaders() - if err != nil { - return nil, batch, err - } - - requests := make([]*serverRequest, len(reqs)) - - // verify requests - for i, r := range reqs { - var ok bool - var svc *service - - if r.err != nil { - requests[i] = &serverRequest{id: r.id, err: r.err} - continue - } - - if r.isPubSub && strings.HasSuffix(r.method, unsubscribeMethodSuffix) { - requests[i] = &serverRequest{id: r.id, isUnsubscribe: true} - argTypes := []reflect.Type{reflect.TypeOf("")} // expect subscription id as first arg - if args, err := codec.ParseRequestArguments(argTypes, r.params); err == nil { - requests[i].args = args - } else { - requests[i].err = &invalidParamsError{err.Error()} - } - continue - } - - if svc, ok = s.services[r.service]; !ok { // rpc method isn't available - requests[i] = &serverRequest{id: r.id, err: &methodNotFoundError{r.service, r.method}} - continue - } - - if r.isPubSub { // eth_subscribe, r.method contains the subscription method name - if callb, ok := svc.subscriptions[r.method]; ok { - requests[i] = &serverRequest{id: r.id, svcname: svc.name, callb: callb} - if r.params != nil && len(callb.argTypes) > 0 { - argTypes := []reflect.Type{reflect.TypeOf("")} - argTypes = append(argTypes, callb.argTypes...) - if args, err := codec.ParseRequestArguments(argTypes, r.params); err == nil { - requests[i].args = args[1:] // first one is service.method name which isn't an actual argument - } else { - requests[i].err = &invalidParamsError{err.Error()} - } - } - } else { - requests[i] = &serverRequest{id: r.id, err: &methodNotFoundError{r.service, r.method}} - } - continue - } - - if callb, ok := svc.callbacks[r.method]; ok { // lookup RPC method - requests[i] = &serverRequest{id: r.id, svcname: svc.name, callb: callb} - if r.params != nil && len(callb.argTypes) > 0 { - if args, err := codec.ParseRequestArguments(callb.argTypes, r.params); err == nil { - requests[i].args = args - } else { - requests[i].err = &invalidParamsError{err.Error()} - } - } - continue - } +// Modules returns the list of RPC services with their version number +func (s *RPCService) Modules() map[string]string { + s.server.services.mu.Lock() + defer s.server.services.mu.Unlock() - requests[i] = &serverRequest{id: r.id, err: &methodNotFoundError{r.service, r.method}} + modules := make(map[string]string) + for name := range s.server.services.services { + modules[name] = "1.0" } - - return requests, batch, nil + return modules } diff --git a/rpc/service.go b/rpc/service.go new file mode 100644 index 0000000000..86b5eca468 --- /dev/null +++ b/rpc/service.go @@ -0,0 +1,285 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rpc + +import ( + "context" + "errors" + "fmt" + "reflect" + "runtime" + "strings" + "sync" + "unicode" + "unicode/utf8" + + "github.com/tomochain/tomochain/log" +) + +var ( + contextType = reflect.TypeOf((*context.Context)(nil)).Elem() + errorType = reflect.TypeOf((*error)(nil)).Elem() + subscriptionType = reflect.TypeOf(Subscription{}) + stringType = reflect.TypeOf("") +) + +type serviceRegistry struct { + mu sync.Mutex + services map[string]service +} + +// service represents a registered object. +type service struct { + name string // name for service + callbacks map[string]*callback // registered handlers + subscriptions map[string]*callback // available subscriptions/notifications +} + +// callback is a method callback which was registered in the server +type callback struct { + fn reflect.Value // the function + rcvr reflect.Value // receiver object of method, set if fn is method + argTypes []reflect.Type // input argument types + hasCtx bool // method's first argument is a context (not included in argTypes) + errPos int // err return idx, of -1 when method cannot return error + isSubscribe bool // true if this is a subscription callback +} + +func (r *serviceRegistry) registerName(name string, rcvr interface{}) error { + rcvrVal := reflect.ValueOf(rcvr) + if name == "" { + return fmt.Errorf("no service name for type %s", rcvrVal.Type().String()) + } + callbacks := suitableCallbacks(rcvrVal) + if len(callbacks) == 0 { + return fmt.Errorf("service %T doesn't have any suitable methods/subscriptions to expose", rcvr) + } + + r.mu.Lock() + defer r.mu.Unlock() + if r.services == nil { + r.services = make(map[string]service) + } + svc, ok := r.services[name] + if !ok { + svc = service{ + name: name, + callbacks: make(map[string]*callback), + subscriptions: make(map[string]*callback), + } + r.services[name] = svc + } + for name, cb := range callbacks { + if cb.isSubscribe { + svc.subscriptions[name] = cb + } else { + svc.callbacks[name] = cb + } + } + return nil +} + +// callback returns the callback corresponding to the given RPC method name. +func (r *serviceRegistry) callback(method string) *callback { + elem := strings.SplitN(method, serviceMethodSeparator, 2) + if len(elem) != 2 { + return nil + } + r.mu.Lock() + defer r.mu.Unlock() + return r.services[elem[0]].callbacks[elem[1]] +} + +// subscription returns a subscription callback in the given service. +func (r *serviceRegistry) subscription(service, name string) *callback { + r.mu.Lock() + defer r.mu.Unlock() + return r.services[service].subscriptions[name] +} + +// suitableCallbacks iterates over the methods of the given type. It determines if a method +// satisfies the criteria for a RPC callback or a subscription callback and adds it to the +// collection of callbacks. See server documentation for a summary of these criteria. +func suitableCallbacks(receiver reflect.Value) map[string]*callback { + typ := receiver.Type() + callbacks := make(map[string]*callback) + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) + if method.PkgPath != "" { + continue // method not exported + } + cb := newCallback(receiver, method.Func) + if cb == nil { + continue // function invalid + } + name := formatName(method.Name) + callbacks[name] = cb + } + return callbacks +} + +// newCallback turns fn (a function) into a callback object. It returns nil if the function +// is unsuitable as an RPC callback. +func newCallback(receiver, fn reflect.Value) *callback { + fntype := fn.Type() + c := &callback{fn: fn, rcvr: receiver, errPos: -1, isSubscribe: isPubSub(fntype)} + // Determine parameter types. They must all be exported or builtin types. + c.makeArgTypes() + if !allExportedOrBuiltin(c.argTypes) { + return nil + } + // Verify return types. The function must return at most one error + // and/or one other non-error value. + outs := make([]reflect.Type, fntype.NumOut()) + for i := 0; i < fntype.NumOut(); i++ { + outs[i] = fntype.Out(i) + } + if len(outs) > 2 || !allExportedOrBuiltin(outs) { + return nil + } + // If an error is returned, it must be the last returned value. + switch { + case len(outs) == 1 && isErrorType(outs[0]): + c.errPos = 0 + case len(outs) == 2: + if isErrorType(outs[0]) || !isErrorType(outs[1]) { + return nil + } + c.errPos = 1 + } + return c +} + +// makeArgTypes composes the argTypes list. +func (c *callback) makeArgTypes() { + fntype := c.fn.Type() + // Skip receiver and context.Context parameter (if present). + firstArg := 0 + if c.rcvr.IsValid() { + firstArg++ + } + if fntype.NumIn() > firstArg && fntype.In(firstArg) == contextType { + c.hasCtx = true + firstArg++ + } + // Add all remaining parameters. + c.argTypes = make([]reflect.Type, fntype.NumIn()-firstArg) + for i := firstArg; i < fntype.NumIn(); i++ { + c.argTypes[i-firstArg] = fntype.In(i) + } +} + +// call invokes the callback. +func (c *callback) call(ctx context.Context, method string, args []reflect.Value) (res interface{}, errRes error) { + // Create the argument slice. + fullargs := make([]reflect.Value, 0, 2+len(args)) + if c.rcvr.IsValid() { + fullargs = append(fullargs, c.rcvr) + } + if c.hasCtx { + fullargs = append(fullargs, reflect.ValueOf(ctx)) + } + fullargs = append(fullargs, args...) + + // Catch panic while running the callback. + defer func() { + if err := recover(); err != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + log.Error("RPC method " + method + " crashed: " + fmt.Sprintf("%v\n%s", err, buf)) + errRes = errors.New("method handler crashed") + } + }() + // Run the callback. + results := c.fn.Call(fullargs) + if len(results) == 0 { + return nil, nil + } + if c.errPos >= 0 && !results[c.errPos].IsNil() { + // Method has returned non-nil error value. + err := results[c.errPos].Interface().(error) + return reflect.Value{}, err + } + return results[0].Interface(), nil +} + +// Is this an exported - upper case - name? +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// Are all those types exported or built-in? +func allExportedOrBuiltin(types []reflect.Type) bool { + for _, typ := range types { + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + // PkgPath will be non-empty even for an exported type, + // so we need to check the type name as well. + if !isExported(typ.Name()) && typ.PkgPath() != "" { + return false + } + } + return true +} + +// Is t context.Context or *context.Context? +func isContextType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t == contextType +} + +// Does t satisfy the error interface? +func isErrorType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t.Implements(errorType) +} + +// Is t Subscription or *Subscription? +func isSubscriptionType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t == subscriptionType +} + +// isPubSub tests whether the given method has as as first argument a context.Context and +// returns the pair (Subscription, error). +func isPubSub(methodType reflect.Type) bool { + // numIn(0) is the receiver type + if methodType.NumIn() < 2 || methodType.NumOut() != 2 { + return false + } + return isContextType(methodType.In(1)) && + isSubscriptionType(methodType.Out(0)) && + isErrorType(methodType.Out(1)) +} + +// formatName converts to first character of name to lowercase. +func formatName(name string) string { + ret := []rune(name) + if len(ret) > 0 { + ret[0] = unicode.ToLower(ret[0]) + } + return string(ret) +} diff --git a/rpc/stdio.go b/rpc/stdio.go new file mode 100644 index 0000000000..8f6b7bd4bf --- /dev/null +++ b/rpc/stdio.go @@ -0,0 +1,54 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rpc + +import ( + "context" + "errors" + "net" + "os" + "time" +) + +// DialStdIO creates a client on stdin/stdout. +func DialStdIO(ctx context.Context) (*Client, error) { + return newClient(ctx, func(_ context.Context) (ServerCodec, error) { + return NewJSONCodec(stdioConn{}), nil + }) +} + +type stdioConn struct{} + +func (io stdioConn) Read(b []byte) (n int, err error) { + return os.Stdin.Read(b) +} + +func (io stdioConn) Write(b []byte) (n int, err error) { + return os.Stdout.Write(b) +} + +func (io stdioConn) Close() error { + return nil +} + +func (io stdioConn) RemoteAddr() string { + return "/dev/stdin" +} + +func (io stdioConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "stdio", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/rpc/subscription.go b/rpc/subscription.go index 6ce7befa1d..c1e869b8a3 100644 --- a/rpc/subscription.go +++ b/rpc/subscription.go @@ -17,9 +17,19 @@ package rpc import ( + "bufio" + "container/list" "context" + crand "crypto/rand" + "encoding/binary" + "encoding/hex" + "encoding/json" "errors" + "math/rand" + "reflect" + "strings" "sync" + "time" ) var ( @@ -29,107 +39,289 @@ var ( ErrSubscriptionNotFound = errors.New("subscription not found") ) +var globalGen = randomIDGenerator() + // ID defines a pseudo random number that is used to identify RPC subscriptions. type ID string -// a Subscription is created by a notifier and tight to that notifier. The client can use -// this subscription to wait for an unsubscribe request for the client, see Err(). -type Subscription struct { - ID ID - namespace string - err chan error // closed on unsubscribe +// NewID returns a new, random ID. +func NewID() ID { + return globalGen() } -// Err returns a channel that is closed when the client send an unsubscribe request. -func (s *Subscription) Err() <-chan error { - return s.err +// randomIDGenerator returns a function generates a random IDs. +func randomIDGenerator() func() ID { + seed, err := binary.ReadVarint(bufio.NewReader(crand.Reader)) + if err != nil { + seed = int64(time.Now().Nanosecond()) + } + var ( + mu sync.Mutex + rng = rand.New(rand.NewSource(seed)) + ) + return func() ID { + mu.Lock() + defer mu.Unlock() + id := make([]byte, 16) + rng.Read(id) + return encodeID(id) + } } -// notifierKey is used to store a notifier within the connection context. -type notifierKey struct{} - -// Notifier is tight to a RPC connection that supports subscriptions. -// Server callbacks use the notifier to send notifications. -type Notifier struct { - codec ServerCodec - subMu sync.RWMutex // guards active and inactive maps - active map[ID]*Subscription - inactive map[ID]*Subscription -} - -// newNotifier creates a new notifier that can be used to send subscription -// notifications to the client. -func newNotifier(codec ServerCodec) *Notifier { - return &Notifier{ - codec: codec, - active: make(map[ID]*Subscription), - inactive: make(map[ID]*Subscription), +func encodeID(b []byte) ID { + id := hex.EncodeToString(b) + id = strings.TrimLeft(id, "0") + if id == "" { + id = "0" // ID's are RPC quantities, no leading zero's and 0 is 0x0. } + return ID("0x" + id) } +type notifierKey struct{} + // NotifierFromContext returns the Notifier value stored in ctx, if any. func NotifierFromContext(ctx context.Context) (*Notifier, bool) { n, ok := ctx.Value(notifierKey{}).(*Notifier) return n, ok } +// Notifier is tied to a RPC connection that supports subscriptions. +// Server callbacks use the notifier to send notifications. +type Notifier struct { + h *handler + namespace string + + mu sync.Mutex + sub *Subscription + buffer []json.RawMessage + callReturned bool + activated bool +} + // CreateSubscription returns a new subscription that is coupled to the // RPC connection. By default subscriptions are inactive and notifications // are dropped until the subscription is marked as active. This is done // by the RPC server after the subscription ID is send to the client. func (n *Notifier) CreateSubscription() *Subscription { - s := &Subscription{ID: NewID(), err: make(chan error)} - n.subMu.Lock() - n.inactive[s.ID] = s - n.subMu.Unlock() - return s + n.mu.Lock() + defer n.mu.Unlock() + + if n.sub != nil { + panic("can't create multiple subscriptions with Notifier") + } else if n.callReturned { + panic("can't create subscription after subscribe call has returned") + } + n.sub = &Subscription{ID: n.h.idgen(), namespace: n.namespace, err: make(chan error, 1)} + return n.sub } // Notify sends a notification to the client with the given data as payload. // If an error occurs the RPC connection is closed and the error is returned. func (n *Notifier) Notify(id ID, data interface{}) error { - n.subMu.RLock() - defer n.subMu.RUnlock() - - sub, active := n.active[id] - if active { - notification := n.codec.CreateNotification(string(id), sub.namespace, data) - if err := n.codec.Write(notification); err != nil { - n.codec.Close() - return err - } + enc, err := json.Marshal(data) + if err != nil { + return err } + + n.mu.Lock() + defer n.mu.Unlock() + + if n.sub == nil { + panic("can't Notify before subscription is created") + } else if n.sub.ID != id { + panic("Notify with wrong ID") + } + if n.activated { + return n.send(n.sub, enc) + } + n.buffer = append(n.buffer, enc) return nil } // Closed returns a channel that is closed when the RPC connection is closed. +// Deprecated: use subscription error channel func (n *Notifier) Closed() <-chan interface{} { - return n.codec.Closed() -} - -// unsubscribe a subscription. -// If the subscription could not be found ErrSubscriptionNotFound is returned. -func (n *Notifier) unsubscribe(id ID) error { - n.subMu.Lock() - defer n.subMu.Unlock() - if s, found := n.active[id]; found { - close(s.err) - delete(n.active, id) - return nil + return n.h.conn.Closed() +} + +// takeSubscription returns the subscription (if one has been created). No subscription can +// be created after this call. +func (n *Notifier) takeSubscription() *Subscription { + n.mu.Lock() + defer n.mu.Unlock() + n.callReturned = true + return n.sub +} + +// acticate is called after the subscription ID was sent to client. Notifications are +// buffered before activation. This prevents notifications being sent to the client before +// the subscription ID is sent to the client. +func (n *Notifier) activate() error { + n.mu.Lock() + defer n.mu.Unlock() + + for _, data := range n.buffer { + if err := n.send(n.sub, data); err != nil { + return err + } } - return ErrSubscriptionNotFound -} - -// activate enables a subscription. Until a subscription is enabled all -// notifications are dropped. This method is called by the RPC server after -// the subscription ID was sent to client. This prevents notifications being -// send to the client before the subscription ID is send to the client. -func (n *Notifier) activate(id ID, namespace string) { - n.subMu.Lock() - defer n.subMu.Unlock() - if sub, found := n.inactive[id]; found { - sub.namespace = namespace - n.active[id] = sub - delete(n.inactive, id) + n.activated = true + return nil +} + +func (n *Notifier) send(sub *Subscription, data json.RawMessage) error { + params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data}) + ctx := context.Background() + return n.h.conn.Write(ctx, &jsonrpcMessage{ + Version: vsn, + Method: n.namespace + notificationMethodSuffix, + Params: params, + }) +} + +// A Subscription is created by a notifier and tight to that notifier. The client can use +// this subscription to wait for an unsubscribe request for the client, see Err(). +type Subscription struct { + ID ID + namespace string + err chan error // closed on unsubscribe +} + +// Err returns a channel that is closed when the client send an unsubscribe request. +func (s *Subscription) Err() <-chan error { + return s.err +} + +// MarshalJSON marshals a subscription as its ID. +func (s *Subscription) MarshalJSON() ([]byte, error) { + return json.Marshal(s.ID) +} + +// ClientSubscription is a subscription established through the Client's Subscribe or +// EthSubscribe methods. +type ClientSubscription struct { + client *Client + etype reflect.Type + channel reflect.Value + namespace string + subid string + in chan json.RawMessage + + quitOnce sync.Once // ensures quit is closed once + quit chan struct{} // quit is closed when the subscription exits + errOnce sync.Once // ensures err is closed once + err chan error +} + +func newClientSubscription(c *Client, namespace string, channel reflect.Value) *ClientSubscription { + sub := &ClientSubscription{ + client: c, + namespace: namespace, + etype: channel.Type().Elem(), + channel: channel, + quit: make(chan struct{}), + err: make(chan error, 1), + in: make(chan json.RawMessage), } + return sub +} + +// Err returns the subscription error channel. The intended use of Err is to schedule +// resubscription when the client connection is closed unexpectedly. +// +// The error channel receives a value when the subscription has ended due +// to an error. The received error is nil if Close has been called +// on the underlying client and no other error has occurred. +// +// The error channel is closed when Unsubscribe is called on the subscription. +func (sub *ClientSubscription) Err() <-chan error { + return sub.err +} + +// Unsubscribe unsubscribes the notification and closes the error channel. +// It can safely be called more than once. +func (sub *ClientSubscription) Unsubscribe() { + sub.quitWithError(nil, true) + sub.errOnce.Do(func() { close(sub.err) }) +} + +func (sub *ClientSubscription) quitWithError(err error, unsubscribeServer bool) { + sub.quitOnce.Do(func() { + // The dispatch loop won't be able to execute the unsubscribe call + // if it is blocked on deliver. Close sub.quit first because it + // unblocks deliver. + close(sub.quit) + if unsubscribeServer { + sub.requestUnsubscribe() + } + if err != nil { + if err == ErrClientQuit { + err = nil // Adhere to subscription semantics. + } + sub.err <- err + } + }) +} + +func (sub *ClientSubscription) deliver(result json.RawMessage) (ok bool) { + select { + case sub.in <- result: + return true + case <-sub.quit: + return false + } +} + +func (sub *ClientSubscription) start() { + sub.quitWithError(sub.forward()) +} + +func (sub *ClientSubscription) forward() (err error, unsubscribeServer bool) { + cases := []reflect.SelectCase{ + {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.quit)}, + {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.in)}, + {Dir: reflect.SelectSend, Chan: sub.channel}, + } + buffer := list.New() + defer buffer.Init() + for { + var chosen int + var recv reflect.Value + if buffer.Len() == 0 { + // Idle, omit send case. + chosen, recv, _ = reflect.Select(cases[:2]) + } else { + // Non-empty buffer, send the first queued item. + cases[2].Send = reflect.ValueOf(buffer.Front().Value) + chosen, recv, _ = reflect.Select(cases) + } + + switch chosen { + case 0: // <-sub.quit + return nil, false + case 1: // <-sub.in + val, err := sub.unmarshal(recv.Interface().(json.RawMessage)) + if err != nil { + return err, true + } + if buffer.Len() == maxClientSubscriptionBuffer { + return ErrSubscriptionQueueOverflow, true + } + buffer.PushBack(val) + case 2: // sub.channel<- + cases[2].Send = reflect.Value{} // Don't hold onto the value. + buffer.Remove(buffer.Front()) + } + } +} + +func (sub *ClientSubscription) unmarshal(result json.RawMessage) (interface{}, error) { + val := reflect.New(sub.etype) + err := json.Unmarshal(result, val.Interface()) + return val.Elem().Interface(), err +} + +func (sub *ClientSubscription) requestUnsubscribe() error { + var result interface{} + return sub.client.Call(&result, sub.namespace+unsubscribeMethodSuffix, sub.subid) } diff --git a/rpc/types.go b/rpc/types.go index f32f86bddc..c7539a2b20 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -17,13 +17,11 @@ package rpc import ( + "context" "fmt" "math" - "reflect" "strings" - "sync" - mapset "github.com/deckarep/golang-set" "github.com/tomochain/tomochain/common/hexutil" ) @@ -35,57 +33,6 @@ type API struct { Public bool // indication if the methods must be considered safe for public use } -// callback is a method callback which was registered in the server -type callback struct { - rcvr reflect.Value // receiver of method - method reflect.Method // callback - argTypes []reflect.Type // input argument types - hasCtx bool // method's first argument is a context (not included in argTypes) - errPos int // err return idx, of -1 when method cannot return error - isSubscribe bool // indication if the callback is a subscription -} - -// service represents a registered object -type service struct { - name string // name for service - typ reflect.Type // receiver type - callbacks callbacks // registered handlers - subscriptions subscriptions // available subscriptions/notifications -} - -// serverRequest is an incoming request -type serverRequest struct { - id interface{} - svcname string - callb *callback - args []reflect.Value - isUnsubscribe bool - err Error -} - -type serviceRegistry map[string]*service // collection of services -type callbacks map[string]*callback // collection of RPC callbacks -type subscriptions map[string]*callback // collection of subscription callbacks - -// Server represents a RPC server -type Server struct { - services serviceRegistry - - run int32 - codecsMu sync.Mutex - codecs mapset.Set -} - -// rpcRequest represents a raw incoming RPC request -type rpcRequest struct { - service string - method string - id interface{} - isPubSub bool - params interface{} - err Error // invalid batch element -} - // Error wraps RPC errors, which contain an error code in addition to the message. type Error interface { Error() string // returns the message @@ -96,24 +43,19 @@ type Error interface { // a RPC session. Implementations must be go-routine safe since the codec can be called in // multiple go-routines concurrently. type ServerCodec interface { - // Read next request - ReadRequestHeaders() ([]rpcRequest, bool, Error) - // Parse request argument to the given types - ParseRequestArguments(argTypes []reflect.Type, params interface{}) ([]reflect.Value, Error) - // Assemble success response, expects response id and payload - CreateResponse(id interface{}, reply interface{}) interface{} - // Assemble error response, expects response id and error - CreateErrorResponse(id interface{}, err Error) interface{} - // Assemble error response with extra information about the error through info - CreateErrorResponseWithInfo(id interface{}, err Error, info interface{}) interface{} - // Create notification response - CreateNotification(id, namespace string, event interface{}) interface{} - // Write msg to client. - Write(msg interface{}) error - // Close underlying data stream + Read() (msgs []*jsonrpcMessage, isBatch bool, err error) Close() - // Closed when underlying connection is closed + jsonWriter +} + +// jsonWriter can write JSON messages to its underlying connection. +// Implementations must be safe for concurrent use. +type jsonWriter interface { + Write(context.Context, interface{}) error + // Closed returns a channel which is closed when the connection is closed. Closed() <-chan interface{} + // RemoteAddr returns the peer address of the connection. + RemoteAddr() string } type BlockNumber int64 diff --git a/rpc/utils.go b/rpc/utils.go deleted file mode 100644 index 9315cab591..0000000000 --- a/rpc/utils.go +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package rpc - -import ( - "bufio" - "context" - crand "crypto/rand" - "encoding/binary" - "encoding/hex" - "math/big" - "math/rand" - "reflect" - "strings" - "sync" - "time" - "unicode" - "unicode/utf8" -) - -var ( - subscriptionIDGenMu sync.Mutex - subscriptionIDGen = idGenerator() -) - -// Is this an exported - upper case - name? -func isExported(name string) bool { - rune, _ := utf8.DecodeRuneInString(name) - return unicode.IsUpper(rune) -} - -// Is this type exported or a builtin? -func isExportedOrBuiltinType(t reflect.Type) bool { - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - // PkgPath will be non-empty even for an exported type, - // so we need to check the type name as well. - return isExported(t.Name()) || t.PkgPath() == "" -} - -var contextType = reflect.TypeOf((*context.Context)(nil)).Elem() - -// isContextType returns an indication if the given t is of context.Context or *context.Context type -func isContextType(t reflect.Type) bool { - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - return t == contextType -} - -var errorType = reflect.TypeOf((*error)(nil)).Elem() - -// Implements this type the error interface -func isErrorType(t reflect.Type) bool { - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - return t.Implements(errorType) -} - -var subscriptionType = reflect.TypeOf((*Subscription)(nil)).Elem() - -// isSubscriptionType returns an indication if the given t is of Subscription or *Subscription type -func isSubscriptionType(t reflect.Type) bool { - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - return t == subscriptionType -} - -// isPubSub tests whether the given method has as as first argument a context.Context -// and returns the pair (Subscription, error) -func isPubSub(methodType reflect.Type) bool { - // numIn(0) is the receiver type - if methodType.NumIn() < 2 || methodType.NumOut() != 2 { - return false - } - - return isContextType(methodType.In(1)) && - isSubscriptionType(methodType.Out(0)) && - isErrorType(methodType.Out(1)) -} - -// formatName will convert to first character to lower case -func formatName(name string) string { - ret := []rune(name) - if len(ret) > 0 { - ret[0] = unicode.ToLower(ret[0]) - } - return string(ret) -} - -var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem() - -// Indication if this type should be serialized in hex -func isHexNum(t reflect.Type) bool { - if t == nil { - return false - } - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - - return t == bigIntType -} - -// suitableCallbacks iterates over the methods of the given type. It will determine if a method satisfies the criteria -// for a RPC callback or a subscription callback and adds it to the collection of callbacks or subscriptions. See server -// documentation for a summary of these criteria. -func suitableCallbacks(rcvr reflect.Value, typ reflect.Type) (callbacks, subscriptions) { - callbacks := make(callbacks) - subscriptions := make(subscriptions) - -METHODS: - for m := 0; m < typ.NumMethod(); m++ { - method := typ.Method(m) - mtype := method.Type - mname := formatName(method.Name) - if method.PkgPath != "" { // method must be exported - continue - } - - var h callback - h.isSubscribe = isPubSub(mtype) - h.rcvr = rcvr - h.method = method - h.errPos = -1 - - firstArg := 1 - numIn := mtype.NumIn() - if numIn >= 2 && mtype.In(1) == contextType { - h.hasCtx = true - firstArg = 2 - } - - if h.isSubscribe { - h.argTypes = make([]reflect.Type, numIn-firstArg) // skip rcvr type - for i := firstArg; i < numIn; i++ { - argType := mtype.In(i) - if isExportedOrBuiltinType(argType) { - h.argTypes[i-firstArg] = argType - } else { - continue METHODS - } - } - - subscriptions[mname] = &h - continue METHODS - } - - // determine method arguments, ignore first arg since it's the receiver type - // Arguments must be exported or builtin types - h.argTypes = make([]reflect.Type, numIn-firstArg) - for i := firstArg; i < numIn; i++ { - argType := mtype.In(i) - if !isExportedOrBuiltinType(argType) { - continue METHODS - } - h.argTypes[i-firstArg] = argType - } - - // check that all returned values are exported or builtin types - for i := 0; i < mtype.NumOut(); i++ { - if !isExportedOrBuiltinType(mtype.Out(i)) { - continue METHODS - } - } - - // when a method returns an error it must be the last returned value - h.errPos = -1 - for i := 0; i < mtype.NumOut(); i++ { - if isErrorType(mtype.Out(i)) { - h.errPos = i - break - } - } - - if h.errPos >= 0 && h.errPos != mtype.NumOut()-1 { - continue METHODS - } - - switch mtype.NumOut() { - case 0, 1, 2: - if mtype.NumOut() == 2 && h.errPos == -1 { // method must one return value and 1 error - continue METHODS - } - callbacks[mname] = &h - } - } - - return callbacks, subscriptions -} - -// idGenerator helper utility that generates a (pseudo) random sequence of -// bytes that are used to generate identifiers. -func idGenerator() *rand.Rand { - if seed, err := binary.ReadVarint(bufio.NewReader(crand.Reader)); err == nil { - return rand.New(rand.NewSource(seed)) - } - return rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) -} - -// NewID generates a identifier that can be used as an identifier in the RPC interface. -// e.g. filter and subscription identifier. -func NewID() ID { - subscriptionIDGenMu.Lock() - defer subscriptionIDGenMu.Unlock() - - id := make([]byte, 16) - for i := 0; i < len(id); i += 7 { - val := subscriptionIDGen.Int63() - for j := 0; i+j < len(id) && j < 7; j++ { - id[i+j] = byte(val) - val >>= 8 - } - } - - rpcId := hex.EncodeToString(id) - // rpc ID's are RPC quantities, no leading zero's and 0 is 0x0 - rpcId = strings.TrimLeft(rpcId, "0") - if rpcId == "" { - rpcId = "0" - } - - return ID("0x" + rpcId) -} diff --git a/rpc/utils_test.go b/rpc/utils_test.go deleted file mode 100644 index e0e063f607..0000000000 --- a/rpc/utils_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2016 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package rpc - -import ( - "strings" - "testing" -) - -func TestNewID(t *testing.T) { - hexchars := "0123456789ABCDEFabcdef" - for i := 0; i < 100; i++ { - id := string(NewID()) - if !strings.HasPrefix(id, "0x") { - t.Fatalf("invalid ID prefix, want '0x...', got %s", id) - } - - id = id[2:] - if len(id) == 0 || len(id) > 32 { - t.Fatalf("invalid ID length, want len(id) > 0 && len(id) <= 32), got %d", len(id)) - } - - for i := 0; i < len(id); i++ { - if strings.IndexByte(hexchars, id[i]) == -1 { - t.Fatalf("unexpected byte, want any valid hex char, got %c", id[i]) - } - } - } -} diff --git a/rpc/websocket.go b/rpc/websocket.go index 1faaf400b9..43ef76959e 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -20,7 +20,9 @@ import ( "bytes" "context" "crypto/tls" + "encoding/base64" "encoding/json" + "errors" "fmt" "net" "net/http" @@ -30,8 +32,9 @@ import ( "time" mapset "github.com/deckarep/golang-set" - "github.com/tomochain/tomochain/log" "golang.org/x/net/websocket" + + "github.com/tomochain/tomochain/log" ) // websocketJSONCodec is a custom JSON codec with payload size enforcement and @@ -55,24 +58,39 @@ var websocketJSONCodec = websocket.Codec{ // // allowedOrigins should be a comma-separated list of allowed origin URLs. // To allow connections with any origin, pass "*". -func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { +func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { return websocket.Server{ Handshake: wsHandshakeValidator(allowedOrigins), Handler: func(conn *websocket.Conn) { - // Create a custom encode/decode pair to enforce payload size and number encoding - conn.MaxPayloadBytes = maxRequestContentLength - - encoder := func(v interface{}) error { - return websocketJSONCodec.Send(conn, v) - } - decoder := func(v interface{}) error { - return websocketJSONCodec.Receive(conn, v) - } - srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) + codec := newWebsocketCodec(conn) + s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions) }, } } +func newWebsocketCodec(conn *websocket.Conn) ServerCodec { + // Create a custom encode/decode pair to enforce payload size and number encoding + conn.MaxPayloadBytes = maxRequestContentLength + encoder := func(v interface{}) error { + return websocketJSONCodec.Send(conn, v) + } + decoder := func(v interface{}) error { + return websocketJSONCodec.Receive(conn, v) + } + rpcconn := Conn(conn) + if conn.IsServerConn() { + // Override remote address with the actual socket address because + // package websocket crashes if there is no request origin. + addr := conn.Request().RemoteAddr + if wsaddr := conn.RemoteAddr().(*websocket.Addr); wsaddr.URL != nil { + // Add origin if present. + addr += "(" + wsaddr.URL.String() + ")" + } + rpcconn = connWithRemoteAddr{conn, addr} + } + return NewCodec(rpcconn, encoder, decoder) +} + // NewWSServer creates a new websocket RPC server around an API provider. // // Deprecated: use Server.WebsocketHandler @@ -104,26 +122,22 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http } } - log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice())) + log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) f := func(cfg *websocket.Config, req *http.Request) error { + // Verify origin against whitelist. origin := strings.ToLower(req.Header.Get("Origin")) if allowAllOrigins || origins.Contains(origin) { return nil } - log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) - return fmt.Errorf("origin %s not allowed", origin) + log.Warn("Rejected WebSocket connection", "origin", origin) + return errors.New("origin not allowed") } return f } -// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server -// that is listening on the given endpoint. -// -// The context is used for the initial connection establishment. It does not -// affect subsequent interactions with the client. -func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { +func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { if origin == "" { var err error if origin, err = os.Hostname(); err != nil { @@ -140,8 +154,31 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error return nil, err } - return newClient(ctx, func(ctx context.Context) (net.Conn, error) { - return wsDialContext(ctx, config) + if config.Location.User != nil { + b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String())) + config.Header.Add("Authorization", "Basic "+b64auth) + config.Location.User = nil + } + return config, nil +} + +// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server +// that is listening on the given endpoint. +// +// The context is used for the initial connection establishment. It does not +// affect subsequent interactions with the client. +func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { + config, err := wsGetConfig(endpoint, origin) + if err != nil { + return nil, err + } + + return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { + conn, err := wsDialContext(ctx, config) + if err != nil { + return nil, err + } + return newWebsocketCodec(conn), nil }) } From 4ddc12c852cbbad57333e2a7a4b341d6c48c6431 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 30 Nov 2023 14:50:46 +0700 Subject: [PATCH 106/119] Add HTTP timeout and cleanup --- contracts/trc21issuer/simulation/test/main.go | 43 +++++------ ethclient/ethclient.go | 10 --- node/api.go | 2 +- node/config.go | 5 ++ node/defaults.go | 2 + node/node.go | 73 +++---------------- 6 files changed, 36 insertions(+), 99 deletions(-) diff --git a/contracts/trc21issuer/simulation/test/main.go b/contracts/trc21issuer/simulation/test/main.go index c1467968d7..79c31c50d7 100644 --- a/contracts/trc21issuer/simulation/test/main.go +++ b/contracts/trc21issuer/simulation/test/main.go @@ -2,17 +2,16 @@ package main import ( "context" - "encoding/json" "fmt" + "log" + "math/big" + "time" + "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/contracts/trc21issuer" "github.com/tomochain/tomochain/contracts/trc21issuer/simulation" "github.com/tomochain/tomochain/ethclient" - "log" - "math/big" - "time" ) var ( @@ -42,17 +41,15 @@ func airDropTokenToAccountNoTomo() { fmt.Println("wait 10s to airdrop success ", tx.Hash().Hex()) time.Sleep(10 * time.Second) - _, receiptRpc, err := client.GetTransactionReceiptResult(context.Background(), tx.Hash()) - receipt := map[string]interface{}{} - err = json.Unmarshal(receiptRpc, &receipt) + receipt, err := client.TransactionReceipt(context.Background(), tx.Hash()) if err != nil { - log.Fatal("can't transaction's receipt ", err, "hash", tx.Hash().Hex()) + log.Fatal("can't get transaction's receipt ", err, "hash", tx.Hash().Hex()) } - fee := big.NewInt(0).SetUint64(hexutil.MustDecodeUint64(receipt["gasUsed"].(string))) - if hexutil.MustDecodeUint64(receipt["blockNumber"].(string)) > common.TIPTRC21Fee.Uint64() { + fee := big.NewInt(0).SetUint64(receipt.GasUsed) + if receipt.BlockNumber.Uint64() > common.TIPTRC21Fee.Uint64() { fee = fee.Mul(fee, common.TRC21GasPrice) } - fmt.Println("fee", fee.Uint64(), "number", hexutil.MustDecodeUint64(receipt["blockNumber"].(string))) + fmt.Println("fee", fee.Uint64(), "number", receipt.BlockNumber.Uint64()) remainFee = big.NewInt(0).Sub(remainFee, fee) //check balance fee balanceIssuerFee, err := trc21IssuerInstance.GetTokenCapacity(trc21TokenAddr) @@ -105,17 +102,15 @@ func testTransferTRC21TokenWithAccountNoTomo() { if err != nil || balance.Cmp(remainAirDrop) != 0 { log.Fatal("check balance after fail transferAmount in tr21: ", err, "get", balance, "wanted", remainAirDrop) } - _, receiptRpc, err := client.GetTransactionReceiptResult(context.Background(), tx.Hash()) - receipt := map[string]interface{}{} - err = json.Unmarshal(receiptRpc, &receipt) + receipt, err := client.TransactionReceipt(context.Background(), tx.Hash()) if err != nil { - log.Fatal("can't transaction's receipt ", err, "hash", tx.Hash().Hex()) + log.Fatal("can't get transaction's receipt ", err, "hash", tx.Hash().Hex()) } - fee := big.NewInt(0).SetUint64(hexutil.MustDecodeUint64(receipt["gasUsed"].(string))) - if hexutil.MustDecodeUint64(receipt["blockNumber"].(string)) > common.TIPTRC21Fee.Uint64() { + fee := big.NewInt(0).SetUint64(receipt.GasUsed) + if receipt.BlockNumber.Uint64() > common.TIPTRC21Fee.Uint64() { fee = fee.Mul(fee, common.TRC21GasPrice) } - fmt.Println("fee", fee.Uint64(), "number", hexutil.MustDecodeUint64(receipt["blockNumber"].(string))) + fmt.Println("fee", fee.Uint64(), "number", receipt.BlockNumber.Uint64()) remainFee = big.NewInt(0).Sub(remainFee, fee) //check balance fee balanceIssuerFee, err := trc21IssuerInstance.GetTokenCapacity(trc21TokenAddr) @@ -172,17 +167,15 @@ func testTransferTrc21Fail() { if err != nil || balance.Cmp(ownerBalance) != 0 { log.Fatal("can't get balance token fee in smart contract: ", err, "got", balanceIssuerFee, "wanted", remainFee) } - _, receiptRpc, err := client.GetTransactionReceiptResult(context.Background(), tx.Hash()) - receipt := map[string]interface{}{} - err = json.Unmarshal(receiptRpc, &receipt) + receipt, err := client.TransactionReceipt(context.Background(), tx.Hash()) if err != nil { log.Fatal("can't transaction's receipt ", err, "hash", tx.Hash().Hex()) } - fee := big.NewInt(0).SetUint64(hexutil.MustDecodeUint64(receipt["gasUsed"].(string))) - if hexutil.MustDecodeUint64(receipt["blockNumber"].(string)) > common.TIPTRC21Fee.Uint64() { + fee := big.NewInt(0).SetUint64(receipt.GasUsed) + if receipt.BlockNumber.Uint64() > common.TIPTRC21Fee.Uint64() { fee = fee.Mul(fee, common.TRC21GasPrice) } - fmt.Println("fee", fee.Uint64(), "number", hexutil.MustDecodeUint64(receipt["blockNumber"].(string))) + fmt.Println("fee", fee.Uint64(), "number", receipt.BlockNumber.Uint64()) remainFee = big.NewInt(0).Sub(remainFee, fee) //check balance fee balanceIssuerFee, err = trc21IssuerInstance.GetTokenCapacity(trc21TokenAddr) diff --git a/ethclient/ethclient.go b/ethclient/ethclient.go index d4da5b6cc8..3decd2c41d 100644 --- a/ethclient/ethclient.go +++ b/ethclient/ethclient.go @@ -253,16 +253,6 @@ func (ec *Client) TransactionReceipt(ctx context.Context, txHash common.Hash) (* return r, err } -func (ec *Client) GetTransactionReceiptResult(ctx context.Context, txHash common.Hash) (*types.Receipt, json.RawMessage, error) { - var r *types.Receipt - result, err := ec.c.GetResultCallContext(ctx, &r, "eth_getTransactionReceipt", txHash) - if err == nil { - if r == nil { - return nil, nil, ethereum.NotFound - } - } - return r, result, err -} func toBlockNumArg(number *big.Int) string { if number == nil { return "latest" diff --git a/node/api.go b/node/api.go index 23edbe2b3f..0f40039b60 100644 --- a/node/api.go +++ b/node/api.go @@ -157,7 +157,7 @@ func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis } } - if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts); err != nil { + if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts); err != nil { return false, err } return true, nil diff --git a/node/config.go b/node/config.go index b8ad712fc8..066af60762 100644 --- a/node/config.go +++ b/node/config.go @@ -33,6 +33,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p" "github.com/tomochain/tomochain/p2p/discover" + "github.com/tomochain/tomochain/rpc" ) const ( @@ -119,6 +120,10 @@ type Config struct { // exposed. HTTPModules []string `toml:",omitempty"` + // HTTPTimeouts allows for customization of the timeout values used by the HTTP RPC + // interface. + HTTPTimeouts rpc.HTTPTimeouts + // WSHost is the host interface on which to start the websocket RPC server. If // this field is empty, no websocket API endpoint will be started. WSHost string `toml:",omitempty"` diff --git a/node/defaults.go b/node/defaults.go index 8eb84740d8..071a13a4d3 100644 --- a/node/defaults.go +++ b/node/defaults.go @@ -24,6 +24,7 @@ import ( "github.com/tomochain/tomochain/p2p" "github.com/tomochain/tomochain/p2p/nat" + "github.com/tomochain/tomochain/rpc" ) const ( @@ -39,6 +40,7 @@ var DefaultConfig = Config{ HTTPPort: DefaultHTTPPort, HTTPModules: []string{"net", "web3"}, HTTPVirtualHosts: []string{"localhost"}, + HTTPTimeouts: rpc.DefaultHTTPTimeouts, WSPort: DefaultWSPort, WSModules: []string{"net", "web3"}, P2P: p2p.Config{ diff --git a/node/node.go b/node/node.go index 2f1d4853e9..21fa842604 100644 --- a/node/node.go +++ b/node/node.go @@ -19,7 +19,6 @@ package node import ( "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "net" "os" "path/filepath" @@ -28,7 +27,9 @@ import ( "sync" "github.com/prometheus/prometheus/util/flock" + "github.com/tomochain/tomochain/accounts" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/internal/debug" @@ -264,7 +265,7 @@ func (n *Node) startRPC(services map[reflect.Type]Service) error { n.stopInProc() return err } - if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts); err != nil { + if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts); err != nil { n.stopIPC() n.stopInProc() return err @@ -304,50 +305,16 @@ func (n *Node) stopInProc() { // startIPC initializes and starts the IPC RPC endpoint. func (n *Node) startIPC(apis []rpc.API) error { - // Short circuit if the IPC endpoint isn't being exposed if n.ipcEndpoint == "" { - return nil - } - // Register all the APIs exposed by the services - handler := rpc.NewServer() - for _, api := range apis { - if err := handler.RegisterName(api.Namespace, api.Service); err != nil { - return err - } - n.log.Debug("IPC registered", "service", api.Service, "namespace", api.Namespace) + return nil // IPC disabled. } - // All APIs registered, start the IPC listener - var ( - listener net.Listener - err error - ) - if listener, err = rpc.CreateIPCListener(n.ipcEndpoint); err != nil { + listener, handler, err := rpc.StartIPCEndpoint(n.ipcEndpoint, apis) + if err != nil { return err } - go func() { - n.log.Info("IPC endpoint opened", "url", n.ipcEndpoint) - - for { - conn, err := listener.Accept() - if err != nil { - // Terminate if the listener was closed - n.lock.RLock() - closed := n.ipcListener == nil - n.lock.RUnlock() - if closed { - return - } - // Not closed, just some error; report and continue - n.log.Error("IPC accept failed", "err", err) - continue - } - go handler.ServeCodec(rpc.NewJSONCodec(conn), rpc.OptionMethodInvocation|rpc.OptionSubscriptions) - } - }() - // All listeners booted successfully n.ipcListener = listener n.ipcHandler = handler - + n.log.Info("IPC endpoint opened", "url", n.ipcEndpoint) return nil } @@ -366,35 +333,15 @@ func (n *Node) stopIPC() { } // startHTTP initializes and starts the HTTP RPC endpoint. -func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string) error { +func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts) error { // Short circuit if the HTTP endpoint isn't being exposed if endpoint == "" { return nil } - // Generate the whitelist based on the allowed modules - whitelist := make(map[string]bool) - for _, module := range modules { - whitelist[module] = true - } - // Register all the APIs exposed by the services - handler := rpc.NewServer() - for _, api := range apis { - if whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { - if err := handler.RegisterName(api.Namespace, api.Service); err != nil { - return err - } - n.log.Debug("HTTP registered", "service", api.Service, "namespace", api.Namespace) - } - } - // All APIs registered, start the HTTP listener - var ( - listener net.Listener - err error - ) - if listener, err = net.Listen("tcp", endpoint); err != nil { + listener, handler, err := rpc.StartHTTPEndpoint(endpoint, apis, modules, cors, vhosts, timeouts) + if err != nil { return err } - go rpc.NewHTTPServer(cors, vhosts, handler).Serve(listener) n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%s", endpoint), "cors", strings.Join(cors, ","), "vhosts", strings.Join(vhosts, ",")) // All listeners booted successfully n.httpEndpoint = endpoint From 6af9da78d9c9b3cbf9bbc5b5593f1f3f257f26e7 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 30 Nov 2023 14:50:57 +0700 Subject: [PATCH 107/119] Add rpc unit tests --- rpc/client_example_test.go | 4 +- rpc/client_test.go | 164 ++++++------ rpc/server_test.go | 208 ++++++++------- rpc/subscription_test.go | 389 ++++++++++------------------- rpc/testdata/invalid-badid.js | 7 + rpc/testdata/invalid-badversion.js | 19 ++ rpc/testdata/invalid-batch.js | 17 ++ rpc/testdata/invalid-idonly.js | 7 + rpc/testdata/invalid-nonobj.js | 7 + rpc/testdata/invalid-syntax.json | 5 + rpc/testdata/reqresp-batch.js | 8 + rpc/testdata/reqresp-echo.js | 16 ++ rpc/testdata/reqresp-namedparam.js | 5 + rpc/testdata/reqresp-noargsrets.js | 4 + rpc/testdata/reqresp-nomethod.js | 4 + rpc/testdata/reqresp-noparam.js | 4 + rpc/testdata/reqresp-paramsnull.js | 4 + rpc/testdata/revcall.js | 6 + rpc/testdata/revcall2.js | 7 + rpc/testdata/subscription.js | 12 + rpc/testservice_test.go | 180 +++++++++++++ rpc/websocket_test.go | 54 ++++ 22 files changed, 687 insertions(+), 444 deletions(-) create mode 100644 rpc/testdata/invalid-badid.js create mode 100644 rpc/testdata/invalid-badversion.js create mode 100644 rpc/testdata/invalid-batch.js create mode 100644 rpc/testdata/invalid-idonly.js create mode 100644 rpc/testdata/invalid-nonobj.js create mode 100644 rpc/testdata/invalid-syntax.json create mode 100644 rpc/testdata/reqresp-batch.js create mode 100644 rpc/testdata/reqresp-echo.js create mode 100644 rpc/testdata/reqresp-namedparam.js create mode 100644 rpc/testdata/reqresp-noargsrets.js create mode 100644 rpc/testdata/reqresp-nomethod.js create mode 100644 rpc/testdata/reqresp-noparam.js create mode 100644 rpc/testdata/reqresp-paramsnull.js create mode 100644 rpc/testdata/revcall.js create mode 100644 rpc/testdata/revcall2.js create mode 100644 rpc/testdata/subscription.js create mode 100644 rpc/testservice_test.go create mode 100644 rpc/websocket_test.go diff --git a/rpc/client_example_test.go b/rpc/client_example_test.go index fa0ee9dd98..3f32da368f 100644 --- a/rpc/client_example_test.go +++ b/rpc/client_example_test.go @@ -25,7 +25,7 @@ import ( "github.com/tomochain/tomochain/rpc" ) -// In this example, our client whishes to track the latest 'block number' +// In this example, our client wishes to track the latest 'block number' // known to the server. The server supports two methods: // // eth_getBlockByNumber("latest", {}) @@ -66,7 +66,7 @@ func subscribeBlocks(client *rpc.Client, subch chan Block) { defer cancel() // Subscribe to new blocks. - sub, err := client.EthSubscribe(ctx, subch, "newBlocks") + sub, err := client.EthSubscribe(ctx, subch, "newHeads") if err != nil { fmt.Println("subscribe error:", err) return diff --git a/rpc/client_test.go b/rpc/client_test.go index 0dc797677b..d1ee95a726 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -31,17 +31,18 @@ import ( "time" "github.com/davecgh/go-spew/spew" + "github.com/tomochain/tomochain/log" ) func TestClientRequest(t *testing.T) { - server := newTestServer("service", new(Service)) + server := newTestServer() defer server.Stop() client := DialInProc(server) defer client.Close() var resp Result - if err := client.Call(&resp, "service_echo", "hello", 10, &Args{"world"}); err != nil { + if err := client.Call(&resp, "test_echo", "hello", 10, &Args{"world"}); err != nil { t.Fatal(err) } if !reflect.DeepEqual(resp, Result{"hello", 10, &Args{"world"}}) { @@ -50,19 +51,19 @@ func TestClientRequest(t *testing.T) { } func TestClientBatchRequest(t *testing.T) { - server := newTestServer("service", new(Service)) + server := newTestServer() defer server.Stop() client := DialInProc(server) defer client.Close() batch := []BatchElem{ { - Method: "service_echo", + Method: "test_echo", Args: []interface{}{"hello", 10, &Args{"world"}}, Result: new(Result), }, { - Method: "service_echo", + Method: "test_echo", Args: []interface{}{"hello2", 11, &Args{"world"}}, Result: new(Result), }, @@ -77,12 +78,12 @@ func TestClientBatchRequest(t *testing.T) { } wantResult := []BatchElem{ { - Method: "service_echo", + Method: "test_echo", Args: []interface{}{"hello", 10, &Args{"world"}}, Result: &Result{"hello", 10, &Args{"world"}}, }, { - Method: "service_echo", + Method: "test_echo", Args: []interface{}{"hello2", 11, &Args{"world"}}, Result: &Result{"hello2", 11, &Args{"world"}}, }, @@ -90,7 +91,7 @@ func TestClientBatchRequest(t *testing.T) { Method: "no_such_method", Args: []interface{}{1, 2, 3}, Result: new(int), - Error: &jsonError{Code: -32601, Message: "The method no_such_method_ does not exist/is not available"}, + Error: &jsonError{Code: -32601, Message: "the method no_such_method does not exist/is not available"}, }, } if !reflect.DeepEqual(batch, wantResult) { @@ -98,6 +99,17 @@ func TestClientBatchRequest(t *testing.T) { } } +func TestClientNotify(t *testing.T) { + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + if err := client.Notify(context.Background(), "test_echo", "hello", 10, &Args{"world"}); err != nil { + t.Fatal(err) + } +} + // func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) } func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) } func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) } @@ -106,7 +118,12 @@ func TestClientCancelIPC(t *testing.T) { testClientCancel("ipc", t) } // This test checks that requests made through CallContext can be canceled by canceling // the context. func testClientCancel(transport string, t *testing.T) { - server := newTestServer("service", new(Service)) + // These tests take a lot of time, run them all at once. + // You probably want to run with -parallel 1 or comment out + // the call to t.Parallel if you enable the logging. + t.Parallel() + + server := newTestServer() defer server.Stop() // What we want to achieve is that the context gets canceled @@ -142,11 +159,6 @@ func testClientCancel(transport string, t *testing.T) { panic("unknown transport: " + transport) } - // These tests take a lot of time, run them all at once. - // You probably want to run with -parallel 1 or comment out - // the call to t.Parallel if you enable the logging. - t.Parallel() - // The actual test starts here. var ( wg sync.WaitGroup @@ -174,7 +186,8 @@ func testClientCancel(transport string, t *testing.T) { } // Now perform a call with the context. // The key thing here is that no call will ever complete successfully. - err := client.CallContext(ctx, nil, "service_sleep", 2*maxContextCancelTimeout) + sleepTime := maxContextCancelTimeout + 20*time.Millisecond + err := client.CallContext(ctx, nil, "test_sleep", sleepTime) if err != nil { log.Debug(fmt.Sprint("got expected error:", err)) } else { @@ -191,7 +204,7 @@ func testClientCancel(transport string, t *testing.T) { } func TestClientSubscribeInvalidArg(t *testing.T) { - server := newTestServer("service", new(Service)) + server := newTestServer() defer server.Stop() client := DialInProc(server) defer client.Close() @@ -221,14 +234,14 @@ func TestClientSubscribeInvalidArg(t *testing.T) { } func TestClientSubscribe(t *testing.T) { - server := newTestServer("eth", new(NotificationTestService)) + server := newTestServer() defer server.Stop() client := DialInProc(server) defer client.Close() nc := make(chan int) count := 10 - sub, err := client.EthSubscribe(context.Background(), nc, "someSubscription", count, 0) + sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", count, 0) if err != nil { t.Fatal("can't subscribe:", err) } @@ -251,46 +264,17 @@ func TestClientSubscribe(t *testing.T) { } } -func TestClientSubscribeCustomNamespace(t *testing.T) { - namespace := "custom" - server := newTestServer(namespace, new(NotificationTestService)) - defer server.Stop() - client := DialInProc(server) - defer client.Close() - - nc := make(chan int) - count := 10 - sub, err := client.Subscribe(context.Background(), namespace, nc, "someSubscription", count, 0) - if err != nil { - t.Fatal("can't subscribe:", err) - } - for i := 0; i < count; i++ { - if val := <-nc; val != i { - t.Fatalf("value mismatch: got %d, want %d", val, i) - } - } - - sub.Unsubscribe() - select { - case v := <-nc: - t.Fatal("received value after unsubscribe:", v) - case err := <-sub.Err(): - if err != nil { - t.Fatalf("Err returned a non-nil error after explicit unsubscribe: %q", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("subscription not closed within 1s after unsubscribe") - } -} - -// In this test, the connection drops while EthSubscribe is -// waiting for a response. +// In this test, the connection drops while Subscribe is waiting for a response. func TestClientSubscribeClose(t *testing.T) { - service := &NotificationTestService{ + server := newTestServer() + service := ¬ificationTestService{ gotHangSubscriptionReq: make(chan struct{}), unblockHangSubscription: make(chan struct{}), } - server := newTestServer("eth", service) + if err := server.RegisterName("nftest2", service); err != nil { + t.Fatal(err) + } + defer server.Stop() client := DialInProc(server) defer client.Close() @@ -302,7 +286,7 @@ func TestClientSubscribeClose(t *testing.T) { err error ) go func() { - sub, err = client.EthSubscribe(context.Background(), nc, "hangSubscription", 999) + sub, err = client.Subscribe(context.Background(), "nftest2", nc, "hangSubscription", 999) errc <- err }() @@ -313,20 +297,43 @@ func TestClientSubscribeClose(t *testing.T) { select { case err := <-errc: if err == nil { - t.Errorf("EthSubscribe returned nil error after Close") + t.Errorf("Subscribe returned nil error after Close") } if sub != nil { - t.Error("EthSubscribe returned non-nil subscription after Close") + t.Error("Subscribe returned non-nil subscription after Close") } case <-time.After(1 * time.Second): - t.Fatalf("EthSubscribe did not return within 1s after Close") + t.Fatalf("Subscribe did not return within 1s after Close") + } +} + +// This test reproduces https://github.com/ethereum/go-ethereum/issues/17837 where the +// client hangs during shutdown when Unsubscribe races with Client.Close. +func TestClientCloseUnsubscribeRace(t *testing.T) { + server := newTestServer() + defer server.Stop() + + for i := 0; i < 20; i++ { + client := DialInProc(server) + nc := make(chan int) + sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", 3, 1) + if err != nil { + t.Fatal(err) + } + go client.Close() + go sub.Unsubscribe() + select { + case <-sub.Err(): + case <-time.After(5 * time.Second): + t.Fatal("subscription not closed within timeout") + } } } // This test checks that Client doesn't lock up when a single subscriber // doesn't read subscription events. func TestClientNotificationStorm(t *testing.T) { - server := newTestServer("eth", new(NotificationTestService)) + server := newTestServer() defer server.Stop() doTest := func(count int, wantError bool) { @@ -338,7 +345,7 @@ func TestClientNotificationStorm(t *testing.T) { // Subscribe on the server. It will start sending many notifications // very quickly. nc := make(chan int) - sub, err := client.EthSubscribe(ctx, nc, "someSubscription", count, 0) + sub, err := client.Subscribe(ctx, "nftest", nc, "someSubscription", count, 0) if err != nil { t.Fatal("can't subscribe:", err) } @@ -360,7 +367,7 @@ func TestClientNotificationStorm(t *testing.T) { return } var r int - err := client.CallContext(ctx, &r, "eth_echo", i) + err := client.CallContext(ctx, &r, "nftest_echo", i) if err != nil { if !wantError { t.Fatalf("(%d/%d) call error: %v", i, count, err) @@ -375,7 +382,7 @@ func TestClientNotificationStorm(t *testing.T) { } func TestClientHTTP(t *testing.T) { - server := newTestServer("service", new(Service)) + server := newTestServer() defer server.Stop() client, hs := httpTestClient(server, "http", nil) @@ -392,7 +399,7 @@ func TestClientHTTP(t *testing.T) { for i := range results { i := i go func() { - errc <- client.Call(&results[i], "service_echo", + errc <- client.Call(&results[i], "test_echo", wantResult.String, wantResult.Int, wantResult.Args) }() } @@ -421,16 +428,16 @@ func TestClientHTTP(t *testing.T) { func TestClientReconnect(t *testing.T) { startServer := func(addr string) (*Server, net.Listener) { - srv := newTestServer("service", new(Service)) + srv := newTestServer() l, err := net.Listen("tcp", addr) if err != nil { - t.Fatal(err) + t.Fatal("can't listen:", err) } go http.Serve(l, srv.WebsocketHandler([]string{"*"})) return srv, l } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) defer cancel() // Start a server and corresponding client. @@ -442,21 +449,22 @@ func TestClientReconnect(t *testing.T) { // Perform a call. This should work because the server is up. var resp Result - if err := client.CallContext(ctx, &resp, "service_echo", "", 1, nil); err != nil { + if err := client.CallContext(ctx, &resp, "test_echo", "", 1, nil); err != nil { t.Fatal(err) } - // Shut down the server and try calling again. It shouldn't work. + // Shut down the server and allow for some cool down time so we can listen on the same + // address again. l1.Close() s1.Stop() - if err := client.CallContext(ctx, &resp, "service_echo", "", 2, nil); err == nil { + time.Sleep(2 * time.Second) + + // Try calling again. It shouldn't work. + if err := client.CallContext(ctx, &resp, "test_echo", "", 2, nil); err == nil { t.Error("successful call while the server is down") t.Logf("resp: %#v", resp) } - // Allow for some cool down time so we can listen on the same address again. - time.Sleep(2 * time.Second) - // Start it up again and call again. The connection should be reestablished. // We spawn multiple calls here to check whether this hangs somehow. s2, l2 := startServer(l1.Addr().String()) @@ -469,7 +477,7 @@ func TestClientReconnect(t *testing.T) { go func() { <-start var resp Result - errors <- client.CallContext(ctx, &resp, "service_echo", "", 3, nil) + errors <- client.CallContext(ctx, &resp, "test_echo", "", 3, nil) }() } close(start) @@ -479,20 +487,12 @@ func TestClientReconnect(t *testing.T) { errcount++ } } - t.Log("err:", err) + t.Logf("%d errors, last error: %v", errcount, err) if errcount > 1 { t.Errorf("expected one error after disconnect, got %d", errcount) } } -func newTestServer(serviceName string, service interface{}) *Server { - server := NewServer() - if err := server.RegisterName(serviceName, service); err != nil { - panic(err) - } - return server -} - func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, *httptest.Server) { // Create the HTTP server. var hs *httptest.Server diff --git a/rpc/server_test.go b/rpc/server_test.go index 90d62f26d8..2a6926abf9 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -17,146 +17,136 @@ package rpc import ( - "context" - "encoding/json" + "bufio" + "bytes" + "io" "net" - "reflect" + "os" + "path/filepath" + "strings" "testing" "time" ) -type Service struct{} - -type Args struct { - S string -} - -func (s *Service) NoArgsRets() { -} - -type Result struct { - String string - Int int - Args *Args -} - -func (s *Service) Echo(str string, i int, args *Args) Result { - return Result{str, i, args} -} - -func (s *Service) EchoWithCtx(ctx context.Context, str string, i int, args *Args) Result { - return Result{str, i, args} -} - -func (s *Service) Sleep(ctx context.Context, duration time.Duration) { - select { - case <-time.After(duration): - case <-ctx.Done(): - } -} - -func (s *Service) Rets() (string, error) { - return "", nil -} - -func (s *Service) InvalidRets1() (error, string) { - return nil, "" -} - -func (s *Service) InvalidRets2() (string, string) { - return "", "" -} - -func (s *Service) InvalidRets3() (string, string, error) { - return "", "", nil -} - -func (s *Service) Subscription(ctx context.Context) (*Subscription, error) { - return nil, nil -} - func TestServerRegisterName(t *testing.T) { server := NewServer() - service := new(Service) + service := new(testService) - if err := server.RegisterName("calc", service); err != nil { + if err := server.RegisterName("test", service); err != nil { t.Fatalf("%v", err) } - if len(server.services) != 2 { - t.Fatalf("Expected 2 service entries, got %d", len(server.services)) + if len(server.services.services) != 2 { + t.Fatalf("Expected 2 service entries, got %d", len(server.services.services)) } - svc, ok := server.services["calc"] + svc, ok := server.services.services["test"] if !ok { t.Fatalf("Expected service calc to be registered") } - if len(svc.callbacks) != 5 { - t.Errorf("Expected 5 callbacks for service 'calc', got %d", len(svc.callbacks)) - } - - if len(svc.subscriptions) != 1 { - t.Errorf("Expected 1 subscription for service 'calc', got %d", len(svc.subscriptions)) + wantCallbacks := 7 + if len(svc.callbacks) != wantCallbacks { + t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks)) } } -func testServerMethodExecution(t *testing.T, method string) { - server := NewServer() - service := new(Service) - - if err := server.RegisterName("test", service); err != nil { - t.Fatalf("%v", err) +func TestServer(t *testing.T) { + files, err := os.ReadDir("testdata") + if err != nil { + t.Fatal("where'd my testdata go?") } + for _, f := range files { + if f.IsDir() || strings.HasPrefix(f.Name(), ".") { + continue + } + path := filepath.Join("testdata", f.Name()) + name := strings.TrimSuffix(f.Name(), filepath.Ext(f.Name())) + t.Run(name, func(t *testing.T) { + runTestScript(t, path) + }) + } +} - stringArg := "string arg" - intArg := 1122 - argsArg := &Args{"abcde"} - params := []interface{}{stringArg, intArg, argsArg} - - request := map[string]interface{}{ - "id": 12345, - "method": "test_" + method, - "version": "2.0", - "params": params, +func runTestScript(t *testing.T, file string) { + server := newTestServer() + content, err := os.ReadFile(file) + if err != nil { + t.Fatal(err) } clientConn, serverConn := net.Pipe() defer clientConn.Close() - - go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation) - - out := json.NewEncoder(clientConn) - in := json.NewDecoder(clientConn) - - if err := out.Encode(request); err != nil { - t.Fatal(err) + go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions) + readbuf := bufio.NewReader(clientConn) + for _, line := range strings.Split(string(content), "\n") { + line = strings.TrimSpace(line) + switch { + case len(line) == 0 || strings.HasPrefix(line, "//"): + // skip comments, blank lines + continue + case strings.HasPrefix(line, "--> "): + t.Log(line) + // write to connection + clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if _, err := io.WriteString(clientConn, line[4:]+"\n"); err != nil { + t.Fatalf("write error: %v", err) + } + case strings.HasPrefix(line, "<-- "): + t.Log(line) + want := line[4:] + // read line from connection and compare text + clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + sent, err := readbuf.ReadString('\n') + if err != nil { + t.Fatalf("read error: %v", err) + } + sent = strings.TrimRight(sent, "\r\n") + if sent != want { + t.Errorf("wrong line from server\ngot: %s\nwant: %s", sent, want) + } + default: + panic("invalid line in test script: " + line) + } } +} - response := jsonSuccessResponse{Result: &Result{}} - if err := in.Decode(&response); err != nil { - t.Fatal(err) - } +// This test checks that responses are delivered for very short-lived connections that +// only carry a single request. +func TestServerShortLivedConn(t *testing.T) { + server := newTestServer() + defer server.Stop() - if result, ok := response.Result.(*Result); ok { - if result.String != stringArg { - t.Errorf("expected %s, got : %s\n", stringArg, result.String) + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("can't listen:", err) + } + defer listener.Close() + go server.ServeListener(listener) + + var ( + request = `{"jsonrpc":"2.0","id":1,"method":"rpc_modules"}` + "\n" + wantResp = `{"jsonrpc":"2.0","id":1,"result":{"nftest":"1.0","rpc":"1.0","test":"1.0"}}` + "\n" + deadline = time.Now().Add(10 * time.Second) + ) + for i := 0; i < 20; i++ { + conn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + t.Fatal("can't dial:", err) } - if result.Int != intArg { - t.Errorf("expected %d, got %d\n", intArg, result.Int) + defer conn.Close() + conn.SetDeadline(deadline) + // Write the request, then half-close the connection so the server stops reading. + conn.Write([]byte(request)) + conn.(*net.TCPConn).CloseWrite() + // Now try to get the response. + buf := make([]byte, 2000) + n, err := conn.Read(buf) + if err != nil { + t.Fatal("read error:", err) } - if !reflect.DeepEqual(result.Args, argsArg) { - t.Errorf("expected %v, got %v\n", argsArg, result) + if !bytes.Equal(buf[:n], []byte(wantResp)) { + t.Fatalf("wrong response: %s", buf[:n]) } - } else { - t.Fatalf("invalid response: expected *Result - got: %T", response.Result) } } - -func TestServerMethodExecution(t *testing.T) { - testServerMethodExecution(t, "echo") -} - -func TestServerMethodWithCtx(t *testing.T) { - testServerMethodExecution(t, "echoWithCtx") -} diff --git a/rpc/subscription_test.go b/rpc/subscription_test.go index 0ba177e63b..87ab4120af 100644 --- a/rpc/subscription_test.go +++ b/rpc/subscription_test.go @@ -17,314 +17,201 @@ package rpc import ( - "context" "encoding/json" "fmt" "net" - "sync" + "strings" "testing" "time" ) -type NotificationTestService struct { - mu sync.Mutex - unsubscribed bool - - gotHangSubscriptionReq chan struct{} - unblockHangSubscription chan struct{} -} - -func (s *NotificationTestService) Echo(i int) int { - return i -} - -func (s *NotificationTestService) wasUnsubCallbackCalled() bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.unsubscribed -} - -func (s *NotificationTestService) Unsubscribe(subid string) { - s.mu.Lock() - s.unsubscribed = true - s.mu.Unlock() -} - -func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) { - notifier, supported := NotifierFromContext(ctx) - if !supported { - return nil, ErrNotificationsUnsupported - } - - // by explicitly creating an subscription we make sure that the subscription id is send back to the client - // before the first subscription.Notify is called. Otherwise the events might be send before the response - // for the eth_subscribe method. - subscription := notifier.CreateSubscription() - - go func() { - // test expects n events, if we begin sending event immediately some events - // will probably be dropped since the subscription ID might not be send to - // the client. - time.Sleep(5 * time.Second) - for i := 0; i < n; i++ { - if err := notifier.Notify(subscription.ID, val+i); err != nil { - return - } - } - - select { - case <-notifier.Closed(): - s.mu.Lock() - s.unsubscribed = true - s.mu.Unlock() - case <-subscription.Err(): - s.mu.Lock() - s.unsubscribed = true - s.mu.Unlock() +func TestNewID(t *testing.T) { + hexchars := "0123456789ABCDEFabcdef" + for i := 0; i < 100; i++ { + id := string(NewID()) + if !strings.HasPrefix(id, "0x") { + t.Fatalf("invalid ID prefix, want '0x...', got %s", id) } - }() - - return subscription, nil -} - -// HangSubscription blocks on s.unblockHangSubscription before -// sending anything. -func (s *NotificationTestService) HangSubscription(ctx context.Context, val int) (*Subscription, error) { - notifier, supported := NotifierFromContext(ctx) - if !supported { - return nil, ErrNotificationsUnsupported - } - - s.gotHangSubscriptionReq <- struct{}{} - <-s.unblockHangSubscription - subscription := notifier.CreateSubscription() - - go func() { - notifier.Notify(subscription.ID, val) - }() - return subscription, nil -} - -func TestNotifications(t *testing.T) { - server := NewServer() - service := &NotificationTestService{} - - if err := server.RegisterName("eth", service); err != nil { - t.Fatalf("unable to register test service %v", err) - } - - clientConn, serverConn := net.Pipe() - - go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions) - out := json.NewEncoder(clientConn) - in := json.NewDecoder(clientConn) - - n := 5 - val := 12345 - request := map[string]interface{}{ - "id": 1, - "method": "eth_subscribe", - "version": "2.0", - "params": []interface{}{"someSubscription", n, val}, - } - - // create subscription - if err := out.Encode(request); err != nil { - t.Fatal(err) - } - - var subid string - response := jsonSuccessResponse{Result: subid} - if err := in.Decode(&response); err != nil { - t.Fatal(err) - } - - var ok bool - if _, ok = response.Result.(string); !ok { - t.Fatalf("expected subscription id, got %T", response.Result) - } - - for i := 0; i < n; i++ { - var notification jsonNotification - if err := in.Decode(¬ification); err != nil { - t.Fatalf("%v", err) - } - - if int(notification.Params.Result.(float64)) != val+i { - t.Fatalf("expected %d, got %d", val+i, notification.Params.Result) - } - } - - clientConn.Close() // causes notification unsubscribe callback to be called - time.Sleep(1 * time.Second) - - if !service.wasUnsubCallbackCalled() { - t.Error("unsubscribe callback not called after closing connection") - } -} - -func waitForMessages(t *testing.T, in *json.Decoder, successes chan<- jsonSuccessResponse, - failures chan<- jsonErrResponse, notifications chan<- jsonNotification, errors chan<- error) { - - // read and parse server messages - for { - var rmsg json.RawMessage - if err := in.Decode(&rmsg); err != nil { - return - } - - var responses []map[string]interface{} - if rmsg[0] == '[' { - if err := json.Unmarshal(rmsg, &responses); err != nil { - errors <- fmt.Errorf("Received invalid message: %s", rmsg) - return - } - } else { - var msg map[string]interface{} - if err := json.Unmarshal(rmsg, &msg); err != nil { - errors <- fmt.Errorf("Received invalid message: %s", rmsg) - return - } - responses = append(responses, msg) + id = id[2:] + if len(id) == 0 || len(id) > 32 { + t.Fatalf("invalid ID length, want len(id) > 0 && len(id) <= 32), got %d", len(id)) } - for _, msg := range responses { - // determine what kind of msg was received and broadcast - // it to over the corresponding channel - if _, found := msg["result"]; found { - successes <- jsonSuccessResponse{ - Version: msg["jsonrpc"].(string), - Id: msg["id"], - Result: msg["result"], - } - continue - } - if _, found := msg["error"]; found { - params := msg["params"].(map[string]interface{}) - failures <- jsonErrResponse{ - Version: msg["jsonrpc"].(string), - Id: msg["id"], - Error: jsonError{int(params["subscription"].(float64)), params["message"].(string), params["data"]}, - } - continue + for i := 0; i < len(id); i++ { + if strings.IndexByte(hexchars, id[i]) == -1 { + t.Fatalf("unexpected byte, want any valid hex char, got %c", id[i]) } - if _, found := msg["params"]; found { - params := msg["params"].(map[string]interface{}) - notifications <- jsonNotification{ - Version: msg["jsonrpc"].(string), - Method: msg["method"].(string), - Params: jsonSubscription{params["subscription"].(string), params["result"]}, - } - continue - } - errors <- fmt.Errorf("Received invalid message: %s", msg) } } } -// TestSubscriptionMultipleNamespaces ensures that subscriptions can exists -// for multiple different namespaces. -func TestSubscriptionMultipleNamespaces(t *testing.T) { +func TestSubscriptions(t *testing.T) { var ( - namespaces = []string{"eth", "shh", "bzz"} + namespaces = []string{"eth", "bzz"} + service = ¬ificationTestService{} + subCount = len(namespaces) + notificationCount = 3 + server = NewServer() - service = NotificationTestService{} clientConn, serverConn = net.Pipe() - - out = json.NewEncoder(clientConn) - in = json.NewDecoder(clientConn) - successes = make(chan jsonSuccessResponse) - failures = make(chan jsonErrResponse) - notifications = make(chan jsonNotification) - - errors = make(chan error, 10) + out = json.NewEncoder(clientConn) + in = json.NewDecoder(clientConn) + successes = make(chan subConfirmation) + notifications = make(chan subscriptionResult) + errors = make(chan error, subCount*notificationCount+1) ) // setup and start server for _, namespace := range namespaces { - if err := server.RegisterName(namespace, &service); err != nil { + if err := server.RegisterName(namespace, service); err != nil { t.Fatalf("unable to register test service %v", err) } } - - go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions) + go server.ServeCodec(NewJSONCodec(serverConn), 0) defer server.Stop() // wait for message and write them to the given channels - go waitForMessages(t, in, successes, failures, notifications, errors) + go waitForMessages(in, successes, notifications, errors) // create subscriptions one by one - n := 3 for i, namespace := range namespaces { request := map[string]interface{}{ "id": i, "method": fmt.Sprintf("%s_subscribe", namespace), - "version": "2.0", - "params": []interface{}{"someSubscription", n, i}, + "jsonrpc": "2.0", + "params": []interface{}{"someSubscription", notificationCount, i}, } - if err := out.Encode(&request); err != nil { t.Fatalf("Could not create subscription: %v", err) } } - // create all subscriptions in 1 batch - var requests []interface{} - for i, namespace := range namespaces { - requests = append(requests, map[string]interface{}{ - "id": i, - "method": fmt.Sprintf("%s_subscribe", namespace), - "version": "2.0", - "params": []interface{}{"someSubscription", n, i}, - }) - } - - if err := out.Encode(&requests); err != nil { - t.Fatalf("Could not create subscription in batch form: %v", err) - } - timeout := time.After(30 * time.Second) - subids := make(map[string]string, 2*len(namespaces)) - count := make(map[string]int, 2*len(namespaces)) - - for { - done := true - for id := range count { - if count, found := count[id]; !found || count < (2*n) { + subids := make(map[string]string, subCount) + count := make(map[string]int, subCount) + allReceived := func() bool { + done := len(count) == subCount + for _, c := range count { + if c < notificationCount { done = false } } - - if done && len(count) == len(namespaces) { - break - } - + return done + } + for !allReceived() { select { + case confirmation := <-successes: // subscription created + subids[namespaces[confirmation.reqid]] = string(confirmation.subid) + case notification := <-notifications: + count[notification.ID]++ case err := <-errors: t.Fatal(err) - case suc := <-successes: // subscription created - subids[namespaces[int(suc.Id.(float64))]] = suc.Result.(string) - case failure := <-failures: - t.Errorf("received error: %v", failure.Error) - case notification := <-notifications: - if cnt, found := count[notification.Params.Subscription]; found { - count[notification.Params.Subscription] = cnt + 1 - } else { - count[notification.Params.Subscription] = 1 - } case <-timeout: for _, namespace := range namespaces { subid, found := subids[namespace] if !found { - t.Errorf("Subscription for '%s' not created", namespace) + t.Errorf("subscription for %q not created", namespace) continue } - if count, found := count[subid]; !found || count < n { - t.Errorf("Didn't receive all notifications (%d<%d) in time for namespace '%s'", count, n, namespace) + if count, found := count[subid]; !found || count < notificationCount { + t.Errorf("didn't receive all notifications (%d<%d) in time for namespace %q", count, notificationCount, namespace) } } + t.Fatal("timed out") + } + } +} + +// This test checks that unsubscribing works. +func TestServerUnsubscribe(t *testing.T) { + // Start the server. + server := newTestServer() + service := ¬ificationTestService{unsubscribed: make(chan string)} + server.RegisterName("nftest2", service) + p1, p2 := net.Pipe() + go server.ServeCodec(NewJSONCodec(p1), OptionMethodInvocation|OptionSubscriptions) + + p2.SetDeadline(time.Now().Add(10 * time.Second)) + + // Subscribe. + p2.Write([]byte(`{"jsonrpc":"2.0","id":1,"method":"nftest2_subscribe","params":["someSubscription",0,10]}`)) + + // Handle received messages. + resps := make(chan subConfirmation) + notifications := make(chan subscriptionResult) + errors := make(chan error) + go waitForMessages(json.NewDecoder(p2), resps, notifications, errors) + + // Receive the subscription ID. + var sub subConfirmation + select { + case sub = <-resps: + case err := <-errors: + t.Fatal(err) + } + + // Unsubscribe and check that it is handled on the server side. + p2.Write([]byte(`{"jsonrpc":"2.0","method":"nftest2_unsubscribe","params":["` + sub.subid + `"]}`)) + for { + select { + case id := <-service.unsubscribed: + if id != string(sub.subid) { + t.Errorf("wrong subscription ID unsubscribed") + } return + case err := <-errors: + t.Fatal(err) + case <-notifications: + // drop notifications + } + } +} + +type subConfirmation struct { + reqid int + subid ID +} + +// waitForMessages reads RPC messages from 'in' and dispatches them into the given channels. +// It stops if there is an error. +func waitForMessages(in *json.Decoder, successes chan subConfirmation, notifications chan subscriptionResult, errors chan error) { + for { + resp, notification, err := readAndValidateMessage(in) + if err != nil { + errors <- err + return + } else if resp != nil { + successes <- *resp + } else { + notifications <- *notification + } + } +} + +func readAndValidateMessage(in *json.Decoder) (*subConfirmation, *subscriptionResult, error) { + var msg jsonrpcMessage + if err := in.Decode(&msg); err != nil { + return nil, nil, fmt.Errorf("decode error: %v", err) + } + switch { + case msg.isNotification(): + var res subscriptionResult + if err := json.Unmarshal(msg.Params, &res); err != nil { + return nil, nil, fmt.Errorf("invalid subscription result: %v", err) + } + return nil, &res, nil + case msg.isResponse(): + var c subConfirmation + if msg.Error != nil { + return nil, nil, msg.Error + } else if err := json.Unmarshal(msg.Result, &c.subid); err != nil { + return nil, nil, fmt.Errorf("invalid response: %v", err) + } else { + json.Unmarshal(msg.ID, &c.reqid) + return &c, nil, nil } + default: + return nil, nil, fmt.Errorf("unrecognized message: %v", msg) } } diff --git a/rpc/testdata/invalid-badid.js b/rpc/testdata/invalid-badid.js new file mode 100644 index 0000000000..2202b8ccd2 --- /dev/null +++ b/rpc/testdata/invalid-badid.js @@ -0,0 +1,7 @@ +// This test checks processing of messages with invalid ID. + +--> {"id":[],"method":"test_foo"} +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} + +--> {"id":{},"method":"test_foo"} +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} diff --git a/rpc/testdata/invalid-badversion.js b/rpc/testdata/invalid-badversion.js new file mode 100644 index 0000000000..75b5291dc3 --- /dev/null +++ b/rpc/testdata/invalid-badversion.js @@ -0,0 +1,19 @@ +// This test checks processing of messages with invalid Version. + +--> {"jsonrpc":"2.0","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"result":{"String":"x","Int":3,"Args":null}} + +--> {"jsonrpc":"2.1","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":"go-ethereum","id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":1,"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":2.0,"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"id":1,"method":"test_echo","params":["x", 3]} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} diff --git a/rpc/testdata/invalid-batch.js b/rpc/testdata/invalid-batch.js new file mode 100644 index 0000000000..9e6e27e207 --- /dev/null +++ b/rpc/testdata/invalid-batch.js @@ -0,0 +1,17 @@ +// This test checks the behavior of batches with invalid elements. +// Empty batches are not allowed. Batches may contain junk. + +--> [] +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"empty batch"}} + +--> [1] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] + +--> [1,2,3] +<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] + +// --> [null] +// <-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] + +--> [{"jsonrpc":"2.0","id":1,"method":"test_echo","params":["foo",1]},55,{"jsonrpc":"2.0","id":2,"method":"unknown_method"},{"foo":"bar"}] +<-- [{"jsonrpc":"2.0","id":1,"result":{"String":"foo","Int":1,"Args":null}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}},{"jsonrpc":"2.0","id":2,"error":{"code":-32601,"message":"the method unknown_method does not exist/is not available"}},{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}}] \ No newline at end of file diff --git a/rpc/testdata/invalid-idonly.js b/rpc/testdata/invalid-idonly.js new file mode 100644 index 0000000000..e1983889f9 --- /dev/null +++ b/rpc/testdata/invalid-idonly.js @@ -0,0 +1,7 @@ +// This test checks processing of messages that contain just the ID and nothing else. + +--> {"id":1} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} + +--> {"jsonrpc":"2.0","id":1} +<-- {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"invalid request"}} \ No newline at end of file diff --git a/rpc/testdata/invalid-nonobj.js b/rpc/testdata/invalid-nonobj.js new file mode 100644 index 0000000000..c7fc43a5b8 --- /dev/null +++ b/rpc/testdata/invalid-nonobj.js @@ -0,0 +1,7 @@ +// This test checks behavior for invalid requests. + +--> 1 +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} + +// --> null +// <-- {"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"invalid request"}} diff --git a/rpc/testdata/invalid-syntax.json b/rpc/testdata/invalid-syntax.json new file mode 100644 index 0000000000..b194299603 --- /dev/null +++ b/rpc/testdata/invalid-syntax.json @@ -0,0 +1,5 @@ +// This test checks that an error is written for invalid JSON requests. + +--> 'f +<-- {"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"invalid character '\\'' looking for beginning of value"}} + diff --git a/rpc/testdata/reqresp-batch.js b/rpc/testdata/reqresp-batch.js new file mode 100644 index 0000000000..977af76630 --- /dev/null +++ b/rpc/testdata/reqresp-batch.js @@ -0,0 +1,8 @@ +// There is no response for all-notification batches. + +--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] + +// This test checks regular batch calls. + +--> [{"jsonrpc":"2.0","id":2,"method":"test_echo","params":[]}, {"jsonrpc":"2.0","id": 3,"method":"test_echo","params":["x",3]}] +<-- [{"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 0"}},{"jsonrpc":"2.0","id":3,"result":{"String":"x","Int":3,"Args":null}}] diff --git a/rpc/testdata/reqresp-echo.js b/rpc/testdata/reqresp-echo.js new file mode 100644 index 0000000000..7a9e90321c --- /dev/null +++ b/rpc/testdata/reqresp-echo.js @@ -0,0 +1,16 @@ +// This test calls the test_echo method. + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": []} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 0"}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x"]} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32602,"message":"missing value for required argument 1"}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x", 3]} +<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":null}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echo", "params": ["x", 3, {"S": "foo"}]} +<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":{"S":"foo"}}} + +--> {"jsonrpc": "2.0", "id": 2, "method": "test_echoWithCtx", "params": ["x", 3, {"S": "foo"}]} +<-- {"jsonrpc":"2.0","id":2,"result":{"String":"x","Int":3,"Args":{"S":"foo"}}} diff --git a/rpc/testdata/reqresp-namedparam.js b/rpc/testdata/reqresp-namedparam.js new file mode 100644 index 0000000000..9a9372b0a7 --- /dev/null +++ b/rpc/testdata/reqresp-namedparam.js @@ -0,0 +1,5 @@ +// This test checks that an error response is sent for calls +// with named parameters. + +--> {"jsonrpc":"2.0","method":"test_echo","params":{"int":23},"id":3} +<-- {"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"non-array args"}} diff --git a/rpc/testdata/reqresp-noargsrets.js b/rpc/testdata/reqresp-noargsrets.js new file mode 100644 index 0000000000..e61cc708ba --- /dev/null +++ b/rpc/testdata/reqresp-noargsrets.js @@ -0,0 +1,4 @@ +// This test calls the test_noArgsRets method. + +--> {"jsonrpc": "2.0", "id": "foo", "method": "test_noArgsRets", "params": []} +<-- {"jsonrpc":"2.0","id":"foo","result":null} diff --git a/rpc/testdata/reqresp-nomethod.js b/rpc/testdata/reqresp-nomethod.js new file mode 100644 index 0000000000..58ea6f3079 --- /dev/null +++ b/rpc/testdata/reqresp-nomethod.js @@ -0,0 +1,4 @@ +// This test calls a method that doesn't exist. + +--> {"jsonrpc": "2.0", "id": 2, "method": "invalid_method", "params": [2, 3]} +<-- {"jsonrpc":"2.0","id":2,"error":{"code":-32601,"message":"the method invalid_method does not exist/is not available"}} diff --git a/rpc/testdata/reqresp-noparam.js b/rpc/testdata/reqresp-noparam.js new file mode 100644 index 0000000000..2edf486d9f --- /dev/null +++ b/rpc/testdata/reqresp-noparam.js @@ -0,0 +1,4 @@ +// This test checks that calls with no parameters work. + +--> {"jsonrpc":"2.0","method":"test_noArgsRets","id":3} +<-- {"jsonrpc":"2.0","id":3,"result":null} diff --git a/rpc/testdata/reqresp-paramsnull.js b/rpc/testdata/reqresp-paramsnull.js new file mode 100644 index 0000000000..8a01bae1bb --- /dev/null +++ b/rpc/testdata/reqresp-paramsnull.js @@ -0,0 +1,4 @@ +// This test checks that calls with "params":null work. + +--> {"jsonrpc":"2.0","method":"test_noArgsRets","params":null,"id":3} +<-- {"jsonrpc":"2.0","id":3,"result":null} diff --git a/rpc/testdata/revcall.js b/rpc/testdata/revcall.js new file mode 100644 index 0000000000..695d9858f8 --- /dev/null +++ b/rpc/testdata/revcall.js @@ -0,0 +1,6 @@ +// This test checks reverse calls. + +--> {"jsonrpc":"2.0","id":2,"method":"test_callMeBack","params":["foo",[1]]} +<-- {"jsonrpc":"2.0","id":1,"method":"foo","params":[1]} +--> {"jsonrpc":"2.0","id":1,"result":"my result"} +<-- {"jsonrpc":"2.0","id":2,"result":"my result"} diff --git a/rpc/testdata/revcall2.js b/rpc/testdata/revcall2.js new file mode 100644 index 0000000000..acab46551e --- /dev/null +++ b/rpc/testdata/revcall2.js @@ -0,0 +1,7 @@ +// This test checks reverse calls. + +--> {"jsonrpc":"2.0","id":2,"method":"test_callMeBackLater","params":["foo",[1]]} +<-- {"jsonrpc":"2.0","id":2,"result":null} +<-- {"jsonrpc":"2.0","id":1,"method":"foo","params":[1]} +--> {"jsonrpc":"2.0","id":1,"result":"my result"} + diff --git a/rpc/testdata/subscription.js b/rpc/testdata/subscription.js new file mode 100644 index 0000000000..9f10073010 --- /dev/null +++ b/rpc/testdata/subscription.js @@ -0,0 +1,12 @@ +// This test checks basic subscription support. + +--> {"jsonrpc":"2.0","id":1,"method":"nftest_subscribe","params":["someSubscription",5,1]} +<-- {"jsonrpc":"2.0","id":1,"result":"0x1"} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":1}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":2}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":3}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":4}} +<-- {"jsonrpc":"2.0","method":"nftest_subscription","params":{"subscription":"0x1","result":5}} + +--> {"jsonrpc":"2.0","id":2,"method":"nftest_echo","params":[11]} +<-- {"jsonrpc":"2.0","id":2,"result":11} diff --git a/rpc/testservice_test.go b/rpc/testservice_test.go new file mode 100644 index 0000000000..470870bacf --- /dev/null +++ b/rpc/testservice_test.go @@ -0,0 +1,180 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rpc + +import ( + "context" + "encoding/binary" + "errors" + "sync" + "time" +) + +func newTestServer() *Server { + server := NewServer() + server.idgen = sequentialIDGenerator() + if err := server.RegisterName("test", new(testService)); err != nil { + panic(err) + } + if err := server.RegisterName("nftest", new(notificationTestService)); err != nil { + panic(err) + } + return server +} + +func sequentialIDGenerator() func() ID { + var ( + mu sync.Mutex + counter uint64 + ) + return func() ID { + mu.Lock() + defer mu.Unlock() + counter++ + id := make([]byte, 8) + binary.BigEndian.PutUint64(id, counter) + return encodeID(id) + } +} + +type testService struct{} + +type Args struct { + S string +} + +type Result struct { + String string + Int int + Args *Args +} + +func (s *testService) NoArgsRets() {} + +func (s *testService) Echo(str string, i int, args *Args) Result { + return Result{str, i, args} +} + +func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *Args) Result { + return Result{str, i, args} +} + +func (s *testService) Sleep(ctx context.Context, duration time.Duration) { + time.Sleep(duration) +} + +func (s *testService) Rets() (string, error) { + return "", nil +} + +func (s *testService) InvalidRets1() (error, string) { + return nil, "" +} + +func (s *testService) InvalidRets2() (string, string) { + return "", "" +} + +func (s *testService) InvalidRets3() (string, string, error) { + return "", "", nil +} + +func (s *testService) CallMeBack(ctx context.Context, method string, args []interface{}) (interface{}, error) { + c, ok := ClientFromContext(ctx) + if !ok { + return nil, errors.New("no client") + } + var result interface{} + err := c.Call(&result, method, args...) + return result, err +} + +func (s *testService) CallMeBackLater(ctx context.Context, method string, args []interface{}) error { + c, ok := ClientFromContext(ctx) + if !ok { + return errors.New("no client") + } + go func() { + <-ctx.Done() + var result interface{} + c.Call(&result, method, args...) + }() + return nil +} + +func (s *testService) Subscription(ctx context.Context) (*Subscription, error) { + return nil, nil +} + +type notificationTestService struct { + unsubscribed chan string + gotHangSubscriptionReq chan struct{} + unblockHangSubscription chan struct{} +} + +func (s *notificationTestService) Echo(i int) int { + return i +} + +func (s *notificationTestService) Unsubscribe(subid string) { + if s.unsubscribed != nil { + s.unsubscribed <- subid + } +} + +func (s *notificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) { + notifier, supported := NotifierFromContext(ctx) + if !supported { + return nil, ErrNotificationsUnsupported + } + + // By explicitly creating an subscription we make sure that the subscription id is send + // back to the client before the first subscription.Notify is called. Otherwise the + // events might be send before the response for the *_subscribe method. + subscription := notifier.CreateSubscription() + go func() { + for i := 0; i < n; i++ { + if err := notifier.Notify(subscription.ID, val+i); err != nil { + return + } + } + select { + case <-notifier.Closed(): + case <-subscription.Err(): + } + if s.unsubscribed != nil { + s.unsubscribed <- string(subscription.ID) + } + }() + return subscription, nil +} + +// HangSubscription blocks on s.unblockHangSubscription before sending anything. +func (s *notificationTestService) HangSubscription(ctx context.Context, val int) (*Subscription, error) { + notifier, supported := NotifierFromContext(ctx) + if !supported { + return nil, ErrNotificationsUnsupported + } + s.gotHangSubscriptionReq <- struct{}{} + <-s.unblockHangSubscription + subscription := notifier.CreateSubscription() + + go func() { + notifier.Notify(subscription.ID, val) + }() + return subscription, nil +} diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go new file mode 100644 index 0000000000..5bf3780d62 --- /dev/null +++ b/rpc/websocket_test.go @@ -0,0 +1,54 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rpc + +import "testing" + +func TestWSGetConfigNoAuth(t *testing.T) { + config, err := wsGetConfig("ws://example.com:1234", "") + if err != nil { + t.Logf("wsGetConfig failed: %s", err) + t.Fail() + return + } + if config.Location.User != nil { + t.Log("User should have been stripped from the URL") + t.Fail() + } + if config.Location.Hostname() != "example.com" || + config.Location.Port() != "1234" || config.Location.Scheme != "ws" { + t.Logf("Unexpected URL: %s", config.Location) + t.Fail() + } +} + +func TestWSGetConfigWithBasicAuth(t *testing.T) { + config, err := wsGetConfig("wss://testuser:test-PASS_01@example.com:1234", "") + if err != nil { + t.Logf("wsGetConfig failed: %s", err) + t.Fail() + return + } + if config.Location.User != nil { + t.Log("User should have been stripped from the URL") + t.Fail() + } + if config.Header.Get("Authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { + t.Log("Basic auth header is incorrect") + t.Fail() + } +} From 9fbea60f685f3bf2020479b52c38a2c8bed2a7fc Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 4 Dec 2023 15:21:11 +0700 Subject: [PATCH 108/119] Fix protocol unit tests --- eth/helper_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eth/helper_test.go b/eth/helper_test.go index 85f75f01d1..a4da7286e0 100644 --- a/eth/helper_test.go +++ b/eth/helper_test.go @@ -150,7 +150,7 @@ func newTestPeer(name string, version int, pm *ProtocolManager, shake bool) (*te // Generate a random id and create the peer var id enode.ID - rand.Read(id.Bytes()) + rand.Read(id[:]) peer := pm.newPeer(version, p2p.NewPeer(id, name, nil), net) From ba6b0db008f63dd8bff7cc709a87c61c8c29918d Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 4 Dec 2023 16:40:42 +0700 Subject: [PATCH 109/119] Use gorilla websocket instead of native lib --- go.mod | 11 +- go.sum | 22 ++-- rpc/client.go | 5 +- rpc/ipc_unix.go | 3 +- rpc/json.go | 17 ++- rpc/websocket.go | 204 +++++++++++++--------------------- rpc/websocket_test.go | 251 ++++++++++++++++++++++++++++++++++++++---- 7 files changed, 335 insertions(+), 178 deletions(-) diff --git a/go.mod b/go.mod index 6c16905987..3e3be311db 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/go-stack/stack v1.8.1 github.com/golang/protobuf v1.5.2 github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb + github.com/gorilla/websocket v1.5.1 github.com/hashicorp/golang-lru v0.5.3 github.com/holiman/uint256 v1.2.2 github.com/huin/goupnp v1.0.3 @@ -38,10 +39,10 @@ require ( github.com/steakknife/bloomfilter v0.0.0-20180922174646-6819c0d2a570 github.com/stretchr/testify v1.8.1 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 - golang.org/x/crypto v0.1.0 - golang.org/x/net v0.8.0 + golang.org/x/crypto v0.14.0 + golang.org/x/net v0.17.0 golang.org/x/sync v0.1.0 - golang.org/x/sys v0.7.0 + golang.org/x/sys v0.13.0 golang.org/x/tools v0.7.0 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c gopkg.in/karalabe/cookiejar.v2 v2.0.0-20150724131613-8dcd6a7f4951 @@ -72,8 +73,8 @@ require ( github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3 // indirect golang.org/x/mod v0.11.0 // indirect - golang.org/x/term v0.6.0 // indirect - golang.org/x/text v0.8.0 // indirect + golang.org/x/term v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 693c5630c6..991dbcfe3b 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,8 @@ github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8q github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.3 h1:YPkqC67at8FYaadspW/6uE0COsBxS2656RLEr8Bppgk= github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= @@ -246,8 +248,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190404164418-38d8ce5564a5/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= -golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= @@ -263,8 +265,8 @@ golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -297,20 +299,20 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= -golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= diff --git a/rpc/client.go b/rpc/client.go index 93ca384715..736edbafee 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -41,9 +41,8 @@ var ( const ( // Timeouts - tcpKeepAliveInterval = 30 * time.Second - defaultDialTimeout = 10 * time.Second // used if context has no deadline - subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls + defaultDialTimeout = 10 * time.Second // used if context has no deadline + subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls ) const ( diff --git a/rpc/ipc_unix.go b/rpc/ipc_unix.go index 0851ea61e1..1dab2f7a3c 100644 --- a/rpc/ipc_unix.go +++ b/rpc/ipc_unix.go @@ -14,6 +14,7 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . +//go:build darwin || dragonfly || freebsd || linux || nacl || netbsd || openbsd || solaris // +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris package rpc @@ -42,5 +43,5 @@ func ipcListen(endpoint string) (net.Listener, error) { // newIPCConnection will connect to a Unix socket on the given endpoint. func newIPCConnection(ctx context.Context, endpoint string) (net.Conn, error) { - return dialContext(ctx, "unix", endpoint) + return new(net.Dialer).DialContext(ctx, "unix", endpoint) } diff --git a/rpc/json.go b/rpc/json.go index 34c825c025..1fa075a5cf 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -145,6 +145,11 @@ type Conn interface { SetWriteDeadline(time.Time) error } +type deadlineCloser interface { + io.Closer + SetWriteDeadline(time.Time) error +} + // ConnRemoteAddr wraps the RemoteAddr operation, which returns a description // of the peer address of a connection. If a Conn also implements ConnRemoteAddr, this // description is used in log messages. @@ -169,12 +174,10 @@ type jsonCodec struct { decode func(v interface{}) error // decoder to allow multiple transports encMu sync.Mutex // guards the encoder encode func(v interface{}) error // encoder to allow multiple transports - conn Conn + conn deadlineCloser } -// NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based -// on explicitly given encoding and decoding methods. -func NewCodec(conn Conn, encode, decode func(v interface{}) error) ServerCodec { +func newCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec { codec := &jsonCodec{ closed: make(chan interface{}), encode: encode, @@ -187,12 +190,14 @@ func NewCodec(conn Conn, encode, decode func(v interface{}) error) ServerCodec { return codec } -// NewJSONCodec creates a new RPC server codec with support for JSON-RPC 2.0. +// NewJSONCodec creates a codec that reads from the given connection. If conn implements +// ConnRemoteAddr, log messages will use it to include the remote address of the +// connection. func NewJSONCodec(conn Conn) ServerCodec { enc := json.NewEncoder(conn) dec := json.NewDecoder(conn) dec.UseNumber() - return NewCodec(conn, enc.Encode, dec.Decode) + return newCodec(conn, enc.Encode, dec.Decode) } func (c *jsonCodec) RemoteAddr() string { diff --git a/rpc/websocket.go b/rpc/websocket.go index 43ef76959e..f70ec78c95 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -17,41 +17,33 @@ package rpc import ( - "bytes" "context" - "crypto/tls" "encoding/base64" - "encoding/json" - "errors" "fmt" - "net" "net/http" "net/url" "os" "strings" - "time" + "sync" mapset "github.com/deckarep/golang-set" - "golang.org/x/net/websocket" + "github.com/gorilla/websocket" "github.com/tomochain/tomochain/log" ) -// websocketJSONCodec is a custom JSON codec with payload size enforcement and -// special number parsing. -var websocketJSONCodec = websocket.Codec{ - // Marshal is the stock JSON marshaller used by the websocket library too. - Marshal: func(v interface{}) ([]byte, byte, error) { - msg, err := json.Marshal(v) - return msg, websocket.TextFrame, err - }, - // Unmarshal is a specialized unmarshaller to properly convert numbers. - Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { - dec := json.NewDecoder(bytes.NewReader(msg)) - dec.UseNumber() - - return dec.Decode(v) - }, +const ( + wsReadBuffer = 1024 + wsWriteBuffer = 1024 +) + +var wsBufferPool = new(sync.Pool) + +// NewWSServer creates a new websocket RPC server around an API provider. +// +// Deprecated: use Server.WebsocketHandler +func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { + return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} } // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. @@ -59,49 +51,27 @@ var websocketJSONCodec = websocket.Codec{ // allowedOrigins should be a comma-separated list of allowed origin URLs. // To allow connections with any origin, pass "*". func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { - return websocket.Server{ - Handshake: wsHandshakeValidator(allowedOrigins), - Handler: func(conn *websocket.Conn) { - codec := newWebsocketCodec(conn) - s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions) - }, - } -} - -func newWebsocketCodec(conn *websocket.Conn) ServerCodec { - // Create a custom encode/decode pair to enforce payload size and number encoding - conn.MaxPayloadBytes = maxRequestContentLength - encoder := func(v interface{}) error { - return websocketJSONCodec.Send(conn, v) - } - decoder := func(v interface{}) error { - return websocketJSONCodec.Receive(conn, v) - } - rpcconn := Conn(conn) - if conn.IsServerConn() { - // Override remote address with the actual socket address because - // package websocket crashes if there is no request origin. - addr := conn.Request().RemoteAddr - if wsaddr := conn.RemoteAddr().(*websocket.Addr); wsaddr.URL != nil { - // Add origin if present. - addr += "(" + wsaddr.URL.String() + ")" + var upgrader = websocket.Upgrader{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + CheckOrigin: wsHandshakeValidator(allowedOrigins), + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Debug("WebSocket upgrade failed", "err", err) + return } - rpcconn = connWithRemoteAddr{conn, addr} - } - return NewCodec(rpcconn, encoder, decoder) -} - -// NewWSServer creates a new websocket RPC server around an API provider. -// -// Deprecated: use Server.WebsocketHandler -func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { - return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} + codec := newWebsocketCodec(conn) + s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions) + }) } // wsHandshakeValidator returns a handler that verifies the origin during the // websocket upgrade process. When a '*' is specified as an allowed origins all // connections are accepted. -func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error { +func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { origins := mapset.NewSet() allowAllOrigins := false @@ -113,7 +83,6 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http origins.Add(strings.ToLower(origin)) } } - // allow localhost if no allowedOrigins are specified. if len(origins.ToSlice()) == 0 { origins.Add("http://localhost") @@ -121,45 +90,39 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http origins.Add("http://" + strings.ToLower(hostname)) } } - log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) - f := func(cfg *websocket.Config, req *http.Request) error { + f := func(req *http.Request) bool { + // Skip origin verification if no Origin header is present. The origin check + // is supposed to protect against browser based attacks. Browsers always set + // Origin. Non-browser software can put anything in origin and checking it doesn't + // provide additional security. + if _, ok := req.Header["Origin"]; !ok { + return true + } // Verify origin against whitelist. origin := strings.ToLower(req.Header.Get("Origin")) if allowAllOrigins || origins.Contains(origin) { - return nil + return true } log.Warn("Rejected WebSocket connection", "origin", origin) - return errors.New("origin not allowed") + return false } return f } -func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { - if origin == "" { - var err error - if origin, err = os.Hostname(); err != nil { - return nil, err - } - if strings.HasPrefix(endpoint, "wss") { - origin = "https://" + strings.ToLower(origin) - } else { - origin = "http://" + strings.ToLower(origin) - } - } - config, err := websocket.NewConfig(endpoint, origin) - if err != nil { - return nil, err - } +type wsHandshakeError struct { + err error + status string +} - if config.Location.User != nil { - b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String())) - config.Header.Add("Authorization", "Basic "+b64auth) - config.Location.User = nil +func (e wsHandshakeError) Error() string { + s := e.err.Error() + if e.status != "" { + s += " (HTTP status " + e.status + ")" } - return config, nil + return s } // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server @@ -168,65 +131,46 @@ func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { - config, err := wsGetConfig(endpoint, origin) + endpoint, header, err := wsClientHeaders(endpoint, origin) if err != nil { return nil, err } - + dialer := websocket.Dialer{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + } return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { - conn, err := wsDialContext(ctx, config) + conn, resp, err := dialer.DialContext(ctx, endpoint, header) if err != nil { - return nil, err + hErr := wsHandshakeError{err: err} + if resp != nil { + hErr.status = resp.Status + } + return nil, hErr } return newWebsocketCodec(conn), nil }) } -func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { - var conn net.Conn - var err error - switch config.Location.Scheme { - case "ws": - conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) - case "wss": - dialer := contextDialer(ctx) - conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) - default: - err = websocket.ErrBadScheme - } +func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { + endpointURL, err := url.Parse(endpoint) if err != nil { - return nil, err + return endpoint, nil, err } - ws, err := websocket.NewClient(config, conn) - if err != nil { - conn.Close() - return nil, err + header := make(http.Header) + if origin != "" { + header.Add("origin", origin) } - return ws, err -} - -var wsPortMap = map[string]string{"ws": "80", "wss": "443"} - -func wsDialAddress(location *url.URL) string { - if _, ok := wsPortMap[location.Scheme]; ok { - if _, _, err := net.SplitHostPort(location.Host); err != nil { - return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) - } + if endpointURL.User != nil { + b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) + header.Add("authorization", "Basic "+b64auth) + endpointURL.User = nil } - return location.Host -} - -func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { - d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} - return d.DialContext(ctx, network, addr) + return endpointURL.String(), header, nil } -func contextDialer(ctx context.Context) *net.Dialer { - dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} - if deadline, ok := ctx.Deadline(); ok { - dialer.Deadline = deadline - } else { - dialer.Deadline = time.Now().Add(defaultDialTimeout) - } - return dialer +func newWebsocketCodec(conn *websocket.Conn) ServerCodec { + conn.SetReadLimit(maxRequestContentLength) + return newCodec(conn, conn.WriteJSON, conn.ReadJSON) } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 5bf3780d62..a00e8da0f6 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -16,39 +16,244 @@ package rpc -import "testing" +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" -func TestWSGetConfigNoAuth(t *testing.T) { - config, err := wsGetConfig("ws://example.com:1234", "") + "github.com/gorilla/websocket" +) + +func TestWebsocketClientHeaders(t *testing.T) { + t.Parallel() + + endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") if err != nil { - t.Logf("wsGetConfig failed: %s", err) - t.Fail() - return + t.Fatalf("wsGetConfig failed: %s", err) + } + if endpoint != "wss://example.com:1234" { + t.Fatal("User should have been stripped from the URL") + } + if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { + t.Fatal("Basic auth header is incorrect") + } + if header.Get("origin") != "https://example.com" { + t.Fatal("Origin not set") + } +} + +// This test checks that the server rejects connections from disallowed origins. +func TestWebsocketOriginCheck(t *testing.T) { + t.Parallel() + + var ( + srv = newTestServer() + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) + wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + ) + defer srv.Stop() + defer httpsrv.Close() + + client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") + if err == nil { + client.Close() + t.Fatal("no error for wrong origin") + } + wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"} + if !reflect.DeepEqual(err, wantErr) { + t.Fatalf("wrong error for wrong origin: %q", err) + } + + // Connections without origin header should work. + client, err = DialWebsocket(context.Background(), wsURL, "") + if err != nil { + t.Fatal("error for empty origin") + } + client.Close() +} + +// This test checks whether calls exceeding the request size limit are rejected. +func TestWebsocketLargeCall(t *testing.T) { + t.Parallel() + + var ( + srv = newTestServer() + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + ) + defer srv.Stop() + defer httpsrv.Close() + + client, err := DialWebsocket(context.Background(), wsURL, "") + if err != nil { + t.Fatalf("can't dial: %v", err) + } + defer client.Close() + + // This call sends slightly less than the limit and should work. + var result Result + arg := strings.Repeat("x", maxRequestContentLength-200) + if err := client.Call(&result, "test_echo", arg, 1); err != nil { + t.Fatalf("valid call didn't work: %v", err) + } + if result.String != arg { + t.Fatal("wrong string echoed") + } + + // This call sends twice the allowed size and shouldn't work. + arg = strings.Repeat("x", maxRequestContentLength*2) + err = client.Call(&result, "test_echo", arg) + if err == nil { + t.Fatal("no error for too large call") } - if config.Location.User != nil { - t.Log("User should have been stripped from the URL") - t.Fail() +} + +// This test checks that client handles WebSocket ping frames correctly. +func TestClientWebsocketPing(t *testing.T) { + t.Parallel() + + var ( + sendPing = make(chan struct{}) + server = wsPingTestServer(t, sendPing) + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) + ) + defer cancel() + defer server.Shutdown(ctx) + + client, err := DialContext(ctx, "ws://"+server.Addr) + if err != nil { + t.Fatalf("client dial error: %v", err) + } + resultChan := make(chan int) + sub, err := client.EthSubscribe(ctx, resultChan, "foo") + if err != nil { + t.Fatalf("client subscribe error: %v", err) } - if config.Location.Hostname() != "example.com" || - config.Location.Port() != "1234" || config.Location.Scheme != "ws" { - t.Logf("Unexpected URL: %s", config.Location) - t.Fail() + + // Wait for the context's deadline to be reached before proceeding. + // This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798 + <-ctx.Done() + close(sendPing) + + // Wait for the subscription result. + timeout := time.NewTimer(5 * time.Second) + for { + select { + case err := <-sub.Err(): + t.Error("client subscription error:", err) + case result := <-resultChan: + t.Log("client got result:", result) + return + case <-timeout.C: + t.Error("didn't get any result within the test timeout") + return + } } } -func TestWSGetConfigWithBasicAuth(t *testing.T) { - config, err := wsGetConfig("wss://testuser:test-PASS_01@example.com:1234", "") +// wsPingTestServer runs a WebSocket server which accepts a single subscription request. +// When a value arrives on sendPing, the server sends a ping frame, waits for a matching +// pong and finally delivers a single subscription result. +func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server { + var srv http.Server + shutdown := make(chan struct{}) + srv.RegisterOnShutdown(func() { + close(shutdown) + }) + srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Upgrade to WebSocket. + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("server WS upgrade error: %v", err) + return + } + defer conn.Close() + + // Handle the connection. + wsPingTestHandler(t, conn, shutdown, sendPing) + }) + + // Start the server. + listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { - t.Logf("wsGetConfig failed: %s", err) - t.Fail() + t.Fatal("can't listen:", err) + } + srv.Addr = listener.Addr().String() + go srv.Serve(listener) + return &srv +} + +func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) { + // Canned responses for the eth_subscribe call in TestClientWebsocketPing. + const ( + subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}` + subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}` + ) + + // Handle subscribe request. + if _, _, err := conn.ReadMessage(); err != nil { + t.Errorf("server read error: %v", err) return } - if config.Location.User != nil { - t.Log("User should have been stripped from the URL") - t.Fail() + if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil { + t.Errorf("server write error: %v", err) + return } - if config.Header.Get("Authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { - t.Log("Basic auth header is incorrect") - t.Fail() + + // Read from the connection to process control messages. + var pongCh = make(chan string) + conn.SetPongHandler(func(d string) error { + t.Logf("server got pong: %q", d) + pongCh <- d + return nil + }) + go func() { + for { + typ, msg, err := conn.ReadMessage() + if err != nil { + return + } + t.Logf("server got message (%d): %q", typ, msg) + } + }() + + // Write messages. + var ( + sendResponse <-chan time.Time + wantPong string + ) + for { + select { + case _, open := <-sendPing: + if !open { + sendPing = nil + } + t.Logf("server sending ping") + conn.WriteMessage(websocket.PingMessage, []byte("ping")) + wantPong = "ping" + case data := <-pongCh: + if wantPong == "" { + t.Errorf("unexpected pong") + } else if data != wantPong { + t.Errorf("got pong with wrong data %q", data) + } + wantPong = "" + sendResponse = time.NewTimer(200 * time.Millisecond).C + case <-sendResponse: + t.Logf("server sending response") + conn.WriteMessage(websocket.TextMessage, []byte(subNotify)) + sendResponse = nil + case <-shutdown: + conn.Close() + return + } } } From 772a53b56eda299a7213329d9cf418e3a7355e7c Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 4 Dec 2023 16:41:05 +0700 Subject: [PATCH 110/119] Increase HTTP/Websocket request size limit to 5MB --- rpc/http.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rpc/http.go b/rpc/http.go index 640a8460c1..c8a097e9c3 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -37,7 +37,7 @@ import ( ) const ( - maxRequestContentLength = 1024 * 512 + maxRequestContentLength = 1024 * 1024 * 5 contentType = "application/json" ) From ca97f69b0ed1c0d262ce276047da088bda015da7 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Mon, 4 Dec 2023 16:41:47 +0700 Subject: [PATCH 111/119] Update ETHstats --- ethstats/ethstats.go | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/ethstats/ethstats.go b/ethstats/ethstats.go index 69825c5a10..237e2cb9c6 100644 --- a/ethstats/ethstats.go +++ b/ethstats/ethstats.go @@ -23,13 +23,15 @@ import ( "errors" "fmt" "math/big" - "net" + "net/http" "regexp" "runtime" "strconv" "strings" "time" + "github.com/gorilla/websocket" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/mclock" "github.com/tomochain/tomochain/consensus" @@ -41,7 +43,6 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/p2p" "github.com/tomochain/tomochain/rpc" - "golang.org/x/net/websocket" ) const ( @@ -202,21 +203,21 @@ func (s *Service) loop() { path := fmt.Sprintf("%s/api", s.host) urls := []string{path} - if !strings.Contains(path, "://") { // url.Parse and url.IsAbs is unsuitable (https://github.com/golang/go/issues/19779) + // url.Parse and url.IsAbs is unsuitable (https://github.com/golang/go/issues/19779) + if !strings.Contains(path, "://") { urls = []string{"wss://" + path, "ws://" + path} } // Establish a websocket connection to the server on any supported URL var ( - conf *websocket.Config conn *websocket.Conn err error ) + dialer := websocket.Dialer{HandshakeTimeout: 5 * time.Second} + header := make(http.Header) + header.Set("origin", "http://localhost") for _, url := range urls { - if conf, err = websocket.NewConfig(url, "http://localhost/"); err != nil { - continue - } - conf.Dialer = &net.Dialer{Timeout: 5 * time.Second} - if conn, err = websocket.DialConfig(conf); err == nil { + conn, _, err = dialer.Dial(url, header) + if err == nil { break } } @@ -286,7 +287,7 @@ func (s *Service) readLoop(conn *websocket.Conn) { for { // Retrieve the next generic network packet and bail out on error var msg map[string][]interface{} - if err := websocket.JSON.Receive(conn, &msg); err != nil { + if err := conn.ReadJSON(&msg); err != nil { log.Warn("Failed to decode stats server message", "err", err) return } @@ -401,12 +402,12 @@ func (s *Service) login(conn *websocket.Conn) error { login := map[string][]interface{}{ "emit": {"hello", auth}, } - if err := websocket.JSON.Send(conn, login); err != nil { + if err := conn.ReadJSON(login); err != nil { return err } // Retrieve the remote ack or connection termination var ack map[string][]string - if err := websocket.JSON.Receive(conn, &ack); err != nil || len(ack["emit"]) != 1 || ack["emit"][0] != "ready" { + if err := conn.ReadJSON(&ack); err != nil || len(ack["emit"]) != 1 || ack["emit"][0] != "ready" { return errors.New("unauthorized") } return nil @@ -443,7 +444,7 @@ func (s *Service) reportLatency(conn *websocket.Conn) error { "clientTime": start.String(), }}, } - if err := websocket.JSON.Send(conn, ping); err != nil { + if err := conn.ReadJSON(ping); err != nil { return err } // Wait for the pong request to arrive back @@ -465,7 +466,7 @@ func (s *Service) reportLatency(conn *websocket.Conn) error { "latency": latency, }}, } - return websocket.JSON.Send(conn, stats) + return conn.ReadJSON(stats) } // blockStats is the information to report about individual blocks. @@ -516,7 +517,7 @@ func (s *Service) reportBlock(conn *websocket.Conn, block *types.Block) error { report := map[string][]interface{}{ "emit": {"block", stats}, } - return websocket.JSON.Send(conn, report) + return conn.ReadJSON(report) } // assembleBlockStats retrieves any required metadata to report a single block @@ -630,7 +631,7 @@ func (s *Service) reportHistory(conn *websocket.Conn, list []uint64) error { report := map[string][]interface{}{ "emit": {"history", stats}, } - return websocket.JSON.Send(conn, report) + return conn.ReadJSON(report) } // pendStats is the information to report about pending transactions. @@ -660,7 +661,7 @@ func (s *Service) reportPending(conn *websocket.Conn) error { report := map[string][]interface{}{ "emit": {"pending", stats}, } - return websocket.JSON.Send(conn, report) + return conn.ReadJSON(report) } // nodeStats is the information to report about the local node. @@ -715,5 +716,5 @@ func (s *Service) reportStats(conn *websocket.Conn) error { report := map[string][]interface{}{ "emit": {"stats", stats}, } - return websocket.JSON.Send(conn, report) + return conn.ReadJSON(report) } From c3bb18b26a0815a872fc88812393bcaaff0ddf7d Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 5 Dec 2023 11:44:26 +0700 Subject: [PATCH 112/119] Improve codec abstraction and use gorilla/websocket --- node/node.go | 23 +++++++++++++ rpc/client.go | 12 +++---- rpc/doc.go | 79 +++++++++++++++++++++++---------------------- rpc/handler.go | 10 +++--- rpc/http.go | 26 +++++++-------- rpc/inproc.go | 4 +-- rpc/ipc.go | 4 +-- rpc/json.go | 60 +++++++++++++++++----------------- rpc/server.go | 10 +++--- rpc/stdio.go | 20 +++++++++--- rpc/subscription.go | 4 +-- rpc/types.go | 10 +++--- rpc/websocket.go | 4 +-- 13 files changed, 153 insertions(+), 113 deletions(-) diff --git a/node/node.go b/node/node.go index 21fa842604..b40a53af47 100644 --- a/node/node.go +++ b/node/node.go @@ -123,6 +123,29 @@ func New(conf *Config) (*Node, error) { }, nil } +// Close stops the Node and releases resources acquired in +// Node constructor New. +func (n *Node) Close() error { + var errs []error + + // Terminate all subsystems and collect any errors + if err := n.Stop(); err != nil && err != ErrNodeStopped { + errs = append(errs, err) + } + if err := n.accman.Close(); err != nil { + errs = append(errs, err) + } + // Report any errors that might have occurred + switch len(errs) { + case 0: + return nil + case 1: + return errs[0] + default: + return fmt.Errorf("%v", errs) + } +} + // Register injects a new service into the node's stack. The service created by // the passed constructor must be unique in its type with regard to sibling ones. func (n *Node) Register(constructor ServiceConstructor) error { diff --git a/rpc/client.go b/rpc/client.go index 736edbafee..5da0874897 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -117,7 +117,7 @@ func (c *Client) newClientConn(conn ServerCodec) *clientConn { func (cc *clientConn) close(err error, inflightReq *requestOp) { cc.handler.close(err, inflightReq) - cc.codec.Close() + cc.codec.close() } type readOp struct { @@ -482,7 +482,7 @@ func (c *Client) write(ctx context.Context, msg interface{}) error { return err } } - err := c.writeConn.Write(ctx, msg) + err := c.writeConn.writeJSON(ctx, msg) if err != nil { c.writeConn = nil } @@ -509,7 +509,7 @@ func (c *Client) reconnect(ctx context.Context) error { c.writeConn = newconn return nil case <-c.didClose: - newconn.Close() + newconn.close() return ErrClientQuit } } @@ -556,7 +556,7 @@ func (c *Client) dispatch(codec ServerCodec) { // Reconnect: case newcodec := <-c.reconnected: - log.Debug("RPC client reconnected", "reading", reading, "conn", newcodec.RemoteAddr()) + log.Debug("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr()) if reading { // Wait for the previous read loop to exit. This is a rare case which // happens if this loop isn't notified in time after the connection breaks. @@ -610,9 +610,9 @@ func (c *Client) drainRead() { // read decodes RPC messages from a codec, feeding them into dispatch. func (c *Client) read(codec ServerCodec) { for { - msgs, batch, err := codec.Read() + msgs, batch, err := codec.readBatch() if _, ok := err.(*json.SyntaxError); ok { - codec.Write(context.Background(), errorMessage(&parseError{err.Error()})) + codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()})) } if err != nil { c.readErr <- err diff --git a/rpc/doc.go b/rpc/doc.go index 14b3780ade..9cd0e4c7ec 100644 --- a/rpc/doc.go +++ b/rpc/doc.go @@ -22,14 +22,15 @@ conventions can be called remotely. It also has support for the publish/subscrib pattern. Methods that satisfy the following criteria are made available for remote access: - - object must be exported - - method must be exported - - method returns 0, 1 (response or error) or 2 (response and error) values - - method argument(s) must be exported or builtin types - - method returned value(s) must be exported or builtin types + - object must be exported + - method must be exported + - method returns 0, 1 (response or error) or 2 (response and error) values + - method argument(s) must be exported or builtin types + - method returned value(s) must be exported or builtin types An example method: - func (s *CalcService) Add(a, b int) (int, error) + + func (s *CalcService) Add(a, b int) (int, error) When the returned error isn't nil the returned integer is ignored and the error is send back to the client. Otherwise the returned integer is send back to the client. @@ -38,7 +39,7 @@ Optional arguments are supported by accepting pointer values as arguments. E.g. if we want to do the addition in an optional finite field we can accept a mod argument as pointer value. - func (s *CalService) Add(a, b int, mod *int) (int, error) + func (s *CalService) Add(a, b int, mod *int) (int, error) This RPC method can be called with 2 integers and a null value as third argument. In that case the mod argument will be nil. Or it can be called with 3 integers, @@ -52,47 +53,49 @@ client using the codec. The server can execute requests concurrently. Responses can be sent back to the client out of order. An example server which uses the JSON codec: - type CalculatorService struct {} - func (s *CalculatorService) Add(a, b int) int { - return a + b - } + type CalculatorService struct {} - func (s *CalculatorService Div(a, b int) (int, error) { - if b == 0 { - return 0, errors.New("divide by zero") - } - return a/b, nil - } + func (s *CalculatorService) Add(a, b int) int { + return a + b + } + + func (s *CalculatorService Div(a, b int) (int, error) { + if b == 0 { + return 0, errors.New("divide by zero") + } + return a/b, nil + } - calculator := new(CalculatorService) - server := NewServer() - server.RegisterName("calculator", calculator") + calculator := new(CalculatorService) + server := NewServer() + server.RegisterName("calculator", calculator") - l, _ := net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: "/tmp/calculator.sock"}) - for { - c, _ := l.AcceptUnix() - codec := v2.NewJSONCodec(c) - go server.ServeCodec(codec) - } + l, _ := net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: "/tmp/calculator.sock"}) + for { + c, _ := l.AcceptUnix() + codec := v2.NewJSONCodec(c) + go server.ServeCodec(codec) + } The package also supports the publish subscribe pattern through the use of subscriptions. A method that is considered eligible for notifications must satisfy the following criteria: - - object must be exported - - method must be exported - - first method argument type must be context.Context - - method argument(s) must be exported or builtin types - - method must return the tuple Subscription, error + - object must be exported + - method must be exported + - first method argument type must be context.Context + - method argument(s) must be exported or builtin types + - method must return the tuple Subscription, error An example method: - func (s *BlockChainService) NewBlocks(ctx context.Context) (Subscription, error) { - ... - } + + func (s *BlockChainService) NewBlocks(ctx context.Context) (Subscription, error) { + ... + } Subscriptions are deleted when: - - the user sends an unsubscribe request - - the connection which was used to create the subscription is closed. This can be initiated - by the client and server. The server will close the connection on an write error or when - the queue of buffered notifications gets too big. + - the user sends an unsubscribe request + - the connection which was used to create the subscription is closeCh. This can be initiated + by the client and server. The server will close the connection on an write error or when + the queue of buffered notifications gets too big. */ package rpc diff --git a/rpc/handler.go b/rpc/handler.go index 3507dd8411..d4f0cee36d 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -84,8 +84,8 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg * serverSubs: make(map[ID]*Subscription), log: log.Root(), } - if conn.RemoteAddr() != "" { - h.log = h.log.New("conn", conn.RemoteAddr()) + if conn.remoteAddr() != "" { + h.log = h.log.New("conn", conn.remoteAddr()) } h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe)) return h @@ -96,7 +96,7 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { // Emit error response for empty batches: if len(msgs) == 0 { h.startCallProc(func(cp *callProc) { - h.conn.Write(cp.ctx, errorMessage(&invalidRequestError{"empty batch"})) + h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"})) }) return } @@ -121,7 +121,7 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { } h.addSubscriptions(cp.notifiers) if len(answers) > 0 { - h.conn.Write(cp.ctx, answers) + h.conn.writeJSON(cp.ctx, answers) } for _, n := range cp.notifiers { n.activate() @@ -138,7 +138,7 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) { answer := h.handleCallMsg(cp, msg) h.addSubscriptions(cp.notifiers) if answer != nil { - h.conn.Write(cp.ctx, answer) + h.conn.writeJSON(cp.ctx, answer) } for _, n := range cp.notifiers { n.activate() diff --git a/rpc/http.go b/rpc/http.go index c8a097e9c3..1eb9417d01 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -48,29 +48,29 @@ type httpConn struct { client *http.Client req *http.Request closeOnce sync.Once - closed chan interface{} + closeCh chan interface{} } // httpConn is treated specially by Client. -func (hc *httpConn) Write(context.Context, interface{}) error { - panic("Write called on httpConn") +func (hc *httpConn) writeJSON(context.Context, interface{}) error { + panic("writeJSON called on httpConn") } -func (hc *httpConn) RemoteAddr() string { +func (hc *httpConn) remoteAddr() string { return hc.req.URL.String() } -func (hc *httpConn) Read() ([]*jsonrpcMessage, bool, error) { - <-hc.closed +func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) { + <-hc.closeCh return nil, false, io.EOF } -func (hc *httpConn) Close() { - hc.closeOnce.Do(func() { close(hc.closed) }) +func (hc *httpConn) close() { + hc.closeOnce.Do(func() { close(hc.closeCh) }) } -func (hc *httpConn) Closed() <-chan interface{} { - return hc.closed +func (hc *httpConn) closed() <-chan interface{} { + return hc.closeCh } // HTTPTimeouts represents the configuration params for the HTTP RPC server. @@ -117,7 +117,7 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) { initctx := context.Background() return newClient(initctx, func(context.Context) (ServerCodec, error) { - return &httpConn{client: client, req: req, closed: make(chan interface{})}, nil + return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil }) } @@ -196,7 +196,7 @@ type httpServerConn struct { func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec { body := io.LimitReader(r.Body, maxRequestContentLength) conn := &httpServerConn{Reader: body, Writer: w, r: r} - return NewJSONCodec(conn) + return NewCodec(conn) } // Close does nothing and always returns nil. @@ -266,7 +266,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("content-type", contentType) codec := newHTTPServerConn(r, w) - defer codec.Close() + defer codec.close() s.serveSingleRequest(ctx, codec) } diff --git a/rpc/inproc.go b/rpc/inproc.go index c4456cfc4b..bafe4f59be 100644 --- a/rpc/inproc.go +++ b/rpc/inproc.go @@ -26,8 +26,8 @@ func DialInProc(handler *Server) *Client { initctx := context.Background() c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) { p1, p2 := net.Pipe() - go handler.ServeCodec(NewJSONCodec(p1), OptionMethodInvocation|OptionSubscriptions) - return NewJSONCodec(p2), nil + go handler.ServeCodec(NewCodec(p1), OptionMethodInvocation|OptionSubscriptions) + return NewCodec(p2), nil }) return c } diff --git a/rpc/ipc.go b/rpc/ipc.go index b17e021cf4..77d7bf4cf0 100644 --- a/rpc/ipc.go +++ b/rpc/ipc.go @@ -35,7 +35,7 @@ func (s *Server) ServeListener(l net.Listener) error { return err } log.Trace("Accepted RPC connection", "conn", conn.RemoteAddr()) - go s.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions) + go s.ServeCodec(NewCodec(conn), OptionMethodInvocation|OptionSubscriptions) } } @@ -51,6 +51,6 @@ func DialIPC(ctx context.Context, endpoint string) (*Client, error) { if err != nil { return nil, err } - return NewJSONCodec(conn), err + return NewCodec(conn), err }) } diff --git a/rpc/json.go b/rpc/json.go index 1fa075a5cf..35d9859209 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -168,43 +168,45 @@ func (c connWithRemoteAddr) RemoteAddr() string { return c.addr } // jsonCodec reads and writes JSON-RPC messages to the underlying connection. It also has // support for parsing arguments and serializing (result) objects. type jsonCodec struct { - remoteAddr string - closer sync.Once // close closed channel once - closed chan interface{} // closed on Close - decode func(v interface{}) error // decoder to allow multiple transports - encMu sync.Mutex // guards the encoder - encode func(v interface{}) error // encoder to allow multiple transports - conn deadlineCloser -} - -func newCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec { + remote string + closer sync.Once // close closed channel once + closeCh chan interface{} // closed on Close + decode func(v interface{}) error // decoder to allow multiple transports + encMu sync.Mutex // guards the encoder + encode func(v interface{}) error // encoder to allow multiple transports + conn deadlineCloser +} + +// NewFuncCodec creates a codec which uses the given functions to read and write. If conn +// implements ConnRemoteAddr, log messages will use it to include the remote address of +// the connection. +func NewFuncCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec { codec := &jsonCodec{ - closed: make(chan interface{}), - encode: encode, - decode: decode, - conn: conn, + closeCh: make(chan interface{}), + encode: encode, + decode: decode, + conn: conn, } if ra, ok := conn.(ConnRemoteAddr); ok { - codec.remoteAddr = ra.RemoteAddr() + codec.remote = ra.RemoteAddr() } return codec } -// NewJSONCodec creates a codec that reads from the given connection. If conn implements -// ConnRemoteAddr, log messages will use it to include the remote address of the -// connection. -func NewJSONCodec(conn Conn) ServerCodec { +// NewCodec creates a codec on the given connection. If conn implements ConnRemoteAddr, log +// messages will use it to include the remote address of the connection. +func NewCodec(conn Conn) ServerCodec { enc := json.NewEncoder(conn) dec := json.NewDecoder(conn) dec.UseNumber() - return newCodec(conn, enc.Encode, dec.Decode) + return NewFuncCodec(conn, enc.Encode, dec.Decode) } -func (c *jsonCodec) RemoteAddr() string { - return c.remoteAddr +func (c *jsonCodec) remoteAddr() string { + return c.remote } -func (c *jsonCodec) Read() (msg []*jsonrpcMessage, batch bool, err error) { +func (c *jsonCodec) readBatch() (msg []*jsonrpcMessage, batch bool, err error) { // Decode the next JSON object in the input stream. // This verifies basic syntax, etc. var rawmsg json.RawMessage @@ -215,8 +217,8 @@ func (c *jsonCodec) Read() (msg []*jsonrpcMessage, batch bool, err error) { return msg, batch, nil } -// Write sends a message to client. -func (c *jsonCodec) Write(ctx context.Context, v interface{}) error { +// writeJSON sends a message to client. +func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error { c.encMu.Lock() defer c.encMu.Unlock() @@ -229,16 +231,16 @@ func (c *jsonCodec) Write(ctx context.Context, v interface{}) error { } // Close the underlying connection -func (c *jsonCodec) Close() { +func (c *jsonCodec) close() { c.closer.Do(func() { - close(c.closed) + close(c.closeCh) c.conn.Close() }) } // Closed returns a channel which will be closed when Close is called -func (c *jsonCodec) Closed() <-chan interface{} { - return c.closed +func (c *jsonCodec) closed() <-chan interface{} { + return c.closeCh } // parseMessage parses raw bytes as a (batch of) JSON-RPC message(s). There are no error diff --git a/rpc/server.go b/rpc/server.go index e8eca78564..aa455e374b 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -73,7 +73,7 @@ func (s *Server) RegisterName(name string, receiver interface{}) error { // // Note that codec options are no longer supported. func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { - defer codec.Close() + defer codec.close() // Don't serve if server is stopped. if atomic.LoadInt32(&s.run) == 0 { @@ -85,7 +85,7 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { defer s.codecs.Remove(codec) c := initClient(codec, s.idgen, &s.services) - <-codec.Closed() + <-codec.closed() c.Close() } @@ -102,10 +102,10 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { h.allowSubscribe = false defer h.close(io.EOF, nil) - reqs, batch, err := codec.Read() + reqs, batch, err := codec.readBatch() if err != nil { if err != io.EOF { - codec.Write(ctx, errorMessage(&invalidMessageError{"parse error"})) + codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"})) } return } @@ -123,7 +123,7 @@ func (s *Server) Stop() { if atomic.CompareAndSwapInt32(&s.run, 1, 0) { log.Debug("RPC server shutting down") s.codecs.Each(func(c interface{}) bool { - c.(ServerCodec).Close() + c.(ServerCodec).close() return true }) } diff --git a/rpc/stdio.go b/rpc/stdio.go index 8f6b7bd4bf..be2bab1c98 100644 --- a/rpc/stdio.go +++ b/rpc/stdio.go @@ -19,6 +19,7 @@ package rpc import ( "context" "errors" + "io" "net" "os" "time" @@ -26,19 +27,30 @@ import ( // DialStdIO creates a client on stdin/stdout. func DialStdIO(ctx context.Context) (*Client, error) { + return DialIO(ctx, os.Stdin, os.Stdout) +} + +// DialIO creates a client which uses the given IO channels +func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) { return newClient(ctx, func(_ context.Context) (ServerCodec, error) { - return NewJSONCodec(stdioConn{}), nil + return NewCodec(stdioConn{ + in: in, + out: out, + }), nil }) } -type stdioConn struct{} +type stdioConn struct { + in io.Reader + out io.Writer +} func (io stdioConn) Read(b []byte) (n int, err error) { - return os.Stdin.Read(b) + return io.in.Read(b) } func (io stdioConn) Write(b []byte) (n int, err error) { - return os.Stdout.Write(b) + return io.out.Write(b) } func (io stdioConn) Close() error { diff --git a/rpc/subscription.go b/rpc/subscription.go index c1e869b8a3..153e24063e 100644 --- a/rpc/subscription.go +++ b/rpc/subscription.go @@ -141,7 +141,7 @@ func (n *Notifier) Notify(id ID, data interface{}) error { // Closed returns a channel that is closed when the RPC connection is closed. // Deprecated: use subscription error channel func (n *Notifier) Closed() <-chan interface{} { - return n.h.conn.Closed() + return n.h.conn.closed() } // takeSubscription returns the subscription (if one has been created). No subscription can @@ -172,7 +172,7 @@ func (n *Notifier) activate() error { func (n *Notifier) send(sub *Subscription, data json.RawMessage) error { params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data}) ctx := context.Background() - return n.h.conn.Write(ctx, &jsonrpcMessage{ + return n.h.conn.writeJSON(ctx, &jsonrpcMessage{ Version: vsn, Method: n.namespace + notificationMethodSuffix, Params: params, diff --git a/rpc/types.go b/rpc/types.go index c7539a2b20..c2d7780c92 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -43,19 +43,19 @@ type Error interface { // a RPC session. Implementations must be go-routine safe since the codec can be called in // multiple go-routines concurrently. type ServerCodec interface { - Read() (msgs []*jsonrpcMessage, isBatch bool, err error) - Close() + readBatch() (msgs []*jsonrpcMessage, isBatch bool, err error) + close() jsonWriter } // jsonWriter can write JSON messages to its underlying connection. // Implementations must be safe for concurrent use. type jsonWriter interface { - Write(context.Context, interface{}) error + writeJSON(context.Context, interface{}) error // Closed returns a channel which is closed when the connection is closed. - Closed() <-chan interface{} + closed() <-chan interface{} // RemoteAddr returns the peer address of the connection. - RemoteAddr() string + remoteAddr() string } type BlockNumber int64 diff --git a/rpc/websocket.go b/rpc/websocket.go index f70ec78c95..f7663979c3 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -64,7 +64,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { return } codec := newWebsocketCodec(conn) - s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions) + s.ServeCodec(codec, 0) }) } @@ -172,5 +172,5 @@ func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { func newWebsocketCodec(conn *websocket.Conn) ServerCodec { conn.SetReadLimit(maxRequestContentLength) - return newCodec(conn, conn.WriteJSON, conn.ReadJSON) + return NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON) } From d7cb1cb6f85ed2bc3f2c4081b399a402fa096c65 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 5 Dec 2023 11:48:52 +0700 Subject: [PATCH 113/119] Use gorilla websocket in p2p/simulations --- p2p/simulations/adapters/exec.go | 304 ++++++++++++++---------- p2p/simulations/adapters/inproc.go | 186 +++++++++------ p2p/simulations/adapters/inproc_test.go | 259 ++++++++++++++++++++ p2p/simulations/adapters/types.go | 62 ++++- p2p/simulations/http.go | 29 ++- p2p/simulations/http_test.go | 84 +++++-- 6 files changed, 696 insertions(+), 228 deletions(-) create mode 100644 p2p/simulations/adapters/inproc_test.go diff --git a/p2p/simulations/adapters/exec.go b/p2p/simulations/adapters/exec.go index 58e2613123..bfbdb424b5 100644 --- a/p2p/simulations/adapters/exec.go +++ b/p2p/simulations/adapters/exec.go @@ -17,7 +17,7 @@ package adapters import ( - "bufio" + "bytes" "context" "crypto/ecdsa" "encoding/json" @@ -25,31 +25,34 @@ import ( "fmt" "io" "net" + "net/http" "os" "os/exec" "os/signal" "path/filepath" - "regexp" "strings" "sync" "syscall" "time" "github.com/docker/docker/pkg/reexec" + "github.com/gorilla/websocket" + "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/rpc" - "golang.org/x/net/websocket" ) -// ExecAdapter is a NodeAdapter which runs simulation nodes by executing the -// current binary as a child process. -// -// An init hook is used so that the child process executes the node services -// (rather than whataver the main() function would normally do), see the -// execP2PNode function for more information. +func init() { + // Register a reexec function to start a simulation node when the current binary is + // executed as "p2p-node" (rather than whatever the main() function would normally do). + reexec.Register("p2p-node", execP2PNode) +} + +// ExecAdapter is a NodeAdapter which runs simulation nodes by executing the current binary +// as a child process. type ExecAdapter struct { // BaseDir is the directory under which the data directories for each // simulation node are created. @@ -90,24 +93,35 @@ func (e *ExecAdapter) NewNode(config *NodeConfig) (Node, error) { return nil, fmt.Errorf("error creating node directory: %s", err) } + err := config.initDummyEnode() + if err != nil { + return nil, err + } + // generate the config conf := &execNodeConfig{ Stack: node.DefaultConfig, Node: config, } - conf.Stack.DataDir = filepath.Join(dir, "data") + if config.DataDir != "" { + conf.Stack.DataDir = config.DataDir + } else { + conf.Stack.DataDir = filepath.Join(dir, "data") + } + + // these parameters are crucial for execadapter node to run correctly conf.Stack.WSHost = "127.0.0.1" conf.Stack.WSPort = 0 conf.Stack.WSOrigins = []string{"*"} conf.Stack.WSExposeAll = true - conf.Stack.P2P.EnableMsgEvents = false + conf.Stack.P2P.EnableMsgEvents = config.EnableMsgEvents conf.Stack.P2P.NoDiscovery = true conf.Stack.P2P.NAT = nil conf.Stack.NoUSB = true - // listen on a random localhost port (we'll get the actual port after - // starting the node through the RPC admin.nodeInfo method) - conf.Stack.P2P.ListenAddr = "127.0.0.1:0" + // Listen on a localhost port, which we set when we + // initialise NodeConfig (usually a random port) + conf.Stack.P2P.ListenAddr = fmt.Sprintf(":%d", config.Port) node := &ExecNode{ ID: config.ID, @@ -150,20 +164,14 @@ func (n *ExecNode) Client() (*rpc.Client, error) { return n.client, nil } -// wsAddrPattern is a regex used to read the WebSocket address from the node's -// log -var wsAddrPattern = regexp.MustCompile(`ws://[\d.:]+`) - // Start exec's the node passing the ID and service as command line arguments -// and the node config encoded as JSON in the _P2P_NODE_CONFIG environment -// variable +// and the node config encoded as JSON in an environment variable. func (n *ExecNode) Start(snapshots map[string][]byte) (err error) { if n.Cmd != nil { return errors.New("already started") } defer func() { if err != nil { - log.Error("node failed to start", "err", err) n.Stop() } }() @@ -180,59 +188,78 @@ func (n *ExecNode) Start(snapshots map[string][]byte) (err error) { return fmt.Errorf("error generating node config: %s", err) } - // use a pipe for stderr so we can both copy the node's stderr to - // os.Stderr and read the WebSocket address from the logs - stderrR, stderrW := io.Pipe() - stderr := io.MultiWriter(os.Stderr, stderrW) + // start the one-shot server that waits for startup information + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + statusURL, statusC := n.waitForStartupJSON(ctx) // start the node cmd := n.newCmd() cmd.Stdout = os.Stdout - cmd.Stderr = stderr - cmd.Env = append(os.Environ(), fmt.Sprintf("_P2P_NODE_CONFIG=%s", confData)) + cmd.Stderr = os.Stderr + cmd.Env = append(os.Environ(), + envStatusURL+"="+statusURL, + envNodeConfig+"="+string(confData), + ) if err := cmd.Start(); err != nil { return fmt.Errorf("error starting node: %s", err) } n.Cmd = cmd - // read the WebSocket address from the stderr logs - var wsAddr string - wsAddrC := make(chan string) - go func() { - s := bufio.NewScanner(stderrR) - for s.Scan() { - if strings.Contains(s.Text(), "WebSocket endpoint opened:") { - wsAddrC <- wsAddrPattern.FindString(s.Text()) - } - } - }() - select { - case wsAddr = <-wsAddrC: - if wsAddr == "" { - return errors.New("failed to read WebSocket address from stderr") - } - case <-time.After(10 * time.Second): - return errors.New("timed out waiting for WebSocket address on stderr") + // Wait for the node to start. + status := <-statusC + if status.Err != "" { + return errors.New(status.Err) } - - // create the RPC client and load the node info - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - client, err := rpc.DialWebsocket(ctx, wsAddr, "") + client, err := rpc.DialWebsocket(ctx, status.WSEndpoint, "") if err != nil { - return fmt.Errorf("error dialing rpc websocket: %s", err) - } - var info p2p.NodeInfo - if err := client.CallContext(ctx, &info, "admin_nodeInfo"); err != nil { - return fmt.Errorf("error getting node info: %s", err) + return fmt.Errorf("can't connect to RPC server: %v", err) } - n.client = client - n.wsAddr = wsAddr - n.Info = &info + // Node ready :) + n.client = client + n.wsAddr = status.WSEndpoint + n.Info = status.NodeInfo return nil } +// waitForStartupJSON runs a one-shot HTTP server to receive a startup report. +func (n *ExecNode) waitForStartupJSON(ctx context.Context) (string, chan nodeStartupJSON) { + var ( + ch = make(chan nodeStartupJSON, 1) + quitOnce sync.Once + srv http.Server + ) + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ch <- nodeStartupJSON{Err: err.Error()} + return "", ch + } + quit := func(status nodeStartupJSON) { + quitOnce.Do(func() { + l.Close() + ch <- status + }) + } + srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var status nodeStartupJSON + if err := json.NewDecoder(r.Body).Decode(&status); err != nil { + status.Err = fmt.Sprintf("can't decode startup report: %v", err) + } + quit(status) + }) + // Run the HTTP server, but don't wait forever and shut it down + // if the context is canceled. + go srv.Serve(l) + go func() { + <-ctx.Done() + quit(nodeStartupJSON{Err: "didn't get startup report"}) + }() + + url := "http://" + l.Addr().String() + return url, ch +} + // execCommand returns a command which runs the node locally by exec'ing // the current binary but setting argv[0] to "p2p-node" so that the child // runs execP2PNode @@ -288,31 +315,37 @@ func (n *ExecNode) NodeInfo() *p2p.NodeInfo { // ServeRPC serves RPC requests over the given connection by dialling the // node's WebSocket address and joining the two connections -func (n *ExecNode) ServeRPC(clientConn net.Conn) error { - conn, err := websocket.Dial(n.wsAddr, "", "http://localhost") +func (n *ExecNode) ServeRPC(clientConn *websocket.Conn) error { + conn, _, err := websocket.DefaultDialer.Dial(n.wsAddr, nil) if err != nil { return err } var wg sync.WaitGroup wg.Add(2) - join := func(src, dst net.Conn) { - defer wg.Done() - io.Copy(dst, src) - // close the write end of the destination connection - if cw, ok := dst.(interface { - CloseWrite() error - }); ok { - cw.CloseWrite() - } else { - dst.Close() - } - } - go join(conn, clientConn) - go join(clientConn, conn) + go wsCopy(&wg, conn, clientConn) + go wsCopy(&wg, clientConn, conn) wg.Wait() + conn.Close() return nil } +func wsCopy(wg *sync.WaitGroup, src, dst *websocket.Conn) { + defer wg.Done() + for { + msgType, r, err := src.NextReader() + if err != nil { + return + } + w, err := dst.NextWriter(msgType) + if err != nil { + return + } + if _, err = io.Copy(w, r); err != nil { + return + } + } +} + // Snapshots creates snapshots of the services by calling the // simulation_snapshot RPC method func (n *ExecNode) Snapshots() (map[string][]byte, error) { @@ -323,12 +356,6 @@ func (n *ExecNode) Snapshots() (map[string][]byte, error) { return snapshots, n.client.Call(&snapshots, "simulation_snapshot") } -func init() { - // register a reexec function to start a devp2p node when the current - // binary is executed as "p2p-node" - reexec.Register("p2p-node", execP2PNode) -} - // execNodeConfig is used to serialize the node configuration so it can be // passed to the child process as a JSON encoded environment variable type execNodeConfig struct { @@ -338,54 +365,76 @@ type execNodeConfig struct { PeerAddrs map[string]string `json:"peer_addrs,omitempty"` } -// execP2PNode starts a devp2p node when the current binary is executed with +// execP2PNode starts a simulation node when the current binary is executed with // argv[0] being "p2p-node", reading the service / ID from argv[1] / argv[2] -// and the node config from the _P2P_NODE_CONFIG environment variable +// and the node config from an environment variable. func execP2PNode() { glogger := log.NewGlogHandler(log.StreamHandler(os.Stderr, log.LogfmtFormat())) glogger.Verbosity(log.LvlInfo) log.Root().SetHandler(glogger) + statusURL := os.Getenv(envStatusURL) + if statusURL == "" { + log.Crit("missing " + envStatusURL) + } + + // Start the node and gather startup report. + var status nodeStartupJSON + stack, stackErr := startExecNodeStack() + if stackErr != nil { + status.Err = stackErr.Error() + } else { + status.WSEndpoint = "ws://" + stack.WSEndpoint() + status.NodeInfo = stack.Server().NodeInfo() + } + // Send status to the host. + statusJSON, _ := json.Marshal(status) + if _, err := http.Post(statusURL, "application/json", bytes.NewReader(statusJSON)); err != nil { + log.Crit("Can't post startup info", "url", statusURL, "err", err) + } + if stackErr != nil { + os.Exit(1) + } + + // Stop the stack if we get a SIGTERM signal. + go func() { + sigc := make(chan os.Signal, 1) + signal.Notify(sigc, syscall.SIGTERM) + defer signal.Stop(sigc) + <-sigc + log.Info("Received SIGTERM, shutting down...") + stack.Stop() + }() + stack.Wait() // Wait for the stack to exit. +} + +func startExecNodeStack() (*node.Node, error) { // read the services from argv serviceNames := strings.Split(os.Args[1], ",") // decode the config - confEnv := os.Getenv("_P2P_NODE_CONFIG") + confEnv := os.Getenv(envNodeConfig) if confEnv == "" { - log.Crit("missing _P2P_NODE_CONFIG") + return nil, fmt.Errorf("missing " + envNodeConfig) } var conf execNodeConfig if err := json.Unmarshal([]byte(confEnv), &conf); err != nil { - log.Crit("error decoding _P2P_NODE_CONFIG", "err", err) + return nil, fmt.Errorf("error decoding %s: %v", envNodeConfig, err) } - conf.Stack.P2P.PrivateKey = conf.Node.PrivateKey - conf.Stack.Logger = log.New("node.id", conf.Node.ID.String()) - // use explicit IP address in ListenAddr so that Enode URL is usable - externalIP := func() string { - addrs, err := net.InterfaceAddrs() - if err != nil { - log.Crit("error getting IP address", "err", err) - } - for _, addr := range addrs { - if ip, ok := addr.(*net.IPNet); ok && !ip.IP.IsLoopback() { - return ip.IP.String() - } - } - log.Crit("unable to determine explicit IP address") - return "" - } - if strings.HasPrefix(conf.Stack.P2P.ListenAddr, ":") { - conf.Stack.P2P.ListenAddr = externalIP() + conf.Stack.P2P.ListenAddr - } - if conf.Stack.WSHost == "0.0.0.0" { - conf.Stack.WSHost = externalIP() + // create enode record + nodeTcpConn, _ := net.ResolveTCPAddr("tcp", conf.Stack.P2P.ListenAddr) + if nodeTcpConn.IP == nil { + nodeTcpConn.IP = net.IPv4(127, 0, 0, 1) } + conf.Node.initEnode(nodeTcpConn.IP, nodeTcpConn.Port, nodeTcpConn.Port) + conf.Stack.P2P.PrivateKey = conf.Node.PrivateKey + conf.Stack.Logger = log.New("node.id", conf.Node.ID.String()) // initialize the devp2p stack stack, err := node.New(&conf.Stack) if err != nil { - log.Crit("error creating node stack", "err", err) + return nil, fmt.Errorf("error creating node stack: %v", err) } // register the services, collecting them into a map so we can wrap @@ -394,7 +443,7 @@ func execP2PNode() { for _, name := range serviceNames { serviceFunc, exists := serviceFuncs[name] if !exists { - log.Crit("unknown node service", "name", name) + return nil, fmt.Errorf("unknown node service %q", err) } constructor := func(nodeCtx *node.ServiceContext) (node.Service, error) { ctx := &ServiceContext{ @@ -413,34 +462,35 @@ func execP2PNode() { return service, nil } if err := stack.Register(constructor); err != nil { - log.Crit("error starting service", "name", name, "err", err) + return stack, fmt.Errorf("error registering service %q: %v", name, err) } } // register the snapshot service - if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { + err = stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { return &snapshotService{services}, nil - }); err != nil { - log.Crit("error starting snapshot service", "err", err) + }) + if err != nil { + return stack, fmt.Errorf("error starting snapshot service: %v", err) } // start the stack - if err := stack.Start(); err != nil { - log.Crit("error stating node stack", "err", err) + if err = stack.Start(); err != nil { + err = fmt.Errorf("error starting stack: %v", err) } + return stack, err +} - // stop the stack if we get a SIGTERM signal - go func() { - sigc := make(chan os.Signal, 1) - signal.Notify(sigc, syscall.SIGTERM) - defer signal.Stop(sigc) - <-sigc - log.Info("Received SIGTERM, shutting down...") - stack.Stop() - }() +const ( + envStatusURL = "_P2P_STATUS_URL" + envNodeConfig = "_P2P_NODE_CONFIG" +) - // wait for the stack to exit - stack.Wait() +// nodeStartupJSON is sent to the simulation host after startup. +type nodeStartupJSON struct { + Err string + WSEndpoint string + NodeInfo *p2p.NodeInfo } // snapshotService is a node.Service which wraps a list of services and @@ -449,6 +499,8 @@ type snapshotService struct { services map[string]node.Service } +func (s *snapshotService) SaveData() {} + func (s *snapshotService) APIs() []rpc.API { return []rpc.API{{ Namespace: "simulation", @@ -465,8 +517,6 @@ func (s *snapshotService) Start(*p2p.Server) error { return nil } -func (s *snapshotService) SaveData() { -} func (s *snapshotService) Stop() error { return nil } diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index 3a21dd2792..56bcdb0fd1 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -23,17 +23,21 @@ import ( "net" "sync" + "github.com/gorilla/websocket" + "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" "github.com/tomochain/tomochain/p2p/enode" + "github.com/tomochain/tomochain/p2p/simulations/pipes" "github.com/tomochain/tomochain/rpc" ) // SimAdapter is a NodeAdapter which creates in-memory simulation nodes and -// connects them using in-memory net.Pipe connections +// connects them using net.Pipe type SimAdapter struct { + pipe func() (net.Conn, net.Conn, error) mtx sync.RWMutex nodes map[enode.ID]*SimNode services map[string]ServiceFunc @@ -42,8 +46,18 @@ type SimAdapter struct { // NewSimAdapter creates a SimAdapter which is capable of running in-memory // simulation nodes running any of the given services (the services to run on a // particular node are passed to the NewNode function in the NodeConfig) +// the adapter uses a net.Pipe for in-memory simulated network connections func NewSimAdapter(services map[string]ServiceFunc) *SimAdapter { return &SimAdapter{ + pipe: pipes.NetPipe, + nodes: make(map[enode.ID]*SimNode), + services: services, + } +} + +func NewTCPAdapter(services map[string]ServiceFunc) *SimAdapter { + return &SimAdapter{ + pipe: pipes.TCPPipe, nodes: make(map[enode.ID]*SimNode), services: services, } @@ -59,8 +73,13 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { s.mtx.Lock() defer s.mtx.Unlock() - // check a node with the ID doesn't already exist id := config.ID + // verify that the node has a private key in the config + if config.PrivateKey == nil { + return nil, fmt.Errorf("node is missing private key: %s", id) + } + + // check a node with the ID doesn't already exist if _, exists := s.nodes[id]; exists { return nil, fmt.Errorf("node already exists: %s", id) } @@ -75,13 +94,18 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { } } + err := config.initDummyEnode() + if err != nil { + return nil, err + } + n, err := node.New(&node.Config{ P2P: p2p.Config{ PrivateKey: config.PrivateKey, MaxPeers: math.MaxInt32, NoDiscovery: true, Dialer: s, - EnableMsgEvents: true, + EnableMsgEvents: config.EnableMsgEvents, }, NoUSB: true, Logger: log.New("node.id", id.String()), @@ -91,34 +115,36 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { } simNode := &SimNode{ - ID: id, - config: config, - node: n, - adapter: s, - running: make(map[string]node.Service), - connected: make(map[enode.ID]bool), + ID: id, + config: config, + node: n, + adapter: s, + running: make(map[string]node.Service), } s.nodes[id] = simNode return simNode, nil } // Dial implements the p2p.NodeDialer interface by connecting to the node using -// an in-memory net.Pipe connection +// an in-memory net.Pipe func (s *SimAdapter) Dial(dest *enode.Node) (conn net.Conn, err error) { node, ok := s.GetNode(dest.ID()) if !ok { return nil, fmt.Errorf("unknown node: %s", dest.ID()) } - if node.connected[dest.ID()] { - return nil, fmt.Errorf("dialed node: %s", dest.ID()) - } srv := node.Server() if srv == nil { return nil, fmt.Errorf("node not running: %s", dest.ID()) } - pipe1, pipe2 := net.Pipe() + // SimAdapter.pipe is net.Pipe (NewSimAdapter) + pipe1, pipe2, err := s.pipe() + if err != nil { + return nil, err + } + // this is simulated 'listening' + // asynchronously call the dialed destination node's p2p server + // to set up connection on the 'listening' side go srv.SetupConn(pipe1, 0, nil) - node.connected[dest.ID()] = true return pipe2, nil } @@ -145,8 +171,8 @@ func (s *SimAdapter) GetNode(id enode.ID) (*SimNode, bool) { } // SimNode is an in-memory simulation node which connects to other nodes using -// an in-memory net.Pipe connection (see SimAdapter.Dial), running devp2p -// protocols directly over that pipe +// net.Pipe (see SimAdapter.Dial), running devp2p protocols directly over that +// pipe type SimNode struct { lock sync.RWMutex ID enode.ID @@ -156,12 +182,17 @@ type SimNode struct { running map[string]node.Service client *rpc.Client registerOnce sync.Once - connected map[enode.ID]bool +} + +// Close closes the underlaying node.Node to release +// acquired resources. +func (sn *SimNode) Close() error { + return sn.node.Close() } // Addr returns the node's discovery address -func (self *SimNode) Addr() []byte { - return []byte(self.Node().String()) +func (sn *SimNode) Addr() []byte { + return []byte(sn.Node().String()) } // Node returns a node descriptor representing the SimNode @@ -171,35 +202,36 @@ func (sn *SimNode) Node() *enode.Node { // Client returns an rpc.Client which can be used to communicate with the // underlying services (it is set once the node has started) -func (self *SimNode) Client() (*rpc.Client, error) { - self.lock.RLock() - defer self.lock.RUnlock() - if self.client == nil { +func (sn *SimNode) Client() (*rpc.Client, error) { + sn.lock.RLock() + defer sn.lock.RUnlock() + if sn.client == nil { return nil, errors.New("node not started") } - return self.client, nil + return sn.client, nil } // ServeRPC serves RPC requests over the given connection by creating an -// in-memory client to the node's RPC server -func (self *SimNode) ServeRPC(conn net.Conn) error { - handler, err := self.node.RPCHandler() +// in-memory client to the node's RPC server. +func (sn *SimNode) ServeRPC(conn *websocket.Conn) error { + handler, err := sn.node.RPCHandler() if err != nil { return err } - handler.ServeCodec(rpc.NewJSONCodec(conn), rpc.OptionMethodInvocation|rpc.OptionSubscriptions) + codec := rpc.NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON) + handler.ServeCodec(codec, 0) return nil } // Snapshots creates snapshots of the services by calling the // simulation_snapshot RPC method -func (self *SimNode) Snapshots() (map[string][]byte, error) { - self.lock.RLock() - services := make(map[string]node.Service, len(self.running)) - for name, service := range self.running { +func (sn *SimNode) Snapshots() (map[string][]byte, error) { + sn.lock.RLock() + services := make(map[string]node.Service, len(sn.running)) + for name, service := range sn.running { services[name] = service } - self.lock.RUnlock() + sn.lock.RUnlock() if len(services) == 0 { return nil, errors.New("no running services") } @@ -219,23 +251,23 @@ func (self *SimNode) Snapshots() (map[string][]byte, error) { } // Start registers the services and starts the underlying devp2p node -func (self *SimNode) Start(snapshots map[string][]byte) error { +func (sn *SimNode) Start(snapshots map[string][]byte) error { newService := func(name string) func(ctx *node.ServiceContext) (node.Service, error) { return func(nodeCtx *node.ServiceContext) (node.Service, error) { ctx := &ServiceContext{ - RPCDialer: self.adapter, + RPCDialer: sn.adapter, NodeContext: nodeCtx, - Config: self.config, + Config: sn.config, } if snapshots != nil { ctx.Snapshot = snapshots[name] } - serviceFunc := self.adapter.services[name] + serviceFunc := sn.adapter.services[name] service, err := serviceFunc(ctx) if err != nil { return nil, err } - self.running[name] = service + sn.running[name] = service return service, nil } } @@ -243,11 +275,11 @@ func (self *SimNode) Start(snapshots map[string][]byte) error { // ensure we only register the services once in the case of the node // being stopped and then started again var regErr error - self.registerOnce.Do(func() { - for _, name := range self.config.Services { - if err := self.node.Register(newService(name)); err != nil { + sn.registerOnce.Do(func() { + for _, name := range sn.config.Services { + if err := sn.node.Register(newService(name)); err != nil { regErr = err - return + break } } }) @@ -255,54 +287,72 @@ func (self *SimNode) Start(snapshots map[string][]byte) error { return regErr } - if err := self.node.Start(); err != nil { + if err := sn.node.Start(); err != nil { return err } // create an in-process RPC client - handler, err := self.node.RPCHandler() + handler, err := sn.node.RPCHandler() if err != nil { return err } - self.lock.Lock() - self.client = rpc.DialInProc(handler) - self.lock.Unlock() + sn.lock.Lock() + sn.client = rpc.DialInProc(handler) + sn.lock.Unlock() return nil } // Stop closes the RPC client and stops the underlying devp2p node -func (self *SimNode) Stop() error { - self.lock.Lock() - if self.client != nil { - self.client.Close() - self.client = nil +func (sn *SimNode) Stop() error { + sn.lock.Lock() + if sn.client != nil { + sn.client.Close() + sn.client = nil } - self.lock.Unlock() - return self.node.Stop() + sn.lock.Unlock() + return sn.node.Stop() +} + +// Service returns a running service by name +func (sn *SimNode) Service(name string) node.Service { + sn.lock.RLock() + defer sn.lock.RUnlock() + return sn.running[name] } // Services returns a copy of the underlying services -func (self *SimNode) Services() []node.Service { - self.lock.RLock() - defer self.lock.RUnlock() - services := make([]node.Service, 0, len(self.running)) - for _, service := range self.running { +func (sn *SimNode) Services() []node.Service { + sn.lock.RLock() + defer sn.lock.RUnlock() + services := make([]node.Service, 0, len(sn.running)) + for _, service := range sn.running { services = append(services, service) } return services } +// ServiceMap returns a map by names of the underlying services +func (sn *SimNode) ServiceMap() map[string]node.Service { + sn.lock.RLock() + defer sn.lock.RUnlock() + services := make(map[string]node.Service, len(sn.running)) + for name, service := range sn.running { + services[name] = service + } + return services +} + // Server returns the underlying p2p.Server -func (self *SimNode) Server() *p2p.Server { - return self.node.Server() +func (sn *SimNode) Server() *p2p.Server { + return sn.node.Server() } // SubscribeEvents subscribes the given channel to peer events from the // underlying p2p.Server -func (self *SimNode) SubscribeEvents(ch chan *p2p.PeerEvent) event.Subscription { - srv := self.Server() +func (sn *SimNode) SubscribeEvents(ch chan *p2p.PeerEvent) event.Subscription { + srv := sn.Server() if srv == nil { panic("node not running") } @@ -310,12 +360,12 @@ func (self *SimNode) SubscribeEvents(ch chan *p2p.PeerEvent) event.Subscription } // NodeInfo returns information about the node -func (self *SimNode) NodeInfo() *p2p.NodeInfo { - server := self.Server() +func (sn *SimNode) NodeInfo() *p2p.NodeInfo { + server := sn.Server() if server == nil { return &p2p.NodeInfo{ - ID: self.ID.String(), - Enode: self.Node().String(), + ID: sn.ID.String(), + Enode: sn.Node().String(), } } return server.NodeInfo() diff --git a/p2p/simulations/adapters/inproc_test.go b/p2p/simulations/adapters/inproc_test.go new file mode 100644 index 0000000000..32a9e1ac2a --- /dev/null +++ b/p2p/simulations/adapters/inproc_test.go @@ -0,0 +1,259 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package adapters + +import ( + "bytes" + "encoding/binary" + "fmt" + "testing" + "time" + + "github.com/tomochain/tomochain/p2p/simulations/pipes" +) + +func TestTCPPipe(t *testing.T) { + c1, c2, err := pipes.TCPPipe() + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + + go func() { + msgs := 50 + size := 1024 + for i := 0; i < msgs; i++ { + msg := make([]byte, size) + _ = binary.PutUvarint(msg, uint64(i)) + + _, err := c1.Write(msg) + if err != nil { + t.Fatal(err) + } + } + + for i := 0; i < msgs; i++ { + msg := make([]byte, size) + _ = binary.PutUvarint(msg, uint64(i)) + + out := make([]byte, size) + _, err := c2.Read(out) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(msg, out) { + t.Fatalf("expected %#v, got %#v", msg, out) + } + } + done <- struct{}{} + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("test timeout") + } +} + +func TestTCPPipeBidirections(t *testing.T) { + c1, c2, err := pipes.TCPPipe() + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + + go func() { + msgs := 50 + size := 7 + for i := 0; i < msgs; i++ { + msg := []byte(fmt.Sprintf("ping %02d", i)) + + _, err := c1.Write(msg) + if err != nil { + t.Fatal(err) + } + } + + for i := 0; i < msgs; i++ { + expected := []byte(fmt.Sprintf("ping %02d", i)) + + out := make([]byte, size) + _, err := c2.Read(out) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(expected, out) { + t.Fatalf("expected %#v, got %#v", out, expected) + } else { + msg := []byte(fmt.Sprintf("pong %02d", i)) + _, err := c2.Write(msg) + if err != nil { + t.Fatal(err) + } + } + } + + for i := 0; i < msgs; i++ { + expected := []byte(fmt.Sprintf("pong %02d", i)) + + out := make([]byte, size) + _, err := c1.Read(out) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(expected, out) { + t.Fatalf("expected %#v, got %#v", out, expected) + } + } + done <- struct{}{} + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("test timeout") + } +} + +func TestNetPipe(t *testing.T) { + c1, c2, err := pipes.NetPipe() + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + + go func() { + msgs := 50 + size := 1024 + // netPipe is blocking, so writes are emitted asynchronously + go func() { + for i := 0; i < msgs; i++ { + msg := make([]byte, size) + _ = binary.PutUvarint(msg, uint64(i)) + + _, err := c1.Write(msg) + if err != nil { + t.Fatal(err) + } + } + }() + + for i := 0; i < msgs; i++ { + msg := make([]byte, size) + _ = binary.PutUvarint(msg, uint64(i)) + + out := make([]byte, size) + _, err := c2.Read(out) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(msg, out) { + t.Fatalf("expected %#v, got %#v", msg, out) + } + } + + done <- struct{}{} + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("test timeout") + } +} + +func TestNetPipeBidirections(t *testing.T) { + c1, c2, err := pipes.NetPipe() + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + + go func() { + msgs := 1000 + size := 8 + pingTemplate := "ping %03d" + pongTemplate := "pong %03d" + + // netPipe is blocking, so writes are emitted asynchronously + go func() { + for i := 0; i < msgs; i++ { + msg := []byte(fmt.Sprintf(pingTemplate, i)) + + _, err := c1.Write(msg) + if err != nil { + t.Fatal(err) + } + } + }() + + // netPipe is blocking, so reads for pong are emitted asynchronously + go func() { + for i := 0; i < msgs; i++ { + expected := []byte(fmt.Sprintf(pongTemplate, i)) + + out := make([]byte, size) + _, err := c1.Read(out) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(expected, out) { + t.Fatalf("expected %#v, got %#v", expected, out) + } + } + + done <- struct{}{} + }() + + // expect to read pings, and respond with pongs to the alternate connection + for i := 0; i < msgs; i++ { + expected := []byte(fmt.Sprintf(pingTemplate, i)) + + out := make([]byte, size) + _, err := c2.Read(out) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(expected, out) { + t.Fatalf("expected %#v, got %#v", expected, out) + } else { + msg := []byte(fmt.Sprintf(pongTemplate, i)) + + _, err := c2.Write(msg) + if err != nil { + t.Fatal(err) + } + } + } + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("test timeout") + } +} diff --git a/p2p/simulations/adapters/types.go b/p2p/simulations/adapters/types.go index 089f50ea20..68690aebea 100644 --- a/p2p/simulations/adapters/types.go +++ b/p2p/simulations/adapters/types.go @@ -26,11 +26,14 @@ import ( "strconv" "github.com/docker/docker/pkg/reexec" + "github.com/gorilla/websocket" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" "github.com/tomochain/tomochain/p2p/enode" + "github.com/tomochain/tomochain/p2p/enr" "github.com/tomochain/tomochain/rpc" ) @@ -49,7 +52,7 @@ type Node interface { Client() (*rpc.Client, error) // ServeRPC serves RPC requests over the given connection - ServeRPC(net.Conn) error + ServeRPC(*websocket.Conn) error // Start starts the node with the given snapshots Start(snapshots map[string][]byte) error @@ -90,12 +93,26 @@ type NodeConfig struct { // Name is a human friendly name for the node like "node01" Name string + // Use an existing database instead of a temporary one if non-empty + DataDir string + // Services are the names of the services which should be run when // starting the node (for SimNodes it should be the names of services // contained in SimAdapter.services, for other nodes it should be // services registered by calling the RegisterService function) Services []string + // Properties are the names of the properties this node should hold + // within running services (e.g. "bootnode", "lightnode" or any custom values) + // These values need to be checked and acted upon by node Services + Properties []string + + // Enode + node *enode.Node + + // ENR Record with entries to overwrite + Record enr.Record + // function to sanction or prevent suggesting a peer Reachable func(id enode.ID) bool @@ -109,6 +126,7 @@ type nodeConfigJSON struct { PrivateKey string `json:"private_key"` Name string `json:"name"` Services []string `json:"services"` + Properties []string `json:"properties"` EnableMsgEvents bool `json:"enable_msg_events"` Port uint16 `json:"port"` } @@ -120,6 +138,7 @@ func (n *NodeConfig) MarshalJSON() ([]byte, error) { ID: n.ID.String(), Name: n.Name, Services: n.Services, + Properties: n.Properties, Port: n.Port, EnableMsgEvents: n.EnableMsgEvents, } @@ -157,6 +176,7 @@ func (n *NodeConfig) UnmarshalJSON(data []byte) error { n.Name = confJSON.Name n.Services = confJSON.Services + n.Properties = confJSON.Properties n.Port = confJSON.Port n.EnableMsgEvents = confJSON.EnableMsgEvents @@ -165,26 +185,27 @@ func (n *NodeConfig) UnmarshalJSON(data []byte) error { // Node returns the node descriptor represented by the config. func (n *NodeConfig) Node() *enode.Node { - return enode.NewV4(&n.PrivateKey.PublicKey, net.IP{127, 0, 0, 1}, int(n.Port), int(n.Port)) + return n.node } // RandomNodeConfig returns node configuration with a randomly generated ID and // PrivateKey func RandomNodeConfig() *NodeConfig { - key, err := crypto.GenerateKey() + prvkey, err := crypto.GenerateKey() if err != nil { panic("unable to generate key") } - id := enode.PubkeyToIDV4(&key.PublicKey) port, err := assignTCPPort() if err != nil { panic("unable to assign tcp port") } + + enodId := enode.PubkeyToIDV4(&prvkey.PublicKey) return &NodeConfig{ - ID: id, - Name: fmt.Sprintf("node_%s", id.String()), - PrivateKey: key, + PrivateKey: prvkey, + ID: enodId, + Name: fmt.Sprintf("node_%s", enodId.String()), Port: port, EnableMsgEvents: true, } @@ -254,3 +275,30 @@ func RegisterServices(services Services) { os.Exit(0) } } + +// adds the host part to the configuration's ENR, signs it +// creates and the corresponding enode object to the configuration +func (n *NodeConfig) initEnode(ip net.IP, tcpport int, udpport int) error { + enrIp := enr.IP(ip) + n.Record.Set(&enrIp) + enrTcpPort := enr.TCP(tcpport) + n.Record.Set(&enrTcpPort) + enrUdpPort := enr.UDP(udpport) + n.Record.Set(&enrUdpPort) + + err := enode.SignV4(&n.Record, n.PrivateKey) + if err != nil { + return fmt.Errorf("unable to generate ENR: %v", err) + } + nod, err := enode.New(enode.V4ID{}, &n.Record) + if err != nil { + return fmt.Errorf("unable to create enode: %v", err) + } + log.Trace("simnode new", "record", n.Record) + n.node = nod + return nil +} + +func (n *NodeConfig) initDummyEnode() error { + return n.initEnode(net.IPv4(127, 0, 0, 1), 0, 0) +} diff --git a/p2p/simulations/http.go b/p2p/simulations/http.go index d7ed380a4e..6faac7f751 100644 --- a/p2p/simulations/http.go +++ b/p2p/simulations/http.go @@ -29,14 +29,14 @@ import ( "strings" "sync" - "github.com/tomochain/tomochain/p2p/enode" - + "github.com/gorilla/websocket" "github.com/julienschmidt/httprouter" + "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/p2p" + "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/simulations/adapters" "github.com/tomochain/tomochain/rpc" - "golang.org/x/net/websocket" ) // DefaultClient is the default simulation API client which expects the API @@ -562,7 +562,8 @@ func (s *Server) LoadSnapshot(w http.ResponseWriter, req *http.Request) { // CreateNode creates a node in the network using the given configuration func (s *Server) CreateNode(w http.ResponseWriter, req *http.Request) { - config := adapters.RandomNodeConfig() + config := &adapters.NodeConfig{} + err := json.NewDecoder(req.Body).Decode(config) if err != nil && err != io.EOF { http.Error(w, err.Error(), http.StatusBadRequest) @@ -654,16 +655,20 @@ func (s *Server) Options(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) } +var wsUpgrade = websocket.Upgrader{ + CheckOrigin: func(*http.Request) bool { return true }, +} + // NodeRPC forwards RPC requests to a node in the network via a WebSocket // connection func (s *Server) NodeRPC(w http.ResponseWriter, req *http.Request) { - node := req.Context().Value("node").(*Node) - - handler := func(conn *websocket.Conn) { - node.ServeRPC(conn) + conn, err := wsUpgrade.Upgrade(w, req, nil) + if err != nil { + return } - - websocket.Server{Handler: handler}.ServeHTTP(w, req) + defer conn.Close() + node := req.Context().Value("node").(*Node) + node.ServeRPC(conn) } // ServeHTTP implements the http.Handler interface by delegating to the @@ -699,14 +704,14 @@ func (s *Server) JSON(w http.ResponseWriter, status int, data interface{}) { json.NewEncoder(w).Encode(data) } -// wrapHandler returns an httprouter.Handle which wraps an http.HandlerFunc by +// wrapHandler returns a httprouter.Handle which wraps a http.HandlerFunc by // populating request.Context with any objects from the URL params func (s *Server) wrapHandler(handler http.HandlerFunc) httprouter.Handle { return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - ctx := req.Context() + ctx := context.Background() if id := params.ByName("nodeid"); id != "" { var nodeID enode.ID diff --git a/p2p/simulations/http_test.go b/p2p/simulations/http_test.go index d557e5f996..50e7fb1a7a 100644 --- a/p2p/simulations/http_test.go +++ b/p2p/simulations/http_test.go @@ -18,16 +18,21 @@ package simulations import ( "context" + "flag" "fmt" "math/rand" "net/http/httptest" + "os" "reflect" "sync" "sync/atomic" "testing" "time" + "github.com/mattn/go-colorable" + "github.com/tomochain/tomochain/event" + "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" "github.com/tomochain/tomochain/p2p/enode" @@ -35,6 +40,15 @@ import ( "github.com/tomochain/tomochain/rpc" ) +func TestMain(m *testing.M) { + loglevel := flag.Int("loglevel", 2, "verbosity of logs") + + flag.Parse() + log.PrintOrigins(true) + log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(colorable.NewColorableStderr(), log.TerminalFormat(true)))) + os.Exit(m.Run()) +} + // testService implements the node.Service interface and provides protocols // and APIs which are useful for testing nodes in a simulation network type testService struct { @@ -51,6 +65,8 @@ type testService struct { state atomic.Value } +func (t *testService) SaveData() {} + func newTestService(ctx *adapters.ServiceContext) (node.Service, error) { svc := &testService{ id: ctx.Config.ID, @@ -117,8 +133,6 @@ func (t *testService) Start(server *p2p.Server) error { return nil } -func (t *testService) SaveData() { -} func (t *testService) Stop() error { return nil } @@ -282,6 +296,7 @@ var testServices = adapters.Services{ } func testHTTPServer(t *testing.T) (*Network, *httptest.Server) { + t.Helper() adapter := adapters.NewSimAdapter(testServices) network := NewNetwork(adapter, &NetworkConfig{ DefaultService: "test", @@ -350,7 +365,8 @@ func startTestNetwork(t *testing.T, client *Client) []string { nodeCount := 2 nodeIDs := make([]string, nodeCount) for i := 0; i < nodeCount; i++ { - node, err := client.CreateNode(nil) + config := adapters.RandomNodeConfig() + node, err := client.CreateNode(config) if err != nil { t.Fatalf("error creating node: %s", err) } @@ -407,14 +423,15 @@ type expectEvents struct { } func (t *expectEvents) nodeEvent(id string, up bool) *Event { + node := Node{ + Config: &adapters.NodeConfig{ + ID: enode.HexID(id), + }, + Up: up, + } return &Event{ Type: EventTypeNode, - Node: &Node{ - Config: &adapters.NodeConfig{ - ID: enode.HexID(id), - }, - Up: up, - }, + Node: &node, } } @@ -466,6 +483,7 @@ loop: } func (t *expectEvents) expect(events ...*Event) { + t.Helper() timeout := time.After(10 * time.Second) i := 0 for { @@ -529,7 +547,9 @@ func TestHTTPNodeRPC(t *testing.T) { // start a node in the network client := NewClient(s.URL) - node, err := client.CreateNode(nil) + + config := adapters.RandomNodeConfig() + node, err := client.CreateNode(config) if err != nil { t.Fatalf("error creating node: %s", err) } @@ -583,15 +603,33 @@ func TestHTTPNodeRPC(t *testing.T) { // TestHTTPSnapshot tests creating and loading network snapshots func TestHTTPSnapshot(t *testing.T) { // start the server - _, s := testHTTPServer(t) + network, s := testHTTPServer(t) defer s.Close() + var eventsDone = make(chan struct{}) + count := 1 + eventsDoneChan := make(chan *Event) + eventSub := network.Events().Subscribe(eventsDoneChan) + go func() { + defer eventSub.Unsubscribe() + for event := range eventsDoneChan { + if event.Type == EventTypeConn && !event.Control { + count-- + if count == 0 { + eventsDone <- struct{}{} + return + } + } + } + }() + // create a two-node network client := NewClient(s.URL) nodeCount := 2 nodes := make([]*p2p.NodeInfo, nodeCount) for i := 0; i < nodeCount; i++ { - node, err := client.CreateNode(nil) + config := adapters.RandomNodeConfig() + node, err := client.CreateNode(config) if err != nil { t.Fatalf("error creating node: %s", err) } @@ -618,7 +656,7 @@ func TestHTTPSnapshot(t *testing.T) { } states[i] = state } - + <-eventsDone // create a snapshot snap, err := client.CreateSnapshot() if err != nil { @@ -632,9 +670,23 @@ func TestHTTPSnapshot(t *testing.T) { } // create another network - _, s = testHTTPServer(t) + network2, s := testHTTPServer(t) defer s.Close() client = NewClient(s.URL) + count = 1 + eventSub = network2.Events().Subscribe(eventsDoneChan) + go func() { + defer eventSub.Unsubscribe() + for event := range eventsDoneChan { + if event.Type == EventTypeConn && !event.Control { + count-- + if count == 0 { + eventsDone <- struct{}{} + return + } + } + } + }() // subscribe to events so we can check them later events := make(chan *Event, 100) @@ -649,6 +701,7 @@ func TestHTTPSnapshot(t *testing.T) { if err := client.LoadSnapshot(snap); err != nil { t.Fatalf("error loading snapshot: %s", err) } + <-eventsDone // check the nodes and connection exists net, err := client.GetNetwork() @@ -674,6 +727,9 @@ func TestHTTPSnapshot(t *testing.T) { if conn.Other.String() != nodes[1].ID { t.Fatalf("expected connection to have other=%q, got other=%q", nodes[1].ID, conn.Other) } + if !conn.Up { + t.Fatal("should be up") + } // check the node states were restored for i, node := range nodes { From c283f8a2c12abdb159034d980517bce2f3134134 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 5 Dec 2023 12:38:14 +0700 Subject: [PATCH 114/119] Fix unit tests --- p2p/protocols/protocol.go | 311 -------------------------- p2p/protocols/protocol_test.go | 389 --------------------------------- rpc/server_test.go | 2 +- rpc/subscription_test.go | 4 +- 4 files changed, 3 insertions(+), 703 deletions(-) delete mode 100644 p2p/protocols/protocol.go delete mode 100644 p2p/protocols/protocol_test.go diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go deleted file mode 100644 index cb334d318e..0000000000 --- a/p2p/protocols/protocol.go +++ /dev/null @@ -1,311 +0,0 @@ -// Copyright 2017 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -/* -Package protocols is an extension to p2p. It offers a user friendly simple way to define -devp2p subprotocols by abstracting away code standardly shared by protocols. - -* automate assigments of code indexes to messages -* automate RLP decoding/encoding based on reflecting -* provide the forever loop to read incoming messages -* standardise error handling related to communication -* standardised handshake negotiation -* TODO: automatic generation of wire protocol specification for peers - -*/ -package protocols - -import ( - "context" - "fmt" - "reflect" - "sync" - - "github.com/tomochain/tomochain/p2p" -) - -// error codes used by this protocol scheme -const ( - ErrMsgTooLong = iota - ErrDecode - ErrWrite - ErrInvalidMsgCode - ErrInvalidMsgType - ErrHandshake - ErrNoHandler - ErrHandler -) - -// error description strings associated with the codes -var errorToString = map[int]string{ - ErrMsgTooLong: "Message too long", - ErrDecode: "Invalid message (RLP error)", - ErrWrite: "Error sending message", - ErrInvalidMsgCode: "Invalid message code", - ErrInvalidMsgType: "Invalid message type", - ErrHandshake: "Handshake error", - ErrNoHandler: "No handler registered error", - ErrHandler: "Message handler error", -} - -/* -Error implements the standard go error interface. -Use: - - errorf(code, format, params ...interface{}) - -Prints as: - - :
- -where description is given by code in errorToString -and details is fmt.Sprintf(format, params...) - -exported field Code can be checked -*/ -type Error struct { - Code int - message string - format string - params []interface{} -} - -func (e Error) Error() (message string) { - if len(e.message) == 0 { - name, ok := errorToString[e.Code] - if !ok { - panic("invalid message code") - } - e.message = name - if e.format != "" { - e.message += ": " + fmt.Sprintf(e.format, e.params...) - } - } - return e.message -} - -func errorf(code int, format string, params ...interface{}) *Error { - return &Error{ - Code: code, - format: format, - params: params, - } -} - -// Spec is a protocol specification including its name and version as well as -// the types of messages which are exchanged -type Spec struct { - // Name is the name of the protocol, often a three-letter word - Name string - - // Version is the version number of the protocol - Version uint - - // MaxMsgSize is the maximum accepted length of the message payload - MaxMsgSize uint32 - - // Messages is a list of message data types which this protocol uses, with - // each message type being sent with its array index as the code (so - // [&foo{}, &bar{}, &baz{}] would send foo, bar and baz with codes - // 0, 1 and 2 respectively) - // each message must have a single unique data type - Messages []interface{} - - initOnce sync.Once - codes map[reflect.Type]uint64 - types map[uint64]reflect.Type -} - -func (s *Spec) init() { - s.initOnce.Do(func() { - s.codes = make(map[reflect.Type]uint64, len(s.Messages)) - s.types = make(map[uint64]reflect.Type, len(s.Messages)) - for i, msg := range s.Messages { - code := uint64(i) - typ := reflect.TypeOf(msg) - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - s.codes[typ] = code - s.types[code] = typ - } - }) -} - -// Length returns the number of message types in the protocol -func (s *Spec) Length() uint64 { - return uint64(len(s.Messages)) -} - -// GetCode returns the message code of a type, and boolean second argument is -// false if the message type is not found -func (s *Spec) GetCode(msg interface{}) (uint64, bool) { - s.init() - typ := reflect.TypeOf(msg) - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - code, ok := s.codes[typ] - return code, ok -} - -// NewMsg construct a new message type given the code -func (s *Spec) NewMsg(code uint64) (interface{}, bool) { - s.init() - typ, ok := s.types[code] - if !ok { - return nil, false - } - return reflect.New(typ).Interface(), true -} - -// Peer represents a remote peer or protocol instance that is running on a peer connection with -// a remote peer -type Peer struct { - *p2p.Peer // the p2p.Peer object representing the remote - rw p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from - spec *Spec -} - -// NewPeer constructs a new peer -// this constructor is called by the p2p.Protocol#Run function -// the first two arguments are the arguments passed to p2p.Protocol.Run function -// the third argument is the Spec describing the protocol -func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer { - return &Peer{ - Peer: p, - rw: rw, - spec: spec, - } -} - -// Run starts the forever loop that handles incoming messages -// called within the p2p.Protocol#Run function -// the handler argument is a function which is called for each message received -// from the remote peer, a returned error causes the loop to exit -// resulting in disconnection -func (p *Peer) Run(handler func(msg interface{}) error) error { - for { - if err := p.handleIncoming(handler); err != nil { - return err - } - } -} - -// Drop disconnects a peer. -// TODO: may need to implement protocol drop only? don't want to kick off the peer -// if they are useful for other protocols -func (p *Peer) Drop(err error) { - p.Disconnect(p2p.DiscSubprotocolError) -} - -// Send takes a message, encodes it in RLP, finds the right message code and sends the -// message off to the peer -// this low level call will be wrapped by libraries providing routed or broadcast sends -// but often just used to forward and push messages to directly connected peers -func (p *Peer) Send(msg interface{}) error { - code, found := p.spec.GetCode(msg) - if !found { - return errorf(ErrInvalidMsgType, "%v", code) - } - return p2p.Send(p.rw, code, msg) -} - -// handleIncoming(code) -// is called each cycle of the main forever loop that dispatches incoming messages -// if this returns an error the loop returns and the peer is disconnected with the error -// this generic handler -// * checks message size, -// * checks for out-of-range message codes, -// * handles decoding with reflection, -// * call handlers as callbacks -func (p *Peer) handleIncoming(handle func(msg interface{}) error) error { - msg, err := p.rw.ReadMsg() - if err != nil { - return err - } - // make sure that the payload has been fully consumed - defer msg.Discard() - - if msg.Size > p.spec.MaxMsgSize { - return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize) - } - - val, ok := p.spec.NewMsg(msg.Code) - if !ok { - return errorf(ErrInvalidMsgCode, "%v", msg.Code) - } - if err := msg.Decode(val); err != nil { - return errorf(ErrDecode, "<= %v: %v", msg, err) - } - - // call the registered handler callbacks - // a registered callback take the decoded message as argument as an interface - // which the handler is supposed to cast to the appropriate type - // it is entirely safe not to check the cast in the handler since the handler is - // chosen based on the proper type in the first place - if err := handle(val); err != nil { - return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err) - } - return nil -} - -// Handshake negotiates a handshake on the peer connection -// * arguments -// * context -// * the local handshake to be sent to the remote peer -// * funcion to be called on the remote handshake (can be nil) -// * expects a remote handshake back of the same type -// * the dialing peer needs to send the handshake first and then waits for remote -// * the listening peer waits for the remote handshake and then sends it -// returns the remote handshake and an error -func (p *Peer) Handshake(ctx context.Context, hs interface{}, verify func(interface{}) error) (rhs interface{}, err error) { - if _, ok := p.spec.GetCode(hs); !ok { - return nil, errorf(ErrHandshake, "unknown handshake message type: %T", hs) - } - errc := make(chan error, 2) - handle := func(msg interface{}) error { - rhs = msg - if verify != nil { - return verify(rhs) - } - return nil - } - send := func() { errc <- p.Send(hs) } - receive := func() { errc <- p.handleIncoming(handle) } - - go func() { - if p.Inbound() { - receive() - send() - } else { - send() - receive() - } - }() - - for i := 0; i < 2; i++ { - select { - case err = <-errc: - case <-ctx.Done(): - err = ctx.Err() - } - if err != nil { - return nil, errorf(ErrHandshake, err.Error()) - } - } - return rhs, nil -} diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go deleted file mode 100644 index 0e4523e403..0000000000 --- a/p2p/protocols/protocol_test.go +++ /dev/null @@ -1,389 +0,0 @@ -// Copyright 2017 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package protocols - -import ( - "context" - "errors" - "fmt" - "testing" - "time" - - "github.com/tomochain/tomochain/p2p" - "github.com/tomochain/tomochain/p2p/enode" - "github.com/tomochain/tomochain/p2p/simulations/adapters" - p2ptest "github.com/tomochain/tomochain/p2p/testing" -) - -// handshake message type -type hs0 struct { - C uint -} - -// message to kill/drop the peer with nodeID -type kill struct { - C enode.ID -} - -// message to drop connection -type drop struct { -} - -// / protoHandshake represents module-independent aspects of the protocol and is -// the first message peers send and receive as part the initial exchange -type protoHandshake struct { - Version uint // local and remote peer should have identical version - NetworkID string // local and remote peer should have identical network id -} - -// checkProtoHandshake verifies local and remote protoHandshakes match -func checkProtoHandshake(testVersion uint, testNetworkID string) func(interface{}) error { - return func(rhs interface{}) error { - remote := rhs.(*protoHandshake) - if remote.NetworkID != testNetworkID { - return fmt.Errorf("%s (!= %s)", remote.NetworkID, testNetworkID) - } - - if remote.Version != testVersion { - return fmt.Errorf("%d (!= %d)", remote.Version, testVersion) - } - return nil - } -} - -// newProtocol sets up a protocol -// the run function here demonstrates a typical protocol using peerPool, handshake -// and messages registered to handlers -func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) error { - spec := &Spec{ - Name: "test", - Version: 42, - MaxMsgSize: 10 * 1024, - Messages: []interface{}{ - protoHandshake{}, - hs0{}, - kill{}, - drop{}, - }, - } - return func(p *p2p.Peer, rw p2p.MsgReadWriter) error { - peer := NewPeer(p, rw, spec) - - // initiate one-off protohandshake and check validity - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - phs := &protoHandshake{42, "420"} - hsCheck := checkProtoHandshake(phs.Version, phs.NetworkID) - _, err := peer.Handshake(ctx, phs, hsCheck) - if err != nil { - return err - } - - lhs := &hs0{42} - // module handshake demonstrating a simple repeatable exchange of same-type message - hs, err := peer.Handshake(ctx, lhs, nil) - if err != nil { - return err - } - - if rmhs := hs.(*hs0); rmhs.C > lhs.C { - return fmt.Errorf("handshake mismatch remote %v > local %v", rmhs.C, lhs.C) - } - - handle := func(msg interface{}) error { - switch msg := msg.(type) { - - case *protoHandshake: - return errors.New("duplicate handshake") - - case *hs0: - rhs := msg - if rhs.C > lhs.C { - return fmt.Errorf("handshake mismatch remote %v > local %v", rhs.C, lhs.C) - } - lhs.C += rhs.C - return peer.Send(lhs) - - case *kill: - // demonstrates use of peerPool, killing another peer connection as a response to a message - id := msg.C - pp.Get(id).Drop(errors.New("killed")) - return nil - - case *drop: - // for testing we can trigger self induced disconnect upon receiving drop message - return errors.New("dropped") - - default: - return fmt.Errorf("unknown message type: %T", msg) - } - } - - pp.Add(peer) - defer pp.Remove(peer) - return peer.Run(handle) - } -} - -func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTester { - conf := adapters.RandomNodeConfig() - return p2ptest.NewProtocolTester(t, conf.ID, 2, newProtocol(pp)) -} - -func protoHandshakeExchange(id enode.ID, proto *protoHandshake) []p2ptest.Exchange { - - return []p2ptest.Exchange{ - { - Expects: []p2ptest.Expect{ - { - Code: 0, - Msg: &protoHandshake{42, "420"}, - Peer: id, - }, - }, - }, - { - Triggers: []p2ptest.Trigger{ - { - Code: 0, - Msg: proto, - Peer: id, - }, - }, - }, - } -} - -func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) { - pp := p2ptest.NewTestPeerPool() - s := protocolTester(t, pp) - // TODO: make this more than one handshake - id := s.IDs[0] - if err := s.TestExchanges(protoHandshakeExchange(id, proto)...); err != nil { - t.Fatal(err) - } - var disconnects []*p2ptest.Disconnect - for i, err := range errs { - disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) - } - if err := s.TestDisconnected(disconnects...); err != nil { - t.Fatal(err) - } -} - -func TestProtoHandshakeVersionMismatch(t *testing.T) { - runProtoHandshake(t, &protoHandshake{41, "420"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 41 (!= 42)").Error())) -} - -func TestProtoHandshakeNetworkIDMismatch(t *testing.T) { - runProtoHandshake(t, &protoHandshake{42, "421"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 421 (!= 420)").Error())) -} - -func TestProtoHandshakeSuccess(t *testing.T) { - runProtoHandshake(t, &protoHandshake{42, "420"}) -} - -func moduleHandshakeExchange(id enode.ID, resp uint) []p2ptest.Exchange { - - return []p2ptest.Exchange{ - { - Expects: []p2ptest.Expect{ - { - Code: 1, - Msg: &hs0{42}, - Peer: id, - }, - }, - }, - { - Triggers: []p2ptest.Trigger{ - { - Code: 1, - Msg: &hs0{resp}, - Peer: id, - }, - }, - }, - } -} - -func runModuleHandshake(t *testing.T, resp uint, errs ...error) { - pp := p2ptest.NewTestPeerPool() - s := protocolTester(t, pp) - id := s.IDs[0] - if err := s.TestExchanges(protoHandshakeExchange(id, &protoHandshake{42, "420"})...); err != nil { - t.Fatal(err) - } - if err := s.TestExchanges(moduleHandshakeExchange(id, resp)...); err != nil { - t.Fatal(err) - } - var disconnects []*p2ptest.Disconnect - for i, err := range errs { - disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) - } - if err := s.TestDisconnected(disconnects...); err != nil { - t.Fatal(err) - } -} - -func TestModuleHandshakeError(t *testing.T) { - runModuleHandshake(t, 43, fmt.Errorf("handshake mismatch remote 43 > local 42")) -} - -func TestModuleHandshakeSuccess(t *testing.T) { - runModuleHandshake(t, 42) -} - -// testing complex interactions over multiple peers, relaying, dropping -func testMultiPeerSetup(a, b enode.ID) []p2ptest.Exchange { - - return []p2ptest.Exchange{ - { - Label: "primary handshake", - Expects: []p2ptest.Expect{ - { - Code: 0, - Msg: &protoHandshake{42, "420"}, - Peer: a, - }, - { - Code: 0, - Msg: &protoHandshake{42, "420"}, - Peer: b, - }, - }, - }, - { - Label: "module handshake", - Triggers: []p2ptest.Trigger{ - { - Code: 0, - Msg: &protoHandshake{42, "420"}, - Peer: a, - }, - { - Code: 0, - Msg: &protoHandshake{42, "420"}, - Peer: b, - }, - }, - Expects: []p2ptest.Expect{ - { - Code: 1, - Msg: &hs0{42}, - Peer: a, - }, - { - Code: 1, - Msg: &hs0{42}, - Peer: b, - }, - }, - }, - - {Label: "alternative module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{41}, Peer: a}, - {Code: 1, Msg: &hs0{41}, Peer: b}}}, - {Label: "repeated module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{1}, Peer: a}}}, - {Label: "receiving repeated module handshake", Expects: []p2ptest.Expect{{Code: 1, Msg: &hs0{43}, Peer: a}}}} -} - -func runMultiplePeers(t *testing.T, peer int, errs ...error) { - pp := p2ptest.NewTestPeerPool() - s := protocolTester(t, pp) - - if err := s.TestExchanges(testMultiPeerSetup(s.IDs[0], s.IDs[1])...); err != nil { - t.Fatal(err) - } - // after some exchanges of messages, we can test state changes - // here this is simply demonstrated by the peerPool - // after the handshake negotiations peers must be added to the pool - // time.Sleep(1) - tick := time.NewTicker(10 * time.Millisecond) - timeout := time.NewTimer(1 * time.Second) -WAIT: - for { - select { - case <-tick.C: - if pp.Has(s.IDs[0]) { - break WAIT - } - case <-timeout.C: - t.Fatal("timeout") - } - } - if !pp.Has(s.IDs[1]) { - t.Fatalf("missing peer test-1: %v (%v)", pp, s.IDs) - } - - // peer 0 sends kill request for peer with index - err := s.TestExchanges(p2ptest.Exchange{ - Triggers: []p2ptest.Trigger{ - { - Code: 2, - Msg: &kill{s.IDs[peer]}, - Peer: s.IDs[0], - }, - }, - }) - - if err != nil { - t.Fatal(err) - } - - // the peer not killed sends a drop request - err = s.TestExchanges(p2ptest.Exchange{ - Triggers: []p2ptest.Trigger{ - { - Code: 3, - Msg: &drop{}, - Peer: s.IDs[(peer+1)%2], - }, - }, - }) - - if err != nil { - t.Fatal(err) - } - - // check the actual discconnect errors on the individual peers - var disconnects []*p2ptest.Disconnect - for i, err := range errs { - disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) - } - if err := s.TestDisconnected(disconnects...); err != nil { - t.Fatal(err) - } - // test if disconnected peers have been removed from peerPool - if pp.Has(s.IDs[peer]) { - t.Fatalf("peer test-%v not dropped: %v (%v)", peer, pp, s.IDs) - } - -} - -func TestMultiplePeersDropSelf(t *testing.T) { - runMultiplePeers(t, 0, - fmt.Errorf("subprotocol error"), - fmt.Errorf("Message handler error: (msg code 3): dropped"), - ) -} - -func TestMultiplePeersDropOther(t *testing.T) { - runMultiplePeers(t, 1, - fmt.Errorf("Message handler error: (msg code 3): dropped"), - fmt.Errorf("subprotocol error"), - ) -} diff --git a/rpc/server_test.go b/rpc/server_test.go index 2a6926abf9..454d23a29a 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -77,7 +77,7 @@ func runTestScript(t *testing.T, file string) { clientConn, serverConn := net.Pipe() defer clientConn.Close() - go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions) + go server.ServeCodec(NewCodec(serverConn), OptionMethodInvocation|OptionSubscriptions) readbuf := bufio.NewReader(clientConn) for _, line := range strings.Split(string(content), "\n") { line = strings.TrimSpace(line) diff --git a/rpc/subscription_test.go b/rpc/subscription_test.go index 87ab4120af..cfec7c620a 100644 --- a/rpc/subscription_test.go +++ b/rpc/subscription_test.go @@ -68,7 +68,7 @@ func TestSubscriptions(t *testing.T) { t.Fatalf("unable to register test service %v", err) } } - go server.ServeCodec(NewJSONCodec(serverConn), 0) + go server.ServeCodec(NewCodec(serverConn), 0) defer server.Stop() // wait for message and write them to the given channels @@ -130,7 +130,7 @@ func TestServerUnsubscribe(t *testing.T) { service := ¬ificationTestService{unsubscribed: make(chan string)} server.RegisterName("nftest2", service) p1, p2 := net.Pipe() - go server.ServeCodec(NewJSONCodec(p1), OptionMethodInvocation|OptionSubscriptions) + go server.ServeCodec(NewCodec(p1), OptionMethodInvocation|OptionSubscriptions) p2.SetDeadline(time.Now().Add(10 * time.Second)) From a62b058693489d8678504b1a14f2b31a4a027f36 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 5 Dec 2023 12:46:15 +0700 Subject: [PATCH 115/119] Use gorilla websocket for cmd/faucet --- cmd/faucet/faucet.go | 258 ++++++++++++++++++------------------------- common/format.go | 42 +++++++ 2 files changed, 147 insertions(+), 153 deletions(-) diff --git a/cmd/faucet/faucet.go b/cmd/faucet/faucet.go index 33a17f35b7..114f5a1cb6 100644 --- a/cmd/faucet/faucet.go +++ b/cmd/faucet/faucet.go @@ -41,6 +41,8 @@ import ( "sync" "time" + "github.com/gorilla/websocket" + "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/accounts/keystore" "github.com/tomochain/tomochain/common" @@ -58,7 +60,6 @@ import ( "github.com/tomochain/tomochain/p2p/enode" "github.com/tomochain/tomochain/p2p/nat" "github.com/tomochain/tomochain/params" - "golang.org/x/net/websocket" ) var ( @@ -77,9 +78,6 @@ var ( accJSONFlag = flag.String("account.json", "", "Key json file to fund user requests with") accPassFlag = flag.String("account.pass", "", "Decryption password to access faucet funds") - githubUser = flag.String("github.user", "", "GitHub user to authenticate with for Gist access") - githubToken = flag.String("github.token", "", "GitHub personal token to access Gists with") - captchaToken = flag.String("captcha.token", "", "Recaptcha site key to authenticate client side") captchaSecret = flag.String("captcha.secret", "", "Recaptcha secret key to authenticate server side") @@ -91,6 +89,11 @@ var ( ether = new(big.Int).Exp(big.NewInt(10), big.NewInt(18), nil) ) +var ( + gitCommit = "" // Git SHA1 commit hash of the release (set via linker flags) + gitDate = "" // Git commit date YYYYMMDD of the release (set via linker flags) +) + func main() { // Parse the flags and set up the logger to print everything requested flag.Parse() @@ -160,7 +163,8 @@ func main() { if blob, err = ioutil.ReadFile(*accPassFlag); err != nil { log.Crit("Failed to read account password contents", "file", *accPassFlag, "err", err) } - pass := string(blob) + // Delete trailing newline in password + pass := strings.TrimSuffix(string(blob), "\n") ks := keystore.NewKeyStore(filepath.Join(os.Getenv("HOME"), ".faucet", "keys"), keystore.StandardScryptN, keystore.StandardScryptP) if blob, err = ioutil.ReadFile(*accJSONFlag); err != nil { @@ -200,7 +204,9 @@ type faucet struct { index []byte // Index page to serve up on the web keystore *keystore.KeyStore // Keystore containing the single signer - account accounts.Account // StateAccount funding user faucet requests + account accounts.Account // Account funding user faucet requests + head *types.Header // Current head header of the faucet + balance *big.Int // Current balance of the faucet nonce uint64 // Current pending nonce of the faucet price *big.Int // Current gas price to issue funds with @@ -215,8 +221,8 @@ type faucet struct { func newFaucet(genesis *core.Genesis, port int, enodes []*discv5.Node, network uint64, stats string, ks *keystore.KeyStore, index []byte) (*faucet, error) { // Assemble the raw devp2p protocol stack stack, err := node.New(&node.Config{ - Name: "tomo", - Version: params.Version, + Name: "geth", + Version: params.VersionWithCommit(gitCommit), DataDir: filepath.Join(os.Getenv("HOME"), ".faucet"), P2P: p2p.Config{ NAT: nat.Any(), @@ -255,8 +261,10 @@ func newFaucet(genesis *core.Genesis, port int, enodes []*discv5.Node, network u return nil, err } for _, boot := range enodes { - old, _ := enode.ParseV4(boot.String()) - stack.Server().AddPeer(old) + old, err := enode.Parse(enode.ValidSchemes, boot.String()) + if err == nil { + stack.Server().AddPeer(old) + } } // Attach to the client and retrieve and interesting metadatas api, err := stack.Attach() @@ -280,7 +288,7 @@ func newFaucet(genesis *core.Genesis, port int, enodes []*discv5.Node, network u // close terminates the Ethereum connection and tears down the faucet. func (f *faucet) close() error { - return f.stack.Stop() + return f.stack.Close() } // listenAndServe registers the HTTP handlers for the faucet and boots it up @@ -289,8 +297,7 @@ func (f *faucet) listenAndServe(port int) error { go f.loop() http.HandleFunc("/", f.webHandler) - http.Handle("/api", websocket.Handler(f.apiHandler)) - + http.HandleFunc("/api", f.apiHandler) return http.ListenAndServe(fmt.Sprintf(":%d", port), nil) } @@ -301,7 +308,13 @@ func (f *faucet) webHandler(w http.ResponseWriter, r *http.Request) { } // apiHandler handles requests for Ether grants and transaction statuses. -func (f *faucet) apiHandler(conn *websocket.Conn) { +func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + // Start tracking the connection and drop at the end defer conn.Close() @@ -324,35 +337,31 @@ func (f *faucet) apiHandler(conn *websocket.Conn) { head *types.Header balance *big.Int nonce uint64 - err error ) - for { - // Attempt to retrieve the stats, may error on no faucet connectivity - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - head, err = f.client.HeaderByNumber(ctx, nil) - if err == nil { - balance, err = f.client.BalanceAt(ctx, f.account.Address, head.Number) - if err == nil { - nonce, err = f.client.NonceAt(ctx, f.account.Address, nil) - } + for head == nil || balance == nil { + // Retrieve the current stats cached by the faucet + f.lock.RLock() + if f.head != nil { + head = types.CopyHeader(f.head) } - cancel() + if f.balance != nil { + balance = new(big.Int).Set(f.balance) + } + nonce = f.nonce + f.lock.RUnlock() - // If stats retrieval failed, wait a bit and retry - if err != nil { - if err = sendError(conn, errors.New("Faucet offline: "+err.Error())); err != nil { + if head == nil || balance == nil { + // Report the faucet offline until initial stats are ready + if err = sendError(conn, errors.New("Faucet offline")); err != nil { log.Warn("Failed to send faucet error to client", "err", err) return } time.Sleep(3 * time.Second) - continue } - // Initial stats reported successfully, proceed with user interaction - break } // Send over the initial stats and the latest header if err = send(conn, map[string]interface{}{ - "funds": balance.Div(balance, ether), + "funds": new(big.Int).Div(balance, ether), "funded": nonce, "peers": f.stack.Server().PeerCount(), "requests": f.reqs, @@ -372,7 +381,7 @@ func (f *faucet) apiHandler(conn *websocket.Conn) { Tier uint `json:"tier"` Captcha string `json:"captcha"` } - if err = websocket.JSON.Receive(conn, &msg); err != nil { + if err = conn.ReadJSON(&msg); err != nil { return } if !*noauthFlag && !strings.HasPrefix(msg.URL, "https://gist.github.com/") && !strings.HasPrefix(msg.URL, "https://twitter.com/") && @@ -441,16 +450,20 @@ func (f *faucet) apiHandler(conn *websocket.Conn) { return } continue + case strings.HasPrefix(msg.URL, "https://plus.google.com/"): + if err = sendError(conn, errors.New("Google+ authentication discontinued as the service was sunset")); err != nil { + log.Warn("Failed to send Google+ deprecation to client", "err", err) + return + } + continue case strings.HasPrefix(msg.URL, "https://twitter.com/"): username, avatar, address, err = authTwitter(msg.URL) - case strings.HasPrefix(msg.URL, "https://plus.google.com/"): - username, avatar, address, err = authGooglePlus(msg.URL) case strings.HasPrefix(msg.URL, "https://www.facebook.com/"): username, avatar, address, err = authFacebook(msg.URL) case *noauthFlag: username, avatar, address, err = authNoAuth(msg.URL) default: - err = errors.New("Something funky happened, please open an issue at https://github.com/tomochain/tomochain/issues") + err = errors.New("Something funky happened, please open an issue at https://github.com/ethereum/go-ethereum/issues") } if err != nil { if err = sendError(conn, err); err != nil { @@ -498,7 +511,10 @@ func (f *faucet) apiHandler(conn *websocket.Conn) { Time: time.Now(), Tx: signed, }) - f.timeouts[username] = time.Now().Add(time.Duration(*minutesFlag*int(math.Pow(3, float64(msg.Tier)))) * time.Minute) + timeout := time.Duration(*minutesFlag*int(math.Pow(3, float64(msg.Tier)))) * time.Minute + grace := timeout / 288 // 24h timeout => 5m grace + + f.timeouts[username] = time.Now().Add(timeout - grace) fund = true } f.lock.Unlock() @@ -522,6 +538,47 @@ func (f *faucet) apiHandler(conn *websocket.Conn) { } } +// refresh attempts to retrieve the latest header from the chain and extract the +// associated faucet balance and nonce for connectivity caching. +func (f *faucet) refresh(head *types.Header) error { + // Ensure a state update does not run for too long + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // If no header was specified, use the current chain head + var err error + if head == nil { + if head, err = f.client.HeaderByNumber(ctx, nil); err != nil { + return err + } + } + // Retrieve the balance, nonce and gas price from the current head + var ( + balance *big.Int + nonce uint64 + price *big.Int + ) + if balance, err = f.client.BalanceAt(ctx, f.account.Address, head.Number); err != nil { + return err + } + if nonce, err = f.client.NonceAt(ctx, f.account.Address, head.Number); err != nil { + return err + } + if price, err = f.client.SuggestGasPrice(ctx); err != nil { + return err + } + // Everything succeeded, update the cached stats and eject old requests + f.lock.Lock() + f.head, f.balance = head, balance + f.price, f.nonce = price, nonce + for len(f.reqs) > 0 && f.reqs[0].Tx.Nonce() < f.nonce { + f.reqs = f.reqs[1:] + } + f.lock.Unlock() + + return nil +} + // loop keeps waiting for interesting events and pushes them out to connected // websockets. func (f *faucet) loop() { @@ -539,45 +596,27 @@ func (f *faucet) loop() { go func() { for head := range update { // New chain head arrived, query the current stats and stream to clients - var ( - balance *big.Int - nonce uint64 - price *big.Int - err error - ) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - balance, err = f.client.BalanceAt(ctx, f.account.Address, head.Number) - if err == nil { - nonce, err = f.client.NonceAt(ctx, f.account.Address, nil) - if err == nil { - price, err = f.client.SuggestGasPrice(ctx) - } + timestamp := time.Unix(head.Time.Int64(), 0) + if time.Since(timestamp) > time.Hour { + log.Warn("Skipping faucet refresh, head too old", "number", head.Number, "hash", head.Hash(), "age", common.PrettyAge(timestamp)) + continue } - cancel() - - // If querying the data failed, try for the next block - if err != nil { + if err := f.refresh(head); err != nil { log.Warn("Failed to update faucet state", "block", head.Number, "hash", head.Hash(), "err", err) continue - } else { - log.Info("Updated faucet state", "block", head.Number, "hash", head.Hash(), "balance", balance, "nonce", nonce, "price", price) } // Faucet state retrieved, update locally and send to clients - balance = new(big.Int).Div(balance, ether) + f.lock.RLock() + log.Info("Updated faucet state", "number", head.Number, "hash", head.Hash(), "age", common.PrettyAge(timestamp), "balance", f.balance, "nonce", f.nonce, "price", f.price) - f.lock.Lock() - f.price, f.nonce = price, nonce - for len(f.reqs) > 0 && f.reqs[0].Tx.Nonce() < f.nonce { - f.reqs = f.reqs[1:] - } - f.lock.Unlock() + balance := new(big.Int).Div(f.balance, ether) + peers := f.stack.Server().PeerCount() - f.lock.RLock() for _, conn := range f.conns { if err := send(conn, map[string]interface{}{ "funds": balance, "funded": f.nonce, - "peers": f.stack.Server().PeerCount(), + "peers": peers, "requests": f.reqs, }, time.Second); err != nil { log.Warn("Failed to send stats to client", "err", err) @@ -623,7 +662,7 @@ func send(conn *websocket.Conn, value interface{}, timeout time.Duration) error timeout = 60 * time.Second } conn.SetWriteDeadline(time.Now().Add(timeout)) - return websocket.JSON.Send(conn, value) + return conn.WriteJSON(value) } // sendError transmits an error to the remote end of the websocket, also setting @@ -638,59 +677,6 @@ func sendSuccess(conn *websocket.Conn, msg string) error { return send(conn, map[string]string{"success": msg}, time.Second) } -// authGitHub tries to authenticate a faucet request using GitHub gists, returning -// the username, avatar URL and Ethereum address to fund on success. -func authGitHub(url string) (string, string, common.Address, error) { - // Retrieve the gist from the GitHub Gist APIs - parts := strings.Split(url, "/") - req, _ := http.NewRequest("GET", "https://api.github.com/gists/"+parts[len(parts)-1], nil) - if *githubUser != "" { - req.SetBasicAuth(*githubUser, *githubToken) - } - res, err := http.DefaultClient.Do(req) - if err != nil { - return "", "", common.Address{}, err - } - var gist struct { - Owner struct { - Login string `json:"login"` - } `json:"owner"` - Files map[string]struct { - Content string `json:"content"` - } `json:"files"` - } - err = json.NewDecoder(res.Body).Decode(&gist) - res.Body.Close() - if err != nil { - return "", "", common.Address{}, err - } - if gist.Owner.Login == "" { - return "", "", common.Address{}, errors.New("Anonymous Gists not allowed") - } - // Iterate over all the files and look for Ethereum addresses - var address common.Address - for _, file := range gist.Files { - content := strings.TrimSpace(file.Content) - if len(content) == 2+common.AddressLength*2 { - address = common.HexToAddress(content) - } - } - if address == (common.Address{}) { - return "", "", common.Address{}, errors.New("No Ethereum address found to fund") - } - // Validate the user's existence since the API is unhelpful here - if res, err = http.Head("https://github.com/" + gist.Owner.Login); err != nil { - return "", "", common.Address{}, err - } - res.Body.Close() - - if res.StatusCode != 200 { - return "", "", common.Address{}, errors.New("Invalid user... boom!") - } - // Everything passed validation, return the gathered infos - return gist.Owner.Login + "@github", fmt.Sprintf("https://github.com/%s.png?size=64", gist.Owner.Login), address, nil -} - // authTwitter tries to authenticate a faucet request using Twitter posts, returning // the username, avatar URL and Ethereum address to fund on success. func authTwitter(url string) (string, string, common.Address, error) { @@ -730,40 +716,6 @@ func authTwitter(url string) (string, string, common.Address, error) { return username + "@twitter", avatar, address, nil } -// authGooglePlus tries to authenticate a faucet request using GooglePlus posts, -// returning the username, avatar URL and Ethereum address to fund on success. -func authGooglePlus(url string) (string, string, common.Address, error) { - // Ensure the user specified a meaningful URL, no fancy nonsense - parts := strings.Split(url, "/") - if len(parts) < 4 || parts[len(parts)-2] != "posts" { - return "", "", common.Address{}, errors.New("Invalid Google+ post URL") - } - username := parts[len(parts)-3] - - // Google's API isn't really friendly with direct links. Still, we don't - // want to do ask read permissions from users, so just load the public posts and - // scrape it for the Ethereum address and profile URL. - res, err := http.Get(url) - if err != nil { - return "", "", common.Address{}, err - } - defer res.Body.Close() - - body, err := ioutil.ReadAll(res.Body) - if err != nil { - return "", "", common.Address{}, err - } - address := common.HexToAddress(string(regexp.MustCompile("0x[0-9a-fA-F]{40}").Find(body))) - if address == (common.Address{}) { - return "", "", common.Address{}, errors.New("No Ethereum address found to fund") - } - var avatar string - if parts = regexp.MustCompile("src=\"([^\"]+googleusercontent.com[^\"]+photo.jpg)\"").FindStringSubmatch(string(body)); len(parts) == 2 { - avatar = parts[1] - } - return username + "@google+", avatar, address, nil -} - // authFacebook tries to authenticate a faucet request using Facebook posts, // returning the username, avatar URL and Ethereum address to fund on success. func authFacebook(url string) (string, string, common.Address, error) { diff --git a/common/format.go b/common/format.go index fccc299620..6fc21af719 100644 --- a/common/format.go +++ b/common/format.go @@ -38,3 +38,45 @@ func (d PrettyDuration) String() string { } return label } + +// PrettyAge is a pretty printed version of a time.Duration value that rounds +// the values up to a single most significant unit, days/weeks/years included. +type PrettyAge time.Time + +// ageUnits is a list of units the age pretty printing uses. +var ageUnits = []struct { + Size time.Duration + Symbol string +}{ + {12 * 30 * 24 * time.Hour, "y"}, + {30 * 24 * time.Hour, "mo"}, + {7 * 24 * time.Hour, "w"}, + {24 * time.Hour, "d"}, + {time.Hour, "h"}, + {time.Minute, "m"}, + {time.Second, "s"}, +} + +// String implements the Stringer interface, allowing pretty printing of duration +// values rounded to the most significant time unit. +func (t PrettyAge) String() string { + // Calculate the time difference and handle the 0 cornercase + diff := time.Since(time.Time(t)) + if diff < time.Second { + return "0" + } + // Accumulate a precision of 3 components before returning + result, prec := "", 0 + + for _, unit := range ageUnits { + if diff > unit.Size { + result = fmt.Sprintf("%s%d%s", result, diff/unit.Size, unit.Symbol) + diff %= unit.Size + + if prec += 1; prec >= 3 { + break + } + } + } + return result +} From 3952c1e573f9706efca5e8cd7e8c5f29c1a1685c Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Tue, 5 Dec 2023 12:52:35 +0700 Subject: [PATCH 116/119] Send websocket ping when connection is idle --- rpc/endpoints.go | 30 -------------- rpc/websocket.go | 101 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 80 insertions(+), 51 deletions(-) diff --git a/rpc/endpoints.go b/rpc/endpoints.go index d91b00fea1..0b42f1c8ca 100644 --- a/rpc/endpoints.go +++ b/rpc/endpoints.go @@ -51,36 +51,6 @@ func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []str return listener, handler, err } -// StartWSEndpoint starts a websocket endpoint -func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []string, exposeAll bool) (net.Listener, *Server, error) { - // Generate the whitelist based on the allowed modules - whitelist := make(map[string]bool) - for _, module := range modules { - whitelist[module] = true - } - // Register all the APIs exposed by the services - handler := NewServer() - for _, api := range apis { - if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { - if err := handler.RegisterName(api.Namespace, api.Service); err != nil { - return nil, nil, err - } - log.Debug("WebSocket registered", "service", api.Service, "namespace", api.Namespace) - } - } - // All APIs registered, start the HTTP listener - var ( - listener net.Listener - err error - ) - if listener, err = net.Listen("tcp", endpoint); err != nil { - return nil, nil, err - } - go NewWSServer(wsOrigins, handler).Serve(listener) - return listener, handler, err - -} - // StartIPCEndpoint starts an IPC endpoint. func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, error) { // Register all the APIs exposed by the services. diff --git a/rpc/websocket.go b/rpc/websocket.go index f7663979c3..2b44ea239e 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -25,6 +25,7 @@ import ( "os" "strings" "sync" + "time" mapset "github.com/deckarep/golang-set" "github.com/gorilla/websocket" @@ -33,19 +34,14 @@ import ( ) const ( - wsReadBuffer = 1024 - wsWriteBuffer = 1024 + wsReadBuffer = 1024 + wsWriteBuffer = 1024 + wsPingInterval = 60 * time.Second + wsPingWriteTimeout = 5 * time.Second ) var wsBufferPool = new(sync.Pool) -// NewWSServer creates a new websocket RPC server around an API provider. -// -// Deprecated: use Server.WebsocketHandler -func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { - return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} -} - // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. // // allowedOrigins should be a comma-separated list of allowed origin URLs. @@ -125,21 +121,13 @@ func (e wsHandshakeError) Error() string { return s } -// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server -// that is listening on the given endpoint. -// -// The context is used for the initial connection establishment. It does not -// affect subsequent interactions with the client. -func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { +// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server +// that is listening on the given endpoint using the provided dialer. +func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { endpoint, header, err := wsClientHeaders(endpoint, origin) if err != nil { return nil, err } - dialer := websocket.Dialer{ - ReadBufferSize: wsReadBuffer, - WriteBufferSize: wsWriteBuffer, - WriteBufferPool: wsBufferPool, - } return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { conn, resp, err := dialer.DialContext(ctx, endpoint, header) if err != nil { @@ -153,6 +141,20 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error }) } +// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server +// that is listening on the given endpoint. +// +// The context is used for the initial connection establishment. It does not +// affect subsequent interactions with the client. +func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { + dialer := websocket.Dialer{ + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + } + return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) +} + func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { endpointURL, err := url.Parse(endpoint) if err != nil { @@ -170,7 +172,64 @@ func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { return endpointURL.String(), header, nil } +type websocketCodec struct { + *jsonCodec + conn *websocket.Conn + + wg sync.WaitGroup + pingReset chan struct{} +} + func newWebsocketCodec(conn *websocket.Conn) ServerCodec { conn.SetReadLimit(maxRequestContentLength) - return NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON) + wc := &websocketCodec{ + jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), + conn: conn, + pingReset: make(chan struct{}, 1), + } + wc.wg.Add(1) + go wc.pingLoop() + return wc +} + +func (wc *websocketCodec) close() { + wc.jsonCodec.close() + wc.wg.Wait() +} + +func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { + err := wc.jsonCodec.writeJSON(ctx, v) + if err == nil { + // Notify pingLoop to delay the next idle ping. + select { + case wc.pingReset <- struct{}{}: + default: + } + } + return err +} + +// pingLoop sends periodic ping frames when the connection is idle. +func (wc *websocketCodec) pingLoop() { + var timer = time.NewTimer(wsPingInterval) + defer wc.wg.Done() + defer timer.Stop() + + for { + select { + case <-wc.closed(): + return + case <-wc.pingReset: + if !timer.Stop() { + <-timer.C + } + timer.Reset(wsPingInterval) + case <-timer.C: + wc.jsonCodec.encMu.Lock() + wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) + wc.conn.WriteMessage(websocket.PingMessage, nil) + wc.jsonCodec.encMu.Unlock() + timer.Reset(wsPingInterval) + } + } } From 09ad227001b0e7bca0b5e73910ebf5aa6b28bd8d Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 7 Dec 2023 14:27:16 +0700 Subject: [PATCH 117/119] Re-add NewWSServer for compatibility --- rpc/websocket.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rpc/websocket.go b/rpc/websocket.go index 2b44ea239e..80c9e2d0bc 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -42,6 +42,13 @@ const ( var wsBufferPool = new(sync.Pool) +// NewWSServer creates a new websocket RPC server around an API provider. +// +// Deprecated: use Server.WebsocketHandler +func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { + return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} +} + // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. // // allowedOrigins should be a comma-separated list of allowed origin URLs. From 94663c60bbe604d426f6604aa6a0d8c92def96ce Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 7 Dec 2023 14:39:59 +0700 Subject: [PATCH 118/119] Separate and increase wsMessageSizeLimit to 32 MB --- rpc/http_test.go | 91 ++++++++++++++++++++++++++++++++++++----- rpc/testservice_test.go | 10 +++++ rpc/websocket.go | 5 ++- rpc/websocket_test.go | 27 ++++++++++++ 4 files changed, 121 insertions(+), 12 deletions(-) diff --git a/rpc/http_test.go b/rpc/http_test.go index b3f694d8af..b75af67c52 100644 --- a/rpc/http_test.go +++ b/rpc/http_test.go @@ -23,32 +23,103 @@ import ( "testing" ) +func confirmStatusCode(t *testing.T, got, want int) { + t.Helper() + if got == want { + return + } + if gotName := http.StatusText(got); len(gotName) > 0 { + if wantName := http.StatusText(want); len(wantName) > 0 { + t.Fatalf("response status code: got %d (%s), want %d (%s)", got, gotName, want, wantName) + } + } + t.Fatalf("response status code: got %d, want %d", got, want) +} + +func confirmRequestValidationCode(t *testing.T, method, contentType, body string, expectedStatusCode int) { + t.Helper() + request := httptest.NewRequest(method, "http://url.com", strings.NewReader(body)) + if len(contentType) > 0 { + request.Header.Set("Content-Type", contentType) + } + code, err := validateRequest(request) + if code == 0 { + if err != nil { + t.Errorf("validation: got error %v, expected nil", err) + } + } else if err == nil { + t.Errorf("validation: code %d: got nil, expected error", code) + } + confirmStatusCode(t, code, expectedStatusCode) +} + func TestHTTPErrorResponseWithDelete(t *testing.T) { - testHTTPErrorResponse(t, http.MethodDelete, contentType, "", http.StatusMethodNotAllowed) + confirmRequestValidationCode(t, http.MethodDelete, contentType, "", http.StatusMethodNotAllowed) } func TestHTTPErrorResponseWithPut(t *testing.T) { - testHTTPErrorResponse(t, http.MethodPut, contentType, "", http.StatusMethodNotAllowed) + confirmRequestValidationCode(t, http.MethodPut, contentType, "", http.StatusMethodNotAllowed) } func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) { body := make([]rune, maxRequestContentLength+1) - testHTTPErrorResponse(t, + confirmRequestValidationCode(t, http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge) } func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) { - testHTTPErrorResponse(t, http.MethodPost, "", "", http.StatusUnsupportedMediaType) + confirmRequestValidationCode(t, http.MethodPost, "", "", http.StatusUnsupportedMediaType) } func TestHTTPErrorResponseWithValidRequest(t *testing.T) { - testHTTPErrorResponse(t, http.MethodPost, contentType, "", 0) + confirmRequestValidationCode(t, http.MethodPost, contentType, "", 0) } -func testHTTPErrorResponse(t *testing.T, method, contentType, body string, expected int) { - request := httptest.NewRequest(method, "http://url.com", strings.NewReader(body)) - request.Header.Set("content-type", contentType) - if code, _ := validateRequest(request); code != expected { - t.Fatalf("response code should be %d not %d", expected, code) +func confirmHTTPRequestYieldsStatusCode(t *testing.T, method, contentType, body string, expectedStatusCode int) { + t.Helper() + s := Server{} + ts := httptest.NewServer(&s) + defer ts.Close() + + request, err := http.NewRequest(method, ts.URL, strings.NewReader(body)) + if err != nil { + t.Fatalf("failed to create a valid HTTP request: %v", err) + } + if len(contentType) > 0 { + request.Header.Set("Content-Type", contentType) + } + resp, err := http.DefaultClient.Do(request) + if err != nil { + t.Fatalf("request failed: %v", err) + } + confirmStatusCode(t, resp.StatusCode, expectedStatusCode) +} + +func TestHTTPResponseWithEmptyGet(t *testing.T) { + confirmHTTPRequestYieldsStatusCode(t, http.MethodGet, "", "", http.StatusOK) +} + +// This checks that maxRequestContentLength is not applied to the response of a request. +func TestHTTPRespBodyUnlimited(t *testing.T) { + const respLength = maxRequestContentLength * 3 + + s := NewServer() + defer s.Stop() + s.RegisterName("test", largeRespService{respLength}) + ts := httptest.NewServer(s) + defer ts.Close() + + c, err := DialHTTP(ts.URL) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + var r string + if err := c.Call(&r, "test_largeResp"); err != nil { + t.Fatal(err) + } + if len(r) != respLength { + t.Fatalf("response has wrong length %d, want %d", len(r), respLength) } } diff --git a/rpc/testservice_test.go b/rpc/testservice_test.go index 470870bacf..32aefbb75e 100644 --- a/rpc/testservice_test.go +++ b/rpc/testservice_test.go @@ -20,6 +20,7 @@ import ( "context" "encoding/binary" "errors" + "strings" "sync" "time" ) @@ -178,3 +179,12 @@ func (s *notificationTestService) HangSubscription(ctx context.Context, val int) }() return subscription, nil } + +// largeRespService generates arbitrary-size JSON responses. +type largeRespService struct { + length int +} + +func (x largeRespService) LargeResp() string { + return strings.Repeat("x", x.length) +} diff --git a/rpc/websocket.go b/rpc/websocket.go index 80c9e2d0bc..0b54ef6d48 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -36,8 +36,9 @@ import ( const ( wsReadBuffer = 1024 wsWriteBuffer = 1024 - wsPingInterval = 60 * time.Second + wsPingInterval = 30 * time.Second wsPingWriteTimeout = 5 * time.Second + wsMessageSizeLimit = 32 * 1024 * 1024 ) var wsBufferPool = new(sync.Pool) @@ -188,7 +189,7 @@ type websocketCodec struct { } func newWebsocketCodec(conn *websocket.Conn) ServerCodec { - conn.SetReadLimit(maxRequestContentLength) + conn.SetReadLimit(wsMessageSizeLimit) wc := &websocketCodec{ jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), conn: conn, diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index a00e8da0f6..d67aacc296 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -156,6 +156,33 @@ func TestClientWebsocketPing(t *testing.T) { } } +// This checks that the websocket transport can deal with large messages. +func TestClientWebsocketLargeMessage(t *testing.T) { + var ( + srv = NewServer() + httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) + wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + ) + defer srv.Stop() + defer httpsrv.Close() + + respLength := wsMessageSizeLimit - 50 + srv.RegisterName("test", largeRespService{respLength}) + + c, err := DialWebsocket(context.Background(), wsURL, "") + if err != nil { + t.Fatal(err) + } + + var r string + if err := c.Call(&r, "test_largeResp"); err != nil { + t.Fatal("call failed:", err) + } + if len(r) != respLength { + t.Fatalf("response has wrong length %d, want %d", len(r), respLength) + } +} + // wsPingTestServer runs a WebSocket server which accepts a single subscription request. // When a value arrives on sendPing, the server sends a ping frame, waits for a matching // pong and finally delivers a single subscription result. From 1c3d200827d8900703006c44c6cbbcfc28398756 Mon Sep 17 00:00:00 2001 From: Dang Nhat Trinh Date: Thu, 7 Dec 2023 17:03:18 +0700 Subject: [PATCH 119/119] Set websocket pong deadline timeout --- rpc/websocket.go | 6 +++ rpc/websocket_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/rpc/websocket.go b/rpc/websocket.go index 0b54ef6d48..7dbb8ca932 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -38,6 +38,7 @@ const ( wsWriteBuffer = 1024 wsPingInterval = 30 * time.Second wsPingWriteTimeout = 5 * time.Second + wsPongTimeout = 30 * time.Second wsMessageSizeLimit = 32 * 1024 * 1024 ) @@ -190,6 +191,10 @@ type websocketCodec struct { func newWebsocketCodec(conn *websocket.Conn) ServerCodec { conn.SetReadLimit(wsMessageSizeLimit) + conn.SetPongHandler(func(appData string) error { + conn.SetReadDeadline(time.Time{}) + return nil + }) wc := &websocketCodec{ jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), conn: conn, @@ -236,6 +241,7 @@ func (wc *websocketCodec) pingLoop() { wc.jsonCodec.encMu.Lock() wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) wc.conn.WriteMessage(websocket.PingMessage, nil) + wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) wc.jsonCodec.encMu.Unlock() timer.Reset(wsPingInterval) } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index d67aacc296..9676bdab21 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -18,11 +18,15 @@ package rpc import ( "context" + "io" "net" "net/http" "net/http/httptest" + "net/http/httputil" + "net/url" "reflect" "strings" + "sync/atomic" "testing" "time" @@ -183,6 +187,63 @@ func TestClientWebsocketLargeMessage(t *testing.T) { } } +func TestClientWebsocketSevered(t *testing.T) { + t.Parallel() + + var ( + server = wsPingTestServer(t, nil) + ctx = context.Background() + ) + defer server.Shutdown(ctx) + + u, err := url.Parse("http://" + server.Addr) + if err != nil { + t.Fatal(err) + } + rproxy := httputil.NewSingleHostReverseProxy(u) + var severable *severableReadWriteCloser + rproxy.ModifyResponse = func(response *http.Response) error { + severable = &severableReadWriteCloser{ReadWriteCloser: response.Body.(io.ReadWriteCloser)} + response.Body = severable + return nil + } + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + wsURL := "ws:" + strings.TrimPrefix(frontendProxy.URL, "http:") + client, err := DialWebsocket(ctx, wsURL, "") + if err != nil { + t.Fatalf("client dial error: %v", err) + } + defer client.Close() + + resultChan := make(chan int) + sub, err := client.EthSubscribe(ctx, resultChan, "foo") + if err != nil { + t.Fatalf("client subscribe error: %v", err) + } + + // sever the connection + severable.Sever() + + // Wait for subscription error. + timeout := time.NewTimer(3 * wsPingInterval) + defer timeout.Stop() + for { + select { + case err := <-sub.Err(): + t.Log("client subscription error:", err) + return + case result := <-resultChan: + t.Error("unexpected result:", result) + return + case <-timeout.C: + t.Error("didn't get any error within the test timeout") + return + } + } +} + // wsPingTestServer runs a WebSocket server which accepts a single subscription request. // When a value arrives on sendPing, the server sends a ping frame, waits for a matching // pong and finally delivers a single subscription result. @@ -284,3 +345,31 @@ func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <- } } } + +// severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty. +type severableReadWriteCloser struct { + io.ReadWriteCloser + severed int32 // atomic +} + +func (s *severableReadWriteCloser) Sever() { + atomic.StoreInt32(&s.severed, 1) +} + +func (s *severableReadWriteCloser) Read(p []byte) (n int, err error) { + if atomic.LoadInt32(&s.severed) > 0 { + return 0, nil + } + return s.ReadWriteCloser.Read(p) +} + +func (s *severableReadWriteCloser) Write(p []byte) (n int, err error) { + if atomic.LoadInt32(&s.severed) > 0 { + return len(p), nil + } + return s.ReadWriteCloser.Write(p) +} + +func (s *severableReadWriteCloser) Close() error { + return s.ReadWriteCloser.Close() +}