Skip to content

Commit 94d49e5

Browse files
committed
fix process manager follow stream regression
1 parent 9caf11a commit 94d49e5

3 files changed

Lines changed: 144 additions & 38 deletions

File tree

src/cli/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@ pub mod serve;
44
use crate::cli::client::{new_client, new_streaming_client, DEFAULT_ADDR};
55
use anyhow::{Context, Result};
66

7+
fn is_active_session(session: &serde_json::Value) -> bool {
8+
matches!(
9+
session["state"].as_str(),
10+
Some("starting" | "running" | "stopping")
11+
)
12+
}
13+
714
async fn read_json_response(resp: reqwest::Response) -> Result<serde_json::Value> {
815
let status = resp.status();
916
let body = resp.json::<serde_json::Value>().await?;
@@ -38,9 +45,10 @@ async fn resolve_id(id: Option<String>) -> Result<String> {
3845
let arr = sessions["sessions"]
3946
.as_array()
4047
.ok_or_else(|| anyhow::anyhow!("unexpected response from daemon"))?;
41-
match arr.len() {
48+
let active: Vec<&serde_json::Value> = arr.iter().filter(|s| is_active_session(s)).collect();
49+
match active.len() {
4250
0 => anyhow::bail!("no active sessions"),
43-
1 => arr[0]["id"]
51+
1 => active[0]["id"]
4452
.as_str()
4553
.map(|s| s.to_string())
4654
.ok_or_else(|| anyhow::anyhow!("session has no id")),
@@ -131,6 +139,7 @@ pub async fn cmd_stop_all() -> Result<()> {
131139
.as_array()
132140
.map(|arr| {
133141
arr.iter()
142+
.filter(|s| is_active_session(s))
134143
.filter_map(|s| s["id"].as_str().map(|s| s.to_string()))
135144
.collect()
136145
})

src/daemon/handlers.rs

Lines changed: 81 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ use futures::StreamExt;
1010
use serde::Deserialize;
1111
use std::collections::HashMap;
1212
use std::path::{Path as FsPath, PathBuf};
13-
use tokio_stream::wrappers::BroadcastStream;
1413
use uuid::Uuid;
1514

1615
use crate::daemon::buffer::{LogEntry, Stream};
@@ -327,44 +326,57 @@ pub async fn get_logs(
327326
.ok_or_else(|| AppError::NotFound("session not found".into()))?;
328327

329328
// Streaming follow mode
330-
// Both stream_name and format are owned Strings — safe to move into the closure
331-
let stream_name_clone = stream_name.clone();
332329
let format_clone = format.clone();
333-
let stream_mode = stream_name.clone();
334-
let stream = BroadcastStream::new(rx).flat_map(move |msg| {
335-
let line = match msg {
336-
Ok(entry) => {
337-
if entry.seq < snapshot_next_seq {
338-
return futures::stream::empty().left_stream();
339-
}
340-
// Filter: for blended, show all; for stdout/stderr, show only that stream; for unknown, show nothing
341-
let should_include = match stream_name_clone.as_str() {
342-
"blended" => true,
343-
"stdout" => entry.stream == Stream::Stdout,
344-
"stderr" => entry.stream == Stream::Stderr,
345-
_ => false, // unknown stream name → no entries
346-
};
347-
if !should_include {
348-
return futures::stream::empty().left_stream();
349-
}
350-
format_entry(&entry, &stream_mode, &format_clone)
351-
}
352-
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => {
353-
if format_clone == "text" {
354-
format!("[korun: dropped {n} lines]\n")
355-
} else {
356-
let e = serde_json::json!({
357-
"seq": 0u64,
358-
"ts": Utc::now().to_rfc3339(),
359-
"stream": "system",
360-
"line": format!("dropped {n} lines"),
361-
});
362-
format!("{e}\n")
330+
let mgr_for_follow = mgr.clone();
331+
let stream = futures::stream::unfold(
332+
FollowState {
333+
rx,
334+
mgr: mgr_for_follow,
335+
session_id: id,
336+
stream_name: stream_name.clone(),
337+
format: format_clone,
338+
cutoff_seq: snapshot_next_seq,
339+
last_seen_seq: snapshot_next_seq,
340+
},
341+
|mut state| async move {
342+
loop {
343+
tokio::select! {
344+
msg = state.rx.recv() => {
345+
match msg {
346+
Ok(entry) => {
347+
if entry.seq < state.cutoff_seq {
348+
continue;
349+
}
350+
if !should_include_stream(&state.stream_name, entry.stream) {
351+
continue;
352+
}
353+
state.last_seen_seq = state.last_seen_seq.max(entry.seq.saturating_add(1));
354+
let line = format_entry(&entry, &state.stream_name, &state.format);
355+
return Some((Ok::<_, std::convert::Infallible>(line), state));
356+
}
357+
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
358+
let line = format_lag_notice(n, &state.format);
359+
return Some((Ok::<_, std::convert::Infallible>(line), state));
360+
}
361+
Err(tokio::sync::broadcast::error::RecvError::Closed) => return None,
362+
}
363+
}
364+
_ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
365+
let (is_terminal, next_seq) = match state
366+
.mgr
367+
.with(&state.session_id, |s| (matches!(s.state, SessionState::Exited | SessionState::Failed), s.next_seq()))
368+
{
369+
Some(values) => values,
370+
None => return None,
371+
};
372+
if is_terminal && next_seq <= state.last_seen_seq {
373+
return None;
374+
}
375+
}
363376
}
364377
}
365-
};
366-
futures::stream::once(async move { Ok::<_, std::convert::Infallible>(line) }).right_stream()
367-
});
378+
},
379+
);
368380

369381
// First send buffered entries, then stream new ones.
370382
// Use into_iter() + move to avoid capturing &format (which would prevent 'static bound).
@@ -454,6 +466,39 @@ fn format_entry(entry: &LogEntry, requested_stream: &str, format: &str) -> Strin
454466
}
455467
}
456468

469+
fn should_include_stream(stream_name: &str, stream: Stream) -> bool {
470+
match stream_name {
471+
"blended" => true,
472+
"stdout" => stream == Stream::Stdout,
473+
"stderr" => stream == Stream::Stderr,
474+
_ => false,
475+
}
476+
}
477+
478+
fn format_lag_notice(n: u64, format: &str) -> String {
479+
if format == "text" {
480+
format!("[korun: dropped {n} lines]\n")
481+
} else {
482+
let e = serde_json::json!({
483+
"seq": 0u64,
484+
"ts": Utc::now().to_rfc3339(),
485+
"stream": "system",
486+
"line": format!("dropped {n} lines"),
487+
});
488+
format!("{e}\n")
489+
}
490+
}
491+
492+
struct FollowState {
493+
rx: tokio::sync::broadcast::Receiver<LogEntry>,
494+
mgr: AppState,
495+
session_id: Uuid,
496+
stream_name: String,
497+
format: String,
498+
cutoff_seq: u64,
499+
last_seen_seq: u64,
500+
}
501+
457502
fn validate_create_session_request(
458503
command: &[String],
459504
cwd: &str,

tests/http_test.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use korun::daemon::session_mgr::SessionManager;
55
use serde_json::Value;
66
use tempfile::tempdir;
77
use tower::util::ServiceExt;
8+
use tokio::time::{timeout, Duration};
89

910
#[tokio::test]
1011
async fn healthz_returns_200() {
@@ -156,3 +157,54 @@ async fn create_session_retains_watcher_handle() {
156157
let id: uuid::Uuid = json["id"].as_str().unwrap().parse().unwrap();
157158
assert!(mgr.with(&id, |s| s.watcher.is_some()).unwrap());
158159
}
160+
161+
#[tokio::test]
162+
async fn follow_logs_closes_after_process_exits() {
163+
let mgr = SessionManager::new();
164+
let app = build_router(mgr);
165+
166+
let body = serde_json::json!({
167+
"command": ["sh", "-c", "echo hi"]
168+
});
169+
170+
let response = app
171+
.clone()
172+
.oneshot(
173+
Request::builder()
174+
.method("POST")
175+
.uri("/v1/sessions")
176+
.header("content-type", "application/json")
177+
.body(Body::from(serde_json::to_vec(&body).unwrap()))
178+
.unwrap(),
179+
)
180+
.await
181+
.unwrap();
182+
183+
assert_eq!(response.status(), StatusCode::CREATED);
184+
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
185+
.await
186+
.unwrap();
187+
let json: Value = serde_json::from_slice(&bytes).unwrap();
188+
let id = json["id"].as_str().unwrap();
189+
190+
let response = app
191+
.oneshot(
192+
Request::builder()
193+
.uri(format!("/v1/sessions/{id}/logs?follow=1&format=text"))
194+
.body(Body::empty())
195+
.unwrap(),
196+
)
197+
.await
198+
.unwrap();
199+
200+
assert_eq!(response.status(), StatusCode::OK);
201+
let bytes = timeout(
202+
Duration::from_secs(2),
203+
axum::body::to_bytes(response.into_body(), usize::MAX),
204+
)
205+
.await
206+
.expect("follow stream should close after process exit")
207+
.unwrap();
208+
let text = String::from_utf8(bytes.to_vec()).unwrap();
209+
assert!(text.contains("[stdout] hi"));
210+
}

0 commit comments

Comments
 (0)