diff options
author | RaindropsSys <raindrops@equestria.dev> | 2024-03-13 22:36:31 +0100 |
---|---|---|
committer | RaindropsSys <raindrops@equestria.dev> | 2024-03-13 22:36:31 +0100 |
commit | c199fbe28e4aa8ddcd47b3e6cbbee59f11385d5f (patch) | |
tree | 8daa2d1ea32af724ea4449b8f18a67511b7feaaf | |
parent | 8aa42793332aff699807acb7f4fafb5a7d684fef (diff) | |
download | where-rs-c199fbe28e4aa8ddcd47b3e6cbbee59f11385d5f.tar.gz where-rs-c199fbe28e4aa8ddcd47b3e6cbbee59f11385d5f.tar.bz2 where-rs-c199fbe28e4aa8ddcd47b3e6cbbee59f11385d5f.zip |
Add initial config file support + bug fixes + UI fixes (warning: this sucks A LOT!)
-rw-r--r-- | Cargo.lock | 76 | ||||
-rw-r--r-- | config.toml | 80 | ||||
-rw-r--r-- | where-rs/Cargo.toml | 2 | ||||
-rw-r--r-- | where-rs/src/main.rs | 114 | ||||
-rw-r--r-- | where-shared/src/error.rs | 14 | ||||
-rw-r--r-- | where-shared/src/lib.rs | 6 | ||||
-rw-r--r-- | whered/src/main.rs | 4 |
7 files changed, 272 insertions, 24 deletions
@@ -112,6 +112,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "212d0f5754cb6769937f4501cc0e67f4f4483c8d2c3e1e922ee9edbe4ab4c7c0" [[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" + +[[package]] name = "iana-time-zone" version = "0.1.60" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -135,6 +147,16 @@ dependencies = [ ] [[package]] +name = "indexmap" +version = "2.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] name = "itoa" version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -309,6 +331,15 @@ dependencies = [ ] [[package]] +name = "serde_spanned" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" +dependencies = [ + "serde", +] + +[[package]] name = "sha1" version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -442,6 +473,40 @@ dependencies = [ ] [[package]] +name = "toml" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af06656561d28735e9c1cd63dfd57132c8155426aa6af24f36a00a351f88c48e" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18769cd1cec395d70860ceb4d932812a0b4d06b1a4bb336745a4d21b9496e992" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -512,6 +577,8 @@ name = "where-rs" version = "0.1.0" dependencies = [ "chrono", + "serde", + "toml", "where-shared", ] @@ -616,3 +683,12 @@ name = "windows_x86_64_msvc" version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" + +[[package]] +name = "winnow" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dffa400e67ed5a4dd237983829e66475f0a4a26938c4b04c21baede6262215b8" +dependencies = [ + "memchr", +] diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..44fb53a --- /dev/null +++ b/config.toml @@ -0,0 +1,80 @@ +# where-rs: .where.toml, v1.0 2024/03/13 + +# This is the where-rs configuration file. Documentation is provided in-line. + +# where-rs is a collection of 2 programs: whered, the server-side implementation +# of the WHRD/UDP protocol, and where(1), the client-side utility. + +# This configuration file covers the client-side part of where-rs. +# If you don't know about TOML, check <https://toml.io/en/>. + +# The following options apply to all the backend servers that are connected. +[global] + +# The default time to wait for a server to respond. If overriden in a server-specific +# configuration, that value is used instead. +# Default: 2000 +#timeout = 2000 + +# How many times where(1) should retry contacting a server if it does not reply +# within the timeout. If overriden in a server-specific configuration, that value +# is used instead. +# Default: 3 +#max_retries = 3 + +# Whether inactive sessions should be shown in the output or not. This includes +# users that have logged out but their terminal is still unused, as well as +# terminals reserved for specific users that have never been used since the system +# started up. +# Default: true +#include_inactive = true + +# The default port to use for contacting a whered server if it is not specified in the +# endpoint address. The WHRD/UDP specification says the port should be 15/udp, but it +# can be changed to adapt to environments where using port 15/udp is not possible. +# Default: 15 +#port = 15 + +# The default "Source" value if there is no SSH client connected. This is the value that +# is shown for local sessions. This can also be an empty string, which is the default +# behavior for POSIX's who(1). +# Default: "Local" +#source = "Local" + +# These are server-specific configurations. There can be as many as you want, and each +# server will be processed in the order that they are in the configuration file. Only +# the "endpoint" value is required in each server configuration. +[[server]] + +# This is the address to connect to. It can be any type of address (domain name, IPv4 +# or IPv6) and will use your default DNS server if needed. The port defined in +# global.port is used if no port is specified (through :<port> at the end). +endpoint = "127.0.0.1" + +# The label that is displayed in the UI to represent this server. If this is not set, +# the "endpoint" value will be used instead. +#label = "Computer" + +# This allows you to override the timeout value on a per-server basis (to, e.g., set a +# higher timeout for servers that are farther away). This timeout will take effect on +# all the attempts to reach this server and only this server. +# Default: 2000 (unless overriden by global.timeout) +#timeout = 2000 + +# This allows you to override the maximum allowed retries on a per-server basis. This +# value will be applied only for this server and can be used to, for instance, set a +# higher threshold for more unstable servers. +# Default: 3 (unless overriden by global.max_retries) +#max_retries = 3 + +# Whether where(1) should continue processing data even if this server has not responded +# within the allowed time range. By default, where(1) will stop if a server does not +# respond after the maximum allowed retries. If this is set to true, this server will +# simply be ignored and where(1) will continue processing the next server. +# Default: false +#failsafe = false + +# Add more server configurations as you see fit: +#[[server]] +#endpoint = "10.51.0.2" +#... diff --git a/where-rs/Cargo.toml b/where-rs/Cargo.toml index 42604ab..690d02b 100644 --- a/where-rs/Cargo.toml +++ b/where-rs/Cargo.toml @@ -12,3 +12,5 @@ path = "src/main.rs" [dependencies] where-shared = { path = "../where-shared" } chrono = "0.4.35" +toml = "0.8.11" +serde = { version = "1.0.197", features = ["derive"] } diff --git a/where-rs/src/main.rs b/where-rs/src/main.rs index fcaf832..4391c52 100644 --- a/where-rs/src/main.rs +++ b/where-rs/src/main.rs @@ -1,13 +1,39 @@ -use std::net::UdpSocket; +use std::fs; +use std::net::{SocketAddr, ToSocketAddrs, UdpSocket}; use std::io::ErrorKind; use std::time::Duration; use where_shared::error::{WhereError, WhereResult}; use where_shared::{Session, SessionCollection, MAX_PAYLOAD_LENGTH, WHERED_MAGIC}; use chrono::prelude::*; +use serde::Deserialize; -pub const TIMEOUT: Duration = Duration::from_millis(2000); +pub const TIMEOUT: u64 = 2000; pub const MAX_SEND_RETRIES: usize = 3; +#[derive(Deserialize, Debug)] +struct Config { + global: Option<GlobalConfig>, + server: Vec<ServerConfig> +} + +#[derive(Deserialize, Debug)] +struct ServerConfig { + endpoint: String, + label: Option<String>, + timeout: Option<u64>, + max_retries: Option<usize>, + failsafe: Option<bool> +} + +#[derive(Deserialize, Debug, Clone, Default)] +struct GlobalConfig { + timeout: Option<u64>, + max_retries: Option<usize>, + include_inactive: Option<bool>, + port: Option<u16>, + source: Option<String> +} + fn main() { if let Err(e) = start_client() { eprintln!("where: {}", e); @@ -16,29 +42,76 @@ fn main() { } fn start_client() -> WhereResult<()> { - let servers = ["127.0.0.1:15"]; + // TODO: Make it load from an actual path: /etc/where.toml, or ~/.where.toml if it exists + let config_path = "./config.toml"; + + let config: Config = toml::from_str(&fs::read_to_string(config_path).unwrap_or_else(|e| { + eprintln!("where: Failed to open configuration file: {e}"); + std::process::exit(1); + })).unwrap_or_else(|e| { + eprintln!("where: Failed to parse configuration file: {e}"); + std::process::exit(1); + }); + + println!("{:?}", config); + let global_config = config.global.unwrap_or_default(); + + let servers: Vec<ServerConfig> = config.server; let mut sessions = vec![]; for server in servers { - sessions.extend(process_server(server, "My Computer")?.into_vec()); + // I know using .clone() sucks! + let res = match process_server(&server, global_config.clone()) { + Ok(data) => { + data + } + Err(e) => { + eprintln!("where: {e}"); + + if !server.failsafe.unwrap_or(false) { + std::process::exit(1); + } + + SessionCollection::get_empty() + } + }; + + sessions.extend(res.into_vec()); } - print_summary(sessions); + print_summary(sessions, global_config); Ok(()) } -fn process_server(server: &str, host: &str) -> WhereResult<SessionCollection> { - let socket = UdpSocket::bind("0.0.0.0:0")?; - socket.set_read_timeout(Some(TIMEOUT))?; +fn process_server(server: &ServerConfig, config: GlobalConfig) -> WhereResult<SessionCollection> { + let label = server.label.clone().unwrap_or(server.endpoint.to_owned()); + let timeout = Duration::from_millis(server.timeout.unwrap_or(config.timeout.unwrap_or(TIMEOUT))); + let retries = server.max_retries.unwrap_or(config.max_retries.unwrap_or(MAX_SEND_RETRIES)); + + let address: SocketAddr = match server.endpoint.to_socket_addrs() { + Ok(addr) => addr.as_slice()[0], + Err(_) => { + let mut endpoint = server.endpoint.clone(); + endpoint.push_str(&format!(":{}", config.port.unwrap_or(15))); + endpoint.to_socket_addrs()?.as_slice()[0] + } + }; + + let socket = UdpSocket::bind(if address.is_ipv4() { + "0.0.0.0:0" + } else { + "[::]:0" + })?; + socket.set_read_timeout(Some(timeout))?; let mut buf = [0; MAX_PAYLOAD_LENGTH]; - for _ in 0..MAX_SEND_RETRIES { - socket.send_to(&WHERED_MAGIC, server)?; + for _ in 0..retries { + socket.send_to(&WHERED_MAGIC, address)?; match socket.recv_from(&mut buf) { Ok(_) => { - let collection = SessionCollection::from_udp_payload(buf, host)?; + let collection = SessionCollection::from_udp_payload(buf, &label)?; return Ok(collection); }, Err(e) if e.kind() == ErrorKind::TimedOut || e.kind() == ErrorKind::WouldBlock => continue, @@ -46,10 +119,10 @@ fn process_server(server: &str, host: &str) -> WhereResult<SessionCollection> { } } - Err(WhereError::TimedOut(server.to_string(), MAX_SEND_RETRIES, TIMEOUT)) + Err(WhereError::TimedOut(server.endpoint.to_string(), address.to_string(), retries, timeout)) } -fn print_summary(mut sessions: Vec<Session>) { +fn print_summary(mut sessions: Vec<Session>, config: GlobalConfig) { fn max_key_with_min<T, F>(sessions: &[Session], get_key: F, floor: T) -> T where T: Ord + Default, @@ -64,23 +137,22 @@ fn print_summary(mut sessions: Vec<Session>) { sessions.sort_unstable_by_key(|s| s.login_time); - sessions.sort_by_key(|s| s.active); + sessions.sort_by_key(|s| !s.active); // We want active first - const ACTIVE_PADDING: usize = 4; + const ACTIVE_PADDING: usize = 2; let host_padding = max_key_with_min(&sessions, |s| s.host.as_deref().map_or(0, |str| str.len()), 5); let remote_padding = max_key_with_min(&sessions, |s| s.remote.as_deref().map_or(0, |str| str.len()), 7); let username_padding = max_key_with_min(&sessions, |s| s.user.len(), 5); let tty_padding = max_key_with_min(&sessions, |s| s.tty.len(), 4); let pid_padding = max_key_with_min(&sessions, |s| s.pid.abs().checked_ilog10().unwrap_or_default() + 1 + (s.pid < 0) as u32, 4); - println!("{:pad_0$} {:<pad_1$} {:<pad_2$} {:<pad_3$} {:<pad_4$} {:<pad_5$} {}", + println!("{:pad_0$} {:<pad_1$} {:<pad_2$} {:<pad_3$} {:<pad_4$} {:<pad_5$} Since", "Act", "Host", "Source", "User", "TTY", "PID", - "Since", pad_0 = ACTIVE_PADDING, pad_1 = host_padding, pad_2 = remote_padding, @@ -89,6 +161,10 @@ fn print_summary(mut sessions: Vec<Session>) { pad_5 = pid_padding as usize); for session in sessions { + if !config.include_inactive.unwrap_or(true) && !session.active { + continue; + } + let active = if session.active { '*' } else { @@ -96,12 +172,12 @@ fn print_summary(mut sessions: Vec<Session>) { }; let host = session.host.unwrap_or_else(|| ' '.to_string()); - let remote = session.remote.unwrap_or_else(|| "Local".to_owned()); + let remote = session.remote.unwrap_or_else(|| config.source.clone().unwrap_or("Local".to_string())); let datetime = DateTime::from_timestamp(session.login_time, 0).unwrap(); let time = datetime.format("%Y-%m-%d %H:%M:%S"); - println!("{:<pad_0$} {:<pad_1$} {:<pad_2$} {:<pad_3$} {:<pad_4$} {:<pad_5$} {}", + println!(" {:<pad_0$} {:<pad_1$} {:<pad_2$} {:<pad_3$} {:<pad_4$} {:<pad_5$} {}", active, host, remote, diff --git a/where-shared/src/error.rs b/where-shared/src/error.rs index 7ae0099..931aefe 100644 --- a/where-shared/src/error.rs +++ b/where-shared/src/error.rs @@ -1,13 +1,15 @@ use std::fmt::Display; use std::string::FromUtf8Error; use std::{fmt, io}; +use std::net::AddrParseError; use std::time::Duration; use crate::{MAX_ENTRY_LENGTH, MAX_PAYLOAD_LENGTH}; pub enum WhereError { EncodeDecodeError(EncodeDecodeError), IOError(io::Error), - TimedOut(String, usize, Duration) + TimedOut(String, String, usize, Duration), + CannotParseAddress(AddrParseError) } pub enum EncodeDecodeError { @@ -49,13 +51,18 @@ impl From<io::Error> for EncodeDecodeError { } } - impl From<FromUtf8Error> for EncodeDecodeError { fn from (value: FromUtf8Error) -> Self { Self::StringDecodeError(value) } } +impl From<AddrParseError> for WhereError { + fn from (value: AddrParseError) -> Self { + Self::CannotParseAddress(value) + } +} + impl Display for EncodeDecodeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -77,7 +84,8 @@ impl Display for WhereError { match self { Self::EncodeDecodeError(e) => write!(f, "Encode/decode error: {e}"), Self::IOError(e) => write!(f, "Input/output error: {e}"), - Self::TimedOut(server, max_retry, timeout) => write!(f, "Timed out waiting for data from {server} after {max_retry} attempts every {} ms", timeout.as_millis()) + Self::TimedOut(server, address, max_retry, timeout) => write!(f, "Timed out waiting for data from {server} ({address}) after {max_retry} attempts every {} ms", timeout.as_millis()), + Self::CannotParseAddress(e) => write!(f, "Unable to parse server address: {e}") } } } diff --git a/where-shared/src/lib.rs b/where-shared/src/lib.rs index dd8042a..8073a90 100644 --- a/where-shared/src/lib.rs +++ b/where-shared/src/lib.rs @@ -44,6 +44,12 @@ impl SessionCollection { inner } } + + pub fn get_empty() -> Self { + Self { + inner: vec![] + } + } pub fn into_vec(self) -> Vec<Session> { self.inner diff --git a/whered/src/main.rs b/whered/src/main.rs index 47e97de..61e010f 100644 --- a/whered/src/main.rs +++ b/whered/src/main.rs @@ -1,4 +1,4 @@ -use std::net::UdpSocket; +use std::net::{SocketAddr, UdpSocket}; use where_shared::error::WhereResult; use where_shared::{SessionCollection, WHERED_MAGIC}; @@ -10,7 +10,7 @@ fn main() { } fn run_server() -> WhereResult<()> { - let socket = UdpSocket::bind("0.0.0.0:15")?; + let socket = UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], 15)))?; println!("Now listening on 0.0.0.0:15"); loop { |