diff --git a/openapi/SwarmCommon.yaml b/openapi/SwarmCommon.yaml index f0b4b8ed719..8bb49e61777 100644 --- a/openapi/SwarmCommon.yaml +++ b/openapi/SwarmCommon.yaml @@ -586,6 +586,18 @@ components: stakedAmount: $ref: "#/components/schemas/BigInt" + StakeDepositResponse: + type: object + properties: + txHash: + $ref: "#/components/schemas/TransactionHash" + + WithdrawAllStakeResponse: + type: object + properties: + txHash: + $ref: "#/components/schemas/TransactionHash" + SwarmOnlyReference: oneOf: - $ref: "#/components/schemas/SwarmAddress" diff --git a/openapi/SwarmDebug.yaml b/openapi/SwarmDebug.yaml index 907659b3f99..38ff38587f4 100644 --- a/openapi/SwarmDebug.yaml +++ b/openapi/SwarmDebug.yaml @@ -1056,6 +1056,23 @@ paths: $ref: "SwarmCommon.yaml#/components/responses/500" default: description: Default response + delete: + summary: Withdraw all staked amount. + description: Be aware, this endpoint creates an on-chain transactions and transfers BZZ from the node's Ethereum account and hence directly manipulates the wallet balance. + tags: + - Staking + parameters: + - $ref: "SwarmCommon.yaml#/components/parameters/GasPriceParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/GasLimitParameter" + responses: + "200": + $ref: "SwarmCommon.yaml#/components/schemas/WithdrawAllStakeResponse" + "400": + $ref: "SwarmCommon.yaml#/components/responses/400" + "500": + $ref: "SwarmCommon.yaml#/components/responses/500" + default: + description: Default response "/loggers": get: diff --git a/pkg/api/export_test.go b/pkg/api/export_test.go index f5d6fe16ec0..75fba7c4e24 100644 --- a/pkg/api/export_test.go +++ b/pkg/api/export_test.go @@ -100,6 +100,7 @@ type ( BucketData = bucketData WalletResponse = walletResponse GetStakeResponse = getStakeResponse + WithdrawAllStakeResponse = withdrawAllStakeResponse ) var ( diff --git a/pkg/api/router.go b/pkg/api/router.go index 5bf51509d5a..e7574f54fb3 100644 --- a/pkg/api/router.go +++ b/pkg/api/router.go @@ -562,8 +562,10 @@ func (s *Service) mountBusinessDebug(restricted bool) { handle("/stake", web.ChainHandlers( s.stakingAccessHandler, + s.gasConfigMiddleware("get or withdraw stake"), web.FinalHandler(jsonhttp.MethodHandler{ - "GET": http.HandlerFunc(s.getStakedAmountHandler), + "GET": http.HandlerFunc(s.getStakedAmountHandler), + "DELETE": http.HandlerFunc(s.withdrawAllStakeHandler), })), ) } diff --git a/pkg/api/staking.go b/pkg/api/staking.go index 1e50f0baff1..77e8a7e836c 100644 --- a/pkg/api/staking.go +++ b/pkg/api/staking.go @@ -6,10 +6,11 @@ package api import ( "errors" - "github.com/ethersphere/bee/pkg/bigint" "math/big" "net/http" + "github.com/ethersphere/bee/pkg/bigint" + "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/storageincentives/staking" "github.com/gorilla/mux" @@ -36,6 +37,10 @@ type stakeDepositResponse struct { TxHash string `json:"txhash"` } +type withdrawAllStakeResponse struct { + TxHash string `json:"txhash"` +} + func (s *Service) stakingDepositHandler(w http.ResponseWriter, r *http.Request) { logger := s.logger.WithName("post_stake_deposit").Build() @@ -90,3 +95,23 @@ func (s *Service) getStakedAmountHandler(w http.ResponseWriter, r *http.Request) jsonhttp.OK(w, getStakeResponse{StakedAmount: bigint.Wrap(stakedAmount)}) } + +func (s *Service) withdrawAllStakeHandler(w http.ResponseWriter, r *http.Request) { + logger := s.logger.WithName("delete_withdraw_all_stake").Build() + + txHash, err := s.stakingContract.WithdrawAllStake(r.Context()) + if err != nil { + if errors.Is(err, staking.ErrInsufficientStake) { + logger.Debug("insufficient stake", "overlayAddr", s.overlay, "error", err) + logger.Error(nil, "insufficient stake") + jsonhttp.BadRequest(w, "insufficient stake to withdraw") + return + } + logger.Debug("withdraw stake failed", "error", err) + logger.Error(nil, "withdraw stake failed") + jsonhttp.InternalServerError(w, "cannot withdraw stake") + return + } + + jsonhttp.OK(w, withdrawAllStakeResponse{TxHash: txHash.String()}) +} diff --git a/pkg/api/staking_test.go b/pkg/api/staking_test.go index 51477c94b81..2ca045cbcaa 100644 --- a/pkg/api/staking_test.go +++ b/pkg/api/staking_test.go @@ -7,12 +7,13 @@ package api_test import ( "context" "fmt" - "github.com/ethereum/go-ethereum/common" - "github.com/ethersphere/bee/pkg/bigint" "math/big" "net/http" "testing" + "github.com/ethereum/go-ethereum/common" + "github.com/ethersphere/bee/pkg/bigint" + "github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" @@ -170,3 +171,70 @@ func Test_stakingDepositHandler_invalidInputs(t *testing.T) { }) } } + +func TestWithdrawAllStake(t *testing.T) { + t.Parallel() + + txHash := common.HexToHash("0x1234") + + t.Run("ok", func(t *testing.T) { + t.Parallel() + + contract := stakingContractMock.New( + stakingContractMock.WithWithdrawAllStake(func(ctx context.Context) (common.Hash, error) { + return txHash, nil + }), + ) + ts, _, _, _ := newTestServer(t, testServerOptions{DebugAPI: true, StakingContract: contract}) + jsonhttptest.Request(t, ts, http.MethodDelete, "/stake", http.StatusOK, jsonhttptest.WithExpectedJSONResponse( + &api.WithdrawAllStakeResponse{TxHash: txHash.String()})) + }) + + t.Run("with invalid stake amount", func(t *testing.T) { + t.Parallel() + + contract := stakingContractMock.New( + stakingContractMock.WithWithdrawAllStake(func(ctx context.Context) (common.Hash, error) { + return common.Hash{}, staking.ErrInsufficientStake + }), + ) + ts, _, _, _ := newTestServer(t, testServerOptions{DebugAPI: true, StakingContract: contract}) + jsonhttptest.Request(t, ts, http.MethodDelete, "/stake", http.StatusBadRequest, + jsonhttptest.WithExpectedJSONResponse(&jsonhttp.StatusResponse{Code: http.StatusBadRequest, Message: "insufficient stake to withdraw"})) + }) + + t.Run("internal error", func(t *testing.T) { + t.Parallel() + + contract := stakingContractMock.New( + stakingContractMock.WithWithdrawAllStake(func(ctx context.Context) (common.Hash, error) { + return common.Hash{}, fmt.Errorf("some error") + }), + ) + ts, _, _, _ := newTestServer(t, testServerOptions{DebugAPI: true, StakingContract: contract}) + jsonhttptest.Request(t, ts, http.MethodDelete, "/stake", http.StatusInternalServerError) + jsonhttptest.WithExpectedJSONResponse(&jsonhttp.StatusResponse{Code: http.StatusInternalServerError, Message: "cannot withdraw stake"}) + }) + + t.Run("gas limit header", func(t *testing.T) { + t.Parallel() + + contract := stakingContractMock.New( + stakingContractMock.WithWithdrawAllStake(func(ctx context.Context) (common.Hash, error) { + gasLimit := sctx.GetGasLimit(ctx) + if gasLimit != 2000000 { + t.Fatalf("want 2000000, got %d", gasLimit) + } + return txHash, nil + }), + ) + ts, _, _, _ := newTestServer(t, testServerOptions{ + DebugAPI: true, + StakingContract: contract, + }) + + jsonhttptest.Request(t, ts, http.MethodDelete, "/stake", http.StatusOK, + jsonhttptest.WithRequestHeader("Gas-Limit", "2000000"), + ) + }) +} diff --git a/pkg/storageincentives/staking/contract.go b/pkg/storageincentives/staking/contract.go index b12a3bf89bf..fcf99935277 100644 --- a/pkg/storageincentives/staking/contract.go +++ b/pkg/storageincentives/staking/contract.go @@ -27,15 +27,19 @@ var ( ErrInsufficientStakeAmount = errors.New("insufficient stake amount") ErrInsufficientFunds = errors.New("insufficient token balance") + ErrInsufficientStake = errors.New("insufficient stake") ErrNotImplemented = errors.New("not implemented") + ErrNotPaused = errors.New("contract is not paused") - approveDescription = "Approve tokens for stake deposit operations" - depositStakeDescription = "Deposit Stake" + approveDescription = "Approve tokens for stake deposit operations" + depositStakeDescription = "Deposit Stake" + withdrawStakeDescription = "Withdraw stake" ) type Contract interface { DepositStake(ctx context.Context, stakedAmount *big.Int) (common.Hash, error) GetStake(ctx context.Context) (*big.Int, error) + WithdrawAllStake(ctx context.Context) (common.Hash, error) } type contract struct { @@ -158,42 +162,46 @@ func (c *contract) getStake(ctx context.Context, overlay swarm.Address) (*big.In if err != nil { return nil, err } + + if len(results) == 0 { + return nil, errors.New("unexpected empty results") + } + return abi.ConvertType(results[0], new(big.Int)).(*big.Int), nil } -func (c *contract) DepositStake(ctx context.Context, stakedAmount *big.Int) (txHash common.Hash, err error) { +func (c *contract) DepositStake(ctx context.Context, stakedAmount *big.Int) (common.Hash, error) { prevStakedAmount, err := c.GetStake(ctx) if err != nil { - return + return common.Hash{}, err } if len(prevStakedAmount.Bits()) == 0 { if stakedAmount.Cmp(MinimumStakeAmount) == -1 { - err = ErrInsufficientStakeAmount - return + return common.Hash{}, ErrInsufficientStakeAmount } } balance, err := c.getBalance(ctx) if err != nil { - return + return common.Hash{}, err } if balance.Cmp(stakedAmount) < 0 { - err = ErrInsufficientFunds - return + return common.Hash{}, ErrInsufficientFunds } _, err = c.sendApproveTransaction(ctx, stakedAmount) if err != nil { - return + return common.Hash{}, err } receipt, err := c.sendDepositStakeTransaction(ctx, c.owner, stakedAmount, c.overlayNonce) - if receipt != nil { - txHash = receipt.TxHash + if err != nil { + return common.Hash{}, err } - return + + return receipt.TxHash, nil } func (c *contract) GetStake(ctx context.Context) (*big.Int, error) { @@ -222,5 +230,86 @@ func (c *contract) getBalance(ctx context.Context) (*big.Int, error) { if err != nil { return nil, err } + + if len(results) == 0 { + return nil, errors.New("unexpected empty results") + } + return abi.ConvertType(results[0], new(big.Int)).(*big.Int), nil } + +func (c *contract) WithdrawAllStake(ctx context.Context) (txHash common.Hash, err error) { + isPaused, err := c.paused(ctx) + if err != nil { + return + } + if !isPaused { + return common.Hash{}, ErrNotPaused + } + + stakedAmount, err := c.getStake(ctx, c.overlay) + if err != nil { + return + } + + if stakedAmount.Cmp(big.NewInt(0)) <= 0 { + return common.Hash{}, ErrInsufficientStake + } + + _, err = c.sendApproveTransaction(ctx, stakedAmount) + if err != nil { + return common.Hash{}, err + } + + receipt, err := c.withdrawFromStake(ctx, stakedAmount) + if err != nil { + return common.Hash{}, err + } + if receipt != nil { + txHash = receipt.TxHash + } + return txHash, nil +} + +func (c *contract) withdrawFromStake(ctx context.Context, stakedAmount *big.Int) (*types.Receipt, error) { + var overlayAddr [32]byte + copy(overlayAddr[:], c.overlay.Bytes()) + + callData, err := c.stakingContractABI.Pack("withdrawFromStake", overlayAddr, stakedAmount) + if err != nil { + return nil, err + } + + receipt, err := c.sendTransaction(ctx, callData, withdrawStakeDescription) + if err != nil { + return nil, fmt.Errorf("withdraw stake: stakedAmount %d: %w", stakedAmount, err) + } + + return receipt, nil +} + +func (c *contract) paused(ctx context.Context) (bool, error) { + callData, err := c.stakingContractABI.Pack("paused") + if err != nil { + return false, err + } + + result, err := c.transactionService.Call(ctx, &transaction.TxRequest{ + To: &c.stakingContractAddress, + Data: callData, + }) + if err != nil { + return false, err + } + + results, err := c.stakingContractABI.Unpack("paused", result) + if err != nil { + return false, err + } + + if len(results) == 0 { + return false, errors.New("unexpected empty results") + } + + return results[0].(bool), nil +} diff --git a/pkg/storageincentives/staking/contract_test.go b/pkg/storageincentives/staking/contract_test.go index b924de003d5..96d5986040e 100644 --- a/pkg/storageincentives/staking/contract_test.go +++ b/pkg/storageincentives/staking/contract_test.go @@ -635,3 +635,404 @@ func TestGetStake(t *testing.T) { } }) } + +func TestWithdrawStake(t *testing.T) { + t.Parallel() + + ctx := context.Background() + owner := common.HexToAddress("abcd") + stakingContractAddress := common.HexToAddress("ffff") + bzzTokenAddress := common.HexToAddress("eeee") + nonce := common.BytesToHash(make([]byte, 32)) + stakedAmount := big.NewInt(100000000000000000) + addr := swarm.MustParseHexAddress("f30c0aa7e9e2a0ef4c9b1b750ebfeaeb7c7c24da700bb089da19a46e3677824b") + txHashApprove := common.HexToHash("abb0") + + t.Run("ok", func(t *testing.T) { + t.Parallel() + txHashWithdrawn := common.HexToHash("c3a1") + expected := big.NewInt(1) + + expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + if err != nil { + t.Fatal(err) + } + + expectedCallDataForWithdraw, err := stakingContractABI.Pack("withdrawFromStake", common.BytesToHash(addr.Bytes()), stakedAmount) + if err != nil { + t.Fatal(err) + } + + expectedCallDataForGetStake, err := stakingContractABI.Pack("stakeOfOverlay", common.BytesToHash(addr.Bytes())) + if err != nil { + t.Fatal(err) + } + + contract := staking.New( + addr, + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithSendFunc(func(ctx context.Context, request *transaction.TxRequest, boost int) (txHash common.Hash, err error) { + if *request.To == bzzTokenAddress { + return txHashApprove, nil + } + if *request.To == stakingContractAddress { + if !bytes.Equal(expectedCallDataForWithdraw[:], request.Data[:]) { + return common.Hash{}, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallDataForWithdraw, request.Data) + } + return txHashWithdrawn, nil + } + return common.Hash{}, errors.New("sent to wrong contract") + }), + transactionMock.WithWaitForReceiptFunc(func(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) { + if txHash == txHashApprove { + return &types.Receipt{ + Status: 1, + }, nil + } + if txHash == txHashWithdrawn { + return &types.Receipt{ + Status: 1, + }, nil + } + return nil, errors.New("unknown tx hash") + }), + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { + return expected.FillBytes(make([]byte, 32)), nil + } + if bytes.Equal(expectedCallDataForGetStake[:64], request.Data[:64]) { + return stakedAmount.FillBytes(make([]byte, 32)), nil + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + ) + + _, err = contract.WithdrawAllStake(ctx) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("is paused", func(t *testing.T) { + t.Parallel() + expected := big.NewInt(0) + + expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + if err != nil { + t.Fatal(err) + } + + contract := staking.New( + addr, + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { + return expected.FillBytes(make([]byte, 32)), nil + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + ) + + _, err = contract.WithdrawAllStake(ctx) + if !errors.Is(err, staking.ErrNotPaused) { + t.Fatal(err) + } + }) + + t.Run("has no stake", func(t *testing.T) { + t.Parallel() + expected := big.NewInt(1) + + expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + if err != nil { + t.Fatal(err) + } + + invalidStakedAmount := big.NewInt(0) + + expectedCallDataForGetStake, err := stakingContractABI.Pack("stakeOfOverlay", common.BytesToHash(addr.Bytes())) + if err != nil { + t.Fatal(err) + } + + contract := staking.New( + addr, + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { + return expected.FillBytes(make([]byte, 32)), nil + } + if bytes.Equal(expectedCallDataForGetStake[:64], request.Data[:64]) { + return invalidStakedAmount.FillBytes(make([]byte, 32)), nil + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + ) + + _, err = contract.WithdrawAllStake(ctx) + if !errors.Is(err, staking.ErrInsufficientStake) { + t.Fatal(err) + } + }) + + t.Run("invalid call data", func(t *testing.T) { + t.Parallel() + _, err := stakingContractABI.Pack("paused", addr) + if err == nil { + t.Fatal(err) + } + _, err = stakingContractABI.Pack("withdrawFromStake", stakedAmount) + if err == nil { + t.Fatal(err) + } + + _, err = stakingContractABI.Pack("stakeOfOverlay", stakedAmount) + if err == nil { + t.Fatal(err) + } + }) + + t.Run("send tx failed", func(t *testing.T) { + t.Parallel() + txHashWithdrawn := common.HexToHash("c3a1") + expected := big.NewInt(1) + + expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + if err != nil { + t.Fatal(err) + } + + expectedCallDataForWithdraw, err := stakingContractABI.Pack("withdrawFromStake", common.BytesToHash(addr.Bytes()), stakedAmount) + if err != nil { + t.Fatal(err) + } + + expectedCallDataForGetStake, err := stakingContractABI.Pack("stakeOfOverlay", common.BytesToHash(addr.Bytes())) + if err != nil { + t.Fatal(err) + } + + contract := staking.New( + addr, + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithSendFunc(func(ctx context.Context, request *transaction.TxRequest, boost int) (txHash common.Hash, err error) { + if *request.To == bzzTokenAddress { + return txHashApprove, nil + } + if *request.To == stakingContractAddress { + if !bytes.Equal(expectedCallDataForWithdraw[:], request.Data[:]) { + return common.Hash{}, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallDataForWithdraw, request.Data) + } + return common.Hash{}, errors.New("send tx failed") + } + return common.Hash{}, errors.New("sent to wrong contract") + }), + transactionMock.WithWaitForReceiptFunc(func(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) { + if txHash == txHashApprove { + return &types.Receipt{ + Status: 1, + }, nil + } + if txHash == txHashWithdrawn { + return &types.Receipt{ + Status: 1, + }, nil + } + return nil, errors.New("unknown tx hash") + }), + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { + return expected.FillBytes(make([]byte, 32)), nil + } + if bytes.Equal(expectedCallDataForGetStake[:64], request.Data[:64]) { + return stakedAmount.FillBytes(make([]byte, 32)), nil + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + ) + + _, err = contract.WithdrawAllStake(ctx) + if err == nil { + t.Fatal(err) + } + }) + + t.Run("tx reverted", func(t *testing.T) { + t.Parallel() + txHashWithdrawn := common.HexToHash("c3a1") + expected := big.NewInt(1) + + expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + if err != nil { + t.Fatal(err) + } + + expectedCallDataForGetStake, err := stakingContractABI.Pack("stakeOfOverlay", common.BytesToHash(addr.Bytes())) + if err != nil { + t.Fatal(err) + } + + expectedCallDataForWithdraw, err := stakingContractABI.Pack("withdrawFromStake", common.BytesToHash(addr.Bytes()), stakedAmount) + if err != nil { + t.Fatal(err) + } + + contract := staking.New( + addr, + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithSendFunc(func(ctx context.Context, request *transaction.TxRequest, boost int) (txHash common.Hash, err error) { + if *request.To == bzzTokenAddress { + return txHashApprove, nil + } + if *request.To == stakingContractAddress { + if !bytes.Equal(expectedCallDataForWithdraw[:], request.Data[:]) { + return common.Hash{}, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallDataForWithdraw, request.Data) + } + return txHashWithdrawn, nil + } + return common.Hash{}, errors.New("sent to wrong contract") + }), + transactionMock.WithWaitForReceiptFunc(func(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) { + if txHash == txHashApprove { + return &types.Receipt{ + Status: 1, + }, nil + } + if txHash == txHashWithdrawn { + return &types.Receipt{ + Status: 0, + }, nil + } + return nil, errors.New("unknown tx hash") + }), + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { + return expected.FillBytes(make([]byte, 32)), nil + } + if bytes.Equal(expectedCallDataForGetStake[:64], request.Data[:64]) { + return stakedAmount.FillBytes(make([]byte, 32)), nil + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + ) + + _, err = contract.WithdrawAllStake(ctx) + if err == nil { + t.Fatal(err) + } + }) + + t.Run("is paused with err", func(t *testing.T) { + t.Parallel() + expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + if err != nil { + t.Fatal(err) + } + + contract := staking.New( + addr, + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { + return nil, fmt.Errorf("some error") + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + ) + + _, err = contract.WithdrawAllStake(ctx) + if err == nil { + t.Fatal(err) + } + }) + + t.Run("get stake with err", func(t *testing.T) { + t.Parallel() + expected := big.NewInt(1) + + expectedCallDataForPaused, err := stakingContractABI.Pack("paused") + if err != nil { + t.Fatal(err) + } + + expectedCallDataForGetStake, err := stakingContractABI.Pack("stakeOfOverlay", common.BytesToHash(addr.Bytes())) + if err != nil { + t.Fatal(err) + } + + contract := staking.New( + addr, + owner, + stakingContractAddress, + stakingContractABI, + bzzTokenAddress, + transactionMock.New( + transactionMock.WithCallFunc(func(ctx context.Context, request *transaction.TxRequest) (result []byte, err error) { + if *request.To == stakingContractAddress { + if bytes.Equal(expectedCallDataForPaused[:], request.Data[:]) { + return expected.FillBytes(make([]byte, 32)), nil + } + if bytes.Equal(expectedCallDataForGetStake[:64], request.Data[:64]) { + return nil, fmt.Errorf("some error") + } + } + return nil, errors.New("unexpected call") + }), + ), + nonce, + ) + + _, err = contract.WithdrawAllStake(ctx) + if err == nil { + t.Fatal(err) + } + }) +} diff --git a/pkg/storageincentives/staking/mock/contract.go b/pkg/storageincentives/staking/mock/contract.go index eaf8730930a..58e78b9ca65 100644 --- a/pkg/storageincentives/staking/mock/contract.go +++ b/pkg/storageincentives/staking/mock/contract.go @@ -6,15 +6,17 @@ package mock import ( "context" - "github.com/ethereum/go-ethereum/common" "math/big" + "github.com/ethereum/go-ethereum/common" + "github.com/ethersphere/bee/pkg/storageincentives/staking" ) type stakingContractMock struct { - depositStake func(ctx context.Context, stakedAmount *big.Int) (common.Hash, error) - getStake func(ctx context.Context) (*big.Int, error) + depositStake func(ctx context.Context, stakedAmount *big.Int) (common.Hash, error) + getStake func(ctx context.Context) (*big.Int, error) + withdrawAllStake func(ctx context.Context) (common.Hash, error) } func (s *stakingContractMock) DepositStake(ctx context.Context, stakedAmount *big.Int) (common.Hash, error) { @@ -25,6 +27,10 @@ func (s *stakingContractMock) GetStake(ctx context.Context) (*big.Int, error) { return s.getStake(ctx) } +func (s *stakingContractMock) WithdrawAllStake(ctx context.Context) (common.Hash, error) { + return s.withdrawAllStake(ctx) +} + // Option is a an option passed to New type Option func(mock *stakingContractMock) @@ -50,3 +56,9 @@ func WithGetStake(f func(ctx context.Context) (*big.Int, error)) Option { mock.getStake = f } } + +func WithWithdrawAllStake(f func(ctx context.Context) (common.Hash, error)) Option { + return func(mock *stakingContractMock) { + mock.withdrawAllStake = f + } +}