diff --git a/CHANGELOG.md b/CHANGELOG.md index c3bb4800..c9884ccb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,8 @@ - Fix `http Provider` data race when generating new request `id`s. - Add support for `net_version` RPC method. [595](https://github.com/gakonst/ethers-rs/pull/595) +- Add support for `evm_snapshot` and `evm_revert` dev RPC methods. + [640](https://github.com/gakonst/ethers-rs/pull/640) ### Unreleased diff --git a/Cargo.toml b/Cargo.toml index 7877c4cf..c051efcb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,7 @@ ws = ["ethers-providers/ws"] ipc = ["ethers-providers/ipc"] rustls = ["ethers-providers/rustls"] openssl = ["ethers-providers/openssl"] +dev-rpc = ["ethers-providers/dev-rpc"] ## signers ledger = ["ethers-signers/ledger"] yubi = ["ethers-signers/yubi"] diff --git a/ethers-providers/Cargo.toml b/ethers-providers/Cargo.toml index d5b0618b..c2d7e164 100644 --- a/ethers-providers/Cargo.toml +++ b/ethers-providers/Cargo.toml @@ -65,3 +65,4 @@ ipc = ["tokio", "tokio/io-util", "tokio-util", "bytes"] openssl = ["tokio-tungstenite/native-tls", "reqwest/native-tls"] rustls = ["tokio-tungstenite/rustls-tls", "reqwest/rustls-tls"] +dev-rpc = [] diff --git a/ethers-providers/src/lib.rs b/ethers-providers/src/lib.rs index 99dd869d..ba538a8a 100644 --- a/ethers-providers/src/lib.rs +++ b/ethers-providers/src/lib.rs @@ -32,6 +32,10 @@ use std::{error::Error, fmt::Debug, future::Future, pin::Pin, str::FromStr}; pub use provider::{FilterKind, Provider, ProviderError}; +// feature-enabled support for dev-rpc methods +#[cfg(feature = "dev-rpc")] +pub use provider::dev_rpc::DevRpcMiddleware; + /// A simple gas escalation policy pub type EscalationPolicy = Box U256 + Send + Sync>; diff --git a/ethers-providers/src/provider.rs b/ethers-providers/src/provider.rs index b4b8a12b..01c9b678 100644 --- a/ethers-providers/src/provider.rs +++ b/ethers-providers/src/provider.rs @@ -1021,6 +1021,183 @@ impl TryFrom for Provider { } } +/// A middleware supporting development-specific JSON RPC methods +/// +/// # Example +/// +///``` +/// use ethers_providers::{Provider, Http, Middleware, DevRpcMiddleware}; +/// use ethers_core::types::TransactionRequest; +/// use ethers_core::utils::Ganache; +/// use std::convert::TryFrom; +/// +/// # #[tokio::main] +/// # async fn main() -> Result<(), Box> { +/// let ganache = Ganache::new().spawn(); +/// let provider = Provider::::try_from(ganache.endpoint()).unwrap(); +/// let client = DevRpcMiddleware::new(provider); +/// +/// // snapshot the initial state +/// let block0 = client.get_block_number().await.unwrap(); +/// let snap_id = client.snapshot().await.unwrap(); +/// +/// // send a transaction +/// let accounts = client.get_accounts().await?; +/// let from = accounts[0]; +/// let to = accounts[1]; +/// let balance_before = client.get_balance(to, None).await?; +/// let tx = TransactionRequest::new().to(to).value(1000).from(from); +/// client.send_transaction(tx, None).await?.await?; +/// let balance_after = client.get_balance(to, None).await?; +/// assert_eq!(balance_after, balance_before + 1000); +/// +/// // revert to snapshot +/// client.revert_to_snapshot(snap_id).await.unwrap(); +/// let balance_after_revert = client.get_balance(to, None).await?; +/// assert_eq!(balance_after_revert, balance_before); +/// # Ok(()) +/// # } +/// ``` +#[cfg(feature = "dev-rpc")] +pub mod dev_rpc { + use crate::{FromErr, Middleware, ProviderError}; + use async_trait::async_trait; + use ethers_core::types::U256; + use thiserror::Error; + + use std::fmt::Debug; + + #[derive(Clone, Debug)] + pub struct DevRpcMiddleware(M); + + #[derive(Error, Debug)] + pub enum DevRpcMiddlewareError { + #[error("{0}")] + MiddlewareError(M::Error), + + #[error("{0}")] + ProviderError(ProviderError), + + #[error("Could not revert to snapshot")] + NoSnapshot, + } + + #[async_trait] + impl Middleware for DevRpcMiddleware { + type Error = DevRpcMiddlewareError; + type Provider = M::Provider; + type Inner = M; + + fn inner(&self) -> &M { + &self.0 + } + } + + impl FromErr for DevRpcMiddlewareError { + fn from(src: M::Error) -> DevRpcMiddlewareError { + DevRpcMiddlewareError::MiddlewareError(src) + } + } + + impl From for DevRpcMiddlewareError + where + M: Middleware, + { + fn from(src: ProviderError) -> Self { + Self::ProviderError(src) + } + } + + impl DevRpcMiddleware { + pub fn new(inner: M) -> Self { + Self(inner) + } + + // both ganache and hardhat increment snapshot id even if no state has changed + pub async fn snapshot(&self) -> Result> { + self.provider().request::<(), U256>("evm_snapshot", ()).await.map_err(From::from) + } + + pub async fn revert_to_snapshot(&self, id: U256) -> Result<(), DevRpcMiddlewareError> { + let ok = self + .provider() + .request::<[U256; 1], bool>("evm_revert", [id]) + .await + .map_err(DevRpcMiddlewareError::ProviderError)?; + if ok { + Ok(()) + } else { + Err(DevRpcMiddlewareError::NoSnapshot) + } + } + } + #[cfg(test)] + // Celo blocks can not get parsed when used with Ganache + #[cfg(not(feature = "celo"))] + mod tests { + use super::*; + use crate::{Http, Provider}; + use ethers_core::utils::Ganache; + use std::convert::TryFrom; + + #[tokio::test] + async fn test_snapshot() { + // launch ganache + let ganache = Ganache::new().spawn(); + let provider = Provider::::try_from(ganache.endpoint()).unwrap(); + let client = DevRpcMiddleware::new(provider); + + // snapshot initial state + let block0 = client.get_block_number().await.unwrap(); + let time0 = client.get_block(block0).await.unwrap().unwrap().timestamp; + let snap_id0 = client.snapshot().await.unwrap(); + + // mine a new block + client.provider().mine(1).await.unwrap(); + + // snapshot state + let block1 = client.get_block_number().await.unwrap(); + let time1 = client.get_block(block1).await.unwrap().unwrap().timestamp; + let snap_id1 = client.snapshot().await.unwrap(); + + // mine some blocks + client.provider().mine(5).await.unwrap(); + + // snapshot state + let block2 = client.get_block_number().await.unwrap(); + let time2 = client.get_block(block2).await.unwrap().unwrap().timestamp; + let snap_id2 = client.snapshot().await.unwrap(); + + // mine some blocks + client.provider().mine(5).await.unwrap(); + + // revert_to_snapshot should reset state to snap id + client.revert_to_snapshot(snap_id2).await.unwrap(); + let block = client.get_block_number().await.unwrap(); + let time = client.get_block(block).await.unwrap().unwrap().timestamp; + assert_eq!(block, block2); + assert_eq!(time, time2); + + client.revert_to_snapshot(snap_id1).await.unwrap(); + let block = client.get_block_number().await.unwrap(); + let time = client.get_block(block).await.unwrap().unwrap().timestamp; + assert_eq!(block, block1); + assert_eq!(time, time1); + + // revert_to_snapshot should throw given non-existent or + // previously used snapshot + let result = client.revert_to_snapshot(snap_id1).await; + assert!(result.is_err()); + + client.revert_to_snapshot(snap_id0).await.unwrap(); + let block = client.get_block_number().await.unwrap(); + let time = client.get_block(block).await.unwrap().unwrap().timestamp; + assert_eq!(block, block0); + assert_eq!(time, time0); + } + } +} + #[cfg(test)] #[cfg(not(target_arch = "wasm32"))] mod tests {