From 792054749f76c06876ed7c79547bef6dd0803b43 Mon Sep 17 00:00:00 2001 From: Kyle Johnsen Date: Fri, 18 Mar 2022 08:59:41 -0400 Subject: [PATCH] important change, added 5s timeout on all connections. Must send heartbeat --- src/main.rs | 98 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 76 insertions(+), 22 deletions(-) diff --git a/src/main.rs b/src/main.rs index aecb577..cdd38ec 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,10 @@ use async_std::prelude::*; +use async_std::future; use async_std::net::TcpListener; use async_std::net::TcpStream; use async_std::net::UdpSocket; use async_notify::Notify; use futures::stream::StreamExt; -use futures::future; use futures::join; use futures::select; use futures::pin_mut; @@ -15,7 +15,7 @@ use std::rc::{Rc}; use std::cell::{RefCell}; use std::fs; use chrono::Local; -use std::time; +use std::time::Duration; use serde::{Serialize, Deserialize}; use std::error::Error; use std::ptr; @@ -102,15 +102,19 @@ async fn main() { let tcp_future = tcp_listener .incoming() - .for_each_concurrent(None, |tcpstream| process_client(tcpstream.unwrap(), udp_socket.clone(), clients.clone(),rooms.clone(),last_client_id.clone())); + .for_each_concurrent(None, |tcpstream| process_client(tcpstream.unwrap(), udp_socket.clone(), clients.clone(),rooms.clone(),last_client_id.clone(),&config)); let udp_future = process_udp(udp_socket.clone(),clients.clone(),rooms.clone()); join!(tcp_future,udp_future); } -async fn process_client(socket: TcpStream, udp_socket: Rc>, clients: Rc>>>>, rooms: Rc>>>>, last_client_id: Rc>){ +async fn process_client(socket: TcpStream, udp_socket: Rc>, clients: Rc>>>>, rooms: Rc>>>>, last_client_id: Rc>, config: &Config){ println!("started tcp"); + + socket.set_nodelay(true).unwrap(); + //socket.set_read_timeout(Some(time::Duration::new(config.tcp_timeout,0))).unwrap(); + //socket.set_write_timeout(Some(time::Duration::new(config.tcp_timeout,0))).unwrap(); let my_id; { @@ -156,7 +160,6 @@ async fn process_client(socket: TcpStream, udp_socket: Rc>, c } - let read_async = client_read(client.clone(), socket.clone(), clients.clone(), rooms.clone()).fuse(); let write_async = client_write(client.clone(), socket, client_notify.clone()).fuse(); let write_async_udp = client_write_udp(client.clone(), udp_socket.clone(), client_notify_udp.clone()).fuse(); @@ -180,17 +183,39 @@ async fn process_client(socket: TcpStream, udp_socket: Rc>, c } +async fn read_timeout(mut socket: &TcpStream, buf: &mut [u8], duration: u64) -> Result> { + + match future::timeout(Duration::from_millis(duration), socket.read(buf)).await { + Ok(r) => { + match r { + + Ok(n) if n == 0 => { + + return Err(format!("{}", "no bytes read"))? + + }, + Ok(n) => { + return Ok(n); + }, + Err(e) => {return Err(format!("{}", e.to_string()))?} + } + + }, + + Err(e) => {return Err(format!("{}", e.to_string()))?} + + } + +} + async fn client_read(client: Rc>, mut socket: TcpStream, clients: Rc>>>>, rooms: Rc>>>>){ let mut buf = [0; 1]; loop { - match socket.read(&mut buf).await { - // socket closed - Ok(n) if n == 0 => { - println!("client read ended naturally?"); - return; - }, + + match read_timeout(&mut socket, &mut buf, 5000).await { + Ok(_) => { let t = buf[0]; @@ -211,14 +236,14 @@ async fn client_read(client: Rc>, mut socket: TcpStream, clients }; } else if t == FromClientTCPMessageType::JoinRoom as u8 {//[2:u8][roomname.length():u8][roomname:shortstring] match read_join_message(socket.clone(), client.clone(), rooms.clone()).await{ - Ok(_)=>(), + Ok(_)=>(), Err(_)=>{eprintln!("failed to read from socket"); return;} }; } else if t == FromClientTCPMessageType::SendMessageOthersUnbuffered as u8 || - t == FromClientTCPMessageType::SendMessageAllUnbuffered as u8 || - t == FromClientTCPMessageType::SendMessageGroupUnbuffered as u8 || - t == FromClientTCPMessageType::SendMessageOthersBuffered as u8 || - t == FromClientTCPMessageType::SendMessageAllBuffered as u8 { //others,all,group[t:u8][message.length():i32][message:u8array] + t == FromClientTCPMessageType::SendMessageAllUnbuffered as u8 || + t == FromClientTCPMessageType::SendMessageGroupUnbuffered as u8 || + t == FromClientTCPMessageType::SendMessageOthersBuffered as u8 || + t == FromClientTCPMessageType::SendMessageAllBuffered as u8 { //others,all,group[t:u8][message.length():i32][message:u8array] match read_send_message(socket.clone(), client.clone(), rooms.clone(), t).await{ Ok(_)=>(), Err(_)=>{eprintln!("failed to read from socket"); return;} @@ -240,9 +265,37 @@ async fn client_read(client: Rc>, mut socket: TcpStream, clients //remove the client return; } + + }; } } + + +async fn write_timeout(mut socket: &TcpStream, buf: &[u8], duration: u64) -> Result> { + + match future::timeout(Duration::from_millis(duration), socket.write(buf)).await { + Ok(r) => { + match r { + + Ok(n)=> { + + return Ok(n); + + }, + Err(e) => {return Err(format!("{}", e.to_string()))?} + + } + + }, + + Err(e) => {return Err(format!("{}", e.to_string()))?} + + } + +} + + async fn client_write(client: Rc>, mut socket: TcpStream, notify: Rc){ //wait on messages in my queue @@ -261,7 +314,7 @@ async fn client_write(client: Rc>, mut socket: TcpStream, notify client_ref.message_queue.clear(); } - match socket.write(&to_write).await { + match write_timeout(&mut socket, &to_write, 5000).await { Ok(_) => (), Err(_) => {eprintln!("failed to write to the tcp socket"); return;} } @@ -882,32 +935,33 @@ fn send_group_message(sender: Rc>, message: &Vec, group: &St async fn read_u8(stream: &mut TcpStream) -> Result> { let mut buf = [0; 1]; - stream.read_exact(&mut buf).await?; + read_timeout(stream, &mut buf, 5000).await?; return Ok(buf[0]); + } async fn read_u32(stream: &mut TcpStream) -> Result> { let mut buf:[u8;4] = [0; 4]; - stream.read_exact(&mut buf).await?; + read_timeout(stream, &mut buf, 5000).await?; let size = u32::from_be_bytes(buf); return Ok(size); } async fn _read_string(stream: &mut TcpStream) -> Result> { let size = read_u32(stream).await?; let mut string_bytes = vec![0;size as usize]; - stream.read_exact(&mut string_bytes).await?; + read_timeout(stream, &mut string_bytes, 5000).await?; return Ok(String::from_utf8(string_bytes).unwrap()); } async fn read_short_string(stream: &mut TcpStream) -> Result> { let size = read_u8(stream).await?; let mut string_bytes = vec![0;size as usize]; - stream.read_exact(&mut string_bytes).await?; + read_timeout(stream, &mut string_bytes, 5000).await?; return Ok(String::from_utf8(string_bytes).unwrap()); } async fn read_vec(stream: &mut TcpStream) -> Result,Box> { let message_size = read_u32(stream).await?; let mut message = vec![0u8;message_size as usize]; - stream.read_exact(&mut message).await?; + read_timeout(stream, &mut message, 5000).await?; return Ok(message); }