From 825a35c6860981fcfd69113ae36d30ed550ccf4e Mon Sep 17 00:00:00 2001 From: Chris <ccbrown112@gmail.com> Date: Sat, 18 Jul 2020 17:00:17 -0400 Subject: [PATCH] Fix juniper_warp subscriptions (#707) * use anyhow::anyhow for subscriptions mod * remove unnecessary Clone * fix simultaneous subscriptions * rm unnecessary .clone --- juniper_warp/src/lib.rs | 57 ++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index 158813c8..d07b1916 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -417,7 +417,8 @@ pub mod subscriptions { }, }; - use futures::{channel::mpsc, Future, StreamExt as _, TryStreamExt as _}; + use anyhow::anyhow; + use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _}; use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _}; use juniper_subscriptions::Coordinator; use serde::{Deserialize, Serialize}; @@ -439,7 +440,7 @@ pub mod subscriptions { Mutation::TypeInfo: Send + Sync, Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, Subscription::TypeInfo: Send + Sync, - CtxT: Clone + Send + Sync + 'static, + CtxT: Send + Sync + 'static, S: ScalarValue + Send + Sync + 'static, { let (sink_tx, sink_rx) = websocket.split(); @@ -452,25 +453,28 @@ pub mod subscriptions { ); let context = Arc::new(context); - let running = Arc::new(AtomicBool::new(false)); let got_close_signal = Arc::new(AtomicBool::new(false)); let got_close_signal2 = got_close_signal.clone(); + struct SubscriptionState { + should_stop: AtomicBool, + } + let subscription_states = HashMap::<String, Arc<SubscriptionState>>::new(); + sink_rx .map_err(move |e| { got_close_signal2.store(true, Ordering::Relaxed); anyhow!("Websocket error: {}", e) }) - .try_fold((), move |_, msg| { + .try_fold(subscription_states, move |mut subscription_states, msg| { let coordinator = coordinator.clone(); let context = context.clone(); - let running = running.clone(); let got_close_signal = got_close_signal.clone(); let ws_tx = ws_tx.clone(); async move { if msg.is_close() { - return Ok(()); + return Ok(subscription_states); } let msg = msg @@ -482,18 +486,20 @@ pub mod subscriptions { match request.type_name.as_str() { "connection_init" => {} "start" => { - { - let closed = got_close_signal.load(Ordering::Relaxed); - if closed { - return Ok(()); - } - - if running.load(Ordering::Relaxed) { - return Ok(()); - } - running.store(true, Ordering::Relaxed); + if got_close_signal.load(Ordering::Relaxed) { + return Ok(subscription_states); } + let request_id = request.id.clone().unwrap_or("1".to_owned()); + + if let Some(existing) = subscription_states.get(&request_id) { + existing.should_stop.store(true, Ordering::Relaxed); + } + let state = Arc::new(SubscriptionState { + should_stop: AtomicBool::new(false), + }); + subscription_states.insert(request_id, state.clone()); + let ws_tx = ws_tx.clone(); if let Some(ref payload) = request.payload { @@ -507,8 +513,6 @@ pub mod subscriptions { tokio::task::spawn(async move { let payload = request.payload.unwrap(); - let request_id = request.id.unwrap_or("1".to_owned()); - let graphql_request = GraphQLRequest::<S>::new( payload.query.unwrap(), None, @@ -545,8 +549,9 @@ pub mod subscriptions { values_stream .take_while(move |response| { let request_id = request_id.clone(); - let closed = got_close_signal.load(Ordering::Relaxed); - if !closed { + let should_stop = state.should_stop.load(Ordering::Relaxed) + || got_close_signal.load(Ordering::Relaxed); + if !should_stop { let mut response_text = serde_json::to_string( &response, ) @@ -562,16 +567,19 @@ pub mod subscriptions { )))); } - async move { !closed } + async move { !should_stop } }) .for_each(|_| async {}) .await; }); } "stop" => { - got_close_signal.store(true, Ordering::Relaxed); - let request_id = request.id.unwrap_or("1".to_owned()); + if let Some(existing) = subscription_states.get(&request_id) { + existing.should_stop.store(true, Ordering::Relaxed); + subscription_states.remove(&request_id); + } + let close_message = format!( r#"{{"type":"complete","id":"{}","payload":null}}"#, request_id @@ -584,9 +592,10 @@ pub mod subscriptions { _ => {} } - Ok(()) + Ok(subscription_states) } }) + .map_ok(|_| ()) } #[derive(Deserialize)]