refactor: make IPC generic over AsyncRead/Write (#264)
* refactor: make IPC generic over AsyncRead/Write * chore(ipc): fix typo
This commit is contained in:
parent
42b10cca9a
commit
66a503d294
|
@ -907,6 +907,7 @@ version = "0.2.2"
|
|||
dependencies = [
|
||||
"async-trait",
|
||||
"auto_impl",
|
||||
"bytes",
|
||||
"ethers",
|
||||
"ethers-core",
|
||||
"futures-channel",
|
||||
|
|
|
@ -40,6 +40,7 @@ tracing-futures = { version = "0.2.5", default-features = false, features = ["st
|
|||
tokio = { version = "1.4", default-features = false, optional = true }
|
||||
tokio-tungstenite = { version = "0.13.0", default-features = false, features = ["connect", "tls"], optional = true }
|
||||
tokio-util = { version = "0.6.5", default-features = false, features = ["io"], optional = true }
|
||||
bytes = { version = "1.0.1", default-features = false, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
ethers = { version = "0.2", path = "../ethers" }
|
||||
|
@ -50,4 +51,4 @@ tempfile = "3.2.0"
|
|||
default = ["ws", "ipc"]
|
||||
celo = ["ethers-core/celo"]
|
||||
ws = ["tokio", "tokio-tungstenite"]
|
||||
ipc = ["tokio", "tokio/io-util", "tokio-util"]
|
||||
ipc = ["tokio", "tokio/io-util", "tokio-util", "bytes"]
|
||||
|
|
|
@ -728,7 +728,7 @@ impl Provider<crate::Ws> {
|
|||
impl Provider<crate::Ipc> {
|
||||
/// Direct connection to an IPC socket.
|
||||
pub async fn connect_ipc(path: impl AsRef<std::path::Path>) -> Result<Self, ProviderError> {
|
||||
let ipc = crate::Ipc::new(path).await?;
|
||||
let ipc = crate::Ipc::connect(path).await?;
|
||||
Ok(Self::new(ipc))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use ethers_core::types::U256;
|
|||
|
||||
use async_trait::async_trait;
|
||||
use futures_channel::mpsc;
|
||||
use futures_util::stream::StreamExt;
|
||||
use futures_util::stream::{Fuse, StreamExt};
|
||||
use oneshot::error::RecvError;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::sync::atomic::Ordering;
|
||||
|
@ -17,7 +17,11 @@ use std::{
|
|||
sync::{atomic::AtomicU64, Arc},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::{io::AsyncWriteExt, net::UnixStream, sync::oneshot};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf},
|
||||
net::UnixStream,
|
||||
sync::oneshot,
|
||||
};
|
||||
use tokio_util::io::ReaderStream;
|
||||
use tracing::{error, warn};
|
||||
|
||||
|
@ -28,24 +32,40 @@ pub struct Ipc {
|
|||
messages_tx: mpsc::UnboundedSender<TransportMessage>,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
type Pending = oneshot::Sender<serde_json::Value>;
|
||||
type Subscription = mpsc::UnboundedSender<serde_json::Value>;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TransportMessage {
|
||||
Request {
|
||||
id: u64,
|
||||
request: String,
|
||||
sender: Pending,
|
||||
},
|
||||
Subscribe {
|
||||
id: U256,
|
||||
sink: Subscription,
|
||||
},
|
||||
Unsubscribe {
|
||||
id: U256,
|
||||
},
|
||||
}
|
||||
|
||||
impl Ipc {
|
||||
/// Creates a new IPC transport from a given path.
|
||||
///
|
||||
/// IPC is only available on Unix.
|
||||
pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self, IpcError> {
|
||||
let stream = UnixStream::connect(path).await?;
|
||||
|
||||
Ok(Self::with_stream(stream))
|
||||
}
|
||||
|
||||
fn with_stream(stream: UnixStream) -> Self {
|
||||
/// Creates a new IPC transport from a Async Reader / Writer
|
||||
fn new<S: AsyncRead + AsyncWrite + Send + 'static>(stream: S) -> Self {
|
||||
let id = Arc::new(AtomicU64::new(1));
|
||||
let (messages_tx, messages_rx) = mpsc::unbounded();
|
||||
|
||||
tokio::spawn(run_server(stream, messages_rx));
|
||||
IpcServer::new(stream, messages_rx).spawn();
|
||||
Self { id, messages_tx }
|
||||
}
|
||||
|
||||
Ipc { id, messages_tx }
|
||||
/// Creates a new IPC transport from a given path using Unix sockets
|
||||
#[cfg(unix)]
|
||||
pub async fn connect<P: AsRef<Path>>(path: P) -> Result<Self, IpcError> {
|
||||
let ipc = UnixStream::connect(path).await?;
|
||||
Ok(Self::new(ipc))
|
||||
}
|
||||
|
||||
fn send(&self, msg: TransportMessage) -> Result<(), IpcError> {
|
||||
|
@ -104,142 +124,186 @@ impl PubsubClient for Ipc {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TransportMessage {
|
||||
Request {
|
||||
id: u64,
|
||||
request: String,
|
||||
sender: oneshot::Sender<serde_json::Value>,
|
||||
},
|
||||
Subscribe {
|
||||
id: U256,
|
||||
sink: mpsc::UnboundedSender<serde_json::Value>,
|
||||
},
|
||||
Unsubscribe {
|
||||
id: U256,
|
||||
},
|
||||
struct IpcServer<T> {
|
||||
socket_reader: Fuse<ReaderStream<ReadHalf<T>>>,
|
||||
socket_writer: WriteHalf<T>,
|
||||
requests: Fuse<mpsc::UnboundedReceiver<TransportMessage>>,
|
||||
pending: HashMap<u64, Pending>,
|
||||
subscriptions: HashMap<U256, Subscription>,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn run_server(
|
||||
unix_stream: UnixStream,
|
||||
messages_rx: mpsc::UnboundedReceiver<TransportMessage>,
|
||||
) -> Result<(), IpcError> {
|
||||
let (socket_reader, mut socket_writer) = unix_stream.into_split();
|
||||
let mut pending_response_txs = HashMap::default();
|
||||
let mut subscription_txs = HashMap::default();
|
||||
impl<T> IpcServer<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
{
|
||||
/// Instantiates the Websocket Server
|
||||
pub fn new(ipc: T, requests: mpsc::UnboundedReceiver<TransportMessage>) -> Self {
|
||||
let (socket_reader, socket_writer) = tokio::io::split(ipc);
|
||||
let socket_reader = ReaderStream::new(socket_reader).fuse();
|
||||
Self {
|
||||
socket_reader,
|
||||
socket_writer,
|
||||
requests: requests.fuse(),
|
||||
pending: HashMap::default(),
|
||||
subscriptions: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
let mut socket_reader = ReaderStream::new(socket_reader);
|
||||
let mut messages_rx = messages_rx.fuse();
|
||||
let mut read_buffer = vec![];
|
||||
let mut closed = false;
|
||||
/// Spawns the event loop
|
||||
fn spawn(mut self)
|
||||
where
|
||||
T: 'static + Send,
|
||||
{
|
||||
let f = async move {
|
||||
let mut read_buffer = Vec::new();
|
||||
loop {
|
||||
let closed = self
|
||||
.process(&mut read_buffer)
|
||||
.await
|
||||
.expect("WS Server panic");
|
||||
if closed && self.pending.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while !closed || !pending_response_txs.is_empty() {
|
||||
tokio::select! {
|
||||
message = messages_rx.next() => match message {
|
||||
Some(TransportMessage::Subscribe{ id, sink }) => {
|
||||
if subscription_txs.insert(id, sink).is_some() {
|
||||
warn!("Replacing a subscription with id {:?}", id);
|
||||
}
|
||||
},
|
||||
Some(TransportMessage::Unsubscribe{ id }) => {
|
||||
if subscription_txs.remove(&id).is_none() {
|
||||
warn!("Unsubscribing not subscribed id {:?}", id);
|
||||
}
|
||||
},
|
||||
Some(TransportMessage::Request{ id, request, sender }) => {
|
||||
if pending_response_txs.insert(id, sender).is_some() {
|
||||
warn!("Replacing a pending request with id {:?}", id);
|
||||
}
|
||||
tokio::spawn(f);
|
||||
}
|
||||
|
||||
if let Err(err) = socket_writer.write(&request.as_bytes()).await {
|
||||
pending_response_txs.remove(&id);
|
||||
error!("IPC write error: {:?}", err);
|
||||
}
|
||||
},
|
||||
None => closed = true,
|
||||
/// Processes 1 item selected from the incoming `requests` or `socket`
|
||||
#[allow(clippy::single_match)]
|
||||
async fn process(&mut self, read_buffer: &mut Vec<u8>) -> Result<bool, IpcError> {
|
||||
futures_util::select! {
|
||||
// Handle requests
|
||||
msg = self.requests.next() => match msg {
|
||||
Some(msg) => self.handle_request(msg).await?,
|
||||
None => return Ok(true),
|
||||
},
|
||||
bytes = socket_reader.next() => match bytes {
|
||||
Some(Ok(bytes)) => {
|
||||
// Extend buffer of previously unread with the new read bytes
|
||||
read_buffer.extend_from_slice(&bytes);
|
||||
|
||||
let read_len = {
|
||||
// Deserialize as many full elements from the stream as exists
|
||||
let mut de: serde_json::StreamDeserializer<_, serde_json::Value> =
|
||||
serde_json::Deserializer::from_slice(&read_buffer).into_iter();
|
||||
|
||||
// Iterate through these elements, and handle responses/notifications
|
||||
while let Some(Ok(value)) = de.next() {
|
||||
if let Ok(notification) = serde_json::from_value::<Notification<serde_json::Value>>(value.clone()) {
|
||||
// Send notify response if okay.
|
||||
if let Err(e) = notify(&mut subscription_txs, notification) {
|
||||
error!("Failed to send IPC notification: {}", e)
|
||||
}
|
||||
} else if let Ok(response) = serde_json::from_value::<Response<serde_json::Value>>(value) {
|
||||
if let Err(e) = respond(&mut pending_response_txs, response) {
|
||||
error!("Failed to send IPC response: {}", e)
|
||||
}
|
||||
} else {
|
||||
warn!("JSON from IPC stream is not a response or notification");
|
||||
}
|
||||
}
|
||||
|
||||
// Get the offset of bytes to handle partial buffer reads
|
||||
de.byte_offset()
|
||||
};
|
||||
|
||||
// Reset buffer to just include the partial value bytes.
|
||||
read_buffer.copy_within(read_len.., 0);
|
||||
read_buffer.truncate(read_buffer.len() - read_len);
|
||||
},
|
||||
// Handle socket messages
|
||||
msg = self.socket_reader.next() => match msg {
|
||||
Some(Ok(msg)) => self.handle_socket(read_buffer, msg).await?,
|
||||
Some(Err(err)) => {
|
||||
error!("IPC read error: {:?}", err);
|
||||
return Err(err.into());
|
||||
},
|
||||
None => break,
|
||||
None => {},
|
||||
},
|
||||
// finished
|
||||
complete => {},
|
||||
};
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn handle_request(&mut self, msg: TransportMessage) -> Result<(), IpcError> {
|
||||
match msg {
|
||||
TransportMessage::Request {
|
||||
id,
|
||||
request,
|
||||
sender,
|
||||
} => {
|
||||
if self.pending.insert(id, sender).is_some() {
|
||||
warn!("Replacing a pending request with id {:?}", id);
|
||||
}
|
||||
|
||||
if let Err(err) = self.socket_writer.write(&request.as_bytes()).await {
|
||||
error!("WS connection error: {:?}", err);
|
||||
self.pending.remove(&id);
|
||||
}
|
||||
}
|
||||
TransportMessage::Subscribe { id, sink } => {
|
||||
if self.subscriptions.insert(id, sink).is_some() {
|
||||
warn!("Replacing already-registered subscription with id {:?}", id);
|
||||
}
|
||||
}
|
||||
TransportMessage::Unsubscribe { id } => {
|
||||
if self.subscriptions.remove(&id).is_none() {
|
||||
warn!(
|
||||
"Unsubscribing from non-existent subscription with id {:?}",
|
||||
id
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
async fn handle_socket(
|
||||
&mut self,
|
||||
read_buffer: &mut Vec<u8>,
|
||||
bytes: bytes::Bytes,
|
||||
) -> Result<(), IpcError> {
|
||||
// Extend buffer of previously unread with the new read bytes
|
||||
read_buffer.extend_from_slice(&bytes);
|
||||
|
||||
/// Sends notification through the channel based on the ID of the subscription.
|
||||
/// This handles streaming responses.
|
||||
fn notify(
|
||||
subscription_txs: &mut HashMap<U256, mpsc::UnboundedSender<serde_json::Value>>,
|
||||
notification: Notification<serde_json::Value>,
|
||||
) -> Result<(), IpcError> {
|
||||
let id = notification.params.subscription;
|
||||
if let Some(tx) = subscription_txs.get(&id) {
|
||||
tx.unbounded_send(notification.params.result)
|
||||
.map_err(|_| IpcError::ChannelError(format!("Subscription receiver {} dropped", id)))?;
|
||||
let read_len = {
|
||||
// Deserialize as many full elements from the stream as exists
|
||||
let mut de: serde_json::StreamDeserializer<_, serde_json::Value> =
|
||||
serde_json::Deserializer::from_slice(&read_buffer).into_iter();
|
||||
|
||||
// Iterate through these elements, and handle responses/notifications
|
||||
while let Some(Ok(value)) = de.next() {
|
||||
if let Ok(notification) =
|
||||
serde_json::from_value::<Notification<serde_json::Value>>(value.clone())
|
||||
{
|
||||
// Send notify response if okay.
|
||||
if let Err(e) = self.notify(notification) {
|
||||
error!("Failed to send IPC notification: {}", e)
|
||||
}
|
||||
} else if let Ok(response) =
|
||||
serde_json::from_value::<Response<serde_json::Value>>(value)
|
||||
{
|
||||
if let Err(e) = self.respond(response) {
|
||||
error!("Failed to send IPC response: {}", e)
|
||||
}
|
||||
} else {
|
||||
warn!("JSON from IPC stream is not a response or notification");
|
||||
}
|
||||
}
|
||||
|
||||
// Get the offset of bytes to handle partial buffer reads
|
||||
de.byte_offset()
|
||||
};
|
||||
|
||||
// Reset buffer to just include the partial value bytes.
|
||||
read_buffer.copy_within(read_len.., 0);
|
||||
read_buffer.truncate(read_buffer.len() - read_len);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
/// Sends notification through the channel based on the ID of the subscription.
|
||||
/// This handles streaming responses.
|
||||
fn notify(&mut self, notification: Notification<serde_json::Value>) -> Result<(), IpcError> {
|
||||
let id = notification.params.subscription;
|
||||
if let Some(tx) = self.subscriptions.get(&id) {
|
||||
tx.unbounded_send(notification.params.result).map_err(|_| {
|
||||
IpcError::ChannelError(format!("Subscription receiver {} dropped", id))
|
||||
})?;
|
||||
}
|
||||
|
||||
/// Sends JSON response through the channel based on the ID in that response.
|
||||
/// This handles RPC calls with only one response, and the channel entry is dropped after sending.
|
||||
fn respond(
|
||||
pending_response_txs: &mut HashMap<u64, oneshot::Sender<serde_json::Value>>,
|
||||
output: Response<serde_json::Value>,
|
||||
) -> Result<(), IpcError> {
|
||||
let id = output.id;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Converts output into result, to send data if valid response.
|
||||
let value = output.data.into_result()?;
|
||||
/// Sends JSON response through the channel based on the ID in that response.
|
||||
/// This handles RPC calls with only one response, and the channel entry is dropped after sending.
|
||||
fn respond(&mut self, output: Response<serde_json::Value>) -> Result<(), IpcError> {
|
||||
let id = output.id;
|
||||
|
||||
let response_tx = pending_response_txs.remove(&id).ok_or_else(|| {
|
||||
IpcError::ChannelError("No response channel exists for the response ID".to_string())
|
||||
})?;
|
||||
// Converts output into result, to send data if valid response.
|
||||
let value = output.data.into_result()?;
|
||||
|
||||
response_tx.send(value).map_err(|_| {
|
||||
IpcError::ChannelError("Receiver channel for response has been dropped".to_string())
|
||||
})?;
|
||||
let response_tx = self.pending.remove(&id).ok_or_else(|| {
|
||||
IpcError::ChannelError("No response channel exists for the response ID".to_string())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
response_tx.send(value).map_err(|_| {
|
||||
IpcError::ChannelError("Receiver channel for response has been dropped".to_string())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
@ -283,7 +347,7 @@ mod test {
|
|||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.into_temp_path().to_path_buf();
|
||||
let _geth = Geth::new().block_time(1u64).ipc_path(&path).spawn();
|
||||
let ipc = Ipc::new(path).await.unwrap();
|
||||
let ipc = Ipc::connect(path).await.unwrap();
|
||||
|
||||
let block_num: U256 = ipc.request("eth_blockNumber", ()).await.unwrap();
|
||||
std::thread::sleep(std::time::Duration::new(3, 0));
|
||||
|
@ -295,8 +359,11 @@ mod test {
|
|||
async fn subscription() {
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.into_temp_path().to_path_buf();
|
||||
let _geth = Geth::new().block_time(1u64).ipc_path(&path).spawn();
|
||||
let ipc = Ipc::new(path).await.unwrap();
|
||||
let _geth = Geth::new().block_time(2u64).ipc_path(&path).spawn();
|
||||
let ipc = Ipc::connect(path).await.unwrap();
|
||||
|
||||
let sub_id: U256 = ipc.request("eth_subscribe", ["newHeads"]).await.unwrap();
|
||||
let mut stream = ipc.subscribe(sub_id).unwrap();
|
||||
|
||||
// Subscribing requires sending the sub request and then subscribing to
|
||||
// the returned sub_id
|
||||
|
@ -305,9 +372,6 @@ mod test {
|
|||
.await
|
||||
.unwrap()
|
||||
.as_u64();
|
||||
let sub_id: U256 = ipc.request("eth_subscribe", ["newHeads"]).await.unwrap();
|
||||
let mut stream = ipc.subscribe(sub_id).unwrap();
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
for _ in 0..3 {
|
||||
let item = stream.next().await.unwrap();
|
||||
|
|
|
@ -3,8 +3,9 @@ use std::time::Duration;
|
|||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let ws = Ipc::new("~/.ethereum/geth.ipc").await?;
|
||||
let provider = Provider::new(ws).interval(Duration::from_millis(2000));
|
||||
let provider = Provider::connect_ipc("~/.ethereum/geth.ipc")
|
||||
.await?
|
||||
.interval(Duration::from_millis(2000));
|
||||
let block = provider.get_block_number().await?;
|
||||
println!("Current block: {}", block);
|
||||
let mut stream = provider.watch_blocks().await?.stream();
|
||||
|
|
Loading…
Reference in New Issue