Skip to content

🐛 [Bug] FLUX Attention Bug #4194

@cehongwang

Description

@cehongwang

Bug Description

File "/home/TensorRT/experiments/refit_flux_benchmark.py", line 112, in run_flux_benchmark
    trt_gm = torch_trt.dynamo.compile(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 798, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 1044, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 343, in convert_module
    serialized_interpreter_result = interpret_module_to_result(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 277, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 605, in run
    self._construct_trt_network_def()
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 412, in _construct_trt_network_def
    super().run()
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 200, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 678, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
                              ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 297, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 785, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 4044, in aten_ops_scaled_dot_product_flash_attention
    return impl.attention.scaled_dot_product_flash_attention(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/attention.py", line 285, in scaled_dot_product_flash_attention
    assert attention_layer is not None, "attention layer is None"
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: attention layer is None

To Reproduce

Run flux_demo.py in examples/apps

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions