From 37a37d462ff680dab14776e602b506a8a0092e8d Mon Sep 17 00:00:00 2001 From: imspace Date: Thu, 25 Jun 2020 12:23:13 +0800 Subject: [PATCH] fix: endless poll on an errored ws stream (#683) * fix: endless poll on an errored ws stream --- juniper_warp/src/lib.rs | 179 ++++++++++++++++++++-------------------- 1 file changed, 91 insertions(+), 88 deletions(-) diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index 018f23e0..d7c4a932 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -420,7 +420,7 @@ pub mod subscriptions { }, }; - use futures::{channel::mpsc, Future, StreamExt as _}; + use futures::{channel::mpsc, Future, StreamExt as _, TryStreamExt as _}; use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _}; use juniper_subscriptions::Coordinator; use serde::{Deserialize, Serialize}; @@ -458,71 +458,71 @@ pub mod subscriptions { let context = Arc::new(context); let running = Arc::new(AtomicBool::new(false)); let got_close_signal = Arc::new(AtomicBool::new(false)); + let got_close_signal2 = got_close_signal.clone(); - sink_rx.fold(Ok(()), move |_, msg| { - let coordinator = coordinator.clone(); - let context = context.clone(); - let running = running.clone(); - let got_close_signal = got_close_signal.clone(); - let ws_tx = ws_tx.clone(); + sink_rx + .map_err(move |e| { + got_close_signal2.store(true, Ordering::Relaxed); + failure::format_err!("Websocket error: {}", e) + }) + .try_fold((), move |_, msg| { + let coordinator = coordinator.clone(); + let context = context.clone(); + let running = running.clone(); + let got_close_signal = got_close_signal.clone(); + let ws_tx = ws_tx.clone(); - async move { - let msg = match msg { - Ok(m) => m, - Err(e) => { - got_close_signal.store(true, Ordering::Relaxed); - return Err(failure::format_err!("Websocket error: {}", e)); + async move { + if msg.is_close() { + return Ok(()); } - }; - if msg.is_close() { - return Ok(()); - } + let msg = msg + .to_str() + .map_err(|_| failure::format_err!("Non-text messages are not accepted"))?; + let request: WsPayload = serde_json::from_str(msg) + .map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?; - let msg = msg - .to_str() - .map_err(|_| failure::format_err!("Non-text messages are not accepted"))?; - let request: WsPayload = serde_json::from_str(msg) - .map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?; + match request.type_name.as_str() { + "connection_init" => {} + "start" => { + { + let closed = got_close_signal.load(Ordering::Relaxed); + if closed { + return Ok(()); + } - match request.type_name.as_str() { - "connection_init" => {} - "start" => { - { - let closed = got_close_signal.load(Ordering::Relaxed); - if closed { - return Ok(()); + if running.load(Ordering::Relaxed) { + return Ok(()); + } + running.store(true, Ordering::Relaxed); } - if running.load(Ordering::Relaxed) { - return Ok(()); + let ws_tx = ws_tx.clone(); + + if let Some(ref payload) = request.payload { + if payload.query.is_none() { + return Err(failure::format_err!("Query not found")); + } + } else { + return Err(failure::format_err!("Payload not found")); } - running.store(true, Ordering::Relaxed); - } - let ws_tx = ws_tx.clone(); + tokio::task::spawn(async move { + let payload = request.payload.unwrap(); - if let Some(ref payload) = request.payload { - if payload.query.is_none() { - return Err(failure::format_err!("Query not found")); - } - } else { - return Err(failure::format_err!("Payload not found")); - } + let request_id = request.id.unwrap_or("1".to_owned()); - tokio::task::spawn(async move { - let payload = request.payload.unwrap(); + let graphql_request = GraphQLRequest::::new( + payload.query.unwrap(), + None, + payload.variables, + ); - let request_id = request.id.unwrap_or("1".to_owned()); - - let graphql_request = GraphQLRequest::::new( - payload.query.unwrap(), - None, - payload.variables, - ); - - let values_stream = - match coordinator.subscribe(&graphql_request, &context).await { + let values_stream = match coordinator + .subscribe(&graphql_request, &context) + .await + { Ok(s) => s, Err(err) => { let _ = @@ -546,48 +546,51 @@ pub mod subscriptions { } }; - values_stream - .take_while(move |response| { - let request_id = request_id.clone(); - let closed = got_close_signal.load(Ordering::Relaxed); - if !closed { - let mut response_text = serde_json::to_string(&response) + values_stream + .take_while(move |response| { + let request_id = request_id.clone(); + let closed = got_close_signal.load(Ordering::Relaxed); + if !closed { + let mut response_text = serde_json::to_string( + &response, + ) .unwrap_or("Error deserializing response".to_owned()); - response_text = format!( - r#"{{"type":"data","id":"{}","payload":{} }}"#, - request_id, response_text - ); + response_text = format!( + r#"{{"type":"data","id":"{}","payload":{} }}"#, + request_id, response_text + ); - let _ = ws_tx - .unbounded_send(Some(Ok(Message::text(response_text)))); - } + let _ = ws_tx.unbounded_send(Some(Ok(Message::text( + response_text, + )))); + } - async move { !closed } - }) - .for_each(|_| async {}) - .await; - }); + async move { !closed } + }) + .for_each(|_| async {}) + .await; + }); + } + "stop" => { + got_close_signal.store(true, Ordering::Relaxed); + + let request_id = request.id.unwrap_or("1".to_owned()); + let close_message = format!( + r#"{{"type":"complete","id":"{}","payload":null}}"#, + request_id + ); + let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message)))); + + // close channel + let _ = ws_tx.unbounded_send(None); + } + _ => {} } - "stop" => { - got_close_signal.store(true, Ordering::Relaxed); - let request_id = request.id.unwrap_or("1".to_owned()); - let close_message = format!( - r#"{{"type":"complete","id":"{}","payload":null}}"#, - request_id - ); - let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message)))); - - // close channel - let _ = ws_tx.unbounded_send(None); - } - _ => {} + Ok(()) } - - Ok(()) - } - }) + }) } #[derive(Deserialize)]