From e7904293550e7d3fe7afc4514be69f0b5ed78f12 Mon Sep 17 00:00:00 2001 From: James Prestwich <10149425+prestwich@users.noreply.github.com> Date: Fri, 20 Aug 2021 10:23:39 -0700 Subject: [PATCH] Refactor WS handling code (#397) * refactor: simplify handle_text message * refactor: rename 'TransportMessage' to 'Instruction' * nit: update docstring on tick method * refactor: rename dispatch to service * nit: rename WsServer.requests to WsServer.instructions for consistency * refactor: simplify handle_text, remove unwrap * nit: fix misspelled comment --- ethers-providers/src/pubsub.rs | 2 +- ethers-providers/src/transports/ws.rs | 226 ++++++++++++++------------ 2 files changed, 127 insertions(+), 101 deletions(-) diff --git a/ethers-providers/src/pubsub.rs b/ethers-providers/src/pubsub.rs index 668704cc..6bd6bc44 100644 --- a/ethers-providers/src/pubsub.rs +++ b/ethers-providers/src/pubsub.rs @@ -48,7 +48,7 @@ where /// /// ### Note /// Most providers treat `SubscriptionStream` IDs as global singletons. - /// Instanitating this directly with a known ID will likely cause any + /// Instantiating this directly with a known ID will likely cause any /// existing streams with that ID to end. To avoid this, start a new stream /// using [`Provider::subscribe`] instead of `SubscriptionStream::new`. pub fn new(id: U256, provider: &'a Provider

) -> Result { diff --git a/ethers-providers/src/transports/ws.rs b/ethers-providers/src/transports/ws.rs index 38b9c7aa..5d029aa1 100644 --- a/ethers-providers/src/transports/ws.rs +++ b/ethers-providers/src/transports/ws.rs @@ -30,6 +30,30 @@ use tokio_tungstenite::{ }; use tracing::{error, warn}; +type Pending = oneshot::Sender>; +type Subscription = mpsc::UnboundedSender; + +/// Instructions for the `WsServer`. +enum Instruction { + /// JSON-RPC request + Request { + id: u64, + request: String, + sender: Pending, + }, + /// Create a new subscription + Subscribe { id: U256, sink: Subscription }, + /// Cancel an existing subscription + Unsubscribe { id: U256 }, +} + +#[derive(Debug, serde::Deserialize)] +#[serde(untagged)] +enum Incoming { + Notification(Notification), + Response(Response), +} + /// A JSON-RPC Client over Websockets. /// /// ```no_run @@ -43,25 +67,7 @@ use tracing::{error, warn}; #[derive(Clone)] pub struct Ws { id: Arc, - requests: mpsc::UnboundedSender, -} - -type Pending = oneshot::Sender>; -type Subscription = mpsc::UnboundedSender; - -enum TransportMessage { - Request { - id: u64, - request: String, - sender: Pending, - }, - Subscribe { - id: U256, - sink: Subscription, - }, - Unsubscribe { - id: U256, - }, + instructions: mpsc::UnboundedSender, } impl Debug for Ws { @@ -90,13 +96,13 @@ impl Ws { Self { id: Arc::new(AtomicU64::new(0)), - requests: sink, + instructions: sink, } } /// Returns true if the WS connection is active, false otherwise pub fn ready(&self) -> bool { - !self.requests.is_closed() + !self.instructions.is_closed() } /// Initializes a new WebSocket Client @@ -107,8 +113,10 @@ impl Ws { Ok(Self::new(ws)) } - fn send(&self, msg: TransportMessage) -> Result<(), ClientError> { - self.requests.unbounded_send(msg).map_err(to_client_error) + fn send(&self, msg: Instruction) -> Result<(), ClientError> { + self.instructions + .unbounded_send(msg) + .map_err(to_client_error) } } @@ -125,14 +133,14 @@ impl JsonRpcClient for Ws { // send the message let (sender, receiver) = oneshot::channel(); - let payload = TransportMessage::Request { + let payload = Instruction::Request { id: next_id, request: serde_json::to_string(&Request::new(next_id, method, params))?, sender, }; // send the data - self.send(payload).map_err(to_client_error)?; + self.send(payload)?; // wait for the response let res = receiver.await?; @@ -150,7 +158,7 @@ impl PubsubClient for Ws { fn subscribe>(&self, id: T) -> Result { let (sink, stream) = mpsc::unbounded(); - self.send(TransportMessage::Subscribe { + self.send(Instruction::Subscribe { id: id.into(), sink, })?; @@ -158,13 +166,13 @@ impl PubsubClient for Ws { } fn unsubscribe>(&self, id: T) -> Result<(), ClientError> { - self.send(TransportMessage::Unsubscribe { id: id.into() }) + self.send(Instruction::Unsubscribe { id: id.into() }) } } struct WsServer { ws: Fuse, - requests: Fuse>, + instructions: Fuse>, pending: BTreeMap, subscriptions: BTreeMap, @@ -179,12 +187,12 @@ where + Unpin, { /// Instantiates the Websocket Server - fn new(ws: S, requests: mpsc::UnboundedReceiver) -> Self { + fn new(ws: S, requests: mpsc::UnboundedReceiver) -> Self { Self { // Fuse the 2 steams together, so that we can `select` them in the // Stream implementation ws: ws.fuse(), - requests: requests.fuse(), + instructions: requests.fuse(), pending: BTreeMap::default(), subscriptions: BTreeMap::default(), } @@ -197,13 +205,13 @@ where { let f = async move { loop { - match self.process().await { + match self.tick().await { Err(ClientError::UnexpectedClose) => { tracing::error!("{}", ClientError::UnexpectedClose); break; } - Err(_) => { - panic!("WS Server panic"); + Err(e) => { + panic!("WS Server panic: {}", e); } _ => {} } @@ -213,63 +221,84 @@ where tokio::spawn(f); } - /// Processes 1 item selected from the incoming `requests` or `ws` - #[allow(clippy::single_match)] - async fn process(&mut self) -> Result<(), ClientError> { - futures_util::select! { - // Handle requests - msg = self.requests.select_next_some() => { - self.handle_request(msg).await?; - }, - // Handle ws messages - msg = self.ws.next() => match msg { - Some(Ok(msg)) => self.handle_ws(msg).await?, - // TODO: Log the error? - Some(Err(_)) => {}, - None => { - return Err(ClientError::UnexpectedClose); - }, - } - }; + // dispatch an RPC request + async fn service_request( + &mut self, + id: u64, + request: String, + sender: Pending, + ) -> Result<(), ClientError> { + if self.pending.insert(id, sender).is_some() { + warn!("Replacing a pending request with id {:?}", id); + } + if let Err(e) = self.ws.send(Message::Text(request)).await { + error!("WS connection error: {:?}", e); + self.pending.remove(&id); + } Ok(()) } - async fn handle_request(&mut self, msg: TransportMessage) -> Result<(), ClientError> { - match msg { - TransportMessage::Request { + /// Dispatch a subscription request + async fn service_subscribe(&mut self, id: U256, sink: Subscription) -> Result<(), ClientError> { + if self.subscriptions.insert(id, sink).is_some() { + warn!("Replacing already-registered subscription with id {:?}", id); + } + Ok(()) + } + + /// Dispatch a unsubscribe request + async fn service_unsubscribe(&mut self, id: U256) -> Result<(), ClientError> { + if self.subscriptions.remove(&id).is_none() { + warn!( + "Unsubscribing from non-existent subscription with id {:?}", + id + ); + } + Ok(()) + } + + /// Dispatch an outgoing message + async fn service(&mut self, instruction: Instruction) -> Result<(), ClientError> { + match instruction { + Instruction::Request { id, request, sender, - } => { - if self.pending.insert(id, sender).is_some() { - warn!("Replacing a pending request with id {:?}", id); - } - - if let Err(e) = self.ws.send(Message::Text(request)).await { - error!("WS connection error: {:?}", e); - self.pending.remove(&id); - } - } - TransportMessage::Subscribe { id, sink } => { - if self.subscriptions.insert(id, sink).is_some() { - warn!("Replacing already-registered subscription with id {:?}", id); - } - } - TransportMessage::Unsubscribe { id } => { - if self.subscriptions.remove(&id).is_none() { - warn!( - "Unsubscribing from non-existent subscription with id {:?}", - id - ); - } - } - }; + } => self.service_request(id, request, sender).await, + Instruction::Subscribe { id, sink } => self.service_subscribe(id, sink).await, + Instruction::Unsubscribe { id } => self.service_unsubscribe(id).await, + } + } + async fn handle_ping(&mut self, inner: Vec) -> Result<(), ClientError> { + self.ws.send(Message::Pong(inner)).await?; Ok(()) } - async fn handle_ws(&mut self, resp: Message) -> Result<(), ClientError> { + async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> { + match serde_json::from_str::(&inner) { + Err(_) => {} + Ok(Incoming::Response(resp)) => { + if let Some(request) = self.pending.remove(&resp.id) { + request + .send(resp.data.into_result()) + .map_err(to_client_error)?; + } + } + Ok(Incoming::Notification(notification)) => { + let id = notification.params.subscription; + if let Some(stream) = self.subscriptions.get(&id) { + stream + .unbounded_send(notification.params.result) + .map_err(to_client_error)?; + } + } + } + Ok(()) + } + + async fn handle(&mut self, resp: Message) -> Result<(), ClientError> { match resp { Message::Text(inner) => self.handle_text(inner).await, Message::Ping(inner) => self.handle_ping(inner).await, @@ -280,28 +309,25 @@ where } } - async fn handle_ping(&mut self, inner: Vec) -> Result<(), ClientError> { - self.ws.send(Message::Pong(inner)).await?; - Ok(()) - } + /// Processes 1 instruction or 1 incoming websocket message + #[allow(clippy::single_match)] + async fn tick(&mut self) -> Result<(), ClientError> { + futures_util::select! { + // Handle requests + instruction = self.instructions.select_next_some() => { + self.service(instruction).await?; + }, + // Handle ws messages + resp = self.ws.next() => match resp { + Some(Ok(resp)) => self.handle(resp).await?, + // TODO: Log the error? + Some(Err(_)) => {}, + None => { + return Err(ClientError::UnexpectedClose); + }, + } + }; - async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> { - if let Ok(resp) = serde_json::from_str::>(&inner) { - if let Some(request) = self.pending.remove(&resp.id) { - request - .send(resp.data.into_result()) - .map_err(to_client_error)?; - } - } else if let Ok(notification) = - serde_json::from_str::>(&inner) - { - let id = notification.params.subscription; - if let Some(stream) = self.subscriptions.get(&id) { - stream - .unbounded_send(notification.params.result) - .map_err(to_client_error)?; - } - } Ok(()) } }