Skip to content

Commit ba2b8e4

Browse files
mickqianthyecust
authored andcommitted
refactor: multimodal data (sgl-project#4754)
1 parent c34051c commit ba2b8e4

36 files changed

+989
-1138
lines changed

benchmark/mmmu/bench_hf.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,38 @@ def eval_mmmu(args):
7272
if suffix:
7373
contents += [{"type": "text", "text": suffix}]
7474
messages = [{"role": "user", "content": contents}]
75-
model_inputs = processor.apply_chat_template(
76-
messages,
77-
tokenize=True,
78-
return_dict=True,
79-
add_generation_prompt=True,
80-
return_tensors="pt",
81-
).to(model.device)
82-
input_len = model_inputs["input_ids"].shape[-1]
83-
generation = model.generate(**model_inputs, generation_config=generation_config)
84-
generation = generation[0][input_len:]
85-
response = processor.decode(generation, skip_special_tokens=True)
75+
try:
76+
model_inputs = processor.tokenizer.apply_chat_template(
77+
messages,
78+
tokenize=True,
79+
return_dict=True,
80+
add_generation_prompt=True,
81+
return_tensors="pt",
82+
).to(model.device)
83+
input_len = model_inputs["input_ids"].shape[-1]
84+
generation = model.generate(
85+
**model_inputs, generation_config=generation_config
86+
)
87+
generation = generation[0][input_len:]
88+
response = processor.decode(generation, skip_special_tokens=True)
89+
except:
90+
contents = []
91+
if prefix:
92+
contents += [prefix]
93+
image = PIL.Image.open(sample["image_path"])
94+
contents += [image]
95+
if suffix:
96+
contents += [suffix]
97+
messages = [{"role": "user", "content": contents}]
98+
response = model.chat(
99+
msgs=messages,
100+
tokenizer=processor.tokenizer,
101+
sampling=False,
102+
max_new_tokens=sampling_params["max_new_tokens"],
103+
use_tts_template=False,
104+
generate_audio=False,
105+
temperature=0.0,
106+
)
86107
print(f"response: {response}")
87108
process_result(response, sample, answer_dict, out_samples)
88109

benchmark/mmmu/eval_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict):
442442

443443

444444
def process_result(response, sample, answer_dict, out_samples):
445+
if response is None:
446+
return
445447
if sample["question_type"] == "multiple-choice":
446448
pred_ans = parse_multi_choice_response(
447449
response, sample["all_choices"], sample["index2ans"]

0 commit comments

Comments
 (0)