diff options
author | RaindropsSys <raindrops@equestria.dev> | 2024-03-10 16:49:49 +0100 |
---|---|---|
committer | RaindropsSys <raindrops@equestria.dev> | 2024-03-10 16:49:49 +0100 |
commit | 91cee164518399db9dfc3399f2af1c4aadc9239d (patch) | |
tree | c2a31390f525cb6da795b197e5d08ae1fed43719 | |
parent | 871329b8712a401514c51b517be38a69819c9052 (diff) | |
parent | 6eb7683d34dde376233d49f94294193464b46010 (diff) | |
download | where-rs-91cee164518399db9dfc3399f2af1c4aadc9239d.tar.gz where-rs-91cee164518399db9dfc3399f2af1c4aadc9239d.tar.bz2 where-rs-91cee164518399db9dfc3399f2af1c4aadc9239d.zip |
Fix merge conflict
-rw-r--r-- | where-rs/src/main.rs | 7 | ||||
-rw-r--r-- | where-shared/src/error.rs | 16 | ||||
-rw-r--r-- | where-shared/src/lib.rs | 136 | ||||
-rw-r--r-- | where-shared/src/parse.rs | 44 |
4 files changed, 99 insertions, 104 deletions
diff --git a/where-rs/src/main.rs b/where-rs/src/main.rs index 9bbfce5..6783a68 100644 --- a/where-rs/src/main.rs +++ b/where-rs/src/main.rs @@ -21,7 +21,7 @@ fn start_client() -> WhereResult<()> { let mut entries = vec![]; for server in servers { - entries.extend(process_server(server)?.into_vec()); + entries.extend(process_server(server, "My Computer")?.into_vec()); } entries.sort_by_key(|s| s.login_time); @@ -99,7 +99,7 @@ fn start_client() -> WhereResult<()> { Ok(()) } -fn process_server(server: &str) -> WhereResult<SessionCollection> { +fn process_server(server: &str, host: &str) -> WhereResult<SessionCollection> { let socket = UdpSocket::bind("0.0.0.0:0")?; socket.set_read_timeout(Some(TIMEOUT))?; @@ -110,7 +110,8 @@ fn process_server(server: &str) -> WhereResult<SessionCollection> { match socket.recv_from(&mut buf) { Ok(_) => { - return Ok(SessionCollection::from_udp_payload(buf, "My Computer")?); + let collection = SessionCollection::from_udp_payload(buf, host)?; + return Ok(collection); }, Err(e) if e.kind() == ErrorKind::TimedOut || e.kind() == ErrorKind::WouldBlock => continue, Err(e) => return Err(WhereError::from(e)), diff --git a/where-shared/src/error.rs b/where-shared/src/error.rs index 5908077..7ae0099 100644 --- a/where-shared/src/error.rs +++ b/where-shared/src/error.rs @@ -1,4 +1,5 @@ use std::fmt::Display; +use std::string::FromUtf8Error; use std::{fmt, io}; use std::time::Duration; use crate::{MAX_ENTRY_LENGTH, MAX_PAYLOAD_LENGTH}; @@ -15,6 +16,7 @@ pub enum EncodeDecodeError { BadMagic([u8; 4]), IncorrectEntryCount, StringSizeLimitExceeded(u32, usize), + StringDecodeError(FromUtf8Error), NonbinaryBoolean, EmptyRemote, IOErrorWhileTranscoding(io::Error) @@ -35,12 +37,25 @@ impl From<EncodeDecodeError> for WhereError { } } +impl From<FromUtf8Error> for WhereError { + fn from(value: FromUtf8Error) -> Self { + Self::EncodeDecodeError(EncodeDecodeError::StringDecodeError(value)) + } +} + impl From<io::Error> for EncodeDecodeError { fn from(value: io::Error) -> Self { Self::IOErrorWhileTranscoding(value) } } + +impl From<FromUtf8Error> for EncodeDecodeError { + fn from (value: FromUtf8Error) -> Self { + Self::StringDecodeError(value) + } +} + impl Display for EncodeDecodeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -48,6 +63,7 @@ impl Display for EncodeDecodeError { Self::InvalidPayloadLength(s) => write!(f, "Invalid full payload length: {s} but maximum is {MAX_PAYLOAD_LENGTH}"), Self::BadMagic(m) => write!(f, "Invalid packet magic ({}), possible corruption or invalid server", String::from_utf8_lossy(m)), Self::IncorrectEntryCount => write!(f, "Invalid amount of entries decoded"), + Self::StringDecodeError(e) => write!(f, "String decoding error: {e}"), Self::StringSizeLimitExceeded(curr, max) => write!(f, "Exceeded length limit for payload string ({curr} > {max})"), Self::NonbinaryBoolean => write!(f, "Boolean value is not 0 or 1"), Self::EmptyRemote => write!(f, "Remote tag set but no remote host is present"), diff --git a/where-shared/src/lib.rs b/where-shared/src/lib.rs index 06cc5ba..774e2fd 100644 --- a/where-shared/src/lib.rs +++ b/where-shared/src/lib.rs @@ -1,8 +1,9 @@ use std::io::Cursor; use coreutils_core::os::utmpx::*; -use std::io::Read; -use crate::error::{EncodeDecodeError, EncodeDecodeResult}; +use crate::error::{WhereResult, EncodeDecodeResult, EncodeDecodeError}; + +mod parse; pub mod error; pub const WHERED_MAGIC: [u8; 4] = *b"WHRD"; @@ -12,15 +13,18 @@ pub const MAX_ENTRY_LENGTH: usize = MAX_REMOTE_LENGTH + MAX_USER_TTY_LENGTH * 2 pub const MAX_PAYLOAD_LENGTH: usize = 65501; pub const MAX_PAYLOAD_ENTRIES: usize = MAX_PAYLOAD_LENGTH / MAX_ENTRY_LENGTH; +type Payload = [u8; MAX_PAYLOAD_LENGTH]; +type PayloadCursor = Cursor<Payload>; + #[derive(Debug)] pub struct Session { pub host: Option<String>, - pub user: String, pub pid: i32, + pub login_time: i64, + pub user: String, pub tty: String, pub remote: Option<String>, pub active: bool, - pub login_time: i64 } #[derive(Debug)] @@ -71,26 +75,23 @@ impl SessionCollection { } } - pub fn from_udp_payload(buffer: [u8; MAX_PAYLOAD_LENGTH], host: &str) -> EncodeDecodeResult<Self> { - let mut buf = Cursor::new(buffer); + pub fn from_udp_payload(buffer: Payload, host: &str) -> WhereResult<Self> { + let mut cursor = Cursor::new(buffer); let mut inner = vec![]; - let mut magic = [0u8; 4]; - let mut length = [0u8; 2]; - Session::read_field(&mut buf, &mut magic)?; - Session::read_field(&mut buf, &mut length)?; - let entry_count = u16::from_be_bytes(length); + // Check magic + parse::read_field(&mut cursor, |buf| { + if buf != WHERED_MAGIC { + Err(EncodeDecodeError::BadMagic(buf))? + } else { + Ok(()) + } + })?; - if magic != WHERED_MAGIC { - return Err(EncodeDecodeError::BadMagic(magic)); - } + let entry_count = parse::read_field(&mut cursor, |buf| Ok(u32::from_be_bytes(buf)))?; for _ in 0..entry_count { - inner.push(Session::from_udp_payload(&mut buf, &host)?); - } - - if inner.len() != entry_count as usize { - return Err(EncodeDecodeError::IncorrectEntryCount); + inner.push(Session::from_udp_payload(&mut cursor, &host)?); } Ok(Self { @@ -100,103 +101,36 @@ impl SessionCollection { } impl Session { - pub fn from_udp_payload(cursor: &mut Cursor<[u8; MAX_PAYLOAD_LENGTH]>, host: &str) -> EncodeDecodeResult<Self> { - let mut username_length = [0u8; 4]; - let mut pid = [0u8; 4]; - let mut tty_length = [0u8; 4]; - let mut remote_tag = [0u8; 1]; - let mut remote_length = [0u8; 4]; - let mut active = [0u8; 1]; - let mut login_time = [0u8; 8]; - - Session::read_field(cursor, &mut pid)?; - Session::read_field(cursor, &mut login_time)?; - - Session::read_field(cursor, &mut username_length)?; - let username_length = u32::from_be_bytes(username_length); - if username_length as usize > MAX_USER_TTY_LENGTH { - return Err(EncodeDecodeError::StringSizeLimitExceeded(username_length, MAX_USER_TTY_LENGTH)); - } - - let mut user = vec![0u8; username_length as usize]; - Session::read_field(cursor, &mut user)?; - - Session::read_field(cursor, &mut tty_length)?; - let tty_length = u32::from_be_bytes(tty_length); - if tty_length as usize > MAX_USER_TTY_LENGTH { - return Err(EncodeDecodeError::StringSizeLimitExceeded(tty_length, MAX_USER_TTY_LENGTH)); - } - - let mut tty = vec![0u8; tty_length as usize]; - Session::read_field(cursor, &mut tty)?; - - Session::read_field(cursor, &mut remote_tag)?; - if remote_tag[0] > 1 { - return Err(EncodeDecodeError::NonbinaryBoolean); - } - - let has_remote_tag = remote_tag[0] == 1; - - let remote = if has_remote_tag { - Session::read_field(cursor, &mut remote_length)?; - let remote_length = u32::from_be_bytes(remote_length); - if remote_length as usize > MAX_USER_TTY_LENGTH { - return Err(EncodeDecodeError::StringSizeLimitExceeded(username_length, MAX_USER_TTY_LENGTH)); - } - - if remote_length == 0 { - return Err(EncodeDecodeError::EmptyRemote); + pub fn from_udp_payload(cursor: &mut PayloadCursor, host: &str) -> WhereResult<Self> { + let pid = parse::read_field(cursor, |buf| Ok(i32::from_be_bytes(buf)))?; + let login_time = parse::read_field(cursor, |buf| Ok(i64::from_be_bytes(buf)))?; + let user = parse::read_string_field(cursor)?; + let tty = parse::read_string_field(cursor)?; + + let remote = { + let has_remote_tag = parse::read_bool_field(cursor)?; + if has_remote_tag { + Some(parse::read_string_field(cursor)?) + } else { + None } - - let mut remote = vec![0u8; remote_length as usize]; - Session::read_field(cursor, &mut remote)?; - - Some(String::from_utf8_lossy(&remote).to_string()) - } else { - None }; - Session::read_field(cursor, &mut active)?; - if active[0] > 1 { - return Err(EncodeDecodeError::NonbinaryBoolean); - } - - let user = String::from_utf8_lossy(&user).to_string(); - let pid = i32::from_be_bytes(pid); - let tty = String::from_utf8_lossy(&tty).to_string(); - let active = active[0] == 1; - let login_time = i64::from_be_bytes(login_time); + let active = parse::read_bool_field(cursor)?; let host = Some(host.to_string()); Ok(Self { host, - user, pid, + login_time, + user, tty, remote, active, - login_time }) } - fn read_field(cursor: &mut Cursor<[u8; MAX_PAYLOAD_LENGTH]>, buffer: &mut [u8]) -> EncodeDecodeResult<()> { - cursor.read_exact(buffer)?; - Ok(()) - } - - /*fn read_field<T, F>(cursor: &mut Cursor<&[u8]>, convert_func: F) -> WhereResult<T> - where - N: const usize, - F: FnOnce([u8; N]) -> EncodeDecodeError, - { - let mut buf = [0u8; N]; - cursor.read_exact(&mut buf)?; - - let value = convert_func(buf)?; - Ok(value) - }*/ - pub fn to_udp_payload(self) -> Vec<u8> { let mut bytes: Vec<u8> = vec![]; diff --git a/where-shared/src/parse.rs b/where-shared/src/parse.rs new file mode 100644 index 0000000..e624ce6 --- /dev/null +++ b/where-shared/src/parse.rs @@ -0,0 +1,44 @@ +use std::io::Read; + +use crate::error::{EncodeDecodeError, WhereError, WhereResult}; +use crate::MAX_USER_TTY_LENGTH; +use crate::PayloadCursor; + +pub fn read_field<const N: usize, F, T>(cursor: &mut PayloadCursor, convert_func: F) -> WhereResult<T> +where + F: Fn([u8; N]) -> WhereResult<T> +{ + let mut buffer = [0u8; N]; + cursor.read_exact(&mut buffer)?; + + let value = convert_func(buffer)?; + Ok(value) +} + +pub fn read_field_dynamic<F, T>(cursor: &mut PayloadCursor, size: usize, convert_func: F) -> WhereResult<T> +where + F: Fn(Vec<u8>) -> WhereResult<T> +{ + let mut buffer = vec![0u8; size]; + cursor.read_exact(&mut buffer)?; + + let value = convert_func(buffer)?; + Ok(value) +} + +pub fn read_bool_field(cursor: &mut PayloadCursor) -> WhereResult<bool> { + let value = read_field::<1, _, _>(cursor, |buf| Ok(buf[0] == 1))?; + Ok(value) +} + +pub fn read_string_field(cursor: &mut PayloadCursor) -> WhereResult<String> { + let string_length = read_field(cursor, |buf| Ok(u32::from_be_bytes(buf)))?; + + if string_length > MAX_USER_TTY_LENGTH as u32 { + return Err(WhereError::EncodeDecodeError(EncodeDecodeError::StringSizeLimitExceeded(string_length, MAX_USER_TTY_LENGTH))); + } + + let string = read_field_dynamic(cursor, string_length as usize, |buf| Ok(String::from_utf8(buf)?))?; + + Ok(string) +} |