Fix juniper_warp subscriptions (#707)
* use anyhow::anyhow for subscriptions mod * remove unnecessary Clone * fix simultaneous subscriptions * rm unnecessary .clone
This commit is contained in:
parent
f5839c034e
commit
825a35c686
1 changed files with 33 additions and 24 deletions
|
@ -417,7 +417,8 @@ pub mod subscriptions {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures::{channel::mpsc, Future, StreamExt as _, TryStreamExt as _};
|
use anyhow::anyhow;
|
||||||
|
use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _};
|
||||||
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
|
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
|
||||||
use juniper_subscriptions::Coordinator;
|
use juniper_subscriptions::Coordinator;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -439,7 +440,7 @@ pub mod subscriptions {
|
||||||
Mutation::TypeInfo: Send + Sync,
|
Mutation::TypeInfo: Send + Sync,
|
||||||
Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||||
Subscription::TypeInfo: Send + Sync,
|
Subscription::TypeInfo: Send + Sync,
|
||||||
CtxT: Clone + Send + Sync + 'static,
|
CtxT: Send + Sync + 'static,
|
||||||
S: ScalarValue + Send + Sync + 'static,
|
S: ScalarValue + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
let (sink_tx, sink_rx) = websocket.split();
|
let (sink_tx, sink_rx) = websocket.split();
|
||||||
|
@ -452,25 +453,28 @@ pub mod subscriptions {
|
||||||
);
|
);
|
||||||
|
|
||||||
let context = Arc::new(context);
|
let context = Arc::new(context);
|
||||||
let running = Arc::new(AtomicBool::new(false));
|
|
||||||
let got_close_signal = Arc::new(AtomicBool::new(false));
|
let got_close_signal = Arc::new(AtomicBool::new(false));
|
||||||
let got_close_signal2 = got_close_signal.clone();
|
let got_close_signal2 = got_close_signal.clone();
|
||||||
|
|
||||||
|
struct SubscriptionState {
|
||||||
|
should_stop: AtomicBool,
|
||||||
|
}
|
||||||
|
let subscription_states = HashMap::<String, Arc<SubscriptionState>>::new();
|
||||||
|
|
||||||
sink_rx
|
sink_rx
|
||||||
.map_err(move |e| {
|
.map_err(move |e| {
|
||||||
got_close_signal2.store(true, Ordering::Relaxed);
|
got_close_signal2.store(true, Ordering::Relaxed);
|
||||||
anyhow!("Websocket error: {}", e)
|
anyhow!("Websocket error: {}", e)
|
||||||
})
|
})
|
||||||
.try_fold((), move |_, msg| {
|
.try_fold(subscription_states, move |mut subscription_states, msg| {
|
||||||
let coordinator = coordinator.clone();
|
let coordinator = coordinator.clone();
|
||||||
let context = context.clone();
|
let context = context.clone();
|
||||||
let running = running.clone();
|
|
||||||
let got_close_signal = got_close_signal.clone();
|
let got_close_signal = got_close_signal.clone();
|
||||||
let ws_tx = ws_tx.clone();
|
let ws_tx = ws_tx.clone();
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
if msg.is_close() {
|
if msg.is_close() {
|
||||||
return Ok(());
|
return Ok(subscription_states);
|
||||||
}
|
}
|
||||||
|
|
||||||
let msg = msg
|
let msg = msg
|
||||||
|
@ -482,18 +486,20 @@ pub mod subscriptions {
|
||||||
match request.type_name.as_str() {
|
match request.type_name.as_str() {
|
||||||
"connection_init" => {}
|
"connection_init" => {}
|
||||||
"start" => {
|
"start" => {
|
||||||
{
|
if got_close_signal.load(Ordering::Relaxed) {
|
||||||
let closed = got_close_signal.load(Ordering::Relaxed);
|
return Ok(subscription_states);
|
||||||
if closed {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
if running.load(Ordering::Relaxed) {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
running.store(true, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let request_id = request.id.clone().unwrap_or("1".to_owned());
|
||||||
|
|
||||||
|
if let Some(existing) = subscription_states.get(&request_id) {
|
||||||
|
existing.should_stop.store(true, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
let state = Arc::new(SubscriptionState {
|
||||||
|
should_stop: AtomicBool::new(false),
|
||||||
|
});
|
||||||
|
subscription_states.insert(request_id, state.clone());
|
||||||
|
|
||||||
let ws_tx = ws_tx.clone();
|
let ws_tx = ws_tx.clone();
|
||||||
|
|
||||||
if let Some(ref payload) = request.payload {
|
if let Some(ref payload) = request.payload {
|
||||||
|
@ -507,8 +513,6 @@ pub mod subscriptions {
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
let payload = request.payload.unwrap();
|
let payload = request.payload.unwrap();
|
||||||
|
|
||||||
let request_id = request.id.unwrap_or("1".to_owned());
|
|
||||||
|
|
||||||
let graphql_request = GraphQLRequest::<S>::new(
|
let graphql_request = GraphQLRequest::<S>::new(
|
||||||
payload.query.unwrap(),
|
payload.query.unwrap(),
|
||||||
None,
|
None,
|
||||||
|
@ -545,8 +549,9 @@ pub mod subscriptions {
|
||||||
values_stream
|
values_stream
|
||||||
.take_while(move |response| {
|
.take_while(move |response| {
|
||||||
let request_id = request_id.clone();
|
let request_id = request_id.clone();
|
||||||
let closed = got_close_signal.load(Ordering::Relaxed);
|
let should_stop = state.should_stop.load(Ordering::Relaxed)
|
||||||
if !closed {
|
|| got_close_signal.load(Ordering::Relaxed);
|
||||||
|
if !should_stop {
|
||||||
let mut response_text = serde_json::to_string(
|
let mut response_text = serde_json::to_string(
|
||||||
&response,
|
&response,
|
||||||
)
|
)
|
||||||
|
@ -562,16 +567,19 @@ pub mod subscriptions {
|
||||||
))));
|
))));
|
||||||
}
|
}
|
||||||
|
|
||||||
async move { !closed }
|
async move { !should_stop }
|
||||||
})
|
})
|
||||||
.for_each(|_| async {})
|
.for_each(|_| async {})
|
||||||
.await;
|
.await;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
"stop" => {
|
"stop" => {
|
||||||
got_close_signal.store(true, Ordering::Relaxed);
|
|
||||||
|
|
||||||
let request_id = request.id.unwrap_or("1".to_owned());
|
let request_id = request.id.unwrap_or("1".to_owned());
|
||||||
|
if let Some(existing) = subscription_states.get(&request_id) {
|
||||||
|
existing.should_stop.store(true, Ordering::Relaxed);
|
||||||
|
subscription_states.remove(&request_id);
|
||||||
|
}
|
||||||
|
|
||||||
let close_message = format!(
|
let close_message = format!(
|
||||||
r#"{{"type":"complete","id":"{}","payload":null}}"#,
|
r#"{{"type":"complete","id":"{}","payload":null}}"#,
|
||||||
request_id
|
request_id
|
||||||
|
@ -584,9 +592,10 @@ pub mod subscriptions {
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(subscription_states)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
.map_ok(|_| ())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
|
Loading…
Add table
Reference in a new issue