//! This crate supplies [`SubscriptionCoordinator`] and //! [`SubscriptionConnection`] implementations for the //! [juniper](https://github.com/graphql-rust/juniper) crate. //! //! You need both this and `juniper` crate. //! //! [`SubscriptionCoordinator`]: juniper::SubscriptionCoordinator //! [`SubscriptionConnection`]: juniper::SubscriptionConnection #![deny(missing_docs)] #![deny(warnings)] #![doc(html_root_url = "https://docs.rs/juniper_subscriptions/0.14.2")] use std::{iter::FromIterator, pin::Pin}; use futures::{task::Poll, Stream}; use juniper::{ http::{GraphQLRequest, GraphQLResponse}, BoxFuture, ExecutionError, GraphQLError, GraphQLSubscriptionType, GraphQLTypeAsync, Object, ScalarValue, SubscriptionConnection, SubscriptionCoordinator, Value, ValuesStream, }; /// Simple [`SubscriptionCoordinator`] implementation: /// - contains the schema /// - handles subscription start pub struct Coordinator<'a, QueryT, MutationT, SubscriptionT, CtxT, S> where S: ScalarValue + Send + Sync, QueryT: GraphQLTypeAsync + Send + Sync, QueryT::TypeInfo: Send + Sync, MutationT: GraphQLTypeAsync + Send + Sync, MutationT::TypeInfo: Send + Sync, SubscriptionT: GraphQLSubscriptionType + Send + Sync, SubscriptionT::TypeInfo: Send + Sync, CtxT: Send + Sync, { root_node: juniper::RootNode<'a, QueryT, MutationT, SubscriptionT, S>, } impl<'a, QueryT, MutationT, SubscriptionT, CtxT, S> Coordinator<'a, QueryT, MutationT, SubscriptionT, CtxT, S> where S: ScalarValue + Send + Sync, QueryT: GraphQLTypeAsync + Send + Sync, QueryT::TypeInfo: Send + Sync, MutationT: GraphQLTypeAsync + Send + Sync, MutationT::TypeInfo: Send + Sync, SubscriptionT: GraphQLSubscriptionType + Send + Sync, SubscriptionT::TypeInfo: Send + Sync, CtxT: Send + Sync, { /// Builds new [`Coordinator`] with specified `root_node` pub fn new(root_node: juniper::RootNode<'a, QueryT, MutationT, SubscriptionT, S>) -> Self { Self { root_node } } } impl<'a, QueryT, MutationT, SubscriptionT, CtxT, S> SubscriptionCoordinator<'a, CtxT, S> for Coordinator<'a, QueryT, MutationT, SubscriptionT, CtxT, S> where S: ScalarValue + Send + Sync + 'a, QueryT: GraphQLTypeAsync + Send + Sync, QueryT::TypeInfo: Send + Sync, MutationT: GraphQLTypeAsync + Send + Sync, MutationT::TypeInfo: Send + Sync, SubscriptionT: GraphQLSubscriptionType + Send + Sync, SubscriptionT::TypeInfo: Send + Sync, CtxT: Send + Sync, { type Connection = Connection<'a, S>; type Error = GraphQLError<'a>; fn subscribe( &'a self, req: &'a GraphQLRequest, context: &'a CtxT, ) -> BoxFuture<'a, Result> { let rn = &self.root_node; Box::pin(async move { let (stream, errors) = juniper::http::resolve_into_stream(req, rn, context).await?; Ok(Connection::from_stream(stream, errors)) }) } } /// Simple [`SubscriptionConnection`] implementation. /// /// Resolves `Value` into `Stream` using the following /// logic: /// /// [`Value::Null`] - returns [`Value::Null`] once /// [`Value::Scalar`] - returns `Ok` value or [`Value::Null`] and errors vector /// [`Value::List`] - resolves each stream from the list using current logic and returns /// values in the order received /// [`Value::Object`] - waits while each field of the [`Object`] is returned, then yields the whole object /// `Value::Object>` - returns [`Value::Null`] if [`Value::Object`] consists of sub-objects pub struct Connection<'a, S> { stream: Pin> + Send + 'a>>, } impl<'a, S> Connection<'a, S> where S: ScalarValue + Send + Sync + 'a, { /// Creates new [`Connection`] from values stream and errors pub fn from_stream(stream: Value>, errors: Vec>) -> Self { Self { stream: whole_responses_stream(stream, errors), } } } impl<'a, S> SubscriptionConnection<'a, S> for Connection<'a, S> where S: ScalarValue + Send + Sync + 'a { } impl<'a, S> futures::Stream for Connection<'a, S> where S: ScalarValue + Send + Sync + 'a, { type Item = GraphQLResponse<'a, S>; fn poll_next( self: Pin<&mut Self>, cx: &mut futures::task::Context<'_>, ) -> Poll> { // this is safe as stream is only mutated here and is not moved anywhere let Connection { stream } = unsafe { self.get_unchecked_mut() }; let stream = unsafe { Pin::new_unchecked(stream) }; stream.poll_next(cx) } } /// Creates [`futures::Stream`] that yields [`GraphQLResponse`]s depending on the given [`Value`]: /// /// [`Value::Null`] - returns [`Value::Null`] once /// [`Value::Scalar`] - returns `Ok` value or [`Value::Null`] and errors vector /// [`Value::List`] - resolves each stream from the list using current logic and returns /// values in the order received /// [`Value::Object`] - waits while each field of the [`Object`] is returned, then yields the whole object /// `Value::Object>` - returns [`Value::Null`] if [`Value::Object`] consists of sub-objects fn whole_responses_stream<'a, S>( stream: Value>, errors: Vec>, ) -> Pin> + Send + 'a>> where S: ScalarValue + Send + Sync + 'a, { use futures::stream::{self, StreamExt as _}; if !errors.is_empty() { return Box::pin(stream::once(async move { GraphQLResponse::from_result(Ok((Value::Null, errors))) })); } match stream { Value::Null => Box::pin(stream::once(async move { GraphQLResponse::from_result(Ok((Value::Null, vec![]))) })), Value::Scalar(s) => Box::pin(s.map(|res| match res { Ok(val) => GraphQLResponse::from_result(Ok((val, vec![]))), Err(err) => GraphQLResponse::from_result(Ok((Value::Null, vec![err]))), })), Value::List(list) => { let mut streams = vec![]; for s in list.into_iter() { streams.push(whole_responses_stream(s, vec![])); } Box::pin(stream::select_all(streams)) } Value::Object(mut object) => { let obj_len = object.field_count(); if obj_len == 0 { return Box::pin(stream::once(async move { GraphQLResponse::from_result(Ok((Value::Null, vec![]))) })); } let mut filled_count = 0; let mut ready_vec = Vec::with_capacity(obj_len); for _ in 0..obj_len { ready_vec.push(None); } let stream = futures::stream::poll_fn( move |mut ctx| -> Poll>> { let mut obj_iterator = object.iter_mut(); // Due to having to modify `ready_vec` contents (by-move pattern) // and only being able to iterate over `object`'s mutable references (by-ref pattern) // `ready_vec` and `object` cannot be iterated simultaneously. // TODO: iterate over i and (ref field_name, ref val) once // [this RFC](https://github.com/rust-lang/rust/issues/68354) // is implemented for ready in ready_vec.iter_mut().take(obj_len) { let (field_name, val) = match obj_iterator.next() { Some(v) => v, None => break, }; if ready.is_some() { continue; } match val { Value::Scalar(stream) => { match Pin::new(stream).poll_next(&mut ctx) { Poll::Ready(None) => return Poll::Ready(None), Poll::Ready(Some(value)) => { *ready = Some((field_name.clone(), value)); filled_count += 1; } Poll::Pending => { /* check back later */ } } } _ => { // For now only `Object` is supported *ready = Some((field_name.clone(), Ok(Value::Null))); filled_count += 1; } } } if filled_count == obj_len { filled_count = 0; let new_vec = (0..obj_len).map(|_| None).collect::>(); let ready_vec = std::mem::replace(&mut ready_vec, new_vec); let ready_vec_iterator = ready_vec.into_iter().map(|el| { let (name, val) = el.unwrap(); if let Ok(value) = val { (name, value) } else { (name, Value::Null) } }); let obj = Object::from_iter(ready_vec_iterator); Poll::Ready(Some(GraphQLResponse::from_result(Ok(( Value::Object(obj), vec![], ))))) } else { Poll::Pending } }, ); Box::pin(stream) } } } #[cfg(test)] mod whole_responses_stream { use super::*; use futures::{stream, StreamExt as _}; use juniper::{DefaultScalarValue, ExecutionError, FieldError}; #[tokio::test] async fn with_error() { let expected = vec![GraphQLResponse::::error( FieldError::new("field error", Value::Null), )]; let expected = serde_json::to_string(&expected).unwrap(); let result = whole_responses_stream::( Value::Null, vec![ExecutionError::at_origin(FieldError::new( "field error", Value::Null, ))], ) .collect::>() .await; let result = serde_json::to_string(&result).unwrap(); assert_eq!(result, expected); } #[tokio::test] async fn value_null() { let expected = vec![GraphQLResponse::::from_result(Ok(( Value::Null, vec![], )))]; let expected = serde_json::to_string(&expected).unwrap(); let result = whole_responses_stream::(Value::Null, vec![]) .collect::>() .await; let result = serde_json::to_string(&result).unwrap(); assert_eq!(result, expected); } type PollResult = Result, ExecutionError>; #[tokio::test] async fn value_scalar() { let expected = vec![ GraphQLResponse::from_result(Ok(( Value::Scalar(DefaultScalarValue::Int(1i32)), vec![], ))), GraphQLResponse::from_result(Ok(( Value::Scalar(DefaultScalarValue::Int(2i32)), vec![], ))), GraphQLResponse::from_result(Ok(( Value::Scalar(DefaultScalarValue::Int(3i32)), vec![], ))), GraphQLResponse::from_result(Ok(( Value::Scalar(DefaultScalarValue::Int(4i32)), vec![], ))), GraphQLResponse::from_result(Ok(( Value::Scalar(DefaultScalarValue::Int(5i32)), vec![], ))), ]; let expected = serde_json::to_string(&expected).unwrap(); let mut counter = 0; let stream = stream::poll_fn(move |_| -> Poll> { if counter == 5 { return Poll::Ready(None); } counter += 1; Poll::Ready(Some(Ok(Value::Scalar(DefaultScalarValue::Int(counter))))) }); let result = whole_responses_stream::(Value::Scalar(Box::pin(stream)), vec![]) .collect::>() .await; let result = serde_json::to_string(&result).unwrap(); assert_eq!(result, expected); } #[tokio::test] async fn value_list() { let expected = vec![ GraphQLResponse::from_result(Ok(( Value::Scalar(DefaultScalarValue::Int(1i32)), vec![], ))), GraphQLResponse::from_result(Ok(( Value::Scalar(DefaultScalarValue::Int(2i32)), vec![], ))), GraphQLResponse::from_result(Ok((Value::Null, vec![]))), GraphQLResponse::from_result(Ok(( Value::Scalar(DefaultScalarValue::Int(4i32)), vec![], ))), ]; let expected = serde_json::to_string(&expected).unwrap(); let streams: Vec> = vec![ Value::Scalar(Box::pin(stream::once(async { PollResult::Ok(Value::Scalar(DefaultScalarValue::Int(1i32))) }))), Value::Scalar(Box::pin(stream::once(async { PollResult::Ok(Value::Scalar(DefaultScalarValue::Int(2i32))) }))), Value::Null, Value::Scalar(Box::pin(stream::once(async { PollResult::Ok(Value::Scalar(DefaultScalarValue::Int(4i32))) }))), ]; let result = whole_responses_stream::(Value::List(streams), vec![]) .collect::>() .await; let result = serde_json::to_string(&result).unwrap(); assert_eq!(result, expected); } #[tokio::test] async fn value_object() { let expected = vec![ GraphQLResponse::from_result(Ok(( Value::Object(Object::from_iter( vec![ ("one", Value::Scalar(DefaultScalarValue::Int(1i32))), ("two", Value::Scalar(DefaultScalarValue::Int(1i32))), ] .into_iter(), )), vec![], ))), GraphQLResponse::from_result(Ok(( Value::Object(Object::from_iter( vec![ ("one", Value::Scalar(DefaultScalarValue::Int(2i32))), ("two", Value::Scalar(DefaultScalarValue::Int(2i32))), ] .into_iter(), )), vec![], ))), ]; let expected = serde_json::to_string(&expected).unwrap(); let mut counter = 0; let big_stream = stream::poll_fn(move |_| -> Poll> { if counter == 2 { return Poll::Ready(None); } counter += 1; Poll::Ready(Some(Ok(Value::Scalar(DefaultScalarValue::Int(counter))))) }); let mut counter = 0; let small_stream = stream::poll_fn(move |_| -> Poll> { if counter == 2 { return Poll::Ready(None); } counter += 1; Poll::Ready(Some(Ok(Value::Scalar(DefaultScalarValue::Int(counter))))) }); let vals: Vec<(&str, Value)> = vec![ ("one", Value::Scalar(Box::pin(big_stream))), ("two", Value::Scalar(Box::pin(small_stream))), ]; let result = whole_responses_stream::( Value::Object(Object::from_iter(vals.into_iter())), vec![], ) .collect::>() .await; let result = serde_json::to_string(&result).unwrap(); assert_eq!(result, expected); } }