diff --git a/ethers-providers/src/lib.rs b/ethers-providers/src/lib.rs index 126430eb..dd7f0412 100644 --- a/ethers-providers/src/lib.rs +++ b/ethers-providers/src/lib.rs @@ -76,7 +76,7 @@ pub use pending_transaction::PendingTransaction; mod stream; pub use futures_util::StreamExt; -pub use stream::{interval, FilterWatcher, DEFAULT_POLL_INTERVAL}; +pub use stream::{interval, FilterWatcher, TransactionStream, DEFAULT_POLL_INTERVAL}; mod pubsub; pub use pubsub::{PubsubClient, SubscriptionStream}; diff --git a/ethers-providers/src/pubsub.rs b/ethers-providers/src/pubsub.rs index fbd1e13b..f3a1dcb2 100644 --- a/ethers-providers/src/pubsub.rs +++ b/ethers-providers/src/pubsub.rs @@ -1,6 +1,6 @@ -use crate::{JsonRpcClient, Middleware, Provider}; +use crate::{JsonRpcClient, Middleware, Provider, TransactionStream}; -use ethers_core::types::U256; +use ethers_core::types::{TxHash, U256}; use futures_util::stream::Stream; use pin_project::{pin_project, pinned_drop}; @@ -97,3 +97,17 @@ where let _ = (*self.provider).as_ref().unsubscribe(self.id); } } + +impl<'a, P> SubscriptionStream<'a, P, TxHash> +where + P: PubsubClient, +{ + /// Returns a stream that yields the `Transaction`s for the transaction hashes this stream yields. + /// + /// This internally calls `Provider::get_transaction` with every new transaction. + /// No more than n futures will be buffered at any point in time, and less than n may also be + /// buffered depending on the state of each future. + pub fn transactions_unordered(self, n: usize) -> TransactionStream<'a, P, Self> { + TransactionStream::new(self.provider, self, n) + } +} diff --git a/ethers-providers/src/stream.rs b/ethers-providers/src/stream.rs index 1269400a..0c8d4613 100644 --- a/ethers-providers/src/stream.rs +++ b/ethers-providers/src/stream.rs @@ -1,12 +1,15 @@ -use crate::{JsonRpcClient, Middleware, PinBoxFut, Provider}; +use crate::{JsonRpcClient, Middleware, PinBoxFut, Provider, ProviderError}; -use ethers_core::types::U256; +use ethers_core::types::{Transaction, TxHash, U256}; use futures_core::stream::Stream; +use futures_core::Future; use futures_timer::Delay; +use futures_util::stream::FuturesUnordered; use futures_util::{stream, FutureExt, StreamExt}; use pin_project::pin_project; use serde::{de::DeserializeOwned, Serialize}; +use std::collections::VecDeque; use std::{ fmt::Debug, pin::Pin, @@ -120,3 +123,258 @@ where Poll::Pending } } + +impl<'a, P> FilterWatcher<'a, P, TxHash> +where + P: JsonRpcClient, +{ + /// Returns a stream that yields the `Transaction`s for the transaction hashes this stream yields. + /// + /// This internally calls `Provider::get_transaction` with every new transaction. + /// No more than n futures will be buffered at any point in time, and less than n may also be + /// buffered depending on the state of each future. + pub fn transactions_unordered(self, n: usize) -> TransactionStream<'a, P, Self> { + TransactionStream::new(self.provider, self, n) + } +} + +/// Errors `TransactionStream` can throw +#[derive(Debug, thiserror::Error)] +pub enum GetTransactionError { + #[error("Failed to get transaction `{0}`: {1}")] + ProviderError(TxHash, ProviderError), + /// `get_transaction` resulted in a `None` + #[error("Transaction `{0}` not found")] + NotFound(TxHash), +} + +impl From for ProviderError { + fn from(err: GetTransactionError) -> Self { + match err { + GetTransactionError::ProviderError(_, err) => err, + err @ GetTransactionError::NotFound(_) => ProviderError::CustomError(err.to_string()), + } + } +} + +type TransactionFut<'a> = Pin + 'a>>; + +type TransactionResult = Result; + +/// Drains a stream of transaction hashes and yields entire `Transaction`. +#[must_use = "streams do nothing unless polled"] +pub struct TransactionStream<'a, P, St> { + /// Currently running futures pending completion. + pending: FuturesUnordered>, + /// Temporary buffered transaction that get started as soon as another future finishes. + buffered: VecDeque, + /// The provider that gets the transaction + provider: &'a Provider

, + /// A stream of transaction hashes. + stream: St, + /// max allowed futures to execute at once. + max_concurrent: usize, +} + +impl<'a, P: JsonRpcClient, St> TransactionStream<'a, P, St> { + /// Create a new `TransactionStream` instance + pub fn new(provider: &'a Provider

, stream: St, max_concurrent: usize) -> Self { + Self { + pending: Default::default(), + buffered: Default::default(), + provider, + stream, + max_concurrent, + } + } + + /// Push a future into the set + fn push_tx(&mut self, tx: TxHash) { + let fut = self + .provider + .get_transaction(tx) + .then(move |res| match res { + Ok(Some(tx)) => futures_util::future::ok(tx), + Ok(None) => futures_util::future::err(GetTransactionError::NotFound(tx)), + Err(err) => futures_util::future::err(GetTransactionError::ProviderError(tx, err)), + }); + self.pending.push(Box::pin(fut)); + } +} + +impl<'a, P, St> Stream for TransactionStream<'a, P, St> +where + P: JsonRpcClient, + St: Stream + Unpin + 'a, +{ + type Item = TransactionResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + // drain buffered transactions first + while this.pending.len() < this.max_concurrent { + if let Some(tx) = this.buffered.pop_front() { + this.push_tx(tx); + } else { + break; + } + } + + let mut stream_done = false; + loop { + match Stream::poll_next(Pin::new(&mut this.stream), cx) { + Poll::Ready(Some(tx)) => { + if this.pending.len() < this.max_concurrent { + this.push_tx(tx); + } else { + this.buffered.push_back(tx); + } + } + Poll::Ready(None) => { + stream_done = true; + break; + } + _ => break, + } + } + + // poll running futures + if let tx @ Poll::Ready(Some(_)) = this.pending.poll_next_unpin(cx) { + return tx; + } + + if stream_done && this.pending.is_empty() { + // all done + return Poll::Ready(None); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Http, Ws}; + use ethers_core::{ + types::{TransactionReceipt, TransactionRequest}, + utils::{Ganache, Geth}, + }; + use futures_util::{FutureExt, StreamExt}; + use std::collections::HashSet; + use std::convert::TryFrom; + + #[tokio::test] + async fn can_stream_pending_transactions() { + let num_txs = 5; + let geth = Geth::new().block_time(2u64).spawn(); + let provider = Provider::::try_from(geth.endpoint()) + .unwrap() + .interval(Duration::from_millis(1000)); + let ws = Ws::connect(geth.ws_endpoint()).await.unwrap(); + let ws_provider = Provider::new(ws); + + let accounts = provider.get_accounts().await.unwrap(); + let tx = TransactionRequest::new() + .from(accounts[0]) + .to(accounts[0]) + .value(1e18 as u64); + + let mut sending = futures_util::future::join_all( + std::iter::repeat(tx.clone()).take(num_txs).map(|tx| async { + provider + .send_transaction(tx, None) + .await + .unwrap() + .await + .unwrap() + }), + ) + .fuse(); + + let mut watch_tx_stream = provider + .watch_pending_transactions() + .await + .unwrap() + .transactions_unordered(num_txs) + .fuse(); + + let mut sub_tx_stream = ws_provider + .subscribe_pending_txs() + .await + .unwrap() + .transactions_unordered(2) + .fuse(); + + let mut sent: Option> = None; + let mut watch_received: Vec = Vec::with_capacity(num_txs); + let mut sub_received: Vec = Vec::with_capacity(num_txs); + + loop { + futures_util::select! { + txs = sending => { + sent = Some(txs) + }, + tx = watch_tx_stream.next() => watch_received.push(tx.unwrap().unwrap()), + tx = sub_tx_stream.next() => sub_received.push(tx.unwrap().unwrap()), + }; + if watch_received.len() == num_txs && sub_received.len() == num_txs { + if let Some(ref sent) = sent { + assert_eq!(sent.len(), watch_received.len()); + let sent_txs = sent + .into_iter() + .map(|tx| tx.transaction_hash) + .collect::>(); + assert_eq!(sent_txs, watch_received.iter().map(|tx| tx.hash).collect()); + assert_eq!(sent_txs, sub_received.iter().map(|tx| tx.hash).collect()); + break; + } + } + } + } + + #[tokio::test] + async fn can_stream_transactions() { + let ganache = Ganache::new().block_time(2u64).spawn(); + let provider = Provider::::try_from(ganache.endpoint()) + .unwrap() + .with_sender(ganache.addresses()[0]); + + let accounts = provider.get_accounts().await.unwrap(); + + let tx = TransactionRequest::new() + .from(accounts[0]) + .to(accounts[0]) + .value(1e18 as u64); + + let txs = + futures_util::future::join_all(std::iter::repeat(tx.clone()).take(3).map(|tx| async { + provider + .send_transaction(tx, None) + .await + .unwrap() + .await + .unwrap() + })) + .await; + + let stream = TransactionStream::new( + &provider, + stream::iter(txs.iter().map(|tx| tx.transaction_hash)), + 10, + ); + let res = stream + .collect::>() + .await + .into_iter() + .collect::, _>>() + .unwrap(); + + assert_eq!(res.len(), txs.len()); + assert_eq!( + res.into_iter().map(|tx| tx.hash).collect::>(), + txs.into_iter().map(|tx| tx.transaction_hash).collect() + ); + } +}