diff --git a/CHANGELOG.md b/CHANGELOG.md index eddab7cc4..6fb03e735 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * `EIP712`: cache `address(this)` to immutable storage to avoid potential issues if a vanilla contract is used in a delegatecall context. ([#2852](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/#2852)) * Add internal `_setApprovalForAll` to `ERC721` and `ERC1155`. ([#2834](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2834)) * `Governor`: shift vote start and end by one block to better match Compound's GovernorBravo and prevent voting at the Governor level if the voting snapshot is not ready. ([#2892](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/#2892)) + * `PaymentSplitter`: now supports ERC20 assets in addition to Ether. ([#2858](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/#2858)) ## 4.3.2 (2021-09-14) diff --git a/contracts/finance/PaymentSplitter.sol b/contracts/finance/PaymentSplitter.sol index 4d6810411..83f8316a0 100644 --- a/contracts/finance/PaymentSplitter.sol +++ b/contracts/finance/PaymentSplitter.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.0; +import "../token/ERC20/utils/SafeERC20.sol"; import "../utils/Address.sol"; import "../utils/Context.sol"; @@ -17,10 +18,15 @@ import "../utils/Context.sol"; * `PaymentSplitter` follows a _pull payment_ model. This means that payments are not automatically forwarded to the * accounts but kept in this contract, and the actual transfer is triggered as a separate step by calling the {release} * function. + * + * NOTE: This contract assumes that ERC20 tokens will behave similarly to native tokens (Ether). Rebasing tokens, and + * tokens that apply fees during transfers, are likely to not be supported as expected. If in doubt, we encourage you + * to run tests before sending real value to this contract. */ contract PaymentSplitter is Context { event PayeeAdded(address account, uint256 shares); event PaymentReleased(address to, uint256 amount); + event ERC20PaymentReleased(IERC20 indexed token, address to, uint256 amount); event PaymentReceived(address from, uint256 amount); uint256 private _totalShares; @@ -30,6 +36,9 @@ contract PaymentSplitter is Context { mapping(address => uint256) private _released; address[] private _payees; + mapping(IERC20 => uint256) private _erc20TotalReleased; + mapping(IERC20 => mapping(address => uint256)) private _erc20Released; + /** * @dev Creates an instance of `PaymentSplitter` where each account in `payees` is assigned the number of shares at * the matching position in the `shares` array. @@ -73,6 +82,14 @@ contract PaymentSplitter is Context { return _totalReleased; } + /** + * @dev Getter for the total amount of `token` already released. `token` should be the address of an IERC20 + * contract. + */ + function totalReleased(IERC20 token) public view returns (uint256) { + return _erc20TotalReleased[token]; + } + /** * @dev Getter for the amount of shares held by an account. */ @@ -87,6 +104,14 @@ contract PaymentSplitter is Context { return _released[account]; } + /** + * @dev Getter for the amount of `token` tokens already released to a payee. `token` should be the address of an + * IERC20 contract. + */ + function released(IERC20 token, address account) public view returns (uint256) { + return _erc20Released[token][account]; + } + /** * @dev Getter for the address of the payee number `index`. */ @@ -101,18 +126,50 @@ contract PaymentSplitter is Context { function release(address payable account) public virtual { require(_shares[account] > 0, "PaymentSplitter: account has no shares"); - uint256 totalReceived = address(this).balance + _totalReleased; - uint256 payment = (totalReceived * _shares[account]) / _totalShares - _released[account]; + uint256 totalReceived = address(this).balance + totalReleased(); + uint256 payment = _pendingPayment(account, totalReceived, released(account)); require(payment != 0, "PaymentSplitter: account is not due payment"); - _released[account] = _released[account] + payment; - _totalReleased = _totalReleased + payment; + _released[account] += payment; + _totalReleased += payment; Address.sendValue(account, payment); emit PaymentReleased(account, payment); } + /** + * @dev Triggers a transfer to `account` of the amount of `token` tokens they are owed, according to their + * percentage of the total shares and their previous withdrawals. `token` must be the address of an IERC20 + * contract. + */ + function release(IERC20 token, address account) public virtual { + require(_shares[account] > 0, "PaymentSplitter: account has no shares"); + + uint256 totalReceived = token.balanceOf(address(this)) + totalReleased(token); + uint256 payment = _pendingPayment(account, totalReceived, released(token, account)); + + require(payment != 0, "PaymentSplitter: account is not due payment"); + + _erc20Released[token][account] += payment; + _erc20TotalReleased[token] += payment; + + SafeERC20.safeTransfer(token, account, payment); + emit ERC20PaymentReleased(token, account, payment); + } + + /** + * @dev internal logic for computing the pending payment of an `account` given the token historical balances and + * already released amounts. + */ + function _pendingPayment( + address account, + uint256 totalReceived, + uint256 alreadyReleased + ) private view returns (uint256) { + return (totalReceived * _shares[account]) / _totalShares - alreadyReleased; + } + /** * @dev Add a new payee to the contract. * @param account The address of the payee to add. diff --git a/test/finance/PaymentSplitter.test.js b/test/finance/PaymentSplitter.test.js index 8e1ad16c6..6df0c4cb9 100644 --- a/test/finance/PaymentSplitter.test.js +++ b/test/finance/PaymentSplitter.test.js @@ -4,6 +4,7 @@ const { ZERO_ADDRESS } = constants; const { expect } = require('chai'); const PaymentSplitter = artifacts.require('PaymentSplitter'); +const Token = artifacts.require('ERC20Mock'); contract('PaymentSplitter', function (accounts) { const [ owner, payee1, payee2, payee3, nonpayee1, payer1 ] = accounts; @@ -50,6 +51,7 @@ contract('PaymentSplitter', function (accounts) { this.shares = [20, 10, 70]; this.contract = await PaymentSplitter.new(this.payees, this.shares); + this.token = await Token.new('MyToken', 'MT', owner, ether('1000')); }); it('has total shares', async function () { @@ -63,10 +65,18 @@ contract('PaymentSplitter', function (accounts) { })); }); - it('accepts payments', async function () { - await send.ether(owner, this.contract.address, amount); + describe('accepts payments', async function () { + it('Ether', async function () { + await send.ether(owner, this.contract.address, amount); - expect(await balance.current(this.contract.address)).to.be.bignumber.equal(amount); + expect(await balance.current(this.contract.address)).to.be.bignumber.equal(amount); + }); + + it('Token', async function () { + await this.token.transfer(this.contract.address, amount, { from: owner }); + + expect(await this.token.balanceOf(this.contract.address)).to.be.bignumber.equal(amount); + }); }); describe('shares', async function () { @@ -80,51 +90,107 @@ contract('PaymentSplitter', function (accounts) { }); describe('release', async function () { - it('reverts if no funds to claim', async function () { - await expectRevert(this.contract.release(payee1), - 'PaymentSplitter: account is not due payment', - ); + describe('Ether', async function () { + it('reverts if no funds to claim', async function () { + await expectRevert(this.contract.release(payee1), + 'PaymentSplitter: account is not due payment', + ); + }); + it('reverts if non-payee want to claim', async function () { + await send.ether(payer1, this.contract.address, amount); + await expectRevert(this.contract.release(nonpayee1), + 'PaymentSplitter: account has no shares', + ); + }); }); - it('reverts if non-payee want to claim', async function () { - await send.ether(payer1, this.contract.address, amount); - await expectRevert(this.contract.release(nonpayee1), - 'PaymentSplitter: account has no shares', - ); + + describe('Token', async function () { + it('reverts if no funds to claim', async function () { + await expectRevert(this.contract.release(this.token.address, payee1), + 'PaymentSplitter: account is not due payment', + ); + }); + it('reverts if non-payee want to claim', async function () { + await this.token.transfer(this.contract.address, amount, { from: owner }); + await expectRevert(this.contract.release(this.token.address, nonpayee1), + 'PaymentSplitter: account has no shares', + ); + }); }); }); - it('distributes funds to payees', async function () { - await send.ether(payer1, this.contract.address, amount); + describe('distributes funds to payees', async function () { + it('Ether', async function () { + await send.ether(payer1, this.contract.address, amount); - // receive funds - const initBalance = await balance.current(this.contract.address); - expect(initBalance).to.be.bignumber.equal(amount); + // receive funds + const initBalance = await balance.current(this.contract.address); + expect(initBalance).to.be.bignumber.equal(amount); - // distribute to payees + // distribute to payees - const tracker1 = await balance.tracker(payee1); - const { logs: logs1 } = await this.contract.release(payee1); - const profit1 = await tracker1.delta(); - expect(profit1).to.be.bignumber.equal(ether('0.20')); - expectEvent.inLogs(logs1, 'PaymentReleased', { to: payee1, amount: profit1 }); + const tracker1 = await balance.tracker(payee1); + const { logs: logs1 } = await this.contract.release(payee1); + const profit1 = await tracker1.delta(); + expect(profit1).to.be.bignumber.equal(ether('0.20')); + expectEvent.inLogs(logs1, 'PaymentReleased', { to: payee1, amount: profit1 }); - const tracker2 = await balance.tracker(payee2); - const { logs: logs2 } = await this.contract.release(payee2); - const profit2 = await tracker2.delta(); - expect(profit2).to.be.bignumber.equal(ether('0.10')); - expectEvent.inLogs(logs2, 'PaymentReleased', { to: payee2, amount: profit2 }); + const tracker2 = await balance.tracker(payee2); + const { logs: logs2 } = await this.contract.release(payee2); + const profit2 = await tracker2.delta(); + expect(profit2).to.be.bignumber.equal(ether('0.10')); + expectEvent.inLogs(logs2, 'PaymentReleased', { to: payee2, amount: profit2 }); - const tracker3 = await balance.tracker(payee3); - const { logs: logs3 } = await this.contract.release(payee3); - const profit3 = await tracker3.delta(); - expect(profit3).to.be.bignumber.equal(ether('0.70')); - expectEvent.inLogs(logs3, 'PaymentReleased', { to: payee3, amount: profit3 }); + const tracker3 = await balance.tracker(payee3); + const { logs: logs3 } = await this.contract.release(payee3); + const profit3 = await tracker3.delta(); + expect(profit3).to.be.bignumber.equal(ether('0.70')); + expectEvent.inLogs(logs3, 'PaymentReleased', { to: payee3, amount: profit3 }); - // end balance should be zero - expect(await balance.current(this.contract.address)).to.be.bignumber.equal('0'); + // end balance should be zero + expect(await balance.current(this.contract.address)).to.be.bignumber.equal('0'); - // check correct funds released accounting - expect(await this.contract.totalReleased()).to.be.bignumber.equal(initBalance); + // check correct funds released accounting + expect(await this.contract.totalReleased()).to.be.bignumber.equal(initBalance); + }); + + it('Token', async function () { + expect(await this.token.balanceOf(payee1)).to.be.bignumber.equal('0'); + expect(await this.token.balanceOf(payee2)).to.be.bignumber.equal('0'); + expect(await this.token.balanceOf(payee3)).to.be.bignumber.equal('0'); + + await this.token.transfer(this.contract.address, amount, { from: owner }); + + expectEvent( + await this.contract.release(this.token.address, payee1), + 'ERC20PaymentReleased', + { token: this.token.address, to: payee1, amount: ether('0.20') }, + ); + + await this.token.transfer(this.contract.address, amount, { from: owner }); + + expectEvent( + await this.contract.release(this.token.address, payee1), + 'ERC20PaymentReleased', + { token: this.token.address, to: payee1, amount: ether('0.20') }, + ); + + expectEvent( + await this.contract.release(this.token.address, payee2), + 'ERC20PaymentReleased', + { token: this.token.address, to: payee2, amount: ether('0.20') }, + ); + + expectEvent( + await this.contract.release(this.token.address, payee3), + 'ERC20PaymentReleased', + { token: this.token.address, to: payee3, amount: ether('1.40') }, + ); + + expect(await this.token.balanceOf(payee1)).to.be.bignumber.equal(ether('0.40')); + expect(await this.token.balanceOf(payee2)).to.be.bignumber.equal(ether('0.20')); + expect(await this.token.balanceOf(payee3)).to.be.bignumber.equal(ether('1.40')); + }); }); }); });