From 825a35c6860981fcfd69113ae36d30ed550ccf4e Mon Sep 17 00:00:00 2001
From: Chris <ccbrown112@gmail.com>
Date: Sat, 18 Jul 2020 17:00:17 -0400
Subject: [PATCH] Fix juniper_warp subscriptions (#707)

* use anyhow::anyhow for subscriptions mod

* remove unnecessary Clone

* fix simultaneous subscriptions

* rm unnecessary .clone
---
 juniper_warp/src/lib.rs | 57 ++++++++++++++++++++++++-----------------
 1 file changed, 33 insertions(+), 24 deletions(-)

diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs
index 158813c8..d07b1916 100644
--- a/juniper_warp/src/lib.rs
+++ b/juniper_warp/src/lib.rs
@@ -417,7 +417,8 @@ pub mod subscriptions {
         },
     };
 
-    use futures::{channel::mpsc, Future, StreamExt as _, TryStreamExt as _};
+    use anyhow::anyhow;
+    use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _};
     use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
     use juniper_subscriptions::Coordinator;
     use serde::{Deserialize, Serialize};
@@ -439,7 +440,7 @@ pub mod subscriptions {
         Mutation::TypeInfo: Send + Sync,
         Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
         Subscription::TypeInfo: Send + Sync,
-        CtxT: Clone + Send + Sync + 'static,
+        CtxT: Send + Sync + 'static,
         S: ScalarValue + Send + Sync + 'static,
     {
         let (sink_tx, sink_rx) = websocket.split();
@@ -452,25 +453,28 @@ pub mod subscriptions {
         );
 
         let context = Arc::new(context);
-        let running = Arc::new(AtomicBool::new(false));
         let got_close_signal = Arc::new(AtomicBool::new(false));
         let got_close_signal2 = got_close_signal.clone();
 
+        struct SubscriptionState {
+            should_stop: AtomicBool,
+        }
+        let subscription_states = HashMap::<String, Arc<SubscriptionState>>::new();
+
         sink_rx
             .map_err(move |e| {
                 got_close_signal2.store(true, Ordering::Relaxed);
                 anyhow!("Websocket error: {}", e)
             })
-            .try_fold((), move |_, msg| {
+            .try_fold(subscription_states, move |mut subscription_states, msg| {
                 let coordinator = coordinator.clone();
                 let context = context.clone();
-                let running = running.clone();
                 let got_close_signal = got_close_signal.clone();
                 let ws_tx = ws_tx.clone();
 
                 async move {
                     if msg.is_close() {
-                        return Ok(());
+                        return Ok(subscription_states);
                     }
 
                     let msg = msg
@@ -482,18 +486,20 @@ pub mod subscriptions {
                     match request.type_name.as_str() {
                         "connection_init" => {}
                         "start" => {
-                            {
-                                let closed = got_close_signal.load(Ordering::Relaxed);
-                                if closed {
-                                    return Ok(());
-                                }
-
-                                if running.load(Ordering::Relaxed) {
-                                    return Ok(());
-                                }
-                                running.store(true, Ordering::Relaxed);
+                            if got_close_signal.load(Ordering::Relaxed) {
+                                return Ok(subscription_states);
                             }
 
+                            let request_id = request.id.clone().unwrap_or("1".to_owned());
+
+                            if let Some(existing) = subscription_states.get(&request_id) {
+                                existing.should_stop.store(true, Ordering::Relaxed);
+                            }
+                            let state = Arc::new(SubscriptionState {
+                                should_stop: AtomicBool::new(false),
+                            });
+                            subscription_states.insert(request_id, state.clone());
+
                             let ws_tx = ws_tx.clone();
 
                             if let Some(ref payload) = request.payload {
@@ -507,8 +513,6 @@ pub mod subscriptions {
                             tokio::task::spawn(async move {
                                 let payload = request.payload.unwrap();
 
-                                let request_id = request.id.unwrap_or("1".to_owned());
-
                                 let graphql_request = GraphQLRequest::<S>::new(
                                     payload.query.unwrap(),
                                     None,
@@ -545,8 +549,9 @@ pub mod subscriptions {
                                 values_stream
                                     .take_while(move |response| {
                                         let request_id = request_id.clone();
-                                        let closed = got_close_signal.load(Ordering::Relaxed);
-                                        if !closed {
+                                        let should_stop = state.should_stop.load(Ordering::Relaxed)
+                                            || got_close_signal.load(Ordering::Relaxed);
+                                        if !should_stop {
                                             let mut response_text = serde_json::to_string(
                                                 &response,
                                             )
@@ -562,16 +567,19 @@ pub mod subscriptions {
                                             ))));
                                         }
 
-                                        async move { !closed }
+                                        async move { !should_stop }
                                     })
                                     .for_each(|_| async {})
                                     .await;
                             });
                         }
                         "stop" => {
-                            got_close_signal.store(true, Ordering::Relaxed);
-
                             let request_id = request.id.unwrap_or("1".to_owned());
+                            if let Some(existing) = subscription_states.get(&request_id) {
+                                existing.should_stop.store(true, Ordering::Relaxed);
+                                subscription_states.remove(&request_id);
+                            }
+
                             let close_message = format!(
                                 r#"{{"type":"complete","id":"{}","payload":null}}"#,
                                 request_id
@@ -584,9 +592,10 @@ pub mod subscriptions {
                         _ => {}
                     }
 
-                    Ok(())
+                    Ok(subscription_states)
                 }
             })
+            .map_ok(|_| ())
     }
 
     #[derive(Deserialize)]