forked from ethereum/go-ethereum
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tracker.go
206 lines (191 loc) · 6.27 KB
/
tracker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
package test
import (
"context"
"errors"
"fmt"
"runtime/debug"
"sync"
ethereum "github.com/celo-org/celo-blockchain"
"github.com/celo-org/celo-blockchain/common"
"github.com/celo-org/celo-blockchain/core/types"
"github.com/celo-org/celo-blockchain/ethclient"
"github.com/celo-org/celo-blockchain/event"
)
var (
errStopped = errors.New("transaction tracker closed")
)
// Tracker tracks processed blocks and transactions through a subscription with
// an ethclient. It provides the ability to check whether blocks or
// transactions have been processed and to wait till those blocks or
// transactions have been processed.
type Tracker struct {
client *ethclient.Client
heads chan *types.Header
sub ethereum.Subscription
wg sync.WaitGroup
// processedTxs maps transaction hashes to the block they were processed in.
processedTxs map[common.Hash]*types.Block
// processedBlocks maps block number to processed blocks.
processedBlocks map[uint64]*types.Block
processedMu sync.Mutex
stopCh chan struct{}
newBlock event.Feed
}
// NewTracker creates a new tracker.
func NewTracker() *Tracker {
return &Tracker{
heads: make(chan *types.Header, 10),
processedTxs: make(map[common.Hash]*types.Block),
processedBlocks: make(map[uint64]*types.Block),
}
}
// GetProcessedTx returns the processed transaction with the given hash or nil
// if the tracker has not seen a processed transaction with the given hash.
func (tr *Tracker) GetProcessedTx(hash common.Hash) *types.Transaction {
tr.processedMu.Lock()
defer tr.processedMu.Unlock()
return tr.processedTxs[hash].Transaction(hash)
}
// GetProcessedBlockForTx returns the block that a transaction with the given
// hash was processed in or nil if the tracker has not seen a processed
// transaction with the given hash.
func (tr *Tracker) GetProcessedBlockForTx(hash common.Hash) *types.Block {
tr.processedMu.Lock()
defer tr.processedMu.Unlock()
return tr.processedTxs[hash]
}
// GetProcessedBlock returns processed block with the given num or nil if the
// tracker has not seen a processed block with that num.
func (tr *Tracker) GetProcessedBlock(num uint64) *types.Block {
tr.processedMu.Lock()
defer tr.processedMu.Unlock()
return tr.processedBlocks[num]
}
// StartTracking subscribes to new head events on the client and starts
// processing the events in a goroutine.
func (tr *Tracker) StartTracking(client *ethclient.Client) error {
if tr.sub != nil {
return errors.New("attempted to start already started tracker")
}
// The subscription client will buffer 20000 notifications before closing
// the subscription, if that happens the Err() chan will return
// ErrSubscriptionQueueOverflow
sub, err := client.SubscribeNewHead(context.Background(), tr.heads)
if err != nil {
return err
}
tr.client = client
tr.sub = sub
tr.stopCh = make(chan struct{})
tr.wg.Add(1)
go func() {
defer tr.wg.Done()
err := tr.track()
if err != nil {
fmt.Printf("track failed with error: %v\n", err)
}
}()
return nil
}
// track reads new heads from the heads channel and for each head retrieves the
// block, places the block in processedBlocks and places the transactions into
// processedTxs. It signals the sub Subscription for each retrieved block.
func (tr *Tracker) track() error {
for {
select {
case h := <-tr.heads:
b, err := tr.client.BlockByHash(context.Background(), h.Hash())
if err != nil {
return err
}
tr.processedMu.Lock()
tr.processedBlocks[b.NumberU64()] = b
// If we have transactions then process them
if len(b.Transactions()) > 0 {
for _, t := range b.Transactions() {
tr.processedTxs[t.Hash()] = b
}
}
tr.processedMu.Unlock()
// signal
tr.newBlock.Send(struct{}{})
case err := <-tr.sub.Err():
// Will be nil if closed by calling Unsubscribe()
return err
case <-tr.stopCh:
return nil
}
}
}
// AwaitTransactions waits for the transactions listed in hashes to be
// processed, it will return the ctx.Err() if ctx expires before all the
// transactions in hashes were processed or ErrStopped if StopTracking is
// called before all the transactions in hashes were processed.
func (tr *Tracker) AwaitTransactions(ctx context.Context, hashes []common.Hash) error {
hashmap := make(map[common.Hash]struct{}, len(hashes))
for i := range hashes {
hashmap[hashes[i]] = struct{}{}
}
condition := func() bool {
for hash := range hashmap {
_, ok := tr.processedTxs[hash]
if ok {
delete(hashmap, hash)
}
}
// If there are no transactions left then they have all been processed.
return len(hashmap) == 0
}
return tr.await(ctx, condition)
}
// AwaitBlock waits for a block with the given num to be processed, it will
// return the ctx.Err() if ctx expires before a block with that number has been
// processed or ErrStopped if StopTracking is called before a block with that
// number is processed.
func (tr *Tracker) AwaitBlock(ctx context.Context, num uint64) error {
condition := func() bool {
return tr.processedBlocks[num] != nil
}
return tr.await(ctx, condition)
}
// await waits for the provided condition to return true, it rechecks the
// condition every time a new block is received by the Tracker. Await returns
// nil when the condition returns true, otherwise it will return ctx.Err() if
// ctx expires before the condition returns true or ErrStopped if StopTracking
// is called before the condition returns true.
func (tr *Tracker) await(ctx context.Context, condition func() bool) error {
ch := make(chan struct{}, 10)
sub := tr.newBlock.Subscribe(ch)
defer sub.Unsubscribe()
for {
tr.processedMu.Lock()
found := condition()
tr.processedMu.Unlock()
// If we found what we are looking for then return.
if found {
return nil
}
select {
case <-ch:
continue
case <-ctx.Done():
return ctx.Err()
case <-tr.stopCh:
return errStopped
}
}
}
// StopTracking shuts down all the goroutines in the tracker.
func (tr *Tracker) StopTracking() error {
if tr.sub == nil {
return fmt.Errorf("attempted to stop already stopped tracker - stack: \n%s", string(debug.Stack()))
}
tr.sub.Unsubscribe()
close(tr.stopCh)
tr.wg.Wait()
// Set this to nil to mark the tracker as stopped. This must be done after
// waiting for wg, to avoid a data race in trackTransactions.
tr.sub = nil
tr.wg = sync.WaitGroup{}
return nil
}