Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1002 Handle error when interacting with a bad contract #3194

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
6 changes: 4 additions & 2 deletions tests/core/contracts/test_contract_call_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from web3.exceptions import (
BadFunctionCallOutput,
BlockNumberOutofRange,
ContractLogicError,
FallbackNotFound,
InvalidAddress,
MismatchedABI,
Expand Down Expand Up @@ -477,9 +478,10 @@ def test_call_missing_function(mismatched_math_contract, call):

def test_call_undeployed_contract(undeployed_math_contract, call):
expected_undeployed_call_error_message = (
"Could not transact with/call contract function"
"Could not transact with/call contract function, is contract "
+ "deployed correctly and chain synced?"
)
with pytest.raises(BadFunctionCallOutput) as exception_info:
with pytest.raises(ContractLogicError) as exception_info:
call(contract=undeployed_math_contract, contract_function="return13")
assert expected_undeployed_call_error_message in str(exception_info.value)

Expand Down
4 changes: 2 additions & 2 deletions tests/core/contracts/test_contract_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ens_addresses,
)
from web3.exceptions import (
BadFunctionCallOutput,
ContractLogicError,
NameNotFound,
)

Expand Down Expand Up @@ -68,7 +68,7 @@ def test_contract_with_name_address_changing(math_contract_factory, math_addr):

# what happens when name returns address to different contract
with contract_ens_addresses(mc, [("thedao.eth", "0x" + "11" * 20)]):
with pytest.raises(BadFunctionCallOutput):
with pytest.raises(ContractLogicError):
mc.functions.return13().call({"from": caller})

# contract works again when name resolves correctly
Expand Down
33 changes: 33 additions & 0 deletions tests/core/contracts/test_contract_logic_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from unittest.mock import (
patch,
)

from web3 import (
Web3,
)
from web3._utils.contract_sources.contract_data.contract_logic_error_tester import (
MEAS_PUB_DATA,
)
from web3.exceptions import (
ContractLogicError,
)


@pytest.fixture
def web3_instance():
return Web3(Web3.EthereumTesterProvider())


def test_get_subscriber_count_with_invalid_address(web3_instance):
invalid_contract_address = "0x0000000000000000000000000000000000000000"
contract_instance = web3_instance.eth.contract(
address=invalid_contract_address, abi=MEAS_PUB_DATA["abi"]
)

# Mock the responses for eth_call and get_code
with patch.object(web3_instance.eth, "call", return_value=b""):
with patch.object(web3_instance.eth, "get_code", return_value=""):
# Expect the ContractLogicError due to the mocked conditions
with pytest.raises(ContractLogicError):
contract_instance.functions.getSubscriberCount().call()
69 changes: 69 additions & 0 deletions web3/_utils/contract_sources/ContractLogicErrorTester.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.23;

contract MeasPub {
address public publisher;
string public description;
uint public price_per_second; // wei per second

mapping (address => uint) private balances;
mapping (address => uint) private last_publish_block;
mapping (uint => address) private subscriber_index;
uint private subscriber_count;
mapping (address => string) private subscriber_pubkey;

event LogPublished(address indexed _subscriber, bytes _pwdenc, bytes _ipfsaddr, uint _cost);
event LogDebug(string _msg);

constructor() {
publisher = msg.sender;
price_per_second = 0; // Wei
}

function setPricePerSecond(uint _price) public {
if (msg.sender == publisher) {
price_per_second = _price;
}
}

function publish(address _subscriber, bytes memory _pwdenc, bytes memory _ipfsaddr) public returns (bool covered) {
if (msg.sender != publisher) {
emit LogDebug("only publisher can publish");
return false;
}
uint cost = (block.timestamp - last_publish_block[_subscriber]) * price_per_second;
if (balances[_subscriber] < cost) {
emit LogDebug("subscriber has insufficient funds");
return false;
}
balances[_subscriber] -= cost;
payable(publisher).transfer(cost);
last_publish_block[_subscriber] = block.timestamp;
emit LogPublished(_subscriber, _pwdenc, _ipfsaddr, cost);
return true;
}

function getSubscriberCount() public view returns (uint count) {
return subscriber_count;
}

function getSubscriber(uint _index) public view returns (address _subscriber, string memory _pubkey) {
if (msg.sender != publisher) return (address(0), "");
return (subscriber_index[_index], subscriber_pubkey[subscriber_index[_index]]);
}

function subscribe(string memory _pubkey) public payable returns (bool success) {
if (last_publish_block[msg.sender] != 0) return false;
last_publish_block[msg.sender] = block.timestamp;
subscriber_index[subscriber_count] = msg.sender;
subscriber_count += 1;
subscriber_pubkey[msg.sender] = _pubkey;
balances[msg.sender] += msg.value;
emit LogDebug("new subscription successful");
return true;
}

function kill() public {
if (msg.sender == publisher) selfdestruct(payable(publisher));
}
}

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion web3/contract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from web3.exceptions import (
BadFunctionCallOutput,
ContractLogicError,
)
from web3.types import (
ABI,
Expand Down Expand Up @@ -121,12 +122,13 @@ def call_contract_function(
"Could not transact with/call contract function, is contract "
"deployed correctly and chain synced?"
)
raise ContractLogicError(msg) from e
else:
msg = (
f"Could not decode contract function call to {function_identifier} "
f"with return data: {str(return_data)}, output_types: {output_types}"
)
raise BadFunctionCallOutput(msg) from e
raise BadFunctionCallOutput(msg) from e

_normalizers = itertools.chain(
BASE_RETURN_NORMALIZERS,
Expand Down