Impl subscriptions for juniper_actix (#716)
* Impl subscriptions for juniper_actix * Add random_human example subscription * Add actix_subscriptions example to CI * fixup! Add random_human example subscription * Migrate actix subscriptions to juniper_graphql_ws * Simplify error handling * Change unwrap to expect * Close connection on server serialization error Co-authored-by: Christian Legnitto <LegNeato@users.noreply.github.com>
This commit is contained in:
parent
bdc8745a56
commit
8d7ba8295c
8 changed files with 461 additions and 13 deletions
|
@ -7,6 +7,7 @@ members = [
|
|||
"examples/basic_subscriptions",
|
||||
"examples/warp_async",
|
||||
"examples/warp_subscriptions",
|
||||
"examples/actix_subscriptions",
|
||||
"integration_tests/juniper_tests",
|
||||
"integration_tests/async_await",
|
||||
"integration_tests/codegen_fail",
|
||||
|
|
22
examples/actix_subscriptions/Cargo.toml
Normal file
22
examples/actix_subscriptions/Cargo.toml
Normal file
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "actix_subscriptions"
|
||||
version = "0.1.0"
|
||||
authors = ["Mihai Dinculescu <mihai.dinculescu@outlook.com>"]
|
||||
edition = "2018"
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
actix-web = "2.0.0"
|
||||
actix-rt = "1.1.1"
|
||||
actix-cors = "0.2.0"
|
||||
|
||||
futures = "0.3.5"
|
||||
tokio = { version = "0.2", features = ["rt-core", "macros"] }
|
||||
env_logger = "0.7.1"
|
||||
serde = "1.0.114"
|
||||
serde_json = "1.0.57"
|
||||
rand = "0.7.3"
|
||||
|
||||
juniper = { path = "../../juniper", features = ["expose-test-schema", "serde_json"] }
|
||||
juniper_actix = { path = "../../juniper_actix", features = ["subscriptions"] }
|
||||
juniper_graphql_ws = { path = "../../juniper_graphql_ws" }
|
15
examples/actix_subscriptions/Makefile.toml
Normal file
15
examples/actix_subscriptions/Makefile.toml
Normal file
|
@ -0,0 +1,15 @@
|
|||
[tasks.run]
|
||||
disabled = true
|
||||
|
||||
[tasks.release]
|
||||
disabled = true
|
||||
[tasks.release-some]
|
||||
disabled = true
|
||||
[tasks.release-local-test]
|
||||
disabled = true
|
||||
[tasks.release-some-local-test]
|
||||
disabled = true
|
||||
[tasks.release-dry-run]
|
||||
disabled = true
|
||||
[tasks.release-some-dry-run]
|
||||
disabled = true
|
140
examples/actix_subscriptions/src/main.rs
Normal file
140
examples/actix_subscriptions/src/main.rs
Normal file
|
@ -0,0 +1,140 @@
|
|||
use std::{env, pin::Pin, time::Duration};
|
||||
|
||||
use actix_cors::Cors;
|
||||
use actix_web::{middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer};
|
||||
use futures::Stream;
|
||||
|
||||
use juniper::{
|
||||
tests::fixtures::starwars::{model::Database, schema::Query},
|
||||
DefaultScalarValue, EmptyMutation, FieldError, RootNode,
|
||||
};
|
||||
use juniper_actix::{graphql_handler, playground_handler, subscriptions::subscriptions_handler};
|
||||
use juniper_graphql_ws::ConnectionConfig;
|
||||
|
||||
type Schema = RootNode<'static, Query, EmptyMutation<Database>, Subscription>;
|
||||
|
||||
fn schema() -> Schema {
|
||||
Schema::new(Query, EmptyMutation::<Database>::new(), Subscription)
|
||||
}
|
||||
|
||||
async fn playground() -> Result<HttpResponse, Error> {
|
||||
playground_handler("/graphql", Some("/subscriptions")).await
|
||||
}
|
||||
|
||||
async fn graphql(
|
||||
req: actix_web::HttpRequest,
|
||||
payload: actix_web::web::Payload,
|
||||
schema: web::Data<Schema>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let context = Database::new();
|
||||
graphql_handler(&schema, &context, req, payload).await
|
||||
}
|
||||
|
||||
struct Subscription;
|
||||
|
||||
struct RandomHuman {
|
||||
id: String,
|
||||
name: String,
|
||||
}
|
||||
|
||||
// TODO: remove this when async interfaces are merged
|
||||
#[juniper::graphql_object(Context = Database)]
|
||||
impl RandomHuman {
|
||||
fn id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
}
|
||||
|
||||
#[juniper::graphql_subscription(Context = Database)]
|
||||
impl Subscription {
|
||||
#[graphql(
|
||||
description = "A random humanoid creature in the Star Wars universe every 3 seconds. Second result will be an error."
|
||||
)]
|
||||
async fn random_human(
|
||||
context: &Database,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<RandomHuman, FieldError>> + Send>> {
|
||||
let mut counter = 0;
|
||||
|
||||
let context = (*context).clone();
|
||||
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
let mut rng = StdRng::from_entropy();
|
||||
|
||||
let stream = tokio::time::interval(Duration::from_secs(3)).map(move |_| {
|
||||
counter += 1;
|
||||
|
||||
if counter == 2 {
|
||||
Err(FieldError::new(
|
||||
"some field error from handler",
|
||||
Value::Scalar(DefaultScalarValue::String(
|
||||
"some additional string".to_string(),
|
||||
)),
|
||||
))
|
||||
} else {
|
||||
let random_id = rng.gen_range(1000, 1005).to_string();
|
||||
let human = context.get_human(&random_id).unwrap();
|
||||
|
||||
Ok(RandomHuman {
|
||||
id: human.id().to_owned(),
|
||||
name: human.name().to_owned(),
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
Box::pin(stream)
|
||||
}
|
||||
}
|
||||
|
||||
async fn subscriptions(
|
||||
req: HttpRequest,
|
||||
stream: web::Payload,
|
||||
schema: web::Data<Schema>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let context = Database::new();
|
||||
let schema = schema.into_inner();
|
||||
let config = ConnectionConfig::new(context);
|
||||
// set the keep alive interval to 15 secs so that it doesn't timeout in playground
|
||||
// playground has a hard-coded timeout set to 20 secs
|
||||
let config = config.with_keep_alive_interval(Duration::from_secs(15));
|
||||
|
||||
subscriptions_handler(req, stream, schema, config).await
|
||||
}
|
||||
|
||||
#[actix_rt::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
env::set_var("RUST_LOG", "info");
|
||||
env_logger::init();
|
||||
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
.data(schema())
|
||||
.wrap(middleware::Compress::default())
|
||||
.wrap(middleware::Logger::default())
|
||||
.wrap(
|
||||
Cors::new()
|
||||
.allowed_methods(vec!["POST", "GET"])
|
||||
.supports_credentials()
|
||||
.max_age(3600)
|
||||
.finish(),
|
||||
)
|
||||
.service(web::resource("/subscriptions").route(web::get().to(subscriptions)))
|
||||
.service(
|
||||
web::resource("/graphql")
|
||||
.route(web::post().to(graphql))
|
||||
.route(web::get().to(graphql)),
|
||||
)
|
||||
.service(web::resource("/playground").route(web::get().to(playground)))
|
||||
.default_service(web::route().to(|| {
|
||||
HttpResponse::Found()
|
||||
.header("location", "/playground")
|
||||
.finish()
|
||||
}))
|
||||
})
|
||||
.bind(format!("{}:{}", "127.0.0.1", 8080))?
|
||||
.run()
|
||||
.await
|
||||
}
|
4
juniper/src/tests/fixtures/starwars/model.rs
vendored
4
juniper/src/tests/fixtures/starwars/model.rs
vendored
|
@ -29,6 +29,7 @@ pub trait Droid: Character {
|
|||
fn primary_function(&self) -> &Option<String>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HumanData {
|
||||
id: String,
|
||||
name: String,
|
||||
|
@ -38,6 +39,7 @@ struct HumanData {
|
|||
home_planet: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DroidData {
|
||||
id: String,
|
||||
name: String,
|
||||
|
@ -101,7 +103,7 @@ impl Droid for DroidData {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Clone)]
|
||||
pub struct Database {
|
||||
humans: HashMap<String, HumanData>,
|
||||
droids: HashMap<String, DroidData>,
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
# master
|
||||
|
||||
- Subscription support
|
||||
- Initial Release
|
||||
|
|
|
@ -8,25 +8,32 @@ documentation = "https://docs.rs/juniper_actix"
|
|||
repository = "https://github.com/graphql-rust/juniper"
|
||||
edition = "2018"
|
||||
|
||||
[features]
|
||||
subscriptions = ["juniper_graphql_ws"]
|
||||
|
||||
[dependencies]
|
||||
actix = "0.9.0"
|
||||
actix-rt = "1.0.0"
|
||||
actix-rt = "1.1.1"
|
||||
actix-web = { version = "2.0.0", features = ["rustls"] }
|
||||
actix-web-actors = "2.0.0"
|
||||
futures = { version = "0.3.1", features = ["compat"] }
|
||||
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
|
||||
|
||||
futures = { version = "0.3.5", features = ["compat"] }
|
||||
tokio = { version = "0.2", features = ["time"] }
|
||||
serde = { version = "1.0.75", features = ["derive"] }
|
||||
serde_json = "1.0.24"
|
||||
serde = { version = "1.0.114", features = ["derive"] }
|
||||
serde_json = "1.0.57"
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
|
||||
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
|
||||
juniper_graphql_ws = { path = "../juniper_graphql_ws", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
juniper = { version = "0.14.2", path = "../juniper", features = ["expose-test-schema", "serde_json"] }
|
||||
env_logger = "0.7.1"
|
||||
log = "0.4.3"
|
||||
tokio = { version = "0.2", features = ["rt-core", "macros", "blocking"] }
|
||||
actix-cors = "0.2.0"
|
||||
actix-identity = "0.2.0"
|
||||
bytes = "0.5.4"
|
||||
actix-identity = "0.2.1"
|
||||
|
||||
bytes = "0.5.6"
|
||||
env_logger = "0.7.1"
|
||||
log = "0.4.11"
|
||||
tokio = { version = "0.2", features = ["rt-core", "macros", "blocking"] }
|
||||
|
||||
juniper = { version = "0.14.2", path = "../juniper", features = ["expose-test-schema", "serde_json"] }
|
||||
|
|
|
@ -217,6 +217,267 @@ pub async fn playground_handler(
|
|||
.body(html))
|
||||
}
|
||||
|
||||
/// `juniper_actix` subscriptions handler implementation.
|
||||
/// Cannot be merged to `juniper_actix` yet as GraphQL over WS[1]
|
||||
/// is not fully supported in current implementation.
|
||||
///
|
||||
/// *Note: this implementation is in an alpha state.*
|
||||
///
|
||||
/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
|
||||
#[cfg(feature = "subscriptions")]
|
||||
pub mod subscriptions {
|
||||
use std::{fmt, sync::Arc};
|
||||
|
||||
use actix::prelude::*;
|
||||
use actix::{Actor, StreamHandler};
|
||||
use actix_web::http::header::{HeaderName, HeaderValue};
|
||||
use actix_web::{web, HttpRequest, HttpResponse};
|
||||
use actix_web_actors::ws;
|
||||
|
||||
use futures::SinkExt;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use juniper::futures::stream::{SplitSink, SplitStream, StreamExt};
|
||||
use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue};
|
||||
use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init, ServerMessage};
|
||||
|
||||
/// Serves the graphql-ws protocol over a WebSocket connection.
|
||||
///
|
||||
/// The `init` argument is used to provide the context and additional configuration for
|
||||
/// connections. This can be a `juniper_graphql_ws::ConnectionConfig` if the context and
|
||||
/// configuration are already known, or it can be a closure that gets executed asynchronously
|
||||
/// when the client sends the ConnectionInit message. Using a closure allows you to perform
|
||||
/// authentication based on the parameters provided by the client.
|
||||
pub async fn subscriptions_handler<Query, Mutation, Subscription, CtxT, S, I>(
|
||||
req: HttpRequest,
|
||||
stream: web::Payload,
|
||||
root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>,
|
||||
init: I,
|
||||
) -> Result<HttpResponse, actix_web::Error>
|
||||
where
|
||||
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Query::TypeInfo: Send + Sync,
|
||||
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Mutation::TypeInfo: Send + Sync,
|
||||
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
Subscription::TypeInfo: Send + Sync,
|
||||
CtxT: Unpin + Send + Sync + 'static,
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
I: Init<S, CtxT> + Send,
|
||||
{
|
||||
let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split::<Message>();
|
||||
|
||||
let mut resp = ws::start(
|
||||
SubscriptionActor {
|
||||
graphql_tx: Arc::new(Mutex::new(s_tx)),
|
||||
graphql_rx: Arc::new(Mutex::new(s_rx)),
|
||||
},
|
||||
&req,
|
||||
stream,
|
||||
)?;
|
||||
|
||||
resp.headers_mut().insert(
|
||||
HeaderName::from_static("sec-websocket-protocol"),
|
||||
HeaderValue::from_static("graphql-ws"),
|
||||
);
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
type ConnectionSplitSink<Query, Mutation, Subscription, CtxT, S, I> = Arc<
|
||||
Mutex<SplitSink<Connection<ArcSchema<Query, Mutation, Subscription, CtxT, S>, I>, Message>>,
|
||||
>;
|
||||
|
||||
type ConnectionSplitStream<Query, Mutation, Subscription, CtxT, S, I> =
|
||||
Arc<Mutex<SplitStream<Connection<ArcSchema<Query, Mutation, Subscription, CtxT, S>, I>>>>;
|
||||
|
||||
/// Subscription Actor
|
||||
/// coordinates messages between actix_web and juniper_graphql_ws
|
||||
/// ws message -> actor -> juniper
|
||||
/// juniper -> actor -> ws response
|
||||
struct SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
|
||||
where
|
||||
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Query::TypeInfo: Send + Sync,
|
||||
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Mutation::TypeInfo: Send + Sync,
|
||||
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
Subscription::TypeInfo: Send + Sync,
|
||||
CtxT: Unpin + Send + Sync + 'static,
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
I: Init<S, CtxT> + Send,
|
||||
{
|
||||
graphql_tx: ConnectionSplitSink<Query, Mutation, Subscription, CtxT, S, I>,
|
||||
graphql_rx: ConnectionSplitStream<Query, Mutation, Subscription, CtxT, S, I>,
|
||||
}
|
||||
|
||||
/// ws message -> actor -> juniper
|
||||
impl<Query, Mutation, Subscription, CtxT, S, I>
|
||||
StreamHandler<Result<ws::Message, ws::ProtocolError>>
|
||||
for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
|
||||
where
|
||||
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Query::TypeInfo: Send + Sync,
|
||||
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Mutation::TypeInfo: Send + Sync,
|
||||
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
Subscription::TypeInfo: Send + Sync,
|
||||
CtxT: Unpin + Send + Sync + 'static,
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
I: Init<S, CtxT> + Send,
|
||||
{
|
||||
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
|
||||
let msg = msg.map(|r| Message(r));
|
||||
|
||||
match msg {
|
||||
Ok(msg) => {
|
||||
let tx = self.graphql_tx.clone();
|
||||
|
||||
async move {
|
||||
let mut tx = tx.lock().await;
|
||||
tx.send(msg)
|
||||
.await
|
||||
.expect("Infallible: this should not happen");
|
||||
}
|
||||
.into_actor(self)
|
||||
.wait(ctx);
|
||||
}
|
||||
Err(_) => {
|
||||
// TODO: trace
|
||||
// ignore the message if there's a transport error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// juniper -> actor
|
||||
impl<Query, Mutation, Subscription, CtxT, S, I> Actor
|
||||
for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
|
||||
where
|
||||
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Query::TypeInfo: Send + Sync,
|
||||
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Mutation::TypeInfo: Send + Sync,
|
||||
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
Subscription::TypeInfo: Send + Sync,
|
||||
CtxT: Unpin + Send + Sync + 'static,
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
I: Init<S, CtxT> + Send,
|
||||
{
|
||||
type Context = ws::WebsocketContext<Self>;
|
||||
|
||||
fn started(&mut self, ctx: &mut Self::Context) {
|
||||
let stream = self.graphql_rx.clone();
|
||||
let addr = ctx.address();
|
||||
|
||||
let fut = async move {
|
||||
let mut stream = stream.lock().await;
|
||||
while let Some(message) = stream.next().await {
|
||||
// sending the message to self so that it can be forwarded back to the client
|
||||
addr.do_send(ServerMessageWrapper { message });
|
||||
}
|
||||
}
|
||||
.into_actor(self);
|
||||
|
||||
// TODO: trace
|
||||
ctx.spawn(fut);
|
||||
}
|
||||
|
||||
fn stopped(&mut self, _: &mut Self::Context) {
|
||||
// TODO: trace
|
||||
}
|
||||
}
|
||||
|
||||
/// actor -> websocket response
|
||||
impl<Query, Mutation, Subscription, CtxT, S, I> Handler<ServerMessageWrapper<S>>
|
||||
for SubscriptionActor<Query, Mutation, Subscription, CtxT, S, I>
|
||||
where
|
||||
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Query::TypeInfo: Send + Sync,
|
||||
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Mutation::TypeInfo: Send + Sync,
|
||||
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
Subscription::TypeInfo: Send + Sync,
|
||||
CtxT: Unpin + Send + Sync + 'static,
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
I: Init<S, CtxT> + Send,
|
||||
{
|
||||
type Result = ();
|
||||
|
||||
fn handle(
|
||||
&mut self,
|
||||
msg: ServerMessageWrapper<S>,
|
||||
ctx: &mut ws::WebsocketContext<Self>,
|
||||
) -> Self::Result {
|
||||
let msg = serde_json::to_string(&msg.message);
|
||||
|
||||
match msg {
|
||||
Ok(msg) => {
|
||||
ctx.text(msg);
|
||||
}
|
||||
Err(e) => {
|
||||
let reason = ws::CloseReason {
|
||||
code: ws::CloseCode::Error,
|
||||
description: Some(format!("error serializing response: {}", e)),
|
||||
};
|
||||
|
||||
// TODO: trace
|
||||
ctx.close(Some(reason));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Message)]
|
||||
#[rtype(result = "()")]
|
||||
struct ServerMessageWrapper<S>
|
||||
where
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
{
|
||||
message: ServerMessage<S>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Message(ws::Message);
|
||||
|
||||
impl<S: ScalarValue> std::convert::TryFrom<Message> for ClientMessage<S> {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(msg: Message) -> Result<Self, Self::Error> {
|
||||
match msg.0 {
|
||||
ws::Message::Text(text) => {
|
||||
serde_json::from_slice(text.as_bytes()).map_err(|e| Error::Serde(e))
|
||||
}
|
||||
ws::Message::Close(_) => Ok(ClientMessage::ConnectionTerminate),
|
||||
_ => Err(Error::UnexpectedClientMessage),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors that can happen while handling client messages
|
||||
#[derive(Debug)]
|
||||
enum Error {
|
||||
/// Errors that can happen while deserializing client messages
|
||||
Serde(serde_json::Error),
|
||||
|
||||
/// Error for unexpected client messages
|
||||
UnexpectedClientMessage,
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Serde(e) => write!(f, "serde error: {}", e),
|
||||
Self::UnexpectedClientMessage => {
|
||||
write!(f, "unexpected message received from client")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
Loading…
Reference in a new issue