diff --git a/Cargo.lock b/Cargo.lock index 6466ad2b..0026891f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -918,9 +918,11 @@ dependencies = [ "reqwest", "serde", "serde_json", + "tempfile", "thiserror", "tokio", "tokio-tungstenite", + "tokio-util", "tracing", "tracing-futures", "url", diff --git a/ethers-core/src/utils/geth.rs b/ethers-core/src/utils/geth.rs index e5cec3da..8787e38a 100644 --- a/ethers-core/src/utils/geth.rs +++ b/ethers-core/src/utils/geth.rs @@ -1,6 +1,7 @@ use super::unused_port; use std::{ io::{BufRead, BufReader}, + path::PathBuf, process::{Child, Command}, time::{Duration, Instant}, }; @@ -20,6 +21,7 @@ const GETH: &str = "geth"; pub struct GethInstance { pid: Child, port: u16, + ipc: Option, } impl GethInstance { @@ -37,6 +39,10 @@ impl GethInstance { pub fn ws_endpoint(&self) -> String { format!("ws://localhost:{}", self.port) } + + pub fn ipc_path(&self) -> &Option { + &self.ipc + } } impl Drop for GethInstance { @@ -70,6 +76,7 @@ impl Drop for GethInstance { pub struct Geth { port: Option, block_time: Option, + ipc_path: Option, } impl Geth { @@ -91,6 +98,12 @@ impl Geth { self } + /// Manually sets the IPC path for the socket manually. + pub fn ipc_path>(mut self, path: T) -> Self { + self.ipc_path = Some(path.into()); + self + } + /// Consumes the builder and spawns `geth` with stdout redirected /// to /dev/null. pub fn spawn(self) -> GethInstance { @@ -119,6 +132,10 @@ impl Geth { cmd.arg("--dev.period").arg(block_time.to_string()); } + if let Some(ref ipc) = self.ipc_path { + cmd.arg("--ipcpath").arg(ipc); + } + let mut child = cmd.spawn().expect("couldnt start geth"); let stdout = child @@ -146,6 +163,10 @@ impl Geth { child.stderr = Some(reader.into_inner()); - GethInstance { pid: child, port } + GethInstance { + pid: child, + port, + ipc: self.ipc_path, + } } } diff --git a/ethers-providers/Cargo.toml b/ethers-providers/Cargo.toml index 56bb45d1..7169e870 100644 --- a/ethers-providers/Cargo.toml +++ b/ethers-providers/Cargo.toml @@ -39,12 +39,15 @@ tracing-futures = { version = "0.2.5", default-features = false, features = ["st # tokio tokio = { version = "1.4", default-features = false, optional = true } tokio-tungstenite = { version = "0.13.0", default-features = false, features = ["connect", "tls"], optional = true } +tokio-util = { version = "0.6.5", default-features = false, features = ["io"], optional = true } [dev-dependencies] ethers = { version = "0.2", path = "../ethers" } tokio = { version = "1.4", default-features = false, features = ["rt", "macros"] } +tempfile = "3.2.0" [features] -default = ["ws"] +default = ["ws", "ipc"] celo = ["ethers-core/celo"] ws = ["tokio", "tokio-tungstenite"] +ipc = ["tokio", "tokio/io-util", "tokio-util"] diff --git a/ethers-providers/src/provider.rs b/ethers-providers/src/provider.rs index 36e2de73..0970a070 100644 --- a/ethers-providers/src/provider.rs +++ b/ethers-providers/src/provider.rs @@ -724,6 +724,15 @@ impl Provider { } } +#[cfg(feature = "ipc")] +impl Provider { + /// Direct connection to an IPC socket. + pub async fn connect_ipc(path: impl AsRef) -> Result { + let ipc = crate::Ipc::new(path).await?; + Ok(Self::new(ipc)) + } +} + impl Provider { /// Returns a `Provider` instantiated with an internal "mock" transport. /// diff --git a/ethers-providers/src/transports/ipc.rs b/ethers-providers/src/transports/ipc.rs new file mode 100644 index 00000000..57d6b8e3 --- /dev/null +++ b/ethers-providers/src/transports/ipc.rs @@ -0,0 +1,320 @@ +use crate::{ + provider::ProviderError, + transports::common::{JsonRpcError, Notification, Request, Response}, + JsonRpcClient, PubsubClient, +}; +use ethers_core::types::U256; + +use async_trait::async_trait; +use futures_channel::mpsc; +use futures_util::stream::StreamExt; +use oneshot::error::RecvError; +use serde::{de::DeserializeOwned, Serialize}; +use std::sync::atomic::Ordering; +use std::{ + collections::HashMap, + path::Path, + sync::{atomic::AtomicU64, Arc}, +}; +use thiserror::Error; +use tokio::{io::AsyncWriteExt, net::UnixStream, sync::oneshot}; +use tokio_util::io::ReaderStream; +use tracing::{error, warn}; + +/// Unix Domain Sockets (IPC) transport. +#[derive(Debug, Clone)] +pub struct Ipc { + id: Arc, + messages_tx: mpsc::UnboundedSender, +} + +#[cfg(unix)] +impl Ipc { + /// Creates a new IPC transport from a given path. + /// + /// IPC is only available on Unix. + pub async fn new>(path: P) -> Result { + let stream = UnixStream::connect(path).await?; + + Ok(Self::with_stream(stream)) + } + + fn with_stream(stream: UnixStream) -> Self { + let id = Arc::new(AtomicU64::new(1)); + let (messages_tx, messages_rx) = mpsc::unbounded(); + + tokio::spawn(run_server(stream, messages_rx)); + + Ipc { id, messages_tx } + } + + fn send(&self, msg: TransportMessage) -> Result<(), IpcError> { + self.messages_tx + .unbounded_send(msg) + .map_err(|_| IpcError::ChannelError("IPC server receiver dropped".to_string()))?; + + Ok(()) + } +} + +#[async_trait] +impl JsonRpcClient for Ipc { + type Error = IpcError; + + async fn request( + &self, + method: &str, + params: T, + ) -> Result { + let next_id = self.id.fetch_add(1, Ordering::SeqCst); + + // Create the request and initialize the response channel + let (sender, receiver) = oneshot::channel(); + let payload = TransportMessage::Request { + id: next_id, + request: serde_json::to_string(&Request::new(next_id, method, params))?, + sender, + }; + + // Send the request to the IPC server to be handled. + self.send(payload)?; + + // Wait for the response from the IPC server. + let res = receiver.await?; + + // Parse JSON response. + Ok(serde_json::from_value(res)?) + } +} + +impl PubsubClient for Ipc { + type NotificationStream = mpsc::UnboundedReceiver; + + fn subscribe>(&self, id: T) -> Result { + let (sink, stream) = mpsc::unbounded(); + self.send(TransportMessage::Subscribe { + id: id.into(), + sink, + })?; + Ok(stream) + } + + fn unsubscribe>(&self, id: T) -> Result<(), IpcError> { + self.send(TransportMessage::Unsubscribe { id: id.into() }) + } +} + +#[derive(Debug)] +enum TransportMessage { + Request { + id: u64, + request: String, + sender: oneshot::Sender, + }, + Subscribe { + id: U256, + sink: mpsc::UnboundedSender, + }, + Unsubscribe { + id: U256, + }, +} + +#[cfg(unix)] +async fn run_server( + unix_stream: UnixStream, + messages_rx: mpsc::UnboundedReceiver, +) -> Result<(), IpcError> { + let (socket_reader, mut socket_writer) = unix_stream.into_split(); + let mut pending_response_txs = HashMap::default(); + let mut subscription_txs = HashMap::default(); + + let mut socket_reader = ReaderStream::new(socket_reader); + let mut messages_rx = messages_rx.fuse(); + let mut read_buffer = vec![]; + let mut closed = false; + + while !closed || !pending_response_txs.is_empty() { + tokio::select! { + message = messages_rx.next() => match message { + Some(TransportMessage::Subscribe{ id, sink }) => { + if subscription_txs.insert(id, sink).is_some() { + warn!("Replacing a subscription with id {:?}", id); + } + }, + Some(TransportMessage::Unsubscribe{ id }) => { + if subscription_txs.remove(&id).is_none() { + warn!("Unsubscribing not subscribed id {:?}", id); + } + }, + Some(TransportMessage::Request{ id, request, sender }) => { + if pending_response_txs.insert(id, sender).is_some() { + warn!("Replacing a pending request with id {:?}", id); + } + + if let Err(err) = socket_writer.write(&request.as_bytes()).await { + pending_response_txs.remove(&id); + error!("IPC write error: {:?}", err); + } + }, + None => closed = true, + }, + bytes = socket_reader.next() => match bytes { + Some(Ok(bytes)) => { + // Extend buffer of previously unread with the new read bytes + read_buffer.extend_from_slice(&bytes); + + let read_len = { + // Deserialize as many full elements from the stream as exists + let mut de: serde_json::StreamDeserializer<_, serde_json::Value> = + serde_json::Deserializer::from_slice(&read_buffer).into_iter(); + + // Iterate through these elements, and handle responses/notifications + while let Some(Ok(value)) = de.next() { + if let Ok(notification) = serde_json::from_value::>(value.clone()) { + // Send notify response if okay. + if let Err(e) = notify(&mut subscription_txs, notification) { + error!("Failed to send IPC notification: {}", e) + } + } else if let Ok(response) = serde_json::from_value::>(value) { + if let Err(e) = respond(&mut pending_response_txs, response) { + error!("Failed to send IPC response: {}", e) + } + } else { + warn!("JSON from IPC stream is not a response or notification"); + } + } + + // Get the offset of bytes to handle partial buffer reads + de.byte_offset() + }; + + // Reset buffer to just include the partial value bytes. + read_buffer.copy_within(read_len.., 0); + read_buffer.truncate(read_buffer.len() - read_len); + }, + Some(Err(err)) => { + error!("IPC read error: {:?}", err); + return Err(err.into()); + }, + None => break, + } + }; + } + + Ok(()) +} + +/// Sends notification through the channel based on the ID of the subscription. +/// This handles streaming responses. +fn notify( + subscription_txs: &mut HashMap>, + notification: Notification, +) -> Result<(), IpcError> { + let id = notification.params.subscription; + if let Some(tx) = subscription_txs.get(&id) { + tx.unbounded_send(notification.params.result) + .map_err(|_| IpcError::ChannelError(format!("Subscription receiver {} dropped", id)))?; + } + + Ok(()) +} + +/// Sends JSON response through the channel based on the ID in that response. +/// This handles RPC calls with only one response, and the channel entry is dropped after sending. +fn respond( + pending_response_txs: &mut HashMap>, + output: Response, +) -> Result<(), IpcError> { + let id = output.id; + + // Converts output into result, to send data if valid response. + let value = output.data.into_result()?; + + let response_tx = pending_response_txs.remove(&id).ok_or_else(|| { + IpcError::ChannelError("No response channel exists for the response ID".to_string()) + })?; + + response_tx.send(value).map_err(|_| { + IpcError::ChannelError("Receiver channel for response has been dropped".to_string()) + })?; + + Ok(()) +} + +#[derive(Error, Debug)] +/// Error thrown when sending or receiving an IPC message. +pub enum IpcError { + /// Thrown if deserialization failed + #[error(transparent)] + JsonError(#[from] serde_json::Error), + + /// std IO error forwarding. + #[error(transparent)] + IoError(#[from] std::io::Error), + + #[error(transparent)] + /// Thrown if the response could not be parsed + JsonRpcError(#[from] JsonRpcError), + + #[error("{0}")] + ChannelError(String), + + #[error(transparent)] + Canceled(#[from] RecvError), +} + +impl From for ProviderError { + fn from(src: IpcError) -> Self { + ProviderError::JsonRpcClientError(Box::new(src)) + } +} + +#[cfg(all(test, unix))] +#[cfg(not(feature = "celo"))] +mod test { + use super::*; + use ethers::utils::Geth; + use ethers_core::types::{Block, TxHash, U256}; + use tempfile::NamedTempFile; + + #[tokio::test] + async fn request() { + let temp_file = NamedTempFile::new().unwrap(); + let path = temp_file.into_temp_path().to_path_buf(); + let _geth = Geth::new().block_time(1u64).ipc_path(&path).spawn(); + let ipc = Ipc::new(path).await.unwrap(); + + let block_num: U256 = ipc.request("eth_blockNumber", ()).await.unwrap(); + std::thread::sleep(std::time::Duration::new(3, 0)); + let block_num2: U256 = ipc.request("eth_blockNumber", ()).await.unwrap(); + assert!(block_num2 > block_num); + } + + #[tokio::test] + async fn subscription() { + let temp_file = NamedTempFile::new().unwrap(); + let path = temp_file.into_temp_path().to_path_buf(); + let _geth = Geth::new().block_time(1u64).ipc_path(&path).spawn(); + let ipc = Ipc::new(path).await.unwrap(); + + // Subscribing requires sending the sub request and then subscribing to + // the returned sub_id + let block_num: u64 = ipc + .request::<_, U256>("eth_blockNumber", ()) + .await + .unwrap() + .as_u64(); + let sub_id: U256 = ipc.request("eth_subscribe", ["newHeads"]).await.unwrap(); + let mut stream = ipc.subscribe(sub_id).unwrap(); + + let mut blocks = Vec::new(); + for _ in 0..3 { + let item = stream.next().await.unwrap(); + let block = serde_json::from_value::>(item).unwrap(); + blocks.push(block.number.unwrap_or_default().as_u64()); + } + + assert_eq!(blocks, &[block_num + 1, block_num + 2, block_num + 3]) + } +} diff --git a/ethers-providers/src/transports/mod.rs b/ethers-providers/src/transports/mod.rs index 394b6fdc..9b8dab5f 100644 --- a/ethers-providers/src/transports/mod.rs +++ b/ethers-providers/src/transports/mod.rs @@ -8,5 +8,10 @@ mod ws; #[cfg(feature = "ws")] pub use ws::Ws; +#[cfg(feature = "ipc")] +mod ipc; +#[cfg(feature = "ipc")] +pub use ipc::Ipc; + mod mock; pub use mock::{MockError, MockProvider}; diff --git a/ethers/examples/ipc.rs b/ethers/examples/ipc.rs new file mode 100644 index 00000000..427f3a16 --- /dev/null +++ b/ethers/examples/ipc.rs @@ -0,0 +1,16 @@ +use ethers::prelude::*; +use std::time::Duration; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let ws = Ipc::new("~/.ethereum/geth.ipc").await?; + let provider = Provider::new(ws).interval(Duration::from_millis(2000)); + let block = provider.get_block_number().await?; + println!("Current block: {}", block); + let mut stream = provider.watch_blocks().await?.stream(); + while let Some(block) = stream.next().await { + dbg!(block); + } + + Ok(()) +}