diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fe15e5bccc..f30f4174300 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,11 @@ * `ERC721`, `ERC1155`: simplified revert reasons. ([#3254](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3254), ([#3438](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3438))) * `ERC721`: removed redundant require statement. ([#3434](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3434)) * `PaymentSplitter`: add `releasable` getters. ([#3350](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3350)) + * `Initializable`: refactored implementation of modifiers for easier understanding. ([#3450](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3450)) + +### Breaking changes + + * `Initializable`: functions decorated with the modifier `reinitializer(1)` may no longer invoke each other. ## 4.6.0 (2022-04-26) diff --git a/contracts/mocks/InitializableMock.sol b/contracts/mocks/InitializableMock.sol index 91ceabe91c0..b24db08911c 100644 --- a/contracts/mocks/InitializableMock.sol +++ b/contracts/mocks/InitializableMock.sol @@ -100,3 +100,23 @@ contract ReinitializerMock is Initializable { counter++; } } + +contract DisableNew is Initializable { + constructor() { + _disableInitializers(); + } +} + +contract DisableOld is Initializable { + constructor() initializer {} +} + +contract DisableBad1 is DisableNew, DisableOld {} + +contract DisableBad2 is Initializable { + constructor() initializer { + _disableInitializers(); + } +} + +contract DisableOk is DisableOld, DisableNew {} diff --git a/contracts/proxy/utils/Initializable.sol b/contracts/proxy/utils/Initializable.sol index 9e8bb6e67f6..2a56119aa12 100644 --- a/contracts/proxy/utils/Initializable.sol +++ b/contracts/proxy/utils/Initializable.sol @@ -76,7 +76,12 @@ abstract contract Initializable { * `onlyInitializing` functions can be used to initialize parent contracts. Equivalent to `reinitializer(1)`. */ modifier initializer() { - bool isTopLevelCall = _setInitializedVersion(1); + bool isTopLevelCall = !_initializing; + require( + (isTopLevelCall && _initialized < 1) || (!Address.isContract(address(this)) && _initialized == 1), + "Initializable: contract is already initialized" + ); + _initialized = 1; if (isTopLevelCall) { _initializing = true; } @@ -100,15 +105,12 @@ abstract contract Initializable { * a contract, executing them in the right order is up to the developer or operator. */ modifier reinitializer(uint8 version) { - bool isTopLevelCall = _setInitializedVersion(version); - if (isTopLevelCall) { - _initializing = true; - } + require(!_initializing && _initialized < version, "Initializable: contract is already initialized"); + _initialized = version; + _initializing = true; _; - if (isTopLevelCall) { - _initializing = false; - emit Initialized(version); - } + _initializing = false; + emit Initialized(version); } /** @@ -127,27 +129,10 @@ abstract contract Initializable { * through proxies. */ function _disableInitializers() internal virtual { - _setInitializedVersion(type(uint8).max); - } - - function _setInitializedVersion(uint8 version) private returns (bool) { - // If the contract is initializing we ignore whether _initialized is set in order to support multiple - // inheritance patterns, but we only do this in the context of a constructor, and for the lowest level - // of initializers, because in other contexts the contract may have been reentered. - - bool isTopLevelCall = !_initializing; // cache sload - uint8 currentVersion = _initialized; // cache sload - - require( - (isTopLevelCall && version > currentVersion) || // not nested with increasing version or - (!Address.isContract(address(this)) && (version == 1 || version == type(uint8).max)), // contract being constructed - "Initializable: contract is already initialized" - ); - - if (isTopLevelCall) { - _initialized = version; + require(!_initializing, "Initializable: contract is initializing"); + if (_initialized < type(uint8).max) { + _initialized = type(uint8).max; + emit Initialized(type(uint8).max); } - - return isTopLevelCall; } } diff --git a/package.json b/package.json index 6074a428893..60509ecd027 100644 --- a/package.json +++ b/package.json @@ -29,7 +29,7 @@ "release": "scripts/release/release.sh", "version": "scripts/release/version.sh", "test": "hardhat test", - "test:inheritance": "scripts/checks/inheritanceOrdering.js artifacts/build-info/*", + "test:inheritance": "scripts/checks/inheritance-ordering.js artifacts/build-info/*", "test:generation": "scripts/checks/generation.sh", "gas-report": "env ENABLE_GAS_REPORT=true npm run test", "slither": "npm run clean && slither . --detect reentrancy-eth,reentrancy-no-eth,reentrancy-unlimited-gas" diff --git a/scripts/checks/inheritanceOrdering.js b/scripts/checks/inheritance-ordering.js similarity index 96% rename from scripts/checks/inheritanceOrdering.js rename to scripts/checks/inheritance-ordering.js index 9d332cba330..3ade7409a7a 100755 --- a/scripts/checks/inheritanceOrdering.js +++ b/scripts/checks/inheritance-ordering.js @@ -13,6 +13,10 @@ for (const artifact of artifacts) { const linearized = []; for (const source in solcOutput.contracts) { + if (source.includes('/mocks/')) { + continue; + } + for (const contractDef of findAll('ContractDefinition', solcOutput.sources[source].ast)) { names[contractDef.id] = contractDef.name; linearized.push(contractDef.linearizedBaseContracts); diff --git a/test/proxy/utils/Initializable.test.js b/test/proxy/utils/Initializable.test.js index 28f272adca8..664bd899d24 100644 --- a/test/proxy/utils/Initializable.test.js +++ b/test/proxy/utils/Initializable.test.js @@ -6,6 +6,9 @@ const ConstructorInitializableMock = artifacts.require('ConstructorInitializable const ChildConstructorInitializableMock = artifacts.require('ChildConstructorInitializableMock'); const ReinitializerMock = artifacts.require('ReinitializerMock'); const SampleChild = artifacts.require('SampleChild'); +const DisableBad1 = artifacts.require('DisableBad1'); +const DisableBad2 = artifacts.require('DisableBad2'); +const DisableOk = artifacts.require('DisableOk'); contract('Initializable', function (accounts) { describe('basic testing without inheritance', function () { @@ -184,4 +187,17 @@ contract('Initializable', function (accounts) { expect(await this.contract.child()).to.be.bignumber.equal(child); }); }); + + describe('disabling initialization', function () { + it('old and new patterns in bad sequence', async function () { + await expectRevert(DisableBad1.new(), 'Initializable: contract is already initialized'); + await expectRevert(DisableBad2.new(), 'Initializable: contract is initializing'); + }); + + it('old and new patterns in good sequence', async function () { + const ok = await DisableOk.new(); + await expectEvent.inConstruction(ok, 'Initialized', { version: '1' }); + await expectEvent.inConstruction(ok, 'Initialized', { version: '255' }); + }); + }); });