From 8d7ba8295c6a37f9300b2fb7a099f4e6eb87af7a Mon Sep 17 00:00:00 2001 From: Mihai Dinculescu <mihai.dinculescu@outlook.com> Date: Sun, 9 Aug 2020 23:19:34 +0100 Subject: [PATCH] Impl subscriptions for juniper_actix (#716) * Impl subscriptions for juniper_actix * Add random_human example subscription * Add actix_subscriptions example to CI * fixup! Add random_human example subscription * Migrate actix subscriptions to juniper_graphql_ws * Simplify error handling * Change unwrap to expect * Close connection on server serialization error Co-authored-by: Christian Legnitto <LegNeato@users.noreply.github.com> --- Cargo.toml | 1 + examples/actix_subscriptions/Cargo.toml | 22 ++ examples/actix_subscriptions/Makefile.toml | 15 ++ examples/actix_subscriptions/src/main.rs | 140 ++++++++++ juniper/src/tests/fixtures/starwars/model.rs | 4 +- juniper_actix/CHANGELOG.md | 2 +- juniper_actix/Cargo.toml | 29 ++- juniper_actix/src/lib.rs | 261 +++++++++++++++++++ 8 files changed, 461 insertions(+), 13 deletions(-) create mode 100644 examples/actix_subscriptions/Cargo.toml create mode 100644 examples/actix_subscriptions/Makefile.toml create mode 100644 examples/actix_subscriptions/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index d37670be..ddb54b5b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "examples/basic_subscriptions", "examples/warp_async", "examples/warp_subscriptions", + "examples/actix_subscriptions", "integration_tests/juniper_tests", "integration_tests/async_await", "integration_tests/codegen_fail", diff --git a/examples/actix_subscriptions/Cargo.toml b/examples/actix_subscriptions/Cargo.toml new file mode 100644 index 00000000..6825e404 --- /dev/null +++ b/examples/actix_subscriptions/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "actix_subscriptions" +version = "0.1.0" +authors = ["Mihai Dinculescu <mihai.dinculescu@outlook.com>"] +edition = "2018" +publish = false + +[dependencies] +actix-web = "2.0.0" +actix-rt = "1.1.1" +actix-cors = "0.2.0" + +futures = "0.3.5" +tokio = { version = "0.2", features = ["rt-core", "macros"] } +env_logger = "0.7.1" +serde = "1.0.114" +serde_json = "1.0.57" +rand = "0.7.3" + +juniper = { path = "../../juniper", features = ["expose-test-schema", "serde_json"] } +juniper_actix = { path = "../../juniper_actix", features = ["subscriptions"] } +juniper_graphql_ws = { path = "../../juniper_graphql_ws" } diff --git a/examples/actix_subscriptions/Makefile.toml b/examples/actix_subscriptions/Makefile.toml new file mode 100644 index 00000000..dccba933 --- /dev/null +++ b/examples/actix_subscriptions/Makefile.toml @@ -0,0 +1,15 @@ +[tasks.run] +disabled = true + +[tasks.release] +disabled = true +[tasks.release-some] +disabled = true +[tasks.release-local-test] +disabled = true +[tasks.release-some-local-test] +disabled = true +[tasks.release-dry-run] +disabled = true +[tasks.release-some-dry-run] +disabled = true diff --git a/examples/actix_subscriptions/src/main.rs b/examples/actix_subscriptions/src/main.rs new file mode 100644 index 00000000..e9f884ea --- /dev/null +++ b/examples/actix_subscriptions/src/main.rs @@ -0,0 +1,140 @@ +use std::{env, pin::Pin, time::Duration}; + +use actix_cors::Cors; +use actix_web::{middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer}; +use futures::Stream; + +use juniper::{ + tests::fixtures::starwars::{model::Database, schema::Query}, + DefaultScalarValue, EmptyMutation, FieldError, RootNode, +}; +use juniper_actix::{graphql_handler, playground_handler, subscriptions::subscriptions_handler}; +use juniper_graphql_ws::ConnectionConfig; + +type Schema = RootNode<'static, Query, EmptyMutation<Database>, Subscription>; + +fn schema() -> Schema { + Schema::new(Query, EmptyMutation::<Database>::new(), Subscription) +} + +async fn playground() -> Result<HttpResponse, Error> { + playground_handler("/graphql", Some("/subscriptions")).await +} + +async fn graphql( + req: actix_web::HttpRequest, + payload: actix_web::web::Payload, + schema: web::Data<Schema>, +) -> Result<HttpResponse, Error> { + let context = Database::new(); + graphql_handler(&schema, &context, req, payload).await +} + +struct Subscription; + +struct RandomHuman { + id: String, + name: String, +} + +// TODO: remove this when async interfaces are merged +#[juniper::graphql_object(Context = Database)] +impl RandomHuman { + fn id(&self) -> &str { + &self.id + } + + fn name(&self) -> &str { + &self.name + } +} + +#[juniper::graphql_subscription(Context = Database)] +impl Subscription { + #[graphql( + description = "A random humanoid creature in the Star Wars universe every 3 seconds. Second result will be an error." + )] + async fn random_human( + context: &Database, + ) -> Pin<Box<dyn Stream<Item = Result<RandomHuman, FieldError>> + Send>> { + let mut counter = 0; + + let context = (*context).clone(); + + use rand::{rngs::StdRng, Rng, SeedableRng}; + let mut rng = StdRng::from_entropy(); + + let stream = tokio::time::interval(Duration::from_secs(3)).map(move |_| { + counter += 1; + + if counter == 2 { + Err(FieldError::new( + "some field error from handler", + Value::Scalar(DefaultScalarValue::String( + "some additional string".to_string(), + )), + )) + } else { + let random_id = rng.gen_range(1000, 1005).to_string(); + let human = context.get_human(&random_id).unwrap(); + + Ok(RandomHuman { + id: human.id().to_owned(), + name: human.name().to_owned(), + }) + } + }); + + Box::pin(stream) + } +} + +async fn subscriptions( + req: HttpRequest, + stream: web::Payload, + schema: web::Data<Schema>, +) -> Result<HttpResponse, Error> { + let context = Database::new(); + let schema = schema.into_inner(); + let config = ConnectionConfig::new(context); + // set the keep alive interval to 15 secs so that it doesn't timeout in playground + // playground has a hard-coded timeout set to 20 secs + let config = config.with_keep_alive_interval(Duration::from_secs(15)); + + subscriptions_handler(req, stream, schema, config).await +} + +#[actix_rt::main] +async fn main() -> std::io::Result<()> { + env::set_var("RUST_LOG", "info"); + env_logger::init(); + + HttpServer::new(move || { + App::new() + .data(schema()) + .wrap(middleware::Compress::default()) + .wrap(middleware::Logger::default()) + .wrap( + Cors::new() + .allowed_methods(vec!["POST", "GET"]) + .supports_credentials() + .max_age(3600) + .finish(), + ) + .service(web::resource("/subscriptions").route(web::get().to(subscriptions))) + .service( + web::resource("/graphql") + .route(web::post().to(graphql)) + .route(web::get().to(graphql)), + ) + .service(web::resource("/playground").route(web::get().to(playground))) + .default_service(web::route().to(|| { + HttpResponse::Found() + .header("location", "/playground") + .finish() + })) + }) + .bind(format!("{}:{}", "127.0.0.1", 8080))? + .run() + .await +} diff --git a/juniper/src/tests/fixtures/starwars/model.rs b/juniper/src/tests/fixtures/starwars/model.rs index 354c0b31..ae250607 100644 --- a/juniper/src/tests/fixtures/starwars/model.rs +++ b/juniper/src/tests/fixtures/starwars/model.rs @@ -29,6 +29,7 @@ pub trait Droid: Character { fn primary_function(&self) -> &Option<String>; } +#[derive(Clone)] struct HumanData { id: String, name: String, @@ -38,6 +39,7 @@ struct HumanData { home_planet: Option<String>, } +#[derive(Clone)] struct DroidData { id: String, name: String, @@ -101,7 +103,7 @@ impl Droid for DroidData { } } -#[derive(Default)] +#[derive(Default, Clone)] pub struct Database { humans: HashMap<String, HumanData>, droids: HashMap<String, DroidData>, diff --git a/juniper_actix/CHANGELOG.md b/juniper_actix/CHANGELOG.md index 05232472..cf41ac35 100644 --- a/juniper_actix/CHANGELOG.md +++ b/juniper_actix/CHANGELOG.md @@ -1,3 +1,3 @@ # master - +- Subscription support - Initial Release diff --git a/juniper_actix/Cargo.toml b/juniper_actix/Cargo.toml index ba24d40b..7f9be579 100644 --- a/juniper_actix/Cargo.toml +++ b/juniper_actix/Cargo.toml @@ -8,25 +8,32 @@ documentation = "https://docs.rs/juniper_actix" repository = "https://github.com/graphql-rust/juniper" edition = "2018" +[features] +subscriptions = ["juniper_graphql_ws"] [dependencies] actix = "0.9.0" -actix-rt = "1.0.0" +actix-rt = "1.1.1" actix-web = { version = "2.0.0", features = ["rustls"] } actix-web-actors = "2.0.0" -futures = { version = "0.3.1", features = ["compat"] } -juniper = { version = "0.14.2", path = "../juniper", default-features = false } + +futures = { version = "0.3.5", features = ["compat"] } tokio = { version = "0.2", features = ["time"] } -serde = { version = "1.0.75", features = ["derive"] } -serde_json = "1.0.24" +serde = { version = "1.0.114", features = ["derive"] } +serde_json = "1.0.57" anyhow = "1.0" thiserror = "1.0" +juniper = { version = "0.14.2", path = "../juniper", default-features = false } +juniper_graphql_ws = { path = "../juniper_graphql_ws", optional = true } + [dev-dependencies] -juniper = { version = "0.14.2", path = "../juniper", features = ["expose-test-schema", "serde_json"] } -env_logger = "0.7.1" -log = "0.4.3" -tokio = { version = "0.2", features = ["rt-core", "macros", "blocking"] } actix-cors = "0.2.0" -actix-identity = "0.2.0" -bytes = "0.5.4" +actix-identity = "0.2.1" + +bytes = "0.5.6" +env_logger = "0.7.1" +log = "0.4.11" +tokio = { version = "0.2", features = ["rt-core", "macros", "blocking"] } + +juniper = { version = "0.14.2", path = "../juniper", features = ["expose-test-schema", "serde_json"] } diff --git a/juniper_actix/src/lib.rs b/juniper_actix/src/lib.rs index a3c5d00e..e6a2c28d 100644 --- a/juniper_actix/src/lib.rs +++ b/juniper_actix/src/lib.rs @@ -217,6 +217,267 @@ pub async fn playground_handler( .body(html)) } +/// `juniper_actix` subscriptions handler implementation. +/// Cannot be merged to `juniper_actix` yet as GraphQL over WS[1] +/// is not fully supported in current implementation. +/// +/// *Note: this implementation is in an alpha state.* +/// +/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +#[cfg(feature = "subscriptions")] +pub mod subscriptions { + use std::{fmt, sync::Arc}; + + use actix::prelude::*; + use actix::{Actor, StreamHandler}; + use actix_web::http::header::{HeaderName, HeaderValue}; + use actix_web::{web, HttpRequest, HttpResponse}; + use actix_web_actors::ws; + + use futures::SinkExt; + use tokio::sync::Mutex; + + use juniper::futures::stream::{SplitSink, SplitStream, StreamExt}; + use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue}; + use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init, ServerMessage}; + + /// Serves the graphql-ws protocol over a WebSocket connection. + /// + /// The `init` argument is used to provide the context and additional configuration for + /// connections. This can be a `juniper_graphql_ws::ConnectionConfig` if the context and + /// configuration are already known, or it can be a closure that gets executed asynchronously + /// when the client sends the ConnectionInit message. Using a closure allows you to perform + /// authentication based on the parameters provided by the client. + pub async fn subscriptions_handler<Query, Mutation, Subscription, CtxT, S, I>( + req: HttpRequest, + stream: web::Payload, + root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>, + init: I, + ) -> Result<HttpResponse, actix_web::Error> + where + Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: Init<S, CtxT> + Send, + { + let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split::<Message>(); + + let mut resp = ws::start( + SubscriptionActor { + graphql_tx: Arc::new(Mutex::new(s_tx)), + graphql_rx: Arc::new(Mutex::new(s_rx)), + }, + &req, + stream, + )?; + + resp.headers_mut().insert( + HeaderName::from_static("sec-websocket-protocol"), + HeaderValue::from_static("graphql-ws"), + ); + + Ok(resp) + } + + type ConnectionSplitSink<Query, Mutation, Subscription, CtxT, S, I> = Arc< + Mutex<SplitSink<Connection<ArcSchema<Query, Mutation, Subscription, CtxT, S>, I>, Message>>, + >; + + type ConnectionSplitStream<Query, Mutation, Subscription, CtxT, S, I> = + Arc<Mutex<SplitStream<Connection<ArcSchema<Query, Mutation, Subscription, CtxT, S>, I>>>>; + + /// Subscription Actor + /// coordinates messages between actix_web and juniper_graphql_ws + /// ws message -> actor -> juniper + /// juniper -> actor -> ws response + struct SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I> + where + Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: Init<S, CtxT> + Send, + { + graphql_tx: ConnectionSplitSink<Query, Mutation, Subscription, CtxT, S, I>, + graphql_rx: ConnectionSplitStream<Query, Mutation, Subscription, CtxT, S, I>, + } + + /// ws message -> actor -> juniper + impl<Query, Mutation, Subscription, CtxT, S, I> + StreamHandler<Result<ws::Message, ws::ProtocolError>> + for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I> + where + Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: Init<S, CtxT> + Send, + { + fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) { + let msg = msg.map(|r| Message(r)); + + match msg { + Ok(msg) => { + let tx = self.graphql_tx.clone(); + + async move { + let mut tx = tx.lock().await; + tx.send(msg) + .await + .expect("Infallible: this should not happen"); + } + .into_actor(self) + .wait(ctx); + } + Err(_) => { + // TODO: trace + // ignore the message if there's a transport error + } + } + } + } + + /// juniper -> actor + impl<Query, Mutation, Subscription, CtxT, S, I> Actor + for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I> + where + Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: Init<S, CtxT> + Send, + { + type Context = ws::WebsocketContext<Self>; + + fn started(&mut self, ctx: &mut Self::Context) { + let stream = self.graphql_rx.clone(); + let addr = ctx.address(); + + let fut = async move { + let mut stream = stream.lock().await; + while let Some(message) = stream.next().await { + // sending the message to self so that it can be forwarded back to the client + addr.do_send(ServerMessageWrapper { message }); + } + } + .into_actor(self); + + // TODO: trace + ctx.spawn(fut); + } + + fn stopped(&mut self, _: &mut Self::Context) { + // TODO: trace + } + } + + /// actor -> websocket response + impl<Query, Mutation, Subscription, CtxT, S, I> Handler<ServerMessageWrapper<S>> + for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I> + where + Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: Init<S, CtxT> + Send, + { + type Result = (); + + fn handle( + &mut self, + msg: ServerMessageWrapper<S>, + ctx: &mut ws::WebsocketContext<Self>, + ) -> Self::Result { + let msg = serde_json::to_string(&msg.message); + + match msg { + Ok(msg) => { + ctx.text(msg); + } + Err(e) => { + let reason = ws::CloseReason { + code: ws::CloseCode::Error, + description: Some(format!("error serializing response: {}", e)), + }; + + // TODO: trace + ctx.close(Some(reason)); + } + } + } + } + + #[derive(Message)] + #[rtype(result = "()")] + struct ServerMessageWrapper<S> + where + S: ScalarValue + Send + Sync + 'static, + { + message: ServerMessage<S>, + } + + #[derive(Debug)] + struct Message(ws::Message); + + impl<S: ScalarValue> std::convert::TryFrom<Message> for ClientMessage<S> { + type Error = Error; + + fn try_from(msg: Message) -> Result<Self, Self::Error> { + match msg.0 { + ws::Message::Text(text) => { + serde_json::from_slice(text.as_bytes()).map_err(|e| Error::Serde(e)) + } + ws::Message::Close(_) => Ok(ClientMessage::ConnectionTerminate), + _ => Err(Error::UnexpectedClientMessage), + } + } + } + + /// Errors that can happen while handling client messages + #[derive(Debug)] + enum Error { + /// Errors that can happen while deserializing client messages + Serde(serde_json::Error), + + /// Error for unexpected client messages + UnexpectedClientMessage, + } + + impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Serde(e) => write!(f, "serde error: {}", e), + Self::UnexpectedClientMessage => { + write!(f, "unexpected message received from client") + } + } + } + } + + impl std::error::Error for Error {} +} + #[cfg(test)] mod tests { use super::*;