Skip to content

Commit

Permalink
Merge pull request ElementsProject#246 from YusukeShimizu/max-htlc-sa…
Browse files Browse the repository at this point in the history
…nity-check

Sanity check max htlc amount msat
  • Loading branch information
wtogami committed Sep 28, 2023
2 parents dd3e403 + 70e93bf commit 226f607
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ test-bitcoin-lnd: test-bins
'Test_LndLnd_Bitcoin_SwapOut|'\
'Test_LndLnd_Bitcoin_SwapIn|'\
'Test_LndCln_Bitcoin_SwapOut|'\
'Test_LndCln_Bitcoin_SwapIn)'\
'Test_LndCln_Bitcoin_SwapIn|'\
'Test_LndLnd_ExcessiveAmount)'\
./test
${INTEGRATION_TEST_ENV} go test $(INTEGRATION_TEST_OPTS) ./lnd
.PHONY: test-bitcoin-lnd
Expand Down
39 changes: 38 additions & 1 deletion lnd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,32 @@ func (l *Client) CanSpend(amtMsat uint64) error {
return nil
}

// getMaxHtlcAmtMsat returns the maximum htlc amount in msat for a channel.
// If for some reason it cannot be retrieved, return 0.
func (l *Client) getMaxHtlcAmtMsat(chanId uint64, pubkey string) (uint64, error) {
var maxHtlcAmtMsat uint64 = 0
r, err := l.lndClient.GetChanInfo(context.Background(), &lnrpc.ChanInfoRequest{
ChanId: chanId,
})
if err != nil {
// Ignore err because channel graph information is not always set.
return maxHtlcAmtMsat, nil
}
if r.Node1Pub == pubkey {
maxHtlcAmtMsat = r.GetNode1Policy().GetMaxHtlcMsat()
} else if r.Node2Pub == pubkey {
maxHtlcAmtMsat = r.GetNode2Policy().GetMaxHtlcMsat()
}
return maxHtlcAmtMsat, nil
}

func min(x, y uint64) uint64 {
if x < y {
return x
}
return y
}

// SpendableMsat returns an estimate of the total we could send through the
// channel with given scid.
func (l *Client) SpendableMsat(scid string) (uint64, error) {
Expand All @@ -96,7 +122,18 @@ func (l *Client) SpendableMsat(scid string) (uint64, error) {
if err = l.checkChannel(ch); err != nil {
return 0, err
}
return uint64(ch.LocalBalance * 1000), nil
maxHtlcAmtMsat, err := l.getMaxHtlcAmtMsat(ch.ChanId, l.pubkey)
if err != nil {
return 0, err
}
spendable := uint64(ch.LocalBalance * 1000)
// since the max htlc limit is not always set reliably,
// the check is skipped if it is not set.
if maxHtlcAmtMsat == 0 {
return spendable, nil
}
return min(maxHtlcAmtMsat, spendable), nil

}
}
return 0, fmt.Errorf("could not find a channel with scid: %s", scid)
Expand Down
9 changes: 9 additions & 0 deletions swap/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,15 @@ func (r *PayFeeInvoiceAction) Execute(services *SwapServices, swap *SwapData) Ev
return swap.HandleError(err)
}

sp, err := ll.SpendableMsat(swap.SwapOutRequest.Scid)
if err != nil {
return swap.HandleError(err)
}

if sp <= swap.SwapOutRequest.Amount*1000 {
return swap.HandleError(err)
}

swap.OpeningTxFee = msatAmt / 1000

expectedFee, err := wallet.GetFlatSwapOutFee()
Expand Down
86 changes: 86 additions & 0 deletions test/bitcoin_lnd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,89 @@ func Test_LndCln_Bitcoin_SwapOut(t *testing.T) {
csvClaimTest(t, params)
})
}

func Test_LndLnd_ExcessiveAmount(t *testing.T) {
IsIntegrationTest(t)
t.Parallel()
t.Run("exceed_maxhtlc", func(t *testing.T) {
t.Parallel()
require := require.New(t)

bitcoind, lightningds, peerswapds, scid := lndlndSetup(t, uint64(math.Pow10(9)))
defer func() {
if t.Failed() {
pprintFail(
tailableProcess{
p: bitcoind.DaemonProcess,
lines: defaultLines,
},
tailableProcess{
p: lightningds[0].DaemonProcess,
lines: defaultLines,
},
tailableProcess{
p: lightningds[1].DaemonProcess,
lines: defaultLines,
},
tailableProcess{
p: peerswapds[0].DaemonProcess,
lines: defaultLines,
},
tailableProcess{
p: peerswapds[1].DaemonProcess,
lines: defaultLines,
},
)
}
}()

var channelBalances []uint64
var walletBalances []uint64
for _, lightningd := range lightningds {
b, err := lightningd.GetBtcBalanceSat()
require.NoError(err)
walletBalances = append(walletBalances, b)

b, err = lightningd.GetChannelBalanceSat(scid)
require.NoError(err)
channelBalances = append(channelBalances, b)
}

lcid, err := lightningds[0].ChanIdFromScid(scid)
if err != nil {
t.Fatalf("lightingds[0].ChanIdFromScid() %v", err)
}

params := &testParams{
swapAmt: channelBalances[0] / 2,
scid: scid,
origTakerWallet: walletBalances[0],
origMakerWallet: walletBalances[1],
origTakerBalance: channelBalances[0],
origMakerBalance: channelBalances[1],
takerNode: lightningds[0],
makerNode: lightningds[1],
takerPeerswap: peerswapds[0].DaemonProcess,
makerPeerswap: peerswapds[1].DaemonProcess,
chainRpc: bitcoind.RpcProxy,
chaind: bitcoind,
confirms: BitcoinConfirms,
csv: BitcoinCsv,
swapType: swap.SWAPTYPE_OUT,
}
asset := "btc"

_, err = lightningds[0].SetHtlcMaximumMilliSatoshis(scid, channelBalances[0]*1000/2-1)
assert.NoError(t, err)
// Swap out should fail as the swap_amt is to high.
// Do swap.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err = peerswapds[0].PeerswapClient.SwapOut(ctx, &peerswaprpc.SwapOutRequest{
ChannelId: lcid,
SwapAmount: params.swapAmt,
Asset: asset,
})
assert.Error(t, err)
})
}
36 changes: 36 additions & 0 deletions testframework/clightning.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/elementsproject/glightning/glightning"
"github.com/elementsproject/peerswap/lightning"
)

type CLightningNode struct {
Expand Down Expand Up @@ -519,3 +520,38 @@ func (n *CLightningNode) GetFeeInvoiceAmtSat() (sat uint64, err error) {
}
return feeInvoiceAmt, nil
}

type SetChannel struct {
Id string `json:"id"`
HtlcMaximumMilliSatoshis string `json:"htlcmax,omitempty"`
}

type ChannelInfo struct {
PeerID string `json:"peer_id"`
ChannelID string `json:"channel_id"`
ShortChannelID string `json:"short_channel_id"`
FeeBaseMsat glightning.Amount `json:"fee_base_msat"`
FeeProportionalMillionths glightning.Amount `json:"fee_proportional_millionths"`
MinimumHtlcOutMsat glightning.Amount `json:"minimum_htlc_out_msat"`
MaximumHtlcOutMsat glightning.Amount `json:"maximum_htlc_out_msat"`
}

type SetChannelResponse struct {
Channels []ChannelInfo `json:"channels"`
}

func (r *SetChannel) Name() string {
return "setchannel"
}

func (n *CLightningNode) SetHtlcMaximumMilliSatoshis(scid string, maxHtlcMsat uint64) (msat uint64, err error) {
var res SetChannelResponse
err = n.Rpc.Request(&SetChannel{
Id: lightning.Scid(scid).ClnStyle(),
HtlcMaximumMilliSatoshis: fmt.Sprint(maxHtlcMsat),
}, &res)
if err != nil {
return 0, err
}
return maxHtlcMsat, err
}
1 change: 1 addition & 0 deletions testframework/lightning.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ type LightningNode interface {
GetFeeInvoiceAmtSat() (sat uint64, err error)

Run(waitForReady, swaitForBitcoinSynced bool) error
SetHtlcMaximumMilliSatoshis(scid string, maxHtlcMsat uint64) (msat uint64, err error)
Stop() error
}
50 changes: 50 additions & 0 deletions testframework/lnd.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import (
"os"
"path/filepath"
"regexp"
"strconv"
"strings"

"github.com/elementsproject/peerswap/lightning"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnwire"
)
Expand Down Expand Up @@ -529,3 +532,50 @@ func (n *LndNode) GetFeeInvoiceAmtSat() (sat uint64, err error) {
}
return feeInvoiceAmt, nil
}

func (n *LndNode) SetHtlcMaximumMilliSatoshis(scid string, maxHtlcMsat uint64) (msat uint64, err error) {
s := lightning.Scid(scid)
res, err := n.Rpc.ListChannels(context.Background(), &lnrpc.ListChannelsRequest{})
if err != nil {
return 0, fmt.Errorf("ListChannels() %w", err)
}
for _, ch := range res.GetChannels() {
channelShortId := lnwire.NewShortChanIDFromInt(ch.ChanId)
if channelShortId.String() == s.LndStyle() {
r, err := n.Rpc.GetChanInfo(context.Background(), &lnrpc.ChanInfoRequest{
ChanId: ch.ChanId,
})
if err != nil {
return 0, err
}
parts := strings.Split(r.ChanPoint, ":")
if len(parts) != 2 {
return 0, fmt.Errorf("expected scid to be composed of 3 blocks")
}
txPosition, err := strconv.Atoi(parts[1])
if err != nil {
return 0, err
}
_, err = n.Rpc.UpdateChannelPolicy(context.Background(), &lnrpc.PolicyUpdateRequest{
Scope: &lnrpc.PolicyUpdateRequest_ChanPoint{ChanPoint: &lnrpc.ChannelPoint{
FundingTxid: &lnrpc.ChannelPoint_FundingTxidStr{
FundingTxidStr: parts[0],
},
OutputIndex: uint32(txPosition),
}},
BaseFeeMsat: 1000,
FeeRate: 1,
FeeRatePpm: 0,
TimeLockDelta: 40,
MaxHtlcMsat: maxHtlcMsat,
MinHtlcMsat: msat,
MinHtlcMsatSpecified: false,
})
if err != nil {
return 0, err
}
return maxHtlcMsat, err
}
}
return 0, fmt.Errorf("could not find a channel with scid: %s", scid)
}

0 comments on commit 226f607

Please sign in to comment.