From 373fc5889d764adc1d7c26867499b5b5f69cdb65 Mon Sep 17 00:00:00 2001 From: Zoey Date: Thu, 25 Apr 2024 00:01:28 -0700 Subject: [PATCH] replace most `.expect`s and `.unwrap`s with a custom error type, closes #54 --- Cargo.lock | 1 + Cargo.toml | 1 + src/error.rs | 12 +++++++ src/level.rs | 21 ++++-------- src/main.rs | 7 ++-- src/server.rs | 29 +++++++++-------- src/server/network.rs | 55 +++++++++++++++++++------------- src/server/network/extensions.rs | 13 ++++++-- 8 files changed, 82 insertions(+), 57 deletions(-) create mode 100644 src/error.rs diff --git a/Cargo.lock b/Cargo.lock index 223d3c2..0e8bd5b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -127,6 +127,7 @@ dependencies = [ "serde", "serde_json", "strum", + "thiserror", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 74e782a..1476225 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,4 +18,5 @@ safer-bytes = "0.2" serde = {version = "1", features = ["derive"]} serde_json = "1" strum = { version = "0.26", features = ["derive"] } +thiserror = "1" tokio = {version = "1", features = ["full"]} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..ce306ec --- /dev/null +++ b/src/error.rs @@ -0,0 +1,12 @@ +/// error type for the server +#[derive(Debug, thiserror::Error)] +pub enum GeneralError { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Json(#[from] serde_json::Error), + #[error("{0}")] + Custom(String), + #[error("{0}")] + CustomPrivate(String), +} diff --git a/src/level.rs b/src/level.rs index 03296e9..c79d60f 100644 --- a/src/level.rs +++ b/src/level.rs @@ -6,7 +6,7 @@ use std::{ use serde::{Deserialize, Serialize}; -use crate::{packet::server::ServerPacket, util::neighbors}; +use crate::{error::GeneralError, packet::server::ServerPacket, util::neighbors}; use self::block::BLOCK_INFO; @@ -108,7 +108,7 @@ impl Level { } /// saves the level - pub async fn save

(&self, path: P) -> std::io::Result<()> + pub async fn save

(&self, path: P) -> Result<(), GeneralError> where P: AsRef, { @@ -116,29 +116,22 @@ impl Level { tokio::fs::create_dir_all(path).await?; tokio::fs::write( path.join(LEVEL_INFO_PATH), - serde_json::to_string_pretty(self).unwrap(), + serde_json::to_string_pretty(self)?, ) .await?; let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::best()); - encoder - .write_all(&self.blocks) - .expect("failed to write blocks"); - tokio::fs::write( - path.join(LEVEL_DATA_PATH), - encoder.finish().expect("failed to encode blocks"), - ) - .await + encoder.write_all(&self.blocks)?; + Ok(tokio::fs::write(path.join(LEVEL_DATA_PATH), encoder.finish()?).await?) } /// loads the level - pub async fn load

(path: P) -> std::io::Result + pub async fn load

(path: P) -> Result where P: AsRef, { let path = path.as_ref(); let mut info: Self = - serde_json::from_str(&tokio::fs::read_to_string(path.join(LEVEL_INFO_PATH)).await?) - .expect("failed to deserialize level info"); + serde_json::from_str(&tokio::fs::read_to_string(path.join(LEVEL_INFO_PATH)).await?)?; let blocks_data = tokio::fs::read(path.join(LEVEL_DATA_PATH)).await?; let mut decoder = flate2::read::GzDecoder::new(blocks_data.as_slice()); decoder.read_to_end(&mut info.blocks)?; diff --git a/src/main.rs b/src/main.rs index e8ff4ac..2a3044a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,12 +2,14 @@ use std::path::PathBuf; +use error::GeneralError; use server::{ config::{OptionalServerConfig, ServerConfig}, Server, }; mod command; +mod error; mod level; mod packet; mod player; @@ -18,11 +20,10 @@ const SERVER_NAME: &str = "classics"; const CONFIG_FILE: &str = "./server-config.json"; #[tokio::main] -async fn main() -> std::io::Result<()> { +async fn main() -> Result<(), GeneralError> { let config_path = PathBuf::from(CONFIG_FILE); let config = if config_path.exists() { - serde_json::from_str::(&std::fs::read_to_string(&config_path)?) - .expect("failed to deserialize config") + serde_json::from_str::(&std::fs::read_to_string(&config_path)?)? .build_default() } else { ServerConfig::default() diff --git a/src/server.rs b/src/server.rs index 968941d..b59cfa1 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,6 +6,7 @@ use std::{path::PathBuf, sync::Arc}; use tokio::{net::TcpListener, sync::RwLock}; use crate::{ + error::GeneralError, level::{ block::{BlockType, BLOCK_INFO}, BlockUpdate, Level, @@ -58,7 +59,7 @@ impl ServerData { impl Server { /// creates a new server with a generated level - pub async fn new(config: ServerConfig) -> std::io::Result { + pub async fn new(config: ServerConfig) -> Result { let levels_path = PathBuf::from(LEVELS_PATH); if !levels_path.exists() { std::fs::create_dir_all(&levels_path)?; @@ -84,7 +85,7 @@ impl Server { } /// creates a new server with the given level - pub async fn new_with_level(config: ServerConfig, level: Level) -> std::io::Result { + pub async fn new_with_level(config: ServerConfig, level: Level) -> Result { let listener = TcpListener::bind("0.0.0.0:25565").await?; Ok(Self { @@ -101,11 +102,15 @@ impl Server { } /// starts the server - pub async fn run(self) -> std::io::Result<()> { + pub async fn run(self) -> Result<(), GeneralError> { let data = self.data.clone(); tokio::spawn(async move { loop { - let (stream, addr) = self.listener.accept().await.unwrap(); + let (stream, addr) = self + .listener + .accept() + .await + .expect("failed to accept listener!"); println!("connection from {addr}"); let data = data.clone(); tokio::spawn(async move { @@ -113,7 +118,7 @@ impl Server { }); } }); - handle_ticks(self.data.clone()).await; + handle_ticks(self.data.clone()).await?; tokio::time::sleep(std::time::Duration::from_millis(1)).await; // TODO: cancel pending tasks/send out "Server is stopping" messages *here* instead of elsewhere @@ -129,7 +134,7 @@ impl Server { } /// function to tick the server -async fn handle_ticks(data: Arc>) { +async fn handle_ticks(data: Arc>) -> Result<(), GeneralError> { let mut current_tick = 0; let mut last_auto_save = std::time::Instant::now(); loop { @@ -138,12 +143,7 @@ async fn handle_ticks(data: Arc>) { tick(&mut data, current_tick); if data.config_needs_saving { - std::fs::write( - CONFIG_FILE, - serde_json::to_string_pretty(&data.config) - .expect("failed to serialize default config"), - ) - .expect("failed to save config file"); + tokio::fs::write(CONFIG_FILE, serde_json::to_string_pretty(&data.config)?).await?; data.config_needs_saving = false; } @@ -164,8 +164,7 @@ async fn handle_ticks(data: Arc>) { data.level.save_now = false; data.level .save(PathBuf::from(LEVELS_PATH).join(&data.config.level_name)) - .await - .expect("failed to autosave level"); + .await?; last_auto_save = std::time::Instant::now(); let packet = ServerPacket::Message { @@ -181,6 +180,8 @@ async fn handle_ticks(data: Arc>) { current_tick = current_tick.wrapping_add(1); tokio::time::sleep(TICK_DURATION).await; } + + Ok(()) } /// function which ticks the server once diff --git a/src/server/network.rs b/src/server/network.rs index 8071dbd..e2e94eb 100644 --- a/src/server/network.rs +++ b/src/server/network.rs @@ -13,6 +13,7 @@ use tokio::{ use crate::{ command::Command, + error::GeneralError, level::{block::BLOCK_INFO, BlockUpdate, Level}, packet::{ client::ClientPacket, server::ServerPacket, ExtBitmask, PacketWriter, ARRAY_LENGTH, @@ -24,7 +25,7 @@ use crate::{ use super::ServerData; -async fn next_packet(stream: &mut TcpStream) -> std::io::Result> { +async fn next_packet(stream: &mut TcpStream) -> Result, GeneralError> { let id = stream.read_u8().await?; if let Some(size) = ClientPacket::get_size_from_id(id) { @@ -37,7 +38,7 @@ async fn next_packet(stream: &mut TcpStream) -> std::io::Result(stream: &mut TcpStream, packets: I) -> std::io::Result<()> +async fn write_packets(stream: &mut TcpStream, packets: I) -> Result<(), GeneralError> where I: Iterator, { @@ -84,9 +85,12 @@ pub(super) async fn handle_stream( let r = handle_stream_inner(&mut stream, addr, data.clone(), &mut own_id).await; println!("{addr} is no longer connected"); - match r { - Ok(disconnect_reason) => { - if let Some(disconnect_reason) = disconnect_reason { + if let Err(e) = r { + match e { + // unexpected eof is expected when clients disconnect + GeneralError::Io(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {} + GeneralError::Custom(disconnect_reason) => { + println!("disconnecting <{addr}> for reason: {disconnect_reason}"); let packet = ServerPacket::DisconnectPlayer { disconnect_reason }; let writer = PacketWriter::default().write_u8(packet.get_id()); let msg = packet.write(writer).into_raw_packet(); @@ -94,11 +98,8 @@ pub(super) async fn handle_stream( eprintln!("Failed to write disconnect packet for <{addr}>: {e}"); } } - } - Err(e) => { - // unexpected eof is expected when clients disconnect - if e.kind() != std::io::ErrorKind::UnexpectedEof { - eprintln!("Error in stream handler for <{addr}>: {e}") + _ => { + eprintln!("Error in stream handler for <{addr}>: {e:?}"); } } } @@ -129,7 +130,7 @@ async fn handle_stream_inner( addr: SocketAddr, data: Arc>, own_id: &mut i8, -) -> std::io::Result> { +) -> Result<(), GeneralError> { let mut reply_queue: Vec = Vec::new(); macro_rules! msg { @@ -144,7 +145,7 @@ async fn handle_stream_inner( loop { if let Some(player) = data.read().await.players.iter().find(|p| p.id == *own_id) { if let Some(msg) = &player.should_be_kicked { - return Ok(Some(format!("Kicked: {msg}"))); + return Err(GeneralError::Custom(msg.clone())); } } @@ -157,7 +158,7 @@ async fn handle_stream_inner( magic_number, } => { if protocol_version != 0x07 { - return Ok(Some("Unknown protocol version! Please connect with a classic 0.30-compatible client.".to_string())); + return Err(GeneralError::Custom("Unknown protocol version! Please connect with a classic 0.30-compatible client.".to_string())); } let zero = f16::from_f32(0.0); @@ -168,7 +169,9 @@ async fn handle_stream_inner( ServerProtectionMode::None => {} ServerProtectionMode::Password(password) => { if verification_key != *password { - return Ok(Some("Incorrect password!".to_string())); + return Err(GeneralError::Custom( + "Incorrect password!".to_string(), + )); } } ServerProtectionMode::PasswordsByUser(passwords) => { @@ -177,14 +180,18 @@ async fn handle_stream_inner( .map(|password| verification_key == *password) .unwrap_or_default() { - return Ok(Some("Incorrect password!".to_string())); + return Err(GeneralError::Custom( + "Incorrect password!".to_string(), + )); } } } for player in &data.players { if player.username == username { - return Ok(Some("Player with username already connected!".to_string())); + return Err(GeneralError::Custom( + "Player with username already connected!".to_string(), + )); } } @@ -232,7 +239,7 @@ async fn handle_stream_inner( println!("generating level packets"); reply_queue.extend( - build_level_packets(&data.level, extensions, custom_blocks_support_level) + build_level_packets(&data.level, extensions, custom_blocks_support_level)? .into_iter(), ); @@ -321,7 +328,9 @@ async fn handle_stream_inner( || y.clamp(0, data.level.y_size as i16 - 1) != y || z.clamp(0, data.level.z_size as i16 - 1) != z { - return Ok(Some("Attempt to place block out of bounds".to_string())); + return Err(GeneralError::Custom( + "Attempt to place block out of bounds".to_string(), + )); } let new_block_info = BLOCK_INFO.get(&block_type); @@ -429,7 +438,7 @@ async fn handle_stream_inner( ClientPacket::Extended(_packet) => { // extended packets! - return Ok(Some( + return Err(GeneralError::Custom( "Unexpected extension packet in this phase!".to_string(), )); // match packet { @@ -468,7 +477,7 @@ fn build_level_packets( level: &Level, extensions: ExtBitmask, custom_blocks_support_level: u8, -) -> Vec { +) -> Result, GeneralError> { let mut packets: Vec = vec![ServerPacket::LevelInitialize {}]; let custom_blocks = @@ -490,8 +499,8 @@ fn build_level_packets( })); let mut e = GzEncoder::new(Vec::new(), Compression::best()); - e.write_all(&data).expect("failed to gzip level data"); - let data = e.finish().expect("failed to gzip level data"); + e.write_all(&data)?; + let data = e.finish()?; let data_len = data.len(); let mut total_bytes = 0; @@ -513,5 +522,5 @@ fn build_level_packets( z_size: level.z_size as i16, }); - packets + Ok(packets) } diff --git a/src/server/network/extensions.rs b/src/server/network/extensions.rs index a348f62..95ea696 100644 --- a/src/server/network/extensions.rs +++ b/src/server/network/extensions.rs @@ -1,6 +1,7 @@ use tokio::net::TcpStream; use crate::{ + error::GeneralError, level::block::CUSTOM_BLOCKS_SUPPORT_LEVEL, packet::{ client::ClientPacket, client_extended::ExtendedClientPacket, server::ServerPacket, @@ -10,7 +11,9 @@ use crate::{ use super::{next_packet, write_packets}; -pub async fn get_supported_extensions(stream: &mut TcpStream) -> std::io::Result<(ExtBitmask, u8)> { +pub async fn get_supported_extensions( + stream: &mut TcpStream, +) -> Result<(ExtBitmask, u8), GeneralError> { let extensions = ExtBitmask::all().all_contained_info(); write_packets( @@ -39,7 +42,9 @@ pub async fn get_supported_extensions(stream: &mut TcpStream) -> std::io::Result { client_extensions.push(ExtInfo::new(ext_name, version, ExtBitmask::none())); } else { - panic!("expected ExtEntry packet!"); + return Err(GeneralError::Custom( + "expected ExtEntry packet!".to_string(), + )); } } client_extensions.retain_mut(|cext| { @@ -76,7 +81,9 @@ pub async fn get_supported_extensions(stream: &mut TcpStream) -> std::io::Result { support_level.min(CUSTOM_BLOCKS_SUPPORT_LEVEL) } else { - panic!("expected CustomBlockSupportLevel packet!"); + return Err(GeneralError::Custom( + "expected CustomBlockSupportLevel packet!".to_string(), + )); } } else { 0