refactor: make IPC generic over AsyncRead/Write (#264)

* refactor: make IPC generic over AsyncRead/Write

* chore(ipc): fix typo
This commit is contained in:
Georgios Konstantopoulos 2021-04-08 11:44:48 +03:00 committed by GitHub
parent 42b10cca9a
commit 66a503d294
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 206 additions and 139 deletions

1
Cargo.lock generated
View File

@ -907,6 +907,7 @@ version = "0.2.2"
dependencies = [
"async-trait",
"auto_impl",
"bytes",
"ethers",
"ethers-core",
"futures-channel",

View File

@ -40,6 +40,7 @@ tracing-futures = { version = "0.2.5", default-features = false, features = ["st
tokio = { version = "1.4", default-features = false, optional = true }
tokio-tungstenite = { version = "0.13.0", default-features = false, features = ["connect", "tls"], optional = true }
tokio-util = { version = "0.6.5", default-features = false, features = ["io"], optional = true }
bytes = { version = "1.0.1", default-features = false, optional = true }
[dev-dependencies]
ethers = { version = "0.2", path = "../ethers" }
@ -50,4 +51,4 @@ tempfile = "3.2.0"
default = ["ws", "ipc"]
celo = ["ethers-core/celo"]
ws = ["tokio", "tokio-tungstenite"]
ipc = ["tokio", "tokio/io-util", "tokio-util"]
ipc = ["tokio", "tokio/io-util", "tokio-util", "bytes"]

View File

@ -728,7 +728,7 @@ impl Provider<crate::Ws> {
impl Provider<crate::Ipc> {
/// Direct connection to an IPC socket.
pub async fn connect_ipc(path: impl AsRef<std::path::Path>) -> Result<Self, ProviderError> {
let ipc = crate::Ipc::new(path).await?;
let ipc = crate::Ipc::connect(path).await?;
Ok(Self::new(ipc))
}
}

View File

@ -7,7 +7,7 @@ use ethers_core::types::U256;
use async_trait::async_trait;
use futures_channel::mpsc;
use futures_util::stream::StreamExt;
use futures_util::stream::{Fuse, StreamExt};
use oneshot::error::RecvError;
use serde::{de::DeserializeOwned, Serialize};
use std::sync::atomic::Ordering;
@ -17,7 +17,11 @@ use std::{
sync::{atomic::AtomicU64, Arc},
};
use thiserror::Error;
use tokio::{io::AsyncWriteExt, net::UnixStream, sync::oneshot};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf},
net::UnixStream,
sync::oneshot,
};
use tokio_util::io::ReaderStream;
use tracing::{error, warn};
@ -28,24 +32,40 @@ pub struct Ipc {
messages_tx: mpsc::UnboundedSender<TransportMessage>,
}
#[cfg(unix)]
type Pending = oneshot::Sender<serde_json::Value>;
type Subscription = mpsc::UnboundedSender<serde_json::Value>;
#[derive(Debug)]
enum TransportMessage {
Request {
id: u64,
request: String,
sender: Pending,
},
Subscribe {
id: U256,
sink: Subscription,
},
Unsubscribe {
id: U256,
},
}
impl Ipc {
/// Creates a new IPC transport from a given path.
///
/// IPC is only available on Unix.
pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self, IpcError> {
let stream = UnixStream::connect(path).await?;
Ok(Self::with_stream(stream))
}
fn with_stream(stream: UnixStream) -> Self {
/// Creates a new IPC transport from a Async Reader / Writer
fn new<S: AsyncRead + AsyncWrite + Send + 'static>(stream: S) -> Self {
let id = Arc::new(AtomicU64::new(1));
let (messages_tx, messages_rx) = mpsc::unbounded();
tokio::spawn(run_server(stream, messages_rx));
IpcServer::new(stream, messages_rx).spawn();
Self { id, messages_tx }
}
Ipc { id, messages_tx }
/// Creates a new IPC transport from a given path using Unix sockets
#[cfg(unix)]
pub async fn connect<P: AsRef<Path>>(path: P) -> Result<Self, IpcError> {
let ipc = UnixStream::connect(path).await?;
Ok(Self::new(ipc))
}
fn send(&self, msg: TransportMessage) -> Result<(), IpcError> {
@ -104,142 +124,186 @@ impl PubsubClient for Ipc {
}
}
#[derive(Debug)]
enum TransportMessage {
Request {
id: u64,
request: String,
sender: oneshot::Sender<serde_json::Value>,
},
Subscribe {
id: U256,
sink: mpsc::UnboundedSender<serde_json::Value>,
},
Unsubscribe {
id: U256,
},
struct IpcServer<T> {
socket_reader: Fuse<ReaderStream<ReadHalf<T>>>,
socket_writer: WriteHalf<T>,
requests: Fuse<mpsc::UnboundedReceiver<TransportMessage>>,
pending: HashMap<u64, Pending>,
subscriptions: HashMap<U256, Subscription>,
}
#[cfg(unix)]
async fn run_server(
unix_stream: UnixStream,
messages_rx: mpsc::UnboundedReceiver<TransportMessage>,
) -> Result<(), IpcError> {
let (socket_reader, mut socket_writer) = unix_stream.into_split();
let mut pending_response_txs = HashMap::default();
let mut subscription_txs = HashMap::default();
impl<T> IpcServer<T>
where
T: AsyncRead + AsyncWrite,
{
/// Instantiates the Websocket Server
pub fn new(ipc: T, requests: mpsc::UnboundedReceiver<TransportMessage>) -> Self {
let (socket_reader, socket_writer) = tokio::io::split(ipc);
let socket_reader = ReaderStream::new(socket_reader).fuse();
Self {
socket_reader,
socket_writer,
requests: requests.fuse(),
pending: HashMap::default(),
subscriptions: HashMap::default(),
}
}
let mut socket_reader = ReaderStream::new(socket_reader);
let mut messages_rx = messages_rx.fuse();
let mut read_buffer = vec![];
let mut closed = false;
/// Spawns the event loop
fn spawn(mut self)
where
T: 'static + Send,
{
let f = async move {
let mut read_buffer = Vec::new();
loop {
let closed = self
.process(&mut read_buffer)
.await
.expect("WS Server panic");
if closed && self.pending.is_empty() {
break;
}
}
};
while !closed || !pending_response_txs.is_empty() {
tokio::select! {
message = messages_rx.next() => match message {
Some(TransportMessage::Subscribe{ id, sink }) => {
if subscription_txs.insert(id, sink).is_some() {
warn!("Replacing a subscription with id {:?}", id);
}
},
Some(TransportMessage::Unsubscribe{ id }) => {
if subscription_txs.remove(&id).is_none() {
warn!("Unsubscribing not subscribed id {:?}", id);
}
},
Some(TransportMessage::Request{ id, request, sender }) => {
if pending_response_txs.insert(id, sender).is_some() {
warn!("Replacing a pending request with id {:?}", id);
}
tokio::spawn(f);
}
if let Err(err) = socket_writer.write(&request.as_bytes()).await {
pending_response_txs.remove(&id);
error!("IPC write error: {:?}", err);
}
},
None => closed = true,
/// Processes 1 item selected from the incoming `requests` or `socket`
#[allow(clippy::single_match)]
async fn process(&mut self, read_buffer: &mut Vec<u8>) -> Result<bool, IpcError> {
futures_util::select! {
// Handle requests
msg = self.requests.next() => match msg {
Some(msg) => self.handle_request(msg).await?,
None => return Ok(true),
},
bytes = socket_reader.next() => match bytes {
Some(Ok(bytes)) => {
// Extend buffer of previously unread with the new read bytes
read_buffer.extend_from_slice(&bytes);
let read_len = {
// Deserialize as many full elements from the stream as exists
let mut de: serde_json::StreamDeserializer<_, serde_json::Value> =
serde_json::Deserializer::from_slice(&read_buffer).into_iter();
// Iterate through these elements, and handle responses/notifications
while let Some(Ok(value)) = de.next() {
if let Ok(notification) = serde_json::from_value::<Notification<serde_json::Value>>(value.clone()) {
// Send notify response if okay.
if let Err(e) = notify(&mut subscription_txs, notification) {
error!("Failed to send IPC notification: {}", e)
}
} else if let Ok(response) = serde_json::from_value::<Response<serde_json::Value>>(value) {
if let Err(e) = respond(&mut pending_response_txs, response) {
error!("Failed to send IPC response: {}", e)
}
} else {
warn!("JSON from IPC stream is not a response or notification");
}
}
// Get the offset of bytes to handle partial buffer reads
de.byte_offset()
};
// Reset buffer to just include the partial value bytes.
read_buffer.copy_within(read_len.., 0);
read_buffer.truncate(read_buffer.len() - read_len);
},
// Handle socket messages
msg = self.socket_reader.next() => match msg {
Some(Ok(msg)) => self.handle_socket(read_buffer, msg).await?,
Some(Err(err)) => {
error!("IPC read error: {:?}", err);
return Err(err.into());
},
None => break,
None => {},
},
// finished
complete => {},
};
Ok(false)
}
async fn handle_request(&mut self, msg: TransportMessage) -> Result<(), IpcError> {
match msg {
TransportMessage::Request {
id,
request,
sender,
} => {
if self.pending.insert(id, sender).is_some() {
warn!("Replacing a pending request with id {:?}", id);
}
if let Err(err) = self.socket_writer.write(&request.as_bytes()).await {
error!("WS connection error: {:?}", err);
self.pending.remove(&id);
}
}
TransportMessage::Subscribe { id, sink } => {
if self.subscriptions.insert(id, sink).is_some() {
warn!("Replacing already-registered subscription with id {:?}", id);
}
}
TransportMessage::Unsubscribe { id } => {
if self.subscriptions.remove(&id).is_none() {
warn!(
"Unsubscribing from non-existent subscription with id {:?}",
id
);
}
}
};
Ok(())
}
Ok(())
}
async fn handle_socket(
&mut self,
read_buffer: &mut Vec<u8>,
bytes: bytes::Bytes,
) -> Result<(), IpcError> {
// Extend buffer of previously unread with the new read bytes
read_buffer.extend_from_slice(&bytes);
/// Sends notification through the channel based on the ID of the subscription.
/// This handles streaming responses.
fn notify(
subscription_txs: &mut HashMap<U256, mpsc::UnboundedSender<serde_json::Value>>,
notification: Notification<serde_json::Value>,
) -> Result<(), IpcError> {
let id = notification.params.subscription;
if let Some(tx) = subscription_txs.get(&id) {
tx.unbounded_send(notification.params.result)
.map_err(|_| IpcError::ChannelError(format!("Subscription receiver {} dropped", id)))?;
let read_len = {
// Deserialize as many full elements from the stream as exists
let mut de: serde_json::StreamDeserializer<_, serde_json::Value> =
serde_json::Deserializer::from_slice(&read_buffer).into_iter();
// Iterate through these elements, and handle responses/notifications
while let Some(Ok(value)) = de.next() {
if let Ok(notification) =
serde_json::from_value::<Notification<serde_json::Value>>(value.clone())
{
// Send notify response if okay.
if let Err(e) = self.notify(notification) {
error!("Failed to send IPC notification: {}", e)
}
} else if let Ok(response) =
serde_json::from_value::<Response<serde_json::Value>>(value)
{
if let Err(e) = self.respond(response) {
error!("Failed to send IPC response: {}", e)
}
} else {
warn!("JSON from IPC stream is not a response or notification");
}
}
// Get the offset of bytes to handle partial buffer reads
de.byte_offset()
};
// Reset buffer to just include the partial value bytes.
read_buffer.copy_within(read_len.., 0);
read_buffer.truncate(read_buffer.len() - read_len);
Ok(())
}
Ok(())
}
/// Sends notification through the channel based on the ID of the subscription.
/// This handles streaming responses.
fn notify(&mut self, notification: Notification<serde_json::Value>) -> Result<(), IpcError> {
let id = notification.params.subscription;
if let Some(tx) = self.subscriptions.get(&id) {
tx.unbounded_send(notification.params.result).map_err(|_| {
IpcError::ChannelError(format!("Subscription receiver {} dropped", id))
})?;
}
/// Sends JSON response through the channel based on the ID in that response.
/// This handles RPC calls with only one response, and the channel entry is dropped after sending.
fn respond(
pending_response_txs: &mut HashMap<u64, oneshot::Sender<serde_json::Value>>,
output: Response<serde_json::Value>,
) -> Result<(), IpcError> {
let id = output.id;
Ok(())
}
// Converts output into result, to send data if valid response.
let value = output.data.into_result()?;
/// Sends JSON response through the channel based on the ID in that response.
/// This handles RPC calls with only one response, and the channel entry is dropped after sending.
fn respond(&mut self, output: Response<serde_json::Value>) -> Result<(), IpcError> {
let id = output.id;
let response_tx = pending_response_txs.remove(&id).ok_or_else(|| {
IpcError::ChannelError("No response channel exists for the response ID".to_string())
})?;
// Converts output into result, to send data if valid response.
let value = output.data.into_result()?;
response_tx.send(value).map_err(|_| {
IpcError::ChannelError("Receiver channel for response has been dropped".to_string())
})?;
let response_tx = self.pending.remove(&id).ok_or_else(|| {
IpcError::ChannelError("No response channel exists for the response ID".to_string())
})?;
Ok(())
response_tx.send(value).map_err(|_| {
IpcError::ChannelError("Receiver channel for response has been dropped".to_string())
})?;
Ok(())
}
}
#[derive(Error, Debug)]
@ -283,7 +347,7 @@ mod test {
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.into_temp_path().to_path_buf();
let _geth = Geth::new().block_time(1u64).ipc_path(&path).spawn();
let ipc = Ipc::new(path).await.unwrap();
let ipc = Ipc::connect(path).await.unwrap();
let block_num: U256 = ipc.request("eth_blockNumber", ()).await.unwrap();
std::thread::sleep(std::time::Duration::new(3, 0));
@ -295,8 +359,11 @@ mod test {
async fn subscription() {
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.into_temp_path().to_path_buf();
let _geth = Geth::new().block_time(1u64).ipc_path(&path).spawn();
let ipc = Ipc::new(path).await.unwrap();
let _geth = Geth::new().block_time(2u64).ipc_path(&path).spawn();
let ipc = Ipc::connect(path).await.unwrap();
let sub_id: U256 = ipc.request("eth_subscribe", ["newHeads"]).await.unwrap();
let mut stream = ipc.subscribe(sub_id).unwrap();
// Subscribing requires sending the sub request and then subscribing to
// the returned sub_id
@ -305,9 +372,6 @@ mod test {
.await
.unwrap()
.as_u64();
let sub_id: U256 = ipc.request("eth_subscribe", ["newHeads"]).await.unwrap();
let mut stream = ipc.subscribe(sub_id).unwrap();
let mut blocks = Vec::new();
for _ in 0..3 {
let item = stream.next().await.unwrap();

View File

@ -3,8 +3,9 @@ use std::time::Duration;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let ws = Ipc::new("~/.ethereum/geth.ipc").await?;
let provider = Provider::new(ws).interval(Duration::from_millis(2000));
let provider = Provider::connect_ipc("~/.ethereum/geth.ipc")
.await?
.interval(Duration::from_millis(2000));
let block = provider.get_block_number().await?;
println!("Current block: {}", block);
let mut stream = provider.watch_blocks().await?.stream();