Skip to content

Commit

Permalink
routing+server: use cached graph interface
Browse files Browse the repository at this point in the history
  • Loading branch information
guggero committed Aug 26, 2021
1 parent a47f02b commit 64a4966
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 93 deletions.
26 changes: 13 additions & 13 deletions routing/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"github.com/lightningnetwork/lnd/routing/route"
)

// routingGraph is an abstract interface that provides information about nodes
// Graph is an abstract interface that provides information about nodes
// and edges to pathfinding.
type routingGraph interface {
type Graph interface {
// forEachNodeChannel calls the callback for every channel of the given node.
forEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error
Expand All @@ -20,48 +20,48 @@ type routingGraph interface {
fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
}

// dbRoutingTx is a routingGraph implementation that retrieves from the
// CachedGraph is a Graph implementation that retrieves from the
// database.
type dbRoutingTx struct {
type CachedGraph struct {
graph *channeldb.ChannelGraph
source route.Vertex
}

// newDbRoutingTx instantiates a new db-connected routing graph. It implictly
// NewCachedGraph instantiates a new db-connected routing graph. It implictly
// instantiates a new read transaction.
func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) {
func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) {
sourceNode, err := graph.SourceNode()
if err != nil {
return nil, err
}

return &dbRoutingTx{
return &CachedGraph{
graph: graph,
source: sourceNode.PubKeyBytes,
}, nil
}

// forEachNodeChannel calls the callback for every channel of the given node.
//
// NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex,
// NOTE: Part of the Graph interface.
func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error {

return g.graph.ForEachNodeChannel(nodePub, cb)
}

// sourceNode returns the source node of the graph.
//
// NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) sourceNode() route.Vertex {
// NOTE: Part of the Graph interface.
func (g *CachedGraph) sourceNode() route.Vertex {
return g.source
}

// fetchNodeFeatures returns the features of the given node. If the node is
// unknown, assume no additional features are supported.
//
// NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) (
// NOTE: Part of the Graph interface.
func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {

return g.graph.FetchNodeFeatures(nodePub)
Expand Down
6 changes: 1 addition & 5 deletions routing/integrated_routing_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
}

session, err := newPaymentSession(
&payment, getBandwidthHints,
func() (routingGraph, func(), error) {
return c.graph, func() {}, nil
},
mc, c.pathFindingCfg,
&payment, getBandwidthHints, c.graph, mc, c.pathFindingCfg,
)
if err != nil {
c.t.Fatal(err)
Expand Down
20 changes: 10 additions & 10 deletions routing/mock_graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte,

// forEachNodeChannel calls the callback for every channel of the given node.
//
// NOTE: Part of the routingGraph interface.
// NOTE: Part of the Graph interface.
func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex,
cb func(channel *channeldb.DirectedChannel) error) error {

Expand All @@ -182,18 +182,18 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex,
// Call the per channel callback.
err := cb(
&channeldb.DirectedChannel{
ChannelID: channel.id,
IsNode1: nodePub == node1,
OtherNode: peer,
Capacity: channel.capacity,
ChannelID: channel.id,
IsNode1: nodePub == node1,
OtherNode: peer,
Capacity: channel.capacity,
OutPolicySet: true,
InPolicy: &channeldb.CachedEdgePolicy{
ChannelID: channel.id,
ToNodePubKey: func() route.Vertex {
return nodePub
},
ToNodeFeatures: lnwire.EmptyFeatureVector(),
FeeBaseMSat: peerNode.baseFee,
FeeBaseMSat: peerNode.baseFee,
},
},
)
Expand All @@ -206,14 +206,14 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex,

// sourceNode returns the source node of the graph.
//
// NOTE: Part of the routingGraph interface.
// NOTE: Part of the Graph interface.
func (m *mockGraph) sourceNode() route.Vertex {
return m.source.pubkey
}

// fetchNodeFeatures returns the features of the given node.
//
// NOTE: Part of the routingGraph interface.
// NOTE: Part of the Graph interface.
func (m *mockGraph) fetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {

Expand Down Expand Up @@ -263,5 +263,5 @@ func (m *mockGraph) sendHtlc(route *route.Route) (htlcResult, error) {
return source.fwd(nil, next)
}

// Compile-time check for the routingGraph interface.
var _ routingGraph = &mockGraph{}
// Compile-time check for the Graph interface.
var _ Graph = &mockGraph{}
4 changes: 2 additions & 2 deletions routing/pathfind.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func edgeWeight(lockedAmt lnwire.MilliSatoshi, fee lnwire.MilliSatoshi,
// graphParams wraps the set of graph parameters passed to findPath.
type graphParams struct {
// graph is the ChannelGraph to be used during path finding.
graph routingGraph
graph Graph

// additionalEdges is an optional set of edges that should be
// considered during path finding, that is not already found in the
Expand Down Expand Up @@ -356,7 +356,7 @@ type PathFindingConfig struct {
// available balance.
func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
bandwidthHints map[uint64]lnwire.MilliSatoshi,
g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {
g Graph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {

var max, total lnwire.MilliSatoshi
cb := func(channel *channeldb.DirectedChannel) error {
Expand Down
2 changes: 1 addition & 1 deletion routing/pathfind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3024,7 +3024,7 @@ func dbFindPath(graph *channeldb.ChannelGraph,
source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) {

routingTx, err := newDbRoutingTx(graph)
routingTx, err := NewCachedGraph(graph)
if err != nil {
return nil, err
}
Expand Down
19 changes: 5 additions & 14 deletions routing/payment_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ type paymentSession struct {

pathFinder pathFinder

getRoutingGraph func() (routingGraph, func(), error)
routingGraph Graph

// pathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probabiity.
Expand All @@ -193,7 +193,7 @@ type paymentSession struct {
// newPaymentSession instantiates a new payment session.
func newPaymentSession(p *LightningPayment,
getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error),
getRoutingGraph func() (routingGraph, func(), error),
routingGraph Graph,
missionControl MissionController, pathFindingConfig PathFindingConfig) (
*paymentSession, error) {

Expand All @@ -209,7 +209,7 @@ func newPaymentSession(p *LightningPayment,
getBandwidthHints: getBandwidthHints,
payment: p,
pathFinder: findPath,
getRoutingGraph: getRoutingGraph,
routingGraph: routingGraph,
pathFindingConfig: pathFindingConfig,
missionControl: missionControl,
minShardAmt: DefaultShardMinAmt,
Expand Down Expand Up @@ -287,29 +287,20 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,

p.log.Debugf("pathfinding for amt=%v", maxAmt)

// Get a routing graph.
routingGraph, cleanup, err := p.getRoutingGraph()
if err != nil {
return nil, err
}

sourceVertex := routingGraph.sourceNode()
sourceVertex := p.routingGraph.sourceNode()

// Find a route for the current amount.
path, err := p.pathFinder(
&graphParams{
additionalEdges: p.additionalEdges,
bandwidthHints: bandwidthHints,
graph: routingGraph,
graph: p.routingGraph,
},
restrictions, &p.pathFindingConfig,
sourceVertex, p.payment.Target,
maxAmt, finalHtlcExpiry,
)

// Close routing graph.
cleanup()

switch {
case err == errNoPathFound:
// Don't split if this is a legacy payment without mpp
Expand Down
21 changes: 3 additions & 18 deletions routing/payment_session_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var _ PaymentSessionSource = (*SessionSource)(nil)
type SessionSource struct {
// Graph is the channel graph that will be used to gather metrics from
// and also to carry out path finding queries.
Graph *channeldb.ChannelGraph
Graph Graph

// QueryBandwidth is a method that allows querying the lower link layer
// to determine the up to date available bandwidth at a prospective link
Expand All @@ -40,38 +40,23 @@ type SessionSource struct {
PathFindingConfig PathFindingConfig
}

// getRoutingGraph returns a routing graph and a clean-up function for
// pathfinding.
func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
routingTx, err := newDbRoutingTx(m.Graph)
if err != nil {
return nil, nil, err
}
return routingTx, func() {}, nil
}

// NewPaymentSession creates a new payment session backed by the latest prune
// view from Mission Control. An optional set of routing hints can be provided
// in order to populate additional edges to explore when finding a path to the
// payment's destination.
func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
PaymentSession, error) {

sourceNode, err := m.Graph.SourceNode()
if err != nil {
return nil, err
}

getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi,
error) {

return generateBandwidthHints(
sourceNode.PubKeyBytes, m.Graph, m.QueryBandwidth,
m.Graph.sourceNode(), m.Graph, m.QueryBandwidth,
)
}

session, err := newPaymentSession(
p, getBandwidthHints, m.getRoutingGraph,
p, getBandwidthHints, m.Graph,
m.MissionControl, m.PathFindingConfig,
)
if err != nil {
Expand Down
10 changes: 3 additions & 7 deletions routing/payment_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ func TestUpdateAdditionalEdge(t *testing.T) {

return nil, nil
},
func() (routingGraph, func(), error) {
return &sessionGraph{}, func() {}, nil
},
&sessionGraph{},
&MissionControl{},
PathFindingConfig{},
)
Expand Down Expand Up @@ -203,9 +201,7 @@ func TestRequestRoute(t *testing.T) {

return nil, nil
},
func() (routingGraph, func(), error) {
return &sessionGraph{}, func() {}, nil
},
&sessionGraph{},
&MissionControl{},
PathFindingConfig{},
)
Expand Down Expand Up @@ -255,7 +251,7 @@ func TestRequestRoute(t *testing.T) {
}

type sessionGraph struct {
routingGraph
Graph
}

func (g *sessionGraph) sourceNode() route.Vertex {
Expand Down

0 comments on commit 64a4966

Please sign in to comment.