10
10
#include " main.cpp"
11
11
#include " extra.h"
12
12
13
+ void print_tok_vec (std::vector<llama_token> & embd)
14
+ {
15
+ std::cout << " [" ;
16
+ bool first = true ;
17
+ for (auto i: embd) {
18
+ if (!first)
19
+ {
20
+ std::cout << ' ,' ;
21
+ }
22
+ first = false ;
23
+ std::cout << i;
24
+ }
25
+ std::cout << " ]" ;
26
+ }
27
+
13
28
extern " C" {
14
29
15
30
struct load_model_inputs
@@ -31,7 +46,6 @@ extern "C" {
31
46
const float top_p;
32
47
const float rep_pen;
33
48
const int rep_pen_range;
34
- const bool reset_state = true ; // determines if we can continue off the previous prompt state
35
49
};
36
50
struct generation_outputs
37
51
{
@@ -43,12 +57,12 @@ extern "C" {
43
57
llama_context_params ctx_params;
44
58
gpt_params params;
45
59
int n_past = 0 ;
46
- llama_token old_embd_id = -1 ;
47
60
int n_threads = 4 ;
48
61
int n_batch = 8 ;
49
62
std::string model;
50
63
llama_context * ctx;
51
64
std::vector<llama_token> last_n_tokens;
65
+ std::vector<llama_token> current_context_tokens;
52
66
53
67
bool load_model (const load_model_inputs inputs)
54
68
{
@@ -80,6 +94,10 @@ extern "C" {
80
94
printf (" \n ---\n Warning: Your model is using an OUTDATED format. Please reconvert it for better results!\n " );
81
95
}
82
96
97
+ // determine mem per token
98
+ const std::vector<llama_token> tmp = { 0 , 1 , 2 , 3 };
99
+ llama_eval (ctx, tmp.data (), tmp.size (), 0 , params.n_threads );
100
+
83
101
return true ;
84
102
}
85
103
@@ -96,12 +114,6 @@ extern "C" {
96
114
params.n_ctx = inputs.max_context_length ;
97
115
params.n_batch = n_batch;
98
116
params.n_threads = n_threads;
99
-
100
- bool reset_state = inputs.reset_state ;
101
- if (n_past==0 )
102
- {
103
- reset_state = true ;
104
- }
105
117
106
118
if (params.repeat_last_n <1 )
107
119
{
@@ -115,12 +127,9 @@ extern "C" {
115
127
{
116
128
params.seed = time (NULL );
117
129
}
118
-
119
- if (reset_state)
120
- {
121
- params.prompt .insert (0 , 1 , ' ' );
122
- }
123
-
130
+
131
+ params.prompt .insert (0 , 1 , ' ' );
132
+
124
133
// tokenize the prompt
125
134
std::vector<llama_token> embd_inp;
126
135
if (legacy_format)
@@ -135,7 +144,10 @@ extern "C" {
135
144
if (embd_inp.size () + params.n_predict > params.n_ctx ) {
136
145
int offset = embd_inp.size () - params.n_ctx + params.n_predict ;
137
146
embd_inp = std::vector<llama_token>(embd_inp.begin () + offset, embd_inp.end ());
138
- }
147
+ }
148
+
149
+ // determine how much npast we have to rewind from the current state
150
+
139
151
std::vector<llama_token> embd;
140
152
141
153
int last_n_size = params.repeat_last_n ;
@@ -145,26 +157,30 @@ extern "C" {
145
157
// std::string tst = " ";
146
158
// char * tst2 = (char*)tst.c_str();
147
159
// gpt_print_usage(1,&tst2,params);
160
+
161
+ std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
162
+ n_past = 0 ;
148
163
149
- if (reset_state)
164
+ // fast forward the past based on identical tokens, stop once a divergence is noted
165
+ for (int i=0 ;i<current_context_tokens.size ();++i)
150
166
{
151
- const std::vector<llama_token> tmp = { 0 , 1 , 2 , 3 };
152
- llama_eval (ctx, tmp.data (), tmp.size (), 0 , params.n_threads );
153
- std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
154
- n_past = 0 ;
155
- }
156
- else
157
- {
158
- // strip out the reset token (1) at the start of the embedding
159
- if (embd_inp.size ()>0 )
167
+ if (current_context_tokens[i]==embd_inp[0 ])
160
168
{
169
+ n_past += 1 ;
161
170
embd_inp.erase (embd_inp.begin ());
171
+ last_n_tokens.erase (last_n_tokens.begin ());
172
+ last_n_tokens.push_back (current_context_tokens[i]);
162
173
}
163
- if (old_embd_id!=- 1 )
174
+ else
164
175
{
165
- embd.push_back (old_embd_id);
176
+ break ;
177
+ }
178
+ if (embd_inp.size ()<=1 )
179
+ {
180
+ break ;
166
181
}
167
182
}
183
+ current_context_tokens.resize (n_past);
168
184
169
185
int remaining_tokens = params.n_predict ;
170
186
int input_consumed = 0 ;
@@ -180,11 +196,8 @@ extern "C" {
180
196
// predict
181
197
if (embd.size () > 0 )
182
198
{
183
- printf (" |" );
184
- // for (auto i: embd) {
185
- // std::cout << i << ',';
186
- // }
187
- // printf("\nnp:%d embd:%d",n_past,embd.size());
199
+ printf (" |" );
200
+ // printf("\nnp:%d embd:%d txt:%s",n_past,embd.size(),llama_token_to_str(ctx, embd[0]));
188
201
if (llama_eval (ctx, embd.data (), embd.size (), n_past, params.n_threads ))
189
202
{
190
203
fprintf (stderr, " Failed to predict\n " );
@@ -222,13 +235,12 @@ extern "C" {
222
235
223
236
last_n_tokens.erase (last_n_tokens.begin ());
224
237
last_n_tokens.push_back (id);
238
+ current_context_tokens.push_back (id);
225
239
}
226
240
227
241
// add it to the context
228
- old_embd_id = id;
229
242
embd.push_back (id);
230
243
231
-
232
244
// decrement remaining sampling budget
233
245
--remaining_tokens;
234
246
// printf("\nid:%d word:%s\n",id,llama_token_to_str(ctx, id));
@@ -239,10 +251,10 @@ extern "C" {
239
251
// some user input remains from prompt or interaction, forward it to processing
240
252
while ((int ) embd_inp.size () > input_consumed)
241
253
{
242
- old_embd_id = embd_inp[input_consumed];
243
254
embd.push_back (embd_inp[input_consumed]);
244
255
last_n_tokens.erase (last_n_tokens.begin ());
245
256
last_n_tokens.push_back (embd_inp[input_consumed]);
257
+ current_context_tokens.push_back (embd_inp[input_consumed]);
246
258
++input_consumed;
247
259
if ((int ) embd.size () >= params.n_batch )
248
260
{
0 commit comments