diff --git a/Cargo.lock b/Cargo.lock index 0026891f..5537a467 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -907,6 +907,7 @@ version = "0.2.2" dependencies = [ "async-trait", "auto_impl", + "bytes", "ethers", "ethers-core", "futures-channel", diff --git a/ethers-providers/Cargo.toml b/ethers-providers/Cargo.toml index 7169e870..cfe9daa1 100644 --- a/ethers-providers/Cargo.toml +++ b/ethers-providers/Cargo.toml @@ -40,6 +40,7 @@ tracing-futures = { version = "0.2.5", default-features = false, features = ["st tokio = { version = "1.4", default-features = false, optional = true } tokio-tungstenite = { version = "0.13.0", default-features = false, features = ["connect", "tls"], optional = true } tokio-util = { version = "0.6.5", default-features = false, features = ["io"], optional = true } +bytes = { version = "1.0.1", default-features = false, optional = true } [dev-dependencies] ethers = { version = "0.2", path = "../ethers" } @@ -50,4 +51,4 @@ tempfile = "3.2.0" default = ["ws", "ipc"] celo = ["ethers-core/celo"] ws = ["tokio", "tokio-tungstenite"] -ipc = ["tokio", "tokio/io-util", "tokio-util"] +ipc = ["tokio", "tokio/io-util", "tokio-util", "bytes"] diff --git a/ethers-providers/src/provider.rs b/ethers-providers/src/provider.rs index 0970a070..84b0f4a3 100644 --- a/ethers-providers/src/provider.rs +++ b/ethers-providers/src/provider.rs @@ -728,7 +728,7 @@ impl Provider { impl Provider { /// Direct connection to an IPC socket. pub async fn connect_ipc(path: impl AsRef) -> Result { - let ipc = crate::Ipc::new(path).await?; + let ipc = crate::Ipc::connect(path).await?; Ok(Self::new(ipc)) } } diff --git a/ethers-providers/src/transports/ipc.rs b/ethers-providers/src/transports/ipc.rs index 57d6b8e3..9240bc81 100644 --- a/ethers-providers/src/transports/ipc.rs +++ b/ethers-providers/src/transports/ipc.rs @@ -7,7 +7,7 @@ use ethers_core::types::U256; use async_trait::async_trait; use futures_channel::mpsc; -use futures_util::stream::StreamExt; +use futures_util::stream::{Fuse, StreamExt}; use oneshot::error::RecvError; use serde::{de::DeserializeOwned, Serialize}; use std::sync::atomic::Ordering; @@ -17,7 +17,11 @@ use std::{ sync::{atomic::AtomicU64, Arc}, }; use thiserror::Error; -use tokio::{io::AsyncWriteExt, net::UnixStream, sync::oneshot}; +use tokio::{ + io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}, + net::UnixStream, + sync::oneshot, +}; use tokio_util::io::ReaderStream; use tracing::{error, warn}; @@ -28,24 +32,40 @@ pub struct Ipc { messages_tx: mpsc::UnboundedSender, } -#[cfg(unix)] +type Pending = oneshot::Sender; +type Subscription = mpsc::UnboundedSender; + +#[derive(Debug)] +enum TransportMessage { + Request { + id: u64, + request: String, + sender: Pending, + }, + Subscribe { + id: U256, + sink: Subscription, + }, + Unsubscribe { + id: U256, + }, +} + impl Ipc { - /// Creates a new IPC transport from a given path. - /// - /// IPC is only available on Unix. - pub async fn new>(path: P) -> Result { - let stream = UnixStream::connect(path).await?; - - Ok(Self::with_stream(stream)) - } - - fn with_stream(stream: UnixStream) -> Self { + /// Creates a new IPC transport from a Async Reader / Writer + fn new(stream: S) -> Self { let id = Arc::new(AtomicU64::new(1)); let (messages_tx, messages_rx) = mpsc::unbounded(); - tokio::spawn(run_server(stream, messages_rx)); + IpcServer::new(stream, messages_rx).spawn(); + Self { id, messages_tx } + } - Ipc { id, messages_tx } + /// Creates a new IPC transport from a given path using Unix sockets + #[cfg(unix)] + pub async fn connect>(path: P) -> Result { + let ipc = UnixStream::connect(path).await?; + Ok(Self::new(ipc)) } fn send(&self, msg: TransportMessage) -> Result<(), IpcError> { @@ -104,142 +124,186 @@ impl PubsubClient for Ipc { } } -#[derive(Debug)] -enum TransportMessage { - Request { - id: u64, - request: String, - sender: oneshot::Sender, - }, - Subscribe { - id: U256, - sink: mpsc::UnboundedSender, - }, - Unsubscribe { - id: U256, - }, +struct IpcServer { + socket_reader: Fuse>>, + socket_writer: WriteHalf, + requests: Fuse>, + pending: HashMap, + subscriptions: HashMap, } -#[cfg(unix)] -async fn run_server( - unix_stream: UnixStream, - messages_rx: mpsc::UnboundedReceiver, -) -> Result<(), IpcError> { - let (socket_reader, mut socket_writer) = unix_stream.into_split(); - let mut pending_response_txs = HashMap::default(); - let mut subscription_txs = HashMap::default(); +impl IpcServer +where + T: AsyncRead + AsyncWrite, +{ + /// Instantiates the Websocket Server + pub fn new(ipc: T, requests: mpsc::UnboundedReceiver) -> Self { + let (socket_reader, socket_writer) = tokio::io::split(ipc); + let socket_reader = ReaderStream::new(socket_reader).fuse(); + Self { + socket_reader, + socket_writer, + requests: requests.fuse(), + pending: HashMap::default(), + subscriptions: HashMap::default(), + } + } - let mut socket_reader = ReaderStream::new(socket_reader); - let mut messages_rx = messages_rx.fuse(); - let mut read_buffer = vec![]; - let mut closed = false; + /// Spawns the event loop + fn spawn(mut self) + where + T: 'static + Send, + { + let f = async move { + let mut read_buffer = Vec::new(); + loop { + let closed = self + .process(&mut read_buffer) + .await + .expect("WS Server panic"); + if closed && self.pending.is_empty() { + break; + } + } + }; - while !closed || !pending_response_txs.is_empty() { - tokio::select! { - message = messages_rx.next() => match message { - Some(TransportMessage::Subscribe{ id, sink }) => { - if subscription_txs.insert(id, sink).is_some() { - warn!("Replacing a subscription with id {:?}", id); - } - }, - Some(TransportMessage::Unsubscribe{ id }) => { - if subscription_txs.remove(&id).is_none() { - warn!("Unsubscribing not subscribed id {:?}", id); - } - }, - Some(TransportMessage::Request{ id, request, sender }) => { - if pending_response_txs.insert(id, sender).is_some() { - warn!("Replacing a pending request with id {:?}", id); - } + tokio::spawn(f); + } - if let Err(err) = socket_writer.write(&request.as_bytes()).await { - pending_response_txs.remove(&id); - error!("IPC write error: {:?}", err); - } - }, - None => closed = true, + /// Processes 1 item selected from the incoming `requests` or `socket` + #[allow(clippy::single_match)] + async fn process(&mut self, read_buffer: &mut Vec) -> Result { + futures_util::select! { + // Handle requests + msg = self.requests.next() => match msg { + Some(msg) => self.handle_request(msg).await?, + None => return Ok(true), }, - bytes = socket_reader.next() => match bytes { - Some(Ok(bytes)) => { - // Extend buffer of previously unread with the new read bytes - read_buffer.extend_from_slice(&bytes); - - let read_len = { - // Deserialize as many full elements from the stream as exists - let mut de: serde_json::StreamDeserializer<_, serde_json::Value> = - serde_json::Deserializer::from_slice(&read_buffer).into_iter(); - - // Iterate through these elements, and handle responses/notifications - while let Some(Ok(value)) = de.next() { - if let Ok(notification) = serde_json::from_value::>(value.clone()) { - // Send notify response if okay. - if let Err(e) = notify(&mut subscription_txs, notification) { - error!("Failed to send IPC notification: {}", e) - } - } else if let Ok(response) = serde_json::from_value::>(value) { - if let Err(e) = respond(&mut pending_response_txs, response) { - error!("Failed to send IPC response: {}", e) - } - } else { - warn!("JSON from IPC stream is not a response or notification"); - } - } - - // Get the offset of bytes to handle partial buffer reads - de.byte_offset() - }; - - // Reset buffer to just include the partial value bytes. - read_buffer.copy_within(read_len.., 0); - read_buffer.truncate(read_buffer.len() - read_len); - }, + // Handle socket messages + msg = self.socket_reader.next() => match msg { + Some(Ok(msg)) => self.handle_socket(read_buffer, msg).await?, Some(Err(err)) => { error!("IPC read error: {:?}", err); return Err(err.into()); }, - None => break, + None => {}, + }, + // finished + complete => {}, + }; + + Ok(false) + } + + async fn handle_request(&mut self, msg: TransportMessage) -> Result<(), IpcError> { + match msg { + TransportMessage::Request { + id, + request, + sender, + } => { + if self.pending.insert(id, sender).is_some() { + warn!("Replacing a pending request with id {:?}", id); + } + + if let Err(err) = self.socket_writer.write(&request.as_bytes()).await { + error!("WS connection error: {:?}", err); + self.pending.remove(&id); + } + } + TransportMessage::Subscribe { id, sink } => { + if self.subscriptions.insert(id, sink).is_some() { + warn!("Replacing already-registered subscription with id {:?}", id); + } + } + TransportMessage::Unsubscribe { id } => { + if self.subscriptions.remove(&id).is_none() { + warn!( + "Unsubscribing from non-existent subscription with id {:?}", + id + ); + } } }; + + Ok(()) } - Ok(()) -} + async fn handle_socket( + &mut self, + read_buffer: &mut Vec, + bytes: bytes::Bytes, + ) -> Result<(), IpcError> { + // Extend buffer of previously unread with the new read bytes + read_buffer.extend_from_slice(&bytes); -/// Sends notification through the channel based on the ID of the subscription. -/// This handles streaming responses. -fn notify( - subscription_txs: &mut HashMap>, - notification: Notification, -) -> Result<(), IpcError> { - let id = notification.params.subscription; - if let Some(tx) = subscription_txs.get(&id) { - tx.unbounded_send(notification.params.result) - .map_err(|_| IpcError::ChannelError(format!("Subscription receiver {} dropped", id)))?; + let read_len = { + // Deserialize as many full elements from the stream as exists + let mut de: serde_json::StreamDeserializer<_, serde_json::Value> = + serde_json::Deserializer::from_slice(&read_buffer).into_iter(); + + // Iterate through these elements, and handle responses/notifications + while let Some(Ok(value)) = de.next() { + if let Ok(notification) = + serde_json::from_value::>(value.clone()) + { + // Send notify response if okay. + if let Err(e) = self.notify(notification) { + error!("Failed to send IPC notification: {}", e) + } + } else if let Ok(response) = + serde_json::from_value::>(value) + { + if let Err(e) = self.respond(response) { + error!("Failed to send IPC response: {}", e) + } + } else { + warn!("JSON from IPC stream is not a response or notification"); + } + } + + // Get the offset of bytes to handle partial buffer reads + de.byte_offset() + }; + + // Reset buffer to just include the partial value bytes. + read_buffer.copy_within(read_len.., 0); + read_buffer.truncate(read_buffer.len() - read_len); + + Ok(()) } - Ok(()) -} + /// Sends notification through the channel based on the ID of the subscription. + /// This handles streaming responses. + fn notify(&mut self, notification: Notification) -> Result<(), IpcError> { + let id = notification.params.subscription; + if let Some(tx) = self.subscriptions.get(&id) { + tx.unbounded_send(notification.params.result).map_err(|_| { + IpcError::ChannelError(format!("Subscription receiver {} dropped", id)) + })?; + } -/// Sends JSON response through the channel based on the ID in that response. -/// This handles RPC calls with only one response, and the channel entry is dropped after sending. -fn respond( - pending_response_txs: &mut HashMap>, - output: Response, -) -> Result<(), IpcError> { - let id = output.id; + Ok(()) + } - // Converts output into result, to send data if valid response. - let value = output.data.into_result()?; + /// Sends JSON response through the channel based on the ID in that response. + /// This handles RPC calls with only one response, and the channel entry is dropped after sending. + fn respond(&mut self, output: Response) -> Result<(), IpcError> { + let id = output.id; - let response_tx = pending_response_txs.remove(&id).ok_or_else(|| { - IpcError::ChannelError("No response channel exists for the response ID".to_string()) - })?; + // Converts output into result, to send data if valid response. + let value = output.data.into_result()?; - response_tx.send(value).map_err(|_| { - IpcError::ChannelError("Receiver channel for response has been dropped".to_string()) - })?; + let response_tx = self.pending.remove(&id).ok_or_else(|| { + IpcError::ChannelError("No response channel exists for the response ID".to_string()) + })?; - Ok(()) + response_tx.send(value).map_err(|_| { + IpcError::ChannelError("Receiver channel for response has been dropped".to_string()) + })?; + + Ok(()) + } } #[derive(Error, Debug)] @@ -283,7 +347,7 @@ mod test { let temp_file = NamedTempFile::new().unwrap(); let path = temp_file.into_temp_path().to_path_buf(); let _geth = Geth::new().block_time(1u64).ipc_path(&path).spawn(); - let ipc = Ipc::new(path).await.unwrap(); + let ipc = Ipc::connect(path).await.unwrap(); let block_num: U256 = ipc.request("eth_blockNumber", ()).await.unwrap(); std::thread::sleep(std::time::Duration::new(3, 0)); @@ -295,8 +359,11 @@ mod test { async fn subscription() { let temp_file = NamedTempFile::new().unwrap(); let path = temp_file.into_temp_path().to_path_buf(); - let _geth = Geth::new().block_time(1u64).ipc_path(&path).spawn(); - let ipc = Ipc::new(path).await.unwrap(); + let _geth = Geth::new().block_time(2u64).ipc_path(&path).spawn(); + let ipc = Ipc::connect(path).await.unwrap(); + + let sub_id: U256 = ipc.request("eth_subscribe", ["newHeads"]).await.unwrap(); + let mut stream = ipc.subscribe(sub_id).unwrap(); // Subscribing requires sending the sub request and then subscribing to // the returned sub_id @@ -305,9 +372,6 @@ mod test { .await .unwrap() .as_u64(); - let sub_id: U256 = ipc.request("eth_subscribe", ["newHeads"]).await.unwrap(); - let mut stream = ipc.subscribe(sub_id).unwrap(); - let mut blocks = Vec::new(); for _ in 0..3 { let item = stream.next().await.unwrap(); diff --git a/ethers/examples/ipc.rs b/ethers/examples/ipc.rs index 427f3a16..361c7f9c 100644 --- a/ethers/examples/ipc.rs +++ b/ethers/examples/ipc.rs @@ -3,8 +3,9 @@ use std::time::Duration; #[tokio::main] async fn main() -> anyhow::Result<()> { - let ws = Ipc::new("~/.ethereum/geth.ipc").await?; - let provider = Provider::new(ws).interval(Duration::from_millis(2000)); + let provider = Provider::connect_ipc("~/.ethereum/geth.ipc") + .await? + .interval(Duration::from_millis(2000)); let block = provider.get_block_number().await?; println!("Current block: {}", block); let mut stream = provider.watch_blocks().await?.stream();