diff --git a/src/client.rs b/src/client.rs index 5b37e67..53af4ed 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,12 +12,20 @@ // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use crate::Config; -use std::io::Result; -use tokio::prelude::*; +use crate::{server::read_line, Config}; +use std::{ + io::{Error, ErrorKind, Result}, + os::unix::fs::MetadataExt, + path::PathBuf, +}; +use tokio::{ + fs::{remove_file, File}, + io::{copy, split, BufReader, BufWriter}, + prelude::*, +}; use tokio_libtls::prelude::*; -pub(crate) async fn run(config: Config) -> Result<()> { +pub(crate) async fn run(config: Config, filename: PathBuf) -> Result<()> { let (cert, key, ca) = config.load_keys()?; let mut options = config.load_client_options(); @@ -27,11 +35,103 @@ pub(crate) async fn run(config: Config) -> Result<()> { .build() .unwrap(); let addr = config.address.unwrap(); - let mut tls = AsyncTls::connect(&addr.to_string(), &tls_config, options.build()) + let tls = AsyncTls::connect(&addr.to_string(), &tls_config, options.build()) .await .unwrap(); - let _ = tls.write_all(b"OK\r\n").await; - let mut buf = vec![0u8; 1024]; - let _ = tls.read(&mut buf).await; + let peer_addr: String = addr.to_string(); + + let (reader, writer) = split(tls); + let mut reader = BufReader::new(reader); + let mut writer = BufWriter::new(writer); + + // Send the filename + let name = match filename + .file_name() + .ok_or_else(|| { + debug!("{} failed: file name ({})", peer_addr, filename.display()); + Error::new(ErrorKind::Other, "file") + })? + .to_str() + { + Some(name) => name, + None => { + debug!( + "{} failed: filename format ({})", + peer_addr, + filename.display() + ); + return Err(Error::new(ErrorKind::Other, "file format")); + } + }; + let _ = writer.write_all(format!("{}\n", name).as_bytes()).await; + + // Send the file size + let file_size = filename.metadata()?.size(); + let _ = writer + .write_all(format!("{}\n", file_size).as_bytes()) + .await; + + debug!( + "{} status: sending {} ({} bytes)", + peer_addr, + filename.display(), + file_size + ); + + // Send the actual file + let file = match File::open(&filename).await { + Ok(f) => f, + Err(err) => { + debug!( + "{} failed {}: file ({})", + peer_addr, + filename.display(), + err + ); + return Err(err); + } + }; + + let mut file_reader = file.take(file_size); + let copied = match copy(&mut file_reader, &mut writer).await { + Ok(s) => s, + Err(err) => { + debug!("{} failed: I/O ({:?})", peer_addr, err); + return Err(err); + } + }; + + if copied != file_size { + drop(file_reader); + let _ = remove_file(&filename).await; + warn!( + "{} failed: {} ({}/{} bytes)", + peer_addr, + filename.display(), + copied, + file_size + ); + } else { + info!( + "{} success: {} ({} bytes)", + peer_addr, + filename.display(), + copied + ); + } + + // Read the server result + match read_line(&peer_addr, &mut reader).await { + Ok(s) if s.starts_with("success") => s, + Ok(s) => { + debug!("server: {}", s); + return Err(Error::new(ErrorKind::Other, s)); + } + Err(err) => { + debug!("{}", err); + return Err(err); + } + }; + Ok(()) } diff --git a/src/main.rs b/src/main.rs index 90317e3..1e8a3cc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -110,24 +110,38 @@ async fn main() { "server name", &config.servername.as_ref().unwrap(), ); + opts.optopt("f", "file", "send file as client", "filename"); opts.optflag("s", "server", "run server"); - opts.optflag("c", "client", "connect as client"); opts.optflag("h", "help", "print this help menu"); let matches = match opts.parse(&args[1..]) { Ok(m) => m, Err(f) => panic!(f.to_string()), }; - if matches.opt_present("h") || (matches.opt_present("c") && matches.opt_present("s")) { + if matches.opt_present("h") || (matches.opt_present("f") && matches.opt_present("s")) { usage(&program, opts); } + if let Some(address) = matches.opt_str("a") { + let addr: SocketAddr = address.parse().unwrap(); + config.address = Some(addr); + } + + let addr = match config.address { + Some(SocketAddr::V6(addr)) => addr.clone(), + _ => panic!("invalid address: {:?}", config.address), + }; env_logger::builder() .filter_level(LevelFilter::Debug) .init(); - let addr = match config.address { - Some(SocketAddr::V6(addr)) => addr.clone(), - _ => panic!("invalid address: {:?}", config.address), + let file = if let Some(file) = matches.opt_str("f") { + let file = PathBuf::from(file); + if !file.exists() { + panic!("invalid file: {}", file.display()) + } + Some(file) + } else { + None }; let keypair = KeyPair::new(&config); @@ -140,8 +154,8 @@ async fn main() { addr.port() ); - if matches.opt_present("c") { - client::run(config).await.expect("client"); + if let Some(file) = file { + client::run(config, file).await.expect("client"); } else { server::run(config).await.expect("server"); } diff --git a/src/server.rs b/src/server.rs index 9c9bf6e..ac6ae49 100644 --- a/src/server.rs +++ b/src/server.rs @@ -19,7 +19,9 @@ use std::{ }; use tokio::{ fs::{remove_file, File}, - io::{self, AsyncReadExt}, + io::{ + copy, split, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter, + }, net::TcpListener, }; use tokio_libtls::prelude::*; @@ -44,7 +46,7 @@ pub(crate) async fn run(config: Config) -> Result<()> { let size_limit = config.size_limit as u64; let peer_addr: String = tcp.peer_addr()?.to_string(); let options = options.build(); - let mut tls = match AsyncTls::accept_stream(tcp, &tls_config, options).await { + let tls = match AsyncTls::accept_stream(tcp, &tls_config, options).await { Ok(tls) => { debug!("{} status: connected", peer_addr); tls @@ -55,11 +57,15 @@ pub(crate) async fn run(config: Config) -> Result<()> { } }; + let (reader, mut writer) = split(tls); + let mut reader = BufReader::new(reader); + // let mut writer = BufWriter::new(writer); + tokio::spawn(async move { let mut filename = PathBuf::from("/home/reyk/Downloads"); // Read and validate the filename - let line = match read_line(&peer_addr, &mut tls).await { + let line = match read_line(&peer_addr, &mut reader).await { Ok(s) if !(s.contains(MAIN_SEPARATOR) || s.contains('/') || s.contains('\\')) => s, Err(err) => { debug!("{}", err); @@ -89,7 +95,7 @@ pub(crate) async fn run(config: Config) -> Result<()> { } // Read the file size - let line = match read_line(&peer_addr, &mut tls).await { + let line = match read_line(&peer_addr, &mut reader).await { Ok(s) => s, Err(err) => { debug!("{}", err); @@ -119,7 +125,7 @@ pub(crate) async fn run(config: Config) -> Result<()> { ); // Create output file - let mut file = match File::create(&filename).await { + let file = match File::create(&filename).await { Ok(f) => f, Err(err) => { debug!( @@ -133,8 +139,9 @@ pub(crate) async fn run(config: Config) -> Result<()> { }; // I/O - let mut reader = tls.take(file_size); - let copied = match io::copy(&mut reader, &mut file).await { + let mut file_writer = BufWriter::new(file); + let mut reader = reader.take(file_size); + let copied = match copy(&mut reader, &mut file_writer).await { Ok(s) => s, Err(err) => { debug!("{} failed: I/O ({})", peer_addr, err); @@ -142,9 +149,10 @@ pub(crate) async fn run(config: Config) -> Result<()> { } }; - if copied != file_size { - drop(file); - let _ = remove_file("a.txt").await.is_ok(); + // Check result and send response + let result = if copied != file_size { + drop(file_writer); + let _ = remove_file(&filename).await.is_ok(); warn!( "{} failed: {} ({}/{} bytes)", peer_addr, @@ -152,6 +160,7 @@ pub(crate) async fn run(config: Config) -> Result<()> { copied, file_size ); + "failed: truncated file\n".to_string() } else { info!( "{} success: {} ({} bytes)", @@ -159,28 +168,34 @@ pub(crate) async fn run(config: Config) -> Result<()> { filename.display(), copied ); - } + "success\n".to_string() + }; + + let _ = writer.write_all(result.as_bytes()).await; }); } } -async fn read_line(peer: &str, reader: &mut T) -> Result { - let mut buf = vec![0u8; 1024]; - if let Err(err) = reader.read(&mut buf).await { - return Err(Error::new( - ErrorKind::Other, - format!("{} failed: read ({})", peer, err), - )); - } - let line = match String::from_utf8(buf) { - Ok(s) => s, - Err(err) => { +pub async fn read_line( + peer: &str, + reader: &mut T, +) -> Result { + let mut line = String::with_capacity(1024); + let mut len = 0; + + // Ignore some empty lines + for _ in 0..10 { + if let Err(err) = reader.read_line(&mut line).await { return Err(Error::new( ErrorKind::Other, - format!("{} read failed: line ({})", peer, err), + format!("{} failed: read ({})", peer, err), )); } - }; - let len = line.find(|c: char| c == '\r' || c == '\n').unwrap_or(0); + len = line.find(|c: char| c == '\r' || c == '\n').unwrap_or(0); + if len > 0 { + break; + } + } + Ok((&line[0..len]).to_owned()) }