Skip to content

Commit 0a44a1f

Browse files
lengrongfuMu Huai
authored andcommitted
[Model] use AutoWeightsLoader for olmoe,opt,orion,persimmon,phi3_small (vllm-project#16548)
Signed-off-by: rongfu.leng <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 21f8f85 commit 0a44a1f

File tree

5 files changed

+216
-193
lines changed

5 files changed

+216
-193
lines changed

vllm/model_executor/models/olmoe.py

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from vllm.sequence import IntermediateTensors
4040

4141
from .interfaces import SupportsPP
42-
from .utils import (is_pp_missing_parameter,
42+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4343
make_empty_intermediate_tensors_factory, make_layers,
4444
maybe_prefix)
4545

@@ -255,7 +255,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
255255
quant_config = vllm_config.quant_config
256256

257257
self.vocab_size = config.vocab_size
258-
258+
self.config = config
259259
self.embed_tokens = VocabParallelEmbedding(
260260
config.vocab_size,
261261
config.hidden_size,
@@ -308,56 +308,6 @@ def forward(
308308
hidden_states, _ = self.norm(hidden_states, residual)
309309
return hidden_states
310310

311-
312-
class OlmoeForCausalLM(nn.Module, SupportsPP):
313-
314-
fall_back_to_pt_during_load = False
315-
316-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
317-
super().__init__()
318-
config = vllm_config.model_config.hf_config
319-
quant_config = vllm_config.quant_config
320-
self.config = config
321-
self.quant_config = quant_config
322-
self.model = OlmoeModel(vllm_config=vllm_config,
323-
prefix=maybe_prefix(prefix, "model"))
324-
self.lm_head = ParallelLMHead(config.vocab_size,
325-
config.hidden_size,
326-
quant_config=quant_config)
327-
self.logits_processor = LogitsProcessor(config.vocab_size)
328-
self.sampler = get_sampler()
329-
330-
self.make_empty_intermediate_tensors = (
331-
self.model.make_empty_intermediate_tensors)
332-
333-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
334-
return self.model.get_input_embeddings(input_ids)
335-
336-
def forward(
337-
self,
338-
input_ids: torch.Tensor,
339-
positions: torch.Tensor,
340-
intermediate_tensors: Optional[IntermediateTensors] = None,
341-
inputs_embeds: Optional[torch.Tensor] = None,
342-
) -> Union[torch.Tensor, IntermediateTensors]:
343-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
344-
inputs_embeds)
345-
return hidden_states
346-
347-
def compute_logits(self, hidden_states: torch.Tensor,
348-
sampling_metadata: SamplingMetadata) -> torch.Tensor:
349-
logits = self.logits_processor(self.lm_head, hidden_states,
350-
sampling_metadata)
351-
return logits
352-
353-
def sample(
354-
self,
355-
logits: Optional[torch.Tensor],
356-
sampling_metadata: SamplingMetadata,
357-
) -> Optional[SamplerOutput]:
358-
next_tokens = self.sampler(logits, sampling_metadata)
359-
return next_tokens
360-
361311
def load_weights(self, weights: Iterable[Tuple[str,
362312
torch.Tensor]]) -> Set[str]:
363313
stacked_params_mapping = [
@@ -380,8 +330,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
380330
params_dict = dict(self.named_parameters())
381331
loaded_params: Set[str] = set()
382332
for name, loaded_weight in weights:
383-
if "rotary_emb.inv_freq" in name:
384-
continue
385333
for (param_name, weight_name, shard_id) in stacked_params_mapping:
386334
# Skip non-stacked layers and experts (experts handled below).
387335
if weight_name not in name:
@@ -453,3 +401,59 @@ def load_weights(self, weights: Iterable[Tuple[str,
453401
weight_loader(param, loaded_weight)
454402
loaded_params.add(name)
455403
return loaded_params
404+
405+
406+
class OlmoeForCausalLM(nn.Module, SupportsPP):
407+
408+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
409+
super().__init__()
410+
config = vllm_config.model_config.hf_config
411+
quant_config = vllm_config.quant_config
412+
self.config = config
413+
self.quant_config = quant_config
414+
self.model = OlmoeModel(vllm_config=vllm_config,
415+
prefix=maybe_prefix(prefix, "model"))
416+
self.lm_head = ParallelLMHead(config.vocab_size,
417+
config.hidden_size,
418+
quant_config=quant_config)
419+
self.logits_processor = LogitsProcessor(config.vocab_size)
420+
self.sampler = get_sampler()
421+
422+
self.make_empty_intermediate_tensors = (
423+
self.model.make_empty_intermediate_tensors)
424+
425+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
426+
return self.model.get_input_embeddings(input_ids)
427+
428+
def forward(
429+
self,
430+
input_ids: torch.Tensor,
431+
positions: torch.Tensor,
432+
intermediate_tensors: Optional[IntermediateTensors] = None,
433+
inputs_embeds: Optional[torch.Tensor] = None,
434+
) -> Union[torch.Tensor, IntermediateTensors]:
435+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
436+
inputs_embeds)
437+
return hidden_states
438+
439+
def compute_logits(self, hidden_states: torch.Tensor,
440+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
441+
logits = self.logits_processor(self.lm_head, hidden_states,
442+
sampling_metadata)
443+
return logits
444+
445+
def sample(
446+
self,
447+
logits: Optional[torch.Tensor],
448+
sampling_metadata: SamplingMetadata,
449+
) -> Optional[SamplerOutput]:
450+
next_tokens = self.sampler(logits, sampling_metadata)
451+
return next_tokens
452+
453+
def load_weights(self, weights: Iterable[Tuple[str,
454+
torch.Tensor]]) -> Set[str]:
455+
loader = AutoWeightsLoader(
456+
self,
457+
skip_prefixes=["rotary_emb.inv_freq"],
458+
)
459+
return loader.load_weights(weights)

vllm/model_executor/models/opt.py

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.sequence import IntermediateTensors
4444

4545
from .interfaces import SupportsPP
46-
from .utils import (is_pp_missing_parameter,
46+
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
4747
make_empty_intermediate_tensors_factory, make_layers,
4848
maybe_prefix)
4949

@@ -313,13 +313,54 @@ def forward(
313313
intermediate_tensors,
314314
inputs_embeds=inputs_embeds)
315315

316+
def load_weights(self, weights: Iterable[Tuple[str,
317+
torch.Tensor]]) -> Set[str]:
318+
stacked_params_mapping = [
319+
# (param_name, shard_name, shard_id)
320+
("qkv_proj", "q_proj", "q"),
321+
("qkv_proj", "k_proj", "k"),
322+
("qkv_proj", "v_proj", "v"),
323+
]
324+
params_dict = dict(self.named_parameters(remove_duplicate=False))
325+
loaded_params: Set[str] = set()
326+
for name, loaded_weight in weights:
327+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
328+
if weight_name not in name:
329+
continue
330+
name = name.replace(weight_name, param_name)
331+
# Skip loading extra bias for GPTQ models.
332+
if name.endswith(".bias") and name not in params_dict:
333+
continue
334+
if is_pp_missing_parameter(name, self):
335+
continue
336+
param = params_dict[name]
337+
weight_loader = param.weight_loader
338+
weight_loader(param, loaded_weight, shard_id)
339+
break
340+
else:
341+
# Skip loading extra bias for GPTQ models.
342+
if name.endswith(".bias") and name not in params_dict:
343+
continue
344+
if is_pp_missing_parameter(name, self):
345+
continue
346+
param = params_dict[name]
347+
weight_loader = getattr(param, "weight_loader",
348+
default_weight_loader)
349+
weight_loader(param, loaded_weight)
350+
loaded_params.add(name)
351+
return loaded_params
352+
316353

317354
class OPTForCausalLM(nn.Module, SupportsPP):
318355
packed_modules_mapping = {
319356
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
320357
"gate_up_proj": ["gate_proj", "up_proj"]
321358
}
322359

360+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
361+
"decoder.": "model.decoder.",
362+
})
363+
323364
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
324365
super().__init__()
325366
config = vllm_config.model_config.hf_config
@@ -371,42 +412,9 @@ def sample(
371412

372413
def load_weights(self, weights: Iterable[Tuple[str,
373414
torch.Tensor]]) -> Set[str]:
374-
stacked_params_mapping = [
375-
# (param_name, shard_name, shard_id)
376-
("qkv_proj", "q_proj", "q"),
377-
("qkv_proj", "k_proj", "k"),
378-
("qkv_proj", "v_proj", "v"),
379-
]
380-
params_dict = dict(self.named_parameters(remove_duplicate=False))
381-
loaded_params: Set[str] = set()
382-
for name, loaded_weight in weights:
383-
if "lm_head.weight" in name and self.config.tie_word_embeddings:
384-
continue
385-
if name.startswith("decoder."):
386-
name = "model." + name
387-
388-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
389-
if weight_name not in name:
390-
continue
391-
name = name.replace(weight_name, param_name)
392-
# Skip loading extra bias for GPTQ models.
393-
if name.endswith(".bias") and name not in params_dict:
394-
continue
395-
if is_pp_missing_parameter(name, self):
396-
continue
397-
param = params_dict[name]
398-
weight_loader = param.weight_loader
399-
weight_loader(param, loaded_weight, shard_id)
400-
break
401-
else:
402-
# Skip loading extra bias for GPTQ models.
403-
if name.endswith(".bias") and name not in params_dict:
404-
continue
405-
if is_pp_missing_parameter(name, self):
406-
continue
407-
param = params_dict[name]
408-
weight_loader = getattr(param, "weight_loader",
409-
default_weight_loader)
410-
weight_loader(param, loaded_weight)
411-
loaded_params.add(name)
412-
return loaded_params
415+
loader = AutoWeightsLoader(
416+
self,
417+
skip_prefixes=(["lm_head.weight"]
418+
if self.config.tie_word_embeddings else None),
419+
)
420+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

vllm/model_executor/models/orion.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from vllm.sequence import IntermediateTensors
3131

3232
from .interfaces import SupportsPP
33-
from .utils import (is_pp_missing_parameter,
33+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
3434
make_empty_intermediate_tensors_factory, make_layers,
3535
maybe_prefix)
3636

@@ -260,6 +260,45 @@ def forward(
260260
hidden_states = self.norm(hidden_states)
261261
return hidden_states
262262

263+
def load_weights(self, weights: Iterable[Tuple[str,
264+
torch.Tensor]]) -> Set[str]:
265+
stacked_params_mapping = [
266+
# (param_name, shard_name, shard_id)
267+
("qkv_proj", "q_proj", "q"),
268+
("qkv_proj", "k_proj", "k"),
269+
("qkv_proj", "v_proj", "v"),
270+
("gate_up_proj", "gate_proj", 0),
271+
("gate_up_proj", "up_proj", 1),
272+
]
273+
params_dict = dict(self.named_parameters())
274+
loaded_params: Set[str] = set()
275+
for name, loaded_weight in weights:
276+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
277+
if weight_name not in name:
278+
continue
279+
name = name.replace(weight_name, param_name)
280+
# Skip loading extra bias for GPTQ models.
281+
if name.endswith(".bias") and name not in params_dict:
282+
continue
283+
if is_pp_missing_parameter(name, self):
284+
continue
285+
param = params_dict[name]
286+
weight_loader = param.weight_loader
287+
weight_loader(param, loaded_weight, shard_id)
288+
break
289+
else:
290+
# Skip loading extra bias for GPTQ models.
291+
if name.endswith(".bias") and name not in params_dict:
292+
continue
293+
if is_pp_missing_parameter(name, self):
294+
continue
295+
param = params_dict[name]
296+
weight_loader = getattr(param, "weight_loader",
297+
default_weight_loader)
298+
weight_loader(param, loaded_weight)
299+
loaded_params.add(name)
300+
return loaded_params
301+
263302

264303
class OrionForCausalLM(nn.Module, SupportsPP):
265304

@@ -314,46 +353,14 @@ def sample(
314353

315354
def load_weights(self, weights: Iterable[Tuple[str,
316355
torch.Tensor]]) -> Set[str]:
317-
stacked_params_mapping = [
318-
# (param_name, shard_name, shard_id)
319-
("qkv_proj", "q_proj", "q"),
320-
("qkv_proj", "k_proj", "k"),
321-
("qkv_proj", "v_proj", "v"),
322-
("gate_up_proj", "gate_proj", 0),
323-
("gate_up_proj", "up_proj", 1),
324-
]
325-
params_dict = dict(self.named_parameters())
326-
loaded_params: Set[str] = set()
327-
for name, loaded_weight in weights:
328-
if "rotary_emb.inv_freq" in name:
329-
continue
330-
if ("rotary_emb.cos_cached" in name
331-
or "rotary_emb.sin_cached" in name):
356+
loader = AutoWeightsLoader(
357+
self,
358+
skip_prefixes=([
359+
"rotary_emb.inv_freq",
332360
# Models trained using ColossalAI may include these tensors in
333361
# the checkpoint. Skip them.
334-
continue
335-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
336-
if weight_name not in name:
337-
continue
338-
name = name.replace(weight_name, param_name)
339-
# Skip loading extra bias for GPTQ models.
340-
if name.endswith(".bias") and name not in params_dict:
341-
continue
342-
if is_pp_missing_parameter(name, self):
343-
continue
344-
param = params_dict[name]
345-
weight_loader = param.weight_loader
346-
weight_loader(param, loaded_weight, shard_id)
347-
break
348-
else:
349-
# Skip loading extra bias for GPTQ models.
350-
if name.endswith(".bias") and name not in params_dict:
351-
continue
352-
if is_pp_missing_parameter(name, self):
353-
continue
354-
param = params_dict[name]
355-
weight_loader = getattr(param, "weight_loader",
356-
default_weight_loader)
357-
weight_loader(param, loaded_weight)
358-
loaded_params.add(name)
359-
return loaded_params
362+
"rotary_emb.cos_cached",
363+
"rotary_emb.sin_cached"
364+
]),
365+
)
366+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)