Fix juniper_warp subscriptions ()

* use anyhow::anyhow for subscriptions mod

* remove unnecessary Clone

* fix simultaneous subscriptions

* rm unnecessary .clone
This commit is contained in:
Chris 2020-07-18 17:00:17 -04:00 committed by GitHub
parent f5839c034e
commit 825a35c686
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -417,7 +417,8 @@ pub mod subscriptions {
}, },
}; };
use futures::{channel::mpsc, Future, StreamExt as _, TryStreamExt as _}; use anyhow::anyhow;
use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _};
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _}; use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
use juniper_subscriptions::Coordinator; use juniper_subscriptions::Coordinator;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -439,7 +440,7 @@ pub mod subscriptions {
Mutation::TypeInfo: Send + Sync, Mutation::TypeInfo: Send + Sync,
Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync, Subscription::TypeInfo: Send + Sync,
CtxT: Clone + Send + Sync + 'static, CtxT: Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static, S: ScalarValue + Send + Sync + 'static,
{ {
let (sink_tx, sink_rx) = websocket.split(); let (sink_tx, sink_rx) = websocket.split();
@ -452,25 +453,28 @@ pub mod subscriptions {
); );
let context = Arc::new(context); let context = Arc::new(context);
let running = Arc::new(AtomicBool::new(false));
let got_close_signal = Arc::new(AtomicBool::new(false)); let got_close_signal = Arc::new(AtomicBool::new(false));
let got_close_signal2 = got_close_signal.clone(); let got_close_signal2 = got_close_signal.clone();
struct SubscriptionState {
should_stop: AtomicBool,
}
let subscription_states = HashMap::<String, Arc<SubscriptionState>>::new();
sink_rx sink_rx
.map_err(move |e| { .map_err(move |e| {
got_close_signal2.store(true, Ordering::Relaxed); got_close_signal2.store(true, Ordering::Relaxed);
anyhow!("Websocket error: {}", e) anyhow!("Websocket error: {}", e)
}) })
.try_fold((), move |_, msg| { .try_fold(subscription_states, move |mut subscription_states, msg| {
let coordinator = coordinator.clone(); let coordinator = coordinator.clone();
let context = context.clone(); let context = context.clone();
let running = running.clone();
let got_close_signal = got_close_signal.clone(); let got_close_signal = got_close_signal.clone();
let ws_tx = ws_tx.clone(); let ws_tx = ws_tx.clone();
async move { async move {
if msg.is_close() { if msg.is_close() {
return Ok(()); return Ok(subscription_states);
} }
let msg = msg let msg = msg
@ -482,18 +486,20 @@ pub mod subscriptions {
match request.type_name.as_str() { match request.type_name.as_str() {
"connection_init" => {} "connection_init" => {}
"start" => { "start" => {
{ if got_close_signal.load(Ordering::Relaxed) {
let closed = got_close_signal.load(Ordering::Relaxed); return Ok(subscription_states);
if closed {
return Ok(());
}
if running.load(Ordering::Relaxed) {
return Ok(());
}
running.store(true, Ordering::Relaxed);
} }
let request_id = request.id.clone().unwrap_or("1".to_owned());
if let Some(existing) = subscription_states.get(&request_id) {
existing.should_stop.store(true, Ordering::Relaxed);
}
let state = Arc::new(SubscriptionState {
should_stop: AtomicBool::new(false),
});
subscription_states.insert(request_id, state.clone());
let ws_tx = ws_tx.clone(); let ws_tx = ws_tx.clone();
if let Some(ref payload) = request.payload { if let Some(ref payload) = request.payload {
@ -507,8 +513,6 @@ pub mod subscriptions {
tokio::task::spawn(async move { tokio::task::spawn(async move {
let payload = request.payload.unwrap(); let payload = request.payload.unwrap();
let request_id = request.id.unwrap_or("1".to_owned());
let graphql_request = GraphQLRequest::<S>::new( let graphql_request = GraphQLRequest::<S>::new(
payload.query.unwrap(), payload.query.unwrap(),
None, None,
@ -545,8 +549,9 @@ pub mod subscriptions {
values_stream values_stream
.take_while(move |response| { .take_while(move |response| {
let request_id = request_id.clone(); let request_id = request_id.clone();
let closed = got_close_signal.load(Ordering::Relaxed); let should_stop = state.should_stop.load(Ordering::Relaxed)
if !closed { || got_close_signal.load(Ordering::Relaxed);
if !should_stop {
let mut response_text = serde_json::to_string( let mut response_text = serde_json::to_string(
&response, &response,
) )
@ -562,16 +567,19 @@ pub mod subscriptions {
)))); ))));
} }
async move { !closed } async move { !should_stop }
}) })
.for_each(|_| async {}) .for_each(|_| async {})
.await; .await;
}); });
} }
"stop" => { "stop" => {
got_close_signal.store(true, Ordering::Relaxed);
let request_id = request.id.unwrap_or("1".to_owned()); let request_id = request.id.unwrap_or("1".to_owned());
if let Some(existing) = subscription_states.get(&request_id) {
existing.should_stop.store(true, Ordering::Relaxed);
subscription_states.remove(&request_id);
}
let close_message = format!( let close_message = format!(
r#"{{"type":"complete","id":"{}","payload":null}}"#, r#"{{"type":"complete","id":"{}","payload":null}}"#,
request_id request_id
@ -584,9 +592,10 @@ pub mod subscriptions {
_ => {} _ => {}
} }
Ok(()) Ok(subscription_states)
} }
}) })
.map_ok(|_| ())
} }
#[derive(Deserialize)] #[derive(Deserialize)]