RRR: Reconnection & Request Reissuance (#2181)

* wip: ws2

* ws2 backend compiles

* refactor: rename PubSubItem and BackendDriver

* feature: dispatch request to end subscription

* refactor: move ws2 to ws, fix reconnection and deser on subs

* chore: improve use of tracing in manager

* refactor: feature legacy_ws to enable backwards compatibility

* nit: mod file ordering

* docs: copy PR description to ws structs

* fixes: remove unused macros file, remove err formats

* docs: add comments to struct fields

* docs: comment client struct fields

* chore: changelog

* fix: unused imports in ws_errors test

* docs: missing comment

Co-authored-by: Georgios Konstantopoulos <me@gakonst.com>

* fix: legacy-ws feature in root crate, hyphen not underscore

* fix: a couple bad imports/exports

---------

Co-authored-by: Georgios Konstantopoulos <me@gakonst.com>
This commit is contained in:
James Prestwich 2023-02-28 17:25:59 -08:00 committed by GitHub
parent 20375e291b
commit 73636a906e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1262 additions and 54 deletions

View File

@ -239,6 +239,9 @@
### Unreleased
- Breaking: WS now includes reconnection logic and a changed `connect`
interface. Old behavior can be accessed via the `legacy_ws` feature
[#2181](https://github.com/gakonst/ethers-rs/pull/2181)
- Re-organize the crate. #[2150](https://github.com/gakonst/ethers-rs/pull/2159)
- Convert provider errors to arbitrary middleware errors
[#1920](https://github.com/gakonst/ethers-rs/pull/1920)

62
Cargo.lock generated
View File

@ -1512,6 +1512,7 @@ dependencies = [
"tokio-tungstenite",
"tracing",
"tracing-futures",
"tracing-test",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
@ -1549,6 +1550,7 @@ dependencies = [
"thiserror",
"tokio",
"tracing",
"tracing-subscriber",
"trezor-client",
"yubihsm",
]
@ -2633,6 +2635,16 @@ dependencies = [
"static_assertions",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "num-integer"
version = "0.1.45"
@ -2798,6 +2810,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "p256"
version = "0.11.1"
@ -4414,6 +4432,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
@ -4426,6 +4445,17 @@ dependencies = [
"tracing",
]
[[package]]
name = "tracing-log"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922"
dependencies = [
"lazy_static",
"log",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.16"
@ -4433,12 +4463,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]
[[package]]
name = "tracing-test"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a2c0ff408fe918a94c428a3f2ad04e4afd5c95bbc08fcf868eff750c15728a4"
dependencies = [
"lazy_static",
"tracing-core",
"tracing-subscriber",
"tracing-test-macro",
]
[[package]]
name = "tracing-test-macro"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "258bc1c4f8e2e73a977812ab339d503e6feeb92700f6d07a6de4d321522d5c08"
dependencies = [
"lazy_static",
"quote",
"syn",
]
[[package]]
@ -4583,6 +4639,12 @@ dependencies = [
"serde",
]
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "vcpkg"
version = "0.2.15"

View File

@ -39,7 +39,7 @@ pin-project = { version = "1.0.11", default-features = false }
enr = { version = "0.7.0", default-features = false, features = ["k256", "serde"] }
# tracing
tracing = { version = "0.1.37", default-features = false }
tracing = { version = "0.1.37", default-features = false, features = ["attributes"] }
tracing-futures = { version = "0.2.5", default-features = false, features = ["std-future"] }
bytes = { version = "1.4.0", default-features = false, optional = true }
@ -76,6 +76,7 @@ default = ["ws", "rustls"]
celo = ["ethers-core/celo"]
ws = ["tokio-tungstenite", "futures-channel"]
legacy-ws = ["ws"]
ipc = ["tokio/io-util", "bytes", "futures-channel", "winapi"]
openssl = ["tokio-tungstenite/native-tls", "reqwest/native-tls"]
@ -83,3 +84,6 @@ openssl = ["tokio-tungstenite/native-tls", "reqwest/native-tls"]
# on the host
rustls = ["tokio-tungstenite/rustls-tls-webpki-roots", "reqwest/rustls-tls"]
dev-rpc = []
[dev-dependencies]
tracing-test = { version = "0.2.4", features = ["no-env-filter"] }

View File

@ -11,8 +11,6 @@ use crate::{
MockProvider, NodeInfo, PeerInfo, PendingTransaction, QuorumProvider, RwClient,
};
#[cfg(all(not(target_arch = "wasm32"), feature = "ws"))]
use crate::Authorization;
#[cfg(not(target_arch = "wasm32"))]
use crate::{HttpRateLimitRetryPolicy, RetryClient};
@ -1217,35 +1215,6 @@ impl<P: JsonRpcClient> Provider<P> {
}
}
#[cfg(feature = "ws")]
impl Provider<crate::Ws> {
/// Direct connection to a websocket endpoint
#[cfg(not(target_arch = "wasm32"))]
pub async fn connect(
url: impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin,
) -> Result<Self, ProviderError> {
let ws = crate::Ws::connect(url).await?;
Ok(Self::new(ws))
}
/// Direct connection to a websocket endpoint
#[cfg(target_arch = "wasm32")]
pub async fn connect(url: &str) -> Result<Self, ProviderError> {
let ws = crate::Ws::connect(url).await?;
Ok(Self::new(ws))
}
/// Connect to a WS RPC provider with authentication details
#[cfg(not(target_arch = "wasm32"))]
pub async fn connect_with_auth(
url: impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin,
auth: Authorization,
) -> Result<Self, ProviderError> {
let ws = crate::Ws::connect_with_auth(url, auth).await?;
Ok(Self::new(ws))
}
}
#[cfg(all(feature = "ipc", any(unix, windows)))]
impl Provider<crate::Ipc> {
#[cfg_attr(unix, doc = "Connects to the Unix socket at the provided path.")]

View File

@ -571,3 +571,31 @@ mod tests {
resp.unwrap_err();
}
}
impl crate::Provider<Ws> {
/// Direct connection to a websocket endpoint
#[cfg(not(target_arch = "wasm32"))]
pub async fn connect(
url: impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin,
) -> Result<Self, ProviderError> {
let ws = crate::Ws::connect(url).await?;
Ok(Self::new(ws))
}
/// Direct connection to a websocket endpoint
#[cfg(target_arch = "wasm32")]
pub async fn connect(url: &str) -> Result<Self, ProviderError> {
let ws = crate::Ws::connect(url).await?;
Ok(Self::new(ws))
}
/// Connect to a WS RPC provider with authentication details
#[cfg(not(target_arch = "wasm32"))]
pub async fn connect_with_auth(
url: impl tokio_tungstenite::tungstenite::client::IntoClientRequest + Unpin,
auth: Authorization,
) -> Result<Self, ProviderError> {
let ws = crate::Ws::connect_with_auth(url, auth).await?;
Ok(Self::new(ws))
}
}

View File

@ -9,11 +9,6 @@ mod ipc;
#[cfg(all(feature = "ipc", any(unix, windows)))]
pub use ipc::{Ipc, IpcError};
#[cfg(feature = "ws")]
mod ws;
#[cfg(feature = "ws")]
pub use ws::{ClientError as WsClientError, Ws};
mod quorum;
pub use quorum::{JsonRpcClientWrapper, Quorum, QuorumError, QuorumProvider, WeightedProvider};
@ -23,5 +18,16 @@ pub use rw::{RwClient, RwClientError};
mod retry;
pub use retry::*;
#[cfg(all(feature = "ws", not(feature = "legacy-ws")))]
mod ws;
#[cfg(all(feature = "ws", not(feature = "legacy-ws")))]
pub use ws::{ConnectionDetails, WsClient as Ws, WsClientError};
/// archival websocket
#[cfg(feature = "legacy-ws")]
pub mod legacy_ws;
#[cfg(feature = "legacy-ws")]
pub use legacy_ws::{ClientError as WsClientError, Ws};
mod mock;
pub use mock::{MockError, MockProvider};

View File

@ -0,0 +1,210 @@
use futures_channel::{mpsc, oneshot};
use futures_util::{select, sink::SinkExt, stream::StreamExt, FutureExt};
use serde_json::value::RawValue;
use super::{types::*, WsClientError};
use tracing::{error, trace};
/// `BackendDriver` drives a specific `WsBackend`. It can be used to issue
/// requests, receive responses, see errors, and shut down the backend.
pub struct BackendDriver {
// Pubsub items from the backend, received via WS
pub to_handle: mpsc::UnboundedReceiver<PubSubItem>,
// Notification from the backend of a terminal error
pub error: oneshot::Receiver<()>,
// Requests that the backend should dispatch
pub dispatcher: mpsc::UnboundedSender<Box<RawValue>>,
// Notify the backend of intentional shutdown
shutdown: oneshot::Sender<()>,
}
impl BackendDriver {
pub fn shutdown(self) {
// don't care if it fails, as that means the backend is gone anyway
let _ = self.shutdown.send(());
}
}
/// `WsBackend` dispatches requests and routes responses and notifications. It
/// also has a simple ping-based keepalive (when not compiled to wasm), to
/// prevent inactivity from triggering server-side closes
///
/// The `WsBackend` shuts down when instructed to by the `RequestManager` or
/// when the `RequestManager` drops (because the inbound channel will close)
pub struct WsBackend {
server: InternalStream,
// channel to the manager, through which to send items received via WS
handler: mpsc::UnboundedSender<PubSubItem>,
// notify manager of an error causing this task to halt
error: oneshot::Sender<()>,
// channel of inbound requests to dispatch
to_dispatch: mpsc::UnboundedReceiver<Box<RawValue>>,
// notification from manager of intentional shutdown
shutdown: oneshot::Receiver<()>,
}
impl WsBackend {
#[cfg(target_arch = "wasm32")]
pub async fn connect(
details: ConnectionDetails,
) -> Result<(Self, BackendDriver), WsClientError> {
let wsio = WsMeta::connect(details.url, None)
.await
.expect_throw("Could not create websocket")
.1
.fuse();
Ok(Self::new(wsio))
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn connect(
details: ConnectionDetails,
) -> Result<(Self, BackendDriver), WsClientError> {
let ws = connect_async(details).await?.0.fuse();
Ok(Self::new(ws))
}
pub fn new(server: InternalStream) -> (Self, BackendDriver) {
let (handler, to_handle) = mpsc::unbounded();
let (dispatcher, to_dispatch) = mpsc::unbounded();
let (error_tx, error_rx) = oneshot::channel();
let (shutdown_tx, shutdown_rx) = oneshot::channel();
(
WsBackend { server, handler, error: error_tx, to_dispatch, shutdown: shutdown_rx },
BackendDriver { to_handle, error: error_rx, dispatcher, shutdown: shutdown_tx },
)
}
pub async fn handle_text(&mut self, t: String) -> Result<(), WsClientError> {
trace!(text = t, "Received message");
match serde_json::from_str(&t) {
Ok(item) => {
trace!(%item, "Deserialized message");
let res = self.handler.unbounded_send(item);
if res.is_err() {
return Err(WsClientError::DeadChannel)
}
}
Err(e) => {
error!(e = %e, "Failed to deserialize message");
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
async fn handle(&mut self, item: WsStreamItem) -> Result<(), WsClientError> {
match item {
Ok(item) => match item {
Message::Text(t) => self.handle_text(t).await,
Message::Ping(data) => {
if self.server.send(Message::Pong(data)).await.is_err() {
return Err(WsClientError::UnexpectedClose)
}
Ok(())
}
Message::Pong(_) => Ok(()),
Message::Frame(_) => Ok(()),
Message::Binary(buf) => Err(WsClientError::UnexpectedBinary(buf)),
Message::Close(frame) => {
if frame.is_some() {
error!("Close frame: {}", frame.unwrap());
}
Err(WsClientError::UnexpectedClose)
}
},
Err(e) => {
error!(err = %e, "Error response from WS");
Err(e.into())
}
}
}
#[cfg(target_arch = "wasm32")]
async fn handle(&mut self, item: WsStreamItem) -> Result<(), WsClientError> {
match item {
Message::Text(inner) => self.handle_text(inner).await,
Message::Binary(buf) => Err(WsClientError::UnexpectedBinary(buf)),
}
}
pub fn spawn(mut self) {
let fut = async move {
let mut err = false;
loop {
#[cfg(not(target_arch = "wasm32"))]
let keepalive = tokio::time::sleep(std::time::Duration::from_secs(10)).fuse();
#[cfg(not(target_arch = "wasm32"))]
tokio::pin!(keepalive);
// in wasm, we don't ping. as ping doesn't exist in our wasm lib
#[cfg(target_arch = "wasm32")]
let mut keepalive = futures_util::future::pending::<()>().fuse();
select! {
_ = keepalive => {
#[cfg(not(target_arch = "wasm32"))]
if let Err(e) = self.server.send(Message::Ping(vec![])).await {
error!(err = %e, "WS connection error");
err = true;
break
}
#[cfg(target_arch = "wasm32")]
unreachable!();
}
resp = self.server.next() => {
match resp {
Some(item) => {
err = self.handle(item).await.is_err();
if err { break }
},
None => {
error!("WS server has gone away");
err = true;
break
},
}
}
// we've received a new dispatch, so we send it via
// websocket
inst = self.to_dispatch.next() => {
match inst {
Some(msg) => {
if let Err(e) = self.server.send(Message::Text(msg.to_string())).await {
error!(err = %e, "WS connection error");
err = true;
break
}
},
// dispatcher has gone away
None => {
break
},
}
},
// break on shutdown recv, or on shutdown recv error
_ = &mut self.shutdown => {
break
},
}
}
if err {
let _ = self.error.send(());
}
};
#[cfg(target_arch = "wasm32")]
super::spawn_local(fut);
#[cfg(not(target_arch = "wasm32"))]
tokio::spawn(fut);
}
}

View File

@ -0,0 +1,63 @@
use ethers_core::types::U256;
use crate::{JsonRpcError, ProviderError};
use super::WsError;
#[derive(Debug, thiserror::Error)]
pub enum WsClientError {
/// Thrown if deserialization failed
#[error(transparent)]
JsonError(#[from] serde_json::Error),
/// Thrown if the response could not be parsed
#[error(transparent)]
JsonRpcError(#[from] JsonRpcError),
/// Internal lib error
#[error(transparent)]
InternalError(#[from] WsError),
/// Remote server sent a Close message
#[error("Websocket closed unexpectedly")]
UnexpectedClose,
/// Unexpected channel closure
#[error("Unexpected internal channel closure. This is likely a bug. Please report via github")]
DeadChannel,
/// Thrown if the websocket responds with binary data
#[error("Websocket responded with unexpected binary data")]
UnexpectedBinary(Vec<u8>),
/// PubSubClient asked to listen to an unknown subscription id
#[error("Attempted to listen to unknown subscription: {0:?}")]
UnknownSubscription(U256),
/// Too Many Reconnects
#[error("Reconnect limit reached")]
TooManyReconnects,
}
impl crate::RpcError for WsClientError {
fn as_error_response(&self) -> Option<&JsonRpcError> {
if let WsClientError::JsonRpcError(err) = self {
Some(err)
} else {
None
}
}
fn as_serde_error(&self) -> Option<&serde_json::Error> {
match self {
WsClientError::JsonError(err) => Some(err),
_ => None,
}
}
}
impl From<WsClientError> for ProviderError {
fn from(src: WsClientError) -> Self {
ProviderError::JsonRpcClientError(Box::new(src))
}
}

View File

@ -0,0 +1,409 @@
use std::{
collections::{BTreeMap, HashMap},
sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
},
};
use ethers_core::types::U256;
use futures_channel::{mpsc, oneshot};
use futures_util::{select_biased, StreamExt};
use serde_json::value::RawValue;
use crate::JsonRpcError;
use super::{
backend::{BackendDriver, WsBackend},
ActiveSub, ConnectionDetails, InFlight, Instruction, Notification, PubSubItem, Response, SubId,
WsClient, WsClientError,
};
pub type SharedChannelMap = Arc<Mutex<HashMap<U256, mpsc::UnboundedReceiver<Box<RawValue>>>>>;
pub const DEFAULT_RECONNECTS: usize = 5;
/// This struct manages the relationship between the u64 request ID, and U256
/// server-side subscription ID. It does this by aliasing the server ID to the
/// request ID, and returning the Request ID to the caller (hiding the server
/// ID in the SubscriptionManager internals.) Giving the caller a "fake"
/// subscription id allows the subscription to behave consistently across
/// reconnections
pub struct SubscriptionManager {
// Active subs indexed by request id
subs: BTreeMap<u64, ActiveSub>,
// Maps active server-side IDs to local subscription IDs
aliases: HashMap<U256, u64>,
// Used to share notification channels with the WsClient(s)
channel_map: SharedChannelMap,
}
impl SubscriptionManager {
fn new(channel_map: SharedChannelMap) -> Self {
Self { subs: Default::default(), aliases: Default::default(), channel_map }
}
fn count(&self) -> usize {
self.subs.len()
}
fn add_alias(&mut self, sub: U256, id: u64) {
if let Some(entry) = self.subs.get_mut(&id) {
entry.current_server_id = Some(sub);
}
self.aliases.insert(sub, id);
}
fn remove_alias(&mut self, server_id: U256) {
if let Some(id) = self.aliases.get(&server_id) {
if let Some(sub) = self.subs.get_mut(id) {
sub.current_server_id = None;
}
}
self.aliases.remove(&server_id);
}
#[tracing::instrument(skip(self))]
fn end_subscription(&mut self, id: u64) -> Option<Box<RawValue>> {
if let Some(sub) = self.subs.remove(&id) {
if let Some(server_id) = sub.current_server_id {
tracing::debug!(server_id = format!("0x{server_id:x}"), "Ending subscription");
self.remove_alias(server_id);
// drop the receiver as we don't need the result
let (channel, _) = oneshot::channel();
// Serialization errors are ignored, and result in the request
// not being dispatched. This is fine, as worst case it will
// result in the server sending us notifications we ignore
let unsub_request = InFlight {
method: "eth_unsubscribe".to_string(),
params: SubId(server_id).serialize_raw().ok()?,
channel,
};
// reuse the RPC ID. this is somewhat dirty.
return unsub_request.serialize_raw(id).ok()
}
tracing::trace!("No current server id");
}
tracing::trace!("Cannot end unknown subscription");
None
}
#[tracing::instrument(skip_all, fields(server_id = ?notification.subscription))]
fn handle_notification(&mut self, notification: Notification) {
let server_id = notification.subscription;
// If no alias, just return
let id_opt = self.aliases.get(&server_id).copied();
if id_opt.is_none() {
tracing::debug!(
server_id = format!("0x{server_id:x}"),
"No aliased subscription found"
);
return
}
let id = id_opt.unwrap();
// alias exists, or should be dropped from alias table
let sub_opt = self.subs.get(&id);
if sub_opt.is_none() {
tracing::trace!(id, "Aliased subscription found, but not active");
self.aliases.remove(&server_id);
}
let active = sub_opt.unwrap();
tracing::debug!(id, "Forwarding notification to listener");
// send the notification over the channel
let send_res = active.channel.unbounded_send(notification.result);
// receiver has dropped, so we drop the sub
if send_res.is_err() {
tracing::debug!(id, "Listener dropped. Dropping alias and subs");
// TODO: end subcription here?
self.aliases.remove(&server_id);
self.subs.remove(&id);
}
}
fn req_success(&mut self, id: u64, result: Box<RawValue>) -> Box<RawValue> {
if let Ok(server_id) = serde_json::from_str::<SubId>(result.get()) {
tracing::debug!(id, server_id = %server_id.0, "Registering new sub alias");
self.add_alias(server_id.0, id);
let result = U256::from(id);
RawValue::from_string(format!("\"0x{result:x}\"")).unwrap()
} else {
result
}
}
fn has(&self, id: u64) -> bool {
self.subs.contains_key(&id)
}
fn to_reissue(&self) -> impl Iterator<Item = (&u64, &ActiveSub)> {
self.subs.iter()
}
fn service_subscription_request(
&mut self,
id: u64,
params: Box<RawValue>,
) -> Result<Box<RawValue>, WsClientError> {
let (tx, rx) = mpsc::unbounded();
let active_sub = ActiveSub { params, channel: tx, current_server_id: None };
let req = active_sub.serialize_raw(id)?;
// Explicit scope for the lock
// This insertion should be made BEFORE the request returns.
// So we make it before the request is even dispatched :)
{
self.channel_map.lock().unwrap().insert(id.into(), rx);
}
self.subs.insert(id, active_sub);
Ok(req)
}
}
/// The `RequestManager` holds copies of all pending requests (as `InFlight`),
/// and active subscriptions (as `ActiveSub`). When reconnection occurs, all
/// pending requests are re-dispatched to the new backend, and all active subs
/// are re-subscribed
///
/// `RequestManager` holds a `BackendDriver`, to communicate with the current
/// backend. Reconnection is accomplished by instantiating a new `WsBackend` and
/// swapping out the manager's `BackendDriver`.
///
/// In order to provide continuity of subscription IDs to the client, the
/// `RequestManager` also keeps a `SubscriptionManager`. See the
/// `SubscriptionManager` docstring for more complete details
///
/// The behavior is accessed by the WsClient frontend, which implements ]
/// `JsonRpcClient`. The `WsClient` is cloneable, so no need for an arc :). It
/// communicates to the request manager via a channel, and receives
/// notifications in a shared map for the client to retrieve
///
/// The `RequestManager` shuts down and drops when all `WsClient` instances have
/// been dropped (because all instruction channel `UnboundedSender` instances
/// will have dropped).
pub struct RequestManager {
// Next JSON-RPC Request ID
id: AtomicU64,
// How many times we should reconnect the backend before erroring
reconnects: usize,
// Subscription manager
subs: SubscriptionManager,
// Requests for which a response has not been receivedc
reqs: BTreeMap<u64, InFlight>,
// Control of the active WS backend
backend: BackendDriver,
// The URL and optional auth info for the connection
conn: ConnectionDetails,
// Instructions from the user-facing providers
instructions: mpsc::UnboundedReceiver<Instruction>,
}
impl RequestManager {
fn next_id(&mut self) -> u64 {
self.id.fetch_add(1, Ordering::Relaxed)
}
pub async fn connect(conn: ConnectionDetails) -> Result<(Self, WsClient), WsClientError> {
Self::connect_with_reconnects(conn, DEFAULT_RECONNECTS).await
}
pub async fn connect_with_reconnects(
conn: ConnectionDetails,
reconnects: usize,
) -> Result<(Self, WsClient), WsClientError> {
let (ws, backend) = WsBackend::connect(conn.clone()).await?;
let (instructions_tx, instructions_rx) = mpsc::unbounded();
let channel_map: SharedChannelMap = Default::default();
ws.spawn();
Ok((
Self {
id: Default::default(),
reconnects,
subs: SubscriptionManager::new(channel_map.clone()),
reqs: Default::default(),
backend,
conn,
instructions: instructions_rx,
},
WsClient { instructions: instructions_tx, channel_map },
))
}
async fn reconnect(&mut self) -> Result<(), WsClientError> {
if self.reconnects == 0 {
return Err(WsClientError::TooManyReconnects)
}
self.reconnects -= 1;
tracing::info!(remaining = self.reconnects, url = self.conn.url, "Reconnecting to backend");
// create the new backend
let (s, mut backend) = WsBackend::connect(self.conn.clone()).await?;
// spawn the new backend
s.spawn();
// swap out the backend
std::mem::swap(&mut self.backend, &mut backend);
// rename for clarity
let mut old_backend = backend;
// Drain anything in the backend
tracing::debug!("Draining old backend to_handle channel");
while let Some(to_handle) = old_backend.to_handle.next().await {
self.handle(to_handle);
}
// issue a shutdown command (even though it's likely gone)
old_backend.shutdown();
tracing::debug!(count = self.subs.count(), "Re-starting active subscriptions");
// reissue subscriptionps
for (id, sub) in self.subs.to_reissue() {
self.backend
.dispatcher
.unbounded_send(sub.serialize_raw(*id)?)
.map_err(|_| WsClientError::DeadChannel)?;
}
tracing::debug!(count = self.reqs.len(), "Re-issuing pending requests");
// reissue requests. We filter these to prevent in-flight requests for
// subscriptions to be re-issued twice (once in above loop, once in this loop).
for (id, req) in self.reqs.iter().filter(|(id, _)| !self.subs.has(**id)) {
self.backend
.dispatcher
.unbounded_send(req.serialize_raw(*id)?)
.map_err(|_| WsClientError::DeadChannel)?;
}
tracing::info!(subs = self.subs.count(), reqs = self.reqs.len(), "Re-connection complete");
Ok(())
}
#[tracing::instrument(skip(self, result))]
fn req_success(&mut self, id: u64, result: Box<RawValue>) {
// pending fut is missing, this is fine
tracing::trace!(%result, "Success response received");
if let Some(req) = self.reqs.remove(&id) {
tracing::debug!("Sending result to request listener");
// Allow subscription manager to rewrite the result if the request
// corresponds to a known ID
let result = if self.subs.has(id) { self.subs.req_success(id, result) } else { result };
let _ = req.channel.send(Ok(result));
} else {
tracing::trace!("No InFlight found");
}
}
fn req_fail(&mut self, id: u64, error: JsonRpcError) {
// pending fut is missing, this is fine
if let Some(req) = self.reqs.remove(&id) {
// pending fut has been dropped, this is fine
let _ = req.channel.send(Err(error));
}
}
fn handle(&mut self, item: PubSubItem) {
match item {
PubSubItem::Success { id, result } => self.req_success(id, result),
PubSubItem::Error { id, error } => self.req_fail(id, error),
PubSubItem::Notification { params } => self.subs.handle_notification(params),
}
}
#[tracing::instrument(skip(self, params, sender))]
fn service_request(
&mut self,
id: u64,
method: String,
params: Box<RawValue>,
sender: oneshot::Sender<Response>,
) -> Result<(), WsClientError> {
let in_flight = InFlight { method, params, channel: sender };
let req = in_flight.serialize_raw(id)?;
// Ordering matters here. We want this block above the unbounded send,
// and after the serialization
if in_flight.method == "eth_subscribe" {
self.subs.service_subscription_request(id, in_flight.params.clone())?;
}
// Must come after self.subs.service_subscription_request. Do not re-order
tracing::debug!("Dispatching request to backend");
self.backend.dispatcher.unbounded_send(req).map_err(|_| WsClientError::DeadChannel)?;
self.reqs.insert(id, in_flight);
Ok(())
}
fn service_instruction(&mut self, instruction: Instruction) -> Result<(), WsClientError> {
match instruction {
Instruction::Request { method, params, sender } => {
let id = self.next_id();
self.service_request(id, method, params, sender)?;
}
Instruction::Unsubscribe { id } => {
if let Some(req) = self.subs.end_subscription(id.low_u64()) {
self.backend
.dispatcher
.unbounded_send(req)
.map_err(|_| WsClientError::DeadChannel)?;
}
}
}
Ok(())
}
pub fn spawn(mut self) {
let fut = async move {
let result = loop {
// We bias the loop so that we always handle messages before
// reconnecting, and always reconnect before dispatching new
// requests
select_biased! {
item_opt = self.backend.to_handle.next() => {
match item_opt {
Some(item) => self.handle(item),
// Backend is gone, so reconnect
None => if let Err(e) = self.reconnect().await {
break Err(e);
}
}
},
_ = &mut self.backend.error => {
if let Err(e) = self.reconnect().await {
break Err(e);
}
},
inst_opt = self.instructions.next() => {
match inst_opt {
Some(instruction) => if let Err(e) = self.service_instruction(instruction) { break Err(e)},
// User-facing side is gone, so just exit
None => break Ok(()),
}
}
}
};
if let Err(err) = result {
tracing::error!(%err, "Error during reconnection");
}
// Issue the shutdown command. we don't care if it is received
self.backend.shutdown();
};
#[cfg(target_arch = "wasm32")]
super::spawn_local(fut);
#[cfg(not(target_arch = "wasm32"))]
tokio::spawn(fut);
}
}

View File

@ -0,0 +1,145 @@
#![allow(missing_docs)]
mod backend;
mod manager;
use manager::{RequestManager, SharedChannelMap};
mod types;
pub use types::ConnectionDetails;
pub(self) use types::*;
mod error;
pub use error::*;
use async_trait::async_trait;
use ethers_core::types::U256;
use futures_channel::{mpsc, oneshot};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::value::RawValue;
use crate::{JsonRpcClient, ProviderError, PubsubClient};
#[cfg(not(target_arch = "wasm32"))]
use crate::Authorization;
#[derive(Debug, Clone)]
pub struct WsClient {
// Used to send instructions to the `RequestManager`
instructions: mpsc::UnboundedSender<Instruction>,
// Used to receive sub notifications channels with the backend
channel_map: SharedChannelMap,
}
impl WsClient {
pub async fn connect(conn: impl Into<ConnectionDetails>) -> Result<Self, WsClientError> {
let (man, this) = RequestManager::connect(conn.into()).await?;
man.spawn();
Ok(this)
}
pub async fn connect_with_reconnects(
conn: impl Into<ConnectionDetails>,
reconnects: usize,
) -> Result<Self, WsClientError> {
let (man, this) = RequestManager::connect_with_reconnects(conn.into(), reconnects).await?;
man.spawn();
Ok(this)
}
#[tracing::instrument(skip(self, params), err)]
async fn make_request<R>(&self, method: &str, params: Box<RawValue>) -> Result<R, WsClientError>
where
R: DeserializeOwned,
{
let (tx, rx) = oneshot::channel();
let instruction = Instruction::Request { method: method.to_owned(), params, sender: tx };
self.instructions
.unbounded_send(instruction)
.map_err(|_| WsClientError::UnexpectedClose)?;
let res = rx.await.map_err(|_| WsClientError::UnexpectedClose)??;
tracing::trace!(res = %res, "Received response from request manager");
let resp = serde_json::from_str(res.get())?;
tracing::trace!("Deserialization success");
Ok(resp)
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl JsonRpcClient for WsClient {
type Error = WsClientError;
async fn request<T, R>(&self, method: &str, params: T) -> Result<R, WsClientError>
where
T: Serialize + Send + Sync,
R: DeserializeOwned,
{
let params = serde_json::to_string(&params)?;
let params = RawValue::from_string(params)?;
let res = self.make_request(method, params).await?;
Ok(res)
}
}
impl PubsubClient for WsClient {
type NotificationStream = mpsc::UnboundedReceiver<Box<RawValue>>;
fn subscribe<T: Into<U256>>(&self, id: T) -> Result<Self::NotificationStream, WsClientError> {
// due to the behavior of the request manager, we know this map has
// been populated by the time the `request()` call returns
let id = id.into();
self.channel_map.lock().unwrap().remove(&id).ok_or(WsClientError::UnknownSubscription(id))
}
fn unsubscribe<T: Into<U256>>(&self, id: T) -> Result<(), WsClientError> {
self.instructions
.unbounded_send(Instruction::Unsubscribe { id: id.into() })
.map_err(|_| WsClientError::UnexpectedClose)
}
}
impl crate::Provider<WsClient> {
/// Direct connection to a websocket endpoint. Defaults to 5 reconnects.
pub async fn connect(url: impl Into<ConnectionDetails>) -> Result<Self, ProviderError> {
let ws = crate::Ws::connect(url).await?;
Ok(Self::new(ws))
}
/// Direct connection to a websocket endpoint, with a set number of
/// reconnection attempts
pub async fn connect_with_reconnects(
url: impl Into<ConnectionDetails>,
reconnects: usize,
) -> Result<Self, ProviderError> {
let ws = crate::Ws::connect_with_reconnects(url, reconnects).await?;
Ok(Self::new(ws))
}
/// Connect to a WS RPC provider with authentication details
#[cfg(not(target_arch = "wasm32"))]
pub async fn connect_with_auth(
url: impl AsRef<str>,
auth: Authorization,
) -> Result<Self, ProviderError> {
let conn = ConnectionDetails::new(url, Some(auth));
let ws = crate::Ws::connect(conn).await?;
Ok(Self::new(ws))
}
#[cfg(not(target_arch = "wasm32"))]
/// Connect to a WS RPC provider with authentication details and a set
/// number of reconnection attempts
pub async fn connect_with_auth_and_reconnects(
url: impl AsRef<str>,
auth: Authorization,
reconnects: usize,
) -> Result<Self, ProviderError> {
let conn = ConnectionDetails::new(url, Some(auth));
let ws = crate::Ws::connect_with_reconnects(conn, reconnects).await?;
Ok(Self::new(ws))
}
}

View File

@ -0,0 +1,308 @@
use std::fmt;
use ethers_core::types::U256;
use futures_channel::{mpsc, oneshot};
use serde::{de, Deserialize};
use serde_json::value::RawValue;
use crate::{common::Request, JsonRpcError};
// Normal JSON-RPC response
pub type Response = Result<Box<RawValue>, JsonRpcError>;
#[derive(serde::Deserialize, serde::Serialize)]
pub struct SubId(pub U256);
impl SubId {
pub(super) fn serialize_raw(&self) -> Result<Box<RawValue>, serde_json::Error> {
let s = serde_json::to_string(&self)?;
RawValue::from_string(s)
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct Notification {
pub subscription: U256,
pub result: Box<RawValue>,
}
#[derive(Debug, Clone)]
pub enum PubSubItem {
Success { id: u64, result: Box<RawValue> },
Error { id: u64, error: JsonRpcError },
Notification { params: Notification },
}
// FIXME: ideally, this could be auto-derived as an untagged enum, but due to
// https://github.com/serde-rs/serde/issues/1183 this currently fails
impl<'de> Deserialize<'de> for PubSubItem {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ResponseVisitor;
impl<'de> de::Visitor<'de> for ResponseVisitor {
type Value = PubSubItem;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a valid jsonrpc 2.0 response object")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
let mut jsonrpc = false;
// response & error
let mut id = None;
// only response
let mut result = None;
// only error
let mut error = None;
// only notification
let mut method = None;
let mut params = None;
while let Some(key) = map.next_key()? {
match key {
"jsonrpc" => {
if jsonrpc {
return Err(de::Error::duplicate_field("jsonrpc"))
}
let value = map.next_value()?;
if value != "2.0" {
return Err(de::Error::invalid_value(
de::Unexpected::Str(value),
&"2.0",
))
}
jsonrpc = true;
}
"id" => {
if id.is_some() {
return Err(de::Error::duplicate_field("id"))
}
let value: u64 = map.next_value()?;
id = Some(value);
}
"result" => {
if result.is_some() {
return Err(de::Error::duplicate_field("result"))
}
let value: Box<RawValue> = map.next_value()?;
result = Some(value);
}
"error" => {
if error.is_some() {
return Err(de::Error::duplicate_field("error"))
}
let value: JsonRpcError = map.next_value()?;
error = Some(value);
}
"method" => {
if method.is_some() {
return Err(de::Error::duplicate_field("method"))
}
let value: String = map.next_value()?;
method = Some(value);
}
"params" => {
if params.is_some() {
return Err(de::Error::duplicate_field("params"))
}
let value: Notification = map.next_value()?;
params = Some(value);
}
key => {
return Err(de::Error::unknown_field(
key,
&["id", "jsonrpc", "result", "error", "params", "method"],
))
}
}
}
// jsonrpc version must be present in all responses
if !jsonrpc {
return Err(de::Error::missing_field("jsonrpc"))
}
match (id, result, error, method, params) {
(Some(id), Some(result), None, None, None) => {
Ok(PubSubItem::Success { id, result })
}
(Some(id), None, Some(error), None, None) => {
Ok(PubSubItem::Error { id, error })
}
(None, None, None, Some(_), Some(params)) => {
Ok(PubSubItem::Notification { params })
}
_ => Err(de::Error::custom(
"response must be either a success/error or notification object",
)),
}
}
}
deserializer.deserialize_map(ResponseVisitor)
}
}
impl std::fmt::Display for PubSubItem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PubSubItem::Success { id, .. } => write!(f, "Req success. ID: {id}"),
PubSubItem::Error { id, .. } => write!(f, "Req error. ID: {id}"),
PubSubItem::Notification { params } => {
write!(f, "Notification for sub: {:?}", params.subscription)
}
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionDetails {
pub url: String,
#[cfg(not(target_arch = "wasm32"))]
pub auth: Option<crate::Authorization>,
}
impl ConnectionDetails {
#[cfg(not(target_arch = "wasm32"))]
pub fn new(url: impl AsRef<str>, auth: Option<crate::Authorization>) -> Self {
Self { url: url.as_ref().to_string(), auth }
}
#[cfg(target_arch = "wasm32")]
pub fn new(url: impl AsRef<str>) -> Self {
Self { url: url.as_ref().to_string() }
}
}
impl<T> From<T> for ConnectionDetails
where
T: AsRef<str>,
{
#[cfg(not(target_arch = "wasm32"))]
fn from(value: T) -> Self {
ConnectionDetails { url: value.as_ref().to_string(), auth: None }
}
#[cfg(target_arch = "wasm32")]
fn from(value: T) -> Self {
ConnectionDetails { url: value.as_ref().to_string() }
}
}
#[derive(Debug)]
pub(super) struct InFlight {
pub method: String,
pub params: Box<RawValue>,
pub channel: oneshot::Sender<Response>,
}
impl InFlight {
pub(super) fn to_request(&self, id: u64) -> Request<'_, Box<RawValue>> {
Request::new(id, &self.method, self.params.clone())
}
pub(super) fn serialize_raw(&self, id: u64) -> Result<Box<RawValue>, serde_json::Error> {
let s = serde_json::to_string(&self.to_request(id))?;
RawValue::from_string(s)
}
}
#[derive(Debug)]
pub(super) struct ActiveSub {
pub params: Box<RawValue>,
pub channel: mpsc::UnboundedSender<Box<RawValue>>,
pub current_server_id: Option<U256>,
}
impl ActiveSub {
pub(super) fn to_request(&self, id: u64) -> Request<'static, Box<RawValue>> {
Request::new(id, "eth_subscribe", self.params.clone())
}
pub(super) fn serialize_raw(&self, id: u64) -> Result<Box<RawValue>, serde_json::Error> {
let s = serde_json::to_string(&self.to_request(id))?;
RawValue::from_string(s)
}
}
/// Instructions for the `WsServer`.
pub enum Instruction {
/// JSON-RPC request
Request { method: String, params: Box<RawValue>, sender: oneshot::Sender<Response> },
/// Cancel an existing subscription
Unsubscribe { id: U256 },
}
#[cfg(target_arch = "wasm32")]
mod aliases {
pub use wasm_bindgen::prelude::*;
pub use wasm_bindgen_futures::spawn_local;
pub use ws_stream_wasm::*;
pub type Message = WsMessage;
pub type WsError = ws_stream_wasm::WsErr;
pub type WsStreamItem = Message;
pub type InternalStream = futures_util::stream::Fuse<WsStream>;
}
#[cfg(not(target_arch = "wasm32"))]
mod aliases {
pub use tokio_tungstenite::{
connect_async,
tungstenite::{self, protocol::CloseFrame},
};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
pub type Message = tungstenite::protocol::Message;
pub type WsError = tungstenite::Error;
pub type WsStreamItem = Result<Message, WsError>;
pub use http::Request as HttpRequest;
pub use tracing::{debug, error, trace, warn};
pub use tungstenite::client::IntoClientRequest;
pub use tokio::time::sleep;
pub type InternalStream =
futures_util::stream::Fuse<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>>;
impl IntoClientRequest for super::ConnectionDetails {
fn into_client_request(
self,
) -> tungstenite::Result<tungstenite::handshake::client::Request> {
let mut request: HttpRequest<()> = self.url.into_client_request()?;
if let Some(auth) = self.auth {
let mut auth_value = http::HeaderValue::from_str(&auth.to_string())?;
auth_value.set_sensitive(true);
request.headers_mut().insert(http::header::AUTHORIZATION, auth_value);
}
request.into_client_request()
}
}
}
pub use aliases::*;
#[cfg(test)]
mod test {
use super::*;
#[test]
fn it_desers_pubsub_items() {
let a = "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}";
serde_json::from_str::<PubSubItem>(a).unwrap();
}
}

View File

@ -1,10 +1,10 @@
use ethers_core::types::Filter;
use ethers_providers::{Middleware, Provider, StreamExt, Ws};
use ethers_providers::{Middleware, Provider, StreamExt};
use futures_util::SinkExt;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::{
accept_async, connect_async,
accept_async,
tungstenite::{
self,
protocol::{frame::coding::CloseCode, CloseFrame},
@ -15,20 +15,6 @@ use tungstenite::protocol::Message;
const WS_ENDPOINT: &str = "127.0.0.1:9002";
#[tokio::test]
async fn graceful_disconnect_on_ws_errors() {
// Spawn a fake Ws server that will drop our connection after a while
spawn_ws_server().await;
// Connect to the fake server
let (ws, _) = connect_async(format!("ws://{WS_ENDPOINT}")).await.unwrap();
let provider = Provider::new(Ws::new(ws));
let filter = Filter::new().event("Transfer(address,address,uint256)");
let mut stream = provider.subscribe_logs(&filter).await.unwrap();
assert!(stream.next().await.is_none());
}
async fn spawn_ws_server() {
let listener = TcpListener::bind(&WS_ENDPOINT).await.expect("Can't listen");
tokio::spawn(async move {
@ -43,7 +29,7 @@ async fn handle_conn(stream: TcpStream) -> Result<(), Error> {
while ws_stream.next().await.is_some() {
let res: String =
"{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}"
"{\"jsonrpc\":\"2.0\",\"id\":0,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}"
.into();
// Answer with a valid RPC response to keep the connection alive
@ -64,3 +50,17 @@ async fn handle_conn(stream: TcpStream) -> Result<(), Error> {
Ok(())
}
#[tokio::test]
async fn graceful_disconnect_on_ws_errors() {
// Spawn a fake Ws server that will drop our connection after a while
spawn_ws_server().await;
// Connect to the fake server
let provider =
Provider::connect_with_reconnects(format!("ws://{WS_ENDPOINT}"), 1).await.unwrap();
let filter = Filter::new().event("Transfer(address,address,uint256)");
let mut stream = provider.subscribe_logs(&filter).await.unwrap();
assert!(stream.next().await.is_none());
}

View File

@ -51,6 +51,7 @@ serde_json = { version = "1.0.64" }
yubihsm = { version = "0.41.0", features = ["secp256k1", "usb", "mockhsm"] }
tokio = { version = "1.18", default-features = false, features = ["macros", "rt"] }
tempfile = "3.4.0"
tracing-subscriber = "0.3.16"
[features]
futures = ["futures-util", "futures-executor"]