Skip to content

Commit 8638365

Browse files
refactor: code review polish — dedup, simplify, consistency
- Move current_time_millis() to agent-tunnel-proto (R1: eliminate duplication) - Delete DomainInfo, use DomainAdvertisement directly in AgentInfo (R2) - Merge enroll_agent/bootstrap_and_persist into single function (I1) - Agent task_handles: Vec<JoinHandle> → JoinSet with reaping (I4) - Same-epoch route refresh: mutate updated_at in place, no clone (I5) - Add #[must_use] on enrollment_store::redeem() (I6) - connect_via_agent: cleaner error extraction with if-let (I3) - Add TODO for active_stream_count tracking (I2) - SECS_PER_DAY constant replaces magic 86400 (P4) - Consistent .context() for ProtoError instead of map_err (P7) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4a01ddb commit 8638365

9 files changed

Lines changed: 39 additions & 77 deletions

File tree

crates/agent-tunnel-proto/src/lib.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,14 @@ pub use error::ProtoError;
2424
pub use session::{ConnectRequest, ConnectResponse, MAX_SESSION_MESSAGE_SIZE};
2525
pub use stream::{ControlRecvStream, ControlSendStream, ControlStream, SessionStream};
2626
pub use version::{CURRENT_PROTOCOL_VERSION, MIN_SUPPORTED_VERSION, validate_protocol_version};
27+
28+
/// Current wall-clock time in milliseconds since UNIX epoch.
29+
pub fn current_time_millis() -> u64 {
30+
u64::try_from(
31+
std::time::SystemTime::now()
32+
.duration_since(std::time::UNIX_EPOCH)
33+
.expect("system time should be after unix epoch")
34+
.as_millis(),
35+
)
36+
.expect("millisecond timestamp should fit in u64")
37+
}

devolutions-agent/src/enrollment.rs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,6 @@ pub async fn enroll_agent(
5656
enrollment_token: &str,
5757
agent_name: &str,
5858
advertise_subnets: Vec<String>,
59-
) -> anyhow::Result<()> {
60-
bootstrap_and_persist(gateway_url, enrollment_token, agent_name, advertise_subnets).await?;
61-
Ok(())
62-
}
63-
64-
pub async fn bootstrap_and_persist(
65-
gateway_url: &str,
66-
enrollment_token: &str,
67-
agent_name: &str,
68-
advertise_subnets: Vec<String>,
6959
) -> anyhow::Result<PersistedEnrollment> {
7060
// Generate key pair and CSR locally — the private key never leaves this machine.
7161
let (key_pem, csr_pem) = generate_key_and_csr(agent_name)?;

devolutions-agent/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ fn main() {
301301

302302
let rt = tokio::runtime::Runtime::new().expect("failed to create tokio runtime");
303303
let result = rt.block_on(async {
304-
devolutions_agent::enrollment::bootstrap_and_persist(
304+
devolutions_agent::enrollment::enroll_agent(
305305
&command.gateway_url,
306306
&command.enrollment_token,
307307
&command.agent_name,

devolutions-agent/src/tunnel.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66
use std::sync::Arc;
77
use std::time::Duration;
88

9-
use agent_tunnel_proto::{ConnectResponse, ControlMessage, ControlRecvStream, ControlStream, SessionStream};
9+
use agent_tunnel_proto::{
10+
ConnectResponse, ControlMessage, ControlRecvStream, ControlStream, SessionStream, current_time_millis,
11+
};
1012
use anyhow::{Context as _, bail};
1113
use async_trait::async_trait;
1214
use devolutions_gateway_task::{ShutdownSignal, Task};
1315
use ipnetwork::Ipv4Network;
1416
use sha2::Digest as _;
1517

1618
use crate::config::ConfHandle;
17-
use crate::tunnel_helpers::{Target, connect_to_target, current_time_millis, resolve_target};
19+
use crate::tunnel_helpers::{Target, connect_to_target, resolve_target};
1820

1921
// ---------------------------------------------------------------------------
2022
// Custom TLS verifier: chain + hostname validation + SPKI pinning
@@ -349,8 +351,8 @@ async fn run_single_connection(conf_handle: &ConfHandle, shutdown_signal: &mut S
349351

350352
// Split: recv half goes to a reader task, send half stays for periodic messages.
351353
let (mut ctrl_send, ctrl_recv) = ctrl.into_split();
352-
let mut task_handles: Vec<tokio::task::JoinHandle<()>> = Vec::new();
353-
task_handles.push(tokio::spawn(run_control_reader(ctrl_recv)));
354+
let mut task_handles = tokio::task::JoinSet::new();
355+
task_handles.spawn(run_control_reader(ctrl_recv));
354356

355357
// -- Main loop: accept incoming session streams + periodic tasks --
356358

@@ -380,6 +382,7 @@ async fn run_single_connection(conf_handle: &ConfHandle, shutdown_signal: &mut S
380382
}
381383

382384
_ = heartbeat_tick.tick() => {
385+
// TODO: track actual active_stream_count instead of hardcoded 0.
383386
let msg = ControlMessage::heartbeat(current_time_millis(), 0);
384387
let _ = ctrl_send.send(&msg).await
385388
.inspect(|_| trace!("Sent Heartbeat"))
@@ -389,18 +392,15 @@ async fn run_single_connection(conf_handle: &ConfHandle, shutdown_signal: &mut S
389392
result = connection.accept_bi() => {
390393
let (send, recv) = result.context("accept incoming bidi stream")?;
391394
let subnets = advertise_subnets.clone();
392-
task_handles.push(tokio::spawn(run_session_proxy(subnets, send, recv)));
395+
task_handles.spawn(run_session_proxy(subnets, send, recv));
393396
}
397+
398+
// Reap completed session tasks.
399+
Some(_) = task_handles.join_next() => {}
394400
}
395401
}
396402

397-
// Abort all spawned tasks on shutdown.
398-
for handle in &task_handles {
399-
handle.abort();
400-
}
401-
for handle in task_handles {
402-
let _ = handle.await;
403-
}
403+
task_handles.shutdown().await;
404404

405405
Ok(())
406406
}

devolutions-agent/src/tunnel_helpers.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,3 @@ pub(crate) async fn connect_to_target(candidates: &[SocketAddr]) -> anyhow::Resu
8383

8484
Err(error).with_context(|| format!("TCP connect failed for {candidate}"))
8585
}
86-
87-
/// Current wall-clock time in milliseconds since UNIX epoch.
88-
pub(crate) fn current_time_millis() -> u64 {
89-
let elapsed = std::time::SystemTime::now()
90-
.duration_since(std::time::UNIX_EPOCH)
91-
.expect("system time should be after unix epoch");
92-
93-
u64::try_from(elapsed.as_millis()).expect("millisecond timestamp should fit in u64")
94-
}

devolutions-gateway/src/agent_tunnel/cert.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const CA_VALIDITY_DAYS: u32 = 3650; // ~10 years
2020
const SERVER_CERT_VALIDITY_DAYS: u32 = 365; // 1 year
2121
const AGENT_CERT_VALIDITY_DAYS: u32 = 365; // 1 year
2222

23+
const SECS_PER_DAY: u64 = 86_400;
2324
const CA_COMMON_NAME: &str = "Devolutions Gateway Agent Tunnel CA";
2425
const CA_ORG_NAME: &str = "Devolutions Inc.";
2526

@@ -33,7 +34,8 @@ fn make_ca_params() -> CertificateParams {
3334
params.key_usages.push(KeyUsagePurpose::KeyCertSign);
3435
params.key_usages.push(KeyUsagePurpose::CrlSign);
3536
params.not_before = time::OffsetDateTime::now_utc();
36-
params.not_after = time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(CA_VALIDITY_DAYS) * 86400);
37+
params.not_after =
38+
time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(CA_VALIDITY_DAYS) * SECS_PER_DAY);
3739
params
3840
}
3941

@@ -154,7 +156,7 @@ impl CaManager {
154156
.push(ExtendedKeyUsagePurpose::ClientAuth);
155157
agent_params.not_before = time::OffsetDateTime::now_utc();
156158
agent_params.not_after =
157-
time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(AGENT_CERT_VALIDITY_DAYS) * 86400);
159+
time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(AGENT_CERT_VALIDITY_DAYS) * SECS_PER_DAY);
158160

159161
// Sign with the CA, embedding the public key from the CSR.
160162
let ca_cert = self.reconstruct_ca_cert()?;
@@ -201,7 +203,7 @@ impl CaManager {
201203
.push(ExtendedKeyUsagePurpose::ServerAuth);
202204
server_params.not_before = time::OffsetDateTime::now_utc();
203205
server_params.not_after =
204-
time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(SERVER_CERT_VALIDITY_DAYS) * 86400);
206+
time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(SERVER_CERT_VALIDITY_DAYS) * SECS_PER_DAY);
205207

206208
let ca_cert = self.reconstruct_ca_cert()?;
207209

devolutions-gateway/src/agent_tunnel/enrollment_store.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ impl EnrollmentTokenStore {
5252
///
5353
/// Returns `true` if the token was valid and has been redeemed (removed).
5454
/// Returns `false` if the token doesn't exist or is expired.
55+
#[must_use = "check whether the token was valid"]
5556
pub fn redeem(&self, token: &str) -> bool {
5657
let now = current_time_secs();
5758

devolutions-gateway/src/agent_tunnel/listener.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,15 @@ impl AgentTunnelHandle {
7373
session
7474
.send_request(&connect_msg)
7575
.await
76-
.map_err(|e| anyhow::anyhow!("send ConnectRequest: {e}"))?;
76+
.context("send ConnectRequest")?;
7777

7878
// Read ConnectResponse (with timeout to prevent stalled peers).
7979
let response = tokio::time::timeout(Duration::from_secs(30), session.recv_response())
8080
.await
8181
.map_err(|_| anyhow::anyhow!("session handshake timeout (30s)"))?
82-
.map_err(|e| anyhow::anyhow!("recv ConnectResponse: {e}"))?;
82+
.context("recv ConnectResponse")?;
8383

84-
if !response.is_success() {
85-
let reason = match &response {
86-
ConnectResponse::Error { reason, .. } => reason.clone(),
87-
_ => "unknown".to_owned(),
88-
};
84+
if let ConnectResponse::Error { reason, .. } = &response {
8985
anyhow::bail!("agent refused connection: {reason}");
9086
}
9187

devolutions-gateway/src/agent_tunnel/registry.rs

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::Arc;
22
use std::sync::atomic::{AtomicU64, Ordering};
3-
use std::time::{Duration, SystemTime, UNIX_EPOCH};
3+
use std::time::{Duration, SystemTime};
44

55
use agent_tunnel_proto::DomainAdvertisement;
66
use dashmap::DashMap;
@@ -124,16 +124,10 @@ impl AgentPeer {
124124
agent_id = %self.agent_id,
125125
epoch,
126126
subnet_count = subnets.len(),
127-
domain_count = current.domains.len(),
127+
domain_count = state.domains.len(),
128128
"Refreshing route advertisement (same epoch)"
129129
);
130-
*state = RouteAdvertisementState {
131-
epoch,
132-
subnets: current.subnets.clone(),
133-
domains: current.domains.clone(),
134-
received_at: current.received_at,
135-
updated_at: now,
136-
};
130+
state.updated_at = now;
137131
} else {
138132
// New epoch (or first advertisement): replace everything.
139133
info!(
@@ -234,13 +228,6 @@ impl Default for AgentRegistry {
234228
}
235229
}
236230

237-
/// Domain info with source tracking for API responses.
238-
#[derive(Debug, Clone, Serialize)]
239-
pub struct DomainInfo {
240-
pub domain: String,
241-
pub auto_detected: bool,
242-
}
243-
244231
/// Serializable snapshot of an agent's state, suitable for API responses.
245232
#[derive(Debug, Clone, Serialize)]
246233
pub struct AgentInfo {
@@ -250,7 +237,7 @@ pub struct AgentInfo {
250237
pub is_online: bool,
251238
pub last_seen_ms: u64,
252239
pub subnets: Vec<String>,
253-
pub domains: Vec<DomainInfo>,
240+
pub domains: Vec<DomainAdvertisement>,
254241
pub route_epoch: u64,
255242
}
256243

@@ -264,29 +251,13 @@ impl From<&Arc<AgentPeer>> for AgentInfo {
264251
is_online: agent.is_online(AGENT_OFFLINE_TIMEOUT),
265252
last_seen_ms: agent.last_seen_ms(),
266253
subnets: route_state.subnets.iter().map(ToString::to_string).collect(),
267-
domains: route_state
268-
.domains
269-
.iter()
270-
.map(|d| DomainInfo {
271-
domain: d.domain.clone(),
272-
auto_detected: d.auto_detected,
273-
})
274-
.collect(),
254+
domains: route_state.domains.clone(),
275255
route_epoch: route_state.epoch,
276256
}
277257
}
278258
}
279259

280-
/// Returns the current time as milliseconds since UNIX epoch.
281-
fn current_time_millis() -> u64 {
282-
u64::try_from(
283-
SystemTime::now()
284-
.duration_since(UNIX_EPOCH)
285-
.unwrap_or(Duration::ZERO)
286-
.as_millis(),
287-
)
288-
.expect("millisecond timestamp should fit in u64")
289-
}
260+
use agent_tunnel_proto::current_time_millis;
290261

291262
#[cfg(test)]
292263
mod tests {

0 commit comments

Comments
 (0)