1
1
use actix_web_lab:: sse;
2
2
use futures_util:: stream:: Stream ;
3
3
use once_cell:: sync:: Lazy ;
4
+ use rand:: rngs:: ThreadRng ;
4
5
use rand:: Rng ;
5
6
use serde:: { Deserialize , Serialize } ;
6
- use tokio:: time:: Sleep ;
7
7
use std:: future:: Future ;
8
8
use std:: pin:: Pin ;
9
9
use std:: task:: { Context , Poll } ;
10
- use rand :: rngs :: ThreadRng ;
10
+ use tokio :: time :: Sleep ;
11
11
12
12
use crate :: common:: MAX_TOKENS ;
13
13
use crate :: routes:: Usage ;
@@ -58,6 +58,8 @@ impl StreamingChunkResponse {
58
58
}
59
59
}
60
60
}
61
+ // TODO
62
+ // this can be combined with the one in routes.rs
61
63
#[ derive( Deserialize , Serialize , Debug , Default ) ]
62
64
struct Choice {
63
65
index : i32 ,
@@ -114,18 +116,20 @@ fn init_template() -> String {
114
116
impl Stream for StringsStream < ' _ > {
115
117
type Item = Result < sse:: Event , std:: convert:: Infallible > ;
116
118
119
+ // high level
120
+ // Starts with state::Input
121
+ // switch to state::Start after a random delay
122
+ // Once it reaches the end of the strings, it will switch to state::Usage if log usage is enabled
123
+ // After state::Usage, it will switch to state::Done
124
+ // If log usage is not enabled, it will switch to state::Done
125
+ // Once it reaches state::Done, it will switch to state::Completed
126
+
127
+ // init a string for faster access
128
+ // let response = StreamingChunkResponse::from_string("[INPUT]".to_string());
129
+ // let output = serde_json::to_string(&response).unwrap();
117
130
fn poll_next ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
118
131
let this = & mut * self ;
119
- // high level
120
- // Starts with state::Start
121
- // Once it reaches the end of the strings, it will switch to state::Usage if log usage is enabled
122
- // After state::Usage, it will switch to state::Done
123
- // If log usage is not enabled, it will switch to state::Done
124
- // Once it reaches state::Done, it will switch to state::Completed
125
132
126
- // init a string for faster access
127
- // let response = StreamingChunkResponse::from_string("[INPUT]".to_string());
128
- // let output = serde_json::to_string(&response).unwrap();
129
133
if let Some ( sleep) = & mut this. sleep {
130
134
if Pin :: new ( sleep) . poll ( cx) . is_pending ( ) {
131
135
return Poll :: Pending ;
@@ -134,19 +138,23 @@ impl Stream for StringsStream<'_> {
134
138
}
135
139
136
140
match this. state {
137
- State :: Input => {
141
+ State :: Input => {
138
142
// Input gives a fake TTFT
139
143
// that is your initial delay from the LLM processing the tokens
140
144
// this can typically be long
141
145
let rand = this. rng . random_range ( 500 ..1000 ) ;
142
- this. sleep = Some ( Box :: pin ( tokio:: time:: sleep ( tokio:: time:: Duration :: from_millis ( rand) ) ) ) ;
146
+ this. sleep = Some ( Box :: pin ( tokio:: time:: sleep (
147
+ tokio:: time:: Duration :: from_millis ( rand) ,
148
+ ) ) ) ;
143
149
this. state = State :: Start ;
144
150
Poll :: Pending
145
151
}
146
152
State :: Start => {
147
153
if this. index < this. max_tokens {
148
154
let rand = this. rng . random_range ( 50 ..100 ) ;
149
- this. sleep = Some ( Box :: pin ( tokio:: time:: sleep ( tokio:: time:: Duration :: from_millis ( rand) ) ) ) ;
155
+ this. sleep = Some ( Box :: pin ( tokio:: time:: sleep (
156
+ tokio:: time:: Duration :: from_millis ( rand) ,
157
+ ) ) ) ;
150
158
let string_item = & this. strings [ this. index ] ;
151
159
this. index += 1 ;
152
160
// let chunk = StreamingChunkResponse::from_string(string_item);
0 commit comments