From 0eba5112c8f0a2e084fa69aacd8e282cf996365b Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Tue, 22 Mar 2022 19:06:29 +0100 Subject: [PATCH] Allow the re-initialization of contracts (#3232) * allow re-initialization of contracts * fix lint * use a private function to avoid code duplication * use oz-retyped-from syntax * add documentation * rephrase * documentation * Update contracts/proxy/utils/Initializable.sol Co-authored-by: Francisco Giordano * reinitialize test * lint * typos and style * add note about relation between initializer and reinitializer * lint * set _initializing in the modifier * remove unnecessary variable set * rename _preventInitialize -> _disableInitializers * rename preventInitialize -> disableInitializers * test nested reinitializers in reverse order * docs typos and style * edit docs for consistency between initializer and reinitializer Co-authored-by: Francisco Giordano --- contracts/mocks/InitializableMock.sol | 29 ++++++++ contracts/proxy/utils/Initializable.sol | 94 ++++++++++++++++++++----- test/proxy/utils/Initializable.test.js | 83 ++++++++++++++++++---- 3 files changed, 176 insertions(+), 30 deletions(-) diff --git a/contracts/mocks/InitializableMock.sol b/contracts/mocks/InitializableMock.sol index 630e8bbfad6..bdf53991fbc 100644 --- a/contracts/mocks/InitializableMock.sol +++ b/contracts/mocks/InitializableMock.sol @@ -59,3 +59,32 @@ contract ConstructorInitializableMock is Initializable { onlyInitializingRan = true; } } + +contract ReinitializerMock is Initializable { + uint256 public counter; + + function initialize() public initializer { + doStuff(); + } + + function reinitialize(uint8 i) public reinitializer(i) { + doStuff(); + } + + function nestedReinitialize(uint8 i, uint8 j) public reinitializer(i) { + reinitialize(j); + } + + function chainReinitialize(uint8 i, uint8 j) public { + reinitialize(i); + reinitialize(j); + } + + function disableInitializers() public { + _disableInitializers(); + } + + function doStuff() public onlyInitializing { + counter++; + } +} diff --git a/contracts/proxy/utils/Initializable.sol b/contracts/proxy/utils/Initializable.sol index 1fe4583a2b9..e6b2a4423e3 100644 --- a/contracts/proxy/utils/Initializable.sol +++ b/contracts/proxy/utils/Initializable.sol @@ -11,6 +11,26 @@ import "../../utils/Address.sol"; * external initializer function, usually called `initialize`. It then becomes necessary to protect this initializer * function so it can only be called once. The {initializer} modifier provided by this contract will have this effect. * + * The initialization functions use a version number. Once a version number is used, it is consumed and cannot be + * reused. This mechanism prevents re-execution of each "step" but allows the creation of new initialization steps in + * case an upgrade adds a module that needs to be initialized. + * + * For example: + * + * [.hljs-theme-light.nopadding] + * ``` + * contract MyToken is ERC20Upgradeable { + * function initialize() initializer public { + * __ERC20_init("MyToken", "MTK"); + * } + * } + * contract MyTokenV2 is MyToken, ERC20PermitUpgradeable { + * function initializeV2() reinitializer(2) public { + * __ERC20Permit_init("MyToken"); + * } + * } + * ``` + * * TIP: To avoid leaving the proxy in an uninitialized state, the initializer function should be called as early as * possible by providing the encoded function call as the `_data` argument to {ERC1967Proxy-constructor}. * @@ -22,21 +42,24 @@ import "../../utils/Address.sol"; * Avoid leaving a contract uninitialized. * * An uninitialized contract can be taken over by an attacker. This applies to both a proxy and its implementation - * contract, which may impact the proxy. To initialize the implementation contract, you can either invoke the - * initializer manually, or you can include a constructor to automatically mark it as initialized when it is deployed: + * contract, which may impact the proxy. To prevent the implementation contract from being used, you should invoke + * the {_disableInitializers} function in the constructor to automatically lock it when it is deployed: * * [.hljs-theme-light.nopadding] * ``` * /// @custom:oz-upgrades-unsafe-allow constructor - * constructor() initializer {} + * constructor() { + * _disableInitializers(); + * } * ``` * ==== */ abstract contract Initializable { /** * @dev Indicates that the contract has been initialized. + * @custom:oz-retyped-from bool */ - bool private _initialized; + uint8 private _initialized; /** * @dev Indicates that the contract is in the process of being initialized. @@ -44,22 +67,38 @@ abstract contract Initializable { bool private _initializing; /** - * @dev Modifier to protect an initializer function from being invoked twice. + * @dev A modifier that defines a protected initializer function that can be invoked at most once. In its scope, + * `onlyInitializing` functions can be used to initialize parent contracts. Equivalent to `reinitializer(1)`. */ modifier initializer() { - // If the contract is initializing we ignore whether _initialized is set in order to support multiple - // inheritance patterns, but we only do this in the context of a constructor, because in other contexts the - // contract may have been reentered. - require(_initializing ? _isConstructor() : !_initialized, "Initializable: contract is already initialized"); - - bool isTopLevelCall = !_initializing; + bool isTopLevelCall = _setInitializedVersion(1); if (isTopLevelCall) { _initializing = true; - _initialized = true; } - _; + if (isTopLevelCall) { + _initializing = false; + } + } + /** + * @dev A modifier that defines a protected reinitializer function that can be invoked at most once, and only if the + * contract hasn't been initialized to a greater version before. In its scope, `onlyInitializing` functions can be + * used to initialize parent contracts. + * + * `initializer` is equivalent to `reinitializer(1)`, so a reinitializer may be used after the original + * initialization step. This is essential to configure modules that are added through upgrades and that require + * initialization. + * + * Note that versions can jump in increments greater than 1; this implies that if multiple reinitializers coexist in + * a contract, executing them in the right order is up to the developer or operator. + */ + modifier reinitializer(uint8 version) { + bool isTopLevelCall = _setInitializedVersion(version); + if (isTopLevelCall) { + _initializing = true; + } + _; if (isTopLevelCall) { _initializing = false; } @@ -67,14 +106,37 @@ abstract contract Initializable { /** * @dev Modifier to protect an initialization function so that it can only be invoked by functions with the - * {initializer} modifier, directly or indirectly. + * {initializer} and {reinitializer} modifiers, directly or indirectly. */ modifier onlyInitializing() { require(_initializing, "Initializable: contract is not initializing"); _; } - function _isConstructor() private view returns (bool) { - return !Address.isContract(address(this)); + /** + * @dev Locks the contract, preventing any future reinitialization. This cannot be part of an initializer call. + * Calling this in the constructor of a contract will prevent that contract from being initialized or reinitialized + * to any version. It is recommended to use this to lock implementation contracts that are designed to be called + * through proxies. + */ + function _disableInitializers() internal virtual { + _setInitializedVersion(type(uint8).max); + } + + function _setInitializedVersion(uint8 version) private returns (bool) { + // If the contract is initializing we ignore whether _initialized is set in order to support multiple + // inheritance patterns, but we only do this in the context of a constructor, and for the lowest level + // of initializers, because in other contexts the contract may have been reentered. + if (_initializing) { + require( + version == 1 && !Address.isContract(address(this)), + "Initializable: contract is already initialized" + ); + return false; + } else { + require(_initialized < version, "Initializable: contract is already initialized"); + _initialized = version; + return true; + } } } diff --git a/test/proxy/utils/Initializable.test.js b/test/proxy/utils/Initializable.test.js index 04884a1d45e..1efb728504a 100644 --- a/test/proxy/utils/Initializable.test.js +++ b/test/proxy/utils/Initializable.test.js @@ -1,8 +1,9 @@ const { expectRevert } = require('@openzeppelin/test-helpers'); -const { assert } = require('chai'); +const { expect } = require('chai'); const InitializableMock = artifacts.require('InitializableMock'); const ConstructorInitializableMock = artifacts.require('ConstructorInitializableMock'); +const ReinitializerMock = artifacts.require('ReinitializerMock'); const SampleChild = artifacts.require('SampleChild'); contract('Initializable', function (accounts) { @@ -13,7 +14,7 @@ contract('Initializable', function (accounts) { context('before initialize', function () { it('initializer has not run', async function () { - assert.isFalse(await this.contract.initializerRan()); + expect(await this.contract.initializerRan()).to.equal(false); }); }); @@ -23,7 +24,7 @@ contract('Initializable', function (accounts) { }); it('initializer has run', async function () { - assert.isTrue(await this.contract.initializerRan()); + expect(await this.contract.initializerRan()).to.equal(true); }); it('initializer does not run again', async function () { @@ -38,7 +39,7 @@ contract('Initializable', function (accounts) { it('onlyInitializing modifier succeeds', async function () { await this.contract.onlyInitializingNested(); - assert.isTrue(await this.contract.onlyInitializingRan()); + expect(await this.contract.onlyInitializingRan()).to.equal(true); }); }); @@ -49,15 +50,69 @@ contract('Initializable', function (accounts) { it('nested initializer can run during construction', async function () { const contract2 = await ConstructorInitializableMock.new(); - assert.isTrue(await contract2.initializerRan()); - assert.isTrue(await contract2.onlyInitializingRan()); + expect(await contract2.initializerRan()).to.equal(true); + expect(await contract2.onlyInitializingRan()).to.equal(true); + }); + + describe('reinitialization', function () { + beforeEach('deploying', async function () { + this.contract = await ReinitializerMock.new(); + }); + + it('can reinitialize', async function () { + expect(await this.contract.counter()).to.be.bignumber.equal('0'); + await this.contract.initialize(); + expect(await this.contract.counter()).to.be.bignumber.equal('1'); + await this.contract.reinitialize(2); + expect(await this.contract.counter()).to.be.bignumber.equal('2'); + await this.contract.reinitialize(3); + expect(await this.contract.counter()).to.be.bignumber.equal('3'); + }); + + it('can jump multiple steps', async function () { + expect(await this.contract.counter()).to.be.bignumber.equal('0'); + await this.contract.initialize(); + expect(await this.contract.counter()).to.be.bignumber.equal('1'); + await this.contract.reinitialize(128); + expect(await this.contract.counter()).to.be.bignumber.equal('2'); + }); + + it('cannot nest reinitializers', async function () { + expect(await this.contract.counter()).to.be.bignumber.equal('0'); + await expectRevert(this.contract.nestedReinitialize(2, 3), 'Initializable: contract is already initialized'); + await expectRevert(this.contract.nestedReinitialize(3, 2), 'Initializable: contract is already initialized'); + }); + + it('can chain reinitializers', async function () { + expect(await this.contract.counter()).to.be.bignumber.equal('0'); + await this.contract.chainReinitialize(2, 3); + expect(await this.contract.counter()).to.be.bignumber.equal('2'); + }); + + describe('contract locking', function () { + it('prevents initialization', async function () { + await this.contract.disableInitializers(); + await expectRevert(this.contract.initialize(), 'Initializable: contract is already initialized'); + }); + + it('prevents re-initialization', async function () { + await this.contract.disableInitializers(); + await expectRevert(this.contract.reinitialize(255), 'Initializable: contract is already initialized'); + }); + + it('can lock contract after initialization', async function () { + await this.contract.initialize(); + await this.contract.disableInitializers(); + await expectRevert(this.contract.reinitialize(255), 'Initializable: contract is already initialized'); + }); + }); }); describe('complex testing with inheritance', function () { - const mother = 12; + const mother = '12'; const gramps = '56'; - const father = 34; - const child = 78; + const father = '34'; + const child = '78'; beforeEach('deploying', async function () { this.contract = await SampleChild.new(); @@ -68,23 +123,23 @@ contract('Initializable', function (accounts) { }); it('initializes human', async function () { - assert.equal(await this.contract.isHuman(), true); + expect(await this.contract.isHuman()).to.be.equal(true); }); it('initializes mother', async function () { - assert.equal(await this.contract.mother(), mother); + expect(await this.contract.mother()).to.be.bignumber.equal(mother); }); it('initializes gramps', async function () { - assert.equal(await this.contract.gramps(), gramps); + expect(await this.contract.gramps()).to.be.bignumber.equal(gramps); }); it('initializes father', async function () { - assert.equal(await this.contract.father(), father); + expect(await this.contract.father()).to.be.bignumber.equal(father); }); it('initializes child', async function () { - assert.equal(await this.contract.child(), child); + expect(await this.contract.child()).to.be.bignumber.equal(child); }); }); });