Skip to content

Commit 104293f

Browse files
committed
Add LoRA support
1 parent ee164d1 commit 104293f

File tree

6 files changed

+51
-8
lines changed

6 files changed

+51
-8
lines changed

css/main.css

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
.tabs.svelte-710i53 {
22
margin-top: 0
33
}
4+
45
.py-6 {
56
padding-top: 2.5rem
67
}
8+
79
.dark #refresh-button {
810
background-color: #ffffff1f;
911
}
12+
1013
#refresh-button {
1114
flex: none;
1215
margin: 0;
@@ -17,22 +20,28 @@
1720
border-radius: 10px;
1821
background-color: #0000000d;
1922
}
23+
2024
#download-label, #upload-label {
2125
min-height: 0
2226
}
27+
2328
#accordion {
2429
}
30+
2531
.dark svg {
2632
fill: white;
2733
}
34+
2835
svg {
2936
display: unset !important;
3037
vertical-align: middle !important;
3138
margin: 5px;
3239
}
40+
3341
ol li p, ul li p {
3442
display: inline-block;
3543
}
36-
#main, #parameters, #chat-settings, #interface-mode {
44+
45+
#main, #parameters, #chat-settings, #interface-mode, #lora {
3746
border: 0;
3847
}

download-model.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch):
101101
classifications = []
102102
has_pytorch = False
103103
has_safetensors = False
104+
is_lora = False
104105
while True:
105106
content = requests.get(f"{base}{page}{cursor.decode()}").content
106107

@@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch):
110111

111112
for i in range(len(dict)):
112113
fname = dict[i]['path']
114+
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
115+
is_lora = True
113116

114-
is_pytorch = re.match("pytorch_model.*\.bin", fname)
117+
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
115118
is_safetensors = re.match("model.*\.safetensors", fname)
116119
is_tokenizer = re.match("tokenizer.*\.model", fname)
117120
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
@@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
130133
has_pytorch = True
131134
classifications.append('pytorch')
132135

136+
133137
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
134138
cursor = base64.b64encode(cursor)
135139
cursor = cursor.replace(b'=', b'%3D')
@@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch):
140144
if classifications[i] == 'pytorch':
141145
links.pop(i)
142146

143-
return links
147+
return links, is_lora
144148

145149
if __name__ == '__main__':
146150
model = args.MODEL
@@ -159,15 +163,16 @@ def get_download_links_from_huggingface(model, branch):
159163
except ValueError as err_branch:
160164
print(f"Error: {err_branch}")
161165
sys.exit()
166+
167+
links, is_lora = get_download_links_from_huggingface(model, branch)
168+
base_folder = 'models' if not is_lora else 'loras'
162169
if branch != 'main':
163-
output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
170+
output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
164171
else:
165-
output_folder = Path("models") / model.split('/')[-1]
172+
output_folder = Path(base_folder) / model.split('/')[-1]
166173
if not output_folder.exists():
167174
output_folder.mkdir()
168175

169-
links = get_download_links_from_huggingface(model, branch)
170-
171176
# Downloading the files
172177
print(f"Downloading the model to {output_folder}")
173178
pool = multiprocessing.Pool(processes=args.threads)

modules/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
1212
BitsAndBytesConfig)
1313

14+
from peft import PeftModel
15+
1416
import modules.shared as shared
1517

1618
transformers.logging.set_verbosity_error()

modules/shared.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
model = None
44
tokenizer = None
5-
model_name = ""
5+
model_name = "None"
6+
lora_name = "None"
67
soft_prompt_tensor = None
78
soft_prompt = False
89
is_RWKV = False

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ flexgen==0.1.7
44
gradio==3.18.0
55
markdown
66
numpy
7+
peft==0.2.0
78
requests
89
rwkv==0.4.2
910
safetensors==0.3.0

server.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from modules.html_generator import generate_chat_html
1818
from modules.models import load_model, load_soft_prompt
1919
from modules.text_generation import generate_reply
20+
from modules.LoRA import add_lora_to_model
2021

2122
# Loading custom settings
2223
settings_file = None
@@ -48,6 +49,9 @@ def get_available_extensions():
4849
def get_available_softprompts():
4950
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
5051

52+
def get_available_loras():
53+
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
54+
5155
def load_model_wrapper(selected_model):
5256
if selected_model != shared.model_name:
5357
shared.model_name = selected_model
@@ -59,6 +63,13 @@ def load_model_wrapper(selected_model):
5963

6064
return selected_model
6165

66+
def load_lora_wrapper(selected_lora):
67+
if not shared.args.cpu:
68+
gc.collect()
69+
torch.cuda.empty_cache()
70+
add_lora_to_model(selected_lora)
71+
return selected_lora
72+
6273
def load_preset_values(preset_menu, return_dict=False):
6374
generate_params = {
6475
'do_sample': True,
@@ -181,6 +192,7 @@ def set_interface_arguments(interface_mode, extensions, cmd_active):
181192
available_presets = get_available_presets()
182193
available_characters = get_available_characters()
183194
available_softprompts = get_available_softprompts()
195+
available_loras = get_available_loras()
184196

185197
# Default extensions
186198
extensions_module.available_extensions = get_available_extensions()
@@ -401,6 +413,19 @@ def create_interface():
401413
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
402414
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
403415

416+
with gr.Tab("LoRA", elem_id="lora"):
417+
with gr.Row():
418+
with gr.Column():
419+
gr.Markdown("Load")
420+
with gr.Row():
421+
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
422+
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
423+
with gr.Column():
424+
gr.Markdown("Train (TODO)")
425+
gr.Button("Practice your button clicking skills")
426+
427+
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
428+
404429
with gr.Tab("Interface mode", elem_id="interface-mode"):
405430
modes = ["default", "notebook", "chat", "cai_chat"]
406431
current_mode = "default"

0 commit comments

Comments
 (0)