Refactor WS handling code (#397)
* refactor: simplify handle_text message * refactor: rename 'TransportMessage' to 'Instruction' * nit: update docstring on tick method * refactor: rename dispatch to service * nit: rename WsServer.requests to WsServer.instructions for consistency * refactor: simplify handle_text, remove unwrap * nit: fix misspelled comment
This commit is contained in:
parent
0eee674ba0
commit
e790429355
|
@ -48,7 +48,7 @@ where
|
|||
///
|
||||
/// ### Note
|
||||
/// Most providers treat `SubscriptionStream` IDs as global singletons.
|
||||
/// Instanitating this directly with a known ID will likely cause any
|
||||
/// Instantiating this directly with a known ID will likely cause any
|
||||
/// existing streams with that ID to end. To avoid this, start a new stream
|
||||
/// using [`Provider::subscribe`] instead of `SubscriptionStream::new`.
|
||||
pub fn new(id: U256, provider: &'a Provider<P>) -> Result<Self, P::Error> {
|
||||
|
|
|
@ -30,6 +30,30 @@ use tokio_tungstenite::{
|
|||
};
|
||||
use tracing::{error, warn};
|
||||
|
||||
type Pending = oneshot::Sender<Result<serde_json::Value, JsonRpcError>>;
|
||||
type Subscription = mpsc::UnboundedSender<serde_json::Value>;
|
||||
|
||||
/// Instructions for the `WsServer`.
|
||||
enum Instruction {
|
||||
/// JSON-RPC request
|
||||
Request {
|
||||
id: u64,
|
||||
request: String,
|
||||
sender: Pending,
|
||||
},
|
||||
/// Create a new subscription
|
||||
Subscribe { id: U256, sink: Subscription },
|
||||
/// Cancel an existing subscription
|
||||
Unsubscribe { id: U256 },
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Incoming {
|
||||
Notification(Notification<serde_json::Value>),
|
||||
Response(Response<serde_json::Value>),
|
||||
}
|
||||
|
||||
/// A JSON-RPC Client over Websockets.
|
||||
///
|
||||
/// ```no_run
|
||||
|
@ -43,25 +67,7 @@ use tracing::{error, warn};
|
|||
#[derive(Clone)]
|
||||
pub struct Ws {
|
||||
id: Arc<AtomicU64>,
|
||||
requests: mpsc::UnboundedSender<TransportMessage>,
|
||||
}
|
||||
|
||||
type Pending = oneshot::Sender<Result<serde_json::Value, JsonRpcError>>;
|
||||
type Subscription = mpsc::UnboundedSender<serde_json::Value>;
|
||||
|
||||
enum TransportMessage {
|
||||
Request {
|
||||
id: u64,
|
||||
request: String,
|
||||
sender: Pending,
|
||||
},
|
||||
Subscribe {
|
||||
id: U256,
|
||||
sink: Subscription,
|
||||
},
|
||||
Unsubscribe {
|
||||
id: U256,
|
||||
},
|
||||
instructions: mpsc::UnboundedSender<Instruction>,
|
||||
}
|
||||
|
||||
impl Debug for Ws {
|
||||
|
@ -90,13 +96,13 @@ impl Ws {
|
|||
|
||||
Self {
|
||||
id: Arc::new(AtomicU64::new(0)),
|
||||
requests: sink,
|
||||
instructions: sink,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the WS connection is active, false otherwise
|
||||
pub fn ready(&self) -> bool {
|
||||
!self.requests.is_closed()
|
||||
!self.instructions.is_closed()
|
||||
}
|
||||
|
||||
/// Initializes a new WebSocket Client
|
||||
|
@ -107,8 +113,10 @@ impl Ws {
|
|||
Ok(Self::new(ws))
|
||||
}
|
||||
|
||||
fn send(&self, msg: TransportMessage) -> Result<(), ClientError> {
|
||||
self.requests.unbounded_send(msg).map_err(to_client_error)
|
||||
fn send(&self, msg: Instruction) -> Result<(), ClientError> {
|
||||
self.instructions
|
||||
.unbounded_send(msg)
|
||||
.map_err(to_client_error)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,14 +133,14 @@ impl JsonRpcClient for Ws {
|
|||
|
||||
// send the message
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
let payload = TransportMessage::Request {
|
||||
let payload = Instruction::Request {
|
||||
id: next_id,
|
||||
request: serde_json::to_string(&Request::new(next_id, method, params))?,
|
||||
sender,
|
||||
};
|
||||
|
||||
// send the data
|
||||
self.send(payload).map_err(to_client_error)?;
|
||||
self.send(payload)?;
|
||||
|
||||
// wait for the response
|
||||
let res = receiver.await?;
|
||||
|
@ -150,7 +158,7 @@ impl PubsubClient for Ws {
|
|||
|
||||
fn subscribe<T: Into<U256>>(&self, id: T) -> Result<Self::NotificationStream, ClientError> {
|
||||
let (sink, stream) = mpsc::unbounded();
|
||||
self.send(TransportMessage::Subscribe {
|
||||
self.send(Instruction::Subscribe {
|
||||
id: id.into(),
|
||||
sink,
|
||||
})?;
|
||||
|
@ -158,13 +166,13 @@ impl PubsubClient for Ws {
|
|||
}
|
||||
|
||||
fn unsubscribe<T: Into<U256>>(&self, id: T) -> Result<(), ClientError> {
|
||||
self.send(TransportMessage::Unsubscribe { id: id.into() })
|
||||
self.send(Instruction::Unsubscribe { id: id.into() })
|
||||
}
|
||||
}
|
||||
|
||||
struct WsServer<S> {
|
||||
ws: Fuse<S>,
|
||||
requests: Fuse<mpsc::UnboundedReceiver<TransportMessage>>,
|
||||
instructions: Fuse<mpsc::UnboundedReceiver<Instruction>>,
|
||||
|
||||
pending: BTreeMap<u64, Pending>,
|
||||
subscriptions: BTreeMap<U256, Subscription>,
|
||||
|
@ -179,12 +187,12 @@ where
|
|||
+ Unpin,
|
||||
{
|
||||
/// Instantiates the Websocket Server
|
||||
fn new(ws: S, requests: mpsc::UnboundedReceiver<TransportMessage>) -> Self {
|
||||
fn new(ws: S, requests: mpsc::UnboundedReceiver<Instruction>) -> Self {
|
||||
Self {
|
||||
// Fuse the 2 steams together, so that we can `select` them in the
|
||||
// Stream implementation
|
||||
ws: ws.fuse(),
|
||||
requests: requests.fuse(),
|
||||
instructions: requests.fuse(),
|
||||
pending: BTreeMap::default(),
|
||||
subscriptions: BTreeMap::default(),
|
||||
}
|
||||
|
@ -197,13 +205,13 @@ where
|
|||
{
|
||||
let f = async move {
|
||||
loop {
|
||||
match self.process().await {
|
||||
match self.tick().await {
|
||||
Err(ClientError::UnexpectedClose) => {
|
||||
tracing::error!("{}", ClientError::UnexpectedClose);
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
panic!("WS Server panic");
|
||||
Err(e) => {
|
||||
panic!("WS Server panic: {}", e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -213,63 +221,84 @@ where
|
|||
tokio::spawn(f);
|
||||
}
|
||||
|
||||
/// Processes 1 item selected from the incoming `requests` or `ws`
|
||||
#[allow(clippy::single_match)]
|
||||
async fn process(&mut self) -> Result<(), ClientError> {
|
||||
futures_util::select! {
|
||||
// Handle requests
|
||||
msg = self.requests.select_next_some() => {
|
||||
self.handle_request(msg).await?;
|
||||
},
|
||||
// Handle ws messages
|
||||
msg = self.ws.next() => match msg {
|
||||
Some(Ok(msg)) => self.handle_ws(msg).await?,
|
||||
// TODO: Log the error?
|
||||
Some(Err(_)) => {},
|
||||
None => {
|
||||
return Err(ClientError::UnexpectedClose);
|
||||
},
|
||||
}
|
||||
};
|
||||
// dispatch an RPC request
|
||||
async fn service_request(
|
||||
&mut self,
|
||||
id: u64,
|
||||
request: String,
|
||||
sender: Pending,
|
||||
) -> Result<(), ClientError> {
|
||||
if self.pending.insert(id, sender).is_some() {
|
||||
warn!("Replacing a pending request with id {:?}", id);
|
||||
}
|
||||
|
||||
if let Err(e) = self.ws.send(Message::Text(request)).await {
|
||||
error!("WS connection error: {:?}", e);
|
||||
self.pending.remove(&id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_request(&mut self, msg: TransportMessage) -> Result<(), ClientError> {
|
||||
match msg {
|
||||
TransportMessage::Request {
|
||||
/// Dispatch a subscription request
|
||||
async fn service_subscribe(&mut self, id: U256, sink: Subscription) -> Result<(), ClientError> {
|
||||
if self.subscriptions.insert(id, sink).is_some() {
|
||||
warn!("Replacing already-registered subscription with id {:?}", id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Dispatch a unsubscribe request
|
||||
async fn service_unsubscribe(&mut self, id: U256) -> Result<(), ClientError> {
|
||||
if self.subscriptions.remove(&id).is_none() {
|
||||
warn!(
|
||||
"Unsubscribing from non-existent subscription with id {:?}",
|
||||
id
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Dispatch an outgoing message
|
||||
async fn service(&mut self, instruction: Instruction) -> Result<(), ClientError> {
|
||||
match instruction {
|
||||
Instruction::Request {
|
||||
id,
|
||||
request,
|
||||
sender,
|
||||
} => {
|
||||
if self.pending.insert(id, sender).is_some() {
|
||||
warn!("Replacing a pending request with id {:?}", id);
|
||||
}
|
||||
|
||||
if let Err(e) = self.ws.send(Message::Text(request)).await {
|
||||
error!("WS connection error: {:?}", e);
|
||||
self.pending.remove(&id);
|
||||
}
|
||||
}
|
||||
TransportMessage::Subscribe { id, sink } => {
|
||||
if self.subscriptions.insert(id, sink).is_some() {
|
||||
warn!("Replacing already-registered subscription with id {:?}", id);
|
||||
}
|
||||
}
|
||||
TransportMessage::Unsubscribe { id } => {
|
||||
if self.subscriptions.remove(&id).is_none() {
|
||||
warn!(
|
||||
"Unsubscribing from non-existent subscription with id {:?}",
|
||||
id
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
} => self.service_request(id, request, sender).await,
|
||||
Instruction::Subscribe { id, sink } => self.service_subscribe(id, sink).await,
|
||||
Instruction::Unsubscribe { id } => self.service_unsubscribe(id).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_ping(&mut self, inner: Vec<u8>) -> Result<(), ClientError> {
|
||||
self.ws.send(Message::Pong(inner)).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_ws(&mut self, resp: Message) -> Result<(), ClientError> {
|
||||
async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> {
|
||||
match serde_json::from_str::<Incoming>(&inner) {
|
||||
Err(_) => {}
|
||||
Ok(Incoming::Response(resp)) => {
|
||||
if let Some(request) = self.pending.remove(&resp.id) {
|
||||
request
|
||||
.send(resp.data.into_result())
|
||||
.map_err(to_client_error)?;
|
||||
}
|
||||
}
|
||||
Ok(Incoming::Notification(notification)) => {
|
||||
let id = notification.params.subscription;
|
||||
if let Some(stream) = self.subscriptions.get(&id) {
|
||||
stream
|
||||
.unbounded_send(notification.params.result)
|
||||
.map_err(to_client_error)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle(&mut self, resp: Message) -> Result<(), ClientError> {
|
||||
match resp {
|
||||
Message::Text(inner) => self.handle_text(inner).await,
|
||||
Message::Ping(inner) => self.handle_ping(inner).await,
|
||||
|
@ -280,28 +309,25 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
async fn handle_ping(&mut self, inner: Vec<u8>) -> Result<(), ClientError> {
|
||||
self.ws.send(Message::Pong(inner)).await?;
|
||||
Ok(())
|
||||
}
|
||||
/// Processes 1 instruction or 1 incoming websocket message
|
||||
#[allow(clippy::single_match)]
|
||||
async fn tick(&mut self) -> Result<(), ClientError> {
|
||||
futures_util::select! {
|
||||
// Handle requests
|
||||
instruction = self.instructions.select_next_some() => {
|
||||
self.service(instruction).await?;
|
||||
},
|
||||
// Handle ws messages
|
||||
resp = self.ws.next() => match resp {
|
||||
Some(Ok(resp)) => self.handle(resp).await?,
|
||||
// TODO: Log the error?
|
||||
Some(Err(_)) => {},
|
||||
None => {
|
||||
return Err(ClientError::UnexpectedClose);
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> {
|
||||
if let Ok(resp) = serde_json::from_str::<Response<serde_json::Value>>(&inner) {
|
||||
if let Some(request) = self.pending.remove(&resp.id) {
|
||||
request
|
||||
.send(resp.data.into_result())
|
||||
.map_err(to_client_error)?;
|
||||
}
|
||||
} else if let Ok(notification) =
|
||||
serde_json::from_str::<Notification<serde_json::Value>>(&inner)
|
||||
{
|
||||
let id = notification.params.subscription;
|
||||
if let Some(stream) = self.subscriptions.get(&id) {
|
||||
stream
|
||||
.unbounded_send(notification.params.result)
|
||||
.map_err(to_client_error)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue