Skip to content

Commit 1413931

Browse files
committed
Add a header bar and redesign the interface (oobabooga#293)
1 parent 9d6a625 commit 1413931

File tree

3 files changed

+102
-74
lines changed

3 files changed

+102
-74
lines changed

extensions/gallery/script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def generate_html():
7676
return container_html
7777

7878
def ui():
79-
with gr.Accordion("Character gallery"):
79+
with gr.Accordion("Character gallery", open=False):
8080
update = gr.Button("Refresh")
8181
gallery = gr.HTML(value=generate_html())
8282
update.click(generate_html, [], gallery)

modules/ui.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
ol li p, ul li p {
3939
display: inline-block;
4040
}
41+
#main, #settings, #extensions, #chat-settings {
42+
border: 0;
43+
}
4144
"""
4245

4346
chat_css = """
@@ -64,6 +67,12 @@
6467
}
6568
"""
6669

70+
page_js = """
71+
document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px"
72+
document.getElementById("main").parentNode.style = "padding: 0; margin: 0"
73+
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0"
74+
"""
75+
6776
class ToolButton(gr.Button, gr.components.FormComponent):
6877
"""Small button with single emoji as text, fits inside gradio forms"""
6978

server.py

Lines changed: 92 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def upload_soft_prompt(file):
101101

102102
return name
103103

104-
def create_settings_menus(default_preset):
105-
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
106-
104+
def create_model_and_preset_menus():
107105
with gr.Row():
108106
with gr.Column():
109107
with gr.Row():
@@ -114,7 +112,11 @@ def create_settings_menus(default_preset):
114112
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
115113
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
116114

117-
with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'):
115+
def create_settings_menus(default_preset):
116+
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
117+
118+
with gr.Box():
119+
gr.Markdown('Custom generation parameters')
118120
with gr.Row():
119121
with gr.Column():
120122
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
@@ -128,9 +130,11 @@ def create_settings_menus(default_preset):
128130
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
129131
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
130132

133+
with gr.Box():
131134
gr.Markdown('Contrastive search:')
132135
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
133136

137+
with gr.Box():
134138
gr.Markdown('Beam search (uses a lot of VRAM):')
135139
with gr.Row():
136140
with gr.Column():
@@ -139,7 +143,8 @@ def create_settings_menus(default_preset):
139143
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
140144
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
141145

142-
with gr.Accordion('Soft prompt', open=False, elem_id='accordion'):
146+
with gr.Box():
147+
gr.Markdown('Soft prompt')
143148
with gr.Row():
144149
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
145150
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
@@ -202,26 +207,41 @@ def create_settings_menus(default_preset):
202207

203208
if shared.args.chat or shared.args.cai_chat:
204209
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
205-
if shared.args.cai_chat:
206-
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
207-
else:
208-
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
209-
shared.gradio['textbox'] = gr.Textbox(label='Input')
210-
with gr.Row():
211-
shared.gradio['Stop'] = gr.Button('Stop')
212-
shared.gradio['Generate'] = gr.Button('Generate')
213-
with gr.Row():
214-
shared.gradio['Impersonate'] = gr.Button('Impersonate')
215-
shared.gradio['Regenerate'] = gr.Button('Regenerate')
216-
with gr.Row():
217-
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
218-
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
219-
shared.gradio['Remove last'] = gr.Button('Remove last')
220-
221-
shared.gradio['Clear history'] = gr.Button('Clear history')
222-
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
223-
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
224-
with gr.Tab('Chat settings'):
210+
with gr.Tab("Text generation", elem_id="main"):
211+
if shared.args.cai_chat:
212+
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
213+
else:
214+
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
215+
shared.gradio['textbox'] = gr.Textbox(label='Input')
216+
with gr.Row():
217+
shared.gradio['Stop'] = gr.Button('Stop')
218+
shared.gradio['Generate'] = gr.Button('Generate')
219+
with gr.Row():
220+
shared.gradio['Impersonate'] = gr.Button('Impersonate')
221+
shared.gradio['Regenerate'] = gr.Button('Regenerate')
222+
with gr.Row():
223+
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
224+
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
225+
shared.gradio['Remove last'] = gr.Button('Remove last')
226+
227+
shared.gradio['Clear history'] = gr.Button('Clear history')
228+
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
229+
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
230+
231+
create_model_and_preset_menus()
232+
233+
with gr.Box():
234+
with gr.Row():
235+
with gr.Column():
236+
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
237+
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
238+
with gr.Column():
239+
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
240+
241+
if shared.args.extensions is not None:
242+
extensions_module.create_extensions_block()
243+
244+
with gr.Tab("Chat settings", elem_id="chat-settings"):
225245
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
226246
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
227247
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
@@ -255,21 +275,11 @@ def create_settings_menus(default_preset):
255275
with gr.Tab('Upload TavernAI Character Card'):
256276
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
257277

258-
with gr.Tab('Generation settings'):
259-
with gr.Row():
260-
with gr.Column():
261-
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
262-
with gr.Column():
263-
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
264-
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
278+
with gr.Tab("Settings", elem_id="settings"):
265279
create_settings_menus(default_preset)
266280

267-
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
268-
if shared.args.extensions is not None:
269-
with gr.Tab('Extensions'):
270-
extensions_module.create_extensions_block()
271-
272281
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
282+
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
273283

274284
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
275285
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
@@ -310,65 +320,74 @@ def create_settings_menus(default_preset):
310320
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
311321
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
312322

323+
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}")
313324
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
314325
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
315326

316327
elif shared.args.notebook:
317328
with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
318-
gr.Markdown(description)
319-
with gr.Tab('Raw'):
320-
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23)
321-
with gr.Tab('Markdown'):
322-
shared.gradio['markdown'] = gr.Markdown()
323-
with gr.Tab('HTML'):
324-
shared.gradio['html'] = gr.HTML()
325-
326-
shared.gradio['Generate'] = gr.Button('Generate')
327-
shared.gradio['Stop'] = gr.Button('Stop')
328-
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
329-
330-
create_settings_menus(default_preset)
331-
if shared.args.extensions is not None:
332-
extensions_module.create_extensions_block()
329+
with gr.Tab("Text generation", elem_id="main"):
330+
with gr.Tab('Raw'):
331+
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
332+
with gr.Tab('Markdown'):
333+
shared.gradio['markdown'] = gr.Markdown()
334+
with gr.Tab('HTML'):
335+
shared.gradio['html'] = gr.HTML()
336+
337+
with gr.Row():
338+
shared.gradio['Stop'] = gr.Button('Stop')
339+
shared.gradio['Generate'] = gr.Button('Generate')
340+
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
341+
342+
create_model_and_preset_menus()
343+
if shared.args.extensions is not None:
344+
extensions_module.create_extensions_block()
345+
346+
with gr.Tab("Settings", elem_id="settings"):
347+
create_settings_menus(default_preset)
333348

334349
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
335350
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
336351
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
337352
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
338353
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
354+
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}")
339355

340356
else:
341357
with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
342-
gr.Markdown(description)
343-
with gr.Row():
344-
with gr.Column():
345-
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
346-
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
347-
shared.gradio['Generate'] = gr.Button('Generate')
348-
with gr.Row():
349-
with gr.Column():
350-
shared.gradio['Continue'] = gr.Button('Continue')
351-
with gr.Column():
352-
shared.gradio['Stop'] = gr.Button('Stop')
358+
with gr.Tab("Text generation", elem_id="main"):
359+
with gr.Row():
360+
with gr.Column():
361+
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
362+
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
363+
shared.gradio['Generate'] = gr.Button('Generate')
364+
with gr.Row():
365+
with gr.Column():
366+
shared.gradio['Continue'] = gr.Button('Continue')
367+
with gr.Column():
368+
shared.gradio['Stop'] = gr.Button('Stop')
353369

354-
create_settings_menus(default_preset)
355-
if shared.args.extensions is not None:
356-
extensions_module.create_extensions_block()
370+
create_model_and_preset_menus()
371+
if shared.args.extensions is not None:
372+
extensions_module.create_extensions_block()
357373

358-
with gr.Column():
359-
with gr.Tab('Raw'):
360-
shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output')
361-
with gr.Tab('Markdown'):
362-
shared.gradio['markdown'] = gr.Markdown()
363-
with gr.Tab('HTML'):
364-
shared.gradio['html'] = gr.HTML()
374+
with gr.Column():
375+
with gr.Tab('Raw'):
376+
shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output')
377+
with gr.Tab('Markdown'):
378+
shared.gradio['markdown'] = gr.Markdown()
379+
with gr.Tab('HTML'):
380+
shared.gradio['html'] = gr.HTML()
381+
with gr.Tab("Settings", elem_id="settings"):
382+
create_settings_menus(default_preset)
365383

366384
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
367385
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
368386
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
369387
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
370388
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
371389
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
390+
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}")
372391

373392
shared.gradio['interface'].queue()
374393
if shared.args.listen:

0 commit comments

Comments
 (0)