16
16
# and "Punica: Multi-Tenant LoRA Serving"
17
17
18
18
import logging
19
- from typing import Dict , List , Set , Tuple
19
+ from typing import Dict , Set , Tuple
20
20
21
21
import torch
22
22
@@ -45,7 +45,6 @@ class LoRAManager:
45
45
def __init__ (
46
46
self ,
47
47
base_model : torch .nn .Module ,
48
- lora_paths : Dict [str , str ],
49
48
base_hf_config : AutoConfig ,
50
49
max_loras_per_batch : int ,
51
50
load_config : LoadConfig ,
@@ -55,7 +54,6 @@ def __init__(
55
54
tp_rank : int = 0 ,
56
55
):
57
56
self .base_model : torch .nn .Module = base_model
58
- self .lora_paths : Dict [str , str ] = lora_paths
59
57
self .base_hf_config : AutoConfig = base_hf_config
60
58
self .max_loras_per_batch : int = max_loras_per_batch
61
59
self .load_config : LoadConfig = load_config
@@ -69,8 +67,8 @@ def __init__(
69
67
backend_type = get_backend_from_name (lora_backend )
70
68
self .lora_backend : BaseLoRABackend = backend_type (lora_backend )
71
69
72
- self . init_loras ()
73
- self .init_lora_memory_pool ()
70
+ # Initialize mutable internal state of the LoRAManager.
71
+ self .init_state ()
74
72
75
73
def init_cuda_graph_batch_info (self , max_bs_in_cuda_graph : int ):
76
74
self .max_bs_in_cuda_graph = max_bs_in_cuda_graph
@@ -100,72 +98,49 @@ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
100
98
],
101
99
)
102
100
103
- def init_loras (self ):
104
- # Config of each LoRA adapter
105
- self .configs : Dict [str , LoRAConfig ] = {}
101
+ def load_lora_adapters (self , lora_paths : Dict [str , str ]):
102
+ """
103
+ Load LoRA adapters from the specified paths.
104
+ TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
105
+
106
+ Args:
107
+ lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
108
+ If a LoRA adapter is already loaded, it will be skipped with a warning.
109
+ """
110
+
111
+ for lora_name , lora_path in lora_paths .items ():
112
+ if lora_name in self .loras :
113
+ logger .warning (
114
+ f"LoRA adapter { lora_name } is already loaded."
115
+ "If you want to reload it, please unload it first."
116
+ )
117
+ continue
106
118
107
- # Target module names in huggingface lora configs.
108
- # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
109
- self .hf_target_names : Set [str ] = set ()
110
- for name , path in self .lora_paths .items ():
111
- self .configs [name ] = LoRAConfig (path )
112
- self .hf_target_names .update (self .configs [name ].target_modules )
119
+ self .configs [lora_name ] = LoRAConfig (lora_path )
113
120
114
- # Target lora weight names for lora_a and lora_b modules respectively.
115
- weights_A : List [str ] = []
116
- weights_B : List [str ] = []
117
- for module in self .hf_target_names :
118
- lora_A , lora_B = get_normalized_lora_weight_names (module )
119
- weights_A += lora_A
120
- weights_B += lora_B
121
- self .lora_weight_names : Tuple [Set [str ]] = set (weights_A ), set (weights_B )
121
+ self .update_state_from_configs ()
122
122
123
- # load all weights to cpu
124
- self .loras : Dict [str , LoRAAdapter ] = {}
125
- for name in self .lora_paths .keys ():
126
- lora_adapter = LoRAAdapter (
127
- name ,
128
- self .configs [name ],
129
- self .base_hf_config ,
130
- self .load_config ,
131
- self .lora_backend ,
132
- )
133
- lora_adapter .initialize_weights ()
134
- self .loras [name ] = lora_adapter
135
-
136
- # misc lora configs
137
- self .max_lora_dim : int = max ([x .hf_config ["r" ] for x in self .configs .values ()])
123
+ def unload_lora_adapters (self , lora_names : Set [str ]):
124
+ """
125
+ Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
126
+ delete the corresponding LoRA modules.
138
127
139
- if self .lora_backend == "flashinfer" :
140
- # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
141
- max_lora_dim = max ([x .hf_config ["r" ] for x in self .configs .values ()])
142
- scaling = list (self .loras .values ())[0 ].scaling
143
- assert all (x .hf_config ["r" ] == max_lora_dim for x in self .configs .values ())
144
- assert all (x .scaling == scaling for x in self .loras .values ())
128
+ Args:
129
+ lora_names (Set[str]): A set of LoRA adapter names to unload.
130
+ """
131
+ for lora_name in lora_names :
132
+ if lora_name in self .loras :
133
+ del self .configs [lora_name ]
134
+ else :
135
+ logger .warning (f"LoRA adapter { lora_name } is not loaded." )
145
136
146
- # Convert original model layers to layers with LoRA
147
- self .convert_to_lora_layers ()
148
-
149
- def init_lora_memory_pool (self ):
150
- # Initialize memory pool
151
- self .memory_pool = LoRAMemoryPool (
152
- self .base_hf_config ,
153
- self .max_loras_per_batch ,
154
- self .max_lora_dim ,
155
- self .dtype ,
156
- self .tp_size ,
157
- self .tp_rank ,
158
- self .lora_modules ,
159
- )
160
-
161
- # Initialize target lora modules in memory pool
162
- self .memory_pool .init_buffers (self .lora_weight_names , self .base_model )
137
+ self .update_state_from_configs ()
163
138
164
139
def prepare_lora_batch (self , forward_batch : ForwardBatch ):
165
140
# load active loras into lora memory pool
166
141
cur_uids = set (forward_batch .lora_paths )
167
142
assert len (cur_uids ) <= self .max_loras_per_batch
168
- self .memory_pool .prepare_lora_batch (cur_uids , self .loras )
143
+ self .memory_pool .prepare_lora_batch (cur_uids , self .loras , self . lora_modules )
169
144
170
145
# set up batch info shared by all lora modules
171
146
bs = forward_batch .batch_size
@@ -267,9 +242,16 @@ def transfer_adapter_info(
267
242
)
268
243
self .lora_backend .set_batch_info (batch_info )
269
244
270
- # call set_lora_info for each lora modules
271
- for layer_id , modules in self .lora_modules .items ():
272
- for module_name , module in modules :
245
+ # TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
246
+ # this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
247
+ self .update_lora_info ()
248
+
249
+ def update_lora_info (self ):
250
+ """
251
+ Update all LoRA modules to associate them with the latest memory buffer.
252
+ """
253
+ for layer_id , layer_modules in self .lora_modules .items ():
254
+ for module_name , module in layer_modules .items ():
273
255
if "qkv_proj" in module_name :
274
256
module .set_lora_info (
275
257
self .memory_pool .get_tensor (
@@ -295,23 +277,139 @@ def transfer_adapter_info(
295
277
),
296
278
)
297
279
280
+ def init_state (self ):
281
+ """
282
+ Initialize the internal (mutable) state of the LoRAManager.
283
+
284
+ These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
285
+ """
286
+
287
+ # Configs of all active LoRA adapters.
288
+ self .configs : Dict [str , LoRAConfig ] = {}
289
+
290
+ # LoRA adapter weights cached in CPU memory.
291
+ self .loras : Dict [str , LoRAAdapter ] = {}
292
+
293
+ # Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
294
+ self .lora_weight_names : Tuple [Set [str ]] = (set (), set ())
295
+
296
+ # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
297
+ self .lora_modules : Dict [int , Dict [str , BaseLayerWithLoRA ]] = {
298
+ i : {} for i in range (self .base_hf_config .num_hidden_layers )
299
+ }
300
+
301
+ # Initialize memory pool
302
+ self .memory_pool = LoRAMemoryPool (
303
+ self .base_hf_config ,
304
+ self .max_loras_per_batch ,
305
+ self .dtype ,
306
+ self .tp_size ,
307
+ self .tp_rank ,
308
+ )
309
+
310
+ def update_state_from_configs (self ):
311
+ """
312
+ Update the internal state of the LoRAManager based on the current `self.configs`. This method
313
+ should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
314
+
315
+ This includes:
316
+ - Initializing LoRA adapters if they are not already loaded.
317
+ - Collect all LoRA weight names based on the current loaded adapters.
318
+ - Lazily monkey-patching the base model to use LoRA layers where applicable.
319
+ - Preparing the GPU buffer pool for active LoRA weights.
320
+ """
321
+
322
+ # Target module names in huggingface lora configs.
323
+ # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
324
+ hf_target_module_names : Set [str ] = set ()
325
+ for config in self .configs .values ():
326
+ hf_target_module_names .update (config .target_modules )
327
+ max_lora_dim : int = max ([x .hf_config ["r" ] for x in self .configs .values ()])
328
+
329
+ # Loads / unloads LoRA adapters based on the latest configs.
330
+ self .update_lora_adapters ()
331
+
332
+ # Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
333
+ #
334
+ # Please note that the following update operations are "monotonic" by design, meaning that we update
335
+ # multiple places to support the new weight names when the first adapter targeting such weight names
336
+ # is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
337
+ # even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
338
+ # list of LoRA weight names is expected to be extremely finite and stable.
339
+ self .update_lora_weight_names (hf_target_module_names )
340
+ self .update_lora_modules (hf_target_module_names )
341
+ self .update_memory_buffers (max_lora_dim )
342
+
343
+ def update_lora_weight_names (self , hf_target_names : Set [str ]):
344
+ """
345
+ Add new LoRA weight names if needed based on the current `self.configs`.
346
+ """
347
+
348
+ # Target lora weight names for lora_a and lora_b modules respectively.
349
+ for module in hf_target_names :
350
+ lora_A , lora_B = get_normalized_lora_weight_names (module )
351
+ self .lora_weight_names [0 ].update (lora_A )
352
+ self .lora_weight_names [1 ].update (lora_B )
353
+
354
+ def update_lora_adapters (self ):
355
+ """
356
+ Update the LoRA adapters in CPU memory based on the current `self.configs`.
357
+ It loads any new adapters that are not already loaded, and unloads any adapters
358
+ that are no longer in `self.configs` (e.g., unloaded).
359
+ """
360
+
361
+ # Load new adapter weights to cpu
362
+ for name , config in self .configs .items ():
363
+ if name not in self .loras :
364
+ logger .info (f"Loading weight of LoRA adapter { name } from { config .path } " )
365
+ lora_adapter = LoRAAdapter (
366
+ name ,
367
+ config ,
368
+ self .base_hf_config ,
369
+ self .load_config ,
370
+ self .lora_backend ,
371
+ )
372
+ lora_adapter .initialize_weights ()
373
+ self .loras [name ] = lora_adapter
374
+
375
+ # Clean up unused LoRA adapters
376
+ for name in self .loras :
377
+ if name not in self .configs :
378
+ logger .info (f"Unloading LoRA adapter { name } " )
379
+ del self .loras [name ]
380
+
381
+ # Additional checks for flashinfer backend
382
+ # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
383
+ if self .lora_backend == "flashinfer" :
384
+ lora_dims = set (x .hf_config ["r" ] for x in self .configs .values ())
385
+ scalings = set (x .scaling for x in self .loras .values ())
386
+ assert (
387
+ len (lora_dims ) == 1 and len (scalings ) == 1
388
+ ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
389
+
390
+ def update_memory_buffers (self , max_lora_dim : int ):
391
+ """
392
+ Update the LoRA memory pool buffers based on the current LoRA configurations and update
393
+ LoRA modules to use the new buffers. This method should be called after the LoRA configurations
394
+ are set or updated.
395
+ """
396
+
397
+ self .memory_pool .init_buffers (
398
+ self .lora_weight_names , self .base_model , max_lora_dim
399
+ )
400
+
298
401
def set_lora_module (self , module_name , module ):
299
402
lora_module = get_lora_layer (module , self .lora_backend )
300
403
replace_submodule (self .base_model , module_name , lora_module )
301
404
return lora_module
302
405
303
- def convert_to_lora_layers (self ):
406
+ def update_lora_modules (self , hf_target_names : Set [ str ] ):
304
407
# Target module names of customized layers defined in python/sglang/srt/layers
305
408
# e.g., {"qkv_proj", "o_proj"}
306
409
customized_target_names = get_customized_names_from_hf_names (
307
- self . hf_target_names , self .base_model
410
+ hf_target_names , self .base_model
308
411
)
309
412
310
- # Monkey patch to use the LoRA version layers
311
- self .lora_modules : Dict [int , List [Tuple [str , BaseLayerWithLoRA ]]] = {
312
- i : [] for i in range (self .base_hf_config .num_hidden_layers )
313
- }
314
-
315
413
for module_name , module in self .base_model .named_modules ():
316
414
# TODO (lifuhuang): in the future, we should consider generalizing the
317
415
# should_apply_lora function to support mapping by full module name instead
@@ -326,6 +424,7 @@ def convert_to_lora_layers(self):
326
424
# The module should be converted if it is included in target_names
327
425
if module_name .split ("." )[- 1 ] in customized_target_names :
328
426
layer_id = get_layer_id (module_name )
329
- self .lora_modules [layer_id ].append (
330
- (module_name , self .set_lora_module (module_name , module ))
331
- )
427
+ if module_name not in self .lora_modules [layer_id ]:
428
+ self .lora_modules [layer_id ][module_name ] = self .set_lora_module (
429
+ module_name , module
430
+ )
0 commit comments