fix: endless poll on an errored ws stream (#683)
* fix: endless poll on an errored ws stream
This commit is contained in:
parent
a08ce0760d
commit
37a37d462f
1 changed files with 91 additions and 88 deletions
|
@ -420,7 +420,7 @@ pub mod subscriptions {
|
|||
},
|
||||
};
|
||||
|
||||
use futures::{channel::mpsc, Future, StreamExt as _};
|
||||
use futures::{channel::mpsc, Future, StreamExt as _, TryStreamExt as _};
|
||||
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
|
||||
use juniper_subscriptions::Coordinator;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -458,71 +458,71 @@ pub mod subscriptions {
|
|||
let context = Arc::new(context);
|
||||
let running = Arc::new(AtomicBool::new(false));
|
||||
let got_close_signal = Arc::new(AtomicBool::new(false));
|
||||
let got_close_signal2 = got_close_signal.clone();
|
||||
|
||||
sink_rx.fold(Ok(()), move |_, msg| {
|
||||
let coordinator = coordinator.clone();
|
||||
let context = context.clone();
|
||||
let running = running.clone();
|
||||
let got_close_signal = got_close_signal.clone();
|
||||
let ws_tx = ws_tx.clone();
|
||||
sink_rx
|
||||
.map_err(move |e| {
|
||||
got_close_signal2.store(true, Ordering::Relaxed);
|
||||
failure::format_err!("Websocket error: {}", e)
|
||||
})
|
||||
.try_fold((), move |_, msg| {
|
||||
let coordinator = coordinator.clone();
|
||||
let context = context.clone();
|
||||
let running = running.clone();
|
||||
let got_close_signal = got_close_signal.clone();
|
||||
let ws_tx = ws_tx.clone();
|
||||
|
||||
async move {
|
||||
let msg = match msg {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
got_close_signal.store(true, Ordering::Relaxed);
|
||||
return Err(failure::format_err!("Websocket error: {}", e));
|
||||
async move {
|
||||
if msg.is_close() {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
if msg.is_close() {
|
||||
return Ok(());
|
||||
}
|
||||
let msg = msg
|
||||
.to_str()
|
||||
.map_err(|_| failure::format_err!("Non-text messages are not accepted"))?;
|
||||
let request: WsPayload<S> = serde_json::from_str(msg)
|
||||
.map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?;
|
||||
|
||||
let msg = msg
|
||||
.to_str()
|
||||
.map_err(|_| failure::format_err!("Non-text messages are not accepted"))?;
|
||||
let request: WsPayload<S> = serde_json::from_str(msg)
|
||||
.map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?;
|
||||
match request.type_name.as_str() {
|
||||
"connection_init" => {}
|
||||
"start" => {
|
||||
{
|
||||
let closed = got_close_signal.load(Ordering::Relaxed);
|
||||
if closed {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match request.type_name.as_str() {
|
||||
"connection_init" => {}
|
||||
"start" => {
|
||||
{
|
||||
let closed = got_close_signal.load(Ordering::Relaxed);
|
||||
if closed {
|
||||
return Ok(());
|
||||
if running.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
running.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
if running.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
let ws_tx = ws_tx.clone();
|
||||
|
||||
if let Some(ref payload) = request.payload {
|
||||
if payload.query.is_none() {
|
||||
return Err(failure::format_err!("Query not found"));
|
||||
}
|
||||
} else {
|
||||
return Err(failure::format_err!("Payload not found"));
|
||||
}
|
||||
running.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let ws_tx = ws_tx.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let payload = request.payload.unwrap();
|
||||
|
||||
if let Some(ref payload) = request.payload {
|
||||
if payload.query.is_none() {
|
||||
return Err(failure::format_err!("Query not found"));
|
||||
}
|
||||
} else {
|
||||
return Err(failure::format_err!("Payload not found"));
|
||||
}
|
||||
let request_id = request.id.unwrap_or("1".to_owned());
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let payload = request.payload.unwrap();
|
||||
let graphql_request = GraphQLRequest::<S>::new(
|
||||
payload.query.unwrap(),
|
||||
None,
|
||||
payload.variables,
|
||||
);
|
||||
|
||||
let request_id = request.id.unwrap_or("1".to_owned());
|
||||
|
||||
let graphql_request = GraphQLRequest::<S>::new(
|
||||
payload.query.unwrap(),
|
||||
None,
|
||||
payload.variables,
|
||||
);
|
||||
|
||||
let values_stream =
|
||||
match coordinator.subscribe(&graphql_request, &context).await {
|
||||
let values_stream = match coordinator
|
||||
.subscribe(&graphql_request, &context)
|
||||
.await
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(err) => {
|
||||
let _ =
|
||||
|
@ -546,48 +546,51 @@ pub mod subscriptions {
|
|||
}
|
||||
};
|
||||
|
||||
values_stream
|
||||
.take_while(move |response| {
|
||||
let request_id = request_id.clone();
|
||||
let closed = got_close_signal.load(Ordering::Relaxed);
|
||||
if !closed {
|
||||
let mut response_text = serde_json::to_string(&response)
|
||||
values_stream
|
||||
.take_while(move |response| {
|
||||
let request_id = request_id.clone();
|
||||
let closed = got_close_signal.load(Ordering::Relaxed);
|
||||
if !closed {
|
||||
let mut response_text = serde_json::to_string(
|
||||
&response,
|
||||
)
|
||||
.unwrap_or("Error deserializing response".to_owned());
|
||||
|
||||
response_text = format!(
|
||||
r#"{{"type":"data","id":"{}","payload":{} }}"#,
|
||||
request_id, response_text
|
||||
);
|
||||
response_text = format!(
|
||||
r#"{{"type":"data","id":"{}","payload":{} }}"#,
|
||||
request_id, response_text
|
||||
);
|
||||
|
||||
let _ = ws_tx
|
||||
.unbounded_send(Some(Ok(Message::text(response_text))));
|
||||
}
|
||||
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(
|
||||
response_text,
|
||||
))));
|
||||
}
|
||||
|
||||
async move { !closed }
|
||||
})
|
||||
.for_each(|_| async {})
|
||||
.await;
|
||||
});
|
||||
async move { !closed }
|
||||
})
|
||||
.for_each(|_| async {})
|
||||
.await;
|
||||
});
|
||||
}
|
||||
"stop" => {
|
||||
got_close_signal.store(true, Ordering::Relaxed);
|
||||
|
||||
let request_id = request.id.unwrap_or("1".to_owned());
|
||||
let close_message = format!(
|
||||
r#"{{"type":"complete","id":"{}","payload":null}}"#,
|
||||
request_id
|
||||
);
|
||||
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message))));
|
||||
|
||||
// close channel
|
||||
let _ = ws_tx.unbounded_send(None);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
"stop" => {
|
||||
got_close_signal.store(true, Ordering::Relaxed);
|
||||
|
||||
let request_id = request.id.unwrap_or("1".to_owned());
|
||||
let close_message = format!(
|
||||
r#"{{"type":"complete","id":"{}","payload":null}}"#,
|
||||
request_id
|
||||
);
|
||||
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message))));
|
||||
|
||||
// close channel
|
||||
let _ = ws_tx.unbounded_send(None);
|
||||
}
|
||||
_ => {}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
|
Loading…
Reference in a new issue