-
Notifications
You must be signed in to change notification settings - Fork 363
Added flux demo #3418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Added flux demo #3418
Conversation
b2eb297
to
6d36077
Compare
Can the app display the inference time, might be nice to have some stats rendered live as you generate |
48a7c94
to
5a528f1
Compare
361fb76
to
0aeea36
Compare
9964674
to
cfbc9ea
Compare
c9cca30
to
27dee53
Compare
a791b42
to
3dcf128
Compare
41139e9
to
f536ac6
Compare
py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Outdated
Show resolved
Hide resolved
tools/perf/Flux/register_sdpa.py
Outdated
from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( | ||
clean_up_graph_after_modifications, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use them from examples
tools/perf/Flux/flux_quantization.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should avoid copying the whole model scripts for measuring perf. Try using the sys.path approach and importing the model and just a perf loop. something like
import sys
import os
sys.path.append(torchtrt_root + "examples/dynamo/apps")
from flux_demo import *
model = <insert FLUX model (fp16 or fp8) >
results = measure_flux_perf(.... )
a031a02
to
9acbee6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-12 23:00:13.565607+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-12 23:00:38.626299+00:00
@@ -73,11 +73,10 @@
max_bound = 127
elif num_bits == 8 and exponent_bits == 4:
dtype = trt.DataType.FP8
max_bound = 448
-
axis = None
# int8 weight quantization is per-channel quantization(it can have one or multiple amax values)
if dtype == trt.DataType.INT8 and amax.numel() > 1:
# if the amax has more than one element, calculate the axis, otherwise axis value will be ignored
amax_init_shape = amax.shape
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py 2025-06-12 23:00:13.567607+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py 2025-06-12 23:00:39.100847+00:00
@@ -98,16 +98,17 @@
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def is_impure(self, node: torch.fx.node.Node) -> bool:
- # Set of known quantization ops to be excluded from constant folding.
+ # Set of known quantization ops to be excluded from constant folding.
# Currently, we exclude all quantization ops coming from modelopt library.
quantization_ops = {}
try:
- # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
+ # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
import modelopt.torch.quantization as mtq
+
assert torch.ops.tensorrt.quantize_op.default
quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
except Exception as e:
pass
if quantization_ops and node.target in quantization_ops:
if args.debug: | ||
pipe.transformer = FluxTransformer2DModel( | ||
num_layers=1, num_single_layers=1, guidance_embeds=True | ||
).to(torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: