Refactor WS handling code (#397)

* refactor: simplify handle_text message

* refactor: rename 'TransportMessage' to 'Instruction'

* nit: update docstring on tick method

* refactor: rename dispatch to service

* nit: rename WsServer.requests to WsServer.instructions for consistency

* refactor: simplify handle_text, remove unwrap

* nit: fix misspelled comment
This commit is contained in:
James Prestwich 2021-08-20 10:23:39 -07:00 committed by GitHub
parent 0eee674ba0
commit e790429355
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 101 deletions

View File

@ -48,7 +48,7 @@ where
/// ///
/// ### Note /// ### Note
/// Most providers treat `SubscriptionStream` IDs as global singletons. /// Most providers treat `SubscriptionStream` IDs as global singletons.
/// Instanitating this directly with a known ID will likely cause any /// Instantiating this directly with a known ID will likely cause any
/// existing streams with that ID to end. To avoid this, start a new stream /// existing streams with that ID to end. To avoid this, start a new stream
/// using [`Provider::subscribe`] instead of `SubscriptionStream::new`. /// using [`Provider::subscribe`] instead of `SubscriptionStream::new`.
pub fn new(id: U256, provider: &'a Provider<P>) -> Result<Self, P::Error> { pub fn new(id: U256, provider: &'a Provider<P>) -> Result<Self, P::Error> {

View File

@ -30,6 +30,30 @@ use tokio_tungstenite::{
}; };
use tracing::{error, warn}; use tracing::{error, warn};
type Pending = oneshot::Sender<Result<serde_json::Value, JsonRpcError>>;
type Subscription = mpsc::UnboundedSender<serde_json::Value>;
/// Instructions for the `WsServer`.
enum Instruction {
/// JSON-RPC request
Request {
id: u64,
request: String,
sender: Pending,
},
/// Create a new subscription
Subscribe { id: U256, sink: Subscription },
/// Cancel an existing subscription
Unsubscribe { id: U256 },
}
#[derive(Debug, serde::Deserialize)]
#[serde(untagged)]
enum Incoming {
Notification(Notification<serde_json::Value>),
Response(Response<serde_json::Value>),
}
/// A JSON-RPC Client over Websockets. /// A JSON-RPC Client over Websockets.
/// ///
/// ```no_run /// ```no_run
@ -43,25 +67,7 @@ use tracing::{error, warn};
#[derive(Clone)] #[derive(Clone)]
pub struct Ws { pub struct Ws {
id: Arc<AtomicU64>, id: Arc<AtomicU64>,
requests: mpsc::UnboundedSender<TransportMessage>, instructions: mpsc::UnboundedSender<Instruction>,
}
type Pending = oneshot::Sender<Result<serde_json::Value, JsonRpcError>>;
type Subscription = mpsc::UnboundedSender<serde_json::Value>;
enum TransportMessage {
Request {
id: u64,
request: String,
sender: Pending,
},
Subscribe {
id: U256,
sink: Subscription,
},
Unsubscribe {
id: U256,
},
} }
impl Debug for Ws { impl Debug for Ws {
@ -90,13 +96,13 @@ impl Ws {
Self { Self {
id: Arc::new(AtomicU64::new(0)), id: Arc::new(AtomicU64::new(0)),
requests: sink, instructions: sink,
} }
} }
/// Returns true if the WS connection is active, false otherwise /// Returns true if the WS connection is active, false otherwise
pub fn ready(&self) -> bool { pub fn ready(&self) -> bool {
!self.requests.is_closed() !self.instructions.is_closed()
} }
/// Initializes a new WebSocket Client /// Initializes a new WebSocket Client
@ -107,8 +113,10 @@ impl Ws {
Ok(Self::new(ws)) Ok(Self::new(ws))
} }
fn send(&self, msg: TransportMessage) -> Result<(), ClientError> { fn send(&self, msg: Instruction) -> Result<(), ClientError> {
self.requests.unbounded_send(msg).map_err(to_client_error) self.instructions
.unbounded_send(msg)
.map_err(to_client_error)
} }
} }
@ -125,14 +133,14 @@ impl JsonRpcClient for Ws {
// send the message // send the message
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();
let payload = TransportMessage::Request { let payload = Instruction::Request {
id: next_id, id: next_id,
request: serde_json::to_string(&Request::new(next_id, method, params))?, request: serde_json::to_string(&Request::new(next_id, method, params))?,
sender, sender,
}; };
// send the data // send the data
self.send(payload).map_err(to_client_error)?; self.send(payload)?;
// wait for the response // wait for the response
let res = receiver.await?; let res = receiver.await?;
@ -150,7 +158,7 @@ impl PubsubClient for Ws {
fn subscribe<T: Into<U256>>(&self, id: T) -> Result<Self::NotificationStream, ClientError> { fn subscribe<T: Into<U256>>(&self, id: T) -> Result<Self::NotificationStream, ClientError> {
let (sink, stream) = mpsc::unbounded(); let (sink, stream) = mpsc::unbounded();
self.send(TransportMessage::Subscribe { self.send(Instruction::Subscribe {
id: id.into(), id: id.into(),
sink, sink,
})?; })?;
@ -158,13 +166,13 @@ impl PubsubClient for Ws {
} }
fn unsubscribe<T: Into<U256>>(&self, id: T) -> Result<(), ClientError> { fn unsubscribe<T: Into<U256>>(&self, id: T) -> Result<(), ClientError> {
self.send(TransportMessage::Unsubscribe { id: id.into() }) self.send(Instruction::Unsubscribe { id: id.into() })
} }
} }
struct WsServer<S> { struct WsServer<S> {
ws: Fuse<S>, ws: Fuse<S>,
requests: Fuse<mpsc::UnboundedReceiver<TransportMessage>>, instructions: Fuse<mpsc::UnboundedReceiver<Instruction>>,
pending: BTreeMap<u64, Pending>, pending: BTreeMap<u64, Pending>,
subscriptions: BTreeMap<U256, Subscription>, subscriptions: BTreeMap<U256, Subscription>,
@ -179,12 +187,12 @@ where
+ Unpin, + Unpin,
{ {
/// Instantiates the Websocket Server /// Instantiates the Websocket Server
fn new(ws: S, requests: mpsc::UnboundedReceiver<TransportMessage>) -> Self { fn new(ws: S, requests: mpsc::UnboundedReceiver<Instruction>) -> Self {
Self { Self {
// Fuse the 2 steams together, so that we can `select` them in the // Fuse the 2 steams together, so that we can `select` them in the
// Stream implementation // Stream implementation
ws: ws.fuse(), ws: ws.fuse(),
requests: requests.fuse(), instructions: requests.fuse(),
pending: BTreeMap::default(), pending: BTreeMap::default(),
subscriptions: BTreeMap::default(), subscriptions: BTreeMap::default(),
} }
@ -197,13 +205,13 @@ where
{ {
let f = async move { let f = async move {
loop { loop {
match self.process().await { match self.tick().await {
Err(ClientError::UnexpectedClose) => { Err(ClientError::UnexpectedClose) => {
tracing::error!("{}", ClientError::UnexpectedClose); tracing::error!("{}", ClientError::UnexpectedClose);
break; break;
} }
Err(_) => { Err(e) => {
panic!("WS Server panic"); panic!("WS Server panic: {}", e);
} }
_ => {} _ => {}
} }
@ -213,35 +221,13 @@ where
tokio::spawn(f); tokio::spawn(f);
} }
/// Processes 1 item selected from the incoming `requests` or `ws` // dispatch an RPC request
#[allow(clippy::single_match)] async fn service_request(
async fn process(&mut self) -> Result<(), ClientError> { &mut self,
futures_util::select! { id: u64,
// Handle requests request: String,
msg = self.requests.select_next_some() => { sender: Pending,
self.handle_request(msg).await?; ) -> Result<(), ClientError> {
},
// Handle ws messages
msg = self.ws.next() => match msg {
Some(Ok(msg)) => self.handle_ws(msg).await?,
// TODO: Log the error?
Some(Err(_)) => {},
None => {
return Err(ClientError::UnexpectedClose);
},
}
};
Ok(())
}
async fn handle_request(&mut self, msg: TransportMessage) -> Result<(), ClientError> {
match msg {
TransportMessage::Request {
id,
request,
sender,
} => {
if self.pending.insert(id, sender).is_some() { if self.pending.insert(id, sender).is_some() {
warn!("Replacing a pending request with id {:?}", id); warn!("Replacing a pending request with id {:?}", id);
} }
@ -250,26 +236,69 @@ where
error!("WS connection error: {:?}", e); error!("WS connection error: {:?}", e);
self.pending.remove(&id); self.pending.remove(&id);
} }
Ok(())
} }
TransportMessage::Subscribe { id, sink } => {
/// Dispatch a subscription request
async fn service_subscribe(&mut self, id: U256, sink: Subscription) -> Result<(), ClientError> {
if self.subscriptions.insert(id, sink).is_some() { if self.subscriptions.insert(id, sink).is_some() {
warn!("Replacing already-registered subscription with id {:?}", id); warn!("Replacing already-registered subscription with id {:?}", id);
} }
Ok(())
} }
TransportMessage::Unsubscribe { id } => {
/// Dispatch a unsubscribe request
async fn service_unsubscribe(&mut self, id: U256) -> Result<(), ClientError> {
if self.subscriptions.remove(&id).is_none() { if self.subscriptions.remove(&id).is_none() {
warn!( warn!(
"Unsubscribing from non-existent subscription with id {:?}", "Unsubscribing from non-existent subscription with id {:?}",
id id
); );
} }
}
};
Ok(()) Ok(())
} }
async fn handle_ws(&mut self, resp: Message) -> Result<(), ClientError> { /// Dispatch an outgoing message
async fn service(&mut self, instruction: Instruction) -> Result<(), ClientError> {
match instruction {
Instruction::Request {
id,
request,
sender,
} => self.service_request(id, request, sender).await,
Instruction::Subscribe { id, sink } => self.service_subscribe(id, sink).await,
Instruction::Unsubscribe { id } => self.service_unsubscribe(id).await,
}
}
async fn handle_ping(&mut self, inner: Vec<u8>) -> Result<(), ClientError> {
self.ws.send(Message::Pong(inner)).await?;
Ok(())
}
async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> {
match serde_json::from_str::<Incoming>(&inner) {
Err(_) => {}
Ok(Incoming::Response(resp)) => {
if let Some(request) = self.pending.remove(&resp.id) {
request
.send(resp.data.into_result())
.map_err(to_client_error)?;
}
}
Ok(Incoming::Notification(notification)) => {
let id = notification.params.subscription;
if let Some(stream) = self.subscriptions.get(&id) {
stream
.unbounded_send(notification.params.result)
.map_err(to_client_error)?;
}
}
}
Ok(())
}
async fn handle(&mut self, resp: Message) -> Result<(), ClientError> {
match resp { match resp {
Message::Text(inner) => self.handle_text(inner).await, Message::Text(inner) => self.handle_text(inner).await,
Message::Ping(inner) => self.handle_ping(inner).await, Message::Ping(inner) => self.handle_ping(inner).await,
@ -280,28 +309,25 @@ where
} }
} }
async fn handle_ping(&mut self, inner: Vec<u8>) -> Result<(), ClientError> { /// Processes 1 instruction or 1 incoming websocket message
self.ws.send(Message::Pong(inner)).await?; #[allow(clippy::single_match)]
Ok(()) async fn tick(&mut self) -> Result<(), ClientError> {
futures_util::select! {
// Handle requests
instruction = self.instructions.select_next_some() => {
self.service(instruction).await?;
},
// Handle ws messages
resp = self.ws.next() => match resp {
Some(Ok(resp)) => self.handle(resp).await?,
// TODO: Log the error?
Some(Err(_)) => {},
None => {
return Err(ClientError::UnexpectedClose);
},
} }
};
async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> {
if let Ok(resp) = serde_json::from_str::<Response<serde_json::Value>>(&inner) {
if let Some(request) = self.pending.remove(&resp.id) {
request
.send(resp.data.into_result())
.map_err(to_client_error)?;
}
} else if let Ok(notification) =
serde_json::from_str::<Notification<serde_json::Value>>(&inner)
{
let id = notification.params.subscription;
if let Some(stream) = self.subscriptions.get(&id) {
stream
.unbounded_send(notification.params.result)
.map_err(to_client_error)?;
}
}
Ok(()) Ok(())
} }
} }