Skip to content

Commit ca2d455

Browse files
lifuhuangwhybeyoung
authored andcommitted
Refactor LoRAManager and LoRAMemoryPool state management logic for dynamic LoRA loading support (sgl-project#7412)
1 parent 739f303 commit ca2d455

File tree

4 files changed

+228
-121
lines changed

4 files changed

+228
-121
lines changed

python/sglang/srt/lora/lora_manager.py

Lines changed: 173 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# and "Punica: Multi-Tenant LoRA Serving"
1717

1818
import logging
19-
from typing import Dict, List, Set, Tuple
19+
from typing import Dict, Set, Tuple
2020

2121
import torch
2222

@@ -45,7 +45,6 @@ class LoRAManager:
4545
def __init__(
4646
self,
4747
base_model: torch.nn.Module,
48-
lora_paths: Dict[str, str],
4948
base_hf_config: AutoConfig,
5049
max_loras_per_batch: int,
5150
load_config: LoadConfig,
@@ -55,7 +54,6 @@ def __init__(
5554
tp_rank: int = 0,
5655
):
5756
self.base_model: torch.nn.Module = base_model
58-
self.lora_paths: Dict[str, str] = lora_paths
5957
self.base_hf_config: AutoConfig = base_hf_config
6058
self.max_loras_per_batch: int = max_loras_per_batch
6159
self.load_config: LoadConfig = load_config
@@ -69,8 +67,8 @@ def __init__(
6967
backend_type = get_backend_from_name(lora_backend)
7068
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
7169

72-
self.init_loras()
73-
self.init_lora_memory_pool()
70+
# Initialize mutable internal state of the LoRAManager.
71+
self.init_state()
7472

7573
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
7674
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
@@ -100,72 +98,49 @@ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
10098
],
10199
)
102100

103-
def init_loras(self):
104-
# Config of each LoRA adapter
105-
self.configs: Dict[str, LoRAConfig] = {}
101+
def load_lora_adapters(self, lora_paths: Dict[str, str]):
102+
"""
103+
Load LoRA adapters from the specified paths.
104+
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
105+
106+
Args:
107+
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
108+
If a LoRA adapter is already loaded, it will be skipped with a warning.
109+
"""
110+
111+
for lora_name, lora_path in lora_paths.items():
112+
if lora_name in self.loras:
113+
logger.warning(
114+
f"LoRA adapter {lora_name} is already loaded."
115+
"If you want to reload it, please unload it first."
116+
)
117+
continue
106118

107-
# Target module names in huggingface lora configs.
108-
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
109-
self.hf_target_names: Set[str] = set()
110-
for name, path in self.lora_paths.items():
111-
self.configs[name] = LoRAConfig(path)
112-
self.hf_target_names.update(self.configs[name].target_modules)
119+
self.configs[lora_name] = LoRAConfig(lora_path)
113120

114-
# Target lora weight names for lora_a and lora_b modules respectively.
115-
weights_A: List[str] = []
116-
weights_B: List[str] = []
117-
for module in self.hf_target_names:
118-
lora_A, lora_B = get_normalized_lora_weight_names(module)
119-
weights_A += lora_A
120-
weights_B += lora_B
121-
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
121+
self.update_state_from_configs()
122122

123-
# load all weights to cpu
124-
self.loras: Dict[str, LoRAAdapter] = {}
125-
for name in self.lora_paths.keys():
126-
lora_adapter = LoRAAdapter(
127-
name,
128-
self.configs[name],
129-
self.base_hf_config,
130-
self.load_config,
131-
self.lora_backend,
132-
)
133-
lora_adapter.initialize_weights()
134-
self.loras[name] = lora_adapter
135-
136-
# misc lora configs
137-
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
123+
def unload_lora_adapters(self, lora_names: Set[str]):
124+
"""
125+
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
126+
delete the corresponding LoRA modules.
138127
139-
if self.lora_backend == "flashinfer":
140-
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
141-
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
142-
scaling = list(self.loras.values())[0].scaling
143-
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
144-
assert all(x.scaling == scaling for x in self.loras.values())
128+
Args:
129+
lora_names (Set[str]): A set of LoRA adapter names to unload.
130+
"""
131+
for lora_name in lora_names:
132+
if lora_name in self.loras:
133+
del self.configs[lora_name]
134+
else:
135+
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
145136

146-
# Convert original model layers to layers with LoRA
147-
self.convert_to_lora_layers()
148-
149-
def init_lora_memory_pool(self):
150-
# Initialize memory pool
151-
self.memory_pool = LoRAMemoryPool(
152-
self.base_hf_config,
153-
self.max_loras_per_batch,
154-
self.max_lora_dim,
155-
self.dtype,
156-
self.tp_size,
157-
self.tp_rank,
158-
self.lora_modules,
159-
)
160-
161-
# Initialize target lora modules in memory pool
162-
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
137+
self.update_state_from_configs()
163138

164139
def prepare_lora_batch(self, forward_batch: ForwardBatch):
165140
# load active loras into lora memory pool
166141
cur_uids = set(forward_batch.lora_paths)
167142
assert len(cur_uids) <= self.max_loras_per_batch
168-
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
143+
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
169144

170145
# set up batch info shared by all lora modules
171146
bs = forward_batch.batch_size
@@ -267,9 +242,16 @@ def transfer_adapter_info(
267242
)
268243
self.lora_backend.set_batch_info(batch_info)
269244

270-
# call set_lora_info for each lora modules
271-
for layer_id, modules in self.lora_modules.items():
272-
for module_name, module in modules:
245+
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
246+
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
247+
self.update_lora_info()
248+
249+
def update_lora_info(self):
250+
"""
251+
Update all LoRA modules to associate them with the latest memory buffer.
252+
"""
253+
for layer_id, layer_modules in self.lora_modules.items():
254+
for module_name, module in layer_modules.items():
273255
if "qkv_proj" in module_name:
274256
module.set_lora_info(
275257
self.memory_pool.get_tensor(
@@ -295,23 +277,139 @@ def transfer_adapter_info(
295277
),
296278
)
297279

280+
def init_state(self):
281+
"""
282+
Initialize the internal (mutable) state of the LoRAManager.
283+
284+
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
285+
"""
286+
287+
# Configs of all active LoRA adapters.
288+
self.configs: Dict[str, LoRAConfig] = {}
289+
290+
# LoRA adapter weights cached in CPU memory.
291+
self.loras: Dict[str, LoRAAdapter] = {}
292+
293+
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
294+
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
295+
296+
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
297+
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
298+
i: {} for i in range(self.base_hf_config.num_hidden_layers)
299+
}
300+
301+
# Initialize memory pool
302+
self.memory_pool = LoRAMemoryPool(
303+
self.base_hf_config,
304+
self.max_loras_per_batch,
305+
self.dtype,
306+
self.tp_size,
307+
self.tp_rank,
308+
)
309+
310+
def update_state_from_configs(self):
311+
"""
312+
Update the internal state of the LoRAManager based on the current `self.configs`. This method
313+
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
314+
315+
This includes:
316+
- Initializing LoRA adapters if they are not already loaded.
317+
- Collect all LoRA weight names based on the current loaded adapters.
318+
- Lazily monkey-patching the base model to use LoRA layers where applicable.
319+
- Preparing the GPU buffer pool for active LoRA weights.
320+
"""
321+
322+
# Target module names in huggingface lora configs.
323+
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
324+
hf_target_module_names: Set[str] = set()
325+
for config in self.configs.values():
326+
hf_target_module_names.update(config.target_modules)
327+
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
328+
329+
# Loads / unloads LoRA adapters based on the latest configs.
330+
self.update_lora_adapters()
331+
332+
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
333+
#
334+
# Please note that the following update operations are "monotonic" by design, meaning that we update
335+
# multiple places to support the new weight names when the first adapter targeting such weight names
336+
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
337+
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
338+
# list of LoRA weight names is expected to be extremely finite and stable.
339+
self.update_lora_weight_names(hf_target_module_names)
340+
self.update_lora_modules(hf_target_module_names)
341+
self.update_memory_buffers(max_lora_dim)
342+
343+
def update_lora_weight_names(self, hf_target_names: Set[str]):
344+
"""
345+
Add new LoRA weight names if needed based on the current `self.configs`.
346+
"""
347+
348+
# Target lora weight names for lora_a and lora_b modules respectively.
349+
for module in hf_target_names:
350+
lora_A, lora_B = get_normalized_lora_weight_names(module)
351+
self.lora_weight_names[0].update(lora_A)
352+
self.lora_weight_names[1].update(lora_B)
353+
354+
def update_lora_adapters(self):
355+
"""
356+
Update the LoRA adapters in CPU memory based on the current `self.configs`.
357+
It loads any new adapters that are not already loaded, and unloads any adapters
358+
that are no longer in `self.configs` (e.g., unloaded).
359+
"""
360+
361+
# Load new adapter weights to cpu
362+
for name, config in self.configs.items():
363+
if name not in self.loras:
364+
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
365+
lora_adapter = LoRAAdapter(
366+
name,
367+
config,
368+
self.base_hf_config,
369+
self.load_config,
370+
self.lora_backend,
371+
)
372+
lora_adapter.initialize_weights()
373+
self.loras[name] = lora_adapter
374+
375+
# Clean up unused LoRA adapters
376+
for name in self.loras:
377+
if name not in self.configs:
378+
logger.info(f"Unloading LoRA adapter {name}")
379+
del self.loras[name]
380+
381+
# Additional checks for flashinfer backend
382+
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
383+
if self.lora_backend == "flashinfer":
384+
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
385+
scalings = set(x.scaling for x in self.loras.values())
386+
assert (
387+
len(lora_dims) == 1 and len(scalings) == 1
388+
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
389+
390+
def update_memory_buffers(self, max_lora_dim: int):
391+
"""
392+
Update the LoRA memory pool buffers based on the current LoRA configurations and update
393+
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
394+
are set or updated.
395+
"""
396+
397+
self.memory_pool.init_buffers(
398+
self.lora_weight_names, self.base_model, max_lora_dim
399+
)
400+
298401
def set_lora_module(self, module_name, module):
299402
lora_module = get_lora_layer(module, self.lora_backend)
300403
replace_submodule(self.base_model, module_name, lora_module)
301404
return lora_module
302405

303-
def convert_to_lora_layers(self):
406+
def update_lora_modules(self, hf_target_names: Set[str]):
304407
# Target module names of customized layers defined in python/sglang/srt/layers
305408
# e.g., {"qkv_proj", "o_proj"}
306409
customized_target_names = get_customized_names_from_hf_names(
307-
self.hf_target_names, self.base_model
410+
hf_target_names, self.base_model
308411
)
309412

310-
# Monkey patch to use the LoRA version layers
311-
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
312-
i: [] for i in range(self.base_hf_config.num_hidden_layers)
313-
}
314-
315413
for module_name, module in self.base_model.named_modules():
316414
# TODO (lifuhuang): in the future, we should consider generalizing the
317415
# should_apply_lora function to support mapping by full module name instead
@@ -326,6 +424,7 @@ def convert_to_lora_layers(self):
326424
# The module should be converted if it is included in target_names
327425
if module_name.split(".")[-1] in customized_target_names:
328426
layer_id = get_layer_id(module_name)
329-
self.lora_modules[layer_id].append(
330-
(module_name, self.set_lora_module(module_name, module))
331-
)
427+
if module_name not in self.lora_modules[layer_id]:
428+
self.lora_modules[layer_id][module_name] = self.set_lora_module(
429+
module_name, module
430+
)

0 commit comments

Comments
 (0)