Skip to content

Commit 1e40e5b

Browse files
authored
[vLLM] add vLLM offline inference (#47)
* [vLLM] add vLLM offline inference * update x2i tokenizer_path * add copyright * refactor requirements * rm requirements.txt * update readme
1 parent 5d6f548 commit 1e40e5b

29 files changed

+3782
-9
lines changed

README.md

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ Emu3.5 Team, BAAI
6464
### Environment Setup
6565

6666
```bash
67-
# Python 3.10 or higher is required.
67+
# Requires Python 3.12 or higher.
6868
git clone https://github.com/baaivision/Emu3.5
6969
cd Emu3.5
70-
pip install -r requirements.txt
70+
pip install -r requirements/transformers.txt
7171
pip install flash_attn==2.8.3 --no-build-isolation
7272
```
7373
### Configuration
@@ -112,6 +112,44 @@ CUDA_VISIBLE_DEVICES=0,1 python inference.py --cfg configs/example_config_visual
112112

113113
Protobuf outputs are written to `outputs/<exp_name>/proto/`. For better throughput, we recommend ≥2 GPUs.
114114

115+
116+
### Run Inference with vLLM
117+
118+
#### vLLM Enviroment Setup
119+
120+
1. [Optional Recommendation] Use a virtual environment
121+
```bash
122+
conda create -n Emu3p5 python=3.12
123+
```
124+
125+
2. Install vLLM and apply the patch files.
126+
```bash
127+
# Requires Python 3.12 or higher.
128+
# Recommended: CUDA 12.8.
129+
pip install -r requirements/vllm.txt
130+
pip install flash_attn==2.8.3 --no-build-isolation
131+
132+
cd Emu3.5
133+
python src/patch/apply.py
134+
```
135+
136+
#### Example Configurations by Task
137+
138+
```bash
139+
# 🖼️ Text-to-Image (T2I) task
140+
CUDA_VISIBLE_DEVICES=0,1 python inference_vllm.py --cfg configs/example_config_t2i.py
141+
142+
# 🔄 Any-to-Image (X2I) task
143+
CUDA_VISIBLE_DEVICES=0,1 python inference_vllm.py --cfg configs/example_config_x2i.py
144+
145+
# 🎯 Visual Guidance task
146+
CUDA_VISIBLE_DEVICES=0,1 python inference_vllm.py --cfg configs/example_config_visual_guidance.py
147+
148+
# 📖 Visual Narrative task
149+
CUDA_VISIBLE_DEVICES=0,1 python inference_vllm.py --cfg configs/example_config_visual_narrative.py
150+
```
151+
152+
115153
### Visualize Protobuf Outputs
116154

117155
To visualize generated protobuf files (--video: Generate video visualizations for interleaved output):

configs/example_config_x2i.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
model_path = "path_to_emu3.5_model" # download from hf
99
vq_path = "path_to_vq_model" # download from hf
1010

11-
tokenizer_path = "path_to_tokenizer"
11+
tokenizer_path = "./src/tokenizer_emu3_ibq"
1212
vq_type = "ibq"
1313

1414
task_type = "x2i"

inference_vllm.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright 2025 BAAI. and/or its affiliates.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import argparse
5+
import os
6+
import torch
7+
8+
import importlib as imp
9+
import os.path as osp
10+
11+
from pathlib import Path
12+
from PIL import Image
13+
from tqdm import tqdm
14+
15+
from src.utils.model_utils import build_emu3p5_vllm
16+
from src.utils.vllm_generation_utils import generate
17+
from src.utils.generation_utils import multimodal_decode
18+
from src.utils.painting_utils import ProtoWriter
19+
from src.utils.input_utils import build_image
20+
21+
22+
def parse_args():
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument("--cfg", default="", type=str)
25+
parser.add_argument("--tensor-parallel-size", default=2, type=int)
26+
parser.add_argument("--gpu-memory-utilization", default=0.7, type=float)
27+
parser.add_argument("--seed", default=6666, type=int)
28+
args = parser.parse_args()
29+
return args
30+
31+
32+
def inference(
33+
cfg,
34+
model,
35+
tokenizer,
36+
vq_model,
37+
):
38+
save_path = cfg.save_path
39+
40+
os.makedirs(save_path, exist_ok=True)
41+
os.makedirs(f"{save_path}/proto", exist_ok=True)
42+
proto_writer = ProtoWriter()
43+
44+
for name, question in tqdm(cfg.prompts, total=len(cfg.prompts)):
45+
if osp.exists(f"{save_path}/proto/{name}.pb"):
46+
print(f"[WARNING] Result already exists, skipping {name}", flush=True)
47+
continue
48+
49+
torch.cuda.empty_cache()
50+
51+
reference_image = None
52+
if not isinstance(question, str):
53+
if isinstance(question["reference_image"], list):
54+
print(f"[INFO] {len(question['reference_image'])} reference images are provided")
55+
reference_image = []
56+
for img in question["reference_image"]:
57+
reference_image.append(Image.open(img).convert("RGB"))
58+
else:
59+
print (f"[INFO] 1 reference image is provided")
60+
reference_image = Image.open(question["reference_image"]).convert("RGB")
61+
question = question["prompt"]
62+
else:
63+
print(f"[INFO] No reference image is provided")
64+
65+
proto_writer.clear()
66+
proto_writer.extend([["question", question]])
67+
if reference_image is not None:
68+
if isinstance(reference_image, list):
69+
for idx, img in enumerate(reference_image):
70+
proto_writer.extend([[f"reference_image", img]])
71+
else:
72+
proto_writer.extend([["reference_image", reference_image]])
73+
74+
success = True
75+
prompt = cfg.template.format(question=question)
76+
77+
print(f"[INFO] Handling prompt: {prompt}")
78+
if reference_image is not None:
79+
if isinstance(reference_image, list):
80+
image_str = ""
81+
for img in reference_image:
82+
image_str += build_image(img, cfg, tokenizer, vq_model)
83+
else:
84+
image_str = build_image(reference_image, cfg, tokenizer, vq_model)
85+
prompt = prompt.replace("<|IMAGE|>", image_str)
86+
unc_prompt = cfg.unc_prompt.replace("<|IMAGE|>", image_str)
87+
else:
88+
unc_prompt = cfg.unc_prompt
89+
90+
input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False)
91+
92+
if input_ids[0, 0] != cfg.special_token_ids["BOS"]:
93+
BOS = torch.Tensor([[cfg.special_token_ids["BOS"]]], dtype=input_ids.dtype)
94+
input_ids = torch.cat([BOS, input_ids], dim=1)
95+
96+
unconditional_ids = tokenizer.encode(unc_prompt, return_tensors="pt", add_special_tokens=False)
97+
98+
for result_tokens in generate(cfg, model, tokenizer, input_ids, unconditional_ids):
99+
try:
100+
print(f"{result_tokens.shape=}")
101+
result = tokenizer.decode(result_tokens, skip_special_tokens=False)
102+
mm_out = multimodal_decode(result, tokenizer, vq_model)
103+
proto_writer.extend(mm_out)
104+
except Exception as e:
105+
success = False
106+
print(f"[ERROR] Failed to generate token sequence: {e}")
107+
break
108+
109+
if not success:
110+
continue
111+
112+
proto_writer.save(f"{save_path}/proto/{name}.pb")
113+
114+
115+
def main():
116+
args = parse_args()
117+
cfg_name = Path(args.cfg).stem
118+
cfg_package = Path(args.cfg).parent.__str__().replace("/", ".")
119+
cfg = imp.import_module(f".{cfg_name}", package=cfg_package)
120+
121+
if isinstance(cfg.prompts, dict):
122+
cfg.prompts = [(n, p) for n, p in cfg.prompts.items()]
123+
else:
124+
cfg.prompts = [(f"{idx:03d}", p) for idx, p in enumerate(cfg.prompts)]
125+
126+
cfg.prompts = [(n, p) for n, p in cfg.prompts if not osp.exists(f"{cfg.save_path}/proto/{n}.pb")]
127+
cfg.num_prompts = len(cfg.prompts)
128+
129+
model, tokenizer, vq_model = build_emu3p5_vllm(
130+
cfg.model_path,
131+
cfg.tokenizer_path,
132+
cfg.vq_path,
133+
vq_type=cfg.vq_type,
134+
vq_device=cfg.vq_device,
135+
seed=cfg.seed,
136+
tensor_parallel_size=args.tensor_parallel_size,
137+
gpu_memory_utilization=args.gpu_memory_utilization,
138+
**getattr(cfg, "diffusion_decoder_kwargs", {}),
139+
)
140+
print(f"[INFO] Model loaded successfully")
141+
cfg.special_token_ids = {}
142+
for k, v in cfg.special_tokens.items():
143+
cfg.special_token_ids[k] = tokenizer.encode(v)[0]
144+
145+
inference(
146+
cfg=cfg,
147+
model=model,
148+
tokenizer=tokenizer,
149+
vq_model=vq_model,
150+
)
151+
print(f"[INFO] Inference finished")
152+
153+
154+
if __name__ == "__main__":
155+
main()
Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
torch>=2.6.0
2-
torchvision>=0.15.0
3-
torchaudio>=2.0.0
4-
transformers==4.48.2
5-
accelerate>=0.20.0
61
pillow>=9.0.0
72
numpy>=1.21.0
83
tqdm>=4.64.0
@@ -11,4 +6,4 @@ tiktoken>=0.12.0
116
imageio==2.37.0
127
imageio-ffmpeg==0.6.0
138
omegaconf==2.3.0
14-
gradio==5.49.1
9+
gradio==5.49.1

requirements/transformers.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
-r common.txt
2+
3+
torch>=2.6.0
4+
torchvision>=0.15.0
5+
torchaudio>=2.0.0
6+
transformers==4.48.2
7+
accelerate>=0.20.0

requirements/vllm.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-r common.txt
2+
3+
vllm==0.11.0; python_version > '3.11' # torch==2.8.0

src/patch/apply.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2025 BAAI. and/or its affiliates.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import argparse
6+
import os
7+
import subprocess
8+
import sys
9+
import shutil
10+
from pathlib import Path
11+
12+
REQUIRED_VLLM_VERSION = "0.11.0"
13+
14+
15+
def get_vllm_site():
16+
try:
17+
import vllm
18+
if getattr(vllm, "__version__", None) != REQUIRED_VLLM_VERSION:
19+
print(f"[FATAL] vLLM version must be {REQUIRED_VLLM_VERSION}, "
20+
f"but found {vllm.__version__}. Aborting.")
21+
sys.exit(10)
22+
print(f"[INFO] vLLM version verified: {vllm.__version__}")
23+
return Path(vllm.__file__).parent
24+
except ImportError:
25+
print("[ERROR] vllm is not installed. Please run: pip install vllm==0.11.0")
26+
sys.exit(1)
27+
28+
29+
def run_patch(patch_file, site_dir, dry_run=False):
30+
cmd = ["patch", f"-p2"]
31+
if dry_run:
32+
cmd.insert(1, "--dry-run")
33+
with open(patch_file, "r") as f:
34+
result = subprocess.run(
35+
cmd,
36+
cwd=str(site_dir),
37+
stdin=f,
38+
stdout=subprocess.PIPE,
39+
stderr=subprocess.PIPE,
40+
text=True
41+
)
42+
if not dry_run:
43+
print(f"[INFO] Applied patch: {patch_file}")
44+
if result.returncode != 0:
45+
print(f"[ERROR] Patch failed: {patch_file}")
46+
print(result.stdout)
47+
print(result.stderr)
48+
return result.returncode == 0, result.stdout, result.stderr
49+
50+
51+
def extract_patch_targets(patch_file):
52+
targets = []
53+
with open(patch_file, "r") as f:
54+
for line in f:
55+
if line.startswith("--- a/") or line.startswith("+++ b/"):
56+
path = line.split("\t")[0].split(" ", 1)[-1]
57+
if path not in ("a/dev/null", "b/dev/null"):
58+
idx = len("a/vllm/")
59+
targets.append(path[idx:-1])
60+
return list(set(targets))
61+
62+
63+
def backup_files(targets, site_dir, backup_root):
64+
for rel in targets:
65+
src = site_dir / rel
66+
if src.exists():
67+
dst = backup_root / rel
68+
dst.parent.mkdir(parents=True, exist_ok=True)
69+
print(f"[INFO] Backing up {src} to {dst}")
70+
shutil.copy2(src, dst)
71+
72+
73+
def restore_backup(backup_root, site_dir):
74+
if not backup_root.exists():
75+
print("[WARN] No backup directory found.")
76+
return
77+
for root, _, files in os.walk(backup_root):
78+
for f in files:
79+
bfile = Path(root) / f
80+
rel = bfile.relative_to(backup_root)
81+
orig = site_dir / rel
82+
orig.parent.mkdir(parents=True, exist_ok=True)
83+
shutil.copy2(bfile, orig)
84+
print("[INFO] Restore completed.")
85+
86+
87+
def main():
88+
parser = argparse.ArgumentParser()
89+
parser.add_argument("--patch-dir", type=str, default="./third_party/vllm/",
90+
help="Directory containing .patch files")
91+
args = parser.parse_args()
92+
patch_dir = Path(args.patch_dir)
93+
94+
if not patch_dir.exists() or not patch_dir.is_dir():
95+
print(f"[ERROR] patch-dir does not exist: {patch_dir}")
96+
sys.exit(1)
97+
98+
site_dir = get_vllm_site()
99+
print(f"[INFO] vLLM site-packages: {site_dir}")
100+
101+
patch_files = sorted(p for p in patch_dir.rglob("*.patch"))
102+
if not patch_files:
103+
print("[ERROR] No patch files found.")
104+
sys.exit(1)
105+
106+
print(f"[INFO] Found {len(patch_files)} patch(es).")
107+
108+
# Backup root folder
109+
backup_root = site_dir.parent / "vllm_patch_backup"
110+
if backup_root.exists():
111+
print("[WARN] Removing previous backup...")
112+
shutil.rmtree(backup_root)
113+
backup_root.mkdir(parents=True)
114+
115+
print("[INFO] Running dry-run...")
116+
for p in patch_files:
117+
ok, out, err = run_patch(p, site_dir, dry_run=True)
118+
if not ok:
119+
print(f"[FATAL] Dry-run failed for patch: {p}\n{err}")
120+
sys.exit(2)
121+
print("[INFO] Dry-run passed.")
122+
123+
print("[INFO] Backing up modified files...")
124+
for p in patch_files:
125+
targets = extract_patch_targets(p)
126+
backup_files(targets, site_dir, backup_root)
127+
128+
print("[INFO] Applying patches...")
129+
for p in patch_files:
130+
ok, out, err = run_patch(p, site_dir, dry_run=False)
131+
if not ok:
132+
print(f"[ERROR] Failed to apply patch: {p}\n{err}")
133+
print("[INFO] Restoring from backup...")
134+
restore_backup(backup_root, site_dir)
135+
sys.exit(3)
136+
print("[SUCCESS] All patches applied successfully.")
137+
print(f"[INFO] Backup stored at: {backup_root}")
138+
139+
if __name__ == "__main__":
140+
main()

0 commit comments

Comments
 (0)