From 4fd742f8cea5fd4de1577f1a8541f8f9c4240ceb Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Tue, 3 Jan 2023 15:15:51 +0100 Subject: [PATCH] feat: windows ipc provider (named pipe) (#1976) * fmt: imports * fix: ipc tests * fmt * chore: move ws macros * chore: gate ipc to unix family * chore: make tokio optional * feat: initial named_pipe * feat: windows ipc * chore: update Provider * chore: clippy * chore: use Path instead of OsStr * chore: clippy * fix: docs * lf * lf * test: better subscription tests * docs * fix: ipc doctest * chore: make winapi optional * fix: optional tokio --- Cargo.lock | 5 +- ethers-providers/Cargo.toml | 8 +- ethers-providers/src/lib.rs | 9 +- ethers-providers/src/provider.rs | 27 +- ethers-providers/src/transports/common.rs | 7 +- ethers-providers/src/transports/http.rs | 5 +- ethers-providers/src/transports/ipc.rs | 299 +++++++++++++++++----- ethers-providers/src/transports/mock.rs | 1 - ethers-providers/src/transports/mod.rs | 28 +- ethers-providers/src/transports/quorum.rs | 13 +- ethers-providers/src/transports/rw.rs | 3 - ethers-providers/src/transports/ws.rs | 57 +++-- ethers-providers/tests/ws_errors.rs | 37 ++- 13 files changed, 335 insertions(+), 164 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f6f73319..75cf5208 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1533,6 +1533,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-timer", "web-sys", + "winapi", "ws_stream_wasm", ] @@ -3702,9 +3703,9 @@ dependencies = [ [[package]] name = "scoped-tls" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea6a9290e3c9cf0f18145ef7ffa62d68ee0bf5fcd651017e586dc7fd5da448c2" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" [[package]] name = "scopeguard" diff --git a/ethers-providers/Cargo.toml b/ethers-providers/Cargo.toml index e3117d68..e81bafcf 100644 --- a/ethers-providers/Cargo.toml +++ b/ethers-providers/Cargo.toml @@ -46,9 +46,12 @@ bytes = { version = "1.3.0", default-features = false, optional = true } once_cell = "1.17.0" hashers = "1.0.1" +[target.'cfg(target_family = "windows")'.dependencies] +winapi = { version = "0.3", optional = true } + [target.'cfg(not(target_arch = "wasm32"))'.dependencies] # tokio -tokio = { version = "1.18", features = ["time"] } +tokio = { version = "1.18", default-features = false, features = ["time"] } tokio-tungstenite = { version = "0.18.0", default-features = false, features = [ "connect", ], optional = true } @@ -71,8 +74,9 @@ tempfile = "3.3.0" [features] default = ["ws", "rustls"] celo = ["ethers-core/celo"] + ws = ["tokio-tungstenite", "futures-channel"] -ipc = ["tokio/io-util", "bytes", "futures-channel"] +ipc = ["tokio/io-util", "bytes", "futures-channel", "winapi"] openssl = ["tokio-tungstenite/native-tls", "reqwest/native-tls"] # we use the webpki roots so we can build static binaries w/o any root cert dependencies diff --git a/ethers-providers/src/lib.rs b/ethers-providers/src/lib.rs index 31ef49fe..fd30d5b1 100644 --- a/ethers-providers/src/lib.rs +++ b/ethers-providers/src/lib.rs @@ -3,8 +3,8 @@ #![deny(rustdoc::broken_intra_doc_links)] #![allow(clippy::type_complexity)] #![doc = include_str!("../README.md")] + mod transports; -use futures_util::future::join_all; pub use transports::*; mod provider; @@ -40,7 +40,11 @@ pub mod erc; use async_trait::async_trait; use auto_impl::auto_impl; -use ethers_core::types::transaction::{eip2718::TypedTransaction, eip2930::AccessListWithGasUsed}; +use ethers_core::types::{ + transaction::{eip2718::TypedTransaction, eip2930::AccessListWithGasUsed}, + *, +}; +use futures_util::future::join_all; use serde::{de::DeserializeOwned, Serialize}; use std::{error::Error, fmt::Debug, future::Future, pin::Pin}; use url::Url; @@ -75,7 +79,6 @@ pub trait JsonRpcClient: Debug + Send + Sync { R: DeserializeOwned; } -use ethers_core::types::*; pub trait FromErr { fn from(src: T) -> Self; } diff --git a/ethers-providers/src/provider.rs b/ethers-providers/src/provider.rs index 97d5a995..37e85262 100644 --- a/ethers-providers/src/provider.rs +++ b/ethers-providers/src/provider.rs @@ -21,26 +21,24 @@ use ethers_core::{ abi::{self, Detokenize, ParamType}, types::{ transaction::{eip2718::TypedTransaction, eip2930::AccessListWithGasUsed}, - Address, Block, BlockId, BlockNumber, BlockTrace, Bytes, EIP1186ProofResponse, FeeHistory, - Filter, FilterBlockOption, GethDebugTracingCallOptions, GethDebugTracingOptions, GethTrace, - Log, NameOrAddress, Selector, Signature, Trace, TraceFilter, TraceType, Transaction, - TransactionReceipt, TransactionRequest, TxHash, TxpoolContent, TxpoolInspect, TxpoolStatus, - H256, U256, U64, + Address, Block, BlockId, BlockNumber, BlockTrace, Bytes, Chain, EIP1186ProofResponse, + FeeHistory, Filter, FilterBlockOption, GethDebugTracingCallOptions, + GethDebugTracingOptions, GethTrace, Log, NameOrAddress, Selector, Signature, Trace, + TraceFilter, TraceType, Transaction, TransactionReceipt, TransactionRequest, TxHash, + TxpoolContent, TxpoolInspect, TxpoolStatus, H256, U256, U64, }, utils, }; +use futures_util::{lock::Mutex, try_join}; use hex::FromHex; use serde::{de::DeserializeOwned, Serialize}; -use thiserror::Error; -use url::{ParseError, Url}; - -use ethers_core::types::Chain; -use futures_util::{lock::Mutex, try_join}; use std::{ collections::VecDeque, convert::TryFrom, fmt::Debug, str::FromStr, sync::Arc, time::Duration, }; +use thiserror::Error; use tracing::trace; use tracing_futures::Instrument; +use url::{ParseError, Url}; #[derive(Copy, Clone)] pub enum NodeClient { @@ -1415,9 +1413,14 @@ impl Provider { } } -#[cfg(all(target_family = "unix", feature = "ipc"))] +#[cfg(all(feature = "ipc", any(unix, windows)))] impl Provider { - /// Direct connection to an IPC socket. + #[cfg_attr(unix, doc = "Connects to the Unix socket at the provided path.")] + #[cfg_attr(windows, doc = "Connects to the named pipe at the provided path.\n")] + #[cfg_attr( + windows, + doc = r"Note: the path must be the fully qualified, like: `\\.\pipe\`." + )] pub async fn connect_ipc(path: impl AsRef) -> Result { let ipc = crate::Ipc::connect(path).await?; Ok(Self::new(ipc)) diff --git a/ethers-providers/src/transports/common.rs b/ethers-providers/src/transports/common.rs index 45494e33..747e1405 100644 --- a/ethers-providers/src/transports/common.rs +++ b/ethers-providers/src/transports/common.rs @@ -1,17 +1,16 @@ // Code adapted from: https://github.com/althea-net/guac_rs/tree/master/web3/src/jsonrpc -use std::fmt; +use ethers_core::types::U256; use serde::{ de::{self, MapAccess, Unexpected, Visitor}, Deserialize, Serialize, }; use serde_json::{value::RawValue, Value}; +use std::fmt; use thiserror::Error; -use ethers_core::types::U256; - -#[derive(Deserialize, Debug, Clone, Error)] /// A JSON-RPC 2.0 error +#[derive(Deserialize, Debug, Clone, Error)] pub struct JsonRpcError { /// The error code pub code: i64, diff --git a/ethers-providers/src/transports/http.rs b/ethers-providers/src/transports/http.rs index 976aab64..b110c66f 100644 --- a/ethers-providers/src/transports/http.rs +++ b/ethers-providers/src/transports/http.rs @@ -1,6 +1,7 @@ // Code adapted from: https://github.com/althea-net/guac_rs/tree/master/web3/src/jsonrpc -use crate::{provider::ProviderError, JsonRpcClient}; +use super::common::{Authorization, JsonRpcError, Request, Response}; +use crate::{provider::ProviderError, JsonRpcClient}; use async_trait::async_trait; use reqwest::{header::HeaderValue, Client, Error as ReqwestError}; use serde::{de::DeserializeOwned, Serialize}; @@ -11,8 +12,6 @@ use std::{ use thiserror::Error; use url::Url; -use super::common::{Authorization, JsonRpcError, Request, Response}; - /// A low-level JSON-RPC Client over HTTP. /// /// # Example diff --git a/ethers-providers/src/transports/ipc.rs b/ethers-providers/src/transports/ipc.rs index acd111cf..8f3c597b 100644 --- a/ethers-providers/src/transports/ipc.rs +++ b/ethers-providers/src/transports/ipc.rs @@ -1,7 +1,22 @@ +use super::common::Params; +use crate::{ + provider::ProviderError, + transports::common::{JsonRpcError, Request, Response}, + JsonRpcClient, PubsubClient, +}; +use async_trait::async_trait; +use bytes::{Buf, BytesMut}; +use ethers_core::types::U256; +use futures_channel::mpsc; +use futures_util::stream::StreamExt; +use hashers::fx_hash::FxHasher64; +use serde::{de::DeserializeOwned, Serialize}; +use serde_json::{value::RawValue, Deserializer}; use std::{ cell::RefCell, convert::Infallible, hash::BuildHasherDefault, + io, path::Path, sync::{ atomic::{AtomicU64, Ordering}, @@ -9,40 +24,194 @@ use std::{ }, thread, }; - -use async_trait::async_trait; -use bytes::{Buf as _, BytesMut}; -use ethers_core::types::U256; -use futures_channel::mpsc; -use futures_util::stream::StreamExt as _; -use hashers::fx_hash::FxHasher64; -use serde::{de::DeserializeOwned, Serialize}; -use serde_json::{value::RawValue, Deserializer}; use thiserror::Error; use tokio::{ - io::{AsyncReadExt as _, AsyncWriteExt as _, BufReader}, - net::{ - unix::{ReadHalf, WriteHalf}, - UnixStream, - }, + io::{AsyncReadExt, AsyncWriteExt, BufReader}, runtime, sync::oneshot::{self, error::RecvError}, }; -use crate::{ - provider::ProviderError, - transports::common::{JsonRpcError, Request, Response}, - JsonRpcClient, PubsubClient, -}; - -use super::common::Params; - type FxHashMap = std::collections::HashMap>; type Pending = oneshot::Sender, JsonRpcError>>; type Subscription = mpsc::UnboundedSender>; -/// Unix Domain Sockets (IPC) transport. +#[cfg(unix)] +#[doc(hidden)] +mod imp { + pub(super) use tokio::net::{ + unix::{ReadHalf, WriteHalf}, + UnixStream as Stream, + }; +} + +#[cfg(windows)] +#[doc(hidden)] +mod imp { + use super::*; + use std::{ + ops::{Deref, DerefMut}, + pin::Pin, + task::{Context, Poll}, + time::Duration, + }; + use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::windows::named_pipe::{ClientOptions, NamedPipeClient}, + time::sleep, + }; + use winapi::shared::winerror; + + /// Wrapper around [NamedPipeClient] to have the same methods as a UnixStream. + /// + /// Should not be exported. + #[repr(transparent)] + pub(super) struct Stream(pub NamedPipeClient); + + impl Deref for Stream { + type Target = NamedPipeClient; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl DerefMut for Stream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + impl Stream { + pub async fn connect(addr: impl AsRef) -> Result { + let addr = addr.as_ref().as_os_str(); + loop { + match ClientOptions::new().open(addr) { + Ok(client) => break Ok(Self(client)), + Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (), + Err(e) => break Err(e), + } + + sleep(Duration::from_millis(50)).await; + } + } + + #[allow(unsafe_code)] + pub fn split(&mut self) -> (ReadHalf, WriteHalf) { + // SAFETY: ReadHalf cannot write but still needs a mutable reference for polling. + // NamedPipeClient calls its `io` using immutable references, but it's private. + let self1 = unsafe { &mut *(self as *mut Self) }; + let self2 = self; + (ReadHalf(self1), WriteHalf(self2)) + } + } + + impl AsyncRead for Stream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = Pin::new(&mut self.get_mut().0); + this.poll_read(cx, buf) + } + } + + impl AsyncWrite for Stream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = Pin::new(&mut self.get_mut().0); + this.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let this = Pin::new(&mut self.get_mut().0); + this.poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } + } + + pub(super) struct ReadHalf<'a>(pub &'a mut Stream); + + pub(super) struct WriteHalf<'a>(pub &'a mut Stream); + + impl AsyncRead for ReadHalf<'_> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = Pin::new(&mut self.get_mut().0 .0); + this.poll_read(cx, buf) + } + } + + impl AsyncWrite for WriteHalf<'_> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = Pin::new(&mut self.get_mut().0 .0); + this.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let this = Pin::new(&mut self.get_mut().0 .0); + this.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::new(&mut self.get_mut().0 .0); + this.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } + } +} + +use self::imp::*; + +#[cfg_attr(unix, doc = "A JSON-RPC Client over Unix IPC.")] +#[cfg_attr(windows, doc = "A JSON-RPC Client over named pipes.")] +/// +/// # Example +/// +/// ```no_run +/// # async fn foo() -> Result<(), Box> { +/// use ethers_providers::Ipc; +/// +/// // the ipc's path +#[cfg_attr(unix, doc = r#"let path = "/home/user/.local/share/reth/reth.ipc";"#)] +#[cfg_attr(windows, doc = r#"let path = r"\\.\pipe\reth.ipc";"#)] +/// let ipc = Ipc::connect(path).await?; +/// # Ok(()) +/// # } +/// ``` #[derive(Debug, Clone)] pub struct Ipc { id: Arc, @@ -57,12 +226,17 @@ enum TransportMessage { } impl Ipc { - /// Creates a new IPC transport from a given path using Unix sockets. + #[cfg_attr(unix, doc = "Connects to the Unix socket at the provided path.")] + #[cfg_attr(windows, doc = "Connects to the named pipe at the provided path.\n")] + #[cfg_attr( + windows, + doc = r"Note: the path must be the fully qualified, like: `\\.\pipe\`." + )] pub async fn connect(path: impl AsRef) -> Result { let id = Arc::new(AtomicU64::new(1)); let (request_tx, request_rx) = mpsc::unbounded(); - let stream = UnixStream::connect(path).await?; + let stream = Stream::connect(path).await?; spawn_ipc_server(stream, request_rx); Ok(Self { id, request_tx }) @@ -121,11 +295,11 @@ impl PubsubClient for Ipc { } } -fn spawn_ipc_server(stream: UnixStream, request_rx: mpsc::UnboundedReceiver) { - // 65 KiB should be more than enough for this thread, as all unbounded data +fn spawn_ipc_server(stream: Stream, request_rx: mpsc::UnboundedReceiver) { + // 256 Kb should be more than enough for this thread, as all unbounded data // growth occurs on heap-allocated data structures and buffers and the call // stack is not going to do anything crazy either - const STACK_SIZE: usize = 1 << 16; + const STACK_SIZE: usize = 1 << 18; // spawn a light-weight thread with a thread-local async runtime just for // sending and receiving data over the IPC socket let _ = thread::Builder::new() @@ -142,10 +316,7 @@ fn spawn_ipc_server(stream: UnixStream, request_rx: mpsc::UnboundedReceiver, -) { +async fn run_ipc_server(mut stream: Stream, request_rx: mpsc::UnboundedReceiver) { // the shared state for both reads & writes let shared = Shared { pending: FxHashMap::with_capacity_and_hasher(64, BuildHasherDefault::default()).into(), @@ -289,8 +460,8 @@ impl Shared { } } -#[derive(Error, Debug)] /// Error thrown when sending or receiving an IPC message. +#[derive(Debug, Error)] pub enum IpcError { /// Thrown if deserialization failed #[error(transparent)] @@ -298,7 +469,7 @@ pub enum IpcError { /// std IO error forwarding. #[error(transparent)] - IoError(#[from] std::io::Error), + IoError(#[from] io::Error), #[error(transparent)] /// Thrown if the response could not be parsed @@ -319,22 +490,30 @@ impl From for ProviderError { ProviderError::JsonRpcClientError(Box::new(src)) } } -#[cfg(all(test, target_family = "unix"))] -#[cfg(not(feature = "celo"))] -mod test { + +#[cfg(test)] +mod tests { use super::*; - use ethers_core::{ - types::{Block, TxHash, U256}, - utils::Geth, - }; + use ethers_core::utils::{Geth, GethInstance}; use tempfile::NamedTempFile; + async fn connect() -> (Ipc, GethInstance) { + let temp_file = NamedTempFile::new().unwrap(); + let path = temp_file.into_temp_path().to_path_buf(); + let geth = Geth::new().block_time(1u64).ipc_path(&path).spawn(); + + // [Windows named pipes](https://learn.microsoft.com/en-us/windows/win32/ipc/named-pipes) + // are located at `\\\pipe\`. + #[cfg(windows)] + let path = format!(r"\\.\pipe\{}", path.display()); + let ipc = Ipc::connect(path).await.unwrap(); + + (ipc, geth) + } + #[tokio::test] async fn request() { - let temp_file = NamedTempFile::new().unwrap(); - let path = temp_file.into_temp_path().to_path_buf(); - let _geth = Geth::new().block_time(1u64).ipc_path(&path).spawn(); - let ipc = Ipc::connect(path).await.unwrap(); + let (ipc, _geth) = connect().await; let block_num: U256 = ipc.request("eth_blockNumber", ()).await.unwrap(); std::thread::sleep(std::time::Duration::new(3, 0)); @@ -343,25 +522,25 @@ mod test { } #[tokio::test] + #[cfg(not(feature = "celo"))] async fn subscription() { - let temp_file = NamedTempFile::new().unwrap(); - let path = temp_file.into_temp_path().to_path_buf(); - let _geth = Geth::new().block_time(2u64).ipc_path(&path).spawn(); - let ipc = Ipc::connect(path).await.unwrap(); + use ethers_core::types::{Block, TxHash}; - let sub_id: U256 = ipc.request("eth_subscribe", ["newHeads"]).await.unwrap(); - let mut stream = ipc.subscribe(sub_id).unwrap(); + let (ipc, _geth) = connect().await; // Subscribing requires sending the sub request and then subscribing to // the returned sub_id - let block_num: u64 = ipc.request::<_, U256>("eth_blockNumber", ()).await.unwrap().as_u64(); - let mut blocks = Vec::new(); - for _ in 0..3 { - let item = stream.next().await.unwrap(); - let block: Block = serde_json::from_str(item.get()).unwrap(); - blocks.push(block.number.unwrap_or_default().as_u64()); - } - let offset = blocks[0] - block_num; - assert_eq!(blocks, &[block_num + offset, block_num + offset + 1, block_num + offset + 2]) + let sub_id: U256 = ipc.request("eth_subscribe", ["newHeads"]).await.unwrap(); + let stream = ipc.subscribe(sub_id).unwrap(); + + let blocks: Vec = stream + .take(3) + .map(|item| { + let block: Block = serde_json::from_str(item.get()).unwrap(); + block.number.unwrap_or_default().as_u64() + }) + .collect() + .await; + assert_eq!(blocks, vec![1, 2, 3]); } } diff --git a/ethers-providers/src/transports/mock.rs b/ethers-providers/src/transports/mock.rs index e6ff8f5c..0e0cedb1 100644 --- a/ethers-providers/src/transports/mock.rs +++ b/ethers-providers/src/transports/mock.rs @@ -1,5 +1,4 @@ use crate::{JsonRpcClient, ProviderError}; - use async_trait::async_trait; use serde::{de::DeserializeOwned, Serialize}; use serde_json::Value; diff --git a/ethers-providers/src/transports/mod.rs b/ethers-providers/src/transports/mod.rs index 5623675d..feb4ed37 100644 --- a/ethers-providers/src/transports/mod.rs +++ b/ethers-providers/src/transports/mod.rs @@ -1,32 +1,14 @@ mod common; pub use common::Authorization; -// only used with WS -#[cfg(feature = "ws")] -macro_rules! if_wasm { - ($($item:item)*) => {$( - #[cfg(target_arch = "wasm32")] - $item - )*} -} - -// only used with WS -#[cfg(feature = "ws")] -macro_rules! if_not_wasm { - ($($item:item)*) => {$( - #[cfg(not(target_arch = "wasm32"))] - $item - )*} -} - -#[cfg(all(target_family = "unix", feature = "ipc"))] -mod ipc; -#[cfg(all(target_family = "unix", feature = "ipc"))] -pub use ipc::{Ipc, IpcError}; - mod http; pub use self::http::{ClientError as HttpClientError, Provider as Http}; +#[cfg(all(feature = "ipc", any(unix, windows)))] +mod ipc; +#[cfg(all(feature = "ipc", any(unix, windows)))] +pub use ipc::{Ipc, IpcError}; + #[cfg(feature = "ws")] mod ws; #[cfg(feature = "ws")] diff --git a/ethers-providers/src/transports/quorum.rs b/ethers-providers/src/transports/quorum.rs index aba508d8..3faf58b9 100644 --- a/ethers-providers/src/transports/quorum.rs +++ b/ethers-providers/src/transports/quorum.rs @@ -1,10 +1,3 @@ -use std::{ - fmt::Debug, - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - use crate::{provider::ProviderError, JsonRpcClient, PubsubClient}; use async_trait::async_trait; use ethers_core::types::{U256, U64}; @@ -12,6 +5,12 @@ use futures_core::Stream; use futures_util::{future::join_all, FutureExt, StreamExt}; use serde::{de::DeserializeOwned, Serialize}; use serde_json::{value::RawValue, Value}; +use std::{ + fmt::Debug, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; use thiserror::Error; /// A provider that bundles multiple providers and only returns a value to the diff --git a/ethers-providers/src/transports/rw.rs b/ethers-providers/src/transports/rw.rs index 58499d20..f3c649d2 100644 --- a/ethers-providers/src/transports/rw.rs +++ b/ethers-providers/src/transports/rw.rs @@ -2,11 +2,8 @@ //! and uses a dedicated client for read and the other for write operations use crate::{provider::ProviderError, JsonRpcClient}; - use async_trait::async_trait; - use serde::{de::DeserializeOwned, Serialize}; - use thiserror::Error; /// A client containing two clients. diff --git a/ethers-providers/src/transports/ws.rs b/ethers-providers/src/transports/ws.rs index 696c51da..a8865a8e 100644 --- a/ethers-providers/src/transports/ws.rs +++ b/ethers-providers/src/transports/ws.rs @@ -1,11 +1,11 @@ +use super::common::{Params, Response}; use crate::{ provider::ProviderError, transports::common::{JsonRpcError, Request}, JsonRpcClient, PubsubClient, }; -use ethers_core::types::U256; - use async_trait::async_trait; +use ethers_core::types::U256; use futures_channel::{mpsc, oneshot}; use futures_util::{ sink::{Sink, SinkExt}, @@ -24,7 +24,19 @@ use std::{ use thiserror::Error; use tracing::trace; -use super::common::{Params, Response}; +macro_rules! if_wasm { + ($($item:item)*) => {$( + #[cfg(target_arch = "wasm32")] + $item + )*} +} + +macro_rules! if_not_wasm { + ($($item:item)*) => {$( + #[cfg(not(target_arch = "wasm32"))] + $item + )*} +} if_wasm! { use wasm_bindgen::prelude::*; @@ -84,11 +96,13 @@ enum Instruction { /// A JSON-RPC Client over Websockets. /// +/// # Example +/// /// ```no_run /// # async fn foo() -> Result<(), Box> { /// use ethers_providers::Ws; /// -/// let ws = Ws::connect("wss://localhost:8545").await?; +/// let ws = Ws::connect("ws://localhost:8545").await?; /// # Ok(()) /// # } /// ``` @@ -427,8 +441,8 @@ fn to_client_error(err: T) -> ClientError { ClientError::ChannelError(format!("{err:?}")) } -#[derive(Error, Debug)] /// Error thrown when sending a WS message +#[derive(Debug, Error)] pub enum ClientError { /// Thrown if deserialization failed #[error(transparent)] @@ -488,15 +502,10 @@ impl From for ProviderError { } } -#[cfg(test)] -#[cfg(not(feature = "celo"))] -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(test, not(target_arch = "wasm32")))] mod tests { use super::*; - use ethers_core::{ - types::{Block, TxHash, U256}, - utils::Anvil, - }; + use ethers_core::{types::U256, utils::Anvil}; #[tokio::test] async fn request() { @@ -504,29 +513,33 @@ mod tests { let ws = Ws::connect(anvil.ws_endpoint()).await.unwrap(); let block_num: U256 = ws.request("eth_blockNumber", ()).await.unwrap(); - std::thread::sleep(std::time::Duration::new(3, 0)); + tokio::time::sleep(std::time::Duration::from_secs(3)).await; let block_num2: U256 = ws.request("eth_blockNumber", ()).await.unwrap(); assert!(block_num2 > block_num); } #[tokio::test] + #[cfg(not(feature = "celo"))] async fn subscription() { + use ethers_core::types::{Block, TxHash}; + let anvil = Anvil::new().block_time(1u64).spawn(); let ws = Ws::connect(anvil.ws_endpoint()).await.unwrap(); // Subscribing requires sending the sub request and then subscribing to // the returned sub_id let sub_id: U256 = ws.request("eth_subscribe", ["newHeads"]).await.unwrap(); - let mut stream = ws.subscribe(sub_id).unwrap(); + let stream = ws.subscribe(sub_id).unwrap(); - let mut blocks = Vec::new(); - for _ in 0..3 { - let item = stream.next().await.unwrap(); - let block: Block = serde_json::from_str(item.get()).unwrap(); - blocks.push(block.number.unwrap_or_default().as_u64()); - } - - assert_eq!(blocks, vec![1, 2, 3]) + let blocks: Vec = stream + .take(3) + .map(|item| { + let block: Block = serde_json::from_str(item.get()).unwrap(); + block.number.unwrap_or_default().as_u64() + }) + .collect() + .await; + assert_eq!(blocks, vec![1, 2, 3]); } #[tokio::test] diff --git a/ethers-providers/tests/ws_errors.rs b/ethers-providers/tests/ws_errors.rs index b70b2c96..6ae85213 100644 --- a/ethers-providers/tests/ws_errors.rs +++ b/ethers-providers/tests/ws_errors.rs @@ -1,4 +1,5 @@ -#![allow(unused)] +#![cfg(not(feature = "celo"))] + use ethers_providers::{Middleware, Provider, StreamExt, Ws}; use futures_util::SinkExt; use std::time::Duration; @@ -15,29 +16,21 @@ use tungstenite::protocol::Message; const WS_ENDPOINT: &str = "127.0.0.1:9002"; -#[cfg(not(feature = "celo"))] -mod eth_tests { - use super::*; - use ethers_core::types::Filter; - use tokio_tungstenite::connect_async; +use ethers_core::types::Filter; +use tokio_tungstenite::connect_async; - #[tokio::test] - async fn graceful_disconnect_on_ws_errors() { - // Spawn a fake Ws server that will drop our connection after a while - spawn_ws_server().await; +#[tokio::test] +async fn graceful_disconnect_on_ws_errors() { + // Spawn a fake Ws server that will drop our connection after a while + spawn_ws_server().await; - // Connect to the fake server - let (ws, _) = connect_async(format!("ws://{}", WS_ENDPOINT)).await.unwrap(); - let provider = Provider::new(Ws::new(ws)); - let filter = Filter::new().event("Transfer(address,address,uint256)"); - let mut stream = provider.subscribe_logs(&filter).await.unwrap(); + // Connect to the fake server + let (ws, _) = connect_async(format!("ws://{}", WS_ENDPOINT)).await.unwrap(); + let provider = Provider::new(Ws::new(ws)); + let filter = Filter::new().event("Transfer(address,address,uint256)"); + let mut stream = provider.subscribe_logs(&filter).await.unwrap(); - while let Some(_) = stream.next().await { - assert!(false); // force test to fail - } - - assert!(true); - } + assert!(stream.next().await.is_none()); } async fn spawn_ws_server() { @@ -52,7 +45,7 @@ async fn spawn_ws_server() { async fn handle_conn(stream: TcpStream) -> Result<(), Error> { let mut ws_stream = accept_async(stream).await?; - while let Some(_) = ws_stream.next().await { + while ws_stream.next().await.is_some() { let res: String = "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}" .into();