summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaindropsSys <raindrops@equestria.dev>2024-03-13 22:36:31 +0100
committerRaindropsSys <raindrops@equestria.dev>2024-03-13 22:36:31 +0100
commitc199fbe28e4aa8ddcd47b3e6cbbee59f11385d5f (patch)
tree8daa2d1ea32af724ea4449b8f18a67511b7feaaf
parent8aa42793332aff699807acb7f4fafb5a7d684fef (diff)
downloadwhere-rs-c199fbe28e4aa8ddcd47b3e6cbbee59f11385d5f.tar.gz
where-rs-c199fbe28e4aa8ddcd47b3e6cbbee59f11385d5f.tar.bz2
where-rs-c199fbe28e4aa8ddcd47b3e6cbbee59f11385d5f.zip
Add initial config file support + bug fixes + UI fixes (warning: this sucks A LOT!)
-rw-r--r--Cargo.lock76
-rw-r--r--config.toml80
-rw-r--r--where-rs/Cargo.toml2
-rw-r--r--where-rs/src/main.rs114
-rw-r--r--where-shared/src/error.rs14
-rw-r--r--where-shared/src/lib.rs6
-rw-r--r--whered/src/main.rs4
7 files changed, 272 insertions, 24 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 392ca06..b943e8d 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -112,6 +112,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "212d0f5754cb6769937f4501cc0e67f4f4483c8d2c3e1e922ee9edbe4ab4c7c0"
[[package]]
+name = "equivalent"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
+
+[[package]]
+name = "hashbrown"
+version = "0.14.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604"
+
+[[package]]
name = "iana-time-zone"
version = "0.1.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -135,6 +147,16 @@ dependencies = [
]
[[package]]
+name = "indexmap"
+version = "2.2.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4"
+dependencies = [
+ "equivalent",
+ "hashbrown",
+]
+
+[[package]]
name = "itoa"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -309,6 +331,15 @@ dependencies = [
]
[[package]]
+name = "serde_spanned"
+version = "0.6.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1"
+dependencies = [
+ "serde",
+]
+
+[[package]]
name = "sha1"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -442,6 +473,40 @@ dependencies = [
]
[[package]]
+name = "toml"
+version = "0.8.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "af06656561d28735e9c1cd63dfd57132c8155426aa6af24f36a00a351f88c48e"
+dependencies = [
+ "serde",
+ "serde_spanned",
+ "toml_datetime",
+ "toml_edit",
+]
+
+[[package]]
+name = "toml_datetime"
+version = "0.6.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1"
+dependencies = [
+ "serde",
+]
+
+[[package]]
+name = "toml_edit"
+version = "0.22.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "18769cd1cec395d70860ceb4d932812a0b4d06b1a4bb336745a4d21b9496e992"
+dependencies = [
+ "indexmap",
+ "serde",
+ "serde_spanned",
+ "toml_datetime",
+ "winnow",
+]
+
+[[package]]
name = "unicode-ident"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -512,6 +577,8 @@ name = "where-rs"
version = "0.1.0"
dependencies = [
"chrono",
+ "serde",
+ "toml",
"where-shared",
]
@@ -616,3 +683,12 @@ name = "windows_x86_64_msvc"
version = "0.52.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
+
+[[package]]
+name = "winnow"
+version = "0.6.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dffa400e67ed5a4dd237983829e66475f0a4a26938c4b04c21baede6262215b8"
+dependencies = [
+ "memchr",
+]
diff --git a/config.toml b/config.toml
new file mode 100644
index 0000000..44fb53a
--- /dev/null
+++ b/config.toml
@@ -0,0 +1,80 @@
+# where-rs: .where.toml, v1.0 2024/03/13
+
+# This is the where-rs configuration file. Documentation is provided in-line.
+
+# where-rs is a collection of 2 programs: whered, the server-side implementation
+# of the WHRD/UDP protocol, and where(1), the client-side utility.
+
+# This configuration file covers the client-side part of where-rs.
+# If you don't know about TOML, check <https://toml.io/en/>.
+
+# The following options apply to all the backend servers that are connected.
+[global]
+
+# The default time to wait for a server to respond. If overriden in a server-specific
+# configuration, that value is used instead.
+# Default: 2000
+#timeout = 2000
+
+# How many times where(1) should retry contacting a server if it does not reply
+# within the timeout. If overriden in a server-specific configuration, that value
+# is used instead.
+# Default: 3
+#max_retries = 3
+
+# Whether inactive sessions should be shown in the output or not. This includes
+# users that have logged out but their terminal is still unused, as well as
+# terminals reserved for specific users that have never been used since the system
+# started up.
+# Default: true
+#include_inactive = true
+
+# The default port to use for contacting a whered server if it is not specified in the
+# endpoint address. The WHRD/UDP specification says the port should be 15/udp, but it
+# can be changed to adapt to environments where using port 15/udp is not possible.
+# Default: 15
+#port = 15
+
+# The default "Source" value if there is no SSH client connected. This is the value that
+# is shown for local sessions. This can also be an empty string, which is the default
+# behavior for POSIX's who(1).
+# Default: "Local"
+#source = "Local"
+
+# These are server-specific configurations. There can be as many as you want, and each
+# server will be processed in the order that they are in the configuration file. Only
+# the "endpoint" value is required in each server configuration.
+[[server]]
+
+# This is the address to connect to. It can be any type of address (domain name, IPv4
+# or IPv6) and will use your default DNS server if needed. The port defined in
+# global.port is used if no port is specified (through :<port> at the end).
+endpoint = "127.0.0.1"
+
+# The label that is displayed in the UI to represent this server. If this is not set,
+# the "endpoint" value will be used instead.
+#label = "Computer"
+
+# This allows you to override the timeout value on a per-server basis (to, e.g., set a
+# higher timeout for servers that are farther away). This timeout will take effect on
+# all the attempts to reach this server and only this server.
+# Default: 2000 (unless overriden by global.timeout)
+#timeout = 2000
+
+# This allows you to override the maximum allowed retries on a per-server basis. This
+# value will be applied only for this server and can be used to, for instance, set a
+# higher threshold for more unstable servers.
+# Default: 3 (unless overriden by global.max_retries)
+#max_retries = 3
+
+# Whether where(1) should continue processing data even if this server has not responded
+# within the allowed time range. By default, where(1) will stop if a server does not
+# respond after the maximum allowed retries. If this is set to true, this server will
+# simply be ignored and where(1) will continue processing the next server.
+# Default: false
+#failsafe = false
+
+# Add more server configurations as you see fit:
+#[[server]]
+#endpoint = "10.51.0.2"
+#...
diff --git a/where-rs/Cargo.toml b/where-rs/Cargo.toml
index 42604ab..690d02b 100644
--- a/where-rs/Cargo.toml
+++ b/where-rs/Cargo.toml
@@ -12,3 +12,5 @@ path = "src/main.rs"
[dependencies]
where-shared = { path = "../where-shared" }
chrono = "0.4.35"
+toml = "0.8.11"
+serde = { version = "1.0.197", features = ["derive"] }
diff --git a/where-rs/src/main.rs b/where-rs/src/main.rs
index fcaf832..4391c52 100644
--- a/where-rs/src/main.rs
+++ b/where-rs/src/main.rs
@@ -1,13 +1,39 @@
-use std::net::UdpSocket;
+use std::fs;
+use std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
use std::io::ErrorKind;
use std::time::Duration;
use where_shared::error::{WhereError, WhereResult};
use where_shared::{Session, SessionCollection, MAX_PAYLOAD_LENGTH, WHERED_MAGIC};
use chrono::prelude::*;
+use serde::Deserialize;
-pub const TIMEOUT: Duration = Duration::from_millis(2000);
+pub const TIMEOUT: u64 = 2000;
pub const MAX_SEND_RETRIES: usize = 3;
+#[derive(Deserialize, Debug)]
+struct Config {
+ global: Option<GlobalConfig>,
+ server: Vec<ServerConfig>
+}
+
+#[derive(Deserialize, Debug)]
+struct ServerConfig {
+ endpoint: String,
+ label: Option<String>,
+ timeout: Option<u64>,
+ max_retries: Option<usize>,
+ failsafe: Option<bool>
+}
+
+#[derive(Deserialize, Debug, Clone, Default)]
+struct GlobalConfig {
+ timeout: Option<u64>,
+ max_retries: Option<usize>,
+ include_inactive: Option<bool>,
+ port: Option<u16>,
+ source: Option<String>
+}
+
fn main() {
if let Err(e) = start_client() {
eprintln!("where: {}", e);
@@ -16,29 +42,76 @@ fn main() {
}
fn start_client() -> WhereResult<()> {
- let servers = ["127.0.0.1:15"];
+ // TODO: Make it load from an actual path: /etc/where.toml, or ~/.where.toml if it exists
+ let config_path = "./config.toml";
+
+ let config: Config = toml::from_str(&fs::read_to_string(config_path).unwrap_or_else(|e| {
+ eprintln!("where: Failed to open configuration file: {e}");
+ std::process::exit(1);
+ })).unwrap_or_else(|e| {
+ eprintln!("where: Failed to parse configuration file: {e}");
+ std::process::exit(1);
+ });
+
+ println!("{:?}", config);
+ let global_config = config.global.unwrap_or_default();
+
+ let servers: Vec<ServerConfig> = config.server;
let mut sessions = vec![];
for server in servers {
- sessions.extend(process_server(server, "My Computer")?.into_vec());
+ // I know using .clone() sucks!
+ let res = match process_server(&server, global_config.clone()) {
+ Ok(data) => {
+ data
+ }
+ Err(e) => {
+ eprintln!("where: {e}");
+
+ if !server.failsafe.unwrap_or(false) {
+ std::process::exit(1);
+ }
+
+ SessionCollection::get_empty()
+ }
+ };
+
+ sessions.extend(res.into_vec());
}
- print_summary(sessions);
+ print_summary(sessions, global_config);
Ok(())
}
-fn process_server(server: &str, host: &str) -> WhereResult<SessionCollection> {
- let socket = UdpSocket::bind("0.0.0.0:0")?;
- socket.set_read_timeout(Some(TIMEOUT))?;
+fn process_server(server: &ServerConfig, config: GlobalConfig) -> WhereResult<SessionCollection> {
+ let label = server.label.clone().unwrap_or(server.endpoint.to_owned());
+ let timeout = Duration::from_millis(server.timeout.unwrap_or(config.timeout.unwrap_or(TIMEOUT)));
+ let retries = server.max_retries.unwrap_or(config.max_retries.unwrap_or(MAX_SEND_RETRIES));
+
+ let address: SocketAddr = match server.endpoint.to_socket_addrs() {
+ Ok(addr) => addr.as_slice()[0],
+ Err(_) => {
+ let mut endpoint = server.endpoint.clone();
+ endpoint.push_str(&format!(":{}", config.port.unwrap_or(15)));
+ endpoint.to_socket_addrs()?.as_slice()[0]
+ }
+ };
+
+ let socket = UdpSocket::bind(if address.is_ipv4() {
+ "0.0.0.0:0"
+ } else {
+ "[::]:0"
+ })?;
+ socket.set_read_timeout(Some(timeout))?;
let mut buf = [0; MAX_PAYLOAD_LENGTH];
- for _ in 0..MAX_SEND_RETRIES {
- socket.send_to(&WHERED_MAGIC, server)?;
+ for _ in 0..retries {
+ socket.send_to(&WHERED_MAGIC, address)?;
match socket.recv_from(&mut buf) {
Ok(_) => {
- let collection = SessionCollection::from_udp_payload(buf, host)?;
+ let collection = SessionCollection::from_udp_payload(buf, &label)?;
return Ok(collection);
},
Err(e) if e.kind() == ErrorKind::TimedOut || e.kind() == ErrorKind::WouldBlock => continue,
@@ -46,10 +119,10 @@ fn process_server(server: &str, host: &str) -> WhereResult<SessionCollection> {
}
}
- Err(WhereError::TimedOut(server.to_string(), MAX_SEND_RETRIES, TIMEOUT))
+ Err(WhereError::TimedOut(server.endpoint.to_string(), address.to_string(), retries, timeout))
}
-fn print_summary(mut sessions: Vec<Session>) {
+fn print_summary(mut sessions: Vec<Session>, config: GlobalConfig) {
fn max_key_with_min<T, F>(sessions: &[Session], get_key: F, floor: T) -> T
where
T: Ord + Default,
@@ -64,23 +137,22 @@ fn print_summary(mut sessions: Vec<Session>) {
sessions.sort_unstable_by_key(|s| s.login_time);
- sessions.sort_by_key(|s| s.active);
+ sessions.sort_by_key(|s| !s.active); // We want active first
- const ACTIVE_PADDING: usize = 4;
+ const ACTIVE_PADDING: usize = 2;
let host_padding = max_key_with_min(&sessions, |s| s.host.as_deref().map_or(0, |str| str.len()), 5);
let remote_padding = max_key_with_min(&sessions, |s| s.remote.as_deref().map_or(0, |str| str.len()), 7);
let username_padding = max_key_with_min(&sessions, |s| s.user.len(), 5);
let tty_padding = max_key_with_min(&sessions, |s| s.tty.len(), 4);
let pid_padding = max_key_with_min(&sessions, |s| s.pid.abs().checked_ilog10().unwrap_or_default() + 1 + (s.pid < 0) as u32, 4);
- println!("{:pad_0$} {:<pad_1$} {:<pad_2$} {:<pad_3$} {:<pad_4$} {:<pad_5$} {}",
+ println!("{:pad_0$} {:<pad_1$} {:<pad_2$} {:<pad_3$} {:<pad_4$} {:<pad_5$} Since",
"Act",
"Host",
"Source",
"User",
"TTY",
"PID",
- "Since",
pad_0 = ACTIVE_PADDING,
pad_1 = host_padding,
pad_2 = remote_padding,
@@ -89,6 +161,10 @@ fn print_summary(mut sessions: Vec<Session>) {
pad_5 = pid_padding as usize);
for session in sessions {
+ if !config.include_inactive.unwrap_or(true) && !session.active {
+ continue;
+ }
+
let active = if session.active {
'*'
} else {
@@ -96,12 +172,12 @@ fn print_summary(mut sessions: Vec<Session>) {
};
let host = session.host.unwrap_or_else(|| ' '.to_string());
- let remote = session.remote.unwrap_or_else(|| "Local".to_owned());
+ let remote = session.remote.unwrap_or_else(|| config.source.clone().unwrap_or("Local".to_string()));
let datetime = DateTime::from_timestamp(session.login_time, 0).unwrap();
let time = datetime.format("%Y-%m-%d %H:%M:%S");
- println!("{:<pad_0$} {:<pad_1$} {:<pad_2$} {:<pad_3$} {:<pad_4$} {:<pad_5$} {}",
+ println!(" {:<pad_0$} {:<pad_1$} {:<pad_2$} {:<pad_3$} {:<pad_4$} {:<pad_5$} {}",
active,
host,
remote,
diff --git a/where-shared/src/error.rs b/where-shared/src/error.rs
index 7ae0099..931aefe 100644
--- a/where-shared/src/error.rs
+++ b/where-shared/src/error.rs
@@ -1,13 +1,15 @@
use std::fmt::Display;
use std::string::FromUtf8Error;
use std::{fmt, io};
+use std::net::AddrParseError;
use std::time::Duration;
use crate::{MAX_ENTRY_LENGTH, MAX_PAYLOAD_LENGTH};
pub enum WhereError {
EncodeDecodeError(EncodeDecodeError),
IOError(io::Error),
- TimedOut(String, usize, Duration)
+ TimedOut(String, String, usize, Duration),
+ CannotParseAddress(AddrParseError)
}
pub enum EncodeDecodeError {
@@ -49,13 +51,18 @@ impl From<io::Error> for EncodeDecodeError {
}
}
-
impl From<FromUtf8Error> for EncodeDecodeError {
fn from (value: FromUtf8Error) -> Self {
Self::StringDecodeError(value)
}
}
+impl From<AddrParseError> for WhereError {
+ fn from (value: AddrParseError) -> Self {
+ Self::CannotParseAddress(value)
+ }
+}
+
impl Display for EncodeDecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
@@ -77,7 +84,8 @@ impl Display for WhereError {
match self {
Self::EncodeDecodeError(e) => write!(f, "Encode/decode error: {e}"),
Self::IOError(e) => write!(f, "Input/output error: {e}"),
- Self::TimedOut(server, max_retry, timeout) => write!(f, "Timed out waiting for data from {server} after {max_retry} attempts every {} ms", timeout.as_millis())
+ Self::TimedOut(server, address, max_retry, timeout) => write!(f, "Timed out waiting for data from {server} ({address}) after {max_retry} attempts every {} ms", timeout.as_millis()),
+ Self::CannotParseAddress(e) => write!(f, "Unable to parse server address: {e}")
}
}
}
diff --git a/where-shared/src/lib.rs b/where-shared/src/lib.rs
index dd8042a..8073a90 100644
--- a/where-shared/src/lib.rs
+++ b/where-shared/src/lib.rs
@@ -44,6 +44,12 @@ impl SessionCollection {
inner
}
}
+
+ pub fn get_empty() -> Self {
+ Self {
+ inner: vec![]
+ }
+ }
pub fn into_vec(self) -> Vec<Session> {
self.inner
diff --git a/whered/src/main.rs b/whered/src/main.rs
index 47e97de..61e010f 100644
--- a/whered/src/main.rs
+++ b/whered/src/main.rs
@@ -1,4 +1,4 @@
-use std::net::UdpSocket;
+use std::net::{SocketAddr, UdpSocket};
use where_shared::error::WhereResult;
use where_shared::{SessionCollection, WHERED_MAGIC};
@@ -10,7 +10,7 @@ fn main() {
}
fn run_server() -> WhereResult<()> {
- let socket = UdpSocket::bind("0.0.0.0:15")?;
+ let socket = UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], 15)))?;
println!("Now listening on 0.0.0.0:15");
loop {