Impl subscriptions for juniper_actix (#716)

* Impl subscriptions for juniper_actix

* Add random_human example subscription

* Add actix_subscriptions example to CI

* fixup! Add random_human example subscription

* Migrate actix subscriptions to juniper_graphql_ws

* Simplify error handling

* Change unwrap to expect

* Close connection on server serialization error

Co-authored-by: Christian Legnitto <LegNeato@users.noreply.github.com>
This commit is contained in:
Mihai Dinculescu 2020-08-09 23:19:34 +01:00 committed by GitHub
parent bdc8745a56
commit 8d7ba8295c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 461 additions and 13 deletions

View file

@ -7,6 +7,7 @@ members = [
"examples/basic_subscriptions",
"examples/warp_async",
"examples/warp_subscriptions",
"examples/actix_subscriptions",
"integration_tests/juniper_tests",
"integration_tests/async_await",
"integration_tests/codegen_fail",

View file

@ -0,0 +1,22 @@
[package]
name = "actix_subscriptions"
version = "0.1.0"
authors = ["Mihai Dinculescu <mihai.dinculescu@outlook.com>"]
edition = "2018"
publish = false
[dependencies]
actix-web = "2.0.0"
actix-rt = "1.1.1"
actix-cors = "0.2.0"
futures = "0.3.5"
tokio = { version = "0.2", features = ["rt-core", "macros"] }
env_logger = "0.7.1"
serde = "1.0.114"
serde_json = "1.0.57"
rand = "0.7.3"
juniper = { path = "../../juniper", features = ["expose-test-schema", "serde_json"] }
juniper_actix = { path = "../../juniper_actix", features = ["subscriptions"] }
juniper_graphql_ws = { path = "../../juniper_graphql_ws" }

View file

@ -0,0 +1,15 @@
[tasks.run]
disabled = true
[tasks.release]
disabled = true
[tasks.release-some]
disabled = true
[tasks.release-local-test]
disabled = true
[tasks.release-some-local-test]
disabled = true
[tasks.release-dry-run]
disabled = true
[tasks.release-some-dry-run]
disabled = true

View file

@ -0,0 +1,140 @@
use std::{env, pin::Pin, time::Duration};
use actix_cors::Cors;
use actix_web::{middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer};
use futures::Stream;
use juniper::{
tests::fixtures::starwars::{model::Database, schema::Query},
DefaultScalarValue, EmptyMutation, FieldError, RootNode,
};
use juniper_actix::{graphql_handler, playground_handler, subscriptions::subscriptions_handler};
use juniper_graphql_ws::ConnectionConfig;
type Schema = RootNode<'static, Query, EmptyMutation<Database>, Subscription>;
fn schema() -> Schema {
Schema::new(Query, EmptyMutation::<Database>::new(), Subscription)
}
async fn playground() -> Result<HttpResponse, Error> {
playground_handler("/graphql", Some("/subscriptions")).await
}
async fn graphql(
req: actix_web::HttpRequest,
payload: actix_web::web::Payload,
schema: web::Data<Schema>,
) -> Result<HttpResponse, Error> {
let context = Database::new();
graphql_handler(&schema, &context, req, payload).await
}
struct Subscription;
struct RandomHuman {
id: String,
name: String,
}
// TODO: remove this when async interfaces are merged
#[juniper::graphql_object(Context = Database)]
impl RandomHuman {
fn id(&self) -> &str {
&self.id
}
fn name(&self) -> &str {
&self.name
}
}
#[juniper::graphql_subscription(Context = Database)]
impl Subscription {
#[graphql(
description = "A random humanoid creature in the Star Wars universe every 3 seconds. Second result will be an error."
)]
async fn random_human(
context: &Database,
) -> Pin<Box<dyn Stream<Item = Result<RandomHuman, FieldError>> + Send>> {
let mut counter = 0;
let context = (*context).clone();
use rand::{rngs::StdRng, Rng, SeedableRng};
let mut rng = StdRng::from_entropy();
let stream = tokio::time::interval(Duration::from_secs(3)).map(move |_| {
counter += 1;
if counter == 2 {
Err(FieldError::new(
"some field error from handler",
Value::Scalar(DefaultScalarValue::String(
"some additional string".to_string(),
)),
))
} else {
let random_id = rng.gen_range(1000, 1005).to_string();
let human = context.get_human(&random_id).unwrap();
Ok(RandomHuman {
id: human.id().to_owned(),
name: human.name().to_owned(),
})
}
});
Box::pin(stream)
}
}
async fn subscriptions(
req: HttpRequest,
stream: web::Payload,
schema: web::Data<Schema>,
) -> Result<HttpResponse, Error> {
let context = Database::new();
let schema = schema.into_inner();
let config = ConnectionConfig::new(context);
// set the keep alive interval to 15 secs so that it doesn't timeout in playground
// playground has a hard-coded timeout set to 20 secs
let config = config.with_keep_alive_interval(Duration::from_secs(15));
subscriptions_handler(req, stream, schema, config).await
}
#[actix_rt::main]
async fn main() -> std::io::Result<()> {
env::set_var("RUST_LOG", "info");
env_logger::init();
HttpServer::new(move || {
App::new()
.data(schema())
.wrap(middleware::Compress::default())
.wrap(middleware::Logger::default())
.wrap(
Cors::new()
.allowed_methods(vec!["POST", "GET"])
.supports_credentials()
.max_age(3600)
.finish(),
)
.service(web::resource("/subscriptions").route(web::get().to(subscriptions)))
.service(
web::resource("/graphql")
.route(web::post().to(graphql))
.route(web::get().to(graphql)),
)
.service(web::resource("/playground").route(web::get().to(playground)))
.default_service(web::route().to(|| {
HttpResponse::Found()
.header("location", "/playground")
.finish()
}))
})
.bind(format!("{}:{}", "127.0.0.1", 8080))?
.run()
.await
}

View file

@ -29,6 +29,7 @@ pub trait Droid: Character {
fn primary_function(&self) -> &Option<String>;
}
#[derive(Clone)]
struct HumanData {
id: String,
name: String,
@ -38,6 +39,7 @@ struct HumanData {
home_planet: Option<String>,
}
#[derive(Clone)]
struct DroidData {
id: String,
name: String,
@ -101,7 +103,7 @@ impl Droid for DroidData {
}
}
#[derive(Default)]
#[derive(Default, Clone)]
pub struct Database {
humans: HashMap<String, HumanData>,
droids: HashMap<String, DroidData>,

View file

@ -1,3 +1,3 @@
# master
- Subscription support
- Initial Release

View file

@ -8,25 +8,32 @@ documentation = "https://docs.rs/juniper_actix"
repository = "https://github.com/graphql-rust/juniper"
edition = "2018"
[features]
subscriptions = ["juniper_graphql_ws"]
[dependencies]
actix = "0.9.0"
actix-rt = "1.0.0"
actix-rt = "1.1.1"
actix-web = { version = "2.0.0", features = ["rustls"] }
actix-web-actors = "2.0.0"
futures = { version = "0.3.1", features = ["compat"] }
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
futures = { version = "0.3.5", features = ["compat"] }
tokio = { version = "0.2", features = ["time"] }
serde = { version = "1.0.75", features = ["derive"] }
serde_json = "1.0.24"
serde = { version = "1.0.114", features = ["derive"] }
serde_json = "1.0.57"
anyhow = "1.0"
thiserror = "1.0"
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
juniper_graphql_ws = { path = "../juniper_graphql_ws", optional = true }
[dev-dependencies]
juniper = { version = "0.14.2", path = "../juniper", features = ["expose-test-schema", "serde_json"] }
env_logger = "0.7.1"
log = "0.4.3"
tokio = { version = "0.2", features = ["rt-core", "macros", "blocking"] }
actix-cors = "0.2.0"
actix-identity = "0.2.0"
bytes = "0.5.4"
actix-identity = "0.2.1"
bytes = "0.5.6"
env_logger = "0.7.1"
log = "0.4.11"
tokio = { version = "0.2", features = ["rt-core", "macros", "blocking"] }
juniper = { version = "0.14.2", path = "../juniper", features = ["expose-test-schema", "serde_json"] }

View file

@ -217,6 +217,267 @@ pub async fn playground_handler(
.body(html))
}
/// `juniper_actix` subscriptions handler implementation.
/// Cannot be merged to `juniper_actix` yet as GraphQL over WS[1]
/// is not fully supported in current implementation.
///
/// *Note: this implementation is in an alpha state.*
///
/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
#[cfg(feature = "subscriptions")]
pub mod subscriptions {
use std::{fmt, sync::Arc};
use actix::prelude::*;
use actix::{Actor, StreamHandler};
use actix_web::http::header::{HeaderName, HeaderValue};
use actix_web::{web, HttpRequest, HttpResponse};
use actix_web_actors::ws;
use futures::SinkExt;
use tokio::sync::Mutex;
use juniper::futures::stream::{SplitSink, SplitStream, StreamExt};
use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue};
use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init, ServerMessage};
/// Serves the graphql-ws protocol over a WebSocket connection.
///
/// The `init` argument is used to provide the context and additional configuration for
/// connections. This can be a `juniper_graphql_ws::ConnectionConfig` if the context and
/// configuration are already known, or it can be a closure that gets executed asynchronously
/// when the client sends the ConnectionInit message. Using a closure allows you to perform
/// authentication based on the parameters provided by the client.
pub async fn subscriptions_handler<Query, Mutation, Subscription, CtxT, S, I>(
req: HttpRequest,
stream: web::Payload,
root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>,
init: I,
) -> Result<HttpResponse, actix_web::Error>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: Init<S, CtxT> + Send,
{
let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split::<Message>();
let mut resp = ws::start(
SubscriptionActor {
graphql_tx: Arc::new(Mutex::new(s_tx)),
graphql_rx: Arc::new(Mutex::new(s_rx)),
},
&req,
stream,
)?;
resp.headers_mut().insert(
HeaderName::from_static("sec-websocket-protocol"),
HeaderValue::from_static("graphql-ws"),
);
Ok(resp)
}
type ConnectionSplitSink<Query, Mutation, Subscription, CtxT, S, I> = Arc<
Mutex<SplitSink<Connection<ArcSchema<Query, Mutation, Subscription, CtxT, S>, I>, Message>>,
>;
type ConnectionSplitStream<Query, Mutation, Subscription, CtxT, S, I> =
Arc<Mutex<SplitStream<Connection<ArcSchema<Query, Mutation, Subscription, CtxT, S>, I>>>>;
/// Subscription Actor
/// coordinates messages between actix_web and juniper_graphql_ws
/// ws message -> actor -> juniper
/// juniper -> actor -> ws response
struct SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: Init<S, CtxT> + Send,
{
graphql_tx: ConnectionSplitSink<Query, Mutation, Subscription, CtxT, S, I>,
graphql_rx: ConnectionSplitStream<Query, Mutation, Subscription, CtxT, S, I>,
}
/// ws message -> actor -> juniper
impl<Query, Mutation, Subscription, CtxT, S, I>
StreamHandler<Result<ws::Message, ws::ProtocolError>>
for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: Init<S, CtxT> + Send,
{
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
let msg = msg.map(|r| Message(r));
match msg {
Ok(msg) => {
let tx = self.graphql_tx.clone();
async move {
let mut tx = tx.lock().await;
tx.send(msg)
.await
.expect("Infallible: this should not happen");
}
.into_actor(self)
.wait(ctx);
}
Err(_) => {
// TODO: trace
// ignore the message if there's a transport error
}
}
}
}
/// juniper -> actor
impl<Query, Mutation, Subscription, CtxT, S, I> Actor
for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: Init<S, CtxT> + Send,
{
type Context = ws::WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
let stream = self.graphql_rx.clone();
let addr = ctx.address();
let fut = async move {
let mut stream = stream.lock().await;
while let Some(message) = stream.next().await {
// sending the message to self so that it can be forwarded back to the client
addr.do_send(ServerMessageWrapper { message });
}
}
.into_actor(self);
// TODO: trace
ctx.spawn(fut);
}
fn stopped(&mut self, _: &mut Self::Context) {
// TODO: trace
}
}
/// actor -> websocket response
impl<Query, Mutation, Subscription, CtxT, S, I> Handler<ServerMessageWrapper<S>>
for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: Init<S, CtxT> + Send,
{
type Result = ();
fn handle(
&mut self,
msg: ServerMessageWrapper<S>,
ctx: &mut ws::WebsocketContext<Self>,
) -> Self::Result {
let msg = serde_json::to_string(&msg.message);
match msg {
Ok(msg) => {
ctx.text(msg);
}
Err(e) => {
let reason = ws::CloseReason {
code: ws::CloseCode::Error,
description: Some(format!("error serializing response: {}", e)),
};
// TODO: trace
ctx.close(Some(reason));
}
}
}
}
#[derive(Message)]
#[rtype(result = "()")]
struct ServerMessageWrapper<S>
where
S: ScalarValue + Send + Sync + 'static,
{
message: ServerMessage<S>,
}
#[derive(Debug)]
struct Message(ws::Message);
impl<S: ScalarValue> std::convert::TryFrom<Message> for ClientMessage<S> {
type Error = Error;
fn try_from(msg: Message) -> Result<Self, Self::Error> {
match msg.0 {
ws::Message::Text(text) => {
serde_json::from_slice(text.as_bytes()).map_err(|e| Error::Serde(e))
}
ws::Message::Close(_) => Ok(ClientMessage::ConnectionTerminate),
_ => Err(Error::UnexpectedClientMessage),
}
}
}
/// Errors that can happen while handling client messages
#[derive(Debug)]
enum Error {
/// Errors that can happen while deserializing client messages
Serde(serde_json::Error),
/// Error for unexpected client messages
UnexpectedClientMessage,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Serde(e) => write!(f, "serde error: {}", e),
Self::UnexpectedClientMessage => {
write!(f, "unexpected message received from client")
}
}
}
}
impl std::error::Error for Error {}
}
#[cfg(test)]
mod tests {
use super::*;