summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaindropsSys <raindrops@equestria.dev>2024-03-10 16:49:49 +0100
committerRaindropsSys <raindrops@equestria.dev>2024-03-10 16:49:49 +0100
commit91cee164518399db9dfc3399f2af1c4aadc9239d (patch)
treec2a31390f525cb6da795b197e5d08ae1fed43719
parent871329b8712a401514c51b517be38a69819c9052 (diff)
parent6eb7683d34dde376233d49f94294193464b46010 (diff)
downloadwhere-rs-91cee164518399db9dfc3399f2af1c4aadc9239d.tar.gz
where-rs-91cee164518399db9dfc3399f2af1c4aadc9239d.tar.bz2
where-rs-91cee164518399db9dfc3399f2af1c4aadc9239d.zip
Fix merge conflict
-rw-r--r--where-rs/src/main.rs7
-rw-r--r--where-shared/src/error.rs16
-rw-r--r--where-shared/src/lib.rs136
-rw-r--r--where-shared/src/parse.rs44
4 files changed, 99 insertions, 104 deletions
diff --git a/where-rs/src/main.rs b/where-rs/src/main.rs
index 9bbfce5..6783a68 100644
--- a/where-rs/src/main.rs
+++ b/where-rs/src/main.rs
@@ -21,7 +21,7 @@ fn start_client() -> WhereResult<()> {
let mut entries = vec![];
for server in servers {
- entries.extend(process_server(server)?.into_vec());
+ entries.extend(process_server(server, "My Computer")?.into_vec());
}
entries.sort_by_key(|s| s.login_time);
@@ -99,7 +99,7 @@ fn start_client() -> WhereResult<()> {
Ok(())
}
-fn process_server(server: &str) -> WhereResult<SessionCollection> {
+fn process_server(server: &str, host: &str) -> WhereResult<SessionCollection> {
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.set_read_timeout(Some(TIMEOUT))?;
@@ -110,7 +110,8 @@ fn process_server(server: &str) -> WhereResult<SessionCollection> {
match socket.recv_from(&mut buf) {
Ok(_) => {
- return Ok(SessionCollection::from_udp_payload(buf, "My Computer")?);
+ let collection = SessionCollection::from_udp_payload(buf, host)?;
+ return Ok(collection);
},
Err(e) if e.kind() == ErrorKind::TimedOut || e.kind() == ErrorKind::WouldBlock => continue,
Err(e) => return Err(WhereError::from(e)),
diff --git a/where-shared/src/error.rs b/where-shared/src/error.rs
index 5908077..7ae0099 100644
--- a/where-shared/src/error.rs
+++ b/where-shared/src/error.rs
@@ -1,4 +1,5 @@
use std::fmt::Display;
+use std::string::FromUtf8Error;
use std::{fmt, io};
use std::time::Duration;
use crate::{MAX_ENTRY_LENGTH, MAX_PAYLOAD_LENGTH};
@@ -15,6 +16,7 @@ pub enum EncodeDecodeError {
BadMagic([u8; 4]),
IncorrectEntryCount,
StringSizeLimitExceeded(u32, usize),
+ StringDecodeError(FromUtf8Error),
NonbinaryBoolean,
EmptyRemote,
IOErrorWhileTranscoding(io::Error)
@@ -35,12 +37,25 @@ impl From<EncodeDecodeError> for WhereError {
}
}
+impl From<FromUtf8Error> for WhereError {
+ fn from(value: FromUtf8Error) -> Self {
+ Self::EncodeDecodeError(EncodeDecodeError::StringDecodeError(value))
+ }
+}
+
impl From<io::Error> for EncodeDecodeError {
fn from(value: io::Error) -> Self {
Self::IOErrorWhileTranscoding(value)
}
}
+
+impl From<FromUtf8Error> for EncodeDecodeError {
+ fn from (value: FromUtf8Error) -> Self {
+ Self::StringDecodeError(value)
+ }
+}
+
impl Display for EncodeDecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
@@ -48,6 +63,7 @@ impl Display for EncodeDecodeError {
Self::InvalidPayloadLength(s) => write!(f, "Invalid full payload length: {s} but maximum is {MAX_PAYLOAD_LENGTH}"),
Self::BadMagic(m) => write!(f, "Invalid packet magic ({}), possible corruption or invalid server", String::from_utf8_lossy(m)),
Self::IncorrectEntryCount => write!(f, "Invalid amount of entries decoded"),
+ Self::StringDecodeError(e) => write!(f, "String decoding error: {e}"),
Self::StringSizeLimitExceeded(curr, max) => write!(f, "Exceeded length limit for payload string ({curr} > {max})"),
Self::NonbinaryBoolean => write!(f, "Boolean value is not 0 or 1"),
Self::EmptyRemote => write!(f, "Remote tag set but no remote host is present"),
diff --git a/where-shared/src/lib.rs b/where-shared/src/lib.rs
index 06cc5ba..774e2fd 100644
--- a/where-shared/src/lib.rs
+++ b/where-shared/src/lib.rs
@@ -1,8 +1,9 @@
use std::io::Cursor;
use coreutils_core::os::utmpx::*;
-use std::io::Read;
-use crate::error::{EncodeDecodeError, EncodeDecodeResult};
+use crate::error::{WhereResult, EncodeDecodeResult, EncodeDecodeError};
+
+mod parse;
pub mod error;
pub const WHERED_MAGIC: [u8; 4] = *b"WHRD";
@@ -12,15 +13,18 @@ pub const MAX_ENTRY_LENGTH: usize = MAX_REMOTE_LENGTH + MAX_USER_TTY_LENGTH * 2
pub const MAX_PAYLOAD_LENGTH: usize = 65501;
pub const MAX_PAYLOAD_ENTRIES: usize = MAX_PAYLOAD_LENGTH / MAX_ENTRY_LENGTH;
+type Payload = [u8; MAX_PAYLOAD_LENGTH];
+type PayloadCursor = Cursor<Payload>;
+
#[derive(Debug)]
pub struct Session {
pub host: Option<String>,
- pub user: String,
pub pid: i32,
+ pub login_time: i64,
+ pub user: String,
pub tty: String,
pub remote: Option<String>,
pub active: bool,
- pub login_time: i64
}
#[derive(Debug)]
@@ -71,26 +75,23 @@ impl SessionCollection {
}
}
- pub fn from_udp_payload(buffer: [u8; MAX_PAYLOAD_LENGTH], host: &str) -> EncodeDecodeResult<Self> {
- let mut buf = Cursor::new(buffer);
+ pub fn from_udp_payload(buffer: Payload, host: &str) -> WhereResult<Self> {
+ let mut cursor = Cursor::new(buffer);
let mut inner = vec![];
- let mut magic = [0u8; 4];
- let mut length = [0u8; 2];
- Session::read_field(&mut buf, &mut magic)?;
- Session::read_field(&mut buf, &mut length)?;
- let entry_count = u16::from_be_bytes(length);
+ // Check magic
+ parse::read_field(&mut cursor, |buf| {
+ if buf != WHERED_MAGIC {
+ Err(EncodeDecodeError::BadMagic(buf))?
+ } else {
+ Ok(())
+ }
+ })?;
- if magic != WHERED_MAGIC {
- return Err(EncodeDecodeError::BadMagic(magic));
- }
+ let entry_count = parse::read_field(&mut cursor, |buf| Ok(u32::from_be_bytes(buf)))?;
for _ in 0..entry_count {
- inner.push(Session::from_udp_payload(&mut buf, &host)?);
- }
-
- if inner.len() != entry_count as usize {
- return Err(EncodeDecodeError::IncorrectEntryCount);
+ inner.push(Session::from_udp_payload(&mut cursor, &host)?);
}
Ok(Self {
@@ -100,103 +101,36 @@ impl SessionCollection {
}
impl Session {
- pub fn from_udp_payload(cursor: &mut Cursor<[u8; MAX_PAYLOAD_LENGTH]>, host: &str) -> EncodeDecodeResult<Self> {
- let mut username_length = [0u8; 4];
- let mut pid = [0u8; 4];
- let mut tty_length = [0u8; 4];
- let mut remote_tag = [0u8; 1];
- let mut remote_length = [0u8; 4];
- let mut active = [0u8; 1];
- let mut login_time = [0u8; 8];
-
- Session::read_field(cursor, &mut pid)?;
- Session::read_field(cursor, &mut login_time)?;
-
- Session::read_field(cursor, &mut username_length)?;
- let username_length = u32::from_be_bytes(username_length);
- if username_length as usize > MAX_USER_TTY_LENGTH {
- return Err(EncodeDecodeError::StringSizeLimitExceeded(username_length, MAX_USER_TTY_LENGTH));
- }
-
- let mut user = vec![0u8; username_length as usize];
- Session::read_field(cursor, &mut user)?;
-
- Session::read_field(cursor, &mut tty_length)?;
- let tty_length = u32::from_be_bytes(tty_length);
- if tty_length as usize > MAX_USER_TTY_LENGTH {
- return Err(EncodeDecodeError::StringSizeLimitExceeded(tty_length, MAX_USER_TTY_LENGTH));
- }
-
- let mut tty = vec![0u8; tty_length as usize];
- Session::read_field(cursor, &mut tty)?;
-
- Session::read_field(cursor, &mut remote_tag)?;
- if remote_tag[0] > 1 {
- return Err(EncodeDecodeError::NonbinaryBoolean);
- }
-
- let has_remote_tag = remote_tag[0] == 1;
-
- let remote = if has_remote_tag {
- Session::read_field(cursor, &mut remote_length)?;
- let remote_length = u32::from_be_bytes(remote_length);
- if remote_length as usize > MAX_USER_TTY_LENGTH {
- return Err(EncodeDecodeError::StringSizeLimitExceeded(username_length, MAX_USER_TTY_LENGTH));
- }
-
- if remote_length == 0 {
- return Err(EncodeDecodeError::EmptyRemote);
+ pub fn from_udp_payload(cursor: &mut PayloadCursor, host: &str) -> WhereResult<Self> {
+ let pid = parse::read_field(cursor, |buf| Ok(i32::from_be_bytes(buf)))?;
+ let login_time = parse::read_field(cursor, |buf| Ok(i64::from_be_bytes(buf)))?;
+ let user = parse::read_string_field(cursor)?;
+ let tty = parse::read_string_field(cursor)?;
+
+ let remote = {
+ let has_remote_tag = parse::read_bool_field(cursor)?;
+ if has_remote_tag {
+ Some(parse::read_string_field(cursor)?)
+ } else {
+ None
}
-
- let mut remote = vec![0u8; remote_length as usize];
- Session::read_field(cursor, &mut remote)?;
-
- Some(String::from_utf8_lossy(&remote).to_string())
- } else {
- None
};
- Session::read_field(cursor, &mut active)?;
- if active[0] > 1 {
- return Err(EncodeDecodeError::NonbinaryBoolean);
- }
-
- let user = String::from_utf8_lossy(&user).to_string();
- let pid = i32::from_be_bytes(pid);
- let tty = String::from_utf8_lossy(&tty).to_string();
- let active = active[0] == 1;
- let login_time = i64::from_be_bytes(login_time);
+ let active = parse::read_bool_field(cursor)?;
let host = Some(host.to_string());
Ok(Self {
host,
- user,
pid,
+ login_time,
+ user,
tty,
remote,
active,
- login_time
})
}
- fn read_field(cursor: &mut Cursor<[u8; MAX_PAYLOAD_LENGTH]>, buffer: &mut [u8]) -> EncodeDecodeResult<()> {
- cursor.read_exact(buffer)?;
- Ok(())
- }
-
- /*fn read_field<T, F>(cursor: &mut Cursor<&[u8]>, convert_func: F) -> WhereResult<T>
- where
- N: const usize,
- F: FnOnce([u8; N]) -> EncodeDecodeError,
- {
- let mut buf = [0u8; N];
- cursor.read_exact(&mut buf)?;
-
- let value = convert_func(buf)?;
- Ok(value)
- }*/
-
pub fn to_udp_payload(self) -> Vec<u8> {
let mut bytes: Vec<u8> = vec![];
diff --git a/where-shared/src/parse.rs b/where-shared/src/parse.rs
new file mode 100644
index 0000000..e624ce6
--- /dev/null
+++ b/where-shared/src/parse.rs
@@ -0,0 +1,44 @@
+use std::io::Read;
+
+use crate::error::{EncodeDecodeError, WhereError, WhereResult};
+use crate::MAX_USER_TTY_LENGTH;
+use crate::PayloadCursor;
+
+pub fn read_field<const N: usize, F, T>(cursor: &mut PayloadCursor, convert_func: F) -> WhereResult<T>
+where
+ F: Fn([u8; N]) -> WhereResult<T>
+{
+ let mut buffer = [0u8; N];
+ cursor.read_exact(&mut buffer)?;
+
+ let value = convert_func(buffer)?;
+ Ok(value)
+}
+
+pub fn read_field_dynamic<F, T>(cursor: &mut PayloadCursor, size: usize, convert_func: F) -> WhereResult<T>
+where
+ F: Fn(Vec<u8>) -> WhereResult<T>
+{
+ let mut buffer = vec![0u8; size];
+ cursor.read_exact(&mut buffer)?;
+
+ let value = convert_func(buffer)?;
+ Ok(value)
+}
+
+pub fn read_bool_field(cursor: &mut PayloadCursor) -> WhereResult<bool> {
+ let value = read_field::<1, _, _>(cursor, |buf| Ok(buf[0] == 1))?;
+ Ok(value)
+}
+
+pub fn read_string_field(cursor: &mut PayloadCursor) -> WhereResult<String> {
+ let string_length = read_field(cursor, |buf| Ok(u32::from_be_bytes(buf)))?;
+
+ if string_length > MAX_USER_TTY_LENGTH as u32 {
+ return Err(WhereError::EncodeDecodeError(EncodeDecodeError::StringSizeLimitExceeded(string_length, MAX_USER_TTY_LENGTH)));
+ }
+
+ let string = read_field_dynamic(cursor, string_length as usize, |buf| Ok(String::from_utf8(buf)?))?;
+
+ Ok(string)
+}