Skip to content

Commit 706e19e

Browse files
committed
added ability to fast forward in time through partially duplicated prompts
1 parent 1166fda commit 706e19e

File tree

3 files changed

+53
-53
lines changed

3 files changed

+53
-53
lines changed

expose.cpp

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,21 @@
1010
#include "main.cpp"
1111
#include "extra.h"
1212

13+
void print_tok_vec(std::vector<llama_token> & embd)
14+
{
15+
std::cout << "[";
16+
bool first = true;
17+
for (auto i: embd) {
18+
if(!first)
19+
{
20+
std::cout << ',';
21+
}
22+
first = false;
23+
std::cout << i;
24+
}
25+
std::cout << "]";
26+
}
27+
1328
extern "C" {
1429

1530
struct load_model_inputs
@@ -31,7 +46,6 @@ extern "C" {
3146
const float top_p;
3247
const float rep_pen;
3348
const int rep_pen_range;
34-
const bool reset_state = true; //determines if we can continue off the previous prompt state
3549
};
3650
struct generation_outputs
3751
{
@@ -43,12 +57,12 @@ extern "C" {
4357
llama_context_params ctx_params;
4458
gpt_params params;
4559
int n_past = 0;
46-
llama_token old_embd_id = -1;
4760
int n_threads = 4;
4861
int n_batch = 8;
4962
std::string model;
5063
llama_context * ctx;
5164
std::vector<llama_token> last_n_tokens;
65+
std::vector<llama_token> current_context_tokens;
5266

5367
bool load_model(const load_model_inputs inputs)
5468
{
@@ -80,6 +94,10 @@ extern "C" {
8094
printf("\n---\nWarning: Your model is using an OUTDATED format. Please reconvert it for better results!\n");
8195
}
8296

97+
//determine mem per token
98+
const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
99+
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
100+
83101
return true;
84102
}
85103

@@ -96,12 +114,6 @@ extern "C" {
96114
params.n_ctx = inputs.max_context_length;
97115
params.n_batch = n_batch;
98116
params.n_threads = n_threads;
99-
100-
bool reset_state = inputs.reset_state;
101-
if(n_past==0)
102-
{
103-
reset_state = true;
104-
}
105117

106118
if(params.repeat_last_n<1)
107119
{
@@ -115,12 +127,9 @@ extern "C" {
115127
{
116128
params.seed = time(NULL);
117129
}
118-
119-
if(reset_state)
120-
{
121-
params.prompt.insert(0, 1, ' ');
122-
}
123-
130+
131+
params.prompt.insert(0, 1, ' ');
132+
124133
// tokenize the prompt
125134
std::vector<llama_token> embd_inp;
126135
if(legacy_format)
@@ -135,7 +144,10 @@ extern "C" {
135144
if (embd_inp.size() + params.n_predict > params.n_ctx) {
136145
int offset = embd_inp.size() - params.n_ctx + params.n_predict;
137146
embd_inp = std::vector<llama_token>(embd_inp.begin() + offset, embd_inp.end());
138-
}
147+
}
148+
149+
//determine how much npast we have to rewind from the current state
150+
139151
std::vector<llama_token> embd;
140152

141153
int last_n_size = params.repeat_last_n;
@@ -145,26 +157,30 @@ extern "C" {
145157
// std::string tst = " ";
146158
// char * tst2 = (char*)tst.c_str();
147159
// gpt_print_usage(1,&tst2,params);
160+
161+
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
162+
n_past = 0;
148163

149-
if(reset_state)
164+
//fast forward the past based on identical tokens, stop once a divergence is noted
165+
for(int i=0;i<current_context_tokens.size();++i)
150166
{
151-
const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
152-
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
153-
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
154-
n_past = 0;
155-
}
156-
else
157-
{
158-
//strip out the reset token (1) at the start of the embedding
159-
if(embd_inp.size()>0)
167+
if(current_context_tokens[i]==embd_inp[0])
160168
{
169+
n_past += 1;
161170
embd_inp.erase(embd_inp.begin());
171+
last_n_tokens.erase(last_n_tokens.begin());
172+
last_n_tokens.push_back(current_context_tokens[i]);
162173
}
163-
if(old_embd_id!=-1)
174+
else
164175
{
165-
embd.push_back(old_embd_id);
176+
break;
177+
}
178+
if(embd_inp.size()<=1)
179+
{
180+
break;
166181
}
167182
}
183+
current_context_tokens.resize(n_past);
168184

169185
int remaining_tokens = params.n_predict;
170186
int input_consumed = 0;
@@ -180,11 +196,8 @@ extern "C" {
180196
// predict
181197
if (embd.size() > 0)
182198
{
183-
printf("|");
184-
// for (auto i: embd) {
185-
// std::cout << i << ',';
186-
// }
187-
// printf("\nnp:%d embd:%d",n_past,embd.size());
199+
printf("|");
200+
//printf("\nnp:%d embd:%d txt:%s",n_past,embd.size(),llama_token_to_str(ctx, embd[0]));
188201
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads))
189202
{
190203
fprintf(stderr, "Failed to predict\n");
@@ -222,13 +235,12 @@ extern "C" {
222235

223236
last_n_tokens.erase(last_n_tokens.begin());
224237
last_n_tokens.push_back(id);
238+
current_context_tokens.push_back(id);
225239
}
226240

227241
// add it to the context
228-
old_embd_id = id;
229242
embd.push_back(id);
230243

231-
232244
// decrement remaining sampling budget
233245
--remaining_tokens;
234246
//printf("\nid:%d word:%s\n",id,llama_token_to_str(ctx, id));
@@ -239,10 +251,10 @@ extern "C" {
239251
// some user input remains from prompt or interaction, forward it to processing
240252
while ((int) embd_inp.size() > input_consumed)
241253
{
242-
old_embd_id = embd_inp[input_consumed];
243254
embd.push_back(embd_inp[input_consumed]);
244255
last_n_tokens.erase(last_n_tokens.begin());
245256
last_n_tokens.push_back(embd_inp[input_consumed]);
257+
current_context_tokens.push_back(embd_inp[input_consumed]);
246258
++input_consumed;
247259
if ((int) embd.size() >= params.n_batch)
248260
{

llama_for_kobold.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ class generation_inputs(ctypes.Structure):
2323
("top_k", ctypes.c_int),
2424
("top_p", ctypes.c_float),
2525
("rep_pen", ctypes.c_float),
26-
("rep_pen_range", ctypes.c_int),
27-
("reset_state", ctypes.c_bool)]
26+
("rep_pen_range", ctypes.c_int)]
2827

2928
class generation_outputs(ctypes.Structure):
3029
_fields_ = [("status", ctypes.c_int),
@@ -48,7 +47,7 @@ def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwr
4847
ret = handle.load_model(inputs)
4948
return ret
5049

51-
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1,reset_state=True):
50+
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1):
5251
inputs = generation_inputs()
5352
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
5453
inputs.prompt = prompt.encode("UTF-8")
@@ -60,7 +59,6 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
6059
inputs.rep_pen = rep_pen
6160
inputs.rep_pen_range = rep_pen_range
6261
inputs.seed = seed
63-
inputs.reset_state = reset_state
6462
ret = handle.generate(inputs,outputs)
6563
if(ret.status==1):
6664
return ret.text.decode("UTF-8")
@@ -80,7 +78,6 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
8078
maxlen = 128
8179
modelbusy = False
8280
port = 5001
83-
last_context = ""
8481
embedded_kailite = None
8582

8683
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
@@ -130,7 +127,6 @@ def do_GET(self):
130127

131128
def do_POST(self):
132129
global modelbusy
133-
global last_context
134130
content_length = int(self.headers['Content-Length'])
135131
body = self.rfile.read(content_length)
136132

@@ -159,18 +155,14 @@ def do_POST(self):
159155
self.end_headers()
160156
return
161157
print("\nInput: " + json.dumps(genparams))
162-
fresh_state = True
158+
163159
modelbusy = True
164160
if kai_api_flag:
165161
fullprompt = genparams.get('prompt', "")
166162
else:
167163
fullprompt = genparams.get('text', "")
168164
newprompt = fullprompt
169-
if last_context!="" and newprompt.startswith(last_context):
170-
fresh_state = False
171-
newprompt = newprompt[len(last_context):]
172-
print("Resuming state, new input len: " + str(len(newprompt)))
173-
165+
174166

175167
recvtxt = ""
176168
if kai_api_flag:
@@ -183,11 +175,9 @@ def do_POST(self):
183175
top_p=genparams.get('top_p', 0.85),
184176
rep_pen=genparams.get('rep_pen', 1.1),
185177
rep_pen_range=genparams.get('rep_pen_range', 128),
186-
seed=-1,
187-
reset_state=fresh_state
178+
seed=-1
188179
)
189180
print("\nOutput: " + recvtxt)
190-
last_context = fullprompt + recvtxt
191181
res = {"results": [{"text": recvtxt}]}
192182
self.send_response(200)
193183
self.end_headers()
@@ -201,11 +191,9 @@ def do_POST(self):
201191
top_p=genparams.get('top_p', 0.85),
202192
rep_pen=genparams.get('rep_pen', 1.1),
203193
rep_pen_range=genparams.get('rep_pen_range', 128),
204-
seed=-1,
205-
reset_state=fresh_state
194+
seed=-1
206195
)
207196
print("\nOutput: " + recvtxt)
208-
last_context = fullprompt + recvtxt
209197
res = {"data": {"seqs":[recvtxt]}}
210198
self.send_response(200)
211199
self.end_headers()

llamacpp.dll

1.08 KB
Binary file not shown.

0 commit comments

Comments
 (0)