|
18 | 18 | from typing import List, Optional, Tuple, Union
|
19 | 19 |
|
20 | 20 | import torch
|
| 21 | +import xgrammar |
21 | 22 | from xgrammar import (
|
22 | 23 | CompiledGrammar,
|
23 | 24 | GrammarCompiler,
|
@@ -58,17 +59,11 @@ def __init__(
|
58 | 59 | self.override_stop_tokens = override_stop_tokens
|
59 | 60 | self.finished = False
|
60 | 61 |
|
61 |
| - # Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the |
62 |
| - # class init site to avoid re-initializing CUDA in forked subprocess. |
63 |
| - from xgrammar.kernels import apply_token_bitmask_inplace_kernels |
64 |
| - |
65 |
| - self.use_token_bitmask_triton = get_bool_env_var( |
66 |
| - "SGLANG_TOKEN_BITMASK_TRITON", "false" |
67 |
| - ) |
68 |
| - self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get( |
69 |
| - "cuda", None |
| 62 | + from xgrammar.kernels.apply_token_bitmask_inplace_cpu import ( |
| 63 | + apply_token_bitmask_inplace_cpu, |
70 | 64 | )
|
71 |
| - self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None) |
| 65 | + |
| 66 | + self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu |
72 | 67 |
|
73 | 68 | def accept_token(self, token: int):
|
74 | 69 | assert self.matcher.accept_token(token)
|
@@ -113,15 +108,12 @@ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
113 | 108 | return vocab_mask.to(device, non_blocking=True)
|
114 | 109 |
|
115 | 110 | def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
116 |
| - if ( |
117 |
| - not self.use_token_bitmask_triton |
118 |
| - and logits.device.type == "cuda" |
119 |
| - and self.apply_vocab_mask_cuda |
120 |
| - ): |
121 |
| - return self.apply_vocab_mask_cuda(logits, vocab_mask) |
122 |
| - if logits.device.type == "cpu" and self.apply_vocab_mask_cpu: |
123 |
| - return self.apply_vocab_mask_cpu(logits, vocab_mask) |
124 |
| - apply_token_bitmask_inplace_triton(logits, vocab_mask) |
| 111 | + if logits.device.type == "cuda": |
| 112 | + apply_token_bitmask_inplace_triton(logits, vocab_mask) |
| 113 | + elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu: |
| 114 | + self.apply_vocab_mask_cpu(logits, vocab_mask) |
| 115 | + else: |
| 116 | + raise RuntimeError(f"Unsupported device: {logits.device.type}") |
125 | 117 |
|
126 | 118 | def copy(self):
|
127 | 119 | matcher = GrammarMatcher(
|
|
0 commit comments