Closed
Description
Checklist
- 1. I have searched related issues but cannot get the expected help.
- 2. The bug has not been fixed in the latest version.
- 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
- 5. Please use English, otherwise it will be closed.
Describe the bug
MLA models like DeepSeek-V2 and MiniCPM3 cannot use --enable-torch-compile
.
However, simply adding 2 lines of code into sglang can resolve it and improve single batch speed.
import torch._dynamo
torch._dynamo.config.suppress_errors = True
I request to add this error suppression before torch becomes stable.
For example, fixing this bug can make MiniCPM3 single batch decoding throughput rise from 66 token/s to 103 token/s, without affecting the output quality.
Reproduction
to reproduce:
python3 -m sglang.bench_latency --model openbmb/MiniCPM3-4B--trust-remote-code --input-len 1024 --output-len 1024 --batch 1 --enable-torch-compile
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] Error in codegen for ComputedBuffer(name='buf13', layout=FixedLayout('cuda', torch.bfloat16, size=[s0, 40, 128], stride=[5120, 128, 1]), data=Pointwise(
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] 'cuda',
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] torch.bfloat16,
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] def inner_fn(index):
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] i0, i1, i2 = index
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp0 = ops.index_expr(i2, torch.int64)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp1 = ops.index_expr(96, torch.int64)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp2 = tmp0 < tmp1
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp3 = ops.index_expr(i2, torch.int64)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp4 = ops.index_expr(64, torch.int64)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp5 = tmp3 >= tmp4
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp6 = ops.load(buf12, -64 + i2 + 32 * i0)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp7 = ops.masked(tmp5, tmp6, 0.0)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp8 = ops.index_expr(i2, torch.int64)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp9 = ops.index_expr(64, torch.int64)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp10 = tmp8 < tmp9
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp11 = ops.load(buf7, i2 + 128 * i1 + 5120 * i0)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp12 = ops.masked(tmp10, tmp11, 0.0)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp13 = ops.load(buf11, i2 + 96 * i1 + 3840 * i0)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp14 = ops.where(tmp10, tmp12, tmp13)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp15 = ops.where(tmp5, tmp7, tmp14)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] tmp16 = ops.masked(tmp2, tmp15, 0.0)
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] return tmp16
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] ,
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] ranges=[s0, 40, 128],
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] origin_node=constant_pad_nd_1,
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] origins={slice_scatter_default_2, copy_2, slice_scatter_defau...
[rank0]:C0914 16:10:28.525000 47272548425472 torch/_inductor/scheduler.py:836] [5/62_1] ))
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT forward /mypath/sglang/python/sglang/srt/models/minicpm3.py line 188
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] due to:
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] result = self._inner_convert(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return _compile(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return StrobelightCompileTimeProfiler.profile_compile_time(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] out_code = transform_code_object(code, transform)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] transformations(instructions, code_options)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return fn(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] tracer.run()
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] super().run()
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] while self.step():
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] self.dispatch_table[inst.opcode](self, inst)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 497, in wrapper
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return handle_graph_break(self, inst, speculation.reason)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 566, in handle_graph_break
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] self.output.compile_subgraph(self, reason=reason)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1123, in compile_subgraph
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1318, in compile_and_call_fx_graph
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] compiled_fn = self.call_user_compiler(gm)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1409, in call_user_compiler
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1390, in call_user_compiler
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] compiled_fn = compiler_fn(gm, self.example_inputs())
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] compiled_gm = compiler_fn(gm, example_inputs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/__init__.py", line 1951, in __call__
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return compile_fx(model_, inputs_, config_patches=self.config)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1261, in compile_fx
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return compile_fx(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1505, in compile_fx
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return aot_autograd(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 69, in __call__
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 954, in aot_module_simplified
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] compiled_fn, _ = create_aot_dispatcher_function(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 687, in create_aot_dispatcher_function
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] compiled_fn, fw_metadata = compiler_fn(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 168, in aot_dispatch_base
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] compiled_fw = compiler(fw_module, updated_flat_args)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1410, in fw_compiler_base
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return inner_compile(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 84, in debug_wrapper
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] inner_compiled_fn = compiler_fn(gm, example_inputs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/debug.py", line 304, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return fn(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 517, in compile_fx_inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] compiled_graph = FxGraphCache.load(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1044, in load
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 831, in fx_codegen_and_compile
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] compiled_fn = graph.compile_to_fn()
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1749, in compile_to_fn
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return self.compile_to_module().call
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1678, in compile_to_module
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1634, in codegen
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] self.scheduler = Scheduler(self.buffers)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 1409, in __init__
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] self.fuse_nodes()
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 1825, in fuse_nodes
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] self.fuse_nodes_once()
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 2067, in fuse_nodes_once
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] if not self.speedup_by_fusion(node1, node2):
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 1996, in speedup_by_fusion
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] ms_fused, _ = self.benchmark_fused_nodes(node_list_fused)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 1848, in benchmark_fused_nodes
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return backend.benchmark_fused_nodes(nodes)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 81, in benchmark_fused_nodes
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return self._triton_scheduling.benchmark_fused_nodes(nodes)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/triton.py", line 2515, in benchmark_fused_nodes
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/simd.py", line 1686, in generate_kernel_code_from_nodes
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] src_code = self.codegen_template(
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/simd.py", line 1420, in codegen_template
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 834, in codegen
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] self._body(*index_vars)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/ir.py", line 8038, in __call__
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] result = self.root_block()
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/ir.py", line 8200, in __call__
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return InterpreterShim(graph, submodules).run(V.get_ops_handler())
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/ir.py", line 7941, in run
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return super().run(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/fx/interpreter.py", line 146, in run
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] self.env[node] = self.run_node(node)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/ir.py", line 7937, in run_node
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return super().run_node(n)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/fx/interpreter.py", line 203, in run_node
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/fx/interpreter.py", line 320, in call_module
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return submod(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/ir.py", line 8116, in shim
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return V.ops.masked(mask, subblock, other)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/common.py", line 1676, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/triton.py", line 923, in masked
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] result = body()
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/ir.py", line 8200, in __call__
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return InterpreterShim(graph, submodules).run(V.get_ops_handler())
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/ir.py", line 7941, in run
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return super().run(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/fx/interpreter.py", line 146, in run
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] self.env[node] = self.run_node(node)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/ir.py", line 7937, in run_node
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return super().run_node(n)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/fx/interpreter.py", line 203, in run_node
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/fx/interpreter.py", line 320, in call_module
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return submod(*args, **kwargs)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/ir.py", line 8116, in shim
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] return V.ops.masked(mask, subblock, other)
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/common.py", line 1676, in inner
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] File "/my_conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/triton.py", line 926, in masked
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] if result.bounds.is_bool:
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] AttributeError: 'str' object has no attribute 'bounds'
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009]
[rank0]:W0914 16:10:28.557000 47272548425472 torch/_dynamo/convert_frame.py:1009] Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Environment
Python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA A800-SXM4-80GB
GPU 0,1,2,3,4,5,6,7 Compute Capability: 8.0
CUDA_HOME: /home/test/test01/cuda-12.1
NVCC: Cuda compilation tools, release 12.1, V12.1.66
CUDA Driver Version: 535.183.06
PyTorch: 2.4.0+cu121
sglang: 0.3.0
flashinfer: 0.1.6
triton: 3.0.0
transformers: 4.43.3
requests: 2.32.3
tqdm: 4.66.4
numpy: 1.26.3
aiohttp: 3.10.0
fastapi: 0.111.1
hf_transfer: 0.1.8
huggingface_hub: 0.24.3
interegular: 0.3.3
packaging: 24.1
PIL: 10.2.0
psutil: 6.0.0
pydantic: 2.8.2
uvicorn: 0.30.3
uvloop: 0.19.0
zmq: 26.0.3
vllm: 0.5.5
multipart: 0.0.9
openai: 1.44.0
anthropic: 0.34.2
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 NIC5 NIC6 NIC7 NIC8 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV8 NV8 NV8 NV8 NV8 NV8 NV8 PXB NODE NODE NODE SYS SYS SYS SYS SYS 0-51 0 N/A
GPU1 NV8 X NV8 NV8 NV8 NV8 NV8 NV8 NODE PXB NODE NODE SYS SYS SYS SYS SYS 0-51 0 N/A
GPU2 NV8 NV8 X NV8 NV8 NV8 NV8 NV8 NODE NODE PXB NODE SYS SYS SYS SYS SYS 0-51 0 N/A
GPU3 NV8 NV8 NV8 X NV8 NV8 NV8 NV8 NODE NODE NODE PXB SYS SYS SYS SYS SYS 0-51 0 N/A
GPU4 NV8 NV8 NV8 NV8 X NV8 NV8 NV8 SYS SYS SYS SYS PXB NODE NODE NODE NODE 52-103 1 N/A
GPU5 NV8 NV8 NV8 NV8 NV8 X NV8 NV8 SYS SYS SYS SYS NODE NODE PXB NODE NODE 52-103 1 N/A
GPU6 NV8 NV8 NV8 NV8 NV8 NV8 X NV8 SYS SYS SYS SYS NODE NODE NODE PXB NODE 52-103 1 N/A
GPU7 NV8 NV8 NV8 NV8 NV8 NV8 NV8 X SYS SYS SYS SYS NODE NODE NODE NODE PXB 52-103 1 N/A
NIC0 PXB NODE NODE NODE SYS SYS SYS SYS X NODE NODE NODE SYS SYS SYS SYS SYS
NIC1 NODE PXB NODE NODE SYS SYS SYS SYS NODE X NODE NODE SYS SYS SYS SYS SYS
NIC2 NODE NODE PXB NODE SYS SYS SYS SYS NODE NODE X NODE SYS SYS SYS SYS SYS
NIC3 NODE NODE NODE PXB SYS SYS SYS SYS NODE NODE NODE X SYS SYS SYS SYS SYS
NIC4 SYS SYS SYS SYS PXB NODE NODE NODE SYS SYS SYS SYS X NODE NODE NODE NODE
NIC5 SYS SYS SYS SYS NODE NODE NODE NODE SYS SYS SYS SYS NODE X NODE NODE NODE
NIC6 SYS SYS SYS SYS NODE PXB NODE NODE SYS SYS SYS SYS NODE NODE X NODE NODE
NIC7 SYS SYS SYS SYS NODE NODE PXB NODE SYS SYS SYS SYS NODE NODE NODE X NODE
NIC8 SYS SYS SYS SYS NODE NODE NODE PXB SYS SYS SYS SYS NODE NODE NODE NODE X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
NIC4: mlx5_4
NIC5: mlx5_5
NIC6: mlx5_6
NIC7: mlx5_7
NIC8: mlx5_8
ulimit soft: 1000000
Metadata
Metadata
Assignees
Labels
No labels