Refactor WS handling code (#397)

* refactor: simplify handle_text message

* refactor: rename 'TransportMessage' to 'Instruction'

* nit: update docstring on tick method

* refactor: rename dispatch to service

* nit: rename WsServer.requests to WsServer.instructions for consistency

* refactor: simplify handle_text, remove unwrap

* nit: fix misspelled comment
This commit is contained in:
James Prestwich 2021-08-20 10:23:39 -07:00 committed by GitHub
parent 0eee674ba0
commit e790429355
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 101 deletions

View File

@ -48,7 +48,7 @@ where
///
/// ### Note
/// Most providers treat `SubscriptionStream` IDs as global singletons.
/// Instanitating this directly with a known ID will likely cause any
/// Instantiating this directly with a known ID will likely cause any
/// existing streams with that ID to end. To avoid this, start a new stream
/// using [`Provider::subscribe`] instead of `SubscriptionStream::new`.
pub fn new(id: U256, provider: &'a Provider<P>) -> Result<Self, P::Error> {

View File

@ -30,6 +30,30 @@ use tokio_tungstenite::{
};
use tracing::{error, warn};
type Pending = oneshot::Sender<Result<serde_json::Value, JsonRpcError>>;
type Subscription = mpsc::UnboundedSender<serde_json::Value>;
/// Instructions for the `WsServer`.
enum Instruction {
/// JSON-RPC request
Request {
id: u64,
request: String,
sender: Pending,
},
/// Create a new subscription
Subscribe { id: U256, sink: Subscription },
/// Cancel an existing subscription
Unsubscribe { id: U256 },
}
#[derive(Debug, serde::Deserialize)]
#[serde(untagged)]
enum Incoming {
Notification(Notification<serde_json::Value>),
Response(Response<serde_json::Value>),
}
/// A JSON-RPC Client over Websockets.
///
/// ```no_run
@ -43,25 +67,7 @@ use tracing::{error, warn};
#[derive(Clone)]
pub struct Ws {
id: Arc<AtomicU64>,
requests: mpsc::UnboundedSender<TransportMessage>,
}
type Pending = oneshot::Sender<Result<serde_json::Value, JsonRpcError>>;
type Subscription = mpsc::UnboundedSender<serde_json::Value>;
enum TransportMessage {
Request {
id: u64,
request: String,
sender: Pending,
},
Subscribe {
id: U256,
sink: Subscription,
},
Unsubscribe {
id: U256,
},
instructions: mpsc::UnboundedSender<Instruction>,
}
impl Debug for Ws {
@ -90,13 +96,13 @@ impl Ws {
Self {
id: Arc::new(AtomicU64::new(0)),
requests: sink,
instructions: sink,
}
}
/// Returns true if the WS connection is active, false otherwise
pub fn ready(&self) -> bool {
!self.requests.is_closed()
!self.instructions.is_closed()
}
/// Initializes a new WebSocket Client
@ -107,8 +113,10 @@ impl Ws {
Ok(Self::new(ws))
}
fn send(&self, msg: TransportMessage) -> Result<(), ClientError> {
self.requests.unbounded_send(msg).map_err(to_client_error)
fn send(&self, msg: Instruction) -> Result<(), ClientError> {
self.instructions
.unbounded_send(msg)
.map_err(to_client_error)
}
}
@ -125,14 +133,14 @@ impl JsonRpcClient for Ws {
// send the message
let (sender, receiver) = oneshot::channel();
let payload = TransportMessage::Request {
let payload = Instruction::Request {
id: next_id,
request: serde_json::to_string(&Request::new(next_id, method, params))?,
sender,
};
// send the data
self.send(payload).map_err(to_client_error)?;
self.send(payload)?;
// wait for the response
let res = receiver.await?;
@ -150,7 +158,7 @@ impl PubsubClient for Ws {
fn subscribe<T: Into<U256>>(&self, id: T) -> Result<Self::NotificationStream, ClientError> {
let (sink, stream) = mpsc::unbounded();
self.send(TransportMessage::Subscribe {
self.send(Instruction::Subscribe {
id: id.into(),
sink,
})?;
@ -158,13 +166,13 @@ impl PubsubClient for Ws {
}
fn unsubscribe<T: Into<U256>>(&self, id: T) -> Result<(), ClientError> {
self.send(TransportMessage::Unsubscribe { id: id.into() })
self.send(Instruction::Unsubscribe { id: id.into() })
}
}
struct WsServer<S> {
ws: Fuse<S>,
requests: Fuse<mpsc::UnboundedReceiver<TransportMessage>>,
instructions: Fuse<mpsc::UnboundedReceiver<Instruction>>,
pending: BTreeMap<u64, Pending>,
subscriptions: BTreeMap<U256, Subscription>,
@ -179,12 +187,12 @@ where
+ Unpin,
{
/// Instantiates the Websocket Server
fn new(ws: S, requests: mpsc::UnboundedReceiver<TransportMessage>) -> Self {
fn new(ws: S, requests: mpsc::UnboundedReceiver<Instruction>) -> Self {
Self {
// Fuse the 2 steams together, so that we can `select` them in the
// Stream implementation
ws: ws.fuse(),
requests: requests.fuse(),
instructions: requests.fuse(),
pending: BTreeMap::default(),
subscriptions: BTreeMap::default(),
}
@ -197,13 +205,13 @@ where
{
let f = async move {
loop {
match self.process().await {
match self.tick().await {
Err(ClientError::UnexpectedClose) => {
tracing::error!("{}", ClientError::UnexpectedClose);
break;
}
Err(_) => {
panic!("WS Server panic");
Err(e) => {
panic!("WS Server panic: {}", e);
}
_ => {}
}
@ -213,35 +221,13 @@ where
tokio::spawn(f);
}
/// Processes 1 item selected from the incoming `requests` or `ws`
#[allow(clippy::single_match)]
async fn process(&mut self) -> Result<(), ClientError> {
futures_util::select! {
// Handle requests
msg = self.requests.select_next_some() => {
self.handle_request(msg).await?;
},
// Handle ws messages
msg = self.ws.next() => match msg {
Some(Ok(msg)) => self.handle_ws(msg).await?,
// TODO: Log the error?
Some(Err(_)) => {},
None => {
return Err(ClientError::UnexpectedClose);
},
}
};
Ok(())
}
async fn handle_request(&mut self, msg: TransportMessage) -> Result<(), ClientError> {
match msg {
TransportMessage::Request {
id,
request,
sender,
} => {
// dispatch an RPC request
async fn service_request(
&mut self,
id: u64,
request: String,
sender: Pending,
) -> Result<(), ClientError> {
if self.pending.insert(id, sender).is_some() {
warn!("Replacing a pending request with id {:?}", id);
}
@ -250,26 +236,69 @@ where
error!("WS connection error: {:?}", e);
self.pending.remove(&id);
}
Ok(())
}
TransportMessage::Subscribe { id, sink } => {
/// Dispatch a subscription request
async fn service_subscribe(&mut self, id: U256, sink: Subscription) -> Result<(), ClientError> {
if self.subscriptions.insert(id, sink).is_some() {
warn!("Replacing already-registered subscription with id {:?}", id);
}
Ok(())
}
TransportMessage::Unsubscribe { id } => {
/// Dispatch a unsubscribe request
async fn service_unsubscribe(&mut self, id: U256) -> Result<(), ClientError> {
if self.subscriptions.remove(&id).is_none() {
warn!(
"Unsubscribing from non-existent subscription with id {:?}",
id
);
}
}
};
Ok(())
}
async fn handle_ws(&mut self, resp: Message) -> Result<(), ClientError> {
/// Dispatch an outgoing message
async fn service(&mut self, instruction: Instruction) -> Result<(), ClientError> {
match instruction {
Instruction::Request {
id,
request,
sender,
} => self.service_request(id, request, sender).await,
Instruction::Subscribe { id, sink } => self.service_subscribe(id, sink).await,
Instruction::Unsubscribe { id } => self.service_unsubscribe(id).await,
}
}
async fn handle_ping(&mut self, inner: Vec<u8>) -> Result<(), ClientError> {
self.ws.send(Message::Pong(inner)).await?;
Ok(())
}
async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> {
match serde_json::from_str::<Incoming>(&inner) {
Err(_) => {}
Ok(Incoming::Response(resp)) => {
if let Some(request) = self.pending.remove(&resp.id) {
request
.send(resp.data.into_result())
.map_err(to_client_error)?;
}
}
Ok(Incoming::Notification(notification)) => {
let id = notification.params.subscription;
if let Some(stream) = self.subscriptions.get(&id) {
stream
.unbounded_send(notification.params.result)
.map_err(to_client_error)?;
}
}
}
Ok(())
}
async fn handle(&mut self, resp: Message) -> Result<(), ClientError> {
match resp {
Message::Text(inner) => self.handle_text(inner).await,
Message::Ping(inner) => self.handle_ping(inner).await,
@ -280,28 +309,25 @@ where
}
}
async fn handle_ping(&mut self, inner: Vec<u8>) -> Result<(), ClientError> {
self.ws.send(Message::Pong(inner)).await?;
Ok(())
/// Processes 1 instruction or 1 incoming websocket message
#[allow(clippy::single_match)]
async fn tick(&mut self) -> Result<(), ClientError> {
futures_util::select! {
// Handle requests
instruction = self.instructions.select_next_some() => {
self.service(instruction).await?;
},
// Handle ws messages
resp = self.ws.next() => match resp {
Some(Ok(resp)) => self.handle(resp).await?,
// TODO: Log the error?
Some(Err(_)) => {},
None => {
return Err(ClientError::UnexpectedClose);
},
}
};
async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> {
if let Ok(resp) = serde_json::from_str::<Response<serde_json::Value>>(&inner) {
if let Some(request) = self.pending.remove(&resp.id) {
request
.send(resp.data.into_result())
.map_err(to_client_error)?;
}
} else if let Ok(notification) =
serde_json::from_str::<Notification<serde_json::Value>>(&inner)
{
let id = notification.params.subscription;
if let Some(stream) = self.subscriptions.get(&id) {
stream
.unbounded_send(notification.params.result)
.map_err(to_client_error)?;
}
}
Ok(())
}
}