From 8d7ba8295c6a37f9300b2fb7a099f4e6eb87af7a Mon Sep 17 00:00:00 2001
From: Mihai Dinculescu <mihai.dinculescu@outlook.com>
Date: Sun, 9 Aug 2020 23:19:34 +0100
Subject: [PATCH] Impl subscriptions for juniper_actix (#716)

* Impl subscriptions for juniper_actix

* Add random_human example subscription

* Add actix_subscriptions example to CI

* fixup! Add random_human example subscription

* Migrate actix subscriptions to juniper_graphql_ws

* Simplify error handling

* Change unwrap to expect

* Close connection on server serialization error

Co-authored-by: Christian Legnitto <LegNeato@users.noreply.github.com>
---
 Cargo.toml                                   |   1 +
 examples/actix_subscriptions/Cargo.toml      |  22 ++
 examples/actix_subscriptions/Makefile.toml   |  15 ++
 examples/actix_subscriptions/src/main.rs     | 140 ++++++++++
 juniper/src/tests/fixtures/starwars/model.rs |   4 +-
 juniper_actix/CHANGELOG.md                   |   2 +-
 juniper_actix/Cargo.toml                     |  29 ++-
 juniper_actix/src/lib.rs                     | 261 +++++++++++++++++++
 8 files changed, 461 insertions(+), 13 deletions(-)
 create mode 100644 examples/actix_subscriptions/Cargo.toml
 create mode 100644 examples/actix_subscriptions/Makefile.toml
 create mode 100644 examples/actix_subscriptions/src/main.rs

diff --git a/Cargo.toml b/Cargo.toml
index d37670be..ddb54b5b 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -7,6 +7,7 @@ members = [
   "examples/basic_subscriptions",
   "examples/warp_async",
   "examples/warp_subscriptions",
+  "examples/actix_subscriptions",
   "integration_tests/juniper_tests",
   "integration_tests/async_await",
   "integration_tests/codegen_fail",
diff --git a/examples/actix_subscriptions/Cargo.toml b/examples/actix_subscriptions/Cargo.toml
new file mode 100644
index 00000000..6825e404
--- /dev/null
+++ b/examples/actix_subscriptions/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "actix_subscriptions"
+version = "0.1.0"
+authors = ["Mihai Dinculescu <mihai.dinculescu@outlook.com>"]
+edition = "2018"
+publish = false
+
+[dependencies]
+actix-web = "2.0.0"
+actix-rt = "1.1.1"
+actix-cors = "0.2.0"
+
+futures = "0.3.5"
+tokio = { version = "0.2", features = ["rt-core", "macros"] }
+env_logger = "0.7.1"
+serde = "1.0.114"
+serde_json = "1.0.57"
+rand = "0.7.3"
+
+juniper = { path = "../../juniper", features = ["expose-test-schema", "serde_json"] }
+juniper_actix = { path = "../../juniper_actix", features = ["subscriptions"] }
+juniper_graphql_ws = { path = "../../juniper_graphql_ws" }
diff --git a/examples/actix_subscriptions/Makefile.toml b/examples/actix_subscriptions/Makefile.toml
new file mode 100644
index 00000000..dccba933
--- /dev/null
+++ b/examples/actix_subscriptions/Makefile.toml
@@ -0,0 +1,15 @@
+[tasks.run]
+disabled = true
+
+[tasks.release]
+disabled = true
+[tasks.release-some]
+disabled = true
+[tasks.release-local-test]
+disabled = true
+[tasks.release-some-local-test]
+disabled = true
+[tasks.release-dry-run]
+disabled = true
+[tasks.release-some-dry-run]
+disabled = true
diff --git a/examples/actix_subscriptions/src/main.rs b/examples/actix_subscriptions/src/main.rs
new file mode 100644
index 00000000..e9f884ea
--- /dev/null
+++ b/examples/actix_subscriptions/src/main.rs
@@ -0,0 +1,140 @@
+use std::{env, pin::Pin, time::Duration};
+
+use actix_cors::Cors;
+use actix_web::{middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer};
+use futures::Stream;
+
+use juniper::{
+    tests::fixtures::starwars::{model::Database, schema::Query},
+    DefaultScalarValue, EmptyMutation, FieldError, RootNode,
+};
+use juniper_actix::{graphql_handler, playground_handler, subscriptions::subscriptions_handler};
+use juniper_graphql_ws::ConnectionConfig;
+
+type Schema = RootNode<'static, Query, EmptyMutation<Database>, Subscription>;
+
+fn schema() -> Schema {
+    Schema::new(Query, EmptyMutation::<Database>::new(), Subscription)
+}
+
+async fn playground() -> Result<HttpResponse, Error> {
+    playground_handler("/graphql", Some("/subscriptions")).await
+}
+
+async fn graphql(
+    req: actix_web::HttpRequest,
+    payload: actix_web::web::Payload,
+    schema: web::Data<Schema>,
+) -> Result<HttpResponse, Error> {
+    let context = Database::new();
+    graphql_handler(&schema, &context, req, payload).await
+}
+
+struct Subscription;
+
+struct RandomHuman {
+    id: String,
+    name: String,
+}
+
+// TODO: remove this when async interfaces are merged
+#[juniper::graphql_object(Context = Database)]
+impl RandomHuman {
+    fn id(&self) -> &str {
+        &self.id
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+#[juniper::graphql_subscription(Context = Database)]
+impl Subscription {
+    #[graphql(
+        description = "A random humanoid creature in the Star Wars universe every 3 seconds. Second result will be an error."
+    )]
+    async fn random_human(
+        context: &Database,
+    ) -> Pin<Box<dyn Stream<Item = Result<RandomHuman, FieldError>> + Send>> {
+        let mut counter = 0;
+
+        let context = (*context).clone();
+
+        use rand::{rngs::StdRng, Rng, SeedableRng};
+        let mut rng = StdRng::from_entropy();
+
+        let stream = tokio::time::interval(Duration::from_secs(3)).map(move |_| {
+            counter += 1;
+
+            if counter == 2 {
+                Err(FieldError::new(
+                    "some field error from handler",
+                    Value::Scalar(DefaultScalarValue::String(
+                        "some additional string".to_string(),
+                    )),
+                ))
+            } else {
+                let random_id = rng.gen_range(1000, 1005).to_string();
+                let human = context.get_human(&random_id).unwrap();
+
+                Ok(RandomHuman {
+                    id: human.id().to_owned(),
+                    name: human.name().to_owned(),
+                })
+            }
+        });
+
+        Box::pin(stream)
+    }
+}
+
+async fn subscriptions(
+    req: HttpRequest,
+    stream: web::Payload,
+    schema: web::Data<Schema>,
+) -> Result<HttpResponse, Error> {
+    let context = Database::new();
+    let schema = schema.into_inner();
+    let config = ConnectionConfig::new(context);
+    // set the keep alive interval to 15 secs so that it doesn't timeout in playground
+    // playground has a hard-coded timeout set to 20 secs
+    let config = config.with_keep_alive_interval(Duration::from_secs(15));
+
+    subscriptions_handler(req, stream, schema, config).await
+}
+
+#[actix_rt::main]
+async fn main() -> std::io::Result<()> {
+    env::set_var("RUST_LOG", "info");
+    env_logger::init();
+
+    HttpServer::new(move || {
+        App::new()
+            .data(schema())
+            .wrap(middleware::Compress::default())
+            .wrap(middleware::Logger::default())
+            .wrap(
+                Cors::new()
+                    .allowed_methods(vec!["POST", "GET"])
+                    .supports_credentials()
+                    .max_age(3600)
+                    .finish(),
+            )
+            .service(web::resource("/subscriptions").route(web::get().to(subscriptions)))
+            .service(
+                web::resource("/graphql")
+                    .route(web::post().to(graphql))
+                    .route(web::get().to(graphql)),
+            )
+            .service(web::resource("/playground").route(web::get().to(playground)))
+            .default_service(web::route().to(|| {
+                HttpResponse::Found()
+                    .header("location", "/playground")
+                    .finish()
+            }))
+    })
+    .bind(format!("{}:{}", "127.0.0.1", 8080))?
+    .run()
+    .await
+}
diff --git a/juniper/src/tests/fixtures/starwars/model.rs b/juniper/src/tests/fixtures/starwars/model.rs
index 354c0b31..ae250607 100644
--- a/juniper/src/tests/fixtures/starwars/model.rs
+++ b/juniper/src/tests/fixtures/starwars/model.rs
@@ -29,6 +29,7 @@ pub trait Droid: Character {
     fn primary_function(&self) -> &Option<String>;
 }
 
+#[derive(Clone)]
 struct HumanData {
     id: String,
     name: String,
@@ -38,6 +39,7 @@ struct HumanData {
     home_planet: Option<String>,
 }
 
+#[derive(Clone)]
 struct DroidData {
     id: String,
     name: String,
@@ -101,7 +103,7 @@ impl Droid for DroidData {
     }
 }
 
-#[derive(Default)]
+#[derive(Default, Clone)]
 pub struct Database {
     humans: HashMap<String, HumanData>,
     droids: HashMap<String, DroidData>,
diff --git a/juniper_actix/CHANGELOG.md b/juniper_actix/CHANGELOG.md
index 05232472..cf41ac35 100644
--- a/juniper_actix/CHANGELOG.md
+++ b/juniper_actix/CHANGELOG.md
@@ -1,3 +1,3 @@
 # master
-
+- Subscription support
 - Initial Release
diff --git a/juniper_actix/Cargo.toml b/juniper_actix/Cargo.toml
index ba24d40b..7f9be579 100644
--- a/juniper_actix/Cargo.toml
+++ b/juniper_actix/Cargo.toml
@@ -8,25 +8,32 @@ documentation = "https://docs.rs/juniper_actix"
 repository = "https://github.com/graphql-rust/juniper"
 edition = "2018"
 
+[features]
+subscriptions = ["juniper_graphql_ws"]
 
 [dependencies]
 actix = "0.9.0"
-actix-rt = "1.0.0"
+actix-rt = "1.1.1"
 actix-web = { version = "2.0.0", features = ["rustls"] }
 actix-web-actors = "2.0.0"
-futures = { version = "0.3.1", features = ["compat"] }
-juniper = { version = "0.14.2", path = "../juniper", default-features = false  }
+
+futures = { version = "0.3.5", features = ["compat"] }
 tokio = { version = "0.2", features = ["time"] }
-serde = { version = "1.0.75", features = ["derive"] }
-serde_json = "1.0.24"
+serde = { version = "1.0.114", features = ["derive"] }
+serde_json = "1.0.57"
 anyhow = "1.0"
 thiserror = "1.0"
 
+juniper = { version = "0.14.2", path = "../juniper", default-features = false  }
+juniper_graphql_ws = { path = "../juniper_graphql_ws", optional = true }
+
 [dev-dependencies]
-juniper = { version = "0.14.2", path = "../juniper", features = ["expose-test-schema", "serde_json"] }
-env_logger = "0.7.1"
-log = "0.4.3"
-tokio = { version = "0.2", features = ["rt-core", "macros", "blocking"] }
 actix-cors = "0.2.0"
-actix-identity = "0.2.0"
-bytes = "0.5.4"
+actix-identity = "0.2.1"
+
+bytes = "0.5.6"
+env_logger = "0.7.1"
+log = "0.4.11"
+tokio = { version = "0.2", features = ["rt-core", "macros", "blocking"] }
+
+juniper = { version = "0.14.2", path = "../juniper", features = ["expose-test-schema", "serde_json"] }
diff --git a/juniper_actix/src/lib.rs b/juniper_actix/src/lib.rs
index a3c5d00e..e6a2c28d 100644
--- a/juniper_actix/src/lib.rs
+++ b/juniper_actix/src/lib.rs
@@ -217,6 +217,267 @@ pub async fn playground_handler(
         .body(html))
 }
 
+/// `juniper_actix` subscriptions handler implementation.
+/// Cannot be merged to `juniper_actix` yet as GraphQL over WS[1]
+/// is not fully supported in current implementation.
+///
+/// *Note: this implementation is in an alpha state.*
+///
+/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
+#[cfg(feature = "subscriptions")]
+pub mod subscriptions {
+    use std::{fmt, sync::Arc};
+
+    use actix::prelude::*;
+    use actix::{Actor, StreamHandler};
+    use actix_web::http::header::{HeaderName, HeaderValue};
+    use actix_web::{web, HttpRequest, HttpResponse};
+    use actix_web_actors::ws;
+
+    use futures::SinkExt;
+    use tokio::sync::Mutex;
+
+    use juniper::futures::stream::{SplitSink, SplitStream, StreamExt};
+    use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue};
+    use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init, ServerMessage};
+
+    /// Serves the graphql-ws protocol over a WebSocket connection.
+    ///
+    /// The `init` argument is used to provide the context and additional configuration for
+    /// connections. This can be a `juniper_graphql_ws::ConnectionConfig` if the context and
+    /// configuration are already known, or it can be a closure that gets executed asynchronously
+    /// when the client sends the ConnectionInit message. Using a closure allows you to perform
+    /// authentication based on the parameters provided by the client.
+    pub async fn subscriptions_handler<Query, Mutation, Subscription, CtxT, S, I>(
+        req: HttpRequest,
+        stream: web::Payload,
+        root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>,
+        init: I,
+    ) -> Result<HttpResponse, actix_web::Error>
+    where
+        Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Query::TypeInfo: Send + Sync,
+        Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Mutation::TypeInfo: Send + Sync,
+        Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
+        Subscription::TypeInfo: Send + Sync,
+        CtxT: Unpin + Send + Sync + 'static,
+        S: ScalarValue + Send + Sync + 'static,
+        I: Init<S, CtxT> + Send,
+    {
+        let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split::<Message>();
+
+        let mut resp = ws::start(
+            SubscriptionActor {
+                graphql_tx: Arc::new(Mutex::new(s_tx)),
+                graphql_rx: Arc::new(Mutex::new(s_rx)),
+            },
+            &req,
+            stream,
+        )?;
+
+        resp.headers_mut().insert(
+            HeaderName::from_static("sec-websocket-protocol"),
+            HeaderValue::from_static("graphql-ws"),
+        );
+
+        Ok(resp)
+    }
+
+    type ConnectionSplitSink<Query, Mutation, Subscription, CtxT, S, I> = Arc<
+        Mutex<SplitSink<Connection<ArcSchema<Query, Mutation, Subscription, CtxT, S>, I>, Message>>,
+    >;
+
+    type ConnectionSplitStream<Query, Mutation, Subscription, CtxT, S, I> =
+        Arc<Mutex<SplitStream<Connection<ArcSchema<Query, Mutation, Subscription, CtxT, S>, I>>>>;
+
+    /// Subscription Actor
+    /// coordinates messages between actix_web and juniper_graphql_ws
+    /// ws message -> actor -> juniper
+    /// juniper -> actor -> ws response
+    struct SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
+    where
+        Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Query::TypeInfo: Send + Sync,
+        Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Mutation::TypeInfo: Send + Sync,
+        Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
+        Subscription::TypeInfo: Send + Sync,
+        CtxT: Unpin + Send + Sync + 'static,
+        S: ScalarValue + Send + Sync + 'static,
+        I: Init<S, CtxT> + Send,
+    {
+        graphql_tx: ConnectionSplitSink<Query, Mutation, Subscription, CtxT, S, I>,
+        graphql_rx: ConnectionSplitStream<Query, Mutation, Subscription, CtxT, S, I>,
+    }
+
+    /// ws message -> actor -> juniper
+    impl<Query, Mutation, Subscription, CtxT, S, I>
+        StreamHandler<Result<ws::Message, ws::ProtocolError>>
+        for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
+    where
+        Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Query::TypeInfo: Send + Sync,
+        Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Mutation::TypeInfo: Send + Sync,
+        Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
+        Subscription::TypeInfo: Send + Sync,
+        CtxT: Unpin + Send + Sync + 'static,
+        S: ScalarValue + Send + Sync + 'static,
+        I: Init<S, CtxT> + Send,
+    {
+        fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
+            let msg = msg.map(|r| Message(r));
+
+            match msg {
+                Ok(msg) => {
+                    let tx = self.graphql_tx.clone();
+
+                    async move {
+                        let mut tx = tx.lock().await;
+                        tx.send(msg)
+                            .await
+                            .expect("Infallible: this should not happen");
+                    }
+                    .into_actor(self)
+                    .wait(ctx);
+                }
+                Err(_) => {
+                    // TODO: trace
+                    // ignore the message if there's a transport error
+                }
+            }
+        }
+    }
+
+    /// juniper -> actor
+    impl<Query, Mutation, Subscription, CtxT, S, I> Actor
+        for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
+    where
+        Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Query::TypeInfo: Send + Sync,
+        Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Mutation::TypeInfo: Send + Sync,
+        Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
+        Subscription::TypeInfo: Send + Sync,
+        CtxT: Unpin + Send + Sync + 'static,
+        S: ScalarValue + Send + Sync + 'static,
+        I: Init<S, CtxT> + Send,
+    {
+        type Context = ws::WebsocketContext<Self>;
+
+        fn started(&mut self, ctx: &mut Self::Context) {
+            let stream = self.graphql_rx.clone();
+            let addr = ctx.address();
+
+            let fut = async move {
+                let mut stream = stream.lock().await;
+                while let Some(message) = stream.next().await {
+                    // sending the message to self so that it can be forwarded back to the client
+                    addr.do_send(ServerMessageWrapper { message });
+                }
+            }
+            .into_actor(self);
+
+            // TODO: trace
+            ctx.spawn(fut);
+        }
+
+        fn stopped(&mut self, _: &mut Self::Context) {
+            // TODO: trace
+        }
+    }
+
+    /// actor -> websocket response
+    impl<Query, Mutation, Subscription, CtxT, S, I> Handler<ServerMessageWrapper<S>>
+        for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
+    where
+        Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Query::TypeInfo: Send + Sync,
+        Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Mutation::TypeInfo: Send + Sync,
+        Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
+        Subscription::TypeInfo: Send + Sync,
+        CtxT: Unpin + Send + Sync + 'static,
+        S: ScalarValue + Send + Sync + 'static,
+        I: Init<S, CtxT> + Send,
+    {
+        type Result = ();
+
+        fn handle(
+            &mut self,
+            msg: ServerMessageWrapper<S>,
+            ctx: &mut ws::WebsocketContext<Self>,
+        ) -> Self::Result {
+            let msg = serde_json::to_string(&msg.message);
+
+            match msg {
+                Ok(msg) => {
+                    ctx.text(msg);
+                }
+                Err(e) => {
+                    let reason = ws::CloseReason {
+                        code: ws::CloseCode::Error,
+                        description: Some(format!("error serializing response: {}", e)),
+                    };
+
+                    // TODO: trace
+                    ctx.close(Some(reason));
+                }
+            }
+        }
+    }
+
+    #[derive(Message)]
+    #[rtype(result = "()")]
+    struct ServerMessageWrapper<S>
+    where
+        S: ScalarValue + Send + Sync + 'static,
+    {
+        message: ServerMessage<S>,
+    }
+
+    #[derive(Debug)]
+    struct Message(ws::Message);
+
+    impl<S: ScalarValue> std::convert::TryFrom<Message> for ClientMessage<S> {
+        type Error = Error;
+
+        fn try_from(msg: Message) -> Result<Self, Self::Error> {
+            match msg.0 {
+                ws::Message::Text(text) => {
+                    serde_json::from_slice(text.as_bytes()).map_err(|e| Error::Serde(e))
+                }
+                ws::Message::Close(_) => Ok(ClientMessage::ConnectionTerminate),
+                _ => Err(Error::UnexpectedClientMessage),
+            }
+        }
+    }
+
+    /// Errors that can happen while handling client messages
+    #[derive(Debug)]
+    enum Error {
+        /// Errors that can happen while deserializing client messages
+        Serde(serde_json::Error),
+
+        /// Error for unexpected client messages
+        UnexpectedClientMessage,
+    }
+
+    impl fmt::Display for Error {
+        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+            match self {
+                Self::Serde(e) => write!(f, "serde error: {}", e),
+                Self::UnexpectedClientMessage => {
+                    write!(f, "unexpected message received from client")
+                }
+            }
+        }
+    }
+
+    impl std::error::Error for Error {}
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;