|
| 1 | +# Copyright 2025 BAAI. and/or its affiliates. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import argparse |
| 5 | +import os |
| 6 | +import torch |
| 7 | + |
| 8 | +import importlib as imp |
| 9 | +import os.path as osp |
| 10 | + |
| 11 | +from pathlib import Path |
| 12 | +from PIL import Image |
| 13 | +from tqdm import tqdm |
| 14 | + |
| 15 | +from src.utils.model_utils import build_emu3p5_vllm |
| 16 | +from src.utils.vllm_generation_utils import generate |
| 17 | +from src.utils.generation_utils import multimodal_decode |
| 18 | +from src.utils.painting_utils import ProtoWriter |
| 19 | +from src.utils.input_utils import build_image |
| 20 | + |
| 21 | + |
| 22 | +def parse_args(): |
| 23 | + parser = argparse.ArgumentParser() |
| 24 | + parser.add_argument("--cfg", default="", type=str) |
| 25 | + parser.add_argument("--tensor-parallel-size", default=2, type=int) |
| 26 | + parser.add_argument("--gpu-memory-utilization", default=0.7, type=float) |
| 27 | + parser.add_argument("--seed", default=6666, type=int) |
| 28 | + args = parser.parse_args() |
| 29 | + return args |
| 30 | + |
| 31 | + |
| 32 | +def inference( |
| 33 | + cfg, |
| 34 | + model, |
| 35 | + tokenizer, |
| 36 | + vq_model, |
| 37 | +): |
| 38 | + save_path = cfg.save_path |
| 39 | + |
| 40 | + os.makedirs(save_path, exist_ok=True) |
| 41 | + os.makedirs(f"{save_path}/proto", exist_ok=True) |
| 42 | + proto_writer = ProtoWriter() |
| 43 | + |
| 44 | + for name, question in tqdm(cfg.prompts, total=len(cfg.prompts)): |
| 45 | + if osp.exists(f"{save_path}/proto/{name}.pb"): |
| 46 | + print(f"[WARNING] Result already exists, skipping {name}", flush=True) |
| 47 | + continue |
| 48 | + |
| 49 | + torch.cuda.empty_cache() |
| 50 | + |
| 51 | + reference_image = None |
| 52 | + if not isinstance(question, str): |
| 53 | + if isinstance(question["reference_image"], list): |
| 54 | + print(f"[INFO] {len(question['reference_image'])} reference images are provided") |
| 55 | + reference_image = [] |
| 56 | + for img in question["reference_image"]: |
| 57 | + reference_image.append(Image.open(img).convert("RGB")) |
| 58 | + else: |
| 59 | + print (f"[INFO] 1 reference image is provided") |
| 60 | + reference_image = Image.open(question["reference_image"]).convert("RGB") |
| 61 | + question = question["prompt"] |
| 62 | + else: |
| 63 | + print(f"[INFO] No reference image is provided") |
| 64 | + |
| 65 | + proto_writer.clear() |
| 66 | + proto_writer.extend([["question", question]]) |
| 67 | + if reference_image is not None: |
| 68 | + if isinstance(reference_image, list): |
| 69 | + for idx, img in enumerate(reference_image): |
| 70 | + proto_writer.extend([[f"reference_image", img]]) |
| 71 | + else: |
| 72 | + proto_writer.extend([["reference_image", reference_image]]) |
| 73 | + |
| 74 | + success = True |
| 75 | + prompt = cfg.template.format(question=question) |
| 76 | + |
| 77 | + print(f"[INFO] Handling prompt: {prompt}") |
| 78 | + if reference_image is not None: |
| 79 | + if isinstance(reference_image, list): |
| 80 | + image_str = "" |
| 81 | + for img in reference_image: |
| 82 | + image_str += build_image(img, cfg, tokenizer, vq_model) |
| 83 | + else: |
| 84 | + image_str = build_image(reference_image, cfg, tokenizer, vq_model) |
| 85 | + prompt = prompt.replace("<|IMAGE|>", image_str) |
| 86 | + unc_prompt = cfg.unc_prompt.replace("<|IMAGE|>", image_str) |
| 87 | + else: |
| 88 | + unc_prompt = cfg.unc_prompt |
| 89 | + |
| 90 | + input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False) |
| 91 | + |
| 92 | + if input_ids[0, 0] != cfg.special_token_ids["BOS"]: |
| 93 | + BOS = torch.Tensor([[cfg.special_token_ids["BOS"]]], dtype=input_ids.dtype) |
| 94 | + input_ids = torch.cat([BOS, input_ids], dim=1) |
| 95 | + |
| 96 | + unconditional_ids = tokenizer.encode(unc_prompt, return_tensors="pt", add_special_tokens=False) |
| 97 | + |
| 98 | + for result_tokens in generate(cfg, model, tokenizer, input_ids, unconditional_ids): |
| 99 | + try: |
| 100 | + print(f"{result_tokens.shape=}") |
| 101 | + result = tokenizer.decode(result_tokens, skip_special_tokens=False) |
| 102 | + mm_out = multimodal_decode(result, tokenizer, vq_model) |
| 103 | + proto_writer.extend(mm_out) |
| 104 | + except Exception as e: |
| 105 | + success = False |
| 106 | + print(f"[ERROR] Failed to generate token sequence: {e}") |
| 107 | + break |
| 108 | + |
| 109 | + if not success: |
| 110 | + continue |
| 111 | + |
| 112 | + proto_writer.save(f"{save_path}/proto/{name}.pb") |
| 113 | + |
| 114 | + |
| 115 | +def main(): |
| 116 | + args = parse_args() |
| 117 | + cfg_name = Path(args.cfg).stem |
| 118 | + cfg_package = Path(args.cfg).parent.__str__().replace("/", ".") |
| 119 | + cfg = imp.import_module(f".{cfg_name}", package=cfg_package) |
| 120 | + |
| 121 | + if isinstance(cfg.prompts, dict): |
| 122 | + cfg.prompts = [(n, p) for n, p in cfg.prompts.items()] |
| 123 | + else: |
| 124 | + cfg.prompts = [(f"{idx:03d}", p) for idx, p in enumerate(cfg.prompts)] |
| 125 | + |
| 126 | + cfg.prompts = [(n, p) for n, p in cfg.prompts if not osp.exists(f"{cfg.save_path}/proto/{n}.pb")] |
| 127 | + cfg.num_prompts = len(cfg.prompts) |
| 128 | + |
| 129 | + model, tokenizer, vq_model = build_emu3p5_vllm( |
| 130 | + cfg.model_path, |
| 131 | + cfg.tokenizer_path, |
| 132 | + cfg.vq_path, |
| 133 | + vq_type=cfg.vq_type, |
| 134 | + vq_device=cfg.vq_device, |
| 135 | + seed=cfg.seed, |
| 136 | + tensor_parallel_size=args.tensor_parallel_size, |
| 137 | + gpu_memory_utilization=args.gpu_memory_utilization, |
| 138 | + **getattr(cfg, "diffusion_decoder_kwargs", {}), |
| 139 | + ) |
| 140 | + print(f"[INFO] Model loaded successfully") |
| 141 | + cfg.special_token_ids = {} |
| 142 | + for k, v in cfg.special_tokens.items(): |
| 143 | + cfg.special_token_ids[k] = tokenizer.encode(v)[0] |
| 144 | + |
| 145 | + inference( |
| 146 | + cfg=cfg, |
| 147 | + model=model, |
| 148 | + tokenizer=tokenizer, |
| 149 | + vq_model=vq_model, |
| 150 | + ) |
| 151 | + print(f"[INFO] Inference finished") |
| 152 | + |
| 153 | + |
| 154 | +if __name__ == "__main__": |
| 155 | + main() |
0 commit comments